Refactor mplex and start to add close detection

This commit is contained in:
mhchia 2019-09-12 00:38:12 +08:00
parent 7483da762e
commit 0bd213bbb7
No known key found for this signature in database
GPG Key ID: 389EFBEA1362589A
8 changed files with 153 additions and 109 deletions

View File

@ -0,0 +1,18 @@
from abc import abstractmethod
from typing import Tuple
from libp2p.io.abc import Closer
from libp2p.network.stream.net_stream_interface import INetStream
from libp2p.stream_muxer.abc import IMuxedConn
class INetConn(Closer):
conn: IMuxedConn
@abstractmethod
async def new_stream(self) -> INetStream:
...
@abstractmethod
async def get_streams(self) -> Tuple[INetStream, ...]:
...

View File

@ -0,0 +1,74 @@
import asyncio
from typing import TYPE_CHECKING, Any, Awaitable, List, Set, Tuple
from libp2p.network.connection.net_connection_interface import INetConn
from libp2p.network.stream.net_stream import NetStream
from libp2p.stream_muxer.abc import IMuxedConn, IMuxedStream
if TYPE_CHECKING:
from libp2p.network.swarm import Swarm # noqa: F401
"""
Reference: https://github.com/libp2p/go-libp2p-swarm/blob/04c86bbdafd390651cb2ee14e334f7caeedad722/swarm_conn.go # noqa: E501
"""
class SwarmConn(INetConn):
conn: IMuxedConn
swarm: "Swarm"
streams: Set[NetStream]
event_closed: asyncio.Event
_tasks: List["asyncio.Future[Any]"]
def __init__(self, conn: IMuxedConn, swarm: "Swarm") -> None:
self.conn = conn
self.swarm = swarm
self.streams = set()
self.event_closed = asyncio.Event()
self._tasks = []
async def close(self) -> None:
if self.event_closed.is_set():
return
self.event_closed.set()
await self.conn.close()
for task in self._tasks:
task.cancel()
# TODO: Reset streams for local.
# TODO: Notify closed.
async def _handle_new_streams(self) -> None:
while True:
print("!@# SwarmConn._handle_new_streams")
stream = await self.conn.accept_stream()
print("!@# SwarmConn._handle_new_streams: accept_stream:", stream)
net_stream = await self._add_stream(stream)
print("!@# SwarmConn.calling swarm_stream_handler")
await self.run_task(self.swarm.swarm_stream_handler(net_stream))
await self.close()
async def _add_stream(self, muxed_stream: IMuxedStream) -> NetStream:
print("!@# SwarmConn._add_stream:", muxed_stream)
net_stream = NetStream(muxed_stream)
# Call notifiers since event occurred
for notifee in self.swarm.notifees:
await notifee.opened_stream(self.swarm, net_stream)
return net_stream
async def start(self) -> None:
print("!@# SwarmConn.start")
await self.run_task(self._handle_new_streams())
async def run_task(self, coro: Awaitable[Any]) -> None:
self._tasks.append(asyncio.ensure_future(coro))
async def new_stream(self) -> NetStream:
muxed_stream = await self.conn.open_stream()
return await self._add_stream(muxed_stream)
async def get_streams(self) -> Tuple[NetStream, ...]:
return tuple(self.streams)

View File

@ -3,9 +3,9 @@ from typing import TYPE_CHECKING, Dict, Sequence
from multiaddr import Multiaddr from multiaddr import Multiaddr
from libp2p.network.connection.net_connection_interface import INetConn
from libp2p.peer.id import ID from libp2p.peer.id import ID
from libp2p.peer.peerstore_interface import IPeerStore from libp2p.peer.peerstore_interface import IPeerStore
from libp2p.stream_muxer.abc import IMuxedConn
from libp2p.transport.listener_interface import IListener from libp2p.transport.listener_interface import IListener
from libp2p.typing import StreamHandlerFn, TProtocol from libp2p.typing import StreamHandlerFn, TProtocol
@ -18,7 +18,7 @@ if TYPE_CHECKING:
class INetwork(ABC): class INetwork(ABC):
peerstore: IPeerStore peerstore: IPeerStore
connections: Dict[ID, IMuxedConn] connections: Dict[ID, INetConn]
listeners: Dict[str, IListener] listeners: Dict[str, IListener]
@abstractmethod @abstractmethod
@ -28,7 +28,7 @@ class INetwork(ABC):
""" """
@abstractmethod @abstractmethod
async def dial_peer(self, peer_id: ID) -> IMuxedConn: async def dial_peer(self, peer_id: ID) -> INetConn:
""" """
dial_peer try to create a connection to peer_id dial_peer try to create a connection to peer_id

