Fix all modules except for security

This commit is contained in:
mhchia 2019-12-06 17:06:37 +08:00
parent e9ab0646e3
commit 1929f307fb
No known key found for this signature in database
GPG Key ID: 389EFBEA1362589A
28 changed files with 764 additions and 955 deletions

View File

@ -1,6 +1,7 @@
import trio
import logging
import trio
from libp2p.network.stream.exceptions import StreamClosed, StreamEOF, StreamReset
from libp2p.network.stream.net_stream_interface import INetStream
from libp2p.peer.id import ID as PeerID

View File

@ -1,7 +1,6 @@
import logging
import trio
from trio import SocketStream
from libp2p.io.abc import ReadWriteCloser
from libp2p.io.exceptions import IOException
@ -9,29 +8,48 @@ from libp2p.io.exceptions import IOException
logger = logging.getLogger("libp2p.io.trio")
class TrioReadWriteCloser(ReadWriteCloser):
stream: SocketStream
class TrioTCPStream(ReadWriteCloser):
stream: trio.SocketStream
# NOTE: Add both read and write lock to avoid `trio.BusyResourceError`
read_lock: trio.Lock
write_lock: trio.Lock
def __init__(self, stream: SocketStream) -> None:
def __init__(self, stream: trio.SocketStream) -> None:
self.stream = stream
self.read_lock = trio.Lock()
self.write_lock = trio.Lock()
async def write(self, data: bytes) -> None:
"""Raise `RawConnError` if the underlying connection breaks."""
try:
await self.stream.send_all(data)
except (trio.ClosedResourceError, trio.BrokenResourceError) as error:
raise IOException(error)
async with self.write_lock:
try:
await self.stream.send_all(data)
except (trio.ClosedResourceError, trio.BrokenResourceError) as error:
raise IOException from error
except trio.BusyResourceError as error:
# This should never happen, since we already access streams with read/write locks.
raise Exception(
"this should never happen "
"since we already access streams with read/write locks."
) from error
async def read(self, n: int = -1) -> bytes:
if n == 0:
# Check point
await trio.sleep(0)
return b""
max_bytes = n if n != -1 else None
try:
return await self.stream.receive_some(max_bytes)
except (trio.ClosedResourceError, trio.BrokenResourceError) as error:
raise IOException(error)
async with self.read_lock:
if n == 0:
# Checkpoint
await trio.hazmat.checkpoint()
return b""
max_bytes = n if n != -1 else None
try:
return await self.stream.receive_some(max_bytes)
except (trio.ClosedResourceError, trio.BrokenResourceError) as error:
raise IOException from error
except trio.BusyResourceError as error:
# This should never happen, since we already access streams with read/write locks.
raise Exception(
"this should never happen "
"since we already access streams with read/write locks."
) from error
async def close(self) -> None:
await self.stream.aclose()

View File

@ -1,5 +1,3 @@
import trio
from libp2p.io.abc import ReadWriteCloser
from libp2p.io.exceptions import IOException
@ -8,17 +6,17 @@ from .raw_connection_interface import IRawConnection
class RawConnection(IRawConnection):
read_write_closer: ReadWriteCloser
stream: ReadWriteCloser
is_initiator: bool
def __init__(self, read_write_closer: ReadWriteCloser, initiator: bool) -> None:
self.read_write_closer = read_write_closer
def __init__(self, stream: ReadWriteCloser, initiator: bool) -> None:
self.stream = stream
self.is_initiator = initiator
async def write(self, data: bytes) -> None:
"""Raise `RawConnError` if the underlying connection breaks."""
try:
await self.read_write_closer.write(data)
await self.stream.write(data)
except IOException as error:
raise RawConnError(error)
@ -30,9 +28,9 @@ class RawConnection(IRawConnection):
Raise `RawConnError` if the underlying connection breaks
"""
try:
return await self.read_write_closer.read(n)
return await self.stream.read(n)
except IOException as error:
raise RawConnError(error)
async def close(self) -> None:
await self.read_write_closer.close()
await self.stream.close()

View File

@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Any, Awaitable, List, Set, Tuple
from typing import TYPE_CHECKING, Set, Tuple
from async_service import Service
import trio
@ -45,16 +45,11 @@ class SwarmConn(INetConn, Service):
# before we cancel the stream handler tasks.
await trio.sleep(0.1)
# FIXME: Now let `_notify_disconnected` finish first.
# Schedule `self._notify_disconnected` to make it execute after `close` is finished.
await self._notify_disconnected()
async def _handle_new_streams(self) -> None:
while self.manager.is_running:
try:
print(
f"!@# SwarmConn._handle_new_streams: {self.muxed_conn._id}: waiting for new streams"
)
stream = await self.muxed_conn.accept_stream()
except MuxedConnUnavailable:
# If there is anything wrong in the MuxedConn,
@ -63,9 +58,6 @@ class SwarmConn(INetConn, Service):
# Asynchronously handle the accepted stream, to avoid blocking the next stream.
self.manager.run_task(self._handle_muxed_stream, stream)
print(
f"!@# SwarmConn._handle_new_streams: {self.muxed_conn._id}: out of the loop"
)
await self.close()
async def _call_stream_handler(self, net_stream: NetStream) -> None:
@ -92,8 +84,7 @@ class SwarmConn(INetConn, Service):
await self.swarm.notify_disconnected(self)
async def run(self) -> None:
self.manager.run_task(self._handle_new_streams)
await self.manager.wait_finished()
await self._handle_new_streams()
async def new_stream(self) -> NetStream:
muxed_stream = await self.muxed_conn.open_stream()

View File

@ -203,16 +203,17 @@ class Swarm(INetwork, Service):
await self.add_conn(muxed_conn)
logger.debug("successfully opened connection to peer %s", peer_id)
# FIXME: This is a intentional barrier to prevent from the handler exiting and
# closing the connection. Probably change to `Service.manager.wait_finished`?
await trio.sleep_forever()
# NOTE: This is a intentional barrier to prevent from the handler exiting and
# closing the connection.
await self.manager.wait_finished()
try:
# Success
listener = self.transport.create_listener(conn_handler)
self.listeners[str(maddr)] = listener
# FIXME: Hack
await listener.listen(maddr, self.manager._task_nursery)
# TODO: `listener.listen` is not bounded with nursery. If we want to be
# I/O agnostic, we should change the API.
await listener.listen(maddr, self.manager._task_nursery) # type: ignore
# Call notifiers since event occurred
await self.notify_listen(maddr)
@ -278,6 +279,7 @@ class Swarm(INetwork, Service):
"""
self.notifees.append(notifee)
# TODO: Use `run_task`.
async def notify_opened_stream(self, stream: INetStream) -> None:
async with trio.open_nursery() as nursery:
for notifee in self.notifees:

View File

@ -64,7 +64,7 @@ class FloodSub(IPubsubRouter):
:param rpc: rpc message
"""
# Checkpoint
await trio.sleep(0)
await trio.hazmat.checkpoint()
async def publish(self, msg_forwarder: ID, pubsub_msg: rpc_pb2.Message) -> None:
"""
@ -107,7 +107,7 @@ class FloodSub(IPubsubRouter):
:param topic: topic to join
"""
# Checkpoint
await trio.sleep(0)
await trio.hazmat.checkpoint()
async def leave(self, topic: str) -> None:
"""
@ -117,7 +117,7 @@ class FloodSub(IPubsubRouter):
:param topic: topic to leave
"""
# Checkpoint
await trio.sleep(0)
await trio.hazmat.checkpoint()
def _get_peers_to_send(
self, topic_ids: Iterable[str], msg_forwarder: ID, origin: ID

View File

@ -1,15 +1,18 @@
from ast import literal_eval
import asyncio
import logging
import random
from typing import Any, Dict, Iterable, List, Sequence, Set
from async_service import Service
import trio
from libp2p.network.stream.exceptions import StreamClosed
from libp2p.peer.id import ID
from libp2p.pubsub import floodsub
from libp2p.typing import TProtocol
from libp2p.utils import encode_varint_prefixed
from .exceptions import NoPubsubAttached
from .mcache import MessageCache
from .pb import rpc_pb2
from .pubsub import Pubsub
@ -20,8 +23,7 @@ PROTOCOL_ID = TProtocol("/meshsub/1.0.0")
logger = logging.getLogger("libp2p.pubsub.gossipsub")
class GossipSub(IPubsubRouter):
class GossipSub(IPubsubRouter, Service):
protocols: List[TProtocol]
pubsub: Pubsub
@ -86,6 +88,12 @@ class GossipSub(IPubsubRouter):
# Create heartbeat timer
self.heartbeat_interval = heartbeat_interval
async def run(self) -> None:
if self.pubsub is None:
raise NoPubsubAttached
self.manager.run_task(self.heartbeat)
await self.manager.wait_finished()
# Interface functions
def get_protocols(self) -> List[TProtocol]:
@ -105,10 +113,6 @@ class GossipSub(IPubsubRouter):
logger.debug("attached to pusub")
# Start heartbeat now that we have a pubsub instance
# TODO: Start after delay
asyncio.ensure_future(self.heartbeat())
def add_peer(self, peer_id: ID, protocol_id: TProtocol) -> None:
"""
Notifies the router that a new peer has been connected.
@ -310,7 +314,7 @@ class GossipSub(IPubsubRouter):
await self.fanout_heartbeat()
await self.gossip_heartbeat()
await asyncio.sleep(self.heartbeat_interval)
await trio.sleep(self.heartbeat_interval)
async def mesh_heartbeat(self) -> None:
# Note: the comments here are the exact pseudocode from the spec
@ -338,7 +342,7 @@ class GossipSub(IPubsubRouter):
if num_mesh_peers_in_topic > self.degree_high:
# Select |mesh[topic]| - D peers from mesh[topic]
selected_peers = GossipSub.select_from_minus(
selected_peers = self.select_from_minus(
num_mesh_peers_in_topic - self.degree, self.mesh[topic], []
)
for peer in selected_peers:
@ -353,7 +357,10 @@ class GossipSub(IPubsubRouter):
for topic in self.fanout:
# If time since last published > ttl
# TODO: there's no way time_since_last_publish gets set anywhere yet
if self.time_since_last_publish[topic] > self.time_to_live:
if (
topic in self.time_since_last_publish
and self.time_since_last_publish[topic] > self.time_to_live
):
# Remove topic from fanout
del self.fanout[topic]
del self.time_since_last_publish[topic]
@ -407,11 +414,7 @@ class GossipSub(IPubsubRouter):
topic, self.degree, []
)
for peer in peers_to_emit_ihave_to:
if (
peer not in self.mesh[topic]
and peer not in self.fanout[topic]
):
if peer not in self.fanout[topic]:
msg_id_strs = [str(msg) for msg in msg_ids]
await self.emit_ihave(topic, msg_id_strs, peer)

View File

