From 0548d285682d23e7ff2e003ab11ffd1b946d391e Mon Sep 17 00:00:00 2001 From: mhchia Date: Tue, 4 Feb 2020 20:46:40 +0800 Subject: [PATCH] Fix: `StreamReset` in the stream handlers Since we don't catch `Exception` in the stream handlers, catch them in the stream handlers in the tests. --- libp2p/tools/utils.py | 11 +++++++++-- tests/libp2p/test_libp2p.py | 16 +++++++++++++--- 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/libp2p/tools/utils.py b/libp2p/tools/utils.py index a66155c..5a262b3 100644 --- a/libp2p/tools/utils.py +++ b/libp2p/tools/utils.py @@ -1,6 +1,7 @@ from typing import Awaitable, Callable from libp2p.host.host_interface import IHost +from libp2p.network.stream.exceptions import StreamError from libp2p.network.stream.net_stream_interface import INetStream from libp2p.network.swarm import Swarm from libp2p.peer.peerinfo import info_from_p2p_addr @@ -33,9 +34,15 @@ def create_echo_stream_handler( ) -> Callable[[INetStream], Awaitable[None]]: async def echo_stream_handler(stream: INetStream) -> None: while True: - read_string = (await stream.read(MAX_READ_LEN)).decode() + try: + read_string = (await stream.read(MAX_READ_LEN)).decode() + except StreamError: + break resp = ack_prefix + read_string - await stream.write(resp.encode()) + try: + await stream.write(resp.encode()) + except StreamError: + break return echo_stream_handler diff --git a/tests/libp2p/test_libp2p.py b/tests/libp2p/test_libp2p.py index 91fea58..99a60bd 100644 --- a/tests/libp2p/test_libp2p.py +++ b/tests/libp2p/test_libp2p.py @@ -1,6 +1,7 @@ import multiaddr import pytest +from libp2p.network.stream.exceptions import StreamError from libp2p.tools.constants import MAX_READ_LEN from libp2p.tools.factories import HostFactory from libp2p.tools.utils import connect, create_echo_stream_handler @@ -42,13 +43,22 @@ async def test_double_response(is_host_secure): async def double_response_stream_handler(stream): while True: - read_string = (await stream.read(MAX_READ_LEN)).decode() + try: + read_string = (await stream.read(MAX_READ_LEN)).decode() + except StreamError: + break response = ACK_STR_0 + read_string - await stream.write(response.encode()) + try: + await stream.write(response.encode()) + except StreamError: + break response = ACK_STR_1 + read_string - await stream.write(response.encode()) + try: + await stream.write(response.encode()) + except StreamError: + break hosts[1].set_stream_handler(PROTOCOL_ID_0, double_response_stream_handler)