diff --git a/libp2p/stream_muxer/mplex/mplex.py b/libp2p/stream_muxer/mplex/mplex.py index 5d09b75..d85a7c2 100644 --- a/libp2p/stream_muxer/mplex/mplex.py +++ b/libp2p/stream_muxer/mplex/mplex.py @@ -6,6 +6,7 @@ from libp2p.exceptions import ParseError from libp2p.io.exceptions import IncompleteReadError from libp2p.network.typing import GenericProtocolHandlerFn from libp2p.peer.id import ID +from libp2p.protocol_muxer.exceptions import MultiselectError from libp2p.security.secure_conn_interface import ISecureConn from libp2p.stream_muxer.abc import IMuxedConn, IMuxedStream from libp2p.typing import TProtocol @@ -102,12 +103,6 @@ class Mplex(IMuxedConn): self.next_channel_id += 1 return next_id - async def _initialize_stream(self, stream_id: StreamID, name: str) -> MplexStream: - async with self.streams_lock: - stream = MplexStream(name, stream_id, self) - self.streams[stream_id] = stream - return stream - async def open_stream(self) -> IMuxedStream: """ creates a new muxed_stream @@ -117,17 +112,28 @@ class Mplex(IMuxedConn): stream_id = StreamID(channel_id=channel_id, is_initiator=True) # Default stream name is the `channel_id` name = str(channel_id) - stream = await self._initialize_stream(stream_id, name) + async with self.streams_lock: + stream = MplexStream(name, stream_id, self) await self.send_message(HeaderTags.NewStream, name.encode(), stream_id) + # TODO: is there a way to know if the peer accepted the stream? + # then we can safely register the stream + self.streams[stream_id] = stream return stream async def accept_stream(self, stream_id: StreamID, name: str) -> None: """ accepts a muxed stream opened by the other end """ - stream = await self._initialize_stream(stream_id, name) + async with self.streams_lock: + stream = MplexStream(name, stream_id, self) # Perform protocol negotiation for the stream. - self._tasks.append(asyncio.ensure_future(self.generic_protocol_handler(stream))) + try: + await self.generic_protocol_handler(stream) + except MultiselectError: + # TODO: what to do when stream protocol negotiation fail? + return + + self.streams[stream_id] = stream async def send_message( self, flag: HeaderTags, data: Optional[bytes], stream_id: StreamID @@ -180,7 +186,11 @@ class Mplex(IMuxedConn): # `NewStream` for the same id is received twice... # TODO: Shutdown pass - await self.accept_stream(stream_id, message.decode()) + self._tasks.append( + asyncio.ensure_future( + self.accept_stream(stream_id, message.decode()) + ) + ) elif flag in ( HeaderTags.MessageInitiator.value, HeaderTags.MessageReceiver.value,