diff --git a/libp2p/network/stream/net_stream.py b/libp2p/network/stream/net_stream.py index f4e078e..cac0d48 100644 --- a/libp2p/network/stream/net_stream.py +++ b/libp2p/network/stream/net_stream.py @@ -28,12 +28,13 @@ class NetStream(INetStream): """ self.protocol_id = protocol_id - async def read(self) -> bytes: + async def read(self, n: int = -1) -> bytes: """ - read from stream - :return: bytes of input until EOF + reads from stream + :param n: number of bytes to read + :return: bytes of input """ - return await self.muxed_stream.read() + return await self.muxed_stream.read(n) async def write(self, data: bytes) -> int: """ diff --git a/libp2p/network/stream/net_stream_interface.py b/libp2p/network/stream/net_stream_interface.py index 6bf25ea..43bbc53 100644 --- a/libp2p/network/stream/net_stream_interface.py +++ b/libp2p/network/stream/net_stream_interface.py @@ -22,9 +22,10 @@ class INetStream(ABC): """ @abstractmethod - async def read(self) -> bytes: + async def read(self, n: int = -1) -> bytes: """ reads from the underlying muxed_stream + :param n: number of bytes to read :return: bytes of input """ diff --git a/libp2p/stream_muxer/abc.py b/libp2p/stream_muxer/abc.py index 0903633..5c11107 100644 --- a/libp2p/stream_muxer/abc.py +++ b/libp2p/stream_muxer/abc.py @@ -88,9 +88,10 @@ class IMuxedStream(ABC): mplex_conn: IMuxedConn @abstractmethod - async def read(self) -> bytes: + async def read(self, n: int = -1) -> bytes: """ reads from the underlying muxed_conn + :param n: number of bytes to read :return: bytes of input """ diff --git a/libp2p/stream_muxer/mplex/mplex_stream.py b/libp2p/stream_muxer/mplex/mplex_stream.py index 7fc1361..6df8227 100644 --- a/libp2p/stream_muxer/mplex/mplex_stream.py +++ b/libp2p/stream_muxer/mplex/mplex_stream.py @@ -1,4 +1,5 @@ import asyncio +from io import BytesIO from libp2p.stream_muxer.abc import IMuxedConn, IMuxedStream @@ -19,6 +20,8 @@ class MplexStream(IMuxedStream): remote_closed: bool stream_lock: asyncio.Lock + _buf: bytes + def __init__(self, stream_id: int, initiator: bool, mplex_conn: IMuxedConn) -> None: """ create new MuxedStream in muxer @@ -34,13 +37,33 @@ class MplexStream(IMuxedStream): self.local_closed = False self.remote_closed = False self.stream_lock = asyncio.Lock() + self._buf = None - async def read(self) -> bytes: + async def read(self, n: int = -1) -> bytes: """ read messages associated with stream from buffer til end of file + :param n: number of bytes to read :return: bytes of input """ - return await self.mplex_conn.read_buffer(self.stream_id) + if n == -1: + return await self.mplex_conn.read_buffer(self.stream_id) + return await self.read_bytes(n) + + async def read_bytes(self, n: int) -> bytes: + if self._buf is None: + self._buf = await self.mplex_conn.read_buffer(self.stream_id) + n_read = 0 + bytes_buf = BytesIO() + while self._buf is not None and n_read < n: + n_to_read = min(n - n_read, len(self._buf)) + bytes_buf.write(self._buf[:n_to_read]) + if n_to_read == n - n_read: + self._buf = self._buf[n_to_read:] + else: + self._buf = None + self._buf = await self.mplex_conn.read_buffer(self.stream_id) + n_read += n_to_read + return bytes_buf.getvalue() async def write(self, data: bytes) -> int: """