Merge ef2e2487af9ccc8260bfcb8a8ecf51120a482847 into 1f881e04648f296e4eb89450ecd8333438c3d2d3
This commit is contained in:
commit
a851a11a4a
@ -1,7 +1,13 @@
|
|||||||
import logging
|
import logging
|
||||||
|
import math
|
||||||
|
import secrets
|
||||||
|
import time
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
import trio
|
import trio
|
||||||
|
|
||||||
|
from libp2p.exceptions import ValidationError
|
||||||
|
from libp2p.host.host_interface import IHost
|
||||||
from libp2p.network.stream.exceptions import StreamClosed, StreamEOF, StreamReset
|
from libp2p.network.stream.exceptions import StreamClosed, StreamEOF, StreamReset
|
||||||
from libp2p.network.stream.net_stream_interface import INetStream
|
from libp2p.network.stream.net_stream_interface import INetStream
|
||||||
from libp2p.peer.id import ID as PeerID
|
from libp2p.peer.id import ID as PeerID
|
||||||
@ -14,6 +20,21 @@ RESP_TIMEOUT = 60
|
|||||||
logger = logging.getLogger("libp2p.host.ping")
|
logger = logging.getLogger("libp2p.host.ping")
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_ping(stream: INetStream) -> None:
|
||||||
|
"""``handle_ping`` responds to incoming ping requests until one side errors
|
||||||
|
or closes the ``stream``."""
|
||||||
|
peer_id = stream.muxed_conn.peer_id
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
should_continue = await _handle_ping(stream, peer_id)
|
||||||
|
if not should_continue:
|
||||||
|
return
|
||||||
|
except Exception:
|
||||||
|
await stream.reset()
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
async def _handle_ping(stream: INetStream, peer_id: PeerID) -> bool:
|
async def _handle_ping(stream: INetStream, peer_id: PeerID) -> bool:
|
||||||
"""Return a boolean indicating if we expect more pings from the peer at
|
"""Return a boolean indicating if we expect more pings from the peer at
|
||||||
``peer_id``."""
|
``peer_id``."""
|
||||||
@ -45,16 +66,65 @@ async def _handle_ping(stream: INetStream, peer_id: PeerID) -> bool:
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
async def handle_ping(stream: INetStream) -> None:
|
class PingService:
|
||||||
"""``handle_ping`` responds to incoming ping requests until one side errors
|
"""PingService executes pings and returns RTT in miliseconds."""
|
||||||
or closes the ``stream``."""
|
|
||||||
peer_id = stream.muxed_conn.peer_id
|
|
||||||
|
|
||||||
while True:
|
def __init__(self, host: IHost):
|
||||||
|
self._host = host
|
||||||
|
|
||||||
|
async def ping(self, peer_id: PeerID) -> int:
|
||||||
|
stream = await self._host.new_stream(peer_id, (ID,))
|
||||||
try:
|
try:
|
||||||
should_continue = await _handle_ping(stream, peer_id)
|
rtt = await _ping(stream)
|
||||||
if not should_continue:
|
await _close_stream(stream)
|
||||||
return
|
return rtt
|
||||||
except Exception:
|
except Exception:
|
||||||
await stream.reset()
|
await _close_stream(stream)
|
||||||
return
|
raise
|
||||||
|
|
||||||
|
async def ping_loop(
|
||||||
|
self, peer_id: PeerID, ping_amount: Union[int, float] = math.inf
|
||||||
|
) -> "PingIterator":
|
||||||
|
stream = await self._host.new_stream(peer_id, (ID,))
|
||||||
|
ping_iterator = PingIterator(stream, ping_amount)
|
||||||
|
return ping_iterator
|
||||||
|
|
||||||
|
|
||||||
|
class PingIterator:
|
||||||
|
def __init__(self, stream: INetStream, ping_amount: Union[int, float]):
|
||||||
|
self._stream = stream
|
||||||
|
self._ping_limit = ping_amount
|
||||||
|
self._ping_counter = 0
|
||||||
|
|
||||||
|
def __aiter__(self) -> "PingIterator":
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __anext__(self) -> int:
|
||||||
|
if self._ping_counter > self._ping_limit:
|
||||||
|
await _close_stream(self._stream)
|
||||||
|
raise StopAsyncIteration
|
||||||
|
|
||||||
|
self._ping_counter += 1
|
||||||
|
try:
|
||||||
|
return await _ping(self._stream)
|
||||||
|
except trio.EndOfChannel:
|
||||||
|
await _close_stream(self._stream)
|
||||||
|
raise StopAsyncIteration
|
||||||
|
|
||||||
|
|
||||||
|
async def _ping(stream: INetStream) -> int:
|
||||||
|
ping_bytes = secrets.token_bytes(PING_LENGTH)
|
||||||
|
before = int(time.time() * 10 ** 6) # convert float of seconds to int miliseconds
|
||||||
|
await stream.write(ping_bytes)
|
||||||
|
pong_bytes = await stream.read(PING_LENGTH)
|
||||||
|
rtt = int(time.time() * 10 ** 6) - before
|
||||||
|
if ping_bytes != pong_bytes:
|
||||||
|
raise ValidationError("Invalid PING response")
|
||||||
|
return rtt
|
||||||
|
|
||||||
|
|
||||||
|
async def _close_stream(stream: INetStream) -> None:
|
||||||
|
try:
|
||||||
|
await stream.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
@ -391,7 +391,7 @@ class GossipSub(IPubsubRouter, Service):
|
|||||||
await trio.sleep(self.heartbeat_interval)
|
await trio.sleep(self.heartbeat_interval)
|
||||||
|
|
||||||
def mesh_heartbeat(
|
def mesh_heartbeat(
|
||||||
self
|
self,
|
||||||
) -> Tuple[DefaultDict[ID, List[str]], DefaultDict[ID, List[str]]]:
|
) -> Tuple[DefaultDict[ID, List[str]], DefaultDict[ID, List[str]]]:
|
||||||
peers_to_graft: DefaultDict[ID, List[str]] = defaultdict(list)
|
peers_to_graft: DefaultDict[ID, List[str]] = defaultdict(list)
|
||||||
peers_to_prune: DefaultDict[ID, List[str]] = defaultdict(list)
|
peers_to_prune: DefaultDict[ID, List[str]] = defaultdict(list)
|
||||||
|
@ -84,7 +84,7 @@ def noise_transport_factory() -> NoiseTransport:
|
|||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def raw_conn_factory(
|
async def raw_conn_factory(
|
||||||
nursery: trio.Nursery
|
nursery: trio.Nursery,
|
||||||
) -> AsyncIterator[Tuple[IRawConnection, IRawConnection]]:
|
) -> AsyncIterator[Tuple[IRawConnection, IRawConnection]]:
|
||||||
conn_0 = None
|
conn_0 = None
|
||||||
conn_1 = None
|
conn_1 = None
|
||||||
@ -401,7 +401,7 @@ async def swarm_pair_factory(
|
|||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def host_pair_factory(
|
async def host_pair_factory(
|
||||||
is_secure: bool
|
is_secure: bool,
|
||||||
) -> AsyncIterator[Tuple[BasicHost, BasicHost]]:
|
) -> AsyncIterator[Tuple[BasicHost, BasicHost]]:
|
||||||
async with HostFactory.create_batch_and_listen(is_secure, 2) as hosts:
|
async with HostFactory.create_batch_and_listen(is_secure, 2) as hosts:
|
||||||
await connect(hosts[0], hosts[1])
|
await connect(hosts[0], hosts[1])
|
||||||
@ -420,7 +420,7 @@ async def swarm_conn_pair_factory(
|
|||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def mplex_conn_pair_factory(
|
async def mplex_conn_pair_factory(
|
||||||
is_secure: bool
|
is_secure: bool,
|
||||||
) -> AsyncIterator[Tuple[Mplex, Mplex]]:
|
) -> AsyncIterator[Tuple[Mplex, Mplex]]:
|
||||||
muxer_opt = {MPLEX_PROTOCOL_ID: Mplex}
|
muxer_opt = {MPLEX_PROTOCOL_ID: Mplex}
|
||||||
async with swarm_conn_pair_factory(is_secure, muxer_opt=muxer_opt) as swarm_pair:
|
async with swarm_conn_pair_factory(is_secure, muxer_opt=muxer_opt) as swarm_pair:
|
||||||
@ -432,7 +432,7 @@ async def mplex_conn_pair_factory(
|
|||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def mplex_stream_pair_factory(
|
async def mplex_stream_pair_factory(
|
||||||
is_secure: bool
|
is_secure: bool,
|
||||||
) -> AsyncIterator[Tuple[MplexStream, MplexStream]]:
|
) -> AsyncIterator[Tuple[MplexStream, MplexStream]]:
|
||||||
async with mplex_conn_pair_factory(is_secure) as mplex_conn_pair_info:
|
async with mplex_conn_pair_factory(is_secure) as mplex_conn_pair_info:
|
||||||
mplex_conn_0, mplex_conn_1 = mplex_conn_pair_info
|
mplex_conn_0, mplex_conn_1 = mplex_conn_pair_info
|
||||||
@ -448,7 +448,7 @@ async def mplex_stream_pair_factory(
|
|||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def net_stream_pair_factory(
|
async def net_stream_pair_factory(
|
||||||
is_secure: bool
|
is_secure: bool,
|
||||||
) -> AsyncIterator[Tuple[INetStream, INetStream]]:
|
) -> AsyncIterator[Tuple[INetStream, INetStream]]:
|
||||||
protocol_id = TProtocol("/example/id/1")
|
protocol_id = TProtocol("/example/id/1")
|
||||||
|
|
||||||
|
@ -30,7 +30,7 @@ async def connect(node1: IHost, node2: IHost) -> None:
|
|||||||
|
|
||||||
|
|
||||||
def create_echo_stream_handler(
|
def create_echo_stream_handler(
|
||||||
ack_prefix: str
|
ack_prefix: str,
|
||||||
) -> Callable[[INetStream], Awaitable[None]]:
|
) -> Callable[[INetStream], Awaitable[None]]:
|
||||||
async def echo_stream_handler(stream: INetStream) -> None:
|
async def echo_stream_handler(stream: INetStream) -> None:
|
||||||
while True:
|
while True:
|
||||||
|
@ -3,7 +3,7 @@ import secrets
|
|||||||
import pytest
|
import pytest
|
||||||
import trio
|
import trio
|
||||||
|
|
||||||
from libp2p.host.ping import ID, PING_LENGTH
|
from libp2p.host.ping import ID, PING_LENGTH, PingService
|
||||||
from libp2p.tools.factories import host_pair_factory
|
from libp2p.tools.factories import host_pair_factory
|
||||||
|
|
||||||
|
|
||||||
@ -36,3 +36,32 @@ async def test_ping_several(is_host_secure):
|
|||||||
# NOTE: this interval can be `0` for this test.
|
# NOTE: this interval can be `0` for this test.
|
||||||
await trio.sleep(0)
|
await trio.sleep(0)
|
||||||
await stream.close()
|
await stream.close()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.trio
|
||||||
|
async def test_ping_service_once(is_host_secure):
|
||||||
|
async with host_pair_factory(is_host_secure) as (host_a, host_b):
|
||||||
|
ping_service = PingService(host_b)
|
||||||
|
rtt = await ping_service.ping(host_a.get_id())
|
||||||
|
assert rtt < 10 ** 6 # rtt is in miliseconds
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.trio
|
||||||
|
async def test_ping_service_loop(is_host_secure):
|
||||||
|
async with host_pair_factory(is_host_secure) as (host_a, host_b):
|
||||||
|
ping_service = PingService(host_b)
|
||||||
|
ping_loop = await ping_service.ping_loop(
|
||||||
|
host_a.get_id(), ping_amount=SOME_PING_COUNT
|
||||||
|
)
|
||||||
|
async for rtt in ping_loop:
|
||||||
|
assert rtt < 10 ** 6
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.trio
|
||||||
|
async def test_ping_service_loop_infinite(is_host_secure):
|
||||||
|
async with host_pair_factory(is_host_secure) as (host_a, host_b):
|
||||||
|
ping_service = PingService(host_b)
|
||||||
|
ping_loop = await ping_service.ping_loop(host_a.get_id())
|
||||||
|
with trio.move_on_after(1): # breaking loop after one second
|
||||||
|
async for rtt in ping_loop:
|
||||||
|
assert rtt < 10 ** 6
|
||||||
|
Loading…
x
Reference in New Issue
Block a user