Fix Pubsub
This commit is contained in:
parent
bdbb7b2394
commit
e9ab0646e3
|
@ -1,6 +1,8 @@
|
|||
import logging
|
||||
from typing import Iterable, List, Sequence
|
||||
|
||||
import trio
|
||||
|
||||
from libp2p.network.stream.exceptions import StreamClosed
|
||||
from libp2p.peer.id import ID
|
||||
from libp2p.typing import TProtocol
|
||||
|
@ -61,6 +63,8 @@ class FloodSub(IPubsubRouter):
|
|||
|
||||
:param rpc: rpc message
|
||||
"""
|
||||
# Checkpoint
|
||||
await trio.sleep(0)
|
||||
|
||||
async def publish(self, msg_forwarder: ID, pubsub_msg: rpc_pb2.Message) -> None:
|
||||
"""
|
||||
|
@ -102,6 +106,8 @@ class FloodSub(IPubsubRouter):
|
|||
|
||||
:param topic: topic to join
|
||||
"""
|
||||
# Checkpoint
|
||||
await trio.sleep(0)
|
||||
|
||||
async def leave(self, topic: str) -> None:
|
||||
"""
|
||||
|
@ -110,6 +116,8 @@ class FloodSub(IPubsubRouter):
|
|||
|
||||
:param topic: topic to leave
|
||||
"""
|
||||
# Checkpoint
|
||||
await trio.sleep(0)
|
||||
|
||||
def _get_peers_to_send(
|
||||
self, topic_ids: Iterable[str], msg_forwarder: ID, origin: ID
|
||||
|
|
|
@ -1,11 +1,13 @@
|
|||
import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
import logging
|
||||
import math
|
||||
import time
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Dict,
|
||||
KeysView,
|
||||
List,
|
||||
NamedTuple,
|
||||
Tuple,
|
||||
|
@ -13,8 +15,10 @@ from typing import (
|
|||
cast,
|
||||
)
|
||||
|
||||
from async_service import Service
|
||||
import base58
|
||||
from lru import LRU
|
||||
import trio
|
||||
|
||||
from libp2p.exceptions import ParseError, ValidationError
|
||||
from libp2p.host.host_interface import IHost
|
||||
|
@ -53,24 +57,24 @@ class TopicValidator(NamedTuple):
|
|||
is_async: bool
|
||||
|
||||
|
||||
class Pubsub:
|
||||
class BasePubsub(ABC):
|
||||
pass
|
||||
|
||||
|
||||
class Pubsub(BasePubsub, Service):
|
||||
|
||||
host: IHost
|
||||
my_id: ID
|
||||
|
||||
router: "IPubsubRouter"
|
||||
|
||||
peer_queue: "asyncio.Queue[ID]"
|
||||
dead_peer_queue: "asyncio.Queue[ID]"
|
||||
|
||||
protocols: List[TProtocol]
|
||||
|
||||
incoming_msgs_from_peers: "asyncio.Queue[rpc_pb2.Message]"
|
||||
outgoing_messages: "asyncio.Queue[rpc_pb2.Message]"
|
||||
peer_receive_channel: "trio.MemoryReceiveChannel[ID]"
|
||||
dead_peer_receive_channel: "trio.MemoryReceiveChannel[ID]"
|
||||
|
||||
seen_messages: LRU
|
||||
|
||||
my_topics: Dict[str, "asyncio.Queue[rpc_pb2.Message]"]
|
||||
# TODO: Implement `trio.abc.Channel`?
|
||||
subscribed_topics_send: Dict[str, "trio.MemorySendChannel[rpc_pb2.Message]"]
|
||||
subscribed_topics_receive: Dict[str, "trio.MemoryReceiveChannel[rpc_pb2.Message]"]
|
||||
|
||||
peer_topics: Dict[str, List[ID]]
|
||||
peers: Dict[ID, INetStream]
|
||||
|
@ -80,10 +84,8 @@ class Pubsub:
|
|||
# TODO: Be sure it is increased atomically everytime.
|
||||
counter: int # uint64
|
||||
|
||||
_tasks: List["asyncio.Future[Any]"]
|
||||
|
||||
def __init__(
|
||||
self, host: IHost, router: "IPubsubRouter", my_id: ID, cache_size: int = None
|
||||
self, host: IHost, router: "IPubsubRouter", cache_size: int = None
|
||||
) -> None:
|
||||
"""
|
||||
Construct a new Pubsub object, which is responsible for handling all
|
||||
|
@ -97,28 +99,26 @@ class Pubsub:
|
|||
"""
|
||||
self.host = host
|
||||
self.router = router
|
||||
self.my_id = my_id
|
||||
|
||||
# Attach this new Pubsub object to the router
|
||||
self.router.attach(self)
|
||||
|
||||
peer_send_channel, peer_receive_channel = trio.open_memory_channel(0)
|
||||
dead_peer_send_channel, dead_peer_receive_channel = trio.open_memory_channel(0)
|
||||
# Only keep the receive channels in `Pubsub`.
|
||||
# Therefore, we can only close from the receive side.
|
||||
self.peer_receive_channel = peer_receive_channel
|
||||
self.dead_peer_receive_channel = dead_peer_receive_channel
|
||||
# Register a notifee
|
||||
self.peer_queue = asyncio.Queue()
|
||||
self.dead_peer_queue = asyncio.Queue()
|
||||
self.host.get_network().register_notifee(
|
||||
PubsubNotifee(self.peer_queue, self.dead_peer_queue)
|
||||
PubsubNotifee(peer_send_channel, dead_peer_send_channel)
|
||||
)
|
||||
|
||||
# Register stream handlers for each pubsub router protocol to handle
|
||||
# the pubsub streams opened on those protocols
|
||||
self.protocols = self.router.get_protocols()
|
||||
for protocol in self.protocols:
|
||||
for protocol in router.protocols:
|
||||
self.host.set_stream_handler(protocol, self.stream_handler)
|
||||
|
||||
# Use asyncio queues for proper context switching
|
||||
self.incoming_msgs_from_peers = asyncio.Queue()
|
||||
self.outgoing_messages = asyncio.Queue()
|
||||
|
||||
# keeps track of seen messages as LRU cache
|
||||
if cache_size is None:
|
||||
self.cache_size = 128
|
||||
|
@ -129,7 +129,8 @@ class Pubsub:
|
|||
|
||||
# Map of topics we are subscribed to blocking queues
|
||||
# for when the given topic receives a message
|
||||
self.my_topics = {}
|
||||
self.subscribed_topics_send = {}
|
||||
self.subscribed_topics_receive = {}
|
||||
|
||||
# Map of topic to peers to keep track of what peers are subscribed to
|
||||
self.peer_topics = {}
|
||||
|
@ -142,16 +143,28 @@ class Pubsub:
|
|||
|
||||
self.counter = time.time_ns()
|
||||
|
||||
self._tasks = []
|
||||
# Call handle peer to keep waiting for updates to peer queue
|
||||
self._tasks.append(asyncio.ensure_future(self.handle_peer_queue()))
|
||||
self._tasks.append(asyncio.ensure_future(self.handle_dead_peer_queue()))
|
||||
async def run(self) -> None:
|
||||
self.manager.run_daemon_task(self.handle_peer_queue)
|
||||
self.manager.run_daemon_task(self.handle_dead_peer_queue)
|
||||
await self.manager.wait_finished()
|
||||
|
||||
@property
|
||||
def my_id(self) -> ID:
|
||||
return self.host.get_id()
|
||||
|
||||
@property
|
||||
def protocols(self) -> Tuple[TProtocol, ...]:
|
||||
return tuple(self.router.get_protocols())
|
||||
|
||||
@property
|
||||
def topic_ids(self) -> KeysView[str]:
|
||||
return self.subscribed_topics_receive.keys()
|
||||
|
||||
def get_hello_packet(self) -> rpc_pb2.RPC:
|
||||
"""Generate subscription message with all topics we are subscribed to
|
||||
only send hello packet if we have subscribed topics."""
|
||||
packet = rpc_pb2.RPC()
|
||||
for topic_id in self.my_topics:
|
||||
for topic_id in self.topic_ids:
|
||||
packet.subscriptions.extend(
|
||||
[rpc_pb2.RPC.SubOpts(subscribe=True, topicid=topic_id)]
|
||||
)
|
||||
|
@ -166,7 +179,7 @@ class Pubsub:
|
|||
"""
|
||||
peer_id = stream.muxed_conn.peer_id
|
||||
|
||||
while True:
|
||||
while self.manager.is_running:
|
||||
incoming: bytes = await read_varint_prefixed_bytes(stream)
|
||||
rpc_incoming: rpc_pb2.RPC = rpc_pb2.RPC()
|
||||
rpc_incoming.ParseFromString(incoming)
|
||||
|
@ -178,11 +191,7 @@ class Pubsub:
|
|||
logger.debug(
|
||||
"received `publish` message %s from peer %s", msg, peer_id
|
||||
)
|
||||
self._tasks.append(
|
||||
asyncio.ensure_future(
|
||||
self.push_msg(msg_forwarder=peer_id, msg=msg)
|
||||
)
|
||||
)
|
||||
self.manager.run_task(self.push_msg, peer_id, msg)
|
||||
|
||||
if rpc_incoming.subscriptions:
|
||||
# deal with RPC.subscriptions
|
||||
|
@ -210,9 +219,6 @@ class Pubsub:
|
|||
)
|
||||
await self.router.handle_rpc(rpc_incoming, peer_id)
|
||||
|
||||
# Force context switch
|
||||
await asyncio.sleep(0)
|
||||
|
||||
def set_topic_validator(
|
||||
self, topic: str, validator: ValidatorFn, is_async_validator: bool
|
||||
) -> None:
|
||||
|
@ -285,7 +291,6 @@ class Pubsub:
|
|||
logger.debug("Fail to add new peer %s: stream closed", peer_id)
|
||||
del self.peers[peer_id]
|
||||
return
|
||||
# TODO: Check EOF of this stream.
|
||||
# TODO: Check if the peer in black list.
|
||||
try:
|
||||
self.router.add_peer(peer_id, stream.get_protocol())
|
||||
|
@ -311,21 +316,23 @@ class Pubsub:
|
|||
|
||||
async def handle_peer_queue(self) -> None:
|
||||
"""
|
||||
Continuously read from peer queue and each time a new peer is found,
|
||||
Continuously read from peer channel and each time a new peer is found,
|
||||
open a stream to the peer using a supported pubsub protocol
|
||||
TODO: Handle failure for when the peer does not support any of the
|
||||
pubsub protocols we support
|
||||
"""
|
||||
while True:
|
||||
peer_id: ID = await self.peer_queue.get()
|
||||
async with self.peer_receive_channel:
|
||||
while self.manager.is_running:
|
||||
peer_id: ID = await self.peer_receive_channel.receive()
|
||||
# Add Peer
|
||||
self._tasks.append(asyncio.ensure_future(self._handle_new_peer(peer_id)))
|
||||
self.manager.run_task(self._handle_new_peer, peer_id)
|
||||
|
||||
async def handle_dead_peer_queue(self) -> None:
|
||||
"""Continuously read from dead peer queue and close the stream between
|
||||
"""Continuously read from dead peer channel and close the stream between
|
||||
that peer and remove peer info from pubsub and pubsub router."""
|
||||
while True:
|
||||
peer_id: ID = await self.dead_peer_queue.get()
|
||||
async with self.dead_peer_receive_channel:
|
||||
while self.manager.is_running:
|
||||
peer_id: ID = await self.dead_peer_receive_channel.receive()
|
||||
# Remove Peer
|
||||
self._handle_dead_peer(peer_id)
|
||||
|
||||
|
@ -361,13 +368,16 @@ class Pubsub:
|
|||
|
||||
# Check if this message has any topics that we are subscribed to
|
||||
for topic in publish_message.topicIDs:
|
||||
if topic in self.my_topics:
|
||||
if topic in self.topic_ids:
|
||||
# we are subscribed to a topic this message was sent for,
|
||||
# so add message to the subscription output queue
|
||||
# for each topic
|
||||
await self.my_topics[topic].put(publish_message)
|
||||
await self.subscribed_topics_send[topic].send(publish_message)
|
||||
|
||||
async def subscribe(self, topic_id: str) -> "asyncio.Queue[rpc_pb2.Message]":
|
||||
# TODO: Change to return an `AsyncIterable` to be I/O-agnostic?
|
||||
async def subscribe(
|
||||
self, topic_id: str
|
||||
) -> "trio.MemoryReceiveChannel[rpc_pb2.Message]":
|
||||
"""
|
||||
Subscribe ourself to a topic.
|
||||
|
||||
|
@ -377,11 +387,13 @@ class Pubsub:
|
|||
logger.debug("subscribing to topic %s", topic_id)
|
||||
|
||||
# Already subscribed
|
||||
if topic_id in self.my_topics:
|
||||
return self.my_topics[topic_id]
|
||||
if topic_id in self.topic_ids:
|
||||
return self.subscribed_topics_receive[topic_id]
|
||||
|
||||
# Map topic_id to blocking queue
|
||||
self.my_topics[topic_id] = asyncio.Queue()
|
||||
# Map topic_id to a blocking channel
|
||||
send_channel, receive_channel = trio.open_memory_channel(math.inf)
|
||||
self.subscribed_topics_send[topic_id] = send_channel
|
||||
self.subscribed_topics_receive[topic_id] = receive_channel
|
||||
|
||||
# Create subscribe message
|
||||
packet: rpc_pb2.RPC = rpc_pb2.RPC()
|
||||
|
@ -395,8 +407,8 @@ class Pubsub:
|
|||
# Tell router we are joining this topic
|
||||
await self.router.join(topic_id)
|
||||
|
||||
# Return the asyncio queue for messages on this topic
|
||||
return self.my_topics[topic_id]
|
||||
# Return the trio channel for messages on this topic
|
||||
return receive_channel
|
||||
|
||||
async def unsubscribe(self, topic_id: str) -> None:
|
||||
"""
|
||||
|
@ -408,10 +420,14 @@ class Pubsub:
|
|||
logger.debug("unsubscribing from topic %s", topic_id)
|
||||
|
||||
# Return if we already unsubscribed from the topic
|
||||
if topic_id not in self.my_topics:
|
||||
if topic_id not in self.topic_ids:
|
||||
return
|
||||
# Remove topic_id from map if present
|
||||
del self.my_topics[topic_id]
|
||||
# Remove topic_id from the maps before yielding
|
||||
send_channel = self.subscribed_topics_send[topic_id]
|
||||
del self.subscribed_topics_send[topic_id]
|
||||
del self.subscribed_topics_receive[topic_id]
|
||||
# Only close the send side
|
||||
await send_channel.aclose()
|
||||
|
||||
# Create unsubscribe message
|
||||
packet: rpc_pb2.RPC = rpc_pb2.RPC()
|
||||
|
@ -453,13 +469,13 @@ class Pubsub:
|
|||
data=data,
|
||||
topicIDs=[topic_id],
|
||||
# Origin is ourself.
|
||||
from_id=self.host.get_id().to_bytes(),
|
||||
from_id=self.my_id.to_bytes(),
|
||||
seqno=self._next_seqno(),
|
||||
)
|
||||
|
||||
# TODO: Sign with our signing key
|
||||
|
||||
await self.push_msg(self.host.get_id(), msg)
|
||||
await self.push_msg(self.my_id, msg)
|
||||
|
||||
logger.debug("successfully published message %s", msg)
|
||||
|
||||
|
@ -470,12 +486,12 @@ class Pubsub:
|
|||
:param msg_forwarder: the peer who forward us the message.
|
||||
:param msg: the message.
|
||||
"""
|
||||
sync_topic_validators = []
|
||||
async_topic_validator_futures: List[Awaitable[bool]] = []
|
||||
sync_topic_validators: List[SyncValidatorFn] = []
|
||||
async_topic_validators: List[AsyncValidatorFn] = []
|
||||
for topic_validator in self.get_msg_validators(msg):
|
||||
if topic_validator.is_async:
|
||||
async_topic_validator_futures.append(
|
||||
cast(Awaitable[bool], topic_validator.validator(msg_forwarder, msg))
|
||||
async_topic_validators.append(
|
||||
cast(AsyncValidatorFn, topic_validator.validator)
|
||||
)
|
||||
else:
|
||||
sync_topic_validators.append(
|
||||
|
@ -488,9 +504,20 @@ class Pubsub:
|
|||
|
||||
# TODO: Implement throttle on async validators
|
||||
|
||||
if len(async_topic_validator_futures) > 0:
|
||||
results = await asyncio.gather(*async_topic_validator_futures)
|
||||
if not all(results):
|
||||
if len(async_topic_validators) > 0:
|
||||
# TODO: Use a better pattern
|
||||
final_result = True
|
||||
|
||||
async def run_async_validator(func: AsyncValidatorFn) -> None:
|
||||
nonlocal final_result
|
||||
result = await func(msg_forwarder, msg)
|
||||
final_result = final_result and result
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
for validator in async_topic_validators:
|
||||
nursery.start_soon(run_async_validator, validator)
|
||||
|
||||
if not final_result:
|
||||
raise ValidationError(f"Validation failed for msg={msg}")
|
||||
|
||||
async def push_msg(self, msg_forwarder: ID, msg: rpc_pb2.Message) -> None:
|
||||
|
@ -551,14 +578,4 @@ class Pubsub:
|
|||
self.seen_messages[msg_id] = 1
|
||||
|
||||
def _is_subscribed_to_msg(self, msg: rpc_pb2.Message) -> bool:
|
||||
if not self.my_topics:
|
||||
return False
|
||||
return any(topic in self.my_topics for topic in msg.topicIDs)
|
||||
|
||||
async def close(self) -> None:
|
||||
for task in self._tasks:
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
return any(topic in self.topic_ids for topic in msg.topicIDs)
|
||||
|
|
|
@ -8,19 +8,19 @@ from libp2p.network.notifee_interface import INotifee
|
|||
from libp2p.network.stream.net_stream_interface import INetStream
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import asyncio # noqa: F401
|
||||
import trio # noqa: F401
|
||||
from libp2p.peer.id import ID # noqa: F401
|
||||
|
||||
|
||||
class PubsubNotifee(INotifee):
|
||||
|
||||
initiator_peers_queue: "asyncio.Queue[ID]"
|
||||
dead_peers_queue: "asyncio.Queue[ID]"
|
||||
initiator_peers_queue: "trio.MemorySendChannel[ID]"
|
||||
dead_peers_queue: "trio.MemorySendChannel[ID]"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
initiator_peers_queue: "asyncio.Queue[ID]",
|
||||
dead_peers_queue: "asyncio.Queue[ID]",
|
||||
initiator_peers_queue: "trio.MemorySendChannel[ID]",
|
||||
dead_peers_queue: "trio.MemorySendChannel[ID]",
|
||||
) -> None:
|
||||
"""
|
||||
:param initiator_peers_queue: queue to add new peers to so that pubsub
|
||||
|
@ -46,7 +46,12 @@ class PubsubNotifee(INotifee):
|
|||
:param network: network the connection was opened on
|
||||
:param conn: connection that was opened
|
||||
"""
|
||||
await self.initiator_peers_queue.put(conn.muxed_conn.peer_id)
|
||||
try:
|
||||
await self.initiator_peers_queue.send(conn.muxed_conn.peer_id)
|
||||
except trio.BrokenResourceError:
|
||||
# Raised when the receive channel is closed.
|
||||
# TODO: Do something with loggers?
|
||||
...
|
||||
|
||||
async def disconnected(self, network: INetwork, conn: INetConn) -> None:
|
||||
"""
|
||||
|
@ -56,7 +61,7 @@ class PubsubNotifee(INotifee):
|
|||
:param network: network the connection was opened on
|
||||
:param conn: connection that was opened
|
||||
"""
|
||||
await self.dead_peers_queue.put(conn.muxed_conn.peer_id)
|
||||
await self.dead_peers_queue.send(conn.muxed_conn.peer_id)
|
||||
|
||||
async def listen(self, network: INetwork, multiaddr: Multiaddr) -> None:
|
||||
pass
|
||||
|
|
|
@ -281,7 +281,7 @@ class Mplex(IMuxedConn, Service):
|
|||
mplex_stream = await self._initialize_stream(stream_id, message.decode())
|
||||
try:
|
||||
await self.new_stream_send_channel.send(mplex_stream)
|
||||
except (trio.BrokenResourceError, trio.EndOfChannel):
|
||||
except (trio.BrokenResourceError, trio.ClosedResourceError):
|
||||
raise MplexUnavailable
|
||||
|
||||
async def _handle_message(self, stream_id: StreamID, message: bytes) -> None:
|
||||
|
|
|
@ -5,6 +5,7 @@ from async_service import background_trio_service
|
|||
import factory
|
||||
import trio
|
||||
|
||||
from libp2p.tools.constants import GOSSIPSUB_PARAMS
|
||||
from libp2p import generate_new_rsa_identity, generate_peer_id_from
|
||||
from libp2p.crypto.keys import KeyPair
|
||||
from libp2p.host.basic_host import BasicHost
|
||||
|
@ -15,6 +16,7 @@ from libp2p.network.connection.swarm_connection import SwarmConn
|
|||
from libp2p.network.stream.net_stream_interface import INetStream
|
||||
from libp2p.network.swarm import Swarm
|
||||
from libp2p.peer.peerstore import PeerStore
|
||||
from libp2p.peer.id import ID
|
||||
from libp2p.pubsub.floodsub import FloodSub
|
||||
from libp2p.pubsub.gossipsub import GossipSub
|
||||
from libp2p.pubsub.pubsub import Pubsub
|
||||
|
@ -28,15 +30,19 @@ from libp2p.transport.typing import TMuxerOptions
|
|||
from libp2p.transport.upgrader import TransportUpgrader
|
||||
from libp2p.typing import TProtocol
|
||||
|
||||
from .constants import (
|
||||
FLOODSUB_PROTOCOL_ID,
|
||||
GOSSIPSUB_PARAMS,
|
||||
GOSSIPSUB_PROTOCOL_ID,
|
||||
LISTEN_MADDR,
|
||||
)
|
||||
from .constants import FLOODSUB_PROTOCOL_ID, GOSSIPSUB_PROTOCOL_ID, LISTEN_MADDR
|
||||
from .utils import connect, connect_swarm
|
||||
|
||||
|
||||
class IDFactory(factory.Factory):
|
||||
class Meta:
|
||||
model = ID
|
||||
|
||||
peer_id_bytes = factory.LazyFunction(
|
||||
lambda: generate_peer_id_from(generate_new_rsa_identity())
|
||||
)
|
||||
|
||||
|
||||
def security_transport_factory(
|
||||
is_secure: bool, key_pair: KeyPair
|
||||
) -> Dict[TProtocol, BaseSecureTransport]:
|
||||
|
@ -181,9 +187,38 @@ class PubsubFactory(factory.Factory):
|
|||
|
||||
host = factory.SubFactory(HostFactory)
|
||||
router = None
|
||||
my_id = factory.LazyAttribute(lambda obj: obj.host.get_id())
|
||||
cache_size = None
|
||||
|
||||
@classmethod
|
||||
@asynccontextmanager
|
||||
async def create_and_start(cls, host, router, cache_size):
|
||||
pubsub = PubsubFactory(host=host, router=router, cache_size=cache_size)
|
||||
async with background_trio_service(pubsub):
|
||||
yield pubsub
|
||||
|
||||
@classmethod
|
||||
@asynccontextmanager
|
||||
async def create_batch_with_floodsub(
|
||||
cls, number: int, is_secure: bool = False, cache_size: int = None
|
||||
):
|
||||
floodsubs = FloodsubFactory.create_batch(number)
|
||||
async with HostFactory.create_batch_and_listen(is_secure, number) as hosts:
|
||||
# Pubsubs should exit before hosts
|
||||
async with AsyncExitStack() as stack:
|
||||
pubsubs = [
|
||||
await stack.enter_async_context(
|
||||
cls.create_and_start(host, router, cache_size)
|
||||
)
|
||||
for host, router in zip(hosts, floodsubs)
|
||||
]
|
||||
yield pubsubs
|
||||
|
||||
# @classmethod
|
||||
# async def create_batch_with_gossipsub(
|
||||
# cls, number: int, cache_size: int = None, gossipsub_params=GOSSIPSUB_PARAMS
|
||||
# ):
|
||||
# ...
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def swarm_pair_factory(
|
||||
|
|
|
@ -4,18 +4,6 @@ from libp2p.tools.constants import GOSSIPSUB_PARAMS
|
|||
from libp2p.tools.factories import FloodsubFactory, GossipsubFactory, PubsubFactory
|
||||
|
||||
|
||||
def _make_pubsubs(hosts, pubsub_routers, cache_size):
|
||||
if len(pubsub_routers) != len(hosts):
|
||||
raise ValueError(
|
||||
f"lenght of pubsub_routers={pubsub_routers} should be equaled to the "
|
||||
f"length of hosts={len(hosts)}"
|
||||
)
|
||||
return tuple(
|
||||
PubsubFactory(host=host, router=router, cache_size=cache_size)
|
||||
for host, router in zip(hosts, pubsub_routers)
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pubsub_cache_size():
|
||||
return None # default
|
||||
|
@ -26,17 +14,9 @@ def gossipsub_params():
|
|||
return GOSSIPSUB_PARAMS
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pubsubs_fsub(num_hosts, hosts, pubsub_cache_size):
|
||||
floodsubs = FloodsubFactory.create_batch(num_hosts)
|
||||
_pubsubs_fsub = _make_pubsubs(hosts, floodsubs, pubsub_cache_size)
|
||||
yield _pubsubs_fsub
|
||||
# TODO: Clean up
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pubsubs_gsub(num_hosts, hosts, pubsub_cache_size, gossipsub_params):
|
||||
gossipsubs = GossipsubFactory.create_batch(num_hosts, **gossipsub_params._asdict())
|
||||
_pubsubs_gsub = _make_pubsubs(hosts, gossipsubs, pubsub_cache_size)
|
||||
yield _pubsubs_gsub
|
||||
# TODO: Clean up
|
||||
# @pytest.fixture
|
||||
# def pubsubs_gsub(num_hosts, hosts, pubsub_cache_size, gossipsub_params):
|
||||
# gossipsubs = GossipsubFactory.create_batch(num_hosts, **gossipsub_params._asdict())
|
||||
# _pubsubs_gsub = _make_pubsubs(hosts, gossipsubs, pubsub_cache_size)
|
||||
# yield _pubsubs_gsub
|
||||
# # TODO: Clean up
|
||||
|
|
|
@ -1,73 +1,78 @@
|
|||
import asyncio
|
||||
from contextlib import contextmanager
|
||||
from typing import NamedTuple
|
||||
|
||||
import pytest
|
||||
import trio
|
||||
|
||||
from libp2p.exceptions import ValidationError
|
||||
from libp2p.peer.id import ID
|
||||
from libp2p.pubsub.pb import rpc_pb2
|
||||
from libp2p.tools.pubsub.utils import make_pubsub_msg
|
||||
from libp2p.tools.utils import connect
|
||||
from libp2p.tools.constants import MAX_READ_LEN
|
||||
from libp2p.tools.factories import PubsubFactory, net_stream_pair_factory, IDFactory
|
||||
from libp2p.utils import encode_varint_prefixed
|
||||
|
||||
TESTING_TOPIC = "TEST_SUBSCRIBE"
|
||||
TESTING_DATA = b"data"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_hosts", (1,))
|
||||
@pytest.mark.asyncio
|
||||
async def test_subscribe_and_unsubscribe(pubsubs_fsub):
|
||||
@pytest.mark.trio
|
||||
async def test_subscribe_and_unsubscribe():
|
||||
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
|
||||
await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
|
||||
assert TESTING_TOPIC in pubsubs_fsub[0].my_topics
|
||||
assert TESTING_TOPIC in pubsubs_fsub[0].topic_ids
|
||||
|
||||
await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC)
|
||||
assert TESTING_TOPIC not in pubsubs_fsub[0].my_topics
|
||||
assert TESTING_TOPIC not in pubsubs_fsub[0].topic_ids
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_hosts", (1,))
|
||||
@pytest.mark.asyncio
|
||||
async def test_re_subscribe(pubsubs_fsub):
|
||||
@pytest.mark.trio
|
||||
async def test_re_subscribe():
|
||||
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
|
||||
await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
|
||||
assert TESTING_TOPIC in pubsubs_fsub[0].my_topics
|
||||
assert TESTING_TOPIC in pubsubs_fsub[0].topic_ids
|
||||
|
||||
await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
|
||||
assert TESTING_TOPIC in pubsubs_fsub[0].my_topics
|
||||
assert TESTING_TOPIC in pubsubs_fsub[0].topic_ids
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_hosts", (1,))
|
||||
@pytest.mark.asyncio
|
||||
async def test_re_unsubscribe(pubsubs_fsub):
|
||||
@pytest.mark.trio
|
||||
async def test_re_unsubscribe():
|
||||
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
|
||||
# Unsubscribe from topic we didn't even subscribe to
|
||||
assert "NOT_MY_TOPIC" not in pubsubs_fsub[0].my_topics
|
||||
assert "NOT_MY_TOPIC" not in pubsubs_fsub[0].topic_ids
|
||||
await pubsubs_fsub[0].unsubscribe("NOT_MY_TOPIC")
|
||||
assert "NOT_MY_TOPIC" not in pubsubs_fsub[0].my_topics
|
||||
assert "NOT_MY_TOPIC" not in pubsubs_fsub[0].topic_ids
|
||||
|
||||
await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
|
||||
assert TESTING_TOPIC in pubsubs_fsub[0].my_topics
|
||||
assert TESTING_TOPIC in pubsubs_fsub[0].topic_ids
|
||||
|
||||
await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC)
|
||||
assert TESTING_TOPIC not in pubsubs_fsub[0].my_topics
|
||||
assert TESTING_TOPIC not in pubsubs_fsub[0].topic_ids
|
||||
|
||||
await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC)
|
||||
assert TESTING_TOPIC not in pubsubs_fsub[0].my_topics
|
||||
assert TESTING_TOPIC not in pubsubs_fsub[0].topic_ids
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_peers_subscribe(pubsubs_fsub):
|
||||
@pytest.mark.trio
|
||||
async def test_peers_subscribe():
|
||||
async with PubsubFactory.create_batch_with_floodsub(2) as pubsubs_fsub:
|
||||
await connect(pubsubs_fsub[0].host, pubsubs_fsub[1].host)
|
||||
await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
|
||||
# Yield to let 0 notify 1
|
||||
await asyncio.sleep(0.1)
|
||||
await trio.sleep(0.1)
|
||||
assert pubsubs_fsub[0].my_id in pubsubs_fsub[1].peer_topics[TESTING_TOPIC]
|
||||
await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC)
|
||||
# Yield to let 0 notify 1
|
||||
await asyncio.sleep(0.1)
|
||||
await trio.sleep(0.1)
|
||||
assert pubsubs_fsub[0].my_id not in pubsubs_fsub[1].peer_topics[TESTING_TOPIC]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_hosts", (1,))
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_hello_packet(pubsubs_fsub):
|
||||
@pytest.mark.trio
|
||||
async def test_get_hello_packet():
|
||||
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
|
||||
|
||||
def _get_hello_packet_topic_ids():
|
||||
packet = pubsubs_fsub[0].get_hello_packet()
|
||||
return tuple(sub.topicid for sub in packet.subscriptions)
|
||||
|
@ -77,16 +82,16 @@ async def test_get_hello_packet(pubsubs_fsub):
|
|||
|
||||
# Test: After subscriptions, topic ids should be in the hello packet.
|
||||
topic_ids = ["t", "o", "p", "i", "c"]
|
||||
await asyncio.gather(*[pubsubs_fsub[0].subscribe(topic) for topic in topic_ids])
|
||||
for topic in topic_ids:
|
||||
await pubsubs_fsub[0].subscribe(topic)
|
||||
topic_ids_in_hello = _get_hello_packet_topic_ids()
|
||||
for topic in topic_ids:
|
||||
assert topic in topic_ids_in_hello
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_hosts", (1,))
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_and_remove_topic_validator(pubsubs_fsub):
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_set_and_remove_topic_validator():
|
||||
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
|
||||
is_sync_validator_called = False
|
||||
|
||||
def sync_validator(peer_id, msg):
|
||||
|
@ -111,7 +116,7 @@ async def test_set_and_remove_topic_validator(pubsubs_fsub):
|
|||
assert not topic_validator.is_async
|
||||
|
||||
# Validate with sync validator
|
||||
topic_validator.validator(peer_id=ID(b"peer"), msg="msg")
|
||||
topic_validator.validator(peer_id=IDFactory(), msg="msg")
|
||||
|
||||
assert is_sync_validator_called
|
||||
assert not is_async_validator_called
|
||||
|
@ -125,7 +130,7 @@ async def test_set_and_remove_topic_validator(pubsubs_fsub):
|
|||
assert topic_validator.is_async
|
||||
|
||||
# Validate with async validator
|
||||
await topic_validator.validator(peer_id=ID(b"peer"), msg="msg")
|
||||
await topic_validator.validator(peer_id=IDFactory(), msg="msg")
|
||||
|
||||
assert is_async_validator_called
|
||||
assert not is_sync_validator_called
|
||||
|
@ -135,10 +140,9 @@ async def test_set_and_remove_topic_validator(pubsubs_fsub):
|
|||
assert topic not in pubsubs_fsub[0].topic_validators
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_hosts", (1,))
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_msg_validators(pubsubs_fsub):
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_get_msg_validators():
|
||||
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
|
||||
times_sync_validator_called = 0
|
||||
|
||||
def sync_validator(peer_id, msg):
|
||||
|
@ -172,21 +176,22 @@ async def test_get_msg_validators(pubsubs_fsub):
|
|||
topic_validators = pubsubs_fsub[0].get_msg_validators(msg)
|
||||
for topic_validator in topic_validators:
|
||||
if topic_validator.is_async:
|
||||
await topic_validator.validator(peer_id=ID(b"peer"), msg="msg")
|
||||
await topic_validator.validator(peer_id=IDFactory(), msg="msg")
|
||||
else:
|
||||
topic_validator.validator(peer_id=ID(b"peer"), msg="msg")
|
||||
topic_validator.validator(peer_id=IDFactory(), msg="msg")
|
||||
|
||||
assert times_sync_validator_called == 2
|
||||
assert times_async_validator_called == 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_hosts", (1,))
|
||||
@pytest.mark.parametrize(
|
||||
"is_topic_1_val_passed, is_topic_2_val_passed",
|
||||
((False, True), (True, False), (True, True)),
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_msg(pubsubs_fsub, is_topic_1_val_passed, is_topic_2_val_passed):
|
||||
@pytest.mark.trio
|
||||
async def test_validate_msg(is_topic_1_val_passed, is_topic_2_val_passed):
|
||||
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
|
||||
|
||||
def passed_sync_validator(peer_id, msg):
|
||||
return True
|
||||
|
||||
|
@ -226,123 +231,98 @@ async def test_validate_msg(pubsubs_fsub, is_topic_1_val_passed, is_topic_2_val_
|
|||
await pubsubs_fsub[0].validate_msg(pubsubs_fsub[0].my_id, msg)
|
||||
|
||||
|
||||
class FakeNetStream:
|
||||
_queue: asyncio.Queue
|
||||
@pytest.mark.trio
|
||||
async def test_continuously_read_stream(monkeypatch, nursery, is_host_secure):
|
||||
async def wait_for_event_occurring(event):
|
||||
with trio.fail_after(0.1):
|
||||
await event.wait()
|
||||
|
||||
class FakeMplexConn(NamedTuple):
|
||||
peer_id: ID = ID(b"\x12\x20" + b"\x00" * 32)
|
||||
class Events(NamedTuple):
|
||||
push_msg: trio.Event
|
||||
handle_subscription: trio.Event
|
||||
handle_rpc: trio.Event
|
||||
|
||||
muxed_conn = FakeMplexConn()
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._queue = asyncio.Queue()
|
||||
|
||||
async def read(self, n: int = -1) -> bytes:
|
||||
buf = bytearray()
|
||||
# Force to blocking wait if no data available now.
|
||||
if self._queue.empty():
|
||||
first_byte = await self._queue.get()
|
||||
buf.extend(first_byte)
|
||||
# If `n == -1`, read until no data is in the buffer(_queue).
|
||||
# Else, read until no data is in the buffer(_queue) or we have read `n` bytes.
|
||||
while (n == -1) or (len(buf) < n):
|
||||
if self._queue.empty():
|
||||
break
|
||||
buf.extend(await self._queue.get())
|
||||
return bytes(buf)
|
||||
|
||||
async def write(self, data: bytes) -> int:
|
||||
for i in data:
|
||||
await self._queue.put(i.to_bytes(1, "big"))
|
||||
return len(data)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_hosts", (1,))
|
||||
@pytest.mark.asyncio
|
||||
async def test_continuously_read_stream(pubsubs_fsub, monkeypatch):
|
||||
stream = FakeNetStream()
|
||||
|
||||
await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
|
||||
|
||||
event_push_msg = asyncio.Event()
|
||||
event_handle_subscription = asyncio.Event()
|
||||
event_handle_rpc = asyncio.Event()
|
||||
@contextmanager
|
||||
def mock_methods():
|
||||
event_push_msg = trio.Event()
|
||||
event_handle_subscription = trio.Event()
|
||||
event_handle_rpc = trio.Event()
|
||||
|
||||
async def mock_push_msg(msg_forwarder, msg):
|
||||
event_push_msg.set()
|
||||
await trio.sleep(0)
|
||||
|
||||
def mock_handle_subscription(origin_id, sub_message):
|
||||
event_handle_subscription.set()
|
||||
|
||||
async def mock_handle_rpc(rpc, sender_peer_id):
|
||||
event_handle_rpc.set()
|
||||
await trio.sleep(0)
|
||||
|
||||
monkeypatch.setattr(pubsubs_fsub[0], "push_msg", mock_push_msg)
|
||||
monkeypatch.setattr(
|
||||
pubsubs_fsub[0], "handle_subscription", mock_handle_subscription
|
||||
)
|
||||
monkeypatch.setattr(pubsubs_fsub[0].router, "handle_rpc", mock_handle_rpc)
|
||||
|
||||
async def wait_for_event_occurring(event):
|
||||
try:
|
||||
await asyncio.wait_for(event.wait(), timeout=1)
|
||||
except asyncio.TimeoutError as error:
|
||||
event.clear()
|
||||
raise asyncio.TimeoutError(
|
||||
f"Event {event} is not set before the timeout. "
|
||||
"This indicates the mocked functions are not called properly."
|
||||
) from error
|
||||
else:
|
||||
event.clear()
|
||||
with monkeypatch.context() as m:
|
||||
m.setattr(pubsubs_fsub[0], "push_msg", mock_push_msg)
|
||||
m.setattr(pubsubs_fsub[0], "handle_subscription", mock_handle_subscription)
|
||||
m.setattr(pubsubs_fsub[0].router, "handle_rpc", mock_handle_rpc)
|
||||
yield Events(event_push_msg, event_handle_subscription, event_handle_rpc)
|
||||
|
||||
async with PubsubFactory.create_batch_with_floodsub(
|
||||
1, is_secure=is_host_secure
|
||||
) as pubsubs_fsub, net_stream_pair_factory(is_secure=is_host_secure) as stream_pair:
|
||||
await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
|
||||
# Kick off the task `continuously_read_stream`
|
||||
task = asyncio.ensure_future(pubsubs_fsub[0].continuously_read_stream(stream))
|
||||
nursery.start_soon(pubsubs_fsub[0].continuously_read_stream, stream_pair[0])
|
||||
|
||||
# Test: `push_msg` is called when publishing to a subscribed topic.
|
||||
publish_subscribed_topic = rpc_pb2.RPC(
|
||||
publish=[rpc_pb2.Message(topicIDs=[TESTING_TOPIC])]
|
||||
)
|
||||
await stream.write(
|
||||
with mock_methods() as events:
|
||||
await stream_pair[1].write(
|
||||
encode_varint_prefixed(publish_subscribed_topic.SerializeToString())
|
||||
)
|
||||
await wait_for_event_occurring(event_push_msg)
|
||||
await wait_for_event_occurring(events.push_msg)
|
||||
# Make sure the other events are not emitted.
|
||||
with pytest.raises(asyncio.TimeoutError):
|
||||
await wait_for_event_occurring(event_handle_subscription)
|
||||
with pytest.raises(asyncio.TimeoutError):
|
||||
await wait_for_event_occurring(event_handle_rpc)
|
||||
with pytest.raises(trio.TooSlowError):
|
||||
await wait_for_event_occurring(events.handle_subscription)
|
||||
with pytest.raises(trio.TooSlowError):
|
||||
await wait_for_event_occurring(events.handle_rpc)
|
||||
|
||||
# Test: `push_msg` is not called when publishing to a topic-not-subscribed.
|
||||
publish_not_subscribed_topic = rpc_pb2.RPC(
|
||||
publish=[rpc_pb2.Message(topicIDs=["NOT_SUBSCRIBED"])]
|
||||
)
|
||||
await stream.write(
|
||||
with mock_methods() as events:
|
||||
await stream_pair[1].write(
|
||||
encode_varint_prefixed(publish_not_subscribed_topic.SerializeToString())
|
||||
)
|
||||
with pytest.raises(asyncio.TimeoutError):
|
||||
await wait_for_event_occurring(event_push_msg)
|
||||
with pytest.raises(trio.TooSlowError):
|
||||
await wait_for_event_occurring(events.push_msg)
|
||||
|
||||
# Test: `handle_subscription` is called when a subscription message is received.
|
||||
subscription_msg = rpc_pb2.RPC(subscriptions=[rpc_pb2.RPC.SubOpts()])
|
||||
await stream.write(encode_varint_prefixed(subscription_msg.SerializeToString()))
|
||||
await wait_for_event_occurring(event_handle_subscription)
|
||||
with mock_methods() as events:
|
||||
await stream_pair[1].write(
|
||||
encode_varint_prefixed(subscription_msg.SerializeToString())
|
||||
)
|
||||
await wait_for_event_occurring(events.handle_subscription)
|
||||
# Make sure the other events are not emitted.
|
||||
with pytest.raises(asyncio.TimeoutError):
|
||||
await wait_for_event_occurring(event_push_msg)
|
||||
with pytest.raises(asyncio.TimeoutError):
|
||||
await wait_for_event_occurring(event_handle_rpc)
|
||||
with pytest.raises(trio.TooSlowError):
|
||||
await wait_for_event_occurring(events.push_msg)
|
||||
with pytest.raises(trio.TooSlowError):
|
||||
await wait_for_event_occurring(events.handle_rpc)
|
||||
|
||||
# Test: `handle_rpc` is called when a control message is received.
|
||||
control_msg = rpc_pb2.RPC(control=rpc_pb2.ControlMessage())
|
||||
await stream.write(encode_varint_prefixed(control_msg.SerializeToString()))
|
||||
await wait_for_event_occurring(event_handle_rpc)
|
||||
with mock_methods() as events:
|
||||
await stream_pair[1].write(
|
||||
encode_varint_prefixed(control_msg.SerializeToString())
|
||||
)
|
||||
await wait_for_event_occurring(events.handle_rpc)
|
||||
# Make sure the other events are not emitted.
|
||||
with pytest.raises(asyncio.TimeoutError):
|
||||
await wait_for_event_occurring(event_push_msg)
|
||||
with pytest.raises(asyncio.TimeoutError):
|
||||
await wait_for_event_occurring(event_handle_subscription)
|
||||
|
||||
task.cancel()
|
||||
with pytest.raises(trio.TooSlowError):
|
||||
await wait_for_event_occurring(events.push_msg)
|
||||
with pytest.raises(trio.TooSlowError):
|
||||
await wait_for_event_occurring(events.handle_subscription)
|
||||
|
||||
|
||||
# TODO: Add the following tests after they are aligned with Go.
|
||||
|
@ -351,11 +331,12 @@ async def test_continuously_read_stream(pubsubs_fsub, monkeypatch):
|
|||
# - `test_handle_peer_queue`
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_hosts", (1,))
|
||||
def test_handle_subscription(pubsubs_fsub):
|
||||
@pytest.mark.trio
|
||||
async def test_handle_subscription():
|
||||
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
|
||||
assert len(pubsubs_fsub[0].peer_topics) == 0
|
||||
sub_msg_0 = rpc_pb2.RPC.SubOpts(subscribe=True, topicid=TESTING_TOPIC)
|
||||
peer_ids = [ID(b"\x12\x20" + i.to_bytes(32, "big")) for i in range(2)]
|
||||
peer_ids = [IDFactory() for _ in range(2)]
|
||||
# Test: One peer is subscribed
|
||||
pubsubs_fsub[0].handle_subscription(peer_ids[0], sub_msg_0)
|
||||
assert (
|
||||
|
@ -382,9 +363,9 @@ def test_handle_subscription(pubsubs_fsub):
|
|||
assert peer_ids[0] not in pubsubs_fsub[0].peer_topics[TESTING_TOPIC]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_hosts", (1,))
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_talk(pubsubs_fsub):
|
||||
@pytest.mark.trio
|
||||
async def test_handle_talk():
|
||||
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
|
||||
sub = await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
|
||||
msg_0 = make_pubsub_msg(
|
||||
origin_id=pubsubs_fsub[0].my_id,
|
||||
|
@ -401,31 +382,33 @@ async def test_handle_talk(pubsubs_fsub):
|
|||
)
|
||||
await pubsubs_fsub[0].handle_talk(msg_1)
|
||||
assert (
|
||||
len(pubsubs_fsub[0].my_topics) == 1
|
||||
and sub == pubsubs_fsub[0].my_topics[TESTING_TOPIC]
|
||||
len(pubsubs_fsub[0].topic_ids) == 1
|
||||
and sub == pubsubs_fsub[0].subscribed_topics_receive[TESTING_TOPIC]
|
||||
)
|
||||
assert sub.qsize() == 1
|
||||
assert (await sub.get()) == msg_0
|
||||
assert (await sub.receive()) == msg_0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_hosts", (1,))
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_all_peers(pubsubs_fsub, monkeypatch):
|
||||
peer_ids = [ID(b"\x12\x20" + i.to_bytes(32, "big")) for i in range(10)]
|
||||
mock_peers = {peer_id: FakeNetStream() for peer_id in peer_ids}
|
||||
monkeypatch.setattr(pubsubs_fsub[0], "peers", mock_peers)
|
||||
@pytest.mark.trio
|
||||
async def test_message_all_peers(monkeypatch, is_host_secure):
|
||||
async with PubsubFactory.create_batch_with_floodsub(
|
||||
1, is_secure=is_host_secure
|
||||
) as pubsubs_fsub, net_stream_pair_factory(is_secure=is_host_secure) as stream_pair:
|
||||
peer_id = IDFactory()
|
||||
mock_peers = {peer_id: stream_pair[0]}
|
||||
with monkeypatch.context() as m:
|
||||
m.setattr(pubsubs_fsub[0], "peers", mock_peers)
|
||||
|
||||
empty_rpc = rpc_pb2.RPC()
|
||||
empty_rpc_bytes = empty_rpc.SerializeToString()
|
||||
empty_rpc_bytes_len_prefixed = encode_varint_prefixed(empty_rpc_bytes)
|
||||
await pubsubs_fsub[0].message_all_peers(empty_rpc_bytes)
|
||||
for stream in mock_peers.values():
|
||||
assert (await stream.read()) == empty_rpc_bytes_len_prefixed
|
||||
assert (
|
||||
await stream_pair[1].read(MAX_READ_LEN)
|
||||
) == empty_rpc_bytes_len_prefixed
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_hosts", (1,))
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish(pubsubs_fsub, monkeypatch):
|
||||
@pytest.mark.trio
|
||||
async def test_publish(monkeypatch):
|
||||
msg_forwarders = []
|
||||
msgs = []
|
||||
|
||||
|
@ -433,21 +416,27 @@ async def test_publish(pubsubs_fsub, monkeypatch):
|
|||
msg_forwarders.append(msg_forwarder)
|
||||
msgs.append(msg)
|
||||
|
||||
monkeypatch.setattr(pubsubs_fsub[0], "push_msg", push_msg)
|
||||
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
|
||||
with monkeypatch.context() as m:
|
||||
m.setattr(pubsubs_fsub[0], "push_msg", push_msg)
|
||||
|
||||
await pubsubs_fsub[0].publish(TESTING_TOPIC, TESTING_DATA)
|
||||
await pubsubs_fsub[0].publish(TESTING_TOPIC, TESTING_DATA)
|
||||
|
||||
assert len(msgs) == 2, "`push_msg` should be called every time `publish` is called"
|
||||
assert (
|
||||
len(msgs) == 2
|
||||
), "`push_msg` should be called every time `publish` is called"
|
||||
assert (msg_forwarders[0] == msg_forwarders[1]) and (
|
||||
msg_forwarders[1] == pubsubs_fsub[0].my_id
|
||||
)
|
||||
assert msgs[0].seqno != msgs[1].seqno, "`seqno` should be different every time"
|
||||
assert (
|
||||
msgs[0].seqno != msgs[1].seqno
|
||||
), "`seqno` should be different every time"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_hosts", (1,))
|
||||
@pytest.mark.asyncio
|
||||
async def test_push_msg(pubsubs_fsub, monkeypatch):
|
||||
@pytest.mark.trio
|
||||
async def test_push_msg(monkeypatch):
|
||||
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
|
||||
msg_0 = make_pubsub_msg(
|
||||
origin_id=pubsubs_fsub[0].my_id,
|
||||
topic_ids=[TESTING_TOPIC],
|
||||
|
@ -455,25 +444,33 @@ async def test_push_msg(pubsubs_fsub, monkeypatch):
|
|||
seqno=b"\x00" * 8,
|
||||
)
|
||||
|
||||
event = asyncio.Event()
|
||||
@contextmanager
|
||||
def mock_router_publish():
|
||||
|
||||
event = trio.Event()
|
||||
|
||||
async def router_publish(*args, **kwargs):
|
||||
event.set()
|
||||
await trio.sleep(0)
|
||||
|
||||
monkeypatch.setattr(pubsubs_fsub[0].router, "publish", router_publish)
|
||||
with monkeypatch.context() as m:
|
||||
m.setattr(pubsubs_fsub[0].router, "publish", router_publish)
|
||||
yield event
|
||||
|
||||
with mock_router_publish() as event:
|
||||
# Test: `msg` is not seen before `push_msg`, and is seen after `push_msg`.
|
||||
assert not pubsubs_fsub[0]._is_msg_seen(msg_0)
|
||||
await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg_0)
|
||||
assert pubsubs_fsub[0]._is_msg_seen(msg_0)
|
||||
# Test: Ensure `router.publish` is called in `push_msg`
|
||||
await asyncio.wait_for(event.wait(), timeout=0.1)
|
||||
with trio.fail_after(0.1):
|
||||
await event.wait()
|
||||
|
||||
with mock_router_publish() as event:
|
||||
# Test: `push_msg` the message again and it will be reject.
|
||||
# `router_publish` is not called then.
|
||||
event.clear()
|
||||
await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg_0)
|
||||
await asyncio.sleep(0.01)
|
||||
await trio.sleep(0.01)
|
||||
assert not event.is_set()
|
||||
|
||||
sub = await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
|
||||
|
@ -487,17 +484,21 @@ async def test_push_msg(pubsubs_fsub, monkeypatch):
|
|||
assert not pubsubs_fsub[0]._is_msg_seen(msg_1)
|
||||
await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg_1)
|
||||
assert pubsubs_fsub[0]._is_msg_seen(msg_1)
|
||||
await asyncio.wait_for(event.wait(), timeout=0.1)
|
||||
with trio.fail_after(0.1):
|
||||
await event.wait()
|
||||
# Test: Subscribers are notified when `push_msg` new messages.
|
||||
assert (await sub.get()) == msg_1
|
||||
assert (await sub.receive()) == msg_1
|
||||
|
||||
with mock_router_publish() as event:
|
||||
# Test: add a topic validator and `push_msg` the message that
|
||||
# does not pass the validation.
|
||||
# `router_publish` is not called then.
|
||||
def failed_sync_validator(peer_id, msg):
|
||||
return False
|
||||
|
||||
pubsubs_fsub[0].set_topic_validator(TESTING_TOPIC, failed_sync_validator, False)
|
||||
pubsubs_fsub[0].set_topic_validator(
|
||||
TESTING_TOPIC, failed_sync_validator, False
|
||||
)
|
||||
|
||||
msg_2 = make_pubsub_msg(
|
||||
origin_id=pubsubs_fsub[0].my_id,
|
||||
|
@ -506,7 +507,6 @@ async def test_push_msg(pubsubs_fsub, monkeypatch):
|
|||
seqno=b"\x22" * 8,
|
||||
)
|
||||
|
||||
event.clear()
|
||||
await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg_2)
|
||||
await asyncio.sleep(0.01)
|
||||
await trio.sleep(0.01)
|
||||
assert not event.is_set()
|
||||
|
|
Loading…
Reference in New Issue
Block a user