diff --git a/libp2p/stream_muxer/abc.py b/libp2p/stream_muxer/abc.py index 0600dee..6e7737e 100644 --- a/libp2p/stream_muxer/abc.py +++ b/libp2p/stream_muxer/abc.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from libp2p.peer.id import ID from libp2p.security.secure_conn_interface import ISecureConn @@ -51,20 +51,6 @@ class IMuxedConn(ABC): :return: true if successful """ - @abstractmethod - async def read_buffer(self, stream_id: StreamID) -> bytes: - """ - Read a message from stream_id's buffer, check raw connection for new messages - :param stream_id: stream id of stream to read from - :return: message read - """ - - @abstractmethod - async def read_buffer_nonblocking(self, stream_id: StreamID) -> Optional[bytes]: - """ - Read a message from `stream_id`'s buffer, non-blockingly. - """ - @abstractmethod async def open_stream(self) -> "IMuxedStream": """ @@ -73,7 +59,7 @@ class IMuxedConn(ABC): """ @abstractmethod - async def accept_stream(self, name: str) -> None: + async def accept_stream(self, stream_id: StreamID, name: str) -> None: """ accepts a muxed stream opened by the other end """ diff --git a/libp2p/stream_muxer/mplex/exceptions.py b/libp2p/stream_muxer/mplex/exceptions.py index 74a6ade..bd4ceb5 100644 --- a/libp2p/stream_muxer/mplex/exceptions.py +++ b/libp2p/stream_muxer/mplex/exceptions.py @@ -1,2 +1,13 @@ -class StreamNotFound(Exception): +from libp2p.exceptions import BaseLibp2pError + + +class MplexError(BaseLibp2pError): + pass + + +class MplexShutdown(MplexError): + pass + + +class StreamNotFound(MplexError): pass diff --git a/libp2p/stream_muxer/mplex/mplex.py b/libp2p/stream_muxer/mplex/mplex.py index cf1ec91..af1282e 100644 --- a/libp2p/stream_muxer/mplex/mplex.py +++ b/libp2p/stream_muxer/mplex/mplex.py @@ -16,7 +16,6 @@ from libp2p.utils import ( from .constants import HeaderTags from .datastructures import StreamID -from .exceptions import StreamNotFound from .mplex_stream import MplexStream MPLEX_PROTOCOL_ID = TProtocol("/mplex/6.7.0") @@ -32,10 +31,9 @@ class Mplex(IMuxedConn): # TODO: `dataIn` in go implementation. Should be size of 8. # TODO: Also, `dataIn` is closed indicating EOF in Go. We don't have similar strategies # to let the `MplexStream`s know that EOF arrived (#235). - buffers: Dict[StreamID, "asyncio.Queue[bytes]"] - stream_queue: "asyncio.Queue[StreamID]" next_channel_id: int - buffers_lock: asyncio.Lock + streams: Dict[StreamID, MplexStream] + streams_lock: asyncio.Lock shutdown: asyncio.Event _tasks: List["asyncio.Future[Any]"] @@ -65,12 +63,10 @@ class Mplex(IMuxedConn): self.peer_id = peer_id # Mapping from stream ID -> buffer of messages for that stream - self.buffers = {} - self.buffers_lock = asyncio.Lock() + self.streams = {} + self.streams_lock = asyncio.Lock() self.shutdown = asyncio.Event() - self.stream_queue = asyncio.Queue() - self._tasks = [] # Kick off reading @@ -95,29 +91,6 @@ class Mplex(IMuxedConn): """ raise NotImplementedError() - async def read_buffer(self, stream_id: StreamID) -> bytes: - """ - Read a message from buffer of the stream specified by `stream_id`, - check secured connection for new messages. - `StreamNotFound` is raised when stream `stream_id` is not found in `Mplex`. - :param stream_id: stream id of stream to read from - :return: message read - """ - if stream_id not in self.buffers: - raise StreamNotFound(f"stream {stream_id} is not found") - return await self.buffers[stream_id].get() - - async def read_buffer_nonblocking(self, stream_id: StreamID) -> Optional[bytes]: - """ - Read a message from buffer of the stream specified by `stream_id`, non-blockingly. - `StreamNotFound` is raised when stream `stream_id` is not found in `Mplex`. - """ - if stream_id not in self.buffers: - raise StreamNotFound(f"stream {stream_id} is not found") - if self.buffers[stream_id].empty(): - return None - return await self.buffers[stream_id].get() - def _get_next_channel_id(self) -> int: """ Get next available stream id @@ -127,6 +100,12 @@ class Mplex(IMuxedConn): self.next_channel_id += 1 return next_id + async def _initialize_stream(self, stream_id: StreamID, name: str) -> MplexStream: + async with self.streams_lock: + stream = MplexStream(name, stream_id, self) + self.streams[stream_id] = stream + return stream + async def open_stream(self) -> IMuxedStream: """ creates a new muxed_stream @@ -134,19 +113,18 @@ class Mplex(IMuxedConn): """ channel_id = self._get_next_channel_id() stream_id = StreamID(channel_id=channel_id, is_initiator=True) - name = str(channel_id) - stream = MplexStream(name, stream_id, self) - self.buffers[stream_id] = asyncio.Queue() # Default stream name is the `channel_id` + name = str(channel_id) + stream = await self._initialize_stream(stream_id, name) await self.send_message(HeaderTags.NewStream, name.encode(), stream_id) return stream - async def accept_stream(self, name: str) -> None: + async def accept_stream(self, stream_id: StreamID, name: str) -> None: """ accepts a muxed stream opened by the other end """ - stream_id = await self.stream_queue.get() - stream = MplexStream(name, stream_id, self) + stream = await self._initialize_stream(stream_id, name) + # Perform protocol negotiation for the stream. self._tasks.append(asyncio.ensure_future(self.generic_protocol_handler(stream))) async def send_message( @@ -185,22 +163,30 @@ class Mplex(IMuxedConn): while True: channel_id, flag, message = await self.read_message() - if channel_id is not None and flag is not None and message is not None: stream_id = StreamID(channel_id=channel_id, is_initiator=bool(flag & 1)) - if stream_id not in self.buffers: - self.buffers[stream_id] = asyncio.Queue() - await self.stream_queue.put(stream_id) - + is_stream_id_seen: bool + async with self.streams_lock: + is_stream_id_seen = stream_id in self.streams + # Other consequent stream message should wait until the stream get accepted # TODO: Handle more tags, and refactor `HeaderTags` if flag == HeaderTags.NewStream.value: - # new stream detected on connection - await self.accept_stream(message.decode()) + if is_stream_id_seen: + # `NewStream` for the same id is received twice... + pass + await self.accept_stream(stream_id, message.decode()) elif flag in ( HeaderTags.MessageInitiator.value, HeaderTags.MessageReceiver.value, ): - await self.buffers[stream_id].put(message) + if not is_stream_id_seen: + # We receive a message of the stream `stream_id` which is not accepted + # before. It is abnormal. Possibly disconnect? + # TODO: Warn and emit logs about this. + continue + async with self.streams_lock: + stream = self.streams[stream_id] + await stream.incoming_data.put(message) # elif flag in ( # HeaderTags.CloseInitiator.value, # HeaderTags.CloseReceiver.value diff --git a/libp2p/stream_muxer/mplex/mplex_stream.py b/libp2p/stream_muxer/mplex/mplex_stream.py index d0f0801..d257297 100644 --- a/libp2p/stream_muxer/mplex/mplex_stream.py +++ b/libp2p/stream_muxer/mplex/mplex_stream.py @@ -23,6 +23,8 @@ class MplexStream(IMuxedStream): close_lock: asyncio.Lock + incoming_data: "asyncio.Queue[bytes]" + event_local_closed: asyncio.Event event_remote_closed: asyncio.Event event_reset: asyncio.Event @@ -44,6 +46,7 @@ class MplexStream(IMuxedStream): self.event_remote_closed = asyncio.Event() self.event_reset = asyncio.Event() self.close_lock = asyncio.Lock() + self.incoming_data = asyncio.Queue() self._buf = bytearray() @property @@ -58,7 +61,6 @@ class MplexStream(IMuxedStream): :param n: number of bytes to read :return: bytes actually read """ - # TODO: Handle `StreamNotFound` raised in `self.mplex_conn.read_buffer`. # TODO: Add exceptions and handle/raise them in this class. if n < 0 and n != -1: raise ValueError( @@ -66,17 +68,16 @@ class MplexStream(IMuxedStream): ) # If the buffer is empty at first, blocking wait for data. if len(self._buf) == 0: - self._buf.extend(await self.mplex_conn.read_buffer(self.stream_id)) + self._buf.extend(await self.incoming_data.get()) # FIXME: If `n == -1`, we should blocking read until EOF, instead of returning when # no message is available. # If `n >= 0`, read up to `n` bytes. # Else, read until no message is available. while len(self._buf) < n or n == -1: - new_bytes = await self.mplex_conn.read_buffer_nonblocking(self.stream_id) - if new_bytes is None: - # Nothing to read in the `MplexConn` buffer + if self.incoming_data.empty(): break + new_bytes = await self.incoming_data.get() self._buf.extend(new_bytes) payload: bytearray if n == -1: @@ -122,8 +123,8 @@ class MplexStream(IMuxedStream): if _is_remote_closed: # Both sides are closed, we can safely remove the buffer from the dict. - async with self.mplex_conn.buffers_lock: - del self.mplex_conn.buffers[self.stream_id] + async with self.mplex_conn.streams_lock: + del self.mplex_conn.streams[self.stream_id] async def reset(self) -> None: """ @@ -152,8 +153,8 @@ class MplexStream(IMuxedStream): self.event_local_closed.set() self.event_remote_closed.set() - async with self.mplex_conn.buffers_lock: - del self.mplex_conn.buffers[self.stream_id] + async with self.mplex_conn.streams_lock: + del self.mplex_conn.streams[self.stream_id] # TODO deadline not in use def set_deadline(self, ttl: int) -> bool: