Fix the tests according to pubsub.Publish
And refactored a bit.
This commit is contained in:
parent
cae4f34034
commit
dadcf8138e
|
@ -91,6 +91,7 @@ class GossipSub(IPubsubRouter):
|
||||||
:param rpc: rpc message
|
:param rpc: rpc message
|
||||||
"""
|
"""
|
||||||
control_message = rpc.control
|
control_message = rpc.control
|
||||||
|
sender_peer_id = str(sender_peer_id)
|
||||||
|
|
||||||
# Relay each rpc control to the appropriate handler
|
# Relay each rpc control to the appropriate handler
|
||||||
if control_message.ihave:
|
if control_message.ihave:
|
||||||
|
|
|
@ -128,26 +128,21 @@ class Pubsub:
|
||||||
messages from other nodes
|
messages from other nodes
|
||||||
:param stream: stream to continously read from
|
:param stream: stream to continously read from
|
||||||
"""
|
"""
|
||||||
|
peer_id = stream.mplex_conn.peer_id
|
||||||
# TODO check on types here
|
|
||||||
peer_id = str(stream.mplex_conn.peer_id)
|
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
incoming = (await stream.read())
|
incoming = (await stream.read())
|
||||||
rpc_incoming = rpc_pb2.RPC()
|
rpc_incoming = rpc_pb2.RPC()
|
||||||
rpc_incoming.ParseFromString(incoming)
|
rpc_incoming.ParseFromString(incoming)
|
||||||
|
|
||||||
should_publish = False
|
|
||||||
|
|
||||||
if rpc_incoming.publish:
|
if rpc_incoming.publish:
|
||||||
# deal with RPC.publish
|
# deal with RPC.publish
|
||||||
for message in rpc_incoming.publish:
|
for msg in rpc_incoming.publish:
|
||||||
id_in_seen_msgs = (message.seqno, message.from_id)
|
if not self._is_subscribed_to_msg(msg):
|
||||||
if id_in_seen_msgs not in self.seen_messages:
|
continue
|
||||||
should_publish = True
|
# TODO(mhchia): This will block this read_stream loop until all data are pushed.
|
||||||
self.seen_messages[id_in_seen_msgs] = 1
|
# Should investigate further if this is an issue.
|
||||||
|
await self.push_msg(src=peer_id, msg=msg)
|
||||||
await self.handle_talk(message)
|
|
||||||
|
|
||||||
if rpc_incoming.subscriptions:
|
if rpc_incoming.subscriptions:
|
||||||
# deal with RPC.subscriptions
|
# deal with RPC.subscriptions
|
||||||
|
@ -158,10 +153,6 @@ class Pubsub:
|
||||||
for message in rpc_incoming.subscriptions:
|
for message in rpc_incoming.subscriptions:
|
||||||
self.handle_subscription(peer_id, message)
|
self.handle_subscription(peer_id, message)
|
||||||
|
|
||||||
if should_publish:
|
|
||||||
# relay message to peers with router
|
|
||||||
await self.router.publish(peer_id, incoming)
|
|
||||||
|
|
||||||
if rpc_incoming.control:
|
if rpc_incoming.control:
|
||||||
# Pass rpc to router so router could perform custom logic
|
# Pass rpc to router so router could perform custom logic
|
||||||
await self.router.handle_rpc(rpc_incoming, peer_id)
|
await self.router.handle_rpc(rpc_incoming, peer_id)
|
||||||
|
@ -228,6 +219,7 @@ class Pubsub:
|
||||||
:param origin_id: id of the peer who subscribe to the message
|
:param origin_id: id of the peer who subscribe to the message
|
||||||
:param sub_message: RPC.SubOpts
|
:param sub_message: RPC.SubOpts
|
||||||
"""
|
"""
|
||||||
|
origin_id = str(origin_id)
|
||||||
if sub_message.subscribe:
|
if sub_message.subscribe:
|
||||||
if sub_message.topicid not in self.peer_topics:
|
if sub_message.topicid not in self.peer_topics:
|
||||||
self.peer_topics[sub_message.topicid] = [origin_id]
|
self.peer_topics[sub_message.topicid] = [origin_id]
|
||||||
|
@ -379,3 +371,8 @@ class Pubsub:
|
||||||
# FIXME: Mapping `msg_id` to `1` is quite awkward. Should investigate if there is a
|
# FIXME: Mapping `msg_id` to `1` is quite awkward. Should investigate if there is a
|
||||||
# more appropriate way.
|
# more appropriate way.
|
||||||
self.seen_messages[msg_id] = 1
|
self.seen_messages[msg_id] = 1
|
||||||
|
|
||||||
|
def _is_subscribed_to_msg(self, msg: rpc_pb2.Message) -> bool:
|
||||||
|
if len(self.my_topics) == 0:
|
||||||
|
return False
|
||||||
|
return all([topic in self.my_topics for topic in msg.topicIDs])
|
||||||
|
|
|
@ -4,17 +4,17 @@ import pytest
|
||||||
|
|
||||||
from tests.utils import cleanup
|
from tests.utils import cleanup
|
||||||
from libp2p import new_node
|
from libp2p import new_node
|
||||||
|
from libp2p.peer.id import ID
|
||||||
from libp2p.peer.peerinfo import info_from_p2p_addr
|
from libp2p.peer.peerinfo import info_from_p2p_addr
|
||||||
from libp2p.pubsub.pb import rpc_pb2
|
|
||||||
from libp2p.pubsub.pubsub import Pubsub
|
from libp2p.pubsub.pubsub import Pubsub
|
||||||
from libp2p.pubsub.floodsub import FloodSub
|
from libp2p.pubsub.floodsub import FloodSub
|
||||||
|
|
||||||
from .utils import (
|
from .utils import (
|
||||||
make_pubsub_msg,
|
|
||||||
message_id_generator,
|
message_id_generator,
|
||||||
generate_RPC_packet,
|
generate_RPC_packet,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=too-many-locals
|
# pylint: disable=too-many-locals
|
||||||
|
|
||||||
async def connect(node1, node2):
|
async def connect(node1, node2):
|
||||||
|
@ -39,106 +39,84 @@ async def test_simple_two_nodes():
|
||||||
data = b"some data"
|
data = b"some data"
|
||||||
|
|
||||||
floodsub_a = FloodSub(supported_protocols)
|
floodsub_a = FloodSub(supported_protocols)
|
||||||
pubsub_a = Pubsub(node_a, floodsub_a, "a")
|
pubsub_a = Pubsub(node_a, floodsub_a, ID(b"a" * 32))
|
||||||
floodsub_b = FloodSub(supported_protocols)
|
floodsub_b = FloodSub(supported_protocols)
|
||||||
pubsub_b = Pubsub(node_b, floodsub_b, "b")
|
pubsub_b = Pubsub(node_b, floodsub_b, ID(b"b" * 32))
|
||||||
|
|
||||||
await connect(node_a, node_b)
|
await connect(node_a, node_b)
|
||||||
|
|
||||||
await asyncio.sleep(0.25)
|
await asyncio.sleep(0.25)
|
||||||
|
|
||||||
sub_b = await pubsub_b.subscribe(topic)
|
sub_b = await pubsub_b.subscribe(topic)
|
||||||
|
# Sleep to let a know of b's subscription
|
||||||
await asyncio.sleep(0.25)
|
await asyncio.sleep(0.25)
|
||||||
|
|
||||||
next_msg_id_func = message_id_generator(0)
|
await pubsub_a.publish(topic, data)
|
||||||
msg = make_pubsub_msg(
|
|
||||||
origin_id=node_a.get_id(),
|
|
||||||
topic_ids=[topic],
|
|
||||||
data=data,
|
|
||||||
seqno=next_msg_id_func(),
|
|
||||||
)
|
|
||||||
await floodsub_a.publish(node_a.get_id(), msg)
|
|
||||||
await asyncio.sleep(0.25)
|
|
||||||
|
|
||||||
res_b = await sub_b.get()
|
res_b = await sub_b.get()
|
||||||
|
|
||||||
# Check that the msg received by node_b is the same
|
# Check that the msg received by node_b is the same
|
||||||
# as the message sent by node_a
|
# as the message sent by node_a
|
||||||
assert res_b.SerializeToString() == msg.SerializeToString()
|
assert ID(res_b.from_id) == node_a.get_id()
|
||||||
|
assert res_b.data == data
|
||||||
|
assert res_b.topicIDs == [topic]
|
||||||
|
|
||||||
# Success, terminate pending tasks.
|
# Success, terminate pending tasks.
|
||||||
await cleanup()
|
await cleanup()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_lru_cache_two_nodes():
|
async def test_lru_cache_two_nodes(monkeypatch):
|
||||||
# two nodes with cache_size of 4
|
# two nodes with cache_size of 4
|
||||||
# node_a send the following messages to node_b
|
# `node_a` send the following messages to node_b
|
||||||
# [1, 1, 2, 1, 3, 1, 4, 1, 5, 1]
|
message_indices = [1, 1, 2, 1, 3, 1, 4, 1, 5, 1]
|
||||||
# node_b should only receive the following
|
# `node_b` should only receive the following
|
||||||
# [1, 2, 3, 4, 5, 1]
|
expected_received_indices = [1, 2, 3, 4, 5, 1]
|
||||||
node_a = await new_node(transport_opt=["/ip4/127.0.0.1/tcp/0"])
|
|
||||||
node_b = await new_node(transport_opt=["/ip4/127.0.0.1/tcp/0"])
|
|
||||||
|
|
||||||
await node_a.get_network().listen(multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/0"))
|
listen_maddr = multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/0")
|
||||||
await node_b.get_network().listen(multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/0"))
|
node_a = await new_node()
|
||||||
|
node_b = await new_node()
|
||||||
|
|
||||||
|
await node_a.get_network().listen(listen_maddr)
|
||||||
|
await node_b.get_network().listen(listen_maddr)
|
||||||
|
|
||||||
supported_protocols = ["/floodsub/1.0.0"]
|
supported_protocols = ["/floodsub/1.0.0"]
|
||||||
|
topic = "my_topic"
|
||||||
|
|
||||||
# initialize PubSub with a cache_size of 4
|
# Mock `get_msg_id` to make us easier to manipulate `msg_id` by `data`.
|
||||||
|
def get_msg_id(msg):
|
||||||
|
# Originally it is `(msg.seqno, msg.from_id)`
|
||||||
|
return (msg.data, msg.from_id)
|
||||||
|
import libp2p.pubsub.pubsub
|
||||||
|
monkeypatch.setattr(libp2p.pubsub.pubsub, "get_msg_id", get_msg_id)
|
||||||
|
|
||||||
|
# Initialize Pubsub with a cache_size of 4
|
||||||
|
cache_size = 4
|
||||||
floodsub_a = FloodSub(supported_protocols)
|
floodsub_a = FloodSub(supported_protocols)
|
||||||
pubsub_a = Pubsub(node_a, floodsub_a, "a", 4)
|
pubsub_a = Pubsub(node_a, floodsub_a, ID(b"a" * 32), cache_size)
|
||||||
|
|
||||||
floodsub_b = FloodSub(supported_protocols)
|
floodsub_b = FloodSub(supported_protocols)
|
||||||
pubsub_b = Pubsub(node_b, floodsub_b, "b", 4)
|
pubsub_b = Pubsub(node_b, floodsub_b, ID(b"b" * 32), cache_size)
|
||||||
|
|
||||||
await connect(node_a, node_b)
|
await connect(node_a, node_b)
|
||||||
|
|
||||||
await asyncio.sleep(0.25)
|
|
||||||
qb = await pubsub_b.subscribe("my_topic")
|
|
||||||
|
|
||||||
await asyncio.sleep(0.25)
|
await asyncio.sleep(0.25)
|
||||||
|
|
||||||
node_a_id = str(node_a.get_id())
|
sub_b = await pubsub_b.subscribe(topic)
|
||||||
|
|
||||||
# initialize message_id_generator
|
|
||||||
# store first message
|
|
||||||
next_msg_id_func = message_id_generator(0)
|
|
||||||
first_message = generate_RPC_packet(node_a_id, ["my_topic"], "some data 1", next_msg_id_func())
|
|
||||||
|
|
||||||
await floodsub_a.publish(node_a_id, first_message.SerializeToString())
|
|
||||||
await asyncio.sleep(0.25)
|
|
||||||
print (first_message)
|
|
||||||
|
|
||||||
messages = [first_message]
|
|
||||||
# for the next 5 messages
|
|
||||||
for i in range(2, 6):
|
|
||||||
# write first message
|
|
||||||
await floodsub_a.publish(node_a_id, first_message.SerializeToString())
|
|
||||||
await asyncio.sleep(0.25)
|
|
||||||
|
|
||||||
# generate and write next message
|
|
||||||
msg = generate_RPC_packet(node_a_id, ["my_topic"], "some data " + str(i), next_msg_id_func())
|
|
||||||
messages.append(msg)
|
|
||||||
|
|
||||||
await floodsub_a.publish(node_a_id, msg.SerializeToString())
|
|
||||||
await asyncio.sleep(0.25)
|
|
||||||
|
|
||||||
# write first message again
|
|
||||||
await floodsub_a.publish(node_a_id, first_message.SerializeToString())
|
|
||||||
await asyncio.sleep(0.25)
|
await asyncio.sleep(0.25)
|
||||||
|
|
||||||
# check the first five messages in queue
|
def _make_testing_data(i: int) -> bytes:
|
||||||
# should only see 1 first_message
|
num_int_bytes = 4
|
||||||
for i in range(5):
|
if i >= 2**(num_int_bytes * 8):
|
||||||
# Check that the msg received by node_b is the same
|
raise ValueError("")
|
||||||
# as the message sent by node_a
|
return b"data" + i.to_bytes(num_int_bytes, "big")
|
||||||
res_b = await qb.get()
|
|
||||||
assert res_b.SerializeToString() == messages[i].publish[0].SerializeToString()
|
|
||||||
|
|
||||||
# the 6th message should be first_message
|
for index in message_indices:
|
||||||
res_b = await qb.get()
|
await pubsub_a.publish(topic, _make_testing_data(index))
|
||||||
assert res_b.SerializeToString() == first_message.publish[0].SerializeToString()
|
await asyncio.sleep(0.25)
|
||||||
assert qb.empty()
|
|
||||||
|
for index in expected_received_indices:
|
||||||
|
res_b = await sub_b.get()
|
||||||
|
assert res_b.data == _make_testing_data(index)
|
||||||
|
assert sub_b.empty()
|
||||||
|
|
||||||
# Success, terminate pending tasks.
|
# Success, terminate pending tasks.
|
||||||
await cleanup()
|
await cleanup()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user