View File

@ -1,9 +1,10 @@
import asyncio import asyncio
import logging import logging
from typing import Callable, Dict, List, Sequence from typing import Callable, Dict, List, Optional, Sequence
from multiaddr import Multiaddr from multiaddr import Multiaddr
from libp2p.network.connection.net_connection_interface import INetConn
from libp2p.peer.id import ID from libp2p.peer.id import ID
from libp2p.peer.peerstore import PeerStoreError from libp2p.peer.peerstore import PeerStoreError
from libp2p.peer.peerstore_interface import IPeerStore from libp2p.peer.peerstore_interface import IPeerStore
@ -11,7 +12,7 @@ from libp2p.protocol_muxer.multiselect import Multiselect
from libp2p.protocol_muxer.multiselect_client import MultiselectClient from libp2p.protocol_muxer.multiselect_client import MultiselectClient
from libp2p.protocol_muxer.multiselect_communicator import MultiselectCommunicator from libp2p.protocol_muxer.multiselect_communicator import MultiselectCommunicator
from libp2p.routing.interfaces import IPeerRouting from libp2p.routing.interfaces import IPeerRouting
from libp2p.stream_muxer.abc import IMuxedConn, IMuxedStream from libp2p.stream_muxer.abc import IMuxedConn
from libp2p.transport.exceptions import MuxerUpgradeFailure, SecurityUpgradeFailure from libp2p.transport.exceptions import MuxerUpgradeFailure, SecurityUpgradeFailure
from libp2p.transport.listener_interface import IListener from libp2p.transport.listener_interface import IListener
from libp2p.transport.transport_interface import ITransport from libp2p.transport.transport_interface import ITransport
@ -19,10 +20,10 @@ from libp2p.transport.upgrader import TransportUpgrader
from libp2p.typing import StreamHandlerFn, TProtocol from libp2p.typing import StreamHandlerFn, TProtocol
from .connection.raw_connection import RawConnection from .connection.raw_connection import RawConnection
from .connection.swarm_connection import SwarmConn
from .exceptions import SwarmException from .exceptions import SwarmException
from .network_interface import INetwork from .network_interface import INetwork
from .notifee_interface import INotifee from .notifee_interface import INotifee
from .stream.net_stream import NetStream
from .stream.net_stream_interface import INetStream from .stream.net_stream_interface import INetStream
from .typing import GenericProtocolHandlerFn from .typing import GenericProtocolHandlerFn
@ -39,9 +40,9 @@ class Swarm(INetwork):
router: IPeerRouting router: IPeerRouting
# TODO: Connection and `peer_id` are 1-1 mapping in our implementation, # TODO: Connection and `peer_id` are 1-1 mapping in our implementation,
# whereas in Go one `peer_id` may point to multiple connections. # whereas in Go one `peer_id` may point to multiple connections.
connections: Dict[ID, IMuxedConn] connections: Dict[ID, INetConn]
listeners: Dict[str, IListener] listeners: Dict[str, IListener]
stream_handlers: Dict[INetStream, Callable[[INetStream], None]] swarm_stream_handler: Optional[Callable[[INetStream], None]]
multiselect: Multiselect multiselect: Multiselect
multiselect_client: MultiselectClient multiselect_client: MultiselectClient
@ -63,7 +64,6 @@ class Swarm(INetwork):
self.router = router self.router = router
self.connections = dict() self.connections = dict()
self.listeners = dict() self.listeners = dict()
self.stream_handlers = dict()
# Protocol muxing # Protocol muxing
self.multiselect = Multiselect() self.multiselect = Multiselect()
@ -73,7 +73,9 @@ class Swarm(INetwork):
self.notifees = [] self.notifees = []
# Create generic protocol handler # Create generic protocol handler
self.generic_protocol_handler = create_generic_protocol_handler(self) self.swarm_stream_handler = (
self.generic_protocol_handler
) = create_generic_protocol_handler(self)
def get_peer_id(self) -> ID: def get_peer_id(self) -> ID:
return self.self_id return self.self_id
@ -87,7 +89,7 @@ class Swarm(INetwork):
""" """
self.multiselect.add_handler(protocol_id, stream_handler) self.multiselect.add_handler(protocol_id, stream_handler)
async def dial_peer(self, peer_id: ID) -> IMuxedConn: async def dial_peer(self, peer_id: ID) -> INetConn:
""" """
dial_peer try to create a connection to peer_id dial_peer try to create a connection to peer_id
:param peer_id: peer if we want to dial :param peer_id: peer if we want to dial
@ -134,9 +136,7 @@ class Swarm(INetwork):
logger.debug("upgraded security for peer %s", peer_id) logger.debug("upgraded security for peer %s", peer_id)
try: try:
muxed_conn = await self.upgrader.upgrade_connection( muxed_conn = await self.upgrader.upgrade_connection(secured_conn, peer_id)
secured_conn, self.generic_protocol_handler, peer_id
)
except MuxerUpgradeFailure as error: except MuxerUpgradeFailure as error:
error_msg = "fail to upgrade mux for peer %s" error_msg = "fail to upgrade mux for peer %s"
logger.debug(error_msg, peer_id) logger.debug(error_msg, peer_id)
@ -145,20 +145,15 @@ class Swarm(INetwork):
logger.debug("upgraded mux for peer %s", peer_id) logger.debug("upgraded mux for peer %s", peer_id)
# Store muxed connection in connections swarm_conn = await self.add_conn(muxed_conn)
self.connections[peer_id] = muxed_conn
# Call notifiers since event occurred
for notifee in self.notifees:
await notifee.connected(self, muxed_conn)
logger.debug("successfully dialed peer %s", peer_id) logger.debug("successfully dialed peer %s", peer_id)
return muxed_conn return swarm_conn
async def new_stream( async def new_stream(
self, peer_id: ID, protocol_ids: Sequence[TProtocol] self, peer_id: ID, protocol_ids: Sequence[TProtocol]
) -> NetStream: ) -> INetStream:
""" """
:param peer_id: peer_id of destination :param peer_id: peer_id of destination
:param protocol_id: protocol id :param protocol_id: protocol id
@ -170,18 +165,20 @@ class Swarm(INetwork):
protocol_ids, protocol_ids,
) )
muxed_conn = await self.dial_peer(peer_id) print(f"!@# swarm.new_stream: 0")
swarm_conn = await self.dial_peer(peer_id)
print(f"!@# swarm.new_stream: 1")
# Use muxed conn to open stream, which returns a muxed stream # Use muxed conn to open stream, which returns a muxed stream
muxed_stream = await muxed_conn.open_stream() net_stream = await swarm_conn.new_stream()
print(f"!@# swarm.new_stream: 2")
# Perform protocol muxing to determine protocol to use # Perform protocol muxing to determine protocol to use
selected_protocol = await self.multiselect_client.select_one_of( selected_protocol = await self.multiselect_client.select_one_of(
list(protocol_ids), MultiselectCommunicator(muxed_stream) list(protocol_ids), MultiselectCommunicator(net_stream)
) )
print(f"!@# swarm.new_stream: 3")
# Create a net stream with the selected protocol
net_stream = NetStream(muxed_stream)
net_stream.set_protocol(selected_protocol) net_stream.set_protocol(selected_protocol)
logger.debug( logger.debug(
@ -189,11 +186,6 @@ class Swarm(INetwork):
peer_id, peer_id,
selected_protocol, selected_protocol,
) )
# Call notifiers since event occurred
for notifee in self.notifees:
await notifee.opened_stream(self, net_stream)
return net_stream return net_stream
async def listen(self, *multiaddrs: Multiaddr) -> bool: async def listen(self, *multiaddrs: Multiaddr) -> bool:
@ -243,7 +235,7 @@ class Swarm(INetwork):
try: try:
muxed_conn = await self.upgrader.upgrade_connection( muxed_conn = await self.upgrader.upgrade_connection(
secured_conn, self.generic_protocol_handler, peer_id secured_conn, peer_id
) )
except MuxerUpgradeFailure as error: except MuxerUpgradeFailure as error:
error_msg = "fail to upgrade mux for peer %s" error_msg = "fail to upgrade mux for peer %s"
@ -251,11 +243,8 @@ class Swarm(INetwork):
await secured_conn.close() await secured_conn.close()
raise SwarmException(error_msg % peer_id) from error raise SwarmException(error_msg % peer_id) from error
logger.debug("upgraded mux for peer %s", peer_id) logger.debug("upgraded mux for peer %s", peer_id)
# Store muxed_conn with peer id
self.connections[peer_id] = muxed_conn await self.add_conn(muxed_conn)
# Call notifiers since event occurred
for notifee in self.notifees:
await notifee.connected(self, muxed_conn)
logger.debug("successfully opened connection to peer %s", peer_id) logger.debug("successfully opened connection to peer %s", peer_id)
@ -315,7 +304,19 @@ class Swarm(INetwork):
logger.debug("successfully close the connection to peer %s", peer_id) logger.debug("successfully close the connection to peer %s", peer_id)
async def add_conn(self, muxed_conn: IMuxedConn) -> SwarmConn:
swarm_conn = SwarmConn(muxed_conn, self)
# Store muxed_conn with peer id
self.connections[muxed_conn.peer_id] = swarm_conn
# Call notifiers since event occurred
for notifee in self.notifees:
# TODO: Call with other type of conn?
await notifee.connected(self, muxed_conn)
await swarm_conn.start()
return swarm_conn
# TODO: Move to `BasicHost`
def create_generic_protocol_handler(swarm: Swarm) -> GenericProtocolHandlerFn: def create_generic_protocol_handler(swarm: Swarm) -> GenericProtocolHandlerFn:
""" """
Create a generic protocol handler from the given swarm. We use swarm Create a generic protocol handler from the given swarm. We use swarm
@ -325,20 +326,13 @@ def create_generic_protocol_handler(swarm: Swarm) -> GenericProtocolHandlerFn:
""" """
multiselect = swarm.multiselect multiselect = swarm.multiselect
async def generic_protocol_handler(muxed_stream: IMuxedStream) -> None: # Reference: `BasicHost.newStreamHandler` in Go.
async def generic_protocol_handler(net_stream: INetStream) -> None:
# Perform protocol muxing to determine protocol to use # Perform protocol muxing to determine protocol to use
protocol, handler = await multiselect.negotiate( protocol, handler = await multiselect.negotiate(
MultiselectCommunicator(muxed_stream) MultiselectCommunicator(net_stream)
) )
net_stream = NetStream(muxed_stream)
net_stream.set_protocol(protocol) net_stream.set_protocol(protocol)
# Call notifiers since event occurred
for notifee in swarm.notifees:
await notifee.opened_stream(swarm, net_stream)
# Give to stream handler
asyncio.ensure_future(handler(net_stream)) asyncio.ensure_future(handler(net_stream))
return generic_protocol_handler return generic_protocol_handler

