diff --git a/tests/pubsub/test_gossipsub.py b/tests/pubsub/test_gossipsub.py index 2121f8f..3ff602d 100644 --- a/tests/pubsub/test_gossipsub.py +++ b/tests/pubsub/test_gossipsub.py @@ -3,7 +3,8 @@ import random import pytest -from libp2p.tools.constants import GossipsubParams +from libp2p.peer.id import ID +from libp2p.tools.constants import GOSSIPSUB_PARAMS, GossipsubParams from libp2p.tools.pubsub.utils import dense_connect, one_to_all_connect from libp2p.tools.utils import connect @@ -366,3 +367,42 @@ async def test_gossip_propagation(hosts, pubsubs_gsub): # should be able to read message msg = await queue_1.get() assert msg.data == msg_content + + +@pytest.mark.parametrize( + "num_hosts, gossipsub_params", ((1, GossipsubParams(heartbeat_initial_delay=100)),) +) +@pytest.mark.parametrize("initial_mesh_peer_count", (7, 10, 13)) +@pytest.mark.asyncio +async def test_mesh_heartbeat( + num_hosts, initial_mesh_peer_count, pubsubs_gsub, hosts, monkeypatch +): + total_peer_count = 14 + topic = "TEST_MESH_HEARTBEAT" + + fake_peer_ids = [ + ID((i).to_bytes(2, byteorder="big")) for i in range(total_peer_count) + ] + monkeypatch.setattr(pubsubs_gsub[0].router, "peers_gossipsub", fake_peer_ids) + + peer_topics = {topic: fake_peer_ids} + monkeypatch.setattr(pubsubs_gsub[0], "peer_topics", peer_topics) + + mesh_peer_indices = random.sample(range(total_peer_count), initial_mesh_peer_count) + mesh_peers = [fake_peer_ids[i] for i in mesh_peer_indices] + router_mesh = {topic: list(mesh_peers)} + monkeypatch.setattr(pubsubs_gsub[0].router, "mesh", router_mesh) + + peers_to_graft, peers_to_prune = pubsubs_gsub[0].router.mesh_heartbeat() + if initial_mesh_peer_count > GOSSIPSUB_PARAMS.degree: + assert len(peers_to_graft) == 0 + assert len(peers_to_prune) == initial_mesh_peer_count - GOSSIPSUB_PARAMS.degree + for peer in peers_to_prune: + assert peer in mesh_peers + elif initial_mesh_peer_count < GOSSIPSUB_PARAMS.degree: + assert len(peers_to_prune) == 0 + assert len(peers_to_graft) == GOSSIPSUB_PARAMS.degree - initial_mesh_peer_count + for peer in peers_to_graft: + assert peer not in mesh_peers + else: + assert len(peers_to_prune) == 0 and len(peers_to_graft) == 0