From 649a2307769ab4b266c69d48f8131ff32f39bf9f Mon Sep 17 00:00:00 2001 From: mhchia Date: Fri, 6 Sep 2019 17:26:40 +0800 Subject: [PATCH] Fix `MplexStream.read` --- examples/chat/chat.py | 4 +- examples/echo/echo.py | 5 +- libp2p/stream_muxer/mplex/mplex_stream.py | 42 ++++---- tests/examples/test_chat.py | 28 +++--- tests/libp2p/test_libp2p.py | 51 +++++----- tests/libp2p/test_notify.py | 106 ++++++++------------ tests/protocol_muxer/test_protocol_muxer.py | 17 +--- tests/utils.py | 6 +- 8 files changed, 116 insertions(+), 143 deletions(-) diff --git a/examples/chat/chat.py b/examples/chat/chat.py index 39258b5..24c9269 100755 --- a/examples/chat/chat.py +++ b/examples/chat/chat.py @@ -11,11 +11,12 @@ from libp2p.peer.peerinfo import info_from_p2p_addr from libp2p.typing import TProtocol PROTOCOL_ID = TProtocol("/chat/1.0.0") +MAX_READ_LEN = 2 ** 32 - 1 async def read_data(stream: INetStream) -> None: while True: - read_bytes = await stream.read() + read_bytes = await stream.read(MAX_READ_LEN) if read_bytes is not None: read_string = read_bytes.decode() if read_string != "\n": @@ -24,7 +25,6 @@ async def read_data(stream: INetStream) -> None: print("\x1b[32m %s\x1b[0m " % read_string, end="") -# FIXME(mhchia): Reconsider whether we should use a thread pool here. async def write_data(stream: INetStream) -> None: loop = asyncio.get_event_loop() while True: diff --git a/examples/echo/echo.py b/examples/echo/echo.py index 06e4f17..3f3ed33 100644 --- a/examples/echo/echo.py +++ b/examples/echo/echo.py @@ -14,6 +14,7 @@ PROTOCOL_ID = TProtocol("/echo/1.0.0") async def _echo_stream_handler(stream: INetStream) -> None: + # Wait until EOF msg = await stream.read() await stream.write(msg) await stream.close() @@ -72,13 +73,13 @@ async def run(port: int, destination: str, localhost: bool, seed: int = None) -> msg = b"hi, there!\n" await stream.write(msg) + # Notify the other side about EOF + await stream.close() response = await stream.read() print(f"Sent: {msg}") print(f"Got: {response}") - await stream.close() - def main() -> None: description = """ diff --git a/libp2p/stream_muxer/mplex/mplex_stream.py b/libp2p/stream_muxer/mplex/mplex_stream.py index 4f2e76c..e537dda 100644 --- a/libp2p/stream_muxer/mplex/mplex_stream.py +++ b/libp2p/stream_muxer/mplex/mplex_stream.py @@ -55,7 +55,6 @@ class MplexStream(IMuxedStream): return self.stream_id.is_initiator async def _wait_for_data(self) -> None: - print("!@# _wait_for_data: 0") done, pending = await asyncio.wait( [ self.event_reset.wait(), @@ -64,16 +63,25 @@ class MplexStream(IMuxedStream): ], return_when=asyncio.FIRST_COMPLETED, ) - print("!@# _wait_for_data: 1") if self.event_reset.is_set(): raise MplexStreamReset if self.event_remote_closed.is_set(): - while not self.incoming_data.empty(): - self._buf.extend(await self.incoming_data.get()) raise MplexStreamEOF + # TODO: Handle timeout when deadline is used. + data = tuple(done)[0].result() self._buf.extend(data) + async def _read_until_eof(self) -> bytes: + while True: + try: + await self._wait_for_data() + except MplexStreamEOF: + break + payload = self._buf + self._buf = self._buf[len(payload) :] + return bytes(payload) + async def read(self, n: int = -1) -> bytes: """ Read up to n bytes. Read possibly returns fewer than `n` bytes, @@ -87,22 +95,18 @@ class MplexStream(IMuxedStream): raise ValueError( f"the number of bytes to read `n` must be positive or -1 to indicate read until EOF" ) - - # FIXME: If `n == -1`, we should blocking read until EOF, instead of returning when - # no message is available. - # If `n >= 0`, read up to `n` bytes. - # Else, read until no message is available. - while len(self._buf) < n or n == -1: - # new_bytes = await self.incoming_data.get() - try: - await self._wait_for_data() - except MplexStreamEOF: - break - payload: bytearray + if self.event_reset.is_set(): + raise MplexStreamReset if n == -1: - payload = self._buf - else: - payload = self._buf[:n] + return await self._read_until_eof() + if len(self._buf) == 0: + await self._wait_for_data() + # Read up to `n` bytes. + while len(self._buf) < n: + if self.incoming_data.empty() or self.event_remote_closed.is_set(): + break + self._buf.extend(await self.incoming_data.get()) + payload = self._buf[:n] self._buf = self._buf[len(payload) :] return bytes(payload) diff --git a/tests/examples/test_chat.py b/tests/examples/test_chat.py index f461d9d..75d8ec7 100644 --- a/tests/examples/test_chat.py +++ b/tests/examples/test_chat.py @@ -10,10 +10,13 @@ PROTOCOL_ID = "/chat/1.0.0" async def hello_world(host_a, host_b): + hello_world_from_host_a = b"hello world from host a" + hello_world_from_host_b = b"hello world from host b" + async def stream_handler(stream): - read = await stream.read() - assert read == b"hello world from host b" - await stream.write(b"hello world from host a") + read = await stream.read(len(hello_world_from_host_b)) + assert read == hello_world_from_host_b + await stream.write(hello_world_from_host_a) await stream.close() host_a.set_stream_handler(PROTOCOL_ID, stream_handler) @@ -21,9 +24,9 @@ async def hello_world(host_a, host_b): # Start a stream with the destination. # Multiaddress of the destination peer is fetched from the peerstore using 'peerId'. stream = await host_b.new_stream(host_a.get_id(), [PROTOCOL_ID]) - await stream.write(b"hello world from host b") + await stream.write(hello_world_from_host_b) read = await stream.read() - assert read == b"hello world from host a" + assert read == hello_world_from_host_a await stream.close() @@ -32,11 +35,8 @@ async def connect_write(host_a, host_b): received = [] async def stream_handler(stream): - while True: - try: - received.append((await stream.read()).decode()) - except Exception: # exception is raised when other side close the stream ? - break + for message in messages: + received.append((await stream.read(len(message))).decode()) host_a.set_stream_handler(PROTOCOL_ID, stream_handler) @@ -67,12 +67,8 @@ async def connect_read(host_a, host_b): # Multiaddress of the destination peer is fetched from the peerstore using 'peerId'. stream = await host_b.new_stream(host_a.get_id(), [PROTOCOL_ID]) received = [] - # while True: Seems the close stream event from the other host is not received - for _ in range(5): - try: - received.append(await stream.read()) - except Exception: # exception is raised when other side close the stream ? - break + for message in messages: + received.append(await stream.read(len(message))) await stream.close() assert received == messages diff --git a/tests/libp2p/test_libp2p.py b/tests/libp2p/test_libp2p.py index bc58a8c..b4a643d 100644 --- a/tests/libp2p/test_libp2p.py +++ b/tests/libp2p/test_libp2p.py @@ -3,6 +3,7 @@ import pytest from libp2p.peer.peerinfo import info_from_p2p_addr from tests.utils import cleanup, set_up_nodes_by_transport_opt +from tests.constants import MAX_READ_LEN @pytest.mark.asyncio @@ -12,7 +13,7 @@ async def test_simple_messages(): async def stream_handler(stream): while True: - read_string = (await stream.read()).decode() + read_string = (await stream.read(MAX_READ_LEN)).decode() response = "ack:" + read_string await stream.write(response.encode()) @@ -28,7 +29,7 @@ async def test_simple_messages(): for message in messages: await stream.write(message.encode()) - response = (await stream.read()).decode() + response = (await stream.read(MAX_READ_LEN)).decode() assert response == ("ack:" + message) @@ -43,7 +44,7 @@ async def test_double_response(): async def stream_handler(stream): while True: - read_string = (await stream.read()).decode() + read_string = (await stream.read(MAX_READ_LEN)).decode() response = "ack1:" + read_string await stream.write(response.encode()) @@ -61,10 +62,10 @@ async def test_double_response(): for message in messages: await stream.write(message.encode()) - response1 = (await stream.read()).decode() + response1 = (await stream.read(MAX_READ_LEN)).decode() assert response1 == ("ack1:" + message) - response2 = (await stream.read()).decode() + response2 = (await stream.read(MAX_READ_LEN)).decode() assert response2 == ("ack2:" + message) # Success, terminate pending tasks. @@ -80,14 +81,14 @@ async def test_multiple_streams(): async def stream_handler_a(stream): while True: - read_string = (await stream.read()).decode() + read_string = (await stream.read(MAX_READ_LEN)).decode() response = "ack_a:" + read_string await stream.write(response.encode()) async def stream_handler_b(stream): while True: - read_string = (await stream.read()).decode() + read_string = (await stream.read(MAX_READ_LEN)).decode() response = "ack_b:" + read_string await stream.write(response.encode()) @@ -111,8 +112,8 @@ async def test_multiple_streams(): await stream_a.write(a_message.encode()) await stream_b.write(b_message.encode()) - response_a = (await stream_a.read()).decode() - response_b = (await stream_b.read()).decode() + response_a = (await stream_a.read(MAX_READ_LEN)).decode() + response_b = (await stream_b.read(MAX_READ_LEN)).decode() assert response_a == ("ack_b:" + a_message) and response_b == ( "ack_a:" + b_message @@ -129,21 +130,21 @@ async def test_multiple_streams_same_initiator_different_protocols(): async def stream_handler_a1(stream): while True: - read_string = (await stream.read()).decode() + read_string = (await stream.read(MAX_READ_LEN)).decode() response = "ack_a1:" + read_string await stream.write(response.encode()) async def stream_handler_a2(stream): while True: - read_string = (await stream.read()).decode() + read_string = (await stream.read(MAX_READ_LEN)).decode() response = "ack_a2:" + read_string await stream.write(response.encode()) async def stream_handler_a3(stream): while True: - read_string = (await stream.read()).decode() + read_string = (await stream.read(MAX_READ_LEN)).decode() response = "ack_a3:" + read_string await stream.write(response.encode()) @@ -171,9 +172,9 @@ async def test_multiple_streams_same_initiator_different_protocols(): await stream_a2.write(a2_message.encode()) await stream_a3.write(a3_message.encode()) - response_a1 = (await stream_a1.read()).decode() - response_a2 = (await stream_a2.read()).decode() - response_a3 = (await stream_a3.read()).decode() + response_a1 = (await stream_a1.read(MAX_READ_LEN)).decode() + response_a2 = (await stream_a2.read(MAX_READ_LEN)).decode() + response_a3 = (await stream_a3.read(MAX_READ_LEN)).decode() assert ( response_a1 == ("ack_a1:" + a1_message) @@ -192,28 +193,28 @@ async def test_multiple_streams_two_initiators(): async def stream_handler_a1(stream): while True: - read_string = (await stream.read()).decode() + read_string = (await stream.read(MAX_READ_LEN)).decode() response = "ack_a1:" + read_string await stream.write(response.encode()) async def stream_handler_a2(stream): while True: - read_string = (await stream.read()).decode() + read_string = (await stream.read(MAX_READ_LEN)).decode() response = "ack_a2:" + read_string await stream.write(response.encode()) async def stream_handler_b1(stream): while True: - read_string = (await stream.read()).decode() + read_string = (await stream.read(MAX_READ_LEN)).decode() response = "ack_b1:" + read_string await stream.write(response.encode()) async def stream_handler_b2(stream): while True: - read_string = (await stream.read()).decode() + read_string = (await stream.read(MAX_READ_LEN)).decode() response = "ack_b2:" + read_string await stream.write(response.encode()) @@ -249,11 +250,11 @@ async def test_multiple_streams_two_initiators(): await stream_b1.write(b1_message.encode()) await stream_b2.write(b2_message.encode()) - response_a1 = (await stream_a1.read()).decode() - response_a2 = (await stream_a2.read()).decode() + response_a1 = (await stream_a1.read(MAX_READ_LEN)).decode() + response_a2 = (await stream_a2.read(MAX_READ_LEN)).decode() - response_b1 = (await stream_b1.read()).decode() - response_b2 = (await stream_b2.read()).decode() + response_b1 = (await stream_b1.read(MAX_READ_LEN)).decode() + response_b2 = (await stream_b2.read(MAX_READ_LEN)).decode() assert ( response_a1 == ("ack_a1:" + a1_message) @@ -277,7 +278,7 @@ async def test_triangle_nodes_connection(): async def stream_handler(stream): while True: - read_string = (await stream.read()).decode() + read_string = (await stream.read(MAX_READ_LEN)).decode() response = "ack:" + read_string await stream.write(response.encode()) @@ -320,7 +321,7 @@ async def test_triangle_nodes_connection(): for stream in streams: await stream.write(message.encode()) - response = (await stream.read()).decode() + response = (await stream.read(MAX_READ_LEN)).decode() assert response == ("ack:" + message) diff --git a/tests/libp2p/test_notify.py b/tests/libp2p/test_notify.py index f4bd2ef..206f3e3 100644 --- a/tests/libp2p/test_notify.py +++ b/tests/libp2p/test_notify.py @@ -16,11 +16,10 @@ from libp2p import initialize_default_swarm, new_node from libp2p.crypto.rsa import create_new_key_pair from libp2p.host.basic_host import BasicHost from libp2p.network.notifee_interface import INotifee -from tests.utils import ( - cleanup, - echo_stream_handler, - perform_two_host_set_up_custom_handler, -) +from tests.utils import cleanup, perform_two_host_set_up +from tests.constants import MAX_READ_LEN + +ACK = "ack:" class MyNotifee(INotifee): @@ -67,38 +66,9 @@ class InvalidNotifee: assert False -async def perform_two_host_simple_set_up(): - node_a = await new_node(transport_opt=["/ip4/127.0.0.1/tcp/0"]) - node_b = await new_node(transport_opt=["/ip4/127.0.0.1/tcp/0"]) - - async def my_stream_handler(stream): - while True: - read_string = (await stream.read()).decode() - - resp = "ack:" + read_string - await stream.write(resp.encode()) - - node_b.set_stream_handler("/echo/1.0.0", my_stream_handler) - - # Associate the peer with local ip address (see default parameters of Libp2p()) - node_a.get_peerstore().add_addrs(node_b.get_id(), node_b.get_addrs(), 10) - return node_a, node_b - - -async def perform_two_host_simple_set_up_custom_handler(handler): - node_a = await new_node(transport_opt=["/ip4/127.0.0.1/tcp/0"]) - node_b = await new_node(transport_opt=["/ip4/127.0.0.1/tcp/0"]) - - node_b.set_stream_handler("/echo/1.0.0", handler) - - # Associate the peer with local ip address (see default parameters of Libp2p()) - node_a.get_peerstore().add_addrs(node_b.get_id(), node_b.get_addrs(), 10) - return node_a, node_b - - @pytest.mark.asyncio async def test_one_notifier(): - node_a, node_b = await perform_two_host_set_up_custom_handler(echo_stream_handler) + node_a, node_b = await perform_two_host_set_up() # Add notifee for node_a events = [] @@ -113,11 +83,12 @@ async def test_one_notifier(): messages = ["hello", "hello"] for message in messages: + expected_resp = ACK + message await stream.write(message.encode()) - response = (await stream.read()).decode() + response = (await stream.read(len(expected_resp))).decode() - assert response == ("ack:" + message) + assert response == expected_resp # Success, terminate pending tasks. await cleanup() @@ -126,6 +97,7 @@ async def test_one_notifier(): @pytest.mark.asyncio async def test_one_notifier_on_two_nodes(): events_b = [] + messages = ["hello", "hello"] async def my_stream_handler(stream): # Ensure the connected and opened_stream events were hit in Notifee obj @@ -135,13 +107,13 @@ async def test_one_notifier_on_two_nodes(): ["connectedb", stream.mplex_conn], ["opened_streamb", stream], ] - while True: - read_string = (await stream.read()).decode() + for message in messages: + read_string = (await stream.read(len(message))).decode() - resp = "ack:" + read_string + resp = ACK + read_string await stream.write(resp.encode()) - node_a, node_b = await perform_two_host_set_up_custom_handler(my_stream_handler) + node_a, node_b = await perform_two_host_set_up(my_stream_handler) # Add notifee for node_a events_a = [] @@ -157,13 +129,13 @@ async def test_one_notifier_on_two_nodes(): # node_a assert events_a == [["connecteda", stream.mplex_conn], ["opened_streama", stream]] - messages = ["hello", "hello"] for message in messages: + expected_resp = ACK + message await stream.write(message.encode()) - response = (await stream.read()).decode() + response = (await stream.read(len(expected_resp))).decode() - assert response == ("ack:" + message) + assert response == expected_resp # Success, terminate pending tasks. await cleanup() @@ -172,6 +144,7 @@ async def test_one_notifier_on_two_nodes(): @pytest.mark.asyncio async def test_one_notifier_on_two_nodes_with_listen(): events_b = [] + messages = ["hello", "hello"] node_a_key_pair = create_new_key_pair() node_a_transport_opt = ["/ip4/127.0.0.1/tcp/0"] @@ -196,10 +169,9 @@ async def test_one_notifier_on_two_nodes_with_listen(): ["connectedb", stream.mplex_conn], ["opened_streamb", stream], ] - while True: - read_string = (await stream.read()).decode() - - resp = "ack:" + read_string + for message in messages: + read_string = (await stream.read(len(message))).decode() + resp = ACK + read_string await stream.write(resp.encode()) # Add notifee for node_a @@ -222,13 +194,13 @@ async def test_one_notifier_on_two_nodes_with_listen(): # node_a assert events_a == [["connecteda", stream.mplex_conn], ["opened_streama", stream]] - messages = ["hello", "hello"] for message in messages: + expected_resp = ACK + message await stream.write(message.encode()) - response = (await stream.read()).decode() + response = (await stream.read(len(expected_resp))).decode() - assert response == ("ack:" + message) + assert response == expected_resp # Success, terminate pending tasks. await cleanup() @@ -236,7 +208,7 @@ async def test_one_notifier_on_two_nodes_with_listen(): @pytest.mark.asyncio async def test_two_notifiers(): - node_a, node_b = await perform_two_host_set_up_custom_handler(echo_stream_handler) + node_a, node_b = await perform_two_host_set_up() # Add notifee for node_a events0 = [] @@ -255,11 +227,12 @@ async def test_two_notifiers(): messages = ["hello", "hello"] for message in messages: + expected_resp = ACK + message await stream.write(message.encode()) - response = (await stream.read()).decode() + response = (await stream.read(len(expected_resp))).decode() - assert response == ("ack:" + message) + assert response == expected_resp # Success, terminate pending tasks. await cleanup() @@ -269,7 +242,7 @@ async def test_two_notifiers(): async def test_ten_notifiers(): num_notifiers = 10 - node_a, node_b = await perform_two_host_set_up_custom_handler(echo_stream_handler) + node_a, node_b = await perform_two_host_set_up() # Add notifee for node_a events_lst = [] @@ -290,11 +263,12 @@ async def test_ten_notifiers(): messages = ["hello", "hello"] for message in messages: + expected_resp = ACK + message await stream.write(message.encode()) - response = (await stream.read()).decode() + response = (await stream.read(len(expected_resp))).decode() - assert response == ("ack:" + message) + assert response == expected_resp # Success, terminate pending tasks. await cleanup() @@ -315,12 +289,12 @@ async def test_ten_notifiers_on_two_nodes(): ["opened_streamb" + str(i), stream], ] while True: - read_string = (await stream.read()).decode() + read_string = (await stream.read(MAX_READ_LEN)).decode() - resp = "ack:" + read_string + resp = ACK + read_string await stream.write(resp.encode()) - node_a, node_b = await perform_two_host_set_up_custom_handler(my_stream_handler) + node_a, node_b = await perform_two_host_set_up(my_stream_handler) # Add notifee for node_a and node_b events_lst_a = [] @@ -343,11 +317,12 @@ async def test_ten_notifiers_on_two_nodes(): messages = ["hello", "hello"] for message in messages: + expected_resp = ACK + message await stream.write(message.encode()) - response = (await stream.read()).decode() + response = (await stream.read(len(expected_resp))).decode() - assert response == ("ack:" + message) + assert response == expected_resp # Success, terminate pending tasks. await cleanup() @@ -357,7 +332,7 @@ async def test_ten_notifiers_on_two_nodes(): async def test_invalid_notifee(): num_notifiers = 10 - node_a, node_b = await perform_two_host_set_up_custom_handler(echo_stream_handler) + node_a, node_b = await perform_two_host_set_up() # Add notifee for node_a events_lst = [] @@ -372,11 +347,12 @@ async def test_invalid_notifee(): # given that InvalidNotifee should not have been added as a notifee) messages = ["hello", "hello"] for message in messages: + expected_resp = ACK + message await stream.write(message.encode()) - response = (await stream.read()).decode() + response = (await stream.read(len(expected_resp))).decode() - assert response == ("ack:" + message) + assert response == expected_resp # Success, terminate pending tasks. await cleanup() diff --git a/tests/protocol_muxer/test_protocol_muxer.py b/tests/protocol_muxer/test_protocol_muxer.py index 02f08bd..8fb1537 100644 --- a/tests/protocol_muxer/test_protocol_muxer.py +++ b/tests/protocol_muxer/test_protocol_muxer.py @@ -1,7 +1,7 @@ import pytest from libp2p.protocol_muxer.exceptions import MultiselectClientError -from tests.utils import cleanup, set_up_nodes_by_transport_opt +from tests.utils import cleanup, set_up_nodes_by_transport_opt, echo_stream_handler # TODO: Add tests for multiple streams being opened on different # protocols through the same connection @@ -18,14 +18,8 @@ async def perform_simple_test( transport_opt_list = [["/ip4/127.0.0.1/tcp/0"], ["/ip4/127.0.0.1/tcp/0"]] (node_a, node_b) = await set_up_nodes_by_transport_opt(transport_opt_list) - async def stream_handler(stream): - while True: - read_string = (await stream.read()).decode() - response = "ack:" + read_string - await stream.write(response.encode()) - for protocol in protocols_with_handlers: - node_b.set_stream_handler(protocol, stream_handler) + node_b.set_stream_handler(protocol, echo_stream_handler) # Associate the peer with local ip address (see default parameters of Libp2p()) node_a.get_peerstore().add_addrs(node_b.get_id(), node_b.get_addrs(), 10) @@ -33,11 +27,10 @@ async def perform_simple_test( stream = await node_a.new_stream(node_b.get_id(), protocols_for_client) messages = ["hello" + str(x) for x in range(10)] for message in messages: + expected_resp = "ack:" + message await stream.write(message.encode()) - - response = (await stream.read()).decode() - - assert response == ("ack:" + message) + response = (await stream.read(len(expected_resp))).decode() + assert response == expected_resp assert expected_selected_protocol == stream.get_protocol() diff --git a/tests/utils.py b/tests/utils.py index 58a0880..1f1cfc4 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -6,6 +6,8 @@ import multiaddr from libp2p import new_node from libp2p.peer.peerinfo import info_from_p2p_addr +from tests.constants import MAX_READ_LEN + async def connect(node1, node2): """ @@ -38,13 +40,13 @@ async def set_up_nodes_by_transport_opt(transport_opt_list): async def echo_stream_handler(stream): while True: - read_string = (await stream.read()).decode() + read_string = (await stream.read(MAX_READ_LEN)).decode() resp = "ack:" + read_string await stream.write(resp.encode()) -async def perform_two_host_set_up_custom_handler(handler): +async def perform_two_host_set_up(handler=echo_stream_handler): transport_opt_list = [["/ip4/127.0.0.1/tcp/0"], ["/ip4/127.0.0.1/tcp/0"]] (node_a, node_b) = await set_up_nodes_by_transport_opt(transport_opt_list)