from typing import Optional from libp2p.stream_muxer.abc import IMuxedStream from libp2p.stream_muxer.exceptions import ( MuxedStreamClosed, MuxedStreamEOF, MuxedStreamReset, ) from libp2p.typing import TProtocol from .exceptions import StreamClosed, StreamEOF, StreamReset from .net_stream_interface import INetStream # TODO: Handle exceptions from `muxed_stream` # TODO: Add stream state # - Reference: https://github.com/libp2p/go-libp2p-swarm/blob/99831444e78c8f23c9335c17d8f7c700ba25ca14/swarm_stream.go # noqa: E501 class NetStream(INetStream): muxed_stream: IMuxedStream protocol_id: Optional[TProtocol] def __init__(self, muxed_stream: IMuxedStream) -> None: self.muxed_stream = muxed_stream self.muxed_conn = muxed_stream.muxed_conn self.protocol_id = None def get_protocol(self) -> TProtocol: """ :return: protocol id that stream runs on """ return self.protocol_id def set_protocol(self, protocol_id: TProtocol) -> None: """ :param protocol_id: protocol id that stream runs on """ self.protocol_id = protocol_id async def read(self, n: int = None) -> bytes: """ reads from stream. :param n: number of bytes to read :return: bytes of input """ try: return await self.muxed_stream.read(n) except MuxedStreamEOF as error: raise StreamEOF() from error except MuxedStreamReset as error: raise StreamReset() from error async def write(self, data: bytes) -> int: """ write to stream. :return: number of bytes written """ try: return await self.muxed_stream.write(data) except MuxedStreamClosed as error: raise StreamClosed() from error async def close(self) -> None: """close stream.""" await self.muxed_stream.close() async def reset(self) -> None: await self.muxed_stream.reset() # TODO: `remove`: Called by close and write when the stream is in specific states. # It notifies `ClosedStream` after `SwarmConn.remove_stream` is called. # Reference: https://github.com/libp2p/go-libp2p-swarm/blob/99831444e78c8f23c9335c17d8f7c700ba25ca14/swarm_stream.go # noqa: E501