diff --git a/libp2p/exceptions.py b/libp2p/exceptions.py new file mode 100644 index 0000000..8d8af44 --- /dev/null +++ b/libp2p/exceptions.py @@ -0,0 +1,6 @@ +class ValidationError(Exception): + """ + Raised when something does not pass a validation check. + """ + + pass diff --git a/libp2p/pubsub/pubsub.py b/libp2p/pubsub/pubsub.py index 9a70b40..d0dad89 100644 --- a/libp2p/pubsub/pubsub.py +++ b/libp2p/pubsub/pubsub.py @@ -1,25 +1,39 @@ import asyncio +import logging import time -from typing import TYPE_CHECKING, Any, Dict, List, Tuple +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, NamedTuple, Tuple, Union from lru import LRU +from libp2p.exceptions import ValidationError from libp2p.host.host_interface import IHost from libp2p.network.stream.net_stream_interface import INetStream from libp2p.peer.id import ID from .pb import rpc_pb2 from .pubsub_notifee import PubsubNotifee +from .validators import signature_validator if TYPE_CHECKING: from .pubsub_router_interface import IPubsubRouter +log = logging.getLogger(__name__) + + def get_msg_id(msg: rpc_pb2.Message) -> Tuple[bytes, bytes]: # NOTE: `string(from, seqno)` in Go return (msg.seqno, msg.from_id) +SyncValidatorFn = Callable[[ID, rpc_pb2.Message], bool] +AsyncValidatorFn = Callable[[ID, rpc_pb2.Message], Awaitable[bool]] +ValidatorFn = Union[SyncValidatorFn, AsyncValidatorFn] + + +TopicValidator = NamedTuple("TopicValidator", (("validator", ValidatorFn), ("is_async", bool))) + + class Pubsub: host: IHost @@ -41,6 +55,8 @@ class Pubsub: peer_topics: Dict[str, List[ID]] peers: Dict[ID, INetStream] + topic_validators: Dict[str, TopicValidator] + # NOTE: Be sure it is increased atomically everytime. counter: int # uint64 @@ -93,6 +109,9 @@ class Pubsub: # Create peers map, which maps peer_id (as string) to stream (to a given peer) self.peers = {} + # Map of topic to topic validator + self.topic_validators = {} + self.counter = time.time_ns() # Call handle peer to keep waiting for updates to peer queue @@ -128,7 +147,7 @@ class Pubsub: continue # TODO(mhchia): This will block this read_stream loop until all data are pushed. # Should investigate further if this is an issue. - await self.push_msg(msg_forwarder=peer_id, msg=msg) + asyncio.ensure_future(self.push_msg(msg_forwarder=peer_id, msg=msg)) if rpc_incoming.subscriptions: # deal with RPC.subscriptions @@ -149,6 +168,34 @@ class Pubsub: # Force context switch await asyncio.sleep(0) + def set_topic_validator( + self, topic: str, validator: ValidatorFn, is_async_validator: bool + ) -> None: + """ + Register a validator under the given topic. One topic can only have one validtor. + :param topic: the topic to register validator under + :param validator: the validator used to validate messages published to the topic + :param is_async_validator: indicate if the validator is an asynchronous validator + """ + self.topic_validators[topic] = TopicValidator(validator, is_async_validator) + + def remove_topic_validator(self, topic: str) -> None: + """ + Remove the validator from the given topic. + :param topic: the topic to remove validator from + """ + if topic in self.topic_validators: + del self.topic_validators[topic] + + def get_msg_validators(self, msg: rpc_pb2.Message) -> Tuple[TopicValidator, ...]: + """ + Get all validators corresponding to the topics in the message. + :param msg: the message published to the topic + """ + return ( + self.topic_validators[topic] for topic in msg.topicIDs if topic in self.topic_validators + ) + async def stream_handler(self, stream: INetStream) -> None: """ Stream handler for pubsub. Gets invoked whenever a new stream is created @@ -320,6 +367,31 @@ class Pubsub: await self.push_msg(self.host.get_id(), msg) + async def validate_msg(self, msg_forwarder: ID, msg: rpc_pb2.Message) -> None: + """ + Validate the received message + :param msg_forwarder: the peer who forward us the message. + :param msg: the message. + """ + sync_topic_validators = [] + async_topic_validator_futures = [] + for topic_validator in self.get_msg_validators(msg): + if topic_validator.is_async: + async_topic_validator_futures.append(topic_validator.validator(msg_forwarder, msg)) + else: + sync_topic_validators.append(topic_validator.validator) + + for validator in sync_topic_validators: + if not validator(msg_forwarder, msg): + raise ValidationError(f"Validation failed for msg={msg}") + + # TODO: Implement throttle on async validators + + if len(async_topic_validator_futures) > 0: + results = await asyncio.gather(*async_topic_validator_futures) + if not all(results): + raise ValidationError(f"Validation failed for msg={msg}") + async def push_msg(self, msg_forwarder: ID, msg: rpc_pb2.Message) -> None: """ Push a pubsub message to others. @@ -332,10 +404,23 @@ class Pubsub: # TODO: Check if signing is required and if so signature should be attached. + # If the message is processed before, return(i.e., don't further process the message). if self._is_msg_seen(msg): return # TODO: - Validate the message. If failed, reject it. + # Validate the signature of the message + # FIXME: `signature_validator` is currently a stub. + if not signature_validator(msg.key, msg.SerializeToString()): + log.debug(f"Signature validation failed for msg={msg}") + return + # Validate the message with registered topic validators. + # If the validation failed, return(i.e., don't further process the message). + try: + await self.validate_msg(msg_forwarder, msg) + except ValidationError: + log.debug(f"Topic validation failed for msg={msg}") + return self._mark_msg_seen(msg) await self.handle_talk(msg) @@ -361,4 +446,4 @@ class Pubsub: def _is_subscribed_to_msg(self, msg: rpc_pb2.Message) -> bool: if not self.my_topics: return False - return all([topic in self.my_topics for topic in msg.topicIDs]) + return any(topic in self.my_topics for topic in msg.topicIDs) diff --git a/libp2p/pubsub/validators.py b/libp2p/pubsub/validators.py new file mode 100644 index 0000000..e575980 --- /dev/null +++ b/libp2p/pubsub/validators.py @@ -0,0 +1,9 @@ +# FIXME: Replace the type of `pubkey` with a custom type `Pubkey` +def signature_validator(pubkey: bytes, msg: bytes) -> bool: + """ + Verify the message against the given public key. + :param pubkey: the public key which signs the message. + :param msg: the message signed. + """ + # TODO: Implement the signature validation + return True diff --git a/tests/pubsub/floodsub_integration_test_settings.py b/tests/pubsub/floodsub_integration_test_settings.py index f72dc22..736e725 100644 --- a/tests/pubsub/floodsub_integration_test_settings.py +++ b/tests/pubsub/floodsub_integration_test_settings.py @@ -208,7 +208,7 @@ async def perform_test_from_obj(obj, router_factory): tasks_topic.append(asyncio.sleep(2)) # Gather is like Promise.all - responses = await asyncio.gather(*tasks_topic, return_exceptions=True) + responses = await asyncio.gather(*tasks_topic) for i in range(len(responses) - 1): node_id, topic = tasks_topic_data[i] if node_id not in queues_map: diff --git a/tests/pubsub/test_pubsub.py b/tests/pubsub/test_pubsub.py index 530677b..170b72b 100644 --- a/tests/pubsub/test_pubsub.py +++ b/tests/pubsub/test_pubsub.py @@ -4,6 +4,7 @@ from typing import NamedTuple import pytest +from libp2p.exceptions import ValidationError from libp2p.peer.id import ID from libp2p.pubsub.pb import rpc_pb2 from tests.utils import connect @@ -84,6 +85,148 @@ async def test_get_hello_packet(pubsubs_fsub): assert topic in topic_ids_in_hello +@pytest.mark.parametrize("num_hosts", (1,)) +@pytest.mark.asyncio +async def test_set_and_remove_topic_validator(pubsubs_fsub): + + is_sync_validator_called = False + + def sync_validator(peer_id, msg): + nonlocal is_sync_validator_called + is_sync_validator_called = True + + is_async_validator_called = False + + async def async_validator(peer_id, msg): + nonlocal is_async_validator_called + is_async_validator_called = True + + topic = "TEST_VALIDATOR" + + assert topic not in pubsubs_fsub[0].topic_validators + + # Register sync validator + pubsubs_fsub[0].set_topic_validator(topic, sync_validator, False) + + assert topic in pubsubs_fsub[0].topic_validators + topic_validator = pubsubs_fsub[0].topic_validators[topic] + assert not topic_validator.is_async + + # Validate with sync validator + topic_validator.validator(peer_id=ID(b"peer"), msg="msg") + + assert is_sync_validator_called + assert not is_async_validator_called + + # Register with async validator + pubsubs_fsub[0].set_topic_validator(topic, async_validator, True) + + is_sync_validator_called = False + assert topic in pubsubs_fsub[0].topic_validators + topic_validator = pubsubs_fsub[0].topic_validators[topic] + assert topic_validator.is_async + + # Validate with async validator + await topic_validator.validator(peer_id=ID(b"peer"), msg="msg") + + assert is_async_validator_called + assert not is_sync_validator_called + + # Remove validator + pubsubs_fsub[0].remove_topic_validator(topic) + assert topic not in pubsubs_fsub[0].topic_validators + + +@pytest.mark.parametrize("num_hosts", (1,)) +@pytest.mark.asyncio +async def test_get_msg_validators(pubsubs_fsub): + + times_sync_validator_called = 0 + + def sync_validator(peer_id, msg): + nonlocal times_sync_validator_called + times_sync_validator_called += 1 + + times_async_validator_called = 0 + + async def async_validator(peer_id, msg): + nonlocal times_async_validator_called + times_async_validator_called += 1 + + topic_1 = "TEST_VALIDATOR_1" + topic_2 = "TEST_VALIDATOR_2" + topic_3 = "TEST_VALIDATOR_3" + + # Register sync validator for topic 1 and 2 + pubsubs_fsub[0].set_topic_validator(topic_1, sync_validator, False) + pubsubs_fsub[0].set_topic_validator(topic_2, sync_validator, False) + + # Register async validator for topic 3 + pubsubs_fsub[0].set_topic_validator(topic_3, async_validator, True) + + msg = make_pubsub_msg( + origin_id=pubsubs_fsub[0].my_id, + topic_ids=[topic_1, topic_2, topic_3], + data=b"1234", + seqno=b"\x00" * 8, + ) + + topic_validators = pubsubs_fsub[0].get_msg_validators(msg) + for topic_validator in topic_validators: + if topic_validator.is_async: + await topic_validator.validator(peer_id=ID(b"peer"), msg="msg") + else: + topic_validator.validator(peer_id=ID(b"peer"), msg="msg") + + assert times_sync_validator_called == 2 + assert times_async_validator_called == 1 + + +@pytest.mark.parametrize("num_hosts", (1,)) +@pytest.mark.parametrize( + "is_topic_1_val_passed, is_topic_2_val_passed", ((False, True), (True, False), (True, True)) +) +@pytest.mark.asyncio +async def test_validate_msg(pubsubs_fsub, is_topic_1_val_passed, is_topic_2_val_passed): + def passed_sync_validator(peer_id, msg): + return True + + def failed_sync_validator(peer_id, msg): + return False + + async def passed_async_validator(peer_id, msg): + return True + + async def failed_async_validator(peer_id, msg): + return False + + topic_1 = "TEST_SYNC_VALIDATOR" + topic_2 = "TEST_ASYNC_VALIDATOR" + + if is_topic_1_val_passed: + pubsubs_fsub[0].set_topic_validator(topic_1, passed_sync_validator, False) + else: + pubsubs_fsub[0].set_topic_validator(topic_1, failed_sync_validator, False) + + if is_topic_2_val_passed: + pubsubs_fsub[0].set_topic_validator(topic_2, passed_async_validator, True) + else: + pubsubs_fsub[0].set_topic_validator(topic_2, failed_async_validator, True) + + msg = make_pubsub_msg( + origin_id=pubsubs_fsub[0].my_id, + topic_ids=[topic_1, topic_2], + data=b"1234", + seqno=b"\x00" * 8, + ) + + if is_topic_1_val_passed and is_topic_2_val_passed: + await pubsubs_fsub[0].validate_msg(pubsubs_fsub[0].my_id, msg) + else: + with pytest.raises(ValidationError): + await pubsubs_fsub[0].validate_msg(pubsubs_fsub[0].my_id, msg) + + class FakeNetStream: _queue: asyncio.Queue @@ -319,3 +462,23 @@ async def test_push_msg(pubsubs_fsub, monkeypatch): await asyncio.wait_for(event.wait(), timeout=0.1) # Test: Subscribers are notified when `push_msg` new messages. assert (await sub.get()) == msg_1 + + # Test: add a topic validator and `push_msg` the message that + # does not pass the validation. + # `router_publish` is not called then. + def failed_sync_validator(peer_id, msg): + return False + + pubsubs_fsub[0].set_topic_validator(TESTING_TOPIC, failed_sync_validator, False) + + msg_2 = make_pubsub_msg( + origin_id=pubsubs_fsub[0].my_id, + topic_ids=[TESTING_TOPIC], + data=TESTING_DATA, + seqno=b"\x22" * 8, + ) + + event.clear() + await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg_2) + await asyncio.sleep(0.01) + assert not event.is_set()