Merge pull request #410 from ralexstokes/configurable-msg-id

Allow `Pubsub` creator to supply a custom msg_id
pull/411/head
Alex Stokes 2020-02-28 08:42:49 -08:00 committed by GitHub
commit 379a157d6b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 55 additions and 19 deletions

View File

@ -1,7 +1,19 @@
import base64
import functools
import hashlib
import logging
import time
from typing import TYPE_CHECKING, Dict, KeysView, List, NamedTuple, Set, Tuple, cast
from typing import (
TYPE_CHECKING,
Callable,
Dict,
KeysView,
List,
NamedTuple,
Set,
Tuple,
cast,
)
from async_service import Service
import base58
@ -37,9 +49,13 @@ SUBSCRIPTION_CHANNEL_SIZE = 32
logger = logging.getLogger("libp2p.pubsub")
def get_msg_id(msg: rpc_pb2.Message) -> Tuple[bytes, bytes]:
def get_peer_and_seqno_msg_id(msg: rpc_pb2.Message) -> bytes:
# NOTE: `string(from, seqno)` in Go
return (msg.seqno, msg.from_id)
return msg.seqno + msg.from_id
def get_content_addressed_msg_id(msg: rpc_pb2.Message) -> bytes:
return base64.b64encode(hashlib.sha256(msg.data).digest())
class TopicValidator(NamedTuple):
@ -81,6 +97,9 @@ class Pubsub(Service, IPubsub):
router: "IPubsubRouter",
cache_size: int = None,
strict_signing: bool = True,
msg_id_constructor: Callable[
[rpc_pb2.Message], bytes
] = get_peer_and_seqno_msg_id,
) -> None:
"""
Construct a new Pubsub object, which is responsible for handling all
@ -95,6 +114,8 @@ class Pubsub(Service, IPubsub):
self.host = host
self.router = router
self._msg_id_constructor = msg_id_constructor
# Attach this new Pubsub object to the router
self.router.attach(self)
@ -586,11 +607,11 @@ class Pubsub(Service, IPubsub):
return self.counter.to_bytes(8, "big")
def _is_msg_seen(self, msg: rpc_pb2.Message) -> bool:
msg_id = get_msg_id(msg)
msg_id = self._msg_id_constructor(msg)
return msg_id in self.seen_messages
def _mark_msg_seen(self, msg: rpc_pb2.Message) -> None:
msg_id = get_msg_id(msg)
msg_id = self._msg_id_constructor(msg)
# FIXME: Mapping `msg_id` to `1` is quite awkward. Should investigate if there is a
# more appropriate way.
self.seen_messages[msg_id] = 1

View File

@ -26,7 +26,8 @@ from libp2p.peer.peerstore import PeerStore
from libp2p.pubsub.abc import IPubsubRouter
from libp2p.pubsub.floodsub import FloodSub
from libp2p.pubsub.gossipsub import GossipSub
from libp2p.pubsub.pubsub import Pubsub
import libp2p.pubsub.pb.rpc_pb2 as rpc_pb2
from libp2p.pubsub.pubsub import Pubsub, get_peer_and_seqno_msg_id
from libp2p.routing.interfaces import IPeerRouting
from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport
from libp2p.security.noise.messages import (
@ -370,13 +371,19 @@ class PubsubFactory(factory.Factory):
@classmethod
@asynccontextmanager
async def create_and_start(
cls, host: IHost, router: IPubsubRouter, cache_size: int, strict_signing: bool
cls,
host: IHost,
router: IPubsubRouter,
cache_size: int,
strict_signing: bool,
msg_id_constructor: Callable[[rpc_pb2.Message], bytes] = None,
) -> AsyncIterator[Pubsub]:
pubsub = cls(
host=host,
router=router,
cache_size=cache_size,
strict_signing=strict_signing,
msg_id_constructor=msg_id_constructor,
)
async with background_trio_service(pubsub):
await pubsub.wait_until_ready()
@ -392,6 +399,7 @@ class PubsubFactory(factory.Factory):
strict_signing: bool = False,
security_protocol: TProtocol = None,
muxer_opt: TMuxerOptions = None,
msg_id_constructor: Callable[[rpc_pb2.Message], bytes] = None,
) -> AsyncIterator[Tuple[Pubsub, ...]]:
async with HostFactory.create_batch_and_listen(
number, security_protocol=security_protocol, muxer_opt=muxer_opt
@ -400,7 +408,9 @@ class PubsubFactory(factory.Factory):
async with AsyncExitStack() as stack:
pubsubs = [
await stack.enter_async_context(
cls.create_and_start(host, router, cache_size, strict_signing)
cls.create_and_start(
host, router, cache_size, strict_signing, msg_id_constructor
)
)
for host, router in zip(hosts, routers)
]
@ -416,6 +426,9 @@ class PubsubFactory(factory.Factory):
protocols: Sequence[TProtocol] = None,
security_protocol: TProtocol = None,
muxer_opt: TMuxerOptions = None,
msg_id_constructor: Callable[
[rpc_pb2.Message], bytes
] = get_peer_and_seqno_msg_id,
) -> AsyncIterator[Tuple[Pubsub, ...]]:
if protocols is not None:
floodsubs = FloodsubFactory.create_batch(number, protocols=list(protocols))
@ -428,6 +441,7 @@ class PubsubFactory(factory.Factory):
strict_signing,
security_protocol=security_protocol,
muxer_opt=muxer_opt,
msg_id_constructor=msg_id_constructor,
) as pubsubs:
yield pubsubs
@ -450,6 +464,9 @@ class PubsubFactory(factory.Factory):
heartbeat_initial_delay: float = GOSSIPSUB_PARAMS.heartbeat_initial_delay,
security_protocol: TProtocol = None,
muxer_opt: TMuxerOptions = None,
msg_id_constructor: Callable[
[rpc_pb2.Message], bytes
] = get_peer_and_seqno_msg_id,
) -> AsyncIterator[Tuple[Pubsub, ...]]:
if protocols is not None:
gossipsubs = GossipsubFactory.create_batch(
@ -480,6 +497,7 @@ class PubsubFactory(factory.Factory):
strict_signing,
security_protocol=security_protocol,
muxer_opt=muxer_opt,
msg_id_constructor=msg_id_constructor,
) as pubsubs:
async with AsyncExitStack() as stack:
for router in gossipsubs:

View File

@ -37,10 +37,16 @@ async def test_simple_two_nodes():
@pytest.mark.trio
async def test_lru_cache_two_nodes(monkeypatch):
async def test_lru_cache_two_nodes():
# two nodes with cache_size of 4
# Mock `get_msg_id` to make us easier to manipulate `msg_id` by `data`.
def get_msg_id(msg):
# Originally it is `(msg.seqno, msg.from_id)`
return (msg.data, msg.from_id)
async with PubsubFactory.create_batch_with_floodsub(
2, cache_size=4
2, cache_size=4, msg_id_constructor=get_msg_id
) as pubsubs_fsub:
# `node_a` send the following messages to node_b
message_indices = [1, 1, 2, 1, 3, 1, 4, 1, 5, 1]
@ -49,15 +55,6 @@ async def test_lru_cache_two_nodes(monkeypatch):
topic = "my_topic"
# Mock `get_msg_id` to make us easier to manipulate `msg_id` by `data`.
def get_msg_id(msg):
# Originally it is `(msg.seqno, msg.from_id)`
return (msg.data, msg.from_id)
import libp2p.pubsub.pubsub
monkeypatch.setattr(libp2p.pubsub.pubsub, "get_msg_id", get_msg_id)
await connect(pubsubs_fsub[0].host, pubsubs_fsub[1].host)
await trio.sleep(0.25)