Fix IPubsub and add IPubsub.wait_until_ready

This commit is contained in:
mhchia 2020-01-27 00:10:33 +08:00
parent e3a1dd62e4
commit 92ea35e147
No known key found for this signature in database
GPG Key ID: 389EFBEA1362589A
4 changed files with 76 additions and 32 deletions

View File

@ -1,18 +1,35 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, AsyncContextManager, AsyncIterable, List from typing import (
TYPE_CHECKING,
AsyncContextManager,
AsyncIterable,
KeysView,
List,
Tuple,
)
from async_service import ServiceAPI
from libp2p.peer.id import ID from libp2p.peer.id import ID
from libp2p.typing import TProtocol from libp2p.typing import TProtocol
from .pb import rpc_pb2 from .pb import rpc_pb2
from .typing import ValidatorFn
if TYPE_CHECKING: if TYPE_CHECKING:
from .pubsub import Pubsub # noqa: F401 from .pubsub import Pubsub # noqa: F401
# TODO: Add interface for Pubsub class ISubscriptionAPI(
class IPubsub(ABC): AsyncContextManager["ISubscriptionAPI"], AsyncIterable[rpc_pb2.Message]
pass ):
@abstractmethod
async def cancel(self) -> None:
...
@abstractmethod
async def get(self) -> rpc_pb2.Message:
...
class IPubsubRouter(ABC): class IPubsubRouter(ABC):
@ -86,13 +103,44 @@ class IPubsubRouter(ABC):
""" """
class ISubscriptionAPI( class IPubsub(ServiceAPI):
AsyncContextManager["ISubscriptionAPI"], AsyncIterable[rpc_pb2.Message] @property
):
@abstractmethod @abstractmethod
async def cancel(self) -> None: def my_id(self) -> ID:
...
@property
@abstractmethod
def protocols(self) -> Tuple[TProtocol, ...]:
...
@property
@abstractmethod
def topic_ids(self) -> KeysView[str]:
... ...
@abstractmethod @abstractmethod
async def get(self) -> rpc_pb2.Message: def set_topic_validator(
self, topic: str, validator: ValidatorFn, is_async_validator: bool
) -> None:
...
@abstractmethod
def remove_topic_validator(self, topic: str) -> None:
...
@abstractmethod
async def wait_until_ready(self) -> None:
...
@abstractmethod
async def subscribe(self, topic_id: str) -> ISubscriptionAPI:
...
@abstractmethod
async def unsubscribe(self, topic_id: str) -> None:
...
@abstractmethod
async def publish(self, topic_id: str, data: bytes) -> None:
... ...

View File

@ -1,19 +1,7 @@
import logging import logging
import math import math
import time import time
from typing import ( from typing import TYPE_CHECKING, Dict, KeysView, List, NamedTuple, Set, Tuple, cast
TYPE_CHECKING,
Awaitable,
Callable,
Dict,
KeysView,
List,
NamedTuple,
Set,
Tuple,
Union,
cast,
)
from async_service import Service from async_service import Service
import base58 import base58
@ -35,6 +23,7 @@ from .abc import IPubsub, ISubscriptionAPI
from .pb import rpc_pb2 from .pb import rpc_pb2
from .pubsub_notifee import PubsubNotifee from .pubsub_notifee import PubsubNotifee
from .subscription import TrioSubscriptionAPI from .subscription import TrioSubscriptionAPI
from .typing import AsyncValidatorFn, SyncValidatorFn, ValidatorFn
from .validators import PUBSUB_SIGNING_PREFIX, signature_validator from .validators import PUBSUB_SIGNING_PREFIX, signature_validator
if TYPE_CHECKING: if TYPE_CHECKING:
@ -50,17 +39,12 @@ def get_msg_id(msg: rpc_pb2.Message) -> Tuple[bytes, bytes]:
return (msg.seqno, msg.from_id) return (msg.seqno, msg.from_id)
SyncValidatorFn = Callable[[ID, rpc_pb2.Message], bool]
AsyncValidatorFn = Callable[[ID, rpc_pb2.Message], Awaitable[bool]]
ValidatorFn = Union[SyncValidatorFn, AsyncValidatorFn]
class TopicValidator(NamedTuple): class TopicValidator(NamedTuple):
validator: ValidatorFn validator: ValidatorFn
is_async: bool is_async: bool
class Pubsub(IPubsub, Service): class Pubsub(Service, IPubsub):
host: IHost host: IHost
@ -290,6 +274,10 @@ class Pubsub(IPubsub, Service):
await stream.reset() await stream.reset()
self._handle_dead_peer(peer_id) self._handle_dead_peer(peer_id)
async def wait_until_ready(self) -> None:
await self.event_handle_peer_queue_started.wait()
await self.event_handle_dead_peer_queue_started.wait()
async def _handle_new_peer(self, peer_id: ID) -> None: async def _handle_new_peer(self, peer_id: ID) -> None:
try: try:
stream: INetStream = await self.host.new_stream(peer_id, self.protocols) stream: INetStream = await self.host.new_stream(peer_id, self.protocols)
@ -332,18 +320,18 @@ class Pubsub(IPubsub, Service):
"""Continuously read from peer queue and each time a new peer is found, """Continuously read from peer queue and each time a new peer is found,
open a stream to the peer using a supported pubsub protocol pubsub open a stream to the peer using a supported pubsub protocol pubsub
protocols we support.""" protocols we support."""
self.event_handle_peer_queue_started.set()
async with self.peer_receive_channel: async with self.peer_receive_channel:
self.event_handle_peer_queue_started.set()
async for peer_id in self.peer_receive_channel: async for peer_id in self.peer_receive_channel:
# Add Peer # Add Peer
self.manager.run_task(self._handle_new_peer, peer_id) self.manager.run_task(self._handle_new_peer, peer_id)
async def handle_dead_peer_queue(self) -> None: async def handle_dead_peer_queue(self) -> None:
self.event_handle_dead_peer_queue_started.set()
"""Continuously read from dead peer channel and close the stream """Continuously read from dead peer channel and close the stream
between that peer and remove peer info from pubsub and pubsub between that peer and remove peer info from pubsub and pubsub
router.""" router."""
async with self.dead_peer_receive_channel: async with self.dead_peer_receive_channel:
self.event_handle_dead_peer_queue_started.set()
async for peer_id in self.dead_peer_receive_channel: async for peer_id in self.dead_peer_receive_channel:
# Remove Peer # Remove Peer
self._handle_dead_peer(peer_id) self._handle_dead_peer(peer_id)

9
libp2p/pubsub/typing.py Normal file
View File

@ -0,0 +1,9 @@
from typing import Awaitable, Callable, Union
from libp2p.peer.id import ID
from .pb import rpc_pb2
SyncValidatorFn = Callable[[ID, rpc_pb2.Message], bool]
AsyncValidatorFn = Callable[[ID, rpc_pb2.Message], Awaitable[bool]]
ValidatorFn = Union[SyncValidatorFn, AsyncValidatorFn]

View File

@ -245,8 +245,7 @@ class PubsubFactory(factory.Factory):
strict_signing=strict_signing, strict_signing=strict_signing,
) )
async with background_trio_service(pubsub): async with background_trio_service(pubsub):
await pubsub.event_handle_peer_queue_started.wait() await pubsub.wait_until_ready()
await pubsub.event_handle_dead_peer_queue_started.wait()
yield pubsub yield pubsub
@classmethod @classmethod