@ -1,4 +1,4 @@
from abc import ABC, abstractmethod
from abc import ABC
import logging
import math
import time
@ -57,6 +57,7 @@ class TopicValidator(NamedTuple):
is_async: bool
# TODO: Add interface for Pubsub
class BasePubsub(ABC):
pass
@ -103,20 +104,24 @@ class Pubsub(BasePubsub, Service):
# Attach this new Pubsub object to the router
self.router.attach(self)
peer_send_channel, peer_receive_channel = trio.open_memory_channel(0)
dead_peer_send_channel, dead_peer_receive_channel = trio.open_memory_channel(0)
peer_channels: Tuple[
"trio.MemorySendChannel[ID]", "trio.MemoryReceiveChannel[ID]"
] = trio.open_memory_channel(0)
dead_peer_channels: Tuple[
"trio.MemorySendChannel[ID]", "trio.MemoryReceiveChannel[ID]"
] = trio.open_memory_channel(0)
# Only keep the receive channels in `Pubsub`.
# Therefore, we can only close from the receive side.
self.peer_receive_channel = peer_receive_channel
self.dead_peer_receive_channel = dead_peer_receive_channel
self.peer_receive_channel = peer_channels[1]
self.dead_peer_receive_channel = dead_peer_channels[1]
# Register a notifee
self.host.get_network().register_notifee(
PubsubNotifee(peer_send_channel, dead_peer_send_channel)
PubsubNotifee(peer_channels[0], dead_peer_channels[0])
)
# Register stream handlers for each pubsub router protocol to handle
# the pubsub streams opened on those protocols
for protocol in router.protocols:
for protocol in router.get_protocols():
self.host.set_stream_handler(protocol, self.stream_handler)
# keeps track of seen messages as LRU cache
@ -328,8 +333,9 @@ class Pubsub(BasePubsub, Service):
self.manager.run_task(self._handle_new_peer, peer_id)
async def handle_dead_peer_queue(self) -> None:
"""Continuously read from dead peer channel and close the stream between
that peer and remove peer info from pubsub and pubsub router."""
"""Continuously read from dead peer channel and close the stream
between that peer and remove peer info from pubsub and pubsub
router."""
async with self.dead_peer_receive_channel:
while self.manager.is_running:
peer_id: ID = await self.dead_peer_receive_channel.receive()
@ -391,7 +397,11 @@ class Pubsub(BasePubsub, Service):
return self.subscribed_topics_receive[topic_id]
# Map topic_id to a blocking channel
send_channel, receive_channel = trio.open_memory_channel(math.inf)
channels: Tuple[
"trio.MemorySendChannel[rpc_pb2.Message]",
"trio.MemoryReceiveChannel[rpc_pb2.Message]",
] = trio.open_memory_channel(math.inf)
send_channel, receive_channel = channels
self.subscribed_topics_send[topic_id] = send_channel
self.subscribed_topics_receive[topic_id] = receive_channel
@ -506,7 +516,7 @@ class Pubsub(BasePubsub, Service):
if len(async_topic_validators) > 0:
# TODO: Use a better pattern
final_result = True
final_result: bool = True
async def run_async_validator(func: AsyncValidatorFn) -> None:
nonlocal final_result
@ -514,8 +524,8 @@ class Pubsub(BasePubsub, Service):
final_result = final_result and result
async with trio.open_nursery() as nursery:
for validator in async_topic_validators:
nursery.start_soon(run_async_validator, validator)
for async_validator in async_topic_validators:
nursery.start_soon(run_async_validator, async_validator)
if not final_result:
raise ValidationError(f"Validation failed for msg={msg}")

View File

@ -1,11 +1,13 @@
from abc import ABC, abstractmethod
from abc import abstractmethod
from async_service import ServiceAPI
from libp2p.io.abc import ReadWriteCloser
from libp2p.peer.id import ID
from libp2p.security.secure_conn_interface import ISecureConn
class IMuxedConn(ABC):
class IMuxedConn(ServiceAPI):
"""
reference: https://github.com/libp2p/go-stream-muxer/blob/master/muxer.go
"""

View File

@ -1,7 +1,6 @@
import logging
import math
from typing import Any # noqa: F401
from typing import Awaitable, Dict, List, Optional, Tuple
from typing import Dict, Optional, Tuple
from async_service import Service
import trio
@ -67,13 +66,15 @@ class Mplex(IMuxedConn, Service):
self.streams = {}
self.streams_lock = trio.Lock()
self.streams_msg_channels = {}
send_channel, receive_channel = trio.open_memory_channel(math.inf)
self.new_stream_send_channel = send_channel
self.new_stream_receive_channel = receive_channel
channels: Tuple[
"trio.MemorySendChannel[IMuxedStream]",
"trio.MemoryReceiveChannel[IMuxedStream]",
] = trio.open_memory_channel(math.inf)
self.new_stream_send_channel, self.new_stream_receive_channel = channels
self.event_shutting_down = trio.Event()
self.event_closed = trio.Event()
async def run(self):
async def run(self) -> None:
self.manager.run_task(self.handle_incoming)
await self.manager.wait_finished()
@ -112,11 +113,13 @@ class Mplex(IMuxedConn, Service):
async def _initialize_stream(self, stream_id: StreamID, name: str) -> MplexStream:
# Use an unbounded buffer, to avoid `handle_incoming` being blocked when doing
# `send_channel.send`.
send_channel, receive_channel = trio.open_memory_channel(math.inf)
stream = MplexStream(name, stream_id, self, receive_channel)
channels: Tuple[
"trio.MemorySendChannel[bytes]", "trio.MemoryReceiveChannel[bytes]"
] = trio.open_memory_channel(math.inf)
stream = MplexStream(name, stream_id, self, channels[1])
async with self.streams_lock:
self.streams[stream_id] = stream
self.streams_msg_channels[stream_id] = send_channel
self.streams_msg_channels[stream_id] = channels[0]
return stream
async def open_stream(self) -> IMuxedStream:
@ -150,9 +153,6 @@ class Mplex(IMuxedConn, Service):
:param data: data to send in the message
:param stream_id: stream the message is in
"""
print(
f"!@# send_message: {self._id}: flag={flag}, data={data}, stream_id={stream_id}"
)
# << by 3, then or with flag
header = encode_uvarint((stream_id.channel_id << 3) | flag.value)
@ -179,19 +179,10 @@ class Mplex(IMuxedConn, Service):
while self.manager.is_running:
try:
print(
f"!@# handle_incoming: {self._id}: before _handle_incoming_message"
)
await self._handle_incoming_message()
print(
f"!@# handle_incoming: {self._id}: after _handle_incoming_message"
)
except MplexUnavailable as e:
logger.debug("mplex unavailable while waiting for incoming: %s", e)
print(f"!@# handle_incoming: {self._id}: MplexUnavailable: {e}")
break
print(f"!@# handle_incoming: {self._id}: leaving")
# If we enter here, it means this connection is shutting down.
# We should clean things up.
await self._cleanup()
@ -232,44 +223,27 @@ class Mplex(IMuxedConn, Service):
:raise MplexUnavailable: `Mplex` encounters fatal error or is shutting down.
"""
print(f"!@# _handle_incoming_message: {self._id}: before reading")
channel_id, flag, message = await self.read_message()
print(
f"!@# _handle_incoming_message: {self._id}: channel_id={channel_id}, flag={flag}, message={message}"
)
stream_id = StreamID(channel_id=channel_id, is_initiator=bool(flag & 1))
print(f"!@# _handle_incoming_message: {self._id}: 2")
if flag == HeaderTags.NewStream.value:
print(f"!@# _handle_incoming_message: {self._id}: 3")
await self._handle_new_stream(stream_id, message)
print(f"!@# _handle_incoming_message: {self._id}: 4")
elif flag in (
HeaderTags.MessageInitiator.value,
HeaderTags.MessageReceiver.value,
):
print(f"!@# _handle_incoming_message: {self._id}: 5")
await self._handle_message(stream_id, message)
print(f"!@# _handle_incoming_message: {self._id}: 6")
elif flag in (HeaderTags.CloseInitiator.value, HeaderTags.CloseReceiver.value):
print(f"!@# _handle_incoming_message: {self._id}: 7")
await self._handle_close(stream_id)
print(f"!@# _handle_incoming_message: {self._id}: 8")
elif flag in (HeaderTags.ResetInitiator.value, HeaderTags.ResetReceiver.value):
print(f"!@# _handle_incoming_message: {self._id}: 9")
await self._handle_reset(stream_id)
print(f"!@# _handle_incoming_message: {self._id}: 10")
else:
print(f"!@# _handle_incoming_message: {self._id}: 11")
# Receives messages with an unknown flag
# TODO: logging
async with self.streams_lock:
print(f"!@# _handle_incoming_message: {self._id}: 12")
if stream_id in self.streams:
print(f"!@# _handle_incoming_message: {self._id}: 13")
stream = self.streams[stream_id]
await stream.reset()
print(f"!@# _handle_incoming_message: {self._id}: 14")
async def _handle_new_stream(self, stream_id: StreamID, message: bytes) -> None:
async with self.streams_lock:
@ -285,59 +259,43 @@ class Mplex(IMuxedConn, Service):
raise MplexUnavailable
async def _handle_message(self, stream_id: StreamID, message: bytes) -> None:
print(
f"!@# _handle_message: {self._id}: stream_id={stream_id}, message={message}"
)
async with self.streams_lock:
print(f"!@# _handle_message: {self._id}: 1")
if stream_id not in self.streams:
# We receive a message of the stream `stream_id` which is not accepted
# before. It is abnormal. Possibly disconnect?
# TODO: Warn and emit logs about this.
print(f"!@# _handle_message: {self._id}: 2")
return
print(f"!@# _handle_message: {self._id}: 3")
stream = self.streams[stream_id]
send_channel = self.streams_msg_channels[stream_id]
async with stream.close_lock:
print(f"!@# _handle_message: {self._id}: 4")
if stream.event_remote_closed.is_set():
print(f"!@# _handle_message: {self._id}: 5")
# TODO: Warn "Received data from remote after stream was closed by them. (len = %d)" # noqa: E501
return
print(f"!@# _handle_message: {self._id}: 6")
await send_channel.send(message)
print(f"!@# _handle_message: {self._id}: 7")
async def _handle_close(self, stream_id: StreamID) -> None:
print(f"!@# _handle_close: {self._id}: step=0")
async with self.streams_lock:
if stream_id not in self.streams:
# Ignore unmatched messages for now.
return
stream = self.streams[stream_id]
send_channel = self.streams_msg_channels[stream_id]
print(f"!@# _handle_close: {self._id}: step=1")
await send_channel.aclose()
print(f"!@# _handle_close: {self._id}: step=2")
# NOTE: If remote is already closed, then return: Technically a bug
# on the other side. We should consider killing the connection.
async with stream.close_lock:
if stream.event_remote_closed.is_set():
return
print(f"!@# _handle_close: {self._id}: step=3")
is_local_closed: bool
async with stream.close_lock:
stream.event_remote_closed.set()
is_local_closed = stream.event_local_closed.is_set()
print(f"!@# _handle_close: {self._id}: step=4")
# If local is also closed, both sides are closed. Then, we should clean up
# the entry of this stream, to avoid others from accessing it.
if is_local_closed:
async with self.streams_lock:
if stream_id in self.streams:
del self.streams[stream_id]
print(f"!@# _handle_close: {self._id}: step=5")
async def _handle_reset(self, stream_id: StreamID) -> None:
async with self.streams_lock:

View File

