diff --git a/libp2p/network/connection/swarm_connection.py b/libp2p/network/connection/swarm_connection.py index 46b3f6f..48774ec 100644 --- a/libp2p/network/connection/swarm_connection.py +++ b/libp2p/network/connection/swarm_connection.py @@ -1,7 +1,7 @@ from typing import TYPE_CHECKING, Any, Awaitable, List, Set, Tuple -import trio from async_service import Service +import trio from libp2p.network.connection.net_connection_interface import INetConn from libp2p.network.stream.net_stream import NetStream diff --git a/libp2p/network/swarm.py b/libp2p/network/swarm.py index 78fb7fd..37614bc 100644 --- a/libp2p/network/swarm.py +++ b/libp2p/network/swarm.py @@ -2,7 +2,6 @@ import logging from typing import Dict, List, Optional from async_service import Service - from multiaddr import Multiaddr import trio @@ -205,7 +204,7 @@ class Swarm(INetwork, Service): logger.debug("successfully opened connection to peer %s", peer_id) # FIXME: This is a intentional barrier to prevent from the handler exiting and - # closing the connection. + # closing the connection. Probably change to `Service.manager.wait_finished`? await trio.sleep_forever() try: @@ -229,16 +228,6 @@ class Swarm(INetwork, Service): async def close(self) -> None: # TODO: Prevent from new listeners and conns being added. # Reference: https://github.com/libp2p/go-libp2p-swarm/blob/8be680aef8dea0a4497283f2f98470c2aeae6b65/swarm.go#L124-L134 # noqa: E501 - - # Close listeners - # await asyncio.gather( - # *[listener.close() for listener in self.listeners.values()] - # ) - - # # Close connections - # await asyncio.gather( - # *[connection.close() for connection in self.connections.values()] - # ) await self.manager.stop() await self.manager.wait_finished() logger.debug("swarm successfully closed") diff --git a/libp2p/stream_muxer/mplex/mplex.py b/libp2p/stream_muxer/mplex/mplex.py index 53d855b..f93acea 100644 --- a/libp2p/stream_muxer/mplex/mplex.py +++ b/libp2p/stream_muxer/mplex/mplex.py @@ -1,11 +1,11 @@ -import math import asyncio import logging +import math from typing import Any # noqa: F401 from typing import Awaitable, Dict, List, Optional, Tuple -import trio from async_service import Service +import trio from libp2p.exceptions import ParseError from libp2p.io.exceptions import IncompleteReadError diff --git a/libp2p/tools/factories.py b/libp2p/tools/factories.py index 448ec90..470cbc3 100644 --- a/libp2p/tools/factories.py +++ b/libp2p/tools/factories.py @@ -1,9 +1,9 @@ -import trio -from contextlib import asynccontextmanager, AsyncExitStack +from contextlib import AsyncExitStack, asynccontextmanager from typing import Any, AsyncIterator, Dict, Tuple, cast -import factory from async_service import background_trio_service +import factory +import trio from libp2p import generate_new_rsa_identity, generate_peer_id_from from libp2p.crypto.keys import KeyPair diff --git a/libp2p/tools/utils.py b/libp2p/tools/utils.py index 0d39156..db1e8ab 100644 --- a/libp2p/tools/utils.py +++ b/libp2p/tools/utils.py @@ -1,7 +1,7 @@ -import trio from typing import List, Sequence, Tuple import multiaddr +import trio from libp2p import new_node from libp2p.host.basic_host import BasicHost diff --git a/libp2p/transport/tcp/tcp.py b/libp2p/transport/tcp/tcp.py index 1636598..745bafe 100644 --- a/libp2p/transport/tcp/tcp.py +++ b/libp2p/transport/tcp/tcp.py @@ -25,7 +25,8 @@ class TCPListener(IListener): self.server = None self.handler = handler_function - async def listen(self, maddr: Multiaddr, nursery: trio.Nursery) -> bool: + # TODO: Fix handling? + async def listen(self, maddr: Multiaddr, nursery: trio.Nursery) -> None: """ put listener in listening mode and wait for incoming connections. @@ -50,16 +51,13 @@ class TCPListener(IListener): socket = listeners[0].socket self.multiaddrs.append(_multiaddr_from_socket(socket)) - return True - def get_addrs(self) -> List[Multiaddr]: """ retrieve list of addresses the listener is listening on. :return: return list of addrs """ - # TODO check if server is listening - return self.multiaddrs + return tuple(self.multiaddrs) async def close(self) -> None: """close the listener such that no more connections can be open on this diff --git a/tests/network/test_swarm.py b/tests/network/test_swarm.py index de08635..1492441 100644 --- a/tests/network/test_swarm.py +++ b/tests/network/test_swarm.py @@ -1,5 +1,5 @@ -import trio import pytest +import trio from trio.testing import wait_all_tasks_blocked from libp2p.network.exceptions import SwarmException diff --git a/tests/stream_muxer/test_mplex_conn.py b/tests/stream_muxer/test_mplex_conn.py index d48432d..4cedc36 100644 --- a/tests/stream_muxer/test_mplex_conn.py +++ b/tests/stream_muxer/test_mplex_conn.py @@ -1,6 +1,5 @@ -import trio - import pytest +import trio @pytest.mark.trio diff --git a/tests/stream_muxer/test_mplex_stream.py b/tests/stream_muxer/test_mplex_stream.py index e3d19a5..e47af49 100644 --- a/tests/stream_muxer/test_mplex_stream.py +++ b/tests/stream_muxer/test_mplex_stream.py @@ -7,11 +7,10 @@ from libp2p.stream_muxer.mplex.exceptions import ( MplexStreamEOF, MplexStreamReset, ) -from libp2p.tools.constants import MAX_READ_LEN, LISTEN_MADDR +from libp2p.tools.constants import LISTEN_MADDR, MAX_READ_LEN from libp2p.tools.factories import SwarmFactory from libp2p.tools.utils import connect_swarm - DATA = b"data_123" diff --git a/tests/transport/test_tcp.py b/tests/transport/test_tcp.py index 7231a06..c8fe6f2 100644 --- a/tests/transport/test_tcp.py +++ b/tests/transport/test_tcp.py @@ -1,20 +1,47 @@ -import asyncio - +from multiaddr import Multiaddr import pytest +import trio -from libp2p.transport.tcp.tcp import _multiaddr_from_socket +from libp2p.network.connection.raw_connection import RawConnection +from libp2p.tools.constants import LISTEN_MADDR, MAX_READ_LEN +from libp2p.transport.tcp.tcp import TCP -@pytest.mark.asyncio -async def test_multiaddr_from_socket(): - def handler(r, w): - pass +@pytest.mark.trio +async def test_tcp_listener(nursery): + transport = TCP() - server = await asyncio.start_server(handler, "127.0.0.1", 8000) - assert str(_multiaddr_from_socket(server.sockets[0])) == "/ip4/127.0.0.1/tcp/8000" + async def handler(tcp_stream): + ... - server = await asyncio.start_server(handler, "127.0.0.1", 0) - addr = _multiaddr_from_socket(server.sockets[0]) - assert addr.value_for_protocol("ip4") == "127.0.0.1" - port = addr.value_for_protocol("tcp") - assert int(port) > 0 + listener = transport.create_listener(handler) + assert len(listener.get_addrs()) == 0 + await listener.listen(LISTEN_MADDR, nursery) + assert len(listener.get_addrs()) == 1 + await listener.listen(LISTEN_MADDR, nursery) + assert len(listener.get_addrs()) == 2 + + +@pytest.mark.trio +async def test_tcp_dial(nursery): + transport = TCP() + raw_conn_other_side = None + + async def handler(tcp_stream): + nonlocal raw_conn_other_side + raw_conn_other_side = RawConnection(tcp_stream, False) + await trio.sleep_forever() + + # Test: OSError is raised when trying to dial to a port which no one is not listening to. + with pytest.raises(OSError): + await transport.dial(Multiaddr("/ip4/127.0.0.1/tcp/1")) + + listener = transport.create_listener(handler) + await listener.listen(LISTEN_MADDR, nursery) + assert len(listener.multiaddrs) == 1 + listen_addr = listener.multiaddrs[0] + raw_conn = await transport.dial(listen_addr) + + data = b"123" + await raw_conn_other_side.write(data) + assert (await raw_conn.read(len(data))) == data