Add lock to RawConnection
To avoid `self.writer.drain()` is called in parallel. Reference: https://bugs.python.org/issue29930
This commit is contained in:
parent
5768daa9bf
commit
0b466ddc86
|
@ -9,9 +9,11 @@ class RawConnection(IRawConnection):
|
||||||
conn_port: str
|
conn_port: str
|
||||||
reader: asyncio.StreamReader
|
reader: asyncio.StreamReader
|
||||||
writer: asyncio.StreamWriter
|
writer: asyncio.StreamWriter
|
||||||
_next_id: int
|
|
||||||
initiator: bool
|
initiator: bool
|
||||||
|
|
||||||
|
_drain_lock: asyncio.Lock
|
||||||
|
_next_id: int
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
ip: str,
|
ip: str,
|
||||||
|
@ -24,13 +26,18 @@ class RawConnection(IRawConnection):
|
||||||
self.conn_port = port
|
self.conn_port = port
|
||||||
self.reader = reader
|
self.reader = reader
|
||||||
self.writer = writer
|
self.writer = writer
|
||||||
self._next_id = 0 if initiator else 1
|
|
||||||
self.initiator = initiator
|
self.initiator = initiator
|
||||||
|
|
||||||
|
self._drain_lock = asyncio.Lock()
|
||||||
|
self._next_id = 0 if initiator else 1
|
||||||
|
|
||||||
async def write(self, data: bytes) -> None:
|
async def write(self, data: bytes) -> None:
|
||||||
self.writer.write(data)
|
self.writer.write(data)
|
||||||
self.writer.write("\n".encode())
|
# Reference: https://github.com/ethereum/lahja/blob/93610b2eb46969ff1797e0748c7ac2595e130aef/lahja/asyncio/endpoint.py#L99-L102 # noqa: E501
|
||||||
await self.writer.drain()
|
# Use a lock to serialize drain() calls. Circumvents this bug:
|
||||||
|
# https://bugs.python.org/issue29930
|
||||||
|
async with self._drain_lock:
|
||||||
|
await self.writer.drain()
|
||||||
|
|
||||||
async def read(self) -> bytes:
|
async def read(self) -> bytes:
|
||||||
line = await self.reader.readline()
|
line = await self.reader.readline()
|
||||||
|
|
|
@ -12,12 +12,12 @@ class RawConnectionCommunicator(IMultiselectCommunicator):
|
||||||
self.conn = conn
|
self.conn = conn
|
||||||
|
|
||||||
async def write(self, msg_str: str) -> None:
|
async def write(self, msg_str: str) -> None:
|
||||||
msg_bytes = encode_delim(msg_str)
|
msg_bytes = encode_delim(msg_str.encode())
|
||||||
self.conn.writer.write(msg_bytes)
|
await self.conn.write(msg_bytes)
|
||||||
await self.conn.writer.drain()
|
|
||||||
|
|
||||||
async def read(self) -> str:
|
async def read(self) -> str:
|
||||||
return await read_delim(self.conn.reader)
|
data = await read_delim(self.conn.reader)
|
||||||
|
return data.decode()
|
||||||
|
|
||||||
|
|
||||||
class StreamCommunicator(IMultiselectCommunicator):
|
class StreamCommunicator(IMultiselectCommunicator):
|
||||||
|
@ -27,8 +27,9 @@ class StreamCommunicator(IMultiselectCommunicator):
|
||||||
self.stream = stream
|
self.stream = stream
|
||||||
|
|
||||||
async def write(self, msg_str: str) -> None:
|
async def write(self, msg_str: str) -> None:
|
||||||
msg_bytes = encode_delim(msg_str)
|
msg_bytes = encode_delim(msg_str.encode())
|
||||||
await self.stream.write(msg_bytes)
|
await self.stream.write(msg_bytes)
|
||||||
|
|
||||||
async def read(self) -> str:
|
async def read(self) -> str:
|
||||||
return await read_delim(self.stream)
|
data = await read_delim(self.stream)
|
||||||
|
return data.decode()
|
||||||
|
|
|
@ -21,8 +21,7 @@ class InsecureSession(BaseSession):
|
||||||
msg = make_exchange_message(self.local_private_key.get_public_key())
|
msg = make_exchange_message(self.local_private_key.get_public_key())
|
||||||
msg_bytes = msg.SerializeToString()
|
msg_bytes = msg.SerializeToString()
|
||||||
encoded_msg_bytes = encode_fixedint_prefixed(msg_bytes)
|
encoded_msg_bytes = encode_fixedint_prefixed(msg_bytes)
|
||||||
self.writer.write(encoded_msg_bytes)
|
await self.write(encoded_msg_bytes)
|
||||||
await self.writer.drain()
|
|
||||||
|
|
||||||
msg_bytes_other_side = await read_fixedint_prefixed(self.reader)
|
msg_bytes_other_side = await read_fixedint_prefixed(self.reader)
|
||||||
msg_other_side = plaintext_pb2.Exchange()
|
msg_other_side = plaintext_pb2.Exchange()
|
||||||
|
|
|
@ -150,8 +150,7 @@ class Mplex(IMuxedConn):
|
||||||
:param _bytes: byte array to write
|
:param _bytes: byte array to write
|
||||||
:return: length written
|
:return: length written
|
||||||
"""
|
"""
|
||||||
self.conn.writer.write(_bytes)
|
await self.conn.write(_bytes)
|
||||||
await self.conn.writer.drain()
|
|
||||||
return len(_bytes)
|
return len(_bytes)
|
||||||
|
|
||||||
async def handle_incoming(self) -> None:
|
async def handle_incoming(self) -> None:
|
||||||
|
|
|
@ -4,8 +4,6 @@ from typing import Tuple
|
||||||
|
|
||||||
from libp2p.typing import StreamReader
|
from libp2p.typing import StreamReader
|
||||||
|
|
||||||
TIMEOUT = 10
|
|
||||||
|
|
||||||
|
|
||||||
def encode_uvarint(number: int) -> bytes:
|
def encode_uvarint(number: int) -> bytes:
|
||||||
"""Pack `number` into varint bytes"""
|
"""Pack `number` into varint bytes"""
|
||||||
|
@ -57,25 +55,31 @@ def encode_varint_prefixed(msg_bytes: bytes) -> bytes:
|
||||||
return varint_len + msg_bytes
|
return varint_len + msg_bytes
|
||||||
|
|
||||||
|
|
||||||
async def read_varint_prefixed_bytes(
|
async def read_varint_prefixed_bytes(reader: StreamReader) -> bytes:
|
||||||
reader: StreamReader, timeout: int = TIMEOUT
|
len_msg = await decode_uvarint_from_stream(reader, None)
|
||||||
) -> bytes:
|
data = await reader.read(len_msg)
|
||||||
len_msg = await decode_uvarint_from_stream(reader, timeout)
|
if len(data) != len_msg:
|
||||||
return await reader.read(len_msg)
|
raise ValueError(
|
||||||
|
f"failed to read enough bytes: len_msg={len_msg}, data={data!r}"
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
# Delimited read/write, used by multistream-select.
|
# Delimited read/write, used by multistream-select.
|
||||||
# Reference: https://github.com/gogo/protobuf/blob/07eab6a8298cf32fac45cceaac59424f98421bbc/io/varint.go#L109-L126 # noqa: E501
|
# Reference: https://github.com/gogo/protobuf/blob/07eab6a8298cf32fac45cceaac59424f98421bbc/io/varint.go#L109-L126 # noqa: E501
|
||||||
|
|
||||||
|
|
||||||
def encode_delim(msg_str: str) -> bytes:
|
def encode_delim(msg: bytes) -> bytes:
|
||||||
delimited_msg = msg_str + "\n"
|
delimited_msg = msg + b"\n"
|
||||||
return encode_varint_prefixed(delimited_msg.encode())
|
return encode_varint_prefixed(delimited_msg)
|
||||||
|
|
||||||
|
|
||||||
async def read_delim(reader: StreamReader, timeout: int = TIMEOUT) -> str:
|
async def read_delim(reader: StreamReader) -> bytes:
|
||||||
msg_bytes = await read_varint_prefixed_bytes(reader, timeout)
|
msg_bytes = await read_varint_prefixed_bytes(reader)
|
||||||
return msg_bytes.decode().rstrip()
|
# TODO: Investigate if it is possible to have empty `msg_bytes`
|
||||||
|
if len(msg_bytes) != 0 and msg_bytes[-1:] != b"\n":
|
||||||
|
raise ValueError(f'msg_bytes is not delimited by b"\\n": msg_bytes={msg_bytes}')
|
||||||
|
return msg_bytes[:-1]
|
||||||
|
|
||||||
|
|
||||||
SIZE_LEN_BYTES = 4
|
SIZE_LEN_BYTES = 4
|
||||||
|
|
Loading…
Reference in New Issue
Block a user