Fix the tests according to pubsub.Publish

And refactored a bit.
This commit is contained in:
mhchia 2019-07-25 16:58:00 +08:00
parent cae4f34034
commit dadcf8138e
No known key found for this signature in database
GPG Key ID: 389EFBEA1362589A
3 changed files with 62 additions and 86 deletions

View File

@ -91,6 +91,7 @@ class GossipSub(IPubsubRouter):
:param rpc: rpc message :param rpc: rpc message
""" """
control_message = rpc.control control_message = rpc.control
sender_peer_id = str(sender_peer_id)
# Relay each rpc control to the appropriate handler # Relay each rpc control to the appropriate handler
if control_message.ihave: if control_message.ihave:

View File

@ -128,26 +128,21 @@ class Pubsub:
messages from other nodes messages from other nodes
:param stream: stream to continously read from :param stream: stream to continously read from
""" """
peer_id = stream.mplex_conn.peer_id
# TODO check on types here
peer_id = str(stream.mplex_conn.peer_id)
while True: while True:
incoming = (await stream.read()) incoming = (await stream.read())
rpc_incoming = rpc_pb2.RPC() rpc_incoming = rpc_pb2.RPC()
rpc_incoming.ParseFromString(incoming) rpc_incoming.ParseFromString(incoming)
should_publish = False
if rpc_incoming.publish: if rpc_incoming.publish:
# deal with RPC.publish # deal with RPC.publish
for message in rpc_incoming.publish: for msg in rpc_incoming.publish:
id_in_seen_msgs = (message.seqno, message.from_id) if not self._is_subscribed_to_msg(msg):
if id_in_seen_msgs not in self.seen_messages: continue
should_publish = True # TODO(mhchia): This will block this read_stream loop until all data are pushed.
self.seen_messages[id_in_seen_msgs] = 1 # Should investigate further if this is an issue.
await self.push_msg(src=peer_id, msg=msg)
await self.handle_talk(message)
if rpc_incoming.subscriptions: if rpc_incoming.subscriptions:
# deal with RPC.subscriptions # deal with RPC.subscriptions
@ -158,10 +153,6 @@ class Pubsub:
for message in rpc_incoming.subscriptions: for message in rpc_incoming.subscriptions:
self.handle_subscription(peer_id, message) self.handle_subscription(peer_id, message)
if should_publish:
# relay message to peers with router
await self.router.publish(peer_id, incoming)
if rpc_incoming.control: if rpc_incoming.control:
# Pass rpc to router so router could perform custom logic # Pass rpc to router so router could perform custom logic
await self.router.handle_rpc(rpc_incoming, peer_id) await self.router.handle_rpc(rpc_incoming, peer_id)
@ -228,6 +219,7 @@ class Pubsub:
:param origin_id: id of the peer who subscribe to the message :param origin_id: id of the peer who subscribe to the message
:param sub_message: RPC.SubOpts :param sub_message: RPC.SubOpts
""" """
origin_id = str(origin_id)
if sub_message.subscribe: if sub_message.subscribe:
if sub_message.topicid not in self.peer_topics: if sub_message.topicid not in self.peer_topics:
self.peer_topics[sub_message.topicid] = [origin_id] self.peer_topics[sub_message.topicid] = [origin_id]
@ -379,3 +371,8 @@ class Pubsub:
# FIXME: Mapping `msg_id` to `1` is quite awkward. Should investigate if there is a # FIXME: Mapping `msg_id` to `1` is quite awkward. Should investigate if there is a
# more appropriate way. # more appropriate way.
self.seen_messages[msg_id] = 1 self.seen_messages[msg_id] = 1
def _is_subscribed_to_msg(self, msg: rpc_pb2.Message) -> bool:
if len(self.my_topics) == 0:
return False
return all([topic in self.my_topics for topic in msg.topicIDs])

View File

