From 9683d5e8ac6e08a1e22f0049e90e50a8c4cba734 Mon Sep 17 00:00:00 2001 From: mhchia Date: Tue, 30 Jul 2019 13:33:48 +0800 Subject: [PATCH] Add tests for `Pubsub` - `test_handle_subscription` - `test_handle_talk` - `test_message_all_peers` --- libp2p/pubsub/pubsub.py | 2 +- tests/pubsub/test_pubsub.py | 111 +++++++++++++++++++++++++++++++++--- 2 files changed, 105 insertions(+), 8 deletions(-) diff --git a/libp2p/pubsub/pubsub.py b/libp2p/pubsub/pubsub.py index b46acda..7569786 100644 --- a/libp2p/pubsub/pubsub.py +++ b/libp2p/pubsub/pubsub.py @@ -323,7 +323,7 @@ class Pubsub: """ # Broadcast message - for _, stream in self.peers.items(): + for stream in self.peers.values(): # Write message to stream await stream.write(raw_msg) diff --git a/tests/pubsub/test_pubsub.py b/tests/pubsub/test_pubsub.py index 497fd86..874922e 100644 --- a/tests/pubsub/test_pubsub.py +++ b/tests/pubsub/test_pubsub.py @@ -138,7 +138,7 @@ class FakeNetStream: ) @pytest.mark.asyncio async def test_continuously_read_stream(pubsubs_fsub, monkeypatch): - s = FakeNetStream() + stream = FakeNetStream() await pubsubs_fsub[0].subscribe(TESTING_TOPIC) @@ -172,7 +172,7 @@ async def test_continuously_read_stream(pubsubs_fsub, monkeypatch): event.clear() # Kick off the task `continuously_read_stream` - task = asyncio.ensure_future(pubsubs_fsub[0].continuously_read_stream(s)) + task = asyncio.ensure_future(pubsubs_fsub[0].continuously_read_stream(stream)) # Test: `push_msg` is called when publishing to a subscribed topic. publish_subscribed_topic = rpc_pb2.RPC( @@ -180,7 +180,7 @@ async def test_continuously_read_stream(pubsubs_fsub, monkeypatch): topicIDs=[TESTING_TOPIC] )], ) - await s.write(publish_subscribed_topic.SerializeToString()) + await stream.write(publish_subscribed_topic.SerializeToString()) await wait_for_event_occurring(event_push_msg) # Make sure the other events are not emitted. with pytest.raises(asyncio.TimeoutError): @@ -194,7 +194,7 @@ async def test_continuously_read_stream(pubsubs_fsub, monkeypatch): topicIDs=["NOT_SUBSCRIBED"] )], ) - await s.write(publish_not_subscribed_topic.SerializeToString()) + await stream.write(publish_not_subscribed_topic.SerializeToString()) with pytest.raises(asyncio.TimeoutError): await wait_for_event_occurring(event_push_msg) @@ -202,7 +202,7 @@ async def test_continuously_read_stream(pubsubs_fsub, monkeypatch): subscription_msg = rpc_pb2.RPC( subscriptions=[rpc_pb2.RPC.SubOpts()], ) - await s.write(subscription_msg.SerializeToString()) + await stream.write(subscription_msg.SerializeToString()) await wait_for_event_occurring(event_handle_subscription) # Make sure the other events are not emitted. with pytest.raises(asyncio.TimeoutError): @@ -212,7 +212,7 @@ async def test_continuously_read_stream(pubsubs_fsub, monkeypatch): # Test: `handle_rpc` is called when a control message is received. control_msg = rpc_pb2.RPC(control=rpc_pb2.ControlMessage()) - await s.write(control_msg.SerializeToString()) + await stream.write(control_msg.SerializeToString()) await wait_for_event_occurring(event_handle_rpc) # Make sure the other events are not emitted. with pytest.raises(asyncio.TimeoutError): @@ -223,9 +223,106 @@ async def test_continuously_read_stream(pubsubs_fsub, monkeypatch): task.cancel() +# TODO: Add the following tests after they are aligned with Go. +# (Issue #191: https://github.com/libp2p/py-libp2p/issues/191) +# - `test_stream_handler` +# - `test_handle_peer_queue` + + @pytest.mark.parametrize( "num_hosts", - (2,), + (1,), +) +def test_handle_subscription(pubsubs_fsub): + assert len(pubsubs_fsub[0].peer_topics) == 0 + sub_msg_0 = rpc_pb2.RPC.SubOpts( + subscribe=True, + topicid=TESTING_TOPIC, + ) + peer_ids = [ + ID(b"\x12\x20" + i.to_bytes(32, "big")) + for i in range(2) + ] + # Test: One peer is subscribed + pubsubs_fsub[0].handle_subscription(peer_ids[0], sub_msg_0) + assert len(pubsubs_fsub[0].peer_topics) == 1 and TESTING_TOPIC in pubsubs_fsub[0].peer_topics + assert len(pubsubs_fsub[0].peer_topics[TESTING_TOPIC]) == 1 + assert str(peer_ids[0]) in pubsubs_fsub[0].peer_topics[TESTING_TOPIC] + # Test: Another peer is subscribed + pubsubs_fsub[0].handle_subscription(peer_ids[1], sub_msg_0) + assert len(pubsubs_fsub[0].peer_topics) == 1 + assert len(pubsubs_fsub[0].peer_topics[TESTING_TOPIC]) == 2 + assert str(peer_ids[1]) in pubsubs_fsub[0].peer_topics[TESTING_TOPIC] + # Test: Subscribe to another topic + another_topic = "ANOTHER_TOPIC" + sub_msg_1 = rpc_pb2.RPC.SubOpts( + subscribe=True, + topicid=another_topic, + ) + pubsubs_fsub[0].handle_subscription(peer_ids[0], sub_msg_1) + assert len(pubsubs_fsub[0].peer_topics) == 2 + assert another_topic in pubsubs_fsub[0].peer_topics + assert str(peer_ids[0]) in pubsubs_fsub[0].peer_topics[another_topic] + # Test: unsubscribe + unsub_msg = rpc_pb2.RPC.SubOpts( + subscribe=False, + topicid=TESTING_TOPIC, + ) + pubsubs_fsub[0].handle_subscription(peer_ids[0], unsub_msg) + assert str(peer_ids[0]) not in pubsubs_fsub[0].peer_topics[TESTING_TOPIC] + + +@pytest.mark.parametrize( + "num_hosts", + (1,), +) +@pytest.mark.asyncio +async def test_handle_talk(pubsubs_fsub): + sub = await pubsubs_fsub[0].subscribe(TESTING_TOPIC) + msg_0 = make_pubsub_msg( + origin_id=pubsubs_fsub[0].my_id, + topic_ids=[TESTING_TOPIC], + data=b"1234", + seqno=b"\x00" * 8, + ) + await pubsubs_fsub[0].handle_talk(msg_0) + msg_1 = make_pubsub_msg( + origin_id=pubsubs_fsub[0].my_id, + topic_ids=["NOT_SUBSCRIBED"], + data=b"1234", + seqno=b"\x11" * 8, + ) + await pubsubs_fsub[0].handle_talk(msg_1) + assert len(pubsubs_fsub[0].my_topics) == 1 and sub == pubsubs_fsub[0].my_topics[TESTING_TOPIC] + assert sub.qsize() == 1 + assert (await sub.get()) == msg_0 + + +@pytest.mark.parametrize( + "num_hosts", + (1,), +) +@pytest.mark.asyncio +async def test_message_all_peers(pubsubs_fsub, monkeypatch): + peer_ids = [ + ID(b"\x12\x20" + i.to_bytes(32, "big")) + for i in range(10) + ] + mock_peers = { + str(peer_id): FakeNetStream() + for peer_id in peer_ids + } + monkeypatch.setattr(pubsubs_fsub[0], "peers", mock_peers) + + empty_rpc = rpc_pb2.RPC() + await pubsubs_fsub[0].message_all_peers(empty_rpc.SerializeToString()) + for stream in mock_peers.values(): + assert (await stream.read()) == empty_rpc.SerializeToString() + + +@pytest.mark.parametrize( + "num_hosts", + (1,), ) @pytest.mark.asyncio async def test_publish(pubsubs_fsub, monkeypatch):