Allow `Pubsub` creator to supply a custom msg_id

pull/410/head
Alex Stokes 2020-02-27 11:57:00 -08:00
parent 9d68de8c21
commit ef666267bd
No known key found for this signature in database
GPG Key ID: 51CE1721B245C086
3 changed files with 49 additions and 19 deletions

View File

@ -1,7 +1,17 @@
import functools import functools
import logging import logging
import time import time
from typing import TYPE_CHECKING, Dict, KeysView, List, NamedTuple, Set, Tuple, cast from typing import (
TYPE_CHECKING,
Callable,
Dict,
KeysView,
List,
NamedTuple,
Set,
Tuple,
cast,
)
from async_service import Service from async_service import Service
import base58 import base58
@ -37,9 +47,9 @@ SUBSCRIPTION_CHANNEL_SIZE = 32
logger = logging.getLogger("libp2p.pubsub") logger = logging.getLogger("libp2p.pubsub")
def get_msg_id(msg: rpc_pb2.Message) -> Tuple[bytes, bytes]: def get_peer_and_seqno_msg_id(msg: rpc_pb2.Message) -> bytes:
# NOTE: `string(from, seqno)` in Go # NOTE: `string(from, seqno)` in Go
return (msg.seqno, msg.from_id) return msg.seqno + msg.from_id
class TopicValidator(NamedTuple): class TopicValidator(NamedTuple):
@ -81,6 +91,9 @@ class Pubsub(Service, IPubsub):
router: "IPubsubRouter", router: "IPubsubRouter",
cache_size: int = None, cache_size: int = None,
strict_signing: bool = True, strict_signing: bool = True,
msg_id_constructor: Callable[
[rpc_pb2.Message], bytes
] = get_peer_and_seqno_msg_id,
) -> None: ) -> None:
""" """
Construct a new Pubsub object, which is responsible for handling all Construct a new Pubsub object, which is responsible for handling all
@ -95,6 +108,8 @@ class Pubsub(Service, IPubsub):
self.host = host self.host = host
self.router = router self.router = router
self._msg_id_constructor = msg_id_constructor
# Attach this new Pubsub object to the router # Attach this new Pubsub object to the router
self.router.attach(self) self.router.attach(self)
@ -586,11 +601,11 @@ class Pubsub(Service, IPubsub):
return self.counter.to_bytes(8, "big") return self.counter.to_bytes(8, "big")
def _is_msg_seen(self, msg: rpc_pb2.Message) -> bool: def _is_msg_seen(self, msg: rpc_pb2.Message) -> bool:
msg_id = get_msg_id(msg) msg_id = self._msg_id_constructor(msg)
return msg_id in self.seen_messages return msg_id in self.seen_messages
def _mark_msg_seen(self, msg: rpc_pb2.Message) -> None: def _mark_msg_seen(self, msg: rpc_pb2.Message) -> None:
msg_id = get_msg_id(msg) msg_id = self._msg_id_constructor(msg)
# FIXME: Mapping `msg_id` to `1` is quite awkward. Should investigate if there is a # FIXME: Mapping `msg_id` to `1` is quite awkward. Should investigate if there is a
# more appropriate way. # more appropriate way.
self.seen_messages[msg_id] = 1 self.seen_messages[msg_id] = 1

View File

