Add delim_encode and delim_read

- Add `StreamCommunicator` and `RawConnectionCommunicator`, read/write
messages with delim codec, with `IMuxedStream` and `IRawConnection`
respectively.
- Use it in `Multiselect` and `MultiselectClient`.
This commit is contained in:
mhchia 2019-08-15 23:31:26 +08:00 committed by Kevin Mai-Husan Chia
parent 8cd23abfe2
commit 86d4ce1da8
11 changed files with 74 additions and 64 deletions

View File

@ -7,6 +7,7 @@ from libp2p.peer.id import ID
from libp2p.peer.peerstore_interface import IPeerStore from libp2p.peer.peerstore_interface import IPeerStore
from libp2p.protocol_muxer.multiselect import Multiselect from libp2p.protocol_muxer.multiselect import Multiselect
from libp2p.protocol_muxer.multiselect_client import MultiselectClient from libp2p.protocol_muxer.multiselect_client import MultiselectClient
from libp2p.protocol_muxer.multiselect_communicator import StreamCommunicator
from libp2p.routing.interfaces import IPeerRouting from libp2p.routing.interfaces import IPeerRouting
from libp2p.stream_muxer.abc import IMuxedConn, IMuxedStream from libp2p.stream_muxer.abc import IMuxedConn, IMuxedStream
from libp2p.transport.listener_interface import IListener from libp2p.transport.listener_interface import IListener
@ -148,7 +149,7 @@ class Swarm(INetwork):
# Perform protocol muxing to determine protocol to use # Perform protocol muxing to determine protocol to use
selected_protocol = await self.multiselect_client.select_one_of( selected_protocol = await self.multiselect_client.select_one_of(
list(protocol_ids), muxed_stream list(protocol_ids), StreamCommunicator(muxed_stream)
) )
# Create a net stream with the selected protocol # Create a net stream with the selected protocol
@ -264,7 +265,9 @@ def create_generic_protocol_handler(swarm: Swarm) -> GenericProtocolHandlerFn:
async def generic_protocol_handler(muxed_stream: IMuxedStream) -> None: async def generic_protocol_handler(muxed_stream: IMuxedStream) -> None:
# Perform protocol muxing to determine protocol to use # Perform protocol muxing to determine protocol to use
protocol, handler = await multiselect.negotiate(muxed_stream) protocol, handler = await multiselect.negotiate(
StreamCommunicator(muxed_stream)
)
net_stream = NetStream(muxed_stream) net_stream = NetStream(muxed_stream)
net_stream.set_protocol(protocol) net_stream.set_protocol(protocol)

View File

@ -1,8 +1,7 @@
from typing import Dict, Tuple from typing import Dict, Tuple
from libp2p.typing import NegotiableTransport, StreamHandlerFn, TProtocol from libp2p.typing import StreamHandlerFn, TProtocol
from .multiselect_communicator import MultiselectCommunicator
from .multiselect_communicator_interface import IMultiselectCommunicator from .multiselect_communicator_interface import IMultiselectCommunicator
from .multiselect_muxer_interface import IMultiselectMuxer from .multiselect_muxer_interface import IMultiselectMuxer
@ -31,7 +30,7 @@ class Multiselect(IMultiselectMuxer):
self.handlers[protocol] = handler self.handlers[protocol] = handler
async def negotiate( async def negotiate(
self, stream: NegotiableTransport self, communicator: IMultiselectCommunicator
) -> Tuple[TProtocol, StreamHandlerFn]: ) -> Tuple[TProtocol, StreamHandlerFn]:
""" """
Negotiate performs protocol selection Negotiate performs protocol selection
@ -39,8 +38,6 @@ class Multiselect(IMultiselectMuxer):
:return: selected protocol name, handler function :return: selected protocol name, handler function
:raise Exception: negotiation failed exception :raise Exception: negotiation failed exception
""" """
# Create a communicator to handle all communication across the stream
communicator = MultiselectCommunicator(stream)
# Perform handshake to ensure multiselect protocol IDs match # Perform handshake to ensure multiselect protocol IDs match
await self.handshake(communicator) await self.handshake(communicator)
@ -48,7 +45,7 @@ class Multiselect(IMultiselectMuxer):
# Read and respond to commands until a valid protocol ID is sent # Read and respond to commands until a valid protocol ID is sent
while True: while True:
# Read message # Read message
command = await communicator.read_stream_until_eof() command = await communicator.read()
# Command is ls or a protocol # Command is ls or a protocol
if command == "ls": if command == "ls":
@ -78,7 +75,7 @@ class Multiselect(IMultiselectMuxer):
await communicator.write(MULTISELECT_PROTOCOL_ID) await communicator.write(MULTISELECT_PROTOCOL_ID)
# Read in the protocol ID from other party # Read in the protocol ID from other party
handshake_contents = await communicator.read_stream_until_eof() handshake_contents = await communicator.read()
# Confirm that the protocols are the same # Confirm that the protocols are the same
if not validate_handshake(handshake_contents): if not validate_handshake(handshake_contents):

