Refactor multiselect out of Swarm to BasicHost

This commit is contained in:
mhchia 2019-09-12 14:30:39 +08:00
parent 0bd213bbb7
commit 6cb033fd1f
No known key found for this signature in database
GPG Key ID: 389EFBEA1362589A
5 changed files with 50 additions and 84 deletions

View File

@ -1,4 +1,5 @@
from typing import Any, List, Sequence import asyncio
from typing import List, Sequence
import multiaddr import multiaddr
@ -7,6 +8,9 @@ from libp2p.network.stream.net_stream_interface import INetStream
from libp2p.peer.id import ID from libp2p.peer.id import ID
from libp2p.peer.peerinfo import PeerInfo from libp2p.peer.peerinfo import PeerInfo
from libp2p.peer.peerstore_interface import IPeerStore from libp2p.peer.peerstore_interface import IPeerStore
from libp2p.protocol_muxer.multiselect import Multiselect
from libp2p.protocol_muxer.multiselect_client import MultiselectClient
from libp2p.protocol_muxer.multiselect_communicator import MultiselectCommunicator
from libp2p.routing.kademlia.kademlia_peer_router import KadmeliaPeerRouter from libp2p.routing.kademlia.kademlia_peer_router import KadmeliaPeerRouter
from libp2p.typing import StreamHandlerFn, TProtocol from libp2p.typing import StreamHandlerFn, TProtocol
@ -24,11 +28,18 @@ class BasicHost(IHost):
_router: KadmeliaPeerRouter _router: KadmeliaPeerRouter
peerstore: IPeerStore peerstore: IPeerStore
multiselect: Multiselect
multiselect_client: MultiselectClient
# default options constructor # default options constructor
def __init__(self, network: INetwork, router: KadmeliaPeerRouter = None) -> None: def __init__(self, network: INetwork, router: KadmeliaPeerRouter = None) -> None:
self._network = network self._network = network
self._network.set_stream_handler(self._swarm_stream_handler)
self._router = router self._router = router
self.peerstore = self._network.peerstore self.peerstore = self._network.peerstore
# Protocol muxing
self.multiselect = Multiselect()
self.multiselect_client = MultiselectClient()
def get_id(self) -> ID: def get_id(self) -> ID:
""" """
@ -48,11 +59,11 @@ class BasicHost(IHost):
""" """
return self.peerstore return self.peerstore
# FIXME: Replace with correct return type def get_mux(self) -> Multiselect:
def get_mux(self) -> Any:
""" """
:return: mux instance of host :return: mux instance of host
""" """
return self.multiselect
def get_addrs(self) -> List[multiaddr.Multiaddr]: def get_addrs(self) -> List[multiaddr.Multiaddr]:
""" """
@ -74,7 +85,7 @@ class BasicHost(IHost):
:param protocol_id: protocol id used on stream :param protocol_id: protocol id used on stream
:param stream_handler: a stream handler function :param stream_handler: a stream handler function
""" """
self._network.set_stream_handler(protocol_id, stream_handler) self.multiselect.add_handler(protocol_id, stream_handler)
# `protocol_ids` can be a list of `protocol_id` # `protocol_ids` can be a list of `protocol_id`
# stream will decide which `protocol_id` to run on # stream will decide which `protocol_id` to run on
@ -86,7 +97,16 @@ class BasicHost(IHost):
:param protocol_ids: available protocol ids to use for stream :param protocol_ids: available protocol ids to use for stream
:return: stream: new stream created :return: stream: new stream created
""" """
return await self._network.new_stream(peer_id, protocol_ids)
net_stream = await self._network.new_stream(peer_id, protocol_ids)
# Perform protocol muxing to determine protocol to use
selected_protocol = await self.multiselect_client.select_one_of(
list(protocol_ids), MultiselectCommunicator(net_stream)
)
net_stream.set_protocol(selected_protocol)
return net_stream
async def connect(self, peer_info: PeerInfo) -> None: async def connect(self, peer_info: PeerInfo) -> None:
""" """
@ -111,3 +131,12 @@ class BasicHost(IHost):
async def close(self) -> None: async def close(self) -> None:
await self._network.close() await self._network.close()
# Reference: `BasicHost.newStreamHandler` in Go.
async def _swarm_stream_handler(self, net_stream: INetStream) -> None:
# Perform protocol muxing to determine protocol to use
protocol, handler = await self.multiselect.negotiate(
MultiselectCommunicator(net_stream)
)
net_stream.set_protocol(protocol)
asyncio.ensure_future(handler(net_stream))

