From 6c1f77dc1a946733411ff0103e8f0ae06d6ddcab Mon Sep 17 00:00:00 2001 From: mhchia Date: Fri, 6 Sep 2019 21:35:15 +0800 Subject: [PATCH] Fix: Change the `event.close` to `event.set` And add missing parts. --- libp2p/pubsub/pubsub.py | 1 - libp2p/stream_muxer/mplex/mplex.py | 28 ++++++++++----------- libp2p/stream_muxer/mplex/mplex_stream.py | 4 +-- tests/libp2p/test_libp2p.py | 2 +- tests/libp2p/test_notify.py | 2 +- tests/protocol_muxer/test_protocol_muxer.py | 2 +- tests/utils.py | 1 - 7 files changed, 19 insertions(+), 21 deletions(-) diff --git a/libp2p/pubsub/pubsub.py b/libp2p/pubsub/pubsub.py index c55a183..5c0466c 100644 --- a/libp2p/pubsub/pubsub.py +++ b/libp2p/pubsub/pubsub.py @@ -3,7 +3,6 @@ import logging import time from typing import ( TYPE_CHECKING, - Any, Awaitable, Callable, Dict, diff --git a/libp2p/stream_muxer/mplex/mplex.py b/libp2p/stream_muxer/mplex/mplex.py index f342978..1e8823a 100644 --- a/libp2p/stream_muxer/mplex/mplex.py +++ b/libp2p/stream_muxer/mplex/mplex.py @@ -166,8 +166,11 @@ class Mplex(IMuxedConn): if channel_id is not None and flag is not None and message is not None: stream_id = StreamID(channel_id=channel_id, is_initiator=bool(flag & 1)) is_stream_id_seen: bool + stream: MplexStream async with self.streams_lock: is_stream_id_seen = stream_id in self.streams + if is_stream_id_seen: + stream = self.streams[stream_id] # Other consequent stream message should wait until the stream get accepted # TODO: Handle more tags, and refactor `HeaderTags` if flag == HeaderTags.NewStream.value: @@ -185,8 +188,6 @@ class Mplex(IMuxedConn): # before. It is abnormal. Possibly disconnect? # TODO: Warn and emit logs about this. continue - async with self.streams_lock: - stream = self.streams[stream_id] await stream.incoming_data.put(message) elif flag in ( HeaderTags.CloseInitiator.value, @@ -194,15 +195,17 @@ class Mplex(IMuxedConn): ): if not is_stream_id_seen: continue - stream: MplexStream - async with self.streams_lock: - stream = self.streams[stream_id] + # NOTE: If remote is already closed, then return: Technically a bug + # on the other side. We should consider killing the connection. + async with stream.close_lock: + if stream.event_remote_closed.is_set(): + continue is_local_closed: bool async with stream.close_lock: stream.event_remote_closed.set() is_local_closed = stream.event_local_closed.is_set() # If local is also closed, both sides are closed. Then, we should clean up - # this stream. + # the entry of this stream, to avoid others from accessing it. if is_local_closed: async with self.streams_lock: del self.streams[stream_id] @@ -213,24 +216,21 @@ class Mplex(IMuxedConn): if not is_stream_id_seen: # This is *ok*. We forget the stream on reset. continue - stream: MplexStream - async with self.streams_lock: - stream = self.streams[stream_id] async with stream.close_lock: if not stream.event_remote_closed.is_set(): + # TODO: Why? Only if remote is not closed before then reset. stream.event_reset.set() + stream.event_remote_closed.set() + # If local is not closed, we should close it. if not stream.event_local_closed.is_set(): - stream.event_local_closed.close() + stream.event_local_closed.set() async with self.streams_lock: del self.streams[stream_id] else: # TODO: logging - print(f"message with unknown header on stream {stream_id}") if is_stream_id_seen: - async with self.streams_lock: - stream = self.streams[stream_id] - await stream.reset() + await stream.reset() # Force context switch await asyncio.sleep(0) diff --git a/libp2p/stream_muxer/mplex/mplex_stream.py b/libp2p/stream_muxer/mplex/mplex_stream.py index e537dda..18c8ff0 100644 --- a/libp2p/stream_muxer/mplex/mplex_stream.py +++ b/libp2p/stream_muxer/mplex/mplex_stream.py @@ -4,8 +4,8 @@ from typing import TYPE_CHECKING from libp2p.stream_muxer.abc import IMuxedStream from .constants import HeaderTags -from .exceptions import MplexStreamReset, MplexStreamEOF from .datastructures import StreamID +from .exceptions import MplexStreamEOF, MplexStreamReset if TYPE_CHECKING: from libp2p.stream_muxer.mplex.mplex import Mplex @@ -55,7 +55,7 @@ class MplexStream(IMuxedStream): return self.stream_id.is_initiator async def _wait_for_data(self) -> None: - done, pending = await asyncio.wait( + done, pending = await asyncio.wait( # type: ignore [ self.event_reset.wait(), self.event_remote_closed.wait(), diff --git a/tests/libp2p/test_libp2p.py b/tests/libp2p/test_libp2p.py index b4a643d..8090f5e 100644 --- a/tests/libp2p/test_libp2p.py +++ b/tests/libp2p/test_libp2p.py @@ -2,8 +2,8 @@ import multiaddr import pytest from libp2p.peer.peerinfo import info_from_p2p_addr -from tests.utils import cleanup, set_up_nodes_by_transport_opt from tests.constants import MAX_READ_LEN +from tests.utils import cleanup, set_up_nodes_by_transport_opt @pytest.mark.asyncio diff --git a/tests/libp2p/test_notify.py b/tests/libp2p/test_notify.py index 206f3e3..e21030a 100644 --- a/tests/libp2p/test_notify.py +++ b/tests/libp2p/test_notify.py @@ -16,8 +16,8 @@ from libp2p import initialize_default_swarm, new_node from libp2p.crypto.rsa import create_new_key_pair from libp2p.host.basic_host import BasicHost from libp2p.network.notifee_interface import INotifee -from tests.utils import cleanup, perform_two_host_set_up from tests.constants import MAX_READ_LEN +from tests.utils import cleanup, perform_two_host_set_up ACK = "ack:" diff --git a/tests/protocol_muxer/test_protocol_muxer.py b/tests/protocol_muxer/test_protocol_muxer.py index 8fb1537..7830aaa 100644 --- a/tests/protocol_muxer/test_protocol_muxer.py +++ b/tests/protocol_muxer/test_protocol_muxer.py @@ -1,7 +1,7 @@ import pytest from libp2p.protocol_muxer.exceptions import MultiselectClientError -from tests.utils import cleanup, set_up_nodes_by_transport_opt, echo_stream_handler +from tests.utils import cleanup, echo_stream_handler, set_up_nodes_by_transport_opt # TODO: Add tests for multiple streams being opened on different # protocols through the same connection diff --git a/tests/utils.py b/tests/utils.py index 1f1cfc4..a26ebc5 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -5,7 +5,6 @@ import multiaddr from libp2p import new_node from libp2p.peer.peerinfo import info_from_p2p_addr - from tests.constants import MAX_READ_LEN