import asyncio from typing import TYPE_CHECKING, cast from libp2p.stream_muxer.abc import IMuxedStream from .constants import HeaderTags from .datastructures import StreamID from .exceptions import MplexStreamClosed, MplexStreamEOF, MplexStreamReset if TYPE_CHECKING: from typing import Any # noqa: F401 from libp2p.stream_muxer.mplex.mplex import Mplex class MplexStream(IMuxedStream): """ reference: https://github.com/libp2p/go-mplex/blob/master/stream.go """ name: str stream_id: StreamID mplex_conn: "Mplex" read_deadline: int write_deadline: int close_lock: asyncio.Lock incoming_data: "asyncio.Queue[bytes]" event_local_closed: asyncio.Event event_remote_closed: asyncio.Event event_reset: asyncio.Event _buf: bytearray def __init__(self, name: str, stream_id: StreamID, mplex_conn: "Mplex") -> None: """ create new MuxedStream in muxer :param stream_id: stream id of this stream :param mplex_conn: muxed connection of this muxed_stream """ self.name = name self.stream_id = stream_id self.mplex_conn = mplex_conn self.read_deadline = None self.write_deadline = None self.event_local_closed = asyncio.Event() self.event_remote_closed = asyncio.Event() self.event_reset = asyncio.Event() self.close_lock = asyncio.Lock() self.incoming_data = asyncio.Queue() self._buf = bytearray() @property def is_initiator(self) -> bool: return self.stream_id.is_initiator async def _wait_for_data(self) -> None: done, pending = await asyncio.wait( # type: ignore [ self.event_reset.wait(), self.incoming_data.get(), self.event_remote_closed.wait(), ], return_when=asyncio.FIRST_COMPLETED, ) for fut in pending: fut.cancel() if self.event_reset.is_set(): raise MplexStreamReset done_task = cast("asyncio.Task[Any]", tuple(done)[0]) # TODO: `_coro` is not in `asyncio.Task`'s typeshed. if done_task._coro.__qualname__ == "Queue.get": # type: ignore data = done_task.result() self._buf.extend(data) return if self.event_remote_closed.is_set(): raise MplexStreamEOF # TODO: Handle timeout when deadline is used. async def _read_until_eof(self) -> bytes: while True: try: await self._wait_for_data() except MplexStreamEOF: break payload = self._buf self._buf = self._buf[len(payload) :] return bytes(payload) async def read(self, n: int = -1) -> bytes: """ Read up to n bytes. Read possibly returns fewer than `n` bytes, if there are not enough bytes in the Mplex buffer. If `n == -1`, read until EOF. :param n: number of bytes to read :return: bytes actually read """ # TODO: Add exceptions and handle/raise them in this class. if n < 0 and n != -1: raise ValueError( f"the number of bytes to read `n` must be positive or -1 to indicate read until EOF" ) if self.event_reset.is_set(): raise MplexStreamReset if n == -1: return await self._read_until_eof() if len(self._buf) == 0 and self.incoming_data.empty(): await self._wait_for_data() # Either `buf` is not empty or `incoming_data` is not empty now. # Try to put enough incoming data into `self._buf`. while len(self._buf) < n: try: self._buf.extend(self.incoming_data.get_nowait()) except asyncio.QueueEmpty: break payload = self._buf[:n] self._buf = self._buf[len(payload) :] return bytes(payload) async def write(self, data: bytes) -> int: """ write to stream :return: number of bytes written """ if self.event_local_closed.is_set(): raise MplexStreamClosed(f"cannot write to closed stream: data={data}") flag = ( HeaderTags.MessageInitiator if self.is_initiator else HeaderTags.MessageReceiver ) return await self.mplex_conn.send_message(flag, data, self.stream_id) async def close(self) -> None: """ Closing a stream closes it for writing and closes the remote end for reading but allows writing in the other direction. """ # TODO error handling with timeout async with self.close_lock: if self.event_local_closed.is_set(): return flag = ( HeaderTags.CloseInitiator if self.is_initiator else HeaderTags.CloseReceiver ) # TODO: Raise when `mplex_conn.send_message` fails and `Mplex` isn't shutdown. await self.mplex_conn.send_message(flag, None, self.stream_id) _is_remote_closed: bool async with self.close_lock: self.event_local_closed.set() _is_remote_closed = self.event_remote_closed.is_set() if _is_remote_closed: # Both sides are closed, we can safely remove the buffer from the dict. async with self.mplex_conn.streams_lock: del self.mplex_conn.streams[self.stream_id] async def reset(self) -> None: """ closes both ends of the stream tells this remote side to hang up """ async with self.close_lock: # Both sides have been closed. No need to event_reset. if self.event_remote_closed.is_set() and self.event_local_closed.is_set(): return if self.event_reset.is_set(): return self.event_reset.set() if not self.event_remote_closed.is_set(): flag = ( HeaderTags.ResetInitiator if self.is_initiator else HeaderTags.ResetReceiver ) asyncio.ensure_future( self.mplex_conn.send_message(flag, None, self.stream_id) ) await asyncio.sleep(0) self.event_local_closed.set() self.event_remote_closed.set() async with self.mplex_conn.streams_lock: del self.mplex_conn.streams[self.stream_id] # TODO deadline not in use def set_deadline(self, ttl: int) -> bool: """ set deadline for muxed stream :return: True if successful """ self.read_deadline = ttl self.write_deadline = ttl return True def set_read_deadline(self, ttl: int) -> bool: """ set read deadline for muxed stream :return: True if successful """ self.read_deadline = ttl return True def set_write_deadline(self, ttl: int) -> bool: """ set write deadline for muxed stream :return: True if successful """ self.write_deadline = ttl return True