Merge pull request #268 from mhchia/fix/mplex-interop

Fix: name of a `MplexStream` is not handled in `Mplex`
This commit is contained in:
Alex Stokes 2019-08-26 19:51:57 +02:00 committed by GitHub
commit 98a0e76dda
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 75 additions and 71 deletions

View File

@ -6,3 +6,7 @@ class ValidationError(BaseLibp2pError):
""" """
Raised when something does not pass a validation check. Raised when something does not pass a validation check.
""" """
class ParseError(BaseLibp2pError):
pass

View File

@ -156,15 +156,10 @@ class Swarm(INetwork):
if not addrs: if not addrs:
raise SwarmException("No known addresses to peer") raise SwarmException("No known addresses to peer")
multiaddr = addrs[0]
muxed_conn = await self.dial_peer(peer_id) muxed_conn = await self.dial_peer(peer_id)
# Use muxed conn to open stream, which returns # Use muxed conn to open stream, which returns a muxed stream
# a muxed stream muxed_stream = await muxed_conn.open_stream()
# TODO: Remove protocol id from being passed into muxed_conn
# FIXME: Remove multiaddr from being passed into muxed_conn
muxed_stream = await muxed_conn.open_stream(protocol_ids[0], multiaddr)
# Perform protocol muxing to determine protocol to use # Perform protocol muxing to determine protocol to use
selected_protocol = await self.multiselect_client.select_one_of( selected_protocol = await self.multiselect_client.select_one_of(

View File

@ -1,8 +1,6 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional
from multiaddr import Multiaddr
from libp2p.peer.id import ID from libp2p.peer.id import ID
from libp2p.security.secure_conn_interface import ISecureConn from libp2p.security.secure_conn_interface import ISecureConn
from libp2p.stream_muxer.mplex.constants import HeaderTags from libp2p.stream_muxer.mplex.constants import HeaderTags
@ -66,20 +64,15 @@ class IMuxedConn(ABC):
Read a message from `stream_id`'s buffer, non-blockingly. Read a message from `stream_id`'s buffer, non-blockingly.
""" """
# FIXME: Remove multiaddr from being passed into muxed_conn
@abstractmethod @abstractmethod
async def open_stream( async def open_stream(self) -> "IMuxedStream":
self, protocol_id: str, multi_addr: Multiaddr
) -> "IMuxedStream":
""" """
creates a new muxed_stream creates a new muxed_stream
:param protocol_id: protocol_id of stream :return: a new ``IMuxedStream`` stream
:param multi_addr: multi_addr that stream connects to
:return: a new stream
""" """
@abstractmethod @abstractmethod
async def accept_stream(self) -> None: async def accept_stream(self, name: str) -> None:
""" """
accepts a muxed stream opened by the other end accepts a muxed stream opened by the other end
""" """

View File

@ -1,14 +1,16 @@
import asyncio import asyncio
from typing import Dict, Optional, Tuple from typing import Dict, Optional, Tuple
from multiaddr import Multiaddr
from libp2p.network.typing import GenericProtocolHandlerFn from libp2p.network.typing import GenericProtocolHandlerFn
from libp2p.peer.id import ID from libp2p.peer.id import ID
from libp2p.security.secure_conn_interface import ISecureConn from libp2p.security.secure_conn_interface import ISecureConn
from libp2p.stream_muxer.abc import IMuxedConn, IMuxedStream from libp2p.stream_muxer.abc import IMuxedConn, IMuxedStream
from libp2p.typing import TProtocol from libp2p.typing import TProtocol
from libp2p.utils import decode_uvarint_from_stream, encode_uvarint from libp2p.utils import (
decode_uvarint_from_stream,
encode_uvarint,
read_varint_prefixed_bytes,
)
from .constants import HeaderTags from .constants import HeaderTags
from .exceptions import StreamNotFound from .exceptions import StreamNotFound
@ -31,6 +33,7 @@ class Mplex(IMuxedConn):
stream_queue: "asyncio.Queue[int]" stream_queue: "asyncio.Queue[int]"
next_stream_id: int next_stream_id: int
# TODO: `generic_protocol_handler` should be refactored out of mplex conn.
def __init__( def __init__(
self, self,
secured_conn: ISecureConn, secured_conn: ISecureConn,
@ -114,28 +117,25 @@ class Mplex(IMuxedConn):
self.next_stream_id += 2 self.next_stream_id += 2
return next_id return next_id
# FIXME: Remove multiaddr from being passed into muxed_conn async def open_stream(self) -> IMuxedStream:
async def open_stream(
self, protocol_id: str, multi_addr: Multiaddr
) -> IMuxedStream:
""" """
creates a new muxed_stream creates a new muxed_stream
:param protocol_id: protocol_id of stream :return: a new ``MplexStream``
:param multi_addr: multi_addr that stream connects to
:return: a new muxed stream
""" """
stream_id = self._get_next_stream_id() stream_id = self._get_next_stream_id()
stream = MplexStream(stream_id, True, self) name = str(stream_id)
stream = MplexStream(name, stream_id, True, self)
self.buffers[stream_id] = asyncio.Queue() self.buffers[stream_id] = asyncio.Queue()
await self.send_message(HeaderTags.NewStream, None, stream_id) # Default stream name is the `stream_id`
await self.send_message(HeaderTags.NewStream, name.encode(), stream_id)
return stream return stream
async def accept_stream(self) -> None: async def accept_stream(self, name: str) -> None:
""" """
accepts a muxed stream opened by the other end accepts a muxed stream opened by the other end
""" """
stream_id = await self.stream_queue.get() stream_id = await self.stream_queue.get()
stream = MplexStream(stream_id, False, self) stream = MplexStream(name, stream_id, False, self)
asyncio.ensure_future(self.generic_protocol_handler(stream)) asyncio.ensure_future(self.generic_protocol_handler(stream))
async def send_message(self, flag: HeaderTags, data: bytes, stream_id: int) -> int: async def send_message(self, flag: HeaderTags, data: bytes, stream_id: int) -> int:
@ -181,11 +181,14 @@ class Mplex(IMuxedConn):
self.buffers[stream_id] = asyncio.Queue() self.buffers[stream_id] = asyncio.Queue()
await self.stream_queue.put(stream_id) await self.stream_queue.put(stream_id)
# TODO: Handle more tags, and refactor `HeaderTags`
if flag == HeaderTags.NewStream.value: if flag == HeaderTags.NewStream.value:
# new stream detected on connection # new stream detected on connection
await self.accept_stream() await self.accept_stream(message.decode())
elif flag in (
if message: HeaderTags.MessageInitiator.value,
HeaderTags.MessageReceiver.value,
):
await self.buffers[stream_id].put(message) await self.buffers[stream_id].put(message)
# Force context switch # Force context switch
@ -200,14 +203,14 @@ class Mplex(IMuxedConn):
# FIXME: No timeout is used in Go implementation. # FIXME: No timeout is used in Go implementation.
# Timeout is set to a relatively small value to alleviate wait time to exit # Timeout is set to a relatively small value to alleviate wait time to exit
# loop in handle_incoming # loop in handle_incoming
timeout = 0.1 header = await decode_uvarint_from_stream(self.secured_conn)
# TODO: Handle the case of EOF and other exceptions?
try: try:
header = await decode_uvarint_from_stream(self.secured_conn, timeout)
length = await decode_uvarint_from_stream(self.secured_conn, timeout)
message = await asyncio.wait_for( message = await asyncio.wait_for(
self.secured_conn.read(length), timeout=timeout read_varint_prefixed_bytes(self.secured_conn), timeout=5
) )
except asyncio.TimeoutError: except asyncio.TimeoutError:
# TODO: Investigate what we should do if time is out.
return None, None, None return None, None, None
flag = header & 0x07 flag = header & 0x07

View File

@ -10,6 +10,7 @@ class MplexStream(IMuxedStream):
reference: https://github.com/libp2p/go-mplex/blob/master/stream.go reference: https://github.com/libp2p/go-mplex/blob/master/stream.go
""" """
name: str
stream_id: int stream_id: int
initiator: bool initiator: bool
mplex_conn: IMuxedConn mplex_conn: IMuxedConn
@ -21,13 +22,16 @@ class MplexStream(IMuxedStream):
_buf: bytearray _buf: bytearray
def __init__(self, stream_id: int, initiator: bool, mplex_conn: IMuxedConn) -> None: def __init__(
self, name: str, stream_id: int, initiator: bool, mplex_conn: IMuxedConn
) -> None:
""" """
create new MuxedStream in muxer create new MuxedStream in muxer
:param stream_id: stream stream id :param stream_id: stream stream id
:param initiator: boolean if this is an initiator :param initiator: boolean if this is an initiator
:param mplex_conn: muxed connection of this muxed_stream :param mplex_conn: muxed connection of this muxed_stream
""" """
self.name = name
self.stream_id = stream_id self.stream_id = stream_id
self.initiator = initiator self.initiator = initiator
self.mplex_conn = mplex_conn self.mplex_conn = mplex_conn

View File

@ -1,9 +1,20 @@
import asyncio import itertools
import struct import math
from typing import Tuple
from libp2p.exceptions import ParseError
from libp2p.typing import StreamReader from libp2p.typing import StreamReader
# Unsigned LEB128(varint codec)
# Reference: https://github.com/ethereum/py-wasm/blob/master/wasm/parsers/leb128.py
LOW_MASK = 2 ** 7 - 1
HIGH_MASK = 2 ** 7
# The maximum shift width for a 64 bit integer. We shouldn't have to decode
# integers larger than this.
SHIFT_64_BIT_MAX = int(math.ceil(64 / 7)) * 7
def encode_uvarint(number: int) -> bytes: def encode_uvarint(number: int) -> bytes:
"""Pack `number` into varint bytes""" """Pack `number` into varint bytes"""
@ -19,35 +30,29 @@ def encode_uvarint(number: int) -> bytes:
return buf return buf
def decode_uvarint(buff: bytes, index: int) -> Tuple[int, int]: async def decode_uvarint_from_stream(reader: StreamReader) -> int:
shift = 0 """
result = 0 https://en.wikipedia.org/wiki/LEB128
while True: """
i = buff[index] res = 0
result |= (i & 0x7F) << shift for shift in itertools.count(0, 7):
shift += 7 if shift > SHIFT_64_BIT_MAX:
if not i & 0x80: raise ParseError("TODO: better exception msg: Integer is too large...")
byte = await reader.read(1)
try:
value = byte[0]
except IndexError:
raise ParseError(
"Unexpected end of stream while parsing LEB128 encoded integer"
)
res += (value & LOW_MASK) << shift
if not value & HIGH_MASK:
break break
index += 1 return res
return result, index + 1
async def decode_uvarint_from_stream(reader: StreamReader, timeout: float) -> int:
shift = 0
result = 0
while True:
byte = await asyncio.wait_for(reader.read(1), timeout=timeout)
i = struct.unpack(">H", b"\x00" + byte)[0]
result |= (i & 0x7F) << shift
shift += 7
if not i & 0x80:
break
return result
# Varint-prefixed read/write
def encode_varint_prefixed(msg_bytes: bytes) -> bytes: def encode_varint_prefixed(msg_bytes: bytes) -> bytes:
@ -56,7 +61,7 @@ def encode_varint_prefixed(msg_bytes: bytes) -> bytes:
async def read_varint_prefixed_bytes(reader: StreamReader) -> bytes: async def read_varint_prefixed_bytes(reader: StreamReader) -> bytes:
len_msg = await decode_uvarint_from_stream(reader, None) len_msg = await decode_uvarint_from_stream(reader)
data = await reader.read(len_msg) data = await reader.read(len_msg)
if len(data) != len_msg: if len(data) != len_msg:
raise ValueError( raise ValueError(