Fix several tests

This commit is contained in:
mhchia 2019-07-26 18:35:25 +08:00
parent 035d08b8bd
commit 65aedcb25a
No known key found for this signature in database
GPG Key ID: 389EFBEA1362589A
7 changed files with 245 additions and 213 deletions

View File

@ -80,6 +80,21 @@ class FloodSub(IPubsubRouter):
# Ref: https://github.com/libp2p/go-libp2p-pubsub/blob/master/comm.go#L107 # Ref: https://github.com/libp2p/go-libp2p-pubsub/blob/master/comm.go#L107
await stream.write(rpc_msg.SerializeToString()) await stream.write(rpc_msg.SerializeToString())
async def join(self, topic):
"""
Join notifies the router that we want to receive and
forward messages in a topic. It is invoked after the
subscription announcement
:param topic: topic to join
"""
async def leave(self, topic):
"""
Leave notifies the router that we are no longer interested in a topic.
It is invoked after the unsubscription announcement.
:param topic: topic to leave
"""
def _get_peers_to_send( def _get_peers_to_send(
self, self,
topic_ids: Iterable[str], topic_ids: Iterable[str],
@ -102,18 +117,3 @@ class FloodSub(IPubsubRouter):
if str(peer_id) not in self.pubsub.peers: if str(peer_id) not in self.pubsub.peers:
continue continue
yield peer_id yield peer_id
async def join(self, topic):
"""
Join notifies the router that we want to receive and
forward messages in a topic. It is invoked after the
subscription announcement
:param topic: topic to join
"""
async def leave(self, topic):
"""
Leave notifies the router that we are no longer interested in a topic.
It is invoked after the unsubscription announcement.
:param topic: topic to leave
"""

View File

@ -1,10 +1,22 @@
import random
import asyncio import asyncio
import random
from typing import (
Iterable,
List,
MutableSet,
Sequence,
)
from ast import literal_eval from ast import literal_eval
from libp2p.peer.id import (
ID,
id_b58_decode,
)
from .mcache import MessageCache
from .pb import rpc_pb2 from .pb import rpc_pb2
from .pubsub_router_interface import IPubsubRouter from .pubsub_router_interface import IPubsubRouter
from .mcache import MessageCache
class GossipSub(IPubsubRouter): class GossipSub(IPubsubRouter):
@ -107,70 +119,73 @@ class GossipSub(IPubsubRouter):
for prune in control_message.prune: for prune in control_message.prune:
await self.handle_prune(prune, sender_peer_id) await self.handle_prune(prune, sender_peer_id)
async def publish(self, sender_peer_id, rpc_message): async def publish(self, src: ID, pubsub_msg: rpc_pb2.Message) -> None:
# pylint: disable=too-many-locals # pylint: disable=too-many-locals
""" """
Invoked to forward a new message that has been validated. Invoked to forward a new message that has been validated.
""" """
self.mcache.put(pubsub_msg)
packet = rpc_pb2.RPC() peers_gen = self._get_peers_to_send(
packet.ParseFromString(rpc_message) pubsub_msg.topicIDs,
msg_sender = str(sender_peer_id) src=src,
origin=ID(pubsub_msg.from_id),
)
rpc_msg = rpc_pb2.RPC(
publish=[pubsub_msg],
)
for peer_id in peers_gen:
stream = self.pubsub.peers[str(peer_id)]
# FIXME: We should add a `WriteMsg` similar to write delimited messages.
# Ref: https://github.com/libp2p/go-libp2p-pubsub/blob/master/comm.go#L107
# TODO: Go use `sendRPC`, which possibly piggybacks gossip/control messages.
await stream.write(rpc_msg.SerializeToString())
# Deliver to self if self was origin def _get_peers_to_send(
# Note: handle_talk checks if self is subscribed to topics in message self,
for message in packet.publish: topic_ids: Iterable[str],
# Add RPC message to cache src: ID,
self.mcache.put(message) origin: ID) -> Iterable[ID]:
"""
Get the eligible peers to send the data to.
:param src: the peer id of the peer who forwards the message to me.
:param origin: the peer id of the peer who originally broadcast the message.
:return: a generator of the peer ids who we send data to.
"""
to_send: MutableSet[ID] = set()
for topic in topic_ids:
if topic not in self.pubsub.peer_topics:
continue
decoded_from_id = message.from_id.decode('utf-8') # floodsub peers
new_packet = rpc_pb2.RPC() for peer_id_str in self.pubsub.peer_topics[topic]:
new_packet.publish.extend([message]) peer_id = id_b58_decode(peer_id_str)
new_packet_serialized = new_packet.SerializeToString() # FIXME: `gossipsub.peers_floodsub` can be changed to `gossipsub.peers` in go.
# This will improve the efficiency when searching for a peer's protocol id.
if peer_id_str in self.peers_floodsub:
to_send.add(peer_id)
# Deliver to self if needed # gossipsub peers
if msg_sender == decoded_from_id and msg_sender == str(self.pubsub.host.get_id()): # FIXME: Change `str` to `ID`
id_in_seen_msgs = (message.seqno, message.from_id) gossipsub_peers: List[str] = None
# TODO: Do we need to check `topic in self.pubsub.my_topics`?
if topic in self.mesh:
gossipsub_peers = self.mesh[topic]
else:
# TODO(robzajac): Is topic DEFINITELY supposed to be in fanout if we are not
# subscribed?
# I assume there could be short periods between heartbeats where topic may not
# be but we should check that this path gets hit appropriately
if id_in_seen_msgs not in self.pubsub.seen_messages: # pylint: disable=len-as-condition
self.pubsub.seen_messages[id_in_seen_msgs] = 1 if (topic not in self.fanout) or (len(self.fanout[topic]) == 0):
# If no peers in fanout, choose some peers from gossipsub peers in topic.
await self.pubsub.handle_talk(message) self.fanout[topic] = self._get_peers_from_minus(topic, self.degree, [])
gossipsub_peers = self.fanout[topic]
# Deliver to peers for peer_id_str in gossipsub_peers:
for topic in message.topicIDs: to_send.add(id_b58_decode(peer_id_str))
# If topic has floodsub peers, deliver to floodsub peers # Excludes `src` and `origin`
# TODO: This can be done more efficiently. Do it more efficiently. yield from to_send.difference([src, origin])
floodsub_peers_in_topic = []
if topic in self.pubsub.peer_topics:
for peer in self.pubsub.peer_topics[topic]:
if str(peer) in self.peers_floodsub:
floodsub_peers_in_topic.append(peer)
await self.deliver_messages_to_peers(floodsub_peers_in_topic, msg_sender,
decoded_from_id, new_packet_serialized)
# If you are subscribed to topic, send to mesh, otherwise send to fanout
if topic in self.pubsub.my_topics and topic in self.mesh:
await self.deliver_messages_to_peers(self.mesh[topic], msg_sender,
decoded_from_id, new_packet_serialized)
else:
# Send to fanout peers
if topic not in self.fanout:
# If no peers in fanout, choose some peers from gossipsub peers in topic
gossipsub_peers_in_topic = [peer for peer in self.pubsub.peer_topics[topic]
if peer in self.peers_gossipsub]
selected = \
GossipSub.select_from_minus(self.degree, gossipsub_peers_in_topic, [])
self.fanout[topic] = selected
# TODO: Is topic DEFINITELY supposed to be in fanout if we are not subscribed?
# I assume there could be short periods between heartbeats where topic may not
# be but we should check that this path gets hit appropriately
await self.deliver_messages_to_peers(self.fanout[topic], msg_sender,
decoded_from_id, new_packet_serialized)
async def join(self, topic): async def join(self, topic):
# Note: the comments here are the near-exact algorithm description from the spec # Note: the comments here are the near-exact algorithm description from the spec
@ -401,6 +416,22 @@ class GossipSub(IPubsubRouter):
return selection return selection
def _get_peers_from_minus(
self,
topic: str,
num_to_select: int,
minus: Sequence[ID]) -> List[ID]:
gossipsub_peers_in_topic = [
peer_str
for peer_str in self.pubsub.peer_topics[topic]
if peer_str in self.peers_gossipsub
]
return self.select_from_minus(
num_to_select,
gossipsub_peers_in_topic,
list(minus),
)
# RPC handlers # RPC handlers
async def handle_ihave(self, ihave_msg, sender_peer_id): async def handle_ihave(self, ihave_msg, sender_peer_id):

View File

@ -2,11 +2,14 @@ import asyncio
import multiaddr import multiaddr
import uuid import uuid
from utils import message_id_generator, generate_RPC_packet
from libp2p import new_node from libp2p import new_node
from libp2p.host.host_interface import IHost
from libp2p.pubsub.pubsub import Pubsub from libp2p.pubsub.pubsub import Pubsub
from libp2p.pubsub.floodsub import FloodSub from libp2p.pubsub.floodsub import FloodSub
from .utils import message_id_generator, generate_RPC_packet
SUPPORTED_PUBSUB_PROTOCOLS = ["/floodsub/1.0.0"] SUPPORTED_PUBSUB_PROTOCOLS = ["/floodsub/1.0.0"]
CRYPTO_TOPIC = "ethereum" CRYPTO_TOPIC = "ethereum"
@ -17,14 +20,25 @@ CRYPTO_TOPIC = "ethereum"
# Ex. set,rob,5 # Ex. set,rob,5
# Determine message type by looking at first item before first comma # Determine message type by looking at first item before first comma
class DummyAccountNode():
class DummyAccountNode:
""" """
Node which has an internal balance mapping, meant to serve as Node which has an internal balance mapping, meant to serve as
a dummy crypto blockchain. There is no actual blockchain, just a simple a dummy crypto blockchain. There is no actual blockchain, just a simple
map indicating how much crypto each user in the mappings holds map indicating how much crypto each user in the mappings holds
""" """
libp2p_node: IHost
pubsub: Pubsub
floodsub: FloodSub
def __init__(self): def __init__(
self,
libp2p_node: IHost,
pubsub: Pubsub,
floodsub: FloodSub):
self.libp2p_node = libp2p_node
self.pubsub = pubsub
self.floodsub = floodsub
self.balances = {} self.balances = {}
self.next_msg_id_func = message_id_generator(0) self.next_msg_id_func = message_id_generator(0)
self.node_id = str(uuid.uuid1()) self.node_id = str(uuid.uuid1())
@ -38,16 +52,21 @@ class DummyAccountNode():
We use create as this serves as a factory function and allows us We use create as this serves as a factory function and allows us
to use async await, unlike the init function to use async await, unlike the init function
""" """
self = DummyAccountNode()
libp2p_node = await new_node(transport_opt=["/ip4/127.0.0.1/tcp/0"]) libp2p_node = await new_node(transport_opt=["/ip4/127.0.0.1/tcp/0"])
await libp2p_node.get_network().listen(multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/0")) await libp2p_node.get_network().listen(multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/0"))
self.libp2p_node = libp2p_node floodsub = FloodSub(SUPPORTED_PUBSUB_PROTOCOLS)
pubsub = Pubsub(
self.floodsub = FloodSub(SUPPORTED_PUBSUB_PROTOCOLS) libp2p_node,
self.pubsub = Pubsub(self.libp2p_node, self.floodsub, "a") floodsub,
return self "a",
)
return cls(
libp2p_node=libp2p_node,
pubsub=pubsub,
floodsub=floodsub,
)
async def handle_incoming_msgs(self): async def handle_incoming_msgs(self):
""" """
@ -78,10 +97,8 @@ class DummyAccountNode():
:param dest_user: user to send crypto to :param dest_user: user to send crypto to
:param amount: amount of crypto to send :param amount: amount of crypto to send
""" """
my_id = str(self.libp2p_node.get_id())
msg_contents = "send," + source_user + "," + dest_user + "," + str(amount) msg_contents = "send," + source_user + "," + dest_user + "," + str(amount)
packet = generate_RPC_packet(my_id, [CRYPTO_TOPIC], msg_contents, self.next_msg_id_func()) await self.pubsub.publish(CRYPTO_TOPIC, msg_contents.encode())
await self.floodsub.publish(my_id, packet.SerializeToString())
async def publish_set_crypto(self, user, amount): async def publish_set_crypto(self, user, amount):
""" """
@ -89,18 +106,15 @@ class DummyAccountNode():
:param user: user to set crypto for :param user: user to set crypto for
:param amount: amount of crypto :param amount: amount of crypto
""" """
my_id = str(self.libp2p_node.get_id())
msg_contents = "set," + user + "," + str(amount) msg_contents = "set," + user + "," + str(amount)
packet = generate_RPC_packet(my_id, [CRYPTO_TOPIC], msg_contents, self.next_msg_id_func()) await self.pubsub.publish(CRYPTO_TOPIC, msg_contents.encode())
await self.floodsub.publish(my_id, packet.SerializeToString())
def handle_send_crypto(self, source_user, dest_user, amount): def handle_send_crypto(self, source_user, dest_user, amount):
""" """
Handle incoming send_crypto message Handle incoming send_crypto message
:param source_user: user to send crypto from :param source_user: user to send crypto from
:param dest_user: user to send crypto to :param dest_user: user to send crypto to
:param amount: amount of crypto to send :param amount: amount of crypto to send
""" """
if source_user in self.balances: if source_user in self.balances:
self.balances[source_user] -= amount self.balances[source_user] -= amount

View File

@ -1,35 +1,37 @@
import asyncio import asyncio
from threading import Thread
import multiaddr import multiaddr
import pytest import pytest
from threading import Thread
from tests.utils import cleanup
from libp2p import new_node from libp2p import new_node
from libp2p.peer.peerinfo import info_from_p2p_addr from libp2p.peer.peerinfo import info_from_p2p_addr
from libp2p.pubsub.pubsub import Pubsub from libp2p.pubsub.pubsub import Pubsub
from libp2p.pubsub.floodsub import FloodSub from libp2p.pubsub.floodsub import FloodSub
from dummy_account_node import DummyAccountNode
from tests.utils import (
cleanup,
connect,
)
from .dummy_account_node import DummyAccountNode
# pylint: disable=too-many-locals # pylint: disable=too-many-locals
async def connect(node1, node2):
# node1 connects to node2
addr = node2.get_addrs()[0]
info = info_from_p2p_addr(addr)
await node1.connect(info)
def create_setup_in_new_thread_func(dummy_node): def create_setup_in_new_thread_func(dummy_node):
def setup_in_new_thread(): def setup_in_new_thread():
asyncio.ensure_future(dummy_node.setup_crypto_networking()) asyncio.ensure_future(dummy_node.setup_crypto_networking())
return setup_in_new_thread return setup_in_new_thread
async def perform_test(num_nodes, adjacency_map, action_func, assertion_func): async def perform_test(num_nodes, adjacency_map, action_func, assertion_func):
""" """
Helper function to allow for easy construction of custom tests for dummy account nodes Helper function to allow for easy construction of custom tests for dummy account nodes
in various network topologies in various network topologies
:param num_nodes: number of nodes in the test :param num_nodes: number of nodes in the test
:param adjacency_map: adjacency map defining each node and its list of neighbors :param adjacency_map: adjacency map defining each node and its list of neighbors
:param action_func: function to execute that includes actions by the nodes, :param action_func: function to execute that includes actions by the nodes,
such as send crypto and set crypto such as send crypto and set crypto
:param assertion_func: assertions for testing the results of the actions are correct :param assertion_func: assertions for testing the results of the actions are correct
""" """
@ -73,6 +75,7 @@ async def perform_test(num_nodes, adjacency_map, action_func, assertion_func):
# Success, terminate pending tasks. # Success, terminate pending tasks.
await cleanup() await cleanup()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_simple_two_nodes(): async def test_simple_two_nodes():
num_nodes = 2 num_nodes = 2
@ -86,6 +89,7 @@ async def test_simple_two_nodes():
await perform_test(num_nodes, adj_map, action_func, assertion_func) await perform_test(num_nodes, adj_map, action_func, assertion_func)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_simple_three_nodes_line_topography(): async def test_simple_three_nodes_line_topography():
num_nodes = 3 num_nodes = 3
@ -99,6 +103,7 @@ async def test_simple_three_nodes_line_topography():
await perform_test(num_nodes, adj_map, action_func, assertion_func) await perform_test(num_nodes, adj_map, action_func, assertion_func)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_simple_three_nodes_triangle_topography(): async def test_simple_three_nodes_triangle_topography():
num_nodes = 3 num_nodes = 3
@ -112,6 +117,7 @@ async def test_simple_three_nodes_triangle_topography():
await perform_test(num_nodes, adj_map, action_func, assertion_func) await perform_test(num_nodes, adj_map, action_func, assertion_func)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_simple_seven_nodes_tree_topography(): async def test_simple_seven_nodes_tree_topography():
num_nodes = 7 num_nodes = 7
@ -125,6 +131,7 @@ async def test_simple_seven_nodes_tree_topography():
await perform_test(num_nodes, adj_map, action_func, assertion_func) await perform_test(num_nodes, adj_map, action_func, assertion_func)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_set_then_send_from_root_seven_nodes_tree_topography(): async def test_set_then_send_from_root_seven_nodes_tree_topography():
num_nodes = 7 num_nodes = 7

View File

@ -1,11 +1,16 @@
import asyncio import asyncio
import pytest
import random import random
from utils import message_id_generator, generate_RPC_packet, \ import pytest
from tests.utils import (
cleanup,
connect,
)
from .utils import message_id_generator, generate_RPC_packet, \
create_libp2p_hosts, create_pubsub_and_gossipsub_instances, sparse_connect, dense_connect, \ create_libp2p_hosts, create_pubsub_and_gossipsub_instances, sparse_connect, dense_connect, \
connect, one_to_all_connect one_to_all_connect
from tests.utils import cleanup
SUPPORTED_PROTOCOLS = ["/gossipsub/1.0.0"] SUPPORTED_PROTOCOLS = ["/gossipsub/1.0.0"]
@ -41,13 +46,8 @@ async def test_join():
# Central node publish to the topic so that this topic # Central node publish to the topic so that this topic
# is added to central node's fanout # is added to central node's fanout
next_msg_id_func = message_id_generator(0)
msg_content = ""
host_id = str(libp2p_hosts[central_node_index].get_id())
# Generate message packet
packet = generate_RPC_packet(host_id, [topic], msg_content, next_msg_id_func())
# publish from the randomly chosen host # publish from the randomly chosen host
await gossipsubs[central_node_index].publish(host_id, packet.SerializeToString()) await pubsubs[central_node_index].publish(topic, b"")
# Check that the gossipsub of central node has fanout for the topic # Check that the gossipsub of central node has fanout for the topic
assert topic in gossipsubs[central_node_index].fanout assert topic in gossipsubs[central_node_index].fanout
@ -86,6 +86,8 @@ async def test_leave():
gossipsub = gossipsubs[0] gossipsub = gossipsubs[0]
topic = "test_leave" topic = "test_leave"
assert topic not in gossipsub.mesh
await gossipsub.join(topic) await gossipsub.join(topic)
assert topic in gossipsub.mesh assert topic in gossipsub.mesh
@ -205,14 +207,12 @@ async def test_handle_prune():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_dense(): async def test_dense():
# Create libp2p hosts # Create libp2p hosts
next_msg_id_func = message_id_generator(0)
num_hosts = 10 num_hosts = 10
num_msgs = 5 num_msgs = 5
libp2p_hosts = await create_libp2p_hosts(num_hosts) libp2p_hosts = await create_libp2p_hosts(num_hosts)
# Create pubsub, gossipsub instances # Create pubsub, gossipsub instances
pubsubs, gossipsubs = create_pubsub_and_gossipsub_instances(libp2p_hosts, \ pubsubs, _ = create_pubsub_and_gossipsub_instances(libp2p_hosts, \
SUPPORTED_PROTOCOLS, \ SUPPORTED_PROTOCOLS, \
10, 9, 11, 30, 3, 5, 0.5) 10, 9, 11, 30, 3, 5, 0.5)
@ -231,41 +231,35 @@ async def test_dense():
await asyncio.sleep(2) await asyncio.sleep(2)
for i in range(num_msgs): for i in range(num_msgs):
msg_content = "foo " + str(i) msg_content = b"foo " + i.to_bytes(1, 'big')
# randomly pick a message origin # randomly pick a message origin
origin_idx = random.randint(0, num_hosts - 1) origin_idx = random.randint(0, num_hosts - 1)
origin_host = libp2p_hosts[origin_idx]
host_id = str(origin_host.get_id())
# Generate message packet
packet = generate_RPC_packet(host_id, ["foobar"], msg_content, next_msg_id_func())
# publish from the randomly chosen host # publish from the randomly chosen host
await gossipsubs[origin_idx].publish(host_id, packet.SerializeToString()) await pubsubs[origin_idx].publish("foobar", msg_content)
await asyncio.sleep(0.5) await asyncio.sleep(0.5)
# Assert that all blocking queues receive the message # Assert that all blocking queues receive the message
for queue in queues: for queue in queues:
msg = await queue.get() msg = await queue.get()
assert msg.data == packet.publish[0].data assert msg.data == msg_content
await cleanup() await cleanup()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_fanout(): async def test_fanout():
# Create libp2p hosts # Create libp2p hosts
next_msg_id_func = message_id_generator(0)
num_hosts = 10 num_hosts = 10
num_msgs = 5 num_msgs = 5
libp2p_hosts = await create_libp2p_hosts(num_hosts) libp2p_hosts = await create_libp2p_hosts(num_hosts)
# Create pubsub, gossipsub instances # Create pubsub, gossipsub instances
pubsubs, gossipsubs = create_pubsub_and_gossipsub_instances(libp2p_hosts, \ pubsubs, _ = create_pubsub_and_gossipsub_instances(libp2p_hosts, \
SUPPORTED_PROTOCOLS, \ SUPPORTED_PROTOCOLS, \
10, 9, 11, 30, 3, 5, 0.5) 10, 9, 11, 30, 3, 5, 0.5)
# All pubsub subscribe to foobar # All pubsub subscribe to foobar except for `pubsubs[0]`
queues = [] queues = []
for i in range(1, len(pubsubs)): for i in range(1, len(pubsubs)):
q = await pubsubs[i].subscribe("foobar") q = await pubsubs[i].subscribe("foobar")
@ -279,71 +273,61 @@ async def test_fanout():
# Wait 2 seconds for heartbeat to allow mesh to connect # Wait 2 seconds for heartbeat to allow mesh to connect
await asyncio.sleep(2) await asyncio.sleep(2)
topic = "foobar"
# Send messages with origin not subscribed # Send messages with origin not subscribed
for i in range(num_msgs): for i in range(num_msgs):
msg_content = "foo " + str(i) msg_content = b"foo " + i.to_bytes(1, "big")
# Pick the message origin to the node that is not subscribed to 'foobar' # Pick the message origin to the node that is not subscribed to 'foobar'
origin_idx = 0 origin_idx = 0
origin_host = libp2p_hosts[origin_idx]
host_id = str(origin_host.get_id())
# Generate message packet
packet = generate_RPC_packet(host_id, ["foobar"], msg_content, next_msg_id_func())
# publish from the randomly chosen host # publish from the randomly chosen host
await gossipsubs[origin_idx].publish(host_id, packet.SerializeToString()) await pubsubs[origin_idx].publish(topic, msg_content)
await asyncio.sleep(0.5) await asyncio.sleep(0.5)
# Assert that all blocking queues receive the message # Assert that all blocking queues receive the message
for queue in queues: for queue in queues:
msg = await queue.get() msg = await queue.get()
assert msg.SerializeToString() == packet.publish[0].SerializeToString() assert msg.data == msg_content
# Subscribe message origin # Subscribe message origin
queues.append(await pubsubs[0].subscribe("foobar")) queues.insert(0, await pubsubs[0].subscribe(topic))
# Send messages again # Send messages again
for i in range(num_msgs): for i in range(num_msgs):
msg_content = "foo " + str(i) msg_content = b"bar " + i.to_bytes(1, 'big')
# Pick the message origin to the node that is not subscribed to 'foobar' # Pick the message origin to the node that is not subscribed to 'foobar'
origin_idx = 0 origin_idx = 0
origin_host = libp2p_hosts[origin_idx]
host_id = str(origin_host.get_id())
# Generate message packet
packet = generate_RPC_packet(host_id, ["foobar"], msg_content, next_msg_id_func())
# publish from the randomly chosen host # publish from the randomly chosen host
await gossipsubs[origin_idx].publish(host_id, packet.SerializeToString()) await pubsubs[origin_idx].publish(topic, msg_content)
await asyncio.sleep(0.5) await asyncio.sleep(0.5)
# Assert that all blocking queues receive the message # Assert that all blocking queues receive the message
for queue in queues: for queue in queues:
msg = await queue.get() msg = await queue.get()
assert msg.SerializeToString() == packet.publish[0].SerializeToString() assert msg.data == msg_content
await cleanup() await cleanup()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_fanout_maintenance(): async def test_fanout_maintenance():
# Create libp2p hosts # Create libp2p hosts
next_msg_id_func = message_id_generator(0)
num_hosts = 10 num_hosts = 10
num_msgs = 5 num_msgs = 5
libp2p_hosts = await create_libp2p_hosts(num_hosts) libp2p_hosts = await create_libp2p_hosts(num_hosts)
# Create pubsub, gossipsub instances # Create pubsub, gossipsub instances
pubsubs, gossipsubs = create_pubsub_and_gossipsub_instances(libp2p_hosts, \ pubsubs, _ = create_pubsub_and_gossipsub_instances(libp2p_hosts, \
SUPPORTED_PROTOCOLS, \ SUPPORTED_PROTOCOLS, \
10, 9, 11, 30, 3, 5, 0.5) 10, 9, 11, 30, 3, 5, 0.5)
# All pubsub subscribe to foobar # All pubsub subscribe to foobar
queues = [] queues = []
topic = "foobar"
for i in range(1, len(pubsubs)): for i in range(1, len(pubsubs)):
q = await pubsubs[i].subscribe("foobar") q = await pubsubs[i].subscribe(topic)
# Add each blocking queue to an array of blocking queues # Add each blocking queue to an array of blocking queues
queues.append(q) queues.append(q)
@ -356,27 +340,22 @@ async def test_fanout_maintenance():
# Send messages with origin not subscribed # Send messages with origin not subscribed
for i in range(num_msgs): for i in range(num_msgs):
msg_content = "foo " + str(i) msg_content = b"foo " + i.to_bytes(1, 'big')
# Pick the message origin to the node that is not subscribed to 'foobar' # Pick the message origin to the node that is not subscribed to 'foobar'
origin_idx = 0 origin_idx = 0
origin_host = libp2p_hosts[origin_idx]
host_id = str(origin_host.get_id())
# Generate message packet
packet = generate_RPC_packet(host_id, ["foobar"], msg_content, next_msg_id_func())
# publish from the randomly chosen host # publish from the randomly chosen host
await gossipsubs[origin_idx].publish(host_id, packet.SerializeToString()) await pubsubs[origin_idx].publish(topic, msg_content)
await asyncio.sleep(0.5) await asyncio.sleep(0.5)
# Assert that all blocking queues receive the message # Assert that all blocking queues receive the message
for queue in queues: for queue in queues:
msg = await queue.get() msg = await queue.get()
assert msg.SerializeToString() == packet.publish[0].SerializeToString() assert msg.data == msg_content
for sub in pubsubs: for sub in pubsubs:
await sub.unsubscribe('foobar') await sub.unsubscribe(topic)
queues = [] queues = []
@ -384,7 +363,7 @@ async def test_fanout_maintenance():
# Resub and repeat # Resub and repeat
for i in range(1, len(pubsubs)): for i in range(1, len(pubsubs)):
q = await pubsubs[i].subscribe("foobar") q = await pubsubs[i].subscribe(topic)
# Add each blocking queue to an array of blocking queues # Add each blocking queue to an array of blocking queues
queues.append(q) queues.append(q)
@ -393,65 +372,61 @@ async def test_fanout_maintenance():
# Check messages can still be sent # Check messages can still be sent
for i in range(num_msgs): for i in range(num_msgs):
msg_content = "foo " + str(i) msg_content = b"bar " + i.to_bytes(1, 'big')
# Pick the message origin to the node that is not subscribed to 'foobar' # Pick the message origin to the node that is not subscribed to 'foobar'
origin_idx = 0 origin_idx = 0
origin_host = libp2p_hosts[origin_idx]
host_id = str(origin_host.get_id())
# Generate message packet
packet = generate_RPC_packet(host_id, ["foobar"], msg_content, next_msg_id_func())
# publish from the randomly chosen host # publish from the randomly chosen host
await gossipsubs[origin_idx].publish(host_id, packet.SerializeToString()) await pubsubs[origin_idx].publish(topic, msg_content)
await asyncio.sleep(0.5) await asyncio.sleep(0.5)
# Assert that all blocking queues receive the message # Assert that all blocking queues receive the message
for queue in queues: for queue in queues:
msg = await queue.get() msg = await queue.get()
assert msg.SerializeToString() == packet.publish[0].SerializeToString() assert msg.data == msg_content
await cleanup() await cleanup()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_gossip_propagation(): async def test_gossip_propagation():
# Create libp2p hosts # Create libp2p hosts
next_msg_id_func = message_id_generator(0)
num_hosts = 2 num_hosts = 2
libp2p_hosts = await create_libp2p_hosts(num_hosts) hosts = await create_libp2p_hosts(num_hosts)
# Create pubsub, gossipsub instances # Create pubsub, gossipsub instances
pubsubs, gossipsubs = create_pubsub_and_gossipsub_instances(libp2p_hosts, \ pubsubs, _ = create_pubsub_and_gossipsub_instances(
SUPPORTED_PROTOCOLS, \ hosts,
1, 0, 2, 30, 50, 100, 0.5) SUPPORTED_PROTOCOLS,
node1, node2 = libp2p_hosts[0], libp2p_hosts[1] 1,
sub1, sub2 = pubsubs[0], pubsubs[1] 0,
gsub1, gsub2 = gossipsubs[0], gossipsubs[1] 2,
30,
50,
100,
0.5,
)
node1_queue = await sub1.subscribe('foo') topic = "foo"
await pubsubs[0].subscribe(topic)
# node 1 publish to topic # node 0 publish to topic
msg_content = 'foo_msg' msg_content = b'foo_msg'
node1_id = str(node1.get_id())
# Generate message packet
packet = generate_RPC_packet(node1_id, ["foo"], msg_content, next_msg_id_func())
# publish from the randomly chosen host # publish from the randomly chosen host
await gsub1.publish(node1_id, packet.SerializeToString()) await pubsubs[0].publish(topic, msg_content)
# now node 2 subscribes # now node 1 subscribes
node2_queue = await sub2.subscribe('foo') queue_1 = await pubsubs[1].subscribe(topic)
await connect(node2, node1) await connect(hosts[0], hosts[1])
# wait for gossip heartbeat # wait for gossip heartbeat
await asyncio.sleep(2) await asyncio.sleep(2)
# should be able to read message # should be able to read message
msg = await node2_queue.get() msg = await queue_1.get()
assert msg.SerializeToString() == packet.publish[0].SerializeToString() assert msg.data == msg_content
await cleanup() await cleanup()

