Skip to content

Commit

Permalink
Add sequence number fields (#6011)
Browse files Browse the repository at this point in the history
  • Loading branch information
anticorrelator authored Jan 11, 2025
1 parent d82d609 commit 8869cbd
Show file tree
Hide file tree
Showing 7 changed files with 73 additions and 6 deletions.
5 changes: 5 additions & 0 deletions app/schema.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -1503,6 +1503,11 @@ type PromptVersion implements Node {
tags: [PromptVersionTag!]!
user: User
previousVersion: PromptVersion

"""
Sequence number (1-based) of prompt versions belonging to the same prompt
"""
sequenceNumber: Int!
}

"""A connection to a list of items."""
Expand Down
2 changes: 2 additions & 0 deletions src/phoenix/server/api/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
LatencyMsQuantileDataLoader,
MinStartOrMaxEndTimeDataLoader,
ProjectByNameDataLoader,
PromptVersionSequenceNumberDataLoader,
RecordCountDataLoader,
SessionIODataLoader,
SessionNumTracesDataLoader,
Expand Down Expand Up @@ -73,6 +74,7 @@ class DataLoaders:
experiment_sequence_number: ExperimentSequenceNumberDataLoader
latency_ms_quantile: LatencyMsQuantileDataLoader
min_start_or_max_end_times: MinStartOrMaxEndTimeDataLoader
prompt_version_sequence_number: PromptVersionSequenceNumberDataLoader
record_counts: RecordCountDataLoader
session_first_inputs: SessionIODataLoader
session_last_outputs: SessionIODataLoader
Expand Down
2 changes: 2 additions & 0 deletions src/phoenix/server/api/dataloaders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from .latency_ms_quantile import LatencyMsQuantileCache, LatencyMsQuantileDataLoader
from .min_start_or_max_end_times import MinStartOrMaxEndTimeCache, MinStartOrMaxEndTimeDataLoader
from .project_by_name import ProjectByNameDataLoader
from .prompt_version_sequence_number import PromptVersionSequenceNumberDataLoader
from .record_counts import RecordCountCache, RecordCountDataLoader
from .session_io import SessionIODataLoader
from .session_num_traces import SessionNumTracesDataLoader
Expand Down Expand Up @@ -50,6 +51,7 @@
"ExperimentSequenceNumberDataLoader",
"LatencyMsQuantileDataLoader",
"MinStartOrMaxEndTimeDataLoader",
"PromptVersionSequenceNumberDataLoader",
"RecordCountDataLoader",
"SessionIODataLoader",
"SessionNumTracesDataLoader",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from typing import Optional

from sqlalchemy import func, select
from strawberry.dataloader import DataLoader
from typing_extensions import TypeAlias

from phoenix.db import models
from phoenix.server.types import DbSessionFactory

PromptVersionId: TypeAlias = int
Key: TypeAlias = PromptVersionId
Result: TypeAlias = Optional[int]


class PromptVersionSequenceNumberDataLoader(DataLoader[Key, Result]):
def __init__(self, db: DbSessionFactory) -> None:
super().__init__(load_fn=self._load_fn)
self._db = db

async def _load_fn(self, keys: list[Key]) -> list[Result]:
prompt_version_ids = keys
row_number = (
func.row_number().over(
partition_by=models.PromptVersion.prompt_id,
order_by=models.PromptVersion.id,
)
).label("sequence_number")
subq = select(models.PromptVersion.id.label("prompt_version_id"), row_number).subquery()
stmt = select(subq).where(subq.c.prompt_version_id.in_(prompt_version_ids))
async with self._db() as session:
result = {
prompt_version_id: seq_number
async for prompt_version_id, seq_number in await session.stream(stmt)
}
return [result.get(prompt_version_id) for prompt_version_id in keys]
10 changes: 5 additions & 5 deletions src/phoenix/server/api/types/Prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Optional

import strawberry
from sqlalchemy import select
from sqlalchemy import func, select
from strawberry import UNSET
from strawberry.relay import Connection, Node, NodeID
from strawberry.types import Info
Expand Down Expand Up @@ -55,16 +55,16 @@ async def prompt_versions(
last=last,
before=before if isinstance(before, CursorString) else None,
)
row_number = func.row_number().over(order_by=models.PromptVersion.id).label("row_number")
stmt = (
select(models.PromptVersion)
select(models.PromptVersion, row_number)
.where(models.PromptVersion.prompt_id == self.id_attr)
.order_by(models.PromptVersion.id.desc())
)
async with info.context.db() as session:
orm_prompt_versions = await session.stream_scalars(stmt)
data = [
to_gql_prompt_version(prompt_version)
async for prompt_version in orm_prompt_versions
to_gql_prompt_version(prompt_version, sequence_number)
async for prompt_version, sequence_number in await session.stream(stmt)
]
return connection_from_list(data=data, args=args)

Expand Down
23 changes: 22 additions & 1 deletion src/phoenix/server/api/types/PromptVersion.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import strawberry
from sqlalchemy import select
from strawberry import Private
from strawberry.relay import Node, NodeID
from strawberry.scalars import JSON
from strawberry.types import Info
Expand Down Expand Up @@ -40,6 +41,7 @@ class PromptVersion(Node):
model_name: str
model_provider: str
created_at: datetime
cached_sequence_number: Private[Optional[int]] = None

@strawberry.field
async def tags(self, info: Info[Context, None]) -> list[PromptVersionTag]:
Expand Down Expand Up @@ -81,8 +83,26 @@ async def previous_version(self, info: Info[Context, None]) -> Optional["PromptV
return to_gql_prompt_version(prompt_version=previous_version)
return None

@strawberry.field(
description="Sequence number (1-based) of prompt versions belonging to the same prompt"
) # type: ignore
async def sequence_number(
self,
info: Info[Context, None],
) -> int:
if self.cached_sequence_number is None:
seq_num = await info.context.data_loaders.prompt_version_sequence_number.load(
self.id_attr
)
if seq_num is None:
raise ValueError(f"invalid prompt version: id={self.id_attr}")
self.cached_sequence_number = seq_num
return self.cached_sequence_number


def to_gql_prompt_version(prompt_version: models.PromptVersion) -> PromptVersion:
def to_gql_prompt_version(
prompt_version: models.PromptVersion, sequence_number: Optional[int] = None
) -> PromptVersion:
prompt_template_type = PromptTemplateType(prompt_version.template_type)
prompt_template = to_gql_template_from_orm(prompt_version)
prompt_template_format = PromptTemplateFormat(prompt_version.template_format)
Expand Down Expand Up @@ -113,4 +133,5 @@ def to_gql_prompt_version(prompt_version: models.PromptVersion) -> PromptVersion
model_name=prompt_version.model_name,
model_provider=prompt_version.model_provider,
created_at=prompt_version.created_at,
cached_sequence_number=sequence_number,
)
2 changes: 2 additions & 0 deletions src/phoenix/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
LatencyMsQuantileDataLoader,
MinStartOrMaxEndTimeDataLoader,
ProjectByNameDataLoader,
PromptVersionSequenceNumberDataLoader,
RecordCountDataLoader,
SessionIODataLoader,
SessionNumTracesDataLoader,
Expand Down Expand Up @@ -611,6 +612,7 @@ def get_context() -> Context:
else None
),
),
prompt_version_sequence_number=PromptVersionSequenceNumberDataLoader(db),
record_counts=RecordCountDataLoader(
db,
cache_map=cache_for_dataloaders.record_count if cache_for_dataloaders else None,
Expand Down

0 comments on commit 8869cbd

Please sign in to comment.