diff --git a/libp2p/network/connection/net_connection_interface.py b/libp2p/network/connection/net_connection_interface.py index e308ad6..f1bcac2 100644 --- a/libp2p/network/connection/net_connection_interface.py +++ b/libp2p/network/connection/net_connection_interface.py @@ -1,6 +1,8 @@ from abc import abstractmethod from typing import Tuple +import trio + from libp2p.io.abc import Closer from libp2p.network.stream.net_stream_interface import INetStream from libp2p.stream_muxer.abc import IMuxedConn @@ -8,11 +10,12 @@ from libp2p.stream_muxer.abc import IMuxedConn class INetConn(Closer): muxed_conn: IMuxedConn + event_started: trio.Event @abstractmethod async def new_stream(self) -> INetStream: ... @abstractmethod - async def get_streams(self) -> Tuple[INetStream, ...]: + def get_streams(self) -> Tuple[INetStream, ...]: ... diff --git a/libp2p/network/connection/swarm_connection.py b/libp2p/network/connection/swarm_connection.py index 90cb823..7cf3c0f 100644 --- a/libp2p/network/connection/swarm_connection.py +++ b/libp2p/network/connection/swarm_connection.py @@ -1,6 +1,5 @@ from typing import TYPE_CHECKING, Set, Tuple -from async_service import Service import trio from libp2p.network.connection.net_connection_interface import INetConn @@ -17,10 +16,11 @@ Reference: https://github.com/libp2p/go-libp2p-swarm/blob/04c86bbdafd390651cb2ee """ -class SwarmConn(INetConn, Service): +class SwarmConn(INetConn): muxed_conn: IMuxedConn swarm: "Swarm" streams: Set[NetStream] + event_started: trio.Event event_closed: trio.Event def __init__(self, muxed_conn: IMuxedConn, swarm: "Swarm") -> None: @@ -28,6 +28,7 @@ class SwarmConn(INetConn, Service): self.swarm = swarm self.streams = set() self.event_closed = trio.Event() + self.event_started = trio.Event() @property def is_closed(self) -> bool: @@ -38,8 +39,6 @@ class SwarmConn(INetConn, Service): return self.event_closed.set() await self._cleanup() - # Cancel service - await self.manager.stop() async def _cleanup(self) -> None: self.swarm.remove_conn(self) @@ -57,13 +56,14 @@ class SwarmConn(INetConn, Service): self._notify_disconnected() async def _handle_new_streams(self) -> None: - while self.manager.is_running: + self.event_started.set() + while True: try: stream = await self.muxed_conn.accept_stream() # Asynchronously handle the accepted stream, to avoid blocking the next stream. except MuxedConnUnavailable: break - self.manager.run_task(self._handle_muxed_stream, stream) + self.swarm.manager.run_task(self._handle_muxed_stream, stream) await self.close() @@ -87,15 +87,14 @@ class SwarmConn(INetConn, Service): def _notify_disconnected(self) -> None: self.swarm.notify_disconnected(self) - async def run(self) -> None: - self.manager.run_task(self._handle_new_streams) - await self.manager.wait_finished() + async def start(self) -> None: + await self._handle_new_streams() async def new_stream(self) -> NetStream: muxed_stream = await self.muxed_conn.open_stream() return self._add_stream(muxed_stream) - async def get_streams(self) -> Tuple[NetStream, ...]: + def get_streams(self) -> Tuple[NetStream, ...]: return tuple(self.streams) def remove_stream(self, stream: NetStream) -> None: diff --git a/libp2p/network/swarm.py b/libp2p/network/swarm.py index 0904774..45d85b1 100644 --- a/libp2p/network/swarm.py +++ b/libp2p/network/swarm.py @@ -2,7 +2,6 @@ import logging from typing import Dict, List, Optional from multiaddr import Multiaddr -import trio from libp2p.io.abc import ReadWriteCloser from libp2p.network.connection.net_connection_interface import INetConn @@ -44,7 +43,6 @@ class Swarm(INetworkService): common_stream_handler: Optional[StreamHandlerFn] notifees: List[INotifee] - event_closed: trio.Event def __init__( self, @@ -63,8 +61,6 @@ class Swarm(INetworkService): # Create Notifee array self.notifees = [] - self.event_closed = trio.Event() - self.common_stream_handler = None async def run(self) -> None: @@ -158,13 +154,11 @@ class Swarm(INetworkService): try: muxed_conn = await self.upgrader.upgrade_connection(secured_conn, peer_id) - self.manager.run_child_service(muxed_conn) except MuxerUpgradeFailure as error: error_msg = "fail to upgrade mux for peer %s" logger.debug(error_msg, peer_id) await secured_conn.close() raise SwarmException(error_msg % peer_id) from error - logger.debug("upgraded mux for peer %s", peer_id) swarm_conn = await self.add_conn(muxed_conn) @@ -226,7 +220,6 @@ class Swarm(INetworkService): muxed_conn = await self.upgrader.upgrade_connection( secured_conn, peer_id ) - self.manager.run_child_service(muxed_conn) except MuxerUpgradeFailure as error: error_msg = "fail to upgrade mux for peer %s" logger.debug(error_msg, peer_id) @@ -235,8 +228,8 @@ class Swarm(INetworkService): logger.debug("upgraded mux for peer %s", peer_id) await self.add_conn(muxed_conn) - logger.debug("successfully opened connection to peer %s", peer_id) + # NOTE: This is a intentional barrier to prevent from the handler exiting and # closing the connection. await self.manager.wait_finished() @@ -261,26 +254,12 @@ class Swarm(INetworkService): return False async def close(self) -> None: - if self.event_closed.is_set(): - return - self.event_closed.set() - # Reference: https://github.com/libp2p/go-libp2p-swarm/blob/8be680aef8dea0a4497283f2f98470c2aeae6b65/swarm.go#L124-L134 # noqa: E501 - async with trio.open_nursery() as nursery: - for conn in self.connections.values(): - nursery.start_soon(conn.close) - async with trio.open_nursery() as nursery: - for listener in self.listeners.values(): - nursery.start_soon(listener.close) - - # Cancel tasks await self.manager.stop() logger.debug("swarm successfully closed") async def close_peer(self, peer_id: ID) -> None: if peer_id not in self.connections: return - # TODO: Should be changed to close multisple connections, - # if we have several connections per peer in the future. connection = self.connections[peer_id] # NOTE: `connection.close` will delete `peer_id` from `self.connections` # and `notify_disconnected` for us. @@ -293,12 +272,14 @@ class Swarm(INetworkService): and start to monitor the connection for its new streams and disconnection.""" swarm_conn = SwarmConn(muxed_conn, self) - manager = self.manager.run_child_service(swarm_conn) + self.manager.run_task(muxed_conn.start) + await muxed_conn.event_started.wait() + self.manager.run_task(swarm_conn.start) + await swarm_conn.event_started.wait() # Store muxed_conn with peer id self.connections[muxed_conn.peer_id] = swarm_conn # Call notifiers since event occurred self.notify_connected(swarm_conn) - await manager.wait_started() return swarm_conn def remove_conn(self, swarm_conn: SwarmConn) -> None: @@ -307,8 +288,6 @@ class Swarm(INetworkService): peer_id = swarm_conn.muxed_conn.peer_id if peer_id not in self.connections: return - # TODO: Should be changed to remove the exact connection, - # if we have several connections per peer in the future. del self.connections[peer_id] # Notifee diff --git a/libp2p/stream_muxer/abc.py b/libp2p/stream_muxer/abc.py index e34295c..82140ff 100644 --- a/libp2p/stream_muxer/abc.py +++ b/libp2p/stream_muxer/abc.py @@ -1,18 +1,19 @@ -from abc import abstractmethod +from abc import ABC, abstractmethod -from async_service import ServiceAPI +import trio from libp2p.io.abc import ReadWriteCloser from libp2p.peer.id import ID from libp2p.security.secure_conn_interface import ISecureConn -class IMuxedConn(ServiceAPI): +class IMuxedConn(ABC): """ reference: https://github.com/libp2p/go-stream-muxer/blob/master/muxer.go """ peer_id: ID + event_started: trio.Event @abstractmethod def __init__(self, conn: ISecureConn, peer_id: ID) -> None: @@ -27,7 +28,11 @@ class IMuxedConn(ServiceAPI): @property @abstractmethod def is_initiator(self) -> bool: - pass + """if this connection is the initiator.""" + + @abstractmethod + async def start(self) -> None: + """start the multiplexer.""" @abstractmethod async def close(self) -> None: diff --git a/libp2p/stream_muxer/mplex/mplex.py b/libp2p/stream_muxer/mplex/mplex.py index 6523b48..486fd3f 100644 --- a/libp2p/stream_muxer/mplex/mplex.py +++ b/libp2p/stream_muxer/mplex/mplex.py @@ -2,7 +2,6 @@ import logging import math from typing import Dict, Optional, Tuple -from async_service import Service import trio from libp2p.exceptions import ParseError @@ -29,7 +28,7 @@ MPLEX_PROTOCOL_ID = TProtocol("/mplex/6.7.0") logger = logging.getLogger("libp2p.stream_muxer.mplex.mplex") -class Mplex(IMuxedConn, Service): +class Mplex(IMuxedConn): """ reference: https://github.com/libp2p/go-mplex/blob/master/multiplex.go """ @@ -45,6 +44,7 @@ class Mplex(IMuxedConn, Service): event_shutting_down: trio.Event event_closed: trio.Event + event_started: trio.Event def __init__(self, secured_conn: ISecureConn, peer_id: ID) -> None: """ @@ -73,10 +73,10 @@ class Mplex(IMuxedConn, Service): self.new_stream_send_channel, self.new_stream_receive_channel = channels self.event_shutting_down = trio.Event() self.event_closed = trio.Event() + self.event_started = trio.Event() - async def run(self) -> None: - self.manager.run_task(self.handle_incoming) - await self.manager.wait_finished() + async def start(self) -> None: + await self.handle_incoming() @property def is_initiator(self) -> bool: @@ -91,7 +91,6 @@ class Mplex(IMuxedConn, Service): await self.secured_conn.close() # Blocked until `close` is finally set. await self.event_closed.wait() - await self.manager.stop() @property def is_closed(self) -> bool: @@ -178,8 +177,8 @@ class Mplex(IMuxedConn, Service): async def handle_incoming(self) -> None: """Read a message off of the secured connection and add it to the corresponding message buffer.""" - - while self.manager.is_running: + self.event_started.set() + while True: try: await self._handle_incoming_message() except MplexUnavailable as e: diff --git a/libp2p/stream_muxer/mplex/mplex_stream.py b/libp2p/stream_muxer/mplex/mplex_stream.py index fc3d274..ae6f7ea 100644 --- a/libp2p/stream_muxer/mplex/mplex_stream.py +++ b/libp2p/stream_muxer/mplex/mplex_stream.py @@ -188,9 +188,7 @@ class MplexStream(IMuxedStream): if self.is_initiator else HeaderTags.ResetReceiver ) - self.muxed_conn.manager.run_task( - self.muxed_conn.send_message, flag, None, self.stream_id - ) + await self.muxed_conn.send_message(flag, None, self.stream_id) self.event_local_closed.set() self.event_remote_closed.set() diff --git a/tests/network/test_swarm_conn.py b/tests/network/test_swarm_conn.py index 1bfd7d8..dc692f4 100644 --- a/tests/network/test_swarm_conn.py +++ b/tests/network/test_swarm_conn.py @@ -14,7 +14,6 @@ async def test_swarm_conn_close(swarm_conn_pair): await trio.sleep(0.1) await wait_all_tasks_blocked() - await conn_0.manager.wait_finished() assert conn_0.is_closed assert conn_1.is_closed @@ -26,22 +25,22 @@ async def test_swarm_conn_close(swarm_conn_pair): async def test_swarm_conn_streams(swarm_conn_pair): conn_0, conn_1 = swarm_conn_pair - assert len(await conn_0.get_streams()) == 0 - assert len(await conn_1.get_streams()) == 0 + assert len(conn_0.get_streams()) == 0 + assert len(conn_1.get_streams()) == 0 stream_0_0 = await conn_0.new_stream() await trio.sleep(0.01) - assert len(await conn_0.get_streams()) == 1 - assert len(await conn_1.get_streams()) == 1 + assert len(conn_0.get_streams()) == 1 + assert len(conn_1.get_streams()) == 1 stream_0_1 = await conn_0.new_stream() await trio.sleep(0.01) - assert len(await conn_0.get_streams()) == 2 - assert len(await conn_1.get_streams()) == 2 + assert len(conn_0.get_streams()) == 2 + assert len(conn_1.get_streams()) == 2 conn_0.remove_stream(stream_0_0) - assert len(await conn_0.get_streams()) == 1 + assert len(conn_0.get_streams()) == 1 conn_0.remove_stream(stream_0_1) - assert len(await conn_0.get_streams()) == 0 + assert len(conn_0.get_streams()) == 0 # Nothing happen if `stream_0_1` is not present or already removed. conn_0.remove_stream(stream_0_1)