diff --git a/libp2p/host/basic_host.py b/libp2p/host/basic_host.py index cdee9a0..e6fc515 100644 --- a/libp2p/host/basic_host.py +++ b/libp2p/host/basic_host.py @@ -107,3 +107,9 @@ class BasicHost(IHost): return await self._network.dial_peer(peer_info.peer_id) + + async def disconnect(self, peer_id: ID) -> None: + await self._network.close_peer(peer_id) + + async def close(self) -> None: + await self._network.close() diff --git a/libp2p/host/host_interface.py b/libp2p/host/host_interface.py index bcaefad..6b1ef03 100644 --- a/libp2p/host/host_interface.py +++ b/libp2p/host/host_interface.py @@ -71,3 +71,11 @@ class IHost(ABC): :param peer_info: peer_info of the host we want to connect to :type peer_info: peer.peerinfo.PeerInfo """ + + @abstractmethod + async def disconnect(self, peer_id: ID) -> None: + pass + + @abstractmethod + async def close(self) -> None: + pass diff --git a/libp2p/network/network_interface.py b/libp2p/network/network_interface.py index d9cdf48..9ed2d16 100644 --- a/libp2p/network/network_interface.py +++ b/libp2p/network/network_interface.py @@ -70,3 +70,11 @@ class INetwork(ABC): :param notifee: object implementing Notifee interface :return: true if notifee registered successfully, false otherwise """ + + @abstractmethod + async def close(self) -> None: + pass + + @abstractmethod + async def close_peer(self, peer_id: ID) -> None: + pass diff --git a/libp2p/network/swarm.py b/libp2p/network/swarm.py index 52df727..d0f2fe8 100644 --- a/libp2p/network/swarm.py +++ b/libp2p/network/swarm.py @@ -264,12 +264,24 @@ class Swarm(INetwork): def add_router(self, router: IPeerRouting) -> None: self.router = router - # TODO: `tear_down` - async def tear_down(self) -> None: - # Reference: https://github.com/libp2p/go-libp2p-swarm/blob/8be680aef8dea0a4497283f2f98470c2aeae6b65/swarm.go#L118 # noqa: E501 - pass + async def close(self) -> None: + # TODO: Prevent from new listeners and conns being added. + # Reference: https://github.com/libp2p/go-libp2p-swarm/blob/8be680aef8dea0a4497283f2f98470c2aeae6b65/swarm.go#L124-L134 # noqa: E501 - # TODO: `disconnect`? + # Close listeners + await asyncio.gather( + *[listener.close() for listener in self.listeners.values()] + ) + + # Close connections + await asyncio.gather( + *[connection.close() for connection in self.connections.values()] + ) + + async def close_peer(self, peer_id: ID) -> None: + connection = self.connections[peer_id] + del self.connections[peer_id] + await connection.close() def create_generic_protocol_handler(swarm: Swarm) -> GenericProtocolHandlerFn: diff --git a/libp2p/stream_muxer/mplex/mplex.py b/libp2p/stream_muxer/mplex/mplex.py index 40fa67b..5f55a66 100644 --- a/libp2p/stream_muxer/mplex/mplex.py +++ b/libp2p/stream_muxer/mplex/mplex.py @@ -1,5 +1,6 @@ import asyncio -from typing import Dict, Optional, Tuple +from typing import Any # noqa: F401 +from typing import Dict, List, Optional, Tuple from libp2p.network.typing import GenericProtocolHandlerFn from libp2p.peer.id import ID @@ -34,6 +35,8 @@ class Mplex(IMuxedConn): stream_queue: "asyncio.Queue[StreamID]" next_channel_id: int + _tasks: List["asyncio.Future[Any]"] + # TODO: `generic_protocol_handler` should be refactored out of mplex conn. def __init__( self, @@ -63,8 +66,10 @@ class Mplex(IMuxedConn): self.stream_queue = asyncio.Queue() + self._tasks = [] + # Kick off reading - asyncio.ensure_future(self.handle_incoming()) + self._tasks.append(asyncio.ensure_future(self.handle_incoming())) @property def initiator(self) -> bool: @@ -74,6 +79,8 @@ class Mplex(IMuxedConn): """ close the stream muxer and underlying secured connection """ + for task in self._tasks: + task.cancel() await self.secured_conn.close() def is_closed(self) -> bool: @@ -135,7 +142,7 @@ class Mplex(IMuxedConn): """ stream_id = await self.stream_queue.get() stream = MplexStream(name, stream_id, self) - asyncio.ensure_future(self.generic_protocol_handler(stream)) + self._tasks.append(asyncio.ensure_future(self.generic_protocol_handler(stream))) async def send_message( self, flag: HeaderTags, data: bytes, stream_id: StreamID diff --git a/libp2p/transport/listener_interface.py b/libp2p/transport/listener_interface.py index fecc3b9..9664f06 100644 --- a/libp2p/transport/listener_interface.py +++ b/libp2p/transport/listener_interface.py @@ -21,9 +21,8 @@ class IListener(ABC): """ @abstractmethod - def close(self) -> bool: + async def close(self) -> None: """ close the listener such that no more connections can be open on this transport instance - :return: return True if successful """ diff --git a/libp2p/transport/tcp/tcp.py b/libp2p/transport/tcp/tcp.py index 8e29f9b..49c7d9b 100644 --- a/libp2p/transport/tcp/tcp.py +++ b/libp2p/transport/tcp/tcp.py @@ -45,20 +45,16 @@ class TCPListener(IListener): # TODO check if server is listening return self.multiaddrs - def close(self) -> bool: + async def close(self) -> None: """ close the listener such that no more connections can be open on this transport instance - :return: return True if successful """ if self.server is None: - return False + return self.server.close() - _loop = asyncio.get_event_loop() - _loop.run_until_complete(self.server.wait_closed()) - _loop.close() + await self.server.wait_closed() self.server = None - return True class TCP(ITransport): diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..9101fa6 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,33 @@ +import asyncio + +import pytest + +from .configs import LISTEN_MADDR +from .factories import HostFactory + + +@pytest.fixture +def is_host_secure(): + return False + + +@pytest.fixture +def num_hosts(): + return 3 + + +@pytest.fixture +async def hosts(num_hosts, is_host_secure): + _hosts = HostFactory.create_batch(num_hosts, is_secure=is_host_secure) + await asyncio.gather( + *[_host.get_network().listen(LISTEN_MADDR) for _host in _hosts] + ) + try: + yield _hosts + finally: + # TODO: It's possible that `close` raises exceptions currently, + # due to the connection reset things. Though we are not so careful about that when + # cleaning up the tasks, it is probably better to handle the exceptions properly. + await asyncio.gather( + *[_host.close() for _host in _hosts], return_exceptions=True + ) diff --git a/tests/pubsub/factories.py b/tests/factories.py similarity index 51% rename from tests/pubsub/factories.py rename to tests/factories.py index b57c29b..0604094 100644 --- a/tests/pubsub/factories.py +++ b/tests/factories.py @@ -1,12 +1,17 @@ +from typing import Dict + import factory -from libp2p import initialize_default_swarm -from libp2p.crypto.rsa import create_new_key_pair +from libp2p import generate_new_rsa_identity, initialize_default_swarm +from libp2p.crypto.keys import KeyPair from libp2p.host.basic_host import BasicHost from libp2p.pubsub.floodsub import FloodSub from libp2p.pubsub.gossipsub import GossipSub from libp2p.pubsub.pubsub import Pubsub -from tests.configs import LISTEN_MADDR +from libp2p.security.base_transport import BaseSecureTransport +from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport +from libp2p.security.secio.transport import ID, Transport +from libp2p.typing import TProtocol from tests.pubsub.configs import ( FLOODSUB_PROTOCOL_ID, GOSSIPSUB_PARAMS, @@ -14,16 +19,34 @@ from tests.pubsub.configs import ( ) -def swarm_factory(): - private_key = create_new_key_pair() - return initialize_default_swarm(private_key, transport_opt=[str(LISTEN_MADDR)]) +def security_transport_factory( + is_secure: bool, key_pair: KeyPair +) -> Dict[TProtocol, BaseSecureTransport]: + protocol_id: TProtocol + security_transport: BaseSecureTransport + if not is_secure: + protocol_id = PLAINTEXT_PROTOCOL_ID + security_transport = InsecureTransport(key_pair) + else: + protocol_id = ID + security_transport = Transport(key_pair) + return {protocol_id: security_transport} + + +def swarm_factory(is_secure: bool): + key_pair = generate_new_rsa_identity() + sec_opt = security_transport_factory(is_secure, key_pair) + return initialize_default_swarm(key_pair, sec_opt=sec_opt) class HostFactory(factory.Factory): class Meta: model = BasicHost - network = factory.LazyFunction(swarm_factory) + class Params: + is_secure = False + + network = factory.LazyAttribute(lambda o: swarm_factory(o.is_secure)) class FloodsubFactory(factory.Factory): diff --git a/tests/interop/test_echo.py b/tests/interop/test_echo.py index 57f7bd3..513ea86 100644 --- a/tests/interop/test_echo.py +++ b/tests/interop/test_echo.py @@ -7,11 +7,8 @@ from multiaddr import Multiaddr import pexpect import pytest -from libp2p import generate_new_rsa_identity, new_node from libp2p.peer.peerinfo import info_from_p2p_addr -from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport from libp2p.typing import TProtocol -from tests.configs import LISTEN_MADDR GOPATH = pathlib.Path(os.environ["GOPATH"]) ECHO_PATH = GOPATH / "bin" / "echo" @@ -19,50 +16,68 @@ ECHO_PROTOCOL_ID = TProtocol("/echo/1.0.0") NEW_LINE = "\r\n" -@pytest.mark.asyncio -async def test_insecure_conn_py_to_go(unused_tcp_port): +@pytest.fixture +def proc_factory(): + procs = [] + + def call_proc(cmd, args, logfile=None, encoding=None): + if logfile is None: + logfile = sys.stdout + if encoding is None: + encoding = "utf-8" + proc = pexpect.spawn(cmd, args, logfile=logfile, encoding=encoding) + procs.append(proc) + return proc + try: - go_proc = pexpect.spawn( - str(ECHO_PATH), - [f"-l={unused_tcp_port}", "-insecure"], - logfile=sys.stdout, - encoding="utf-8", - ) - await go_proc.expect(r"I am ([\w\./]+)" + NEW_LINE, async_=True) - maddr_str = go_proc.match.group(1) - maddr_str = maddr_str.replace("ipfs", "p2p") - maddr = Multiaddr(maddr_str) - go_pinfo = info_from_p2p_addr(maddr) - await go_proc.expect("listening for connections", async_=True) - - key_pair = generate_new_rsa_identity() - insecure_tpt = InsecureTransport(key_pair) - host = await new_node( - key_pair=key_pair, sec_opt={PLAINTEXT_PROTOCOL_ID: insecure_tpt} - ) - await host.connect(go_pinfo) - await go_proc.expect("swarm listener accepted connection", async_=True) - s = await host.new_stream(go_pinfo.peer_id, [ECHO_PROTOCOL_ID]) - - await go_proc.expect("Got a new stream!", async_=True) - data = "data321123\n" - await s.write(data.encode()) - await go_proc.expect(f"read: {data[:-1]}", async_=True) - echoed_resp = await s.read(len(data)) - assert echoed_resp.decode() == data - await s.close() + yield call_proc finally: - go_proc.close() + for proc in procs: + proc.close() +async def make_echo_proc( + proc_factory, port: int, is_secure: bool, destination: Multiaddr = None +): + args = [f"-l={port}"] + if not is_secure: + args.append("-insecure") + if destination is not None: + args.append(f"-d={str(destination)}") + echo_proc = proc_factory(str(ECHO_PATH), args, logfile=sys.stdout, encoding="utf-8") + await echo_proc.expect(r"I am ([\w\./]+)" + NEW_LINE, async_=True) + maddr_str_ipfs = echo_proc.match.group(1) + maddr_str = maddr_str_ipfs.replace("ipfs", "p2p") + maddr = Multiaddr(maddr_str) + go_pinfo = info_from_p2p_addr(maddr) + if destination is None: + await echo_proc.expect("listening for connections", async_=True) + return echo_proc, go_pinfo + + +@pytest.mark.parametrize("num_hosts", (1,)) @pytest.mark.asyncio -async def test_insecure_conn_go_to_py(unused_tcp_port): - key_pair = generate_new_rsa_identity() - insecure_tpt = InsecureTransport(key_pair) - host = await new_node( - key_pair=key_pair, sec_opt={PLAINTEXT_PROTOCOL_ID: insecure_tpt} - ) - await host.get_network().listen(LISTEN_MADDR) +async def test_insecure_conn_py_to_go(hosts, proc_factory, unused_tcp_port): + go_proc, go_pinfo = await make_echo_proc(proc_factory, unused_tcp_port, False) + + host = hosts[0] + await host.connect(go_pinfo) + await go_proc.expect("swarm listener accepted connection", async_=True) + s = await host.new_stream(go_pinfo.peer_id, [ECHO_PROTOCOL_ID]) + + await go_proc.expect("Got a new stream!", async_=True) + data = "data321123\n" + await s.write(data.encode()) + await go_proc.expect(f"read: {data[:-1]}", async_=True) + echoed_resp = await s.read(len(data)) + assert echoed_resp.decode() == data + await s.close() + + +@pytest.mark.parametrize("num_hosts", (1,)) +@pytest.mark.asyncio +async def test_insecure_conn_go_to_py(hosts, proc_factory, unused_tcp_port): + host = hosts[0] expected_data = "Hello, world!\n" reply_data = "Replyooo!\n" event_handler_finished = asyncio.Event() @@ -76,17 +91,8 @@ async def test_insecure_conn_go_to_py(unused_tcp_port): host.set_stream_handler(ECHO_PROTOCOL_ID, _handle_echo) py_maddr = host.get_addrs()[0] - go_proc = pexpect.spawn( - str(ECHO_PATH), - [f"-l={unused_tcp_port}", "-insecure", f"-d={str(py_maddr)}"], - logfile=sys.stdout, - encoding="utf-8", - ) - try: - await go_proc.expect(r"I am ([\w\./]+)" + NEW_LINE, async_=True) - await go_proc.expect("connect with peer", async_=True) - await go_proc.expect("opened stream", async_=True) - await event_handler_finished.wait() - await go_proc.expect(f"read reply: .*{reply_data.rstrip()}.*", async_=True) - finally: - go_proc.close() + go_proc, _ = await make_echo_proc(proc_factory, unused_tcp_port, False, py_maddr) + await go_proc.expect("connect with peer", async_=True) + await go_proc.expect("opened stream", async_=True) + await event_handler_finished.wait() + await go_proc.expect(f"read reply: .*{reply_data.rstrip()}.*", async_=True) diff --git a/tests/pubsub/conftest.py b/tests/pubsub/conftest.py index 12faff3..1755ee5 100644 --- a/tests/pubsub/conftest.py +++ b/tests/pubsub/conftest.py @@ -1,38 +1,7 @@ -import asyncio - import pytest -from tests.configs import LISTEN_MADDR +from tests.factories import FloodsubFactory, GossipsubFactory, PubsubFactory from tests.pubsub.configs import GOSSIPSUB_PARAMS -from tests.pubsub.factories import ( - FloodsubFactory, - GossipsubFactory, - HostFactory, - PubsubFactory, -) - - -@pytest.fixture -def num_hosts(): - return 3 - - -@pytest.fixture -async def hosts(num_hosts): - _hosts = HostFactory.create_batch(num_hosts) - await asyncio.gather( - *[_host.get_network().listen(LISTEN_MADDR) for _host in _hosts] - ) - try: - yield _hosts - finally: - # Clean up - listeners = [] - for _host in _hosts: - for listener in _host.get_network().listeners.values(): - listener.server.close() - listeners.append(listener) - await asyncio.gather(*[listener.server.wait_closed() for listener in listeners]) @pytest.fixture diff --git a/tests/pubsub/dummy_account_node.py b/tests/pubsub/dummy_account_node.py index c2ca8bf..98a224f 100644 --- a/tests/pubsub/dummy_account_node.py +++ b/tests/pubsub/dummy_account_node.py @@ -5,8 +5,8 @@ from libp2p.host.host_interface import IHost from libp2p.pubsub.floodsub import FloodSub from libp2p.pubsub.pubsub import Pubsub from tests.configs import LISTEN_MADDR +from tests.factories import FloodsubFactory, PubsubFactory -from .factories import FloodsubFactory, PubsubFactory from .utils import message_id_generator CRYPTO_TOPIC = "ethereum" diff --git a/tests/pubsub/floodsub_integration_test_settings.py b/tests/pubsub/floodsub_integration_test_settings.py index 09e5c30..d96fc2b 100644 --- a/tests/pubsub/floodsub_integration_test_settings.py +++ b/tests/pubsub/floodsub_integration_test_settings.py @@ -3,10 +3,10 @@ import asyncio import pytest from tests.configs import LISTEN_MADDR +from tests.factories import PubsubFactory from tests.utils import cleanup, connect from .configs import FLOODSUB_PROTOCOL_ID -from .factories import PubsubFactory SUPPORTED_PROTOCOLS = [FLOODSUB_PROTOCOL_ID] diff --git a/tests/pubsub/test_floodsub.py b/tests/pubsub/test_floodsub.py index 0a67c65..7e079d1 100644 --- a/tests/pubsub/test_floodsub.py +++ b/tests/pubsub/test_floodsub.py @@ -3,9 +3,9 @@ import asyncio import pytest from libp2p.peer.id import ID +from tests.factories import FloodsubFactory from tests.utils import cleanup, connect -from .factories import FloodsubFactory from .floodsub_integration_test_settings import ( floodsub_protocol_pytest_params, perform_test_from_obj, diff --git a/tests/pubsub/test_gossipsub_backward_compatibility.py b/tests/pubsub/test_gossipsub_backward_compatibility.py index e76ce04..3f2224f 100644 --- a/tests/pubsub/test_gossipsub_backward_compatibility.py +++ b/tests/pubsub/test_gossipsub_backward_compatibility.py @@ -2,8 +2,9 @@ import functools import pytest +from tests.factories import GossipsubFactory + from .configs import FLOODSUB_PROTOCOL_ID -from .factories import GossipsubFactory from .floodsub_integration_test_settings import ( floodsub_protocol_pytest_params, perform_test_from_obj,