Add lock to RawConnection

To avoid `self.writer.drain()` is called in parallel.
Reference: https://bugs.python.org/issue29930
This commit is contained in:
mhchia 2019-08-20 17:09:38 +08:00
parent 5768daa9bf
commit 0b466ddc86
No known key found for this signature in database
GPG Key ID: 389EFBEA1362589A
5 changed files with 37 additions and 27 deletions

View File

@ -9,9 +9,11 @@ class RawConnection(IRawConnection):
conn_port: str conn_port: str
reader: asyncio.StreamReader reader: asyncio.StreamReader
writer: asyncio.StreamWriter writer: asyncio.StreamWriter
_next_id: int
initiator: bool initiator: bool
_drain_lock: asyncio.Lock
_next_id: int
def __init__( def __init__(
self, self,
ip: str, ip: str,
@ -24,13 +26,18 @@ class RawConnection(IRawConnection):
self.conn_port = port self.conn_port = port
self.reader = reader self.reader = reader
self.writer = writer self.writer = writer
self._next_id = 0 if initiator else 1
self.initiator = initiator self.initiator = initiator
self._drain_lock = asyncio.Lock()
self._next_id = 0 if initiator else 1
async def write(self, data: bytes) -> None: async def write(self, data: bytes) -> None:
self.writer.write(data) self.writer.write(data)
self.writer.write("\n".encode()) # Reference: https://github.com/ethereum/lahja/blob/93610b2eb46969ff1797e0748c7ac2595e130aef/lahja/asyncio/endpoint.py#L99-L102 # noqa: E501
await self.writer.drain() # Use a lock to serialize drain() calls. Circumvents this bug:
# https://bugs.python.org/issue29930
async with self._drain_lock:
await self.writer.drain()
async def read(self) -> bytes: async def read(self) -> bytes:
line = await self.reader.readline() line = await self.reader.readline()

View File

@ -12,12 +12,12 @@ class RawConnectionCommunicator(IMultiselectCommunicator):
self.conn = conn self.conn = conn
async def write(self, msg_str: str) -> None: async def write(self, msg_str: str) -> None:
msg_bytes = encode_delim(msg_str) msg_bytes = encode_delim(msg_str.encode())
self.conn.writer.write(msg_bytes) await self.conn.write(msg_bytes)
await self.conn.writer.drain()
async def read(self) -> str: async def read(self) -> str:
return await read_delim(self.conn.reader) data = await read_delim(self.conn.reader)
return data.decode()
class StreamCommunicator(IMultiselectCommunicator): class StreamCommunicator(IMultiselectCommunicator):
@ -27,8 +27,9 @@ class StreamCommunicator(IMultiselectCommunicator):
self.stream = stream self.stream = stream
async def write(self, msg_str: str) -> None: async def write(self, msg_str: str) -> None:
msg_bytes = encode_delim(msg_str) msg_bytes = encode_delim(msg_str.encode())
await self.stream.write(msg_bytes) await self.stream.write(msg_bytes)
async def read(self) -> str: async def read(self) -> str:
return await read_delim(self.stream) data = await read_delim(self.stream)
return data.decode()

View File

@ -21,8 +21,7 @@ class InsecureSession(BaseSession):
msg = make_exchange_message(self.local_private_key.get_public_key()) msg = make_exchange_message(self.local_private_key.get_public_key())
msg_bytes = msg.SerializeToString() msg_bytes = msg.SerializeToString()
encoded_msg_bytes = encode_fixedint_prefixed(msg_bytes) encoded_msg_bytes = encode_fixedint_prefixed(msg_bytes)
self.writer.write(encoded_msg_bytes) await self.write(encoded_msg_bytes)
await self.writer.drain()
msg_bytes_other_side = await read_fixedint_prefixed(self.reader) msg_bytes_other_side = await read_fixedint_prefixed(self.reader)
msg_other_side = plaintext_pb2.Exchange() msg_other_side = plaintext_pb2.Exchange()

View File

@ -150,8 +150,7 @@ class Mplex(IMuxedConn):
:param _bytes: byte array to write :param _bytes: byte array to write
:return: length written :return: length written
""" """
self.conn.writer.write(_bytes) await self.conn.write(_bytes)
await self.conn.writer.drain()
return len(_bytes) return len(_bytes)
async def handle_incoming(self) -> None: async def handle_incoming(self) -> None:

View File

@ -4,8 +4,6 @@ from typing import Tuple
from libp2p.typing import StreamReader from libp2p.typing import StreamReader
TIMEOUT = 10
def encode_uvarint(number: int) -> bytes: def encode_uvarint(number: int) -> bytes:
"""Pack `number` into varint bytes""" """Pack `number` into varint bytes"""
@ -57,25 +55,31 @@ def encode_varint_prefixed(msg_bytes: bytes) -> bytes:
return varint_len + msg_bytes return varint_len + msg_bytes
async def read_varint_prefixed_bytes( async def read_varint_prefixed_bytes(reader: StreamReader) -> bytes:
reader: StreamReader, timeout: int = TIMEOUT len_msg = await decode_uvarint_from_stream(reader, None)
) -> bytes: data = await reader.read(len_msg)
len_msg = await decode_uvarint_from_stream(reader, timeout) if len(data) != len_msg:
return await reader.read(len_msg) raise ValueError(
f"failed to read enough bytes: len_msg={len_msg}, data={data!r}"
)
return data
# Delimited read/write, used by multistream-select. # Delimited read/write, used by multistream-select.
# Reference: https://github.com/gogo/protobuf/blob/07eab6a8298cf32fac45cceaac59424f98421bbc/io/varint.go#L109-L126 # noqa: E501 # Reference: https://github.com/gogo/protobuf/blob/07eab6a8298cf32fac45cceaac59424f98421bbc/io/varint.go#L109-L126 # noqa: E501
def encode_delim(msg_str: str) -> bytes: def encode_delim(msg: bytes) -> bytes:
delimited_msg = msg_str + "\n" delimited_msg = msg + b"\n"
return encode_varint_prefixed(delimited_msg.encode()) return encode_varint_prefixed(delimited_msg)
async def read_delim(reader: StreamReader, timeout: int = TIMEOUT) -> str: async def read_delim(reader: StreamReader) -> bytes:
msg_bytes = await read_varint_prefixed_bytes(reader, timeout) msg_bytes = await read_varint_prefixed_bytes(reader)
return msg_bytes.decode().rstrip() # TODO: Investigate if it is possible to have empty `msg_bytes`
if len(msg_bytes) != 0 and msg_bytes[-1:] != b"\n":
raise ValueError(f'msg_bytes is not delimited by b"\\n": msg_bytes={msg_bytes}')
return msg_bytes[:-1]
SIZE_LEN_BYTES = 4 SIZE_LEN_BYTES = 4