diff --git a/libp2p/network/connection/swarm_connection.py b/libp2p/network/connection/swarm_connection.py index 1e31033..a4fc8be 100644 --- a/libp2p/network/connection/swarm_connection.py +++ b/libp2p/network/connection/swarm_connection.py @@ -29,10 +29,19 @@ class SwarmConn(INetConn, Service): self.streams = set() self.event_closed = trio.Event() + @property + def is_closed(self) -> bool: + return self.event_closed.is_set() + async def close(self) -> None: if self.event_closed.is_set(): return self.event_closed.set() + await self._cleanup() + # Cancel service + await self.manager.stop() + + async def _cleanup(self) -> None: self.swarm.remove_conn(self) await self.muxed_conn.close() @@ -51,28 +60,23 @@ class SwarmConn(INetConn, Service): while self.manager.is_running: try: stream = await self.muxed_conn.accept_stream() - except MuxedConnUnavailable: - # If there is anything wrong in the MuxedConn, - # we should break the loop and close the connection. - break # Asynchronously handle the accepted stream, to avoid blocking the next stream. + except MuxedConnUnavailable: + break self.manager.run_task(self._handle_muxed_stream, stream) await self.close() - async def _call_stream_handler(self, net_stream: NetStream) -> None: - try: - await self.swarm.common_stream_handler(net_stream) - # TODO: More exact exceptions - except Exception: - # TODO: Emit logs. - # TODO: Clean up and remove the stream from SwarmConn if there is anything wrong. - self.remove_stream(net_stream) - async def _handle_muxed_stream(self, muxed_stream: IMuxedStream) -> None: net_stream = await self._add_stream(muxed_stream) if self.swarm.common_stream_handler is not None: - await self._call_stream_handler(net_stream) + try: + await self.swarm.common_stream_handler(net_stream) + # TODO: More exact exceptions + except Exception: + # TODO: Emit logs. + # TODO: Clean up and remove the stream from SwarmConn if there is anything wrong. + self.remove_stream(net_stream) async def _add_stream(self, muxed_stream: IMuxedStream) -> NetStream: net_stream = NetStream(muxed_stream) @@ -84,7 +88,8 @@ class SwarmConn(INetConn, Service): await self.swarm.notify_disconnected(self) async def run(self) -> None: - await self._handle_new_streams() + self.manager.run_task(self._handle_new_streams) + await self.manager.wait_finished() async def new_stream(self) -> NetStream: muxed_stream = await self.muxed_conn.open_stream() diff --git a/libp2p/network/swarm.py b/libp2p/network/swarm.py index 4bf86dd..cdb80a4 100644 --- a/libp2p/network/swarm.py +++ b/libp2p/network/swarm.py @@ -44,6 +44,7 @@ class Swarm(INetwork, Service): common_stream_handler: Optional[StreamHandlerFn] notifees: List[INotifee] + event_closed: trio.Event def __init__( self, @@ -62,6 +63,8 @@ class Swarm(INetwork, Service): # Create Notifee array self.notifees = [] + self.event_closed = trio.Event() + self.common_stream_handler = None async def run(self) -> None: @@ -227,10 +230,19 @@ class Swarm(INetwork, Service): return False async def close(self) -> None: - # TODO: Prevent from new listeners and conns being added. + if self.event_closed.is_set(): + return + self.event_closed.set() # Reference: https://github.com/libp2p/go-libp2p-swarm/blob/8be680aef8dea0a4497283f2f98470c2aeae6b65/swarm.go#L124-L134 # noqa: E501 + async with trio.open_nursery() as nursery: + for conn in self.connections.values(): + nursery.start_soon(conn.close) + async with trio.open_nursery() as nursery: + for listener in self.listeners.values(): + nursery.start_soon(listener.close) + + # Cancel tasks await self.manager.stop() - await self.manager.wait_finished() logger.debug("swarm successfully closed") async def close_peer(self, peer_id: ID) -> None: @@ -270,8 +282,6 @@ class Swarm(INetwork, Service): # Notifee - # TODO: Remeber the spawn notifying tasks and clean them up when closing. - def register_notifee(self, notifee: INotifee) -> None: """ :param notifee: object implementing Notifee interface diff --git a/libp2p/pubsub/pubsub.py b/libp2p/pubsub/pubsub.py index 3370ea3..0c3b162 100644 --- a/libp2p/pubsub/pubsub.py +++ b/libp2p/pubsub/pubsub.py @@ -58,11 +58,11 @@ class TopicValidator(NamedTuple): # TODO: Add interface for Pubsub -class BasePubsub(ABC): +class IPubsub(ABC): pass -class Pubsub(BasePubsub, Service): +class Pubsub(IPubsub, Service): host: IHost diff --git a/libp2p/stream_muxer/abc.py b/libp2p/stream_muxer/abc.py index 12a8f80..e34295c 100644 --- a/libp2p/stream_muxer/abc.py +++ b/libp2p/stream_muxer/abc.py @@ -33,6 +33,7 @@ class IMuxedConn(ServiceAPI): async def close(self) -> None: """close connection.""" + @property @abstractmethod def is_closed(self) -> bool: """ diff --git a/libp2p/stream_muxer/mplex/mplex.py b/libp2p/stream_muxer/mplex/mplex.py index e23da00..b7b3a3a 100644 --- a/libp2p/stream_muxer/mplex/mplex.py +++ b/libp2p/stream_muxer/mplex/mplex.py @@ -91,7 +91,9 @@ class Mplex(IMuxedConn, Service): await self.secured_conn.close() # Blocked until `close` is finally set. await self.event_closed.wait() + await self.manager.stop() + @property def is_closed(self) -> bool: """ check connection is fully closed. @@ -213,10 +215,6 @@ class Mplex(IMuxedConn, Service): return channel_id, flag, message - @property - def _id(self) -> int: - return 0 if self.is_initiator else 1 - async def _handle_incoming_message(self) -> None: """ Read and handle a new incoming message. diff --git a/libp2p/stream_muxer/mplex/mplex_stream.py b/libp2p/stream_muxer/mplex/mplex_stream.py index eeefc42..011cd3a 100644 --- a/libp2p/stream_muxer/mplex/mplex_stream.py +++ b/libp2p/stream_muxer/mplex/mplex_stream.py @@ -156,26 +156,22 @@ class MplexStream(IMuxedStream): if self.event_local_closed.is_set(): return - print(f"!@# stream.close: {self.muxed_conn._id}: step=0") flag = ( HeaderTags.CloseInitiator if self.is_initiator else HeaderTags.CloseReceiver ) # TODO: Raise when `muxed_conn.send_message` fails and `Mplex` isn't shutdown. await self.muxed_conn.send_message(flag, None, self.stream_id) - print(f"!@# stream.close: {self.muxed_conn._id}: step=1") _is_remote_closed: bool async with self.close_lock: self.event_local_closed.set() _is_remote_closed = self.event_remote_closed.is_set() - print(f"!@# stream.close: {self.muxed_conn._id}: step=2") if _is_remote_closed: # Both sides are closed, we can safely remove the buffer from the dict. async with self.muxed_conn.streams_lock: if self.stream_id in self.muxed_conn.streams: del self.muxed_conn.streams[self.stream_id] - print(f"!@# stream.close: {self.muxed_conn._id}: step=3") async def reset(self) -> None: """closes both ends of the stream tells this remote side to hang up.""" diff --git a/libp2p/tools/factories.py b/libp2p/tools/factories.py index e179889..a9eb6a5 100644 --- a/libp2p/tools/factories.py +++ b/libp2p/tools/factories.py @@ -69,9 +69,8 @@ async def raw_conn_factory( tcp_transport = TCP() listener = tcp_transport.create_listener(tcp_stream_handler) await listener.listen(LISTEN_MADDR, nursery) - listening_maddr = listener.multiaddrs[0] + listening_maddr = listener.get_addrs()[0] conn_0 = await tcp_transport.dial(listening_maddr) - print("raw_conn_factory") yield conn_0, conn_1 diff --git a/libp2p/tools/utils.py b/libp2p/tools/utils.py index a66155c..216fdd8 100644 --- a/libp2p/tools/utils.py +++ b/libp2p/tools/utils.py @@ -39,3 +39,6 @@ def create_echo_stream_handler( await stream.write(resp.encode()) return echo_stream_handler + + +# TODO: Service `external_api` diff --git a/libp2p/transport/listener_interface.py b/libp2p/transport/listener_interface.py index 6d73723..d170d1d 100644 --- a/libp2p/transport/listener_interface.py +++ b/libp2p/transport/listener_interface.py @@ -22,3 +22,7 @@ class IListener(ABC): :return: return list of addrs """ + + @abstractmethod + async def close(self) -> None: + ... diff --git a/libp2p/transport/tcp/tcp.py b/libp2p/transport/tcp/tcp.py index 04d8874..8c46a4a 100644 --- a/libp2p/transport/tcp/tcp.py +++ b/libp2p/transport/tcp/tcp.py @@ -16,10 +16,10 @@ logger = logging.getLogger("libp2p.transport.tcp") class TCPListener(IListener): - multiaddrs: List[Multiaddr] + listeners: List[trio.SocketListener] def __init__(self, handler_function: THandler) -> None: - self.multiaddrs = [] + self.listeners = [] self.handler = handler_function # TODO: Get rid of `nursery`? @@ -50,8 +50,7 @@ class TCPListener(IListener): int(maddr.value_for_protocol("tcp")), maddr.value_for_protocol("ip4"), ) - socket = listeners[0].socket - self.multiaddrs.append(_multiaddr_from_socket(socket)) + self.listeners.extend(listeners) def get_addrs(self) -> Tuple[Multiaddr, ...]: """ @@ -59,7 +58,14 @@ class TCPListener(IListener): :return: return list of addrs """ - return tuple(self.multiaddrs) + return tuple( + _multiaddr_from_socket(listener.socket) for listener in self.listeners + ) + + async def close(self) -> None: + async with trio.open_nursery() as nursery: + for listener in self.listeners: + nursery.start_soon(listener.aclose) class TCP(ITransport): diff --git a/tests/network/test_swarm_conn.py b/tests/network/test_swarm_conn.py index e1c1285..1bfd7d8 100644 --- a/tests/network/test_swarm_conn.py +++ b/tests/network/test_swarm_conn.py @@ -1,20 +1,23 @@ import pytest import trio +from trio.testing import wait_all_tasks_blocked @pytest.mark.trio async def test_swarm_conn_close(swarm_conn_pair): conn_0, conn_1 = swarm_conn_pair - assert not conn_0.event_closed.is_set() - assert not conn_1.event_closed.is_set() + assert not conn_0.is_closed + assert not conn_1.is_closed await conn_0.close() - await trio.sleep(0.01) + await trio.sleep(0.1) + await wait_all_tasks_blocked() + await conn_0.manager.wait_finished() - assert conn_0.event_closed.is_set() - assert conn_1.event_closed.is_set() + assert conn_0.is_closed + assert conn_1.is_closed assert conn_0 not in conn_0.swarm.connections.values() assert conn_1 not in conn_1.swarm.connections.values() diff --git a/tests/stream_muxer/test_mplex_conn.py b/tests/stream_muxer/test_mplex_conn.py index 4cedc36..4bff2d6 100644 --- a/tests/stream_muxer/test_mplex_conn.py +++ b/tests/stream_muxer/test_mplex_conn.py @@ -8,10 +8,6 @@ async def test_mplex_conn(mplex_conn_pair): assert len(conn_0.streams) == 0 assert len(conn_1.streams) == 0 - assert not conn_0.event_shutting_down.is_set() - assert not conn_1.event_shutting_down.is_set() - assert not conn_0.event_closed.is_set() - assert not conn_1.event_closed.is_set() # Test: Open a stream, and both side get 1 more stream. stream_0 = await conn_0.open_stream() @@ -29,10 +25,8 @@ async def test_mplex_conn(mplex_conn_pair): # Sleep for a while for both side to handle `close`. await trio.sleep(0.01) # Test: Both side is closed. - assert conn_0.event_shutting_down.is_set() - assert conn_0.event_closed.is_set() - assert conn_1.event_shutting_down.is_set() - assert conn_1.event_closed.is_set() + assert conn_0.is_closed + assert conn_1.is_closed # Test: All streams should have been closed. assert stream_0.event_remote_closed.is_set() assert stream_0.event_reset.is_set() diff --git a/tests/transport/test_tcp.py b/tests/transport/test_tcp.py index abd5884..247b5f9 100644 --- a/tests/transport/test_tcp.py +++ b/tests/transport/test_tcp.py @@ -38,8 +38,9 @@ async def test_tcp_dial(nursery): listener = transport.create_listener(handler) await listener.listen(LISTEN_MADDR, nursery) - assert len(listener.multiaddrs) == 1 - listen_addr = listener.multiaddrs[0] + addrs = listener.get_addrs() + assert len(addrs) == 1 + listen_addr = addrs[0] raw_conn = await transport.dial(listen_addr) data = b"123"