Restructure mplex and mplex_stream

This commit is contained in:
mhchia 2019-09-05 22:29:33 +08:00
parent 96230758e4
commit eac159c527
No known key found for this signature in database
GPG Key ID: 389EFBEA1362589A
4 changed files with 55 additions and 71 deletions

View File

@ -1,5 +1,5 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING
from libp2p.peer.id import ID from libp2p.peer.id import ID
from libp2p.security.secure_conn_interface import ISecureConn from libp2p.security.secure_conn_interface import ISecureConn
@ -51,20 +51,6 @@ class IMuxedConn(ABC):
:return: true if successful :return: true if successful
""" """
@abstractmethod
async def read_buffer(self, stream_id: StreamID) -> bytes:
"""
Read a message from stream_id's buffer, check raw connection for new messages
:param stream_id: stream id of stream to read from
:return: message read
"""
@abstractmethod
async def read_buffer_nonblocking(self, stream_id: StreamID) -> Optional[bytes]:
"""
Read a message from `stream_id`'s buffer, non-blockingly.
"""
@abstractmethod @abstractmethod
async def open_stream(self) -> "IMuxedStream": async def open_stream(self) -> "IMuxedStream":
""" """
@ -73,7 +59,7 @@ class IMuxedConn(ABC):
""" """
@abstractmethod @abstractmethod
async def accept_stream(self, name: str) -> None: async def accept_stream(self, stream_id: StreamID, name: str) -> None:
""" """
accepts a muxed stream opened by the other end accepts a muxed stream opened by the other end
""" """

View File

@ -1,2 +1,13 @@
class StreamNotFound(Exception): from libp2p.exceptions import BaseLibp2pError
class MplexError(BaseLibp2pError):
pass
class MplexShutdown(MplexError):
pass
class StreamNotFound(MplexError):
pass pass

View File