View File

@ -1,10 +1,8 @@
from typing import Sequence from typing import Sequence
from libp2p.stream_muxer.abc import IMuxedStream from libp2p.typing import TProtocol
from libp2p.typing import NegotiableTransport, TProtocol
from .multiselect_client_interface import IMultiselectClient from .multiselect_client_interface import IMultiselectClient
from .multiselect_communicator import MultiselectCommunicator
from .multiselect_communicator_interface import IMultiselectCommunicator from .multiselect_communicator_interface import IMultiselectCommunicator
MULTISELECT_PROTOCOL_ID = "/multistream/1.0.0" MULTISELECT_PROTOCOL_ID = "/multistream/1.0.0"
@ -31,7 +29,7 @@ class MultiselectClient(IMultiselectClient):
await communicator.write(MULTISELECT_PROTOCOL_ID) await communicator.write(MULTISELECT_PROTOCOL_ID)
# Read in the protocol ID from other party # Read in the protocol ID from other party
handshake_contents = await communicator.read_stream_until_eof() handshake_contents = await communicator.read()
# Confirm that the protocols are the same # Confirm that the protocols are the same
if not validate_handshake(handshake_contents): if not validate_handshake(handshake_contents):
@ -40,7 +38,7 @@ class MultiselectClient(IMultiselectClient):
# Handshake succeeded if this point is reached # Handshake succeeded if this point is reached
async def select_protocol_or_fail( async def select_protocol_or_fail(
self, protocol: TProtocol, stream: IMuxedStream self, protocol: TProtocol, communicator: IMultiselectCommunicator
) -> TProtocol: ) -> TProtocol:
""" """
Send message to multiselect selecting protocol Send message to multiselect selecting protocol
@ -49,9 +47,6 @@ class MultiselectClient(IMultiselectClient):
:param stream: stream to communicate with multiselect over :param stream: stream to communicate with multiselect over
:return: selected protocol :return: selected protocol
""" """
# Create a communicator to handle all communication across the stream
communicator = MultiselectCommunicator(stream)
# Perform handshake to ensure multiselect protocol IDs match # Perform handshake to ensure multiselect protocol IDs match
await self.handshake(communicator) await self.handshake(communicator)
@ -61,7 +56,7 @@ class MultiselectClient(IMultiselectClient):
return selected_protocol return selected_protocol
async def select_one_of( async def select_one_of(
self, protocols: Sequence[TProtocol], stream: NegotiableTransport self, protocols: Sequence[TProtocol], communicator: IMultiselectCommunicator
) -> TProtocol: ) -> TProtocol:
""" """
For each protocol, send message to multiselect selecting protocol For each protocol, send message to multiselect selecting protocol
@ -71,10 +66,6 @@ class MultiselectClient(IMultiselectClient):
:param stream: stream to communicate with multiselect over :param stream: stream to communicate with multiselect over
:return: selected protocol :return: selected protocol
""" """
# Create a communicator to handle all communication across the stream
communicator = MultiselectCommunicator(stream)
# Perform handshake to ensure multiselect protocol IDs match # Perform handshake to ensure multiselect protocol IDs match
await self.handshake(communicator) await self.handshake(communicator)
@ -105,7 +96,7 @@ class MultiselectClient(IMultiselectClient):
await communicator.write(protocol) await communicator.write(protocol)
# Get what counterparty says in response # Get what counterparty says in response
response = await communicator.read_stream_until_eof() response = await communicator.read()
# Return protocol if response is equal to protocol or raise error # Return protocol if response is equal to protocol or raise error
if response == protocol: if response == protocol:

View File

