Fix Pubsub
This commit is contained in:
parent
bdbb7b2394
commit
e9ab0646e3
|
@ -1,6 +1,8 @@
|
||||||
import logging
|
import logging
|
||||||
from typing import Iterable, List, Sequence
|
from typing import Iterable, List, Sequence
|
||||||
|
|
||||||
|
import trio
|
||||||
|
|
||||||
from libp2p.network.stream.exceptions import StreamClosed
|
from libp2p.network.stream.exceptions import StreamClosed
|
||||||
from libp2p.peer.id import ID
|
from libp2p.peer.id import ID
|
||||||
from libp2p.typing import TProtocol
|
from libp2p.typing import TProtocol
|
||||||
|
@ -61,6 +63,8 @@ class FloodSub(IPubsubRouter):
|
||||||
|
|
||||||
:param rpc: rpc message
|
:param rpc: rpc message
|
||||||
"""
|
"""
|
||||||
|
# Checkpoint
|
||||||
|
await trio.sleep(0)
|
||||||
|
|
||||||
async def publish(self, msg_forwarder: ID, pubsub_msg: rpc_pb2.Message) -> None:
|
async def publish(self, msg_forwarder: ID, pubsub_msg: rpc_pb2.Message) -> None:
|
||||||
"""
|
"""
|
||||||
|
@ -102,6 +106,8 @@ class FloodSub(IPubsubRouter):
|
||||||
|
|
||||||
:param topic: topic to join
|
:param topic: topic to join
|
||||||
"""
|
"""
|
||||||
|
# Checkpoint
|
||||||
|
await trio.sleep(0)
|
||||||
|
|
||||||
async def leave(self, topic: str) -> None:
|
async def leave(self, topic: str) -> None:
|
||||||
"""
|
"""
|
||||||
|
@ -110,6 +116,8 @@ class FloodSub(IPubsubRouter):
|
||||||
|
|
||||||
:param topic: topic to leave
|
:param topic: topic to leave
|
||||||
"""
|
"""
|
||||||
|
# Checkpoint
|
||||||
|
await trio.sleep(0)
|
||||||
|
|
||||||
def _get_peers_to_send(
|
def _get_peers_to_send(
|
||||||
self, topic_ids: Iterable[str], msg_forwarder: ID, origin: ID
|
self, topic_ids: Iterable[str], msg_forwarder: ID, origin: ID
|
||||||
|
|
|
@ -1,11 +1,13 @@
|
||||||
import asyncio
|
from abc import ABC, abstractmethod
|
||||||
import logging
|
import logging
|
||||||
|
import math
|
||||||
import time
|
import time
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Awaitable,
|
Awaitable,
|
||||||
Callable,
|
Callable,
|
||||||
Dict,
|
Dict,
|
||||||
|
KeysView,
|
||||||
List,
|
List,
|
||||||
NamedTuple,
|
NamedTuple,
|
||||||
Tuple,
|
Tuple,
|
||||||
|
@ -13,8 +15,10 @@ from typing import (
|
||||||
cast,
|
cast,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from async_service import Service
|
||||||
import base58
|
import base58
|
||||||
from lru import LRU
|
from lru import LRU
|
||||||
|
import trio
|
||||||
|
|
||||||
from libp2p.exceptions import ParseError, ValidationError
|
from libp2p.exceptions import ParseError, ValidationError
|
||||||
from libp2p.host.host_interface import IHost
|
from libp2p.host.host_interface import IHost
|
||||||
|
@ -53,24 +57,24 @@ class TopicValidator(NamedTuple):
|
||||||
is_async: bool
|
is_async: bool
|
||||||
|
|
||||||
|
|
||||||
class Pubsub:
|
class BasePubsub(ABC):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class Pubsub(BasePubsub, Service):
|
||||||
|
|
||||||
host: IHost
|
host: IHost
|
||||||
my_id: ID
|
|
||||||
|
|
||||||
router: "IPubsubRouter"
|
router: "IPubsubRouter"
|
||||||
|
|
||||||
peer_queue: "asyncio.Queue[ID]"
|
peer_receive_channel: "trio.MemoryReceiveChannel[ID]"
|
||||||
dead_peer_queue: "asyncio.Queue[ID]"
|
dead_peer_receive_channel: "trio.MemoryReceiveChannel[ID]"
|
||||||
|
|
||||||
protocols: List[TProtocol]
|
|
||||||
|
|
||||||
incoming_msgs_from_peers: "asyncio.Queue[rpc_pb2.Message]"
|
|
||||||
outgoing_messages: "asyncio.Queue[rpc_pb2.Message]"
|
|
||||||
|
|
||||||
seen_messages: LRU
|
seen_messages: LRU
|
||||||
|
|
||||||
my_topics: Dict[str, "asyncio.Queue[rpc_pb2.Message]"]
|
# TODO: Implement `trio.abc.Channel`?
|
||||||
|
subscribed_topics_send: Dict[str, "trio.MemorySendChannel[rpc_pb2.Message]"]
|
||||||
|
subscribed_topics_receive: Dict[str, "trio.MemoryReceiveChannel[rpc_pb2.Message]"]
|
||||||
|
|
||||||
peer_topics: Dict[str, List[ID]]
|
peer_topics: Dict[str, List[ID]]
|
||||||
peers: Dict[ID, INetStream]
|
peers: Dict[ID, INetStream]
|
||||||
|
@ -80,10 +84,8 @@ class Pubsub:
|
||||||
# TODO: Be sure it is increased atomically everytime.
|
# TODO: Be sure it is increased atomically everytime.
|
||||||
counter: int # uint64
|
counter: int # uint64
|
||||||
|
|
||||||
_tasks: List["asyncio.Future[Any]"]
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, host: IHost, router: "IPubsubRouter", my_id: ID, cache_size: int = None
|
self, host: IHost, router: "IPubsubRouter", cache_size: int = None
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Construct a new Pubsub object, which is responsible for handling all
|
Construct a new Pubsub object, which is responsible for handling all
|
||||||
|
@ -97,28 +99,26 @@ class Pubsub:
|
||||||
"""
|
"""
|
||||||
self.host = host
|
self.host = host
|
||||||
self.router = router
|
self.router = router
|
||||||
self.my_id = my_id
|
|
||||||
|
|
||||||
# Attach this new Pubsub object to the router
|
# Attach this new Pubsub object to the router
|
||||||
self.router.attach(self)
|
self.router.attach(self)
|
||||||
|
|
||||||
|
peer_send_channel, peer_receive_channel = trio.open_memory_channel(0)
|
||||||
|
dead_peer_send_channel, dead_peer_receive_channel = trio.open_memory_channel(0)
|
||||||
|
# Only keep the receive channels in `Pubsub`.
|
||||||
|
# Therefore, we can only close from the receive side.
|
||||||
|
self.peer_receive_channel = peer_receive_channel
|
||||||
|
self.dead_peer_receive_channel = dead_peer_receive_channel
|
||||||
# Register a notifee
|
# Register a notifee
|
||||||
self.peer_queue = asyncio.Queue()
|
|
||||||
self.dead_peer_queue = asyncio.Queue()
|
|
||||||
self.host.get_network().register_notifee(
|
self.host.get_network().register_notifee(
|
||||||
PubsubNotifee(self.peer_queue, self.dead_peer_queue)
|
PubsubNotifee(peer_send_channel, dead_peer_send_channel)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Register stream handlers for each pubsub router protocol to handle
|
# Register stream handlers for each pubsub router protocol to handle
|
||||||
# the pubsub streams opened on those protocols
|
# the pubsub streams opened on those protocols
|
||||||
self.protocols = self.router.get_protocols()
|
for protocol in router.protocols:
|
||||||
for protocol in self.protocols:
|
|
||||||
self.host.set_stream_handler(protocol, self.stream_handler)
|
self.host.set_stream_handler(protocol, self.stream_handler)
|
||||||
|
|
||||||
# Use asyncio queues for proper context switching
|
|
||||||
self.incoming_msgs_from_peers = asyncio.Queue()
|
|
||||||
self.outgoing_messages = asyncio.Queue()
|
|
||||||
|
|
||||||
# keeps track of seen messages as LRU cache
|
# keeps track of seen messages as LRU cache
|
||||||
if cache_size is None:
|
if cache_size is None:
|
||||||
self.cache_size = 128
|
self.cache_size = 128
|
||||||
|
@ -129,7 +129,8 @@ class Pubsub:
|
||||||
|
|
||||||
# Map of topics we are subscribed to blocking queues
|
# Map of topics we are subscribed to blocking queues
|
||||||
# for when the given topic receives a message
|
# for when the given topic receives a message
|
||||||
self.my_topics = {}
|
self.subscribed_topics_send = {}
|
||||||
|
self.subscribed_topics_receive = {}
|
||||||
|
|
||||||
# Map of topic to peers to keep track of what peers are subscribed to
|
# Map of topic to peers to keep track of what peers are subscribed to
|
||||||
self.peer_topics = {}
|
self.peer_topics = {}
|
||||||
|
@ -142,16 +143,28 @@ class Pubsub:
|
||||||
|
|
||||||
self.counter = time.time_ns()
|
self.counter = time.time_ns()
|
||||||
|
|
||||||
self._tasks = []
|
async def run(self) -> None:
|
||||||
# Call handle peer to keep waiting for updates to peer queue
|
self.manager.run_daemon_task(self.handle_peer_queue)
|
||||||
self._tasks.append(asyncio.ensure_future(self.handle_peer_queue()))
|
self.manager.run_daemon_task(self.handle_dead_peer_queue)
|
||||||
self._tasks.append(asyncio.ensure_future(self.handle_dead_peer_queue()))
|
await self.manager.wait_finished()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def my_id(self) -> ID:
|
||||||
|
return self.host.get_id()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def protocols(self) -> Tuple[TProtocol, ...]:
|
||||||
|
return tuple(self.router.get_protocols())
|
||||||
|
|
||||||
|
@property
|
||||||
|
def topic_ids(self) -> KeysView[str]:
|
||||||
|
return self.subscribed_topics_receive.keys()
|
||||||
|
|
||||||
def get_hello_packet(self) -> rpc_pb2.RPC:
|
def get_hello_packet(self) -> rpc_pb2.RPC:
|
||||||
"""Generate subscription message with all topics we are subscribed to
|
"""Generate subscription message with all topics we are subscribed to
|
||||||
only send hello packet if we have subscribed topics."""
|
only send hello packet if we have subscribed topics."""
|
||||||
packet = rpc_pb2.RPC()
|
packet = rpc_pb2.RPC()
|
||||||
for topic_id in self.my_topics:
|
for topic_id in self.topic_ids:
|
||||||
packet.subscriptions.extend(
|
packet.subscriptions.extend(
|
||||||
[rpc_pb2.RPC.SubOpts(subscribe=True, topicid=topic_id)]
|
[rpc_pb2.RPC.SubOpts(subscribe=True, topicid=topic_id)]
|
||||||
)
|
)
|
||||||
|
@ -166,7 +179,7 @@ class Pubsub:
|
||||||
"""
|
"""
|
||||||
peer_id = stream.muxed_conn.peer_id
|
peer_id = stream.muxed_conn.peer_id
|
||||||
|
|
||||||
while True:
|
while self.manager.is_running:
|
||||||
incoming: bytes = await read_varint_prefixed_bytes(stream)
|
incoming: bytes = await read_varint_prefixed_bytes(stream)
|
||||||
rpc_incoming: rpc_pb2.RPC = rpc_pb2.RPC()
|
rpc_incoming: rpc_pb2.RPC = rpc_pb2.RPC()
|
||||||
rpc_incoming.ParseFromString(incoming)
|
rpc_incoming.ParseFromString(incoming)
|
||||||
|
@ -178,11 +191,7 @@ class Pubsub:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"received `publish` message %s from peer %s", msg, peer_id
|
"received `publish` message %s from peer %s", msg, peer_id
|
||||||
)
|
)
|
||||||
self._tasks.append(
|
self.manager.run_task(self.push_msg, peer_id, msg)
|
||||||
asyncio.ensure_future(
|
|
||||||
self.push_msg(msg_forwarder=peer_id, msg=msg)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if rpc_incoming.subscriptions:
|
if rpc_incoming.subscriptions:
|
||||||
# deal with RPC.subscriptions
|
# deal with RPC.subscriptions
|
||||||
|
@ -210,9 +219,6 @@ class Pubsub:
|
||||||
)
|
)
|
||||||
await self.router.handle_rpc(rpc_incoming, peer_id)
|
await self.router.handle_rpc(rpc_incoming, peer_id)
|
||||||
|
|
||||||
# Force context switch
|
|
||||||
await asyncio.sleep(0)
|
|
||||||
|
|
||||||
def set_topic_validator(
|
def set_topic_validator(
|
||||||
self, topic: str, validator: ValidatorFn, is_async_validator: bool
|
self, topic: str, validator: ValidatorFn, is_async_validator: bool
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -285,7 +291,6 @@ class Pubsub:
|
||||||
logger.debug("Fail to add new peer %s: stream closed", peer_id)
|
logger.debug("Fail to add new peer %s: stream closed", peer_id)
|
||||||
del self.peers[peer_id]
|
del self.peers[peer_id]
|
||||||
return
|
return
|
||||||
# TODO: Check EOF of this stream.
|
|
||||||
# TODO: Check if the peer in black list.
|
# TODO: Check if the peer in black list.
|
||||||
try:
|
try:
|
||||||
self.router.add_peer(peer_id, stream.get_protocol())
|
self.router.add_peer(peer_id, stream.get_protocol())
|
||||||
|
@ -311,23 +316,25 @@ class Pubsub:
|
||||||
|
|
||||||
async def handle_peer_queue(self) -> None:
|
async def handle_peer_queue(self) -> None:
|
||||||
"""
|
"""
|
||||||
Continuously read from peer queue and each time a new peer is found,
|
Continuously read from peer channel and each time a new peer is found,
|
||||||
open a stream to the peer using a supported pubsub protocol
|
open a stream to the peer using a supported pubsub protocol
|
||||||
TODO: Handle failure for when the peer does not support any of the
|
TODO: Handle failure for when the peer does not support any of the
|
||||||
pubsub protocols we support
|
pubsub protocols we support
|
||||||
"""
|
"""
|
||||||
while True:
|
async with self.peer_receive_channel:
|
||||||
peer_id: ID = await self.peer_queue.get()
|
while self.manager.is_running:
|
||||||
# Add Peer
|
peer_id: ID = await self.peer_receive_channel.receive()
|
||||||
self._tasks.append(asyncio.ensure_future(self._handle_new_peer(peer_id)))
|
# Add Peer
|
||||||
|
self.manager.run_task(self._handle_new_peer, peer_id)
|
||||||
|
|
||||||
async def handle_dead_peer_queue(self) -> None:
|
async def handle_dead_peer_queue(self) -> None:
|
||||||
"""Continuously read from dead peer queue and close the stream between
|
"""Continuously read from dead peer channel and close the stream between
|
||||||
that peer and remove peer info from pubsub and pubsub router."""
|
that peer and remove peer info from pubsub and pubsub router."""
|
||||||
while True:
|
async with self.dead_peer_receive_channel:
|
||||||
peer_id: ID = await self.dead_peer_queue.get()
|
while self.manager.is_running:
|
||||||
# Remove Peer
|
peer_id: ID = await self.dead_peer_receive_channel.receive()
|
||||||
self._handle_dead_peer(peer_id)
|
# Remove Peer
|
||||||
|
self._handle_dead_peer(peer_id)
|
||||||
|
|
||||||
def handle_subscription(
|
def handle_subscription(
|
||||||
self, origin_id: ID, sub_message: rpc_pb2.RPC.SubOpts
|
self, origin_id: ID, sub_message: rpc_pb2.RPC.SubOpts
|
||||||
|
@ -361,13 +368,16 @@ class Pubsub:
|
||||||
|
|
||||||
# Check if this message has any topics that we are subscribed to
|
# Check if this message has any topics that we are subscribed to
|
||||||
for topic in publish_message.topicIDs:
|
for topic in publish_message.topicIDs:
|
||||||
if topic in self.my_topics:
|
if topic in self.topic_ids:
|
||||||
# we are subscribed to a topic this message was sent for,
|
# we are subscribed to a topic this message was sent for,
|
||||||
# so add message to the subscription output queue
|
# so add message to the subscription output queue
|
||||||
# for each topic
|
# for each topic
|
||||||
await self.my_topics[topic].put(publish_message)
|
await self.subscribed_topics_send[topic].send(publish_message)
|
||||||
|
|
||||||
async def subscribe(self, topic_id: str) -> "asyncio.Queue[rpc_pb2.Message]":
|
# TODO: Change to return an `AsyncIterable` to be I/O-agnostic?
|
||||||
|
async def subscribe(
|
||||||
|
self, topic_id: str
|
||||||
|
) -> "trio.MemoryReceiveChannel[rpc_pb2.Message]":
|
||||||
"""
|
"""
|
||||||
Subscribe ourself to a topic.
|
Subscribe ourself to a topic.
|
||||||
|
|
||||||
|
@ -377,11 +387,13 @@ class Pubsub:
|
||||||
logger.debug("subscribing to topic %s", topic_id)
|
logger.debug("subscribing to topic %s", topic_id)
|
||||||
|
|
||||||
# Already subscribed
|
# Already subscribed
|
||||||
if topic_id in self.my_topics:
|
if topic_id in self.topic_ids:
|
||||||
return self.my_topics[topic_id]
|
return self.subscribed_topics_receive[topic_id]
|
||||||
|
|
||||||
# Map topic_id to blocking queue
|
# Map topic_id to a blocking channel
|
||||||
self.my_topics[topic_id] = asyncio.Queue()
|
send_channel, receive_channel = trio.open_memory_channel(math.inf)
|
||||||
|
self.subscribed_topics_send[topic_id] = send_channel
|
||||||
|
self.subscribed_topics_receive[topic_id] = receive_channel
|
||||||
|
|
||||||
# Create subscribe message
|
# Create subscribe message
|
||||||
packet: rpc_pb2.RPC = rpc_pb2.RPC()
|
packet: rpc_pb2.RPC = rpc_pb2.RPC()
|
||||||
|
@ -395,8 +407,8 @@ class Pubsub:
|
||||||
# Tell router we are joining this topic
|
# Tell router we are joining this topic
|
||||||
await self.router.join(topic_id)
|
await self.router.join(topic_id)
|
||||||
|
|
||||||
# Return the asyncio queue for messages on this topic
|
# Return the trio channel for messages on this topic
|
||||||
return self.my_topics[topic_id]
|
return receive_channel
|
||||||
|
|
||||||
async def unsubscribe(self, topic_id: str) -> None:
|
async def unsubscribe(self, topic_id: str) -> None:
|
||||||
"""
|
"""
|
||||||
|
@ -408,10 +420,14 @@ class Pubsub:
|
||||||
logger.debug("unsubscribing from topic %s", topic_id)
|
logger.debug("unsubscribing from topic %s", topic_id)
|
||||||
|
|
||||||
# Return if we already unsubscribed from the topic
|
# Return if we already unsubscribed from the topic
|
||||||
if topic_id not in self.my_topics:
|
if topic_id not in self.topic_ids:
|
||||||
return
|
return
|
||||||
# Remove topic_id from map if present
|
# Remove topic_id from the maps before yielding
|
||||||
del self.my_topics[topic_id]
|
send_channel = self.subscribed_topics_send[topic_id]
|
||||||
|
del self.subscribed_topics_send[topic_id]
|
||||||
|
del self.subscribed_topics_receive[topic_id]
|
||||||
|
# Only close the send side
|
||||||
|
await send_channel.aclose()
|
||||||
|
|
||||||
# Create unsubscribe message
|
# Create unsubscribe message
|
||||||
packet: rpc_pb2.RPC = rpc_pb2.RPC()
|
packet: rpc_pb2.RPC = rpc_pb2.RPC()
|
||||||
|
@ -453,13 +469,13 @@ class Pubsub:
|
||||||
data=data,
|
data=data,
|
||||||
topicIDs=[topic_id],
|
topicIDs=[topic_id],
|
||||||
# Origin is ourself.
|
# Origin is ourself.
|
||||||
from_id=self.host.get_id().to_bytes(),
|
from_id=self.my_id.to_bytes(),
|
||||||
seqno=self._next_seqno(),
|
seqno=self._next_seqno(),
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: Sign with our signing key
|
# TODO: Sign with our signing key
|
||||||
|
|
||||||
await self.push_msg(self.host.get_id(), msg)
|
await self.push_msg(self.my_id, msg)
|
||||||
|
|
||||||
logger.debug("successfully published message %s", msg)
|
logger.debug("successfully published message %s", msg)
|
||||||
|
|
||||||
|
@ -470,12 +486,12 @@ class Pubsub:
|
||||||
:param msg_forwarder: the peer who forward us the message.
|
:param msg_forwarder: the peer who forward us the message.
|
||||||
:param msg: the message.
|
:param msg: the message.
|
||||||
"""
|
"""
|
||||||
sync_topic_validators = []
|
sync_topic_validators: List[SyncValidatorFn] = []
|
||||||
async_topic_validator_futures: List[Awaitable[bool]] = []
|
async_topic_validators: List[AsyncValidatorFn] = []
|
||||||
for topic_validator in self.get_msg_validators(msg):
|
for topic_validator in self.get_msg_validators(msg):
|
||||||
if topic_validator.is_async:
|
if topic_validator.is_async:
|
||||||
async_topic_validator_futures.append(
|
async_topic_validators.append(
|
||||||
cast(Awaitable[bool], topic_validator.validator(msg_forwarder, msg))
|
cast(AsyncValidatorFn, topic_validator.validator)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
sync_topic_validators.append(
|
sync_topic_validators.append(
|
||||||
|
@ -488,9 +504,20 @@ class Pubsub:
|
||||||
|
|
||||||
# TODO: Implement throttle on async validators
|
# TODO: Implement throttle on async validators
|
||||||
|
|
||||||
if len(async_topic_validator_futures) > 0:
|
if len(async_topic_validators) > 0:
|
||||||
results = await asyncio.gather(*async_topic_validator_futures)
|
# TODO: Use a better pattern
|
||||||
if not all(results):
|
final_result = True
|
||||||
|
|
||||||
|
async def run_async_validator(func: AsyncValidatorFn) -> None:
|
||||||
|
nonlocal final_result
|
||||||
|
result = await func(msg_forwarder, msg)
|
||||||
|
final_result = final_result and result
|
||||||
|
|
||||||
|
async with trio.open_nursery() as nursery:
|
||||||
|
for validator in async_topic_validators:
|
||||||
|
nursery.start_soon(run_async_validator, validator)
|
||||||
|
|
||||||
|
if not final_result:
|
||||||
raise ValidationError(f"Validation failed for msg={msg}")
|
raise ValidationError(f"Validation failed for msg={msg}")
|
||||||
|
|
||||||
async def push_msg(self, msg_forwarder: ID, msg: rpc_pb2.Message) -> None:
|
async def push_msg(self, msg_forwarder: ID, msg: rpc_pb2.Message) -> None:
|
||||||
|
@ -551,14 +578,4 @@ class Pubsub:
|
||||||
self.seen_messages[msg_id] = 1
|
self.seen_messages[msg_id] = 1
|
||||||
|
|
||||||
def _is_subscribed_to_msg(self, msg: rpc_pb2.Message) -> bool:
|
def _is_subscribed_to_msg(self, msg: rpc_pb2.Message) -> bool:
|
||||||
if not self.my_topics:
|
return any(topic in self.topic_ids for topic in msg.topicIDs)
|
||||||
return False
|
|
||||||
return any(topic in self.my_topics for topic in msg.topicIDs)
|
|
||||||
|
|
||||||
async def close(self) -> None:
|
|
||||||
for task in self._tasks:
|
|
||||||
task.cancel()
|
|
||||||
try:
|
|
||||||
await task
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
pass
|
|
||||||
|
|
|
@ -8,19 +8,19 @@ from libp2p.network.notifee_interface import INotifee
|
||||||
from libp2p.network.stream.net_stream_interface import INetStream
|
from libp2p.network.stream.net_stream_interface import INetStream
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import asyncio # noqa: F401
|
import trio # noqa: F401
|
||||||
from libp2p.peer.id import ID # noqa: F401
|
from libp2p.peer.id import ID # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
class PubsubNotifee(INotifee):
|
class PubsubNotifee(INotifee):
|
||||||
|
|
||||||
initiator_peers_queue: "asyncio.Queue[ID]"
|
initiator_peers_queue: "trio.MemorySendChannel[ID]"
|
||||||
dead_peers_queue: "asyncio.Queue[ID]"
|
dead_peers_queue: "trio.MemorySendChannel[ID]"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
initiator_peers_queue: "asyncio.Queue[ID]",
|
initiator_peers_queue: "trio.MemorySendChannel[ID]",
|
||||||
dead_peers_queue: "asyncio.Queue[ID]",
|
dead_peers_queue: "trio.MemorySendChannel[ID]",
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
:param initiator_peers_queue: queue to add new peers to so that pubsub
|
:param initiator_peers_queue: queue to add new peers to so that pubsub
|
||||||
|
@ -46,7 +46,12 @@ class PubsubNotifee(INotifee):
|
||||||
:param network: network the connection was opened on
|
:param network: network the connection was opened on
|
||||||
:param conn: connection that was opened
|
:param conn: connection that was opened
|
||||||
"""
|
"""
|
||||||
await self.initiator_peers_queue.put(conn.muxed_conn.peer_id)
|
try:
|
||||||
|
await self.initiator_peers_queue.send(conn.muxed_conn.peer_id)
|
||||||
|
except trio.BrokenResourceError:
|
||||||
|
# Raised when the receive channel is closed.
|
||||||
|
# TODO: Do something with loggers?
|
||||||
|
...
|
||||||
|
|
||||||
async def disconnected(self, network: INetwork, conn: INetConn) -> None:
|
async def disconnected(self, network: INetwork, conn: INetConn) -> None:
|
||||||
"""
|
"""
|
||||||
|
@ -56,7 +61,7 @@ class PubsubNotifee(INotifee):
|
||||||
:param network: network the connection was opened on
|
:param network: network the connection was opened on
|
||||||
:param conn: connection that was opened
|
:param conn: connection that was opened
|
||||||
"""
|
"""
|
||||||
await self.dead_peers_queue.put(conn.muxed_conn.peer_id)
|
await self.dead_peers_queue.send(conn.muxed_conn.peer_id)
|
||||||
|
|
||||||
async def listen(self, network: INetwork, multiaddr: Multiaddr) -> None:
|
async def listen(self, network: INetwork, multiaddr: Multiaddr) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -281,7 +281,7 @@ class Mplex(IMuxedConn, Service):
|
||||||
mplex_stream = await self._initialize_stream(stream_id, message.decode())
|
mplex_stream = await self._initialize_stream(stream_id, message.decode())
|
||||||
try:
|
try:
|
||||||
await self.new_stream_send_channel.send(mplex_stream)
|
await self.new_stream_send_channel.send(mplex_stream)
|
||||||
except (trio.BrokenResourceError, trio.EndOfChannel):
|
except (trio.BrokenResourceError, trio.ClosedResourceError):
|
||||||
raise MplexUnavailable
|
raise MplexUnavailable
|
||||||
|
|
||||||
async def _handle_message(self, stream_id: StreamID, message: bytes) -> None:
|
async def _handle_message(self, stream_id: StreamID, message: bytes) -> None:
|
||||||
|
|
|
@ -5,6 +5,7 @@ from async_service import background_trio_service
|
||||||
import factory
|
import factory
|
||||||
import trio
|
import trio
|
||||||
|
|
||||||
|
from libp2p.tools.constants import GOSSIPSUB_PARAMS
|
||||||
from libp2p import generate_new_rsa_identity, generate_peer_id_from
|
from libp2p import generate_new_rsa_identity, generate_peer_id_from
|
||||||
from libp2p.crypto.keys import KeyPair
|
from libp2p.crypto.keys import KeyPair
|
||||||
from libp2p.host.basic_host import BasicHost
|
from libp2p.host.basic_host import BasicHost
|
||||||
|
@ -15,6 +16,7 @@ from libp2p.network.connection.swarm_connection import SwarmConn
|
||||||
from libp2p.network.stream.net_stream_interface import INetStream
|
from libp2p.network.stream.net_stream_interface import INetStream
|
||||||
from libp2p.network.swarm import Swarm
|
from libp2p.network.swarm import Swarm
|
||||||
from libp2p.peer.peerstore import PeerStore
|
from libp2p.peer.peerstore import PeerStore
|
||||||
|
from libp2p.peer.id import ID
|
||||||
from libp2p.pubsub.floodsub import FloodSub
|
from libp2p.pubsub.floodsub import FloodSub
|
||||||
from libp2p.pubsub.gossipsub import GossipSub
|
from libp2p.pubsub.gossipsub import GossipSub
|
||||||
from libp2p.pubsub.pubsub import Pubsub
|
from libp2p.pubsub.pubsub import Pubsub
|
||||||
|
@ -28,15 +30,19 @@ from libp2p.transport.typing import TMuxerOptions
|
||||||
from libp2p.transport.upgrader import TransportUpgrader
|
from libp2p.transport.upgrader import TransportUpgrader
|
||||||
from libp2p.typing import TProtocol
|
from libp2p.typing import TProtocol
|
||||||
|
|
||||||
from .constants import (
|
from .constants import FLOODSUB_PROTOCOL_ID, GOSSIPSUB_PROTOCOL_ID, LISTEN_MADDR
|
||||||
FLOODSUB_PROTOCOL_ID,
|
|
||||||
GOSSIPSUB_PARAMS,
|
|
||||||
GOSSIPSUB_PROTOCOL_ID,
|
|
||||||
LISTEN_MADDR,
|
|
||||||
)
|
|
||||||
from .utils import connect, connect_swarm
|
from .utils import connect, connect_swarm
|
||||||
|
|
||||||
|
|
||||||
|
class IDFactory(factory.Factory):
|
||||||
|
class Meta:
|
||||||
|
model = ID
|
||||||
|
|
||||||
|
peer_id_bytes = factory.LazyFunction(
|
||||||
|
lambda: generate_peer_id_from(generate_new_rsa_identity())
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def security_transport_factory(
|
def security_transport_factory(
|
||||||
is_secure: bool, key_pair: KeyPair
|
is_secure: bool, key_pair: KeyPair
|
||||||
) -> Dict[TProtocol, BaseSecureTransport]:
|
) -> Dict[TProtocol, BaseSecureTransport]:
|
||||||
|
@ -181,9 +187,38 @@ class PubsubFactory(factory.Factory):
|
||||||
|
|
||||||
host = factory.SubFactory(HostFactory)
|
host = factory.SubFactory(HostFactory)
|
||||||
router = None
|
router = None
|
||||||
my_id = factory.LazyAttribute(lambda obj: obj.host.get_id())
|
|
||||||
cache_size = None
|
cache_size = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@asynccontextmanager
|
||||||
|
async def create_and_start(cls, host, router, cache_size):
|
||||||
|
pubsub = PubsubFactory(host=host, router=router, cache_size=cache_size)
|
||||||
|
async with background_trio_service(pubsub):
|
||||||
|
yield pubsub
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@asynccontextmanager
|
||||||
|
async def create_batch_with_floodsub(
|
||||||
|
cls, number: int, is_secure: bool = False, cache_size: int = None
|
||||||
|
):
|
||||||
|
floodsubs = FloodsubFactory.create_batch(number)
|
||||||
|
async with HostFactory.create_batch_and_listen(is_secure, number) as hosts:
|
||||||
|
# Pubsubs should exit before hosts
|
||||||
|
async with AsyncExitStack() as stack:
|
||||||
|
pubsubs = [
|
||||||
|
await stack.enter_async_context(
|
||||||
|
cls.create_and_start(host, router, cache_size)
|
||||||
|
)
|
||||||
|
for host, router in zip(hosts, floodsubs)
|
||||||
|
]
|
||||||
|
yield pubsubs
|
||||||
|
|
||||||
|
# @classmethod
|
||||||
|
# async def create_batch_with_gossipsub(
|
||||||
|
# cls, number: int, cache_size: int = None, gossipsub_params=GOSSIPSUB_PARAMS
|
||||||
|
# ):
|
||||||
|
# ...
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def swarm_pair_factory(
|
async def swarm_pair_factory(
|
||||||
|
|
|
@ -4,18 +4,6 @@ from libp2p.tools.constants import GOSSIPSUB_PARAMS
|
||||||
from libp2p.tools.factories import FloodsubFactory, GossipsubFactory, PubsubFactory
|
from libp2p.tools.factories import FloodsubFactory, GossipsubFactory, PubsubFactory
|
||||||
|
|
||||||
|
|
||||||
def _make_pubsubs(hosts, pubsub_routers, cache_size):
|
|
||||||
if len(pubsub_routers) != len(hosts):
|
|
||||||
raise ValueError(
|
|
||||||
f"lenght of pubsub_routers={pubsub_routers} should be equaled to the "
|
|
||||||
f"length of hosts={len(hosts)}"
|
|
||||||
)
|
|
||||||
return tuple(
|
|
||||||
PubsubFactory(host=host, router=router, cache_size=cache_size)
|
|
||||||
for host, router in zip(hosts, pubsub_routers)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def pubsub_cache_size():
|
def pubsub_cache_size():
|
||||||
return None # default
|
return None # default
|
||||||
|
@ -26,17 +14,9 @@ def gossipsub_params():
|
||||||
return GOSSIPSUB_PARAMS
|
return GOSSIPSUB_PARAMS
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
# @pytest.fixture
|
||||||
def pubsubs_fsub(num_hosts, hosts, pubsub_cache_size):
|
# def pubsubs_gsub(num_hosts, hosts, pubsub_cache_size, gossipsub_params):
|
||||||
floodsubs = FloodsubFactory.create_batch(num_hosts)
|
# gossipsubs = GossipsubFactory.create_batch(num_hosts, **gossipsub_params._asdict())
|
||||||
_pubsubs_fsub = _make_pubsubs(hosts, floodsubs, pubsub_cache_size)
|
# _pubsubs_gsub = _make_pubsubs(hosts, gossipsubs, pubsub_cache_size)
|
||||||
yield _pubsubs_fsub
|
# yield _pubsubs_gsub
|
||||||
# TODO: Clean up
|
# # TODO: Clean up
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def pubsubs_gsub(num_hosts, hosts, pubsub_cache_size, gossipsub_params):
|
|
||||||
gossipsubs = GossipsubFactory.create_batch(num_hosts, **gossipsub_params._asdict())
|
|
||||||
_pubsubs_gsub = _make_pubsubs(hosts, gossipsubs, pubsub_cache_size)
|
|
||||||
yield _pubsubs_gsub
|
|
||||||
# TODO: Clean up
|
|
||||||
|
|
|
@ -1,348 +1,328 @@
|
||||||
import asyncio
|
from contextlib import contextmanager
|
||||||
from typing import NamedTuple
|
from typing import NamedTuple
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import trio
|
||||||
|
|
||||||
from libp2p.exceptions import ValidationError
|
from libp2p.exceptions import ValidationError
|
||||||
from libp2p.peer.id import ID
|
from libp2p.peer.id import ID
|
||||||
from libp2p.pubsub.pb import rpc_pb2
|
from libp2p.pubsub.pb import rpc_pb2
|
||||||
from libp2p.tools.pubsub.utils import make_pubsub_msg
|
from libp2p.tools.pubsub.utils import make_pubsub_msg
|
||||||
from libp2p.tools.utils import connect
|
from libp2p.tools.utils import connect
|
||||||
|
from libp2p.tools.constants import MAX_READ_LEN
|
||||||
|
from libp2p.tools.factories import PubsubFactory, net_stream_pair_factory, IDFactory
|
||||||
from libp2p.utils import encode_varint_prefixed
|
from libp2p.utils import encode_varint_prefixed
|
||||||
|
|
||||||
TESTING_TOPIC = "TEST_SUBSCRIBE"
|
TESTING_TOPIC = "TEST_SUBSCRIBE"
|
||||||
TESTING_DATA = b"data"
|
TESTING_DATA = b"data"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("num_hosts", (1,))
|
@pytest.mark.trio
|
||||||
@pytest.mark.asyncio
|
async def test_subscribe_and_unsubscribe():
|
||||||
async def test_subscribe_and_unsubscribe(pubsubs_fsub):
|
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
|
||||||
await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
|
await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
|
||||||
assert TESTING_TOPIC in pubsubs_fsub[0].my_topics
|
assert TESTING_TOPIC in pubsubs_fsub[0].topic_ids
|
||||||
|
|
||||||
await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC)
|
await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC)
|
||||||
assert TESTING_TOPIC not in pubsubs_fsub[0].my_topics
|
assert TESTING_TOPIC not in pubsubs_fsub[0].topic_ids
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("num_hosts", (1,))
|
@pytest.mark.trio
|
||||||
@pytest.mark.asyncio
|
async def test_re_subscribe():
|
||||||
async def test_re_subscribe(pubsubs_fsub):
|
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
|
||||||
await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
|
await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
|
||||||
assert TESTING_TOPIC in pubsubs_fsub[0].my_topics
|
assert TESTING_TOPIC in pubsubs_fsub[0].topic_ids
|
||||||
|
|
||||||
await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
|
await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
|
||||||
assert TESTING_TOPIC in pubsubs_fsub[0].my_topics
|
assert TESTING_TOPIC in pubsubs_fsub[0].topic_ids
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("num_hosts", (1,))
|
@pytest.mark.trio
|
||||||
@pytest.mark.asyncio
|
async def test_re_unsubscribe():
|
||||||
async def test_re_unsubscribe(pubsubs_fsub):
|
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
|
||||||
# Unsubscribe from topic we didn't even subscribe to
|
# Unsubscribe from topic we didn't even subscribe to
|
||||||
assert "NOT_MY_TOPIC" not in pubsubs_fsub[0].my_topics
|
assert "NOT_MY_TOPIC" not in pubsubs_fsub[0].topic_ids
|
||||||
await pubsubs_fsub[0].unsubscribe("NOT_MY_TOPIC")
|
await pubsubs_fsub[0].unsubscribe("NOT_MY_TOPIC")
|
||||||
assert "NOT_MY_TOPIC" not in pubsubs_fsub[0].my_topics
|
assert "NOT_MY_TOPIC" not in pubsubs_fsub[0].topic_ids
|
||||||
|
|
||||||
await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
|
await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
|
||||||
assert TESTING_TOPIC in pubsubs_fsub[0].my_topics
|
assert TESTING_TOPIC in pubsubs_fsub[0].topic_ids
|
||||||
|
|
||||||
await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC)
|
await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC)
|
||||||
assert TESTING_TOPIC not in pubsubs_fsub[0].my_topics
|
assert TESTING_TOPIC not in pubsubs_fsub[0].topic_ids
|
||||||
|
|
||||||
await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC)
|
await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC)
|
||||||
assert TESTING_TOPIC not in pubsubs_fsub[0].my_topics
|
assert TESTING_TOPIC not in pubsubs_fsub[0].topic_ids
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.trio
|
||||||
async def test_peers_subscribe(pubsubs_fsub):
|
async def test_peers_subscribe():
|
||||||
await connect(pubsubs_fsub[0].host, pubsubs_fsub[1].host)
|
async with PubsubFactory.create_batch_with_floodsub(2) as pubsubs_fsub:
|
||||||
await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
|
await connect(pubsubs_fsub[0].host, pubsubs_fsub[1].host)
|
||||||
# Yield to let 0 notify 1
|
await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
|
||||||
await asyncio.sleep(0.1)
|
# Yield to let 0 notify 1
|
||||||
assert pubsubs_fsub[0].my_id in pubsubs_fsub[1].peer_topics[TESTING_TOPIC]
|
await trio.sleep(0.1)
|
||||||
await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC)
|
assert pubsubs_fsub[0].my_id in pubsubs_fsub[1].peer_topics[TESTING_TOPIC]
|
||||||
# Yield to let 0 notify 1
|
await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC)
|
||||||
await asyncio.sleep(0.1)
|
# Yield to let 0 notify 1
|
||||||
assert pubsubs_fsub[0].my_id not in pubsubs_fsub[1].peer_topics[TESTING_TOPIC]
|
await trio.sleep(0.1)
|
||||||
|
assert pubsubs_fsub[0].my_id not in pubsubs_fsub[1].peer_topics[TESTING_TOPIC]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("num_hosts", (1,))
|
@pytest.mark.trio
|
||||||
@pytest.mark.asyncio
|
async def test_get_hello_packet():
|
||||||
async def test_get_hello_packet(pubsubs_fsub):
|
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
|
||||||
def _get_hello_packet_topic_ids():
|
|
||||||
packet = pubsubs_fsub[0].get_hello_packet()
|
|
||||||
return tuple(sub.topicid for sub in packet.subscriptions)
|
|
||||||
|
|
||||||
# Test: No subscription, so there should not be any topic ids in the hello packet.
|
def _get_hello_packet_topic_ids():
|
||||||
assert len(_get_hello_packet_topic_ids()) == 0
|
packet = pubsubs_fsub[0].get_hello_packet()
|
||||||
|
return tuple(sub.topicid for sub in packet.subscriptions)
|
||||||
|
|
||||||
# Test: After subscriptions, topic ids should be in the hello packet.
|
# Test: No subscription, so there should not be any topic ids in the hello packet.
|
||||||
topic_ids = ["t", "o", "p", "i", "c"]
|
assert len(_get_hello_packet_topic_ids()) == 0
|
||||||
await asyncio.gather(*[pubsubs_fsub[0].subscribe(topic) for topic in topic_ids])
|
|
||||||
topic_ids_in_hello = _get_hello_packet_topic_ids()
|
# Test: After subscriptions, topic ids should be in the hello packet.
|
||||||
for topic in topic_ids:
|
topic_ids = ["t", "o", "p", "i", "c"]
|
||||||
assert topic in topic_ids_in_hello
|
for topic in topic_ids:
|
||||||
|
await pubsubs_fsub[0].subscribe(topic)
|
||||||
|
topic_ids_in_hello = _get_hello_packet_topic_ids()
|
||||||
|
for topic in topic_ids:
|
||||||
|
assert topic in topic_ids_in_hello
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("num_hosts", (1,))
|
@pytest.mark.trio
|
||||||
@pytest.mark.asyncio
|
async def test_set_and_remove_topic_validator():
|
||||||
async def test_set_and_remove_topic_validator(pubsubs_fsub):
|
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
|
||||||
|
is_sync_validator_called = False
|
||||||
|
|
||||||
is_sync_validator_called = False
|
def sync_validator(peer_id, msg):
|
||||||
|
nonlocal is_sync_validator_called
|
||||||
|
is_sync_validator_called = True
|
||||||
|
|
||||||
def sync_validator(peer_id, msg):
|
is_async_validator_called = False
|
||||||
nonlocal is_sync_validator_called
|
|
||||||
is_sync_validator_called = True
|
|
||||||
|
|
||||||
is_async_validator_called = False
|
async def async_validator(peer_id, msg):
|
||||||
|
nonlocal is_async_validator_called
|
||||||
|
is_async_validator_called = True
|
||||||
|
|
||||||
async def async_validator(peer_id, msg):
|
topic = "TEST_VALIDATOR"
|
||||||
nonlocal is_async_validator_called
|
|
||||||
is_async_validator_called = True
|
|
||||||
|
|
||||||
topic = "TEST_VALIDATOR"
|
assert topic not in pubsubs_fsub[0].topic_validators
|
||||||
|
|
||||||
assert topic not in pubsubs_fsub[0].topic_validators
|
# Register sync validator
|
||||||
|
pubsubs_fsub[0].set_topic_validator(topic, sync_validator, False)
|
||||||
|
|
||||||
# Register sync validator
|
assert topic in pubsubs_fsub[0].topic_validators
|
||||||
pubsubs_fsub[0].set_topic_validator(topic, sync_validator, False)
|
topic_validator = pubsubs_fsub[0].topic_validators[topic]
|
||||||
|
assert not topic_validator.is_async
|
||||||
|
|
||||||
assert topic in pubsubs_fsub[0].topic_validators
|
# Validate with sync validator
|
||||||
topic_validator = pubsubs_fsub[0].topic_validators[topic]
|
topic_validator.validator(peer_id=IDFactory(), msg="msg")
|
||||||
assert not topic_validator.is_async
|
|
||||||
|
|
||||||
# Validate with sync validator
|
assert is_sync_validator_called
|
||||||
topic_validator.validator(peer_id=ID(b"peer"), msg="msg")
|
assert not is_async_validator_called
|
||||||
|
|
||||||
assert is_sync_validator_called
|
# Register with async validator
|
||||||
assert not is_async_validator_called
|
pubsubs_fsub[0].set_topic_validator(topic, async_validator, True)
|
||||||
|
|
||||||
# Register with async validator
|
is_sync_validator_called = False
|
||||||
pubsubs_fsub[0].set_topic_validator(topic, async_validator, True)
|
assert topic in pubsubs_fsub[0].topic_validators
|
||||||
|
topic_validator = pubsubs_fsub[0].topic_validators[topic]
|
||||||
|
assert topic_validator.is_async
|
||||||
|
|
||||||
is_sync_validator_called = False
|
# Validate with async validator
|
||||||
assert topic in pubsubs_fsub[0].topic_validators
|
await topic_validator.validator(peer_id=IDFactory(), msg="msg")
|
||||||
topic_validator = pubsubs_fsub[0].topic_validators[topic]
|
|
||||||
assert topic_validator.is_async
|
|
||||||
|
|
||||||
# Validate with async validator
|
assert is_async_validator_called
|
||||||
await topic_validator.validator(peer_id=ID(b"peer"), msg="msg")
|
assert not is_sync_validator_called
|
||||||
|
|
||||||
assert is_async_validator_called
|
# Remove validator
|
||||||
assert not is_sync_validator_called
|
pubsubs_fsub[0].remove_topic_validator(topic)
|
||||||
|
assert topic not in pubsubs_fsub[0].topic_validators
|
||||||
# Remove validator
|
|
||||||
pubsubs_fsub[0].remove_topic_validator(topic)
|
|
||||||
assert topic not in pubsubs_fsub[0].topic_validators
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("num_hosts", (1,))
|
@pytest.mark.trio
|
||||||
@pytest.mark.asyncio
|
async def test_get_msg_validators():
|
||||||
async def test_get_msg_validators(pubsubs_fsub):
|
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
|
||||||
|
times_sync_validator_called = 0
|
||||||
|
|
||||||
times_sync_validator_called = 0
|
def sync_validator(peer_id, msg):
|
||||||
|
nonlocal times_sync_validator_called
|
||||||
|
times_sync_validator_called += 1
|
||||||
|
|
||||||
def sync_validator(peer_id, msg):
|
times_async_validator_called = 0
|
||||||
nonlocal times_sync_validator_called
|
|
||||||
times_sync_validator_called += 1
|
|
||||||
|
|
||||||
times_async_validator_called = 0
|
async def async_validator(peer_id, msg):
|
||||||
|
nonlocal times_async_validator_called
|
||||||
|
times_async_validator_called += 1
|
||||||
|
|
||||||
async def async_validator(peer_id, msg):
|
topic_1 = "TEST_VALIDATOR_1"
|
||||||
nonlocal times_async_validator_called
|
topic_2 = "TEST_VALIDATOR_2"
|
||||||
times_async_validator_called += 1
|
topic_3 = "TEST_VALIDATOR_3"
|
||||||
|
|
||||||
topic_1 = "TEST_VALIDATOR_1"
|
# Register sync validator for topic 1 and 2
|
||||||
topic_2 = "TEST_VALIDATOR_2"
|
pubsubs_fsub[0].set_topic_validator(topic_1, sync_validator, False)
|
||||||
topic_3 = "TEST_VALIDATOR_3"
|
pubsubs_fsub[0].set_topic_validator(topic_2, sync_validator, False)
|
||||||
|
|
||||||
# Register sync validator for topic 1 and 2
|
# Register async validator for topic 3
|
||||||
pubsubs_fsub[0].set_topic_validator(topic_1, sync_validator, False)
|
pubsubs_fsub[0].set_topic_validator(topic_3, async_validator, True)
|
||||||
pubsubs_fsub[0].set_topic_validator(topic_2, sync_validator, False)
|
|
||||||
|
|
||||||
# Register async validator for topic 3
|
msg = make_pubsub_msg(
|
||||||
pubsubs_fsub[0].set_topic_validator(topic_3, async_validator, True)
|
origin_id=pubsubs_fsub[0].my_id,
|
||||||
|
topic_ids=[topic_1, topic_2, topic_3],
|
||||||
|
data=b"1234",
|
||||||
|
seqno=b"\x00" * 8,
|
||||||
|
)
|
||||||
|
|
||||||
msg = make_pubsub_msg(
|
topic_validators = pubsubs_fsub[0].get_msg_validators(msg)
|
||||||
origin_id=pubsubs_fsub[0].my_id,
|
for topic_validator in topic_validators:
|
||||||
topic_ids=[topic_1, topic_2, topic_3],
|
if topic_validator.is_async:
|
||||||
data=b"1234",
|
await topic_validator.validator(peer_id=IDFactory(), msg="msg")
|
||||||
seqno=b"\x00" * 8,
|
else:
|
||||||
)
|
topic_validator.validator(peer_id=IDFactory(), msg="msg")
|
||||||
|
|
||||||
topic_validators = pubsubs_fsub[0].get_msg_validators(msg)
|
assert times_sync_validator_called == 2
|
||||||
for topic_validator in topic_validators:
|
assert times_async_validator_called == 1
|
||||||
if topic_validator.is_async:
|
|
||||||
await topic_validator.validator(peer_id=ID(b"peer"), msg="msg")
|
|
||||||
else:
|
|
||||||
topic_validator.validator(peer_id=ID(b"peer"), msg="msg")
|
|
||||||
|
|
||||||
assert times_sync_validator_called == 2
|
|
||||||
assert times_async_validator_called == 1
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("num_hosts", (1,))
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"is_topic_1_val_passed, is_topic_2_val_passed",
|
"is_topic_1_val_passed, is_topic_2_val_passed",
|
||||||
((False, True), (True, False), (True, True)),
|
((False, True), (True, False), (True, True)),
|
||||||
)
|
)
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.trio
|
||||||
async def test_validate_msg(pubsubs_fsub, is_topic_1_val_passed, is_topic_2_val_passed):
|
async def test_validate_msg(is_topic_1_val_passed, is_topic_2_val_passed):
|
||||||
def passed_sync_validator(peer_id, msg):
|
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
|
||||||
return True
|
|
||||||
|
|
||||||
def failed_sync_validator(peer_id, msg):
|
def passed_sync_validator(peer_id, msg):
|
||||||
return False
|
return True
|
||||||
|
|
||||||
async def passed_async_validator(peer_id, msg):
|
def failed_sync_validator(peer_id, msg):
|
||||||
return True
|
return False
|
||||||
|
|
||||||
async def failed_async_validator(peer_id, msg):
|
async def passed_async_validator(peer_id, msg):
|
||||||
return False
|
return True
|
||||||
|
|
||||||
topic_1 = "TEST_SYNC_VALIDATOR"
|
async def failed_async_validator(peer_id, msg):
|
||||||
topic_2 = "TEST_ASYNC_VALIDATOR"
|
return False
|
||||||
|
|
||||||
if is_topic_1_val_passed:
|
topic_1 = "TEST_SYNC_VALIDATOR"
|
||||||
pubsubs_fsub[0].set_topic_validator(topic_1, passed_sync_validator, False)
|
topic_2 = "TEST_ASYNC_VALIDATOR"
|
||||||
else:
|
|
||||||
pubsubs_fsub[0].set_topic_validator(topic_1, failed_sync_validator, False)
|
|
||||||
|
|
||||||
if is_topic_2_val_passed:
|
if is_topic_1_val_passed:
|
||||||
pubsubs_fsub[0].set_topic_validator(topic_2, passed_async_validator, True)
|
pubsubs_fsub[0].set_topic_validator(topic_1, passed_sync_validator, False)
|
||||||
else:
|
|
||||||
pubsubs_fsub[0].set_topic_validator(topic_2, failed_async_validator, True)
|
|
||||||
|
|
||||||
msg = make_pubsub_msg(
|
|
||||||
origin_id=pubsubs_fsub[0].my_id,
|
|
||||||
topic_ids=[topic_1, topic_2],
|
|
||||||
data=b"1234",
|
|
||||||
seqno=b"\x00" * 8,
|
|
||||||
)
|
|
||||||
|
|
||||||
if is_topic_1_val_passed and is_topic_2_val_passed:
|
|
||||||
await pubsubs_fsub[0].validate_msg(pubsubs_fsub[0].my_id, msg)
|
|
||||||
else:
|
|
||||||
with pytest.raises(ValidationError):
|
|
||||||
await pubsubs_fsub[0].validate_msg(pubsubs_fsub[0].my_id, msg)
|
|
||||||
|
|
||||||
|
|
||||||
class FakeNetStream:
|
|
||||||
_queue: asyncio.Queue
|
|
||||||
|
|
||||||
class FakeMplexConn(NamedTuple):
|
|
||||||
peer_id: ID = ID(b"\x12\x20" + b"\x00" * 32)
|
|
||||||
|
|
||||||
muxed_conn = FakeMplexConn()
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self._queue = asyncio.Queue()
|
|
||||||
|
|
||||||
async def read(self, n: int = -1) -> bytes:
|
|
||||||
buf = bytearray()
|
|
||||||
# Force to blocking wait if no data available now.
|
|
||||||
if self._queue.empty():
|
|
||||||
first_byte = await self._queue.get()
|
|
||||||
buf.extend(first_byte)
|
|
||||||
# If `n == -1`, read until no data is in the buffer(_queue).
|
|
||||||
# Else, read until no data is in the buffer(_queue) or we have read `n` bytes.
|
|
||||||
while (n == -1) or (len(buf) < n):
|
|
||||||
if self._queue.empty():
|
|
||||||
break
|
|
||||||
buf.extend(await self._queue.get())
|
|
||||||
return bytes(buf)
|
|
||||||
|
|
||||||
async def write(self, data: bytes) -> int:
|
|
||||||
for i in data:
|
|
||||||
await self._queue.put(i.to_bytes(1, "big"))
|
|
||||||
return len(data)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("num_hosts", (1,))
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_continuously_read_stream(pubsubs_fsub, monkeypatch):
|
|
||||||
stream = FakeNetStream()
|
|
||||||
|
|
||||||
await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
|
|
||||||
|
|
||||||
event_push_msg = asyncio.Event()
|
|
||||||
event_handle_subscription = asyncio.Event()
|
|
||||||
event_handle_rpc = asyncio.Event()
|
|
||||||
|
|
||||||
async def mock_push_msg(msg_forwarder, msg):
|
|
||||||
event_push_msg.set()
|
|
||||||
|
|
||||||
def mock_handle_subscription(origin_id, sub_message):
|
|
||||||
event_handle_subscription.set()
|
|
||||||
|
|
||||||
async def mock_handle_rpc(rpc, sender_peer_id):
|
|
||||||
event_handle_rpc.set()
|
|
||||||
|
|
||||||
monkeypatch.setattr(pubsubs_fsub[0], "push_msg", mock_push_msg)
|
|
||||||
monkeypatch.setattr(
|
|
||||||
pubsubs_fsub[0], "handle_subscription", mock_handle_subscription
|
|
||||||
)
|
|
||||||
monkeypatch.setattr(pubsubs_fsub[0].router, "handle_rpc", mock_handle_rpc)
|
|
||||||
|
|
||||||
async def wait_for_event_occurring(event):
|
|
||||||
try:
|
|
||||||
await asyncio.wait_for(event.wait(), timeout=1)
|
|
||||||
except asyncio.TimeoutError as error:
|
|
||||||
event.clear()
|
|
||||||
raise asyncio.TimeoutError(
|
|
||||||
f"Event {event} is not set before the timeout. "
|
|
||||||
"This indicates the mocked functions are not called properly."
|
|
||||||
) from error
|
|
||||||
else:
|
else:
|
||||||
event.clear()
|
pubsubs_fsub[0].set_topic_validator(topic_1, failed_sync_validator, False)
|
||||||
|
|
||||||
# Kick off the task `continuously_read_stream`
|
if is_topic_2_val_passed:
|
||||||
task = asyncio.ensure_future(pubsubs_fsub[0].continuously_read_stream(stream))
|
pubsubs_fsub[0].set_topic_validator(topic_2, passed_async_validator, True)
|
||||||
|
else:
|
||||||
|
pubsubs_fsub[0].set_topic_validator(topic_2, failed_async_validator, True)
|
||||||
|
|
||||||
# Test: `push_msg` is called when publishing to a subscribed topic.
|
msg = make_pubsub_msg(
|
||||||
publish_subscribed_topic = rpc_pb2.RPC(
|
origin_id=pubsubs_fsub[0].my_id,
|
||||||
publish=[rpc_pb2.Message(topicIDs=[TESTING_TOPIC])]
|
topic_ids=[topic_1, topic_2],
|
||||||
)
|
data=b"1234",
|
||||||
await stream.write(
|
seqno=b"\x00" * 8,
|
||||||
encode_varint_prefixed(publish_subscribed_topic.SerializeToString())
|
)
|
||||||
)
|
|
||||||
await wait_for_event_occurring(event_push_msg)
|
|
||||||
# Make sure the other events are not emitted.
|
|
||||||
with pytest.raises(asyncio.TimeoutError):
|
|
||||||
await wait_for_event_occurring(event_handle_subscription)
|
|
||||||
with pytest.raises(asyncio.TimeoutError):
|
|
||||||
await wait_for_event_occurring(event_handle_rpc)
|
|
||||||
|
|
||||||
# Test: `push_msg` is not called when publishing to a topic-not-subscribed.
|
if is_topic_1_val_passed and is_topic_2_val_passed:
|
||||||
publish_not_subscribed_topic = rpc_pb2.RPC(
|
await pubsubs_fsub[0].validate_msg(pubsubs_fsub[0].my_id, msg)
|
||||||
publish=[rpc_pb2.Message(topicIDs=["NOT_SUBSCRIBED"])]
|
else:
|
||||||
)
|
with pytest.raises(ValidationError):
|
||||||
await stream.write(
|
await pubsubs_fsub[0].validate_msg(pubsubs_fsub[0].my_id, msg)
|
||||||
encode_varint_prefixed(publish_not_subscribed_topic.SerializeToString())
|
|
||||||
)
|
|
||||||
with pytest.raises(asyncio.TimeoutError):
|
|
||||||
await wait_for_event_occurring(event_push_msg)
|
|
||||||
|
|
||||||
# Test: `handle_subscription` is called when a subscription message is received.
|
|
||||||
subscription_msg = rpc_pb2.RPC(subscriptions=[rpc_pb2.RPC.SubOpts()])
|
|
||||||
await stream.write(encode_varint_prefixed(subscription_msg.SerializeToString()))
|
|
||||||
await wait_for_event_occurring(event_handle_subscription)
|
|
||||||
# Make sure the other events are not emitted.
|
|
||||||
with pytest.raises(asyncio.TimeoutError):
|
|
||||||
await wait_for_event_occurring(event_push_msg)
|
|
||||||
with pytest.raises(asyncio.TimeoutError):
|
|
||||||
await wait_for_event_occurring(event_handle_rpc)
|
|
||||||
|
|
||||||
# Test: `handle_rpc` is called when a control message is received.
|
@pytest.mark.trio
|
||||||
control_msg = rpc_pb2.RPC(control=rpc_pb2.ControlMessage())
|
async def test_continuously_read_stream(monkeypatch, nursery, is_host_secure):
|
||||||
await stream.write(encode_varint_prefixed(control_msg.SerializeToString()))
|
async def wait_for_event_occurring(event):
|
||||||
await wait_for_event_occurring(event_handle_rpc)
|
with trio.fail_after(0.1):
|
||||||
# Make sure the other events are not emitted.
|
await event.wait()
|
||||||
with pytest.raises(asyncio.TimeoutError):
|
|
||||||
await wait_for_event_occurring(event_push_msg)
|
|
||||||
with pytest.raises(asyncio.TimeoutError):
|
|
||||||
await wait_for_event_occurring(event_handle_subscription)
|
|
||||||
|
|
||||||
task.cancel()
|
class Events(NamedTuple):
|
||||||
|
push_msg: trio.Event
|
||||||
|
handle_subscription: trio.Event
|
||||||
|
handle_rpc: trio.Event
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def mock_methods():
|
||||||
|
event_push_msg = trio.Event()
|
||||||
|
event_handle_subscription = trio.Event()
|
||||||
|
event_handle_rpc = trio.Event()
|
||||||
|
|
||||||
|
async def mock_push_msg(msg_forwarder, msg):
|
||||||
|
event_push_msg.set()
|
||||||
|
await trio.sleep(0)
|
||||||
|
|
||||||
|
def mock_handle_subscription(origin_id, sub_message):
|
||||||
|
event_handle_subscription.set()
|
||||||
|
|
||||||
|
async def mock_handle_rpc(rpc, sender_peer_id):
|
||||||
|
event_handle_rpc.set()
|
||||||
|
await trio.sleep(0)
|
||||||
|
|
||||||
|
with monkeypatch.context() as m:
|
||||||
|
m.setattr(pubsubs_fsub[0], "push_msg", mock_push_msg)
|
||||||
|
m.setattr(pubsubs_fsub[0], "handle_subscription", mock_handle_subscription)
|
||||||
|
m.setattr(pubsubs_fsub[0].router, "handle_rpc", mock_handle_rpc)
|
||||||
|
yield Events(event_push_msg, event_handle_subscription, event_handle_rpc)
|
||||||
|
|
||||||
|
async with PubsubFactory.create_batch_with_floodsub(
|
||||||
|
1, is_secure=is_host_secure
|
||||||
|
) as pubsubs_fsub, net_stream_pair_factory(is_secure=is_host_secure) as stream_pair:
|
||||||
|
await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
|
||||||
|
# Kick off the task `continuously_read_stream`
|
||||||
|
nursery.start_soon(pubsubs_fsub[0].continuously_read_stream, stream_pair[0])
|
||||||
|
|
||||||
|
# Test: `push_msg` is called when publishing to a subscribed topic.
|
||||||
|
publish_subscribed_topic = rpc_pb2.RPC(
|
||||||
|
publish=[rpc_pb2.Message(topicIDs=[TESTING_TOPIC])]
|
||||||
|
)
|
||||||
|
with mock_methods() as events:
|
||||||
|
await stream_pair[1].write(
|
||||||
|
encode_varint_prefixed(publish_subscribed_topic.SerializeToString())
|
||||||
|
)
|
||||||
|
await wait_for_event_occurring(events.push_msg)
|
||||||
|
# Make sure the other events are not emitted.
|
||||||
|
with pytest.raises(trio.TooSlowError):
|
||||||
|
await wait_for_event_occurring(events.handle_subscription)
|
||||||
|
with pytest.raises(trio.TooSlowError):
|
||||||
|
await wait_for_event_occurring(events.handle_rpc)
|
||||||
|
|
||||||
|
# Test: `push_msg` is not called when publishing to a topic-not-subscribed.
|
||||||
|
publish_not_subscribed_topic = rpc_pb2.RPC(
|
||||||
|
publish=[rpc_pb2.Message(topicIDs=["NOT_SUBSCRIBED"])]
|
||||||
|
)
|
||||||
|
with mock_methods() as events:
|
||||||
|
await stream_pair[1].write(
|
||||||
|
encode_varint_prefixed(publish_not_subscribed_topic.SerializeToString())
|
||||||
|
)
|
||||||
|
with pytest.raises(trio.TooSlowError):
|
||||||
|
await wait_for_event_occurring(events.push_msg)
|
||||||
|
|
||||||
|
# Test: `handle_subscription` is called when a subscription message is received.
|
||||||
|
subscription_msg = rpc_pb2.RPC(subscriptions=[rpc_pb2.RPC.SubOpts()])
|
||||||
|
with mock_methods() as events:
|
||||||
|
await stream_pair[1].write(
|
||||||
|
encode_varint_prefixed(subscription_msg.SerializeToString())
|
||||||
|
)
|
||||||
|
await wait_for_event_occurring(events.handle_subscription)
|
||||||
|
# Make sure the other events are not emitted.
|
||||||
|
with pytest.raises(trio.TooSlowError):
|
||||||
|
await wait_for_event_occurring(events.push_msg)
|
||||||
|
with pytest.raises(trio.TooSlowError):
|
||||||
|
await wait_for_event_occurring(events.handle_rpc)
|
||||||
|
|
||||||
|
# Test: `handle_rpc` is called when a control message is received.
|
||||||
|
control_msg = rpc_pb2.RPC(control=rpc_pb2.ControlMessage())
|
||||||
|
with mock_methods() as events:
|
||||||
|
await stream_pair[1].write(
|
||||||
|
encode_varint_prefixed(control_msg.SerializeToString())
|
||||||
|
)
|
||||||
|
await wait_for_event_occurring(events.handle_rpc)
|
||||||
|
# Make sure the other events are not emitted.
|
||||||
|
with pytest.raises(trio.TooSlowError):
|
||||||
|
await wait_for_event_occurring(events.push_msg)
|
||||||
|
with pytest.raises(trio.TooSlowError):
|
||||||
|
await wait_for_event_occurring(events.handle_subscription)
|
||||||
|
|
||||||
|
|
||||||
# TODO: Add the following tests after they are aligned with Go.
|
# TODO: Add the following tests after they are aligned with Go.
|
||||||
|
@ -351,81 +331,84 @@ async def test_continuously_read_stream(pubsubs_fsub, monkeypatch):
|
||||||
# - `test_handle_peer_queue`
|
# - `test_handle_peer_queue`
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("num_hosts", (1,))
|
@pytest.mark.trio
|
||||||
def test_handle_subscription(pubsubs_fsub):
|
async def test_handle_subscription():
|
||||||
assert len(pubsubs_fsub[0].peer_topics) == 0
|
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
|
||||||
sub_msg_0 = rpc_pb2.RPC.SubOpts(subscribe=True, topicid=TESTING_TOPIC)
|
assert len(pubsubs_fsub[0].peer_topics) == 0
|
||||||
peer_ids = [ID(b"\x12\x20" + i.to_bytes(32, "big")) for i in range(2)]
|
sub_msg_0 = rpc_pb2.RPC.SubOpts(subscribe=True, topicid=TESTING_TOPIC)
|
||||||
# Test: One peer is subscribed
|
peer_ids = [IDFactory() for _ in range(2)]
|
||||||
pubsubs_fsub[0].handle_subscription(peer_ids[0], sub_msg_0)
|
# Test: One peer is subscribed
|
||||||
assert (
|
pubsubs_fsub[0].handle_subscription(peer_ids[0], sub_msg_0)
|
||||||
len(pubsubs_fsub[0].peer_topics) == 1
|
assert (
|
||||||
and TESTING_TOPIC in pubsubs_fsub[0].peer_topics
|
len(pubsubs_fsub[0].peer_topics) == 1
|
||||||
)
|
and TESTING_TOPIC in pubsubs_fsub[0].peer_topics
|
||||||
assert len(pubsubs_fsub[0].peer_topics[TESTING_TOPIC]) == 1
|
)
|
||||||
assert peer_ids[0] in pubsubs_fsub[0].peer_topics[TESTING_TOPIC]
|
assert len(pubsubs_fsub[0].peer_topics[TESTING_TOPIC]) == 1
|
||||||
# Test: Another peer is subscribed
|
assert peer_ids[0] in pubsubs_fsub[0].peer_topics[TESTING_TOPIC]
|
||||||
pubsubs_fsub[0].handle_subscription(peer_ids[1], sub_msg_0)
|
# Test: Another peer is subscribed
|
||||||
assert len(pubsubs_fsub[0].peer_topics) == 1
|
pubsubs_fsub[0].handle_subscription(peer_ids[1], sub_msg_0)
|
||||||
assert len(pubsubs_fsub[0].peer_topics[TESTING_TOPIC]) == 2
|
assert len(pubsubs_fsub[0].peer_topics) == 1
|
||||||
assert peer_ids[1] in pubsubs_fsub[0].peer_topics[TESTING_TOPIC]
|
assert len(pubsubs_fsub[0].peer_topics[TESTING_TOPIC]) == 2
|
||||||
# Test: Subscribe to another topic
|
assert peer_ids[1] in pubsubs_fsub[0].peer_topics[TESTING_TOPIC]
|
||||||
another_topic = "ANOTHER_TOPIC"
|
# Test: Subscribe to another topic
|
||||||
sub_msg_1 = rpc_pb2.RPC.SubOpts(subscribe=True, topicid=another_topic)
|
another_topic = "ANOTHER_TOPIC"
|
||||||
pubsubs_fsub[0].handle_subscription(peer_ids[0], sub_msg_1)
|
sub_msg_1 = rpc_pb2.RPC.SubOpts(subscribe=True, topicid=another_topic)
|
||||||
assert len(pubsubs_fsub[0].peer_topics) == 2
|
pubsubs_fsub[0].handle_subscription(peer_ids[0], sub_msg_1)
|
||||||
assert another_topic in pubsubs_fsub[0].peer_topics
|
assert len(pubsubs_fsub[0].peer_topics) == 2
|
||||||
assert peer_ids[0] in pubsubs_fsub[0].peer_topics[another_topic]
|
assert another_topic in pubsubs_fsub[0].peer_topics
|
||||||
# Test: unsubscribe
|
assert peer_ids[0] in pubsubs_fsub[0].peer_topics[another_topic]
|
||||||
unsub_msg = rpc_pb2.RPC.SubOpts(subscribe=False, topicid=TESTING_TOPIC)
|
# Test: unsubscribe
|
||||||
pubsubs_fsub[0].handle_subscription(peer_ids[0], unsub_msg)
|
unsub_msg = rpc_pb2.RPC.SubOpts(subscribe=False, topicid=TESTING_TOPIC)
|
||||||
assert peer_ids[0] not in pubsubs_fsub[0].peer_topics[TESTING_TOPIC]
|
pubsubs_fsub[0].handle_subscription(peer_ids[0], unsub_msg)
|
||||||
|
assert peer_ids[0] not in pubsubs_fsub[0].peer_topics[TESTING_TOPIC]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("num_hosts", (1,))
|
@pytest.mark.trio
|
||||||
@pytest.mark.asyncio
|
async def test_handle_talk():
|
||||||
async def test_handle_talk(pubsubs_fsub):
|
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
|
||||||
sub = await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
|
sub = await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
|
||||||
msg_0 = make_pubsub_msg(
|
msg_0 = make_pubsub_msg(
|
||||||
origin_id=pubsubs_fsub[0].my_id,
|
origin_id=pubsubs_fsub[0].my_id,
|
||||||
topic_ids=[TESTING_TOPIC],
|
topic_ids=[TESTING_TOPIC],
|
||||||
data=b"1234",
|
data=b"1234",
|
||||||
seqno=b"\x00" * 8,
|
seqno=b"\x00" * 8,
|
||||||
)
|
)
|
||||||
await pubsubs_fsub[0].handle_talk(msg_0)
|
await pubsubs_fsub[0].handle_talk(msg_0)
|
||||||
msg_1 = make_pubsub_msg(
|
msg_1 = make_pubsub_msg(
|
||||||
origin_id=pubsubs_fsub[0].my_id,
|
origin_id=pubsubs_fsub[0].my_id,
|
||||||
topic_ids=["NOT_SUBSCRIBED"],
|
topic_ids=["NOT_SUBSCRIBED"],
|
||||||
data=b"1234",
|
data=b"1234",
|
||||||
seqno=b"\x11" * 8,
|
seqno=b"\x11" * 8,
|
||||||
)
|
)
|
||||||
await pubsubs_fsub[0].handle_talk(msg_1)
|
await pubsubs_fsub[0].handle_talk(msg_1)
|
||||||
assert (
|
assert (
|
||||||
len(pubsubs_fsub[0].my_topics) == 1
|
len(pubsubs_fsub[0].topic_ids) == 1
|
||||||
and sub == pubsubs_fsub[0].my_topics[TESTING_TOPIC]
|
and sub == pubsubs_fsub[0].subscribed_topics_receive[TESTING_TOPIC]
|
||||||
)
|
)
|
||||||
assert sub.qsize() == 1
|
assert (await sub.receive()) == msg_0
|
||||||
assert (await sub.get()) == msg_0
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("num_hosts", (1,))
|
@pytest.mark.trio
|
||||||
@pytest.mark.asyncio
|
async def test_message_all_peers(monkeypatch, is_host_secure):
|
||||||
async def test_message_all_peers(pubsubs_fsub, monkeypatch):
|
async with PubsubFactory.create_batch_with_floodsub(
|
||||||
peer_ids = [ID(b"\x12\x20" + i.to_bytes(32, "big")) for i in range(10)]
|
1, is_secure=is_host_secure
|
||||||
mock_peers = {peer_id: FakeNetStream() for peer_id in peer_ids}
|
) as pubsubs_fsub, net_stream_pair_factory(is_secure=is_host_secure) as stream_pair:
|
||||||
monkeypatch.setattr(pubsubs_fsub[0], "peers", mock_peers)
|
peer_id = IDFactory()
|
||||||
|
mock_peers = {peer_id: stream_pair[0]}
|
||||||
|
with monkeypatch.context() as m:
|
||||||
|
m.setattr(pubsubs_fsub[0], "peers", mock_peers)
|
||||||
|
|
||||||
empty_rpc = rpc_pb2.RPC()
|
empty_rpc = rpc_pb2.RPC()
|
||||||
empty_rpc_bytes = empty_rpc.SerializeToString()
|
empty_rpc_bytes = empty_rpc.SerializeToString()
|
||||||
empty_rpc_bytes_len_prefixed = encode_varint_prefixed(empty_rpc_bytes)
|
empty_rpc_bytes_len_prefixed = encode_varint_prefixed(empty_rpc_bytes)
|
||||||
await pubsubs_fsub[0].message_all_peers(empty_rpc_bytes)
|
await pubsubs_fsub[0].message_all_peers(empty_rpc_bytes)
|
||||||
for stream in mock_peers.values():
|
assert (
|
||||||
assert (await stream.read()) == empty_rpc_bytes_len_prefixed
|
await stream_pair[1].read(MAX_READ_LEN)
|
||||||
|
) == empty_rpc_bytes_len_prefixed
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("num_hosts", (1,))
|
@pytest.mark.trio
|
||||||
@pytest.mark.asyncio
|
async def test_publish(monkeypatch):
|
||||||
async def test_publish(pubsubs_fsub, monkeypatch):
|
|
||||||
msg_forwarders = []
|
msg_forwarders = []
|
||||||
msgs = []
|
msgs = []
|
||||||
|
|
||||||
|
@ -433,80 +416,97 @@ async def test_publish(pubsubs_fsub, monkeypatch):
|
||||||
msg_forwarders.append(msg_forwarder)
|
msg_forwarders.append(msg_forwarder)
|
||||||
msgs.append(msg)
|
msgs.append(msg)
|
||||||
|
|
||||||
monkeypatch.setattr(pubsubs_fsub[0], "push_msg", push_msg)
|
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
|
||||||
|
with monkeypatch.context() as m:
|
||||||
|
m.setattr(pubsubs_fsub[0], "push_msg", push_msg)
|
||||||
|
|
||||||
await pubsubs_fsub[0].publish(TESTING_TOPIC, TESTING_DATA)
|
await pubsubs_fsub[0].publish(TESTING_TOPIC, TESTING_DATA)
|
||||||
await pubsubs_fsub[0].publish(TESTING_TOPIC, TESTING_DATA)
|
await pubsubs_fsub[0].publish(TESTING_TOPIC, TESTING_DATA)
|
||||||
|
|
||||||
assert len(msgs) == 2, "`push_msg` should be called every time `publish` is called"
|
assert (
|
||||||
assert (msg_forwarders[0] == msg_forwarders[1]) and (
|
len(msgs) == 2
|
||||||
msg_forwarders[1] == pubsubs_fsub[0].my_id
|
), "`push_msg` should be called every time `publish` is called"
|
||||||
)
|
assert (msg_forwarders[0] == msg_forwarders[1]) and (
|
||||||
assert msgs[0].seqno != msgs[1].seqno, "`seqno` should be different every time"
|
msg_forwarders[1] == pubsubs_fsub[0].my_id
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
msgs[0].seqno != msgs[1].seqno
|
||||||
|
), "`seqno` should be different every time"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("num_hosts", (1,))
|
@pytest.mark.trio
|
||||||
@pytest.mark.asyncio
|
async def test_push_msg(monkeypatch):
|
||||||
async def test_push_msg(pubsubs_fsub, monkeypatch):
|
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
|
||||||
msg_0 = make_pubsub_msg(
|
msg_0 = make_pubsub_msg(
|
||||||
origin_id=pubsubs_fsub[0].my_id,
|
origin_id=pubsubs_fsub[0].my_id,
|
||||||
topic_ids=[TESTING_TOPIC],
|
topic_ids=[TESTING_TOPIC],
|
||||||
data=TESTING_DATA,
|
data=TESTING_DATA,
|
||||||
seqno=b"\x00" * 8,
|
seqno=b"\x00" * 8,
|
||||||
)
|
)
|
||||||
|
|
||||||
event = asyncio.Event()
|
@contextmanager
|
||||||
|
def mock_router_publish():
|
||||||
|
|
||||||
async def router_publish(*args, **kwargs):
|
event = trio.Event()
|
||||||
event.set()
|
|
||||||
|
|
||||||
monkeypatch.setattr(pubsubs_fsub[0].router, "publish", router_publish)
|
async def router_publish(*args, **kwargs):
|
||||||
|
event.set()
|
||||||
|
await trio.sleep(0)
|
||||||
|
|
||||||
# Test: `msg` is not seen before `push_msg`, and is seen after `push_msg`.
|
with monkeypatch.context() as m:
|
||||||
assert not pubsubs_fsub[0]._is_msg_seen(msg_0)
|
m.setattr(pubsubs_fsub[0].router, "publish", router_publish)
|
||||||
await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg_0)
|
yield event
|
||||||
assert pubsubs_fsub[0]._is_msg_seen(msg_0)
|
|
||||||
# Test: Ensure `router.publish` is called in `push_msg`
|
|
||||||
await asyncio.wait_for(event.wait(), timeout=0.1)
|
|
||||||
|
|
||||||
# Test: `push_msg` the message again and it will be reject.
|
with mock_router_publish() as event:
|
||||||
# `router_publish` is not called then.
|
# Test: `msg` is not seen before `push_msg`, and is seen after `push_msg`.
|
||||||
event.clear()
|
assert not pubsubs_fsub[0]._is_msg_seen(msg_0)
|
||||||
await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg_0)
|
await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg_0)
|
||||||
await asyncio.sleep(0.01)
|
assert pubsubs_fsub[0]._is_msg_seen(msg_0)
|
||||||
assert not event.is_set()
|
# Test: Ensure `router.publish` is called in `push_msg`
|
||||||
|
with trio.fail_after(0.1):
|
||||||
|
await event.wait()
|
||||||
|
|
||||||
sub = await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
|
with mock_router_publish() as event:
|
||||||
# Test: `push_msg` succeeds with another unseen msg.
|
# Test: `push_msg` the message again and it will be reject.
|
||||||
msg_1 = make_pubsub_msg(
|
# `router_publish` is not called then.
|
||||||
origin_id=pubsubs_fsub[0].my_id,
|
await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg_0)
|
||||||
topic_ids=[TESTING_TOPIC],
|
await trio.sleep(0.01)
|
||||||
data=TESTING_DATA,
|
assert not event.is_set()
|
||||||
seqno=b"\x11" * 8,
|
|
||||||
)
|
|
||||||
assert not pubsubs_fsub[0]._is_msg_seen(msg_1)
|
|
||||||
await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg_1)
|
|
||||||
assert pubsubs_fsub[0]._is_msg_seen(msg_1)
|
|
||||||
await asyncio.wait_for(event.wait(), timeout=0.1)
|
|
||||||
# Test: Subscribers are notified when `push_msg` new messages.
|
|
||||||
assert (await sub.get()) == msg_1
|
|
||||||
|
|
||||||
# Test: add a topic validator and `push_msg` the message that
|
sub = await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
|
||||||
# does not pass the validation.
|
# Test: `push_msg` succeeds with another unseen msg.
|
||||||
# `router_publish` is not called then.
|
msg_1 = make_pubsub_msg(
|
||||||
def failed_sync_validator(peer_id, msg):
|
origin_id=pubsubs_fsub[0].my_id,
|
||||||
return False
|
topic_ids=[TESTING_TOPIC],
|
||||||
|
data=TESTING_DATA,
|
||||||
|
seqno=b"\x11" * 8,
|
||||||
|
)
|
||||||
|
assert not pubsubs_fsub[0]._is_msg_seen(msg_1)
|
||||||
|
await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg_1)
|
||||||
|
assert pubsubs_fsub[0]._is_msg_seen(msg_1)
|
||||||
|
with trio.fail_after(0.1):
|
||||||
|
await event.wait()
|
||||||
|
# Test: Subscribers are notified when `push_msg` new messages.
|
||||||
|
assert (await sub.receive()) == msg_1
|
||||||
|
|
||||||
pubsubs_fsub[0].set_topic_validator(TESTING_TOPIC, failed_sync_validator, False)
|
with mock_router_publish() as event:
|
||||||
|
# Test: add a topic validator and `push_msg` the message that
|
||||||
|
# does not pass the validation.
|
||||||
|
# `router_publish` is not called then.
|
||||||
|
def failed_sync_validator(peer_id, msg):
|
||||||
|
return False
|
||||||
|
|
||||||
msg_2 = make_pubsub_msg(
|
pubsubs_fsub[0].set_topic_validator(
|
||||||
origin_id=pubsubs_fsub[0].my_id,
|
TESTING_TOPIC, failed_sync_validator, False
|
||||||
topic_ids=[TESTING_TOPIC],
|
)
|
||||||
data=TESTING_DATA,
|
|
||||||
seqno=b"\x22" * 8,
|
|
||||||
)
|
|
||||||
|
|
||||||
event.clear()
|
msg_2 = make_pubsub_msg(
|
||||||
await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg_2)
|
origin_id=pubsubs_fsub[0].my_id,
|
||||||
await asyncio.sleep(0.01)
|
topic_ids=[TESTING_TOPIC],
|
||||||
assert not event.is_set()
|
data=TESTING_DATA,
|
||||||
|
seqno=b"\x22" * 8,
|
||||||
|
)
|
||||||
|
|
||||||
|
await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg_2)
|
||||||
|
await trio.sleep(0.01)
|
||||||
|
assert not event.is_set()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user