From 676fe1d2954808cacd016e9ab0ed4f0561e662a1 Mon Sep 17 00:00:00 2001 From: Dustin Ngo Date: Tue, 14 Jan 2025 10:07:38 +0900 Subject: [PATCH] feat: Add Prompt to setPromptVersionTag mutation (#6010) * Add prompt to version tag payload * Build gql schema * Separate delete and set payloads * Add Prompt to delete payload as well --- app/schema.graphql | 1 + .../mutations/prompt_version_tag_mutations.py | 41 +++++++++++++++---- 2 files changed, 35 insertions(+), 7 deletions(-) diff --git a/app/schema.graphql b/app/schema.graphql index 5e7f09782c..0cc5490033 100644 --- a/app/schema.graphql +++ b/app/schema.graphql @@ -1545,6 +1545,7 @@ type PromptVersionTag implements Node { type PromptVersionTagMutationPayload { promptVersionTag: PromptVersionTag + prompt: Prompt! query: Query! } diff --git a/src/phoenix/server/api/mutations/prompt_version_tag_mutations.py b/src/phoenix/server/api/mutations/prompt_version_tag_mutations.py index 05e266b505..e8540c8d47 100644 --- a/src/phoenix/server/api/mutations/prompt_version_tag_mutations.py +++ b/src/phoenix/server/api/mutations/prompt_version_tag_mutations.py @@ -1,7 +1,7 @@ from typing import Optional import strawberry -from sqlalchemy import delete, select +from sqlalchemy import select from sqlalchemy.exc import IntegrityError as PostgreSQLIntegrityError from sqlean.dbapi2 import IntegrityError as SQLiteIntegrityError # type: ignore[import-untyped] from strawberry.relay import GlobalID @@ -12,6 +12,7 @@ from phoenix.server.api.exceptions import BadRequest, Conflict, NotFound from phoenix.server.api.queries import Query from phoenix.server.api.types.node import from_global_id_with_expected_type +from phoenix.server.api.types.Prompt import Prompt, to_gql_prompt_from_orm from phoenix.server.api.types.PromptVersion import PromptVersion from phoenix.server.api.types.PromptVersionTag import PromptVersionTag, to_gql_prompt_version_tag @@ -31,6 +32,7 @@ class SetPromptVersionTagInput: @strawberry.type class PromptVersionTagMutationPayload: prompt_version_tag: Optional[PromptVersionTag] + prompt: Prompt query: Query @@ -44,15 +46,33 @@ async def delete_prompt_version_tag( prompt_version_tag_id = from_global_id_with_expected_type( input.prompt_version_tag_id, PromptVersionTag.__name__ ) - stmt = delete(models.PromptVersionTag).where( - models.PromptVersionTag.id == prompt_version_tag_id + stmt = ( + select(models.PromptVersionTag, models.Prompt) + .join( + models.PromptVersion, + models.PromptVersion.id == models.PromptVersionTag.prompt_version_id, + ) + .join(models.Prompt, models.Prompt.id == models.PromptVersion.prompt_id) + .where(models.PromptVersionTag.id == prompt_version_tag_id) ) result = await session.execute(stmt) - if result.rowcount == 0: + if results := result.one_or_none(): + prompt_version_tag, prompt = results + + if not prompt_version_tag: raise NotFound(f"PromptVersionTag with ID {input.prompt_version_tag_id} not found") + if not prompt: + raise BadRequest( + f"PromptVersionTag with ID {input.prompt_version_tag_id} " + "does not belong to a prompt" + ) + + await session.delete(prompt_version_tag) await session.commit() - return PromptVersionTagMutationPayload(prompt_version_tag=None, query=Query()) + return PromptVersionTagMutationPayload( + prompt_version_tag=None, query=Query(), prompt=to_gql_prompt_from_orm(prompt) + ) @strawberry.mutation async def set_prompt_version_tag( @@ -66,9 +86,14 @@ async def set_prompt_version_tag( select(models.PromptVersion).where(models.PromptVersion.id == prompt_version_id) ) if not prompt_version: - raise BadRequest("PromptVersion with ID {input.prompt_version_id} not found.") + raise BadRequest(f"PromptVersion with ID {input.prompt_version_id} not found.") prompt_id = prompt_version.prompt_id + prompt = await session.scalar( + select(models.Prompt).where(models.Prompt.id == prompt_id) + ) + if not prompt: + raise BadRequest("All prompt version tags must belong to a prompt") existing_tag = await session.scalar( select(models.PromptVersionTag).where( @@ -101,4 +126,6 @@ async def set_prompt_version_tag( raise Conflict("Failed to update PromptVersionTag.") version_tag = to_gql_prompt_version_tag(updated_tag) - return PromptVersionTagMutationPayload(prompt_version_tag=version_tag, query=Query()) + return PromptVersionTagMutationPayload( + prompt_version_tag=version_tag, prompt=to_gql_prompt_from_orm(prompt), query=Query() + )