Add tests for Pubsub
- `test_get_hello_packet` - `test_continuously_read_stream` - `test_publish` - `test_push_msg`
This commit is contained in:
parent
550289a439
commit
037b95252d
@ -1,16 +1,23 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
import io
|
||||||
|
from typing import NamedTuple
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from libp2p.peer.id import ID
|
||||||
from libp2p.pubsub.pb import rpc_pb2
|
from libp2p.pubsub.pb import rpc_pb2
|
||||||
|
|
||||||
from tests.utils import (
|
from tests.utils import (
|
||||||
connect,
|
connect,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from .utils import (
|
||||||
|
make_pubsub_msg,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
TESTING_TOPIC = "TEST_SUBSCRIBE"
|
TESTING_TOPIC = "TEST_SUBSCRIBE"
|
||||||
TESTIND_DATA = b"data"
|
TESTING_DATA = b"data"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -101,3 +108,188 @@ async def test_get_hello_packet(pubsubs_fsub):
|
|||||||
for topic in topic_ids:
|
for topic in topic_ids:
|
||||||
assert topic in topic_ids_in_hello
|
assert topic in topic_ids_in_hello
|
||||||
|
|
||||||
|
|
||||||
|
class FakeNetStream:
|
||||||
|
_queue: asyncio.Queue
|
||||||
|
|
||||||
|
class FakeMplexConn(NamedTuple):
|
||||||
|
peer_id: ID = ID(b"\x12\x20" + b"\x00" * 32)
|
||||||
|
|
||||||
|
mplex_conn = FakeMplexConn()
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._queue = asyncio.Queue()
|
||||||
|
|
||||||
|
async def read(self) -> bytes:
|
||||||
|
buf = io.BytesIO()
|
||||||
|
while not self._queue.empty():
|
||||||
|
buf.write(await self._queue.get())
|
||||||
|
return buf.getvalue()
|
||||||
|
|
||||||
|
async def write(self, data: bytes) -> int:
|
||||||
|
for i in data:
|
||||||
|
await self._queue.put(i.to_bytes(1, 'big'))
|
||||||
|
return len(data)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"num_hosts",
|
||||||
|
(1,),
|
||||||
|
)
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_continuously_read_stream(pubsubs_fsub, monkeypatch):
|
||||||
|
s = FakeNetStream()
|
||||||
|
|
||||||
|
await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
|
||||||
|
|
||||||
|
event_push_msg = asyncio.Event()
|
||||||
|
event_handle_subscription = asyncio.Event()
|
||||||
|
event_handle_rpc = asyncio.Event()
|
||||||
|
|
||||||
|
async def mock_push_msg(msg_forwarder, msg):
|
||||||
|
event_push_msg.set()
|
||||||
|
|
||||||
|
def mock_handle_subscription(origin_id, sub_message):
|
||||||
|
event_handle_subscription.set()
|
||||||
|
|
||||||
|
async def mock_handle_rpc(rpc, sender_peer_id):
|
||||||
|
event_handle_rpc.set()
|
||||||
|
|
||||||
|
monkeypatch.setattr(pubsubs_fsub[0], "push_msg", mock_push_msg)
|
||||||
|
monkeypatch.setattr(pubsubs_fsub[0], "handle_subscription", mock_handle_subscription)
|
||||||
|
monkeypatch.setattr(pubsubs_fsub[0].router, "handle_rpc", mock_handle_rpc)
|
||||||
|
|
||||||
|
async def wait_for_event_occurring(event):
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(event.wait(), timeout=0.01)
|
||||||
|
except asyncio.TimeoutError as error:
|
||||||
|
event.clear()
|
||||||
|
raise asyncio.TimeoutError(
|
||||||
|
f"Event {event} is not set before the timeout. "
|
||||||
|
"This indicates the mocked functions are not called properly."
|
||||||
|
) from error
|
||||||
|
else:
|
||||||
|
event.clear()
|
||||||
|
|
||||||
|
# Kick off the task `continuously_read_stream`
|
||||||
|
task = asyncio.ensure_future(pubsubs_fsub[0].continuously_read_stream(s))
|
||||||
|
|
||||||
|
# Test: `push_msg` is called when publishing to a subscribed topic.
|
||||||
|
publish_subscribed_topic = rpc_pb2.RPC(
|
||||||
|
publish=[rpc_pb2.Message(
|
||||||
|
topicIDs=[TESTING_TOPIC]
|
||||||
|
)],
|
||||||
|
)
|
||||||
|
await s.write(publish_subscribed_topic.SerializeToString())
|
||||||
|
await wait_for_event_occurring(event_push_msg)
|
||||||
|
# Make sure the other events are not emitted.
|
||||||
|
with pytest.raises(asyncio.TimeoutError):
|
||||||
|
await wait_for_event_occurring(event_handle_subscription)
|
||||||
|
with pytest.raises(asyncio.TimeoutError):
|
||||||
|
await wait_for_event_occurring(event_handle_rpc)
|
||||||
|
|
||||||
|
# Test: `push_msg` is not called when publishing to a topic-not-subscribed.
|
||||||
|
publish_not_subscribed_topic = rpc_pb2.RPC(
|
||||||
|
publish=[rpc_pb2.Message(
|
||||||
|
topicIDs=["NOT_SUBSCRIBED"]
|
||||||
|
)],
|
||||||
|
)
|
||||||
|
await s.write(publish_not_subscribed_topic.SerializeToString())
|
||||||
|
with pytest.raises(asyncio.TimeoutError):
|
||||||
|
await wait_for_event_occurring(event_push_msg)
|
||||||
|
|
||||||
|
# Test: `handle_subscription` is called when a subscription message is received.
|
||||||
|
subscription_msg = rpc_pb2.RPC(
|
||||||
|
subscriptions=[rpc_pb2.RPC.SubOpts()],
|
||||||
|
)
|
||||||
|
await s.write(subscription_msg.SerializeToString())
|
||||||
|
await wait_for_event_occurring(event_handle_subscription)
|
||||||
|
# Make sure the other events are not emitted.
|
||||||
|
with pytest.raises(asyncio.TimeoutError):
|
||||||
|
await wait_for_event_occurring(event_push_msg)
|
||||||
|
with pytest.raises(asyncio.TimeoutError):
|
||||||
|
await wait_for_event_occurring(event_handle_rpc)
|
||||||
|
|
||||||
|
# Test: `handle_rpc` is called when a control message is received.
|
||||||
|
control_msg = rpc_pb2.RPC(control=rpc_pb2.ControlMessage())
|
||||||
|
await s.write(control_msg.SerializeToString())
|
||||||
|
await wait_for_event_occurring(event_handle_rpc)
|
||||||
|
# Make sure the other events are not emitted.
|
||||||
|
with pytest.raises(asyncio.TimeoutError):
|
||||||
|
await wait_for_event_occurring(event_push_msg)
|
||||||
|
with pytest.raises(asyncio.TimeoutError):
|
||||||
|
await wait_for_event_occurring(event_handle_subscription)
|
||||||
|
|
||||||
|
task.cancel()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"num_hosts",
|
||||||
|
(2,),
|
||||||
|
)
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_publish(pubsubs_fsub, monkeypatch):
|
||||||
|
msg_forwarders = []
|
||||||
|
msgs = []
|
||||||
|
|
||||||
|
async def push_msg(msg_forwarder, msg):
|
||||||
|
msg_forwarders.append(msg_forwarder)
|
||||||
|
msgs.append(msg)
|
||||||
|
monkeypatch.setattr(pubsubs_fsub[0], "push_msg", push_msg)
|
||||||
|
|
||||||
|
await pubsubs_fsub[0].publish(TESTING_TOPIC, TESTING_DATA)
|
||||||
|
await pubsubs_fsub[0].publish(TESTING_TOPIC, TESTING_DATA)
|
||||||
|
|
||||||
|
assert len(msgs) == 2, "`push_msg` should be called every time `publish` is called"
|
||||||
|
assert (msg_forwarders[0] == msg_forwarders[1]) and (msg_forwarders[1] == pubsubs_fsub[0].my_id)
|
||||||
|
assert msgs[0].seqno != msgs[1].seqno, "`seqno` should be different every time"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"num_hosts",
|
||||||
|
(1,),
|
||||||
|
)
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_push_msg(pubsubs_fsub, monkeypatch):
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
msg_0 = make_pubsub_msg(
|
||||||
|
origin_id=pubsubs_fsub[0].my_id,
|
||||||
|
topic_ids=[TESTING_TOPIC],
|
||||||
|
data=TESTING_DATA,
|
||||||
|
seqno=b"\x00" * 8,
|
||||||
|
)
|
||||||
|
|
||||||
|
event = asyncio.Event()
|
||||||
|
|
||||||
|
async def router_publish(*args, **kwargs):
|
||||||
|
event.set()
|
||||||
|
monkeypatch.setattr(pubsubs_fsub[0].router, "publish", router_publish)
|
||||||
|
|
||||||
|
# Test: `msg` is not seen before `push_msg`, and is seen after `push_msg`.
|
||||||
|
assert not pubsubs_fsub[0]._is_msg_seen(msg_0)
|
||||||
|
await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg_0)
|
||||||
|
assert pubsubs_fsub[0]._is_msg_seen(msg_0)
|
||||||
|
# Test: Ensure `router.publish` is called in `push_msg`
|
||||||
|
await asyncio.wait_for(event.wait(), timeout=0.1)
|
||||||
|
|
||||||
|
# Test: `push_msg` the message again and it will be reject.
|
||||||
|
# `router_publish` is not called then.
|
||||||
|
event.clear()
|
||||||
|
await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg_0)
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
assert not event.is_set()
|
||||||
|
|
||||||
|
sub = await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
|
||||||
|
# Test: `push_msg` succeeds with another unseen msg.
|
||||||
|
msg_1 = make_pubsub_msg(
|
||||||
|
origin_id=pubsubs_fsub[0].my_id,
|
||||||
|
topic_ids=[TESTING_TOPIC],
|
||||||
|
data=TESTING_DATA,
|
||||||
|
seqno=b"\x11" * 8,
|
||||||
|
)
|
||||||
|
assert not pubsubs_fsub[0]._is_msg_seen(msg_1)
|
||||||
|
await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg_1)
|
||||||
|
assert pubsubs_fsub[0]._is_msg_seen(msg_1)
|
||||||
|
await asyncio.wait_for(event.wait(), timeout=0.1)
|
||||||
|
# Test: Subscribers are notified when `push_msg` new messages.
|
||||||
|
assert (await sub.get()) == msg_1
|
||||||
|
@ -14,6 +14,8 @@ from libp2p.pubsub.pubsub import Pubsub
|
|||||||
|
|
||||||
from tests.utils import connect
|
from tests.utils import connect
|
||||||
|
|
||||||
|
from .configs import LISTEN_MADDR
|
||||||
|
|
||||||
|
|
||||||
def message_id_generator(start_val):
|
def message_id_generator(start_val):
|
||||||
"""
|
"""
|
||||||
@ -80,13 +82,13 @@ async def create_libp2p_hosts(num_hosts):
|
|||||||
tasks_create = []
|
tasks_create = []
|
||||||
for i in range(0, num_hosts):
|
for i in range(0, num_hosts):
|
||||||
# Create node
|
# Create node
|
||||||
tasks_create.append(new_node(transport_opt=["/ip4/127.0.0.1/tcp/0"]))
|
tasks_create.append(new_node(transport_opt=[str(LISTEN_MADDR)]))
|
||||||
hosts = await asyncio.gather(*tasks_create)
|
hosts = await asyncio.gather(*tasks_create)
|
||||||
|
|
||||||
tasks_listen = []
|
tasks_listen = []
|
||||||
for node in hosts:
|
for node in hosts:
|
||||||
# Start listener
|
# Start listener
|
||||||
tasks_listen.append(asyncio.ensure_future(node.get_network().listen(multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/0"))))
|
tasks_listen.append(node.get_network().listen(LISTEN_MADDR))
|
||||||
await asyncio.gather(*tasks_listen)
|
await asyncio.gather(*tasks_listen)
|
||||||
|
|
||||||
return hosts
|
return hosts
|
||||||
@ -109,7 +111,7 @@ def create_pubsub_and_gossipsub_instances(
|
|||||||
degree_low, degree_high, time_to_live,
|
degree_low, degree_high, time_to_live,
|
||||||
gossip_window, gossip_history,
|
gossip_window, gossip_history,
|
||||||
heartbeat_interval)
|
heartbeat_interval)
|
||||||
pubsub = Pubsub(node, gossipsub, "a")
|
pubsub = Pubsub(node, gossipsub, node.get_id())
|
||||||
pubsubs.append(pubsub)
|
pubsubs.append(pubsub)
|
||||||
gossipsubs.append(gossipsub)
|
gossipsubs.append(gossipsub)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user