diff --git a/libp2p/stream_muxer/mplex/mplex.py b/libp2p/stream_muxer/mplex/mplex.py index b3b45c0..5b1df77 100644 --- a/libp2p/stream_muxer/mplex/mplex.py +++ b/libp2p/stream_muxer/mplex/mplex.py @@ -1,5 +1,4 @@ import logging -import math from typing import Dict, Optional, Tuple import trio @@ -24,6 +23,8 @@ from .exceptions import MplexUnavailable from .mplex_stream import MplexStream MPLEX_PROTOCOL_ID = TProtocol("/mplex/6.7.0") +# Ref: https://github.com/libp2p/go-mplex/blob/master/multiplex.go#L115 +MPLEX_MESSAGE_CHANNEL_SIZE = 8 logger = logging.getLogger("libp2p.stream_muxer.mplex.mplex") @@ -109,9 +110,9 @@ class Mplex(IMuxedConn): return next_id async def _initialize_stream(self, stream_id: StreamID, name: str) -> MplexStream: - # Use an unbounded buffer, to avoid `handle_incoming` being blocked when doing - # `send_channel.send`. - send_channel, receive_channel = trio.open_memory_channel[bytes](math.inf) + send_channel, receive_channel = trio.open_memory_channel[bytes]( + MPLEX_MESSAGE_CHANNEL_SIZE + ) stream = MplexStream(name, stream_id, self, receive_channel) async with self.streams_lock: self.streams[stream_id] = stream @@ -145,7 +146,7 @@ class Mplex(IMuxedConn): """ sends a message over the connection. - :param header: header to use + :param flag: header to use :param data: data to send in the message :param stream_id: stream the message is in """ @@ -270,9 +271,15 @@ class Mplex(IMuxedConn): # TODO: Warn "Received data from remote after stream was closed by them. (len = %d)" # noqa: E501 return try: - await send_channel.send(message) + send_channel.send_nowait(message) except (trio.BrokenResourceError, trio.ClosedResourceError): raise MplexUnavailable + except trio.WouldBlock: + # `send_channel` is full, reset this stream. + logger.warning( + "message channel of stream %s is full: stream is reset", stream_id + ) + await stream.reset() async def _handle_close(self, stream_id: StreamID) -> None: async with self.streams_lock: diff --git a/tests/stream_muxer/test_mplex_stream.py b/tests/stream_muxer/test_mplex_stream.py index eeb7653..3bc8bc1 100644 --- a/tests/stream_muxer/test_mplex_stream.py +++ b/tests/stream_muxer/test_mplex_stream.py @@ -7,6 +7,7 @@ from libp2p.stream_muxer.mplex.exceptions import ( MplexStreamEOF, MplexStreamReset, ) +from libp2p.stream_muxer.mplex.mplex import MPLEX_MESSAGE_CHANNEL_SIZE from libp2p.tools.constants import MAX_READ_LEN DATA = b"data_123" @@ -19,6 +20,28 @@ async def test_mplex_stream_read_write(mplex_stream_pair): assert (await stream_1.read(MAX_READ_LEN)) == DATA +@pytest.mark.trio +async def test_mplex_stream_full_buffer(mplex_stream_pair): + stream_0, stream_1 = mplex_stream_pair + # Test: The message channel is of size `MPLEX_MESSAGE_CHANNEL_SIZE`. + # It should be fine to read even there are already `MPLEX_MESSAGE_CHANNEL_SIZE` + # messages arriving. + for _ in range(MPLEX_MESSAGE_CHANNEL_SIZE): + await stream_0.write(DATA) + await wait_all_tasks_blocked() + # Sanity check + assert MAX_READ_LEN >= MPLEX_MESSAGE_CHANNEL_SIZE * len(DATA) + assert (await stream_1.read(MAX_READ_LEN)) == MPLEX_MESSAGE_CHANNEL_SIZE * DATA + + # Test: Read after `MPLEX_MESSAGE_CHANNEL_SIZE + 1` messages has arrived, which + # exceeds the channel size. The stream should have been reset. + for _ in range(MPLEX_MESSAGE_CHANNEL_SIZE + 1): + await stream_0.write(DATA) + await wait_all_tasks_blocked() + with pytest.raises(MplexStreamReset): + await stream_1.read(MAX_READ_LEN) + + @pytest.mark.trio async def test_mplex_stream_pair_read_until_eof(mplex_stream_pair): read_bytes = bytearray()