Implemented PING fully-featured

First draft of PingService, which calculates RTT results
pull/409/merge^2
aratz-lasa 2020-02-27 09:12:09 +01:00
parent 99f505d6d7
commit ef2e2487af
5 changed files with 117 additions and 18 deletions

View File

@ -1,7 +1,13 @@
import logging
import math
import secrets
import time
from typing import Union
import trio
from libp2p.exceptions import ValidationError
from libp2p.host.host_interface import IHost
from libp2p.network.stream.exceptions import StreamClosed, StreamEOF, StreamReset
from libp2p.network.stream.net_stream_interface import INetStream
from libp2p.peer.id import ID as PeerID
@ -14,6 +20,21 @@ RESP_TIMEOUT = 60
logger = logging.getLogger("libp2p.host.ping")
async def handle_ping(stream: INetStream) -> None:
"""``handle_ping`` responds to incoming ping requests until one side errors
or closes the ``stream``."""
peer_id = stream.muxed_conn.peer_id
while True:
try:
should_continue = await _handle_ping(stream, peer_id)
if not should_continue:
return
except Exception:
await stream.reset()
return
async def _handle_ping(stream: INetStream, peer_id: PeerID) -> bool:
"""Return a boolean indicating if we expect more pings from the peer at
``peer_id``."""
@ -45,16 +66,65 @@ async def _handle_ping(stream: INetStream, peer_id: PeerID) -> bool:
return True
async def handle_ping(stream: INetStream) -> None:
"""``handle_ping`` responds to incoming ping requests until one side errors
or closes the ``stream``."""
peer_id = stream.muxed_conn.peer_id
class PingService:
"""PingService executes pings and returns RTT in miliseconds."""
while True:
def __init__(self, host: IHost):
self._host = host
async def ping(self, peer_id: PeerID) -> int:
stream = await self._host.new_stream(peer_id, (ID,))
try:
should_continue = await _handle_ping(stream, peer_id)
if not should_continue:
return
rtt = await _ping(stream)
await _close_stream(stream)
return rtt
except Exception:
await stream.reset()
return
await _close_stream(stream)
raise
async def ping_loop(
self, peer_id: PeerID, ping_amount: Union[int, float] = math.inf
) -> "PingIterator":
stream = await self._host.new_stream(peer_id, (ID,))
ping_iterator = PingIterator(stream, ping_amount)
return ping_iterator
class PingIterator:
def __init__(self, stream: INetStream, ping_amount: Union[int, float]):
self._stream = stream
self._ping_limit = ping_amount
self._ping_counter = 0
def __aiter__(self) -> "PingIterator":
return self
async def __anext__(self) -> int:
if self._ping_counter > self._ping_limit:
await _close_stream(self._stream)
raise StopAsyncIteration
self._ping_counter += 1
try:
return await _ping(self._stream)
except trio.EndOfChannel:
await _close_stream(self._stream)
raise StopAsyncIteration
async def _ping(stream: INetStream) -> int:
ping_bytes = secrets.token_bytes(PING_LENGTH)
before = int(time.time() * 10 ** 6) # convert float of seconds to int miliseconds
await stream.write(ping_bytes)
pong_bytes = await stream.read(PING_LENGTH)
rtt = int(time.time() * 10 ** 6) - before
if ping_bytes != pong_bytes:
raise ValidationError("Invalid PING response")
return rtt
async def _close_stream(stream: INetStream) -> None:
try:
await stream.close()
except Exception:
pass

View File

@ -391,7 +391,7 @@ class GossipSub(IPubsubRouter, Service):
await trio.sleep(self.heartbeat_interval)
def mesh_heartbeat(
self
self,
) -> Tuple[DefaultDict[ID, List[str]], DefaultDict[ID, List[str]]]:
peers_to_graft: DefaultDict[ID, List[str]] = defaultdict(list)
peers_to_prune: DefaultDict[ID, List[str]] = defaultdict(list)

View File

@ -67,7 +67,7 @@ def security_transport_factory(
@asynccontextmanager
async def raw_conn_factory(
nursery: trio.Nursery
nursery: trio.Nursery,
) -> AsyncIterator[Tuple[IRawConnection, IRawConnection]]:
conn_0 = None
conn_1 = None
@ -351,7 +351,7 @@ async def swarm_pair_factory(
@asynccontextmanager
async def host_pair_factory(
is_secure: bool
is_secure: bool,
) -> AsyncIterator[Tuple[BasicHost, BasicHost]]:
async with HostFactory.create_batch_and_listen(is_secure, 2) as hosts:
await connect(hosts[0], hosts[1])
@ -370,7 +370,7 @@ async def swarm_conn_pair_factory(
@asynccontextmanager
async def mplex_conn_pair_factory(
is_secure: bool
is_secure: bool,
) -> AsyncIterator[Tuple[Mplex, Mplex]]:
muxer_opt = {MPLEX_PROTOCOL_ID: Mplex}
async with swarm_conn_pair_factory(is_secure, muxer_opt=muxer_opt) as swarm_pair:
@ -382,7 +382,7 @@ async def mplex_conn_pair_factory(
@asynccontextmanager
async def mplex_stream_pair_factory(
is_secure: bool
is_secure: bool,
) -> AsyncIterator[Tuple[MplexStream, MplexStream]]:
async with mplex_conn_pair_factory(is_secure) as mplex_conn_pair_info:
mplex_conn_0, mplex_conn_1 = mplex_conn_pair_info
@ -398,7 +398,7 @@ async def mplex_stream_pair_factory(
@asynccontextmanager
async def net_stream_pair_factory(
is_secure: bool
is_secure: bool,
) -> AsyncIterator[Tuple[INetStream, INetStream]]:
protocol_id = TProtocol("/example/id/1")

View File

@ -30,7 +30,7 @@ async def connect(node1: IHost, node2: IHost) -> None:
def create_echo_stream_handler(
ack_prefix: str
ack_prefix: str,
) -> Callable[[INetStream], Awaitable[None]]:
async def echo_stream_handler(stream: INetStream) -> None:
while True:

View File

@ -3,7 +3,7 @@ import secrets
import pytest
import trio
from libp2p.host.ping import ID, PING_LENGTH
from libp2p.host.ping import ID, PING_LENGTH, PingService
from libp2p.tools.factories import host_pair_factory
@ -36,3 +36,32 @@ async def test_ping_several(is_host_secure):
# NOTE: this interval can be `0` for this test.
await trio.sleep(0)
await stream.close()
@pytest.mark.trio
async def test_ping_service_once(is_host_secure):
async with host_pair_factory(is_host_secure) as (host_a, host_b):
ping_service = PingService(host_b)
rtt = await ping_service.ping(host_a.get_id())
assert rtt < 10 ** 6 # rtt is in miliseconds
@pytest.mark.trio
async def test_ping_service_loop(is_host_secure):
async with host_pair_factory(is_host_secure) as (host_a, host_b):
ping_service = PingService(host_b)
ping_loop = await ping_service.ping_loop(
host_a.get_id(), ping_amount=SOME_PING_COUNT
)
async for rtt in ping_loop:
assert rtt < 10 ** 6
@pytest.mark.trio
async def test_ping_service_loop_infinite(is_host_secure):
async with host_pair_factory(is_host_secure) as (host_a, host_b):
ping_service = PingService(host_b)
ping_loop = await ping_service.ping_loop(host_a.get_id())
with trio.move_on_after(1): # breaking loop after one second
async for rtt in ping_loop:
assert rtt < 10 ** 6