@ -16,7 +16,6 @@ from libp2p.utils import (
from .constants import HeaderTags from .constants import HeaderTags
from .datastructures import StreamID from .datastructures import StreamID
from .exceptions import StreamNotFound
from .mplex_stream import MplexStream from .mplex_stream import MplexStream
MPLEX_PROTOCOL_ID = TProtocol("/mplex/6.7.0") MPLEX_PROTOCOL_ID = TProtocol("/mplex/6.7.0")
@ -32,10 +31,9 @@ class Mplex(IMuxedConn):
# TODO: `dataIn` in go implementation. Should be size of 8. # TODO: `dataIn` in go implementation. Should be size of 8.
# TODO: Also, `dataIn` is closed indicating EOF in Go. We don't have similar strategies # TODO: Also, `dataIn` is closed indicating EOF in Go. We don't have similar strategies
# to let the `MplexStream`s know that EOF arrived (#235). # to let the `MplexStream`s know that EOF arrived (#235).
buffers: Dict[StreamID, "asyncio.Queue[bytes]"]
stream_queue: "asyncio.Queue[StreamID]"
next_channel_id: int next_channel_id: int
buffers_lock: asyncio.Lock streams: Dict[StreamID, MplexStream]
streams_lock: asyncio.Lock
shutdown: asyncio.Event shutdown: asyncio.Event
_tasks: List["asyncio.Future[Any]"] _tasks: List["asyncio.Future[Any]"]
@ -65,12 +63,10 @@ class Mplex(IMuxedConn):
self.peer_id = peer_id self.peer_id = peer_id
# Mapping from stream ID -> buffer of messages for that stream # Mapping from stream ID -> buffer of messages for that stream
self.buffers = {} self.streams = {}
self.buffers_lock = asyncio.Lock() self.streams_lock = asyncio.Lock()
self.shutdown = asyncio.Event() self.shutdown = asyncio.Event()
self.stream_queue = asyncio.Queue()
self._tasks = [] self._tasks = []
# Kick off reading # Kick off reading
@ -95,29 +91,6 @@ class Mplex(IMuxedConn):
""" """
raise NotImplementedError() raise NotImplementedError()
async def read_buffer(self, stream_id: StreamID) -> bytes:
"""
Read a message from buffer of the stream specified by `stream_id`,
check secured connection for new messages.
`StreamNotFound` is raised when stream `stream_id` is not found in `Mplex`.
:param stream_id: stream id of stream to read from
:return: message read
"""
if stream_id not in self.buffers:
raise StreamNotFound(f"stream {stream_id} is not found")
return await self.buffers[stream_id].get()
async def read_buffer_nonblocking(self, stream_id: StreamID) -> Optional[bytes]:
"""
Read a message from buffer of the stream specified by `stream_id`, non-blockingly.
`StreamNotFound` is raised when stream `stream_id` is not found in `Mplex`.
"""
if stream_id not in self.buffers:
raise StreamNotFound(f"stream {stream_id} is not found")
if self.buffers[stream_id].empty():
return None
return await self.buffers[stream_id].get()
def _get_next_channel_id(self) -> int: def _get_next_channel_id(self) -> int:
""" """
Get next available stream id Get next available stream id
@ -127,6 +100,12 @@ class Mplex(IMuxedConn):
self.next_channel_id += 1 self.next_channel_id += 1
return next_id return next_id
async def _initialize_stream(self, stream_id: StreamID, name: str) -> MplexStream:
async with self.streams_lock:
stream = MplexStream(name, stream_id, self)
self.streams[stream_id] = stream
return stream
async def open_stream(self) -> IMuxedStream: async def open_stream(self) -> IMuxedStream:
""" """
creates a new muxed_stream creates a new muxed_stream
@ -134,19 +113,18 @@ class Mplex(IMuxedConn):
""" """
channel_id = self._get_next_channel_id() channel_id = self._get_next_channel_id()
stream_id = StreamID(channel_id=channel_id, is_initiator=True) stream_id = StreamID(channel_id=channel_id, is_initiator=True)
name = str(channel_id)
stream = MplexStream(name, stream_id, self)
self.buffers[stream_id] = asyncio.Queue()
# Default stream name is the `channel_id` # Default stream name is the `channel_id`
name = str(channel_id)
stream = await self._initialize_stream(stream_id, name)
await self.send_message(HeaderTags.NewStream, name.encode(), stream_id) await self.send_message(HeaderTags.NewStream, name.encode(), stream_id)
return stream return stream
async def accept_stream(self, name: str) -> None: async def accept_stream(self, stream_id: StreamID, name: str) -> None:
""" """
accepts a muxed stream opened by the other end accepts a muxed stream opened by the other end
""" """
stream_id = await self.stream_queue.get() stream = await self._initialize_stream(stream_id, name)
stream = MplexStream(name, stream_id, self) # Perform protocol negotiation for the stream.
self._tasks.append(asyncio.ensure_future(self.generic_protocol_handler(stream))) self._tasks.append(asyncio.ensure_future(self.generic_protocol_handler(stream)))
async def send_message( async def send_message(
@ -185,22 +163,30 @@ class Mplex(IMuxedConn):
while True: while True:
channel_id, flag, message = await self.read_message() channel_id, flag, message = await self.read_message()
if channel_id is not None and flag is not None and message is not None: if channel_id is not None and flag is not None and message is not None:
stream_id = StreamID(channel_id=channel_id, is_initiator=bool(flag & 1)) stream_id = StreamID(channel_id=channel_id, is_initiator=bool(flag & 1))
if stream_id not in self.buffers: is_stream_id_seen: bool
self.buffers[stream_id] = asyncio.Queue() async with self.streams_lock:
await self.stream_queue.put(stream_id) is_stream_id_seen = stream_id in self.streams
# Other consequent stream message should wait until the stream get accepted
# TODO: Handle more tags, and refactor `HeaderTags` # TODO: Handle more tags, and refactor `HeaderTags`
if flag == HeaderTags.NewStream.value: if flag == HeaderTags.NewStream.value:
# new stream detected on connection if is_stream_id_seen:
await self.accept_stream(message.decode()) # `NewStream` for the same id is received twice...
pass
await self.accept_stream(stream_id, message.decode())
elif flag in ( elif flag in (
HeaderTags.MessageInitiator.value, HeaderTags.MessageInitiator.value,
HeaderTags.MessageReceiver.value, HeaderTags.MessageReceiver.value,
): ):
await self.buffers[stream_id].put(message) if not is_stream_id_seen:
# We receive a message of the stream `stream_id` which is not accepted
# before. It is abnormal. Possibly disconnect?
# TODO: Warn and emit logs about this.
continue
async with self.streams_lock:
stream = self.streams[stream_id]
await stream.incoming_data.put(message)
# elif flag in ( # elif flag in (
# HeaderTags.CloseInitiator.value, # HeaderTags.CloseInitiator.value,
# HeaderTags.CloseReceiver.value # HeaderTags.CloseReceiver.value

View File

@ -23,6 +23,8 @@ class MplexStream(IMuxedStream):
close_lock: asyncio.Lock close_lock: asyncio.Lock
incoming_data: "asyncio.Queue[bytes]"
event_local_closed: asyncio.Event event_local_closed: asyncio.Event
event_remote_closed: asyncio.Event event_remote_closed: asyncio.Event
event_reset: asyncio.Event event_reset: asyncio.Event
@ -44,6 +46,7 @@ class MplexStream(IMuxedStream):
self.event_remote_closed = asyncio.Event() self.event_remote_closed = asyncio.Event()
self.event_reset = asyncio.Event() self.event_reset = asyncio.Event()
self.close_lock = asyncio.Lock() self.close_lock = asyncio.Lock()
self.incoming_data = asyncio.Queue()
self._buf = bytearray() self._buf = bytearray()
@property @property
@ -58,7 +61,6 @@ class MplexStream(IMuxedStream):
:param n: number of bytes to read :param n: number of bytes to read
:return: bytes actually read :return: bytes actually read
""" """
# TODO: Handle `StreamNotFound` raised in `self.mplex_conn.read_buffer`.
# TODO: Add exceptions and handle/raise them in this class. # TODO: Add exceptions and handle/raise them in this class.
if n < 0 and n != -1: if n < 0 and n != -1:
raise ValueError( raise ValueError(
@ -66,17 +68,16 @@ class MplexStream(IMuxedStream):
) )
# If the buffer is empty at first, blocking wait for data. # If the buffer is empty at first, blocking wait for data.
if len(self._buf) == 0: if len(self._buf) == 0:
self._buf.extend(await self.mplex_conn.read_buffer(self.stream_id)) self._buf.extend(await self.incoming_data.get())
# FIXME: If `n == -1`, we should blocking read until EOF, instead of returning when # FIXME: If `n == -1`, we should blocking read until EOF, instead of returning when
# no message is available. # no message is available.
# If `n >= 0`, read up to `n` bytes. # If `n >= 0`, read up to `n` bytes.
# Else, read until no message is available. # Else, read until no message is available.
while len(self._buf) < n or n == -1: while len(self._buf) < n or n == -1:
new_bytes = await self.mplex_conn.read_buffer_nonblocking(self.stream_id) if self.incoming_data.empty():
if new_bytes is None:
# Nothing to read in the `MplexConn` buffer
break break
new_bytes = await self.incoming_data.get()
self._buf.extend(new_bytes) self._buf.extend(new_bytes)
payload: bytearray payload: bytearray
if n == -1: if n == -1:
@ -122,8 +123,8 @@ class MplexStream(IMuxedStream):
if _is_remote_closed: if _is_remote_closed:
# Both sides are closed, we can safely remove the buffer from the dict. # Both sides are closed, we can safely remove the buffer from the dict.
async with self.mplex_conn.buffers_lock: async with self.mplex_conn.streams_lock:
del self.mplex_conn.buffers[self.stream_id] del self.mplex_conn.streams[self.stream_id]
async def reset(self) -> None: async def reset(self) -> None:
""" """
@ -152,8 +153,8 @@ class MplexStream(IMuxedStream):
self.event_local_closed.set() self.event_local_closed.set()
self.event_remote_closed.set() self.event_remote_closed.set()
async with self.mplex_conn.buffers_lock: async with self.mplex_conn.streams_lock:
del self.mplex_conn.buffers[self.stream_id] del self.mplex_conn.streams[self.stream_id]
# TODO deadline not in use # TODO deadline not in use
def set_deadline(self, ttl: int) -> bool: def set_deadline(self, ttl: int) -> bool: