Modify the behavior of MplexStream.read
This commit is contained in:
parent
2485a00e24
commit
9cb6ec1c48
|
@ -1,5 +1,5 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import Dict, Tuple
|
from typing import Dict, Tuple, Optional
|
||||||
|
|
||||||
from multiaddr import Multiaddr
|
from multiaddr import Multiaddr
|
||||||
|
|
||||||
|
@ -10,6 +10,7 @@ from libp2p.security.secure_conn_interface import ISecureConn
|
||||||
from libp2p.stream_muxer.abc import IMuxedConn, IMuxedStream
|
from libp2p.stream_muxer.abc import IMuxedConn, IMuxedStream
|
||||||
|
|
||||||
from .constants import HeaderTags
|
from .constants import HeaderTags
|
||||||
|
from .exceptions import StreamNotFound
|
||||||
from .mplex_stream import MplexStream
|
from .mplex_stream import MplexStream
|
||||||
from .utils import decode_uvarint_from_stream, encode_uvarint
|
from .utils import decode_uvarint_from_stream, encode_uvarint
|
||||||
|
|
||||||
|
@ -23,6 +24,9 @@ class Mplex(IMuxedConn):
|
||||||
raw_conn: IRawConnection
|
raw_conn: IRawConnection
|
||||||
initiator: bool
|
initiator: bool
|
||||||
peer_id: ID
|
peer_id: ID
|
||||||
|
# TODO: `dataIn` in go implementation. Should be size of 8.
|
||||||
|
# TODO: Also, `dataIn` is closed indicating EOF in Go. We don't have similar strategies
|
||||||
|
# to let the `MplexStream`s know that EOF arrived (#235).
|
||||||
buffers: Dict[int, "asyncio.Queue[bytes]"]
|
buffers: Dict[int, "asyncio.Queue[bytes]"]
|
||||||
stream_queue: "asyncio.Queue[int]"
|
stream_queue: "asyncio.Queue[int]"
|
||||||
|
|
||||||
|
@ -75,17 +79,19 @@ class Mplex(IMuxedConn):
|
||||||
:param stream_id: stream id of stream to read from
|
:param stream_id: stream id of stream to read from
|
||||||
:return: message read
|
:return: message read
|
||||||
"""
|
"""
|
||||||
# TODO: propagate up timeout exception and catch
|
if stream_id not in self.buffers:
|
||||||
# TODO: pass down timeout from user and use that
|
raise StreamNotFound(f"stream {stream_id} is not found")
|
||||||
if stream_id in self.buffers:
|
return await self.buffers[stream_id].get()
|
||||||
try:
|
|
||||||
data = await asyncio.wait_for(self.buffers[stream_id].get(), timeout=8)
|
|
||||||
return data
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Stream not created yet
|
async def read_buffer_nonblocking(self, stream_id: int) -> Optional[bytes]:
|
||||||
return None
|
"""
|
||||||
|
Read a message from `stream_id`'s buffer, non-blockingly.
|
||||||
|
"""
|
||||||
|
if stream_id not in self.buffers:
|
||||||
|
raise StreamNotFound(f"stream {stream_id} is not found")
|
||||||
|
if self.buffers[stream_id].empty():
|
||||||
|
return None
|
||||||
|
return await self.buffers[stream_id].get()
|
||||||
|
|
||||||
async def open_stream(
|
async def open_stream(
|
||||||
self, protocol_id: str, multi_addr: Multiaddr
|
self, protocol_id: str, multi_addr: Multiaddr
|
||||||
|
@ -170,6 +176,7 @@ class Mplex(IMuxedConn):
|
||||||
:return: stream_id, flag, message contents
|
:return: stream_id, flag, message contents
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# FIXME: No timeout is used in Go implementation.
|
||||||
# Timeout is set to a relatively small value to alleviate wait time to exit
|
# Timeout is set to a relatively small value to alleviate wait time to exit
|
||||||
# loop in handle_incoming
|
# loop in handle_incoming
|
||||||
timeout = 0.1
|
timeout = 0.1
|
||||||
|
|
|
@ -37,33 +37,31 @@ class MplexStream(IMuxedStream):
|
||||||
self.local_closed = False
|
self.local_closed = False
|
||||||
self.remote_closed = False
|
self.remote_closed = False
|
||||||
self.stream_lock = asyncio.Lock()
|
self.stream_lock = asyncio.Lock()
|
||||||
self._buf = None
|
self._buf = b""
|
||||||
|
|
||||||
async def read(self, n: int = -1) -> bytes:
|
async def read(self, n) -> bytes:
|
||||||
"""
|
"""
|
||||||
read messages associated with stream from buffer til end of file
|
Read up to n bytes. Read possibly returns fewer than `n` bytes,
|
||||||
|
if there are not enough bytes in the Mplex buffer.
|
||||||
:param n: number of bytes to read
|
:param n: number of bytes to read
|
||||||
:return: bytes of input
|
:return: bytes actually read
|
||||||
"""
|
"""
|
||||||
if n == -1:
|
# If the buffer is empty at first, blocking wait for data.
|
||||||
return await self.mplex_conn.read_buffer(self.stream_id)
|
if len(self._buf) == 0:
|
||||||
return await self.read_bytes(n)
|
|
||||||
|
|
||||||
async def read_bytes(self, n: int) -> bytes:
|
|
||||||
if self._buf is None:
|
|
||||||
self._buf = await self.mplex_conn.read_buffer(self.stream_id)
|
self._buf = await self.mplex_conn.read_buffer(self.stream_id)
|
||||||
n_read = 0
|
# Here, `self._buf` should never be `None`.
|
||||||
bytes_buf = BytesIO()
|
if self._buf is None or len(self._buf) == 0:
|
||||||
while self._buf is not None and n_read < n:
|
raise Exception("start to `read_buffer_nonblocking` only when there are bytes read.")
|
||||||
n_to_read = min(n - n_read, len(self._buf))
|
|
||||||
bytes_buf.write(self._buf[:n_to_read])
|
while len(self._buf) < n:
|
||||||
if n_to_read == n - n_read:
|
new_bytes = await self.mplex_conn.read_buffer_nonblocking(self.stream_id)
|
||||||
self._buf = self._buf[n_to_read:]
|
if new_bytes is None:
|
||||||
else:
|
# Nothing to read in the `MplexConn` buffer
|
||||||
self._buf = None
|
break
|
||||||
self._buf = await self.mplex_conn.read_buffer(self.stream_id)
|
self._buf += new_bytes
|
||||||
n_read += n_to_read
|
payload = self._buf[:n]
|
||||||
return bytes_buf.getvalue()
|
self._buf = self._buf[n:]
|
||||||
|
return payload
|
||||||
|
|
||||||
async def write(self, data: bytes) -> int:
|
async def write(self, data: bytes) -> int:
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Reference in New Issue
Block a user