Fix close behavior
This commit is contained in:
parent
b2146c5268
commit
be2c0f122a
17
libp2p/network/stream/exceptions.py
Normal file
17
libp2p/network/stream/exceptions.py
Normal file
|
@ -0,0 +1,17 @@
|
||||||
|
from libp2p.exceptions import BaseLibp2pError
|
||||||
|
|
||||||
|
|
||||||
|
class StreamError(BaseLibp2pError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class StreamEOF(StreamError, EOFError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class StreamReset(StreamError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class StreamClosed(StreamError):
|
||||||
|
pass
|
|
@ -1,9 +1,18 @@
|
||||||
from libp2p.stream_muxer.abc import IMuxedConn, IMuxedStream
|
from libp2p.stream_muxer.abc import IMuxedConn, IMuxedStream
|
||||||
|
from libp2p.stream_muxer.exceptions import (
|
||||||
|
MuxedStreamClosed,
|
||||||
|
MuxedStreamEOF,
|
||||||
|
MuxedStreamReset,
|
||||||
|
)
|
||||||
from libp2p.typing import TProtocol
|
from libp2p.typing import TProtocol
|
||||||
|
|
||||||
|
from .exceptions import StreamClosed, StreamEOF, StreamReset
|
||||||
from .net_stream_interface import INetStream
|
from .net_stream_interface import INetStream
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: Handle exceptions from `muxed_stream`
|
||||||
|
# TODO: Add stream state
|
||||||
|
# - Reference: https://github.com/libp2p/go-libp2p-swarm/blob/99831444e78c8f23c9335c17d8f7c700ba25ca14/swarm_stream.go # noqa: E501
|
||||||
class NetStream(INetStream):
|
class NetStream(INetStream):
|
||||||
|
|
||||||
muxed_stream: IMuxedStream
|
muxed_stream: IMuxedStream
|
||||||
|
@ -35,14 +44,22 @@ class NetStream(INetStream):
|
||||||
:param n: number of bytes to read
|
:param n: number of bytes to read
|
||||||
:return: bytes of input
|
:return: bytes of input
|
||||||
"""
|
"""
|
||||||
return await self.muxed_stream.read(n)
|
try:
|
||||||
|
return await self.muxed_stream.read(n)
|
||||||
|
except MuxedStreamEOF as error:
|
||||||
|
raise StreamEOF from error
|
||||||
|
except MuxedStreamReset as error:
|
||||||
|
raise StreamReset from error
|
||||||
|
|
||||||
async def write(self, data: bytes) -> int:
|
async def write(self, data: bytes) -> int:
|
||||||
"""
|
"""
|
||||||
write to stream
|
write to stream
|
||||||
:return: number of bytes written
|
:return: number of bytes written
|
||||||
"""
|
"""
|
||||||
return await self.muxed_stream.write(data)
|
try:
|
||||||
|
return await self.muxed_stream.write(data)
|
||||||
|
except MuxedStreamClosed as error:
|
||||||
|
raise StreamClosed from error
|
||||||
|
|
||||||
async def close(self) -> None:
|
async def close(self) -> None:
|
||||||
"""
|
"""
|
||||||
|
@ -51,5 +68,5 @@ class NetStream(INetStream):
|
||||||
"""
|
"""
|
||||||
await self.muxed_stream.close()
|
await self.muxed_stream.close()
|
||||||
|
|
||||||
async def reset(self) -> bool:
|
async def reset(self) -> None:
|
||||||
return await self.muxed_stream.reset()
|
await self.muxed_stream.reset()
|
||||||
|
|
|
@ -23,7 +23,7 @@ class INetStream(ReadWriteCloser):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def reset(self) -> bool:
|
async def reset(self) -> None:
|
||||||
"""
|
"""
|
||||||
Close both ends of the stream.
|
Close both ends of the stream.
|
||||||
"""
|
"""
|
||||||
|
|
25
libp2p/stream_muxer/exceptions.py
Normal file
25
libp2p/stream_muxer/exceptions.py
Normal file
|
@ -0,0 +1,25 @@
|
||||||
|
from libp2p.exceptions import BaseLibp2pError
|
||||||
|
|
||||||
|
|
||||||
|
class MuxedConnError(BaseLibp2pError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class MuxedConnShutdown(MuxedConnError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class MuxedStreamError(BaseLibp2pError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class MuxedStreamReset(MuxedStreamError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class MuxedStreamEOF(MuxedStreamError, EOFError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class MuxedStreamClosed(MuxedStreamError):
|
||||||
|
pass
|
|
@ -1,17 +1,27 @@
|
||||||
from libp2p.exceptions import BaseLibp2pError
|
from libp2p.stream_muxer.exceptions import (
|
||||||
|
MuxedConnError,
|
||||||
|
MuxedConnShutdown,
|
||||||
|
MuxedStreamClosed,
|
||||||
|
MuxedStreamEOF,
|
||||||
|
MuxedStreamReset,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class MplexError(BaseLibp2pError):
|
class MplexError(MuxedConnError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class MplexStreamReset(MplexError):
|
class MplexShutdown(MuxedConnShutdown):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class MplexStreamEOF(MplexError, EOFError):
|
class MplexStreamReset(MuxedStreamReset):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class MplexShutdown(MplexError):
|
class MplexStreamEOF(MuxedStreamEOF):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class MplexStreamClosed(MuxedStreamClosed):
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -188,6 +188,10 @@ class Mplex(IMuxedConn):
|
||||||
# before. It is abnormal. Possibly disconnect?
|
# before. It is abnormal. Possibly disconnect?
|
||||||
# TODO: Warn and emit logs about this.
|
# TODO: Warn and emit logs about this.
|
||||||
continue
|
continue
|
||||||
|
async with stream.close_lock:
|
||||||
|
if stream.event_remote_closed.is_set():
|
||||||
|
# TODO: Warn "Received data from remote after stream was closed by them. (len = %d)" # noqa: E501
|
||||||
|
continue
|
||||||
await stream.incoming_data.put(message)
|
await stream.incoming_data.put(message)
|
||||||
elif flag in (
|
elif flag in (
|
||||||
HeaderTags.CloseInitiator.value,
|
HeaderTags.CloseInitiator.value,
|
||||||
|
|
|
@ -1,11 +1,11 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING, cast
|
||||||
|
|
||||||
from libp2p.stream_muxer.abc import IMuxedStream
|
from libp2p.stream_muxer.abc import IMuxedStream
|
||||||
|
|
||||||
from .constants import HeaderTags
|
from .constants import HeaderTags
|
||||||
from .datastructures import StreamID
|
from .datastructures import StreamID
|
||||||
from .exceptions import MplexStreamEOF, MplexStreamReset
|
from .exceptions import MplexStreamClosed, MplexStreamEOF, MplexStreamReset
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from libp2p.stream_muxer.mplex.mplex import Mplex
|
from libp2p.stream_muxer.mplex.mplex import Mplex
|
||||||
|
@ -58,20 +58,24 @@ class MplexStream(IMuxedStream):
|
||||||
done, pending = await asyncio.wait( # type: ignore
|
done, pending = await asyncio.wait( # type: ignore
|
||||||
[
|
[
|
||||||
self.event_reset.wait(),
|
self.event_reset.wait(),
|
||||||
self.event_remote_closed.wait(),
|
|
||||||
self.incoming_data.get(),
|
self.incoming_data.get(),
|
||||||
|
self.event_remote_closed.wait(),
|
||||||
],
|
],
|
||||||
return_when=asyncio.FIRST_COMPLETED,
|
return_when=asyncio.FIRST_COMPLETED,
|
||||||
)
|
)
|
||||||
|
for fut in pending:
|
||||||
|
fut.cancel()
|
||||||
if self.event_reset.is_set():
|
if self.event_reset.is_set():
|
||||||
raise MplexStreamReset
|
raise MplexStreamReset
|
||||||
|
done_task = tuple(done)[0]
|
||||||
|
if done_task._coro.__qualname__ == "Queue.get":
|
||||||
|
data = done_task.result()
|
||||||
|
self._buf.extend(data)
|
||||||
|
return
|
||||||
if self.event_remote_closed.is_set():
|
if self.event_remote_closed.is_set():
|
||||||
raise MplexStreamEOF
|
raise MplexStreamEOF
|
||||||
# TODO: Handle timeout when deadline is used.
|
# TODO: Handle timeout when deadline is used.
|
||||||
|
|
||||||
data = tuple(done)[0].result()
|
|
||||||
self._buf.extend(data)
|
|
||||||
|
|
||||||
async def _read_until_eof(self) -> bytes:
|
async def _read_until_eof(self) -> bytes:
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
|
@ -99,13 +103,15 @@ class MplexStream(IMuxedStream):
|
||||||
raise MplexStreamReset
|
raise MplexStreamReset
|
||||||
if n == -1:
|
if n == -1:
|
||||||
return await self._read_until_eof()
|
return await self._read_until_eof()
|
||||||
if len(self._buf) == 0:
|
if len(self._buf) == 0 and self.incoming_data.empty():
|
||||||
await self._wait_for_data()
|
await self._wait_for_data()
|
||||||
# Read up to `n` bytes.
|
# Either `buf` is not empty or `incoming_data` is not empty now.
|
||||||
|
# Try to put enough incoming data into `self._buf`.
|
||||||
while len(self._buf) < n:
|
while len(self._buf) < n:
|
||||||
if self.incoming_data.empty() or self.event_remote_closed.is_set():
|
try:
|
||||||
|
self._buf.extend(self.incoming_data.get_nowait())
|
||||||
|
except asyncio.QueueEmpty:
|
||||||
break
|
break
|
||||||
self._buf.extend(await self.incoming_data.get())
|
|
||||||
payload = self._buf[:n]
|
payload = self._buf[:n]
|
||||||
self._buf = self._buf[len(payload) :]
|
self._buf = self._buf[len(payload) :]
|
||||||
return bytes(payload)
|
return bytes(payload)
|
||||||
|
@ -115,6 +121,8 @@ class MplexStream(IMuxedStream):
|
||||||
write to stream
|
write to stream
|
||||||
:return: number of bytes written
|
:return: number of bytes written
|
||||||
"""
|
"""
|
||||||
|
if self.event_local_closed.is_set():
|
||||||
|
raise MplexStreamClosed(f"cannot write to closed stream: data={data}")
|
||||||
flag = (
|
flag = (
|
||||||
HeaderTags.MessageInitiator
|
HeaderTags.MessageInitiator
|
||||||
if self.is_initiator
|
if self.is_initiator
|
||||||
|
|
|
@ -1,22 +1,29 @@
|
||||||
from typing import Dict
|
import asyncio
|
||||||
|
from typing import Dict, Tuple
|
||||||
|
|
||||||
import factory
|
import factory
|
||||||
|
|
||||||
from libp2p import generate_new_rsa_identity, initialize_default_swarm
|
from libp2p import generate_new_rsa_identity, initialize_default_swarm
|
||||||
from libp2p.crypto.keys import KeyPair
|
from libp2p.crypto.keys import KeyPair
|
||||||
from libp2p.host.basic_host import BasicHost
|
from libp2p.host.basic_host import BasicHost
|
||||||
|
from libp2p.host.host_interface import IHost
|
||||||
|
from libp2p.network.stream.net_stream_interface import INetStream
|
||||||
from libp2p.pubsub.floodsub import FloodSub
|
from libp2p.pubsub.floodsub import FloodSub
|
||||||
from libp2p.pubsub.gossipsub import GossipSub
|
from libp2p.pubsub.gossipsub import GossipSub
|
||||||
from libp2p.pubsub.pubsub import Pubsub
|
from libp2p.pubsub.pubsub import Pubsub
|
||||||
from libp2p.security.base_transport import BaseSecureTransport
|
from libp2p.security.base_transport import BaseSecureTransport
|
||||||
from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport
|
from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport
|
||||||
import libp2p.security.secio.transport as secio
|
import libp2p.security.secio.transport as secio
|
||||||
|
from libp2p.stream_muxer.mplex.mplex import Mplex
|
||||||
|
from libp2p.stream_muxer.mplex.mplex_stream import MplexStream
|
||||||
from libp2p.typing import TProtocol
|
from libp2p.typing import TProtocol
|
||||||
|
from tests.configs import LISTEN_MADDR
|
||||||
from tests.pubsub.configs import (
|
from tests.pubsub.configs import (
|
||||||
FLOODSUB_PROTOCOL_ID,
|
FLOODSUB_PROTOCOL_ID,
|
||||||
GOSSIPSUB_PARAMS,
|
GOSSIPSUB_PARAMS,
|
||||||
GOSSIPSUB_PROTOCOL_ID,
|
GOSSIPSUB_PROTOCOL_ID,
|
||||||
)
|
)
|
||||||
|
from tests.utils import connect
|
||||||
|
|
||||||
|
|
||||||
def security_transport_factory(
|
def security_transport_factory(
|
||||||
|
@ -43,6 +50,12 @@ class HostFactory(factory.Factory):
|
||||||
|
|
||||||
network = factory.LazyAttribute(lambda o: swarm_factory(o.is_secure))
|
network = factory.LazyAttribute(lambda o: swarm_factory(o.is_secure))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def create_and_listen(cls) -> IHost:
|
||||||
|
host = cls()
|
||||||
|
await host.get_network().listen(LISTEN_MADDR)
|
||||||
|
return host
|
||||||
|
|
||||||
|
|
||||||
class FloodsubFactory(factory.Factory):
|
class FloodsubFactory(factory.Factory):
|
||||||
class Meta:
|
class Meta:
|
||||||
|
@ -73,3 +86,37 @@ class PubsubFactory(factory.Factory):
|
||||||
router = None
|
router = None
|
||||||
my_id = factory.LazyAttribute(lambda obj: obj.host.get_id())
|
my_id = factory.LazyAttribute(lambda obj: obj.host.get_id())
|
||||||
cache_size = None
|
cache_size = None
|
||||||
|
|
||||||
|
|
||||||
|
async def host_pair_factory() -> Tuple[BasicHost, BasicHost]:
|
||||||
|
hosts = await asyncio.gather(
|
||||||
|
*[HostFactory.create_and_listen(), HostFactory.create_and_listen()]
|
||||||
|
)
|
||||||
|
await connect(hosts[0], hosts[1])
|
||||||
|
return hosts[0], hosts[1]
|
||||||
|
|
||||||
|
|
||||||
|
async def connection_pair_factory() -> Tuple[Mplex, BasicHost, Mplex, BasicHost]:
|
||||||
|
host_0, host_1 = await host_pair_factory()
|
||||||
|
mplex_conn_0 = host_0.get_network().connections[host_1.get_id()]
|
||||||
|
mplex_conn_1 = host_1.get_network().connections[host_0.get_id()]
|
||||||
|
return mplex_conn_0, host_0, mplex_conn_1, host_1
|
||||||
|
|
||||||
|
|
||||||
|
async def net_stream_pair_factory() -> Tuple[
|
||||||
|
INetStream, BasicHost, INetStream, BasicHost
|
||||||
|
]:
|
||||||
|
protocol_id = "/example/id/1"
|
||||||
|
|
||||||
|
stream_1: INetStream
|
||||||
|
|
||||||
|
# Just a proxy, we only care about the stream
|
||||||
|
def handler(stream: INetStream) -> None:
|
||||||
|
nonlocal stream_1
|
||||||
|
stream_1 = stream
|
||||||
|
|
||||||
|
host_0, host_1 = await host_pair_factory()
|
||||||
|
host_1.set_stream_handler(protocol_id, handler)
|
||||||
|
|
||||||
|
stream_0 = await host_0.new_stream(host_1.get_id(), [protocol_id])
|
||||||
|
return stream_0, host_0, stream_1, host_1
|
||||||
|
|
Loading…
Reference in New Issue
Block a user