diff --git a/libp2p/network/connection/swarm_connection.py b/libp2p/network/connection/swarm_connection.py index 15816fc..10b83a1 100644 --- a/libp2p/network/connection/swarm_connection.py +++ b/libp2p/network/connection/swarm_connection.py @@ -77,7 +77,7 @@ class SwarmConn(INetConn): async def _notify_disconnected(self) -> None: for notifee in self.swarm.notifees: - await notifee.disconnected(self.swarm, self.conn) + await notifee.disconnected(self.swarm, self) async def start(self) -> None: await self.run_task(self._handle_new_streams()) diff --git a/libp2p/network/notifee_interface.py b/libp2p/network/notifee_interface.py index ef996bf..c31f473 100644 --- a/libp2p/network/notifee_interface.py +++ b/libp2p/network/notifee_interface.py @@ -3,8 +3,8 @@ from typing import TYPE_CHECKING from multiaddr import Multiaddr +from libp2p.network.connection.net_connection_interface import INetConn from libp2p.network.stream.net_stream_interface import INetStream -from libp2p.stream_muxer.abc import IMuxedConn if TYPE_CHECKING: from .network_interface import INetwork # noqa: F401 @@ -26,14 +26,14 @@ class INotifee(ABC): """ @abstractmethod - async def connected(self, network: "INetwork", conn: IMuxedConn) -> None: + async def connected(self, network: "INetwork", conn: INetConn) -> None: """ :param network: network the connection was opened on :param conn: connection that was opened """ @abstractmethod - async def disconnected(self, network: "INetwork", conn: IMuxedConn) -> None: + async def disconnected(self, network: "INetwork", conn: INetConn) -> None: """ :param network: network the connection was closed on :param conn: connection that was closed diff --git a/libp2p/network/stream/net_stream.py b/libp2p/network/stream/net_stream.py index d500c08..3ae7c9c 100644 --- a/libp2p/network/stream/net_stream.py +++ b/libp2p/network/stream/net_stream.py @@ -1,4 +1,4 @@ -from libp2p.stream_muxer.abc import IMuxedConn, IMuxedStream +from libp2p.stream_muxer.abc import IMuxedStream from libp2p.stream_muxer.exceptions import ( MuxedStreamClosed, MuxedStreamEOF, @@ -16,13 +16,11 @@ from .net_stream_interface import INetStream class NetStream(INetStream): muxed_stream: IMuxedStream - # TODO: Why we expose `mplex_conn` here? - mplex_conn: IMuxedConn protocol_id: TProtocol def __init__(self, muxed_stream: IMuxedStream) -> None: self.muxed_stream = muxed_stream - self.mplex_conn = muxed_stream.mplex_conn + self.muxed_conn = muxed_stream.muxed_conn self.protocol_id = None def get_protocol(self) -> TProtocol: diff --git a/libp2p/network/stream/net_stream_interface.py b/libp2p/network/stream/net_stream_interface.py index d054789..41bf423 100644 --- a/libp2p/network/stream/net_stream_interface.py +++ b/libp2p/network/stream/net_stream_interface.py @@ -7,7 +7,7 @@ from libp2p.typing import TProtocol class INetStream(ReadWriteCloser): - mplex_conn: IMuxedConn + muxed_conn: IMuxedConn @abstractmethod def get_protocol(self) -> TProtocol: diff --git a/libp2p/network/swarm.py b/libp2p/network/swarm.py index 272a8a9..a0d9d77 100644 --- a/libp2p/network/swarm.py +++ b/libp2p/network/swarm.py @@ -278,8 +278,7 @@ class Swarm(INetwork): self.connections[muxed_conn.peer_id] = swarm_conn # Call notifiers since event occurred for notifee in self.notifees: - # TODO: Call with other type of conn? - await notifee.connected(self, muxed_conn) + await notifee.connected(self, swarm_conn) await swarm_conn.start() return swarm_conn diff --git a/libp2p/pubsub/pubsub.py b/libp2p/pubsub/pubsub.py index b162b89..17cf034 100644 --- a/libp2p/pubsub/pubsub.py +++ b/libp2p/pubsub/pubsub.py @@ -151,7 +151,7 @@ class Pubsub: messages from other nodes :param stream: stream to continously read from """ - peer_id = stream.mplex_conn.peer_id + peer_id = stream.muxed_conn.peer_id while True: incoming: bytes = await read_varint_prefixed_bytes(stream) diff --git a/libp2p/pubsub/pubsub_notifee.py b/libp2p/pubsub/pubsub_notifee.py index 6ecab1a..627152e 100644 --- a/libp2p/pubsub/pubsub_notifee.py +++ b/libp2p/pubsub/pubsub_notifee.py @@ -2,10 +2,10 @@ from typing import TYPE_CHECKING from multiaddr import Multiaddr +from libp2p.network.connection.net_connection_interface import INetConn from libp2p.network.network_interface import INetwork from libp2p.network.notifee_interface import INotifee from libp2p.network.stream.net_stream_interface import INetStream -from libp2p.stream_muxer.abc import IMuxedConn if TYPE_CHECKING: import asyncio # noqa: F401 @@ -29,16 +29,16 @@ class PubsubNotifee(INotifee): async def closed_stream(self, network: INetwork, stream: INetStream) -> None: pass - async def connected(self, network: INetwork, conn: IMuxedConn) -> None: + async def connected(self, network: INetwork, conn: INetConn) -> None: """ Add peer_id to initiator_peers_queue, so that this peer_id can be used to create a stream and we only want to have one pubsub stream with each peer. :param network: network the connection was opened on :param conn: connection that was opened """ - await self.initiator_peers_queue.put(conn.peer_id) + await self.initiator_peers_queue.put(conn.conn.peer_id) - async def disconnected(self, network: INetwork, conn: IMuxedConn) -> None: + async def disconnected(self, network: INetwork, conn: INetConn) -> None: pass async def listen(self, network: INetwork, multiaddr: Multiaddr) -> None: diff --git a/libp2p/stream_muxer/abc.py b/libp2p/stream_muxer/abc.py index 78438f2..4af110b 100644 --- a/libp2p/stream_muxer/abc.py +++ b/libp2p/stream_muxer/abc.py @@ -55,7 +55,7 @@ class IMuxedConn(ABC): class IMuxedStream(ReadWriteCloser): - mplex_conn: IMuxedConn + muxed_conn: IMuxedConn @abstractmethod async def reset(self) -> None: diff --git a/libp2p/stream_muxer/mplex/mplex_stream.py b/libp2p/stream_muxer/mplex/mplex_stream.py index 8cabccc..06b90fa 100644 --- a/libp2p/stream_muxer/mplex/mplex_stream.py +++ b/libp2p/stream_muxer/mplex/mplex_stream.py @@ -18,7 +18,7 @@ class MplexStream(IMuxedStream): name: str stream_id: StreamID - mplex_conn: "Mplex" + muxed_conn: "Mplex" read_deadline: int write_deadline: int @@ -32,15 +32,15 @@ class MplexStream(IMuxedStream): _buf: bytearray - def __init__(self, name: str, stream_id: StreamID, mplex_conn: "Mplex") -> None: + def __init__(self, name: str, stream_id: StreamID, muxed_conn: "Mplex") -> None: """ create new MuxedStream in muxer :param stream_id: stream id of this stream - :param mplex_conn: muxed connection of this muxed_stream + :param muxed_conn: muxed connection of this muxed_stream """ self.name = name self.stream_id = stream_id - self.mplex_conn = mplex_conn + self.muxed_conn = muxed_conn self.read_deadline = None self.write_deadline = None self.event_local_closed = asyncio.Event() @@ -147,7 +147,7 @@ class MplexStream(IMuxedStream): if self.is_initiator else HeaderTags.MessageReceiver ) - return await self.mplex_conn.send_message(flag, data, self.stream_id) + return await self.muxed_conn.send_message(flag, data, self.stream_id) async def close(self) -> None: """ @@ -163,8 +163,8 @@ class MplexStream(IMuxedStream): flag = ( HeaderTags.CloseInitiator if self.is_initiator else HeaderTags.CloseReceiver ) - # TODO: Raise when `mplex_conn.send_message` fails and `Mplex` isn't shutdown. - await self.mplex_conn.send_message(flag, None, self.stream_id) + # TODO: Raise when `muxed_conn.send_message` fails and `Mplex` isn't shutdown. + await self.muxed_conn.send_message(flag, None, self.stream_id) _is_remote_closed: bool async with self.close_lock: @@ -173,8 +173,8 @@ class MplexStream(IMuxedStream): if _is_remote_closed: # Both sides are closed, we can safely remove the buffer from the dict. - async with self.mplex_conn.streams_lock: - del self.mplex_conn.streams[self.stream_id] + async with self.muxed_conn.streams_lock: + del self.muxed_conn.streams[self.stream_id] async def reset(self) -> None: """ @@ -196,19 +196,19 @@ class MplexStream(IMuxedStream): else HeaderTags.ResetReceiver ) asyncio.ensure_future( - self.mplex_conn.send_message(flag, None, self.stream_id) + self.muxed_conn.send_message(flag, None, self.stream_id) ) await asyncio.sleep(0) self.event_local_closed.set() self.event_remote_closed.set() - async with self.mplex_conn.streams_lock: + async with self.muxed_conn.streams_lock: if ( - self.mplex_conn.streams is not None - and self.stream_id in self.mplex_conn.streams + self.muxed_conn.streams is not None + and self.stream_id in self.muxed_conn.streams ): - del self.mplex_conn.streams[self.stream_id] + del self.muxed_conn.streams[self.stream_id] # TODO deadline not in use def set_deadline(self, ttl: int) -> bool: diff --git a/tests/libp2p/test_notify.py b/tests/libp2p/test_notify.py index b9a8707..e47a044 100644 --- a/tests/libp2p/test_notify.py +++ b/tests/libp2p/test_notify.py @@ -34,7 +34,7 @@ class MyNotifee(INotifee): pass async def connected(self, network, conn): - self.events.append(["connected" + self.val_to_append_to_event, conn]) + self.events.append(["connected" + self.val_to_append_to_event, conn.conn]) async def disconnected(self, network, conn): pass @@ -79,7 +79,7 @@ async def test_one_notifier(): # Ensure the connected and opened_stream events were hit in MyNotifee obj # and that stream passed into opened_stream matches the stream created on # node_a - assert events == [["connected0", stream.mplex_conn], ["opened_stream0", stream]] + assert events == [["connected0", stream.muxed_conn], ["opened_stream0", stream]] messages = ["hello", "hello"] for message in messages: @@ -103,7 +103,7 @@ async def test_one_notifier_on_two_nodes(): # and that the stream passed into opened_stream matches the stream created on # node_b assert events_b == [ - ["connectedb", stream.mplex_conn], + ["connectedb", stream.muxed_conn], ["opened_streamb", stream], ] for message in messages: @@ -126,7 +126,7 @@ async def test_one_notifier_on_two_nodes(): # Ensure the connected and opened_stream events were hit in MyNotifee obj # and that stream passed into opened_stream matches the stream created on # node_a - assert events_a == [["connecteda", stream.mplex_conn], ["opened_streama", stream]] + assert events_a == [["connecteda", stream.muxed_conn], ["opened_streama", stream]] for message in messages: expected_resp = ACK + message @@ -164,7 +164,7 @@ async def test_one_notifier_on_two_nodes_with_listen(): # node_b assert events_b == [ ["listenedb", node_b_multiaddr], - ["connectedb", stream.mplex_conn], + ["connectedb", stream.muxed_conn], ["opened_streamb", stream], ] for message in messages: @@ -190,7 +190,7 @@ async def test_one_notifier_on_two_nodes_with_listen(): # Ensure the connected and opened_stream events were hit in MyNotifee obj # and that stream passed into opened_stream matches the stream created on # node_a - assert events_a == [["connecteda", stream.mplex_conn], ["opened_streama", stream]] + assert events_a == [["connecteda", stream.muxed_conn], ["opened_streama", stream]] for message in messages: expected_resp = ACK + message @@ -219,8 +219,8 @@ async def test_two_notifiers(): # Ensure the connected and opened_stream events were hit in both Notifee objs # and that the stream passed into opened_stream matches the stream created on # node_a - assert events0 == [["connected0", stream.mplex_conn], ["opened_stream0", stream]] - assert events1 == [["connected1", stream.mplex_conn], ["opened_stream1", stream]] + assert events0 == [["connected0", stream.muxed_conn], ["opened_stream0", stream]] + assert events1 == [["connected1", stream.muxed_conn], ["opened_stream1", stream]] messages = ["hello", "hello"] for message in messages: @@ -253,7 +253,7 @@ async def test_ten_notifiers(): # node_a for i in range(num_notifiers): assert events_lst[i] == [ - ["connected" + str(i), stream.mplex_conn], + ["connected" + str(i), stream.muxed_conn], ["opened_stream" + str(i), stream], ] @@ -280,7 +280,7 @@ async def test_ten_notifiers_on_two_nodes(): # node_b for i in range(num_notifiers): assert events_lst_b[i] == [ - ["connectedb" + str(i), stream.mplex_conn], + ["connectedb" + str(i), stream.muxed_conn], ["opened_streamb" + str(i), stream], ] while True: @@ -306,7 +306,7 @@ async def test_ten_notifiers_on_two_nodes(): # node_a for i in range(num_notifiers): assert events_lst_a[i] == [ - ["connecteda" + str(i), stream.mplex_conn], + ["connecteda" + str(i), stream.muxed_conn], ["opened_streama" + str(i), stream], ] diff --git a/tests/pubsub/test_pubsub.py b/tests/pubsub/test_pubsub.py index 3413949..29fdf36 100644 --- a/tests/pubsub/test_pubsub.py +++ b/tests/pubsub/test_pubsub.py @@ -233,7 +233,7 @@ class FakeNetStream: class FakeMplexConn(NamedTuple): peer_id: ID = ID(b"\x12\x20" + b"\x00" * 32) - mplex_conn = FakeMplexConn() + muxed_conn = FakeMplexConn() def __init__(self) -> None: self._queue = asyncio.Queue()