Merge pull request #346 from ralexstokes/add-tests-for-identify
Add tests for identify
This commit is contained in:
commit
285bb2ed19
|
@ -168,7 +168,11 @@ class BasicHost(IHost):
|
|||
protocol, handler = await self.multiselect.negotiate(
|
||||
MultiselectCommunicator(net_stream)
|
||||
)
|
||||
except MultiselectError:
|
||||
except MultiselectError as error:
|
||||
peer_id = net_stream.muxed_conn.peer_id
|
||||
logger.debug(
|
||||
"failed to accept a stream from peer %s, error=%s", peer_id, error
|
||||
)
|
||||
await net_stream.reset()
|
||||
return
|
||||
net_stream.set_protocol(protocol)
|
||||
|
|
|
@ -19,16 +19,12 @@ def _multiaddr_to_bytes(maddr: Multiaddr) -> bytes:
|
|||
return maddr.to_bytes()
|
||||
|
||||
|
||||
def identify_handler_for(host: IHost) -> StreamHandlerFn:
|
||||
async def handle_identify(stream: INetStream) -> None:
|
||||
peer_id = stream.muxed_conn.peer_id
|
||||
logger.debug("received a request for %s from %s", ID, peer_id)
|
||||
|
||||
def _mk_identify_protobuf(host: IHost) -> Identify:
|
||||
public_key = host.get_public_key()
|
||||
laddrs = host.get_addrs()
|
||||
protocols = host.get_mux().get_protocols()
|
||||
|
||||
protobuf = Identify(
|
||||
return Identify(
|
||||
protocol_version=PROTOCOL_VERSION,
|
||||
agent_version=AGENT_VERSION,
|
||||
public_key=public_key.serialize(),
|
||||
|
@ -37,6 +33,14 @@ def identify_handler_for(host: IHost) -> StreamHandlerFn:
|
|||
observed_addr=b"",
|
||||
protocols=protocols,
|
||||
)
|
||||
|
||||
|
||||
def identify_handler_for(host: IHost) -> StreamHandlerFn:
|
||||
async def handle_identify(stream: INetStream) -> None:
|
||||
peer_id = stream.muxed_conn.peer_id
|
||||
logger.debug("received a request for %s from %s", ID, peer_id)
|
||||
|
||||
protobuf = _mk_identify_protobuf(host)
|
||||
response = protobuf.SerializeToString()
|
||||
|
||||
await stream.write(response)
|
||||
|
|
|
@ -111,6 +111,9 @@ class SecureSession(BaseSession):
|
|||
self.high_watermark = len(msg)
|
||||
|
||||
async def read(self, n: int = -1) -> bytes:
|
||||
if n == 0:
|
||||
return bytes()
|
||||
|
||||
data_from_buffer = self._drain(n)
|
||||
if len(data_from_buffer) > 0:
|
||||
return data_from_buffer
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from typing import Any # noqa: F401
|
||||
from typing import Awaitable, Dict, List, Optional, Tuple
|
||||
|
||||
|
@ -23,6 +24,8 @@ from .mplex_stream import MplexStream
|
|||
|
||||
MPLEX_PROTOCOL_ID = TProtocol("/mplex/6.7.0")
|
||||
|
||||
logger = logging.getLogger("libp2p.stream_muxer.mplex.mplex")
|
||||
|
||||
|
||||
class Mplex(IMuxedConn):
|
||||
"""
|
||||
|
@ -181,7 +184,8 @@ class Mplex(IMuxedConn):
|
|||
while True:
|
||||
try:
|
||||
await self._handle_incoming_message()
|
||||
except MplexUnavailable:
|
||||
except MplexUnavailable as e:
|
||||
logger.debug("mplex unavailable while waiting for incoming: %s", e)
|
||||
break
|
||||
# Force context switch
|
||||
await asyncio.sleep(0)
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import asyncio
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import factory
|
||||
|
@ -163,6 +164,14 @@ async def host_pair_factory(is_secure) -> Tuple[BasicHost, BasicHost]:
|
|||
return hosts[0], hosts[1]
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def pair_of_connected_hosts(is_secure=True):
|
||||
a, b = await host_pair_factory(is_secure)
|
||||
yield a, b
|
||||
close_tasks = (a.close(), b.close())
|
||||
await asyncio.gather(*close_tasks)
|
||||
|
||||
|
||||
async def swarm_conn_pair_factory(
|
||||
is_secure: bool, muxer_opt: TMuxerOptions = None
|
||||
) -> Tuple[SwarmConn, Swarm, SwarmConn, Swarm]:
|
||||
|
|
|
@ -4,19 +4,12 @@ import secrets
|
|||
import pytest
|
||||
|
||||
from libp2p.host.ping import ID, PING_LENGTH
|
||||
from libp2p.peer.peerinfo import info_from_p2p_addr
|
||||
from tests.utils import set_up_nodes_by_transport_opt
|
||||
from tests.factories import pair_of_connected_hosts
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ping_once():
|
||||
transport_opt_list = [["/ip4/127.0.0.1/tcp/0"], ["/ip4/127.0.0.1/tcp/0"]]
|
||||
(host_a, host_b) = await set_up_nodes_by_transport_opt(transport_opt_list)
|
||||
|
||||
addr = host_a.get_addrs()[0]
|
||||
info = info_from_p2p_addr(addr)
|
||||
await host_b.connect(info)
|
||||
|
||||
async with pair_of_connected_hosts() as (host_a, host_b):
|
||||
stream = await host_b.new_stream(host_a.get_id(), (ID,))
|
||||
some_ping = secrets.token_bytes(PING_LENGTH)
|
||||
await stream.write(some_ping)
|
||||
|
@ -30,13 +23,7 @@ SOME_PING_COUNT = 3
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ping_several():
|
||||
transport_opt_list = [["/ip4/127.0.0.1/tcp/0"], ["/ip4/127.0.0.1/tcp/0"]]
|
||||
(host_a, host_b) = await set_up_nodes_by_transport_opt(transport_opt_list)
|
||||
|
||||
addr = host_a.get_addrs()[0]
|
||||
info = info_from_p2p_addr(addr)
|
||||
await host_b.connect(info)
|
||||
|
||||
async with pair_of_connected_hosts() as (host_a, host_b):
|
||||
stream = await host_b.new_stream(host_a.get_id(), (ID,))
|
||||
for _ in range(SOME_PING_COUNT):
|
||||
some_ping = secrets.token_bytes(PING_LENGTH)
|
||||
|
|
17
tests/identity/identify/test_protocol.py
Normal file
17
tests/identity/identify/test_protocol.py
Normal file
|
@ -0,0 +1,17 @@
|
|||
import pytest
|
||||
|
||||
from libp2p.identity.identify.pb.identify_pb2 import Identify
|
||||
from libp2p.identity.identify.protocol import ID, _mk_identify_protobuf
|
||||
from tests.factories import pair_of_connected_hosts
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_identify_protocol():
|
||||
async with pair_of_connected_hosts() as (host_a, host_b):
|
||||
stream = await host_b.new_stream(host_a.get_id(), (ID,))
|
||||
response = await stream.read()
|
||||
await stream.close()
|
||||
|
||||
identify_response = Identify()
|
||||
identify_response.ParseFromString(response)
|
||||
assert identify_response == _mk_identify_protobuf(host_a)
|
Loading…
Reference in New Issue
Block a user