diff --git a/libp2p/network/stream/exceptions.py b/libp2p/network/stream/exceptions.py new file mode 100644 index 0000000..58f3ddf --- /dev/null +++ b/libp2p/network/stream/exceptions.py @@ -0,0 +1,17 @@ +from libp2p.exceptions import BaseLibp2pError + + +class StreamError(BaseLibp2pError): + pass + + +class StreamEOF(StreamError, EOFError): + pass + + +class StreamReset(StreamError): + pass + + +class StreamClosed(StreamError): + pass diff --git a/libp2p/network/stream/net_stream.py b/libp2p/network/stream/net_stream.py index 7383f73..4dedab7 100644 --- a/libp2p/network/stream/net_stream.py +++ b/libp2p/network/stream/net_stream.py @@ -1,9 +1,18 @@ from libp2p.stream_muxer.abc import IMuxedConn, IMuxedStream +from libp2p.stream_muxer.exceptions import ( + MuxedStreamClosed, + MuxedStreamEOF, + MuxedStreamReset, +) from libp2p.typing import TProtocol +from .exceptions import StreamClosed, StreamEOF, StreamReset from .net_stream_interface import INetStream +# TODO: Handle exceptions from `muxed_stream` +# TODO: Add stream state +# - Reference: https://github.com/libp2p/go-libp2p-swarm/blob/99831444e78c8f23c9335c17d8f7c700ba25ca14/swarm_stream.go # noqa: E501 class NetStream(INetStream): muxed_stream: IMuxedStream @@ -35,14 +44,22 @@ class NetStream(INetStream): :param n: number of bytes to read :return: bytes of input """ - return await self.muxed_stream.read(n) + try: + return await self.muxed_stream.read(n) + except MuxedStreamEOF as error: + raise StreamEOF from error + except MuxedStreamReset as error: + raise StreamReset from error async def write(self, data: bytes) -> int: """ write to stream :return: number of bytes written """ - return await self.muxed_stream.write(data) + try: + return await self.muxed_stream.write(data) + except MuxedStreamClosed as error: + raise StreamClosed from error async def close(self) -> None: """ @@ -51,5 +68,5 @@ class NetStream(INetStream): """ await self.muxed_stream.close() - async def reset(self) -> bool: - return await self.muxed_stream.reset() + async def reset(self) -> None: + await self.muxed_stream.reset() diff --git a/libp2p/network/stream/net_stream_interface.py b/libp2p/network/stream/net_stream_interface.py index aaa775a..53ce038 100644 --- a/libp2p/network/stream/net_stream_interface.py +++ b/libp2p/network/stream/net_stream_interface.py @@ -23,7 +23,7 @@ class INetStream(ReadWriteCloser): """ @abstractmethod - async def reset(self) -> bool: + async def reset(self) -> None: """ Close both ends of the stream. """ diff --git a/libp2p/stream_muxer/exceptions.py b/libp2p/stream_muxer/exceptions.py new file mode 100644 index 0000000..861319a --- /dev/null +++ b/libp2p/stream_muxer/exceptions.py @@ -0,0 +1,25 @@ +from libp2p.exceptions import BaseLibp2pError + + +class MuxedConnError(BaseLibp2pError): + pass + + +class MuxedConnShutdown(MuxedConnError): + pass + + +class MuxedStreamError(BaseLibp2pError): + pass + + +class MuxedStreamReset(MuxedStreamError): + pass + + +class MuxedStreamEOF(MuxedStreamError, EOFError): + pass + + +class MuxedStreamClosed(MuxedStreamError): + pass diff --git a/libp2p/stream_muxer/mplex/exceptions.py b/libp2p/stream_muxer/mplex/exceptions.py index 11663e2..154c371 100644 --- a/libp2p/stream_muxer/mplex/exceptions.py +++ b/libp2p/stream_muxer/mplex/exceptions.py @@ -1,17 +1,27 @@ -from libp2p.exceptions import BaseLibp2pError +from libp2p.stream_muxer.exceptions import ( + MuxedConnError, + MuxedConnShutdown, + MuxedStreamClosed, + MuxedStreamEOF, + MuxedStreamReset, +) -class MplexError(BaseLibp2pError): +class MplexError(MuxedConnError): pass -class MplexStreamReset(MplexError): +class MplexShutdown(MuxedConnShutdown): pass -class MplexStreamEOF(MplexError, EOFError): +class MplexStreamReset(MuxedStreamReset): pass -class MplexShutdown(MplexError): +class MplexStreamEOF(MuxedStreamEOF): + pass + + +class MplexStreamClosed(MuxedStreamClosed): pass diff --git a/libp2p/stream_muxer/mplex/mplex.py b/libp2p/stream_muxer/mplex/mplex.py index 1e8823a..c75000d 100644 --- a/libp2p/stream_muxer/mplex/mplex.py +++ b/libp2p/stream_muxer/mplex/mplex.py @@ -188,6 +188,10 @@ class Mplex(IMuxedConn): # before. It is abnormal. Possibly disconnect? # TODO: Warn and emit logs about this. continue + async with stream.close_lock: + if stream.event_remote_closed.is_set(): + # TODO: Warn "Received data from remote after stream was closed by them. (len = %d)" # noqa: E501 + continue await stream.incoming_data.put(message) elif flag in ( HeaderTags.CloseInitiator.value, diff --git a/libp2p/stream_muxer/mplex/mplex_stream.py b/libp2p/stream_muxer/mplex/mplex_stream.py index 18c8ff0..87b039f 100644 --- a/libp2p/stream_muxer/mplex/mplex_stream.py +++ b/libp2p/stream_muxer/mplex/mplex_stream.py @@ -5,7 +5,7 @@ from libp2p.stream_muxer.abc import IMuxedStream from .constants import HeaderTags from .datastructures import StreamID -from .exceptions import MplexStreamEOF, MplexStreamReset +from .exceptions import MplexStreamClosed, MplexStreamEOF, MplexStreamReset if TYPE_CHECKING: from libp2p.stream_muxer.mplex.mplex import Mplex @@ -55,22 +55,46 @@ class MplexStream(IMuxedStream): return self.stream_id.is_initiator async def _wait_for_data(self) -> None: + task_event_reset = asyncio.ensure_future(self.event_reset.wait()) + task_incoming_data_get = asyncio.ensure_future(self.incoming_data.get()) + task_event_remote_closed = asyncio.ensure_future( + self.event_remote_closed.wait() + ) done, pending = await asyncio.wait( # type: ignore - [ - self.event_reset.wait(), - self.event_remote_closed.wait(), - self.incoming_data.get(), - ], + [task_event_reset, task_incoming_data_get, task_event_remote_closed], return_when=asyncio.FIRST_COMPLETED, ) - if self.event_reset.is_set(): - raise MplexStreamReset - if self.event_remote_closed.is_set(): - raise MplexStreamEOF - # TODO: Handle timeout when deadline is used. + for fut in pending: + fut.cancel() - data = tuple(done)[0].result() - self._buf.extend(data) + if task_event_reset in done: + if self.event_reset.is_set(): + raise MplexStreamReset + else: + # However, it is abnormal that `Event.wait` is unblocked without any of the flag + # is set. The task is probably cancelled. + raise Exception( + "Should not enter here. " + f"It is probably because {task_event_remote_closed} is cancelled." + ) + + if task_incoming_data_get in done: + data = task_incoming_data_get.result() + self._buf.extend(data) + return + + if task_event_remote_closed in done: + if self.event_remote_closed.is_set(): + raise MplexStreamEOF + else: + # However, it is abnormal that `Event.wait` is unblocked without any of the flag + # is set. The task is probably cancelled. + raise Exception( + "Should not enter here. " + f"It is probably because {task_event_remote_closed} is cancelled." + ) + + # TODO: Handle timeout when deadline is used. async def _read_until_eof(self) -> bytes: while True: @@ -90,7 +114,6 @@ class MplexStream(IMuxedStream): :param n: number of bytes to read :return: bytes actually read """ - # TODO: Add exceptions and handle/raise them in this class. if n < 0 and n != -1: raise ValueError( f"the number of bytes to read `n` must be positive or -1 to indicate read until EOF" @@ -99,13 +122,15 @@ class MplexStream(IMuxedStream): raise MplexStreamReset if n == -1: return await self._read_until_eof() - if len(self._buf) == 0: + if len(self._buf) == 0 and self.incoming_data.empty(): await self._wait_for_data() - # Read up to `n` bytes. + # Now we are sure we have something to read. + # Try to put enough incoming data into `self._buf`. while len(self._buf) < n: - if self.incoming_data.empty() or self.event_remote_closed.is_set(): + try: + self._buf.extend(self.incoming_data.get_nowait()) + except asyncio.QueueEmpty: break - self._buf.extend(await self.incoming_data.get()) payload = self._buf[:n] self._buf = self._buf[len(payload) :] return bytes(payload) @@ -115,6 +140,8 @@ class MplexStream(IMuxedStream): write to stream :return: number of bytes written """ + if self.event_local_closed.is_set(): + raise MplexStreamClosed(f"cannot write to closed stream: data={data}") flag = ( HeaderTags.MessageInitiator if self.is_initiator diff --git a/tests/examples/test_chat.py b/tests/examples/test_chat.py index 75d8ec7..18a172c 100644 --- a/tests/examples/test_chat.py +++ b/tests/examples/test_chat.py @@ -4,7 +4,7 @@ import pytest from libp2p.peer.peerinfo import info_from_p2p_addr from libp2p.protocol_muxer.exceptions import MultiselectClientError -from tests.utils import cleanup, set_up_nodes_by_transport_opt +from tests.utils import set_up_nodes_by_transport_opt PROTOCOL_ID = "/chat/1.0.0" @@ -101,5 +101,3 @@ async def test_chat(test): await host_b.connect(info) await test(host_a, host_b) - - await cleanup() diff --git a/tests/factories.py b/tests/factories.py index 240bdb8..0f69707 100644 --- a/tests/factories.py +++ b/tests/factories.py @@ -1,22 +1,28 @@ -from typing import Dict +import asyncio +from typing import Dict, Tuple import factory from libp2p import generate_new_rsa_identity, initialize_default_swarm from libp2p.crypto.keys import KeyPair from libp2p.host.basic_host import BasicHost +from libp2p.host.host_interface import IHost +from libp2p.network.stream.net_stream_interface import INetStream from libp2p.pubsub.floodsub import FloodSub from libp2p.pubsub.gossipsub import GossipSub from libp2p.pubsub.pubsub import Pubsub from libp2p.security.base_transport import BaseSecureTransport from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport import libp2p.security.secio.transport as secio +from libp2p.stream_muxer.mplex.mplex import Mplex from libp2p.typing import TProtocol +from tests.configs import LISTEN_MADDR from tests.pubsub.configs import ( FLOODSUB_PROTOCOL_ID, GOSSIPSUB_PARAMS, GOSSIPSUB_PROTOCOL_ID, ) +from tests.utils import connect def security_transport_factory( @@ -43,6 +49,12 @@ class HostFactory(factory.Factory): network = factory.LazyAttribute(lambda o: swarm_factory(o.is_secure)) + @classmethod + async def create_and_listen(cls) -> IHost: + host = cls() + await host.get_network().listen(LISTEN_MADDR) + return host + class FloodsubFactory(factory.Factory): class Meta: @@ -73,3 +85,37 @@ class PubsubFactory(factory.Factory): router = None my_id = factory.LazyAttribute(lambda obj: obj.host.get_id()) cache_size = None + + +async def host_pair_factory() -> Tuple[BasicHost, BasicHost]: + hosts = await asyncio.gather( + *[HostFactory.create_and_listen(), HostFactory.create_and_listen()] + ) + await connect(hosts[0], hosts[1]) + return hosts[0], hosts[1] + + +async def connection_pair_factory() -> Tuple[Mplex, BasicHost, Mplex, BasicHost]: + host_0, host_1 = await host_pair_factory() + mplex_conn_0 = host_0.get_network().connections[host_1.get_id()] + mplex_conn_1 = host_1.get_network().connections[host_0.get_id()] + return mplex_conn_0, host_0, mplex_conn_1, host_1 + + +async def net_stream_pair_factory() -> Tuple[ + INetStream, BasicHost, INetStream, BasicHost +]: + protocol_id = "/example/id/1" + + stream_1: INetStream + + # Just a proxy, we only care about the stream + def handler(stream: INetStream) -> None: + nonlocal stream_1 + stream_1 = stream + + host_0, host_1 = await host_pair_factory() + host_1.set_stream_handler(protocol_id, handler) + + stream_0 = await host_0.new_stream(host_1.get_id(), [protocol_id]) + return stream_0, host_0, stream_1, host_1 diff --git a/tests/interop/conftest.py b/tests/interop/conftest.py index 7261ee7..c280a4c 100644 --- a/tests/interop/conftest.py +++ b/tests/interop/conftest.py @@ -2,13 +2,16 @@ import asyncio import sys from typing import Union +from p2pclient.datastructures import StreamInfo import pexpect import pytest +from libp2p.io.abc import ReadWriteCloser from tests.factories import FloodsubFactory, GossipsubFactory, PubsubFactory from tests.pubsub.configs import GOSSIPSUB_PARAMS from .daemon import Daemon, make_p2pd +from .utils import connect @pytest.fixture @@ -78,3 +81,71 @@ def pubsubs(num_hosts, hosts, is_gossipsub): ) yield _pubsubs # TODO: Clean up + + +class DaemonStream(ReadWriteCloser): + stream_info: StreamInfo + reader: asyncio.StreamReader + writer: asyncio.StreamWriter + + def __init__( + self, + stream_info: StreamInfo, + reader: asyncio.StreamReader, + writer: asyncio.StreamWriter, + ) -> None: + self.stream_info = stream_info + self.reader = reader + self.writer = writer + + async def close(self) -> None: + self.writer.close() + await self.writer.wait_closed() + + async def read(self, n: int = -1) -> bytes: + return await self.reader.read(n) + + async def write(self, data: bytes) -> int: + return self.writer.write(data) + + +@pytest.fixture +async def is_to_fail_daemon_stream(): + return False + + +@pytest.fixture +async def py_to_daemon_stream_pair(hosts, p2pds, is_to_fail_daemon_stream): + assert len(hosts) >= 1 + assert len(p2pds) >= 1 + host = hosts[0] + p2pd = p2pds[0] + protocol_id = "/protocol/id/123" + stream_py = None + stream_daemon = None + event_stream_handled = asyncio.Event() + await connect(host, p2pd) + + async def daemon_stream_handler(stream_info, reader, writer): + nonlocal stream_daemon + stream_daemon = DaemonStream(stream_info, reader, writer) + event_stream_handled.set() + + await p2pd.control.stream_handler(protocol_id, daemon_stream_handler) + + if is_to_fail_daemon_stream: + # FIXME: This is a workaround to make daemon reset the stream. + # We intentionally close the listener on the python side, it makes the connection from + # daemon to us fail, and therefore the daemon resets the opened stream on their side. + # Reference: https://github.com/libp2p/go-libp2p-daemon/blob/b95e77dbfcd186ccf817f51e95f73f9fd5982600/stream.go#L47-L50 # noqa: E501 + # We need it because we want to test against `stream_py` after the remote side(daemon) + # is reset. This should be removed after the API `stream.reset` is exposed in daemon + # some day. + listener = p2pds[0].control.control.listener + listener.close() + await listener.wait_closed() + stream_py = await host.new_stream(p2pd.peer_id, [protocol_id]) + if not is_to_fail_daemon_stream: + await event_stream_handled.wait() + # NOTE: If `is_to_fail_daemon_stream == True`, `stream_daemon == None`. + yield stream_py, stream_daemon diff --git a/tests/interop/test_net_stream.py b/tests/interop/test_net_stream.py new file mode 100644 index 0000000..0171339 --- /dev/null +++ b/tests/interop/test_net_stream.py @@ -0,0 +1,74 @@ +import asyncio + +import pytest + +from libp2p.network.stream.exceptions import StreamClosed, StreamEOF, StreamReset +from tests.constants import MAX_READ_LEN + +DATA = b"data" + + +@pytest.mark.asyncio +async def test_net_stream_read_write(py_to_daemon_stream_pair, p2pds): + stream_py, stream_daemon = py_to_daemon_stream_pair + assert ( + stream_py.protocol_id is not None + and stream_py.protocol_id == stream_daemon.stream_info.proto + ) + await stream_py.write(DATA) + assert (await stream_daemon.read(MAX_READ_LEN)) == DATA + + +@pytest.mark.asyncio +async def test_net_stream_read_after_remote_closed(py_to_daemon_stream_pair, p2pds): + stream_py, stream_daemon = py_to_daemon_stream_pair + await stream_daemon.write(DATA) + await stream_daemon.close() + await asyncio.sleep(0.01) + assert (await stream_py.read(MAX_READ_LEN)) == DATA + # EOF + with pytest.raises(StreamEOF): + await stream_py.read(MAX_READ_LEN) + + +@pytest.mark.asyncio +async def test_net_stream_read_after_local_reset(py_to_daemon_stream_pair, p2pds): + stream_py, _ = py_to_daemon_stream_pair + await stream_py.reset() + with pytest.raises(StreamReset): + await stream_py.read(MAX_READ_LEN) + + +@pytest.mark.parametrize("is_to_fail_daemon_stream", (True,)) +@pytest.mark.asyncio +async def test_net_stream_read_after_remote_reset(py_to_daemon_stream_pair, p2pds): + stream_py, _ = py_to_daemon_stream_pair + await asyncio.sleep(0.01) + with pytest.raises(StreamReset): + await stream_py.read(MAX_READ_LEN) + + +@pytest.mark.asyncio +async def test_net_stream_write_after_local_closed(py_to_daemon_stream_pair, p2pds): + stream_py, _ = py_to_daemon_stream_pair + await stream_py.write(DATA) + await stream_py.close() + with pytest.raises(StreamClosed): + await stream_py.write(DATA) + + +@pytest.mark.asyncio +async def test_net_stream_write_after_local_reset(py_to_daemon_stream_pair, p2pds): + stream_py, stream_daemon = py_to_daemon_stream_pair + await stream_py.reset() + with pytest.raises(StreamClosed): + await stream_py.write(DATA) + + +@pytest.mark.parametrize("is_to_fail_daemon_stream", (True,)) +@pytest.mark.asyncio +async def test_net_stream_write_after_remote_reset(py_to_daemon_stream_pair, p2pds): + stream_py, _ = py_to_daemon_stream_pair + await asyncio.sleep(0.01) + with pytest.raises(StreamClosed): + await stream_py.write(DATA) diff --git a/tests/libp2p/test_libp2p.py b/tests/libp2p/test_libp2p.py index 8090f5e..793444c 100644 --- a/tests/libp2p/test_libp2p.py +++ b/tests/libp2p/test_libp2p.py @@ -3,7 +3,7 @@ import pytest from libp2p.peer.peerinfo import info_from_p2p_addr from tests.constants import MAX_READ_LEN -from tests.utils import cleanup, set_up_nodes_by_transport_opt +from tests.utils import set_up_nodes_by_transport_opt @pytest.mark.asyncio @@ -34,7 +34,6 @@ async def test_simple_messages(): assert response == ("ack:" + message) # Success, terminate pending tasks. - await cleanup() @pytest.mark.asyncio @@ -69,7 +68,6 @@ async def test_double_response(): assert response2 == ("ack2:" + message) # Success, terminate pending tasks. - await cleanup() @pytest.mark.asyncio @@ -120,7 +118,6 @@ async def test_multiple_streams(): ) # Success, terminate pending tasks. - await cleanup() @pytest.mark.asyncio @@ -183,7 +180,6 @@ async def test_multiple_streams_same_initiator_different_protocols(): ) # Success, terminate pending tasks. - await cleanup() @pytest.mark.asyncio @@ -264,7 +260,6 @@ async def test_multiple_streams_two_initiators(): ) # Success, terminate pending tasks. - await cleanup() @pytest.mark.asyncio @@ -326,7 +321,6 @@ async def test_triangle_nodes_connection(): assert response == ("ack:" + message) # Success, terminate pending tasks. - await cleanup() @pytest.mark.asyncio @@ -353,4 +347,3 @@ async def test_host_connect(): assert addr.encapsulate(ma_node_b) in node_b.get_addrs() # Success, terminate pending tasks. - await cleanup() diff --git a/tests/libp2p/test_notify.py b/tests/libp2p/test_notify.py index e21030a..b9a8707 100644 --- a/tests/libp2p/test_notify.py +++ b/tests/libp2p/test_notify.py @@ -17,7 +17,7 @@ from libp2p.crypto.rsa import create_new_key_pair from libp2p.host.basic_host import BasicHost from libp2p.network.notifee_interface import INotifee from tests.constants import MAX_READ_LEN -from tests.utils import cleanup, perform_two_host_set_up +from tests.utils import perform_two_host_set_up ACK = "ack:" @@ -91,7 +91,6 @@ async def test_one_notifier(): assert response == expected_resp # Success, terminate pending tasks. - await cleanup() @pytest.mark.asyncio @@ -138,7 +137,6 @@ async def test_one_notifier_on_two_nodes(): assert response == expected_resp # Success, terminate pending tasks. - await cleanup() @pytest.mark.asyncio @@ -203,7 +201,6 @@ async def test_one_notifier_on_two_nodes_with_listen(): assert response == expected_resp # Success, terminate pending tasks. - await cleanup() @pytest.mark.asyncio @@ -235,7 +232,6 @@ async def test_two_notifiers(): assert response == expected_resp # Success, terminate pending tasks. - await cleanup() @pytest.mark.asyncio @@ -271,7 +267,6 @@ async def test_ten_notifiers(): assert response == expected_resp # Success, terminate pending tasks. - await cleanup() @pytest.mark.asyncio @@ -325,7 +320,6 @@ async def test_ten_notifiers_on_two_nodes(): assert response == expected_resp # Success, terminate pending tasks. - await cleanup() @pytest.mark.asyncio @@ -355,4 +349,3 @@ async def test_invalid_notifee(): assert response == expected_resp # Success, terminate pending tasks. - await cleanup() diff --git a/tests/network/__init__.py b/tests/network/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/network/conftest.py b/tests/network/conftest.py new file mode 100644 index 0000000..10f7791 --- /dev/null +++ b/tests/network/conftest.py @@ -0,0 +1,14 @@ +import asyncio + +import pytest + +from tests.factories import net_stream_pair_factory + + +@pytest.fixture +async def net_stream_pair(): + stream_0, host_0, stream_1, host_1 = await net_stream_pair_factory() + try: + yield stream_0, stream_1 + finally: + await asyncio.gather(*[host_0.close(), host_1.close()]) diff --git a/tests/network/test_net_stream.py b/tests/network/test_net_stream.py new file mode 100644 index 0000000..80bed6c --- /dev/null +++ b/tests/network/test_net_stream.py @@ -0,0 +1,122 @@ +import asyncio + +import pytest + +from libp2p.network.stream.exceptions import StreamClosed, StreamEOF, StreamReset +from tests.constants import MAX_READ_LEN + +DATA = b"data_123" + +# TODO: Move `muxed_stream` specific(currently we are using `MplexStream`) tests to its +# own file, after `generic_protocol_handler` is refactored out of `Mplex`. + + +@pytest.mark.asyncio +async def test_net_stream_read_write(net_stream_pair): + stream_0, stream_1 = net_stream_pair + assert ( + stream_0.protocol_id is not None + and stream_0.protocol_id == stream_1.protocol_id + ) + await stream_0.write(DATA) + assert (await stream_1.read(MAX_READ_LEN)) == DATA + + +@pytest.mark.asyncio +async def test_net_stream_read_until_eof(net_stream_pair): + read_bytes = bytearray() + stream_0, stream_1 = net_stream_pair + + async def read_until_eof(): + read_bytes.extend(await stream_1.read()) + + task = asyncio.ensure_future(read_until_eof()) + + expected_data = bytearray() + + # Test: `read` doesn't return before `close` is called. + await stream_0.write(DATA) + expected_data.extend(DATA) + await asyncio.sleep(0.01) + assert len(read_bytes) == 0 + # Test: `read` doesn't return before `close` is called. + await stream_0.write(DATA) + expected_data.extend(DATA) + await asyncio.sleep(0.01) + assert len(read_bytes) == 0 + + # Test: Close the stream, `read` returns, and receive previous sent data. + await stream_0.close() + await asyncio.sleep(0.01) + assert read_bytes == expected_data + + task.cancel() + + +@pytest.mark.asyncio +async def test_net_stream_read_after_remote_closed(net_stream_pair): + stream_0, stream_1 = net_stream_pair + assert not stream_1.muxed_stream.event_remote_closed.is_set() + await stream_0.write(DATA) + await stream_0.close() + await asyncio.sleep(0.01) + assert stream_1.muxed_stream.event_remote_closed.is_set() + assert (await stream_1.read(MAX_READ_LEN)) == DATA + with pytest.raises(StreamEOF): + await stream_1.read(MAX_READ_LEN) + + +@pytest.mark.asyncio +async def test_net_stream_read_after_local_reset(net_stream_pair): + stream_0, stream_1 = net_stream_pair + await stream_0.reset() + with pytest.raises(StreamReset): + await stream_0.read(MAX_READ_LEN) + + +@pytest.mark.asyncio +async def test_net_stream_read_after_remote_reset(net_stream_pair): + stream_0, stream_1 = net_stream_pair + await stream_0.write(DATA) + await stream_0.reset() + # Sleep to let `stream_1` receive the message. + await asyncio.sleep(0.01) + with pytest.raises(StreamReset): + await stream_1.read(MAX_READ_LEN) + + +@pytest.mark.asyncio +async def test_net_stream_read_after_remote_closed_and_reset(net_stream_pair): + stream_0, stream_1 = net_stream_pair + await stream_0.write(DATA) + await stream_0.close() + await stream_0.reset() + # Sleep to let `stream_1` receive the message. + await asyncio.sleep(0.01) + assert (await stream_1.read(MAX_READ_LEN)) == DATA + + +@pytest.mark.asyncio +async def test_net_stream_write_after_local_closed(net_stream_pair): + stream_0, stream_1 = net_stream_pair + await stream_0.write(DATA) + await stream_0.close() + with pytest.raises(StreamClosed): + await stream_0.write(DATA) + + +@pytest.mark.asyncio +async def test_net_stream_write_after_local_reset(net_stream_pair): + stream_0, stream_1 = net_stream_pair + await stream_0.reset() + with pytest.raises(StreamClosed): + await stream_0.write(DATA) + + +@pytest.mark.asyncio +async def test_net_stream_write_after_remote_reset(net_stream_pair): + stream_0, stream_1 = net_stream_pair + await stream_1.reset() + await asyncio.sleep(0.01) + with pytest.raises(StreamClosed): + await stream_0.write(DATA) diff --git a/tests/protocol_muxer/test_protocol_muxer.py b/tests/protocol_muxer/test_protocol_muxer.py index 7830aaa..d7523ac 100644 --- a/tests/protocol_muxer/test_protocol_muxer.py +++ b/tests/protocol_muxer/test_protocol_muxer.py @@ -1,7 +1,7 @@ import pytest from libp2p.protocol_muxer.exceptions import MultiselectClientError -from tests.utils import cleanup, echo_stream_handler, set_up_nodes_by_transport_opt +from tests.utils import echo_stream_handler, set_up_nodes_by_transport_opt # TODO: Add tests for multiple streams being opened on different # protocols through the same connection @@ -35,7 +35,6 @@ async def perform_simple_test( assert expected_selected_protocol == stream.get_protocol() # Success, terminate pending tasks. - await cleanup() @pytest.mark.asyncio @@ -52,7 +51,6 @@ async def test_single_protocol_fails(): await perform_simple_test("", ["/echo/1.0.0"], ["/potato/1.0.0"]) # Cleanup not reached on error - await cleanup() @pytest.mark.asyncio @@ -83,4 +81,3 @@ async def test_multiple_protocol_fails(): await perform_simple_test("", protocols_for_client, protocols_for_listener) # Cleanup not reached on error - await cleanup() diff --git a/tests/pubsub/floodsub_integration_test_settings.py b/tests/pubsub/floodsub_integration_test_settings.py index d96fc2b..0a533e2 100644 --- a/tests/pubsub/floodsub_integration_test_settings.py +++ b/tests/pubsub/floodsub_integration_test_settings.py @@ -4,7 +4,7 @@ import pytest from tests.configs import LISTEN_MADDR from tests.factories import PubsubFactory -from tests.utils import cleanup, connect +from tests.utils import connect from .configs import FLOODSUB_PROTOCOL_ID @@ -258,4 +258,3 @@ async def perform_test_from_obj(obj, router_factory): assert node_map[origin_node_id].get_id().to_bytes() == msg.from_id # Success, terminate pending tasks. - await cleanup() diff --git a/tests/pubsub/test_dummyaccount_demo.py b/tests/pubsub/test_dummyaccount_demo.py index b365134..edc2f51 100644 --- a/tests/pubsub/test_dummyaccount_demo.py +++ b/tests/pubsub/test_dummyaccount_demo.py @@ -3,7 +3,7 @@ from threading import Thread import pytest -from tests.utils import cleanup, connect +from tests.utils import connect from .dummy_account_node import DummyAccountNode @@ -64,7 +64,6 @@ async def perform_test(num_nodes, adjacency_map, action_func, assertion_func): assertion_func(dummy_node) # Success, terminate pending tasks. - await cleanup() @pytest.mark.asyncio diff --git a/tests/pubsub/test_floodsub.py b/tests/pubsub/test_floodsub.py index 7e079d1..c6d28bf 100644 --- a/tests/pubsub/test_floodsub.py +++ b/tests/pubsub/test_floodsub.py @@ -4,7 +4,7 @@ import pytest from libp2p.peer.id import ID from tests.factories import FloodsubFactory -from tests.utils import cleanup, connect +from tests.utils import connect from .floodsub_integration_test_settings import ( floodsub_protocol_pytest_params, @@ -36,7 +36,6 @@ async def test_simple_two_nodes(pubsubs_fsub): assert res_b.topicIDs == [topic] # Success, terminate pending tasks. - await cleanup() # Initialize Pubsub with a cache_size of 4 @@ -82,7 +81,6 @@ async def test_lru_cache_two_nodes(pubsubs_fsub, monkeypatch): assert sub_b.empty() # Success, terminate pending tasks. - await cleanup() @pytest.mark.parametrize("test_case_obj", floodsub_protocol_pytest_params) diff --git a/tests/pubsub/test_gossipsub.py b/tests/pubsub/test_gossipsub.py index 7a0efc2..95775be 100644 --- a/tests/pubsub/test_gossipsub.py +++ b/tests/pubsub/test_gossipsub.py @@ -3,7 +3,7 @@ import random import pytest -from tests.utils import cleanup, connect +from tests.utils import connect from .configs import GossipsubParams from .utils import dense_connect, one_to_all_connect @@ -61,8 +61,6 @@ async def test_join(num_hosts, hosts, pubsubs_gsub): assert hosts[i].get_id() not in gossipsubs[central_node_index].mesh[topic] assert topic not in gossipsubs[i].mesh - await cleanup() - @pytest.mark.parametrize("num_hosts", (1,)) @pytest.mark.asyncio @@ -81,8 +79,6 @@ async def test_leave(pubsubs_gsub): # Test re-leave await gossipsub.leave(topic) - await cleanup() - @pytest.mark.parametrize("num_hosts", (2,)) @pytest.mark.asyncio @@ -133,8 +129,6 @@ async def test_handle_graft(pubsubs_gsub, hosts, event_loop, monkeypatch): # Check that bob is now alice's mesh peer assert id_bob in gossipsubs[index_alice].mesh[topic] - await cleanup() - @pytest.mark.parametrize( "num_hosts, gossipsub_params", ((2, GossipsubParams(heartbeat_interval=3)),) @@ -174,8 +168,6 @@ async def test_handle_prune(pubsubs_gsub, hosts): assert id_alice not in gossipsubs[index_bob].mesh[topic] assert id_bob in gossipsubs[index_alice].mesh[topic] - await cleanup() - @pytest.mark.parametrize("num_hosts", (10,)) @pytest.mark.asyncio @@ -210,7 +202,6 @@ async def test_dense(num_hosts, pubsubs_gsub, hosts): for queue in queues: msg = await queue.get() assert msg.data == msg_content - await cleanup() @pytest.mark.parametrize("num_hosts", (10,)) @@ -268,8 +259,6 @@ async def test_fanout(hosts, pubsubs_gsub): msg = await queue.get() assert msg.data == msg_content - await cleanup() - @pytest.mark.parametrize("num_hosts", (10,)) @pytest.mark.asyncio @@ -340,8 +329,6 @@ async def test_fanout_maintenance(hosts, pubsubs_gsub): msg = await queue.get() assert msg.data == msg_content - await cleanup() - @pytest.mark.parametrize( "num_hosts, gossipsub_params", @@ -380,5 +367,3 @@ async def test_gossip_propagation(hosts, pubsubs_gsub): # should be able to read message msg = await queue_1.get() assert msg.data == msg_content - - await cleanup() diff --git a/tests/security/test_security_multistream.py b/tests/security/test_security_multistream.py index 1d87e7b..c8e83c1 100644 --- a/tests/security/test_security_multistream.py +++ b/tests/security/test_security_multistream.py @@ -6,7 +6,7 @@ from libp2p import new_node from libp2p.crypto.rsa import create_new_key_pair from libp2p.security.insecure.transport import InsecureSession, InsecureTransport from tests.configs import LISTEN_MADDR -from tests.utils import cleanup, connect +from tests.utils import connect # TODO: Add tests for multiple streams being opened on different # protocols through the same connection @@ -57,7 +57,6 @@ async def perform_simple_test( assertion_func(node2_conn.secured_conn) # Success, terminate pending tasks. - await cleanup() @pytest.mark.asyncio diff --git a/tests/utils.py b/tests/utils.py index a26ebc5..e9d6c09 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,6 +1,3 @@ -import asyncio -from contextlib import suppress - import multiaddr from libp2p import new_node @@ -17,17 +14,6 @@ async def connect(node1, node2): await node1.connect(info) -async def cleanup(): - pending = asyncio.all_tasks() - for task in pending: - task.cancel() - - # Now we should await task to execute it's cancellation. - # Cancelled task raises asyncio.CancelledError that we can suppress: - with suppress(asyncio.CancelledError): - await task - - async def set_up_nodes_by_transport_opt(transport_opt_list): nodes_list = [] for transport_opt in transport_opt_list: