diff --git a/libp2p/protocol_muxer/multiselect.py b/libp2p/protocol_muxer/multiselect.py index c76e07f..13e6a0d 100644 --- a/libp2p/protocol_muxer/multiselect.py +++ b/libp2p/protocol_muxer/multiselect.py @@ -59,12 +59,18 @@ class Multiselect(IMultiselectMuxer): protocol = TProtocol(command) if protocol in self.handlers: # Tell counterparty we have decided on a protocol - await communicator.write(protocol) + try: + await communicator.write(protocol) + except MultiselectCommunicatorError as error: + raise MultiselectError(error) # Return the decided on protocol return protocol, self.handlers[protocol] # Tell counterparty this protocol was not found - await communicator.write(PROTOCOL_NOT_FOUND_MSG) + try: + await communicator.write(PROTOCOL_NOT_FOUND_MSG) + except MultiselectCommunicatorError as error: + raise MultiselectError(error) async def handshake(self, communicator: IMultiselectCommunicator) -> None: """ @@ -76,7 +82,10 @@ class Multiselect(IMultiselectMuxer): # TODO: Use format used by go repo for messages # Send our MULTISELECT_PROTOCOL_ID to other party - await communicator.write(MULTISELECT_PROTOCOL_ID) + try: + await communicator.write(MULTISELECT_PROTOCOL_ID) + except MultiselectCommunicatorError as error: + raise MultiselectError(error) # Read in the protocol ID from other party try: diff --git a/libp2p/protocol_muxer/multiselect_client.py b/libp2p/protocol_muxer/multiselect_client.py index 51af025..361100b 100644 --- a/libp2p/protocol_muxer/multiselect_client.py +++ b/libp2p/protocol_muxer/multiselect_client.py @@ -27,7 +27,10 @@ class MultiselectClient(IMultiselectClient): # TODO: Use format used by go repo for messages # Send our MULTISELECT_PROTOCOL_ID to counterparty - await communicator.write(MULTISELECT_PROTOCOL_ID) + try: + await communicator.write(MULTISELECT_PROTOCOL_ID) + except MultiselectCommunicatorError as error: + raise MultiselectClientError(error) # Read in the protocol ID from other party try: @@ -79,7 +82,10 @@ class MultiselectClient(IMultiselectClient): """ # Tell counterparty we want to use protocol - await communicator.write(protocol) + try: + await communicator.write(protocol) + except MultiselectCommunicatorError as error: + raise MultiselectClientError(error) # Get what counterparty says in response try: diff --git a/libp2p/protocol_muxer/multiselect_communicator.py b/libp2p/protocol_muxer/multiselect_communicator.py index d946850..6dbc50f 100644 --- a/libp2p/protocol_muxer/multiselect_communicator.py +++ b/libp2p/protocol_muxer/multiselect_communicator.py @@ -15,7 +15,12 @@ class MultiselectCommunicator(IMultiselectCommunicator): async def write(self, msg_str: str) -> None: msg_bytes = encode_delim(msg_str.encode()) - await self.read_writer.write(msg_bytes) + try: + await self.read_writer.write(msg_bytes) + except IOException: + raise MultiselectCommunicatorError( + "fail to write to multiselect communicator" + ) async def read(self) -> str: try: