Rewrite factories, made some of the test running

This commit is contained in:
mhchia 2019-11-26 19:24:30 +08:00
parent 417b5e7d61
commit ec43c25b45
No known key found for this signature in database
GPG Key ID: 389EFBEA1362589A
13 changed files with 260 additions and 282 deletions

View File

@ -1,6 +1,8 @@
import asyncio
from typing import TYPE_CHECKING, Any, Awaitable, List, Set, Tuple from typing import TYPE_CHECKING, Any, Awaitable, List, Set, Tuple
import trio
from async_service import Service
from libp2p.network.connection.net_connection_interface import INetConn from libp2p.network.connection.net_connection_interface import INetConn
from libp2p.network.stream.net_stream import NetStream from libp2p.network.stream.net_stream import NetStream
from libp2p.stream_muxer.abc import IMuxedConn, IMuxedStream from libp2p.stream_muxer.abc import IMuxedConn, IMuxedStream
@ -15,21 +17,17 @@ Reference: https://github.com/libp2p/go-libp2p-swarm/blob/04c86bbdafd390651cb2ee
""" """
class SwarmConn(INetConn): class SwarmConn(INetConn, Service):
muxed_conn: IMuxedConn muxed_conn: IMuxedConn
swarm: "Swarm" swarm: "Swarm"
streams: Set[NetStream] streams: Set[NetStream]
event_closed: asyncio.Event event_closed: trio.Event
_tasks: List["asyncio.Future[Any]"]
def __init__(self, muxed_conn: IMuxedConn, swarm: "Swarm") -> None: def __init__(self, muxed_conn: IMuxedConn, swarm: "Swarm") -> None:
self.muxed_conn = muxed_conn self.muxed_conn = muxed_conn
self.swarm = swarm self.swarm = swarm
self.streams = set() self.streams = set()
self.event_closed = asyncio.Event() self.event_closed = trio.Event()
self._tasks = []
async def close(self) -> None: async def close(self) -> None:
if self.event_closed.is_set(): if self.event_closed.is_set():
@ -45,16 +43,11 @@ class SwarmConn(INetConn):
await stream.reset() await stream.reset()
# Force context switch for stream handlers to process the stream reset event we just emit # Force context switch for stream handlers to process the stream reset event we just emit
# before we cancel the stream handler tasks. # before we cancel the stream handler tasks.
await asyncio.sleep(0.1) await trio.sleep(0.1)
for task in self._tasks: # FIXME: Now let `_notify_disconnected` finish first.
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
# Schedule `self._notify_disconnected` to make it execute after `close` is finished. # Schedule `self._notify_disconnected` to make it execute after `close` is finished.
self._notify_disconnected() await self._notify_disconnected()
async def _handle_new_streams(self) -> None: async def _handle_new_streams(self) -> None:
while True: while True:
@ -65,7 +58,7 @@ class SwarmConn(INetConn):
# we should break the loop and close the connection. # we should break the loop and close the connection.
break break
# Asynchronously handle the accepted stream, to avoid blocking the next stream. # Asynchronously handle the accepted stream, to avoid blocking the next stream.
await self.run_task(self._handle_muxed_stream(stream)) self.manager.run_task(self._handle_muxed_stream, stream)
await self.close() await self.close()
@ -79,28 +72,26 @@ class SwarmConn(INetConn):
self.remove_stream(net_stream) self.remove_stream(net_stream)
async def _handle_muxed_stream(self, muxed_stream: IMuxedStream) -> None: async def _handle_muxed_stream(self, muxed_stream: IMuxedStream) -> None:
net_stream = self._add_stream(muxed_stream) net_stream = await self._add_stream(muxed_stream)
if self.swarm.common_stream_handler is not None: if self.swarm.common_stream_handler is not None:
await self.run_task(self._call_stream_handler(net_stream)) await self._call_stream_handler(net_stream)
def _add_stream(self, muxed_stream: IMuxedStream) -> NetStream: async def _add_stream(self, muxed_stream: IMuxedStream) -> NetStream:
net_stream = NetStream(muxed_stream) net_stream = NetStream(muxed_stream)
self.streams.add(net_stream) self.streams.add(net_stream)
self.swarm.notify_opened_stream(net_stream) await self.swarm.notify_opened_stream(net_stream)
return net_stream return net_stream
def _notify_disconnected(self) -> None: async def _notify_disconnected(self) -> None:
self.swarm.notify_disconnected(self) await self.swarm.notify_disconnected(self)
async def start(self) -> None: async def run(self) -> None:
await self.run_task(self._handle_new_streams()) self.manager.run_task(self._handle_new_streams)
await self.manager.wait_finished()
async def run_task(self, coro: Awaitable[Any]) -> None:
self._tasks.append(asyncio.ensure_future(coro))
async def new_stream(self) -> NetStream: async def new_stream(self) -> NetStream:
muxed_stream = await self.muxed_conn.open_stream() muxed_stream = await self.muxed_conn.open_stream()
return self._add_stream(muxed_stream) return await self._add_stream(muxed_stream)
async def get_streams(self) -> Tuple[NetStream, ...]: async def get_streams(self) -> Tuple[NetStream, ...]:
return tuple(self.streams) return tuple(self.streams)

View File

@ -1,7 +1,8 @@
import asyncio
import logging import logging
from typing import Dict, List, Optional from typing import Dict, List, Optional
from async_service import Service
from multiaddr import Multiaddr from multiaddr import Multiaddr
import trio import trio
@ -31,7 +32,7 @@ from .stream.net_stream_interface import INetStream
logger = logging.getLogger("libp2p.network.swarm") logger = logging.getLogger("libp2p.network.swarm")
class Swarm(INetwork): class Swarm(INetwork, Service):
self_id: ID self_id: ID
peerstore: IPeerStore peerstore: IPeerStore
@ -64,13 +65,16 @@ class Swarm(INetwork):
self.common_stream_handler = None self.common_stream_handler = None
async def run(self) -> None:
await self.manager.wait_finished()
def get_peer_id(self) -> ID: def get_peer_id(self) -> ID:
return self.self_id return self.self_id
def set_stream_handler(self, stream_handler: StreamHandlerFn) -> None: def set_stream_handler(self, stream_handler: StreamHandlerFn) -> None:
self.common_stream_handler = stream_handler self.common_stream_handler = stream_handler
async def dial_peer(self, peer_id: ID, nursery) -> INetConn: async def dial_peer(self, peer_id: ID) -> INetConn:
""" """
dial_peer try to create a connection to peer_id. dial_peer try to create a connection to peer_id.
@ -122,7 +126,7 @@ class Swarm(INetwork):
try: try:
muxed_conn = await self.upgrader.upgrade_connection(secured_conn, peer_id) muxed_conn = await self.upgrader.upgrade_connection(secured_conn, peer_id)
muxed_conn.run(nursery) self.manager.run_child_service(muxed_conn)
except MuxerUpgradeFailure as error: except MuxerUpgradeFailure as error:
error_msg = "fail to upgrade mux for peer %s" error_msg = "fail to upgrade mux for peer %s"
logger.debug(error_msg, peer_id) logger.debug(error_msg, peer_id)
@ -137,7 +141,7 @@ class Swarm(INetwork):
return swarm_conn return swarm_conn
async def new_stream(self, peer_id: ID, nursery) -> INetStream: async def new_stream(self, peer_id: ID) -> INetStream:
""" """
:param peer_id: peer_id of destination :param peer_id: peer_id of destination
:param protocol_id: protocol id :param protocol_id: protocol id
@ -146,13 +150,13 @@ class Swarm(INetwork):
""" """
logger.debug("attempting to open a stream to peer %s", peer_id) logger.debug("attempting to open a stream to peer %s", peer_id)
swarm_conn = await self.dial_peer(peer_id, nursery) swarm_conn = await self.dial_peer(peer_id)
net_stream = await swarm_conn.new_stream() net_stream = await swarm_conn.new_stream()
logger.debug("successfully opened a stream to peer %s", peer_id) logger.debug("successfully opened a stream to peer %s", peer_id)
return net_stream return net_stream
async def listen(self, *multiaddrs: Multiaddr, nursery) -> bool: async def listen(self, *multiaddrs: Multiaddr) -> bool:
""" """
:param multiaddrs: one or many multiaddrs to start listening on :param multiaddrs: one or many multiaddrs to start listening on
:return: true if at least one success :return: true if at least one success
@ -189,7 +193,7 @@ class Swarm(INetwork):
muxed_conn = await self.upgrader.upgrade_connection( muxed_conn = await self.upgrader.upgrade_connection(
secured_conn, peer_id secured_conn, peer_id
) )
muxed_conn.run(nursery) self.manager.run_child_service(muxed_conn)
except MuxerUpgradeFailure as error: except MuxerUpgradeFailure as error:
error_msg = "fail to upgrade mux for peer %s" error_msg = "fail to upgrade mux for peer %s"
logger.debug(error_msg, peer_id) logger.debug(error_msg, peer_id)
@ -200,6 +204,8 @@ class Swarm(INetwork):
await self.add_conn(muxed_conn) await self.add_conn(muxed_conn)
logger.debug("successfully opened connection to peer %s", peer_id) logger.debug("successfully opened connection to peer %s", peer_id)
# FIXME: This is a intentional barrier to prevent from the handler exiting and
# closing the connection.
event = trio.Event() event = trio.Event()
await event.wait() await event.wait()
@ -207,10 +213,11 @@ class Swarm(INetwork):
# Success # Success
listener = self.transport.create_listener(conn_handler) listener = self.transport.create_listener(conn_handler)
self.listeners[str(maddr)] = listener self.listeners[str(maddr)] = listener
await listener.listen(maddr, nursery) # FIXME: Hack
await listener.listen(maddr, self.manager._task_nursery)
# Call notifiers since event occurred # Call notifiers since event occurred
self.notify_listen(maddr) await self.notify_listen(maddr)
return True return True
except IOError: except IOError:
@ -225,15 +232,16 @@ class Swarm(INetwork):
# Reference: https://github.com/libp2p/go-libp2p-swarm/blob/8be680aef8dea0a4497283f2f98470c2aeae6b65/swarm.go#L124-L134 # noqa: E501 # Reference: https://github.com/libp2p/go-libp2p-swarm/blob/8be680aef8dea0a4497283f2f98470c2aeae6b65/swarm.go#L124-L134 # noqa: E501
# Close listeners # Close listeners
await asyncio.gather( # await asyncio.gather(
*[listener.close() for listener in self.listeners.values()] # *[listener.close() for listener in self.listeners.values()]
) # )
# Close connections
await asyncio.gather(
*[connection.close() for connection in self.connections.values()]
)
# # Close connections
# await asyncio.gather(
# *[connection.close() for connection in self.connections.values()]
# )
self.manager.stop()
await self.manager.wait_finished()
logger.debug("swarm successfully closed") logger.debug("swarm successfully closed")
async def close_peer(self, peer_id: ID) -> None: async def close_peer(self, peer_id: ID) -> None:
@ -253,11 +261,12 @@ class Swarm(INetwork):
and start to monitor the connection for its new streams and and start to monitor the connection for its new streams and
disconnection.""" disconnection."""
swarm_conn = SwarmConn(muxed_conn, self) swarm_conn = SwarmConn(muxed_conn, self)
manager = self.manager.run_child_service(swarm_conn)
# Store muxed_conn with peer id # Store muxed_conn with peer id
self.connections[muxed_conn.peer_id] = swarm_conn self.connections[muxed_conn.peer_id] = swarm_conn
# Call notifiers since event occurred # Call notifiers since event occurred
self.notify_connected(swarm_conn) self.manager.run_task(self.notify_connected, swarm_conn)
await swarm_conn.start() await manager.wait_started()
return swarm_conn return swarm_conn
def remove_conn(self, swarm_conn: SwarmConn) -> None: def remove_conn(self, swarm_conn: SwarmConn) -> None:
@ -281,20 +290,26 @@ class Swarm(INetwork):
""" """
self.notifees.append(notifee) self.notifees.append(notifee)
def notify_opened_stream(self, stream: INetStream) -> None: async def notify_opened_stream(self, stream: INetStream) -> None:
asyncio.gather( async with trio.open_nursery() as nursery:
*[notifee.opened_stream(self, stream) for notifee in self.notifees] for notifee in self.notifees:
) nursery.start_soon(notifee.opened_stream, self, stream)
# TODO: `notify_closed_stream` # TODO: `notify_closed_stream`
def notify_connected(self, conn: INetConn) -> None: async def notify_connected(self, conn: INetConn) -> None:
asyncio.gather(*[notifee.connected(self, conn) for notifee in self.notifees]) async with trio.open_nursery() as nursery:
for notifee in self.notifees:
nursery.start_soon(notifee.connected, self, conn)
def notify_disconnected(self, conn: INetConn) -> None: async def notify_disconnected(self, conn: INetConn) -> None:
asyncio.gather(*[notifee.disconnected(self, conn) for notifee in self.notifees]) async with trio.open_nursery() as nursery:
for notifee in self.notifees:
nursery.start_soon(notifee.disconnected, self, conn)
def notify_listen(self, multiaddr: Multiaddr) -> None: async def notify_listen(self, multiaddr: Multiaddr) -> None:
asyncio.gather(*[notifee.listen(self, multiaddr) for notifee in self.notifees]) async with trio.open_nursery() as nursery:
for notifee in self.notifees:
nursery.start_soon(notifee.listen, self, multiaddr)
# TODO: `notify_listen_close` # TODO: `notify_listen_close`

View File

@ -4,6 +4,7 @@ from typing import Any # noqa: F401
from typing import Awaitable, Dict, List, Optional, Tuple from typing import Awaitable, Dict, List, Optional, Tuple
import trio import trio
from async_service import Service
from libp2p.exceptions import ParseError from libp2p.exceptions import ParseError
from libp2p.io.exceptions import IncompleteReadError from libp2p.io.exceptions import IncompleteReadError
@ -17,6 +18,7 @@ from libp2p.utils import (
encode_uvarint, encode_uvarint,
encode_varint_prefixed, encode_varint_prefixed,
read_varint_prefixed_bytes, read_varint_prefixed_bytes,
TrioQueue,
) )
from .constants import HeaderTags from .constants import HeaderTags
@ -29,7 +31,7 @@ MPLEX_PROTOCOL_ID = TProtocol("/mplex/6.7.0")
logger = logging.getLogger("libp2p.stream_muxer.mplex.mplex") logger = logging.getLogger("libp2p.stream_muxer.mplex.mplex")
class Mplex(IMuxedConn): class Mplex(IMuxedConn, Service):
""" """
reference: https://github.com/libp2p/go-mplex/blob/master/multiplex.go reference: https://github.com/libp2p/go-mplex/blob/master/multiplex.go
""" """
@ -38,10 +40,10 @@ class Mplex(IMuxedConn):
peer_id: ID peer_id: ID
next_channel_id: int next_channel_id: int
streams: Dict[StreamID, MplexStream] streams: Dict[StreamID, MplexStream]
streams_lock: asyncio.Lock streams_lock: trio.Lock
new_stream_queue: "asyncio.Queue[IMuxedStream]" new_stream_queue: "TrioQueue[IMuxedStream]"
event_shutting_down: asyncio.Event event_shutting_down: trio.Event
event_closed: asyncio.Event event_closed: trio.Event
def __init__(self, secured_conn: ISecureConn, peer_id: ID) -> None: def __init__(self, secured_conn: ISecureConn, peer_id: ID) -> None:
""" """
@ -61,13 +63,14 @@ class Mplex(IMuxedConn):
# Mapping from stream ID -> buffer of messages for that stream # Mapping from stream ID -> buffer of messages for that stream
self.streams = {} self.streams = {}
self.streams_lock = asyncio.Lock() self.streams_lock = trio.Lock()
self.new_stream_queue = asyncio.Queue() self.new_stream_queue = TrioQueue()
self.event_shutting_down = asyncio.Event() self.event_shutting_down = trio.Event()
self.event_closed = asyncio.Event() self.event_closed = trio.Event()
def run(self, nursery): async def run(self):
nursery.start_soon(self.handle_incoming) self.manager.run_task(self.handle_incoming)
await self.manager.wait_finished()
@property @property
def is_initiator(self) -> bool: def is_initiator(self) -> bool:

View File

@ -136,7 +136,6 @@ class MplexStream(IMuxedStream):
nursery.start_soon( nursery.start_soon(
self.muxed_conn.send_message, flag, None, self.stream_id self.muxed_conn.send_message, flag, None, self.stream_id
) )
await trio.sleep(0)
self.event_local_closed.set() self.event_local_closed.set()
self.event_remote_closed.set() self.event_remote_closed.set()

View File

