diff --git a/tests/factories.py b/tests/factories.py index dcc9a85..cecc656 100644 --- a/tests/factories.py +++ b/tests/factories.py @@ -16,6 +16,7 @@ from libp2p.security.base_transport import BaseSecureTransport from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport import libp2p.security.secio.transport as secio from libp2p.stream_muxer.mplex.mplex import MPLEX_PROTOCOL_ID, Mplex +from libp2p.stream_muxer.mplex.mplex_stream import MplexStream from libp2p.transport.typing import TMuxerOptions from libp2p.typing import TProtocol from tests.configs import LISTEN_MADDR @@ -149,7 +150,7 @@ async def swarm_conn_pair_factory( return conn_0, swarms[0], conn_1, swarms[1] -async def mplex_conn_pair_factory(is_secure): +async def mplex_conn_pair_factory(is_secure: bool) -> Tuple[Mplex, Swarm, Mplex, Swarm]: muxer_opt = {MPLEX_PROTOCOL_ID: Mplex} conn_0, swarm_0, conn_1, swarm_1 = await swarm_conn_pair_factory( is_secure, muxer_opt=muxer_opt @@ -157,6 +158,22 @@ async def mplex_conn_pair_factory(is_secure): return conn_0.conn, swarm_0, conn_1.conn, swarm_1 +async def mplex_stream_pair_factory( + is_secure: bool +) -> Tuple[MplexStream, Swarm, MplexStream, Swarm]: + mplex_conn_0, swarm_0, mplex_conn_1, swarm_1 = await mplex_conn_pair_factory( + is_secure + ) + stream_0 = await mplex_conn_0.open_stream() + await asyncio.sleep(0.01) + stream_1: MplexStream + async with mplex_conn_1.streams_lock: + if len(mplex_conn_1.streams) != 1: + raise Exception("Mplex should not have any stream upon connection") + stream_1 = tuple(mplex_conn_1.streams.values())[0] + return stream_0, swarm_0, stream_1, swarm_1 + + async def net_stream_pair_factory( is_secure: bool ) -> Tuple[INetStream, BasicHost, INetStream, BasicHost]: diff --git a/tests/network/test_net_stream.py b/tests/network/test_net_stream.py index c748837..9229069 100644 --- a/tests/network/test_net_stream.py +++ b/tests/network/test_net_stream.py @@ -53,11 +53,9 @@ async def test_net_stream_read_until_eof(net_stream_pair): @pytest.mark.asyncio async def test_net_stream_read_after_remote_closed(net_stream_pair): stream_0, stream_1 = net_stream_pair - assert not stream_1.muxed_stream.event_remote_closed.is_set() await stream_0.write(DATA) await stream_0.close() await asyncio.sleep(0.01) - assert stream_1.muxed_stream.event_remote_closed.is_set() assert (await stream_1.read(MAX_READ_LEN)) == DATA with pytest.raises(StreamEOF): await stream_1.read(MAX_READ_LEN) diff --git a/tests/stream_muxer/conftest.py b/tests/stream_muxer/conftest.py index b05a016..b1d6c11 100644 --- a/tests/stream_muxer/conftest.py +++ b/tests/stream_muxer/conftest.py @@ -2,7 +2,7 @@ import asyncio import pytest -from tests.factories import mplex_conn_pair_factory +from tests.factories import mplex_conn_pair_factory, mplex_stream_pair_factory @pytest.fixture @@ -16,3 +16,14 @@ async def mplex_conn_pair(is_host_secure): yield mplex_conn_0, mplex_conn_1 finally: await asyncio.gather(*[swarm_0.close(), swarm_1.close()]) + + +@pytest.fixture +async def mplex_stream_pair(is_host_secure): + mplex_stream_0, swarm_0, mplex_stream_1, swarm_1 = await mplex_stream_pair_factory( + is_host_secure + ) + try: + yield mplex_stream_0, mplex_stream_1 + finally: + await asyncio.gather(*[swarm_0.close(), swarm_1.close()]) diff --git a/tests/stream_muxer/test_mplex_stream.py b/tests/stream_muxer/test_mplex_stream.py new file mode 100644 index 0000000..e2bcb24 --- /dev/null +++ b/tests/stream_muxer/test_mplex_stream.py @@ -0,0 +1,182 @@ +import asyncio + +import pytest + +from libp2p.stream_muxer.mplex.exceptions import ( + MplexStreamClosed, + MplexStreamEOF, + MplexStreamReset, +) +from tests.constants import MAX_READ_LEN + +DATA = b"data_123" + + +@pytest.mark.asyncio +async def test_mplex_stream_read_write(mplex_stream_pair): + stream_0, stream_1 = mplex_stream_pair + await stream_0.write(DATA) + assert (await stream_1.read(MAX_READ_LEN)) == DATA + + +@pytest.mark.asyncio +async def test_mplex_stream_pair_read_until_eof(mplex_stream_pair): + read_bytes = bytearray() + stream_0, stream_1 = mplex_stream_pair + + async def read_until_eof(): + read_bytes.extend(await stream_1.read()) + + task = asyncio.ensure_future(read_until_eof()) + + expected_data = bytearray() + + # Test: `read` doesn't return before `close` is called. + await stream_0.write(DATA) + expected_data.extend(DATA) + await asyncio.sleep(0.01) + assert len(read_bytes) == 0 + # Test: `read` doesn't return before `close` is called. + await stream_0.write(DATA) + expected_data.extend(DATA) + await asyncio.sleep(0.01) + assert len(read_bytes) == 0 + + # Test: Close the stream, `read` returns, and receive previous sent data. + await stream_0.close() + await asyncio.sleep(0.01) + assert read_bytes == expected_data + + task.cancel() + + +@pytest.mark.asyncio +async def test_mplex_stream_read_after_remote_closed(mplex_stream_pair): + stream_0, stream_1 = mplex_stream_pair + assert not stream_1.event_remote_closed.is_set() + await stream_0.write(DATA) + await stream_0.close() + await asyncio.sleep(0.01) + assert stream_1.event_remote_closed.is_set() + assert (await stream_1.read(MAX_READ_LEN)) == DATA + with pytest.raises(MplexStreamEOF): + await stream_1.read(MAX_READ_LEN) + + +@pytest.mark.asyncio +async def test_mplex_stream_read_after_local_reset(mplex_stream_pair): + stream_0, stream_1 = mplex_stream_pair + await stream_0.reset() + with pytest.raises(MplexStreamReset): + await stream_0.read(MAX_READ_LEN) + + +@pytest.mark.asyncio +async def test_mplex_stream_read_after_remote_reset(mplex_stream_pair): + stream_0, stream_1 = mplex_stream_pair + await stream_0.write(DATA) + await stream_0.reset() + # Sleep to let `stream_1` receive the message. + await asyncio.sleep(0.01) + with pytest.raises(MplexStreamReset): + await stream_1.read(MAX_READ_LEN) + + +@pytest.mark.asyncio +async def test_mplex_stream_read_after_remote_closed_and_reset(mplex_stream_pair): + stream_0, stream_1 = mplex_stream_pair + await stream_0.write(DATA) + await stream_0.close() + await stream_0.reset() + # Sleep to let `stream_1` receive the message. + await asyncio.sleep(0.01) + assert (await stream_1.read(MAX_READ_LEN)) == DATA + + +@pytest.mark.asyncio +async def test_mplex_stream_write_after_local_closed(mplex_stream_pair): + stream_0, stream_1 = mplex_stream_pair + await stream_0.write(DATA) + await stream_0.close() + with pytest.raises(MplexStreamClosed): + await stream_0.write(DATA) + + +@pytest.mark.asyncio +async def test_mplex_stream_write_after_local_reset(mplex_stream_pair): + stream_0, stream_1 = mplex_stream_pair + await stream_0.reset() + with pytest.raises(MplexStreamClosed): + await stream_0.write(DATA) + + +@pytest.mark.asyncio +async def test_mplex_stream_write_after_remote_reset(mplex_stream_pair): + stream_0, stream_1 = mplex_stream_pair + await stream_1.reset() + await asyncio.sleep(0.01) + with pytest.raises(MplexStreamClosed): + await stream_0.write(DATA) + + +@pytest.mark.asyncio +async def test_mplex_stream_both_close(mplex_stream_pair): + stream_0, stream_1 = mplex_stream_pair + # Flags are not set initially. + assert not stream_0.event_local_closed.is_set() + assert not stream_1.event_local_closed.is_set() + assert not stream_0.event_remote_closed.is_set() + assert not stream_1.event_remote_closed.is_set() + # Streams are present in their `mplex_conn`. + assert stream_0 in stream_0.muxed_conn.streams.values() + assert stream_1 in stream_1.muxed_conn.streams.values() + + # Test: Close one side. + await stream_0.close() + await asyncio.sleep(0.01) + + assert stream_0.event_local_closed.is_set() + assert not stream_1.event_local_closed.is_set() + assert not stream_0.event_remote_closed.is_set() + assert stream_1.event_remote_closed.is_set() + # Streams are still present in their `mplex_conn`. + assert stream_0 in stream_0.muxed_conn.streams.values() + assert stream_1 in stream_1.muxed_conn.streams.values() + + # Test: Close the other side. + await stream_1.close() + await asyncio.sleep(0.01) + # Both sides are closed. + assert stream_0.event_local_closed.is_set() + assert stream_1.event_local_closed.is_set() + assert stream_0.event_remote_closed.is_set() + assert stream_1.event_remote_closed.is_set() + # Streams are removed from their `mplex_conn`. + assert stream_0 not in stream_0.muxed_conn.streams.values() + assert stream_1 not in stream_1.muxed_conn.streams.values() + + # Test: Reset after both close. + await stream_0.reset() + + +@pytest.mark.asyncio +async def test_mplex_stream_reset(mplex_stream_pair): + stream_0, stream_1 = mplex_stream_pair + await stream_0.reset() + await asyncio.sleep(0.01) + + # Both sides are closed. + assert stream_0.event_local_closed.is_set() + assert stream_1.event_local_closed.is_set() + assert stream_0.event_remote_closed.is_set() + assert stream_1.event_remote_closed.is_set() + # Streams are removed from their `mplex_conn`. + assert stream_0 not in stream_0.muxed_conn.streams.values() + assert stream_1 not in stream_1.muxed_conn.streams.values() + + # `close` should do nothing. + await stream_0.close() + await stream_1.close() + # `reset` should do nothing as well. + await stream_0.reset() + await stream_1.reset()