@ -26,7 +26,8 @@ from libp2p.peer.peerstore import PeerStore
from libp2p.pubsub.abc import IPubsubRouter from libp2p.pubsub.abc import IPubsubRouter
from libp2p.pubsub.floodsub import FloodSub from libp2p.pubsub.floodsub import FloodSub
from libp2p.pubsub.gossipsub import GossipSub from libp2p.pubsub.gossipsub import GossipSub
from libp2p.pubsub.pubsub import Pubsub import libp2p.pubsub.pb.rpc_pb2 as rpc_pb2
from libp2p.pubsub.pubsub import Pubsub, get_peer_and_seqno_msg_id
from libp2p.routing.interfaces import IPeerRouting from libp2p.routing.interfaces import IPeerRouting
from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport
from libp2p.security.noise.messages import ( from libp2p.security.noise.messages import (
@ -370,13 +371,19 @@ class PubsubFactory(factory.Factory):
@classmethod @classmethod
@asynccontextmanager @asynccontextmanager
async def create_and_start( async def create_and_start(
cls, host: IHost, router: IPubsubRouter, cache_size: int, strict_signing: bool cls,
host: IHost,
router: IPubsubRouter,
cache_size: int,
strict_signing: bool,
msg_id_constructor: Callable[[rpc_pb2.Message], bytes] = None,
) -> AsyncIterator[Pubsub]: ) -> AsyncIterator[Pubsub]:
pubsub = cls( pubsub = cls(
host=host, host=host,
router=router, router=router,
cache_size=cache_size, cache_size=cache_size,
strict_signing=strict_signing, strict_signing=strict_signing,
msg_id_constructor=msg_id_constructor,
) )
async with background_trio_service(pubsub): async with background_trio_service(pubsub):
await pubsub.wait_until_ready() await pubsub.wait_until_ready()
@ -392,6 +399,7 @@ class PubsubFactory(factory.Factory):
strict_signing: bool = False, strict_signing: bool = False,
security_protocol: TProtocol = None, security_protocol: TProtocol = None,
muxer_opt: TMuxerOptions = None, muxer_opt: TMuxerOptions = None,
msg_id_constructor: Callable[[rpc_pb2.Message], bytes] = None,
) -> AsyncIterator[Tuple[Pubsub, ...]]: ) -> AsyncIterator[Tuple[Pubsub, ...]]:
async with HostFactory.create_batch_and_listen( async with HostFactory.create_batch_and_listen(
number, security_protocol=security_protocol, muxer_opt=muxer_opt number, security_protocol=security_protocol, muxer_opt=muxer_opt
@ -400,7 +408,9 @@ class PubsubFactory(factory.Factory):
async with AsyncExitStack() as stack: async with AsyncExitStack() as stack:
pubsubs = [ pubsubs = [
await stack.enter_async_context( await stack.enter_async_context(
cls.create_and_start(host, router, cache_size, strict_signing) cls.create_and_start(
host, router, cache_size, strict_signing, msg_id_constructor
)
) )
for host, router in zip(hosts, routers) for host, router in zip(hosts, routers)
] ]
@ -416,6 +426,9 @@ class PubsubFactory(factory.Factory):
protocols: Sequence[TProtocol] = None, protocols: Sequence[TProtocol] = None,
security_protocol: TProtocol = None, security_protocol: TProtocol = None,
muxer_opt: TMuxerOptions = None, muxer_opt: TMuxerOptions = None,
msg_id_constructor: Callable[
[rpc_pb2.Message], bytes
] = get_peer_and_seqno_msg_id,
) -> AsyncIterator[Tuple[Pubsub, ...]]: ) -> AsyncIterator[Tuple[Pubsub, ...]]:
if protocols is not None: if protocols is not None:
floodsubs = FloodsubFactory.create_batch(number, protocols=list(protocols)) floodsubs = FloodsubFactory.create_batch(number, protocols=list(protocols))
@ -428,6 +441,7 @@ class PubsubFactory(factory.Factory):
strict_signing, strict_signing,
security_protocol=security_protocol, security_protocol=security_protocol,
muxer_opt=muxer_opt, muxer_opt=muxer_opt,
msg_id_constructor=msg_id_constructor,
) as pubsubs: ) as pubsubs:
yield pubsubs yield pubsubs
@ -450,6 +464,9 @@ class PubsubFactory(factory.Factory):
heartbeat_initial_delay: float = GOSSIPSUB_PARAMS.heartbeat_initial_delay, heartbeat_initial_delay: float = GOSSIPSUB_PARAMS.heartbeat_initial_delay,
security_protocol: TProtocol = None, security_protocol: TProtocol = None,
muxer_opt: TMuxerOptions = None, muxer_opt: TMuxerOptions = None,
msg_id_constructor: Callable[
[rpc_pb2.Message], bytes
] = get_peer_and_seqno_msg_id,
) -> AsyncIterator[Tuple[Pubsub, ...]]: ) -> AsyncIterator[Tuple[Pubsub, ...]]:
if protocols is not None: if protocols is not None:
gossipsubs = GossipsubFactory.create_batch( gossipsubs = GossipsubFactory.create_batch(
@ -480,6 +497,7 @@ class PubsubFactory(factory.Factory):
strict_signing, strict_signing,
security_protocol=security_protocol, security_protocol=security_protocol,
muxer_opt=muxer_opt, muxer_opt=muxer_opt,
msg_id_constructor=msg_id_constructor,
) as pubsubs: ) as pubsubs:
async with AsyncExitStack() as stack: async with AsyncExitStack() as stack:
for router in gossipsubs: for router in gossipsubs:

View File

@ -37,10 +37,16 @@ async def test_simple_two_nodes():
@pytest.mark.trio @pytest.mark.trio
async def test_lru_cache_two_nodes(monkeypatch): async def test_lru_cache_two_nodes():
# two nodes with cache_size of 4 # two nodes with cache_size of 4
# Mock `get_msg_id` to make us easier to manipulate `msg_id` by `data`.
def get_msg_id(msg):
# Originally it is `(msg.seqno, msg.from_id)`
return (msg.data, msg.from_id)
async with PubsubFactory.create_batch_with_floodsub( async with PubsubFactory.create_batch_with_floodsub(
2, cache_size=4 2, cache_size=4, msg_id_constructor=get_msg_id
) as pubsubs_fsub: ) as pubsubs_fsub:
# `node_a` send the following messages to node_b # `node_a` send the following messages to node_b
message_indices = [1, 1, 2, 1, 3, 1, 4, 1, 5, 1] message_indices = [1, 1, 2, 1, 3, 1, 4, 1, 5, 1]
@ -49,15 +55,6 @@ async def test_lru_cache_two_nodes(monkeypatch):
topic = "my_topic" topic = "my_topic"
# Mock `get_msg_id` to make us easier to manipulate `msg_id` by `data`.
def get_msg_id(msg):
# Originally it is `(msg.seqno, msg.from_id)`
return (msg.data, msg.from_id)
import libp2p.pubsub.pubsub
monkeypatch.setattr(libp2p.pubsub.pubsub, "get_msg_id", get_msg_id)
await connect(pubsubs_fsub[0].host, pubsubs_fsub[1].host) await connect(pubsubs_fsub[0].host, pubsubs_fsub[1].host)
await trio.sleep(0.25) await trio.sleep(0.25)