From 8711b2168e8d947f02b8e65a6c51d8bb7a42c912 Mon Sep 17 00:00:00 2001 From: Anthony Powell Date: Thu, 9 Jan 2025 17:40:23 -0500 Subject: [PATCH] fix: Coerce incoming span token counts to int (#5976) * fix: Coerce incoming span token counts to int * Attempt to coerce span attribute types when decoding otlp span If coercion fails, just keep old attribute value * Yield coerces span attributes * Update type --- src/phoenix/db/insertion/span.py | 37 ++++++++++++++++++---------- src/phoenix/trace/otel.py | 19 ++++++++++++++- tests/unit/trace/test_otel.py | 42 +++++++++++++++++++++++++++++--- 3 files changed, 81 insertions(+), 17 deletions(-) diff --git a/src/phoenix/db/insertion/span.py b/src/phoenix/db/insertion/span.py index cc07e0b6dd..9b70ca1288 100644 --- a/src/phoenix/db/insertion/span.py +++ b/src/phoenix/db/insertion/span.py @@ -100,19 +100,30 @@ async def insert_span( ) cumulative_error_count = int(span.status_code is SpanStatusCode.ERROR) - cumulative_llm_token_count_prompt = cast( - int, get_attribute_value(span.attributes, SpanAttributes.LLM_TOKEN_COUNT_PROMPT) or 0 - ) - cumulative_llm_token_count_completion = cast( - int, get_attribute_value(span.attributes, SpanAttributes.LLM_TOKEN_COUNT_COMPLETION) or 0 - ) - llm_token_count_prompt = cast( - Optional[int], get_attribute_value(span.attributes, SpanAttributes.LLM_TOKEN_COUNT_PROMPT) - ) - llm_token_count_completion = cast( - Optional[int], - get_attribute_value(span.attributes, SpanAttributes.LLM_TOKEN_COUNT_COMPLETION), - ) + try: + cumulative_llm_token_count_prompt = int( + get_attribute_value(span.attributes, SpanAttributes.LLM_TOKEN_COUNT_PROMPT) or 0 + ) + except BaseException: + cumulative_llm_token_count_prompt = 0 + try: + cumulative_llm_token_count_completion = int( + get_attribute_value(span.attributes, SpanAttributes.LLM_TOKEN_COUNT_COMPLETION) or 0 + ) + except BaseException: + cumulative_llm_token_count_completion = 0 + try: + llm_token_count_prompt = int( + get_attribute_value(span.attributes, SpanAttributes.LLM_TOKEN_COUNT_PROMPT) or 0 + ) + except BaseException: + llm_token_count_prompt = 0 + try: + llm_token_count_completion = int( + get_attribute_value(span.attributes, SpanAttributes.LLM_TOKEN_COUNT_COMPLETION) or 0 + ) + except BaseException: + llm_token_count_completion = 0 if accumulation := ( await session.execute( select( diff --git a/src/phoenix/trace/otel.py b/src/phoenix/trace/otel.py index 34382bc21e..0f6de7c468 100644 --- a/src/phoenix/trace/otel.py +++ b/src/phoenix/trace/otel.py @@ -49,6 +49,21 @@ OUTPUT_VALUE = SpanAttributes.OUTPUT_VALUE TOOL_PARAMETERS = SpanAttributes.TOOL_PARAMETERS LLM_PROMPT_TEMPLATE_VARIABLES = SpanAttributes.LLM_PROMPT_TEMPLATE_VARIABLES +LLM_TOKEN_COUNT_PROMPT = SpanAttributes.LLM_TOKEN_COUNT_PROMPT +LLM_TOKEN_COUNT_COMPLETION = SpanAttributes.LLM_TOKEN_COUNT_COMPLETION +LLM_TOKEN_COUNT_TOTAL = SpanAttributes.LLM_TOKEN_COUNT_TOTAL + + +def coerce_otlp_span_attributes( + decoded_attributes: Iterable[tuple[str, Any]], +) -> Iterator[tuple[str, Any]]: + for key, value in decoded_attributes: + if key in (LLM_TOKEN_COUNT_PROMPT, LLM_TOKEN_COUNT_COMPLETION, LLM_TOKEN_COUNT_TOTAL): + try: + value = int(value) + except BaseException: + pass + yield key, value def decode_otlp_span(otlp_span: otlp.Span) -> Span: @@ -59,7 +74,9 @@ def decode_otlp_span(otlp_span: otlp.Span) -> Span: start_time = _decode_unix_nano(otlp_span.start_time_unix_nano) end_time = _decode_unix_nano(otlp_span.end_time_unix_nano) - attributes = unflatten(load_json_strings(_decode_key_values(otlp_span.attributes))) + attributes = unflatten( + load_json_strings(coerce_otlp_span_attributes(_decode_key_values(otlp_span.attributes))) + ) span_kind = SpanKind(get_attribute_value(attributes, OPENINFERENCE_SPAN_KIND)) status_code, status_message = _decode_status(otlp_span.status) diff --git a/tests/unit/trace/test_otel.py b/tests/unit/trace/test_otel.py index b20d76a224..5d767e8662 100644 --- a/tests/unit/trace/test_otel.py +++ b/tests/unit/trace/test_otel.py @@ -8,15 +8,14 @@ import opentelemetry.proto.trace.v1.trace_pb2 as otlp import pytest from google.protobuf.json_format import MessageToJson # type: ignore[import-untyped] -from openinference.semconv.trace import ( - SpanAttributes, -) +from openinference.semconv.trace import SpanAttributes from opentelemetry.proto.common.v1.common_pb2 import AnyValue, ArrayValue, KeyValue from pytest import approx from phoenix.trace.otel import ( _decode_identifier, _encode_identifier, + coerce_otlp_span_attributes, decode_otlp_span, encode_span_to_otlp, ) @@ -463,6 +462,43 @@ def test_decode_encode_tool_parameters(span: Span) -> None: assert decoded_span.attributes["tool"]["parameters"] == span.attributes["tool"]["parameters"] +def test_coerce_otlp_span_attributes() -> None: + # Test attributes that should be coerced + input_attrs = [ + ("llm.token_count.prompt", "123"), + ("llm.token_count.completion", "456"), + ("llm.token_count.total", "579"), + # Test attributes that should not be modified + ("other.number", "789"), + ("some.string", "hello"), + ("llm.other.field", "world"), + ] + + result = list(coerce_otlp_span_attributes(input_attrs)) + + expected = [ + ("llm.token_count.prompt", 123), + ("llm.token_count.completion", 456), + ("llm.token_count.total", 579), + ("other.number", "789"), + ("some.string", "hello"), + ("llm.other.field", "world"), + ] + + assert result == expected + + # Test that invalid number strings remain as strings + invalid_attrs = [ + ("llm.token_count.prompt", "not_a_number"), + ("llm.token_count.completion", "invalid"), + ("llm.token_count.total", ""), + ] + + result = list(coerce_otlp_span_attributes(invalid_attrs)) + + assert result == invalid_attrs + + @pytest.fixture def span() -> Span: trace_id = "f096b681-b8d4-44eb-bc4a-1db0b5a8d556"