@ -1,7 +1,9 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Sequence from typing import Sequence
from libp2p.stream_muxer.abc import IMuxedStream from libp2p.protocol_muxer.multiselect_communicator_interface import (
IMultiselectCommunicator,
)
from libp2p.typing import TProtocol from libp2p.typing import TProtocol
@ -13,7 +15,7 @@ class IMultiselectClient(ABC):
@abstractmethod @abstractmethod
async def select_protocol_or_fail( async def select_protocol_or_fail(
self, protocol: TProtocol, stream: IMuxedStream self, protocol: TProtocol, communicator: IMultiselectCommunicator
) -> TProtocol: ) -> TProtocol:
""" """
Send message to multiselect selecting protocol Send message to multiselect selecting protocol
@ -25,7 +27,7 @@ class IMultiselectClient(ABC):
@abstractmethod @abstractmethod
async def select_one_of( async def select_one_of(
self, protocols: Sequence[TProtocol], stream: IMuxedStream self, protocols: Sequence[TProtocol], communicator: IMultiselectCommunicator
) -> TProtocol: ) -> TProtocol:
""" """
For each protocol, send message to multiselect selecting protocol For each protocol, send message to multiselect selecting protocol

View File

@ -1,35 +1,47 @@
from libp2p.typing import NegotiableTransport from libp2p.network.connection.raw_connection_interface import IRawConnection
from libp2p.stream_muxer.abc import IMuxedStream
from libp2p.stream_muxer.mplex.utils import decode_uvarint_from_stream, encode_uvarint
from libp2p.typing import StreamReader
from .multiselect_communicator_interface import IMultiselectCommunicator from .multiselect_communicator_interface import IMultiselectCommunicator
class MultiselectCommunicator(IMultiselectCommunicator): def delim_encode(msg_str: str) -> bytes:
""" msg_bytes = msg_str.encode()
Communicator helper class that ensures both the client varint_len_msg = encode_uvarint(len(msg_bytes) + 1)
and multistream module will follow the same multistream protocol, return varint_len_msg + msg_bytes + b"\n"
which is necessary for them to work
"""
reader_writer: NegotiableTransport
def __init__(self, reader_writer: NegotiableTransport) -> None: async def delim_read(reader: StreamReader, timeout: int = 10) -> str:
""" len_msg = await decode_uvarint_from_stream(reader, timeout)
MultistreamCommunicator expects a reader_writer object that has msg_bytes = await reader.read(len_msg)
an async read and an async write function (this could be a stream, return msg_bytes.decode().rstrip()
raw connection, or other object implementing those functions)
"""
self.reader_writer = reader_writer class RawConnectionCommunicator(IMultiselectCommunicator):
conn: IRawConnection
def __init__(self, conn: IRawConnection) -> None:
self.conn = conn
async def write(self, msg_str: str) -> None: async def write(self, msg_str: str) -> None:
""" msg_bytes = delim_encode(msg_str)
Write message to reader_writer self.conn.writer.write(msg_bytes)
:param msg_str: message to write await self.conn.writer.drain()
"""
await self.reader_writer.write(msg_str.encode())
async def read_stream_until_eof(self) -> str: async def read(self) -> str:
""" return await delim_read(self.conn.reader)
Reads message from reader_writer until EOF
"""
read_str = (await self.reader_writer.read()).decode() class StreamCommunicator(IMultiselectCommunicator):
return read_str stream: IMuxedStream
def __init__(self, stream: IMuxedStream) -> None:
self.stream = stream
async def write(self, msg_str: str) -> None:
msg_bytes = delim_encode(msg_str)
await self.stream.write(msg_bytes)
async def read(self) -> str:
return await delim_read(self.stream)

View File

@ -16,7 +16,7 @@ class IMultiselectCommunicator(ABC):
""" """
@abstractmethod @abstractmethod
async def read_stream_until_eof(self) -> str: async def read(self) -> str:
""" """
Reads message from stream until EOF Reads message from stream until EOF
""" """

View File

