diff --git a/libp2p/pubsub/floodsub.py b/libp2p/pubsub/floodsub.py index 0f9ebf4..40aafc5 100644 --- a/libp2p/pubsub/floodsub.py +++ b/libp2p/pubsub/floodsub.py @@ -80,6 +80,21 @@ class FloodSub(IPubsubRouter): # Ref: https://github.com/libp2p/go-libp2p-pubsub/blob/master/comm.go#L107 await stream.write(rpc_msg.SerializeToString()) + async def join(self, topic): + """ + Join notifies the router that we want to receive and + forward messages in a topic. It is invoked after the + subscription announcement + :param topic: topic to join + """ + + async def leave(self, topic): + """ + Leave notifies the router that we are no longer interested in a topic. + It is invoked after the unsubscription announcement. + :param topic: topic to leave + """ + def _get_peers_to_send( self, topic_ids: Iterable[str], @@ -102,18 +117,3 @@ class FloodSub(IPubsubRouter): if str(peer_id) not in self.pubsub.peers: continue yield peer_id - - async def join(self, topic): - """ - Join notifies the router that we want to receive and - forward messages in a topic. It is invoked after the - subscription announcement - :param topic: topic to join - """ - - async def leave(self, topic): - """ - Leave notifies the router that we are no longer interested in a topic. - It is invoked after the unsubscription announcement. - :param topic: topic to leave - """ diff --git a/libp2p/pubsub/gossipsub.py b/libp2p/pubsub/gossipsub.py index 8f593a6..8051a3b 100644 --- a/libp2p/pubsub/gossipsub.py +++ b/libp2p/pubsub/gossipsub.py @@ -1,10 +1,22 @@ -import random import asyncio +import random +from typing import ( + Iterable, + List, + MutableSet, + Sequence, +) from ast import literal_eval + +from libp2p.peer.id import ( + ID, + id_b58_decode, +) + +from .mcache import MessageCache from .pb import rpc_pb2 from .pubsub_router_interface import IPubsubRouter -from .mcache import MessageCache class GossipSub(IPubsubRouter): @@ -107,70 +119,73 @@ class GossipSub(IPubsubRouter): for prune in control_message.prune: await self.handle_prune(prune, sender_peer_id) - async def publish(self, sender_peer_id, rpc_message): + async def publish(self, src: ID, pubsub_msg: rpc_pb2.Message) -> None: # pylint: disable=too-many-locals """ Invoked to forward a new message that has been validated. """ + self.mcache.put(pubsub_msg) - packet = rpc_pb2.RPC() - packet.ParseFromString(rpc_message) - msg_sender = str(sender_peer_id) + peers_gen = self._get_peers_to_send( + pubsub_msg.topicIDs, + src=src, + origin=ID(pubsub_msg.from_id), + ) + rpc_msg = rpc_pb2.RPC( + publish=[pubsub_msg], + ) + for peer_id in peers_gen: + stream = self.pubsub.peers[str(peer_id)] + # FIXME: We should add a `WriteMsg` similar to write delimited messages. + # Ref: https://github.com/libp2p/go-libp2p-pubsub/blob/master/comm.go#L107 + # TODO: Go use `sendRPC`, which possibly piggybacks gossip/control messages. + await stream.write(rpc_msg.SerializeToString()) - # Deliver to self if self was origin - # Note: handle_talk checks if self is subscribed to topics in message - for message in packet.publish: - # Add RPC message to cache - self.mcache.put(message) + def _get_peers_to_send( + self, + topic_ids: Iterable[str], + src: ID, + origin: ID) -> Iterable[ID]: + """ + Get the eligible peers to send the data to. + :param src: the peer id of the peer who forwards the message to me. + :param origin: the peer id of the peer who originally broadcast the message. + :return: a generator of the peer ids who we send data to. + """ + to_send: MutableSet[ID] = set() + for topic in topic_ids: + if topic not in self.pubsub.peer_topics: + continue - decoded_from_id = message.from_id.decode('utf-8') - new_packet = rpc_pb2.RPC() - new_packet.publish.extend([message]) - new_packet_serialized = new_packet.SerializeToString() + # floodsub peers + for peer_id_str in self.pubsub.peer_topics[topic]: + peer_id = id_b58_decode(peer_id_str) + # FIXME: `gossipsub.peers_floodsub` can be changed to `gossipsub.peers` in go. + # This will improve the efficiency when searching for a peer's protocol id. + if peer_id_str in self.peers_floodsub: + to_send.add(peer_id) - # Deliver to self if needed - if msg_sender == decoded_from_id and msg_sender == str(self.pubsub.host.get_id()): - id_in_seen_msgs = (message.seqno, message.from_id) + # gossipsub peers + # FIXME: Change `str` to `ID` + gossipsub_peers: List[str] = None + # TODO: Do we need to check `topic in self.pubsub.my_topics`? + if topic in self.mesh: + gossipsub_peers = self.mesh[topic] + else: + # TODO(robzajac): Is topic DEFINITELY supposed to be in fanout if we are not + # subscribed? + # I assume there could be short periods between heartbeats where topic may not + # be but we should check that this path gets hit appropriately - if id_in_seen_msgs not in self.pubsub.seen_messages: - self.pubsub.seen_messages[id_in_seen_msgs] = 1 - - await self.pubsub.handle_talk(message) - - # Deliver to peers - for topic in message.topicIDs: - # If topic has floodsub peers, deliver to floodsub peers - # TODO: This can be done more efficiently. Do it more efficiently. - floodsub_peers_in_topic = [] - if topic in self.pubsub.peer_topics: - for peer in self.pubsub.peer_topics[topic]: - if str(peer) in self.peers_floodsub: - floodsub_peers_in_topic.append(peer) - - await self.deliver_messages_to_peers(floodsub_peers_in_topic, msg_sender, - decoded_from_id, new_packet_serialized) - - # If you are subscribed to topic, send to mesh, otherwise send to fanout - if topic in self.pubsub.my_topics and topic in self.mesh: - await self.deliver_messages_to_peers(self.mesh[topic], msg_sender, - decoded_from_id, new_packet_serialized) - else: - # Send to fanout peers - if topic not in self.fanout: - # If no peers in fanout, choose some peers from gossipsub peers in topic - gossipsub_peers_in_topic = [peer for peer in self.pubsub.peer_topics[topic] - if peer in self.peers_gossipsub] - - selected = \ - GossipSub.select_from_minus(self.degree, gossipsub_peers_in_topic, []) - self.fanout[topic] = selected - - # TODO: Is topic DEFINITELY supposed to be in fanout if we are not subscribed? - # I assume there could be short periods between heartbeats where topic may not - # be but we should check that this path gets hit appropriately - - await self.deliver_messages_to_peers(self.fanout[topic], msg_sender, - decoded_from_id, new_packet_serialized) + # pylint: disable=len-as-condition + if (topic not in self.fanout) or (len(self.fanout[topic]) == 0): + # If no peers in fanout, choose some peers from gossipsub peers in topic. + self.fanout[topic] = self._get_peers_from_minus(topic, self.degree, []) + gossipsub_peers = self.fanout[topic] + for peer_id_str in gossipsub_peers: + to_send.add(id_b58_decode(peer_id_str)) + # Excludes `src` and `origin` + yield from to_send.difference([src, origin]) async def join(self, topic): # Note: the comments here are the near-exact algorithm description from the spec @@ -401,6 +416,22 @@ class GossipSub(IPubsubRouter): return selection + def _get_peers_from_minus( + self, + topic: str, + num_to_select: int, + minus: Sequence[ID]) -> List[ID]: + gossipsub_peers_in_topic = [ + peer_str + for peer_str in self.pubsub.peer_topics[topic] + if peer_str in self.peers_gossipsub + ] + return self.select_from_minus( + num_to_select, + gossipsub_peers_in_topic, + list(minus), + ) + # RPC handlers async def handle_ihave(self, ihave_msg, sender_peer_id): diff --git a/tests/pubsub/dummy_account_node.py b/tests/pubsub/dummy_account_node.py index a640997..8446028 100644 --- a/tests/pubsub/dummy_account_node.py +++ b/tests/pubsub/dummy_account_node.py @@ -2,11 +2,14 @@ import asyncio import multiaddr import uuid -from utils import message_id_generator, generate_RPC_packet from libp2p import new_node +from libp2p.host.host_interface import IHost from libp2p.pubsub.pubsub import Pubsub from libp2p.pubsub.floodsub import FloodSub +from .utils import message_id_generator, generate_RPC_packet + + SUPPORTED_PUBSUB_PROTOCOLS = ["/floodsub/1.0.0"] CRYPTO_TOPIC = "ethereum" @@ -17,14 +20,25 @@ CRYPTO_TOPIC = "ethereum" # Ex. set,rob,5 # Determine message type by looking at first item before first comma -class DummyAccountNode(): + +class DummyAccountNode: """ - Node which has an internal balance mapping, meant to serve as + Node which has an internal balance mapping, meant to serve as a dummy crypto blockchain. There is no actual blockchain, just a simple map indicating how much crypto each user in the mappings holds """ + libp2p_node: IHost + pubsub: Pubsub + floodsub: FloodSub - def __init__(self): + def __init__( + self, + libp2p_node: IHost, + pubsub: Pubsub, + floodsub: FloodSub): + self.libp2p_node = libp2p_node + self.pubsub = pubsub + self.floodsub = floodsub self.balances = {} self.next_msg_id_func = message_id_generator(0) self.node_id = str(uuid.uuid1()) @@ -38,16 +52,21 @@ class DummyAccountNode(): We use create as this serves as a factory function and allows us to use async await, unlike the init function """ - self = DummyAccountNode() libp2p_node = await new_node(transport_opt=["/ip4/127.0.0.1/tcp/0"]) await libp2p_node.get_network().listen(multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/0")) - self.libp2p_node = libp2p_node - - self.floodsub = FloodSub(SUPPORTED_PUBSUB_PROTOCOLS) - self.pubsub = Pubsub(self.libp2p_node, self.floodsub, "a") - return self + floodsub = FloodSub(SUPPORTED_PUBSUB_PROTOCOLS) + pubsub = Pubsub( + libp2p_node, + floodsub, + "a", + ) + return cls( + libp2p_node=libp2p_node, + pubsub=pubsub, + floodsub=floodsub, + ) async def handle_incoming_msgs(self): """ @@ -78,10 +97,8 @@ class DummyAccountNode(): :param dest_user: user to send crypto to :param amount: amount of crypto to send """ - my_id = str(self.libp2p_node.get_id()) msg_contents = "send," + source_user + "," + dest_user + "," + str(amount) - packet = generate_RPC_packet(my_id, [CRYPTO_TOPIC], msg_contents, self.next_msg_id_func()) - await self.floodsub.publish(my_id, packet.SerializeToString()) + await self.pubsub.publish(CRYPTO_TOPIC, msg_contents.encode()) async def publish_set_crypto(self, user, amount): """ @@ -89,18 +106,15 @@ class DummyAccountNode(): :param user: user to set crypto for :param amount: amount of crypto """ - my_id = str(self.libp2p_node.get_id()) msg_contents = "set," + user + "," + str(amount) - packet = generate_RPC_packet(my_id, [CRYPTO_TOPIC], msg_contents, self.next_msg_id_func()) - - await self.floodsub.publish(my_id, packet.SerializeToString()) + await self.pubsub.publish(CRYPTO_TOPIC, msg_contents.encode()) def handle_send_crypto(self, source_user, dest_user, amount): """ Handle incoming send_crypto message :param source_user: user to send crypto from :param dest_user: user to send crypto to - :param amount: amount of crypto to send + :param amount: amount of crypto to send """ if source_user in self.balances: self.balances[source_user] -= amount diff --git a/tests/pubsub/test_dummyaccount_demo.py b/tests/pubsub/test_dummyaccount_demo.py index 9fa2aa7..b1c4a64 100644 --- a/tests/pubsub/test_dummyaccount_demo.py +++ b/tests/pubsub/test_dummyaccount_demo.py @@ -1,35 +1,37 @@ import asyncio +from threading import Thread + import multiaddr + import pytest -from threading import Thread -from tests.utils import cleanup from libp2p import new_node from libp2p.peer.peerinfo import info_from_p2p_addr from libp2p.pubsub.pubsub import Pubsub from libp2p.pubsub.floodsub import FloodSub -from dummy_account_node import DummyAccountNode + +from tests.utils import ( + cleanup, + connect, +) +from .dummy_account_node import DummyAccountNode # pylint: disable=too-many-locals -async def connect(node1, node2): - # node1 connects to node2 - addr = node2.get_addrs()[0] - info = info_from_p2p_addr(addr) - await node1.connect(info) def create_setup_in_new_thread_func(dummy_node): def setup_in_new_thread(): asyncio.ensure_future(dummy_node.setup_crypto_networking()) return setup_in_new_thread + async def perform_test(num_nodes, adjacency_map, action_func, assertion_func): """ Helper function to allow for easy construction of custom tests for dummy account nodes in various network topologies :param num_nodes: number of nodes in the test :param adjacency_map: adjacency map defining each node and its list of neighbors - :param action_func: function to execute that includes actions by the nodes, + :param action_func: function to execute that includes actions by the nodes, such as send crypto and set crypto :param assertion_func: assertions for testing the results of the actions are correct """ @@ -73,6 +75,7 @@ async def perform_test(num_nodes, adjacency_map, action_func, assertion_func): # Success, terminate pending tasks. await cleanup() + @pytest.mark.asyncio async def test_simple_two_nodes(): num_nodes = 2 @@ -86,6 +89,7 @@ async def test_simple_two_nodes(): await perform_test(num_nodes, adj_map, action_func, assertion_func) + @pytest.mark.asyncio async def test_simple_three_nodes_line_topography(): num_nodes = 3 @@ -99,6 +103,7 @@ async def test_simple_three_nodes_line_topography(): await perform_test(num_nodes, adj_map, action_func, assertion_func) + @pytest.mark.asyncio async def test_simple_three_nodes_triangle_topography(): num_nodes = 3 @@ -112,6 +117,7 @@ async def test_simple_three_nodes_triangle_topography(): await perform_test(num_nodes, adj_map, action_func, assertion_func) + @pytest.mark.asyncio async def test_simple_seven_nodes_tree_topography(): num_nodes = 7 @@ -125,6 +131,7 @@ async def test_simple_seven_nodes_tree_topography(): await perform_test(num_nodes, adj_map, action_func, assertion_func) + @pytest.mark.asyncio async def test_set_then_send_from_root_seven_nodes_tree_topography(): num_nodes = 7 diff --git a/tests/pubsub/test_gossipsub.py b/tests/pubsub/test_gossipsub.py index bb47135..7e2c757 100644 --- a/tests/pubsub/test_gossipsub.py +++ b/tests/pubsub/test_gossipsub.py @@ -1,11 +1,16 @@ import asyncio -import pytest import random -from utils import message_id_generator, generate_RPC_packet, \ +import pytest + +from tests.utils import ( + cleanup, + connect, +) + +from .utils import message_id_generator, generate_RPC_packet, \ create_libp2p_hosts, create_pubsub_and_gossipsub_instances, sparse_connect, dense_connect, \ - connect, one_to_all_connect -from tests.utils import cleanup + one_to_all_connect SUPPORTED_PROTOCOLS = ["/gossipsub/1.0.0"] @@ -41,13 +46,8 @@ async def test_join(): # Central node publish to the topic so that this topic # is added to central node's fanout - next_msg_id_func = message_id_generator(0) - msg_content = "" - host_id = str(libp2p_hosts[central_node_index].get_id()) - # Generate message packet - packet = generate_RPC_packet(host_id, [topic], msg_content, next_msg_id_func()) # publish from the randomly chosen host - await gossipsubs[central_node_index].publish(host_id, packet.SerializeToString()) + await pubsubs[central_node_index].publish(topic, b"") # Check that the gossipsub of central node has fanout for the topic assert topic in gossipsubs[central_node_index].fanout @@ -86,6 +86,8 @@ async def test_leave(): gossipsub = gossipsubs[0] topic = "test_leave" + assert topic not in gossipsub.mesh + await gossipsub.join(topic) assert topic in gossipsub.mesh @@ -205,14 +207,12 @@ async def test_handle_prune(): @pytest.mark.asyncio async def test_dense(): # Create libp2p hosts - next_msg_id_func = message_id_generator(0) - num_hosts = 10 num_msgs = 5 libp2p_hosts = await create_libp2p_hosts(num_hosts) # Create pubsub, gossipsub instances - pubsubs, gossipsubs = create_pubsub_and_gossipsub_instances(libp2p_hosts, \ + pubsubs, _ = create_pubsub_and_gossipsub_instances(libp2p_hosts, \ SUPPORTED_PROTOCOLS, \ 10, 9, 11, 30, 3, 5, 0.5) @@ -231,41 +231,35 @@ async def test_dense(): await asyncio.sleep(2) for i in range(num_msgs): - msg_content = "foo " + str(i) + msg_content = b"foo " + i.to_bytes(1, 'big') # randomly pick a message origin origin_idx = random.randint(0, num_hosts - 1) - origin_host = libp2p_hosts[origin_idx] - host_id = str(origin_host.get_id()) - - # Generate message packet - packet = generate_RPC_packet(host_id, ["foobar"], msg_content, next_msg_id_func()) # publish from the randomly chosen host - await gossipsubs[origin_idx].publish(host_id, packet.SerializeToString()) + await pubsubs[origin_idx].publish("foobar", msg_content) await asyncio.sleep(0.5) # Assert that all blocking queues receive the message for queue in queues: msg = await queue.get() - assert msg.data == packet.publish[0].data + assert msg.data == msg_content await cleanup() + @pytest.mark.asyncio async def test_fanout(): # Create libp2p hosts - next_msg_id_func = message_id_generator(0) - num_hosts = 10 num_msgs = 5 libp2p_hosts = await create_libp2p_hosts(num_hosts) # Create pubsub, gossipsub instances - pubsubs, gossipsubs = create_pubsub_and_gossipsub_instances(libp2p_hosts, \ + pubsubs, _ = create_pubsub_and_gossipsub_instances(libp2p_hosts, \ SUPPORTED_PROTOCOLS, \ 10, 9, 11, 30, 3, 5, 0.5) - # All pubsub subscribe to foobar + # All pubsub subscribe to foobar except for `pubsubs[0]` queues = [] for i in range(1, len(pubsubs)): q = await pubsubs[i].subscribe("foobar") @@ -279,71 +273,61 @@ async def test_fanout(): # Wait 2 seconds for heartbeat to allow mesh to connect await asyncio.sleep(2) + topic = "foobar" # Send messages with origin not subscribed for i in range(num_msgs): - msg_content = "foo " + str(i) + msg_content = b"foo " + i.to_bytes(1, "big") # Pick the message origin to the node that is not subscribed to 'foobar' origin_idx = 0 - origin_host = libp2p_hosts[origin_idx] - host_id = str(origin_host.get_id()) - - # Generate message packet - packet = generate_RPC_packet(host_id, ["foobar"], msg_content, next_msg_id_func()) # publish from the randomly chosen host - await gossipsubs[origin_idx].publish(host_id, packet.SerializeToString()) + await pubsubs[origin_idx].publish(topic, msg_content) await asyncio.sleep(0.5) # Assert that all blocking queues receive the message for queue in queues: msg = await queue.get() - assert msg.SerializeToString() == packet.publish[0].SerializeToString() + assert msg.data == msg_content # Subscribe message origin - queues.append(await pubsubs[0].subscribe("foobar")) + queues.insert(0, await pubsubs[0].subscribe(topic)) # Send messages again for i in range(num_msgs): - msg_content = "foo " + str(i) + msg_content = b"bar " + i.to_bytes(1, 'big') # Pick the message origin to the node that is not subscribed to 'foobar' origin_idx = 0 - origin_host = libp2p_hosts[origin_idx] - host_id = str(origin_host.get_id()) - - # Generate message packet - packet = generate_RPC_packet(host_id, ["foobar"], msg_content, next_msg_id_func()) # publish from the randomly chosen host - await gossipsubs[origin_idx].publish(host_id, packet.SerializeToString()) + await pubsubs[origin_idx].publish(topic, msg_content) await asyncio.sleep(0.5) # Assert that all blocking queues receive the message for queue in queues: msg = await queue.get() - assert msg.SerializeToString() == packet.publish[0].SerializeToString() + assert msg.data == msg_content await cleanup() @pytest.mark.asyncio async def test_fanout_maintenance(): # Create libp2p hosts - next_msg_id_func = message_id_generator(0) - num_hosts = 10 num_msgs = 5 libp2p_hosts = await create_libp2p_hosts(num_hosts) # Create pubsub, gossipsub instances - pubsubs, gossipsubs = create_pubsub_and_gossipsub_instances(libp2p_hosts, \ + pubsubs, _ = create_pubsub_and_gossipsub_instances(libp2p_hosts, \ SUPPORTED_PROTOCOLS, \ 10, 9, 11, 30, 3, 5, 0.5) # All pubsub subscribe to foobar queues = [] + topic = "foobar" for i in range(1, len(pubsubs)): - q = await pubsubs[i].subscribe("foobar") + q = await pubsubs[i].subscribe(topic) # Add each blocking queue to an array of blocking queues queues.append(q) @@ -356,27 +340,22 @@ async def test_fanout_maintenance(): # Send messages with origin not subscribed for i in range(num_msgs): - msg_content = "foo " + str(i) + msg_content = b"foo " + i.to_bytes(1, 'big') # Pick the message origin to the node that is not subscribed to 'foobar' origin_idx = 0 - origin_host = libp2p_hosts[origin_idx] - host_id = str(origin_host.get_id()) - - # Generate message packet - packet = generate_RPC_packet(host_id, ["foobar"], msg_content, next_msg_id_func()) # publish from the randomly chosen host - await gossipsubs[origin_idx].publish(host_id, packet.SerializeToString()) + await pubsubs[origin_idx].publish(topic, msg_content) await asyncio.sleep(0.5) # Assert that all blocking queues receive the message for queue in queues: msg = await queue.get() - assert msg.SerializeToString() == packet.publish[0].SerializeToString() + assert msg.data == msg_content for sub in pubsubs: - await sub.unsubscribe('foobar') + await sub.unsubscribe(topic) queues = [] @@ -384,7 +363,7 @@ async def test_fanout_maintenance(): # Resub and repeat for i in range(1, len(pubsubs)): - q = await pubsubs[i].subscribe("foobar") + q = await pubsubs[i].subscribe(topic) # Add each blocking queue to an array of blocking queues queues.append(q) @@ -393,65 +372,61 @@ async def test_fanout_maintenance(): # Check messages can still be sent for i in range(num_msgs): - msg_content = "foo " + str(i) + msg_content = b"bar " + i.to_bytes(1, 'big') # Pick the message origin to the node that is not subscribed to 'foobar' origin_idx = 0 - origin_host = libp2p_hosts[origin_idx] - host_id = str(origin_host.get_id()) - - # Generate message packet - packet = generate_RPC_packet(host_id, ["foobar"], msg_content, next_msg_id_func()) # publish from the randomly chosen host - await gossipsubs[origin_idx].publish(host_id, packet.SerializeToString()) + await pubsubs[origin_idx].publish(topic, msg_content) await asyncio.sleep(0.5) # Assert that all blocking queues receive the message for queue in queues: msg = await queue.get() - assert msg.SerializeToString() == packet.publish[0].SerializeToString() + assert msg.data == msg_content await cleanup() + @pytest.mark.asyncio async def test_gossip_propagation(): # Create libp2p hosts - next_msg_id_func = message_id_generator(0) - num_hosts = 2 - libp2p_hosts = await create_libp2p_hosts(num_hosts) + hosts = await create_libp2p_hosts(num_hosts) # Create pubsub, gossipsub instances - pubsubs, gossipsubs = create_pubsub_and_gossipsub_instances(libp2p_hosts, \ - SUPPORTED_PROTOCOLS, \ - 1, 0, 2, 30, 50, 100, 0.5) - node1, node2 = libp2p_hosts[0], libp2p_hosts[1] - sub1, sub2 = pubsubs[0], pubsubs[1] - gsub1, gsub2 = gossipsubs[0], gossipsubs[1] + pubsubs, _ = create_pubsub_and_gossipsub_instances( + hosts, + SUPPORTED_PROTOCOLS, + 1, + 0, + 2, + 30, + 50, + 100, + 0.5, + ) - node1_queue = await sub1.subscribe('foo') + topic = "foo" + await pubsubs[0].subscribe(topic) - # node 1 publish to topic - msg_content = 'foo_msg' - node1_id = str(node1.get_id()) - - # Generate message packet - packet = generate_RPC_packet(node1_id, ["foo"], msg_content, next_msg_id_func()) + # node 0 publish to topic + msg_content = b'foo_msg' # publish from the randomly chosen host - await gsub1.publish(node1_id, packet.SerializeToString()) + await pubsubs[0].publish(topic, msg_content) - # now node 2 subscribes - node2_queue = await sub2.subscribe('foo') + # now node 1 subscribes + queue_1 = await pubsubs[1].subscribe(topic) - await connect(node2, node1) + await connect(hosts[0], hosts[1]) # wait for gossip heartbeat await asyncio.sleep(2) # should be able to read message - msg = await node2_queue.get() - assert msg.SerializeToString() == packet.publish[0].SerializeToString() + msg = await queue_1.get() + assert msg.data == msg_content await cleanup() diff --git a/tests/pubsub/test_gossipsub_backward_compatibility.py b/tests/pubsub/test_gossipsub_backward_compatibility.py index 468e25f..060973a 100644 --- a/tests/pubsub/test_gossipsub_backward_compatibility.py +++ b/tests/pubsub/test_gossipsub_backward_compatibility.py @@ -8,18 +8,17 @@ from libp2p.pubsub.gossipsub import GossipSub from libp2p.pubsub.floodsub import FloodSub from libp2p.pubsub.pb import rpc_pb2 from libp2p.pubsub.pubsub import Pubsub -from utils import message_id_generator, generate_RPC_packet + from tests.utils import cleanup -# pylint: disable=too-many-locals +from .utils import ( + connect, + message_id_generator, + generate_RPC_packet, +) -async def connect(node1, node2): - """ - Connect node1 to node2 - """ - addr = node2.get_addrs()[0] - info = info_from_p2p_addr(addr) - await node1.connect(info) + +# pylint: disable=too-many-locals @pytest.mark.asyncio async def test_init(): @@ -37,11 +36,12 @@ async def test_init(): await cleanup() + async def perform_test_from_obj(obj): """ Perform a floodsub test from a test obj. test obj are composed as follows: - + { "supported_protocols": ["supported/protocol/1.0.0",...], "adj_list": { @@ -95,7 +95,7 @@ async def perform_test_from_obj(obj): if neighbor_id not in node_map: neighbor_node = await new_node(transport_opt=["/ip4/127.0.0.1/tcp/0"]) await neighbor_node.get_network().listen(multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/0")) - + node_map[neighbor_id] = neighbor_node gossipsub = GossipSub(supported_protocols, 3, 2, 4, 30) @@ -104,7 +104,7 @@ async def perform_test_from_obj(obj): pubsub_map[neighbor_id] = pubsub # Connect node and neighbor - tasks_connect.append(asyncio.ensure_future(connect(node_map[start_node_id], node_map[neighbor_id]))) + tasks_connect.append(connect(node_map[start_node_id], node_map[neighbor_id])) tasks_connect.append(asyncio.sleep(2)) await asyncio.gather(*tasks_connect) @@ -130,7 +130,7 @@ async def perform_test_from_obj(obj): # Store queue in topic-queue map for node queues_map[node_id][topic] = q """ - tasks_topic.append(asyncio.ensure_future(pubsub_map[node_id].subscribe(topic))) + tasks_topic.append(pubsub_map[node_id].subscribe(topic)) tasks_topic_data.append((node_id, topic)) tasks_topic.append(asyncio.sleep(2)) @@ -152,29 +152,27 @@ async def perform_test_from_obj(obj): topics_in_msgs_ordered = [] messages = obj["messages"] tasks_publish = [] - next_msg_id_func = message_id_generator(0) for msg in messages: topics = msg["topics"] - data = msg["data"] node_id = msg["node_id"] - # Get actual id for sender node (not the id from the test obj) - actual_node_id = str(node_map[node_id].get_id()) - - # Create correctly formatted message - msg_talk = generate_RPC_packet(actual_node_id, topics, data, next_msg_id_func()) - # Publish message - tasks_publish.append(asyncio.ensure_future(gossipsub_map[node_id].publish(\ - actual_node_id, msg_talk.SerializeToString()))) + # FIXME: This should be one RPC packet with several topics + for topic in topics: + tasks_publish.append( + pubsub_map[node_id].publish( + topic, + data, + ) + ) # For each topic in topics, add topic, msg_talk tuple to ordered test list # TODO: Update message sender to be correct message sender before # adding msg_talk to this list for topic in topics: - topics_in_msgs_ordered.append((topic, msg_talk)) + topics_in_msgs_ordered.append((topic, data)) # Allow time for publishing before continuing # await asyncio.sleep(0.4) @@ -183,14 +181,12 @@ async def perform_test_from_obj(obj): # Step 4) Check that all messages were received correctly. # TODO: Check message sender too - for i in range(len(topics_in_msgs_ordered)): - topic, actual_msg = topics_in_msgs_ordered[i] - + for topic, data in topics_in_msgs_ordered: # Look at each node in each topic for node_id in topic_map[topic]: # Get message from subscription queue msg_on_node = await queues_map[node_id][topic].get() - assert actual_msg.publish[0].SerializeToString() == msg_on_node.SerializeToString() + assert msg_on_node.data == data # Success, terminate pending tasks. await cleanup() diff --git a/tests/pubsub/utils.py b/tests/pubsub/utils.py index 72d796f..bb49a2f 100644 --- a/tests/pubsub/utils.py +++ b/tests/pubsub/utils.py @@ -14,6 +14,8 @@ from libp2p.peer.id import ID from libp2p.pubsub.pubsub import Pubsub from libp2p.pubsub.gossipsub import GossipSub +from tests.utils import connect + def message_id_generator(start_val): """ @@ -22,6 +24,7 @@ def message_id_generator(start_val): :return: message id """ val = start_val + def generator(): # Allow manipulation of val within closure nonlocal val @@ -105,6 +108,10 @@ def create_pubsub_and_gossipsub_instances(libp2p_hosts, supported_protocols, deg return pubsubs, gossipsubs + +# FIXME: There is no difference between `sparse_connect` and `dense_connect`, +# before `connect_some` is fixed. + async def sparse_connect(hosts): await connect_some(hosts, 3) @@ -113,6 +120,7 @@ async def dense_connect(hosts): await connect_some(hosts, 10) +# FIXME: `degree` is not used at all async def connect_some(hosts, degree): for i, host in enumerate(hosts): for j, host2 in enumerate(hosts): @@ -135,6 +143,7 @@ async def connect_some(hosts, degree): # j += 1 + async def one_to_all_connect(hosts, central_host_index): for i, host in enumerate(hosts): if i != central_host_index: