Merge pull request #387 from NIC619/fix_inconsistent_pubsub_peer_record_update

Store peer ids in set instead of list and check if peer id exist before access
This commit is contained in:
NIC Lin 2019-12-20 00:34:14 +08:00 committed by GitHub
commit 28da206aea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 89 additions and 89 deletions

View File

@ -77,16 +77,20 @@ class FloodSub(IPubsubRouter):
:param pubsub_msg: pubsub message in protobuf.
"""
peers_gen = self._get_peers_to_send(
peers_gen = set(
self._get_peers_to_send(
pubsub_msg.topicIDs,
msg_forwarder=msg_forwarder,
origin=ID(pubsub_msg.from_id),
)
)
rpc_msg = rpc_pb2.RPC(publish=[pubsub_msg])
logger.debug("publishing message %s", pubsub_msg)
for peer_id in peers_gen:
if peer_id not in self.pubsub.peers:
continue
stream = self.pubsub.peers[peer_id]
# FIXME: We should add a `WriteMsg` similar to write delimited messages.
# Ref: https://github.com/libp2p/go-libp2p-pubsub/blob/master/comm.go#L107
@ -94,6 +98,7 @@ class FloodSub(IPubsubRouter):
await stream.write(encode_varint_prefixed(rpc_msg.SerializeToString()))
except StreamClosed:
logger.debug("Fail to publish message to %s: stream closed", peer_id)
self.pubsub._handle_dead_peer(peer_id)
async def join(self, topic: str) -> None:
"""

View File

