Mplex: change message channel size to 8
To avoid infinity sized channel, and to conform to the go implementation.
This commit is contained in:
parent
64c9c48dac
commit
1fff6ad6b4
|
@ -1,5 +1,4 @@
|
||||||
import logging
|
import logging
|
||||||
import math
|
|
||||||
from typing import Dict, Optional, Tuple
|
from typing import Dict, Optional, Tuple
|
||||||
|
|
||||||
import trio
|
import trio
|
||||||
|
@ -24,6 +23,8 @@ from .exceptions import MplexUnavailable
|
||||||
from .mplex_stream import MplexStream
|
from .mplex_stream import MplexStream
|
||||||
|
|
||||||
MPLEX_PROTOCOL_ID = TProtocol("/mplex/6.7.0")
|
MPLEX_PROTOCOL_ID = TProtocol("/mplex/6.7.0")
|
||||||
|
# Ref: https://github.com/libp2p/go-mplex/blob/master/multiplex.go#L115
|
||||||
|
MPLEX_MESSAGE_CHANNEL_SIZE = 8
|
||||||
|
|
||||||
logger = logging.getLogger("libp2p.stream_muxer.mplex.mplex")
|
logger = logging.getLogger("libp2p.stream_muxer.mplex.mplex")
|
||||||
|
|
||||||
|
@ -109,9 +110,9 @@ class Mplex(IMuxedConn):
|
||||||
return next_id
|
return next_id
|
||||||
|
|
||||||
async def _initialize_stream(self, stream_id: StreamID, name: str) -> MplexStream:
|
async def _initialize_stream(self, stream_id: StreamID, name: str) -> MplexStream:
|
||||||
# Use an unbounded buffer, to avoid `handle_incoming` being blocked when doing
|
send_channel, receive_channel = trio.open_memory_channel[bytes](
|
||||||
# `send_channel.send`.
|
MPLEX_MESSAGE_CHANNEL_SIZE
|
||||||
send_channel, receive_channel = trio.open_memory_channel[bytes](math.inf)
|
)
|
||||||
stream = MplexStream(name, stream_id, self, receive_channel)
|
stream = MplexStream(name, stream_id, self, receive_channel)
|
||||||
async with self.streams_lock:
|
async with self.streams_lock:
|
||||||
self.streams[stream_id] = stream
|
self.streams[stream_id] = stream
|
||||||
|
@ -145,7 +146,7 @@ class Mplex(IMuxedConn):
|
||||||
"""
|
"""
|
||||||
sends a message over the connection.
|
sends a message over the connection.
|
||||||
|
|
||||||
:param header: header to use
|
:param flag: header to use
|
||||||
:param data: data to send in the message
|
:param data: data to send in the message
|
||||||
:param stream_id: stream the message is in
|
:param stream_id: stream the message is in
|
||||||
"""
|
"""
|
||||||
|
@ -270,9 +271,15 @@ class Mplex(IMuxedConn):
|
||||||
# TODO: Warn "Received data from remote after stream was closed by them. (len = %d)" # noqa: E501
|
# TODO: Warn "Received data from remote after stream was closed by them. (len = %d)" # noqa: E501
|
||||||
return
|
return
|
||||||
try:
|
try:
|
||||||
await send_channel.send(message)
|
send_channel.send_nowait(message)
|
||||||
except (trio.BrokenResourceError, trio.ClosedResourceError):
|
except (trio.BrokenResourceError, trio.ClosedResourceError):
|
||||||
raise MplexUnavailable
|
raise MplexUnavailable
|
||||||
|
except trio.WouldBlock:
|
||||||
|
# `send_channel` is full, reset this stream.
|
||||||
|
logger.warning(
|
||||||
|
"message channel of stream %s is full: stream is reset", stream_id
|
||||||
|
)
|
||||||
|
await stream.reset()
|
||||||
|
|
||||||
async def _handle_close(self, stream_id: StreamID) -> None:
|
async def _handle_close(self, stream_id: StreamID) -> None:
|
||||||
async with self.streams_lock:
|
async with self.streams_lock:
|
||||||
|
|
|
@ -7,6 +7,7 @@ from libp2p.stream_muxer.mplex.exceptions import (
|
||||||
MplexStreamEOF,
|
MplexStreamEOF,
|
||||||
MplexStreamReset,
|
MplexStreamReset,
|
||||||
)
|
)
|
||||||
|
from libp2p.stream_muxer.mplex.mplex import MPLEX_MESSAGE_CHANNEL_SIZE
|
||||||
from libp2p.tools.constants import MAX_READ_LEN
|
from libp2p.tools.constants import MAX_READ_LEN
|
||||||
|
|
||||||
DATA = b"data_123"
|
DATA = b"data_123"
|
||||||
|
@ -19,6 +20,28 @@ async def test_mplex_stream_read_write(mplex_stream_pair):
|
||||||
assert (await stream_1.read(MAX_READ_LEN)) == DATA
|
assert (await stream_1.read(MAX_READ_LEN)) == DATA
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.trio
|
||||||
|
async def test_mplex_stream_full_buffer(mplex_stream_pair):
|
||||||
|
stream_0, stream_1 = mplex_stream_pair
|
||||||
|
# Test: The message channel is of size `MPLEX_MESSAGE_CHANNEL_SIZE`.
|
||||||
|
# It should be fine to read even there are already `MPLEX_MESSAGE_CHANNEL_SIZE`
|
||||||
|
# messages arriving.
|
||||||
|
for _ in range(MPLEX_MESSAGE_CHANNEL_SIZE):
|
||||||
|
await stream_0.write(DATA)
|
||||||
|
await wait_all_tasks_blocked()
|
||||||
|
# Sanity check
|
||||||
|
assert MAX_READ_LEN >= MPLEX_MESSAGE_CHANNEL_SIZE * len(DATA)
|
||||||
|
assert (await stream_1.read(MAX_READ_LEN)) == MPLEX_MESSAGE_CHANNEL_SIZE * DATA
|
||||||
|
|
||||||
|
# Test: Read after `MPLEX_MESSAGE_CHANNEL_SIZE + 1` messages has arrived, which
|
||||||
|
# exceeds the channel size. The stream should have been reset.
|
||||||
|
for _ in range(MPLEX_MESSAGE_CHANNEL_SIZE + 1):
|
||||||
|
await stream_0.write(DATA)
|
||||||
|
await wait_all_tasks_blocked()
|
||||||
|
with pytest.raises(MplexStreamReset):
|
||||||
|
await stream_1.read(MAX_READ_LEN)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.trio
|
@pytest.mark.trio
|
||||||
async def test_mplex_stream_pair_read_until_eof(mplex_stream_pair):
|
async def test_mplex_stream_pair_read_until_eof(mplex_stream_pair):
|
||||||
read_bytes = bytearray()
|
read_bytes = bytearray()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user