@ -1,7 +1,9 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Dict, Tuple from typing import Dict, Tuple
from libp2p.typing import NegotiableTransport, StreamHandlerFn, TProtocol from libp2p.typing import StreamHandlerFn, TProtocol
from .multiselect_communicator_interface import IMultiselectCommunicator
class IMultiselectMuxer(ABC): class IMultiselectMuxer(ABC):
@ -23,7 +25,7 @@ class IMultiselectMuxer(ABC):
@abstractmethod @abstractmethod
async def negotiate( async def negotiate(
self, stream: NegotiableTransport self, communicator: IMultiselectCommunicator
) -> Tuple[TProtocol, StreamHandlerFn]: ) -> Tuple[TProtocol, StreamHandlerFn]:
""" """
Negotiate performs protocol selection Negotiate performs protocol selection

View File

@ -14,6 +14,8 @@ Relevant go repo: https://github.com/libp2p/go-conn-security/blob/master/interfa
class AbstractSecureConn(ABC): class AbstractSecureConn(ABC):
conn: IRawConnection
@abstractmethod @abstractmethod
def get_local_peer(self) -> ID: def get_local_peer(self) -> ID:
pass pass

View File

@ -7,6 +7,7 @@ from libp2p.protocol_muxer.multiselect import Multiselect
from libp2p.protocol_muxer.multiselect_client import MultiselectClient from libp2p.protocol_muxer.multiselect_client import MultiselectClient
from libp2p.security.secure_conn_interface import ISecureConn from libp2p.security.secure_conn_interface import ISecureConn
from libp2p.security.secure_transport_interface import ISecureTransport from libp2p.security.secure_transport_interface import ISecureTransport
from libp2p.protocol_muxer.multiselect_communicator import RawConnectionCommunicator
from libp2p.typing import TProtocol from libp2p.typing import TProtocol
@ -74,14 +75,15 @@ class SecurityMultistream(ABC):
# instead of stream? In go repo, they pass in a raw conn # instead of stream? In go repo, they pass in a raw conn
# (https://raw.githubusercontent.com/libp2p/go-conn-security-multistream/master/ssms.go) # (https://raw.githubusercontent.com/libp2p/go-conn-security-multistream/master/ssms.go)
protocol = None protocol: TProtocol
communicator = RawConnectionCommunicator(conn)
if initiator: if initiator:
# Select protocol if initiator # Select protocol if initiator
protocol = await self.multiselect_client.select_one_of( protocol = await self.multiselect_client.select_one_of(
list(self.transports.keys()), conn list(self.transports.keys()), communicator
) )
else: else:
# Select protocol if non-initiator # Select protocol if non-initiator
protocol, _ = await self.multiselect.negotiate(conn) protocol, _ = await self.multiselect.negotiate(communicator)
# Return transport from protocol # Return transport from protocol
return self.transports[protocol] return self.transports[protocol]

View File

@ -2,6 +2,8 @@ import asyncio
import struct import struct
from typing import Tuple from typing import Tuple
from libp2p.typing import StreamReader
def encode_uvarint(number: int) -> bytes: def encode_uvarint(number: int) -> bytes:
"""Pack `number` into varint bytes""" """Pack `number` into varint bytes"""
@ -31,9 +33,7 @@ def decode_uvarint(buff: bytes, index: int) -> Tuple[int, int]:
return result, index + 1 return result, index + 1
async def decode_uvarint_from_stream( async def decode_uvarint_from_stream(reader: StreamReader, timeout: float) -> int:
reader: asyncio.StreamReader, timeout: float
) -> int:
shift = 0 shift = 0
result = 0 result = 0
while True: while True:

View File

@ -1,7 +1,6 @@
import asyncio
from typing import TYPE_CHECKING, Awaitable, Callable, NewType, Union from typing import TYPE_CHECKING, Awaitable, Callable, NewType, Union
from libp2p.network.connection.raw_connection_interface import IRawConnection
if TYPE_CHECKING: if TYPE_CHECKING:
from libp2p.network.stream.net_stream_interface import INetStream # noqa: F401 from libp2p.network.stream.net_stream_interface import INetStream # noqa: F401
from libp2p.stream_muxer.abc import IMuxedStream # noqa: F401 from libp2p.stream_muxer.abc import IMuxedStream # noqa: F401
@ -10,4 +9,4 @@ TProtocol = NewType("TProtocol", str)
StreamHandlerFn = Callable[["INetStream"], Awaitable[None]] StreamHandlerFn = Callable[["INetStream"], Awaitable[None]]
NegotiableTransport = Union["IMuxedStream", IRawConnection] StreamReader = Union["IMuxedStream", asyncio.StreamReader]