Refactor HeaderTags

This commit is contained in:
Chih Cheng Liang 2019-08-02 17:14:43 +08:00
parent 29fbb9e40a
commit 36b7e8ded9
No known key found for this signature in database
GPG Key ID: C86B5E6612B1487A
4 changed files with 31 additions and 27 deletions

View File

@ -1 +1,11 @@
HEADER_TAGS = {"NEW_STREAM": 0, "MESSAGE": 2, "CLOSE": 4, "RESET": 6}
from enum import Enum
class HeaderTags(Enum):
NewStream = 0
MessageReceiver = 1
MessageInitiator = 2
CloseReceiver = 3
CloseInitiator = 4
ResetReceiver = 5
ResetInitiator = 6

View File

@ -1,8 +1,9 @@
import asyncio
from ..muxed_connection_interface import IMuxedConn
from .constants import HeaderTags
from .utils import encode_uvarint, decode_uvarint_from_stream
from .mplex_stream import MplexStream
from .utils import decode_uvarint_from_stream, encode_uvarint, get_flag
from ..muxed_connection_interface import IMuxedConn
class Mplex(IMuxedConn):
@ -78,7 +79,7 @@ class Mplex(IMuxedConn):
stream_id = self.raw_conn.next_stream_id()
stream = MplexStream(stream_id, multi_addr, self)
self.buffers[stream_id] = asyncio.Queue()
await self.send_message(get_flag(self.initiator, "NEW_STREAM"), None, stream_id)
await self.send_message(HeaderTags.NewStream, None, stream_id)
return stream
async def accept_stream(self):
@ -90,7 +91,7 @@ class Mplex(IMuxedConn):
stream = MplexStream(stream_id, False, self)
asyncio.ensure_future(self.generic_protocol_handler(stream))
async def send_message(self, flag, data, stream_id):
async def send_message(self, flag: HeaderTags, data, stream_id):
"""
sends a message over the connection
:param header: header to use
@ -99,7 +100,7 @@ class Mplex(IMuxedConn):
:return: True if success
"""
# << by 3, then or with flag
header = (stream_id << 3) | flag
header = (stream_id << 3) | flag.value
header = encode_uvarint(header)
if data is None:
@ -135,7 +136,7 @@ class Mplex(IMuxedConn):
self.buffers[stream_id] = asyncio.Queue()
await self.stream_queue.put(stream_id)
if flag is get_flag(True, "NEW_STREAM"):
if flag == HeaderTags.NewStream.value:
# new stream detected on connection
await self.accept_stream()

View File

@ -2,7 +2,7 @@ import asyncio
from libp2p.stream_muxer.muxed_stream_interface import IMuxedStream
from .utils import get_flag
from .constants import HeaderTags
class MplexStream(IMuxedStream):
@ -38,9 +38,12 @@ class MplexStream(IMuxedStream):
write to stream
:return: number of bytes written
"""
return await self.mplex_conn.send_message(
get_flag(self.initiator, "MESSAGE"), data, self.stream_id
flag = (
HeaderTags.MessageInitiator
if self.initiator
else HeaderTags.MessageReceiver
)
return await self.mplex_conn.send_message(flag, data, self.stream_id)
async def close(self):
"""
@ -50,7 +53,8 @@ class MplexStream(IMuxedStream):
"""
# TODO error handling with timeout
# TODO understand better how mutexes are used from go repo
await self.mplex_conn.send_message(get_flag(self.initiator, "CLOSE"), None, self.stream_id)
flag = HeaderTags.CloseInitiator if self.initiator else HeaderTags.CloseReceiver
await self.mplex_conn.send_message(flag, None, self.stream_id)
remote_lock = ""
async with self.stream_lock:
@ -78,9 +82,12 @@ class MplexStream(IMuxedStream):
return True
if not self.remote_closed:
await self.mplex_conn.send_message(
get_flag(self.initiator, "RESET"), None, self.stream_id
flag = (
HeaderTags.ResetInitiator
if self.initiator
else HeaderTags.ResetInitiator
)
await self.mplex_conn.send_message(flag, None, self.stream_id)
self.local_closed = True
self.remote_closed = True

View File

@ -1,8 +1,6 @@
import asyncio
import struct
from .constants import HEADER_TAGS
def encode_uvarint(number):
"""Pack `number` into varint bytes"""
@ -44,15 +42,3 @@ async def decode_uvarint_from_stream(reader, timeout):
break
return result
def get_flag(initiator, action):
"""
get header flag based on action for mplex
:param action: action type in str
:return: int flag
"""
if initiator or HEADER_TAGS[action] == 0:
return HEADER_TAGS[action]
return HEADER_TAGS[action] - 1