Fix Pubsub

This commit is contained in:
mhchia 2019-12-03 17:27:49 +08:00
parent bdbb7b2394
commit e9ab0646e3
No known key found for this signature in database
GPG Key ID: 389EFBEA1362589A
7 changed files with 568 additions and 523 deletions

View File

@ -1,6 +1,8 @@
import logging import logging
from typing import Iterable, List, Sequence from typing import Iterable, List, Sequence
import trio
from libp2p.network.stream.exceptions import StreamClosed from libp2p.network.stream.exceptions import StreamClosed
from libp2p.peer.id import ID from libp2p.peer.id import ID
from libp2p.typing import TProtocol from libp2p.typing import TProtocol
@ -61,6 +63,8 @@ class FloodSub(IPubsubRouter):
:param rpc: rpc message :param rpc: rpc message
""" """
# Checkpoint
await trio.sleep(0)
async def publish(self, msg_forwarder: ID, pubsub_msg: rpc_pb2.Message) -> None: async def publish(self, msg_forwarder: ID, pubsub_msg: rpc_pb2.Message) -> None:
""" """
@ -102,6 +106,8 @@ class FloodSub(IPubsubRouter):
:param topic: topic to join :param topic: topic to join
""" """
# Checkpoint
await trio.sleep(0)
async def leave(self, topic: str) -> None: async def leave(self, topic: str) -> None:
""" """
@ -110,6 +116,8 @@ class FloodSub(IPubsubRouter):
:param topic: topic to leave :param topic: topic to leave
""" """
# Checkpoint
await trio.sleep(0)
def _get_peers_to_send( def _get_peers_to_send(
self, topic_ids: Iterable[str], msg_forwarder: ID, origin: ID self, topic_ids: Iterable[str], msg_forwarder: ID, origin: ID

View File

@ -1,11 +1,13 @@
import asyncio from abc import ABC, abstractmethod
import logging import logging
import math
import time import time
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Awaitable, Awaitable,
Callable, Callable,
Dict, Dict,
KeysView,
List, List,
NamedTuple, NamedTuple,
Tuple, Tuple,
@ -13,8 +15,10 @@ from typing import (
cast, cast,
) )
from async_service import Service
import base58 import base58
from lru import LRU from lru import LRU
import trio
from libp2p.exceptions import ParseError, ValidationError from libp2p.exceptions import ParseError, ValidationError
from libp2p.host.host_interface import IHost from libp2p.host.host_interface import IHost
@ -53,24 +57,24 @@ class TopicValidator(NamedTuple):
is_async: bool is_async: bool
class Pubsub: class BasePubsub(ABC):
pass
class Pubsub(BasePubsub, Service):
host: IHost host: IHost
my_id: ID
router: "IPubsubRouter" router: "IPubsubRouter"
peer_queue: "asyncio.Queue[ID]" peer_receive_channel: "trio.MemoryReceiveChannel[ID]"
dead_peer_queue: "asyncio.Queue[ID]" dead_peer_receive_channel: "trio.MemoryReceiveChannel[ID]"
protocols: List[TProtocol]
incoming_msgs_from_peers: "asyncio.Queue[rpc_pb2.Message]"
outgoing_messages: "asyncio.Queue[rpc_pb2.Message]"
seen_messages: LRU seen_messages: LRU
my_topics: Dict[str, "asyncio.Queue[rpc_pb2.Message]"] # TODO: Implement `trio.abc.Channel`?
subscribed_topics_send: Dict[str, "trio.MemorySendChannel[rpc_pb2.Message]"]
subscribed_topics_receive: Dict[str, "trio.MemoryReceiveChannel[rpc_pb2.Message]"]
peer_topics: Dict[str, List[ID]] peer_topics: Dict[str, List[ID]]
peers: Dict[ID, INetStream] peers: Dict[ID, INetStream]
@ -80,10 +84,8 @@ class Pubsub:
# TODO: Be sure it is increased atomically everytime. # TODO: Be sure it is increased atomically everytime.
counter: int # uint64 counter: int # uint64
_tasks: List["asyncio.Future[Any]"]
def __init__( def __init__(
self, host: IHost, router: "IPubsubRouter", my_id: ID, cache_size: int = None self, host: IHost, router: "IPubsubRouter", cache_size: int = None
) -> None: ) -> None:
""" """
Construct a new Pubsub object, which is responsible for handling all Construct a new Pubsub object, which is responsible for handling all
@ -97,28 +99,26 @@ class Pubsub:
""" """
self.host = host self.host = host
self.router = router self.router = router
self.my_id = my_id
# Attach this new Pubsub object to the router # Attach this new Pubsub object to the router
self.router.attach(self) self.router.attach(self)
peer_send_channel, peer_receive_channel = trio.open_memory_channel(0)
dead_peer_send_channel, dead_peer_receive_channel = trio.open_memory_channel(0)
# Only keep the receive channels in `Pubsub`.
# Therefore, we can only close from the receive side.
self.peer_receive_channel = peer_receive_channel
self.dead_peer_receive_channel = dead_peer_receive_channel
# Register a notifee # Register a notifee
self.peer_queue = asyncio.Queue()
self.dead_peer_queue = asyncio.Queue()
self.host.get_network().register_notifee( self.host.get_network().register_notifee(
PubsubNotifee(self.peer_queue, self.dead_peer_queue) PubsubNotifee(peer_send_channel, dead_peer_send_channel)
) )
# Register stream handlers for each pubsub router protocol to handle # Register stream handlers for each pubsub router protocol to handle
# the pubsub streams opened on those protocols # the pubsub streams opened on those protocols
self.protocols = self.router.get_protocols() for protocol in router.protocols:
for protocol in self.protocols:
self.host.set_stream_handler(protocol, self.stream_handler) self.host.set_stream_handler(protocol, self.stream_handler)
# Use asyncio queues for proper context switching
self.incoming_msgs_from_peers = asyncio.Queue()
self.outgoing_messages = asyncio.Queue()
# keeps track of seen messages as LRU cache # keeps track of seen messages as LRU cache
if cache_size is None: if cache_size is None:
self.cache_size = 128 self.cache_size = 128
@ -129,7 +129,8 @@ class Pubsub:
# Map of topics we are subscribed to blocking queues # Map of topics we are subscribed to blocking queues
# for when the given topic receives a message # for when the given topic receives a message
self.my_topics = {} self.subscribed_topics_send = {}
self.subscribed_topics_receive = {}
# Map of topic to peers to keep track of what peers are subscribed to # Map of topic to peers to keep track of what peers are subscribed to
self.peer_topics = {} self.peer_topics = {}
@ -142,16 +143,28 @@ class Pubsub:
self.counter = time.time_ns() self.counter = time.time_ns()
self._tasks = [] async def run(self) -> None:
# Call handle peer to keep waiting for updates to peer queue self.manager.run_daemon_task(self.handle_peer_queue)
self._tasks.append(asyncio.ensure_future(self.handle_peer_queue())) self.manager.run_daemon_task(self.handle_dead_peer_queue)
self._tasks.append(asyncio.ensure_future(self.handle_dead_peer_queue())) await self.manager.wait_finished()
@property
def my_id(self) -> ID:
return self.host.get_id()
@property
def protocols(self) -> Tuple[TProtocol, ...]:
return tuple(self.router.get_protocols())
@property
def topic_ids(self) -> KeysView[str]:
return self.subscribed_topics_receive.keys()
def get_hello_packet(self) -> rpc_pb2.RPC: def get_hello_packet(self) -> rpc_pb2.RPC:
"""Generate subscription message with all topics we are subscribed to """Generate subscription message with all topics we are subscribed to
only send hello packet if we have subscribed topics.""" only send hello packet if we have subscribed topics."""
packet = rpc_pb2.RPC() packet = rpc_pb2.RPC()
for topic_id in self.my_topics: for topic_id in self.topic_ids:
packet.subscriptions.extend( packet.subscriptions.extend(
[rpc_pb2.RPC.SubOpts(subscribe=True, topicid=topic_id)] [rpc_pb2.RPC.SubOpts(subscribe=True, topicid=topic_id)]
) )
@ -166,7 +179,7 @@ class Pubsub:
""" """
peer_id = stream.muxed_conn.peer_id peer_id = stream.muxed_conn.peer_id
while True: while self.manager.is_running:
incoming: bytes = await read_varint_prefixed_bytes(stream) incoming: bytes = await read_varint_prefixed_bytes(stream)
rpc_incoming: rpc_pb2.RPC = rpc_pb2.RPC() rpc_incoming: rpc_pb2.RPC = rpc_pb2.RPC()
rpc_incoming.ParseFromString(incoming) rpc_incoming.ParseFromString(incoming)
@ -178,11 +191,7 @@ class Pubsub:
logger.debug( logger.debug(
"received `publish` message %s from peer %s", msg, peer_id "received `publish` message %s from peer %s", msg, peer_id
) )
self._tasks.append( self.manager.run_task(self.push_msg, peer_id, msg)
asyncio.ensure_future(
self.push_msg(msg_forwarder=peer_id, msg=msg)
)
)
if rpc_incoming.subscriptions: if rpc_incoming.subscriptions:
# deal with RPC.subscriptions # deal with RPC.subscriptions
@ -210,9 +219,6 @@ class Pubsub:
) )
await self.router.handle_rpc(rpc_incoming, peer_id) await self.router.handle_rpc(rpc_incoming, peer_id)
# Force context switch
await asyncio.sleep(0)
def set_topic_validator( def set_topic_validator(
self, topic: str, validator: ValidatorFn, is_async_validator: bool self, topic: str, validator: ValidatorFn, is_async_validator: bool
) -> None: ) -> None:
@ -285,7 +291,6 @@ class Pubsub:
logger.debug("Fail to add new peer %s: stream closed", peer_id) logger.debug("Fail to add new peer %s: stream closed", peer_id)
del self.peers[peer_id] del self.peers[peer_id]
return return
# TODO: Check EOF of this stream.
# TODO: Check if the peer in black list. # TODO: Check if the peer in black list.
try: try:
self.router.add_peer(peer_id, stream.get_protocol()) self.router.add_peer(peer_id, stream.get_protocol())
@ -311,23 +316,25 @@ class Pubsub:
async def handle_peer_queue(self) -> None: async def handle_peer_queue(self) -> None:
""" """
Continuously read from peer queue and each time a new peer is found, Continuously read from peer channel and each time a new peer is found,
open a stream to the peer using a supported pubsub protocol open a stream to the peer using a supported pubsub protocol
TODO: Handle failure for when the peer does not support any of the TODO: Handle failure for when the peer does not support any of the
pubsub protocols we support pubsub protocols we support
""" """
while True: async with self.peer_receive_channel:
peer_id: ID = await self.peer_queue.get() while self.manager.is_running:
# Add Peer peer_id: ID = await self.peer_receive_channel.receive()
self._tasks.append(asyncio.ensure_future(self._handle_new_peer(peer_id))) # Add Peer
self.manager.run_task(self._handle_new_peer, peer_id)
async def handle_dead_peer_queue(self) -> None: async def handle_dead_peer_queue(self) -> None:
"""Continuously read from dead peer queue and close the stream between """Continuously read from dead peer channel and close the stream between
that peer and remove peer info from pubsub and pubsub router.""" that peer and remove peer info from pubsub and pubsub router."""
while True: async with self.dead_peer_receive_channel:
peer_id: ID = await self.dead_peer_queue.get() while self.manager.is_running:
# Remove Peer peer_id: ID = await self.dead_peer_receive_channel.receive()
self._handle_dead_peer(peer_id) # Remove Peer
self._handle_dead_peer(peer_id)
def handle_subscription( def handle_subscription(
self, origin_id: ID, sub_message: rpc_pb2.RPC.SubOpts self, origin_id: ID, sub_message: rpc_pb2.RPC.SubOpts
@ -361,13 +368,16 @@ class Pubsub:
# Check if this message has any topics that we are subscribed to # Check if this message has any topics that we are subscribed to
for topic in publish_message.topicIDs: for topic in publish_message.topicIDs:
if topic in self.my_topics: if topic in self.topic_ids:
# we are subscribed to a topic this message was sent for, # we are subscribed to a topic this message was sent for,
# so add message to the subscription output queue # so add message to the subscription output queue
# for each topic # for each topic
await self.my_topics[topic].put(publish_message) await self.subscribed_topics_send[topic].send(publish_message)
async def subscribe(self, topic_id: str) -> "asyncio.Queue[rpc_pb2.Message]": # TODO: Change to return an `AsyncIterable` to be I/O-agnostic?
async def subscribe(
self, topic_id: str
) -> "trio.MemoryReceiveChannel[rpc_pb2.Message]":
""" """
Subscribe ourself to a topic. Subscribe ourself to a topic.
@ -377,11 +387,13 @@ class Pubsub:
logger.debug("subscribing to topic %s", topic_id) logger.debug("subscribing to topic %s", topic_id)
# Already subscribed # Already subscribed
if topic_id in self.my_topics: if topic_id in self.topic_ids:
return self.my_topics[topic_id] return self.subscribed_topics_receive[topic_id]
# Map topic_id to blocking queue # Map topic_id to a blocking channel
self.my_topics[topic_id] = asyncio.Queue() send_channel, receive_channel = trio.open_memory_channel(math.inf)
self.subscribed_topics_send[topic_id] = send_channel
self.subscribed_topics_receive[topic_id] = receive_channel
# Create subscribe message # Create subscribe message
packet: rpc_pb2.RPC = rpc_pb2.RPC() packet: rpc_pb2.RPC = rpc_pb2.RPC()
@ -395,8 +407,8 @@ class Pubsub:
# Tell router we are joining this topic # Tell router we are joining this topic
await self.router.join(topic_id) await self.router.join(topic_id)
# Return the asyncio queue for messages on this topic # Return the trio channel for messages on this topic
return self.my_topics[topic_id] return receive_channel
async def unsubscribe(self, topic_id: str) -> None: async def unsubscribe(self, topic_id: str) -> None:
""" """
@ -408,10 +420,14 @@ class Pubsub:
logger.debug("unsubscribing from topic %s", topic_id) logger.debug("unsubscribing from topic %s", topic_id)
# Return if we already unsubscribed from the topic # Return if we already unsubscribed from the topic
if topic_id not in self.my_topics: if topic_id not in self.topic_ids:
return return
# Remove topic_id from map if present # Remove topic_id from the maps before yielding
del self.my_topics[topic_id] send_channel = self.subscribed_topics_send[topic_id]
del self.subscribed_topics_send[topic_id]
del self.subscribed_topics_receive[topic_id]
# Only close the send side
await send_channel.aclose()
# Create unsubscribe message # Create unsubscribe message
packet: rpc_pb2.RPC = rpc_pb2.RPC() packet: rpc_pb2.RPC = rpc_pb2.RPC()
@ -453,13 +469,13 @@ class Pubsub:
data=data, data=data,
topicIDs=[topic_id], topicIDs=[topic_id],
# Origin is ourself. # Origin is ourself.
from_id=self.host.get_id().to_bytes(), from_id=self.my_id.to_bytes(),
seqno=self._next_seqno(), seqno=self._next_seqno(),
) )
# TODO: Sign with our signing key # TODO: Sign with our signing key
await self.push_msg(self.host.get_id(), msg) await self.push_msg(self.my_id, msg)
logger.debug("successfully published message %s", msg) logger.debug("successfully published message %s", msg)
@ -470,12 +486,12 @@ class Pubsub:
:param msg_forwarder: the peer who forward us the message. :param msg_forwarder: the peer who forward us the message.
:param msg: the message. :param msg: the message.
""" """
sync_topic_validators = [] sync_topic_validators: List[SyncValidatorFn] = []
async_topic_validator_futures: List[Awaitable[bool]] = [] async_topic_validators: List[AsyncValidatorFn] = []
for topic_validator in self.get_msg_validators(msg): for topic_validator in self.get_msg_validators(msg):
if topic_validator.is_async: if topic_validator.is_async:
async_topic_validator_futures.append( async_topic_validators.append(
cast(Awaitable[bool], topic_validator.validator(msg_forwarder, msg)) cast(AsyncValidatorFn, topic_validator.validator)
) )
else: else:
sync_topic_validators.append( sync_topic_validators.append(
@ -488,9 +504,20 @@ class Pubsub:
# TODO: Implement throttle on async validators # TODO: Implement throttle on async validators
if len(async_topic_validator_futures) > 0: if len(async_topic_validators) > 0:
results = await asyncio.gather(*async_topic_validator_futures) # TODO: Use a better pattern
if not all(results): final_result = True
async def run_async_validator(func: AsyncValidatorFn) -> None:
nonlocal final_result
result = await func(msg_forwarder, msg)
final_result = final_result and result
async with trio.open_nursery() as nursery:
for validator in async_topic_validators:
nursery.start_soon(run_async_validator, validator)
if not final_result:
raise ValidationError(f"Validation failed for msg={msg}") raise ValidationError(f"Validation failed for msg={msg}")
async def push_msg(self, msg_forwarder: ID, msg: rpc_pb2.Message) -> None: async def push_msg(self, msg_forwarder: ID, msg: rpc_pb2.Message) -> None:
@ -551,14 +578,4 @@ class Pubsub:
self.seen_messages[msg_id] = 1 self.seen_messages[msg_id] = 1
def _is_subscribed_to_msg(self, msg: rpc_pb2.Message) -> bool: def _is_subscribed_to_msg(self, msg: rpc_pb2.Message) -> bool:
if not self.my_topics: return any(topic in self.topic_ids for topic in msg.topicIDs)
return False
return any(topic in self.my_topics for topic in msg.topicIDs)
async def close(self) -> None:
for task in self._tasks:
task.cancel()
try:
await task
except asyncio.CancelledError:
pass

View File

@ -8,19 +8,19 @@ from libp2p.network.notifee_interface import INotifee
from libp2p.network.stream.net_stream_interface import INetStream from libp2p.network.stream.net_stream_interface import INetStream
if TYPE_CHECKING: if TYPE_CHECKING:
import asyncio # noqa: F401 import trio # noqa: F401
from libp2p.peer.id import ID # noqa: F401 from libp2p.peer.id import ID # noqa: F401
class PubsubNotifee(INotifee): class PubsubNotifee(INotifee):
initiator_peers_queue: "asyncio.Queue[ID]" initiator_peers_queue: "trio.MemorySendChannel[ID]"
dead_peers_queue: "asyncio.Queue[ID]" dead_peers_queue: "trio.MemorySendChannel[ID]"
def __init__( def __init__(
self, self,
initiator_peers_queue: "asyncio.Queue[ID]", initiator_peers_queue: "trio.MemorySendChannel[ID]",
dead_peers_queue: "asyncio.Queue[ID]", dead_peers_queue: "trio.MemorySendChannel[ID]",
) -> None: ) -> None:
""" """
:param initiator_peers_queue: queue to add new peers to so that pubsub :param initiator_peers_queue: queue to add new peers to so that pubsub
@ -46,7 +46,12 @@ class PubsubNotifee(INotifee):
:param network: network the connection was opened on :param network: network the connection was opened on
:param conn: connection that was opened :param conn: connection that was opened
""" """
await self.initiator_peers_queue.put(conn.muxed_conn.peer_id) try:
await self.initiator_peers_queue.send(conn.muxed_conn.peer_id)
except trio.BrokenResourceError:
# Raised when the receive channel is closed.
# TODO: Do something with loggers?
...
async def disconnected(self, network: INetwork, conn: INetConn) -> None: async def disconnected(self, network: INetwork, conn: INetConn) -> None:
""" """
@ -56,7 +61,7 @@ class PubsubNotifee(INotifee):
:param network: network the connection was opened on :param network: network the connection was opened on
:param conn: connection that was opened :param conn: connection that was opened
""" """
await self.dead_peers_queue.put(conn.muxed_conn.peer_id) await self.dead_peers_queue.send(conn.muxed_conn.peer_id)
async def listen(self, network: INetwork, multiaddr: Multiaddr) -> None: async def listen(self, network: INetwork, multiaddr: Multiaddr) -> None:
pass pass

View File

@ -281,7 +281,7 @@ class Mplex(IMuxedConn, Service):
mplex_stream = await self._initialize_stream(stream_id, message.decode()) mplex_stream = await self._initialize_stream(stream_id, message.decode())
try: try:
await self.new_stream_send_channel.send(mplex_stream) await self.new_stream_send_channel.send(mplex_stream)
except (trio.BrokenResourceError, trio.EndOfChannel): except (trio.BrokenResourceError, trio.ClosedResourceError):
raise MplexUnavailable raise MplexUnavailable
async def _handle_message(self, stream_id: StreamID, message: bytes) -> None: async def _handle_message(self, stream_id: StreamID, message: bytes) -> None:

View File

@ -5,6 +5,7 @@ from async_service import background_trio_service
import factory import factory
import trio import trio
from libp2p.tools.constants import GOSSIPSUB_PARAMS
from libp2p import generate_new_rsa_identity, generate_peer_id_from from libp2p import generate_new_rsa_identity, generate_peer_id_from
from libp2p.crypto.keys import KeyPair from libp2p.crypto.keys import KeyPair
from libp2p.host.basic_host import BasicHost from libp2p.host.basic_host import BasicHost
@ -15,6 +16,7 @@ from libp2p.network.connection.swarm_connection import SwarmConn
from libp2p.network.stream.net_stream_interface import INetStream from libp2p.network.stream.net_stream_interface import INetStream
from libp2p.network.swarm import Swarm from libp2p.network.swarm import Swarm
from libp2p.peer.peerstore import PeerStore from libp2p.peer.peerstore import PeerStore
from libp2p.peer.id import ID
from libp2p.pubsub.floodsub import FloodSub from libp2p.pubsub.floodsub import FloodSub
from libp2p.pubsub.gossipsub import GossipSub from libp2p.pubsub.gossipsub import GossipSub
from libp2p.pubsub.pubsub import Pubsub from libp2p.pubsub.pubsub import Pubsub
@ -28,15 +30,19 @@ from libp2p.transport.typing import TMuxerOptions
from libp2p.transport.upgrader import TransportUpgrader from libp2p.transport.upgrader import TransportUpgrader
from libp2p.typing import TProtocol from libp2p.typing import TProtocol
from .constants import ( from .constants import FLOODSUB_PROTOCOL_ID, GOSSIPSUB_PROTOCOL_ID, LISTEN_MADDR
FLOODSUB_PROTOCOL_ID,
GOSSIPSUB_PARAMS,
GOSSIPSUB_PROTOCOL_ID,
LISTEN_MADDR,
)
from .utils import connect, connect_swarm from .utils import connect, connect_swarm
class IDFactory(factory.Factory):
class Meta:
model = ID
peer_id_bytes = factory.LazyFunction(
lambda: generate_peer_id_from(generate_new_rsa_identity())
)
def security_transport_factory( def security_transport_factory(
is_secure: bool, key_pair: KeyPair is_secure: bool, key_pair: KeyPair
) -> Dict[TProtocol, BaseSecureTransport]: ) -> Dict[TProtocol, BaseSecureTransport]:
@ -181,9 +187,38 @@ class PubsubFactory(factory.Factory):
host = factory.SubFactory(HostFactory) host = factory.SubFactory(HostFactory)
router = None router = None
my_id = factory.LazyAttribute(lambda obj: obj.host.get_id())
cache_size = None cache_size = None
@classmethod
@asynccontextmanager
async def create_and_start(cls, host, router, cache_size):
pubsub = PubsubFactory(host=host, router=router, cache_size=cache_size)
async with background_trio_service(pubsub):
yield pubsub
@classmethod
@asynccontextmanager
async def create_batch_with_floodsub(
cls, number: int, is_secure: bool = False, cache_size: int = None
):
floodsubs = FloodsubFactory.create_batch(number)
async with HostFactory.create_batch_and_listen(is_secure, number) as hosts:
# Pubsubs should exit before hosts
async with AsyncExitStack() as stack:
pubsubs = [
await stack.enter_async_context(
cls.create_and_start(host, router, cache_size)
)
for host, router in zip(hosts, floodsubs)
]
yield pubsubs
# @classmethod
# async def create_batch_with_gossipsub(
# cls, number: int, cache_size: int = None, gossipsub_params=GOSSIPSUB_PARAMS
# ):
# ...
@asynccontextmanager @asynccontextmanager
async def swarm_pair_factory( async def swarm_pair_factory(

View File

@ -4,18 +4,6 @@ from libp2p.tools.constants import GOSSIPSUB_PARAMS
from libp2p.tools.factories import FloodsubFactory, GossipsubFactory, PubsubFactory from libp2p.tools.factories import FloodsubFactory, GossipsubFactory, PubsubFactory
def _make_pubsubs(hosts, pubsub_routers, cache_size):
if len(pubsub_routers) != len(hosts):
raise ValueError(
f"lenght of pubsub_routers={pubsub_routers} should be equaled to the "
f"length of hosts={len(hosts)}"
)
return tuple(
PubsubFactory(host=host, router=router, cache_size=cache_size)
for host, router in zip(hosts, pubsub_routers)
)
@pytest.fixture @pytest.fixture
def pubsub_cache_size(): def pubsub_cache_size():
return None # default return None # default
@ -26,17 +14,9 @@ def gossipsub_params():
return GOSSIPSUB_PARAMS return GOSSIPSUB_PARAMS
@pytest.fixture # @pytest.fixture
def pubsubs_fsub(num_hosts, hosts, pubsub_cache_size): # def pubsubs_gsub(num_hosts, hosts, pubsub_cache_size, gossipsub_params):
floodsubs = FloodsubFactory.create_batch(num_hosts) # gossipsubs = GossipsubFactory.create_batch(num_hosts, **gossipsub_params._asdict())
_pubsubs_fsub = _make_pubsubs(hosts, floodsubs, pubsub_cache_size) # _pubsubs_gsub = _make_pubsubs(hosts, gossipsubs, pubsub_cache_size)
yield _pubsubs_fsub # yield _pubsubs_gsub
# TODO: Clean up # # TODO: Clean up
@pytest.fixture
def pubsubs_gsub(num_hosts, hosts, pubsub_cache_size, gossipsub_params):
gossipsubs = GossipsubFactory.create_batch(num_hosts, **gossipsub_params._asdict())
_pubsubs_gsub = _make_pubsubs(hosts, gossipsubs, pubsub_cache_size)
yield _pubsubs_gsub
# TODO: Clean up

View File

@ -1,348 +1,328 @@
import asyncio from contextlib import contextmanager
from typing import NamedTuple from typing import NamedTuple
import pytest import pytest
import trio
from libp2p.exceptions import ValidationError from libp2p.exceptions import ValidationError
from libp2p.peer.id import ID from libp2p.peer.id import ID
from libp2p.pubsub.pb import rpc_pb2 from libp2p.pubsub.pb import rpc_pb2
from libp2p.tools.pubsub.utils import make_pubsub_msg from libp2p.tools.pubsub.utils import make_pubsub_msg
from libp2p.tools.utils import connect from libp2p.tools.utils import connect
from libp2p.tools.constants import MAX_READ_LEN
from libp2p.tools.factories import PubsubFactory, net_stream_pair_factory, IDFactory
from libp2p.utils import encode_varint_prefixed from libp2p.utils import encode_varint_prefixed
TESTING_TOPIC = "TEST_SUBSCRIBE" TESTING_TOPIC = "TEST_SUBSCRIBE"
TESTING_DATA = b"data" TESTING_DATA = b"data"
@pytest.mark.parametrize("num_hosts", (1,)) @pytest.mark.trio
@pytest.mark.asyncio async def test_subscribe_and_unsubscribe():
async def test_subscribe_and_unsubscribe(pubsubs_fsub): async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
await pubsubs_fsub[0].subscribe(TESTING_TOPIC) await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
assert TESTING_TOPIC in pubsubs_fsub[0].my_topics assert TESTING_TOPIC in pubsubs_fsub[0].topic_ids
await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC) await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC)
assert TESTING_TOPIC not in pubsubs_fsub[0].my_topics assert TESTING_TOPIC not in pubsubs_fsub[0].topic_ids
@pytest.mark.parametrize("num_hosts", (1,)) @pytest.mark.trio
@pytest.mark.asyncio async def test_re_subscribe():
async def test_re_subscribe(pubsubs_fsub): async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
await pubsubs_fsub[0].subscribe(TESTING_TOPIC) await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
assert TESTING_TOPIC in pubsubs_fsub[0].my_topics assert TESTING_TOPIC in pubsubs_fsub[0].topic_ids
await pubsubs_fsub[0].subscribe(TESTING_TOPIC) await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
assert TESTING_TOPIC in pubsubs_fsub[0].my_topics assert TESTING_TOPIC in pubsubs_fsub[0].topic_ids
@pytest.mark.parametrize("num_hosts", (1,)) @pytest.mark.trio
@pytest.mark.asyncio async def test_re_unsubscribe():
async def test_re_unsubscribe(pubsubs_fsub): async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
# Unsubscribe from topic we didn't even subscribe to # Unsubscribe from topic we didn't even subscribe to
assert "NOT_MY_TOPIC" not in pubsubs_fsub[0].my_topics assert "NOT_MY_TOPIC" not in pubsubs_fsub[0].topic_ids
await pubsubs_fsub[0].unsubscribe("NOT_MY_TOPIC") await pubsubs_fsub[0].unsubscribe("NOT_MY_TOPIC")
assert "NOT_MY_TOPIC" not in pubsubs_fsub[0].my_topics assert "NOT_MY_TOPIC" not in pubsubs_fsub[0].topic_ids
await pubsubs_fsub[0].subscribe(TESTING_TOPIC) await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
assert TESTING_TOPIC in pubsubs_fsub[0].my_topics assert TESTING_TOPIC in pubsubs_fsub[0].topic_ids
await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC) await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC)
assert TESTING_TOPIC not in pubsubs_fsub[0].my_topics assert TESTING_TOPIC not in pubsubs_fsub[0].topic_ids
await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC) await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC)
assert TESTING_TOPIC not in pubsubs_fsub[0].my_topics assert TESTING_TOPIC not in pubsubs_fsub[0].topic_ids
@pytest.mark.asyncio @pytest.mark.trio
async def test_peers_subscribe(pubsubs_fsub): async def test_peers_subscribe():
await connect(pubsubs_fsub[0].host, pubsubs_fsub[1].host) async with PubsubFactory.create_batch_with_floodsub(2) as pubsubs_fsub:
await pubsubs_fsub[0].subscribe(TESTING_TOPIC) await connect(pubsubs_fsub[0].host, pubsubs_fsub[1].host)
# Yield to let 0 notify 1 await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
await asyncio.sleep(0.1) # Yield to let 0 notify 1
assert pubsubs_fsub[0].my_id in pubsubs_fsub[1].peer_topics[TESTING_TOPIC] await trio.sleep(0.1)
await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC) assert pubsubs_fsub[0].my_id in pubsubs_fsub[1].peer_topics[TESTING_TOPIC]
# Yield to let 0 notify 1 await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC)
await asyncio.sleep(0.1) # Yield to let 0 notify 1
assert pubsubs_fsub[0].my_id not in pubsubs_fsub[1].peer_topics[TESTING_TOPIC] await trio.sleep(0.1)
assert pubsubs_fsub[0].my_id not in pubsubs_fsub[1].peer_topics[TESTING_TOPIC]
@pytest.mark.parametrize("num_hosts", (1,)) @pytest.mark.trio
@pytest.mark.asyncio async def test_get_hello_packet():
async def test_get_hello_packet(pubsubs_fsub): async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
def _get_hello_packet_topic_ids():
packet = pubsubs_fsub[0].get_hello_packet()
return tuple(sub.topicid for sub in packet.subscriptions)
# Test: No subscription, so there should not be any topic ids in the hello packet. def _get_hello_packet_topic_ids():
assert len(_get_hello_packet_topic_ids()) == 0 packet = pubsubs_fsub[0].get_hello_packet()
return tuple(sub.topicid for sub in packet.subscriptions)
# Test: After subscriptions, topic ids should be in the hello packet. # Test: No subscription, so there should not be any topic ids in the hello packet.
topic_ids = ["t", "o", "p", "i", "c"] assert len(_get_hello_packet_topic_ids()) == 0
await asyncio.gather(*[pubsubs_fsub[0].subscribe(topic) for topic in topic_ids])
topic_ids_in_hello = _get_hello_packet_topic_ids() # Test: After subscriptions, topic ids should be in the hello packet.
for topic in topic_ids: topic_ids = ["t", "o", "p", "i", "c"]
assert topic in topic_ids_in_hello for topic in topic_ids:
await pubsubs_fsub[0].subscribe(topic)
topic_ids_in_hello = _get_hello_packet_topic_ids()
for topic in topic_ids:
assert topic in topic_ids_in_hello
@pytest.mark.parametrize("num_hosts", (1,)) @pytest.mark.trio
@pytest.mark.asyncio async def test_set_and_remove_topic_validator():
async def test_set_and_remove_topic_validator(pubsubs_fsub): async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
is_sync_validator_called = False
is_sync_validator_called = False def sync_validator(peer_id, msg):
nonlocal is_sync_validator_called
is_sync_validator_called = True
def sync_validator(peer_id, msg): is_async_validator_called = False
nonlocal is_sync_validator_called
is_sync_validator_called = True
is_async_validator_called = False async def async_validator(peer_id, msg):
nonlocal is_async_validator_called
is_async_validator_called = True
async def async_validator(peer_id, msg): topic = "TEST_VALIDATOR"
nonlocal is_async_validator_called
is_async_validator_called = True
topic = "TEST_VALIDATOR" assert topic not in pubsubs_fsub[0].topic_validators
assert topic not in pubsubs_fsub[0].topic_validators # Register sync validator
pubsubs_fsub[0].set_topic_validator(topic, sync_validator, False)
# Register sync validator assert topic in pubsubs_fsub[0].topic_validators
pubsubs_fsub[0].set_topic_validator(topic, sync_validator, False) topic_validator = pubsubs_fsub[0].topic_validators[topic]
assert not topic_validator.is_async
assert topic in pubsubs_fsub[0].topic_validators # Validate with sync validator
topic_validator = pubsubs_fsub[0].topic_validators[topic] topic_validator.validator(peer_id=IDFactory(), msg="msg")
assert not topic_validator.is_async
# Validate with sync validator assert is_sync_validator_called
topic_validator.validator(peer_id=ID(b"peer"), msg="msg") assert not is_async_validator_called
assert is_sync_validator_called # Register with async validator
assert not is_async_validator_called pubsubs_fsub[0].set_topic_validator(topic, async_validator, True)
# Register with async validator is_sync_validator_called = False
pubsubs_fsub[0].set_topic_validator(topic, async_validator, True) assert topic in pubsubs_fsub[0].topic_validators
topic_validator = pubsubs_fsub[0].topic_validators[topic]
assert topic_validator.is_async
is_sync_validator_called = False # Validate with async validator
assert topic in pubsubs_fsub[0].topic_validators await topic_validator.validator(peer_id=IDFactory(), msg="msg")
topic_validator = pubsubs_fsub[0].topic_validators[topic]
assert topic_validator.is_async
# Validate with async validator assert is_async_validator_called
await topic_validator.validator(peer_id=ID(b"peer"), msg="msg") assert not is_sync_validator_called
assert is_async_validator_called # Remove validator
assert not is_sync_validator_called pubsubs_fsub[0].remove_topic_validator(topic)
assert topic not in pubsubs_fsub[0].topic_validators
# Remove validator
pubsubs_fsub[0].remove_topic_validator(topic)
assert topic not in pubsubs_fsub[0].topic_validators
@pytest.mark.parametrize("num_hosts", (1,)) @pytest.mark.trio
@pytest.mark.asyncio async def test_get_msg_validators():
async def test_get_msg_validators(pubsubs_fsub): async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
times_sync_validator_called = 0
times_sync_validator_called = 0 def sync_validator(peer_id, msg):
nonlocal times_sync_validator_called
times_sync_validator_called += 1
def sync_validator(peer_id, msg): times_async_validator_called = 0
nonlocal times_sync_validator_called
times_sync_validator_called += 1
times_async_validator_called = 0 async def async_validator(peer_id, msg):
nonlocal times_async_validator_called
times_async_validator_called += 1
async def async_validator(peer_id, msg): topic_1 = "TEST_VALIDATOR_1"
nonlocal times_async_validator_called topic_2 = "TEST_VALIDATOR_2"
times_async_validator_called += 1 topic_3 = "TEST_VALIDATOR_3"
topic_1 = "TEST_VALIDATOR_1" # Register sync validator for topic 1 and 2
topic_2 = "TEST_VALIDATOR_2" pubsubs_fsub[0].set_topic_validator(topic_1, sync_validator, False)
topic_3 = "TEST_VALIDATOR_3" pubsubs_fsub[0].set_topic_validator(topic_2, sync_validator, False)
# Register sync validator for topic 1 and 2 # Register async validator for topic 3
pubsubs_fsub[0].set_topic_validator(topic_1, sync_validator, False) pubsubs_fsub[0].set_topic_validator(topic_3, async_validator, True)
pubsubs_fsub[0].set_topic_validator(topic_2, sync_validator, False)
# Register async validator for topic 3 msg = make_pubsub_msg(
pubsubs_fsub[0].set_topic_validator(topic_3, async_validator, True) origin_id=pubsubs_fsub[0].my_id,
topic_ids=[topic_1, topic_2, topic_3],
data=b"1234",
seqno=b"\x00" * 8,
)
msg = make_pubsub_msg( topic_validators = pubsubs_fsub[0].get_msg_validators(msg)
origin_id=pubsubs_fsub[0].my_id, for topic_validator in topic_validators:
topic_ids=[topic_1, topic_2, topic_3], if topic_validator.is_async:
data=b"1234", await topic_validator.validator(peer_id=IDFactory(), msg="msg")
seqno=b"\x00" * 8, else:
) topic_validator.validator(peer_id=IDFactory(), msg="msg")
topic_validators = pubsubs_fsub[0].get_msg_validators(msg) assert times_sync_validator_called == 2
for topic_validator in topic_validators: assert times_async_validator_called == 1
if topic_validator.is_async:
await topic_validator.validator(peer_id=ID(b"peer"), msg="msg")
else:
topic_validator.validator(peer_id=ID(b"peer"), msg="msg")
assert times_sync_validator_called == 2
assert times_async_validator_called == 1
@pytest.mark.parametrize("num_hosts", (1,))
@pytest.mark.parametrize( @pytest.mark.parametrize(
"is_topic_1_val_passed, is_topic_2_val_passed", "is_topic_1_val_passed, is_topic_2_val_passed",
((False, True), (True, False), (True, True)), ((False, True), (True, False), (True, True)),
) )
@pytest.mark.asyncio @pytest.mark.trio
async def test_validate_msg(pubsubs_fsub, is_topic_1_val_passed, is_topic_2_val_passed): async def test_validate_msg(is_topic_1_val_passed, is_topic_2_val_passed):
def passed_sync_validator(peer_id, msg): async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
return True
def failed_sync_validator(peer_id, msg): def passed_sync_validator(peer_id, msg):
return False return True
async def passed_async_validator(peer_id, msg): def failed_sync_validator(peer_id, msg):
return True return False
async def failed_async_validator(peer_id, msg): async def passed_async_validator(peer_id, msg):
return False return True
topic_1 = "TEST_SYNC_VALIDATOR" async def failed_async_validator(peer_id, msg):
topic_2 = "TEST_ASYNC_VALIDATOR" return False
if is_topic_1_val_passed: topic_1 = "TEST_SYNC_VALIDATOR"
pubsubs_fsub[0].set_topic_validator(topic_1, passed_sync_validator, False) topic_2 = "TEST_ASYNC_VALIDATOR"
else:
pubsubs_fsub[0].set_topic_validator(topic_1, failed_sync_validator, False)
if is_topic_2_val_passed: if is_topic_1_val_passed:
pubsubs_fsub[0].set_topic_validator(topic_2, passed_async_validator, True) pubsubs_fsub[0].set_topic_validator(topic_1, passed_sync_validator, False)
else:
pubsubs_fsub[0].set_topic_validator(topic_2, failed_async_validator, True)
msg = make_pubsub_msg(
origin_id=pubsubs_fsub[0].my_id,
topic_ids=[topic_1, topic_2],
data=b"1234",
seqno=b"\x00" * 8,
)
if is_topic_1_val_passed and is_topic_2_val_passed:
await pubsubs_fsub[0].validate_msg(pubsubs_fsub[0].my_id, msg)
else:
with pytest.raises(ValidationError):
await pubsubs_fsub[0].validate_msg(pubsubs_fsub[0].my_id, msg)
class FakeNetStream:
_queue: asyncio.Queue
class FakeMplexConn(NamedTuple):
peer_id: ID = ID(b"\x12\x20" + b"\x00" * 32)
muxed_conn = FakeMplexConn()
def __init__(self) -> None:
self._queue = asyncio.Queue()
async def read(self, n: int = -1) -> bytes:
buf = bytearray()
# Force to blocking wait if no data available now.
if self._queue.empty():
first_byte = await self._queue.get()
buf.extend(first_byte)
# If `n == -1`, read until no data is in the buffer(_queue).
# Else, read until no data is in the buffer(_queue) or we have read `n` bytes.
while (n == -1) or (len(buf) < n):
if self._queue.empty():
break
buf.extend(await self._queue.get())
return bytes(buf)
async def write(self, data: bytes) -> int:
for i in data:
await self._queue.put(i.to_bytes(1, "big"))
return len(data)
@pytest.mark.parametrize("num_hosts", (1,))
@pytest.mark.asyncio
async def test_continuously_read_stream(pubsubs_fsub, monkeypatch):
stream = FakeNetStream()
await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
event_push_msg = asyncio.Event()
event_handle_subscription = asyncio.Event()
event_handle_rpc = asyncio.Event()
async def mock_push_msg(msg_forwarder, msg):
event_push_msg.set()
def mock_handle_subscription(origin_id, sub_message):
event_handle_subscription.set()
async def mock_handle_rpc(rpc, sender_peer_id):
event_handle_rpc.set()
monkeypatch.setattr(pubsubs_fsub[0], "push_msg", mock_push_msg)
monkeypatch.setattr(
pubsubs_fsub[0], "handle_subscription", mock_handle_subscription
)
monkeypatch.setattr(pubsubs_fsub[0].router, "handle_rpc", mock_handle_rpc)
async def wait_for_event_occurring(event):
try:
await asyncio.wait_for(event.wait(), timeout=1)
except asyncio.TimeoutError as error:
event.clear()
raise asyncio.TimeoutError(
f"Event {event} is not set before the timeout. "
"This indicates the mocked functions are not called properly."
) from error
else: else:
event.clear() pubsubs_fsub[0].set_topic_validator(topic_1, failed_sync_validator, False)
# Kick off the task `continuously_read_stream` if is_topic_2_val_passed:
task = asyncio.ensure_future(pubsubs_fsub[0].continuously_read_stream(stream)) pubsubs_fsub[0].set_topic_validator(topic_2, passed_async_validator, True)
else:
pubsubs_fsub[0].set_topic_validator(topic_2, failed_async_validator, True)
# Test: `push_msg` is called when publishing to a subscribed topic. msg = make_pubsub_msg(
publish_subscribed_topic = rpc_pb2.RPC( origin_id=pubsubs_fsub[0].my_id,
publish=[rpc_pb2.Message(topicIDs=[TESTING_TOPIC])] topic_ids=[topic_1, topic_2],
) data=b"1234",
await stream.write( seqno=b"\x00" * 8,
encode_varint_prefixed(publish_subscribed_topic.SerializeToString()) )
)
await wait_for_event_occurring(event_push_msg)
# Make sure the other events are not emitted.
with pytest.raises(asyncio.TimeoutError):
await wait_for_event_occurring(event_handle_subscription)
with pytest.raises(asyncio.TimeoutError):
await wait_for_event_occurring(event_handle_rpc)
# Test: `push_msg` is not called when publishing to a topic-not-subscribed. if is_topic_1_val_passed and is_topic_2_val_passed:
publish_not_subscribed_topic = rpc_pb2.RPC( await pubsubs_fsub[0].validate_msg(pubsubs_fsub[0].my_id, msg)
publish=[rpc_pb2.Message(topicIDs=["NOT_SUBSCRIBED"])] else:
) with pytest.raises(ValidationError):
await stream.write( await pubsubs_fsub[0].validate_msg(pubsubs_fsub[0].my_id, msg)
encode_varint_prefixed(publish_not_subscribed_topic.SerializeToString())
)
with pytest.raises(asyncio.TimeoutError):
await wait_for_event_occurring(event_push_msg)
# Test: `handle_subscription` is called when a subscription message is received.
subscription_msg = rpc_pb2.RPC(subscriptions=[rpc_pb2.RPC.SubOpts()])
await stream.write(encode_varint_prefixed(subscription_msg.SerializeToString()))
await wait_for_event_occurring(event_handle_subscription)
# Make sure the other events are not emitted.
with pytest.raises(asyncio.TimeoutError):
await wait_for_event_occurring(event_push_msg)
with pytest.raises(asyncio.TimeoutError):
await wait_for_event_occurring(event_handle_rpc)
# Test: `handle_rpc` is called when a control message is received. @pytest.mark.trio
control_msg = rpc_pb2.RPC(control=rpc_pb2.ControlMessage()) async def test_continuously_read_stream(monkeypatch, nursery, is_host_secure):
await stream.write(encode_varint_prefixed(control_msg.SerializeToString())) async def wait_for_event_occurring(event):
await wait_for_event_occurring(event_handle_rpc) with trio.fail_after(0.1):
# Make sure the other events are not emitted. await event.wait()
with pytest.raises(asyncio.TimeoutError):
await wait_for_event_occurring(event_push_msg)
with pytest.raises(asyncio.TimeoutError):
await wait_for_event_occurring(event_handle_subscription)
task.cancel() class Events(NamedTuple):
push_msg: trio.Event
handle_subscription: trio.Event
handle_rpc: trio.Event
@contextmanager
def mock_methods():
event_push_msg = trio.Event()
event_handle_subscription = trio.Event()
event_handle_rpc = trio.Event()
async def mock_push_msg(msg_forwarder, msg):
event_push_msg.set()
await trio.sleep(0)
def mock_handle_subscription(origin_id, sub_message):
event_handle_subscription.set()
async def mock_handle_rpc(rpc, sender_peer_id):
event_handle_rpc.set()
await trio.sleep(0)
with monkeypatch.context() as m:
m.setattr(pubsubs_fsub[0], "push_msg", mock_push_msg)
m.setattr(pubsubs_fsub[0], "handle_subscription", mock_handle_subscription)
m.setattr(pubsubs_fsub[0].router, "handle_rpc", mock_handle_rpc)
yield Events(event_push_msg, event_handle_subscription, event_handle_rpc)
async with PubsubFactory.create_batch_with_floodsub(
1, is_secure=is_host_secure
) as pubsubs_fsub, net_stream_pair_factory(is_secure=is_host_secure) as stream_pair:
await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
# Kick off the task `continuously_read_stream`
nursery.start_soon(pubsubs_fsub[0].continuously_read_stream, stream_pair[0])
# Test: `push_msg` is called when publishing to a subscribed topic.
publish_subscribed_topic = rpc_pb2.RPC(
publish=[rpc_pb2.Message(topicIDs=[TESTING_TOPIC])]
)
with mock_methods() as events:
await stream_pair[1].write(
encode_varint_prefixed(publish_subscribed_topic.SerializeToString())
)
await wait_for_event_occurring(events.push_msg)
# Make sure the other events are not emitted.
with pytest.raises(trio.TooSlowError):
await wait_for_event_occurring(events.handle_subscription)
with pytest.raises(trio.TooSlowError):
await wait_for_event_occurring(events.handle_rpc)
# Test: `push_msg` is not called when publishing to a topic-not-subscribed.
publish_not_subscribed_topic = rpc_pb2.RPC(
publish=[rpc_pb2.Message(topicIDs=["NOT_SUBSCRIBED"])]
)
with mock_methods() as events:
await stream_pair[1].write(
encode_varint_prefixed(publish_not_subscribed_topic.SerializeToString())
)
with pytest.raises(trio.TooSlowError):
await wait_for_event_occurring(events.push_msg)
# Test: `handle_subscription` is called when a subscription message is received.
subscription_msg = rpc_pb2.RPC(subscriptions=[rpc_pb2.RPC.SubOpts()])
with mock_methods() as events:
await stream_pair[1].write(
encode_varint_prefixed(subscription_msg.SerializeToString())
)
await wait_for_event_occurring(events.handle_subscription)
# Make sure the other events are not emitted.
with pytest.raises(trio.TooSlowError):
await wait_for_event_occurring(events.push_msg)
with pytest.raises(trio.TooSlowError):
await wait_for_event_occurring(events.handle_rpc)
# Test: `handle_rpc` is called when a control message is received.
control_msg = rpc_pb2.RPC(control=rpc_pb2.ControlMessage())
with mock_methods() as events:
await stream_pair[1].write(
encode_varint_prefixed(control_msg.SerializeToString())
)
await wait_for_event_occurring(events.handle_rpc)
# Make sure the other events are not emitted.
with pytest.raises(trio.TooSlowError):
await wait_for_event_occurring(events.push_msg)
with pytest.raises(trio.TooSlowError):
await wait_for_event_occurring(events.handle_subscription)
# TODO: Add the following tests after they are aligned with Go. # TODO: Add the following tests after they are aligned with Go.
@ -351,81 +331,84 @@ async def test_continuously_read_stream(pubsubs_fsub, monkeypatch):
# - `test_handle_peer_queue` # - `test_handle_peer_queue`
@pytest.mark.parametrize("num_hosts", (1,)) @pytest.mark.trio
def test_handle_subscription(pubsubs_fsub): async def test_handle_subscription():
assert len(pubsubs_fsub[0].peer_topics) == 0 async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
sub_msg_0 = rpc_pb2.RPC.SubOpts(subscribe=True, topicid=TESTING_TOPIC) assert len(pubsubs_fsub[0].peer_topics) == 0
peer_ids = [ID(b"\x12\x20" + i.to_bytes(32, "big")) for i in range(2)] sub_msg_0 = rpc_pb2.RPC.SubOpts(subscribe=True, topicid=TESTING_TOPIC)
# Test: One peer is subscribed peer_ids = [IDFactory() for _ in range(2)]
pubsubs_fsub[0].handle_subscription(peer_ids[0], sub_msg_0) # Test: One peer is subscribed
assert ( pubsubs_fsub[0].handle_subscription(peer_ids[0], sub_msg_0)
len(pubsubs_fsub[0].peer_topics) == 1 assert (
and TESTING_TOPIC in pubsubs_fsub[0].peer_topics len(pubsubs_fsub[0].peer_topics) == 1
) and TESTING_TOPIC in pubsubs_fsub[0].peer_topics
assert len(pubsubs_fsub[0].peer_topics[TESTING_TOPIC]) == 1 )
assert peer_ids[0] in pubsubs_fsub[0].peer_topics[TESTING_TOPIC] assert len(pubsubs_fsub[0].peer_topics[TESTING_TOPIC]) == 1
# Test: Another peer is subscribed assert peer_ids[0] in pubsubs_fsub[0].peer_topics[TESTING_TOPIC]
pubsubs_fsub[0].handle_subscription(peer_ids[1], sub_msg_0) # Test: Another peer is subscribed
assert len(pubsubs_fsub[0].peer_topics) == 1 pubsubs_fsub[0].handle_subscription(peer_ids[1], sub_msg_0)
assert len(pubsubs_fsub[0].peer_topics[TESTING_TOPIC]) == 2 assert len(pubsubs_fsub[0].peer_topics) == 1
assert peer_ids[1] in pubsubs_fsub[0].peer_topics[TESTING_TOPIC] assert len(pubsubs_fsub[0].peer_topics[TESTING_TOPIC]) == 2
# Test: Subscribe to another topic assert peer_ids[1] in pubsubs_fsub[0].peer_topics[TESTING_TOPIC]
another_topic = "ANOTHER_TOPIC" # Test: Subscribe to another topic
sub_msg_1 = rpc_pb2.RPC.SubOpts(subscribe=True, topicid=another_topic) another_topic = "ANOTHER_TOPIC"
pubsubs_fsub[0].handle_subscription(peer_ids[0], sub_msg_1) sub_msg_1 = rpc_pb2.RPC.SubOpts(subscribe=True, topicid=another_topic)
assert len(pubsubs_fsub[0].peer_topics) == 2 pubsubs_fsub[0].handle_subscription(peer_ids[0], sub_msg_1)
assert another_topic in pubsubs_fsub[0].peer_topics assert len(pubsubs_fsub[0].peer_topics) == 2
assert peer_ids[0] in pubsubs_fsub[0].peer_topics[another_topic] assert another_topic in pubsubs_fsub[0].peer_topics
# Test: unsubscribe assert peer_ids[0] in pubsubs_fsub[0].peer_topics[another_topic]
unsub_msg = rpc_pb2.RPC.SubOpts(subscribe=False, topicid=TESTING_TOPIC) # Test: unsubscribe
pubsubs_fsub[0].handle_subscription(peer_ids[0], unsub_msg) unsub_msg = rpc_pb2.RPC.SubOpts(subscribe=False, topicid=TESTING_TOPIC)
assert peer_ids[0] not in pubsubs_fsub[0].peer_topics[TESTING_TOPIC] pubsubs_fsub[0].handle_subscription(peer_ids[0], unsub_msg)
assert peer_ids[0] not in pubsubs_fsub[0].peer_topics[TESTING_TOPIC]
@pytest.mark.parametrize("num_hosts", (1,)) @pytest.mark.trio
@pytest.mark.asyncio async def test_handle_talk():
async def test_handle_talk(pubsubs_fsub): async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
sub = await pubsubs_fsub[0].subscribe(TESTING_TOPIC) sub = await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
msg_0 = make_pubsub_msg( msg_0 = make_pubsub_msg(
origin_id=pubsubs_fsub[0].my_id, origin_id=pubsubs_fsub[0].my_id,
topic_ids=[TESTING_TOPIC], topic_ids=[TESTING_TOPIC],
data=b"1234", data=b"1234",
seqno=b"\x00" * 8, seqno=b"\x00" * 8,
) )
await pubsubs_fsub[0].handle_talk(msg_0) await pubsubs_fsub[0].handle_talk(msg_0)
msg_1 = make_pubsub_msg( msg_1 = make_pubsub_msg(
origin_id=pubsubs_fsub[0].my_id, origin_id=pubsubs_fsub[0].my_id,
topic_ids=["NOT_SUBSCRIBED"], topic_ids=["NOT_SUBSCRIBED"],
data=b"1234", data=b"1234",
seqno=b"\x11" * 8, seqno=b"\x11" * 8,
) )
await pubsubs_fsub[0].handle_talk(msg_1) await pubsubs_fsub[0].handle_talk(msg_1)
assert ( assert (
len(pubsubs_fsub[0].my_topics) == 1 len(pubsubs_fsub[0].topic_ids) == 1
and sub == pubsubs_fsub[0].my_topics[TESTING_TOPIC] and sub == pubsubs_fsub[0].subscribed_topics_receive[TESTING_TOPIC]
) )
assert sub.qsize() == 1 assert (await sub.receive()) == msg_0
assert (await sub.get()) == msg_0
@pytest.mark.parametrize("num_hosts", (1,)) @pytest.mark.trio
@pytest.mark.asyncio async def test_message_all_peers(monkeypatch, is_host_secure):
async def test_message_all_peers(pubsubs_fsub, monkeypatch): async with PubsubFactory.create_batch_with_floodsub(
peer_ids = [ID(b"\x12\x20" + i.to_bytes(32, "big")) for i in range(10)] 1, is_secure=is_host_secure
mock_peers = {peer_id: FakeNetStream() for peer_id in peer_ids} ) as pubsubs_fsub, net_stream_pair_factory(is_secure=is_host_secure) as stream_pair:
monkeypatch.setattr(pubsubs_fsub[0], "peers", mock_peers) peer_id = IDFactory()
mock_peers = {peer_id: stream_pair[0]}
with monkeypatch.context() as m:
m.setattr(pubsubs_fsub[0], "peers", mock_peers)
empty_rpc = rpc_pb2.RPC() empty_rpc = rpc_pb2.RPC()
empty_rpc_bytes = empty_rpc.SerializeToString() empty_rpc_bytes = empty_rpc.SerializeToString()
empty_rpc_bytes_len_prefixed = encode_varint_prefixed(empty_rpc_bytes) empty_rpc_bytes_len_prefixed = encode_varint_prefixed(empty_rpc_bytes)
await pubsubs_fsub[0].message_all_peers(empty_rpc_bytes) await pubsubs_fsub[0].message_all_peers(empty_rpc_bytes)
for stream in mock_peers.values(): assert (
assert (await stream.read()) == empty_rpc_bytes_len_prefixed await stream_pair[1].read(MAX_READ_LEN)
) == empty_rpc_bytes_len_prefixed
@pytest.mark.parametrize("num_hosts", (1,)) @pytest.mark.trio
@pytest.mark.asyncio async def test_publish(monkeypatch):
async def test_publish(pubsubs_fsub, monkeypatch):
msg_forwarders = [] msg_forwarders = []
msgs = [] msgs = []
@ -433,80 +416,97 @@ async def test_publish(pubsubs_fsub, monkeypatch):
msg_forwarders.append(msg_forwarder) msg_forwarders.append(msg_forwarder)
msgs.append(msg) msgs.append(msg)
monkeypatch.setattr(pubsubs_fsub[0], "push_msg", push_msg) async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
with monkeypatch.context() as m:
m.setattr(pubsubs_fsub[0], "push_msg", push_msg)
await pubsubs_fsub[0].publish(TESTING_TOPIC, TESTING_DATA) await pubsubs_fsub[0].publish(TESTING_TOPIC, TESTING_DATA)
await pubsubs_fsub[0].publish(TESTING_TOPIC, TESTING_DATA) await pubsubs_fsub[0].publish(TESTING_TOPIC, TESTING_DATA)
assert len(msgs) == 2, "`push_msg` should be called every time `publish` is called" assert (
assert (msg_forwarders[0] == msg_forwarders[1]) and ( len(msgs) == 2
msg_forwarders[1] == pubsubs_fsub[0].my_id ), "`push_msg` should be called every time `publish` is called"
) assert (msg_forwarders[0] == msg_forwarders[1]) and (
assert msgs[0].seqno != msgs[1].seqno, "`seqno` should be different every time" msg_forwarders[1] == pubsubs_fsub[0].my_id
)
assert (
msgs[0].seqno != msgs[1].seqno
), "`seqno` should be different every time"
@pytest.mark.parametrize("num_hosts", (1,)) @pytest.mark.trio
@pytest.mark.asyncio async def test_push_msg(monkeypatch):
async def test_push_msg(pubsubs_fsub, monkeypatch): async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
msg_0 = make_pubsub_msg( msg_0 = make_pubsub_msg(
origin_id=pubsubs_fsub[0].my_id, origin_id=pubsubs_fsub[0].my_id,
topic_ids=[TESTING_TOPIC], topic_ids=[TESTING_TOPIC],
data=TESTING_DATA, data=TESTING_DATA,
seqno=b"\x00" * 8, seqno=b"\x00" * 8,
) )
event = asyncio.Event() @contextmanager
def mock_router_publish():
async def router_publish(*args, **kwargs): event = trio.Event()
event.set()
monkeypatch.setattr(pubsubs_fsub[0].router, "publish", router_publish) async def router_publish(*args, **kwargs):
event.set()
await trio.sleep(0)
# Test: `msg` is not seen before `push_msg`, and is seen after `push_msg`. with monkeypatch.context() as m:
assert not pubsubs_fsub[0]._is_msg_seen(msg_0) m.setattr(pubsubs_fsub[0].router, "publish", router_publish)
await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg_0) yield event
assert pubsubs_fsub[0]._is_msg_seen(msg_0)
# Test: Ensure `router.publish` is called in `push_msg`
await asyncio.wait_for(event.wait(), timeout=0.1)
# Test: `push_msg` the message again and it will be reject. with mock_router_publish() as event:
# `router_publish` is not called then. # Test: `msg` is not seen before `push_msg`, and is seen after `push_msg`.
event.clear() assert not pubsubs_fsub[0]._is_msg_seen(msg_0)
await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg_0) await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg_0)
await asyncio.sleep(0.01) assert pubsubs_fsub[0]._is_msg_seen(msg_0)
assert not event.is_set() # Test: Ensure `router.publish` is called in `push_msg`
with trio.fail_after(0.1):
await event.wait()
sub = await pubsubs_fsub[0].subscribe(TESTING_TOPIC) with mock_router_publish() as event:
# Test: `push_msg` succeeds with another unseen msg. # Test: `push_msg` the message again and it will be reject.
msg_1 = make_pubsub_msg( # `router_publish` is not called then.
origin_id=pubsubs_fsub[0].my_id, await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg_0)
topic_ids=[TESTING_TOPIC], await trio.sleep(0.01)
data=TESTING_DATA, assert not event.is_set()
seqno=b"\x11" * 8,
)
assert not pubsubs_fsub[0]._is_msg_seen(msg_1)
await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg_1)
assert pubsubs_fsub[0]._is_msg_seen(msg_1)
await asyncio.wait_for(event.wait(), timeout=0.1)
# Test: Subscribers are notified when `push_msg` new messages.
assert (await sub.get()) == msg_1
# Test: add a topic validator and `push_msg` the message that sub = await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
# does not pass the validation. # Test: `push_msg` succeeds with another unseen msg.
# `router_publish` is not called then. msg_1 = make_pubsub_msg(
def failed_sync_validator(peer_id, msg): origin_id=pubsubs_fsub[0].my_id,
return False topic_ids=[TESTING_TOPIC],
data=TESTING_DATA,
seqno=b"\x11" * 8,
)
assert not pubsubs_fsub[0]._is_msg_seen(msg_1)
await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg_1)
assert pubsubs_fsub[0]._is_msg_seen(msg_1)
with trio.fail_after(0.1):
await event.wait()
# Test: Subscribers are notified when `push_msg` new messages.
assert (await sub.receive()) == msg_1
pubsubs_fsub[0].set_topic_validator(TESTING_TOPIC, failed_sync_validator, False) with mock_router_publish() as event:
# Test: add a topic validator and `push_msg` the message that
# does not pass the validation.
# `router_publish` is not called then.
def failed_sync_validator(peer_id, msg):
return False
msg_2 = make_pubsub_msg( pubsubs_fsub[0].set_topic_validator(
origin_id=pubsubs_fsub[0].my_id, TESTING_TOPIC, failed_sync_validator, False
topic_ids=[TESTING_TOPIC], )
data=TESTING_DATA,
seqno=b"\x22" * 8,
)
event.clear() msg_2 = make_pubsub_msg(
await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg_2) origin_id=pubsubs_fsub[0].my_id,
await asyncio.sleep(0.01) topic_ids=[TESTING_TOPIC],
assert not event.is_set() data=TESTING_DATA,
seqno=b"\x22" * 8,
)
await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg_2)
await trio.sleep(0.01)
assert not event.is_set()