Fix the tests according to pubsub.Publish
And refactored a bit.
This commit is contained in:
parent
cae4f34034
commit
dadcf8138e
|
@ -91,6 +91,7 @@ class GossipSub(IPubsubRouter):
|
|||
:param rpc: rpc message
|
||||
"""
|
||||
control_message = rpc.control
|
||||
sender_peer_id = str(sender_peer_id)
|
||||
|
||||
# Relay each rpc control to the appropriate handler
|
||||
if control_message.ihave:
|
||||
|
|
|
@ -128,26 +128,21 @@ class Pubsub:
|
|||
messages from other nodes
|
||||
:param stream: stream to continously read from
|
||||
"""
|
||||
|
||||
# TODO check on types here
|
||||
peer_id = str(stream.mplex_conn.peer_id)
|
||||
peer_id = stream.mplex_conn.peer_id
|
||||
|
||||
while True:
|
||||
incoming = (await stream.read())
|
||||
rpc_incoming = rpc_pb2.RPC()
|
||||
rpc_incoming.ParseFromString(incoming)
|
||||
|
||||
should_publish = False
|
||||
|
||||
if rpc_incoming.publish:
|
||||
# deal with RPC.publish
|
||||
for message in rpc_incoming.publish:
|
||||
id_in_seen_msgs = (message.seqno, message.from_id)
|
||||
if id_in_seen_msgs not in self.seen_messages:
|
||||
should_publish = True
|
||||
self.seen_messages[id_in_seen_msgs] = 1
|
||||
|
||||
await self.handle_talk(message)
|
||||
for msg in rpc_incoming.publish:
|
||||
if not self._is_subscribed_to_msg(msg):
|
||||
continue
|
||||
# TODO(mhchia): This will block this read_stream loop until all data are pushed.
|
||||
# Should investigate further if this is an issue.
|
||||
await self.push_msg(src=peer_id, msg=msg)
|
||||
|
||||
if rpc_incoming.subscriptions:
|
||||
# deal with RPC.subscriptions
|
||||
|
@ -158,10 +153,6 @@ class Pubsub:
|
|||
for message in rpc_incoming.subscriptions:
|
||||
self.handle_subscription(peer_id, message)
|
||||
|
||||
if should_publish:
|
||||
# relay message to peers with router
|
||||
await self.router.publish(peer_id, incoming)
|
||||
|
||||
if rpc_incoming.control:
|
||||
# Pass rpc to router so router could perform custom logic
|
||||
await self.router.handle_rpc(rpc_incoming, peer_id)
|
||||
|
@ -228,6 +219,7 @@ class Pubsub:
|
|||
:param origin_id: id of the peer who subscribe to the message
|
||||
:param sub_message: RPC.SubOpts
|
||||
"""
|
||||
origin_id = str(origin_id)
|
||||
if sub_message.subscribe:
|
||||
if sub_message.topicid not in self.peer_topics:
|
||||
self.peer_topics[sub_message.topicid] = [origin_id]
|
||||
|
@ -379,3 +371,8 @@ class Pubsub:
|
|||
# FIXME: Mapping `msg_id` to `1` is quite awkward. Should investigate if there is a
|
||||
# more appropriate way.
|
||||
self.seen_messages[msg_id] = 1
|
||||
|
||||
def _is_subscribed_to_msg(self, msg: rpc_pb2.Message) -> bool:
|
||||
if len(self.my_topics) == 0:
|
||||
return False
|
||||
return all([topic in self.my_topics for topic in msg.topicIDs])
|
||||
|
|
|
@ -4,17 +4,17 @@ import pytest
|
|||
|
||||
from tests.utils import cleanup
|
||||
from libp2p import new_node
|
||||
from libp2p.peer.id import ID
|
||||
from libp2p.peer.peerinfo import info_from_p2p_addr
|
||||
from libp2p.pubsub.pb import rpc_pb2
|
||||
from libp2p.pubsub.pubsub import Pubsub
|
||||
from libp2p.pubsub.floodsub import FloodSub
|
||||
|
||||
from .utils import (
|
||||
make_pubsub_msg,
|
||||
message_id_generator,
|
||||
generate_RPC_packet,
|
||||
)
|
||||
|
||||
|
||||
# pylint: disable=too-many-locals
|
||||
|
||||
async def connect(node1, node2):
|
||||
|
@ -39,106 +39,84 @@ async def test_simple_two_nodes():
|
|||
data = b"some data"
|
||||
|
||||
floodsub_a = FloodSub(supported_protocols)
|
||||
pubsub_a = Pubsub(node_a, floodsub_a, "a")
|
||||
pubsub_a = Pubsub(node_a, floodsub_a, ID(b"a" * 32))
|
||||
floodsub_b = FloodSub(supported_protocols)
|
||||
pubsub_b = Pubsub(node_b, floodsub_b, "b")
|
||||
pubsub_b = Pubsub(node_b, floodsub_b, ID(b"b" * 32))
|
||||
|
||||
await connect(node_a, node_b)
|
||||
|
||||
await asyncio.sleep(0.25)
|
||||
|
||||
sub_b = await pubsub_b.subscribe(topic)
|
||||
|
||||
# Sleep to let a know of b's subscription
|
||||
await asyncio.sleep(0.25)
|
||||
|
||||
next_msg_id_func = message_id_generator(0)
|
||||
msg = make_pubsub_msg(
|
||||
origin_id=node_a.get_id(),
|
||||
topic_ids=[topic],
|
||||
data=data,
|
||||
seqno=next_msg_id_func(),
|
||||
)
|
||||
await floodsub_a.publish(node_a.get_id(), msg)
|
||||
await asyncio.sleep(0.25)
|
||||
await pubsub_a.publish(topic, data)
|
||||
|
||||
res_b = await sub_b.get()
|
||||
|
||||
# Check that the msg received by node_b is the same
|
||||
# as the message sent by node_a
|
||||
assert res_b.SerializeToString() == msg.SerializeToString()
|
||||
assert ID(res_b.from_id) == node_a.get_id()
|
||||
assert res_b.data == data
|
||||
assert res_b.topicIDs == [topic]
|
||||
|
||||
# Success, terminate pending tasks.
|
||||
await cleanup()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lru_cache_two_nodes():
|
||||
async def test_lru_cache_two_nodes(monkeypatch):
|
||||
# two nodes with cache_size of 4
|
||||
# node_a send the following messages to node_b
|
||||
# [1, 1, 2, 1, 3, 1, 4, 1, 5, 1]
|
||||
# node_b should only receive the following
|
||||
# [1, 2, 3, 4, 5, 1]
|
||||
node_a = await new_node(transport_opt=["/ip4/127.0.0.1/tcp/0"])
|
||||
node_b = await new_node(transport_opt=["/ip4/127.0.0.1/tcp/0"])
|
||||
# `node_a` send the following messages to node_b
|
||||
message_indices = [1, 1, 2, 1, 3, 1, 4, 1, 5, 1]
|
||||
# `node_b` should only receive the following
|
||||
expected_received_indices = [1, 2, 3, 4, 5, 1]
|
||||
|
||||
await node_a.get_network().listen(multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/0"))
|
||||
await node_b.get_network().listen(multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/0"))
|
||||
listen_maddr = multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/0")
|
||||
node_a = await new_node()
|
||||
node_b = await new_node()
|
||||
|
||||
await node_a.get_network().listen(listen_maddr)
|
||||
await node_b.get_network().listen(listen_maddr)
|
||||
|
||||
supported_protocols = ["/floodsub/1.0.0"]
|
||||
topic = "my_topic"
|
||||
|
||||
# initialize PubSub with a cache_size of 4
|
||||
# Mock `get_msg_id` to make us easier to manipulate `msg_id` by `data`.
|
||||
def get_msg_id(msg):
|
||||
# Originally it is `(msg.seqno, msg.from_id)`
|
||||
return (msg.data, msg.from_id)
|
||||
import libp2p.pubsub.pubsub
|
||||
monkeypatch.setattr(libp2p.pubsub.pubsub, "get_msg_id", get_msg_id)
|
||||
|
||||
# Initialize Pubsub with a cache_size of 4
|
||||
cache_size = 4
|
||||
floodsub_a = FloodSub(supported_protocols)
|
||||
pubsub_a = Pubsub(node_a, floodsub_a, "a", 4)
|
||||
pubsub_a = Pubsub(node_a, floodsub_a, ID(b"a" * 32), cache_size)
|
||||
|
||||
floodsub_b = FloodSub(supported_protocols)
|
||||
pubsub_b = Pubsub(node_b, floodsub_b, "b", 4)
|
||||
pubsub_b = Pubsub(node_b, floodsub_b, ID(b"b" * 32), cache_size)
|
||||
|
||||
await connect(node_a, node_b)
|
||||
|
||||
await asyncio.sleep(0.25)
|
||||
qb = await pubsub_b.subscribe("my_topic")
|
||||
|
||||
await asyncio.sleep(0.25)
|
||||
|
||||
node_a_id = str(node_a.get_id())
|
||||
|
||||
# initialize message_id_generator
|
||||
# store first message
|
||||
next_msg_id_func = message_id_generator(0)
|
||||
first_message = generate_RPC_packet(node_a_id, ["my_topic"], "some data 1", next_msg_id_func())
|
||||
|
||||
await floodsub_a.publish(node_a_id, first_message.SerializeToString())
|
||||
await asyncio.sleep(0.25)
|
||||
print (first_message)
|
||||
|
||||
messages = [first_message]
|
||||
# for the next 5 messages
|
||||
for i in range(2, 6):
|
||||
# write first message
|
||||
await floodsub_a.publish(node_a_id, first_message.SerializeToString())
|
||||
await asyncio.sleep(0.25)
|
||||
|
||||
# generate and write next message
|
||||
msg = generate_RPC_packet(node_a_id, ["my_topic"], "some data " + str(i), next_msg_id_func())
|
||||
messages.append(msg)
|
||||
|
||||
await floodsub_a.publish(node_a_id, msg.SerializeToString())
|
||||
await asyncio.sleep(0.25)
|
||||
|
||||
# write first message again
|
||||
await floodsub_a.publish(node_a_id, first_message.SerializeToString())
|
||||
sub_b = await pubsub_b.subscribe(topic)
|
||||
await asyncio.sleep(0.25)
|
||||
|
||||
# check the first five messages in queue
|
||||
# should only see 1 first_message
|
||||
for i in range(5):
|
||||
# Check that the msg received by node_b is the same
|
||||
# as the message sent by node_a
|
||||
res_b = await qb.get()
|
||||
assert res_b.SerializeToString() == messages[i].publish[0].SerializeToString()
|
||||
def _make_testing_data(i: int) -> bytes:
|
||||
num_int_bytes = 4
|
||||
if i >= 2**(num_int_bytes * 8):
|
||||
raise ValueError("")
|
||||
return b"data" + i.to_bytes(num_int_bytes, "big")
|
||||
|
||||
# the 6th message should be first_message
|
||||
res_b = await qb.get()
|
||||
assert res_b.SerializeToString() == first_message.publish[0].SerializeToString()
|
||||
assert qb.empty()
|
||||
for index in message_indices:
|
||||
await pubsub_a.publish(topic, _make_testing_data(index))
|
||||
await asyncio.sleep(0.25)
|
||||
|
||||
for index in expected_received_indices:
|
||||
res_b = await sub_b.get()
|
||||
assert res_b.data == _make_testing_data(index)
|
||||
assert sub_b.empty()
|
||||
|
||||
# Success, terminate pending tasks.
|
||||
await cleanup()
|
||||
|
|
Loading…
Reference in New Issue
Block a user