@ -1,30 +1,29 @@
from contextlib import AsyncExitStack, asynccontextmanager
from typing import Any, AsyncIterator, Dict, Tuple, cast
from typing import Any, AsyncIterator, Dict, Sequence, Tuple, cast
from async_service import background_trio_service
import factory
import trio
from libp2p.tools.constants import GOSSIPSUB_PARAMS
from libp2p import generate_new_rsa_identity, generate_peer_id_from
from libp2p.crypto.keys import KeyPair
from libp2p.host.basic_host import BasicHost
from libp2p.host.routed_host import RoutedHost
from libp2p.tools.utils import set_up_routers
from libp2p.kademlia.network import KademliaServer
from libp2p.host.host_interface import IHost
from libp2p.network.connection.swarm_connection import SwarmConn
from libp2p.network.stream.net_stream_interface import INetStream
from libp2p.network.swarm import Swarm
from libp2p.peer.peerstore import PeerStore
from libp2p.peer.id import ID
from libp2p.peer.peerstore import PeerStore
from libp2p.pubsub.floodsub import FloodSub
from libp2p.pubsub.gossipsub import GossipSub
from libp2p.pubsub.pubsub import Pubsub
from libp2p.pubsub.pubsub_router_interface import IPubsubRouter
from libp2p.security.base_transport import BaseSecureTransport
from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport
import libp2p.security.secio.transport as secio
from libp2p.stream_muxer.mplex.mplex import MPLEX_PROTOCOL_ID, Mplex
from libp2p.stream_muxer.mplex.mplex_stream import MplexStream
from libp2p.tools.constants import GOSSIPSUB_PARAMS
from libp2p.transport.tcp.tcp import TCP
from libp2p.transport.typing import TMuxerOptions
from libp2p.transport.upgrader import TransportUpgrader
@ -74,7 +73,7 @@ class SwarmFactory(factory.Factory):
@asynccontextmanager
async def create_and_listen(
cls, is_secure: bool, key_pair: KeyPair = None, muxer_opt: TMuxerOptions = None
) -> Swarm:
) -> AsyncIterator[Swarm]:
# `factory.Factory.__init__` does *not* prepare a *default value* if we pass
# an argument explicitly with `None`. If an argument is `None`, we don't pass it to
# `factory.Factory.__init__`, in order to let the function initialize it.
@ -92,7 +91,7 @@ class SwarmFactory(factory.Factory):
@asynccontextmanager
async def create_batch_and_listen(
cls, is_secure: bool, number: int, muxer_opt: TMuxerOptions = None
) -> Tuple[Swarm, ...]:
) -> AsyncIterator[Tuple[Swarm, ...]]:
async with AsyncExitStack() as stack:
ctx_mgrs = [
await stack.enter_async_context(
@ -100,7 +99,7 @@ class SwarmFactory(factory.Factory):
)
for _ in range(number)
]
yield ctx_mgrs
yield tuple(ctx_mgrs)
class HostFactory(factory.Factory):
@ -120,7 +119,7 @@ class HostFactory(factory.Factory):
@asynccontextmanager
async def create_batch_and_listen(
cls, is_secure: bool, number: int
) -> Tuple[BasicHost, ...]:
) -> AsyncIterator[Tuple[BasicHost, ...]]:
key_pairs = [generate_new_rsa_identity() for _ in range(number)]
async with AsyncExitStack() as stack:
swarms = [
@ -136,30 +135,6 @@ class HostFactory(factory.Factory):
yield hosts
class RoutedHostFactory(factory.Factory):
class Meta:
model = RoutedHost
public_key = factory.LazyAttribute(lambda o: o.key_pair.public_key)
network = factory.LazyAttribute(
lambda o: SwarmFactory(is_secure=o.is_secure, key_pair=o.key_pair)
)
router = factory.LazyFunction(KademliaServer)
@classmethod
@asynccontextmanager
async def create_batch_and_listen(
cls, is_secure: bool, number: int
) -> Tuple[RoutedHost, ...]:
key_pairs = [generate_new_rsa_identity() for _ in range(number)]
routers = await set_up_routers((0,) * number)
async with SwarmFactory.create_batch_and_listen(is_secure, number) as swarms:
yield tuple(
RoutedHost(key_pair.public_key, swarm, router)
for key_pair, swarm, router in zip(key_pairs, swarms, routers)
)
class FloodsubFactory(factory.Factory):
class Meta:
model = FloodSub
@ -191,17 +166,22 @@ class PubsubFactory(factory.Factory):
@classmethod
@asynccontextmanager
async def create_and_start(cls, host, router, cache_size):
async def create_and_start(
cls, host: IHost, router: IPubsubRouter, cache_size: int
) -> AsyncIterator[Pubsub]:
pubsub = PubsubFactory(host=host, router=router, cache_size=cache_size)
async with background_trio_service(pubsub):
yield pubsub
@classmethod
@asynccontextmanager
async def create_batch_with_floodsub(
cls, number: int, is_secure: bool = False, cache_size: int = None
):
floodsubs = FloodsubFactory.create_batch(number)
async def _create_batch_with_router(
cls,
number: int,
routers: Sequence[IPubsubRouter],
is_secure: bool = False,
cache_size: int = None,
) -> AsyncIterator[Tuple[Pubsub, ...]]:
async with HostFactory.create_batch_and_listen(is_secure, number) as hosts:
# Pubsubs should exit before hosts
async with AsyncExitStack() as stack:
@ -209,21 +189,80 @@ class PubsubFactory(factory.Factory):
await stack.enter_async_context(
cls.create_and_start(host, router, cache_size)
)
for host, router in zip(hosts, floodsubs)
for host, router in zip(hosts, routers)
]
yield pubsubs
yield tuple(pubsubs)
# @classmethod
# async def create_batch_with_gossipsub(
# cls, number: int, cache_size: int = None, gossipsub_params=GOSSIPSUB_PARAMS
# ):
# ...
@classmethod
@asynccontextmanager
async def create_batch_with_floodsub(
cls,
number: int,
is_secure: bool = False,
cache_size: int = None,
protocols: Sequence[TProtocol] = None,
) -> AsyncIterator[Tuple[Pubsub, ...]]:
if protocols is not None:
floodsubs = FloodsubFactory.create_batch(number, protocols=list(protocols))
else:
floodsubs = FloodsubFactory.create_batch(number)
async with cls._create_batch_with_router(
number, floodsubs, is_secure, cache_size
) as pubsubs:
yield pubsubs
@classmethod
@asynccontextmanager
async def create_batch_with_gossipsub(
cls,
number: int,
*,
is_secure: bool = False,
cache_size: int = None,
protocols: Sequence[TProtocol] = None,
degree: int = GOSSIPSUB_PARAMS.degree,
degree_low: int = GOSSIPSUB_PARAMS.degree_low,
degree_high: int = GOSSIPSUB_PARAMS.degree_high,
time_to_live: int = GOSSIPSUB_PARAMS.time_to_live,
gossip_window: int = GOSSIPSUB_PARAMS.gossip_window,
gossip_history: int = GOSSIPSUB_PARAMS.gossip_history,
heartbeat_interval: float = GOSSIPSUB_PARAMS.heartbeat_interval,
) -> AsyncIterator[Tuple[Pubsub, ...]]:
if protocols is not None:
gossipsubs = GossipsubFactory.create_batch(
number,
protocols=protocols,
degree=degree,
degree_low=degree_low,
degree_high=degree_high,
time_to_live=time_to_live,
gossip_window=gossip_window,
heartbeat_interval=heartbeat_interval,
)
else:
gossipsubs = GossipsubFactory.create_batch(
number,
degree=degree,
degree_low=degree_low,
degree_high=degree_high,
time_to_live=time_to_live,
gossip_window=gossip_window,
heartbeat_interval=heartbeat_interval,
)
async with cls._create_batch_with_router(
number, gossipsubs, is_secure, cache_size
) as pubsubs:
async with AsyncExitStack() as stack:
for router in gossipsubs:
await stack.enter_async_context(background_trio_service(router))
yield pubsubs
@asynccontextmanager
async def swarm_pair_factory(
is_secure: bool, muxer_opt: TMuxerOptions = None
) -> Tuple[Swarm, Swarm]:
) -> AsyncIterator[Tuple[Swarm, Swarm]]:
async with SwarmFactory.create_batch_and_listen(
is_secure, 2, muxer_opt=muxer_opt
) as swarms:
@ -232,7 +271,9 @@ async def swarm_pair_factory(
@asynccontextmanager
async def host_pair_factory(is_secure: bool) -> Tuple[BasicHost, BasicHost]:
async def host_pair_factory(
is_secure: bool
) -> AsyncIterator[Tuple[BasicHost, BasicHost]]:
async with HostFactory.create_batch_and_listen(is_secure, 2) as hosts:
await connect(hosts[0], hosts[1])
yield hosts[0], hosts[1]
@ -241,7 +282,7 @@ async def host_pair_factory(is_secure: bool) -> Tuple[BasicHost, BasicHost]:
@asynccontextmanager
async def swarm_conn_pair_factory(
is_secure: bool, muxer_opt: TMuxerOptions = None
) -> Tuple[SwarmConn, SwarmConn]:
) -> AsyncIterator[Tuple[SwarmConn, SwarmConn]]:
async with swarm_pair_factory(is_secure) as swarms:
conn_0 = swarms[0].connections[swarms[1].get_peer_id()]
conn_1 = swarms[1].connections[swarms[0].get_peer_id()]
@ -249,7 +290,9 @@ async def swarm_conn_pair_factory(
@asynccontextmanager
async def mplex_conn_pair_factory(is_secure: bool) -> Tuple[Mplex, Mplex]:
async def mplex_conn_pair_factory(
is_secure: bool
) -> AsyncIterator[Tuple[Mplex, Mplex]]:
muxer_opt = {MPLEX_PROTOCOL_ID: Mplex}
async with swarm_conn_pair_factory(is_secure, muxer_opt=muxer_opt) as swarm_pair:
yield (
@ -259,21 +302,25 @@ async def mplex_conn_pair_factory(is_secure: bool) -> Tuple[Mplex, Mplex]:
@asynccontextmanager
async def mplex_stream_pair_factory(is_secure: bool) -> Tuple[MplexStream, MplexStream]:
async def mplex_stream_pair_factory(
is_secure: bool
) -> AsyncIterator[Tuple[MplexStream, MplexStream]]:
async with mplex_conn_pair_factory(is_secure) as mplex_conn_pair_info:
mplex_conn_0, mplex_conn_1 = mplex_conn_pair_info
stream_0 = await mplex_conn_0.open_stream()
stream_0 = cast(MplexStream, await mplex_conn_0.open_stream())
await trio.sleep(0.01)
stream_1: MplexStream
async with mplex_conn_1.streams_lock:
if len(mplex_conn_1.streams) != 1:
raise Exception("Mplex should not have any other stream")
stream_1 = tuple(mplex_conn_1.streams.values())[0]
yield cast(MplexStream, stream_0), cast(MplexStream, stream_1)
yield stream_0, stream_1
@asynccontextmanager
async def net_stream_pair_factory(is_secure: bool) -> Tuple[INetStream, INetStream]:
async def net_stream_pair_factory(
is_secure: bool
) -> AsyncIterator[Tuple[INetStream, INetStream]]:
protocol_id = TProtocol("/example/id/1")
stream_1: INetStream

View File

@ -1,12 +1,11 @@
import asyncio
from typing import Dict
import uuid
from contextlib import AsyncExitStack, asynccontextmanager
from typing import AsyncIterator, Dict, Tuple
from async_service import Service, background_trio_service
from libp2p.host.host_interface import IHost
from libp2p.pubsub.floodsub import FloodSub
from libp2p.pubsub.pubsub import Pubsub
from libp2p.tools.constants import LISTEN_MADDR
from libp2p.tools.factories import FloodsubFactory, PubsubFactory
from libp2p.tools.factories import PubsubFactory
CRYPTO_TOPIC = "ethereum"
@ -18,7 +17,7 @@ CRYPTO_TOPIC = "ethereum"
# Determine message type by looking at first item before first comma
class DummyAccountNode:
class DummyAccountNode(Service):
"""
Node which has an internal balance mapping, meant to serve as a dummy
crypto blockchain.
@ -27,19 +26,24 @@ class DummyAccountNode:
crypto each user in the mappings holds
"""
libp2p_node: IHost
pubsub: Pubsub
floodsub: FloodSub
def __init__(self, libp2p_node: IHost, pubsub: Pubsub, floodsub: FloodSub):
self.libp2p_node = libp2p_node
def __init__(self, pubsub: Pubsub) -> None:
self.pubsub = pubsub
self.floodsub = floodsub
self.balances: Dict[str, int] = {}
self.node_id = str(uuid.uuid1())
@property
def host(self) -> IHost:
return self.pubsub.host
async def run(self) -> None:
self.subscription = await self.pubsub.subscribe(CRYPTO_TOPIC)
self.manager.run_daemon_task(self.handle_incoming_msgs)
await self.manager.wait_finished()
@classmethod
async def create(cls) -> "DummyAccountNode":
@asynccontextmanager
async def create(cls, number: int) -> AsyncIterator[Tuple["DummyAccountNode", ...]]:
"""
Create a new DummyAccountNode and attach a libp2p node, a floodsub, and
a pubsub instance to this new node.
@ -47,15 +51,17 @@ class DummyAccountNode:
We use create as this serves as a factory function and allows us
to use async await, unlike the init function
"""
pubsub = PubsubFactory(router=FloodsubFactory())
await pubsub.host.get_network().listen(LISTEN_MADDR)
return cls(libp2p_node=pubsub.host, pubsub=pubsub, floodsub=pubsub.router)
async with PubsubFactory.create_batch_with_floodsub(number) as pubsubs:
async with AsyncExitStack() as stack:
dummy_acount_nodes = tuple(cls(pubsub) for pubsub in pubsubs)
for node in dummy_acount_nodes:
await stack.enter_async_context(background_trio_service(node))
yield dummy_acount_nodes
async def handle_incoming_msgs(self) -> None:
"""Handle all incoming messages on the CRYPTO_TOPIC from peers."""
while True:
incoming = await self.q.get()
incoming = await self.subscription.receive()
msg_comps = incoming.data.decode("utf-8").split(",")
if msg_comps[0] == "send":
@ -63,13 +69,6 @@ class DummyAccountNode:
elif msg_comps[0] == "set":
self.handle_set_crypto(msg_comps[1], int(msg_comps[2]))
async def setup_crypto_networking(self) -> None:
"""Subscribe to CRYPTO_TOPIC and perform call to function that handles
all incoming messages on said topic."""
self.q = await self.pubsub.subscribe(CRYPTO_TOPIC)
asyncio.ensure_future(self.handle_incoming_msgs())
async def publish_send_crypto(
self, source_user: str, dest_user: str, amount: int
) -> None:

View File

@ -1,12 +1,10 @@
# type: ignore
# To add typing to this module, it's better to do it after refactoring test cases into classes
import asyncio
import pytest
import trio
from libp2p.tools.constants import FLOODSUB_PROTOCOL_ID, LISTEN_MADDR
from libp2p.tools.factories import PubsubFactory
from libp2p.tools.constants import FLOODSUB_PROTOCOL_ID
from libp2p.tools.utils import connect
SUPPORTED_PROTOCOLS = [FLOODSUB_PROTOCOL_ID]
@ -15,6 +13,7 @@ FLOODSUB_PROTOCOL_TEST_CASES = [
{
"name": "simple_two_nodes",
"supported_protocols": SUPPORTED_PROTOCOLS,
"nodes": ["A", "B"],
"adj_list": {"A": ["B"]},
"topic_map": {"topic1": ["B"]},
"messages": [{"topics": ["topic1"], "data": b"foo", "node_id": "A"}],
@ -22,6 +21,7 @@ FLOODSUB_PROTOCOL_TEST_CASES = [
{
"name": "three_nodes_two_topics",
"supported_protocols": SUPPORTED_PROTOCOLS,
"nodes": ["A", "B", "C"],
"adj_list": {"A": ["B"], "B": ["C"]},
"topic_map": {"topic1": ["B", "C"], "topic2": ["B", "C"]},
"messages": [
@ -32,6 +32,7 @@ FLOODSUB_PROTOCOL_TEST_CASES = [
{
"name": "two_nodes_one_topic_single_subscriber_is_sender",
"supported_protocols": SUPPORTED_PROTOCOLS,
"nodes": ["A", "B"],
"adj_list": {"A": ["B"]},
"topic_map": {"topic1": ["B"]},
"messages": [{"topics": ["topic1"], "data": b"Alex is tall", "node_id": "B"}],
@ -39,6 +40,7 @@ FLOODSUB_PROTOCOL_TEST_CASES = [
{
"name": "two_nodes_one_topic_two_msgs",
"supported_protocols": SUPPORTED_PROTOCOLS,
"nodes": ["A", "B"],
"adj_list": {"A": ["B"]},
"topic_map": {"topic1": ["B"]},
"messages": [
@ -49,6 +51,7 @@ FLOODSUB_PROTOCOL_TEST_CASES = [
{
"name": "seven_nodes_tree_one_topics",
"supported_protocols": SUPPORTED_PROTOCOLS,
"nodes": ["1", "2", "3", "4", "5", "6", "7"],
"adj_list": {"1": ["2", "3"], "2": ["4", "5"], "3": ["6", "7"]},
"topic_map": {"astrophysics": ["2", "3", "4", "5", "6", "7"]},
"messages": [{"topics": ["astrophysics"], "data": b"e=mc^2", "node_id": "1"}],
@ -56,6 +59,7 @@ FLOODSUB_PROTOCOL_TEST_CASES = [
{
"name": "seven_nodes_tree_three_topics",
"supported_protocols": SUPPORTED_PROTOCOLS,
"nodes": ["1", "2", "3", "4", "5", "6", "7"],
"adj_list": {"1": ["2", "3"], "2": ["4", "5"], "3": ["6", "7"]},
"topic_map": {
"astrophysics": ["2", "3", "4", "5", "6", "7"],
@ -71,6 +75,7 @@ FLOODSUB_PROTOCOL_TEST_CASES = [
{
"name": "seven_nodes_tree_three_topics_diff_origin",
"supported_protocols": SUPPORTED_PROTOCOLS,
"nodes": ["1", "2", "3", "4", "5", "6", "7"],
"adj_list": {"1": ["2", "3"], "2": ["4", "5"], "3": ["6", "7"]},
"topic_map": {
"astrophysics": ["1", "2", "3", "4", "5", "6", "7"],
@ -86,6 +91,7 @@ FLOODSUB_PROTOCOL_TEST_CASES = [
{
"name": "three_nodes_clique_two_topic_diff_origin",
"supported_protocols": SUPPORTED_PROTOCOLS,
"nodes": ["1", "2", "3"],
"adj_list": {"1": ["2", "3"], "2": ["3"]},
"topic_map": {"astrophysics": ["1", "2", "3"], "school": ["1", "2", "3"]},
"messages": [
@ -97,6 +103,7 @@ FLOODSUB_PROTOCOL_TEST_CASES = [
{
"name": "four_nodes_clique_two_topic_diff_origin_many_msgs",
"supported_protocols": SUPPORTED_PROTOCOLS,
"nodes": ["1", "2", "3", "4"],
"adj_list": {
"1": ["2", "3", "4"],
"2": ["1", "3", "4"],
@ -120,6 +127,7 @@ FLOODSUB_PROTOCOL_TEST_CASES = [
{
"name": "five_nodes_ring_two_topic_diff_origin_many_msgs",
"supported_protocols": SUPPORTED_PROTOCOLS,
"nodes": ["1", "2", "3", "4", "5"],
"adj_list": {"1": ["2"], "2": ["3"], "3": ["4"], "4": ["5"], "5": ["1"]},
"topic_map": {
"astrophysics": ["1", "2", "3", "4", "5"],
@ -143,7 +151,7 @@ floodsub_protocol_pytest_params = [
]
async def perform_test_from_obj(obj, router_factory) -> None:
async def perform_test_from_obj(obj, pubsub_factory) -> None:
"""
Perform pubsub tests from a test obj.
test obj are composed as follows:
@ -174,88 +182,75 @@ async def perform_test_from_obj(obj, router_factory) -> None:
# Step 1) Create graph
adj_list = obj["adj_list"]
node_list = obj["nodes"]
node_map = {}
pubsub_map = {}
async def add_node(node_id_str: str) -> None:
pubsub_router = router_factory(protocols=obj["supported_protocols"])
pubsub = PubsubFactory(router=pubsub_router)
await pubsub.host.get_network().listen(LISTEN_MADDR)
node_map[node_id_str] = pubsub.host
pubsub_map[node_id_str] = pubsub
async with pubsub_factory(
number=len(node_list), protocols=obj["supported_protocols"]
) as pubsubs:
for node_id_str, pubsub in zip(node_list, pubsubs):
node_map[node_id_str] = pubsub.host
pubsub_map[node_id_str] = pubsub
tasks_connect = []
for start_node_id in adj_list:
# Create node if node does not yet exist
if start_node_id not in node_map:
await add_node(start_node_id)
# Connect nodes and wait at least for 2 seconds
async with trio.open_nursery() as nursery:
for start_node_id in adj_list:
# For each neighbor of start_node, create if does not yet exist,
# then connect start_node to neighbor
for neighbor_id in adj_list[start_node_id]:
nursery.start_soon(
connect, node_map[start_node_id], node_map[neighbor_id]
)
nursery.start_soon(trio.sleep, 2)
# For each neighbor of start_node, create if does not yet exist,
# then connect start_node to neighbor
for neighbor_id in adj_list[start_node_id]:
# Create neighbor if neighbor does not yet exist
if neighbor_id not in node_map:
await add_node(neighbor_id)
tasks_connect.append(
connect(node_map[start_node_id], node_map[neighbor_id])
)
# Connect nodes and wait at least for 2 seconds
await asyncio.gather(*tasks_connect, asyncio.sleep(2))
# Step 2) Subscribe to topics
queues_map = {}
topic_map = obj["topic_map"]
# Step 2) Subscribe to topics
queues_map = {}
topic_map = obj["topic_map"]
async def subscribe_node(node_id, topic):
if node_id not in queues_map:
queues_map[node_id] = {}
# Avoid repeated works
if topic in queues_map[node_id]:
# Checkpoint
await trio.hazmat.checkpoint()
return
sub = await pubsub_map[node_id].subscribe(topic)
queues_map[node_id][topic] = sub
tasks_topic = []
tasks_topic_data = []
for topic, node_ids in topic_map.items():
for node_id in node_ids:
tasks_topic.append(pubsub_map[node_id].subscribe(topic))
tasks_topic_data.append((node_id, topic))
tasks_topic.append(asyncio.sleep(2))
async with trio.open_nursery() as nursery:
for topic, node_ids in topic_map.items():
for node_id in node_ids:
nursery.start_soon(subscribe_node, node_id, topic)
nursery.start_soon(trio.sleep, 2)
# Gather is like Promise.all
responses = await asyncio.gather(*tasks_topic)
for i in range(len(responses) - 1):
node_id, topic = tasks_topic_data[i]
if node_id not in queues_map:
queues_map[node_id] = {}
# Store queue in topic-queue map for node
queues_map[node_id][topic] = responses[i]
# Step 3) Publish messages
topics_in_msgs_ordered = []
messages = obj["messages"]
# Allow time for subscribing before continuing
await asyncio.sleep(0.01)
for msg in messages:
topics = msg["topics"]
data = msg["data"]
node_id = msg["node_id"]
# Step 3) Publish messages
topics_in_msgs_ordered = []
messages = obj["messages"]
tasks_publish = []
# Publish message
# TODO: Should be single RPC package with several topics
for topic in topics:
await pubsub_map[node_id].publish(topic, data)
for msg in messages:
topics = msg["topics"]
data = msg["data"]
node_id = msg["node_id"]
# For each topic in topics, add (topic, node_id, data) tuple to ordered test list
for topic in topics:
topics_in_msgs_ordered.append((topic, node_id, data))
# Allow time for publishing before continuing
await trio.sleep(1)
# Publish message
# TODO: Should be single RPC package with several topics
for topic in topics:
tasks_publish.append(pubsub_map[node_id].publish(topic, data))
# For each topic in topics, add (topic, node_id, data) tuple to ordered test list
for topic in topics:
topics_in_msgs_ordered.append((topic, node_id, data))
# Allow time for publishing before continuing
await asyncio.gather(*tasks_publish, asyncio.sleep(2))
# Step 4) Check that all messages were received correctly.
for topic, origin_node_id, data in topics_in_msgs_ordered:
# Look at each node in each topic
for node_id in topic_map[topic]:
# Get message from subscription queue
msg = await queues_map[node_id][topic].get()
assert data == msg.data
# Check the message origin
assert node_map[origin_node_id].get_id().to_bytes() == msg.from_id
# Success, terminate pending tasks.
# Step 4) Check that all messages were received correctly.
for topic, origin_node_id, data in topics_in_msgs_ordered:
# Look at each node in each topic
for node_id in topic_map[topic]:
# Get message from subscription queue
msg = await queues_map[node_id][topic].receive()
assert data == msg.data
# Check the message origin
assert node_map[origin_node_id].get_id().to_bytes() == msg.from_id

View File

@ -1,17 +1,9 @@
from typing import Callable, List, Sequence, Tuple
from typing import Awaitable, Callable
import multiaddr
import trio
from libp2p import new_node
from libp2p.host.basic_host import BasicHost
from libp2p.host.host_interface import IHost
from libp2p.kademlia.network import KademliaServer
from libp2p.network.stream.net_stream_interface import INetStream
from libp2p.network.swarm import Swarm
from libp2p.peer.peerinfo import info_from_p2p_addr
from libp2p.routing.interfaces import IPeerRouting
from libp2p.routing.kademlia.kademlia_peer_router import KadmeliaPeerRouter
from .constants import MAX_READ_LEN
@ -36,49 +28,9 @@ async def connect(node1: IHost, node2: IHost) -> None:
await node1.connect(info)
async def set_up_nodes_by_transport_opt(
transport_opt_list: Sequence[Sequence[str]], nursery: trio.Nursery
) -> Tuple[BasicHost, ...]:
nodes_list = []
for transport_opt in transport_opt_list:
node = new_node(transport_opt=transport_opt)
await node.get_network().listen(
multiaddr.Multiaddr(transport_opt[0]), nursery=nursery
)
nodes_list.append(node)
return tuple(nodes_list)
async def set_up_nodes_by_transport_and_disc_opt(
transport_disc_opt_list: Sequence[Tuple[Sequence[str], IPeerRouting]]
) -> Tuple[BasicHost, ...]:
nodes_list = []
for transport_opt, disc_opt in transport_disc_opt_list:
node = await new_node(transport_opt=transport_opt, disc_opt=disc_opt)
await node.get_network().listen(multiaddr.Multiaddr(transport_opt[0]))
nodes_list.append(node)
return tuple(nodes_list)
async def set_up_routers(
router_ports: Tuple[int, ...] = (0, 0)
) -> List[KadmeliaPeerRouter]:
"""The default ``router_confs`` selects two free ports local to this
machine."""
bootstrap_node = KademliaServer() # type: ignore
await bootstrap_node.listen(router_ports[0])
routers = [KadmeliaPeerRouter(bootstrap_node)]
for port in router_ports[1:]:
node = KademliaServer() # type: ignore
await node.listen(port)
await node.bootstrap_node(bootstrap_node.address)
routers.append(KadmeliaPeerRouter(node))
return routers
def create_echo_stream_handler(ack_prefix: str) -> Callable[[INetStream], None]:
def create_echo_stream_handler(
ack_prefix: str
) -> Callable[[INetStream], Awaitable[None]]:
async def echo_stream_handler(stream: INetStream) -> None:
while True:
read_string = (await stream.read(MAX_READ_LEN)).decode()

View File

@ -1,12 +1,13 @@
from abc import ABC, abstractmethod
from typing import List
from typing import Tuple
from multiaddr import Multiaddr
import trio
class IListener(ABC):
@abstractmethod
async def listen(self, maddr: Multiaddr) -> bool:
async def listen(self, maddr: Multiaddr, nursery: trio.Nursery) -> bool:
"""
put listener in listening mode and wait for incoming connections.
@ -15,14 +16,9 @@ class IListener(ABC):
"""
@abstractmethod
def get_addrs(self) -> List[Multiaddr]:
def get_addrs(self) -> Tuple[Multiaddr, ...]:
"""
retrieve list of addresses the listener is listening on.
:return: return list of addrs
"""
@abstractmethod
async def close(self) -> None:
"""close the listener such that no more connections can be open on this
transport instance."""

View File

@ -1,14 +1,13 @@
import logging
from socket import socket
from typing import List
from typing import Awaitable, Callable, List, Sequence, Tuple
from multiaddr import Multiaddr
import trio
from trio_typing import TaskStatus
from libp2p.io.trio import TrioReadWriteCloser
from libp2p.io.trio import TrioTCPStream
from libp2p.network.connection.raw_connection import RawConnection
from libp2p.network.connection.raw_connection_interface import IRawConnection
from libp2p.transport.exceptions import OpenConnectionError
from libp2p.transport.listener_interface import IListener
from libp2p.transport.transport_interface import ITransport
from libp2p.transport.typing import THandler
@ -18,14 +17,12 @@ logger = logging.getLogger("libp2p.transport.tcp")
class TCPListener(IListener):
multiaddrs: List[Multiaddr]
server = None
def __init__(self, handler_function: THandler) -> None:
self.multiaddrs = []
self.server = None
self.handler = handler_function
# TODO: Fix handling?
# TODO: Get rid of `nursery`?
async def listen(self, maddr: Multiaddr, nursery: trio.Nursery) -> None:
"""
put listener in listening mode and wait for incoming connections.
@ -34,13 +31,18 @@ class TCPListener(IListener):
:return: return True if successful
"""
async def serve_tcp(handler, port, host, task_status=None):
async def serve_tcp(
handler: Callable[[trio.SocketStream], Awaitable[None]],
port: int,
host: str,
task_status: TaskStatus[Sequence[trio.SocketListener]] = None,
) -> None:
logger.debug("serve_tcp %s %s", host, port)
await trio.serve_tcp(handler, port, host=host, task_status=task_status)
async def handler(stream):
read_write_closer = TrioReadWriteCloser(stream)
await self.handler(read_write_closer)
async def handler(stream: trio.SocketStream) -> None:
tcp_stream = TrioTCPStream(stream)
await self.handler(tcp_stream)
listeners = await nursery.start(
serve_tcp,
@ -51,7 +53,7 @@ class TCPListener(IListener):
socket = listeners[0].socket
self.multiaddrs.append(_multiaddr_from_socket(socket))
def get_addrs(self) -> List[Multiaddr]:
def get_addrs(self) -> Tuple[Multiaddr, ...]:
"""
retrieve list of addresses the listener is listening on.
@ -59,15 +61,6 @@ class TCPListener(IListener):
"""
return tuple(self.multiaddrs)
async def close(self) -> None:
"""close the listener such that no more connections can be open on this
transport instance."""
if self.server is None:
return
self.server.close()
await self.server.wait_closed()
self.server = None
class TCP(ITransport):
async def dial(self, maddr: Multiaddr) -> IRawConnection:
@ -82,7 +75,7 @@ class TCP(ITransport):
self.port = int(maddr.value_for_protocol("tcp"))
stream = await trio.open_tcp_stream(self.host, self.port)
read_write_closer = TrioReadWriteCloser(stream)
read_write_closer = TrioTCPStream(stream)
return RawConnection(read_write_closer, True)
@ -97,5 +90,6 @@ class TCP(ITransport):
return TCPListener(handler_function)
def _multiaddr_from_socket(socket: socket) -> Multiaddr:
return Multiaddr("/ip4/%s/tcp/%s" % socket.getsockname())
def _multiaddr_from_socket(socket: trio.socket.SocketType) -> Multiaddr:
ip, port = socket.getsockname() # type: ignore
return Multiaddr(f"/ip4/{ip}/tcp/{port}")

View File

@ -1,7 +1,7 @@
import trio
import secrets
import pytest
import trio
from libp2p.host.ping import ID, PING_LENGTH
from libp2p.tools.factories import host_pair_factory

View File

@ -1,73 +0,0 @@
import pytest
from libp2p.host.exceptions import ConnectionFailure
from libp2p.peer.peerinfo import PeerInfo
from libp2p.routing.kademlia.kademlia_peer_router import peer_info_to_str
from libp2p.tools.utils import (
set_up_nodes_by_transport_and_disc_opt,
set_up_nodes_by_transport_opt,
set_up_routers,
)
from libp2p.tools.factories import RoutedHostFactory
# FIXME:
# TODO: Kademlia is full of asyncio code. Skip it for now
@pytest.mark.skip
@pytest.mark.trio
async def test_host_routing_success(is_host_secure):
async with RoutedHostFactory.create_batch_and_listen(
is_host_secure, 2
) as routed_hosts:
# Set routing info
await routed_hosts[0]._router.server.set(
routed_hosts[0].get_id().xor_id,
peer_info_to_str(
PeerInfo(routed_hosts[0].get_id(), routed_hosts[0].get_addrs())
),
)
await routed_hosts[1]._router.server.set(
routed_hosts[1].get_id().xor_id,
peer_info_to_str(
PeerInfo(routed_hosts[1].get_id(), routed_hosts[1].get_addrs())
),
)
# forces to use routing as no addrs are provided
await routed_hosts[0].connect(PeerInfo(routed_hosts[1].get_id(), []))
await routed_hosts[1].connect(PeerInfo(routed_hosts[0].get_id(), []))
# TODO: Kademlia is full of asyncio code. Skip it for now
@pytest.mark.skip
@pytest.mark.trio
async def test_host_routing_fail():
routers = await set_up_routers()
transports = [["/ip4/127.0.0.1/tcp/0"], ["/ip4/127.0.0.1/tcp/0"]]
transport_disc_opt_list = zip(transports, routers)
(host_a, host_b) = await set_up_nodes_by_transport_and_disc_opt(
transport_disc_opt_list
)
host_c = (await set_up_nodes_by_transport_opt([["/ip4/127.0.0.1/tcp/0"]]))[0]
# Set routing info
await routers[0].server.set(
host_a.get_id().xor_id,
peer_info_to_str(PeerInfo(host_a.get_id(), host_a.get_addrs())),
)
await routers[1].server.set(
host_b.get_id().xor_id,
peer_info_to_str(PeerInfo(host_b.get_id(), host_b.get_addrs())),
)
# routing fails because host_c does not use routing
with pytest.raises(ConnectionFailure):
await host_a.connect(PeerInfo(host_c.get_id(), []))
with pytest.raises(ConnectionFailure):
await host_b.connect(PeerInfo(host_c.get_id(), []))
# Clean up
routers[0].server.stop()
routers[1].server.stop()

View File

@ -4,7 +4,6 @@ from libp2p.host.exceptions import StreamFailure
from libp2p.tools.factories import HostFactory
from libp2p.tools.utils import create_echo_stream_handler
PROTOCOL_ECHO = "/echo/1.0.0"
PROTOCOL_POTATO = "/potato/1.0.0"
PROTOCOL_FOO = "/foo/1.0.0"

View File

@ -1,22 +0,0 @@
import pytest
from libp2p.tools.constants import GOSSIPSUB_PARAMS
from libp2p.tools.factories import FloodsubFactory, GossipsubFactory, PubsubFactory
@pytest.fixture
def pubsub_cache_size():
return None # default
@pytest.fixture
def gossipsub_params():
return GOSSIPSUB_PARAMS
# @pytest.fixture
# def pubsubs_gsub(num_hosts, hosts, pubsub_cache_size, gossipsub_params):
# gossipsubs = GossipsubFactory.create_batch(num_hosts, **gossipsub_params._asdict())
# _pubsubs_gsub = _make_pubsubs(hosts, gossipsubs, pubsub_cache_size)
# yield _pubsubs_gsub
# # TODO: Clean up

View File

@ -1,19 +1,10 @@
import asyncio
from threading import Thread
import pytest
import trio
from libp2p.tools.pubsub.dummy_account_node import DummyAccountNode
from libp2p.tools.utils import connect
def create_setup_in_new_thread_func(dummy_node):
def setup_in_new_thread():
asyncio.ensure_future(dummy_node.setup_crypto_networking())
return setup_in_new_thread
async def perform_test(num_nodes, adjacency_map, action_func, assertion_func):
"""
Helper function to allow for easy construction of custom tests for dummy
@ -26,47 +17,35 @@ async def perform_test(num_nodes, adjacency_map, action_func, assertion_func):
:param assertion_func: assertions for testing the results of the actions are correct
"""
# Create nodes
dummy_nodes = []
for _ in range(num_nodes):
dummy_nodes.append(await DummyAccountNode.create())
async with DummyAccountNode.create(num_nodes) as dummy_nodes:
# Create connections between nodes according to `adjacency_map`
async with trio.open_nursery() as nursery:
for source_num in adjacency_map:
target_nums = adjacency_map[source_num]
for target_num in target_nums:
nursery.start_soon(
connect,
dummy_nodes[source_num].host,
dummy_nodes[target_num].host,
)
# Create network
for source_num in adjacency_map:
target_nums = adjacency_map[source_num]
for target_num in target_nums:
await connect(
dummy_nodes[source_num].libp2p_node, dummy_nodes[target_num].libp2p_node
)
# Allow time for network creation to take place
await trio.sleep(0.25)
# Allow time for network creation to take place
await asyncio.sleep(0.25)
# Perform action function
await action_func(dummy_nodes)
# Start a thread for each node so that each node can listen and respond
# to messages on its own thread, which will avoid waiting indefinitely
# on the main thread. On this thread, call the setup func for the node,
# which subscribes the node to the CRYPTO_TOPIC topic
for dummy_node in dummy_nodes:
thread = Thread(target=create_setup_in_new_thread_func(dummy_node))
thread.run()
# Allow time for action function to be performed (i.e. messages to propogate)
await trio.sleep(1)
# Allow time for nodes to subscribe to CRYPTO_TOPIC topic
await asyncio.sleep(0.25)
# Perform action function
await action_func(dummy_nodes)
# Allow time for action function to be performed (i.e. messages to propogate)
await asyncio.sleep(1)
# Perform assertion function
for dummy_node in dummy_nodes:
assertion_func(dummy_node)
# Perform assertion function
for dummy_node in dummy_nodes:
assertion_func(dummy_node)
# Success, terminate pending tasks.
@pytest.mark.asyncio
@pytest.mark.trio
async def test_simple_two_nodes():
num_nodes = 2
adj_map = {0: [1]}
@ -80,7 +59,7 @@ async def test_simple_two_nodes():
await perform_test(num_nodes, adj_map, action_func, assertion_func)
@pytest.mark.asyncio
@pytest.mark.trio
async def test_simple_three_nodes_line_topography():
num_nodes = 3
adj_map = {0: [1], 1: [2]}
@ -94,7 +73,7 @@ async def test_simple_three_nodes_line_topography():
await perform_test(num_nodes, adj_map, action_func, assertion_func)
@pytest.mark.asyncio
@pytest.mark.trio
async def test_simple_three_nodes_triangle_topography():
num_nodes = 3
adj_map = {0: [1, 2], 1: [2]}
@ -108,7 +87,7 @@ async def test_simple_three_nodes_triangle_topography():
await perform_test(num_nodes, adj_map, action_func, assertion_func)
@pytest.mark.asyncio
@pytest.mark.trio
async def test_simple_seven_nodes_tree_topography():
num_nodes = 7
adj_map = {0: [1, 2], 1: [3, 4], 2: [5, 6]}
@ -122,14 +101,14 @@ async def test_simple_seven_nodes_tree_topography():
await perform_test(num_nodes, adj_map, action_func, assertion_func)
@pytest.mark.asyncio
@pytest.mark.trio
async def test_set_then_send_from_root_seven_nodes_tree_topography():
num_nodes = 7
adj_map = {0: [1, 2], 1: [3, 4], 2: [5, 6]}
async def action_func(dummy_nodes):
await dummy_nodes[0].publish_set_crypto("aspyn", 20)
await asyncio.sleep(0.25)
await trio.sleep(0.25)
await dummy_nodes[0].publish_send_crypto("aspyn", "alex", 5)
def assertion_func(dummy_node):
@ -139,14 +118,14 @@ async def test_set_then_send_from_root_seven_nodes_tree_topography():
await perform_test(num_nodes, adj_map, action_func, assertion_func)
@pytest.mark.asyncio
@pytest.mark.trio
async def test_set_then_send_from_different_leafs_seven_nodes_tree_topography():
num_nodes = 7
adj_map = {0: [1, 2], 1: [3, 4], 2: [5, 6]}
async def action_func(dummy_nodes):
await dummy_nodes[6].publish_set_crypto("aspyn", 20)
await asyncio.sleep(0.25)
await trio.sleep(0.25)
await dummy_nodes[4].publish_send_crypto("aspyn", "alex", 5)
def assertion_func(dummy_node):
@ -156,7 +135,7 @@ async def test_set_then_send_from_different_leafs_seven_nodes_tree_topography():
await perform_test(num_nodes, adj_map, action_func, assertion_func)
@pytest.mark.asyncio
@pytest.mark.trio
async def test_simple_five_nodes_ring_topography():
num_nodes = 5
adj_map = {0: [1], 1: [2], 2: [3], 3: [4], 4: [0]}
@ -170,14 +149,14 @@ async def test_simple_five_nodes_ring_topography():
await perform_test(num_nodes, adj_map, action_func, assertion_func)
@pytest.mark.asyncio
@pytest.mark.trio
async def test_set_then_send_from_diff_nodes_five_nodes_ring_topography():
num_nodes = 5
adj_map = {0: [1], 1: [2], 2: [3], 3: [4], 4: [0]}
async def action_func(dummy_nodes):
await dummy_nodes[0].publish_set_crypto("alex", 20)
await asyncio.sleep(0.25)
await trio.sleep(0.25)
await dummy_nodes[3].publish_send_crypto("alex", "rob", 12)
def assertion_func(dummy_node):
@ -187,7 +166,7 @@ async def test_set_then_send_from_diff_nodes_five_nodes_ring_topography():
await perform_test(num_nodes, adj_map, action_func, assertion_func)
@pytest.mark.asyncio
@pytest.mark.trio
@pytest.mark.slow
async def test_set_then_send_from_five_diff_nodes_five_nodes_ring_topography():
num_nodes = 5
@ -195,13 +174,13 @@ async def test_set_then_send_from_five_diff_nodes_five_nodes_ring_topography():
async def action_func(dummy_nodes):
await dummy_nodes[0].publish_set_crypto("alex", 20)
await asyncio.sleep(1)
await trio.sleep(1)
await dummy_nodes[1].publish_send_crypto("alex", "rob", 3)
await asyncio.sleep(1)
await trio.sleep(1)
await dummy_nodes[2].publish_send_crypto("rob", "aspyn", 2)
await asyncio.sleep(1)
await trio.sleep(1)
await dummy_nodes[3].publish_send_crypto("aspyn", "zx", 1)
await asyncio.sleep(1)
await trio.sleep(1)
await dummy_nodes[4].publish_send_crypto("zx", "raul", 1)
def assertion_func(dummy_node):

View File

@ -1,9 +1,10 @@
import asyncio
import functools
import pytest
import trio
from libp2p.peer.id import ID
from libp2p.tools.factories import FloodsubFactory
from libp2p.tools.factories import PubsubFactory
from libp2p.tools.pubsub.floodsub_integration_test_settings import (
floodsub_protocol_pytest_params,
perform_test_from_obj,
@ -11,79 +12,83 @@ from libp2p.tools.pubsub.floodsub_integration_test_settings import (
from libp2p.tools.utils import connect
@pytest.mark.parametrize("num_hosts", (2,))
@pytest.mark.asyncio
async def test_simple_two_nodes(pubsubs_fsub):
topic = "my_topic"
data = b"some data"
@pytest.mark.trio
async def test_simple_two_nodes():
async with PubsubFactory.create_batch_with_floodsub(2) as pubsubs_fsub:
topic = "my_topic"
data = b"some data"
await connect(pubsubs_fsub[0].host, pubsubs_fsub[1].host)
await asyncio.sleep(0.25)
await connect(pubsubs_fsub[0].host, pubsubs_fsub[1].host)
await trio.sleep(0.25)
sub_b = await pubsubs_fsub[1].subscribe(topic)
# Sleep to let a know of b's subscription
await asyncio.sleep(0.25)
sub_b = await pubsubs_fsub[1].subscribe(topic)
# Sleep to let a know of b's subscription
await trio.sleep(0.25)
await pubsubs_fsub[0].publish(topic, data)
await pubsubs_fsub[0].publish(topic, data)
res_b = await sub_b.get()
res_b = await sub_b.receive()
# Check that the msg received by node_b is the same
# as the message sent by node_a
assert ID(res_b.from_id) == pubsubs_fsub[0].host.get_id()
assert res_b.data == data
assert res_b.topicIDs == [topic]
# Success, terminate pending tasks.
# Check that the msg received by node_b is the same
# as the message sent by node_a
assert ID(res_b.from_id) == pubsubs_fsub[0].host.get_id()
assert res_b.data == data
assert res_b.topicIDs == [topic]
# Initialize Pubsub with a cache_size of 4
@pytest.mark.parametrize("num_hosts, pubsub_cache_size", ((2, 4),))
@pytest.mark.asyncio
async def test_lru_cache_two_nodes(pubsubs_fsub, monkeypatch):
@pytest.mark.trio
async def test_lru_cache_two_nodes(monkeypatch):
# two nodes with cache_size of 4
# `node_a` send the following messages to node_b
message_indices = [1, 1, 2, 1, 3, 1, 4, 1, 5, 1]
# `node_b` should only receive the following
expected_received_indices = [1, 2, 3, 4, 5, 1]
async with PubsubFactory.create_batch_with_floodsub(
2, cache_size=4
) as pubsubs_fsub:
# `node_a` send the following messages to node_b
message_indices = [1, 1, 2, 1, 3, 1, 4, 1, 5, 1]
# `node_b` should only receive the following
expected_received_indices = [1, 2, 3, 4, 5, 1]
topic = "my_topic"
topic = "my_topic"
# Mock `get_msg_id` to make us easier to manipulate `msg_id` by `data`.
def get_msg_id(msg):
# Originally it is `(msg.seqno, msg.from_id)`
return (msg.data, msg.from_id)
# Mock `get_msg_id` to make us easier to manipulate `msg_id` by `data`.
def get_msg_id(msg):
# Originally it is `(msg.seqno, msg.from_id)`
return (msg.data, msg.from_id)
import libp2p.pubsub.pubsub
import libp2p.pubsub.pubsub
monkeypatch.setattr(libp2p.pubsub.pubsub, "get_msg_id", get_msg_id)
monkeypatch.setattr(libp2p.pubsub.pubsub, "get_msg_id", get_msg_id)
await connect(pubsubs_fsub[0].host, pubsubs_fsub[1].host)
await asyncio.sleep(0.25)
await connect(pubsubs_fsub[0].host, pubsubs_fsub[1].host)
await trio.sleep(0.25)
sub_b = await pubsubs_fsub[1].subscribe(topic)
await asyncio.sleep(0.25)
sub_b = await pubsubs_fsub[1].subscribe(topic)
await trio.sleep(0.25)
def _make_testing_data(i: int) -> bytes:
num_int_bytes = 4
if i >= 2 ** (num_int_bytes * 8):
raise ValueError("integer is too large to be serialized")
return b"data" + i.to_bytes(num_int_bytes, "big")
def _make_testing_data(i: int) -> bytes:
num_int_bytes = 4
if i >= 2 ** (num_int_bytes * 8):
raise ValueError("integer is too large to be serialized")
return b"data" + i.to_bytes(num_int_bytes, "big")
for index in message_indices:
await pubsubs_fsub[0].publish(topic, _make_testing_data(index))
await asyncio.sleep(0.25)
for index in message_indices:
await pubsubs_fsub[0].publish(topic, _make_testing_data(index))
await trio.sleep(0.25)
for index in expected_received_indices:
res_b = await sub_b.get()
assert res_b.data == _make_testing_data(index)
assert sub_b.empty()
for index in expected_received_indices:
res_b = await sub_b.receive()
assert res_b.data == _make_testing_data(index)
# Success, terminate pending tasks.
with pytest.raises(trio.WouldBlock):
sub_b.receive_nowait()
@pytest.mark.parametrize("test_case_obj", floodsub_protocol_pytest_params)
@pytest.mark.asyncio
@pytest.mark.trio
@pytest.mark.slow
async def test_gossipsub_run_with_floodsub_tests(test_case_obj):
await perform_test_from_obj(test_case_obj, FloodsubFactory)
async def test_gossipsub_run_with_floodsub_tests(test_case_obj, is_host_secure):
await perform_test_from_obj(
test_case_obj,
functools.partial(
PubsubFactory.create_batch_with_floodsub, is_secure=is_host_secure
),
)

View File

@ -1,368 +1,350 @@
import asyncio
import random
import pytest
import trio
from libp2p.tools.constants import GossipsubParams
from libp2p.tools.factories import PubsubFactory
from libp2p.tools.pubsub.utils import dense_connect, one_to_all_connect
from libp2p.tools.utils import connect
@pytest.mark.parametrize(
"num_hosts, gossipsub_params",
((4, GossipsubParams(degree=4, degree_low=3, degree_high=5)),),
)
@pytest.mark.asyncio
async def test_join(num_hosts, hosts, pubsubs_gsub):
gossipsubs = tuple(pubsub.router for pubsub in pubsubs_gsub)
hosts_indices = list(range(num_hosts))
@pytest.mark.trio
async def test_join():
async with PubsubFactory.create_batch_with_gossipsub(
4, degree=4, degree_low=3, degree_high=5
) as pubsubs_gsub:
gossipsubs = [pubsub.router for pubsub in pubsubs_gsub]
hosts = [pubsub.host for pubsub in pubsubs_gsub]
hosts_indices = list(range(len(pubsubs_gsub)))
topic = "test_join"
central_node_index = 0
# Remove index of central host from the indices
hosts_indices.remove(central_node_index)
num_subscribed_peer = 2
subscribed_peer_indices = random.sample(hosts_indices, num_subscribed_peer)
topic = "test_join"
central_node_index = 0
# Remove index of central host from the indices
hosts_indices.remove(central_node_index)
num_subscribed_peer = 2
subscribed_peer_indices = random.sample(hosts_indices, num_subscribed_peer)
# All pubsub except the one of central node subscribe to topic
for i in subscribed_peer_indices:
await pubsubs_gsub[i].subscribe(topic)
# All pubsub except the one of central node subscribe to topic
for i in subscribed_peer_indices:
await pubsubs_gsub[i].subscribe(topic)
# Connect central host to all other hosts
await one_to_all_connect(hosts, central_node_index)
# Connect central host to all other hosts
await one_to_all_connect(hosts, central_node_index)
# Wait 2 seconds for heartbeat to allow mesh to connect
await asyncio.sleep(2)
# Central node publish to the topic so that this topic
# is added to central node's fanout
# publish from the randomly chosen host
await pubsubs_gsub[central_node_index].publish(topic, b"data")
# Check that the gossipsub of central node has fanout for the topic
assert topic in gossipsubs[central_node_index].fanout
# Check that the gossipsub of central node does not have a mesh for the topic
assert topic not in gossipsubs[central_node_index].mesh
# Central node subscribes the topic
await pubsubs_gsub[central_node_index].subscribe(topic)
await asyncio.sleep(2)
# Check that the gossipsub of central node no longer has fanout for the topic
assert topic not in gossipsubs[central_node_index].fanout
for i in hosts_indices:
if i in subscribed_peer_indices:
assert hosts[i].get_id() in gossipsubs[central_node_index].mesh[topic]
assert hosts[central_node_index].get_id() in gossipsubs[i].mesh[topic]
else:
assert hosts[i].get_id() not in gossipsubs[central_node_index].mesh[topic]
assert topic not in gossipsubs[i].mesh
@pytest.mark.parametrize("num_hosts", (1,))
@pytest.mark.asyncio
async def test_leave(pubsubs_gsub):
gossipsub = pubsubs_gsub[0].router
topic = "test_leave"
assert topic not in gossipsub.mesh
await gossipsub.join(topic)
assert topic in gossipsub.mesh
await gossipsub.leave(topic)
assert topic not in gossipsub.mesh
# Test re-leave
await gossipsub.leave(topic)
@pytest.mark.parametrize("num_hosts", (2,))
@pytest.mark.asyncio
async def test_handle_graft(pubsubs_gsub, hosts, event_loop, monkeypatch):
gossipsubs = tuple(pubsub.router for pubsub in pubsubs_gsub)
index_alice = 0
id_alice = hosts[index_alice].get_id()
index_bob = 1
id_bob = hosts[index_bob].get_id()
await connect(hosts[index_alice], hosts[index_bob])
# Wait 2 seconds for heartbeat to allow mesh to connect
await asyncio.sleep(2)
topic = "test_handle_graft"
# Only lice subscribe to the topic
await gossipsubs[index_alice].join(topic)
# Monkey patch bob's `emit_prune` function so we can
# check if it is called in `handle_graft`
event_emit_prune = asyncio.Event()
async def emit_prune(topic, sender_peer_id):
event_emit_prune.set()
monkeypatch.setattr(gossipsubs[index_bob], "emit_prune", emit_prune)
# Check that alice is bob's peer but not his mesh peer
assert id_alice in gossipsubs[index_bob].peers_gossipsub
assert topic not in gossipsubs[index_bob].mesh
await gossipsubs[index_alice].emit_graft(topic, id_bob)
# Check that `emit_prune` is called
await asyncio.wait_for(event_emit_prune.wait(), timeout=1, loop=event_loop)
assert event_emit_prune.is_set()
# Check that bob is alice's peer but not her mesh peer
assert topic in gossipsubs[index_alice].mesh
assert id_bob not in gossipsubs[index_alice].mesh[topic]
assert id_bob in gossipsubs[index_alice].peers_gossipsub
await gossipsubs[index_bob].emit_graft(topic, id_alice)
await asyncio.sleep(1)
# Check that bob is now alice's mesh peer
assert id_bob in gossipsubs[index_alice].mesh[topic]
@pytest.mark.parametrize(
"num_hosts, gossipsub_params", ((2, GossipsubParams(heartbeat_interval=3)),)
)
@pytest.mark.asyncio
async def test_handle_prune(pubsubs_gsub, hosts):
gossipsubs = tuple(pubsub.router for pubsub in pubsubs_gsub)
index_alice = 0
id_alice = hosts[index_alice].get_id()
index_bob = 1
id_bob = hosts[index_bob].get_id()
topic = "test_handle_prune"
for pubsub in pubsubs_gsub:
await pubsub.subscribe(topic)
await connect(hosts[index_alice], hosts[index_bob])
# Wait 3 seconds for heartbeat to allow mesh to connect
await asyncio.sleep(3)
# Check that they are each other's mesh peer
assert id_alice in gossipsubs[index_bob].mesh[topic]
assert id_bob in gossipsubs[index_alice].mesh[topic]
# alice emit prune message to bob, alice should be removed
# from bob's mesh peer
await gossipsubs[index_alice].emit_prune(topic, id_bob)
# FIXME: This test currently works because the heartbeat interval
# is increased to 3 seconds, so alice won't get add back into
# bob's mesh peer during heartbeat.
await asyncio.sleep(1)
# Check that alice is no longer bob's mesh peer
assert id_alice not in gossipsubs[index_bob].mesh[topic]
assert id_bob in gossipsubs[index_alice].mesh[topic]
@pytest.mark.parametrize("num_hosts", (10,))
@pytest.mark.asyncio
async def test_dense(num_hosts, pubsubs_gsub, hosts):
num_msgs = 5
# All pubsub subscribe to foobar
queues = []
for pubsub in pubsubs_gsub:
q = await pubsub.subscribe("foobar")
# Add each blocking queue to an array of blocking queues
queues.append(q)
# Densely connect libp2p hosts in a random way
await dense_connect(hosts)
# Wait 2 seconds for heartbeat to allow mesh to connect
await asyncio.sleep(2)
for i in range(num_msgs):
msg_content = b"foo " + i.to_bytes(1, "big")
# randomly pick a message origin
origin_idx = random.randint(0, num_hosts - 1)
# Wait 2 seconds for heartbeat to allow mesh to connect
await trio.sleep(2)
# Central node publish to the topic so that this topic
# is added to central node's fanout
# publish from the randomly chosen host
await pubsubs_gsub[origin_idx].publish("foobar", msg_content)
await pubsubs_gsub[central_node_index].publish(topic, b"data")
await asyncio.sleep(0.5)
# Assert that all blocking queues receive the message
for queue in queues:
msg = await queue.get()
assert msg.data == msg_content
# Check that the gossipsub of central node has fanout for the topic
assert topic in gossipsubs[central_node_index].fanout
# Check that the gossipsub of central node does not have a mesh for the topic
assert topic not in gossipsubs[central_node_index].mesh
# Central node subscribes the topic
await pubsubs_gsub[central_node_index].subscribe(topic)
await trio.sleep(2)
# Check that the gossipsub of central node no longer has fanout for the topic
assert topic not in gossipsubs[central_node_index].fanout
for i in hosts_indices:
if i in subscribed_peer_indices:
assert hosts[i].get_id() in gossipsubs[central_node_index].mesh[topic]
assert hosts[central_node_index].get_id() in gossipsubs[i].mesh[topic]
else:
assert (
hosts[i].get_id() not in gossipsubs[central_node_index].mesh[topic]
)
assert topic not in gossipsubs[i].mesh
@pytest.mark.parametrize("num_hosts", (10,))
@pytest.mark.asyncio
async def test_fanout(hosts, pubsubs_gsub):
num_msgs = 5
@pytest.mark.trio
async def test_leave():
async with PubsubFactory.create_batch_with_gossipsub(1) as pubsubs_gsub:
gossipsub = pubsubs_gsub[0].router
topic = "test_leave"
# All pubsub subscribe to foobar except for `pubsubs_gsub[0]`
queues = []
for i in range(1, len(pubsubs_gsub)):
q = await pubsubs_gsub[i].subscribe("foobar")
assert topic not in gossipsub.mesh
# Add each blocking queue to an array of blocking queues
queues.append(q)
await gossipsub.join(topic)
assert topic in gossipsub.mesh
# Sparsely connect libp2p hosts in random way
await dense_connect(hosts)
await gossipsub.leave(topic)
assert topic not in gossipsub.mesh
# Wait 2 seconds for heartbeat to allow mesh to connect
await asyncio.sleep(2)
topic = "foobar"
# Send messages with origin not subscribed
for i in range(num_msgs):
msg_content = b"foo " + i.to_bytes(1, "big")
# Pick the message origin to the node that is not subscribed to 'foobar'
origin_idx = 0
# publish from the randomly chosen host
await pubsubs_gsub[origin_idx].publish(topic, msg_content)
await asyncio.sleep(0.5)
# Assert that all blocking queues receive the message
for queue in queues:
msg = await queue.get()
assert msg.data == msg_content
# Subscribe message origin
queues.insert(0, await pubsubs_gsub[0].subscribe(topic))
# Send messages again
for i in range(num_msgs):
msg_content = b"bar " + i.to_bytes(1, "big")
# Pick the message origin to the node that is not subscribed to 'foobar'
origin_idx = 0
# publish from the randomly chosen host
await pubsubs_gsub[origin_idx].publish(topic, msg_content)
await asyncio.sleep(0.5)
# Assert that all blocking queues receive the message
for queue in queues:
msg = await queue.get()
assert msg.data == msg_content
# Test re-leave
await gossipsub.leave(topic)
@pytest.mark.parametrize("num_hosts", (10,))
@pytest.mark.asyncio
@pytest.mark.trio
async def test_handle_graft(monkeypatch):
async with PubsubFactory.create_batch_with_gossipsub(2) as pubsubs_gsub:
gossipsubs = tuple(pubsub.router for pubsub in pubsubs_gsub)
index_alice = 0
id_alice = pubsubs_gsub[index_alice].my_id
index_bob = 1
id_bob = pubsubs_gsub[index_bob].my_id
await connect(pubsubs_gsub[index_alice].host, pubsubs_gsub[index_bob].host)
# Wait 2 seconds for heartbeat to allow mesh to connect
await trio.sleep(2)
topic = "test_handle_graft"
# Only lice subscribe to the topic
await gossipsubs[index_alice].join(topic)
# Monkey patch bob's `emit_prune` function so we can
# check if it is called in `handle_graft`
event_emit_prune = trio.Event()
async def emit_prune(topic, sender_peer_id):
event_emit_prune.set()
monkeypatch.setattr(gossipsubs[index_bob], "emit_prune", emit_prune)
# Check that alice is bob's peer but not his mesh peer
assert id_alice in gossipsubs[index_bob].peers_gossipsub
assert topic not in gossipsubs[index_bob].mesh
await gossipsubs[index_alice].emit_graft(topic, id_bob)
# Check that `emit_prune` is called
await event_emit_prune.wait()
# Check that bob is alice's peer but not her mesh peer
assert topic in gossipsubs[index_alice].mesh
assert id_bob not in gossipsubs[index_alice].mesh[topic]
assert id_bob in gossipsubs[index_alice].peers_gossipsub
await gossipsubs[index_bob].emit_graft(topic, id_alice)
await trio.sleep(1)
# Check that bob is now alice's mesh peer
assert id_bob in gossipsubs[index_alice].mesh[topic]
@pytest.mark.trio
async def test_handle_prune():
async with PubsubFactory.create_batch_with_gossipsub(
2, heartbeat_interval=3
) as pubsubs_gsub:
gossipsubs = tuple(pubsub.router for pubsub in pubsubs_gsub)
index_alice = 0
id_alice = pubsubs_gsub[index_alice].my_id
index_bob = 1
id_bob = pubsubs_gsub[index_bob].my_id
topic = "test_handle_prune"
for pubsub in pubsubs_gsub:
await pubsub.subscribe(topic)
await connect(pubsubs_gsub[index_alice].host, pubsubs_gsub[index_bob].host)
# Wait 3 seconds for heartbeat to allow mesh to connect
await trio.sleep(3)
# Check that they are each other's mesh peer
assert id_alice in gossipsubs[index_bob].mesh[topic]
assert id_bob in gossipsubs[index_alice].mesh[topic]
# alice emit prune message to bob, alice should be removed
# from bob's mesh peer
await gossipsubs[index_alice].emit_prune(topic, id_bob)
# FIXME: This test currently works because the heartbeat interval
# is increased to 3 seconds, so alice won't get add back into
# bob's mesh peer during heartbeat.
await trio.sleep(1)
# Check that alice is no longer bob's mesh peer
assert id_alice not in gossipsubs[index_bob].mesh[topic]
assert id_bob in gossipsubs[index_alice].mesh[topic]
@pytest.mark.trio
async def test_dense():
async with PubsubFactory.create_batch_with_gossipsub(10) as pubsubs_gsub:
hosts = [pubsub.host for pubsub in pubsubs_gsub]
num_msgs = 5
# All pubsub subscribe to foobar
queues = [await pubsub.subscribe("foobar") for pubsub in pubsubs_gsub]
# Densely connect libp2p hosts in a random way
await dense_connect(hosts)
# Wait 2 seconds for heartbeat to allow mesh to connect
await trio.sleep(2)
for i in range(num_msgs):
msg_content = b"foo " + i.to_bytes(1, "big")
# randomly pick a message origin
origin_idx = random.randint(0, len(hosts) - 1)
# publish from the randomly chosen host
await pubsubs_gsub[origin_idx].publish("foobar", msg_content)
await trio.sleep(0.5)
# Assert that all blocking queues receive the message
for queue in queues:
msg = await queue.receive()
assert msg.data == msg_content
@pytest.mark.trio
async def test_fanout():
async with PubsubFactory.create_batch_with_gossipsub(10) as pubsubs_gsub:
hosts = [pubsub.host for pubsub in pubsubs_gsub]
num_msgs = 5
# All pubsub subscribe to foobar except for `pubsubs_gsub[0]`
subs = [await pubsub.subscribe("foobar") for pubsub in pubsubs_gsub[1:]]
# Sparsely connect libp2p hosts in random way
await dense_connect(hosts)
# Wait 2 seconds for heartbeat to allow mesh to connect
await trio.sleep(2)
topic = "foobar"
# Send messages with origin not subscribed
for i in range(num_msgs):
msg_content = b"foo " + i.to_bytes(1, "big")
# Pick the message origin to the node that is not subscribed to 'foobar'
origin_idx = 0
# publish from the randomly chosen host
await pubsubs_gsub[origin_idx].publish(topic, msg_content)
await trio.sleep(0.5)
# Assert that all blocking queues receive the message
for sub in subs:
msg = await sub.receive()
assert msg.data == msg_content
# Subscribe message origin
subs.insert(0, await pubsubs_gsub[0].subscribe(topic))
# Send messages again
for i in range(num_msgs):
msg_content = b"bar " + i.to_bytes(1, "big")
# Pick the message origin to the node that is not subscribed to 'foobar'
origin_idx = 0
# publish from the randomly chosen host
await pubsubs_gsub[origin_idx].publish(topic, msg_content)
await trio.sleep(0.5)
# Assert that all blocking queues receive the message
for sub in subs:
msg = await sub.receive()
assert msg.data == msg_content
@pytest.mark.trio
@pytest.mark.slow
async def test_fanout_maintenance(hosts, pubsubs_gsub):
num_msgs = 5
async def test_fanout_maintenance():
async with PubsubFactory.create_batch_with_gossipsub(10) as pubsubs_gsub:
hosts = [pubsub.host for pubsub in pubsubs_gsub]
num_msgs = 5
# All pubsub subscribe to foobar
queues = []
topic = "foobar"
for i in range(1, len(pubsubs_gsub)):
q = await pubsubs_gsub[i].subscribe(topic)
# All pubsub subscribe to foobar
queues = []
topic = "foobar"
for i in range(1, len(pubsubs_gsub)):
q = await pubsubs_gsub[i].subscribe(topic)
# Add each blocking queue to an array of blocking queues
queues.append(q)
# Add each blocking queue to an array of blocking queues
queues.append(q)
# Sparsely connect libp2p hosts in random way
await dense_connect(hosts)
# Sparsely connect libp2p hosts in random way
await dense_connect(hosts)
# Wait 2 seconds for heartbeat to allow mesh to connect
await asyncio.sleep(2)
# Wait 2 seconds for heartbeat to allow mesh to connect
await trio.sleep(2)
# Send messages with origin not subscribed
for i in range(num_msgs):
msg_content = b"foo " + i.to_bytes(1, "big")
# Send messages with origin not subscribed
for i in range(num_msgs):
msg_content = b"foo " + i.to_bytes(1, "big")
# Pick the message origin to the node that is not subscribed to 'foobar'
origin_idx = 0
# Pick the message origin to the node that is not subscribed to 'foobar'
origin_idx = 0
# publish from the randomly chosen host
await pubsubs_gsub[origin_idx].publish(topic, msg_content)
await trio.sleep(0.5)
# Assert that all blocking queues receive the message
for queue in queues:
msg = await queue.receive()
assert msg.data == msg_content
for sub in pubsubs_gsub:
await sub.unsubscribe(topic)
queues = []
await trio.sleep(2)
# Resub and repeat
for i in range(1, len(pubsubs_gsub)):
q = await pubsubs_gsub[i].subscribe(topic)
# Add each blocking queue to an array of blocking queues
queues.append(q)
await trio.sleep(2)
# Check messages can still be sent
for i in range(num_msgs):
msg_content = b"bar " + i.to_bytes(1, "big")
# Pick the message origin to the node that is not subscribed to 'foobar'
origin_idx = 0
# publish from the randomly chosen host
await pubsubs_gsub[origin_idx].publish(topic, msg_content)
await trio.sleep(0.5)
# Assert that all blocking queues receive the message
for queue in queues:
msg = await queue.receive()
assert msg.data == msg_content
@pytest.mark.trio
async def test_gossip_propagation():
async with PubsubFactory.create_batch_with_gossipsub(
2, degree=1, degree_low=0, degree_high=2, gossip_window=50, gossip_history=100
) as pubsubs_gsub:
topic = "foo"
await pubsubs_gsub[0].subscribe(topic)
# node 0 publish to topic
msg_content = b"foo_msg"
# publish from the randomly chosen host
await pubsubs_gsub[origin_idx].publish(topic, msg_content)
await pubsubs_gsub[0].publish(topic, msg_content)
await asyncio.sleep(0.5)
# Assert that all blocking queues receive the message
for queue in queues:
msg = await queue.get()
assert msg.data == msg_content
# now node 1 subscribes
queue_1 = await pubsubs_gsub[1].subscribe(topic)
for sub in pubsubs_gsub:
await sub.unsubscribe(topic)
await connect(pubsubs_gsub[0].host, pubsubs_gsub[1].host)
queues = []
# wait for gossip heartbeat
await trio.sleep(2)
await asyncio.sleep(2)
# Resub and repeat
for i in range(1, len(pubsubs_gsub)):
q = await pubsubs_gsub[i].subscribe(topic)
# Add each blocking queue to an array of blocking queues
queues.append(q)
await asyncio.sleep(2)
# Check messages can still be sent
for i in range(num_msgs):
msg_content = b"bar " + i.to_bytes(1, "big")
# Pick the message origin to the node that is not subscribed to 'foobar'
origin_idx = 0
# publish from the randomly chosen host
await pubsubs_gsub[origin_idx].publish(topic, msg_content)
await asyncio.sleep(0.5)
# Assert that all blocking queues receive the message
for queue in queues:
msg = await queue.get()
assert msg.data == msg_content
@pytest.mark.parametrize(
"num_hosts, gossipsub_params",
(
(
2,
GossipsubParams(
degree=1,
degree_low=0,
degree_high=2,
gossip_window=50,
gossip_history=100,
),
),
),
)
@pytest.mark.asyncio
async def test_gossip_propagation(hosts, pubsubs_gsub):
topic = "foo"
await pubsubs_gsub[0].subscribe(topic)
# node 0 publish to topic
msg_content = b"foo_msg"
# publish from the randomly chosen host
await pubsubs_gsub[0].publish(topic, msg_content)
# now node 1 subscribes
queue_1 = await pubsubs_gsub[1].subscribe(topic)
await connect(hosts[0], hosts[1])
# wait for gossip heartbeat
await asyncio.sleep(2)
# should be able to read message
msg = await queue_1.get()
assert msg.data == msg_content
# should be able to read message
msg = await queue_1.receive()
assert msg.data == msg_content

View File

@ -3,25 +3,25 @@ import functools
import pytest
from libp2p.tools.constants import FLOODSUB_PROTOCOL_ID
from libp2p.tools.factories import GossipsubFactory
from libp2p.tools.factories import PubsubFactory
from libp2p.tools.pubsub.floodsub_integration_test_settings import (
floodsub_protocol_pytest_params,
perform_test_from_obj,
)
@pytest.mark.asyncio
async def test_gossipsub_initialize_with_floodsub_protocol():
GossipsubFactory(protocols=[FLOODSUB_PROTOCOL_ID])
@pytest.mark.parametrize("test_case_obj", floodsub_protocol_pytest_params)
@pytest.mark.asyncio
@pytest.mark.trio
@pytest.mark.slow
async def test_gossipsub_run_with_floodsub_tests(test_case_obj):
await perform_test_from_obj(
test_case_obj,
functools.partial(
GossipsubFactory, degree=3, degree_low=2, degree_high=4, time_to_live=30
PubsubFactory.create_batch_with_gossipsub,
protocols=[FLOODSUB_PROTOCOL_ID],
degree=3,
degree_low=2,
degree_high=4,
time_to_live=30,
),
)

View File

@ -1,5 +1,3 @@
import pytest
from libp2p.pubsub.mcache import MessageCache
@ -12,8 +10,7 @@ class Msg:
self.from_id = from_id
@pytest.mark.asyncio
async def test_mcache():
def test_mcache():
# Ported from:
# https://github.com/libp2p/go-libp2p-pubsub/blob/51b7501433411b5096cac2b4994a36a68515fc03/mcache_test.go
mcache = MessageCache(3, 5)

View File

@ -5,12 +5,11 @@ import pytest
import trio
from libp2p.exceptions import ValidationError
from libp2p.peer.id import ID
from libp2p.pubsub.pb import rpc_pb2
from libp2p.tools.constants import MAX_READ_LEN
from libp2p.tools.factories import IDFactory, PubsubFactory, net_stream_pair_factory
from libp2p.tools.pubsub.utils import make_pubsub_msg
from libp2p.tools.utils import connect
from libp2p.tools.constants import MAX_READ_LEN
from libp2p.tools.factories import PubsubFactory, net_stream_pair_factory, IDFactory
from libp2p.utils import encode_varint_prefixed
TESTING_TOPIC = "TEST_SUBSCRIBE"
@ -250,14 +249,14 @@ async def test_continuously_read_stream(monkeypatch, nursery, is_host_secure):
async def mock_push_msg(msg_forwarder, msg):
event_push_msg.set()
await trio.sleep(0)
await trio.hazmat.checkpoint()
def mock_handle_subscription(origin_id, sub_message):
event_handle_subscription.set()
async def mock_handle_rpc(rpc, sender_peer_id):
event_handle_rpc.set()
await trio.sleep(0)
await trio.hazmat.checkpoint()
with monkeypatch.context() as m:
m.setattr(pubsubs_fsub[0], "push_msg", mock_push_msg)

View File

@ -69,33 +69,10 @@ async def test_mplex_stream_read_after_remote_closed(mplex_stream_pair):
await stream_0.close()
assert stream_0.event_local_closed.is_set()
await trio.sleep(0.01)
print(
"!@# ",
stream_0.muxed_conn.event_shutting_down.is_set(),
stream_0.muxed_conn.event_closed.is_set(),
stream_1.muxed_conn.event_shutting_down.is_set(),
stream_1.muxed_conn.event_closed.is_set(),
)
# await trio.sleep(100000)
await wait_all_tasks_blocked()
print(
"!@# ",
stream_0.muxed_conn.event_shutting_down.is_set(),
stream_0.muxed_conn.event_closed.is_set(),
stream_1.muxed_conn.event_shutting_down.is_set(),
stream_1.muxed_conn.event_closed.is_set(),
)
print("!@# sleeping")
print("!@# result=", stream_1.event_remote_closed.is_set())
# await trio.sleep_forever()
assert stream_1.event_remote_closed.is_set()
print(
"!@# ",
stream_0.muxed_conn.event_shutting_down.is_set(),
stream_0.muxed_conn.event_closed.is_set(),
stream_1.muxed_conn.event_shutting_down.is_set(),
stream_1.muxed_conn.event_closed.is_set(),
)
assert (await stream_1.read(MAX_READ_LEN)) == DATA
with pytest.raises(MplexStreamEOF):
await stream_1.read(MAX_READ_LEN)

View File

@ -3,7 +3,7 @@ import pytest
import trio
from libp2p.network.connection.raw_connection import RawConnection
from libp2p.tools.constants import LISTEN_MADDR, MAX_READ_LEN
from libp2p.tools.constants import LISTEN_MADDR
from libp2p.transport.tcp.tcp import TCP