Merge pull request #304 from mhchia/fix/detection-of-close
Detect closed `Mplex`
This commit is contained in:
commit
4a838033ff
@ -30,6 +30,10 @@ logger.setLevel(logging.DEBUG)
|
|||||||
|
|
||||||
|
|
||||||
class BasicHost(IHost):
|
class BasicHost(IHost):
|
||||||
|
"""
|
||||||
|
BasicHost is a wrapper of a `INetwork` implementation. It performs protocol negotiation
|
||||||
|
on a stream with multistream-select right after a stream is initialized.
|
||||||
|
"""
|
||||||
|
|
||||||
_network: INetwork
|
_network: INetwork
|
||||||
_router: KadmeliaPeerRouter
|
_router: KadmeliaPeerRouter
|
||||||
@ -38,7 +42,6 @@ class BasicHost(IHost):
|
|||||||
multiselect: Multiselect
|
multiselect: Multiselect
|
||||||
multiselect_client: MultiselectClient
|
multiselect_client: MultiselectClient
|
||||||
|
|
||||||
# default options constructor
|
|
||||||
def __init__(self, network: INetwork, router: KadmeliaPeerRouter = None) -> None:
|
def __init__(self, network: INetwork, router: KadmeliaPeerRouter = None) -> None:
|
||||||
self._network = network
|
self._network = network
|
||||||
self._network.set_stream_handler(self._swarm_stream_handler)
|
self._network.set_stream_handler(self._swarm_stream_handler)
|
||||||
@ -76,6 +79,7 @@ class BasicHost(IHost):
|
|||||||
"""
|
"""
|
||||||
:return: all the multiaddr addresses this host is listening to
|
:return: all the multiaddr addresses this host is listening to
|
||||||
"""
|
"""
|
||||||
|
# TODO: We don't need "/p2p/{peer_id}" postfix actually.
|
||||||
p2p_part = multiaddr.Multiaddr("/p2p/{}".format(self.get_id().pretty()))
|
p2p_part = multiaddr.Multiaddr("/p2p/{}".format(self.get_id().pretty()))
|
||||||
|
|
||||||
addrs: List[multiaddr.Multiaddr] = []
|
addrs: List[multiaddr.Multiaddr] = []
|
||||||
@ -94,8 +98,6 @@ class BasicHost(IHost):
|
|||||||
"""
|
"""
|
||||||
self.multiselect.add_handler(protocol_id, stream_handler)
|
self.multiselect.add_handler(protocol_id, stream_handler)
|
||||||
|
|
||||||
# `protocol_ids` can be a list of `protocol_id`
|
|
||||||
# stream will decide which `protocol_id` to run on
|
|
||||||
async def new_stream(
|
async def new_stream(
|
||||||
self, peer_id: ID, protocol_ids: Sequence[TProtocol]
|
self, peer_id: ID, protocol_ids: Sequence[TProtocol]
|
||||||
) -> INetStream:
|
) -> INetStream:
|
||||||
|
@ -37,8 +37,8 @@ class RawConnection(IRawConnection):
|
|||||||
async with self._drain_lock:
|
async with self._drain_lock:
|
||||||
try:
|
try:
|
||||||
await self.writer.drain()
|
await self.writer.drain()
|
||||||
except ConnectionResetError:
|
except ConnectionResetError as error:
|
||||||
raise RawConnError()
|
raise RawConnError(error)
|
||||||
|
|
||||||
async def read(self, n: int = -1) -> bytes:
|
async def read(self, n: int = -1) -> bytes:
|
||||||
"""
|
"""
|
||||||
|
@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Any, Awaitable, List, Set, Tuple
|
|||||||
from libp2p.network.connection.net_connection_interface import INetConn
|
from libp2p.network.connection.net_connection_interface import INetConn
|
||||||
from libp2p.network.stream.net_stream import NetStream
|
from libp2p.network.stream.net_stream import NetStream
|
||||||
from libp2p.stream_muxer.abc import IMuxedConn, IMuxedStream
|
from libp2p.stream_muxer.abc import IMuxedConn, IMuxedStream
|
||||||
|
from libp2p.stream_muxer.exceptions import MuxedConnUnavailable
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from libp2p.network.swarm import Swarm # noqa: F401
|
from libp2p.network.swarm import Swarm # noqa: F401
|
||||||
@ -34,17 +35,28 @@ class SwarmConn(INetConn):
|
|||||||
if self.event_closed.is_set():
|
if self.event_closed.is_set():
|
||||||
return
|
return
|
||||||
self.event_closed.set()
|
self.event_closed.set()
|
||||||
|
self.swarm.remove_conn(self)
|
||||||
|
|
||||||
await self.conn.close()
|
await self.conn.close()
|
||||||
|
|
||||||
|
# This is just for cleaning up state. The connection has already been closed.
|
||||||
|
# We *could* optimize this but it really isn't worth it.
|
||||||
|
for stream in self.streams:
|
||||||
|
await stream.reset()
|
||||||
|
# Schedule `self._notify_disconnected` to make it execute after `close` is finished.
|
||||||
|
asyncio.ensure_future(self._notify_disconnected())
|
||||||
|
|
||||||
for task in self._tasks:
|
for task in self._tasks:
|
||||||
task.cancel()
|
task.cancel()
|
||||||
|
|
||||||
# TODO: Reset streams for local.
|
|
||||||
# TODO: Notify closed.
|
|
||||||
|
|
||||||
async def _handle_new_streams(self) -> None:
|
async def _handle_new_streams(self) -> None:
|
||||||
# TODO: Break the loop when anything wrong in the connection.
|
|
||||||
while True:
|
while True:
|
||||||
|
try:
|
||||||
stream = await self.conn.accept_stream()
|
stream = await self.conn.accept_stream()
|
||||||
|
except MuxedConnUnavailable:
|
||||||
|
# If there is anything wrong in the MuxedConn,
|
||||||
|
# we should break the loop and close the connection.
|
||||||
|
break
|
||||||
# Asynchronously handle the accepted stream, to avoid blocking the next stream.
|
# Asynchronously handle the accepted stream, to avoid blocking the next stream.
|
||||||
await self.run_task(self._handle_muxed_stream(stream))
|
await self.run_task(self._handle_muxed_stream(stream))
|
||||||
|
|
||||||
@ -57,11 +69,16 @@ class SwarmConn(INetConn):
|
|||||||
|
|
||||||
async def _add_stream(self, muxed_stream: IMuxedStream) -> NetStream:
|
async def _add_stream(self, muxed_stream: IMuxedStream) -> NetStream:
|
||||||
net_stream = NetStream(muxed_stream)
|
net_stream = NetStream(muxed_stream)
|
||||||
|
self.streams.add(net_stream)
|
||||||
# Call notifiers since event occurred
|
# Call notifiers since event occurred
|
||||||
for notifee in self.swarm.notifees:
|
for notifee in self.swarm.notifees:
|
||||||
await notifee.opened_stream(self.swarm, net_stream)
|
await notifee.opened_stream(self.swarm, net_stream)
|
||||||
return net_stream
|
return net_stream
|
||||||
|
|
||||||
|
async def _notify_disconnected(self) -> None:
|
||||||
|
for notifee in self.swarm.notifees:
|
||||||
|
await notifee.disconnected(self.swarm, self.conn)
|
||||||
|
|
||||||
async def start(self) -> None:
|
async def start(self) -> None:
|
||||||
await self.run_task(self._handle_new_streams())
|
await self.run_task(self._handle_new_streams())
|
||||||
|
|
||||||
|
@ -272,13 +272,18 @@ class Swarm(INetwork):
|
|||||||
async def close_peer(self, peer_id: ID) -> None:
|
async def close_peer(self, peer_id: ID) -> None:
|
||||||
if peer_id not in self.connections:
|
if peer_id not in self.connections:
|
||||||
return
|
return
|
||||||
|
# TODO: Should be changed to close multisple connections,
|
||||||
|
# if we have several connections per peer in the future.
|
||||||
connection = self.connections[peer_id]
|
connection = self.connections[peer_id]
|
||||||
del self.connections[peer_id]
|
|
||||||
await connection.close()
|
await connection.close()
|
||||||
|
|
||||||
logger.debug("successfully close the connection to peer %s", peer_id)
|
logger.debug("successfully close the connection to peer %s", peer_id)
|
||||||
|
|
||||||
async def add_conn(self, muxed_conn: IMuxedConn) -> SwarmConn:
|
async def add_conn(self, muxed_conn: IMuxedConn) -> SwarmConn:
|
||||||
|
"""
|
||||||
|
Add a `IMuxedConn` to `Swarm` as a `SwarmConn`, notify "connected",
|
||||||
|
and start to monitor the connection for its new streams and disconnection.
|
||||||
|
"""
|
||||||
swarm_conn = SwarmConn(muxed_conn, self)
|
swarm_conn = SwarmConn(muxed_conn, self)
|
||||||
# Store muxed_conn with peer id
|
# Store muxed_conn with peer id
|
||||||
self.connections[muxed_conn.peer_id] = swarm_conn
|
self.connections[muxed_conn.peer_id] = swarm_conn
|
||||||
@ -288,3 +293,14 @@ class Swarm(INetwork):
|
|||||||
await notifee.connected(self, muxed_conn)
|
await notifee.connected(self, muxed_conn)
|
||||||
await swarm_conn.start()
|
await swarm_conn.start()
|
||||||
return swarm_conn
|
return swarm_conn
|
||||||
|
|
||||||
|
def remove_conn(self, swarm_conn: SwarmConn) -> None:
|
||||||
|
"""
|
||||||
|
Simply remove the connection from Swarm's records, without closing the connection.
|
||||||
|
"""
|
||||||
|
peer_id = swarm_conn.conn.peer_id
|
||||||
|
if peer_id not in self.connections:
|
||||||
|
return
|
||||||
|
# TODO: Should be changed to remove the exact connection,
|
||||||
|
# if we have several connections per peer in the future.
|
||||||
|
del self.connections[peer_id]
|
||||||
|
@ -5,7 +5,7 @@ class MuxedConnError(BaseLibp2pError):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class MuxedConnShutdown(MuxedConnError):
|
class MuxedConnUnavailable(MuxedConnError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from libp2p.stream_muxer.exceptions import (
|
from libp2p.stream_muxer.exceptions import (
|
||||||
MuxedConnError,
|
MuxedConnError,
|
||||||
MuxedConnShutdown,
|
MuxedConnUnavailable,
|
||||||
MuxedStreamClosed,
|
MuxedStreamClosed,
|
||||||
MuxedStreamEOF,
|
MuxedStreamEOF,
|
||||||
MuxedStreamReset,
|
MuxedStreamReset,
|
||||||
@ -11,7 +11,7 @@ class MplexError(MuxedConnError):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class MplexShutdown(MuxedConnShutdown):
|
class MplexUnavailable(MuxedConnUnavailable):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,9 +1,10 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
from typing import Any # noqa: F401
|
from typing import Any # noqa: F401
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Awaitable, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from libp2p.exceptions import ParseError
|
from libp2p.exceptions import ParseError
|
||||||
from libp2p.io.exceptions import IncompleteReadError
|
from libp2p.io.exceptions import IncompleteReadError
|
||||||
|
from libp2p.network.connection.exceptions import RawConnError
|
||||||
from libp2p.peer.id import ID
|
from libp2p.peer.id import ID
|
||||||
from libp2p.security.secure_conn_interface import ISecureConn
|
from libp2p.security.secure_conn_interface import ISecureConn
|
||||||
from libp2p.stream_muxer.abc import IMuxedConn, IMuxedStream
|
from libp2p.stream_muxer.abc import IMuxedConn, IMuxedStream
|
||||||
@ -17,6 +18,7 @@ from libp2p.utils import (
|
|||||||
|
|
||||||
from .constants import HeaderTags
|
from .constants import HeaderTags
|
||||||
from .datastructures import StreamID
|
from .datastructures import StreamID
|
||||||
|
from .exceptions import MplexUnavailable
|
||||||
from .mplex_stream import MplexStream
|
from .mplex_stream import MplexStream
|
||||||
|
|
||||||
MPLEX_PROTOCOL_ID = TProtocol("/mplex/6.7.0")
|
MPLEX_PROTOCOL_ID = TProtocol("/mplex/6.7.0")
|
||||||
@ -36,7 +38,8 @@ class Mplex(IMuxedConn):
|
|||||||
streams: Dict[StreamID, MplexStream]
|
streams: Dict[StreamID, MplexStream]
|
||||||
streams_lock: asyncio.Lock
|
streams_lock: asyncio.Lock
|
||||||
new_stream_queue: "asyncio.Queue[IMuxedStream]"
|
new_stream_queue: "asyncio.Queue[IMuxedStream]"
|
||||||
shutdown: asyncio.Event
|
event_shutting_down: asyncio.Event
|
||||||
|
event_closed: asyncio.Event
|
||||||
|
|
||||||
_tasks: List["asyncio.Future[Any]"]
|
_tasks: List["asyncio.Future[Any]"]
|
||||||
|
|
||||||
@ -60,7 +63,8 @@ class Mplex(IMuxedConn):
|
|||||||
self.streams = {}
|
self.streams = {}
|
||||||
self.streams_lock = asyncio.Lock()
|
self.streams_lock = asyncio.Lock()
|
||||||
self.new_stream_queue = asyncio.Queue()
|
self.new_stream_queue = asyncio.Queue()
|
||||||
self.shutdown = asyncio.Event()
|
self.event_shutting_down = asyncio.Event()
|
||||||
|
self.event_closed = asyncio.Event()
|
||||||
|
|
||||||
self._tasks = []
|
self._tasks = []
|
||||||
|
|
||||||
@ -75,16 +79,20 @@ class Mplex(IMuxedConn):
|
|||||||
"""
|
"""
|
||||||
close the stream muxer and underlying secured connection
|
close the stream muxer and underlying secured connection
|
||||||
"""
|
"""
|
||||||
for task in self._tasks:
|
if self.event_shutting_down.is_set():
|
||||||
task.cancel()
|
return
|
||||||
|
# Set the `event_shutting_down`, to allow graceful shutdown.
|
||||||
|
self.event_shutting_down.set()
|
||||||
await self.secured_conn.close()
|
await self.secured_conn.close()
|
||||||
|
# Blocked until `close` is finally set.
|
||||||
|
await self.event_closed.wait()
|
||||||
|
|
||||||
def is_closed(self) -> bool:
|
def is_closed(self) -> bool:
|
||||||
"""
|
"""
|
||||||
check connection is fully closed
|
check connection is fully closed
|
||||||
:return: true if successful
|
:return: true if successful
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
return self.event_closed.is_set()
|
||||||
|
|
||||||
def _get_next_channel_id(self) -> int:
|
def _get_next_channel_id(self) -> int:
|
||||||
"""
|
"""
|
||||||
@ -114,11 +122,29 @@ class Mplex(IMuxedConn):
|
|||||||
await self.send_message(HeaderTags.NewStream, name.encode(), stream_id)
|
await self.send_message(HeaderTags.NewStream, name.encode(), stream_id)
|
||||||
return stream
|
return stream
|
||||||
|
|
||||||
|
async def _wait_until_shutting_down_or_closed(self, coro: Awaitable[Any]) -> Any:
|
||||||
|
task_coro = asyncio.ensure_future(coro)
|
||||||
|
task_wait_closed = asyncio.ensure_future(self.event_closed.wait())
|
||||||
|
task_wait_shutting_down = asyncio.ensure_future(self.event_shutting_down.wait())
|
||||||
|
done, pending = await asyncio.wait(
|
||||||
|
[task_coro, task_wait_closed, task_wait_shutting_down],
|
||||||
|
return_when=asyncio.FIRST_COMPLETED,
|
||||||
|
)
|
||||||
|
for fut in pending:
|
||||||
|
fut.cancel()
|
||||||
|
if task_wait_closed in done:
|
||||||
|
raise MplexUnavailable("Mplex is closed")
|
||||||
|
if task_wait_shutting_down in done:
|
||||||
|
raise MplexUnavailable("Mplex is shutting down")
|
||||||
|
return task_coro.result()
|
||||||
|
|
||||||
async def accept_stream(self) -> IMuxedStream:
|
async def accept_stream(self) -> IMuxedStream:
|
||||||
"""
|
"""
|
||||||
accepts a muxed stream opened by the other end
|
accepts a muxed stream opened by the other end
|
||||||
"""
|
"""
|
||||||
return await self.new_stream_queue.get()
|
return await self._wait_until_shutting_down_or_closed(
|
||||||
|
self.new_stream_queue.get()
|
||||||
|
)
|
||||||
|
|
||||||
async def send_message(
|
async def send_message(
|
||||||
self, flag: HeaderTags, data: Optional[bytes], stream_id: StreamID
|
self, flag: HeaderTags, data: Optional[bytes], stream_id: StreamID
|
||||||
@ -137,7 +163,9 @@ class Mplex(IMuxedConn):
|
|||||||
|
|
||||||
_bytes = header + encode_varint_prefixed(data)
|
_bytes = header + encode_varint_prefixed(data)
|
||||||
|
|
||||||
return await self.write_to_stream(_bytes)
|
return await self._wait_until_shutting_down_or_closed(
|
||||||
|
self.write_to_stream(_bytes)
|
||||||
|
)
|
||||||
|
|
||||||
async def write_to_stream(self, _bytes: bytes) -> int:
|
async def write_to_stream(self, _bytes: bytes) -> int:
|
||||||
"""
|
"""
|
||||||
@ -152,55 +180,112 @@ class Mplex(IMuxedConn):
|
|||||||
"""
|
"""
|
||||||
Read a message off of the secured connection and add it to the corresponding message buffer
|
Read a message off of the secured connection and add it to the corresponding message buffer
|
||||||
"""
|
"""
|
||||||
# TODO Deal with other types of messages using flag (currently _)
|
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
channel_id, flag, message = await self.read_message()
|
try:
|
||||||
if channel_id is not None and flag is not None and message is not None:
|
await self._handle_incoming_message()
|
||||||
stream_id = StreamID(channel_id=channel_id, is_initiator=bool(flag & 1))
|
except MplexUnavailable:
|
||||||
is_stream_id_seen: bool
|
break
|
||||||
stream: MplexStream
|
# Force context switch
|
||||||
async with self.streams_lock:
|
await asyncio.sleep(0)
|
||||||
is_stream_id_seen = stream_id in self.streams
|
# If we enter here, it means this connection is shutting down.
|
||||||
if is_stream_id_seen:
|
# We should clean things up.
|
||||||
stream = self.streams[stream_id]
|
await self._cleanup()
|
||||||
# Other consequent stream message should wait until the stream get accepted
|
|
||||||
# TODO: Handle more tags, and refactor `HeaderTags`
|
async def read_message(self) -> Tuple[int, int, bytes]:
|
||||||
if flag == HeaderTags.NewStream.value:
|
"""
|
||||||
if is_stream_id_seen:
|
Read a single message off of the secured connection
|
||||||
# `NewStream` for the same id is received twice...
|
:return: stream_id, flag, message contents
|
||||||
# TODO: Shutdown
|
"""
|
||||||
pass
|
|
||||||
mplex_stream = await self._initialize_stream(
|
# FIXME: No timeout is used in Go implementation.
|
||||||
stream_id, message.decode()
|
try:
|
||||||
|
header = await decode_uvarint_from_stream(self.secured_conn)
|
||||||
|
message = await asyncio.wait_for(
|
||||||
|
read_varint_prefixed_bytes(self.secured_conn), timeout=5
|
||||||
)
|
)
|
||||||
# TODO: Check if `self` is shutdown.
|
except (ParseError, RawConnError, IncompleteReadError) as error:
|
||||||
await self.new_stream_queue.put(mplex_stream)
|
raise MplexUnavailable(
|
||||||
|
"failed to read messages correctly from the underlying connection"
|
||||||
|
) from error
|
||||||
|
except asyncio.TimeoutError as error:
|
||||||
|
raise MplexUnavailable(
|
||||||
|
"failed to read more message body within the timeout"
|
||||||
|
) from error
|
||||||
|
|
||||||
|
flag = header & 0x07
|
||||||
|
channel_id = header >> 3
|
||||||
|
|
||||||
|
return channel_id, flag, message
|
||||||
|
|
||||||
|
async def _handle_incoming_message(self) -> None:
|
||||||
|
"""
|
||||||
|
Read and handle a new incoming message.
|
||||||
|
:raise MplexUnavailable: `Mplex` encounters fatal error or is shutting down.
|
||||||
|
"""
|
||||||
|
channel_id, flag, message = await self._wait_until_shutting_down_or_closed(
|
||||||
|
self.read_message()
|
||||||
|
)
|
||||||
|
stream_id = StreamID(channel_id=channel_id, is_initiator=bool(flag & 1))
|
||||||
|
|
||||||
|
if flag == HeaderTags.NewStream.value:
|
||||||
|
await self._handle_new_stream(stream_id, message)
|
||||||
elif flag in (
|
elif flag in (
|
||||||
HeaderTags.MessageInitiator.value,
|
HeaderTags.MessageInitiator.value,
|
||||||
HeaderTags.MessageReceiver.value,
|
HeaderTags.MessageReceiver.value,
|
||||||
):
|
):
|
||||||
if not is_stream_id_seen:
|
await self._handle_message(stream_id, message)
|
||||||
|
elif flag in (HeaderTags.CloseInitiator.value, HeaderTags.CloseReceiver.value):
|
||||||
|
await self._handle_close(stream_id)
|
||||||
|
elif flag in (HeaderTags.ResetInitiator.value, HeaderTags.ResetReceiver.value):
|
||||||
|
await self._handle_reset(stream_id)
|
||||||
|
else:
|
||||||
|
# Receives messages with an unknown flag
|
||||||
|
# TODO: logging
|
||||||
|
async with self.streams_lock:
|
||||||
|
if stream_id in self.streams:
|
||||||
|
stream = self.streams[stream_id]
|
||||||
|
await stream.reset()
|
||||||
|
|
||||||
|
async def _handle_new_stream(self, stream_id: StreamID, message: bytes) -> None:
|
||||||
|
async with self.streams_lock:
|
||||||
|
if stream_id in self.streams:
|
||||||
|
# `NewStream` for the same id is received twice...
|
||||||
|
raise MplexUnavailable(
|
||||||
|
f"received NewStream message for existing stream: {stream_id}"
|
||||||
|
)
|
||||||
|
mplex_stream = await self._initialize_stream(stream_id, message.decode())
|
||||||
|
await self._wait_until_shutting_down_or_closed(
|
||||||
|
self.new_stream_queue.put(mplex_stream)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _handle_message(self, stream_id: StreamID, message: bytes) -> None:
|
||||||
|
async with self.streams_lock:
|
||||||
|
if stream_id not in self.streams:
|
||||||
# We receive a message of the stream `stream_id` which is not accepted
|
# We receive a message of the stream `stream_id` which is not accepted
|
||||||
# before. It is abnormal. Possibly disconnect?
|
# before. It is abnormal. Possibly disconnect?
|
||||||
# TODO: Warn and emit logs about this.
|
# TODO: Warn and emit logs about this.
|
||||||
continue
|
return
|
||||||
|
stream = self.streams[stream_id]
|
||||||
async with stream.close_lock:
|
async with stream.close_lock:
|
||||||
if stream.event_remote_closed.is_set():
|
if stream.event_remote_closed.is_set():
|
||||||
# TODO: Warn "Received data from remote after stream was closed by them. (len = %d)" # noqa: E501
|
# TODO: Warn "Received data from remote after stream was closed by them. (len = %d)" # noqa: E501
|
||||||
continue
|
return
|
||||||
await stream.incoming_data.put(message)
|
await self._wait_until_shutting_down_or_closed(
|
||||||
elif flag in (
|
stream.incoming_data.put(message)
|
||||||
HeaderTags.CloseInitiator.value,
|
)
|
||||||
HeaderTags.CloseReceiver.value,
|
|
||||||
):
|
async def _handle_close(self, stream_id: StreamID) -> None:
|
||||||
if not is_stream_id_seen:
|
async with self.streams_lock:
|
||||||
continue
|
if stream_id not in self.streams:
|
||||||
|
# Ignore unmatched messages for now.
|
||||||
|
return
|
||||||
|
stream = self.streams[stream_id]
|
||||||
# NOTE: If remote is already closed, then return: Technically a bug
|
# NOTE: If remote is already closed, then return: Technically a bug
|
||||||
# on the other side. We should consider killing the connection.
|
# on the other side. We should consider killing the connection.
|
||||||
async with stream.close_lock:
|
async with stream.close_lock:
|
||||||
if stream.event_remote_closed.is_set():
|
if stream.event_remote_closed.is_set():
|
||||||
continue
|
return
|
||||||
is_local_closed: bool
|
is_local_closed: bool
|
||||||
async with stream.close_lock:
|
async with stream.close_lock:
|
||||||
stream.event_remote_closed.set()
|
stream.event_remote_closed.set()
|
||||||
@ -210,16 +295,16 @@ class Mplex(IMuxedConn):
|
|||||||
if is_local_closed:
|
if is_local_closed:
|
||||||
async with self.streams_lock:
|
async with self.streams_lock:
|
||||||
del self.streams[stream_id]
|
del self.streams[stream_id]
|
||||||
elif flag in (
|
|
||||||
HeaderTags.ResetInitiator.value,
|
async def _handle_reset(self, stream_id: StreamID) -> None:
|
||||||
HeaderTags.ResetReceiver.value,
|
async with self.streams_lock:
|
||||||
):
|
if stream_id not in self.streams:
|
||||||
if not is_stream_id_seen:
|
|
||||||
# This is *ok*. We forget the stream on reset.
|
# This is *ok*. We forget the stream on reset.
|
||||||
continue
|
return
|
||||||
|
stream = self.streams[stream_id]
|
||||||
|
|
||||||
async with stream.close_lock:
|
async with stream.close_lock:
|
||||||
if not stream.event_remote_closed.is_set():
|
if not stream.event_remote_closed.is_set():
|
||||||
# TODO: Why? Only if remote is not closed before then reset.
|
|
||||||
stream.event_reset.set()
|
stream.event_reset.set()
|
||||||
|
|
||||||
stream.event_remote_closed.set()
|
stream.event_remote_closed.set()
|
||||||
@ -228,36 +313,16 @@ class Mplex(IMuxedConn):
|
|||||||
stream.event_local_closed.set()
|
stream.event_local_closed.set()
|
||||||
async with self.streams_lock:
|
async with self.streams_lock:
|
||||||
del self.streams[stream_id]
|
del self.streams[stream_id]
|
||||||
else:
|
|
||||||
# TODO: logging
|
|
||||||
if is_stream_id_seen:
|
|
||||||
await stream.reset()
|
|
||||||
|
|
||||||
# Force context switch
|
async def _cleanup(self) -> None:
|
||||||
await asyncio.sleep(0)
|
if not self.event_shutting_down.is_set():
|
||||||
|
self.event_shutting_down.set()
|
||||||
async def read_message(self) -> Tuple[int, int, bytes]:
|
async with self.streams_lock:
|
||||||
"""
|
for stream in self.streams.values():
|
||||||
Read a single message off of the secured connection
|
async with stream.close_lock:
|
||||||
:return: stream_id, flag, message contents
|
if not stream.event_remote_closed.is_set():
|
||||||
"""
|
stream.event_remote_closed.set()
|
||||||
|
stream.event_reset.set()
|
||||||
# FIXME: No timeout is used in Go implementation.
|
stream.event_local_closed.set()
|
||||||
# Timeout is set to a relatively small value to alleviate wait time to exit
|
self.streams = None
|
||||||
# loop in handle_incoming
|
self.event_closed.set()
|
||||||
try:
|
|
||||||
header = await decode_uvarint_from_stream(self.secured_conn)
|
|
||||||
except ParseError:
|
|
||||||
return None, None, None
|
|
||||||
try:
|
|
||||||
message = await asyncio.wait_for(
|
|
||||||
read_varint_prefixed_bytes(self.secured_conn), timeout=5
|
|
||||||
)
|
|
||||||
except (ParseError, IncompleteReadError, asyncio.TimeoutError):
|
|
||||||
# TODO: Investigate what we should do if time is out.
|
|
||||||
return None, None, None
|
|
||||||
|
|
||||||
flag = header & 0x07
|
|
||||||
channel_id = header >> 3
|
|
||||||
|
|
||||||
return channel_id, flag, message
|
|
||||||
|
@ -204,6 +204,10 @@ class MplexStream(IMuxedStream):
|
|||||||
self.event_remote_closed.set()
|
self.event_remote_closed.set()
|
||||||
|
|
||||||
async with self.mplex_conn.streams_lock:
|
async with self.mplex_conn.streams_lock:
|
||||||
|
if (
|
||||||
|
self.mplex_conn.streams is not None
|
||||||
|
and self.stream_id in self.mplex_conn.streams
|
||||||
|
):
|
||||||
del self.mplex_conn.streams[self.stream_id]
|
del self.mplex_conn.streams[self.stream_id]
|
||||||
|
|
||||||
# TODO deadline not in use
|
# TODO deadline not in use
|
||||||
|
@ -41,14 +41,8 @@ async def decode_uvarint_from_stream(reader: Reader) -> int:
|
|||||||
if shift > SHIFT_64_BIT_MAX:
|
if shift > SHIFT_64_BIT_MAX:
|
||||||
raise ParseError("TODO: better exception msg: Integer is too large...")
|
raise ParseError("TODO: better exception msg: Integer is too large...")
|
||||||
|
|
||||||
byte = await reader.read(1)
|
byte = await read_exactly(reader, 1)
|
||||||
|
|
||||||
try:
|
|
||||||
value = byte[0]
|
value = byte[0]
|
||||||
except IndexError:
|
|
||||||
raise ParseError(
|
|
||||||
"Unexpected end of stream while parsing LEB128 encoded integer"
|
|
||||||
)
|
|
||||||
|
|
||||||
res += (value & LOW_MASK) << shift
|
res += (value & LOW_MASK) << shift
|
||||||
|
|
||||||
|
@ -6,15 +6,14 @@ import factory
|
|||||||
from libp2p import generate_new_rsa_identity, initialize_default_swarm
|
from libp2p import generate_new_rsa_identity, initialize_default_swarm
|
||||||
from libp2p.crypto.keys import KeyPair
|
from libp2p.crypto.keys import KeyPair
|
||||||
from libp2p.host.basic_host import BasicHost
|
from libp2p.host.basic_host import BasicHost
|
||||||
from libp2p.host.host_interface import IHost
|
|
||||||
from libp2p.network.stream.net_stream_interface import INetStream
|
from libp2p.network.stream.net_stream_interface import INetStream
|
||||||
|
from libp2p.network.swarm import Swarm
|
||||||
from libp2p.pubsub.floodsub import FloodSub
|
from libp2p.pubsub.floodsub import FloodSub
|
||||||
from libp2p.pubsub.gossipsub import GossipSub
|
from libp2p.pubsub.gossipsub import GossipSub
|
||||||
from libp2p.pubsub.pubsub import Pubsub
|
from libp2p.pubsub.pubsub import Pubsub
|
||||||
from libp2p.security.base_transport import BaseSecureTransport
|
from libp2p.security.base_transport import BaseSecureTransport
|
||||||
from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport
|
from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport
|
||||||
import libp2p.security.secio.transport as secio
|
import libp2p.security.secio.transport as secio
|
||||||
from libp2p.stream_muxer.mplex.mplex import Mplex
|
|
||||||
from libp2p.typing import TProtocol
|
from libp2p.typing import TProtocol
|
||||||
from tests.configs import LISTEN_MADDR
|
from tests.configs import LISTEN_MADDR
|
||||||
from tests.pubsub.configs import (
|
from tests.pubsub.configs import (
|
||||||
@ -22,7 +21,7 @@ from tests.pubsub.configs import (
|
|||||||
GOSSIPSUB_PARAMS,
|
GOSSIPSUB_PARAMS,
|
||||||
GOSSIPSUB_PROTOCOL_ID,
|
GOSSIPSUB_PROTOCOL_ID,
|
||||||
)
|
)
|
||||||
from tests.utils import connect
|
from tests.utils import connect, connect_swarm
|
||||||
|
|
||||||
|
|
||||||
def security_transport_factory(
|
def security_transport_factory(
|
||||||
@ -34,12 +33,31 @@ def security_transport_factory(
|
|||||||
return {secio.ID: secio.Transport(key_pair)}
|
return {secio.ID: secio.Transport(key_pair)}
|
||||||
|
|
||||||
|
|
||||||
def swarm_factory(is_secure: bool):
|
def SwarmFactory(is_secure: bool) -> Swarm:
|
||||||
key_pair = generate_new_rsa_identity()
|
key_pair = generate_new_rsa_identity()
|
||||||
sec_opt = security_transport_factory(is_secure, key_pair)
|
sec_opt = security_transport_factory(False, key_pair)
|
||||||
return initialize_default_swarm(key_pair, sec_opt=sec_opt)
|
return initialize_default_swarm(key_pair, sec_opt=sec_opt)
|
||||||
|
|
||||||
|
|
||||||
|
class ListeningSwarmFactory(factory.Factory):
|
||||||
|
class Meta:
|
||||||
|
model = Swarm
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def create_and_listen(cls, is_secure: bool) -> Swarm:
|
||||||
|
swarm = SwarmFactory(is_secure)
|
||||||
|
await swarm.listen(LISTEN_MADDR)
|
||||||
|
return swarm
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def create_batch_and_listen(
|
||||||
|
cls, is_secure: bool, number: int
|
||||||
|
) -> Tuple[Swarm, ...]:
|
||||||
|
return await asyncio.gather(
|
||||||
|
*[cls.create_and_listen(is_secure) for _ in range(number)]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class HostFactory(factory.Factory):
|
class HostFactory(factory.Factory):
|
||||||
class Meta:
|
class Meta:
|
||||||
model = BasicHost
|
model = BasicHost
|
||||||
@ -47,13 +65,19 @@ class HostFactory(factory.Factory):
|
|||||||
class Params:
|
class Params:
|
||||||
is_secure = False
|
is_secure = False
|
||||||
|
|
||||||
network = factory.LazyAttribute(lambda o: swarm_factory(o.is_secure))
|
network = factory.LazyAttribute(lambda o: SwarmFactory(o.is_secure))
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def create_and_listen(cls) -> IHost:
|
async def create_and_listen(cls, is_secure: bool) -> BasicHost:
|
||||||
host = cls()
|
swarms = await ListeningSwarmFactory.create_batch_and_listen(is_secure, 1)
|
||||||
await host.get_network().listen(LISTEN_MADDR)
|
return BasicHost(swarms[0])
|
||||||
return host
|
|
||||||
|
@classmethod
|
||||||
|
async def create_batch_and_listen(
|
||||||
|
cls, is_secure: bool, number: int
|
||||||
|
) -> Tuple[BasicHost, ...]:
|
||||||
|
swarms = await ListeningSwarmFactory.create_batch_and_listen(is_secure, number)
|
||||||
|
return tuple(BasicHost(swarm) for swarm in range(swarms))
|
||||||
|
|
||||||
|
|
||||||
class FloodsubFactory(factory.Factory):
|
class FloodsubFactory(factory.Factory):
|
||||||
@ -87,24 +111,33 @@ class PubsubFactory(factory.Factory):
|
|||||||
cache_size = None
|
cache_size = None
|
||||||
|
|
||||||
|
|
||||||
async def host_pair_factory() -> Tuple[BasicHost, BasicHost]:
|
async def swarm_pair_factory(is_secure: bool) -> Tuple[Swarm, Swarm]:
|
||||||
|
swarms = await ListeningSwarmFactory.create_batch_and_listen(is_secure, 2)
|
||||||
|
await connect_swarm(swarms[0], swarms[1])
|
||||||
|
return swarms[0], swarms[1]
|
||||||
|
|
||||||
|
|
||||||
|
async def host_pair_factory(is_secure) -> Tuple[BasicHost, BasicHost]:
|
||||||
hosts = await asyncio.gather(
|
hosts = await asyncio.gather(
|
||||||
*[HostFactory.create_and_listen(), HostFactory.create_and_listen()]
|
*[
|
||||||
|
HostFactory.create_and_listen(is_secure),
|
||||||
|
HostFactory.create_and_listen(is_secure),
|
||||||
|
]
|
||||||
)
|
)
|
||||||
await connect(hosts[0], hosts[1])
|
await connect(hosts[0], hosts[1])
|
||||||
return hosts[0], hosts[1]
|
return hosts[0], hosts[1]
|
||||||
|
|
||||||
|
|
||||||
async def connection_pair_factory() -> Tuple[Mplex, BasicHost, Mplex, BasicHost]:
|
# async def connection_pair_factory() -> Tuple[Mplex, BasicHost, Mplex, BasicHost]:
|
||||||
host_0, host_1 = await host_pair_factory()
|
# host_0, host_1 = await host_pair_factory()
|
||||||
mplex_conn_0 = host_0.get_network().connections[host_1.get_id()]
|
# mplex_conn_0 = host_0.get_network().connections[host_1.get_id()]
|
||||||
mplex_conn_1 = host_1.get_network().connections[host_0.get_id()]
|
# mplex_conn_1 = host_1.get_network().connections[host_0.get_id()]
|
||||||
return mplex_conn_0, host_0, mplex_conn_1, host_1
|
# return mplex_conn_0, host_0, mplex_conn_1, host_1
|
||||||
|
|
||||||
|
|
||||||
async def net_stream_pair_factory() -> Tuple[
|
async def net_stream_pair_factory(
|
||||||
INetStream, BasicHost, INetStream, BasicHost
|
is_secure: bool
|
||||||
]:
|
) -> Tuple[INetStream, BasicHost, INetStream, BasicHost]:
|
||||||
protocol_id = "/example/id/1"
|
protocol_id = "/example/id/1"
|
||||||
|
|
||||||
stream_1: INetStream
|
stream_1: INetStream
|
||||||
@ -114,7 +147,7 @@ async def net_stream_pair_factory() -> Tuple[
|
|||||||
nonlocal stream_1
|
nonlocal stream_1
|
||||||
stream_1 = stream
|
stream_1 = stream
|
||||||
|
|
||||||
host_0, host_1 = await host_pair_factory()
|
host_0, host_1 = await host_pair_factory(is_secure)
|
||||||
host_1.set_stream_handler(protocol_id, handler)
|
host_1.set_stream_handler(protocol_id, handler)
|
||||||
|
|
||||||
stream_0 = await host_0.new_stream(host_1.get_id(), [protocol_id])
|
stream_0 = await host_0.new_stream(host_1.get_id(), [protocol_id])
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
import asyncio
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from .utils import connect
|
from .utils import connect
|
||||||
@ -21,4 +23,5 @@ async def test_connect(hosts, p2pds):
|
|||||||
# Test: `disconnect` from Go
|
# Test: `disconnect` from Go
|
||||||
await p2pd.control.disconnect(host.get_id())
|
await p2pd.control.disconnect(host.get_id())
|
||||||
# FIXME: Failed to handle disconnect
|
# FIXME: Failed to handle disconnect
|
||||||
# assert len(host.get_network().connections) == 0
|
await asyncio.sleep(0.01)
|
||||||
|
assert len(host.get_network().connections) == 0
|
||||||
|
@ -2,13 +2,22 @@ import asyncio
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from tests.factories import net_stream_pair_factory
|
from tests.factories import net_stream_pair_factory, swarm_pair_factory
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def net_stream_pair():
|
async def net_stream_pair(is_host_secure):
|
||||||
stream_0, host_0, stream_1, host_1 = await net_stream_pair_factory()
|
stream_0, host_0, stream_1, host_1 = await net_stream_pair_factory(is_host_secure)
|
||||||
try:
|
try:
|
||||||
yield stream_0, stream_1
|
yield stream_0, stream_1
|
||||||
finally:
|
finally:
|
||||||
await asyncio.gather(*[host_0.close(), host_1.close()])
|
await asyncio.gather(*[host_0.close(), host_1.close()])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def swarm_pair(is_host_secure):
|
||||||
|
swarm_0, swarm_1 = await swarm_pair_factory(is_host_secure)
|
||||||
|
try:
|
||||||
|
yield swarm_0, swarm_1
|
||||||
|
finally:
|
||||||
|
await asyncio.gather(*[swarm_0.close(), swarm_1.close()])
|
||||||
|
93
tests/network/test_swarm.py
Normal file
93
tests/network/test_swarm.py
Normal file
@ -0,0 +1,93 @@
|
|||||||
|
import asyncio
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from libp2p.network.exceptions import SwarmException
|
||||||
|
from tests.factories import ListeningSwarmFactory
|
||||||
|
from tests.utils import connect_swarm
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_swarm_dial_peer(is_host_secure):
|
||||||
|
swarms = await ListeningSwarmFactory.create_batch_and_listen(is_host_secure, 3)
|
||||||
|
# Test: No addr found.
|
||||||
|
with pytest.raises(SwarmException):
|
||||||
|
await swarms[0].dial_peer(swarms[1].get_peer_id())
|
||||||
|
|
||||||
|
# Test: len(addr) in the peerstore is 0.
|
||||||
|
swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), [], 10000)
|
||||||
|
with pytest.raises(SwarmException):
|
||||||
|
await swarms[0].dial_peer(swarms[1].get_peer_id())
|
||||||
|
|
||||||
|
# Test: Succeed if addrs of the peer_id are present in the peerstore.
|
||||||
|
addrs = tuple(
|
||||||
|
addr
|
||||||
|
for transport in swarms[1].listeners.values()
|
||||||
|
for addr in transport.get_addrs()
|
||||||
|
)
|
||||||
|
swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), addrs, 10000)
|
||||||
|
await swarms[0].dial_peer(swarms[1].get_peer_id())
|
||||||
|
assert swarms[0].get_peer_id() in swarms[1].connections
|
||||||
|
assert swarms[1].get_peer_id() in swarms[0].connections
|
||||||
|
|
||||||
|
# Test: Reuse connections when we already have ones with a peer.
|
||||||
|
conn_to_1 = swarms[0].connections[swarms[1].get_peer_id()]
|
||||||
|
conn = await swarms[0].dial_peer(swarms[1].get_peer_id())
|
||||||
|
assert conn is conn_to_1
|
||||||
|
|
||||||
|
# Clean up
|
||||||
|
await asyncio.gather(*[swarm.close() for swarm in swarms])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_swarm_close_peer(is_host_secure):
|
||||||
|
swarms = await ListeningSwarmFactory.create_batch_and_listen(is_host_secure, 3)
|
||||||
|
# 0 <> 1 <> 2
|
||||||
|
await connect_swarm(swarms[0], swarms[1])
|
||||||
|
await connect_swarm(swarms[1], swarms[2])
|
||||||
|
|
||||||
|
# peer 1 closes peer 0
|
||||||
|
await swarms[1].close_peer(swarms[0].get_peer_id())
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
# 0 1 <> 2
|
||||||
|
assert len(swarms[0].connections) == 0
|
||||||
|
assert (
|
||||||
|
len(swarms[1].connections) == 1
|
||||||
|
and swarms[2].get_peer_id() in swarms[1].connections
|
||||||
|
)
|
||||||
|
|
||||||
|
# peer 1 is closed by peer 2
|
||||||
|
await swarms[2].close_peer(swarms[1].get_peer_id())
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
# 0 1 2
|
||||||
|
assert len(swarms[1].connections) == 0 and len(swarms[2].connections) == 0
|
||||||
|
|
||||||
|
await connect_swarm(swarms[0], swarms[1])
|
||||||
|
# 0 <> 1 2
|
||||||
|
assert (
|
||||||
|
len(swarms[0].connections) == 1
|
||||||
|
and swarms[1].get_peer_id() in swarms[0].connections
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
len(swarms[1].connections) == 1
|
||||||
|
and swarms[0].get_peer_id() in swarms[1].connections
|
||||||
|
)
|
||||||
|
# peer 0 closes peer 1
|
||||||
|
await swarms[0].close_peer(swarms[1].get_peer_id())
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
# 0 1 2
|
||||||
|
assert len(swarms[1].connections) == 0 and len(swarms[2].connections) == 0
|
||||||
|
|
||||||
|
# Clean up
|
||||||
|
await asyncio.gather(*[swarm.close() for swarm in swarms])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_swarm_remove_conn(swarm_pair):
|
||||||
|
swarm_0, swarm_1 = swarm_pair
|
||||||
|
conn_0 = swarm_0.connections[swarm_1.get_peer_id()]
|
||||||
|
swarm_0.remove_conn(conn_0)
|
||||||
|
assert swarm_1.get_peer_id() not in swarm_0.connections
|
||||||
|
# Test: Remove twice. There should not be errors.
|
||||||
|
swarm_0.remove_conn(conn_0)
|
||||||
|
assert swarm_1.get_peer_id() not in swarm_0.connections
|
@ -5,6 +5,19 @@ from libp2p.peer.peerinfo import info_from_p2p_addr
|
|||||||
from tests.constants import MAX_READ_LEN
|
from tests.constants import MAX_READ_LEN
|
||||||
|
|
||||||
|
|
||||||
|
async def connect_swarm(swarm_0, swarm_1):
|
||||||
|
peer_id = swarm_1.get_peer_id()
|
||||||
|
addrs = tuple(
|
||||||
|
addr
|
||||||
|
for transport in swarm_1.listeners.values()
|
||||||
|
for addr in transport.get_addrs()
|
||||||
|
)
|
||||||
|
swarm_0.peerstore.add_addrs(peer_id, addrs, 10000)
|
||||||
|
await swarm_0.dial_peer(peer_id)
|
||||||
|
assert swarm_0.get_peer_id() in swarm_1.connections
|
||||||
|
assert swarm_1.get_peer_id() in swarm_0.connections
|
||||||
|
|
||||||
|
|
||||||
async def connect(node1, node2):
|
async def connect(node1, node2):
|
||||||
"""
|
"""
|
||||||
Connect node1 to node2
|
Connect node1 to node2
|
||||||
|
Loading…
x
Reference in New Issue
Block a user