Merge pull request #404 from libp2p/feature/trio

Merge `feature/trio` into `master`
pull/403/head
Kevin Mai-Husan Chia 2020-02-06 10:49:53 +08:00 committed by GitHub
commit e63584c387
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
80 changed files with 3563 additions and 3378 deletions

View File

@ -5,16 +5,16 @@ matrix:
- python: 3.6-dev
dist: xenial
env: TOXENV=py36-test
- python: 3.7-dev
- python: 3.7
dist: xenial
env: TOXENV=py37-test
- python: 3.7-dev
- python: 3.7
dist: xenial
env: TOXENV=lint
- python: 3.7-dev
- python: 3.7
dist: xenial
env: TOXENV=docs
- python: 3.7-dev
- python: 3.7
dist: xenial
env: TOXENV=py37-interop
sudo: true

View File

@ -51,7 +51,7 @@ lint:
black --check $(FILES_TO_LINT)
isort --recursive --check-only --diff $(FILES_TO_LINT)
docformatter --pre-summary-newline --check --recursive $(FILES_TO_LINT)
tox -elint # This is probably redundant, but just in case...
tox -e lint # This is probably redundant, but just in case...
lint-roll:
isort --recursive $(FILES_TO_LINT)

View File

@ -11,6 +11,22 @@ Subpackages
Submodules
----------
libp2p.pubsub.abc module
------------------------
.. automodule:: libp2p.pubsub.abc
:members:
:undoc-members:
:show-inheritance:
libp2p.pubsub.exceptions module
-------------------------------
.. automodule:: libp2p.pubsub.exceptions
:members:
:undoc-members:
:show-inheritance:
libp2p.pubsub.floodsub module
-----------------------------
@ -51,10 +67,10 @@ libp2p.pubsub.pubsub\_notifee module
:undoc-members:
:show-inheritance:
libp2p.pubsub.pubsub\_router\_interface module
----------------------------------------------
libp2p.pubsub.subscription module
---------------------------------
.. automodule:: libp2p.pubsub.pubsub_router_interface
.. automodule:: libp2p.pubsub.subscription
:members:
:undoc-members:
:show-inheritance:

View File

@ -1,11 +1,10 @@
import argparse
import asyncio
import sys
import urllib.request
import multiaddr
import trio
from libp2p import new_node
from libp2p import new_host
from libp2p.network.stream.net_stream_interface import INetStream
from libp2p.peer.peerinfo import info_from_p2p_addr
from libp2p.typing import TProtocol
@ -26,53 +25,47 @@ async def read_data(stream: INetStream) -> None:
async def write_data(stream: INetStream) -> None:
loop = asyncio.get_event_loop()
async_f = trio.wrap_file(sys.stdin)
while True:
line = await loop.run_in_executor(None, sys.stdin.readline)
line = await async_f.readline()
await stream.write(line.encode())
async def run(port: int, destination: str, localhost: bool) -> None:
if localhost:
ip = "127.0.0.1"
else:
ip = urllib.request.urlopen("https://v4.ident.me/").read().decode("utf8")
transport_opt = f"/ip4/{ip}/tcp/{port}"
host = await new_node(transport_opt=[transport_opt])
async def run(port: int, destination: str) -> None:
localhost_ip = "127.0.0.1"
listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/{port}")
host = new_host()
async with host.run(listen_addrs=[listen_addr]), trio.open_nursery() as nursery:
if not destination: # its the server
await host.get_network().listen(multiaddr.Multiaddr(transport_opt))
async def stream_handler(stream: INetStream) -> None:
nursery.start_soon(read_data, stream)
nursery.start_soon(write_data, stream)
if not destination: # its the server
host.set_stream_handler(PROTOCOL_ID, stream_handler)
async def stream_handler(stream: INetStream) -> None:
asyncio.ensure_future(read_data(stream))
asyncio.ensure_future(write_data(stream))
print(
f"Run 'python ./examples/chat/chat.py "
f"-p {int(port) + 1} "
f"-d /ip4/{localhost_ip}/tcp/{port}/p2p/{host.get_id().pretty()}' "
"on another console."
)
print("Waiting for incoming connection...")
host.set_stream_handler(PROTOCOL_ID, stream_handler)
else: # its the client
maddr = multiaddr.Multiaddr(destination)
info = info_from_p2p_addr(maddr)
# Associate the peer with local ip address
await host.connect(info)
# Start a stream with the destination.
# Multiaddress of the destination peer is fetched from the peerstore using 'peerId'.
stream = await host.new_stream(info.peer_id, [PROTOCOL_ID])
localhost_opt = " --localhost" if localhost else ""
nursery.start_soon(read_data, stream)
nursery.start_soon(write_data, stream)
print(f"Connected to peer {info.addrs[0]}")
print(
f"Run 'python ./examples/chat/chat.py"
+ localhost_opt
+ f" -p {int(port) + 1} -d /ip4/{ip}/tcp/{port}/p2p/{host.get_id().pretty()}'"
+ " on another console."
)
print("Waiting for incoming connection...")
else: # its the client
maddr = multiaddr.Multiaddr(destination)
info = info_from_p2p_addr(maddr)
# Associate the peer with local ip address
await host.connect(info)
# Start a stream with the destination.
# Multiaddress of the destination peer is fetched from the peerstore using 'peerId'.
stream = await host.new_stream(info.peer_id, [PROTOCOL_ID])
asyncio.ensure_future(read_data(stream))
asyncio.ensure_future(write_data(stream))
print("Connected to peer %s" % info.addrs[0])
await trio.sleep_forever()
def main() -> None:
@ -86,11 +79,6 @@ def main() -> None:
"/ip4/127.0.0.1/tcp/8000/p2p/QmQn4SwGkDZKkUEpBRBvTmheQycxAHJUNmVEnjA2v1qe8Q"
)
parser = argparse.ArgumentParser(description=description)
parser.add_argument(
"--debug",
action="store_true",
help="generate the same node ID on every execution",
)
parser.add_argument(
"-p", "--port", default=8000, type=int, help="source port number"
)
@ -100,26 +88,15 @@ def main() -> None:
type=str,
help=f"destination multiaddr string, e.g. {example_maddr}",
)
parser.add_argument(
"-l",
"--localhost",
dest="localhost",
action="store_true",
help="flag indicating if localhost should be used or an external IP",
)
args = parser.parse_args()
if not args.port:
raise RuntimeError("was not able to determine a local port")
loop = asyncio.get_event_loop()
try:
asyncio.ensure_future(run(args.port, args.destination, args.localhost))
loop.run_forever()
trio.run(run, *(args.port, args.destination))
except KeyboardInterrupt:
pass
finally:
loop.close()
if __name__ == "__main__":

View File

