diff --git a/libp2p/security/noise/io.py b/libp2p/security/noise/io.py index 3dd024d..c6d48e4 100644 --- a/libp2p/security/noise/io.py +++ b/libp2p/security/noise/io.py @@ -32,23 +32,26 @@ class BaseNoiseMsgReadWriter(EncryptedMsgReadWriter): read_writer: MsgReadWriteCloser noise_state: NoiseState + # FIXME: This prefix is added in msg#3 in Go. Check whether it's a desired behavior. + prefix: bytes = b"\x00" * 32 + def __init__(self, conn: IRawConnection, noise_state: NoiseState) -> None: self.read_writer = NoisePacketReadWriter(cast(ReadWriteCloser, conn)) self.noise_state = noise_state - async def write_msg(self, data: bytes) -> None: + async def write_msg(self, data: bytes, prefix_encoded: bool = False) -> None: data_encrypted = self.encrypt(data) - # FIXME: Decide whether this prefix should be added or not. - # if not first: - # data_encrypted = b"\x00" * 32 + data_encrypted - await self.read_writer.write_msg(data_encrypted) + if prefix_encoded: + await self.read_writer.write_msg(self.prefix + data_encrypted) + else: + await self.read_writer.write_msg(data_encrypted) - async def read_msg(self) -> bytes: + async def read_msg(self, prefix_encoded: bool = False) -> bytes: noise_msg_encrypted = await self.read_writer.read_msg() - # FIXME: Decide whether this prefix should be added or not. - # if not first: - # noise_msg_encrypted = noise_msg_encrypted[32:] - return self.decrypt(noise_msg_encrypted) + if prefix_encoded: + return self.decrypt(noise_msg_encrypted[len(self.prefix) :]) + else: + return self.decrypt(noise_msg_encrypted) async def close(self) -> None: await self.read_writer.close() diff --git a/libp2p/security/noise/patterns.py b/libp2p/security/noise/patterns.py index fd6efc5..1e537d0 100644 --- a/libp2p/security/noise/patterns.py +++ b/libp2p/security/noise/patterns.py @@ -88,7 +88,7 @@ class PatternXX(BasePattern): await read_writer.write_msg(msg_2) # Receive and consume msg#3. - msg_3 = await read_writer.read_msg() + msg_3 = await read_writer.read_msg(prefix_encoded=True) peer_handshake_payload = NoiseHandshakePayload.deserialize(msg_3) if handshake_state.rs is None: @@ -156,7 +156,7 @@ class PatternXX(BasePattern): # Send msg#3, which includes our encrypted payload and our noise static key. our_payload = self.make_handshake_payload() msg_3 = our_payload.serialize() - await read_writer.write_msg(msg_3) + await read_writer.write_msg(msg_3, prefix_encoded=True) if not noise_state.handshake_finished: raise HandshakeHasNotFinished( diff --git a/tests_interop/conftest.py b/tests_interop/conftest.py index b14f91c..a59af28 100644 --- a/tests_interop/conftest.py +++ b/tests_interop/conftest.py @@ -7,14 +7,15 @@ import trio from libp2p.io.abc import ReadWriteCloser from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID +from libp2p.security.noise.transport import PROTOCOL_ID as NOISE_PROTOCOL_ID from libp2p.tools.factories import HostFactory, PubsubFactory from libp2p.tools.interop.daemon import make_p2pd from libp2p.tools.interop.utils import connect -@pytest.fixture -def security_protocol(): - return PLAINTEXT_PROTOCOL_ID +@pytest.fixture(params=[PLAINTEXT_PROTOCOL_ID, NOISE_PROTOCOL_ID]) +def security_protocol(request): + return request.param @pytest.fixture