From 96230758e42a8c5cc357cd68f558842da7839815 Mon Sep 17 00:00:00 2001 From: mhchia Date: Thu, 5 Sep 2019 18:18:08 +0800 Subject: [PATCH] Add events in MplexStream And modify a little bit of `close` and `reset` --- libp2p/stream_muxer/mplex/mplex.py | 24 +++--- libp2p/stream_muxer/mplex/mplex_stream.py | 91 ++++++++++++----------- 2 files changed, 64 insertions(+), 51 deletions(-) diff --git a/libp2p/stream_muxer/mplex/mplex.py b/libp2p/stream_muxer/mplex/mplex.py index 5f55a66..cf1ec91 100644 --- a/libp2p/stream_muxer/mplex/mplex.py +++ b/libp2p/stream_muxer/mplex/mplex.py @@ -10,6 +10,7 @@ from libp2p.typing import TProtocol from libp2p.utils import ( decode_uvarint_from_stream, encode_uvarint, + encode_varint_prefixed, read_varint_prefixed_bytes, ) @@ -34,6 +35,8 @@ class Mplex(IMuxedConn): buffers: Dict[StreamID, "asyncio.Queue[bytes]"] stream_queue: "asyncio.Queue[StreamID]" next_channel_id: int + buffers_lock: asyncio.Lock + shutdown: asyncio.Event _tasks: List["asyncio.Future[Any]"] @@ -63,6 +66,8 @@ class Mplex(IMuxedConn): # Mapping from stream ID -> buffer of messages for that stream self.buffers = {} + self.buffers_lock = asyncio.Lock() + self.shutdown = asyncio.Event() self.stream_queue = asyncio.Queue() @@ -145,7 +150,7 @@ class Mplex(IMuxedConn): self._tasks.append(asyncio.ensure_future(self.generic_protocol_handler(stream))) async def send_message( - self, flag: HeaderTags, data: bytes, stream_id: StreamID + self, flag: HeaderTags, data: Optional[bytes], stream_id: StreamID ) -> int: """ sends a message over the connection @@ -154,19 +159,16 @@ class Mplex(IMuxedConn): :param stream_id: stream the message is in """ # << by 3, then or with flag - header = (stream_id.channel_id << 3) | flag.value - header = encode_uvarint(header) + header = encode_uvarint((stream_id.channel_id << 3) | flag.value) if data is None: - data_length = encode_uvarint(0) - _bytes = header + data_length - else: - data_length = encode_uvarint(len(data)) - _bytes = header + data_length + data + data = b"" + + _bytes = header + encode_varint_prefixed(data) return await self.write_to_stream(_bytes) - async def write_to_stream(self, _bytes: bytearray) -> int: + async def write_to_stream(self, _bytes: bytes) -> int: """ writes a byte array to a secured connection :param _bytes: byte array to write @@ -199,6 +201,10 @@ class Mplex(IMuxedConn): HeaderTags.MessageReceiver.value, ): await self.buffers[stream_id].put(message) + # elif flag in ( + # HeaderTags.CloseInitiator.value, + # HeaderTags.CloseReceiver.value + # ): # Force context switch await asyncio.sleep(0) diff --git a/libp2p/stream_muxer/mplex/mplex_stream.py b/libp2p/stream_muxer/mplex/mplex_stream.py index fe0261b..d0f0801 100644 --- a/libp2p/stream_muxer/mplex/mplex_stream.py +++ b/libp2p/stream_muxer/mplex/mplex_stream.py @@ -1,10 +1,14 @@ import asyncio +from typing import TYPE_CHECKING -from libp2p.stream_muxer.abc import IMuxedConn, IMuxedStream +from libp2p.stream_muxer.abc import IMuxedStream from .constants import HeaderTags from .datastructures import StreamID +if TYPE_CHECKING: + from libp2p.stream_muxer.mplex.mplex import Mplex + class MplexStream(IMuxedStream): """ @@ -13,16 +17,19 @@ class MplexStream(IMuxedStream): name: str stream_id: StreamID - mplex_conn: IMuxedConn + mplex_conn: "Mplex" read_deadline: int write_deadline: int - local_closed: bool - remote_closed: bool - stream_lock: asyncio.Lock + + close_lock: asyncio.Lock + + event_local_closed: asyncio.Event + event_remote_closed: asyncio.Event + event_reset: asyncio.Event _buf: bytearray - def __init__(self, name: str, stream_id: StreamID, mplex_conn: IMuxedConn) -> None: + def __init__(self, name: str, stream_id: StreamID, mplex_conn: "Mplex") -> None: """ create new MuxedStream in muxer :param stream_id: stream id of this stream @@ -33,9 +40,10 @@ class MplexStream(IMuxedStream): self.mplex_conn = mplex_conn self.read_deadline = None self.write_deadline = None - self.local_closed = False - self.remote_closed = False - self.stream_lock = asyncio.Lock() + self.event_local_closed = asyncio.Event() + self.event_remote_closed = asyncio.Event() + self.event_reset = asyncio.Event() + self.close_lock = asyncio.Lock() self._buf = bytearray() @property @@ -90,63 +98,62 @@ class MplexStream(IMuxedStream): ) return await self.mplex_conn.send_message(flag, data, self.stream_id) - async def close(self) -> bool: + async def close(self) -> None: """ Closing a stream closes it for writing and closes the remote end for reading but allows writing in the other direction. - :return: true if successful """ # TODO error handling with timeout - # TODO understand better how mutexes are used from go repo + + async with self.close_lock: + if self.event_local_closed.is_set(): + return + flag = ( HeaderTags.CloseInitiator if self.is_initiator else HeaderTags.CloseReceiver ) + # TODO: Raise when `mplex_conn.send_message` fails and `Mplex` isn't shutdown. await self.mplex_conn.send_message(flag, None, self.stream_id) - remote_lock = False - async with self.stream_lock: - if self.local_closed: - return True - self.local_closed = True - remote_lock = self.remote_closed + _is_remote_closed: bool + async with self.close_lock: + self.event_local_closed.set() + _is_remote_closed = self.event_remote_closed.is_set() - if remote_lock: - # FIXME: mplex_conn has no conn_lock! - async with self.mplex_conn.conn_lock: # type: ignore - # FIXME: Don't access to buffers directly - self.mplex_conn.buffers.pop(self.stream_id) # type: ignore + if _is_remote_closed: + # Both sides are closed, we can safely remove the buffer from the dict. + async with self.mplex_conn.buffers_lock: + del self.mplex_conn.buffers[self.stream_id] - return True - - async def reset(self) -> bool: + async def reset(self) -> None: """ closes both ends of the stream tells this remote side to hang up - :return: true if successful """ - # TODO understand better how mutexes are used here - # TODO understand the difference between close and reset - async with self.stream_lock: - if self.remote_closed and self.local_closed: - return True + async with self.close_lock: + # Both sides have been closed. No need to event_reset. + if self.event_remote_closed.is_set() and self.event_local_closed.is_set(): + return + if self.event_reset.is_set(): + return + self.event_reset.set() - if not self.remote_closed: + if not self.event_remote_closed.is_set(): flag = ( HeaderTags.ResetInitiator if self.is_initiator else HeaderTags.ResetReceiver ) - await self.mplex_conn.send_message(flag, None, self.stream_id) + asyncio.ensure_future( + self.mplex_conn.send_message(flag, None, self.stream_id) + ) + await asyncio.sleep(0) - self.local_closed = True - self.remote_closed = True + self.event_local_closed.set() + self.event_remote_closed.set() - # FIXME: mplex_conn has no conn_lock! - async with self.mplex_conn.conn_lock: # type: ignore - # FIXME: Don't access to buffers directly - self.mplex_conn.buffers.pop(self.stream_id, None) # type: ignore - - return True + async with self.mplex_conn.buffers_lock: + del self.mplex_conn.buffers[self.stream_id] # TODO deadline not in use def set_deadline(self, ttl: int) -> bool: