Swarm: add default_stream_handler

Advantage:
- To avoid `None` checks
- If users forget to register a stream handler for `Swarm`,
with the default stream handler, opened streams aren't removed
until `Swarm` finishes.
This commit is contained in:
mhchia 2020-02-04 17:05:53 +08:00
parent 3fc60cb312
commit 3a91f114ab
No known key found for this signature in database
GPG Key ID: 389EFBEA1362589A
2 changed files with 19 additions and 9 deletions

View File

@ -68,11 +68,12 @@ class SwarmConn(INetConn):
async def _handle_muxed_stream(self, muxed_stream: IMuxedStream) -> None: async def _handle_muxed_stream(self, muxed_stream: IMuxedStream) -> None:
net_stream = self._add_stream(muxed_stream) net_stream = self._add_stream(muxed_stream)
if self.swarm.common_stream_handler is not None: try:
try: # Ignore type here since mypy complains: https://github.com/python/mypy/issues/2427
await self.swarm.common_stream_handler(net_stream) await self.swarm.common_stream_handler(net_stream) # type: ignore
finally: finally:
self.remove_stream(net_stream) # As long as `common_stream_handler`, remove the stream.
self.remove_stream(net_stream)
def _add_stream(self, muxed_stream: IMuxedStream) -> NetStream: def _add_stream(self, muxed_stream: IMuxedStream) -> NetStream:
net_stream = NetStream(muxed_stream) net_stream = NetStream(muxed_stream)

View File

@ -1,5 +1,5 @@
import logging import logging
from typing import Dict, List, Optional from typing import Dict, List
from async_service import Service from async_service import Service
from multiaddr import Multiaddr from multiaddr import Multiaddr
@ -31,6 +31,13 @@ from .stream.net_stream_interface import INetStream
logger = logging.getLogger("libp2p.network.swarm") logger = logging.getLogger("libp2p.network.swarm")
def create_default_stream_handler(network: INetworkService) -> StreamHandlerFn:
async def stream_handler(stream: INetStream) -> None:
await network.get_manager().wait_finished()
return stream_handler
class Swarm(Service, INetworkService): class Swarm(Service, INetworkService):
self_id: ID self_id: ID
@ -41,7 +48,7 @@ class Swarm(Service, INetworkService):
# whereas in Go one `peer_id` may point to multiple connections. # whereas in Go one `peer_id` may point to multiple connections.
connections: Dict[ID, INetConn] connections: Dict[ID, INetConn]
listeners: Dict[str, IListener] listeners: Dict[str, IListener]
common_stream_handler: Optional[StreamHandlerFn] common_stream_handler: StreamHandlerFn
notifees: List[INotifee] notifees: List[INotifee]
@ -62,7 +69,8 @@ class Swarm(Service, INetworkService):
# Create Notifee array # Create Notifee array
self.notifees = [] self.notifees = []
self.common_stream_handler = None # Ignore type here since mypy complains: https://github.com/python/mypy/issues/2427
self.common_stream_handler = create_default_stream_handler(self) # type: ignore
async def run(self) -> None: async def run(self) -> None:
await self.manager.wait_finished() await self.manager.wait_finished()
@ -71,7 +79,8 @@ class Swarm(Service, INetworkService):
return self.self_id return self.self_id
def set_stream_handler(self, stream_handler: StreamHandlerFn) -> None: def set_stream_handler(self, stream_handler: StreamHandlerFn) -> None:
self.common_stream_handler = stream_handler # Ignore type here since mypy complains: https://github.com/python/mypy/issues/2427
self.common_stream_handler = stream_handler # type: ignore
async def dial_peer(self, peer_id: ID) -> INetConn: async def dial_peer(self, peer_id: ID) -> INetConn:
""" """