diff --git a/libp2p/security/noise/connection.py b/libp2p/security/noise/connection.py index 8f94a0f..29bbc8b 100644 --- a/libp2p/security/noise/connection.py +++ b/libp2p/security/noise/connection.py @@ -1,3 +1,5 @@ +import io + from noise.connection import NoiseConnection as NoiseState from libp2p.crypto.keys import PrivateKey @@ -8,6 +10,10 @@ from libp2p.security.noise.io import MsgReadWriter, NoiseTransportReadWriter class NoiseConnection(BaseSession): + buf: io.BytesIO + low_watermark: int + high_watermark: int + read_writer: IRawConnection noise_state: NoiseState @@ -19,20 +25,64 @@ class NoiseConnection(BaseSession): conn: IRawConnection, is_initiator: bool, noise_state: NoiseState, + # remote_permanent_pubkey ) -> None: super().__init__(local_peer, local_private_key, is_initiator, remote_peer) self.conn = conn self.noise_state = noise_state + self._reset_internal_buffer() def get_msg_read_writer(self) -> MsgReadWriter: return NoiseTransportReadWriter(self.conn, self.noise_state) + async def close(self) -> None: + await self.conn.close() + + def _reset_internal_buffer(self) -> None: + self.buf = io.BytesIO() + self.low_watermark = 0 + self.high_watermark = 0 + + def _drain(self, n: int) -> bytes: + if self.low_watermark == self.high_watermark: + return bytes() + + data = self.buf.getbuffer()[self.low_watermark : self.high_watermark] + + if n is None: + n = len(data) + result = data[:n].tobytes() + self.low_watermark += len(result) + + if self.low_watermark == self.high_watermark: + del data # free the memoryview so we can free the underlying BytesIO + self.buf.close() + self._reset_internal_buffer() + return result + async def read(self, n: int = None) -> bytes: - # TODO: Use a buffer to handle buffered messages. + if n == 0: + return bytes() + + data_from_buffer = self._drain(n) + if len(data_from_buffer) > 0: + return data_from_buffer + + msg = await self.read_msg() + + if n < len(msg): + self.buf.write(msg) + self.low_watermark = 0 + self.high_watermark = len(msg) + return self._drain(n) + else: + return msg + + async def read_msg(self) -> bytes: return await self.get_msg_read_writer().read_msg() async def write(self, data: bytes) -> None: - await self.get_msg_read_writer().write_msg(data) + await self.write_msg(data) - async def close(self) -> None: - await self.conn.close() + async def write_msg(self, msg: bytes) -> None: + await self.get_msg_read_writer().write_msg(msg) diff --git a/tests/security/noise/test_noise.py b/tests/security/noise/test_noise.py index 1c5eebb..f1d208c 100644 --- a/tests/security/noise/test_noise.py +++ b/tests/security/noise/test_noise.py @@ -3,7 +3,9 @@ import pytest from libp2p.security.noise.messages import NoiseHandshakePayload from libp2p.tools.factories import noise_conn_factory, noise_handshake_payload_factory -DATA = b"testing_123" +DATA_0 = b"data_0" +DATA_1 = b"1" * 1000 +DATA_2 = b"data_2" @pytest.mark.trio @@ -16,9 +18,12 @@ async def test_noise_transport(nursery): async def test_noise_connection(nursery): async with noise_conn_factory(nursery) as conns: local_conn, remote_conn = conns - await local_conn.write(DATA) - read_data = await remote_conn.read(len(DATA)) - assert read_data == DATA + await local_conn.write(DATA_0) + await local_conn.write(DATA_1) + assert DATA_0 == (await remote_conn.read(len(DATA_0))) + assert DATA_1 == (await remote_conn.read(len(DATA_1))) + await local_conn.write(DATA_2) + assert DATA_2 == (await remote_conn.read(len(DATA_2))) def test_noise_handshake_payload():