diff --git a/libp2p/security/noise/connection.py b/libp2p/security/noise/connection.py index 940d807..8f94a0f 100644 --- a/libp2p/security/noise/connection.py +++ b/libp2p/security/noise/connection.py @@ -1,11 +1,15 @@ +from noise.connection import NoiseConnection as NoiseState + from libp2p.crypto.keys import PrivateKey from libp2p.network.connection.raw_connection_interface import IRawConnection from libp2p.peer.id import ID from libp2p.security.base_session import BaseSession +from libp2p.security.noise.io import MsgReadWriter, NoiseTransportReadWriter class NoiseConnection(BaseSession): - conn: IRawConnection + read_writer: IRawConnection + noise_state: NoiseState def __init__( self, @@ -14,17 +18,21 @@ class NoiseConnection(BaseSession): remote_peer: ID, conn: IRawConnection, is_initiator: bool, + noise_state: NoiseState, ) -> None: super().__init__(local_peer, local_private_key, is_initiator, remote_peer) self.conn = conn + self.noise_state = noise_state + + def get_msg_read_writer(self) -> MsgReadWriter: + return NoiseTransportReadWriter(self.conn, self.noise_state) async def read(self, n: int = None) -> bytes: - # TODO: Add decryption logic here - return await self.conn.read(n) + # TODO: Use a buffer to handle buffered messages. + return await self.get_msg_read_writer().read_msg() async def write(self, data: bytes) -> None: - # TODO: Add encryption logic here - await self.conn.write(data) + await self.get_msg_read_writer().write_msg(data) async def close(self) -> None: await self.conn.close() diff --git a/libp2p/security/noise/io.py b/libp2p/security/noise/io.py index 6a6f7be..2533530 100644 --- a/libp2p/security/noise/io.py +++ b/libp2p/security/noise/io.py @@ -88,7 +88,6 @@ def decode_msg_body(noise_msg: bytes) -> bytes: ] -# TODO: Add comments class NoiseHandshakeReadWriter(MsgReadWriter): read_writer: MsgReadWriter noise_state: NoiseState @@ -106,3 +105,22 @@ class NoiseHandshakeReadWriter(MsgReadWriter): noise_msg_encrypted = await self.read_writer.read_msg() noise_msg = self.noise_state.read_message(noise_msg_encrypted) return decode_msg_body(noise_msg) + + +class NoiseTransportReadWriter(MsgReadWriter): + read_writer: MsgReadWriter + noise_state: NoiseState + + def __init__(self, conn: IRawConnection, noise_state: NoiseState) -> None: + self.read_writer = NoisePacketReadWriter(cast(ReadWriter, conn)) + self.noise_state = noise_state + + async def write_msg(self, data: bytes) -> None: + noise_msg = encode_msg_body(data) + data_encrypted = self.noise_state.encrypt(noise_msg) + await self.read_writer.write_msg(data_encrypted) + + async def read_msg(self) -> bytes: + noise_msg_encrypted = await self.read_writer.read_msg() + noise_msg = self.noise_state.decrypt(noise_msg_encrypted) + return decode_msg_body(noise_msg) diff --git a/libp2p/security/noise/patterns.py b/libp2p/security/noise/patterns.py index 783c101..91ce350 100644 --- a/libp2p/security/noise/patterns.py +++ b/libp2p/security/noise/patterns.py @@ -113,6 +113,7 @@ class PatternXX(BasePattern): remote_peer_id_from_pubkey, conn, False, + noise_state, ) async def handshake_outbound( @@ -162,5 +163,5 @@ class PatternXX(BasePattern): ) return NoiseConnection( - self.local_peer, self.libp2p_privkey, remote_peer, conn, False + self.local_peer, self.libp2p_privkey, remote_peer, conn, False, noise_state )