Change default value of read()

From `n = -1` to `n = None`, to comply with trio API
This commit is contained in:
mhchia 2020-01-26 23:03:38 +08:00
parent 6e01a7da31
commit 5b4b65faa8
No known key found for this signature in database
GPG Key ID: 389EFBEA1362589A
9 changed files with 18 additions and 29 deletions

View File

@ -8,7 +8,7 @@ class Closer(ABC):
class Reader(ABC): class Reader(ABC):
@abstractmethod @abstractmethod
async def read(self, n: int = -1) -> bytes: async def read(self, n: int = None) -> bytes:
... ...

View File

@ -54,7 +54,7 @@ class MsgIOReader(ReadCloser):
self.read_closer = read_closer self.read_closer = read_closer
self.next_length = None self.next_length = None
async def read(self, n: int = -1) -> bytes: async def read(self, n: int = None) -> bytes:
return await self.read_msg() return await self.read_msg()
async def read_msg(self) -> bytes: async def read_msg(self) -> bytes:

View File

@ -26,22 +26,13 @@ class TrioTCPStream(ReadWriteCloser):
await self.stream.send_all(data) await self.stream.send_all(data)
except (trio.ClosedResourceError, trio.BrokenResourceError) as error: except (trio.ClosedResourceError, trio.BrokenResourceError) as error:
raise IOException from error raise IOException from error
except trio.BusyResourceError as error:
# This should never happen, since we already access streams with read/write locks.
raise Exception(
"this should never happen "
"since we already access streams with read/write locks."
) from error
async def read(self, n: int = -1) -> bytes: async def read(self, n: int = None) -> bytes:
async with self.read_lock: async with self.read_lock:
if n == 0: if n is not None and n == 0:
# Checkpoint
await trio.hazmat.checkpoint()
return b"" return b""
max_bytes = n if n != -1 else None
try: try:
return await self.stream.receive_some(max_bytes) return await self.stream.receive_some(n)
except (trio.ClosedResourceError, trio.BrokenResourceError) as error: except (trio.ClosedResourceError, trio.BrokenResourceError) as error:
raise IOException from error raise IOException from error
except trio.BusyResourceError as error: except trio.BusyResourceError as error:

View File

@ -20,7 +20,7 @@ class RawConnection(IRawConnection):
except IOException as error: except IOException as error:
raise RawConnError(error) raise RawConnError(error)
async def read(self, n: int = -1) -> bytes: async def read(self, n: int = None) -> bytes:
""" """
Read up to ``n`` bytes from the underlying stream. This call is Read up to ``n`` bytes from the underlying stream. This call is
delegated directly to the underlying ``self.reader``. delegated directly to the underlying ``self.reader``.

View File

@ -37,7 +37,7 @@ class NetStream(INetStream):
""" """
self.protocol_id = protocol_id self.protocol_id = protocol_id
async def read(self, n: int = -1) -> bytes: async def read(self, n: int = None) -> bytes:
""" """
reads from stream. reads from stream.

View File

@ -39,7 +39,7 @@ class InsecureSession(BaseSession):
await self.conn.write(data) await self.conn.write(data)
return len(data) return len(data)
async def read(self, n: int = -1) -> bytes: async def read(self, n: int = None) -> bytes:
return await self.conn.read(n) return await self.conn.read(n)
async def close(self) -> None: async def close(self) -> None:

View File

@ -94,7 +94,7 @@ class SecureSession(BaseSession):
data = self.buf.getbuffer()[self.low_watermark : self.high_watermark] data = self.buf.getbuffer()[self.low_watermark : self.high_watermark]
if n < 0: if n is None:
n = len(data) n = len(data)
result = data[:n].tobytes() result = data[:n].tobytes()
self.low_watermark += len(result) self.low_watermark += len(result)
@ -111,7 +111,7 @@ class SecureSession(BaseSession):
self.low_watermark = 0 self.low_watermark = 0
self.high_watermark = len(msg) self.high_watermark = len(msg)
async def read(self, n: int = -1) -> bytes: async def read(self, n: int = None) -> bytes:
if n == 0: if n == 0:
return bytes() return bytes()

View File

@ -81,22 +81,23 @@ class MplexStream(IMuxedStream):
break break
return buf return buf
async def read(self, n: int = -1) -> bytes: async def read(self, n: int = None) -> bytes:
""" """
Read up to n bytes. Read possibly returns fewer than `n` bytes, if Read up to n bytes. Read possibly returns fewer than `n` bytes, if
there are not enough bytes in the Mplex buffer. If `n == -1`, read there are not enough bytes in the Mplex buffer. If `n is None`, read
until EOF. until EOF.
:param n: number of bytes to read :param n: number of bytes to read
:return: bytes actually read :return: bytes actually read
""" """
if n < 0 and n != -1: if n is not None and n < 0:
raise ValueError( raise ValueError(
f"the number of bytes to read `n` must be positive or -1 to indicate read until EOF" f"the number of bytes to read `n` must be non-negative or "
"`None` to indicate read until EOF"
) )
if self.event_reset.is_set(): if self.event_reset.is_set():
raise MplexStreamReset raise MplexStreamReset
if n == -1: if n is None:
return await self._read_until_eof() return await self._read_until_eof()
if len(self._buf) == 0: if len(self._buf) == 0:
data: bytes data: bytes

View File

@ -84,11 +84,8 @@ class DaemonStream(ReadWriteCloser):
async def close(self) -> None: async def close(self) -> None:
await self.stream.close() await self.stream.close()
async def read(self, n: int = -1) -> bytes: async def read(self, n: int = None) -> bytes:
if n == -1: return await self.stream.receive_some(n)
return await self.stream.receive_some()
else:
return await self.stream.receive_some(n)
async def write(self, data: bytes) -> None: async def write(self, data: bytes) -> None:
return await self.stream.send_all(data) return await self.stream.send_all(data)