Fix bug in security multistream

This commit is contained in:
Stuckinaboot 2019-04-30 16:07:26 -04:00
parent a0bd6e5eb0
commit e555f17a7b
2 changed files with 20 additions and 9 deletions

View File

@ -40,7 +40,7 @@ class SecurityMultistream(ABC):
""" """
# Select a secure transport # Select a secure transport
transport = await self.select_transport(conn, True) transport = await self.select_transport(conn, False)
# Create secured connection # Create secured connection
secure_conn = await transport.secure_inbound(conn) secure_conn = await transport.secure_inbound(conn)
@ -81,7 +81,6 @@ class SecurityMultistream(ABC):
protocol = await self.multiselect_client.select_one_of(list(self.transports.keys()), conn) protocol = await self.multiselect_client.select_one_of(list(self.transports.keys()), conn)
else: else:
# Select protocol if non-initiator # Select protocol if non-initiator
protocol = await self.multiselect.negotiate(conn) protocol, _ = await self.multiselect.negotiate(conn)
# Return transport from protocol # Return transport from protocol
return self.transports[protocol] return self.transports[protocol]

View File

@ -32,8 +32,8 @@ async def perform_simple_test(assertion_func, transports_for_initiator, transpor
# TODO: implement -- note we need to introduce the notion of communicating over a raw connection # TODO: implement -- note we need to introduce the notion of communicating over a raw connection
# for testing, we do NOT want to communicate over a stream so we can't just create two nodes # for testing, we do NOT want to communicate over a stream so we can't just create two nodes
# and use their conn because our mplex will internally relay messages to a stream # and use their conn because our mplex will internally relay messages to a stream
sec_opt1 = dict((str(i), transport) for i, transport in enumerate(transports_for_initiator)) sec_opt1 = transports_for_initiator
sec_opt2 = dict((str(i), transport) for i, transport in enumerate(transports_for_noninitiator)) sec_opt2 = transports_for_noninitiator
node1 = await new_node(transport_opt=["/ip4/127.0.0.1/tcp/0"], sec_opt=sec_opt1) node1 = await new_node(transport_opt=["/ip4/127.0.0.1/tcp/0"], sec_opt=sec_opt1)
node2 = await new_node(transport_opt=["/ip4/127.0.0.1/tcp/0"], sec_opt=sec_opt2) node2 = await new_node(transport_opt=["/ip4/127.0.0.1/tcp/0"], sec_opt=sec_opt2)
@ -62,8 +62,8 @@ async def perform_simple_test(assertion_func, transports_for_initiator, transpor
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_single_insecure_security_transport_succeeds(): async def test_single_insecure_security_transport_succeeds():
transports_for_initiator = [InsecureTransport("foo")] transports_for_initiator = {"foo": InsecureTransport("foo")}
transports_for_noninitiator = [InsecureTransport("foo")] transports_for_noninitiator = {"foo": InsecureTransport("foo")}
def assertion_func(details): def assertion_func(details):
assert details["id"] == "foo" assert details["id"] == "foo"
@ -73,8 +73,8 @@ async def test_single_insecure_security_transport_succeeds():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_single_simple_test_security_transport_succeeds(): async def test_single_simple_test_security_transport_succeeds():
transports_for_initiator = [SimpleSecurityTransport("tacos")] transports_for_initiator = {"tacos": SimpleSecurityTransport("tacos")}
transports_for_noninitiator = [SimpleSecurityTransport("tacos")] transports_for_noninitiator = {"tacos": SimpleSecurityTransport("tacos")}
def assertion_func(details): def assertion_func(details):
assert details["key_phrase"] == "tacos" assert details["key_phrase"] == "tacos"
@ -82,3 +82,15 @@ async def test_single_simple_test_security_transport_succeeds():
await perform_simple_test(assertion_func, await perform_simple_test(assertion_func,
transports_for_initiator, transports_for_noninitiator) transports_for_initiator, transports_for_noninitiator)
@pytest.mark.asyncio
async def test_two_simple_test_security_transport_for_initiator_succeeds():
transports_for_initiator = {"tacos": SimpleSecurityTransport("tacos"),
"shleep": SimpleSecurityTransport("shleep")}
transports_for_noninitiator = {"shleep": SimpleSecurityTransport("shleep")}
def assertion_func(details):
assert details["key_phrase"] == "shleep"
await perform_simple_test(assertion_func,
transports_for_initiator, transports_for_noninitiator)