`_is_subscribed_to_msg` need only subscribe to one of the topics
This commit is contained in:
NIC619 2019-08-05 18:20:04 +08:00
parent a2efd03dfa
commit b96ef0e6c7
No known key found for this signature in database
GPG Key ID: 570C35F5C2D51B17
2 changed files with 4 additions and 22 deletions

View File

@ -1,17 +1,7 @@
import asyncio
from collections import namedtuple
import time
from typing import (
Any,
Awaitable,
Callable,
Dict,
Iterable,
List,
Tuple,
Union,
TYPE_CHECKING,
)
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Iterable, List, Tuple, Union
from lru import LRU
@ -381,9 +371,7 @@ class Pubsub:
async_topic_validator_futures = []
for topic_validator in self.get_msg_validators(msg):
if topic_validator.is_async:
async_topic_validator_futures.append(
topic_validator.validator(msg_forwarder, msg)
)
async_topic_validator_futures.append(topic_validator.validator(msg_forwarder, msg))
else:
sync_topic_validators.append(topic_validator.validator)
@ -448,4 +436,4 @@ class Pubsub:
def _is_subscribed_to_msg(self, msg: rpc_pb2.Message) -> bool:
if not self.my_topics:
return False
return all([topic in self.my_topics for topic in msg.topicIDs])
return any([topic in self.my_topics for topic in msg.topicIDs])

View File

@ -183,16 +183,10 @@ async def test_get_msg_validators(pubsubs_fsub):
@pytest.mark.parametrize("num_hosts", (1,))
@pytest.mark.parametrize(
"is_topic_1_val_passed, is_topic_2_val_passed",
(
(False, True),
(True, False),
(True, True),
)
"is_topic_1_val_passed, is_topic_2_val_passed", ((False, True), (True, False), (True, True))
)
@pytest.mark.asyncio
async def test_validate_msg(pubsubs_fsub, is_topic_1_val_passed, is_topic_2_val_passed):
def passed_sync_validator(peer_id, msg):
return True