View File

@ -42,13 +42,15 @@ class SwarmConn(INetConn):
# TODO: Notify closed. # TODO: Notify closed.
async def _handle_new_streams(self) -> None: async def _handle_new_streams(self) -> None:
# TODO: Break the loop when anything wrong in the connection.
while True: while True:
print("!@# SwarmConn._handle_new_streams") print("!@# SwarmConn._handle_new_streams")
stream = await self.conn.accept_stream() stream = await self.conn.accept_stream()
print("!@# SwarmConn._handle_new_streams: accept_stream:", stream) print("!@# SwarmConn._handle_new_streams: accept_stream:", stream)
net_stream = await self._add_stream(stream) net_stream = await self._add_stream(stream)
print("!@# SwarmConn.calling swarm_stream_handler") print("!@# SwarmConn.calling common_stream_handler")
await self.run_task(self.swarm.swarm_stream_handler(net_stream)) if self.swarm.common_stream_handler is not None:
await self.run_task(self.swarm.common_stream_handler(net_stream))
await self.close() await self.close()
async def _add_stream(self, muxed_stream: IMuxedStream) -> NetStream: async def _add_stream(self, muxed_stream: IMuxedStream) -> NetStream:

View File

@ -37,15 +37,6 @@ class INetwork(ABC):
:return: muxed connection :return: muxed connection
""" """
@abstractmethod
def set_stream_handler(
self, protocol_id: TProtocol, stream_handler: StreamHandlerFn
) -> None:
"""
:param protocol_id: protocol id used on stream
:param stream_handler: a stream handler instance
"""
@abstractmethod @abstractmethod
async def new_stream( async def new_stream(
self, peer_id: ID, protocol_ids: Sequence[TProtocol] self, peer_id: ID, protocol_ids: Sequence[TProtocol]
@ -56,6 +47,12 @@ class INetwork(ABC):
:return: net stream instance :return: net stream instance
""" """
@abstractmethod
def set_stream_handler(self, stream_handler: StreamHandlerFn) -> None:
"""
Set the stream handler for all incoming streams.
"""
@abstractmethod @abstractmethod
async def listen(self, *multiaddrs: Sequence[Multiaddr]) -> bool: async def listen(self, *multiaddrs: Sequence[Multiaddr]) -> bool:
""" """

View File

@ -1,6 +1,6 @@
import asyncio import asyncio
import logging import logging
from typing import Callable, Dict, List, Optional, Sequence from typing import Dict, List, Optional, Sequence
from multiaddr import Multiaddr from multiaddr import Multiaddr
@ -8,9 +8,6 @@ from libp2p.network.connection.net_connection_interface import INetConn
from libp2p.peer.id import ID from libp2p.peer.id import ID
from libp2p.peer.peerstore import PeerStoreError from libp2p.peer.peerstore import PeerStoreError
from libp2p.peer.peerstore_interface import IPeerStore from libp2p.peer.peerstore_interface import IPeerStore
from libp2p.protocol_muxer.multiselect import Multiselect
from libp2p.protocol_muxer.multiselect_client import MultiselectClient
from libp2p.protocol_muxer.multiselect_communicator import MultiselectCommunicator
from libp2p.routing.interfaces import IPeerRouting from libp2p.routing.interfaces import IPeerRouting
from libp2p.stream_muxer.abc import IMuxedConn from libp2p.stream_muxer.abc import IMuxedConn
from libp2p.transport.exceptions import MuxerUpgradeFailure, SecurityUpgradeFailure from libp2p.transport.exceptions import MuxerUpgradeFailure, SecurityUpgradeFailure
@ -25,7 +22,6 @@ from .exceptions import SwarmException
from .network_interface import INetwork from .network_interface import INetwork
from .notifee_interface import INotifee from .notifee_interface import INotifee
from .stream.net_stream_interface import INetStream from .stream.net_stream_interface import INetStream
from .typing import GenericProtocolHandlerFn
logger = logging.getLogger("libp2p.network.swarm") logger = logging.getLogger("libp2p.network.swarm")
logger.setLevel(logging.DEBUG) logger.setLevel(logging.DEBUG)
@ -42,10 +38,7 @@ class Swarm(INetwork):
# whereas in Go one `peer_id` may point to multiple connections. # whereas in Go one `peer_id` may point to multiple connections.
connections: Dict[ID, INetConn] connections: Dict[ID, INetConn]
listeners: Dict[str, IListener] listeners: Dict[str, IListener]
swarm_stream_handler: Optional[Callable[[INetStream], None]] common_stream_handler: Optional[StreamHandlerFn]
multiselect: Multiselect
multiselect_client: MultiselectClient
notifees: List[INotifee] notifees: List[INotifee]
@ -65,29 +58,16 @@ class Swarm(INetwork):
self.connections = dict() self.connections = dict()
self.listeners = dict() self.listeners = dict()
# Protocol muxing
self.multiselect = Multiselect()
self.multiselect_client = MultiselectClient()
# Create Notifee array # Create Notifee array
self.notifees = [] self.notifees = []
# Create generic protocol handler self.common_stream_handler = None
self.swarm_stream_handler = (
self.generic_protocol_handler
) = create_generic_protocol_handler(self)
def get_peer_id(self) -> ID: def get_peer_id(self) -> ID:
return self.self_id return self.self_id
def set_stream_handler( def set_stream_handler(self, stream_handler: StreamHandlerFn) -> None:
self, protocol_id: TProtocol, stream_handler: StreamHandlerFn self.common_stream_handler = stream_handler
) -> None:
"""
:param protocol_id: protocol id used on stream
:param stream_handler: a stream handler instance
"""
self.multiselect.add_handler(protocol_id, stream_handler)
async def dial_peer(self, peer_id: ID) -> INetConn: async def dial_peer(self, peer_id: ID) -> INetConn:
""" """
@ -169,23 +149,8 @@ class Swarm(INetwork):
swarm_conn = await self.dial_peer(peer_id) swarm_conn = await self.dial_peer(peer_id)
print(f"!@# swarm.new_stream: 1") print(f"!@# swarm.new_stream: 1")
# Use muxed conn to open stream, which returns a muxed stream
net_stream = await swarm_conn.new_stream() net_stream = await swarm_conn.new_stream()
print(f"!@# swarm.new_stream: 2") logger.debug("successfully opened a stream to peer %s", peer_id)
# Perform protocol muxing to determine protocol to use
selected_protocol = await self.multiselect_client.select_one_of(
list(protocol_ids), MultiselectCommunicator(net_stream)
)
print(f"!@# swarm.new_stream: 3")
net_stream.set_protocol(selected_protocol)
logger.debug(
"successfully opened a stream to peer %s, over protocol %s",
peer_id,
selected_protocol,
)
return net_stream return net_stream
async def listen(self, *multiaddrs: Multiaddr) -> bool: async def listen(self, *multiaddrs: Multiaddr) -> bool:
@ -314,25 +279,3 @@ class Swarm(INetwork):
await notifee.connected(self, muxed_conn) await notifee.connected(self, muxed_conn)
await swarm_conn.start() await swarm_conn.start()
return swarm_conn return swarm_conn
# TODO: Move to `BasicHost`
def create_generic_protocol_handler(swarm: Swarm) -> GenericProtocolHandlerFn:
"""
Create a generic protocol handler from the given swarm. We use swarm
to extract the multiselect module so that generic_protocol_handler
can use multiselect when generic_protocol_handler is called
from a different class
"""
multiselect = swarm.multiselect
# Reference: `BasicHost.newStreamHandler` in Go.
async def generic_protocol_handler(net_stream: INetStream) -> None:
# Perform protocol muxing to determine protocol to use
protocol, handler = await multiselect.negotiate(
MultiselectCommunicator(net_stream)
)
net_stream.set_protocol(protocol)
asyncio.ensure_future(handler(net_stream))
return generic_protocol_handler

View File

@ -1,5 +0,0 @@
from typing import Awaitable, Callable
from libp2p.stream_muxer.abc import IMuxedStream
GenericProtocolHandlerFn = Callable[[IMuxedStream], Awaitable[None]]