Source code for gencaster.schema

"""
.. _schema:

Schema
======

Here we define all the endpoints for GraphQL.

For a specific details of the types consider the
`GraphiQL <https://github.com/graphql/graphiql>`_
page available under the `/graphql` endpoint of
the running backend.

Any subscription updates are messaged via Redis and
is handled via channels and has an abstraction layer
:class:`~gencaster.distributor.GenCasterChannel`.
"""

import json
import logging
import os
import uuid
from typing import Any, AsyncGenerator, Dict, List, Optional

import strawberry
import strawberry.django
from asgiref.sync import sync_to_async
from django.contrib.auth import authenticate, login, logout
from django.contrib.auth.models import AbstractBaseUser
from django.contrib.auth.models import User as UserModel
from django.core.exceptions import PermissionDenied
from django.core.files import File
from django.http.request import HttpRequest
from strawberry import UNSET, auto
from strawberry.types import Info
from strawberry_django.fields.field import StrawberryDjangoField

import story_graph.models as story_graph_models
import stream.models as stream_models
from story_graph.engine import Engine
from story_graph.types import (
    AddGraphInput,
    AudioCellInput,
    Edge,
    EdgeInput,
    Graph,
    GraphFilter,
    InvalidPythonCode,
    Node,
    NodeCreate,
    NodeDoor,
    NodeDoorInputCreate,
    NodeDoorInputUpdate,
    NodeDoorResponse,
    NodeUpdate,
    ScriptCell,
    ScriptCellInputCreate,
    ScriptCellInputUpdate,
    UpdateGraphInput,
    create_python_highlight_string,
)
from stream.exceptions import NoStreamAvailableException
from stream.frontend_types import Dialog
from stream.types import (
    AddAudioFile,
    AudioFile,
    AudioFileUploadResponse,
    GraphDeadEnd,
    InvalidAudioFile,
    NoStreamAvailable,
    Stream,
    StreamInfo,
    StreamInfoResponse,
    StreamLog,
    StreamPoint,
    StreamVariable,
    StreamVariableInput,
    UpdateAudioFile,
)

from . import db_logging
from .distributor import GenCasterChannel, GraphQLWSConsumerInjector

log = logging.getLogger(__name__)


[docs]class IsAuthenticated(strawberry.BasePermission): message = "User is not authenticated" async def has_permission(self, source: Any, info: Info, **kwargs) -> bool: return True if await sync_to_async(lambda: info.context.request.user.is_authenticated)(): # type: ignore return True return False
[docs]class AuthStrawberryDjangoField(StrawberryDjangoField): """Allows us to restrict certain actions to logged in users.""" def resolver(self, info: Info, source, **kwargs): request: HttpRequest = info.context.request if not request.user.is_authenticated: raise PermissionDenied() return super().resolver(info, source, **kwargs)
[docs]async def graphql_check_authenticated(info: Info): """Helper function to determine if we are loggin in an async manner. This would be better a decorator but strawberry is not nice in these regards, see `Stack Overflow <https://stackoverflow.com/a/72796313/3475778>`_. """ auth = await sync_to_async(lambda: info.context.request.user.is_authenticated)() # type: ignore if auth is False: raise PermissionDenied()
[docs]async def update_or_create_audio_cell( audio_cell_input: Optional[AudioCellInput], ) -> Optional[story_graph_models.AudioCell]: """Async function to update audio cells""" if audio_cell_input: ( audio_cell, created, ) = await story_graph_models.AudioCell.objects.aupdate_or_create( uuid=audio_cell_input.uuid, defaults={ "playback": audio_cell_input.playback, "audio_file_id": audio_cell_input.audio_file.uuid, "volume": audio_cell_input.volume, }, ) if created: # @todo access .uuid directly to avoid fk access # in async mode log.debug(f"Created audio cell {audio_cell.uuid}") else: audio_cell = None return audio_cell
[docs]@strawberry.django.type(UserModel) class User: username: auto is_staff: auto is_active: auto first_name: auto last_name: auto email: auto
[docs]@strawberry.type class Query: """Queries for Gencaster.""" stream_point: StreamPoint = strawberry.django.field() stream_points: List[StreamPoint] = strawberry.django.field() graphs: List[Graph] = strawberry.django.field(filters=GraphFilter) graph: Graph = AuthStrawberryDjangoField() nodes: List[Node] = AuthStrawberryDjangoField() node: Node = AuthStrawberryDjangoField() audio_files: List[AudioFile] = AuthStrawberryDjangoField() audio_file: AudioFile = AuthStrawberryDjangoField() stream_variable: StreamVariable = AuthStrawberryDjangoField() @strawberry.field(permission_classes=[IsAuthenticated]) async def is_authenticated(self, info) -> Optional[User]: # type issue https://github.com/python/mypy/issues/9590 if not await sync_to_async(lambda: info.context.request.user.is_anonymous)(): # type: ignore return info.context.request.user # type: ignore return None
[docs]@strawberry.type class LoginError: error_message: Optional[str] = None
LoginRequest = strawberry.union("LoginRequestResponse", [LoginError, User])
[docs]@strawberry.type class Mutation: """Mutations for Gencaster via GraphQL.""" @strawberry.mutation async def auth_login(self, info, username: str, password: str) -> LoginRequest: # type: ignore # user type is Optional[AbstractBaseUser] but we return user which is similar user: Optional[AbstractBaseUser] try: user = await sync_to_async(authenticate)( request=info.context.request, username=username, password=password, ) except PermissionDenied as e: return LoginError( error_message=str(e), ) if user is not None: await sync_to_async(login)(info.context.request, user) return user # type: ignore return LoginError( error_message="Wrong credentials", ) @strawberry.mutation async def auth_logout(self, info) -> bool: await sync_to_async(logout)(info.context.request) return True
[docs] @strawberry.mutation async def update_audio_file( self, info, uuid: uuid.UUID, update_audio_file: UpdateAudioFile ) -> AudioFile: """Update metadata of an :class:`~stream.models.AudioFile` via a UUID""" await graphql_check_authenticated(info) audio_file = await stream_models.AudioFile.objects.aget(uuid=uuid) if update_audio_file.name: audio_file.name = update_audio_file.name if update_audio_file.description: audio_file.description = update_audio_file.description await audio_file.asave() return audio_file # type: ignore
[docs] @strawberry.mutation async def add_node(self, info: Info, new_node: NodeCreate) -> None: """Creates a new :class:`~story_graph.models.Node` in a given ~class:`~story_graph.models.Graph`. Although it creates a new node with UUID we don't hand it back yet. """ await graphql_check_authenticated(info) graph = await story_graph_models.Graph.objects.aget(uuid=new_node.graph_uuid) node = story_graph_models.Node( name=new_node.name, graph=graph, ) # transfer from new_node to node model if attribute is not none for field in ["position_x", "position_y", "color"]: if new_value := getattr(new_node, field): setattr(node, field, new_value) await node.asave() return None
[docs] @strawberry.mutation async def update_node(self, info: Info, node_update: NodeUpdate) -> None: """Updates a given :class:`~story_graph.models.Node` which can be used for renaming or moving it across the canvas. """ await graphql_check_authenticated(info) node = await story_graph_models.Node.objects.select_related("graph").aget( uuid=node_update.uuid ) for field in ["position_x", "position_y", "color", "name"]: if new_value := getattr(node_update, field): setattr(node, field, new_value) await node.asave() return None
[docs] @strawberry.mutation async def add_edge(self, info: Info, new_edge: EdgeInput) -> Edge: """Creates a :class:`~story_graph.models.Edge` for a given :class:`~story_graph.models.Graph`. It returns the created edge. """ await graphql_check_authenticated(info) in_node_door = await story_graph_models.NodeDoor.objects.select_related( "node__graph" ).aget(uuid=new_edge.node_door_in_uuid) out_node_door = await story_graph_models.NodeDoor.objects.aget( uuid=new_edge.node_door_out_uuid ) edge = await story_graph_models.Edge.objects.acreate( in_node_door=in_node_door, out_node_door=out_node_door, ) return edge # type: ignore
[docs] @strawberry.mutation async def delete_edge(self, info, edge_uuid: uuid.UUID) -> None: """Deletes a given :class:`~story_graph.models.Edge`.""" await graphql_check_authenticated(info) await story_graph_models.Edge.objects.filter(uuid=edge_uuid).adelete()
[docs] @strawberry.mutation async def delete_node(self, info, node_uuid: uuid.UUID) -> None: """Deletes a given :class:`~story_graph.models.Node`.""" await graphql_check_authenticated(info) node = await story_graph_models.Node.objects.aget(uuid=node_uuid) if node is None: raise Exception(f"Could not find node {node_uuid}") if node.is_entry_node: raise Exception( f"Node {node_uuid} is an entry node which can not be deleted" ) await node.adelete() return None
[docs] @strawberry.mutation async def create_script_cells( self, info, script_cell_inputs: List[ScriptCellInputCreate], node_uuid: uuid.UUID, ) -> List[ScriptCell]: """Creates or updates a given :class:`~story_graph.models.ScriptCell` to change its content.""" await graphql_check_authenticated(info) try: node: story_graph_models.Node = await story_graph_models.Node.objects.aget( uuid=node_uuid ) except story_graph_models.Node.DoesNotExist as e: log.error(f"Received update on unknown node {node_uuid}") raise e script_cells: List[story_graph_models.ScriptCell] = [] for script_cell_input in script_cell_inputs: audio_cell = await update_or_create_audio_cell(script_cell_input.audio_cell) # if no cell order is given we add it to the end of the current node if not script_cell_input.cell_order: cur_max_cell_order = ( await story_graph_models.ScriptCell.objects.filter(node=node) .order_by("-cell_order") .afirst() ) if cur_max_cell_order: script_cell_input.cell_order = cur_max_cell_order.cell_order + 1 else: script_cell_input.cell_order = 0 script_cell = await story_graph_models.ScriptCell.objects.acreate( cell_order=script_cell_input.cell_order, cell_type=script_cell_input.cell_type, cell_code=script_cell_input.cell_code, node=node, audio_cell=audio_cell, ) log.debug(f"Created script cell {script_cell.uuid}") script_cells.append(script_cell) return script_cells # type: ignore
@strawberry.mutation async def update_script_cells( self, info, script_cell_inputs: List[ScriptCellInputUpdate] ) -> List[ScriptCell]: script_cells: List[story_graph_models.ScriptCell] = [] for script_cell_input in script_cell_inputs: # the async orm is still strange sometimes, therefore the code is not written in a clean # and concise manner script_cell: story_graph_models.ScriptCell = ( await story_graph_models.ScriptCell.objects.aget( uuid=script_cell_input.uuid ) ) audio_cell = await update_or_create_audio_cell(script_cell_input.audio_cell) # **{k: v for (k, v) in updates.items() if v is not None} # did not work updates: Dict[str, Any] = {} if (order := script_cell_input.cell_order) != UNSET: updates["cell_order"] = order if audio_cell: updates["audio_cell"] = audio_cell if cell_code := script_cell_input.cell_code: updates["cell_code"] = cell_code if cell_type := script_cell_input.cell_type: updates["cell_type"] = cell_type if len(updates) == 0: # maybe continue await story_graph_models.ScriptCell.objects.filter( uuid=script_cell_input.uuid ).aupdate(**updates) script_cells.append(script_cell) return script_cells # type: ignore
[docs] @strawberry.mutation async def delete_script_cell(self, info, script_cell_uuid: uuid.UUID) -> None: """Deletes a given :class:`~story_graph.models.ScriptCell`.""" await graphql_check_authenticated(info) await story_graph_models.ScriptCell.objects.filter( uuid=script_cell_uuid ).adelete()
@strawberry.mutation async def add_graph(self, info, graph_input: AddGraphInput) -> Graph: await graphql_check_authenticated(info) graph = await story_graph_models.Graph.objects.acreate( name=graph_input.name, display_name=graph_input.display_name, slug_name=graph_input.slug_name, start_text=graph_input.start_text, about_text=graph_input.about_text, end_text=graph_input.end_text, public_visible=graph_input.public_visible, stream_assignment_policy=graph_input.stream_assignment_policy, ) await graph.acreate_entry_node() # need a refresh - in django 4.2 this will be available, see # https://docs.djangoproject.com/en/4.2/ref/models/instances/#django.db.models.Model.arefresh_from_db return await story_graph_models.Graph.objects.aget(uuid=graph.uuid) # type: ignore @strawberry.mutation async def update_graph( self, info, graph_input: UpdateGraphInput, graph_uuid: uuid.UUID ) -> Graph: await graphql_check_authenticated(info) graph = await story_graph_models.Graph.objects.aget(uuid=graph_uuid) for key, value in graph_input.__dict__.items(): if value == strawberry.UNSET: continue graph.__setattr__(key, value) await graph.asave() return graph # type: ignore @strawberry.mutation async def add_audio_file(self, info, new_audio_file: AddAudioFile) -> AudioFileUploadResponse: # type: ignore if new_audio_file.file is None or len(new_audio_file.file) == 0: return InvalidAudioFile(error="Received empty audio file") elif not os.path.splitext(new_audio_file.file_name)[-1].lower() in [ ".flac", ".wav", ]: return InvalidAudioFile(error="Only support flac and wav files") try: audio_file = await stream_models.AudioFile.objects.acreate( name=new_audio_file.name, file=File(new_audio_file.file, name=new_audio_file.file_name), description=new_audio_file.description, auto_generated=False, ) except Exception as e: return InvalidAudioFile( error=f"Unexpected error, could not save audio file: {e}" ) return audio_file @strawberry.mutation async def create_update_stream_variable( self, info, stream_variables: List[StreamVariableInput] ) -> List[StreamVariable]: stream_vars: List[stream_models.StreamVariable] = [] for stream_variable in stream_variables: ( stream_var, _, ) = await stream_models.StreamVariable.objects.aupdate_or_create( stream=await stream_models.Stream.objects.aget( uuid=stream_variable.stream_uuid ), key=stream_variable.key, defaults={ "value": stream_variable.value, "stream_to_sc": stream_variable.stream_to_sc, }, ) if stream_variable.stream_to_sc: await sync_to_async(stream_var.send_to_sc)() stream_vars.append(stream_var) return stream_vars # type: ignore @strawberry.mutation async def create_node_door( self, info, node_door_input: NodeDoorInputCreate, node_uuid: uuid.UUID, ) -> NodeDoor: await graphql_check_authenticated(info) node = await story_graph_models.Node.objects.aget(uuid=node_uuid) return await story_graph_models.NodeDoor.objects.acreate( door_type=node_door_input.door_type, node=node, name=node_door_input.name, order=node_door_input.order, code=node_door_input.code, ) # type: ignore @strawberry.mutation async def update_node_door( self, info, node_door_input: NodeDoorInputUpdate, ) -> NodeDoorResponse: # type: ignore await graphql_check_authenticated(info) node_door = await story_graph_models.NodeDoor.objects.aget( uuid=node_door_input.uuid ) node_door.door_type = node_door_input.door_type if node_door_input.code: node_door.code = node_door_input.code if node_door_input.name: node_door.name = node_door_input.name if node_door_input.order: node_door.order = node_door_input.order try: await node_door.asave() except SyntaxError as e: return InvalidPythonCode( error_type=e.msg, error_code=e.text if e.text else "", error_message=create_python_highlight_string(e), ) return node_door # type: ignore
[docs] @strawberry.mutation async def delete_node_door(self, info, node_door_uuid: uuid.UUID) -> bool: """Allows to delete a non-default NodeDoor. If a node door was deleted it will return ``True``, otherwise ``False``. """ await graphql_check_authenticated(info) deleted_objects, _ = await story_graph_models.NodeDoor.objects.filter( is_default=False, uuid=node_door_uuid, ).adelete() return deleted_objects >= 1
[docs]@strawberry.type class Subscription:
[docs] @strawberry.subscription async def graph( self, info: Info, graph_uuid: uuid.UUID, ) -> AsyncGenerator[Graph, None]: """Used within the editor to synchronize any updates of the graph such as movement of a :class:`~story_graph.models.Node`. """ graph = await story_graph_models.Graph.objects.aget(uuid=graph_uuid) yield graph # type: ignore async for graph_update in GenCasterChannel.receive_graph_updates( info.context["ws"], graph_uuid ): yield await story_graph_models.Graph.objects.aget(uuid=graph_update.uuid) # type: ignore
[docs] @strawberry.subscription async def node( self, info: Info, node_uuid: uuid.UUID, ) -> AsyncGenerator[Node, None]: """Used within the editor to synchronize any updates on a node such as updates on a :class:`~story_graph.models.ScriptCell`. """ node = await story_graph_models.Node.objects.aget(uuid=node_uuid) yield node # type: ignore async for node_update in GenCasterChannel.receive_node_updates( info.context["ws"], node_uuid ): yield await story_graph_models.Node.objects.aget(uuid=node_update.uuid) # type: ignore
[docs] @strawberry.subscription async def stream_info( self, info: Info, graph_uuid: uuid.UUID, ) -> AsyncGenerator[StreamInfoResponse, None]: # type: ignore """Used within the frontend to attach a user to a stream. :class:`~story_graph.engine.Engine` contains the specifics of how the iteration over a graph is handled. Upon visit the ``num_of_listeners`` of the associated :class:~stream.models.Stream` will be incremented which indicates if a given stream is free or used. Upon connection stop this will be decremented again. """ consumer: GraphQLWSConsumerInjector = info.context["ws"] graph = await story_graph_models.Graph.objects.aget(uuid=graph_uuid) try: stream = await stream_models.Stream.objects.aget_free_stream(graph) log.info(f"Attached to stream {stream.uuid}") except NoStreamAvailableException: log.error(f"No stream is available for graph {graph.name}") yield NoStreamAvailable() return async def cleanup(): await stream.decrement_num_listeners() async def cleanup_on_stop(**kwargs: Dict[str, str]): """ A helper function which scans for a "stop" signal send via the websocket connection of our graphql subscription as this is the indication from urql that we paused the subscription. """ if text_data := kwargs.get("text_data"): d = json.loads(text_data) # type: ignore if d.get("type") == "stop": log.info("Stop a stream due to a stop signal") await cleanup() with db_logging.LogContext(db_logging.LogKeyEnum.STREAM, stream): engine = Engine( graph=graph, stream=stream, ) await stream.increment_num_listeners() consumer.disconnect_callback = cleanup consumer.receive_callback = cleanup_on_stop # send a first stream info response so the front-end has # received information that streaming has/can be started, # see https://github.com/Gencaster/gencaster/issues/483 # otherwise this can result in a dead end if we await # a stream variable which is set from the frontend yield StreamInfo(stream=stream, stream_instruction=None) # type: ignore async for instruction in engine.start(max_steps=int(10e10)): if type(instruction) == Dialog: yield instruction else: yield StreamInfo( stream=stream, # type: ignore stream_instruction=instruction, # type: ignore ) yield GraphDeadEnd()
@strawberry.subscription async def stream_logs(self, info: Info, stream_uuid: Optional[uuid.UUID] = None, stream_point_uuid: Optional[uuid.UUID] = None) -> AsyncGenerator[StreamLog, None]: # type: ignore stream_logs = stream_models.StreamLog.objects.order_by("created_date") if stream_uuid: stream_logs = stream_logs.filter(stream__uuid=stream_uuid) if stream_point_uuid: stream_logs = stream_logs.filter(stream_point__uuid=stream_point_uuid) async for stream_log in stream_logs.all(): yield stream_log # type: ignore async for log_update in GenCasterChannel.receive_stream_log_updates( info.context["ws"], ): if stream_uuid: if str(log_update.stream_uuid) != str(stream_uuid): continue if stream_point_uuid: if str(log_update.stream_point_uuid) != str(stream_point_uuid): continue yield await stream_models.StreamLog.objects.aget(uuid=log_update.uuid) # type: ignore @strawberry.subscription async def streams(self, info: Info, limit: int = 20) -> AsyncGenerator[List[Stream], None]: # type: ignore async def get_streams() -> List[Stream]: # as slicing operation is not implemented in async mode we need this # helper function streams_db: List[Stream] = [] async for stream in stream_models.Stream.objects.order_by("-created_date")[ 0:limit ]: streams_db.append(stream) # type: ignore return streams_db yield await get_streams() async for _ in GenCasterChannel.receive_streams_updates(info.context["ws"]): yield await get_streams()
schema = strawberry.Schema( query=Query, mutation=Mutation, subscription=Subscription, )