View File

@ -1,15 +1,8 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import TYPE_CHECKING
from libp2p.io.abc import ReadWriteCloser from libp2p.io.abc import ReadWriteCloser
from libp2p.peer.id import ID from libp2p.peer.id import ID
from libp2p.security.secure_conn_interface import ISecureConn from libp2p.security.secure_conn_interface import ISecureConn
from libp2p.stream_muxer.mplex.constants import HeaderTags
from libp2p.stream_muxer.mplex.datastructures import StreamID
if TYPE_CHECKING:
# Prevent GenericProtocolHandlerFn introducing circular dependencies
from libp2p.network.typing import GenericProtocolHandlerFn # noqa: F401
class IMuxedConn(ABC): class IMuxedConn(ABC):
@ -20,16 +13,10 @@ class IMuxedConn(ABC):
peer_id: ID peer_id: ID
@abstractmethod @abstractmethod
def __init__( def __init__(self, conn: ISecureConn, peer_id: ID) -> None:
self,
conn: ISecureConn,
generic_protocol_handler: "GenericProtocolHandlerFn",
peer_id: ID,
) -> None:
""" """
create a new muxed connection create a new muxed connection
:param conn: an instance of secured connection :param conn: an instance of secured connection
:param generic_protocol_handler: generic protocol handler
for new muxed streams for new muxed streams
:param peer_id: peer_id of peer the connection is to :param peer_id: peer_id of peer the connection is to
""" """
@ -60,22 +47,11 @@ class IMuxedConn(ABC):
""" """
@abstractmethod @abstractmethod
async def accept_stream(self, stream_id: StreamID, name: str) -> None: async def accept_stream(self) -> "IMuxedStream":
""" """
accepts a muxed stream opened by the other end accepts a muxed stream opened by the other end
""" """
@abstractmethod
async def send_message(
self, flag: HeaderTags, data: bytes, stream_id: StreamID
) -> int:
"""
sends a message over the connection
:param header: header to use
:param data: data to send in the message
:param stream_id: stream the message is in
"""
class IMuxedStream(ReadWriteCloser): class IMuxedStream(ReadWriteCloser):

