From 550289a439b29c5dcf412b3da5ce7323c75295e2 Mon Sep 17 00:00:00 2001 From: mhchia Date: Sun, 28 Jul 2019 18:07:48 +0800 Subject: [PATCH] Combine test_subscription.py to test_pubsub.py And add a bunch of tests for pubsub --- libp2p/pubsub/pubsub.py | 19 +++++-- tests/pubsub/conftest.py | 12 ++-- tests/pubsub/test_mcache.py | 5 +- tests/pubsub/test_pubsub.py | 106 +++++++++++++++++++++++++++++++++--- tests/utils.py | 3 + 5 files changed, 125 insertions(+), 20 deletions(-) diff --git a/libp2p/pubsub/pubsub.py b/libp2p/pubsub/pubsub.py index 9534a95..265635b 100644 --- a/libp2p/pubsub/pubsub.py +++ b/libp2p/pubsub/pubsub.py @@ -118,12 +118,14 @@ class Pubsub: Generate subscription message with all topics we are subscribed to only send hello packet if we have subscribed topics """ - packet: rpc_pb2.RPC = rpc_pb2.RPC() - if self.my_topics: - for topic_id in self.my_topics: - packet.subscriptions.extend([rpc_pb2.RPC.SubOpts( - subscribe=True, topicid=topic_id)]) - + packet = rpc_pb2.RPC() + for topic_id in self.my_topics: + packet.subscriptions.extend([ + rpc_pb2.RPC.SubOpts( + subscribe=True, + topicid=topic_id, + ) + ]) return packet.SerializeToString() async def continuously_read_stream(self, stream: INetStream) -> None: @@ -182,6 +184,8 @@ class Pubsub: await stream.write(hello) # Pass stream off to stream reader asyncio.ensure_future(self.continuously_read_stream(stream)) + # Force context switch + await asyncio.sleep(0) async def handle_peer_queue(self) -> None: """ @@ -208,6 +212,9 @@ class Pubsub: hello: bytes = self.get_hello_packet() await stream.write(hello) + # pylint: disable=line-too-long + # TODO: Investigate whether this should be replaced by `handlePeerEOF` + # Ref: https://github.com/libp2p/go-libp2p-pubsub/blob/49274b0e8aecdf6cad59d768e5702ff00aa48488/comm.go#L80 # noqa: E501 # Pass stream off to stream reader asyncio.ensure_future(self.continuously_read_stream(stream)) diff --git a/tests/pubsub/conftest.py b/tests/pubsub/conftest.py index a2a3419..4a28888 100644 --- a/tests/pubsub/conftest.py +++ b/tests/pubsub/conftest.py @@ -27,21 +27,25 @@ def num_hosts(): @pytest.fixture async def hosts(num_hosts): - new_node_coros = tuple( + _hosts = await asyncio.gather(*[ new_node(transport_opt=[str(LISTEN_MADDR)]) for _ in range(num_hosts) - ) - _hosts = await asyncio.gather(*new_node_coros) + ]) await asyncio.gather(*[ _host.get_network().listen(LISTEN_MADDR) for _host in _hosts ]) yield _hosts # Clean up + listeners = [] for _host in _hosts: for listener in _host.get_network().listeners.values(): listener.server.close() - await listener.server.wait_closed() + listeners.append(listener) + await asyncio.gather(*[ + listener.server.wait_closed() + for listener in listeners + ]) @pytest.fixture diff --git a/tests/pubsub/test_mcache.py b/tests/pubsub/test_mcache.py index 0e73222..ebfe957 100644 --- a/tests/pubsub/test_mcache.py +++ b/tests/pubsub/test_mcache.py @@ -3,8 +3,8 @@ import pytest from libp2p.pubsub.mcache import MessageCache +# pylint: disable=too-few-public-methods class Msg: - def __init__(self, topicIDs, seqno, from_id): # pylint: disable=invalid-name self.topicIDs = topicIDs @@ -15,8 +15,7 @@ class Msg: @pytest.mark.asyncio async def test_mcache(): # Ported from: - # https://github.com/libp2p/go-libp2p-pubsub - # /blob/51b7501433411b5096cac2b4994a36a68515fc03/mcache_test.go + # https://github.com/libp2p/go-libp2p-pubsub/blob/51b7501433411b5096cac2b4994a36a68515fc03/mcache_test.go mcache = MessageCache(3, 5) msgs = [] diff --git a/tests/pubsub/test_pubsub.py b/tests/pubsub/test_pubsub.py index fd0515a..4de5cf8 100644 --- a/tests/pubsub/test_pubsub.py +++ b/tests/pubsub/test_pubsub.py @@ -1,11 +1,103 @@ +import asyncio + import pytest +from libp2p.pubsub.pb import rpc_pb2 + +from tests.utils import ( + connect, +) + + +TESTING_TOPIC = "TEST_SUBSCRIBE" +TESTIND_DATA = b"data" + + +@pytest.mark.parametrize( + "num_hosts", + (1,), +) +@pytest.mark.asyncio +async def test_subscribe_and_unsubscribe(pubsubs_fsub): + await pubsubs_fsub[0].subscribe(TESTING_TOPIC) + assert TESTING_TOPIC in pubsubs_fsub[0].my_topics + + await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC) + assert TESTING_TOPIC not in pubsubs_fsub[0].my_topics + + +@pytest.mark.parametrize( + "num_hosts", + (1,), +) +@pytest.mark.asyncio +async def test_re_subscribe(pubsubs_fsub): + await pubsubs_fsub[0].subscribe(TESTING_TOPIC) + assert TESTING_TOPIC in pubsubs_fsub[0].my_topics + + await pubsubs_fsub[0].subscribe(TESTING_TOPIC) + assert TESTING_TOPIC in pubsubs_fsub[0].my_topics + + +@pytest.mark.parametrize( + "num_hosts", + (1,), +) +@pytest.mark.asyncio +async def test_re_unsubscribe(pubsubs_fsub): + # Unsubscribe from topic we didn't even subscribe to + assert "NOT_MY_TOPIC" not in pubsubs_fsub[0].my_topics + await pubsubs_fsub[0].unsubscribe("NOT_MY_TOPIC") + assert "NOT_MY_TOPIC" not in pubsubs_fsub[0].my_topics + + await pubsubs_fsub[0].subscribe(TESTING_TOPIC) + assert TESTING_TOPIC in pubsubs_fsub[0].my_topics + + await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC) + assert TESTING_TOPIC not in pubsubs_fsub[0].my_topics + + await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC) + assert TESTING_TOPIC not in pubsubs_fsub[0].my_topics + @pytest.mark.asyncio -async def test_test(pubsubs_fsub): - topic = "topic" - data = b"data" - sub = await pubsubs_fsub[0].subscribe(topic) - await pubsubs_fsub[0].publish(topic, data) - msg = await sub.get() - assert msg.data == data +async def test_peers_subscribe(pubsubs_fsub): + await connect(pubsubs_fsub[0].host, pubsubs_fsub[1].host) + await pubsubs_fsub[0].subscribe(TESTING_TOPIC) + # Yield to let 0 notify 1 + await asyncio.sleep(0.1) + assert str(pubsubs_fsub[0].my_id) in pubsubs_fsub[1].peer_topics[TESTING_TOPIC] + await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC) + # Yield to let 0 notify 1 + await asyncio.sleep(0.1) + assert str(pubsubs_fsub[0].my_id) not in pubsubs_fsub[1].peer_topics[TESTING_TOPIC] + + +@pytest.mark.parametrize( + "num_hosts", + (1,), +) +@pytest.mark.asyncio +async def test_get_hello_packet(pubsubs_fsub): + def _get_hello_packet_topic_ids(): + packet = rpc_pb2.RPC() + packet.ParseFromString(pubsubs_fsub[0].get_hello_packet()) + return tuple( + sub.topicid + for sub in packet.subscriptions + ) + + # pylint: disable=len-as-condition + # Test: No subscription, so there should not be any topic ids in the hello packet. + assert len(_get_hello_packet_topic_ids()) == 0 + + # Test: After subscriptions, topic ids should be in the hello packet. + topic_ids = ["t", "o", "p", "i", "c"] + await asyncio.gather(*[ + pubsubs_fsub[0].subscribe(topic) + for topic in topic_ids + ]) + topic_ids_in_hello = _get_hello_packet_topic_ids() + for topic in topic_ids: + assert topic in topic_ids_in_hello + diff --git a/tests/utils.py b/tests/utils.py index 686d086..6e2d64c 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -13,6 +13,8 @@ async def connect(node1, node2): addr = node2.get_addrs()[0] info = info_from_p2p_addr(addr) await node1.connect(info) + assert node1.get_id() in node2.get_network().connections + assert node2.get_id() in node1.get_network().connections async def cleanup(): @@ -25,6 +27,7 @@ async def cleanup(): with suppress(asyncio.CancelledError): await task + async def set_up_nodes_by_transport_opt(transport_opt_list): nodes_list = [] for transport_opt in transport_opt_list: