diff --git a/libp2p/network/stream/net_stream.py b/libp2p/network/stream/net_stream.py index 8cfb635..f708962 100644 --- a/libp2p/network/stream/net_stream.py +++ b/libp2p/network/stream/net_stream.py @@ -5,6 +5,7 @@ class NetStream(INetStream): def __init__(self, muxed_stream): self.muxed_stream = muxed_stream + self.mplex_conn = muxed_stream.mplex_conn self.protocol_id = None def get_protocol(self): diff --git a/tests/libp2p/test_notify.py b/tests/libp2p/test_notify.py index e9a62cd..3021cf1 100644 --- a/tests/libp2p/test_notify.py +++ b/tests/libp2p/test_notify.py @@ -33,7 +33,8 @@ class MyNotifee(INotifee): pass async def connected(self, network, conn): - self.events.append("connected" + self.val_to_append_to_event) + self.events.append(["connected" + self.val_to_append_to_event,\ + conn]) async def disconnected(self, network, conn): pass @@ -70,7 +71,8 @@ async def test_one_notifier(): # Ensure the connected and opened_stream events were hit in MyNotifee obj # and that stream passed into opened_stream matches the stream created on # node_a - assert events == ["connected0", ["opened_stream0", stream]] + assert events == [["connected0", stream.mplex_conn], \ + ["opened_stream0", stream]] messages = ["hello", "hello"] for message in messages: @@ -112,8 +114,8 @@ async def test_two_notifiers(): # Ensure the connected and opened_stream events were hit in both Notifee objs # and that the stream passed into opened_stream matches the stream created on # node_a - assert events0 == ["connected0", ["opened_stream0", stream]] - assert events1 == ["connected1", ["opened_stream1", stream]] + assert events0 == [["connected0", stream.mplex_conn], ["opened_stream0", stream]] + assert events1 == [["connected1", stream.mplex_conn], ["opened_stream1", stream]] messages = ["hello", "hello"] @@ -158,7 +160,7 @@ async def test_ten_notifiers(): # and that the stream passed into opened_stream matches the stream created on # node_a for i in range(num_notifiers): - assert events_lst[i] == ["connected" + str(i), \ + assert events_lst[i] == [["connected" + str(i), stream.mplex_conn], \ ["opened_stream" + str(i), stream]] messages = ["hello", "hello"]