Merge pull request #182 from NIC619/fix_refactor_gossipsub_join
Fix and refactor: `gossipsub.join`
This commit is contained in:
commit
84824fd566
|
@ -184,43 +184,31 @@ class GossipSub(IPubsubRouter):
|
|||
# Create mesh[topic] if it does not yet exist
|
||||
self.mesh[topic] = []
|
||||
|
||||
if topic in self.fanout and len(self.fanout[topic]) == self.degree:
|
||||
# If router already has D peers from the fanout peers of a topic
|
||||
# TODO: Do we remove all peers from fanout[topic]?
|
||||
|
||||
# Add them to mesh[topic], and notifies them with a
|
||||
# GRAFT(topic) control message.
|
||||
for peer in self.fanout[topic]:
|
||||
self.mesh[topic].append(peer)
|
||||
await self.emit_graft(topic, peer)
|
||||
else:
|
||||
# Otherwise, if there are less than D peers
|
||||
# (let this number be x) in the fanout for a topic (or the topic is not in the fanout),
|
||||
fanout_size = 0
|
||||
if topic in self.fanout:
|
||||
fanout_size = len(self.fanout[topic])
|
||||
# then it still adds them as above (if there are any)
|
||||
for peer in self.fanout[topic]:
|
||||
self.mesh[topic].append(peer)
|
||||
await self.emit_graft(topic, peer)
|
||||
|
||||
if topic in self.peers_gossipsub:
|
||||
# TODO: Should we have self.fanout[topic] here or [] (as the minus variable)?
|
||||
# Selects the remaining number of peers (D-x) from peers.gossipsub[topic]
|
||||
topic_in_fanout = topic in self.fanout
|
||||
fanout_peers = self.fanout[topic] if topic_in_fanout else []
|
||||
fanout_size = len(fanout_peers)
|
||||
if not topic_in_fanout or (topic_in_fanout and fanout_size < self.degree):
|
||||
# There are less than D peers (let this number be x)
|
||||
# in the fanout for a topic (or the topic is not in the fanout).
|
||||
# Selects the remaining number of peers (D-x) from peers.gossipsub[topic].
|
||||
if topic in self.pubsub.peer_topics:
|
||||
gossipsub_peers_in_topic = [peer for peer in self.pubsub.peer_topics[topic]
|
||||
if peer in self.peers_gossipsub]
|
||||
selected_peers = \
|
||||
GossipSub.select_from_minus(self.degree - fanout_size,
|
||||
gossipsub_peers_in_topic,
|
||||
self.fanout[topic] if topic in self.fanout else [])
|
||||
fanout_peers)
|
||||
|
||||
# And likewise adds them to mesh[topic] and notifies them with a
|
||||
# GRAFT(topic) control message.
|
||||
for peer in selected_peers:
|
||||
self.mesh[topic].append(peer)
|
||||
await self.emit_graft(topic, peer)
|
||||
# Combine fanout peers with selected peers
|
||||
fanout_peers += selected_peers
|
||||
|
||||
# TODO: Do we remove all peers from fanout[topic]?
|
||||
# Add fanout peers to mesh and notifies them with a GRAFT(topic) control message.
|
||||
for peer in fanout_peers:
|
||||
self.mesh[topic].append(peer)
|
||||
await self.emit_graft(topic, peer)
|
||||
|
||||
if topic_in_fanout:
|
||||
del self.fanout[topic]
|
||||
|
||||
async def leave(self, topic):
|
||||
# Note: the comments here are the near-exact algorithm description from the spec
|
||||
|
@ -277,13 +265,21 @@ class GossipSub(IPubsubRouter):
|
|||
async def mesh_heartbeat(self):
|
||||
# Note: the comments here are the exact pseudocode from the spec
|
||||
for topic in self.mesh:
|
||||
# Skip if no peers have subscribed to the topic
|
||||
if topic not in self.pubsub.peer_topics:
|
||||
continue
|
||||
|
||||
num_mesh_peers_in_topic = len(self.mesh[topic])
|
||||
if num_mesh_peers_in_topic < self.degree_low:
|
||||
gossipsub_peers_in_topic = [peer for peer in self.pubsub.peer_topics[topic]
|
||||
if peer in self.peers_gossipsub]
|
||||
|
||||
# Select D - |mesh[topic]| peers from peers.gossipsub[topic] - mesh[topic]
|
||||
selected_peers = GossipSub.select_from_minus(self.degree - num_mesh_peers_in_topic,
|
||||
self.peers_gossipsub, self.mesh[topic])
|
||||
selected_peers = GossipSub.select_from_minus(
|
||||
self.degree - num_mesh_peers_in_topic,
|
||||
gossipsub_peers_in_topic,
|
||||
self.mesh[topic]
|
||||
)
|
||||
|
||||
for peer in selected_peers:
|
||||
# Add peer to mesh[topic]
|
||||
|
@ -310,7 +306,7 @@ class GossipSub(IPubsubRouter):
|
|||
# TODO: there's no way time_since_last_publish gets set anywhere yet
|
||||
if self.time_since_last_publish[topic] > self.time_to_live:
|
||||
# Remove topic from fanout
|
||||
self.fanout.remove(topic)
|
||||
del self.fanout[topic]
|
||||
self.time_since_last_publish.remove(topic)
|
||||
else:
|
||||
num_fanout_peers_in_topic = len(self.fanout[topic])
|
||||
|
|
|
@ -2,13 +2,9 @@ import asyncio
|
|||
import pytest
|
||||
import random
|
||||
|
||||
from libp2p.pubsub.gossipsub import GossipSub
|
||||
from libp2p.pubsub.floodsub import FloodSub
|
||||
from libp2p.pubsub.pb import rpc_pb2
|
||||
from libp2p.pubsub.pubsub import Pubsub
|
||||
from utils import message_id_generator, generate_RPC_packet, \
|
||||
create_libp2p_hosts, create_pubsub_and_gossipsub_instances, sparse_connect, dense_connect, \
|
||||
connect
|
||||
connect, one_to_all_connect
|
||||
from tests.utils import cleanup
|
||||
|
||||
SUPPORTED_PROTOCOLS = ["/gossipsub/1.0.0"]
|
||||
|
@ -16,23 +12,63 @@ SUPPORTED_PROTOCOLS = ["/gossipsub/1.0.0"]
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_join():
|
||||
num_hosts = 1
|
||||
# Create libp2p hosts
|
||||
num_hosts = 4
|
||||
hosts_indices = list(range(num_hosts))
|
||||
libp2p_hosts = await create_libp2p_hosts(num_hosts)
|
||||
|
||||
# Create pubsub, gossipsub instances
|
||||
_, gossipsubs = create_pubsub_and_gossipsub_instances(libp2p_hosts, \
|
||||
pubsubs, gossipsubs = create_pubsub_and_gossipsub_instances(libp2p_hosts, \
|
||||
SUPPORTED_PROTOCOLS, \
|
||||
10, 9, 11, 30, 3, 5, 0.5)
|
||||
4, 3, 5, 30, 3, 5, 0.5)
|
||||
|
||||
gossipsub = gossipsubs[0]
|
||||
topic = "test_join"
|
||||
central_node_index = 0
|
||||
# Remove index of central host from the indices
|
||||
hosts_indices.remove(central_node_index)
|
||||
num_subscribed_peer = 2
|
||||
subscribed_peer_indices = random.sample(hosts_indices, num_subscribed_peer)
|
||||
|
||||
assert topic not in gossipsub.mesh
|
||||
await gossipsub.join(topic)
|
||||
assert topic in gossipsub.mesh
|
||||
# All pubsub except the one of central node subscribe to topic
|
||||
for i in subscribed_peer_indices:
|
||||
await pubsubs[i].subscribe(topic)
|
||||
|
||||
# Test re-join
|
||||
await gossipsub.join(topic)
|
||||
# Connect central host to all other hosts
|
||||
await one_to_all_connect(libp2p_hosts, central_node_index)
|
||||
|
||||
# Wait 2 seconds for heartbeat to allow mesh to connect
|
||||
await asyncio.sleep(2)
|
||||
|
||||
# Central node publish to the topic so that this topic
|
||||
# is added to central node's fanout
|
||||
next_msg_id_func = message_id_generator(0)
|
||||
msg_content = ""
|
||||
host_id = str(libp2p_hosts[central_node_index].get_id())
|
||||
# Generate message packet
|
||||
packet = generate_RPC_packet(host_id, [topic], msg_content, next_msg_id_func())
|
||||
# publish from the randomly chosen host
|
||||
await gossipsubs[central_node_index].publish(host_id, packet.SerializeToString())
|
||||
|
||||
# Check that the gossipsub of central node has fanout for the topic
|
||||
assert topic in gossipsubs[central_node_index].fanout
|
||||
# Check that the gossipsub of central node does not have a mesh for the topic
|
||||
assert topic not in gossipsubs[central_node_index].mesh
|
||||
|
||||
# Central node subscribes the topic
|
||||
await pubsubs[central_node_index].subscribe(topic)
|
||||
|
||||
await asyncio.sleep(2)
|
||||
|
||||
# Check that the gossipsub of central node no longer has fanout for the topic
|
||||
assert topic not in gossipsubs[central_node_index].fanout
|
||||
|
||||
for i in hosts_indices:
|
||||
if i in subscribed_peer_indices:
|
||||
assert str(libp2p_hosts[i].get_id()) in gossipsubs[central_node_index].mesh[topic]
|
||||
assert str(libp2p_hosts[central_node_index].get_id()) in gossipsubs[i].mesh[topic]
|
||||
else:
|
||||
assert str(libp2p_hosts[i].get_id()) not in gossipsubs[central_node_index].mesh[topic]
|
||||
assert topic not in gossipsubs[i].mesh
|
||||
|
||||
await cleanup()
|
||||
|
||||
|
@ -106,11 +142,9 @@ async def test_dense():
|
|||
|
||||
await asyncio.sleep(0.5)
|
||||
# Assert that all blocking queues receive the message
|
||||
items = []
|
||||
for queue in queues:
|
||||
msg = await queue.get()
|
||||
assert msg.data == packet.publish[0].data
|
||||
items.append(msg.data)
|
||||
await cleanup()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
|
@ -122,3 +122,8 @@ async def connect_some(hosts, degree):
|
|||
# await connect(host, neighbor)
|
||||
|
||||
# j += 1
|
||||
|
||||
async def one_to_all_connect(hosts, central_host_index):
|
||||
for i, host in enumerate(hosts):
|
||||
if i != central_host_index:
|
||||
await connect(hosts[central_host_index], host)
|
||||
|
|
Loading…
Reference in New Issue
Block a user