From e9ab0646e38a1e13fd0ba5e2d6362da459b241db Mon Sep 17 00:00:00 2001 From: mhchia Date: Tue, 3 Dec 2019 17:27:49 +0800 Subject: [PATCH] Fix Pubsub --- libp2p/pubsub/floodsub.py | 8 + libp2p/pubsub/pubsub.py | 181 ++++--- libp2p/pubsub/pubsub_notifee.py | 19 +- libp2p/stream_muxer/mplex/mplex.py | 2 +- libp2p/tools/factories.py | 49 +- tests/pubsub/conftest.py | 32 +- tests/pubsub/test_pubsub.py | 800 ++++++++++++++--------------- 7 files changed, 568 insertions(+), 523 deletions(-) diff --git a/libp2p/pubsub/floodsub.py b/libp2p/pubsub/floodsub.py index bac0bd7..8c15a44 100644 --- a/libp2p/pubsub/floodsub.py +++ b/libp2p/pubsub/floodsub.py @@ -1,6 +1,8 @@ import logging from typing import Iterable, List, Sequence +import trio + from libp2p.network.stream.exceptions import StreamClosed from libp2p.peer.id import ID from libp2p.typing import TProtocol @@ -61,6 +63,8 @@ class FloodSub(IPubsubRouter): :param rpc: rpc message """ + # Checkpoint + await trio.sleep(0) async def publish(self, msg_forwarder: ID, pubsub_msg: rpc_pb2.Message) -> None: """ @@ -102,6 +106,8 @@ class FloodSub(IPubsubRouter): :param topic: topic to join """ + # Checkpoint + await trio.sleep(0) async def leave(self, topic: str) -> None: """ @@ -110,6 +116,8 @@ class FloodSub(IPubsubRouter): :param topic: topic to leave """ + # Checkpoint + await trio.sleep(0) def _get_peers_to_send( self, topic_ids: Iterable[str], msg_forwarder: ID, origin: ID diff --git a/libp2p/pubsub/pubsub.py b/libp2p/pubsub/pubsub.py index 3834eb4..7c4b50d 100644 --- a/libp2p/pubsub/pubsub.py +++ b/libp2p/pubsub/pubsub.py @@ -1,11 +1,13 @@ -import asyncio +from abc import ABC, abstractmethod import logging +import math import time from typing import ( TYPE_CHECKING, Awaitable, Callable, Dict, + KeysView, List, NamedTuple, Tuple, @@ -13,8 +15,10 @@ from typing import ( cast, ) +from async_service import Service import base58 from lru import LRU +import trio from libp2p.exceptions import ParseError, ValidationError from libp2p.host.host_interface import IHost @@ -53,24 +57,24 @@ class TopicValidator(NamedTuple): is_async: bool -class Pubsub: +class BasePubsub(ABC): + pass + + +class Pubsub(BasePubsub, Service): host: IHost - my_id: ID router: "IPubsubRouter" - peer_queue: "asyncio.Queue[ID]" - dead_peer_queue: "asyncio.Queue[ID]" - - protocols: List[TProtocol] - - incoming_msgs_from_peers: "asyncio.Queue[rpc_pb2.Message]" - outgoing_messages: "asyncio.Queue[rpc_pb2.Message]" + peer_receive_channel: "trio.MemoryReceiveChannel[ID]" + dead_peer_receive_channel: "trio.MemoryReceiveChannel[ID]" seen_messages: LRU - my_topics: Dict[str, "asyncio.Queue[rpc_pb2.Message]"] + # TODO: Implement `trio.abc.Channel`? + subscribed_topics_send: Dict[str, "trio.MemorySendChannel[rpc_pb2.Message]"] + subscribed_topics_receive: Dict[str, "trio.MemoryReceiveChannel[rpc_pb2.Message]"] peer_topics: Dict[str, List[ID]] peers: Dict[ID, INetStream] @@ -80,10 +84,8 @@ class Pubsub: # TODO: Be sure it is increased atomically everytime. counter: int # uint64 - _tasks: List["asyncio.Future[Any]"] - def __init__( - self, host: IHost, router: "IPubsubRouter", my_id: ID, cache_size: int = None + self, host: IHost, router: "IPubsubRouter", cache_size: int = None ) -> None: """ Construct a new Pubsub object, which is responsible for handling all @@ -97,28 +99,26 @@ class Pubsub: """ self.host = host self.router = router - self.my_id = my_id # Attach this new Pubsub object to the router self.router.attach(self) + peer_send_channel, peer_receive_channel = trio.open_memory_channel(0) + dead_peer_send_channel, dead_peer_receive_channel = trio.open_memory_channel(0) + # Only keep the receive channels in `Pubsub`. + # Therefore, we can only close from the receive side. + self.peer_receive_channel = peer_receive_channel + self.dead_peer_receive_channel = dead_peer_receive_channel # Register a notifee - self.peer_queue = asyncio.Queue() - self.dead_peer_queue = asyncio.Queue() self.host.get_network().register_notifee( - PubsubNotifee(self.peer_queue, self.dead_peer_queue) + PubsubNotifee(peer_send_channel, dead_peer_send_channel) ) # Register stream handlers for each pubsub router protocol to handle # the pubsub streams opened on those protocols - self.protocols = self.router.get_protocols() - for protocol in self.protocols: + for protocol in router.protocols: self.host.set_stream_handler(protocol, self.stream_handler) - # Use asyncio queues for proper context switching - self.incoming_msgs_from_peers = asyncio.Queue() - self.outgoing_messages = asyncio.Queue() - # keeps track of seen messages as LRU cache if cache_size is None: self.cache_size = 128 @@ -129,7 +129,8 @@ class Pubsub: # Map of topics we are subscribed to blocking queues # for when the given topic receives a message - self.my_topics = {} + self.subscribed_topics_send = {} + self.subscribed_topics_receive = {} # Map of topic to peers to keep track of what peers are subscribed to self.peer_topics = {} @@ -142,16 +143,28 @@ class Pubsub: self.counter = time.time_ns() - self._tasks = [] - # Call handle peer to keep waiting for updates to peer queue - self._tasks.append(asyncio.ensure_future(self.handle_peer_queue())) - self._tasks.append(asyncio.ensure_future(self.handle_dead_peer_queue())) + async def run(self) -> None: + self.manager.run_daemon_task(self.handle_peer_queue) + self.manager.run_daemon_task(self.handle_dead_peer_queue) + await self.manager.wait_finished() + + @property + def my_id(self) -> ID: + return self.host.get_id() + + @property + def protocols(self) -> Tuple[TProtocol, ...]: + return tuple(self.router.get_protocols()) + + @property + def topic_ids(self) -> KeysView[str]: + return self.subscribed_topics_receive.keys() def get_hello_packet(self) -> rpc_pb2.RPC: """Generate subscription message with all topics we are subscribed to only send hello packet if we have subscribed topics.""" packet = rpc_pb2.RPC() - for topic_id in self.my_topics: + for topic_id in self.topic_ids: packet.subscriptions.extend( [rpc_pb2.RPC.SubOpts(subscribe=True, topicid=topic_id)] ) @@ -166,7 +179,7 @@ class Pubsub: """ peer_id = stream.muxed_conn.peer_id - while True: + while self.manager.is_running: incoming: bytes = await read_varint_prefixed_bytes(stream) rpc_incoming: rpc_pb2.RPC = rpc_pb2.RPC() rpc_incoming.ParseFromString(incoming) @@ -178,11 +191,7 @@ class Pubsub: logger.debug( "received `publish` message %s from peer %s", msg, peer_id ) - self._tasks.append( - asyncio.ensure_future( - self.push_msg(msg_forwarder=peer_id, msg=msg) - ) - ) + self.manager.run_task(self.push_msg, peer_id, msg) if rpc_incoming.subscriptions: # deal with RPC.subscriptions @@ -210,9 +219,6 @@ class Pubsub: ) await self.router.handle_rpc(rpc_incoming, peer_id) - # Force context switch - await asyncio.sleep(0) - def set_topic_validator( self, topic: str, validator: ValidatorFn, is_async_validator: bool ) -> None: @@ -285,7 +291,6 @@ class Pubsub: logger.debug("Fail to add new peer %s: stream closed", peer_id) del self.peers[peer_id] return - # TODO: Check EOF of this stream. # TODO: Check if the peer in black list. try: self.router.add_peer(peer_id, stream.get_protocol()) @@ -311,23 +316,25 @@ class Pubsub: async def handle_peer_queue(self) -> None: """ - Continuously read from peer queue and each time a new peer is found, + Continuously read from peer channel and each time a new peer is found, open a stream to the peer using a supported pubsub protocol TODO: Handle failure for when the peer does not support any of the pubsub protocols we support """ - while True: - peer_id: ID = await self.peer_queue.get() - # Add Peer - self._tasks.append(asyncio.ensure_future(self._handle_new_peer(peer_id))) + async with self.peer_receive_channel: + while self.manager.is_running: + peer_id: ID = await self.peer_receive_channel.receive() + # Add Peer + self.manager.run_task(self._handle_new_peer, peer_id) async def handle_dead_peer_queue(self) -> None: - """Continuously read from dead peer queue and close the stream between + """Continuously read from dead peer channel and close the stream between that peer and remove peer info from pubsub and pubsub router.""" - while True: - peer_id: ID = await self.dead_peer_queue.get() - # Remove Peer - self._handle_dead_peer(peer_id) + async with self.dead_peer_receive_channel: + while self.manager.is_running: + peer_id: ID = await self.dead_peer_receive_channel.receive() + # Remove Peer + self._handle_dead_peer(peer_id) def handle_subscription( self, origin_id: ID, sub_message: rpc_pb2.RPC.SubOpts @@ -361,13 +368,16 @@ class Pubsub: # Check if this message has any topics that we are subscribed to for topic in publish_message.topicIDs: - if topic in self.my_topics: + if topic in self.topic_ids: # we are subscribed to a topic this message was sent for, # so add message to the subscription output queue # for each topic - await self.my_topics[topic].put(publish_message) + await self.subscribed_topics_send[topic].send(publish_message) - async def subscribe(self, topic_id: str) -> "asyncio.Queue[rpc_pb2.Message]": + # TODO: Change to return an `AsyncIterable` to be I/O-agnostic? + async def subscribe( + self, topic_id: str + ) -> "trio.MemoryReceiveChannel[rpc_pb2.Message]": """ Subscribe ourself to a topic. @@ -377,11 +387,13 @@ class Pubsub: logger.debug("subscribing to topic %s", topic_id) # Already subscribed - if topic_id in self.my_topics: - return self.my_topics[topic_id] + if topic_id in self.topic_ids: + return self.subscribed_topics_receive[topic_id] - # Map topic_id to blocking queue - self.my_topics[topic_id] = asyncio.Queue() + # Map topic_id to a blocking channel + send_channel, receive_channel = trio.open_memory_channel(math.inf) + self.subscribed_topics_send[topic_id] = send_channel + self.subscribed_topics_receive[topic_id] = receive_channel # Create subscribe message packet: rpc_pb2.RPC = rpc_pb2.RPC() @@ -395,8 +407,8 @@ class Pubsub: # Tell router we are joining this topic await self.router.join(topic_id) - # Return the asyncio queue for messages on this topic - return self.my_topics[topic_id] + # Return the trio channel for messages on this topic + return receive_channel async def unsubscribe(self, topic_id: str) -> None: """ @@ -408,10 +420,14 @@ class Pubsub: logger.debug("unsubscribing from topic %s", topic_id) # Return if we already unsubscribed from the topic - if topic_id not in self.my_topics: + if topic_id not in self.topic_ids: return - # Remove topic_id from map if present - del self.my_topics[topic_id] + # Remove topic_id from the maps before yielding + send_channel = self.subscribed_topics_send[topic_id] + del self.subscribed_topics_send[topic_id] + del self.subscribed_topics_receive[topic_id] + # Only close the send side + await send_channel.aclose() # Create unsubscribe message packet: rpc_pb2.RPC = rpc_pb2.RPC() @@ -453,13 +469,13 @@ class Pubsub: data=data, topicIDs=[topic_id], # Origin is ourself. - from_id=self.host.get_id().to_bytes(), + from_id=self.my_id.to_bytes(), seqno=self._next_seqno(), ) # TODO: Sign with our signing key - await self.push_msg(self.host.get_id(), msg) + await self.push_msg(self.my_id, msg) logger.debug("successfully published message %s", msg) @@ -470,12 +486,12 @@ class Pubsub: :param msg_forwarder: the peer who forward us the message. :param msg: the message. """ - sync_topic_validators = [] - async_topic_validator_futures: List[Awaitable[bool]] = [] + sync_topic_validators: List[SyncValidatorFn] = [] + async_topic_validators: List[AsyncValidatorFn] = [] for topic_validator in self.get_msg_validators(msg): if topic_validator.is_async: - async_topic_validator_futures.append( - cast(Awaitable[bool], topic_validator.validator(msg_forwarder, msg)) + async_topic_validators.append( + cast(AsyncValidatorFn, topic_validator.validator) ) else: sync_topic_validators.append( @@ -488,9 +504,20 @@ class Pubsub: # TODO: Implement throttle on async validators - if len(async_topic_validator_futures) > 0: - results = await asyncio.gather(*async_topic_validator_futures) - if not all(results): + if len(async_topic_validators) > 0: + # TODO: Use a better pattern + final_result = True + + async def run_async_validator(func: AsyncValidatorFn) -> None: + nonlocal final_result + result = await func(msg_forwarder, msg) + final_result = final_result and result + + async with trio.open_nursery() as nursery: + for validator in async_topic_validators: + nursery.start_soon(run_async_validator, validator) + + if not final_result: raise ValidationError(f"Validation failed for msg={msg}") async def push_msg(self, msg_forwarder: ID, msg: rpc_pb2.Message) -> None: @@ -551,14 +578,4 @@ class Pubsub: self.seen_messages[msg_id] = 1 def _is_subscribed_to_msg(self, msg: rpc_pb2.Message) -> bool: - if not self.my_topics: - return False - return any(topic in self.my_topics for topic in msg.topicIDs) - - async def close(self) -> None: - for task in self._tasks: - task.cancel() - try: - await task - except asyncio.CancelledError: - pass + return any(topic in self.topic_ids for topic in msg.topicIDs) diff --git a/libp2p/pubsub/pubsub_notifee.py b/libp2p/pubsub/pubsub_notifee.py index 6afa9ad..7394736 100644 --- a/libp2p/pubsub/pubsub_notifee.py +++ b/libp2p/pubsub/pubsub_notifee.py @@ -8,19 +8,19 @@ from libp2p.network.notifee_interface import INotifee from libp2p.network.stream.net_stream_interface import INetStream if TYPE_CHECKING: - import asyncio # noqa: F401 + import trio # noqa: F401 from libp2p.peer.id import ID # noqa: F401 class PubsubNotifee(INotifee): - initiator_peers_queue: "asyncio.Queue[ID]" - dead_peers_queue: "asyncio.Queue[ID]" + initiator_peers_queue: "trio.MemorySendChannel[ID]" + dead_peers_queue: "trio.MemorySendChannel[ID]" def __init__( self, - initiator_peers_queue: "asyncio.Queue[ID]", - dead_peers_queue: "asyncio.Queue[ID]", + initiator_peers_queue: "trio.MemorySendChannel[ID]", + dead_peers_queue: "trio.MemorySendChannel[ID]", ) -> None: """ :param initiator_peers_queue: queue to add new peers to so that pubsub @@ -46,7 +46,12 @@ class PubsubNotifee(INotifee): :param network: network the connection was opened on :param conn: connection that was opened """ - await self.initiator_peers_queue.put(conn.muxed_conn.peer_id) + try: + await self.initiator_peers_queue.send(conn.muxed_conn.peer_id) + except trio.BrokenResourceError: + # Raised when the receive channel is closed. + # TODO: Do something with loggers? + ... async def disconnected(self, network: INetwork, conn: INetConn) -> None: """ @@ -56,7 +61,7 @@ class PubsubNotifee(INotifee): :param network: network the connection was opened on :param conn: connection that was opened """ - await self.dead_peers_queue.put(conn.muxed_conn.peer_id) + await self.dead_peers_queue.send(conn.muxed_conn.peer_id) async def listen(self, network: INetwork, multiaddr: Multiaddr) -> None: pass diff --git a/libp2p/stream_muxer/mplex/mplex.py b/libp2p/stream_muxer/mplex/mplex.py index 6d8a64e..ac6cdcd 100644 --- a/libp2p/stream_muxer/mplex/mplex.py +++ b/libp2p/stream_muxer/mplex/mplex.py @@ -281,7 +281,7 @@ class Mplex(IMuxedConn, Service): mplex_stream = await self._initialize_stream(stream_id, message.decode()) try: await self.new_stream_send_channel.send(mplex_stream) - except (trio.BrokenResourceError, trio.EndOfChannel): + except (trio.BrokenResourceError, trio.ClosedResourceError): raise MplexUnavailable async def _handle_message(self, stream_id: StreamID, message: bytes) -> None: diff --git a/libp2p/tools/factories.py b/libp2p/tools/factories.py index 6b6c78a..ac24301 100644 --- a/libp2p/tools/factories.py +++ b/libp2p/tools/factories.py @@ -5,6 +5,7 @@ from async_service import background_trio_service import factory import trio +from libp2p.tools.constants import GOSSIPSUB_PARAMS from libp2p import generate_new_rsa_identity, generate_peer_id_from from libp2p.crypto.keys import KeyPair from libp2p.host.basic_host import BasicHost @@ -15,6 +16,7 @@ from libp2p.network.connection.swarm_connection import SwarmConn from libp2p.network.stream.net_stream_interface import INetStream from libp2p.network.swarm import Swarm from libp2p.peer.peerstore import PeerStore +from libp2p.peer.id import ID from libp2p.pubsub.floodsub import FloodSub from libp2p.pubsub.gossipsub import GossipSub from libp2p.pubsub.pubsub import Pubsub @@ -28,15 +30,19 @@ from libp2p.transport.typing import TMuxerOptions from libp2p.transport.upgrader import TransportUpgrader from libp2p.typing import TProtocol -from .constants import ( - FLOODSUB_PROTOCOL_ID, - GOSSIPSUB_PARAMS, - GOSSIPSUB_PROTOCOL_ID, - LISTEN_MADDR, -) +from .constants import FLOODSUB_PROTOCOL_ID, GOSSIPSUB_PROTOCOL_ID, LISTEN_MADDR from .utils import connect, connect_swarm +class IDFactory(factory.Factory): + class Meta: + model = ID + + peer_id_bytes = factory.LazyFunction( + lambda: generate_peer_id_from(generate_new_rsa_identity()) + ) + + def security_transport_factory( is_secure: bool, key_pair: KeyPair ) -> Dict[TProtocol, BaseSecureTransport]: @@ -181,9 +187,38 @@ class PubsubFactory(factory.Factory): host = factory.SubFactory(HostFactory) router = None - my_id = factory.LazyAttribute(lambda obj: obj.host.get_id()) cache_size = None + @classmethod + @asynccontextmanager + async def create_and_start(cls, host, router, cache_size): + pubsub = PubsubFactory(host=host, router=router, cache_size=cache_size) + async with background_trio_service(pubsub): + yield pubsub + + @classmethod + @asynccontextmanager + async def create_batch_with_floodsub( + cls, number: int, is_secure: bool = False, cache_size: int = None + ): + floodsubs = FloodsubFactory.create_batch(number) + async with HostFactory.create_batch_and_listen(is_secure, number) as hosts: + # Pubsubs should exit before hosts + async with AsyncExitStack() as stack: + pubsubs = [ + await stack.enter_async_context( + cls.create_and_start(host, router, cache_size) + ) + for host, router in zip(hosts, floodsubs) + ] + yield pubsubs + + # @classmethod + # async def create_batch_with_gossipsub( + # cls, number: int, cache_size: int = None, gossipsub_params=GOSSIPSUB_PARAMS + # ): + # ... + @asynccontextmanager async def swarm_pair_factory( diff --git a/tests/pubsub/conftest.py b/tests/pubsub/conftest.py index 9dbe90b..6c08dd7 100644 --- a/tests/pubsub/conftest.py +++ b/tests/pubsub/conftest.py @@ -4,18 +4,6 @@ from libp2p.tools.constants import GOSSIPSUB_PARAMS from libp2p.tools.factories import FloodsubFactory, GossipsubFactory, PubsubFactory -def _make_pubsubs(hosts, pubsub_routers, cache_size): - if len(pubsub_routers) != len(hosts): - raise ValueError( - f"lenght of pubsub_routers={pubsub_routers} should be equaled to the " - f"length of hosts={len(hosts)}" - ) - return tuple( - PubsubFactory(host=host, router=router, cache_size=cache_size) - for host, router in zip(hosts, pubsub_routers) - ) - - @pytest.fixture def pubsub_cache_size(): return None # default @@ -26,17 +14,9 @@ def gossipsub_params(): return GOSSIPSUB_PARAMS -@pytest.fixture -def pubsubs_fsub(num_hosts, hosts, pubsub_cache_size): - floodsubs = FloodsubFactory.create_batch(num_hosts) - _pubsubs_fsub = _make_pubsubs(hosts, floodsubs, pubsub_cache_size) - yield _pubsubs_fsub - # TODO: Clean up - - -@pytest.fixture -def pubsubs_gsub(num_hosts, hosts, pubsub_cache_size, gossipsub_params): - gossipsubs = GossipsubFactory.create_batch(num_hosts, **gossipsub_params._asdict()) - _pubsubs_gsub = _make_pubsubs(hosts, gossipsubs, pubsub_cache_size) - yield _pubsubs_gsub - # TODO: Clean up +# @pytest.fixture +# def pubsubs_gsub(num_hosts, hosts, pubsub_cache_size, gossipsub_params): +# gossipsubs = GossipsubFactory.create_batch(num_hosts, **gossipsub_params._asdict()) +# _pubsubs_gsub = _make_pubsubs(hosts, gossipsubs, pubsub_cache_size) +# yield _pubsubs_gsub +# # TODO: Clean up diff --git a/tests/pubsub/test_pubsub.py b/tests/pubsub/test_pubsub.py index ebe2003..22cea0c 100644 --- a/tests/pubsub/test_pubsub.py +++ b/tests/pubsub/test_pubsub.py @@ -1,348 +1,328 @@ -import asyncio +from contextlib import contextmanager from typing import NamedTuple import pytest +import trio from libp2p.exceptions import ValidationError from libp2p.peer.id import ID from libp2p.pubsub.pb import rpc_pb2 from libp2p.tools.pubsub.utils import make_pubsub_msg from libp2p.tools.utils import connect +from libp2p.tools.constants import MAX_READ_LEN +from libp2p.tools.factories import PubsubFactory, net_stream_pair_factory, IDFactory from libp2p.utils import encode_varint_prefixed TESTING_TOPIC = "TEST_SUBSCRIBE" TESTING_DATA = b"data" -@pytest.mark.parametrize("num_hosts", (1,)) -@pytest.mark.asyncio -async def test_subscribe_and_unsubscribe(pubsubs_fsub): - await pubsubs_fsub[0].subscribe(TESTING_TOPIC) - assert TESTING_TOPIC in pubsubs_fsub[0].my_topics +@pytest.mark.trio +async def test_subscribe_and_unsubscribe(): + async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub: + await pubsubs_fsub[0].subscribe(TESTING_TOPIC) + assert TESTING_TOPIC in pubsubs_fsub[0].topic_ids - await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC) - assert TESTING_TOPIC not in pubsubs_fsub[0].my_topics + await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC) + assert TESTING_TOPIC not in pubsubs_fsub[0].topic_ids -@pytest.mark.parametrize("num_hosts", (1,)) -@pytest.mark.asyncio -async def test_re_subscribe(pubsubs_fsub): - await pubsubs_fsub[0].subscribe(TESTING_TOPIC) - assert TESTING_TOPIC in pubsubs_fsub[0].my_topics +@pytest.mark.trio +async def test_re_subscribe(): + async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub: + await pubsubs_fsub[0].subscribe(TESTING_TOPIC) + assert TESTING_TOPIC in pubsubs_fsub[0].topic_ids - await pubsubs_fsub[0].subscribe(TESTING_TOPIC) - assert TESTING_TOPIC in pubsubs_fsub[0].my_topics + await pubsubs_fsub[0].subscribe(TESTING_TOPIC) + assert TESTING_TOPIC in pubsubs_fsub[0].topic_ids -@pytest.mark.parametrize("num_hosts", (1,)) -@pytest.mark.asyncio -async def test_re_unsubscribe(pubsubs_fsub): - # Unsubscribe from topic we didn't even subscribe to - assert "NOT_MY_TOPIC" not in pubsubs_fsub[0].my_topics - await pubsubs_fsub[0].unsubscribe("NOT_MY_TOPIC") - assert "NOT_MY_TOPIC" not in pubsubs_fsub[0].my_topics +@pytest.mark.trio +async def test_re_unsubscribe(): + async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub: + # Unsubscribe from topic we didn't even subscribe to + assert "NOT_MY_TOPIC" not in pubsubs_fsub[0].topic_ids + await pubsubs_fsub[0].unsubscribe("NOT_MY_TOPIC") + assert "NOT_MY_TOPIC" not in pubsubs_fsub[0].topic_ids - await pubsubs_fsub[0].subscribe(TESTING_TOPIC) - assert TESTING_TOPIC in pubsubs_fsub[0].my_topics + await pubsubs_fsub[0].subscribe(TESTING_TOPIC) + assert TESTING_TOPIC in pubsubs_fsub[0].topic_ids - await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC) - assert TESTING_TOPIC not in pubsubs_fsub[0].my_topics + await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC) + assert TESTING_TOPIC not in pubsubs_fsub[0].topic_ids - await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC) - assert TESTING_TOPIC not in pubsubs_fsub[0].my_topics + await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC) + assert TESTING_TOPIC not in pubsubs_fsub[0].topic_ids -@pytest.mark.asyncio -async def test_peers_subscribe(pubsubs_fsub): - await connect(pubsubs_fsub[0].host, pubsubs_fsub[1].host) - await pubsubs_fsub[0].subscribe(TESTING_TOPIC) - # Yield to let 0 notify 1 - await asyncio.sleep(0.1) - assert pubsubs_fsub[0].my_id in pubsubs_fsub[1].peer_topics[TESTING_TOPIC] - await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC) - # Yield to let 0 notify 1 - await asyncio.sleep(0.1) - assert pubsubs_fsub[0].my_id not in pubsubs_fsub[1].peer_topics[TESTING_TOPIC] +@pytest.mark.trio +async def test_peers_subscribe(): + async with PubsubFactory.create_batch_with_floodsub(2) as pubsubs_fsub: + await connect(pubsubs_fsub[0].host, pubsubs_fsub[1].host) + await pubsubs_fsub[0].subscribe(TESTING_TOPIC) + # Yield to let 0 notify 1 + await trio.sleep(0.1) + assert pubsubs_fsub[0].my_id in pubsubs_fsub[1].peer_topics[TESTING_TOPIC] + await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC) + # Yield to let 0 notify 1 + await trio.sleep(0.1) + assert pubsubs_fsub[0].my_id not in pubsubs_fsub[1].peer_topics[TESTING_TOPIC] -@pytest.mark.parametrize("num_hosts", (1,)) -@pytest.mark.asyncio -async def test_get_hello_packet(pubsubs_fsub): - def _get_hello_packet_topic_ids(): - packet = pubsubs_fsub[0].get_hello_packet() - return tuple(sub.topicid for sub in packet.subscriptions) +@pytest.mark.trio +async def test_get_hello_packet(): + async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub: - # Test: No subscription, so there should not be any topic ids in the hello packet. - assert len(_get_hello_packet_topic_ids()) == 0 + def _get_hello_packet_topic_ids(): + packet = pubsubs_fsub[0].get_hello_packet() + return tuple(sub.topicid for sub in packet.subscriptions) - # Test: After subscriptions, topic ids should be in the hello packet. - topic_ids = ["t", "o", "p", "i", "c"] - await asyncio.gather(*[pubsubs_fsub[0].subscribe(topic) for topic in topic_ids]) - topic_ids_in_hello = _get_hello_packet_topic_ids() - for topic in topic_ids: - assert topic in topic_ids_in_hello + # Test: No subscription, so there should not be any topic ids in the hello packet. + assert len(_get_hello_packet_topic_ids()) == 0 + + # Test: After subscriptions, topic ids should be in the hello packet. + topic_ids = ["t", "o", "p", "i", "c"] + for topic in topic_ids: + await pubsubs_fsub[0].subscribe(topic) + topic_ids_in_hello = _get_hello_packet_topic_ids() + for topic in topic_ids: + assert topic in topic_ids_in_hello -@pytest.mark.parametrize("num_hosts", (1,)) -@pytest.mark.asyncio -async def test_set_and_remove_topic_validator(pubsubs_fsub): +@pytest.mark.trio +async def test_set_and_remove_topic_validator(): + async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub: + is_sync_validator_called = False - is_sync_validator_called = False + def sync_validator(peer_id, msg): + nonlocal is_sync_validator_called + is_sync_validator_called = True - def sync_validator(peer_id, msg): - nonlocal is_sync_validator_called - is_sync_validator_called = True + is_async_validator_called = False - is_async_validator_called = False + async def async_validator(peer_id, msg): + nonlocal is_async_validator_called + is_async_validator_called = True - async def async_validator(peer_id, msg): - nonlocal is_async_validator_called - is_async_validator_called = True + topic = "TEST_VALIDATOR" - topic = "TEST_VALIDATOR" + assert topic not in pubsubs_fsub[0].topic_validators - assert topic not in pubsubs_fsub[0].topic_validators + # Register sync validator + pubsubs_fsub[0].set_topic_validator(topic, sync_validator, False) - # Register sync validator - pubsubs_fsub[0].set_topic_validator(topic, sync_validator, False) + assert topic in pubsubs_fsub[0].topic_validators + topic_validator = pubsubs_fsub[0].topic_validators[topic] + assert not topic_validator.is_async - assert topic in pubsubs_fsub[0].topic_validators - topic_validator = pubsubs_fsub[0].topic_validators[topic] - assert not topic_validator.is_async + # Validate with sync validator + topic_validator.validator(peer_id=IDFactory(), msg="msg") - # Validate with sync validator - topic_validator.validator(peer_id=ID(b"peer"), msg="msg") + assert is_sync_validator_called + assert not is_async_validator_called - assert is_sync_validator_called - assert not is_async_validator_called + # Register with async validator + pubsubs_fsub[0].set_topic_validator(topic, async_validator, True) - # Register with async validator - pubsubs_fsub[0].set_topic_validator(topic, async_validator, True) + is_sync_validator_called = False + assert topic in pubsubs_fsub[0].topic_validators + topic_validator = pubsubs_fsub[0].topic_validators[topic] + assert topic_validator.is_async - is_sync_validator_called = False - assert topic in pubsubs_fsub[0].topic_validators - topic_validator = pubsubs_fsub[0].topic_validators[topic] - assert topic_validator.is_async + # Validate with async validator + await topic_validator.validator(peer_id=IDFactory(), msg="msg") - # Validate with async validator - await topic_validator.validator(peer_id=ID(b"peer"), msg="msg") + assert is_async_validator_called + assert not is_sync_validator_called - assert is_async_validator_called - assert not is_sync_validator_called - - # Remove validator - pubsubs_fsub[0].remove_topic_validator(topic) - assert topic not in pubsubs_fsub[0].topic_validators + # Remove validator + pubsubs_fsub[0].remove_topic_validator(topic) + assert topic not in pubsubs_fsub[0].topic_validators -@pytest.mark.parametrize("num_hosts", (1,)) -@pytest.mark.asyncio -async def test_get_msg_validators(pubsubs_fsub): +@pytest.mark.trio +async def test_get_msg_validators(): + async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub: + times_sync_validator_called = 0 - times_sync_validator_called = 0 + def sync_validator(peer_id, msg): + nonlocal times_sync_validator_called + times_sync_validator_called += 1 - def sync_validator(peer_id, msg): - nonlocal times_sync_validator_called - times_sync_validator_called += 1 + times_async_validator_called = 0 - times_async_validator_called = 0 + async def async_validator(peer_id, msg): + nonlocal times_async_validator_called + times_async_validator_called += 1 - async def async_validator(peer_id, msg): - nonlocal times_async_validator_called - times_async_validator_called += 1 + topic_1 = "TEST_VALIDATOR_1" + topic_2 = "TEST_VALIDATOR_2" + topic_3 = "TEST_VALIDATOR_3" - topic_1 = "TEST_VALIDATOR_1" - topic_2 = "TEST_VALIDATOR_2" - topic_3 = "TEST_VALIDATOR_3" + # Register sync validator for topic 1 and 2 + pubsubs_fsub[0].set_topic_validator(topic_1, sync_validator, False) + pubsubs_fsub[0].set_topic_validator(topic_2, sync_validator, False) - # Register sync validator for topic 1 and 2 - pubsubs_fsub[0].set_topic_validator(topic_1, sync_validator, False) - pubsubs_fsub[0].set_topic_validator(topic_2, sync_validator, False) + # Register async validator for topic 3 + pubsubs_fsub[0].set_topic_validator(topic_3, async_validator, True) - # Register async validator for topic 3 - pubsubs_fsub[0].set_topic_validator(topic_3, async_validator, True) + msg = make_pubsub_msg( + origin_id=pubsubs_fsub[0].my_id, + topic_ids=[topic_1, topic_2, topic_3], + data=b"1234", + seqno=b"\x00" * 8, + ) - msg = make_pubsub_msg( - origin_id=pubsubs_fsub[0].my_id, - topic_ids=[topic_1, topic_2, topic_3], - data=b"1234", - seqno=b"\x00" * 8, - ) + topic_validators = pubsubs_fsub[0].get_msg_validators(msg) + for topic_validator in topic_validators: + if topic_validator.is_async: + await topic_validator.validator(peer_id=IDFactory(), msg="msg") + else: + topic_validator.validator(peer_id=IDFactory(), msg="msg") - topic_validators = pubsubs_fsub[0].get_msg_validators(msg) - for topic_validator in topic_validators: - if topic_validator.is_async: - await topic_validator.validator(peer_id=ID(b"peer"), msg="msg") - else: - topic_validator.validator(peer_id=ID(b"peer"), msg="msg") - - assert times_sync_validator_called == 2 - assert times_async_validator_called == 1 + assert times_sync_validator_called == 2 + assert times_async_validator_called == 1 -@pytest.mark.parametrize("num_hosts", (1,)) @pytest.mark.parametrize( "is_topic_1_val_passed, is_topic_2_val_passed", ((False, True), (True, False), (True, True)), ) -@pytest.mark.asyncio -async def test_validate_msg(pubsubs_fsub, is_topic_1_val_passed, is_topic_2_val_passed): - def passed_sync_validator(peer_id, msg): - return True +@pytest.mark.trio +async def test_validate_msg(is_topic_1_val_passed, is_topic_2_val_passed): + async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub: - def failed_sync_validator(peer_id, msg): - return False + def passed_sync_validator(peer_id, msg): + return True - async def passed_async_validator(peer_id, msg): - return True + def failed_sync_validator(peer_id, msg): + return False - async def failed_async_validator(peer_id, msg): - return False + async def passed_async_validator(peer_id, msg): + return True - topic_1 = "TEST_SYNC_VALIDATOR" - topic_2 = "TEST_ASYNC_VALIDATOR" + async def failed_async_validator(peer_id, msg): + return False - if is_topic_1_val_passed: - pubsubs_fsub[0].set_topic_validator(topic_1, passed_sync_validator, False) - else: - pubsubs_fsub[0].set_topic_validator(topic_1, failed_sync_validator, False) + topic_1 = "TEST_SYNC_VALIDATOR" + topic_2 = "TEST_ASYNC_VALIDATOR" - if is_topic_2_val_passed: - pubsubs_fsub[0].set_topic_validator(topic_2, passed_async_validator, True) - else: - pubsubs_fsub[0].set_topic_validator(topic_2, failed_async_validator, True) - - msg = make_pubsub_msg( - origin_id=pubsubs_fsub[0].my_id, - topic_ids=[topic_1, topic_2], - data=b"1234", - seqno=b"\x00" * 8, - ) - - if is_topic_1_val_passed and is_topic_2_val_passed: - await pubsubs_fsub[0].validate_msg(pubsubs_fsub[0].my_id, msg) - else: - with pytest.raises(ValidationError): - await pubsubs_fsub[0].validate_msg(pubsubs_fsub[0].my_id, msg) - - -class FakeNetStream: - _queue: asyncio.Queue - - class FakeMplexConn(NamedTuple): - peer_id: ID = ID(b"\x12\x20" + b"\x00" * 32) - - muxed_conn = FakeMplexConn() - - def __init__(self) -> None: - self._queue = asyncio.Queue() - - async def read(self, n: int = -1) -> bytes: - buf = bytearray() - # Force to blocking wait if no data available now. - if self._queue.empty(): - first_byte = await self._queue.get() - buf.extend(first_byte) - # If `n == -1`, read until no data is in the buffer(_queue). - # Else, read until no data is in the buffer(_queue) or we have read `n` bytes. - while (n == -1) or (len(buf) < n): - if self._queue.empty(): - break - buf.extend(await self._queue.get()) - return bytes(buf) - - async def write(self, data: bytes) -> int: - for i in data: - await self._queue.put(i.to_bytes(1, "big")) - return len(data) - - -@pytest.mark.parametrize("num_hosts", (1,)) -@pytest.mark.asyncio -async def test_continuously_read_stream(pubsubs_fsub, monkeypatch): - stream = FakeNetStream() - - await pubsubs_fsub[0].subscribe(TESTING_TOPIC) - - event_push_msg = asyncio.Event() - event_handle_subscription = asyncio.Event() - event_handle_rpc = asyncio.Event() - - async def mock_push_msg(msg_forwarder, msg): - event_push_msg.set() - - def mock_handle_subscription(origin_id, sub_message): - event_handle_subscription.set() - - async def mock_handle_rpc(rpc, sender_peer_id): - event_handle_rpc.set() - - monkeypatch.setattr(pubsubs_fsub[0], "push_msg", mock_push_msg) - monkeypatch.setattr( - pubsubs_fsub[0], "handle_subscription", mock_handle_subscription - ) - monkeypatch.setattr(pubsubs_fsub[0].router, "handle_rpc", mock_handle_rpc) - - async def wait_for_event_occurring(event): - try: - await asyncio.wait_for(event.wait(), timeout=1) - except asyncio.TimeoutError as error: - event.clear() - raise asyncio.TimeoutError( - f"Event {event} is not set before the timeout. " - "This indicates the mocked functions are not called properly." - ) from error + if is_topic_1_val_passed: + pubsubs_fsub[0].set_topic_validator(topic_1, passed_sync_validator, False) else: - event.clear() + pubsubs_fsub[0].set_topic_validator(topic_1, failed_sync_validator, False) - # Kick off the task `continuously_read_stream` - task = asyncio.ensure_future(pubsubs_fsub[0].continuously_read_stream(stream)) + if is_topic_2_val_passed: + pubsubs_fsub[0].set_topic_validator(topic_2, passed_async_validator, True) + else: + pubsubs_fsub[0].set_topic_validator(topic_2, failed_async_validator, True) - # Test: `push_msg` is called when publishing to a subscribed topic. - publish_subscribed_topic = rpc_pb2.RPC( - publish=[rpc_pb2.Message(topicIDs=[TESTING_TOPIC])] - ) - await stream.write( - encode_varint_prefixed(publish_subscribed_topic.SerializeToString()) - ) - await wait_for_event_occurring(event_push_msg) - # Make sure the other events are not emitted. - with pytest.raises(asyncio.TimeoutError): - await wait_for_event_occurring(event_handle_subscription) - with pytest.raises(asyncio.TimeoutError): - await wait_for_event_occurring(event_handle_rpc) + msg = make_pubsub_msg( + origin_id=pubsubs_fsub[0].my_id, + topic_ids=[topic_1, topic_2], + data=b"1234", + seqno=b"\x00" * 8, + ) - # Test: `push_msg` is not called when publishing to a topic-not-subscribed. - publish_not_subscribed_topic = rpc_pb2.RPC( - publish=[rpc_pb2.Message(topicIDs=["NOT_SUBSCRIBED"])] - ) - await stream.write( - encode_varint_prefixed(publish_not_subscribed_topic.SerializeToString()) - ) - with pytest.raises(asyncio.TimeoutError): - await wait_for_event_occurring(event_push_msg) + if is_topic_1_val_passed and is_topic_2_val_passed: + await pubsubs_fsub[0].validate_msg(pubsubs_fsub[0].my_id, msg) + else: + with pytest.raises(ValidationError): + await pubsubs_fsub[0].validate_msg(pubsubs_fsub[0].my_id, msg) - # Test: `handle_subscription` is called when a subscription message is received. - subscription_msg = rpc_pb2.RPC(subscriptions=[rpc_pb2.RPC.SubOpts()]) - await stream.write(encode_varint_prefixed(subscription_msg.SerializeToString())) - await wait_for_event_occurring(event_handle_subscription) - # Make sure the other events are not emitted. - with pytest.raises(asyncio.TimeoutError): - await wait_for_event_occurring(event_push_msg) - with pytest.raises(asyncio.TimeoutError): - await wait_for_event_occurring(event_handle_rpc) - # Test: `handle_rpc` is called when a control message is received. - control_msg = rpc_pb2.RPC(control=rpc_pb2.ControlMessage()) - await stream.write(encode_varint_prefixed(control_msg.SerializeToString())) - await wait_for_event_occurring(event_handle_rpc) - # Make sure the other events are not emitted. - with pytest.raises(asyncio.TimeoutError): - await wait_for_event_occurring(event_push_msg) - with pytest.raises(asyncio.TimeoutError): - await wait_for_event_occurring(event_handle_subscription) +@pytest.mark.trio +async def test_continuously_read_stream(monkeypatch, nursery, is_host_secure): + async def wait_for_event_occurring(event): + with trio.fail_after(0.1): + await event.wait() - task.cancel() + class Events(NamedTuple): + push_msg: trio.Event + handle_subscription: trio.Event + handle_rpc: trio.Event + + @contextmanager + def mock_methods(): + event_push_msg = trio.Event() + event_handle_subscription = trio.Event() + event_handle_rpc = trio.Event() + + async def mock_push_msg(msg_forwarder, msg): + event_push_msg.set() + await trio.sleep(0) + + def mock_handle_subscription(origin_id, sub_message): + event_handle_subscription.set() + + async def mock_handle_rpc(rpc, sender_peer_id): + event_handle_rpc.set() + await trio.sleep(0) + + with monkeypatch.context() as m: + m.setattr(pubsubs_fsub[0], "push_msg", mock_push_msg) + m.setattr(pubsubs_fsub[0], "handle_subscription", mock_handle_subscription) + m.setattr(pubsubs_fsub[0].router, "handle_rpc", mock_handle_rpc) + yield Events(event_push_msg, event_handle_subscription, event_handle_rpc) + + async with PubsubFactory.create_batch_with_floodsub( + 1, is_secure=is_host_secure + ) as pubsubs_fsub, net_stream_pair_factory(is_secure=is_host_secure) as stream_pair: + await pubsubs_fsub[0].subscribe(TESTING_TOPIC) + # Kick off the task `continuously_read_stream` + nursery.start_soon(pubsubs_fsub[0].continuously_read_stream, stream_pair[0]) + + # Test: `push_msg` is called when publishing to a subscribed topic. + publish_subscribed_topic = rpc_pb2.RPC( + publish=[rpc_pb2.Message(topicIDs=[TESTING_TOPIC])] + ) + with mock_methods() as events: + await stream_pair[1].write( + encode_varint_prefixed(publish_subscribed_topic.SerializeToString()) + ) + await wait_for_event_occurring(events.push_msg) + # Make sure the other events are not emitted. + with pytest.raises(trio.TooSlowError): + await wait_for_event_occurring(events.handle_subscription) + with pytest.raises(trio.TooSlowError): + await wait_for_event_occurring(events.handle_rpc) + + # Test: `push_msg` is not called when publishing to a topic-not-subscribed. + publish_not_subscribed_topic = rpc_pb2.RPC( + publish=[rpc_pb2.Message(topicIDs=["NOT_SUBSCRIBED"])] + ) + with mock_methods() as events: + await stream_pair[1].write( + encode_varint_prefixed(publish_not_subscribed_topic.SerializeToString()) + ) + with pytest.raises(trio.TooSlowError): + await wait_for_event_occurring(events.push_msg) + + # Test: `handle_subscription` is called when a subscription message is received. + subscription_msg = rpc_pb2.RPC(subscriptions=[rpc_pb2.RPC.SubOpts()]) + with mock_methods() as events: + await stream_pair[1].write( + encode_varint_prefixed(subscription_msg.SerializeToString()) + ) + await wait_for_event_occurring(events.handle_subscription) + # Make sure the other events are not emitted. + with pytest.raises(trio.TooSlowError): + await wait_for_event_occurring(events.push_msg) + with pytest.raises(trio.TooSlowError): + await wait_for_event_occurring(events.handle_rpc) + + # Test: `handle_rpc` is called when a control message is received. + control_msg = rpc_pb2.RPC(control=rpc_pb2.ControlMessage()) + with mock_methods() as events: + await stream_pair[1].write( + encode_varint_prefixed(control_msg.SerializeToString()) + ) + await wait_for_event_occurring(events.handle_rpc) + # Make sure the other events are not emitted. + with pytest.raises(trio.TooSlowError): + await wait_for_event_occurring(events.push_msg) + with pytest.raises(trio.TooSlowError): + await wait_for_event_occurring(events.handle_subscription) # TODO: Add the following tests after they are aligned with Go. @@ -351,81 +331,84 @@ async def test_continuously_read_stream(pubsubs_fsub, monkeypatch): # - `test_handle_peer_queue` -@pytest.mark.parametrize("num_hosts", (1,)) -def test_handle_subscription(pubsubs_fsub): - assert len(pubsubs_fsub[0].peer_topics) == 0 - sub_msg_0 = rpc_pb2.RPC.SubOpts(subscribe=True, topicid=TESTING_TOPIC) - peer_ids = [ID(b"\x12\x20" + i.to_bytes(32, "big")) for i in range(2)] - # Test: One peer is subscribed - pubsubs_fsub[0].handle_subscription(peer_ids[0], sub_msg_0) - assert ( - len(pubsubs_fsub[0].peer_topics) == 1 - and TESTING_TOPIC in pubsubs_fsub[0].peer_topics - ) - assert len(pubsubs_fsub[0].peer_topics[TESTING_TOPIC]) == 1 - assert peer_ids[0] in pubsubs_fsub[0].peer_topics[TESTING_TOPIC] - # Test: Another peer is subscribed - pubsubs_fsub[0].handle_subscription(peer_ids[1], sub_msg_0) - assert len(pubsubs_fsub[0].peer_topics) == 1 - assert len(pubsubs_fsub[0].peer_topics[TESTING_TOPIC]) == 2 - assert peer_ids[1] in pubsubs_fsub[0].peer_topics[TESTING_TOPIC] - # Test: Subscribe to another topic - another_topic = "ANOTHER_TOPIC" - sub_msg_1 = rpc_pb2.RPC.SubOpts(subscribe=True, topicid=another_topic) - pubsubs_fsub[0].handle_subscription(peer_ids[0], sub_msg_1) - assert len(pubsubs_fsub[0].peer_topics) == 2 - assert another_topic in pubsubs_fsub[0].peer_topics - assert peer_ids[0] in pubsubs_fsub[0].peer_topics[another_topic] - # Test: unsubscribe - unsub_msg = rpc_pb2.RPC.SubOpts(subscribe=False, topicid=TESTING_TOPIC) - pubsubs_fsub[0].handle_subscription(peer_ids[0], unsub_msg) - assert peer_ids[0] not in pubsubs_fsub[0].peer_topics[TESTING_TOPIC] +@pytest.mark.trio +async def test_handle_subscription(): + async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub: + assert len(pubsubs_fsub[0].peer_topics) == 0 + sub_msg_0 = rpc_pb2.RPC.SubOpts(subscribe=True, topicid=TESTING_TOPIC) + peer_ids = [IDFactory() for _ in range(2)] + # Test: One peer is subscribed + pubsubs_fsub[0].handle_subscription(peer_ids[0], sub_msg_0) + assert ( + len(pubsubs_fsub[0].peer_topics) == 1 + and TESTING_TOPIC in pubsubs_fsub[0].peer_topics + ) + assert len(pubsubs_fsub[0].peer_topics[TESTING_TOPIC]) == 1 + assert peer_ids[0] in pubsubs_fsub[0].peer_topics[TESTING_TOPIC] + # Test: Another peer is subscribed + pubsubs_fsub[0].handle_subscription(peer_ids[1], sub_msg_0) + assert len(pubsubs_fsub[0].peer_topics) == 1 + assert len(pubsubs_fsub[0].peer_topics[TESTING_TOPIC]) == 2 + assert peer_ids[1] in pubsubs_fsub[0].peer_topics[TESTING_TOPIC] + # Test: Subscribe to another topic + another_topic = "ANOTHER_TOPIC" + sub_msg_1 = rpc_pb2.RPC.SubOpts(subscribe=True, topicid=another_topic) + pubsubs_fsub[0].handle_subscription(peer_ids[0], sub_msg_1) + assert len(pubsubs_fsub[0].peer_topics) == 2 + assert another_topic in pubsubs_fsub[0].peer_topics + assert peer_ids[0] in pubsubs_fsub[0].peer_topics[another_topic] + # Test: unsubscribe + unsub_msg = rpc_pb2.RPC.SubOpts(subscribe=False, topicid=TESTING_TOPIC) + pubsubs_fsub[0].handle_subscription(peer_ids[0], unsub_msg) + assert peer_ids[0] not in pubsubs_fsub[0].peer_topics[TESTING_TOPIC] -@pytest.mark.parametrize("num_hosts", (1,)) -@pytest.mark.asyncio -async def test_handle_talk(pubsubs_fsub): - sub = await pubsubs_fsub[0].subscribe(TESTING_TOPIC) - msg_0 = make_pubsub_msg( - origin_id=pubsubs_fsub[0].my_id, - topic_ids=[TESTING_TOPIC], - data=b"1234", - seqno=b"\x00" * 8, - ) - await pubsubs_fsub[0].handle_talk(msg_0) - msg_1 = make_pubsub_msg( - origin_id=pubsubs_fsub[0].my_id, - topic_ids=["NOT_SUBSCRIBED"], - data=b"1234", - seqno=b"\x11" * 8, - ) - await pubsubs_fsub[0].handle_talk(msg_1) - assert ( - len(pubsubs_fsub[0].my_topics) == 1 - and sub == pubsubs_fsub[0].my_topics[TESTING_TOPIC] - ) - assert sub.qsize() == 1 - assert (await sub.get()) == msg_0 +@pytest.mark.trio +async def test_handle_talk(): + async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub: + sub = await pubsubs_fsub[0].subscribe(TESTING_TOPIC) + msg_0 = make_pubsub_msg( + origin_id=pubsubs_fsub[0].my_id, + topic_ids=[TESTING_TOPIC], + data=b"1234", + seqno=b"\x00" * 8, + ) + await pubsubs_fsub[0].handle_talk(msg_0) + msg_1 = make_pubsub_msg( + origin_id=pubsubs_fsub[0].my_id, + topic_ids=["NOT_SUBSCRIBED"], + data=b"1234", + seqno=b"\x11" * 8, + ) + await pubsubs_fsub[0].handle_talk(msg_1) + assert ( + len(pubsubs_fsub[0].topic_ids) == 1 + and sub == pubsubs_fsub[0].subscribed_topics_receive[TESTING_TOPIC] + ) + assert (await sub.receive()) == msg_0 -@pytest.mark.parametrize("num_hosts", (1,)) -@pytest.mark.asyncio -async def test_message_all_peers(pubsubs_fsub, monkeypatch): - peer_ids = [ID(b"\x12\x20" + i.to_bytes(32, "big")) for i in range(10)] - mock_peers = {peer_id: FakeNetStream() for peer_id in peer_ids} - monkeypatch.setattr(pubsubs_fsub[0], "peers", mock_peers) +@pytest.mark.trio +async def test_message_all_peers(monkeypatch, is_host_secure): + async with PubsubFactory.create_batch_with_floodsub( + 1, is_secure=is_host_secure + ) as pubsubs_fsub, net_stream_pair_factory(is_secure=is_host_secure) as stream_pair: + peer_id = IDFactory() + mock_peers = {peer_id: stream_pair[0]} + with monkeypatch.context() as m: + m.setattr(pubsubs_fsub[0], "peers", mock_peers) - empty_rpc = rpc_pb2.RPC() - empty_rpc_bytes = empty_rpc.SerializeToString() - empty_rpc_bytes_len_prefixed = encode_varint_prefixed(empty_rpc_bytes) - await pubsubs_fsub[0].message_all_peers(empty_rpc_bytes) - for stream in mock_peers.values(): - assert (await stream.read()) == empty_rpc_bytes_len_prefixed + empty_rpc = rpc_pb2.RPC() + empty_rpc_bytes = empty_rpc.SerializeToString() + empty_rpc_bytes_len_prefixed = encode_varint_prefixed(empty_rpc_bytes) + await pubsubs_fsub[0].message_all_peers(empty_rpc_bytes) + assert ( + await stream_pair[1].read(MAX_READ_LEN) + ) == empty_rpc_bytes_len_prefixed -@pytest.mark.parametrize("num_hosts", (1,)) -@pytest.mark.asyncio -async def test_publish(pubsubs_fsub, monkeypatch): +@pytest.mark.trio +async def test_publish(monkeypatch): msg_forwarders = [] msgs = [] @@ -433,80 +416,97 @@ async def test_publish(pubsubs_fsub, monkeypatch): msg_forwarders.append(msg_forwarder) msgs.append(msg) - monkeypatch.setattr(pubsubs_fsub[0], "push_msg", push_msg) + async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub: + with monkeypatch.context() as m: + m.setattr(pubsubs_fsub[0], "push_msg", push_msg) - await pubsubs_fsub[0].publish(TESTING_TOPIC, TESTING_DATA) - await pubsubs_fsub[0].publish(TESTING_TOPIC, TESTING_DATA) + await pubsubs_fsub[0].publish(TESTING_TOPIC, TESTING_DATA) + await pubsubs_fsub[0].publish(TESTING_TOPIC, TESTING_DATA) - assert len(msgs) == 2, "`push_msg` should be called every time `publish` is called" - assert (msg_forwarders[0] == msg_forwarders[1]) and ( - msg_forwarders[1] == pubsubs_fsub[0].my_id - ) - assert msgs[0].seqno != msgs[1].seqno, "`seqno` should be different every time" + assert ( + len(msgs) == 2 + ), "`push_msg` should be called every time `publish` is called" + assert (msg_forwarders[0] == msg_forwarders[1]) and ( + msg_forwarders[1] == pubsubs_fsub[0].my_id + ) + assert ( + msgs[0].seqno != msgs[1].seqno + ), "`seqno` should be different every time" -@pytest.mark.parametrize("num_hosts", (1,)) -@pytest.mark.asyncio -async def test_push_msg(pubsubs_fsub, monkeypatch): - msg_0 = make_pubsub_msg( - origin_id=pubsubs_fsub[0].my_id, - topic_ids=[TESTING_TOPIC], - data=TESTING_DATA, - seqno=b"\x00" * 8, - ) +@pytest.mark.trio +async def test_push_msg(monkeypatch): + async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub: + msg_0 = make_pubsub_msg( + origin_id=pubsubs_fsub[0].my_id, + topic_ids=[TESTING_TOPIC], + data=TESTING_DATA, + seqno=b"\x00" * 8, + ) - event = asyncio.Event() + @contextmanager + def mock_router_publish(): - async def router_publish(*args, **kwargs): - event.set() + event = trio.Event() - monkeypatch.setattr(pubsubs_fsub[0].router, "publish", router_publish) + async def router_publish(*args, **kwargs): + event.set() + await trio.sleep(0) - # Test: `msg` is not seen before `push_msg`, and is seen after `push_msg`. - assert not pubsubs_fsub[0]._is_msg_seen(msg_0) - await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg_0) - assert pubsubs_fsub[0]._is_msg_seen(msg_0) - # Test: Ensure `router.publish` is called in `push_msg` - await asyncio.wait_for(event.wait(), timeout=0.1) + with monkeypatch.context() as m: + m.setattr(pubsubs_fsub[0].router, "publish", router_publish) + yield event - # Test: `push_msg` the message again and it will be reject. - # `router_publish` is not called then. - event.clear() - await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg_0) - await asyncio.sleep(0.01) - assert not event.is_set() + with mock_router_publish() as event: + # Test: `msg` is not seen before `push_msg`, and is seen after `push_msg`. + assert not pubsubs_fsub[0]._is_msg_seen(msg_0) + await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg_0) + assert pubsubs_fsub[0]._is_msg_seen(msg_0) + # Test: Ensure `router.publish` is called in `push_msg` + with trio.fail_after(0.1): + await event.wait() - sub = await pubsubs_fsub[0].subscribe(TESTING_TOPIC) - # Test: `push_msg` succeeds with another unseen msg. - msg_1 = make_pubsub_msg( - origin_id=pubsubs_fsub[0].my_id, - topic_ids=[TESTING_TOPIC], - data=TESTING_DATA, - seqno=b"\x11" * 8, - ) - assert not pubsubs_fsub[0]._is_msg_seen(msg_1) - await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg_1) - assert pubsubs_fsub[0]._is_msg_seen(msg_1) - await asyncio.wait_for(event.wait(), timeout=0.1) - # Test: Subscribers are notified when `push_msg` new messages. - assert (await sub.get()) == msg_1 + with mock_router_publish() as event: + # Test: `push_msg` the message again and it will be reject. + # `router_publish` is not called then. + await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg_0) + await trio.sleep(0.01) + assert not event.is_set() - # Test: add a topic validator and `push_msg` the message that - # does not pass the validation. - # `router_publish` is not called then. - def failed_sync_validator(peer_id, msg): - return False + sub = await pubsubs_fsub[0].subscribe(TESTING_TOPIC) + # Test: `push_msg` succeeds with another unseen msg. + msg_1 = make_pubsub_msg( + origin_id=pubsubs_fsub[0].my_id, + topic_ids=[TESTING_TOPIC], + data=TESTING_DATA, + seqno=b"\x11" * 8, + ) + assert not pubsubs_fsub[0]._is_msg_seen(msg_1) + await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg_1) + assert pubsubs_fsub[0]._is_msg_seen(msg_1) + with trio.fail_after(0.1): + await event.wait() + # Test: Subscribers are notified when `push_msg` new messages. + assert (await sub.receive()) == msg_1 - pubsubs_fsub[0].set_topic_validator(TESTING_TOPIC, failed_sync_validator, False) + with mock_router_publish() as event: + # Test: add a topic validator and `push_msg` the message that + # does not pass the validation. + # `router_publish` is not called then. + def failed_sync_validator(peer_id, msg): + return False - msg_2 = make_pubsub_msg( - origin_id=pubsubs_fsub[0].my_id, - topic_ids=[TESTING_TOPIC], - data=TESTING_DATA, - seqno=b"\x22" * 8, - ) + pubsubs_fsub[0].set_topic_validator( + TESTING_TOPIC, failed_sync_validator, False + ) - event.clear() - await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg_2) - await asyncio.sleep(0.01) - assert not event.is_set() + msg_2 = make_pubsub_msg( + origin_id=pubsubs_fsub[0].my_id, + topic_ids=[TESTING_TOPIC], + data=TESTING_DATA, + seqno=b"\x22" * 8, + ) + + await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg_2) + await trio.sleep(0.01) + assert not event.is_set()