From 3a91f114abc2f301a742b7acc14655f289cbc9da Mon Sep 17 00:00:00 2001 From: mhchia Date: Tue, 4 Feb 2020 17:05:53 +0800 Subject: [PATCH] Swarm: add `default_stream_handler` Advantage: - To avoid `None` checks - If users forget to register a stream handler for `Swarm`, with the default stream handler, opened streams aren't removed until `Swarm` finishes. --- libp2p/network/connection/swarm_connection.py | 11 ++++++----- libp2p/network/swarm.py | 17 +++++++++++++---- 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/libp2p/network/connection/swarm_connection.py b/libp2p/network/connection/swarm_connection.py index 0e930f5..cc13dcc 100644 --- a/libp2p/network/connection/swarm_connection.py +++ b/libp2p/network/connection/swarm_connection.py @@ -68,11 +68,12 @@ class SwarmConn(INetConn): async def _handle_muxed_stream(self, muxed_stream: IMuxedStream) -> None: net_stream = self._add_stream(muxed_stream) - if self.swarm.common_stream_handler is not None: - try: - await self.swarm.common_stream_handler(net_stream) - finally: - self.remove_stream(net_stream) + try: + # Ignore type here since mypy complains: https://github.com/python/mypy/issues/2427 + await self.swarm.common_stream_handler(net_stream) # type: ignore + finally: + # As long as `common_stream_handler`, remove the stream. + self.remove_stream(net_stream) def _add_stream(self, muxed_stream: IMuxedStream) -> NetStream: net_stream = NetStream(muxed_stream) diff --git a/libp2p/network/swarm.py b/libp2p/network/swarm.py index e03d65b..9a0279d 100644 --- a/libp2p/network/swarm.py +++ b/libp2p/network/swarm.py @@ -1,5 +1,5 @@ import logging -from typing import Dict, List, Optional +from typing import Dict, List from async_service import Service from multiaddr import Multiaddr @@ -31,6 +31,13 @@ from .stream.net_stream_interface import INetStream logger = logging.getLogger("libp2p.network.swarm") +def create_default_stream_handler(network: INetworkService) -> StreamHandlerFn: + async def stream_handler(stream: INetStream) -> None: + await network.get_manager().wait_finished() + + return stream_handler + + class Swarm(Service, INetworkService): self_id: ID @@ -41,7 +48,7 @@ class Swarm(Service, INetworkService): # whereas in Go one `peer_id` may point to multiple connections. connections: Dict[ID, INetConn] listeners: Dict[str, IListener] - common_stream_handler: Optional[StreamHandlerFn] + common_stream_handler: StreamHandlerFn notifees: List[INotifee] @@ -62,7 +69,8 @@ class Swarm(Service, INetworkService): # Create Notifee array self.notifees = [] - self.common_stream_handler = None + # Ignore type here since mypy complains: https://github.com/python/mypy/issues/2427 + self.common_stream_handler = create_default_stream_handler(self) # type: ignore async def run(self) -> None: await self.manager.wait_finished() @@ -71,7 +79,8 @@ class Swarm(Service, INetworkService): return self.self_id def set_stream_handler(self, stream_handler: StreamHandlerFn) -> None: - self.common_stream_handler = stream_handler + # Ignore type here since mypy complains: https://github.com/python/mypy/issues/2427 + self.common_stream_handler = stream_handler # type: ignore async def dial_peer(self, peer_id: ID) -> INetConn: """