Skip to content

Commit

Permalink
fix(sessions): sortable last trace start time (#5606)
Browse files Browse the repository at this point in the history
  • Loading branch information
RogerHYang committed Dec 4, 2024
1 parent aeb8a61 commit b324685
Show file tree
Hide file tree
Showing 18 changed files with 48 additions and 63 deletions.
3 changes: 2 additions & 1 deletion app/schema.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -1327,19 +1327,20 @@ type ProjectSession implements Node {
sessionId: String!
sessionUser: String
startTime: DateTime!
lastTraceStartTime: DateTime!
projectId: GlobalID!
numTraces: Int!
numTracesWithError: Int!
firstInput: SpanIOValue
lastOutput: SpanIOValue
lastTraceStartTime: DateTime
tokenUsage: TokenUsage!
traces(first: Int = 50, last: Int, after: String, before: String): TraceConnection!
traceLatencyMsQuantile(probability: Float!): Float
}

enum ProjectSessionColumn {
startTime
lastTraceStartTime
tokenCountTotal
numTraces
}
Expand Down
2 changes: 1 addition & 1 deletion app/src/pages/project/SessionsTable.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ export function SessionsTable(props: SessionsTableProps) {
{
header: "last trace start time",
accessorKey: "lastTraceStartTime",
enableSorting: false,
enableSorting: true,
cell: TimestampCell,
},
{
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions src/phoenix/db/insertion/span.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ async def insert_span(
select(models.ProjectSession).filter_by(session_id=session_id)
)
if project_session:
if project_session.last_trace_start_time < span.start_time:
project_session.last_trace_start_time = span.start_time
if span.start_time < project_session.start_time:
project_session.start_time = span.start_time
if project_session.project_id != project_rowid:
Expand All @@ -62,6 +64,7 @@ async def insert_span(
session_id=session_id,
session_user=session_user if session_user else None,
start_time=span.start_time,
last_trace_start_time=span.start_time,
)
session.add(project_session)
if project_session in session.dirty:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ class ProjectSession(Base):
session_user: Mapped[Optional[str]]
project_id: Mapped[int]
start_time: Mapped[datetime]
last_trace_start_time: Mapped[datetime]


class Trace(Base):
Expand Down Expand Up @@ -96,6 +97,7 @@ def upgrade() -> None:
nullable=False,
),
sa.Column("start_time", sa.TIMESTAMP(timezone=True), index=True, nullable=False),
sa.Column("last_trace_start_time", sa.TIMESTAMP(timezone=True), index=True, nullable=False),
)
with op.batch_alter_table("traces") as batch_op:
batch_op.add_column(
Expand Down Expand Up @@ -123,6 +125,9 @@ def upgrade() -> None:
order_by=[Trace.start_time, Trace.id, Span.id],
)
.label("rank"),
func.max(Trace.start_time)
.over(partition_by=Span.attributes[SESSION_ID])
.label("last_trace_start_time"),
)
.join_from(Span, Trace, Span.trace_rowid == Trace.id)
.where(Span.parent_id.is_(None))
Expand All @@ -136,12 +141,14 @@ def upgrade() -> None:
"session_user",
"project_id",
"start_time",
"last_trace_start_time",
],
select(
sessions_from_span.c.session_id,
sessions_from_span.c.session_user,
sessions_from_span.c.project_id,
sessions_from_span.c.start_time,
sessions_from_span.c.last_trace_start_time,
).where(sessions_from_span.c.rank == 1),
)
)
Expand Down
5 changes: 4 additions & 1 deletion src/phoenix/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,10 @@ class ProjectSession(Base):
nullable=False,
index=True,
)
start_time: Mapped[datetime] = mapped_column(UtcTimeStamp, index=True)
start_time: Mapped[datetime] = mapped_column(UtcTimeStamp, index=True, nullable=False)
last_trace_start_time: Mapped[datetime] = mapped_column(
UtcTimeStamp, index=True, nullable=False
)
traces: Mapped[list["Trace"]] = relationship(
"Trace",
back_populates="project_session",
Expand Down
2 changes: 0 additions & 2 deletions src/phoenix/server/api/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
ProjectByNameDataLoader,
RecordCountDataLoader,
SessionIODataLoader,
SessionLastTraceStartTimeDataLoader,
SessionNumTracesDataLoader,
SessionNumTracesWithErrorDataLoader,
SessionTokenUsagesDataLoader,
Expand Down Expand Up @@ -77,7 +76,6 @@ class DataLoaders:
record_counts: RecordCountDataLoader
session_first_inputs: SessionIODataLoader
session_last_outputs: SessionIODataLoader
session_last_trace_start_times: SessionLastTraceStartTimeDataLoader
session_num_traces: SessionNumTracesDataLoader
session_num_traces_with_error: SessionNumTracesWithErrorDataLoader
session_token_usages: SessionTokenUsagesDataLoader
Expand Down
2 changes: 0 additions & 2 deletions src/phoenix/server/api/dataloaders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from .project_by_name import ProjectByNameDataLoader
from .record_counts import RecordCountCache, RecordCountDataLoader
from .session_io import SessionIODataLoader
from .session_last_start_times import SessionLastTraceStartTimeDataLoader
from .session_num_traces import SessionNumTracesDataLoader
from .session_num_traces_with_error import SessionNumTracesWithErrorDataLoader
from .session_token_usages import SessionTokenUsagesDataLoader
Expand Down Expand Up @@ -53,7 +52,6 @@
"MinStartOrMaxEndTimeDataLoader",
"RecordCountDataLoader",
"SessionIODataLoader",
"SessionLastTraceStartTimeDataLoader",
"SessionNumTracesDataLoader",
"SessionNumTracesWithErrorDataLoader",
"SessionTokenUsagesDataLoader",
Expand Down
42 changes: 0 additions & 42 deletions src/phoenix/server/api/dataloaders/session_last_start_times.py

This file was deleted.

6 changes: 5 additions & 1 deletion src/phoenix/server/api/input_types/ProjectSessionSort.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,18 @@
@strawberry.enum
class ProjectSessionColumn(Enum):
startTime = auto()
lastTraceStartTime = auto()
tokenCountTotal = auto()
numTraces = auto()

@property
def data_type(self) -> CursorSortColumnDataType:
if self is ProjectSessionColumn.tokenCountTotal or self is ProjectSessionColumn.numTraces:
return CursorSortColumnDataType.INT
if self is ProjectSessionColumn.startTime:
if (
self is ProjectSessionColumn.startTime
or self is ProjectSessionColumn.lastTraceStartTime
):
return CursorSortColumnDataType.DATETIME
assert_never(self)

Expand Down
2 changes: 2 additions & 0 deletions src/phoenix/server/api/types/Project.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,8 @@ async def sessions(
key: ColumnElement[Any]
if sort.col is ProjectSessionColumn.startTime:
key = table.start_time.label("key")
elif sort.col is ProjectSessionColumn.lastTraceStartTime:
key = table.last_trace_start_time.label("key")
elif (
sort.col is ProjectSessionColumn.tokenCountTotal
or sort.col is ProjectSessionColumn.numTraces
Expand Down
9 changes: 2 additions & 7 deletions src/phoenix/server/api/types/ProjectSession.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class ProjectSession(Node):
session_id: str
session_user: Optional[str]
start_time: datetime
last_trace_start_time: datetime

@strawberry.field
async def project_id(self) -> GlobalID:
Expand Down Expand Up @@ -73,13 +74,6 @@ async def last_output(
value=record.value,
)

@strawberry.field
async def last_trace_start_time(
self,
info: Info[Context, None],
) -> Optional[datetime]:
return await info.context.data_loaders.session_last_trace_start_times.load(self.id_attr)

@strawberry.field
async def token_usage(
self,
Expand Down Expand Up @@ -137,6 +131,7 @@ def to_gql_project_session(project_session: models.ProjectSession) -> ProjectSes
session_user=project_session.session_user,
start_time=project_session.start_time,
project_rowid=project_session.project_id,
last_trace_start_time=project_session.last_trace_start_time,
)


Expand Down
2 changes: 0 additions & 2 deletions src/phoenix/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@
ProjectByNameDataLoader,
RecordCountDataLoader,
SessionIODataLoader,
SessionLastTraceStartTimeDataLoader,
SessionNumTracesDataLoader,
SessionNumTracesWithErrorDataLoader,
SessionTokenUsagesDataLoader,
Expand Down Expand Up @@ -618,7 +617,6 @@ def get_context() -> Context:
),
session_first_inputs=SessionIODataLoader(db, "first_input"),
session_last_outputs=SessionIODataLoader(db, "last_output"),
session_last_trace_start_times=SessionLastTraceStartTimeDataLoader(db),
session_num_traces=SessionNumTracesDataLoader(db),
session_num_traces_with_error=SessionNumTracesWithErrorDataLoader(db),
session_token_usages=SessionTokenUsagesDataLoader(db),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,11 @@ def get_spans(traces: Iterable[tuple[int, datetime]]) -> Iterator[dict[str, Any]
df_project_sessions_joined_spans.session_id
== df_project_sessions_joined_spans.session_id_span
).all()
assert (
df_project_sessions_joined_spans.groupby("session_id")
.apply(lambda s: s.last_trace_start_time.min() == s.start_time_trace.max()) # type: ignore
.all()
)

is_first = df_project_sessions_joined_spans.groupby("session_id").cumcount() == 0

Expand Down
11 changes: 11 additions & 0 deletions tests/integration/db_migrations/test_up_and_down_migrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,12 @@ def test_up_and_down_migrations(
assert isinstance(column.type, TIMESTAMP)
del column

column = columns.pop("last_trace_start_time", None)
assert column is not None
assert not column.nullable
assert isinstance(column.type, TIMESTAMP)
del column

assert not columns
del columns

Expand All @@ -91,6 +97,11 @@ def test_up_and_down_migrations(
assert not index.unique
del index

index = indexes.pop("ix_project_sessions_last_trace_start_time", None)
assert index is not None
assert not index.unique
del index

index = indexes.pop("ix_project_sessions_session_user", None)
assert index is not None
assert not index.unique
Expand Down
1 change: 1 addition & 0 deletions tests/unit/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ async def _add_project_session(
session_user=session_user,
project_id=project.id,
start_time=start_time,
last_trace_start_time=start_time,
)
session.add(project_session)
await session.flush()
Expand Down
1 change: 1 addition & 0 deletions tests/unit/server/api/dataloaders/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ async def data_for_testing_dataloaders(
session_id=f"proj{i}_sess{l}",
project_id=project_row_id,
start_time=start_time,
last_trace_start_time=start_time,
)
.returning(models.ProjectSession.id)
)
Expand Down

0 comments on commit b324685

Please sign in to comment.