Combine test_subscription.py to test_pubsub.py

And add a bunch of tests for pubsub
This commit is contained in:
mhchia 2019-07-28 18:07:48 +08:00
parent 96563c0d84
commit 550289a439
No known key found for this signature in database
GPG Key ID: 389EFBEA1362589A
5 changed files with 125 additions and 20 deletions

View File

@ -118,12 +118,14 @@ class Pubsub:
Generate subscription message with all topics we are subscribed to Generate subscription message with all topics we are subscribed to
only send hello packet if we have subscribed topics only send hello packet if we have subscribed topics
""" """
packet: rpc_pb2.RPC = rpc_pb2.RPC() packet = rpc_pb2.RPC()
if self.my_topics: for topic_id in self.my_topics:
for topic_id in self.my_topics: packet.subscriptions.extend([
packet.subscriptions.extend([rpc_pb2.RPC.SubOpts( rpc_pb2.RPC.SubOpts(
subscribe=True, topicid=topic_id)]) subscribe=True,
topicid=topic_id,
)
])
return packet.SerializeToString() return packet.SerializeToString()
async def continuously_read_stream(self, stream: INetStream) -> None: async def continuously_read_stream(self, stream: INetStream) -> None:
@ -182,6 +184,8 @@ class Pubsub:
await stream.write(hello) await stream.write(hello)
# Pass stream off to stream reader # Pass stream off to stream reader
asyncio.ensure_future(self.continuously_read_stream(stream)) asyncio.ensure_future(self.continuously_read_stream(stream))
# Force context switch
await asyncio.sleep(0)
async def handle_peer_queue(self) -> None: async def handle_peer_queue(self) -> None:
""" """
@ -208,6 +212,9 @@ class Pubsub:
hello: bytes = self.get_hello_packet() hello: bytes = self.get_hello_packet()
await stream.write(hello) await stream.write(hello)
# pylint: disable=line-too-long
# TODO: Investigate whether this should be replaced by `handlePeerEOF`
# Ref: https://github.com/libp2p/go-libp2p-pubsub/blob/49274b0e8aecdf6cad59d768e5702ff00aa48488/comm.go#L80 # noqa: E501
# Pass stream off to stream reader # Pass stream off to stream reader
asyncio.ensure_future(self.continuously_read_stream(stream)) asyncio.ensure_future(self.continuously_read_stream(stream))

View File

@ -27,21 +27,25 @@ def num_hosts():
@pytest.fixture @pytest.fixture
async def hosts(num_hosts): async def hosts(num_hosts):
new_node_coros = tuple( _hosts = await asyncio.gather(*[
new_node(transport_opt=[str(LISTEN_MADDR)]) new_node(transport_opt=[str(LISTEN_MADDR)])
for _ in range(num_hosts) for _ in range(num_hosts)
) ])
_hosts = await asyncio.gather(*new_node_coros)
await asyncio.gather(*[ await asyncio.gather(*[
_host.get_network().listen(LISTEN_MADDR) _host.get_network().listen(LISTEN_MADDR)
for _host in _hosts for _host in _hosts
]) ])
yield _hosts yield _hosts
# Clean up # Clean up
listeners = []
for _host in _hosts: for _host in _hosts:
for listener in _host.get_network().listeners.values(): for listener in _host.get_network().listeners.values():
listener.server.close() listener.server.close()
await listener.server.wait_closed() listeners.append(listener)
await asyncio.gather(*[
listener.server.wait_closed()
for listener in listeners
])
@pytest.fixture @pytest.fixture

View File

@ -3,8 +3,8 @@ import pytest
from libp2p.pubsub.mcache import MessageCache from libp2p.pubsub.mcache import MessageCache
# pylint: disable=too-few-public-methods
class Msg: class Msg:
def __init__(self, topicIDs, seqno, from_id): def __init__(self, topicIDs, seqno, from_id):
# pylint: disable=invalid-name # pylint: disable=invalid-name
self.topicIDs = topicIDs self.topicIDs = topicIDs
@ -15,8 +15,7 @@ class Msg:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_mcache(): async def test_mcache():
# Ported from: # Ported from:
# https://github.com/libp2p/go-libp2p-pubsub # https://github.com/libp2p/go-libp2p-pubsub/blob/51b7501433411b5096cac2b4994a36a68515fc03/mcache_test.go
# /blob/51b7501433411b5096cac2b4994a36a68515fc03/mcache_test.go
mcache = MessageCache(3, 5) mcache = MessageCache(3, 5)
msgs = [] msgs = []

