Refactor HeaderTags
This commit is contained in:
parent
29fbb9e40a
commit
36b7e8ded9
|
@ -1 +1,11 @@
|
|||
HEADER_TAGS = {"NEW_STREAM": 0, "MESSAGE": 2, "CLOSE": 4, "RESET": 6}
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class HeaderTags(Enum):
|
||||
NewStream = 0
|
||||
MessageReceiver = 1
|
||||
MessageInitiator = 2
|
||||
CloseReceiver = 3
|
||||
CloseInitiator = 4
|
||||
ResetReceiver = 5
|
||||
ResetInitiator = 6
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
import asyncio
|
||||
|
||||
from ..muxed_connection_interface import IMuxedConn
|
||||
from .constants import HeaderTags
|
||||
from .utils import encode_uvarint, decode_uvarint_from_stream
|
||||
from .mplex_stream import MplexStream
|
||||
from .utils import decode_uvarint_from_stream, encode_uvarint, get_flag
|
||||
from ..muxed_connection_interface import IMuxedConn
|
||||
|
||||
|
||||
class Mplex(IMuxedConn):
|
||||
|
@ -78,7 +79,7 @@ class Mplex(IMuxedConn):
|
|||
stream_id = self.raw_conn.next_stream_id()
|
||||
stream = MplexStream(stream_id, multi_addr, self)
|
||||
self.buffers[stream_id] = asyncio.Queue()
|
||||
await self.send_message(get_flag(self.initiator, "NEW_STREAM"), None, stream_id)
|
||||
await self.send_message(HeaderTags.NewStream, None, stream_id)
|
||||
return stream
|
||||
|
||||
async def accept_stream(self):
|
||||
|
@ -90,7 +91,7 @@ class Mplex(IMuxedConn):
|
|||
stream = MplexStream(stream_id, False, self)
|
||||
asyncio.ensure_future(self.generic_protocol_handler(stream))
|
||||
|
||||
async def send_message(self, flag, data, stream_id):
|
||||
async def send_message(self, flag: HeaderTags, data, stream_id):
|
||||
"""
|
||||
sends a message over the connection
|
||||
:param header: header to use
|
||||
|
@ -99,7 +100,7 @@ class Mplex(IMuxedConn):
|
|||
:return: True if success
|
||||
"""
|
||||
# << by 3, then or with flag
|
||||
header = (stream_id << 3) | flag
|
||||
header = (stream_id << 3) | flag.value
|
||||
header = encode_uvarint(header)
|
||||
|
||||
if data is None:
|
||||
|
@ -135,7 +136,7 @@ class Mplex(IMuxedConn):
|
|||
self.buffers[stream_id] = asyncio.Queue()
|
||||
await self.stream_queue.put(stream_id)
|
||||
|
||||
if flag is get_flag(True, "NEW_STREAM"):
|
||||
if flag == HeaderTags.NewStream.value:
|
||||
# new stream detected on connection
|
||||
await self.accept_stream()
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@ import asyncio
|
|||
|
||||
from libp2p.stream_muxer.muxed_stream_interface import IMuxedStream
|
||||
|
||||
from .utils import get_flag
|
||||
from .constants import HeaderTags
|
||||
|
||||
|
||||
class MplexStream(IMuxedStream):
|
||||
|
@ -38,9 +38,12 @@ class MplexStream(IMuxedStream):
|
|||
write to stream
|
||||
:return: number of bytes written
|
||||
"""
|
||||
return await self.mplex_conn.send_message(
|
||||
get_flag(self.initiator, "MESSAGE"), data, self.stream_id
|
||||
flag = (
|
||||
HeaderTags.MessageInitiator
|
||||
if self.initiator
|
||||
else HeaderTags.MessageReceiver
|
||||
)
|
||||
return await self.mplex_conn.send_message(flag, data, self.stream_id)
|
||||
|
||||
async def close(self):
|
||||
"""
|
||||
|
@ -50,7 +53,8 @@ class MplexStream(IMuxedStream):
|
|||
"""
|
||||
# TODO error handling with timeout
|
||||
# TODO understand better how mutexes are used from go repo
|
||||
await self.mplex_conn.send_message(get_flag(self.initiator, "CLOSE"), None, self.stream_id)
|
||||
flag = HeaderTags.CloseInitiator if self.initiator else HeaderTags.CloseReceiver
|
||||
await self.mplex_conn.send_message(flag, None, self.stream_id)
|
||||
|
||||
remote_lock = ""
|
||||
async with self.stream_lock:
|
||||
|
@ -78,9 +82,12 @@ class MplexStream(IMuxedStream):
|
|||
return True
|
||||
|
||||
if not self.remote_closed:
|
||||
await self.mplex_conn.send_message(
|
||||
get_flag(self.initiator, "RESET"), None, self.stream_id
|
||||
flag = (
|
||||
HeaderTags.ResetInitiator
|
||||
if self.initiator
|
||||
else HeaderTags.ResetInitiator
|
||||
)
|
||||
await self.mplex_conn.send_message(flag, None, self.stream_id)
|
||||
|
||||
self.local_closed = True
|
||||
self.remote_closed = True
|
||||
|
|
|
@ -1,8 +1,6 @@
|
|||
import asyncio
|
||||
import struct
|
||||
|
||||
from .constants import HEADER_TAGS
|
||||
|
||||
|
||||
def encode_uvarint(number):
|
||||
"""Pack `number` into varint bytes"""
|
||||
|
@ -44,15 +42,3 @@ async def decode_uvarint_from_stream(reader, timeout):
|
|||
break
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def get_flag(initiator, action):
|
||||
"""
|
||||
get header flag based on action for mplex
|
||||
:param action: action type in str
|
||||
:return: int flag
|
||||
"""
|
||||
if initiator or HEADER_TAGS[action] == 0:
|
||||
return HEADER_TAGS[action]
|
||||
|
||||
return HEADER_TAGS[action] - 1
|
||||
|
|
Loading…
Reference in New Issue
Block a user