Add SubscriptionAPI

And `TrioSubscriptionAPI`, to make subscription io-agnostic.
This commit is contained in:
mhchia 2019-12-17 18:17:28 +08:00
parent fb0519129d
commit 47d10e186f
No known key found for this signature in database
GPG Key ID: 389EFBEA1362589A
12 changed files with 158 additions and 36 deletions

View File

@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, List
from typing import TYPE_CHECKING, AsyncContextManager, AsyncIterable, List
from libp2p.peer.id import ID
from libp2p.typing import TProtocol
@ -10,6 +10,11 @@ if TYPE_CHECKING:
from .pubsub import Pubsub # noqa: F401
# TODO: Add interface for Pubsub
class IPubsub(ABC):
pass
class IPubsubRouter(ABC):
@abstractmethod
def get_protocols(self) -> List[TProtocol]:
@ -53,7 +58,6 @@ class IPubsubRouter(ABC):
:param rpc: rpc message
"""
# FIXME: Should be changed to type 'peer.ID'
@abstractmethod
async def publish(self, msg_forwarder: ID, pubsub_msg: rpc_pb2.Message) -> None:
"""
@ -80,3 +84,15 @@ class IPubsubRouter(ABC):
:param topic: topic to leave
"""
class ISubscriptionAPI(
AsyncContextManager["ISubscriptionAPI"], AsyncIterable[rpc_pb2.Message]
):
@abstractmethod
async def cancel(self) -> None:
...
@abstractmethod
async def get(self) -> rpc_pb2.Message:
...

View File

@ -8,9 +8,9 @@ from libp2p.peer.id import ID
from libp2p.typing import TProtocol
from libp2p.utils import encode_varint_prefixed
from .abc import IPubsubRouter
from .pb import rpc_pb2
from .pubsub import Pubsub
from .pubsub_router_interface import IPubsubRouter
PROTOCOL_ID = TProtocol("/floodsub/1.0.0")

View File

@ -12,11 +12,11 @@ from libp2p.pubsub import floodsub
from libp2p.typing import TProtocol
from libp2p.utils import encode_varint_prefixed
from .abc import IPubsubRouter
from .exceptions import NoPubsubAttached
from .mcache import MessageCache
from .pb import rpc_pb2
from .pubsub import Pubsub
from .pubsub_router_interface import IPubsubRouter
PROTOCOL_ID = TProtocol("/meshsub/1.0.0")

View File

@ -1,4 +1,3 @@
from abc import ABC
import logging
import math
import time
@ -30,12 +29,14 @@ from libp2p.peer.id import ID
from libp2p.typing import TProtocol
from libp2p.utils import encode_varint_prefixed, read_varint_prefixed_bytes
from .abc import IPubsub, ISubscriptionAPI
from .pb import rpc_pb2
from .pubsub_notifee import PubsubNotifee
from .subscription import TrioSubscriptionAPI
from .validators import signature_validator
if TYPE_CHECKING:
from .pubsub_router_interface import IPubsubRouter # noqa: F401
from .abc import IPubsubRouter # noqa: F401
from typing import Any # noqa: F401
@ -57,11 +58,6 @@ class TopicValidator(NamedTuple):
is_async: bool
# TODO: Add interface for Pubsub
class IPubsub(ABC):
pass
class Pubsub(IPubsub, Service):
host: IHost
@ -75,7 +71,7 @@ class Pubsub(IPubsub, Service):
# TODO: Implement `trio.abc.Channel`?
subscribed_topics_send: Dict[str, "trio.MemorySendChannel[rpc_pb2.Message]"]
subscribed_topics_receive: Dict[str, "trio.MemoryReceiveChannel[rpc_pb2.Message]"]
subscribed_topics_receive: Dict[str, "TrioSubscriptionAPI"]
peer_topics: Dict[str, List[ID]]
peers: Dict[ID, INetStream]
@ -380,10 +376,7 @@ class Pubsub(IPubsub, Service):
# for each topic
await self.subscribed_topics_send[topic].send(publish_message)
# TODO: Change to return an `AsyncIterable` to be I/O-agnostic?
async def subscribe(
self, topic_id: str
) -> "trio.MemoryReceiveChannel[rpc_pb2.Message]":
async def subscribe(self, topic_id: str) -> ISubscriptionAPI:
"""
Subscribe ourself to a topic.
@ -396,14 +389,14 @@ class Pubsub(IPubsub, Service):
if topic_id in self.topic_ids:
return self.subscribed_topics_receive[topic_id]
# Map topic_id to a blocking channel
channels: Tuple[
"trio.MemorySendChannel[rpc_pb2.Message]",
"trio.MemoryReceiveChannel[rpc_pb2.Message]",
] = trio.open_memory_channel(math.inf)
send_channel, receive_channel = channels
subscription = TrioSubscriptionAPI(receive_channel)
self.subscribed_topics_send[topic_id] = send_channel
self.subscribed_topics_receive[topic_id] = receive_channel
self.subscribed_topics_receive[topic_id] = subscription
# Create subscribe message
packet: rpc_pb2.RPC = rpc_pb2.RPC()
@ -417,8 +410,8 @@ class Pubsub(IPubsub, Service):
# Tell router we are joining this topic
await self.router.join(topic_id)
# Return the trio channel for messages on this topic
return receive_channel
# Return the subscription for messages on this topic
return subscription
async def unsubscribe(self, topic_id: str) -> None:
"""

View File

@ -0,0 +1,39 @@
from types import TracebackType
from typing import AsyncIterator, Optional, Type
import trio
from .abc import ISubscriptionAPI
from .pb import rpc_pb2
class BaseSubscriptionAPI(ISubscriptionAPI):
async def __aenter__(self) -> "BaseSubscriptionAPI":
await trio.hazmat.checkpoint()
return self
async def __aexit__(
self,
exc_type: "Optional[Type[BaseException]]",
exc_value: "Optional[BaseException]",
traceback: "Optional[TracebackType]",
) -> None:
await self.cancel()
class TrioSubscriptionAPI(BaseSubscriptionAPI):
receive_channel: "trio.MemoryReceiveChannel[rpc_pb2.Message]"
def __init__(
self, receive_channel: "trio.MemoryReceiveChannel[rpc_pb2.Message]"
) -> None:
self.receive_channel = receive_channel
async def cancel(self) -> None:
await self.receive_channel.aclose()
def __aiter__(self) -> AsyncIterator[rpc_pb2.Message]:
return self.receive_channel.__aiter__()
async def get(self) -> rpc_pb2.Message:
return await self.receive_channel.receive()

View File

@ -17,10 +17,10 @@ from libp2p.network.stream.net_stream_interface import INetStream
from libp2p.network.swarm import Swarm
from libp2p.peer.id import ID
from libp2p.peer.peerstore import PeerStore
from libp2p.pubsub.abc import IPubsubRouter
from libp2p.pubsub.floodsub import FloodSub
from libp2p.pubsub.gossipsub import GossipSub
from libp2p.pubsub.pubsub import Pubsub
from libp2p.pubsub.pubsub_router_interface import IPubsubRouter
from libp2p.security.base_transport import BaseSecureTransport
from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport
import libp2p.security.secio.transport as secio

View File

@ -61,7 +61,7 @@ class DummyAccountNode(Service):
async def handle_incoming_msgs(self) -> None:
"""Handle all incoming messages on the CRYPTO_TOPIC from peers."""
while True:
incoming = await self.subscription.receive()
incoming = await self.subscription.get()
msg_comps = incoming.data.decode("utf-8").split(",")
if msg_comps[0] == "send":

View File

@ -250,7 +250,7 @@ async def perform_test_from_obj(obj, pubsub_factory) -> None:
# Look at each node in each topic
for node_id in topic_map[topic]:
# Get message from subscription queue
msg = await queues_map[node_id][topic].receive()
msg = await queues_map[node_id][topic].get()
assert data == msg.data
# Check the message origin
assert node_map[origin_node_id].get_id().to_bytes() == msg.from_id

View File

@ -27,7 +27,7 @@ async def test_simple_two_nodes():
await pubsubs_fsub[0].publish(topic, data)
res_b = await sub_b.receive()
res_b = await sub_b.get()
# Check that the msg received by node_b is the same
# as the message sent by node_a
@ -75,12 +75,9 @@ async def test_lru_cache_two_nodes(monkeypatch):
await trio.sleep(0.25)
for index in expected_received_indices:
res_b = await sub_b.receive()
res_b = await sub_b.get()
assert res_b.data == _make_testing_data(index)
with pytest.raises(trio.WouldBlock):
sub_b.receive_nowait()
@pytest.mark.parametrize("test_case_obj", floodsub_protocol_pytest_params)
@pytest.mark.trio

View File

@ -196,7 +196,7 @@ async def test_dense():
await trio.sleep(0.5)
# Assert that all blocking queues receive the message
for queue in queues:
msg = await queue.receive()
msg = await queue.get()
assert msg.data == msg_content
@ -229,7 +229,7 @@ async def test_fanout():
await trio.sleep(0.5)
# Assert that all blocking queues receive the message
for sub in subs:
msg = await sub.receive()
msg = await sub.get()
assert msg.data == msg_content
# Subscribe message origin
@ -248,7 +248,7 @@ async def test_fanout():
await trio.sleep(0.5)
# Assert that all blocking queues receive the message
for sub in subs:
msg = await sub.receive()
msg = await sub.get()
assert msg.data == msg_content
@ -287,7 +287,7 @@ async def test_fanout_maintenance():
await trio.sleep(0.5)
# Assert that all blocking queues receive the message
for queue in queues:
msg = await queue.receive()
msg = await queue.get()
assert msg.data == msg_content
for sub in pubsubs_gsub:
@ -319,7 +319,7 @@ async def test_fanout_maintenance():
await trio.sleep(0.5)
# Assert that all blocking queues receive the message
for queue in queues:
msg = await queue.receive()
msg = await queue.get()
assert msg.data == msg_content
@ -346,5 +346,5 @@ async def test_gossip_propagation():
await trio.sleep(2)
# should be able to read message
msg = await queue_1.receive()
msg = await queue_1.get()
assert msg.data == msg_content

View File

@ -384,7 +384,7 @@ async def test_handle_talk():
len(pubsubs_fsub[0].topic_ids) == 1
and sub == pubsubs_fsub[0].subscribed_topics_receive[TESTING_TOPIC]
)
assert (await sub.receive()) == msg_0
assert (await sub.get()) == msg_0
@pytest.mark.trio
@ -486,7 +486,7 @@ async def test_push_msg(monkeypatch):
with trio.fail_after(0.1):
await event.wait()
# Test: Subscribers are notified when `push_msg` new messages.
assert (await sub.receive()) == msg_1
assert (await sub.get()) == msg_1
with mock_router_publish() as event:
# Test: add a topic validator and `push_msg` the message that

View File

@ -0,0 +1,77 @@
import math
import pytest
import trio
from libp2p.pubsub.pb import rpc_pb2
from libp2p.pubsub.subscription import TrioSubscriptionAPI
GET_TIMEOUT = 0.001
def make_trio_subscription():
send_channel, receive_channel = trio.open_memory_channel(math.inf)
return send_channel, TrioSubscriptionAPI(receive_channel)
def make_pubsub_msg():
return rpc_pb2.Message()
async def send_something(send_channel):
msg = make_pubsub_msg()
await send_channel.send(msg)
return msg
@pytest.mark.trio
async def test_trio_subscription_get():
send_channel, sub = make_trio_subscription()
data_0 = await send_something(send_channel)
data_1 = await send_something(send_channel)
assert data_0 == await sub.get()
assert data_1 == await sub.get()
# No more message
with pytest.raises(trio.TooSlowError):
with trio.fail_after(GET_TIMEOUT):
await sub.get()
@pytest.mark.trio
async def test_trio_subscription_iter():
send_channel, sub = make_trio_subscription()
received_data = []
async def iter_subscriptions(subscription):
async for data in sub:
received_data.append(data)
async with trio.open_nursery() as nursery:
nursery.start_soon(iter_subscriptions, sub)
await send_something(send_channel)
await send_something(send_channel)
await send_channel.aclose()
assert len(received_data) == 2
@pytest.mark.trio
async def test_trio_subscription_cancel():
send_channel, sub = make_trio_subscription()
await sub.cancel()
# Test: If the subscription is cancelled, `send_channel` should be broken.
with pytest.raises(trio.BrokenResourceError):
await send_something(send_channel)
# Test: No side effect when cancelled twice.
await sub.cancel()
@pytest.mark.trio
async def test_trio_subscription_async_context_manager():
send_channel, sub = make_trio_subscription()
async with sub:
# Test: `sub` is not cancelled yet, so `send_something` works fine.
await send_something(send_channel)
# Test: `sub` is cancelled, `send_something` fails
with pytest.raises(trio.BrokenResourceError):
await send_something(send_channel)