diff --git a/libp2p/network/connection/swarm_connection.py b/libp2p/network/connection/swarm_connection.py index 29d544e..b91783d 100644 --- a/libp2p/network/connection/swarm_connection.py +++ b/libp2p/network/connection/swarm_connection.py @@ -1,6 +1,8 @@ -import asyncio from typing import TYPE_CHECKING, Any, Awaitable, List, Set, Tuple +import trio +from async_service import Service + from libp2p.network.connection.net_connection_interface import INetConn from libp2p.network.stream.net_stream import NetStream from libp2p.stream_muxer.abc import IMuxedConn, IMuxedStream @@ -15,21 +17,17 @@ Reference: https://github.com/libp2p/go-libp2p-swarm/blob/04c86bbdafd390651cb2ee """ -class SwarmConn(INetConn): +class SwarmConn(INetConn, Service): muxed_conn: IMuxedConn swarm: "Swarm" streams: Set[NetStream] - event_closed: asyncio.Event - - _tasks: List["asyncio.Future[Any]"] + event_closed: trio.Event def __init__(self, muxed_conn: IMuxedConn, swarm: "Swarm") -> None: self.muxed_conn = muxed_conn self.swarm = swarm self.streams = set() - self.event_closed = asyncio.Event() - - self._tasks = [] + self.event_closed = trio.Event() async def close(self) -> None: if self.event_closed.is_set(): @@ -45,16 +43,11 @@ class SwarmConn(INetConn): await stream.reset() # Force context switch for stream handlers to process the stream reset event we just emit # before we cancel the stream handler tasks. - await asyncio.sleep(0.1) + await trio.sleep(0.1) - for task in self._tasks: - task.cancel() - try: - await task - except asyncio.CancelledError: - pass + # FIXME: Now let `_notify_disconnected` finish first. # Schedule `self._notify_disconnected` to make it execute after `close` is finished. - self._notify_disconnected() + await self._notify_disconnected() async def _handle_new_streams(self) -> None: while True: @@ -65,7 +58,7 @@ class SwarmConn(INetConn): # we should break the loop and close the connection. break # Asynchronously handle the accepted stream, to avoid blocking the next stream. - await self.run_task(self._handle_muxed_stream(stream)) + self.manager.run_task(self._handle_muxed_stream, stream) await self.close() @@ -79,28 +72,26 @@ class SwarmConn(INetConn): self.remove_stream(net_stream) async def _handle_muxed_stream(self, muxed_stream: IMuxedStream) -> None: - net_stream = self._add_stream(muxed_stream) + net_stream = await self._add_stream(muxed_stream) if self.swarm.common_stream_handler is not None: - await self.run_task(self._call_stream_handler(net_stream)) + await self._call_stream_handler(net_stream) - def _add_stream(self, muxed_stream: IMuxedStream) -> NetStream: + async def _add_stream(self, muxed_stream: IMuxedStream) -> NetStream: net_stream = NetStream(muxed_stream) self.streams.add(net_stream) - self.swarm.notify_opened_stream(net_stream) + await self.swarm.notify_opened_stream(net_stream) return net_stream - def _notify_disconnected(self) -> None: - self.swarm.notify_disconnected(self) + async def _notify_disconnected(self) -> None: + await self.swarm.notify_disconnected(self) - async def start(self) -> None: - await self.run_task(self._handle_new_streams()) - - async def run_task(self, coro: Awaitable[Any]) -> None: - self._tasks.append(asyncio.ensure_future(coro)) + async def run(self) -> None: + self.manager.run_task(self._handle_new_streams) + await self.manager.wait_finished() async def new_stream(self) -> NetStream: muxed_stream = await self.muxed_conn.open_stream() - return self._add_stream(muxed_stream) + return await self._add_stream(muxed_stream) async def get_streams(self) -> Tuple[NetStream, ...]: return tuple(self.streams) diff --git a/libp2p/network/swarm.py b/libp2p/network/swarm.py index e54ad6f..337da89 100644 --- a/libp2p/network/swarm.py +++ b/libp2p/network/swarm.py @@ -1,7 +1,8 @@ -import asyncio import logging from typing import Dict, List, Optional +from async_service import Service + from multiaddr import Multiaddr import trio @@ -31,7 +32,7 @@ from .stream.net_stream_interface import INetStream logger = logging.getLogger("libp2p.network.swarm") -class Swarm(INetwork): +class Swarm(INetwork, Service): self_id: ID peerstore: IPeerStore @@ -64,13 +65,16 @@ class Swarm(INetwork): self.common_stream_handler = None + async def run(self) -> None: + await self.manager.wait_finished() + def get_peer_id(self) -> ID: return self.self_id def set_stream_handler(self, stream_handler: StreamHandlerFn) -> None: self.common_stream_handler = stream_handler - async def dial_peer(self, peer_id: ID, nursery) -> INetConn: + async def dial_peer(self, peer_id: ID) -> INetConn: """ dial_peer try to create a connection to peer_id. @@ -122,7 +126,7 @@ class Swarm(INetwork): try: muxed_conn = await self.upgrader.upgrade_connection(secured_conn, peer_id) - muxed_conn.run(nursery) + self.manager.run_child_service(muxed_conn) except MuxerUpgradeFailure as error: error_msg = "fail to upgrade mux for peer %s" logger.debug(error_msg, peer_id) @@ -137,7 +141,7 @@ class Swarm(INetwork): return swarm_conn - async def new_stream(self, peer_id: ID, nursery) -> INetStream: + async def new_stream(self, peer_id: ID) -> INetStream: """ :param peer_id: peer_id of destination :param protocol_id: protocol id @@ -146,13 +150,13 @@ class Swarm(INetwork): """ logger.debug("attempting to open a stream to peer %s", peer_id) - swarm_conn = await self.dial_peer(peer_id, nursery) + swarm_conn = await self.dial_peer(peer_id) net_stream = await swarm_conn.new_stream() logger.debug("successfully opened a stream to peer %s", peer_id) return net_stream - async def listen(self, *multiaddrs: Multiaddr, nursery) -> bool: + async def listen(self, *multiaddrs: Multiaddr) -> bool: """ :param multiaddrs: one or many multiaddrs to start listening on :return: true if at least one success @@ -189,7 +193,7 @@ class Swarm(INetwork): muxed_conn = await self.upgrader.upgrade_connection( secured_conn, peer_id ) - muxed_conn.run(nursery) + self.manager.run_child_service(muxed_conn) except MuxerUpgradeFailure as error: error_msg = "fail to upgrade mux for peer %s" logger.debug(error_msg, peer_id) @@ -200,6 +204,8 @@ class Swarm(INetwork): await self.add_conn(muxed_conn) logger.debug("successfully opened connection to peer %s", peer_id) + # FIXME: This is a intentional barrier to prevent from the handler exiting and + # closing the connection. event = trio.Event() await event.wait() @@ -207,10 +213,11 @@ class Swarm(INetwork): # Success listener = self.transport.create_listener(conn_handler) self.listeners[str(maddr)] = listener - await listener.listen(maddr, nursery) + # FIXME: Hack + await listener.listen(maddr, self.manager._task_nursery) # Call notifiers since event occurred - self.notify_listen(maddr) + await self.notify_listen(maddr) return True except IOError: @@ -225,15 +232,16 @@ class Swarm(INetwork): # Reference: https://github.com/libp2p/go-libp2p-swarm/blob/8be680aef8dea0a4497283f2f98470c2aeae6b65/swarm.go#L124-L134 # noqa: E501 # Close listeners - await asyncio.gather( - *[listener.close() for listener in self.listeners.values()] - ) - - # Close connections - await asyncio.gather( - *[connection.close() for connection in self.connections.values()] - ) + # await asyncio.gather( + # *[listener.close() for listener in self.listeners.values()] + # ) + # # Close connections + # await asyncio.gather( + # *[connection.close() for connection in self.connections.values()] + # ) + self.manager.stop() + await self.manager.wait_finished() logger.debug("swarm successfully closed") async def close_peer(self, peer_id: ID) -> None: @@ -253,11 +261,12 @@ class Swarm(INetwork): and start to monitor the connection for its new streams and disconnection.""" swarm_conn = SwarmConn(muxed_conn, self) + manager = self.manager.run_child_service(swarm_conn) # Store muxed_conn with peer id self.connections[muxed_conn.peer_id] = swarm_conn # Call notifiers since event occurred - self.notify_connected(swarm_conn) - await swarm_conn.start() + self.manager.run_task(self.notify_connected, swarm_conn) + await manager.wait_started() return swarm_conn def remove_conn(self, swarm_conn: SwarmConn) -> None: @@ -281,20 +290,26 @@ class Swarm(INetwork): """ self.notifees.append(notifee) - def notify_opened_stream(self, stream: INetStream) -> None: - asyncio.gather( - *[notifee.opened_stream(self, stream) for notifee in self.notifees] - ) + async def notify_opened_stream(self, stream: INetStream) -> None: + async with trio.open_nursery() as nursery: + for notifee in self.notifees: + nursery.start_soon(notifee.opened_stream, self, stream) # TODO: `notify_closed_stream` - def notify_connected(self, conn: INetConn) -> None: - asyncio.gather(*[notifee.connected(self, conn) for notifee in self.notifees]) + async def notify_connected(self, conn: INetConn) -> None: + async with trio.open_nursery() as nursery: + for notifee in self.notifees: + nursery.start_soon(notifee.connected, self, conn) - def notify_disconnected(self, conn: INetConn) -> None: - asyncio.gather(*[notifee.disconnected(self, conn) for notifee in self.notifees]) + async def notify_disconnected(self, conn: INetConn) -> None: + async with trio.open_nursery() as nursery: + for notifee in self.notifees: + nursery.start_soon(notifee.disconnected, self, conn) - def notify_listen(self, multiaddr: Multiaddr) -> None: - asyncio.gather(*[notifee.listen(self, multiaddr) for notifee in self.notifees]) + async def notify_listen(self, multiaddr: Multiaddr) -> None: + async with trio.open_nursery() as nursery: + for notifee in self.notifees: + nursery.start_soon(notifee.listen, self, multiaddr) # TODO: `notify_listen_close` diff --git a/libp2p/stream_muxer/mplex/mplex.py b/libp2p/stream_muxer/mplex/mplex.py index b6df526..9ac5614 100644 --- a/libp2p/stream_muxer/mplex/mplex.py +++ b/libp2p/stream_muxer/mplex/mplex.py @@ -4,6 +4,7 @@ from typing import Any # noqa: F401 from typing import Awaitable, Dict, List, Optional, Tuple import trio +from async_service import Service from libp2p.exceptions import ParseError from libp2p.io.exceptions import IncompleteReadError @@ -17,6 +18,7 @@ from libp2p.utils import ( encode_uvarint, encode_varint_prefixed, read_varint_prefixed_bytes, + TrioQueue, ) from .constants import HeaderTags @@ -29,7 +31,7 @@ MPLEX_PROTOCOL_ID = TProtocol("/mplex/6.7.0") logger = logging.getLogger("libp2p.stream_muxer.mplex.mplex") -class Mplex(IMuxedConn): +class Mplex(IMuxedConn, Service): """ reference: https://github.com/libp2p/go-mplex/blob/master/multiplex.go """ @@ -38,10 +40,10 @@ class Mplex(IMuxedConn): peer_id: ID next_channel_id: int streams: Dict[StreamID, MplexStream] - streams_lock: asyncio.Lock - new_stream_queue: "asyncio.Queue[IMuxedStream]" - event_shutting_down: asyncio.Event - event_closed: asyncio.Event + streams_lock: trio.Lock + new_stream_queue: "TrioQueue[IMuxedStream]" + event_shutting_down: trio.Event + event_closed: trio.Event def __init__(self, secured_conn: ISecureConn, peer_id: ID) -> None: """ @@ -61,13 +63,14 @@ class Mplex(IMuxedConn): # Mapping from stream ID -> buffer of messages for that stream self.streams = {} - self.streams_lock = asyncio.Lock() - self.new_stream_queue = asyncio.Queue() - self.event_shutting_down = asyncio.Event() - self.event_closed = asyncio.Event() + self.streams_lock = trio.Lock() + self.new_stream_queue = TrioQueue() + self.event_shutting_down = trio.Event() + self.event_closed = trio.Event() - def run(self, nursery): - nursery.start_soon(self.handle_incoming) + async def run(self): + self.manager.run_task(self.handle_incoming) + await self.manager.wait_finished() @property def is_initiator(self) -> bool: diff --git a/libp2p/stream_muxer/mplex/mplex_stream.py b/libp2p/stream_muxer/mplex/mplex_stream.py index 44b3aef..58da404 100644 --- a/libp2p/stream_muxer/mplex/mplex_stream.py +++ b/libp2p/stream_muxer/mplex/mplex_stream.py @@ -136,7 +136,6 @@ class MplexStream(IMuxedStream): nursery.start_soon( self.muxed_conn.send_message, flag, None, self.stream_id ) - await trio.sleep(0) self.event_local_closed.set() self.event_remote_closed.set() diff --git a/libp2p/tools/factories.py b/libp2p/tools/factories.py index 2b63544..f324371 100644 --- a/libp2p/tools/factories.py +++ b/libp2p/tools/factories.py @@ -1,8 +1,9 @@ -import asyncio -from contextlib import asynccontextmanager +import trio +from contextlib import asynccontextmanager, AsyncExitStack from typing import Any, AsyncIterator, Dict, Tuple, cast import factory +from async_service import background_trio_service from libp2p import generate_new_rsa_identity, generate_peer_id_from from libp2p.crypto.keys import KeyPair @@ -61,6 +62,7 @@ class SwarmFactory(factory.Factory): transport = factory.LazyFunction(TCP) @classmethod + @asynccontextmanager async def create_and_listen( cls, is_secure: bool, key_pair: KeyPair = None, muxer_opt: TMuxerOptions = None ) -> Swarm: @@ -73,20 +75,23 @@ class SwarmFactory(factory.Factory): if muxer_opt is not None: optional_kwargs["muxer_opt"] = muxer_opt swarm = cls(is_secure=is_secure, **optional_kwargs) - await swarm.listen(LISTEN_MADDR) - return swarm + async with background_trio_service(swarm): + await swarm.listen(LISTEN_MADDR) + yield swarm @classmethod + @asynccontextmanager async def create_batch_and_listen( cls, is_secure: bool, number: int, muxer_opt: TMuxerOptions = None ) -> Tuple[Swarm, ...]: - # Ignore typing since we are removing asyncio soon - return await asyncio.gather( # type: ignore - *[ - cls.create_and_listen(is_secure=is_secure, muxer_opt=muxer_opt) + async with AsyncExitStack() as stack: + ctx_mgrs = [ + await stack.enter_async_context( + cls.create_and_listen(is_secure=is_secure, muxer_opt=muxer_opt) + ) for _ in range(number) ] - ) + yield ctx_mgrs class HostFactory(factory.Factory): @@ -103,20 +108,23 @@ class HostFactory(factory.Factory): ) @classmethod + @asynccontextmanager async def create_batch_and_listen( cls, is_secure: bool, number: int ) -> Tuple[BasicHost, ...]: key_pairs = [generate_new_rsa_identity() for _ in range(number)] - swarms = await asyncio.gather( - *[ - SwarmFactory.create_and_listen(is_secure, key_pair) + async with AsyncExitStack() as stack: + swarms = [ + await stack.enter_async_context( + SwarmFactory.create_and_listen(is_secure, key_pair) + ) for key_pair in key_pairs ] - ) - return tuple( - BasicHost(key_pair.public_key, swarm) - for key_pair, swarm in zip(key_pairs, swarms) - ) + hosts = tuple( + BasicHost(key_pair.public_key, swarm) + for key_pair, swarm in zip(key_pairs, swarms) + ) + yield hosts class FloodsubFactory(factory.Factory): @@ -150,73 +158,60 @@ class PubsubFactory(factory.Factory): cache_size = None +@asynccontextmanager async def swarm_pair_factory( is_secure: bool, muxer_opt: TMuxerOptions = None ) -> Tuple[Swarm, Swarm]: - swarms = await SwarmFactory.create_batch_and_listen( + async with SwarmFactory.create_batch_and_listen( is_secure, 2, muxer_opt=muxer_opt - ) - await connect_swarm(swarms[0], swarms[1]) - return swarms[0], swarms[1] - - -async def host_pair_factory(is_secure: bool) -> Tuple[BasicHost, BasicHost]: - hosts = await HostFactory.create_batch_and_listen(is_secure, 2) - await connect(hosts[0], hosts[1]) - return hosts[0], hosts[1] + ) as swarms: + await connect_swarm(swarms[0], swarms[1]) + yield swarms[0], swarms[1] @asynccontextmanager -async def pair_of_connected_hosts( - is_secure: bool = True -) -> AsyncIterator[Tuple[BasicHost, BasicHost]]: - a, b = await host_pair_factory(is_secure) - yield a, b - close_tasks = (a.close(), b.close()) - await asyncio.gather(*close_tasks) +async def host_pair_factory(is_secure: bool) -> Tuple[BasicHost, BasicHost]: + async with HostFactory.create_batch_and_listen(is_secure, 2) as hosts: + await connect(hosts[0], hosts[1]) + yield hosts[0], hosts[1] +@asynccontextmanager async def swarm_conn_pair_factory( is_secure: bool, muxer_opt: TMuxerOptions = None -) -> Tuple[SwarmConn, Swarm, SwarmConn, Swarm]: - swarms = await swarm_pair_factory(is_secure) - conn_0 = swarms[0].connections[swarms[1].get_peer_id()] - conn_1 = swarms[1].connections[swarms[0].get_peer_id()] - return cast(SwarmConn, conn_0), swarms[0], cast(SwarmConn, conn_1), swarms[1] +) -> Tuple[SwarmConn, SwarmConn]: + async with swarm_pair_factory(is_secure) as swarms: + conn_0 = swarms[0].connections[swarms[1].get_peer_id()] + conn_1 = swarms[1].connections[swarms[0].get_peer_id()] + yield cast(SwarmConn, conn_0), cast(SwarmConn, conn_1) -async def mplex_conn_pair_factory(is_secure: bool) -> Tuple[Mplex, Swarm, Mplex, Swarm]: +@asynccontextmanager +async def mplex_conn_pair_factory(is_secure: bool) -> Tuple[Mplex, Mplex]: muxer_opt = {MPLEX_PROTOCOL_ID: Mplex} - conn_0, swarm_0, conn_1, swarm_1 = await swarm_conn_pair_factory( - is_secure, muxer_opt=muxer_opt - ) - return ( - cast(Mplex, conn_0.muxed_conn), - swarm_0, - cast(Mplex, conn_1.muxed_conn), - swarm_1, - ) + async with swarm_conn_pair_factory(is_secure, muxer_opt=muxer_opt) as swarm_pair: + yield ( + cast(Mplex, swarm_pair[0].muxed_conn), + cast(Mplex, swarm_pair[1].muxed_conn), + ) -async def mplex_stream_pair_factory( - is_secure: bool -) -> Tuple[MplexStream, Swarm, MplexStream, Swarm]: - mplex_conn_0, swarm_0, mplex_conn_1, swarm_1 = await mplex_conn_pair_factory( - is_secure - ) - stream_0 = await mplex_conn_0.open_stream() - await asyncio.sleep(0.01) - stream_1: MplexStream - async with mplex_conn_1.streams_lock: - if len(mplex_conn_1.streams) != 1: - raise Exception("Mplex should not have any stream upon connection") - stream_1 = tuple(mplex_conn_1.streams.values())[0] - return cast(MplexStream, stream_0), swarm_0, stream_1, swarm_1 +@asynccontextmanager +async def mplex_stream_pair_factory(is_secure: bool) -> Tuple[MplexStream, MplexStream]: + async with mplex_conn_pair_factory(is_secure) as mplex_conn_pair_info: + mplex_conn_0, mplex_conn_1 = mplex_conn_pair_info + stream_0 = await mplex_conn_0.open_stream() + await trio.sleep(0.01) + stream_1: MplexStream + async with mplex_conn_1.streams_lock: + if len(mplex_conn_1.streams) != 1: + raise Exception("Mplex should not have any stream upon connection") + stream_1 = tuple(mplex_conn_1.streams.values())[0] + yield cast(MplexStream, stream_0), cast(MplexStream, stream_1) -async def net_stream_pair_factory( - is_secure: bool -) -> Tuple[INetStream, BasicHost, INetStream, BasicHost]: +@asynccontextmanager +async def net_stream_pair_factory(is_secure: bool) -> Tuple[INetStream, INetStream]: protocol_id = TProtocol("/example/id/1") stream_1: INetStream @@ -226,8 +221,8 @@ async def net_stream_pair_factory( nonlocal stream_1 stream_1 = stream - host_0, host_1 = await host_pair_factory(is_secure) - host_1.set_stream_handler(protocol_id, handler) + async with host_pair_factory(is_secure) as hosts: + hosts[1].set_stream_handler(protocol_id, handler) - stream_0 = await host_0.new_stream(host_1.get_id(), [protocol_id]) - return stream_0, host_0, stream_1, host_1 + stream_0 = await hosts[0].new_stream(hosts[1].get_id(), [protocol_id]) + yield stream_0, stream_1 diff --git a/libp2p/tools/utils.py b/libp2p/tools/utils.py index d1c266a..0d39156 100644 --- a/libp2p/tools/utils.py +++ b/libp2p/tools/utils.py @@ -17,7 +17,7 @@ from libp2p.typing import StreamHandlerFn, TProtocol from .constants import MAX_READ_LEN -async def connect_swarm(swarm_0: Swarm, swarm_1: Swarm, nursery: trio.Nursery) -> None: +async def connect_swarm(swarm_0: Swarm, swarm_1: Swarm) -> None: peer_id = swarm_1.get_peer_id() addrs = tuple( addr @@ -25,7 +25,7 @@ async def connect_swarm(swarm_0: Swarm, swarm_1: Swarm, nursery: trio.Nursery) - for addr in transport.get_addrs() ) swarm_0.peerstore.add_addrs(peer_id, addrs, 10000) - await swarm_0.dial_peer(peer_id, nursery) + await swarm_0.dial_peer(peer_id) assert swarm_0.get_peer_id() in swarm_1.connections assert swarm_1.get_peer_id() in swarm_0.connections diff --git a/libp2p/transport/tcp/tcp.py b/libp2p/transport/tcp/tcp.py index fc00047..1636598 100644 --- a/libp2p/transport/tcp/tcp.py +++ b/libp2p/transport/tcp/tcp.py @@ -25,7 +25,7 @@ class TCPListener(IListener): self.server = None self.handler = handler_function - async def listen(self, maddr: Multiaddr, nursery) -> bool: + async def listen(self, maddr: Multiaddr, nursery: trio.Nursery) -> bool: """ put listener in listening mode and wait for incoming connections. diff --git a/tests/host/test_ping.py b/tests/host/test_ping.py index fcc5a85..1bd02f0 100644 --- a/tests/host/test_ping.py +++ b/tests/host/test_ping.py @@ -4,12 +4,12 @@ import secrets import pytest from libp2p.host.ping import ID, PING_LENGTH -from libp2p.tools.factories import pair_of_connected_hosts +from libp2p.tools.factories import host_pair_factory @pytest.mark.asyncio async def test_ping_once(): - async with pair_of_connected_hosts() as (host_a, host_b): + async with host_pair_factory() as (host_a, host_b): stream = await host_b.new_stream(host_a.get_id(), (ID,)) some_ping = secrets.token_bytes(PING_LENGTH) await stream.write(some_ping) @@ -23,7 +23,7 @@ SOME_PING_COUNT = 3 @pytest.mark.asyncio async def test_ping_several(): - async with pair_of_connected_hosts() as (host_a, host_b): + async with host_pair_factory() as (host_a, host_b): stream = await host_b.new_stream(host_a.get_id(), (ID,)) for _ in range(SOME_PING_COUNT): some_ping = secrets.token_bytes(PING_LENGTH) diff --git a/tests/identity/identify/test_protocol.py b/tests/identity/identify/test_protocol.py index fab78ec..6136c87 100644 --- a/tests/identity/identify/test_protocol.py +++ b/tests/identity/identify/test_protocol.py @@ -2,12 +2,12 @@ import pytest from libp2p.identity.identify.pb.identify_pb2 import Identify from libp2p.identity.identify.protocol import ID, _mk_identify_protobuf -from libp2p.tools.factories import pair_of_connected_hosts +from libp2p.tools.factories import host_pair_factory @pytest.mark.asyncio async def test_identify_protocol(): - async with pair_of_connected_hosts() as (host_a, host_b): + async with host_pair_factory() as (host_a, host_b): stream = await host_b.new_stream(host_a.get_id(), (ID,)) response = await stream.read() await stream.close() diff --git a/tests/network/conftest.py b/tests/network/conftest.py index 6b75b75..c45dbdb 100644 --- a/tests/network/conftest.py +++ b/tests/network/conftest.py @@ -11,26 +11,17 @@ from libp2p.tools.factories import ( @pytest.fixture async def net_stream_pair(is_host_secure): - stream_0, host_0, stream_1, host_1 = await net_stream_pair_factory(is_host_secure) - try: - yield stream_0, stream_1 - finally: - await asyncio.gather(*[host_0.close(), host_1.close()]) + async with net_stream_pair_factory(is_host_secure) as net_stream_pair: + yield net_stream_pair @pytest.fixture async def swarm_pair(is_host_secure): - swarm_0, swarm_1 = await swarm_pair_factory(is_host_secure) - try: - yield swarm_0, swarm_1 - finally: - await asyncio.gather(*[swarm_0.close(), swarm_1.close()]) + async with swarm_pair_factory(is_host_secure) as swarms: + yield swarms @pytest.fixture async def swarm_conn_pair(is_host_secure): - conn_0, swarm_0, conn_1, swarm_1 = await swarm_conn_pair_factory(is_host_secure) - try: - yield conn_0, conn_1 - finally: - await asyncio.gather(*[swarm_0.close(), swarm_1.close()]) + async with swarm_conn_pair_factory(is_host_secure) as swarm_conn_pair: + yield swarm_conn_pair diff --git a/tests/network/test_swarm.py b/tests/network/test_swarm.py index 6fe2543..de08635 100644 --- a/tests/network/test_swarm.py +++ b/tests/network/test_swarm.py @@ -1,88 +1,83 @@ -import asyncio - +import trio import pytest +from trio.testing import wait_all_tasks_blocked from libp2p.network.exceptions import SwarmException from libp2p.tools.factories import SwarmFactory from libp2p.tools.utils import connect_swarm -@pytest.mark.asyncio +@pytest.mark.trio async def test_swarm_dial_peer(is_host_secure): - swarms = await SwarmFactory.create_batch_and_listen(is_host_secure, 3) - # Test: No addr found. - with pytest.raises(SwarmException): + async with SwarmFactory.create_batch_and_listen(is_host_secure, 3) as swarms: + # Test: No addr found. + with pytest.raises(SwarmException): + await swarms[0].dial_peer(swarms[1].get_peer_id()) + + # Test: len(addr) in the peerstore is 0. + swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), [], 10000) + with pytest.raises(SwarmException): + await swarms[0].dial_peer(swarms[1].get_peer_id()) + + # Test: Succeed if addrs of the peer_id are present in the peerstore. + addrs = tuple( + addr + for transport in swarms[1].listeners.values() + for addr in transport.get_addrs() + ) + swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), addrs, 10000) await swarms[0].dial_peer(swarms[1].get_peer_id()) + assert swarms[0].get_peer_id() in swarms[1].connections + assert swarms[1].get_peer_id() in swarms[0].connections - # Test: len(addr) in the peerstore is 0. - swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), [], 10000) - with pytest.raises(SwarmException): - await swarms[0].dial_peer(swarms[1].get_peer_id()) - - # Test: Succeed if addrs of the peer_id are present in the peerstore. - addrs = tuple( - addr - for transport in swarms[1].listeners.values() - for addr in transport.get_addrs() - ) - swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), addrs, 10000) - await swarms[0].dial_peer(swarms[1].get_peer_id()) - assert swarms[0].get_peer_id() in swarms[1].connections - assert swarms[1].get_peer_id() in swarms[0].connections - - # Test: Reuse connections when we already have ones with a peer. - conn_to_1 = swarms[0].connections[swarms[1].get_peer_id()] - conn = await swarms[0].dial_peer(swarms[1].get_peer_id()) - assert conn is conn_to_1 - - # Clean up - await asyncio.gather(*[swarm.close() for swarm in swarms]) + # Test: Reuse connections when we already have ones with a peer. + conn_to_1 = swarms[0].connections[swarms[1].get_peer_id()] + conn = await swarms[0].dial_peer(swarms[1].get_peer_id()) + assert conn is conn_to_1 -@pytest.mark.asyncio +@pytest.mark.trio async def test_swarm_close_peer(is_host_secure): - swarms = await SwarmFactory.create_batch_and_listen(is_host_secure, 3) - # 0 <> 1 <> 2 - await connect_swarm(swarms[0], swarms[1]) - await connect_swarm(swarms[1], swarms[2]) + async with SwarmFactory.create_batch_and_listen(is_host_secure, 3) as swarms: + # 0 <> 1 <> 2 + await connect_swarm(swarms[0], swarms[1]) + await connect_swarm(swarms[1], swarms[2]) - # peer 1 closes peer 0 - await swarms[1].close_peer(swarms[0].get_peer_id()) - await asyncio.sleep(0.01) - # 0 1 <> 2 - assert len(swarms[0].connections) == 0 - assert ( - len(swarms[1].connections) == 1 - and swarms[2].get_peer_id() in swarms[1].connections - ) + # peer 1 closes peer 0 + await swarms[1].close_peer(swarms[0].get_peer_id()) + await trio.sleep(0.01) + await wait_all_tasks_blocked() + # 0 1 <> 2 + assert len(swarms[0].connections) == 0 + assert ( + len(swarms[1].connections) == 1 + and swarms[2].get_peer_id() in swarms[1].connections + ) - # peer 1 is closed by peer 2 - await swarms[2].close_peer(swarms[1].get_peer_id()) - await asyncio.sleep(0.01) - # 0 1 2 - assert len(swarms[1].connections) == 0 and len(swarms[2].connections) == 0 + # peer 1 is closed by peer 2 + await swarms[2].close_peer(swarms[1].get_peer_id()) + await trio.sleep(0.01) + # 0 1 2 + assert len(swarms[1].connections) == 0 and len(swarms[2].connections) == 0 - await connect_swarm(swarms[0], swarms[1]) - # 0 <> 1 2 - assert ( - len(swarms[0].connections) == 1 - and swarms[1].get_peer_id() in swarms[0].connections - ) - assert ( - len(swarms[1].connections) == 1 - and swarms[0].get_peer_id() in swarms[1].connections - ) - # peer 0 closes peer 1 - await swarms[0].close_peer(swarms[1].get_peer_id()) - await asyncio.sleep(0.01) - # 0 1 2 - assert len(swarms[1].connections) == 0 and len(swarms[2].connections) == 0 - - # Clean up - await asyncio.gather(*[swarm.close() for swarm in swarms]) + await connect_swarm(swarms[0], swarms[1]) + # 0 <> 1 2 + assert ( + len(swarms[0].connections) == 1 + and swarms[1].get_peer_id() in swarms[0].connections + ) + assert ( + len(swarms[1].connections) == 1 + and swarms[0].get_peer_id() in swarms[1].connections + ) + # peer 0 closes peer 1 + await swarms[0].close_peer(swarms[1].get_peer_id()) + await trio.sleep(0.01) + # 0 1 2 + assert len(swarms[1].connections) == 0 and len(swarms[2].connections) == 0 -@pytest.mark.asyncio +@pytest.mark.trio async def test_swarm_remove_conn(swarm_pair): swarm_0, swarm_1 = swarm_pair conn_0 = swarm_0.connections[swarm_1.get_peer_id()] diff --git a/tests/stream_muxer/conftest.py b/tests/stream_muxer/conftest.py index cdb57e8..5c5bc2b 100644 --- a/tests/stream_muxer/conftest.py +++ b/tests/stream_muxer/conftest.py @@ -7,23 +7,13 @@ from libp2p.tools.factories import mplex_conn_pair_factory, mplex_stream_pair_fa @pytest.fixture async def mplex_conn_pair(is_host_secure): - mplex_conn_0, swarm_0, mplex_conn_1, swarm_1 = await mplex_conn_pair_factory( - is_host_secure - ) - assert mplex_conn_0.is_initiator - assert not mplex_conn_1.is_initiator - try: - yield mplex_conn_0, mplex_conn_1 - finally: - await asyncio.gather(*[swarm_0.close(), swarm_1.close()]) + async with mplex_conn_pair_factory(is_host_secure) as mplex_conn_pair: + assert mplex_conn_pair[0].is_initiator + assert not mplex_conn_pair[1].is_initiator + yield mplex_conn_pair[0], mplex_conn_pair[1] @pytest.fixture async def mplex_stream_pair(is_host_secure): - mplex_stream_0, swarm_0, mplex_stream_1, swarm_1 = await mplex_stream_pair_factory( - is_host_secure - ) - try: - yield mplex_stream_0, mplex_stream_1 - finally: - await asyncio.gather(*[swarm_0.close(), swarm_1.close()]) + async with mplex_stream_pair_factory(is_host_secure) as mplex_stream_pair: + yield mplex_stream_pair diff --git a/tests/stream_muxer/test_mplex_stream.py b/tests/stream_muxer/test_mplex_stream.py index 27c7f45..dae6657 100644 --- a/tests/stream_muxer/test_mplex_stream.py +++ b/tests/stream_muxer/test_mplex_stream.py @@ -1,5 +1,3 @@ -import asyncio - import pytest import trio @@ -12,25 +10,26 @@ from libp2p.tools.constants import MAX_READ_LEN, LISTEN_MADDR from libp2p.tools.factories import SwarmFactory from libp2p.tools.utils import connect_swarm + DATA = b"data_123" @pytest.mark.trio -async def test_mplex_stream_read_write(nursery): - swarm0, swarm1 = SwarmFactory(), SwarmFactory() - await swarm0.listen(LISTEN_MADDR, nursery=nursery) - await swarm1.listen(LISTEN_MADDR, nursery=nursery) - await connect_swarm(swarm0, swarm1, nursery) - conn_0 = swarm0.connections[swarm1.get_peer_id()] - conn_1 = swarm1.connections[swarm0.get_peer_id()] - stream_0 = await conn_0.muxed_conn.open_stream() - await trio.sleep(1) - stream_1 = tuple(conn_1.muxed_conn.streams.values())[0] - await stream_0.write(DATA) - assert (await stream_1.read(MAX_READ_LEN)) == DATA +async def test_mplex_stream_read_write(): + async with SwarmFactory.create_batch_and_listen(False, 2) as swarms: + await swarms[0].listen(LISTEN_MADDR) + await swarms[1].listen(LISTEN_MADDR) + await connect_swarm(swarms[0], swarms[1]) + conn_0 = swarms[0].connections[swarms[1].get_peer_id()] + conn_1 = swarms[1].connections[swarms[0].get_peer_id()] + stream_0 = await conn_0.muxed_conn.open_stream() + await trio.sleep(1) + stream_1 = tuple(conn_1.muxed_conn.streams.values())[0] + await stream_0.write(DATA) + assert (await stream_1.read(MAX_READ_LEN)) == DATA -@pytest.mark.asyncio +@pytest.mark.trio async def test_mplex_stream_pair_read_until_eof(mplex_stream_pair): read_bytes = bytearray() stream_0, stream_1 = mplex_stream_pair @@ -38,43 +37,43 @@ async def test_mplex_stream_pair_read_until_eof(mplex_stream_pair): async def read_until_eof(): read_bytes.extend(await stream_1.read()) - task = asyncio.ensure_future(read_until_eof()) + task = trio.ensure_future(read_until_eof()) expected_data = bytearray() # Test: `read` doesn't return before `close` is called. await stream_0.write(DATA) expected_data.extend(DATA) - await asyncio.sleep(0.01) + await trio.sleep(0.01) assert len(read_bytes) == 0 # Test: `read` doesn't return before `close` is called. await stream_0.write(DATA) expected_data.extend(DATA) - await asyncio.sleep(0.01) + await trio.sleep(0.01) assert len(read_bytes) == 0 # Test: Close the stream, `read` returns, and receive previous sent data. await stream_0.close() - await asyncio.sleep(0.01) + await trio.sleep(0.01) assert read_bytes == expected_data task.cancel() -@pytest.mark.asyncio +@pytest.mark.trio async def test_mplex_stream_read_after_remote_closed(mplex_stream_pair): stream_0, stream_1 = mplex_stream_pair assert not stream_1.event_remote_closed.is_set() await stream_0.write(DATA) await stream_0.close() - await asyncio.sleep(0.01) + await trio.sleep(0.01) assert stream_1.event_remote_closed.is_set() assert (await stream_1.read(MAX_READ_LEN)) == DATA with pytest.raises(MplexStreamEOF): await stream_1.read(MAX_READ_LEN) -@pytest.mark.asyncio +@pytest.mark.trio async def test_mplex_stream_read_after_local_reset(mplex_stream_pair): stream_0, stream_1 = mplex_stream_pair await stream_0.reset() @@ -82,29 +81,29 @@ async def test_mplex_stream_read_after_local_reset(mplex_stream_pair): await stream_0.read(MAX_READ_LEN) -@pytest.mark.asyncio +@pytest.mark.trio async def test_mplex_stream_read_after_remote_reset(mplex_stream_pair): stream_0, stream_1 = mplex_stream_pair await stream_0.write(DATA) await stream_0.reset() # Sleep to let `stream_1` receive the message. - await asyncio.sleep(0.01) + await trio.sleep(0.01) with pytest.raises(MplexStreamReset): await stream_1.read(MAX_READ_LEN) -@pytest.mark.asyncio +@pytest.mark.trio async def test_mplex_stream_read_after_remote_closed_and_reset(mplex_stream_pair): stream_0, stream_1 = mplex_stream_pair await stream_0.write(DATA) await stream_0.close() await stream_0.reset() # Sleep to let `stream_1` receive the message. - await asyncio.sleep(0.01) + await trio.sleep(0.01) assert (await stream_1.read(MAX_READ_LEN)) == DATA -@pytest.mark.asyncio +@pytest.mark.trio async def test_mplex_stream_write_after_local_closed(mplex_stream_pair): stream_0, stream_1 = mplex_stream_pair await stream_0.write(DATA) @@ -113,7 +112,7 @@ async def test_mplex_stream_write_after_local_closed(mplex_stream_pair): await stream_0.write(DATA) -@pytest.mark.asyncio +@pytest.mark.trio async def test_mplex_stream_write_after_local_reset(mplex_stream_pair): stream_0, stream_1 = mplex_stream_pair await stream_0.reset() @@ -121,16 +120,16 @@ async def test_mplex_stream_write_after_local_reset(mplex_stream_pair): await stream_0.write(DATA) -@pytest.mark.asyncio +@pytest.mark.trio async def test_mplex_stream_write_after_remote_reset(mplex_stream_pair): stream_0, stream_1 = mplex_stream_pair await stream_1.reset() - await asyncio.sleep(0.01) + await trio.sleep(0.01) with pytest.raises(MplexStreamClosed): await stream_0.write(DATA) -@pytest.mark.asyncio +@pytest.mark.trio async def test_mplex_stream_both_close(mplex_stream_pair): stream_0, stream_1 = mplex_stream_pair # Flags are not set initially. @@ -144,7 +143,7 @@ async def test_mplex_stream_both_close(mplex_stream_pair): # Test: Close one side. await stream_0.close() - await asyncio.sleep(0.01) + await trio.sleep(0.01) assert stream_0.event_local_closed.is_set() assert not stream_1.event_local_closed.is_set() @@ -156,7 +155,7 @@ async def test_mplex_stream_both_close(mplex_stream_pair): # Test: Close the other side. await stream_1.close() - await asyncio.sleep(0.01) + await trio.sleep(0.01) # Both sides are closed. assert stream_0.event_local_closed.is_set() assert stream_1.event_local_closed.is_set() @@ -170,11 +169,11 @@ async def test_mplex_stream_both_close(mplex_stream_pair): await stream_0.reset() -@pytest.mark.asyncio +@pytest.mark.trio async def test_mplex_stream_reset(mplex_stream_pair): stream_0, stream_1 = mplex_stream_pair await stream_0.reset() - await asyncio.sleep(0.01) + await trio.sleep(0.01) # Both sides are closed. assert stream_0.event_local_closed.is_set()