Refactor interop tests and factories

- Add `close` and `disconnect` in `Host`
- Add `close` and `close_peer` in `Network`
- Change `IListener.close` to async, to await for server's closing
- Add factories for security transports, and modify `HostFactory`
This commit is contained in:
mhchia 2019-08-29 21:38:06 +08:00
parent 64c0dab3af
commit c61a06706a
No known key found for this signature in database
GPG Key ID: 389EFBEA1362589A
15 changed files with 184 additions and 116 deletions

View File

@ -107,3 +107,9 @@ class BasicHost(IHost):
return return
await self._network.dial_peer(peer_info.peer_id) await self._network.dial_peer(peer_info.peer_id)
async def disconnect(self, peer_id: ID) -> None:
await self._network.close_peer(peer_id)
async def close(self) -> None:
await self._network.close()

View File

@ -71,3 +71,11 @@ class IHost(ABC):
:param peer_info: peer_info of the host we want to connect to :param peer_info: peer_info of the host we want to connect to
:type peer_info: peer.peerinfo.PeerInfo :type peer_info: peer.peerinfo.PeerInfo
""" """
@abstractmethod
async def disconnect(self, peer_id: ID) -> None:
pass
@abstractmethod
async def close(self) -> None:
pass

View File

@ -70,3 +70,11 @@ class INetwork(ABC):
:param notifee: object implementing Notifee interface :param notifee: object implementing Notifee interface
:return: true if notifee registered successfully, false otherwise :return: true if notifee registered successfully, false otherwise
""" """
@abstractmethod
async def close(self) -> None:
pass
@abstractmethod
async def close_peer(self, peer_id: ID) -> None:
pass

View File

@ -264,12 +264,24 @@ class Swarm(INetwork):
def add_router(self, router: IPeerRouting) -> None: def add_router(self, router: IPeerRouting) -> None:
self.router = router self.router = router
# TODO: `tear_down` async def close(self) -> None:
async def tear_down(self) -> None: # TODO: Prevent from new listeners and conns being added.
# Reference: https://github.com/libp2p/go-libp2p-swarm/blob/8be680aef8dea0a4497283f2f98470c2aeae6b65/swarm.go#L118 # noqa: E501 # Reference: https://github.com/libp2p/go-libp2p-swarm/blob/8be680aef8dea0a4497283f2f98470c2aeae6b65/swarm.go#L124-L134 # noqa: E501
pass
# TODO: `disconnect`? # Close listeners
await asyncio.gather(
*[listener.close() for listener in self.listeners.values()]
)
# Close connections
await asyncio.gather(
*[connection.close() for connection in self.connections.values()]
)
async def close_peer(self, peer_id: ID) -> None:
connection = self.connections[peer_id]
del self.connections[peer_id]
await connection.close()
def create_generic_protocol_handler(swarm: Swarm) -> GenericProtocolHandlerFn: def create_generic_protocol_handler(swarm: Swarm) -> GenericProtocolHandlerFn:

View File

