Combine test_subscription.py to test_pubsub.py

And add a bunch of tests for pubsub
This commit is contained in:
mhchia 2019-07-28 18:07:48 +08:00
parent 96563c0d84
commit 550289a439
No known key found for this signature in database
GPG Key ID: 389EFBEA1362589A
5 changed files with 125 additions and 20 deletions

View File

@ -118,12 +118,14 @@ class Pubsub:
Generate subscription message with all topics we are subscribed to
only send hello packet if we have subscribed topics
"""
packet: rpc_pb2.RPC = rpc_pb2.RPC()
if self.my_topics:
for topic_id in self.my_topics:
packet.subscriptions.extend([rpc_pb2.RPC.SubOpts(
subscribe=True, topicid=topic_id)])
packet = rpc_pb2.RPC()
for topic_id in self.my_topics:
packet.subscriptions.extend([
rpc_pb2.RPC.SubOpts(
subscribe=True,
topicid=topic_id,
)
])
return packet.SerializeToString()
async def continuously_read_stream(self, stream: INetStream) -> None:
@ -182,6 +184,8 @@ class Pubsub:
await stream.write(hello)
# Pass stream off to stream reader
asyncio.ensure_future(self.continuously_read_stream(stream))
# Force context switch
await asyncio.sleep(0)
async def handle_peer_queue(self) -> None:
"""
@ -208,6 +212,9 @@ class Pubsub:
hello: bytes = self.get_hello_packet()
await stream.write(hello)
# pylint: disable=line-too-long
# TODO: Investigate whether this should be replaced by `handlePeerEOF`
# Ref: https://github.com/libp2p/go-libp2p-pubsub/blob/49274b0e8aecdf6cad59d768e5702ff00aa48488/comm.go#L80 # noqa: E501
# Pass stream off to stream reader
asyncio.ensure_future(self.continuously_read_stream(stream))

View File

@ -27,21 +27,25 @@ def num_hosts():
@pytest.fixture
async def hosts(num_hosts):
new_node_coros = tuple(
_hosts = await asyncio.gather(*[
new_node(transport_opt=[str(LISTEN_MADDR)])
for _ in range(num_hosts)
)
_hosts = await asyncio.gather(*new_node_coros)
])
await asyncio.gather(*[
_host.get_network().listen(LISTEN_MADDR)
for _host in _hosts
])
yield _hosts
# Clean up
listeners = []
for _host in _hosts:
for listener in _host.get_network().listeners.values():
listener.server.close()
await listener.server.wait_closed()
listeners.append(listener)
await asyncio.gather(*[
listener.server.wait_closed()
for listener in listeners
])
@pytest.fixture

View File

@ -3,8 +3,8 @@ import pytest
from libp2p.pubsub.mcache import MessageCache
# pylint: disable=too-few-public-methods
class Msg:
def __init__(self, topicIDs, seqno, from_id):
# pylint: disable=invalid-name
self.topicIDs = topicIDs
@ -15,8 +15,7 @@ class Msg:
@pytest.mark.asyncio
async def test_mcache():
# Ported from:
# https://github.com/libp2p/go-libp2p-pubsub
# /blob/51b7501433411b5096cac2b4994a36a68515fc03/mcache_test.go
# https://github.com/libp2p/go-libp2p-pubsub/blob/51b7501433411b5096cac2b4994a36a68515fc03/mcache_test.go
mcache = MessageCache(3, 5)
msgs = []

View File

@ -1,11 +1,103 @@
import asyncio
import pytest
from libp2p.pubsub.pb import rpc_pb2
from tests.utils import (
connect,
)
TESTING_TOPIC = "TEST_SUBSCRIBE"
TESTIND_DATA = b"data"
@pytest.mark.parametrize(
"num_hosts",
(1,),
)
@pytest.mark.asyncio
async def test_subscribe_and_unsubscribe(pubsubs_fsub):
await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
assert TESTING_TOPIC in pubsubs_fsub[0].my_topics
await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC)
assert TESTING_TOPIC not in pubsubs_fsub[0].my_topics
@pytest.mark.parametrize(
"num_hosts",
(1,),
)
@pytest.mark.asyncio
async def test_re_subscribe(pubsubs_fsub):
await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
assert TESTING_TOPIC in pubsubs_fsub[0].my_topics
await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
assert TESTING_TOPIC in pubsubs_fsub[0].my_topics
@pytest.mark.parametrize(
"num_hosts",
(1,),
)
@pytest.mark.asyncio
async def test_re_unsubscribe(pubsubs_fsub):
# Unsubscribe from topic we didn't even subscribe to
assert "NOT_MY_TOPIC" not in pubsubs_fsub[0].my_topics
await pubsubs_fsub[0].unsubscribe("NOT_MY_TOPIC")
assert "NOT_MY_TOPIC" not in pubsubs_fsub[0].my_topics
await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
assert TESTING_TOPIC in pubsubs_fsub[0].my_topics
await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC)
assert TESTING_TOPIC not in pubsubs_fsub[0].my_topics
await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC)
assert TESTING_TOPIC not in pubsubs_fsub[0].my_topics
@pytest.mark.asyncio
async def test_test(pubsubs_fsub):
topic = "topic"
data = b"data"
sub = await pubsubs_fsub[0].subscribe(topic)
await pubsubs_fsub[0].publish(topic, data)
msg = await sub.get()
assert msg.data == data
async def test_peers_subscribe(pubsubs_fsub):
await connect(pubsubs_fsub[0].host, pubsubs_fsub[1].host)
await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
# Yield to let 0 notify 1
await asyncio.sleep(0.1)
assert str(pubsubs_fsub[0].my_id) in pubsubs_fsub[1].peer_topics[TESTING_TOPIC]
await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC)
# Yield to let 0 notify 1
await asyncio.sleep(0.1)
assert str(pubsubs_fsub[0].my_id) not in pubsubs_fsub[1].peer_topics[TESTING_TOPIC]
@pytest.mark.parametrize(
"num_hosts",
(1,),
)
@pytest.mark.asyncio
async def test_get_hello_packet(pubsubs_fsub):
def _get_hello_packet_topic_ids():
packet = rpc_pb2.RPC()
packet.ParseFromString(pubsubs_fsub[0].get_hello_packet())
return tuple(
sub.topicid
for sub in packet.subscriptions
)
# pylint: disable=len-as-condition
# Test: No subscription, so there should not be any topic ids in the hello packet.
assert len(_get_hello_packet_topic_ids()) == 0
# Test: After subscriptions, topic ids should be in the hello packet.
topic_ids = ["t", "o", "p", "i", "c"]
await asyncio.gather(*[
pubsubs_fsub[0].subscribe(topic)
for topic in topic_ids
])
topic_ids_in_hello = _get_hello_packet_topic_ids()
for topic in topic_ids:
assert topic in topic_ids_in_hello

View File

@ -13,6 +13,8 @@ async def connect(node1, node2):
addr = node2.get_addrs()[0]
info = info_from_p2p_addr(addr)
await node1.connect(info)
assert node1.get_id() in node2.get_network().connections
assert node2.get_id() in node1.get_network().connections
async def cleanup():
@ -25,6 +27,7 @@ async def cleanup():
with suppress(asyncio.CancelledError):
await task
async def set_up_nodes_by_transport_opt(transport_opt_list):
nodes_list = []
for transport_opt in transport_opt_list: