From e1b86904e3f26f21e8f7ae0bc7d77af3cc1ae476 Mon Sep 17 00:00:00 2001 From: NIC619 Date: Sun, 4 Aug 2019 18:13:23 +0800 Subject: [PATCH] Add `validate_msg` and test --- libp2p/pubsub/pubsub.py | 32 ++++++++--- libp2p/pubsub/validators.py | 9 +++ .../floodsub_integration_test_settings.py | 2 +- tests/pubsub/test_pubsub.py | 57 +++++++++++++++++-- 4 files changed, 87 insertions(+), 13 deletions(-) create mode 100644 libp2p/pubsub/validators.py diff --git a/libp2p/pubsub/pubsub.py b/libp2p/pubsub/pubsub.py index 0e6cbaf..84df48f 100644 --- a/libp2p/pubsub/pubsub.py +++ b/libp2p/pubsub/pubsub.py @@ -1,7 +1,7 @@ import asyncio from collections import namedtuple import time -from typing import Any, Awaitable, Callable, Dict, List, Tuple, Union, TYPE_CHECKING +from typing import Any, Awaitable, Callable, Dict, Iterable, List, Tuple, Union, TYPE_CHECKING from lru import LRU @@ -176,15 +176,13 @@ class Pubsub: if topic in self.topic_validators: del self.topic_validators[topic] - def get_msg_validators(self, msg: rpc_pb2.Message) -> Tuple[TopicValidator, ...]: + def get_msg_validators(self, msg: rpc_pb2.Message) -> Iterable[TopicValidator]: """ Get all validators corresponding to the topics in the message. """ - return ( - self.topic_validators[topic] - for topic in msg.topicIDs - if topic in self.topic_validators - ) + for topic in msg.topicIDs: + if topic in self.topic_validators: + yield self.topic_validators[topic] async def stream_handler(self, stream: INetStream) -> None: """ @@ -357,6 +355,26 @@ class Pubsub: await self.push_msg(self.host.get_id(), msg) + async def validate_msg(self, msg_forwarder: ID, msg: rpc_pb2.Message) -> bool: + sync_topic_validators = [] + async_topic_validator_futures = [] + for topic_validator in self.get_msg_validators(msg): + if topic_validator.is_async: + async_topic_validator_futures.append( + topic_validator.validator(msg_forwarder, msg) + ) + else: + sync_topic_validators.append(topic_validator.validator) + + for validator in sync_topic_validators: + if not validator(msg_forwarder, msg): + return False + + # TODO: Implement throttle on async validators + + results = await asyncio.gather(*async_topic_validator_futures) + return all(results) + async def push_msg(self, msg_forwarder: ID, msg: rpc_pb2.Message) -> None: """ Push a pubsub message to others. diff --git a/libp2p/pubsub/validators.py b/libp2p/pubsub/validators.py new file mode 100644 index 0000000..e575980 --- /dev/null +++ b/libp2p/pubsub/validators.py @@ -0,0 +1,9 @@ +# FIXME: Replace the type of `pubkey` with a custom type `Pubkey` +def signature_validator(pubkey: bytes, msg: bytes) -> bool: + """ + Verify the message against the given public key. + :param pubkey: the public key which signs the message. + :param msg: the message signed. + """ + # TODO: Implement the signature validation + return True diff --git a/tests/pubsub/floodsub_integration_test_settings.py b/tests/pubsub/floodsub_integration_test_settings.py index f72dc22..736e725 100644 --- a/tests/pubsub/floodsub_integration_test_settings.py +++ b/tests/pubsub/floodsub_integration_test_settings.py @@ -208,7 +208,7 @@ async def perform_test_from_obj(obj, router_factory): tasks_topic.append(asyncio.sleep(2)) # Gather is like Promise.all - responses = await asyncio.gather(*tasks_topic, return_exceptions=True) + responses = await asyncio.gather(*tasks_topic) for i in range(len(responses) - 1): node_id, topic = tasks_topic_data[i] if node_id not in queues_map: diff --git a/tests/pubsub/test_pubsub.py b/tests/pubsub/test_pubsub.py index 907f3b8..d66960a 100644 --- a/tests/pubsub/test_pubsub.py +++ b/tests/pubsub/test_pubsub.py @@ -160,14 +160,9 @@ async def test_get_msg_validators(pubsubs_fsub): pubsubs_fsub[0].set_topic_validator(topic_1, sync_validator, False) pubsubs_fsub[0].set_topic_validator(topic_2, sync_validator, False) - assert topic_1 in pubsubs_fsub[0].topic_validators - assert topic_2 in pubsubs_fsub[0].topic_validators - # Register async validator for topic 3 pubsubs_fsub[0].set_topic_validator(topic_3, async_validator, True) - assert topic_3 in pubsubs_fsub[0].topic_validators - msg = make_pubsub_msg( origin_id=pubsubs_fsub[0].my_id, topic_ids=[topic_1, topic_2, topic_3], @@ -186,6 +181,58 @@ async def test_get_msg_validators(pubsubs_fsub): assert times_async_validator_called == 1 +@pytest.mark.parametrize("num_hosts", (1,)) +@pytest.mark.parametrize( + "is_topic_1_val_passed, is_topic_2_val_passed", + ( + (False, True), + (True, False), + (True, True), + ) +) +@pytest.mark.asyncio +async def test_validate_msg(pubsubs_fsub, is_topic_1_val_passed, is_topic_2_val_passed): + + def passed_sync_validator(peer_id, msg): + return True + + def failed_sync_validator(peer_id, msg): + return False + + async def passed_async_validator(peer_id, msg): + return True + + async def failed_async_validator(peer_id, msg): + return False + + topic_1 = "TEST_SYNC_VALIDATOR" + topic_2 = "TEST_ASYNC_VALIDATOR" + + if is_topic_1_val_passed: + pubsubs_fsub[0].set_topic_validator(topic_1, passed_sync_validator, False) + else: + pubsubs_fsub[0].set_topic_validator(topic_1, failed_sync_validator, False) + + if is_topic_2_val_passed: + pubsubs_fsub[0].set_topic_validator(topic_2, passed_async_validator, True) + else: + pubsubs_fsub[0].set_topic_validator(topic_2, failed_async_validator, True) + + msg = make_pubsub_msg( + origin_id=pubsubs_fsub[0].my_id, + topic_ids=[topic_1, topic_2], + data=b"1234", + seqno=b"\x00" * 8, + ) + + is_validation_passed = await pubsubs_fsub[0].validate_msg(pubsubs_fsub[0].my_id, msg) + + if is_topic_1_val_passed and is_topic_2_val_passed: + assert is_validation_passed + else: + assert not is_validation_passed + + class FakeNetStream: _queue: asyncio.Queue