View File

@ -8,18 +8,17 @@ from libp2p.pubsub.gossipsub import GossipSub
from libp2p.pubsub.floodsub import FloodSub from libp2p.pubsub.floodsub import FloodSub
from libp2p.pubsub.pb import rpc_pb2 from libp2p.pubsub.pb import rpc_pb2
from libp2p.pubsub.pubsub import Pubsub from libp2p.pubsub.pubsub import Pubsub
from utils import message_id_generator, generate_RPC_packet
from tests.utils import cleanup from tests.utils import cleanup
# pylint: disable=too-many-locals from .utils import (
connect,
message_id_generator,
generate_RPC_packet,
)
async def connect(node1, node2):
""" # pylint: disable=too-many-locals
Connect node1 to node2
"""
addr = node2.get_addrs()[0]
info = info_from_p2p_addr(addr)
await node1.connect(info)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_init(): async def test_init():
@ -37,11 +36,12 @@ async def test_init():
await cleanup() await cleanup()
async def perform_test_from_obj(obj): async def perform_test_from_obj(obj):
""" """
Perform a floodsub test from a test obj. Perform a floodsub test from a test obj.
test obj are composed as follows: test obj are composed as follows:
{ {
"supported_protocols": ["supported/protocol/1.0.0",...], "supported_protocols": ["supported/protocol/1.0.0",...],
"adj_list": { "adj_list": {
@ -95,7 +95,7 @@ async def perform_test_from_obj(obj):
if neighbor_id not in node_map: if neighbor_id not in node_map:
neighbor_node = await new_node(transport_opt=["/ip4/127.0.0.1/tcp/0"]) neighbor_node = await new_node(transport_opt=["/ip4/127.0.0.1/tcp/0"])
await neighbor_node.get_network().listen(multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/0")) await neighbor_node.get_network().listen(multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/0"))
node_map[neighbor_id] = neighbor_node node_map[neighbor_id] = neighbor_node
gossipsub = GossipSub(supported_protocols, 3, 2, 4, 30) gossipsub = GossipSub(supported_protocols, 3, 2, 4, 30)
@ -104,7 +104,7 @@ async def perform_test_from_obj(obj):
pubsub_map[neighbor_id] = pubsub pubsub_map[neighbor_id] = pubsub
# Connect node and neighbor # Connect node and neighbor
tasks_connect.append(asyncio.ensure_future(connect(node_map[start_node_id], node_map[neighbor_id]))) tasks_connect.append(connect(node_map[start_node_id], node_map[neighbor_id]))
tasks_connect.append(asyncio.sleep(2)) tasks_connect.append(asyncio.sleep(2))
await asyncio.gather(*tasks_connect) await asyncio.gather(*tasks_connect)
@ -130,7 +130,7 @@ async def perform_test_from_obj(obj):
# Store queue in topic-queue map for node # Store queue in topic-queue map for node
queues_map[node_id][topic] = q queues_map[node_id][topic] = q
""" """
tasks_topic.append(asyncio.ensure_future(pubsub_map[node_id].subscribe(topic))) tasks_topic.append(pubsub_map[node_id].subscribe(topic))
tasks_topic_data.append((node_id, topic)) tasks_topic_data.append((node_id, topic))
tasks_topic.append(asyncio.sleep(2)) tasks_topic.append(asyncio.sleep(2))
@ -152,29 +152,27 @@ async def perform_test_from_obj(obj):
topics_in_msgs_ordered = [] topics_in_msgs_ordered = []
messages = obj["messages"] messages = obj["messages"]
tasks_publish = [] tasks_publish = []
next_msg_id_func = message_id_generator(0)
for msg in messages: for msg in messages:
topics = msg["topics"] topics = msg["topics"]
data = msg["data"] data = msg["data"]
node_id = msg["node_id"] node_id = msg["node_id"]
# Get actual id for sender node (not the id from the test obj)
actual_node_id = str(node_map[node_id].get_id())
# Create correctly formatted message
msg_talk = generate_RPC_packet(actual_node_id, topics, data, next_msg_id_func())
# Publish message # Publish message
tasks_publish.append(asyncio.ensure_future(gossipsub_map[node_id].publish(\ # FIXME: This should be one RPC packet with several topics
actual_node_id, msg_talk.SerializeToString()))) for topic in topics:
tasks_publish.append(
pubsub_map[node_id].publish(
topic,
data,
)
)
# For each topic in topics, add topic, msg_talk tuple to ordered test list # For each topic in topics, add topic, msg_talk tuple to ordered test list
# TODO: Update message sender to be correct message sender before # TODO: Update message sender to be correct message sender before
# adding msg_talk to this list # adding msg_talk to this list
for topic in topics: for topic in topics:
topics_in_msgs_ordered.append((topic, msg_talk)) topics_in_msgs_ordered.append((topic, data))
# Allow time for publishing before continuing # Allow time for publishing before continuing
# await asyncio.sleep(0.4) # await asyncio.sleep(0.4)
@ -183,14 +181,12 @@ async def perform_test_from_obj(obj):
# Step 4) Check that all messages were received correctly. # Step 4) Check that all messages were received correctly.
# TODO: Check message sender too # TODO: Check message sender too
for i in range(len(topics_in_msgs_ordered)): for topic, data in topics_in_msgs_ordered:
topic, actual_msg = topics_in_msgs_ordered[i]
# Look at each node in each topic # Look at each node in each topic
for node_id in topic_map[topic]: for node_id in topic_map[topic]:
# Get message from subscription queue # Get message from subscription queue
msg_on_node = await queues_map[node_id][topic].get() msg_on_node = await queues_map[node_id][topic].get()
assert actual_msg.publish[0].SerializeToString() == msg_on_node.SerializeToString() assert msg_on_node.data == data
# Success, terminate pending tasks. # Success, terminate pending tasks.
await cleanup() await cleanup()

View File

@ -14,6 +14,8 @@ from libp2p.peer.id import ID
from libp2p.pubsub.pubsub import Pubsub from libp2p.pubsub.pubsub import Pubsub
from libp2p.pubsub.gossipsub import GossipSub from libp2p.pubsub.gossipsub import GossipSub
from tests.utils import connect
def message_id_generator(start_val): def message_id_generator(start_val):
""" """
@ -22,6 +24,7 @@ def message_id_generator(start_val):
:return: message id :return: message id
""" """
val = start_val val = start_val
def generator(): def generator():
# Allow manipulation of val within closure # Allow manipulation of val within closure
nonlocal val nonlocal val
@ -105,6 +108,10 @@ def create_pubsub_and_gossipsub_instances(libp2p_hosts, supported_protocols, deg
return pubsubs, gossipsubs return pubsubs, gossipsubs
# FIXME: There is no difference between `sparse_connect` and `dense_connect`,
# before `connect_some` is fixed.
async def sparse_connect(hosts): async def sparse_connect(hosts):
await connect_some(hosts, 3) await connect_some(hosts, 3)
@ -113,6 +120,7 @@ async def dense_connect(hosts):
await connect_some(hosts, 10) await connect_some(hosts, 10)
# FIXME: `degree` is not used at all
async def connect_some(hosts, degree): async def connect_some(hosts, degree):
for i, host in enumerate(hosts): for i, host in enumerate(hosts):
for j, host2 in enumerate(hosts): for j, host2 in enumerate(hosts):
@ -135,6 +143,7 @@ async def connect_some(hosts, degree):
# j += 1 # j += 1
async def one_to_all_connect(hosts, central_host_index): async def one_to_all_connect(hosts, central_host_index):
for i, host in enumerate(hosts): for i, host in enumerate(hosts):
if i != central_host_index: if i != central_host_index: