Allow Pubsub
creator to supply a custom msg_id
This commit is contained in:
parent
9d68de8c21
commit
ef666267bd
|
@ -1,7 +1,17 @@
|
||||||
import functools
|
import functools
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from typing import TYPE_CHECKING, Dict, KeysView, List, NamedTuple, Set, Tuple, cast
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
KeysView,
|
||||||
|
List,
|
||||||
|
NamedTuple,
|
||||||
|
Set,
|
||||||
|
Tuple,
|
||||||
|
cast,
|
||||||
|
)
|
||||||
|
|
||||||
from async_service import Service
|
from async_service import Service
|
||||||
import base58
|
import base58
|
||||||
|
@ -37,9 +47,9 @@ SUBSCRIPTION_CHANNEL_SIZE = 32
|
||||||
logger = logging.getLogger("libp2p.pubsub")
|
logger = logging.getLogger("libp2p.pubsub")
|
||||||
|
|
||||||
|
|
||||||
def get_msg_id(msg: rpc_pb2.Message) -> Tuple[bytes, bytes]:
|
def get_peer_and_seqno_msg_id(msg: rpc_pb2.Message) -> bytes:
|
||||||
# NOTE: `string(from, seqno)` in Go
|
# NOTE: `string(from, seqno)` in Go
|
||||||
return (msg.seqno, msg.from_id)
|
return msg.seqno + msg.from_id
|
||||||
|
|
||||||
|
|
||||||
class TopicValidator(NamedTuple):
|
class TopicValidator(NamedTuple):
|
||||||
|
@ -81,6 +91,9 @@ class Pubsub(Service, IPubsub):
|
||||||
router: "IPubsubRouter",
|
router: "IPubsubRouter",
|
||||||
cache_size: int = None,
|
cache_size: int = None,
|
||||||
strict_signing: bool = True,
|
strict_signing: bool = True,
|
||||||
|
msg_id_constructor: Callable[
|
||||||
|
[rpc_pb2.Message], bytes
|
||||||
|
] = get_peer_and_seqno_msg_id,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Construct a new Pubsub object, which is responsible for handling all
|
Construct a new Pubsub object, which is responsible for handling all
|
||||||
|
@ -95,6 +108,8 @@ class Pubsub(Service, IPubsub):
|
||||||
self.host = host
|
self.host = host
|
||||||
self.router = router
|
self.router = router
|
||||||
|
|
||||||
|
self._msg_id_constructor = msg_id_constructor
|
||||||
|
|
||||||
# Attach this new Pubsub object to the router
|
# Attach this new Pubsub object to the router
|
||||||
self.router.attach(self)
|
self.router.attach(self)
|
||||||
|
|
||||||
|
@ -586,11 +601,11 @@ class Pubsub(Service, IPubsub):
|
||||||
return self.counter.to_bytes(8, "big")
|
return self.counter.to_bytes(8, "big")
|
||||||
|
|
||||||
def _is_msg_seen(self, msg: rpc_pb2.Message) -> bool:
|
def _is_msg_seen(self, msg: rpc_pb2.Message) -> bool:
|
||||||
msg_id = get_msg_id(msg)
|
msg_id = self._msg_id_constructor(msg)
|
||||||
return msg_id in self.seen_messages
|
return msg_id in self.seen_messages
|
||||||
|
|
||||||
def _mark_msg_seen(self, msg: rpc_pb2.Message) -> None:
|
def _mark_msg_seen(self, msg: rpc_pb2.Message) -> None:
|
||||||
msg_id = get_msg_id(msg)
|
msg_id = self._msg_id_constructor(msg)
|
||||||
# FIXME: Mapping `msg_id` to `1` is quite awkward. Should investigate if there is a
|
# FIXME: Mapping `msg_id` to `1` is quite awkward. Should investigate if there is a
|
||||||
# more appropriate way.
|
# more appropriate way.
|
||||||
self.seen_messages[msg_id] = 1
|
self.seen_messages[msg_id] = 1
|
||||||
|
|
|
@ -26,7 +26,8 @@ from libp2p.peer.peerstore import PeerStore
|
||||||
from libp2p.pubsub.abc import IPubsubRouter
|
from libp2p.pubsub.abc import IPubsubRouter
|
||||||
from libp2p.pubsub.floodsub import FloodSub
|
from libp2p.pubsub.floodsub import FloodSub
|
||||||
from libp2p.pubsub.gossipsub import GossipSub
|
from libp2p.pubsub.gossipsub import GossipSub
|
||||||
from libp2p.pubsub.pubsub import Pubsub
|
import libp2p.pubsub.pb.rpc_pb2 as rpc_pb2
|
||||||
|
from libp2p.pubsub.pubsub import Pubsub, get_peer_and_seqno_msg_id
|
||||||
from libp2p.routing.interfaces import IPeerRouting
|
from libp2p.routing.interfaces import IPeerRouting
|
||||||
from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport
|
from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport
|
||||||
from libp2p.security.noise.messages import (
|
from libp2p.security.noise.messages import (
|
||||||
|
@ -370,13 +371,19 @@ class PubsubFactory(factory.Factory):
|
||||||
@classmethod
|
@classmethod
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def create_and_start(
|
async def create_and_start(
|
||||||
cls, host: IHost, router: IPubsubRouter, cache_size: int, strict_signing: bool
|
cls,
|
||||||
|
host: IHost,
|
||||||
|
router: IPubsubRouter,
|
||||||
|
cache_size: int,
|
||||||
|
strict_signing: bool,
|
||||||
|
msg_id_constructor: Callable[[rpc_pb2.Message], bytes] = None,
|
||||||
) -> AsyncIterator[Pubsub]:
|
) -> AsyncIterator[Pubsub]:
|
||||||
pubsub = cls(
|
pubsub = cls(
|
||||||
host=host,
|
host=host,
|
||||||
router=router,
|
router=router,
|
||||||
cache_size=cache_size,
|
cache_size=cache_size,
|
||||||
strict_signing=strict_signing,
|
strict_signing=strict_signing,
|
||||||
|
msg_id_constructor=msg_id_constructor,
|
||||||
)
|
)
|
||||||
async with background_trio_service(pubsub):
|
async with background_trio_service(pubsub):
|
||||||
await pubsub.wait_until_ready()
|
await pubsub.wait_until_ready()
|
||||||
|
@ -392,6 +399,7 @@ class PubsubFactory(factory.Factory):
|
||||||
strict_signing: bool = False,
|
strict_signing: bool = False,
|
||||||
security_protocol: TProtocol = None,
|
security_protocol: TProtocol = None,
|
||||||
muxer_opt: TMuxerOptions = None,
|
muxer_opt: TMuxerOptions = None,
|
||||||
|
msg_id_constructor: Callable[[rpc_pb2.Message], bytes] = None,
|
||||||
) -> AsyncIterator[Tuple[Pubsub, ...]]:
|
) -> AsyncIterator[Tuple[Pubsub, ...]]:
|
||||||
async with HostFactory.create_batch_and_listen(
|
async with HostFactory.create_batch_and_listen(
|
||||||
number, security_protocol=security_protocol, muxer_opt=muxer_opt
|
number, security_protocol=security_protocol, muxer_opt=muxer_opt
|
||||||
|
@ -400,7 +408,9 @@ class PubsubFactory(factory.Factory):
|
||||||
async with AsyncExitStack() as stack:
|
async with AsyncExitStack() as stack:
|
||||||
pubsubs = [
|
pubsubs = [
|
||||||
await stack.enter_async_context(
|
await stack.enter_async_context(
|
||||||
cls.create_and_start(host, router, cache_size, strict_signing)
|
cls.create_and_start(
|
||||||
|
host, router, cache_size, strict_signing, msg_id_constructor
|
||||||
|
)
|
||||||
)
|
)
|
||||||
for host, router in zip(hosts, routers)
|
for host, router in zip(hosts, routers)
|
||||||
]
|
]
|
||||||
|
@ -416,6 +426,9 @@ class PubsubFactory(factory.Factory):
|
||||||
protocols: Sequence[TProtocol] = None,
|
protocols: Sequence[TProtocol] = None,
|
||||||
security_protocol: TProtocol = None,
|
security_protocol: TProtocol = None,
|
||||||
muxer_opt: TMuxerOptions = None,
|
muxer_opt: TMuxerOptions = None,
|
||||||
|
msg_id_constructor: Callable[
|
||||||
|
[rpc_pb2.Message], bytes
|
||||||
|
] = get_peer_and_seqno_msg_id,
|
||||||
) -> AsyncIterator[Tuple[Pubsub, ...]]:
|
) -> AsyncIterator[Tuple[Pubsub, ...]]:
|
||||||
if protocols is not None:
|
if protocols is not None:
|
||||||
floodsubs = FloodsubFactory.create_batch(number, protocols=list(protocols))
|
floodsubs = FloodsubFactory.create_batch(number, protocols=list(protocols))
|
||||||
|
@ -428,6 +441,7 @@ class PubsubFactory(factory.Factory):
|
||||||
strict_signing,
|
strict_signing,
|
||||||
security_protocol=security_protocol,
|
security_protocol=security_protocol,
|
||||||
muxer_opt=muxer_opt,
|
muxer_opt=muxer_opt,
|
||||||
|
msg_id_constructor=msg_id_constructor,
|
||||||
) as pubsubs:
|
) as pubsubs:
|
||||||
yield pubsubs
|
yield pubsubs
|
||||||
|
|
||||||
|
@ -450,6 +464,9 @@ class PubsubFactory(factory.Factory):
|
||||||
heartbeat_initial_delay: float = GOSSIPSUB_PARAMS.heartbeat_initial_delay,
|
heartbeat_initial_delay: float = GOSSIPSUB_PARAMS.heartbeat_initial_delay,
|
||||||
security_protocol: TProtocol = None,
|
security_protocol: TProtocol = None,
|
||||||
muxer_opt: TMuxerOptions = None,
|
muxer_opt: TMuxerOptions = None,
|
||||||
|
msg_id_constructor: Callable[
|
||||||
|
[rpc_pb2.Message], bytes
|
||||||
|
] = get_peer_and_seqno_msg_id,
|
||||||
) -> AsyncIterator[Tuple[Pubsub, ...]]:
|
) -> AsyncIterator[Tuple[Pubsub, ...]]:
|
||||||
if protocols is not None:
|
if protocols is not None:
|
||||||
gossipsubs = GossipsubFactory.create_batch(
|
gossipsubs = GossipsubFactory.create_batch(
|
||||||
|
@ -480,6 +497,7 @@ class PubsubFactory(factory.Factory):
|
||||||
strict_signing,
|
strict_signing,
|
||||||
security_protocol=security_protocol,
|
security_protocol=security_protocol,
|
||||||
muxer_opt=muxer_opt,
|
muxer_opt=muxer_opt,
|
||||||
|
msg_id_constructor=msg_id_constructor,
|
||||||
) as pubsubs:
|
) as pubsubs:
|
||||||
async with AsyncExitStack() as stack:
|
async with AsyncExitStack() as stack:
|
||||||
for router in gossipsubs:
|
for router in gossipsubs:
|
||||||
|
|
|
@ -37,10 +37,16 @@ async def test_simple_two_nodes():
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.trio
|
@pytest.mark.trio
|
||||||
async def test_lru_cache_two_nodes(monkeypatch):
|
async def test_lru_cache_two_nodes():
|
||||||
# two nodes with cache_size of 4
|
# two nodes with cache_size of 4
|
||||||
|
|
||||||
|
# Mock `get_msg_id` to make us easier to manipulate `msg_id` by `data`.
|
||||||
|
def get_msg_id(msg):
|
||||||
|
# Originally it is `(msg.seqno, msg.from_id)`
|
||||||
|
return (msg.data, msg.from_id)
|
||||||
|
|
||||||
async with PubsubFactory.create_batch_with_floodsub(
|
async with PubsubFactory.create_batch_with_floodsub(
|
||||||
2, cache_size=4
|
2, cache_size=4, msg_id_constructor=get_msg_id
|
||||||
) as pubsubs_fsub:
|
) as pubsubs_fsub:
|
||||||
# `node_a` send the following messages to node_b
|
# `node_a` send the following messages to node_b
|
||||||
message_indices = [1, 1, 2, 1, 3, 1, 4, 1, 5, 1]
|
message_indices = [1, 1, 2, 1, 3, 1, 4, 1, 5, 1]
|
||||||
|
@ -49,15 +55,6 @@ async def test_lru_cache_two_nodes(monkeypatch):
|
||||||
|
|
||||||
topic = "my_topic"
|
topic = "my_topic"
|
||||||
|
|
||||||
# Mock `get_msg_id` to make us easier to manipulate `msg_id` by `data`.
|
|
||||||
def get_msg_id(msg):
|
|
||||||
# Originally it is `(msg.seqno, msg.from_id)`
|
|
||||||
return (msg.data, msg.from_id)
|
|
||||||
|
|
||||||
import libp2p.pubsub.pubsub
|
|
||||||
|
|
||||||
monkeypatch.setattr(libp2p.pubsub.pubsub, "get_msg_id", get_msg_id)
|
|
||||||
|
|
||||||
await connect(pubsubs_fsub[0].host, pubsubs_fsub[1].host)
|
await connect(pubsubs_fsub[0].host, pubsubs_fsub[1].host)
|
||||||
await trio.sleep(0.25)
|
await trio.sleep(0.25)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user