Restructure mplex and mplex_stream
This commit is contained in:
parent
96230758e4
commit
eac159c527
|
@ -1,5 +1,5 @@
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import TYPE_CHECKING, Optional
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from libp2p.peer.id import ID
|
from libp2p.peer.id import ID
|
||||||
from libp2p.security.secure_conn_interface import ISecureConn
|
from libp2p.security.secure_conn_interface import ISecureConn
|
||||||
|
@ -51,20 +51,6 @@ class IMuxedConn(ABC):
|
||||||
:return: true if successful
|
:return: true if successful
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def read_buffer(self, stream_id: StreamID) -> bytes:
|
|
||||||
"""
|
|
||||||
Read a message from stream_id's buffer, check raw connection for new messages
|
|
||||||
:param stream_id: stream id of stream to read from
|
|
||||||
:return: message read
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def read_buffer_nonblocking(self, stream_id: StreamID) -> Optional[bytes]:
|
|
||||||
"""
|
|
||||||
Read a message from `stream_id`'s buffer, non-blockingly.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def open_stream(self) -> "IMuxedStream":
|
async def open_stream(self) -> "IMuxedStream":
|
||||||
"""
|
"""
|
||||||
|
@ -73,7 +59,7 @@ class IMuxedConn(ABC):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def accept_stream(self, name: str) -> None:
|
async def accept_stream(self, stream_id: StreamID, name: str) -> None:
|
||||||
"""
|
"""
|
||||||
accepts a muxed stream opened by the other end
|
accepts a muxed stream opened by the other end
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -1,2 +1,13 @@
|
||||||
class StreamNotFound(Exception):
|
from libp2p.exceptions import BaseLibp2pError
|
||||||
|
|
||||||
|
|
||||||
|
class MplexError(BaseLibp2pError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class MplexShutdown(MplexError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class StreamNotFound(MplexError):
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -16,7 +16,6 @@ from libp2p.utils import (
|
||||||
|
|
||||||
from .constants import HeaderTags
|
from .constants import HeaderTags
|
||||||
from .datastructures import StreamID
|
from .datastructures import StreamID
|
||||||
from .exceptions import StreamNotFound
|
|
||||||
from .mplex_stream import MplexStream
|
from .mplex_stream import MplexStream
|
||||||
|
|
||||||
MPLEX_PROTOCOL_ID = TProtocol("/mplex/6.7.0")
|
MPLEX_PROTOCOL_ID = TProtocol("/mplex/6.7.0")
|
||||||
|
@ -32,10 +31,9 @@ class Mplex(IMuxedConn):
|
||||||
# TODO: `dataIn` in go implementation. Should be size of 8.
|
# TODO: `dataIn` in go implementation. Should be size of 8.
|
||||||
# TODO: Also, `dataIn` is closed indicating EOF in Go. We don't have similar strategies
|
# TODO: Also, `dataIn` is closed indicating EOF in Go. We don't have similar strategies
|
||||||
# to let the `MplexStream`s know that EOF arrived (#235).
|
# to let the `MplexStream`s know that EOF arrived (#235).
|
||||||
buffers: Dict[StreamID, "asyncio.Queue[bytes]"]
|
|
||||||
stream_queue: "asyncio.Queue[StreamID]"
|
|
||||||
next_channel_id: int
|
next_channel_id: int
|
||||||
buffers_lock: asyncio.Lock
|
streams: Dict[StreamID, MplexStream]
|
||||||
|
streams_lock: asyncio.Lock
|
||||||
shutdown: asyncio.Event
|
shutdown: asyncio.Event
|
||||||
|
|
||||||
_tasks: List["asyncio.Future[Any]"]
|
_tasks: List["asyncio.Future[Any]"]
|
||||||
|
@ -65,12 +63,10 @@ class Mplex(IMuxedConn):
|
||||||
self.peer_id = peer_id
|
self.peer_id = peer_id
|
||||||
|
|
||||||
# Mapping from stream ID -> buffer of messages for that stream
|
# Mapping from stream ID -> buffer of messages for that stream
|
||||||
self.buffers = {}
|
self.streams = {}
|
||||||
self.buffers_lock = asyncio.Lock()
|
self.streams_lock = asyncio.Lock()
|
||||||
self.shutdown = asyncio.Event()
|
self.shutdown = asyncio.Event()
|
||||||
|
|
||||||
self.stream_queue = asyncio.Queue()
|
|
||||||
|
|
||||||
self._tasks = []
|
self._tasks = []
|
||||||
|
|
||||||
# Kick off reading
|
# Kick off reading
|
||||||
|
@ -95,29 +91,6 @@ class Mplex(IMuxedConn):
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
async def read_buffer(self, stream_id: StreamID) -> bytes:
|
|
||||||
"""
|
|
||||||
Read a message from buffer of the stream specified by `stream_id`,
|
|
||||||
check secured connection for new messages.
|
|
||||||
`StreamNotFound` is raised when stream `stream_id` is not found in `Mplex`.
|
|
||||||
:param stream_id: stream id of stream to read from
|
|
||||||
:return: message read
|
|
||||||
"""
|
|
||||||
if stream_id not in self.buffers:
|
|
||||||
raise StreamNotFound(f"stream {stream_id} is not found")
|
|
||||||
return await self.buffers[stream_id].get()
|
|
||||||
|
|
||||||
async def read_buffer_nonblocking(self, stream_id: StreamID) -> Optional[bytes]:
|
|
||||||
"""
|
|
||||||
Read a message from buffer of the stream specified by `stream_id`, non-blockingly.
|
|
||||||
`StreamNotFound` is raised when stream `stream_id` is not found in `Mplex`.
|
|
||||||
"""
|
|
||||||
if stream_id not in self.buffers:
|
|
||||||
raise StreamNotFound(f"stream {stream_id} is not found")
|
|
||||||
if self.buffers[stream_id].empty():
|
|
||||||
return None
|
|
||||||
return await self.buffers[stream_id].get()
|
|
||||||
|
|
||||||
def _get_next_channel_id(self) -> int:
|
def _get_next_channel_id(self) -> int:
|
||||||
"""
|
"""
|
||||||
Get next available stream id
|
Get next available stream id
|
||||||
|
@ -127,6 +100,12 @@ class Mplex(IMuxedConn):
|
||||||
self.next_channel_id += 1
|
self.next_channel_id += 1
|
||||||
return next_id
|
return next_id
|
||||||
|
|
||||||
|
async def _initialize_stream(self, stream_id: StreamID, name: str) -> MplexStream:
|
||||||
|
async with self.streams_lock:
|
||||||
|
stream = MplexStream(name, stream_id, self)
|
||||||
|
self.streams[stream_id] = stream
|
||||||
|
return stream
|
||||||
|
|
||||||
async def open_stream(self) -> IMuxedStream:
|
async def open_stream(self) -> IMuxedStream:
|
||||||
"""
|
"""
|
||||||
creates a new muxed_stream
|
creates a new muxed_stream
|
||||||
|
@ -134,19 +113,18 @@ class Mplex(IMuxedConn):
|
||||||
"""
|
"""
|
||||||
channel_id = self._get_next_channel_id()
|
channel_id = self._get_next_channel_id()
|
||||||
stream_id = StreamID(channel_id=channel_id, is_initiator=True)
|
stream_id = StreamID(channel_id=channel_id, is_initiator=True)
|
||||||
name = str(channel_id)
|
|
||||||
stream = MplexStream(name, stream_id, self)
|
|
||||||
self.buffers[stream_id] = asyncio.Queue()
|
|
||||||
# Default stream name is the `channel_id`
|
# Default stream name is the `channel_id`
|
||||||
|
name = str(channel_id)
|
||||||
|
stream = await self._initialize_stream(stream_id, name)
|
||||||
await self.send_message(HeaderTags.NewStream, name.encode(), stream_id)
|
await self.send_message(HeaderTags.NewStream, name.encode(), stream_id)
|
||||||
return stream
|
return stream
|
||||||
|
|
||||||
async def accept_stream(self, name: str) -> None:
|
async def accept_stream(self, stream_id: StreamID, name: str) -> None:
|
||||||
"""
|
"""
|
||||||
accepts a muxed stream opened by the other end
|
accepts a muxed stream opened by the other end
|
||||||
"""
|
"""
|
||||||
stream_id = await self.stream_queue.get()
|
stream = await self._initialize_stream(stream_id, name)
|
||||||
stream = MplexStream(name, stream_id, self)
|
# Perform protocol negotiation for the stream.
|
||||||
self._tasks.append(asyncio.ensure_future(self.generic_protocol_handler(stream)))
|
self._tasks.append(asyncio.ensure_future(self.generic_protocol_handler(stream)))
|
||||||
|
|
||||||
async def send_message(
|
async def send_message(
|
||||||
|
@ -185,22 +163,30 @@ class Mplex(IMuxedConn):
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
channel_id, flag, message = await self.read_message()
|
channel_id, flag, message = await self.read_message()
|
||||||
|
|
||||||
if channel_id is not None and flag is not None and message is not None:
|
if channel_id is not None and flag is not None and message is not None:
|
||||||
stream_id = StreamID(channel_id=channel_id, is_initiator=bool(flag & 1))
|
stream_id = StreamID(channel_id=channel_id, is_initiator=bool(flag & 1))
|
||||||
if stream_id not in self.buffers:
|
is_stream_id_seen: bool
|
||||||
self.buffers[stream_id] = asyncio.Queue()
|
async with self.streams_lock:
|
||||||
await self.stream_queue.put(stream_id)
|
is_stream_id_seen = stream_id in self.streams
|
||||||
|
# Other consequent stream message should wait until the stream get accepted
|
||||||
# TODO: Handle more tags, and refactor `HeaderTags`
|
# TODO: Handle more tags, and refactor `HeaderTags`
|
||||||
if flag == HeaderTags.NewStream.value:
|
if flag == HeaderTags.NewStream.value:
|
||||||
# new stream detected on connection
|
if is_stream_id_seen:
|
||||||
await self.accept_stream(message.decode())
|
# `NewStream` for the same id is received twice...
|
||||||
|
pass
|
||||||
|
await self.accept_stream(stream_id, message.decode())
|
||||||
elif flag in (
|
elif flag in (
|
||||||
HeaderTags.MessageInitiator.value,
|
HeaderTags.MessageInitiator.value,
|
||||||
HeaderTags.MessageReceiver.value,
|
HeaderTags.MessageReceiver.value,
|
||||||
):
|
):
|
||||||
await self.buffers[stream_id].put(message)
|
if not is_stream_id_seen:
|
||||||
|
# We receive a message of the stream `stream_id` which is not accepted
|
||||||
|
# before. It is abnormal. Possibly disconnect?
|
||||||
|
# TODO: Warn and emit logs about this.
|
||||||
|
continue
|
||||||
|
async with self.streams_lock:
|
||||||
|
stream = self.streams[stream_id]
|
||||||
|
await stream.incoming_data.put(message)
|
||||||
# elif flag in (
|
# elif flag in (
|
||||||
# HeaderTags.CloseInitiator.value,
|
# HeaderTags.CloseInitiator.value,
|
||||||
# HeaderTags.CloseReceiver.value
|
# HeaderTags.CloseReceiver.value
|
||||||
|
|
|
@ -23,6 +23,8 @@ class MplexStream(IMuxedStream):
|
||||||
|
|
||||||
close_lock: asyncio.Lock
|
close_lock: asyncio.Lock
|
||||||
|
|
||||||
|
incoming_data: "asyncio.Queue[bytes]"
|
||||||
|
|
||||||
event_local_closed: asyncio.Event
|
event_local_closed: asyncio.Event
|
||||||
event_remote_closed: asyncio.Event
|
event_remote_closed: asyncio.Event
|
||||||
event_reset: asyncio.Event
|
event_reset: asyncio.Event
|
||||||
|
@ -44,6 +46,7 @@ class MplexStream(IMuxedStream):
|
||||||
self.event_remote_closed = asyncio.Event()
|
self.event_remote_closed = asyncio.Event()
|
||||||
self.event_reset = asyncio.Event()
|
self.event_reset = asyncio.Event()
|
||||||
self.close_lock = asyncio.Lock()
|
self.close_lock = asyncio.Lock()
|
||||||
|
self.incoming_data = asyncio.Queue()
|
||||||
self._buf = bytearray()
|
self._buf = bytearray()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -58,7 +61,6 @@ class MplexStream(IMuxedStream):
|
||||||
:param n: number of bytes to read
|
:param n: number of bytes to read
|
||||||
:return: bytes actually read
|
:return: bytes actually read
|
||||||
"""
|
"""
|
||||||
# TODO: Handle `StreamNotFound` raised in `self.mplex_conn.read_buffer`.
|
|
||||||
# TODO: Add exceptions and handle/raise them in this class.
|
# TODO: Add exceptions and handle/raise them in this class.
|
||||||
if n < 0 and n != -1:
|
if n < 0 and n != -1:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -66,17 +68,16 @@ class MplexStream(IMuxedStream):
|
||||||
)
|
)
|
||||||
# If the buffer is empty at first, blocking wait for data.
|
# If the buffer is empty at first, blocking wait for data.
|
||||||
if len(self._buf) == 0:
|
if len(self._buf) == 0:
|
||||||
self._buf.extend(await self.mplex_conn.read_buffer(self.stream_id))
|
self._buf.extend(await self.incoming_data.get())
|
||||||
|
|
||||||
# FIXME: If `n == -1`, we should blocking read until EOF, instead of returning when
|
# FIXME: If `n == -1`, we should blocking read until EOF, instead of returning when
|
||||||
# no message is available.
|
# no message is available.
|
||||||
# If `n >= 0`, read up to `n` bytes.
|
# If `n >= 0`, read up to `n` bytes.
|
||||||
# Else, read until no message is available.
|
# Else, read until no message is available.
|
||||||
while len(self._buf) < n or n == -1:
|
while len(self._buf) < n or n == -1:
|
||||||
new_bytes = await self.mplex_conn.read_buffer_nonblocking(self.stream_id)
|
if self.incoming_data.empty():
|
||||||
if new_bytes is None:
|
|
||||||
# Nothing to read in the `MplexConn` buffer
|
|
||||||
break
|
break
|
||||||
|
new_bytes = await self.incoming_data.get()
|
||||||
self._buf.extend(new_bytes)
|
self._buf.extend(new_bytes)
|
||||||
payload: bytearray
|
payload: bytearray
|
||||||
if n == -1:
|
if n == -1:
|
||||||
|
@ -122,8 +123,8 @@ class MplexStream(IMuxedStream):
|
||||||
|
|
||||||
if _is_remote_closed:
|
if _is_remote_closed:
|
||||||
# Both sides are closed, we can safely remove the buffer from the dict.
|
# Both sides are closed, we can safely remove the buffer from the dict.
|
||||||
async with self.mplex_conn.buffers_lock:
|
async with self.mplex_conn.streams_lock:
|
||||||
del self.mplex_conn.buffers[self.stream_id]
|
del self.mplex_conn.streams[self.stream_id]
|
||||||
|
|
||||||
async def reset(self) -> None:
|
async def reset(self) -> None:
|
||||||
"""
|
"""
|
||||||
|
@ -152,8 +153,8 @@ class MplexStream(IMuxedStream):
|
||||||
self.event_local_closed.set()
|
self.event_local_closed.set()
|
||||||
self.event_remote_closed.set()
|
self.event_remote_closed.set()
|
||||||
|
|
||||||
async with self.mplex_conn.buffers_lock:
|
async with self.mplex_conn.streams_lock:
|
||||||
del self.mplex_conn.buffers[self.stream_id]
|
del self.mplex_conn.streams[self.stream_id]
|
||||||
|
|
||||||
# TODO deadline not in use
|
# TODO deadline not in use
|
||||||
def set_deadline(self, ttl: int) -> bool:
|
def set_deadline(self, ttl: int) -> bool:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user