Add lock to RawConnection
To avoid `self.writer.drain()` is called in parallel. Reference: https://bugs.python.org/issue29930
This commit is contained in:
parent
5768daa9bf
commit
0b466ddc86
|
@ -9,9 +9,11 @@ class RawConnection(IRawConnection):
|
|||
conn_port: str
|
||||
reader: asyncio.StreamReader
|
||||
writer: asyncio.StreamWriter
|
||||
_next_id: int
|
||||
initiator: bool
|
||||
|
||||
_drain_lock: asyncio.Lock
|
||||
_next_id: int
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ip: str,
|
||||
|
@ -24,12 +26,17 @@ class RawConnection(IRawConnection):
|
|||
self.conn_port = port
|
||||
self.reader = reader
|
||||
self.writer = writer
|
||||
self._next_id = 0 if initiator else 1
|
||||
self.initiator = initiator
|
||||
|
||||
self._drain_lock = asyncio.Lock()
|
||||
self._next_id = 0 if initiator else 1
|
||||
|
||||
async def write(self, data: bytes) -> None:
|
||||
self.writer.write(data)
|
||||
self.writer.write("\n".encode())
|
||||
# Reference: https://github.com/ethereum/lahja/blob/93610b2eb46969ff1797e0748c7ac2595e130aef/lahja/asyncio/endpoint.py#L99-L102 # noqa: E501
|
||||
# Use a lock to serialize drain() calls. Circumvents this bug:
|
||||
# https://bugs.python.org/issue29930
|
||||
async with self._drain_lock:
|
||||
await self.writer.drain()
|
||||
|
||||
async def read(self) -> bytes:
|
||||
|
|
|
@ -12,12 +12,12 @@ class RawConnectionCommunicator(IMultiselectCommunicator):
|
|||
self.conn = conn
|
||||
|
||||
async def write(self, msg_str: str) -> None:
|
||||
msg_bytes = encode_delim(msg_str)
|
||||
self.conn.writer.write(msg_bytes)
|
||||
await self.conn.writer.drain()
|
||||
msg_bytes = encode_delim(msg_str.encode())
|
||||
await self.conn.write(msg_bytes)
|
||||
|
||||
async def read(self) -> str:
|
||||
return await read_delim(self.conn.reader)
|
||||
data = await read_delim(self.conn.reader)
|
||||
return data.decode()
|
||||
|
||||
|
||||
class StreamCommunicator(IMultiselectCommunicator):
|
||||
|
@ -27,8 +27,9 @@ class StreamCommunicator(IMultiselectCommunicator):
|
|||
self.stream = stream
|
||||
|
||||
async def write(self, msg_str: str) -> None:
|
||||
msg_bytes = encode_delim(msg_str)
|
||||
msg_bytes = encode_delim(msg_str.encode())
|
||||
await self.stream.write(msg_bytes)
|
||||
|
||||
async def read(self) -> str:
|
||||
return await read_delim(self.stream)
|
||||
data = await read_delim(self.stream)
|
||||
return data.decode()
|
||||
|
|
|
@ -21,8 +21,7 @@ class InsecureSession(BaseSession):
|
|||
msg = make_exchange_message(self.local_private_key.get_public_key())
|
||||
msg_bytes = msg.SerializeToString()
|
||||
encoded_msg_bytes = encode_fixedint_prefixed(msg_bytes)
|
||||
self.writer.write(encoded_msg_bytes)
|
||||
await self.writer.drain()
|
||||
await self.write(encoded_msg_bytes)
|
||||
|
||||
msg_bytes_other_side = await read_fixedint_prefixed(self.reader)
|
||||
msg_other_side = plaintext_pb2.Exchange()
|
||||
|
|
|
@ -150,8 +150,7 @@ class Mplex(IMuxedConn):
|
|||
:param _bytes: byte array to write
|
||||
:return: length written
|
||||
"""
|
||||
self.conn.writer.write(_bytes)
|
||||
await self.conn.writer.drain()
|
||||
await self.conn.write(_bytes)
|
||||
return len(_bytes)
|
||||
|
||||
async def handle_incoming(self) -> None:
|
||||
|
|
|
@ -4,8 +4,6 @@ from typing import Tuple
|
|||
|
||||
from libp2p.typing import StreamReader
|
||||
|
||||
TIMEOUT = 10
|
||||
|
||||
|
||||
def encode_uvarint(number: int) -> bytes:
|
||||
"""Pack `number` into varint bytes"""
|
||||
|
@ -57,25 +55,31 @@ def encode_varint_prefixed(msg_bytes: bytes) -> bytes:
|
|||
return varint_len + msg_bytes
|
||||
|
||||
|
||||
async def read_varint_prefixed_bytes(
|
||||
reader: StreamReader, timeout: int = TIMEOUT
|
||||
) -> bytes:
|
||||
len_msg = await decode_uvarint_from_stream(reader, timeout)
|
||||
return await reader.read(len_msg)
|
||||
async def read_varint_prefixed_bytes(reader: StreamReader) -> bytes:
|
||||
len_msg = await decode_uvarint_from_stream(reader, None)
|
||||
data = await reader.read(len_msg)
|
||||
if len(data) != len_msg:
|
||||
raise ValueError(
|
||||
f"failed to read enough bytes: len_msg={len_msg}, data={data!r}"
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
# Delimited read/write, used by multistream-select.
|
||||
# Reference: https://github.com/gogo/protobuf/blob/07eab6a8298cf32fac45cceaac59424f98421bbc/io/varint.go#L109-L126 # noqa: E501
|
||||
|
||||
|
||||
def encode_delim(msg_str: str) -> bytes:
|
||||
delimited_msg = msg_str + "\n"
|
||||
return encode_varint_prefixed(delimited_msg.encode())
|
||||
def encode_delim(msg: bytes) -> bytes:
|
||||
delimited_msg = msg + b"\n"
|
||||
return encode_varint_prefixed(delimited_msg)
|
||||
|
||||
|
||||
async def read_delim(reader: StreamReader, timeout: int = TIMEOUT) -> str:
|
||||
msg_bytes = await read_varint_prefixed_bytes(reader, timeout)
|
||||
return msg_bytes.decode().rstrip()
|
||||
async def read_delim(reader: StreamReader) -> bytes:
|
||||
msg_bytes = await read_varint_prefixed_bytes(reader)
|
||||
# TODO: Investigate if it is possible to have empty `msg_bytes`
|
||||
if len(msg_bytes) != 0 and msg_bytes[-1:] != b"\n":
|
||||
raise ValueError(f'msg_bytes is not delimited by b"\\n": msg_bytes={msg_bytes}')
|
||||
return msg_bytes[:-1]
|
||||
|
||||
|
||||
SIZE_LEN_BYTES = 4
|
||||
|
|
Loading…
Reference in New Issue
Block a user