Fix msg encoding

- Change varint-prefix encode to fixedint-prefix(4 bytes) encode.
This commit is contained in:
mhchia 2019-08-17 21:41:17 +08:00
parent 22b1a5395d
commit bb7d37fd4f
No known key found for this signature in database
GPG Key ID: 389EFBEA1362589A
4 changed files with 38 additions and 15 deletions

View File

@ -177,8 +177,8 @@ class Swarm(INetwork):
Call listener listen with the multiaddr
Map multiaddr to listener
"""
for multiaddr in multiaddrs:
if str(multiaddr) in self.listeners:
for maddr in multiaddrs:
if str(maddr) in self.listeners:
return True
async def conn_handler(
@ -187,8 +187,8 @@ class Swarm(INetwork):
# Upgrade reader/write to a net_stream and pass \
# to appropriate stream handler (using multiaddr)
raw_conn = RawConnection(
multiaddr.value_for_protocol("ip4"),
multiaddr.value_for_protocol("tcp"),
maddr.value_for_protocol("ip4"),
maddr.value_for_protocol("tcp"),
reader,
writer,
False,
@ -215,19 +215,19 @@ class Swarm(INetwork):
try:
# Success
listener = self.transport.create_listener(conn_handler)
self.listeners[str(multiaddr)] = listener
await listener.listen(multiaddr)
self.listeners[str(maddr)] = listener
await listener.listen(maddr)
# Call notifiers since event occurred
for notifee in self.notifees:
await notifee.listen(self, multiaddr)
await notifee.listen(self, maddr)
return True
except IOError:
# Failed. Continue looping.
print("Failed to connect to: " + str(multiaddr))
print("Failed to connect to: " + str(maddr))
# No multiaddr succeeded
# No maddr succeeded
return False
def notify(self, notifee: INotifee) -> bool:

View File

@ -79,7 +79,9 @@ class Multiselect(IMultiselectMuxer):
# Confirm that the protocols are the same
if not validate_handshake(handshake_contents):
raise MultiselectError("multiselect protocol ID mismatch")
raise MultiselectError(
f"multiselect protocol ID mismatch: handshake_contents={handshake_contents}"
)
# Handshake succeeded if this point is reached

View File

@ -5,7 +5,7 @@ from libp2p.security.base_session import BaseSession
from libp2p.security.base_transport import BaseSecureTransport
from libp2p.security.secure_conn_interface import ISecureConn
from libp2p.typing import TProtocol
from libp2p.utils import encode_varint_prefixed, read_varint_prefixed_bytes
from libp2p.utils import encode_fixedint_prefixed, read_fixedint_prefixed
from .exceptions import UpgradeFailure
from .pb import plaintext_pb2
@ -20,12 +20,15 @@ class InsecureSession(BaseSession):
# FIXME: Update the read/write to `BaseSession`
async def run_handshake(self):
msg = make_exchange_message(self.local_private_key.get_public_key())
self.writer.write(encode_varint_prefixed(msg.SerializeToString()))
msg_bytes = msg.SerializeToString()
encoded_msg_bytes = encode_fixedint_prefixed(msg_bytes)
self.writer.write(encoded_msg_bytes)
await self.writer.drain()
msg_bytes_other_side = await read_varint_prefixed_bytes(self.reader)
msg_bytes_other_side = await read_fixedint_prefixed(self.reader)
msg_other_side = plaintext_pb2.Exchange()
msg_other_side.ParseFromString(msg_bytes_other_side)
# TODO: Verify public key with peer id
# TODO: Store public key
self.remote_peer_id = ID(msg_other_side.id)

View File

@ -4,7 +4,7 @@ from typing import Tuple
from libp2p.typing import StreamReader
TIMEOUT = 1
TIMEOUT = 10
def encode_uvarint(number: int) -> bytes:
@ -64,7 +64,8 @@ async def read_varint_prefixed_bytes(
return await reader.read(len_msg)
# Delimited read/write
# Delimited read/write, used by multistream-select.
# Reference: https://github.com/gogo/protobuf/blob/07eab6a8298cf32fac45cceaac59424f98421bbc/io/varint.go#L109-L126 # noqa: E501
def encode_delim(msg_str: str) -> bytes:
@ -75,3 +76,20 @@ def encode_delim(msg_str: str) -> bytes:
async def read_delim(reader: StreamReader, timeout: int = TIMEOUT) -> str:
msg_bytes = await read_varint_prefixed_bytes(reader, timeout)
return msg_bytes.decode().rstrip()
SIZE_LEN_BYTES = 4
# Fixed-prefixed read/write, used by "/plaintext/2.0.0".
# Reference: https://github.com/libp2p/go-msgio/blob/d5bbf59d3c4240266b1d2e5df9dc993454c42011/num.go#L11-L33 # noqa: E501 # noqa: E501
def encode_fixedint_prefixed(msg_bytes: bytes) -> bytes:
len_prefix = len(msg_bytes).to_bytes(SIZE_LEN_BYTES, "big")
return len_prefix + msg_bytes
async def read_fixedint_prefixed(reader: StreamReader) -> bytes:
len_bytes = await reader.read(SIZE_LEN_BYTES)
len_int = int.from_bytes(len_bytes, "big")
return await reader.read(len_int)