diff --git a/tests/stream_muxer/test_mplex_stream.py b/tests/stream_muxer/test_mplex_stream.py index 181daa0..eeb7653 100644 --- a/tests/stream_muxer/test_mplex_stream.py +++ b/tests/stream_muxer/test_mplex_stream.py @@ -7,26 +7,16 @@ from libp2p.stream_muxer.mplex.exceptions import ( MplexStreamEOF, MplexStreamReset, ) -from libp2p.tools.constants import LISTEN_MADDR, MAX_READ_LEN -from libp2p.tools.factories import SwarmFactory -from libp2p.tools.utils import connect_swarm +from libp2p.tools.constants import MAX_READ_LEN DATA = b"data_123" @pytest.mark.trio -async def test_mplex_stream_read_write(): - async with SwarmFactory.create_batch_and_listen(False, 2) as swarms: - await swarms[0].listen(LISTEN_MADDR) - await swarms[1].listen(LISTEN_MADDR) - await connect_swarm(swarms[0], swarms[1]) - conn_0 = swarms[0].connections[swarms[1].get_peer_id()] - conn_1 = swarms[1].connections[swarms[0].get_peer_id()] - stream_0 = await conn_0.muxed_conn.open_stream() - await trio.sleep(1) - stream_1 = tuple(conn_1.muxed_conn.streams.values())[0] - await stream_0.write(DATA) - assert (await stream_1.read(MAX_READ_LEN)) == DATA +async def test_mplex_stream_read_write(mplex_stream_pair): + stream_0, stream_1 = mplex_stream_pair + await stream_0.write(DATA) + assert (await stream_1.read(MAX_READ_LEN)) == DATA @pytest.mark.trio