Refactor HeaderTags
This commit is contained in:
parent
29fbb9e40a
commit
36b7e8ded9
|
@ -1 +1,11 @@
|
||||||
HEADER_TAGS = {"NEW_STREAM": 0, "MESSAGE": 2, "CLOSE": 4, "RESET": 6}
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class HeaderTags(Enum):
|
||||||
|
NewStream = 0
|
||||||
|
MessageReceiver = 1
|
||||||
|
MessageInitiator = 2
|
||||||
|
CloseReceiver = 3
|
||||||
|
CloseInitiator = 4
|
||||||
|
ResetReceiver = 5
|
||||||
|
ResetInitiator = 6
|
||||||
|
|
|
@ -1,8 +1,9 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
from ..muxed_connection_interface import IMuxedConn
|
from .constants import HeaderTags
|
||||||
|
from .utils import encode_uvarint, decode_uvarint_from_stream
|
||||||
from .mplex_stream import MplexStream
|
from .mplex_stream import MplexStream
|
||||||
from .utils import decode_uvarint_from_stream, encode_uvarint, get_flag
|
from ..muxed_connection_interface import IMuxedConn
|
||||||
|
|
||||||
|
|
||||||
class Mplex(IMuxedConn):
|
class Mplex(IMuxedConn):
|
||||||
|
@ -78,7 +79,7 @@ class Mplex(IMuxedConn):
|
||||||
stream_id = self.raw_conn.next_stream_id()
|
stream_id = self.raw_conn.next_stream_id()
|
||||||
stream = MplexStream(stream_id, multi_addr, self)
|
stream = MplexStream(stream_id, multi_addr, self)
|
||||||
self.buffers[stream_id] = asyncio.Queue()
|
self.buffers[stream_id] = asyncio.Queue()
|
||||||
await self.send_message(get_flag(self.initiator, "NEW_STREAM"), None, stream_id)
|
await self.send_message(HeaderTags.NewStream, None, stream_id)
|
||||||
return stream
|
return stream
|
||||||
|
|
||||||
async def accept_stream(self):
|
async def accept_stream(self):
|
||||||
|
@ -90,7 +91,7 @@ class Mplex(IMuxedConn):
|
||||||
stream = MplexStream(stream_id, False, self)
|
stream = MplexStream(stream_id, False, self)
|
||||||
asyncio.ensure_future(self.generic_protocol_handler(stream))
|
asyncio.ensure_future(self.generic_protocol_handler(stream))
|
||||||
|
|
||||||
async def send_message(self, flag, data, stream_id):
|
async def send_message(self, flag: HeaderTags, data, stream_id):
|
||||||
"""
|
"""
|
||||||
sends a message over the connection
|
sends a message over the connection
|
||||||
:param header: header to use
|
:param header: header to use
|
||||||
|
@ -99,7 +100,7 @@ class Mplex(IMuxedConn):
|
||||||
:return: True if success
|
:return: True if success
|
||||||
"""
|
"""
|
||||||
# << by 3, then or with flag
|
# << by 3, then or with flag
|
||||||
header = (stream_id << 3) | flag
|
header = (stream_id << 3) | flag.value
|
||||||
header = encode_uvarint(header)
|
header = encode_uvarint(header)
|
||||||
|
|
||||||
if data is None:
|
if data is None:
|
||||||
|
@ -135,7 +136,7 @@ class Mplex(IMuxedConn):
|
||||||
self.buffers[stream_id] = asyncio.Queue()
|
self.buffers[stream_id] = asyncio.Queue()
|
||||||
await self.stream_queue.put(stream_id)
|
await self.stream_queue.put(stream_id)
|
||||||
|
|
||||||
if flag is get_flag(True, "NEW_STREAM"):
|
if flag == HeaderTags.NewStream.value:
|
||||||
# new stream detected on connection
|
# new stream detected on connection
|
||||||
await self.accept_stream()
|
await self.accept_stream()
|
||||||
|
|
||||||
|
|
|
@ -2,7 +2,7 @@ import asyncio
|
||||||
|
|
||||||
from libp2p.stream_muxer.muxed_stream_interface import IMuxedStream
|
from libp2p.stream_muxer.muxed_stream_interface import IMuxedStream
|
||||||
|
|
||||||
from .utils import get_flag
|
from .constants import HeaderTags
|
||||||
|
|
||||||
|
|
||||||
class MplexStream(IMuxedStream):
|
class MplexStream(IMuxedStream):
|
||||||
|
@ -38,9 +38,12 @@ class MplexStream(IMuxedStream):
|
||||||
write to stream
|
write to stream
|
||||||
:return: number of bytes written
|
:return: number of bytes written
|
||||||
"""
|
"""
|
||||||
return await self.mplex_conn.send_message(
|
flag = (
|
||||||
get_flag(self.initiator, "MESSAGE"), data, self.stream_id
|
HeaderTags.MessageInitiator
|
||||||
|
if self.initiator
|
||||||
|
else HeaderTags.MessageReceiver
|
||||||
)
|
)
|
||||||
|
return await self.mplex_conn.send_message(flag, data, self.stream_id)
|
||||||
|
|
||||||
async def close(self):
|
async def close(self):
|
||||||
"""
|
"""
|
||||||
|
@ -50,7 +53,8 @@ class MplexStream(IMuxedStream):
|
||||||
"""
|
"""
|
||||||
# TODO error handling with timeout
|
# TODO error handling with timeout
|
||||||
# TODO understand better how mutexes are used from go repo
|
# TODO understand better how mutexes are used from go repo
|
||||||
await self.mplex_conn.send_message(get_flag(self.initiator, "CLOSE"), None, self.stream_id)
|
flag = HeaderTags.CloseInitiator if self.initiator else HeaderTags.CloseReceiver
|
||||||
|
await self.mplex_conn.send_message(flag, None, self.stream_id)
|
||||||
|
|
||||||
remote_lock = ""
|
remote_lock = ""
|
||||||
async with self.stream_lock:
|
async with self.stream_lock:
|
||||||
|
@ -78,9 +82,12 @@ class MplexStream(IMuxedStream):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if not self.remote_closed:
|
if not self.remote_closed:
|
||||||
await self.mplex_conn.send_message(
|
flag = (
|
||||||
get_flag(self.initiator, "RESET"), None, self.stream_id
|
HeaderTags.ResetInitiator
|
||||||
|
if self.initiator
|
||||||
|
else HeaderTags.ResetInitiator
|
||||||
)
|
)
|
||||||
|
await self.mplex_conn.send_message(flag, None, self.stream_id)
|
||||||
|
|
||||||
self.local_closed = True
|
self.local_closed = True
|
||||||
self.remote_closed = True
|
self.remote_closed = True
|
||||||
|
|
|
@ -1,8 +1,6 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import struct
|
import struct
|
||||||
|
|
||||||
from .constants import HEADER_TAGS
|
|
||||||
|
|
||||||
|
|
||||||
def encode_uvarint(number):
|
def encode_uvarint(number):
|
||||||
"""Pack `number` into varint bytes"""
|
"""Pack `number` into varint bytes"""
|
||||||
|
@ -44,15 +42,3 @@ async def decode_uvarint_from_stream(reader, timeout):
|
||||||
break
|
break
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def get_flag(initiator, action):
|
|
||||||
"""
|
|
||||||
get header flag based on action for mplex
|
|
||||||
:param action: action type in str
|
|
||||||
:return: int flag
|
|
||||||
"""
|
|
||||||
if initiator or HEADER_TAGS[action] == 0:
|
|
||||||
return HEADER_TAGS[action]
|
|
||||||
|
|
||||||
return HEADER_TAGS[action] - 1
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user