Skip to content

Commit

Permalink
feat: Add Prompt to setPromptVersionTag mutation (#6010)
Browse files Browse the repository at this point in the history
* Add prompt to version tag payload

* Build gql schema

* Separate delete and set payloads

* Add Prompt to delete payload as well
  • Loading branch information
anticorrelator authored Jan 14, 2025
1 parent 4609752 commit 676fe1d
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 7 deletions.
1 change: 1 addition & 0 deletions app/schema.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -1545,6 +1545,7 @@ type PromptVersionTag implements Node {

type PromptVersionTagMutationPayload {
promptVersionTag: PromptVersionTag
prompt: Prompt!
query: Query!
}

Expand Down
41 changes: 34 additions & 7 deletions src/phoenix/server/api/mutations/prompt_version_tag_mutations.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Optional

import strawberry
from sqlalchemy import delete, select
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError as PostgreSQLIntegrityError
from sqlean.dbapi2 import IntegrityError as SQLiteIntegrityError # type: ignore[import-untyped]
from strawberry.relay import GlobalID
Expand All @@ -12,6 +12,7 @@
from phoenix.server.api.exceptions import BadRequest, Conflict, NotFound
from phoenix.server.api.queries import Query
from phoenix.server.api.types.node import from_global_id_with_expected_type
from phoenix.server.api.types.Prompt import Prompt, to_gql_prompt_from_orm
from phoenix.server.api.types.PromptVersion import PromptVersion
from phoenix.server.api.types.PromptVersionTag import PromptVersionTag, to_gql_prompt_version_tag

Expand All @@ -31,6 +32,7 @@ class SetPromptVersionTagInput:
@strawberry.type
class PromptVersionTagMutationPayload:
prompt_version_tag: Optional[PromptVersionTag]
prompt: Prompt
query: Query


Expand All @@ -44,15 +46,33 @@ async def delete_prompt_version_tag(
prompt_version_tag_id = from_global_id_with_expected_type(
input.prompt_version_tag_id, PromptVersionTag.__name__
)
stmt = delete(models.PromptVersionTag).where(
models.PromptVersionTag.id == prompt_version_tag_id
stmt = (
select(models.PromptVersionTag, models.Prompt)
.join(
models.PromptVersion,
models.PromptVersion.id == models.PromptVersionTag.prompt_version_id,
)
.join(models.Prompt, models.Prompt.id == models.PromptVersion.prompt_id)
.where(models.PromptVersionTag.id == prompt_version_tag_id)
)
result = await session.execute(stmt)
if result.rowcount == 0:
if results := result.one_or_none():
prompt_version_tag, prompt = results

if not prompt_version_tag:
raise NotFound(f"PromptVersionTag with ID {input.prompt_version_tag_id} not found")

if not prompt:
raise BadRequest(
f"PromptVersionTag with ID {input.prompt_version_tag_id} "
"does not belong to a prompt"
)

await session.delete(prompt_version_tag)
await session.commit()
return PromptVersionTagMutationPayload(prompt_version_tag=None, query=Query())
return PromptVersionTagMutationPayload(
prompt_version_tag=None, query=Query(), prompt=to_gql_prompt_from_orm(prompt)
)

@strawberry.mutation
async def set_prompt_version_tag(
Expand All @@ -66,9 +86,14 @@ async def set_prompt_version_tag(
select(models.PromptVersion).where(models.PromptVersion.id == prompt_version_id)
)
if not prompt_version:
raise BadRequest("PromptVersion with ID {input.prompt_version_id} not found.")
raise BadRequest(f"PromptVersion with ID {input.prompt_version_id} not found.")

prompt_id = prompt_version.prompt_id
prompt = await session.scalar(
select(models.Prompt).where(models.Prompt.id == prompt_id)
)
if not prompt:
raise BadRequest("All prompt version tags must belong to a prompt")

existing_tag = await session.scalar(
select(models.PromptVersionTag).where(
Expand Down Expand Up @@ -101,4 +126,6 @@ async def set_prompt_version_tag(
raise Conflict("Failed to update PromptVersionTag.")

version_tag = to_gql_prompt_version_tag(updated_tag)
return PromptVersionTagMutationPayload(prompt_version_tag=version_tag, query=Query())
return PromptVersionTagMutationPayload(
prompt_version_tag=version_tag, prompt=to_gql_prompt_from_orm(prompt), query=Query()
)

0 comments on commit 676fe1d

Please sign in to comment.