diff --git a/libp2p/exceptions.py b/libp2p/exceptions.py index 0ea0078..ce2dc34 100644 --- a/libp2p/exceptions.py +++ b/libp2p/exceptions.py @@ -6,3 +6,7 @@ class ValidationError(BaseLibp2pError): """ Raised when something does not pass a validation check. """ + + +class ParseError(BaseLibp2pError): + pass diff --git a/libp2p/network/swarm.py b/libp2p/network/swarm.py index 07b5e14..52df727 100644 --- a/libp2p/network/swarm.py +++ b/libp2p/network/swarm.py @@ -156,15 +156,10 @@ class Swarm(INetwork): if not addrs: raise SwarmException("No known addresses to peer") - multiaddr = addrs[0] - muxed_conn = await self.dial_peer(peer_id) - # Use muxed conn to open stream, which returns - # a muxed stream - # TODO: Remove protocol id from being passed into muxed_conn - # FIXME: Remove multiaddr from being passed into muxed_conn - muxed_stream = await muxed_conn.open_stream(protocol_ids[0], multiaddr) + # Use muxed conn to open stream, which returns a muxed stream + muxed_stream = await muxed_conn.open_stream() # Perform protocol muxing to determine protocol to use selected_protocol = await self.multiselect_client.select_one_of( diff --git a/libp2p/stream_muxer/abc.py b/libp2p/stream_muxer/abc.py index 245a739..f16b805 100644 --- a/libp2p/stream_muxer/abc.py +++ b/libp2p/stream_muxer/abc.py @@ -1,8 +1,6 @@ from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Optional -from multiaddr import Multiaddr - from libp2p.peer.id import ID from libp2p.security.secure_conn_interface import ISecureConn from libp2p.stream_muxer.mplex.constants import HeaderTags @@ -66,20 +64,15 @@ class IMuxedConn(ABC): Read a message from `stream_id`'s buffer, non-blockingly. """ - # FIXME: Remove multiaddr from being passed into muxed_conn @abstractmethod - async def open_stream( - self, protocol_id: str, multi_addr: Multiaddr - ) -> "IMuxedStream": + async def open_stream(self) -> "IMuxedStream": """ creates a new muxed_stream - :param protocol_id: protocol_id of stream - :param multi_addr: multi_addr that stream connects to - :return: a new stream + :return: a new ``IMuxedStream`` stream """ @abstractmethod - async def accept_stream(self) -> None: + async def accept_stream(self, name: str) -> None: """ accepts a muxed stream opened by the other end """ diff --git a/libp2p/stream_muxer/mplex/mplex.py b/libp2p/stream_muxer/mplex/mplex.py index 8c78431..2dd2618 100644 --- a/libp2p/stream_muxer/mplex/mplex.py +++ b/libp2p/stream_muxer/mplex/mplex.py @@ -1,14 +1,16 @@ import asyncio from typing import Dict, Optional, Tuple -from multiaddr import Multiaddr - from libp2p.network.typing import GenericProtocolHandlerFn from libp2p.peer.id import ID from libp2p.security.secure_conn_interface import ISecureConn from libp2p.stream_muxer.abc import IMuxedConn, IMuxedStream from libp2p.typing import TProtocol -from libp2p.utils import decode_uvarint_from_stream, encode_uvarint +from libp2p.utils import ( + decode_uvarint_from_stream, + encode_uvarint, + read_varint_prefixed_bytes, +) from .constants import HeaderTags from .exceptions import StreamNotFound @@ -31,6 +33,7 @@ class Mplex(IMuxedConn): stream_queue: "asyncio.Queue[int]" next_stream_id: int + # TODO: `generic_protocol_handler` should be refactored out of mplex conn. def __init__( self, secured_conn: ISecureConn, @@ -114,28 +117,25 @@ class Mplex(IMuxedConn): self.next_stream_id += 2 return next_id - # FIXME: Remove multiaddr from being passed into muxed_conn - async def open_stream( - self, protocol_id: str, multi_addr: Multiaddr - ) -> IMuxedStream: + async def open_stream(self) -> IMuxedStream: """ creates a new muxed_stream - :param protocol_id: protocol_id of stream - :param multi_addr: multi_addr that stream connects to - :return: a new muxed stream + :return: a new ``MplexStream`` """ stream_id = self._get_next_stream_id() - stream = MplexStream(stream_id, True, self) + name = str(stream_id) + stream = MplexStream(name, stream_id, True, self) self.buffers[stream_id] = asyncio.Queue() - await self.send_message(HeaderTags.NewStream, None, stream_id) + # Default stream name is the `stream_id` + await self.send_message(HeaderTags.NewStream, name.encode(), stream_id) return stream - async def accept_stream(self) -> None: + async def accept_stream(self, name: str) -> None: """ accepts a muxed stream opened by the other end """ stream_id = await self.stream_queue.get() - stream = MplexStream(stream_id, False, self) + stream = MplexStream(name, stream_id, False, self) asyncio.ensure_future(self.generic_protocol_handler(stream)) async def send_message(self, flag: HeaderTags, data: bytes, stream_id: int) -> int: @@ -181,11 +181,14 @@ class Mplex(IMuxedConn): self.buffers[stream_id] = asyncio.Queue() await self.stream_queue.put(stream_id) + # TODO: Handle more tags, and refactor `HeaderTags` if flag == HeaderTags.NewStream.value: # new stream detected on connection - await self.accept_stream() - - if message: + await self.accept_stream(message.decode()) + elif flag in ( + HeaderTags.MessageInitiator.value, + HeaderTags.MessageReceiver.value, + ): await self.buffers[stream_id].put(message) # Force context switch @@ -200,14 +203,14 @@ class Mplex(IMuxedConn): # FIXME: No timeout is used in Go implementation. # Timeout is set to a relatively small value to alleviate wait time to exit # loop in handle_incoming - timeout = 0.1 + header = await decode_uvarint_from_stream(self.secured_conn) + # TODO: Handle the case of EOF and other exceptions? try: - header = await decode_uvarint_from_stream(self.secured_conn, timeout) - length = await decode_uvarint_from_stream(self.secured_conn, timeout) message = await asyncio.wait_for( - self.secured_conn.read(length), timeout=timeout + read_varint_prefixed_bytes(self.secured_conn), timeout=5 ) except asyncio.TimeoutError: + # TODO: Investigate what we should do if time is out. return None, None, None flag = header & 0x07 diff --git a/libp2p/stream_muxer/mplex/mplex_stream.py b/libp2p/stream_muxer/mplex/mplex_stream.py index 2ec23f1..e90a1d4 100644 --- a/libp2p/stream_muxer/mplex/mplex_stream.py +++ b/libp2p/stream_muxer/mplex/mplex_stream.py @@ -10,6 +10,7 @@ class MplexStream(IMuxedStream): reference: https://github.com/libp2p/go-mplex/blob/master/stream.go """ + name: str stream_id: int initiator: bool mplex_conn: IMuxedConn @@ -21,13 +22,16 @@ class MplexStream(IMuxedStream): _buf: bytearray - def __init__(self, stream_id: int, initiator: bool, mplex_conn: IMuxedConn) -> None: + def __init__( + self, name: str, stream_id: int, initiator: bool, mplex_conn: IMuxedConn + ) -> None: """ create new MuxedStream in muxer :param stream_id: stream stream id :param initiator: boolean if this is an initiator :param mplex_conn: muxed connection of this muxed_stream """ + self.name = name self.stream_id = stream_id self.initiator = initiator self.mplex_conn = mplex_conn diff --git a/libp2p/utils.py b/libp2p/utils.py index 9a1f0cb..a309064 100644 --- a/libp2p/utils.py +++ b/libp2p/utils.py @@ -1,9 +1,20 @@ -import asyncio -import struct -from typing import Tuple +import itertools +import math +from libp2p.exceptions import ParseError from libp2p.typing import StreamReader +# Unsigned LEB128(varint codec) +# Reference: https://github.com/ethereum/py-wasm/blob/master/wasm/parsers/leb128.py + +LOW_MASK = 2 ** 7 - 1 +HIGH_MASK = 2 ** 7 + + +# The maximum shift width for a 64 bit integer. We shouldn't have to decode +# integers larger than this. +SHIFT_64_BIT_MAX = int(math.ceil(64 / 7)) * 7 + def encode_uvarint(number: int) -> bytes: """Pack `number` into varint bytes""" @@ -19,35 +30,29 @@ def encode_uvarint(number: int) -> bytes: return buf -def decode_uvarint(buff: bytes, index: int) -> Tuple[int, int]: - shift = 0 - result = 0 - while True: - i = buff[index] - result |= (i & 0x7F) << shift - shift += 7 - if not i & 0x80: +async def decode_uvarint_from_stream(reader: StreamReader) -> int: + """ + https://en.wikipedia.org/wiki/LEB128 + """ + res = 0 + for shift in itertools.count(0, 7): + if shift > SHIFT_64_BIT_MAX: + raise ParseError("TODO: better exception msg: Integer is too large...") + + byte = await reader.read(1) + + try: + value = byte[0] + except IndexError: + raise ParseError( + "Unexpected end of stream while parsing LEB128 encoded integer" + ) + + res += (value & LOW_MASK) << shift + + if not value & HIGH_MASK: break - index += 1 - - return result, index + 1 - - -async def decode_uvarint_from_stream(reader: StreamReader, timeout: float) -> int: - shift = 0 - result = 0 - while True: - byte = await asyncio.wait_for(reader.read(1), timeout=timeout) - i = struct.unpack(">H", b"\x00" + byte)[0] - result |= (i & 0x7F) << shift - shift += 7 - if not i & 0x80: - break - - return result - - -# Varint-prefixed read/write + return res def encode_varint_prefixed(msg_bytes: bytes) -> bytes: @@ -56,7 +61,7 @@ def encode_varint_prefixed(msg_bytes: bytes) -> bytes: async def read_varint_prefixed_bytes(reader: StreamReader) -> bytes: - len_msg = await decode_uvarint_from_stream(reader, None) + len_msg = await decode_uvarint_from_stream(reader) data = await reader.read(len_msg) if len(data) != len_msg: raise ValueError(