diff --git a/libp2p/network/stream/exceptions.py b/libp2p/network/stream/exceptions.py new file mode 100644 index 0000000..58f3ddf --- /dev/null +++ b/libp2p/network/stream/exceptions.py @@ -0,0 +1,17 @@ +from libp2p.exceptions import BaseLibp2pError + + +class StreamError(BaseLibp2pError): + pass + + +class StreamEOF(StreamError, EOFError): + pass + + +class StreamReset(StreamError): + pass + + +class StreamClosed(StreamError): + pass diff --git a/libp2p/network/stream/net_stream.py b/libp2p/network/stream/net_stream.py index 7383f73..4dedab7 100644 --- a/libp2p/network/stream/net_stream.py +++ b/libp2p/network/stream/net_stream.py @@ -1,9 +1,18 @@ from libp2p.stream_muxer.abc import IMuxedConn, IMuxedStream +from libp2p.stream_muxer.exceptions import ( + MuxedStreamClosed, + MuxedStreamEOF, + MuxedStreamReset, +) from libp2p.typing import TProtocol +from .exceptions import StreamClosed, StreamEOF, StreamReset from .net_stream_interface import INetStream +# TODO: Handle exceptions from `muxed_stream` +# TODO: Add stream state +# - Reference: https://github.com/libp2p/go-libp2p-swarm/blob/99831444e78c8f23c9335c17d8f7c700ba25ca14/swarm_stream.go # noqa: E501 class NetStream(INetStream): muxed_stream: IMuxedStream @@ -35,14 +44,22 @@ class NetStream(INetStream): :param n: number of bytes to read :return: bytes of input """ - return await self.muxed_stream.read(n) + try: + return await self.muxed_stream.read(n) + except MuxedStreamEOF as error: + raise StreamEOF from error + except MuxedStreamReset as error: + raise StreamReset from error async def write(self, data: bytes) -> int: """ write to stream :return: number of bytes written """ - return await self.muxed_stream.write(data) + try: + return await self.muxed_stream.write(data) + except MuxedStreamClosed as error: + raise StreamClosed from error async def close(self) -> None: """ @@ -51,5 +68,5 @@ class NetStream(INetStream): """ await self.muxed_stream.close() - async def reset(self) -> bool: - return await self.muxed_stream.reset() + async def reset(self) -> None: + await self.muxed_stream.reset() diff --git a/libp2p/network/stream/net_stream_interface.py b/libp2p/network/stream/net_stream_interface.py index aaa775a..53ce038 100644 --- a/libp2p/network/stream/net_stream_interface.py +++ b/libp2p/network/stream/net_stream_interface.py @@ -23,7 +23,7 @@ class INetStream(ReadWriteCloser): """ @abstractmethod - async def reset(self) -> bool: + async def reset(self) -> None: """ Close both ends of the stream. """ diff --git a/libp2p/stream_muxer/exceptions.py b/libp2p/stream_muxer/exceptions.py new file mode 100644 index 0000000..861319a --- /dev/null +++ b/libp2p/stream_muxer/exceptions.py @@ -0,0 +1,25 @@ +from libp2p.exceptions import BaseLibp2pError + + +class MuxedConnError(BaseLibp2pError): + pass + + +class MuxedConnShutdown(MuxedConnError): + pass + + +class MuxedStreamError(BaseLibp2pError): + pass + + +class MuxedStreamReset(MuxedStreamError): + pass + + +class MuxedStreamEOF(MuxedStreamError, EOFError): + pass + + +class MuxedStreamClosed(MuxedStreamError): + pass diff --git a/libp2p/stream_muxer/mplex/exceptions.py b/libp2p/stream_muxer/mplex/exceptions.py index 11663e2..154c371 100644 --- a/libp2p/stream_muxer/mplex/exceptions.py +++ b/libp2p/stream_muxer/mplex/exceptions.py @@ -1,17 +1,27 @@ -from libp2p.exceptions import BaseLibp2pError +from libp2p.stream_muxer.exceptions import ( + MuxedConnError, + MuxedConnShutdown, + MuxedStreamClosed, + MuxedStreamEOF, + MuxedStreamReset, +) -class MplexError(BaseLibp2pError): +class MplexError(MuxedConnError): pass -class MplexStreamReset(MplexError): +class MplexShutdown(MuxedConnShutdown): pass -class MplexStreamEOF(MplexError, EOFError): +class MplexStreamReset(MuxedStreamReset): pass -class MplexShutdown(MplexError): +class MplexStreamEOF(MuxedStreamEOF): + pass + + +class MplexStreamClosed(MuxedStreamClosed): pass diff --git a/libp2p/stream_muxer/mplex/mplex.py b/libp2p/stream_muxer/mplex/mplex.py index 1e8823a..c75000d 100644 --- a/libp2p/stream_muxer/mplex/mplex.py +++ b/libp2p/stream_muxer/mplex/mplex.py @@ -188,6 +188,10 @@ class Mplex(IMuxedConn): # before. It is abnormal. Possibly disconnect? # TODO: Warn and emit logs about this. continue + async with stream.close_lock: + if stream.event_remote_closed.is_set(): + # TODO: Warn "Received data from remote after stream was closed by them. (len = %d)" # noqa: E501 + continue await stream.incoming_data.put(message) elif flag in ( HeaderTags.CloseInitiator.value, diff --git a/libp2p/stream_muxer/mplex/mplex_stream.py b/libp2p/stream_muxer/mplex/mplex_stream.py index 18c8ff0..547d7b8 100644 --- a/libp2p/stream_muxer/mplex/mplex_stream.py +++ b/libp2p/stream_muxer/mplex/mplex_stream.py @@ -1,11 +1,11 @@ import asyncio -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, cast from libp2p.stream_muxer.abc import IMuxedStream from .constants import HeaderTags from .datastructures import StreamID -from .exceptions import MplexStreamEOF, MplexStreamReset +from .exceptions import MplexStreamClosed, MplexStreamEOF, MplexStreamReset if TYPE_CHECKING: from libp2p.stream_muxer.mplex.mplex import Mplex @@ -58,20 +58,24 @@ class MplexStream(IMuxedStream): done, pending = await asyncio.wait( # type: ignore [ self.event_reset.wait(), - self.event_remote_closed.wait(), self.incoming_data.get(), + self.event_remote_closed.wait(), ], return_when=asyncio.FIRST_COMPLETED, ) + for fut in pending: + fut.cancel() if self.event_reset.is_set(): raise MplexStreamReset + done_task = tuple(done)[0] + if done_task._coro.__qualname__ == "Queue.get": + data = done_task.result() + self._buf.extend(data) + return if self.event_remote_closed.is_set(): raise MplexStreamEOF # TODO: Handle timeout when deadline is used. - data = tuple(done)[0].result() - self._buf.extend(data) - async def _read_until_eof(self) -> bytes: while True: try: @@ -99,13 +103,15 @@ class MplexStream(IMuxedStream): raise MplexStreamReset if n == -1: return await self._read_until_eof() - if len(self._buf) == 0: + if len(self._buf) == 0 and self.incoming_data.empty(): await self._wait_for_data() - # Read up to `n` bytes. + # Either `buf` is not empty or `incoming_data` is not empty now. + # Try to put enough incoming data into `self._buf`. while len(self._buf) < n: - if self.incoming_data.empty() or self.event_remote_closed.is_set(): + try: + self._buf.extend(self.incoming_data.get_nowait()) + except asyncio.QueueEmpty: break - self._buf.extend(await self.incoming_data.get()) payload = self._buf[:n] self._buf = self._buf[len(payload) :] return bytes(payload) @@ -115,6 +121,8 @@ class MplexStream(IMuxedStream): write to stream :return: number of bytes written """ + if self.event_local_closed.is_set(): + raise MplexStreamClosed(f"cannot write to closed stream: data={data}") flag = ( HeaderTags.MessageInitiator if self.is_initiator diff --git a/tests/factories.py b/tests/factories.py index 240bdb8..e161e25 100644 --- a/tests/factories.py +++ b/tests/factories.py @@ -1,22 +1,29 @@ -from typing import Dict +import asyncio +from typing import Dict, Tuple import factory from libp2p import generate_new_rsa_identity, initialize_default_swarm from libp2p.crypto.keys import KeyPair from libp2p.host.basic_host import BasicHost +from libp2p.host.host_interface import IHost +from libp2p.network.stream.net_stream_interface import INetStream from libp2p.pubsub.floodsub import FloodSub from libp2p.pubsub.gossipsub import GossipSub from libp2p.pubsub.pubsub import Pubsub from libp2p.security.base_transport import BaseSecureTransport from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport import libp2p.security.secio.transport as secio +from libp2p.stream_muxer.mplex.mplex import Mplex +from libp2p.stream_muxer.mplex.mplex_stream import MplexStream from libp2p.typing import TProtocol +from tests.configs import LISTEN_MADDR from tests.pubsub.configs import ( FLOODSUB_PROTOCOL_ID, GOSSIPSUB_PARAMS, GOSSIPSUB_PROTOCOL_ID, ) +from tests.utils import connect def security_transport_factory( @@ -43,6 +50,12 @@ class HostFactory(factory.Factory): network = factory.LazyAttribute(lambda o: swarm_factory(o.is_secure)) + @classmethod + async def create_and_listen(cls) -> IHost: + host = cls() + await host.get_network().listen(LISTEN_MADDR) + return host + class FloodsubFactory(factory.Factory): class Meta: @@ -73,3 +86,37 @@ class PubsubFactory(factory.Factory): router = None my_id = factory.LazyAttribute(lambda obj: obj.host.get_id()) cache_size = None + + +async def host_pair_factory() -> Tuple[BasicHost, BasicHost]: + hosts = await asyncio.gather( + *[HostFactory.create_and_listen(), HostFactory.create_and_listen()] + ) + await connect(hosts[0], hosts[1]) + return hosts[0], hosts[1] + + +async def connection_pair_factory() -> Tuple[Mplex, BasicHost, Mplex, BasicHost]: + host_0, host_1 = await host_pair_factory() + mplex_conn_0 = host_0.get_network().connections[host_1.get_id()] + mplex_conn_1 = host_1.get_network().connections[host_0.get_id()] + return mplex_conn_0, host_0, mplex_conn_1, host_1 + + +async def net_stream_pair_factory() -> Tuple[ + INetStream, BasicHost, INetStream, BasicHost +]: + protocol_id = "/example/id/1" + + stream_1: INetStream + + # Just a proxy, we only care about the stream + def handler(stream: INetStream) -> None: + nonlocal stream_1 + stream_1 = stream + + host_0, host_1 = await host_pair_factory() + host_1.set_stream_handler(protocol_id, handler) + + stream_0 = await host_0.new_stream(host_1.get_id(), [protocol_id]) + return stream_0, host_0, stream_1, host_1