Fix Pubsub

This commit is contained in:
mhchia 2019-12-03 17:27:49 +08:00
parent bdbb7b2394
commit e9ab0646e3
No known key found for this signature in database
GPG Key ID: 389EFBEA1362589A
7 changed files with 568 additions and 523 deletions

View File

@ -1,6 +1,8 @@
import logging
from typing import Iterable, List, Sequence
import trio
from libp2p.network.stream.exceptions import StreamClosed
from libp2p.peer.id import ID
from libp2p.typing import TProtocol
@ -61,6 +63,8 @@ class FloodSub(IPubsubRouter):
:param rpc: rpc message
"""
# Checkpoint
await trio.sleep(0)
async def publish(self, msg_forwarder: ID, pubsub_msg: rpc_pb2.Message) -> None:
"""
@ -102,6 +106,8 @@ class FloodSub(IPubsubRouter):
:param topic: topic to join
"""
# Checkpoint
await trio.sleep(0)
async def leave(self, topic: str) -> None:
"""
@ -110,6 +116,8 @@ class FloodSub(IPubsubRouter):
:param topic: topic to leave
"""
# Checkpoint
await trio.sleep(0)
def _get_peers_to_send(
self, topic_ids: Iterable[str], msg_forwarder: ID, origin: ID

View File

@ -1,11 +1,13 @@
import asyncio
from abc import ABC, abstractmethod
import logging
import math
import time
from typing import (
TYPE_CHECKING,
Awaitable,
Callable,
Dict,
KeysView,
List,
NamedTuple,
Tuple,
@ -13,8 +15,10 @@ from typing import (
cast,
)
from async_service import Service
import base58
from lru import LRU
import trio
from libp2p.exceptions import ParseError, ValidationError
from libp2p.host.host_interface import IHost
@ -53,24 +57,24 @@ class TopicValidator(NamedTuple):
is_async: bool
class Pubsub:
class BasePubsub(ABC):
pass
class Pubsub(BasePubsub, Service):
host: IHost
my_id: ID
router: "IPubsubRouter"
peer_queue: "asyncio.Queue[ID]"
dead_peer_queue: "asyncio.Queue[ID]"
protocols: List[TProtocol]
incoming_msgs_from_peers: "asyncio.Queue[rpc_pb2.Message]"
outgoing_messages: "asyncio.Queue[rpc_pb2.Message]"
peer_receive_channel: "trio.MemoryReceiveChannel[ID]"
dead_peer_receive_channel: "trio.MemoryReceiveChannel[ID]"
seen_messages: LRU
my_topics: Dict[str, "asyncio.Queue[rpc_pb2.Message]"]
# TODO: Implement `trio.abc.Channel`?
subscribed_topics_send: Dict[str, "trio.MemorySendChannel[rpc_pb2.Message]"]
subscribed_topics_receive: Dict[str, "trio.MemoryReceiveChannel[rpc_pb2.Message]"]
peer_topics: Dict[str, List[ID]]
peers: Dict[ID, INetStream]
@ -80,10 +84,8 @@ class Pubsub:
# TODO: Be sure it is increased atomically everytime.
counter: int # uint64
_tasks: List["asyncio.Future[Any]"]
def __init__(
self, host: IHost, router: "IPubsubRouter", my_id: ID, cache_size: int = None
self, host: IHost, router: "IPubsubRouter", cache_size: int = None
) -> None:
"""
Construct a new Pubsub object, which is responsible for handling all
@ -97,28 +99,26 @@ class Pubsub:
"""
self.host = host
self.router = router
self.my_id = my_id
# Attach this new Pubsub object to the router
self.router.attach(self)
peer_send_channel, peer_receive_channel = trio.open_memory_channel(0)
dead_peer_send_channel, dead_peer_receive_channel = trio.open_memory_channel(0)
# Only keep the receive channels in `Pubsub`.
# Therefore, we can only close from the receive side.
self.peer_receive_channel = peer_receive_channel
self.dead_peer_receive_channel = dead_peer_receive_channel
# Register a notifee
self.peer_queue = asyncio.Queue()
self.dead_peer_queue = asyncio.Queue()
self.host.get_network().register_notifee(
PubsubNotifee(self.peer_queue, self.dead_peer_queue)
PubsubNotifee(peer_send_channel, dead_peer_send_channel)
)
# Register stream handlers for each pubsub router protocol to handle
# the pubsub streams opened on those protocols
self.protocols = self.router.get_protocols()
for protocol in self.protocols:
for protocol in router.protocols:
self.host.set_stream_handler(protocol, self.stream_handler)
# Use asyncio queues for proper context switching
self.incoming_msgs_from_peers = asyncio.Queue()
self.outgoing_messages = asyncio.Queue()
# keeps track of seen messages as LRU cache
if cache_size is None:
self.cache_size = 128
@ -129,7 +129,8 @@ class Pubsub:
# Map of topics we are subscribed to blocking queues
# for when the given topic receives a message
self.my_topics = {}
self.subscribed_topics_send = {}
self.subscribed_topics_receive = {}
# Map of topic to peers to keep track of what peers are subscribed to
self.peer_topics = {}
@ -142,16 +143,28 @@ class Pubsub:
self.counter = time.time_ns()
self._tasks = []
# Call handle peer to keep waiting for updates to peer queue
self._tasks.append(asyncio.ensure_future(self.handle_peer_queue()))
self._tasks.append(asyncio.ensure_future(self.handle_dead_peer_queue()))
async def run(self) -> None:
self.manager.run_daemon_task(self.handle_peer_queue)
self.manager.run_daemon_task(self.handle_dead_peer_queue)
await self.manager.wait_finished()
@property
def my_id(self) -> ID:
return self.host.get_id()
@property
def protocols(self) -> Tuple[TProtocol, ...]:
return tuple(self.router.get_protocols())
@property
def topic_ids(self) -> KeysView[str]:
return self.subscribed_topics_receive.keys()
def get_hello_packet(self) -> rpc_pb2.RPC:
"""Generate subscription message with all topics we are subscribed to
only send hello packet if we have subscribed topics."""
packet = rpc_pb2.RPC()
for topic_id in self.my_topics:
for topic_id in self.topic_ids:
packet.subscriptions.extend(
[rpc_pb2.RPC.SubOpts(subscribe=True, topicid=topic_id)]
)
@ -166,7 +179,7 @@ class Pubsub:
"""
peer_id = stream.muxed_conn.peer_id
while True:
while self.manager.is_running:
incoming: bytes = await read_varint_prefixed_bytes(stream)
rpc_incoming: rpc_pb2.RPC = rpc_pb2.RPC()
rpc_incoming.ParseFromString(incoming)
@ -178,11 +191,7 @@ class Pubsub:
logger.debug(
"received `publish` message %s from peer %s", msg, peer_id
)
self._tasks.append(
asyncio.ensure_future(
self.push_msg(msg_forwarder=peer_id, msg=msg)
)
)
self.manager.run_task(self.push_msg, peer_id, msg)
if rpc_incoming.subscriptions:
# deal with RPC.subscriptions
@ -210,9 +219,6 @@ class Pubsub:
)
await self.router.handle_rpc(rpc_incoming, peer_id)
# Force context switch
await asyncio.sleep(0)
def set_topic_validator(
self, topic: str, validator: ValidatorFn, is_async_validator: bool
) -> None:
@ -285,7 +291,6 @@ class Pubsub:
logger.debug("Fail to add new peer %s: stream closed", peer_id)
del self.peers[peer_id]
return
# TODO: Check EOF of this stream.
# TODO: Check if the peer in black list.
try:
self.router.add_peer(peer_id, stream.get_protocol())
@ -311,21 +316,23 @@ class Pubsub:
async def handle_peer_queue(self) -> None:
"""
Continuously read from peer queue and each time a new peer is found,
Continuously read from peer channel and each time a new peer is found,
open a stream to the peer using a supported pubsub protocol
TODO: Handle failure for when the peer does not support any of the
pubsub protocols we support
"""
while True:
peer_id: ID = await self.peer_queue.get()
async with self.peer_receive_channel:
while self.manager.is_running:
peer_id: ID = await self.peer_receive_channel.receive()
# Add Peer
self._tasks.append(asyncio.ensure_future(self._handle_new_peer(peer_id)))
self.manager.run_task(self._handle_new_peer, peer_id)
async def handle_dead_peer_queue(self) -> None:
"""Continuously read from dead peer queue and close the stream between
"""Continuously read from dead peer channel and close the stream between
that peer and remove peer info from pubsub and pubsub router."""
while True:
peer_id: ID = await self.dead_peer_queue.get()
async with self.dead_peer_receive_channel:
while self.manager.is_running:
peer_id: ID = await self.dead_peer_receive_channel.receive()
# Remove Peer
self._handle_dead_peer(peer_id)
@ -361,13 +368,16 @@ class Pubsub:
# Check if this message has any topics that we are subscribed to
for topic in publish_message.topicIDs:
if topic in self.my_topics:
if topic in self.topic_ids:
# we are subscribed to a topic this message was sent for,
# so add message to the subscription output queue
# for each topic
await self.my_topics[topic].put(publish_message)
await self.subscribed_topics_send[topic].send(publish_message)
async def subscribe(self, topic_id: str) -> "asyncio.Queue[rpc_pb2.Message]":
# TODO: Change to return an `AsyncIterable` to be I/O-agnostic?
async def subscribe(
self, topic_id: str
) -> "trio.MemoryReceiveChannel[rpc_pb2.Message]":
"""
Subscribe ourself to a topic.
@ -377,11 +387,13 @@ class Pubsub:
logger.debug("subscribing to topic %s", topic_id)
# Already subscribed
if topic_id in self.my_topics:
return self.my_topics[topic_id]
if topic_id in self.topic_ids:
return self.subscribed_topics_receive[topic_id]
# Map topic_id to blocking queue
self.my_topics[topic_id] = asyncio.Queue()
# Map topic_id to a blocking channel
send_channel, receive_channel = trio.open_memory_channel(math.inf)
self.subscribed_topics_send[topic_id] = send_channel
self.subscribed_topics_receive[topic_id] = receive_channel
# Create subscribe message
packet: rpc_pb2.RPC = rpc_pb2.RPC()
@ -395,8 +407,8 @@ class Pubsub:
# Tell router we are joining this topic
await self.router.join(topic_id)
# Return the asyncio queue for messages on this topic
return self.my_topics[topic_id]
# Return the trio channel for messages on this topic
return receive_channel
async def unsubscribe(self, topic_id: str) -> None:
"""
@ -408,10 +420,14 @@ class Pubsub:
logger.debug("unsubscribing from topic %s", topic_id)
# Return if we already unsubscribed from the topic
if topic_id not in self.my_topics:
if topic_id not in self.topic_ids:
return
# Remove topic_id from map if present
del self.my_topics[topic_id]
# Remove topic_id from the maps before yielding
send_channel = self.subscribed_topics_send[topic_id]
del self.subscribed_topics_send[topic_id]
del self.subscribed_topics_receive[topic_id]
# Only close the send side
await send_channel.aclose()
# Create unsubscribe message
packet: rpc_pb2.RPC = rpc_pb2.RPC()
@ -453,13 +469,13 @@ class Pubsub:
data=data,
topicIDs=[topic_id],
# Origin is ourself.
from_id=self.host.get_id().to_bytes(),
from_id=self.my_id.to_bytes(),
seqno=self._next_seqno(),
)
# TODO: Sign with our signing key
await self.push_msg(self.host.get_id(), msg)
await self.push_msg(self.my_id, msg)
logger.debug("successfully published message %s", msg)
@ -470,12 +486,12 @@ class Pubsub:
:param msg_forwarder: the peer who forward us the message.
:param msg: the message.
"""
sync_topic_validators = []
async_topic_validator_futures: List[Awaitable[bool]] = []
sync_topic_validators: List[SyncValidatorFn] = []
async_topic_validators: List[AsyncValidatorFn] = []
for topic_validator in self.get_msg_validators(msg):
if topic_validator.is_async:
async_topic_validator_futures.append(
cast(Awaitable[bool], topic_validator.validator(msg_forwarder, msg))
async_topic_validators.append(
cast(AsyncValidatorFn, topic_validator.validator)
)
else:
sync_topic_validators.append(
@ -488,9 +504,20 @@ class Pubsub:
# TODO: Implement throttle on async validators
if len(async_topic_validator_futures) > 0:
results = await asyncio.gather(*async_topic_validator_futures)
if not all(results):
if len(async_topic_validators) > 0:
# TODO: Use a better pattern
final_result = True
async def run_async_validator(func: AsyncValidatorFn) -> None:
nonlocal final_result
result = await func(msg_forwarder, msg)
final_result = final_result and result
async with trio.open_nursery() as nursery:
for validator in async_topic_validators:
nursery.start_soon(run_async_validator, validator)
if not final_result:
raise ValidationError(f"Validation failed for msg={msg}")
async def push_msg(self, msg_forwarder: ID, msg: rpc_pb2.Message) -> None:
@ -551,14 +578,4 @@ class Pubsub:
self.seen_messages[msg_id] = 1
def _is_subscribed_to_msg(self, msg: rpc_pb2.Message) -> bool:
if not self.my_topics:
return False
return any(topic in self.my_topics for topic in msg.topicIDs)
async def close(self) -> None:
for task in self._tasks:
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
return any(topic in self.topic_ids for topic in msg.topicIDs)

View File

@ -8,19 +8,19 @@ from libp2p.network.notifee_interface import INotifee
from libp2p.network.stream.net_stream_interface import INetStream
if TYPE_CHECKING:
import asyncio # noqa: F401
import trio # noqa: F401
from libp2p.peer.id import ID # noqa: F401
class PubsubNotifee(INotifee):
initiator_peers_queue: "asyncio.Queue[ID]"
dead_peers_queue: "asyncio.Queue[ID]"
initiator_peers_queue: "trio.MemorySendChannel[ID]"
dead_peers_queue: "trio.MemorySendChannel[ID]"
def __init__(
self,
initiator_peers_queue: "asyncio.Queue[ID]",
dead_peers_queue: "asyncio.Queue[ID]",
initiator_peers_queue: "trio.MemorySendChannel[ID]",
dead_peers_queue: "trio.MemorySendChannel[ID]",
) -> None:
"""
:param initiator_peers_queue: queue to add new peers to so that pubsub
@ -46,7 +46,12 @@ class PubsubNotifee(INotifee):
:param network: network the connection was opened on
:param conn: connection that was opened
"""
await self.initiator_peers_queue.put(conn.muxed_conn.peer_id)
try:
await self.initiator_peers_queue.send(conn.muxed_conn.peer_id)
except trio.BrokenResourceError:
# Raised when the receive channel is closed.
# TODO: Do something with loggers?
...
async def disconnected(self, network: INetwork, conn: INetConn) -> None:
"""
@ -56,7 +61,7 @@ class PubsubNotifee(INotifee):
:param network: network the connection was opened on
:param conn: connection that was opened
"""
await self.dead_peers_queue.put(conn.muxed_conn.peer_id)
await self.dead_peers_queue.send(conn.muxed_conn.peer_id)
async def listen(self, network: INetwork, multiaddr: Multiaddr) -> None:
pass

View File

@ -281,7 +281,7 @@ class Mplex(IMuxedConn, Service):
mplex_stream = await self._initialize_stream(stream_id, message.decode())
try:
await self.new_stream_send_channel.send(mplex_stream)
except (trio.BrokenResourceError, trio.EndOfChannel):
except (trio.BrokenResourceError, trio.ClosedResourceError):
raise MplexUnavailable
async def _handle_message(self, stream_id: StreamID, message: bytes) -> None:

View File

@ -5,6 +5,7 @@ from async_service import background_trio_service
import factory
import trio
from libp2p.tools.constants import GOSSIPSUB_PARAMS
from libp2p import generate_new_rsa_identity, generate_peer_id_from
from libp2p.crypto.keys import KeyPair
from libp2p.host.basic_host import BasicHost
@ -15,6 +16,7 @@ from libp2p.network.connection.swarm_connection import SwarmConn
from libp2p.network.stream.net_stream_interface import INetStream
from libp2p.network.swarm import Swarm
from libp2p.peer.peerstore import PeerStore
from libp2p.peer.id import ID
from libp2p.pubsub.floodsub import FloodSub
from libp2p.pubsub.gossipsub import GossipSub
from libp2p.pubsub.pubsub import Pubsub
@ -28,15 +30,19 @@ from libp2p.transport.typing import TMuxerOptions
from libp2p.transport.upgrader import TransportUpgrader
from libp2p.typing import TProtocol
from .constants import (
FLOODSUB_PROTOCOL_ID,
GOSSIPSUB_PARAMS,
GOSSIPSUB_PROTOCOL_ID,
LISTEN_MADDR,
)
from .constants import FLOODSUB_PROTOCOL_ID, GOSSIPSUB_PROTOCOL_ID, LISTEN_MADDR
from .utils import connect, connect_swarm
class IDFactory(factory.Factory):
class Meta:
model = ID
peer_id_bytes = factory.LazyFunction(
lambda: generate_peer_id_from(generate_new_rsa_identity())
)
def security_transport_factory(
is_secure: bool, key_pair: KeyPair
) -> Dict[TProtocol, BaseSecureTransport]:
@ -181,9 +187,38 @@ class PubsubFactory(factory.Factory):
host = factory.SubFactory(HostFactory)
router = None
my_id = factory.LazyAttribute(lambda obj: obj.host.get_id())
cache_size = None
@classmethod
@asynccontextmanager
async def create_and_start(cls, host, router, cache_size):
pubsub = PubsubFactory(host=host, router=router, cache_size=cache_size)
async with background_trio_service(pubsub):
yield pubsub
@classmethod
@asynccontextmanager
async def create_batch_with_floodsub(
cls, number: int, is_secure: bool = False, cache_size: int = None
):
floodsubs = FloodsubFactory.create_batch(number)
async with HostFactory.create_batch_and_listen(is_secure, number) as hosts:
# Pubsubs should exit before hosts
async with AsyncExitStack() as stack:
pubsubs = [
await stack.enter_async_context(
cls.create_and_start(host, router, cache_size)
)
for host, router in zip(hosts, floodsubs)
]
yield pubsubs
# @classmethod
# async def create_batch_with_gossipsub(
# cls, number: int, cache_size: int = None, gossipsub_params=GOSSIPSUB_PARAMS
# ):
# ...
@asynccontextmanager
async def swarm_pair_factory(

View File

@ -4,18 +4,6 @@ from libp2p.tools.constants import GOSSIPSUB_PARAMS
from libp2p.tools.factories import FloodsubFactory, GossipsubFactory, PubsubFactory
def _make_pubsubs(hosts, pubsub_routers, cache_size):
if len(pubsub_routers) != len(hosts):
raise ValueError(
f"lenght of pubsub_routers={pubsub_routers} should be equaled to the "
f"length of hosts={len(hosts)}"
)
return tuple(
PubsubFactory(host=host, router=router, cache_size=cache_size)
for host, router in zip(hosts, pubsub_routers)
)
@pytest.fixture
def pubsub_cache_size():
return None # default
@ -26,17 +14,9 @@ def gossipsub_params():
return GOSSIPSUB_PARAMS
@pytest.fixture
def pubsubs_fsub(num_hosts, hosts, pubsub_cache_size):
floodsubs = FloodsubFactory.create_batch(num_hosts)
_pubsubs_fsub = _make_pubsubs(hosts, floodsubs, pubsub_cache_size)
yield _pubsubs_fsub
# TODO: Clean up
@pytest.fixture
def pubsubs_gsub(num_hosts, hosts, pubsub_cache_size, gossipsub_params):
gossipsubs = GossipsubFactory.create_batch(num_hosts, **gossipsub_params._asdict())
_pubsubs_gsub = _make_pubsubs(hosts, gossipsubs, pubsub_cache_size)
yield _pubsubs_gsub
# TODO: Clean up
# @pytest.fixture
# def pubsubs_gsub(num_hosts, hosts, pubsub_cache_size, gossipsub_params):
# gossipsubs = GossipsubFactory.create_batch(num_hosts, **gossipsub_params._asdict())
# _pubsubs_gsub = _make_pubsubs(hosts, gossipsubs, pubsub_cache_size)
# yield _pubsubs_gsub
# # TODO: Clean up

View File

@ -1,73 +1,78 @@
import asyncio
from contextlib import contextmanager
from typing import NamedTuple
import pytest
import trio
from libp2p.exceptions import ValidationError
from libp2p.peer.id import ID
from libp2p.pubsub.pb import rpc_pb2
from libp2p.tools.pubsub.utils import make_pubsub_msg
from libp2p.tools.utils import connect
from libp2p.tools.constants import MAX_READ_LEN
from libp2p.tools.factories import PubsubFactory, net_stream_pair_factory, IDFactory
from libp2p.utils import encode_varint_prefixed
TESTING_TOPIC = "TEST_SUBSCRIBE"
TESTING_DATA = b"data"
@pytest.mark.parametrize("num_hosts", (1,))
@pytest.mark.asyncio
async def test_subscribe_and_unsubscribe(pubsubs_fsub):
@pytest.mark.trio
async def test_subscribe_and_unsubscribe():
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
assert TESTING_TOPIC in pubsubs_fsub[0].my_topics
assert TESTING_TOPIC in pubsubs_fsub[0].topic_ids
await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC)
assert TESTING_TOPIC not in pubsubs_fsub[0].my_topics
assert TESTING_TOPIC not in pubsubs_fsub[0].topic_ids
@pytest.mark.parametrize("num_hosts", (1,))
@pytest.mark.asyncio
async def test_re_subscribe(pubsubs_fsub):
@pytest.mark.trio
async def test_re_subscribe():
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
assert TESTING_TOPIC in pubsubs_fsub[0].my_topics
assert TESTING_TOPIC in pubsubs_fsub[0].topic_ids
await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
assert TESTING_TOPIC in pubsubs_fsub[0].my_topics
assert TESTING_TOPIC in pubsubs_fsub[0].topic_ids
@pytest.mark.parametrize("num_hosts", (1,))
@pytest.mark.asyncio
async def test_re_unsubscribe(pubsubs_fsub):
@pytest.mark.trio
async def test_re_unsubscribe():
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
# Unsubscribe from topic we didn't even subscribe to
assert "NOT_MY_TOPIC" not in pubsubs_fsub[0].my_topics
assert "NOT_MY_TOPIC" not in pubsubs_fsub[0].topic_ids
await pubsubs_fsub[0].unsubscribe("NOT_MY_TOPIC")
assert "NOT_MY_TOPIC" not in pubsubs_fsub[0].my_topics
assert "NOT_MY_TOPIC" not in pubsubs_fsub[0].topic_ids
await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
assert TESTING_TOPIC in pubsubs_fsub[0].my_topics
assert TESTING_TOPIC in pubsubs_fsub[0].topic_ids
await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC)
assert TESTING_TOPIC not in pubsubs_fsub[0].my_topics
assert TESTING_TOPIC not in pubsubs_fsub[0].topic_ids
await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC)
assert TESTING_TOPIC not in pubsubs_fsub[0].my_topics
assert TESTING_TOPIC not in pubsubs_fsub[0].topic_ids
@pytest.mark.asyncio
async def test_peers_subscribe(pubsubs_fsub):
@pytest.mark.trio
async def test_peers_subscribe():
async with PubsubFactory.create_batch_with_floodsub(2) as pubsubs_fsub:
await connect(pubsubs_fsub[0].host, pubsubs_fsub[1].host)
await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
# Yield to let 0 notify 1
await asyncio.sleep(0.1)
await trio.sleep(0.1)
assert pubsubs_fsub[0].my_id in pubsubs_fsub[1].peer_topics[TESTING_TOPIC]
await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC)
# Yield to let 0 notify 1
await asyncio.sleep(0.1)
await trio.sleep(0.1)
assert pubsubs_fsub[0].my_id not in pubsubs_fsub[1].peer_topics[TESTING_TOPIC]
@pytest.mark.parametrize("num_hosts", (1,))
@pytest.mark.asyncio
async def test_get_hello_packet(pubsubs_fsub):
@pytest.mark.trio
async def test_get_hello_packet():
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
def _get_hello_packet_topic_ids():
packet = pubsubs_fsub[0].get_hello_packet()
return tuple(sub.topicid for sub in packet.subscriptions)
@ -77,16 +82,16 @@ async def test_get_hello_packet(pubsubs_fsub):
# Test: After subscriptions, topic ids should be in the hello packet.
topic_ids = ["t", "o", "p", "i", "c"]
await asyncio.gather(*[pubsubs_fsub[0].subscribe(topic) for topic in topic_ids])
for topic in topic_ids:
await pubsubs_fsub[0].subscribe(topic)
topic_ids_in_hello = _get_hello_packet_topic_ids()
for topic in topic_ids:
assert topic in topic_ids_in_hello
@pytest.mark.parametrize("num_hosts", (1,))
@pytest.mark.asyncio
async def test_set_and_remove_topic_validator(pubsubs_fsub):
@pytest.mark.trio
async def test_set_and_remove_topic_validator():
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
is_sync_validator_called = False
def sync_validator(peer_id, msg):
@ -111,7 +116,7 @@ async def test_set_and_remove_topic_validator(pubsubs_fsub):
assert not topic_validator.is_async
# Validate with sync validator
topic_validator.validator(peer_id=ID(b"peer"), msg="msg")
topic_validator.validator(peer_id=IDFactory(), msg="msg")
assert is_sync_validator_called
assert not is_async_validator_called
@ -125,7 +130,7 @@ async def test_set_and_remove_topic_validator(pubsubs_fsub):
assert topic_validator.is_async
# Validate with async validator
await topic_validator.validator(peer_id=ID(b"peer"), msg="msg")
await topic_validator.validator(peer_id=IDFactory(), msg="msg")
assert is_async_validator_called
assert not is_sync_validator_called
@ -135,10 +140,9 @@ async def test_set_and_remove_topic_validator(pubsubs_fsub):
assert topic not in pubsubs_fsub[0].topic_validators
@pytest.mark.parametrize("num_hosts", (1,))
@pytest.mark.asyncio
async def test_get_msg_validators(pubsubs_fsub):
@pytest.mark.trio
async def test_get_msg_validators():
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
times_sync_validator_called = 0
def sync_validator(peer_id, msg):
@ -172,21 +176,22 @@ async def test_get_msg_validators(pubsubs_fsub):
topic_validators = pubsubs_fsub[0].get_msg_validators(msg)
for topic_validator in topic_validators:
if topic_validator.is_async:
await topic_validator.validator(peer_id=ID(b"peer"), msg="msg")
await topic_validator.validator(peer_id=IDFactory(), msg="msg")
else:
topic_validator.validator(peer_id=ID(b"peer"), msg="msg")
topic_validator.validator(peer_id=IDFactory(), msg="msg")
assert times_sync_validator_called == 2
assert times_async_validator_called == 1
@pytest.mark.parametrize("num_hosts", (1,))
@pytest.mark.parametrize(
"is_topic_1_val_passed, is_topic_2_val_passed",
((False, True), (True, False), (True, True)),
)
@pytest.mark.asyncio
async def test_validate_msg(pubsubs_fsub, is_topic_1_val_passed, is_topic_2_val_passed):
@pytest.mark.trio
async def test_validate_msg(is_topic_1_val_passed, is_topic_2_val_passed):
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
def passed_sync_validator(peer_id, msg):
return True
@ -226,123 +231,98 @@ async def test_validate_msg(pubsubs_fsub, is_topic_1_val_passed, is_topic_2_val_
await pubsubs_fsub[0].validate_msg(pubsubs_fsub[0].my_id, msg)
class FakeNetStream:
_queue: asyncio.Queue
@pytest.mark.trio
async def test_continuously_read_stream(monkeypatch, nursery, is_host_secure):
async def wait_for_event_occurring(event):
with trio.fail_after(0.1):
await event.wait()
class FakeMplexConn(NamedTuple):
peer_id: ID = ID(b"\x12\x20" + b"\x00" * 32)
class Events(NamedTuple):
push_msg: trio.Event
handle_subscription: trio.Event
handle_rpc: trio.Event
muxed_conn = FakeMplexConn()
def __init__(self) -> None:
self._queue = asyncio.Queue()
async def read(self, n: int = -1) -> bytes:
buf = bytearray()
# Force to blocking wait if no data available now.
if self._queue.empty():
first_byte = await self._queue.get()
buf.extend(first_byte)
# If `n == -1`, read until no data is in the buffer(_queue).
# Else, read until no data is in the buffer(_queue) or we have read `n` bytes.
while (n == -1) or (len(buf) < n):
if self._queue.empty():
break
buf.extend(await self._queue.get())
return bytes(buf)
async def write(self, data: bytes) -> int:
for i in data:
await self._queue.put(i.to_bytes(1, "big"))
return len(data)
@pytest.mark.parametrize("num_hosts", (1,))
@pytest.mark.asyncio
async def test_continuously_read_stream(pubsubs_fsub, monkeypatch):
stream = FakeNetStream()
await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
event_push_msg = asyncio.Event()
event_handle_subscription = asyncio.Event()
event_handle_rpc = asyncio.Event()
@contextmanager
def mock_methods():
event_push_msg = trio.Event()
event_handle_subscription = trio.Event()
event_handle_rpc = trio.Event()
async def mock_push_msg(msg_forwarder, msg):
event_push_msg.set()
await trio.sleep(0)
def mock_handle_subscription(origin_id, sub_message):
event_handle_subscription.set()
async def mock_handle_rpc(rpc, sender_peer_id):
event_handle_rpc.set()
await trio.sleep(0)
monkeypatch.setattr(pubsubs_fsub[0], "push_msg", mock_push_msg)
monkeypatch.setattr(
pubsubs_fsub[0], "handle_subscription", mock_handle_subscription
)
monkeypatch.setattr(pubsubs_fsub[0].router, "handle_rpc", mock_handle_rpc)
async def wait_for_event_occurring(event):
try:
await asyncio.wait_for(event.wait(), timeout=1)
except asyncio.TimeoutError as error:
event.clear()
raise asyncio.TimeoutError(
f"Event {event} is not set before the timeout. "
"This indicates the mocked functions are not called properly."
) from error
else:
event.clear()
with monkeypatch.context() as m:
m.setattr(pubsubs_fsub[0], "push_msg", mock_push_msg)
m.setattr(pubsubs_fsub[0], "handle_subscription", mock_handle_subscription)
m.setattr(pubsubs_fsub[0].router, "handle_rpc", mock_handle_rpc)
yield Events(event_push_msg, event_handle_subscription, event_handle_rpc)
async with PubsubFactory.create_batch_with_floodsub(
1, is_secure=is_host_secure
) as pubsubs_fsub, net_stream_pair_factory(is_secure=is_host_secure) as stream_pair:
await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
# Kick off the task `continuously_read_stream`
task = asyncio.ensure_future(pubsubs_fsub[0].continuously_read_stream(stream))
nursery.start_soon(pubsubs_fsub[0].continuously_read_stream, stream_pair[0])
# Test: `push_msg` is called when publishing to a subscribed topic.
publish_subscribed_topic = rpc_pb2.RPC(
publish=[rpc_pb2.Message(topicIDs=[TESTING_TOPIC])]
)
await stream.write(
with mock_methods() as events:
await stream_pair[1].write(
encode_varint_prefixed(publish_subscribed_topic.SerializeToString())
)
await wait_for_event_occurring(event_push_msg)
await wait_for_event_occurring(events.push_msg)
# Make sure the other events are not emitted.
with pytest.raises(asyncio.TimeoutError):
await wait_for_event_occurring(event_handle_subscription)
with pytest.raises(asyncio.TimeoutError):
await wait_for_event_occurring(event_handle_rpc)
with pytest.raises(trio.TooSlowError):
await wait_for_event_occurring(events.handle_subscription)
with pytest.raises(trio.TooSlowError):
await wait_for_event_occurring(events.handle_rpc)
# Test: `push_msg` is not called when publishing to a topic-not-subscribed.
publish_not_subscribed_topic = rpc_pb2.RPC(
publish=[rpc_pb2.Message(topicIDs=["NOT_SUBSCRIBED"])]
)
await stream.write(
with mock_methods() as events:
await stream_pair[1].write(
encode_varint_prefixed(publish_not_subscribed_topic.SerializeToString())
)
with pytest.raises(asyncio.TimeoutError):
await wait_for_event_occurring(event_push_msg)
with pytest.raises(trio.TooSlowError):
await wait_for_event_occurring(events.push_msg)
# Test: `handle_subscription` is called when a subscription message is received.
subscription_msg = rpc_pb2.RPC(subscriptions=[rpc_pb2.RPC.SubOpts()])
await stream.write(encode_varint_prefixed(subscription_msg.SerializeToString()))
await wait_for_event_occurring(event_handle_subscription)
with mock_methods() as events:
await stream_pair[1].write(
encode_varint_prefixed(subscription_msg.SerializeToString())
)
await wait_for_event_occurring(events.handle_subscription)
# Make sure the other events are not emitted.
with pytest.raises(asyncio.TimeoutError):
await wait_for_event_occurring(event_push_msg)
with pytest.raises(asyncio.TimeoutError):
await wait_for_event_occurring(event_handle_rpc)
with pytest.raises(trio.TooSlowError):
await wait_for_event_occurring(events.push_msg)
with pytest.raises(trio.TooSlowError):
await wait_for_event_occurring(events.handle_rpc)
# Test: `handle_rpc` is called when a control message is received.
control_msg = rpc_pb2.RPC(control=rpc_pb2.ControlMessage())
await stream.write(encode_varint_prefixed(control_msg.SerializeToString()))
await wait_for_event_occurring(event_handle_rpc)
with mock_methods() as events:
await stream_pair[1].write(
encode_varint_prefixed(control_msg.SerializeToString())
)
await wait_for_event_occurring(events.handle_rpc)
# Make sure the other events are not emitted.
with pytest.raises(asyncio.TimeoutError):
await wait_for_event_occurring(event_push_msg)
with pytest.raises(asyncio.TimeoutError):
await wait_for_event_occurring(event_handle_subscription)
task.cancel()
with pytest.raises(trio.TooSlowError):
await wait_for_event_occurring(events.push_msg)
with pytest.raises(trio.TooSlowError):
await wait_for_event_occurring(events.handle_subscription)
# TODO: Add the following tests after they are aligned with Go.
@ -351,11 +331,12 @@ async def test_continuously_read_stream(pubsubs_fsub, monkeypatch):
# - `test_handle_peer_queue`
@pytest.mark.parametrize("num_hosts", (1,))
def test_handle_subscription(pubsubs_fsub):
@pytest.mark.trio
async def test_handle_subscription():
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
assert len(pubsubs_fsub[0].peer_topics) == 0
sub_msg_0 = rpc_pb2.RPC.SubOpts(subscribe=True, topicid=TESTING_TOPIC)
peer_ids = [ID(b"\x12\x20" + i.to_bytes(32, "big")) for i in range(2)]
peer_ids = [IDFactory() for _ in range(2)]
# Test: One peer is subscribed
pubsubs_fsub[0].handle_subscription(peer_ids[0], sub_msg_0)
assert (
@ -382,9 +363,9 @@ def test_handle_subscription(pubsubs_fsub):
assert peer_ids[0] not in pubsubs_fsub[0].peer_topics[TESTING_TOPIC]
@pytest.mark.parametrize("num_hosts", (1,))
@pytest.mark.asyncio
async def test_handle_talk(pubsubs_fsub):
@pytest.mark.trio
async def test_handle_talk():
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
sub = await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
msg_0 = make_pubsub_msg(
origin_id=pubsubs_fsub[0].my_id,
@ -401,31 +382,33 @@ async def test_handle_talk(pubsubs_fsub):
)
await pubsubs_fsub[0].handle_talk(msg_1)
assert (
len(pubsubs_fsub[0].my_topics) == 1
and sub == pubsubs_fsub[0].my_topics[TESTING_TOPIC]
len(pubsubs_fsub[0].topic_ids) == 1
and sub == pubsubs_fsub[0].subscribed_topics_receive[TESTING_TOPIC]
)
assert sub.qsize() == 1
assert (await sub.get()) == msg_0
assert (await sub.receive()) == msg_0
@pytest.mark.parametrize("num_hosts", (1,))
@pytest.mark.asyncio
async def test_message_all_peers(pubsubs_fsub, monkeypatch):
peer_ids = [ID(b"\x12\x20" + i.to_bytes(32, "big")) for i in range(10)]
mock_peers = {peer_id: FakeNetStream() for peer_id in peer_ids}
monkeypatch.setattr(pubsubs_fsub[0], "peers", mock_peers)
@pytest.mark.trio
async def test_message_all_peers(monkeypatch, is_host_secure):
async with PubsubFactory.create_batch_with_floodsub(
1, is_secure=is_host_secure
) as pubsubs_fsub, net_stream_pair_factory(is_secure=is_host_secure) as stream_pair:
peer_id = IDFactory()
mock_peers = {peer_id: stream_pair[0]}
with monkeypatch.context() as m:
m.setattr(pubsubs_fsub[0], "peers", mock_peers)
empty_rpc = rpc_pb2.RPC()
empty_rpc_bytes = empty_rpc.SerializeToString()
empty_rpc_bytes_len_prefixed = encode_varint_prefixed(empty_rpc_bytes)
await pubsubs_fsub[0].message_all_peers(empty_rpc_bytes)
for stream in mock_peers.values():
assert (await stream.read()) == empty_rpc_bytes_len_prefixed
assert (
await stream_pair[1].read(MAX_READ_LEN)
) == empty_rpc_bytes_len_prefixed
@pytest.mark.parametrize("num_hosts", (1,))
@pytest.mark.asyncio
async def test_publish(pubsubs_fsub, monkeypatch):
@pytest.mark.trio
async def test_publish(monkeypatch):
msg_forwarders = []
msgs = []
@ -433,21 +416,27 @@ async def test_publish(pubsubs_fsub, monkeypatch):
msg_forwarders.append(msg_forwarder)
msgs.append(msg)
monkeypatch.setattr(pubsubs_fsub[0], "push_msg", push_msg)
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
with monkeypatch.context() as m:
m.setattr(pubsubs_fsub[0], "push_msg", push_msg)
await pubsubs_fsub[0].publish(TESTING_TOPIC, TESTING_DATA)
await pubsubs_fsub[0].publish(TESTING_TOPIC, TESTING_DATA)
assert len(msgs) == 2, "`push_msg` should be called every time `publish` is called"
assert (
len(msgs) == 2
), "`push_msg` should be called every time `publish` is called"
assert (msg_forwarders[0] == msg_forwarders[1]) and (
msg_forwarders[1] == pubsubs_fsub[0].my_id
)
assert msgs[0].seqno != msgs[1].seqno, "`seqno` should be different every time"
assert (
msgs[0].seqno != msgs[1].seqno
), "`seqno` should be different every time"
@pytest.mark.parametrize("num_hosts", (1,))
@pytest.mark.asyncio
async def test_push_msg(pubsubs_fsub, monkeypatch):
@pytest.mark.trio
async def test_push_msg(monkeypatch):
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
msg_0 = make_pubsub_msg(
origin_id=pubsubs_fsub[0].my_id,
topic_ids=[TESTING_TOPIC],
@ -455,25 +444,33 @@ async def test_push_msg(pubsubs_fsub, monkeypatch):
seqno=b"\x00" * 8,
)
event = asyncio.Event()
@contextmanager
def mock_router_publish():
event = trio.Event()
async def router_publish(*args, **kwargs):
event.set()
await trio.sleep(0)
monkeypatch.setattr(pubsubs_fsub[0].router, "publish", router_publish)
with monkeypatch.context() as m:
m.setattr(pubsubs_fsub[0].router, "publish", router_publish)
yield event
with mock_router_publish() as event:
# Test: `msg` is not seen before `push_msg`, and is seen after `push_msg`.
assert not pubsubs_fsub[0]._is_msg_seen(msg_0)
await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg_0)
assert pubsubs_fsub[0]._is_msg_seen(msg_0)
# Test: Ensure `router.publish` is called in `push_msg`
await asyncio.wait_for(event.wait(), timeout=0.1)
with trio.fail_after(0.1):
await event.wait()
with mock_router_publish() as event:
# Test: `push_msg` the message again and it will be reject.
# `router_publish` is not called then.
event.clear()
await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg_0)
await asyncio.sleep(0.01)
await trio.sleep(0.01)
assert not event.is_set()
sub = await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
@ -487,17 +484,21 @@ async def test_push_msg(pubsubs_fsub, monkeypatch):
assert not pubsubs_fsub[0]._is_msg_seen(msg_1)
await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg_1)
assert pubsubs_fsub[0]._is_msg_seen(msg_1)
await asyncio.wait_for(event.wait(), timeout=0.1)
with trio.fail_after(0.1):
await event.wait()
# Test: Subscribers are notified when `push_msg` new messages.
assert (await sub.get()) == msg_1
assert (await sub.receive()) == msg_1
with mock_router_publish() as event:
# Test: add a topic validator and `push_msg` the message that
# does not pass the validation.
# `router_publish` is not called then.
def failed_sync_validator(peer_id, msg):
return False
pubsubs_fsub[0].set_topic_validator(TESTING_TOPIC, failed_sync_validator, False)
pubsubs_fsub[0].set_topic_validator(
TESTING_TOPIC, failed_sync_validator, False
)
msg_2 = make_pubsub_msg(
origin_id=pubsubs_fsub[0].my_id,
@ -506,7 +507,6 @@ async def test_push_msg(pubsubs_fsub, monkeypatch):
seqno=b"\x22" * 8,
)
event.clear()
await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg_2)
await asyncio.sleep(0.01)
await trio.sleep(0.01)
assert not event.is_set()