From b3246851efa6f783b0f4f033bf16bf2310bc1478 Mon Sep 17 00:00:00 2001 From: Roger Yang <80478925+RogerHYang@users.noreply.github.com> Date: Wed, 4 Dec 2024 11:22:37 -0800 Subject: [PATCH] fix(sessions): sortable last trace start time (#5606) --- app/schema.graphql | 3 +- app/src/pages/project/SessionsTable.tsx | 2 +- .../SessionsTableQuery.graphql.ts | 4 +- .../SessionsTable_sessions.graphql.ts | 4 +- src/phoenix/db/insertion/span.py | 3 ++ ...d9e43755f_create_project_sessions_table.py | 7 ++++ src/phoenix/db/models.py | 5 ++- src/phoenix/server/api/context.py | 2 - .../server/api/dataloaders/__init__.py | 2 - .../dataloaders/session_last_start_times.py | 42 ------------------- .../api/input_types/ProjectSessionSort.py | 6 ++- src/phoenix/server/api/types/Project.py | 2 + .../server/api/types/ProjectSession.py | 9 +--- src/phoenix/server/app.py | 2 - ...d9e43755f_create_project_sessions_table.py | 5 +++ .../test_up_and_down_migrations.py | 11 +++++ tests/unit/_helpers.py | 1 + tests/unit/server/api/dataloaders/conftest.py | 1 + 18 files changed, 48 insertions(+), 63 deletions(-) delete mode 100644 src/phoenix/server/api/dataloaders/session_last_start_times.py diff --git a/app/schema.graphql b/app/schema.graphql index 11ddd75ef0..92dd33ecec 100644 --- a/app/schema.graphql +++ b/app/schema.graphql @@ -1327,12 +1327,12 @@ type ProjectSession implements Node { sessionId: String! sessionUser: String startTime: DateTime! + lastTraceStartTime: DateTime! projectId: GlobalID! numTraces: Int! numTracesWithError: Int! firstInput: SpanIOValue lastOutput: SpanIOValue - lastTraceStartTime: DateTime tokenUsage: TokenUsage! traces(first: Int = 50, last: Int, after: String, before: String): TraceConnection! traceLatencyMsQuantile(probability: Float!): Float @@ -1340,6 +1340,7 @@ type ProjectSession implements Node { enum ProjectSessionColumn { startTime + lastTraceStartTime tokenCountTotal numTraces } diff --git a/app/src/pages/project/SessionsTable.tsx b/app/src/pages/project/SessionsTable.tsx index 984b6a13af..66d6765aa5 100644 --- a/app/src/pages/project/SessionsTable.tsx +++ b/app/src/pages/project/SessionsTable.tsx @@ -138,7 +138,7 @@ export function SessionsTable(props: SessionsTableProps) { { header: "last trace start time", accessorKey: "lastTraceStartTime", - enableSorting: false, + enableSorting: true, cell: TimestampCell, }, { diff --git a/app/src/pages/project/__generated__/SessionsTableQuery.graphql.ts b/app/src/pages/project/__generated__/SessionsTableQuery.graphql.ts index 61abb6d381..d614196e22 100644 --- a/app/src/pages/project/__generated__/SessionsTableQuery.graphql.ts +++ b/app/src/pages/project/__generated__/SessionsTableQuery.graphql.ts @@ -1,5 +1,5 @@ /** - * @generated SignedSource<<15f426fc91f1e8e4e251f32ed446f684>> + * @generated SignedSource<<893092928cd943f6560b50f2039f1446>> * @lightSyntaxTransform * @nogrep */ @@ -10,7 +10,7 @@ import { ConcreteRequest, Query } from 'relay-runtime'; import { FragmentRefs } from "relay-runtime"; -export type ProjectSessionColumn = "numTraces" | "startTime" | "tokenCountTotal"; +export type ProjectSessionColumn = "lastTraceStartTime" | "numTraces" | "startTime" | "tokenCountTotal"; export type SortDir = "asc" | "desc"; export type ProjectSessionSort = { col: ProjectSessionColumn; diff --git a/app/src/pages/project/__generated__/SessionsTable_sessions.graphql.ts b/app/src/pages/project/__generated__/SessionsTable_sessions.graphql.ts index d6a408194a..519b7bd387 100644 --- a/app/src/pages/project/__generated__/SessionsTable_sessions.graphql.ts +++ b/app/src/pages/project/__generated__/SessionsTable_sessions.graphql.ts @@ -1,5 +1,5 @@ /** - * @generated SignedSource<<67ff1a198341dfb472d4d063fdc8a590>> + * @generated SignedSource<<842f5fcd489299dcfa79e97cf4f743c2>> * @lightSyntaxTransform * @nogrep */ @@ -23,7 +23,7 @@ export type SessionsTable_sessions$data = { readonly lastOutput: { readonly value: string; } | null; - readonly lastTraceStartTime: string | null; + readonly lastTraceStartTime: string; readonly numTraces: number; readonly sessionId: string; readonly startTime: string; diff --git a/src/phoenix/db/insertion/span.py b/src/phoenix/db/insertion/span.py index 02c6e91d2b..1853c0626d 100644 --- a/src/phoenix/db/insertion/span.py +++ b/src/phoenix/db/insertion/span.py @@ -50,6 +50,8 @@ async def insert_span( select(models.ProjectSession).filter_by(session_id=session_id) ) if project_session: + if project_session.last_trace_start_time < span.start_time: + project_session.last_trace_start_time = span.start_time if span.start_time < project_session.start_time: project_session.start_time = span.start_time if project_session.project_id != project_rowid: @@ -62,6 +64,7 @@ async def insert_span( session_id=session_id, session_user=session_user if session_user else None, start_time=span.start_time, + last_trace_start_time=span.start_time, ) session.add(project_session) if project_session in session.dirty: diff --git a/src/phoenix/db/migrations/versions/4ded9e43755f_create_project_sessions_table.py b/src/phoenix/db/migrations/versions/4ded9e43755f_create_project_sessions_table.py index d9dabaad58..8a88355e96 100644 --- a/src/phoenix/db/migrations/versions/4ded9e43755f_create_project_sessions_table.py +++ b/src/phoenix/db/migrations/versions/4ded9e43755f_create_project_sessions_table.py @@ -58,6 +58,7 @@ class ProjectSession(Base): session_user: Mapped[Optional[str]] project_id: Mapped[int] start_time: Mapped[datetime] + last_trace_start_time: Mapped[datetime] class Trace(Base): @@ -96,6 +97,7 @@ def upgrade() -> None: nullable=False, ), sa.Column("start_time", sa.TIMESTAMP(timezone=True), index=True, nullable=False), + sa.Column("last_trace_start_time", sa.TIMESTAMP(timezone=True), index=True, nullable=False), ) with op.batch_alter_table("traces") as batch_op: batch_op.add_column( @@ -123,6 +125,9 @@ def upgrade() -> None: order_by=[Trace.start_time, Trace.id, Span.id], ) .label("rank"), + func.max(Trace.start_time) + .over(partition_by=Span.attributes[SESSION_ID]) + .label("last_trace_start_time"), ) .join_from(Span, Trace, Span.trace_rowid == Trace.id) .where(Span.parent_id.is_(None)) @@ -136,12 +141,14 @@ def upgrade() -> None: "session_user", "project_id", "start_time", + "last_trace_start_time", ], select( sessions_from_span.c.session_id, sessions_from_span.c.session_user, sessions_from_span.c.project_id, sessions_from_span.c.start_time, + sessions_from_span.c.last_trace_start_time, ).where(sessions_from_span.c.rank == 1), ) ) diff --git a/src/phoenix/db/models.py b/src/phoenix/db/models.py index 6b0a9d467e..dc4cac938c 100644 --- a/src/phoenix/db/models.py +++ b/src/phoenix/db/models.py @@ -166,7 +166,10 @@ class ProjectSession(Base): nullable=False, index=True, ) - start_time: Mapped[datetime] = mapped_column(UtcTimeStamp, index=True) + start_time: Mapped[datetime] = mapped_column(UtcTimeStamp, index=True, nullable=False) + last_trace_start_time: Mapped[datetime] = mapped_column( + UtcTimeStamp, index=True, nullable=False + ) traces: Mapped[list["Trace"]] = relationship( "Trace", back_populates="project_session", diff --git a/src/phoenix/server/api/context.py b/src/phoenix/server/api/context.py index 4e7cdb4659..75e7544b13 100644 --- a/src/phoenix/server/api/context.py +++ b/src/phoenix/server/api/context.py @@ -32,7 +32,6 @@ ProjectByNameDataLoader, RecordCountDataLoader, SessionIODataLoader, - SessionLastTraceStartTimeDataLoader, SessionNumTracesDataLoader, SessionNumTracesWithErrorDataLoader, SessionTokenUsagesDataLoader, @@ -77,7 +76,6 @@ class DataLoaders: record_counts: RecordCountDataLoader session_first_inputs: SessionIODataLoader session_last_outputs: SessionIODataLoader - session_last_trace_start_times: SessionLastTraceStartTimeDataLoader session_num_traces: SessionNumTracesDataLoader session_num_traces_with_error: SessionNumTracesWithErrorDataLoader session_token_usages: SessionTokenUsagesDataLoader diff --git a/src/phoenix/server/api/dataloaders/__init__.py b/src/phoenix/server/api/dataloaders/__init__.py index cc26df0f37..9cea67eaca 100644 --- a/src/phoenix/server/api/dataloaders/__init__.py +++ b/src/phoenix/server/api/dataloaders/__init__.py @@ -20,7 +20,6 @@ from .project_by_name import ProjectByNameDataLoader from .record_counts import RecordCountCache, RecordCountDataLoader from .session_io import SessionIODataLoader -from .session_last_start_times import SessionLastTraceStartTimeDataLoader from .session_num_traces import SessionNumTracesDataLoader from .session_num_traces_with_error import SessionNumTracesWithErrorDataLoader from .session_token_usages import SessionTokenUsagesDataLoader @@ -53,7 +52,6 @@ "MinStartOrMaxEndTimeDataLoader", "RecordCountDataLoader", "SessionIODataLoader", - "SessionLastTraceStartTimeDataLoader", "SessionNumTracesDataLoader", "SessionNumTracesWithErrorDataLoader", "SessionTokenUsagesDataLoader", diff --git a/src/phoenix/server/api/dataloaders/session_last_start_times.py b/src/phoenix/server/api/dataloaders/session_last_start_times.py deleted file mode 100644 index c7ddb9de13..0000000000 --- a/src/phoenix/server/api/dataloaders/session_last_start_times.py +++ /dev/null @@ -1,42 +0,0 @@ -from datetime import datetime -from typing import Optional - -from openinference.semconv.trace import SpanAttributes -from sqlalchemy import func, select -from strawberry.dataloader import DataLoader -from typing_extensions import TypeAlias - -from phoenix.db import models -from phoenix.server.types import DbSessionFactory - -Key: TypeAlias = int -Result: TypeAlias = Optional[datetime] - - -class SessionLastTraceStartTimeDataLoader(DataLoader[Key, Result]): - def __init__(self, db: DbSessionFactory) -> None: - super().__init__(load_fn=self._load_fn) - self._db = db - - async def _load_fn(self, keys: list[Key]) -> list[Result]: - stmt = ( - select( - models.Trace.project_session_rowid, - func.max(models.Trace.start_time).label("last_start_time"), - ) - .where(models.Trace.project_session_rowid.in_(set(keys))) - .group_by(models.Trace.project_session_rowid) - ) - async with self._db() as session: - result: dict[Key, Result] = { - id_: last_start_time - async for id_, last_start_time in await session.stream(stmt) - if id_ is not None - } - return [result.get(key) for key in keys] - - -INPUT_VALUE = SpanAttributes.INPUT_VALUE.split(".") -INPUT_MIME_TYPE = SpanAttributes.INPUT_MIME_TYPE.split(".") -OUTPUT_VALUE = SpanAttributes.OUTPUT_VALUE.split(".") -OUTPUT_MIME_TYPE = SpanAttributes.OUTPUT_MIME_TYPE.split(".") diff --git a/src/phoenix/server/api/input_types/ProjectSessionSort.py b/src/phoenix/server/api/input_types/ProjectSessionSort.py index d6e1dd095b..4688763cfe 100644 --- a/src/phoenix/server/api/input_types/ProjectSessionSort.py +++ b/src/phoenix/server/api/input_types/ProjectSessionSort.py @@ -10,6 +10,7 @@ @strawberry.enum class ProjectSessionColumn(Enum): startTime = auto() + lastTraceStartTime = auto() tokenCountTotal = auto() numTraces = auto() @@ -17,7 +18,10 @@ class ProjectSessionColumn(Enum): def data_type(self) -> CursorSortColumnDataType: if self is ProjectSessionColumn.tokenCountTotal or self is ProjectSessionColumn.numTraces: return CursorSortColumnDataType.INT - if self is ProjectSessionColumn.startTime: + if ( + self is ProjectSessionColumn.startTime + or self is ProjectSessionColumn.lastTraceStartTime + ): return CursorSortColumnDataType.DATETIME assert_never(self) diff --git a/src/phoenix/server/api/types/Project.py b/src/phoenix/server/api/types/Project.py index 95424d7a2e..c849013f63 100644 --- a/src/phoenix/server/api/types/Project.py +++ b/src/phoenix/server/api/types/Project.py @@ -298,6 +298,8 @@ async def sessions( key: ColumnElement[Any] if sort.col is ProjectSessionColumn.startTime: key = table.start_time.label("key") + elif sort.col is ProjectSessionColumn.lastTraceStartTime: + key = table.last_trace_start_time.label("key") elif ( sort.col is ProjectSessionColumn.tokenCountTotal or sort.col is ProjectSessionColumn.numTraces diff --git a/src/phoenix/server/api/types/ProjectSession.py b/src/phoenix/server/api/types/ProjectSession.py index 0ba6f8a8bb..79e9c4efea 100644 --- a/src/phoenix/server/api/types/ProjectSession.py +++ b/src/phoenix/server/api/types/ProjectSession.py @@ -26,6 +26,7 @@ class ProjectSession(Node): session_id: str session_user: Optional[str] start_time: datetime + last_trace_start_time: datetime @strawberry.field async def project_id(self) -> GlobalID: @@ -73,13 +74,6 @@ async def last_output( value=record.value, ) - @strawberry.field - async def last_trace_start_time( - self, - info: Info[Context, None], - ) -> Optional[datetime]: - return await info.context.data_loaders.session_last_trace_start_times.load(self.id_attr) - @strawberry.field async def token_usage( self, @@ -137,6 +131,7 @@ def to_gql_project_session(project_session: models.ProjectSession) -> ProjectSes session_user=project_session.session_user, start_time=project_session.start_time, project_rowid=project_session.project_id, + last_trace_start_time=project_session.last_trace_start_time, ) diff --git a/src/phoenix/server/app.py b/src/phoenix/server/app.py index 4228e8d1f6..cffb1f3b4f 100644 --- a/src/phoenix/server/app.py +++ b/src/phoenix/server/app.py @@ -88,7 +88,6 @@ ProjectByNameDataLoader, RecordCountDataLoader, SessionIODataLoader, - SessionLastTraceStartTimeDataLoader, SessionNumTracesDataLoader, SessionNumTracesWithErrorDataLoader, SessionTokenUsagesDataLoader, @@ -618,7 +617,6 @@ def get_context() -> Context: ), session_first_inputs=SessionIODataLoader(db, "first_input"), session_last_outputs=SessionIODataLoader(db, "last_output"), - session_last_trace_start_times=SessionLastTraceStartTimeDataLoader(db), session_num_traces=SessionNumTracesDataLoader(db), session_num_traces_with_error=SessionNumTracesWithErrorDataLoader(db), session_token_usages=SessionTokenUsagesDataLoader(db), diff --git a/tests/integration/db_migrations/test_data_migration_4ded9e43755f_create_project_sessions_table.py b/tests/integration/db_migrations/test_data_migration_4ded9e43755f_create_project_sessions_table.py index 05a929329f..0c648a16cd 100644 --- a/tests/integration/db_migrations/test_data_migration_4ded9e43755f_create_project_sessions_table.py +++ b/tests/integration/db_migrations/test_data_migration_4ded9e43755f_create_project_sessions_table.py @@ -186,6 +186,11 @@ def get_spans(traces: Iterable[tuple[int, datetime]]) -> Iterator[dict[str, Any] df_project_sessions_joined_spans.session_id == df_project_sessions_joined_spans.session_id_span ).all() + assert ( + df_project_sessions_joined_spans.groupby("session_id") + .apply(lambda s: s.last_trace_start_time.min() == s.start_time_trace.max()) # type: ignore + .all() + ) is_first = df_project_sessions_joined_spans.groupby("session_id").cumcount() == 0 diff --git a/tests/integration/db_migrations/test_up_and_down_migrations.py b/tests/integration/db_migrations/test_up_and_down_migrations.py index a4b291e0b4..c323feee07 100644 --- a/tests/integration/db_migrations/test_up_and_down_migrations.py +++ b/tests/integration/db_migrations/test_up_and_down_migrations.py @@ -81,6 +81,12 @@ def test_up_and_down_migrations( assert isinstance(column.type, TIMESTAMP) del column + column = columns.pop("last_trace_start_time", None) + assert column is not None + assert not column.nullable + assert isinstance(column.type, TIMESTAMP) + del column + assert not columns del columns @@ -91,6 +97,11 @@ def test_up_and_down_migrations( assert not index.unique del index + index = indexes.pop("ix_project_sessions_last_trace_start_time", None) + assert index is not None + assert not index.unique + del index + index = indexes.pop("ix_project_sessions_session_user", None) assert index is not None assert not index.unique diff --git a/tests/unit/_helpers.py b/tests/unit/_helpers.py index a413f8c6e1..87a10778c0 100644 --- a/tests/unit/_helpers.py +++ b/tests/unit/_helpers.py @@ -142,6 +142,7 @@ async def _add_project_session( session_user=session_user, project_id=project.id, start_time=start_time, + last_trace_start_time=start_time, ) session.add(project_session) await session.flush() diff --git a/tests/unit/server/api/dataloaders/conftest.py b/tests/unit/server/api/dataloaders/conftest.py index 2a89129c68..6b2411bb61 100644 --- a/tests/unit/server/api/dataloaders/conftest.py +++ b/tests/unit/server/api/dataloaders/conftest.py @@ -32,6 +32,7 @@ async def data_for_testing_dataloaders( session_id=f"proj{i}_sess{l}", project_id=project_row_id, start_time=start_time, + last_trace_start_time=start_time, ) .returning(models.ProjectSession.id) )