Skip to content

Commit

Permalink
fix: Coerce incoming span token counts to int (#5976)
Browse files Browse the repository at this point in the history
* fix: Coerce incoming span token counts to int

* Attempt to coerce span attribute types when decoding otlp span

If coercion fails, just keep old attribute value

* Yield coerces span attributes

* Update type
  • Loading branch information
cephalization authored Jan 9, 2025
1 parent 0bf194d commit 8711b21
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 17 deletions.
37 changes: 24 additions & 13 deletions src/phoenix/db/insertion/span.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,19 +100,30 @@ async def insert_span(
)

cumulative_error_count = int(span.status_code is SpanStatusCode.ERROR)
cumulative_llm_token_count_prompt = cast(
int, get_attribute_value(span.attributes, SpanAttributes.LLM_TOKEN_COUNT_PROMPT) or 0
)
cumulative_llm_token_count_completion = cast(
int, get_attribute_value(span.attributes, SpanAttributes.LLM_TOKEN_COUNT_COMPLETION) or 0
)
llm_token_count_prompt = cast(
Optional[int], get_attribute_value(span.attributes, SpanAttributes.LLM_TOKEN_COUNT_PROMPT)
)
llm_token_count_completion = cast(
Optional[int],
get_attribute_value(span.attributes, SpanAttributes.LLM_TOKEN_COUNT_COMPLETION),
)
try:
cumulative_llm_token_count_prompt = int(
get_attribute_value(span.attributes, SpanAttributes.LLM_TOKEN_COUNT_PROMPT) or 0
)
except BaseException:
cumulative_llm_token_count_prompt = 0
try:
cumulative_llm_token_count_completion = int(
get_attribute_value(span.attributes, SpanAttributes.LLM_TOKEN_COUNT_COMPLETION) or 0
)
except BaseException:
cumulative_llm_token_count_completion = 0
try:
llm_token_count_prompt = int(
get_attribute_value(span.attributes, SpanAttributes.LLM_TOKEN_COUNT_PROMPT) or 0
)
except BaseException:
llm_token_count_prompt = 0
try:
llm_token_count_completion = int(
get_attribute_value(span.attributes, SpanAttributes.LLM_TOKEN_COUNT_COMPLETION) or 0
)
except BaseException:
llm_token_count_completion = 0
if accumulation := (
await session.execute(
select(
Expand Down
19 changes: 18 additions & 1 deletion src/phoenix/trace/otel.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,21 @@
OUTPUT_VALUE = SpanAttributes.OUTPUT_VALUE
TOOL_PARAMETERS = SpanAttributes.TOOL_PARAMETERS
LLM_PROMPT_TEMPLATE_VARIABLES = SpanAttributes.LLM_PROMPT_TEMPLATE_VARIABLES
LLM_TOKEN_COUNT_PROMPT = SpanAttributes.LLM_TOKEN_COUNT_PROMPT
LLM_TOKEN_COUNT_COMPLETION = SpanAttributes.LLM_TOKEN_COUNT_COMPLETION
LLM_TOKEN_COUNT_TOTAL = SpanAttributes.LLM_TOKEN_COUNT_TOTAL


def coerce_otlp_span_attributes(
decoded_attributes: Iterable[tuple[str, Any]],
) -> Iterator[tuple[str, Any]]:
for key, value in decoded_attributes:
if key in (LLM_TOKEN_COUNT_PROMPT, LLM_TOKEN_COUNT_COMPLETION, LLM_TOKEN_COUNT_TOTAL):
try:
value = int(value)
except BaseException:
pass
yield key, value


def decode_otlp_span(otlp_span: otlp.Span) -> Span:
Expand All @@ -59,7 +74,9 @@ def decode_otlp_span(otlp_span: otlp.Span) -> Span:
start_time = _decode_unix_nano(otlp_span.start_time_unix_nano)
end_time = _decode_unix_nano(otlp_span.end_time_unix_nano)

attributes = unflatten(load_json_strings(_decode_key_values(otlp_span.attributes)))
attributes = unflatten(
load_json_strings(coerce_otlp_span_attributes(_decode_key_values(otlp_span.attributes)))
)
span_kind = SpanKind(get_attribute_value(attributes, OPENINFERENCE_SPAN_KIND))

status_code, status_message = _decode_status(otlp_span.status)
Expand Down
42 changes: 39 additions & 3 deletions tests/unit/trace/test_otel.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,14 @@
import opentelemetry.proto.trace.v1.trace_pb2 as otlp
import pytest
from google.protobuf.json_format import MessageToJson # type: ignore[import-untyped]
from openinference.semconv.trace import (
SpanAttributes,
)
from openinference.semconv.trace import SpanAttributes
from opentelemetry.proto.common.v1.common_pb2 import AnyValue, ArrayValue, KeyValue
from pytest import approx

from phoenix.trace.otel import (
_decode_identifier,
_encode_identifier,
coerce_otlp_span_attributes,
decode_otlp_span,
encode_span_to_otlp,
)
Expand Down Expand Up @@ -463,6 +462,43 @@ def test_decode_encode_tool_parameters(span: Span) -> None:
assert decoded_span.attributes["tool"]["parameters"] == span.attributes["tool"]["parameters"]


def test_coerce_otlp_span_attributes() -> None:
# Test attributes that should be coerced
input_attrs = [
("llm.token_count.prompt", "123"),
("llm.token_count.completion", "456"),
("llm.token_count.total", "579"),
# Test attributes that should not be modified
("other.number", "789"),
("some.string", "hello"),
("llm.other.field", "world"),
]

result = list(coerce_otlp_span_attributes(input_attrs))

expected = [
("llm.token_count.prompt", 123),
("llm.token_count.completion", 456),
("llm.token_count.total", 579),
("other.number", "789"),
("some.string", "hello"),
("llm.other.field", "world"),
]

assert result == expected

# Test that invalid number strings remain as strings
invalid_attrs = [
("llm.token_count.prompt", "not_a_number"),
("llm.token_count.completion", "invalid"),
("llm.token_count.total", ""),
]

result = list(coerce_otlp_span_attributes(invalid_attrs))

assert result == invalid_attrs


@pytest.fixture
def span() -> Span:
trace_id = "f096b681-b8d4-44eb-bc4a-1db0b5a8d556"
Expand Down

0 comments on commit 8711b21

Please sign in to comment.