Implemented PING fully-featured
First draft of PingService, which calculates RTT resultspull/409/merge^2
parent
99f505d6d7
commit
ef2e2487af
|
@ -1,7 +1,13 @@
|
|||
import logging
|
||||
import math
|
||||
import secrets
|
||||
import time
|
||||
from typing import Union
|
||||
|
||||
import trio
|
||||
|
||||
from libp2p.exceptions import ValidationError
|
||||
from libp2p.host.host_interface import IHost
|
||||
from libp2p.network.stream.exceptions import StreamClosed, StreamEOF, StreamReset
|
||||
from libp2p.network.stream.net_stream_interface import INetStream
|
||||
from libp2p.peer.id import ID as PeerID
|
||||
|
@ -14,6 +20,21 @@ RESP_TIMEOUT = 60
|
|||
logger = logging.getLogger("libp2p.host.ping")
|
||||
|
||||
|
||||
async def handle_ping(stream: INetStream) -> None:
|
||||
"""``handle_ping`` responds to incoming ping requests until one side errors
|
||||
or closes the ``stream``."""
|
||||
peer_id = stream.muxed_conn.peer_id
|
||||
|
||||
while True:
|
||||
try:
|
||||
should_continue = await _handle_ping(stream, peer_id)
|
||||
if not should_continue:
|
||||
return
|
||||
except Exception:
|
||||
await stream.reset()
|
||||
return
|
||||
|
||||
|
||||
async def _handle_ping(stream: INetStream, peer_id: PeerID) -> bool:
|
||||
"""Return a boolean indicating if we expect more pings from the peer at
|
||||
``peer_id``."""
|
||||
|
@ -45,16 +66,65 @@ async def _handle_ping(stream: INetStream, peer_id: PeerID) -> bool:
|
|||
return True
|
||||
|
||||
|
||||
async def handle_ping(stream: INetStream) -> None:
|
||||
"""``handle_ping`` responds to incoming ping requests until one side errors
|
||||
or closes the ``stream``."""
|
||||
peer_id = stream.muxed_conn.peer_id
|
||||
class PingService:
|
||||
"""PingService executes pings and returns RTT in miliseconds."""
|
||||
|
||||
while True:
|
||||
def __init__(self, host: IHost):
|
||||
self._host = host
|
||||
|
||||
async def ping(self, peer_id: PeerID) -> int:
|
||||
stream = await self._host.new_stream(peer_id, (ID,))
|
||||
try:
|
||||
should_continue = await _handle_ping(stream, peer_id)
|
||||
if not should_continue:
|
||||
return
|
||||
rtt = await _ping(stream)
|
||||
await _close_stream(stream)
|
||||
return rtt
|
||||
except Exception:
|
||||
await stream.reset()
|
||||
return
|
||||
await _close_stream(stream)
|
||||
raise
|
||||
|
||||
async def ping_loop(
|
||||
self, peer_id: PeerID, ping_amount: Union[int, float] = math.inf
|
||||
) -> "PingIterator":
|
||||
stream = await self._host.new_stream(peer_id, (ID,))
|
||||
ping_iterator = PingIterator(stream, ping_amount)
|
||||
return ping_iterator
|
||||
|
||||
|
||||
class PingIterator:
|
||||
def __init__(self, stream: INetStream, ping_amount: Union[int, float]):
|
||||
self._stream = stream
|
||||
self._ping_limit = ping_amount
|
||||
self._ping_counter = 0
|
||||
|
||||
def __aiter__(self) -> "PingIterator":
|
||||
return self
|
||||
|
||||
async def __anext__(self) -> int:
|
||||
if self._ping_counter > self._ping_limit:
|
||||
await _close_stream(self._stream)
|
||||
raise StopAsyncIteration
|
||||
|
||||
self._ping_counter += 1
|
||||
try:
|
||||
return await _ping(self._stream)
|
||||
except trio.EndOfChannel:
|
||||
await _close_stream(self._stream)
|
||||
raise StopAsyncIteration
|
||||
|
||||
|
||||
async def _ping(stream: INetStream) -> int:
|
||||
ping_bytes = secrets.token_bytes(PING_LENGTH)
|
||||
before = int(time.time() * 10 ** 6) # convert float of seconds to int miliseconds
|
||||
await stream.write(ping_bytes)
|
||||
pong_bytes = await stream.read(PING_LENGTH)
|
||||
rtt = int(time.time() * 10 ** 6) - before
|
||||
if ping_bytes != pong_bytes:
|
||||
raise ValidationError("Invalid PING response")
|
||||
return rtt
|
||||
|
||||
|
||||
async def _close_stream(stream: INetStream) -> None:
|
||||
try:
|
||||
await stream.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
|
|
@ -391,7 +391,7 @@ class GossipSub(IPubsubRouter, Service):
|
|||
await trio.sleep(self.heartbeat_interval)
|
||||
|
||||
def mesh_heartbeat(
|
||||
self
|
||||
self,
|
||||
) -> Tuple[DefaultDict[ID, List[str]], DefaultDict[ID, List[str]]]:
|
||||
peers_to_graft: DefaultDict[ID, List[str]] = defaultdict(list)
|
||||
peers_to_prune: DefaultDict[ID, List[str]] = defaultdict(list)
|
||||
|
|
|
@ -67,7 +67,7 @@ def security_transport_factory(
|
|||
|
||||
@asynccontextmanager
|
||||
async def raw_conn_factory(
|
||||
nursery: trio.Nursery
|
||||
nursery: trio.Nursery,
|
||||
) -> AsyncIterator[Tuple[IRawConnection, IRawConnection]]:
|
||||
conn_0 = None
|
||||
conn_1 = None
|
||||
|
@ -351,7 +351,7 @@ async def swarm_pair_factory(
|
|||
|
||||
@asynccontextmanager
|
||||
async def host_pair_factory(
|
||||
is_secure: bool
|
||||
is_secure: bool,
|
||||
) -> AsyncIterator[Tuple[BasicHost, BasicHost]]:
|
||||
async with HostFactory.create_batch_and_listen(is_secure, 2) as hosts:
|
||||
await connect(hosts[0], hosts[1])
|
||||
|
@ -370,7 +370,7 @@ async def swarm_conn_pair_factory(
|
|||
|
||||
@asynccontextmanager
|
||||
async def mplex_conn_pair_factory(
|
||||
is_secure: bool
|
||||
is_secure: bool,
|
||||
) -> AsyncIterator[Tuple[Mplex, Mplex]]:
|
||||
muxer_opt = {MPLEX_PROTOCOL_ID: Mplex}
|
||||
async with swarm_conn_pair_factory(is_secure, muxer_opt=muxer_opt) as swarm_pair:
|
||||
|
@ -382,7 +382,7 @@ async def mplex_conn_pair_factory(
|
|||
|
||||
@asynccontextmanager
|
||||
async def mplex_stream_pair_factory(
|
||||
is_secure: bool
|
||||
is_secure: bool,
|
||||
) -> AsyncIterator[Tuple[MplexStream, MplexStream]]:
|
||||
async with mplex_conn_pair_factory(is_secure) as mplex_conn_pair_info:
|
||||
mplex_conn_0, mplex_conn_1 = mplex_conn_pair_info
|
||||
|
@ -398,7 +398,7 @@ async def mplex_stream_pair_factory(
|
|||
|
||||
@asynccontextmanager
|
||||
async def net_stream_pair_factory(
|
||||
is_secure: bool
|
||||
is_secure: bool,
|
||||
) -> AsyncIterator[Tuple[INetStream, INetStream]]:
|
||||
protocol_id = TProtocol("/example/id/1")
|
||||
|
||||
|
|
|
@ -30,7 +30,7 @@ async def connect(node1: IHost, node2: IHost) -> None:
|
|||
|
||||
|
||||
def create_echo_stream_handler(
|
||||
ack_prefix: str
|
||||
ack_prefix: str,
|
||||
) -> Callable[[INetStream], Awaitable[None]]:
|
||||
async def echo_stream_handler(stream: INetStream) -> None:
|
||||
while True:
|
||||
|
|
|
@ -3,7 +3,7 @@ import secrets
|
|||
import pytest
|
||||
import trio
|
||||
|
||||
from libp2p.host.ping import ID, PING_LENGTH
|
||||
from libp2p.host.ping import ID, PING_LENGTH, PingService
|
||||
from libp2p.tools.factories import host_pair_factory
|
||||
|
||||
|
||||
|
@ -36,3 +36,32 @@ async def test_ping_several(is_host_secure):
|
|||
# NOTE: this interval can be `0` for this test.
|
||||
await trio.sleep(0)
|
||||
await stream.close()
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_ping_service_once(is_host_secure):
|
||||
async with host_pair_factory(is_host_secure) as (host_a, host_b):
|
||||
ping_service = PingService(host_b)
|
||||
rtt = await ping_service.ping(host_a.get_id())
|
||||
assert rtt < 10 ** 6 # rtt is in miliseconds
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_ping_service_loop(is_host_secure):
|
||||
async with host_pair_factory(is_host_secure) as (host_a, host_b):
|
||||
ping_service = PingService(host_b)
|
||||
ping_loop = await ping_service.ping_loop(
|
||||
host_a.get_id(), ping_amount=SOME_PING_COUNT
|
||||
)
|
||||
async for rtt in ping_loop:
|
||||
assert rtt < 10 ** 6
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_ping_service_loop_infinite(is_host_secure):
|
||||
async with host_pair_factory(is_host_secure) as (host_a, host_b):
|
||||
ping_service = PingService(host_b)
|
||||
ping_loop = await ping_service.ping_loop(host_a.get_id())
|
||||
with trio.move_on_after(1): # breaking loop after one second
|
||||
async for rtt in ping_loop:
|
||||
assert rtt < 10 ** 6
|
||||
|
|
Loading…
Reference in New Issue