Combine test_subscription.py to test_pubsub.py
And add a bunch of tests for pubsub
This commit is contained in:
parent
96563c0d84
commit
550289a439
|
@ -118,12 +118,14 @@ class Pubsub:
|
|||
Generate subscription message with all topics we are subscribed to
|
||||
only send hello packet if we have subscribed topics
|
||||
"""
|
||||
packet: rpc_pb2.RPC = rpc_pb2.RPC()
|
||||
if self.my_topics:
|
||||
packet = rpc_pb2.RPC()
|
||||
for topic_id in self.my_topics:
|
||||
packet.subscriptions.extend([rpc_pb2.RPC.SubOpts(
|
||||
subscribe=True, topicid=topic_id)])
|
||||
|
||||
packet.subscriptions.extend([
|
||||
rpc_pb2.RPC.SubOpts(
|
||||
subscribe=True,
|
||||
topicid=topic_id,
|
||||
)
|
||||
])
|
||||
return packet.SerializeToString()
|
||||
|
||||
async def continuously_read_stream(self, stream: INetStream) -> None:
|
||||
|
@ -182,6 +184,8 @@ class Pubsub:
|
|||
await stream.write(hello)
|
||||
# Pass stream off to stream reader
|
||||
asyncio.ensure_future(self.continuously_read_stream(stream))
|
||||
# Force context switch
|
||||
await asyncio.sleep(0)
|
||||
|
||||
async def handle_peer_queue(self) -> None:
|
||||
"""
|
||||
|
@ -208,6 +212,9 @@ class Pubsub:
|
|||
hello: bytes = self.get_hello_packet()
|
||||
await stream.write(hello)
|
||||
|
||||
# pylint: disable=line-too-long
|
||||
# TODO: Investigate whether this should be replaced by `handlePeerEOF`
|
||||
# Ref: https://github.com/libp2p/go-libp2p-pubsub/blob/49274b0e8aecdf6cad59d768e5702ff00aa48488/comm.go#L80 # noqa: E501
|
||||
# Pass stream off to stream reader
|
||||
asyncio.ensure_future(self.continuously_read_stream(stream))
|
||||
|
||||
|
|
|
@ -27,21 +27,25 @@ def num_hosts():
|
|||
|
||||
@pytest.fixture
|
||||
async def hosts(num_hosts):
|
||||
new_node_coros = tuple(
|
||||
_hosts = await asyncio.gather(*[
|
||||
new_node(transport_opt=[str(LISTEN_MADDR)])
|
||||
for _ in range(num_hosts)
|
||||
)
|
||||
_hosts = await asyncio.gather(*new_node_coros)
|
||||
])
|
||||
await asyncio.gather(*[
|
||||
_host.get_network().listen(LISTEN_MADDR)
|
||||
for _host in _hosts
|
||||
])
|
||||
yield _hosts
|
||||
# Clean up
|
||||
listeners = []
|
||||
for _host in _hosts:
|
||||
for listener in _host.get_network().listeners.values():
|
||||
listener.server.close()
|
||||
await listener.server.wait_closed()
|
||||
listeners.append(listener)
|
||||
await asyncio.gather(*[
|
||||
listener.server.wait_closed()
|
||||
for listener in listeners
|
||||
])
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
|
@ -3,8 +3,8 @@ import pytest
|
|||
from libp2p.pubsub.mcache import MessageCache
|
||||
|
||||
|
||||
# pylint: disable=too-few-public-methods
|
||||
class Msg:
|
||||
|
||||
def __init__(self, topicIDs, seqno, from_id):
|
||||
# pylint: disable=invalid-name
|
||||
self.topicIDs = topicIDs
|
||||
|
@ -15,8 +15,7 @@ class Msg:
|
|||
@pytest.mark.asyncio
|
||||
async def test_mcache():
|
||||
# Ported from:
|
||||
# https://github.com/libp2p/go-libp2p-pubsub
|
||||
# /blob/51b7501433411b5096cac2b4994a36a68515fc03/mcache_test.go
|
||||
# https://github.com/libp2p/go-libp2p-pubsub/blob/51b7501433411b5096cac2b4994a36a68515fc03/mcache_test.go
|
||||
mcache = MessageCache(3, 5)
|
||||
msgs = []
|
||||
|
||||
|
|
|
@ -1,11 +1,103 @@
|
|||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from libp2p.pubsub.pb import rpc_pb2
|
||||
|
||||
from tests.utils import (
|
||||
connect,
|
||||
)
|
||||
|
||||
|
||||
TESTING_TOPIC = "TEST_SUBSCRIBE"
|
||||
TESTIND_DATA = b"data"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"num_hosts",
|
||||
(1,),
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_subscribe_and_unsubscribe(pubsubs_fsub):
|
||||
await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
|
||||
assert TESTING_TOPIC in pubsubs_fsub[0].my_topics
|
||||
|
||||
await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC)
|
||||
assert TESTING_TOPIC not in pubsubs_fsub[0].my_topics
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"num_hosts",
|
||||
(1,),
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_re_subscribe(pubsubs_fsub):
|
||||
await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
|
||||
assert TESTING_TOPIC in pubsubs_fsub[0].my_topics
|
||||
|
||||
await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
|
||||
assert TESTING_TOPIC in pubsubs_fsub[0].my_topics
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"num_hosts",
|
||||
(1,),
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_re_unsubscribe(pubsubs_fsub):
|
||||
# Unsubscribe from topic we didn't even subscribe to
|
||||
assert "NOT_MY_TOPIC" not in pubsubs_fsub[0].my_topics
|
||||
await pubsubs_fsub[0].unsubscribe("NOT_MY_TOPIC")
|
||||
assert "NOT_MY_TOPIC" not in pubsubs_fsub[0].my_topics
|
||||
|
||||
await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
|
||||
assert TESTING_TOPIC in pubsubs_fsub[0].my_topics
|
||||
|
||||
await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC)
|
||||
assert TESTING_TOPIC not in pubsubs_fsub[0].my_topics
|
||||
|
||||
await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC)
|
||||
assert TESTING_TOPIC not in pubsubs_fsub[0].my_topics
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_test(pubsubs_fsub):
|
||||
topic = "topic"
|
||||
data = b"data"
|
||||
sub = await pubsubs_fsub[0].subscribe(topic)
|
||||
await pubsubs_fsub[0].publish(topic, data)
|
||||
msg = await sub.get()
|
||||
assert msg.data == data
|
||||
async def test_peers_subscribe(pubsubs_fsub):
|
||||
await connect(pubsubs_fsub[0].host, pubsubs_fsub[1].host)
|
||||
await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
|
||||
# Yield to let 0 notify 1
|
||||
await asyncio.sleep(0.1)
|
||||
assert str(pubsubs_fsub[0].my_id) in pubsubs_fsub[1].peer_topics[TESTING_TOPIC]
|
||||
await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC)
|
||||
# Yield to let 0 notify 1
|
||||
await asyncio.sleep(0.1)
|
||||
assert str(pubsubs_fsub[0].my_id) not in pubsubs_fsub[1].peer_topics[TESTING_TOPIC]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"num_hosts",
|
||||
(1,),
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_hello_packet(pubsubs_fsub):
|
||||
def _get_hello_packet_topic_ids():
|
||||
packet = rpc_pb2.RPC()
|
||||
packet.ParseFromString(pubsubs_fsub[0].get_hello_packet())
|
||||
return tuple(
|
||||
sub.topicid
|
||||
for sub in packet.subscriptions
|
||||
)
|
||||
|
||||
# pylint: disable=len-as-condition
|
||||
# Test: No subscription, so there should not be any topic ids in the hello packet.
|
||||
assert len(_get_hello_packet_topic_ids()) == 0
|
||||
|
||||
# Test: After subscriptions, topic ids should be in the hello packet.
|
||||
topic_ids = ["t", "o", "p", "i", "c"]
|
||||
await asyncio.gather(*[
|
||||
pubsubs_fsub[0].subscribe(topic)
|
||||
for topic in topic_ids
|
||||
])
|
||||
topic_ids_in_hello = _get_hello_packet_topic_ids()
|
||||
for topic in topic_ids:
|
||||
assert topic in topic_ids_in_hello
|
||||
|
||||
|
|
|
@ -13,6 +13,8 @@ async def connect(node1, node2):
|
|||
addr = node2.get_addrs()[0]
|
||||
info = info_from_p2p_addr(addr)
|
||||
await node1.connect(info)
|
||||
assert node1.get_id() in node2.get_network().connections
|
||||
assert node2.get_id() in node1.get_network().connections
|
||||
|
||||
|
||||
async def cleanup():
|
||||
|
@ -25,6 +27,7 @@ async def cleanup():
|
|||
with suppress(asyncio.CancelledError):
|
||||
await task
|
||||
|
||||
|
||||
async def set_up_nodes_by_transport_opt(transport_opt_list):
|
||||
nodes_list = []
|
||||
for transport_opt in transport_opt_list:
|
||||
|
|
Loading…
Reference in New Issue
Block a user