View File

@ -2,7 +2,6 @@ import asyncio
from typing import Any # noqa: F401 from typing import Any # noqa: F401
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
from libp2p.network.typing import GenericProtocolHandlerFn
from libp2p.peer.id import ID from libp2p.peer.id import ID
from libp2p.security.secure_conn_interface import ISecureConn from libp2p.security.secure_conn_interface import ISecureConn
from libp2p.stream_muxer.abc import IMuxedConn, IMuxedStream from libp2p.stream_muxer.abc import IMuxedConn, IMuxedStream
@ -34,17 +33,13 @@ class Mplex(IMuxedConn):
next_channel_id: int next_channel_id: int
streams: Dict[StreamID, MplexStream] streams: Dict[StreamID, MplexStream]
streams_lock: asyncio.Lock streams_lock: asyncio.Lock
new_stream_queue: "asyncio.Queue[IMuxedStream]"
shutdown: asyncio.Event shutdown: asyncio.Event
_tasks: List["asyncio.Future[Any]"] _tasks: List["asyncio.Future[Any]"]
# TODO: `generic_protocol_handler` should be refactored out of mplex conn. # TODO: `generic_protocol_handler` should be refactored out of mplex conn.
def __init__( def __init__(self, secured_conn: ISecureConn, peer_id: ID) -> None:
self,
secured_conn: ISecureConn,
generic_protocol_handler: GenericProtocolHandlerFn,
peer_id: ID,
) -> None:
""" """
create a new muxed connection create a new muxed connection
:param secured_conn: an instance of ``ISecureConn`` :param secured_conn: an instance of ``ISecureConn``
@ -56,15 +51,13 @@ class Mplex(IMuxedConn):
self.next_channel_id = 0 self.next_channel_id = 0
# Store generic protocol handler
self.generic_protocol_handler = generic_protocol_handler
# Set peer_id # Set peer_id
self.peer_id = peer_id self.peer_id = peer_id
# Mapping from stream ID -> buffer of messages for that stream # Mapping from stream ID -> buffer of messages for that stream
self.streams = {} self.streams = {}
self.streams_lock = asyncio.Lock() self.streams_lock = asyncio.Lock()
self.new_stream_queue = asyncio.Queue()
self.shutdown = asyncio.Event() self.shutdown = asyncio.Event()
self._tasks = [] self._tasks = []
@ -101,9 +94,10 @@ class Mplex(IMuxedConn):
return next_id return next_id
async def _initialize_stream(self, stream_id: StreamID, name: str) -> MplexStream: async def _initialize_stream(self, stream_id: StreamID, name: str) -> MplexStream:
stream = MplexStream(name, stream_id, self)
async with self.streams_lock: async with self.streams_lock:
stream = MplexStream(name, stream_id, self) self.streams[stream_id] = stream
self.streams[stream_id] = stream print(f"!@# _initialize_stream: stream_id={stream_id}, name={name}")
return stream return stream
async def open_stream(self) -> IMuxedStream: async def open_stream(self) -> IMuxedStream:
@ -119,13 +113,11 @@ class Mplex(IMuxedConn):
await self.send_message(HeaderTags.NewStream, name.encode(), stream_id) await self.send_message(HeaderTags.NewStream, name.encode(), stream_id)
return stream return stream
async def accept_stream(self, stream_id: StreamID, name: str) -> None: async def accept_stream(self) -> IMuxedStream:
""" """
accepts a muxed stream opened by the other end accepts a muxed stream opened by the other end
""" """
stream = await self._initialize_stream(stream_id, name) return await self.new_stream_queue.get()
# Perform protocol negotiation for the stream.
self._tasks.append(asyncio.ensure_future(self.generic_protocol_handler(stream)))
async def send_message( async def send_message(
self, flag: HeaderTags, data: Optional[bytes], stream_id: StreamID self, flag: HeaderTags, data: Optional[bytes], stream_id: StreamID
@ -178,7 +170,11 @@ class Mplex(IMuxedConn):
# `NewStream` for the same id is received twice... # `NewStream` for the same id is received twice...
# TODO: Shutdown # TODO: Shutdown
pass pass
await self.accept_stream(stream_id, message.decode()) mplex_stream = await self._initialize_stream(
stream_id, message.decode()
)
# TODO: Check if `self` is shutdown.
await self.new_stream_queue.put(mplex_stream)
elif flag in ( elif flag in (
HeaderTags.MessageInitiator.value, HeaderTags.MessageInitiator.value,
HeaderTags.MessageReceiver.value, HeaderTags.MessageReceiver.value,

View File

@ -2,7 +2,6 @@ from collections import OrderedDict
from typing import Mapping, Type from typing import Mapping, Type
from libp2p.network.connection.raw_connection_interface import IRawConnection from libp2p.network.connection.raw_connection_interface import IRawConnection
from libp2p.network.typing import GenericProtocolHandlerFn
from libp2p.peer.id import ID from libp2p.peer.id import ID
from libp2p.protocol_muxer.multiselect import Multiselect from libp2p.protocol_muxer.multiselect import Multiselect
from libp2p.protocol_muxer.multiselect_client import MultiselectClient from libp2p.protocol_muxer.multiselect_client import MultiselectClient
@ -69,11 +68,6 @@ class MuxerMultistream:
protocol, _ = await self.multiselect.negotiate(communicator) protocol, _ = await self.multiselect.negotiate(communicator)
return self.transports[protocol] return self.transports[protocol]
async def new_conn( async def new_conn(self, conn: ISecureConn, peer_id: ID) -> IMuxedConn:
self,
conn: ISecureConn,
generic_protocol_handler: GenericProtocolHandlerFn,
peer_id: ID,
) -> IMuxedConn:
transport_class = await self.select_transport(conn) transport_class = await self.select_transport(conn)
return transport_class(conn, generic_protocol_handler, peer_id) return transport_class(conn, peer_id)

View File

@ -1,7 +1,6 @@
from typing import Mapping from typing import Mapping
from libp2p.network.connection.raw_connection_interface import IRawConnection from libp2p.network.connection.raw_connection_interface import IRawConnection
from libp2p.network.typing import GenericProtocolHandlerFn
from libp2p.peer.id import ID from libp2p.peer.id import ID
from libp2p.protocol_muxer.exceptions import MultiselectClientError, MultiselectError from libp2p.protocol_muxer.exceptions import MultiselectClientError, MultiselectError
from libp2p.security.secure_conn_interface import ISecureConn from libp2p.security.secure_conn_interface import ISecureConn
@ -60,19 +59,12 @@ class TransportUpgrader:
"handshake failed when upgrading to secure connection" "handshake failed when upgrading to secure connection"
) from error ) from error
async def upgrade_connection( async def upgrade_connection(self, conn: ISecureConn, peer_id: ID) -> IMuxedConn:
self,
conn: ISecureConn,
generic_protocol_handler: GenericProtocolHandlerFn,
peer_id: ID,
) -> IMuxedConn:
""" """
Upgrade secured connection to a muxed connection Upgrade secured connection to a muxed connection
""" """
try: try:
return await self.muxer_multistream.new_conn( return await self.muxer_multistream.new_conn(conn, peer_id)
conn, generic_protocol_handler, peer_id
)
except (MultiselectError, MultiselectClientError) as error: except (MultiselectError, MultiselectClientError) as error:
raise MuxerUpgradeFailure( raise MuxerUpgradeFailure(
"failed to negotiate the multiplexer protocol" "failed to negotiate the multiplexer protocol"