Allow Pubsub
creator to supply a custom msg_id
This commit is contained in:
parent
9d68de8c21
commit
ef666267bd
|
@ -1,7 +1,17 @@
|
|||
import functools
|
||||
import logging
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Dict, KeysView, List, NamedTuple, Set, Tuple, cast
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Callable,
|
||||
Dict,
|
||||
KeysView,
|
||||
List,
|
||||
NamedTuple,
|
||||
Set,
|
||||
Tuple,
|
||||
cast,
|
||||
)
|
||||
|
||||
from async_service import Service
|
||||
import base58
|
||||
|
@ -37,9 +47,9 @@ SUBSCRIPTION_CHANNEL_SIZE = 32
|
|||
logger = logging.getLogger("libp2p.pubsub")
|
||||
|
||||
|
||||
def get_msg_id(msg: rpc_pb2.Message) -> Tuple[bytes, bytes]:
|
||||
def get_peer_and_seqno_msg_id(msg: rpc_pb2.Message) -> bytes:
|
||||
# NOTE: `string(from, seqno)` in Go
|
||||
return (msg.seqno, msg.from_id)
|
||||
return msg.seqno + msg.from_id
|
||||
|
||||
|
||||
class TopicValidator(NamedTuple):
|
||||
|
@ -81,6 +91,9 @@ class Pubsub(Service, IPubsub):
|
|||
router: "IPubsubRouter",
|
||||
cache_size: int = None,
|
||||
strict_signing: bool = True,
|
||||
msg_id_constructor: Callable[
|
||||
[rpc_pb2.Message], bytes
|
||||
] = get_peer_and_seqno_msg_id,
|
||||
) -> None:
|
||||
"""
|
||||
Construct a new Pubsub object, which is responsible for handling all
|
||||
|
@ -95,6 +108,8 @@ class Pubsub(Service, IPubsub):
|
|||
self.host = host
|
||||
self.router = router
|
||||
|
||||
self._msg_id_constructor = msg_id_constructor
|
||||
|
||||
# Attach this new Pubsub object to the router
|
||||
self.router.attach(self)
|
||||
|
||||
|
@ -586,11 +601,11 @@ class Pubsub(Service, IPubsub):
|
|||
return self.counter.to_bytes(8, "big")
|
||||
|
||||
def _is_msg_seen(self, msg: rpc_pb2.Message) -> bool:
|
||||
msg_id = get_msg_id(msg)
|
||||
msg_id = self._msg_id_constructor(msg)
|
||||
return msg_id in self.seen_messages
|
||||
|
||||
def _mark_msg_seen(self, msg: rpc_pb2.Message) -> None:
|
||||
msg_id = get_msg_id(msg)
|
||||
msg_id = self._msg_id_constructor(msg)
|
||||
# FIXME: Mapping `msg_id` to `1` is quite awkward. Should investigate if there is a
|
||||
# more appropriate way.
|
||||
self.seen_messages[msg_id] = 1
|
||||
|
|
|
@ -26,7 +26,8 @@ from libp2p.peer.peerstore import PeerStore
|
|||
from libp2p.pubsub.abc import IPubsubRouter
|
||||
from libp2p.pubsub.floodsub import FloodSub
|
||||
from libp2p.pubsub.gossipsub import GossipSub
|
||||
from libp2p.pubsub.pubsub import Pubsub
|
||||
import libp2p.pubsub.pb.rpc_pb2 as rpc_pb2
|
||||
from libp2p.pubsub.pubsub import Pubsub, get_peer_and_seqno_msg_id
|
||||
from libp2p.routing.interfaces import IPeerRouting
|
||||
from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport
|
||||
from libp2p.security.noise.messages import (
|
||||
|
@ -370,13 +371,19 @@ class PubsubFactory(factory.Factory):
|
|||
@classmethod
|
||||
@asynccontextmanager
|
||||
async def create_and_start(
|
||||
cls, host: IHost, router: IPubsubRouter, cache_size: int, strict_signing: bool
|
||||
cls,
|
||||
host: IHost,
|
||||
router: IPubsubRouter,
|
||||
cache_size: int,
|
||||
strict_signing: bool,
|
||||
msg_id_constructor: Callable[[rpc_pb2.Message], bytes] = None,
|
||||
) -> AsyncIterator[Pubsub]:
|
||||
pubsub = cls(
|
||||
host=host,
|
||||
router=router,
|
||||
cache_size=cache_size,
|
||||
strict_signing=strict_signing,
|
||||
msg_id_constructor=msg_id_constructor,
|
||||
)
|
||||
async with background_trio_service(pubsub):
|
||||
await pubsub.wait_until_ready()
|
||||
|
@ -392,6 +399,7 @@ class PubsubFactory(factory.Factory):
|
|||
strict_signing: bool = False,
|
||||
security_protocol: TProtocol = None,
|
||||
muxer_opt: TMuxerOptions = None,
|
||||
msg_id_constructor: Callable[[rpc_pb2.Message], bytes] = None,
|
||||
) -> AsyncIterator[Tuple[Pubsub, ...]]:
|
||||
async with HostFactory.create_batch_and_listen(
|
||||
number, security_protocol=security_protocol, muxer_opt=muxer_opt
|
||||
|
@ -400,7 +408,9 @@ class PubsubFactory(factory.Factory):
|
|||
async with AsyncExitStack() as stack:
|
||||
pubsubs = [
|
||||
await stack.enter_async_context(
|
||||
cls.create_and_start(host, router, cache_size, strict_signing)
|
||||
cls.create_and_start(
|
||||
host, router, cache_size, strict_signing, msg_id_constructor
|
||||
)
|
||||
)
|
||||
for host, router in zip(hosts, routers)
|
||||
]
|
||||
|
@ -416,6 +426,9 @@ class PubsubFactory(factory.Factory):
|
|||
protocols: Sequence[TProtocol] = None,
|
||||
security_protocol: TProtocol = None,
|
||||
muxer_opt: TMuxerOptions = None,
|
||||
msg_id_constructor: Callable[
|
||||
[rpc_pb2.Message], bytes
|
||||
] = get_peer_and_seqno_msg_id,
|
||||
) -> AsyncIterator[Tuple[Pubsub, ...]]:
|
||||
if protocols is not None:
|
||||
floodsubs = FloodsubFactory.create_batch(number, protocols=list(protocols))
|
||||
|
@ -428,6 +441,7 @@ class PubsubFactory(factory.Factory):
|
|||
strict_signing,
|
||||
security_protocol=security_protocol,
|
||||
muxer_opt=muxer_opt,
|
||||
msg_id_constructor=msg_id_constructor,
|
||||
) as pubsubs:
|
||||
yield pubsubs
|
||||
|
||||
|
@ -450,6 +464,9 @@ class PubsubFactory(factory.Factory):
|
|||
heartbeat_initial_delay: float = GOSSIPSUB_PARAMS.heartbeat_initial_delay,
|
||||
security_protocol: TProtocol = None,
|
||||
muxer_opt: TMuxerOptions = None,
|
||||
msg_id_constructor: Callable[
|
||||
[rpc_pb2.Message], bytes
|
||||
] = get_peer_and_seqno_msg_id,
|
||||
) -> AsyncIterator[Tuple[Pubsub, ...]]:
|
||||
if protocols is not None:
|
||||
gossipsubs = GossipsubFactory.create_batch(
|
||||
|
@ -480,6 +497,7 @@ class PubsubFactory(factory.Factory):
|
|||
strict_signing,
|
||||
security_protocol=security_protocol,
|
||||
muxer_opt=muxer_opt,
|
||||
msg_id_constructor=msg_id_constructor,
|
||||
) as pubsubs:
|
||||
async with AsyncExitStack() as stack:
|
||||
for router in gossipsubs:
|
||||
|
|
|
@ -37,10 +37,16 @@ async def test_simple_two_nodes():
|
|||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_lru_cache_two_nodes(monkeypatch):
|
||||
async def test_lru_cache_two_nodes():
|
||||
# two nodes with cache_size of 4
|
||||
|
||||
# Mock `get_msg_id` to make us easier to manipulate `msg_id` by `data`.
|
||||
def get_msg_id(msg):
|
||||
# Originally it is `(msg.seqno, msg.from_id)`
|
||||
return (msg.data, msg.from_id)
|
||||
|
||||
async with PubsubFactory.create_batch_with_floodsub(
|
||||
2, cache_size=4
|
||||
2, cache_size=4, msg_id_constructor=get_msg_id
|
||||
) as pubsubs_fsub:
|
||||
# `node_a` send the following messages to node_b
|
||||
message_indices = [1, 1, 2, 1, 3, 1, 4, 1, 5, 1]
|
||||
|
@ -49,15 +55,6 @@ async def test_lru_cache_two_nodes(monkeypatch):
|
|||
|
||||
topic = "my_topic"
|
||||
|
||||
# Mock `get_msg_id` to make us easier to manipulate `msg_id` by `data`.
|
||||
def get_msg_id(msg):
|
||||
# Originally it is `(msg.seqno, msg.from_id)`
|
||||
return (msg.data, msg.from_id)
|
||||
|
||||
import libp2p.pubsub.pubsub
|
||||
|
||||
monkeypatch.setattr(libp2p.pubsub.pubsub, "get_msg_id", get_msg_id)
|
||||
|
||||
await connect(pubsubs_fsub[0].host, pubsubs_fsub[1].host)
|
||||
await trio.sleep(0.25)
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user