Merge pull request #267 from NIC619/fix_conn_attr_in_mplex

Small fix on `conn` attribute and docstring in mplex
This commit is contained in:
NIC Lin 2019-08-25 16:53:29 +08:00 committed by GitHub
commit 5b122d04b2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 20 additions and 17 deletions

View File

@ -44,9 +44,9 @@ class Mplex(IMuxedConn):
for new muxed streams for new muxed streams
:param peer_id: peer_id of peer the connection is to :param peer_id: peer_id of peer the connection is to
""" """
self.conn = secured_conn self.secured_conn = secured_conn
if self.conn.initiator: if self.secured_conn.initiator:
self.next_stream_id = 0 self.next_stream_id = 0
else: else:
self.next_stream_id = 1 self.next_stream_id = 1
@ -67,13 +67,13 @@ class Mplex(IMuxedConn):
@property @property
def initiator(self) -> bool: def initiator(self) -> bool:
return self.conn.initiator return self.secured_conn.initiator
async def close(self) -> None: async def close(self) -> None:
""" """
close the stream muxer and underlying raw connection close the stream muxer and underlying secured connection
""" """
await self.conn.close() await self.secured_conn.close()
def is_closed(self) -> bool: def is_closed(self) -> bool:
""" """
@ -84,7 +84,8 @@ class Mplex(IMuxedConn):
async def read_buffer(self, stream_id: int) -> bytes: async def read_buffer(self, stream_id: int) -> bytes:
""" """
Read a message from stream_id's buffer, check raw connection for new messages. Read a message from buffer of the stream specified by `stream_id`,
check secured connection for new messages.
`StreamNotFound` is raised when stream `stream_id` is not found in `Mplex`. `StreamNotFound` is raised when stream `stream_id` is not found in `Mplex`.
:param stream_id: stream id of stream to read from :param stream_id: stream id of stream to read from
:return: message read :return: message read
@ -95,7 +96,7 @@ class Mplex(IMuxedConn):
async def read_buffer_nonblocking(self, stream_id: int) -> Optional[bytes]: async def read_buffer_nonblocking(self, stream_id: int) -> Optional[bytes]:
""" """
Read a message from `stream_id`'s buffer, non-blockingly. Read a message from buffer of the stream specified by `stream_id`, non-blockingly.
`StreamNotFound` is raised when stream `stream_id` is not found in `Mplex`. `StreamNotFound` is raised when stream `stream_id` is not found in `Mplex`.
""" """
if stream_id not in self.buffers: if stream_id not in self.buffers:
@ -121,7 +122,7 @@ class Mplex(IMuxedConn):
creates a new muxed_stream creates a new muxed_stream
:param protocol_id: protocol_id of stream :param protocol_id: protocol_id of stream
:param multi_addr: multi_addr that stream connects to :param multi_addr: multi_addr that stream connects to
:return: a new stream :return: a new muxed stream
""" """
stream_id = self._get_next_stream_id() stream_id = self._get_next_stream_id()
stream = MplexStream(stream_id, True, self) stream = MplexStream(stream_id, True, self)
@ -159,16 +160,16 @@ class Mplex(IMuxedConn):
async def write_to_stream(self, _bytes: bytearray) -> int: async def write_to_stream(self, _bytes: bytearray) -> int:
""" """
writes a byte array to a raw connection writes a byte array to a secured connection
:param _bytes: byte array to write :param _bytes: byte array to write
:return: length written :return: length written
""" """
await self.conn.write(_bytes) await self.secured_conn.write(_bytes)
return len(_bytes) return len(_bytes)
async def handle_incoming(self) -> None: async def handle_incoming(self) -> None:
""" """
Read a message off of the raw connection and add it to the corresponding message buffer Read a message off of the secured connection and add it to the corresponding message buffer
""" """
# TODO Deal with other types of messages using flag (currently _) # TODO Deal with other types of messages using flag (currently _)
@ -192,7 +193,7 @@ class Mplex(IMuxedConn):
async def read_message(self) -> Tuple[int, int, bytes]: async def read_message(self) -> Tuple[int, int, bytes]:
""" """
Read a single message off of the raw connection Read a single message off of the secured connection
:return: stream_id, flag, message contents :return: stream_id, flag, message contents
""" """
@ -201,9 +202,11 @@ class Mplex(IMuxedConn):
# loop in handle_incoming # loop in handle_incoming
timeout = 0.1 timeout = 0.1
try: try:
header = await decode_uvarint_from_stream(self.conn, timeout) header = await decode_uvarint_from_stream(self.secured_conn, timeout)
length = await decode_uvarint_from_stream(self.conn, timeout) length = await decode_uvarint_from_stream(self.secured_conn, timeout)
message = await asyncio.wait_for(self.conn.read(length), timeout=timeout) message = await asyncio.wait_for(
self.secured_conn.read(length), timeout=timeout
)
except asyncio.TimeoutError: except asyncio.TimeoutError:
return None, None, None return None, None, None

View File

@ -53,8 +53,8 @@ async def perform_simple_test(
node2_conn = node2.get_network().connections[peer_id_for_node(node1)] node2_conn = node2.get_network().connections[peer_id_for_node(node1)]
# Perform assertion # Perform assertion
assertion_func(node1_conn.conn) assertion_func(node1_conn.secured_conn)
assertion_func(node2_conn.conn) assertion_func(node2_conn.secured_conn)
# Success, terminate pending tasks. # Success, terminate pending tasks.
await cleanup() await cleanup()