Fix several tests

This commit is contained in:
mhchia 2019-07-26 18:35:25 +08:00
parent 035d08b8bd
commit 65aedcb25a
No known key found for this signature in database
GPG Key ID: 389EFBEA1362589A
7 changed files with 245 additions and 213 deletions

View File

@ -80,6 +80,21 @@ class FloodSub(IPubsubRouter):
# Ref: https://github.com/libp2p/go-libp2p-pubsub/blob/master/comm.go#L107
await stream.write(rpc_msg.SerializeToString())
async def join(self, topic):
"""
Join notifies the router that we want to receive and
forward messages in a topic. It is invoked after the
subscription announcement
:param topic: topic to join
"""
async def leave(self, topic):
"""
Leave notifies the router that we are no longer interested in a topic.
It is invoked after the unsubscription announcement.
:param topic: topic to leave
"""
def _get_peers_to_send(
self,
topic_ids: Iterable[str],
@ -102,18 +117,3 @@ class FloodSub(IPubsubRouter):
if str(peer_id) not in self.pubsub.peers:
continue
yield peer_id
async def join(self, topic):
"""
Join notifies the router that we want to receive and
forward messages in a topic. It is invoked after the
subscription announcement
:param topic: topic to join
"""
async def leave(self, topic):
"""
Leave notifies the router that we are no longer interested in a topic.
It is invoked after the unsubscription announcement.
:param topic: topic to leave
"""

View File

@ -1,10 +1,22 @@
import random
import asyncio
import random
from typing import (
Iterable,
List,
MutableSet,
Sequence,
)
from ast import literal_eval
from libp2p.peer.id import (
ID,
id_b58_decode,
)
from .mcache import MessageCache
from .pb import rpc_pb2
from .pubsub_router_interface import IPubsubRouter
from .mcache import MessageCache
class GossipSub(IPubsubRouter):
@ -107,70 +119,73 @@ class GossipSub(IPubsubRouter):
for prune in control_message.prune:
await self.handle_prune(prune, sender_peer_id)
async def publish(self, sender_peer_id, rpc_message):
async def publish(self, src: ID, pubsub_msg: rpc_pb2.Message) -> None:
# pylint: disable=too-many-locals
"""
Invoked to forward a new message that has been validated.
"""
self.mcache.put(pubsub_msg)
packet = rpc_pb2.RPC()
packet.ParseFromString(rpc_message)
msg_sender = str(sender_peer_id)
peers_gen = self._get_peers_to_send(
pubsub_msg.topicIDs,
src=src,
origin=ID(pubsub_msg.from_id),
)
rpc_msg = rpc_pb2.RPC(
publish=[pubsub_msg],
)
for peer_id in peers_gen:
stream = self.pubsub.peers[str(peer_id)]
# FIXME: We should add a `WriteMsg` similar to write delimited messages.
# Ref: https://github.com/libp2p/go-libp2p-pubsub/blob/master/comm.go#L107
# TODO: Go use `sendRPC`, which possibly piggybacks gossip/control messages.
await stream.write(rpc_msg.SerializeToString())
# Deliver to self if self was origin
# Note: handle_talk checks if self is subscribed to topics in message
for message in packet.publish:
# Add RPC message to cache
self.mcache.put(message)
def _get_peers_to_send(
self,
topic_ids: Iterable[str],
src: ID,
origin: ID) -> Iterable[ID]:
"""
Get the eligible peers to send the data to.
:param src: the peer id of the peer who forwards the message to me.
:param origin: the peer id of the peer who originally broadcast the message.
:return: a generator of the peer ids who we send data to.
"""
to_send: MutableSet[ID] = set()
for topic in topic_ids:
if topic not in self.pubsub.peer_topics:
continue
decoded_from_id = message.from_id.decode('utf-8')
new_packet = rpc_pb2.RPC()
new_packet.publish.extend([message])
new_packet_serialized = new_packet.SerializeToString()
# floodsub peers
for peer_id_str in self.pubsub.peer_topics[topic]:
peer_id = id_b58_decode(peer_id_str)
# FIXME: `gossipsub.peers_floodsub` can be changed to `gossipsub.peers` in go.
# This will improve the efficiency when searching for a peer's protocol id.
if peer_id_str in self.peers_floodsub:
to_send.add(peer_id)
# Deliver to self if needed
if msg_sender == decoded_from_id and msg_sender == str(self.pubsub.host.get_id()):
id_in_seen_msgs = (message.seqno, message.from_id)
# gossipsub peers
# FIXME: Change `str` to `ID`
gossipsub_peers: List[str] = None
# TODO: Do we need to check `topic in self.pubsub.my_topics`?
if topic in self.mesh:
gossipsub_peers = self.mesh[topic]
else:
# TODO(robzajac): Is topic DEFINITELY supposed to be in fanout if we are not
# subscribed?
# I assume there could be short periods between heartbeats where topic may not
# be but we should check that this path gets hit appropriately
if id_in_seen_msgs not in self.pubsub.seen_messages:
self.pubsub.seen_messages[id_in_seen_msgs] = 1
await self.pubsub.handle_talk(message)
# Deliver to peers
for topic in message.topicIDs:
# If topic has floodsub peers, deliver to floodsub peers
# TODO: This can be done more efficiently. Do it more efficiently.
floodsub_peers_in_topic = []
if topic in self.pubsub.peer_topics:
for peer in self.pubsub.peer_topics[topic]:
if str(peer) in self.peers_floodsub:
floodsub_peers_in_topic.append(peer)
await self.deliver_messages_to_peers(floodsub_peers_in_topic, msg_sender,
decoded_from_id, new_packet_serialized)
# If you are subscribed to topic, send to mesh, otherwise send to fanout
if topic in self.pubsub.my_topics and topic in self.mesh:
await self.deliver_messages_to_peers(self.mesh[topic], msg_sender,
decoded_from_id, new_packet_serialized)
else:
# Send to fanout peers
if topic not in self.fanout:
# If no peers in fanout, choose some peers from gossipsub peers in topic
gossipsub_peers_in_topic = [peer for peer in self.pubsub.peer_topics[topic]
if peer in self.peers_gossipsub]
selected = \
GossipSub.select_from_minus(self.degree, gossipsub_peers_in_topic, [])
self.fanout[topic] = selected
# TODO: Is topic DEFINITELY supposed to be in fanout if we are not subscribed?
# I assume there could be short periods between heartbeats where topic may not
# be but we should check that this path gets hit appropriately
await self.deliver_messages_to_peers(self.fanout[topic], msg_sender,
decoded_from_id, new_packet_serialized)
# pylint: disable=len-as-condition
if (topic not in self.fanout) or (len(self.fanout[topic]) == 0):
# If no peers in fanout, choose some peers from gossipsub peers in topic.
self.fanout[topic] = self._get_peers_from_minus(topic, self.degree, [])
gossipsub_peers = self.fanout[topic]
for peer_id_str in gossipsub_peers:
to_send.add(id_b58_decode(peer_id_str))
# Excludes `src` and `origin`
yield from to_send.difference([src, origin])
async def join(self, topic):
# Note: the comments here are the near-exact algorithm description from the spec
@ -401,6 +416,22 @@ class GossipSub(IPubsubRouter):
return selection
def _get_peers_from_minus(
self,
topic: str,
num_to_select: int,
minus: Sequence[ID]) -> List[ID]:
gossipsub_peers_in_topic = [
peer_str
for peer_str in self.pubsub.peer_topics[topic]
if peer_str in self.peers_gossipsub
]
return self.select_from_minus(
num_to_select,
gossipsub_peers_in_topic,
list(minus),
)
# RPC handlers
async def handle_ihave(self, ihave_msg, sender_peer_id):

View File

@ -2,11 +2,14 @@ import asyncio
import multiaddr
import uuid
from utils import message_id_generator, generate_RPC_packet
from libp2p import new_node
from libp2p.host.host_interface import IHost
from libp2p.pubsub.pubsub import Pubsub
from libp2p.pubsub.floodsub import FloodSub
from .utils import message_id_generator, generate_RPC_packet
SUPPORTED_PUBSUB_PROTOCOLS = ["/floodsub/1.0.0"]
CRYPTO_TOPIC = "ethereum"
@ -17,14 +20,25 @@ CRYPTO_TOPIC = "ethereum"
# Ex. set,rob,5
# Determine message type by looking at first item before first comma
class DummyAccountNode():
class DummyAccountNode:
"""
Node which has an internal balance mapping, meant to serve as
Node which has an internal balance mapping, meant to serve as
a dummy crypto blockchain. There is no actual blockchain, just a simple
map indicating how much crypto each user in the mappings holds
"""
libp2p_node: IHost
pubsub: Pubsub
floodsub: FloodSub
def __init__(self):
def __init__(
self,
libp2p_node: IHost,
pubsub: Pubsub,
floodsub: FloodSub):
self.libp2p_node = libp2p_node
self.pubsub = pubsub
self.floodsub = floodsub
self.balances = {}
self.next_msg_id_func = message_id_generator(0)
self.node_id = str(uuid.uuid1())
@ -38,16 +52,21 @@ class DummyAccountNode():
We use create as this serves as a factory function and allows us
to use async await, unlike the init function
"""
self = DummyAccountNode()
libp2p_node = await new_node(transport_opt=["/ip4/127.0.0.1/tcp/0"])
await libp2p_node.get_network().listen(multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/0"))
self.libp2p_node = libp2p_node
self.floodsub = FloodSub(SUPPORTED_PUBSUB_PROTOCOLS)
self.pubsub = Pubsub(self.libp2p_node, self.floodsub, "a")
return self
floodsub = FloodSub(SUPPORTED_PUBSUB_PROTOCOLS)
pubsub = Pubsub(
libp2p_node,
floodsub,
"a",
)
return cls(
libp2p_node=libp2p_node,
pubsub=pubsub,
floodsub=floodsub,
)
async def handle_incoming_msgs(self):
"""
@ -78,10 +97,8 @@ class DummyAccountNode():
:param dest_user: user to send crypto to
:param amount: amount of crypto to send
"""
my_id = str(self.libp2p_node.get_id())
msg_contents = "send," + source_user + "," + dest_user + "," + str(amount)
packet = generate_RPC_packet(my_id, [CRYPTO_TOPIC], msg_contents, self.next_msg_id_func())
await self.floodsub.publish(my_id, packet.SerializeToString())
await self.pubsub.publish(CRYPTO_TOPIC, msg_contents.encode())
async def publish_set_crypto(self, user, amount):
"""
@ -89,18 +106,15 @@ class DummyAccountNode():
:param user: user to set crypto for
:param amount: amount of crypto
"""
my_id = str(self.libp2p_node.get_id())
msg_contents = "set," + user + "," + str(amount)
packet = generate_RPC_packet(my_id, [CRYPTO_TOPIC], msg_contents, self.next_msg_id_func())
await self.floodsub.publish(my_id, packet.SerializeToString())
await self.pubsub.publish(CRYPTO_TOPIC, msg_contents.encode())
def handle_send_crypto(self, source_user, dest_user, amount):
"""
Handle incoming send_crypto message
:param source_user: user to send crypto from
:param dest_user: user to send crypto to
:param amount: amount of crypto to send
:param amount: amount of crypto to send
"""
if source_user in self.balances:
self.balances[source_user] -= amount

View File

@ -1,35 +1,37 @@
import asyncio
from threading import Thread
import multiaddr
import pytest
from threading import Thread
from tests.utils import cleanup
from libp2p import new_node
from libp2p.peer.peerinfo import info_from_p2p_addr
from libp2p.pubsub.pubsub import Pubsub
from libp2p.pubsub.floodsub import FloodSub
from dummy_account_node import DummyAccountNode
from tests.utils import (
cleanup,
connect,
)
from .dummy_account_node import DummyAccountNode
# pylint: disable=too-many-locals
async def connect(node1, node2):
# node1 connects to node2
addr = node2.get_addrs()[0]
info = info_from_p2p_addr(addr)
await node1.connect(info)
def create_setup_in_new_thread_func(dummy_node):
def setup_in_new_thread():
asyncio.ensure_future(dummy_node.setup_crypto_networking())
return setup_in_new_thread
async def perform_test(num_nodes, adjacency_map, action_func, assertion_func):
"""
Helper function to allow for easy construction of custom tests for dummy account nodes
in various network topologies
:param num_nodes: number of nodes in the test
:param adjacency_map: adjacency map defining each node and its list of neighbors
:param action_func: function to execute that includes actions by the nodes,
:param action_func: function to execute that includes actions by the nodes,
such as send crypto and set crypto
:param assertion_func: assertions for testing the results of the actions are correct
"""
@ -73,6 +75,7 @@ async def perform_test(num_nodes, adjacency_map, action_func, assertion_func):
# Success, terminate pending tasks.
await cleanup()
@pytest.mark.asyncio
async def test_simple_two_nodes():
num_nodes = 2
@ -86,6 +89,7 @@ async def test_simple_two_nodes():
await perform_test(num_nodes, adj_map, action_func, assertion_func)
@pytest.mark.asyncio
async def test_simple_three_nodes_line_topography():
num_nodes = 3
@ -99,6 +103,7 @@ async def test_simple_three_nodes_line_topography():
await perform_test(num_nodes, adj_map, action_func, assertion_func)
@pytest.mark.asyncio
async def test_simple_three_nodes_triangle_topography():
num_nodes = 3
@ -112,6 +117,7 @@ async def test_simple_three_nodes_triangle_topography():
await perform_test(num_nodes, adj_map, action_func, assertion_func)
@pytest.mark.asyncio
async def test_simple_seven_nodes_tree_topography():
num_nodes = 7
@ -125,6 +131,7 @@ async def test_simple_seven_nodes_tree_topography():
await perform_test(num_nodes, adj_map, action_func, assertion_func)
@pytest.mark.asyncio
async def test_set_then_send_from_root_seven_nodes_tree_topography():
num_nodes = 7

View File

@ -1,11 +1,16 @@
import asyncio
import pytest
import random
from utils import message_id_generator, generate_RPC_packet, \
import pytest
from tests.utils import (
cleanup,
connect,
)
from .utils import message_id_generator, generate_RPC_packet, \
create_libp2p_hosts, create_pubsub_and_gossipsub_instances, sparse_connect, dense_connect, \
connect, one_to_all_connect
from tests.utils import cleanup
one_to_all_connect
SUPPORTED_PROTOCOLS = ["/gossipsub/1.0.0"]
@ -41,13 +46,8 @@ async def test_join():
# Central node publish to the topic so that this topic
# is added to central node's fanout
next_msg_id_func = message_id_generator(0)
msg_content = ""
host_id = str(libp2p_hosts[central_node_index].get_id())
# Generate message packet
packet = generate_RPC_packet(host_id, [topic], msg_content, next_msg_id_func())
# publish from the randomly chosen host
await gossipsubs[central_node_index].publish(host_id, packet.SerializeToString())
await pubsubs[central_node_index].publish(topic, b"")
# Check that the gossipsub of central node has fanout for the topic
assert topic in gossipsubs[central_node_index].fanout
@ -86,6 +86,8 @@ async def test_leave():
gossipsub = gossipsubs[0]
topic = "test_leave"
assert topic not in gossipsub.mesh
await gossipsub.join(topic)
assert topic in gossipsub.mesh
@ -205,14 +207,12 @@ async def test_handle_prune():
@pytest.mark.asyncio
async def test_dense():
# Create libp2p hosts
next_msg_id_func = message_id_generator(0)
num_hosts = 10
num_msgs = 5
libp2p_hosts = await create_libp2p_hosts(num_hosts)
# Create pubsub, gossipsub instances
pubsubs, gossipsubs = create_pubsub_and_gossipsub_instances(libp2p_hosts, \
pubsubs, _ = create_pubsub_and_gossipsub_instances(libp2p_hosts, \
SUPPORTED_PROTOCOLS, \
10, 9, 11, 30, 3, 5, 0.5)
@ -231,41 +231,35 @@ async def test_dense():
await asyncio.sleep(2)
for i in range(num_msgs):
msg_content = "foo " + str(i)
msg_content = b"foo " + i.to_bytes(1, 'big')
# randomly pick a message origin
origin_idx = random.randint(0, num_hosts - 1)
origin_host = libp2p_hosts[origin_idx]
host_id = str(origin_host.get_id())
# Generate message packet
packet = generate_RPC_packet(host_id, ["foobar"], msg_content, next_msg_id_func())
# publish from the randomly chosen host
await gossipsubs[origin_idx].publish(host_id, packet.SerializeToString())
await pubsubs[origin_idx].publish("foobar", msg_content)
await asyncio.sleep(0.5)
# Assert that all blocking queues receive the message
for queue in queues:
msg = await queue.get()
assert msg.data == packet.publish[0].data
assert msg.data == msg_content
await cleanup()
@pytest.mark.asyncio
async def test_fanout():
# Create libp2p hosts
next_msg_id_func = message_id_generator(0)
num_hosts = 10
num_msgs = 5
libp2p_hosts = await create_libp2p_hosts(num_hosts)
# Create pubsub, gossipsub instances
pubsubs, gossipsubs = create_pubsub_and_gossipsub_instances(libp2p_hosts, \
pubsubs, _ = create_pubsub_and_gossipsub_instances(libp2p_hosts, \
SUPPORTED_PROTOCOLS, \
10, 9, 11, 30, 3, 5, 0.5)
# All pubsub subscribe to foobar
# All pubsub subscribe to foobar except for `pubsubs[0]`
queues = []
for i in range(1, len(pubsubs)):
q = await pubsubs[i].subscribe("foobar")
@ -279,71 +273,61 @@ async def test_fanout():
# Wait 2 seconds for heartbeat to allow mesh to connect
await asyncio.sleep(2)
topic = "foobar"
# Send messages with origin not subscribed
for i in range(num_msgs):
msg_content = "foo " + str(i)
msg_content = b"foo " + i.to_bytes(1, "big")
# Pick the message origin to the node that is not subscribed to 'foobar'
origin_idx = 0
origin_host = libp2p_hosts[origin_idx]
host_id = str(origin_host.get_id())
# Generate message packet
packet = generate_RPC_packet(host_id, ["foobar"], msg_content, next_msg_id_func())
# publish from the randomly chosen host
await gossipsubs[origin_idx].publish(host_id, packet.SerializeToString())
await pubsubs[origin_idx].publish(topic, msg_content)
await asyncio.sleep(0.5)
# Assert that all blocking queues receive the message
for queue in queues:
msg = await queue.get()
assert msg.SerializeToString() == packet.publish[0].SerializeToString()
assert msg.data == msg_content
# Subscribe message origin
queues.append(await pubsubs[0].subscribe("foobar"))
queues.insert(0, await pubsubs[0].subscribe(topic))
# Send messages again
for i in range(num_msgs):
msg_content = "foo " + str(i)
msg_content = b"bar " + i.to_bytes(1, 'big')
# Pick the message origin to the node that is not subscribed to 'foobar'
origin_idx = 0
origin_host = libp2p_hosts[origin_idx]
host_id = str(origin_host.get_id())
# Generate message packet
packet = generate_RPC_packet(host_id, ["foobar"], msg_content, next_msg_id_func())
# publish from the randomly chosen host
await gossipsubs[origin_idx].publish(host_id, packet.SerializeToString())
await pubsubs[origin_idx].publish(topic, msg_content)
await asyncio.sleep(0.5)
# Assert that all blocking queues receive the message
for queue in queues:
msg = await queue.get()
assert msg.SerializeToString() == packet.publish[0].SerializeToString()
assert msg.data == msg_content
await cleanup()
@pytest.mark.asyncio
async def test_fanout_maintenance():
# Create libp2p hosts
next_msg_id_func = message_id_generator(0)
num_hosts = 10
num_msgs = 5
libp2p_hosts = await create_libp2p_hosts(num_hosts)
# Create pubsub, gossipsub instances
pubsubs, gossipsubs = create_pubsub_and_gossipsub_instances(libp2p_hosts, \
pubsubs, _ = create_pubsub_and_gossipsub_instances(libp2p_hosts, \
SUPPORTED_PROTOCOLS, \
10, 9, 11, 30, 3, 5, 0.5)
# All pubsub subscribe to foobar
queues = []
topic = "foobar"
for i in range(1, len(pubsubs)):
q = await pubsubs[i].subscribe("foobar")
q = await pubsubs[i].subscribe(topic)
# Add each blocking queue to an array of blocking queues
queues.append(q)
@ -356,27 +340,22 @@ async def test_fanout_maintenance():
# Send messages with origin not subscribed
for i in range(num_msgs):
msg_content = "foo " + str(i)
msg_content = b"foo " + i.to_bytes(1, 'big')
# Pick the message origin to the node that is not subscribed to 'foobar'
origin_idx = 0
origin_host = libp2p_hosts[origin_idx]
host_id = str(origin_host.get_id())
# Generate message packet
packet = generate_RPC_packet(host_id, ["foobar"], msg_content, next_msg_id_func())
# publish from the randomly chosen host
await gossipsubs[origin_idx].publish(host_id, packet.SerializeToString())
await pubsubs[origin_idx].publish(topic, msg_content)
await asyncio.sleep(0.5)
# Assert that all blocking queues receive the message
for queue in queues:
msg = await queue.get()
assert msg.SerializeToString() == packet.publish[0].SerializeToString()
assert msg.data == msg_content
for sub in pubsubs:
await sub.unsubscribe('foobar')
await sub.unsubscribe(topic)
queues = []
@ -384,7 +363,7 @@ async def test_fanout_maintenance():
# Resub and repeat
for i in range(1, len(pubsubs)):
q = await pubsubs[i].subscribe("foobar")
q = await pubsubs[i].subscribe(topic)
# Add each blocking queue to an array of blocking queues
queues.append(q)
@ -393,65 +372,61 @@ async def test_fanout_maintenance():
# Check messages can still be sent
for i in range(num_msgs):
msg_content = "foo " + str(i)
msg_content = b"bar " + i.to_bytes(1, 'big')
# Pick the message origin to the node that is not subscribed to 'foobar'
origin_idx = 0
origin_host = libp2p_hosts[origin_idx]
host_id = str(origin_host.get_id())
# Generate message packet
packet = generate_RPC_packet(host_id, ["foobar"], msg_content, next_msg_id_func())
# publish from the randomly chosen host
await gossipsubs[origin_idx].publish(host_id, packet.SerializeToString())
await pubsubs[origin_idx].publish(topic, msg_content)
await asyncio.sleep(0.5)
# Assert that all blocking queues receive the message
for queue in queues:
msg = await queue.get()
assert msg.SerializeToString() == packet.publish[0].SerializeToString()
assert msg.data == msg_content
await cleanup()
@pytest.mark.asyncio
async def test_gossip_propagation():
# Create libp2p hosts
next_msg_id_func = message_id_generator(0)
num_hosts = 2
libp2p_hosts = await create_libp2p_hosts(num_hosts)
hosts = await create_libp2p_hosts(num_hosts)
# Create pubsub, gossipsub instances
pubsubs, gossipsubs = create_pubsub_and_gossipsub_instances(libp2p_hosts, \
SUPPORTED_PROTOCOLS, \
1, 0, 2, 30, 50, 100, 0.5)
node1, node2 = libp2p_hosts[0], libp2p_hosts[1]
sub1, sub2 = pubsubs[0], pubsubs[1]
gsub1, gsub2 = gossipsubs[0], gossipsubs[1]
pubsubs, _ = create_pubsub_and_gossipsub_instances(
hosts,
SUPPORTED_PROTOCOLS,
1,
0,
2,
30,
50,
100,
0.5,
)
node1_queue = await sub1.subscribe('foo')
topic = "foo"
await pubsubs[0].subscribe(topic)
# node 1 publish to topic
msg_content = 'foo_msg'
node1_id = str(node1.get_id())
# Generate message packet
packet = generate_RPC_packet(node1_id, ["foo"], msg_content, next_msg_id_func())
# node 0 publish to topic
msg_content = b'foo_msg'
# publish from the randomly chosen host
await gsub1.publish(node1_id, packet.SerializeToString())
await pubsubs[0].publish(topic, msg_content)
# now node 2 subscribes
node2_queue = await sub2.subscribe('foo')
# now node 1 subscribes
queue_1 = await pubsubs[1].subscribe(topic)
await connect(node2, node1)
await connect(hosts[0], hosts[1])
# wait for gossip heartbeat
await asyncio.sleep(2)
# should be able to read message
msg = await node2_queue.get()
assert msg.SerializeToString() == packet.publish[0].SerializeToString()
msg = await queue_1.get()
assert msg.data == msg_content
await cleanup()

View File

@ -8,18 +8,17 @@ from libp2p.pubsub.gossipsub import GossipSub
from libp2p.pubsub.floodsub import FloodSub
from libp2p.pubsub.pb import rpc_pb2
from libp2p.pubsub.pubsub import Pubsub
from utils import message_id_generator, generate_RPC_packet
from tests.utils import cleanup
# pylint: disable=too-many-locals
from .utils import (
connect,
message_id_generator,
generate_RPC_packet,
)
async def connect(node1, node2):
"""
Connect node1 to node2
"""
addr = node2.get_addrs()[0]
info = info_from_p2p_addr(addr)
await node1.connect(info)
# pylint: disable=too-many-locals
@pytest.mark.asyncio
async def test_init():
@ -37,11 +36,12 @@ async def test_init():
await cleanup()
async def perform_test_from_obj(obj):
"""
Perform a floodsub test from a test obj.
test obj are composed as follows:
{
"supported_protocols": ["supported/protocol/1.0.0",...],
"adj_list": {
@ -95,7 +95,7 @@ async def perform_test_from_obj(obj):
if neighbor_id not in node_map:
neighbor_node = await new_node(transport_opt=["/ip4/127.0.0.1/tcp/0"])
await neighbor_node.get_network().listen(multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/0"))
node_map[neighbor_id] = neighbor_node
gossipsub = GossipSub(supported_protocols, 3, 2, 4, 30)
@ -104,7 +104,7 @@ async def perform_test_from_obj(obj):
pubsub_map[neighbor_id] = pubsub
# Connect node and neighbor
tasks_connect.append(asyncio.ensure_future(connect(node_map[start_node_id], node_map[neighbor_id])))
tasks_connect.append(connect(node_map[start_node_id], node_map[neighbor_id]))
tasks_connect.append(asyncio.sleep(2))
await asyncio.gather(*tasks_connect)
@ -130,7 +130,7 @@ async def perform_test_from_obj(obj):
# Store queue in topic-queue map for node
queues_map[node_id][topic] = q
"""
tasks_topic.append(asyncio.ensure_future(pubsub_map[node_id].subscribe(topic)))
tasks_topic.append(pubsub_map[node_id].subscribe(topic))
tasks_topic_data.append((node_id, topic))
tasks_topic.append(asyncio.sleep(2))
@ -152,29 +152,27 @@ async def perform_test_from_obj(obj):
topics_in_msgs_ordered = []
messages = obj["messages"]
tasks_publish = []
next_msg_id_func = message_id_generator(0)
for msg in messages:
topics = msg["topics"]
data = msg["data"]
node_id = msg["node_id"]
# Get actual id for sender node (not the id from the test obj)
actual_node_id = str(node_map[node_id].get_id())
# Create correctly formatted message
msg_talk = generate_RPC_packet(actual_node_id, topics, data, next_msg_id_func())
# Publish message
tasks_publish.append(asyncio.ensure_future(gossipsub_map[node_id].publish(\
actual_node_id, msg_talk.SerializeToString())))
# FIXME: This should be one RPC packet with several topics
for topic in topics:
tasks_publish.append(
pubsub_map[node_id].publish(
topic,
data,
)
)
# For each topic in topics, add topic, msg_talk tuple to ordered test list
# TODO: Update message sender to be correct message sender before
# adding msg_talk to this list
for topic in topics:
topics_in_msgs_ordered.append((topic, msg_talk))
topics_in_msgs_ordered.append((topic, data))
# Allow time for publishing before continuing
# await asyncio.sleep(0.4)
@ -183,14 +181,12 @@ async def perform_test_from_obj(obj):
# Step 4) Check that all messages were received correctly.
# TODO: Check message sender too
for i in range(len(topics_in_msgs_ordered)):
topic, actual_msg = topics_in_msgs_ordered[i]
for topic, data in topics_in_msgs_ordered:
# Look at each node in each topic
for node_id in topic_map[topic]:
# Get message from subscription queue
msg_on_node = await queues_map[node_id][topic].get()
assert actual_msg.publish[0].SerializeToString() == msg_on_node.SerializeToString()
assert msg_on_node.data == data
# Success, terminate pending tasks.
await cleanup()

View File

@ -14,6 +14,8 @@ from libp2p.peer.id import ID
from libp2p.pubsub.pubsub import Pubsub
from libp2p.pubsub.gossipsub import GossipSub
from tests.utils import connect
def message_id_generator(start_val):
"""
@ -22,6 +24,7 @@ def message_id_generator(start_val):
:return: message id
"""
val = start_val
def generator():
# Allow manipulation of val within closure
nonlocal val
@ -105,6 +108,10 @@ def create_pubsub_and_gossipsub_instances(libp2p_hosts, supported_protocols, deg
return pubsubs, gossipsubs
# FIXME: There is no difference between `sparse_connect` and `dense_connect`,
# before `connect_some` is fixed.
async def sparse_connect(hosts):
await connect_some(hosts, 3)
@ -113,6 +120,7 @@ async def dense_connect(hosts):
await connect_some(hosts, 10)
# FIXME: `degree` is not used at all
async def connect_some(hosts, degree):
for i, host in enumerate(hosts):
for j, host2 in enumerate(hosts):
@ -135,6 +143,7 @@ async def connect_some(hosts, degree):
# j += 1
async def one_to_all_connect(hosts, central_host_index):
for i, host in enumerate(hosts):
if i != central_host_index: