2019-12-17 18:17:28 +08:00
|
|
|
from types import TracebackType
|
|
|
|
from typing import AsyncIterator, Optional, Type
|
|
|
|
|
|
|
|
import trio
|
|
|
|
|
|
|
|
from .abc import ISubscriptionAPI
|
|
|
|
from .pb import rpc_pb2
|
2020-01-28 00:29:05 +08:00
|
|
|
from .typing import UnsubscribeFn
|
2019-12-17 18:17:28 +08:00
|
|
|
|
|
|
|
|
|
|
|
class BaseSubscriptionAPI(ISubscriptionAPI):
|
|
|
|
async def __aenter__(self) -> "BaseSubscriptionAPI":
|
|
|
|
await trio.hazmat.checkpoint()
|
|
|
|
return self
|
|
|
|
|
|
|
|
async def __aexit__(
|
|
|
|
self,
|
|
|
|
exc_type: "Optional[Type[BaseException]]",
|
|
|
|
exc_value: "Optional[BaseException]",
|
|
|
|
traceback: "Optional[TracebackType]",
|
|
|
|
) -> None:
|
2020-01-28 00:29:05 +08:00
|
|
|
await self.unsubscribe()
|
2019-12-17 18:17:28 +08:00
|
|
|
|
|
|
|
|
|
|
|
class TrioSubscriptionAPI(BaseSubscriptionAPI):
|
|
|
|
receive_channel: "trio.MemoryReceiveChannel[rpc_pb2.Message]"
|
2020-01-28 00:29:05 +08:00
|
|
|
unsubscribe_fn: UnsubscribeFn
|
2019-12-17 18:17:28 +08:00
|
|
|
|
|
|
|
def __init__(
|
2020-01-28 00:29:05 +08:00
|
|
|
self,
|
|
|
|
receive_channel: "trio.MemoryReceiveChannel[rpc_pb2.Message]",
|
|
|
|
unsubscribe_fn: UnsubscribeFn,
|
2019-12-17 18:17:28 +08:00
|
|
|
) -> None:
|
|
|
|
self.receive_channel = receive_channel
|
2020-01-28 00:29:05 +08:00
|
|
|
# Ignore type here since mypy complains: https://github.com/python/mypy/issues/2427
|
|
|
|
self.unsubscribe_fn = unsubscribe_fn # type: ignore
|
2019-12-17 18:17:28 +08:00
|
|
|
|
2020-01-28 00:29:05 +08:00
|
|
|
async def unsubscribe(self) -> None:
|
|
|
|
# Ignore type here since mypy complains: https://github.com/python/mypy/issues/2427
|
|
|
|
await self.unsubscribe_fn() # type: ignore
|
2019-12-17 18:17:28 +08:00
|
|
|
|
|
|
|
def __aiter__(self) -> AsyncIterator[rpc_pb2.Message]:
|
|
|
|
return self.receive_channel.__aiter__()
|
|
|
|
|
|
|
|
async def get(self) -> rpc_pb2.Message:
|
|
|
|
return await self.receive_channel.receive()
|