Fix several tests
This commit is contained in:
parent
035d08b8bd
commit
65aedcb25a
|
@ -80,6 +80,21 @@ class FloodSub(IPubsubRouter):
|
|||
# Ref: https://github.com/libp2p/go-libp2p-pubsub/blob/master/comm.go#L107
|
||||
await stream.write(rpc_msg.SerializeToString())
|
||||
|
||||
async def join(self, topic):
|
||||
"""
|
||||
Join notifies the router that we want to receive and
|
||||
forward messages in a topic. It is invoked after the
|
||||
subscription announcement
|
||||
:param topic: topic to join
|
||||
"""
|
||||
|
||||
async def leave(self, topic):
|
||||
"""
|
||||
Leave notifies the router that we are no longer interested in a topic.
|
||||
It is invoked after the unsubscription announcement.
|
||||
:param topic: topic to leave
|
||||
"""
|
||||
|
||||
def _get_peers_to_send(
|
||||
self,
|
||||
topic_ids: Iterable[str],
|
||||
|
@ -102,18 +117,3 @@ class FloodSub(IPubsubRouter):
|
|||
if str(peer_id) not in self.pubsub.peers:
|
||||
continue
|
||||
yield peer_id
|
||||
|
||||
async def join(self, topic):
|
||||
"""
|
||||
Join notifies the router that we want to receive and
|
||||
forward messages in a topic. It is invoked after the
|
||||
subscription announcement
|
||||
:param topic: topic to join
|
||||
"""
|
||||
|
||||
async def leave(self, topic):
|
||||
"""
|
||||
Leave notifies the router that we are no longer interested in a topic.
|
||||
It is invoked after the unsubscription announcement.
|
||||
:param topic: topic to leave
|
||||
"""
|
||||
|
|
|
@ -1,10 +1,22 @@
|
|||
import random
|
||||
import asyncio
|
||||
import random
|
||||
from typing import (
|
||||
Iterable,
|
||||
List,
|
||||
MutableSet,
|
||||
Sequence,
|
||||
)
|
||||
|
||||
from ast import literal_eval
|
||||
|
||||
from libp2p.peer.id import (
|
||||
ID,
|
||||
id_b58_decode,
|
||||
)
|
||||
|
||||
from .mcache import MessageCache
|
||||
from .pb import rpc_pb2
|
||||
from .pubsub_router_interface import IPubsubRouter
|
||||
from .mcache import MessageCache
|
||||
|
||||
|
||||
class GossipSub(IPubsubRouter):
|
||||
|
@ -107,70 +119,73 @@ class GossipSub(IPubsubRouter):
|
|||
for prune in control_message.prune:
|
||||
await self.handle_prune(prune, sender_peer_id)
|
||||
|
||||
async def publish(self, sender_peer_id, rpc_message):
|
||||
async def publish(self, src: ID, pubsub_msg: rpc_pb2.Message) -> None:
|
||||
# pylint: disable=too-many-locals
|
||||
"""
|
||||
Invoked to forward a new message that has been validated.
|
||||
"""
|
||||
self.mcache.put(pubsub_msg)
|
||||
|
||||
packet = rpc_pb2.RPC()
|
||||
packet.ParseFromString(rpc_message)
|
||||
msg_sender = str(sender_peer_id)
|
||||
peers_gen = self._get_peers_to_send(
|
||||
pubsub_msg.topicIDs,
|
||||
src=src,
|
||||
origin=ID(pubsub_msg.from_id),
|
||||
)
|
||||
rpc_msg = rpc_pb2.RPC(
|
||||
publish=[pubsub_msg],
|
||||
)
|
||||
for peer_id in peers_gen:
|
||||
stream = self.pubsub.peers[str(peer_id)]
|
||||
# FIXME: We should add a `WriteMsg` similar to write delimited messages.
|
||||
# Ref: https://github.com/libp2p/go-libp2p-pubsub/blob/master/comm.go#L107
|
||||
# TODO: Go use `sendRPC`, which possibly piggybacks gossip/control messages.
|
||||
await stream.write(rpc_msg.SerializeToString())
|
||||
|
||||
# Deliver to self if self was origin
|
||||
# Note: handle_talk checks if self is subscribed to topics in message
|
||||
for message in packet.publish:
|
||||
# Add RPC message to cache
|
||||
self.mcache.put(message)
|
||||
def _get_peers_to_send(
|
||||
self,
|
||||
topic_ids: Iterable[str],
|
||||
src: ID,
|
||||
origin: ID) -> Iterable[ID]:
|
||||
"""
|
||||
Get the eligible peers to send the data to.
|
||||
:param src: the peer id of the peer who forwards the message to me.
|
||||
:param origin: the peer id of the peer who originally broadcast the message.
|
||||
:return: a generator of the peer ids who we send data to.
|
||||
"""
|
||||
to_send: MutableSet[ID] = set()
|
||||
for topic in topic_ids:
|
||||
if topic not in self.pubsub.peer_topics:
|
||||
continue
|
||||
|
||||
decoded_from_id = message.from_id.decode('utf-8')
|
||||
new_packet = rpc_pb2.RPC()
|
||||
new_packet.publish.extend([message])
|
||||
new_packet_serialized = new_packet.SerializeToString()
|
||||
# floodsub peers
|
||||
for peer_id_str in self.pubsub.peer_topics[topic]:
|
||||
peer_id = id_b58_decode(peer_id_str)
|
||||
# FIXME: `gossipsub.peers_floodsub` can be changed to `gossipsub.peers` in go.
|
||||
# This will improve the efficiency when searching for a peer's protocol id.
|
||||
if peer_id_str in self.peers_floodsub:
|
||||
to_send.add(peer_id)
|
||||
|
||||
# Deliver to self if needed
|
||||
if msg_sender == decoded_from_id and msg_sender == str(self.pubsub.host.get_id()):
|
||||
id_in_seen_msgs = (message.seqno, message.from_id)
|
||||
# gossipsub peers
|
||||
# FIXME: Change `str` to `ID`
|
||||
gossipsub_peers: List[str] = None
|
||||
# TODO: Do we need to check `topic in self.pubsub.my_topics`?
|
||||
if topic in self.mesh:
|
||||
gossipsub_peers = self.mesh[topic]
|
||||
else:
|
||||
# TODO(robzajac): Is topic DEFINITELY supposed to be in fanout if we are not
|
||||
# subscribed?
|
||||
# I assume there could be short periods between heartbeats where topic may not
|
||||
# be but we should check that this path gets hit appropriately
|
||||
|
||||
if id_in_seen_msgs not in self.pubsub.seen_messages:
|
||||
self.pubsub.seen_messages[id_in_seen_msgs] = 1
|
||||
|
||||
await self.pubsub.handle_talk(message)
|
||||
|
||||
# Deliver to peers
|
||||
for topic in message.topicIDs:
|
||||
# If topic has floodsub peers, deliver to floodsub peers
|
||||
# TODO: This can be done more efficiently. Do it more efficiently.
|
||||
floodsub_peers_in_topic = []
|
||||
if topic in self.pubsub.peer_topics:
|
||||
for peer in self.pubsub.peer_topics[topic]:
|
||||
if str(peer) in self.peers_floodsub:
|
||||
floodsub_peers_in_topic.append(peer)
|
||||
|
||||
await self.deliver_messages_to_peers(floodsub_peers_in_topic, msg_sender,
|
||||
decoded_from_id, new_packet_serialized)
|
||||
|
||||
# If you are subscribed to topic, send to mesh, otherwise send to fanout
|
||||
if topic in self.pubsub.my_topics and topic in self.mesh:
|
||||
await self.deliver_messages_to_peers(self.mesh[topic], msg_sender,
|
||||
decoded_from_id, new_packet_serialized)
|
||||
else:
|
||||
# Send to fanout peers
|
||||
if topic not in self.fanout:
|
||||
# If no peers in fanout, choose some peers from gossipsub peers in topic
|
||||
gossipsub_peers_in_topic = [peer for peer in self.pubsub.peer_topics[topic]
|
||||
if peer in self.peers_gossipsub]
|
||||
|
||||
selected = \
|
||||
GossipSub.select_from_minus(self.degree, gossipsub_peers_in_topic, [])
|
||||
self.fanout[topic] = selected
|
||||
|
||||
# TODO: Is topic DEFINITELY supposed to be in fanout if we are not subscribed?
|
||||
# I assume there could be short periods between heartbeats where topic may not
|
||||
# be but we should check that this path gets hit appropriately
|
||||
|
||||
await self.deliver_messages_to_peers(self.fanout[topic], msg_sender,
|
||||
decoded_from_id, new_packet_serialized)
|
||||
# pylint: disable=len-as-condition
|
||||
if (topic not in self.fanout) or (len(self.fanout[topic]) == 0):
|
||||
# If no peers in fanout, choose some peers from gossipsub peers in topic.
|
||||
self.fanout[topic] = self._get_peers_from_minus(topic, self.degree, [])
|
||||
gossipsub_peers = self.fanout[topic]
|
||||
for peer_id_str in gossipsub_peers:
|
||||
to_send.add(id_b58_decode(peer_id_str))
|
||||
# Excludes `src` and `origin`
|
||||
yield from to_send.difference([src, origin])
|
||||
|
||||
async def join(self, topic):
|
||||
# Note: the comments here are the near-exact algorithm description from the spec
|
||||
|
@ -401,6 +416,22 @@ class GossipSub(IPubsubRouter):
|
|||
|
||||
return selection
|
||||
|
||||
def _get_peers_from_minus(
|
||||
self,
|
||||
topic: str,
|
||||
num_to_select: int,
|
||||
minus: Sequence[ID]) -> List[ID]:
|
||||
gossipsub_peers_in_topic = [
|
||||
peer_str
|
||||
for peer_str in self.pubsub.peer_topics[topic]
|
||||
if peer_str in self.peers_gossipsub
|
||||
]
|
||||
return self.select_from_minus(
|
||||
num_to_select,
|
||||
gossipsub_peers_in_topic,
|
||||
list(minus),
|
||||
)
|
||||
|
||||
# RPC handlers
|
||||
|
||||
async def handle_ihave(self, ihave_msg, sender_peer_id):
|
||||
|
|
|
@ -2,11 +2,14 @@ import asyncio
|
|||
import multiaddr
|
||||
import uuid
|
||||
|
||||
from utils import message_id_generator, generate_RPC_packet
|
||||
from libp2p import new_node
|
||||
from libp2p.host.host_interface import IHost
|
||||
from libp2p.pubsub.pubsub import Pubsub
|
||||
from libp2p.pubsub.floodsub import FloodSub
|
||||
|
||||
from .utils import message_id_generator, generate_RPC_packet
|
||||
|
||||
|
||||
SUPPORTED_PUBSUB_PROTOCOLS = ["/floodsub/1.0.0"]
|
||||
CRYPTO_TOPIC = "ethereum"
|
||||
|
||||
|
@ -17,14 +20,25 @@ CRYPTO_TOPIC = "ethereum"
|
|||
# Ex. set,rob,5
|
||||
# Determine message type by looking at first item before first comma
|
||||
|
||||
class DummyAccountNode():
|
||||
|
||||
class DummyAccountNode:
|
||||
"""
|
||||
Node which has an internal balance mapping, meant to serve as
|
||||
Node which has an internal balance mapping, meant to serve as
|
||||
a dummy crypto blockchain. There is no actual blockchain, just a simple
|
||||
map indicating how much crypto each user in the mappings holds
|
||||
"""
|
||||
libp2p_node: IHost
|
||||
pubsub: Pubsub
|
||||
floodsub: FloodSub
|
||||
|
||||
def __init__(self):
|
||||
def __init__(
|
||||
self,
|
||||
libp2p_node: IHost,
|
||||
pubsub: Pubsub,
|
||||
floodsub: FloodSub):
|
||||
self.libp2p_node = libp2p_node
|
||||
self.pubsub = pubsub
|
||||
self.floodsub = floodsub
|
||||
self.balances = {}
|
||||
self.next_msg_id_func = message_id_generator(0)
|
||||
self.node_id = str(uuid.uuid1())
|
||||
|
@ -38,16 +52,21 @@ class DummyAccountNode():
|
|||
We use create as this serves as a factory function and allows us
|
||||
to use async await, unlike the init function
|
||||
"""
|
||||
self = DummyAccountNode()
|
||||
|
||||
libp2p_node = await new_node(transport_opt=["/ip4/127.0.0.1/tcp/0"])
|
||||
await libp2p_node.get_network().listen(multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/0"))
|
||||
|
||||
self.libp2p_node = libp2p_node
|
||||
|
||||
self.floodsub = FloodSub(SUPPORTED_PUBSUB_PROTOCOLS)
|
||||
self.pubsub = Pubsub(self.libp2p_node, self.floodsub, "a")
|
||||
return self
|
||||
floodsub = FloodSub(SUPPORTED_PUBSUB_PROTOCOLS)
|
||||
pubsub = Pubsub(
|
||||
libp2p_node,
|
||||
floodsub,
|
||||
"a",
|
||||
)
|
||||
return cls(
|
||||
libp2p_node=libp2p_node,
|
||||
pubsub=pubsub,
|
||||
floodsub=floodsub,
|
||||
)
|
||||
|
||||
async def handle_incoming_msgs(self):
|
||||
"""
|
||||
|
@ -78,10 +97,8 @@ class DummyAccountNode():
|
|||
:param dest_user: user to send crypto to
|
||||
:param amount: amount of crypto to send
|
||||
"""
|
||||
my_id = str(self.libp2p_node.get_id())
|
||||
msg_contents = "send," + source_user + "," + dest_user + "," + str(amount)
|
||||
packet = generate_RPC_packet(my_id, [CRYPTO_TOPIC], msg_contents, self.next_msg_id_func())
|
||||
await self.floodsub.publish(my_id, packet.SerializeToString())
|
||||
await self.pubsub.publish(CRYPTO_TOPIC, msg_contents.encode())
|
||||
|
||||
async def publish_set_crypto(self, user, amount):
|
||||
"""
|
||||
|
@ -89,18 +106,15 @@ class DummyAccountNode():
|
|||
:param user: user to set crypto for
|
||||
:param amount: amount of crypto
|
||||
"""
|
||||
my_id = str(self.libp2p_node.get_id())
|
||||
msg_contents = "set," + user + "," + str(amount)
|
||||
packet = generate_RPC_packet(my_id, [CRYPTO_TOPIC], msg_contents, self.next_msg_id_func())
|
||||
|
||||
await self.floodsub.publish(my_id, packet.SerializeToString())
|
||||
await self.pubsub.publish(CRYPTO_TOPIC, msg_contents.encode())
|
||||
|
||||
def handle_send_crypto(self, source_user, dest_user, amount):
|
||||
"""
|
||||
Handle incoming send_crypto message
|
||||
:param source_user: user to send crypto from
|
||||
:param dest_user: user to send crypto to
|
||||
:param amount: amount of crypto to send
|
||||
:param amount: amount of crypto to send
|
||||
"""
|
||||
if source_user in self.balances:
|
||||
self.balances[source_user] -= amount
|
||||
|
|
|
@ -1,35 +1,37 @@
|
|||
import asyncio
|
||||
from threading import Thread
|
||||
|
||||
import multiaddr
|
||||
|
||||
import pytest
|
||||
|
||||
from threading import Thread
|
||||
from tests.utils import cleanup
|
||||
from libp2p import new_node
|
||||
from libp2p.peer.peerinfo import info_from_p2p_addr
|
||||
from libp2p.pubsub.pubsub import Pubsub
|
||||
from libp2p.pubsub.floodsub import FloodSub
|
||||
from dummy_account_node import DummyAccountNode
|
||||
|
||||
from tests.utils import (
|
||||
cleanup,
|
||||
connect,
|
||||
)
|
||||
from .dummy_account_node import DummyAccountNode
|
||||
|
||||
# pylint: disable=too-many-locals
|
||||
|
||||
async def connect(node1, node2):
|
||||
# node1 connects to node2
|
||||
addr = node2.get_addrs()[0]
|
||||
info = info_from_p2p_addr(addr)
|
||||
await node1.connect(info)
|
||||
|
||||
def create_setup_in_new_thread_func(dummy_node):
|
||||
def setup_in_new_thread():
|
||||
asyncio.ensure_future(dummy_node.setup_crypto_networking())
|
||||
return setup_in_new_thread
|
||||
|
||||
|
||||
async def perform_test(num_nodes, adjacency_map, action_func, assertion_func):
|
||||
"""
|
||||
Helper function to allow for easy construction of custom tests for dummy account nodes
|
||||
in various network topologies
|
||||
:param num_nodes: number of nodes in the test
|
||||
:param adjacency_map: adjacency map defining each node and its list of neighbors
|
||||
:param action_func: function to execute that includes actions by the nodes,
|
||||
:param action_func: function to execute that includes actions by the nodes,
|
||||
such as send crypto and set crypto
|
||||
:param assertion_func: assertions for testing the results of the actions are correct
|
||||
"""
|
||||
|
@ -73,6 +75,7 @@ async def perform_test(num_nodes, adjacency_map, action_func, assertion_func):
|
|||
# Success, terminate pending tasks.
|
||||
await cleanup()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_simple_two_nodes():
|
||||
num_nodes = 2
|
||||
|
@ -86,6 +89,7 @@ async def test_simple_two_nodes():
|
|||
|
||||
await perform_test(num_nodes, adj_map, action_func, assertion_func)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_simple_three_nodes_line_topography():
|
||||
num_nodes = 3
|
||||
|
@ -99,6 +103,7 @@ async def test_simple_three_nodes_line_topography():
|
|||
|
||||
await perform_test(num_nodes, adj_map, action_func, assertion_func)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_simple_three_nodes_triangle_topography():
|
||||
num_nodes = 3
|
||||
|
@ -112,6 +117,7 @@ async def test_simple_three_nodes_triangle_topography():
|
|||
|
||||
await perform_test(num_nodes, adj_map, action_func, assertion_func)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_simple_seven_nodes_tree_topography():
|
||||
num_nodes = 7
|
||||
|
@ -125,6 +131,7 @@ async def test_simple_seven_nodes_tree_topography():
|
|||
|
||||
await perform_test(num_nodes, adj_map, action_func, assertion_func)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_then_send_from_root_seven_nodes_tree_topography():
|
||||
num_nodes = 7
|
||||
|
|
|
@ -1,11 +1,16 @@
|
|||
import asyncio
|
||||
import pytest
|
||||
import random
|
||||
|
||||
from utils import message_id_generator, generate_RPC_packet, \
|
||||
import pytest
|
||||
|
||||
from tests.utils import (
|
||||
cleanup,
|
||||
connect,
|
||||
)
|
||||
|
||||
from .utils import message_id_generator, generate_RPC_packet, \
|
||||
create_libp2p_hosts, create_pubsub_and_gossipsub_instances, sparse_connect, dense_connect, \
|
||||
connect, one_to_all_connect
|
||||
from tests.utils import cleanup
|
||||
one_to_all_connect
|
||||
|
||||
SUPPORTED_PROTOCOLS = ["/gossipsub/1.0.0"]
|
||||
|
||||
|
@ -41,13 +46,8 @@ async def test_join():
|
|||
|
||||
# Central node publish to the topic so that this topic
|
||||
# is added to central node's fanout
|
||||
next_msg_id_func = message_id_generator(0)
|
||||
msg_content = ""
|
||||
host_id = str(libp2p_hosts[central_node_index].get_id())
|
||||
# Generate message packet
|
||||
packet = generate_RPC_packet(host_id, [topic], msg_content, next_msg_id_func())
|
||||
# publish from the randomly chosen host
|
||||
await gossipsubs[central_node_index].publish(host_id, packet.SerializeToString())
|
||||
await pubsubs[central_node_index].publish(topic, b"")
|
||||
|
||||
# Check that the gossipsub of central node has fanout for the topic
|
||||
assert topic in gossipsubs[central_node_index].fanout
|
||||
|
@ -86,6 +86,8 @@ async def test_leave():
|
|||
gossipsub = gossipsubs[0]
|
||||
topic = "test_leave"
|
||||
|
||||
assert topic not in gossipsub.mesh
|
||||
|
||||
await gossipsub.join(topic)
|
||||
assert topic in gossipsub.mesh
|
||||
|
||||
|
@ -205,14 +207,12 @@ async def test_handle_prune():
|
|||
@pytest.mark.asyncio
|
||||
async def test_dense():
|
||||
# Create libp2p hosts
|
||||
next_msg_id_func = message_id_generator(0)
|
||||
|
||||
num_hosts = 10
|
||||
num_msgs = 5
|
||||
libp2p_hosts = await create_libp2p_hosts(num_hosts)
|
||||
|
||||
# Create pubsub, gossipsub instances
|
||||
pubsubs, gossipsubs = create_pubsub_and_gossipsub_instances(libp2p_hosts, \
|
||||
pubsubs, _ = create_pubsub_and_gossipsub_instances(libp2p_hosts, \
|
||||
SUPPORTED_PROTOCOLS, \
|
||||
10, 9, 11, 30, 3, 5, 0.5)
|
||||
|
||||
|
@ -231,41 +231,35 @@ async def test_dense():
|
|||
await asyncio.sleep(2)
|
||||
|
||||
for i in range(num_msgs):
|
||||
msg_content = "foo " + str(i)
|
||||
msg_content = b"foo " + i.to_bytes(1, 'big')
|
||||
|
||||
# randomly pick a message origin
|
||||
origin_idx = random.randint(0, num_hosts - 1)
|
||||
origin_host = libp2p_hosts[origin_idx]
|
||||
host_id = str(origin_host.get_id())
|
||||
|
||||
# Generate message packet
|
||||
packet = generate_RPC_packet(host_id, ["foobar"], msg_content, next_msg_id_func())
|
||||
|
||||
# publish from the randomly chosen host
|
||||
await gossipsubs[origin_idx].publish(host_id, packet.SerializeToString())
|
||||
await pubsubs[origin_idx].publish("foobar", msg_content)
|
||||
|
||||
await asyncio.sleep(0.5)
|
||||
# Assert that all blocking queues receive the message
|
||||
for queue in queues:
|
||||
msg = await queue.get()
|
||||
assert msg.data == packet.publish[0].data
|
||||
assert msg.data == msg_content
|
||||
await cleanup()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fanout():
|
||||
# Create libp2p hosts
|
||||
next_msg_id_func = message_id_generator(0)
|
||||
|
||||
num_hosts = 10
|
||||
num_msgs = 5
|
||||
libp2p_hosts = await create_libp2p_hosts(num_hosts)
|
||||
|
||||
# Create pubsub, gossipsub instances
|
||||
pubsubs, gossipsubs = create_pubsub_and_gossipsub_instances(libp2p_hosts, \
|
||||
pubsubs, _ = create_pubsub_and_gossipsub_instances(libp2p_hosts, \
|
||||
SUPPORTED_PROTOCOLS, \
|
||||
10, 9, 11, 30, 3, 5, 0.5)
|
||||
|
||||
# All pubsub subscribe to foobar
|
||||
# All pubsub subscribe to foobar except for `pubsubs[0]`
|
||||
queues = []
|
||||
for i in range(1, len(pubsubs)):
|
||||
q = await pubsubs[i].subscribe("foobar")
|
||||
|
@ -279,71 +273,61 @@ async def test_fanout():
|
|||
# Wait 2 seconds for heartbeat to allow mesh to connect
|
||||
await asyncio.sleep(2)
|
||||
|
||||
topic = "foobar"
|
||||
# Send messages with origin not subscribed
|
||||
for i in range(num_msgs):
|
||||
msg_content = "foo " + str(i)
|
||||
msg_content = b"foo " + i.to_bytes(1, "big")
|
||||
|
||||
# Pick the message origin to the node that is not subscribed to 'foobar'
|
||||
origin_idx = 0
|
||||
origin_host = libp2p_hosts[origin_idx]
|
||||
host_id = str(origin_host.get_id())
|
||||
|
||||
# Generate message packet
|
||||
packet = generate_RPC_packet(host_id, ["foobar"], msg_content, next_msg_id_func())
|
||||
|
||||
# publish from the randomly chosen host
|
||||
await gossipsubs[origin_idx].publish(host_id, packet.SerializeToString())
|
||||
await pubsubs[origin_idx].publish(topic, msg_content)
|
||||
|
||||
await asyncio.sleep(0.5)
|
||||
# Assert that all blocking queues receive the message
|
||||
for queue in queues:
|
||||
msg = await queue.get()
|
||||
assert msg.SerializeToString() == packet.publish[0].SerializeToString()
|
||||
assert msg.data == msg_content
|
||||
|
||||
# Subscribe message origin
|
||||
queues.append(await pubsubs[0].subscribe("foobar"))
|
||||
queues.insert(0, await pubsubs[0].subscribe(topic))
|
||||
|
||||
# Send messages again
|
||||
for i in range(num_msgs):
|
||||
msg_content = "foo " + str(i)
|
||||
msg_content = b"bar " + i.to_bytes(1, 'big')
|
||||
|
||||
# Pick the message origin to the node that is not subscribed to 'foobar'
|
||||
origin_idx = 0
|
||||
origin_host = libp2p_hosts[origin_idx]
|
||||
host_id = str(origin_host.get_id())
|
||||
|
||||
# Generate message packet
|
||||
packet = generate_RPC_packet(host_id, ["foobar"], msg_content, next_msg_id_func())
|
||||
|
||||
# publish from the randomly chosen host
|
||||
await gossipsubs[origin_idx].publish(host_id, packet.SerializeToString())
|
||||
await pubsubs[origin_idx].publish(topic, msg_content)
|
||||
|
||||
await asyncio.sleep(0.5)
|
||||
# Assert that all blocking queues receive the message
|
||||
for queue in queues:
|
||||
msg = await queue.get()
|
||||
assert msg.SerializeToString() == packet.publish[0].SerializeToString()
|
||||
assert msg.data == msg_content
|
||||
|
||||
await cleanup()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fanout_maintenance():
|
||||
# Create libp2p hosts
|
||||
next_msg_id_func = message_id_generator(0)
|
||||
|
||||
num_hosts = 10
|
||||
num_msgs = 5
|
||||
libp2p_hosts = await create_libp2p_hosts(num_hosts)
|
||||
|
||||
# Create pubsub, gossipsub instances
|
||||
pubsubs, gossipsubs = create_pubsub_and_gossipsub_instances(libp2p_hosts, \
|
||||
pubsubs, _ = create_pubsub_and_gossipsub_instances(libp2p_hosts, \
|
||||
SUPPORTED_PROTOCOLS, \
|
||||
10, 9, 11, 30, 3, 5, 0.5)
|
||||
|
||||
# All pubsub subscribe to foobar
|
||||
queues = []
|
||||
topic = "foobar"
|
||||
for i in range(1, len(pubsubs)):
|
||||
q = await pubsubs[i].subscribe("foobar")
|
||||
q = await pubsubs[i].subscribe(topic)
|
||||
|
||||
# Add each blocking queue to an array of blocking queues
|
||||
queues.append(q)
|
||||
|
@ -356,27 +340,22 @@ async def test_fanout_maintenance():
|
|||
|
||||
# Send messages with origin not subscribed
|
||||
for i in range(num_msgs):
|
||||
msg_content = "foo " + str(i)
|
||||
msg_content = b"foo " + i.to_bytes(1, 'big')
|
||||
|
||||
# Pick the message origin to the node that is not subscribed to 'foobar'
|
||||
origin_idx = 0
|
||||
origin_host = libp2p_hosts[origin_idx]
|
||||
host_id = str(origin_host.get_id())
|
||||
|
||||
# Generate message packet
|
||||
packet = generate_RPC_packet(host_id, ["foobar"], msg_content, next_msg_id_func())
|
||||
|
||||
# publish from the randomly chosen host
|
||||
await gossipsubs[origin_idx].publish(host_id, packet.SerializeToString())
|
||||
await pubsubs[origin_idx].publish(topic, msg_content)
|
||||
|
||||
await asyncio.sleep(0.5)
|
||||
# Assert that all blocking queues receive the message
|
||||
for queue in queues:
|
||||
msg = await queue.get()
|
||||
assert msg.SerializeToString() == packet.publish[0].SerializeToString()
|
||||
assert msg.data == msg_content
|
||||
|
||||
for sub in pubsubs:
|
||||
await sub.unsubscribe('foobar')
|
||||
await sub.unsubscribe(topic)
|
||||
|
||||
queues = []
|
||||
|
||||
|
@ -384,7 +363,7 @@ async def test_fanout_maintenance():
|
|||
|
||||
# Resub and repeat
|
||||
for i in range(1, len(pubsubs)):
|
||||
q = await pubsubs[i].subscribe("foobar")
|
||||
q = await pubsubs[i].subscribe(topic)
|
||||
|
||||
# Add each blocking queue to an array of blocking queues
|
||||
queues.append(q)
|
||||
|
@ -393,65 +372,61 @@ async def test_fanout_maintenance():
|
|||
|
||||
# Check messages can still be sent
|
||||
for i in range(num_msgs):
|
||||
msg_content = "foo " + str(i)
|
||||
msg_content = b"bar " + i.to_bytes(1, 'big')
|
||||
|
||||
# Pick the message origin to the node that is not subscribed to 'foobar'
|
||||
origin_idx = 0
|
||||
origin_host = libp2p_hosts[origin_idx]
|
||||
host_id = str(origin_host.get_id())
|
||||
|
||||
# Generate message packet
|
||||
packet = generate_RPC_packet(host_id, ["foobar"], msg_content, next_msg_id_func())
|
||||
|
||||
# publish from the randomly chosen host
|
||||
await gossipsubs[origin_idx].publish(host_id, packet.SerializeToString())
|
||||
await pubsubs[origin_idx].publish(topic, msg_content)
|
||||
|
||||
await asyncio.sleep(0.5)
|
||||
# Assert that all blocking queues receive the message
|
||||
for queue in queues:
|
||||
msg = await queue.get()
|
||||
assert msg.SerializeToString() == packet.publish[0].SerializeToString()
|
||||
assert msg.data == msg_content
|
||||
|
||||
await cleanup()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gossip_propagation():
|
||||
# Create libp2p hosts
|
||||
next_msg_id_func = message_id_generator(0)
|
||||
|
||||
num_hosts = 2
|
||||
libp2p_hosts = await create_libp2p_hosts(num_hosts)
|
||||
hosts = await create_libp2p_hosts(num_hosts)
|
||||
|
||||
# Create pubsub, gossipsub instances
|
||||
pubsubs, gossipsubs = create_pubsub_and_gossipsub_instances(libp2p_hosts, \
|
||||
SUPPORTED_PROTOCOLS, \
|
||||
1, 0, 2, 30, 50, 100, 0.5)
|
||||
node1, node2 = libp2p_hosts[0], libp2p_hosts[1]
|
||||
sub1, sub2 = pubsubs[0], pubsubs[1]
|
||||
gsub1, gsub2 = gossipsubs[0], gossipsubs[1]
|
||||
pubsubs, _ = create_pubsub_and_gossipsub_instances(
|
||||
hosts,
|
||||
SUPPORTED_PROTOCOLS,
|
||||
1,
|
||||
0,
|
||||
2,
|
||||
30,
|
||||
50,
|
||||
100,
|
||||
0.5,
|
||||
)
|
||||
|
||||
node1_queue = await sub1.subscribe('foo')
|
||||
topic = "foo"
|
||||
await pubsubs[0].subscribe(topic)
|
||||
|
||||
# node 1 publish to topic
|
||||
msg_content = 'foo_msg'
|
||||
node1_id = str(node1.get_id())
|
||||
|
||||
# Generate message packet
|
||||
packet = generate_RPC_packet(node1_id, ["foo"], msg_content, next_msg_id_func())
|
||||
# node 0 publish to topic
|
||||
msg_content = b'foo_msg'
|
||||
|
||||
# publish from the randomly chosen host
|
||||
await gsub1.publish(node1_id, packet.SerializeToString())
|
||||
await pubsubs[0].publish(topic, msg_content)
|
||||
|
||||
# now node 2 subscribes
|
||||
node2_queue = await sub2.subscribe('foo')
|
||||
# now node 1 subscribes
|
||||
queue_1 = await pubsubs[1].subscribe(topic)
|
||||
|
||||
await connect(node2, node1)
|
||||
await connect(hosts[0], hosts[1])
|
||||
|
||||
# wait for gossip heartbeat
|
||||
await asyncio.sleep(2)
|
||||
|
||||
# should be able to read message
|
||||
msg = await node2_queue.get()
|
||||
assert msg.SerializeToString() == packet.publish[0].SerializeToString()
|
||||
msg = await queue_1.get()
|
||||
assert msg.data == msg_content
|
||||
|
||||
await cleanup()
|
||||
|
|
|
@ -8,18 +8,17 @@ from libp2p.pubsub.gossipsub import GossipSub
|
|||
from libp2p.pubsub.floodsub import FloodSub
|
||||
from libp2p.pubsub.pb import rpc_pb2
|
||||
from libp2p.pubsub.pubsub import Pubsub
|
||||
from utils import message_id_generator, generate_RPC_packet
|
||||
|
||||
from tests.utils import cleanup
|
||||
|
||||
# pylint: disable=too-many-locals
|
||||
from .utils import (
|
||||
connect,
|
||||
message_id_generator,
|
||||
generate_RPC_packet,
|
||||
)
|
||||
|
||||
async def connect(node1, node2):
|
||||
"""
|
||||
Connect node1 to node2
|
||||
"""
|
||||
addr = node2.get_addrs()[0]
|
||||
info = info_from_p2p_addr(addr)
|
||||
await node1.connect(info)
|
||||
|
||||
# pylint: disable=too-many-locals
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_init():
|
||||
|
@ -37,11 +36,12 @@ async def test_init():
|
|||
|
||||
await cleanup()
|
||||
|
||||
|
||||
async def perform_test_from_obj(obj):
|
||||
"""
|
||||
Perform a floodsub test from a test obj.
|
||||
test obj are composed as follows:
|
||||
|
||||
|
||||
{
|
||||
"supported_protocols": ["supported/protocol/1.0.0",...],
|
||||
"adj_list": {
|
||||
|
@ -95,7 +95,7 @@ async def perform_test_from_obj(obj):
|
|||
if neighbor_id not in node_map:
|
||||
neighbor_node = await new_node(transport_opt=["/ip4/127.0.0.1/tcp/0"])
|
||||
await neighbor_node.get_network().listen(multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/0"))
|
||||
|
||||
|
||||
node_map[neighbor_id] = neighbor_node
|
||||
|
||||
gossipsub = GossipSub(supported_protocols, 3, 2, 4, 30)
|
||||
|
@ -104,7 +104,7 @@ async def perform_test_from_obj(obj):
|
|||
pubsub_map[neighbor_id] = pubsub
|
||||
|
||||
# Connect node and neighbor
|
||||
tasks_connect.append(asyncio.ensure_future(connect(node_map[start_node_id], node_map[neighbor_id])))
|
||||
tasks_connect.append(connect(node_map[start_node_id], node_map[neighbor_id]))
|
||||
tasks_connect.append(asyncio.sleep(2))
|
||||
await asyncio.gather(*tasks_connect)
|
||||
|
||||
|
@ -130,7 +130,7 @@ async def perform_test_from_obj(obj):
|
|||
# Store queue in topic-queue map for node
|
||||
queues_map[node_id][topic] = q
|
||||
"""
|
||||
tasks_topic.append(asyncio.ensure_future(pubsub_map[node_id].subscribe(topic)))
|
||||
tasks_topic.append(pubsub_map[node_id].subscribe(topic))
|
||||
tasks_topic_data.append((node_id, topic))
|
||||
tasks_topic.append(asyncio.sleep(2))
|
||||
|
||||
|
@ -152,29 +152,27 @@ async def perform_test_from_obj(obj):
|
|||
topics_in_msgs_ordered = []
|
||||
messages = obj["messages"]
|
||||
tasks_publish = []
|
||||
next_msg_id_func = message_id_generator(0)
|
||||
|
||||
for msg in messages:
|
||||
topics = msg["topics"]
|
||||
|
||||
data = msg["data"]
|
||||
node_id = msg["node_id"]
|
||||
|
||||
# Get actual id for sender node (not the id from the test obj)
|
||||
actual_node_id = str(node_map[node_id].get_id())
|
||||
|
||||
# Create correctly formatted message
|
||||
msg_talk = generate_RPC_packet(actual_node_id, topics, data, next_msg_id_func())
|
||||
|
||||
# Publish message
|
||||
tasks_publish.append(asyncio.ensure_future(gossipsub_map[node_id].publish(\
|
||||
actual_node_id, msg_talk.SerializeToString())))
|
||||
# FIXME: This should be one RPC packet with several topics
|
||||
for topic in topics:
|
||||
tasks_publish.append(
|
||||
pubsub_map[node_id].publish(
|
||||
topic,
|
||||
data,
|
||||
)
|
||||
)
|
||||
|
||||
# For each topic in topics, add topic, msg_talk tuple to ordered test list
|
||||
# TODO: Update message sender to be correct message sender before
|
||||
# adding msg_talk to this list
|
||||
for topic in topics:
|
||||
topics_in_msgs_ordered.append((topic, msg_talk))
|
||||
topics_in_msgs_ordered.append((topic, data))
|
||||
|
||||
# Allow time for publishing before continuing
|
||||
# await asyncio.sleep(0.4)
|
||||
|
@ -183,14 +181,12 @@ async def perform_test_from_obj(obj):
|
|||
|
||||
# Step 4) Check that all messages were received correctly.
|
||||
# TODO: Check message sender too
|
||||
for i in range(len(topics_in_msgs_ordered)):
|
||||
topic, actual_msg = topics_in_msgs_ordered[i]
|
||||
|
||||
for topic, data in topics_in_msgs_ordered:
|
||||
# Look at each node in each topic
|
||||
for node_id in topic_map[topic]:
|
||||
# Get message from subscription queue
|
||||
msg_on_node = await queues_map[node_id][topic].get()
|
||||
assert actual_msg.publish[0].SerializeToString() == msg_on_node.SerializeToString()
|
||||
assert msg_on_node.data == data
|
||||
|
||||
# Success, terminate pending tasks.
|
||||
await cleanup()
|
||||
|
|
|
@ -14,6 +14,8 @@ from libp2p.peer.id import ID
|
|||
from libp2p.pubsub.pubsub import Pubsub
|
||||
from libp2p.pubsub.gossipsub import GossipSub
|
||||
|
||||
from tests.utils import connect
|
||||
|
||||
|
||||
def message_id_generator(start_val):
|
||||
"""
|
||||
|
@ -22,6 +24,7 @@ def message_id_generator(start_val):
|
|||
:return: message id
|
||||
"""
|
||||
val = start_val
|
||||
|
||||
def generator():
|
||||
# Allow manipulation of val within closure
|
||||
nonlocal val
|
||||
|
@ -105,6 +108,10 @@ def create_pubsub_and_gossipsub_instances(libp2p_hosts, supported_protocols, deg
|
|||
|
||||
return pubsubs, gossipsubs
|
||||
|
||||
|
||||
# FIXME: There is no difference between `sparse_connect` and `dense_connect`,
|
||||
# before `connect_some` is fixed.
|
||||
|
||||
async def sparse_connect(hosts):
|
||||
await connect_some(hosts, 3)
|
||||
|
||||
|
@ -113,6 +120,7 @@ async def dense_connect(hosts):
|
|||
await connect_some(hosts, 10)
|
||||
|
||||
|
||||
# FIXME: `degree` is not used at all
|
||||
async def connect_some(hosts, degree):
|
||||
for i, host in enumerate(hosts):
|
||||
for j, host2 in enumerate(hosts):
|
||||
|
@ -135,6 +143,7 @@ async def connect_some(hosts, degree):
|
|||
|
||||
# j += 1
|
||||
|
||||
|
||||
async def one_to_all_connect(hosts, central_host_index):
|
||||
for i, host in enumerate(hosts):
|
||||
if i != central_host_index:
|
||||
|
|
Loading…
Reference in New Issue
Block a user