Modify NetStream
to read n bytes
This commit is contained in:
parent
dbdbcf7440
commit
2485a00e24
|
@ -28,12 +28,13 @@ class NetStream(INetStream):
|
|||
"""
|
||||
self.protocol_id = protocol_id
|
||||
|
||||
async def read(self) -> bytes:
|
||||
async def read(self, n: int = -1) -> bytes:
|
||||
"""
|
||||
read from stream
|
||||
:return: bytes of input until EOF
|
||||
reads from stream
|
||||
:param n: number of bytes to read
|
||||
:return: bytes of input
|
||||
"""
|
||||
return await self.muxed_stream.read()
|
||||
return await self.muxed_stream.read(n)
|
||||
|
||||
async def write(self, data: bytes) -> int:
|
||||
"""
|
||||
|
|
|
@ -22,9 +22,10 @@ class INetStream(ABC):
|
|||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def read(self) -> bytes:
|
||||
async def read(self, n: int = -1) -> bytes:
|
||||
"""
|
||||
reads from the underlying muxed_stream
|
||||
:param n: number of bytes to read
|
||||
:return: bytes of input
|
||||
"""
|
||||
|
||||
|
|
|
@ -88,9 +88,10 @@ class IMuxedStream(ABC):
|
|||
mplex_conn: IMuxedConn
|
||||
|
||||
@abstractmethod
|
||||
async def read(self) -> bytes:
|
||||
async def read(self, n: int = -1) -> bytes:
|
||||
"""
|
||||
reads from the underlying muxed_conn
|
||||
:param n: number of bytes to read
|
||||
:return: bytes of input
|
||||
"""
|
||||
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import asyncio
|
||||
from io import BytesIO
|
||||
|
||||
from libp2p.stream_muxer.abc import IMuxedConn, IMuxedStream
|
||||
|
||||
|
@ -19,6 +20,8 @@ class MplexStream(IMuxedStream):
|
|||
remote_closed: bool
|
||||
stream_lock: asyncio.Lock
|
||||
|
||||
_buf: bytes
|
||||
|
||||
def __init__(self, stream_id: int, initiator: bool, mplex_conn: IMuxedConn) -> None:
|
||||
"""
|
||||
create new MuxedStream in muxer
|
||||
|
@ -34,13 +37,33 @@ class MplexStream(IMuxedStream):
|
|||
self.local_closed = False
|
||||
self.remote_closed = False
|
||||
self.stream_lock = asyncio.Lock()
|
||||
self._buf = None
|
||||
|
||||
async def read(self) -> bytes:
|
||||
async def read(self, n: int = -1) -> bytes:
|
||||
"""
|
||||
read messages associated with stream from buffer til end of file
|
||||
:param n: number of bytes to read
|
||||
:return: bytes of input
|
||||
"""
|
||||
return await self.mplex_conn.read_buffer(self.stream_id)
|
||||
if n == -1:
|
||||
return await self.mplex_conn.read_buffer(self.stream_id)
|
||||
return await self.read_bytes(n)
|
||||
|
||||
async def read_bytes(self, n: int) -> bytes:
|
||||
if self._buf is None:
|
||||
self._buf = await self.mplex_conn.read_buffer(self.stream_id)
|
||||
n_read = 0
|
||||
bytes_buf = BytesIO()
|
||||
while self._buf is not None and n_read < n:
|
||||
n_to_read = min(n - n_read, len(self._buf))
|
||||
bytes_buf.write(self._buf[:n_to_read])
|
||||
if n_to_read == n - n_read:
|
||||
self._buf = self._buf[n_to_read:]
|
||||
else:
|
||||
self._buf = None
|
||||
self._buf = await self.mplex_conn.read_buffer(self.stream_id)
|
||||
n_read += n_to_read
|
||||
return bytes_buf.getvalue()
|
||||
|
||||
async def write(self, data: bytes) -> int:
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue
Block a user