View File

@ -1,11 +1,103 @@
import asyncio
import pytest import pytest
from libp2p.pubsub.pb import rpc_pb2
from tests.utils import (
connect,
)
TESTING_TOPIC = "TEST_SUBSCRIBE"
TESTIND_DATA = b"data"
@pytest.mark.parametrize(
"num_hosts",
(1,),
)
@pytest.mark.asyncio
async def test_subscribe_and_unsubscribe(pubsubs_fsub):
await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
assert TESTING_TOPIC in pubsubs_fsub[0].my_topics
await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC)
assert TESTING_TOPIC not in pubsubs_fsub[0].my_topics
@pytest.mark.parametrize(
"num_hosts",
(1,),
)
@pytest.mark.asyncio
async def test_re_subscribe(pubsubs_fsub):
await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
assert TESTING_TOPIC in pubsubs_fsub[0].my_topics
await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
assert TESTING_TOPIC in pubsubs_fsub[0].my_topics
@pytest.mark.parametrize(
"num_hosts",
(1,),
)
@pytest.mark.asyncio
async def test_re_unsubscribe(pubsubs_fsub):
# Unsubscribe from topic we didn't even subscribe to
assert "NOT_MY_TOPIC" not in pubsubs_fsub[0].my_topics
await pubsubs_fsub[0].unsubscribe("NOT_MY_TOPIC")
assert "NOT_MY_TOPIC" not in pubsubs_fsub[0].my_topics
await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
assert TESTING_TOPIC in pubsubs_fsub[0].my_topics
await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC)
assert TESTING_TOPIC not in pubsubs_fsub[0].my_topics
await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC)
assert TESTING_TOPIC not in pubsubs_fsub[0].my_topics
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_test(pubsubs_fsub): async def test_peers_subscribe(pubsubs_fsub):
topic = "topic" await connect(pubsubs_fsub[0].host, pubsubs_fsub[1].host)
data = b"data" await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
sub = await pubsubs_fsub[0].subscribe(topic) # Yield to let 0 notify 1
await pubsubs_fsub[0].publish(topic, data) await asyncio.sleep(0.1)
msg = await sub.get() assert str(pubsubs_fsub[0].my_id) in pubsubs_fsub[1].peer_topics[TESTING_TOPIC]
assert msg.data == data await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC)
# Yield to let 0 notify 1
await asyncio.sleep(0.1)
assert str(pubsubs_fsub[0].my_id) not in pubsubs_fsub[1].peer_topics[TESTING_TOPIC]
@pytest.mark.parametrize(
"num_hosts",
(1,),
)
@pytest.mark.asyncio
async def test_get_hello_packet(pubsubs_fsub):
def _get_hello_packet_topic_ids():
packet = rpc_pb2.RPC()
packet.ParseFromString(pubsubs_fsub[0].get_hello_packet())
return tuple(
sub.topicid
for sub in packet.subscriptions
)
# pylint: disable=len-as-condition
# Test: No subscription, so there should not be any topic ids in the hello packet.
assert len(_get_hello_packet_topic_ids()) == 0
# Test: After subscriptions, topic ids should be in the hello packet.
topic_ids = ["t", "o", "p", "i", "c"]
await asyncio.gather(*[
pubsubs_fsub[0].subscribe(topic)
for topic in topic_ids
])
topic_ids_in_hello = _get_hello_packet_topic_ids()
for topic in topic_ids:
assert topic in topic_ids_in_hello

View File

@ -13,6 +13,8 @@ async def connect(node1, node2):
addr = node2.get_addrs()[0] addr = node2.get_addrs()[0]
info = info_from_p2p_addr(addr) info = info_from_p2p_addr(addr)
await node1.connect(info) await node1.connect(info)
assert node1.get_id() in node2.get_network().connections
assert node2.get_id() in node1.get_network().connections
async def cleanup(): async def cleanup():
@ -25,6 +27,7 @@ async def cleanup():
with suppress(asyncio.CancelledError): with suppress(asyncio.CancelledError):
await task await task
async def set_up_nodes_by_transport_opt(transport_opt_list): async def set_up_nodes_by_transport_opt(transport_opt_list):
nodes_list = [] nodes_list = []
for transport_opt in transport_opt_list: for transport_opt in transport_opt_list: