diff --git a/libp2p/stream_muxer/mplex/constants.py b/libp2p/stream_muxer/mplex/constants.py index a0537b2..8989e76 100644 --- a/libp2p/stream_muxer/mplex/constants.py +++ b/libp2p/stream_muxer/mplex/constants.py @@ -1 +1,11 @@ -HEADER_TAGS = {"NEW_STREAM": 0, "MESSAGE": 2, "CLOSE": 4, "RESET": 6} +from enum import Enum + + +class HeaderTags(Enum): + NewStream = 0 + MessageReceiver = 1 + MessageInitiator = 2 + CloseReceiver = 3 + CloseInitiator = 4 + ResetReceiver = 5 + ResetInitiator = 6 diff --git a/libp2p/stream_muxer/mplex/mplex.py b/libp2p/stream_muxer/mplex/mplex.py index f00588b..08970ff 100644 --- a/libp2p/stream_muxer/mplex/mplex.py +++ b/libp2p/stream_muxer/mplex/mplex.py @@ -1,8 +1,9 @@ import asyncio -from ..muxed_connection_interface import IMuxedConn +from .constants import HeaderTags +from .utils import encode_uvarint, decode_uvarint_from_stream from .mplex_stream import MplexStream -from .utils import decode_uvarint_from_stream, encode_uvarint, get_flag +from ..muxed_connection_interface import IMuxedConn class Mplex(IMuxedConn): @@ -78,7 +79,7 @@ class Mplex(IMuxedConn): stream_id = self.raw_conn.next_stream_id() stream = MplexStream(stream_id, multi_addr, self) self.buffers[stream_id] = asyncio.Queue() - await self.send_message(get_flag(self.initiator, "NEW_STREAM"), None, stream_id) + await self.send_message(HeaderTags.NewStream, None, stream_id) return stream async def accept_stream(self): @@ -90,7 +91,7 @@ class Mplex(IMuxedConn): stream = MplexStream(stream_id, False, self) asyncio.ensure_future(self.generic_protocol_handler(stream)) - async def send_message(self, flag, data, stream_id): + async def send_message(self, flag: HeaderTags, data, stream_id): """ sends a message over the connection :param header: header to use @@ -99,7 +100,7 @@ class Mplex(IMuxedConn): :return: True if success """ # << by 3, then or with flag - header = (stream_id << 3) | flag + header = (stream_id << 3) | flag.value header = encode_uvarint(header) if data is None: @@ -135,7 +136,7 @@ class Mplex(IMuxedConn): self.buffers[stream_id] = asyncio.Queue() await self.stream_queue.put(stream_id) - if flag is get_flag(True, "NEW_STREAM"): + if flag == HeaderTags.NewStream.value: # new stream detected on connection await self.accept_stream() diff --git a/libp2p/stream_muxer/mplex/mplex_stream.py b/libp2p/stream_muxer/mplex/mplex_stream.py index a1a25b7..e452fbd 100644 --- a/libp2p/stream_muxer/mplex/mplex_stream.py +++ b/libp2p/stream_muxer/mplex/mplex_stream.py @@ -2,7 +2,7 @@ import asyncio from libp2p.stream_muxer.muxed_stream_interface import IMuxedStream -from .utils import get_flag +from .constants import HeaderTags class MplexStream(IMuxedStream): @@ -38,9 +38,12 @@ class MplexStream(IMuxedStream): write to stream :return: number of bytes written """ - return await self.mplex_conn.send_message( - get_flag(self.initiator, "MESSAGE"), data, self.stream_id + flag = ( + HeaderTags.MessageInitiator + if self.initiator + else HeaderTags.MessageReceiver ) + return await self.mplex_conn.send_message(flag, data, self.stream_id) async def close(self): """ @@ -50,7 +53,8 @@ class MplexStream(IMuxedStream): """ # TODO error handling with timeout # TODO understand better how mutexes are used from go repo - await self.mplex_conn.send_message(get_flag(self.initiator, "CLOSE"), None, self.stream_id) + flag = HeaderTags.CloseInitiator if self.initiator else HeaderTags.CloseReceiver + await self.mplex_conn.send_message(flag, None, self.stream_id) remote_lock = "" async with self.stream_lock: @@ -78,9 +82,12 @@ class MplexStream(IMuxedStream): return True if not self.remote_closed: - await self.mplex_conn.send_message( - get_flag(self.initiator, "RESET"), None, self.stream_id + flag = ( + HeaderTags.ResetInitiator + if self.initiator + else HeaderTags.ResetInitiator ) + await self.mplex_conn.send_message(flag, None, self.stream_id) self.local_closed = True self.remote_closed = True diff --git a/libp2p/stream_muxer/mplex/utils.py b/libp2p/stream_muxer/mplex/utils.py index 70dcd12..7bd5772 100644 --- a/libp2p/stream_muxer/mplex/utils.py +++ b/libp2p/stream_muxer/mplex/utils.py @@ -1,8 +1,6 @@ import asyncio import struct -from .constants import HEADER_TAGS - def encode_uvarint(number): """Pack `number` into varint bytes""" @@ -44,15 +42,3 @@ async def decode_uvarint_from_stream(reader, timeout): break return result - - -def get_flag(initiator, action): - """ - get header flag based on action for mplex - :param action: action type in str - :return: int flag - """ - if initiator or HEADER_TAGS[action] == 0: - return HEADER_TAGS[action] - - return HEADER_TAGS[action] - 1