Merge pull request #346 from ralexstokes/add-tests-for-identify
Add tests for identify
This commit is contained in:
commit
285bb2ed19
|
@ -168,7 +168,11 @@ class BasicHost(IHost):
|
||||||
protocol, handler = await self.multiselect.negotiate(
|
protocol, handler = await self.multiselect.negotiate(
|
||||||
MultiselectCommunicator(net_stream)
|
MultiselectCommunicator(net_stream)
|
||||||
)
|
)
|
||||||
except MultiselectError:
|
except MultiselectError as error:
|
||||||
|
peer_id = net_stream.muxed_conn.peer_id
|
||||||
|
logger.debug(
|
||||||
|
"failed to accept a stream from peer %s, error=%s", peer_id, error
|
||||||
|
)
|
||||||
await net_stream.reset()
|
await net_stream.reset()
|
||||||
return
|
return
|
||||||
net_stream.set_protocol(protocol)
|
net_stream.set_protocol(protocol)
|
||||||
|
|
|
@ -19,24 +19,28 @@ def _multiaddr_to_bytes(maddr: Multiaddr) -> bytes:
|
||||||
return maddr.to_bytes()
|
return maddr.to_bytes()
|
||||||
|
|
||||||
|
|
||||||
|
def _mk_identify_protobuf(host: IHost) -> Identify:
|
||||||
|
public_key = host.get_public_key()
|
||||||
|
laddrs = host.get_addrs()
|
||||||
|
protocols = host.get_mux().get_protocols()
|
||||||
|
|
||||||
|
return Identify(
|
||||||
|
protocol_version=PROTOCOL_VERSION,
|
||||||
|
agent_version=AGENT_VERSION,
|
||||||
|
public_key=public_key.serialize(),
|
||||||
|
listen_addrs=map(_multiaddr_to_bytes, laddrs),
|
||||||
|
# TODO send observed address from ``stream``
|
||||||
|
observed_addr=b"",
|
||||||
|
protocols=protocols,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def identify_handler_for(host: IHost) -> StreamHandlerFn:
|
def identify_handler_for(host: IHost) -> StreamHandlerFn:
|
||||||
async def handle_identify(stream: INetStream) -> None:
|
async def handle_identify(stream: INetStream) -> None:
|
||||||
peer_id = stream.muxed_conn.peer_id
|
peer_id = stream.muxed_conn.peer_id
|
||||||
logger.debug("received a request for %s from %s", ID, peer_id)
|
logger.debug("received a request for %s from %s", ID, peer_id)
|
||||||
|
|
||||||
public_key = host.get_public_key()
|
protobuf = _mk_identify_protobuf(host)
|
||||||
laddrs = host.get_addrs()
|
|
||||||
protocols = host.get_mux().get_protocols()
|
|
||||||
|
|
||||||
protobuf = Identify(
|
|
||||||
protocol_version=PROTOCOL_VERSION,
|
|
||||||
agent_version=AGENT_VERSION,
|
|
||||||
public_key=public_key.serialize(),
|
|
||||||
listen_addrs=map(_multiaddr_to_bytes, laddrs),
|
|
||||||
# TODO send observed address from ``stream``
|
|
||||||
observed_addr=b"",
|
|
||||||
protocols=protocols,
|
|
||||||
)
|
|
||||||
response = protobuf.SerializeToString()
|
response = protobuf.SerializeToString()
|
||||||
|
|
||||||
await stream.write(response)
|
await stream.write(response)
|
||||||
|
|
|
@ -111,6 +111,9 @@ class SecureSession(BaseSession):
|
||||||
self.high_watermark = len(msg)
|
self.high_watermark = len(msg)
|
||||||
|
|
||||||
async def read(self, n: int = -1) -> bytes:
|
async def read(self, n: int = -1) -> bytes:
|
||||||
|
if n == 0:
|
||||||
|
return bytes()
|
||||||
|
|
||||||
data_from_buffer = self._drain(n)
|
data_from_buffer = self._drain(n)
|
||||||
if len(data_from_buffer) > 0:
|
if len(data_from_buffer) > 0:
|
||||||
return data_from_buffer
|
return data_from_buffer
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import logging
|
||||||
from typing import Any # noqa: F401
|
from typing import Any # noqa: F401
|
||||||
from typing import Awaitable, Dict, List, Optional, Tuple
|
from typing import Awaitable, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
@ -23,6 +24,8 @@ from .mplex_stream import MplexStream
|
||||||
|
|
||||||
MPLEX_PROTOCOL_ID = TProtocol("/mplex/6.7.0")
|
MPLEX_PROTOCOL_ID = TProtocol("/mplex/6.7.0")
|
||||||
|
|
||||||
|
logger = logging.getLogger("libp2p.stream_muxer.mplex.mplex")
|
||||||
|
|
||||||
|
|
||||||
class Mplex(IMuxedConn):
|
class Mplex(IMuxedConn):
|
||||||
"""
|
"""
|
||||||
|
@ -181,7 +184,8 @@ class Mplex(IMuxedConn):
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
await self._handle_incoming_message()
|
await self._handle_incoming_message()
|
||||||
except MplexUnavailable:
|
except MplexUnavailable as e:
|
||||||
|
logger.debug("mplex unavailable while waiting for incoming: %s", e)
|
||||||
break
|
break
|
||||||
# Force context switch
|
# Force context switch
|
||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0)
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
from typing import Dict, Tuple
|
from typing import Dict, Tuple
|
||||||
|
|
||||||
import factory
|
import factory
|
||||||
|
@ -163,6 +164,14 @@ async def host_pair_factory(is_secure) -> Tuple[BasicHost, BasicHost]:
|
||||||
return hosts[0], hosts[1]
|
return hosts[0], hosts[1]
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def pair_of_connected_hosts(is_secure=True):
|
||||||
|
a, b = await host_pair_factory(is_secure)
|
||||||
|
yield a, b
|
||||||
|
close_tasks = (a.close(), b.close())
|
||||||
|
await asyncio.gather(*close_tasks)
|
||||||
|
|
||||||
|
|
||||||
async def swarm_conn_pair_factory(
|
async def swarm_conn_pair_factory(
|
||||||
is_secure: bool, muxer_opt: TMuxerOptions = None
|
is_secure: bool, muxer_opt: TMuxerOptions = None
|
||||||
) -> Tuple[SwarmConn, Swarm, SwarmConn, Swarm]:
|
) -> Tuple[SwarmConn, Swarm, SwarmConn, Swarm]:
|
||||||
|
|
|
@ -4,25 +4,18 @@ import secrets
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from libp2p.host.ping import ID, PING_LENGTH
|
from libp2p.host.ping import ID, PING_LENGTH
|
||||||
from libp2p.peer.peerinfo import info_from_p2p_addr
|
from tests.factories import pair_of_connected_hosts
|
||||||
from tests.utils import set_up_nodes_by_transport_opt
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_ping_once():
|
async def test_ping_once():
|
||||||
transport_opt_list = [["/ip4/127.0.0.1/tcp/0"], ["/ip4/127.0.0.1/tcp/0"]]
|
async with pair_of_connected_hosts() as (host_a, host_b):
|
||||||
(host_a, host_b) = await set_up_nodes_by_transport_opt(transport_opt_list)
|
stream = await host_b.new_stream(host_a.get_id(), (ID,))
|
||||||
|
some_ping = secrets.token_bytes(PING_LENGTH)
|
||||||
addr = host_a.get_addrs()[0]
|
await stream.write(some_ping)
|
||||||
info = info_from_p2p_addr(addr)
|
some_pong = await stream.read(PING_LENGTH)
|
||||||
await host_b.connect(info)
|
assert some_ping == some_pong
|
||||||
|
await stream.close()
|
||||||
stream = await host_b.new_stream(host_a.get_id(), (ID,))
|
|
||||||
some_ping = secrets.token_bytes(PING_LENGTH)
|
|
||||||
await stream.write(some_ping)
|
|
||||||
some_pong = await stream.read(PING_LENGTH)
|
|
||||||
assert some_ping == some_pong
|
|
||||||
await stream.close()
|
|
||||||
|
|
||||||
|
|
||||||
SOME_PING_COUNT = 3
|
SOME_PING_COUNT = 3
|
||||||
|
@ -30,21 +23,15 @@ SOME_PING_COUNT = 3
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_ping_several():
|
async def test_ping_several():
|
||||||
transport_opt_list = [["/ip4/127.0.0.1/tcp/0"], ["/ip4/127.0.0.1/tcp/0"]]
|
async with pair_of_connected_hosts() as (host_a, host_b):
|
||||||
(host_a, host_b) = await set_up_nodes_by_transport_opt(transport_opt_list)
|
stream = await host_b.new_stream(host_a.get_id(), (ID,))
|
||||||
|
for _ in range(SOME_PING_COUNT):
|
||||||
addr = host_a.get_addrs()[0]
|
some_ping = secrets.token_bytes(PING_LENGTH)
|
||||||
info = info_from_p2p_addr(addr)
|
await stream.write(some_ping)
|
||||||
await host_b.connect(info)
|
some_pong = await stream.read(PING_LENGTH)
|
||||||
|
assert some_ping == some_pong
|
||||||
stream = await host_b.new_stream(host_a.get_id(), (ID,))
|
# NOTE: simulate some time to sleep to mirror a real
|
||||||
for _ in range(SOME_PING_COUNT):
|
# world usage where a peer sends pings on some periodic interval
|
||||||
some_ping = secrets.token_bytes(PING_LENGTH)
|
# NOTE: this interval can be `0` for this test.
|
||||||
await stream.write(some_ping)
|
await asyncio.sleep(0)
|
||||||
some_pong = await stream.read(PING_LENGTH)
|
await stream.close()
|
||||||
assert some_ping == some_pong
|
|
||||||
# NOTE: simulate some time to sleep to mirror a real
|
|
||||||
# world usage where a peer sends pings on some periodic interval
|
|
||||||
# NOTE: this interval can be `0` for this test.
|
|
||||||
await asyncio.sleep(0)
|
|
||||||
await stream.close()
|
|
||||||
|
|
17
tests/identity/identify/test_protocol.py
Normal file
17
tests/identity/identify/test_protocol.py
Normal file
|
@ -0,0 +1,17 @@
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from libp2p.identity.identify.pb.identify_pb2 import Identify
|
||||||
|
from libp2p.identity.identify.protocol import ID, _mk_identify_protobuf
|
||||||
|
from tests.factories import pair_of_connected_hosts
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_identify_protocol():
|
||||||
|
async with pair_of_connected_hosts() as (host_a, host_b):
|
||||||
|
stream = await host_b.new_stream(host_a.get_id(), (ID,))
|
||||||
|
response = await stream.read()
|
||||||
|
await stream.close()
|
||||||
|
|
||||||
|
identify_response = Identify()
|
||||||
|
identify_response.ParseFromString(response)
|
||||||
|
assert identify_response == _mk_identify_protobuf(host_a)
|
Loading…
Reference in New Issue
Block a user