@ -1,10 +1,9 @@
import argparse
import asyncio
import urllib.request
import multiaddr
import trio
from libp2p import new_node
from libp2p import new_host
from libp2p.crypto.secp256k1 import create_new_key_pair
from libp2p.network.stream.net_stream_interface import INetStream
from libp2p.peer.peerinfo import info_from_p2p_addr
@ -20,12 +19,9 @@ async def _echo_stream_handler(stream: INetStream) -> None:
await stream.close()
async def run(port: int, destination: str, localhost: bool, seed: int = None) -> None:
if localhost:
ip = "127.0.0.1"
else:
ip = urllib.request.urlopen("https://v4.ident.me/").read().decode("utf8")
transport_opt = f"/ip4/{ip}/tcp/{port}"
async def run(port: int, destination: str, seed: int = None) -> None:
localhost_ip = "127.0.0.1"
listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/{port}")
if seed:
import random
@ -38,47 +34,43 @@ async def run(port: int, destination: str, localhost: bool, seed: int = None) ->
secret = secrets.token_bytes(32)
host = await new_node(
key_pair=create_new_key_pair(secret), transport_opt=[transport_opt]
)
host = new_host(key_pair=create_new_key_pair(secret))
async with host.run(listen_addrs=[listen_addr]):
print(f"I am {host.get_id().to_string()}")
print(f"I am {host.get_id().to_string()}")
await host.get_network().listen(multiaddr.Multiaddr(transport_opt))
if not destination: # its the server
if not destination: # its the server
host.set_stream_handler(PROTOCOL_ID, _echo_stream_handler)
host.set_stream_handler(PROTOCOL_ID, _echo_stream_handler)
print(
f"Run 'python ./examples/echo/echo.py "
f"-p {int(port) + 1} "
f"-d /ip4/{localhost_ip}/tcp/{port}/p2p/{host.get_id().pretty()}' "
"on another console."
)
print("Waiting for incoming connections...")
await trio.sleep_forever()
localhost_opt = " --localhost" if localhost else ""
else: # its the client
maddr = multiaddr.Multiaddr(destination)
info = info_from_p2p_addr(maddr)
# Associate the peer with local ip address
await host.connect(info)
print(
f"Run 'python ./examples/echo/echo.py"
+ localhost_opt
+ f" -p {int(port) + 1} -d /ip4/{ip}/tcp/{port}/p2p/{host.get_id().pretty()}'"
+ " on another console."
)
print("Waiting for incoming connections...")
# Start a stream with the destination.
# Multiaddress of the destination peer is fetched from the peerstore using 'peerId'.
stream = await host.new_stream(info.peer_id, [PROTOCOL_ID])
else: # its the client
maddr = multiaddr.Multiaddr(destination)
info = info_from_p2p_addr(maddr)
# Associate the peer with local ip address
await host.connect(info)
msg = b"hi, there!\n"
# Start a stream with the destination.
# Multiaddress of the destination peer is fetched from the peerstore using 'peerId'.
stream = await host.new_stream(info.peer_id, [PROTOCOL_ID])
await stream.write(msg)
# Notify the other side about EOF
await stream.close()
response = await stream.read()
msg = b"hi, there!\n"
await stream.write(msg)
# Notify the other side about EOF
await stream.close()
response = await stream.read()
print(f"Sent: {msg}")
print(f"Got: {response}")
print(f"Sent: {msg}")
print(f"Got: {response}")
def main() -> None:
@ -94,11 +86,6 @@ def main() -> None:
"/ip4/127.0.0.1/tcp/8000/p2p/QmQn4SwGkDZKkUEpBRBvTmheQycxAHJUNmVEnjA2v1qe8Q"
)
parser = argparse.ArgumentParser(description=description)
parser.add_argument(
"--debug",
action="store_true",
help="generate the same node ID on every execution",
)
parser.add_argument(
"-p", "--port", default=8000, type=int, help="source port number"
)
@ -108,13 +95,6 @@ def main() -> None:
type=str,
help=f"destination multiaddr string, e.g. {example_maddr}",
)
parser.add_argument(
"-l",
"--localhost",
dest="localhost",
action="store_true",
help="flag indicating if localhost should be used or an external IP",
)
parser.add_argument(
"-s",
"--seed",
@ -126,16 +106,10 @@ def main() -> None:
if not args.port:
raise RuntimeError("was not able to determine a local port")
loop = asyncio.get_event_loop()
try:
asyncio.ensure_future(
run(args.port, args.destination, args.localhost, args.seed)
)
loop.run_forever()
trio.run(run, args.port, args.destination, args.seed)
except KeyboardInterrupt:
pass
finally:
loop.close()
if __name__ == "__main__":

View File

@ -1,12 +1,9 @@
import asyncio
from typing import Sequence
from libp2p.crypto.keys import KeyPair
from libp2p.crypto.rsa import create_new_key_pair
from libp2p.host.basic_host import BasicHost
from libp2p.host.host_interface import IHost
from libp2p.host.routed_host import RoutedHost
from libp2p.network.network_interface import INetwork
from libp2p.network.network_interface import INetworkService
from libp2p.network.swarm import Swarm
from libp2p.peer.id import ID
from libp2p.peer.peerstore import PeerStore
@ -21,18 +18,6 @@ from libp2p.transport.upgrader import TransportUpgrader
from libp2p.typing import TProtocol
async def cleanup_done_tasks() -> None:
"""clean up asyncio done tasks to free up resources."""
while True:
for task in asyncio.all_tasks():
if task.done():
await task
# Need not run often
# Some sleep necessary to context switch
await asyncio.sleep(3)
def generate_new_rsa_identity() -> KeyPair:
return create_new_key_pair()
@ -42,29 +27,28 @@ def generate_peer_id_from(key_pair: KeyPair) -> ID:
return ID.from_pubkey(public_key)
def initialize_default_swarm(
key_pair: KeyPair,
id_opt: ID = None,
transport_opt: Sequence[str] = None,
def new_swarm(
key_pair: KeyPair = None,
muxer_opt: TMuxerOptions = None,
sec_opt: TSecurityOptions = None,
peerstore_opt: IPeerStore = None,
) -> Swarm:
) -> INetworkService:
"""
initialize swarm when no swarm is passed in.
Create a swarm instance based on the parameters.
:param id_opt: optional id for host
:param transport_opt: optional choice of transport upgrade
:param key_pair: optional choice of the ``KeyPair``
:param muxer_opt: optional choice of stream muxer
:param sec_opt: optional choice of security upgrade
:param peerstore_opt: optional peerstore
:return: return a default swarm instance
"""
if not id_opt:
id_opt = generate_peer_id_from(key_pair)
if key_pair is None:
key_pair = generate_new_rsa_identity()
# TODO: Parse `transport_opt` to determine transport
id_opt = generate_peer_id_from(key_pair)
# TODO: Parse `listen_addrs` to determine transport
transport = TCP()
muxer_transports_by_protocol = muxer_opt or {MPLEX_PROTOCOL_ID: Mplex}
@ -80,57 +64,35 @@ def initialize_default_swarm(
# Store our key pair in peerstore
peerstore.add_key_pair(id_opt, key_pair)
# TODO: Initialize discovery if not presented
return Swarm(id_opt, peerstore, upgrader, transport)
async def new_node(
def new_host(
key_pair: KeyPair = None,
swarm_opt: INetwork = None,
transport_opt: Sequence[str] = None,
muxer_opt: TMuxerOptions = None,
sec_opt: TSecurityOptions = None,
peerstore_opt: IPeerStore = None,
disc_opt: IPeerRouting = None,
) -> BasicHost:
) -> IHost:
"""
create new libp2p node.
Create a new libp2p host based on the given parameters.
:param key_pair: key pair for deriving an identity
:param swarm_opt: optional swarm
:param id_opt: optional id for host
:param transport_opt: optional choice of transport upgrade
:param key_pair: optional choice of the ``KeyPair``
:param muxer_opt: optional choice of stream muxer
:param sec_opt: optional choice of security upgrade
:param peerstore_opt: optional peerstore
:param disc_opt: optional discovery
:return: return a host instance
"""
if not key_pair:
key_pair = generate_new_rsa_identity()
id_opt = generate_peer_id_from(key_pair)
if not swarm_opt:
swarm_opt = initialize_default_swarm(
key_pair=key_pair,
id_opt=id_opt,
transport_opt=transport_opt,
muxer_opt=muxer_opt,
sec_opt=sec_opt,
peerstore_opt=peerstore_opt,
)
# TODO enable support for other host type
# TODO routing unimplemented
host: IHost # If not explicitly typed, MyPy raises error
swarm = new_swarm(
key_pair=key_pair,
muxer_opt=muxer_opt,
sec_opt=sec_opt,
peerstore_opt=peerstore_opt,
)
host: IHost
if disc_opt:
host = RoutedHost(swarm_opt, disc_opt)
host = RoutedHost(swarm, disc_opt)
else:
host = BasicHost(swarm_opt)
# Kick off cleanup job
asyncio.ensure_future(cleanup_done_tasks())
host = BasicHost(swarm)
return host

View File

@ -1,12 +1,14 @@
import logging
from typing import TYPE_CHECKING, List, Sequence
from typing import TYPE_CHECKING, AsyncIterator, List, Sequence
from async_generator import asynccontextmanager
from async_service import background_trio_service
import multiaddr
from libp2p.crypto.keys import PrivateKey, PublicKey
from libp2p.host.defaults import get_default_protocols
from libp2p.host.exceptions import StreamFailure
from libp2p.network.network_interface import INetwork
from libp2p.network.network_interface import INetworkService
from libp2p.network.stream.net_stream_interface import INetStream
from libp2p.peer.id import ID
from libp2p.peer.peerinfo import PeerInfo
@ -39,7 +41,7 @@ class BasicHost(IHost):
right after a stream is initialized.
"""
_network: INetwork
_network: INetworkService
peerstore: IPeerStore
multiselect: Multiselect
@ -47,7 +49,7 @@ class BasicHost(IHost):
def __init__(
self,
network: INetwork,
network: INetworkService,
default_protocols: "OrderedDict[TProtocol, StreamHandlerFn]" = None,
) -> None:
self._network = network
@ -70,7 +72,7 @@ class BasicHost(IHost):
def get_private_key(self) -> PrivateKey:
return self.peerstore.privkey(self.get_id())
def get_network(self) -> INetwork:
def get_network(self) -> INetworkService:
"""
:return: network instance of host
"""
@ -101,6 +103,20 @@ class BasicHost(IHost):
addrs.append(addr.encapsulate(p2p_part))
return addrs
@asynccontextmanager
async def run(
self, listen_addrs: Sequence[multiaddr.Multiaddr]
) -> AsyncIterator[None]:
"""
run the host instance and listen to ``listen_addrs``.
:param listen_addrs: a sequence of multiaddrs that we want to listen to
"""
network = self.get_network()
async with background_trio_service(network):
await network.listen(*listen_addrs)
yield
def set_stream_handler(
self, protocol_id: TProtocol, stream_handler: StreamHandlerFn
) -> None:

View File

@ -1,10 +1,10 @@
from abc import ABC, abstractmethod
from typing import Any, List, Sequence
from typing import Any, AsyncContextManager, List, Sequence
import multiaddr
from libp2p.crypto.keys import PrivateKey, PublicKey
from libp2p.network.network_interface import INetwork
from libp2p.network.network_interface import INetworkService
from libp2p.network.stream.net_stream_interface import INetStream
from libp2p.peer.id import ID
from libp2p.peer.peerinfo import PeerInfo
@ -31,7 +31,7 @@ class IHost(ABC):
"""
@abstractmethod
def get_network(self) -> INetwork:
def get_network(self) -> INetworkService:
"""
:return: network instance of host
"""
@ -49,6 +49,16 @@ class IHost(ABC):
:return: all the multiaddr addresses this host is listening to
"""
@abstractmethod
def run(
self, listen_addrs: Sequence[multiaddr.Multiaddr]
) -> AsyncContextManager[None]:
"""
run the host instance and listen to ``listen_addrs``.
:param listen_addrs: a sequence of multiaddrs that we want to listen to
"""
@abstractmethod
def set_stream_handler(
self, protocol_id: TProtocol, stream_handler: StreamHandlerFn

View File

@ -1,6 +1,7 @@
import asyncio
import logging
import trio
from libp2p.network.stream.exceptions import StreamClosed, StreamEOF, StreamReset
from libp2p.network.stream.net_stream_interface import INetStream
from libp2p.peer.id import ID as PeerID
@ -17,8 +18,9 @@ async def _handle_ping(stream: INetStream, peer_id: PeerID) -> bool:
"""Return a boolean indicating if we expect more pings from the peer at
``peer_id``."""
try:
payload = await asyncio.wait_for(stream.read(PING_LENGTH), RESP_TIMEOUT)
except asyncio.TimeoutError as error:
with trio.fail_after(RESP_TIMEOUT):
payload = await stream.read(PING_LENGTH)
except trio.TooSlowError as error:
logger.debug("Timed out waiting for ping from %s: %s", peer_id, error)
raise
except StreamEOF:

View File

@ -1,6 +1,6 @@
from libp2p.host.basic_host import BasicHost
from libp2p.host.exceptions import ConnectionFailure
from libp2p.network.network_interface import INetwork
from libp2p.network.network_interface import INetworkService
from libp2p.peer.peerinfo import PeerInfo
from libp2p.routing.interfaces import IPeerRouting
@ -10,7 +10,7 @@ from libp2p.routing.interfaces import IPeerRouting
class RoutedHost(BasicHost):
_router: IPeerRouting
def __init__(self, network: INetwork, router: IPeerRouting):
def __init__(self, network: INetworkService, router: IPeerRouting):
super().__init__(network)
self._router = router

View File

@ -8,7 +8,7 @@ class Closer(ABC):
class Reader(ABC):
@abstractmethod
async def read(self, n: int = -1) -> bytes:
async def read(self, n: int = None) -> bytes:
...

View File

@ -54,7 +54,7 @@ class MsgIOReader(ReadCloser):
self.read_closer = read_closer
self.next_length = None
async def read(self, n: int = -1) -> bytes:
async def read(self, n: int = None) -> bytes:
return await self.read_msg()
async def read_msg(self) -> bytes:

40
libp2p/io/trio.py Normal file
View File

@ -0,0 +1,40 @@
import logging
import trio
from libp2p.io.abc import ReadWriteCloser
from libp2p.io.exceptions import IOException
logger = logging.getLogger("libp2p.io.trio")
class TrioTCPStream(ReadWriteCloser):
stream: trio.SocketStream
# NOTE: Add both read and write lock to avoid `trio.BusyResourceError`
read_lock: trio.Lock
write_lock: trio.Lock
def __init__(self, stream: trio.SocketStream) -> None:
self.stream = stream
self.read_lock = trio.Lock()
self.write_lock = trio.Lock()
async def write(self, data: bytes) -> None:
"""Raise `RawConnError` if the underlying connection breaks."""
async with self.write_lock:
try:
await self.stream.send_all(data)
except (trio.ClosedResourceError, trio.BrokenResourceError) as error:
raise IOException from error
async def read(self, n: int = None) -> bytes:
async with self.read_lock:
if n is not None and n == 0:
return b""
try:
return await self.stream.receive_some(n)
except (trio.ClosedResourceError, trio.BrokenResourceError) as error:
raise IOException from error
async def close(self) -> None:
await self.stream.aclose()

View File

@ -1,6 +1,8 @@
from abc import abstractmethod
from typing import Tuple
import trio
from libp2p.io.abc import Closer
from libp2p.network.stream.net_stream_interface import INetStream
from libp2p.stream_muxer.abc import IMuxedConn
@ -8,11 +10,12 @@ from libp2p.stream_muxer.abc import IMuxedConn
class INetConn(Closer):
muxed_conn: IMuxedConn
event_started: trio.Event
@abstractmethod
async def new_stream(self) -> INetStream:
...
@abstractmethod
async def get_streams(self) -> Tuple[INetStream, ...]:
def get_streams(self) -> Tuple[INetStream, ...]:
...

View File

@ -1,46 +1,26 @@
import asyncio
import sys
from libp2p.io.abc import ReadWriteCloser
from libp2p.io.exceptions import IOException
from .exceptions import RawConnError
from .raw_connection_interface import IRawConnection
class RawConnection(IRawConnection):
reader: asyncio.StreamReader
writer: asyncio.StreamWriter
stream: ReadWriteCloser
is_initiator: bool
_drain_lock: asyncio.Lock
def __init__(
self,
reader: asyncio.StreamReader,
writer: asyncio.StreamWriter,
initiator: bool,
) -> None:
self.reader = reader
self.writer = writer
def __init__(self, stream: ReadWriteCloser, initiator: bool) -> None:
self.stream = stream
self.is_initiator = initiator
self._drain_lock = asyncio.Lock()
async def write(self, data: bytes) -> None:
"""Raise `RawConnError` if the underlying connection breaks."""
# Detect if underlying transport is closing before write data to it
# ref: https://github.com/ethereum/trinity/pull/614
if self.writer.transport.is_closing():
raise RawConnError("Transport is closing")
self.writer.write(data)
# Reference: https://github.com/ethereum/lahja/blob/93610b2eb46969ff1797e0748c7ac2595e130aef/lahja/asyncio/endpoint.py#L99-L102 # noqa: E501
# Use a lock to serialize drain() calls. Circumvents this bug:
# https://bugs.python.org/issue29930
async with self._drain_lock:
try:
await self.writer.drain()
except ConnectionResetError as error:
raise RawConnError() from error
try:
await self.stream.write(data)
except IOException as error:
raise RawConnError from error
async def read(self, n: int = -1) -> bytes:
async def read(self, n: int = None) -> bytes:
"""
Read up to ``n`` bytes from the underlying stream. This call is
delegated directly to the underlying ``self.reader``.
@ -48,18 +28,9 @@ class RawConnection(IRawConnection):
Raise `RawConnError` if the underlying connection breaks
"""
try:
return await self.reader.read(n)
except ConnectionResetError as error:
raise RawConnError() from error
return await self.stream.read(n)
except IOException as error:
raise RawConnError from error
async def close(self) -> None:
if self.writer.transport.is_closing():
return
self.writer.close()
if sys.version_info < (3, 7):
return
try:
await self.writer.wait_closed()
# In case the connection is already reset.
except ConnectionResetError:
return
await self.stream.close()

View File

@ -1,5 +1,6 @@
import asyncio
from typing import TYPE_CHECKING, Any, Awaitable, List, Set, Tuple
from typing import TYPE_CHECKING, Set, Tuple
import trio
from libp2p.network.connection.net_connection_interface import INetConn
from libp2p.network.stream.net_stream import NetStream
@ -19,90 +20,78 @@ class SwarmConn(INetConn):
muxed_conn: IMuxedConn
swarm: "Swarm"
streams: Set[NetStream]
event_closed: asyncio.Event
_tasks: List["asyncio.Future[Any]"]
event_closed: trio.Event
def __init__(self, muxed_conn: IMuxedConn, swarm: "Swarm") -> None:
self.muxed_conn = muxed_conn
self.swarm = swarm
self.streams = set()
self.event_closed = asyncio.Event()
self.event_closed = trio.Event()
self.event_started = trio.Event()
self._tasks = []
@property
def is_closed(self) -> bool:
return self.event_closed.is_set()
async def close(self) -> None:
if self.event_closed.is_set():
return
self.event_closed.set()
await self._cleanup()
async def _cleanup(self) -> None:
self.swarm.remove_conn(self)
await self.muxed_conn.close()
# This is just for cleaning up state. The connection has already been closed.
# We *could* optimize this but it really isn't worth it.
for stream in self.streams:
for stream in self.streams.copy():
await stream.reset()
# Force context switch for stream handlers to process the stream reset event we just emit
# before we cancel the stream handler tasks.
await asyncio.sleep(0.1)
await trio.sleep(0.1)
for task in self._tasks:
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
# Schedule `self._notify_disconnected` to make it execute after `close` is finished.
self._notify_disconnected()
await self._notify_disconnected()
async def _handle_new_streams(self) -> None:
while True:
try:
stream = await self.muxed_conn.accept_stream()
except MuxedConnUnavailable:
# If there is anything wrong in the MuxedConn,
# we should break the loop and close the connection.
break
# Asynchronously handle the accepted stream, to avoid blocking the next stream.
await self.run_task(self._handle_muxed_stream(stream))
await self.close()
async def _call_stream_handler(self, net_stream: NetStream) -> None:
try:
await self.swarm.common_stream_handler(net_stream)
# TODO: More exact exceptions
except Exception:
# TODO: Emit logs.
# TODO: Clean up and remove the stream from SwarmConn if there is anything wrong.
self.remove_stream(net_stream)
self.event_started.set()
async with trio.open_nursery() as nursery:
while True:
try:
stream = await self.muxed_conn.accept_stream()
except MuxedConnUnavailable:
await self.close()
break
# Asynchronously handle the accepted stream, to avoid blocking the next stream.
nursery.start_soon(self._handle_muxed_stream, stream)
async def _handle_muxed_stream(self, muxed_stream: IMuxedStream) -> None:
net_stream = self._add_stream(muxed_stream)
if self.swarm.common_stream_handler is not None:
await self.run_task(self._call_stream_handler(net_stream))
net_stream = await self._add_stream(muxed_stream)
try:
# Ignore type here since mypy complains: https://github.com/python/mypy/issues/2427
await self.swarm.common_stream_handler(net_stream) # type: ignore
finally:
# As long as `common_stream_handler`, remove the stream.
self.remove_stream(net_stream)
def _add_stream(self, muxed_stream: IMuxedStream) -> NetStream:
async def _add_stream(self, muxed_stream: IMuxedStream) -> NetStream:
net_stream = NetStream(muxed_stream)
self.streams.add(net_stream)
self.swarm.notify_opened_stream(net_stream)
await self.swarm.notify_opened_stream(net_stream)
return net_stream
def _notify_disconnected(self) -> None:
self.swarm.notify_disconnected(self)
async def _notify_disconnected(self) -> None:
await self.swarm.notify_disconnected(self)
async def start(self) -> None:
await self.run_task(self._handle_new_streams())
async def run_task(self, coro: Awaitable[Any]) -> None:
self._tasks.append(asyncio.ensure_future(coro))
await self._handle_new_streams()
async def new_stream(self) -> NetStream:
muxed_stream = await self.muxed_conn.open_stream()
return self._add_stream(muxed_stream)
return await self._add_stream(muxed_stream)
async def get_streams(self) -> Tuple[NetStream, ...]:
def get_streams(self) -> Tuple[NetStream, ...]:
return tuple(self.streams)
def remove_stream(self, stream: NetStream) -> None:

View File

@ -1,6 +1,7 @@
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Dict, Sequence
from async_service import ServiceAPI
from multiaddr import Multiaddr
from libp2p.network.connection.net_connection_interface import INetConn
@ -70,3 +71,7 @@ class INetwork(ABC):
@abstractmethod
async def close_peer(self, peer_id: ID) -> None:
pass
class INetworkService(INetwork, ServiceAPI):
pass

View File

@ -37,7 +37,7 @@ class NetStream(INetStream):
"""
self.protocol_id = protocol_id
async def read(self, n: int = -1) -> bytes:
async def read(self, n: int = None) -> bytes:
"""
reads from stream.

View File

@ -1,9 +1,11 @@
import asyncio
import logging
from typing import Dict, List, Optional
from async_service import Service
from multiaddr import Multiaddr
import trio
from libp2p.io.abc import ReadWriteCloser
from libp2p.network.connection.net_connection_interface import INetConn
from libp2p.peer.id import ID
from libp2p.peer.peerstore import PeerStoreError
@ -23,14 +25,21 @@ from ..exceptions import MultiError
from .connection.raw_connection import RawConnection
from .connection.swarm_connection import SwarmConn
from .exceptions import SwarmException
from .network_interface import INetwork
from .network_interface import INetworkService
from .notifee_interface import INotifee
from .stream.net_stream_interface import INetStream
logger = logging.getLogger("libp2p.network.swarm")
class Swarm(INetwork):
def create_default_stream_handler(network: INetworkService) -> StreamHandlerFn:
async def stream_handler(stream: INetStream) -> None:
await network.get_manager().wait_finished()
return stream_handler
class Swarm(Service, INetworkService):
self_id: ID
peerstore: IPeerStore
@ -40,7 +49,9 @@ class Swarm(INetwork):
# whereas in Go one `peer_id` may point to multiple connections.
connections: Dict[ID, INetConn]
listeners: Dict[str, IListener]
common_stream_handler: Optional[StreamHandlerFn]
common_stream_handler: StreamHandlerFn
listener_nursery: Optional[trio.Nursery]
event_listener_nursery_created: trio.Event
notifees: List[INotifee]
@ -61,13 +72,31 @@ class Swarm(INetwork):
# Create Notifee array
self.notifees = []
self.common_stream_handler = None
# Ignore type here since mypy complains: https://github.com/python/mypy/issues/2427
self.common_stream_handler = create_default_stream_handler(self) # type: ignore
self.listener_nursery = None
self.event_listener_nursery_created = trio.Event()
async def run(self) -> None:
async with trio.open_nursery() as nursery:
# Create a nursery for listener tasks.
self.listener_nursery = nursery
self.event_listener_nursery_created.set()
try:
await self.manager.wait_finished()
finally:
# The service ended. Cancel listener tasks.
nursery.cancel_scope.cancel()
# Indicate that the nursery has been cancelled.
self.listener_nursery = None
def get_peer_id(self) -> ID:
return self.self_id
def set_stream_handler(self, stream_handler: StreamHandlerFn) -> None:
self.common_stream_handler = stream_handler
# Ignore type here since mypy complains: https://github.com/python/mypy/issues/2427
self.common_stream_handler = stream_handler # type: ignore
async def dial_peer(self, peer_id: ID) -> INetConn:
"""
@ -195,19 +224,15 @@ class Swarm(INetwork):
- Call listener listen with the multiaddr
- Map multiaddr to listener
"""
# We need to wait until `self.listener_nursery` is created.
await self.event_listener_nursery_created.wait()
for maddr in multiaddrs:
if str(maddr) in self.listeners:
return True
async def conn_handler(
reader: asyncio.StreamReader, writer: asyncio.StreamWriter
) -> None:
connection_info = writer.get_extra_info("peername")
# TODO make a proper multiaddr
peer_addr = f"/ip4/{connection_info[0]}/tcp/{connection_info[1]}"
logger.debug("inbound connection at %s", peer_addr)
# logger.debug("inbound connection request", peer_id)
raw_conn = RawConnection(reader, writer, False)
async def conn_handler(read_write_closer: ReadWriteCloser) -> None:
raw_conn = RawConnection(read_write_closer, False)
# Per, https://discuss.libp2p.io/t/multistream-security/130, we first secure
# the conn and then mux the conn
@ -217,16 +242,13 @@ class Swarm(INetwork):
raw_conn, ID(b""), False
)
except SecurityUpgradeFailure as error:
logger.debug("failed to upgrade security for peer at %s", peer_addr)
logger.debug("failed to upgrade security for peer at %s", maddr)
await raw_conn.close()
raise SwarmException(
f"failed to upgrade security for peer at {peer_addr}"
f"failed to upgrade security for peer at {maddr}"
) from error
peer_id = secured_conn.get_remote_peer()
logger.debug("upgraded security for peer at %s", peer_addr)
logger.debug("identified peer at %s as %s", peer_addr, peer_id)
try:
muxed_conn = await self.upgrader.upgrade_connection(
secured_conn, peer_id
@ -240,17 +262,24 @@ class Swarm(INetwork):
logger.debug("upgraded mux for peer %s", peer_id)
await self.add_conn(muxed_conn)
logger.debug("successfully opened connection to peer %s", peer_id)
# NOTE: This is a intentional barrier to prevent from the handler exiting and
# closing the connection.
await self.manager.wait_finished()
try:
# Success
listener = self.transport.create_listener(conn_handler)
self.listeners[str(maddr)] = listener
await listener.listen(maddr)
# TODO: `listener.listen` is not bounded with nursery. If we want to be
# I/O agnostic, we should change the API.
if self.listener_nursery is None:
raise SwarmException("swarm instance hasn't been run")
await listener.listen(maddr, self.listener_nursery)
# Call notifiers since event occurred
self.notify_listen(maddr)
await self.notify_listen(maddr)
return True
except IOError:
@ -261,26 +290,12 @@ class Swarm(INetwork):
return False
async def close(self) -> None:
# TODO: Prevent from new listeners and conns being added.
# Reference: https://github.com/libp2p/go-libp2p-swarm/blob/8be680aef8dea0a4497283f2f98470c2aeae6b65/swarm.go#L124-L134 # noqa: E501
# Close listeners
await asyncio.gather(
*[listener.close() for listener in self.listeners.values()]
)
# Close connections
await asyncio.gather(
*[connection.close() for connection in self.connections.values()]
)
await self.manager.stop()
logger.debug("swarm successfully closed")
async def close_peer(self, peer_id: ID) -> None:
if peer_id not in self.connections:
return
# TODO: Should be changed to close multisple connections,
# if we have several connections per peer in the future.
connection = self.connections[peer_id]
# NOTE: `connection.close` will delete `peer_id` from `self.connections`
# and `notify_disconnected` for us.
@ -293,11 +308,14 @@ class Swarm(INetwork):
and start to monitor the connection for its new streams and
disconnection."""
swarm_conn = SwarmConn(muxed_conn, self)
self.manager.run_task(muxed_conn.start)
await muxed_conn.event_started.wait()
self.manager.run_task(swarm_conn.start)
await swarm_conn.event_started.wait()
# Store muxed_conn with peer id
self.connections[muxed_conn.peer_id] = swarm_conn
# Call notifiers since event occurred
self.notify_connected(swarm_conn)
await swarm_conn.start()
await self.notify_connected(swarm_conn)
return swarm_conn
def remove_conn(self, swarm_conn: SwarmConn) -> None:
@ -306,14 +324,10 @@ class Swarm(INetwork):
peer_id = swarm_conn.muxed_conn.peer_id
if peer_id not in self.connections:
return
# TODO: Should be changed to remove the exact connection,
# if we have several connections per peer in the future.
del self.connections[peer_id]
# Notifee
# TODO: Remeber the spawn notifying tasks and clean them up when closing.
def register_notifee(self, notifee: INotifee) -> None:
"""
:param notifee: object implementing Notifee interface
@ -321,20 +335,28 @@ class Swarm(INetwork):
"""
self.notifees.append(notifee)
def notify_opened_stream(self, stream: INetStream) -> None:
asyncio.gather(
*[notifee.opened_stream(self, stream) for notifee in self.notifees]
)
async def notify_opened_stream(self, stream: INetStream) -> None:
async with trio.open_nursery() as nursery:
for notifee in self.notifees:
nursery.start_soon(notifee.opened_stream, self, stream)
# TODO: `notify_closed_stream`
async def notify_connected(self, conn: INetConn) -> None:
async with trio.open_nursery() as nursery:
for notifee in self.notifees:
nursery.start_soon(notifee.connected, self, conn)
def notify_connected(self, conn: INetConn) -> None:
asyncio.gather(*[notifee.connected(self, conn) for notifee in self.notifees])
async def notify_disconnected(self, conn: INetConn) -> None:
async with trio.open_nursery() as nursery:
for notifee in self.notifees:
nursery.start_soon(notifee.disconnected, self, conn)
def notify_disconnected(self, conn: INetConn) -> None:
asyncio.gather(*[notifee.disconnected(self, conn) for notifee in self.notifees])
async def notify_listen(self, multiaddr: Multiaddr) -> None:
async with trio.open_nursery() as nursery:
for notifee in self.notifees:
nursery.start_soon(notifee.listen, self, multiaddr)
def notify_listen(self, multiaddr: Multiaddr) -> None:
asyncio.gather(*[notifee.listen(self, multiaddr) for notifee in self.notifees])
async def notify_closed_stream(self, stream: INetStream) -> None:
raise NotImplementedError
# TODO: `notify_listen_close`
async def notify_listen_close(self, multiaddr: Multiaddr) -> None:
raise NotImplementedError

View File

@ -25,9 +25,6 @@ def info_from_p2p_addr(addr: multiaddr.Multiaddr) -> PeerInfo:
if not addr:
raise InvalidAddrError("`addr` should not be `None`")
if not isinstance(addr, multiaddr.Multiaddr):
raise InvalidAddrError(f"`addr`={addr} should be of type `Multiaddr`")
parts = addr.split()
if not parts:
raise InvalidAddrError(

View File

@ -1,15 +1,37 @@
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, List
from typing import (
TYPE_CHECKING,
AsyncContextManager,
AsyncIterable,
KeysView,
List,
Tuple,
)
from async_service import ServiceAPI
from libp2p.peer.id import ID
from libp2p.typing import TProtocol
from .pb import rpc_pb2
from .typing import ValidatorFn
if TYPE_CHECKING:
from .pubsub import Pubsub # noqa: F401
class ISubscriptionAPI(
AsyncContextManager["ISubscriptionAPI"], AsyncIterable[rpc_pb2.Message]
):
@abstractmethod
async def unsubscribe(self) -> None:
...
@abstractmethod
async def get(self) -> rpc_pb2.Message:
...
class IPubsubRouter(ABC):
@abstractmethod
def get_protocols(self) -> List[TProtocol]:
@ -53,7 +75,6 @@ class IPubsubRouter(ABC):
:param rpc: rpc message
"""
# FIXME: Should be changed to type 'peer.ID'
@abstractmethod
async def publish(self, msg_forwarder: ID, pubsub_msg: rpc_pb2.Message) -> None:
"""
@ -80,3 +101,46 @@ class IPubsubRouter(ABC):
:param topic: topic to leave
"""
class IPubsub(ServiceAPI):
@property
@abstractmethod
def my_id(self) -> ID:
...
@property
@abstractmethod
def protocols(self) -> Tuple[TProtocol, ...]:
...
@property
@abstractmethod
def topic_ids(self) -> KeysView[str]:
...
@abstractmethod
def set_topic_validator(
self, topic: str, validator: ValidatorFn, is_async_validator: bool
) -> None:
...
@abstractmethod
def remove_topic_validator(self, topic: str) -> None:
...
@abstractmethod
async def wait_until_ready(self) -> None:
...
@abstractmethod
async def subscribe(self, topic_id: str) -> ISubscriptionAPI:
...
@abstractmethod
async def unsubscribe(self, topic_id: str) -> None:
...
@abstractmethod
async def publish(self, topic_id: str, data: bytes) -> None:
...

View File

@ -0,0 +1,9 @@
from libp2p.exceptions import BaseLibp2pError
class PubsubRouterError(BaseLibp2pError):
pass
class NoPubsubAttached(PubsubRouterError):
pass

View File

@ -1,14 +1,16 @@
import logging
from typing import Iterable, List, Sequence
import trio
from libp2p.network.stream.exceptions import StreamClosed
from libp2p.peer.id import ID
from libp2p.typing import TProtocol
from libp2p.utils import encode_varint_prefixed
from .abc import IPubsubRouter
from .pb import rpc_pb2
from .pubsub import Pubsub
from .pubsub_router_interface import IPubsubRouter
PROTOCOL_ID = TProtocol("/floodsub/1.0.0")
@ -61,6 +63,8 @@ class FloodSub(IPubsubRouter):
:param rpc: rpc message
"""
# Checkpoint
await trio.hazmat.checkpoint()
async def publish(self, msg_forwarder: ID, pubsub_msg: rpc_pb2.Message) -> None:
"""
@ -107,6 +111,8 @@ class FloodSub(IPubsubRouter):
:param topic: topic to join
"""
# Checkpoint
await trio.hazmat.checkpoint()
async def leave(self, topic: str) -> None:
"""
@ -115,6 +121,8 @@ class FloodSub(IPubsubRouter):
:param topic: topic to leave
"""
# Checkpoint
await trio.hazmat.checkpoint()
def _get_peers_to_send(
self, topic_ids: Iterable[str], msg_forwarder: ID, origin: ID

View File

@ -1,28 +1,30 @@
from ast import literal_eval
import asyncio
from collections import defaultdict
import logging
import random
from typing import Any, DefaultDict, Dict, Iterable, List, Sequence, Set, Tuple
from async_service import Service
import trio
from libp2p.network.stream.exceptions import StreamClosed
from libp2p.peer.id import ID
from libp2p.pubsub import floodsub
from libp2p.typing import TProtocol
from libp2p.utils import encode_varint_prefixed
from .abc import IPubsubRouter
from .exceptions import NoPubsubAttached
from .mcache import MessageCache
from .pb import rpc_pb2
from .pubsub import Pubsub
from .pubsub_router_interface import IPubsubRouter
PROTOCOL_ID = TProtocol("/meshsub/1.0.0")
logger = logging.getLogger("libp2p.pubsub.gossipsub")
class GossipSub(IPubsubRouter):
class GossipSub(IPubsubRouter, Service):
protocols: List[TProtocol]
pubsub: Pubsub
@ -38,7 +40,8 @@ class GossipSub(IPubsubRouter):
# The protocol peer supports
peer_protocol: Dict[ID, TProtocol]
time_since_last_publish: Dict[str, int]
# TODO: Add `time_since_last_publish`
# Create topic --> time since last publish map.
mcache: MessageCache
@ -75,9 +78,6 @@ class GossipSub(IPubsubRouter):
# Create peer --> protocol mapping
self.peer_protocol = {}
# Create topic --> time since last publish map
self.time_since_last_publish = {}
# Create message cache
self.mcache = MessageCache(gossip_window, gossip_history)
@ -85,6 +85,12 @@ class GossipSub(IPubsubRouter):
self.heartbeat_initial_delay = heartbeat_initial_delay
self.heartbeat_interval = heartbeat_interval
async def run(self) -> None:
if self.pubsub is None:
raise NoPubsubAttached
self.manager.run_daemon_task(self.heartbeat)
await self.manager.wait_finished()
# Interface functions
def get_protocols(self) -> List[TProtocol]:
@ -104,9 +110,6 @@ class GossipSub(IPubsubRouter):
logger.debug("attached to pusub")
# Start heartbeat now that we have a pubsub instance
asyncio.ensure_future(self.heartbeat())
def add_peer(self, peer_id: ID, protocol_id: TProtocol) -> None:
"""
Notifies the router that a new peer has been connected.
@ -370,7 +373,7 @@ class GossipSub(IPubsubRouter):
state changes in the preceding heartbeat
"""
# Start after a delay. Ref: https://github.com/libp2p/go-libp2p-pubsub/blob/01b9825fbee1848751d90a8469e3f5f43bac8466/gossipsub.go#L410 # Noqa: E501
await asyncio.sleep(self.heartbeat_initial_delay)
await trio.sleep(self.heartbeat_initial_delay)
while True:
# Maintain mesh and keep track of which peers to send GRAFT or PRUNE to
peers_to_graft, peers_to_prune = self.mesh_heartbeat()
@ -385,7 +388,7 @@ class GossipSub(IPubsubRouter):
self.mcache.shift()
await asyncio.sleep(self.heartbeat_interval)
await trio.sleep(self.heartbeat_interval)
def mesh_heartbeat(
self
@ -413,7 +416,7 @@ class GossipSub(IPubsubRouter):
if num_mesh_peers_in_topic > self.degree_high:
# Select |mesh[topic]| - D peers from mesh[topic]
selected_peers = GossipSub.select_from_minus(
selected_peers = self.select_from_minus(
num_mesh_peers_in_topic - self.degree, self.mesh[topic], set()
)
for peer in selected_peers:
@ -428,15 +431,10 @@ class GossipSub(IPubsubRouter):
# Note: the comments here are the exact pseudocode from the spec
for topic in self.fanout:
# Delete topic entry if it's not in `pubsub.peer_topics`
# or if it's time-since-last-published > ttl
# TODO: there's no way time_since_last_publish gets set anywhere yet
if (
topic not in self.pubsub.peer_topics
or self.time_since_last_publish[topic] > self.time_to_live
):
# or (TODO) if it's time-since-last-published > ttl
if topic not in self.pubsub.peer_topics:
# Remove topic from fanout
del self.fanout[topic]
del self.time_since_last_publish[topic]
else:
# Check if fanout peers are still in the topic and remove the ones that are not
# ref: https://github.com/libp2p/go-libp2p-pubsub/blob/01b9825fbee1848751d90a8469e3f5f43bac8466/gossipsub.go#L498-L504 # noqa: E501

View File

@ -1,21 +1,12 @@
import asyncio
import functools
import logging
import time
from typing import (
TYPE_CHECKING,
Awaitable,
Callable,
Dict,
List,
NamedTuple,
Set,
Tuple,
Union,
cast,
)
from typing import TYPE_CHECKING, Dict, KeysView, List, NamedTuple, Set, Tuple, cast
from async_service import Service
import base58
from lru import LRU
import trio
from libp2p.crypto.keys import PrivateKey
from libp2p.exceptions import ParseError, ValidationError
@ -28,15 +19,21 @@ from libp2p.peer.id import ID
from libp2p.typing import TProtocol
from libp2p.utils import encode_varint_prefixed, read_varint_prefixed_bytes
from .abc import IPubsub, ISubscriptionAPI
from .pb import rpc_pb2
from .pubsub_notifee import PubsubNotifee
from .subscription import TrioSubscriptionAPI
from .typing import AsyncValidatorFn, SyncValidatorFn, ValidatorFn
from .validators import PUBSUB_SIGNING_PREFIX, signature_validator
if TYPE_CHECKING:
from .pubsub_router_interface import IPubsubRouter # noqa: F401
from .abc import IPubsubRouter # noqa: F401
from typing import Any # noqa: F401
# Ref: https://github.com/libp2p/go-libp2p-pubsub/blob/40e1c94708658b155f30cf99e4574f384756d83c/topic.go#L97 # noqa: E501
SUBSCRIPTION_CHANNEL_SIZE = 32
logger = logging.getLogger("libp2p.pubsub")
@ -45,34 +42,24 @@ def get_msg_id(msg: rpc_pb2.Message) -> Tuple[bytes, bytes]:
return (msg.seqno, msg.from_id)
SyncValidatorFn = Callable[[ID, rpc_pb2.Message], bool]
AsyncValidatorFn = Callable[[ID, rpc_pb2.Message], Awaitable[bool]]
ValidatorFn = Union[SyncValidatorFn, AsyncValidatorFn]
class TopicValidator(NamedTuple):
validator: ValidatorFn
is_async: bool
class Pubsub:
class Pubsub(Service, IPubsub):
host: IHost
my_id: ID
router: "IPubsubRouter"
peer_queue: "asyncio.Queue[ID]"
dead_peer_queue: "asyncio.Queue[ID]"
protocols: List[TProtocol]
incoming_msgs_from_peers: "asyncio.Queue[rpc_pb2.Message]"
outgoing_messages: "asyncio.Queue[rpc_pb2.Message]"
peer_receive_channel: "trio.MemoryReceiveChannel[ID]"
dead_peer_receive_channel: "trio.MemoryReceiveChannel[ID]"
seen_messages: LRU
my_topics: Dict[str, "asyncio.Queue[rpc_pb2.Message]"]
subscribed_topics_send: Dict[str, "trio.MemorySendChannel[rpc_pb2.Message]"]
subscribed_topics_receive: Dict[str, "TrioSubscriptionAPI"]
peer_topics: Dict[str, Set[ID]]
peers: Dict[ID, INetStream]
@ -81,17 +68,17 @@ class Pubsub:
counter: int # uint64
_tasks: List["asyncio.Future[Any]"]
# Indicate if we should enforce signature verification
strict_signing: bool
sign_key: PrivateKey
event_handle_peer_queue_started: trio.Event
event_handle_dead_peer_queue_started: trio.Event
def __init__(
self,
host: IHost,
router: "IPubsubRouter",
my_id: ID,
cache_size: int = None,
strict_signing: bool = True,
) -> None:
@ -107,39 +94,44 @@ class Pubsub:
"""
self.host = host
self.router = router
self.my_id = my_id
# Attach this new Pubsub object to the router
self.router.attach(self)
peer_send, peer_receive = trio.open_memory_channel[ID](0)
dead_peer_send, dead_peer_receive = trio.open_memory_channel[ID](0)
# Only keep the receive channels in `Pubsub`.
# Therefore, we can only close from the receive side.
self.peer_receive_channel = peer_receive
self.dead_peer_receive_channel = dead_peer_receive
# Register a notifee
self.peer_queue = asyncio.Queue()
self.dead_peer_queue = asyncio.Queue()
self.host.get_network().register_notifee(
PubsubNotifee(self.peer_queue, self.dead_peer_queue)
PubsubNotifee(peer_send, dead_peer_send)
)
# Register stream handlers for each pubsub router protocol to handle
# the pubsub streams opened on those protocols
self.protocols = self.router.get_protocols()
for protocol in self.protocols:
for protocol in router.get_protocols():
self.host.set_stream_handler(protocol, self.stream_handler)
# Use asyncio queues for proper context switching
self.incoming_msgs_from_peers = asyncio.Queue()
self.outgoing_messages = asyncio.Queue()
# keeps track of seen messages as LRU cache
if cache_size is None:
self.cache_size = 128
else:
self.cache_size = cache_size
self.strict_signing = strict_signing
if strict_signing:
self.sign_key = self.host.get_private_key()
else:
self.sign_key = None
self.seen_messages = LRU(self.cache_size)
# Map of topics we are subscribed to blocking queues
# for when the given topic receives a message
self.my_topics = {}
self.subscribed_topics_send = {}
self.subscribed_topics_receive = {}
# Map of topic to peers to keep track of what peers are subscribed to
self.peer_topics = {}
@ -152,22 +144,31 @@ class Pubsub:
self.counter = int(time.time())
self._tasks = []
# Call handle peer to keep waiting for updates to peer queue
self._tasks.append(asyncio.ensure_future(self.handle_peer_queue()))
self._tasks.append(asyncio.ensure_future(self.handle_dead_peer_queue()))
self.event_handle_peer_queue_started = trio.Event()
self.event_handle_dead_peer_queue_started = trio.Event()
self.strict_signing = strict_signing
if strict_signing:
self.sign_key = self.host.get_private_key()
else:
self.sign_key = None
async def run(self) -> None:
self.manager.run_daemon_task(self.handle_peer_queue)
self.manager.run_daemon_task(self.handle_dead_peer_queue)
await self.manager.wait_finished()
@property
def my_id(self) -> ID:
return self.host.get_id()
@property
def protocols(self) -> Tuple[TProtocol, ...]:
return tuple(self.router.get_protocols())
@property
def topic_ids(self) -> KeysView[str]:
return self.subscribed_topics_receive.keys()
def get_hello_packet(self) -> rpc_pb2.RPC:
"""Generate subscription message with all topics we are subscribed to
only send hello packet if we have subscribed topics."""
packet = rpc_pb2.RPC()
for topic_id in self.my_topics:
for topic_id in self.topic_ids:
packet.subscriptions.extend(
[rpc_pb2.RPC.SubOpts(subscribe=True, topicid=topic_id)]
)
@ -182,7 +183,7 @@ class Pubsub:
"""
peer_id = stream.muxed_conn.peer_id
while True:
while self.manager.is_running:
incoming: bytes = await read_varint_prefixed_bytes(stream)
rpc_incoming: rpc_pb2.RPC = rpc_pb2.RPC()
rpc_incoming.ParseFromString(incoming)
@ -194,11 +195,7 @@ class Pubsub:
logger.debug(
"received `publish` message %s from peer %s", msg, peer_id
)
self._tasks.append(
asyncio.ensure_future(
self.push_msg(msg_forwarder=peer_id, msg=msg)
)
)
self.manager.run_task(self.push_msg, peer_id, msg)
if rpc_incoming.subscriptions:
# deal with RPC.subscriptions
@ -226,9 +223,6 @@ class Pubsub:
)
await self.router.handle_rpc(rpc_incoming, peer_id)
# Force context switch
await asyncio.sleep(0)
def set_topic_validator(
self, topic: str, validator: ValidatorFn, is_async_validator: bool
) -> None:
@ -283,6 +277,10 @@ class Pubsub:
await stream.reset()
self._handle_dead_peer(peer_id)
async def wait_until_ready(self) -> None:
await self.event_handle_peer_queue_started.wait()
await self.event_handle_dead_peer_queue_started.wait()
async def _handle_new_peer(self, peer_id: ID) -> None:
try:
stream: INetStream = await self.host.new_stream(peer_id, self.protocols)
@ -325,18 +323,21 @@ class Pubsub:
"""Continuously read from peer queue and each time a new peer is found,
open a stream to the peer using a supported pubsub protocol pubsub
protocols we support."""
while True:
peer_id: ID = await self.peer_queue.get()
# Add Peer
self._tasks.append(asyncio.ensure_future(self._handle_new_peer(peer_id)))
async with self.peer_receive_channel:
self.event_handle_peer_queue_started.set()
async for peer_id in self.peer_receive_channel:
# Add Peer
self.manager.run_task(self._handle_new_peer, peer_id)
async def handle_dead_peer_queue(self) -> None:
"""Continuously read from dead peer queue and close the stream between
that peer and remove peer info from pubsub and pubsub router."""
while True:
peer_id: ID = await self.dead_peer_queue.get()
# Remove Peer
self._handle_dead_peer(peer_id)
"""Continuously read from dead peer channel and close the stream
between that peer and remove peer info from pubsub and pubsub
router."""
async with self.dead_peer_receive_channel:
self.event_handle_dead_peer_queue_started.set()
async for peer_id in self.dead_peer_receive_channel:
# Remove Peer
self._handle_dead_peer(peer_id)
def handle_subscription(
self, origin_id: ID, sub_message: rpc_pb2.RPC.SubOpts
@ -360,8 +361,7 @@ class Pubsub:
if origin_id in self.peer_topics[sub_message.topicid]:
self.peer_topics[sub_message.topicid].discard(origin_id)
# FIXME(mhchia): Change the function name?
async def handle_talk(self, publish_message: rpc_pb2.Message) -> None:
def notify_subscriptions(self, publish_message: rpc_pb2.Message) -> None:
"""
Put incoming message from a peer onto my blocking queue.
@ -370,13 +370,19 @@ class Pubsub:
# Check if this message has any topics that we are subscribed to
for topic in publish_message.topicIDs:
if topic in self.my_topics:
if topic in self.topic_ids:
# we are subscribed to a topic this message was sent for,
# so add message to the subscription output queue
# for each topic
await self.my_topics[topic].put(publish_message)
try:
self.subscribed_topics_send[topic].send_nowait(publish_message)
except trio.WouldBlock:
# Channel is full, ignore this message.
logger.warning(
"fail to deliver message to subscription for topic %s", topic
)
async def subscribe(self, topic_id: str) -> "asyncio.Queue[rpc_pb2.Message]":
async def subscribe(self, topic_id: str) -> ISubscriptionAPI:
"""
Subscribe ourself to a topic.
@ -386,11 +392,19 @@ class Pubsub:
logger.debug("subscribing to topic %s", topic_id)
# Already subscribed
if topic_id in self.my_topics:
return self.my_topics[topic_id]
if topic_id in self.topic_ids:
return self.subscribed_topics_receive[topic_id]
# Map topic_id to blocking queue
self.my_topics[topic_id] = asyncio.Queue()
send_channel, receive_channel = trio.open_memory_channel[rpc_pb2.Message](
SUBSCRIPTION_CHANNEL_SIZE
)
subscription = TrioSubscriptionAPI(
receive_channel,
unsubscribe_fn=functools.partial(self.unsubscribe, topic_id),
)
self.subscribed_topics_send[topic_id] = send_channel
self.subscribed_topics_receive[topic_id] = subscription
# Create subscribe message
packet: rpc_pb2.RPC = rpc_pb2.RPC()
@ -404,8 +418,8 @@ class Pubsub:
# Tell router we are joining this topic
await self.router.join(topic_id)
# Return the asyncio queue for messages on this topic
return self.my_topics[topic_id]
# Return the subscription for messages on this topic
return subscription
async def unsubscribe(self, topic_id: str) -> None:
"""
@ -417,10 +431,14 @@ class Pubsub:
logger.debug("unsubscribing from topic %s", topic_id)
# Return if we already unsubscribed from the topic
if topic_id not in self.my_topics:
if topic_id not in self.topic_ids:
return
# Remove topic_id from map if present
del self.my_topics[topic_id]
# Remove topic_id from the maps before yielding
send_channel = self.subscribed_topics_send[topic_id]
del self.subscribed_topics_send[topic_id]
del self.subscribed_topics_receive[topic_id]
# Only close the send side
await send_channel.aclose()
# Create unsubscribe message
packet: rpc_pb2.RPC = rpc_pb2.RPC()
@ -462,7 +480,7 @@ class Pubsub:
data=data,
topicIDs=[topic_id],
# Origin is ourself.
from_id=self.host.get_id().to_bytes(),
from_id=self.my_id.to_bytes(),
seqno=self._next_seqno(),
)
@ -474,7 +492,7 @@ class Pubsub:
msg.key = self.host.get_public_key().serialize()
msg.signature = signature
await self.push_msg(self.host.get_id(), msg)
await self.push_msg(self.my_id, msg)
logger.debug("successfully published message %s", msg)
@ -485,12 +503,12 @@ class Pubsub:
:param msg_forwarder: the peer who forward us the message.
:param msg: the message.
"""
sync_topic_validators = []
async_topic_validator_futures: List[Awaitable[bool]] = []
sync_topic_validators: List[SyncValidatorFn] = []
async_topic_validators: List[AsyncValidatorFn] = []
for topic_validator in self.get_msg_validators(msg):
if topic_validator.is_async:
async_topic_validator_futures.append(
cast(Awaitable[bool], topic_validator.validator(msg_forwarder, msg))
async_topic_validators.append(
cast(AsyncValidatorFn, topic_validator.validator)
)
else:
sync_topic_validators.append(
@ -503,9 +521,20 @@ class Pubsub:
# TODO: Implement throttle on async validators
if len(async_topic_validator_futures) > 0:
results = await asyncio.gather(*async_topic_validator_futures)
if not all(results):
if len(async_topic_validators) > 0:
# TODO: Use a better pattern
final_result: bool = True
async def run_async_validator(func: AsyncValidatorFn) -> None:
nonlocal final_result
result = await func(msg_forwarder, msg)
final_result = final_result and result
async with trio.open_nursery() as nursery:
for async_validator in async_topic_validators:
nursery.start_soon(run_async_validator, async_validator)
if not final_result:
raise ValidationError(f"Validation failed for msg={msg}")
async def push_msg(self, msg_forwarder: ID, msg: rpc_pb2.Message) -> None:
@ -548,7 +577,7 @@ class Pubsub:
return
self._mark_msg_seen(msg)
await self.handle_talk(msg)
self.notify_subscriptions(msg)
await self.router.publish(msg_forwarder, msg)
def _next_seqno(self) -> bytes:
@ -567,14 +596,4 @@ class Pubsub:
self.seen_messages[msg_id] = 1
def _is_subscribed_to_msg(self, msg: rpc_pb2.Message) -> bool:
if not self.my_topics:
return False
return any(topic in self.my_topics for topic in msg.topicIDs)
async def close(self) -> None:
for task in self._tasks:
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
return any(topic in self.topic_ids for topic in msg.topicIDs)

View File

@ -1,6 +1,7 @@
from typing import TYPE_CHECKING
from multiaddr import Multiaddr
import trio
from libp2p.network.connection.net_connection_interface import INetConn
from libp2p.network.network_interface import INetwork
@ -8,19 +9,18 @@ from libp2p.network.notifee_interface import INotifee
from libp2p.network.stream.net_stream_interface import INetStream
if TYPE_CHECKING:
import asyncio # noqa: F401
from libp2p.peer.id import ID # noqa: F401
class PubsubNotifee(INotifee):
initiator_peers_queue: "asyncio.Queue[ID]"
dead_peers_queue: "asyncio.Queue[ID]"
initiator_peers_queue: "trio.MemorySendChannel[ID]"
dead_peers_queue: "trio.MemorySendChannel[ID]"
def __init__(
self,
initiator_peers_queue: "asyncio.Queue[ID]",
dead_peers_queue: "asyncio.Queue[ID]",
initiator_peers_queue: "trio.MemorySendChannel[ID]",
dead_peers_queue: "trio.MemorySendChannel[ID]",
) -> None:
"""
:param initiator_peers_queue: queue to add new peers to so that pubsub
@ -32,10 +32,10 @@ class PubsubNotifee(INotifee):
self.dead_peers_queue = dead_peers_queue
async def opened_stream(self, network: INetwork, stream: INetStream) -> None:
pass
await trio.hazmat.checkpoint()
async def closed_stream(self, network: INetwork, stream: INetStream) -> None:
pass
await trio.hazmat.checkpoint()
async def connected(self, network: INetwork, conn: INetConn) -> None:
"""
@ -46,7 +46,11 @@ class PubsubNotifee(INotifee):
:param network: network the connection was opened on
:param conn: connection that was opened
"""
await self.initiator_peers_queue.put(conn.muxed_conn.peer_id)
try:
await self.initiator_peers_queue.send(conn.muxed_conn.peer_id)
except trio.BrokenResourceError:
# The receive channel is closed by Pubsub. We should do nothing here.
pass
async def disconnected(self, network: INetwork, conn: INetConn) -> None:
"""
@ -56,10 +60,14 @@ class PubsubNotifee(INotifee):
:param network: network the connection was opened on
:param conn: connection that was opened
"""
await self.dead_peers_queue.put(conn.muxed_conn.peer_id)
try:
await self.dead_peers_queue.send(conn.muxed_conn.peer_id)
except trio.BrokenResourceError:
# The receive channel is closed by Pubsub. We should do nothing here.
pass
async def listen(self, network: INetwork, multiaddr: Multiaddr) -> None:
pass
await trio.hazmat.checkpoint()
async def listen_close(self, network: INetwork, multiaddr: Multiaddr) -> None:
pass
await trio.hazmat.checkpoint()

View File

@ -0,0 +1,46 @@
from types import TracebackType
from typing import AsyncIterator, Optional, Type
import trio
from .abc import ISubscriptionAPI
from .pb import rpc_pb2
from .typing import UnsubscribeFn
class BaseSubscriptionAPI(ISubscriptionAPI):
async def __aenter__(self) -> "BaseSubscriptionAPI":
await trio.hazmat.checkpoint()
return self
async def __aexit__(
self,
exc_type: "Optional[Type[BaseException]]",
exc_value: "Optional[BaseException]",
traceback: "Optional[TracebackType]",
) -> None:
await self.unsubscribe()
class TrioSubscriptionAPI(BaseSubscriptionAPI):
receive_channel: "trio.MemoryReceiveChannel[rpc_pb2.Message]"
unsubscribe_fn: UnsubscribeFn
def __init__(
self,
receive_channel: "trio.MemoryReceiveChannel[rpc_pb2.Message]",
unsubscribe_fn: UnsubscribeFn,
) -> None:
self.receive_channel = receive_channel
# Ignore type here since mypy complains: https://github.com/python/mypy/issues/2427
self.unsubscribe_fn = unsubscribe_fn # type: ignore
async def unsubscribe(self) -> None:
# Ignore type here since mypy complains: https://github.com/python/mypy/issues/2427
await self.unsubscribe_fn() # type: ignore
def __aiter__(self) -> AsyncIterator[rpc_pb2.Message]:
return self.receive_channel.__aiter__()
async def get(self) -> rpc_pb2.Message:
return await self.receive_channel.receive()

11
libp2p/pubsub/typing.py Normal file
View File

@ -0,0 +1,11 @@
from typing import Awaitable, Callable, Union
from libp2p.peer.id import ID
from .pb import rpc_pb2
SyncValidatorFn = Callable[[ID, rpc_pb2.Message], bool]
AsyncValidatorFn = Callable[[ID, rpc_pb2.Message], Awaitable[bool]]
ValidatorFn = Union[SyncValidatorFn, AsyncValidatorFn]
UnsubscribeFn = Callable[[], Awaitable[None]]

View File

@ -39,7 +39,7 @@ class InsecureSession(BaseSession):
await self.conn.write(data)
return len(data)
async def read(self, n: int = -1) -> bytes:
async def read(self, n: int = None) -> bytes:
return await self.conn.read(n)
async def close(self) -> None:

View File

@ -94,7 +94,7 @@ class SecureSession(BaseSession):
data = self.buf.getbuffer()[self.low_watermark : self.high_watermark]
if n < 0:
if n is None:
n = len(data)
result = data[:n].tobytes()
self.low_watermark += len(result)
@ -111,7 +111,7 @@ class SecureSession(BaseSession):
self.low_watermark = 0
self.high_watermark = len(msg)
async def read(self, n: int = -1) -> bytes:
async def read(self, n: int = None) -> bytes:
if n == 0:
return bytes()

View File

@ -1,5 +1,7 @@
from abc import ABC, abstractmethod
import trio
from libp2p.io.abc import ReadWriteCloser
from libp2p.peer.id import ID
from libp2p.security.secure_conn_interface import ISecureConn
@ -11,6 +13,7 @@ class IMuxedConn(ABC):
"""
peer_id: ID
event_started: trio.Event
@abstractmethod
def __init__(self, conn: ISecureConn, peer_id: ID) -> None:
@ -25,12 +28,17 @@ class IMuxedConn(ABC):
@property
@abstractmethod
def is_initiator(self) -> bool:
pass
"""if this connection is the initiator."""
@abstractmethod
async def start(self) -> None:
"""start the multiplexer."""
@abstractmethod
async def close(self) -> None:
"""close connection."""
@property
@abstractmethod
def is_closed(self) -> bool:
"""

View File

@ -1,7 +1,7 @@
import asyncio
import logging
from typing import Any # noqa: F401
from typing import Awaitable, Dict, List, Optional, Tuple
from typing import Dict, Optional, Tuple
import trio
from libp2p.exceptions import ParseError
from libp2p.io.exceptions import IncompleteReadError
@ -23,6 +23,8 @@ from .exceptions import MplexUnavailable
from .mplex_stream import MplexStream
MPLEX_PROTOCOL_ID = TProtocol("/mplex/6.7.0")
# Ref: https://github.com/libp2p/go-mplex/blob/414db61813d9ad3e6f4a7db5c1b1612de343ace9/multiplex.go#L115 # noqa: E501
MPLEX_MESSAGE_CHANNEL_SIZE = 8
logger = logging.getLogger("libp2p.stream_muxer.mplex.mplex")
@ -36,12 +38,14 @@ class Mplex(IMuxedConn):
peer_id: ID
next_channel_id: int
streams: Dict[StreamID, MplexStream]
streams_lock: asyncio.Lock
new_stream_queue: "asyncio.Queue[IMuxedStream]"
event_shutting_down: asyncio.Event
event_closed: asyncio.Event
streams_lock: trio.Lock
streams_msg_channels: Dict[StreamID, "trio.MemorySendChannel[bytes]"]
new_stream_send_channel: "trio.MemorySendChannel[IMuxedStream]"
new_stream_receive_channel: "trio.MemoryReceiveChannel[IMuxedStream]"
_tasks: List["asyncio.Future[Any]"]
event_shutting_down: trio.Event
event_closed: trio.Event
event_started: trio.Event
def __init__(self, secured_conn: ISecureConn, peer_id: ID) -> None:
"""
@ -61,15 +65,16 @@ class Mplex(IMuxedConn):
# Mapping from stream ID -> buffer of messages for that stream
self.streams = {}
self.streams_lock = asyncio.Lock()
self.new_stream_queue = asyncio.Queue()
self.event_shutting_down = asyncio.Event()
self.event_closed = asyncio.Event()
self.streams_lock = trio.Lock()
self.streams_msg_channels = {}
channels = trio.open_memory_channel[IMuxedStream](0)
self.new_stream_send_channel, self.new_stream_receive_channel = channels
self.event_shutting_down = trio.Event()
self.event_closed = trio.Event()
self.event_started = trio.Event()
self._tasks = []
# Kick off reading
self._tasks.append(asyncio.ensure_future(self.handle_incoming()))
async def start(self) -> None:
await self.handle_incoming()
@property
def is_initiator(self) -> bool:
@ -85,6 +90,7 @@ class Mplex(IMuxedConn):
# Blocked until `close` is finally set.
await self.event_closed.wait()
@property
def is_closed(self) -> bool:
"""
check connection is fully closed.
@ -104,9 +110,13 @@ class Mplex(IMuxedConn):
return next_id
async def _initialize_stream(self, stream_id: StreamID, name: str) -> MplexStream:
stream = MplexStream(name, stream_id, self)
send_channel, receive_channel = trio.open_memory_channel[bytes](
MPLEX_MESSAGE_CHANNEL_SIZE
)
stream = MplexStream(name, stream_id, self, receive_channel)
async with self.streams_lock:
self.streams[stream_id] = stream
self.streams_msg_channels[stream_id] = send_channel
return stream
async def open_stream(self) -> IMuxedStream:
@ -123,27 +133,12 @@ class Mplex(IMuxedConn):
await self.send_message(HeaderTags.NewStream, name.encode(), stream_id)
return stream
async def _wait_until_shutting_down_or_closed(self, coro: Awaitable[Any]) -> Any:
task_coro = asyncio.ensure_future(coro)
task_wait_closed = asyncio.ensure_future(self.event_closed.wait())
task_wait_shutting_down = asyncio.ensure_future(self.event_shutting_down.wait())
done, pending = await asyncio.wait(
[task_coro, task_wait_closed, task_wait_shutting_down],
return_when=asyncio.FIRST_COMPLETED,
)
for fut in pending:
fut.cancel()
if task_wait_closed in done:
raise MplexUnavailable("Mplex is closed")
if task_wait_shutting_down in done:
raise MplexUnavailable("Mplex is shutting down")
return task_coro.result()
async def accept_stream(self) -> IMuxedStream:
"""accepts a muxed stream opened by the other end."""
return await self._wait_until_shutting_down_or_closed(
self.new_stream_queue.get()
)
try:
return await self.new_stream_receive_channel.receive()
except trio.EndOfChannel:
raise MplexUnavailable
async def send_message(
self, flag: HeaderTags, data: Optional[bytes], stream_id: StreamID
@ -151,7 +146,7 @@ class Mplex(IMuxedConn):
"""
sends a message over the connection.
:param header: header to use
:param flag: header to use
:param data: data to send in the message
:param stream_id: stream the message is in
"""
@ -163,9 +158,7 @@ class Mplex(IMuxedConn):
_bytes = header + encode_varint_prefixed(data)
return await self._wait_until_shutting_down_or_closed(
self.write_to_stream(_bytes)
)
return await self.write_to_stream(_bytes)
async def write_to_stream(self, _bytes: bytes) -> int:
"""
@ -174,21 +167,25 @@ class Mplex(IMuxedConn):
:param _bytes: byte array to write
:return: length written
"""
await self.secured_conn.write(_bytes)
try:
await self.secured_conn.write(_bytes)
except RawConnError as e:
raise MplexUnavailable(
"failed to write message to the underlying connection"
) from e
return len(_bytes)
async def handle_incoming(self) -> None:
"""Read a message off of the secured connection and add it to the
corresponding message buffer."""
self.event_started.set()
while True:
try:
await self._handle_incoming_message()
except MplexUnavailable as e:
logger.debug("mplex unavailable while waiting for incoming: %s", e)
break
# Force context switch
await asyncio.sleep(0)
# If we enter here, it means this connection is shutting down.
# We should clean things up.
await self._cleanup()
@ -200,20 +197,19 @@ class Mplex(IMuxedConn):
:return: stream_id, flag, message contents
"""
# FIXME: No timeout is used in Go implementation.
try:
header = await decode_uvarint_from_stream(self.secured_conn)
message = await asyncio.wait_for(
read_varint_prefixed_bytes(self.secured_conn), timeout=5
)
except (ParseError, RawConnError, IncompleteReadError) as error:
raise MplexUnavailable(
"failed to read messages correctly from the underlying connection"
) from error
except asyncio.TimeoutError as error:
f"failed to read the header correctly from the underlying connection: {error}"
)
try:
message = await read_varint_prefixed_bytes(self.secured_conn)
except (ParseError, RawConnError, IncompleteReadError) as error:
raise MplexUnavailable(
"failed to read more message body within the timeout"
) from error
"failed to read the message body correctly from the underlying connection: "
f"{error}"
)
flag = header & 0x07
channel_id = header >> 3
@ -226,9 +222,7 @@ class Mplex(IMuxedConn):
:raise MplexUnavailable: `Mplex` encounters fatal error or is shutting down.
"""
channel_id, flag, message = await self._wait_until_shutting_down_or_closed(
self.read_message()
)
channel_id, flag, message = await self.read_message()
stream_id = StreamID(channel_id=channel_id, is_initiator=bool(flag & 1))
if flag == HeaderTags.NewStream.value:
@ -258,9 +252,10 @@ class Mplex(IMuxedConn):
f"received NewStream message for existing stream: {stream_id}"
)
mplex_stream = await self._initialize_stream(stream_id, message.decode())
await self._wait_until_shutting_down_or_closed(
self.new_stream_queue.put(mplex_stream)
)
try:
await self.new_stream_send_channel.send(mplex_stream)
except trio.ClosedResourceError:
raise MplexUnavailable
async def _handle_message(self, stream_id: StreamID, message: bytes) -> None:
async with self.streams_lock:
@ -270,13 +265,21 @@ class Mplex(IMuxedConn):
# TODO: Warn and emit logs about this.
return
stream = self.streams[stream_id]
send_channel = self.streams_msg_channels[stream_id]
async with stream.close_lock:
if stream.event_remote_closed.is_set():
# TODO: Warn "Received data from remote after stream was closed by them. (len = %d)" # noqa: E501
return
await self._wait_until_shutting_down_or_closed(
stream.incoming_data.put(message)
)
try:
send_channel.send_nowait(message)
except (trio.BrokenResourceError, trio.ClosedResourceError):
raise MplexUnavailable
except trio.WouldBlock:
# `send_channel` is full, reset this stream.
logger.warning(
"message channel of stream %s is full: stream is reset", stream_id
)
await stream.reset()
async def _handle_close(self, stream_id: StreamID) -> None:
async with self.streams_lock:
@ -284,6 +287,8 @@ class Mplex(IMuxedConn):
# Ignore unmatched messages for now.
return
stream = self.streams[stream_id]
send_channel = self.streams_msg_channels[stream_id]
await send_channel.aclose()
# NOTE: If remote is already closed, then return: Technically a bug
# on the other side. We should consider killing the connection.
async with stream.close_lock:
@ -305,27 +310,30 @@ class Mplex(IMuxedConn):
# This is *ok*. We forget the stream on reset.
return
stream = self.streams[stream_id]
send_channel = self.streams_msg_channels[stream_id]
await send_channel.aclose()
async with stream.close_lock:
if not stream.event_remote_closed.is_set():
stream.event_reset.set()
stream.event_remote_closed.set()
# If local is not closed, we should close it.
if not stream.event_local_closed.is_set():
stream.event_local_closed.set()
async with self.streams_lock:
self.streams.pop(stream_id, None)
self.streams_msg_channels.pop(stream_id, None)
async def _cleanup(self) -> None:
if not self.event_shutting_down.is_set():
self.event_shutting_down.set()
async with self.streams_lock:
for stream in self.streams.values():
for stream_id, stream in self.streams.items():
async with stream.close_lock:
if not stream.event_remote_closed.is_set():
stream.event_remote_closed.set()
stream.event_reset.set()
stream.event_local_closed.set()
self.streams = None
send_channel = self.streams_msg_channels[stream_id]
await send_channel.aclose()
self.event_closed.set()
await self.new_stream_send_channel.aclose()

View File

@ -1,7 +1,9 @@
import asyncio
from typing import TYPE_CHECKING
import trio
from libp2p.stream_muxer.abc import IMuxedStream
from libp2p.stream_muxer.exceptions import MuxedConnUnavailable
from .constants import HeaderTags
from .datastructures import StreamID
@ -22,18 +24,25 @@ class MplexStream(IMuxedStream):
read_deadline: int
write_deadline: int
close_lock: asyncio.Lock
# TODO: Add lock for read/write to avoid interleaving receiving messages?
close_lock: trio.Lock
# NOTE: `dataIn` is size of 8 in Go implementation.
incoming_data: "asyncio.Queue[bytes]"
incoming_data_channel: "trio.MemoryReceiveChannel[bytes]"
event_local_closed: asyncio.Event
event_remote_closed: asyncio.Event
event_reset: asyncio.Event
event_local_closed: trio.Event
event_remote_closed: trio.Event
event_reset: trio.Event
_buf: bytearray
def __init__(self, name: str, stream_id: StreamID, muxed_conn: "Mplex") -> None:
def __init__(
self,
name: str,
stream_id: StreamID,
muxed_conn: "Mplex",
incoming_data_channel: "trio.MemoryReceiveChannel[bytes]",
) -> None:
"""
create new MuxedStream in muxer.
@ -45,99 +54,82 @@ class MplexStream(IMuxedStream):
self.muxed_conn = muxed_conn
self.read_deadline = None
self.write_deadline = None
self.event_local_closed = asyncio.Event()
self.event_remote_closed = asyncio.Event()
self.event_reset = asyncio.Event()
self.close_lock = asyncio.Lock()
self.incoming_data = asyncio.Queue()
self.event_local_closed = trio.Event()
self.event_remote_closed = trio.Event()
self.event_reset = trio.Event()
self.close_lock = trio.Lock()
self.incoming_data_channel = incoming_data_channel
self._buf = bytearray()
@property
def is_initiator(self) -> bool:
return self.stream_id.is_initiator
async def _wait_for_data(self) -> None:
task_event_reset = asyncio.ensure_future(self.event_reset.wait())
task_incoming_data_get = asyncio.ensure_future(self.incoming_data.get())
task_event_remote_closed = asyncio.ensure_future(
self.event_remote_closed.wait()
)
done, pending = await asyncio.wait( # type: ignore
[ # type: ignore
task_event_reset,
task_incoming_data_get,
task_event_remote_closed,
],
return_when=asyncio.FIRST_COMPLETED,
)
for fut in pending:
fut.cancel()
if task_event_reset in done:
if self.event_reset.is_set():
raise MplexStreamReset()
else:
# However, it is abnormal that `Event.wait` is unblocked without any of the flag
# is set. The task is probably cancelled.
raise Exception(
"Should not enter here. "
f"It is probably because {task_event_remote_closed} is cancelled."
)
if task_incoming_data_get in done:
data = task_incoming_data_get.result()
self._buf.extend(data)
return
if task_event_remote_closed in done:
if self.event_remote_closed.is_set():
raise MplexStreamEOF()
else:
# However, it is abnormal that `Event.wait` is unblocked without any of the flag
# is set. The task is probably cancelled.
raise Exception(
"Should not enter here. "
f"It is probably because {task_event_remote_closed} is cancelled."
)
# TODO: Handle timeout when deadline is used.
async def _read_until_eof(self) -> bytes:
while True:
try:
await self._wait_for_data()
except MplexStreamEOF:
break
async for data in self.incoming_data_channel:
self._buf.extend(data)
payload = self._buf
self._buf = self._buf[len(payload) :]
return bytes(payload)
async def read(self, n: int = -1) -> bytes:
def _read_return_when_blocked(self) -> bytes:
buf = bytearray()
while True:
try:
data = self.incoming_data_channel.receive_nowait()
buf.extend(data)
except (trio.WouldBlock, trio.EndOfChannel):
break
return buf
async def read(self, n: int = None) -> bytes:
"""
Read up to n bytes. Read possibly returns fewer than `n` bytes, if
there are not enough bytes in the Mplex buffer. If `n == -1`, read
there are not enough bytes in the Mplex buffer. If `n is None`, read
until EOF.
:param n: number of bytes to read
:return: bytes actually read
"""
if n < 0 and n != -1:
if n is not None and n < 0:
raise ValueError(
f"the number of bytes to read `n` must be positive or -1 to indicate read until EOF"
f"the number of bytes to read `n` must be non-negative or "
"`None` to indicate read until EOF"
)
if self.event_reset.is_set():
raise MplexStreamReset()
if n == -1:
raise MplexStreamReset
if n is None:
return await self._read_until_eof()
if len(self._buf) == 0 and self.incoming_data.empty():
await self._wait_for_data()
# Now we are sure we have something to read.
# Try to put enough incoming data into `self._buf`.
while len(self._buf) < n:
if len(self._buf) == 0:
data: bytes
# Peek whether there is data available. If yes, we just read until there is no data,
# and then return.
try:
self._buf.extend(self.incoming_data.get_nowait())
except asyncio.QueueEmpty:
break
data = self.incoming_data_channel.receive_nowait()
self._buf.extend(data)
except trio.EndOfChannel:
raise MplexStreamEOF
except trio.WouldBlock:
# We know `receive` will be blocked here. Wait for data here with `receive` and
# catch all kinds of errors here.
try:
data = await self.incoming_data_channel.receive()
self._buf.extend(data)
except trio.EndOfChannel:
if self.event_reset.is_set():
raise MplexStreamReset
if self.event_remote_closed.is_set():
raise MplexStreamEOF
except trio.ClosedResourceError as error:
# Probably `incoming_data_channel` is closed in `reset` when we are waiting
# for `receive`.
if self.event_reset.is_set():
raise MplexStreamReset
raise Exception(
"`incoming_data_channel` is closed but stream is not reset. "
"This should never happen."
) from error
self._buf.extend(self._read_return_when_blocked())
payload = self._buf[:n]
self._buf = self._buf[len(payload) :]
return bytes(payload)
@ -198,14 +190,17 @@ class MplexStream(IMuxedStream):
if self.is_initiator
else HeaderTags.ResetReceiver
)
asyncio.ensure_future(
self.muxed_conn.send_message(flag, None, self.stream_id)
)
await asyncio.sleep(0)
# Try to send reset message to the other side. Ignore if there is anything wrong.
try:
await self.muxed_conn.send_message(flag, None, self.stream_id)
except MuxedConnUnavailable:
pass
self.event_local_closed.set()
self.event_remote_closed.set()
await self.incoming_data_channel.aclose()
async with self.muxed_conn.streams_lock:
if self.muxed_conn.streams is not None:
self.muxed_conn.streams.pop(self.stream_id, None)

View File

@ -7,7 +7,7 @@ from libp2p.pubsub import floodsub, gossipsub
# Just a arbitrary large number.
# It is used when calling `MplexStream.read(MAX_READ_LEN)`,
# to avoid `MplexStream.read()`, which blocking reads until EOF.
MAX_READ_LEN = 2 ** 32 - 1
MAX_READ_LEN = 65535
LISTEN_MADDR = multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/0")

View File

@ -1,40 +1,55 @@
import asyncio
from typing import Any, AsyncIterator, Dict, Tuple, cast
from typing import Any, AsyncIterator, Dict, List, Sequence, Tuple, cast
# NOTE: import ``asynccontextmanager`` from ``contextlib`` when support for python 3.6 is dropped.
from async_exit_stack import AsyncExitStack
from async_generator import asynccontextmanager
from async_service import background_trio_service
import factory
from multiaddr import Multiaddr
import trio
from libp2p import generate_new_rsa_identity, generate_peer_id_from
from libp2p.crypto.keys import KeyPair
from libp2p.host.basic_host import BasicHost
from libp2p.host.host_interface import IHost
from libp2p.host.routed_host import RoutedHost
from libp2p.io.abc import ReadWriteCloser
from libp2p.network.connection.raw_connection import RawConnection
from libp2p.network.connection.raw_connection_interface import IRawConnection
from libp2p.network.connection.swarm_connection import SwarmConn
from libp2p.network.stream.net_stream_interface import INetStream
from libp2p.network.swarm import Swarm
from libp2p.peer.id import ID
from libp2p.peer.peerinfo import PeerInfo
from libp2p.peer.peerstore import PeerStore
from libp2p.pubsub.abc import IPubsubRouter
from libp2p.pubsub.floodsub import FloodSub
from libp2p.pubsub.gossipsub import GossipSub
from libp2p.pubsub.pubsub import Pubsub
from libp2p.routing.interfaces import IPeerRouting
from libp2p.security.base_transport import BaseSecureTransport
from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport
import libp2p.security.secio.transport as secio
from libp2p.stream_muxer.mplex.mplex import MPLEX_PROTOCOL_ID, Mplex
from libp2p.stream_muxer.mplex.mplex_stream import MplexStream
from libp2p.tools.constants import GOSSIPSUB_PARAMS
from libp2p.transport.tcp.tcp import TCP
from libp2p.transport.typing import TMuxerOptions
from libp2p.transport.upgrader import TransportUpgrader
from libp2p.typing import TProtocol
from .constants import (
FLOODSUB_PROTOCOL_ID,
GOSSIPSUB_PARAMS,
GOSSIPSUB_PROTOCOL_ID,
LISTEN_MADDR,
)
from .constants import FLOODSUB_PROTOCOL_ID, GOSSIPSUB_PROTOCOL_ID, LISTEN_MADDR
from .utils import connect, connect_swarm
class IDFactory(factory.Factory):
class Meta:
model = ID
peer_id_bytes = factory.LazyFunction(
lambda: generate_peer_id_from(generate_new_rsa_identity())
)
def initialize_peerstore_with_our_keypair(self_id: ID, key_pair: KeyPair) -> PeerStore:
peer_store = PeerStore()
peer_store.add_key_pair(self_id, key_pair)
@ -50,6 +65,29 @@ def security_transport_factory(
return {secio.ID: secio.Transport(key_pair)}
@asynccontextmanager
async def raw_conn_factory(
nursery: trio.Nursery
) -> AsyncIterator[Tuple[IRawConnection, IRawConnection]]:
conn_0 = None
conn_1 = None
event = trio.Event()
async def tcp_stream_handler(stream: ReadWriteCloser) -> None:
nonlocal conn_1
conn_1 = RawConnection(stream, initiator=False)
event.set()
await trio.sleep_forever()
tcp_transport = TCP()
listener = tcp_transport.create_listener(tcp_stream_handler)
await listener.listen(LISTEN_MADDR, nursery)
listening_maddr = listener.get_addrs()[0]
conn_0 = await tcp_transport.dial(listening_maddr)
await event.wait()
yield conn_0, conn_1
class SwarmFactory(factory.Factory):
class Meta:
model = Swarm
@ -71,9 +109,10 @@ class SwarmFactory(factory.Factory):
transport = factory.LazyFunction(TCP)
@classmethod
@asynccontextmanager
async def create_and_listen(
cls, is_secure: bool, key_pair: KeyPair = None, muxer_opt: TMuxerOptions = None
) -> Swarm:
) -> AsyncIterator[Swarm]:
# `factory.Factory.__init__` does *not* prepare a *default value* if we pass
# an argument explicitly with `None`. If an argument is `None`, we don't pass it to
# `factory.Factory.__init__`, in order to let the function initialize it.
@ -83,20 +122,23 @@ class SwarmFactory(factory.Factory):
if muxer_opt is not None:
optional_kwargs["muxer_opt"] = muxer_opt
swarm = cls(is_secure=is_secure, **optional_kwargs)
await swarm.listen(LISTEN_MADDR)
return swarm
async with background_trio_service(swarm):
await swarm.listen(LISTEN_MADDR)
yield swarm
@classmethod
@asynccontextmanager
async def create_batch_and_listen(
cls, is_secure: bool, number: int, muxer_opt: TMuxerOptions = None
) -> Tuple[Swarm, ...]:
# Ignore typing since we are removing asyncio soon
return await asyncio.gather( # type: ignore
*[
cls.create_and_listen(is_secure=is_secure, muxer_opt=muxer_opt)
) -> AsyncIterator[Tuple[Swarm, ...]]:
async with AsyncExitStack() as stack:
ctx_mgrs = [
await stack.enter_async_context(
cls.create_and_listen(is_secure=is_secure, muxer_opt=muxer_opt)
)
for _ in range(number)
]
)
yield tuple(ctx_mgrs)
class HostFactory(factory.Factory):
@ -107,22 +149,57 @@ class HostFactory(factory.Factory):
is_secure = False
key_pair = factory.LazyFunction(generate_new_rsa_identity)
network = factory.LazyAttribute(
lambda o: SwarmFactory(is_secure=o.is_secure, key_pair=o.key_pair)
)
network = factory.LazyAttribute(lambda o: SwarmFactory(is_secure=o.is_secure))
@classmethod
@asynccontextmanager
async def create_batch_and_listen(
cls, is_secure: bool, number: int
) -> Tuple[BasicHost, ...]:
key_pairs = [generate_new_rsa_identity() for _ in range(number)]
swarms = await asyncio.gather(
*[
SwarmFactory.create_and_listen(is_secure, key_pair)
for key_pair in key_pairs
]
)
return tuple(BasicHost(swarm) for swarm in swarms)
) -> AsyncIterator[Tuple[BasicHost, ...]]:
async with SwarmFactory.create_batch_and_listen(is_secure, number) as swarms:
hosts = tuple(BasicHost(swarm) for swarm in swarms)
yield hosts
class DummyRouter(IPeerRouting):
_routing_table: Dict[ID, PeerInfo]
def __init__(self) -> None:
self._routing_table = dict()
def _add_peer(self, peer_id: ID, addrs: List[Multiaddr]) -> None:
self._routing_table[peer_id] = PeerInfo(peer_id, addrs)
async def find_peer(self, peer_id: ID) -> PeerInfo:
await trio.hazmat.checkpoint()
return self._routing_table.get(peer_id, None)
class RoutedHostFactory(factory.Factory):
class Meta:
model = RoutedHost
class Params:
is_secure = False
network = factory.LazyAttribute(
lambda o: HostFactory(is_secure=o.is_secure).get_network()
)
router = factory.LazyFunction(DummyRouter)
@classmethod
@asynccontextmanager
async def create_batch_and_listen(
cls, is_secure: bool, number: int
) -> AsyncIterator[Tuple[RoutedHost, ...]]:
routing_table = DummyRouter()
async with HostFactory.create_batch_and_listen(is_secure, number) as hosts:
for host in hosts:
routing_table._add_peer(host.get_id(), host.get_addrs())
routed_hosts = tuple(
RoutedHost(host.get_network(), routing_table) for host in hosts
)
yield routed_hosts
class FloodsubFactory(factory.Factory):
@ -153,89 +230,192 @@ class PubsubFactory(factory.Factory):
host = factory.SubFactory(HostFactory)
router = None
my_id = factory.LazyAttribute(lambda obj: obj.host.get_id())
cache_size = None
strict_signing = False
@classmethod
@asynccontextmanager
async def create_and_start(
cls, host: IHost, router: IPubsubRouter, cache_size: int, strict_signing: bool
) -> AsyncIterator[Pubsub]:
pubsub = cls(
host=host,
router=router,
cache_size=cache_size,
strict_signing=strict_signing,
)
async with background_trio_service(pubsub):
await pubsub.wait_until_ready()
yield pubsub
@classmethod
@asynccontextmanager
async def _create_batch_with_router(
cls,
number: int,
routers: Sequence[IPubsubRouter],
is_secure: bool = False,
cache_size: int = None,
strict_signing: bool = False,
) -> AsyncIterator[Tuple[Pubsub, ...]]:
async with HostFactory.create_batch_and_listen(is_secure, number) as hosts:
# Pubsubs should exit before hosts
async with AsyncExitStack() as stack:
pubsubs = [
await stack.enter_async_context(
cls.create_and_start(host, router, cache_size, strict_signing)
)
for host, router in zip(hosts, routers)
]
yield tuple(pubsubs)
@classmethod
@asynccontextmanager
async def create_batch_with_floodsub(
cls,
number: int,
is_secure: bool = False,
cache_size: int = None,
strict_signing: bool = False,
protocols: Sequence[TProtocol] = None,
) -> AsyncIterator[Tuple[Pubsub, ...]]:
if protocols is not None:
floodsubs = FloodsubFactory.create_batch(number, protocols=list(protocols))
else:
floodsubs = FloodsubFactory.create_batch(number)
async with cls._create_batch_with_router(
number, floodsubs, is_secure, cache_size, strict_signing
) as pubsubs:
yield pubsubs
@classmethod
@asynccontextmanager
async def create_batch_with_gossipsub(
cls,
number: int,
*,
is_secure: bool = False,
cache_size: int = None,
strict_signing: bool = False,
protocols: Sequence[TProtocol] = None,
degree: int = GOSSIPSUB_PARAMS.degree,
degree_low: int = GOSSIPSUB_PARAMS.degree_low,
degree_high: int = GOSSIPSUB_PARAMS.degree_high,
time_to_live: int = GOSSIPSUB_PARAMS.time_to_live,
gossip_window: int = GOSSIPSUB_PARAMS.gossip_window,
gossip_history: int = GOSSIPSUB_PARAMS.gossip_history,
heartbeat_interval: float = GOSSIPSUB_PARAMS.heartbeat_interval,
heartbeat_initial_delay: float = GOSSIPSUB_PARAMS.heartbeat_initial_delay,
) -> AsyncIterator[Tuple[Pubsub, ...]]:
if protocols is not None:
gossipsubs = GossipsubFactory.create_batch(
number,
protocols=protocols,
degree=degree,
degree_low=degree_low,
degree_high=degree_high,
time_to_live=time_to_live,
gossip_window=gossip_window,
heartbeat_interval=heartbeat_interval,
)
else:
gossipsubs = GossipsubFactory.create_batch(
number,
degree=degree,
degree_low=degree_low,
degree_high=degree_high,
time_to_live=time_to_live,
gossip_window=gossip_window,
heartbeat_interval=heartbeat_interval,
)
async with cls._create_batch_with_router(
number, gossipsubs, is_secure, cache_size, strict_signing
) as pubsubs:
async with AsyncExitStack() as stack:
for router in gossipsubs:
await stack.enter_async_context(background_trio_service(router))
yield pubsubs
@asynccontextmanager
async def swarm_pair_factory(
is_secure: bool, muxer_opt: TMuxerOptions = None
) -> Tuple[Swarm, Swarm]:
swarms = await SwarmFactory.create_batch_and_listen(
) -> AsyncIterator[Tuple[Swarm, Swarm]]:
async with SwarmFactory.create_batch_and_listen(
is_secure, 2, muxer_opt=muxer_opt
)
await connect_swarm(swarms[0], swarms[1])
return swarms[0], swarms[1]
) as swarms:
await connect_swarm(swarms[0], swarms[1])
yield swarms[0], swarms[1]
async def host_pair_factory(is_secure: bool) -> Tuple[BasicHost, BasicHost]:
hosts = await HostFactory.create_batch_and_listen(is_secure, 2)
await connect(hosts[0], hosts[1])
return hosts[0], hosts[1]
@asynccontextmanager # type: ignore
async def pair_of_connected_hosts(
is_secure: bool = True
@asynccontextmanager
async def host_pair_factory(
is_secure: bool
) -> AsyncIterator[Tuple[BasicHost, BasicHost]]:
a, b = await host_pair_factory(is_secure)
yield a, b
close_tasks = (a.close(), b.close())
await asyncio.gather(*close_tasks)
async with HostFactory.create_batch_and_listen(is_secure, 2) as hosts:
await connect(hosts[0], hosts[1])
yield hosts[0], hosts[1]
@asynccontextmanager
async def swarm_conn_pair_factory(
is_secure: bool, muxer_opt: TMuxerOptions = None
) -> Tuple[SwarmConn, Swarm, SwarmConn, Swarm]:
swarms = await swarm_pair_factory(is_secure)
conn_0 = swarms[0].connections[swarms[1].get_peer_id()]
conn_1 = swarms[1].connections[swarms[0].get_peer_id()]
return cast(SwarmConn, conn_0), swarms[0], cast(SwarmConn, conn_1), swarms[1]
) -> AsyncIterator[Tuple[SwarmConn, SwarmConn]]:
async with swarm_pair_factory(is_secure) as swarms:
conn_0 = swarms[0].connections[swarms[1].get_peer_id()]
conn_1 = swarms[1].connections[swarms[0].get_peer_id()]
yield cast(SwarmConn, conn_0), cast(SwarmConn, conn_1)
async def mplex_conn_pair_factory(is_secure: bool) -> Tuple[Mplex, Swarm, Mplex, Swarm]:
@asynccontextmanager
async def mplex_conn_pair_factory(
is_secure: bool
) -> AsyncIterator[Tuple[Mplex, Mplex]]:
muxer_opt = {MPLEX_PROTOCOL_ID: Mplex}
conn_0, swarm_0, conn_1, swarm_1 = await swarm_conn_pair_factory(
is_secure, muxer_opt=muxer_opt
)
return (
cast(Mplex, conn_0.muxed_conn),
swarm_0,
cast(Mplex, conn_1.muxed_conn),
swarm_1,
)
async with swarm_conn_pair_factory(is_secure, muxer_opt=muxer_opt) as swarm_pair:
yield (
cast(Mplex, swarm_pair[0].muxed_conn),
cast(Mplex, swarm_pair[1].muxed_conn),
)
@asynccontextmanager
async def mplex_stream_pair_factory(
is_secure: bool
) -> Tuple[MplexStream, Swarm, MplexStream, Swarm]:
mplex_conn_0, swarm_0, mplex_conn_1, swarm_1 = await mplex_conn_pair_factory(
is_secure
)
stream_0 = await mplex_conn_0.open_stream()
await asyncio.sleep(0.01)
stream_1: MplexStream
async with mplex_conn_1.streams_lock:
if len(mplex_conn_1.streams) != 1:
raise Exception("Mplex should not have any stream upon connection")
stream_1 = tuple(mplex_conn_1.streams.values())[0]
return cast(MplexStream, stream_0), swarm_0, stream_1, swarm_1
) -> AsyncIterator[Tuple[MplexStream, MplexStream]]:
async with mplex_conn_pair_factory(is_secure) as mplex_conn_pair_info:
mplex_conn_0, mplex_conn_1 = mplex_conn_pair_info
stream_0 = cast(MplexStream, await mplex_conn_0.open_stream())
await trio.sleep(0.01)
stream_1: MplexStream
async with mplex_conn_1.streams_lock:
if len(mplex_conn_1.streams) != 1:
raise Exception("Mplex should not have any other stream")
stream_1 = tuple(mplex_conn_1.streams.values())[0]
yield stream_0, stream_1
@asynccontextmanager
async def net_stream_pair_factory(
is_secure: bool
) -> Tuple[INetStream, BasicHost, INetStream, BasicHost]:
) -> AsyncIterator[Tuple[INetStream, INetStream]]:
protocol_id = TProtocol("/example/id/1")
stream_1: INetStream
# Just a proxy, we only care about the stream
def handler(stream: INetStream) -> None:
# Just a proxy, we only care about the stream.
# Add a barrier to avoid stream being removed.
event_handler_finished = trio.Event()
async def handler(stream: INetStream) -> None:
nonlocal stream_1
stream_1 = stream
await event_handler_finished.wait()
host_0, host_1 = await host_pair_factory(is_secure)
host_1.set_stream_handler(protocol_id, handler)
async with host_pair_factory(is_secure) as hosts:
hosts[1].set_stream_handler(protocol_id, handler)
stream_0 = await host_0.new_stream(host_1.get_id(), [protocol_id])
return stream_0, host_0, stream_1, host_1
stream_0 = await hosts[0].new_stream(hosts[1].get_id(), [protocol_id])
yield stream_0, stream_1
event_handler_finished.set()

View File

@ -1,2 +1 @@
LOCALHOST_IP = "127.0.0.1"
PEXPECT_NEW_LINE = "\r\n"

View File

@ -1,52 +1,22 @@
import asyncio
import time
from typing import Any, Awaitable, Callable, List
from typing import AsyncIterator
from async_generator import asynccontextmanager
import multiaddr
from multiaddr import Multiaddr
from p2pclient import Client
import pytest
import trio
from libp2p.peer.id import ID
from libp2p.peer.peerinfo import PeerInfo, info_from_p2p_addr
from .constants import LOCALHOST_IP
from .envs import GO_BIN_PATH
from .process import BaseInteractiveProcess
P2PD_PATH = GO_BIN_PATH / "p2pd"
TIMEOUT_DURATION = 30
async def try_until_success(
coro_func: Callable[[], Awaitable[Any]], timeout: int = TIMEOUT_DURATION
) -> None:
"""
Keep running ``coro_func`` until either it succeed or time is up.
All arguments of ``coro_func`` should be filled, i.e. it should be
called without arguments.
"""
t_start = time.monotonic()
while True:
result = await coro_func()
if result:
break
if (time.monotonic() - t_start) >= timeout:
# timeout
pytest.fail(f"{coro_func} is still failing after `{timeout}` seconds")
await asyncio.sleep(0.01)
class P2PDProcess:
proc: asyncio.subprocess.Process
cmd: str = str(P2PD_PATH)
args: List[Any]
is_proc_running: bool
_tasks: List["asyncio.Future[Any]"]
class P2PDProcess(BaseInteractiveProcess):
def __init__(
self,
control_maddr: Multiaddr,
@ -75,74 +45,21 @@ class P2PDProcess:
# - gossipsubHeartbeatInterval: GossipSubHeartbeatInitialDelay = 100 * time.Millisecond # noqa: E501
# - gossipsubHeartbeatInitialDelay: GossipSubHeartbeatInterval = 1 * time.Second
# Referece: https://github.com/libp2p/go-libp2p-daemon/blob/b95e77dbfcd186ccf817f51e95f73f9fd5982600/p2pd/main.go#L348-L353 # noqa: E501
self.proc = None
self.cmd = str(P2PD_PATH)
self.args = args
self.is_proc_running = False
self._tasks = []
async def wait_until_ready(self) -> None:
lines_head_pattern = (b"Control socket:", b"Peer ID:", b"Peer Addrs:")
lines_head_occurred = {line: False for line in lines_head_pattern}
async def read_from_daemon_and_check() -> bool:
line = await self.proc.stdout.readline()
for head_pattern in lines_head_occurred:
if line.startswith(head_pattern):
lines_head_occurred[head_pattern] = True
return all([value for value in lines_head_occurred.values()])
await try_until_success(read_from_daemon_and_check)
# Sleep a little bit to ensure the listener is up after logs are emitted.
await asyncio.sleep(0.01)
async def start_printing_logs(self) -> None:
async def _print_from_stream(
src_name: str, reader: asyncio.StreamReader
) -> None:
while True:
line = await reader.readline()
if line != b"":
print(f"{src_name}\t: {line.rstrip().decode()}")
await asyncio.sleep(0.01)
self._tasks.append(
asyncio.ensure_future(_print_from_stream("out", self.proc.stdout))
)
self._tasks.append(
asyncio.ensure_future(_print_from_stream("err", self.proc.stderr))
)
await asyncio.sleep(0)
async def start(self) -> None:
if self.is_proc_running:
return
self.proc = await asyncio.subprocess.create_subprocess_exec(
self.cmd,
*self.args,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
bufsize=0,
)
self.is_proc_running = True
await self.wait_until_ready()
await self.start_printing_logs()
async def close(self) -> None:
if self.is_proc_running:
self.proc.terminate()
await self.proc.wait()
self.is_proc_running = False
for task in self._tasks:
task.cancel()
self.patterns = (b"Control socket:", b"Peer ID:", b"Peer Addrs:")
self.bytes_read = bytearray()
self.event_ready = trio.Event()
class Daemon:
p2pd_proc: P2PDProcess
p2pd_proc: BaseInteractiveProcess
control: Client
peer_info: PeerInfo
def __init__(
self, p2pd_proc: P2PDProcess, control: Client, peer_info: PeerInfo
self, p2pd_proc: BaseInteractiveProcess, control: Client, peer_info: PeerInfo
) -> None:
self.p2pd_proc = p2pd_proc
self.control = control
@ -164,6 +81,7 @@ class Daemon:
await self.control.close()
@asynccontextmanager
async def make_p2pd(
daemon_control_port: int,
client_callback_port: int,
@ -172,7 +90,7 @@ async def make_p2pd(
is_gossipsub: bool = True,
is_pubsub_signing: bool = False,
is_pubsub_signing_strict: bool = False,
) -> Daemon:
) -> AsyncIterator[Daemon]:
control_maddr = Multiaddr(f"/ip4/{LOCALHOST_IP}/tcp/{daemon_control_port}")
p2pd_proc = P2PDProcess(
control_maddr,
@ -185,21 +103,22 @@ async def make_p2pd(
await p2pd_proc.start()
client_callback_maddr = Multiaddr(f"/ip4/{LOCALHOST_IP}/tcp/{client_callback_port}")
p2pc = Client(control_maddr, client_callback_maddr)
await p2pc.listen()
peer_id, maddrs = await p2pc.identify()
listen_maddr: Multiaddr = None
for maddr in maddrs:
try:
ip = maddr.value_for_protocol(multiaddr.protocols.P_IP4)
# NOTE: Check if this `maddr` uses `tcp`.
maddr.value_for_protocol(multiaddr.protocols.P_TCP)
except multiaddr.exceptions.ProtocolLookupError:
continue
if ip == LOCALHOST_IP:
listen_maddr = maddr
break
assert listen_maddr is not None, "no loopback maddr is found"
peer_info = info_from_p2p_addr(
listen_maddr.encapsulate(Multiaddr(f"/p2p/{peer_id.to_string()}"))
)
return Daemon(p2pd_proc, p2pc, peer_info)
async with p2pc.listen():
peer_id, maddrs = await p2pc.identify()
listen_maddr: Multiaddr = None
for maddr in maddrs:
try:
ip = maddr.value_for_protocol(multiaddr.protocols.P_IP4)
# NOTE: Check if this `maddr` uses `tcp`.
maddr.value_for_protocol(multiaddr.protocols.P_TCP)
except multiaddr.exceptions.ProtocolLookupError:
continue
if ip == LOCALHOST_IP:
listen_maddr = maddr
break
assert listen_maddr is not None, "no loopback maddr is found"
peer_info = info_from_p2p_addr(
listen_maddr.encapsulate(Multiaddr(f"/p2p/{peer_id.to_string()}"))
)
yield Daemon(p2pd_proc, p2pc, peer_info)

View File

@ -0,0 +1,66 @@
from abc import ABC, abstractmethod
import subprocess
from typing import Iterable, List
import trio
TIMEOUT_DURATION = 30
class AbstractInterativeProcess(ABC):
@abstractmethod
async def start(self) -> None:
...
@abstractmethod
async def close(self) -> None:
...
class BaseInteractiveProcess(AbstractInterativeProcess):
proc: trio.Process = None
cmd: str
args: List[str]
bytes_read: bytearray
patterns: Iterable[bytes] = None
event_ready: trio.Event
async def wait_until_ready(self) -> None:
patterns_occurred = {pat: False for pat in self.patterns}
async def read_from_daemon_and_check() -> None:
async for data in self.proc.stdout:
# TODO: It takes O(n^2), which is quite bad.
# But it should succeed in a few seconds.
self.bytes_read.extend(data)
for pat, occurred in patterns_occurred.items():
if occurred:
continue
if pat in self.bytes_read:
patterns_occurred[pat] = True
if all([value for value in patterns_occurred.values()]):
return
with trio.fail_after(TIMEOUT_DURATION):
await read_from_daemon_and_check()
self.event_ready.set()
# Sleep a little bit to ensure the listener is up after logs are emitted.
await trio.sleep(0.01)
async def start(self) -> None:
if self.proc is not None:
return
# NOTE: Ignore type checks here since mypy complains about bufsize=0
self.proc = await trio.open_process( # type: ignore
[self.cmd] + self.args,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT, # Redirect stderr to stdout, which makes parsing easier
bufsize=0,
)
await self.wait_until_ready()
async def close(self) -> None:
if self.proc is None:
return
self.proc.terminate()
await self.proc.wait()

View File

@ -1,7 +1,7 @@
import asyncio
from typing import Union
from multiaddr import Multiaddr
import trio
from libp2p.host.host_interface import IHost
from libp2p.peer.id import ID
@ -50,7 +50,7 @@ async def connect(a: TDaemonOrHost, b: TDaemonOrHost) -> None:
else: # isinstance(b, IHost)
await a.connect(b_peer_info)
# Allow additional sleep for both side to establish the connection.
await asyncio.sleep(0.1)
await trio.sleep(0.1)
a_peer_info = _get_peer_info(a)

View File

@ -1,12 +1,12 @@
import asyncio
from typing import Dict
import uuid
from typing import AsyncIterator, Dict, Tuple
from async_exit_stack import AsyncExitStack
from async_generator import asynccontextmanager
from async_service import Service, background_trio_service
from libp2p.host.host_interface import IHost
from libp2p.pubsub.floodsub import FloodSub
from libp2p.pubsub.pubsub import Pubsub
from libp2p.tools.constants import LISTEN_MADDR
from libp2p.tools.factories import FloodsubFactory, PubsubFactory
from libp2p.tools.factories import PubsubFactory
CRYPTO_TOPIC = "ethereum"
@ -18,7 +18,7 @@ CRYPTO_TOPIC = "ethereum"
# Determine message type by looking at first item before first comma
class DummyAccountNode:
class DummyAccountNode(Service):
"""
Node which has an internal balance mapping, meant to serve as a dummy
crypto blockchain.
@ -27,19 +27,24 @@ class DummyAccountNode:
crypto each user in the mappings holds
"""
libp2p_node: IHost
pubsub: Pubsub
floodsub: FloodSub
def __init__(self, libp2p_node: IHost, pubsub: Pubsub, floodsub: FloodSub):
self.libp2p_node = libp2p_node
def __init__(self, pubsub: Pubsub) -> None:
self.pubsub = pubsub
self.floodsub = floodsub
self.balances: Dict[str, int] = {}
self.node_id = str(uuid.uuid1())
@property
def host(self) -> IHost:
return self.pubsub.host
async def run(self) -> None:
self.subscription = await self.pubsub.subscribe(CRYPTO_TOPIC)
self.manager.run_daemon_task(self.handle_incoming_msgs)
await self.manager.wait_finished()
@classmethod
async def create(cls) -> "DummyAccountNode":
@asynccontextmanager
async def create(cls, number: int) -> AsyncIterator[Tuple["DummyAccountNode", ...]]:
"""
Create a new DummyAccountNode and attach a libp2p node, a floodsub, and
a pubsub instance to this new node.
@ -47,15 +52,17 @@ class DummyAccountNode:
We use create as this serves as a factory function and allows us
to use async await, unlike the init function
"""
pubsub = PubsubFactory(router=FloodsubFactory())
await pubsub.host.get_network().listen(LISTEN_MADDR)
return cls(libp2p_node=pubsub.host, pubsub=pubsub, floodsub=pubsub.router)
async with PubsubFactory.create_batch_with_floodsub(number) as pubsubs:
async with AsyncExitStack() as stack:
dummy_acount_nodes = tuple(cls(pubsub) for pubsub in pubsubs)
for node in dummy_acount_nodes:
await stack.enter_async_context(background_trio_service(node))
yield dummy_acount_nodes
async def handle_incoming_msgs(self) -> None:
"""Handle all incoming messages on the CRYPTO_TOPIC from peers."""
while True:
incoming = await self.q.get()
incoming = await self.subscription.get()
msg_comps = incoming.data.decode("utf-8").split(",")
if msg_comps[0] == "send":
@ -63,13 +70,6 @@ class DummyAccountNode:
elif msg_comps[0] == "set":
self.handle_set_crypto(msg_comps[1], int(msg_comps[2]))
async def setup_crypto_networking(self) -> None:
"""Subscribe to CRYPTO_TOPIC and perform call to function that handles
all incoming messages on said topic."""
self.q = await self.pubsub.subscribe(CRYPTO_TOPIC)
asyncio.ensure_future(self.handle_incoming_msgs())
async def publish_send_crypto(
self, source_user: str, dest_user: str, amount: int
) -> None:

View File

@ -1,12 +1,10 @@
# type: ignore
# To add typing to this module, it's better to do it after refactoring test cases into classes
import asyncio
import pytest
import trio
from libp2p.tools.constants import FLOODSUB_PROTOCOL_ID, LISTEN_MADDR
from libp2p.tools.factories import PubsubFactory
from libp2p.tools.constants import FLOODSUB_PROTOCOL_ID
from libp2p.tools.utils import connect
SUPPORTED_PROTOCOLS = [FLOODSUB_PROTOCOL_ID]
@ -15,6 +13,7 @@ FLOODSUB_PROTOCOL_TEST_CASES = [
{
"name": "simple_two_nodes",
"supported_protocols": SUPPORTED_PROTOCOLS,
"nodes": ["A", "B"],
"adj_list": {"A": ["B"]},
"topic_map": {"topic1": ["B"]},
"messages": [{"topics": ["topic1"], "data": b"foo", "node_id": "A"}],
@ -22,6 +21,7 @@ FLOODSUB_PROTOCOL_TEST_CASES = [
{
"name": "three_nodes_two_topics",
"supported_protocols": SUPPORTED_PROTOCOLS,
"nodes": ["A", "B", "C"],
"adj_list": {"A": ["B"], "B": ["C"]},
"topic_map": {"topic1": ["B", "C"], "topic2": ["B", "C"]},
"messages": [
@ -32,6 +32,7 @@ FLOODSUB_PROTOCOL_TEST_CASES = [
{
"name": "two_nodes_one_topic_single_subscriber_is_sender",
"supported_protocols": SUPPORTED_PROTOCOLS,
"nodes": ["A", "B"],
"adj_list": {"A": ["B"]},
"topic_map": {"topic1": ["B"]},
"messages": [{"topics": ["topic1"], "data": b"Alex is tall", "node_id": "B"}],
@ -39,6 +40,7 @@ FLOODSUB_PROTOCOL_TEST_CASES = [
{
"name": "two_nodes_one_topic_two_msgs",
"supported_protocols": SUPPORTED_PROTOCOLS,
"nodes": ["A", "B"],
"adj_list": {"A": ["B"]},
"topic_map": {"topic1": ["B"]},
"messages": [
@ -49,6 +51,7 @@ FLOODSUB_PROTOCOL_TEST_CASES = [
{
"name": "seven_nodes_tree_one_topics",
"supported_protocols": SUPPORTED_PROTOCOLS,
"nodes": ["1", "2", "3", "4", "5", "6", "7"],
"adj_list": {"1": ["2", "3"], "2": ["4", "5"], "3": ["6", "7"]},
"topic_map": {"astrophysics": ["2", "3", "4", "5", "6", "7"]},
"messages": [{"topics": ["astrophysics"], "data": b"e=mc^2", "node_id": "1"}],
@ -56,6 +59,7 @@ FLOODSUB_PROTOCOL_TEST_CASES = [
{
"name": "seven_nodes_tree_three_topics",
"supported_protocols": SUPPORTED_PROTOCOLS,
"nodes": ["1", "2", "3", "4", "5", "6", "7"],
"adj_list": {"1": ["2", "3"], "2": ["4", "5"], "3": ["6", "7"]},
"topic_map": {
"astrophysics": ["2", "3", "4", "5", "6", "7"],
@ -71,6 +75,7 @@ FLOODSUB_PROTOCOL_TEST_CASES = [
{
"name": "seven_nodes_tree_three_topics_diff_origin",
"supported_protocols": SUPPORTED_PROTOCOLS,
"nodes": ["1", "2", "3", "4", "5", "6", "7"],
"adj_list": {"1": ["2", "3"], "2": ["4", "5"], "3": ["6", "7"]},
"topic_map": {
"astrophysics": ["1", "2", "3", "4", "5", "6", "7"],
@ -86,6 +91,7 @@ FLOODSUB_PROTOCOL_TEST_CASES = [
{
"name": "three_nodes_clique_two_topic_diff_origin",
"supported_protocols": SUPPORTED_PROTOCOLS,
"nodes": ["1", "2", "3"],
"adj_list": {"1": ["2", "3"], "2": ["3"]},
"topic_map": {"astrophysics": ["1", "2", "3"], "school": ["1", "2", "3"]},
"messages": [
@ -97,6 +103,7 @@ FLOODSUB_PROTOCOL_TEST_CASES = [
{
"name": "four_nodes_clique_two_topic_diff_origin_many_msgs",
"supported_protocols": SUPPORTED_PROTOCOLS,
"nodes": ["1", "2", "3", "4"],
"adj_list": {
"1": ["2", "3", "4"],
"2": ["1", "3", "4"],
@ -120,6 +127,7 @@ FLOODSUB_PROTOCOL_TEST_CASES = [
{
"name": "five_nodes_ring_two_topic_diff_origin_many_msgs",
"supported_protocols": SUPPORTED_PROTOCOLS,
"nodes": ["1", "2", "3", "4", "5"],
"adj_list": {"1": ["2"], "2": ["3"], "3": ["4"], "4": ["5"], "5": ["1"]},
"topic_map": {
"astrophysics": ["1", "2", "3", "4", "5"],
@ -143,15 +151,7 @@ floodsub_protocol_pytest_params = [
]
def _collect_node_ids(adj_list):
node_ids = set()
for node, neighbors in adj_list.items():
node_ids.add(node)
node_ids.update(set(neighbors))
return node_ids
async def perform_test_from_obj(obj, router_factory) -> None:
async def perform_test_from_obj(obj, pubsub_factory) -> None:
"""
Perform pubsub tests from a test object, which is composed as follows:
@ -185,68 +185,75 @@ async def perform_test_from_obj(obj, router_factory) -> None:
# Step 1) Create graph
adj_list = obj["adj_list"]
node_list = obj["nodes"]
node_map = {}
pubsub_map = {}
async def add_node(node_id_str: str):
pubsub_router = router_factory(protocols=obj["supported_protocols"])
pubsub = PubsubFactory(router=pubsub_router)
await pubsub.host.get_network().listen(LISTEN_MADDR)
node_map[node_id_str] = pubsub.host
pubsub_map[node_id_str] = pubsub
async with pubsub_factory(
number=len(node_list), protocols=obj["supported_protocols"]
) as pubsubs:
for node_id_str, pubsub in zip(node_list, pubsubs):
node_map[node_id_str] = pubsub.host
pubsub_map[node_id_str] = pubsub
all_node_ids = _collect_node_ids(adj_list)
# Connect nodes and wait at least for 2 seconds
async with trio.open_nursery() as nursery:
for start_node_id in adj_list:
# For each neighbor of start_node, create if does not yet exist,
# then connect start_node to neighbor
for neighbor_id in adj_list[start_node_id]:
nursery.start_soon(
connect, node_map[start_node_id], node_map[neighbor_id]
)
nursery.start_soon(trio.sleep, 2)
for node in all_node_ids:
await add_node(node)
# Step 2) Subscribe to topics
queues_map = {}
topic_map = obj["topic_map"]
for node, neighbors in adj_list.items():
for neighbor_id in neighbors:
await connect(node_map[node], node_map[neighbor_id])
# NOTE: the test using this routine will fail w/o these sleeps...
await asyncio.sleep(1)
# Step 2) Subscribe to topics
queues_map = {}
topic_map = obj["topic_map"]
for topic, node_ids in topic_map.items():
for node_id in node_ids:
queue = await pubsub_map[node_id].subscribe(topic)
async def subscribe_node(node_id, topic):
if node_id not in queues_map:
queues_map[node_id] = {}
# Store queue in topic-queue map for node
queues_map[node_id][topic] = queue
# Avoid repeated works
if topic in queues_map[node_id]:
# Checkpoint
await trio.hazmat.checkpoint()
return
sub = await pubsub_map[node_id].subscribe(topic)
queues_map[node_id][topic] = sub
# NOTE: the test using this routine will fail w/o these sleeps...
await asyncio.sleep(1)
async with trio.open_nursery() as nursery:
for topic, node_ids in topic_map.items():
for node_id in node_ids:
nursery.start_soon(subscribe_node, node_id, topic)
nursery.start_soon(trio.sleep, 2)
# Step 3) Publish messages
topics_in_msgs_ordered = []
messages = obj["messages"]
# Step 3) Publish messages
topics_in_msgs_ordered = []
messages = obj["messages"]
for msg in messages:
topics = msg["topics"]
data = msg["data"]
node_id = msg["node_id"]
for msg in messages:
topics = msg["topics"]
data = msg["data"]
node_id = msg["node_id"]
# Publish message
# TODO: Should be single RPC package with several topics
for topic in topics:
await pubsub_map[node_id].publish(topic, data)
# Publish message
# TODO: Should be single RPC package with several topics
for topic in topics:
await pubsub_map[node_id].publish(topic, data)
# For each topic in topics, add (topic, node_id, data) tuple to ordered test list
topics_in_msgs_ordered.append((topic, node_id, data))
for topic in topics:
topics_in_msgs_ordered.append((topic, node_id, data))
# Allow time for publishing before continuing
await trio.sleep(1)
# Step 4) Check that all messages were received correctly.
for topic, origin_node_id, data in topics_in_msgs_ordered:
# Look at each node in each topic
for node_id in topic_map[topic]:
# Get message from subscription queue
queue = queues_map[node_id][topic]
msg = await queue.get()
assert data == msg.data
# Check the message origin
assert node_map[origin_node_id].get_id().to_bytes() == msg.from_id
# Success, terminate pending tasks.
# Step 4) Check that all messages were received correctly.
for topic, origin_node_id, data in topics_in_msgs_ordered:
# Look at each node in each topic
for node_id in topic_map[topic]:
# Get message from subscription queue
msg = await queues_map[node_id][topic].get()
assert data == msg.data
# Check the message origin
assert node_map[origin_node_id].get_id().to_bytes() == msg.from_id

View File

@ -1,17 +1,10 @@
from typing import Dict, Sequence, Tuple, cast
from typing import Awaitable, Callable
import multiaddr
from libp2p import new_node
from libp2p.host.basic_host import BasicHost
from libp2p.host.host_interface import IHost
from libp2p.host.routed_host import RoutedHost
from libp2p.network.stream.exceptions import StreamError
from libp2p.network.stream.net_stream_interface import INetStream
from libp2p.network.swarm import Swarm
from libp2p.peer.id import ID
from libp2p.peer.peerinfo import PeerInfo, info_from_p2p_addr
from libp2p.routing.interfaces import IPeerRouting
from libp2p.typing import StreamHandlerFn, TProtocol
from libp2p.peer.peerinfo import info_from_p2p_addr
from .constants import MAX_READ_LEN
@ -36,63 +29,20 @@ async def connect(node1: IHost, node2: IHost) -> None:
await node1.connect(info)
async def set_up_nodes_by_transport_opt(
transport_opt_list: Sequence[Sequence[str]]
) -> Tuple[BasicHost, ...]:
nodes_list = []
for transport_opt in transport_opt_list:
node = await new_node(transport_opt=transport_opt)
await node.get_network().listen(multiaddr.Multiaddr(transport_opt[0]))
nodes_list.append(node)
return tuple(nodes_list)
def create_echo_stream_handler(
ack_prefix: str
) -> Callable[[INetStream], Awaitable[None]]:
async def echo_stream_handler(stream: INetStream) -> None:
while True:
try:
read_string = (await stream.read(MAX_READ_LEN)).decode()
except StreamError:
break
resp = ack_prefix + read_string
try:
await stream.write(resp.encode())
except StreamError:
break
async def echo_stream_handler(stream: INetStream) -> None:
while True:
read_string = (await stream.read(MAX_READ_LEN)).decode()
resp = f"ack:{read_string}"
await stream.write(resp.encode())
async def perform_two_host_set_up(
handler: StreamHandlerFn = echo_stream_handler
) -> Tuple[BasicHost, BasicHost]:
transport_opt_list = [["/ip4/127.0.0.1/tcp/0"], ["/ip4/127.0.0.1/tcp/0"]]
(node_a, node_b) = await set_up_nodes_by_transport_opt(transport_opt_list)
node_b.set_stream_handler(TProtocol("/echo/1.0.0"), handler)
# Associate the peer with local ip address (see default parameters of Libp2p())
node_a.get_peerstore().add_addrs(node_b.get_id(), node_b.get_addrs(), 10)
return node_a, node_b
class DummyRouter(IPeerRouting):
_routing_table: Dict[ID, PeerInfo]
def __init__(self) -> None:
self._routing_table = dict()
async def find_peer(self, peer_id: ID) -> PeerInfo:
return self._routing_table.get(peer_id, None)
async def set_up_routed_hosts() -> Tuple[RoutedHost, RoutedHost]:
router_a, router_b = DummyRouter(), DummyRouter()
transport = "/ip4/127.0.0.1/tcp/0"
host_a = await new_node(transport_opt=[transport], disc_opt=router_a)
host_b = await new_node(transport_opt=[transport], disc_opt=router_b)
address = multiaddr.Multiaddr(transport)
await host_a.get_network().listen(address)
await host_b.get_network().listen(address)
mock_routing_table = {
host_a.get_id(): PeerInfo(host_a.get_id(), host_a.get_addrs()),
host_b.get_id(): PeerInfo(host_b.get_id(), host_b.get_addrs()),
}
router_a._routing_table = router_b._routing_table = mock_routing_table
return cast(RoutedHost, host_a), cast(RoutedHost, host_b)
return echo_stream_handler

View File

@ -1,12 +1,13 @@
from abc import ABC, abstractmethod
from typing import List
from typing import Tuple
from multiaddr import Multiaddr
import trio
class IListener(ABC):
@abstractmethod
async def listen(self, maddr: Multiaddr) -> bool:
async def listen(self, maddr: Multiaddr, nursery: trio.Nursery) -> bool:
"""
put listener in listening mode and wait for incoming connections.
@ -15,7 +16,7 @@ class IListener(ABC):
"""
@abstractmethod
def get_addrs(self) -> List[Multiaddr]:
def get_addrs(self) -> Tuple[Multiaddr, ...]:
"""
retrieve list of addresses the listener is listening on.
@ -24,5 +25,4 @@ class IListener(ABC):
@abstractmethod
async def close(self) -> None:
"""close the listener such that no more connections can be open on this
transport instance."""
...

View File

@ -1,10 +1,11 @@
import asyncio
from socket import socket
import sys
from typing import List
import logging
from typing import Awaitable, Callable, List, Sequence, Tuple
from multiaddr import Multiaddr
import trio
from trio_typing import TaskStatus
from libp2p.io.trio import TrioTCPStream
from libp2p.network.connection.raw_connection import RawConnection
from libp2p.network.connection.raw_connection_interface import IRawConnection
from libp2p.transport.exceptions import OpenConnectionError
@ -12,53 +13,61 @@ from libp2p.transport.listener_interface import IListener
from libp2p.transport.transport_interface import ITransport
from libp2p.transport.typing import THandler
logger = logging.getLogger("libp2p.transport.tcp")
class TCPListener(IListener):
multiaddrs: List[Multiaddr]
server = None
listeners: List[trio.SocketListener]
def __init__(self, handler_function: THandler) -> None:
self.multiaddrs = []
self.server = None
self.listeners = []
self.handler = handler_function
async def listen(self, maddr: Multiaddr) -> bool:
# TODO: Get rid of `nursery`?
async def listen(self, maddr: Multiaddr, nursery: trio.Nursery) -> None:
"""
put listener in listening mode and wait for incoming connections.
:param maddr: maddr of peer
:return: return True if successful
"""
self.server = await asyncio.start_server(
self.handler,
async def serve_tcp(
handler: Callable[[trio.SocketStream], Awaitable[None]],
port: int,
host: str,
task_status: TaskStatus[Sequence[trio.SocketListener]] = None,
) -> None:
"""Just a proxy function to add logging here."""
logger.debug("serve_tcp %s %s", host, port)
await trio.serve_tcp(handler, port, host=host, task_status=task_status)
async def handler(stream: trio.SocketStream) -> None:
tcp_stream = TrioTCPStream(stream)
await self.handler(tcp_stream)
listeners = await nursery.start(
serve_tcp,
handler,
int(maddr.value_for_protocol("tcp")),
maddr.value_for_protocol("ip4"),
maddr.value_for_protocol("tcp"),
)
socket = self.server.sockets[0]
self.multiaddrs.append(_multiaddr_from_socket(socket))
self.listeners.extend(listeners)
return True
def get_addrs(self) -> List[Multiaddr]:
def get_addrs(self) -> Tuple[Multiaddr, ...]:
"""
retrieve list of addresses the listener is listening on.
:return: return list of addrs
"""
# TODO check if server is listening
return self.multiaddrs
return tuple(
_multiaddr_from_socket(listener.socket) for listener in self.listeners
)
async def close(self) -> None:
"""close the listener such that no more connections can be open on this
transport instance."""
if self.server is None:
return
self.server.close()
server = self.server
self.server = None
if sys.version_info < (3, 7):
return
await server.wait_closed()
async with trio.open_nursery() as nursery:
for listener in self.listeners:
nursery.start_soon(listener.aclose)
class TCP(ITransport):
@ -74,11 +83,12 @@ class TCP(ITransport):
self.port = int(maddr.value_for_protocol("tcp"))
try:
reader, writer = await asyncio.open_connection(self.host, self.port)
except (ConnectionAbortedError, ConnectionRefusedError) as error:
raise OpenConnectionError(error)
stream = await trio.open_tcp_stream(self.host, self.port)
except OSError as error:
raise OpenConnectionError from error
read_write_closer = TrioTCPStream(stream)
return RawConnection(reader, writer, True)
return RawConnection(read_write_closer, True)
def create_listener(self, handler_function: THandler) -> TCPListener:
"""
@ -91,6 +101,6 @@ class TCP(ITransport):
return TCPListener(handler_function)
def _multiaddr_from_socket(socket: socket) -> Multiaddr:
addr, port = socket.getsockname()[:2]
return Multiaddr(f"/ip4/{addr}/tcp/{port}")
def _multiaddr_from_socket(socket: trio.socket.SocketType) -> Multiaddr:
ip, port = socket.getsockname() # type: ignore
return Multiaddr(f"/ip4/{ip}/tcp/{port}")

View File

@ -1,11 +1,11 @@
from asyncio import StreamReader, StreamWriter
from typing import Awaitable, Callable, Mapping, Type
from libp2p.io.abc import ReadWriteCloser
from libp2p.security.secure_transport_interface import ISecureTransport
from libp2p.stream_muxer.abc import IMuxedConn
from libp2p.typing import TProtocol
THandler = Callable[[StreamReader, StreamWriter], Awaitable[None]]
THandler = Callable[[ReadWriteCloser], Awaitable[None]]
TSecurityOptions = Mapping[TProtocol, ISecureTransport]
TMuxerClass = Type[IMuxedConn]
TMuxerOptions = Mapping[TProtocol, TMuxerClass]

View File

@ -7,8 +7,8 @@ from setuptools import find_packages, setup
extras_require = {
"test": [
"pytest>=4.6.3,<5.0.0",
"pytest-xdist>=1.30.0,<2",
"pytest-asyncio>=0.10.0,<1.0.0",
"pytest-xdist>=1.30.0",
"pytest-trio>=0.5.2",
"factory-boy>=2.12.0,<3.0.0",
],
"lint": [
@ -74,6 +74,10 @@ install_requires = [
"pynacl==1.3.0",
"dataclasses>=0.7, <1;python_version<'3.7'",
"async_generator==1.10",
"trio>=0.13.0",
"async-service>=0.1.0a6",
"async-exit-stack==1.0.1",
"trio-typing>=0.3.0,<0.4.0",
]

View File

@ -1,8 +1,5 @@
import asyncio
import pytest
from libp2p.tools.constants import LISTEN_MADDR
from libp2p.tools.factories import HostFactory
@ -17,17 +14,6 @@ def num_hosts():
@pytest.fixture
async def hosts(num_hosts, is_host_secure):
_hosts = HostFactory.create_batch(num_hosts, is_secure=is_host_secure)
await asyncio.gather(
*[_host.get_network().listen(LISTEN_MADDR) for _host in _hosts]
)
try:
async def hosts(num_hosts, is_host_secure, nursery):
async with HostFactory.create_batch_and_listen(is_host_secure, num_hosts) as _hosts:
yield _hosts
finally:
# TODO: It's possible that `close` raises exceptions currently,
# due to the connection reset things. Though we don't care much about that when
# cleaning up the tasks, it is probably better to handle the exceptions properly.
await asyncio.gather(
*[_host.close() for _host in _hosts], return_exceptions=True
)

View File

@ -1,10 +1,10 @@
import asyncio
import pytest
import trio
from libp2p.host.exceptions import StreamFailure
from libp2p.peer.peerinfo import info_from_p2p_addr
from libp2p.tools.utils import set_up_nodes_by_transport_opt
from libp2p.tools.factories import HostFactory
from libp2p.tools.utils import MAX_READ_LEN
PROTOCOL_ID = "/chat/1.0.0"
@ -25,7 +25,7 @@ async def hello_world(host_a, host_b):
# Multiaddress of the destination peer is fetched from the peerstore using 'peerId'.
stream = await host_b.new_stream(host_a.get_id(), [PROTOCOL_ID])
await stream.write(hello_world_from_host_b)
read = await stream.read()
read = await stream.read(MAX_READ_LEN)
assert read == hello_world_from_host_a
await stream.close()
@ -47,7 +47,7 @@ async def connect_write(host_a, host_b):
await stream.write(message.encode())
# Reader needs time due to async reads
await asyncio.sleep(2)
await trio.sleep(2)
await stream.close()
assert received == messages
@ -88,16 +88,14 @@ async def no_common_protocol(host_a, host_b):
await host_b.new_stream(host_a.get_id(), ["/fakeproto/0.0.1"])
@pytest.mark.asyncio
@pytest.mark.parametrize(
"test", [(hello_world), (connect_write), (connect_read), (no_common_protocol)]
)
async def test_chat(test):
transport_opt_list = [["/ip4/127.0.0.1/tcp/0"], ["/ip4/127.0.0.1/tcp/0"]]
(host_a, host_b) = await set_up_nodes_by_transport_opt(transport_opt_list)
@pytest.mark.trio
async def test_chat(test, is_host_secure):
async with HostFactory.create_batch_and_listen(is_host_secure, 2) as hosts:
addr = hosts[0].get_addrs()[0]
info = info_from_p2p_addr(addr)
await hosts[1].connect(info)
addr = host_a.get_addrs()[0]
info = info_from_p2p_addr(addr)
await host_b.connect(info)
await test(host_a, host_b)
await test(hosts[0], hosts[1])

View File

@ -1,4 +1,4 @@
from libp2p import initialize_default_swarm
from libp2p import new_swarm
from libp2p.crypto.rsa import create_new_key_pair
from libp2p.host.basic_host import BasicHost
from libp2p.host.defaults import get_default_protocols
@ -6,7 +6,7 @@ from libp2p.host.defaults import get_default_protocols
def test_default_protocols():
key_pair = create_new_key_pair()
swarm = initialize_default_swarm(key_pair)
swarm = new_swarm(key_pair)
host = BasicHost(swarm)
mux = host.get_mux()

View File

@ -1,18 +1,19 @@
import asyncio
import secrets
import pytest
import trio
from libp2p.host.ping import ID, PING_LENGTH
from libp2p.tools.factories import pair_of_connected_hosts
from libp2p.tools.factories import host_pair_factory
@pytest.mark.asyncio
async def test_ping_once():
async with pair_of_connected_hosts() as (host_a, host_b):
@pytest.mark.trio
async def test_ping_once(is_host_secure):
async with host_pair_factory(is_host_secure) as (host_a, host_b):
stream = await host_b.new_stream(host_a.get_id(), (ID,))
some_ping = secrets.token_bytes(PING_LENGTH)
await stream.write(some_ping)
await trio.sleep(0.01)
some_pong = await stream.read(PING_LENGTH)
assert some_ping == some_pong
await stream.close()
@ -21,9 +22,9 @@ async def test_ping_once():
SOME_PING_COUNT = 3
@pytest.mark.asyncio
async def test_ping_several():
async with pair_of_connected_hosts() as (host_a, host_b):
@pytest.mark.trio
async def test_ping_several(is_host_secure):
async with host_pair_factory(is_host_secure) as (host_a, host_b):
stream = await host_b.new_stream(host_a.get_id(), (ID,))
for _ in range(SOME_PING_COUNT):
some_ping = secrets.token_bytes(PING_LENGTH)
@ -33,5 +34,5 @@ async def test_ping_several():
# NOTE: simulate some time to sleep to mirror a real
# world usage where a peer sends pings on some periodic interval
# NOTE: this interval can be `0` for this test.
await asyncio.sleep(0)
await trio.sleep(0)
await stream.close()

View File

@ -1,33 +1,26 @@
import asyncio
import pytest
from libp2p.host.exceptions import ConnectionFailure
from libp2p.peer.peerinfo import PeerInfo
from libp2p.tools.utils import set_up_nodes_by_transport_opt, set_up_routed_hosts
from libp2p.tools.factories import HostFactory, RoutedHostFactory
@pytest.mark.asyncio
@pytest.mark.trio
async def test_host_routing_success():
host_a, host_b = await set_up_routed_hosts()
# forces to use routing as no addrs are provided
await host_a.connect(PeerInfo(host_b.get_id(), []))
await host_b.connect(PeerInfo(host_a.get_id(), []))
# Clean up
await asyncio.gather(*[host_a.close(), host_b.close()])
async with RoutedHostFactory.create_batch_and_listen(False, 2) as hosts:
# forces to use routing as no addrs are provided
await hosts[0].connect(PeerInfo(hosts[1].get_id(), []))
await hosts[1].connect(PeerInfo(hosts[0].get_id(), []))
@pytest.mark.asyncio
@pytest.mark.trio
async def test_host_routing_fail():
host_a, host_b = await set_up_routed_hosts()
basic_host_c = (await set_up_nodes_by_transport_opt([["/ip4/127.0.0.1/tcp/0"]]))[0]
# routing fails because host_c does not use routing
with pytest.raises(ConnectionFailure):
await host_a.connect(PeerInfo(basic_host_c.get_id(), []))
with pytest.raises(ConnectionFailure):
await host_b.connect(PeerInfo(basic_host_c.get_id(), []))
# Clean up
await asyncio.gather(*[host_a.close(), host_b.close(), basic_host_c.close()])
is_secure = False
async with RoutedHostFactory.create_batch_and_listen(
is_secure, 2
) as routed_hosts, HostFactory.create_batch_and_listen(is_secure, 1) as basic_hosts:
# routing fails because host_c does not use routing
with pytest.raises(ConnectionFailure):
await routed_hosts[0].connect(PeerInfo(basic_hosts[0].get_id(), []))
with pytest.raises(ConnectionFailure):
await routed_hosts[1].connect(PeerInfo(basic_hosts[0].get_id(), []))

View File

@ -2,12 +2,12 @@ import pytest
from libp2p.identity.identify.pb.identify_pb2 import Identify
from libp2p.identity.identify.protocol import ID, _mk_identify_protobuf
from libp2p.tools.factories import pair_of_connected_hosts
from libp2p.tools.factories import host_pair_factory
@pytest.mark.asyncio
async def test_identify_protocol():
async with pair_of_connected_hosts() as (host_a, host_b):
@pytest.mark.trio
async def test_identify_protocol(is_host_secure):
async with host_pair_factory(is_host_secure) as (host_a, host_b):
stream = await host_b.new_stream(host_a.get_id(), (ID,))
response = await stream.read()
await stream.close()

View File

@ -1,350 +1,285 @@
import multiaddr
import pytest
from libp2p.peer.peerinfo import info_from_p2p_addr
from libp2p.network.stream.exceptions import StreamError
from libp2p.tools.constants import MAX_READ_LEN
from libp2p.tools.utils import set_up_nodes_by_transport_opt
from libp2p.tools.factories import HostFactory
from libp2p.tools.utils import connect, create_echo_stream_handler
from libp2p.typing import TProtocol
PROTOCOL_ID_0 = TProtocol("/echo/0")
PROTOCOL_ID_1 = TProtocol("/echo/1")
PROTOCOL_ID_2 = TProtocol("/echo/2")
PROTOCOL_ID_3 = TProtocol("/echo/3")
ACK_STR_0 = "ack_0:"
ACK_STR_1 = "ack_1:"
ACK_STR_2 = "ack_2:"
ACK_STR_3 = "ack_3:"
@pytest.mark.asyncio
async def test_simple_messages():
transport_opt_list = [["/ip4/127.0.0.1/tcp/0"], ["/ip4/127.0.0.1/tcp/0"]]
(node_a, node_b) = await set_up_nodes_by_transport_opt(transport_opt_list)
async def stream_handler(stream):
while True:
read_string = (await stream.read(MAX_READ_LEN)).decode()
response = "ack:" + read_string
await stream.write(response.encode())
node_b.set_stream_handler("/echo/1.0.0", stream_handler)
# Associate the peer with local ip address (see default parameters of Libp2p())
node_a.get_peerstore().add_addrs(node_b.get_id(), node_b.get_addrs(), 10)
stream = await node_a.new_stream(node_b.get_id(), ["/echo/1.0.0"])
messages = ["hello" + str(x) for x in range(10)]
for message in messages:
await stream.write(message.encode())
response = (await stream.read(MAX_READ_LEN)).decode()
assert response == ("ack:" + message)
# Success, terminate pending tasks.
@pytest.mark.asyncio
async def test_double_response():
transport_opt_list = [["/ip4/127.0.0.1/tcp/0"], ["/ip4/127.0.0.1/tcp/0"]]
(node_a, node_b) = await set_up_nodes_by_transport_opt(transport_opt_list)
async def stream_handler(stream):
while True:
read_string = (await stream.read(MAX_READ_LEN)).decode()
response = "ack1:" + read_string
await stream.write(response.encode())
response = "ack2:" + read_string
await stream.write(response.encode())
node_b.set_stream_handler("/echo/1.0.0", stream_handler)
# Associate the peer with local ip address (see default parameters of Libp2p())
node_a.get_peerstore().add_addrs(node_b.get_id(), node_b.get_addrs(), 10)
stream = await node_a.new_stream(node_b.get_id(), ["/echo/1.0.0"])
messages = ["hello" + str(x) for x in range(10)]
for message in messages:
await stream.write(message.encode())
response1 = (await stream.read(MAX_READ_LEN)).decode()
assert response1 == ("ack1:" + message)
response2 = (await stream.read(MAX_READ_LEN)).decode()
assert response2 == ("ack2:" + message)
# Success, terminate pending tasks.
@pytest.mark.asyncio
async def test_multiple_streams():
# Node A should be able to open a stream with node B and then vice versa.
# Stream IDs should be generated uniquely so that the stream state is not overwritten
transport_opt_list = [["/ip4/127.0.0.1/tcp/0"], ["/ip4/127.0.0.1/tcp/0"]]
(node_a, node_b) = await set_up_nodes_by_transport_opt(transport_opt_list)
async def stream_handler_a(stream):
while True:
read_string = (await stream.read(MAX_READ_LEN)).decode()
response = "ack_a:" + read_string
await stream.write(response.encode())
async def stream_handler_b(stream):
while True:
read_string = (await stream.read(MAX_READ_LEN)).decode()
response = "ack_b:" + read_string
await stream.write(response.encode())
node_a.set_stream_handler("/echo_a/1.0.0", stream_handler_a)
node_b.set_stream_handler("/echo_b/1.0.0", stream_handler_b)
# Associate the peer with local ip address (see default parameters of Libp2p())
node_a.get_peerstore().add_addrs(node_b.get_id(), node_b.get_addrs(), 10)
node_b.get_peerstore().add_addrs(node_a.get_id(), node_a.get_addrs(), 10)
stream_a = await node_a.new_stream(node_b.get_id(), ["/echo_b/1.0.0"])
stream_b = await node_b.new_stream(node_a.get_id(), ["/echo_a/1.0.0"])
# A writes to /echo_b via stream_a, and B writes to /echo_a via stream_b
messages = ["hello" + str(x) for x in range(10)]
for message in messages:
a_message = message + "_a"
b_message = message + "_b"
await stream_a.write(a_message.encode())
await stream_b.write(b_message.encode())
response_a = (await stream_a.read(MAX_READ_LEN)).decode()
response_b = (await stream_b.read(MAX_READ_LEN)).decode()
assert response_a == ("ack_b:" + a_message) and response_b == (
"ack_a:" + b_message
@pytest.mark.trio
async def test_simple_messages(is_host_secure):
async with HostFactory.create_batch_and_listen(is_host_secure, 2) as hosts:
hosts[1].set_stream_handler(
PROTOCOL_ID_0, create_echo_stream_handler(ACK_STR_0)
)
# Success, terminate pending tasks.
# Associate the peer with local ip address (see default parameters of Libp2p())
hosts[0].get_peerstore().add_addrs(hosts[1].get_id(), hosts[1].get_addrs(), 10)
stream = await hosts[0].new_stream(hosts[1].get_id(), [PROTOCOL_ID_0])
messages = ["hello" + str(x) for x in range(10)]
for message in messages:
await stream.write(message.encode())
response = (await stream.read(MAX_READ_LEN)).decode()
assert response == (ACK_STR_0 + message)
@pytest.mark.asyncio
async def test_multiple_streams_same_initiator_different_protocols():
transport_opt_list = [["/ip4/127.0.0.1/tcp/0"], ["/ip4/127.0.0.1/tcp/0"]]
(node_a, node_b) = await set_up_nodes_by_transport_opt(transport_opt_list)
@pytest.mark.trio
async def test_double_response(is_host_secure):
async with HostFactory.create_batch_and_listen(is_host_secure, 2) as hosts:
async def stream_handler_a1(stream):
while True:
read_string = (await stream.read(MAX_READ_LEN)).decode()
async def double_response_stream_handler(stream):
while True:
try:
read_string = (await stream.read(MAX_READ_LEN)).decode()
except StreamError:
break
response = "ack_a1:" + read_string
await stream.write(response.encode())
response = ACK_STR_0 + read_string
try:
await stream.write(response.encode())
except StreamError:
break
async def stream_handler_a2(stream):
while True:
read_string = (await stream.read(MAX_READ_LEN)).decode()
response = ACK_STR_1 + read_string
try:
await stream.write(response.encode())
except StreamError:
break
response = "ack_a2:" + read_string
await stream.write(response.encode())
hosts[1].set_stream_handler(PROTOCOL_ID_0, double_response_stream_handler)
async def stream_handler_a3(stream):
while True:
read_string = (await stream.read(MAX_READ_LEN)).decode()
# Associate the peer with local ip address (see default parameters of Libp2p())
hosts[0].get_peerstore().add_addrs(hosts[1].get_id(), hosts[1].get_addrs(), 10)
stream = await hosts[0].new_stream(hosts[1].get_id(), [PROTOCOL_ID_0])
response = "ack_a3:" + read_string
await stream.write(response.encode())
node_b.set_stream_handler("/echo_a1/1.0.0", stream_handler_a1)
node_b.set_stream_handler("/echo_a2/1.0.0", stream_handler_a2)
node_b.set_stream_handler("/echo_a3/1.0.0", stream_handler_a3)
# Associate the peer with local ip address (see default parameters of Libp2p())
node_a.get_peerstore().add_addrs(node_b.get_id(), node_b.get_addrs(), 10)
node_b.get_peerstore().add_addrs(node_a.get_id(), node_a.get_addrs(), 10)
# Open streams to node_b over echo_a1 echo_a2 echo_a3 protocols
stream_a1 = await node_a.new_stream(node_b.get_id(), ["/echo_a1/1.0.0"])
stream_a2 = await node_a.new_stream(node_b.get_id(), ["/echo_a2/1.0.0"])
stream_a3 = await node_a.new_stream(node_b.get_id(), ["/echo_a3/1.0.0"])
messages = ["hello" + str(x) for x in range(10)]
for message in messages:
a1_message = message + "_a1"
a2_message = message + "_a2"
a3_message = message + "_a3"
await stream_a1.write(a1_message.encode())
await stream_a2.write(a2_message.encode())
await stream_a3.write(a3_message.encode())
response_a1 = (await stream_a1.read(MAX_READ_LEN)).decode()
response_a2 = (await stream_a2.read(MAX_READ_LEN)).decode()
response_a3 = (await stream_a3.read(MAX_READ_LEN)).decode()
assert (
response_a1 == ("ack_a1:" + a1_message)
and response_a2 == ("ack_a2:" + a2_message)
and response_a3 == ("ack_a3:" + a3_message)
)
# Success, terminate pending tasks.
@pytest.mark.asyncio
async def test_multiple_streams_two_initiators():
transport_opt_list = [["/ip4/127.0.0.1/tcp/0"], ["/ip4/127.0.0.1/tcp/0"]]
(node_a, node_b) = await set_up_nodes_by_transport_opt(transport_opt_list)
async def stream_handler_a1(stream):
while True:
read_string = (await stream.read(MAX_READ_LEN)).decode()
response = "ack_a1:" + read_string
await stream.write(response.encode())
async def stream_handler_a2(stream):
while True:
read_string = (await stream.read(MAX_READ_LEN)).decode()
response = "ack_a2:" + read_string
await stream.write(response.encode())
async def stream_handler_b1(stream):
while True:
read_string = (await stream.read(MAX_READ_LEN)).decode()
response = "ack_b1:" + read_string
await stream.write(response.encode())
async def stream_handler_b2(stream):
while True:
read_string = (await stream.read(MAX_READ_LEN)).decode()
response = "ack_b2:" + read_string
await stream.write(response.encode())
node_a.set_stream_handler("/echo_b1/1.0.0", stream_handler_b1)
node_a.set_stream_handler("/echo_b2/1.0.0", stream_handler_b2)
node_b.set_stream_handler("/echo_a1/1.0.0", stream_handler_a1)
node_b.set_stream_handler("/echo_a2/1.0.0", stream_handler_a2)
# Associate the peer with local ip address (see default parameters of Libp2p())
node_a.get_peerstore().add_addrs(node_b.get_id(), node_b.get_addrs(), 10)
node_b.get_peerstore().add_addrs(node_a.get_id(), node_a.get_addrs(), 10)
stream_a1 = await node_a.new_stream(node_b.get_id(), ["/echo_a1/1.0.0"])
stream_a2 = await node_a.new_stream(node_b.get_id(), ["/echo_a2/1.0.0"])
stream_b1 = await node_b.new_stream(node_a.get_id(), ["/echo_b1/1.0.0"])
stream_b2 = await node_b.new_stream(node_a.get_id(), ["/echo_b2/1.0.0"])
# A writes to /echo_b via stream_a, and B writes to /echo_a via stream_b
messages = ["hello" + str(x) for x in range(10)]
for message in messages:
a1_message = message + "_a1"
a2_message = message + "_a2"
b1_message = message + "_b1"
b2_message = message + "_b2"
await stream_a1.write(a1_message.encode())
await stream_a2.write(a2_message.encode())
await stream_b1.write(b1_message.encode())
await stream_b2.write(b2_message.encode())
response_a1 = (await stream_a1.read(MAX_READ_LEN)).decode()
response_a2 = (await stream_a2.read(MAX_READ_LEN)).decode()
response_b1 = (await stream_b1.read(MAX_READ_LEN)).decode()
response_b2 = (await stream_b2.read(MAX_READ_LEN)).decode()
assert (
response_a1 == ("ack_a1:" + a1_message)
and response_a2 == ("ack_a2:" + a2_message)
and response_b1 == ("ack_b1:" + b1_message)
and response_b2 == ("ack_b2:" + b2_message)
)
# Success, terminate pending tasks.
@pytest.mark.asyncio
async def test_triangle_nodes_connection():
transport_opt_list = [
["/ip4/127.0.0.1/tcp/0"],
["/ip4/127.0.0.1/tcp/0"],
["/ip4/127.0.0.1/tcp/0"],
]
(node_a, node_b, node_c) = await set_up_nodes_by_transport_opt(transport_opt_list)
async def stream_handler(stream):
while True:
read_string = (await stream.read(MAX_READ_LEN)).decode()
response = "ack:" + read_string
await stream.write(response.encode())
node_a.set_stream_handler("/echo/1.0.0", stream_handler)
node_b.set_stream_handler("/echo/1.0.0", stream_handler)
node_c.set_stream_handler("/echo/1.0.0", stream_handler)
# Associate the peer with local ip address (see default parameters of Libp2p())
# Associate all permutations
node_a.get_peerstore().add_addrs(node_b.get_id(), node_b.get_addrs(), 10)
node_a.get_peerstore().add_addrs(node_c.get_id(), node_c.get_addrs(), 10)
node_b.get_peerstore().add_addrs(node_a.get_id(), node_a.get_addrs(), 10)
node_b.get_peerstore().add_addrs(node_c.get_id(), node_c.get_addrs(), 10)
node_c.get_peerstore().add_addrs(node_a.get_id(), node_a.get_addrs(), 10)
node_c.get_peerstore().add_addrs(node_b.get_id(), node_b.get_addrs(), 10)
stream_a_to_b = await node_a.new_stream(node_b.get_id(), ["/echo/1.0.0"])
stream_a_to_c = await node_a.new_stream(node_c.get_id(), ["/echo/1.0.0"])
stream_b_to_a = await node_b.new_stream(node_a.get_id(), ["/echo/1.0.0"])
stream_b_to_c = await node_b.new_stream(node_c.get_id(), ["/echo/1.0.0"])
stream_c_to_a = await node_c.new_stream(node_a.get_id(), ["/echo/1.0.0"])
stream_c_to_b = await node_c.new_stream(node_b.get_id(), ["/echo/1.0.0"])
messages = ["hello" + str(x) for x in range(5)]
streams = [
stream_a_to_b,
stream_a_to_c,
stream_b_to_a,
stream_b_to_c,
stream_c_to_a,
stream_c_to_b,
]
for message in messages:
for stream in streams:
messages = ["hello" + str(x) for x in range(10)]
for message in messages:
await stream.write(message.encode())
response = (await stream.read(MAX_READ_LEN)).decode()
response1 = (await stream.read(MAX_READ_LEN)).decode()
assert response1 == (ACK_STR_0 + message)
assert response == ("ack:" + message)
# Success, terminate pending tasks.
response2 = (await stream.read(MAX_READ_LEN)).decode()
assert response2 == (ACK_STR_1 + message)
@pytest.mark.asyncio
async def test_host_connect():
transport_opt_list = [["/ip4/127.0.0.1/tcp/0"], ["/ip4/127.0.0.1/tcp/0"]]
(node_a, node_b) = await set_up_nodes_by_transport_opt(transport_opt_list)
@pytest.mark.trio
async def test_multiple_streams(is_host_secure):
# hosts[0] should be able to open a stream with hosts[1] and then vice versa.
# Stream IDs should be generated uniquely so that the stream state is not overwritten
# Only our peer ID is stored in peer store
assert len(node_a.get_peerstore().peer_ids()) == 1
async with HostFactory.create_batch_and_listen(is_host_secure, 2) as hosts:
hosts[0].set_stream_handler(
PROTOCOL_ID_0, create_echo_stream_handler(ACK_STR_0)
)
hosts[1].set_stream_handler(
PROTOCOL_ID_1, create_echo_stream_handler(ACK_STR_1)
)
addr = node_b.get_addrs()[0]
info = info_from_p2p_addr(addr)
await node_a.connect(info)
# Associate the peer with local ip address (see default parameters of Libp2p())
hosts[0].get_peerstore().add_addrs(hosts[1].get_id(), hosts[1].get_addrs(), 10)
hosts[1].get_peerstore().add_addrs(hosts[0].get_id(), hosts[0].get_addrs(), 10)
assert len(node_a.get_peerstore().peer_ids()) == 2
stream_a = await hosts[0].new_stream(hosts[1].get_id(), [PROTOCOL_ID_1])
stream_b = await hosts[1].new_stream(hosts[0].get_id(), [PROTOCOL_ID_0])
await node_a.connect(info)
# A writes to /echo_b via stream_a, and B writes to /echo_a via stream_b
messages = ["hello" + str(x) for x in range(10)]
for message in messages:
a_message = message + "_a"
b_message = message + "_b"
# make sure we don't do double connection
assert len(node_a.get_peerstore().peer_ids()) == 2
await stream_a.write(a_message.encode())
await stream_b.write(b_message.encode())
assert node_b.get_id() in node_a.get_peerstore().peer_ids()
ma_node_b = multiaddr.Multiaddr("/p2p/%s" % node_b.get_id().pretty())
for addr in node_a.get_peerstore().addrs(node_b.get_id()):
assert addr.encapsulate(ma_node_b) in node_b.get_addrs()
response_a = (await stream_a.read(MAX_READ_LEN)).decode()
response_b = (await stream_b.read(MAX_READ_LEN)).decode()
# Success, terminate pending tasks.
assert response_a == (ACK_STR_1 + a_message) and response_b == (
ACK_STR_0 + b_message
)
@pytest.mark.trio
async def test_multiple_streams_same_initiator_different_protocols(is_host_secure):
async with HostFactory.create_batch_and_listen(is_host_secure, 2) as hosts:
hosts[1].set_stream_handler(
PROTOCOL_ID_0, create_echo_stream_handler(ACK_STR_0)
)
hosts[1].set_stream_handler(
PROTOCOL_ID_1, create_echo_stream_handler(ACK_STR_1)
)
hosts[1].set_stream_handler(
PROTOCOL_ID_2, create_echo_stream_handler(ACK_STR_2)
)
# Associate the peer with local ip address (see default parameters of Libp2p())
hosts[0].get_peerstore().add_addrs(hosts[1].get_id(), hosts[1].get_addrs(), 10)
hosts[1].get_peerstore().add_addrs(hosts[0].get_id(), hosts[0].get_addrs(), 10)
# Open streams to hosts[1] over echo_a1 echo_a2 echo_a3 protocols
stream_a1 = await hosts[0].new_stream(hosts[1].get_id(), [PROTOCOL_ID_0])
stream_a2 = await hosts[0].new_stream(hosts[1].get_id(), [PROTOCOL_ID_1])
stream_a3 = await hosts[0].new_stream(hosts[1].get_id(), [PROTOCOL_ID_2])
messages = ["hello" + str(x) for x in range(10)]
for message in messages:
a1_message = message + "_a1"
a2_message = message + "_a2"
a3_message = message + "_a3"
await stream_a1.write(a1_message.encode())
await stream_a2.write(a2_message.encode())
await stream_a3.write(a3_message.encode())
response_a1 = (await stream_a1.read(MAX_READ_LEN)).decode()
response_a2 = (await stream_a2.read(MAX_READ_LEN)).decode()
response_a3 = (await stream_a3.read(MAX_READ_LEN)).decode()
assert (
response_a1 == (ACK_STR_0 + a1_message)
and response_a2 == (ACK_STR_1 + a2_message)
and response_a3 == (ACK_STR_2 + a3_message)
)
# Success, terminate pending tasks.
@pytest.mark.trio
async def test_multiple_streams_two_initiators(is_host_secure):
async with HostFactory.create_batch_and_listen(is_host_secure, 2) as hosts:
hosts[0].set_stream_handler(
PROTOCOL_ID_2, create_echo_stream_handler(ACK_STR_2)
)
hosts[0].set_stream_handler(
PROTOCOL_ID_3, create_echo_stream_handler(ACK_STR_3)
)
hosts[1].set_stream_handler(
PROTOCOL_ID_0, create_echo_stream_handler(ACK_STR_0)
)
hosts[1].set_stream_handler(
PROTOCOL_ID_1, create_echo_stream_handler(ACK_STR_1)
)
# Associate the peer with local ip address (see default parameters of Libp2p())
hosts[0].get_peerstore().add_addrs(hosts[1].get_id(), hosts[1].get_addrs(), 10)
hosts[1].get_peerstore().add_addrs(hosts[0].get_id(), hosts[0].get_addrs(), 10)
stream_a1 = await hosts[0].new_stream(hosts[1].get_id(), [PROTOCOL_ID_0])
stream_a2 = await hosts[0].new_stream(hosts[1].get_id(), [PROTOCOL_ID_1])
stream_b1 = await hosts[1].new_stream(hosts[0].get_id(), [PROTOCOL_ID_2])
stream_b2 = await hosts[1].new_stream(hosts[0].get_id(), [PROTOCOL_ID_3])
# A writes to /echo_b via stream_a, and B writes to /echo_a via stream_b
messages = ["hello" + str(x) for x in range(10)]
for message in messages:
a1_message = message + "_a1"
a2_message = message + "_a2"
b1_message = message + "_b1"
b2_message = message + "_b2"
await stream_a1.write(a1_message.encode())
await stream_a2.write(a2_message.encode())
await stream_b1.write(b1_message.encode())
await stream_b2.write(b2_message.encode())
response_a1 = (await stream_a1.read(MAX_READ_LEN)).decode()
response_a2 = (await stream_a2.read(MAX_READ_LEN)).decode()
response_b1 = (await stream_b1.read(MAX_READ_LEN)).decode()
response_b2 = (await stream_b2.read(MAX_READ_LEN)).decode()
assert (
response_a1 == (ACK_STR_0 + a1_message)
and response_a2 == (ACK_STR_1 + a2_message)
and response_b1 == (ACK_STR_2 + b1_message)
and response_b2 == (ACK_STR_3 + b2_message)
)
@pytest.mark.trio
async def test_triangle_nodes_connection(is_host_secure):
async with HostFactory.create_batch_and_listen(is_host_secure, 3) as hosts:
hosts[0].set_stream_handler(
PROTOCOL_ID_0, create_echo_stream_handler(ACK_STR_0)
)
hosts[1].set_stream_handler(
PROTOCOL_ID_0, create_echo_stream_handler(ACK_STR_0)
)
hosts[2].set_stream_handler(
PROTOCOL_ID_0, create_echo_stream_handler(ACK_STR_0)
)
# Associate the peer with local ip address (see default parameters of Libp2p())
# Associate all permutations
hosts[0].get_peerstore().add_addrs(hosts[1].get_id(), hosts[1].get_addrs(), 10)
hosts[0].get_peerstore().add_addrs(hosts[2].get_id(), hosts[2].get_addrs(), 10)
hosts[1].get_peerstore().add_addrs(hosts[0].get_id(), hosts[0].get_addrs(), 10)
hosts[1].get_peerstore().add_addrs(hosts[2].get_id(), hosts[2].get_addrs(), 10)
hosts[2].get_peerstore().add_addrs(hosts[0].get_id(), hosts[0].get_addrs(), 10)
hosts[2].get_peerstore().add_addrs(hosts[1].get_id(), hosts[1].get_addrs(), 10)
stream_0_to_1 = await hosts[0].new_stream(hosts[1].get_id(), [PROTOCOL_ID_0])
stream_0_to_2 = await hosts[0].new_stream(hosts[2].get_id(), [PROTOCOL_ID_0])
stream_1_to_0 = await hosts[1].new_stream(hosts[0].get_id(), [PROTOCOL_ID_0])
stream_1_to_2 = await hosts[1].new_stream(hosts[2].get_id(), [PROTOCOL_ID_0])
stream_2_to_0 = await hosts[2].new_stream(hosts[0].get_id(), [PROTOCOL_ID_0])
stream_2_to_1 = await hosts[2].new_stream(hosts[1].get_id(), [PROTOCOL_ID_0])
messages = ["hello" + str(x) for x in range(5)]
streams = [
stream_0_to_1,
stream_0_to_2,
stream_1_to_0,
stream_1_to_2,
stream_2_to_0,
stream_2_to_1,
]
for message in messages:
for stream in streams:
await stream.write(message.encode())
response = (await stream.read(MAX_READ_LEN)).decode()
assert response == (ACK_STR_0 + message)
@pytest.mark.trio
async def test_host_connect(is_host_secure):
async with HostFactory.create_batch_and_listen(is_host_secure, 2) as hosts:
assert len(hosts[0].get_peerstore().peer_ids()) == 1
await connect(hosts[0], hosts[1])
assert len(hosts[0].get_peerstore().peer_ids()) == 2
await connect(hosts[0], hosts[1])
# make sure we don't do double connection
assert len(hosts[0].get_peerstore().peer_ids()) == 2
assert hosts[1].get_id() in hosts[0].get_peerstore().peer_ids()
ma_node_b = multiaddr.Multiaddr("/p2p/%s" % hosts[1].get_id().pretty())
for addr in hosts[0].get_peerstore().addrs(hosts[1].get_id()):
assert addr.encapsulate(ma_node_b) in hosts[1].get_addrs()

View File

@ -1,5 +1,3 @@
import asyncio
import pytest
from libp2p.tools.factories import (
@ -11,26 +9,17 @@ from libp2p.tools.factories import (
@pytest.fixture
async def net_stream_pair(is_host_secure):
stream_0, host_0, stream_1, host_1 = await net_stream_pair_factory(is_host_secure)
try:
yield stream_0, stream_1
finally:
await asyncio.gather(*[host_0.close(), host_1.close()])
async with net_stream_pair_factory(is_host_secure) as net_stream_pair:
yield net_stream_pair
@pytest.fixture
async def swarm_pair(is_host_secure):
swarm_0, swarm_1 = await swarm_pair_factory(is_host_secure)
try:
yield swarm_0, swarm_1
finally:
await asyncio.gather(*[swarm_0.close(), swarm_1.close()])
async with swarm_pair_factory(is_host_secure) as swarms:
yield swarms
@pytest.fixture
async def swarm_conn_pair(is_host_secure):
conn_0, swarm_0, conn_1, swarm_1 = await swarm_conn_pair_factory(is_host_secure)
try:
yield conn_0, conn_1
finally:
await asyncio.gather(*[swarm_0.close(), swarm_1.close()])
async with swarm_conn_pair_factory(is_host_secure) as swarm_conn_pair:
yield swarm_conn_pair

View File

@ -1,6 +1,5 @@
import asyncio
import pytest
import trio
from libp2p.network.stream.exceptions import StreamClosed, StreamEOF, StreamReset
from libp2p.tools.constants import MAX_READ_LEN
@ -8,7 +7,7 @@ from libp2p.tools.constants import MAX_READ_LEN
DATA = b"data_123"
@pytest.mark.asyncio
@pytest.mark.trio
async def test_net_stream_read_write(net_stream_pair):
stream_0, stream_1 = net_stream_pair
assert (
@ -19,7 +18,7 @@ async def test_net_stream_read_write(net_stream_pair):
assert (await stream_1.read(MAX_READ_LEN)) == DATA
@pytest.mark.asyncio
@pytest.mark.trio
async def test_net_stream_read_until_eof(net_stream_pair):
read_bytes = bytearray()
stream_0, stream_1 = net_stream_pair
@ -27,41 +26,39 @@ async def test_net_stream_read_until_eof(net_stream_pair):
async def read_until_eof():
read_bytes.extend(await stream_1.read())
task = asyncio.ensure_future(read_until_eof())
async with trio.open_nursery() as nursery:
nursery.start_soon(read_until_eof)
expected_data = bytearray()
expected_data = bytearray()
# Test: `read` doesn't return before `close` is called.
await stream_0.write(DATA)
expected_data.extend(DATA)
await trio.sleep(0.01)
assert len(read_bytes) == 0
# Test: `read` doesn't return before `close` is called.
await stream_0.write(DATA)
expected_data.extend(DATA)
await trio.sleep(0.01)
assert len(read_bytes) == 0
# Test: `read` doesn't return before `close` is called.
await stream_0.write(DATA)
expected_data.extend(DATA)
await asyncio.sleep(0.01)
assert len(read_bytes) == 0
# Test: `read` doesn't return before `close` is called.
await stream_0.write(DATA)
expected_data.extend(DATA)
await asyncio.sleep(0.01)
assert len(read_bytes) == 0
# Test: Close the stream, `read` returns, and receive previous sent data.
await stream_0.close()
await asyncio.sleep(0.01)
assert read_bytes == expected_data
task.cancel()
# Test: Close the stream, `read` returns, and receive previous sent data.
await stream_0.close()
await trio.sleep(0.01)
assert read_bytes == expected_data
@pytest.mark.asyncio
@pytest.mark.trio
async def test_net_stream_read_after_remote_closed(net_stream_pair):
stream_0, stream_1 = net_stream_pair
await stream_0.write(DATA)
await stream_0.close()
await asyncio.sleep(0.01)
await trio.sleep(0.01)
assert (await stream_1.read(MAX_READ_LEN)) == DATA
with pytest.raises(StreamEOF):
await stream_1.read(MAX_READ_LEN)
@pytest.mark.asyncio
@pytest.mark.trio
async def test_net_stream_read_after_local_reset(net_stream_pair):
stream_0, stream_1 = net_stream_pair
await stream_0.reset()
@ -69,29 +66,29 @@ async def test_net_stream_read_after_local_reset(net_stream_pair):
await stream_0.read(MAX_READ_LEN)
@pytest.mark.asyncio
@pytest.mark.trio
async def test_net_stream_read_after_remote_reset(net_stream_pair):
stream_0, stream_1 = net_stream_pair
await stream_0.write(DATA)
await stream_0.reset()
# Sleep to let `stream_1` receive the message.
await asyncio.sleep(0.01)
await trio.sleep(0.01)
with pytest.raises(StreamReset):
await stream_1.read(MAX_READ_LEN)
@pytest.mark.asyncio
@pytest.mark.trio
async def test_net_stream_read_after_remote_closed_and_reset(net_stream_pair):
stream_0, stream_1 = net_stream_pair
await stream_0.write(DATA)
await stream_0.close()
await stream_0.reset()
# Sleep to let `stream_1` receive the message.
await asyncio.sleep(0.01)
await trio.sleep(0.01)
assert (await stream_1.read(MAX_READ_LEN)) == DATA
@pytest.mark.asyncio
@pytest.mark.trio
async def test_net_stream_write_after_local_closed(net_stream_pair):
stream_0, stream_1 = net_stream_pair
await stream_0.write(DATA)
@ -100,7 +97,7 @@ async def test_net_stream_write_after_local_closed(net_stream_pair):
await stream_0.write(DATA)
@pytest.mark.asyncio
@pytest.mark.trio
async def test_net_stream_write_after_local_reset(net_stream_pair):
stream_0, stream_1 = net_stream_pair
await stream_0.reset()
@ -108,10 +105,10 @@ async def test_net_stream_write_after_local_reset(net_stream_pair):
await stream_0.write(DATA)
@pytest.mark.asyncio
@pytest.mark.trio
async def test_net_stream_write_after_remote_reset(net_stream_pair):
stream_0, stream_1 = net_stream_pair
await stream_1.reset()
await asyncio.sleep(0.01)
await trio.sleep(0.01)
with pytest.raises(StreamClosed):
await stream_0.write(DATA)

View File

@ -8,11 +8,11 @@ into network after network has already started listening
TODO: Add tests for closed_stream, listen_close when those
features are implemented in swarm
"""
import asyncio
import enum
from async_service import background_trio_service
import pytest
import trio
from libp2p.network.notifee_interface import INotifee
from libp2p.tools.constants import LISTEN_MADDR
@ -54,59 +54,63 @@ class MyNotifee(INotifee):
pass
@pytest.mark.asyncio
@pytest.mark.trio
async def test_notify(is_host_secure):
swarms = [SwarmFactory(is_secure=is_host_secure) for _ in range(2)]
events_0_0 = []
events_1_0 = []
events_0_without_listen = []
swarms[0].register_notifee(MyNotifee(events_0_0))
swarms[1].register_notifee(MyNotifee(events_1_0))
# Listen
await asyncio.gather(*[swarm.listen(LISTEN_MADDR) for swarm in swarms])
# Run swarms.
async with background_trio_service(swarms[0]), background_trio_service(swarms[1]):
# Register events before listening, to allow `MyNotifee` is notified with the event
# `listen`.
swarms[0].register_notifee(MyNotifee(events_0_0))
swarms[1].register_notifee(MyNotifee(events_1_0))
swarms[0].register_notifee(MyNotifee(events_0_without_listen))
# Listen
async with trio.open_nursery() as nursery:
nursery.start_soon(swarms[0].listen, LISTEN_MADDR)
nursery.start_soon(swarms[1].listen, LISTEN_MADDR)
# Connected
await connect_swarm(swarms[0], swarms[1])
# OpenedStream: first
await swarms[0].new_stream(swarms[1].get_peer_id())
# OpenedStream: second
await swarms[0].new_stream(swarms[1].get_peer_id())
# OpenedStream: third, but different direction.
await swarms[1].new_stream(swarms[0].get_peer_id())
swarms[0].register_notifee(MyNotifee(events_0_without_listen))
await asyncio.sleep(0.01)
# Connected
await connect_swarm(swarms[0], swarms[1])
# OpenedStream: first
await swarms[0].new_stream(swarms[1].get_peer_id())
# OpenedStream: second
await swarms[0].new_stream(swarms[1].get_peer_id())
# OpenedStream: third, but different direction.
await swarms[1].new_stream(swarms[0].get_peer_id())
# TODO: Check `ClosedStream` and `ListenClose` events after they are ready.
await trio.sleep(0.01)
# Disconnected
await swarms[0].close_peer(swarms[1].get_peer_id())
await asyncio.sleep(0.01)
# TODO: Check `ClosedStream` and `ListenClose` events after they are ready.
# Connected again, but different direction.
await connect_swarm(swarms[1], swarms[0])
await asyncio.sleep(0.01)
# Disconnected
await swarms[0].close_peer(swarms[1].get_peer_id())
await trio.sleep(0.01)
# Disconnected again, but different direction.
await swarms[1].close_peer(swarms[0].get_peer_id())
await asyncio.sleep(0.01)
# Connected again, but different direction.
await connect_swarm(swarms[1], swarms[0])
await trio.sleep(0.01)
expected_events_without_listen = [
Event.Connected,
Event.OpenedStream,
Event.OpenedStream,
Event.OpenedStream,
Event.Disconnected,
Event.Connected,
Event.Disconnected,
]
expected_events = [Event.Listen] + expected_events_without_listen
# Disconnected again, but different direction.
await swarms[1].close_peer(swarms[0].get_peer_id())
await trio.sleep(0.01)
assert events_0_0 == expected_events
assert events_1_0 == expected_events
assert events_0_without_listen == expected_events_without_listen
expected_events_without_listen = [
Event.Connected,
Event.OpenedStream,
Event.OpenedStream,
Event.OpenedStream,
Event.Disconnected,
Event.Connected,
Event.Disconnected,
]
expected_events = [Event.Listen] + expected_events_without_listen
# Clean up
await asyncio.gather(*[swarm.close() for swarm in swarms])
assert events_0_0 == expected_events
assert events_1_0 == expected_events
assert events_0_without_listen == expected_events_without_listen

View File

@ -1,89 +1,84 @@
import asyncio
from multiaddr import Multiaddr
import pytest
import trio
from trio.testing import wait_all_tasks_blocked
from libp2p.network.exceptions import SwarmException
from libp2p.tools.factories import SwarmFactory
from libp2p.tools.utils import connect_swarm
@pytest.mark.asyncio
@pytest.mark.trio
async def test_swarm_dial_peer(is_host_secure):
swarms = await SwarmFactory.create_batch_and_listen(is_host_secure, 3)
# Test: No addr found.
with pytest.raises(SwarmException):
async with SwarmFactory.create_batch_and_listen(is_host_secure, 3) as swarms:
# Test: No addr found.
with pytest.raises(SwarmException):
await swarms[0].dial_peer(swarms[1].get_peer_id())
# Test: len(addr) in the peerstore is 0.
swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), [], 10000)
with pytest.raises(SwarmException):
await swarms[0].dial_peer(swarms[1].get_peer_id())
# Test: Succeed if addrs of the peer_id are present in the peerstore.
addrs = tuple(
addr
for transport in swarms[1].listeners.values()
for addr in transport.get_addrs()
)
swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), addrs, 10000)
await swarms[0].dial_peer(swarms[1].get_peer_id())
assert swarms[0].get_peer_id() in swarms[1].connections
assert swarms[1].get_peer_id() in swarms[0].connections
# Test: len(addr) in the peerstore is 0.
swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), [], 10000)
with pytest.raises(SwarmException):
await swarms[0].dial_peer(swarms[1].get_peer_id())
# Test: Succeed if addrs of the peer_id are present in the peerstore.
addrs = tuple(
addr
for transport in swarms[1].listeners.values()
for addr in transport.get_addrs()
)
swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), addrs, 10000)
await swarms[0].dial_peer(swarms[1].get_peer_id())
assert swarms[0].get_peer_id() in swarms[1].connections
assert swarms[1].get_peer_id() in swarms[0].connections
# Test: Reuse connections when we already have ones with a peer.
conn_to_1 = swarms[0].connections[swarms[1].get_peer_id()]
conn = await swarms[0].dial_peer(swarms[1].get_peer_id())
assert conn is conn_to_1
# Clean up
await asyncio.gather(*[swarm.close() for swarm in swarms])
# Test: Reuse connections when we already have ones with a peer.
conn_to_1 = swarms[0].connections[swarms[1].get_peer_id()]
conn = await swarms[0].dial_peer(swarms[1].get_peer_id())
assert conn is conn_to_1
@pytest.mark.asyncio
@pytest.mark.trio
async def test_swarm_close_peer(is_host_secure):
swarms = await SwarmFactory.create_batch_and_listen(is_host_secure, 3)
# 0 <> 1 <> 2
await connect_swarm(swarms[0], swarms[1])
await connect_swarm(swarms[1], swarms[2])
async with SwarmFactory.create_batch_and_listen(is_host_secure, 3) as swarms:
# 0 <> 1 <> 2
await connect_swarm(swarms[0], swarms[1])
await connect_swarm(swarms[1], swarms[2])
# peer 1 closes peer 0
await swarms[1].close_peer(swarms[0].get_peer_id())
await asyncio.sleep(0.01)
# 0 1 <> 2
assert len(swarms[0].connections) == 0
assert (
len(swarms[1].connections) == 1
and swarms[2].get_peer_id() in swarms[1].connections
)
# peer 1 closes peer 0
await swarms[1].close_peer(swarms[0].get_peer_id())
await trio.sleep(0.01)
await wait_all_tasks_blocked()
# 0 1 <> 2
assert len(swarms[0].connections) == 0
assert (
len(swarms[1].connections) == 1
and swarms[2].get_peer_id() in swarms[1].connections
)
# peer 1 is closed by peer 2
await swarms[2].close_peer(swarms[1].get_peer_id())
await asyncio.sleep(0.01)
# 0 1 2
assert len(swarms[1].connections) == 0 and len(swarms[2].connections) == 0
# peer 1 is closed by peer 2
await swarms[2].close_peer(swarms[1].get_peer_id())
await trio.sleep(0.01)
# 0 1 2
assert len(swarms[1].connections) == 0 and len(swarms[2].connections) == 0
await connect_swarm(swarms[0], swarms[1])
# 0 <> 1 2
assert (
len(swarms[0].connections) == 1
and swarms[1].get_peer_id() in swarms[0].connections
)
assert (
len(swarms[1].connections) == 1
and swarms[0].get_peer_id() in swarms[1].connections
)
# peer 0 closes peer 1
await swarms[0].close_peer(swarms[1].get_peer_id())
await asyncio.sleep(0.01)
# 0 1 2
assert len(swarms[1].connections) == 0 and len(swarms[2].connections) == 0
# Clean up
await asyncio.gather(*[swarm.close() for swarm in swarms])
await connect_swarm(swarms[0], swarms[1])
# 0 <> 1 2
assert (
len(swarms[0].connections) == 1
and swarms[1].get_peer_id() in swarms[0].connections
)
assert (
len(swarms[1].connections) == 1
and swarms[0].get_peer_id() in swarms[1].connections
)
# peer 0 closes peer 1
await swarms[0].close_peer(swarms[1].get_peer_id())
await trio.sleep(0.01)
# 0 1 2
assert len(swarms[1].connections) == 0 and len(swarms[2].connections) == 0
@pytest.mark.asyncio
@pytest.mark.trio
async def test_swarm_remove_conn(swarm_pair):
swarm_0, swarm_1 = swarm_pair
conn_0 = swarm_0.connections[swarm_1.get_peer_id()]
@ -94,57 +89,54 @@ async def test_swarm_remove_conn(swarm_pair):
assert swarm_1.get_peer_id() not in swarm_0.connections
@pytest.mark.asyncio
@pytest.mark.trio
async def test_swarm_multiaddr(is_host_secure):
swarms = await SwarmFactory.create_batch_and_listen(is_host_secure, 3)
async with SwarmFactory.create_batch_and_listen(is_host_secure, 3) as swarms:
def clear():
swarms[0].peerstore.clear_addrs(swarms[1].get_peer_id())
def clear():
swarms[0].peerstore.clear_addrs(swarms[1].get_peer_id())
clear()
# No addresses
with pytest.raises(SwarmException):
clear()
# No addresses
with pytest.raises(SwarmException):
await swarms[0].dial_peer(swarms[1].get_peer_id())
clear()
# Wrong addresses
swarms[0].peerstore.add_addrs(
swarms[1].get_peer_id(), [Multiaddr("/ip4/0.0.0.0/tcp/9999")], 10000
)
with pytest.raises(SwarmException):
await swarms[0].dial_peer(swarms[1].get_peer_id())
clear()
# Multiple wrong addresses
swarms[0].peerstore.add_addrs(
swarms[1].get_peer_id(),
[Multiaddr("/ip4/0.0.0.0/tcp/9999"), Multiaddr("/ip4/0.0.0.0/tcp/9998")],
10000,
)
with pytest.raises(SwarmException):
await swarms[0].dial_peer(swarms[1].get_peer_id())
# Test one address
addrs = tuple(
addr
for transport in swarms[1].listeners.values()
for addr in transport.get_addrs()
)
swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), addrs[:1], 10000)
await swarms[0].dial_peer(swarms[1].get_peer_id())
clear()
# Wrong addresses
swarms[0].peerstore.add_addrs(
swarms[1].get_peer_id(), [Multiaddr("/ip4/0.0.0.0/tcp/9999")], 10000
)
# Test multiple addresses
addrs = tuple(
addr
for transport in swarms[1].listeners.values()
for addr in transport.get_addrs()
)
with pytest.raises(SwarmException):
swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), addrs + addrs, 10000)
await swarms[0].dial_peer(swarms[1].get_peer_id())
clear()
# Multiple wrong addresses
swarms[0].peerstore.add_addrs(
swarms[1].get_peer_id(),
[Multiaddr("/ip4/0.0.0.0/tcp/9999"), Multiaddr("/ip4/0.0.0.0/tcp/9998")],
10000,
)
with pytest.raises(SwarmException):
await swarms[0].dial_peer(swarms[1].get_peer_id())
# Test one address
addrs = tuple(
addr
for transport in swarms[1].listeners.values()
for addr in transport.get_addrs()
)
swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), addrs[:1], 10000)
await swarms[0].dial_peer(swarms[1].get_peer_id())
# Test multiple addresses
addrs = tuple(
addr
for transport in swarms[1].listeners.values()
for addr in transport.get_addrs()
)
swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), addrs + addrs, 10000)
await swarms[0].dial_peer(swarms[1].get_peer_id())
for swarm in swarms:
await swarm.close()

View File

@ -1,45 +1,46 @@
import asyncio
import pytest
import trio
from trio.testing import wait_all_tasks_blocked
@pytest.mark.asyncio
@pytest.mark.trio
async def test_swarm_conn_close(swarm_conn_pair):
conn_0, conn_1 = swarm_conn_pair
assert not conn_0.event_closed.is_set()
assert not conn_1.event_closed.is_set()
assert not conn_0.is_closed
assert not conn_1.is_closed
await conn_0.close()
await asyncio.sleep(0.01)
await trio.sleep(0.1)
await wait_all_tasks_blocked()
assert conn_0.event_closed.is_set()
assert conn_1.event_closed.is_set()
assert conn_0.is_closed
assert conn_1.is_closed
assert conn_0 not in conn_0.swarm.connections.values()
assert conn_1 not in conn_1.swarm.connections.values()
@pytest.mark.asyncio
@pytest.mark.trio
async def test_swarm_conn_streams(swarm_conn_pair):
conn_0, conn_1 = swarm_conn_pair
assert len(await conn_0.get_streams()) == 0
assert len(await conn_1.get_streams()) == 0
assert len(conn_0.get_streams()) == 0
assert len(conn_1.get_streams()) == 0
stream_0_0 = await conn_0.new_stream()
await asyncio.sleep(0.01)
assert len(await conn_0.get_streams()) == 1
assert len(await conn_1.get_streams()) == 1
await trio.sleep(0.01)
assert len(conn_0.get_streams()) == 1
assert len(conn_1.get_streams()) == 1
stream_0_1 = await conn_0.new_stream()
await asyncio.sleep(0.01)
assert len(await conn_0.get_streams()) == 2
assert len(await conn_1.get_streams()) == 2
await trio.sleep(0.01)
assert len(conn_0.get_streams()) == 2
assert len(conn_1.get_streams()) == 2
conn_0.remove_stream(stream_0_0)
assert len(await conn_0.get_streams()) == 1
assert len(conn_0.get_streams()) == 1
conn_0.remove_stream(stream_0_1)
assert len(await conn_0.get_streams()) == 0
assert len(conn_0.get_streams()) == 0
# Nothing happen if `stream_0_1` is not present or already removed.
conn_0.remove_stream(stream_0_1)

View File

@ -25,8 +25,6 @@ def test_init_():
@pytest.mark.parametrize(
"addr",
(
pytest.param(None),
pytest.param(random.randint(0, 255), id="random integer"),
pytest.param(multiaddr.Multiaddr("/"), id="empty multiaddr"),
pytest.param(
multiaddr.Multiaddr("/ip4/127.0.0.1"),

View File

@ -1,83 +1,94 @@
import pytest
from libp2p.host.exceptions import StreamFailure
from libp2p.tools.utils import echo_stream_handler, set_up_nodes_by_transport_opt
from libp2p.tools.factories import HostFactory
from libp2p.tools.utils import create_echo_stream_handler
# TODO: Add tests for multiple streams being opened on different
# protocols through the same connection
PROTOCOL_ECHO = "/echo/1.0.0"
PROTOCOL_POTATO = "/potato/1.0.0"
PROTOCOL_FOO = "/foo/1.0.0"
PROTOCOL_ROCK = "/rock/1.0.0"
# Note: async issues occurred when using the same port
# so that's why I use different ports here.
# TODO: modify tests so that those async issues don't occur
# when using the same ports across tests
ACK_PREFIX = "ack:"
async def perform_simple_test(
expected_selected_protocol, protocols_for_client, protocols_with_handlers
expected_selected_protocol,
protocols_for_client,
protocols_with_handlers,
is_host_secure,
):
transport_opt_list = [["/ip4/127.0.0.1/tcp/0"], ["/ip4/127.0.0.1/tcp/0"]]
(node_a, node_b) = await set_up_nodes_by_transport_opt(transport_opt_list)
async with HostFactory.create_batch_and_listen(is_host_secure, 2) as hosts:
for protocol in protocols_with_handlers:
hosts[1].set_stream_handler(
protocol, create_echo_stream_handler(ACK_PREFIX)
)
for protocol in protocols_with_handlers:
node_b.set_stream_handler(protocol, echo_stream_handler)
# Associate the peer with local ip address (see default parameters of Libp2p())
hosts[0].get_peerstore().add_addrs(hosts[1].get_id(), hosts[1].get_addrs(), 10)
stream = await hosts[0].new_stream(hosts[1].get_id(), protocols_for_client)
messages = ["hello" + str(x) for x in range(10)]
for message in messages:
expected_resp = "ack:" + message
await stream.write(message.encode())
response = (await stream.read(len(expected_resp))).decode()
assert response == expected_resp
# Associate the peer with local ip address (see default parameters of Libp2p())
node_a.get_peerstore().add_addrs(node_b.get_id(), node_b.get_addrs(), 10)
stream = await node_a.new_stream(node_b.get_id(), protocols_for_client)
messages = ["hello" + str(x) for x in range(10)]
for message in messages:
expected_resp = "ack:" + message
await stream.write(message.encode())
response = (await stream.read(len(expected_resp))).decode()
assert response == expected_resp
assert expected_selected_protocol == stream.get_protocol()
# Success, terminate pending tasks.
assert expected_selected_protocol == stream.get_protocol()
@pytest.mark.asyncio
async def test_single_protocol_succeeds():
expected_selected_protocol = "/echo/1.0.0"
@pytest.mark.trio
async def test_single_protocol_succeeds(is_host_secure):
expected_selected_protocol = PROTOCOL_ECHO
await perform_simple_test(
expected_selected_protocol, ["/echo/1.0.0"], ["/echo/1.0.0"]
expected_selected_protocol,
[expected_selected_protocol],
[expected_selected_protocol],
is_host_secure,
)
@pytest.mark.asyncio
async def test_single_protocol_fails():
@pytest.mark.trio
async def test_single_protocol_fails(is_host_secure):
with pytest.raises(StreamFailure):
await perform_simple_test("", ["/echo/1.0.0"], ["/potato/1.0.0"])
await perform_simple_test(
"", [PROTOCOL_ECHO], [PROTOCOL_POTATO], is_host_secure
)
# Cleanup not reached on error
@pytest.mark.asyncio
async def test_multiple_protocol_first_is_valid_succeeds():
expected_selected_protocol = "/echo/1.0.0"
protocols_for_client = ["/echo/1.0.0", "/potato/1.0.0"]
protocols_for_listener = ["/foo/1.0.0", "/echo/1.0.0"]
@pytest.mark.trio
async def test_multiple_protocol_first_is_valid_succeeds(is_host_secure):
expected_selected_protocol = PROTOCOL_ECHO
protocols_for_client = [PROTOCOL_ECHO, PROTOCOL_POTATO]
protocols_for_listener = [PROTOCOL_FOO, PROTOCOL_ECHO]
await perform_simple_test(
expected_selected_protocol, protocols_for_client, protocols_for_listener
expected_selected_protocol,
protocols_for_client,
protocols_for_listener,
is_host_secure,
)
@pytest.mark.asyncio
async def test_multiple_protocol_second_is_valid_succeeds():
expected_selected_protocol = "/foo/1.0.0"
protocols_for_client = ["/rock/1.0.0", "/foo/1.0.0"]
protocols_for_listener = ["/foo/1.0.0", "/echo/1.0.0"]
@pytest.mark.trio
async def test_multiple_protocol_second_is_valid_succeeds(is_host_secure):
expected_selected_protocol = PROTOCOL_FOO
protocols_for_client = [PROTOCOL_ROCK, PROTOCOL_FOO]
protocols_for_listener = [PROTOCOL_FOO, PROTOCOL_ECHO]
await perform_simple_test(
expected_selected_protocol, protocols_for_client, protocols_for_listener
expected_selected_protocol,
protocols_for_client,
protocols_for_listener,
is_host_secure,
)
@pytest.mark.asyncio
async def test_multiple_protocol_fails():
protocols_for_client = ["/rock/1.0.0", "/foo/1.0.0", "/bar/1.0.0"]
@pytest.mark.trio
async def test_multiple_protocol_fails(is_host_secure):
protocols_for_client = [PROTOCOL_ROCK, PROTOCOL_FOO, "/bar/1.0.0"]
protocols_for_listener = ["/aspyn/1.0.0", "/rob/1.0.0", "/zx/1.0.0", "/alex/1.0.0"]
with pytest.raises(StreamFailure):
await perform_simple_test("", protocols_for_client, protocols_for_listener)
# Cleanup not reached on error
await perform_simple_test(
"", protocols_for_client, protocols_for_listener, is_host_secure
)

View File

@ -1,58 +0,0 @@
import pytest
from libp2p.tools.constants import GOSSIPSUB_PARAMS
from libp2p.tools.factories import FloodsubFactory, GossipsubFactory, PubsubFactory
@pytest.fixture
def is_strict_signing():
return False
def _make_pubsubs(hosts, pubsub_routers, cache_size, is_strict_signing):
if len(pubsub_routers) != len(hosts):
raise ValueError(
f"lenght of pubsub_routers={pubsub_routers} should be equaled to the "
f"length of hosts={len(hosts)}"
)
return tuple(
PubsubFactory(
host=host,
router=router,
cache_size=cache_size,
strict_signing=is_strict_signing,
)
for host, router in zip(hosts, pubsub_routers)
)
@pytest.fixture
def pubsub_cache_size():
return None # default
@pytest.fixture
def gossipsub_params():
return GOSSIPSUB_PARAMS
@pytest.fixture
def pubsubs_fsub(num_hosts, hosts, pubsub_cache_size, is_strict_signing):
floodsubs = FloodsubFactory.create_batch(num_hosts)
_pubsubs_fsub = _make_pubsubs(
hosts, floodsubs, pubsub_cache_size, is_strict_signing
)
yield _pubsubs_fsub
# TODO: Clean up
@pytest.fixture
def pubsubs_gsub(
num_hosts, hosts, pubsub_cache_size, gossipsub_params, is_strict_signing
):
gossipsubs = GossipsubFactory.create_batch(num_hosts, **gossipsub_params._asdict())
_pubsubs_gsub = _make_pubsubs(
hosts, gossipsubs, pubsub_cache_size, is_strict_signing
)
yield _pubsubs_gsub
# TODO: Clean up

View File

@ -1,19 +1,10 @@
import asyncio
from threading import Thread
import pytest
import trio
from libp2p.tools.pubsub.dummy_account_node import DummyAccountNode
from libp2p.tools.utils import connect
def create_setup_in_new_thread_func(dummy_node):
def setup_in_new_thread():
asyncio.ensure_future(dummy_node.setup_crypto_networking())
return setup_in_new_thread
async def perform_test(num_nodes, adjacency_map, action_func, assertion_func):
"""
Helper function to allow for easy construction of custom tests for dummy
@ -26,47 +17,35 @@ async def perform_test(num_nodes, adjacency_map, action_func, assertion_func):
:param assertion_func: assertions for testing the results of the actions are correct
"""
# Create nodes
dummy_nodes = []
for _ in range(num_nodes):
dummy_nodes.append(await DummyAccountNode.create())
async with DummyAccountNode.create(num_nodes) as dummy_nodes:
# Create connections between nodes according to `adjacency_map`
async with trio.open_nursery() as nursery:
for source_num in adjacency_map:
target_nums = adjacency_map[source_num]
for target_num in target_nums:
nursery.start_soon(
connect,
dummy_nodes[source_num].host,
dummy_nodes[target_num].host,
)
# Create network
for source_num in adjacency_map:
target_nums = adjacency_map[source_num]
for target_num in target_nums:
await connect(
dummy_nodes[source_num].libp2p_node, dummy_nodes[target_num].libp2p_node
)
# Allow time for network creation to take place
await trio.sleep(0.25)
# Allow time for network creation to take place
await asyncio.sleep(0.25)
# Perform action function
await action_func(dummy_nodes)
# Start a thread for each node so that each node can listen and respond
# to messages on its own thread, which will avoid waiting indefinitely
# on the main thread. On this thread, call the setup func for the node,
# which subscribes the node to the CRYPTO_TOPIC topic
for dummy_node in dummy_nodes:
thread = Thread(target=create_setup_in_new_thread_func(dummy_node))
thread.run()
# Allow time for action function to be performed (i.e. messages to propogate)
await trio.sleep(1)
# Allow time for nodes to subscribe to CRYPTO_TOPIC topic
await asyncio.sleep(0.25)
# Perform action function
await action_func(dummy_nodes)
# Allow time for action function to be performed (i.e. messages to propogate)
await asyncio.sleep(1)
# Perform assertion function
for dummy_node in dummy_nodes:
assertion_func(dummy_node)
# Perform assertion function
for dummy_node in dummy_nodes:
assertion_func(dummy_node)
# Success, terminate pending tasks.
@pytest.mark.asyncio
@pytest.mark.trio
async def test_simple_two_nodes():
num_nodes = 2
adj_map = {0: [1]}
@ -80,7 +59,7 @@ async def test_simple_two_nodes():
await perform_test(num_nodes, adj_map, action_func, assertion_func)
@pytest.mark.asyncio
@pytest.mark.trio
async def test_simple_three_nodes_line_topography():
num_nodes = 3
adj_map = {0: [1], 1: [2]}
@ -94,7 +73,7 @@ async def test_simple_three_nodes_line_topography():
await perform_test(num_nodes, adj_map, action_func, assertion_func)
@pytest.mark.asyncio
@pytest.mark.trio
async def test_simple_three_nodes_triangle_topography():
num_nodes = 3
adj_map = {0: [1, 2], 1: [2]}
@ -108,7 +87,7 @@ async def test_simple_three_nodes_triangle_topography():
await perform_test(num_nodes, adj_map, action_func, assertion_func)
@pytest.mark.asyncio
@pytest.mark.trio
async def test_simple_seven_nodes_tree_topography():
num_nodes = 7
adj_map = {0: [1, 2], 1: [3, 4], 2: [5, 6]}
@ -122,14 +101,14 @@ async def test_simple_seven_nodes_tree_topography():
await perform_test(num_nodes, adj_map, action_func, assertion_func)
@pytest.mark.asyncio
@pytest.mark.trio
async def test_set_then_send_from_root_seven_nodes_tree_topography():
num_nodes = 7
adj_map = {0: [1, 2], 1: [3, 4], 2: [5, 6]}
async def action_func(dummy_nodes):
await dummy_nodes[0].publish_set_crypto("aspyn", 20)
await asyncio.sleep(0.25)
await trio.sleep(0.25)
await dummy_nodes[0].publish_send_crypto("aspyn", "alex", 5)
def assertion_func(dummy_node):
@ -139,14 +118,14 @@ async def test_set_then_send_from_root_seven_nodes_tree_topography():
await perform_test(num_nodes, adj_map, action_func, assertion_func)
@pytest.mark.asyncio
@pytest.mark.trio
async def test_set_then_send_from_different_leafs_seven_nodes_tree_topography():
num_nodes = 7
adj_map = {0: [1, 2], 1: [3, 4], 2: [5, 6]}
async def action_func(dummy_nodes):
await dummy_nodes[6].publish_set_crypto("aspyn", 20)
await asyncio.sleep(0.25)
await trio.sleep(0.25)
await dummy_nodes[4].publish_send_crypto("aspyn", "alex", 5)
def assertion_func(dummy_node):
@ -156,7 +135,7 @@ async def test_set_then_send_from_different_leafs_seven_nodes_tree_topography():
await perform_test(num_nodes, adj_map, action_func, assertion_func)
@pytest.mark.asyncio
@pytest.mark.trio
async def test_simple_five_nodes_ring_topography():
num_nodes = 5
adj_map = {0: [1], 1: [2], 2: [3], 3: [4], 4: [0]}
@ -170,14 +149,14 @@ async def test_simple_five_nodes_ring_topography():
await perform_test(num_nodes, adj_map, action_func, assertion_func)
@pytest.mark.asyncio
@pytest.mark.trio
async def test_set_then_send_from_diff_nodes_five_nodes_ring_topography():
num_nodes = 5
adj_map = {0: [1], 1: [2], 2: [3], 3: [4], 4: [0]}
async def action_func(dummy_nodes):
await dummy_nodes[0].publish_set_crypto("alex", 20)
await asyncio.sleep(0.25)
await trio.sleep(0.25)
await dummy_nodes[3].publish_send_crypto("alex", "rob", 12)
def assertion_func(dummy_node):
@ -187,7 +166,7 @@ async def test_set_then_send_from_diff_nodes_five_nodes_ring_topography():
await perform_test(num_nodes, adj_map, action_func, assertion_func)
@pytest.mark.asyncio
@pytest.mark.trio
@pytest.mark.slow
async def test_set_then_send_from_five_diff_nodes_five_nodes_ring_topography():
num_nodes = 5
@ -195,13 +174,13 @@ async def test_set_then_send_from_five_diff_nodes_five_nodes_ring_topography():
async def action_func(dummy_nodes):
await dummy_nodes[0].publish_set_crypto("alex", 20)
await asyncio.sleep(1)
await trio.sleep(1)
await dummy_nodes[1].publish_send_crypto("alex", "rob", 3)
await asyncio.sleep(1)
await trio.sleep(1)
await dummy_nodes[2].publish_send_crypto("rob", "aspyn", 2)
await asyncio.sleep(1)
await trio.sleep(1)
await dummy_nodes[3].publish_send_crypto("aspyn", "zx", 1)
await asyncio.sleep(1)
await trio.sleep(1)
await dummy_nodes[4].publish_send_crypto("zx", "raul", 1)
def assertion_func(dummy_node):

View File

@ -1,9 +1,10 @@
import asyncio
import functools
import pytest
import trio
from libp2p.peer.id import ID
from libp2p.tools.factories import FloodsubFactory
from libp2p.tools.factories import PubsubFactory
from libp2p.tools.pubsub.floodsub_integration_test_settings import (
floodsub_protocol_pytest_params,
perform_test_from_obj,
@ -11,79 +12,80 @@ from libp2p.tools.pubsub.floodsub_integration_test_settings import (
from libp2p.tools.utils import connect
@pytest.mark.parametrize("num_hosts", (2,))
@pytest.mark.asyncio
async def test_simple_two_nodes(pubsubs_fsub):
topic = "my_topic"
data = b"some data"
@pytest.mark.trio
async def test_simple_two_nodes():
async with PubsubFactory.create_batch_with_floodsub(2) as pubsubs_fsub:
topic = "my_topic"
data = b"some data"
await connect(pubsubs_fsub[0].host, pubsubs_fsub[1].host)
await asyncio.sleep(0.25)
await connect(pubsubs_fsub[0].host, pubsubs_fsub[1].host)
await trio.sleep(0.25)
sub_b = await pubsubs_fsub[1].subscribe(topic)
# Sleep to let a know of b's subscription
await asyncio.sleep(0.25)
sub_b = await pubsubs_fsub[1].subscribe(topic)
# Sleep to let a know of b's subscription
await trio.sleep(0.25)
await pubsubs_fsub[0].publish(topic, data)
await pubsubs_fsub[0].publish(topic, data)
res_b = await sub_b.get()
# Check that the msg received by node_b is the same
# as the message sent by node_a
assert ID(res_b.from_id) == pubsubs_fsub[0].host.get_id()
assert res_b.data == data
assert res_b.topicIDs == [topic]
# Success, terminate pending tasks.
# Initialize Pubsub with a cache_size of 4
@pytest.mark.parametrize("num_hosts, pubsub_cache_size", ((2, 4),))
@pytest.mark.asyncio
async def test_lru_cache_two_nodes(pubsubs_fsub, monkeypatch):
# two nodes with cache_size of 4
# `node_a` send the following messages to node_b
message_indices = [1, 1, 2, 1, 3, 1, 4, 1, 5, 1]
# `node_b` should only receive the following
expected_received_indices = [1, 2, 3, 4, 5, 1]
topic = "my_topic"
# Mock `get_msg_id` to make us easier to manipulate `msg_id` by `data`.
def get_msg_id(msg):
# Originally it is `(msg.seqno, msg.from_id)`
return (msg.data, msg.from_id)
import libp2p.pubsub.pubsub
monkeypatch.setattr(libp2p.pubsub.pubsub, "get_msg_id", get_msg_id)
await connect(pubsubs_fsub[0].host, pubsubs_fsub[1].host)
await asyncio.sleep(0.25)
sub_b = await pubsubs_fsub[1].subscribe(topic)
await asyncio.sleep(0.25)
def _make_testing_data(i: int) -> bytes:
num_int_bytes = 4
if i >= 2 ** (num_int_bytes * 8):
raise ValueError("integer is too large to be serialized")
return b"data" + i.to_bytes(num_int_bytes, "big")
for index in message_indices:
await pubsubs_fsub[0].publish(topic, _make_testing_data(index))
await asyncio.sleep(0.25)
for index in expected_received_indices:
res_b = await sub_b.get()
assert res_b.data == _make_testing_data(index)
assert sub_b.empty()
# Success, terminate pending tasks.
# Check that the msg received by node_b is the same
# as the message sent by node_a
assert ID(res_b.from_id) == pubsubs_fsub[0].host.get_id()
assert res_b.data == data
assert res_b.topicIDs == [topic]
@pytest.mark.trio
async def test_lru_cache_two_nodes(monkeypatch):
# two nodes with cache_size of 4
async with PubsubFactory.create_batch_with_floodsub(
2, cache_size=4
) as pubsubs_fsub:
# `node_a` send the following messages to node_b
message_indices = [1, 1, 2, 1, 3, 1, 4, 1, 5, 1]
# `node_b` should only receive the following
expected_received_indices = [1, 2, 3, 4, 5, 1]
topic = "my_topic"
# Mock `get_msg_id` to make us easier to manipulate `msg_id` by `data`.
def get_msg_id(msg):
# Originally it is `(msg.seqno, msg.from_id)`
return (msg.data, msg.from_id)
import libp2p.pubsub.pubsub
monkeypatch.setattr(libp2p.pubsub.pubsub, "get_msg_id", get_msg_id)
await connect(pubsubs_fsub[0].host, pubsubs_fsub[1].host)
await trio.sleep(0.25)
sub_b = await pubsubs_fsub[1].subscribe(topic)
await trio.sleep(0.25)
def _make_testing_data(i: int) -> bytes:
num_int_bytes = 4
if i >= 2 ** (num_int_bytes * 8):
raise ValueError("integer is too large to be serialized")
return b"data" + i.to_bytes(num_int_bytes, "big")
for index in message_indices:
await pubsubs_fsub[0].publish(topic, _make_testing_data(index))
await trio.sleep(0.25)
for index in expected_received_indices:
res_b = await sub_b.get()
assert res_b.data == _make_testing_data(index)
@pytest.mark.parametrize("test_case_obj", floodsub_protocol_pytest_params)
@pytest.mark.asyncio
@pytest.mark.trio
@pytest.mark.slow
async def test_gossipsub_run_with_floodsub_tests(test_case_obj):
await perform_test_from_obj(test_case_obj, FloodsubFactory)
async def test_gossipsub_run_with_floodsub_tests(test_case_obj, is_host_secure):
await perform_test_from_obj(
test_case_obj,
functools.partial(
PubsubFactory.create_batch_with_floodsub, is_secure=is_host_secure
),
)

View File

@ -1,495 +1,478 @@
import asyncio
import random
import pytest
import trio
from libp2p.peer.id import ID
from libp2p.pubsub.gossipsub import PROTOCOL_ID
from libp2p.tools.constants import GOSSIPSUB_PARAMS, GossipsubParams
from libp2p.tools.factories import IDFactory, PubsubFactory
from libp2p.tools.pubsub.utils import dense_connect, one_to_all_connect
from libp2p.tools.utils import connect
@pytest.mark.parametrize(
"num_hosts, gossipsub_params",
((4, GossipsubParams(degree=4, degree_low=3, degree_high=5)),),
)
@pytest.mark.asyncio
async def test_join(num_hosts, hosts, pubsubs_gsub):
gossipsubs = tuple(pubsub.router for pubsub in pubsubs_gsub)
hosts_indices = list(range(num_hosts))
@pytest.mark.trio
async def test_join():
async with PubsubFactory.create_batch_with_gossipsub(
4, degree=4, degree_low=3, degree_high=5
) as pubsubs_gsub:
gossipsubs = [pubsub.router for pubsub in pubsubs_gsub]
hosts = [pubsub.host for pubsub in pubsubs_gsub]
hosts_indices = list(range(len(pubsubs_gsub)))
topic = "test_join"
central_node_index = 0
# Remove index of central host from the indices
hosts_indices.remove(central_node_index)
num_subscribed_peer = 2
subscribed_peer_indices = random.sample(hosts_indices, num_subscribed_peer)
topic = "test_join"
central_node_index = 0
# Remove index of central host from the indices
hosts_indices.remove(central_node_index)
num_subscribed_peer = 2
subscribed_peer_indices = random.sample(hosts_indices, num_subscribed_peer)
# All pubsub except the one of central node subscribe to topic
for i in subscribed_peer_indices:
await pubsubs_gsub[i].subscribe(topic)
# All pubsub except the one of central node subscribe to topic
for i in subscribed_peer_indices:
await pubsubs_gsub[i].subscribe(topic)
# Connect central host to all other hosts
await one_to_all_connect(hosts, central_node_index)
# Connect central host to all other hosts
await one_to_all_connect(hosts, central_node_index)
# Wait 2 seconds for heartbeat to allow mesh to connect
await asyncio.sleep(2)
# Central node publish to the topic so that this topic
# is added to central node's fanout
# publish from the randomly chosen host
await pubsubs_gsub[central_node_index].publish(topic, b"data")
# Check that the gossipsub of central node has fanout for the topic
assert topic in gossipsubs[central_node_index].fanout
# Check that the gossipsub of central node does not have a mesh for the topic
assert topic not in gossipsubs[central_node_index].mesh
# Central node subscribes the topic
await pubsubs_gsub[central_node_index].subscribe(topic)
await asyncio.sleep(2)
# Check that the gossipsub of central node no longer has fanout for the topic
assert topic not in gossipsubs[central_node_index].fanout
for i in hosts_indices:
if i in subscribed_peer_indices:
assert hosts[i].get_id() in gossipsubs[central_node_index].mesh[topic]
assert hosts[central_node_index].get_id() in gossipsubs[i].mesh[topic]
else:
assert hosts[i].get_id() not in gossipsubs[central_node_index].mesh[topic]
assert topic not in gossipsubs[i].mesh
@pytest.mark.parametrize("num_hosts", (1,))
@pytest.mark.asyncio
async def test_leave(pubsubs_gsub):
gossipsub = pubsubs_gsub[0].router
topic = "test_leave"
assert topic not in gossipsub.mesh
await gossipsub.join(topic)
assert topic in gossipsub.mesh
await gossipsub.leave(topic)
assert topic not in gossipsub.mesh
# Test re-leave
await gossipsub.leave(topic)
@pytest.mark.parametrize("num_hosts", (2,))
@pytest.mark.asyncio
async def test_handle_graft(pubsubs_gsub, hosts, event_loop, monkeypatch):
gossipsubs = tuple(pubsub.router for pubsub in pubsubs_gsub)
index_alice = 0
id_alice = hosts[index_alice].get_id()
index_bob = 1
id_bob = hosts[index_bob].get_id()
await connect(hosts[index_alice], hosts[index_bob])
# Wait 2 seconds for heartbeat to allow mesh to connect
await asyncio.sleep(2)
topic = "test_handle_graft"
# Only lice subscribe to the topic
await gossipsubs[index_alice].join(topic)
# Monkey patch bob's `emit_prune` function so we can
# check if it is called in `handle_graft`
event_emit_prune = asyncio.Event()
async def emit_prune(topic, sender_peer_id):
event_emit_prune.set()
monkeypatch.setattr(gossipsubs[index_bob], "emit_prune", emit_prune)
# Check that alice is bob's peer but not his mesh peer
assert gossipsubs[index_bob].peer_protocol[id_alice] == PROTOCOL_ID
assert topic not in gossipsubs[index_bob].mesh
await gossipsubs[index_alice].emit_graft(topic, id_bob)
# Check that `emit_prune` is called
await asyncio.wait_for(event_emit_prune.wait(), timeout=1, loop=event_loop)
assert event_emit_prune.is_set()
# Check that bob is alice's peer but not her mesh peer
assert topic in gossipsubs[index_alice].mesh
assert id_bob not in gossipsubs[index_alice].mesh[topic]
assert gossipsubs[index_alice].peer_protocol[id_bob] == PROTOCOL_ID
await gossipsubs[index_bob].emit_graft(topic, id_alice)
await asyncio.sleep(1)
# Check that bob is now alice's mesh peer
assert id_bob in gossipsubs[index_alice].mesh[topic]
@pytest.mark.parametrize(
"num_hosts, gossipsub_params", ((2, GossipsubParams(heartbeat_interval=3)),)
)
@pytest.mark.asyncio
async def test_handle_prune(pubsubs_gsub, hosts):
gossipsubs = tuple(pubsub.router for pubsub in pubsubs_gsub)
index_alice = 0
id_alice = hosts[index_alice].get_id()
index_bob = 1
id_bob = hosts[index_bob].get_id()
topic = "test_handle_prune"
for pubsub in pubsubs_gsub:
await pubsub.subscribe(topic)
await connect(hosts[index_alice], hosts[index_bob])
# Wait for heartbeat to allow mesh to connect
await asyncio.sleep(1)
# Check that they are each other's mesh peer
assert id_alice in gossipsubs[index_bob].mesh[topic]
assert id_bob in gossipsubs[index_alice].mesh[topic]
# alice emit prune message to bob, alice should be removed
# from bob's mesh peer
await gossipsubs[index_alice].emit_prune(topic, id_bob)
# `emit_prune` does not remove bob from alice's mesh peers
assert id_bob in gossipsubs[index_alice].mesh[topic]
# NOTE: We increase `heartbeat_interval` to 3 seconds so that bob will not
# add alice back to his mesh after heartbeat.
# Wait for bob to `handle_prune`
await asyncio.sleep(0.1)
# Check that alice is no longer bob's mesh peer
assert id_alice not in gossipsubs[index_bob].mesh[topic]
@pytest.mark.parametrize("num_hosts", (10,))
@pytest.mark.asyncio
async def test_dense(num_hosts, pubsubs_gsub, hosts):
num_msgs = 5
# All pubsub subscribe to foobar
queues = []
for pubsub in pubsubs_gsub:
q = await pubsub.subscribe("foobar")
# Add each blocking queue to an array of blocking queues
queues.append(q)
# Densely connect libp2p hosts in a random way
await dense_connect(hosts)
# Wait 2 seconds for heartbeat to allow mesh to connect
await asyncio.sleep(2)
for i in range(num_msgs):
msg_content = b"foo " + i.to_bytes(1, "big")
# randomly pick a message origin
origin_idx = random.randint(0, num_hosts - 1)
# Wait 2 seconds for heartbeat to allow mesh to connect
await trio.sleep(2)
# Central node publish to the topic so that this topic
# is added to central node's fanout
# publish from the randomly chosen host
await pubsubs_gsub[origin_idx].publish("foobar", msg_content)
await pubsubs_gsub[central_node_index].publish(topic, b"data")
await asyncio.sleep(0.5)
# Assert that all blocking queues receive the message
for queue in queues:
msg = await queue.get()
assert msg.data == msg_content
# Check that the gossipsub of central node has fanout for the topic
assert topic in gossipsubs[central_node_index].fanout
# Check that the gossipsub of central node does not have a mesh for the topic
assert topic not in gossipsubs[central_node_index].mesh
# Central node subscribes the topic
await pubsubs_gsub[central_node_index].subscribe(topic)
await trio.sleep(2)
# Check that the gossipsub of central node no longer has fanout for the topic
assert topic not in gossipsubs[central_node_index].fanout
for i in hosts_indices:
if i in subscribed_peer_indices:
assert hosts[i].get_id() in gossipsubs[central_node_index].mesh[topic]
assert hosts[central_node_index].get_id() in gossipsubs[i].mesh[topic]
else:
assert (
hosts[i].get_id() not in gossipsubs[central_node_index].mesh[topic]
)
assert topic not in gossipsubs[i].mesh
@pytest.mark.parametrize("num_hosts", (10,))
@pytest.mark.asyncio
async def test_fanout(hosts, pubsubs_gsub):
num_msgs = 5
@pytest.mark.trio
async def test_leave():
async with PubsubFactory.create_batch_with_gossipsub(1) as pubsubs_gsub:
gossipsub = pubsubs_gsub[0].router
topic = "test_leave"
# All pubsub subscribe to foobar except for `pubsubs_gsub[0]`
queues = []
for i in range(1, len(pubsubs_gsub)):
q = await pubsubs_gsub[i].subscribe("foobar")
assert topic not in gossipsub.mesh
# Add each blocking queue to an array of blocking queues
queues.append(q)
await gossipsub.join(topic)
assert topic in gossipsub.mesh
# Sparsely connect libp2p hosts in random way
await dense_connect(hosts)
await gossipsub.leave(topic)
assert topic not in gossipsub.mesh
# Wait 2 seconds for heartbeat to allow mesh to connect
await asyncio.sleep(2)
topic = "foobar"
# Send messages with origin not subscribed
for i in range(num_msgs):
msg_content = b"foo " + i.to_bytes(1, "big")
# Pick the message origin to the node that is not subscribed to 'foobar'
origin_idx = 0
# publish from the randomly chosen host
await pubsubs_gsub[origin_idx].publish(topic, msg_content)
await asyncio.sleep(0.5)
# Assert that all blocking queues receive the message
for queue in queues:
msg = await queue.get()
assert msg.data == msg_content
# Subscribe message origin
queues.insert(0, await pubsubs_gsub[0].subscribe(topic))
# Send messages again
for i in range(num_msgs):
msg_content = b"bar " + i.to_bytes(1, "big")
# Pick the message origin to the node that is not subscribed to 'foobar'
origin_idx = 0
# publish from the randomly chosen host
await pubsubs_gsub[origin_idx].publish(topic, msg_content)
await asyncio.sleep(0.5)
# Assert that all blocking queues receive the message
for queue in queues:
msg = await queue.get()
assert msg.data == msg_content
# Test re-leave
await gossipsub.leave(topic)
@pytest.mark.parametrize("num_hosts", (10,))
@pytest.mark.asyncio
@pytest.mark.trio
async def test_handle_graft(monkeypatch):
async with PubsubFactory.create_batch_with_gossipsub(2) as pubsubs_gsub:
gossipsubs = tuple(pubsub.router for pubsub in pubsubs_gsub)
index_alice = 0
id_alice = pubsubs_gsub[index_alice].my_id
index_bob = 1
id_bob = pubsubs_gsub[index_bob].my_id
await connect(pubsubs_gsub[index_alice].host, pubsubs_gsub[index_bob].host)
# Wait 2 seconds for heartbeat to allow mesh to connect
await trio.sleep(2)
topic = "test_handle_graft"
# Only lice subscribe to the topic
await gossipsubs[index_alice].join(topic)
# Monkey patch bob's `emit_prune` function so we can
# check if it is called in `handle_graft`
event_emit_prune = trio.Event()
async def emit_prune(topic, sender_peer_id):
event_emit_prune.set()
await trio.hazmat.checkpoint()
monkeypatch.setattr(gossipsubs[index_bob], "emit_prune", emit_prune)
# Check that alice is bob's peer but not his mesh peer
assert gossipsubs[index_bob].peer_protocol[id_alice] == PROTOCOL_ID
assert topic not in gossipsubs[index_bob].mesh
await gossipsubs[index_alice].emit_graft(topic, id_bob)
# Check that `emit_prune` is called
await event_emit_prune.wait()
# Check that bob is alice's peer but not her mesh peer
assert topic in gossipsubs[index_alice].mesh
assert id_bob not in gossipsubs[index_alice].mesh[topic]
assert gossipsubs[index_alice].peer_protocol[id_bob] == PROTOCOL_ID
await gossipsubs[index_bob].emit_graft(topic, id_alice)
await trio.sleep(1)
# Check that bob is now alice's mesh peer
assert id_bob in gossipsubs[index_alice].mesh[topic]
@pytest.mark.trio
async def test_handle_prune():
async with PubsubFactory.create_batch_with_gossipsub(
2, heartbeat_interval=3
) as pubsubs_gsub:
gossipsubs = tuple(pubsub.router for pubsub in pubsubs_gsub)
index_alice = 0
id_alice = pubsubs_gsub[index_alice].my_id
index_bob = 1
id_bob = pubsubs_gsub[index_bob].my_id
topic = "test_handle_prune"
for pubsub in pubsubs_gsub:
await pubsub.subscribe(topic)
await connect(pubsubs_gsub[index_alice].host, pubsubs_gsub[index_bob].host)
# Wait for heartbeat to allow mesh to connect
await trio.sleep(1)
# Check that they are each other's mesh peer
assert id_alice in gossipsubs[index_bob].mesh[topic]
assert id_bob in gossipsubs[index_alice].mesh[topic]
# alice emit prune message to bob, alice should be removed
# from bob's mesh peer
await gossipsubs[index_alice].emit_prune(topic, id_bob)
# `emit_prune` does not remove bob from alice's mesh peers
assert id_bob in gossipsubs[index_alice].mesh[topic]
# NOTE: We increase `heartbeat_interval` to 3 seconds so that bob will not
# add alice back to his mesh after heartbeat.
# Wait for bob to `handle_prune`
await trio.sleep(0.1)
# Check that alice is no longer bob's mesh peer
assert id_alice not in gossipsubs[index_bob].mesh[topic]
@pytest.mark.trio
async def test_dense():
async with PubsubFactory.create_batch_with_gossipsub(10) as pubsubs_gsub:
hosts = [pubsub.host for pubsub in pubsubs_gsub]
num_msgs = 5
# All pubsub subscribe to foobar
queues = [await pubsub.subscribe("foobar") for pubsub in pubsubs_gsub]
# Densely connect libp2p hosts in a random way
await dense_connect(hosts)
# Wait 2 seconds for heartbeat to allow mesh to connect
await trio.sleep(2)
for i in range(num_msgs):
msg_content = b"foo " + i.to_bytes(1, "big")
# randomly pick a message origin
origin_idx = random.randint(0, len(hosts) - 1)
# publish from the randomly chosen host
await pubsubs_gsub[origin_idx].publish("foobar", msg_content)
await trio.sleep(0.5)
# Assert that all blocking queues receive the message
for queue in queues:
msg = await queue.get()
assert msg.data == msg_content
@pytest.mark.trio
async def test_fanout():
async with PubsubFactory.create_batch_with_gossipsub(10) as pubsubs_gsub:
hosts = [pubsub.host for pubsub in pubsubs_gsub]
num_msgs = 5
# All pubsub subscribe to foobar except for `pubsubs_gsub[0]`
subs = [await pubsub.subscribe("foobar") for pubsub in pubsubs_gsub[1:]]
# Sparsely connect libp2p hosts in random way
await dense_connect(hosts)
# Wait 2 seconds for heartbeat to allow mesh to connect
await trio.sleep(2)
topic = "foobar"
# Send messages with origin not subscribed
for i in range(num_msgs):
msg_content = b"foo " + i.to_bytes(1, "big")
# Pick the message origin to the node that is not subscribed to 'foobar'
origin_idx = 0
# publish from the randomly chosen host
await pubsubs_gsub[origin_idx].publish(topic, msg_content)
await trio.sleep(0.5)
# Assert that all blocking queues receive the message
for sub in subs:
msg = await sub.get()
assert msg.data == msg_content
# Subscribe message origin
subs.insert(0, await pubsubs_gsub[0].subscribe(topic))
# Send messages again
for i in range(num_msgs):
msg_content = b"bar " + i.to_bytes(1, "big")
# Pick the message origin to the node that is not subscribed to 'foobar'
origin_idx = 0
# publish from the randomly chosen host
await pubsubs_gsub[origin_idx].publish(topic, msg_content)
await trio.sleep(0.5)
# Assert that all blocking queues receive the message
for sub in subs:
msg = await sub.get()
assert msg.data == msg_content
@pytest.mark.trio
@pytest.mark.slow
async def test_fanout_maintenance(hosts, pubsubs_gsub):
num_msgs = 5
async def test_fanout_maintenance():
async with PubsubFactory.create_batch_with_gossipsub(10) as pubsubs_gsub:
hosts = [pubsub.host for pubsub in pubsubs_gsub]
num_msgs = 5
# All pubsub subscribe to foobar
queues = []
topic = "foobar"
for i in range(1, len(pubsubs_gsub)):
q = await pubsubs_gsub[i].subscribe(topic)
# All pubsub subscribe to foobar
queues = []
topic = "foobar"
for i in range(1, len(pubsubs_gsub)):
q = await pubsubs_gsub[i].subscribe(topic)
# Add each blocking queue to an array of blocking queues
queues.append(q)
# Add each blocking queue to an array of blocking queues
queues.append(q)
# Sparsely connect libp2p hosts in random way
await dense_connect(hosts)
# Sparsely connect libp2p hosts in random way
await dense_connect(hosts)
# Wait 2 seconds for heartbeat to allow mesh to connect
await asyncio.sleep(2)
# Wait 2 seconds for heartbeat to allow mesh to connect
await trio.sleep(2)
# Send messages with origin not subscribed
for i in range(num_msgs):
msg_content = b"foo " + i.to_bytes(1, "big")
# Send messages with origin not subscribed
for i in range(num_msgs):
msg_content = b"foo " + i.to_bytes(1, "big")
# Pick the message origin to the node that is not subscribed to 'foobar'
origin_idx = 0
# Pick the message origin to the node that is not subscribed to 'foobar'
origin_idx = 0
# publish from the randomly chosen host
await pubsubs_gsub[origin_idx].publish(topic, msg_content)
await trio.sleep(0.5)
# Assert that all blocking queues receive the message
for queue in queues:
msg = await queue.get()
assert msg.data == msg_content
for sub in pubsubs_gsub:
await sub.unsubscribe(topic)
queues = []
await trio.sleep(2)
# Resub and repeat
for i in range(1, len(pubsubs_gsub)):
q = await pubsubs_gsub[i].subscribe(topic)
# Add each blocking queue to an array of blocking queues
queues.append(q)
await trio.sleep(2)
# Check messages can still be sent
for i in range(num_msgs):
msg_content = b"bar " + i.to_bytes(1, "big")
# Pick the message origin to the node that is not subscribed to 'foobar'
origin_idx = 0
# publish from the randomly chosen host
await pubsubs_gsub[origin_idx].publish(topic, msg_content)
await trio.sleep(0.5)
# Assert that all blocking queues receive the message
for queue in queues:
msg = await queue.get()
assert msg.data == msg_content
@pytest.mark.trio
async def test_gossip_propagation():
async with PubsubFactory.create_batch_with_gossipsub(
2, degree=1, degree_low=0, degree_high=2, gossip_window=50, gossip_history=100
) as pubsubs_gsub:
topic = "foo"
queue_0 = await pubsubs_gsub[0].subscribe(topic)
# node 0 publish to topic
msg_content = b"foo_msg"
# publish from the randomly chosen host
await pubsubs_gsub[origin_idx].publish(topic, msg_content)
await pubsubs_gsub[0].publish(topic, msg_content)
await asyncio.sleep(0.5)
# Assert that all blocking queues receive the message
for queue in queues:
msg = await queue.get()
assert msg.data == msg_content
for sub in pubsubs_gsub:
await sub.unsubscribe(topic)
queues = []
await asyncio.sleep(2)
# Resub and repeat
for i in range(1, len(pubsubs_gsub)):
q = await pubsubs_gsub[i].subscribe(topic)
# Add each blocking queue to an array of blocking queues
queues.append(q)
await asyncio.sleep(2)
# Check messages can still be sent
for i in range(num_msgs):
msg_content = b"bar " + i.to_bytes(1, "big")
# Pick the message origin to the node that is not subscribed to 'foobar'
origin_idx = 0
# publish from the randomly chosen host
await pubsubs_gsub[origin_idx].publish(topic, msg_content)
await asyncio.sleep(0.5)
# Assert that all blocking queues receive the message
for queue in queues:
msg = await queue.get()
assert msg.data == msg_content
await trio.sleep(0.5)
# Assert that the blocking queues receive the message
msg = await queue_0.get()
assert msg.data == msg_content
@pytest.mark.parametrize(
"num_hosts, gossipsub_params",
(
(
2,
GossipsubParams(
degree=1,
degree_low=0,
degree_high=2,
gossip_window=50,
gossip_history=100,
),
),
),
)
@pytest.mark.asyncio
async def test_gossip_propagation(hosts, pubsubs_gsub):
topic = "foo"
await pubsubs_gsub[0].subscribe(topic)
# node 0 publish to topic
msg_content = b"foo_msg"
# publish from the randomly chosen host
await pubsubs_gsub[0].publish(topic, msg_content)
# now node 1 subscribes
queue_1 = await pubsubs_gsub[1].subscribe(topic)
await connect(hosts[0], hosts[1])
# wait for gossip heartbeat
await asyncio.sleep(2)
# should be able to read message
msg = await queue_1.get()
assert msg.data == msg_content
@pytest.mark.parametrize(
"num_hosts, gossipsub_params", ((1, GossipsubParams(heartbeat_initial_delay=100)),)
)
@pytest.mark.parametrize("initial_mesh_peer_count", (7, 10, 13))
@pytest.mark.asyncio
async def test_mesh_heartbeat(
num_hosts, initial_mesh_peer_count, pubsubs_gsub, hosts, monkeypatch
):
# It's difficult to set up the initial peer subscription condition.
# Ideally I would like to have initial mesh peer count that's below ``GossipSubDegree``
# so I can test if `mesh_heartbeat` return correct peers to GRAFT.
# The problem is that I can not set it up so that we have peers subscribe to the topic
# but not being part of our mesh peers (as these peers are the peers to GRAFT).
# So I monkeypatch the peer subscriptions and our mesh peers.
total_peer_count = 14
topic = "TEST_MESH_HEARTBEAT"
@pytest.mark.trio
async def test_mesh_heartbeat(initial_mesh_peer_count, monkeypatch):
async with PubsubFactory.create_batch_with_gossipsub(
1, heartbeat_initial_delay=100
) as pubsubs_gsub:
# It's difficult to set up the initial peer subscription condition.
# Ideally I would like to have initial mesh peer count that's below ``GossipSubDegree``
# so I can test if `mesh_heartbeat` return correct peers to GRAFT.
# The problem is that I can not set it up so that we have peers subscribe to the topic
# but not being part of our mesh peers (as these peers are the peers to GRAFT).
# So I monkeypatch the peer subscriptions and our mesh peers.
total_peer_count = 14
topic = "TEST_MESH_HEARTBEAT"
fake_peer_ids = [
ID((i).to_bytes(2, byteorder="big")) for i in range(total_peer_count)
]
peer_protocol = {peer_id: PROTOCOL_ID for peer_id in fake_peer_ids}
monkeypatch.setattr(pubsubs_gsub[0].router, "peer_protocol", peer_protocol)
fake_peer_ids = [IDFactory() for _ in range(total_peer_count)]
peer_protocol = {peer_id: PROTOCOL_ID for peer_id in fake_peer_ids}
monkeypatch.setattr(pubsubs_gsub[0].router, "peer_protocol", peer_protocol)
peer_topics = {topic: set(fake_peer_ids)}
# Monkeypatch the peer subscriptions
monkeypatch.setattr(pubsubs_gsub[0], "peer_topics", peer_topics)
peer_topics = {topic: set(fake_peer_ids)}
# Monkeypatch the peer subscriptions
monkeypatch.setattr(pubsubs_gsub[0], "peer_topics", peer_topics)
mesh_peer_indices = random.sample(range(total_peer_count), initial_mesh_peer_count)
mesh_peers = [fake_peer_ids[i] for i in mesh_peer_indices]
router_mesh = {topic: set(mesh_peers)}
# Monkeypatch our mesh peers
monkeypatch.setattr(pubsubs_gsub[0].router, "mesh", router_mesh)
mesh_peer_indices = random.sample(
range(total_peer_count), initial_mesh_peer_count
)
mesh_peers = [fake_peer_ids[i] for i in mesh_peer_indices]
router_mesh = {topic: set(mesh_peers)}
# Monkeypatch our mesh peers
monkeypatch.setattr(pubsubs_gsub[0].router, "mesh", router_mesh)
peers_to_graft, peers_to_prune = pubsubs_gsub[0].router.mesh_heartbeat()
if initial_mesh_peer_count > GOSSIPSUB_PARAMS.degree:
# If number of initial mesh peers is more than `GossipSubDegree`, we should PRUNE mesh peers
assert len(peers_to_graft) == 0
assert len(peers_to_prune) == initial_mesh_peer_count - GOSSIPSUB_PARAMS.degree
for peer in peers_to_prune:
assert peer in mesh_peers
elif initial_mesh_peer_count < GOSSIPSUB_PARAMS.degree:
# If number of initial mesh peers is less than `GossipSubDegree`, we should GRAFT more peers
assert len(peers_to_prune) == 0
assert len(peers_to_graft) == GOSSIPSUB_PARAMS.degree - initial_mesh_peer_count
for peer in peers_to_graft:
assert peer not in mesh_peers
else:
assert len(peers_to_prune) == 0 and len(peers_to_graft) == 0
@pytest.mark.parametrize(
"num_hosts, gossipsub_params", ((1, GossipsubParams(heartbeat_initial_delay=100)),)
)
@pytest.mark.parametrize("initial_peer_count", (1, 4, 7))
@pytest.mark.asyncio
async def test_gossip_heartbeat(
num_hosts, initial_peer_count, pubsubs_gsub, hosts, monkeypatch
):
# The problem is that I can not set it up so that we have peers subscribe to the topic
# but not being part of our mesh peers (as these peers are the peers to GRAFT).
# So I monkeypatch the peer subscriptions and our mesh peers.
total_peer_count = 28
topic_mesh = "TEST_GOSSIP_HEARTBEAT_1"
topic_fanout = "TEST_GOSSIP_HEARTBEAT_2"
fake_peer_ids = [
ID((i).to_bytes(2, byteorder="big")) for i in range(total_peer_count)
]
peer_protocol = {peer_id: PROTOCOL_ID for peer_id in fake_peer_ids}
monkeypatch.setattr(pubsubs_gsub[0].router, "peer_protocol", peer_protocol)
topic_mesh_peer_count = 14
# Split into mesh peers and fanout peers
peer_topics = {
topic_mesh: set(fake_peer_ids[:topic_mesh_peer_count]),
topic_fanout: set(fake_peer_ids[topic_mesh_peer_count:]),
}
# Monkeypatch the peer subscriptions
monkeypatch.setattr(pubsubs_gsub[0], "peer_topics", peer_topics)
mesh_peer_indices = random.sample(range(topic_mesh_peer_count), initial_peer_count)
mesh_peers = [fake_peer_ids[i] for i in mesh_peer_indices]
router_mesh = {topic_mesh: set(mesh_peers)}
# Monkeypatch our mesh peers
monkeypatch.setattr(pubsubs_gsub[0].router, "mesh", router_mesh)
fanout_peer_indices = random.sample(
range(topic_mesh_peer_count, total_peer_count), initial_peer_count
)
fanout_peers = [fake_peer_ids[i] for i in fanout_peer_indices]
router_fanout = {topic_fanout: set(fanout_peers)}
# Monkeypatch our fanout peers
monkeypatch.setattr(pubsubs_gsub[0].router, "fanout", router_fanout)
def window(topic):
if topic == topic_mesh:
return [topic_mesh]
elif topic == topic_fanout:
return [topic_fanout]
peers_to_graft, peers_to_prune = pubsubs_gsub[0].router.mesh_heartbeat()
if initial_mesh_peer_count > pubsubs_gsub[0].router.degree:
# If number of initial mesh peers is more than `GossipSubDegree`,
# we should PRUNE mesh peers
assert len(peers_to_graft) == 0
assert (
len(peers_to_prune)
== initial_mesh_peer_count - pubsubs_gsub[0].router.degree
)
for peer in peers_to_prune:
assert peer in mesh_peers
elif initial_mesh_peer_count < pubsubs_gsub[0].router.degree:
# If number of initial mesh peers is less than `GossipSubDegree`,
# we should GRAFT more peers
assert len(peers_to_prune) == 0
assert (
len(peers_to_graft)
== pubsubs_gsub[0].router.degree - initial_mesh_peer_count
)
for peer in peers_to_graft:
assert peer not in mesh_peers
else:
return []
assert len(peers_to_prune) == 0 and len(peers_to_graft) == 0
# Monkeypatch the memory cache messages
monkeypatch.setattr(pubsubs_gsub[0].router.mcache, "window", window)
peers_to_gossip = pubsubs_gsub[0].router.gossip_heartbeat()
# If our mesh peer count is less than `GossipSubDegree`, we should gossip to up to
# `GossipSubDegree` peers (exclude mesh peers).
if topic_mesh_peer_count - initial_peer_count < GOSSIPSUB_PARAMS.degree:
# The same goes for fanout so it's two times the number of peers to gossip.
assert len(peers_to_gossip) == 2 * (topic_mesh_peer_count - initial_peer_count)
elif topic_mesh_peer_count - initial_peer_count >= GOSSIPSUB_PARAMS.degree:
assert len(peers_to_gossip) == 2 * (GOSSIPSUB_PARAMS.degree)
@pytest.mark.parametrize("initial_peer_count", (1, 4, 7))
@pytest.mark.trio
async def test_gossip_heartbeat(initial_peer_count, monkeypatch):
async with PubsubFactory.create_batch_with_gossipsub(
1, heartbeat_initial_delay=100
) as pubsubs_gsub:
# The problem is that I can not set it up so that we have peers subscribe to the topic
# but not being part of our mesh peers (as these peers are the peers to GRAFT).
# So I monkeypatch the peer subscriptions and our mesh peers.
total_peer_count = 28
topic_mesh = "TEST_GOSSIP_HEARTBEAT_1"
topic_fanout = "TEST_GOSSIP_HEARTBEAT_2"
for peer in peers_to_gossip:
if peer in peer_topics[topic_mesh]:
# Check that the peer to gossip to is not in our mesh peers
assert peer not in mesh_peers
assert topic_mesh in peers_to_gossip[peer]
elif peer in peer_topics[topic_fanout]:
# Check that the peer to gossip to is not in our fanout peers
assert peer not in fanout_peers
assert topic_fanout in peers_to_gossip[peer]
fake_peer_ids = [IDFactory() for _ in range(total_peer_count)]
peer_protocol = {peer_id: PROTOCOL_ID for peer_id in fake_peer_ids}
monkeypatch.setattr(pubsubs_gsub[0].router, "peer_protocol", peer_protocol)
topic_mesh_peer_count = 14
# Split into mesh peers and fanout peers
peer_topics = {
topic_mesh: set(fake_peer_ids[:topic_mesh_peer_count]),
topic_fanout: set(fake_peer_ids[topic_mesh_peer_count:]),
}
# Monkeypatch the peer subscriptions
monkeypatch.setattr(pubsubs_gsub[0], "peer_topics", peer_topics)
mesh_peer_indices = random.sample(
range(topic_mesh_peer_count), initial_peer_count
)
mesh_peers = [fake_peer_ids[i] for i in mesh_peer_indices]
router_mesh = {topic_mesh: set(mesh_peers)}
# Monkeypatch our mesh peers
monkeypatch.setattr(pubsubs_gsub[0].router, "mesh", router_mesh)
fanout_peer_indices = random.sample(
range(topic_mesh_peer_count, total_peer_count), initial_peer_count
)
fanout_peers = [fake_peer_ids[i] for i in fanout_peer_indices]
router_fanout = {topic_fanout: set(fanout_peers)}
# Monkeypatch our fanout peers
monkeypatch.setattr(pubsubs_gsub[0].router, "fanout", router_fanout)
def window(topic):
if topic == topic_mesh:
return [topic_mesh]
elif topic == topic_fanout:
return [topic_fanout]
else:
return []
# Monkeypatch the memory cache messages
monkeypatch.setattr(pubsubs_gsub[0].router.mcache, "window", window)
peers_to_gossip = pubsubs_gsub[0].router.gossip_heartbeat()
# If our mesh peer count is less than `GossipSubDegree`, we should gossip to up to
# `GossipSubDegree` peers (exclude mesh peers).
if topic_mesh_peer_count - initial_peer_count < pubsubs_gsub[0].router.degree:
# The same goes for fanout so it's two times the number of peers to gossip.
assert len(peers_to_gossip) == 2 * (
topic_mesh_peer_count - initial_peer_count
)
elif (
topic_mesh_peer_count - initial_peer_count >= pubsubs_gsub[0].router.degree
):
assert len(peers_to_gossip) == 2 * (pubsubs_gsub[0].router.degree)
for peer in peers_to_gossip:
if peer in peer_topics[topic_mesh]:
# Check that the peer to gossip to is not in our mesh peers
assert peer not in mesh_peers
assert topic_mesh in peers_to_gossip[peer]
elif peer in peer_topics[topic_fanout]:
# Check that the peer to gossip to is not in our fanout peers
assert peer not in fanout_peers
assert topic_fanout in peers_to_gossip[peer]

View File

@ -3,25 +3,25 @@ import functools
import pytest
from libp2p.tools.constants import FLOODSUB_PROTOCOL_ID
from libp2p.tools.factories import GossipsubFactory
from libp2p.tools.factories import PubsubFactory
from libp2p.tools.pubsub.floodsub_integration_test_settings import (
floodsub_protocol_pytest_params,
perform_test_from_obj,
)
@pytest.mark.asyncio
async def test_gossipsub_initialize_with_floodsub_protocol():
GossipsubFactory(protocols=[FLOODSUB_PROTOCOL_ID])
@pytest.mark.parametrize("test_case_obj", floodsub_protocol_pytest_params)
@pytest.mark.asyncio
@pytest.mark.trio
@pytest.mark.slow
async def test_gossipsub_run_with_floodsub_tests(test_case_obj):
await perform_test_from_obj(
test_case_obj,
functools.partial(
GossipsubFactory, degree=3, degree_low=2, degree_high=4, time_to_live=30
PubsubFactory.create_batch_with_gossipsub,
protocols=[FLOODSUB_PROTOCOL_ID],
degree=3,
degree_low=2,
degree_high=4,
time_to_live=30,
),
)

View File

@ -1,5 +1,3 @@
import pytest
from libp2p.pubsub.mcache import MessageCache
@ -12,8 +10,7 @@ class Msg:
self.from_id = from_id
@pytest.mark.asyncio
async def test_mcache():
def test_mcache():
# Ported from:
# https://github.com/libp2p/go-libp2p-pubsub/blob/51b7501433411b5096cac2b4994a36a68515fc03/mcache_test.go
mcache = MessageCache(3, 5)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,84 @@
import math
import pytest
import trio
from libp2p.pubsub.pb import rpc_pb2
from libp2p.pubsub.subscription import TrioSubscriptionAPI
GET_TIMEOUT = 0.001
def make_trio_subscription():
send_channel, receive_channel = trio.open_memory_channel(math.inf)
async def unsubscribe_fn():
await send_channel.aclose()
return (
send_channel,
TrioSubscriptionAPI(receive_channel, unsubscribe_fn=unsubscribe_fn),
)
def make_pubsub_msg():
return rpc_pb2.Message()
async def send_something(send_channel):
msg = make_pubsub_msg()
await send_channel.send(msg)
return msg
@pytest.mark.trio
async def test_trio_subscription_get():
send_channel, sub = make_trio_subscription()
data_0 = await send_something(send_channel)
data_1 = await send_something(send_channel)
assert data_0 == await sub.get()
assert data_1 == await sub.get()
# No more message
with pytest.raises(trio.TooSlowError):
with trio.fail_after(GET_TIMEOUT):
await sub.get()
@pytest.mark.trio
async def test_trio_subscription_iter():
send_channel, sub = make_trio_subscription()
received_data = []
async def iter_subscriptions(subscription):
async for data in sub:
received_data.append(data)
async with trio.open_nursery() as nursery:
nursery.start_soon(iter_subscriptions, sub)
await send_something(send_channel)
await send_something(send_channel)
await send_channel.aclose()
assert len(received_data) == 2
@pytest.mark.trio
async def test_trio_subscription_unsubscribe():
send_channel, sub = make_trio_subscription()
await sub.unsubscribe()
# Test: If the subscription is unsubscribed, `send_channel` should be closed.
with pytest.raises(trio.ClosedResourceError):
await send_something(send_channel)
# Test: No side effect when cancelled twice.
await sub.unsubscribe()
@pytest.mark.trio
async def test_trio_subscription_async_context_manager():
send_channel, sub = make_trio_subscription()
async with sub:
# Test: `sub` is not cancelled yet, so `send_something` works fine.
await send_something(send_channel)
# Test: `sub` is cancelled, `send_something` fails
with pytest.raises(trio.ClosedResourceError):
await send_something(send_channel)

View File

@ -1,70 +1,15 @@
import asyncio
import pytest
import trio
from libp2p.crypto.secp256k1 import create_new_key_pair
from libp2p.network.connection.raw_connection_interface import IRawConnection
from libp2p.peer.id import ID
from libp2p.security.secio.transport import NONCE_SIZE, create_secure_session
from libp2p.tools.constants import MAX_READ_LEN
from libp2p.tools.factories import raw_conn_factory
class InMemoryConnection(IRawConnection):
def __init__(self, peer, is_initiator=False):
self.peer = peer
self.recv_queue = asyncio.Queue()
self.send_queue = asyncio.Queue()
self.is_initiator = is_initiator
self.current_msg = None
self.current_position = 0
self.closed = False
async def write(self, data: bytes) -> int:
if self.closed:
raise Exception("InMemoryConnection is closed for writing")
await self.send_queue.put(data)
return len(data)
async def read(self, n: int = -1) -> bytes:
"""
NOTE: have to buffer the current message and juggle packets
off the recv queue to satisfy the semantics of this function.
"""
if self.closed:
raise Exception("InMemoryConnection is closed for reading")
if not self.current_msg:
self.current_msg = await self.recv_queue.get()
self.current_position = 0
if n < 0:
msg = self.current_msg
self.current_msg = None
return msg
next_msg = self.current_msg[self.current_position : self.current_position + n]
self.current_position += n
if self.current_position == len(self.current_msg):
self.current_msg = None
return next_msg
async def close(self) -> None:
self.closed = True
async def create_pipe(local_conn, remote_conn):
try:
while True:
next_msg = await local_conn.send_queue.get()
await remote_conn.recv_queue.put(next_msg)
except asyncio.CancelledError:
return
@pytest.mark.asyncio
async def test_create_secure_session():
@pytest.mark.trio
async def test_create_secure_session(nursery):
local_nonce = b"\x01" * NONCE_SIZE
local_key_pair = create_new_key_pair(b"a")
local_peer = ID.from_pubkey(local_key_pair.public_key)
@ -73,30 +18,32 @@ async def test_create_secure_session():
remote_key_pair = create_new_key_pair(b"b")
remote_peer = ID.from_pubkey(remote_key_pair.public_key)
local_conn = InMemoryConnection(local_peer, is_initiator=True)
remote_conn = InMemoryConnection(remote_peer)
async with raw_conn_factory(nursery) as conns:
local_conn, remote_conn = conns
local_pipe_task = asyncio.ensure_future(create_pipe(local_conn, remote_conn))
remote_pipe_task = asyncio.ensure_future(create_pipe(remote_conn, local_conn))
local_secure_conn, remote_secure_conn = None, None
local_session_builder = create_secure_session(
local_nonce, local_peer, local_key_pair.private_key, local_conn, remote_peer
)
remote_session_builder = create_secure_session(
remote_nonce, remote_peer, remote_key_pair.private_key, remote_conn
)
local_secure_conn, remote_secure_conn = await asyncio.gather(
local_session_builder, remote_session_builder
)
async def local_create_secure_session():
nonlocal local_secure_conn
local_secure_conn = await create_secure_session(
local_nonce,
local_peer,
local_key_pair.private_key,
local_conn,
remote_peer,
)
msg = b"abc"
await local_secure_conn.write(msg)
received_msg = await remote_secure_conn.read()
assert received_msg == msg
async def remote_create_secure_session():
nonlocal remote_secure_conn
remote_secure_conn = await create_secure_session(
remote_nonce, remote_peer, remote_key_pair.private_key, remote_conn
)
await asyncio.gather(local_secure_conn.close(), remote_secure_conn.close())
async with trio.open_nursery() as nursery_1:
nursery_1.start_soon(local_create_secure_session)
nursery_1.start_soon(remote_create_secure_session)
local_pipe_task.cancel()
remote_pipe_task.cancel()
await local_pipe_task
await remote_pipe_task
msg = b"abc"
await local_secure_conn.write(msg)
received_msg = await remote_secure_conn.read(MAX_READ_LEN)
assert received_msg == msg

View File

@ -1,8 +1,7 @@
import asyncio
import pytest
import trio
from libp2p import new_node
from libp2p import new_host
from libp2p.crypto.rsa import create_new_key_pair
from libp2p.security.insecure.transport import InsecureSession, InsecureTransport
from libp2p.tools.constants import LISTEN_MADDR
@ -24,42 +23,36 @@ noninitiator_key_pair = create_new_key_pair()
async def perform_simple_test(
assertion_func, transports_for_initiator, transports_for_noninitiator
):
# Create libp2p nodes and connect them, then secure the connection, then check
# the proper security was chosen
# TODO: implement -- note we need to introduce the notion of communicating over a raw connection
# for testing, we do NOT want to communicate over a stream so we can't just create two nodes
# and use their conn because our mplex will internally relay messages to a stream
node1 = await new_node(
key_pair=initiator_key_pair, sec_opt=transports_for_initiator
)
node2 = await new_node(
node1 = new_host(key_pair=initiator_key_pair, sec_opt=transports_for_initiator)
node2 = new_host(
key_pair=noninitiator_key_pair, sec_opt=transports_for_noninitiator
)
async with node1.run(listen_addrs=[LISTEN_MADDR]), node2.run(
listen_addrs=[LISTEN_MADDR]
):
await connect(node1, node2)
await node1.get_network().listen(LISTEN_MADDR)
await node2.get_network().listen(LISTEN_MADDR)
# Wait a very short period to allow conns to be stored (since the functions
# storing the conns are async, they may happen at slightly different times
# on each node)
await trio.sleep(0.1)
await connect(node1, node2)
# Get conns
node1_conn = node1.get_network().connections[peer_id_for_node(node2)]
node2_conn = node2.get_network().connections[peer_id_for_node(node1)]
# Wait a very short period to allow conns to be stored (since the functions
# storing the conns are async, they may happen at slightly different times
# on each node)
await asyncio.sleep(0.1)
# Get conns
node1_conn = node1.get_network().connections[peer_id_for_node(node2)]
node2_conn = node2.get_network().connections[peer_id_for_node(node1)]
# Perform assertion
assertion_func(node1_conn.muxed_conn.secured_conn)
assertion_func(node2_conn.muxed_conn.secured_conn)
# Success, terminate pending tasks.
# Perform assertion
assertion_func(node1_conn.muxed_conn.secured_conn)
assertion_func(node2_conn.muxed_conn.secured_conn)
@pytest.mark.asyncio
@pytest.mark.trio
async def test_single_insecure_security_transport_succeeds():
transports_for_initiator = {"foo": InsecureTransport(initiator_key_pair)}
transports_for_noninitiator = {"foo": InsecureTransport(noninitiator_key_pair)}
@ -72,7 +65,7 @@ async def test_single_insecure_security_transport_succeeds():
)
@pytest.mark.asyncio
@pytest.mark.trio
async def test_default_insecure_security():
transports_for_initiator = None
transports_for_noninitiator = None

View File

@ -1,5 +1,3 @@
import asyncio
import pytest
from libp2p.tools.factories import mplex_conn_pair_factory, mplex_stream_pair_factory
@ -7,23 +5,13 @@ from libp2p.tools.factories import mplex_conn_pair_factory, mplex_stream_pair_fa
@pytest.fixture
async def mplex_conn_pair(is_host_secure):
mplex_conn_0, swarm_0, mplex_conn_1, swarm_1 = await mplex_conn_pair_factory(
is_host_secure
)
assert mplex_conn_0.is_initiator
assert not mplex_conn_1.is_initiator
try:
yield mplex_conn_0, mplex_conn_1
finally:
await asyncio.gather(*[swarm_0.close(), swarm_1.close()])
async with mplex_conn_pair_factory(is_host_secure) as mplex_conn_pair:
assert mplex_conn_pair[0].is_initiator
assert not mplex_conn_pair[1].is_initiator
yield mplex_conn_pair[0], mplex_conn_pair[1]
@pytest.fixture
async def mplex_stream_pair(is_host_secure):
mplex_stream_0, swarm_0, mplex_stream_1, swarm_1 = await mplex_stream_pair_factory(
is_host_secure
)
try:
yield mplex_stream_0, mplex_stream_1
finally:
await asyncio.gather(*[swarm_0.close(), swarm_1.close()])
async with mplex_stream_pair_factory(is_host_secure) as mplex_stream_pair:
yield mplex_stream_pair

View File

@ -1,49 +1,40 @@
import asyncio
import pytest
import trio
@pytest.mark.asyncio
@pytest.mark.trio
async def test_mplex_conn(mplex_conn_pair):
conn_0, conn_1 = mplex_conn_pair
assert len(conn_0.streams) == 0
assert len(conn_1.streams) == 0
assert not conn_0.event_shutting_down.is_set()
assert not conn_1.event_shutting_down.is_set()
assert not conn_0.event_closed.is_set()
assert not conn_1.event_closed.is_set()
# Test: Open a stream, and both side get 1 more stream.
stream_0 = await conn_0.open_stream()
await asyncio.sleep(0.01)
await trio.sleep(0.01)
assert len(conn_0.streams) == 1
assert len(conn_1.streams) == 1
# Test: From another side.
stream_1 = await conn_1.open_stream()
await asyncio.sleep(0.01)
await trio.sleep(0.01)
assert len(conn_0.streams) == 2
assert len(conn_1.streams) == 2
# Close from one side.
await conn_0.close()
# Sleep for a while for both side to handle `close`.
await asyncio.sleep(0.01)
await trio.sleep(0.01)
# Test: Both side is closed.
assert conn_0.event_shutting_down.is_set()
assert conn_0.event_closed.is_set()
assert conn_1.event_shutting_down.is_set()
assert conn_1.event_closed.is_set()
assert conn_0.is_closed
assert conn_1.is_closed
# Test: All streams should have been closed.
assert stream_0.event_remote_closed.is_set()
assert stream_0.event_reset.is_set()
assert stream_0.event_local_closed.is_set()
assert conn_0.streams is None
# Test: All streams on the other side are also closed.
assert stream_1.event_remote_closed.is_set()
assert stream_1.event_reset.is_set()
assert stream_1.event_local_closed.is_set()
assert conn_1.streams is None
# Test: No effect to close more than once between two side.
await conn_0.close()

View File

@ -1,25 +1,48 @@
import asyncio
import pytest
import trio
from trio.testing import wait_all_tasks_blocked
from libp2p.stream_muxer.mplex.exceptions import (
MplexStreamClosed,
MplexStreamEOF,
MplexStreamReset,
)
from libp2p.stream_muxer.mplex.mplex import MPLEX_MESSAGE_CHANNEL_SIZE
from libp2p.tools.constants import MAX_READ_LEN
DATA = b"data_123"
@pytest.mark.asyncio
@pytest.mark.trio
async def test_mplex_stream_read_write(mplex_stream_pair):
stream_0, stream_1 = mplex_stream_pair
await stream_0.write(DATA)
assert (await stream_1.read(MAX_READ_LEN)) == DATA
@pytest.mark.asyncio
@pytest.mark.trio
async def test_mplex_stream_full_buffer(mplex_stream_pair):
stream_0, stream_1 = mplex_stream_pair
# Test: The message channel is of size `MPLEX_MESSAGE_CHANNEL_SIZE`.
# It should be fine to read even there are already `MPLEX_MESSAGE_CHANNEL_SIZE`
# messages arriving.
for _ in range(MPLEX_MESSAGE_CHANNEL_SIZE):
await stream_0.write(DATA)
await wait_all_tasks_blocked()
# Sanity check
assert MAX_READ_LEN >= MPLEX_MESSAGE_CHANNEL_SIZE * len(DATA)
assert (await stream_1.read(MAX_READ_LEN)) == MPLEX_MESSAGE_CHANNEL_SIZE * DATA
# Test: Read after `MPLEX_MESSAGE_CHANNEL_SIZE + 1` messages has arrived, which
# exceeds the channel size. The stream should have been reset.
for _ in range(MPLEX_MESSAGE_CHANNEL_SIZE + 1):
await stream_0.write(DATA)
await wait_all_tasks_blocked()
with pytest.raises(MplexStreamReset):
await stream_1.read(MAX_READ_LEN)
@pytest.mark.trio
async def test_mplex_stream_pair_read_until_eof(mplex_stream_pair):
read_bytes = bytearray()
stream_0, stream_1 = mplex_stream_pair
@ -27,43 +50,46 @@ async def test_mplex_stream_pair_read_until_eof(mplex_stream_pair):
async def read_until_eof():
read_bytes.extend(await stream_1.read())
task = asyncio.ensure_future(read_until_eof())
expected_data = bytearray()
# Test: `read` doesn't return before `close` is called.
await stream_0.write(DATA)
expected_data.extend(DATA)
await asyncio.sleep(0.01)
assert len(read_bytes) == 0
# Test: `read` doesn't return before `close` is called.
await stream_0.write(DATA)
expected_data.extend(DATA)
await asyncio.sleep(0.01)
assert len(read_bytes) == 0
async with trio.open_nursery() as nursery:
nursery.start_soon(read_until_eof)
# Test: `read` doesn't return before `close` is called.
await stream_0.write(DATA)
expected_data.extend(DATA)
await trio.sleep(0.01)
assert len(read_bytes) == 0
# Test: `read` doesn't return before `close` is called.
await stream_0.write(DATA)
expected_data.extend(DATA)
await trio.sleep(0.01)
assert len(read_bytes) == 0
# Test: Close the stream, `read` returns, and receive previous sent data.
await stream_0.close()
# Test: Close the stream, `read` returns, and receive previous sent data.
await stream_0.close()
await asyncio.sleep(0.01)
assert read_bytes == expected_data
task.cancel()
@pytest.mark.asyncio
@pytest.mark.trio
async def test_mplex_stream_read_after_remote_closed(mplex_stream_pair):
stream_0, stream_1 = mplex_stream_pair
assert not stream_1.event_remote_closed.is_set()
await stream_0.write(DATA)
assert not stream_0.event_local_closed.is_set()
await trio.sleep(0.01)
await wait_all_tasks_blocked()
await stream_0.close()
await asyncio.sleep(0.01)
assert stream_0.event_local_closed.is_set()
await trio.sleep(0.01)
await wait_all_tasks_blocked()
assert stream_1.event_remote_closed.is_set()
assert (await stream_1.read(MAX_READ_LEN)) == DATA
with pytest.raises(MplexStreamEOF):
await stream_1.read(MAX_READ_LEN)
@pytest.mark.asyncio
@pytest.mark.trio
async def test_mplex_stream_read_after_local_reset(mplex_stream_pair):
stream_0, stream_1 = mplex_stream_pair
await stream_0.reset()
@ -71,29 +97,30 @@ async def test_mplex_stream_read_after_local_reset(mplex_stream_pair):
await stream_0.read(MAX_READ_LEN)
@pytest.mark.asyncio
@pytest.mark.trio
async def test_mplex_stream_read_after_remote_reset(mplex_stream_pair):
stream_0, stream_1 = mplex_stream_pair
await stream_0.write(DATA)
await stream_0.reset()
# Sleep to let `stream_1` receive the message.
await asyncio.sleep(0.01)
await trio.sleep(0.1)
await wait_all_tasks_blocked()
with pytest.raises(MplexStreamReset):
await stream_1.read(MAX_READ_LEN)
@pytest.mark.asyncio
@pytest.mark.trio
async def test_mplex_stream_read_after_remote_closed_and_reset(mplex_stream_pair):
stream_0, stream_1 = mplex_stream_pair
await stream_0.write(DATA)
await stream_0.close()
await stream_0.reset()
# Sleep to let `stream_1` receive the message.
await asyncio.sleep(0.01)
await trio.sleep(0.01)
assert (await stream_1.read(MAX_READ_LEN)) == DATA
@pytest.mark.asyncio
@pytest.mark.trio
async def test_mplex_stream_write_after_local_closed(mplex_stream_pair):
stream_0, stream_1 = mplex_stream_pair
await stream_0.write(DATA)
@ -102,7 +129,7 @@ async def test_mplex_stream_write_after_local_closed(mplex_stream_pair):
await stream_0.write(DATA)
@pytest.mark.asyncio
@pytest.mark.trio
async def test_mplex_stream_write_after_local_reset(mplex_stream_pair):
stream_0, stream_1 = mplex_stream_pair
await stream_0.reset()
@ -110,16 +137,16 @@ async def test_mplex_stream_write_after_local_reset(mplex_stream_pair):
await stream_0.write(DATA)
@pytest.mark.asyncio
@pytest.mark.trio
async def test_mplex_stream_write_after_remote_reset(mplex_stream_pair):
stream_0, stream_1 = mplex_stream_pair
await stream_1.reset()
await asyncio.sleep(0.01)
await trio.sleep(0.01)
with pytest.raises(MplexStreamClosed):
await stream_0.write(DATA)
@pytest.mark.asyncio
@pytest.mark.trio
async def test_mplex_stream_both_close(mplex_stream_pair):
stream_0, stream_1 = mplex_stream_pair
# Flags are not set initially.
@ -133,7 +160,7 @@ async def test_mplex_stream_both_close(mplex_stream_pair):
# Test: Close one side.
await stream_0.close()
await asyncio.sleep(0.01)
await trio.sleep(0.01)
assert stream_0.event_local_closed.is_set()
assert not stream_1.event_local_closed.is_set()
@ -145,7 +172,7 @@ async def test_mplex_stream_both_close(mplex_stream_pair):
# Test: Close the other side.
await stream_1.close()
await asyncio.sleep(0.01)
await trio.sleep(0.01)
# Both sides are closed.
assert stream_0.event_local_closed.is_set()
assert stream_1.event_local_closed.is_set()
@ -159,11 +186,11 @@ async def test_mplex_stream_both_close(mplex_stream_pair):
await stream_0.reset()
@pytest.mark.asyncio
@pytest.mark.trio
async def test_mplex_stream_reset(mplex_stream_pair):
stream_0, stream_1 = mplex_stream_pair
await stream_0.reset()
await asyncio.sleep(0.01)
await trio.sleep(0.01)
# Both sides are closed.
assert stream_0.event_local_closed.is_set()

View File

@ -1,20 +1,53 @@
import asyncio
from multiaddr import Multiaddr
import pytest
import trio
from libp2p.transport.tcp.tcp import _multiaddr_from_socket
from libp2p.network.connection.raw_connection import RawConnection
from libp2p.tools.constants import LISTEN_MADDR
from libp2p.transport.exceptions import OpenConnectionError
from libp2p.transport.tcp.tcp import TCP
@pytest.mark.asyncio
async def test_multiaddr_from_socket():
def handler(r, w):
@pytest.mark.trio
async def test_tcp_listener(nursery):
transport = TCP()
async def handler(tcp_stream):
pass
server = await asyncio.start_server(handler, "127.0.0.1", 8000)
assert str(_multiaddr_from_socket(server.sockets[0])) == "/ip4/127.0.0.1/tcp/8000"
listener = transport.create_listener(handler)
assert len(listener.get_addrs()) == 0
await listener.listen(LISTEN_MADDR, nursery)
assert len(listener.get_addrs()) == 1
await listener.listen(LISTEN_MADDR, nursery)
assert len(listener.get_addrs()) == 2
server = await asyncio.start_server(handler, "127.0.0.1", 0)
addr = _multiaddr_from_socket(server.sockets[0])
assert addr.value_for_protocol("ip4") == "127.0.0.1"
port = addr.value_for_protocol("tcp")
assert int(port) > 0
@pytest.mark.trio
async def test_tcp_dial(nursery):
transport = TCP()
raw_conn_other_side = None
event = trio.Event()
async def handler(tcp_stream):
nonlocal raw_conn_other_side
raw_conn_other_side = RawConnection(tcp_stream, False)
event.set()
await trio.sleep_forever()
# Test: `OpenConnectionError` is raised when trying to dial to a port which
# no one is not listening to.
with pytest.raises(OpenConnectionError):
await transport.dial(Multiaddr("/ip4/127.0.0.1/tcp/1"))
listener = transport.create_listener(handler)
await listener.listen(LISTEN_MADDR, nursery)
addrs = listener.get_addrs()
assert len(addrs) == 1
listen_addr = addrs[0]
raw_conn = await transport.dial(listen_addr)
await event.wait()
data = b"123"
await raw_conn_other_side.write(data)
assert (await raw_conn.read(len(data))) == data

View File

@ -1,20 +1,13 @@
import asyncio
import sys
from typing import Union
import anyio
from async_exit_stack import AsyncExitStack
from p2pclient.datastructures import StreamInfo
import pexpect
from p2pclient.utils import get_unused_tcp_port
import pytest
import trio
from libp2p.io.abc import ReadWriteCloser
from libp2p.tools.constants import GOSSIPSUB_PARAMS, LISTEN_MADDR
from libp2p.tools.factories import (
FloodsubFactory,
GossipsubFactory,
HostFactory,
PubsubFactory,
)
from libp2p.tools.interop.daemon import Daemon, make_p2pd
from libp2p.tools.factories import HostFactory, PubsubFactory
from libp2p.tools.interop.daemon import make_p2pd
from libp2p.tools.interop.utils import connect
@ -23,48 +16,6 @@ def is_host_secure():
return False
@pytest.fixture
def num_hosts():
return 3
@pytest.fixture
async def hosts(num_hosts, is_host_secure):
_hosts = HostFactory.create_batch(num_hosts, is_secure=is_host_secure)
await asyncio.gather(
*[_host.get_network().listen(LISTEN_MADDR) for _host in _hosts]
)
try:
yield _hosts
finally:
# TODO: It's possible that `close` raises exceptions currently,
# due to the connection reset things. Though we don't care much about that when
# cleaning up the tasks, it is probably better to handle the exceptions properly.
await asyncio.gather(
*[_host.close() for _host in _hosts], return_exceptions=True
)
@pytest.fixture
def proc_factory():
procs = []
def call_proc(cmd, args, logfile=None, encoding=None):
if logfile is None:
logfile = sys.stdout
if encoding is None:
encoding = "utf-8"
proc = pexpect.spawn(cmd, args, logfile=logfile, encoding=encoding)
procs.append(proc)
return proc
try:
yield call_proc
finally:
for proc in procs:
proc.close()
@pytest.fixture
def num_p2pds():
return 1
@ -87,79 +38,57 @@ def is_pubsub_signing_strict():
@pytest.fixture
async def p2pds(
num_p2pds,
is_host_secure,
is_gossipsub,
unused_tcp_port_factory,
is_pubsub_signing,
is_pubsub_signing_strict,
num_p2pds, is_host_secure, is_gossipsub, is_pubsub_signing, is_pubsub_signing_strict
):
p2pds: Union[Daemon, Exception] = await asyncio.gather(
*[
make_p2pd(
unused_tcp_port_factory(),
unused_tcp_port_factory(),
is_host_secure,
is_gossipsub=is_gossipsub,
is_pubsub_signing=is_pubsub_signing,
is_pubsub_signing_strict=is_pubsub_signing_strict,
async with AsyncExitStack() as stack:
p2pds = [
await stack.enter_async_context(
make_p2pd(
get_unused_tcp_port(),
get_unused_tcp_port(),
is_host_secure,
is_gossipsub=is_gossipsub,
is_pubsub_signing=is_pubsub_signing,
is_pubsub_signing_strict=is_pubsub_signing_strict,
)
)
for _ in range(num_p2pds)
],
return_exceptions=True,
)
p2pds_succeeded = tuple(p2pd for p2pd in p2pds if isinstance(p2pd, Daemon))
if len(p2pds_succeeded) != len(p2pds):
# Not all succeeded. Close the succeeded ones and print the failed ones(exceptions).
await asyncio.gather(*[p2pd.close() for p2pd in p2pds_succeeded])
exceptions = tuple(p2pd for p2pd in p2pds if isinstance(p2pd, Exception))
raise Exception(f"not all p2pds succeed: first exception={exceptions[0]}")
try:
yield p2pds
finally:
await asyncio.gather(*[p2pd.close() for p2pd in p2pds])
]
try:
yield p2pds
finally:
for p2pd in p2pds:
await p2pd.close()
@pytest.fixture
def pubsubs(num_hosts, hosts, is_gossipsub, is_pubsub_signing_strict):
async def pubsubs(num_hosts, is_host_secure, is_gossipsub, is_pubsub_signing_strict):
if is_gossipsub:
routers = GossipsubFactory.create_batch(num_hosts, **GOSSIPSUB_PARAMS._asdict())
yield PubsubFactory.create_batch_with_gossipsub(
num_hosts, is_secure=is_host_secure, strict_signing=is_pubsub_signing_strict
)
else:
routers = FloodsubFactory.create_batch(num_hosts)
_pubsubs = tuple(
PubsubFactory(host=host, router=router, strict_signing=is_pubsub_signing_strict)
for host, router in zip(hosts, routers)
)
yield _pubsubs
# TODO: Clean up
yield PubsubFactory.create_batch_with_floodsub(
num_hosts, is_host_secure, strict_signing=is_pubsub_signing_strict
)
class DaemonStream(ReadWriteCloser):
stream_info: StreamInfo
reader: asyncio.StreamReader
writer: asyncio.StreamWriter
stream: anyio.abc.SocketStream
def __init__(
self,
stream_info: StreamInfo,
reader: asyncio.StreamReader,
writer: asyncio.StreamWriter,
) -> None:
def __init__(self, stream_info: StreamInfo, stream: anyio.abc.SocketStream) -> None:
self.stream_info = stream_info
self.reader = reader
self.writer = writer
self.stream = stream
async def close(self) -> None:
self.writer.close()
if sys.version_info < (3, 7):
return
await self.writer.wait_closed()
await self.stream.close()
async def read(self, n: int = -1) -> bytes:
return await self.reader.read(n)
async def read(self, n: int = None) -> bytes:
return await self.stream.receive_some(n)
async def write(self, data: bytes) -> int:
return self.writer.write(data)
async def write(self, data: bytes) -> None:
return await self.stream.send_all(data)
@pytest.fixture
@ -168,40 +97,38 @@ async def is_to_fail_daemon_stream():
@pytest.fixture
async def py_to_daemon_stream_pair(hosts, p2pds, is_to_fail_daemon_stream):
assert len(hosts) >= 1
assert len(p2pds) >= 1
host = hosts[0]
p2pd = p2pds[0]
protocol_id = "/protocol/id/123"
stream_py = None
stream_daemon = None
event_stream_handled = asyncio.Event()
await connect(host, p2pd)
async def py_to_daemon_stream_pair(p2pds, is_host_secure, is_to_fail_daemon_stream):
async with HostFactory.create_batch_and_listen(is_host_secure, 1) as hosts:
assert len(p2pds) >= 1
host = hosts[0]
p2pd = p2pds[0]
protocol_id = "/protocol/id/123"
stream_py = None
stream_daemon = None
event_stream_handled = trio.Event()
await connect(host, p2pd)
async def daemon_stream_handler(stream_info, reader, writer):
nonlocal stream_daemon
stream_daemon = DaemonStream(stream_info, reader, writer)
event_stream_handled.set()
async def daemon_stream_handler(stream_info, stream):
nonlocal stream_daemon
stream_daemon = DaemonStream(stream_info, stream)
event_stream_handled.set()
await trio.hazmat.checkpoint()
await p2pd.control.stream_handler(protocol_id, daemon_stream_handler)
# Sleep for a while to wait for the handler being registered.
await asyncio.sleep(0.01)
await p2pd.control.stream_handler(protocol_id, daemon_stream_handler)
# Sleep for a while to wait for the handler being registered.
await trio.sleep(0.01)
if is_to_fail_daemon_stream:
# FIXME: This is a workaround to make daemon reset the stream.
# We intentionally close the listener on the python side, it makes the connection from
# daemon to us fail, and therefore the daemon resets the opened stream on their side.
# Reference: https://github.com/libp2p/go-libp2p-daemon/blob/b95e77dbfcd186ccf817f51e95f73f9fd5982600/stream.go#L47-L50 # noqa: E501
# We need it because we want to test against `stream_py` after the remote side(daemon)
# is reset. This should be removed after the API `stream.reset` is exposed in daemon
# some day.
listener = p2pds[0].control.control.listener
listener.close()
if sys.version_info[0:2] > (3, 6):
await listener.wait_closed()
stream_py = await host.new_stream(p2pd.peer_id, [protocol_id])
if not is_to_fail_daemon_stream:
await event_stream_handled.wait()
# NOTE: If `is_to_fail_daemon_stream == True`, then `stream_daemon == None`.
yield stream_py, stream_daemon
if is_to_fail_daemon_stream:
# FIXME: This is a workaround to make daemon reset the stream.
# We intentionally close the listener on the python side, it makes the connection from
# daemon to us fail, and therefore the daemon resets the opened stream on their side.
# Reference: https://github.com/libp2p/go-libp2p-daemon/blob/b95e77dbfcd186ccf817f51e95f73f9fd5982600/stream.go#L47-L50 # noqa: E501
# We need it because we want to test against `stream_py` after the remote side(daemon)
# is reset. This should be removed after the API `stream.reset` is exposed in daemon
# some day.
await p2pds[0].control.control.close()
stream_py = await host.new_stream(p2pd.peer_id, [protocol_id])
if not is_to_fail_daemon_stream:
await event_stream_handled.wait()
# NOTE: If `is_to_fail_daemon_stream == True`, then `stream_daemon == None`.
yield stream_py, stream_daemon

View File

@ -1,26 +1,26 @@
import asyncio
import pytest
import trio
from libp2p.tools.factories import HostFactory
from libp2p.tools.interop.utils import connect
@pytest.mark.parametrize("num_hosts", (1,))
@pytest.mark.asyncio
async def test_connect(hosts, p2pds):
p2pd = p2pds[0]
host = hosts[0]
assert len(await p2pd.control.list_peers()) == 0
# Test: connect from Py
await connect(host, p2pd)
assert len(await p2pd.control.list_peers()) == 1
# Test: `disconnect` from Py
await host.disconnect(p2pd.peer_id)
assert len(await p2pd.control.list_peers()) == 0
# Test: connect from Go
await connect(p2pd, host)
assert len(host.get_network().connections) == 1
# Test: `disconnect` from Go
await p2pd.control.disconnect(host.get_id())
await asyncio.sleep(0.01)
assert len(host.get_network().connections) == 0
@pytest.mark.trio
async def test_connect(is_host_secure, p2pds):
async with HostFactory.create_batch_and_listen(is_host_secure, 1) as hosts:
p2pd = p2pds[0]
host = hosts[0]
assert len(await p2pd.control.list_peers()) == 0
# Test: connect from Py
await connect(host, p2pd)
assert len(await p2pd.control.list_peers()) == 1
# Test: `disconnect` from Py
await host.disconnect(p2pd.peer_id)
assert len(await p2pd.control.list_peers()) == 0
# Test: connect from Go
await connect(p2pd, host)
assert len(host.get_network().connections) == 1
# Test: `disconnect` from Go
await p2pd.control.disconnect(host.get_id())
await trio.sleep(0.01)
assert len(host.get_network().connections) == 0

View File

@ -1,82 +1,99 @@
import asyncio
import re
from multiaddr import Multiaddr
from p2pclient.utils import get_unused_tcp_port
import pytest
import trio
from libp2p.peer.peerinfo import info_from_p2p_addr
from libp2p.tools.interop.constants import PEXPECT_NEW_LINE
from libp2p.peer.peerinfo import PeerInfo, info_from_p2p_addr
from libp2p.tools.factories import HostFactory
from libp2p.tools.interop.envs import GO_BIN_PATH
from libp2p.tools.interop.process import BaseInteractiveProcess
from libp2p.typing import TProtocol
ECHO_PATH = GO_BIN_PATH / "echo"
ECHO_PROTOCOL_ID = TProtocol("/echo/1.0.0")
async def make_echo_proc(
proc_factory, port: int, is_secure: bool, destination: Multiaddr = None
):
args = [f"-l={port}"]
if not is_secure:
args.append("-insecure")
if destination is not None:
args.append(f"-d={str(destination)}")
echo_proc = proc_factory(str(ECHO_PATH), args)
await echo_proc.expect(r"I am ([\w\./]+)" + PEXPECT_NEW_LINE, async_=True)
maddr_str_ipfs = echo_proc.match.group(1)
maddr_str = maddr_str_ipfs.replace("ipfs", "p2p")
maddr = Multiaddr(maddr_str)
go_pinfo = info_from_p2p_addr(maddr)
if destination is None:
await echo_proc.expect("listening for connections", async_=True)
return echo_proc, go_pinfo
class EchoProcess(BaseInteractiveProcess):
port: int
_peer_info: PeerInfo
def __init__(
self, port: int, is_secure: bool, destination: Multiaddr = None
) -> None:
args = [f"-l={port}"]
if not is_secure:
args.append("-insecure")
if destination is not None:
args.append(f"-d={str(destination)}")
patterns = [b"I am"]
if destination is None:
patterns.append(b"listening for connections")
self.args = args
self.cmd = str(ECHO_PATH)
self.patterns = patterns
self.bytes_read = bytearray()
self.event_ready = trio.Event()
self.port = port
self._peer_info = None
self.regex_pat = re.compile(br"I am ([\w\./]+)")
@property
def peer_info(self) -> None:
if self._peer_info is not None:
return self._peer_info
if not self.event_ready.is_set():
raise Exception("process is not ready yet. failed to parse the peer info")
# Example:
# b"I am /ip4/127.0.0.1/tcp/56171/ipfs/QmU41TRPs34WWqa1brJEojBLYZKrrBcJq9nyNfVvSrbZUJ\n"
m = re.search(br"I am ([\w\./]+)", self.bytes_read)
if m is None:
raise Exception("failed to find the pattern for the listening multiaddr")
maddr_bytes_str_ipfs = m.group(1)
maddr_str = maddr_bytes_str_ipfs.decode().replace("ipfs", "p2p")
maddr = Multiaddr(maddr_str)
self._peer_info = info_from_p2p_addr(maddr)
return self._peer_info
@pytest.mark.parametrize("num_hosts", (1,))
@pytest.mark.asyncio
async def test_insecure_conn_py_to_go(
hosts, proc_factory, is_host_secure, unused_tcp_port
):
go_proc, go_pinfo = await make_echo_proc(
proc_factory, unused_tcp_port, is_host_secure
)
@pytest.mark.trio
async def test_insecure_conn_py_to_go(is_host_secure):
async with HostFactory.create_batch_and_listen(is_host_secure, 1) as hosts:
go_proc = EchoProcess(get_unused_tcp_port(), is_host_secure)
await go_proc.start()
host = hosts[0]
await host.connect(go_pinfo)
await go_proc.expect("swarm listener accepted connection", async_=True)
s = await host.new_stream(go_pinfo.peer_id, [ECHO_PROTOCOL_ID])
await go_proc.expect("Got a new stream!", async_=True)
data = "data321123\n"
await s.write(data.encode())
await go_proc.expect(f"read: {data[:-1]}", async_=True)
echoed_resp = await s.read(len(data))
assert echoed_resp.decode() == data
await s.close()
host = hosts[0]
peer_info = go_proc.peer_info
await host.connect(peer_info)
s = await host.new_stream(peer_info.peer_id, [ECHO_PROTOCOL_ID])
data = "data321123\n"
await s.write(data.encode())
echoed_resp = await s.read(len(data))
assert echoed_resp.decode() == data
await s.close()
@pytest.mark.parametrize("num_hosts", (1,))
@pytest.mark.asyncio
async def test_insecure_conn_go_to_py(
hosts, proc_factory, is_host_secure, unused_tcp_port
):
host = hosts[0]
expected_data = "Hello, world!\n"
reply_data = "Replyooo!\n"
event_handler_finished = asyncio.Event()
@pytest.mark.trio
async def test_insecure_conn_go_to_py(is_host_secure):
async with HostFactory.create_batch_and_listen(is_host_secure, 1) as hosts:
host = hosts[0]
expected_data = "Hello, world!\n"
reply_data = "Replyooo!\n"
event_handler_finished = trio.Event()
async def _handle_echo(stream):
read_data = await stream.read(len(expected_data))
assert read_data == expected_data.encode()
event_handler_finished.set()
await stream.write(reply_data.encode())
await stream.close()
async def _handle_echo(stream):
read_data = await stream.read(len(expected_data))
assert read_data == expected_data.encode()
event_handler_finished.set()
await stream.write(reply_data.encode())
await stream.close()
host.set_stream_handler(ECHO_PROTOCOL_ID, _handle_echo)
py_maddr = host.get_addrs()[0]
go_proc, _ = await make_echo_proc(
proc_factory, unused_tcp_port, is_host_secure, py_maddr
)
await go_proc.expect("connect with peer", async_=True)
await go_proc.expect("opened stream", async_=True)
await event_handler_finished.wait()
await go_proc.expect(f"read reply: .*{reply_data.rstrip()}.*", async_=True)
host.set_stream_handler(ECHO_PROTOCOL_ID, _handle_echo)
py_maddr = host.get_addrs()[0]
go_proc = EchoProcess(get_unused_tcp_port(), is_host_secure, py_maddr)
await go_proc.start()
await event_handler_finished.wait()

View File

@ -1,6 +1,5 @@
import asyncio
import pytest
import trio
from libp2p.network.stream.exceptions import StreamClosed, StreamEOF, StreamReset
from libp2p.tools.constants import MAX_READ_LEN
@ -8,7 +7,7 @@ from libp2p.tools.constants import MAX_READ_LEN
DATA = b"data"
@pytest.mark.asyncio
@pytest.mark.trio
async def test_net_stream_read_write(py_to_daemon_stream_pair, p2pds):
stream_py, stream_daemon = py_to_daemon_stream_pair
assert (
@ -19,19 +18,19 @@ async def test_net_stream_read_write(py_to_daemon_stream_pair, p2pds):
assert (await stream_daemon.read(MAX_READ_LEN)) == DATA
@pytest.mark.asyncio
@pytest.mark.trio
async def test_net_stream_read_after_remote_closed(py_to_daemon_stream_pair, p2pds):
stream_py, stream_daemon = py_to_daemon_stream_pair
await stream_daemon.write(DATA)
await stream_daemon.close()
await asyncio.sleep(0.01)
await trio.sleep(0.01)
assert (await stream_py.read(MAX_READ_LEN)) == DATA
# EOF
with pytest.raises(StreamEOF):
await stream_py.read(MAX_READ_LEN)
@pytest.mark.asyncio
@pytest.mark.trio
async def test_net_stream_read_after_local_reset(py_to_daemon_stream_pair, p2pds):
stream_py, _ = py_to_daemon_stream_pair
await stream_py.reset()
@ -40,15 +39,15 @@ async def test_net_stream_read_after_local_reset(py_to_daemon_stream_pair, p2pds
@pytest.mark.parametrize("is_to_fail_daemon_stream", (True,))
@pytest.mark.asyncio
@pytest.mark.trio
async def test_net_stream_read_after_remote_reset(py_to_daemon_stream_pair, p2pds):
stream_py, _ = py_to_daemon_stream_pair
await asyncio.sleep(0.01)
await trio.sleep(0.01)
with pytest.raises(StreamReset):
await stream_py.read(MAX_READ_LEN)
@pytest.mark.asyncio
@pytest.mark.trio
async def test_net_stream_write_after_local_closed(py_to_daemon_stream_pair, p2pds):
stream_py, _ = py_to_daemon_stream_pair
await stream_py.write(DATA)
@ -57,7 +56,7 @@ async def test_net_stream_write_after_local_closed(py_to_daemon_stream_pair, p2p
await stream_py.write(DATA)
@pytest.mark.asyncio
@pytest.mark.trio
async def test_net_stream_write_after_local_reset(py_to_daemon_stream_pair, p2pds):
stream_py, stream_daemon = py_to_daemon_stream_pair
await stream_py.reset()
@ -66,9 +65,9 @@ async def test_net_stream_write_after_local_reset(py_to_daemon_stream_pair, p2pd
@pytest.mark.parametrize("is_to_fail_daemon_stream", (True,))
@pytest.mark.asyncio
@pytest.mark.trio
async def test_net_stream_write_after_remote_reset(py_to_daemon_stream_pair, p2pds):
stream_py, _ = py_to_daemon_stream_pair
await asyncio.sleep(0.01)
await trio.sleep(0.01)
with pytest.raises(StreamClosed):
await stream_py.write(DATA)

View File

@ -1,11 +1,15 @@
import asyncio
import functools
import math
from p2pclient.pb import p2pd_pb2
import pytest
import trio
from libp2p.io.trio import TrioTCPStream
from libp2p.peer.id import ID
from libp2p.pubsub.pb import rpc_pb2
from libp2p.pubsub.subscription import TrioSubscriptionAPI
from libp2p.tools.factories import PubsubFactory
from libp2p.tools.interop.utils import connect
from libp2p.utils import read_varint_prefixed_bytes
@ -13,26 +17,15 @@ TOPIC_0 = "ABALA"
TOPIC_1 = "YOOOO"
async def p2pd_subscribe(p2pd, topic) -> "asyncio.Queue[rpc_pb2.Message]":
reader, writer = await p2pd.control.pubsub_subscribe(topic)
async def p2pd_subscribe(p2pd, topic, nursery):
stream = TrioTCPStream(await p2pd.control.pubsub_subscribe(topic))
send_channel, receive_channel = trio.open_memory_channel(math.inf)
queue = asyncio.Queue()
sub = TrioSubscriptionAPI(receive_channel, unsubscribe_fn=stream.close)
async def _read_pubsub_msg() -> None:
writer_closed_task = asyncio.ensure_future(writer.wait_closed())
while True:
done, pending = await asyncio.wait(
[read_varint_prefixed_bytes(reader), writer_closed_task],
return_when=asyncio.FIRST_COMPLETED,
)
done_tasks = tuple(done)
if writer.is_closing():
return
read_task = done_tasks[0]
# Sanity check
assert read_task._coro.__name__ == "read_varint_prefixed_bytes"
msg_bytes = read_task.result()
msg_bytes = await read_varint_prefixed_bytes(stream)
ps_msg = p2pd_pb2.PSMessage()
ps_msg.ParseFromString(msg_bytes)
# Fill in the message used in py-libp2p
@ -44,11 +37,10 @@ async def p2pd_subscribe(p2pd, topic) -> "asyncio.Queue[rpc_pb2.Message]":
signature=ps_msg.signature,
key=ps_msg.key,
)
queue.put_nowait(msg)
await send_channel.send(msg)
asyncio.ensure_future(_read_pubsub_msg())
await asyncio.sleep(0)
return queue
nursery.start_soon(_read_pubsub_msg)
return sub
def validate_pubsub_msg(msg: rpc_pb2.Message, data: bytes, from_peer_id: ID) -> None:
@ -59,108 +51,119 @@ def validate_pubsub_msg(msg: rpc_pb2.Message, data: bytes, from_peer_id: ID) ->
"is_pubsub_signing, is_pubsub_signing_strict", ((True, True), (False, False))
)
@pytest.mark.parametrize("is_gossipsub", (True, False))
@pytest.mark.parametrize("num_hosts, num_p2pds", ((1, 2),))
@pytest.mark.asyncio
async def test_pubsub(pubsubs, p2pds):
#
# Test: Recognize pubsub peers on connection.
#
py_pubsub = pubsubs[0]
# go0 <-> py <-> go1
await connect(p2pds[0], py_pubsub.host)
await connect(py_pubsub.host, p2pds[1])
py_peer_id = py_pubsub.host.get_id()
# Check pubsub peers
pubsub_peers_0 = await p2pds[0].control.pubsub_list_peers("")
assert len(pubsub_peers_0) == 1 and pubsub_peers_0[0] == py_peer_id
pubsub_peers_1 = await p2pds[1].control.pubsub_list_peers("")
assert len(pubsub_peers_1) == 1 and pubsub_peers_1[0] == py_peer_id
assert (
len(py_pubsub.peers) == 2
and p2pds[0].peer_id in py_pubsub.peers
and p2pds[1].peer_id in py_pubsub.peers
)
@pytest.mark.parametrize("num_p2pds", (2,))
@pytest.mark.trio
async def test_pubsub(
p2pds, is_gossipsub, is_host_secure, is_pubsub_signing_strict, nursery
):
pubsub_factory = None
if is_gossipsub:
pubsub_factory = PubsubFactory.create_batch_with_gossipsub
else:
pubsub_factory = PubsubFactory.create_batch_with_floodsub
#
# Test: `subscribe`.
#
# (name, topics)
# (go_0, [0, 1]) <-> (py, [0, 1]) <-> (go_1, [1])
sub_py_topic_0 = await py_pubsub.subscribe(TOPIC_0)
sub_py_topic_1 = await py_pubsub.subscribe(TOPIC_1)
sub_go_0_topic_0 = await p2pd_subscribe(p2pds[0], TOPIC_0)
sub_go_0_topic_1 = await p2pd_subscribe(p2pds[0], TOPIC_1)
sub_go_1_topic_1 = await p2pd_subscribe(p2pds[1], TOPIC_1)
# Check topic peers
await asyncio.sleep(0.1)
# go_0
go_0_topic_0_peers = await p2pds[0].control.pubsub_list_peers(TOPIC_0)
assert len(go_0_topic_0_peers) == 1 and py_peer_id == go_0_topic_0_peers[0]
go_0_topic_1_peers = await p2pds[0].control.pubsub_list_peers(TOPIC_1)
assert len(go_0_topic_1_peers) == 1 and py_peer_id == go_0_topic_1_peers[0]
# py
py_topic_0_peers = list(py_pubsub.peer_topics[TOPIC_0])
assert len(py_topic_0_peers) == 1 and p2pds[0].peer_id == py_topic_0_peers[0]
# go_1
go_1_topic_1_peers = await p2pds[1].control.pubsub_list_peers(TOPIC_1)
assert len(go_1_topic_1_peers) == 1 and py_peer_id == go_1_topic_1_peers[0]
async with pubsub_factory(
1, is_secure=is_host_secure, strict_signing=is_pubsub_signing_strict
) as pubsubs:
#
# Test: Recognize pubsub peers on connection.
#
py_pubsub = pubsubs[0]
# go0 <-> py <-> go1
await connect(p2pds[0], py_pubsub.host)
await connect(py_pubsub.host, p2pds[1])
py_peer_id = py_pubsub.host.get_id()
# Check pubsub peers
pubsub_peers_0 = await p2pds[0].control.pubsub_list_peers("")
assert len(pubsub_peers_0) == 1 and pubsub_peers_0[0] == py_peer_id
pubsub_peers_1 = await p2pds[1].control.pubsub_list_peers("")
assert len(pubsub_peers_1) == 1 and pubsub_peers_1[0] == py_peer_id
assert (
len(py_pubsub.peers) == 2
and p2pds[0].peer_id in py_pubsub.peers
and p2pds[1].peer_id in py_pubsub.peers
)
#
# Test: `publish`
#
# 1. py publishes
# - 1.1. py publishes data_11 to topic_0, py and go_0 receives.
# - 1.2. py publishes data_12 to topic_1, all receive.
# 2. go publishes
# - 2.1. go_0 publishes data_21 to topic_0, py and go_0 receive.
# - 2.2. go_1 publishes data_22 to topic_1, all receive.
#
# Test: `subscribe`.
#
# (name, topics)
# (go_0, [0, 1]) <-> (py, [0, 1]) <-> (go_1, [1])
sub_py_topic_0 = await py_pubsub.subscribe(TOPIC_0)
sub_py_topic_1 = await py_pubsub.subscribe(TOPIC_1)
sub_go_0_topic_0 = await p2pd_subscribe(p2pds[0], TOPIC_0, nursery)
sub_go_0_topic_1 = await p2pd_subscribe(p2pds[0], TOPIC_1, nursery)
sub_go_1_topic_1 = await p2pd_subscribe(p2pds[1], TOPIC_1, nursery)
# Check topic peers
await trio.sleep(0.1)
# go_0
go_0_topic_0_peers = await p2pds[0].control.pubsub_list_peers(TOPIC_0)
assert len(go_0_topic_0_peers) == 1 and py_peer_id == go_0_topic_0_peers[0]
go_0_topic_1_peers = await p2pds[0].control.pubsub_list_peers(TOPIC_1)
assert len(go_0_topic_1_peers) == 1 and py_peer_id == go_0_topic_1_peers[0]
# py
py_topic_0_peers = list(py_pubsub.peer_topics[TOPIC_0])
assert len(py_topic_0_peers) == 1 and p2pds[0].peer_id == py_topic_0_peers[0]
# go_1
go_1_topic_1_peers = await p2pds[1].control.pubsub_list_peers(TOPIC_1)
assert len(go_1_topic_1_peers) == 1 and py_peer_id == go_1_topic_1_peers[0]
# 1.1. py publishes data_11 to topic_0, py and go_0 receives.
data_11 = b"data_11"
await py_pubsub.publish(TOPIC_0, data_11)
validate_11 = functools.partial(
validate_pubsub_msg, data=data_11, from_peer_id=py_peer_id
)
validate_11(await sub_py_topic_0.get())
validate_11(await sub_go_0_topic_0.get())
#
# Test: `publish`
#
# 1. py publishes
# - 1.1. py publishes data_11 to topic_0, py and go_0 receives.
# - 1.2. py publishes data_12 to topic_1, all receive.
# 2. go publishes
# - 2.1. go_0 publishes data_21 to topic_0, py and go_0 receive.
# - 2.2. go_1 publishes data_22 to topic_1, all receive.
# 1.2. py publishes data_12 to topic_1, all receive.
data_12 = b"data_12"
validate_12 = functools.partial(
validate_pubsub_msg, data=data_12, from_peer_id=py_peer_id
)
await py_pubsub.publish(TOPIC_1, data_12)
validate_12(await sub_py_topic_1.get())
validate_12(await sub_go_0_topic_1.get())
validate_12(await sub_go_1_topic_1.get())
# 1.1. py publishes data_11 to topic_0, py and go_0 receives.
data_11 = b"data_11"
await py_pubsub.publish(TOPIC_0, data_11)
validate_11 = functools.partial(
validate_pubsub_msg, data=data_11, from_peer_id=py_peer_id
)
validate_11(await sub_py_topic_0.get())
validate_11(await sub_go_0_topic_0.get())
# 2.1. go_0 publishes data_21 to topic_0, py and go_0 receive.
data_21 = b"data_21"
validate_21 = functools.partial(
validate_pubsub_msg, data=data_21, from_peer_id=p2pds[0].peer_id
)
await p2pds[0].control.pubsub_publish(TOPIC_0, data_21)
validate_21(await sub_py_topic_0.get())
validate_21(await sub_go_0_topic_0.get())
# 1.2. py publishes data_12 to topic_1, all receive.
data_12 = b"data_12"
validate_12 = functools.partial(
validate_pubsub_msg, data=data_12, from_peer_id=py_peer_id
)
await py_pubsub.publish(TOPIC_1, data_12)
validate_12(await sub_py_topic_1.get())
validate_12(await sub_go_0_topic_1.get())
validate_12(await sub_go_1_topic_1.get())
# 2.2. go_1 publishes data_22 to topic_1, all receive.
data_22 = b"data_22"
validate_22 = functools.partial(
validate_pubsub_msg, data=data_22, from_peer_id=p2pds[1].peer_id
)
await p2pds[1].control.pubsub_publish(TOPIC_1, data_22)
validate_22(await sub_py_topic_1.get())
validate_22(await sub_go_0_topic_1.get())
validate_22(await sub_go_1_topic_1.get())
# 2.1. go_0 publishes data_21 to topic_0, py and go_0 receive.
data_21 = b"data_21"
validate_21 = functools.partial(
validate_pubsub_msg, data=data_21, from_peer_id=p2pds[0].peer_id
)
await p2pds[0].control.pubsub_publish(TOPIC_0, data_21)
validate_21(await sub_py_topic_0.get())
validate_21(await sub_go_0_topic_0.get())
#
# Test: `unsubscribe` and re`subscribe`
#
await py_pubsub.unsubscribe(TOPIC_0)
await asyncio.sleep(0.1)
assert py_peer_id not in (await p2pds[0].control.pubsub_list_peers(TOPIC_0))
assert py_peer_id not in (await p2pds[1].control.pubsub_list_peers(TOPIC_0))
await py_pubsub.subscribe(TOPIC_0)
await asyncio.sleep(0.1)
assert py_peer_id in (await p2pds[0].control.pubsub_list_peers(TOPIC_0))
assert py_peer_id in (await p2pds[1].control.pubsub_list_peers(TOPIC_0))
# 2.2. go_1 publishes data_22 to topic_1, all receive.
data_22 = b"data_22"
validate_22 = functools.partial(
validate_pubsub_msg, data=data_22, from_peer_id=p2pds[1].peer_id
)
await p2pds[1].control.pubsub_publish(TOPIC_1, data_22)
validate_22(await sub_py_topic_1.get())
validate_22(await sub_go_0_topic_1.get())
validate_22(await sub_go_1_topic_1.get())
#
# Test: `unsubscribe` and re`subscribe`
#
await py_pubsub.unsubscribe(TOPIC_0)
await trio.sleep(0.1)
assert py_peer_id not in (await p2pds[0].control.pubsub_list_peers(TOPIC_0))
assert py_peer_id not in (await p2pds[1].control.pubsub_list_peers(TOPIC_0))
await py_pubsub.subscribe(TOPIC_0)
await trio.sleep(0.1)
assert py_peer_id in (await p2pds[0].control.pubsub_list_peers(TOPIC_0))
assert py_peer_id in (await p2pds[1].control.pubsub_list_peers(TOPIC_0))

View File

@ -12,7 +12,7 @@ envlist =
combine_as_imports=False
force_sort_within_sections=True
include_trailing_comma=True
known_third_party=hypothesis,pytest,p2pclient,pexpect,factory,lru
known_third_party=anyio,factory,lru,p2pclient,pytest
known_first_party=libp2p
line_length=88
multi_line_output=3
@ -58,7 +58,6 @@ commands =
[testenv:py37-interop]
deps =
p2pclient
pexpect
passenv = CI TRAVIS TRAVIS_* GOPATH
extras = test
commands =