Add test for Swarm.close_peer
This commit is contained in:
parent
6923f257f6
commit
e7304538da
|
@ -23,6 +23,10 @@ from .host_interface import IHost
|
||||||
|
|
||||||
|
|
||||||
class BasicHost(IHost):
|
class BasicHost(IHost):
|
||||||
|
"""
|
||||||
|
BasicHost is a wrapper of a `INetwork` implementation. It performs protocol negotiation
|
||||||
|
on a stream with multistream-select right after a stream is initialized.
|
||||||
|
"""
|
||||||
|
|
||||||
_network: INetwork
|
_network: INetwork
|
||||||
_router: KadmeliaPeerRouter
|
_router: KadmeliaPeerRouter
|
||||||
|
@ -31,7 +35,6 @@ class BasicHost(IHost):
|
||||||
multiselect: Multiselect
|
multiselect: Multiselect
|
||||||
multiselect_client: MultiselectClient
|
multiselect_client: MultiselectClient
|
||||||
|
|
||||||
# default options constructor
|
|
||||||
def __init__(self, network: INetwork, router: KadmeliaPeerRouter = None) -> None:
|
def __init__(self, network: INetwork, router: KadmeliaPeerRouter = None) -> None:
|
||||||
self._network = network
|
self._network = network
|
||||||
self._network.set_stream_handler(self._swarm_stream_handler)
|
self._network.set_stream_handler(self._swarm_stream_handler)
|
||||||
|
@ -69,6 +72,7 @@ class BasicHost(IHost):
|
||||||
"""
|
"""
|
||||||
:return: all the multiaddr addresses this host is listening to
|
:return: all the multiaddr addresses this host is listening to
|
||||||
"""
|
"""
|
||||||
|
# TODO: We don't need "/p2p/{peer_id}" postfix actually.
|
||||||
p2p_part = multiaddr.Multiaddr("/p2p/{}".format(self.get_id().pretty()))
|
p2p_part = multiaddr.Multiaddr("/p2p/{}".format(self.get_id().pretty()))
|
||||||
|
|
||||||
addrs: List[multiaddr.Multiaddr] = []
|
addrs: List[multiaddr.Multiaddr] = []
|
||||||
|
@ -87,8 +91,6 @@ class BasicHost(IHost):
|
||||||
"""
|
"""
|
||||||
self.multiselect.add_handler(protocol_id, stream_handler)
|
self.multiselect.add_handler(protocol_id, stream_handler)
|
||||||
|
|
||||||
# `protocol_ids` can be a list of `protocol_id`
|
|
||||||
# stream will decide which `protocol_id` to run on
|
|
||||||
async def new_stream(
|
async def new_stream(
|
||||||
self, peer_id: ID, protocol_ids: Sequence[TProtocol]
|
self, peer_id: ID, protocol_ids: Sequence[TProtocol]
|
||||||
) -> INetStream:
|
) -> INetStream:
|
||||||
|
|
|
@ -50,11 +50,12 @@ class SwarmConn(INetConn):
|
||||||
task.cancel()
|
task.cancel()
|
||||||
|
|
||||||
async def _handle_new_streams(self) -> None:
|
async def _handle_new_streams(self) -> None:
|
||||||
# TODO: Break the loop when anything wrong in the connection.
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
stream = await self.conn.accept_stream()
|
stream = await self.conn.accept_stream()
|
||||||
except MuxedConnUnavailable:
|
except MuxedConnUnavailable:
|
||||||
|
# If there is anything wrong in the MuxedConn,
|
||||||
|
# we should break the loop and close the connection.
|
||||||
break
|
break
|
||||||
# Asynchronously handle the accepted stream, to avoid blocking the next stream.
|
# Asynchronously handle the accepted stream, to avoid blocking the next stream.
|
||||||
await self.run_task(self._handle_muxed_stream(stream))
|
await self.run_task(self._handle_muxed_stream(stream))
|
||||||
|
|
|
@ -8,13 +8,13 @@ from libp2p.crypto.keys import KeyPair
|
||||||
from libp2p.host.basic_host import BasicHost
|
from libp2p.host.basic_host import BasicHost
|
||||||
from libp2p.host.host_interface import IHost
|
from libp2p.host.host_interface import IHost
|
||||||
from libp2p.network.stream.net_stream_interface import INetStream
|
from libp2p.network.stream.net_stream_interface import INetStream
|
||||||
|
from libp2p.network.swarm import Swarm
|
||||||
from libp2p.pubsub.floodsub import FloodSub
|
from libp2p.pubsub.floodsub import FloodSub
|
||||||
from libp2p.pubsub.gossipsub import GossipSub
|
from libp2p.pubsub.gossipsub import GossipSub
|
||||||
from libp2p.pubsub.pubsub import Pubsub
|
from libp2p.pubsub.pubsub import Pubsub
|
||||||
from libp2p.security.base_transport import BaseSecureTransport
|
from libp2p.security.base_transport import BaseSecureTransport
|
||||||
from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport
|
from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport
|
||||||
import libp2p.security.secio.transport as secio
|
import libp2p.security.secio.transport as secio
|
||||||
from libp2p.stream_muxer.mplex.mplex import Mplex
|
|
||||||
from libp2p.typing import TProtocol
|
from libp2p.typing import TProtocol
|
||||||
from tests.configs import LISTEN_MADDR
|
from tests.configs import LISTEN_MADDR
|
||||||
from tests.pubsub.configs import (
|
from tests.pubsub.configs import (
|
||||||
|
@ -22,7 +22,7 @@ from tests.pubsub.configs import (
|
||||||
GOSSIPSUB_PARAMS,
|
GOSSIPSUB_PARAMS,
|
||||||
GOSSIPSUB_PROTOCOL_ID,
|
GOSSIPSUB_PROTOCOL_ID,
|
||||||
)
|
)
|
||||||
from tests.utils import connect
|
from tests.utils import connect, connect_swarm
|
||||||
|
|
||||||
|
|
||||||
def security_transport_factory(
|
def security_transport_factory(
|
||||||
|
@ -34,11 +34,30 @@ def security_transport_factory(
|
||||||
return {secio.ID: secio.Transport(key_pair)}
|
return {secio.ID: secio.Transport(key_pair)}
|
||||||
|
|
||||||
|
|
||||||
def swarm_factory(is_secure: bool):
|
class SwarmFactory(factory.Factory):
|
||||||
|
class Meta:
|
||||||
|
model = Swarm
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _create(cls, is_secure=False):
|
||||||
key_pair = generate_new_rsa_identity()
|
key_pair = generate_new_rsa_identity()
|
||||||
sec_opt = security_transport_factory(is_secure, key_pair)
|
sec_opt = security_transport_factory(is_secure, key_pair)
|
||||||
return initialize_default_swarm(key_pair, sec_opt=sec_opt)
|
return initialize_default_swarm(key_pair, sec_opt=sec_opt)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def create_and_listen(cls, is_secure: bool) -> Swarm:
|
||||||
|
swarm = cls._create(is_secure)
|
||||||
|
await swarm.listen(LISTEN_MADDR)
|
||||||
|
return swarm
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def create_batch_and_listen(
|
||||||
|
cls, is_secure: bool, number: int
|
||||||
|
) -> Tuple[Swarm, ...]:
|
||||||
|
return await asyncio.gather(
|
||||||
|
*[cls.create_and_listen(is_secure) for _ in range(number)]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class HostFactory(factory.Factory):
|
class HostFactory(factory.Factory):
|
||||||
class Meta:
|
class Meta:
|
||||||
|
@ -47,13 +66,12 @@ class HostFactory(factory.Factory):
|
||||||
class Params:
|
class Params:
|
||||||
is_secure = False
|
is_secure = False
|
||||||
|
|
||||||
network = factory.LazyAttribute(lambda o: swarm_factory(o.is_secure))
|
network = factory.LazyAttribute(lambda o: SwarmFactory(o.is_secure))
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def create_and_listen(cls) -> IHost:
|
async def create_and_listen(cls, is_secure: bool) -> IHost:
|
||||||
host = cls()
|
swarm = await SwarmFactory.create_and_listen(is_secure)
|
||||||
await host.get_network().listen(LISTEN_MADDR)
|
return BasicHost(swarm)
|
||||||
return host
|
|
||||||
|
|
||||||
|
|
||||||
class FloodsubFactory(factory.Factory):
|
class FloodsubFactory(factory.Factory):
|
||||||
|
@ -87,24 +105,33 @@ class PubsubFactory(factory.Factory):
|
||||||
cache_size = None
|
cache_size = None
|
||||||
|
|
||||||
|
|
||||||
async def host_pair_factory() -> Tuple[BasicHost, BasicHost]:
|
async def swarm_pair_factory(is_secure: bool) -> Tuple[Swarm, Swarm]:
|
||||||
|
swarms = await SwarmFactory.create_batch_and_listen(2)
|
||||||
|
await connect_swarm(swarms[0], swarms[1])
|
||||||
|
return swarms[0], swarms[1]
|
||||||
|
|
||||||
|
|
||||||
|
async def host_pair_factory(is_secure) -> Tuple[BasicHost, BasicHost]:
|
||||||
hosts = await asyncio.gather(
|
hosts = await asyncio.gather(
|
||||||
*[HostFactory.create_and_listen(), HostFactory.create_and_listen()]
|
*[
|
||||||
|
HostFactory.create_and_listen(is_secure),
|
||||||
|
HostFactory.create_and_listen(is_secure),
|
||||||
|
]
|
||||||
)
|
)
|
||||||
await connect(hosts[0], hosts[1])
|
await connect(hosts[0], hosts[1])
|
||||||
return hosts[0], hosts[1]
|
return hosts[0], hosts[1]
|
||||||
|
|
||||||
|
|
||||||
async def connection_pair_factory() -> Tuple[Mplex, BasicHost, Mplex, BasicHost]:
|
# async def connection_pair_factory() -> Tuple[Mplex, BasicHost, Mplex, BasicHost]:
|
||||||
host_0, host_1 = await host_pair_factory()
|
# host_0, host_1 = await host_pair_factory()
|
||||||
mplex_conn_0 = host_0.get_network().connections[host_1.get_id()]
|
# mplex_conn_0 = host_0.get_network().connections[host_1.get_id()]
|
||||||
mplex_conn_1 = host_1.get_network().connections[host_0.get_id()]
|
# mplex_conn_1 = host_1.get_network().connections[host_0.get_id()]
|
||||||
return mplex_conn_0, host_0, mplex_conn_1, host_1
|
# return mplex_conn_0, host_0, mplex_conn_1, host_1
|
||||||
|
|
||||||
|
|
||||||
async def net_stream_pair_factory() -> Tuple[
|
async def net_stream_pair_factory(
|
||||||
INetStream, BasicHost, INetStream, BasicHost
|
is_secure: bool
|
||||||
]:
|
) -> Tuple[INetStream, BasicHost, INetStream, BasicHost]:
|
||||||
protocol_id = "/example/id/1"
|
protocol_id = "/example/id/1"
|
||||||
|
|
||||||
stream_1: INetStream
|
stream_1: INetStream
|
||||||
|
@ -114,7 +141,7 @@ async def net_stream_pair_factory() -> Tuple[
|
||||||
nonlocal stream_1
|
nonlocal stream_1
|
||||||
stream_1 = stream
|
stream_1 = stream
|
||||||
|
|
||||||
host_0, host_1 = await host_pair_factory()
|
host_0, host_1 = await host_pair_factory(is_secure)
|
||||||
host_1.set_stream_handler(protocol_id, handler)
|
host_1.set_stream_handler(protocol_id, handler)
|
||||||
|
|
||||||
stream_0 = await host_0.new_stream(host_1.get_id(), [protocol_id])
|
stream_0 = await host_0.new_stream(host_1.get_id(), [protocol_id])
|
||||||
|
|
|
@ -2,13 +2,22 @@ import asyncio
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from tests.factories import net_stream_pair_factory
|
from tests.factories import net_stream_pair_factory, swarm_pair_factory
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def net_stream_pair():
|
async def net_stream_pair(is_host_secure):
|
||||||
stream_0, host_0, stream_1, host_1 = await net_stream_pair_factory()
|
stream_0, host_0, stream_1, host_1 = await net_stream_pair_factory(is_host_secure)
|
||||||
try:
|
try:
|
||||||
yield stream_0, stream_1
|
yield stream_0, stream_1
|
||||||
finally:
|
finally:
|
||||||
await asyncio.gather(*[host_0.close(), host_1.close()])
|
await asyncio.gather(*[host_0.close(), host_1.close()])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def swarm_pair(is_host_secure):
|
||||||
|
swarm_0, swarm_1 = await swarm_pair_factory(is_host_secure)
|
||||||
|
try:
|
||||||
|
yield swarm_0, swarm_1
|
||||||
|
finally:
|
||||||
|
await asyncio.gather(*[swarm_0.close(), swarm_1.close()])
|
||||||
|
|
49
tests/network/test_swarm.py
Normal file
49
tests/network/test_swarm.py
Normal file
|
@ -0,0 +1,49 @@
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from tests.factories import SwarmFactory
|
||||||
|
from tests.utils import connect_swarm
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_swarm_close_peer(is_host_secure):
|
||||||
|
swarms = await SwarmFactory.create_batch_and_listen(is_host_secure, 3)
|
||||||
|
# 0 <> 1 <> 2
|
||||||
|
await connect_swarm(swarms[0], swarms[1])
|
||||||
|
await connect_swarm(swarms[1], swarms[2])
|
||||||
|
|
||||||
|
# peer 1 closes peer 0
|
||||||
|
await swarms[1].close_peer(swarms[0].get_peer_id())
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
# 0 1 <> 2
|
||||||
|
assert len(swarms[0].connections) == 0
|
||||||
|
assert (
|
||||||
|
len(swarms[1].connections) == 1
|
||||||
|
and swarms[2].get_peer_id() in swarms[1].connections
|
||||||
|
)
|
||||||
|
|
||||||
|
# peer 1 is closed by peer 2
|
||||||
|
await swarms[2].close_peer(swarms[1].get_peer_id())
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
# 0 1 2
|
||||||
|
assert len(swarms[1].connections) == 0 and len(swarms[2].connections) == 0
|
||||||
|
|
||||||
|
await connect_swarm(swarms[0], swarms[1])
|
||||||
|
# 0 <> 1 2
|
||||||
|
assert (
|
||||||
|
len(swarms[0].connections) == 1
|
||||||
|
and swarms[1].get_peer_id() in swarms[0].connections
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
len(swarms[1].connections) == 1
|
||||||
|
and swarms[0].get_peer_id() in swarms[1].connections
|
||||||
|
)
|
||||||
|
# peer 0 closes peer 1
|
||||||
|
await swarms[0].close_peer(swarms[1].get_peer_id())
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
# 0 1 2
|
||||||
|
assert len(swarms[1].connections) == 0 and len(swarms[2].connections) == 0
|
||||||
|
|
||||||
|
# Clean up
|
||||||
|
await asyncio.gather(*[swarm.close() for swarm in swarms])
|
|
@ -5,6 +5,19 @@ from libp2p.peer.peerinfo import info_from_p2p_addr
|
||||||
from tests.constants import MAX_READ_LEN
|
from tests.constants import MAX_READ_LEN
|
||||||
|
|
||||||
|
|
||||||
|
async def connect_swarm(swarm_0, swarm_1):
|
||||||
|
peer_id = swarm_1.get_peer_id()
|
||||||
|
addrs = tuple(
|
||||||
|
addr
|
||||||
|
for transport in swarm_1.listeners.values()
|
||||||
|
for addr in transport.get_addrs()
|
||||||
|
)
|
||||||
|
swarm_0.peerstore.add_addrs(peer_id, addrs, 10000)
|
||||||
|
await swarm_0.dial_peer(peer_id)
|
||||||
|
assert swarm_0.get_peer_id() in swarm_1.connections
|
||||||
|
assert swarm_1.get_peer_id() in swarm_0.connections
|
||||||
|
|
||||||
|
|
||||||
async def connect(node1, node2):
|
async def connect(node1, node2):
|
||||||
"""
|
"""
|
||||||
Connect node1 to node2
|
Connect node1 to node2
|
||||||
|
|
Loading…
Reference in New Issue
Block a user