diff --git a/libp2p/pubsub/pubsub.py b/libp2p/pubsub/pubsub.py index 40c4036..2109a9c 100644 --- a/libp2p/pubsub/pubsub.py +++ b/libp2p/pubsub/pubsub.py @@ -1,4 +1,6 @@ +# pylint: disable=no-name-in-module import asyncio +from lru import LRU from .pb import rpc_pb2 from .pubsub_notifee import PubsubNotifee @@ -7,7 +9,7 @@ from .pubsub_notifee import PubsubNotifee class Pubsub(): # pylint: disable=too-many-instance-attributes, no-member - def __init__(self, host, router, my_id): + def __init__(self, host, router, my_id, cache_size=None): """ Construct a new Pubsub object, which is responsible for handling all Pubsub-related messages and relaying messages as appropriate to the @@ -37,8 +39,12 @@ class Pubsub(): self.incoming_msgs_from_peers = asyncio.Queue() self.outgoing_messages = asyncio.Queue() - # TODO: Make seen_messages a cache (LRU cache?) - self.seen_messages = [] + # keeps track of seen messages as LRU cache + if cache_size is None: + self.cache_size = 128 + else: + self.cache_size = cache_size + self.seen_messages = LRU(self.cache_size) # Map of topics we are subscribed to to handler functions # for when the given topic receives a message @@ -89,7 +95,7 @@ class Pubsub(): id_in_seen_msgs = (message.seqno, message.from_id) if id_in_seen_msgs not in self.seen_messages: should_publish = True - self.seen_messages.append(id_in_seen_msgs) + self.seen_messages[id_in_seen_msgs] = 1 await self.handle_talk(message) if rpc_incoming.subscriptions: diff --git a/requirements_dev.txt b/requirements_dev.txt index e3d8c0e..3d5ca3c 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -5,3 +5,4 @@ pytest-asyncio pylint grpcio grpcio-tools +lru-dict>=1.1.6 \ No newline at end of file diff --git a/setup.py b/setup.py index c44a276..0c62145 100644 --- a/setup.py +++ b/setup.py @@ -23,7 +23,8 @@ setuptools.setup( "pymultihash", "multiaddr", "grpcio", - "grpcio-tools" + "grpcio-tools", + "lru-dict>=1.1.6" ], packages=["libp2p"], zip_safe=False, diff --git a/tests/pubsub/test_floodsub.py b/tests/pubsub/test_floodsub.py index 69b2f8c..06c1cbd 100644 --- a/tests/pubsub/test_floodsub.py +++ b/tests/pubsub/test_floodsub.py @@ -21,7 +21,7 @@ async def connect(node1, node2): await node1.connect(info) @pytest.mark.asyncio -async def test_simple_two_nodes_RPC(): +async def test_simple_two_nodes(): node_a = await new_node(transport_opt=["/ip4/127.0.0.1/tcp/0"]) node_b = await new_node(transport_opt=["/ip4/127.0.0.1/tcp/0"]) @@ -58,6 +58,80 @@ async def test_simple_two_nodes_RPC(): # Success, terminate pending tasks. await cleanup() +@pytest.mark.asyncio +async def test_lru_cache_two_nodes(): + # two nodes with cache_size of 4 + # node_a send the following messages to node_b + # [1, 1, 2, 1, 3, 1, 4, 1, 5, 1] + # node_b should only receive the following + # [1, 2, 3, 4, 5, 1] + node_a = await new_node(transport_opt=["/ip4/127.0.0.1/tcp/0"]) + node_b = await new_node(transport_opt=["/ip4/127.0.0.1/tcp/0"]) + + await node_a.get_network().listen(multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/0")) + await node_b.get_network().listen(multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/0")) + + supported_protocols = ["/floodsub/1.0.0"] + + # initialize PubSub with a cache_size of 4 + floodsub_a = FloodSub(supported_protocols) + pubsub_a = Pubsub(node_a, floodsub_a, "a", 4) + floodsub_b = FloodSub(supported_protocols) + pubsub_b = Pubsub(node_b, floodsub_b, "b", 4) + + await connect(node_a, node_b) + + await asyncio.sleep(0.25) + qb = await pubsub_b.subscribe("my_topic") + + await asyncio.sleep(0.25) + + node_a_id = str(node_a.get_id()) + + # initialize message_id_generator + # store first message + next_msg_id_func = message_id_generator(0) + first_message = generate_RPC_packet(node_a_id, ["my_topic"], "some data 1", next_msg_id_func()) + + await floodsub_a.publish(node_a_id, first_message.SerializeToString()) + await asyncio.sleep(0.25) + print (first_message) + + messages = [first_message] + # for the next 5 messages + for i in range(2, 6): + # write first message + await floodsub_a.publish(node_a_id, first_message.SerializeToString()) + await asyncio.sleep(0.25) + + # generate and write next message + msg = generate_RPC_packet(node_a_id, ["my_topic"], "some data " + str(i), next_msg_id_func()) + messages.append(msg) + + await floodsub_a.publish(node_a_id, msg.SerializeToString()) + await asyncio.sleep(0.25) + + # write first message again + await floodsub_a.publish(node_a_id, first_message.SerializeToString()) + await asyncio.sleep(0.25) + + # check the first five messages in queue + # should only see 1 first_message + for i in range(5): + # Check that the msg received by node_b is the same + # as the message sent by node_a + res_b = await qb.get() + assert res_b.SerializeToString() == messages[i].publish[0].SerializeToString() + + # the 6th message should be first_message + res_b = await qb.get() + assert res_b.SerializeToString() == first_message.publish[0].SerializeToString() + assert qb.empty() + + # Success, terminate pending tasks. + await cleanup() + + async def perform_test_from_obj(obj): """ Perform a floodsub test from a test obj.