Add test for Swarm.close_peer

This commit is contained in:
mhchia 2019-09-14 23:37:01 +08:00
parent 6923f257f6
commit e7304538da
No known key found for this signature in database
GPG Key ID: 389EFBEA1362589A
6 changed files with 130 additions and 29 deletions

View File

@ -23,6 +23,10 @@ from .host_interface import IHost
class BasicHost(IHost):
"""
BasicHost is a wrapper of a `INetwork` implementation. It performs protocol negotiation
on a stream with multistream-select right after a stream is initialized.
"""
_network: INetwork
_router: KadmeliaPeerRouter
@ -31,7 +35,6 @@ class BasicHost(IHost):
multiselect: Multiselect
multiselect_client: MultiselectClient
# default options constructor
def __init__(self, network: INetwork, router: KadmeliaPeerRouter = None) -> None:
self._network = network
self._network.set_stream_handler(self._swarm_stream_handler)
@ -69,6 +72,7 @@ class BasicHost(IHost):
"""
:return: all the multiaddr addresses this host is listening to
"""
# TODO: We don't need "/p2p/{peer_id}" postfix actually.
p2p_part = multiaddr.Multiaddr("/p2p/{}".format(self.get_id().pretty()))
addrs: List[multiaddr.Multiaddr] = []
@ -87,8 +91,6 @@ class BasicHost(IHost):
"""
self.multiselect.add_handler(protocol_id, stream_handler)
# `protocol_ids` can be a list of `protocol_id`
# stream will decide which `protocol_id` to run on
async def new_stream(
self, peer_id: ID, protocol_ids: Sequence[TProtocol]
) -> INetStream:

View File

@ -50,11 +50,12 @@ class SwarmConn(INetConn):
task.cancel()
async def _handle_new_streams(self) -> None:
# TODO: Break the loop when anything wrong in the connection.
while True:
try:
stream = await self.conn.accept_stream()
except MuxedConnUnavailable:
# If there is anything wrong in the MuxedConn,
# we should break the loop and close the connection.
break
# Asynchronously handle the accepted stream, to avoid blocking the next stream.
await self.run_task(self._handle_muxed_stream(stream))

View File

@ -8,13 +8,13 @@ from libp2p.crypto.keys import KeyPair
from libp2p.host.basic_host import BasicHost
from libp2p.host.host_interface import IHost
from libp2p.network.stream.net_stream_interface import INetStream
from libp2p.network.swarm import Swarm
from libp2p.pubsub.floodsub import FloodSub
from libp2p.pubsub.gossipsub import GossipSub
from libp2p.pubsub.pubsub import Pubsub
from libp2p.security.base_transport import BaseSecureTransport
from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport
import libp2p.security.secio.transport as secio
from libp2p.stream_muxer.mplex.mplex import Mplex
from libp2p.typing import TProtocol
from tests.configs import LISTEN_MADDR
from tests.pubsub.configs import (
@ -22,7 +22,7 @@ from tests.pubsub.configs import (
GOSSIPSUB_PARAMS,
GOSSIPSUB_PROTOCOL_ID,
)
from tests.utils import connect
from tests.utils import connect, connect_swarm
def security_transport_factory(
@ -34,10 +34,29 @@ def security_transport_factory(
return {secio.ID: secio.Transport(key_pair)}
def swarm_factory(is_secure: bool):
key_pair = generate_new_rsa_identity()
sec_opt = security_transport_factory(is_secure, key_pair)
return initialize_default_swarm(key_pair, sec_opt=sec_opt)
class SwarmFactory(factory.Factory):
class Meta:
model = Swarm
@classmethod
def _create(cls, is_secure=False):
key_pair = generate_new_rsa_identity()
sec_opt = security_transport_factory(is_secure, key_pair)
return initialize_default_swarm(key_pair, sec_opt=sec_opt)
@classmethod
async def create_and_listen(cls, is_secure: bool) -> Swarm:
swarm = cls._create(is_secure)
await swarm.listen(LISTEN_MADDR)
return swarm
@classmethod
async def create_batch_and_listen(
cls, is_secure: bool, number: int
) -> Tuple[Swarm, ...]:
return await asyncio.gather(
*[cls.create_and_listen(is_secure) for _ in range(number)]
)
class HostFactory(factory.Factory):
@ -47,13 +66,12 @@ class HostFactory(factory.Factory):
class Params:
is_secure = False
network = factory.LazyAttribute(lambda o: swarm_factory(o.is_secure))
network = factory.LazyAttribute(lambda o: SwarmFactory(o.is_secure))
@classmethod
async def create_and_listen(cls) -> IHost:
host = cls()
await host.get_network().listen(LISTEN_MADDR)
return host
async def create_and_listen(cls, is_secure: bool) -> IHost:
swarm = await SwarmFactory.create_and_listen(is_secure)
return BasicHost(swarm)
class FloodsubFactory(factory.Factory):
@ -87,24 +105,33 @@ class PubsubFactory(factory.Factory):
cache_size = None
async def host_pair_factory() -> Tuple[BasicHost, BasicHost]:
async def swarm_pair_factory(is_secure: bool) -> Tuple[Swarm, Swarm]:
swarms = await SwarmFactory.create_batch_and_listen(2)
await connect_swarm(swarms[0], swarms[1])
return swarms[0], swarms[1]
async def host_pair_factory(is_secure) -> Tuple[BasicHost, BasicHost]:
hosts = await asyncio.gather(
*[HostFactory.create_and_listen(), HostFactory.create_and_listen()]
*[
HostFactory.create_and_listen(is_secure),
HostFactory.create_and_listen(is_secure),
]
)
await connect(hosts[0], hosts[1])
return hosts[0], hosts[1]
async def connection_pair_factory() -> Tuple[Mplex, BasicHost, Mplex, BasicHost]:
host_0, host_1 = await host_pair_factory()
mplex_conn_0 = host_0.get_network().connections[host_1.get_id()]
mplex_conn_1 = host_1.get_network().connections[host_0.get_id()]
return mplex_conn_0, host_0, mplex_conn_1, host_1
# async def connection_pair_factory() -> Tuple[Mplex, BasicHost, Mplex, BasicHost]:
# host_0, host_1 = await host_pair_factory()
# mplex_conn_0 = host_0.get_network().connections[host_1.get_id()]
# mplex_conn_1 = host_1.get_network().connections[host_0.get_id()]
# return mplex_conn_0, host_0, mplex_conn_1, host_1
async def net_stream_pair_factory() -> Tuple[
INetStream, BasicHost, INetStream, BasicHost
]:
async def net_stream_pair_factory(
is_secure: bool
) -> Tuple[INetStream, BasicHost, INetStream, BasicHost]:
protocol_id = "/example/id/1"
stream_1: INetStream
@ -114,7 +141,7 @@ async def net_stream_pair_factory() -> Tuple[
nonlocal stream_1
stream_1 = stream
host_0, host_1 = await host_pair_factory()
host_0, host_1 = await host_pair_factory(is_secure)
host_1.set_stream_handler(protocol_id, handler)
stream_0 = await host_0.new_stream(host_1.get_id(), [protocol_id])

View File

@ -2,13 +2,22 @@ import asyncio
import pytest
from tests.factories import net_stream_pair_factory
from tests.factories import net_stream_pair_factory, swarm_pair_factory
@pytest.fixture
async def net_stream_pair():
stream_0, host_0, stream_1, host_1 = await net_stream_pair_factory()
async def net_stream_pair(is_host_secure):
stream_0, host_0, stream_1, host_1 = await net_stream_pair_factory(is_host_secure)
try:
yield stream_0, stream_1
finally:
await asyncio.gather(*[host_0.close(), host_1.close()])
@pytest.fixture
async def swarm_pair(is_host_secure):
swarm_0, swarm_1 = await swarm_pair_factory(is_host_secure)
try:
yield swarm_0, swarm_1
finally:
await asyncio.gather(*[swarm_0.close(), swarm_1.close()])

View File

@ -0,0 +1,49 @@
import asyncio
import pytest
from tests.factories import SwarmFactory
from tests.utils import connect_swarm
@pytest.mark.asyncio
async def test_swarm_close_peer(is_host_secure):
swarms = await SwarmFactory.create_batch_and_listen(is_host_secure, 3)
# 0 <> 1 <> 2
await connect_swarm(swarms[0], swarms[1])
await connect_swarm(swarms[1], swarms[2])
# peer 1 closes peer 0
await swarms[1].close_peer(swarms[0].get_peer_id())
await asyncio.sleep(0.01)
# 0 1 <> 2
assert len(swarms[0].connections) == 0
assert (
len(swarms[1].connections) == 1
and swarms[2].get_peer_id() in swarms[1].connections
)
# peer 1 is closed by peer 2
await swarms[2].close_peer(swarms[1].get_peer_id())
await asyncio.sleep(0.01)
# 0 1 2
assert len(swarms[1].connections) == 0 and len(swarms[2].connections) == 0
await connect_swarm(swarms[0], swarms[1])
# 0 <> 1 2
assert (
len(swarms[0].connections) == 1
and swarms[1].get_peer_id() in swarms[0].connections
)
assert (
len(swarms[1].connections) == 1
and swarms[0].get_peer_id() in swarms[1].connections
)
# peer 0 closes peer 1
await swarms[0].close_peer(swarms[1].get_peer_id())
await asyncio.sleep(0.01)
# 0 1 2
assert len(swarms[1].connections) == 0 and len(swarms[2].connections) == 0
# Clean up
await asyncio.gather(*[swarm.close() for swarm in swarms])

View File

@ -5,6 +5,19 @@ from libp2p.peer.peerinfo import info_from_p2p_addr
from tests.constants import MAX_READ_LEN
async def connect_swarm(swarm_0, swarm_1):
peer_id = swarm_1.get_peer_id()
addrs = tuple(
addr
for transport in swarm_1.listeners.values()
for addr in transport.get_addrs()
)
swarm_0.peerstore.add_addrs(peer_id, addrs, 10000)
await swarm_0.dial_peer(peer_id)
assert swarm_0.get_peer_id() in swarm_1.connections
assert swarm_1.get_peer_id() in swarm_0.connections
async def connect(node1, node2):
"""
Connect node1 to node2