From dadcf8138eb4ff4ba49bd76d875f2a6d14902308 Mon Sep 17 00:00:00 2001 From: mhchia Date: Thu, 25 Jul 2019 16:58:00 +0800 Subject: [PATCH] Fix the tests according to `pubsub.Publish` And refactored a bit. --- libp2p/pubsub/gossipsub.py | 1 + libp2p/pubsub/pubsub.py | 29 ++++----- tests/pubsub/test_floodsub.py | 118 ++++++++++++++-------------------- 3 files changed, 62 insertions(+), 86 deletions(-) diff --git a/libp2p/pubsub/gossipsub.py b/libp2p/pubsub/gossipsub.py index 31ab606..8f593a6 100644 --- a/libp2p/pubsub/gossipsub.py +++ b/libp2p/pubsub/gossipsub.py @@ -91,6 +91,7 @@ class GossipSub(IPubsubRouter): :param rpc: rpc message """ control_message = rpc.control + sender_peer_id = str(sender_peer_id) # Relay each rpc control to the appropriate handler if control_message.ihave: diff --git a/libp2p/pubsub/pubsub.py b/libp2p/pubsub/pubsub.py index a7e542c..560b8c2 100644 --- a/libp2p/pubsub/pubsub.py +++ b/libp2p/pubsub/pubsub.py @@ -128,26 +128,21 @@ class Pubsub: messages from other nodes :param stream: stream to continously read from """ - - # TODO check on types here - peer_id = str(stream.mplex_conn.peer_id) + peer_id = stream.mplex_conn.peer_id while True: incoming = (await stream.read()) rpc_incoming = rpc_pb2.RPC() rpc_incoming.ParseFromString(incoming) - should_publish = False - if rpc_incoming.publish: # deal with RPC.publish - for message in rpc_incoming.publish: - id_in_seen_msgs = (message.seqno, message.from_id) - if id_in_seen_msgs not in self.seen_messages: - should_publish = True - self.seen_messages[id_in_seen_msgs] = 1 - - await self.handle_talk(message) + for msg in rpc_incoming.publish: + if not self._is_subscribed_to_msg(msg): + continue + # TODO(mhchia): This will block this read_stream loop until all data are pushed. + # Should investigate further if this is an issue. + await self.push_msg(src=peer_id, msg=msg) if rpc_incoming.subscriptions: # deal with RPC.subscriptions @@ -158,10 +153,6 @@ class Pubsub: for message in rpc_incoming.subscriptions: self.handle_subscription(peer_id, message) - if should_publish: - # relay message to peers with router - await self.router.publish(peer_id, incoming) - if rpc_incoming.control: # Pass rpc to router so router could perform custom logic await self.router.handle_rpc(rpc_incoming, peer_id) @@ -228,6 +219,7 @@ class Pubsub: :param origin_id: id of the peer who subscribe to the message :param sub_message: RPC.SubOpts """ + origin_id = str(origin_id) if sub_message.subscribe: if sub_message.topicid not in self.peer_topics: self.peer_topics[sub_message.topicid] = [origin_id] @@ -379,3 +371,8 @@ class Pubsub: # FIXME: Mapping `msg_id` to `1` is quite awkward. Should investigate if there is a # more appropriate way. self.seen_messages[msg_id] = 1 + + def _is_subscribed_to_msg(self, msg: rpc_pb2.Message) -> bool: + if len(self.my_topics) == 0: + return False + return all([topic in self.my_topics for topic in msg.topicIDs]) diff --git a/tests/pubsub/test_floodsub.py b/tests/pubsub/test_floodsub.py index 389c217..82831b8 100644 --- a/tests/pubsub/test_floodsub.py +++ b/tests/pubsub/test_floodsub.py @@ -4,17 +4,17 @@ import pytest from tests.utils import cleanup from libp2p import new_node +from libp2p.peer.id import ID from libp2p.peer.peerinfo import info_from_p2p_addr -from libp2p.pubsub.pb import rpc_pb2 from libp2p.pubsub.pubsub import Pubsub from libp2p.pubsub.floodsub import FloodSub from .utils import ( - make_pubsub_msg, message_id_generator, generate_RPC_packet, ) + # pylint: disable=too-many-locals async def connect(node1, node2): @@ -39,106 +39,84 @@ async def test_simple_two_nodes(): data = b"some data" floodsub_a = FloodSub(supported_protocols) - pubsub_a = Pubsub(node_a, floodsub_a, "a") + pubsub_a = Pubsub(node_a, floodsub_a, ID(b"a" * 32)) floodsub_b = FloodSub(supported_protocols) - pubsub_b = Pubsub(node_b, floodsub_b, "b") + pubsub_b = Pubsub(node_b, floodsub_b, ID(b"b" * 32)) await connect(node_a, node_b) - await asyncio.sleep(0.25) + sub_b = await pubsub_b.subscribe(topic) - + # Sleep to let a know of b's subscription await asyncio.sleep(0.25) - next_msg_id_func = message_id_generator(0) - msg = make_pubsub_msg( - origin_id=node_a.get_id(), - topic_ids=[topic], - data=data, - seqno=next_msg_id_func(), - ) - await floodsub_a.publish(node_a.get_id(), msg) - await asyncio.sleep(0.25) + await pubsub_a.publish(topic, data) res_b = await sub_b.get() # Check that the msg received by node_b is the same # as the message sent by node_a - assert res_b.SerializeToString() == msg.SerializeToString() + assert ID(res_b.from_id) == node_a.get_id() + assert res_b.data == data + assert res_b.topicIDs == [topic] # Success, terminate pending tasks. await cleanup() @pytest.mark.asyncio -async def test_lru_cache_two_nodes(): +async def test_lru_cache_two_nodes(monkeypatch): # two nodes with cache_size of 4 - # node_a send the following messages to node_b - # [1, 1, 2, 1, 3, 1, 4, 1, 5, 1] - # node_b should only receive the following - # [1, 2, 3, 4, 5, 1] - node_a = await new_node(transport_opt=["/ip4/127.0.0.1/tcp/0"]) - node_b = await new_node(transport_opt=["/ip4/127.0.0.1/tcp/0"]) + # `node_a` send the following messages to node_b + message_indices = [1, 1, 2, 1, 3, 1, 4, 1, 5, 1] + # `node_b` should only receive the following + expected_received_indices = [1, 2, 3, 4, 5, 1] - await node_a.get_network().listen(multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/0")) - await node_b.get_network().listen(multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/0")) + listen_maddr = multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/0") + node_a = await new_node() + node_b = await new_node() + + await node_a.get_network().listen(listen_maddr) + await node_b.get_network().listen(listen_maddr) supported_protocols = ["/floodsub/1.0.0"] + topic = "my_topic" - # initialize PubSub with a cache_size of 4 + # Mock `get_msg_id` to make us easier to manipulate `msg_id` by `data`. + def get_msg_id(msg): + # Originally it is `(msg.seqno, msg.from_id)` + return (msg.data, msg.from_id) + import libp2p.pubsub.pubsub + monkeypatch.setattr(libp2p.pubsub.pubsub, "get_msg_id", get_msg_id) + + # Initialize Pubsub with a cache_size of 4 + cache_size = 4 floodsub_a = FloodSub(supported_protocols) - pubsub_a = Pubsub(node_a, floodsub_a, "a", 4) + pubsub_a = Pubsub(node_a, floodsub_a, ID(b"a" * 32), cache_size) + floodsub_b = FloodSub(supported_protocols) - pubsub_b = Pubsub(node_b, floodsub_b, "b", 4) + pubsub_b = Pubsub(node_b, floodsub_b, ID(b"b" * 32), cache_size) await connect(node_a, node_b) - - await asyncio.sleep(0.25) - qb = await pubsub_b.subscribe("my_topic") - await asyncio.sleep(0.25) - node_a_id = str(node_a.get_id()) - - # initialize message_id_generator - # store first message - next_msg_id_func = message_id_generator(0) - first_message = generate_RPC_packet(node_a_id, ["my_topic"], "some data 1", next_msg_id_func()) - - await floodsub_a.publish(node_a_id, first_message.SerializeToString()) - await asyncio.sleep(0.25) - print (first_message) - - messages = [first_message] - # for the next 5 messages - for i in range(2, 6): - # write first message - await floodsub_a.publish(node_a_id, first_message.SerializeToString()) - await asyncio.sleep(0.25) - - # generate and write next message - msg = generate_RPC_packet(node_a_id, ["my_topic"], "some data " + str(i), next_msg_id_func()) - messages.append(msg) - - await floodsub_a.publish(node_a_id, msg.SerializeToString()) - await asyncio.sleep(0.25) - - # write first message again - await floodsub_a.publish(node_a_id, first_message.SerializeToString()) + sub_b = await pubsub_b.subscribe(topic) await asyncio.sleep(0.25) - # check the first five messages in queue - # should only see 1 first_message - for i in range(5): - # Check that the msg received by node_b is the same - # as the message sent by node_a - res_b = await qb.get() - assert res_b.SerializeToString() == messages[i].publish[0].SerializeToString() + def _make_testing_data(i: int) -> bytes: + num_int_bytes = 4 + if i >= 2**(num_int_bytes * 8): + raise ValueError("") + return b"data" + i.to_bytes(num_int_bytes, "big") - # the 6th message should be first_message - res_b = await qb.get() - assert res_b.SerializeToString() == first_message.publish[0].SerializeToString() - assert qb.empty() + for index in message_indices: + await pubsub_a.publish(topic, _make_testing_data(index)) + await asyncio.sleep(0.25) + + for index in expected_received_indices: + res_b = await sub_b.get() + assert res_b.data == _make_testing_data(index) + assert sub_b.empty() # Success, terminate pending tasks. await cleanup()