Apply PR feedback

This commit is contained in:
NIC619 2019-08-06 12:32:18 +08:00
parent b96ef0e6c7
commit 47643a67c6
No known key found for this signature in database
GPG Key ID: 570C35F5C2D51B17
2 changed files with 19 additions and 13 deletions

View File

@ -1,7 +1,7 @@
import asyncio import asyncio
from collections import namedtuple import logging
import time import time
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Iterable, List, Tuple, Union from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, NamedTuple, Tuple, Union
from lru import LRU from lru import LRU
@ -17,17 +17,20 @@ if TYPE_CHECKING:
from .pubsub_router_interface import IPubsubRouter from .pubsub_router_interface import IPubsubRouter
log = logging.getLogger(__name__)
def get_msg_id(msg: rpc_pb2.Message) -> Tuple[bytes, bytes]: def get_msg_id(msg: rpc_pb2.Message) -> Tuple[bytes, bytes]:
# NOTE: `string(from, seqno)` in Go # NOTE: `string(from, seqno)` in Go
return (msg.seqno, msg.from_id) return (msg.seqno, msg.from_id)
TopicValidator = namedtuple("TopicValidator", ["validator", "is_async"])
ValidatorFn = Union[Callable[[ID, rpc_pb2.Message], bool], Awaitable[bool]] ValidatorFn = Union[Callable[[ID, rpc_pb2.Message], bool], Awaitable[bool]]
TopicValidator = NamedTuple("TopicValidator", (("validator", ValidatorFn), ("is_async", bool)))
class Pubsub: class Pubsub:
host: IHost host: IHost
@ -181,14 +184,14 @@ class Pubsub:
if topic in self.topic_validators: if topic in self.topic_validators:
del self.topic_validators[topic] del self.topic_validators[topic]
def get_msg_validators(self, msg: rpc_pb2.Message) -> Iterable[TopicValidator]: def get_msg_validators(self, msg: rpc_pb2.Message) -> Tuple[TopicValidator, ...]:
""" """
Get all validators corresponding to the topics in the message. Get all validators corresponding to the topics in the message.
:param msg: the message published to the topic :param msg: the message published to the topic
""" """
for topic in msg.topicIDs: return (
if topic in self.topic_validators: self.topic_validators[topic] for topic in msg.topicIDs if topic in self.topic_validators
yield self.topic_validators[topic] )
async def stream_handler(self, stream: INetStream) -> None: async def stream_handler(self, stream: INetStream) -> None:
""" """
@ -399,15 +402,18 @@ class Pubsub:
# TODO: Check if signing is required and if so signature should be attached. # TODO: Check if signing is required and if so signature should be attached.
# If the message is processed before, return(i.e., don't further process the message).
if self._is_msg_seen(msg): if self._is_msg_seen(msg):
return return
# TODO: - Validate the message. If failed, reject it. # TODO: - Validate the message. If failed, reject it.
# Validate the signature of the message # Validate the signature of the message
# FIXME: `signature_validator` is currently a stub. # FIXME: `signature_validator` is currently a stub.
if not signature_validator(msg.key, msg.SerializeToString()): if not signature_validator(msg.key, msg.SerializeToString(), msg.singature):
log.debug(f"Signature validation failed for msg={msg}")
return return
# Validate the message with registered topic validators # Validate the message with registered topic validators.
# If the validation failed, return(i.e., don't further process the message).
is_validation_passed = await self.validate_msg(msg_forwarder, msg) is_validation_passed = await self.validate_msg(msg_forwarder, msg)
if not is_validation_passed: if not is_validation_passed:
return return
@ -436,4 +442,4 @@ class Pubsub:
def _is_subscribed_to_msg(self, msg: rpc_pb2.Message) -> bool: def _is_subscribed_to_msg(self, msg: rpc_pb2.Message) -> bool:
if not self.my_topics: if not self.my_topics:
return False return False
return any([topic in self.my_topics for topic in msg.topicIDs]) return any(topic in self.my_topics for topic in msg.topicIDs)

View File

@ -1,5 +1,5 @@
# FIXME: Replace the type of `pubkey` with a custom type `Pubkey` # FIXME: Replace the type of `pubkey` with a custom type `Pubkey`
def signature_validator(pubkey: bytes, msg: bytes) -> bool: def signature_validator(pubkey: bytes, msg: bytes, sig: bytes) -> bool:
""" """
Verify the message against the given public key. Verify the message against the given public key.
:param pubkey: the public key which signs the message. :param pubkey: the public key which signs the message.