Add reset and close

This commit is contained in:
mhchia 2019-09-05 23:44:22 +08:00
parent 10415cb956
commit 207fa75d8f
No known key found for this signature in database
GPG Key ID: 389EFBEA1362589A
2 changed files with 45 additions and 6 deletions

View File

@ -82,11 +82,10 @@ class IMuxedStream(ReadWriteCloser):
mplex_conn: IMuxedConn
@abstractmethod
async def reset(self) -> bool:
async def reset(self) -> None:
"""
closes both ends of the stream
tells this remote side to hang up
:return: true if successful
"""
@abstractmethod

View File

@ -173,6 +173,7 @@ class Mplex(IMuxedConn):
if flag == HeaderTags.NewStream.value:
if is_stream_id_seen:
# `NewStream` for the same id is received twice...
# TODO: Shutdown
pass
await self.accept_stream(stream_id, message.decode())
elif flag in (
@ -187,10 +188,49 @@ class Mplex(IMuxedConn):
async with self.streams_lock:
stream = self.streams[stream_id]
await stream.incoming_data.put(message)
# elif flag in (
# HeaderTags.CloseInitiator.value,
# HeaderTags.CloseReceiver.value
# ):
elif flag in (
HeaderTags.CloseInitiator.value,
HeaderTags.CloseReceiver.value,
):
if not is_stream_id_seen:
continue
stream: MplexStream
async with self.streams_lock:
stream = self.streams[stream_id]
is_local_closed: bool
async with stream.close_lock:
stream.event_remote_closed.set()
is_local_closed = stream.event_local_closed.is_set()
# If local is also closed, both sides are closed. Then, we should clean up
# this stream.
if is_local_closed:
async with self.streams_lock:
del self.streams[stream_id]
elif flag in (
HeaderTags.ResetInitiator.value,
HeaderTags.ResetReceiver.value,
):
if not is_stream_id_seen:
# This is *ok*. We forget the stream on reset.
continue
stream: MplexStream
async with self.streams_lock:
stream = self.streams[stream_id]
async with stream.close_lock:
if not stream.event_remote_closed.is_set():
stream.event_reset.set()
stream.event_remote_closed.set()
if not stream.event_local_closed.is_set():
stream.event_local_closed.close()
async with self.streams_lock:
del self.streams[stream_id]
else:
# TODO: logging
print(f"message with unknown header on stream {stream_id}")
if is_stream_id_seen:
async with self.streams_lock:
stream = self.streams[stream_id]
await stream.reset()
# Force context switch
await asyncio.sleep(0)