From 0b466ddc86762b08448c8ab7329e10613e8f4772 Mon Sep 17 00:00:00 2001 From: mhchia Date: Tue, 20 Aug 2019 17:09:38 +0800 Subject: [PATCH] Add lock to `RawConnection` To avoid `self.writer.drain()` is called in parallel. Reference: https://bugs.python.org/issue29930 --- libp2p/network/connection/raw_connection.py | 15 +++++++--- .../multiselect_communicator.py | 13 ++++---- libp2p/security/insecure/transport.py | 3 +- libp2p/stream_muxer/mplex/mplex.py | 3 +- libp2p/utils.py | 30 +++++++++++-------- 5 files changed, 37 insertions(+), 27 deletions(-) diff --git a/libp2p/network/connection/raw_connection.py b/libp2p/network/connection/raw_connection.py index 3d12f0d..3277901 100644 --- a/libp2p/network/connection/raw_connection.py +++ b/libp2p/network/connection/raw_connection.py @@ -9,9 +9,11 @@ class RawConnection(IRawConnection): conn_port: str reader: asyncio.StreamReader writer: asyncio.StreamWriter - _next_id: int initiator: bool + _drain_lock: asyncio.Lock + _next_id: int + def __init__( self, ip: str, @@ -24,13 +26,18 @@ class RawConnection(IRawConnection): self.conn_port = port self.reader = reader self.writer = writer - self._next_id = 0 if initiator else 1 self.initiator = initiator + self._drain_lock = asyncio.Lock() + self._next_id = 0 if initiator else 1 + async def write(self, data: bytes) -> None: self.writer.write(data) - self.writer.write("\n".encode()) - await self.writer.drain() + # Reference: https://github.com/ethereum/lahja/blob/93610b2eb46969ff1797e0748c7ac2595e130aef/lahja/asyncio/endpoint.py#L99-L102 # noqa: E501 + # Use a lock to serialize drain() calls. Circumvents this bug: + # https://bugs.python.org/issue29930 + async with self._drain_lock: + await self.writer.drain() async def read(self) -> bytes: line = await self.reader.readline() diff --git a/libp2p/protocol_muxer/multiselect_communicator.py b/libp2p/protocol_muxer/multiselect_communicator.py index ebfcc23..e01e9cc 100644 --- a/libp2p/protocol_muxer/multiselect_communicator.py +++ b/libp2p/protocol_muxer/multiselect_communicator.py @@ -12,12 +12,12 @@ class RawConnectionCommunicator(IMultiselectCommunicator): self.conn = conn async def write(self, msg_str: str) -> None: - msg_bytes = encode_delim(msg_str) - self.conn.writer.write(msg_bytes) - await self.conn.writer.drain() + msg_bytes = encode_delim(msg_str.encode()) + await self.conn.write(msg_bytes) async def read(self) -> str: - return await read_delim(self.conn.reader) + data = await read_delim(self.conn.reader) + return data.decode() class StreamCommunicator(IMultiselectCommunicator): @@ -27,8 +27,9 @@ class StreamCommunicator(IMultiselectCommunicator): self.stream = stream async def write(self, msg_str: str) -> None: - msg_bytes = encode_delim(msg_str) + msg_bytes = encode_delim(msg_str.encode()) await self.stream.write(msg_bytes) async def read(self) -> str: - return await read_delim(self.stream) + data = await read_delim(self.stream) + return data.decode() diff --git a/libp2p/security/insecure/transport.py b/libp2p/security/insecure/transport.py index 337d20f..fa4a1a8 100644 --- a/libp2p/security/insecure/transport.py +++ b/libp2p/security/insecure/transport.py @@ -21,8 +21,7 @@ class InsecureSession(BaseSession): msg = make_exchange_message(self.local_private_key.get_public_key()) msg_bytes = msg.SerializeToString() encoded_msg_bytes = encode_fixedint_prefixed(msg_bytes) - self.writer.write(encoded_msg_bytes) - await self.writer.drain() + await self.write(encoded_msg_bytes) msg_bytes_other_side = await read_fixedint_prefixed(self.reader) msg_other_side = plaintext_pb2.Exchange() diff --git a/libp2p/stream_muxer/mplex/mplex.py b/libp2p/stream_muxer/mplex/mplex.py index 16d1019..765dd56 100644 --- a/libp2p/stream_muxer/mplex/mplex.py +++ b/libp2p/stream_muxer/mplex/mplex.py @@ -150,8 +150,7 @@ class Mplex(IMuxedConn): :param _bytes: byte array to write :return: length written """ - self.conn.writer.write(_bytes) - await self.conn.writer.drain() + await self.conn.write(_bytes) return len(_bytes) async def handle_incoming(self) -> None: diff --git a/libp2p/utils.py b/libp2p/utils.py index 5fbc8ac..9a1f0cb 100644 --- a/libp2p/utils.py +++ b/libp2p/utils.py @@ -4,8 +4,6 @@ from typing import Tuple from libp2p.typing import StreamReader -TIMEOUT = 10 - def encode_uvarint(number: int) -> bytes: """Pack `number` into varint bytes""" @@ -57,25 +55,31 @@ def encode_varint_prefixed(msg_bytes: bytes) -> bytes: return varint_len + msg_bytes -async def read_varint_prefixed_bytes( - reader: StreamReader, timeout: int = TIMEOUT -) -> bytes: - len_msg = await decode_uvarint_from_stream(reader, timeout) - return await reader.read(len_msg) +async def read_varint_prefixed_bytes(reader: StreamReader) -> bytes: + len_msg = await decode_uvarint_from_stream(reader, None) + data = await reader.read(len_msg) + if len(data) != len_msg: + raise ValueError( + f"failed to read enough bytes: len_msg={len_msg}, data={data!r}" + ) + return data # Delimited read/write, used by multistream-select. # Reference: https://github.com/gogo/protobuf/blob/07eab6a8298cf32fac45cceaac59424f98421bbc/io/varint.go#L109-L126 # noqa: E501 -def encode_delim(msg_str: str) -> bytes: - delimited_msg = msg_str + "\n" - return encode_varint_prefixed(delimited_msg.encode()) +def encode_delim(msg: bytes) -> bytes: + delimited_msg = msg + b"\n" + return encode_varint_prefixed(delimited_msg) -async def read_delim(reader: StreamReader, timeout: int = TIMEOUT) -> str: - msg_bytes = await read_varint_prefixed_bytes(reader, timeout) - return msg_bytes.decode().rstrip() +async def read_delim(reader: StreamReader) -> bytes: + msg_bytes = await read_varint_prefixed_bytes(reader) + # TODO: Investigate if it is possible to have empty `msg_bytes` + if len(msg_bytes) != 0 and msg_bytes[-1:] != b"\n": + raise ValueError(f'msg_bytes is not delimited by b"\\n": msg_bytes={msg_bytes}') + return msg_bytes[:-1] SIZE_LEN_BYTES = 4