Add test for Swarm.close_peer

This commit is contained in:
mhchia 2019-09-14 23:37:01 +08:00
parent 6923f257f6
commit e7304538da
No known key found for this signature in database
GPG Key ID: 389EFBEA1362589A
6 changed files with 130 additions and 29 deletions

View File

@ -23,6 +23,10 @@ from .host_interface import IHost
class BasicHost(IHost): class BasicHost(IHost):
"""
BasicHost is a wrapper of a `INetwork` implementation. It performs protocol negotiation
on a stream with multistream-select right after a stream is initialized.
"""
_network: INetwork _network: INetwork
_router: KadmeliaPeerRouter _router: KadmeliaPeerRouter
@ -31,7 +35,6 @@ class BasicHost(IHost):
multiselect: Multiselect multiselect: Multiselect
multiselect_client: MultiselectClient multiselect_client: MultiselectClient
# default options constructor
def __init__(self, network: INetwork, router: KadmeliaPeerRouter = None) -> None: def __init__(self, network: INetwork, router: KadmeliaPeerRouter = None) -> None:
self._network = network self._network = network
self._network.set_stream_handler(self._swarm_stream_handler) self._network.set_stream_handler(self._swarm_stream_handler)
@ -69,6 +72,7 @@ class BasicHost(IHost):
""" """
:return: all the multiaddr addresses this host is listening to :return: all the multiaddr addresses this host is listening to
""" """
# TODO: We don't need "/p2p/{peer_id}" postfix actually.
p2p_part = multiaddr.Multiaddr("/p2p/{}".format(self.get_id().pretty())) p2p_part = multiaddr.Multiaddr("/p2p/{}".format(self.get_id().pretty()))
addrs: List[multiaddr.Multiaddr] = [] addrs: List[multiaddr.Multiaddr] = []
@ -87,8 +91,6 @@ class BasicHost(IHost):
""" """
self.multiselect.add_handler(protocol_id, stream_handler) self.multiselect.add_handler(protocol_id, stream_handler)
# `protocol_ids` can be a list of `protocol_id`
# stream will decide which `protocol_id` to run on
async def new_stream( async def new_stream(
self, peer_id: ID, protocol_ids: Sequence[TProtocol] self, peer_id: ID, protocol_ids: Sequence[TProtocol]
) -> INetStream: ) -> INetStream:

View File

@ -50,11 +50,12 @@ class SwarmConn(INetConn):
task.cancel() task.cancel()
async def _handle_new_streams(self) -> None: async def _handle_new_streams(self) -> None:
# TODO: Break the loop when anything wrong in the connection.
while True: while True:
try: try:
stream = await self.conn.accept_stream() stream = await self.conn.accept_stream()
except MuxedConnUnavailable: except MuxedConnUnavailable:
# If there is anything wrong in the MuxedConn,
# we should break the loop and close the connection.
break break
# Asynchronously handle the accepted stream, to avoid blocking the next stream. # Asynchronously handle the accepted stream, to avoid blocking the next stream.
await self.run_task(self._handle_muxed_stream(stream)) await self.run_task(self._handle_muxed_stream(stream))

View File