@ -3,7 +3,7 @@ import asyncio
from collections import defaultdict
import logging
import random
from typing import Any, DefaultDict, Dict, Iterable, List, Sequence, Tuple
from typing import Any, DefaultDict, Dict, Iterable, List, Sequence, Set, Tuple
from libp2p.network.stream.exceptions import StreamClosed
from libp2p.peer.id import ID
@ -32,16 +32,14 @@ class GossipSub(IPubsubRouter):
time_to_live: int
mesh: Dict[str, List[ID]]
fanout: Dict[str, List[ID]]
mesh: Dict[str, Set[ID]]
fanout: Dict[str, Set[ID]]
peers_to_protocol: Dict[ID, str]
# The protocol peer supports
peer_protocol: Dict[ID, TProtocol]
time_since_last_publish: Dict[str, int]
peers_gossipsub: List[ID]
peers_floodsub: List[ID]
mcache: MessageCache
heartbeat_initial_delay: float
@ -75,14 +73,11 @@ class GossipSub(IPubsubRouter):
self.fanout = {}
# Create peer --> protocol mapping
self.peers_to_protocol = {}
self.peer_protocol = {}
# Create topic --> time since last publish map
self.time_since_last_publish = {}
self.peers_gossipsub = []
self.peers_floodsub = []
# Create message cache
self.mcache = MessageCache(gossip_window, gossip_history)
@ -121,17 +116,13 @@ class GossipSub(IPubsubRouter):
"""
logger.debug("adding peer %s with protocol %s", peer_id, protocol_id)
if protocol_id == PROTOCOL_ID:
self.peers_gossipsub.append(peer_id)
elif protocol_id == floodsub.PROTOCOL_ID:
self.peers_floodsub.append(peer_id)
else:
if protocol_id not in (PROTOCOL_ID, floodsub.PROTOCOL_ID):
# We should never enter here. Becuase the `protocol_id` is registered by your pubsub
# instance in multistream-select, but it is not the protocol that gossipsub supports.
# In this case, probably we registered gossipsub to a wrong `protocol_id`
# in multistream-select, or wrong versions.
raise Exception(f"Unreachable: Protocol={protocol_id} is not supported.")
self.peers_to_protocol[peer_id] = protocol_id
raise ValueError(f"Protocol={protocol_id} is not supported.")
self.peer_protocol[peer_id] = protocol_id
def remove_peer(self, peer_id: ID) -> None:
"""
@ -141,21 +132,12 @@ class GossipSub(IPubsubRouter):
"""
logger.debug("removing peer %s", peer_id)
if peer_id in self.peers_gossipsub:
self.peers_gossipsub.remove(peer_id)
elif peer_id in self.peers_floodsub:
self.peers_floodsub.remove(peer_id)
for topic in self.mesh:
if peer_id in self.mesh[topic]:
# Delete the entry if no other peers left
self.mesh[topic].remove(peer_id)
self.mesh[topic].discard(peer_id)
for topic in self.fanout:
if peer_id in self.fanout[topic]:
# Delete the entry if no other peers left
self.fanout[topic].remove(peer_id)
self.fanout[topic].discard(peer_id)
self.peers_to_protocol.pop(peer_id, None)
self.peer_protocol.pop(peer_id, None)
async def handle_rpc(self, rpc: rpc_pb2.RPC, sender_peer_id: ID) -> None:
"""
@ -195,6 +177,8 @@ class GossipSub(IPubsubRouter):
logger.debug("publishing message %s", pubsub_msg)
for peer_id in peers_gen:
if peer_id not in self.pubsub.peers:
continue
stream = self.pubsub.peers[peer_id]
# FIXME: We should add a `WriteMsg` similar to write delimited messages.
# Ref: https://github.com/libp2p/go-libp2p-pubsub/blob/master/comm.go#L107
@ -215,22 +199,21 @@ class GossipSub(IPubsubRouter):
:param origin: peer id of the peer the message originate from.
:return: a generator of the peer ids who we send data to.
"""
send_to: List[ID] = []
send_to: Set[ID] = set()
for topic in topic_ids:
if topic not in self.pubsub.peer_topics:
continue
# floodsub peers
# FIXME: `gossipsub.peers_floodsub` can be changed to `gossipsub.peers` in go.
# This will improve the efficiency when searching for a peer's protocol id.
floodsub_peers: List[ID] = [
floodsub_peers: Set[ID] = set(
peer_id
for peer_id in self.pubsub.peer_topics[topic]
if peer_id in self.peers_floodsub
]
if self.peer_protocol[peer_id] == floodsub.PROTOCOL_ID
)
send_to.update(floodsub_peers)
# gossipsub peers
gossipsub_peers: List[ID] = []
gossipsub_peers: Set[ID] = set()
if topic in self.mesh:
gossipsub_peers = self.mesh[topic]
else:
@ -238,21 +221,23 @@ class GossipSub(IPubsubRouter):
# `self.degree` number of peers who have subscribed to the topic and add them
# as our `fanout` peers.
topic_in_fanout: bool = topic in self.fanout
fanout_peers: List[ID] = self.fanout[topic] if topic_in_fanout else []
fanout_peers: Set[ID] = self.fanout[topic] if topic_in_fanout else set()
fanout_size = len(fanout_peers)
if not topic_in_fanout or (
topic_in_fanout and fanout_size < self.degree
):
if topic in self.pubsub.peer_topics:
# Combine fanout peers with selected peers
fanout_peers += self._get_in_topic_gossipsub_peers_from_minus(
fanout_peers.update(
self._get_in_topic_gossipsub_peers_from_minus(
topic, self.degree - fanout_size, fanout_peers
)
)
self.fanout[topic] = fanout_peers
gossipsub_peers = fanout_peers
send_to.extend(floodsub_peers + gossipsub_peers)
send_to.update(gossipsub_peers)
# Excludes `msg_forwarder` and `origin`
yield from set(send_to).difference([msg_forwarder, origin])
yield from send_to.difference([msg_forwarder, origin])
async def join(self, topic: str) -> None:
"""
@ -266,10 +251,10 @@ class GossipSub(IPubsubRouter):
if topic in self.mesh:
return
# Create mesh[topic] if it does not yet exist
self.mesh[topic] = []
self.mesh[topic] = set()
topic_in_fanout: bool = topic in self.fanout
fanout_peers: List[ID] = self.fanout[topic] if topic_in_fanout else []
fanout_peers: Set[ID] = self.fanout[topic] if topic_in_fanout else set()
fanout_size = len(fanout_peers)
if not topic_in_fanout or (topic_in_fanout and fanout_size < self.degree):
# There are less than D peers (let this number be x)
@ -280,11 +265,11 @@ class GossipSub(IPubsubRouter):
topic, self.degree - fanout_size, fanout_peers
)
# Combine fanout peers with selected peers
fanout_peers += selected_peers
fanout_peers.update(selected_peers)
# Add fanout peers to mesh and notifies them with a GRAFT(topic) control message.
for peer in fanout_peers:
self.mesh[topic].append(peer)
self.mesh[topic].add(peer)
await self.emit_graft(topic, peer)
self.fanout.pop(topic, None)
@ -421,7 +406,7 @@ class GossipSub(IPubsubRouter):
for peer in selected_peers:
# Add peer to mesh[topic]
self.mesh[topic].append(peer)
self.mesh[topic].add(peer)
# Emit GRAFT(topic) control message to peer
peers_to_graft[peer].append(topic)
@ -429,11 +414,11 @@ class GossipSub(IPubsubRouter):
if num_mesh_peers_in_topic > self.degree_high:
# Select |mesh[topic]| - D peers from mesh[topic]
selected_peers = GossipSub.select_from_minus(
num_mesh_peers_in_topic - self.degree, self.mesh[topic], []
num_mesh_peers_in_topic - self.degree, self.mesh[topic], set()
)
for peer in selected_peers:
# Remove peer from mesh[topic]
self.mesh[topic].remove(peer)
self.mesh[topic].discard(peer)
# Emit PRUNE(topic) control message to peer
peers_to_prune[peer].append(topic)
@ -460,7 +445,7 @@ class GossipSub(IPubsubRouter):
for peer in self.fanout[topic]
if peer in self.pubsub.peer_topics[topic]
]
self.fanout[topic] = in_topic_fanout_peers
self.fanout[topic] = set(in_topic_fanout_peers)
num_fanout_peers_in_topic = len(self.fanout[topic])
# If |fanout[topic]| < D
@ -472,7 +457,7 @@ class GossipSub(IPubsubRouter):
self.fanout[topic],
)
# Add the peers to fanout[topic]
self.fanout[topic].extend(selected_peers)
self.fanout[topic].update(selected_peers)
def gossip_heartbeat(self) -> DefaultDict[ID, Dict[str, List[str]]]:
peers_to_gossip: DefaultDict[ID, Dict[str, List[str]]] = defaultdict(dict)
@ -508,7 +493,7 @@ class GossipSub(IPubsubRouter):
@staticmethod
def select_from_minus(
num_to_select: int, pool: Sequence[Any], minus: Sequence[Any]
num_to_select: int, pool: Iterable[Any], minus: Iterable[Any]
) -> List[Any]:
"""
Select at most num_to_select subset of elements from the set (pool - minus) randomly.
@ -527,7 +512,7 @@ class GossipSub(IPubsubRouter):
# If num_to_select > size(selection_pool), then return selection_pool (which has the most
# possible elements s.t. the number of elements is less than num_to_select)
if num_to_select > len(selection_pool):
if num_to_select >= len(selection_pool):
return selection_pool
# Random selection
@ -536,16 +521,14 @@ class GossipSub(IPubsubRouter):
return selection
def _get_in_topic_gossipsub_peers_from_minus(
self, topic: str, num_to_select: int, minus: Sequence[ID]
self, topic: str, num_to_select: int, minus: Iterable[ID]
) -> List[ID]:
gossipsub_peers_in_topic = [
gossipsub_peers_in_topic = set(
peer_id
for peer_id in self.pubsub.peer_topics[topic]
if peer_id in self.peers_gossipsub
]
return self.select_from_minus(
num_to_select, gossipsub_peers_in_topic, list(minus)
if self.peer_protocol[peer_id] == PROTOCOL_ID
)
return self.select_from_minus(num_to_select, gossipsub_peers_in_topic, minus)
# RPC handlers
@ -603,6 +586,12 @@ class GossipSub(IPubsubRouter):
rpc_msg: bytes = packet.SerializeToString()
# 3) Get the stream to this peer
if sender_peer_id not in self.pubsub.peers:
logger.debug(
"Fail to responed to iwant request from %s: peer record not exist",
sender_peer_id,
)
return
peer_stream = self.pubsub.peers[sender_peer_id]
# 4) And write the packet to the stream
@ -623,7 +612,7 @@ class GossipSub(IPubsubRouter):
# Add peer to mesh for topic
if topic in self.mesh:
if sender_peer_id not in self.mesh[topic]:
self.mesh[topic].append(sender_peer_id)
self.mesh[topic].add(sender_peer_id)
else:
# Respond with PRUNE if not subscribed to the topic
await self.emit_prune(topic, sender_peer_id)
@ -633,9 +622,9 @@ class GossipSub(IPubsubRouter):
) -> None:
topic: str = prune_msg.topicID
# Remove peer from mesh for topic, if peer is in topic
if topic in self.mesh and sender_peer_id in self.mesh[topic]:
self.mesh[topic].remove(sender_peer_id)
# Remove peer from mesh for topic
if topic in self.mesh:
self.mesh[topic].discard(sender_peer_id)
# RPC emitters
@ -709,6 +698,11 @@ class GossipSub(IPubsubRouter):
rpc_msg: bytes = packet.SerializeToString()
# Get stream for peer from pubsub
if to_peer not in self.pubsub.peers:
logger.debug(
"Fail to emit control message to %s: peer record not exist", to_peer
)
return
peer_stream = self.pubsub.peers[to_peer]
# Write rpc to stream

View File

@ -8,6 +8,7 @@ from typing import (
Dict,
List,
NamedTuple,
Set,
Tuple,
Union,
cast,
@ -73,7 +74,7 @@ class Pubsub:
my_topics: Dict[str, "asyncio.Queue[rpc_pb2.Message]"]
peer_topics: Dict[str, List[ID]]
peer_topics: Dict[str, Set[ID]]
peers: Dict[ID, INetStream]
topic_validators: Dict[str, TopicValidator]
@ -289,24 +290,22 @@ class Pubsub:
logger.debug("fail to add new peer %s, error %s", peer_id, error)
return
self.peers[peer_id] = stream
# Send hello packet
hello = self.get_hello_packet()
try:
await stream.write(encode_varint_prefixed(hello.SerializeToString()))
except StreamClosed:
logger.debug("Fail to add new peer %s: stream closed", peer_id)
del self.peers[peer_id]
return
# TODO: Check if the peer in black list.
try:
self.router.add_peer(peer_id, stream.get_protocol())
except Exception as error:
logger.debug("fail to add new peer %s, error %s", peer_id, error)
del self.peers[peer_id]
return
self.peers[peer_id] = stream
logger.debug("added new peer %s", peer_id)
def _handle_dead_peer(self, peer_id: ID) -> None:
@ -316,8 +315,7 @@ class Pubsub:
for topic in self.peer_topics:
if peer_id in self.peer_topics[topic]:
# Delete the entry if no other peers left
self.peer_topics[topic].remove(peer_id)
self.peer_topics[topic].discard(peer_id)
self.router.remove_peer(peer_id)
@ -353,15 +351,14 @@ class Pubsub:
"""
if sub_message.subscribe:
if sub_message.topicid not in self.peer_topics:
self.peer_topics[sub_message.topicid] = [origin_id]
self.peer_topics[sub_message.topicid] = set([origin_id])
elif origin_id not in self.peer_topics[sub_message.topicid]:
# Add peer to topic
self.peer_topics[sub_message.topicid].append(origin_id)
self.peer_topics[sub_message.topicid].add(origin_id)
else:
if sub_message.topicid in self.peer_topics:
if origin_id in self.peer_topics[sub_message.topicid]:
# Delete the entry if no other peers left
self.peer_topics[sub_message.topicid].remove(origin_id)
self.peer_topics[sub_message.topicid].discard(origin_id)
# FIXME(mhchia): Change the function name?
async def handle_talk(self, publish_message: rpc_pb2.Message) -> None:

