From 9cb6ec1c48e1db934a3ab2ca25e6ee04e6144152 Mon Sep 17 00:00:00 2001 From: mhchia Date: Sun, 11 Aug 2019 23:49:58 +0800 Subject: [PATCH] Modify the behavior of `MplexStream.read` --- libp2p/stream_muxer/mplex/mplex.py | 29 ++++++++++------ libp2p/stream_muxer/mplex/mplex_stream.py | 42 +++++++++++------------ 2 files changed, 38 insertions(+), 33 deletions(-) diff --git a/libp2p/stream_muxer/mplex/mplex.py b/libp2p/stream_muxer/mplex/mplex.py index 95336d9..ecddedd 100644 --- a/libp2p/stream_muxer/mplex/mplex.py +++ b/libp2p/stream_muxer/mplex/mplex.py @@ -1,5 +1,5 @@ import asyncio -from typing import Dict, Tuple +from typing import Dict, Tuple, Optional from multiaddr import Multiaddr @@ -10,6 +10,7 @@ from libp2p.security.secure_conn_interface import ISecureConn from libp2p.stream_muxer.abc import IMuxedConn, IMuxedStream from .constants import HeaderTags +from .exceptions import StreamNotFound from .mplex_stream import MplexStream from .utils import decode_uvarint_from_stream, encode_uvarint @@ -23,6 +24,9 @@ class Mplex(IMuxedConn): raw_conn: IRawConnection initiator: bool peer_id: ID + # TODO: `dataIn` in go implementation. Should be size of 8. + # TODO: Also, `dataIn` is closed indicating EOF in Go. We don't have similar strategies + # to let the `MplexStream`s know that EOF arrived (#235). buffers: Dict[int, "asyncio.Queue[bytes]"] stream_queue: "asyncio.Queue[int]" @@ -75,17 +79,19 @@ class Mplex(IMuxedConn): :param stream_id: stream id of stream to read from :return: message read """ - # TODO: propagate up timeout exception and catch - # TODO: pass down timeout from user and use that - if stream_id in self.buffers: - try: - data = await asyncio.wait_for(self.buffers[stream_id].get(), timeout=8) - return data - except asyncio.TimeoutError: - return None + if stream_id not in self.buffers: + raise StreamNotFound(f"stream {stream_id} is not found") + return await self.buffers[stream_id].get() - # Stream not created yet - return None + async def read_buffer_nonblocking(self, stream_id: int) -> Optional[bytes]: + """ + Read a message from `stream_id`'s buffer, non-blockingly. + """ + if stream_id not in self.buffers: + raise StreamNotFound(f"stream {stream_id} is not found") + if self.buffers[stream_id].empty(): + return None + return await self.buffers[stream_id].get() async def open_stream( self, protocol_id: str, multi_addr: Multiaddr @@ -170,6 +176,7 @@ class Mplex(IMuxedConn): :return: stream_id, flag, message contents """ + # FIXME: No timeout is used in Go implementation. # Timeout is set to a relatively small value to alleviate wait time to exit # loop in handle_incoming timeout = 0.1 diff --git a/libp2p/stream_muxer/mplex/mplex_stream.py b/libp2p/stream_muxer/mplex/mplex_stream.py index 6df8227..afc49ee 100644 --- a/libp2p/stream_muxer/mplex/mplex_stream.py +++ b/libp2p/stream_muxer/mplex/mplex_stream.py @@ -37,33 +37,31 @@ class MplexStream(IMuxedStream): self.local_closed = False self.remote_closed = False self.stream_lock = asyncio.Lock() - self._buf = None + self._buf = b"" - async def read(self, n: int = -1) -> bytes: + async def read(self, n) -> bytes: """ - read messages associated with stream from buffer til end of file + Read up to n bytes. Read possibly returns fewer than `n` bytes, + if there are not enough bytes in the Mplex buffer. :param n: number of bytes to read - :return: bytes of input + :return: bytes actually read """ - if n == -1: - return await self.mplex_conn.read_buffer(self.stream_id) - return await self.read_bytes(n) - - async def read_bytes(self, n: int) -> bytes: - if self._buf is None: + # If the buffer is empty at first, blocking wait for data. + if len(self._buf) == 0: self._buf = await self.mplex_conn.read_buffer(self.stream_id) - n_read = 0 - bytes_buf = BytesIO() - while self._buf is not None and n_read < n: - n_to_read = min(n - n_read, len(self._buf)) - bytes_buf.write(self._buf[:n_to_read]) - if n_to_read == n - n_read: - self._buf = self._buf[n_to_read:] - else: - self._buf = None - self._buf = await self.mplex_conn.read_buffer(self.stream_id) - n_read += n_to_read - return bytes_buf.getvalue() + # Here, `self._buf` should never be `None`. + if self._buf is None or len(self._buf) == 0: + raise Exception("start to `read_buffer_nonblocking` only when there are bytes read.") + + while len(self._buf) < n: + new_bytes = await self.mplex_conn.read_buffer_nonblocking(self.stream_id) + if new_bytes is None: + # Nothing to read in the `MplexConn` buffer + break + self._buf += new_bytes + payload = self._buf[:n] + self._buf = self._buf[n:] + return payload async def write(self, data: bytes) -> int: """