From 1929f307fb64b6d526b0e5c07160dd53cd3cbfcf Mon Sep 17 00:00:00 2001 From: mhchia Date: Fri, 6 Dec 2019 17:06:37 +0800 Subject: [PATCH] Fix all modules except for security --- libp2p/host/ping.py | 3 +- libp2p/io/trio.py | 52 +- libp2p/network/connection/raw_connection.py | 14 +- libp2p/network/connection/swarm_connection.py | 13 +- libp2p/network/swarm.py | 12 +- libp2p/pubsub/floodsub.py | 6 +- libp2p/pubsub/gossipsub.py | 33 +- libp2p/pubsub/pubsub.py | 36 +- libp2p/stream_muxer/abc.py | 6 +- libp2p/stream_muxer/mplex/mplex.py | 66 +- libp2p/tools/factories.py | 155 +++-- libp2p/tools/pubsub/dummy_account_node.py | 51 +- .../floodsub_integration_test_settings.py | 151 ++-- libp2p/tools/utils.py | 56 +- libp2p/transport/listener_interface.py | 12 +- libp2p/transport/tcp/tcp.py | 42 +- tests/host/test_ping.py | 2 +- tests/host/test_routed_host.py | 73 -- tests/protocol_muxer/test_protocol_muxer.py | 1 - tests/pubsub/conftest.py | 22 - tests/pubsub/test_dummyaccount_demo.py | 95 +-- tests/pubsub/test_floodsub.py | 117 ++-- tests/pubsub/test_gossipsub.py | 646 +++++++++--------- .../test_gossipsub_backward_compatibility.py | 16 +- tests/pubsub/test_mcache.py | 5 +- tests/pubsub/test_pubsub.py | 9 +- tests/stream_muxer/test_mplex_stream.py | 23 - tests/transport/test_tcp.py | 2 +- 28 files changed, 764 insertions(+), 955 deletions(-) delete mode 100644 tests/host/test_routed_host.py delete mode 100644 tests/pubsub/conftest.py diff --git a/libp2p/host/ping.py b/libp2p/host/ping.py index 589fc91..9e23f1c 100644 --- a/libp2p/host/ping.py +++ b/libp2p/host/ping.py @@ -1,6 +1,7 @@ -import trio import logging +import trio + from libp2p.network.stream.exceptions import StreamClosed, StreamEOF, StreamReset from libp2p.network.stream.net_stream_interface import INetStream from libp2p.peer.id import ID as PeerID diff --git a/libp2p/io/trio.py b/libp2p/io/trio.py index e74e9ed..840c3bc 100644 --- a/libp2p/io/trio.py +++ b/libp2p/io/trio.py @@ -1,7 +1,6 @@ import logging import trio -from trio import SocketStream from libp2p.io.abc import ReadWriteCloser from libp2p.io.exceptions import IOException @@ -9,29 +8,48 @@ from libp2p.io.exceptions import IOException logger = logging.getLogger("libp2p.io.trio") -class TrioReadWriteCloser(ReadWriteCloser): - stream: SocketStream +class TrioTCPStream(ReadWriteCloser): + stream: trio.SocketStream + # NOTE: Add both read and write lock to avoid `trio.BusyResourceError` + read_lock: trio.Lock + write_lock: trio.Lock - def __init__(self, stream: SocketStream) -> None: + def __init__(self, stream: trio.SocketStream) -> None: self.stream = stream + self.read_lock = trio.Lock() + self.write_lock = trio.Lock() async def write(self, data: bytes) -> None: """Raise `RawConnError` if the underlying connection breaks.""" - try: - await self.stream.send_all(data) - except (trio.ClosedResourceError, trio.BrokenResourceError) as error: - raise IOException(error) + async with self.write_lock: + try: + await self.stream.send_all(data) + except (trio.ClosedResourceError, trio.BrokenResourceError) as error: + raise IOException from error + except trio.BusyResourceError as error: + # This should never happen, since we already access streams with read/write locks. + raise Exception( + "this should never happen " + "since we already access streams with read/write locks." + ) from error async def read(self, n: int = -1) -> bytes: - if n == 0: - # Check point - await trio.sleep(0) - return b"" - max_bytes = n if n != -1 else None - try: - return await self.stream.receive_some(max_bytes) - except (trio.ClosedResourceError, trio.BrokenResourceError) as error: - raise IOException(error) + async with self.read_lock: + if n == 0: + # Checkpoint + await trio.hazmat.checkpoint() + return b"" + max_bytes = n if n != -1 else None + try: + return await self.stream.receive_some(max_bytes) + except (trio.ClosedResourceError, trio.BrokenResourceError) as error: + raise IOException from error + except trio.BusyResourceError as error: + # This should never happen, since we already access streams with read/write locks. + raise Exception( + "this should never happen " + "since we already access streams with read/write locks." + ) from error async def close(self) -> None: await self.stream.aclose() diff --git a/libp2p/network/connection/raw_connection.py b/libp2p/network/connection/raw_connection.py index 2bdb3b1..25b1049 100644 --- a/libp2p/network/connection/raw_connection.py +++ b/libp2p/network/connection/raw_connection.py @@ -1,5 +1,3 @@ -import trio - from libp2p.io.abc import ReadWriteCloser from libp2p.io.exceptions import IOException @@ -8,17 +6,17 @@ from .raw_connection_interface import IRawConnection class RawConnection(IRawConnection): - read_write_closer: ReadWriteCloser + stream: ReadWriteCloser is_initiator: bool - def __init__(self, read_write_closer: ReadWriteCloser, initiator: bool) -> None: - self.read_write_closer = read_write_closer + def __init__(self, stream: ReadWriteCloser, initiator: bool) -> None: + self.stream = stream self.is_initiator = initiator async def write(self, data: bytes) -> None: """Raise `RawConnError` if the underlying connection breaks.""" try: - await self.read_write_closer.write(data) + await self.stream.write(data) except IOException as error: raise RawConnError(error) @@ -30,9 +28,9 @@ class RawConnection(IRawConnection): Raise `RawConnError` if the underlying connection breaks """ try: - return await self.read_write_closer.read(n) + return await self.stream.read(n) except IOException as error: raise RawConnError(error) async def close(self) -> None: - await self.read_write_closer.close() + await self.stream.close() diff --git a/libp2p/network/connection/swarm_connection.py b/libp2p/network/connection/swarm_connection.py index 48774ec..1e31033 100644 --- a/libp2p/network/connection/swarm_connection.py +++ b/libp2p/network/connection/swarm_connection.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any, Awaitable, List, Set, Tuple +from typing import TYPE_CHECKING, Set, Tuple from async_service import Service import trio @@ -45,16 +45,11 @@ class SwarmConn(INetConn, Service): # before we cancel the stream handler tasks. await trio.sleep(0.1) - # FIXME: Now let `_notify_disconnected` finish first. - # Schedule `self._notify_disconnected` to make it execute after `close` is finished. await self._notify_disconnected() async def _handle_new_streams(self) -> None: while self.manager.is_running: try: - print( - f"!@# SwarmConn._handle_new_streams: {self.muxed_conn._id}: waiting for new streams" - ) stream = await self.muxed_conn.accept_stream() except MuxedConnUnavailable: # If there is anything wrong in the MuxedConn, @@ -63,9 +58,6 @@ class SwarmConn(INetConn, Service): # Asynchronously handle the accepted stream, to avoid blocking the next stream. self.manager.run_task(self._handle_muxed_stream, stream) - print( - f"!@# SwarmConn._handle_new_streams: {self.muxed_conn._id}: out of the loop" - ) await self.close() async def _call_stream_handler(self, net_stream: NetStream) -> None: @@ -92,8 +84,7 @@ class SwarmConn(INetConn, Service): await self.swarm.notify_disconnected(self) async def run(self) -> None: - self.manager.run_task(self._handle_new_streams) - await self.manager.wait_finished() + await self._handle_new_streams() async def new_stream(self) -> NetStream: muxed_stream = await self.muxed_conn.open_stream() diff --git a/libp2p/network/swarm.py b/libp2p/network/swarm.py index 37614bc..4bf86dd 100644 --- a/libp2p/network/swarm.py +++ b/libp2p/network/swarm.py @@ -203,16 +203,17 @@ class Swarm(INetwork, Service): await self.add_conn(muxed_conn) logger.debug("successfully opened connection to peer %s", peer_id) - # FIXME: This is a intentional barrier to prevent from the handler exiting and - # closing the connection. Probably change to `Service.manager.wait_finished`? - await trio.sleep_forever() + # NOTE: This is a intentional barrier to prevent from the handler exiting and + # closing the connection. + await self.manager.wait_finished() try: # Success listener = self.transport.create_listener(conn_handler) self.listeners[str(maddr)] = listener - # FIXME: Hack - await listener.listen(maddr, self.manager._task_nursery) + # TODO: `listener.listen` is not bounded with nursery. If we want to be + # I/O agnostic, we should change the API. + await listener.listen(maddr, self.manager._task_nursery) # type: ignore # Call notifiers since event occurred await self.notify_listen(maddr) @@ -278,6 +279,7 @@ class Swarm(INetwork, Service): """ self.notifees.append(notifee) + # TODO: Use `run_task`. async def notify_opened_stream(self, stream: INetStream) -> None: async with trio.open_nursery() as nursery: for notifee in self.notifees: diff --git a/libp2p/pubsub/floodsub.py b/libp2p/pubsub/floodsub.py index 8c15a44..9e323eb 100644 --- a/libp2p/pubsub/floodsub.py +++ b/libp2p/pubsub/floodsub.py @@ -64,7 +64,7 @@ class FloodSub(IPubsubRouter): :param rpc: rpc message """ # Checkpoint - await trio.sleep(0) + await trio.hazmat.checkpoint() async def publish(self, msg_forwarder: ID, pubsub_msg: rpc_pb2.Message) -> None: """ @@ -107,7 +107,7 @@ class FloodSub(IPubsubRouter): :param topic: topic to join """ # Checkpoint - await trio.sleep(0) + await trio.hazmat.checkpoint() async def leave(self, topic: str) -> None: """ @@ -117,7 +117,7 @@ class FloodSub(IPubsubRouter): :param topic: topic to leave """ # Checkpoint - await trio.sleep(0) + await trio.hazmat.checkpoint() def _get_peers_to_send( self, topic_ids: Iterable[str], msg_forwarder: ID, origin: ID diff --git a/libp2p/pubsub/gossipsub.py b/libp2p/pubsub/gossipsub.py index 93faebd..df0f83f 100644 --- a/libp2p/pubsub/gossipsub.py +++ b/libp2p/pubsub/gossipsub.py @@ -1,15 +1,18 @@ from ast import literal_eval -import asyncio import logging import random from typing import Any, Dict, Iterable, List, Sequence, Set +from async_service import Service +import trio + from libp2p.network.stream.exceptions import StreamClosed from libp2p.peer.id import ID from libp2p.pubsub import floodsub from libp2p.typing import TProtocol from libp2p.utils import encode_varint_prefixed +from .exceptions import NoPubsubAttached from .mcache import MessageCache from .pb import rpc_pb2 from .pubsub import Pubsub @@ -20,8 +23,7 @@ PROTOCOL_ID = TProtocol("/meshsub/1.0.0") logger = logging.getLogger("libp2p.pubsub.gossipsub") -class GossipSub(IPubsubRouter): - +class GossipSub(IPubsubRouter, Service): protocols: List[TProtocol] pubsub: Pubsub @@ -86,6 +88,12 @@ class GossipSub(IPubsubRouter): # Create heartbeat timer self.heartbeat_interval = heartbeat_interval + async def run(self) -> None: + if self.pubsub is None: + raise NoPubsubAttached + self.manager.run_task(self.heartbeat) + await self.manager.wait_finished() + # Interface functions def get_protocols(self) -> List[TProtocol]: @@ -105,10 +113,6 @@ class GossipSub(IPubsubRouter): logger.debug("attached to pusub") - # Start heartbeat now that we have a pubsub instance - # TODO: Start after delay - asyncio.ensure_future(self.heartbeat()) - def add_peer(self, peer_id: ID, protocol_id: TProtocol) -> None: """ Notifies the router that a new peer has been connected. @@ -310,7 +314,7 @@ class GossipSub(IPubsubRouter): await self.fanout_heartbeat() await self.gossip_heartbeat() - await asyncio.sleep(self.heartbeat_interval) + await trio.sleep(self.heartbeat_interval) async def mesh_heartbeat(self) -> None: # Note: the comments here are the exact pseudocode from the spec @@ -338,7 +342,7 @@ class GossipSub(IPubsubRouter): if num_mesh_peers_in_topic > self.degree_high: # Select |mesh[topic]| - D peers from mesh[topic] - selected_peers = GossipSub.select_from_minus( + selected_peers = self.select_from_minus( num_mesh_peers_in_topic - self.degree, self.mesh[topic], [] ) for peer in selected_peers: @@ -353,7 +357,10 @@ class GossipSub(IPubsubRouter): for topic in self.fanout: # If time since last published > ttl # TODO: there's no way time_since_last_publish gets set anywhere yet - if self.time_since_last_publish[topic] > self.time_to_live: + if ( + topic in self.time_since_last_publish + and self.time_since_last_publish[topic] > self.time_to_live + ): # Remove topic from fanout del self.fanout[topic] del self.time_since_last_publish[topic] @@ -407,11 +414,7 @@ class GossipSub(IPubsubRouter): topic, self.degree, [] ) for peer in peers_to_emit_ihave_to: - if ( - peer not in self.mesh[topic] - and peer not in self.fanout[topic] - ): - + if peer not in self.fanout[topic]: msg_id_strs = [str(msg) for msg in msg_ids] await self.emit_ihave(topic, msg_id_strs, peer) diff --git a/libp2p/pubsub/pubsub.py b/libp2p/pubsub/pubsub.py index 7c4b50d..3370ea3 100644 --- a/libp2p/pubsub/pubsub.py +++ b/libp2p/pubsub/pubsub.py @@ -1,4 +1,4 @@ -from abc import ABC, abstractmethod +from abc import ABC import logging import math import time @@ -57,6 +57,7 @@ class TopicValidator(NamedTuple): is_async: bool +# TODO: Add interface for Pubsub class BasePubsub(ABC): pass @@ -103,20 +104,24 @@ class Pubsub(BasePubsub, Service): # Attach this new Pubsub object to the router self.router.attach(self) - peer_send_channel, peer_receive_channel = trio.open_memory_channel(0) - dead_peer_send_channel, dead_peer_receive_channel = trio.open_memory_channel(0) + peer_channels: Tuple[ + "trio.MemorySendChannel[ID]", "trio.MemoryReceiveChannel[ID]" + ] = trio.open_memory_channel(0) + dead_peer_channels: Tuple[ + "trio.MemorySendChannel[ID]", "trio.MemoryReceiveChannel[ID]" + ] = trio.open_memory_channel(0) # Only keep the receive channels in `Pubsub`. # Therefore, we can only close from the receive side. - self.peer_receive_channel = peer_receive_channel - self.dead_peer_receive_channel = dead_peer_receive_channel + self.peer_receive_channel = peer_channels[1] + self.dead_peer_receive_channel = dead_peer_channels[1] # Register a notifee self.host.get_network().register_notifee( - PubsubNotifee(peer_send_channel, dead_peer_send_channel) + PubsubNotifee(peer_channels[0], dead_peer_channels[0]) ) # Register stream handlers for each pubsub router protocol to handle # the pubsub streams opened on those protocols - for protocol in router.protocols: + for protocol in router.get_protocols(): self.host.set_stream_handler(protocol, self.stream_handler) # keeps track of seen messages as LRU cache @@ -328,8 +333,9 @@ class Pubsub(BasePubsub, Service): self.manager.run_task(self._handle_new_peer, peer_id) async def handle_dead_peer_queue(self) -> None: - """Continuously read from dead peer channel and close the stream between - that peer and remove peer info from pubsub and pubsub router.""" + """Continuously read from dead peer channel and close the stream + between that peer and remove peer info from pubsub and pubsub + router.""" async with self.dead_peer_receive_channel: while self.manager.is_running: peer_id: ID = await self.dead_peer_receive_channel.receive() @@ -391,7 +397,11 @@ class Pubsub(BasePubsub, Service): return self.subscribed_topics_receive[topic_id] # Map topic_id to a blocking channel - send_channel, receive_channel = trio.open_memory_channel(math.inf) + channels: Tuple[ + "trio.MemorySendChannel[rpc_pb2.Message]", + "trio.MemoryReceiveChannel[rpc_pb2.Message]", + ] = trio.open_memory_channel(math.inf) + send_channel, receive_channel = channels self.subscribed_topics_send[topic_id] = send_channel self.subscribed_topics_receive[topic_id] = receive_channel @@ -506,7 +516,7 @@ class Pubsub(BasePubsub, Service): if len(async_topic_validators) > 0: # TODO: Use a better pattern - final_result = True + final_result: bool = True async def run_async_validator(func: AsyncValidatorFn) -> None: nonlocal final_result @@ -514,8 +524,8 @@ class Pubsub(BasePubsub, Service): final_result = final_result and result async with trio.open_nursery() as nursery: - for validator in async_topic_validators: - nursery.start_soon(run_async_validator, validator) + for async_validator in async_topic_validators: + nursery.start_soon(run_async_validator, async_validator) if not final_result: raise ValidationError(f"Validation failed for msg={msg}") diff --git a/libp2p/stream_muxer/abc.py b/libp2p/stream_muxer/abc.py index 71704c1..12a8f80 100644 --- a/libp2p/stream_muxer/abc.py +++ b/libp2p/stream_muxer/abc.py @@ -1,11 +1,13 @@ -from abc import ABC, abstractmethod +from abc import abstractmethod + +from async_service import ServiceAPI from libp2p.io.abc import ReadWriteCloser from libp2p.peer.id import ID from libp2p.security.secure_conn_interface import ISecureConn -class IMuxedConn(ABC): +class IMuxedConn(ServiceAPI): """ reference: https://github.com/libp2p/go-stream-muxer/blob/master/muxer.go """ diff --git a/libp2p/stream_muxer/mplex/mplex.py b/libp2p/stream_muxer/mplex/mplex.py index ac6cdcd..e23da00 100644 --- a/libp2p/stream_muxer/mplex/mplex.py +++ b/libp2p/stream_muxer/mplex/mplex.py @@ -1,7 +1,6 @@ import logging import math -from typing import Any # noqa: F401 -from typing import Awaitable, Dict, List, Optional, Tuple +from typing import Dict, Optional, Tuple from async_service import Service import trio @@ -67,13 +66,15 @@ class Mplex(IMuxedConn, Service): self.streams = {} self.streams_lock = trio.Lock() self.streams_msg_channels = {} - send_channel, receive_channel = trio.open_memory_channel(math.inf) - self.new_stream_send_channel = send_channel - self.new_stream_receive_channel = receive_channel + channels: Tuple[ + "trio.MemorySendChannel[IMuxedStream]", + "trio.MemoryReceiveChannel[IMuxedStream]", + ] = trio.open_memory_channel(math.inf) + self.new_stream_send_channel, self.new_stream_receive_channel = channels self.event_shutting_down = trio.Event() self.event_closed = trio.Event() - async def run(self): + async def run(self) -> None: self.manager.run_task(self.handle_incoming) await self.manager.wait_finished() @@ -112,11 +113,13 @@ class Mplex(IMuxedConn, Service): async def _initialize_stream(self, stream_id: StreamID, name: str) -> MplexStream: # Use an unbounded buffer, to avoid `handle_incoming` being blocked when doing # `send_channel.send`. - send_channel, receive_channel = trio.open_memory_channel(math.inf) - stream = MplexStream(name, stream_id, self, receive_channel) + channels: Tuple[ + "trio.MemorySendChannel[bytes]", "trio.MemoryReceiveChannel[bytes]" + ] = trio.open_memory_channel(math.inf) + stream = MplexStream(name, stream_id, self, channels[1]) async with self.streams_lock: self.streams[stream_id] = stream - self.streams_msg_channels[stream_id] = send_channel + self.streams_msg_channels[stream_id] = channels[0] return stream async def open_stream(self) -> IMuxedStream: @@ -150,9 +153,6 @@ class Mplex(IMuxedConn, Service): :param data: data to send in the message :param stream_id: stream the message is in """ - print( - f"!@# send_message: {self._id}: flag={flag}, data={data}, stream_id={stream_id}" - ) # << by 3, then or with flag header = encode_uvarint((stream_id.channel_id << 3) | flag.value) @@ -179,19 +179,10 @@ class Mplex(IMuxedConn, Service): while self.manager.is_running: try: - print( - f"!@# handle_incoming: {self._id}: before _handle_incoming_message" - ) await self._handle_incoming_message() - print( - f"!@# handle_incoming: {self._id}: after _handle_incoming_message" - ) except MplexUnavailable as e: logger.debug("mplex unavailable while waiting for incoming: %s", e) - print(f"!@# handle_incoming: {self._id}: MplexUnavailable: {e}") break - - print(f"!@# handle_incoming: {self._id}: leaving") # If we enter here, it means this connection is shutting down. # We should clean things up. await self._cleanup() @@ -232,44 +223,27 @@ class Mplex(IMuxedConn, Service): :raise MplexUnavailable: `Mplex` encounters fatal error or is shutting down. """ - print(f"!@# _handle_incoming_message: {self._id}: before reading") channel_id, flag, message = await self.read_message() - print( - f"!@# _handle_incoming_message: {self._id}: channel_id={channel_id}, flag={flag}, message={message}" - ) stream_id = StreamID(channel_id=channel_id, is_initiator=bool(flag & 1)) - print(f"!@# _handle_incoming_message: {self._id}: 2") if flag == HeaderTags.NewStream.value: - print(f"!@# _handle_incoming_message: {self._id}: 3") await self._handle_new_stream(stream_id, message) - print(f"!@# _handle_incoming_message: {self._id}: 4") elif flag in ( HeaderTags.MessageInitiator.value, HeaderTags.MessageReceiver.value, ): - print(f"!@# _handle_incoming_message: {self._id}: 5") await self._handle_message(stream_id, message) - print(f"!@# _handle_incoming_message: {self._id}: 6") elif flag in (HeaderTags.CloseInitiator.value, HeaderTags.CloseReceiver.value): - print(f"!@# _handle_incoming_message: {self._id}: 7") await self._handle_close(stream_id) - print(f"!@# _handle_incoming_message: {self._id}: 8") elif flag in (HeaderTags.ResetInitiator.value, HeaderTags.ResetReceiver.value): - print(f"!@# _handle_incoming_message: {self._id}: 9") await self._handle_reset(stream_id) - print(f"!@# _handle_incoming_message: {self._id}: 10") else: - print(f"!@# _handle_incoming_message: {self._id}: 11") # Receives messages with an unknown flag # TODO: logging async with self.streams_lock: - print(f"!@# _handle_incoming_message: {self._id}: 12") if stream_id in self.streams: - print(f"!@# _handle_incoming_message: {self._id}: 13") stream = self.streams[stream_id] await stream.reset() - print(f"!@# _handle_incoming_message: {self._id}: 14") async def _handle_new_stream(self, stream_id: StreamID, message: bytes) -> None: async with self.streams_lock: @@ -285,59 +259,43 @@ class Mplex(IMuxedConn, Service): raise MplexUnavailable async def _handle_message(self, stream_id: StreamID, message: bytes) -> None: - print( - f"!@# _handle_message: {self._id}: stream_id={stream_id}, message={message}" - ) async with self.streams_lock: - print(f"!@# _handle_message: {self._id}: 1") if stream_id not in self.streams: # We receive a message of the stream `stream_id` which is not accepted # before. It is abnormal. Possibly disconnect? # TODO: Warn and emit logs about this. - print(f"!@# _handle_message: {self._id}: 2") return - print(f"!@# _handle_message: {self._id}: 3") stream = self.streams[stream_id] send_channel = self.streams_msg_channels[stream_id] async with stream.close_lock: - print(f"!@# _handle_message: {self._id}: 4") if stream.event_remote_closed.is_set(): - print(f"!@# _handle_message: {self._id}: 5") # TODO: Warn "Received data from remote after stream was closed by them. (len = %d)" # noqa: E501 return - print(f"!@# _handle_message: {self._id}: 6") await send_channel.send(message) - print(f"!@# _handle_message: {self._id}: 7") async def _handle_close(self, stream_id: StreamID) -> None: - print(f"!@# _handle_close: {self._id}: step=0") async with self.streams_lock: if stream_id not in self.streams: # Ignore unmatched messages for now. return stream = self.streams[stream_id] send_channel = self.streams_msg_channels[stream_id] - print(f"!@# _handle_close: {self._id}: step=1") await send_channel.aclose() - print(f"!@# _handle_close: {self._id}: step=2") # NOTE: If remote is already closed, then return: Technically a bug # on the other side. We should consider killing the connection. async with stream.close_lock: if stream.event_remote_closed.is_set(): return - print(f"!@# _handle_close: {self._id}: step=3") is_local_closed: bool async with stream.close_lock: stream.event_remote_closed.set() is_local_closed = stream.event_local_closed.is_set() - print(f"!@# _handle_close: {self._id}: step=4") # If local is also closed, both sides are closed. Then, we should clean up # the entry of this stream, to avoid others from accessing it. if is_local_closed: async with self.streams_lock: if stream_id in self.streams: del self.streams[stream_id] - print(f"!@# _handle_close: {self._id}: step=5") async def _handle_reset(self, stream_id: StreamID) -> None: async with self.streams_lock: diff --git a/libp2p/tools/factories.py b/libp2p/tools/factories.py index ac24301..568a276 100644 --- a/libp2p/tools/factories.py +++ b/libp2p/tools/factories.py @@ -1,30 +1,29 @@ from contextlib import AsyncExitStack, asynccontextmanager -from typing import Any, AsyncIterator, Dict, Tuple, cast +from typing import Any, AsyncIterator, Dict, Sequence, Tuple, cast from async_service import background_trio_service import factory import trio -from libp2p.tools.constants import GOSSIPSUB_PARAMS from libp2p import generate_new_rsa_identity, generate_peer_id_from from libp2p.crypto.keys import KeyPair from libp2p.host.basic_host import BasicHost -from libp2p.host.routed_host import RoutedHost -from libp2p.tools.utils import set_up_routers -from libp2p.kademlia.network import KademliaServer +from libp2p.host.host_interface import IHost from libp2p.network.connection.swarm_connection import SwarmConn from libp2p.network.stream.net_stream_interface import INetStream from libp2p.network.swarm import Swarm -from libp2p.peer.peerstore import PeerStore from libp2p.peer.id import ID +from libp2p.peer.peerstore import PeerStore from libp2p.pubsub.floodsub import FloodSub from libp2p.pubsub.gossipsub import GossipSub from libp2p.pubsub.pubsub import Pubsub +from libp2p.pubsub.pubsub_router_interface import IPubsubRouter from libp2p.security.base_transport import BaseSecureTransport from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport import libp2p.security.secio.transport as secio from libp2p.stream_muxer.mplex.mplex import MPLEX_PROTOCOL_ID, Mplex from libp2p.stream_muxer.mplex.mplex_stream import MplexStream +from libp2p.tools.constants import GOSSIPSUB_PARAMS from libp2p.transport.tcp.tcp import TCP from libp2p.transport.typing import TMuxerOptions from libp2p.transport.upgrader import TransportUpgrader @@ -74,7 +73,7 @@ class SwarmFactory(factory.Factory): @asynccontextmanager async def create_and_listen( cls, is_secure: bool, key_pair: KeyPair = None, muxer_opt: TMuxerOptions = None - ) -> Swarm: + ) -> AsyncIterator[Swarm]: # `factory.Factory.__init__` does *not* prepare a *default value* if we pass # an argument explicitly with `None`. If an argument is `None`, we don't pass it to # `factory.Factory.__init__`, in order to let the function initialize it. @@ -92,7 +91,7 @@ class SwarmFactory(factory.Factory): @asynccontextmanager async def create_batch_and_listen( cls, is_secure: bool, number: int, muxer_opt: TMuxerOptions = None - ) -> Tuple[Swarm, ...]: + ) -> AsyncIterator[Tuple[Swarm, ...]]: async with AsyncExitStack() as stack: ctx_mgrs = [ await stack.enter_async_context( @@ -100,7 +99,7 @@ class SwarmFactory(factory.Factory): ) for _ in range(number) ] - yield ctx_mgrs + yield tuple(ctx_mgrs) class HostFactory(factory.Factory): @@ -120,7 +119,7 @@ class HostFactory(factory.Factory): @asynccontextmanager async def create_batch_and_listen( cls, is_secure: bool, number: int - ) -> Tuple[BasicHost, ...]: + ) -> AsyncIterator[Tuple[BasicHost, ...]]: key_pairs = [generate_new_rsa_identity() for _ in range(number)] async with AsyncExitStack() as stack: swarms = [ @@ -136,30 +135,6 @@ class HostFactory(factory.Factory): yield hosts -class RoutedHostFactory(factory.Factory): - class Meta: - model = RoutedHost - - public_key = factory.LazyAttribute(lambda o: o.key_pair.public_key) - network = factory.LazyAttribute( - lambda o: SwarmFactory(is_secure=o.is_secure, key_pair=o.key_pair) - ) - router = factory.LazyFunction(KademliaServer) - - @classmethod - @asynccontextmanager - async def create_batch_and_listen( - cls, is_secure: bool, number: int - ) -> Tuple[RoutedHost, ...]: - key_pairs = [generate_new_rsa_identity() for _ in range(number)] - routers = await set_up_routers((0,) * number) - async with SwarmFactory.create_batch_and_listen(is_secure, number) as swarms: - yield tuple( - RoutedHost(key_pair.public_key, swarm, router) - for key_pair, swarm, router in zip(key_pairs, swarms, routers) - ) - - class FloodsubFactory(factory.Factory): class Meta: model = FloodSub @@ -191,17 +166,22 @@ class PubsubFactory(factory.Factory): @classmethod @asynccontextmanager - async def create_and_start(cls, host, router, cache_size): + async def create_and_start( + cls, host: IHost, router: IPubsubRouter, cache_size: int + ) -> AsyncIterator[Pubsub]: pubsub = PubsubFactory(host=host, router=router, cache_size=cache_size) async with background_trio_service(pubsub): yield pubsub @classmethod @asynccontextmanager - async def create_batch_with_floodsub( - cls, number: int, is_secure: bool = False, cache_size: int = None - ): - floodsubs = FloodsubFactory.create_batch(number) + async def _create_batch_with_router( + cls, + number: int, + routers: Sequence[IPubsubRouter], + is_secure: bool = False, + cache_size: int = None, + ) -> AsyncIterator[Tuple[Pubsub, ...]]: async with HostFactory.create_batch_and_listen(is_secure, number) as hosts: # Pubsubs should exit before hosts async with AsyncExitStack() as stack: @@ -209,21 +189,80 @@ class PubsubFactory(factory.Factory): await stack.enter_async_context( cls.create_and_start(host, router, cache_size) ) - for host, router in zip(hosts, floodsubs) + for host, router in zip(hosts, routers) ] - yield pubsubs + yield tuple(pubsubs) - # @classmethod - # async def create_batch_with_gossipsub( - # cls, number: int, cache_size: int = None, gossipsub_params=GOSSIPSUB_PARAMS - # ): - # ... + @classmethod + @asynccontextmanager + async def create_batch_with_floodsub( + cls, + number: int, + is_secure: bool = False, + cache_size: int = None, + protocols: Sequence[TProtocol] = None, + ) -> AsyncIterator[Tuple[Pubsub, ...]]: + if protocols is not None: + floodsubs = FloodsubFactory.create_batch(number, protocols=list(protocols)) + else: + floodsubs = FloodsubFactory.create_batch(number) + async with cls._create_batch_with_router( + number, floodsubs, is_secure, cache_size + ) as pubsubs: + yield pubsubs + + @classmethod + @asynccontextmanager + async def create_batch_with_gossipsub( + cls, + number: int, + *, + is_secure: bool = False, + cache_size: int = None, + protocols: Sequence[TProtocol] = None, + degree: int = GOSSIPSUB_PARAMS.degree, + degree_low: int = GOSSIPSUB_PARAMS.degree_low, + degree_high: int = GOSSIPSUB_PARAMS.degree_high, + time_to_live: int = GOSSIPSUB_PARAMS.time_to_live, + gossip_window: int = GOSSIPSUB_PARAMS.gossip_window, + gossip_history: int = GOSSIPSUB_PARAMS.gossip_history, + heartbeat_interval: float = GOSSIPSUB_PARAMS.heartbeat_interval, + ) -> AsyncIterator[Tuple[Pubsub, ...]]: + if protocols is not None: + gossipsubs = GossipsubFactory.create_batch( + number, + protocols=protocols, + degree=degree, + degree_low=degree_low, + degree_high=degree_high, + time_to_live=time_to_live, + gossip_window=gossip_window, + heartbeat_interval=heartbeat_interval, + ) + else: + gossipsubs = GossipsubFactory.create_batch( + number, + degree=degree, + degree_low=degree_low, + degree_high=degree_high, + time_to_live=time_to_live, + gossip_window=gossip_window, + heartbeat_interval=heartbeat_interval, + ) + + async with cls._create_batch_with_router( + number, gossipsubs, is_secure, cache_size + ) as pubsubs: + async with AsyncExitStack() as stack: + for router in gossipsubs: + await stack.enter_async_context(background_trio_service(router)) + yield pubsubs @asynccontextmanager async def swarm_pair_factory( is_secure: bool, muxer_opt: TMuxerOptions = None -) -> Tuple[Swarm, Swarm]: +) -> AsyncIterator[Tuple[Swarm, Swarm]]: async with SwarmFactory.create_batch_and_listen( is_secure, 2, muxer_opt=muxer_opt ) as swarms: @@ -232,7 +271,9 @@ async def swarm_pair_factory( @asynccontextmanager -async def host_pair_factory(is_secure: bool) -> Tuple[BasicHost, BasicHost]: +async def host_pair_factory( + is_secure: bool +) -> AsyncIterator[Tuple[BasicHost, BasicHost]]: async with HostFactory.create_batch_and_listen(is_secure, 2) as hosts: await connect(hosts[0], hosts[1]) yield hosts[0], hosts[1] @@ -241,7 +282,7 @@ async def host_pair_factory(is_secure: bool) -> Tuple[BasicHost, BasicHost]: @asynccontextmanager async def swarm_conn_pair_factory( is_secure: bool, muxer_opt: TMuxerOptions = None -) -> Tuple[SwarmConn, SwarmConn]: +) -> AsyncIterator[Tuple[SwarmConn, SwarmConn]]: async with swarm_pair_factory(is_secure) as swarms: conn_0 = swarms[0].connections[swarms[1].get_peer_id()] conn_1 = swarms[1].connections[swarms[0].get_peer_id()] @@ -249,7 +290,9 @@ async def swarm_conn_pair_factory( @asynccontextmanager -async def mplex_conn_pair_factory(is_secure: bool) -> Tuple[Mplex, Mplex]: +async def mplex_conn_pair_factory( + is_secure: bool +) -> AsyncIterator[Tuple[Mplex, Mplex]]: muxer_opt = {MPLEX_PROTOCOL_ID: Mplex} async with swarm_conn_pair_factory(is_secure, muxer_opt=muxer_opt) as swarm_pair: yield ( @@ -259,21 +302,25 @@ async def mplex_conn_pair_factory(is_secure: bool) -> Tuple[Mplex, Mplex]: @asynccontextmanager -async def mplex_stream_pair_factory(is_secure: bool) -> Tuple[MplexStream, MplexStream]: +async def mplex_stream_pair_factory( + is_secure: bool +) -> AsyncIterator[Tuple[MplexStream, MplexStream]]: async with mplex_conn_pair_factory(is_secure) as mplex_conn_pair_info: mplex_conn_0, mplex_conn_1 = mplex_conn_pair_info - stream_0 = await mplex_conn_0.open_stream() + stream_0 = cast(MplexStream, await mplex_conn_0.open_stream()) await trio.sleep(0.01) stream_1: MplexStream async with mplex_conn_1.streams_lock: if len(mplex_conn_1.streams) != 1: raise Exception("Mplex should not have any other stream") stream_1 = tuple(mplex_conn_1.streams.values())[0] - yield cast(MplexStream, stream_0), cast(MplexStream, stream_1) + yield stream_0, stream_1 @asynccontextmanager -async def net_stream_pair_factory(is_secure: bool) -> Tuple[INetStream, INetStream]: +async def net_stream_pair_factory( + is_secure: bool +) -> AsyncIterator[Tuple[INetStream, INetStream]]: protocol_id = TProtocol("/example/id/1") stream_1: INetStream diff --git a/libp2p/tools/pubsub/dummy_account_node.py b/libp2p/tools/pubsub/dummy_account_node.py index 94f6576..5a61ed6 100644 --- a/libp2p/tools/pubsub/dummy_account_node.py +++ b/libp2p/tools/pubsub/dummy_account_node.py @@ -1,12 +1,11 @@ -import asyncio -from typing import Dict -import uuid +from contextlib import AsyncExitStack, asynccontextmanager +from typing import AsyncIterator, Dict, Tuple + +from async_service import Service, background_trio_service from libp2p.host.host_interface import IHost -from libp2p.pubsub.floodsub import FloodSub from libp2p.pubsub.pubsub import Pubsub -from libp2p.tools.constants import LISTEN_MADDR -from libp2p.tools.factories import FloodsubFactory, PubsubFactory +from libp2p.tools.factories import PubsubFactory CRYPTO_TOPIC = "ethereum" @@ -18,7 +17,7 @@ CRYPTO_TOPIC = "ethereum" # Determine message type by looking at first item before first comma -class DummyAccountNode: +class DummyAccountNode(Service): """ Node which has an internal balance mapping, meant to serve as a dummy crypto blockchain. @@ -27,19 +26,24 @@ class DummyAccountNode: crypto each user in the mappings holds """ - libp2p_node: IHost pubsub: Pubsub - floodsub: FloodSub - def __init__(self, libp2p_node: IHost, pubsub: Pubsub, floodsub: FloodSub): - self.libp2p_node = libp2p_node + def __init__(self, pubsub: Pubsub) -> None: self.pubsub = pubsub - self.floodsub = floodsub self.balances: Dict[str, int] = {} - self.node_id = str(uuid.uuid1()) + + @property + def host(self) -> IHost: + return self.pubsub.host + + async def run(self) -> None: + self.subscription = await self.pubsub.subscribe(CRYPTO_TOPIC) + self.manager.run_daemon_task(self.handle_incoming_msgs) + await self.manager.wait_finished() @classmethod - async def create(cls) -> "DummyAccountNode": + @asynccontextmanager + async def create(cls, number: int) -> AsyncIterator[Tuple["DummyAccountNode", ...]]: """ Create a new DummyAccountNode and attach a libp2p node, a floodsub, and a pubsub instance to this new node. @@ -47,15 +51,17 @@ class DummyAccountNode: We use create as this serves as a factory function and allows us to use async await, unlike the init function """ - - pubsub = PubsubFactory(router=FloodsubFactory()) - await pubsub.host.get_network().listen(LISTEN_MADDR) - return cls(libp2p_node=pubsub.host, pubsub=pubsub, floodsub=pubsub.router) + async with PubsubFactory.create_batch_with_floodsub(number) as pubsubs: + async with AsyncExitStack() as stack: + dummy_acount_nodes = tuple(cls(pubsub) for pubsub in pubsubs) + for node in dummy_acount_nodes: + await stack.enter_async_context(background_trio_service(node)) + yield dummy_acount_nodes async def handle_incoming_msgs(self) -> None: """Handle all incoming messages on the CRYPTO_TOPIC from peers.""" while True: - incoming = await self.q.get() + incoming = await self.subscription.receive() msg_comps = incoming.data.decode("utf-8").split(",") if msg_comps[0] == "send": @@ -63,13 +69,6 @@ class DummyAccountNode: elif msg_comps[0] == "set": self.handle_set_crypto(msg_comps[1], int(msg_comps[2])) - async def setup_crypto_networking(self) -> None: - """Subscribe to CRYPTO_TOPIC and perform call to function that handles - all incoming messages on said topic.""" - self.q = await self.pubsub.subscribe(CRYPTO_TOPIC) - - asyncio.ensure_future(self.handle_incoming_msgs()) - async def publish_send_crypto( self, source_user: str, dest_user: str, amount: int ) -> None: diff --git a/libp2p/tools/pubsub/floodsub_integration_test_settings.py b/libp2p/tools/pubsub/floodsub_integration_test_settings.py index 90939de..58a5b24 100644 --- a/libp2p/tools/pubsub/floodsub_integration_test_settings.py +++ b/libp2p/tools/pubsub/floodsub_integration_test_settings.py @@ -1,12 +1,10 @@ # type: ignore # To add typing to this module, it's better to do it after refactoring test cases into classes -import asyncio - import pytest +import trio -from libp2p.tools.constants import FLOODSUB_PROTOCOL_ID, LISTEN_MADDR -from libp2p.tools.factories import PubsubFactory +from libp2p.tools.constants import FLOODSUB_PROTOCOL_ID from libp2p.tools.utils import connect SUPPORTED_PROTOCOLS = [FLOODSUB_PROTOCOL_ID] @@ -15,6 +13,7 @@ FLOODSUB_PROTOCOL_TEST_CASES = [ { "name": "simple_two_nodes", "supported_protocols": SUPPORTED_PROTOCOLS, + "nodes": ["A", "B"], "adj_list": {"A": ["B"]}, "topic_map": {"topic1": ["B"]}, "messages": [{"topics": ["topic1"], "data": b"foo", "node_id": "A"}], @@ -22,6 +21,7 @@ FLOODSUB_PROTOCOL_TEST_CASES = [ { "name": "three_nodes_two_topics", "supported_protocols": SUPPORTED_PROTOCOLS, + "nodes": ["A", "B", "C"], "adj_list": {"A": ["B"], "B": ["C"]}, "topic_map": {"topic1": ["B", "C"], "topic2": ["B", "C"]}, "messages": [ @@ -32,6 +32,7 @@ FLOODSUB_PROTOCOL_TEST_CASES = [ { "name": "two_nodes_one_topic_single_subscriber_is_sender", "supported_protocols": SUPPORTED_PROTOCOLS, + "nodes": ["A", "B"], "adj_list": {"A": ["B"]}, "topic_map": {"topic1": ["B"]}, "messages": [{"topics": ["topic1"], "data": b"Alex is tall", "node_id": "B"}], @@ -39,6 +40,7 @@ FLOODSUB_PROTOCOL_TEST_CASES = [ { "name": "two_nodes_one_topic_two_msgs", "supported_protocols": SUPPORTED_PROTOCOLS, + "nodes": ["A", "B"], "adj_list": {"A": ["B"]}, "topic_map": {"topic1": ["B"]}, "messages": [ @@ -49,6 +51,7 @@ FLOODSUB_PROTOCOL_TEST_CASES = [ { "name": "seven_nodes_tree_one_topics", "supported_protocols": SUPPORTED_PROTOCOLS, + "nodes": ["1", "2", "3", "4", "5", "6", "7"], "adj_list": {"1": ["2", "3"], "2": ["4", "5"], "3": ["6", "7"]}, "topic_map": {"astrophysics": ["2", "3", "4", "5", "6", "7"]}, "messages": [{"topics": ["astrophysics"], "data": b"e=mc^2", "node_id": "1"}], @@ -56,6 +59,7 @@ FLOODSUB_PROTOCOL_TEST_CASES = [ { "name": "seven_nodes_tree_three_topics", "supported_protocols": SUPPORTED_PROTOCOLS, + "nodes": ["1", "2", "3", "4", "5", "6", "7"], "adj_list": {"1": ["2", "3"], "2": ["4", "5"], "3": ["6", "7"]}, "topic_map": { "astrophysics": ["2", "3", "4", "5", "6", "7"], @@ -71,6 +75,7 @@ FLOODSUB_PROTOCOL_TEST_CASES = [ { "name": "seven_nodes_tree_three_topics_diff_origin", "supported_protocols": SUPPORTED_PROTOCOLS, + "nodes": ["1", "2", "3", "4", "5", "6", "7"], "adj_list": {"1": ["2", "3"], "2": ["4", "5"], "3": ["6", "7"]}, "topic_map": { "astrophysics": ["1", "2", "3", "4", "5", "6", "7"], @@ -86,6 +91,7 @@ FLOODSUB_PROTOCOL_TEST_CASES = [ { "name": "three_nodes_clique_two_topic_diff_origin", "supported_protocols": SUPPORTED_PROTOCOLS, + "nodes": ["1", "2", "3"], "adj_list": {"1": ["2", "3"], "2": ["3"]}, "topic_map": {"astrophysics": ["1", "2", "3"], "school": ["1", "2", "3"]}, "messages": [ @@ -97,6 +103,7 @@ FLOODSUB_PROTOCOL_TEST_CASES = [ { "name": "four_nodes_clique_two_topic_diff_origin_many_msgs", "supported_protocols": SUPPORTED_PROTOCOLS, + "nodes": ["1", "2", "3", "4"], "adj_list": { "1": ["2", "3", "4"], "2": ["1", "3", "4"], @@ -120,6 +127,7 @@ FLOODSUB_PROTOCOL_TEST_CASES = [ { "name": "five_nodes_ring_two_topic_diff_origin_many_msgs", "supported_protocols": SUPPORTED_PROTOCOLS, + "nodes": ["1", "2", "3", "4", "5"], "adj_list": {"1": ["2"], "2": ["3"], "3": ["4"], "4": ["5"], "5": ["1"]}, "topic_map": { "astrophysics": ["1", "2", "3", "4", "5"], @@ -143,7 +151,7 @@ floodsub_protocol_pytest_params = [ ] -async def perform_test_from_obj(obj, router_factory) -> None: +async def perform_test_from_obj(obj, pubsub_factory) -> None: """ Perform pubsub tests from a test obj. test obj are composed as follows: @@ -174,88 +182,75 @@ async def perform_test_from_obj(obj, router_factory) -> None: # Step 1) Create graph adj_list = obj["adj_list"] + node_list = obj["nodes"] node_map = {} pubsub_map = {} - async def add_node(node_id_str: str) -> None: - pubsub_router = router_factory(protocols=obj["supported_protocols"]) - pubsub = PubsubFactory(router=pubsub_router) - await pubsub.host.get_network().listen(LISTEN_MADDR) - node_map[node_id_str] = pubsub.host - pubsub_map[node_id_str] = pubsub + async with pubsub_factory( + number=len(node_list), protocols=obj["supported_protocols"] + ) as pubsubs: + for node_id_str, pubsub in zip(node_list, pubsubs): + node_map[node_id_str] = pubsub.host + pubsub_map[node_id_str] = pubsub - tasks_connect = [] - for start_node_id in adj_list: - # Create node if node does not yet exist - if start_node_id not in node_map: - await add_node(start_node_id) + # Connect nodes and wait at least for 2 seconds + async with trio.open_nursery() as nursery: + for start_node_id in adj_list: + # For each neighbor of start_node, create if does not yet exist, + # then connect start_node to neighbor + for neighbor_id in adj_list[start_node_id]: + nursery.start_soon( + connect, node_map[start_node_id], node_map[neighbor_id] + ) + nursery.start_soon(trio.sleep, 2) - # For each neighbor of start_node, create if does not yet exist, - # then connect start_node to neighbor - for neighbor_id in adj_list[start_node_id]: - # Create neighbor if neighbor does not yet exist - if neighbor_id not in node_map: - await add_node(neighbor_id) - tasks_connect.append( - connect(node_map[start_node_id], node_map[neighbor_id]) - ) - # Connect nodes and wait at least for 2 seconds - await asyncio.gather(*tasks_connect, asyncio.sleep(2)) + # Step 2) Subscribe to topics + queues_map = {} + topic_map = obj["topic_map"] - # Step 2) Subscribe to topics - queues_map = {} - topic_map = obj["topic_map"] + async def subscribe_node(node_id, topic): + if node_id not in queues_map: + queues_map[node_id] = {} + # Avoid repeated works + if topic in queues_map[node_id]: + # Checkpoint + await trio.hazmat.checkpoint() + return + sub = await pubsub_map[node_id].subscribe(topic) + queues_map[node_id][topic] = sub - tasks_topic = [] - tasks_topic_data = [] - for topic, node_ids in topic_map.items(): - for node_id in node_ids: - tasks_topic.append(pubsub_map[node_id].subscribe(topic)) - tasks_topic_data.append((node_id, topic)) - tasks_topic.append(asyncio.sleep(2)) + async with trio.open_nursery() as nursery: + for topic, node_ids in topic_map.items(): + for node_id in node_ids: + nursery.start_soon(subscribe_node, node_id, topic) + nursery.start_soon(trio.sleep, 2) - # Gather is like Promise.all - responses = await asyncio.gather(*tasks_topic) - for i in range(len(responses) - 1): - node_id, topic = tasks_topic_data[i] - if node_id not in queues_map: - queues_map[node_id] = {} - # Store queue in topic-queue map for node - queues_map[node_id][topic] = responses[i] + # Step 3) Publish messages + topics_in_msgs_ordered = [] + messages = obj["messages"] - # Allow time for subscribing before continuing - await asyncio.sleep(0.01) + for msg in messages: + topics = msg["topics"] + data = msg["data"] + node_id = msg["node_id"] - # Step 3) Publish messages - topics_in_msgs_ordered = [] - messages = obj["messages"] - tasks_publish = [] + # Publish message + # TODO: Should be single RPC package with several topics + for topic in topics: + await pubsub_map[node_id].publish(topic, data) - for msg in messages: - topics = msg["topics"] - data = msg["data"] - node_id = msg["node_id"] + # For each topic in topics, add (topic, node_id, data) tuple to ordered test list + for topic in topics: + topics_in_msgs_ordered.append((topic, node_id, data)) + # Allow time for publishing before continuing + await trio.sleep(1) - # Publish message - # TODO: Should be single RPC package with several topics - for topic in topics: - tasks_publish.append(pubsub_map[node_id].publish(topic, data)) - - # For each topic in topics, add (topic, node_id, data) tuple to ordered test list - for topic in topics: - topics_in_msgs_ordered.append((topic, node_id, data)) - - # Allow time for publishing before continuing - await asyncio.gather(*tasks_publish, asyncio.sleep(2)) - - # Step 4) Check that all messages were received correctly. - for topic, origin_node_id, data in topics_in_msgs_ordered: - # Look at each node in each topic - for node_id in topic_map[topic]: - # Get message from subscription queue - msg = await queues_map[node_id][topic].get() - assert data == msg.data - # Check the message origin - assert node_map[origin_node_id].get_id().to_bytes() == msg.from_id - - # Success, terminate pending tasks. + # Step 4) Check that all messages were received correctly. + for topic, origin_node_id, data in topics_in_msgs_ordered: + # Look at each node in each topic + for node_id in topic_map[topic]: + # Get message from subscription queue + msg = await queues_map[node_id][topic].receive() + assert data == msg.data + # Check the message origin + assert node_map[origin_node_id].get_id().to_bytes() == msg.from_id diff --git a/libp2p/tools/utils.py b/libp2p/tools/utils.py index 9ad6815..a66155c 100644 --- a/libp2p/tools/utils.py +++ b/libp2p/tools/utils.py @@ -1,17 +1,9 @@ -from typing import Callable, List, Sequence, Tuple +from typing import Awaitable, Callable -import multiaddr -import trio - -from libp2p import new_node -from libp2p.host.basic_host import BasicHost from libp2p.host.host_interface import IHost -from libp2p.kademlia.network import KademliaServer from libp2p.network.stream.net_stream_interface import INetStream from libp2p.network.swarm import Swarm from libp2p.peer.peerinfo import info_from_p2p_addr -from libp2p.routing.interfaces import IPeerRouting -from libp2p.routing.kademlia.kademlia_peer_router import KadmeliaPeerRouter from .constants import MAX_READ_LEN @@ -36,49 +28,9 @@ async def connect(node1: IHost, node2: IHost) -> None: await node1.connect(info) -async def set_up_nodes_by_transport_opt( - transport_opt_list: Sequence[Sequence[str]], nursery: trio.Nursery -) -> Tuple[BasicHost, ...]: - nodes_list = [] - for transport_opt in transport_opt_list: - node = new_node(transport_opt=transport_opt) - await node.get_network().listen( - multiaddr.Multiaddr(transport_opt[0]), nursery=nursery - ) - nodes_list.append(node) - return tuple(nodes_list) - - -async def set_up_nodes_by_transport_and_disc_opt( - transport_disc_opt_list: Sequence[Tuple[Sequence[str], IPeerRouting]] -) -> Tuple[BasicHost, ...]: - nodes_list = [] - for transport_opt, disc_opt in transport_disc_opt_list: - node = await new_node(transport_opt=transport_opt, disc_opt=disc_opt) - await node.get_network().listen(multiaddr.Multiaddr(transport_opt[0])) - nodes_list.append(node) - return tuple(nodes_list) - - -async def set_up_routers( - router_ports: Tuple[int, ...] = (0, 0) -) -> List[KadmeliaPeerRouter]: - """The default ``router_confs`` selects two free ports local to this - machine.""" - bootstrap_node = KademliaServer() # type: ignore - await bootstrap_node.listen(router_ports[0]) - - routers = [KadmeliaPeerRouter(bootstrap_node)] - for port in router_ports[1:]: - node = KademliaServer() # type: ignore - await node.listen(port) - - await node.bootstrap_node(bootstrap_node.address) - routers.append(KadmeliaPeerRouter(node)) - return routers - - -def create_echo_stream_handler(ack_prefix: str) -> Callable[[INetStream], None]: +def create_echo_stream_handler( + ack_prefix: str +) -> Callable[[INetStream], Awaitable[None]]: async def echo_stream_handler(stream: INetStream) -> None: while True: read_string = (await stream.read(MAX_READ_LEN)).decode() diff --git a/libp2p/transport/listener_interface.py b/libp2p/transport/listener_interface.py index 1b22531..6d73723 100644 --- a/libp2p/transport/listener_interface.py +++ b/libp2p/transport/listener_interface.py @@ -1,12 +1,13 @@ from abc import ABC, abstractmethod -from typing import List +from typing import Tuple from multiaddr import Multiaddr +import trio class IListener(ABC): @abstractmethod - async def listen(self, maddr: Multiaddr) -> bool: + async def listen(self, maddr: Multiaddr, nursery: trio.Nursery) -> bool: """ put listener in listening mode and wait for incoming connections. @@ -15,14 +16,9 @@ class IListener(ABC): """ @abstractmethod - def get_addrs(self) -> List[Multiaddr]: + def get_addrs(self) -> Tuple[Multiaddr, ...]: """ retrieve list of addresses the listener is listening on. :return: return list of addrs """ - - @abstractmethod - async def close(self) -> None: - """close the listener such that no more connections can be open on this - transport instance.""" diff --git a/libp2p/transport/tcp/tcp.py b/libp2p/transport/tcp/tcp.py index 745bafe..04d8874 100644 --- a/libp2p/transport/tcp/tcp.py +++ b/libp2p/transport/tcp/tcp.py @@ -1,14 +1,13 @@ import logging -from socket import socket -from typing import List +from typing import Awaitable, Callable, List, Sequence, Tuple from multiaddr import Multiaddr import trio +from trio_typing import TaskStatus -from libp2p.io.trio import TrioReadWriteCloser +from libp2p.io.trio import TrioTCPStream from libp2p.network.connection.raw_connection import RawConnection from libp2p.network.connection.raw_connection_interface import IRawConnection -from libp2p.transport.exceptions import OpenConnectionError from libp2p.transport.listener_interface import IListener from libp2p.transport.transport_interface import ITransport from libp2p.transport.typing import THandler @@ -18,14 +17,12 @@ logger = logging.getLogger("libp2p.transport.tcp") class TCPListener(IListener): multiaddrs: List[Multiaddr] - server = None def __init__(self, handler_function: THandler) -> None: self.multiaddrs = [] - self.server = None self.handler = handler_function - # TODO: Fix handling? + # TODO: Get rid of `nursery`? async def listen(self, maddr: Multiaddr, nursery: trio.Nursery) -> None: """ put listener in listening mode and wait for incoming connections. @@ -34,13 +31,18 @@ class TCPListener(IListener): :return: return True if successful """ - async def serve_tcp(handler, port, host, task_status=None): + async def serve_tcp( + handler: Callable[[trio.SocketStream], Awaitable[None]], + port: int, + host: str, + task_status: TaskStatus[Sequence[trio.SocketListener]] = None, + ) -> None: logger.debug("serve_tcp %s %s", host, port) await trio.serve_tcp(handler, port, host=host, task_status=task_status) - async def handler(stream): - read_write_closer = TrioReadWriteCloser(stream) - await self.handler(read_write_closer) + async def handler(stream: trio.SocketStream) -> None: + tcp_stream = TrioTCPStream(stream) + await self.handler(tcp_stream) listeners = await nursery.start( serve_tcp, @@ -51,7 +53,7 @@ class TCPListener(IListener): socket = listeners[0].socket self.multiaddrs.append(_multiaddr_from_socket(socket)) - def get_addrs(self) -> List[Multiaddr]: + def get_addrs(self) -> Tuple[Multiaddr, ...]: """ retrieve list of addresses the listener is listening on. @@ -59,15 +61,6 @@ class TCPListener(IListener): """ return tuple(self.multiaddrs) - async def close(self) -> None: - """close the listener such that no more connections can be open on this - transport instance.""" - if self.server is None: - return - self.server.close() - await self.server.wait_closed() - self.server = None - class TCP(ITransport): async def dial(self, maddr: Multiaddr) -> IRawConnection: @@ -82,7 +75,7 @@ class TCP(ITransport): self.port = int(maddr.value_for_protocol("tcp")) stream = await trio.open_tcp_stream(self.host, self.port) - read_write_closer = TrioReadWriteCloser(stream) + read_write_closer = TrioTCPStream(stream) return RawConnection(read_write_closer, True) @@ -97,5 +90,6 @@ class TCP(ITransport): return TCPListener(handler_function) -def _multiaddr_from_socket(socket: socket) -> Multiaddr: - return Multiaddr("/ip4/%s/tcp/%s" % socket.getsockname()) +def _multiaddr_from_socket(socket: trio.socket.SocketType) -> Multiaddr: + ip, port = socket.getsockname() # type: ignore + return Multiaddr(f"/ip4/{ip}/tcp/{port}") diff --git a/tests/host/test_ping.py b/tests/host/test_ping.py index 2913514..7a0f8db 100644 --- a/tests/host/test_ping.py +++ b/tests/host/test_ping.py @@ -1,7 +1,7 @@ -import trio import secrets import pytest +import trio from libp2p.host.ping import ID, PING_LENGTH from libp2p.tools.factories import host_pair_factory diff --git a/tests/host/test_routed_host.py b/tests/host/test_routed_host.py deleted file mode 100644 index 006dd22..0000000 --- a/tests/host/test_routed_host.py +++ /dev/null @@ -1,73 +0,0 @@ -import pytest - -from libp2p.host.exceptions import ConnectionFailure -from libp2p.peer.peerinfo import PeerInfo -from libp2p.routing.kademlia.kademlia_peer_router import peer_info_to_str -from libp2p.tools.utils import ( - set_up_nodes_by_transport_and_disc_opt, - set_up_nodes_by_transport_opt, - set_up_routers, -) -from libp2p.tools.factories import RoutedHostFactory - - -# FIXME: - -# TODO: Kademlia is full of asyncio code. Skip it for now -@pytest.mark.skip -@pytest.mark.trio -async def test_host_routing_success(is_host_secure): - async with RoutedHostFactory.create_batch_and_listen( - is_host_secure, 2 - ) as routed_hosts: - # Set routing info - await routed_hosts[0]._router.server.set( - routed_hosts[0].get_id().xor_id, - peer_info_to_str( - PeerInfo(routed_hosts[0].get_id(), routed_hosts[0].get_addrs()) - ), - ) - await routed_hosts[1]._router.server.set( - routed_hosts[1].get_id().xor_id, - peer_info_to_str( - PeerInfo(routed_hosts[1].get_id(), routed_hosts[1].get_addrs()) - ), - ) - - # forces to use routing as no addrs are provided - await routed_hosts[0].connect(PeerInfo(routed_hosts[1].get_id(), [])) - await routed_hosts[1].connect(PeerInfo(routed_hosts[0].get_id(), [])) - - -# TODO: Kademlia is full of asyncio code. Skip it for now -@pytest.mark.skip -@pytest.mark.trio -async def test_host_routing_fail(): - routers = await set_up_routers() - transports = [["/ip4/127.0.0.1/tcp/0"], ["/ip4/127.0.0.1/tcp/0"]] - transport_disc_opt_list = zip(transports, routers) - (host_a, host_b) = await set_up_nodes_by_transport_and_disc_opt( - transport_disc_opt_list - ) - - host_c = (await set_up_nodes_by_transport_opt([["/ip4/127.0.0.1/tcp/0"]]))[0] - - # Set routing info - await routers[0].server.set( - host_a.get_id().xor_id, - peer_info_to_str(PeerInfo(host_a.get_id(), host_a.get_addrs())), - ) - await routers[1].server.set( - host_b.get_id().xor_id, - peer_info_to_str(PeerInfo(host_b.get_id(), host_b.get_addrs())), - ) - - # routing fails because host_c does not use routing - with pytest.raises(ConnectionFailure): - await host_a.connect(PeerInfo(host_c.get_id(), [])) - with pytest.raises(ConnectionFailure): - await host_b.connect(PeerInfo(host_c.get_id(), [])) - - # Clean up - routers[0].server.stop() - routers[1].server.stop() diff --git a/tests/protocol_muxer/test_protocol_muxer.py b/tests/protocol_muxer/test_protocol_muxer.py index 9533d1f..cd82652 100644 --- a/tests/protocol_muxer/test_protocol_muxer.py +++ b/tests/protocol_muxer/test_protocol_muxer.py @@ -4,7 +4,6 @@ from libp2p.host.exceptions import StreamFailure from libp2p.tools.factories import HostFactory from libp2p.tools.utils import create_echo_stream_handler - PROTOCOL_ECHO = "/echo/1.0.0" PROTOCOL_POTATO = "/potato/1.0.0" PROTOCOL_FOO = "/foo/1.0.0" diff --git a/tests/pubsub/conftest.py b/tests/pubsub/conftest.py deleted file mode 100644 index 6c08dd7..0000000 --- a/tests/pubsub/conftest.py +++ /dev/null @@ -1,22 +0,0 @@ -import pytest - -from libp2p.tools.constants import GOSSIPSUB_PARAMS -from libp2p.tools.factories import FloodsubFactory, GossipsubFactory, PubsubFactory - - -@pytest.fixture -def pubsub_cache_size(): - return None # default - - -@pytest.fixture -def gossipsub_params(): - return GOSSIPSUB_PARAMS - - -# @pytest.fixture -# def pubsubs_gsub(num_hosts, hosts, pubsub_cache_size, gossipsub_params): -# gossipsubs = GossipsubFactory.create_batch(num_hosts, **gossipsub_params._asdict()) -# _pubsubs_gsub = _make_pubsubs(hosts, gossipsubs, pubsub_cache_size) -# yield _pubsubs_gsub -# # TODO: Clean up diff --git a/tests/pubsub/test_dummyaccount_demo.py b/tests/pubsub/test_dummyaccount_demo.py index cdda603..24d5bd4 100644 --- a/tests/pubsub/test_dummyaccount_demo.py +++ b/tests/pubsub/test_dummyaccount_demo.py @@ -1,19 +1,10 @@ -import asyncio -from threading import Thread - import pytest +import trio from libp2p.tools.pubsub.dummy_account_node import DummyAccountNode from libp2p.tools.utils import connect -def create_setup_in_new_thread_func(dummy_node): - def setup_in_new_thread(): - asyncio.ensure_future(dummy_node.setup_crypto_networking()) - - return setup_in_new_thread - - async def perform_test(num_nodes, adjacency_map, action_func, assertion_func): """ Helper function to allow for easy construction of custom tests for dummy @@ -26,47 +17,35 @@ async def perform_test(num_nodes, adjacency_map, action_func, assertion_func): :param assertion_func: assertions for testing the results of the actions are correct """ - # Create nodes - dummy_nodes = [] - for _ in range(num_nodes): - dummy_nodes.append(await DummyAccountNode.create()) + async with DummyAccountNode.create(num_nodes) as dummy_nodes: + # Create connections between nodes according to `adjacency_map` + async with trio.open_nursery() as nursery: + for source_num in adjacency_map: + target_nums = adjacency_map[source_num] + for target_num in target_nums: + nursery.start_soon( + connect, + dummy_nodes[source_num].host, + dummy_nodes[target_num].host, + ) - # Create network - for source_num in adjacency_map: - target_nums = adjacency_map[source_num] - for target_num in target_nums: - await connect( - dummy_nodes[source_num].libp2p_node, dummy_nodes[target_num].libp2p_node - ) + # Allow time for network creation to take place + await trio.sleep(0.25) - # Allow time for network creation to take place - await asyncio.sleep(0.25) + # Perform action function + await action_func(dummy_nodes) - # Start a thread for each node so that each node can listen and respond - # to messages on its own thread, which will avoid waiting indefinitely - # on the main thread. On this thread, call the setup func for the node, - # which subscribes the node to the CRYPTO_TOPIC topic - for dummy_node in dummy_nodes: - thread = Thread(target=create_setup_in_new_thread_func(dummy_node)) - thread.run() + # Allow time for action function to be performed (i.e. messages to propogate) + await trio.sleep(1) - # Allow time for nodes to subscribe to CRYPTO_TOPIC topic - await asyncio.sleep(0.25) - - # Perform action function - await action_func(dummy_nodes) - - # Allow time for action function to be performed (i.e. messages to propogate) - await asyncio.sleep(1) - - # Perform assertion function - for dummy_node in dummy_nodes: - assertion_func(dummy_node) + # Perform assertion function + for dummy_node in dummy_nodes: + assertion_func(dummy_node) # Success, terminate pending tasks. -@pytest.mark.asyncio +@pytest.mark.trio async def test_simple_two_nodes(): num_nodes = 2 adj_map = {0: [1]} @@ -80,7 +59,7 @@ async def test_simple_two_nodes(): await perform_test(num_nodes, adj_map, action_func, assertion_func) -@pytest.mark.asyncio +@pytest.mark.trio async def test_simple_three_nodes_line_topography(): num_nodes = 3 adj_map = {0: [1], 1: [2]} @@ -94,7 +73,7 @@ async def test_simple_three_nodes_line_topography(): await perform_test(num_nodes, adj_map, action_func, assertion_func) -@pytest.mark.asyncio +@pytest.mark.trio async def test_simple_three_nodes_triangle_topography(): num_nodes = 3 adj_map = {0: [1, 2], 1: [2]} @@ -108,7 +87,7 @@ async def test_simple_three_nodes_triangle_topography(): await perform_test(num_nodes, adj_map, action_func, assertion_func) -@pytest.mark.asyncio +@pytest.mark.trio async def test_simple_seven_nodes_tree_topography(): num_nodes = 7 adj_map = {0: [1, 2], 1: [3, 4], 2: [5, 6]} @@ -122,14 +101,14 @@ async def test_simple_seven_nodes_tree_topography(): await perform_test(num_nodes, adj_map, action_func, assertion_func) -@pytest.mark.asyncio +@pytest.mark.trio async def test_set_then_send_from_root_seven_nodes_tree_topography(): num_nodes = 7 adj_map = {0: [1, 2], 1: [3, 4], 2: [5, 6]} async def action_func(dummy_nodes): await dummy_nodes[0].publish_set_crypto("aspyn", 20) - await asyncio.sleep(0.25) + await trio.sleep(0.25) await dummy_nodes[0].publish_send_crypto("aspyn", "alex", 5) def assertion_func(dummy_node): @@ -139,14 +118,14 @@ async def test_set_then_send_from_root_seven_nodes_tree_topography(): await perform_test(num_nodes, adj_map, action_func, assertion_func) -@pytest.mark.asyncio +@pytest.mark.trio async def test_set_then_send_from_different_leafs_seven_nodes_tree_topography(): num_nodes = 7 adj_map = {0: [1, 2], 1: [3, 4], 2: [5, 6]} async def action_func(dummy_nodes): await dummy_nodes[6].publish_set_crypto("aspyn", 20) - await asyncio.sleep(0.25) + await trio.sleep(0.25) await dummy_nodes[4].publish_send_crypto("aspyn", "alex", 5) def assertion_func(dummy_node): @@ -156,7 +135,7 @@ async def test_set_then_send_from_different_leafs_seven_nodes_tree_topography(): await perform_test(num_nodes, adj_map, action_func, assertion_func) -@pytest.mark.asyncio +@pytest.mark.trio async def test_simple_five_nodes_ring_topography(): num_nodes = 5 adj_map = {0: [1], 1: [2], 2: [3], 3: [4], 4: [0]} @@ -170,14 +149,14 @@ async def test_simple_five_nodes_ring_topography(): await perform_test(num_nodes, adj_map, action_func, assertion_func) -@pytest.mark.asyncio +@pytest.mark.trio async def test_set_then_send_from_diff_nodes_five_nodes_ring_topography(): num_nodes = 5 adj_map = {0: [1], 1: [2], 2: [3], 3: [4], 4: [0]} async def action_func(dummy_nodes): await dummy_nodes[0].publish_set_crypto("alex", 20) - await asyncio.sleep(0.25) + await trio.sleep(0.25) await dummy_nodes[3].publish_send_crypto("alex", "rob", 12) def assertion_func(dummy_node): @@ -187,7 +166,7 @@ async def test_set_then_send_from_diff_nodes_five_nodes_ring_topography(): await perform_test(num_nodes, adj_map, action_func, assertion_func) -@pytest.mark.asyncio +@pytest.mark.trio @pytest.mark.slow async def test_set_then_send_from_five_diff_nodes_five_nodes_ring_topography(): num_nodes = 5 @@ -195,13 +174,13 @@ async def test_set_then_send_from_five_diff_nodes_five_nodes_ring_topography(): async def action_func(dummy_nodes): await dummy_nodes[0].publish_set_crypto("alex", 20) - await asyncio.sleep(1) + await trio.sleep(1) await dummy_nodes[1].publish_send_crypto("alex", "rob", 3) - await asyncio.sleep(1) + await trio.sleep(1) await dummy_nodes[2].publish_send_crypto("rob", "aspyn", 2) - await asyncio.sleep(1) + await trio.sleep(1) await dummy_nodes[3].publish_send_crypto("aspyn", "zx", 1) - await asyncio.sleep(1) + await trio.sleep(1) await dummy_nodes[4].publish_send_crypto("zx", "raul", 1) def assertion_func(dummy_node): diff --git a/tests/pubsub/test_floodsub.py b/tests/pubsub/test_floodsub.py index 7564a94..dbeb683 100644 --- a/tests/pubsub/test_floodsub.py +++ b/tests/pubsub/test_floodsub.py @@ -1,9 +1,10 @@ -import asyncio +import functools import pytest +import trio from libp2p.peer.id import ID -from libp2p.tools.factories import FloodsubFactory +from libp2p.tools.factories import PubsubFactory from libp2p.tools.pubsub.floodsub_integration_test_settings import ( floodsub_protocol_pytest_params, perform_test_from_obj, @@ -11,79 +12,83 @@ from libp2p.tools.pubsub.floodsub_integration_test_settings import ( from libp2p.tools.utils import connect -@pytest.mark.parametrize("num_hosts", (2,)) -@pytest.mark.asyncio -async def test_simple_two_nodes(pubsubs_fsub): - topic = "my_topic" - data = b"some data" +@pytest.mark.trio +async def test_simple_two_nodes(): + async with PubsubFactory.create_batch_with_floodsub(2) as pubsubs_fsub: + topic = "my_topic" + data = b"some data" - await connect(pubsubs_fsub[0].host, pubsubs_fsub[1].host) - await asyncio.sleep(0.25) + await connect(pubsubs_fsub[0].host, pubsubs_fsub[1].host) + await trio.sleep(0.25) - sub_b = await pubsubs_fsub[1].subscribe(topic) - # Sleep to let a know of b's subscription - await asyncio.sleep(0.25) + sub_b = await pubsubs_fsub[1].subscribe(topic) + # Sleep to let a know of b's subscription + await trio.sleep(0.25) - await pubsubs_fsub[0].publish(topic, data) + await pubsubs_fsub[0].publish(topic, data) - res_b = await sub_b.get() + res_b = await sub_b.receive() - # Check that the msg received by node_b is the same - # as the message sent by node_a - assert ID(res_b.from_id) == pubsubs_fsub[0].host.get_id() - assert res_b.data == data - assert res_b.topicIDs == [topic] - - # Success, terminate pending tasks. + # Check that the msg received by node_b is the same + # as the message sent by node_a + assert ID(res_b.from_id) == pubsubs_fsub[0].host.get_id() + assert res_b.data == data + assert res_b.topicIDs == [topic] -# Initialize Pubsub with a cache_size of 4 -@pytest.mark.parametrize("num_hosts, pubsub_cache_size", ((2, 4),)) -@pytest.mark.asyncio -async def test_lru_cache_two_nodes(pubsubs_fsub, monkeypatch): +@pytest.mark.trio +async def test_lru_cache_two_nodes(monkeypatch): # two nodes with cache_size of 4 - # `node_a` send the following messages to node_b - message_indices = [1, 1, 2, 1, 3, 1, 4, 1, 5, 1] - # `node_b` should only receive the following - expected_received_indices = [1, 2, 3, 4, 5, 1] + async with PubsubFactory.create_batch_with_floodsub( + 2, cache_size=4 + ) as pubsubs_fsub: + # `node_a` send the following messages to node_b + message_indices = [1, 1, 2, 1, 3, 1, 4, 1, 5, 1] + # `node_b` should only receive the following + expected_received_indices = [1, 2, 3, 4, 5, 1] - topic = "my_topic" + topic = "my_topic" - # Mock `get_msg_id` to make us easier to manipulate `msg_id` by `data`. - def get_msg_id(msg): - # Originally it is `(msg.seqno, msg.from_id)` - return (msg.data, msg.from_id) + # Mock `get_msg_id` to make us easier to manipulate `msg_id` by `data`. + def get_msg_id(msg): + # Originally it is `(msg.seqno, msg.from_id)` + return (msg.data, msg.from_id) - import libp2p.pubsub.pubsub + import libp2p.pubsub.pubsub - monkeypatch.setattr(libp2p.pubsub.pubsub, "get_msg_id", get_msg_id) + monkeypatch.setattr(libp2p.pubsub.pubsub, "get_msg_id", get_msg_id) - await connect(pubsubs_fsub[0].host, pubsubs_fsub[1].host) - await asyncio.sleep(0.25) + await connect(pubsubs_fsub[0].host, pubsubs_fsub[1].host) + await trio.sleep(0.25) - sub_b = await pubsubs_fsub[1].subscribe(topic) - await asyncio.sleep(0.25) + sub_b = await pubsubs_fsub[1].subscribe(topic) + await trio.sleep(0.25) - def _make_testing_data(i: int) -> bytes: - num_int_bytes = 4 - if i >= 2 ** (num_int_bytes * 8): - raise ValueError("integer is too large to be serialized") - return b"data" + i.to_bytes(num_int_bytes, "big") + def _make_testing_data(i: int) -> bytes: + num_int_bytes = 4 + if i >= 2 ** (num_int_bytes * 8): + raise ValueError("integer is too large to be serialized") + return b"data" + i.to_bytes(num_int_bytes, "big") - for index in message_indices: - await pubsubs_fsub[0].publish(topic, _make_testing_data(index)) - await asyncio.sleep(0.25) + for index in message_indices: + await pubsubs_fsub[0].publish(topic, _make_testing_data(index)) + await trio.sleep(0.25) - for index in expected_received_indices: - res_b = await sub_b.get() - assert res_b.data == _make_testing_data(index) - assert sub_b.empty() + for index in expected_received_indices: + res_b = await sub_b.receive() + assert res_b.data == _make_testing_data(index) - # Success, terminate pending tasks. + with pytest.raises(trio.WouldBlock): + sub_b.receive_nowait() @pytest.mark.parametrize("test_case_obj", floodsub_protocol_pytest_params) -@pytest.mark.asyncio +@pytest.mark.trio @pytest.mark.slow -async def test_gossipsub_run_with_floodsub_tests(test_case_obj): - await perform_test_from_obj(test_case_obj, FloodsubFactory) +async def test_gossipsub_run_with_floodsub_tests(test_case_obj, is_host_secure): + await perform_test_from_obj( + test_case_obj, + functools.partial( + PubsubFactory.create_batch_with_floodsub, is_secure=is_host_secure + ), + ) diff --git a/tests/pubsub/test_gossipsub.py b/tests/pubsub/test_gossipsub.py index 2121f8f..b1ed3af 100644 --- a/tests/pubsub/test_gossipsub.py +++ b/tests/pubsub/test_gossipsub.py @@ -1,368 +1,350 @@ -import asyncio import random import pytest +import trio -from libp2p.tools.constants import GossipsubParams +from libp2p.tools.factories import PubsubFactory from libp2p.tools.pubsub.utils import dense_connect, one_to_all_connect from libp2p.tools.utils import connect -@pytest.mark.parametrize( - "num_hosts, gossipsub_params", - ((4, GossipsubParams(degree=4, degree_low=3, degree_high=5)),), -) -@pytest.mark.asyncio -async def test_join(num_hosts, hosts, pubsubs_gsub): - gossipsubs = tuple(pubsub.router for pubsub in pubsubs_gsub) - hosts_indices = list(range(num_hosts)) +@pytest.mark.trio +async def test_join(): + async with PubsubFactory.create_batch_with_gossipsub( + 4, degree=4, degree_low=3, degree_high=5 + ) as pubsubs_gsub: + gossipsubs = [pubsub.router for pubsub in pubsubs_gsub] + hosts = [pubsub.host for pubsub in pubsubs_gsub] + hosts_indices = list(range(len(pubsubs_gsub))) - topic = "test_join" - central_node_index = 0 - # Remove index of central host from the indices - hosts_indices.remove(central_node_index) - num_subscribed_peer = 2 - subscribed_peer_indices = random.sample(hosts_indices, num_subscribed_peer) + topic = "test_join" + central_node_index = 0 + # Remove index of central host from the indices + hosts_indices.remove(central_node_index) + num_subscribed_peer = 2 + subscribed_peer_indices = random.sample(hosts_indices, num_subscribed_peer) - # All pubsub except the one of central node subscribe to topic - for i in subscribed_peer_indices: - await pubsubs_gsub[i].subscribe(topic) + # All pubsub except the one of central node subscribe to topic + for i in subscribed_peer_indices: + await pubsubs_gsub[i].subscribe(topic) - # Connect central host to all other hosts - await one_to_all_connect(hosts, central_node_index) + # Connect central host to all other hosts + await one_to_all_connect(hosts, central_node_index) - # Wait 2 seconds for heartbeat to allow mesh to connect - await asyncio.sleep(2) - - # Central node publish to the topic so that this topic - # is added to central node's fanout - # publish from the randomly chosen host - await pubsubs_gsub[central_node_index].publish(topic, b"data") - - # Check that the gossipsub of central node has fanout for the topic - assert topic in gossipsubs[central_node_index].fanout - # Check that the gossipsub of central node does not have a mesh for the topic - assert topic not in gossipsubs[central_node_index].mesh - - # Central node subscribes the topic - await pubsubs_gsub[central_node_index].subscribe(topic) - - await asyncio.sleep(2) - - # Check that the gossipsub of central node no longer has fanout for the topic - assert topic not in gossipsubs[central_node_index].fanout - - for i in hosts_indices: - if i in subscribed_peer_indices: - assert hosts[i].get_id() in gossipsubs[central_node_index].mesh[topic] - assert hosts[central_node_index].get_id() in gossipsubs[i].mesh[topic] - else: - assert hosts[i].get_id() not in gossipsubs[central_node_index].mesh[topic] - assert topic not in gossipsubs[i].mesh - - -@pytest.mark.parametrize("num_hosts", (1,)) -@pytest.mark.asyncio -async def test_leave(pubsubs_gsub): - gossipsub = pubsubs_gsub[0].router - topic = "test_leave" - - assert topic not in gossipsub.mesh - - await gossipsub.join(topic) - assert topic in gossipsub.mesh - - await gossipsub.leave(topic) - assert topic not in gossipsub.mesh - - # Test re-leave - await gossipsub.leave(topic) - - -@pytest.mark.parametrize("num_hosts", (2,)) -@pytest.mark.asyncio -async def test_handle_graft(pubsubs_gsub, hosts, event_loop, monkeypatch): - gossipsubs = tuple(pubsub.router for pubsub in pubsubs_gsub) - - index_alice = 0 - id_alice = hosts[index_alice].get_id() - index_bob = 1 - id_bob = hosts[index_bob].get_id() - await connect(hosts[index_alice], hosts[index_bob]) - - # Wait 2 seconds for heartbeat to allow mesh to connect - await asyncio.sleep(2) - - topic = "test_handle_graft" - # Only lice subscribe to the topic - await gossipsubs[index_alice].join(topic) - - # Monkey patch bob's `emit_prune` function so we can - # check if it is called in `handle_graft` - event_emit_prune = asyncio.Event() - - async def emit_prune(topic, sender_peer_id): - event_emit_prune.set() - - monkeypatch.setattr(gossipsubs[index_bob], "emit_prune", emit_prune) - - # Check that alice is bob's peer but not his mesh peer - assert id_alice in gossipsubs[index_bob].peers_gossipsub - assert topic not in gossipsubs[index_bob].mesh - - await gossipsubs[index_alice].emit_graft(topic, id_bob) - - # Check that `emit_prune` is called - await asyncio.wait_for(event_emit_prune.wait(), timeout=1, loop=event_loop) - assert event_emit_prune.is_set() - - # Check that bob is alice's peer but not her mesh peer - assert topic in gossipsubs[index_alice].mesh - assert id_bob not in gossipsubs[index_alice].mesh[topic] - assert id_bob in gossipsubs[index_alice].peers_gossipsub - - await gossipsubs[index_bob].emit_graft(topic, id_alice) - - await asyncio.sleep(1) - - # Check that bob is now alice's mesh peer - assert id_bob in gossipsubs[index_alice].mesh[topic] - - -@pytest.mark.parametrize( - "num_hosts, gossipsub_params", ((2, GossipsubParams(heartbeat_interval=3)),) -) -@pytest.mark.asyncio -async def test_handle_prune(pubsubs_gsub, hosts): - gossipsubs = tuple(pubsub.router for pubsub in pubsubs_gsub) - - index_alice = 0 - id_alice = hosts[index_alice].get_id() - index_bob = 1 - id_bob = hosts[index_bob].get_id() - - topic = "test_handle_prune" - for pubsub in pubsubs_gsub: - await pubsub.subscribe(topic) - - await connect(hosts[index_alice], hosts[index_bob]) - - # Wait 3 seconds for heartbeat to allow mesh to connect - await asyncio.sleep(3) - - # Check that they are each other's mesh peer - assert id_alice in gossipsubs[index_bob].mesh[topic] - assert id_bob in gossipsubs[index_alice].mesh[topic] - - # alice emit prune message to bob, alice should be removed - # from bob's mesh peer - await gossipsubs[index_alice].emit_prune(topic, id_bob) - - # FIXME: This test currently works because the heartbeat interval - # is increased to 3 seconds, so alice won't get add back into - # bob's mesh peer during heartbeat. - await asyncio.sleep(1) - - # Check that alice is no longer bob's mesh peer - assert id_alice not in gossipsubs[index_bob].mesh[topic] - assert id_bob in gossipsubs[index_alice].mesh[topic] - - -@pytest.mark.parametrize("num_hosts", (10,)) -@pytest.mark.asyncio -async def test_dense(num_hosts, pubsubs_gsub, hosts): - num_msgs = 5 - - # All pubsub subscribe to foobar - queues = [] - for pubsub in pubsubs_gsub: - q = await pubsub.subscribe("foobar") - - # Add each blocking queue to an array of blocking queues - queues.append(q) - - # Densely connect libp2p hosts in a random way - await dense_connect(hosts) - - # Wait 2 seconds for heartbeat to allow mesh to connect - await asyncio.sleep(2) - - for i in range(num_msgs): - msg_content = b"foo " + i.to_bytes(1, "big") - - # randomly pick a message origin - origin_idx = random.randint(0, num_hosts - 1) + # Wait 2 seconds for heartbeat to allow mesh to connect + await trio.sleep(2) + # Central node publish to the topic so that this topic + # is added to central node's fanout # publish from the randomly chosen host - await pubsubs_gsub[origin_idx].publish("foobar", msg_content) + await pubsubs_gsub[central_node_index].publish(topic, b"data") - await asyncio.sleep(0.5) - # Assert that all blocking queues receive the message - for queue in queues: - msg = await queue.get() - assert msg.data == msg_content + # Check that the gossipsub of central node has fanout for the topic + assert topic in gossipsubs[central_node_index].fanout + # Check that the gossipsub of central node does not have a mesh for the topic + assert topic not in gossipsubs[central_node_index].mesh + + # Central node subscribes the topic + await pubsubs_gsub[central_node_index].subscribe(topic) + + await trio.sleep(2) + + # Check that the gossipsub of central node no longer has fanout for the topic + assert topic not in gossipsubs[central_node_index].fanout + + for i in hosts_indices: + if i in subscribed_peer_indices: + assert hosts[i].get_id() in gossipsubs[central_node_index].mesh[topic] + assert hosts[central_node_index].get_id() in gossipsubs[i].mesh[topic] + else: + assert ( + hosts[i].get_id() not in gossipsubs[central_node_index].mesh[topic] + ) + assert topic not in gossipsubs[i].mesh -@pytest.mark.parametrize("num_hosts", (10,)) -@pytest.mark.asyncio -async def test_fanout(hosts, pubsubs_gsub): - num_msgs = 5 +@pytest.mark.trio +async def test_leave(): + async with PubsubFactory.create_batch_with_gossipsub(1) as pubsubs_gsub: + gossipsub = pubsubs_gsub[0].router + topic = "test_leave" - # All pubsub subscribe to foobar except for `pubsubs_gsub[0]` - queues = [] - for i in range(1, len(pubsubs_gsub)): - q = await pubsubs_gsub[i].subscribe("foobar") + assert topic not in gossipsub.mesh - # Add each blocking queue to an array of blocking queues - queues.append(q) + await gossipsub.join(topic) + assert topic in gossipsub.mesh - # Sparsely connect libp2p hosts in random way - await dense_connect(hosts) + await gossipsub.leave(topic) + assert topic not in gossipsub.mesh - # Wait 2 seconds for heartbeat to allow mesh to connect - await asyncio.sleep(2) - - topic = "foobar" - # Send messages with origin not subscribed - for i in range(num_msgs): - msg_content = b"foo " + i.to_bytes(1, "big") - - # Pick the message origin to the node that is not subscribed to 'foobar' - origin_idx = 0 - - # publish from the randomly chosen host - await pubsubs_gsub[origin_idx].publish(topic, msg_content) - - await asyncio.sleep(0.5) - # Assert that all blocking queues receive the message - for queue in queues: - msg = await queue.get() - assert msg.data == msg_content - - # Subscribe message origin - queues.insert(0, await pubsubs_gsub[0].subscribe(topic)) - - # Send messages again - for i in range(num_msgs): - msg_content = b"bar " + i.to_bytes(1, "big") - - # Pick the message origin to the node that is not subscribed to 'foobar' - origin_idx = 0 - - # publish from the randomly chosen host - await pubsubs_gsub[origin_idx].publish(topic, msg_content) - - await asyncio.sleep(0.5) - # Assert that all blocking queues receive the message - for queue in queues: - msg = await queue.get() - assert msg.data == msg_content + # Test re-leave + await gossipsub.leave(topic) -@pytest.mark.parametrize("num_hosts", (10,)) -@pytest.mark.asyncio +@pytest.mark.trio +async def test_handle_graft(monkeypatch): + async with PubsubFactory.create_batch_with_gossipsub(2) as pubsubs_gsub: + gossipsubs = tuple(pubsub.router for pubsub in pubsubs_gsub) + + index_alice = 0 + id_alice = pubsubs_gsub[index_alice].my_id + index_bob = 1 + id_bob = pubsubs_gsub[index_bob].my_id + await connect(pubsubs_gsub[index_alice].host, pubsubs_gsub[index_bob].host) + + # Wait 2 seconds for heartbeat to allow mesh to connect + await trio.sleep(2) + + topic = "test_handle_graft" + # Only lice subscribe to the topic + await gossipsubs[index_alice].join(topic) + + # Monkey patch bob's `emit_prune` function so we can + # check if it is called in `handle_graft` + event_emit_prune = trio.Event() + + async def emit_prune(topic, sender_peer_id): + event_emit_prune.set() + + monkeypatch.setattr(gossipsubs[index_bob], "emit_prune", emit_prune) + + # Check that alice is bob's peer but not his mesh peer + assert id_alice in gossipsubs[index_bob].peers_gossipsub + assert topic not in gossipsubs[index_bob].mesh + + await gossipsubs[index_alice].emit_graft(topic, id_bob) + + # Check that `emit_prune` is called + await event_emit_prune.wait() + + # Check that bob is alice's peer but not her mesh peer + assert topic in gossipsubs[index_alice].mesh + assert id_bob not in gossipsubs[index_alice].mesh[topic] + assert id_bob in gossipsubs[index_alice].peers_gossipsub + + await gossipsubs[index_bob].emit_graft(topic, id_alice) + + await trio.sleep(1) + + # Check that bob is now alice's mesh peer + assert id_bob in gossipsubs[index_alice].mesh[topic] + + +@pytest.mark.trio +async def test_handle_prune(): + async with PubsubFactory.create_batch_with_gossipsub( + 2, heartbeat_interval=3 + ) as pubsubs_gsub: + gossipsubs = tuple(pubsub.router for pubsub in pubsubs_gsub) + + index_alice = 0 + id_alice = pubsubs_gsub[index_alice].my_id + index_bob = 1 + id_bob = pubsubs_gsub[index_bob].my_id + + topic = "test_handle_prune" + for pubsub in pubsubs_gsub: + await pubsub.subscribe(topic) + + await connect(pubsubs_gsub[index_alice].host, pubsubs_gsub[index_bob].host) + + # Wait 3 seconds for heartbeat to allow mesh to connect + await trio.sleep(3) + + # Check that they are each other's mesh peer + assert id_alice in gossipsubs[index_bob].mesh[topic] + assert id_bob in gossipsubs[index_alice].mesh[topic] + + # alice emit prune message to bob, alice should be removed + # from bob's mesh peer + await gossipsubs[index_alice].emit_prune(topic, id_bob) + + # FIXME: This test currently works because the heartbeat interval + # is increased to 3 seconds, so alice won't get add back into + # bob's mesh peer during heartbeat. + await trio.sleep(1) + + # Check that alice is no longer bob's mesh peer + assert id_alice not in gossipsubs[index_bob].mesh[topic] + assert id_bob in gossipsubs[index_alice].mesh[topic] + + +@pytest.mark.trio +async def test_dense(): + async with PubsubFactory.create_batch_with_gossipsub(10) as pubsubs_gsub: + hosts = [pubsub.host for pubsub in pubsubs_gsub] + num_msgs = 5 + + # All pubsub subscribe to foobar + queues = [await pubsub.subscribe("foobar") for pubsub in pubsubs_gsub] + + # Densely connect libp2p hosts in a random way + await dense_connect(hosts) + + # Wait 2 seconds for heartbeat to allow mesh to connect + await trio.sleep(2) + + for i in range(num_msgs): + msg_content = b"foo " + i.to_bytes(1, "big") + + # randomly pick a message origin + origin_idx = random.randint(0, len(hosts) - 1) + + # publish from the randomly chosen host + await pubsubs_gsub[origin_idx].publish("foobar", msg_content) + + await trio.sleep(0.5) + # Assert that all blocking queues receive the message + for queue in queues: + msg = await queue.receive() + assert msg.data == msg_content + + +@pytest.mark.trio +async def test_fanout(): + async with PubsubFactory.create_batch_with_gossipsub(10) as pubsubs_gsub: + hosts = [pubsub.host for pubsub in pubsubs_gsub] + num_msgs = 5 + + # All pubsub subscribe to foobar except for `pubsubs_gsub[0]` + subs = [await pubsub.subscribe("foobar") for pubsub in pubsubs_gsub[1:]] + + # Sparsely connect libp2p hosts in random way + await dense_connect(hosts) + + # Wait 2 seconds for heartbeat to allow mesh to connect + await trio.sleep(2) + + topic = "foobar" + # Send messages with origin not subscribed + for i in range(num_msgs): + msg_content = b"foo " + i.to_bytes(1, "big") + + # Pick the message origin to the node that is not subscribed to 'foobar' + origin_idx = 0 + + # publish from the randomly chosen host + await pubsubs_gsub[origin_idx].publish(topic, msg_content) + + await trio.sleep(0.5) + # Assert that all blocking queues receive the message + for sub in subs: + msg = await sub.receive() + assert msg.data == msg_content + + # Subscribe message origin + subs.insert(0, await pubsubs_gsub[0].subscribe(topic)) + + # Send messages again + for i in range(num_msgs): + msg_content = b"bar " + i.to_bytes(1, "big") + + # Pick the message origin to the node that is not subscribed to 'foobar' + origin_idx = 0 + + # publish from the randomly chosen host + await pubsubs_gsub[origin_idx].publish(topic, msg_content) + + await trio.sleep(0.5) + # Assert that all blocking queues receive the message + for sub in subs: + msg = await sub.receive() + assert msg.data == msg_content + + +@pytest.mark.trio @pytest.mark.slow -async def test_fanout_maintenance(hosts, pubsubs_gsub): - num_msgs = 5 +async def test_fanout_maintenance(): + async with PubsubFactory.create_batch_with_gossipsub(10) as pubsubs_gsub: + hosts = [pubsub.host for pubsub in pubsubs_gsub] + num_msgs = 5 - # All pubsub subscribe to foobar - queues = [] - topic = "foobar" - for i in range(1, len(pubsubs_gsub)): - q = await pubsubs_gsub[i].subscribe(topic) + # All pubsub subscribe to foobar + queues = [] + topic = "foobar" + for i in range(1, len(pubsubs_gsub)): + q = await pubsubs_gsub[i].subscribe(topic) - # Add each blocking queue to an array of blocking queues - queues.append(q) + # Add each blocking queue to an array of blocking queues + queues.append(q) - # Sparsely connect libp2p hosts in random way - await dense_connect(hosts) + # Sparsely connect libp2p hosts in random way + await dense_connect(hosts) - # Wait 2 seconds for heartbeat to allow mesh to connect - await asyncio.sleep(2) + # Wait 2 seconds for heartbeat to allow mesh to connect + await trio.sleep(2) - # Send messages with origin not subscribed - for i in range(num_msgs): - msg_content = b"foo " + i.to_bytes(1, "big") + # Send messages with origin not subscribed + for i in range(num_msgs): + msg_content = b"foo " + i.to_bytes(1, "big") - # Pick the message origin to the node that is not subscribed to 'foobar' - origin_idx = 0 + # Pick the message origin to the node that is not subscribed to 'foobar' + origin_idx = 0 + + # publish from the randomly chosen host + await pubsubs_gsub[origin_idx].publish(topic, msg_content) + + await trio.sleep(0.5) + # Assert that all blocking queues receive the message + for queue in queues: + msg = await queue.receive() + assert msg.data == msg_content + + for sub in pubsubs_gsub: + await sub.unsubscribe(topic) + + queues = [] + + await trio.sleep(2) + + # Resub and repeat + for i in range(1, len(pubsubs_gsub)): + q = await pubsubs_gsub[i].subscribe(topic) + + # Add each blocking queue to an array of blocking queues + queues.append(q) + + await trio.sleep(2) + + # Check messages can still be sent + for i in range(num_msgs): + msg_content = b"bar " + i.to_bytes(1, "big") + + # Pick the message origin to the node that is not subscribed to 'foobar' + origin_idx = 0 + + # publish from the randomly chosen host + await pubsubs_gsub[origin_idx].publish(topic, msg_content) + + await trio.sleep(0.5) + # Assert that all blocking queues receive the message + for queue in queues: + msg = await queue.receive() + assert msg.data == msg_content + + +@pytest.mark.trio +async def test_gossip_propagation(): + async with PubsubFactory.create_batch_with_gossipsub( + 2, degree=1, degree_low=0, degree_high=2, gossip_window=50, gossip_history=100 + ) as pubsubs_gsub: + topic = "foo" + await pubsubs_gsub[0].subscribe(topic) + + # node 0 publish to topic + msg_content = b"foo_msg" # publish from the randomly chosen host - await pubsubs_gsub[origin_idx].publish(topic, msg_content) + await pubsubs_gsub[0].publish(topic, msg_content) - await asyncio.sleep(0.5) - # Assert that all blocking queues receive the message - for queue in queues: - msg = await queue.get() - assert msg.data == msg_content + # now node 1 subscribes + queue_1 = await pubsubs_gsub[1].subscribe(topic) - for sub in pubsubs_gsub: - await sub.unsubscribe(topic) + await connect(pubsubs_gsub[0].host, pubsubs_gsub[1].host) - queues = [] + # wait for gossip heartbeat + await trio.sleep(2) - await asyncio.sleep(2) - - # Resub and repeat - for i in range(1, len(pubsubs_gsub)): - q = await pubsubs_gsub[i].subscribe(topic) - - # Add each blocking queue to an array of blocking queues - queues.append(q) - - await asyncio.sleep(2) - - # Check messages can still be sent - for i in range(num_msgs): - msg_content = b"bar " + i.to_bytes(1, "big") - - # Pick the message origin to the node that is not subscribed to 'foobar' - origin_idx = 0 - - # publish from the randomly chosen host - await pubsubs_gsub[origin_idx].publish(topic, msg_content) - - await asyncio.sleep(0.5) - # Assert that all blocking queues receive the message - for queue in queues: - msg = await queue.get() - assert msg.data == msg_content - - -@pytest.mark.parametrize( - "num_hosts, gossipsub_params", - ( - ( - 2, - GossipsubParams( - degree=1, - degree_low=0, - degree_high=2, - gossip_window=50, - gossip_history=100, - ), - ), - ), -) -@pytest.mark.asyncio -async def test_gossip_propagation(hosts, pubsubs_gsub): - topic = "foo" - await pubsubs_gsub[0].subscribe(topic) - - # node 0 publish to topic - msg_content = b"foo_msg" - - # publish from the randomly chosen host - await pubsubs_gsub[0].publish(topic, msg_content) - - # now node 1 subscribes - queue_1 = await pubsubs_gsub[1].subscribe(topic) - - await connect(hosts[0], hosts[1]) - - # wait for gossip heartbeat - await asyncio.sleep(2) - - # should be able to read message - msg = await queue_1.get() - assert msg.data == msg_content + # should be able to read message + msg = await queue_1.receive() + assert msg.data == msg_content diff --git a/tests/pubsub/test_gossipsub_backward_compatibility.py b/tests/pubsub/test_gossipsub_backward_compatibility.py index d82fd22..08f0284 100644 --- a/tests/pubsub/test_gossipsub_backward_compatibility.py +++ b/tests/pubsub/test_gossipsub_backward_compatibility.py @@ -3,25 +3,25 @@ import functools import pytest from libp2p.tools.constants import FLOODSUB_PROTOCOL_ID -from libp2p.tools.factories import GossipsubFactory +from libp2p.tools.factories import PubsubFactory from libp2p.tools.pubsub.floodsub_integration_test_settings import ( floodsub_protocol_pytest_params, perform_test_from_obj, ) -@pytest.mark.asyncio -async def test_gossipsub_initialize_with_floodsub_protocol(): - GossipsubFactory(protocols=[FLOODSUB_PROTOCOL_ID]) - - @pytest.mark.parametrize("test_case_obj", floodsub_protocol_pytest_params) -@pytest.mark.asyncio +@pytest.mark.trio @pytest.mark.slow async def test_gossipsub_run_with_floodsub_tests(test_case_obj): await perform_test_from_obj( test_case_obj, functools.partial( - GossipsubFactory, degree=3, degree_low=2, degree_high=4, time_to_live=30 + PubsubFactory.create_batch_with_gossipsub, + protocols=[FLOODSUB_PROTOCOL_ID], + degree=3, + degree_low=2, + degree_high=4, + time_to_live=30, ), ) diff --git a/tests/pubsub/test_mcache.py b/tests/pubsub/test_mcache.py index e80ad27..fb764b3 100644 --- a/tests/pubsub/test_mcache.py +++ b/tests/pubsub/test_mcache.py @@ -1,5 +1,3 @@ -import pytest - from libp2p.pubsub.mcache import MessageCache @@ -12,8 +10,7 @@ class Msg: self.from_id = from_id -@pytest.mark.asyncio -async def test_mcache(): +def test_mcache(): # Ported from: # https://github.com/libp2p/go-libp2p-pubsub/blob/51b7501433411b5096cac2b4994a36a68515fc03/mcache_test.go mcache = MessageCache(3, 5) diff --git a/tests/pubsub/test_pubsub.py b/tests/pubsub/test_pubsub.py index 22cea0c..ea04788 100644 --- a/tests/pubsub/test_pubsub.py +++ b/tests/pubsub/test_pubsub.py @@ -5,12 +5,11 @@ import pytest import trio from libp2p.exceptions import ValidationError -from libp2p.peer.id import ID from libp2p.pubsub.pb import rpc_pb2 +from libp2p.tools.constants import MAX_READ_LEN +from libp2p.tools.factories import IDFactory, PubsubFactory, net_stream_pair_factory from libp2p.tools.pubsub.utils import make_pubsub_msg from libp2p.tools.utils import connect -from libp2p.tools.constants import MAX_READ_LEN -from libp2p.tools.factories import PubsubFactory, net_stream_pair_factory, IDFactory from libp2p.utils import encode_varint_prefixed TESTING_TOPIC = "TEST_SUBSCRIBE" @@ -250,14 +249,14 @@ async def test_continuously_read_stream(monkeypatch, nursery, is_host_secure): async def mock_push_msg(msg_forwarder, msg): event_push_msg.set() - await trio.sleep(0) + await trio.hazmat.checkpoint() def mock_handle_subscription(origin_id, sub_message): event_handle_subscription.set() async def mock_handle_rpc(rpc, sender_peer_id): event_handle_rpc.set() - await trio.sleep(0) + await trio.hazmat.checkpoint() with monkeypatch.context() as m: m.setattr(pubsubs_fsub[0], "push_msg", mock_push_msg) diff --git a/tests/stream_muxer/test_mplex_stream.py b/tests/stream_muxer/test_mplex_stream.py index e47af49..55ee97b 100644 --- a/tests/stream_muxer/test_mplex_stream.py +++ b/tests/stream_muxer/test_mplex_stream.py @@ -69,33 +69,10 @@ async def test_mplex_stream_read_after_remote_closed(mplex_stream_pair): await stream_0.close() assert stream_0.event_local_closed.is_set() await trio.sleep(0.01) - print( - "!@# ", - stream_0.muxed_conn.event_shutting_down.is_set(), - stream_0.muxed_conn.event_closed.is_set(), - stream_1.muxed_conn.event_shutting_down.is_set(), - stream_1.muxed_conn.event_closed.is_set(), - ) # await trio.sleep(100000) await wait_all_tasks_blocked() - print( - "!@# ", - stream_0.muxed_conn.event_shutting_down.is_set(), - stream_0.muxed_conn.event_closed.is_set(), - stream_1.muxed_conn.event_shutting_down.is_set(), - stream_1.muxed_conn.event_closed.is_set(), - ) - print("!@# sleeping") - print("!@# result=", stream_1.event_remote_closed.is_set()) # await trio.sleep_forever() assert stream_1.event_remote_closed.is_set() - print( - "!@# ", - stream_0.muxed_conn.event_shutting_down.is_set(), - stream_0.muxed_conn.event_closed.is_set(), - stream_1.muxed_conn.event_shutting_down.is_set(), - stream_1.muxed_conn.event_closed.is_set(), - ) assert (await stream_1.read(MAX_READ_LEN)) == DATA with pytest.raises(MplexStreamEOF): await stream_1.read(MAX_READ_LEN) diff --git a/tests/transport/test_tcp.py b/tests/transport/test_tcp.py index c8fe6f2..abd5884 100644 --- a/tests/transport/test_tcp.py +++ b/tests/transport/test_tcp.py @@ -3,7 +3,7 @@ import pytest import trio from libp2p.network.connection.raw_connection import RawConnection -from libp2p.tools.constants import LISTEN_MADDR, MAX_READ_LEN +from libp2p.tools.constants import LISTEN_MADDR from libp2p.transport.tcp.tcp import TCP