commit
e63584c387
|
@ -5,16 +5,16 @@ matrix:
|
|||
- python: 3.6-dev
|
||||
dist: xenial
|
||||
env: TOXENV=py36-test
|
||||
- python: 3.7-dev
|
||||
- python: 3.7
|
||||
dist: xenial
|
||||
env: TOXENV=py37-test
|
||||
- python: 3.7-dev
|
||||
- python: 3.7
|
||||
dist: xenial
|
||||
env: TOXENV=lint
|
||||
- python: 3.7-dev
|
||||
- python: 3.7
|
||||
dist: xenial
|
||||
env: TOXENV=docs
|
||||
- python: 3.7-dev
|
||||
- python: 3.7
|
||||
dist: xenial
|
||||
env: TOXENV=py37-interop
|
||||
sudo: true
|
||||
|
|
2
Makefile
2
Makefile
|
@ -51,7 +51,7 @@ lint:
|
|||
black --check $(FILES_TO_LINT)
|
||||
isort --recursive --check-only --diff $(FILES_TO_LINT)
|
||||
docformatter --pre-summary-newline --check --recursive $(FILES_TO_LINT)
|
||||
tox -elint # This is probably redundant, but just in case...
|
||||
tox -e lint # This is probably redundant, but just in case...
|
||||
|
||||
lint-roll:
|
||||
isort --recursive $(FILES_TO_LINT)
|
||||
|
|
|
@ -11,6 +11,22 @@ Subpackages
|
|||
Submodules
|
||||
----------
|
||||
|
||||
libp2p.pubsub.abc module
|
||||
------------------------
|
||||
|
||||
.. automodule:: libp2p.pubsub.abc
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
libp2p.pubsub.exceptions module
|
||||
-------------------------------
|
||||
|
||||
.. automodule:: libp2p.pubsub.exceptions
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
libp2p.pubsub.floodsub module
|
||||
-----------------------------
|
||||
|
||||
|
@ -51,10 +67,10 @@ libp2p.pubsub.pubsub\_notifee module
|
|||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
libp2p.pubsub.pubsub\_router\_interface module
|
||||
----------------------------------------------
|
||||
libp2p.pubsub.subscription module
|
||||
---------------------------------
|
||||
|
||||
.. automodule:: libp2p.pubsub.pubsub_router_interface
|
||||
.. automodule:: libp2p.pubsub.subscription
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
|
|
@ -1,11 +1,10 @@
|
|||
import argparse
|
||||
import asyncio
|
||||
import sys
|
||||
import urllib.request
|
||||
|
||||
import multiaddr
|
||||
import trio
|
||||
|
||||
from libp2p import new_node
|
||||
from libp2p import new_host
|
||||
from libp2p.network.stream.net_stream_interface import INetStream
|
||||
from libp2p.peer.peerinfo import info_from_p2p_addr
|
||||
from libp2p.typing import TProtocol
|
||||
|
@ -26,53 +25,47 @@ async def read_data(stream: INetStream) -> None:
|
|||
|
||||
|
||||
async def write_data(stream: INetStream) -> None:
|
||||
loop = asyncio.get_event_loop()
|
||||
async_f = trio.wrap_file(sys.stdin)
|
||||
while True:
|
||||
line = await loop.run_in_executor(None, sys.stdin.readline)
|
||||
line = await async_f.readline()
|
||||
await stream.write(line.encode())
|
||||
|
||||
|
||||
async def run(port: int, destination: str, localhost: bool) -> None:
|
||||
if localhost:
|
||||
ip = "127.0.0.1"
|
||||
else:
|
||||
ip = urllib.request.urlopen("https://v4.ident.me/").read().decode("utf8")
|
||||
transport_opt = f"/ip4/{ip}/tcp/{port}"
|
||||
host = await new_node(transport_opt=[transport_opt])
|
||||
async def run(port: int, destination: str) -> None:
|
||||
localhost_ip = "127.0.0.1"
|
||||
listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/{port}")
|
||||
host = new_host()
|
||||
async with host.run(listen_addrs=[listen_addr]), trio.open_nursery() as nursery:
|
||||
if not destination: # its the server
|
||||
|
||||
await host.get_network().listen(multiaddr.Multiaddr(transport_opt))
|
||||
async def stream_handler(stream: INetStream) -> None:
|
||||
nursery.start_soon(read_data, stream)
|
||||
nursery.start_soon(write_data, stream)
|
||||
|
||||
if not destination: # its the server
|
||||
host.set_stream_handler(PROTOCOL_ID, stream_handler)
|
||||
|
||||
async def stream_handler(stream: INetStream) -> None:
|
||||
asyncio.ensure_future(read_data(stream))
|
||||
asyncio.ensure_future(write_data(stream))
|
||||
print(
|
||||
f"Run 'python ./examples/chat/chat.py "
|
||||
f"-p {int(port) + 1} "
|
||||
f"-d /ip4/{localhost_ip}/tcp/{port}/p2p/{host.get_id().pretty()}' "
|
||||
"on another console."
|
||||
)
|
||||
print("Waiting for incoming connection...")
|
||||
|
||||
host.set_stream_handler(PROTOCOL_ID, stream_handler)
|
||||
else: # its the client
|
||||
maddr = multiaddr.Multiaddr(destination)
|
||||
info = info_from_p2p_addr(maddr)
|
||||
# Associate the peer with local ip address
|
||||
await host.connect(info)
|
||||
# Start a stream with the destination.
|
||||
# Multiaddress of the destination peer is fetched from the peerstore using 'peerId'.
|
||||
stream = await host.new_stream(info.peer_id, [PROTOCOL_ID])
|
||||
|
||||
localhost_opt = " --localhost" if localhost else ""
|
||||
nursery.start_soon(read_data, stream)
|
||||
nursery.start_soon(write_data, stream)
|
||||
print(f"Connected to peer {info.addrs[0]}")
|
||||
|
||||
print(
|
||||
f"Run 'python ./examples/chat/chat.py"
|
||||
+ localhost_opt
|
||||
+ f" -p {int(port) + 1} -d /ip4/{ip}/tcp/{port}/p2p/{host.get_id().pretty()}'"
|
||||
+ " on another console."
|
||||
)
|
||||
print("Waiting for incoming connection...")
|
||||
|
||||
else: # its the client
|
||||
maddr = multiaddr.Multiaddr(destination)
|
||||
info = info_from_p2p_addr(maddr)
|
||||
# Associate the peer with local ip address
|
||||
await host.connect(info)
|
||||
|
||||
# Start a stream with the destination.
|
||||
# Multiaddress of the destination peer is fetched from the peerstore using 'peerId'.
|
||||
stream = await host.new_stream(info.peer_id, [PROTOCOL_ID])
|
||||
|
||||
asyncio.ensure_future(read_data(stream))
|
||||
asyncio.ensure_future(write_data(stream))
|
||||
print("Connected to peer %s" % info.addrs[0])
|
||||
await trio.sleep_forever()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
|
@ -86,11 +79,6 @@ def main() -> None:
|
|||
"/ip4/127.0.0.1/tcp/8000/p2p/QmQn4SwGkDZKkUEpBRBvTmheQycxAHJUNmVEnjA2v1qe8Q"
|
||||
)
|
||||
parser = argparse.ArgumentParser(description=description)
|
||||
parser.add_argument(
|
||||
"--debug",
|
||||
action="store_true",
|
||||
help="generate the same node ID on every execution",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-p", "--port", default=8000, type=int, help="source port number"
|
||||
)
|
||||
|
@ -100,26 +88,15 @@ def main() -> None:
|
|||
type=str,
|
||||
help=f"destination multiaddr string, e.g. {example_maddr}",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-l",
|
||||
"--localhost",
|
||||
dest="localhost",
|
||||
action="store_true",
|
||||
help="flag indicating if localhost should be used or an external IP",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.port:
|
||||
raise RuntimeError("was not able to determine a local port")
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
try:
|
||||
asyncio.ensure_future(run(args.port, args.destination, args.localhost))
|
||||
loop.run_forever()
|
||||
trio.run(run, *(args.port, args.destination))
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -1,10 +1,9 @@
|
|||
import argparse
|
||||
import asyncio
|
||||
import urllib.request
|
||||
|
||||
import multiaddr
|
||||
import trio
|
||||
|
||||
from libp2p import new_node
|
||||
from libp2p import new_host
|
||||
from libp2p.crypto.secp256k1 import create_new_key_pair
|
||||
from libp2p.network.stream.net_stream_interface import INetStream
|
||||
from libp2p.peer.peerinfo import info_from_p2p_addr
|
||||
|
@ -20,12 +19,9 @@ async def _echo_stream_handler(stream: INetStream) -> None:
|
|||
await stream.close()
|
||||
|
||||
|
||||
async def run(port: int, destination: str, localhost: bool, seed: int = None) -> None:
|
||||
if localhost:
|
||||
ip = "127.0.0.1"
|
||||
else:
|
||||
ip = urllib.request.urlopen("https://v4.ident.me/").read().decode("utf8")
|
||||
transport_opt = f"/ip4/{ip}/tcp/{port}"
|
||||
async def run(port: int, destination: str, seed: int = None) -> None:
|
||||
localhost_ip = "127.0.0.1"
|
||||
listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/{port}")
|
||||
|
||||
if seed:
|
||||
import random
|
||||
|
@ -38,47 +34,43 @@ async def run(port: int, destination: str, localhost: bool, seed: int = None) ->
|
|||
|
||||
secret = secrets.token_bytes(32)
|
||||
|
||||
host = await new_node(
|
||||
key_pair=create_new_key_pair(secret), transport_opt=[transport_opt]
|
||||
)
|
||||
host = new_host(key_pair=create_new_key_pair(secret))
|
||||
async with host.run(listen_addrs=[listen_addr]):
|
||||
|
||||
print(f"I am {host.get_id().to_string()}")
|
||||
print(f"I am {host.get_id().to_string()}")
|
||||
|
||||
await host.get_network().listen(multiaddr.Multiaddr(transport_opt))
|
||||
if not destination: # its the server
|
||||
|
||||
if not destination: # its the server
|
||||
host.set_stream_handler(PROTOCOL_ID, _echo_stream_handler)
|
||||
|
||||
host.set_stream_handler(PROTOCOL_ID, _echo_stream_handler)
|
||||
print(
|
||||
f"Run 'python ./examples/echo/echo.py "
|
||||
f"-p {int(port) + 1} "
|
||||
f"-d /ip4/{localhost_ip}/tcp/{port}/p2p/{host.get_id().pretty()}' "
|
||||
"on another console."
|
||||
)
|
||||
print("Waiting for incoming connections...")
|
||||
await trio.sleep_forever()
|
||||
|
||||
localhost_opt = " --localhost" if localhost else ""
|
||||
else: # its the client
|
||||
maddr = multiaddr.Multiaddr(destination)
|
||||
info = info_from_p2p_addr(maddr)
|
||||
# Associate the peer with local ip address
|
||||
await host.connect(info)
|
||||
|
||||
print(
|
||||
f"Run 'python ./examples/echo/echo.py"
|
||||
+ localhost_opt
|
||||
+ f" -p {int(port) + 1} -d /ip4/{ip}/tcp/{port}/p2p/{host.get_id().pretty()}'"
|
||||
+ " on another console."
|
||||
)
|
||||
print("Waiting for incoming connections...")
|
||||
# Start a stream with the destination.
|
||||
# Multiaddress of the destination peer is fetched from the peerstore using 'peerId'.
|
||||
stream = await host.new_stream(info.peer_id, [PROTOCOL_ID])
|
||||
|
||||
else: # its the client
|
||||
maddr = multiaddr.Multiaddr(destination)
|
||||
info = info_from_p2p_addr(maddr)
|
||||
# Associate the peer with local ip address
|
||||
await host.connect(info)
|
||||
msg = b"hi, there!\n"
|
||||
|
||||
# Start a stream with the destination.
|
||||
# Multiaddress of the destination peer is fetched from the peerstore using 'peerId'.
|
||||
stream = await host.new_stream(info.peer_id, [PROTOCOL_ID])
|
||||
await stream.write(msg)
|
||||
# Notify the other side about EOF
|
||||
await stream.close()
|
||||
response = await stream.read()
|
||||
|
||||
msg = b"hi, there!\n"
|
||||
|
||||
await stream.write(msg)
|
||||
# Notify the other side about EOF
|
||||
await stream.close()
|
||||
response = await stream.read()
|
||||
|
||||
print(f"Sent: {msg}")
|
||||
print(f"Got: {response}")
|
||||
print(f"Sent: {msg}")
|
||||
print(f"Got: {response}")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
|
@ -94,11 +86,6 @@ def main() -> None:
|
|||
"/ip4/127.0.0.1/tcp/8000/p2p/QmQn4SwGkDZKkUEpBRBvTmheQycxAHJUNmVEnjA2v1qe8Q"
|
||||
)
|
||||
parser = argparse.ArgumentParser(description=description)
|
||||
parser.add_argument(
|
||||
"--debug",
|
||||
action="store_true",
|
||||
help="generate the same node ID on every execution",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-p", "--port", default=8000, type=int, help="source port number"
|
||||
)
|
||||
|
@ -108,13 +95,6 @@ def main() -> None:
|
|||
type=str,
|
||||
help=f"destination multiaddr string, e.g. {example_maddr}",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-l",
|
||||
"--localhost",
|
||||
dest="localhost",
|
||||
action="store_true",
|
||||
help="flag indicating if localhost should be used or an external IP",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-s",
|
||||
"--seed",
|
||||
|
@ -126,16 +106,10 @@ def main() -> None:
|
|||
if not args.port:
|
||||
raise RuntimeError("was not able to determine a local port")
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
try:
|
||||
asyncio.ensure_future(
|
||||
run(args.port, args.destination, args.localhost, args.seed)
|
||||
)
|
||||
loop.run_forever()
|
||||
trio.run(run, args.port, args.destination, args.seed)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -1,12 +1,9 @@
|
|||
import asyncio
|
||||
from typing import Sequence
|
||||
|
||||
from libp2p.crypto.keys import KeyPair
|
||||
from libp2p.crypto.rsa import create_new_key_pair
|
||||
from libp2p.host.basic_host import BasicHost
|
||||
from libp2p.host.host_interface import IHost
|
||||
from libp2p.host.routed_host import RoutedHost
|
||||
from libp2p.network.network_interface import INetwork
|
||||
from libp2p.network.network_interface import INetworkService
|
||||
from libp2p.network.swarm import Swarm
|
||||
from libp2p.peer.id import ID
|
||||
from libp2p.peer.peerstore import PeerStore
|
||||
|
@ -21,18 +18,6 @@ from libp2p.transport.upgrader import TransportUpgrader
|
|||
from libp2p.typing import TProtocol
|
||||
|
||||
|
||||
async def cleanup_done_tasks() -> None:
|
||||
"""clean up asyncio done tasks to free up resources."""
|
||||
while True:
|
||||
for task in asyncio.all_tasks():
|
||||
if task.done():
|
||||
await task
|
||||
|
||||
# Need not run often
|
||||
# Some sleep necessary to context switch
|
||||
await asyncio.sleep(3)
|
||||
|
||||
|
||||
def generate_new_rsa_identity() -> KeyPair:
|
||||
return create_new_key_pair()
|
||||
|
||||
|
@ -42,29 +27,28 @@ def generate_peer_id_from(key_pair: KeyPair) -> ID:
|
|||
return ID.from_pubkey(public_key)
|
||||
|
||||
|
||||
def initialize_default_swarm(
|
||||
key_pair: KeyPair,
|
||||
id_opt: ID = None,
|
||||
transport_opt: Sequence[str] = None,
|
||||
def new_swarm(
|
||||
key_pair: KeyPair = None,
|
||||
muxer_opt: TMuxerOptions = None,
|
||||
sec_opt: TSecurityOptions = None,
|
||||
peerstore_opt: IPeerStore = None,
|
||||
) -> Swarm:
|
||||
) -> INetworkService:
|
||||
"""
|
||||
initialize swarm when no swarm is passed in.
|
||||
Create a swarm instance based on the parameters.
|
||||
|
||||
:param id_opt: optional id for host
|
||||
:param transport_opt: optional choice of transport upgrade
|
||||
:param key_pair: optional choice of the ``KeyPair``
|
||||
:param muxer_opt: optional choice of stream muxer
|
||||
:param sec_opt: optional choice of security upgrade
|
||||
:param peerstore_opt: optional peerstore
|
||||
:return: return a default swarm instance
|
||||
"""
|
||||
|
||||
if not id_opt:
|
||||
id_opt = generate_peer_id_from(key_pair)
|
||||
if key_pair is None:
|
||||
key_pair = generate_new_rsa_identity()
|
||||
|
||||
# TODO: Parse `transport_opt` to determine transport
|
||||
id_opt = generate_peer_id_from(key_pair)
|
||||
|
||||
# TODO: Parse `listen_addrs` to determine transport
|
||||
transport = TCP()
|
||||
|
||||
muxer_transports_by_protocol = muxer_opt or {MPLEX_PROTOCOL_ID: Mplex}
|
||||
|
@ -80,57 +64,35 @@ def initialize_default_swarm(
|
|||
# Store our key pair in peerstore
|
||||
peerstore.add_key_pair(id_opt, key_pair)
|
||||
|
||||
# TODO: Initialize discovery if not presented
|
||||
return Swarm(id_opt, peerstore, upgrader, transport)
|
||||
|
||||
|
||||
async def new_node(
|
||||
def new_host(
|
||||
key_pair: KeyPair = None,
|
||||
swarm_opt: INetwork = None,
|
||||
transport_opt: Sequence[str] = None,
|
||||
muxer_opt: TMuxerOptions = None,
|
||||
sec_opt: TSecurityOptions = None,
|
||||
peerstore_opt: IPeerStore = None,
|
||||
disc_opt: IPeerRouting = None,
|
||||
) -> BasicHost:
|
||||
) -> IHost:
|
||||
"""
|
||||
create new libp2p node.
|
||||
Create a new libp2p host based on the given parameters.
|
||||
|
||||
:param key_pair: key pair for deriving an identity
|
||||
:param swarm_opt: optional swarm
|
||||
:param id_opt: optional id for host
|
||||
:param transport_opt: optional choice of transport upgrade
|
||||
:param key_pair: optional choice of the ``KeyPair``
|
||||
:param muxer_opt: optional choice of stream muxer
|
||||
:param sec_opt: optional choice of security upgrade
|
||||
:param peerstore_opt: optional peerstore
|
||||
:param disc_opt: optional discovery
|
||||
:return: return a host instance
|
||||
"""
|
||||
|
||||
if not key_pair:
|
||||
key_pair = generate_new_rsa_identity()
|
||||
|
||||
id_opt = generate_peer_id_from(key_pair)
|
||||
|
||||
if not swarm_opt:
|
||||
swarm_opt = initialize_default_swarm(
|
||||
key_pair=key_pair,
|
||||
id_opt=id_opt,
|
||||
transport_opt=transport_opt,
|
||||
muxer_opt=muxer_opt,
|
||||
sec_opt=sec_opt,
|
||||
peerstore_opt=peerstore_opt,
|
||||
)
|
||||
|
||||
# TODO enable support for other host type
|
||||
# TODO routing unimplemented
|
||||
host: IHost # If not explicitly typed, MyPy raises error
|
||||
swarm = new_swarm(
|
||||
key_pair=key_pair,
|
||||
muxer_opt=muxer_opt,
|
||||
sec_opt=sec_opt,
|
||||
peerstore_opt=peerstore_opt,
|
||||
)
|
||||
host: IHost
|
||||
if disc_opt:
|
||||
host = RoutedHost(swarm_opt, disc_opt)
|
||||
host = RoutedHost(swarm, disc_opt)
|
||||
else:
|
||||
host = BasicHost(swarm_opt)
|
||||
|
||||
# Kick off cleanup job
|
||||
asyncio.ensure_future(cleanup_done_tasks())
|
||||
|
||||
host = BasicHost(swarm)
|
||||
return host
|
||||
|
|
|
@ -1,12 +1,14 @@
|
|||
import logging
|
||||
from typing import TYPE_CHECKING, List, Sequence
|
||||
from typing import TYPE_CHECKING, AsyncIterator, List, Sequence
|
||||
|
||||
from async_generator import asynccontextmanager
|
||||
from async_service import background_trio_service
|
||||
import multiaddr
|
||||
|
||||
from libp2p.crypto.keys import PrivateKey, PublicKey
|
||||
from libp2p.host.defaults import get_default_protocols
|
||||
from libp2p.host.exceptions import StreamFailure
|
||||
from libp2p.network.network_interface import INetwork
|
||||
from libp2p.network.network_interface import INetworkService
|
||||
from libp2p.network.stream.net_stream_interface import INetStream
|
||||
from libp2p.peer.id import ID
|
||||
from libp2p.peer.peerinfo import PeerInfo
|
||||
|
@ -39,7 +41,7 @@ class BasicHost(IHost):
|
|||
right after a stream is initialized.
|
||||
"""
|
||||
|
||||
_network: INetwork
|
||||
_network: INetworkService
|
||||
peerstore: IPeerStore
|
||||
|
||||
multiselect: Multiselect
|
||||
|
@ -47,7 +49,7 @@ class BasicHost(IHost):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
network: INetwork,
|
||||
network: INetworkService,
|
||||
default_protocols: "OrderedDict[TProtocol, StreamHandlerFn]" = None,
|
||||
) -> None:
|
||||
self._network = network
|
||||
|
@ -70,7 +72,7 @@ class BasicHost(IHost):
|
|||
def get_private_key(self) -> PrivateKey:
|
||||
return self.peerstore.privkey(self.get_id())
|
||||
|
||||
def get_network(self) -> INetwork:
|
||||
def get_network(self) -> INetworkService:
|
||||
"""
|
||||
:return: network instance of host
|
||||
"""
|
||||
|
@ -101,6 +103,20 @@ class BasicHost(IHost):
|
|||
addrs.append(addr.encapsulate(p2p_part))
|
||||
return addrs
|
||||
|
||||
@asynccontextmanager
|
||||
async def run(
|
||||
self, listen_addrs: Sequence[multiaddr.Multiaddr]
|
||||
) -> AsyncIterator[None]:
|
||||
"""
|
||||
run the host instance and listen to ``listen_addrs``.
|
||||
|
||||
:param listen_addrs: a sequence of multiaddrs that we want to listen to
|
||||
"""
|
||||
network = self.get_network()
|
||||
async with background_trio_service(network):
|
||||
await network.listen(*listen_addrs)
|
||||
yield
|
||||
|
||||
def set_stream_handler(
|
||||
self, protocol_id: TProtocol, stream_handler: StreamHandlerFn
|
||||
) -> None:
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Any, List, Sequence
|
||||
from typing import Any, AsyncContextManager, List, Sequence
|
||||
|
||||
import multiaddr
|
||||
|
||||
from libp2p.crypto.keys import PrivateKey, PublicKey
|
||||
from libp2p.network.network_interface import INetwork
|
||||
from libp2p.network.network_interface import INetworkService
|
||||
from libp2p.network.stream.net_stream_interface import INetStream
|
||||
from libp2p.peer.id import ID
|
||||
from libp2p.peer.peerinfo import PeerInfo
|
||||
|
@ -31,7 +31,7 @@ class IHost(ABC):
|
|||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_network(self) -> INetwork:
|
||||
def get_network(self) -> INetworkService:
|
||||
"""
|
||||
:return: network instance of host
|
||||
"""
|
||||
|
@ -49,6 +49,16 @@ class IHost(ABC):
|
|||
:return: all the multiaddr addresses this host is listening to
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def run(
|
||||
self, listen_addrs: Sequence[multiaddr.Multiaddr]
|
||||
) -> AsyncContextManager[None]:
|
||||
"""
|
||||
run the host instance and listen to ``listen_addrs``.
|
||||
|
||||
:param listen_addrs: a sequence of multiaddrs that we want to listen to
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def set_stream_handler(
|
||||
self, protocol_id: TProtocol, stream_handler: StreamHandlerFn
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import asyncio
|
||||
import logging
|
||||
|
||||
import trio
|
||||
|
||||
from libp2p.network.stream.exceptions import StreamClosed, StreamEOF, StreamReset
|
||||
from libp2p.network.stream.net_stream_interface import INetStream
|
||||
from libp2p.peer.id import ID as PeerID
|
||||
|
@ -17,8 +18,9 @@ async def _handle_ping(stream: INetStream, peer_id: PeerID) -> bool:
|
|||
"""Return a boolean indicating if we expect more pings from the peer at
|
||||
``peer_id``."""
|
||||
try:
|
||||
payload = await asyncio.wait_for(stream.read(PING_LENGTH), RESP_TIMEOUT)
|
||||
except asyncio.TimeoutError as error:
|
||||
with trio.fail_after(RESP_TIMEOUT):
|
||||
payload = await stream.read(PING_LENGTH)
|
||||
except trio.TooSlowError as error:
|
||||
logger.debug("Timed out waiting for ping from %s: %s", peer_id, error)
|
||||
raise
|
||||
except StreamEOF:
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from libp2p.host.basic_host import BasicHost
|
||||
from libp2p.host.exceptions import ConnectionFailure
|
||||
from libp2p.network.network_interface import INetwork
|
||||
from libp2p.network.network_interface import INetworkService
|
||||
from libp2p.peer.peerinfo import PeerInfo
|
||||
from libp2p.routing.interfaces import IPeerRouting
|
||||
|
||||
|
@ -10,7 +10,7 @@ from libp2p.routing.interfaces import IPeerRouting
|
|||
class RoutedHost(BasicHost):
|
||||
_router: IPeerRouting
|
||||
|
||||
def __init__(self, network: INetwork, router: IPeerRouting):
|
||||
def __init__(self, network: INetworkService, router: IPeerRouting):
|
||||
super().__init__(network)
|
||||
self._router = router
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ class Closer(ABC):
|
|||
|
||||
class Reader(ABC):
|
||||
@abstractmethod
|
||||
async def read(self, n: int = -1) -> bytes:
|
||||
async def read(self, n: int = None) -> bytes:
|
||||
...
|
||||
|
||||
|
||||
|
|
|
@ -54,7 +54,7 @@ class MsgIOReader(ReadCloser):
|
|||
self.read_closer = read_closer
|
||||
self.next_length = None
|
||||
|
||||
async def read(self, n: int = -1) -> bytes:
|
||||
async def read(self, n: int = None) -> bytes:
|
||||
return await self.read_msg()
|
||||
|
||||
async def read_msg(self) -> bytes:
|
||||
|
|
|
@ -0,0 +1,40 @@
|
|||
import logging
|
||||
|
||||
import trio
|
||||
|
||||
from libp2p.io.abc import ReadWriteCloser
|
||||
from libp2p.io.exceptions import IOException
|
||||
|
||||
logger = logging.getLogger("libp2p.io.trio")
|
||||
|
||||
|
||||
class TrioTCPStream(ReadWriteCloser):
|
||||
stream: trio.SocketStream
|
||||
# NOTE: Add both read and write lock to avoid `trio.BusyResourceError`
|
||||
read_lock: trio.Lock
|
||||
write_lock: trio.Lock
|
||||
|
||||
def __init__(self, stream: trio.SocketStream) -> None:
|
||||
self.stream = stream
|
||||
self.read_lock = trio.Lock()
|
||||
self.write_lock = trio.Lock()
|
||||
|
||||
async def write(self, data: bytes) -> None:
|
||||
"""Raise `RawConnError` if the underlying connection breaks."""
|
||||
async with self.write_lock:
|
||||
try:
|
||||
await self.stream.send_all(data)
|
||||
except (trio.ClosedResourceError, trio.BrokenResourceError) as error:
|
||||
raise IOException from error
|
||||
|
||||
async def read(self, n: int = None) -> bytes:
|
||||
async with self.read_lock:
|
||||
if n is not None and n == 0:
|
||||
return b""
|
||||
try:
|
||||
return await self.stream.receive_some(n)
|
||||
except (trio.ClosedResourceError, trio.BrokenResourceError) as error:
|
||||
raise IOException from error
|
||||
|
||||
async def close(self) -> None:
|
||||
await self.stream.aclose()
|
|
@ -1,6 +1,8 @@
|
|||
from abc import abstractmethod
|
||||
from typing import Tuple
|
||||
|
||||
import trio
|
||||
|
||||
from libp2p.io.abc import Closer
|
||||
from libp2p.network.stream.net_stream_interface import INetStream
|
||||
from libp2p.stream_muxer.abc import IMuxedConn
|
||||
|
@ -8,11 +10,12 @@ from libp2p.stream_muxer.abc import IMuxedConn
|
|||
|
||||
class INetConn(Closer):
|
||||
muxed_conn: IMuxedConn
|
||||
event_started: trio.Event
|
||||
|
||||
@abstractmethod
|
||||
async def new_stream(self) -> INetStream:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def get_streams(self) -> Tuple[INetStream, ...]:
|
||||
def get_streams(self) -> Tuple[INetStream, ...]:
|
||||
...
|
||||
|
|
|
@ -1,46 +1,26 @@
|
|||
import asyncio
|
||||
import sys
|
||||
from libp2p.io.abc import ReadWriteCloser
|
||||
from libp2p.io.exceptions import IOException
|
||||
|
||||
from .exceptions import RawConnError
|
||||
from .raw_connection_interface import IRawConnection
|
||||
|
||||
|
||||
class RawConnection(IRawConnection):
|
||||
reader: asyncio.StreamReader
|
||||
writer: asyncio.StreamWriter
|
||||
stream: ReadWriteCloser
|
||||
is_initiator: bool
|
||||
|
||||
_drain_lock: asyncio.Lock
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
reader: asyncio.StreamReader,
|
||||
writer: asyncio.StreamWriter,
|
||||
initiator: bool,
|
||||
) -> None:
|
||||
self.reader = reader
|
||||
self.writer = writer
|
||||
def __init__(self, stream: ReadWriteCloser, initiator: bool) -> None:
|
||||
self.stream = stream
|
||||
self.is_initiator = initiator
|
||||
|
||||
self._drain_lock = asyncio.Lock()
|
||||
|
||||
async def write(self, data: bytes) -> None:
|
||||
"""Raise `RawConnError` if the underlying connection breaks."""
|
||||
# Detect if underlying transport is closing before write data to it
|
||||
# ref: https://github.com/ethereum/trinity/pull/614
|
||||
if self.writer.transport.is_closing():
|
||||
raise RawConnError("Transport is closing")
|
||||
self.writer.write(data)
|
||||
# Reference: https://github.com/ethereum/lahja/blob/93610b2eb46969ff1797e0748c7ac2595e130aef/lahja/asyncio/endpoint.py#L99-L102 # noqa: E501
|
||||
# Use a lock to serialize drain() calls. Circumvents this bug:
|
||||
# https://bugs.python.org/issue29930
|
||||
async with self._drain_lock:
|
||||
try:
|
||||
await self.writer.drain()
|
||||
except ConnectionResetError as error:
|
||||
raise RawConnError() from error
|
||||
try:
|
||||
await self.stream.write(data)
|
||||
except IOException as error:
|
||||
raise RawConnError from error
|
||||
|
||||
async def read(self, n: int = -1) -> bytes:
|
||||
async def read(self, n: int = None) -> bytes:
|
||||
"""
|
||||
Read up to ``n`` bytes from the underlying stream. This call is
|
||||
delegated directly to the underlying ``self.reader``.
|
||||
|
@ -48,18 +28,9 @@ class RawConnection(IRawConnection):
|
|||
Raise `RawConnError` if the underlying connection breaks
|
||||
"""
|
||||
try:
|
||||
return await self.reader.read(n)
|
||||
except ConnectionResetError as error:
|
||||
raise RawConnError() from error
|
||||
return await self.stream.read(n)
|
||||
except IOException as error:
|
||||
raise RawConnError from error
|
||||
|
||||
async def close(self) -> None:
|
||||
if self.writer.transport.is_closing():
|
||||
return
|
||||
self.writer.close()
|
||||
if sys.version_info < (3, 7):
|
||||
return
|
||||
try:
|
||||
await self.writer.wait_closed()
|
||||
# In case the connection is already reset.
|
||||
except ConnectionResetError:
|
||||
return
|
||||
await self.stream.close()
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import asyncio
|
||||
from typing import TYPE_CHECKING, Any, Awaitable, List, Set, Tuple
|
||||
from typing import TYPE_CHECKING, Set, Tuple
|
||||
|
||||
import trio
|
||||
|
||||
from libp2p.network.connection.net_connection_interface import INetConn
|
||||
from libp2p.network.stream.net_stream import NetStream
|
||||
|
@ -19,90 +20,78 @@ class SwarmConn(INetConn):
|
|||
muxed_conn: IMuxedConn
|
||||
swarm: "Swarm"
|
||||
streams: Set[NetStream]
|
||||
event_closed: asyncio.Event
|
||||
|
||||
_tasks: List["asyncio.Future[Any]"]
|
||||
event_closed: trio.Event
|
||||
|
||||
def __init__(self, muxed_conn: IMuxedConn, swarm: "Swarm") -> None:
|
||||
self.muxed_conn = muxed_conn
|
||||
self.swarm = swarm
|
||||
self.streams = set()
|
||||
self.event_closed = asyncio.Event()
|
||||
self.event_closed = trio.Event()
|
||||
self.event_started = trio.Event()
|
||||
|
||||
self._tasks = []
|
||||
@property
|
||||
def is_closed(self) -> bool:
|
||||
return self.event_closed.is_set()
|
||||
|
||||
async def close(self) -> None:
|
||||
if self.event_closed.is_set():
|
||||
return
|
||||
self.event_closed.set()
|
||||
await self._cleanup()
|
||||
|
||||
async def _cleanup(self) -> None:
|
||||
self.swarm.remove_conn(self)
|
||||
|
||||
await self.muxed_conn.close()
|
||||
|
||||
# This is just for cleaning up state. The connection has already been closed.
|
||||
# We *could* optimize this but it really isn't worth it.
|
||||
for stream in self.streams:
|
||||
for stream in self.streams.copy():
|
||||
await stream.reset()
|
||||
# Force context switch for stream handlers to process the stream reset event we just emit
|
||||
# before we cancel the stream handler tasks.
|
||||
await asyncio.sleep(0.1)
|
||||
await trio.sleep(0.1)
|
||||
|
||||
for task in self._tasks:
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
# Schedule `self._notify_disconnected` to make it execute after `close` is finished.
|
||||
self._notify_disconnected()
|
||||
await self._notify_disconnected()
|
||||
|
||||
async def _handle_new_streams(self) -> None:
|
||||
while True:
|
||||
try:
|
||||
stream = await self.muxed_conn.accept_stream()
|
||||
except MuxedConnUnavailable:
|
||||
# If there is anything wrong in the MuxedConn,
|
||||
# we should break the loop and close the connection.
|
||||
break
|
||||
# Asynchronously handle the accepted stream, to avoid blocking the next stream.
|
||||
await self.run_task(self._handle_muxed_stream(stream))
|
||||
|
||||
await self.close()
|
||||
|
||||
async def _call_stream_handler(self, net_stream: NetStream) -> None:
|
||||
try:
|
||||
await self.swarm.common_stream_handler(net_stream)
|
||||
# TODO: More exact exceptions
|
||||
except Exception:
|
||||
# TODO: Emit logs.
|
||||
# TODO: Clean up and remove the stream from SwarmConn if there is anything wrong.
|
||||
self.remove_stream(net_stream)
|
||||
self.event_started.set()
|
||||
async with trio.open_nursery() as nursery:
|
||||
while True:
|
||||
try:
|
||||
stream = await self.muxed_conn.accept_stream()
|
||||
except MuxedConnUnavailable:
|
||||
await self.close()
|
||||
break
|
||||
# Asynchronously handle the accepted stream, to avoid blocking the next stream.
|
||||
nursery.start_soon(self._handle_muxed_stream, stream)
|
||||
|
||||
async def _handle_muxed_stream(self, muxed_stream: IMuxedStream) -> None:
|
||||
net_stream = self._add_stream(muxed_stream)
|
||||
if self.swarm.common_stream_handler is not None:
|
||||
await self.run_task(self._call_stream_handler(net_stream))
|
||||
net_stream = await self._add_stream(muxed_stream)
|
||||
try:
|
||||
# Ignore type here since mypy complains: https://github.com/python/mypy/issues/2427
|
||||
await self.swarm.common_stream_handler(net_stream) # type: ignore
|
||||
finally:
|
||||
# As long as `common_stream_handler`, remove the stream.
|
||||
self.remove_stream(net_stream)
|
||||
|
||||
def _add_stream(self, muxed_stream: IMuxedStream) -> NetStream:
|
||||
async def _add_stream(self, muxed_stream: IMuxedStream) -> NetStream:
|
||||
net_stream = NetStream(muxed_stream)
|
||||
self.streams.add(net_stream)
|
||||
self.swarm.notify_opened_stream(net_stream)
|
||||
await self.swarm.notify_opened_stream(net_stream)
|
||||
return net_stream
|
||||
|
||||
def _notify_disconnected(self) -> None:
|
||||
self.swarm.notify_disconnected(self)
|
||||
async def _notify_disconnected(self) -> None:
|
||||
await self.swarm.notify_disconnected(self)
|
||||
|
||||
async def start(self) -> None:
|
||||
await self.run_task(self._handle_new_streams())
|
||||
|
||||
async def run_task(self, coro: Awaitable[Any]) -> None:
|
||||
self._tasks.append(asyncio.ensure_future(coro))
|
||||
await self._handle_new_streams()
|
||||
|
||||
async def new_stream(self) -> NetStream:
|
||||
muxed_stream = await self.muxed_conn.open_stream()
|
||||
return self._add_stream(muxed_stream)
|
||||
return await self._add_stream(muxed_stream)
|
||||
|
||||
async def get_streams(self) -> Tuple[NetStream, ...]:
|
||||
def get_streams(self) -> Tuple[NetStream, ...]:
|
||||
return tuple(self.streams)
|
||||
|
||||
def remove_stream(self, stream: NetStream) -> None:
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Dict, Sequence
|
||||
|
||||
from async_service import ServiceAPI
|
||||
from multiaddr import Multiaddr
|
||||
|
||||
from libp2p.network.connection.net_connection_interface import INetConn
|
||||
|
@ -70,3 +71,7 @@ class INetwork(ABC):
|
|||
@abstractmethod
|
||||
async def close_peer(self, peer_id: ID) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class INetworkService(INetwork, ServiceAPI):
|
||||
pass
|
||||
|
|
|
@ -37,7 +37,7 @@ class NetStream(INetStream):
|
|||
"""
|
||||
self.protocol_id = protocol_id
|
||||
|
||||
async def read(self, n: int = -1) -> bytes:
|
||||
async def read(self, n: int = None) -> bytes:
|
||||
"""
|
||||
reads from stream.
|
||||
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from async_service import Service
|
||||
from multiaddr import Multiaddr
|
||||
import trio
|
||||
|
||||
from libp2p.io.abc import ReadWriteCloser
|
||||
from libp2p.network.connection.net_connection_interface import INetConn
|
||||
from libp2p.peer.id import ID
|
||||
from libp2p.peer.peerstore import PeerStoreError
|
||||
|
@ -23,14 +25,21 @@ from ..exceptions import MultiError
|
|||
from .connection.raw_connection import RawConnection
|
||||
from .connection.swarm_connection import SwarmConn
|
||||
from .exceptions import SwarmException
|
||||
from .network_interface import INetwork
|
||||
from .network_interface import INetworkService
|
||||
from .notifee_interface import INotifee
|
||||
from .stream.net_stream_interface import INetStream
|
||||
|
||||
logger = logging.getLogger("libp2p.network.swarm")
|
||||
|
||||
|
||||
class Swarm(INetwork):
|
||||
def create_default_stream_handler(network: INetworkService) -> StreamHandlerFn:
|
||||
async def stream_handler(stream: INetStream) -> None:
|
||||
await network.get_manager().wait_finished()
|
||||
|
||||
return stream_handler
|
||||
|
||||
|
||||
class Swarm(Service, INetworkService):
|
||||
|
||||
self_id: ID
|
||||
peerstore: IPeerStore
|
||||
|
@ -40,7 +49,9 @@ class Swarm(INetwork):
|
|||
# whereas in Go one `peer_id` may point to multiple connections.
|
||||
connections: Dict[ID, INetConn]
|
||||
listeners: Dict[str, IListener]
|
||||
common_stream_handler: Optional[StreamHandlerFn]
|
||||
common_stream_handler: StreamHandlerFn
|
||||
listener_nursery: Optional[trio.Nursery]
|
||||
event_listener_nursery_created: trio.Event
|
||||
|
||||
notifees: List[INotifee]
|
||||
|
||||
|
@ -61,13 +72,31 @@ class Swarm(INetwork):
|
|||
# Create Notifee array
|
||||
self.notifees = []
|
||||
|
||||
self.common_stream_handler = None
|
||||
# Ignore type here since mypy complains: https://github.com/python/mypy/issues/2427
|
||||
self.common_stream_handler = create_default_stream_handler(self) # type: ignore
|
||||
|
||||
self.listener_nursery = None
|
||||
self.event_listener_nursery_created = trio.Event()
|
||||
|
||||
async def run(self) -> None:
|
||||
async with trio.open_nursery() as nursery:
|
||||
# Create a nursery for listener tasks.
|
||||
self.listener_nursery = nursery
|
||||
self.event_listener_nursery_created.set()
|
||||
try:
|
||||
await self.manager.wait_finished()
|
||||
finally:
|
||||
# The service ended. Cancel listener tasks.
|
||||
nursery.cancel_scope.cancel()
|
||||
# Indicate that the nursery has been cancelled.
|
||||
self.listener_nursery = None
|
||||
|
||||
def get_peer_id(self) -> ID:
|
||||
return self.self_id
|
||||
|
||||
def set_stream_handler(self, stream_handler: StreamHandlerFn) -> None:
|
||||
self.common_stream_handler = stream_handler
|
||||
# Ignore type here since mypy complains: https://github.com/python/mypy/issues/2427
|
||||
self.common_stream_handler = stream_handler # type: ignore
|
||||
|
||||
async def dial_peer(self, peer_id: ID) -> INetConn:
|
||||
"""
|
||||
|
@ -195,19 +224,15 @@ class Swarm(INetwork):
|
|||
- Call listener listen with the multiaddr
|
||||
- Map multiaddr to listener
|
||||
"""
|
||||
# We need to wait until `self.listener_nursery` is created.
|
||||
await self.event_listener_nursery_created.wait()
|
||||
|
||||
for maddr in multiaddrs:
|
||||
if str(maddr) in self.listeners:
|
||||
return True
|
||||
|
||||
async def conn_handler(
|
||||
reader: asyncio.StreamReader, writer: asyncio.StreamWriter
|
||||
) -> None:
|
||||
connection_info = writer.get_extra_info("peername")
|
||||
# TODO make a proper multiaddr
|
||||
peer_addr = f"/ip4/{connection_info[0]}/tcp/{connection_info[1]}"
|
||||
logger.debug("inbound connection at %s", peer_addr)
|
||||
# logger.debug("inbound connection request", peer_id)
|
||||
raw_conn = RawConnection(reader, writer, False)
|
||||
async def conn_handler(read_write_closer: ReadWriteCloser) -> None:
|
||||
raw_conn = RawConnection(read_write_closer, False)
|
||||
|
||||
# Per, https://discuss.libp2p.io/t/multistream-security/130, we first secure
|
||||
# the conn and then mux the conn
|
||||
|
@ -217,16 +242,13 @@ class Swarm(INetwork):
|
|||
raw_conn, ID(b""), False
|
||||
)
|
||||
except SecurityUpgradeFailure as error:
|
||||
logger.debug("failed to upgrade security for peer at %s", peer_addr)
|
||||
logger.debug("failed to upgrade security for peer at %s", maddr)
|
||||
await raw_conn.close()
|
||||
raise SwarmException(
|
||||
f"failed to upgrade security for peer at {peer_addr}"
|
||||
f"failed to upgrade security for peer at {maddr}"
|
||||
) from error
|
||||
peer_id = secured_conn.get_remote_peer()
|
||||
|
||||
logger.debug("upgraded security for peer at %s", peer_addr)
|
||||
logger.debug("identified peer at %s as %s", peer_addr, peer_id)
|
||||
|
||||
try:
|
||||
muxed_conn = await self.upgrader.upgrade_connection(
|
||||
secured_conn, peer_id
|
||||
|
@ -240,17 +262,24 @@ class Swarm(INetwork):
|
|||
logger.debug("upgraded mux for peer %s", peer_id)
|
||||
|
||||
await self.add_conn(muxed_conn)
|
||||
|
||||
logger.debug("successfully opened connection to peer %s", peer_id)
|
||||
|
||||
# NOTE: This is a intentional barrier to prevent from the handler exiting and
|
||||
# closing the connection.
|
||||
await self.manager.wait_finished()
|
||||
|
||||
try:
|
||||
# Success
|
||||
listener = self.transport.create_listener(conn_handler)
|
||||
self.listeners[str(maddr)] = listener
|
||||
await listener.listen(maddr)
|
||||
# TODO: `listener.listen` is not bounded with nursery. If we want to be
|
||||
# I/O agnostic, we should change the API.
|
||||
if self.listener_nursery is None:
|
||||
raise SwarmException("swarm instance hasn't been run")
|
||||
await listener.listen(maddr, self.listener_nursery)
|
||||
|
||||
# Call notifiers since event occurred
|
||||
self.notify_listen(maddr)
|
||||
await self.notify_listen(maddr)
|
||||
|
||||
return True
|
||||
except IOError:
|
||||
|
@ -261,26 +290,12 @@ class Swarm(INetwork):
|
|||
return False
|
||||
|
||||
async def close(self) -> None:
|
||||
# TODO: Prevent from new listeners and conns being added.
|
||||
# Reference: https://github.com/libp2p/go-libp2p-swarm/blob/8be680aef8dea0a4497283f2f98470c2aeae6b65/swarm.go#L124-L134 # noqa: E501
|
||||
|
||||
# Close listeners
|
||||
await asyncio.gather(
|
||||
*[listener.close() for listener in self.listeners.values()]
|
||||
)
|
||||
|
||||
# Close connections
|
||||
await asyncio.gather(
|
||||
*[connection.close() for connection in self.connections.values()]
|
||||
)
|
||||
|
||||
await self.manager.stop()
|
||||
logger.debug("swarm successfully closed")
|
||||
|
||||
async def close_peer(self, peer_id: ID) -> None:
|
||||
if peer_id not in self.connections:
|
||||
return
|
||||
# TODO: Should be changed to close multisple connections,
|
||||
# if we have several connections per peer in the future.
|
||||
connection = self.connections[peer_id]
|
||||
# NOTE: `connection.close` will delete `peer_id` from `self.connections`
|
||||
# and `notify_disconnected` for us.
|
||||
|
@ -293,11 +308,14 @@ class Swarm(INetwork):
|
|||
and start to monitor the connection for its new streams and
|
||||
disconnection."""
|
||||
swarm_conn = SwarmConn(muxed_conn, self)
|
||||
self.manager.run_task(muxed_conn.start)
|
||||
await muxed_conn.event_started.wait()
|
||||
self.manager.run_task(swarm_conn.start)
|
||||
await swarm_conn.event_started.wait()
|
||||
# Store muxed_conn with peer id
|
||||
self.connections[muxed_conn.peer_id] = swarm_conn
|
||||
# Call notifiers since event occurred
|
||||
self.notify_connected(swarm_conn)
|
||||
await swarm_conn.start()
|
||||
await self.notify_connected(swarm_conn)
|
||||
return swarm_conn
|
||||
|
||||
def remove_conn(self, swarm_conn: SwarmConn) -> None:
|
||||
|
@ -306,14 +324,10 @@ class Swarm(INetwork):
|
|||
peer_id = swarm_conn.muxed_conn.peer_id
|
||||
if peer_id not in self.connections:
|
||||
return
|
||||
# TODO: Should be changed to remove the exact connection,
|
||||
# if we have several connections per peer in the future.
|
||||
del self.connections[peer_id]
|
||||
|
||||
# Notifee
|
||||
|
||||
# TODO: Remeber the spawn notifying tasks and clean them up when closing.
|
||||
|
||||
def register_notifee(self, notifee: INotifee) -> None:
|
||||
"""
|
||||
:param notifee: object implementing Notifee interface
|
||||
|
@ -321,20 +335,28 @@ class Swarm(INetwork):
|
|||
"""
|
||||
self.notifees.append(notifee)
|
||||
|
||||
def notify_opened_stream(self, stream: INetStream) -> None:
|
||||
asyncio.gather(
|
||||
*[notifee.opened_stream(self, stream) for notifee in self.notifees]
|
||||
)
|
||||
async def notify_opened_stream(self, stream: INetStream) -> None:
|
||||
async with trio.open_nursery() as nursery:
|
||||
for notifee in self.notifees:
|
||||
nursery.start_soon(notifee.opened_stream, self, stream)
|
||||
|
||||
# TODO: `notify_closed_stream`
|
||||
async def notify_connected(self, conn: INetConn) -> None:
|
||||
async with trio.open_nursery() as nursery:
|
||||
for notifee in self.notifees:
|
||||
nursery.start_soon(notifee.connected, self, conn)
|
||||
|
||||
def notify_connected(self, conn: INetConn) -> None:
|
||||
asyncio.gather(*[notifee.connected(self, conn) for notifee in self.notifees])
|
||||
async def notify_disconnected(self, conn: INetConn) -> None:
|
||||
async with trio.open_nursery() as nursery:
|
||||
for notifee in self.notifees:
|
||||
nursery.start_soon(notifee.disconnected, self, conn)
|
||||
|
||||
def notify_disconnected(self, conn: INetConn) -> None:
|
||||
asyncio.gather(*[notifee.disconnected(self, conn) for notifee in self.notifees])
|
||||
async def notify_listen(self, multiaddr: Multiaddr) -> None:
|
||||
async with trio.open_nursery() as nursery:
|
||||
for notifee in self.notifees:
|
||||
nursery.start_soon(notifee.listen, self, multiaddr)
|
||||
|
||||
def notify_listen(self, multiaddr: Multiaddr) -> None:
|
||||
asyncio.gather(*[notifee.listen(self, multiaddr) for notifee in self.notifees])
|
||||
async def notify_closed_stream(self, stream: INetStream) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
# TODO: `notify_listen_close`
|
||||
async def notify_listen_close(self, multiaddr: Multiaddr) -> None:
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -25,9 +25,6 @@ def info_from_p2p_addr(addr: multiaddr.Multiaddr) -> PeerInfo:
|
|||
if not addr:
|
||||
raise InvalidAddrError("`addr` should not be `None`")
|
||||
|
||||
if not isinstance(addr, multiaddr.Multiaddr):
|
||||
raise InvalidAddrError(f"`addr`={addr} should be of type `Multiaddr`")
|
||||
|
||||
parts = addr.split()
|
||||
if not parts:
|
||||
raise InvalidAddrError(
|
||||
|
|
|
@ -1,15 +1,37 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, List
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
AsyncContextManager,
|
||||
AsyncIterable,
|
||||
KeysView,
|
||||
List,
|
||||
Tuple,
|
||||
)
|
||||
|
||||
from async_service import ServiceAPI
|
||||
|
||||
from libp2p.peer.id import ID
|
||||
from libp2p.typing import TProtocol
|
||||
|
||||
from .pb import rpc_pb2
|
||||
from .typing import ValidatorFn
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .pubsub import Pubsub # noqa: F401
|
||||
|
||||
|
||||
class ISubscriptionAPI(
|
||||
AsyncContextManager["ISubscriptionAPI"], AsyncIterable[rpc_pb2.Message]
|
||||
):
|
||||
@abstractmethod
|
||||
async def unsubscribe(self) -> None:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def get(self) -> rpc_pb2.Message:
|
||||
...
|
||||
|
||||
|
||||
class IPubsubRouter(ABC):
|
||||
@abstractmethod
|
||||
def get_protocols(self) -> List[TProtocol]:
|
||||
|
@ -53,7 +75,6 @@ class IPubsubRouter(ABC):
|
|||
:param rpc: rpc message
|
||||
"""
|
||||
|
||||
# FIXME: Should be changed to type 'peer.ID'
|
||||
@abstractmethod
|
||||
async def publish(self, msg_forwarder: ID, pubsub_msg: rpc_pb2.Message) -> None:
|
||||
"""
|
||||
|
@ -80,3 +101,46 @@ class IPubsubRouter(ABC):
|
|||
|
||||
:param topic: topic to leave
|
||||
"""
|
||||
|
||||
|
||||
class IPubsub(ServiceAPI):
|
||||
@property
|
||||
@abstractmethod
|
||||
def my_id(self) -> ID:
|
||||
...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def protocols(self) -> Tuple[TProtocol, ...]:
|
||||
...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def topic_ids(self) -> KeysView[str]:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def set_topic_validator(
|
||||
self, topic: str, validator: ValidatorFn, is_async_validator: bool
|
||||
) -> None:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def remove_topic_validator(self, topic: str) -> None:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def wait_until_ready(self) -> None:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def subscribe(self, topic_id: str) -> ISubscriptionAPI:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def unsubscribe(self, topic_id: str) -> None:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def publish(self, topic_id: str, data: bytes) -> None:
|
||||
...
|
|
@ -0,0 +1,9 @@
|
|||
from libp2p.exceptions import BaseLibp2pError
|
||||
|
||||
|
||||
class PubsubRouterError(BaseLibp2pError):
|
||||
pass
|
||||
|
||||
|
||||
class NoPubsubAttached(PubsubRouterError):
|
||||
pass
|
|
@ -1,14 +1,16 @@
|
|||
import logging
|
||||
from typing import Iterable, List, Sequence
|
||||
|
||||
import trio
|
||||
|
||||
from libp2p.network.stream.exceptions import StreamClosed
|
||||
from libp2p.peer.id import ID
|
||||
from libp2p.typing import TProtocol
|
||||
from libp2p.utils import encode_varint_prefixed
|
||||
|
||||
from .abc import IPubsubRouter
|
||||
from .pb import rpc_pb2
|
||||
from .pubsub import Pubsub
|
||||
from .pubsub_router_interface import IPubsubRouter
|
||||
|
||||
PROTOCOL_ID = TProtocol("/floodsub/1.0.0")
|
||||
|
||||
|
@ -61,6 +63,8 @@ class FloodSub(IPubsubRouter):
|
|||
|
||||
:param rpc: rpc message
|
||||
"""
|
||||
# Checkpoint
|
||||
await trio.hazmat.checkpoint()
|
||||
|
||||
async def publish(self, msg_forwarder: ID, pubsub_msg: rpc_pb2.Message) -> None:
|
||||
"""
|
||||
|
@ -107,6 +111,8 @@ class FloodSub(IPubsubRouter):
|
|||
|
||||
:param topic: topic to join
|
||||
"""
|
||||
# Checkpoint
|
||||
await trio.hazmat.checkpoint()
|
||||
|
||||
async def leave(self, topic: str) -> None:
|
||||
"""
|
||||
|
@ -115,6 +121,8 @@ class FloodSub(IPubsubRouter):
|
|||
|
||||
:param topic: topic to leave
|
||||
"""
|
||||
# Checkpoint
|
||||
await trio.hazmat.checkpoint()
|
||||
|
||||
def _get_peers_to_send(
|
||||
self, topic_ids: Iterable[str], msg_forwarder: ID, origin: ID
|
||||
|
|
|
@ -1,28 +1,30 @@
|
|||
from ast import literal_eval
|
||||
import asyncio
|
||||
from collections import defaultdict
|
||||
import logging
|
||||
import random
|
||||
from typing import Any, DefaultDict, Dict, Iterable, List, Sequence, Set, Tuple
|
||||
|
||||
from async_service import Service
|
||||
import trio
|
||||
|
||||
from libp2p.network.stream.exceptions import StreamClosed
|
||||
from libp2p.peer.id import ID
|
||||
from libp2p.pubsub import floodsub
|
||||
from libp2p.typing import TProtocol
|
||||
from libp2p.utils import encode_varint_prefixed
|
||||
|
||||
from .abc import IPubsubRouter
|
||||
from .exceptions import NoPubsubAttached
|
||||
from .mcache import MessageCache
|
||||
from .pb import rpc_pb2
|
||||
from .pubsub import Pubsub
|
||||
from .pubsub_router_interface import IPubsubRouter
|
||||
|
||||
PROTOCOL_ID = TProtocol("/meshsub/1.0.0")
|
||||
|
||||
logger = logging.getLogger("libp2p.pubsub.gossipsub")
|
||||
|
||||
|
||||
class GossipSub(IPubsubRouter):
|
||||
|
||||
class GossipSub(IPubsubRouter, Service):
|
||||
protocols: List[TProtocol]
|
||||
pubsub: Pubsub
|
||||
|
||||
|
@ -38,7 +40,8 @@ class GossipSub(IPubsubRouter):
|
|||
# The protocol peer supports
|
||||
peer_protocol: Dict[ID, TProtocol]
|
||||
|
||||
time_since_last_publish: Dict[str, int]
|
||||
# TODO: Add `time_since_last_publish`
|
||||
# Create topic --> time since last publish map.
|
||||
|
||||
mcache: MessageCache
|
||||
|
||||
|
@ -75,9 +78,6 @@ class GossipSub(IPubsubRouter):
|
|||
# Create peer --> protocol mapping
|
||||
self.peer_protocol = {}
|
||||
|
||||
# Create topic --> time since last publish map
|
||||
self.time_since_last_publish = {}
|
||||
|
||||
# Create message cache
|
||||
self.mcache = MessageCache(gossip_window, gossip_history)
|
||||
|
||||
|
@ -85,6 +85,12 @@ class GossipSub(IPubsubRouter):
|
|||
self.heartbeat_initial_delay = heartbeat_initial_delay
|
||||
self.heartbeat_interval = heartbeat_interval
|
||||
|
||||
async def run(self) -> None:
|
||||
if self.pubsub is None:
|
||||
raise NoPubsubAttached
|
||||
self.manager.run_daemon_task(self.heartbeat)
|
||||
await self.manager.wait_finished()
|
||||
|
||||
# Interface functions
|
||||
|
||||
def get_protocols(self) -> List[TProtocol]:
|
||||
|
@ -104,9 +110,6 @@ class GossipSub(IPubsubRouter):
|
|||
|
||||
logger.debug("attached to pusub")
|
||||
|
||||
# Start heartbeat now that we have a pubsub instance
|
||||
asyncio.ensure_future(self.heartbeat())
|
||||
|
||||
def add_peer(self, peer_id: ID, protocol_id: TProtocol) -> None:
|
||||
"""
|
||||
Notifies the router that a new peer has been connected.
|
||||
|
@ -370,7 +373,7 @@ class GossipSub(IPubsubRouter):
|
|||
state changes in the preceding heartbeat
|
||||
"""
|
||||
# Start after a delay. Ref: https://github.com/libp2p/go-libp2p-pubsub/blob/01b9825fbee1848751d90a8469e3f5f43bac8466/gossipsub.go#L410 # Noqa: E501
|
||||
await asyncio.sleep(self.heartbeat_initial_delay)
|
||||
await trio.sleep(self.heartbeat_initial_delay)
|
||||
while True:
|
||||
# Maintain mesh and keep track of which peers to send GRAFT or PRUNE to
|
||||
peers_to_graft, peers_to_prune = self.mesh_heartbeat()
|
||||
|
@ -385,7 +388,7 @@ class GossipSub(IPubsubRouter):
|
|||
|
||||
self.mcache.shift()
|
||||
|
||||
await asyncio.sleep(self.heartbeat_interval)
|
||||
await trio.sleep(self.heartbeat_interval)
|
||||
|
||||
def mesh_heartbeat(
|
||||
self
|
||||
|
@ -413,7 +416,7 @@ class GossipSub(IPubsubRouter):
|
|||
|
||||
if num_mesh_peers_in_topic > self.degree_high:
|
||||
# Select |mesh[topic]| - D peers from mesh[topic]
|
||||
selected_peers = GossipSub.select_from_minus(
|
||||
selected_peers = self.select_from_minus(
|
||||
num_mesh_peers_in_topic - self.degree, self.mesh[topic], set()
|
||||
)
|
||||
for peer in selected_peers:
|
||||
|
@ -428,15 +431,10 @@ class GossipSub(IPubsubRouter):
|
|||
# Note: the comments here are the exact pseudocode from the spec
|
||||
for topic in self.fanout:
|
||||
# Delete topic entry if it's not in `pubsub.peer_topics`
|
||||
# or if it's time-since-last-published > ttl
|
||||
# TODO: there's no way time_since_last_publish gets set anywhere yet
|
||||
if (
|
||||
topic not in self.pubsub.peer_topics
|
||||
or self.time_since_last_publish[topic] > self.time_to_live
|
||||
):
|
||||
# or (TODO) if it's time-since-last-published > ttl
|
||||
if topic not in self.pubsub.peer_topics:
|
||||
# Remove topic from fanout
|
||||
del self.fanout[topic]
|
||||
del self.time_since_last_publish[topic]
|
||||
else:
|
||||
# Check if fanout peers are still in the topic and remove the ones that are not
|
||||
# ref: https://github.com/libp2p/go-libp2p-pubsub/blob/01b9825fbee1848751d90a8469e3f5f43bac8466/gossipsub.go#L498-L504 # noqa: E501
|
||||
|
|
|
@ -1,21 +1,12 @@
|
|||
import asyncio
|
||||
import functools
|
||||
import logging
|
||||
import time
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
NamedTuple,
|
||||
Set,
|
||||
Tuple,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
from typing import TYPE_CHECKING, Dict, KeysView, List, NamedTuple, Set, Tuple, cast
|
||||
|
||||
from async_service import Service
|
||||
import base58
|
||||
from lru import LRU
|
||||
import trio
|
||||
|
||||
from libp2p.crypto.keys import PrivateKey
|
||||
from libp2p.exceptions import ParseError, ValidationError
|
||||
|
@ -28,15 +19,21 @@ from libp2p.peer.id import ID
|
|||
from libp2p.typing import TProtocol
|
||||
from libp2p.utils import encode_varint_prefixed, read_varint_prefixed_bytes
|
||||
|
||||
from .abc import IPubsub, ISubscriptionAPI
|
||||
from .pb import rpc_pb2
|
||||
from .pubsub_notifee import PubsubNotifee
|
||||
from .subscription import TrioSubscriptionAPI
|
||||
from .typing import AsyncValidatorFn, SyncValidatorFn, ValidatorFn
|
||||
from .validators import PUBSUB_SIGNING_PREFIX, signature_validator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .pubsub_router_interface import IPubsubRouter # noqa: F401
|
||||
from .abc import IPubsubRouter # noqa: F401
|
||||
from typing import Any # noqa: F401
|
||||
|
||||
|
||||
# Ref: https://github.com/libp2p/go-libp2p-pubsub/blob/40e1c94708658b155f30cf99e4574f384756d83c/topic.go#L97 # noqa: E501
|
||||
SUBSCRIPTION_CHANNEL_SIZE = 32
|
||||
|
||||
logger = logging.getLogger("libp2p.pubsub")
|
||||
|
||||
|
||||
|
@ -45,34 +42,24 @@ def get_msg_id(msg: rpc_pb2.Message) -> Tuple[bytes, bytes]:
|
|||
return (msg.seqno, msg.from_id)
|
||||
|
||||
|
||||
SyncValidatorFn = Callable[[ID, rpc_pb2.Message], bool]
|
||||
AsyncValidatorFn = Callable[[ID, rpc_pb2.Message], Awaitable[bool]]
|
||||
ValidatorFn = Union[SyncValidatorFn, AsyncValidatorFn]
|
||||
|
||||
|
||||
class TopicValidator(NamedTuple):
|
||||
validator: ValidatorFn
|
||||
is_async: bool
|
||||
|
||||
|
||||
class Pubsub:
|
||||
class Pubsub(Service, IPubsub):
|
||||
|
||||
host: IHost
|
||||
my_id: ID
|
||||
|
||||
router: "IPubsubRouter"
|
||||
|
||||
peer_queue: "asyncio.Queue[ID]"
|
||||
dead_peer_queue: "asyncio.Queue[ID]"
|
||||
|
||||
protocols: List[TProtocol]
|
||||
|
||||
incoming_msgs_from_peers: "asyncio.Queue[rpc_pb2.Message]"
|
||||
outgoing_messages: "asyncio.Queue[rpc_pb2.Message]"
|
||||
peer_receive_channel: "trio.MemoryReceiveChannel[ID]"
|
||||
dead_peer_receive_channel: "trio.MemoryReceiveChannel[ID]"
|
||||
|
||||
seen_messages: LRU
|
||||
|
||||
my_topics: Dict[str, "asyncio.Queue[rpc_pb2.Message]"]
|
||||
subscribed_topics_send: Dict[str, "trio.MemorySendChannel[rpc_pb2.Message]"]
|
||||
subscribed_topics_receive: Dict[str, "TrioSubscriptionAPI"]
|
||||
|
||||
peer_topics: Dict[str, Set[ID]]
|
||||
peers: Dict[ID, INetStream]
|
||||
|
@ -81,17 +68,17 @@ class Pubsub:
|
|||
|
||||
counter: int # uint64
|
||||
|
||||
_tasks: List["asyncio.Future[Any]"]
|
||||
|
||||
# Indicate if we should enforce signature verification
|
||||
strict_signing: bool
|
||||
sign_key: PrivateKey
|
||||
|
||||
event_handle_peer_queue_started: trio.Event
|
||||
event_handle_dead_peer_queue_started: trio.Event
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: IHost,
|
||||
router: "IPubsubRouter",
|
||||
my_id: ID,
|
||||
cache_size: int = None,
|
||||
strict_signing: bool = True,
|
||||
) -> None:
|
||||
|
@ -107,39 +94,44 @@ class Pubsub:
|
|||
"""
|
||||
self.host = host
|
||||
self.router = router
|
||||
self.my_id = my_id
|
||||
|
||||
# Attach this new Pubsub object to the router
|
||||
self.router.attach(self)
|
||||
|
||||
peer_send, peer_receive = trio.open_memory_channel[ID](0)
|
||||
dead_peer_send, dead_peer_receive = trio.open_memory_channel[ID](0)
|
||||
# Only keep the receive channels in `Pubsub`.
|
||||
# Therefore, we can only close from the receive side.
|
||||
self.peer_receive_channel = peer_receive
|
||||
self.dead_peer_receive_channel = dead_peer_receive
|
||||
# Register a notifee
|
||||
self.peer_queue = asyncio.Queue()
|
||||
self.dead_peer_queue = asyncio.Queue()
|
||||
self.host.get_network().register_notifee(
|
||||
PubsubNotifee(self.peer_queue, self.dead_peer_queue)
|
||||
PubsubNotifee(peer_send, dead_peer_send)
|
||||
)
|
||||
|
||||
# Register stream handlers for each pubsub router protocol to handle
|
||||
# the pubsub streams opened on those protocols
|
||||
self.protocols = self.router.get_protocols()
|
||||
for protocol in self.protocols:
|
||||
for protocol in router.get_protocols():
|
||||
self.host.set_stream_handler(protocol, self.stream_handler)
|
||||
|
||||
# Use asyncio queues for proper context switching
|
||||
self.incoming_msgs_from_peers = asyncio.Queue()
|
||||
self.outgoing_messages = asyncio.Queue()
|
||||
|
||||
# keeps track of seen messages as LRU cache
|
||||
if cache_size is None:
|
||||
self.cache_size = 128
|
||||
else:
|
||||
self.cache_size = cache_size
|
||||
|
||||
self.strict_signing = strict_signing
|
||||
if strict_signing:
|
||||
self.sign_key = self.host.get_private_key()
|
||||
else:
|
||||
self.sign_key = None
|
||||
|
||||
self.seen_messages = LRU(self.cache_size)
|
||||
|
||||
# Map of topics we are subscribed to blocking queues
|
||||
# for when the given topic receives a message
|
||||
self.my_topics = {}
|
||||
self.subscribed_topics_send = {}
|
||||
self.subscribed_topics_receive = {}
|
||||
|
||||
# Map of topic to peers to keep track of what peers are subscribed to
|
||||
self.peer_topics = {}
|
||||
|
@ -152,22 +144,31 @@ class Pubsub:
|
|||
|
||||
self.counter = int(time.time())
|
||||
|
||||
self._tasks = []
|
||||
# Call handle peer to keep waiting for updates to peer queue
|
||||
self._tasks.append(asyncio.ensure_future(self.handle_peer_queue()))
|
||||
self._tasks.append(asyncio.ensure_future(self.handle_dead_peer_queue()))
|
||||
self.event_handle_peer_queue_started = trio.Event()
|
||||
self.event_handle_dead_peer_queue_started = trio.Event()
|
||||
|
||||
self.strict_signing = strict_signing
|
||||
if strict_signing:
|
||||
self.sign_key = self.host.get_private_key()
|
||||
else:
|
||||
self.sign_key = None
|
||||
async def run(self) -> None:
|
||||
self.manager.run_daemon_task(self.handle_peer_queue)
|
||||
self.manager.run_daemon_task(self.handle_dead_peer_queue)
|
||||
await self.manager.wait_finished()
|
||||
|
||||
@property
|
||||
def my_id(self) -> ID:
|
||||
return self.host.get_id()
|
||||
|
||||
@property
|
||||
def protocols(self) -> Tuple[TProtocol, ...]:
|
||||
return tuple(self.router.get_protocols())
|
||||
|
||||
@property
|
||||
def topic_ids(self) -> KeysView[str]:
|
||||
return self.subscribed_topics_receive.keys()
|
||||
|
||||
def get_hello_packet(self) -> rpc_pb2.RPC:
|
||||
"""Generate subscription message with all topics we are subscribed to
|
||||
only send hello packet if we have subscribed topics."""
|
||||
packet = rpc_pb2.RPC()
|
||||
for topic_id in self.my_topics:
|
||||
for topic_id in self.topic_ids:
|
||||
packet.subscriptions.extend(
|
||||
[rpc_pb2.RPC.SubOpts(subscribe=True, topicid=topic_id)]
|
||||
)
|
||||
|
@ -182,7 +183,7 @@ class Pubsub:
|
|||
"""
|
||||
peer_id = stream.muxed_conn.peer_id
|
||||
|
||||
while True:
|
||||
while self.manager.is_running:
|
||||
incoming: bytes = await read_varint_prefixed_bytes(stream)
|
||||
rpc_incoming: rpc_pb2.RPC = rpc_pb2.RPC()
|
||||
rpc_incoming.ParseFromString(incoming)
|
||||
|
@ -194,11 +195,7 @@ class Pubsub:
|
|||
logger.debug(
|
||||
"received `publish` message %s from peer %s", msg, peer_id
|
||||
)
|
||||
self._tasks.append(
|
||||
asyncio.ensure_future(
|
||||
self.push_msg(msg_forwarder=peer_id, msg=msg)
|
||||
)
|
||||
)
|
||||
self.manager.run_task(self.push_msg, peer_id, msg)
|
||||
|
||||
if rpc_incoming.subscriptions:
|
||||
# deal with RPC.subscriptions
|
||||
|
@ -226,9 +223,6 @@ class Pubsub:
|
|||
)
|
||||
await self.router.handle_rpc(rpc_incoming, peer_id)
|
||||
|
||||
# Force context switch
|
||||
await asyncio.sleep(0)
|
||||
|
||||
def set_topic_validator(
|
||||
self, topic: str, validator: ValidatorFn, is_async_validator: bool
|
||||
) -> None:
|
||||
|
@ -283,6 +277,10 @@ class Pubsub:
|
|||
await stream.reset()
|
||||
self._handle_dead_peer(peer_id)
|
||||
|
||||
async def wait_until_ready(self) -> None:
|
||||
await self.event_handle_peer_queue_started.wait()
|
||||
await self.event_handle_dead_peer_queue_started.wait()
|
||||
|
||||
async def _handle_new_peer(self, peer_id: ID) -> None:
|
||||
try:
|
||||
stream: INetStream = await self.host.new_stream(peer_id, self.protocols)
|
||||
|
@ -325,18 +323,21 @@ class Pubsub:
|
|||
"""Continuously read from peer queue and each time a new peer is found,
|
||||
open a stream to the peer using a supported pubsub protocol pubsub
|
||||
protocols we support."""
|
||||
while True:
|
||||
peer_id: ID = await self.peer_queue.get()
|
||||
# Add Peer
|
||||
self._tasks.append(asyncio.ensure_future(self._handle_new_peer(peer_id)))
|
||||
async with self.peer_receive_channel:
|
||||
self.event_handle_peer_queue_started.set()
|
||||
async for peer_id in self.peer_receive_channel:
|
||||
# Add Peer
|
||||
self.manager.run_task(self._handle_new_peer, peer_id)
|
||||
|
||||
async def handle_dead_peer_queue(self) -> None:
|
||||
"""Continuously read from dead peer queue and close the stream between
|
||||
that peer and remove peer info from pubsub and pubsub router."""
|
||||
while True:
|
||||
peer_id: ID = await self.dead_peer_queue.get()
|
||||
# Remove Peer
|
||||
self._handle_dead_peer(peer_id)
|
||||
"""Continuously read from dead peer channel and close the stream
|
||||
between that peer and remove peer info from pubsub and pubsub
|
||||
router."""
|
||||
async with self.dead_peer_receive_channel:
|
||||
self.event_handle_dead_peer_queue_started.set()
|
||||
async for peer_id in self.dead_peer_receive_channel:
|
||||
# Remove Peer
|
||||
self._handle_dead_peer(peer_id)
|
||||
|
||||
def handle_subscription(
|
||||
self, origin_id: ID, sub_message: rpc_pb2.RPC.SubOpts
|
||||
|
@ -360,8 +361,7 @@ class Pubsub:
|
|||
if origin_id in self.peer_topics[sub_message.topicid]:
|
||||
self.peer_topics[sub_message.topicid].discard(origin_id)
|
||||
|
||||
# FIXME(mhchia): Change the function name?
|
||||
async def handle_talk(self, publish_message: rpc_pb2.Message) -> None:
|
||||
def notify_subscriptions(self, publish_message: rpc_pb2.Message) -> None:
|
||||
"""
|
||||
Put incoming message from a peer onto my blocking queue.
|
||||
|
||||
|
@ -370,13 +370,19 @@ class Pubsub:
|
|||
|
||||
# Check if this message has any topics that we are subscribed to
|
||||
for topic in publish_message.topicIDs:
|
||||
if topic in self.my_topics:
|
||||
if topic in self.topic_ids:
|
||||
# we are subscribed to a topic this message was sent for,
|
||||
# so add message to the subscription output queue
|
||||
# for each topic
|
||||
await self.my_topics[topic].put(publish_message)
|
||||
try:
|
||||
self.subscribed_topics_send[topic].send_nowait(publish_message)
|
||||
except trio.WouldBlock:
|
||||
# Channel is full, ignore this message.
|
||||
logger.warning(
|
||||
"fail to deliver message to subscription for topic %s", topic
|
||||
)
|
||||
|
||||
async def subscribe(self, topic_id: str) -> "asyncio.Queue[rpc_pb2.Message]":
|
||||
async def subscribe(self, topic_id: str) -> ISubscriptionAPI:
|
||||
"""
|
||||
Subscribe ourself to a topic.
|
||||
|
||||
|
@ -386,11 +392,19 @@ class Pubsub:
|
|||
logger.debug("subscribing to topic %s", topic_id)
|
||||
|
||||
# Already subscribed
|
||||
if topic_id in self.my_topics:
|
||||
return self.my_topics[topic_id]
|
||||
if topic_id in self.topic_ids:
|
||||
return self.subscribed_topics_receive[topic_id]
|
||||
|
||||
# Map topic_id to blocking queue
|
||||
self.my_topics[topic_id] = asyncio.Queue()
|
||||
send_channel, receive_channel = trio.open_memory_channel[rpc_pb2.Message](
|
||||
SUBSCRIPTION_CHANNEL_SIZE
|
||||
)
|
||||
|
||||
subscription = TrioSubscriptionAPI(
|
||||
receive_channel,
|
||||
unsubscribe_fn=functools.partial(self.unsubscribe, topic_id),
|
||||
)
|
||||
self.subscribed_topics_send[topic_id] = send_channel
|
||||
self.subscribed_topics_receive[topic_id] = subscription
|
||||
|
||||
# Create subscribe message
|
||||
packet: rpc_pb2.RPC = rpc_pb2.RPC()
|
||||
|
@ -404,8 +418,8 @@ class Pubsub:
|
|||
# Tell router we are joining this topic
|
||||
await self.router.join(topic_id)
|
||||
|
||||
# Return the asyncio queue for messages on this topic
|
||||
return self.my_topics[topic_id]
|
||||
# Return the subscription for messages on this topic
|
||||
return subscription
|
||||
|
||||
async def unsubscribe(self, topic_id: str) -> None:
|
||||
"""
|
||||
|
@ -417,10 +431,14 @@ class Pubsub:
|
|||
logger.debug("unsubscribing from topic %s", topic_id)
|
||||
|
||||
# Return if we already unsubscribed from the topic
|
||||
if topic_id not in self.my_topics:
|
||||
if topic_id not in self.topic_ids:
|
||||
return
|
||||
# Remove topic_id from map if present
|
||||
del self.my_topics[topic_id]
|
||||
# Remove topic_id from the maps before yielding
|
||||
send_channel = self.subscribed_topics_send[topic_id]
|
||||
del self.subscribed_topics_send[topic_id]
|
||||
del self.subscribed_topics_receive[topic_id]
|
||||
# Only close the send side
|
||||
await send_channel.aclose()
|
||||
|
||||
# Create unsubscribe message
|
||||
packet: rpc_pb2.RPC = rpc_pb2.RPC()
|
||||
|
@ -462,7 +480,7 @@ class Pubsub:
|
|||
data=data,
|
||||
topicIDs=[topic_id],
|
||||
# Origin is ourself.
|
||||
from_id=self.host.get_id().to_bytes(),
|
||||
from_id=self.my_id.to_bytes(),
|
||||
seqno=self._next_seqno(),
|
||||
)
|
||||
|
||||
|
@ -474,7 +492,7 @@ class Pubsub:
|
|||
msg.key = self.host.get_public_key().serialize()
|
||||
msg.signature = signature
|
||||
|
||||
await self.push_msg(self.host.get_id(), msg)
|
||||
await self.push_msg(self.my_id, msg)
|
||||
|
||||
logger.debug("successfully published message %s", msg)
|
||||
|
||||
|
@ -485,12 +503,12 @@ class Pubsub:
|
|||
:param msg_forwarder: the peer who forward us the message.
|
||||
:param msg: the message.
|
||||
"""
|
||||
sync_topic_validators = []
|
||||
async_topic_validator_futures: List[Awaitable[bool]] = []
|
||||
sync_topic_validators: List[SyncValidatorFn] = []
|
||||
async_topic_validators: List[AsyncValidatorFn] = []
|
||||
for topic_validator in self.get_msg_validators(msg):
|
||||
if topic_validator.is_async:
|
||||
async_topic_validator_futures.append(
|
||||
cast(Awaitable[bool], topic_validator.validator(msg_forwarder, msg))
|
||||
async_topic_validators.append(
|
||||
cast(AsyncValidatorFn, topic_validator.validator)
|
||||
)
|
||||
else:
|
||||
sync_topic_validators.append(
|
||||
|
@ -503,9 +521,20 @@ class Pubsub:
|
|||
|
||||
# TODO: Implement throttle on async validators
|
||||
|
||||
if len(async_topic_validator_futures) > 0:
|
||||
results = await asyncio.gather(*async_topic_validator_futures)
|
||||
if not all(results):
|
||||
if len(async_topic_validators) > 0:
|
||||
# TODO: Use a better pattern
|
||||
final_result: bool = True
|
||||
|
||||
async def run_async_validator(func: AsyncValidatorFn) -> None:
|
||||
nonlocal final_result
|
||||
result = await func(msg_forwarder, msg)
|
||||
final_result = final_result and result
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
for async_validator in async_topic_validators:
|
||||
nursery.start_soon(run_async_validator, async_validator)
|
||||
|
||||
if not final_result:
|
||||
raise ValidationError(f"Validation failed for msg={msg}")
|
||||
|
||||
async def push_msg(self, msg_forwarder: ID, msg: rpc_pb2.Message) -> None:
|
||||
|
@ -548,7 +577,7 @@ class Pubsub:
|
|||
return
|
||||
|
||||
self._mark_msg_seen(msg)
|
||||
await self.handle_talk(msg)
|
||||
self.notify_subscriptions(msg)
|
||||
await self.router.publish(msg_forwarder, msg)
|
||||
|
||||
def _next_seqno(self) -> bytes:
|
||||
|
@ -567,14 +596,4 @@ class Pubsub:
|
|||
self.seen_messages[msg_id] = 1
|
||||
|
||||
def _is_subscribed_to_msg(self, msg: rpc_pb2.Message) -> bool:
|
||||
if not self.my_topics:
|
||||
return False
|
||||
return any(topic in self.my_topics for topic in msg.topicIDs)
|
||||
|
||||
async def close(self) -> None:
|
||||
for task in self._tasks:
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
return any(topic in self.topic_ids for topic in msg.topicIDs)
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
from typing import TYPE_CHECKING
|
||||
|
||||
from multiaddr import Multiaddr
|
||||
import trio
|
||||
|
||||
from libp2p.network.connection.net_connection_interface import INetConn
|
||||
from libp2p.network.network_interface import INetwork
|
||||
|
@ -8,19 +9,18 @@ from libp2p.network.notifee_interface import INotifee
|
|||
from libp2p.network.stream.net_stream_interface import INetStream
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import asyncio # noqa: F401
|
||||
from libp2p.peer.id import ID # noqa: F401
|
||||
|
||||
|
||||
class PubsubNotifee(INotifee):
|
||||
|
||||
initiator_peers_queue: "asyncio.Queue[ID]"
|
||||
dead_peers_queue: "asyncio.Queue[ID]"
|
||||
initiator_peers_queue: "trio.MemorySendChannel[ID]"
|
||||
dead_peers_queue: "trio.MemorySendChannel[ID]"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
initiator_peers_queue: "asyncio.Queue[ID]",
|
||||
dead_peers_queue: "asyncio.Queue[ID]",
|
||||
initiator_peers_queue: "trio.MemorySendChannel[ID]",
|
||||
dead_peers_queue: "trio.MemorySendChannel[ID]",
|
||||
) -> None:
|
||||
"""
|
||||
:param initiator_peers_queue: queue to add new peers to so that pubsub
|
||||
|
@ -32,10 +32,10 @@ class PubsubNotifee(INotifee):
|
|||
self.dead_peers_queue = dead_peers_queue
|
||||
|
||||
async def opened_stream(self, network: INetwork, stream: INetStream) -> None:
|
||||
pass
|
||||
await trio.hazmat.checkpoint()
|
||||
|
||||
async def closed_stream(self, network: INetwork, stream: INetStream) -> None:
|
||||
pass
|
||||
await trio.hazmat.checkpoint()
|
||||
|
||||
async def connected(self, network: INetwork, conn: INetConn) -> None:
|
||||
"""
|
||||
|
@ -46,7 +46,11 @@ class PubsubNotifee(INotifee):
|
|||
:param network: network the connection was opened on
|
||||
:param conn: connection that was opened
|
||||
"""
|
||||
await self.initiator_peers_queue.put(conn.muxed_conn.peer_id)
|
||||
try:
|
||||
await self.initiator_peers_queue.send(conn.muxed_conn.peer_id)
|
||||
except trio.BrokenResourceError:
|
||||
# The receive channel is closed by Pubsub. We should do nothing here.
|
||||
pass
|
||||
|
||||
async def disconnected(self, network: INetwork, conn: INetConn) -> None:
|
||||
"""
|
||||
|
@ -56,10 +60,14 @@ class PubsubNotifee(INotifee):
|
|||
:param network: network the connection was opened on
|
||||
:param conn: connection that was opened
|
||||
"""
|
||||
await self.dead_peers_queue.put(conn.muxed_conn.peer_id)
|
||||
try:
|
||||
await self.dead_peers_queue.send(conn.muxed_conn.peer_id)
|
||||
except trio.BrokenResourceError:
|
||||
# The receive channel is closed by Pubsub. We should do nothing here.
|
||||
pass
|
||||
|
||||
async def listen(self, network: INetwork, multiaddr: Multiaddr) -> None:
|
||||
pass
|
||||
await trio.hazmat.checkpoint()
|
||||
|
||||
async def listen_close(self, network: INetwork, multiaddr: Multiaddr) -> None:
|
||||
pass
|
||||
await trio.hazmat.checkpoint()
|
||||
|
|
|
@ -0,0 +1,46 @@
|
|||
from types import TracebackType
|
||||
from typing import AsyncIterator, Optional, Type
|
||||
|
||||
import trio
|
||||
|
||||
from .abc import ISubscriptionAPI
|
||||
from .pb import rpc_pb2
|
||||
from .typing import UnsubscribeFn
|
||||
|
||||
|
||||
class BaseSubscriptionAPI(ISubscriptionAPI):
|
||||
async def __aenter__(self) -> "BaseSubscriptionAPI":
|
||||
await trio.hazmat.checkpoint()
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: "Optional[Type[BaseException]]",
|
||||
exc_value: "Optional[BaseException]",
|
||||
traceback: "Optional[TracebackType]",
|
||||
) -> None:
|
||||
await self.unsubscribe()
|
||||
|
||||
|
||||
class TrioSubscriptionAPI(BaseSubscriptionAPI):
|
||||
receive_channel: "trio.MemoryReceiveChannel[rpc_pb2.Message]"
|
||||
unsubscribe_fn: UnsubscribeFn
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
receive_channel: "trio.MemoryReceiveChannel[rpc_pb2.Message]",
|
||||
unsubscribe_fn: UnsubscribeFn,
|
||||
) -> None:
|
||||
self.receive_channel = receive_channel
|
||||
# Ignore type here since mypy complains: https://github.com/python/mypy/issues/2427
|
||||
self.unsubscribe_fn = unsubscribe_fn # type: ignore
|
||||
|
||||
async def unsubscribe(self) -> None:
|
||||
# Ignore type here since mypy complains: https://github.com/python/mypy/issues/2427
|
||||
await self.unsubscribe_fn() # type: ignore
|
||||
|
||||
def __aiter__(self) -> AsyncIterator[rpc_pb2.Message]:
|
||||
return self.receive_channel.__aiter__()
|
||||
|
||||
async def get(self) -> rpc_pb2.Message:
|
||||
return await self.receive_channel.receive()
|
|
@ -0,0 +1,11 @@
|
|||
from typing import Awaitable, Callable, Union
|
||||
|
||||
from libp2p.peer.id import ID
|
||||
|
||||
from .pb import rpc_pb2
|
||||
|
||||
SyncValidatorFn = Callable[[ID, rpc_pb2.Message], bool]
|
||||
AsyncValidatorFn = Callable[[ID, rpc_pb2.Message], Awaitable[bool]]
|
||||
ValidatorFn = Union[SyncValidatorFn, AsyncValidatorFn]
|
||||
|
||||
UnsubscribeFn = Callable[[], Awaitable[None]]
|
|
@ -39,7 +39,7 @@ class InsecureSession(BaseSession):
|
|||
await self.conn.write(data)
|
||||
return len(data)
|
||||
|
||||
async def read(self, n: int = -1) -> bytes:
|
||||
async def read(self, n: int = None) -> bytes:
|
||||
return await self.conn.read(n)
|
||||
|
||||
async def close(self) -> None:
|
||||
|
|
|
@ -94,7 +94,7 @@ class SecureSession(BaseSession):
|
|||
|
||||
data = self.buf.getbuffer()[self.low_watermark : self.high_watermark]
|
||||
|
||||
if n < 0:
|
||||
if n is None:
|
||||
n = len(data)
|
||||
result = data[:n].tobytes()
|
||||
self.low_watermark += len(result)
|
||||
|
@ -111,7 +111,7 @@ class SecureSession(BaseSession):
|
|||
self.low_watermark = 0
|
||||
self.high_watermark = len(msg)
|
||||
|
||||
async def read(self, n: int = -1) -> bytes:
|
||||
async def read(self, n: int = None) -> bytes:
|
||||
if n == 0:
|
||||
return bytes()
|
||||
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
from abc import ABC, abstractmethod
|
||||
|
||||
import trio
|
||||
|
||||
from libp2p.io.abc import ReadWriteCloser
|
||||
from libp2p.peer.id import ID
|
||||
from libp2p.security.secure_conn_interface import ISecureConn
|
||||
|
@ -11,6 +13,7 @@ class IMuxedConn(ABC):
|
|||
"""
|
||||
|
||||
peer_id: ID
|
||||
event_started: trio.Event
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, conn: ISecureConn, peer_id: ID) -> None:
|
||||
|
@ -25,12 +28,17 @@ class IMuxedConn(ABC):
|
|||
@property
|
||||
@abstractmethod
|
||||
def is_initiator(self) -> bool:
|
||||
pass
|
||||
"""if this connection is the initiator."""
|
||||
|
||||
@abstractmethod
|
||||
async def start(self) -> None:
|
||||
"""start the multiplexer."""
|
||||
|
||||
@abstractmethod
|
||||
async def close(self) -> None:
|
||||
"""close connection."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def is_closed(self) -> bool:
|
||||
"""
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from typing import Any # noqa: F401
|
||||
from typing import Awaitable, Dict, List, Optional, Tuple
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
import trio
|
||||
|
||||
from libp2p.exceptions import ParseError
|
||||
from libp2p.io.exceptions import IncompleteReadError
|
||||
|
@ -23,6 +23,8 @@ from .exceptions import MplexUnavailable
|
|||
from .mplex_stream import MplexStream
|
||||
|
||||
MPLEX_PROTOCOL_ID = TProtocol("/mplex/6.7.0")
|
||||
# Ref: https://github.com/libp2p/go-mplex/blob/414db61813d9ad3e6f4a7db5c1b1612de343ace9/multiplex.go#L115 # noqa: E501
|
||||
MPLEX_MESSAGE_CHANNEL_SIZE = 8
|
||||
|
||||
logger = logging.getLogger("libp2p.stream_muxer.mplex.mplex")
|
||||
|
||||
|
@ -36,12 +38,14 @@ class Mplex(IMuxedConn):
|
|||
peer_id: ID
|
||||
next_channel_id: int
|
||||
streams: Dict[StreamID, MplexStream]
|
||||
streams_lock: asyncio.Lock
|
||||
new_stream_queue: "asyncio.Queue[IMuxedStream]"
|
||||
event_shutting_down: asyncio.Event
|
||||
event_closed: asyncio.Event
|
||||
streams_lock: trio.Lock
|
||||
streams_msg_channels: Dict[StreamID, "trio.MemorySendChannel[bytes]"]
|
||||
new_stream_send_channel: "trio.MemorySendChannel[IMuxedStream]"
|
||||
new_stream_receive_channel: "trio.MemoryReceiveChannel[IMuxedStream]"
|
||||
|
||||
_tasks: List["asyncio.Future[Any]"]
|
||||
event_shutting_down: trio.Event
|
||||
event_closed: trio.Event
|
||||
event_started: trio.Event
|
||||
|
||||
def __init__(self, secured_conn: ISecureConn, peer_id: ID) -> None:
|
||||
"""
|
||||
|
@ -61,15 +65,16 @@ class Mplex(IMuxedConn):
|
|||
|
||||
# Mapping from stream ID -> buffer of messages for that stream
|
||||
self.streams = {}
|
||||
self.streams_lock = asyncio.Lock()
|
||||
self.new_stream_queue = asyncio.Queue()
|
||||
self.event_shutting_down = asyncio.Event()
|
||||
self.event_closed = asyncio.Event()
|
||||
self.streams_lock = trio.Lock()
|
||||
self.streams_msg_channels = {}
|
||||
channels = trio.open_memory_channel[IMuxedStream](0)
|
||||
self.new_stream_send_channel, self.new_stream_receive_channel = channels
|
||||
self.event_shutting_down = trio.Event()
|
||||
self.event_closed = trio.Event()
|
||||
self.event_started = trio.Event()
|
||||
|
||||
self._tasks = []
|
||||
|
||||
# Kick off reading
|
||||
self._tasks.append(asyncio.ensure_future(self.handle_incoming()))
|
||||
async def start(self) -> None:
|
||||
await self.handle_incoming()
|
||||
|
||||
@property
|
||||
def is_initiator(self) -> bool:
|
||||
|
@ -85,6 +90,7 @@ class Mplex(IMuxedConn):
|
|||
# Blocked until `close` is finally set.
|
||||
await self.event_closed.wait()
|
||||
|
||||
@property
|
||||
def is_closed(self) -> bool:
|
||||
"""
|
||||
check connection is fully closed.
|
||||
|
@ -104,9 +110,13 @@ class Mplex(IMuxedConn):
|
|||
return next_id
|
||||
|
||||
async def _initialize_stream(self, stream_id: StreamID, name: str) -> MplexStream:
|
||||
stream = MplexStream(name, stream_id, self)
|
||||
send_channel, receive_channel = trio.open_memory_channel[bytes](
|
||||
MPLEX_MESSAGE_CHANNEL_SIZE
|
||||
)
|
||||
stream = MplexStream(name, stream_id, self, receive_channel)
|
||||
async with self.streams_lock:
|
||||
self.streams[stream_id] = stream
|
||||
self.streams_msg_channels[stream_id] = send_channel
|
||||
return stream
|
||||
|
||||
async def open_stream(self) -> IMuxedStream:
|
||||
|
@ -123,27 +133,12 @@ class Mplex(IMuxedConn):
|
|||
await self.send_message(HeaderTags.NewStream, name.encode(), stream_id)
|
||||
return stream
|
||||
|
||||
async def _wait_until_shutting_down_or_closed(self, coro: Awaitable[Any]) -> Any:
|
||||
task_coro = asyncio.ensure_future(coro)
|
||||
task_wait_closed = asyncio.ensure_future(self.event_closed.wait())
|
||||
task_wait_shutting_down = asyncio.ensure_future(self.event_shutting_down.wait())
|
||||
done, pending = await asyncio.wait(
|
||||
[task_coro, task_wait_closed, task_wait_shutting_down],
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
)
|
||||
for fut in pending:
|
||||
fut.cancel()
|
||||
if task_wait_closed in done:
|
||||
raise MplexUnavailable("Mplex is closed")
|
||||
if task_wait_shutting_down in done:
|
||||
raise MplexUnavailable("Mplex is shutting down")
|
||||
return task_coro.result()
|
||||
|
||||
async def accept_stream(self) -> IMuxedStream:
|
||||
"""accepts a muxed stream opened by the other end."""
|
||||
return await self._wait_until_shutting_down_or_closed(
|
||||
self.new_stream_queue.get()
|
||||
)
|
||||
try:
|
||||
return await self.new_stream_receive_channel.receive()
|
||||
except trio.EndOfChannel:
|
||||
raise MplexUnavailable
|
||||
|
||||
async def send_message(
|
||||
self, flag: HeaderTags, data: Optional[bytes], stream_id: StreamID
|
||||
|
@ -151,7 +146,7 @@ class Mplex(IMuxedConn):
|
|||
"""
|
||||
sends a message over the connection.
|
||||
|
||||
:param header: header to use
|
||||
:param flag: header to use
|
||||
:param data: data to send in the message
|
||||
:param stream_id: stream the message is in
|
||||
"""
|
||||
|
@ -163,9 +158,7 @@ class Mplex(IMuxedConn):
|
|||
|
||||
_bytes = header + encode_varint_prefixed(data)
|
||||
|
||||
return await self._wait_until_shutting_down_or_closed(
|
||||
self.write_to_stream(_bytes)
|
||||
)
|
||||
return await self.write_to_stream(_bytes)
|
||||
|
||||
async def write_to_stream(self, _bytes: bytes) -> int:
|
||||
"""
|
||||
|
@ -174,21 +167,25 @@ class Mplex(IMuxedConn):
|
|||
:param _bytes: byte array to write
|
||||
:return: length written
|
||||
"""
|
||||
await self.secured_conn.write(_bytes)
|
||||
try:
|
||||
await self.secured_conn.write(_bytes)
|
||||
except RawConnError as e:
|
||||
raise MplexUnavailable(
|
||||
"failed to write message to the underlying connection"
|
||||
) from e
|
||||
|
||||
return len(_bytes)
|
||||
|
||||
async def handle_incoming(self) -> None:
|
||||
"""Read a message off of the secured connection and add it to the
|
||||
corresponding message buffer."""
|
||||
|
||||
self.event_started.set()
|
||||
while True:
|
||||
try:
|
||||
await self._handle_incoming_message()
|
||||
except MplexUnavailable as e:
|
||||
logger.debug("mplex unavailable while waiting for incoming: %s", e)
|
||||
break
|
||||
# Force context switch
|
||||
await asyncio.sleep(0)
|
||||
# If we enter here, it means this connection is shutting down.
|
||||
# We should clean things up.
|
||||
await self._cleanup()
|
||||
|
@ -200,20 +197,19 @@ class Mplex(IMuxedConn):
|
|||
:return: stream_id, flag, message contents
|
||||
"""
|
||||
|
||||
# FIXME: No timeout is used in Go implementation.
|
||||
try:
|
||||
header = await decode_uvarint_from_stream(self.secured_conn)
|
||||
message = await asyncio.wait_for(
|
||||
read_varint_prefixed_bytes(self.secured_conn), timeout=5
|
||||
)
|
||||
except (ParseError, RawConnError, IncompleteReadError) as error:
|
||||
raise MplexUnavailable(
|
||||
"failed to read messages correctly from the underlying connection"
|
||||
) from error
|
||||
except asyncio.TimeoutError as error:
|
||||
f"failed to read the header correctly from the underlying connection: {error}"
|
||||
)
|
||||
try:
|
||||
message = await read_varint_prefixed_bytes(self.secured_conn)
|
||||
except (ParseError, RawConnError, IncompleteReadError) as error:
|
||||
raise MplexUnavailable(
|
||||
"failed to read more message body within the timeout"
|
||||
) from error
|
||||
"failed to read the message body correctly from the underlying connection: "
|
||||
f"{error}"
|
||||
)
|
||||
|
||||
flag = header & 0x07
|
||||
channel_id = header >> 3
|
||||
|
@ -226,9 +222,7 @@ class Mplex(IMuxedConn):
|
|||
|
||||
:raise MplexUnavailable: `Mplex` encounters fatal error or is shutting down.
|
||||
"""
|
||||
channel_id, flag, message = await self._wait_until_shutting_down_or_closed(
|
||||
self.read_message()
|
||||
)
|
||||
channel_id, flag, message = await self.read_message()
|
||||
stream_id = StreamID(channel_id=channel_id, is_initiator=bool(flag & 1))
|
||||
|
||||
if flag == HeaderTags.NewStream.value:
|
||||
|
@ -258,9 +252,10 @@ class Mplex(IMuxedConn):
|
|||
f"received NewStream message for existing stream: {stream_id}"
|
||||
)
|
||||
mplex_stream = await self._initialize_stream(stream_id, message.decode())
|
||||
await self._wait_until_shutting_down_or_closed(
|
||||
self.new_stream_queue.put(mplex_stream)
|
||||
)
|
||||
try:
|
||||
await self.new_stream_send_channel.send(mplex_stream)
|
||||
except trio.ClosedResourceError:
|
||||
raise MplexUnavailable
|
||||
|
||||
async def _handle_message(self, stream_id: StreamID, message: bytes) -> None:
|
||||
async with self.streams_lock:
|
||||
|
@ -270,13 +265,21 @@ class Mplex(IMuxedConn):
|
|||
# TODO: Warn and emit logs about this.
|
||||
return
|
||||
stream = self.streams[stream_id]
|
||||
send_channel = self.streams_msg_channels[stream_id]
|
||||
async with stream.close_lock:
|
||||
if stream.event_remote_closed.is_set():
|
||||
# TODO: Warn "Received data from remote after stream was closed by them. (len = %d)" # noqa: E501
|
||||
return
|
||||
await self._wait_until_shutting_down_or_closed(
|
||||
stream.incoming_data.put(message)
|
||||
)
|
||||
try:
|
||||
send_channel.send_nowait(message)
|
||||
except (trio.BrokenResourceError, trio.ClosedResourceError):
|
||||
raise MplexUnavailable
|
||||
except trio.WouldBlock:
|
||||
# `send_channel` is full, reset this stream.
|
||||
logger.warning(
|
||||
"message channel of stream %s is full: stream is reset", stream_id
|
||||
)
|
||||
await stream.reset()
|
||||
|
||||
async def _handle_close(self, stream_id: StreamID) -> None:
|
||||
async with self.streams_lock:
|
||||
|
@ -284,6 +287,8 @@ class Mplex(IMuxedConn):
|
|||
# Ignore unmatched messages for now.
|
||||
return
|
||||
stream = self.streams[stream_id]
|
||||
send_channel = self.streams_msg_channels[stream_id]
|
||||
await send_channel.aclose()
|
||||
# NOTE: If remote is already closed, then return: Technically a bug
|
||||
# on the other side. We should consider killing the connection.
|
||||
async with stream.close_lock:
|
||||
|
@ -305,27 +310,30 @@ class Mplex(IMuxedConn):
|
|||
# This is *ok*. We forget the stream on reset.
|
||||
return
|
||||
stream = self.streams[stream_id]
|
||||
|
||||
send_channel = self.streams_msg_channels[stream_id]
|
||||
await send_channel.aclose()
|
||||
async with stream.close_lock:
|
||||
if not stream.event_remote_closed.is_set():
|
||||
stream.event_reset.set()
|
||||
|
||||
stream.event_remote_closed.set()
|
||||
# If local is not closed, we should close it.
|
||||
if not stream.event_local_closed.is_set():
|
||||
stream.event_local_closed.set()
|
||||
async with self.streams_lock:
|
||||
self.streams.pop(stream_id, None)
|
||||
self.streams_msg_channels.pop(stream_id, None)
|
||||
|
||||
async def _cleanup(self) -> None:
|
||||
if not self.event_shutting_down.is_set():
|
||||
self.event_shutting_down.set()
|
||||
async with self.streams_lock:
|
||||
for stream in self.streams.values():
|
||||
for stream_id, stream in self.streams.items():
|
||||
async with stream.close_lock:
|
||||
if not stream.event_remote_closed.is_set():
|
||||
stream.event_remote_closed.set()
|
||||
stream.event_reset.set()
|
||||
stream.event_local_closed.set()
|
||||
self.streams = None
|
||||
send_channel = self.streams_msg_channels[stream_id]
|
||||
await send_channel.aclose()
|
||||
self.event_closed.set()
|
||||
await self.new_stream_send_channel.aclose()
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
import asyncio
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import trio
|
||||
|
||||
from libp2p.stream_muxer.abc import IMuxedStream
|
||||
from libp2p.stream_muxer.exceptions import MuxedConnUnavailable
|
||||
|
||||
from .constants import HeaderTags
|
||||
from .datastructures import StreamID
|
||||
|
@ -22,18 +24,25 @@ class MplexStream(IMuxedStream):
|
|||
read_deadline: int
|
||||
write_deadline: int
|
||||
|
||||
close_lock: asyncio.Lock
|
||||
# TODO: Add lock for read/write to avoid interleaving receiving messages?
|
||||
close_lock: trio.Lock
|
||||
|
||||
# NOTE: `dataIn` is size of 8 in Go implementation.
|
||||
incoming_data: "asyncio.Queue[bytes]"
|
||||
incoming_data_channel: "trio.MemoryReceiveChannel[bytes]"
|
||||
|
||||
event_local_closed: asyncio.Event
|
||||
event_remote_closed: asyncio.Event
|
||||
event_reset: asyncio.Event
|
||||
event_local_closed: trio.Event
|
||||
event_remote_closed: trio.Event
|
||||
event_reset: trio.Event
|
||||
|
||||
_buf: bytearray
|
||||
|
||||
def __init__(self, name: str, stream_id: StreamID, muxed_conn: "Mplex") -> None:
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
stream_id: StreamID,
|
||||
muxed_conn: "Mplex",
|
||||
incoming_data_channel: "trio.MemoryReceiveChannel[bytes]",
|
||||
) -> None:
|
||||
"""
|
||||
create new MuxedStream in muxer.
|
||||
|
||||
|
@ -45,99 +54,82 @@ class MplexStream(IMuxedStream):
|
|||
self.muxed_conn = muxed_conn
|
||||
self.read_deadline = None
|
||||
self.write_deadline = None
|
||||
self.event_local_closed = asyncio.Event()
|
||||
self.event_remote_closed = asyncio.Event()
|
||||
self.event_reset = asyncio.Event()
|
||||
self.close_lock = asyncio.Lock()
|
||||
self.incoming_data = asyncio.Queue()
|
||||
self.event_local_closed = trio.Event()
|
||||
self.event_remote_closed = trio.Event()
|
||||
self.event_reset = trio.Event()
|
||||
self.close_lock = trio.Lock()
|
||||
self.incoming_data_channel = incoming_data_channel
|
||||
self._buf = bytearray()
|
||||
|
||||
@property
|
||||
def is_initiator(self) -> bool:
|
||||
return self.stream_id.is_initiator
|
||||
|
||||
async def _wait_for_data(self) -> None:
|
||||
task_event_reset = asyncio.ensure_future(self.event_reset.wait())
|
||||
task_incoming_data_get = asyncio.ensure_future(self.incoming_data.get())
|
||||
task_event_remote_closed = asyncio.ensure_future(
|
||||
self.event_remote_closed.wait()
|
||||
)
|
||||
done, pending = await asyncio.wait( # type: ignore
|
||||
[ # type: ignore
|
||||
task_event_reset,
|
||||
task_incoming_data_get,
|
||||
task_event_remote_closed,
|
||||
],
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
)
|
||||
for fut in pending:
|
||||
fut.cancel()
|
||||
|
||||
if task_event_reset in done:
|
||||
if self.event_reset.is_set():
|
||||
raise MplexStreamReset()
|
||||
else:
|
||||
# However, it is abnormal that `Event.wait` is unblocked without any of the flag
|
||||
# is set. The task is probably cancelled.
|
||||
raise Exception(
|
||||
"Should not enter here. "
|
||||
f"It is probably because {task_event_remote_closed} is cancelled."
|
||||
)
|
||||
|
||||
if task_incoming_data_get in done:
|
||||
data = task_incoming_data_get.result()
|
||||
self._buf.extend(data)
|
||||
return
|
||||
|
||||
if task_event_remote_closed in done:
|
||||
if self.event_remote_closed.is_set():
|
||||
raise MplexStreamEOF()
|
||||
else:
|
||||
# However, it is abnormal that `Event.wait` is unblocked without any of the flag
|
||||
# is set. The task is probably cancelled.
|
||||
raise Exception(
|
||||
"Should not enter here. "
|
||||
f"It is probably because {task_event_remote_closed} is cancelled."
|
||||
)
|
||||
|
||||
# TODO: Handle timeout when deadline is used.
|
||||
|
||||
async def _read_until_eof(self) -> bytes:
|
||||
while True:
|
||||
try:
|
||||
await self._wait_for_data()
|
||||
except MplexStreamEOF:
|
||||
break
|
||||
async for data in self.incoming_data_channel:
|
||||
self._buf.extend(data)
|
||||
payload = self._buf
|
||||
self._buf = self._buf[len(payload) :]
|
||||
return bytes(payload)
|
||||
|
||||
async def read(self, n: int = -1) -> bytes:
|
||||
def _read_return_when_blocked(self) -> bytes:
|
||||
buf = bytearray()
|
||||
while True:
|
||||
try:
|
||||
data = self.incoming_data_channel.receive_nowait()
|
||||
buf.extend(data)
|
||||
except (trio.WouldBlock, trio.EndOfChannel):
|
||||
break
|
||||
return buf
|
||||
|
||||
async def read(self, n: int = None) -> bytes:
|
||||
"""
|
||||
Read up to n bytes. Read possibly returns fewer than `n` bytes, if
|
||||
there are not enough bytes in the Mplex buffer. If `n == -1`, read
|
||||
there are not enough bytes in the Mplex buffer. If `n is None`, read
|
||||
until EOF.
|
||||
|
||||
:param n: number of bytes to read
|
||||
:return: bytes actually read
|
||||
"""
|
||||
if n < 0 and n != -1:
|
||||
if n is not None and n < 0:
|
||||
raise ValueError(
|
||||
f"the number of bytes to read `n` must be positive or -1 to indicate read until EOF"
|
||||
f"the number of bytes to read `n` must be non-negative or "
|
||||
"`None` to indicate read until EOF"
|
||||
)
|
||||
if self.event_reset.is_set():
|
||||
raise MplexStreamReset()
|
||||
if n == -1:
|
||||
raise MplexStreamReset
|
||||
if n is None:
|
||||
return await self._read_until_eof()
|
||||
if len(self._buf) == 0 and self.incoming_data.empty():
|
||||
await self._wait_for_data()
|
||||
# Now we are sure we have something to read.
|
||||
# Try to put enough incoming data into `self._buf`.
|
||||
while len(self._buf) < n:
|
||||
if len(self._buf) == 0:
|
||||
data: bytes
|
||||
# Peek whether there is data available. If yes, we just read until there is no data,
|
||||
# and then return.
|
||||
try:
|
||||
self._buf.extend(self.incoming_data.get_nowait())
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
data = self.incoming_data_channel.receive_nowait()
|
||||
self._buf.extend(data)
|
||||
except trio.EndOfChannel:
|
||||
raise MplexStreamEOF
|
||||
except trio.WouldBlock:
|
||||
# We know `receive` will be blocked here. Wait for data here with `receive` and
|
||||
# catch all kinds of errors here.
|
||||
try:
|
||||
data = await self.incoming_data_channel.receive()
|
||||
self._buf.extend(data)
|
||||
except trio.EndOfChannel:
|
||||
if self.event_reset.is_set():
|
||||
raise MplexStreamReset
|
||||
if self.event_remote_closed.is_set():
|
||||
raise MplexStreamEOF
|
||||
except trio.ClosedResourceError as error:
|
||||
# Probably `incoming_data_channel` is closed in `reset` when we are waiting
|
||||
# for `receive`.
|
||||
if self.event_reset.is_set():
|
||||
raise MplexStreamReset
|
||||
raise Exception(
|
||||
"`incoming_data_channel` is closed but stream is not reset. "
|
||||
"This should never happen."
|
||||
) from error
|
||||
self._buf.extend(self._read_return_when_blocked())
|
||||
payload = self._buf[:n]
|
||||
self._buf = self._buf[len(payload) :]
|
||||
return bytes(payload)
|
||||
|
@ -198,14 +190,17 @@ class MplexStream(IMuxedStream):
|
|||
if self.is_initiator
|
||||
else HeaderTags.ResetReceiver
|
||||
)
|
||||
asyncio.ensure_future(
|
||||
self.muxed_conn.send_message(flag, None, self.stream_id)
|
||||
)
|
||||
await asyncio.sleep(0)
|
||||
# Try to send reset message to the other side. Ignore if there is anything wrong.
|
||||
try:
|
||||
await self.muxed_conn.send_message(flag, None, self.stream_id)
|
||||
except MuxedConnUnavailable:
|
||||
pass
|
||||
|
||||
self.event_local_closed.set()
|
||||
self.event_remote_closed.set()
|
||||
|
||||
await self.incoming_data_channel.aclose()
|
||||
|
||||
async with self.muxed_conn.streams_lock:
|
||||
if self.muxed_conn.streams is not None:
|
||||
self.muxed_conn.streams.pop(self.stream_id, None)
|
||||
|
|
|
@ -7,7 +7,7 @@ from libp2p.pubsub import floodsub, gossipsub
|
|||
# Just a arbitrary large number.
|
||||
# It is used when calling `MplexStream.read(MAX_READ_LEN)`,
|
||||
# to avoid `MplexStream.read()`, which blocking reads until EOF.
|
||||
MAX_READ_LEN = 2 ** 32 - 1
|
||||
MAX_READ_LEN = 65535
|
||||
|
||||
|
||||
LISTEN_MADDR = multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/0")
|
||||
|
|
|
@ -1,40 +1,55 @@
|
|||
import asyncio
|
||||
from typing import Any, AsyncIterator, Dict, Tuple, cast
|
||||
from typing import Any, AsyncIterator, Dict, List, Sequence, Tuple, cast
|
||||
|
||||
# NOTE: import ``asynccontextmanager`` from ``contextlib`` when support for python 3.6 is dropped.
|
||||
from async_exit_stack import AsyncExitStack
|
||||
from async_generator import asynccontextmanager
|
||||
from async_service import background_trio_service
|
||||
import factory
|
||||
from multiaddr import Multiaddr
|
||||
import trio
|
||||
|
||||
from libp2p import generate_new_rsa_identity, generate_peer_id_from
|
||||
from libp2p.crypto.keys import KeyPair
|
||||
from libp2p.host.basic_host import BasicHost
|
||||
from libp2p.host.host_interface import IHost
|
||||
from libp2p.host.routed_host import RoutedHost
|
||||
from libp2p.io.abc import ReadWriteCloser
|
||||
from libp2p.network.connection.raw_connection import RawConnection
|
||||
from libp2p.network.connection.raw_connection_interface import IRawConnection
|
||||
from libp2p.network.connection.swarm_connection import SwarmConn
|
||||
from libp2p.network.stream.net_stream_interface import INetStream
|
||||
from libp2p.network.swarm import Swarm
|
||||
from libp2p.peer.id import ID
|
||||
from libp2p.peer.peerinfo import PeerInfo
|
||||
from libp2p.peer.peerstore import PeerStore
|
||||
from libp2p.pubsub.abc import IPubsubRouter
|
||||
from libp2p.pubsub.floodsub import FloodSub
|
||||
from libp2p.pubsub.gossipsub import GossipSub
|
||||
from libp2p.pubsub.pubsub import Pubsub
|
||||
from libp2p.routing.interfaces import IPeerRouting
|
||||
from libp2p.security.base_transport import BaseSecureTransport
|
||||
from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport
|
||||
import libp2p.security.secio.transport as secio
|
||||
from libp2p.stream_muxer.mplex.mplex import MPLEX_PROTOCOL_ID, Mplex
|
||||
from libp2p.stream_muxer.mplex.mplex_stream import MplexStream
|
||||
from libp2p.tools.constants import GOSSIPSUB_PARAMS
|
||||
from libp2p.transport.tcp.tcp import TCP
|
||||
from libp2p.transport.typing import TMuxerOptions
|
||||
from libp2p.transport.upgrader import TransportUpgrader
|
||||
from libp2p.typing import TProtocol
|
||||
|
||||
from .constants import (
|
||||
FLOODSUB_PROTOCOL_ID,
|
||||
GOSSIPSUB_PARAMS,
|
||||
GOSSIPSUB_PROTOCOL_ID,
|
||||
LISTEN_MADDR,
|
||||
)
|
||||
from .constants import FLOODSUB_PROTOCOL_ID, GOSSIPSUB_PROTOCOL_ID, LISTEN_MADDR
|
||||
from .utils import connect, connect_swarm
|
||||
|
||||
|
||||
class IDFactory(factory.Factory):
|
||||
class Meta:
|
||||
model = ID
|
||||
|
||||
peer_id_bytes = factory.LazyFunction(
|
||||
lambda: generate_peer_id_from(generate_new_rsa_identity())
|
||||
)
|
||||
|
||||
|
||||
def initialize_peerstore_with_our_keypair(self_id: ID, key_pair: KeyPair) -> PeerStore:
|
||||
peer_store = PeerStore()
|
||||
peer_store.add_key_pair(self_id, key_pair)
|
||||
|
@ -50,6 +65,29 @@ def security_transport_factory(
|
|||
return {secio.ID: secio.Transport(key_pair)}
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def raw_conn_factory(
|
||||
nursery: trio.Nursery
|
||||
) -> AsyncIterator[Tuple[IRawConnection, IRawConnection]]:
|
||||
conn_0 = None
|
||||
conn_1 = None
|
||||
event = trio.Event()
|
||||
|
||||
async def tcp_stream_handler(stream: ReadWriteCloser) -> None:
|
||||
nonlocal conn_1
|
||||
conn_1 = RawConnection(stream, initiator=False)
|
||||
event.set()
|
||||
await trio.sleep_forever()
|
||||
|
||||
tcp_transport = TCP()
|
||||
listener = tcp_transport.create_listener(tcp_stream_handler)
|
||||
await listener.listen(LISTEN_MADDR, nursery)
|
||||
listening_maddr = listener.get_addrs()[0]
|
||||
conn_0 = await tcp_transport.dial(listening_maddr)
|
||||
await event.wait()
|
||||
yield conn_0, conn_1
|
||||
|
||||
|
||||
class SwarmFactory(factory.Factory):
|
||||
class Meta:
|
||||
model = Swarm
|
||||
|
@ -71,9 +109,10 @@ class SwarmFactory(factory.Factory):
|
|||
transport = factory.LazyFunction(TCP)
|
||||
|
||||
@classmethod
|
||||
@asynccontextmanager
|
||||
async def create_and_listen(
|
||||
cls, is_secure: bool, key_pair: KeyPair = None, muxer_opt: TMuxerOptions = None
|
||||
) -> Swarm:
|
||||
) -> AsyncIterator[Swarm]:
|
||||
# `factory.Factory.__init__` does *not* prepare a *default value* if we pass
|
||||
# an argument explicitly with `None`. If an argument is `None`, we don't pass it to
|
||||
# `factory.Factory.__init__`, in order to let the function initialize it.
|
||||
|
@ -83,20 +122,23 @@ class SwarmFactory(factory.Factory):
|
|||
if muxer_opt is not None:
|
||||
optional_kwargs["muxer_opt"] = muxer_opt
|
||||
swarm = cls(is_secure=is_secure, **optional_kwargs)
|
||||
await swarm.listen(LISTEN_MADDR)
|
||||
return swarm
|
||||
async with background_trio_service(swarm):
|
||||
await swarm.listen(LISTEN_MADDR)
|
||||
yield swarm
|
||||
|
||||
@classmethod
|
||||
@asynccontextmanager
|
||||
async def create_batch_and_listen(
|
||||
cls, is_secure: bool, number: int, muxer_opt: TMuxerOptions = None
|
||||
) -> Tuple[Swarm, ...]:
|
||||
# Ignore typing since we are removing asyncio soon
|
||||
return await asyncio.gather( # type: ignore
|
||||
*[
|
||||
cls.create_and_listen(is_secure=is_secure, muxer_opt=muxer_opt)
|
||||
) -> AsyncIterator[Tuple[Swarm, ...]]:
|
||||
async with AsyncExitStack() as stack:
|
||||
ctx_mgrs = [
|
||||
await stack.enter_async_context(
|
||||
cls.create_and_listen(is_secure=is_secure, muxer_opt=muxer_opt)
|
||||
)
|
||||
for _ in range(number)
|
||||
]
|
||||
)
|
||||
yield tuple(ctx_mgrs)
|
||||
|
||||
|
||||
class HostFactory(factory.Factory):
|
||||
|
@ -107,22 +149,57 @@ class HostFactory(factory.Factory):
|
|||
is_secure = False
|
||||
key_pair = factory.LazyFunction(generate_new_rsa_identity)
|
||||
|
||||
network = factory.LazyAttribute(
|
||||
lambda o: SwarmFactory(is_secure=o.is_secure, key_pair=o.key_pair)
|
||||
)
|
||||
network = factory.LazyAttribute(lambda o: SwarmFactory(is_secure=o.is_secure))
|
||||
|
||||
@classmethod
|
||||
@asynccontextmanager
|
||||
async def create_batch_and_listen(
|
||||
cls, is_secure: bool, number: int
|
||||
) -> Tuple[BasicHost, ...]:
|
||||
key_pairs = [generate_new_rsa_identity() for _ in range(number)]
|
||||
swarms = await asyncio.gather(
|
||||
*[
|
||||
SwarmFactory.create_and_listen(is_secure, key_pair)
|
||||
for key_pair in key_pairs
|
||||
]
|
||||
)
|
||||
return tuple(BasicHost(swarm) for swarm in swarms)
|
||||
) -> AsyncIterator[Tuple[BasicHost, ...]]:
|
||||
async with SwarmFactory.create_batch_and_listen(is_secure, number) as swarms:
|
||||
hosts = tuple(BasicHost(swarm) for swarm in swarms)
|
||||
yield hosts
|
||||
|
||||
|
||||
class DummyRouter(IPeerRouting):
|
||||
_routing_table: Dict[ID, PeerInfo]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._routing_table = dict()
|
||||
|
||||
def _add_peer(self, peer_id: ID, addrs: List[Multiaddr]) -> None:
|
||||
self._routing_table[peer_id] = PeerInfo(peer_id, addrs)
|
||||
|
||||
async def find_peer(self, peer_id: ID) -> PeerInfo:
|
||||
await trio.hazmat.checkpoint()
|
||||
return self._routing_table.get(peer_id, None)
|
||||
|
||||
|
||||
class RoutedHostFactory(factory.Factory):
|
||||
class Meta:
|
||||
model = RoutedHost
|
||||
|
||||
class Params:
|
||||
is_secure = False
|
||||
|
||||
network = factory.LazyAttribute(
|
||||
lambda o: HostFactory(is_secure=o.is_secure).get_network()
|
||||
)
|
||||
router = factory.LazyFunction(DummyRouter)
|
||||
|
||||
@classmethod
|
||||
@asynccontextmanager
|
||||
async def create_batch_and_listen(
|
||||
cls, is_secure: bool, number: int
|
||||
) -> AsyncIterator[Tuple[RoutedHost, ...]]:
|
||||
routing_table = DummyRouter()
|
||||
async with HostFactory.create_batch_and_listen(is_secure, number) as hosts:
|
||||
for host in hosts:
|
||||
routing_table._add_peer(host.get_id(), host.get_addrs())
|
||||
routed_hosts = tuple(
|
||||
RoutedHost(host.get_network(), routing_table) for host in hosts
|
||||
)
|
||||
yield routed_hosts
|
||||
|
||||
|
||||
class FloodsubFactory(factory.Factory):
|
||||
|
@ -153,89 +230,192 @@ class PubsubFactory(factory.Factory):
|
|||
|
||||
host = factory.SubFactory(HostFactory)
|
||||
router = None
|
||||
my_id = factory.LazyAttribute(lambda obj: obj.host.get_id())
|
||||
cache_size = None
|
||||
strict_signing = False
|
||||
|
||||
@classmethod
|
||||
@asynccontextmanager
|
||||
async def create_and_start(
|
||||
cls, host: IHost, router: IPubsubRouter, cache_size: int, strict_signing: bool
|
||||
) -> AsyncIterator[Pubsub]:
|
||||
pubsub = cls(
|
||||
host=host,
|
||||
router=router,
|
||||
cache_size=cache_size,
|
||||
strict_signing=strict_signing,
|
||||
)
|
||||
async with background_trio_service(pubsub):
|
||||
await pubsub.wait_until_ready()
|
||||
yield pubsub
|
||||
|
||||
@classmethod
|
||||
@asynccontextmanager
|
||||
async def _create_batch_with_router(
|
||||
cls,
|
||||
number: int,
|
||||
routers: Sequence[IPubsubRouter],
|
||||
is_secure: bool = False,
|
||||
cache_size: int = None,
|
||||
strict_signing: bool = False,
|
||||
) -> AsyncIterator[Tuple[Pubsub, ...]]:
|
||||
async with HostFactory.create_batch_and_listen(is_secure, number) as hosts:
|
||||
# Pubsubs should exit before hosts
|
||||
async with AsyncExitStack() as stack:
|
||||
pubsubs = [
|
||||
await stack.enter_async_context(
|
||||
cls.create_and_start(host, router, cache_size, strict_signing)
|
||||
)
|
||||
for host, router in zip(hosts, routers)
|
||||
]
|
||||
yield tuple(pubsubs)
|
||||
|
||||
@classmethod
|
||||
@asynccontextmanager
|
||||
async def create_batch_with_floodsub(
|
||||
cls,
|
||||
number: int,
|
||||
is_secure: bool = False,
|
||||
cache_size: int = None,
|
||||
strict_signing: bool = False,
|
||||
protocols: Sequence[TProtocol] = None,
|
||||
) -> AsyncIterator[Tuple[Pubsub, ...]]:
|
||||
if protocols is not None:
|
||||
floodsubs = FloodsubFactory.create_batch(number, protocols=list(protocols))
|
||||
else:
|
||||
floodsubs = FloodsubFactory.create_batch(number)
|
||||
async with cls._create_batch_with_router(
|
||||
number, floodsubs, is_secure, cache_size, strict_signing
|
||||
) as pubsubs:
|
||||
yield pubsubs
|
||||
|
||||
@classmethod
|
||||
@asynccontextmanager
|
||||
async def create_batch_with_gossipsub(
|
||||
cls,
|
||||
number: int,
|
||||
*,
|
||||
is_secure: bool = False,
|
||||
cache_size: int = None,
|
||||
strict_signing: bool = False,
|
||||
protocols: Sequence[TProtocol] = None,
|
||||
degree: int = GOSSIPSUB_PARAMS.degree,
|
||||
degree_low: int = GOSSIPSUB_PARAMS.degree_low,
|
||||
degree_high: int = GOSSIPSUB_PARAMS.degree_high,
|
||||
time_to_live: int = GOSSIPSUB_PARAMS.time_to_live,
|
||||
gossip_window: int = GOSSIPSUB_PARAMS.gossip_window,
|
||||
gossip_history: int = GOSSIPSUB_PARAMS.gossip_history,
|
||||
heartbeat_interval: float = GOSSIPSUB_PARAMS.heartbeat_interval,
|
||||
heartbeat_initial_delay: float = GOSSIPSUB_PARAMS.heartbeat_initial_delay,
|
||||
) -> AsyncIterator[Tuple[Pubsub, ...]]:
|
||||
if protocols is not None:
|
||||
gossipsubs = GossipsubFactory.create_batch(
|
||||
number,
|
||||
protocols=protocols,
|
||||
degree=degree,
|
||||
degree_low=degree_low,
|
||||
degree_high=degree_high,
|
||||
time_to_live=time_to_live,
|
||||
gossip_window=gossip_window,
|
||||
heartbeat_interval=heartbeat_interval,
|
||||
)
|
||||
else:
|
||||
gossipsubs = GossipsubFactory.create_batch(
|
||||
number,
|
||||
degree=degree,
|
||||
degree_low=degree_low,
|
||||
degree_high=degree_high,
|
||||
time_to_live=time_to_live,
|
||||
gossip_window=gossip_window,
|
||||
heartbeat_interval=heartbeat_interval,
|
||||
)
|
||||
|
||||
async with cls._create_batch_with_router(
|
||||
number, gossipsubs, is_secure, cache_size, strict_signing
|
||||
) as pubsubs:
|
||||
async with AsyncExitStack() as stack:
|
||||
for router in gossipsubs:
|
||||
await stack.enter_async_context(background_trio_service(router))
|
||||
yield pubsubs
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def swarm_pair_factory(
|
||||
is_secure: bool, muxer_opt: TMuxerOptions = None
|
||||
) -> Tuple[Swarm, Swarm]:
|
||||
swarms = await SwarmFactory.create_batch_and_listen(
|
||||
) -> AsyncIterator[Tuple[Swarm, Swarm]]:
|
||||
async with SwarmFactory.create_batch_and_listen(
|
||||
is_secure, 2, muxer_opt=muxer_opt
|
||||
)
|
||||
await connect_swarm(swarms[0], swarms[1])
|
||||
return swarms[0], swarms[1]
|
||||
) as swarms:
|
||||
await connect_swarm(swarms[0], swarms[1])
|
||||
yield swarms[0], swarms[1]
|
||||
|
||||
|
||||
async def host_pair_factory(is_secure: bool) -> Tuple[BasicHost, BasicHost]:
|
||||
hosts = await HostFactory.create_batch_and_listen(is_secure, 2)
|
||||
await connect(hosts[0], hosts[1])
|
||||
return hosts[0], hosts[1]
|
||||
|
||||
|
||||
@asynccontextmanager # type: ignore
|
||||
async def pair_of_connected_hosts(
|
||||
is_secure: bool = True
|
||||
@asynccontextmanager
|
||||
async def host_pair_factory(
|
||||
is_secure: bool
|
||||
) -> AsyncIterator[Tuple[BasicHost, BasicHost]]:
|
||||
a, b = await host_pair_factory(is_secure)
|
||||
yield a, b
|
||||
close_tasks = (a.close(), b.close())
|
||||
await asyncio.gather(*close_tasks)
|
||||
async with HostFactory.create_batch_and_listen(is_secure, 2) as hosts:
|
||||
await connect(hosts[0], hosts[1])
|
||||
yield hosts[0], hosts[1]
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def swarm_conn_pair_factory(
|
||||
is_secure: bool, muxer_opt: TMuxerOptions = None
|
||||
) -> Tuple[SwarmConn, Swarm, SwarmConn, Swarm]:
|
||||
swarms = await swarm_pair_factory(is_secure)
|
||||
conn_0 = swarms[0].connections[swarms[1].get_peer_id()]
|
||||
conn_1 = swarms[1].connections[swarms[0].get_peer_id()]
|
||||
return cast(SwarmConn, conn_0), swarms[0], cast(SwarmConn, conn_1), swarms[1]
|
||||
) -> AsyncIterator[Tuple[SwarmConn, SwarmConn]]:
|
||||
async with swarm_pair_factory(is_secure) as swarms:
|
||||
conn_0 = swarms[0].connections[swarms[1].get_peer_id()]
|
||||
conn_1 = swarms[1].connections[swarms[0].get_peer_id()]
|
||||
yield cast(SwarmConn, conn_0), cast(SwarmConn, conn_1)
|
||||
|
||||
|
||||
async def mplex_conn_pair_factory(is_secure: bool) -> Tuple[Mplex, Swarm, Mplex, Swarm]:
|
||||
@asynccontextmanager
|
||||
async def mplex_conn_pair_factory(
|
||||
is_secure: bool
|
||||
) -> AsyncIterator[Tuple[Mplex, Mplex]]:
|
||||
muxer_opt = {MPLEX_PROTOCOL_ID: Mplex}
|
||||
conn_0, swarm_0, conn_1, swarm_1 = await swarm_conn_pair_factory(
|
||||
is_secure, muxer_opt=muxer_opt
|
||||
)
|
||||
return (
|
||||
cast(Mplex, conn_0.muxed_conn),
|
||||
swarm_0,
|
||||
cast(Mplex, conn_1.muxed_conn),
|
||||
swarm_1,
|
||||
)
|
||||
async with swarm_conn_pair_factory(is_secure, muxer_opt=muxer_opt) as swarm_pair:
|
||||
yield (
|
||||
cast(Mplex, swarm_pair[0].muxed_conn),
|
||||
cast(Mplex, swarm_pair[1].muxed_conn),
|
||||
)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def mplex_stream_pair_factory(
|
||||
is_secure: bool
|
||||
) -> Tuple[MplexStream, Swarm, MplexStream, Swarm]:
|
||||
mplex_conn_0, swarm_0, mplex_conn_1, swarm_1 = await mplex_conn_pair_factory(
|
||||
is_secure
|
||||
)
|
||||
stream_0 = await mplex_conn_0.open_stream()
|
||||
await asyncio.sleep(0.01)
|
||||
stream_1: MplexStream
|
||||
async with mplex_conn_1.streams_lock:
|
||||
if len(mplex_conn_1.streams) != 1:
|
||||
raise Exception("Mplex should not have any stream upon connection")
|
||||
stream_1 = tuple(mplex_conn_1.streams.values())[0]
|
||||
return cast(MplexStream, stream_0), swarm_0, stream_1, swarm_1
|
||||
) -> AsyncIterator[Tuple[MplexStream, MplexStream]]:
|
||||
async with mplex_conn_pair_factory(is_secure) as mplex_conn_pair_info:
|
||||
mplex_conn_0, mplex_conn_1 = mplex_conn_pair_info
|
||||
stream_0 = cast(MplexStream, await mplex_conn_0.open_stream())
|
||||
await trio.sleep(0.01)
|
||||
stream_1: MplexStream
|
||||
async with mplex_conn_1.streams_lock:
|
||||
if len(mplex_conn_1.streams) != 1:
|
||||
raise Exception("Mplex should not have any other stream")
|
||||
stream_1 = tuple(mplex_conn_1.streams.values())[0]
|
||||
yield stream_0, stream_1
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def net_stream_pair_factory(
|
||||
is_secure: bool
|
||||
) -> Tuple[INetStream, BasicHost, INetStream, BasicHost]:
|
||||
) -> AsyncIterator[Tuple[INetStream, INetStream]]:
|
||||
protocol_id = TProtocol("/example/id/1")
|
||||
|
||||
stream_1: INetStream
|
||||
|
||||
# Just a proxy, we only care about the stream
|
||||
def handler(stream: INetStream) -> None:
|
||||
# Just a proxy, we only care about the stream.
|
||||
# Add a barrier to avoid stream being removed.
|
||||
event_handler_finished = trio.Event()
|
||||
|
||||
async def handler(stream: INetStream) -> None:
|
||||
nonlocal stream_1
|
||||
stream_1 = stream
|
||||
await event_handler_finished.wait()
|
||||
|
||||
host_0, host_1 = await host_pair_factory(is_secure)
|
||||
host_1.set_stream_handler(protocol_id, handler)
|
||||
async with host_pair_factory(is_secure) as hosts:
|
||||
hosts[1].set_stream_handler(protocol_id, handler)
|
||||
|
||||
stream_0 = await host_0.new_stream(host_1.get_id(), [protocol_id])
|
||||
return stream_0, host_0, stream_1, host_1
|
||||
stream_0 = await hosts[0].new_stream(hosts[1].get_id(), [protocol_id])
|
||||
yield stream_0, stream_1
|
||||
event_handler_finished.set()
|
||||
|
|
|
@ -1,2 +1 @@
|
|||
LOCALHOST_IP = "127.0.0.1"
|
||||
PEXPECT_NEW_LINE = "\r\n"
|
||||
|
|
|
@ -1,52 +1,22 @@
|
|||
import asyncio
|
||||
import time
|
||||
from typing import Any, Awaitable, Callable, List
|
||||
from typing import AsyncIterator
|
||||
|
||||
from async_generator import asynccontextmanager
|
||||
import multiaddr
|
||||
from multiaddr import Multiaddr
|
||||
from p2pclient import Client
|
||||
import pytest
|
||||
import trio
|
||||
|
||||
from libp2p.peer.id import ID
|
||||
from libp2p.peer.peerinfo import PeerInfo, info_from_p2p_addr
|
||||
|
||||
from .constants import LOCALHOST_IP
|
||||
from .envs import GO_BIN_PATH
|
||||
from .process import BaseInteractiveProcess
|
||||
|
||||
P2PD_PATH = GO_BIN_PATH / "p2pd"
|
||||
|
||||
|
||||
TIMEOUT_DURATION = 30
|
||||
|
||||
|
||||
async def try_until_success(
|
||||
coro_func: Callable[[], Awaitable[Any]], timeout: int = TIMEOUT_DURATION
|
||||
) -> None:
|
||||
"""
|
||||
Keep running ``coro_func`` until either it succeed or time is up.
|
||||
|
||||
All arguments of ``coro_func`` should be filled, i.e. it should be
|
||||
called without arguments.
|
||||
"""
|
||||
t_start = time.monotonic()
|
||||
while True:
|
||||
result = await coro_func()
|
||||
if result:
|
||||
break
|
||||
if (time.monotonic() - t_start) >= timeout:
|
||||
# timeout
|
||||
pytest.fail(f"{coro_func} is still failing after `{timeout}` seconds")
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
|
||||
class P2PDProcess:
|
||||
proc: asyncio.subprocess.Process
|
||||
cmd: str = str(P2PD_PATH)
|
||||
args: List[Any]
|
||||
is_proc_running: bool
|
||||
|
||||
_tasks: List["asyncio.Future[Any]"]
|
||||
|
||||
class P2PDProcess(BaseInteractiveProcess):
|
||||
def __init__(
|
||||
self,
|
||||
control_maddr: Multiaddr,
|
||||
|
@ -75,74 +45,21 @@ class P2PDProcess:
|
|||
# - gossipsubHeartbeatInterval: GossipSubHeartbeatInitialDelay = 100 * time.Millisecond # noqa: E501
|
||||
# - gossipsubHeartbeatInitialDelay: GossipSubHeartbeatInterval = 1 * time.Second
|
||||
# Referece: https://github.com/libp2p/go-libp2p-daemon/blob/b95e77dbfcd186ccf817f51e95f73f9fd5982600/p2pd/main.go#L348-L353 # noqa: E501
|
||||
self.proc = None
|
||||
self.cmd = str(P2PD_PATH)
|
||||
self.args = args
|
||||
self.is_proc_running = False
|
||||
|
||||
self._tasks = []
|
||||
|
||||
async def wait_until_ready(self) -> None:
|
||||
lines_head_pattern = (b"Control socket:", b"Peer ID:", b"Peer Addrs:")
|
||||
lines_head_occurred = {line: False for line in lines_head_pattern}
|
||||
|
||||
async def read_from_daemon_and_check() -> bool:
|
||||
line = await self.proc.stdout.readline()
|
||||
for head_pattern in lines_head_occurred:
|
||||
if line.startswith(head_pattern):
|
||||
lines_head_occurred[head_pattern] = True
|
||||
return all([value for value in lines_head_occurred.values()])
|
||||
|
||||
await try_until_success(read_from_daemon_and_check)
|
||||
# Sleep a little bit to ensure the listener is up after logs are emitted.
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
async def start_printing_logs(self) -> None:
|
||||
async def _print_from_stream(
|
||||
src_name: str, reader: asyncio.StreamReader
|
||||
) -> None:
|
||||
while True:
|
||||
line = await reader.readline()
|
||||
if line != b"":
|
||||
print(f"{src_name}\t: {line.rstrip().decode()}")
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
self._tasks.append(
|
||||
asyncio.ensure_future(_print_from_stream("out", self.proc.stdout))
|
||||
)
|
||||
self._tasks.append(
|
||||
asyncio.ensure_future(_print_from_stream("err", self.proc.stderr))
|
||||
)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
async def start(self) -> None:
|
||||
if self.is_proc_running:
|
||||
return
|
||||
self.proc = await asyncio.subprocess.create_subprocess_exec(
|
||||
self.cmd,
|
||||
*self.args,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
bufsize=0,
|
||||
)
|
||||
self.is_proc_running = True
|
||||
await self.wait_until_ready()
|
||||
await self.start_printing_logs()
|
||||
|
||||
async def close(self) -> None:
|
||||
if self.is_proc_running:
|
||||
self.proc.terminate()
|
||||
await self.proc.wait()
|
||||
self.is_proc_running = False
|
||||
for task in self._tasks:
|
||||
task.cancel()
|
||||
self.patterns = (b"Control socket:", b"Peer ID:", b"Peer Addrs:")
|
||||
self.bytes_read = bytearray()
|
||||
self.event_ready = trio.Event()
|
||||
|
||||
|
||||
class Daemon:
|
||||
p2pd_proc: P2PDProcess
|
||||
p2pd_proc: BaseInteractiveProcess
|
||||
control: Client
|
||||
peer_info: PeerInfo
|
||||
|
||||
def __init__(
|
||||
self, p2pd_proc: P2PDProcess, control: Client, peer_info: PeerInfo
|
||||
self, p2pd_proc: BaseInteractiveProcess, control: Client, peer_info: PeerInfo
|
||||
) -> None:
|
||||
self.p2pd_proc = p2pd_proc
|
||||
self.control = control
|
||||
|
@ -164,6 +81,7 @@ class Daemon:
|
|||
await self.control.close()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def make_p2pd(
|
||||
daemon_control_port: int,
|
||||
client_callback_port: int,
|
||||
|
@ -172,7 +90,7 @@ async def make_p2pd(
|
|||
is_gossipsub: bool = True,
|
||||
is_pubsub_signing: bool = False,
|
||||
is_pubsub_signing_strict: bool = False,
|
||||
) -> Daemon:
|
||||
) -> AsyncIterator[Daemon]:
|
||||
control_maddr = Multiaddr(f"/ip4/{LOCALHOST_IP}/tcp/{daemon_control_port}")
|
||||
p2pd_proc = P2PDProcess(
|
||||
control_maddr,
|
||||
|
@ -185,21 +103,22 @@ async def make_p2pd(
|
|||
await p2pd_proc.start()
|
||||
client_callback_maddr = Multiaddr(f"/ip4/{LOCALHOST_IP}/tcp/{client_callback_port}")
|
||||
p2pc = Client(control_maddr, client_callback_maddr)
|
||||
await p2pc.listen()
|
||||
peer_id, maddrs = await p2pc.identify()
|
||||
listen_maddr: Multiaddr = None
|
||||
for maddr in maddrs:
|
||||
try:
|
||||
ip = maddr.value_for_protocol(multiaddr.protocols.P_IP4)
|
||||
# NOTE: Check if this `maddr` uses `tcp`.
|
||||
maddr.value_for_protocol(multiaddr.protocols.P_TCP)
|
||||
except multiaddr.exceptions.ProtocolLookupError:
|
||||
continue
|
||||
if ip == LOCALHOST_IP:
|
||||
listen_maddr = maddr
|
||||
break
|
||||
assert listen_maddr is not None, "no loopback maddr is found"
|
||||
peer_info = info_from_p2p_addr(
|
||||
listen_maddr.encapsulate(Multiaddr(f"/p2p/{peer_id.to_string()}"))
|
||||
)
|
||||
return Daemon(p2pd_proc, p2pc, peer_info)
|
||||
|
||||
async with p2pc.listen():
|
||||
peer_id, maddrs = await p2pc.identify()
|
||||
listen_maddr: Multiaddr = None
|
||||
for maddr in maddrs:
|
||||
try:
|
||||
ip = maddr.value_for_protocol(multiaddr.protocols.P_IP4)
|
||||
# NOTE: Check if this `maddr` uses `tcp`.
|
||||
maddr.value_for_protocol(multiaddr.protocols.P_TCP)
|
||||
except multiaddr.exceptions.ProtocolLookupError:
|
||||
continue
|
||||
if ip == LOCALHOST_IP:
|
||||
listen_maddr = maddr
|
||||
break
|
||||
assert listen_maddr is not None, "no loopback maddr is found"
|
||||
peer_info = info_from_p2p_addr(
|
||||
listen_maddr.encapsulate(Multiaddr(f"/p2p/{peer_id.to_string()}"))
|
||||
)
|
||||
yield Daemon(p2pd_proc, p2pc, peer_info)
|
||||
|
|
|
@ -0,0 +1,66 @@
|
|||
from abc import ABC, abstractmethod
|
||||
import subprocess
|
||||
from typing import Iterable, List
|
||||
|
||||
import trio
|
||||
|
||||
TIMEOUT_DURATION = 30
|
||||
|
||||
|
||||
class AbstractInterativeProcess(ABC):
|
||||
@abstractmethod
|
||||
async def start(self) -> None:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def close(self) -> None:
|
||||
...
|
||||
|
||||
|
||||
class BaseInteractiveProcess(AbstractInterativeProcess):
|
||||
proc: trio.Process = None
|
||||
cmd: str
|
||||
args: List[str]
|
||||
bytes_read: bytearray
|
||||
patterns: Iterable[bytes] = None
|
||||
event_ready: trio.Event
|
||||
|
||||
async def wait_until_ready(self) -> None:
|
||||
patterns_occurred = {pat: False for pat in self.patterns}
|
||||
|
||||
async def read_from_daemon_and_check() -> None:
|
||||
async for data in self.proc.stdout:
|
||||
# TODO: It takes O(n^2), which is quite bad.
|
||||
# But it should succeed in a few seconds.
|
||||
self.bytes_read.extend(data)
|
||||
for pat, occurred in patterns_occurred.items():
|
||||
if occurred:
|
||||
continue
|
||||
if pat in self.bytes_read:
|
||||
patterns_occurred[pat] = True
|
||||
if all([value for value in patterns_occurred.values()]):
|
||||
return
|
||||
|
||||
with trio.fail_after(TIMEOUT_DURATION):
|
||||
await read_from_daemon_and_check()
|
||||
self.event_ready.set()
|
||||
# Sleep a little bit to ensure the listener is up after logs are emitted.
|
||||
await trio.sleep(0.01)
|
||||
|
||||
async def start(self) -> None:
|
||||
if self.proc is not None:
|
||||
return
|
||||
# NOTE: Ignore type checks here since mypy complains about bufsize=0
|
||||
self.proc = await trio.open_process( # type: ignore
|
||||
[self.cmd] + self.args,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT, # Redirect stderr to stdout, which makes parsing easier
|
||||
bufsize=0,
|
||||
)
|
||||
await self.wait_until_ready()
|
||||
|
||||
async def close(self) -> None:
|
||||
if self.proc is None:
|
||||
return
|
||||
self.proc.terminate()
|
||||
await self.proc.wait()
|
|
@ -1,7 +1,7 @@
|
|||
import asyncio
|
||||
from typing import Union
|
||||
|
||||
from multiaddr import Multiaddr
|
||||
import trio
|
||||
|
||||
from libp2p.host.host_interface import IHost
|
||||
from libp2p.peer.id import ID
|
||||
|
@ -50,7 +50,7 @@ async def connect(a: TDaemonOrHost, b: TDaemonOrHost) -> None:
|
|||
else: # isinstance(b, IHost)
|
||||
await a.connect(b_peer_info)
|
||||
# Allow additional sleep for both side to establish the connection.
|
||||
await asyncio.sleep(0.1)
|
||||
await trio.sleep(0.1)
|
||||
|
||||
a_peer_info = _get_peer_info(a)
|
||||
|
||||
|
|
|
@ -1,12 +1,12 @@
|
|||
import asyncio
|
||||
from typing import Dict
|
||||
import uuid
|
||||
from typing import AsyncIterator, Dict, Tuple
|
||||
|
||||
from async_exit_stack import AsyncExitStack
|
||||
from async_generator import asynccontextmanager
|
||||
from async_service import Service, background_trio_service
|
||||
|
||||
from libp2p.host.host_interface import IHost
|
||||
from libp2p.pubsub.floodsub import FloodSub
|
||||
from libp2p.pubsub.pubsub import Pubsub
|
||||
from libp2p.tools.constants import LISTEN_MADDR
|
||||
from libp2p.tools.factories import FloodsubFactory, PubsubFactory
|
||||
from libp2p.tools.factories import PubsubFactory
|
||||
|
||||
CRYPTO_TOPIC = "ethereum"
|
||||
|
||||
|
@ -18,7 +18,7 @@ CRYPTO_TOPIC = "ethereum"
|
|||
# Determine message type by looking at first item before first comma
|
||||
|
||||
|
||||
class DummyAccountNode:
|
||||
class DummyAccountNode(Service):
|
||||
"""
|
||||
Node which has an internal balance mapping, meant to serve as a dummy
|
||||
crypto blockchain.
|
||||
|
@ -27,19 +27,24 @@ class DummyAccountNode:
|
|||
crypto each user in the mappings holds
|
||||
"""
|
||||
|
||||
libp2p_node: IHost
|
||||
pubsub: Pubsub
|
||||
floodsub: FloodSub
|
||||
|
||||
def __init__(self, libp2p_node: IHost, pubsub: Pubsub, floodsub: FloodSub):
|
||||
self.libp2p_node = libp2p_node
|
||||
def __init__(self, pubsub: Pubsub) -> None:
|
||||
self.pubsub = pubsub
|
||||
self.floodsub = floodsub
|
||||
self.balances: Dict[str, int] = {}
|
||||
self.node_id = str(uuid.uuid1())
|
||||
|
||||
@property
|
||||
def host(self) -> IHost:
|
||||
return self.pubsub.host
|
||||
|
||||
async def run(self) -> None:
|
||||
self.subscription = await self.pubsub.subscribe(CRYPTO_TOPIC)
|
||||
self.manager.run_daemon_task(self.handle_incoming_msgs)
|
||||
await self.manager.wait_finished()
|
||||
|
||||
@classmethod
|
||||
async def create(cls) -> "DummyAccountNode":
|
||||
@asynccontextmanager
|
||||
async def create(cls, number: int) -> AsyncIterator[Tuple["DummyAccountNode", ...]]:
|
||||
"""
|
||||
Create a new DummyAccountNode and attach a libp2p node, a floodsub, and
|
||||
a pubsub instance to this new node.
|
||||
|
@ -47,15 +52,17 @@ class DummyAccountNode:
|
|||
We use create as this serves as a factory function and allows us
|
||||
to use async await, unlike the init function
|
||||
"""
|
||||
|
||||
pubsub = PubsubFactory(router=FloodsubFactory())
|
||||
await pubsub.host.get_network().listen(LISTEN_MADDR)
|
||||
return cls(libp2p_node=pubsub.host, pubsub=pubsub, floodsub=pubsub.router)
|
||||
async with PubsubFactory.create_batch_with_floodsub(number) as pubsubs:
|
||||
async with AsyncExitStack() as stack:
|
||||
dummy_acount_nodes = tuple(cls(pubsub) for pubsub in pubsubs)
|
||||
for node in dummy_acount_nodes:
|
||||
await stack.enter_async_context(background_trio_service(node))
|
||||
yield dummy_acount_nodes
|
||||
|
||||
async def handle_incoming_msgs(self) -> None:
|
||||
"""Handle all incoming messages on the CRYPTO_TOPIC from peers."""
|
||||
while True:
|
||||
incoming = await self.q.get()
|
||||
incoming = await self.subscription.get()
|
||||
msg_comps = incoming.data.decode("utf-8").split(",")
|
||||
|
||||
if msg_comps[0] == "send":
|
||||
|
@ -63,13 +70,6 @@ class DummyAccountNode:
|
|||
elif msg_comps[0] == "set":
|
||||
self.handle_set_crypto(msg_comps[1], int(msg_comps[2]))
|
||||
|
||||
async def setup_crypto_networking(self) -> None:
|
||||
"""Subscribe to CRYPTO_TOPIC and perform call to function that handles
|
||||
all incoming messages on said topic."""
|
||||
self.q = await self.pubsub.subscribe(CRYPTO_TOPIC)
|
||||
|
||||
asyncio.ensure_future(self.handle_incoming_msgs())
|
||||
|
||||
async def publish_send_crypto(
|
||||
self, source_user: str, dest_user: str, amount: int
|
||||
) -> None:
|
||||
|
|
|
@ -1,12 +1,10 @@
|
|||
# type: ignore
|
||||
# To add typing to this module, it's better to do it after refactoring test cases into classes
|
||||
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
import trio
|
||||
|
||||
from libp2p.tools.constants import FLOODSUB_PROTOCOL_ID, LISTEN_MADDR
|
||||
from libp2p.tools.factories import PubsubFactory
|
||||
from libp2p.tools.constants import FLOODSUB_PROTOCOL_ID
|
||||
from libp2p.tools.utils import connect
|
||||
|
||||
SUPPORTED_PROTOCOLS = [FLOODSUB_PROTOCOL_ID]
|
||||
|
@ -15,6 +13,7 @@ FLOODSUB_PROTOCOL_TEST_CASES = [
|
|||
{
|
||||
"name": "simple_two_nodes",
|
||||
"supported_protocols": SUPPORTED_PROTOCOLS,
|
||||
"nodes": ["A", "B"],
|
||||
"adj_list": {"A": ["B"]},
|
||||
"topic_map": {"topic1": ["B"]},
|
||||
"messages": [{"topics": ["topic1"], "data": b"foo", "node_id": "A"}],
|
||||
|
@ -22,6 +21,7 @@ FLOODSUB_PROTOCOL_TEST_CASES = [
|
|||
{
|
||||
"name": "three_nodes_two_topics",
|
||||
"supported_protocols": SUPPORTED_PROTOCOLS,
|
||||
"nodes": ["A", "B", "C"],
|
||||
"adj_list": {"A": ["B"], "B": ["C"]},
|
||||
"topic_map": {"topic1": ["B", "C"], "topic2": ["B", "C"]},
|
||||
"messages": [
|
||||
|
@ -32,6 +32,7 @@ FLOODSUB_PROTOCOL_TEST_CASES = [
|
|||
{
|
||||
"name": "two_nodes_one_topic_single_subscriber_is_sender",
|
||||
"supported_protocols": SUPPORTED_PROTOCOLS,
|
||||
"nodes": ["A", "B"],
|
||||
"adj_list": {"A": ["B"]},
|
||||
"topic_map": {"topic1": ["B"]},
|
||||
"messages": [{"topics": ["topic1"], "data": b"Alex is tall", "node_id": "B"}],
|
||||
|
@ -39,6 +40,7 @@ FLOODSUB_PROTOCOL_TEST_CASES = [
|
|||
{
|
||||
"name": "two_nodes_one_topic_two_msgs",
|
||||
"supported_protocols": SUPPORTED_PROTOCOLS,
|
||||
"nodes": ["A", "B"],
|
||||
"adj_list": {"A": ["B"]},
|
||||
"topic_map": {"topic1": ["B"]},
|
||||
"messages": [
|
||||
|
@ -49,6 +51,7 @@ FLOODSUB_PROTOCOL_TEST_CASES = [
|
|||
{
|
||||
"name": "seven_nodes_tree_one_topics",
|
||||
"supported_protocols": SUPPORTED_PROTOCOLS,
|
||||
"nodes": ["1", "2", "3", "4", "5", "6", "7"],
|
||||
"adj_list": {"1": ["2", "3"], "2": ["4", "5"], "3": ["6", "7"]},
|
||||
"topic_map": {"astrophysics": ["2", "3", "4", "5", "6", "7"]},
|
||||
"messages": [{"topics": ["astrophysics"], "data": b"e=mc^2", "node_id": "1"}],
|
||||
|
@ -56,6 +59,7 @@ FLOODSUB_PROTOCOL_TEST_CASES = [
|
|||
{
|
||||
"name": "seven_nodes_tree_three_topics",
|
||||
"supported_protocols": SUPPORTED_PROTOCOLS,
|
||||
"nodes": ["1", "2", "3", "4", "5", "6", "7"],
|
||||
"adj_list": {"1": ["2", "3"], "2": ["4", "5"], "3": ["6", "7"]},
|
||||
"topic_map": {
|
||||
"astrophysics": ["2", "3", "4", "5", "6", "7"],
|
||||
|
@ -71,6 +75,7 @@ FLOODSUB_PROTOCOL_TEST_CASES = [
|
|||
{
|
||||
"name": "seven_nodes_tree_three_topics_diff_origin",
|
||||
"supported_protocols": SUPPORTED_PROTOCOLS,
|
||||
"nodes": ["1", "2", "3", "4", "5", "6", "7"],
|
||||
"adj_list": {"1": ["2", "3"], "2": ["4", "5"], "3": ["6", "7"]},
|
||||
"topic_map": {
|
||||
"astrophysics": ["1", "2", "3", "4", "5", "6", "7"],
|
||||
|
@ -86,6 +91,7 @@ FLOODSUB_PROTOCOL_TEST_CASES = [
|
|||
{
|
||||
"name": "three_nodes_clique_two_topic_diff_origin",
|
||||
"supported_protocols": SUPPORTED_PROTOCOLS,
|
||||
"nodes": ["1", "2", "3"],
|
||||
"adj_list": {"1": ["2", "3"], "2": ["3"]},
|
||||
"topic_map": {"astrophysics": ["1", "2", "3"], "school": ["1", "2", "3"]},
|
||||
"messages": [
|
||||
|
@ -97,6 +103,7 @@ FLOODSUB_PROTOCOL_TEST_CASES = [
|
|||
{
|
||||
"name": "four_nodes_clique_two_topic_diff_origin_many_msgs",
|
||||
"supported_protocols": SUPPORTED_PROTOCOLS,
|
||||
"nodes": ["1", "2", "3", "4"],
|
||||
"adj_list": {
|
||||
"1": ["2", "3", "4"],
|
||||
"2": ["1", "3", "4"],
|
||||
|
@ -120,6 +127,7 @@ FLOODSUB_PROTOCOL_TEST_CASES = [
|
|||
{
|
||||
"name": "five_nodes_ring_two_topic_diff_origin_many_msgs",
|
||||
"supported_protocols": SUPPORTED_PROTOCOLS,
|
||||
"nodes": ["1", "2", "3", "4", "5"],
|
||||
"adj_list": {"1": ["2"], "2": ["3"], "3": ["4"], "4": ["5"], "5": ["1"]},
|
||||
"topic_map": {
|
||||
"astrophysics": ["1", "2", "3", "4", "5"],
|
||||
|
@ -143,15 +151,7 @@ floodsub_protocol_pytest_params = [
|
|||
]
|
||||
|
||||
|
||||
def _collect_node_ids(adj_list):
|
||||
node_ids = set()
|
||||
for node, neighbors in adj_list.items():
|
||||
node_ids.add(node)
|
||||
node_ids.update(set(neighbors))
|
||||
return node_ids
|
||||
|
||||
|
||||
async def perform_test_from_obj(obj, router_factory) -> None:
|
||||
async def perform_test_from_obj(obj, pubsub_factory) -> None:
|
||||
"""
|
||||
Perform pubsub tests from a test object, which is composed as follows:
|
||||
|
||||
|
@ -185,68 +185,75 @@ async def perform_test_from_obj(obj, router_factory) -> None:
|
|||
|
||||
# Step 1) Create graph
|
||||
adj_list = obj["adj_list"]
|
||||
node_list = obj["nodes"]
|
||||
node_map = {}
|
||||
pubsub_map = {}
|
||||
|
||||
async def add_node(node_id_str: str):
|
||||
pubsub_router = router_factory(protocols=obj["supported_protocols"])
|
||||
pubsub = PubsubFactory(router=pubsub_router)
|
||||
await pubsub.host.get_network().listen(LISTEN_MADDR)
|
||||
node_map[node_id_str] = pubsub.host
|
||||
pubsub_map[node_id_str] = pubsub
|
||||
async with pubsub_factory(
|
||||
number=len(node_list), protocols=obj["supported_protocols"]
|
||||
) as pubsubs:
|
||||
for node_id_str, pubsub in zip(node_list, pubsubs):
|
||||
node_map[node_id_str] = pubsub.host
|
||||
pubsub_map[node_id_str] = pubsub
|
||||
|
||||
all_node_ids = _collect_node_ids(adj_list)
|
||||
# Connect nodes and wait at least for 2 seconds
|
||||
async with trio.open_nursery() as nursery:
|
||||
for start_node_id in adj_list:
|
||||
# For each neighbor of start_node, create if does not yet exist,
|
||||
# then connect start_node to neighbor
|
||||
for neighbor_id in adj_list[start_node_id]:
|
||||
nursery.start_soon(
|
||||
connect, node_map[start_node_id], node_map[neighbor_id]
|
||||
)
|
||||
nursery.start_soon(trio.sleep, 2)
|
||||
|
||||
for node in all_node_ids:
|
||||
await add_node(node)
|
||||
# Step 2) Subscribe to topics
|
||||
queues_map = {}
|
||||
topic_map = obj["topic_map"]
|
||||
|
||||
for node, neighbors in adj_list.items():
|
||||
for neighbor_id in neighbors:
|
||||
await connect(node_map[node], node_map[neighbor_id])
|
||||
|
||||
# NOTE: the test using this routine will fail w/o these sleeps...
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Step 2) Subscribe to topics
|
||||
queues_map = {}
|
||||
topic_map = obj["topic_map"]
|
||||
|
||||
for topic, node_ids in topic_map.items():
|
||||
for node_id in node_ids:
|
||||
queue = await pubsub_map[node_id].subscribe(topic)
|
||||
async def subscribe_node(node_id, topic):
|
||||
if node_id not in queues_map:
|
||||
queues_map[node_id] = {}
|
||||
# Store queue in topic-queue map for node
|
||||
queues_map[node_id][topic] = queue
|
||||
# Avoid repeated works
|
||||
if topic in queues_map[node_id]:
|
||||
# Checkpoint
|
||||
await trio.hazmat.checkpoint()
|
||||
return
|
||||
sub = await pubsub_map[node_id].subscribe(topic)
|
||||
queues_map[node_id][topic] = sub
|
||||
|
||||
# NOTE: the test using this routine will fail w/o these sleeps...
|
||||
await asyncio.sleep(1)
|
||||
async with trio.open_nursery() as nursery:
|
||||
for topic, node_ids in topic_map.items():
|
||||
for node_id in node_ids:
|
||||
nursery.start_soon(subscribe_node, node_id, topic)
|
||||
nursery.start_soon(trio.sleep, 2)
|
||||
|
||||
# Step 3) Publish messages
|
||||
topics_in_msgs_ordered = []
|
||||
messages = obj["messages"]
|
||||
# Step 3) Publish messages
|
||||
topics_in_msgs_ordered = []
|
||||
messages = obj["messages"]
|
||||
|
||||
for msg in messages:
|
||||
topics = msg["topics"]
|
||||
data = msg["data"]
|
||||
node_id = msg["node_id"]
|
||||
for msg in messages:
|
||||
topics = msg["topics"]
|
||||
data = msg["data"]
|
||||
node_id = msg["node_id"]
|
||||
|
||||
# Publish message
|
||||
# TODO: Should be single RPC package with several topics
|
||||
for topic in topics:
|
||||
await pubsub_map[node_id].publish(topic, data)
|
||||
|
||||
# Publish message
|
||||
# TODO: Should be single RPC package with several topics
|
||||
for topic in topics:
|
||||
await pubsub_map[node_id].publish(topic, data)
|
||||
# For each topic in topics, add (topic, node_id, data) tuple to ordered test list
|
||||
topics_in_msgs_ordered.append((topic, node_id, data))
|
||||
for topic in topics:
|
||||
topics_in_msgs_ordered.append((topic, node_id, data))
|
||||
# Allow time for publishing before continuing
|
||||
await trio.sleep(1)
|
||||
|
||||
# Step 4) Check that all messages were received correctly.
|
||||
for topic, origin_node_id, data in topics_in_msgs_ordered:
|
||||
# Look at each node in each topic
|
||||
for node_id in topic_map[topic]:
|
||||
# Get message from subscription queue
|
||||
queue = queues_map[node_id][topic]
|
||||
msg = await queue.get()
|
||||
assert data == msg.data
|
||||
# Check the message origin
|
||||
assert node_map[origin_node_id].get_id().to_bytes() == msg.from_id
|
||||
|
||||
# Success, terminate pending tasks.
|
||||
# Step 4) Check that all messages were received correctly.
|
||||
for topic, origin_node_id, data in topics_in_msgs_ordered:
|
||||
# Look at each node in each topic
|
||||
for node_id in topic_map[topic]:
|
||||
# Get message from subscription queue
|
||||
msg = await queues_map[node_id][topic].get()
|
||||
assert data == msg.data
|
||||
# Check the message origin
|
||||
assert node_map[origin_node_id].get_id().to_bytes() == msg.from_id
|
||||
|
|
|
@ -1,17 +1,10 @@
|
|||
from typing import Dict, Sequence, Tuple, cast
|
||||
from typing import Awaitable, Callable
|
||||
|
||||
import multiaddr
|
||||
|
||||
from libp2p import new_node
|
||||
from libp2p.host.basic_host import BasicHost
|
||||
from libp2p.host.host_interface import IHost
|
||||
from libp2p.host.routed_host import RoutedHost
|
||||
from libp2p.network.stream.exceptions import StreamError
|
||||
from libp2p.network.stream.net_stream_interface import INetStream
|
||||
from libp2p.network.swarm import Swarm
|
||||
from libp2p.peer.id import ID
|
||||
from libp2p.peer.peerinfo import PeerInfo, info_from_p2p_addr
|
||||
from libp2p.routing.interfaces import IPeerRouting
|
||||
from libp2p.typing import StreamHandlerFn, TProtocol
|
||||
from libp2p.peer.peerinfo import info_from_p2p_addr
|
||||
|
||||
from .constants import MAX_READ_LEN
|
||||
|
||||
|
@ -36,63 +29,20 @@ async def connect(node1: IHost, node2: IHost) -> None:
|
|||
await node1.connect(info)
|
||||
|
||||
|
||||
async def set_up_nodes_by_transport_opt(
|
||||
transport_opt_list: Sequence[Sequence[str]]
|
||||
) -> Tuple[BasicHost, ...]:
|
||||
nodes_list = []
|
||||
for transport_opt in transport_opt_list:
|
||||
node = await new_node(transport_opt=transport_opt)
|
||||
await node.get_network().listen(multiaddr.Multiaddr(transport_opt[0]))
|
||||
nodes_list.append(node)
|
||||
return tuple(nodes_list)
|
||||
def create_echo_stream_handler(
|
||||
ack_prefix: str
|
||||
) -> Callable[[INetStream], Awaitable[None]]:
|
||||
async def echo_stream_handler(stream: INetStream) -> None:
|
||||
while True:
|
||||
try:
|
||||
read_string = (await stream.read(MAX_READ_LEN)).decode()
|
||||
except StreamError:
|
||||
break
|
||||
|
||||
resp = ack_prefix + read_string
|
||||
try:
|
||||
await stream.write(resp.encode())
|
||||
except StreamError:
|
||||
break
|
||||
|
||||
async def echo_stream_handler(stream: INetStream) -> None:
|
||||
while True:
|
||||
read_string = (await stream.read(MAX_READ_LEN)).decode()
|
||||
|
||||
resp = f"ack:{read_string}"
|
||||
await stream.write(resp.encode())
|
||||
|
||||
|
||||
async def perform_two_host_set_up(
|
||||
handler: StreamHandlerFn = echo_stream_handler
|
||||
) -> Tuple[BasicHost, BasicHost]:
|
||||
transport_opt_list = [["/ip4/127.0.0.1/tcp/0"], ["/ip4/127.0.0.1/tcp/0"]]
|
||||
(node_a, node_b) = await set_up_nodes_by_transport_opt(transport_opt_list)
|
||||
|
||||
node_b.set_stream_handler(TProtocol("/echo/1.0.0"), handler)
|
||||
|
||||
# Associate the peer with local ip address (see default parameters of Libp2p())
|
||||
node_a.get_peerstore().add_addrs(node_b.get_id(), node_b.get_addrs(), 10)
|
||||
return node_a, node_b
|
||||
|
||||
|
||||
class DummyRouter(IPeerRouting):
|
||||
_routing_table: Dict[ID, PeerInfo]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._routing_table = dict()
|
||||
|
||||
async def find_peer(self, peer_id: ID) -> PeerInfo:
|
||||
return self._routing_table.get(peer_id, None)
|
||||
|
||||
|
||||
async def set_up_routed_hosts() -> Tuple[RoutedHost, RoutedHost]:
|
||||
router_a, router_b = DummyRouter(), DummyRouter()
|
||||
transport = "/ip4/127.0.0.1/tcp/0"
|
||||
host_a = await new_node(transport_opt=[transport], disc_opt=router_a)
|
||||
host_b = await new_node(transport_opt=[transport], disc_opt=router_b)
|
||||
|
||||
address = multiaddr.Multiaddr(transport)
|
||||
await host_a.get_network().listen(address)
|
||||
await host_b.get_network().listen(address)
|
||||
|
||||
mock_routing_table = {
|
||||
host_a.get_id(): PeerInfo(host_a.get_id(), host_a.get_addrs()),
|
||||
host_b.get_id(): PeerInfo(host_b.get_id(), host_b.get_addrs()),
|
||||
}
|
||||
|
||||
router_a._routing_table = router_b._routing_table = mock_routing_table
|
||||
|
||||
return cast(RoutedHost, host_a), cast(RoutedHost, host_b)
|
||||
return echo_stream_handler
|
||||
|
|
|
@ -1,12 +1,13 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
from typing import Tuple
|
||||
|
||||
from multiaddr import Multiaddr
|
||||
import trio
|
||||
|
||||
|
||||
class IListener(ABC):
|
||||
@abstractmethod
|
||||
async def listen(self, maddr: Multiaddr) -> bool:
|
||||
async def listen(self, maddr: Multiaddr, nursery: trio.Nursery) -> bool:
|
||||
"""
|
||||
put listener in listening mode and wait for incoming connections.
|
||||
|
||||
|
@ -15,7 +16,7 @@ class IListener(ABC):
|
|||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_addrs(self) -> List[Multiaddr]:
|
||||
def get_addrs(self) -> Tuple[Multiaddr, ...]:
|
||||
"""
|
||||
retrieve list of addresses the listener is listening on.
|
||||
|
||||
|
@ -24,5 +25,4 @@ class IListener(ABC):
|
|||
|
||||
@abstractmethod
|
||||
async def close(self) -> None:
|
||||
"""close the listener such that no more connections can be open on this
|
||||
transport instance."""
|
||||
...
|
||||
|
|
|
@ -1,10 +1,11 @@
|
|||
import asyncio
|
||||
from socket import socket
|
||||
import sys
|
||||
from typing import List
|
||||
import logging
|
||||
from typing import Awaitable, Callable, List, Sequence, Tuple
|
||||
|
||||
from multiaddr import Multiaddr
|
||||
import trio
|
||||
from trio_typing import TaskStatus
|
||||
|
||||
from libp2p.io.trio import TrioTCPStream
|
||||
from libp2p.network.connection.raw_connection import RawConnection
|
||||
from libp2p.network.connection.raw_connection_interface import IRawConnection
|
||||
from libp2p.transport.exceptions import OpenConnectionError
|
||||
|
@ -12,53 +13,61 @@ from libp2p.transport.listener_interface import IListener
|
|||
from libp2p.transport.transport_interface import ITransport
|
||||
from libp2p.transport.typing import THandler
|
||||
|
||||
logger = logging.getLogger("libp2p.transport.tcp")
|
||||
|
||||
|
||||
class TCPListener(IListener):
|
||||
multiaddrs: List[Multiaddr]
|
||||
server = None
|
||||
listeners: List[trio.SocketListener]
|
||||
|
||||
def __init__(self, handler_function: THandler) -> None:
|
||||
self.multiaddrs = []
|
||||
self.server = None
|
||||
self.listeners = []
|
||||
self.handler = handler_function
|
||||
|
||||
async def listen(self, maddr: Multiaddr) -> bool:
|
||||
# TODO: Get rid of `nursery`?
|
||||
async def listen(self, maddr: Multiaddr, nursery: trio.Nursery) -> None:
|
||||
"""
|
||||
put listener in listening mode and wait for incoming connections.
|
||||
|
||||
:param maddr: maddr of peer
|
||||
:return: return True if successful
|
||||
"""
|
||||
self.server = await asyncio.start_server(
|
||||
self.handler,
|
||||
|
||||
async def serve_tcp(
|
||||
handler: Callable[[trio.SocketStream], Awaitable[None]],
|
||||
port: int,
|
||||
host: str,
|
||||
task_status: TaskStatus[Sequence[trio.SocketListener]] = None,
|
||||
) -> None:
|
||||
"""Just a proxy function to add logging here."""
|
||||
logger.debug("serve_tcp %s %s", host, port)
|
||||
await trio.serve_tcp(handler, port, host=host, task_status=task_status)
|
||||
|
||||
async def handler(stream: trio.SocketStream) -> None:
|
||||
tcp_stream = TrioTCPStream(stream)
|
||||
await self.handler(tcp_stream)
|
||||
|
||||
listeners = await nursery.start(
|
||||
serve_tcp,
|
||||
handler,
|
||||
int(maddr.value_for_protocol("tcp")),
|
||||
maddr.value_for_protocol("ip4"),
|
||||
maddr.value_for_protocol("tcp"),
|
||||
)
|
||||
socket = self.server.sockets[0]
|
||||
self.multiaddrs.append(_multiaddr_from_socket(socket))
|
||||
self.listeners.extend(listeners)
|
||||
|
||||
return True
|
||||
|
||||
def get_addrs(self) -> List[Multiaddr]:
|
||||
def get_addrs(self) -> Tuple[Multiaddr, ...]:
|
||||
"""
|
||||
retrieve list of addresses the listener is listening on.
|
||||
|
||||
:return: return list of addrs
|
||||
"""
|
||||
# TODO check if server is listening
|
||||
return self.multiaddrs
|
||||
return tuple(
|
||||
_multiaddr_from_socket(listener.socket) for listener in self.listeners
|
||||
)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""close the listener such that no more connections can be open on this
|
||||
transport instance."""
|
||||
if self.server is None:
|
||||
return
|
||||
self.server.close()
|
||||
server = self.server
|
||||
self.server = None
|
||||
if sys.version_info < (3, 7):
|
||||
return
|
||||
await server.wait_closed()
|
||||
async with trio.open_nursery() as nursery:
|
||||
for listener in self.listeners:
|
||||
nursery.start_soon(listener.aclose)
|
||||
|
||||
|
||||
class TCP(ITransport):
|
||||
|
@ -74,11 +83,12 @@ class TCP(ITransport):
|
|||
self.port = int(maddr.value_for_protocol("tcp"))
|
||||
|
||||
try:
|
||||
reader, writer = await asyncio.open_connection(self.host, self.port)
|
||||
except (ConnectionAbortedError, ConnectionRefusedError) as error:
|
||||
raise OpenConnectionError(error)
|
||||
stream = await trio.open_tcp_stream(self.host, self.port)
|
||||
except OSError as error:
|
||||
raise OpenConnectionError from error
|
||||
read_write_closer = TrioTCPStream(stream)
|
||||
|
||||
return RawConnection(reader, writer, True)
|
||||
return RawConnection(read_write_closer, True)
|
||||
|
||||
def create_listener(self, handler_function: THandler) -> TCPListener:
|
||||
"""
|
||||
|
@ -91,6 +101,6 @@ class TCP(ITransport):
|
|||
return TCPListener(handler_function)
|
||||
|
||||
|
||||
def _multiaddr_from_socket(socket: socket) -> Multiaddr:
|
||||
addr, port = socket.getsockname()[:2]
|
||||
return Multiaddr(f"/ip4/{addr}/tcp/{port}")
|
||||
def _multiaddr_from_socket(socket: trio.socket.SocketType) -> Multiaddr:
|
||||
ip, port = socket.getsockname() # type: ignore
|
||||
return Multiaddr(f"/ip4/{ip}/tcp/{port}")
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
from asyncio import StreamReader, StreamWriter
|
||||
from typing import Awaitable, Callable, Mapping, Type
|
||||
|
||||
from libp2p.io.abc import ReadWriteCloser
|
||||
from libp2p.security.secure_transport_interface import ISecureTransport
|
||||
from libp2p.stream_muxer.abc import IMuxedConn
|
||||
from libp2p.typing import TProtocol
|
||||
|
||||
THandler = Callable[[StreamReader, StreamWriter], Awaitable[None]]
|
||||
THandler = Callable[[ReadWriteCloser], Awaitable[None]]
|
||||
TSecurityOptions = Mapping[TProtocol, ISecureTransport]
|
||||
TMuxerClass = Type[IMuxedConn]
|
||||
TMuxerOptions = Mapping[TProtocol, TMuxerClass]
|
||||
|
|
8
setup.py
8
setup.py
|
@ -7,8 +7,8 @@ from setuptools import find_packages, setup
|
|||
extras_require = {
|
||||
"test": [
|
||||
"pytest>=4.6.3,<5.0.0",
|
||||
"pytest-xdist>=1.30.0,<2",
|
||||
"pytest-asyncio>=0.10.0,<1.0.0",
|
||||
"pytest-xdist>=1.30.0",
|
||||
"pytest-trio>=0.5.2",
|
||||
"factory-boy>=2.12.0,<3.0.0",
|
||||
],
|
||||
"lint": [
|
||||
|
@ -74,6 +74,10 @@ install_requires = [
|
|||
"pynacl==1.3.0",
|
||||
"dataclasses>=0.7, <1;python_version<'3.7'",
|
||||
"async_generator==1.10",
|
||||
"trio>=0.13.0",
|
||||
"async-service>=0.1.0a6",
|
||||
"async-exit-stack==1.0.1",
|
||||
"trio-typing>=0.3.0,<0.4.0",
|
||||
]
|
||||
|
||||
|
||||
|
|
|
@ -1,8 +1,5 @@
|
|||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from libp2p.tools.constants import LISTEN_MADDR
|
||||
from libp2p.tools.factories import HostFactory
|
||||
|
||||
|
||||
|
@ -17,17 +14,6 @@ def num_hosts():
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
async def hosts(num_hosts, is_host_secure):
|
||||
_hosts = HostFactory.create_batch(num_hosts, is_secure=is_host_secure)
|
||||
await asyncio.gather(
|
||||
*[_host.get_network().listen(LISTEN_MADDR) for _host in _hosts]
|
||||
)
|
||||
try:
|
||||
async def hosts(num_hosts, is_host_secure, nursery):
|
||||
async with HostFactory.create_batch_and_listen(is_host_secure, num_hosts) as _hosts:
|
||||
yield _hosts
|
||||
finally:
|
||||
# TODO: It's possible that `close` raises exceptions currently,
|
||||
# due to the connection reset things. Though we don't care much about that when
|
||||
# cleaning up the tasks, it is probably better to handle the exceptions properly.
|
||||
await asyncio.gather(
|
||||
*[_host.close() for _host in _hosts], return_exceptions=True
|
||||
)
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
import asyncio
|
||||
|
||||
import pytest
|
||||
import trio
|
||||
|
||||
from libp2p.host.exceptions import StreamFailure
|
||||
from libp2p.peer.peerinfo import info_from_p2p_addr
|
||||
from libp2p.tools.utils import set_up_nodes_by_transport_opt
|
||||
from libp2p.tools.factories import HostFactory
|
||||
from libp2p.tools.utils import MAX_READ_LEN
|
||||
|
||||
PROTOCOL_ID = "/chat/1.0.0"
|
||||
|
||||
|
@ -25,7 +25,7 @@ async def hello_world(host_a, host_b):
|
|||
# Multiaddress of the destination peer is fetched from the peerstore using 'peerId'.
|
||||
stream = await host_b.new_stream(host_a.get_id(), [PROTOCOL_ID])
|
||||
await stream.write(hello_world_from_host_b)
|
||||
read = await stream.read()
|
||||
read = await stream.read(MAX_READ_LEN)
|
||||
assert read == hello_world_from_host_a
|
||||
await stream.close()
|
||||
|
||||
|
@ -47,7 +47,7 @@ async def connect_write(host_a, host_b):
|
|||
await stream.write(message.encode())
|
||||
|
||||
# Reader needs time due to async reads
|
||||
await asyncio.sleep(2)
|
||||
await trio.sleep(2)
|
||||
|
||||
await stream.close()
|
||||
assert received == messages
|
||||
|
@ -88,16 +88,14 @@ async def no_common_protocol(host_a, host_b):
|
|||
await host_b.new_stream(host_a.get_id(), ["/fakeproto/0.0.1"])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"test", [(hello_world), (connect_write), (connect_read), (no_common_protocol)]
|
||||
)
|
||||
async def test_chat(test):
|
||||
transport_opt_list = [["/ip4/127.0.0.1/tcp/0"], ["/ip4/127.0.0.1/tcp/0"]]
|
||||
(host_a, host_b) = await set_up_nodes_by_transport_opt(transport_opt_list)
|
||||
@pytest.mark.trio
|
||||
async def test_chat(test, is_host_secure):
|
||||
async with HostFactory.create_batch_and_listen(is_host_secure, 2) as hosts:
|
||||
addr = hosts[0].get_addrs()[0]
|
||||
info = info_from_p2p_addr(addr)
|
||||
await hosts[1].connect(info)
|
||||
|
||||
addr = host_a.get_addrs()[0]
|
||||
info = info_from_p2p_addr(addr)
|
||||
await host_b.connect(info)
|
||||
|
||||
await test(host_a, host_b)
|
||||
await test(hosts[0], hosts[1])
|
|
@ -1,4 +1,4 @@
|
|||
from libp2p import initialize_default_swarm
|
||||
from libp2p import new_swarm
|
||||
from libp2p.crypto.rsa import create_new_key_pair
|
||||
from libp2p.host.basic_host import BasicHost
|
||||
from libp2p.host.defaults import get_default_protocols
|
||||
|
@ -6,7 +6,7 @@ from libp2p.host.defaults import get_default_protocols
|
|||
|
||||
def test_default_protocols():
|
||||
key_pair = create_new_key_pair()
|
||||
swarm = initialize_default_swarm(key_pair)
|
||||
swarm = new_swarm(key_pair)
|
||||
host = BasicHost(swarm)
|
||||
|
||||
mux = host.get_mux()
|
||||
|
|
|
@ -1,18 +1,19 @@
|
|||
import asyncio
|
||||
import secrets
|
||||
|
||||
import pytest
|
||||
import trio
|
||||
|
||||
from libp2p.host.ping import ID, PING_LENGTH
|
||||
from libp2p.tools.factories import pair_of_connected_hosts
|
||||
from libp2p.tools.factories import host_pair_factory
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ping_once():
|
||||
async with pair_of_connected_hosts() as (host_a, host_b):
|
||||
@pytest.mark.trio
|
||||
async def test_ping_once(is_host_secure):
|
||||
async with host_pair_factory(is_host_secure) as (host_a, host_b):
|
||||
stream = await host_b.new_stream(host_a.get_id(), (ID,))
|
||||
some_ping = secrets.token_bytes(PING_LENGTH)
|
||||
await stream.write(some_ping)
|
||||
await trio.sleep(0.01)
|
||||
some_pong = await stream.read(PING_LENGTH)
|
||||
assert some_ping == some_pong
|
||||
await stream.close()
|
||||
|
@ -21,9 +22,9 @@ async def test_ping_once():
|
|||
SOME_PING_COUNT = 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ping_several():
|
||||
async with pair_of_connected_hosts() as (host_a, host_b):
|
||||
@pytest.mark.trio
|
||||
async def test_ping_several(is_host_secure):
|
||||
async with host_pair_factory(is_host_secure) as (host_a, host_b):
|
||||
stream = await host_b.new_stream(host_a.get_id(), (ID,))
|
||||
for _ in range(SOME_PING_COUNT):
|
||||
some_ping = secrets.token_bytes(PING_LENGTH)
|
||||
|
@ -33,5 +34,5 @@ async def test_ping_several():
|
|||
# NOTE: simulate some time to sleep to mirror a real
|
||||
# world usage where a peer sends pings on some periodic interval
|
||||
# NOTE: this interval can be `0` for this test.
|
||||
await asyncio.sleep(0)
|
||||
await trio.sleep(0)
|
||||
await stream.close()
|
||||
|
|
|
@ -1,33 +1,26 @@
|
|||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from libp2p.host.exceptions import ConnectionFailure
|
||||
from libp2p.peer.peerinfo import PeerInfo
|
||||
from libp2p.tools.utils import set_up_nodes_by_transport_opt, set_up_routed_hosts
|
||||
from libp2p.tools.factories import HostFactory, RoutedHostFactory
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.trio
|
||||
async def test_host_routing_success():
|
||||
host_a, host_b = await set_up_routed_hosts()
|
||||
# forces to use routing as no addrs are provided
|
||||
await host_a.connect(PeerInfo(host_b.get_id(), []))
|
||||
await host_b.connect(PeerInfo(host_a.get_id(), []))
|
||||
|
||||
# Clean up
|
||||
await asyncio.gather(*[host_a.close(), host_b.close()])
|
||||
async with RoutedHostFactory.create_batch_and_listen(False, 2) as hosts:
|
||||
# forces to use routing as no addrs are provided
|
||||
await hosts[0].connect(PeerInfo(hosts[1].get_id(), []))
|
||||
await hosts[1].connect(PeerInfo(hosts[0].get_id(), []))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.trio
|
||||
async def test_host_routing_fail():
|
||||
host_a, host_b = await set_up_routed_hosts()
|
||||
basic_host_c = (await set_up_nodes_by_transport_opt([["/ip4/127.0.0.1/tcp/0"]]))[0]
|
||||
|
||||
# routing fails because host_c does not use routing
|
||||
with pytest.raises(ConnectionFailure):
|
||||
await host_a.connect(PeerInfo(basic_host_c.get_id(), []))
|
||||
with pytest.raises(ConnectionFailure):
|
||||
await host_b.connect(PeerInfo(basic_host_c.get_id(), []))
|
||||
|
||||
# Clean up
|
||||
await asyncio.gather(*[host_a.close(), host_b.close(), basic_host_c.close()])
|
||||
is_secure = False
|
||||
async with RoutedHostFactory.create_batch_and_listen(
|
||||
is_secure, 2
|
||||
) as routed_hosts, HostFactory.create_batch_and_listen(is_secure, 1) as basic_hosts:
|
||||
# routing fails because host_c does not use routing
|
||||
with pytest.raises(ConnectionFailure):
|
||||
await routed_hosts[0].connect(PeerInfo(basic_hosts[0].get_id(), []))
|
||||
with pytest.raises(ConnectionFailure):
|
||||
await routed_hosts[1].connect(PeerInfo(basic_hosts[0].get_id(), []))
|
||||
|
|
|
@ -2,12 +2,12 @@ import pytest
|
|||
|
||||
from libp2p.identity.identify.pb.identify_pb2 import Identify
|
||||
from libp2p.identity.identify.protocol import ID, _mk_identify_protobuf
|
||||
from libp2p.tools.factories import pair_of_connected_hosts
|
||||
from libp2p.tools.factories import host_pair_factory
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_identify_protocol():
|
||||
async with pair_of_connected_hosts() as (host_a, host_b):
|
||||
@pytest.mark.trio
|
||||
async def test_identify_protocol(is_host_secure):
|
||||
async with host_pair_factory(is_host_secure) as (host_a, host_b):
|
||||
stream = await host_b.new_stream(host_a.get_id(), (ID,))
|
||||
response = await stream.read()
|
||||
await stream.close()
|
||||
|
|
|
@ -1,350 +1,285 @@
|
|||
import multiaddr
|
||||
import pytest
|
||||
|
||||
from libp2p.peer.peerinfo import info_from_p2p_addr
|
||||
from libp2p.network.stream.exceptions import StreamError
|
||||
from libp2p.tools.constants import MAX_READ_LEN
|
||||
from libp2p.tools.utils import set_up_nodes_by_transport_opt
|
||||
from libp2p.tools.factories import HostFactory
|
||||
from libp2p.tools.utils import connect, create_echo_stream_handler
|
||||
from libp2p.typing import TProtocol
|
||||
|
||||
PROTOCOL_ID_0 = TProtocol("/echo/0")
|
||||
PROTOCOL_ID_1 = TProtocol("/echo/1")
|
||||
PROTOCOL_ID_2 = TProtocol("/echo/2")
|
||||
PROTOCOL_ID_3 = TProtocol("/echo/3")
|
||||
|
||||
ACK_STR_0 = "ack_0:"
|
||||
ACK_STR_1 = "ack_1:"
|
||||
ACK_STR_2 = "ack_2:"
|
||||
ACK_STR_3 = "ack_3:"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_simple_messages():
|
||||
transport_opt_list = [["/ip4/127.0.0.1/tcp/0"], ["/ip4/127.0.0.1/tcp/0"]]
|
||||
(node_a, node_b) = await set_up_nodes_by_transport_opt(transport_opt_list)
|
||||
|
||||
async def stream_handler(stream):
|
||||
while True:
|
||||
read_string = (await stream.read(MAX_READ_LEN)).decode()
|
||||
|
||||
response = "ack:" + read_string
|
||||
await stream.write(response.encode())
|
||||
|
||||
node_b.set_stream_handler("/echo/1.0.0", stream_handler)
|
||||
|
||||
# Associate the peer with local ip address (see default parameters of Libp2p())
|
||||
node_a.get_peerstore().add_addrs(node_b.get_id(), node_b.get_addrs(), 10)
|
||||
|
||||
stream = await node_a.new_stream(node_b.get_id(), ["/echo/1.0.0"])
|
||||
|
||||
messages = ["hello" + str(x) for x in range(10)]
|
||||
for message in messages:
|
||||
await stream.write(message.encode())
|
||||
|
||||
response = (await stream.read(MAX_READ_LEN)).decode()
|
||||
|
||||
assert response == ("ack:" + message)
|
||||
|
||||
# Success, terminate pending tasks.
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_double_response():
|
||||
transport_opt_list = [["/ip4/127.0.0.1/tcp/0"], ["/ip4/127.0.0.1/tcp/0"]]
|
||||
(node_a, node_b) = await set_up_nodes_by_transport_opt(transport_opt_list)
|
||||
|
||||
async def stream_handler(stream):
|
||||
while True:
|
||||
read_string = (await stream.read(MAX_READ_LEN)).decode()
|
||||
|
||||
response = "ack1:" + read_string
|
||||
await stream.write(response.encode())
|
||||
|
||||
response = "ack2:" + read_string
|
||||
await stream.write(response.encode())
|
||||
|
||||
node_b.set_stream_handler("/echo/1.0.0", stream_handler)
|
||||
|
||||
# Associate the peer with local ip address (see default parameters of Libp2p())
|
||||
node_a.get_peerstore().add_addrs(node_b.get_id(), node_b.get_addrs(), 10)
|
||||
stream = await node_a.new_stream(node_b.get_id(), ["/echo/1.0.0"])
|
||||
|
||||
messages = ["hello" + str(x) for x in range(10)]
|
||||
for message in messages:
|
||||
await stream.write(message.encode())
|
||||
|
||||
response1 = (await stream.read(MAX_READ_LEN)).decode()
|
||||
assert response1 == ("ack1:" + message)
|
||||
|
||||
response2 = (await stream.read(MAX_READ_LEN)).decode()
|
||||
assert response2 == ("ack2:" + message)
|
||||
|
||||
# Success, terminate pending tasks.
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_streams():
|
||||
# Node A should be able to open a stream with node B and then vice versa.
|
||||
# Stream IDs should be generated uniquely so that the stream state is not overwritten
|
||||
transport_opt_list = [["/ip4/127.0.0.1/tcp/0"], ["/ip4/127.0.0.1/tcp/0"]]
|
||||
(node_a, node_b) = await set_up_nodes_by_transport_opt(transport_opt_list)
|
||||
|
||||
async def stream_handler_a(stream):
|
||||
while True:
|
||||
read_string = (await stream.read(MAX_READ_LEN)).decode()
|
||||
|
||||
response = "ack_a:" + read_string
|
||||
await stream.write(response.encode())
|
||||
|
||||
async def stream_handler_b(stream):
|
||||
while True:
|
||||
read_string = (await stream.read(MAX_READ_LEN)).decode()
|
||||
|
||||
response = "ack_b:" + read_string
|
||||
await stream.write(response.encode())
|
||||
|
||||
node_a.set_stream_handler("/echo_a/1.0.0", stream_handler_a)
|
||||
node_b.set_stream_handler("/echo_b/1.0.0", stream_handler_b)
|
||||
|
||||
# Associate the peer with local ip address (see default parameters of Libp2p())
|
||||
node_a.get_peerstore().add_addrs(node_b.get_id(), node_b.get_addrs(), 10)
|
||||
node_b.get_peerstore().add_addrs(node_a.get_id(), node_a.get_addrs(), 10)
|
||||
|
||||
stream_a = await node_a.new_stream(node_b.get_id(), ["/echo_b/1.0.0"])
|
||||
stream_b = await node_b.new_stream(node_a.get_id(), ["/echo_a/1.0.0"])
|
||||
|
||||
# A writes to /echo_b via stream_a, and B writes to /echo_a via stream_b
|
||||
messages = ["hello" + str(x) for x in range(10)]
|
||||
for message in messages:
|
||||
a_message = message + "_a"
|
||||
b_message = message + "_b"
|
||||
|
||||
await stream_a.write(a_message.encode())
|
||||
await stream_b.write(b_message.encode())
|
||||
|
||||
response_a = (await stream_a.read(MAX_READ_LEN)).decode()
|
||||
response_b = (await stream_b.read(MAX_READ_LEN)).decode()
|
||||
|
||||
assert response_a == ("ack_b:" + a_message) and response_b == (
|
||||
"ack_a:" + b_message
|
||||
@pytest.mark.trio
|
||||
async def test_simple_messages(is_host_secure):
|
||||
async with HostFactory.create_batch_and_listen(is_host_secure, 2) as hosts:
|
||||
hosts[1].set_stream_handler(
|
||||
PROTOCOL_ID_0, create_echo_stream_handler(ACK_STR_0)
|
||||
)
|
||||
|
||||
# Success, terminate pending tasks.
|
||||
# Associate the peer with local ip address (see default parameters of Libp2p())
|
||||
hosts[0].get_peerstore().add_addrs(hosts[1].get_id(), hosts[1].get_addrs(), 10)
|
||||
|
||||
stream = await hosts[0].new_stream(hosts[1].get_id(), [PROTOCOL_ID_0])
|
||||
|
||||
messages = ["hello" + str(x) for x in range(10)]
|
||||
for message in messages:
|
||||
await stream.write(message.encode())
|
||||
response = (await stream.read(MAX_READ_LEN)).decode()
|
||||
assert response == (ACK_STR_0 + message)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_streams_same_initiator_different_protocols():
|
||||
transport_opt_list = [["/ip4/127.0.0.1/tcp/0"], ["/ip4/127.0.0.1/tcp/0"]]
|
||||
(node_a, node_b) = await set_up_nodes_by_transport_opt(transport_opt_list)
|
||||
@pytest.mark.trio
|
||||
async def test_double_response(is_host_secure):
|
||||
async with HostFactory.create_batch_and_listen(is_host_secure, 2) as hosts:
|
||||
|
||||
async def stream_handler_a1(stream):
|
||||
while True:
|
||||
read_string = (await stream.read(MAX_READ_LEN)).decode()
|
||||
async def double_response_stream_handler(stream):
|
||||
while True:
|
||||
try:
|
||||
read_string = (await stream.read(MAX_READ_LEN)).decode()
|
||||
except StreamError:
|
||||
break
|
||||
|
||||
response = "ack_a1:" + read_string
|
||||
await stream.write(response.encode())
|
||||
response = ACK_STR_0 + read_string
|
||||
try:
|
||||
await stream.write(response.encode())
|
||||
except StreamError:
|
||||
break
|
||||
|
||||
async def stream_handler_a2(stream):
|
||||
while True:
|
||||
read_string = (await stream.read(MAX_READ_LEN)).decode()
|
||||
response = ACK_STR_1 + read_string
|
||||
try:
|
||||
await stream.write(response.encode())
|
||||
except StreamError:
|
||||
break
|
||||
|
||||
response = "ack_a2:" + read_string
|
||||
await stream.write(response.encode())
|
||||
hosts[1].set_stream_handler(PROTOCOL_ID_0, double_response_stream_handler)
|
||||
|
||||
async def stream_handler_a3(stream):
|
||||
while True:
|
||||
read_string = (await stream.read(MAX_READ_LEN)).decode()
|
||||
# Associate the peer with local ip address (see default parameters of Libp2p())
|
||||
hosts[0].get_peerstore().add_addrs(hosts[1].get_id(), hosts[1].get_addrs(), 10)
|
||||
stream = await hosts[0].new_stream(hosts[1].get_id(), [PROTOCOL_ID_0])
|
||||
|
||||
response = "ack_a3:" + read_string
|
||||
await stream.write(response.encode())
|
||||
|
||||
node_b.set_stream_handler("/echo_a1/1.0.0", stream_handler_a1)
|
||||
node_b.set_stream_handler("/echo_a2/1.0.0", stream_handler_a2)
|
||||
node_b.set_stream_handler("/echo_a3/1.0.0", stream_handler_a3)
|
||||
|
||||
# Associate the peer with local ip address (see default parameters of Libp2p())
|
||||
node_a.get_peerstore().add_addrs(node_b.get_id(), node_b.get_addrs(), 10)
|
||||
node_b.get_peerstore().add_addrs(node_a.get_id(), node_a.get_addrs(), 10)
|
||||
|
||||
# Open streams to node_b over echo_a1 echo_a2 echo_a3 protocols
|
||||
stream_a1 = await node_a.new_stream(node_b.get_id(), ["/echo_a1/1.0.0"])
|
||||
stream_a2 = await node_a.new_stream(node_b.get_id(), ["/echo_a2/1.0.0"])
|
||||
stream_a3 = await node_a.new_stream(node_b.get_id(), ["/echo_a3/1.0.0"])
|
||||
|
||||
messages = ["hello" + str(x) for x in range(10)]
|
||||
for message in messages:
|
||||
a1_message = message + "_a1"
|
||||
a2_message = message + "_a2"
|
||||
a3_message = message + "_a3"
|
||||
|
||||
await stream_a1.write(a1_message.encode())
|
||||
await stream_a2.write(a2_message.encode())
|
||||
await stream_a3.write(a3_message.encode())
|
||||
|
||||
response_a1 = (await stream_a1.read(MAX_READ_LEN)).decode()
|
||||
response_a2 = (await stream_a2.read(MAX_READ_LEN)).decode()
|
||||
response_a3 = (await stream_a3.read(MAX_READ_LEN)).decode()
|
||||
|
||||
assert (
|
||||
response_a1 == ("ack_a1:" + a1_message)
|
||||
and response_a2 == ("ack_a2:" + a2_message)
|
||||
and response_a3 == ("ack_a3:" + a3_message)
|
||||
)
|
||||
|
||||
# Success, terminate pending tasks.
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_streams_two_initiators():
|
||||
transport_opt_list = [["/ip4/127.0.0.1/tcp/0"], ["/ip4/127.0.0.1/tcp/0"]]
|
||||
(node_a, node_b) = await set_up_nodes_by_transport_opt(transport_opt_list)
|
||||
|
||||
async def stream_handler_a1(stream):
|
||||
while True:
|
||||
read_string = (await stream.read(MAX_READ_LEN)).decode()
|
||||
|
||||
response = "ack_a1:" + read_string
|
||||
await stream.write(response.encode())
|
||||
|
||||
async def stream_handler_a2(stream):
|
||||
while True:
|
||||
read_string = (await stream.read(MAX_READ_LEN)).decode()
|
||||
|
||||
response = "ack_a2:" + read_string
|
||||
await stream.write(response.encode())
|
||||
|
||||
async def stream_handler_b1(stream):
|
||||
while True:
|
||||
read_string = (await stream.read(MAX_READ_LEN)).decode()
|
||||
|
||||
response = "ack_b1:" + read_string
|
||||
await stream.write(response.encode())
|
||||
|
||||
async def stream_handler_b2(stream):
|
||||
while True:
|
||||
read_string = (await stream.read(MAX_READ_LEN)).decode()
|
||||
|
||||
response = "ack_b2:" + read_string
|
||||
await stream.write(response.encode())
|
||||
|
||||
node_a.set_stream_handler("/echo_b1/1.0.0", stream_handler_b1)
|
||||
node_a.set_stream_handler("/echo_b2/1.0.0", stream_handler_b2)
|
||||
|
||||
node_b.set_stream_handler("/echo_a1/1.0.0", stream_handler_a1)
|
||||
node_b.set_stream_handler("/echo_a2/1.0.0", stream_handler_a2)
|
||||
|
||||
# Associate the peer with local ip address (see default parameters of Libp2p())
|
||||
node_a.get_peerstore().add_addrs(node_b.get_id(), node_b.get_addrs(), 10)
|
||||
node_b.get_peerstore().add_addrs(node_a.get_id(), node_a.get_addrs(), 10)
|
||||
|
||||
stream_a1 = await node_a.new_stream(node_b.get_id(), ["/echo_a1/1.0.0"])
|
||||
stream_a2 = await node_a.new_stream(node_b.get_id(), ["/echo_a2/1.0.0"])
|
||||
|
||||
stream_b1 = await node_b.new_stream(node_a.get_id(), ["/echo_b1/1.0.0"])
|
||||
stream_b2 = await node_b.new_stream(node_a.get_id(), ["/echo_b2/1.0.0"])
|
||||
|
||||
# A writes to /echo_b via stream_a, and B writes to /echo_a via stream_b
|
||||
messages = ["hello" + str(x) for x in range(10)]
|
||||
for message in messages:
|
||||
a1_message = message + "_a1"
|
||||
a2_message = message + "_a2"
|
||||
|
||||
b1_message = message + "_b1"
|
||||
b2_message = message + "_b2"
|
||||
|
||||
await stream_a1.write(a1_message.encode())
|
||||
await stream_a2.write(a2_message.encode())
|
||||
|
||||
await stream_b1.write(b1_message.encode())
|
||||
await stream_b2.write(b2_message.encode())
|
||||
|
||||
response_a1 = (await stream_a1.read(MAX_READ_LEN)).decode()
|
||||
response_a2 = (await stream_a2.read(MAX_READ_LEN)).decode()
|
||||
|
||||
response_b1 = (await stream_b1.read(MAX_READ_LEN)).decode()
|
||||
response_b2 = (await stream_b2.read(MAX_READ_LEN)).decode()
|
||||
|
||||
assert (
|
||||
response_a1 == ("ack_a1:" + a1_message)
|
||||
and response_a2 == ("ack_a2:" + a2_message)
|
||||
and response_b1 == ("ack_b1:" + b1_message)
|
||||
and response_b2 == ("ack_b2:" + b2_message)
|
||||
)
|
||||
|
||||
# Success, terminate pending tasks.
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_triangle_nodes_connection():
|
||||
transport_opt_list = [
|
||||
["/ip4/127.0.0.1/tcp/0"],
|
||||
["/ip4/127.0.0.1/tcp/0"],
|
||||
["/ip4/127.0.0.1/tcp/0"],
|
||||
]
|
||||
(node_a, node_b, node_c) = await set_up_nodes_by_transport_opt(transport_opt_list)
|
||||
|
||||
async def stream_handler(stream):
|
||||
while True:
|
||||
read_string = (await stream.read(MAX_READ_LEN)).decode()
|
||||
|
||||
response = "ack:" + read_string
|
||||
await stream.write(response.encode())
|
||||
|
||||
node_a.set_stream_handler("/echo/1.0.0", stream_handler)
|
||||
node_b.set_stream_handler("/echo/1.0.0", stream_handler)
|
||||
node_c.set_stream_handler("/echo/1.0.0", stream_handler)
|
||||
|
||||
# Associate the peer with local ip address (see default parameters of Libp2p())
|
||||
# Associate all permutations
|
||||
node_a.get_peerstore().add_addrs(node_b.get_id(), node_b.get_addrs(), 10)
|
||||
node_a.get_peerstore().add_addrs(node_c.get_id(), node_c.get_addrs(), 10)
|
||||
|
||||
node_b.get_peerstore().add_addrs(node_a.get_id(), node_a.get_addrs(), 10)
|
||||
node_b.get_peerstore().add_addrs(node_c.get_id(), node_c.get_addrs(), 10)
|
||||
|
||||
node_c.get_peerstore().add_addrs(node_a.get_id(), node_a.get_addrs(), 10)
|
||||
node_c.get_peerstore().add_addrs(node_b.get_id(), node_b.get_addrs(), 10)
|
||||
|
||||
stream_a_to_b = await node_a.new_stream(node_b.get_id(), ["/echo/1.0.0"])
|
||||
stream_a_to_c = await node_a.new_stream(node_c.get_id(), ["/echo/1.0.0"])
|
||||
|
||||
stream_b_to_a = await node_b.new_stream(node_a.get_id(), ["/echo/1.0.0"])
|
||||
stream_b_to_c = await node_b.new_stream(node_c.get_id(), ["/echo/1.0.0"])
|
||||
|
||||
stream_c_to_a = await node_c.new_stream(node_a.get_id(), ["/echo/1.0.0"])
|
||||
stream_c_to_b = await node_c.new_stream(node_b.get_id(), ["/echo/1.0.0"])
|
||||
|
||||
messages = ["hello" + str(x) for x in range(5)]
|
||||
streams = [
|
||||
stream_a_to_b,
|
||||
stream_a_to_c,
|
||||
stream_b_to_a,
|
||||
stream_b_to_c,
|
||||
stream_c_to_a,
|
||||
stream_c_to_b,
|
||||
]
|
||||
|
||||
for message in messages:
|
||||
for stream in streams:
|
||||
messages = ["hello" + str(x) for x in range(10)]
|
||||
for message in messages:
|
||||
await stream.write(message.encode())
|
||||
|
||||
response = (await stream.read(MAX_READ_LEN)).decode()
|
||||
response1 = (await stream.read(MAX_READ_LEN)).decode()
|
||||
assert response1 == (ACK_STR_0 + message)
|
||||
|
||||
assert response == ("ack:" + message)
|
||||
|
||||
# Success, terminate pending tasks.
|
||||
response2 = (await stream.read(MAX_READ_LEN)).decode()
|
||||
assert response2 == (ACK_STR_1 + message)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_host_connect():
|
||||
transport_opt_list = [["/ip4/127.0.0.1/tcp/0"], ["/ip4/127.0.0.1/tcp/0"]]
|
||||
(node_a, node_b) = await set_up_nodes_by_transport_opt(transport_opt_list)
|
||||
@pytest.mark.trio
|
||||
async def test_multiple_streams(is_host_secure):
|
||||
# hosts[0] should be able to open a stream with hosts[1] and then vice versa.
|
||||
# Stream IDs should be generated uniquely so that the stream state is not overwritten
|
||||
|
||||
# Only our peer ID is stored in peer store
|
||||
assert len(node_a.get_peerstore().peer_ids()) == 1
|
||||
async with HostFactory.create_batch_and_listen(is_host_secure, 2) as hosts:
|
||||
hosts[0].set_stream_handler(
|
||||
PROTOCOL_ID_0, create_echo_stream_handler(ACK_STR_0)
|
||||
)
|
||||
hosts[1].set_stream_handler(
|
||||
PROTOCOL_ID_1, create_echo_stream_handler(ACK_STR_1)
|
||||
)
|
||||
|
||||
addr = node_b.get_addrs()[0]
|
||||
info = info_from_p2p_addr(addr)
|
||||
await node_a.connect(info)
|
||||
# Associate the peer with local ip address (see default parameters of Libp2p())
|
||||
hosts[0].get_peerstore().add_addrs(hosts[1].get_id(), hosts[1].get_addrs(), 10)
|
||||
hosts[1].get_peerstore().add_addrs(hosts[0].get_id(), hosts[0].get_addrs(), 10)
|
||||
|
||||
assert len(node_a.get_peerstore().peer_ids()) == 2
|
||||
stream_a = await hosts[0].new_stream(hosts[1].get_id(), [PROTOCOL_ID_1])
|
||||
stream_b = await hosts[1].new_stream(hosts[0].get_id(), [PROTOCOL_ID_0])
|
||||
|
||||
await node_a.connect(info)
|
||||
# A writes to /echo_b via stream_a, and B writes to /echo_a via stream_b
|
||||
messages = ["hello" + str(x) for x in range(10)]
|
||||
for message in messages:
|
||||
a_message = message + "_a"
|
||||
b_message = message + "_b"
|
||||
|
||||
# make sure we don't do double connection
|
||||
assert len(node_a.get_peerstore().peer_ids()) == 2
|
||||
await stream_a.write(a_message.encode())
|
||||
await stream_b.write(b_message.encode())
|
||||
|
||||
assert node_b.get_id() in node_a.get_peerstore().peer_ids()
|
||||
ma_node_b = multiaddr.Multiaddr("/p2p/%s" % node_b.get_id().pretty())
|
||||
for addr in node_a.get_peerstore().addrs(node_b.get_id()):
|
||||
assert addr.encapsulate(ma_node_b) in node_b.get_addrs()
|
||||
response_a = (await stream_a.read(MAX_READ_LEN)).decode()
|
||||
response_b = (await stream_b.read(MAX_READ_LEN)).decode()
|
||||
|
||||
# Success, terminate pending tasks.
|
||||
assert response_a == (ACK_STR_1 + a_message) and response_b == (
|
||||
ACK_STR_0 + b_message
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_multiple_streams_same_initiator_different_protocols(is_host_secure):
|
||||
async with HostFactory.create_batch_and_listen(is_host_secure, 2) as hosts:
|
||||
|
||||
hosts[1].set_stream_handler(
|
||||
PROTOCOL_ID_0, create_echo_stream_handler(ACK_STR_0)
|
||||
)
|
||||
hosts[1].set_stream_handler(
|
||||
PROTOCOL_ID_1, create_echo_stream_handler(ACK_STR_1)
|
||||
)
|
||||
hosts[1].set_stream_handler(
|
||||
PROTOCOL_ID_2, create_echo_stream_handler(ACK_STR_2)
|
||||
)
|
||||
|
||||
# Associate the peer with local ip address (see default parameters of Libp2p())
|
||||
hosts[0].get_peerstore().add_addrs(hosts[1].get_id(), hosts[1].get_addrs(), 10)
|
||||
hosts[1].get_peerstore().add_addrs(hosts[0].get_id(), hosts[0].get_addrs(), 10)
|
||||
|
||||
# Open streams to hosts[1] over echo_a1 echo_a2 echo_a3 protocols
|
||||
stream_a1 = await hosts[0].new_stream(hosts[1].get_id(), [PROTOCOL_ID_0])
|
||||
stream_a2 = await hosts[0].new_stream(hosts[1].get_id(), [PROTOCOL_ID_1])
|
||||
stream_a3 = await hosts[0].new_stream(hosts[1].get_id(), [PROTOCOL_ID_2])
|
||||
|
||||
messages = ["hello" + str(x) for x in range(10)]
|
||||
for message in messages:
|
||||
a1_message = message + "_a1"
|
||||
a2_message = message + "_a2"
|
||||
a3_message = message + "_a3"
|
||||
|
||||
await stream_a1.write(a1_message.encode())
|
||||
await stream_a2.write(a2_message.encode())
|
||||
await stream_a3.write(a3_message.encode())
|
||||
|
||||
response_a1 = (await stream_a1.read(MAX_READ_LEN)).decode()
|
||||
response_a2 = (await stream_a2.read(MAX_READ_LEN)).decode()
|
||||
response_a3 = (await stream_a3.read(MAX_READ_LEN)).decode()
|
||||
|
||||
assert (
|
||||
response_a1 == (ACK_STR_0 + a1_message)
|
||||
and response_a2 == (ACK_STR_1 + a2_message)
|
||||
and response_a3 == (ACK_STR_2 + a3_message)
|
||||
)
|
||||
|
||||
# Success, terminate pending tasks.
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_multiple_streams_two_initiators(is_host_secure):
|
||||
async with HostFactory.create_batch_and_listen(is_host_secure, 2) as hosts:
|
||||
hosts[0].set_stream_handler(
|
||||
PROTOCOL_ID_2, create_echo_stream_handler(ACK_STR_2)
|
||||
)
|
||||
hosts[0].set_stream_handler(
|
||||
PROTOCOL_ID_3, create_echo_stream_handler(ACK_STR_3)
|
||||
)
|
||||
|
||||
hosts[1].set_stream_handler(
|
||||
PROTOCOL_ID_0, create_echo_stream_handler(ACK_STR_0)
|
||||
)
|
||||
hosts[1].set_stream_handler(
|
||||
PROTOCOL_ID_1, create_echo_stream_handler(ACK_STR_1)
|
||||
)
|
||||
|
||||
# Associate the peer with local ip address (see default parameters of Libp2p())
|
||||
hosts[0].get_peerstore().add_addrs(hosts[1].get_id(), hosts[1].get_addrs(), 10)
|
||||
hosts[1].get_peerstore().add_addrs(hosts[0].get_id(), hosts[0].get_addrs(), 10)
|
||||
|
||||
stream_a1 = await hosts[0].new_stream(hosts[1].get_id(), [PROTOCOL_ID_0])
|
||||
stream_a2 = await hosts[0].new_stream(hosts[1].get_id(), [PROTOCOL_ID_1])
|
||||
|
||||
stream_b1 = await hosts[1].new_stream(hosts[0].get_id(), [PROTOCOL_ID_2])
|
||||
stream_b2 = await hosts[1].new_stream(hosts[0].get_id(), [PROTOCOL_ID_3])
|
||||
|
||||
# A writes to /echo_b via stream_a, and B writes to /echo_a via stream_b
|
||||
messages = ["hello" + str(x) for x in range(10)]
|
||||
for message in messages:
|
||||
a1_message = message + "_a1"
|
||||
a2_message = message + "_a2"
|
||||
|
||||
b1_message = message + "_b1"
|
||||
b2_message = message + "_b2"
|
||||
|
||||
await stream_a1.write(a1_message.encode())
|
||||
await stream_a2.write(a2_message.encode())
|
||||
|
||||
await stream_b1.write(b1_message.encode())
|
||||
await stream_b2.write(b2_message.encode())
|
||||
|
||||
response_a1 = (await stream_a1.read(MAX_READ_LEN)).decode()
|
||||
response_a2 = (await stream_a2.read(MAX_READ_LEN)).decode()
|
||||
|
||||
response_b1 = (await stream_b1.read(MAX_READ_LEN)).decode()
|
||||
response_b2 = (await stream_b2.read(MAX_READ_LEN)).decode()
|
||||
|
||||
assert (
|
||||
response_a1 == (ACK_STR_0 + a1_message)
|
||||
and response_a2 == (ACK_STR_1 + a2_message)
|
||||
and response_b1 == (ACK_STR_2 + b1_message)
|
||||
and response_b2 == (ACK_STR_3 + b2_message)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_triangle_nodes_connection(is_host_secure):
|
||||
async with HostFactory.create_batch_and_listen(is_host_secure, 3) as hosts:
|
||||
|
||||
hosts[0].set_stream_handler(
|
||||
PROTOCOL_ID_0, create_echo_stream_handler(ACK_STR_0)
|
||||
)
|
||||
hosts[1].set_stream_handler(
|
||||
PROTOCOL_ID_0, create_echo_stream_handler(ACK_STR_0)
|
||||
)
|
||||
hosts[2].set_stream_handler(
|
||||
PROTOCOL_ID_0, create_echo_stream_handler(ACK_STR_0)
|
||||
)
|
||||
|
||||
# Associate the peer with local ip address (see default parameters of Libp2p())
|
||||
# Associate all permutations
|
||||
hosts[0].get_peerstore().add_addrs(hosts[1].get_id(), hosts[1].get_addrs(), 10)
|
||||
hosts[0].get_peerstore().add_addrs(hosts[2].get_id(), hosts[2].get_addrs(), 10)
|
||||
|
||||
hosts[1].get_peerstore().add_addrs(hosts[0].get_id(), hosts[0].get_addrs(), 10)
|
||||
hosts[1].get_peerstore().add_addrs(hosts[2].get_id(), hosts[2].get_addrs(), 10)
|
||||
|
||||
hosts[2].get_peerstore().add_addrs(hosts[0].get_id(), hosts[0].get_addrs(), 10)
|
||||
hosts[2].get_peerstore().add_addrs(hosts[1].get_id(), hosts[1].get_addrs(), 10)
|
||||
|
||||
stream_0_to_1 = await hosts[0].new_stream(hosts[1].get_id(), [PROTOCOL_ID_0])
|
||||
stream_0_to_2 = await hosts[0].new_stream(hosts[2].get_id(), [PROTOCOL_ID_0])
|
||||
|
||||
stream_1_to_0 = await hosts[1].new_stream(hosts[0].get_id(), [PROTOCOL_ID_0])
|
||||
stream_1_to_2 = await hosts[1].new_stream(hosts[2].get_id(), [PROTOCOL_ID_0])
|
||||
|
||||
stream_2_to_0 = await hosts[2].new_stream(hosts[0].get_id(), [PROTOCOL_ID_0])
|
||||
stream_2_to_1 = await hosts[2].new_stream(hosts[1].get_id(), [PROTOCOL_ID_0])
|
||||
|
||||
messages = ["hello" + str(x) for x in range(5)]
|
||||
streams = [
|
||||
stream_0_to_1,
|
||||
stream_0_to_2,
|
||||
stream_1_to_0,
|
||||
stream_1_to_2,
|
||||
stream_2_to_0,
|
||||
stream_2_to_1,
|
||||
]
|
||||
|
||||
for message in messages:
|
||||
for stream in streams:
|
||||
await stream.write(message.encode())
|
||||
response = (await stream.read(MAX_READ_LEN)).decode()
|
||||
assert response == (ACK_STR_0 + message)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_host_connect(is_host_secure):
|
||||
async with HostFactory.create_batch_and_listen(is_host_secure, 2) as hosts:
|
||||
assert len(hosts[0].get_peerstore().peer_ids()) == 1
|
||||
|
||||
await connect(hosts[0], hosts[1])
|
||||
assert len(hosts[0].get_peerstore().peer_ids()) == 2
|
||||
|
||||
await connect(hosts[0], hosts[1])
|
||||
# make sure we don't do double connection
|
||||
assert len(hosts[0].get_peerstore().peer_ids()) == 2
|
||||
|
||||
assert hosts[1].get_id() in hosts[0].get_peerstore().peer_ids()
|
||||
ma_node_b = multiaddr.Multiaddr("/p2p/%s" % hosts[1].get_id().pretty())
|
||||
for addr in hosts[0].get_peerstore().addrs(hosts[1].get_id()):
|
||||
assert addr.encapsulate(ma_node_b) in hosts[1].get_addrs()
|
||||
|
|
|
@ -1,5 +1,3 @@
|
|||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from libp2p.tools.factories import (
|
||||
|
@ -11,26 +9,17 @@ from libp2p.tools.factories import (
|
|||
|
||||
@pytest.fixture
|
||||
async def net_stream_pair(is_host_secure):
|
||||
stream_0, host_0, stream_1, host_1 = await net_stream_pair_factory(is_host_secure)
|
||||
try:
|
||||
yield stream_0, stream_1
|
||||
finally:
|
||||
await asyncio.gather(*[host_0.close(), host_1.close()])
|
||||
async with net_stream_pair_factory(is_host_secure) as net_stream_pair:
|
||||
yield net_stream_pair
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def swarm_pair(is_host_secure):
|
||||
swarm_0, swarm_1 = await swarm_pair_factory(is_host_secure)
|
||||
try:
|
||||
yield swarm_0, swarm_1
|
||||
finally:
|
||||
await asyncio.gather(*[swarm_0.close(), swarm_1.close()])
|
||||
async with swarm_pair_factory(is_host_secure) as swarms:
|
||||
yield swarms
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def swarm_conn_pair(is_host_secure):
|
||||
conn_0, swarm_0, conn_1, swarm_1 = await swarm_conn_pair_factory(is_host_secure)
|
||||
try:
|
||||
yield conn_0, conn_1
|
||||
finally:
|
||||
await asyncio.gather(*[swarm_0.close(), swarm_1.close()])
|
||||
async with swarm_conn_pair_factory(is_host_secure) as swarm_conn_pair:
|
||||
yield swarm_conn_pair
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
import asyncio
|
||||
|
||||
import pytest
|
||||
import trio
|
||||
|
||||
from libp2p.network.stream.exceptions import StreamClosed, StreamEOF, StreamReset
|
||||
from libp2p.tools.constants import MAX_READ_LEN
|
||||
|
@ -8,7 +7,7 @@ from libp2p.tools.constants import MAX_READ_LEN
|
|||
DATA = b"data_123"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.trio
|
||||
async def test_net_stream_read_write(net_stream_pair):
|
||||
stream_0, stream_1 = net_stream_pair
|
||||
assert (
|
||||
|
@ -19,7 +18,7 @@ async def test_net_stream_read_write(net_stream_pair):
|
|||
assert (await stream_1.read(MAX_READ_LEN)) == DATA
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.trio
|
||||
async def test_net_stream_read_until_eof(net_stream_pair):
|
||||
read_bytes = bytearray()
|
||||
stream_0, stream_1 = net_stream_pair
|
||||
|
@ -27,41 +26,39 @@ async def test_net_stream_read_until_eof(net_stream_pair):
|
|||
async def read_until_eof():
|
||||
read_bytes.extend(await stream_1.read())
|
||||
|
||||
task = asyncio.ensure_future(read_until_eof())
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(read_until_eof)
|
||||
expected_data = bytearray()
|
||||
|
||||
expected_data = bytearray()
|
||||
# Test: `read` doesn't return before `close` is called.
|
||||
await stream_0.write(DATA)
|
||||
expected_data.extend(DATA)
|
||||
await trio.sleep(0.01)
|
||||
assert len(read_bytes) == 0
|
||||
# Test: `read` doesn't return before `close` is called.
|
||||
await stream_0.write(DATA)
|
||||
expected_data.extend(DATA)
|
||||
await trio.sleep(0.01)
|
||||
assert len(read_bytes) == 0
|
||||
|
||||
# Test: `read` doesn't return before `close` is called.
|
||||
await stream_0.write(DATA)
|
||||
expected_data.extend(DATA)
|
||||
await asyncio.sleep(0.01)
|
||||
assert len(read_bytes) == 0
|
||||
# Test: `read` doesn't return before `close` is called.
|
||||
await stream_0.write(DATA)
|
||||
expected_data.extend(DATA)
|
||||
await asyncio.sleep(0.01)
|
||||
assert len(read_bytes) == 0
|
||||
|
||||
# Test: Close the stream, `read` returns, and receive previous sent data.
|
||||
await stream_0.close()
|
||||
await asyncio.sleep(0.01)
|
||||
assert read_bytes == expected_data
|
||||
|
||||
task.cancel()
|
||||
# Test: Close the stream, `read` returns, and receive previous sent data.
|
||||
await stream_0.close()
|
||||
await trio.sleep(0.01)
|
||||
assert read_bytes == expected_data
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.trio
|
||||
async def test_net_stream_read_after_remote_closed(net_stream_pair):
|
||||
stream_0, stream_1 = net_stream_pair
|
||||
await stream_0.write(DATA)
|
||||
await stream_0.close()
|
||||
await asyncio.sleep(0.01)
|
||||
await trio.sleep(0.01)
|
||||
assert (await stream_1.read(MAX_READ_LEN)) == DATA
|
||||
with pytest.raises(StreamEOF):
|
||||
await stream_1.read(MAX_READ_LEN)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.trio
|
||||
async def test_net_stream_read_after_local_reset(net_stream_pair):
|
||||
stream_0, stream_1 = net_stream_pair
|
||||
await stream_0.reset()
|
||||
|
@ -69,29 +66,29 @@ async def test_net_stream_read_after_local_reset(net_stream_pair):
|
|||
await stream_0.read(MAX_READ_LEN)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.trio
|
||||
async def test_net_stream_read_after_remote_reset(net_stream_pair):
|
||||
stream_0, stream_1 = net_stream_pair
|
||||
await stream_0.write(DATA)
|
||||
await stream_0.reset()
|
||||
# Sleep to let `stream_1` receive the message.
|
||||
await asyncio.sleep(0.01)
|
||||
await trio.sleep(0.01)
|
||||
with pytest.raises(StreamReset):
|
||||
await stream_1.read(MAX_READ_LEN)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.trio
|
||||
async def test_net_stream_read_after_remote_closed_and_reset(net_stream_pair):
|
||||
stream_0, stream_1 = net_stream_pair
|
||||
await stream_0.write(DATA)
|
||||
await stream_0.close()
|
||||
await stream_0.reset()
|
||||
# Sleep to let `stream_1` receive the message.
|
||||
await asyncio.sleep(0.01)
|
||||
await trio.sleep(0.01)
|
||||
assert (await stream_1.read(MAX_READ_LEN)) == DATA
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.trio
|
||||
async def test_net_stream_write_after_local_closed(net_stream_pair):
|
||||
stream_0, stream_1 = net_stream_pair
|
||||
await stream_0.write(DATA)
|
||||
|
@ -100,7 +97,7 @@ async def test_net_stream_write_after_local_closed(net_stream_pair):
|
|||
await stream_0.write(DATA)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.trio
|
||||
async def test_net_stream_write_after_local_reset(net_stream_pair):
|
||||
stream_0, stream_1 = net_stream_pair
|
||||
await stream_0.reset()
|
||||
|
@ -108,10 +105,10 @@ async def test_net_stream_write_after_local_reset(net_stream_pair):
|
|||
await stream_0.write(DATA)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.trio
|
||||
async def test_net_stream_write_after_remote_reset(net_stream_pair):
|
||||
stream_0, stream_1 = net_stream_pair
|
||||
await stream_1.reset()
|
||||
await asyncio.sleep(0.01)
|
||||
await trio.sleep(0.01)
|
||||
with pytest.raises(StreamClosed):
|
||||
await stream_0.write(DATA)
|
||||
|
|
|
@ -8,11 +8,11 @@ into network after network has already started listening
|
|||
TODO: Add tests for closed_stream, listen_close when those
|
||||
features are implemented in swarm
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import enum
|
||||
|
||||
from async_service import background_trio_service
|
||||
import pytest
|
||||
import trio
|
||||
|
||||
from libp2p.network.notifee_interface import INotifee
|
||||
from libp2p.tools.constants import LISTEN_MADDR
|
||||
|
@ -54,59 +54,63 @@ class MyNotifee(INotifee):
|
|||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.trio
|
||||
async def test_notify(is_host_secure):
|
||||
swarms = [SwarmFactory(is_secure=is_host_secure) for _ in range(2)]
|
||||
|
||||
events_0_0 = []
|
||||
events_1_0 = []
|
||||
events_0_without_listen = []
|
||||
swarms[0].register_notifee(MyNotifee(events_0_0))
|
||||
swarms[1].register_notifee(MyNotifee(events_1_0))
|
||||
# Listen
|
||||
await asyncio.gather(*[swarm.listen(LISTEN_MADDR) for swarm in swarms])
|
||||
# Run swarms.
|
||||
async with background_trio_service(swarms[0]), background_trio_service(swarms[1]):
|
||||
# Register events before listening, to allow `MyNotifee` is notified with the event
|
||||
# `listen`.
|
||||
swarms[0].register_notifee(MyNotifee(events_0_0))
|
||||
swarms[1].register_notifee(MyNotifee(events_1_0))
|
||||
|
||||
swarms[0].register_notifee(MyNotifee(events_0_without_listen))
|
||||
# Listen
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(swarms[0].listen, LISTEN_MADDR)
|
||||
nursery.start_soon(swarms[1].listen, LISTEN_MADDR)
|
||||
|
||||
# Connected
|
||||
await connect_swarm(swarms[0], swarms[1])
|
||||
# OpenedStream: first
|
||||
await swarms[0].new_stream(swarms[1].get_peer_id())
|
||||
# OpenedStream: second
|
||||
await swarms[0].new_stream(swarms[1].get_peer_id())
|
||||
# OpenedStream: third, but different direction.
|
||||
await swarms[1].new_stream(swarms[0].get_peer_id())
|
||||
swarms[0].register_notifee(MyNotifee(events_0_without_listen))
|
||||
|
||||
await asyncio.sleep(0.01)
|
||||
# Connected
|
||||
await connect_swarm(swarms[0], swarms[1])
|
||||
# OpenedStream: first
|
||||
await swarms[0].new_stream(swarms[1].get_peer_id())
|
||||
# OpenedStream: second
|
||||
await swarms[0].new_stream(swarms[1].get_peer_id())
|
||||
# OpenedStream: third, but different direction.
|
||||
await swarms[1].new_stream(swarms[0].get_peer_id())
|
||||
|
||||
# TODO: Check `ClosedStream` and `ListenClose` events after they are ready.
|
||||
await trio.sleep(0.01)
|
||||
|
||||
# Disconnected
|
||||
await swarms[0].close_peer(swarms[1].get_peer_id())
|
||||
await asyncio.sleep(0.01)
|
||||
# TODO: Check `ClosedStream` and `ListenClose` events after they are ready.
|
||||
|
||||
# Connected again, but different direction.
|
||||
await connect_swarm(swarms[1], swarms[0])
|
||||
await asyncio.sleep(0.01)
|
||||
# Disconnected
|
||||
await swarms[0].close_peer(swarms[1].get_peer_id())
|
||||
await trio.sleep(0.01)
|
||||
|
||||
# Disconnected again, but different direction.
|
||||
await swarms[1].close_peer(swarms[0].get_peer_id())
|
||||
await asyncio.sleep(0.01)
|
||||
# Connected again, but different direction.
|
||||
await connect_swarm(swarms[1], swarms[0])
|
||||
await trio.sleep(0.01)
|
||||
|
||||
expected_events_without_listen = [
|
||||
Event.Connected,
|
||||
Event.OpenedStream,
|
||||
Event.OpenedStream,
|
||||
Event.OpenedStream,
|
||||
Event.Disconnected,
|
||||
Event.Connected,
|
||||
Event.Disconnected,
|
||||
]
|
||||
expected_events = [Event.Listen] + expected_events_without_listen
|
||||
# Disconnected again, but different direction.
|
||||
await swarms[1].close_peer(swarms[0].get_peer_id())
|
||||
await trio.sleep(0.01)
|
||||
|
||||
assert events_0_0 == expected_events
|
||||
assert events_1_0 == expected_events
|
||||
assert events_0_without_listen == expected_events_without_listen
|
||||
expected_events_without_listen = [
|
||||
Event.Connected,
|
||||
Event.OpenedStream,
|
||||
Event.OpenedStream,
|
||||
Event.OpenedStream,
|
||||
Event.Disconnected,
|
||||
Event.Connected,
|
||||
Event.Disconnected,
|
||||
]
|
||||
expected_events = [Event.Listen] + expected_events_without_listen
|
||||
|
||||
# Clean up
|
||||
await asyncio.gather(*[swarm.close() for swarm in swarms])
|
||||
assert events_0_0 == expected_events
|
||||
assert events_1_0 == expected_events
|
||||
assert events_0_without_listen == expected_events_without_listen
|
||||
|
|
|
@ -1,89 +1,84 @@
|
|||
import asyncio
|
||||
|
||||
from multiaddr import Multiaddr
|
||||
import pytest
|
||||
import trio
|
||||
from trio.testing import wait_all_tasks_blocked
|
||||
|
||||
from libp2p.network.exceptions import SwarmException
|
||||
from libp2p.tools.factories import SwarmFactory
|
||||
from libp2p.tools.utils import connect_swarm
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.trio
|
||||
async def test_swarm_dial_peer(is_host_secure):
|
||||
swarms = await SwarmFactory.create_batch_and_listen(is_host_secure, 3)
|
||||
# Test: No addr found.
|
||||
with pytest.raises(SwarmException):
|
||||
async with SwarmFactory.create_batch_and_listen(is_host_secure, 3) as swarms:
|
||||
# Test: No addr found.
|
||||
with pytest.raises(SwarmException):
|
||||
await swarms[0].dial_peer(swarms[1].get_peer_id())
|
||||
|
||||
# Test: len(addr) in the peerstore is 0.
|
||||
swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), [], 10000)
|
||||
with pytest.raises(SwarmException):
|
||||
await swarms[0].dial_peer(swarms[1].get_peer_id())
|
||||
|
||||
# Test: Succeed if addrs of the peer_id are present in the peerstore.
|
||||
addrs = tuple(
|
||||
addr
|
||||
for transport in swarms[1].listeners.values()
|
||||
for addr in transport.get_addrs()
|
||||
)
|
||||
swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), addrs, 10000)
|
||||
await swarms[0].dial_peer(swarms[1].get_peer_id())
|
||||
assert swarms[0].get_peer_id() in swarms[1].connections
|
||||
assert swarms[1].get_peer_id() in swarms[0].connections
|
||||
|
||||
# Test: len(addr) in the peerstore is 0.
|
||||
swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), [], 10000)
|
||||
with pytest.raises(SwarmException):
|
||||
await swarms[0].dial_peer(swarms[1].get_peer_id())
|
||||
|
||||
# Test: Succeed if addrs of the peer_id are present in the peerstore.
|
||||
addrs = tuple(
|
||||
addr
|
||||
for transport in swarms[1].listeners.values()
|
||||
for addr in transport.get_addrs()
|
||||
)
|
||||
swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), addrs, 10000)
|
||||
await swarms[0].dial_peer(swarms[1].get_peer_id())
|
||||
assert swarms[0].get_peer_id() in swarms[1].connections
|
||||
assert swarms[1].get_peer_id() in swarms[0].connections
|
||||
|
||||
# Test: Reuse connections when we already have ones with a peer.
|
||||
conn_to_1 = swarms[0].connections[swarms[1].get_peer_id()]
|
||||
conn = await swarms[0].dial_peer(swarms[1].get_peer_id())
|
||||
assert conn is conn_to_1
|
||||
|
||||
# Clean up
|
||||
await asyncio.gather(*[swarm.close() for swarm in swarms])
|
||||
# Test: Reuse connections when we already have ones with a peer.
|
||||
conn_to_1 = swarms[0].connections[swarms[1].get_peer_id()]
|
||||
conn = await swarms[0].dial_peer(swarms[1].get_peer_id())
|
||||
assert conn is conn_to_1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.trio
|
||||
async def test_swarm_close_peer(is_host_secure):
|
||||
swarms = await SwarmFactory.create_batch_and_listen(is_host_secure, 3)
|
||||
# 0 <> 1 <> 2
|
||||
await connect_swarm(swarms[0], swarms[1])
|
||||
await connect_swarm(swarms[1], swarms[2])
|
||||
async with SwarmFactory.create_batch_and_listen(is_host_secure, 3) as swarms:
|
||||
# 0 <> 1 <> 2
|
||||
await connect_swarm(swarms[0], swarms[1])
|
||||
await connect_swarm(swarms[1], swarms[2])
|
||||
|
||||
# peer 1 closes peer 0
|
||||
await swarms[1].close_peer(swarms[0].get_peer_id())
|
||||
await asyncio.sleep(0.01)
|
||||
# 0 1 <> 2
|
||||
assert len(swarms[0].connections) == 0
|
||||
assert (
|
||||
len(swarms[1].connections) == 1
|
||||
and swarms[2].get_peer_id() in swarms[1].connections
|
||||
)
|
||||
# peer 1 closes peer 0
|
||||
await swarms[1].close_peer(swarms[0].get_peer_id())
|
||||
await trio.sleep(0.01)
|
||||
await wait_all_tasks_blocked()
|
||||
# 0 1 <> 2
|
||||
assert len(swarms[0].connections) == 0
|
||||
assert (
|
||||
len(swarms[1].connections) == 1
|
||||
and swarms[2].get_peer_id() in swarms[1].connections
|
||||
)
|
||||
|
||||
# peer 1 is closed by peer 2
|
||||
await swarms[2].close_peer(swarms[1].get_peer_id())
|
||||
await asyncio.sleep(0.01)
|
||||
# 0 1 2
|
||||
assert len(swarms[1].connections) == 0 and len(swarms[2].connections) == 0
|
||||
# peer 1 is closed by peer 2
|
||||
await swarms[2].close_peer(swarms[1].get_peer_id())
|
||||
await trio.sleep(0.01)
|
||||
# 0 1 2
|
||||
assert len(swarms[1].connections) == 0 and len(swarms[2].connections) == 0
|
||||
|
||||
await connect_swarm(swarms[0], swarms[1])
|
||||
# 0 <> 1 2
|
||||
assert (
|
||||
len(swarms[0].connections) == 1
|
||||
and swarms[1].get_peer_id() in swarms[0].connections
|
||||
)
|
||||
assert (
|
||||
len(swarms[1].connections) == 1
|
||||
and swarms[0].get_peer_id() in swarms[1].connections
|
||||
)
|
||||
# peer 0 closes peer 1
|
||||
await swarms[0].close_peer(swarms[1].get_peer_id())
|
||||
await asyncio.sleep(0.01)
|
||||
# 0 1 2
|
||||
assert len(swarms[1].connections) == 0 and len(swarms[2].connections) == 0
|
||||
|
||||
# Clean up
|
||||
await asyncio.gather(*[swarm.close() for swarm in swarms])
|
||||
await connect_swarm(swarms[0], swarms[1])
|
||||
# 0 <> 1 2
|
||||
assert (
|
||||
len(swarms[0].connections) == 1
|
||||
and swarms[1].get_peer_id() in swarms[0].connections
|
||||
)
|
||||
assert (
|
||||
len(swarms[1].connections) == 1
|
||||
and swarms[0].get_peer_id() in swarms[1].connections
|
||||
)
|
||||
# peer 0 closes peer 1
|
||||
await swarms[0].close_peer(swarms[1].get_peer_id())
|
||||
await trio.sleep(0.01)
|
||||
# 0 1 2
|
||||
assert len(swarms[1].connections) == 0 and len(swarms[2].connections) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.trio
|
||||
async def test_swarm_remove_conn(swarm_pair):
|
||||
swarm_0, swarm_1 = swarm_pair
|
||||
conn_0 = swarm_0.connections[swarm_1.get_peer_id()]
|
||||
|
@ -94,57 +89,54 @@ async def test_swarm_remove_conn(swarm_pair):
|
|||
assert swarm_1.get_peer_id() not in swarm_0.connections
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.trio
|
||||
async def test_swarm_multiaddr(is_host_secure):
|
||||
swarms = await SwarmFactory.create_batch_and_listen(is_host_secure, 3)
|
||||
async with SwarmFactory.create_batch_and_listen(is_host_secure, 3) as swarms:
|
||||
|
||||
def clear():
|
||||
swarms[0].peerstore.clear_addrs(swarms[1].get_peer_id())
|
||||
def clear():
|
||||
swarms[0].peerstore.clear_addrs(swarms[1].get_peer_id())
|
||||
|
||||
clear()
|
||||
# No addresses
|
||||
with pytest.raises(SwarmException):
|
||||
clear()
|
||||
# No addresses
|
||||
with pytest.raises(SwarmException):
|
||||
await swarms[0].dial_peer(swarms[1].get_peer_id())
|
||||
|
||||
clear()
|
||||
# Wrong addresses
|
||||
swarms[0].peerstore.add_addrs(
|
||||
swarms[1].get_peer_id(), [Multiaddr("/ip4/0.0.0.0/tcp/9999")], 10000
|
||||
)
|
||||
|
||||
with pytest.raises(SwarmException):
|
||||
await swarms[0].dial_peer(swarms[1].get_peer_id())
|
||||
|
||||
clear()
|
||||
# Multiple wrong addresses
|
||||
swarms[0].peerstore.add_addrs(
|
||||
swarms[1].get_peer_id(),
|
||||
[Multiaddr("/ip4/0.0.0.0/tcp/9999"), Multiaddr("/ip4/0.0.0.0/tcp/9998")],
|
||||
10000,
|
||||
)
|
||||
|
||||
with pytest.raises(SwarmException):
|
||||
await swarms[0].dial_peer(swarms[1].get_peer_id())
|
||||
|
||||
# Test one address
|
||||
addrs = tuple(
|
||||
addr
|
||||
for transport in swarms[1].listeners.values()
|
||||
for addr in transport.get_addrs()
|
||||
)
|
||||
|
||||
swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), addrs[:1], 10000)
|
||||
await swarms[0].dial_peer(swarms[1].get_peer_id())
|
||||
|
||||
clear()
|
||||
# Wrong addresses
|
||||
swarms[0].peerstore.add_addrs(
|
||||
swarms[1].get_peer_id(), [Multiaddr("/ip4/0.0.0.0/tcp/9999")], 10000
|
||||
)
|
||||
# Test multiple addresses
|
||||
addrs = tuple(
|
||||
addr
|
||||
for transport in swarms[1].listeners.values()
|
||||
for addr in transport.get_addrs()
|
||||
)
|
||||
|
||||
with pytest.raises(SwarmException):
|
||||
swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), addrs + addrs, 10000)
|
||||
await swarms[0].dial_peer(swarms[1].get_peer_id())
|
||||
|
||||
clear()
|
||||
# Multiple wrong addresses
|
||||
swarms[0].peerstore.add_addrs(
|
||||
swarms[1].get_peer_id(),
|
||||
[Multiaddr("/ip4/0.0.0.0/tcp/9999"), Multiaddr("/ip4/0.0.0.0/tcp/9998")],
|
||||
10000,
|
||||
)
|
||||
|
||||
with pytest.raises(SwarmException):
|
||||
await swarms[0].dial_peer(swarms[1].get_peer_id())
|
||||
|
||||
# Test one address
|
||||
addrs = tuple(
|
||||
addr
|
||||
for transport in swarms[1].listeners.values()
|
||||
for addr in transport.get_addrs()
|
||||
)
|
||||
|
||||
swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), addrs[:1], 10000)
|
||||
await swarms[0].dial_peer(swarms[1].get_peer_id())
|
||||
|
||||
# Test multiple addresses
|
||||
addrs = tuple(
|
||||
addr
|
||||
for transport in swarms[1].listeners.values()
|
||||
for addr in transport.get_addrs()
|
||||
)
|
||||
|
||||
swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), addrs + addrs, 10000)
|
||||
await swarms[0].dial_peer(swarms[1].get_peer_id())
|
||||
|
||||
for swarm in swarms:
|
||||
await swarm.close()
|
||||
|
|
|
@ -1,45 +1,46 @@
|
|||
import asyncio
|
||||
|
||||
import pytest
|
||||
import trio
|
||||
from trio.testing import wait_all_tasks_blocked
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.trio
|
||||
async def test_swarm_conn_close(swarm_conn_pair):
|
||||
conn_0, conn_1 = swarm_conn_pair
|
||||
|
||||
assert not conn_0.event_closed.is_set()
|
||||
assert not conn_1.event_closed.is_set()
|
||||
assert not conn_0.is_closed
|
||||
assert not conn_1.is_closed
|
||||
|
||||
await conn_0.close()
|
||||
|
||||
await asyncio.sleep(0.01)
|
||||
await trio.sleep(0.1)
|
||||
await wait_all_tasks_blocked()
|
||||
|
||||
assert conn_0.event_closed.is_set()
|
||||
assert conn_1.event_closed.is_set()
|
||||
assert conn_0.is_closed
|
||||
assert conn_1.is_closed
|
||||
assert conn_0 not in conn_0.swarm.connections.values()
|
||||
assert conn_1 not in conn_1.swarm.connections.values()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.trio
|
||||
async def test_swarm_conn_streams(swarm_conn_pair):
|
||||
conn_0, conn_1 = swarm_conn_pair
|
||||
|
||||
assert len(await conn_0.get_streams()) == 0
|
||||
assert len(await conn_1.get_streams()) == 0
|
||||
assert len(conn_0.get_streams()) == 0
|
||||
assert len(conn_1.get_streams()) == 0
|
||||
|
||||
stream_0_0 = await conn_0.new_stream()
|
||||
await asyncio.sleep(0.01)
|
||||
assert len(await conn_0.get_streams()) == 1
|
||||
assert len(await conn_1.get_streams()) == 1
|
||||
await trio.sleep(0.01)
|
||||
assert len(conn_0.get_streams()) == 1
|
||||
assert len(conn_1.get_streams()) == 1
|
||||
|
||||
stream_0_1 = await conn_0.new_stream()
|
||||
await asyncio.sleep(0.01)
|
||||
assert len(await conn_0.get_streams()) == 2
|
||||
assert len(await conn_1.get_streams()) == 2
|
||||
await trio.sleep(0.01)
|
||||
assert len(conn_0.get_streams()) == 2
|
||||
assert len(conn_1.get_streams()) == 2
|
||||
|
||||
conn_0.remove_stream(stream_0_0)
|
||||
assert len(await conn_0.get_streams()) == 1
|
||||
assert len(conn_0.get_streams()) == 1
|
||||
conn_0.remove_stream(stream_0_1)
|
||||
assert len(await conn_0.get_streams()) == 0
|
||||
assert len(conn_0.get_streams()) == 0
|
||||
# Nothing happen if `stream_0_1` is not present or already removed.
|
||||
conn_0.remove_stream(stream_0_1)
|
||||
|
|
|
@ -25,8 +25,6 @@ def test_init_():
|
|||
@pytest.mark.parametrize(
|
||||
"addr",
|
||||
(
|
||||
pytest.param(None),
|
||||
pytest.param(random.randint(0, 255), id="random integer"),
|
||||
pytest.param(multiaddr.Multiaddr("/"), id="empty multiaddr"),
|
||||
pytest.param(
|
||||
multiaddr.Multiaddr("/ip4/127.0.0.1"),
|
||||
|
|
|
@ -1,83 +1,94 @@
|
|||
import pytest
|
||||
|
||||
from libp2p.host.exceptions import StreamFailure
|
||||
from libp2p.tools.utils import echo_stream_handler, set_up_nodes_by_transport_opt
|
||||
from libp2p.tools.factories import HostFactory
|
||||
from libp2p.tools.utils import create_echo_stream_handler
|
||||
|
||||
# TODO: Add tests for multiple streams being opened on different
|
||||
# protocols through the same connection
|
||||
PROTOCOL_ECHO = "/echo/1.0.0"
|
||||
PROTOCOL_POTATO = "/potato/1.0.0"
|
||||
PROTOCOL_FOO = "/foo/1.0.0"
|
||||
PROTOCOL_ROCK = "/rock/1.0.0"
|
||||
|
||||
# Note: async issues occurred when using the same port
|
||||
# so that's why I use different ports here.
|
||||
# TODO: modify tests so that those async issues don't occur
|
||||
# when using the same ports across tests
|
||||
ACK_PREFIX = "ack:"
|
||||
|
||||
|
||||
async def perform_simple_test(
|
||||
expected_selected_protocol, protocols_for_client, protocols_with_handlers
|
||||
expected_selected_protocol,
|
||||
protocols_for_client,
|
||||
protocols_with_handlers,
|
||||
is_host_secure,
|
||||
):
|
||||
transport_opt_list = [["/ip4/127.0.0.1/tcp/0"], ["/ip4/127.0.0.1/tcp/0"]]
|
||||
(node_a, node_b) = await set_up_nodes_by_transport_opt(transport_opt_list)
|
||||
async with HostFactory.create_batch_and_listen(is_host_secure, 2) as hosts:
|
||||
for protocol in protocols_with_handlers:
|
||||
hosts[1].set_stream_handler(
|
||||
protocol, create_echo_stream_handler(ACK_PREFIX)
|
||||
)
|
||||
|
||||
for protocol in protocols_with_handlers:
|
||||
node_b.set_stream_handler(protocol, echo_stream_handler)
|
||||
# Associate the peer with local ip address (see default parameters of Libp2p())
|
||||
hosts[0].get_peerstore().add_addrs(hosts[1].get_id(), hosts[1].get_addrs(), 10)
|
||||
stream = await hosts[0].new_stream(hosts[1].get_id(), protocols_for_client)
|
||||
messages = ["hello" + str(x) for x in range(10)]
|
||||
for message in messages:
|
||||
expected_resp = "ack:" + message
|
||||
await stream.write(message.encode())
|
||||
response = (await stream.read(len(expected_resp))).decode()
|
||||
assert response == expected_resp
|
||||
|
||||
# Associate the peer with local ip address (see default parameters of Libp2p())
|
||||
node_a.get_peerstore().add_addrs(node_b.get_id(), node_b.get_addrs(), 10)
|
||||
|
||||
stream = await node_a.new_stream(node_b.get_id(), protocols_for_client)
|
||||
messages = ["hello" + str(x) for x in range(10)]
|
||||
for message in messages:
|
||||
expected_resp = "ack:" + message
|
||||
await stream.write(message.encode())
|
||||
response = (await stream.read(len(expected_resp))).decode()
|
||||
assert response == expected_resp
|
||||
|
||||
assert expected_selected_protocol == stream.get_protocol()
|
||||
|
||||
# Success, terminate pending tasks.
|
||||
assert expected_selected_protocol == stream.get_protocol()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_protocol_succeeds():
|
||||
expected_selected_protocol = "/echo/1.0.0"
|
||||
@pytest.mark.trio
|
||||
async def test_single_protocol_succeeds(is_host_secure):
|
||||
expected_selected_protocol = PROTOCOL_ECHO
|
||||
await perform_simple_test(
|
||||
expected_selected_protocol, ["/echo/1.0.0"], ["/echo/1.0.0"]
|
||||
expected_selected_protocol,
|
||||
[expected_selected_protocol],
|
||||
[expected_selected_protocol],
|
||||
is_host_secure,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_protocol_fails():
|
||||
@pytest.mark.trio
|
||||
async def test_single_protocol_fails(is_host_secure):
|
||||
with pytest.raises(StreamFailure):
|
||||
await perform_simple_test("", ["/echo/1.0.0"], ["/potato/1.0.0"])
|
||||
await perform_simple_test(
|
||||
"", [PROTOCOL_ECHO], [PROTOCOL_POTATO], is_host_secure
|
||||
)
|
||||
|
||||
# Cleanup not reached on error
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_protocol_first_is_valid_succeeds():
|
||||
expected_selected_protocol = "/echo/1.0.0"
|
||||
protocols_for_client = ["/echo/1.0.0", "/potato/1.0.0"]
|
||||
protocols_for_listener = ["/foo/1.0.0", "/echo/1.0.0"]
|
||||
@pytest.mark.trio
|
||||
async def test_multiple_protocol_first_is_valid_succeeds(is_host_secure):
|
||||
expected_selected_protocol = PROTOCOL_ECHO
|
||||
protocols_for_client = [PROTOCOL_ECHO, PROTOCOL_POTATO]
|
||||
protocols_for_listener = [PROTOCOL_FOO, PROTOCOL_ECHO]
|
||||
await perform_simple_test(
|
||||
expected_selected_protocol, protocols_for_client, protocols_for_listener
|
||||
expected_selected_protocol,
|
||||
protocols_for_client,
|
||||
protocols_for_listener,
|
||||
is_host_secure,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_protocol_second_is_valid_succeeds():
|
||||
expected_selected_protocol = "/foo/1.0.0"
|
||||
protocols_for_client = ["/rock/1.0.0", "/foo/1.0.0"]
|
||||
protocols_for_listener = ["/foo/1.0.0", "/echo/1.0.0"]
|
||||
@pytest.mark.trio
|
||||
async def test_multiple_protocol_second_is_valid_succeeds(is_host_secure):
|
||||
expected_selected_protocol = PROTOCOL_FOO
|
||||
protocols_for_client = [PROTOCOL_ROCK, PROTOCOL_FOO]
|
||||
protocols_for_listener = [PROTOCOL_FOO, PROTOCOL_ECHO]
|
||||
await perform_simple_test(
|
||||
expected_selected_protocol, protocols_for_client, protocols_for_listener
|
||||
expected_selected_protocol,
|
||||
protocols_for_client,
|
||||
protocols_for_listener,
|
||||
is_host_secure,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_protocol_fails():
|
||||
protocols_for_client = ["/rock/1.0.0", "/foo/1.0.0", "/bar/1.0.0"]
|
||||
@pytest.mark.trio
|
||||
async def test_multiple_protocol_fails(is_host_secure):
|
||||
protocols_for_client = [PROTOCOL_ROCK, PROTOCOL_FOO, "/bar/1.0.0"]
|
||||
protocols_for_listener = ["/aspyn/1.0.0", "/rob/1.0.0", "/zx/1.0.0", "/alex/1.0.0"]
|
||||
with pytest.raises(StreamFailure):
|
||||
await perform_simple_test("", protocols_for_client, protocols_for_listener)
|
||||
|
||||
# Cleanup not reached on error
|
||||
await perform_simple_test(
|
||||
"", protocols_for_client, protocols_for_listener, is_host_secure
|
||||
)
|
||||
|
|
|
@ -1,58 +0,0 @@
|
|||
import pytest
|
||||
|
||||
from libp2p.tools.constants import GOSSIPSUB_PARAMS
|
||||
from libp2p.tools.factories import FloodsubFactory, GossipsubFactory, PubsubFactory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def is_strict_signing():
|
||||
return False
|
||||
|
||||
|
||||
def _make_pubsubs(hosts, pubsub_routers, cache_size, is_strict_signing):
|
||||
if len(pubsub_routers) != len(hosts):
|
||||
raise ValueError(
|
||||
f"lenght of pubsub_routers={pubsub_routers} should be equaled to the "
|
||||
f"length of hosts={len(hosts)}"
|
||||
)
|
||||
return tuple(
|
||||
PubsubFactory(
|
||||
host=host,
|
||||
router=router,
|
||||
cache_size=cache_size,
|
||||
strict_signing=is_strict_signing,
|
||||
)
|
||||
for host, router in zip(hosts, pubsub_routers)
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pubsub_cache_size():
|
||||
return None # default
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def gossipsub_params():
|
||||
return GOSSIPSUB_PARAMS
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pubsubs_fsub(num_hosts, hosts, pubsub_cache_size, is_strict_signing):
|
||||
floodsubs = FloodsubFactory.create_batch(num_hosts)
|
||||
_pubsubs_fsub = _make_pubsubs(
|
||||
hosts, floodsubs, pubsub_cache_size, is_strict_signing
|
||||
)
|
||||
yield _pubsubs_fsub
|
||||
# TODO: Clean up
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pubsubs_gsub(
|
||||
num_hosts, hosts, pubsub_cache_size, gossipsub_params, is_strict_signing
|
||||
):
|
||||
gossipsubs = GossipsubFactory.create_batch(num_hosts, **gossipsub_params._asdict())
|
||||
_pubsubs_gsub = _make_pubsubs(
|
||||
hosts, gossipsubs, pubsub_cache_size, is_strict_signing
|
||||
)
|
||||
yield _pubsubs_gsub
|
||||
# TODO: Clean up
|
|
@ -1,19 +1,10 @@
|
|||
import asyncio
|
||||
from threading import Thread
|
||||
|
||||
import pytest
|
||||
import trio
|
||||
|
||||
from libp2p.tools.pubsub.dummy_account_node import DummyAccountNode
|
||||
from libp2p.tools.utils import connect
|
||||
|
||||
|
||||
def create_setup_in_new_thread_func(dummy_node):
|
||||
def setup_in_new_thread():
|
||||
asyncio.ensure_future(dummy_node.setup_crypto_networking())
|
||||
|
||||
return setup_in_new_thread
|
||||
|
||||
|
||||
async def perform_test(num_nodes, adjacency_map, action_func, assertion_func):
|
||||
"""
|
||||
Helper function to allow for easy construction of custom tests for dummy
|
||||
|
@ -26,47 +17,35 @@ async def perform_test(num_nodes, adjacency_map, action_func, assertion_func):
|
|||
:param assertion_func: assertions for testing the results of the actions are correct
|
||||
"""
|
||||
|
||||
# Create nodes
|
||||
dummy_nodes = []
|
||||
for _ in range(num_nodes):
|
||||
dummy_nodes.append(await DummyAccountNode.create())
|
||||
async with DummyAccountNode.create(num_nodes) as dummy_nodes:
|
||||
# Create connections between nodes according to `adjacency_map`
|
||||
async with trio.open_nursery() as nursery:
|
||||
for source_num in adjacency_map:
|
||||
target_nums = adjacency_map[source_num]
|
||||
for target_num in target_nums:
|
||||
nursery.start_soon(
|
||||
connect,
|
||||
dummy_nodes[source_num].host,
|
||||
dummy_nodes[target_num].host,
|
||||
)
|
||||
|
||||
# Create network
|
||||
for source_num in adjacency_map:
|
||||
target_nums = adjacency_map[source_num]
|
||||
for target_num in target_nums:
|
||||
await connect(
|
||||
dummy_nodes[source_num].libp2p_node, dummy_nodes[target_num].libp2p_node
|
||||
)
|
||||
# Allow time for network creation to take place
|
||||
await trio.sleep(0.25)
|
||||
|
||||
# Allow time for network creation to take place
|
||||
await asyncio.sleep(0.25)
|
||||
# Perform action function
|
||||
await action_func(dummy_nodes)
|
||||
|
||||
# Start a thread for each node so that each node can listen and respond
|
||||
# to messages on its own thread, which will avoid waiting indefinitely
|
||||
# on the main thread. On this thread, call the setup func for the node,
|
||||
# which subscribes the node to the CRYPTO_TOPIC topic
|
||||
for dummy_node in dummy_nodes:
|
||||
thread = Thread(target=create_setup_in_new_thread_func(dummy_node))
|
||||
thread.run()
|
||||
# Allow time for action function to be performed (i.e. messages to propogate)
|
||||
await trio.sleep(1)
|
||||
|
||||
# Allow time for nodes to subscribe to CRYPTO_TOPIC topic
|
||||
await asyncio.sleep(0.25)
|
||||
|
||||
# Perform action function
|
||||
await action_func(dummy_nodes)
|
||||
|
||||
# Allow time for action function to be performed (i.e. messages to propogate)
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Perform assertion function
|
||||
for dummy_node in dummy_nodes:
|
||||
assertion_func(dummy_node)
|
||||
# Perform assertion function
|
||||
for dummy_node in dummy_nodes:
|
||||
assertion_func(dummy_node)
|
||||
|
||||
# Success, terminate pending tasks.
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.trio
|
||||
async def test_simple_two_nodes():
|
||||
num_nodes = 2
|
||||
adj_map = {0: [1]}
|
||||
|
@ -80,7 +59,7 @@ async def test_simple_two_nodes():
|
|||
await perform_test(num_nodes, adj_map, action_func, assertion_func)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.trio
|
||||
async def test_simple_three_nodes_line_topography():
|
||||
num_nodes = 3
|
||||
adj_map = {0: [1], 1: [2]}
|
||||
|
@ -94,7 +73,7 @@ async def test_simple_three_nodes_line_topography():
|
|||
await perform_test(num_nodes, adj_map, action_func, assertion_func)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.trio
|
||||
async def test_simple_three_nodes_triangle_topography():
|
||||
num_nodes = 3
|
||||
adj_map = {0: [1, 2], 1: [2]}
|
||||
|
@ -108,7 +87,7 @@ async def test_simple_three_nodes_triangle_topography():
|
|||
await perform_test(num_nodes, adj_map, action_func, assertion_func)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.trio
|
||||
async def test_simple_seven_nodes_tree_topography():
|
||||
num_nodes = 7
|
||||
adj_map = {0: [1, 2], 1: [3, 4], 2: [5, 6]}
|
||||
|
@ -122,14 +101,14 @@ async def test_simple_seven_nodes_tree_topography():
|
|||
await perform_test(num_nodes, adj_map, action_func, assertion_func)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.trio
|
||||
async def test_set_then_send_from_root_seven_nodes_tree_topography():
|
||||
num_nodes = 7
|
||||
adj_map = {0: [1, 2], 1: [3, 4], 2: [5, 6]}
|
||||
|
||||
async def action_func(dummy_nodes):
|
||||
await dummy_nodes[0].publish_set_crypto("aspyn", 20)
|
||||
await asyncio.sleep(0.25)
|
||||
await trio.sleep(0.25)
|
||||
await dummy_nodes[0].publish_send_crypto("aspyn", "alex", 5)
|
||||
|
||||
def assertion_func(dummy_node):
|
||||
|
@ -139,14 +118,14 @@ async def test_set_then_send_from_root_seven_nodes_tree_topography():
|
|||
await perform_test(num_nodes, adj_map, action_func, assertion_func)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.trio
|
||||
async def test_set_then_send_from_different_leafs_seven_nodes_tree_topography():
|
||||
num_nodes = 7
|
||||
adj_map = {0: [1, 2], 1: [3, 4], 2: [5, 6]}
|
||||
|
||||
async def action_func(dummy_nodes):
|
||||
await dummy_nodes[6].publish_set_crypto("aspyn", 20)
|
||||
await asyncio.sleep(0.25)
|
||||
await trio.sleep(0.25)
|
||||
await dummy_nodes[4].publish_send_crypto("aspyn", "alex", 5)
|
||||
|
||||
def assertion_func(dummy_node):
|
||||
|
@ -156,7 +135,7 @@ async def test_set_then_send_from_different_leafs_seven_nodes_tree_topography():
|
|||
await perform_test(num_nodes, adj_map, action_func, assertion_func)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.trio
|
||||
async def test_simple_five_nodes_ring_topography():
|
||||
num_nodes = 5
|
||||
adj_map = {0: [1], 1: [2], 2: [3], 3: [4], 4: [0]}
|
||||
|
@ -170,14 +149,14 @@ async def test_simple_five_nodes_ring_topography():
|
|||
await perform_test(num_nodes, adj_map, action_func, assertion_func)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.trio
|
||||
async def test_set_then_send_from_diff_nodes_five_nodes_ring_topography():
|
||||
num_nodes = 5
|
||||
adj_map = {0: [1], 1: [2], 2: [3], 3: [4], 4: [0]}
|
||||
|
||||
async def action_func(dummy_nodes):
|
||||
await dummy_nodes[0].publish_set_crypto("alex", 20)
|
||||
await asyncio.sleep(0.25)
|
||||
await trio.sleep(0.25)
|
||||
await dummy_nodes[3].publish_send_crypto("alex", "rob", 12)
|
||||
|
||||
def assertion_func(dummy_node):
|
||||
|
@ -187,7 +166,7 @@ async def test_set_then_send_from_diff_nodes_five_nodes_ring_topography():
|
|||
await perform_test(num_nodes, adj_map, action_func, assertion_func)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.trio
|
||||
@pytest.mark.slow
|
||||
async def test_set_then_send_from_five_diff_nodes_five_nodes_ring_topography():
|
||||
num_nodes = 5
|
||||
|
@ -195,13 +174,13 @@ async def test_set_then_send_from_five_diff_nodes_five_nodes_ring_topography():
|
|||
|
||||
async def action_func(dummy_nodes):
|
||||
await dummy_nodes[0].publish_set_crypto("alex", 20)
|
||||
await asyncio.sleep(1)
|
||||
await trio.sleep(1)
|
||||
await dummy_nodes[1].publish_send_crypto("alex", "rob", 3)
|
||||
await asyncio.sleep(1)
|
||||
await trio.sleep(1)
|
||||
await dummy_nodes[2].publish_send_crypto("rob", "aspyn", 2)
|
||||
await asyncio.sleep(1)
|
||||
await trio.sleep(1)
|
||||
await dummy_nodes[3].publish_send_crypto("aspyn", "zx", 1)
|
||||
await asyncio.sleep(1)
|
||||
await trio.sleep(1)
|
||||
await dummy_nodes[4].publish_send_crypto("zx", "raul", 1)
|
||||
|
||||
def assertion_func(dummy_node):
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
import asyncio
|
||||
import functools
|
||||
|
||||
import pytest
|
||||
import trio
|
||||
|
||||
from libp2p.peer.id import ID
|
||||
from libp2p.tools.factories import FloodsubFactory
|
||||
from libp2p.tools.factories import PubsubFactory
|
||||
from libp2p.tools.pubsub.floodsub_integration_test_settings import (
|
||||
floodsub_protocol_pytest_params,
|
||||
perform_test_from_obj,
|
||||
|
@ -11,79 +12,80 @@ from libp2p.tools.pubsub.floodsub_integration_test_settings import (
|
|||
from libp2p.tools.utils import connect
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_hosts", (2,))
|
||||
@pytest.mark.asyncio
|
||||
async def test_simple_two_nodes(pubsubs_fsub):
|
||||
topic = "my_topic"
|
||||
data = b"some data"
|
||||
@pytest.mark.trio
|
||||
async def test_simple_two_nodes():
|
||||
async with PubsubFactory.create_batch_with_floodsub(2) as pubsubs_fsub:
|
||||
topic = "my_topic"
|
||||
data = b"some data"
|
||||
|
||||
await connect(pubsubs_fsub[0].host, pubsubs_fsub[1].host)
|
||||
await asyncio.sleep(0.25)
|
||||
await connect(pubsubs_fsub[0].host, pubsubs_fsub[1].host)
|
||||
await trio.sleep(0.25)
|
||||
|
||||
sub_b = await pubsubs_fsub[1].subscribe(topic)
|
||||
# Sleep to let a know of b's subscription
|
||||
await asyncio.sleep(0.25)
|
||||
sub_b = await pubsubs_fsub[1].subscribe(topic)
|
||||
# Sleep to let a know of b's subscription
|
||||
await trio.sleep(0.25)
|
||||
|
||||
await pubsubs_fsub[0].publish(topic, data)
|
||||
await pubsubs_fsub[0].publish(topic, data)
|
||||
|
||||
res_b = await sub_b.get()
|
||||
|
||||
# Check that the msg received by node_b is the same
|
||||
# as the message sent by node_a
|
||||
assert ID(res_b.from_id) == pubsubs_fsub[0].host.get_id()
|
||||
assert res_b.data == data
|
||||
assert res_b.topicIDs == [topic]
|
||||
|
||||
# Success, terminate pending tasks.
|
||||
|
||||
|
||||
# Initialize Pubsub with a cache_size of 4
|
||||
@pytest.mark.parametrize("num_hosts, pubsub_cache_size", ((2, 4),))
|
||||
@pytest.mark.asyncio
|
||||
async def test_lru_cache_two_nodes(pubsubs_fsub, monkeypatch):
|
||||
# two nodes with cache_size of 4
|
||||
# `node_a` send the following messages to node_b
|
||||
message_indices = [1, 1, 2, 1, 3, 1, 4, 1, 5, 1]
|
||||
# `node_b` should only receive the following
|
||||
expected_received_indices = [1, 2, 3, 4, 5, 1]
|
||||
|
||||
topic = "my_topic"
|
||||
|
||||
# Mock `get_msg_id` to make us easier to manipulate `msg_id` by `data`.
|
||||
def get_msg_id(msg):
|
||||
# Originally it is `(msg.seqno, msg.from_id)`
|
||||
return (msg.data, msg.from_id)
|
||||
|
||||
import libp2p.pubsub.pubsub
|
||||
|
||||
monkeypatch.setattr(libp2p.pubsub.pubsub, "get_msg_id", get_msg_id)
|
||||
|
||||
await connect(pubsubs_fsub[0].host, pubsubs_fsub[1].host)
|
||||
await asyncio.sleep(0.25)
|
||||
|
||||
sub_b = await pubsubs_fsub[1].subscribe(topic)
|
||||
await asyncio.sleep(0.25)
|
||||
|
||||
def _make_testing_data(i: int) -> bytes:
|
||||
num_int_bytes = 4
|
||||
if i >= 2 ** (num_int_bytes * 8):
|
||||
raise ValueError("integer is too large to be serialized")
|
||||
return b"data" + i.to_bytes(num_int_bytes, "big")
|
||||
|
||||
for index in message_indices:
|
||||
await pubsubs_fsub[0].publish(topic, _make_testing_data(index))
|
||||
await asyncio.sleep(0.25)
|
||||
|
||||
for index in expected_received_indices:
|
||||
res_b = await sub_b.get()
|
||||
assert res_b.data == _make_testing_data(index)
|
||||
assert sub_b.empty()
|
||||
|
||||
# Success, terminate pending tasks.
|
||||
# Check that the msg received by node_b is the same
|
||||
# as the message sent by node_a
|
||||
assert ID(res_b.from_id) == pubsubs_fsub[0].host.get_id()
|
||||
assert res_b.data == data
|
||||
assert res_b.topicIDs == [topic]
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_lru_cache_two_nodes(monkeypatch):
|
||||
# two nodes with cache_size of 4
|
||||
async with PubsubFactory.create_batch_with_floodsub(
|
||||
2, cache_size=4
|
||||
) as pubsubs_fsub:
|
||||
# `node_a` send the following messages to node_b
|
||||
message_indices = [1, 1, 2, 1, 3, 1, 4, 1, 5, 1]
|
||||
# `node_b` should only receive the following
|
||||
expected_received_indices = [1, 2, 3, 4, 5, 1]
|
||||
|
||||
topic = "my_topic"
|
||||
|
||||
# Mock `get_msg_id` to make us easier to manipulate `msg_id` by `data`.
|
||||
def get_msg_id(msg):
|
||||
# Originally it is `(msg.seqno, msg.from_id)`
|
||||
return (msg.data, msg.from_id)
|
||||
|
||||
import libp2p.pubsub.pubsub
|
||||
|
||||
monkeypatch.setattr(libp2p.pubsub.pubsub, "get_msg_id", get_msg_id)
|
||||
|
||||
await connect(pubsubs_fsub[0].host, pubsubs_fsub[1].host)
|
||||
await trio.sleep(0.25)
|
||||
|
||||
sub_b = await pubsubs_fsub[1].subscribe(topic)
|
||||
await trio.sleep(0.25)
|
||||
|
||||
def _make_testing_data(i: int) -> bytes:
|
||||
num_int_bytes = 4
|
||||
if i >= 2 ** (num_int_bytes * 8):
|
||||
raise ValueError("integer is too large to be serialized")
|
||||
return b"data" + i.to_bytes(num_int_bytes, "big")
|
||||
|
||||
for index in message_indices:
|
||||
await pubsubs_fsub[0].publish(topic, _make_testing_data(index))
|
||||
await trio.sleep(0.25)
|
||||
|
||||
for index in expected_received_indices:
|
||||
res_b = await sub_b.get()
|
||||
assert res_b.data == _make_testing_data(index)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("test_case_obj", floodsub_protocol_pytest_params)
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.trio
|
||||
@pytest.mark.slow
|
||||
async def test_gossipsub_run_with_floodsub_tests(test_case_obj):
|
||||
await perform_test_from_obj(test_case_obj, FloodsubFactory)
|
||||
async def test_gossipsub_run_with_floodsub_tests(test_case_obj, is_host_secure):
|
||||
await perform_test_from_obj(
|
||||
test_case_obj,
|
||||
functools.partial(
|
||||
PubsubFactory.create_batch_with_floodsub, is_secure=is_host_secure
|
||||
),
|
||||
)
|
||||
|
|
|
@ -1,495 +1,478 @@
|
|||
import asyncio
|
||||
import random
|
||||
|
||||
import pytest
|
||||
import trio
|
||||
|
||||
from libp2p.peer.id import ID
|
||||
from libp2p.pubsub.gossipsub import PROTOCOL_ID
|
||||
from libp2p.tools.constants import GOSSIPSUB_PARAMS, GossipsubParams
|
||||
from libp2p.tools.factories import IDFactory, PubsubFactory
|
||||
from libp2p.tools.pubsub.utils import dense_connect, one_to_all_connect
|
||||
from libp2p.tools.utils import connect
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"num_hosts, gossipsub_params",
|
||||
((4, GossipsubParams(degree=4, degree_low=3, degree_high=5)),),
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_join(num_hosts, hosts, pubsubs_gsub):
|
||||
gossipsubs = tuple(pubsub.router for pubsub in pubsubs_gsub)
|
||||
hosts_indices = list(range(num_hosts))
|
||||
@pytest.mark.trio
|
||||
async def test_join():
|
||||
async with PubsubFactory.create_batch_with_gossipsub(
|
||||
4, degree=4, degree_low=3, degree_high=5
|
||||
) as pubsubs_gsub:
|
||||
gossipsubs = [pubsub.router for pubsub in pubsubs_gsub]
|
||||
hosts = [pubsub.host for pubsub in pubsubs_gsub]
|
||||
hosts_indices = list(range(len(pubsubs_gsub)))
|
||||
|
||||
topic = "test_join"
|
||||
central_node_index = 0
|
||||
# Remove index of central host from the indices
|
||||
hosts_indices.remove(central_node_index)
|
||||
num_subscribed_peer = 2
|
||||
subscribed_peer_indices = random.sample(hosts_indices, num_subscribed_peer)
|
||||
topic = "test_join"
|
||||
central_node_index = 0
|
||||
# Remove index of central host from the indices
|
||||
hosts_indices.remove(central_node_index)
|
||||
num_subscribed_peer = 2
|
||||
subscribed_peer_indices = random.sample(hosts_indices, num_subscribed_peer)
|
||||
|
||||
# All pubsub except the one of central node subscribe to topic
|
||||
for i in subscribed_peer_indices:
|
||||
await pubsubs_gsub[i].subscribe(topic)
|
||||
# All pubsub except the one of central node subscribe to topic
|
||||
for i in subscribed_peer_indices:
|
||||
await pubsubs_gsub[i].subscribe(topic)
|
||||
|
||||
# Connect central host to all other hosts
|
||||
await one_to_all_connect(hosts, central_node_index)
|
||||
# Connect central host to all other hosts
|
||||
await one_to_all_connect(hosts, central_node_index)
|
||||
|
||||
# Wait 2 seconds for heartbeat to allow mesh to connect
|
||||
await asyncio.sleep(2)
|
||||
|
||||
# Central node publish to the topic so that this topic
|
||||
# is added to central node's fanout
|
||||
# publish from the randomly chosen host
|
||||
await pubsubs_gsub[central_node_index].publish(topic, b"data")
|
||||
|
||||
# Check that the gossipsub of central node has fanout for the topic
|
||||
assert topic in gossipsubs[central_node_index].fanout
|
||||
# Check that the gossipsub of central node does not have a mesh for the topic
|
||||
assert topic not in gossipsubs[central_node_index].mesh
|
||||
|
||||
# Central node subscribes the topic
|
||||
await pubsubs_gsub[central_node_index].subscribe(topic)
|
||||
|
||||
await asyncio.sleep(2)
|
||||
|
||||
# Check that the gossipsub of central node no longer has fanout for the topic
|
||||
assert topic not in gossipsubs[central_node_index].fanout
|
||||
|
||||
for i in hosts_indices:
|
||||
if i in subscribed_peer_indices:
|
||||
assert hosts[i].get_id() in gossipsubs[central_node_index].mesh[topic]
|
||||
assert hosts[central_node_index].get_id() in gossipsubs[i].mesh[topic]
|
||||
else:
|
||||
assert hosts[i].get_id() not in gossipsubs[central_node_index].mesh[topic]
|
||||
assert topic not in gossipsubs[i].mesh
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_hosts", (1,))
|
||||
@pytest.mark.asyncio
|
||||
async def test_leave(pubsubs_gsub):
|
||||
gossipsub = pubsubs_gsub[0].router
|
||||
topic = "test_leave"
|
||||
|
||||
assert topic not in gossipsub.mesh
|
||||
|
||||
await gossipsub.join(topic)
|
||||
assert topic in gossipsub.mesh
|
||||
|
||||
await gossipsub.leave(topic)
|
||||
assert topic not in gossipsub.mesh
|
||||
|
||||
# Test re-leave
|
||||
await gossipsub.leave(topic)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_hosts", (2,))
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_graft(pubsubs_gsub, hosts, event_loop, monkeypatch):
|
||||
gossipsubs = tuple(pubsub.router for pubsub in pubsubs_gsub)
|
||||
|
||||
index_alice = 0
|
||||
id_alice = hosts[index_alice].get_id()
|
||||
index_bob = 1
|
||||
id_bob = hosts[index_bob].get_id()
|
||||
await connect(hosts[index_alice], hosts[index_bob])
|
||||
|
||||
# Wait 2 seconds for heartbeat to allow mesh to connect
|
||||
await asyncio.sleep(2)
|
||||
|
||||
topic = "test_handle_graft"
|
||||
# Only lice subscribe to the topic
|
||||
await gossipsubs[index_alice].join(topic)
|
||||
|
||||
# Monkey patch bob's `emit_prune` function so we can
|
||||
# check if it is called in `handle_graft`
|
||||
event_emit_prune = asyncio.Event()
|
||||
|
||||
async def emit_prune(topic, sender_peer_id):
|
||||
event_emit_prune.set()
|
||||
|
||||
monkeypatch.setattr(gossipsubs[index_bob], "emit_prune", emit_prune)
|
||||
|
||||
# Check that alice is bob's peer but not his mesh peer
|
||||
assert gossipsubs[index_bob].peer_protocol[id_alice] == PROTOCOL_ID
|
||||
assert topic not in gossipsubs[index_bob].mesh
|
||||
|
||||
await gossipsubs[index_alice].emit_graft(topic, id_bob)
|
||||
|
||||
# Check that `emit_prune` is called
|
||||
await asyncio.wait_for(event_emit_prune.wait(), timeout=1, loop=event_loop)
|
||||
assert event_emit_prune.is_set()
|
||||
|
||||
# Check that bob is alice's peer but not her mesh peer
|
||||
assert topic in gossipsubs[index_alice].mesh
|
||||
assert id_bob not in gossipsubs[index_alice].mesh[topic]
|
||||
assert gossipsubs[index_alice].peer_protocol[id_bob] == PROTOCOL_ID
|
||||
|
||||
await gossipsubs[index_bob].emit_graft(topic, id_alice)
|
||||
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Check that bob is now alice's mesh peer
|
||||
assert id_bob in gossipsubs[index_alice].mesh[topic]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"num_hosts, gossipsub_params", ((2, GossipsubParams(heartbeat_interval=3)),)
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_prune(pubsubs_gsub, hosts):
|
||||
gossipsubs = tuple(pubsub.router for pubsub in pubsubs_gsub)
|
||||
|
||||
index_alice = 0
|
||||
id_alice = hosts[index_alice].get_id()
|
||||
index_bob = 1
|
||||
id_bob = hosts[index_bob].get_id()
|
||||
|
||||
topic = "test_handle_prune"
|
||||
for pubsub in pubsubs_gsub:
|
||||
await pubsub.subscribe(topic)
|
||||
|
||||
await connect(hosts[index_alice], hosts[index_bob])
|
||||
|
||||
# Wait for heartbeat to allow mesh to connect
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Check that they are each other's mesh peer
|
||||
assert id_alice in gossipsubs[index_bob].mesh[topic]
|
||||
assert id_bob in gossipsubs[index_alice].mesh[topic]
|
||||
|
||||
# alice emit prune message to bob, alice should be removed
|
||||
# from bob's mesh peer
|
||||
await gossipsubs[index_alice].emit_prune(topic, id_bob)
|
||||
# `emit_prune` does not remove bob from alice's mesh peers
|
||||
assert id_bob in gossipsubs[index_alice].mesh[topic]
|
||||
|
||||
# NOTE: We increase `heartbeat_interval` to 3 seconds so that bob will not
|
||||
# add alice back to his mesh after heartbeat.
|
||||
# Wait for bob to `handle_prune`
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Check that alice is no longer bob's mesh peer
|
||||
assert id_alice not in gossipsubs[index_bob].mesh[topic]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_hosts", (10,))
|
||||
@pytest.mark.asyncio
|
||||
async def test_dense(num_hosts, pubsubs_gsub, hosts):
|
||||
num_msgs = 5
|
||||
|
||||
# All pubsub subscribe to foobar
|
||||
queues = []
|
||||
for pubsub in pubsubs_gsub:
|
||||
q = await pubsub.subscribe("foobar")
|
||||
|
||||
# Add each blocking queue to an array of blocking queues
|
||||
queues.append(q)
|
||||
|
||||
# Densely connect libp2p hosts in a random way
|
||||
await dense_connect(hosts)
|
||||
|
||||
# Wait 2 seconds for heartbeat to allow mesh to connect
|
||||
await asyncio.sleep(2)
|
||||
|
||||
for i in range(num_msgs):
|
||||
msg_content = b"foo " + i.to_bytes(1, "big")
|
||||
|
||||
# randomly pick a message origin
|
||||
origin_idx = random.randint(0, num_hosts - 1)
|
||||
# Wait 2 seconds for heartbeat to allow mesh to connect
|
||||
await trio.sleep(2)
|
||||
|
||||
# Central node publish to the topic so that this topic
|
||||
# is added to central node's fanout
|
||||
# publish from the randomly chosen host
|
||||
await pubsubs_gsub[origin_idx].publish("foobar", msg_content)
|
||||
await pubsubs_gsub[central_node_index].publish(topic, b"data")
|
||||
|
||||
await asyncio.sleep(0.5)
|
||||
# Assert that all blocking queues receive the message
|
||||
for queue in queues:
|
||||
msg = await queue.get()
|
||||
assert msg.data == msg_content
|
||||
# Check that the gossipsub of central node has fanout for the topic
|
||||
assert topic in gossipsubs[central_node_index].fanout
|
||||
# Check that the gossipsub of central node does not have a mesh for the topic
|
||||
assert topic not in gossipsubs[central_node_index].mesh
|
||||
|
||||
# Central node subscribes the topic
|
||||
await pubsubs_gsub[central_node_index].subscribe(topic)
|
||||
|
||||
await trio.sleep(2)
|
||||
|
||||
# Check that the gossipsub of central node no longer has fanout for the topic
|
||||
assert topic not in gossipsubs[central_node_index].fanout
|
||||
|
||||
for i in hosts_indices:
|
||||
if i in subscribed_peer_indices:
|
||||
assert hosts[i].get_id() in gossipsubs[central_node_index].mesh[topic]
|
||||
assert hosts[central_node_index].get_id() in gossipsubs[i].mesh[topic]
|
||||
else:
|
||||
assert (
|
||||
hosts[i].get_id() not in gossipsubs[central_node_index].mesh[topic]
|
||||
)
|
||||
assert topic not in gossipsubs[i].mesh
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_hosts", (10,))
|
||||
@pytest.mark.asyncio
|
||||
async def test_fanout(hosts, pubsubs_gsub):
|
||||
num_msgs = 5
|
||||
@pytest.mark.trio
|
||||
async def test_leave():
|
||||
async with PubsubFactory.create_batch_with_gossipsub(1) as pubsubs_gsub:
|
||||
gossipsub = pubsubs_gsub[0].router
|
||||
topic = "test_leave"
|
||||
|
||||
# All pubsub subscribe to foobar except for `pubsubs_gsub[0]`
|
||||
queues = []
|
||||
for i in range(1, len(pubsubs_gsub)):
|
||||
q = await pubsubs_gsub[i].subscribe("foobar")
|
||||
assert topic not in gossipsub.mesh
|
||||
|
||||
# Add each blocking queue to an array of blocking queues
|
||||
queues.append(q)
|
||||
await gossipsub.join(topic)
|
||||
assert topic in gossipsub.mesh
|
||||
|
||||
# Sparsely connect libp2p hosts in random way
|
||||
await dense_connect(hosts)
|
||||
await gossipsub.leave(topic)
|
||||
assert topic not in gossipsub.mesh
|
||||
|
||||
# Wait 2 seconds for heartbeat to allow mesh to connect
|
||||
await asyncio.sleep(2)
|
||||
|
||||
topic = "foobar"
|
||||
# Send messages with origin not subscribed
|
||||
for i in range(num_msgs):
|
||||
msg_content = b"foo " + i.to_bytes(1, "big")
|
||||
|
||||
# Pick the message origin to the node that is not subscribed to 'foobar'
|
||||
origin_idx = 0
|
||||
|
||||
# publish from the randomly chosen host
|
||||
await pubsubs_gsub[origin_idx].publish(topic, msg_content)
|
||||
|
||||
await asyncio.sleep(0.5)
|
||||
# Assert that all blocking queues receive the message
|
||||
for queue in queues:
|
||||
msg = await queue.get()
|
||||
assert msg.data == msg_content
|
||||
|
||||
# Subscribe message origin
|
||||
queues.insert(0, await pubsubs_gsub[0].subscribe(topic))
|
||||
|
||||
# Send messages again
|
||||
for i in range(num_msgs):
|
||||
msg_content = b"bar " + i.to_bytes(1, "big")
|
||||
|
||||
# Pick the message origin to the node that is not subscribed to 'foobar'
|
||||
origin_idx = 0
|
||||
|
||||
# publish from the randomly chosen host
|
||||
await pubsubs_gsub[origin_idx].publish(topic, msg_content)
|
||||
|
||||
await asyncio.sleep(0.5)
|
||||
# Assert that all blocking queues receive the message
|
||||
for queue in queues:
|
||||
msg = await queue.get()
|
||||
assert msg.data == msg_content
|
||||
# Test re-leave
|
||||
await gossipsub.leave(topic)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_hosts", (10,))
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.trio
|
||||
async def test_handle_graft(monkeypatch):
|
||||
async with PubsubFactory.create_batch_with_gossipsub(2) as pubsubs_gsub:
|
||||
gossipsubs = tuple(pubsub.router for pubsub in pubsubs_gsub)
|
||||
|
||||
index_alice = 0
|
||||
id_alice = pubsubs_gsub[index_alice].my_id
|
||||
index_bob = 1
|
||||
id_bob = pubsubs_gsub[index_bob].my_id
|
||||
await connect(pubsubs_gsub[index_alice].host, pubsubs_gsub[index_bob].host)
|
||||
|
||||
# Wait 2 seconds for heartbeat to allow mesh to connect
|
||||
await trio.sleep(2)
|
||||
|
||||
topic = "test_handle_graft"
|
||||
# Only lice subscribe to the topic
|
||||
await gossipsubs[index_alice].join(topic)
|
||||
|
||||
# Monkey patch bob's `emit_prune` function so we can
|
||||
# check if it is called in `handle_graft`
|
||||
event_emit_prune = trio.Event()
|
||||
|
||||
async def emit_prune(topic, sender_peer_id):
|
||||
event_emit_prune.set()
|
||||
await trio.hazmat.checkpoint()
|
||||
|
||||
monkeypatch.setattr(gossipsubs[index_bob], "emit_prune", emit_prune)
|
||||
|
||||
# Check that alice is bob's peer but not his mesh peer
|
||||
assert gossipsubs[index_bob].peer_protocol[id_alice] == PROTOCOL_ID
|
||||
assert topic not in gossipsubs[index_bob].mesh
|
||||
|
||||
await gossipsubs[index_alice].emit_graft(topic, id_bob)
|
||||
|
||||
# Check that `emit_prune` is called
|
||||
await event_emit_prune.wait()
|
||||
|
||||
# Check that bob is alice's peer but not her mesh peer
|
||||
assert topic in gossipsubs[index_alice].mesh
|
||||
assert id_bob not in gossipsubs[index_alice].mesh[topic]
|
||||
assert gossipsubs[index_alice].peer_protocol[id_bob] == PROTOCOL_ID
|
||||
|
||||
await gossipsubs[index_bob].emit_graft(topic, id_alice)
|
||||
|
||||
await trio.sleep(1)
|
||||
|
||||
# Check that bob is now alice's mesh peer
|
||||
assert id_bob in gossipsubs[index_alice].mesh[topic]
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_handle_prune():
|
||||
async with PubsubFactory.create_batch_with_gossipsub(
|
||||
2, heartbeat_interval=3
|
||||
) as pubsubs_gsub:
|
||||
gossipsubs = tuple(pubsub.router for pubsub in pubsubs_gsub)
|
||||
|
||||
index_alice = 0
|
||||
id_alice = pubsubs_gsub[index_alice].my_id
|
||||
index_bob = 1
|
||||
id_bob = pubsubs_gsub[index_bob].my_id
|
||||
|
||||
topic = "test_handle_prune"
|
||||
for pubsub in pubsubs_gsub:
|
||||
await pubsub.subscribe(topic)
|
||||
|
||||
await connect(pubsubs_gsub[index_alice].host, pubsubs_gsub[index_bob].host)
|
||||
|
||||
# Wait for heartbeat to allow mesh to connect
|
||||
await trio.sleep(1)
|
||||
|
||||
# Check that they are each other's mesh peer
|
||||
assert id_alice in gossipsubs[index_bob].mesh[topic]
|
||||
assert id_bob in gossipsubs[index_alice].mesh[topic]
|
||||
|
||||
# alice emit prune message to bob, alice should be removed
|
||||
# from bob's mesh peer
|
||||
await gossipsubs[index_alice].emit_prune(topic, id_bob)
|
||||
# `emit_prune` does not remove bob from alice's mesh peers
|
||||
assert id_bob in gossipsubs[index_alice].mesh[topic]
|
||||
|
||||
# NOTE: We increase `heartbeat_interval` to 3 seconds so that bob will not
|
||||
# add alice back to his mesh after heartbeat.
|
||||
# Wait for bob to `handle_prune`
|
||||
await trio.sleep(0.1)
|
||||
|
||||
# Check that alice is no longer bob's mesh peer
|
||||
assert id_alice not in gossipsubs[index_bob].mesh[topic]
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_dense():
|
||||
async with PubsubFactory.create_batch_with_gossipsub(10) as pubsubs_gsub:
|
||||
hosts = [pubsub.host for pubsub in pubsubs_gsub]
|
||||
num_msgs = 5
|
||||
|
||||
# All pubsub subscribe to foobar
|
||||
queues = [await pubsub.subscribe("foobar") for pubsub in pubsubs_gsub]
|
||||
|
||||
# Densely connect libp2p hosts in a random way
|
||||
await dense_connect(hosts)
|
||||
|
||||
# Wait 2 seconds for heartbeat to allow mesh to connect
|
||||
await trio.sleep(2)
|
||||
|
||||
for i in range(num_msgs):
|
||||
msg_content = b"foo " + i.to_bytes(1, "big")
|
||||
|
||||
# randomly pick a message origin
|
||||
origin_idx = random.randint(0, len(hosts) - 1)
|
||||
|
||||
# publish from the randomly chosen host
|
||||
await pubsubs_gsub[origin_idx].publish("foobar", msg_content)
|
||||
|
||||
await trio.sleep(0.5)
|
||||
# Assert that all blocking queues receive the message
|
||||
for queue in queues:
|
||||
msg = await queue.get()
|
||||
assert msg.data == msg_content
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_fanout():
|
||||
async with PubsubFactory.create_batch_with_gossipsub(10) as pubsubs_gsub:
|
||||
hosts = [pubsub.host for pubsub in pubsubs_gsub]
|
||||
num_msgs = 5
|
||||
|
||||
# All pubsub subscribe to foobar except for `pubsubs_gsub[0]`
|
||||
subs = [await pubsub.subscribe("foobar") for pubsub in pubsubs_gsub[1:]]
|
||||
|
||||
# Sparsely connect libp2p hosts in random way
|
||||
await dense_connect(hosts)
|
||||
|
||||
# Wait 2 seconds for heartbeat to allow mesh to connect
|
||||
await trio.sleep(2)
|
||||
|
||||
topic = "foobar"
|
||||
# Send messages with origin not subscribed
|
||||
for i in range(num_msgs):
|
||||
msg_content = b"foo " + i.to_bytes(1, "big")
|
||||
|
||||
# Pick the message origin to the node that is not subscribed to 'foobar'
|
||||
origin_idx = 0
|
||||
|
||||
# publish from the randomly chosen host
|
||||
await pubsubs_gsub[origin_idx].publish(topic, msg_content)
|
||||
|
||||
await trio.sleep(0.5)
|
||||
# Assert that all blocking queues receive the message
|
||||
for sub in subs:
|
||||
msg = await sub.get()
|
||||
assert msg.data == msg_content
|
||||
|
||||
# Subscribe message origin
|
||||
subs.insert(0, await pubsubs_gsub[0].subscribe(topic))
|
||||
|
||||
# Send messages again
|
||||
for i in range(num_msgs):
|
||||
msg_content = b"bar " + i.to_bytes(1, "big")
|
||||
|
||||
# Pick the message origin to the node that is not subscribed to 'foobar'
|
||||
origin_idx = 0
|
||||
|
||||
# publish from the randomly chosen host
|
||||
await pubsubs_gsub[origin_idx].publish(topic, msg_content)
|
||||
|
||||
await trio.sleep(0.5)
|
||||
# Assert that all blocking queues receive the message
|
||||
for sub in subs:
|
||||
msg = await sub.get()
|
||||
assert msg.data == msg_content
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
@pytest.mark.slow
|
||||
async def test_fanout_maintenance(hosts, pubsubs_gsub):
|
||||
num_msgs = 5
|
||||
async def test_fanout_maintenance():
|
||||
async with PubsubFactory.create_batch_with_gossipsub(10) as pubsubs_gsub:
|
||||
hosts = [pubsub.host for pubsub in pubsubs_gsub]
|
||||
num_msgs = 5
|
||||
|
||||
# All pubsub subscribe to foobar
|
||||
queues = []
|
||||
topic = "foobar"
|
||||
for i in range(1, len(pubsubs_gsub)):
|
||||
q = await pubsubs_gsub[i].subscribe(topic)
|
||||
# All pubsub subscribe to foobar
|
||||
queues = []
|
||||
topic = "foobar"
|
||||
for i in range(1, len(pubsubs_gsub)):
|
||||
q = await pubsubs_gsub[i].subscribe(topic)
|
||||
|
||||
# Add each blocking queue to an array of blocking queues
|
||||
queues.append(q)
|
||||
# Add each blocking queue to an array of blocking queues
|
||||
queues.append(q)
|
||||
|
||||
# Sparsely connect libp2p hosts in random way
|
||||
await dense_connect(hosts)
|
||||
# Sparsely connect libp2p hosts in random way
|
||||
await dense_connect(hosts)
|
||||
|
||||
# Wait 2 seconds for heartbeat to allow mesh to connect
|
||||
await asyncio.sleep(2)
|
||||
# Wait 2 seconds for heartbeat to allow mesh to connect
|
||||
await trio.sleep(2)
|
||||
|
||||
# Send messages with origin not subscribed
|
||||
for i in range(num_msgs):
|
||||
msg_content = b"foo " + i.to_bytes(1, "big")
|
||||
# Send messages with origin not subscribed
|
||||
for i in range(num_msgs):
|
||||
msg_content = b"foo " + i.to_bytes(1, "big")
|
||||
|
||||
# Pick the message origin to the node that is not subscribed to 'foobar'
|
||||
origin_idx = 0
|
||||
# Pick the message origin to the node that is not subscribed to 'foobar'
|
||||
origin_idx = 0
|
||||
|
||||
# publish from the randomly chosen host
|
||||
await pubsubs_gsub[origin_idx].publish(topic, msg_content)
|
||||
|
||||
await trio.sleep(0.5)
|
||||
# Assert that all blocking queues receive the message
|
||||
for queue in queues:
|
||||
msg = await queue.get()
|
||||
assert msg.data == msg_content
|
||||
|
||||
for sub in pubsubs_gsub:
|
||||
await sub.unsubscribe(topic)
|
||||
|
||||
queues = []
|
||||
|
||||
await trio.sleep(2)
|
||||
|
||||
# Resub and repeat
|
||||
for i in range(1, len(pubsubs_gsub)):
|
||||
q = await pubsubs_gsub[i].subscribe(topic)
|
||||
|
||||
# Add each blocking queue to an array of blocking queues
|
||||
queues.append(q)
|
||||
|
||||
await trio.sleep(2)
|
||||
|
||||
# Check messages can still be sent
|
||||
for i in range(num_msgs):
|
||||
msg_content = b"bar " + i.to_bytes(1, "big")
|
||||
|
||||
# Pick the message origin to the node that is not subscribed to 'foobar'
|
||||
origin_idx = 0
|
||||
|
||||
# publish from the randomly chosen host
|
||||
await pubsubs_gsub[origin_idx].publish(topic, msg_content)
|
||||
|
||||
await trio.sleep(0.5)
|
||||
# Assert that all blocking queues receive the message
|
||||
for queue in queues:
|
||||
msg = await queue.get()
|
||||
assert msg.data == msg_content
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_gossip_propagation():
|
||||
async with PubsubFactory.create_batch_with_gossipsub(
|
||||
2, degree=1, degree_low=0, degree_high=2, gossip_window=50, gossip_history=100
|
||||
) as pubsubs_gsub:
|
||||
topic = "foo"
|
||||
queue_0 = await pubsubs_gsub[0].subscribe(topic)
|
||||
|
||||
# node 0 publish to topic
|
||||
msg_content = b"foo_msg"
|
||||
|
||||
# publish from the randomly chosen host
|
||||
await pubsubs_gsub[origin_idx].publish(topic, msg_content)
|
||||
await pubsubs_gsub[0].publish(topic, msg_content)
|
||||
|
||||
await asyncio.sleep(0.5)
|
||||
# Assert that all blocking queues receive the message
|
||||
for queue in queues:
|
||||
msg = await queue.get()
|
||||
assert msg.data == msg_content
|
||||
|
||||
for sub in pubsubs_gsub:
|
||||
await sub.unsubscribe(topic)
|
||||
|
||||
queues = []
|
||||
|
||||
await asyncio.sleep(2)
|
||||
|
||||
# Resub and repeat
|
||||
for i in range(1, len(pubsubs_gsub)):
|
||||
q = await pubsubs_gsub[i].subscribe(topic)
|
||||
|
||||
# Add each blocking queue to an array of blocking queues
|
||||
queues.append(q)
|
||||
|
||||
await asyncio.sleep(2)
|
||||
|
||||
# Check messages can still be sent
|
||||
for i in range(num_msgs):
|
||||
msg_content = b"bar " + i.to_bytes(1, "big")
|
||||
|
||||
# Pick the message origin to the node that is not subscribed to 'foobar'
|
||||
origin_idx = 0
|
||||
|
||||
# publish from the randomly chosen host
|
||||
await pubsubs_gsub[origin_idx].publish(topic, msg_content)
|
||||
|
||||
await asyncio.sleep(0.5)
|
||||
# Assert that all blocking queues receive the message
|
||||
for queue in queues:
|
||||
msg = await queue.get()
|
||||
assert msg.data == msg_content
|
||||
await trio.sleep(0.5)
|
||||
# Assert that the blocking queues receive the message
|
||||
msg = await queue_0.get()
|
||||
assert msg.data == msg_content
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"num_hosts, gossipsub_params",
|
||||
(
|
||||
(
|
||||
2,
|
||||
GossipsubParams(
|
||||
degree=1,
|
||||
degree_low=0,
|
||||
degree_high=2,
|
||||
gossip_window=50,
|
||||
gossip_history=100,
|
||||
),
|
||||
),
|
||||
),
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_gossip_propagation(hosts, pubsubs_gsub):
|
||||
topic = "foo"
|
||||
await pubsubs_gsub[0].subscribe(topic)
|
||||
|
||||
# node 0 publish to topic
|
||||
msg_content = b"foo_msg"
|
||||
|
||||
# publish from the randomly chosen host
|
||||
await pubsubs_gsub[0].publish(topic, msg_content)
|
||||
|
||||
# now node 1 subscribes
|
||||
queue_1 = await pubsubs_gsub[1].subscribe(topic)
|
||||
|
||||
await connect(hosts[0], hosts[1])
|
||||
|
||||
# wait for gossip heartbeat
|
||||
await asyncio.sleep(2)
|
||||
|
||||
# should be able to read message
|
||||
msg = await queue_1.get()
|
||||
assert msg.data == msg_content
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"num_hosts, gossipsub_params", ((1, GossipsubParams(heartbeat_initial_delay=100)),)
|
||||
)
|
||||
@pytest.mark.parametrize("initial_mesh_peer_count", (7, 10, 13))
|
||||
@pytest.mark.asyncio
|
||||
async def test_mesh_heartbeat(
|
||||
num_hosts, initial_mesh_peer_count, pubsubs_gsub, hosts, monkeypatch
|
||||
):
|
||||
# It's difficult to set up the initial peer subscription condition.
|
||||
# Ideally I would like to have initial mesh peer count that's below ``GossipSubDegree``
|
||||
# so I can test if `mesh_heartbeat` return correct peers to GRAFT.
|
||||
# The problem is that I can not set it up so that we have peers subscribe to the topic
|
||||
# but not being part of our mesh peers (as these peers are the peers to GRAFT).
|
||||
# So I monkeypatch the peer subscriptions and our mesh peers.
|
||||
total_peer_count = 14
|
||||
topic = "TEST_MESH_HEARTBEAT"
|
||||
@pytest.mark.trio
|
||||
async def test_mesh_heartbeat(initial_mesh_peer_count, monkeypatch):
|
||||
async with PubsubFactory.create_batch_with_gossipsub(
|
||||
1, heartbeat_initial_delay=100
|
||||
) as pubsubs_gsub:
|
||||
# It's difficult to set up the initial peer subscription condition.
|
||||
# Ideally I would like to have initial mesh peer count that's below ``GossipSubDegree``
|
||||
# so I can test if `mesh_heartbeat` return correct peers to GRAFT.
|
||||
# The problem is that I can not set it up so that we have peers subscribe to the topic
|
||||
# but not being part of our mesh peers (as these peers are the peers to GRAFT).
|
||||
# So I monkeypatch the peer subscriptions and our mesh peers.
|
||||
total_peer_count = 14
|
||||
topic = "TEST_MESH_HEARTBEAT"
|
||||
|
||||
fake_peer_ids = [
|
||||
ID((i).to_bytes(2, byteorder="big")) for i in range(total_peer_count)
|
||||
]
|
||||
peer_protocol = {peer_id: PROTOCOL_ID for peer_id in fake_peer_ids}
|
||||
monkeypatch.setattr(pubsubs_gsub[0].router, "peer_protocol", peer_protocol)
|
||||
fake_peer_ids = [IDFactory() for _ in range(total_peer_count)]
|
||||
peer_protocol = {peer_id: PROTOCOL_ID for peer_id in fake_peer_ids}
|
||||
monkeypatch.setattr(pubsubs_gsub[0].router, "peer_protocol", peer_protocol)
|
||||
|
||||
peer_topics = {topic: set(fake_peer_ids)}
|
||||
# Monkeypatch the peer subscriptions
|
||||
monkeypatch.setattr(pubsubs_gsub[0], "peer_topics", peer_topics)
|
||||
peer_topics = {topic: set(fake_peer_ids)}
|
||||
# Monkeypatch the peer subscriptions
|
||||
monkeypatch.setattr(pubsubs_gsub[0], "peer_topics", peer_topics)
|
||||
|
||||
mesh_peer_indices = random.sample(range(total_peer_count), initial_mesh_peer_count)
|
||||
mesh_peers = [fake_peer_ids[i] for i in mesh_peer_indices]
|
||||
router_mesh = {topic: set(mesh_peers)}
|
||||
# Monkeypatch our mesh peers
|
||||
monkeypatch.setattr(pubsubs_gsub[0].router, "mesh", router_mesh)
|
||||
mesh_peer_indices = random.sample(
|
||||
range(total_peer_count), initial_mesh_peer_count
|
||||
)
|
||||
mesh_peers = [fake_peer_ids[i] for i in mesh_peer_indices]
|
||||
router_mesh = {topic: set(mesh_peers)}
|
||||
# Monkeypatch our mesh peers
|
||||
monkeypatch.setattr(pubsubs_gsub[0].router, "mesh", router_mesh)
|
||||
|
||||
peers_to_graft, peers_to_prune = pubsubs_gsub[0].router.mesh_heartbeat()
|
||||
if initial_mesh_peer_count > GOSSIPSUB_PARAMS.degree:
|
||||
# If number of initial mesh peers is more than `GossipSubDegree`, we should PRUNE mesh peers
|
||||
assert len(peers_to_graft) == 0
|
||||
assert len(peers_to_prune) == initial_mesh_peer_count - GOSSIPSUB_PARAMS.degree
|
||||
for peer in peers_to_prune:
|
||||
assert peer in mesh_peers
|
||||
elif initial_mesh_peer_count < GOSSIPSUB_PARAMS.degree:
|
||||
# If number of initial mesh peers is less than `GossipSubDegree`, we should GRAFT more peers
|
||||
assert len(peers_to_prune) == 0
|
||||
assert len(peers_to_graft) == GOSSIPSUB_PARAMS.degree - initial_mesh_peer_count
|
||||
for peer in peers_to_graft:
|
||||
assert peer not in mesh_peers
|
||||
else:
|
||||
assert len(peers_to_prune) == 0 and len(peers_to_graft) == 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"num_hosts, gossipsub_params", ((1, GossipsubParams(heartbeat_initial_delay=100)),)
|
||||
)
|
||||
@pytest.mark.parametrize("initial_peer_count", (1, 4, 7))
|
||||
@pytest.mark.asyncio
|
||||
async def test_gossip_heartbeat(
|
||||
num_hosts, initial_peer_count, pubsubs_gsub, hosts, monkeypatch
|
||||
):
|
||||
# The problem is that I can not set it up so that we have peers subscribe to the topic
|
||||
# but not being part of our mesh peers (as these peers are the peers to GRAFT).
|
||||
# So I monkeypatch the peer subscriptions and our mesh peers.
|
||||
total_peer_count = 28
|
||||
topic_mesh = "TEST_GOSSIP_HEARTBEAT_1"
|
||||
topic_fanout = "TEST_GOSSIP_HEARTBEAT_2"
|
||||
|
||||
fake_peer_ids = [
|
||||
ID((i).to_bytes(2, byteorder="big")) for i in range(total_peer_count)
|
||||
]
|
||||
peer_protocol = {peer_id: PROTOCOL_ID for peer_id in fake_peer_ids}
|
||||
monkeypatch.setattr(pubsubs_gsub[0].router, "peer_protocol", peer_protocol)
|
||||
|
||||
topic_mesh_peer_count = 14
|
||||
# Split into mesh peers and fanout peers
|
||||
peer_topics = {
|
||||
topic_mesh: set(fake_peer_ids[:topic_mesh_peer_count]),
|
||||
topic_fanout: set(fake_peer_ids[topic_mesh_peer_count:]),
|
||||
}
|
||||
# Monkeypatch the peer subscriptions
|
||||
monkeypatch.setattr(pubsubs_gsub[0], "peer_topics", peer_topics)
|
||||
|
||||
mesh_peer_indices = random.sample(range(topic_mesh_peer_count), initial_peer_count)
|
||||
mesh_peers = [fake_peer_ids[i] for i in mesh_peer_indices]
|
||||
router_mesh = {topic_mesh: set(mesh_peers)}
|
||||
# Monkeypatch our mesh peers
|
||||
monkeypatch.setattr(pubsubs_gsub[0].router, "mesh", router_mesh)
|
||||
fanout_peer_indices = random.sample(
|
||||
range(topic_mesh_peer_count, total_peer_count), initial_peer_count
|
||||
)
|
||||
fanout_peers = [fake_peer_ids[i] for i in fanout_peer_indices]
|
||||
router_fanout = {topic_fanout: set(fanout_peers)}
|
||||
# Monkeypatch our fanout peers
|
||||
monkeypatch.setattr(pubsubs_gsub[0].router, "fanout", router_fanout)
|
||||
|
||||
def window(topic):
|
||||
if topic == topic_mesh:
|
||||
return [topic_mesh]
|
||||
elif topic == topic_fanout:
|
||||
return [topic_fanout]
|
||||
peers_to_graft, peers_to_prune = pubsubs_gsub[0].router.mesh_heartbeat()
|
||||
if initial_mesh_peer_count > pubsubs_gsub[0].router.degree:
|
||||
# If number of initial mesh peers is more than `GossipSubDegree`,
|
||||
# we should PRUNE mesh peers
|
||||
assert len(peers_to_graft) == 0
|
||||
assert (
|
||||
len(peers_to_prune)
|
||||
== initial_mesh_peer_count - pubsubs_gsub[0].router.degree
|
||||
)
|
||||
for peer in peers_to_prune:
|
||||
assert peer in mesh_peers
|
||||
elif initial_mesh_peer_count < pubsubs_gsub[0].router.degree:
|
||||
# If number of initial mesh peers is less than `GossipSubDegree`,
|
||||
# we should GRAFT more peers
|
||||
assert len(peers_to_prune) == 0
|
||||
assert (
|
||||
len(peers_to_graft)
|
||||
== pubsubs_gsub[0].router.degree - initial_mesh_peer_count
|
||||
)
|
||||
for peer in peers_to_graft:
|
||||
assert peer not in mesh_peers
|
||||
else:
|
||||
return []
|
||||
assert len(peers_to_prune) == 0 and len(peers_to_graft) == 0
|
||||
|
||||
# Monkeypatch the memory cache messages
|
||||
monkeypatch.setattr(pubsubs_gsub[0].router.mcache, "window", window)
|
||||
|
||||
peers_to_gossip = pubsubs_gsub[0].router.gossip_heartbeat()
|
||||
# If our mesh peer count is less than `GossipSubDegree`, we should gossip to up to
|
||||
# `GossipSubDegree` peers (exclude mesh peers).
|
||||
if topic_mesh_peer_count - initial_peer_count < GOSSIPSUB_PARAMS.degree:
|
||||
# The same goes for fanout so it's two times the number of peers to gossip.
|
||||
assert len(peers_to_gossip) == 2 * (topic_mesh_peer_count - initial_peer_count)
|
||||
elif topic_mesh_peer_count - initial_peer_count >= GOSSIPSUB_PARAMS.degree:
|
||||
assert len(peers_to_gossip) == 2 * (GOSSIPSUB_PARAMS.degree)
|
||||
@pytest.mark.parametrize("initial_peer_count", (1, 4, 7))
|
||||
@pytest.mark.trio
|
||||
async def test_gossip_heartbeat(initial_peer_count, monkeypatch):
|
||||
async with PubsubFactory.create_batch_with_gossipsub(
|
||||
1, heartbeat_initial_delay=100
|
||||
) as pubsubs_gsub:
|
||||
# The problem is that I can not set it up so that we have peers subscribe to the topic
|
||||
# but not being part of our mesh peers (as these peers are the peers to GRAFT).
|
||||
# So I monkeypatch the peer subscriptions and our mesh peers.
|
||||
total_peer_count = 28
|
||||
topic_mesh = "TEST_GOSSIP_HEARTBEAT_1"
|
||||
topic_fanout = "TEST_GOSSIP_HEARTBEAT_2"
|
||||
|
||||
for peer in peers_to_gossip:
|
||||
if peer in peer_topics[topic_mesh]:
|
||||
# Check that the peer to gossip to is not in our mesh peers
|
||||
assert peer not in mesh_peers
|
||||
assert topic_mesh in peers_to_gossip[peer]
|
||||
elif peer in peer_topics[topic_fanout]:
|
||||
# Check that the peer to gossip to is not in our fanout peers
|
||||
assert peer not in fanout_peers
|
||||
assert topic_fanout in peers_to_gossip[peer]
|
||||
fake_peer_ids = [IDFactory() for _ in range(total_peer_count)]
|
||||
peer_protocol = {peer_id: PROTOCOL_ID for peer_id in fake_peer_ids}
|
||||
monkeypatch.setattr(pubsubs_gsub[0].router, "peer_protocol", peer_protocol)
|
||||
|
||||
topic_mesh_peer_count = 14
|
||||
# Split into mesh peers and fanout peers
|
||||
peer_topics = {
|
||||
topic_mesh: set(fake_peer_ids[:topic_mesh_peer_count]),
|
||||
topic_fanout: set(fake_peer_ids[topic_mesh_peer_count:]),
|
||||
}
|
||||
# Monkeypatch the peer subscriptions
|
||||
monkeypatch.setattr(pubsubs_gsub[0], "peer_topics", peer_topics)
|
||||
|
||||
mesh_peer_indices = random.sample(
|
||||
range(topic_mesh_peer_count), initial_peer_count
|
||||
)
|
||||
mesh_peers = [fake_peer_ids[i] for i in mesh_peer_indices]
|
||||
router_mesh = {topic_mesh: set(mesh_peers)}
|
||||
# Monkeypatch our mesh peers
|
||||
monkeypatch.setattr(pubsubs_gsub[0].router, "mesh", router_mesh)
|
||||
fanout_peer_indices = random.sample(
|
||||
range(topic_mesh_peer_count, total_peer_count), initial_peer_count
|
||||
)
|
||||
fanout_peers = [fake_peer_ids[i] for i in fanout_peer_indices]
|
||||
router_fanout = {topic_fanout: set(fanout_peers)}
|
||||
# Monkeypatch our fanout peers
|
||||
monkeypatch.setattr(pubsubs_gsub[0].router, "fanout", router_fanout)
|
||||
|
||||
def window(topic):
|
||||
if topic == topic_mesh:
|
||||
return [topic_mesh]
|
||||
elif topic == topic_fanout:
|
||||
return [topic_fanout]
|
||||
else:
|
||||
return []
|
||||
|
||||
# Monkeypatch the memory cache messages
|
||||
monkeypatch.setattr(pubsubs_gsub[0].router.mcache, "window", window)
|
||||
|
||||
peers_to_gossip = pubsubs_gsub[0].router.gossip_heartbeat()
|
||||
# If our mesh peer count is less than `GossipSubDegree`, we should gossip to up to
|
||||
# `GossipSubDegree` peers (exclude mesh peers).
|
||||
if topic_mesh_peer_count - initial_peer_count < pubsubs_gsub[0].router.degree:
|
||||
# The same goes for fanout so it's two times the number of peers to gossip.
|
||||
assert len(peers_to_gossip) == 2 * (
|
||||
topic_mesh_peer_count - initial_peer_count
|
||||
)
|
||||
elif (
|
||||
topic_mesh_peer_count - initial_peer_count >= pubsubs_gsub[0].router.degree
|
||||
):
|
||||
assert len(peers_to_gossip) == 2 * (pubsubs_gsub[0].router.degree)
|
||||
|
||||
for peer in peers_to_gossip:
|
||||
if peer in peer_topics[topic_mesh]:
|
||||
# Check that the peer to gossip to is not in our mesh peers
|
||||
assert peer not in mesh_peers
|
||||
assert topic_mesh in peers_to_gossip[peer]
|
||||
elif peer in peer_topics[topic_fanout]:
|
||||
# Check that the peer to gossip to is not in our fanout peers
|
||||
assert peer not in fanout_peers
|
||||
assert topic_fanout in peers_to_gossip[peer]
|
||||
|
|
|
@ -3,25 +3,25 @@ import functools
|
|||
import pytest
|
||||
|
||||
from libp2p.tools.constants import FLOODSUB_PROTOCOL_ID
|
||||
from libp2p.tools.factories import GossipsubFactory
|
||||
from libp2p.tools.factories import PubsubFactory
|
||||
from libp2p.tools.pubsub.floodsub_integration_test_settings import (
|
||||
floodsub_protocol_pytest_params,
|
||||
perform_test_from_obj,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gossipsub_initialize_with_floodsub_protocol():
|
||||
GossipsubFactory(protocols=[FLOODSUB_PROTOCOL_ID])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("test_case_obj", floodsub_protocol_pytest_params)
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.trio
|
||||
@pytest.mark.slow
|
||||
async def test_gossipsub_run_with_floodsub_tests(test_case_obj):
|
||||
await perform_test_from_obj(
|
||||
test_case_obj,
|
||||
functools.partial(
|
||||
GossipsubFactory, degree=3, degree_low=2, degree_high=4, time_to_live=30
|
||||
PubsubFactory.create_batch_with_gossipsub,
|
||||
protocols=[FLOODSUB_PROTOCOL_ID],
|
||||
degree=3,
|
||||
degree_low=2,
|
||||
degree_high=4,
|
||||
time_to_live=30,
|
||||
),
|
||||
)
|
||||
|
|
|
@ -1,5 +1,3 @@
|
|||
import pytest
|
||||
|
||||
from libp2p.pubsub.mcache import MessageCache
|
||||
|
||||
|
||||
|
@ -12,8 +10,7 @@ class Msg:
|
|||
self.from_id = from_id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcache():
|
||||
def test_mcache():
|
||||
# Ported from:
|
||||
# https://github.com/libp2p/go-libp2p-pubsub/blob/51b7501433411b5096cac2b4994a36a68515fc03/mcache_test.go
|
||||
mcache = MessageCache(3, 5)
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,84 @@
|
|||
import math
|
||||
|
||||
import pytest
|
||||
import trio
|
||||
|
||||
from libp2p.pubsub.pb import rpc_pb2
|
||||
from libp2p.pubsub.subscription import TrioSubscriptionAPI
|
||||
|
||||
GET_TIMEOUT = 0.001
|
||||
|
||||
|
||||
def make_trio_subscription():
|
||||
send_channel, receive_channel = trio.open_memory_channel(math.inf)
|
||||
|
||||
async def unsubscribe_fn():
|
||||
await send_channel.aclose()
|
||||
|
||||
return (
|
||||
send_channel,
|
||||
TrioSubscriptionAPI(receive_channel, unsubscribe_fn=unsubscribe_fn),
|
||||
)
|
||||
|
||||
|
||||
def make_pubsub_msg():
|
||||
return rpc_pb2.Message()
|
||||
|
||||
|
||||
async def send_something(send_channel):
|
||||
msg = make_pubsub_msg()
|
||||
await send_channel.send(msg)
|
||||
return msg
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_trio_subscription_get():
|
||||
send_channel, sub = make_trio_subscription()
|
||||
data_0 = await send_something(send_channel)
|
||||
data_1 = await send_something(send_channel)
|
||||
assert data_0 == await sub.get()
|
||||
assert data_1 == await sub.get()
|
||||
# No more message
|
||||
with pytest.raises(trio.TooSlowError):
|
||||
with trio.fail_after(GET_TIMEOUT):
|
||||
await sub.get()
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_trio_subscription_iter():
|
||||
send_channel, sub = make_trio_subscription()
|
||||
received_data = []
|
||||
|
||||
async def iter_subscriptions(subscription):
|
||||
async for data in sub:
|
||||
received_data.append(data)
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(iter_subscriptions, sub)
|
||||
await send_something(send_channel)
|
||||
await send_something(send_channel)
|
||||
await send_channel.aclose()
|
||||
|
||||
assert len(received_data) == 2
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_trio_subscription_unsubscribe():
|
||||
send_channel, sub = make_trio_subscription()
|
||||
await sub.unsubscribe()
|
||||
# Test: If the subscription is unsubscribed, `send_channel` should be closed.
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
await send_something(send_channel)
|
||||
# Test: No side effect when cancelled twice.
|
||||
await sub.unsubscribe()
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_trio_subscription_async_context_manager():
|
||||
send_channel, sub = make_trio_subscription()
|
||||
async with sub:
|
||||
# Test: `sub` is not cancelled yet, so `send_something` works fine.
|
||||
await send_something(send_channel)
|
||||
# Test: `sub` is cancelled, `send_something` fails
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
await send_something(send_channel)
|
|
@ -1,70 +1,15 @@
|
|||
import asyncio
|
||||
|
||||
import pytest
|
||||
import trio
|
||||
|
||||
from libp2p.crypto.secp256k1 import create_new_key_pair
|
||||
from libp2p.network.connection.raw_connection_interface import IRawConnection
|
||||
from libp2p.peer.id import ID
|
||||
from libp2p.security.secio.transport import NONCE_SIZE, create_secure_session
|
||||
from libp2p.tools.constants import MAX_READ_LEN
|
||||
from libp2p.tools.factories import raw_conn_factory
|
||||
|
||||
|
||||
class InMemoryConnection(IRawConnection):
|
||||
def __init__(self, peer, is_initiator=False):
|
||||
self.peer = peer
|
||||
self.recv_queue = asyncio.Queue()
|
||||
self.send_queue = asyncio.Queue()
|
||||
self.is_initiator = is_initiator
|
||||
|
||||
self.current_msg = None
|
||||
self.current_position = 0
|
||||
|
||||
self.closed = False
|
||||
|
||||
async def write(self, data: bytes) -> int:
|
||||
if self.closed:
|
||||
raise Exception("InMemoryConnection is closed for writing")
|
||||
|
||||
await self.send_queue.put(data)
|
||||
return len(data)
|
||||
|
||||
async def read(self, n: int = -1) -> bytes:
|
||||
"""
|
||||
NOTE: have to buffer the current message and juggle packets
|
||||
off the recv queue to satisfy the semantics of this function.
|
||||
"""
|
||||
if self.closed:
|
||||
raise Exception("InMemoryConnection is closed for reading")
|
||||
|
||||
if not self.current_msg:
|
||||
self.current_msg = await self.recv_queue.get()
|
||||
self.current_position = 0
|
||||
|
||||
if n < 0:
|
||||
msg = self.current_msg
|
||||
self.current_msg = None
|
||||
return msg
|
||||
|
||||
next_msg = self.current_msg[self.current_position : self.current_position + n]
|
||||
self.current_position += n
|
||||
if self.current_position == len(self.current_msg):
|
||||
self.current_msg = None
|
||||
return next_msg
|
||||
|
||||
async def close(self) -> None:
|
||||
self.closed = True
|
||||
|
||||
|
||||
async def create_pipe(local_conn, remote_conn):
|
||||
try:
|
||||
while True:
|
||||
next_msg = await local_conn.send_queue.get()
|
||||
await remote_conn.recv_queue.put(next_msg)
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_secure_session():
|
||||
@pytest.mark.trio
|
||||
async def test_create_secure_session(nursery):
|
||||
local_nonce = b"\x01" * NONCE_SIZE
|
||||
local_key_pair = create_new_key_pair(b"a")
|
||||
local_peer = ID.from_pubkey(local_key_pair.public_key)
|
||||
|
@ -73,30 +18,32 @@ async def test_create_secure_session():
|
|||
remote_key_pair = create_new_key_pair(b"b")
|
||||
remote_peer = ID.from_pubkey(remote_key_pair.public_key)
|
||||
|
||||
local_conn = InMemoryConnection(local_peer, is_initiator=True)
|
||||
remote_conn = InMemoryConnection(remote_peer)
|
||||
async with raw_conn_factory(nursery) as conns:
|
||||
local_conn, remote_conn = conns
|
||||
|
||||
local_pipe_task = asyncio.ensure_future(create_pipe(local_conn, remote_conn))
|
||||
remote_pipe_task = asyncio.ensure_future(create_pipe(remote_conn, local_conn))
|
||||
local_secure_conn, remote_secure_conn = None, None
|
||||
|
||||
local_session_builder = create_secure_session(
|
||||
local_nonce, local_peer, local_key_pair.private_key, local_conn, remote_peer
|
||||
)
|
||||
remote_session_builder = create_secure_session(
|
||||
remote_nonce, remote_peer, remote_key_pair.private_key, remote_conn
|
||||
)
|
||||
local_secure_conn, remote_secure_conn = await asyncio.gather(
|
||||
local_session_builder, remote_session_builder
|
||||
)
|
||||
async def local_create_secure_session():
|
||||
nonlocal local_secure_conn
|
||||
local_secure_conn = await create_secure_session(
|
||||
local_nonce,
|
||||
local_peer,
|
||||
local_key_pair.private_key,
|
||||
local_conn,
|
||||
remote_peer,
|
||||
)
|
||||
|
||||
msg = b"abc"
|
||||
await local_secure_conn.write(msg)
|
||||
received_msg = await remote_secure_conn.read()
|
||||
assert received_msg == msg
|
||||
async def remote_create_secure_session():
|
||||
nonlocal remote_secure_conn
|
||||
remote_secure_conn = await create_secure_session(
|
||||
remote_nonce, remote_peer, remote_key_pair.private_key, remote_conn
|
||||
)
|
||||
|
||||
await asyncio.gather(local_secure_conn.close(), remote_secure_conn.close())
|
||||
async with trio.open_nursery() as nursery_1:
|
||||
nursery_1.start_soon(local_create_secure_session)
|
||||
nursery_1.start_soon(remote_create_secure_session)
|
||||
|
||||
local_pipe_task.cancel()
|
||||
remote_pipe_task.cancel()
|
||||
await local_pipe_task
|
||||
await remote_pipe_task
|
||||
msg = b"abc"
|
||||
await local_secure_conn.write(msg)
|
||||
received_msg = await remote_secure_conn.read(MAX_READ_LEN)
|
||||
assert received_msg == msg
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
import asyncio
|
||||
|
||||
import pytest
|
||||
import trio
|
||||
|
||||
from libp2p import new_node
|
||||
from libp2p import new_host
|
||||
from libp2p.crypto.rsa import create_new_key_pair
|
||||
from libp2p.security.insecure.transport import InsecureSession, InsecureTransport
|
||||
from libp2p.tools.constants import LISTEN_MADDR
|
||||
|
@ -24,42 +23,36 @@ noninitiator_key_pair = create_new_key_pair()
|
|||
async def perform_simple_test(
|
||||
assertion_func, transports_for_initiator, transports_for_noninitiator
|
||||
):
|
||||
|
||||
# Create libp2p nodes and connect them, then secure the connection, then check
|
||||
# the proper security was chosen
|
||||
# TODO: implement -- note we need to introduce the notion of communicating over a raw connection
|
||||
# for testing, we do NOT want to communicate over a stream so we can't just create two nodes
|
||||
# and use their conn because our mplex will internally relay messages to a stream
|
||||
|
||||
node1 = await new_node(
|
||||
key_pair=initiator_key_pair, sec_opt=transports_for_initiator
|
||||
)
|
||||
node2 = await new_node(
|
||||
node1 = new_host(key_pair=initiator_key_pair, sec_opt=transports_for_initiator)
|
||||
node2 = new_host(
|
||||
key_pair=noninitiator_key_pair, sec_opt=transports_for_noninitiator
|
||||
)
|
||||
async with node1.run(listen_addrs=[LISTEN_MADDR]), node2.run(
|
||||
listen_addrs=[LISTEN_MADDR]
|
||||
):
|
||||
await connect(node1, node2)
|
||||
|
||||
await node1.get_network().listen(LISTEN_MADDR)
|
||||
await node2.get_network().listen(LISTEN_MADDR)
|
||||
# Wait a very short period to allow conns to be stored (since the functions
|
||||
# storing the conns are async, they may happen at slightly different times
|
||||
# on each node)
|
||||
await trio.sleep(0.1)
|
||||
|
||||
await connect(node1, node2)
|
||||
# Get conns
|
||||
node1_conn = node1.get_network().connections[peer_id_for_node(node2)]
|
||||
node2_conn = node2.get_network().connections[peer_id_for_node(node1)]
|
||||
|
||||
# Wait a very short period to allow conns to be stored (since the functions
|
||||
# storing the conns are async, they may happen at slightly different times
|
||||
# on each node)
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Get conns
|
||||
node1_conn = node1.get_network().connections[peer_id_for_node(node2)]
|
||||
node2_conn = node2.get_network().connections[peer_id_for_node(node1)]
|
||||
|
||||
# Perform assertion
|
||||
assertion_func(node1_conn.muxed_conn.secured_conn)
|
||||
assertion_func(node2_conn.muxed_conn.secured_conn)
|
||||
|
||||
# Success, terminate pending tasks.
|
||||
# Perform assertion
|
||||
assertion_func(node1_conn.muxed_conn.secured_conn)
|
||||
assertion_func(node2_conn.muxed_conn.secured_conn)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.trio
|
||||
async def test_single_insecure_security_transport_succeeds():
|
||||
transports_for_initiator = {"foo": InsecureTransport(initiator_key_pair)}
|
||||
transports_for_noninitiator = {"foo": InsecureTransport(noninitiator_key_pair)}
|
||||
|
@ -72,7 +65,7 @@ async def test_single_insecure_security_transport_succeeds():
|
|||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.trio
|
||||
async def test_default_insecure_security():
|
||||
transports_for_initiator = None
|
||||
transports_for_noninitiator = None
|
||||
|
|
|
@ -1,5 +1,3 @@
|
|||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from libp2p.tools.factories import mplex_conn_pair_factory, mplex_stream_pair_factory
|
||||
|
@ -7,23 +5,13 @@ from libp2p.tools.factories import mplex_conn_pair_factory, mplex_stream_pair_fa
|
|||
|
||||
@pytest.fixture
|
||||
async def mplex_conn_pair(is_host_secure):
|
||||
mplex_conn_0, swarm_0, mplex_conn_1, swarm_1 = await mplex_conn_pair_factory(
|
||||
is_host_secure
|
||||
)
|
||||
assert mplex_conn_0.is_initiator
|
||||
assert not mplex_conn_1.is_initiator
|
||||
try:
|
||||
yield mplex_conn_0, mplex_conn_1
|
||||
finally:
|
||||
await asyncio.gather(*[swarm_0.close(), swarm_1.close()])
|
||||
async with mplex_conn_pair_factory(is_host_secure) as mplex_conn_pair:
|
||||
assert mplex_conn_pair[0].is_initiator
|
||||
assert not mplex_conn_pair[1].is_initiator
|
||||
yield mplex_conn_pair[0], mplex_conn_pair[1]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def mplex_stream_pair(is_host_secure):
|
||||
mplex_stream_0, swarm_0, mplex_stream_1, swarm_1 = await mplex_stream_pair_factory(
|
||||
is_host_secure
|
||||
)
|
||||
try:
|
||||
yield mplex_stream_0, mplex_stream_1
|
||||
finally:
|
||||
await asyncio.gather(*[swarm_0.close(), swarm_1.close()])
|
||||
async with mplex_stream_pair_factory(is_host_secure) as mplex_stream_pair:
|
||||
yield mplex_stream_pair
|
||||
|
|
|
@ -1,49 +1,40 @@
|
|||
import asyncio
|
||||
|
||||
import pytest
|
||||
import trio
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.trio
|
||||
async def test_mplex_conn(mplex_conn_pair):
|
||||
conn_0, conn_1 = mplex_conn_pair
|
||||
|
||||
assert len(conn_0.streams) == 0
|
||||
assert len(conn_1.streams) == 0
|
||||
assert not conn_0.event_shutting_down.is_set()
|
||||
assert not conn_1.event_shutting_down.is_set()
|
||||
assert not conn_0.event_closed.is_set()
|
||||
assert not conn_1.event_closed.is_set()
|
||||
|
||||
# Test: Open a stream, and both side get 1 more stream.
|
||||
stream_0 = await conn_0.open_stream()
|
||||
await asyncio.sleep(0.01)
|
||||
await trio.sleep(0.01)
|
||||
assert len(conn_0.streams) == 1
|
||||
assert len(conn_1.streams) == 1
|
||||
# Test: From another side.
|
||||
stream_1 = await conn_1.open_stream()
|
||||
await asyncio.sleep(0.01)
|
||||
await trio.sleep(0.01)
|
||||
assert len(conn_0.streams) == 2
|
||||
assert len(conn_1.streams) == 2
|
||||
|
||||
# Close from one side.
|
||||
await conn_0.close()
|
||||
# Sleep for a while for both side to handle `close`.
|
||||
await asyncio.sleep(0.01)
|
||||
await trio.sleep(0.01)
|
||||
# Test: Both side is closed.
|
||||
assert conn_0.event_shutting_down.is_set()
|
||||
assert conn_0.event_closed.is_set()
|
||||
assert conn_1.event_shutting_down.is_set()
|
||||
assert conn_1.event_closed.is_set()
|
||||
assert conn_0.is_closed
|
||||
assert conn_1.is_closed
|
||||
# Test: All streams should have been closed.
|
||||
assert stream_0.event_remote_closed.is_set()
|
||||
assert stream_0.event_reset.is_set()
|
||||
assert stream_0.event_local_closed.is_set()
|
||||
assert conn_0.streams is None
|
||||
# Test: All streams on the other side are also closed.
|
||||
assert stream_1.event_remote_closed.is_set()
|
||||
assert stream_1.event_reset.is_set()
|
||||
assert stream_1.event_local_closed.is_set()
|
||||
assert conn_1.streams is None
|
||||
|
||||
# Test: No effect to close more than once between two side.
|
||||
await conn_0.close()
|
||||
|
|
|
@ -1,25 +1,48 @@
|
|||
import asyncio
|
||||
|
||||
import pytest
|
||||
import trio
|
||||
from trio.testing import wait_all_tasks_blocked
|
||||
|
||||
from libp2p.stream_muxer.mplex.exceptions import (
|
||||
MplexStreamClosed,
|
||||
MplexStreamEOF,
|
||||
MplexStreamReset,
|
||||
)
|
||||
from libp2p.stream_muxer.mplex.mplex import MPLEX_MESSAGE_CHANNEL_SIZE
|
||||
from libp2p.tools.constants import MAX_READ_LEN
|
||||
|
||||
DATA = b"data_123"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.trio
|
||||
async def test_mplex_stream_read_write(mplex_stream_pair):
|
||||
stream_0, stream_1 = mplex_stream_pair
|
||||
await stream_0.write(DATA)
|
||||
assert (await stream_1.read(MAX_READ_LEN)) == DATA
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.trio
|
||||
async def test_mplex_stream_full_buffer(mplex_stream_pair):
|
||||
stream_0, stream_1 = mplex_stream_pair
|
||||
# Test: The message channel is of size `MPLEX_MESSAGE_CHANNEL_SIZE`.
|
||||
# It should be fine to read even there are already `MPLEX_MESSAGE_CHANNEL_SIZE`
|
||||
# messages arriving.
|
||||
for _ in range(MPLEX_MESSAGE_CHANNEL_SIZE):
|
||||
await stream_0.write(DATA)
|
||||
await wait_all_tasks_blocked()
|
||||
# Sanity check
|
||||
assert MAX_READ_LEN >= MPLEX_MESSAGE_CHANNEL_SIZE * len(DATA)
|
||||
assert (await stream_1.read(MAX_READ_LEN)) == MPLEX_MESSAGE_CHANNEL_SIZE * DATA
|
||||
|
||||
# Test: Read after `MPLEX_MESSAGE_CHANNEL_SIZE + 1` messages has arrived, which
|
||||
# exceeds the channel size. The stream should have been reset.
|
||||
for _ in range(MPLEX_MESSAGE_CHANNEL_SIZE + 1):
|
||||
await stream_0.write(DATA)
|
||||
await wait_all_tasks_blocked()
|
||||
with pytest.raises(MplexStreamReset):
|
||||
await stream_1.read(MAX_READ_LEN)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_mplex_stream_pair_read_until_eof(mplex_stream_pair):
|
||||
read_bytes = bytearray()
|
||||
stream_0, stream_1 = mplex_stream_pair
|
||||
|
@ -27,43 +50,46 @@ async def test_mplex_stream_pair_read_until_eof(mplex_stream_pair):
|
|||
async def read_until_eof():
|
||||
read_bytes.extend(await stream_1.read())
|
||||
|
||||
task = asyncio.ensure_future(read_until_eof())
|
||||
|
||||
expected_data = bytearray()
|
||||
|
||||
# Test: `read` doesn't return before `close` is called.
|
||||
await stream_0.write(DATA)
|
||||
expected_data.extend(DATA)
|
||||
await asyncio.sleep(0.01)
|
||||
assert len(read_bytes) == 0
|
||||
# Test: `read` doesn't return before `close` is called.
|
||||
await stream_0.write(DATA)
|
||||
expected_data.extend(DATA)
|
||||
await asyncio.sleep(0.01)
|
||||
assert len(read_bytes) == 0
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(read_until_eof)
|
||||
# Test: `read` doesn't return before `close` is called.
|
||||
await stream_0.write(DATA)
|
||||
expected_data.extend(DATA)
|
||||
await trio.sleep(0.01)
|
||||
assert len(read_bytes) == 0
|
||||
# Test: `read` doesn't return before `close` is called.
|
||||
await stream_0.write(DATA)
|
||||
expected_data.extend(DATA)
|
||||
await trio.sleep(0.01)
|
||||
assert len(read_bytes) == 0
|
||||
|
||||
# Test: Close the stream, `read` returns, and receive previous sent data.
|
||||
await stream_0.close()
|
||||
|
||||
# Test: Close the stream, `read` returns, and receive previous sent data.
|
||||
await stream_0.close()
|
||||
await asyncio.sleep(0.01)
|
||||
assert read_bytes == expected_data
|
||||
|
||||
task.cancel()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.trio
|
||||
async def test_mplex_stream_read_after_remote_closed(mplex_stream_pair):
|
||||
stream_0, stream_1 = mplex_stream_pair
|
||||
assert not stream_1.event_remote_closed.is_set()
|
||||
await stream_0.write(DATA)
|
||||
assert not stream_0.event_local_closed.is_set()
|
||||
await trio.sleep(0.01)
|
||||
await wait_all_tasks_blocked()
|
||||
await stream_0.close()
|
||||
await asyncio.sleep(0.01)
|
||||
assert stream_0.event_local_closed.is_set()
|
||||
await trio.sleep(0.01)
|
||||
await wait_all_tasks_blocked()
|
||||
assert stream_1.event_remote_closed.is_set()
|
||||
assert (await stream_1.read(MAX_READ_LEN)) == DATA
|
||||
with pytest.raises(MplexStreamEOF):
|
||||
await stream_1.read(MAX_READ_LEN)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.trio
|
||||
async def test_mplex_stream_read_after_local_reset(mplex_stream_pair):
|
||||
stream_0, stream_1 = mplex_stream_pair
|
||||
await stream_0.reset()
|
||||
|
@ -71,29 +97,30 @@ async def test_mplex_stream_read_after_local_reset(mplex_stream_pair):
|
|||
await stream_0.read(MAX_READ_LEN)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.trio
|
||||
async def test_mplex_stream_read_after_remote_reset(mplex_stream_pair):
|
||||
stream_0, stream_1 = mplex_stream_pair
|
||||
await stream_0.write(DATA)
|
||||
await stream_0.reset()
|
||||
# Sleep to let `stream_1` receive the message.
|
||||
await asyncio.sleep(0.01)
|
||||
await trio.sleep(0.1)
|
||||
await wait_all_tasks_blocked()
|
||||
with pytest.raises(MplexStreamReset):
|
||||
await stream_1.read(MAX_READ_LEN)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.trio
|
||||
async def test_mplex_stream_read_after_remote_closed_and_reset(mplex_stream_pair):
|
||||
stream_0, stream_1 = mplex_stream_pair
|
||||
await stream_0.write(DATA)
|
||||
await stream_0.close()
|
||||
await stream_0.reset()
|
||||
# Sleep to let `stream_1` receive the message.
|
||||
await asyncio.sleep(0.01)
|
||||
await trio.sleep(0.01)
|
||||
assert (await stream_1.read(MAX_READ_LEN)) == DATA
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.trio
|
||||
async def test_mplex_stream_write_after_local_closed(mplex_stream_pair):
|
||||
stream_0, stream_1 = mplex_stream_pair
|
||||
await stream_0.write(DATA)
|
||||
|
@ -102,7 +129,7 @@ async def test_mplex_stream_write_after_local_closed(mplex_stream_pair):
|
|||
await stream_0.write(DATA)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.trio
|
||||
async def test_mplex_stream_write_after_local_reset(mplex_stream_pair):
|
||||
stream_0, stream_1 = mplex_stream_pair
|
||||
await stream_0.reset()
|
||||
|
@ -110,16 +137,16 @@ async def test_mplex_stream_write_after_local_reset(mplex_stream_pair):
|
|||
await stream_0.write(DATA)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.trio
|
||||
async def test_mplex_stream_write_after_remote_reset(mplex_stream_pair):
|
||||
stream_0, stream_1 = mplex_stream_pair
|
||||
await stream_1.reset()
|
||||
await asyncio.sleep(0.01)
|
||||
await trio.sleep(0.01)
|
||||
with pytest.raises(MplexStreamClosed):
|
||||
await stream_0.write(DATA)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.trio
|
||||
async def test_mplex_stream_both_close(mplex_stream_pair):
|
||||
stream_0, stream_1 = mplex_stream_pair
|
||||
# Flags are not set initially.
|
||||
|
@ -133,7 +160,7 @@ async def test_mplex_stream_both_close(mplex_stream_pair):
|
|||
|
||||
# Test: Close one side.
|
||||
await stream_0.close()
|
||||
await asyncio.sleep(0.01)
|
||||
await trio.sleep(0.01)
|
||||
|
||||
assert stream_0.event_local_closed.is_set()
|
||||
assert not stream_1.event_local_closed.is_set()
|
||||
|
@ -145,7 +172,7 @@ async def test_mplex_stream_both_close(mplex_stream_pair):
|
|||
|
||||
# Test: Close the other side.
|
||||
await stream_1.close()
|
||||
await asyncio.sleep(0.01)
|
||||
await trio.sleep(0.01)
|
||||
# Both sides are closed.
|
||||
assert stream_0.event_local_closed.is_set()
|
||||
assert stream_1.event_local_closed.is_set()
|
||||
|
@ -159,11 +186,11 @@ async def test_mplex_stream_both_close(mplex_stream_pair):
|
|||
await stream_0.reset()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.trio
|
||||
async def test_mplex_stream_reset(mplex_stream_pair):
|
||||
stream_0, stream_1 = mplex_stream_pair
|
||||
await stream_0.reset()
|
||||
await asyncio.sleep(0.01)
|
||||
await trio.sleep(0.01)
|
||||
|
||||
# Both sides are closed.
|
||||
assert stream_0.event_local_closed.is_set()
|
||||
|
|
|
@ -1,20 +1,53 @@
|
|||
import asyncio
|
||||
|
||||
from multiaddr import Multiaddr
|
||||
import pytest
|
||||
import trio
|
||||
|
||||
from libp2p.transport.tcp.tcp import _multiaddr_from_socket
|
||||
from libp2p.network.connection.raw_connection import RawConnection
|
||||
from libp2p.tools.constants import LISTEN_MADDR
|
||||
from libp2p.transport.exceptions import OpenConnectionError
|
||||
from libp2p.transport.tcp.tcp import TCP
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiaddr_from_socket():
|
||||
def handler(r, w):
|
||||
@pytest.mark.trio
|
||||
async def test_tcp_listener(nursery):
|
||||
transport = TCP()
|
||||
|
||||
async def handler(tcp_stream):
|
||||
pass
|
||||
|
||||
server = await asyncio.start_server(handler, "127.0.0.1", 8000)
|
||||
assert str(_multiaddr_from_socket(server.sockets[0])) == "/ip4/127.0.0.1/tcp/8000"
|
||||
listener = transport.create_listener(handler)
|
||||
assert len(listener.get_addrs()) == 0
|
||||
await listener.listen(LISTEN_MADDR, nursery)
|
||||
assert len(listener.get_addrs()) == 1
|
||||
await listener.listen(LISTEN_MADDR, nursery)
|
||||
assert len(listener.get_addrs()) == 2
|
||||
|
||||
server = await asyncio.start_server(handler, "127.0.0.1", 0)
|
||||
addr = _multiaddr_from_socket(server.sockets[0])
|
||||
assert addr.value_for_protocol("ip4") == "127.0.0.1"
|
||||
port = addr.value_for_protocol("tcp")
|
||||
assert int(port) > 0
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_tcp_dial(nursery):
|
||||
transport = TCP()
|
||||
raw_conn_other_side = None
|
||||
event = trio.Event()
|
||||
|
||||
async def handler(tcp_stream):
|
||||
nonlocal raw_conn_other_side
|
||||
raw_conn_other_side = RawConnection(tcp_stream, False)
|
||||
event.set()
|
||||
await trio.sleep_forever()
|
||||
|
||||
# Test: `OpenConnectionError` is raised when trying to dial to a port which
|
||||
# no one is not listening to.
|
||||
with pytest.raises(OpenConnectionError):
|
||||
await transport.dial(Multiaddr("/ip4/127.0.0.1/tcp/1"))
|
||||
|
||||
listener = transport.create_listener(handler)
|
||||
await listener.listen(LISTEN_MADDR, nursery)
|
||||
addrs = listener.get_addrs()
|
||||
assert len(addrs) == 1
|
||||
listen_addr = addrs[0]
|
||||
raw_conn = await transport.dial(listen_addr)
|
||||
await event.wait()
|
||||
|
||||
data = b"123"
|
||||
await raw_conn_other_side.write(data)
|
||||
assert (await raw_conn.read(len(data))) == data
|
||||
|
|
|
@ -1,20 +1,13 @@
|
|||
import asyncio
|
||||
import sys
|
||||
from typing import Union
|
||||
|
||||
import anyio
|
||||
from async_exit_stack import AsyncExitStack
|
||||
from p2pclient.datastructures import StreamInfo
|
||||
import pexpect
|
||||
from p2pclient.utils import get_unused_tcp_port
|
||||
import pytest
|
||||
import trio
|
||||
|
||||
from libp2p.io.abc import ReadWriteCloser
|
||||
from libp2p.tools.constants import GOSSIPSUB_PARAMS, LISTEN_MADDR
|
||||
from libp2p.tools.factories import (
|
||||
FloodsubFactory,
|
||||
GossipsubFactory,
|
||||
HostFactory,
|
||||
PubsubFactory,
|
||||
)
|
||||
from libp2p.tools.interop.daemon import Daemon, make_p2pd
|
||||
from libp2p.tools.factories import HostFactory, PubsubFactory
|
||||
from libp2p.tools.interop.daemon import make_p2pd
|
||||
from libp2p.tools.interop.utils import connect
|
||||
|
||||
|
||||
|
@ -23,48 +16,6 @@ def is_host_secure():
|
|||
return False
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def num_hosts():
|
||||
return 3
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def hosts(num_hosts, is_host_secure):
|
||||
_hosts = HostFactory.create_batch(num_hosts, is_secure=is_host_secure)
|
||||
await asyncio.gather(
|
||||
*[_host.get_network().listen(LISTEN_MADDR) for _host in _hosts]
|
||||
)
|
||||
try:
|
||||
yield _hosts
|
||||
finally:
|
||||
# TODO: It's possible that `close` raises exceptions currently,
|
||||
# due to the connection reset things. Though we don't care much about that when
|
||||
# cleaning up the tasks, it is probably better to handle the exceptions properly.
|
||||
await asyncio.gather(
|
||||
*[_host.close() for _host in _hosts], return_exceptions=True
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def proc_factory():
|
||||
procs = []
|
||||
|
||||
def call_proc(cmd, args, logfile=None, encoding=None):
|
||||
if logfile is None:
|
||||
logfile = sys.stdout
|
||||
if encoding is None:
|
||||
encoding = "utf-8"
|
||||
proc = pexpect.spawn(cmd, args, logfile=logfile, encoding=encoding)
|
||||
procs.append(proc)
|
||||
return proc
|
||||
|
||||
try:
|
||||
yield call_proc
|
||||
finally:
|
||||
for proc in procs:
|
||||
proc.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def num_p2pds():
|
||||
return 1
|
||||
|
@ -87,79 +38,57 @@ def is_pubsub_signing_strict():
|
|||
|
||||
@pytest.fixture
|
||||
async def p2pds(
|
||||
num_p2pds,
|
||||
is_host_secure,
|
||||
is_gossipsub,
|
||||
unused_tcp_port_factory,
|
||||
is_pubsub_signing,
|
||||
is_pubsub_signing_strict,
|
||||
num_p2pds, is_host_secure, is_gossipsub, is_pubsub_signing, is_pubsub_signing_strict
|
||||
):
|
||||
p2pds: Union[Daemon, Exception] = await asyncio.gather(
|
||||
*[
|
||||
make_p2pd(
|
||||
unused_tcp_port_factory(),
|
||||
unused_tcp_port_factory(),
|
||||
is_host_secure,
|
||||
is_gossipsub=is_gossipsub,
|
||||
is_pubsub_signing=is_pubsub_signing,
|
||||
is_pubsub_signing_strict=is_pubsub_signing_strict,
|
||||
async with AsyncExitStack() as stack:
|
||||
p2pds = [
|
||||
await stack.enter_async_context(
|
||||
make_p2pd(
|
||||
get_unused_tcp_port(),
|
||||
get_unused_tcp_port(),
|
||||
is_host_secure,
|
||||
is_gossipsub=is_gossipsub,
|
||||
is_pubsub_signing=is_pubsub_signing,
|
||||
is_pubsub_signing_strict=is_pubsub_signing_strict,
|
||||
)
|
||||
)
|
||||
for _ in range(num_p2pds)
|
||||
],
|
||||
return_exceptions=True,
|
||||
)
|
||||
p2pds_succeeded = tuple(p2pd for p2pd in p2pds if isinstance(p2pd, Daemon))
|
||||
if len(p2pds_succeeded) != len(p2pds):
|
||||
# Not all succeeded. Close the succeeded ones and print the failed ones(exceptions).
|
||||
await asyncio.gather(*[p2pd.close() for p2pd in p2pds_succeeded])
|
||||
exceptions = tuple(p2pd for p2pd in p2pds if isinstance(p2pd, Exception))
|
||||
raise Exception(f"not all p2pds succeed: first exception={exceptions[0]}")
|
||||
try:
|
||||
yield p2pds
|
||||
finally:
|
||||
await asyncio.gather(*[p2pd.close() for p2pd in p2pds])
|
||||
]
|
||||
try:
|
||||
yield p2pds
|
||||
finally:
|
||||
for p2pd in p2pds:
|
||||
await p2pd.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pubsubs(num_hosts, hosts, is_gossipsub, is_pubsub_signing_strict):
|
||||
async def pubsubs(num_hosts, is_host_secure, is_gossipsub, is_pubsub_signing_strict):
|
||||
if is_gossipsub:
|
||||
routers = GossipsubFactory.create_batch(num_hosts, **GOSSIPSUB_PARAMS._asdict())
|
||||
yield PubsubFactory.create_batch_with_gossipsub(
|
||||
num_hosts, is_secure=is_host_secure, strict_signing=is_pubsub_signing_strict
|
||||
)
|
||||
else:
|
||||
routers = FloodsubFactory.create_batch(num_hosts)
|
||||
_pubsubs = tuple(
|
||||
PubsubFactory(host=host, router=router, strict_signing=is_pubsub_signing_strict)
|
||||
for host, router in zip(hosts, routers)
|
||||
)
|
||||
yield _pubsubs
|
||||
# TODO: Clean up
|
||||
yield PubsubFactory.create_batch_with_floodsub(
|
||||
num_hosts, is_host_secure, strict_signing=is_pubsub_signing_strict
|
||||
)
|
||||
|
||||
|
||||
class DaemonStream(ReadWriteCloser):
|
||||
stream_info: StreamInfo
|
||||
reader: asyncio.StreamReader
|
||||
writer: asyncio.StreamWriter
|
||||
stream: anyio.abc.SocketStream
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
stream_info: StreamInfo,
|
||||
reader: asyncio.StreamReader,
|
||||
writer: asyncio.StreamWriter,
|
||||
) -> None:
|
||||
def __init__(self, stream_info: StreamInfo, stream: anyio.abc.SocketStream) -> None:
|
||||
self.stream_info = stream_info
|
||||
self.reader = reader
|
||||
self.writer = writer
|
||||
self.stream = stream
|
||||
|
||||
async def close(self) -> None:
|
||||
self.writer.close()
|
||||
if sys.version_info < (3, 7):
|
||||
return
|
||||
await self.writer.wait_closed()
|
||||
await self.stream.close()
|
||||
|
||||
async def read(self, n: int = -1) -> bytes:
|
||||
return await self.reader.read(n)
|
||||
async def read(self, n: int = None) -> bytes:
|
||||
return await self.stream.receive_some(n)
|
||||
|
||||
async def write(self, data: bytes) -> int:
|
||||
return self.writer.write(data)
|
||||
async def write(self, data: bytes) -> None:
|
||||
return await self.stream.send_all(data)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -168,40 +97,38 @@ async def is_to_fail_daemon_stream():
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
async def py_to_daemon_stream_pair(hosts, p2pds, is_to_fail_daemon_stream):
|
||||
assert len(hosts) >= 1
|
||||
assert len(p2pds) >= 1
|
||||
host = hosts[0]
|
||||
p2pd = p2pds[0]
|
||||
protocol_id = "/protocol/id/123"
|
||||
stream_py = None
|
||||
stream_daemon = None
|
||||
event_stream_handled = asyncio.Event()
|
||||
await connect(host, p2pd)
|
||||
async def py_to_daemon_stream_pair(p2pds, is_host_secure, is_to_fail_daemon_stream):
|
||||
async with HostFactory.create_batch_and_listen(is_host_secure, 1) as hosts:
|
||||
assert len(p2pds) >= 1
|
||||
host = hosts[0]
|
||||
p2pd = p2pds[0]
|
||||
protocol_id = "/protocol/id/123"
|
||||
stream_py = None
|
||||
stream_daemon = None
|
||||
event_stream_handled = trio.Event()
|
||||
await connect(host, p2pd)
|
||||
|
||||
async def daemon_stream_handler(stream_info, reader, writer):
|
||||
nonlocal stream_daemon
|
||||
stream_daemon = DaemonStream(stream_info, reader, writer)
|
||||
event_stream_handled.set()
|
||||
async def daemon_stream_handler(stream_info, stream):
|
||||
nonlocal stream_daemon
|
||||
stream_daemon = DaemonStream(stream_info, stream)
|
||||
event_stream_handled.set()
|
||||
await trio.hazmat.checkpoint()
|
||||
|
||||
await p2pd.control.stream_handler(protocol_id, daemon_stream_handler)
|
||||
# Sleep for a while to wait for the handler being registered.
|
||||
await asyncio.sleep(0.01)
|
||||
await p2pd.control.stream_handler(protocol_id, daemon_stream_handler)
|
||||
# Sleep for a while to wait for the handler being registered.
|
||||
await trio.sleep(0.01)
|
||||
|
||||
if is_to_fail_daemon_stream:
|
||||
# FIXME: This is a workaround to make daemon reset the stream.
|
||||
# We intentionally close the listener on the python side, it makes the connection from
|
||||
# daemon to us fail, and therefore the daemon resets the opened stream on their side.
|
||||
# Reference: https://github.com/libp2p/go-libp2p-daemon/blob/b95e77dbfcd186ccf817f51e95f73f9fd5982600/stream.go#L47-L50 # noqa: E501
|
||||
# We need it because we want to test against `stream_py` after the remote side(daemon)
|
||||
# is reset. This should be removed after the API `stream.reset` is exposed in daemon
|
||||
# some day.
|
||||
listener = p2pds[0].control.control.listener
|
||||
listener.close()
|
||||
if sys.version_info[0:2] > (3, 6):
|
||||
await listener.wait_closed()
|
||||
stream_py = await host.new_stream(p2pd.peer_id, [protocol_id])
|
||||
if not is_to_fail_daemon_stream:
|
||||
await event_stream_handled.wait()
|
||||
# NOTE: If `is_to_fail_daemon_stream == True`, then `stream_daemon == None`.
|
||||
yield stream_py, stream_daemon
|
||||
if is_to_fail_daemon_stream:
|
||||
# FIXME: This is a workaround to make daemon reset the stream.
|
||||
# We intentionally close the listener on the python side, it makes the connection from
|
||||
# daemon to us fail, and therefore the daemon resets the opened stream on their side.
|
||||
# Reference: https://github.com/libp2p/go-libp2p-daemon/blob/b95e77dbfcd186ccf817f51e95f73f9fd5982600/stream.go#L47-L50 # noqa: E501
|
||||
# We need it because we want to test against `stream_py` after the remote side(daemon)
|
||||
# is reset. This should be removed after the API `stream.reset` is exposed in daemon
|
||||
# some day.
|
||||
await p2pds[0].control.control.close()
|
||||
stream_py = await host.new_stream(p2pd.peer_id, [protocol_id])
|
||||
if not is_to_fail_daemon_stream:
|
||||
await event_stream_handled.wait()
|
||||
# NOTE: If `is_to_fail_daemon_stream == True`, then `stream_daemon == None`.
|
||||
yield stream_py, stream_daemon
|
||||
|
|
|
@ -1,26 +1,26 @@
|
|||
import asyncio
|
||||
|
||||
import pytest
|
||||
import trio
|
||||
|
||||
from libp2p.tools.factories import HostFactory
|
||||
from libp2p.tools.interop.utils import connect
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_hosts", (1,))
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect(hosts, p2pds):
|
||||
p2pd = p2pds[0]
|
||||
host = hosts[0]
|
||||
assert len(await p2pd.control.list_peers()) == 0
|
||||
# Test: connect from Py
|
||||
await connect(host, p2pd)
|
||||
assert len(await p2pd.control.list_peers()) == 1
|
||||
# Test: `disconnect` from Py
|
||||
await host.disconnect(p2pd.peer_id)
|
||||
assert len(await p2pd.control.list_peers()) == 0
|
||||
# Test: connect from Go
|
||||
await connect(p2pd, host)
|
||||
assert len(host.get_network().connections) == 1
|
||||
# Test: `disconnect` from Go
|
||||
await p2pd.control.disconnect(host.get_id())
|
||||
await asyncio.sleep(0.01)
|
||||
assert len(host.get_network().connections) == 0
|
||||
@pytest.mark.trio
|
||||
async def test_connect(is_host_secure, p2pds):
|
||||
async with HostFactory.create_batch_and_listen(is_host_secure, 1) as hosts:
|
||||
p2pd = p2pds[0]
|
||||
host = hosts[0]
|
||||
assert len(await p2pd.control.list_peers()) == 0
|
||||
# Test: connect from Py
|
||||
await connect(host, p2pd)
|
||||
assert len(await p2pd.control.list_peers()) == 1
|
||||
# Test: `disconnect` from Py
|
||||
await host.disconnect(p2pd.peer_id)
|
||||
assert len(await p2pd.control.list_peers()) == 0
|
||||
# Test: connect from Go
|
||||
await connect(p2pd, host)
|
||||
assert len(host.get_network().connections) == 1
|
||||
# Test: `disconnect` from Go
|
||||
await p2pd.control.disconnect(host.get_id())
|
||||
await trio.sleep(0.01)
|
||||
assert len(host.get_network().connections) == 0
|
||||
|
|
|
@ -1,82 +1,99 @@
|
|||
import asyncio
|
||||
import re
|
||||
|
||||
from multiaddr import Multiaddr
|
||||
from p2pclient.utils import get_unused_tcp_port
|
||||
import pytest
|
||||
import trio
|
||||
|
||||
from libp2p.peer.peerinfo import info_from_p2p_addr
|
||||
from libp2p.tools.interop.constants import PEXPECT_NEW_LINE
|
||||
from libp2p.peer.peerinfo import PeerInfo, info_from_p2p_addr
|
||||
from libp2p.tools.factories import HostFactory
|
||||
from libp2p.tools.interop.envs import GO_BIN_PATH
|
||||
from libp2p.tools.interop.process import BaseInteractiveProcess
|
||||
from libp2p.typing import TProtocol
|
||||
|
||||
ECHO_PATH = GO_BIN_PATH / "echo"
|
||||
ECHO_PROTOCOL_ID = TProtocol("/echo/1.0.0")
|
||||
|
||||
|
||||
async def make_echo_proc(
|
||||
proc_factory, port: int, is_secure: bool, destination: Multiaddr = None
|
||||
):
|
||||
args = [f"-l={port}"]
|
||||
if not is_secure:
|
||||
args.append("-insecure")
|
||||
if destination is not None:
|
||||
args.append(f"-d={str(destination)}")
|
||||
echo_proc = proc_factory(str(ECHO_PATH), args)
|
||||
await echo_proc.expect(r"I am ([\w\./]+)" + PEXPECT_NEW_LINE, async_=True)
|
||||
maddr_str_ipfs = echo_proc.match.group(1)
|
||||
maddr_str = maddr_str_ipfs.replace("ipfs", "p2p")
|
||||
maddr = Multiaddr(maddr_str)
|
||||
go_pinfo = info_from_p2p_addr(maddr)
|
||||
if destination is None:
|
||||
await echo_proc.expect("listening for connections", async_=True)
|
||||
return echo_proc, go_pinfo
|
||||
class EchoProcess(BaseInteractiveProcess):
|
||||
port: int
|
||||
_peer_info: PeerInfo
|
||||
|
||||
def __init__(
|
||||
self, port: int, is_secure: bool, destination: Multiaddr = None
|
||||
) -> None:
|
||||
args = [f"-l={port}"]
|
||||
if not is_secure:
|
||||
args.append("-insecure")
|
||||
if destination is not None:
|
||||
args.append(f"-d={str(destination)}")
|
||||
|
||||
patterns = [b"I am"]
|
||||
if destination is None:
|
||||
patterns.append(b"listening for connections")
|
||||
|
||||
self.args = args
|
||||
self.cmd = str(ECHO_PATH)
|
||||
self.patterns = patterns
|
||||
self.bytes_read = bytearray()
|
||||
self.event_ready = trio.Event()
|
||||
|
||||
self.port = port
|
||||
self._peer_info = None
|
||||
self.regex_pat = re.compile(br"I am ([\w\./]+)")
|
||||
|
||||
@property
|
||||
def peer_info(self) -> None:
|
||||
if self._peer_info is not None:
|
||||
return self._peer_info
|
||||
if not self.event_ready.is_set():
|
||||
raise Exception("process is not ready yet. failed to parse the peer info")
|
||||
# Example:
|
||||
# b"I am /ip4/127.0.0.1/tcp/56171/ipfs/QmU41TRPs34WWqa1brJEojBLYZKrrBcJq9nyNfVvSrbZUJ\n"
|
||||
m = re.search(br"I am ([\w\./]+)", self.bytes_read)
|
||||
if m is None:
|
||||
raise Exception("failed to find the pattern for the listening multiaddr")
|
||||
maddr_bytes_str_ipfs = m.group(1)
|
||||
maddr_str = maddr_bytes_str_ipfs.decode().replace("ipfs", "p2p")
|
||||
maddr = Multiaddr(maddr_str)
|
||||
self._peer_info = info_from_p2p_addr(maddr)
|
||||
return self._peer_info
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_hosts", (1,))
|
||||
@pytest.mark.asyncio
|
||||
async def test_insecure_conn_py_to_go(
|
||||
hosts, proc_factory, is_host_secure, unused_tcp_port
|
||||
):
|
||||
go_proc, go_pinfo = await make_echo_proc(
|
||||
proc_factory, unused_tcp_port, is_host_secure
|
||||
)
|
||||
@pytest.mark.trio
|
||||
async def test_insecure_conn_py_to_go(is_host_secure):
|
||||
async with HostFactory.create_batch_and_listen(is_host_secure, 1) as hosts:
|
||||
go_proc = EchoProcess(get_unused_tcp_port(), is_host_secure)
|
||||
await go_proc.start()
|
||||
|
||||
host = hosts[0]
|
||||
await host.connect(go_pinfo)
|
||||
await go_proc.expect("swarm listener accepted connection", async_=True)
|
||||
s = await host.new_stream(go_pinfo.peer_id, [ECHO_PROTOCOL_ID])
|
||||
|
||||
await go_proc.expect("Got a new stream!", async_=True)
|
||||
data = "data321123\n"
|
||||
await s.write(data.encode())
|
||||
await go_proc.expect(f"read: {data[:-1]}", async_=True)
|
||||
echoed_resp = await s.read(len(data))
|
||||
assert echoed_resp.decode() == data
|
||||
await s.close()
|
||||
host = hosts[0]
|
||||
peer_info = go_proc.peer_info
|
||||
await host.connect(peer_info)
|
||||
s = await host.new_stream(peer_info.peer_id, [ECHO_PROTOCOL_ID])
|
||||
data = "data321123\n"
|
||||
await s.write(data.encode())
|
||||
echoed_resp = await s.read(len(data))
|
||||
assert echoed_resp.decode() == data
|
||||
await s.close()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_hosts", (1,))
|
||||
@pytest.mark.asyncio
|
||||
async def test_insecure_conn_go_to_py(
|
||||
hosts, proc_factory, is_host_secure, unused_tcp_port
|
||||
):
|
||||
host = hosts[0]
|
||||
expected_data = "Hello, world!\n"
|
||||
reply_data = "Replyooo!\n"
|
||||
event_handler_finished = asyncio.Event()
|
||||
@pytest.mark.trio
|
||||
async def test_insecure_conn_go_to_py(is_host_secure):
|
||||
async with HostFactory.create_batch_and_listen(is_host_secure, 1) as hosts:
|
||||
host = hosts[0]
|
||||
expected_data = "Hello, world!\n"
|
||||
reply_data = "Replyooo!\n"
|
||||
event_handler_finished = trio.Event()
|
||||
|
||||
async def _handle_echo(stream):
|
||||
read_data = await stream.read(len(expected_data))
|
||||
assert read_data == expected_data.encode()
|
||||
event_handler_finished.set()
|
||||
await stream.write(reply_data.encode())
|
||||
await stream.close()
|
||||
async def _handle_echo(stream):
|
||||
read_data = await stream.read(len(expected_data))
|
||||
assert read_data == expected_data.encode()
|
||||
event_handler_finished.set()
|
||||
await stream.write(reply_data.encode())
|
||||
await stream.close()
|
||||
|
||||
host.set_stream_handler(ECHO_PROTOCOL_ID, _handle_echo)
|
||||
py_maddr = host.get_addrs()[0]
|
||||
go_proc, _ = await make_echo_proc(
|
||||
proc_factory, unused_tcp_port, is_host_secure, py_maddr
|
||||
)
|
||||
await go_proc.expect("connect with peer", async_=True)
|
||||
await go_proc.expect("opened stream", async_=True)
|
||||
await event_handler_finished.wait()
|
||||
await go_proc.expect(f"read reply: .*{reply_data.rstrip()}.*", async_=True)
|
||||
host.set_stream_handler(ECHO_PROTOCOL_ID, _handle_echo)
|
||||
py_maddr = host.get_addrs()[0]
|
||||
go_proc = EchoProcess(get_unused_tcp_port(), is_host_secure, py_maddr)
|
||||
await go_proc.start()
|
||||
await event_handler_finished.wait()
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
import asyncio
|
||||
|
||||
import pytest
|
||||
import trio
|
||||
|
||||
from libp2p.network.stream.exceptions import StreamClosed, StreamEOF, StreamReset
|
||||
from libp2p.tools.constants import MAX_READ_LEN
|
||||
|
@ -8,7 +7,7 @@ from libp2p.tools.constants import MAX_READ_LEN
|
|||
DATA = b"data"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.trio
|
||||
async def test_net_stream_read_write(py_to_daemon_stream_pair, p2pds):
|
||||
stream_py, stream_daemon = py_to_daemon_stream_pair
|
||||
assert (
|
||||
|
@ -19,19 +18,19 @@ async def test_net_stream_read_write(py_to_daemon_stream_pair, p2pds):
|
|||
assert (await stream_daemon.read(MAX_READ_LEN)) == DATA
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.trio
|
||||
async def test_net_stream_read_after_remote_closed(py_to_daemon_stream_pair, p2pds):
|
||||
stream_py, stream_daemon = py_to_daemon_stream_pair
|
||||
await stream_daemon.write(DATA)
|
||||
await stream_daemon.close()
|
||||
await asyncio.sleep(0.01)
|
||||
await trio.sleep(0.01)
|
||||
assert (await stream_py.read(MAX_READ_LEN)) == DATA
|
||||
# EOF
|
||||
with pytest.raises(StreamEOF):
|
||||
await stream_py.read(MAX_READ_LEN)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.trio
|
||||
async def test_net_stream_read_after_local_reset(py_to_daemon_stream_pair, p2pds):
|
||||
stream_py, _ = py_to_daemon_stream_pair
|
||||
await stream_py.reset()
|
||||
|
@ -40,15 +39,15 @@ async def test_net_stream_read_after_local_reset(py_to_daemon_stream_pair, p2pds
|
|||
|
||||
|
||||
@pytest.mark.parametrize("is_to_fail_daemon_stream", (True,))
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.trio
|
||||
async def test_net_stream_read_after_remote_reset(py_to_daemon_stream_pair, p2pds):
|
||||
stream_py, _ = py_to_daemon_stream_pair
|
||||
await asyncio.sleep(0.01)
|
||||
await trio.sleep(0.01)
|
||||
with pytest.raises(StreamReset):
|
||||
await stream_py.read(MAX_READ_LEN)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.trio
|
||||
async def test_net_stream_write_after_local_closed(py_to_daemon_stream_pair, p2pds):
|
||||
stream_py, _ = py_to_daemon_stream_pair
|
||||
await stream_py.write(DATA)
|
||||
|
@ -57,7 +56,7 @@ async def test_net_stream_write_after_local_closed(py_to_daemon_stream_pair, p2p
|
|||
await stream_py.write(DATA)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.trio
|
||||
async def test_net_stream_write_after_local_reset(py_to_daemon_stream_pair, p2pds):
|
||||
stream_py, stream_daemon = py_to_daemon_stream_pair
|
||||
await stream_py.reset()
|
||||
|
@ -66,9 +65,9 @@ async def test_net_stream_write_after_local_reset(py_to_daemon_stream_pair, p2pd
|
|||
|
||||
|
||||
@pytest.mark.parametrize("is_to_fail_daemon_stream", (True,))
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.trio
|
||||
async def test_net_stream_write_after_remote_reset(py_to_daemon_stream_pair, p2pds):
|
||||
stream_py, _ = py_to_daemon_stream_pair
|
||||
await asyncio.sleep(0.01)
|
||||
await trio.sleep(0.01)
|
||||
with pytest.raises(StreamClosed):
|
||||
await stream_py.write(DATA)
|
||||
|
|
|
@ -1,11 +1,15 @@
|
|||
import asyncio
|
||||
import functools
|
||||
import math
|
||||
|
||||
from p2pclient.pb import p2pd_pb2
|
||||
import pytest
|
||||
import trio
|
||||
|
||||
from libp2p.io.trio import TrioTCPStream
|
||||
from libp2p.peer.id import ID
|
||||
from libp2p.pubsub.pb import rpc_pb2
|
||||
from libp2p.pubsub.subscription import TrioSubscriptionAPI
|
||||
from libp2p.tools.factories import PubsubFactory
|
||||
from libp2p.tools.interop.utils import connect
|
||||
from libp2p.utils import read_varint_prefixed_bytes
|
||||
|
||||
|
@ -13,26 +17,15 @@ TOPIC_0 = "ABALA"
|
|||
TOPIC_1 = "YOOOO"
|
||||
|
||||
|
||||
async def p2pd_subscribe(p2pd, topic) -> "asyncio.Queue[rpc_pb2.Message]":
|
||||
reader, writer = await p2pd.control.pubsub_subscribe(topic)
|
||||
async def p2pd_subscribe(p2pd, topic, nursery):
|
||||
stream = TrioTCPStream(await p2pd.control.pubsub_subscribe(topic))
|
||||
send_channel, receive_channel = trio.open_memory_channel(math.inf)
|
||||
|
||||
queue = asyncio.Queue()
|
||||
sub = TrioSubscriptionAPI(receive_channel, unsubscribe_fn=stream.close)
|
||||
|
||||
async def _read_pubsub_msg() -> None:
|
||||
writer_closed_task = asyncio.ensure_future(writer.wait_closed())
|
||||
|
||||
while True:
|
||||
done, pending = await asyncio.wait(
|
||||
[read_varint_prefixed_bytes(reader), writer_closed_task],
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
)
|
||||
done_tasks = tuple(done)
|
||||
if writer.is_closing():
|
||||
return
|
||||
read_task = done_tasks[0]
|
||||
# Sanity check
|
||||
assert read_task._coro.__name__ == "read_varint_prefixed_bytes"
|
||||
msg_bytes = read_task.result()
|
||||
msg_bytes = await read_varint_prefixed_bytes(stream)
|
||||
ps_msg = p2pd_pb2.PSMessage()
|
||||
ps_msg.ParseFromString(msg_bytes)
|
||||
# Fill in the message used in py-libp2p
|
||||
|
@ -44,11 +37,10 @@ async def p2pd_subscribe(p2pd, topic) -> "asyncio.Queue[rpc_pb2.Message]":
|
|||
signature=ps_msg.signature,
|
||||
key=ps_msg.key,
|
||||
)
|
||||
queue.put_nowait(msg)
|
||||
await send_channel.send(msg)
|
||||
|
||||
asyncio.ensure_future(_read_pubsub_msg())
|
||||
await asyncio.sleep(0)
|
||||
return queue
|
||||
nursery.start_soon(_read_pubsub_msg)
|
||||
return sub
|
||||
|
||||
|
||||
def validate_pubsub_msg(msg: rpc_pb2.Message, data: bytes, from_peer_id: ID) -> None:
|
||||
|
@ -59,108 +51,119 @@ def validate_pubsub_msg(msg: rpc_pb2.Message, data: bytes, from_peer_id: ID) ->
|
|||
"is_pubsub_signing, is_pubsub_signing_strict", ((True, True), (False, False))
|
||||
)
|
||||
@pytest.mark.parametrize("is_gossipsub", (True, False))
|
||||
@pytest.mark.parametrize("num_hosts, num_p2pds", ((1, 2),))
|
||||
@pytest.mark.asyncio
|
||||
async def test_pubsub(pubsubs, p2pds):
|
||||
#
|
||||
# Test: Recognize pubsub peers on connection.
|
||||
#
|
||||
py_pubsub = pubsubs[0]
|
||||
# go0 <-> py <-> go1
|
||||
await connect(p2pds[0], py_pubsub.host)
|
||||
await connect(py_pubsub.host, p2pds[1])
|
||||
py_peer_id = py_pubsub.host.get_id()
|
||||
# Check pubsub peers
|
||||
pubsub_peers_0 = await p2pds[0].control.pubsub_list_peers("")
|
||||
assert len(pubsub_peers_0) == 1 and pubsub_peers_0[0] == py_peer_id
|
||||
pubsub_peers_1 = await p2pds[1].control.pubsub_list_peers("")
|
||||
assert len(pubsub_peers_1) == 1 and pubsub_peers_1[0] == py_peer_id
|
||||
assert (
|
||||
len(py_pubsub.peers) == 2
|
||||
and p2pds[0].peer_id in py_pubsub.peers
|
||||
and p2pds[1].peer_id in py_pubsub.peers
|
||||
)
|
||||
@pytest.mark.parametrize("num_p2pds", (2,))
|
||||
@pytest.mark.trio
|
||||
async def test_pubsub(
|
||||
p2pds, is_gossipsub, is_host_secure, is_pubsub_signing_strict, nursery
|
||||
):
|
||||
pubsub_factory = None
|
||||
if is_gossipsub:
|
||||
pubsub_factory = PubsubFactory.create_batch_with_gossipsub
|
||||
else:
|
||||
pubsub_factory = PubsubFactory.create_batch_with_floodsub
|
||||
|
||||
#
|
||||
# Test: `subscribe`.
|
||||
#
|
||||
# (name, topics)
|
||||
# (go_0, [0, 1]) <-> (py, [0, 1]) <-> (go_1, [1])
|
||||
sub_py_topic_0 = await py_pubsub.subscribe(TOPIC_0)
|
||||
sub_py_topic_1 = await py_pubsub.subscribe(TOPIC_1)
|
||||
sub_go_0_topic_0 = await p2pd_subscribe(p2pds[0], TOPIC_0)
|
||||
sub_go_0_topic_1 = await p2pd_subscribe(p2pds[0], TOPIC_1)
|
||||
sub_go_1_topic_1 = await p2pd_subscribe(p2pds[1], TOPIC_1)
|
||||
# Check topic peers
|
||||
await asyncio.sleep(0.1)
|
||||
# go_0
|
||||
go_0_topic_0_peers = await p2pds[0].control.pubsub_list_peers(TOPIC_0)
|
||||
assert len(go_0_topic_0_peers) == 1 and py_peer_id == go_0_topic_0_peers[0]
|
||||
go_0_topic_1_peers = await p2pds[0].control.pubsub_list_peers(TOPIC_1)
|
||||
assert len(go_0_topic_1_peers) == 1 and py_peer_id == go_0_topic_1_peers[0]
|
||||
# py
|
||||
py_topic_0_peers = list(py_pubsub.peer_topics[TOPIC_0])
|
||||
assert len(py_topic_0_peers) == 1 and p2pds[0].peer_id == py_topic_0_peers[0]
|
||||
# go_1
|
||||
go_1_topic_1_peers = await p2pds[1].control.pubsub_list_peers(TOPIC_1)
|
||||
assert len(go_1_topic_1_peers) == 1 and py_peer_id == go_1_topic_1_peers[0]
|
||||
async with pubsub_factory(
|
||||
1, is_secure=is_host_secure, strict_signing=is_pubsub_signing_strict
|
||||
) as pubsubs:
|
||||
#
|
||||
# Test: Recognize pubsub peers on connection.
|
||||
#
|
||||
py_pubsub = pubsubs[0]
|
||||
# go0 <-> py <-> go1
|
||||
await connect(p2pds[0], py_pubsub.host)
|
||||
await connect(py_pubsub.host, p2pds[1])
|
||||
py_peer_id = py_pubsub.host.get_id()
|
||||
# Check pubsub peers
|
||||
pubsub_peers_0 = await p2pds[0].control.pubsub_list_peers("")
|
||||
assert len(pubsub_peers_0) == 1 and pubsub_peers_0[0] == py_peer_id
|
||||
pubsub_peers_1 = await p2pds[1].control.pubsub_list_peers("")
|
||||
assert len(pubsub_peers_1) == 1 and pubsub_peers_1[0] == py_peer_id
|
||||
assert (
|
||||
len(py_pubsub.peers) == 2
|
||||
and p2pds[0].peer_id in py_pubsub.peers
|
||||
and p2pds[1].peer_id in py_pubsub.peers
|
||||
)
|
||||
|
||||
#
|
||||
# Test: `publish`
|
||||
#
|
||||
# 1. py publishes
|
||||
# - 1.1. py publishes data_11 to topic_0, py and go_0 receives.
|
||||
# - 1.2. py publishes data_12 to topic_1, all receive.
|
||||
# 2. go publishes
|
||||
# - 2.1. go_0 publishes data_21 to topic_0, py and go_0 receive.
|
||||
# - 2.2. go_1 publishes data_22 to topic_1, all receive.
|
||||
#
|
||||
# Test: `subscribe`.
|
||||
#
|
||||
# (name, topics)
|
||||
# (go_0, [0, 1]) <-> (py, [0, 1]) <-> (go_1, [1])
|
||||
sub_py_topic_0 = await py_pubsub.subscribe(TOPIC_0)
|
||||
sub_py_topic_1 = await py_pubsub.subscribe(TOPIC_1)
|
||||
sub_go_0_topic_0 = await p2pd_subscribe(p2pds[0], TOPIC_0, nursery)
|
||||
sub_go_0_topic_1 = await p2pd_subscribe(p2pds[0], TOPIC_1, nursery)
|
||||
sub_go_1_topic_1 = await p2pd_subscribe(p2pds[1], TOPIC_1, nursery)
|
||||
# Check topic peers
|
||||
await trio.sleep(0.1)
|
||||
# go_0
|
||||
go_0_topic_0_peers = await p2pds[0].control.pubsub_list_peers(TOPIC_0)
|
||||
assert len(go_0_topic_0_peers) == 1 and py_peer_id == go_0_topic_0_peers[0]
|
||||
go_0_topic_1_peers = await p2pds[0].control.pubsub_list_peers(TOPIC_1)
|
||||
assert len(go_0_topic_1_peers) == 1 and py_peer_id == go_0_topic_1_peers[0]
|
||||
# py
|
||||
py_topic_0_peers = list(py_pubsub.peer_topics[TOPIC_0])
|
||||
assert len(py_topic_0_peers) == 1 and p2pds[0].peer_id == py_topic_0_peers[0]
|
||||
# go_1
|
||||
go_1_topic_1_peers = await p2pds[1].control.pubsub_list_peers(TOPIC_1)
|
||||
assert len(go_1_topic_1_peers) == 1 and py_peer_id == go_1_topic_1_peers[0]
|
||||
|
||||
# 1.1. py publishes data_11 to topic_0, py and go_0 receives.
|
||||
data_11 = b"data_11"
|
||||
await py_pubsub.publish(TOPIC_0, data_11)
|
||||
validate_11 = functools.partial(
|
||||
validate_pubsub_msg, data=data_11, from_peer_id=py_peer_id
|
||||
)
|
||||
validate_11(await sub_py_topic_0.get())
|
||||
validate_11(await sub_go_0_topic_0.get())
|
||||
#
|
||||
# Test: `publish`
|
||||
#
|
||||
# 1. py publishes
|
||||
# - 1.1. py publishes data_11 to topic_0, py and go_0 receives.
|
||||
# - 1.2. py publishes data_12 to topic_1, all receive.
|
||||
# 2. go publishes
|
||||
# - 2.1. go_0 publishes data_21 to topic_0, py and go_0 receive.
|
||||
# - 2.2. go_1 publishes data_22 to topic_1, all receive.
|
||||
|
||||
# 1.2. py publishes data_12 to topic_1, all receive.
|
||||
data_12 = b"data_12"
|
||||
validate_12 = functools.partial(
|
||||
validate_pubsub_msg, data=data_12, from_peer_id=py_peer_id
|
||||
)
|
||||
await py_pubsub.publish(TOPIC_1, data_12)
|
||||
validate_12(await sub_py_topic_1.get())
|
||||
validate_12(await sub_go_0_topic_1.get())
|
||||
validate_12(await sub_go_1_topic_1.get())
|
||||
# 1.1. py publishes data_11 to topic_0, py and go_0 receives.
|
||||
data_11 = b"data_11"
|
||||
await py_pubsub.publish(TOPIC_0, data_11)
|
||||
validate_11 = functools.partial(
|
||||
validate_pubsub_msg, data=data_11, from_peer_id=py_peer_id
|
||||
)
|
||||
validate_11(await sub_py_topic_0.get())
|
||||
validate_11(await sub_go_0_topic_0.get())
|
||||
|
||||
# 2.1. go_0 publishes data_21 to topic_0, py and go_0 receive.
|
||||
data_21 = b"data_21"
|
||||
validate_21 = functools.partial(
|
||||
validate_pubsub_msg, data=data_21, from_peer_id=p2pds[0].peer_id
|
||||
)
|
||||
await p2pds[0].control.pubsub_publish(TOPIC_0, data_21)
|
||||
validate_21(await sub_py_topic_0.get())
|
||||
validate_21(await sub_go_0_topic_0.get())
|
||||
# 1.2. py publishes data_12 to topic_1, all receive.
|
||||
data_12 = b"data_12"
|
||||
validate_12 = functools.partial(
|
||||
validate_pubsub_msg, data=data_12, from_peer_id=py_peer_id
|
||||
)
|
||||
await py_pubsub.publish(TOPIC_1, data_12)
|
||||
validate_12(await sub_py_topic_1.get())
|
||||
validate_12(await sub_go_0_topic_1.get())
|
||||
validate_12(await sub_go_1_topic_1.get())
|
||||
|
||||
# 2.2. go_1 publishes data_22 to topic_1, all receive.
|
||||
data_22 = b"data_22"
|
||||
validate_22 = functools.partial(
|
||||
validate_pubsub_msg, data=data_22, from_peer_id=p2pds[1].peer_id
|
||||
)
|
||||
await p2pds[1].control.pubsub_publish(TOPIC_1, data_22)
|
||||
validate_22(await sub_py_topic_1.get())
|
||||
validate_22(await sub_go_0_topic_1.get())
|
||||
validate_22(await sub_go_1_topic_1.get())
|
||||
# 2.1. go_0 publishes data_21 to topic_0, py and go_0 receive.
|
||||
data_21 = b"data_21"
|
||||
validate_21 = functools.partial(
|
||||
validate_pubsub_msg, data=data_21, from_peer_id=p2pds[0].peer_id
|
||||
)
|
||||
await p2pds[0].control.pubsub_publish(TOPIC_0, data_21)
|
||||
validate_21(await sub_py_topic_0.get())
|
||||
validate_21(await sub_go_0_topic_0.get())
|
||||
|
||||
#
|
||||
# Test: `unsubscribe` and re`subscribe`
|
||||
#
|
||||
await py_pubsub.unsubscribe(TOPIC_0)
|
||||
await asyncio.sleep(0.1)
|
||||
assert py_peer_id not in (await p2pds[0].control.pubsub_list_peers(TOPIC_0))
|
||||
assert py_peer_id not in (await p2pds[1].control.pubsub_list_peers(TOPIC_0))
|
||||
await py_pubsub.subscribe(TOPIC_0)
|
||||
await asyncio.sleep(0.1)
|
||||
assert py_peer_id in (await p2pds[0].control.pubsub_list_peers(TOPIC_0))
|
||||
assert py_peer_id in (await p2pds[1].control.pubsub_list_peers(TOPIC_0))
|
||||
# 2.2. go_1 publishes data_22 to topic_1, all receive.
|
||||
data_22 = b"data_22"
|
||||
validate_22 = functools.partial(
|
||||
validate_pubsub_msg, data=data_22, from_peer_id=p2pds[1].peer_id
|
||||
)
|
||||
await p2pds[1].control.pubsub_publish(TOPIC_1, data_22)
|
||||
validate_22(await sub_py_topic_1.get())
|
||||
validate_22(await sub_go_0_topic_1.get())
|
||||
validate_22(await sub_go_1_topic_1.get())
|
||||
|
||||
#
|
||||
# Test: `unsubscribe` and re`subscribe`
|
||||
#
|
||||
await py_pubsub.unsubscribe(TOPIC_0)
|
||||
await trio.sleep(0.1)
|
||||
assert py_peer_id not in (await p2pds[0].control.pubsub_list_peers(TOPIC_0))
|
||||
assert py_peer_id not in (await p2pds[1].control.pubsub_list_peers(TOPIC_0))
|
||||
await py_pubsub.subscribe(TOPIC_0)
|
||||
await trio.sleep(0.1)
|
||||
assert py_peer_id in (await p2pds[0].control.pubsub_list_peers(TOPIC_0))
|
||||
assert py_peer_id in (await p2pds[1].control.pubsub_list_peers(TOPIC_0))
|
||||
|
|
3
tox.ini
3
tox.ini
|
@ -12,7 +12,7 @@ envlist =
|
|||
combine_as_imports=False
|
||||
force_sort_within_sections=True
|
||||
include_trailing_comma=True
|
||||
known_third_party=hypothesis,pytest,p2pclient,pexpect,factory,lru
|
||||
known_third_party=anyio,factory,lru,p2pclient,pytest
|
||||
known_first_party=libp2p
|
||||
line_length=88
|
||||
multi_line_output=3
|
||||
|
@ -58,7 +58,6 @@ commands =
|
|||
[testenv:py37-interop]
|
||||
deps =
|
||||
p2pclient
|
||||
pexpect
|
||||
passenv = CI TRAVIS TRAVIS_* GOPATH
|
||||
extras = test
|
||||
commands =
|
||||
|
|
Loading…
Reference in New Issue