diff --git a/libp2p/tools/utils.py b/libp2p/tools/utils.py index a66155c..5a262b3 100644 --- a/libp2p/tools/utils.py +++ b/libp2p/tools/utils.py @@ -1,6 +1,7 @@ from typing import Awaitable, Callable from libp2p.host.host_interface import IHost +from libp2p.network.stream.exceptions import StreamError from libp2p.network.stream.net_stream_interface import INetStream from libp2p.network.swarm import Swarm from libp2p.peer.peerinfo import info_from_p2p_addr @@ -33,9 +34,15 @@ def create_echo_stream_handler( ) -> Callable[[INetStream], Awaitable[None]]: async def echo_stream_handler(stream: INetStream) -> None: while True: - read_string = (await stream.read(MAX_READ_LEN)).decode() + try: + read_string = (await stream.read(MAX_READ_LEN)).decode() + except StreamError: + break resp = ack_prefix + read_string - await stream.write(resp.encode()) + try: + await stream.write(resp.encode()) + except StreamError: + break return echo_stream_handler diff --git a/tests/libp2p/test_libp2p.py b/tests/libp2p/test_libp2p.py index 91fea58..99a60bd 100644 --- a/tests/libp2p/test_libp2p.py +++ b/tests/libp2p/test_libp2p.py @@ -1,6 +1,7 @@ import multiaddr import pytest +from libp2p.network.stream.exceptions import StreamError from libp2p.tools.constants import MAX_READ_LEN from libp2p.tools.factories import HostFactory from libp2p.tools.utils import connect, create_echo_stream_handler @@ -42,13 +43,22 @@ async def test_double_response(is_host_secure): async def double_response_stream_handler(stream): while True: - read_string = (await stream.read(MAX_READ_LEN)).decode() + try: + read_string = (await stream.read(MAX_READ_LEN)).decode() + except StreamError: + break response = ACK_STR_0 + read_string - await stream.write(response.encode()) + try: + await stream.write(response.encode()) + except StreamError: + break response = ACK_STR_1 + read_string - await stream.write(response.encode()) + try: + await stream.write(response.encode()) + except StreamError: + break hosts[1].set_stream_handler(PROTOCOL_ID_0, double_response_stream_handler)