@ -1,5 +1,6 @@
import asyncio import asyncio
from typing import Dict, Optional, Tuple from typing import Any # noqa: F401
from typing import Dict, List, Optional, Tuple
from libp2p.network.typing import GenericProtocolHandlerFn from libp2p.network.typing import GenericProtocolHandlerFn
from libp2p.peer.id import ID from libp2p.peer.id import ID
@ -34,6 +35,8 @@ class Mplex(IMuxedConn):
stream_queue: "asyncio.Queue[StreamID]" stream_queue: "asyncio.Queue[StreamID]"
next_channel_id: int next_channel_id: int
_tasks: List["asyncio.Future[Any]"]
# TODO: `generic_protocol_handler` should be refactored out of mplex conn. # TODO: `generic_protocol_handler` should be refactored out of mplex conn.
def __init__( def __init__(
self, self,
@ -63,8 +66,10 @@ class Mplex(IMuxedConn):
self.stream_queue = asyncio.Queue() self.stream_queue = asyncio.Queue()
self._tasks = []
# Kick off reading # Kick off reading
asyncio.ensure_future(self.handle_incoming()) self._tasks.append(asyncio.ensure_future(self.handle_incoming()))
@property @property
def initiator(self) -> bool: def initiator(self) -> bool:
@ -74,6 +79,8 @@ class Mplex(IMuxedConn):
""" """
close the stream muxer and underlying secured connection close the stream muxer and underlying secured connection
""" """
for task in self._tasks:
task.cancel()
await self.secured_conn.close() await self.secured_conn.close()
def is_closed(self) -> bool: def is_closed(self) -> bool:
@ -135,7 +142,7 @@ class Mplex(IMuxedConn):
""" """
stream_id = await self.stream_queue.get() stream_id = await self.stream_queue.get()
stream = MplexStream(name, stream_id, self) stream = MplexStream(name, stream_id, self)
asyncio.ensure_future(self.generic_protocol_handler(stream)) self._tasks.append(asyncio.ensure_future(self.generic_protocol_handler(stream)))
async def send_message( async def send_message(
self, flag: HeaderTags, data: bytes, stream_id: StreamID self, flag: HeaderTags, data: bytes, stream_id: StreamID

View File

@ -21,9 +21,8 @@ class IListener(ABC):
""" """
@abstractmethod @abstractmethod
def close(self) -> bool: async def close(self) -> None:
""" """
close the listener such that no more connections close the listener such that no more connections
can be open on this transport instance can be open on this transport instance
:return: return True if successful
""" """

View File

@ -45,20 +45,16 @@ class TCPListener(IListener):
# TODO check if server is listening # TODO check if server is listening
return self.multiaddrs return self.multiaddrs
def close(self) -> bool: async def close(self) -> None:
""" """
close the listener such that no more connections close the listener such that no more connections
can be open on this transport instance can be open on this transport instance
:return: return True if successful
""" """
if self.server is None: if self.server is None:
return False return
self.server.close() self.server.close()
_loop = asyncio.get_event_loop() await self.server.wait_closed()
_loop.run_until_complete(self.server.wait_closed())
_loop.close()
self.server = None self.server = None
return True
class TCP(ITransport): class TCP(ITransport):

33
tests/conftest.py Normal file
View File

@ -0,0 +1,33 @@
import asyncio
import pytest
from .configs import LISTEN_MADDR
from .factories import HostFactory
@pytest.fixture
def is_host_secure():
return False
@pytest.fixture
def num_hosts():
return 3
@pytest.fixture
async def hosts(num_hosts, is_host_secure):
_hosts = HostFactory.create_batch(num_hosts, is_secure=is_host_secure)
await asyncio.gather(
*[_host.get_network().listen(LISTEN_MADDR) for _host in _hosts]
)
try:
yield _hosts
finally:
# TODO: It's possible that `close` raises exceptions currently,
# due to the connection reset things. Though we are not so careful about that when
# cleaning up the tasks, it is probably better to handle the exceptions properly.
await asyncio.gather(
*[_host.close() for _host in _hosts], return_exceptions=True
)

View File

@ -1,12 +1,17 @@
from typing import Dict
import factory import factory
from libp2p import initialize_default_swarm from libp2p import generate_new_rsa_identity, initialize_default_swarm
from libp2p.crypto.rsa import create_new_key_pair from libp2p.crypto.keys import KeyPair
from libp2p.host.basic_host import BasicHost from libp2p.host.basic_host import BasicHost
from libp2p.pubsub.floodsub import FloodSub from libp2p.pubsub.floodsub import FloodSub
from libp2p.pubsub.gossipsub import GossipSub from libp2p.pubsub.gossipsub import GossipSub
from libp2p.pubsub.pubsub import Pubsub from libp2p.pubsub.pubsub import Pubsub
from tests.configs import LISTEN_MADDR from libp2p.security.base_transport import BaseSecureTransport
from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport
from libp2p.security.secio.transport import ID, Transport
from libp2p.typing import TProtocol
from tests.pubsub.configs import ( from tests.pubsub.configs import (
FLOODSUB_PROTOCOL_ID, FLOODSUB_PROTOCOL_ID,
GOSSIPSUB_PARAMS, GOSSIPSUB_PARAMS,
@ -14,16 +19,34 @@ from tests.pubsub.configs import (
) )
def swarm_factory(): def security_transport_factory(
private_key = create_new_key_pair() is_secure: bool, key_pair: KeyPair
return initialize_default_swarm(private_key, transport_opt=[str(LISTEN_MADDR)]) ) -> Dict[TProtocol, BaseSecureTransport]:
protocol_id: TProtocol
security_transport: BaseSecureTransport
if not is_secure:
protocol_id = PLAINTEXT_PROTOCOL_ID
security_transport = InsecureTransport(key_pair)
else:
protocol_id = ID
security_transport = Transport(key_pair)
return {protocol_id: security_transport}
def swarm_factory(is_secure: bool):
key_pair = generate_new_rsa_identity()
sec_opt = security_transport_factory(is_secure, key_pair)
return initialize_default_swarm(key_pair, sec_opt=sec_opt)
class HostFactory(factory.Factory): class HostFactory(factory.Factory):
class Meta: class Meta:
model = BasicHost model = BasicHost
network = factory.LazyFunction(swarm_factory) class Params:
is_secure = False
network = factory.LazyAttribute(lambda o: swarm_factory(o.is_secure))
class FloodsubFactory(factory.Factory): class FloodsubFactory(factory.Factory):

View File

@ -7,11 +7,8 @@ from multiaddr import Multiaddr
import pexpect import pexpect
import pytest import pytest
from libp2p import generate_new_rsa_identity, new_node
from libp2p.peer.peerinfo import info_from_p2p_addr from libp2p.peer.peerinfo import info_from_p2p_addr
from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport
from libp2p.typing import TProtocol from libp2p.typing import TProtocol
from tests.configs import LISTEN_MADDR
GOPATH = pathlib.Path(os.environ["GOPATH"]) GOPATH = pathlib.Path(os.environ["GOPATH"])
ECHO_PATH = GOPATH / "bin" / "echo" ECHO_PATH = GOPATH / "bin" / "echo"
@ -19,50 +16,68 @@ ECHO_PROTOCOL_ID = TProtocol("/echo/1.0.0")
NEW_LINE = "\r\n" NEW_LINE = "\r\n"
@pytest.mark.asyncio @pytest.fixture
async def test_insecure_conn_py_to_go(unused_tcp_port): def proc_factory():
procs = []
def call_proc(cmd, args, logfile=None, encoding=None):
if logfile is None:
logfile = sys.stdout
if encoding is None:
encoding = "utf-8"
proc = pexpect.spawn(cmd, args, logfile=logfile, encoding=encoding)
procs.append(proc)
return proc
try: try:
go_proc = pexpect.spawn( yield call_proc
str(ECHO_PATH),
[f"-l={unused_tcp_port}", "-insecure"],
logfile=sys.stdout,
encoding="utf-8",
)
await go_proc.expect(r"I am ([\w\./]+)" + NEW_LINE, async_=True)
maddr_str = go_proc.match.group(1)
maddr_str = maddr_str.replace("ipfs", "p2p")
maddr = Multiaddr(maddr_str)
go_pinfo = info_from_p2p_addr(maddr)
await go_proc.expect("listening for connections", async_=True)
key_pair = generate_new_rsa_identity()
insecure_tpt = InsecureTransport(key_pair)
host = await new_node(
key_pair=key_pair, sec_opt={PLAINTEXT_PROTOCOL_ID: insecure_tpt}
)
await host.connect(go_pinfo)
await go_proc.expect("swarm listener accepted connection", async_=True)
s = await host.new_stream(go_pinfo.peer_id, [ECHO_PROTOCOL_ID])
await go_proc.expect("Got a new stream!", async_=True)
data = "data321123\n"
await s.write(data.encode())
await go_proc.expect(f"read: {data[:-1]}", async_=True)
echoed_resp = await s.read(len(data))
assert echoed_resp.decode() == data
await s.close()
finally: finally:
go_proc.close() for proc in procs:
proc.close()
async def make_echo_proc(
proc_factory, port: int, is_secure: bool, destination: Multiaddr = None
):
args = [f"-l={port}"]
if not is_secure:
args.append("-insecure")
if destination is not None:
args.append(f"-d={str(destination)}")
echo_proc = proc_factory(str(ECHO_PATH), args, logfile=sys.stdout, encoding="utf-8")
await echo_proc.expect(r"I am ([\w\./]+)" + NEW_LINE, async_=True)
maddr_str_ipfs = echo_proc.match.group(1)
maddr_str = maddr_str_ipfs.replace("ipfs", "p2p")
maddr = Multiaddr(maddr_str)
go_pinfo = info_from_p2p_addr(maddr)
if destination is None:
await echo_proc.expect("listening for connections", async_=True)
return echo_proc, go_pinfo
@pytest.mark.parametrize("num_hosts", (1,))
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_insecure_conn_go_to_py(unused_tcp_port): async def test_insecure_conn_py_to_go(hosts, proc_factory, unused_tcp_port):
key_pair = generate_new_rsa_identity() go_proc, go_pinfo = await make_echo_proc(proc_factory, unused_tcp_port, False)
insecure_tpt = InsecureTransport(key_pair)
host = await new_node( host = hosts[0]
key_pair=key_pair, sec_opt={PLAINTEXT_PROTOCOL_ID: insecure_tpt} await host.connect(go_pinfo)
) await go_proc.expect("swarm listener accepted connection", async_=True)
await host.get_network().listen(LISTEN_MADDR) s = await host.new_stream(go_pinfo.peer_id, [ECHO_PROTOCOL_ID])
await go_proc.expect("Got a new stream!", async_=True)
data = "data321123\n"
await s.write(data.encode())
await go_proc.expect(f"read: {data[:-1]}", async_=True)
echoed_resp = await s.read(len(data))
assert echoed_resp.decode() == data
await s.close()
@pytest.mark.parametrize("num_hosts", (1,))
@pytest.mark.asyncio
async def test_insecure_conn_go_to_py(hosts, proc_factory, unused_tcp_port):
host = hosts[0]
expected_data = "Hello, world!\n" expected_data = "Hello, world!\n"
reply_data = "Replyooo!\n" reply_data = "Replyooo!\n"
event_handler_finished = asyncio.Event() event_handler_finished = asyncio.Event()
@ -76,17 +91,8 @@ async def test_insecure_conn_go_to_py(unused_tcp_port):
host.set_stream_handler(ECHO_PROTOCOL_ID, _handle_echo) host.set_stream_handler(ECHO_PROTOCOL_ID, _handle_echo)
py_maddr = host.get_addrs()[0] py_maddr = host.get_addrs()[0]
go_proc = pexpect.spawn( go_proc, _ = await make_echo_proc(proc_factory, unused_tcp_port, False, py_maddr)
str(ECHO_PATH), await go_proc.expect("connect with peer", async_=True)
[f"-l={unused_tcp_port}", "-insecure", f"-d={str(py_maddr)}"], await go_proc.expect("opened stream", async_=True)
logfile=sys.stdout, await event_handler_finished.wait()
encoding="utf-8", await go_proc.expect(f"read reply: .*{reply_data.rstrip()}.*", async_=True)
)
try:
await go_proc.expect(r"I am ([\w\./]+)" + NEW_LINE, async_=True)
await go_proc.expect("connect with peer", async_=True)
await go_proc.expect("opened stream", async_=True)
await event_handler_finished.wait()
await go_proc.expect(f"read reply: .*{reply_data.rstrip()}.*", async_=True)
finally:
go_proc.close()

View File

@ -1,38 +1,7 @@
import asyncio
import pytest import pytest
from tests.configs import LISTEN_MADDR from tests.factories import FloodsubFactory, GossipsubFactory, PubsubFactory
from tests.pubsub.configs import GOSSIPSUB_PARAMS from tests.pubsub.configs import GOSSIPSUB_PARAMS
from tests.pubsub.factories import (
FloodsubFactory,
GossipsubFactory,
HostFactory,
PubsubFactory,
)
@pytest.fixture
def num_hosts():
return 3
@pytest.fixture
async def hosts(num_hosts):
_hosts = HostFactory.create_batch(num_hosts)
await asyncio.gather(
*[_host.get_network().listen(LISTEN_MADDR) for _host in _hosts]
)
try:
yield _hosts
finally:
# Clean up
listeners = []
for _host in _hosts:
for listener in _host.get_network().listeners.values():
listener.server.close()
listeners.append(listener)
await asyncio.gather(*[listener.server.wait_closed() for listener in listeners])
@pytest.fixture @pytest.fixture

View File

@ -5,8 +5,8 @@ from libp2p.host.host_interface import IHost
from libp2p.pubsub.floodsub import FloodSub from libp2p.pubsub.floodsub import FloodSub
from libp2p.pubsub.pubsub import Pubsub from libp2p.pubsub.pubsub import Pubsub
from tests.configs import LISTEN_MADDR from tests.configs import LISTEN_MADDR
from tests.factories import FloodsubFactory, PubsubFactory
from .factories import FloodsubFactory, PubsubFactory
from .utils import message_id_generator from .utils import message_id_generator
CRYPTO_TOPIC = "ethereum" CRYPTO_TOPIC = "ethereum"

View File

@ -3,10 +3,10 @@ import asyncio
import pytest import pytest
from tests.configs import LISTEN_MADDR from tests.configs import LISTEN_MADDR
from tests.factories import PubsubFactory
from tests.utils import cleanup, connect from tests.utils import cleanup, connect
from .configs import FLOODSUB_PROTOCOL_ID from .configs import FLOODSUB_PROTOCOL_ID
from .factories import PubsubFactory
SUPPORTED_PROTOCOLS = [FLOODSUB_PROTOCOL_ID] SUPPORTED_PROTOCOLS = [FLOODSUB_PROTOCOL_ID]

View File

@ -3,9 +3,9 @@ import asyncio
import pytest import pytest
from libp2p.peer.id import ID from libp2p.peer.id import ID
from tests.factories import FloodsubFactory
from tests.utils import cleanup, connect from tests.utils import cleanup, connect
from .factories import FloodsubFactory
from .floodsub_integration_test_settings import ( from .floodsub_integration_test_settings import (
floodsub_protocol_pytest_params, floodsub_protocol_pytest_params,
perform_test_from_obj, perform_test_from_obj,

View File

@ -2,8 +2,9 @@ import functools
import pytest import pytest
from tests.factories import GossipsubFactory
from .configs import FLOODSUB_PROTOCOL_ID from .configs import FLOODSUB_PROTOCOL_ID
from .factories import GossipsubFactory
from .floodsub_integration_test_settings import ( from .floodsub_integration_test_settings import (
floodsub_protocol_pytest_params, floodsub_protocol_pytest_params,
perform_test_from_obj, perform_test_from_obj,