diff --git a/libp2p/exceptions.py b/libp2p/exceptions.py index 0ea0078..ce2dc34 100644 --- a/libp2p/exceptions.py +++ b/libp2p/exceptions.py @@ -6,3 +6,7 @@ class ValidationError(BaseLibp2pError): """ Raised when something does not pass a validation check. """ + + +class ParseError(BaseLibp2pError): + pass diff --git a/libp2p/stream_muxer/mplex/mplex.py b/libp2p/stream_muxer/mplex/mplex.py index 4a0ba41..2dd2618 100644 --- a/libp2p/stream_muxer/mplex/mplex.py +++ b/libp2p/stream_muxer/mplex/mplex.py @@ -6,7 +6,11 @@ from libp2p.peer.id import ID from libp2p.security.secure_conn_interface import ISecureConn from libp2p.stream_muxer.abc import IMuxedConn, IMuxedStream from libp2p.typing import TProtocol -from libp2p.utils import decode_uvarint_from_stream, encode_uvarint +from libp2p.utils import ( + decode_uvarint_from_stream, + encode_uvarint, + read_varint_prefixed_bytes, +) from .constants import HeaderTags from .exceptions import StreamNotFound @@ -119,11 +123,11 @@ class Mplex(IMuxedConn): :return: a new ``MplexStream`` """ stream_id = self._get_next_stream_id() - name = str(stream_id).encode() + name = str(stream_id) stream = MplexStream(name, stream_id, True, self) self.buffers[stream_id] = asyncio.Queue() # Default stream name is the `stream_id` - await self.send_message(HeaderTags.NewStream, name, stream_id) + await self.send_message(HeaderTags.NewStream, name.encode(), stream_id) return stream async def accept_stream(self, name: str) -> None: @@ -180,7 +184,7 @@ class Mplex(IMuxedConn): # TODO: Handle more tags, and refactor `HeaderTags` if flag == HeaderTags.NewStream.value: # new stream detected on connection - await self.accept_stream(message) + await self.accept_stream(message.decode()) elif flag in ( HeaderTags.MessageInitiator.value, HeaderTags.MessageReceiver.value, @@ -199,14 +203,14 @@ class Mplex(IMuxedConn): # FIXME: No timeout is used in Go implementation. # Timeout is set to a relatively small value to alleviate wait time to exit # loop in handle_incoming - timeout = 0.1 + header = await decode_uvarint_from_stream(self.secured_conn) + # TODO: Handle the case of EOF and other exceptions? try: - header = await decode_uvarint_from_stream(self.secured_conn, timeout) - length = await decode_uvarint_from_stream(self.secured_conn, timeout) message = await asyncio.wait_for( - self.secured_conn.read(length), timeout=timeout + read_varint_prefixed_bytes(self.secured_conn), timeout=5 ) except asyncio.TimeoutError: + # TODO: Investigate what we should do if time is out. return None, None, None flag = header & 0x07 diff --git a/libp2p/utils.py b/libp2p/utils.py index 9a1f0cb..a309064 100644 --- a/libp2p/utils.py +++ b/libp2p/utils.py @@ -1,9 +1,20 @@ -import asyncio -import struct -from typing import Tuple +import itertools +import math +from libp2p.exceptions import ParseError from libp2p.typing import StreamReader +# Unsigned LEB128(varint codec) +# Reference: https://github.com/ethereum/py-wasm/blob/master/wasm/parsers/leb128.py + +LOW_MASK = 2 ** 7 - 1 +HIGH_MASK = 2 ** 7 + + +# The maximum shift width for a 64 bit integer. We shouldn't have to decode +# integers larger than this. +SHIFT_64_BIT_MAX = int(math.ceil(64 / 7)) * 7 + def encode_uvarint(number: int) -> bytes: """Pack `number` into varint bytes""" @@ -19,35 +30,29 @@ def encode_uvarint(number: int) -> bytes: return buf -def decode_uvarint(buff: bytes, index: int) -> Tuple[int, int]: - shift = 0 - result = 0 - while True: - i = buff[index] - result |= (i & 0x7F) << shift - shift += 7 - if not i & 0x80: +async def decode_uvarint_from_stream(reader: StreamReader) -> int: + """ + https://en.wikipedia.org/wiki/LEB128 + """ + res = 0 + for shift in itertools.count(0, 7): + if shift > SHIFT_64_BIT_MAX: + raise ParseError("TODO: better exception msg: Integer is too large...") + + byte = await reader.read(1) + + try: + value = byte[0] + except IndexError: + raise ParseError( + "Unexpected end of stream while parsing LEB128 encoded integer" + ) + + res += (value & LOW_MASK) << shift + + if not value & HIGH_MASK: break - index += 1 - - return result, index + 1 - - -async def decode_uvarint_from_stream(reader: StreamReader, timeout: float) -> int: - shift = 0 - result = 0 - while True: - byte = await asyncio.wait_for(reader.read(1), timeout=timeout) - i = struct.unpack(">H", b"\x00" + byte)[0] - result |= (i & 0x7F) << shift - shift += 7 - if not i & 0x80: - break - - return result - - -# Varint-prefixed read/write + return res def encode_varint_prefixed(msg_bytes: bytes) -> bytes: @@ -56,7 +61,7 @@ def encode_varint_prefixed(msg_bytes: bytes) -> bytes: async def read_varint_prefixed_bytes(reader: StreamReader) -> bytes: - len_msg = await decode_uvarint_from_stream(reader, None) + len_msg = await decode_uvarint_from_stream(reader) data = await reader.read(len_msg) if len(data) != len_msg: raise ValueError(