"""
Distributor
===========
A collection of async messaging tools which is used by our GraphQL schema.
"""
import logging
import uuid
from dataclasses import asdict, dataclass, field
from typing import AsyncGenerator, Awaitable, Callable, List, Optional, Union
from channels.layers import get_channel_layer
from channels_redis.core import RedisChannelLayer
from strawberry.channels import GraphQLWSConsumer
log = logging.getLogger(__name__)
[docs]def uuid_to_group(u: Union[uuid.UUID, str]) -> str:
"""Channel group names are not allow
to have ``-``, so we replace them with ``_``.
"""
return str(u).replace("-", "_")
[docs]class MissingChannelLayer(Exception):
pass
[docs]class GraphQLWSConsumerInjector(GraphQLWSConsumer):
"""Allows us to inject callbacks on e.g. a disconnect.
.. todo::
This can be made obsolete via
https://github.com/strawberry-graphql/strawberry/pull/2430
"""
[docs] def __init__(self, *args, **kwargs):
self.disconnect_callback: Optional[Callable[[], Awaitable[None]]] = None
self.receive_callback: Optional[Callable] = None
super().__init__(*args, **kwargs)
[docs] async def websocket_disconnect(self, message):
if self.disconnect_callback:
await self.disconnect_callback()
return await super().websocket_disconnect(message)
[docs] async def receive(self, *args, **kwargs) -> None:
if self.receive_callback:
await self.receive_callback(*args, **kwargs)
await super().receive(*args, **kwargs)
[docs]class GenCasterChannel:
"""
Abstraction layer for channels.
Publish and subscribe to specific updates or more general ones as well.
"""
GRAPH_UPDATE_TYPE = "graph.update"
NODE_UPDATE_TYPE = "node.update"
STREAM_LOG_UPDATE_TYPE = "stream_log.update"
STREAMS_UPDATE_TYPE = "streams.update"
[docs] def __init__(self) -> None:
pass
@staticmethod
def _get_layer() -> RedisChannelLayer:
if layer := get_channel_layer():
return layer
raise Exception("Could not obtain redis channel layer")
@staticmethod
async def send_graph_update(graph_uuid: uuid.UUID):
return await GenCasterChannel.send_message(
layer=GenCasterChannel._get_layer(),
message=GraphUpdateMessage(uuid=str(graph_uuid)),
)
@staticmethod
async def send_node_update(node_uuid: uuid.UUID):
return await GenCasterChannel.send_message(
layer=GenCasterChannel._get_layer(),
message=NodeUpdateMessage(uuid=str(node_uuid)),
)
@staticmethod
async def send_log_update(stream_log_message: "StreamLogUpdateMessage"):
return await GenCasterChannel.send_message(
layer=GenCasterChannel._get_layer(), message=stream_log_message
)
@staticmethod
async def send_streams_update(stream_uuid: str):
return await GenCasterChannel.send_message(
layer=GenCasterChannel._get_layer(),
message=StreamsUpdateMessage(uuid=str(stream_uuid)),
)
@staticmethod
async def send_message(
layer: RedisChannelLayer,
message: Union[
"GraphUpdateMessage",
"NodeUpdateMessage",
"StreamLogUpdateMessage",
"StreamsUpdateMessage",
],
):
for channel in message.channels:
await layer.group_send(channel, asdict(message))
@staticmethod
async def receive_graph_updates(
consumer: GraphQLWSConsumer, graph_uuid: uuid.UUID
) -> AsyncGenerator["GraphUpdateMessage", None]:
group_name = uuid_to_group(graph_uuid)
if not consumer.channel_layer:
raise MissingChannelLayer()
await consumer.channel_layer.group_add(group_name, consumer.channel_name)
async for message in consumer.channel_listen(
GenCasterChannel.GRAPH_UPDATE_TYPE, groups=[group_name]
):
yield GraphUpdateMessage(**message)
@staticmethod
async def receive_node_updates(
consumer: GraphQLWSConsumer,
node_uuid: uuid.UUID,
) -> AsyncGenerator["NodeUpdateMessage", None]:
group_name = uuid_to_group(node_uuid)
if not consumer.channel_layer:
raise MissingChannelLayer()
await consumer.channel_layer.group_add(group_name, consumer.channel_name)
async for message in consumer.channel_listen(
GenCasterChannel.NODE_UPDATE_TYPE,
groups=[group_name],
):
yield NodeUpdateMessage(**message)
@staticmethod
async def receive_stream_log_updates(
consumer: GraphQLWSConsumer,
) -> AsyncGenerator["StreamLogUpdateMessage", None]:
group_name = GenCasterChannel.STREAM_LOG_UPDATE_TYPE
if not consumer.channel_layer:
raise MissingChannelLayer()
await consumer.channel_layer.group_add(group_name, consumer.channel_name)
async for message in consumer.channel_listen(
GenCasterChannel.STREAM_LOG_UPDATE_TYPE,
groups=[group_name],
):
yield StreamLogUpdateMessage(**message)
@staticmethod
async def receive_streams_updates(
consumer: GraphQLWSConsumer,
) -> AsyncGenerator["StreamsUpdateMessage", None]:
group_name = GenCasterChannel.STREAMS_UPDATE_TYPE
if not consumer.channel_layer:
raise MissingChannelLayer()
await consumer.channel_layer.group_add(group_name, consumer.channel_name)
async for message in consumer.channel_listen(
GenCasterChannel.STREAMS_UPDATE_TYPE,
groups=[group_name],
):
yield StreamsUpdateMessage(**message)
[docs]@dataclass
class GraphUpdateMessage:
# we can not transfer an UUID via redis so we encode it as string early
uuid: str
type: str = GenCasterChannel.GRAPH_UPDATE_TYPE
additional_channels: List[str] = field(default_factory=list)
@property
def channels(self) -> List[str]:
return [uuid_to_group(self.uuid)] + self.additional_channels
[docs]@dataclass
class NodeUpdateMessage:
uuid: str
type: str = GenCasterChannel.NODE_UPDATE_TYPE
additional_channels: List[str] = field(default_factory=list)
@property
def channels(self) -> List[str]:
return [uuid_to_group(self.uuid)] + self.additional_channels
[docs]@dataclass
class StreamLogUpdateMessage:
# todo if a str is inserted here it will fail
uuid: str
stream_point_uuid: Optional[str]
stream_uuid: Optional[str]
type: str = GenCasterChannel.STREAM_LOG_UPDATE_TYPE
additional_channels: List[str] = field(default_factory=list)
@property
def channels(self) -> List[str]:
return [GenCasterChannel.STREAM_LOG_UPDATE_TYPE] + self.additional_channels
[docs]@dataclass
class StreamsUpdateMessage:
uuid: str
type: str = GenCasterChannel.STREAMS_UPDATE_TYPE
additional_channels: List[str] = field(default_factory=list)
@property
def channels(self) -> List[str]:
return [GenCasterChannel.STREAMS_UPDATE_TYPE] + self.additional_channels