Rewrite factories, made some of the test running
This commit is contained in:
parent
417b5e7d61
commit
ec43c25b45
|
@ -1,6 +1,8 @@
|
||||||
import asyncio
|
|
||||||
from typing import TYPE_CHECKING, Any, Awaitable, List, Set, Tuple
|
from typing import TYPE_CHECKING, Any, Awaitable, List, Set, Tuple
|
||||||
|
|
||||||
|
import trio
|
||||||
|
from async_service import Service
|
||||||
|
|
||||||
from libp2p.network.connection.net_connection_interface import INetConn
|
from libp2p.network.connection.net_connection_interface import INetConn
|
||||||
from libp2p.network.stream.net_stream import NetStream
|
from libp2p.network.stream.net_stream import NetStream
|
||||||
from libp2p.stream_muxer.abc import IMuxedConn, IMuxedStream
|
from libp2p.stream_muxer.abc import IMuxedConn, IMuxedStream
|
||||||
|
@ -15,21 +17,17 @@ Reference: https://github.com/libp2p/go-libp2p-swarm/blob/04c86bbdafd390651cb2ee
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
class SwarmConn(INetConn):
|
class SwarmConn(INetConn, Service):
|
||||||
muxed_conn: IMuxedConn
|
muxed_conn: IMuxedConn
|
||||||
swarm: "Swarm"
|
swarm: "Swarm"
|
||||||
streams: Set[NetStream]
|
streams: Set[NetStream]
|
||||||
event_closed: asyncio.Event
|
event_closed: trio.Event
|
||||||
|
|
||||||
_tasks: List["asyncio.Future[Any]"]
|
|
||||||
|
|
||||||
def __init__(self, muxed_conn: IMuxedConn, swarm: "Swarm") -> None:
|
def __init__(self, muxed_conn: IMuxedConn, swarm: "Swarm") -> None:
|
||||||
self.muxed_conn = muxed_conn
|
self.muxed_conn = muxed_conn
|
||||||
self.swarm = swarm
|
self.swarm = swarm
|
||||||
self.streams = set()
|
self.streams = set()
|
||||||
self.event_closed = asyncio.Event()
|
self.event_closed = trio.Event()
|
||||||
|
|
||||||
self._tasks = []
|
|
||||||
|
|
||||||
async def close(self) -> None:
|
async def close(self) -> None:
|
||||||
if self.event_closed.is_set():
|
if self.event_closed.is_set():
|
||||||
|
@ -45,16 +43,11 @@ class SwarmConn(INetConn):
|
||||||
await stream.reset()
|
await stream.reset()
|
||||||
# Force context switch for stream handlers to process the stream reset event we just emit
|
# Force context switch for stream handlers to process the stream reset event we just emit
|
||||||
# before we cancel the stream handler tasks.
|
# before we cancel the stream handler tasks.
|
||||||
await asyncio.sleep(0.1)
|
await trio.sleep(0.1)
|
||||||
|
|
||||||
for task in self._tasks:
|
# FIXME: Now let `_notify_disconnected` finish first.
|
||||||
task.cancel()
|
|
||||||
try:
|
|
||||||
await task
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
pass
|
|
||||||
# Schedule `self._notify_disconnected` to make it execute after `close` is finished.
|
# Schedule `self._notify_disconnected` to make it execute after `close` is finished.
|
||||||
self._notify_disconnected()
|
await self._notify_disconnected()
|
||||||
|
|
||||||
async def _handle_new_streams(self) -> None:
|
async def _handle_new_streams(self) -> None:
|
||||||
while True:
|
while True:
|
||||||
|
@ -65,7 +58,7 @@ class SwarmConn(INetConn):
|
||||||
# we should break the loop and close the connection.
|
# we should break the loop and close the connection.
|
||||||
break
|
break
|
||||||
# Asynchronously handle the accepted stream, to avoid blocking the next stream.
|
# Asynchronously handle the accepted stream, to avoid blocking the next stream.
|
||||||
await self.run_task(self._handle_muxed_stream(stream))
|
self.manager.run_task(self._handle_muxed_stream, stream)
|
||||||
|
|
||||||
await self.close()
|
await self.close()
|
||||||
|
|
||||||
|
@ -79,28 +72,26 @@ class SwarmConn(INetConn):
|
||||||
self.remove_stream(net_stream)
|
self.remove_stream(net_stream)
|
||||||
|
|
||||||
async def _handle_muxed_stream(self, muxed_stream: IMuxedStream) -> None:
|
async def _handle_muxed_stream(self, muxed_stream: IMuxedStream) -> None:
|
||||||
net_stream = self._add_stream(muxed_stream)
|
net_stream = await self._add_stream(muxed_stream)
|
||||||
if self.swarm.common_stream_handler is not None:
|
if self.swarm.common_stream_handler is not None:
|
||||||
await self.run_task(self._call_stream_handler(net_stream))
|
await self._call_stream_handler(net_stream)
|
||||||
|
|
||||||
def _add_stream(self, muxed_stream: IMuxedStream) -> NetStream:
|
async def _add_stream(self, muxed_stream: IMuxedStream) -> NetStream:
|
||||||
net_stream = NetStream(muxed_stream)
|
net_stream = NetStream(muxed_stream)
|
||||||
self.streams.add(net_stream)
|
self.streams.add(net_stream)
|
||||||
self.swarm.notify_opened_stream(net_stream)
|
await self.swarm.notify_opened_stream(net_stream)
|
||||||
return net_stream
|
return net_stream
|
||||||
|
|
||||||
def _notify_disconnected(self) -> None:
|
async def _notify_disconnected(self) -> None:
|
||||||
self.swarm.notify_disconnected(self)
|
await self.swarm.notify_disconnected(self)
|
||||||
|
|
||||||
async def start(self) -> None:
|
async def run(self) -> None:
|
||||||
await self.run_task(self._handle_new_streams())
|
self.manager.run_task(self._handle_new_streams)
|
||||||
|
await self.manager.wait_finished()
|
||||||
async def run_task(self, coro: Awaitable[Any]) -> None:
|
|
||||||
self._tasks.append(asyncio.ensure_future(coro))
|
|
||||||
|
|
||||||
async def new_stream(self) -> NetStream:
|
async def new_stream(self) -> NetStream:
|
||||||
muxed_stream = await self.muxed_conn.open_stream()
|
muxed_stream = await self.muxed_conn.open_stream()
|
||||||
return self._add_stream(muxed_stream)
|
return await self._add_stream(muxed_stream)
|
||||||
|
|
||||||
async def get_streams(self) -> Tuple[NetStream, ...]:
|
async def get_streams(self) -> Tuple[NetStream, ...]:
|
||||||
return tuple(self.streams)
|
return tuple(self.streams)
|
||||||
|
|
|
@ -1,7 +1,8 @@
|
||||||
import asyncio
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
from async_service import Service
|
||||||
|
|
||||||
from multiaddr import Multiaddr
|
from multiaddr import Multiaddr
|
||||||
import trio
|
import trio
|
||||||
|
|
||||||
|
@ -31,7 +32,7 @@ from .stream.net_stream_interface import INetStream
|
||||||
logger = logging.getLogger("libp2p.network.swarm")
|
logger = logging.getLogger("libp2p.network.swarm")
|
||||||
|
|
||||||
|
|
||||||
class Swarm(INetwork):
|
class Swarm(INetwork, Service):
|
||||||
|
|
||||||
self_id: ID
|
self_id: ID
|
||||||
peerstore: IPeerStore
|
peerstore: IPeerStore
|
||||||
|
@ -64,13 +65,16 @@ class Swarm(INetwork):
|
||||||
|
|
||||||
self.common_stream_handler = None
|
self.common_stream_handler = None
|
||||||
|
|
||||||
|
async def run(self) -> None:
|
||||||
|
await self.manager.wait_finished()
|
||||||
|
|
||||||
def get_peer_id(self) -> ID:
|
def get_peer_id(self) -> ID:
|
||||||
return self.self_id
|
return self.self_id
|
||||||
|
|
||||||
def set_stream_handler(self, stream_handler: StreamHandlerFn) -> None:
|
def set_stream_handler(self, stream_handler: StreamHandlerFn) -> None:
|
||||||
self.common_stream_handler = stream_handler
|
self.common_stream_handler = stream_handler
|
||||||
|
|
||||||
async def dial_peer(self, peer_id: ID, nursery) -> INetConn:
|
async def dial_peer(self, peer_id: ID) -> INetConn:
|
||||||
"""
|
"""
|
||||||
dial_peer try to create a connection to peer_id.
|
dial_peer try to create a connection to peer_id.
|
||||||
|
|
||||||
|
@ -122,7 +126,7 @@ class Swarm(INetwork):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
muxed_conn = await self.upgrader.upgrade_connection(secured_conn, peer_id)
|
muxed_conn = await self.upgrader.upgrade_connection(secured_conn, peer_id)
|
||||||
muxed_conn.run(nursery)
|
self.manager.run_child_service(muxed_conn)
|
||||||
except MuxerUpgradeFailure as error:
|
except MuxerUpgradeFailure as error:
|
||||||
error_msg = "fail to upgrade mux for peer %s"
|
error_msg = "fail to upgrade mux for peer %s"
|
||||||
logger.debug(error_msg, peer_id)
|
logger.debug(error_msg, peer_id)
|
||||||
|
@ -137,7 +141,7 @@ class Swarm(INetwork):
|
||||||
|
|
||||||
return swarm_conn
|
return swarm_conn
|
||||||
|
|
||||||
async def new_stream(self, peer_id: ID, nursery) -> INetStream:
|
async def new_stream(self, peer_id: ID) -> INetStream:
|
||||||
"""
|
"""
|
||||||
:param peer_id: peer_id of destination
|
:param peer_id: peer_id of destination
|
||||||
:param protocol_id: protocol id
|
:param protocol_id: protocol id
|
||||||
|
@ -146,13 +150,13 @@ class Swarm(INetwork):
|
||||||
"""
|
"""
|
||||||
logger.debug("attempting to open a stream to peer %s", peer_id)
|
logger.debug("attempting to open a stream to peer %s", peer_id)
|
||||||
|
|
||||||
swarm_conn = await self.dial_peer(peer_id, nursery)
|
swarm_conn = await self.dial_peer(peer_id)
|
||||||
|
|
||||||
net_stream = await swarm_conn.new_stream()
|
net_stream = await swarm_conn.new_stream()
|
||||||
logger.debug("successfully opened a stream to peer %s", peer_id)
|
logger.debug("successfully opened a stream to peer %s", peer_id)
|
||||||
return net_stream
|
return net_stream
|
||||||
|
|
||||||
async def listen(self, *multiaddrs: Multiaddr, nursery) -> bool:
|
async def listen(self, *multiaddrs: Multiaddr) -> bool:
|
||||||
"""
|
"""
|
||||||
:param multiaddrs: one or many multiaddrs to start listening on
|
:param multiaddrs: one or many multiaddrs to start listening on
|
||||||
:return: true if at least one success
|
:return: true if at least one success
|
||||||
|
@ -189,7 +193,7 @@ class Swarm(INetwork):
|
||||||
muxed_conn = await self.upgrader.upgrade_connection(
|
muxed_conn = await self.upgrader.upgrade_connection(
|
||||||
secured_conn, peer_id
|
secured_conn, peer_id
|
||||||
)
|
)
|
||||||
muxed_conn.run(nursery)
|
self.manager.run_child_service(muxed_conn)
|
||||||
except MuxerUpgradeFailure as error:
|
except MuxerUpgradeFailure as error:
|
||||||
error_msg = "fail to upgrade mux for peer %s"
|
error_msg = "fail to upgrade mux for peer %s"
|
||||||
logger.debug(error_msg, peer_id)
|
logger.debug(error_msg, peer_id)
|
||||||
|
@ -200,6 +204,8 @@ class Swarm(INetwork):
|
||||||
await self.add_conn(muxed_conn)
|
await self.add_conn(muxed_conn)
|
||||||
|
|
||||||
logger.debug("successfully opened connection to peer %s", peer_id)
|
logger.debug("successfully opened connection to peer %s", peer_id)
|
||||||
|
# FIXME: This is a intentional barrier to prevent from the handler exiting and
|
||||||
|
# closing the connection.
|
||||||
event = trio.Event()
|
event = trio.Event()
|
||||||
await event.wait()
|
await event.wait()
|
||||||
|
|
||||||
|
@ -207,10 +213,11 @@ class Swarm(INetwork):
|
||||||
# Success
|
# Success
|
||||||
listener = self.transport.create_listener(conn_handler)
|
listener = self.transport.create_listener(conn_handler)
|
||||||
self.listeners[str(maddr)] = listener
|
self.listeners[str(maddr)] = listener
|
||||||
await listener.listen(maddr, nursery)
|
# FIXME: Hack
|
||||||
|
await listener.listen(maddr, self.manager._task_nursery)
|
||||||
|
|
||||||
# Call notifiers since event occurred
|
# Call notifiers since event occurred
|
||||||
self.notify_listen(maddr)
|
await self.notify_listen(maddr)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
except IOError:
|
except IOError:
|
||||||
|
@ -225,15 +232,16 @@ class Swarm(INetwork):
|
||||||
# Reference: https://github.com/libp2p/go-libp2p-swarm/blob/8be680aef8dea0a4497283f2f98470c2aeae6b65/swarm.go#L124-L134 # noqa: E501
|
# Reference: https://github.com/libp2p/go-libp2p-swarm/blob/8be680aef8dea0a4497283f2f98470c2aeae6b65/swarm.go#L124-L134 # noqa: E501
|
||||||
|
|
||||||
# Close listeners
|
# Close listeners
|
||||||
await asyncio.gather(
|
# await asyncio.gather(
|
||||||
*[listener.close() for listener in self.listeners.values()]
|
# *[listener.close() for listener in self.listeners.values()]
|
||||||
)
|
# )
|
||||||
|
|
||||||
# Close connections
|
|
||||||
await asyncio.gather(
|
|
||||||
*[connection.close() for connection in self.connections.values()]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
# # Close connections
|
||||||
|
# await asyncio.gather(
|
||||||
|
# *[connection.close() for connection in self.connections.values()]
|
||||||
|
# )
|
||||||
|
self.manager.stop()
|
||||||
|
await self.manager.wait_finished()
|
||||||
logger.debug("swarm successfully closed")
|
logger.debug("swarm successfully closed")
|
||||||
|
|
||||||
async def close_peer(self, peer_id: ID) -> None:
|
async def close_peer(self, peer_id: ID) -> None:
|
||||||
|
@ -253,11 +261,12 @@ class Swarm(INetwork):
|
||||||
and start to monitor the connection for its new streams and
|
and start to monitor the connection for its new streams and
|
||||||
disconnection."""
|
disconnection."""
|
||||||
swarm_conn = SwarmConn(muxed_conn, self)
|
swarm_conn = SwarmConn(muxed_conn, self)
|
||||||
|
manager = self.manager.run_child_service(swarm_conn)
|
||||||
# Store muxed_conn with peer id
|
# Store muxed_conn with peer id
|
||||||
self.connections[muxed_conn.peer_id] = swarm_conn
|
self.connections[muxed_conn.peer_id] = swarm_conn
|
||||||
# Call notifiers since event occurred
|
# Call notifiers since event occurred
|
||||||
self.notify_connected(swarm_conn)
|
self.manager.run_task(self.notify_connected, swarm_conn)
|
||||||
await swarm_conn.start()
|
await manager.wait_started()
|
||||||
return swarm_conn
|
return swarm_conn
|
||||||
|
|
||||||
def remove_conn(self, swarm_conn: SwarmConn) -> None:
|
def remove_conn(self, swarm_conn: SwarmConn) -> None:
|
||||||
|
@ -281,20 +290,26 @@ class Swarm(INetwork):
|
||||||
"""
|
"""
|
||||||
self.notifees.append(notifee)
|
self.notifees.append(notifee)
|
||||||
|
|
||||||
def notify_opened_stream(self, stream: INetStream) -> None:
|
async def notify_opened_stream(self, stream: INetStream) -> None:
|
||||||
asyncio.gather(
|
async with trio.open_nursery() as nursery:
|
||||||
*[notifee.opened_stream(self, stream) for notifee in self.notifees]
|
for notifee in self.notifees:
|
||||||
)
|
nursery.start_soon(notifee.opened_stream, self, stream)
|
||||||
|
|
||||||
# TODO: `notify_closed_stream`
|
# TODO: `notify_closed_stream`
|
||||||
|
|
||||||
def notify_connected(self, conn: INetConn) -> None:
|
async def notify_connected(self, conn: INetConn) -> None:
|
||||||
asyncio.gather(*[notifee.connected(self, conn) for notifee in self.notifees])
|
async with trio.open_nursery() as nursery:
|
||||||
|
for notifee in self.notifees:
|
||||||
|
nursery.start_soon(notifee.connected, self, conn)
|
||||||
|
|
||||||
def notify_disconnected(self, conn: INetConn) -> None:
|
async def notify_disconnected(self, conn: INetConn) -> None:
|
||||||
asyncio.gather(*[notifee.disconnected(self, conn) for notifee in self.notifees])
|
async with trio.open_nursery() as nursery:
|
||||||
|
for notifee in self.notifees:
|
||||||
|
nursery.start_soon(notifee.disconnected, self, conn)
|
||||||
|
|
||||||
def notify_listen(self, multiaddr: Multiaddr) -> None:
|
async def notify_listen(self, multiaddr: Multiaddr) -> None:
|
||||||
asyncio.gather(*[notifee.listen(self, multiaddr) for notifee in self.notifees])
|
async with trio.open_nursery() as nursery:
|
||||||
|
for notifee in self.notifees:
|
||||||
|
nursery.start_soon(notifee.listen, self, multiaddr)
|
||||||
|
|
||||||
# TODO: `notify_listen_close`
|
# TODO: `notify_listen_close`
|
||||||
|
|
|
@ -4,6 +4,7 @@ from typing import Any # noqa: F401
|
||||||
from typing import Awaitable, Dict, List, Optional, Tuple
|
from typing import Awaitable, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import trio
|
import trio
|
||||||
|
from async_service import Service
|
||||||
|
|
||||||
from libp2p.exceptions import ParseError
|
from libp2p.exceptions import ParseError
|
||||||
from libp2p.io.exceptions import IncompleteReadError
|
from libp2p.io.exceptions import IncompleteReadError
|
||||||
|
@ -17,6 +18,7 @@ from libp2p.utils import (
|
||||||
encode_uvarint,
|
encode_uvarint,
|
||||||
encode_varint_prefixed,
|
encode_varint_prefixed,
|
||||||
read_varint_prefixed_bytes,
|
read_varint_prefixed_bytes,
|
||||||
|
TrioQueue,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .constants import HeaderTags
|
from .constants import HeaderTags
|
||||||
|
@ -29,7 +31,7 @@ MPLEX_PROTOCOL_ID = TProtocol("/mplex/6.7.0")
|
||||||
logger = logging.getLogger("libp2p.stream_muxer.mplex.mplex")
|
logger = logging.getLogger("libp2p.stream_muxer.mplex.mplex")
|
||||||
|
|
||||||
|
|
||||||
class Mplex(IMuxedConn):
|
class Mplex(IMuxedConn, Service):
|
||||||
"""
|
"""
|
||||||
reference: https://github.com/libp2p/go-mplex/blob/master/multiplex.go
|
reference: https://github.com/libp2p/go-mplex/blob/master/multiplex.go
|
||||||
"""
|
"""
|
||||||
|
@ -38,10 +40,10 @@ class Mplex(IMuxedConn):
|
||||||
peer_id: ID
|
peer_id: ID
|
||||||
next_channel_id: int
|
next_channel_id: int
|
||||||
streams: Dict[StreamID, MplexStream]
|
streams: Dict[StreamID, MplexStream]
|
||||||
streams_lock: asyncio.Lock
|
streams_lock: trio.Lock
|
||||||
new_stream_queue: "asyncio.Queue[IMuxedStream]"
|
new_stream_queue: "TrioQueue[IMuxedStream]"
|
||||||
event_shutting_down: asyncio.Event
|
event_shutting_down: trio.Event
|
||||||
event_closed: asyncio.Event
|
event_closed: trio.Event
|
||||||
|
|
||||||
def __init__(self, secured_conn: ISecureConn, peer_id: ID) -> None:
|
def __init__(self, secured_conn: ISecureConn, peer_id: ID) -> None:
|
||||||
"""
|
"""
|
||||||
|
@ -61,13 +63,14 @@ class Mplex(IMuxedConn):
|
||||||
|
|
||||||
# Mapping from stream ID -> buffer of messages for that stream
|
# Mapping from stream ID -> buffer of messages for that stream
|
||||||
self.streams = {}
|
self.streams = {}
|
||||||
self.streams_lock = asyncio.Lock()
|
self.streams_lock = trio.Lock()
|
||||||
self.new_stream_queue = asyncio.Queue()
|
self.new_stream_queue = TrioQueue()
|
||||||
self.event_shutting_down = asyncio.Event()
|
self.event_shutting_down = trio.Event()
|
||||||
self.event_closed = asyncio.Event()
|
self.event_closed = trio.Event()
|
||||||
|
|
||||||
def run(self, nursery):
|
async def run(self):
|
||||||
nursery.start_soon(self.handle_incoming)
|
self.manager.run_task(self.handle_incoming)
|
||||||
|
await self.manager.wait_finished()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_initiator(self) -> bool:
|
def is_initiator(self) -> bool:
|
||||||
|
|
|
@ -136,7 +136,6 @@ class MplexStream(IMuxedStream):
|
||||||
nursery.start_soon(
|
nursery.start_soon(
|
||||||
self.muxed_conn.send_message, flag, None, self.stream_id
|
self.muxed_conn.send_message, flag, None, self.stream_id
|
||||||
)
|
)
|
||||||
await trio.sleep(0)
|
|
||||||
|
|
||||||
self.event_local_closed.set()
|
self.event_local_closed.set()
|
||||||
self.event_remote_closed.set()
|
self.event_remote_closed.set()
|
||||||
|
|
|
@ -1,8 +1,9 @@
|
||||||
import asyncio
|
import trio
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager, AsyncExitStack
|
||||||
from typing import Any, AsyncIterator, Dict, Tuple, cast
|
from typing import Any, AsyncIterator, Dict, Tuple, cast
|
||||||
|
|
||||||
import factory
|
import factory
|
||||||
|
from async_service import background_trio_service
|
||||||
|
|
||||||
from libp2p import generate_new_rsa_identity, generate_peer_id_from
|
from libp2p import generate_new_rsa_identity, generate_peer_id_from
|
||||||
from libp2p.crypto.keys import KeyPair
|
from libp2p.crypto.keys import KeyPair
|
||||||
|
@ -61,6 +62,7 @@ class SwarmFactory(factory.Factory):
|
||||||
transport = factory.LazyFunction(TCP)
|
transport = factory.LazyFunction(TCP)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@asynccontextmanager
|
||||||
async def create_and_listen(
|
async def create_and_listen(
|
||||||
cls, is_secure: bool, key_pair: KeyPair = None, muxer_opt: TMuxerOptions = None
|
cls, is_secure: bool, key_pair: KeyPair = None, muxer_opt: TMuxerOptions = None
|
||||||
) -> Swarm:
|
) -> Swarm:
|
||||||
|
@ -73,20 +75,23 @@ class SwarmFactory(factory.Factory):
|
||||||
if muxer_opt is not None:
|
if muxer_opt is not None:
|
||||||
optional_kwargs["muxer_opt"] = muxer_opt
|
optional_kwargs["muxer_opt"] = muxer_opt
|
||||||
swarm = cls(is_secure=is_secure, **optional_kwargs)
|
swarm = cls(is_secure=is_secure, **optional_kwargs)
|
||||||
await swarm.listen(LISTEN_MADDR)
|
async with background_trio_service(swarm):
|
||||||
return swarm
|
await swarm.listen(LISTEN_MADDR)
|
||||||
|
yield swarm
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@asynccontextmanager
|
||||||
async def create_batch_and_listen(
|
async def create_batch_and_listen(
|
||||||
cls, is_secure: bool, number: int, muxer_opt: TMuxerOptions = None
|
cls, is_secure: bool, number: int, muxer_opt: TMuxerOptions = None
|
||||||
) -> Tuple[Swarm, ...]:
|
) -> Tuple[Swarm, ...]:
|
||||||
# Ignore typing since we are removing asyncio soon
|
async with AsyncExitStack() as stack:
|
||||||
return await asyncio.gather( # type: ignore
|
ctx_mgrs = [
|
||||||
*[
|
await stack.enter_async_context(
|
||||||
cls.create_and_listen(is_secure=is_secure, muxer_opt=muxer_opt)
|
cls.create_and_listen(is_secure=is_secure, muxer_opt=muxer_opt)
|
||||||
|
)
|
||||||
for _ in range(number)
|
for _ in range(number)
|
||||||
]
|
]
|
||||||
)
|
yield ctx_mgrs
|
||||||
|
|
||||||
|
|
||||||
class HostFactory(factory.Factory):
|
class HostFactory(factory.Factory):
|
||||||
|
@ -103,20 +108,23 @@ class HostFactory(factory.Factory):
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@asynccontextmanager
|
||||||
async def create_batch_and_listen(
|
async def create_batch_and_listen(
|
||||||
cls, is_secure: bool, number: int
|
cls, is_secure: bool, number: int
|
||||||
) -> Tuple[BasicHost, ...]:
|
) -> Tuple[BasicHost, ...]:
|
||||||
key_pairs = [generate_new_rsa_identity() for _ in range(number)]
|
key_pairs = [generate_new_rsa_identity() for _ in range(number)]
|
||||||
swarms = await asyncio.gather(
|
async with AsyncExitStack() as stack:
|
||||||
*[
|
swarms = [
|
||||||
SwarmFactory.create_and_listen(is_secure, key_pair)
|
await stack.enter_async_context(
|
||||||
|
SwarmFactory.create_and_listen(is_secure, key_pair)
|
||||||
|
)
|
||||||
for key_pair in key_pairs
|
for key_pair in key_pairs
|
||||||
]
|
]
|
||||||
)
|
hosts = tuple(
|
||||||
return tuple(
|
BasicHost(key_pair.public_key, swarm)
|
||||||
BasicHost(key_pair.public_key, swarm)
|
for key_pair, swarm in zip(key_pairs, swarms)
|
||||||
for key_pair, swarm in zip(key_pairs, swarms)
|
)
|
||||||
)
|
yield hosts
|
||||||
|
|
||||||
|
|
||||||
class FloodsubFactory(factory.Factory):
|
class FloodsubFactory(factory.Factory):
|
||||||
|
@ -150,73 +158,60 @@ class PubsubFactory(factory.Factory):
|
||||||
cache_size = None
|
cache_size = None
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
async def swarm_pair_factory(
|
async def swarm_pair_factory(
|
||||||
is_secure: bool, muxer_opt: TMuxerOptions = None
|
is_secure: bool, muxer_opt: TMuxerOptions = None
|
||||||
) -> Tuple[Swarm, Swarm]:
|
) -> Tuple[Swarm, Swarm]:
|
||||||
swarms = await SwarmFactory.create_batch_and_listen(
|
async with SwarmFactory.create_batch_and_listen(
|
||||||
is_secure, 2, muxer_opt=muxer_opt
|
is_secure, 2, muxer_opt=muxer_opt
|
||||||
)
|
) as swarms:
|
||||||
await connect_swarm(swarms[0], swarms[1])
|
await connect_swarm(swarms[0], swarms[1])
|
||||||
return swarms[0], swarms[1]
|
yield swarms[0], swarms[1]
|
||||||
|
|
||||||
|
|
||||||
async def host_pair_factory(is_secure: bool) -> Tuple[BasicHost, BasicHost]:
|
|
||||||
hosts = await HostFactory.create_batch_and_listen(is_secure, 2)
|
|
||||||
await connect(hosts[0], hosts[1])
|
|
||||||
return hosts[0], hosts[1]
|
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def pair_of_connected_hosts(
|
async def host_pair_factory(is_secure: bool) -> Tuple[BasicHost, BasicHost]:
|
||||||
is_secure: bool = True
|
async with HostFactory.create_batch_and_listen(is_secure, 2) as hosts:
|
||||||
) -> AsyncIterator[Tuple[BasicHost, BasicHost]]:
|
await connect(hosts[0], hosts[1])
|
||||||
a, b = await host_pair_factory(is_secure)
|
yield hosts[0], hosts[1]
|
||||||
yield a, b
|
|
||||||
close_tasks = (a.close(), b.close())
|
|
||||||
await asyncio.gather(*close_tasks)
|
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
async def swarm_conn_pair_factory(
|
async def swarm_conn_pair_factory(
|
||||||
is_secure: bool, muxer_opt: TMuxerOptions = None
|
is_secure: bool, muxer_opt: TMuxerOptions = None
|
||||||
) -> Tuple[SwarmConn, Swarm, SwarmConn, Swarm]:
|
) -> Tuple[SwarmConn, SwarmConn]:
|
||||||
swarms = await swarm_pair_factory(is_secure)
|
async with swarm_pair_factory(is_secure) as swarms:
|
||||||
conn_0 = swarms[0].connections[swarms[1].get_peer_id()]
|
conn_0 = swarms[0].connections[swarms[1].get_peer_id()]
|
||||||
conn_1 = swarms[1].connections[swarms[0].get_peer_id()]
|
conn_1 = swarms[1].connections[swarms[0].get_peer_id()]
|
||||||
return cast(SwarmConn, conn_0), swarms[0], cast(SwarmConn, conn_1), swarms[1]
|
yield cast(SwarmConn, conn_0), cast(SwarmConn, conn_1)
|
||||||
|
|
||||||
|
|
||||||
async def mplex_conn_pair_factory(is_secure: bool) -> Tuple[Mplex, Swarm, Mplex, Swarm]:
|
@asynccontextmanager
|
||||||
|
async def mplex_conn_pair_factory(is_secure: bool) -> Tuple[Mplex, Mplex]:
|
||||||
muxer_opt = {MPLEX_PROTOCOL_ID: Mplex}
|
muxer_opt = {MPLEX_PROTOCOL_ID: Mplex}
|
||||||
conn_0, swarm_0, conn_1, swarm_1 = await swarm_conn_pair_factory(
|
async with swarm_conn_pair_factory(is_secure, muxer_opt=muxer_opt) as swarm_pair:
|
||||||
is_secure, muxer_opt=muxer_opt
|
yield (
|
||||||
)
|
cast(Mplex, swarm_pair[0].muxed_conn),
|
||||||
return (
|
cast(Mplex, swarm_pair[1].muxed_conn),
|
||||||
cast(Mplex, conn_0.muxed_conn),
|
)
|
||||||
swarm_0,
|
|
||||||
cast(Mplex, conn_1.muxed_conn),
|
|
||||||
swarm_1,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def mplex_stream_pair_factory(
|
@asynccontextmanager
|
||||||
is_secure: bool
|
async def mplex_stream_pair_factory(is_secure: bool) -> Tuple[MplexStream, MplexStream]:
|
||||||
) -> Tuple[MplexStream, Swarm, MplexStream, Swarm]:
|
async with mplex_conn_pair_factory(is_secure) as mplex_conn_pair_info:
|
||||||
mplex_conn_0, swarm_0, mplex_conn_1, swarm_1 = await mplex_conn_pair_factory(
|
mplex_conn_0, mplex_conn_1 = mplex_conn_pair_info
|
||||||
is_secure
|
stream_0 = await mplex_conn_0.open_stream()
|
||||||
)
|
await trio.sleep(0.01)
|
||||||
stream_0 = await mplex_conn_0.open_stream()
|
stream_1: MplexStream
|
||||||
await asyncio.sleep(0.01)
|
async with mplex_conn_1.streams_lock:
|
||||||
stream_1: MplexStream
|
if len(mplex_conn_1.streams) != 1:
|
||||||
async with mplex_conn_1.streams_lock:
|
raise Exception("Mplex should not have any stream upon connection")
|
||||||
if len(mplex_conn_1.streams) != 1:
|
stream_1 = tuple(mplex_conn_1.streams.values())[0]
|
||||||
raise Exception("Mplex should not have any stream upon connection")
|
yield cast(MplexStream, stream_0), cast(MplexStream, stream_1)
|
||||||
stream_1 = tuple(mplex_conn_1.streams.values())[0]
|
|
||||||
return cast(MplexStream, stream_0), swarm_0, stream_1, swarm_1
|
|
||||||
|
|
||||||
|
|
||||||
async def net_stream_pair_factory(
|
@asynccontextmanager
|
||||||
is_secure: bool
|
async def net_stream_pair_factory(is_secure: bool) -> Tuple[INetStream, INetStream]:
|
||||||
) -> Tuple[INetStream, BasicHost, INetStream, BasicHost]:
|
|
||||||
protocol_id = TProtocol("/example/id/1")
|
protocol_id = TProtocol("/example/id/1")
|
||||||
|
|
||||||
stream_1: INetStream
|
stream_1: INetStream
|
||||||
|
@ -226,8 +221,8 @@ async def net_stream_pair_factory(
|
||||||
nonlocal stream_1
|
nonlocal stream_1
|
||||||
stream_1 = stream
|
stream_1 = stream
|
||||||
|
|
||||||
host_0, host_1 = await host_pair_factory(is_secure)
|
async with host_pair_factory(is_secure) as hosts:
|
||||||
host_1.set_stream_handler(protocol_id, handler)
|
hosts[1].set_stream_handler(protocol_id, handler)
|
||||||
|
|
||||||
stream_0 = await host_0.new_stream(host_1.get_id(), [protocol_id])
|
stream_0 = await hosts[0].new_stream(hosts[1].get_id(), [protocol_id])
|
||||||
return stream_0, host_0, stream_1, host_1
|
yield stream_0, stream_1
|
||||||
|
|
|
@ -17,7 +17,7 @@ from libp2p.typing import StreamHandlerFn, TProtocol
|
||||||
from .constants import MAX_READ_LEN
|
from .constants import MAX_READ_LEN
|
||||||
|
|
||||||
|
|
||||||
async def connect_swarm(swarm_0: Swarm, swarm_1: Swarm, nursery: trio.Nursery) -> None:
|
async def connect_swarm(swarm_0: Swarm, swarm_1: Swarm) -> None:
|
||||||
peer_id = swarm_1.get_peer_id()
|
peer_id = swarm_1.get_peer_id()
|
||||||
addrs = tuple(
|
addrs = tuple(
|
||||||
addr
|
addr
|
||||||
|
@ -25,7 +25,7 @@ async def connect_swarm(swarm_0: Swarm, swarm_1: Swarm, nursery: trio.Nursery) -
|
||||||
for addr in transport.get_addrs()
|
for addr in transport.get_addrs()
|
||||||
)
|
)
|
||||||
swarm_0.peerstore.add_addrs(peer_id, addrs, 10000)
|
swarm_0.peerstore.add_addrs(peer_id, addrs, 10000)
|
||||||
await swarm_0.dial_peer(peer_id, nursery)
|
await swarm_0.dial_peer(peer_id)
|
||||||
assert swarm_0.get_peer_id() in swarm_1.connections
|
assert swarm_0.get_peer_id() in swarm_1.connections
|
||||||
assert swarm_1.get_peer_id() in swarm_0.connections
|
assert swarm_1.get_peer_id() in swarm_0.connections
|
||||||
|
|
||||||
|
|
|
@ -25,7 +25,7 @@ class TCPListener(IListener):
|
||||||
self.server = None
|
self.server = None
|
||||||
self.handler = handler_function
|
self.handler = handler_function
|
||||||
|
|
||||||
async def listen(self, maddr: Multiaddr, nursery) -> bool:
|
async def listen(self, maddr: Multiaddr, nursery: trio.Nursery) -> bool:
|
||||||
"""
|
"""
|
||||||
put listener in listening mode and wait for incoming connections.
|
put listener in listening mode and wait for incoming connections.
|
||||||
|
|
||||||
|
|
|
@ -4,12 +4,12 @@ import secrets
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from libp2p.host.ping import ID, PING_LENGTH
|
from libp2p.host.ping import ID, PING_LENGTH
|
||||||
from libp2p.tools.factories import pair_of_connected_hosts
|
from libp2p.tools.factories import host_pair_factory
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_ping_once():
|
async def test_ping_once():
|
||||||
async with pair_of_connected_hosts() as (host_a, host_b):
|
async with host_pair_factory() as (host_a, host_b):
|
||||||
stream = await host_b.new_stream(host_a.get_id(), (ID,))
|
stream = await host_b.new_stream(host_a.get_id(), (ID,))
|
||||||
some_ping = secrets.token_bytes(PING_LENGTH)
|
some_ping = secrets.token_bytes(PING_LENGTH)
|
||||||
await stream.write(some_ping)
|
await stream.write(some_ping)
|
||||||
|
@ -23,7 +23,7 @@ SOME_PING_COUNT = 3
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_ping_several():
|
async def test_ping_several():
|
||||||
async with pair_of_connected_hosts() as (host_a, host_b):
|
async with host_pair_factory() as (host_a, host_b):
|
||||||
stream = await host_b.new_stream(host_a.get_id(), (ID,))
|
stream = await host_b.new_stream(host_a.get_id(), (ID,))
|
||||||
for _ in range(SOME_PING_COUNT):
|
for _ in range(SOME_PING_COUNT):
|
||||||
some_ping = secrets.token_bytes(PING_LENGTH)
|
some_ping = secrets.token_bytes(PING_LENGTH)
|
||||||
|
|
|
@ -2,12 +2,12 @@ import pytest
|
||||||
|
|
||||||
from libp2p.identity.identify.pb.identify_pb2 import Identify
|
from libp2p.identity.identify.pb.identify_pb2 import Identify
|
||||||
from libp2p.identity.identify.protocol import ID, _mk_identify_protobuf
|
from libp2p.identity.identify.protocol import ID, _mk_identify_protobuf
|
||||||
from libp2p.tools.factories import pair_of_connected_hosts
|
from libp2p.tools.factories import host_pair_factory
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_identify_protocol():
|
async def test_identify_protocol():
|
||||||
async with pair_of_connected_hosts() as (host_a, host_b):
|
async with host_pair_factory() as (host_a, host_b):
|
||||||
stream = await host_b.new_stream(host_a.get_id(), (ID,))
|
stream = await host_b.new_stream(host_a.get_id(), (ID,))
|
||||||
response = await stream.read()
|
response = await stream.read()
|
||||||
await stream.close()
|
await stream.close()
|
||||||
|
|
|
@ -11,26 +11,17 @@ from libp2p.tools.factories import (
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def net_stream_pair(is_host_secure):
|
async def net_stream_pair(is_host_secure):
|
||||||
stream_0, host_0, stream_1, host_1 = await net_stream_pair_factory(is_host_secure)
|
async with net_stream_pair_factory(is_host_secure) as net_stream_pair:
|
||||||
try:
|
yield net_stream_pair
|
||||||
yield stream_0, stream_1
|
|
||||||
finally:
|
|
||||||
await asyncio.gather(*[host_0.close(), host_1.close()])
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def swarm_pair(is_host_secure):
|
async def swarm_pair(is_host_secure):
|
||||||
swarm_0, swarm_1 = await swarm_pair_factory(is_host_secure)
|
async with swarm_pair_factory(is_host_secure) as swarms:
|
||||||
try:
|
yield swarms
|
||||||
yield swarm_0, swarm_1
|
|
||||||
finally:
|
|
||||||
await asyncio.gather(*[swarm_0.close(), swarm_1.close()])
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def swarm_conn_pair(is_host_secure):
|
async def swarm_conn_pair(is_host_secure):
|
||||||
conn_0, swarm_0, conn_1, swarm_1 = await swarm_conn_pair_factory(is_host_secure)
|
async with swarm_conn_pair_factory(is_host_secure) as swarm_conn_pair:
|
||||||
try:
|
yield swarm_conn_pair
|
||||||
yield conn_0, conn_1
|
|
||||||
finally:
|
|
||||||
await asyncio.gather(*[swarm_0.close(), swarm_1.close()])
|
|
||||||
|
|
|
@ -1,88 +1,83 @@
|
||||||
import asyncio
|
import trio
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from trio.testing import wait_all_tasks_blocked
|
||||||
|
|
||||||
from libp2p.network.exceptions import SwarmException
|
from libp2p.network.exceptions import SwarmException
|
||||||
from libp2p.tools.factories import SwarmFactory
|
from libp2p.tools.factories import SwarmFactory
|
||||||
from libp2p.tools.utils import connect_swarm
|
from libp2p.tools.utils import connect_swarm
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.trio
|
||||||
async def test_swarm_dial_peer(is_host_secure):
|
async def test_swarm_dial_peer(is_host_secure):
|
||||||
swarms = await SwarmFactory.create_batch_and_listen(is_host_secure, 3)
|
async with SwarmFactory.create_batch_and_listen(is_host_secure, 3) as swarms:
|
||||||
# Test: No addr found.
|
# Test: No addr found.
|
||||||
with pytest.raises(SwarmException):
|
with pytest.raises(SwarmException):
|
||||||
|
await swarms[0].dial_peer(swarms[1].get_peer_id())
|
||||||
|
|
||||||
|
# Test: len(addr) in the peerstore is 0.
|
||||||
|
swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), [], 10000)
|
||||||
|
with pytest.raises(SwarmException):
|
||||||
|
await swarms[0].dial_peer(swarms[1].get_peer_id())
|
||||||
|
|
||||||
|
# Test: Succeed if addrs of the peer_id are present in the peerstore.
|
||||||
|
addrs = tuple(
|
||||||
|
addr
|
||||||
|
for transport in swarms[1].listeners.values()
|
||||||
|
for addr in transport.get_addrs()
|
||||||
|
)
|
||||||
|
swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), addrs, 10000)
|
||||||
await swarms[0].dial_peer(swarms[1].get_peer_id())
|
await swarms[0].dial_peer(swarms[1].get_peer_id())
|
||||||
|
assert swarms[0].get_peer_id() in swarms[1].connections
|
||||||
|
assert swarms[1].get_peer_id() in swarms[0].connections
|
||||||
|
|
||||||
# Test: len(addr) in the peerstore is 0.
|
# Test: Reuse connections when we already have ones with a peer.
|
||||||
swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), [], 10000)
|
conn_to_1 = swarms[0].connections[swarms[1].get_peer_id()]
|
||||||
with pytest.raises(SwarmException):
|
conn = await swarms[0].dial_peer(swarms[1].get_peer_id())
|
||||||
await swarms[0].dial_peer(swarms[1].get_peer_id())
|
assert conn is conn_to_1
|
||||||
|
|
||||||
# Test: Succeed if addrs of the peer_id are present in the peerstore.
|
|
||||||
addrs = tuple(
|
|
||||||
addr
|
|
||||||
for transport in swarms[1].listeners.values()
|
|
||||||
for addr in transport.get_addrs()
|
|
||||||
)
|
|
||||||
swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), addrs, 10000)
|
|
||||||
await swarms[0].dial_peer(swarms[1].get_peer_id())
|
|
||||||
assert swarms[0].get_peer_id() in swarms[1].connections
|
|
||||||
assert swarms[1].get_peer_id() in swarms[0].connections
|
|
||||||
|
|
||||||
# Test: Reuse connections when we already have ones with a peer.
|
|
||||||
conn_to_1 = swarms[0].connections[swarms[1].get_peer_id()]
|
|
||||||
conn = await swarms[0].dial_peer(swarms[1].get_peer_id())
|
|
||||||
assert conn is conn_to_1
|
|
||||||
|
|
||||||
# Clean up
|
|
||||||
await asyncio.gather(*[swarm.close() for swarm in swarms])
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.trio
|
||||||
async def test_swarm_close_peer(is_host_secure):
|
async def test_swarm_close_peer(is_host_secure):
|
||||||
swarms = await SwarmFactory.create_batch_and_listen(is_host_secure, 3)
|
async with SwarmFactory.create_batch_and_listen(is_host_secure, 3) as swarms:
|
||||||
# 0 <> 1 <> 2
|
# 0 <> 1 <> 2
|
||||||
await connect_swarm(swarms[0], swarms[1])
|
await connect_swarm(swarms[0], swarms[1])
|
||||||
await connect_swarm(swarms[1], swarms[2])
|
await connect_swarm(swarms[1], swarms[2])
|
||||||
|
|
||||||
# peer 1 closes peer 0
|
# peer 1 closes peer 0
|
||||||
await swarms[1].close_peer(swarms[0].get_peer_id())
|
await swarms[1].close_peer(swarms[0].get_peer_id())
|
||||||
await asyncio.sleep(0.01)
|
await trio.sleep(0.01)
|
||||||
# 0 1 <> 2
|
await wait_all_tasks_blocked()
|
||||||
assert len(swarms[0].connections) == 0
|
# 0 1 <> 2
|
||||||
assert (
|
assert len(swarms[0].connections) == 0
|
||||||
len(swarms[1].connections) == 1
|
assert (
|
||||||
and swarms[2].get_peer_id() in swarms[1].connections
|
len(swarms[1].connections) == 1
|
||||||
)
|
and swarms[2].get_peer_id() in swarms[1].connections
|
||||||
|
)
|
||||||
|
|
||||||
# peer 1 is closed by peer 2
|
# peer 1 is closed by peer 2
|
||||||
await swarms[2].close_peer(swarms[1].get_peer_id())
|
await swarms[2].close_peer(swarms[1].get_peer_id())
|
||||||
await asyncio.sleep(0.01)
|
await trio.sleep(0.01)
|
||||||
# 0 1 2
|
# 0 1 2
|
||||||
assert len(swarms[1].connections) == 0 and len(swarms[2].connections) == 0
|
assert len(swarms[1].connections) == 0 and len(swarms[2].connections) == 0
|
||||||
|
|
||||||
await connect_swarm(swarms[0], swarms[1])
|
await connect_swarm(swarms[0], swarms[1])
|
||||||
# 0 <> 1 2
|
# 0 <> 1 2
|
||||||
assert (
|
assert (
|
||||||
len(swarms[0].connections) == 1
|
len(swarms[0].connections) == 1
|
||||||
and swarms[1].get_peer_id() in swarms[0].connections
|
and swarms[1].get_peer_id() in swarms[0].connections
|
||||||
)
|
)
|
||||||
assert (
|
assert (
|
||||||
len(swarms[1].connections) == 1
|
len(swarms[1].connections) == 1
|
||||||
and swarms[0].get_peer_id() in swarms[1].connections
|
and swarms[0].get_peer_id() in swarms[1].connections
|
||||||
)
|
)
|
||||||
# peer 0 closes peer 1
|
# peer 0 closes peer 1
|
||||||
await swarms[0].close_peer(swarms[1].get_peer_id())
|
await swarms[0].close_peer(swarms[1].get_peer_id())
|
||||||
await asyncio.sleep(0.01)
|
await trio.sleep(0.01)
|
||||||
# 0 1 2
|
# 0 1 2
|
||||||
assert len(swarms[1].connections) == 0 and len(swarms[2].connections) == 0
|
assert len(swarms[1].connections) == 0 and len(swarms[2].connections) == 0
|
||||||
|
|
||||||
# Clean up
|
|
||||||
await asyncio.gather(*[swarm.close() for swarm in swarms])
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.trio
|
||||||
async def test_swarm_remove_conn(swarm_pair):
|
async def test_swarm_remove_conn(swarm_pair):
|
||||||
swarm_0, swarm_1 = swarm_pair
|
swarm_0, swarm_1 = swarm_pair
|
||||||
conn_0 = swarm_0.connections[swarm_1.get_peer_id()]
|
conn_0 = swarm_0.connections[swarm_1.get_peer_id()]
|
||||||
|
|
|
@ -7,23 +7,13 @@ from libp2p.tools.factories import mplex_conn_pair_factory, mplex_stream_pair_fa
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def mplex_conn_pair(is_host_secure):
|
async def mplex_conn_pair(is_host_secure):
|
||||||
mplex_conn_0, swarm_0, mplex_conn_1, swarm_1 = await mplex_conn_pair_factory(
|
async with mplex_conn_pair_factory(is_host_secure) as mplex_conn_pair:
|
||||||
is_host_secure
|
assert mplex_conn_pair[0].is_initiator
|
||||||
)
|
assert not mplex_conn_pair[1].is_initiator
|
||||||
assert mplex_conn_0.is_initiator
|
yield mplex_conn_pair[0], mplex_conn_pair[1]
|
||||||
assert not mplex_conn_1.is_initiator
|
|
||||||
try:
|
|
||||||
yield mplex_conn_0, mplex_conn_1
|
|
||||||
finally:
|
|
||||||
await asyncio.gather(*[swarm_0.close(), swarm_1.close()])
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def mplex_stream_pair(is_host_secure):
|
async def mplex_stream_pair(is_host_secure):
|
||||||
mplex_stream_0, swarm_0, mplex_stream_1, swarm_1 = await mplex_stream_pair_factory(
|
async with mplex_stream_pair_factory(is_host_secure) as mplex_stream_pair:
|
||||||
is_host_secure
|
yield mplex_stream_pair
|
||||||
)
|
|
||||||
try:
|
|
||||||
yield mplex_stream_0, mplex_stream_1
|
|
||||||
finally:
|
|
||||||
await asyncio.gather(*[swarm_0.close(), swarm_1.close()])
|
|
||||||
|
|
|
@ -1,5 +1,3 @@
|
||||||
import asyncio
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import trio
|
import trio
|
||||||
|
|
||||||
|
@ -12,25 +10,26 @@ from libp2p.tools.constants import MAX_READ_LEN, LISTEN_MADDR
|
||||||
from libp2p.tools.factories import SwarmFactory
|
from libp2p.tools.factories import SwarmFactory
|
||||||
from libp2p.tools.utils import connect_swarm
|
from libp2p.tools.utils import connect_swarm
|
||||||
|
|
||||||
|
|
||||||
DATA = b"data_123"
|
DATA = b"data_123"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.trio
|
@pytest.mark.trio
|
||||||
async def test_mplex_stream_read_write(nursery):
|
async def test_mplex_stream_read_write():
|
||||||
swarm0, swarm1 = SwarmFactory(), SwarmFactory()
|
async with SwarmFactory.create_batch_and_listen(False, 2) as swarms:
|
||||||
await swarm0.listen(LISTEN_MADDR, nursery=nursery)
|
await swarms[0].listen(LISTEN_MADDR)
|
||||||
await swarm1.listen(LISTEN_MADDR, nursery=nursery)
|
await swarms[1].listen(LISTEN_MADDR)
|
||||||
await connect_swarm(swarm0, swarm1, nursery)
|
await connect_swarm(swarms[0], swarms[1])
|
||||||
conn_0 = swarm0.connections[swarm1.get_peer_id()]
|
conn_0 = swarms[0].connections[swarms[1].get_peer_id()]
|
||||||
conn_1 = swarm1.connections[swarm0.get_peer_id()]
|
conn_1 = swarms[1].connections[swarms[0].get_peer_id()]
|
||||||
stream_0 = await conn_0.muxed_conn.open_stream()
|
stream_0 = await conn_0.muxed_conn.open_stream()
|
||||||
await trio.sleep(1)
|
await trio.sleep(1)
|
||||||
stream_1 = tuple(conn_1.muxed_conn.streams.values())[0]
|
stream_1 = tuple(conn_1.muxed_conn.streams.values())[0]
|
||||||
await stream_0.write(DATA)
|
await stream_0.write(DATA)
|
||||||
assert (await stream_1.read(MAX_READ_LEN)) == DATA
|
assert (await stream_1.read(MAX_READ_LEN)) == DATA
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.trio
|
||||||
async def test_mplex_stream_pair_read_until_eof(mplex_stream_pair):
|
async def test_mplex_stream_pair_read_until_eof(mplex_stream_pair):
|
||||||
read_bytes = bytearray()
|
read_bytes = bytearray()
|
||||||
stream_0, stream_1 = mplex_stream_pair
|
stream_0, stream_1 = mplex_stream_pair
|
||||||
|
@ -38,43 +37,43 @@ async def test_mplex_stream_pair_read_until_eof(mplex_stream_pair):
|
||||||
async def read_until_eof():
|
async def read_until_eof():
|
||||||
read_bytes.extend(await stream_1.read())
|
read_bytes.extend(await stream_1.read())
|
||||||
|
|
||||||
task = asyncio.ensure_future(read_until_eof())
|
task = trio.ensure_future(read_until_eof())
|
||||||
|
|
||||||
expected_data = bytearray()
|
expected_data = bytearray()
|
||||||
|
|
||||||
# Test: `read` doesn't return before `close` is called.
|
# Test: `read` doesn't return before `close` is called.
|
||||||
await stream_0.write(DATA)
|
await stream_0.write(DATA)
|
||||||
expected_data.extend(DATA)
|
expected_data.extend(DATA)
|
||||||
await asyncio.sleep(0.01)
|
await trio.sleep(0.01)
|
||||||
assert len(read_bytes) == 0
|
assert len(read_bytes) == 0
|
||||||
# Test: `read` doesn't return before `close` is called.
|
# Test: `read` doesn't return before `close` is called.
|
||||||
await stream_0.write(DATA)
|
await stream_0.write(DATA)
|
||||||
expected_data.extend(DATA)
|
expected_data.extend(DATA)
|
||||||
await asyncio.sleep(0.01)
|
await trio.sleep(0.01)
|
||||||
assert len(read_bytes) == 0
|
assert len(read_bytes) == 0
|
||||||
|
|
||||||
# Test: Close the stream, `read` returns, and receive previous sent data.
|
# Test: Close the stream, `read` returns, and receive previous sent data.
|
||||||
await stream_0.close()
|
await stream_0.close()
|
||||||
await asyncio.sleep(0.01)
|
await trio.sleep(0.01)
|
||||||
assert read_bytes == expected_data
|
assert read_bytes == expected_data
|
||||||
|
|
||||||
task.cancel()
|
task.cancel()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.trio
|
||||||
async def test_mplex_stream_read_after_remote_closed(mplex_stream_pair):
|
async def test_mplex_stream_read_after_remote_closed(mplex_stream_pair):
|
||||||
stream_0, stream_1 = mplex_stream_pair
|
stream_0, stream_1 = mplex_stream_pair
|
||||||
assert not stream_1.event_remote_closed.is_set()
|
assert not stream_1.event_remote_closed.is_set()
|
||||||
await stream_0.write(DATA)
|
await stream_0.write(DATA)
|
||||||
await stream_0.close()
|
await stream_0.close()
|
||||||
await asyncio.sleep(0.01)
|
await trio.sleep(0.01)
|
||||||
assert stream_1.event_remote_closed.is_set()
|
assert stream_1.event_remote_closed.is_set()
|
||||||
assert (await stream_1.read(MAX_READ_LEN)) == DATA
|
assert (await stream_1.read(MAX_READ_LEN)) == DATA
|
||||||
with pytest.raises(MplexStreamEOF):
|
with pytest.raises(MplexStreamEOF):
|
||||||
await stream_1.read(MAX_READ_LEN)
|
await stream_1.read(MAX_READ_LEN)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.trio
|
||||||
async def test_mplex_stream_read_after_local_reset(mplex_stream_pair):
|
async def test_mplex_stream_read_after_local_reset(mplex_stream_pair):
|
||||||
stream_0, stream_1 = mplex_stream_pair
|
stream_0, stream_1 = mplex_stream_pair
|
||||||
await stream_0.reset()
|
await stream_0.reset()
|
||||||
|
@ -82,29 +81,29 @@ async def test_mplex_stream_read_after_local_reset(mplex_stream_pair):
|
||||||
await stream_0.read(MAX_READ_LEN)
|
await stream_0.read(MAX_READ_LEN)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.trio
|
||||||
async def test_mplex_stream_read_after_remote_reset(mplex_stream_pair):
|
async def test_mplex_stream_read_after_remote_reset(mplex_stream_pair):
|
||||||
stream_0, stream_1 = mplex_stream_pair
|
stream_0, stream_1 = mplex_stream_pair
|
||||||
await stream_0.write(DATA)
|
await stream_0.write(DATA)
|
||||||
await stream_0.reset()
|
await stream_0.reset()
|
||||||
# Sleep to let `stream_1` receive the message.
|
# Sleep to let `stream_1` receive the message.
|
||||||
await asyncio.sleep(0.01)
|
await trio.sleep(0.01)
|
||||||
with pytest.raises(MplexStreamReset):
|
with pytest.raises(MplexStreamReset):
|
||||||
await stream_1.read(MAX_READ_LEN)
|
await stream_1.read(MAX_READ_LEN)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.trio
|
||||||
async def test_mplex_stream_read_after_remote_closed_and_reset(mplex_stream_pair):
|
async def test_mplex_stream_read_after_remote_closed_and_reset(mplex_stream_pair):
|
||||||
stream_0, stream_1 = mplex_stream_pair
|
stream_0, stream_1 = mplex_stream_pair
|
||||||
await stream_0.write(DATA)
|
await stream_0.write(DATA)
|
||||||
await stream_0.close()
|
await stream_0.close()
|
||||||
await stream_0.reset()
|
await stream_0.reset()
|
||||||
# Sleep to let `stream_1` receive the message.
|
# Sleep to let `stream_1` receive the message.
|
||||||
await asyncio.sleep(0.01)
|
await trio.sleep(0.01)
|
||||||
assert (await stream_1.read(MAX_READ_LEN)) == DATA
|
assert (await stream_1.read(MAX_READ_LEN)) == DATA
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.trio
|
||||||
async def test_mplex_stream_write_after_local_closed(mplex_stream_pair):
|
async def test_mplex_stream_write_after_local_closed(mplex_stream_pair):
|
||||||
stream_0, stream_1 = mplex_stream_pair
|
stream_0, stream_1 = mplex_stream_pair
|
||||||
await stream_0.write(DATA)
|
await stream_0.write(DATA)
|
||||||
|
@ -113,7 +112,7 @@ async def test_mplex_stream_write_after_local_closed(mplex_stream_pair):
|
||||||
await stream_0.write(DATA)
|
await stream_0.write(DATA)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.trio
|
||||||
async def test_mplex_stream_write_after_local_reset(mplex_stream_pair):
|
async def test_mplex_stream_write_after_local_reset(mplex_stream_pair):
|
||||||
stream_0, stream_1 = mplex_stream_pair
|
stream_0, stream_1 = mplex_stream_pair
|
||||||
await stream_0.reset()
|
await stream_0.reset()
|
||||||
|
@ -121,16 +120,16 @@ async def test_mplex_stream_write_after_local_reset(mplex_stream_pair):
|
||||||
await stream_0.write(DATA)
|
await stream_0.write(DATA)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.trio
|
||||||
async def test_mplex_stream_write_after_remote_reset(mplex_stream_pair):
|
async def test_mplex_stream_write_after_remote_reset(mplex_stream_pair):
|
||||||
stream_0, stream_1 = mplex_stream_pair
|
stream_0, stream_1 = mplex_stream_pair
|
||||||
await stream_1.reset()
|
await stream_1.reset()
|
||||||
await asyncio.sleep(0.01)
|
await trio.sleep(0.01)
|
||||||
with pytest.raises(MplexStreamClosed):
|
with pytest.raises(MplexStreamClosed):
|
||||||
await stream_0.write(DATA)
|
await stream_0.write(DATA)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.trio
|
||||||
async def test_mplex_stream_both_close(mplex_stream_pair):
|
async def test_mplex_stream_both_close(mplex_stream_pair):
|
||||||
stream_0, stream_1 = mplex_stream_pair
|
stream_0, stream_1 = mplex_stream_pair
|
||||||
# Flags are not set initially.
|
# Flags are not set initially.
|
||||||
|
@ -144,7 +143,7 @@ async def test_mplex_stream_both_close(mplex_stream_pair):
|
||||||
|
|
||||||
# Test: Close one side.
|
# Test: Close one side.
|
||||||
await stream_0.close()
|
await stream_0.close()
|
||||||
await asyncio.sleep(0.01)
|
await trio.sleep(0.01)
|
||||||
|
|
||||||
assert stream_0.event_local_closed.is_set()
|
assert stream_0.event_local_closed.is_set()
|
||||||
assert not stream_1.event_local_closed.is_set()
|
assert not stream_1.event_local_closed.is_set()
|
||||||
|
@ -156,7 +155,7 @@ async def test_mplex_stream_both_close(mplex_stream_pair):
|
||||||
|
|
||||||
# Test: Close the other side.
|
# Test: Close the other side.
|
||||||
await stream_1.close()
|
await stream_1.close()
|
||||||
await asyncio.sleep(0.01)
|
await trio.sleep(0.01)
|
||||||
# Both sides are closed.
|
# Both sides are closed.
|
||||||
assert stream_0.event_local_closed.is_set()
|
assert stream_0.event_local_closed.is_set()
|
||||||
assert stream_1.event_local_closed.is_set()
|
assert stream_1.event_local_closed.is_set()
|
||||||
|
@ -170,11 +169,11 @@ async def test_mplex_stream_both_close(mplex_stream_pair):
|
||||||
await stream_0.reset()
|
await stream_0.reset()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.trio
|
||||||
async def test_mplex_stream_reset(mplex_stream_pair):
|
async def test_mplex_stream_reset(mplex_stream_pair):
|
||||||
stream_0, stream_1 = mplex_stream_pair
|
stream_0, stream_1 = mplex_stream_pair
|
||||||
await stream_0.reset()
|
await stream_0.reset()
|
||||||
await asyncio.sleep(0.01)
|
await trio.sleep(0.01)
|
||||||
|
|
||||||
# Both sides are closed.
|
# Both sides are closed.
|
||||||
assert stream_0.event_local_closed.is_set()
|
assert stream_0.event_local_closed.is_set()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user