@ -1,8 +1,9 @@
import asyncio import trio
from contextlib import asynccontextmanager from contextlib import asynccontextmanager, AsyncExitStack
from typing import Any, AsyncIterator, Dict, Tuple, cast from typing import Any, AsyncIterator, Dict, Tuple, cast
import factory import factory
from async_service import background_trio_service
from libp2p import generate_new_rsa_identity, generate_peer_id_from from libp2p import generate_new_rsa_identity, generate_peer_id_from
from libp2p.crypto.keys import KeyPair from libp2p.crypto.keys import KeyPair
@ -61,6 +62,7 @@ class SwarmFactory(factory.Factory):
transport = factory.LazyFunction(TCP) transport = factory.LazyFunction(TCP)
@classmethod @classmethod
@asynccontextmanager
async def create_and_listen( async def create_and_listen(
cls, is_secure: bool, key_pair: KeyPair = None, muxer_opt: TMuxerOptions = None cls, is_secure: bool, key_pair: KeyPair = None, muxer_opt: TMuxerOptions = None
) -> Swarm: ) -> Swarm:
@ -73,20 +75,23 @@ class SwarmFactory(factory.Factory):
if muxer_opt is not None: if muxer_opt is not None:
optional_kwargs["muxer_opt"] = muxer_opt optional_kwargs["muxer_opt"] = muxer_opt
swarm = cls(is_secure=is_secure, **optional_kwargs) swarm = cls(is_secure=is_secure, **optional_kwargs)
await swarm.listen(LISTEN_MADDR) async with background_trio_service(swarm):
return swarm await swarm.listen(LISTEN_MADDR)
yield swarm
@classmethod @classmethod
@asynccontextmanager
async def create_batch_and_listen( async def create_batch_and_listen(
cls, is_secure: bool, number: int, muxer_opt: TMuxerOptions = None cls, is_secure: bool, number: int, muxer_opt: TMuxerOptions = None
) -> Tuple[Swarm, ...]: ) -> Tuple[Swarm, ...]:
# Ignore typing since we are removing asyncio soon async with AsyncExitStack() as stack:
return await asyncio.gather( # type: ignore ctx_mgrs = [
*[ await stack.enter_async_context(
cls.create_and_listen(is_secure=is_secure, muxer_opt=muxer_opt) cls.create_and_listen(is_secure=is_secure, muxer_opt=muxer_opt)
)
for _ in range(number) for _ in range(number)
] ]
) yield ctx_mgrs
class HostFactory(factory.Factory): class HostFactory(factory.Factory):
@ -103,20 +108,23 @@ class HostFactory(factory.Factory):
) )
@classmethod @classmethod
@asynccontextmanager
async def create_batch_and_listen( async def create_batch_and_listen(
cls, is_secure: bool, number: int cls, is_secure: bool, number: int
) -> Tuple[BasicHost, ...]: ) -> Tuple[BasicHost, ...]:
key_pairs = [generate_new_rsa_identity() for _ in range(number)] key_pairs = [generate_new_rsa_identity() for _ in range(number)]
swarms = await asyncio.gather( async with AsyncExitStack() as stack:
*[ swarms = [
SwarmFactory.create_and_listen(is_secure, key_pair) await stack.enter_async_context(
SwarmFactory.create_and_listen(is_secure, key_pair)
)
for key_pair in key_pairs for key_pair in key_pairs
] ]
) hosts = tuple(
return tuple( BasicHost(key_pair.public_key, swarm)
BasicHost(key_pair.public_key, swarm) for key_pair, swarm in zip(key_pairs, swarms)
for key_pair, swarm in zip(key_pairs, swarms) )
) yield hosts
class FloodsubFactory(factory.Factory): class FloodsubFactory(factory.Factory):
@ -150,73 +158,60 @@ class PubsubFactory(factory.Factory):
cache_size = None cache_size = None
@asynccontextmanager
async def swarm_pair_factory( async def swarm_pair_factory(
is_secure: bool, muxer_opt: TMuxerOptions = None is_secure: bool, muxer_opt: TMuxerOptions = None
) -> Tuple[Swarm, Swarm]: ) -> Tuple[Swarm, Swarm]:
swarms = await SwarmFactory.create_batch_and_listen( async with SwarmFactory.create_batch_and_listen(
is_secure, 2, muxer_opt=muxer_opt is_secure, 2, muxer_opt=muxer_opt
) ) as swarms:
await connect_swarm(swarms[0], swarms[1]) await connect_swarm(swarms[0], swarms[1])
return swarms[0], swarms[1] yield swarms[0], swarms[1]
async def host_pair_factory(is_secure: bool) -> Tuple[BasicHost, BasicHost]:
hosts = await HostFactory.create_batch_and_listen(is_secure, 2)
await connect(hosts[0], hosts[1])
return hosts[0], hosts[1]
@asynccontextmanager @asynccontextmanager
async def pair_of_connected_hosts( async def host_pair_factory(is_secure: bool) -> Tuple[BasicHost, BasicHost]:
is_secure: bool = True async with HostFactory.create_batch_and_listen(is_secure, 2) as hosts:
) -> AsyncIterator[Tuple[BasicHost, BasicHost]]: await connect(hosts[0], hosts[1])
a, b = await host_pair_factory(is_secure) yield hosts[0], hosts[1]
yield a, b
close_tasks = (a.close(), b.close())
await asyncio.gather(*close_tasks)
@asynccontextmanager
async def swarm_conn_pair_factory( async def swarm_conn_pair_factory(
is_secure: bool, muxer_opt: TMuxerOptions = None is_secure: bool, muxer_opt: TMuxerOptions = None
) -> Tuple[SwarmConn, Swarm, SwarmConn, Swarm]: ) -> Tuple[SwarmConn, SwarmConn]:
swarms = await swarm_pair_factory(is_secure) async with swarm_pair_factory(is_secure) as swarms:
conn_0 = swarms[0].connections[swarms[1].get_peer_id()] conn_0 = swarms[0].connections[swarms[1].get_peer_id()]
conn_1 = swarms[1].connections[swarms[0].get_peer_id()] conn_1 = swarms[1].connections[swarms[0].get_peer_id()]
return cast(SwarmConn, conn_0), swarms[0], cast(SwarmConn, conn_1), swarms[1] yield cast(SwarmConn, conn_0), cast(SwarmConn, conn_1)
async def mplex_conn_pair_factory(is_secure: bool) -> Tuple[Mplex, Swarm, Mplex, Swarm]: @asynccontextmanager
async def mplex_conn_pair_factory(is_secure: bool) -> Tuple[Mplex, Mplex]:
muxer_opt = {MPLEX_PROTOCOL_ID: Mplex} muxer_opt = {MPLEX_PROTOCOL_ID: Mplex}
conn_0, swarm_0, conn_1, swarm_1 = await swarm_conn_pair_factory( async with swarm_conn_pair_factory(is_secure, muxer_opt=muxer_opt) as swarm_pair:
is_secure, muxer_opt=muxer_opt yield (
) cast(Mplex, swarm_pair[0].muxed_conn),
return ( cast(Mplex, swarm_pair[1].muxed_conn),
cast(Mplex, conn_0.muxed_conn), )
swarm_0,
cast(Mplex, conn_1.muxed_conn),
swarm_1,
)
async def mplex_stream_pair_factory( @asynccontextmanager
is_secure: bool async def mplex_stream_pair_factory(is_secure: bool) -> Tuple[MplexStream, MplexStream]:
) -> Tuple[MplexStream, Swarm, MplexStream, Swarm]: async with mplex_conn_pair_factory(is_secure) as mplex_conn_pair_info:
mplex_conn_0, swarm_0, mplex_conn_1, swarm_1 = await mplex_conn_pair_factory( mplex_conn_0, mplex_conn_1 = mplex_conn_pair_info
is_secure stream_0 = await mplex_conn_0.open_stream()
) await trio.sleep(0.01)
stream_0 = await mplex_conn_0.open_stream() stream_1: MplexStream
await asyncio.sleep(0.01) async with mplex_conn_1.streams_lock:
stream_1: MplexStream if len(mplex_conn_1.streams) != 1:
async with mplex_conn_1.streams_lock: raise Exception("Mplex should not have any stream upon connection")
if len(mplex_conn_1.streams) != 1: stream_1 = tuple(mplex_conn_1.streams.values())[0]
raise Exception("Mplex should not have any stream upon connection") yield cast(MplexStream, stream_0), cast(MplexStream, stream_1)
stream_1 = tuple(mplex_conn_1.streams.values())[0]
return cast(MplexStream, stream_0), swarm_0, stream_1, swarm_1
async def net_stream_pair_factory( @asynccontextmanager
is_secure: bool async def net_stream_pair_factory(is_secure: bool) -> Tuple[INetStream, INetStream]:
) -> Tuple[INetStream, BasicHost, INetStream, BasicHost]:
protocol_id = TProtocol("/example/id/1") protocol_id = TProtocol("/example/id/1")
stream_1: INetStream stream_1: INetStream
@ -226,8 +221,8 @@ async def net_stream_pair_factory(
nonlocal stream_1 nonlocal stream_1
stream_1 = stream stream_1 = stream
host_0, host_1 = await host_pair_factory(is_secure) async with host_pair_factory(is_secure) as hosts:
host_1.set_stream_handler(protocol_id, handler) hosts[1].set_stream_handler(protocol_id, handler)
stream_0 = await host_0.new_stream(host_1.get_id(), [protocol_id]) stream_0 = await hosts[0].new_stream(hosts[1].get_id(), [protocol_id])
return stream_0, host_0, stream_1, host_1 yield stream_0, stream_1

View File

@ -17,7 +17,7 @@ from libp2p.typing import StreamHandlerFn, TProtocol
from .constants import MAX_READ_LEN from .constants import MAX_READ_LEN
async def connect_swarm(swarm_0: Swarm, swarm_1: Swarm, nursery: trio.Nursery) -> None: async def connect_swarm(swarm_0: Swarm, swarm_1: Swarm) -> None:
peer_id = swarm_1.get_peer_id() peer_id = swarm_1.get_peer_id()
addrs = tuple( addrs = tuple(
addr addr
@ -25,7 +25,7 @@ async def connect_swarm(swarm_0: Swarm, swarm_1: Swarm, nursery: trio.Nursery) -
for addr in transport.get_addrs() for addr in transport.get_addrs()
) )
swarm_0.peerstore.add_addrs(peer_id, addrs, 10000) swarm_0.peerstore.add_addrs(peer_id, addrs, 10000)
await swarm_0.dial_peer(peer_id, nursery) await swarm_0.dial_peer(peer_id)
assert swarm_0.get_peer_id() in swarm_1.connections assert swarm_0.get_peer_id() in swarm_1.connections
assert swarm_1.get_peer_id() in swarm_0.connections assert swarm_1.get_peer_id() in swarm_0.connections

View File

@ -25,7 +25,7 @@ class TCPListener(IListener):
self.server = None self.server = None
self.handler = handler_function self.handler = handler_function
async def listen(self, maddr: Multiaddr, nursery) -> bool: async def listen(self, maddr: Multiaddr, nursery: trio.Nursery) -> bool:
""" """
put listener in listening mode and wait for incoming connections. put listener in listening mode and wait for incoming connections.

View File

@ -4,12 +4,12 @@ import secrets
import pytest import pytest
from libp2p.host.ping import ID, PING_LENGTH from libp2p.host.ping import ID, PING_LENGTH
from libp2p.tools.factories import pair_of_connected_hosts from libp2p.tools.factories import host_pair_factory
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_ping_once(): async def test_ping_once():
async with pair_of_connected_hosts() as (host_a, host_b): async with host_pair_factory() as (host_a, host_b):
stream = await host_b.new_stream(host_a.get_id(), (ID,)) stream = await host_b.new_stream(host_a.get_id(), (ID,))
some_ping = secrets.token_bytes(PING_LENGTH) some_ping = secrets.token_bytes(PING_LENGTH)
await stream.write(some_ping) await stream.write(some_ping)
@ -23,7 +23,7 @@ SOME_PING_COUNT = 3
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_ping_several(): async def test_ping_several():
async with pair_of_connected_hosts() as (host_a, host_b): async with host_pair_factory() as (host_a, host_b):
stream = await host_b.new_stream(host_a.get_id(), (ID,)) stream = await host_b.new_stream(host_a.get_id(), (ID,))
for _ in range(SOME_PING_COUNT): for _ in range(SOME_PING_COUNT):
some_ping = secrets.token_bytes(PING_LENGTH) some_ping = secrets.token_bytes(PING_LENGTH)

View File

@ -2,12 +2,12 @@ import pytest
from libp2p.identity.identify.pb.identify_pb2 import Identify from libp2p.identity.identify.pb.identify_pb2 import Identify
from libp2p.identity.identify.protocol import ID, _mk_identify_protobuf from libp2p.identity.identify.protocol import ID, _mk_identify_protobuf
from libp2p.tools.factories import pair_of_connected_hosts from libp2p.tools.factories import host_pair_factory
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_identify_protocol(): async def test_identify_protocol():
async with pair_of_connected_hosts() as (host_a, host_b): async with host_pair_factory() as (host_a, host_b):
stream = await host_b.new_stream(host_a.get_id(), (ID,)) stream = await host_b.new_stream(host_a.get_id(), (ID,))
response = await stream.read() response = await stream.read()
await stream.close() await stream.close()

View File

@ -11,26 +11,17 @@ from libp2p.tools.factories import (
@pytest.fixture @pytest.fixture
async def net_stream_pair(is_host_secure): async def net_stream_pair(is_host_secure):
stream_0, host_0, stream_1, host_1 = await net_stream_pair_factory(is_host_secure) async with net_stream_pair_factory(is_host_secure) as net_stream_pair:
try: yield net_stream_pair
yield stream_0, stream_1
finally:
await asyncio.gather(*[host_0.close(), host_1.close()])
@pytest.fixture @pytest.fixture
async def swarm_pair(is_host_secure): async def swarm_pair(is_host_secure):
swarm_0, swarm_1 = await swarm_pair_factory(is_host_secure) async with swarm_pair_factory(is_host_secure) as swarms:
try: yield swarms
yield swarm_0, swarm_1
finally:
await asyncio.gather(*[swarm_0.close(), swarm_1.close()])
@pytest.fixture @pytest.fixture
async def swarm_conn_pair(is_host_secure): async def swarm_conn_pair(is_host_secure):
conn_0, swarm_0, conn_1, swarm_1 = await swarm_conn_pair_factory(is_host_secure) async with swarm_conn_pair_factory(is_host_secure) as swarm_conn_pair:
try: yield swarm_conn_pair
yield conn_0, conn_1
finally:
await asyncio.gather(*[swarm_0.close(), swarm_1.close()])

View File

@ -1,88 +1,83 @@
import asyncio import trio
import pytest import pytest
from trio.testing import wait_all_tasks_blocked
from libp2p.network.exceptions import SwarmException from libp2p.network.exceptions import SwarmException
from libp2p.tools.factories import SwarmFactory from libp2p.tools.factories import SwarmFactory
from libp2p.tools.utils import connect_swarm from libp2p.tools.utils import connect_swarm
@pytest.mark.asyncio @pytest.mark.trio
async def test_swarm_dial_peer(is_host_secure): async def test_swarm_dial_peer(is_host_secure):
swarms = await SwarmFactory.create_batch_and_listen(is_host_secure, 3) async with SwarmFactory.create_batch_and_listen(is_host_secure, 3) as swarms:
# Test: No addr found. # Test: No addr found.
with pytest.raises(SwarmException): with pytest.raises(SwarmException):
await swarms[0].dial_peer(swarms[1].get_peer_id())
# Test: len(addr) in the peerstore is 0.
swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), [], 10000)
with pytest.raises(SwarmException):
await swarms[0].dial_peer(swarms[1].get_peer_id())
# Test: Succeed if addrs of the peer_id are present in the peerstore.
addrs = tuple(
addr
for transport in swarms[1].listeners.values()
for addr in transport.get_addrs()
)
swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), addrs, 10000)
await swarms[0].dial_peer(swarms[1].get_peer_id()) await swarms[0].dial_peer(swarms[1].get_peer_id())
assert swarms[0].get_peer_id() in swarms[1].connections
assert swarms[1].get_peer_id() in swarms[0].connections
# Test: len(addr) in the peerstore is 0. # Test: Reuse connections when we already have ones with a peer.
swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), [], 10000) conn_to_1 = swarms[0].connections[swarms[1].get_peer_id()]
with pytest.raises(SwarmException): conn = await swarms[0].dial_peer(swarms[1].get_peer_id())
await swarms[0].dial_peer(swarms[1].get_peer_id()) assert conn is conn_to_1
# Test: Succeed if addrs of the peer_id are present in the peerstore.
addrs = tuple(
addr
for transport in swarms[1].listeners.values()
for addr in transport.get_addrs()
)
swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), addrs, 10000)
await swarms[0].dial_peer(swarms[1].get_peer_id())
assert swarms[0].get_peer_id() in swarms[1].connections
assert swarms[1].get_peer_id() in swarms[0].connections
# Test: Reuse connections when we already have ones with a peer.
conn_to_1 = swarms[0].connections[swarms[1].get_peer_id()]
conn = await swarms[0].dial_peer(swarms[1].get_peer_id())
assert conn is conn_to_1
# Clean up
await asyncio.gather(*[swarm.close() for swarm in swarms])
@pytest.mark.asyncio @pytest.mark.trio
async def test_swarm_close_peer(is_host_secure): async def test_swarm_close_peer(is_host_secure):
swarms = await SwarmFactory.create_batch_and_listen(is_host_secure, 3) async with SwarmFactory.create_batch_and_listen(is_host_secure, 3) as swarms:
# 0 <> 1 <> 2 # 0 <> 1 <> 2
await connect_swarm(swarms[0], swarms[1]) await connect_swarm(swarms[0], swarms[1])
await connect_swarm(swarms[1], swarms[2]) await connect_swarm(swarms[1], swarms[2])
# peer 1 closes peer 0 # peer 1 closes peer 0
await swarms[1].close_peer(swarms[0].get_peer_id()) await swarms[1].close_peer(swarms[0].get_peer_id())
await asyncio.sleep(0.01) await trio.sleep(0.01)
# 0 1 <> 2 await wait_all_tasks_blocked()
assert len(swarms[0].connections) == 0 # 0 1 <> 2
assert ( assert len(swarms[0].connections) == 0
len(swarms[1].connections) == 1 assert (
and swarms[2].get_peer_id() in swarms[1].connections len(swarms[1].connections) == 1
) and swarms[2].get_peer_id() in swarms[1].connections
)
# peer 1 is closed by peer 2 # peer 1 is closed by peer 2
await swarms[2].close_peer(swarms[1].get_peer_id()) await swarms[2].close_peer(swarms[1].get_peer_id())
await asyncio.sleep(0.01) await trio.sleep(0.01)
# 0 1 2 # 0 1 2
assert len(swarms[1].connections) == 0 and len(swarms[2].connections) == 0 assert len(swarms[1].connections) == 0 and len(swarms[2].connections) == 0
await connect_swarm(swarms[0], swarms[1]) await connect_swarm(swarms[0], swarms[1])
# 0 <> 1 2 # 0 <> 1 2
assert ( assert (
len(swarms[0].connections) == 1 len(swarms[0].connections) == 1
and swarms[1].get_peer_id() in swarms[0].connections and swarms[1].get_peer_id() in swarms[0].connections
) )
assert ( assert (
len(swarms[1].connections) == 1 len(swarms[1].connections) == 1
and swarms[0].get_peer_id() in swarms[1].connections and swarms[0].get_peer_id() in swarms[1].connections
) )
# peer 0 closes peer 1 # peer 0 closes peer 1
await swarms[0].close_peer(swarms[1].get_peer_id()) await swarms[0].close_peer(swarms[1].get_peer_id())
await asyncio.sleep(0.01) await trio.sleep(0.01)
# 0 1 2 # 0 1 2
assert len(swarms[1].connections) == 0 and len(swarms[2].connections) == 0 assert len(swarms[1].connections) == 0 and len(swarms[2].connections) == 0
# Clean up
await asyncio.gather(*[swarm.close() for swarm in swarms])
@pytest.mark.asyncio @pytest.mark.trio
async def test_swarm_remove_conn(swarm_pair): async def test_swarm_remove_conn(swarm_pair):
swarm_0, swarm_1 = swarm_pair swarm_0, swarm_1 = swarm_pair
conn_0 = swarm_0.connections[swarm_1.get_peer_id()] conn_0 = swarm_0.connections[swarm_1.get_peer_id()]

View File

@ -7,23 +7,13 @@ from libp2p.tools.factories import mplex_conn_pair_factory, mplex_stream_pair_fa
@pytest.fixture @pytest.fixture
async def mplex_conn_pair(is_host_secure): async def mplex_conn_pair(is_host_secure):
mplex_conn_0, swarm_0, mplex_conn_1, swarm_1 = await mplex_conn_pair_factory( async with mplex_conn_pair_factory(is_host_secure) as mplex_conn_pair:
is_host_secure assert mplex_conn_pair[0].is_initiator
) assert not mplex_conn_pair[1].is_initiator
assert mplex_conn_0.is_initiator yield mplex_conn_pair[0], mplex_conn_pair[1]
assert not mplex_conn_1.is_initiator
try:
yield mplex_conn_0, mplex_conn_1
finally:
await asyncio.gather(*[swarm_0.close(), swarm_1.close()])
@pytest.fixture @pytest.fixture
async def mplex_stream_pair(is_host_secure): async def mplex_stream_pair(is_host_secure):
mplex_stream_0, swarm_0, mplex_stream_1, swarm_1 = await mplex_stream_pair_factory( async with mplex_stream_pair_factory(is_host_secure) as mplex_stream_pair:
is_host_secure yield mplex_stream_pair
)
try:
yield mplex_stream_0, mplex_stream_1
finally:
await asyncio.gather(*[swarm_0.close(), swarm_1.close()])

View File

@ -1,5 +1,3 @@
import asyncio
import pytest import pytest
import trio import trio
@ -12,25 +10,26 @@ from libp2p.tools.constants import MAX_READ_LEN, LISTEN_MADDR
from libp2p.tools.factories import SwarmFactory from libp2p.tools.factories import SwarmFactory
from libp2p.tools.utils import connect_swarm from libp2p.tools.utils import connect_swarm
DATA = b"data_123" DATA = b"data_123"
@pytest.mark.trio @pytest.mark.trio
async def test_mplex_stream_read_write(nursery): async def test_mplex_stream_read_write():
swarm0, swarm1 = SwarmFactory(), SwarmFactory() async with SwarmFactory.create_batch_and_listen(False, 2) as swarms:
await swarm0.listen(LISTEN_MADDR, nursery=nursery) await swarms[0].listen(LISTEN_MADDR)
await swarm1.listen(LISTEN_MADDR, nursery=nursery) await swarms[1].listen(LISTEN_MADDR)
await connect_swarm(swarm0, swarm1, nursery) await connect_swarm(swarms[0], swarms[1])
conn_0 = swarm0.connections[swarm1.get_peer_id()] conn_0 = swarms[0].connections[swarms[1].get_peer_id()]
conn_1 = swarm1.connections[swarm0.get_peer_id()] conn_1 = swarms[1].connections[swarms[0].get_peer_id()]
stream_0 = await conn_0.muxed_conn.open_stream() stream_0 = await conn_0.muxed_conn.open_stream()
await trio.sleep(1) await trio.sleep(1)
stream_1 = tuple(conn_1.muxed_conn.streams.values())[0] stream_1 = tuple(conn_1.muxed_conn.streams.values())[0]
await stream_0.write(DATA) await stream_0.write(DATA)
assert (await stream_1.read(MAX_READ_LEN)) == DATA assert (await stream_1.read(MAX_READ_LEN)) == DATA
@pytest.mark.asyncio @pytest.mark.trio
async def test_mplex_stream_pair_read_until_eof(mplex_stream_pair): async def test_mplex_stream_pair_read_until_eof(mplex_stream_pair):
read_bytes = bytearray() read_bytes = bytearray()
stream_0, stream_1 = mplex_stream_pair stream_0, stream_1 = mplex_stream_pair
@ -38,43 +37,43 @@ async def test_mplex_stream_pair_read_until_eof(mplex_stream_pair):
async def read_until_eof(): async def read_until_eof():
read_bytes.extend(await stream_1.read()) read_bytes.extend(await stream_1.read())
task = asyncio.ensure_future(read_until_eof()) task = trio.ensure_future(read_until_eof())
expected_data = bytearray() expected_data = bytearray()
# Test: `read` doesn't return before `close` is called. # Test: `read` doesn't return before `close` is called.
await stream_0.write(DATA) await stream_0.write(DATA)
expected_data.extend(DATA) expected_data.extend(DATA)
await asyncio.sleep(0.01) await trio.sleep(0.01)
assert len(read_bytes) == 0 assert len(read_bytes) == 0
# Test: `read` doesn't return before `close` is called. # Test: `read` doesn't return before `close` is called.
await stream_0.write(DATA) await stream_0.write(DATA)
expected_data.extend(DATA) expected_data.extend(DATA)
await asyncio.sleep(0.01) await trio.sleep(0.01)
assert len(read_bytes) == 0 assert len(read_bytes) == 0
# Test: Close the stream, `read` returns, and receive previous sent data. # Test: Close the stream, `read` returns, and receive previous sent data.
await stream_0.close() await stream_0.close()
await asyncio.sleep(0.01) await trio.sleep(0.01)
assert read_bytes == expected_data assert read_bytes == expected_data
task.cancel() task.cancel()
@pytest.mark.asyncio @pytest.mark.trio
async def test_mplex_stream_read_after_remote_closed(mplex_stream_pair): async def test_mplex_stream_read_after_remote_closed(mplex_stream_pair):
stream_0, stream_1 = mplex_stream_pair stream_0, stream_1 = mplex_stream_pair
assert not stream_1.event_remote_closed.is_set() assert not stream_1.event_remote_closed.is_set()
await stream_0.write(DATA) await stream_0.write(DATA)
await stream_0.close() await stream_0.close()
await asyncio.sleep(0.01) await trio.sleep(0.01)
assert stream_1.event_remote_closed.is_set() assert stream_1.event_remote_closed.is_set()
assert (await stream_1.read(MAX_READ_LEN)) == DATA assert (await stream_1.read(MAX_READ_LEN)) == DATA
with pytest.raises(MplexStreamEOF): with pytest.raises(MplexStreamEOF):
await stream_1.read(MAX_READ_LEN) await stream_1.read(MAX_READ_LEN)
@pytest.mark.asyncio @pytest.mark.trio
async def test_mplex_stream_read_after_local_reset(mplex_stream_pair): async def test_mplex_stream_read_after_local_reset(mplex_stream_pair):
stream_0, stream_1 = mplex_stream_pair stream_0, stream_1 = mplex_stream_pair
await stream_0.reset() await stream_0.reset()
@ -82,29 +81,29 @@ async def test_mplex_stream_read_after_local_reset(mplex_stream_pair):
await stream_0.read(MAX_READ_LEN) await stream_0.read(MAX_READ_LEN)
@pytest.mark.asyncio @pytest.mark.trio
async def test_mplex_stream_read_after_remote_reset(mplex_stream_pair): async def test_mplex_stream_read_after_remote_reset(mplex_stream_pair):
stream_0, stream_1 = mplex_stream_pair stream_0, stream_1 = mplex_stream_pair
await stream_0.write(DATA) await stream_0.write(DATA)
await stream_0.reset() await stream_0.reset()
# Sleep to let `stream_1` receive the message. # Sleep to let `stream_1` receive the message.
await asyncio.sleep(0.01) await trio.sleep(0.01)
with pytest.raises(MplexStreamReset): with pytest.raises(MplexStreamReset):
await stream_1.read(MAX_READ_LEN) await stream_1.read(MAX_READ_LEN)
@pytest.mark.asyncio @pytest.mark.trio
async def test_mplex_stream_read_after_remote_closed_and_reset(mplex_stream_pair): async def test_mplex_stream_read_after_remote_closed_and_reset(mplex_stream_pair):
stream_0, stream_1 = mplex_stream_pair stream_0, stream_1 = mplex_stream_pair
await stream_0.write(DATA) await stream_0.write(DATA)
await stream_0.close() await stream_0.close()
await stream_0.reset() await stream_0.reset()
# Sleep to let `stream_1` receive the message. # Sleep to let `stream_1` receive the message.
await asyncio.sleep(0.01) await trio.sleep(0.01)
assert (await stream_1.read(MAX_READ_LEN)) == DATA assert (await stream_1.read(MAX_READ_LEN)) == DATA
@pytest.mark.asyncio @pytest.mark.trio
async def test_mplex_stream_write_after_local_closed(mplex_stream_pair): async def test_mplex_stream_write_after_local_closed(mplex_stream_pair):
stream_0, stream_1 = mplex_stream_pair stream_0, stream_1 = mplex_stream_pair
await stream_0.write(DATA) await stream_0.write(DATA)
@ -113,7 +112,7 @@ async def test_mplex_stream_write_after_local_closed(mplex_stream_pair):
await stream_0.write(DATA) await stream_0.write(DATA)
@pytest.mark.asyncio @pytest.mark.trio
async def test_mplex_stream_write_after_local_reset(mplex_stream_pair): async def test_mplex_stream_write_after_local_reset(mplex_stream_pair):
stream_0, stream_1 = mplex_stream_pair stream_0, stream_1 = mplex_stream_pair
await stream_0.reset() await stream_0.reset()
@ -121,16 +120,16 @@ async def test_mplex_stream_write_after_local_reset(mplex_stream_pair):
await stream_0.write(DATA) await stream_0.write(DATA)
@pytest.mark.asyncio @pytest.mark.trio
async def test_mplex_stream_write_after_remote_reset(mplex_stream_pair): async def test_mplex_stream_write_after_remote_reset(mplex_stream_pair):
stream_0, stream_1 = mplex_stream_pair stream_0, stream_1 = mplex_stream_pair
await stream_1.reset() await stream_1.reset()
await asyncio.sleep(0.01) await trio.sleep(0.01)
with pytest.raises(MplexStreamClosed): with pytest.raises(MplexStreamClosed):
await stream_0.write(DATA) await stream_0.write(DATA)
@pytest.mark.asyncio @pytest.mark.trio
async def test_mplex_stream_both_close(mplex_stream_pair): async def test_mplex_stream_both_close(mplex_stream_pair):
stream_0, stream_1 = mplex_stream_pair stream_0, stream_1 = mplex_stream_pair
# Flags are not set initially. # Flags are not set initially.
@ -144,7 +143,7 @@ async def test_mplex_stream_both_close(mplex_stream_pair):
# Test: Close one side. # Test: Close one side.
await stream_0.close() await stream_0.close()
await asyncio.sleep(0.01) await trio.sleep(0.01)
assert stream_0.event_local_closed.is_set() assert stream_0.event_local_closed.is_set()
assert not stream_1.event_local_closed.is_set() assert not stream_1.event_local_closed.is_set()
@ -156,7 +155,7 @@ async def test_mplex_stream_both_close(mplex_stream_pair):
# Test: Close the other side. # Test: Close the other side.
await stream_1.close() await stream_1.close()
await asyncio.sleep(0.01) await trio.sleep(0.01)
# Both sides are closed. # Both sides are closed.
assert stream_0.event_local_closed.is_set() assert stream_0.event_local_closed.is_set()
assert stream_1.event_local_closed.is_set() assert stream_1.event_local_closed.is_set()
@ -170,11 +169,11 @@ async def test_mplex_stream_both_close(mplex_stream_pair):
await stream_0.reset() await stream_0.reset()
@pytest.mark.asyncio @pytest.mark.trio
async def test_mplex_stream_reset(mplex_stream_pair): async def test_mplex_stream_reset(mplex_stream_pair):
stream_0, stream_1 = mplex_stream_pair stream_0, stream_1 = mplex_stream_pair
await stream_0.reset() await stream_0.reset()
await asyncio.sleep(0.01) await trio.sleep(0.01)
# Both sides are closed. # Both sides are closed.
assert stream_0.event_local_closed.is_set() assert stream_0.event_local_closed.is_set()