Merge pull request #190 from mhchia/feature/pubsub-test
Add `test_pubsub.py`
This commit is contained in:
commit
1727ba48d9
|
@ -118,12 +118,14 @@ class Pubsub:
|
|||
Generate subscription message with all topics we are subscribed to
|
||||
only send hello packet if we have subscribed topics
|
||||
"""
|
||||
packet: rpc_pb2.RPC = rpc_pb2.RPC()
|
||||
if self.my_topics:
|
||||
for topic_id in self.my_topics:
|
||||
packet.subscriptions.extend([rpc_pb2.RPC.SubOpts(
|
||||
subscribe=True, topicid=topic_id)])
|
||||
|
||||
packet = rpc_pb2.RPC()
|
||||
for topic_id in self.my_topics:
|
||||
packet.subscriptions.extend([
|
||||
rpc_pb2.RPC.SubOpts(
|
||||
subscribe=True,
|
||||
topicid=topic_id,
|
||||
)
|
||||
])
|
||||
return packet.SerializeToString()
|
||||
|
||||
async def continuously_read_stream(self, stream: INetStream) -> None:
|
||||
|
@ -157,7 +159,11 @@ class Pubsub:
|
|||
for message in rpc_incoming.subscriptions:
|
||||
self.handle_subscription(peer_id, message)
|
||||
|
||||
if rpc_incoming.control:
|
||||
# pylint: disable=line-too-long
|
||||
# NOTE: Check if `rpc_incoming.control` is set through `HasField`.
|
||||
# This is necessary because `control` is an optional field in pb2.
|
||||
# Ref: https://developers.google.com/protocol-buffers/docs/reference/python-generated#singular-fields-proto2
|
||||
if rpc_incoming.HasField("control"):
|
||||
# Pass rpc to router so router could perform custom logic
|
||||
await self.router.handle_rpc(rpc_incoming, peer_id)
|
||||
|
||||
|
@ -182,6 +188,8 @@ class Pubsub:
|
|||
await stream.write(hello)
|
||||
# Pass stream off to stream reader
|
||||
asyncio.ensure_future(self.continuously_read_stream(stream))
|
||||
# Force context switch
|
||||
await asyncio.sleep(0)
|
||||
|
||||
async def handle_peer_queue(self) -> None:
|
||||
"""
|
||||
|
@ -208,6 +216,9 @@ class Pubsub:
|
|||
hello: bytes = self.get_hello_packet()
|
||||
await stream.write(hello)
|
||||
|
||||
# pylint: disable=line-too-long
|
||||
# TODO: Investigate whether this should be replaced by `handlePeerEOF`
|
||||
# Ref: https://github.com/libp2p/go-libp2p-pubsub/blob/49274b0e8aecdf6cad59d768e5702ff00aa48488/comm.go#L80 # noqa: E501
|
||||
# Pass stream off to stream reader
|
||||
asyncio.ensure_future(self.continuously_read_stream(stream))
|
||||
|
||||
|
@ -312,7 +323,7 @@ class Pubsub:
|
|||
"""
|
||||
|
||||
# Broadcast message
|
||||
for _, stream in self.peers.items():
|
||||
for stream in self.peers.values():
|
||||
# Write message to stream
|
||||
await stream.write(raw_msg)
|
||||
|
||||
|
@ -340,11 +351,11 @@ class Pubsub:
|
|||
:param msg_forwarder: the peer who forward us the message.
|
||||
:param msg: the message we are going to push out.
|
||||
"""
|
||||
# TODO: - Check if the `source` is in the blacklist. If yes, reject.
|
||||
# TODO: Check if the `source` is in the blacklist. If yes, reject.
|
||||
|
||||
# TODO: - Check if the `from` is in the blacklist. If yes, reject.
|
||||
# TODO: Check if the `from` is in the blacklist. If yes, reject.
|
||||
|
||||
# TODO: - Check if signing is required and if so signature should be attached.
|
||||
# TODO: Check if signing is required and if so signature should be attached.
|
||||
|
||||
if self._is_msg_seen(msg):
|
||||
return
|
||||
|
|
113
tests/pubsub/conftest.py
Normal file
113
tests/pubsub/conftest.py
Normal file
|
@ -0,0 +1,113 @@
|
|||
import asyncio
|
||||
from typing import NamedTuple
|
||||
|
||||
import pytest
|
||||
|
||||
from multiaddr import Multiaddr
|
||||
|
||||
from libp2p import new_node
|
||||
|
||||
from libp2p.pubsub.floodsub import FloodSub
|
||||
from libp2p.pubsub.gossipsub import GossipSub
|
||||
from libp2p.pubsub.pubsub import Pubsub
|
||||
|
||||
from .configs import (
|
||||
FLOODSUB_PROTOCOL_ID,
|
||||
GOSSIPSUB_PROTOCOL_ID,
|
||||
LISTEN_MADDR,
|
||||
)
|
||||
|
||||
|
||||
# pylint: disable=redefined-outer-name
|
||||
|
||||
@pytest.fixture
|
||||
def num_hosts():
|
||||
return 3
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def hosts(num_hosts):
|
||||
_hosts = await asyncio.gather(*[
|
||||
new_node(transport_opt=[str(LISTEN_MADDR)])
|
||||
for _ in range(num_hosts)
|
||||
])
|
||||
await asyncio.gather(*[
|
||||
_host.get_network().listen(LISTEN_MADDR)
|
||||
for _host in _hosts
|
||||
])
|
||||
yield _hosts
|
||||
# Clean up
|
||||
listeners = []
|
||||
for _host in _hosts:
|
||||
for listener in _host.get_network().listeners.values():
|
||||
listener.server.close()
|
||||
listeners.append(listener)
|
||||
await asyncio.gather(*[
|
||||
listener.server.wait_closed()
|
||||
for listener in listeners
|
||||
])
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def floodsubs(num_hosts):
|
||||
return tuple(
|
||||
FloodSub(protocols=[FLOODSUB_PROTOCOL_ID])
|
||||
for _ in range(num_hosts)
|
||||
)
|
||||
|
||||
|
||||
class GossipsubParams(NamedTuple):
|
||||
degree: int = 10
|
||||
degree_low: int = 9
|
||||
degree_high: int = 11
|
||||
fanout_ttl: int = 30
|
||||
gossip_window: int = 3
|
||||
gossip_history: int = 5
|
||||
heartbeat_interval: float = 0.5
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def gossipsub_params():
|
||||
return GossipsubParams()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def gossipsubs(num_hosts, gossipsub_params):
|
||||
yield tuple(
|
||||
GossipSub(
|
||||
protocols=[GOSSIPSUB_PROTOCOL_ID],
|
||||
**gossipsub_params._asdict(),
|
||||
)
|
||||
for _ in range(num_hosts)
|
||||
)
|
||||
# TODO: Clean up
|
||||
|
||||
|
||||
def _make_pubsubs(hosts, pubsub_routers):
|
||||
if len(pubsub_routers) != len(hosts):
|
||||
raise ValueError(
|
||||
f"lenght of pubsub_routers={pubsub_routers} should be equaled to the "
|
||||
f"length of hosts={len(hosts)}"
|
||||
)
|
||||
return tuple(
|
||||
Pubsub(
|
||||
host=host,
|
||||
router=router,
|
||||
my_id=host.get_id(),
|
||||
)
|
||||
for host, router in zip(hosts, pubsub_routers)
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pubsubs_fsub(hosts, floodsubs):
|
||||
_pubsubs_fsub = _make_pubsubs(hosts, floodsubs)
|
||||
yield _pubsubs_fsub
|
||||
# TODO: Clean up
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pubsubs_gsub(hosts, gossipsubs):
|
||||
_pubsubs_gsub = _make_pubsubs(hosts, gossipsubs)
|
||||
yield _pubsubs_gsub
|
||||
# TODO: Clean up
|
|
@ -39,9 +39,9 @@ async def test_simple_two_nodes():
|
|||
data = b"some data"
|
||||
|
||||
floodsub_a = FloodSub(supported_protocols)
|
||||
pubsub_a = Pubsub(node_a, floodsub_a, ID(b"a" * 32))
|
||||
pubsub_a = Pubsub(node_a, floodsub_a, ID(b"\x12\x20" + b"a" * 32))
|
||||
floodsub_b = FloodSub(supported_protocols)
|
||||
pubsub_b = Pubsub(node_b, floodsub_b, ID(b"b" * 32))
|
||||
pubsub_b = Pubsub(node_b, floodsub_b, ID(b"\x12\x20" + b"a" * 32))
|
||||
|
||||
await connect(node_a, node_b)
|
||||
await asyncio.sleep(0.25)
|
||||
|
|
|
@ -3,8 +3,8 @@ import pytest
|
|||
from libp2p.pubsub.mcache import MessageCache
|
||||
|
||||
|
||||
# pylint: disable=too-few-public-methods
|
||||
class Msg:
|
||||
|
||||
def __init__(self, topicIDs, seqno, from_id):
|
||||
# pylint: disable=invalid-name
|
||||
self.topicIDs = topicIDs
|
||||
|
@ -15,8 +15,7 @@ class Msg:
|
|||
@pytest.mark.asyncio
|
||||
async def test_mcache():
|
||||
# Ported from:
|
||||
# https://github.com/libp2p/go-libp2p-pubsub
|
||||
# /blob/51b7501433411b5096cac2b4994a36a68515fc03/mcache_test.go
|
||||
# https://github.com/libp2p/go-libp2p-pubsub/blob/51b7501433411b5096cac2b4994a36a68515fc03/mcache_test.go
|
||||
mcache = MessageCache(3, 5)
|
||||
msgs = []
|
||||
|
||||
|
|
392
tests/pubsub/test_pubsub.py
Normal file
392
tests/pubsub/test_pubsub.py
Normal file
|
@ -0,0 +1,392 @@
|
|||
import asyncio
|
||||
import io
|
||||
from typing import NamedTuple
|
||||
|
||||
import pytest
|
||||
|
||||
from libp2p.peer.id import ID
|
||||
from libp2p.pubsub.pb import rpc_pb2
|
||||
|
||||
from tests.utils import (
|
||||
connect,
|
||||
)
|
||||
|
||||
from .utils import (
|
||||
make_pubsub_msg,
|
||||
)
|
||||
|
||||
|
||||
TESTING_TOPIC = "TEST_SUBSCRIBE"
|
||||
TESTING_DATA = b"data"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"num_hosts",
|
||||
(1,),
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_subscribe_and_unsubscribe(pubsubs_fsub):
|
||||
await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
|
||||
assert TESTING_TOPIC in pubsubs_fsub[0].my_topics
|
||||
|
||||
await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC)
|
||||
assert TESTING_TOPIC not in pubsubs_fsub[0].my_topics
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"num_hosts",
|
||||
(1,),
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_re_subscribe(pubsubs_fsub):
|
||||
await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
|
||||
assert TESTING_TOPIC in pubsubs_fsub[0].my_topics
|
||||
|
||||
await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
|
||||
assert TESTING_TOPIC in pubsubs_fsub[0].my_topics
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"num_hosts",
|
||||
(1,),
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_re_unsubscribe(pubsubs_fsub):
|
||||
# Unsubscribe from topic we didn't even subscribe to
|
||||
assert "NOT_MY_TOPIC" not in pubsubs_fsub[0].my_topics
|
||||
await pubsubs_fsub[0].unsubscribe("NOT_MY_TOPIC")
|
||||
assert "NOT_MY_TOPIC" not in pubsubs_fsub[0].my_topics
|
||||
|
||||
await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
|
||||
assert TESTING_TOPIC in pubsubs_fsub[0].my_topics
|
||||
|
||||
await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC)
|
||||
assert TESTING_TOPIC not in pubsubs_fsub[0].my_topics
|
||||
|
||||
await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC)
|
||||
assert TESTING_TOPIC not in pubsubs_fsub[0].my_topics
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_peers_subscribe(pubsubs_fsub):
|
||||
await connect(pubsubs_fsub[0].host, pubsubs_fsub[1].host)
|
||||
await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
|
||||
# Yield to let 0 notify 1
|
||||
await asyncio.sleep(0.1)
|
||||
assert str(pubsubs_fsub[0].my_id) in pubsubs_fsub[1].peer_topics[TESTING_TOPIC]
|
||||
await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC)
|
||||
# Yield to let 0 notify 1
|
||||
await asyncio.sleep(0.1)
|
||||
assert str(pubsubs_fsub[0].my_id) not in pubsubs_fsub[1].peer_topics[TESTING_TOPIC]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"num_hosts",
|
||||
(1,),
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_hello_packet(pubsubs_fsub):
|
||||
def _get_hello_packet_topic_ids():
|
||||
packet = rpc_pb2.RPC()
|
||||
packet.ParseFromString(pubsubs_fsub[0].get_hello_packet())
|
||||
return tuple(
|
||||
sub.topicid
|
||||
for sub in packet.subscriptions
|
||||
)
|
||||
|
||||
# pylint: disable=len-as-condition
|
||||
# Test: No subscription, so there should not be any topic ids in the hello packet.
|
||||
assert len(_get_hello_packet_topic_ids()) == 0
|
||||
|
||||
# Test: After subscriptions, topic ids should be in the hello packet.
|
||||
topic_ids = ["t", "o", "p", "i", "c"]
|
||||
await asyncio.gather(*[
|
||||
pubsubs_fsub[0].subscribe(topic)
|
||||
for topic in topic_ids
|
||||
])
|
||||
topic_ids_in_hello = _get_hello_packet_topic_ids()
|
||||
for topic in topic_ids:
|
||||
assert topic in topic_ids_in_hello
|
||||
|
||||
|
||||
class FakeNetStream:
|
||||
_queue: asyncio.Queue
|
||||
|
||||
class FakeMplexConn(NamedTuple):
|
||||
peer_id: ID = ID(b"\x12\x20" + b"\x00" * 32)
|
||||
|
||||
mplex_conn = FakeMplexConn()
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._queue = asyncio.Queue()
|
||||
|
||||
async def read(self) -> bytes:
|
||||
buf = io.BytesIO()
|
||||
while not self._queue.empty():
|
||||
buf.write(await self._queue.get())
|
||||
return buf.getvalue()
|
||||
|
||||
async def write(self, data: bytes) -> int:
|
||||
for i in data:
|
||||
await self._queue.put(i.to_bytes(1, 'big'))
|
||||
return len(data)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"num_hosts",
|
||||
(1,),
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_continuously_read_stream(pubsubs_fsub, monkeypatch):
|
||||
stream = FakeNetStream()
|
||||
|
||||
await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
|
||||
|
||||
event_push_msg = asyncio.Event()
|
||||
event_handle_subscription = asyncio.Event()
|
||||
event_handle_rpc = asyncio.Event()
|
||||
|
||||
async def mock_push_msg(msg_forwarder, msg):
|
||||
event_push_msg.set()
|
||||
|
||||
def mock_handle_subscription(origin_id, sub_message):
|
||||
event_handle_subscription.set()
|
||||
|
||||
async def mock_handle_rpc(rpc, sender_peer_id):
|
||||
event_handle_rpc.set()
|
||||
|
||||
monkeypatch.setattr(pubsubs_fsub[0], "push_msg", mock_push_msg)
|
||||
monkeypatch.setattr(pubsubs_fsub[0], "handle_subscription", mock_handle_subscription)
|
||||
monkeypatch.setattr(pubsubs_fsub[0].router, "handle_rpc", mock_handle_rpc)
|
||||
|
||||
async def wait_for_event_occurring(event):
|
||||
try:
|
||||
await asyncio.wait_for(event.wait(), timeout=0.01)
|
||||
except asyncio.TimeoutError as error:
|
||||
event.clear()
|
||||
raise asyncio.TimeoutError(
|
||||
f"Event {event} is not set before the timeout. "
|
||||
"This indicates the mocked functions are not called properly."
|
||||
) from error
|
||||
else:
|
||||
event.clear()
|
||||
|
||||
# Kick off the task `continuously_read_stream`
|
||||
task = asyncio.ensure_future(pubsubs_fsub[0].continuously_read_stream(stream))
|
||||
|
||||
# Test: `push_msg` is called when publishing to a subscribed topic.
|
||||
publish_subscribed_topic = rpc_pb2.RPC(
|
||||
publish=[rpc_pb2.Message(
|
||||
topicIDs=[TESTING_TOPIC]
|
||||
)],
|
||||
)
|
||||
await stream.write(publish_subscribed_topic.SerializeToString())
|
||||
await wait_for_event_occurring(event_push_msg)
|
||||
# Make sure the other events are not emitted.
|
||||
with pytest.raises(asyncio.TimeoutError):
|
||||
await wait_for_event_occurring(event_handle_subscription)
|
||||
with pytest.raises(asyncio.TimeoutError):
|
||||
await wait_for_event_occurring(event_handle_rpc)
|
||||
|
||||
# Test: `push_msg` is not called when publishing to a topic-not-subscribed.
|
||||
publish_not_subscribed_topic = rpc_pb2.RPC(
|
||||
publish=[rpc_pb2.Message(
|
||||
topicIDs=["NOT_SUBSCRIBED"]
|
||||
)],
|
||||
)
|
||||
await stream.write(publish_not_subscribed_topic.SerializeToString())
|
||||
with pytest.raises(asyncio.TimeoutError):
|
||||
await wait_for_event_occurring(event_push_msg)
|
||||
|
||||
# Test: `handle_subscription` is called when a subscription message is received.
|
||||
subscription_msg = rpc_pb2.RPC(
|
||||
subscriptions=[rpc_pb2.RPC.SubOpts()],
|
||||
)
|
||||
await stream.write(subscription_msg.SerializeToString())
|
||||
await wait_for_event_occurring(event_handle_subscription)
|
||||
# Make sure the other events are not emitted.
|
||||
with pytest.raises(asyncio.TimeoutError):
|
||||
await wait_for_event_occurring(event_push_msg)
|
||||
with pytest.raises(asyncio.TimeoutError):
|
||||
await wait_for_event_occurring(event_handle_rpc)
|
||||
|
||||
# Test: `handle_rpc` is called when a control message is received.
|
||||
control_msg = rpc_pb2.RPC(control=rpc_pb2.ControlMessage())
|
||||
await stream.write(control_msg.SerializeToString())
|
||||
await wait_for_event_occurring(event_handle_rpc)
|
||||
# Make sure the other events are not emitted.
|
||||
with pytest.raises(asyncio.TimeoutError):
|
||||
await wait_for_event_occurring(event_push_msg)
|
||||
with pytest.raises(asyncio.TimeoutError):
|
||||
await wait_for_event_occurring(event_handle_subscription)
|
||||
|
||||
task.cancel()
|
||||
|
||||
|
||||
# TODO: Add the following tests after they are aligned with Go.
|
||||
# (Issue #191: https://github.com/libp2p/py-libp2p/issues/191)
|
||||
# - `test_stream_handler`
|
||||
# - `test_handle_peer_queue`
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"num_hosts",
|
||||
(1,),
|
||||
)
|
||||
def test_handle_subscription(pubsubs_fsub):
|
||||
assert len(pubsubs_fsub[0].peer_topics) == 0
|
||||
sub_msg_0 = rpc_pb2.RPC.SubOpts(
|
||||
subscribe=True,
|
||||
topicid=TESTING_TOPIC,
|
||||
)
|
||||
peer_ids = [
|
||||
ID(b"\x12\x20" + i.to_bytes(32, "big"))
|
||||
for i in range(2)
|
||||
]
|
||||
# Test: One peer is subscribed
|
||||
pubsubs_fsub[0].handle_subscription(peer_ids[0], sub_msg_0)
|
||||
assert len(pubsubs_fsub[0].peer_topics) == 1 and TESTING_TOPIC in pubsubs_fsub[0].peer_topics
|
||||
assert len(pubsubs_fsub[0].peer_topics[TESTING_TOPIC]) == 1
|
||||
assert str(peer_ids[0]) in pubsubs_fsub[0].peer_topics[TESTING_TOPIC]
|
||||
# Test: Another peer is subscribed
|
||||
pubsubs_fsub[0].handle_subscription(peer_ids[1], sub_msg_0)
|
||||
assert len(pubsubs_fsub[0].peer_topics) == 1
|
||||
assert len(pubsubs_fsub[0].peer_topics[TESTING_TOPIC]) == 2
|
||||
assert str(peer_ids[1]) in pubsubs_fsub[0].peer_topics[TESTING_TOPIC]
|
||||
# Test: Subscribe to another topic
|
||||
another_topic = "ANOTHER_TOPIC"
|
||||
sub_msg_1 = rpc_pb2.RPC.SubOpts(
|
||||
subscribe=True,
|
||||
topicid=another_topic,
|
||||
)
|
||||
pubsubs_fsub[0].handle_subscription(peer_ids[0], sub_msg_1)
|
||||
assert len(pubsubs_fsub[0].peer_topics) == 2
|
||||
assert another_topic in pubsubs_fsub[0].peer_topics
|
||||
assert str(peer_ids[0]) in pubsubs_fsub[0].peer_topics[another_topic]
|
||||
# Test: unsubscribe
|
||||
unsub_msg = rpc_pb2.RPC.SubOpts(
|
||||
subscribe=False,
|
||||
topicid=TESTING_TOPIC,
|
||||
)
|
||||
pubsubs_fsub[0].handle_subscription(peer_ids[0], unsub_msg)
|
||||
assert str(peer_ids[0]) not in pubsubs_fsub[0].peer_topics[TESTING_TOPIC]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"num_hosts",
|
||||
(1,),
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_talk(pubsubs_fsub):
|
||||
sub = await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
|
||||
msg_0 = make_pubsub_msg(
|
||||
origin_id=pubsubs_fsub[0].my_id,
|
||||
topic_ids=[TESTING_TOPIC],
|
||||
data=b"1234",
|
||||
seqno=b"\x00" * 8,
|
||||
)
|
||||
await pubsubs_fsub[0].handle_talk(msg_0)
|
||||
msg_1 = make_pubsub_msg(
|
||||
origin_id=pubsubs_fsub[0].my_id,
|
||||
topic_ids=["NOT_SUBSCRIBED"],
|
||||
data=b"1234",
|
||||
seqno=b"\x11" * 8,
|
||||
)
|
||||
await pubsubs_fsub[0].handle_talk(msg_1)
|
||||
assert len(pubsubs_fsub[0].my_topics) == 1 and sub == pubsubs_fsub[0].my_topics[TESTING_TOPIC]
|
||||
assert sub.qsize() == 1
|
||||
assert (await sub.get()) == msg_0
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"num_hosts",
|
||||
(1,),
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_all_peers(pubsubs_fsub, monkeypatch):
|
||||
peer_ids = [
|
||||
ID(b"\x12\x20" + i.to_bytes(32, "big"))
|
||||
for i in range(10)
|
||||
]
|
||||
mock_peers = {
|
||||
str(peer_id): FakeNetStream()
|
||||
for peer_id in peer_ids
|
||||
}
|
||||
monkeypatch.setattr(pubsubs_fsub[0], "peers", mock_peers)
|
||||
|
||||
empty_rpc = rpc_pb2.RPC()
|
||||
await pubsubs_fsub[0].message_all_peers(empty_rpc.SerializeToString())
|
||||
for stream in mock_peers.values():
|
||||
assert (await stream.read()) == empty_rpc.SerializeToString()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"num_hosts",
|
||||
(1,),
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish(pubsubs_fsub, monkeypatch):
|
||||
msg_forwarders = []
|
||||
msgs = []
|
||||
|
||||
async def push_msg(msg_forwarder, msg):
|
||||
msg_forwarders.append(msg_forwarder)
|
||||
msgs.append(msg)
|
||||
monkeypatch.setattr(pubsubs_fsub[0], "push_msg", push_msg)
|
||||
|
||||
await pubsubs_fsub[0].publish(TESTING_TOPIC, TESTING_DATA)
|
||||
await pubsubs_fsub[0].publish(TESTING_TOPIC, TESTING_DATA)
|
||||
|
||||
assert len(msgs) == 2, "`push_msg` should be called every time `publish` is called"
|
||||
assert (msg_forwarders[0] == msg_forwarders[1]) and (msg_forwarders[1] == pubsubs_fsub[0].my_id)
|
||||
assert msgs[0].seqno != msgs[1].seqno, "`seqno` should be different every time"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"num_hosts",
|
||||
(1,),
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_push_msg(pubsubs_fsub, monkeypatch):
|
||||
# pylint: disable=protected-access
|
||||
msg_0 = make_pubsub_msg(
|
||||
origin_id=pubsubs_fsub[0].my_id,
|
||||
topic_ids=[TESTING_TOPIC],
|
||||
data=TESTING_DATA,
|
||||
seqno=b"\x00" * 8,
|
||||
)
|
||||
|
||||
event = asyncio.Event()
|
||||
|
||||
async def router_publish(*args, **kwargs):
|
||||
event.set()
|
||||
monkeypatch.setattr(pubsubs_fsub[0].router, "publish", router_publish)
|
||||
|
||||
# Test: `msg` is not seen before `push_msg`, and is seen after `push_msg`.
|
||||
assert not pubsubs_fsub[0]._is_msg_seen(msg_0)
|
||||
await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg_0)
|
||||
assert pubsubs_fsub[0]._is_msg_seen(msg_0)
|
||||
# Test: Ensure `router.publish` is called in `push_msg`
|
||||
await asyncio.wait_for(event.wait(), timeout=0.1)
|
||||
|
||||
# Test: `push_msg` the message again and it will be reject.
|
||||
# `router_publish` is not called then.
|
||||
event.clear()
|
||||
await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg_0)
|
||||
await asyncio.sleep(0.01)
|
||||
assert not event.is_set()
|
||||
|
||||
sub = await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
|
||||
# Test: `push_msg` succeeds with another unseen msg.
|
||||
msg_1 = make_pubsub_msg(
|
||||
origin_id=pubsubs_fsub[0].my_id,
|
||||
topic_ids=[TESTING_TOPIC],
|
||||
data=TESTING_DATA,
|
||||
seqno=b"\x11" * 8,
|
||||
)
|
||||
assert not pubsubs_fsub[0]._is_msg_seen(msg_1)
|
||||
await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg_1)
|
||||
assert pubsubs_fsub[0]._is_msg_seen(msg_1)
|
||||
await asyncio.wait_for(event.wait(), timeout=0.1)
|
||||
# Test: Subscribers are notified when `push_msg` new messages.
|
||||
assert (await sub.get()) == msg_1
|
|
@ -14,6 +14,8 @@ from libp2p.pubsub.pubsub import Pubsub
|
|||
|
||||
from tests.utils import connect
|
||||
|
||||
from .configs import LISTEN_MADDR
|
||||
|
||||
|
||||
def message_id_generator(start_val):
|
||||
"""
|
||||
|
@ -80,13 +82,13 @@ async def create_libp2p_hosts(num_hosts):
|
|||
tasks_create = []
|
||||
for i in range(0, num_hosts):
|
||||
# Create node
|
||||
tasks_create.append(asyncio.ensure_future(new_node(transport_opt=["/ip4/127.0.0.1/tcp/0"])))
|
||||
tasks_create.append(new_node(transport_opt=[str(LISTEN_MADDR)]))
|
||||
hosts = await asyncio.gather(*tasks_create)
|
||||
|
||||
tasks_listen = []
|
||||
for node in hosts:
|
||||
# Start listener
|
||||
tasks_listen.append(asyncio.ensure_future(node.get_network().listen(multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/0"))))
|
||||
tasks_listen.append(node.get_network().listen(LISTEN_MADDR))
|
||||
await asyncio.gather(*tasks_listen)
|
||||
|
||||
return hosts
|
||||
|
@ -109,7 +111,7 @@ def create_pubsub_and_gossipsub_instances(
|
|||
degree_low, degree_high, time_to_live,
|
||||
gossip_window, gossip_history,
|
||||
heartbeat_interval)
|
||||
pubsub = Pubsub(node, gossipsub, "a")
|
||||
pubsub = Pubsub(node, gossipsub, node.get_id())
|
||||
pubsubs.append(pubsub)
|
||||
gossipsubs.append(gossipsub)
|
||||
|
||||
|
|
|
@ -13,6 +13,8 @@ async def connect(node1, node2):
|
|||
addr = node2.get_addrs()[0]
|
||||
info = info_from_p2p_addr(addr)
|
||||
await node1.connect(info)
|
||||
assert node1.get_id() in node2.get_network().connections
|
||||
assert node2.get_id() in node1.get_network().connections
|
||||
|
||||
|
||||
async def cleanup():
|
||||
|
@ -25,6 +27,7 @@ async def cleanup():
|
|||
with suppress(asyncio.CancelledError):
|
||||
await task
|
||||
|
||||
|
||||
async def set_up_nodes_by_transport_opt(transport_opt_list):
|
||||
nodes_list = []
|
||||
for transport_opt in transport_opt_list:
|
||||
|
|
Loading…
Reference in New Issue
Block a user