@ -4,17 +4,17 @@ import pytest
from tests.utils import cleanup from tests.utils import cleanup
from libp2p import new_node from libp2p import new_node
from libp2p.peer.id import ID
from libp2p.peer.peerinfo import info_from_p2p_addr from libp2p.peer.peerinfo import info_from_p2p_addr
from libp2p.pubsub.pb import rpc_pb2
from libp2p.pubsub.pubsub import Pubsub from libp2p.pubsub.pubsub import Pubsub
from libp2p.pubsub.floodsub import FloodSub from libp2p.pubsub.floodsub import FloodSub
from .utils import ( from .utils import (
make_pubsub_msg,
message_id_generator, message_id_generator,
generate_RPC_packet, generate_RPC_packet,
) )
# pylint: disable=too-many-locals # pylint: disable=too-many-locals
async def connect(node1, node2): async def connect(node1, node2):
@ -39,106 +39,84 @@ async def test_simple_two_nodes():
data = b"some data" data = b"some data"
floodsub_a = FloodSub(supported_protocols) floodsub_a = FloodSub(supported_protocols)
pubsub_a = Pubsub(node_a, floodsub_a, "a") pubsub_a = Pubsub(node_a, floodsub_a, ID(b"a" * 32))
floodsub_b = FloodSub(supported_protocols) floodsub_b = FloodSub(supported_protocols)
pubsub_b = Pubsub(node_b, floodsub_b, "b") pubsub_b = Pubsub(node_b, floodsub_b, ID(b"b" * 32))
await connect(node_a, node_b) await connect(node_a, node_b)
await asyncio.sleep(0.25) await asyncio.sleep(0.25)
sub_b = await pubsub_b.subscribe(topic) sub_b = await pubsub_b.subscribe(topic)
# Sleep to let a know of b's subscription
await asyncio.sleep(0.25) await asyncio.sleep(0.25)
next_msg_id_func = message_id_generator(0) await pubsub_a.publish(topic, data)
msg = make_pubsub_msg(
origin_id=node_a.get_id(),
topic_ids=[topic],
data=data,
seqno=next_msg_id_func(),
)
await floodsub_a.publish(node_a.get_id(), msg)
await asyncio.sleep(0.25)
res_b = await sub_b.get() res_b = await sub_b.get()
# Check that the msg received by node_b is the same # Check that the msg received by node_b is the same
# as the message sent by node_a # as the message sent by node_a
assert res_b.SerializeToString() == msg.SerializeToString() assert ID(res_b.from_id) == node_a.get_id()
assert res_b.data == data
assert res_b.topicIDs == [topic]
# Success, terminate pending tasks. # Success, terminate pending tasks.
await cleanup() await cleanup()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_lru_cache_two_nodes(): async def test_lru_cache_two_nodes(monkeypatch):
# two nodes with cache_size of 4 # two nodes with cache_size of 4
# node_a send the following messages to node_b # `node_a` send the following messages to node_b
# [1, 1, 2, 1, 3, 1, 4, 1, 5, 1] message_indices = [1, 1, 2, 1, 3, 1, 4, 1, 5, 1]
# node_b should only receive the following # `node_b` should only receive the following
# [1, 2, 3, 4, 5, 1] expected_received_indices = [1, 2, 3, 4, 5, 1]
node_a = await new_node(transport_opt=["/ip4/127.0.0.1/tcp/0"])
node_b = await new_node(transport_opt=["/ip4/127.0.0.1/tcp/0"])
await node_a.get_network().listen(multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/0")) listen_maddr = multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/0")
await node_b.get_network().listen(multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/0")) node_a = await new_node()
node_b = await new_node()
await node_a.get_network().listen(listen_maddr)
await node_b.get_network().listen(listen_maddr)
supported_protocols = ["/floodsub/1.0.0"] supported_protocols = ["/floodsub/1.0.0"]
topic = "my_topic"
# initialize PubSub with a cache_size of 4 # Mock `get_msg_id` to make us easier to manipulate `msg_id` by `data`.
def get_msg_id(msg):
# Originally it is `(msg.seqno, msg.from_id)`
return (msg.data, msg.from_id)
import libp2p.pubsub.pubsub
monkeypatch.setattr(libp2p.pubsub.pubsub, "get_msg_id", get_msg_id)
# Initialize Pubsub with a cache_size of 4
cache_size = 4
floodsub_a = FloodSub(supported_protocols) floodsub_a = FloodSub(supported_protocols)
pubsub_a = Pubsub(node_a, floodsub_a, "a", 4) pubsub_a = Pubsub(node_a, floodsub_a, ID(b"a" * 32), cache_size)
floodsub_b = FloodSub(supported_protocols) floodsub_b = FloodSub(supported_protocols)
pubsub_b = Pubsub(node_b, floodsub_b, "b", 4) pubsub_b = Pubsub(node_b, floodsub_b, ID(b"b" * 32), cache_size)
await connect(node_a, node_b) await connect(node_a, node_b)
await asyncio.sleep(0.25)
qb = await pubsub_b.subscribe("my_topic")
await asyncio.sleep(0.25) await asyncio.sleep(0.25)
node_a_id = str(node_a.get_id()) sub_b = await pubsub_b.subscribe(topic)
# initialize message_id_generator
# store first message
next_msg_id_func = message_id_generator(0)
first_message = generate_RPC_packet(node_a_id, ["my_topic"], "some data 1", next_msg_id_func())
await floodsub_a.publish(node_a_id, first_message.SerializeToString())
await asyncio.sleep(0.25)
print (first_message)
messages = [first_message]
# for the next 5 messages
for i in range(2, 6):
# write first message
await floodsub_a.publish(node_a_id, first_message.SerializeToString())
await asyncio.sleep(0.25)
# generate and write next message
msg = generate_RPC_packet(node_a_id, ["my_topic"], "some data " + str(i), next_msg_id_func())
messages.append(msg)
await floodsub_a.publish(node_a_id, msg.SerializeToString())
await asyncio.sleep(0.25)
# write first message again
await floodsub_a.publish(node_a_id, first_message.SerializeToString())
await asyncio.sleep(0.25) await asyncio.sleep(0.25)
# check the first five messages in queue def _make_testing_data(i: int) -> bytes:
# should only see 1 first_message num_int_bytes = 4
for i in range(5): if i >= 2**(num_int_bytes * 8):
# Check that the msg received by node_b is the same raise ValueError("")
# as the message sent by node_a return b"data" + i.to_bytes(num_int_bytes, "big")
res_b = await qb.get()
assert res_b.SerializeToString() == messages[i].publish[0].SerializeToString()
# the 6th message should be first_message for index in message_indices:
res_b = await qb.get() await pubsub_a.publish(topic, _make_testing_data(index))
assert res_b.SerializeToString() == first_message.publish[0].SerializeToString() await asyncio.sleep(0.25)
assert qb.empty()
for index in expected_received_indices:
res_b = await sub_b.get()
assert res_b.data == _make_testing_data(index)
assert sub_b.empty()
# Success, terminate pending tasks. # Success, terminate pending tasks.
await cleanup() await cleanup()