Fix close behavior

This commit is contained in:
mhchia 2019-09-09 15:45:35 +08:00
parent b2146c5268
commit be2c0f122a
No known key found for this signature in database
GPG Key ID: 389EFBEA1362589A
8 changed files with 149 additions and 21 deletions

View File

@ -0,0 +1,17 @@
from libp2p.exceptions import BaseLibp2pError
class StreamError(BaseLibp2pError):
pass
class StreamEOF(StreamError, EOFError):
pass
class StreamReset(StreamError):
pass
class StreamClosed(StreamError):
pass

View File

@ -1,9 +1,18 @@
from libp2p.stream_muxer.abc import IMuxedConn, IMuxedStream from libp2p.stream_muxer.abc import IMuxedConn, IMuxedStream
from libp2p.stream_muxer.exceptions import (
MuxedStreamClosed,
MuxedStreamEOF,
MuxedStreamReset,
)
from libp2p.typing import TProtocol from libp2p.typing import TProtocol
from .exceptions import StreamClosed, StreamEOF, StreamReset
from .net_stream_interface import INetStream from .net_stream_interface import INetStream
# TODO: Handle exceptions from `muxed_stream`
# TODO: Add stream state
# - Reference: https://github.com/libp2p/go-libp2p-swarm/blob/99831444e78c8f23c9335c17d8f7c700ba25ca14/swarm_stream.go # noqa: E501
class NetStream(INetStream): class NetStream(INetStream):
muxed_stream: IMuxedStream muxed_stream: IMuxedStream
@ -35,14 +44,22 @@ class NetStream(INetStream):
:param n: number of bytes to read :param n: number of bytes to read
:return: bytes of input :return: bytes of input
""" """
return await self.muxed_stream.read(n) try:
return await self.muxed_stream.read(n)
except MuxedStreamEOF as error:
raise StreamEOF from error
except MuxedStreamReset as error:
raise StreamReset from error
async def write(self, data: bytes) -> int: async def write(self, data: bytes) -> int:
""" """
write to stream write to stream
:return: number of bytes written :return: number of bytes written
""" """
return await self.muxed_stream.write(data) try:
return await self.muxed_stream.write(data)
except MuxedStreamClosed as error:
raise StreamClosed from error
async def close(self) -> None: async def close(self) -> None:
""" """
@ -51,5 +68,5 @@ class NetStream(INetStream):
""" """
await self.muxed_stream.close() await self.muxed_stream.close()
async def reset(self) -> bool: async def reset(self) -> None:
return await self.muxed_stream.reset() await self.muxed_stream.reset()

View File

@ -23,7 +23,7 @@ class INetStream(ReadWriteCloser):
""" """
@abstractmethod @abstractmethod
async def reset(self) -> bool: async def reset(self) -> None:
""" """
Close both ends of the stream. Close both ends of the stream.
""" """

View File

@ -0,0 +1,25 @@
from libp2p.exceptions import BaseLibp2pError
class MuxedConnError(BaseLibp2pError):
pass
class MuxedConnShutdown(MuxedConnError):
pass
class MuxedStreamError(BaseLibp2pError):
pass
class MuxedStreamReset(MuxedStreamError):
pass
class MuxedStreamEOF(MuxedStreamError, EOFError):
pass
class MuxedStreamClosed(MuxedStreamError):
pass

View File

@ -1,17 +1,27 @@
from libp2p.exceptions import BaseLibp2pError from libp2p.stream_muxer.exceptions import (
MuxedConnError,
MuxedConnShutdown,
MuxedStreamClosed,
MuxedStreamEOF,
MuxedStreamReset,
)
class MplexError(BaseLibp2pError): class MplexError(MuxedConnError):
pass pass
class MplexStreamReset(MplexError): class MplexShutdown(MuxedConnShutdown):
pass pass
class MplexStreamEOF(MplexError, EOFError): class MplexStreamReset(MuxedStreamReset):
pass pass
class MplexShutdown(MplexError): class MplexStreamEOF(MuxedStreamEOF):
pass
class MplexStreamClosed(MuxedStreamClosed):
pass pass

View File

@ -188,6 +188,10 @@ class Mplex(IMuxedConn):
# before. It is abnormal. Possibly disconnect? # before. It is abnormal. Possibly disconnect?
# TODO: Warn and emit logs about this. # TODO: Warn and emit logs about this.
continue continue
async with stream.close_lock:
if stream.event_remote_closed.is_set():
# TODO: Warn "Received data from remote after stream was closed by them. (len = %d)" # noqa: E501
continue
await stream.incoming_data.put(message) await stream.incoming_data.put(message)
elif flag in ( elif flag in (
HeaderTags.CloseInitiator.value, HeaderTags.CloseInitiator.value,

View File

@ -1,11 +1,11 @@
import asyncio import asyncio
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, cast
from libp2p.stream_muxer.abc import IMuxedStream from libp2p.stream_muxer.abc import IMuxedStream
from .constants import HeaderTags from .constants import HeaderTags
from .datastructures import StreamID from .datastructures import StreamID
from .exceptions import MplexStreamEOF, MplexStreamReset from .exceptions import MplexStreamClosed, MplexStreamEOF, MplexStreamReset
if TYPE_CHECKING: if TYPE_CHECKING:
from libp2p.stream_muxer.mplex.mplex import Mplex from libp2p.stream_muxer.mplex.mplex import Mplex
@ -58,20 +58,24 @@ class MplexStream(IMuxedStream):
done, pending = await asyncio.wait( # type: ignore done, pending = await asyncio.wait( # type: ignore
[ [
self.event_reset.wait(), self.event_reset.wait(),
self.event_remote_closed.wait(),
self.incoming_data.get(), self.incoming_data.get(),
self.event_remote_closed.wait(),
], ],
return_when=asyncio.FIRST_COMPLETED, return_when=asyncio.FIRST_COMPLETED,
) )
for fut in pending:
fut.cancel()
if self.event_reset.is_set(): if self.event_reset.is_set():
raise MplexStreamReset raise MplexStreamReset
done_task = tuple(done)[0]
if done_task._coro.__qualname__ == "Queue.get":
data = done_task.result()
self._buf.extend(data)
return
if self.event_remote_closed.is_set(): if self.event_remote_closed.is_set():
raise MplexStreamEOF raise MplexStreamEOF
# TODO: Handle timeout when deadline is used. # TODO: Handle timeout when deadline is used.
data = tuple(done)[0].result()
self._buf.extend(data)
async def _read_until_eof(self) -> bytes: async def _read_until_eof(self) -> bytes:
while True: while True:
try: try:
@ -99,13 +103,15 @@ class MplexStream(IMuxedStream):
raise MplexStreamReset raise MplexStreamReset
if n == -1: if n == -1:
return await self._read_until_eof() return await self._read_until_eof()
if len(self._buf) == 0: if len(self._buf) == 0 and self.incoming_data.empty():
await self._wait_for_data() await self._wait_for_data()
# Read up to `n` bytes. # Either `buf` is not empty or `incoming_data` is not empty now.
# Try to put enough incoming data into `self._buf`.
while len(self._buf) < n: while len(self._buf) < n:
if self.incoming_data.empty() or self.event_remote_closed.is_set(): try:
self._buf.extend(self.incoming_data.get_nowait())
except asyncio.QueueEmpty:
break break
self._buf.extend(await self.incoming_data.get())
payload = self._buf[:n] payload = self._buf[:n]
self._buf = self._buf[len(payload) :] self._buf = self._buf[len(payload) :]
return bytes(payload) return bytes(payload)
@ -115,6 +121,8 @@ class MplexStream(IMuxedStream):
write to stream write to stream
:return: number of bytes written :return: number of bytes written
""" """
if self.event_local_closed.is_set():
raise MplexStreamClosed(f"cannot write to closed stream: data={data}")
flag = ( flag = (
HeaderTags.MessageInitiator HeaderTags.MessageInitiator
if self.is_initiator if self.is_initiator

View File

@ -1,22 +1,29 @@
from typing import Dict import asyncio
from typing import Dict, Tuple
import factory import factory
from libp2p import generate_new_rsa_identity, initialize_default_swarm from libp2p import generate_new_rsa_identity, initialize_default_swarm
from libp2p.crypto.keys import KeyPair from libp2p.crypto.keys import KeyPair
from libp2p.host.basic_host import BasicHost from libp2p.host.basic_host import BasicHost
from libp2p.host.host_interface import IHost
from libp2p.network.stream.net_stream_interface import INetStream
from libp2p.pubsub.floodsub import FloodSub from libp2p.pubsub.floodsub import FloodSub
from libp2p.pubsub.gossipsub import GossipSub from libp2p.pubsub.gossipsub import GossipSub
from libp2p.pubsub.pubsub import Pubsub from libp2p.pubsub.pubsub import Pubsub
from libp2p.security.base_transport import BaseSecureTransport from libp2p.security.base_transport import BaseSecureTransport
from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport
import libp2p.security.secio.transport as secio import libp2p.security.secio.transport as secio
from libp2p.stream_muxer.mplex.mplex import Mplex
from libp2p.stream_muxer.mplex.mplex_stream import MplexStream
from libp2p.typing import TProtocol from libp2p.typing import TProtocol
from tests.configs import LISTEN_MADDR
from tests.pubsub.configs import ( from tests.pubsub.configs import (
FLOODSUB_PROTOCOL_ID, FLOODSUB_PROTOCOL_ID,
GOSSIPSUB_PARAMS, GOSSIPSUB_PARAMS,
GOSSIPSUB_PROTOCOL_ID, GOSSIPSUB_PROTOCOL_ID,
) )
from tests.utils import connect
def security_transport_factory( def security_transport_factory(
@ -43,6 +50,12 @@ class HostFactory(factory.Factory):
network = factory.LazyAttribute(lambda o: swarm_factory(o.is_secure)) network = factory.LazyAttribute(lambda o: swarm_factory(o.is_secure))
@classmethod
async def create_and_listen(cls) -> IHost:
host = cls()
await host.get_network().listen(LISTEN_MADDR)
return host
class FloodsubFactory(factory.Factory): class FloodsubFactory(factory.Factory):
class Meta: class Meta:
@ -73,3 +86,37 @@ class PubsubFactory(factory.Factory):
router = None router = None
my_id = factory.LazyAttribute(lambda obj: obj.host.get_id()) my_id = factory.LazyAttribute(lambda obj: obj.host.get_id())
cache_size = None cache_size = None
async def host_pair_factory() -> Tuple[BasicHost, BasicHost]:
hosts = await asyncio.gather(
*[HostFactory.create_and_listen(), HostFactory.create_and_listen()]
)
await connect(hosts[0], hosts[1])
return hosts[0], hosts[1]
async def connection_pair_factory() -> Tuple[Mplex, BasicHost, Mplex, BasicHost]:
host_0, host_1 = await host_pair_factory()
mplex_conn_0 = host_0.get_network().connections[host_1.get_id()]
mplex_conn_1 = host_1.get_network().connections[host_0.get_id()]
return mplex_conn_0, host_0, mplex_conn_1, host_1
async def net_stream_pair_factory() -> Tuple[
INetStream, BasicHost, INetStream, BasicHost
]:
protocol_id = "/example/id/1"
stream_1: INetStream
# Just a proxy, we only care about the stream
def handler(stream: INetStream) -> None:
nonlocal stream_1
stream_1 = stream
host_0, host_1 = await host_pair_factory()
host_1.set_stream_handler(protocol_id, handler)
stream_0 = await host_0.new_stream(host_1.get_id(), [protocol_id])
return stream_0, host_0, stream_1, host_1