diff --git a/libp2p/pubsub/pubsub.py b/libp2p/pubsub/pubsub.py index 2045d48..965fc58 100644 --- a/libp2p/pubsub/pubsub.py +++ b/libp2p/pubsub/pubsub.py @@ -1,7 +1,19 @@ +import base64 import functools +import hashlib import logging import time -from typing import TYPE_CHECKING, Dict, KeysView, List, NamedTuple, Set, Tuple, cast +from typing import ( + TYPE_CHECKING, + Callable, + Dict, + KeysView, + List, + NamedTuple, + Set, + Tuple, + cast, +) from async_service import Service import base58 @@ -37,9 +49,13 @@ SUBSCRIPTION_CHANNEL_SIZE = 32 logger = logging.getLogger("libp2p.pubsub") -def get_msg_id(msg: rpc_pb2.Message) -> Tuple[bytes, bytes]: +def get_peer_and_seqno_msg_id(msg: rpc_pb2.Message) -> bytes: # NOTE: `string(from, seqno)` in Go - return (msg.seqno, msg.from_id) + return msg.seqno + msg.from_id + + +def get_content_addressed_msg_id(msg: rpc_pb2.Message) -> bytes: + return base64.b64encode(hashlib.sha256(msg.data).digest()) class TopicValidator(NamedTuple): @@ -81,6 +97,9 @@ class Pubsub(Service, IPubsub): router: "IPubsubRouter", cache_size: int = None, strict_signing: bool = True, + msg_id_constructor: Callable[ + [rpc_pb2.Message], bytes + ] = get_peer_and_seqno_msg_id, ) -> None: """ Construct a new Pubsub object, which is responsible for handling all @@ -95,6 +114,8 @@ class Pubsub(Service, IPubsub): self.host = host self.router = router + self._msg_id_constructor = msg_id_constructor + # Attach this new Pubsub object to the router self.router.attach(self) @@ -586,11 +607,11 @@ class Pubsub(Service, IPubsub): return self.counter.to_bytes(8, "big") def _is_msg_seen(self, msg: rpc_pb2.Message) -> bool: - msg_id = get_msg_id(msg) + msg_id = self._msg_id_constructor(msg) return msg_id in self.seen_messages def _mark_msg_seen(self, msg: rpc_pb2.Message) -> None: - msg_id = get_msg_id(msg) + msg_id = self._msg_id_constructor(msg) # FIXME: Mapping `msg_id` to `1` is quite awkward. Should investigate if there is a # more appropriate way. self.seen_messages[msg_id] = 1 diff --git a/libp2p/tools/factories.py b/libp2p/tools/factories.py index a97c888..d488a5a 100644 --- a/libp2p/tools/factories.py +++ b/libp2p/tools/factories.py @@ -26,7 +26,8 @@ from libp2p.peer.peerstore import PeerStore from libp2p.pubsub.abc import IPubsubRouter from libp2p.pubsub.floodsub import FloodSub from libp2p.pubsub.gossipsub import GossipSub -from libp2p.pubsub.pubsub import Pubsub +import libp2p.pubsub.pb.rpc_pb2 as rpc_pb2 +from libp2p.pubsub.pubsub import Pubsub, get_peer_and_seqno_msg_id from libp2p.routing.interfaces import IPeerRouting from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport from libp2p.security.noise.messages import ( @@ -370,13 +371,19 @@ class PubsubFactory(factory.Factory): @classmethod @asynccontextmanager async def create_and_start( - cls, host: IHost, router: IPubsubRouter, cache_size: int, strict_signing: bool + cls, + host: IHost, + router: IPubsubRouter, + cache_size: int, + strict_signing: bool, + msg_id_constructor: Callable[[rpc_pb2.Message], bytes] = None, ) -> AsyncIterator[Pubsub]: pubsub = cls( host=host, router=router, cache_size=cache_size, strict_signing=strict_signing, + msg_id_constructor=msg_id_constructor, ) async with background_trio_service(pubsub): await pubsub.wait_until_ready() @@ -392,6 +399,7 @@ class PubsubFactory(factory.Factory): strict_signing: bool = False, security_protocol: TProtocol = None, muxer_opt: TMuxerOptions = None, + msg_id_constructor: Callable[[rpc_pb2.Message], bytes] = None, ) -> AsyncIterator[Tuple[Pubsub, ...]]: async with HostFactory.create_batch_and_listen( number, security_protocol=security_protocol, muxer_opt=muxer_opt @@ -400,7 +408,9 @@ class PubsubFactory(factory.Factory): async with AsyncExitStack() as stack: pubsubs = [ await stack.enter_async_context( - cls.create_and_start(host, router, cache_size, strict_signing) + cls.create_and_start( + host, router, cache_size, strict_signing, msg_id_constructor + ) ) for host, router in zip(hosts, routers) ] @@ -416,6 +426,9 @@ class PubsubFactory(factory.Factory): protocols: Sequence[TProtocol] = None, security_protocol: TProtocol = None, muxer_opt: TMuxerOptions = None, + msg_id_constructor: Callable[ + [rpc_pb2.Message], bytes + ] = get_peer_and_seqno_msg_id, ) -> AsyncIterator[Tuple[Pubsub, ...]]: if protocols is not None: floodsubs = FloodsubFactory.create_batch(number, protocols=list(protocols)) @@ -428,6 +441,7 @@ class PubsubFactory(factory.Factory): strict_signing, security_protocol=security_protocol, muxer_opt=muxer_opt, + msg_id_constructor=msg_id_constructor, ) as pubsubs: yield pubsubs @@ -450,6 +464,9 @@ class PubsubFactory(factory.Factory): heartbeat_initial_delay: float = GOSSIPSUB_PARAMS.heartbeat_initial_delay, security_protocol: TProtocol = None, muxer_opt: TMuxerOptions = None, + msg_id_constructor: Callable[ + [rpc_pb2.Message], bytes + ] = get_peer_and_seqno_msg_id, ) -> AsyncIterator[Tuple[Pubsub, ...]]: if protocols is not None: gossipsubs = GossipsubFactory.create_batch( @@ -480,6 +497,7 @@ class PubsubFactory(factory.Factory): strict_signing, security_protocol=security_protocol, muxer_opt=muxer_opt, + msg_id_constructor=msg_id_constructor, ) as pubsubs: async with AsyncExitStack() as stack: for router in gossipsubs: diff --git a/tests/pubsub/test_floodsub.py b/tests/pubsub/test_floodsub.py index 36bf6fc..fe9f9cf 100644 --- a/tests/pubsub/test_floodsub.py +++ b/tests/pubsub/test_floodsub.py @@ -37,10 +37,16 @@ async def test_simple_two_nodes(): @pytest.mark.trio -async def test_lru_cache_two_nodes(monkeypatch): +async def test_lru_cache_two_nodes(): # two nodes with cache_size of 4 + + # Mock `get_msg_id` to make us easier to manipulate `msg_id` by `data`. + def get_msg_id(msg): + # Originally it is `(msg.seqno, msg.from_id)` + return (msg.data, msg.from_id) + async with PubsubFactory.create_batch_with_floodsub( - 2, cache_size=4 + 2, cache_size=4, msg_id_constructor=get_msg_id ) as pubsubs_fsub: # `node_a` send the following messages to node_b message_indices = [1, 1, 2, 1, 3, 1, 4, 1, 5, 1] @@ -49,15 +55,6 @@ async def test_lru_cache_two_nodes(monkeypatch): topic = "my_topic" - # Mock `get_msg_id` to make us easier to manipulate `msg_id` by `data`. - def get_msg_id(msg): - # Originally it is `(msg.seqno, msg.from_id)` - return (msg.data, msg.from_id) - - import libp2p.pubsub.pubsub - - monkeypatch.setattr(libp2p.pubsub.pubsub, "get_msg_id", get_msg_id) - await connect(pubsubs_fsub[0].host, pubsubs_fsub[1].host) await trio.sleep(0.25)