View File

@ -0,0 +1 @@
Store peer ids in ``set`` instead of ``list`` and check if peer id exists in ``dict`` before accessing to prevent ``KeyError``.

View File

@ -4,6 +4,7 @@ import random
import pytest
from libp2p.peer.id import ID
from libp2p.pubsub.gossipsub import PROTOCOL_ID
from libp2p.tools.constants import GOSSIPSUB_PARAMS, GossipsubParams
from libp2p.tools.pubsub.utils import dense_connect, one_to_all_connect
from libp2p.tools.utils import connect
@ -108,7 +109,7 @@ async def test_handle_graft(pubsubs_gsub, hosts, event_loop, monkeypatch):
monkeypatch.setattr(gossipsubs[index_bob], "emit_prune", emit_prune)
# Check that alice is bob's peer but not his mesh peer
assert id_alice in gossipsubs[index_bob].peers_gossipsub
assert gossipsubs[index_bob].peer_protocol[id_alice] == PROTOCOL_ID
assert topic not in gossipsubs[index_bob].mesh
await gossipsubs[index_alice].emit_graft(topic, id_bob)
@ -120,7 +121,7 @@ async def test_handle_graft(pubsubs_gsub, hosts, event_loop, monkeypatch):
# Check that bob is alice's peer but not her mesh peer
assert topic in gossipsubs[index_alice].mesh
assert id_bob not in gossipsubs[index_alice].mesh[topic]
assert id_bob in gossipsubs[index_alice].peers_gossipsub
assert gossipsubs[index_alice].peer_protocol[id_bob] == PROTOCOL_ID
await gossipsubs[index_bob].emit_graft(topic, id_alice)
@ -390,15 +391,16 @@ async def test_mesh_heartbeat(
fake_peer_ids = [
ID((i).to_bytes(2, byteorder="big")) for i in range(total_peer_count)
]
monkeypatch.setattr(pubsubs_gsub[0].router, "peers_gossipsub", fake_peer_ids)
peer_protocol = {peer_id: PROTOCOL_ID for peer_id in fake_peer_ids}
monkeypatch.setattr(pubsubs_gsub[0].router, "peer_protocol", peer_protocol)
peer_topics = {topic: fake_peer_ids}
peer_topics = {topic: set(fake_peer_ids)}
# Monkeypatch the peer subscriptions
monkeypatch.setattr(pubsubs_gsub[0], "peer_topics", peer_topics)
mesh_peer_indices = random.sample(range(total_peer_count), initial_mesh_peer_count)
mesh_peers = [fake_peer_ids[i] for i in mesh_peer_indices]
router_mesh = {topic: list(mesh_peers)}
router_mesh = {topic: set(mesh_peers)}
# Monkeypatch our mesh peers
monkeypatch.setattr(pubsubs_gsub[0].router, "mesh", router_mesh)
@ -437,27 +439,28 @@ async def test_gossip_heartbeat(
fake_peer_ids = [
ID((i).to_bytes(2, byteorder="big")) for i in range(total_peer_count)
]
monkeypatch.setattr(pubsubs_gsub[0].router, "peers_gossipsub", fake_peer_ids)
peer_protocol = {peer_id: PROTOCOL_ID for peer_id in fake_peer_ids}
monkeypatch.setattr(pubsubs_gsub[0].router, "peer_protocol", peer_protocol)
topic_mesh_peer_count = 14
# Split into mesh peers and fanout peers
peer_topics = {
topic_mesh: fake_peer_ids[:topic_mesh_peer_count],
topic_fanout: fake_peer_ids[topic_mesh_peer_count:],
topic_mesh: set(fake_peer_ids[:topic_mesh_peer_count]),
topic_fanout: set(fake_peer_ids[topic_mesh_peer_count:]),
}
# Monkeypatch the peer subscriptions
monkeypatch.setattr(pubsubs_gsub[0], "peer_topics", peer_topics)
mesh_peer_indices = random.sample(range(topic_mesh_peer_count), initial_peer_count)
mesh_peers = [fake_peer_ids[i] for i in mesh_peer_indices]
router_mesh = {topic_mesh: list(mesh_peers)}
router_mesh = {topic_mesh: set(mesh_peers)}
# Monkeypatch our mesh peers
monkeypatch.setattr(pubsubs_gsub[0].router, "mesh", router_mesh)
fanout_peer_indices = random.sample(
range(topic_mesh_peer_count, total_peer_count), initial_peer_count
)
fanout_peers = [fake_peer_ids[i] for i in fanout_peer_indices]
router_fanout = {topic_fanout: list(fanout_peers)}
router_fanout = {topic_fanout: set(fanout_peers)}
# Monkeypatch our fanout peers
monkeypatch.setattr(pubsubs_gsub[0].router, "fanout", router_fanout)

View File

@ -99,7 +99,7 @@ async def test_pubsub(pubsubs, p2pds):
go_0_topic_1_peers = await p2pds[0].control.pubsub_list_peers(TOPIC_1)
assert len(go_0_topic_1_peers) == 1 and py_peer_id == go_0_topic_1_peers[0]
# py
py_topic_0_peers = py_pubsub.peer_topics[TOPIC_0]
py_topic_0_peers = list(py_pubsub.peer_topics[TOPIC_0])
assert len(py_topic_0_peers) == 1 and p2pds[0].peer_id == py_topic_0_peers[0]
# go_1
go_1_topic_1_peers = await p2pds[1].control.pubsub_list_peers(TOPIC_1)