diff --git a/libp2p/stream_muxer/abc.py b/libp2p/stream_muxer/abc.py index 2c577c7..547e917 100644 --- a/libp2p/stream_muxer/abc.py +++ b/libp2p/stream_muxer/abc.py @@ -82,11 +82,10 @@ class IMuxedStream(ReadWriteCloser): mplex_conn: IMuxedConn @abstractmethod - async def reset(self) -> bool: + async def reset(self) -> None: """ closes both ends of the stream tells this remote side to hang up - :return: true if successful """ @abstractmethod diff --git a/libp2p/stream_muxer/mplex/mplex.py b/libp2p/stream_muxer/mplex/mplex.py index af1282e..f342978 100644 --- a/libp2p/stream_muxer/mplex/mplex.py +++ b/libp2p/stream_muxer/mplex/mplex.py @@ -173,6 +173,7 @@ class Mplex(IMuxedConn): if flag == HeaderTags.NewStream.value: if is_stream_id_seen: # `NewStream` for the same id is received twice... + # TODO: Shutdown pass await self.accept_stream(stream_id, message.decode()) elif flag in ( @@ -187,10 +188,49 @@ class Mplex(IMuxedConn): async with self.streams_lock: stream = self.streams[stream_id] await stream.incoming_data.put(message) - # elif flag in ( - # HeaderTags.CloseInitiator.value, - # HeaderTags.CloseReceiver.value - # ): + elif flag in ( + HeaderTags.CloseInitiator.value, + HeaderTags.CloseReceiver.value, + ): + if not is_stream_id_seen: + continue + stream: MplexStream + async with self.streams_lock: + stream = self.streams[stream_id] + is_local_closed: bool + async with stream.close_lock: + stream.event_remote_closed.set() + is_local_closed = stream.event_local_closed.is_set() + # If local is also closed, both sides are closed. Then, we should clean up + # this stream. + if is_local_closed: + async with self.streams_lock: + del self.streams[stream_id] + elif flag in ( + HeaderTags.ResetInitiator.value, + HeaderTags.ResetReceiver.value, + ): + if not is_stream_id_seen: + # This is *ok*. We forget the stream on reset. + continue + stream: MplexStream + async with self.streams_lock: + stream = self.streams[stream_id] + async with stream.close_lock: + if not stream.event_remote_closed.is_set(): + stream.event_reset.set() + stream.event_remote_closed.set() + if not stream.event_local_closed.is_set(): + stream.event_local_closed.close() + async with self.streams_lock: + del self.streams[stream_id] + else: + # TODO: logging + print(f"message with unknown header on stream {stream_id}") + if is_stream_id_seen: + async with self.streams_lock: + stream = self.streams[stream_id] + await stream.reset() # Force context switch await asyncio.sleep(0)