@ -8,13 +8,13 @@ from libp2p.crypto.keys import KeyPair
from libp2p.host.basic_host import BasicHost from libp2p.host.basic_host import BasicHost
from libp2p.host.host_interface import IHost from libp2p.host.host_interface import IHost
from libp2p.network.stream.net_stream_interface import INetStream from libp2p.network.stream.net_stream_interface import INetStream
from libp2p.network.swarm import Swarm
from libp2p.pubsub.floodsub import FloodSub from libp2p.pubsub.floodsub import FloodSub
from libp2p.pubsub.gossipsub import GossipSub from libp2p.pubsub.gossipsub import GossipSub
from libp2p.pubsub.pubsub import Pubsub from libp2p.pubsub.pubsub import Pubsub
from libp2p.security.base_transport import BaseSecureTransport from libp2p.security.base_transport import BaseSecureTransport
from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport
import libp2p.security.secio.transport as secio import libp2p.security.secio.transport as secio
from libp2p.stream_muxer.mplex.mplex import Mplex
from libp2p.typing import TProtocol from libp2p.typing import TProtocol
from tests.configs import LISTEN_MADDR from tests.configs import LISTEN_MADDR
from tests.pubsub.configs import ( from tests.pubsub.configs import (
@ -22,7 +22,7 @@ from tests.pubsub.configs import (
GOSSIPSUB_PARAMS, GOSSIPSUB_PARAMS,
GOSSIPSUB_PROTOCOL_ID, GOSSIPSUB_PROTOCOL_ID,
) )
from tests.utils import connect from tests.utils import connect, connect_swarm
def security_transport_factory( def security_transport_factory(
@ -34,10 +34,29 @@ def security_transport_factory(
return {secio.ID: secio.Transport(key_pair)} return {secio.ID: secio.Transport(key_pair)}
def swarm_factory(is_secure: bool): class SwarmFactory(factory.Factory):
key_pair = generate_new_rsa_identity() class Meta:
sec_opt = security_transport_factory(is_secure, key_pair) model = Swarm
return initialize_default_swarm(key_pair, sec_opt=sec_opt)
@classmethod
def _create(cls, is_secure=False):
key_pair = generate_new_rsa_identity()
sec_opt = security_transport_factory(is_secure, key_pair)
return initialize_default_swarm(key_pair, sec_opt=sec_opt)
@classmethod
async def create_and_listen(cls, is_secure: bool) -> Swarm:
swarm = cls._create(is_secure)
await swarm.listen(LISTEN_MADDR)
return swarm
@classmethod
async def create_batch_and_listen(
cls, is_secure: bool, number: int
) -> Tuple[Swarm, ...]:
return await asyncio.gather(
*[cls.create_and_listen(is_secure) for _ in range(number)]
)
class HostFactory(factory.Factory): class HostFactory(factory.Factory):
@ -47,13 +66,12 @@ class HostFactory(factory.Factory):
class Params: class Params:
is_secure = False is_secure = False
network = factory.LazyAttribute(lambda o: swarm_factory(o.is_secure)) network = factory.LazyAttribute(lambda o: SwarmFactory(o.is_secure))
@classmethod @classmethod
async def create_and_listen(cls) -> IHost: async def create_and_listen(cls, is_secure: bool) -> IHost:
host = cls() swarm = await SwarmFactory.create_and_listen(is_secure)
await host.get_network().listen(LISTEN_MADDR) return BasicHost(swarm)
return host
class FloodsubFactory(factory.Factory): class FloodsubFactory(factory.Factory):
@ -87,24 +105,33 @@ class PubsubFactory(factory.Factory):
cache_size = None cache_size = None
async def host_pair_factory() -> Tuple[BasicHost, BasicHost]: async def swarm_pair_factory(is_secure: bool) -> Tuple[Swarm, Swarm]:
swarms = await SwarmFactory.create_batch_and_listen(2)
await connect_swarm(swarms[0], swarms[1])
return swarms[0], swarms[1]
async def host_pair_factory(is_secure) -> Tuple[BasicHost, BasicHost]:
hosts = await asyncio.gather( hosts = await asyncio.gather(
*[HostFactory.create_and_listen(), HostFactory.create_and_listen()] *[
HostFactory.create_and_listen(is_secure),
HostFactory.create_and_listen(is_secure),
]
) )
await connect(hosts[0], hosts[1]) await connect(hosts[0], hosts[1])
return hosts[0], hosts[1] return hosts[0], hosts[1]
async def connection_pair_factory() -> Tuple[Mplex, BasicHost, Mplex, BasicHost]: # async def connection_pair_factory() -> Tuple[Mplex, BasicHost, Mplex, BasicHost]:
host_0, host_1 = await host_pair_factory() # host_0, host_1 = await host_pair_factory()
mplex_conn_0 = host_0.get_network().connections[host_1.get_id()] # mplex_conn_0 = host_0.get_network().connections[host_1.get_id()]
mplex_conn_1 = host_1.get_network().connections[host_0.get_id()] # mplex_conn_1 = host_1.get_network().connections[host_0.get_id()]
return mplex_conn_0, host_0, mplex_conn_1, host_1 # return mplex_conn_0, host_0, mplex_conn_1, host_1
async def net_stream_pair_factory() -> Tuple[ async def net_stream_pair_factory(
INetStream, BasicHost, INetStream, BasicHost is_secure: bool
]: ) -> Tuple[INetStream, BasicHost, INetStream, BasicHost]:
protocol_id = "/example/id/1" protocol_id = "/example/id/1"
stream_1: INetStream stream_1: INetStream
@ -114,7 +141,7 @@ async def net_stream_pair_factory() -> Tuple[
nonlocal stream_1 nonlocal stream_1
stream_1 = stream stream_1 = stream
host_0, host_1 = await host_pair_factory() host_0, host_1 = await host_pair_factory(is_secure)
host_1.set_stream_handler(protocol_id, handler) host_1.set_stream_handler(protocol_id, handler)
stream_0 = await host_0.new_stream(host_1.get_id(), [protocol_id]) stream_0 = await host_0.new_stream(host_1.get_id(), [protocol_id])

View File

@ -2,13 +2,22 @@ import asyncio
import pytest import pytest
from tests.factories import net_stream_pair_factory from tests.factories import net_stream_pair_factory, swarm_pair_factory
@pytest.fixture @pytest.fixture
async def net_stream_pair(): async def net_stream_pair(is_host_secure):
stream_0, host_0, stream_1, host_1 = await net_stream_pair_factory() stream_0, host_0, stream_1, host_1 = await net_stream_pair_factory(is_host_secure)
try: try:
yield stream_0, stream_1 yield stream_0, stream_1
finally: finally:
await asyncio.gather(*[host_0.close(), host_1.close()]) await asyncio.gather(*[host_0.close(), host_1.close()])
@pytest.fixture
async def swarm_pair(is_host_secure):
swarm_0, swarm_1 = await swarm_pair_factory(is_host_secure)
try:
yield swarm_0, swarm_1
finally:
await asyncio.gather(*[swarm_0.close(), swarm_1.close()])

View File

@ -0,0 +1,49 @@
import asyncio
import pytest
from tests.factories import SwarmFactory
from tests.utils import connect_swarm
@pytest.mark.asyncio
async def test_swarm_close_peer(is_host_secure):
swarms = await SwarmFactory.create_batch_and_listen(is_host_secure, 3)
# 0 <> 1 <> 2
await connect_swarm(swarms[0], swarms[1])
await connect_swarm(swarms[1], swarms[2])
# peer 1 closes peer 0
await swarms[1].close_peer(swarms[0].get_peer_id())
await asyncio.sleep(0.01)
# 0 1 <> 2
assert len(swarms[0].connections) == 0
assert (
len(swarms[1].connections) == 1
and swarms[2].get_peer_id() in swarms[1].connections
)
# peer 1 is closed by peer 2
await swarms[2].close_peer(swarms[1].get_peer_id())
await asyncio.sleep(0.01)
# 0 1 2
assert len(swarms[1].connections) == 0 and len(swarms[2].connections) == 0
await connect_swarm(swarms[0], swarms[1])
# 0 <> 1 2
assert (
len(swarms[0].connections) == 1
and swarms[1].get_peer_id() in swarms[0].connections
)
assert (
len(swarms[1].connections) == 1
and swarms[0].get_peer_id() in swarms[1].connections
)
# peer 0 closes peer 1
await swarms[0].close_peer(swarms[1].get_peer_id())
await asyncio.sleep(0.01)
# 0 1 2
assert len(swarms[1].connections) == 0 and len(swarms[2].connections) == 0
# Clean up
await asyncio.gather(*[swarm.close() for swarm in swarms])

View File

@ -5,6 +5,19 @@ from libp2p.peer.peerinfo import info_from_p2p_addr
from tests.constants import MAX_READ_LEN from tests.constants import MAX_READ_LEN
async def connect_swarm(swarm_0, swarm_1):
peer_id = swarm_1.get_peer_id()
addrs = tuple(
addr
for transport in swarm_1.listeners.values()
for addr in transport.get_addrs()
)
swarm_0.peerstore.add_addrs(peer_id, addrs, 10000)
await swarm_0.dial_peer(peer_id)
assert swarm_0.get_peer_id() in swarm_1.connections
assert swarm_1.get_peer_id() in swarm_0.connections
async def connect(node1, node2): async def connect(node1, node2):
""" """
Connect node1 to node2 Connect node1 to node2