Add SubscriptionAPI
And `TrioSubscriptionAPI`, to make subscription io-agnostic.
This commit is contained in:
parent
fb0519129d
commit
47d10e186f
@ -1,5 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, List
|
||||
from typing import TYPE_CHECKING, AsyncContextManager, AsyncIterable, List
|
||||
|
||||
from libp2p.peer.id import ID
|
||||
from libp2p.typing import TProtocol
|
||||
@ -10,6 +10,11 @@ if TYPE_CHECKING:
|
||||
from .pubsub import Pubsub # noqa: F401
|
||||
|
||||
|
||||
# TODO: Add interface for Pubsub
|
||||
class IPubsub(ABC):
|
||||
pass
|
||||
|
||||
|
||||
class IPubsubRouter(ABC):
|
||||
@abstractmethod
|
||||
def get_protocols(self) -> List[TProtocol]:
|
||||
@ -53,7 +58,6 @@ class IPubsubRouter(ABC):
|
||||
:param rpc: rpc message
|
||||
"""
|
||||
|
||||
# FIXME: Should be changed to type 'peer.ID'
|
||||
@abstractmethod
|
||||
async def publish(self, msg_forwarder: ID, pubsub_msg: rpc_pb2.Message) -> None:
|
||||
"""
|
||||
@ -80,3 +84,15 @@ class IPubsubRouter(ABC):
|
||||
|
||||
:param topic: topic to leave
|
||||
"""
|
||||
|
||||
|
||||
class ISubscriptionAPI(
|
||||
AsyncContextManager["ISubscriptionAPI"], AsyncIterable[rpc_pb2.Message]
|
||||
):
|
||||
@abstractmethod
|
||||
async def cancel(self) -> None:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def get(self) -> rpc_pb2.Message:
|
||||
...
|
@ -8,9 +8,9 @@ from libp2p.peer.id import ID
|
||||
from libp2p.typing import TProtocol
|
||||
from libp2p.utils import encode_varint_prefixed
|
||||
|
||||
from .abc import IPubsubRouter
|
||||
from .pb import rpc_pb2
|
||||
from .pubsub import Pubsub
|
||||
from .pubsub_router_interface import IPubsubRouter
|
||||
|
||||
PROTOCOL_ID = TProtocol("/floodsub/1.0.0")
|
||||
|
||||
|
@ -12,11 +12,11 @@ from libp2p.pubsub import floodsub
|
||||
from libp2p.typing import TProtocol
|
||||
from libp2p.utils import encode_varint_prefixed
|
||||
|
||||
from .abc import IPubsubRouter
|
||||
from .exceptions import NoPubsubAttached
|
||||
from .mcache import MessageCache
|
||||
from .pb import rpc_pb2
|
||||
from .pubsub import Pubsub
|
||||
from .pubsub_router_interface import IPubsubRouter
|
||||
|
||||
PROTOCOL_ID = TProtocol("/meshsub/1.0.0")
|
||||
|
||||
|
@ -1,4 +1,3 @@
|
||||
from abc import ABC
|
||||
import logging
|
||||
import math
|
||||
import time
|
||||
@ -30,12 +29,14 @@ from libp2p.peer.id import ID
|
||||
from libp2p.typing import TProtocol
|
||||
from libp2p.utils import encode_varint_prefixed, read_varint_prefixed_bytes
|
||||
|
||||
from .abc import IPubsub, ISubscriptionAPI
|
||||
from .pb import rpc_pb2
|
||||
from .pubsub_notifee import PubsubNotifee
|
||||
from .subscription import TrioSubscriptionAPI
|
||||
from .validators import signature_validator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .pubsub_router_interface import IPubsubRouter # noqa: F401
|
||||
from .abc import IPubsubRouter # noqa: F401
|
||||
from typing import Any # noqa: F401
|
||||
|
||||
|
||||
@ -57,11 +58,6 @@ class TopicValidator(NamedTuple):
|
||||
is_async: bool
|
||||
|
||||
|
||||
# TODO: Add interface for Pubsub
|
||||
class IPubsub(ABC):
|
||||
pass
|
||||
|
||||
|
||||
class Pubsub(IPubsub, Service):
|
||||
|
||||
host: IHost
|
||||
@ -75,7 +71,7 @@ class Pubsub(IPubsub, Service):
|
||||
|
||||
# TODO: Implement `trio.abc.Channel`?
|
||||
subscribed_topics_send: Dict[str, "trio.MemorySendChannel[rpc_pb2.Message]"]
|
||||
subscribed_topics_receive: Dict[str, "trio.MemoryReceiveChannel[rpc_pb2.Message]"]
|
||||
subscribed_topics_receive: Dict[str, "TrioSubscriptionAPI"]
|
||||
|
||||
peer_topics: Dict[str, List[ID]]
|
||||
peers: Dict[ID, INetStream]
|
||||
@ -380,10 +376,7 @@ class Pubsub(IPubsub, Service):
|
||||
# for each topic
|
||||
await self.subscribed_topics_send[topic].send(publish_message)
|
||||
|
||||
# TODO: Change to return an `AsyncIterable` to be I/O-agnostic?
|
||||
async def subscribe(
|
||||
self, topic_id: str
|
||||
) -> "trio.MemoryReceiveChannel[rpc_pb2.Message]":
|
||||
async def subscribe(self, topic_id: str) -> ISubscriptionAPI:
|
||||
"""
|
||||
Subscribe ourself to a topic.
|
||||
|
||||
@ -396,14 +389,14 @@ class Pubsub(IPubsub, Service):
|
||||
if topic_id in self.topic_ids:
|
||||
return self.subscribed_topics_receive[topic_id]
|
||||
|
||||
# Map topic_id to a blocking channel
|
||||
channels: Tuple[
|
||||
"trio.MemorySendChannel[rpc_pb2.Message]",
|
||||
"trio.MemoryReceiveChannel[rpc_pb2.Message]",
|
||||
] = trio.open_memory_channel(math.inf)
|
||||
send_channel, receive_channel = channels
|
||||
subscription = TrioSubscriptionAPI(receive_channel)
|
||||
self.subscribed_topics_send[topic_id] = send_channel
|
||||
self.subscribed_topics_receive[topic_id] = receive_channel
|
||||
self.subscribed_topics_receive[topic_id] = subscription
|
||||
|
||||
# Create subscribe message
|
||||
packet: rpc_pb2.RPC = rpc_pb2.RPC()
|
||||
@ -417,8 +410,8 @@ class Pubsub(IPubsub, Service):
|
||||
# Tell router we are joining this topic
|
||||
await self.router.join(topic_id)
|
||||
|
||||
# Return the trio channel for messages on this topic
|
||||
return receive_channel
|
||||
# Return the subscription for messages on this topic
|
||||
return subscription
|
||||
|
||||
async def unsubscribe(self, topic_id: str) -> None:
|
||||
"""
|
||||
|
39
libp2p/pubsub/subscription.py
Normal file
39
libp2p/pubsub/subscription.py
Normal file
@ -0,0 +1,39 @@
|
||||
from types import TracebackType
|
||||
from typing import AsyncIterator, Optional, Type
|
||||
|
||||
import trio
|
||||
|
||||
from .abc import ISubscriptionAPI
|
||||
from .pb import rpc_pb2
|
||||
|
||||
|
||||
class BaseSubscriptionAPI(ISubscriptionAPI):
|
||||
async def __aenter__(self) -> "BaseSubscriptionAPI":
|
||||
await trio.hazmat.checkpoint()
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: "Optional[Type[BaseException]]",
|
||||
exc_value: "Optional[BaseException]",
|
||||
traceback: "Optional[TracebackType]",
|
||||
) -> None:
|
||||
await self.cancel()
|
||||
|
||||
|
||||
class TrioSubscriptionAPI(BaseSubscriptionAPI):
|
||||
receive_channel: "trio.MemoryReceiveChannel[rpc_pb2.Message]"
|
||||
|
||||
def __init__(
|
||||
self, receive_channel: "trio.MemoryReceiveChannel[rpc_pb2.Message]"
|
||||
) -> None:
|
||||
self.receive_channel = receive_channel
|
||||
|
||||
async def cancel(self) -> None:
|
||||
await self.receive_channel.aclose()
|
||||
|
||||
def __aiter__(self) -> AsyncIterator[rpc_pb2.Message]:
|
||||
return self.receive_channel.__aiter__()
|
||||
|
||||
async def get(self) -> rpc_pb2.Message:
|
||||
return await self.receive_channel.receive()
|
@ -17,10 +17,10 @@ from libp2p.network.stream.net_stream_interface import INetStream
|
||||
from libp2p.network.swarm import Swarm
|
||||
from libp2p.peer.id import ID
|
||||
from libp2p.peer.peerstore import PeerStore
|
||||
from libp2p.pubsub.abc import IPubsubRouter
|
||||
from libp2p.pubsub.floodsub import FloodSub
|
||||
from libp2p.pubsub.gossipsub import GossipSub
|
||||
from libp2p.pubsub.pubsub import Pubsub
|
||||
from libp2p.pubsub.pubsub_router_interface import IPubsubRouter
|
||||
from libp2p.security.base_transport import BaseSecureTransport
|
||||
from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport
|
||||
import libp2p.security.secio.transport as secio
|
||||
|
@ -61,7 +61,7 @@ class DummyAccountNode(Service):
|
||||
async def handle_incoming_msgs(self) -> None:
|
||||
"""Handle all incoming messages on the CRYPTO_TOPIC from peers."""
|
||||
while True:
|
||||
incoming = await self.subscription.receive()
|
||||
incoming = await self.subscription.get()
|
||||
msg_comps = incoming.data.decode("utf-8").split(",")
|
||||
|
||||
if msg_comps[0] == "send":
|
||||
|
@ -250,7 +250,7 @@ async def perform_test_from_obj(obj, pubsub_factory) -> None:
|
||||
# Look at each node in each topic
|
||||
for node_id in topic_map[topic]:
|
||||
# Get message from subscription queue
|
||||
msg = await queues_map[node_id][topic].receive()
|
||||
msg = await queues_map[node_id][topic].get()
|
||||
assert data == msg.data
|
||||
# Check the message origin
|
||||
assert node_map[origin_node_id].get_id().to_bytes() == msg.from_id
|
||||
|
@ -27,7 +27,7 @@ async def test_simple_two_nodes():
|
||||
|
||||
await pubsubs_fsub[0].publish(topic, data)
|
||||
|
||||
res_b = await sub_b.receive()
|
||||
res_b = await sub_b.get()
|
||||
|
||||
# Check that the msg received by node_b is the same
|
||||
# as the message sent by node_a
|
||||
@ -75,12 +75,9 @@ async def test_lru_cache_two_nodes(monkeypatch):
|
||||
await trio.sleep(0.25)
|
||||
|
||||
for index in expected_received_indices:
|
||||
res_b = await sub_b.receive()
|
||||
res_b = await sub_b.get()
|
||||
assert res_b.data == _make_testing_data(index)
|
||||
|
||||
with pytest.raises(trio.WouldBlock):
|
||||
sub_b.receive_nowait()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("test_case_obj", floodsub_protocol_pytest_params)
|
||||
@pytest.mark.trio
|
||||
|
@ -196,7 +196,7 @@ async def test_dense():
|
||||
await trio.sleep(0.5)
|
||||
# Assert that all blocking queues receive the message
|
||||
for queue in queues:
|
||||
msg = await queue.receive()
|
||||
msg = await queue.get()
|
||||
assert msg.data == msg_content
|
||||
|
||||
|
||||
@ -229,7 +229,7 @@ async def test_fanout():
|
||||
await trio.sleep(0.5)
|
||||
# Assert that all blocking queues receive the message
|
||||
for sub in subs:
|
||||
msg = await sub.receive()
|
||||
msg = await sub.get()
|
||||
assert msg.data == msg_content
|
||||
|
||||
# Subscribe message origin
|
||||
@ -248,7 +248,7 @@ async def test_fanout():
|
||||
await trio.sleep(0.5)
|
||||
# Assert that all blocking queues receive the message
|
||||
for sub in subs:
|
||||
msg = await sub.receive()
|
||||
msg = await sub.get()
|
||||
assert msg.data == msg_content
|
||||
|
||||
|
||||
@ -287,7 +287,7 @@ async def test_fanout_maintenance():
|
||||
await trio.sleep(0.5)
|
||||
# Assert that all blocking queues receive the message
|
||||
for queue in queues:
|
||||
msg = await queue.receive()
|
||||
msg = await queue.get()
|
||||
assert msg.data == msg_content
|
||||
|
||||
for sub in pubsubs_gsub:
|
||||
@ -319,7 +319,7 @@ async def test_fanout_maintenance():
|
||||
await trio.sleep(0.5)
|
||||
# Assert that all blocking queues receive the message
|
||||
for queue in queues:
|
||||
msg = await queue.receive()
|
||||
msg = await queue.get()
|
||||
assert msg.data == msg_content
|
||||
|
||||
|
||||
@ -346,5 +346,5 @@ async def test_gossip_propagation():
|
||||
await trio.sleep(2)
|
||||
|
||||
# should be able to read message
|
||||
msg = await queue_1.receive()
|
||||
msg = await queue_1.get()
|
||||
assert msg.data == msg_content
|
||||
|
@ -384,7 +384,7 @@ async def test_handle_talk():
|
||||
len(pubsubs_fsub[0].topic_ids) == 1
|
||||
and sub == pubsubs_fsub[0].subscribed_topics_receive[TESTING_TOPIC]
|
||||
)
|
||||
assert (await sub.receive()) == msg_0
|
||||
assert (await sub.get()) == msg_0
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
@ -486,7 +486,7 @@ async def test_push_msg(monkeypatch):
|
||||
with trio.fail_after(0.1):
|
||||
await event.wait()
|
||||
# Test: Subscribers are notified when `push_msg` new messages.
|
||||
assert (await sub.receive()) == msg_1
|
||||
assert (await sub.get()) == msg_1
|
||||
|
||||
with mock_router_publish() as event:
|
||||
# Test: add a topic validator and `push_msg` the message that
|
||||
|
77
tests/pubsub/test_subscription.py
Normal file
77
tests/pubsub/test_subscription.py
Normal file
@ -0,0 +1,77 @@
|
||||
import math
|
||||
|
||||
import pytest
|
||||
import trio
|
||||
|
||||
from libp2p.pubsub.pb import rpc_pb2
|
||||
from libp2p.pubsub.subscription import TrioSubscriptionAPI
|
||||
|
||||
GET_TIMEOUT = 0.001
|
||||
|
||||
|
||||
def make_trio_subscription():
|
||||
send_channel, receive_channel = trio.open_memory_channel(math.inf)
|
||||
return send_channel, TrioSubscriptionAPI(receive_channel)
|
||||
|
||||
|
||||
def make_pubsub_msg():
|
||||
return rpc_pb2.Message()
|
||||
|
||||
|
||||
async def send_something(send_channel):
|
||||
msg = make_pubsub_msg()
|
||||
await send_channel.send(msg)
|
||||
return msg
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_trio_subscription_get():
|
||||
send_channel, sub = make_trio_subscription()
|
||||
data_0 = await send_something(send_channel)
|
||||
data_1 = await send_something(send_channel)
|
||||
assert data_0 == await sub.get()
|
||||
assert data_1 == await sub.get()
|
||||
# No more message
|
||||
with pytest.raises(trio.TooSlowError):
|
||||
with trio.fail_after(GET_TIMEOUT):
|
||||
await sub.get()
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_trio_subscription_iter():
|
||||
send_channel, sub = make_trio_subscription()
|
||||
received_data = []
|
||||
|
||||
async def iter_subscriptions(subscription):
|
||||
async for data in sub:
|
||||
received_data.append(data)
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(iter_subscriptions, sub)
|
||||
await send_something(send_channel)
|
||||
await send_something(send_channel)
|
||||
await send_channel.aclose()
|
||||
|
||||
assert len(received_data) == 2
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_trio_subscription_cancel():
|
||||
send_channel, sub = make_trio_subscription()
|
||||
await sub.cancel()
|
||||
# Test: If the subscription is cancelled, `send_channel` should be broken.
|
||||
with pytest.raises(trio.BrokenResourceError):
|
||||
await send_something(send_channel)
|
||||
# Test: No side effect when cancelled twice.
|
||||
await sub.cancel()
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_trio_subscription_async_context_manager():
|
||||
send_channel, sub = make_trio_subscription()
|
||||
async with sub:
|
||||
# Test: `sub` is not cancelled yet, so `send_something` works fine.
|
||||
await send_something(send_channel)
|
||||
# Test: `sub` is cancelled, `send_something` fails
|
||||
with pytest.raises(trio.BrokenResourceError):
|
||||
await send_something(send_channel)
|
Loading…
x
Reference in New Issue
Block a user