Modify NetStream to read n bytes

This commit is contained in:
mhchia 2019-08-07 15:23:20 +08:00 committed by Kevin Mai-Husan Chia
parent dbdbcf7440
commit 2485a00e24
4 changed files with 34 additions and 8 deletions

View File

@ -28,12 +28,13 @@ class NetStream(INetStream):
""" """
self.protocol_id = protocol_id self.protocol_id = protocol_id
async def read(self) -> bytes: async def read(self, n: int = -1) -> bytes:
""" """
read from stream reads from stream
:return: bytes of input until EOF :param n: number of bytes to read
:return: bytes of input
""" """
return await self.muxed_stream.read() return await self.muxed_stream.read(n)
async def write(self, data: bytes) -> int: async def write(self, data: bytes) -> int:
""" """

View File

@ -22,9 +22,10 @@ class INetStream(ABC):
""" """
@abstractmethod @abstractmethod
async def read(self) -> bytes: async def read(self, n: int = -1) -> bytes:
""" """
reads from the underlying muxed_stream reads from the underlying muxed_stream
:param n: number of bytes to read
:return: bytes of input :return: bytes of input
""" """

View File

@ -88,9 +88,10 @@ class IMuxedStream(ABC):
mplex_conn: IMuxedConn mplex_conn: IMuxedConn
@abstractmethod @abstractmethod
async def read(self) -> bytes: async def read(self, n: int = -1) -> bytes:
""" """
reads from the underlying muxed_conn reads from the underlying muxed_conn
:param n: number of bytes to read
:return: bytes of input :return: bytes of input
""" """

View File

@ -1,4 +1,5 @@
import asyncio import asyncio
from io import BytesIO
from libp2p.stream_muxer.abc import IMuxedConn, IMuxedStream from libp2p.stream_muxer.abc import IMuxedConn, IMuxedStream
@ -19,6 +20,8 @@ class MplexStream(IMuxedStream):
remote_closed: bool remote_closed: bool
stream_lock: asyncio.Lock stream_lock: asyncio.Lock
_buf: bytes
def __init__(self, stream_id: int, initiator: bool, mplex_conn: IMuxedConn) -> None: def __init__(self, stream_id: int, initiator: bool, mplex_conn: IMuxedConn) -> None:
""" """
create new MuxedStream in muxer create new MuxedStream in muxer
@ -34,13 +37,33 @@ class MplexStream(IMuxedStream):
self.local_closed = False self.local_closed = False
self.remote_closed = False self.remote_closed = False
self.stream_lock = asyncio.Lock() self.stream_lock = asyncio.Lock()
self._buf = None
async def read(self) -> bytes: async def read(self, n: int = -1) -> bytes:
""" """
read messages associated with stream from buffer til end of file read messages associated with stream from buffer til end of file
:param n: number of bytes to read
:return: bytes of input :return: bytes of input
""" """
return await self.mplex_conn.read_buffer(self.stream_id) if n == -1:
return await self.mplex_conn.read_buffer(self.stream_id)
return await self.read_bytes(n)
async def read_bytes(self, n: int) -> bytes:
if self._buf is None:
self._buf = await self.mplex_conn.read_buffer(self.stream_id)
n_read = 0
bytes_buf = BytesIO()
while self._buf is not None and n_read < n:
n_to_read = min(n - n_read, len(self._buf))
bytes_buf.write(self._buf[:n_to_read])
if n_to_read == n - n_read:
self._buf = self._buf[n_to_read:]
else:
self._buf = None
self._buf = await self.mplex_conn.read_buffer(self.stream_id)
n_read += n_to_read
return bytes_buf.getvalue()
async def write(self, data: bytes) -> int: async def write(self, data: bytes) -> int:
""" """