Add test for Swarm.close_peer
This commit is contained in:
parent
6923f257f6
commit
e7304538da
|
@ -23,6 +23,10 @@ from .host_interface import IHost
|
|||
|
||||
|
||||
class BasicHost(IHost):
|
||||
"""
|
||||
BasicHost is a wrapper of a `INetwork` implementation. It performs protocol negotiation
|
||||
on a stream with multistream-select right after a stream is initialized.
|
||||
"""
|
||||
|
||||
_network: INetwork
|
||||
_router: KadmeliaPeerRouter
|
||||
|
@ -31,7 +35,6 @@ class BasicHost(IHost):
|
|||
multiselect: Multiselect
|
||||
multiselect_client: MultiselectClient
|
||||
|
||||
# default options constructor
|
||||
def __init__(self, network: INetwork, router: KadmeliaPeerRouter = None) -> None:
|
||||
self._network = network
|
||||
self._network.set_stream_handler(self._swarm_stream_handler)
|
||||
|
@ -69,6 +72,7 @@ class BasicHost(IHost):
|
|||
"""
|
||||
:return: all the multiaddr addresses this host is listening to
|
||||
"""
|
||||
# TODO: We don't need "/p2p/{peer_id}" postfix actually.
|
||||
p2p_part = multiaddr.Multiaddr("/p2p/{}".format(self.get_id().pretty()))
|
||||
|
||||
addrs: List[multiaddr.Multiaddr] = []
|
||||
|
@ -87,8 +91,6 @@ class BasicHost(IHost):
|
|||
"""
|
||||
self.multiselect.add_handler(protocol_id, stream_handler)
|
||||
|
||||
# `protocol_ids` can be a list of `protocol_id`
|
||||
# stream will decide which `protocol_id` to run on
|
||||
async def new_stream(
|
||||
self, peer_id: ID, protocol_ids: Sequence[TProtocol]
|
||||
) -> INetStream:
|
||||
|
|
|
@ -50,11 +50,12 @@ class SwarmConn(INetConn):
|
|||
task.cancel()
|
||||
|
||||
async def _handle_new_streams(self) -> None:
|
||||
# TODO: Break the loop when anything wrong in the connection.
|
||||
while True:
|
||||
try:
|
||||
stream = await self.conn.accept_stream()
|
||||
except MuxedConnUnavailable:
|
||||
# If there is anything wrong in the MuxedConn,
|
||||
# we should break the loop and close the connection.
|
||||
break
|
||||
# Asynchronously handle the accepted stream, to avoid blocking the next stream.
|
||||
await self.run_task(self._handle_muxed_stream(stream))
|
||||
|
|
|
@ -8,13 +8,13 @@ from libp2p.crypto.keys import KeyPair
|
|||
from libp2p.host.basic_host import BasicHost
|
||||
from libp2p.host.host_interface import IHost
|
||||
from libp2p.network.stream.net_stream_interface import INetStream
|
||||
from libp2p.network.swarm import Swarm
|
||||
from libp2p.pubsub.floodsub import FloodSub
|
||||
from libp2p.pubsub.gossipsub import GossipSub
|
||||
from libp2p.pubsub.pubsub import Pubsub
|
||||
from libp2p.security.base_transport import BaseSecureTransport
|
||||
from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport
|
||||
import libp2p.security.secio.transport as secio
|
||||
from libp2p.stream_muxer.mplex.mplex import Mplex
|
||||
from libp2p.typing import TProtocol
|
||||
from tests.configs import LISTEN_MADDR
|
||||
from tests.pubsub.configs import (
|
||||
|
@ -22,7 +22,7 @@ from tests.pubsub.configs import (
|
|||
GOSSIPSUB_PARAMS,
|
||||
GOSSIPSUB_PROTOCOL_ID,
|
||||
)
|
||||
from tests.utils import connect
|
||||
from tests.utils import connect, connect_swarm
|
||||
|
||||
|
||||
def security_transport_factory(
|
||||
|
@ -34,10 +34,29 @@ def security_transport_factory(
|
|||
return {secio.ID: secio.Transport(key_pair)}
|
||||
|
||||
|
||||
def swarm_factory(is_secure: bool):
|
||||
key_pair = generate_new_rsa_identity()
|
||||
sec_opt = security_transport_factory(is_secure, key_pair)
|
||||
return initialize_default_swarm(key_pair, sec_opt=sec_opt)
|
||||
class SwarmFactory(factory.Factory):
|
||||
class Meta:
|
||||
model = Swarm
|
||||
|
||||
@classmethod
|
||||
def _create(cls, is_secure=False):
|
||||
key_pair = generate_new_rsa_identity()
|
||||
sec_opt = security_transport_factory(is_secure, key_pair)
|
||||
return initialize_default_swarm(key_pair, sec_opt=sec_opt)
|
||||
|
||||
@classmethod
|
||||
async def create_and_listen(cls, is_secure: bool) -> Swarm:
|
||||
swarm = cls._create(is_secure)
|
||||
await swarm.listen(LISTEN_MADDR)
|
||||
return swarm
|
||||
|
||||
@classmethod
|
||||
async def create_batch_and_listen(
|
||||
cls, is_secure: bool, number: int
|
||||
) -> Tuple[Swarm, ...]:
|
||||
return await asyncio.gather(
|
||||
*[cls.create_and_listen(is_secure) for _ in range(number)]
|
||||
)
|
||||
|
||||
|
||||
class HostFactory(factory.Factory):
|
||||
|
@ -47,13 +66,12 @@ class HostFactory(factory.Factory):
|
|||
class Params:
|
||||
is_secure = False
|
||||
|
||||
network = factory.LazyAttribute(lambda o: swarm_factory(o.is_secure))
|
||||
network = factory.LazyAttribute(lambda o: SwarmFactory(o.is_secure))
|
||||
|
||||
@classmethod
|
||||
async def create_and_listen(cls) -> IHost:
|
||||
host = cls()
|
||||
await host.get_network().listen(LISTEN_MADDR)
|
||||
return host
|
||||
async def create_and_listen(cls, is_secure: bool) -> IHost:
|
||||
swarm = await SwarmFactory.create_and_listen(is_secure)
|
||||
return BasicHost(swarm)
|
||||
|
||||
|
||||
class FloodsubFactory(factory.Factory):
|
||||
|
@ -87,24 +105,33 @@ class PubsubFactory(factory.Factory):
|
|||
cache_size = None
|
||||
|
||||
|
||||
async def host_pair_factory() -> Tuple[BasicHost, BasicHost]:
|
||||
async def swarm_pair_factory(is_secure: bool) -> Tuple[Swarm, Swarm]:
|
||||
swarms = await SwarmFactory.create_batch_and_listen(2)
|
||||
await connect_swarm(swarms[0], swarms[1])
|
||||
return swarms[0], swarms[1]
|
||||
|
||||
|
||||
async def host_pair_factory(is_secure) -> Tuple[BasicHost, BasicHost]:
|
||||
hosts = await asyncio.gather(
|
||||
*[HostFactory.create_and_listen(), HostFactory.create_and_listen()]
|
||||
*[
|
||||
HostFactory.create_and_listen(is_secure),
|
||||
HostFactory.create_and_listen(is_secure),
|
||||
]
|
||||
)
|
||||
await connect(hosts[0], hosts[1])
|
||||
return hosts[0], hosts[1]
|
||||
|
||||
|
||||
async def connection_pair_factory() -> Tuple[Mplex, BasicHost, Mplex, BasicHost]:
|
||||
host_0, host_1 = await host_pair_factory()
|
||||
mplex_conn_0 = host_0.get_network().connections[host_1.get_id()]
|
||||
mplex_conn_1 = host_1.get_network().connections[host_0.get_id()]
|
||||
return mplex_conn_0, host_0, mplex_conn_1, host_1
|
||||
# async def connection_pair_factory() -> Tuple[Mplex, BasicHost, Mplex, BasicHost]:
|
||||
# host_0, host_1 = await host_pair_factory()
|
||||
# mplex_conn_0 = host_0.get_network().connections[host_1.get_id()]
|
||||
# mplex_conn_1 = host_1.get_network().connections[host_0.get_id()]
|
||||
# return mplex_conn_0, host_0, mplex_conn_1, host_1
|
||||
|
||||
|
||||
async def net_stream_pair_factory() -> Tuple[
|
||||
INetStream, BasicHost, INetStream, BasicHost
|
||||
]:
|
||||
async def net_stream_pair_factory(
|
||||
is_secure: bool
|
||||
) -> Tuple[INetStream, BasicHost, INetStream, BasicHost]:
|
||||
protocol_id = "/example/id/1"
|
||||
|
||||
stream_1: INetStream
|
||||
|
@ -114,7 +141,7 @@ async def net_stream_pair_factory() -> Tuple[
|
|||
nonlocal stream_1
|
||||
stream_1 = stream
|
||||
|
||||
host_0, host_1 = await host_pair_factory()
|
||||
host_0, host_1 = await host_pair_factory(is_secure)
|
||||
host_1.set_stream_handler(protocol_id, handler)
|
||||
|
||||
stream_0 = await host_0.new_stream(host_1.get_id(), [protocol_id])
|
||||
|
|
|
@ -2,13 +2,22 @@ import asyncio
|
|||
|
||||
import pytest
|
||||
|
||||
from tests.factories import net_stream_pair_factory
|
||||
from tests.factories import net_stream_pair_factory, swarm_pair_factory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def net_stream_pair():
|
||||
stream_0, host_0, stream_1, host_1 = await net_stream_pair_factory()
|
||||
async def net_stream_pair(is_host_secure):
|
||||
stream_0, host_0, stream_1, host_1 = await net_stream_pair_factory(is_host_secure)
|
||||
try:
|
||||
yield stream_0, stream_1
|
||||
finally:
|
||||
await asyncio.gather(*[host_0.close(), host_1.close()])
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def swarm_pair(is_host_secure):
|
||||
swarm_0, swarm_1 = await swarm_pair_factory(is_host_secure)
|
||||
try:
|
||||
yield swarm_0, swarm_1
|
||||
finally:
|
||||
await asyncio.gather(*[swarm_0.close(), swarm_1.close()])
|
||||
|
|
49
tests/network/test_swarm.py
Normal file
49
tests/network/test_swarm.py
Normal file
|
@ -0,0 +1,49 @@
|
|||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.factories import SwarmFactory
|
||||
from tests.utils import connect_swarm
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_swarm_close_peer(is_host_secure):
|
||||
swarms = await SwarmFactory.create_batch_and_listen(is_host_secure, 3)
|
||||
# 0 <> 1 <> 2
|
||||
await connect_swarm(swarms[0], swarms[1])
|
||||
await connect_swarm(swarms[1], swarms[2])
|
||||
|
||||
# peer 1 closes peer 0
|
||||
await swarms[1].close_peer(swarms[0].get_peer_id())
|
||||
await asyncio.sleep(0.01)
|
||||
# 0 1 <> 2
|
||||
assert len(swarms[0].connections) == 0
|
||||
assert (
|
||||
len(swarms[1].connections) == 1
|
||||
and swarms[2].get_peer_id() in swarms[1].connections
|
||||
)
|
||||
|
||||
# peer 1 is closed by peer 2
|
||||
await swarms[2].close_peer(swarms[1].get_peer_id())
|
||||
await asyncio.sleep(0.01)
|
||||
# 0 1 2
|
||||
assert len(swarms[1].connections) == 0 and len(swarms[2].connections) == 0
|
||||
|
||||
await connect_swarm(swarms[0], swarms[1])
|
||||
# 0 <> 1 2
|
||||
assert (
|
||||
len(swarms[0].connections) == 1
|
||||
and swarms[1].get_peer_id() in swarms[0].connections
|
||||
)
|
||||
assert (
|
||||
len(swarms[1].connections) == 1
|
||||
and swarms[0].get_peer_id() in swarms[1].connections
|
||||
)
|
||||
# peer 0 closes peer 1
|
||||
await swarms[0].close_peer(swarms[1].get_peer_id())
|
||||
await asyncio.sleep(0.01)
|
||||
# 0 1 2
|
||||
assert len(swarms[1].connections) == 0 and len(swarms[2].connections) == 0
|
||||
|
||||
# Clean up
|
||||
await asyncio.gather(*[swarm.close() for swarm in swarms])
|
|
@ -5,6 +5,19 @@ from libp2p.peer.peerinfo import info_from_p2p_addr
|
|||
from tests.constants import MAX_READ_LEN
|
||||
|
||||
|
||||
async def connect_swarm(swarm_0, swarm_1):
|
||||
peer_id = swarm_1.get_peer_id()
|
||||
addrs = tuple(
|
||||
addr
|
||||
for transport in swarm_1.listeners.values()
|
||||
for addr in transport.get_addrs()
|
||||
)
|
||||
swarm_0.peerstore.add_addrs(peer_id, addrs, 10000)
|
||||
await swarm_0.dial_peer(peer_id)
|
||||
assert swarm_0.get_peer_id() in swarm_1.connections
|
||||
assert swarm_1.get_peer_id() in swarm_0.connections
|
||||
|
||||
|
||||
async def connect(node1, node2):
|
||||
"""
|
||||
Connect node1 to node2
|
||||
|
|
Loading…
Reference in New Issue
Block a user