Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove double record exception #712

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 11 additions & 37 deletions logfire/_internal/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -1948,13 +1948,13 @@ def __init__(self, span: trace_api.Span) -> None:
self._token = context_api.attach(trace_api.set_span_in_context(self._span))

def __enter__(self) -> FastLogfireSpan:
self._span.__enter__()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is maybe unnecessary but seems reasonable to add, and I hope this class will be refactored/simplified soon anyway.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given how optimised this class is, I'd rather not add it here, but maybe it can stay in LogfireSpan.

return self

@handle_internal_errors()
def __exit__(self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: Any) -> None:
context_api.detach(self._token)
_exit_span(self._span, exc_value)
self._span.end()
self._span.__exit__(exc_type, exc_value, traceback)


# Changes to this class may need to be reflected in `FastLogfireSpan` and `NoopSpan` as well.
Expand Down Expand Up @@ -1990,6 +1990,7 @@ def __enter__(self) -> LogfireSpan:
attributes=self._otlp_attributes,
links=self._links,
)
self._span.__enter__()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same note as above, this might not be necessary but seems fine to add and arguably more correct(?)

if self._token is None: # pragma: no branch
self._token = context_api.attach(trace_api.set_span_in_context(self._span))

Expand All @@ -1999,14 +2000,17 @@ def __enter__(self) -> LogfireSpan:
def __exit__(self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: Any) -> None:
if self._token is None: # pragma: no cover
return
assert self._span is not None

context_api.detach(self._token)
self._token = None

assert self._span is not None
_exit_span(self._span, exc_value)

self.end()
if self._span.is_recording():
with handle_internal_errors():
if self._added_attributes:
self._span.set_attribute(
ATTRIBUTES_JSON_SCHEMA_KEY, attributes_json_schema(self._json_schema_properties)
)
Comment on lines +2007 to +2012
Copy link
Contributor Author

@dmontagu dmontagu Dec 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just moved the logic from end() here, noting we can drop a check since we already know self._span is not None, and we don't currently rely on the LogfireSpan methods being called at the same time that the otel methods would be called.

self._span.__exit__(exc_type, exc_value, traceback)

@property
def message_template(self) -> str | None: # pragma: no cover
Expand All @@ -2032,26 +2036,6 @@ def message(self) -> str:
def message(self, message: str):
self._set_attribute(ATTRIBUTES_MESSAGE_KEY, message)

def end(self, end_time: int | None = None) -> None:
"""Sets the current time as the span's end time.

The span's end time is the wall time at which the operation finished.

Only the first call to this method is recorded, further calls are ignored so you
can call this within the span's context manager to end it before the context manager
exits.
"""
if self._span is None: # pragma: no cover
raise RuntimeError('Span has not been started')
if self._span.is_recording():
with handle_internal_errors():
if self._added_attributes:
self._span.set_attribute(
ATTRIBUTES_JSON_SCHEMA_KEY, attributes_json_schema(self._json_schema_properties)
)

self._span.end(end_time)
Copy link
Contributor Author

@dmontagu dmontagu Dec 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We used to only call self._span.end here if self._span was recording. Is that actually the right behavior though? The changes I made ensure that end() always gets called, even on a non-recording span, and that's why the timestamps changed below (because it generates another timestamp in the unconditional call to end()).

I can go back to only calling span.end() if the span is recording but I'm not sure if that's correct..

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import opentelemetry.trace

import logfire

logfire.configure()

tracer = opentelemetry.trace.get_tracer(__name__)
span = tracer.start_span('foo')
assert span.is_recording()
span.end()
assert not span.is_recording()
span.end()  # Logs a warning: Calling end() on an ended span.

So keeping the current behaviour seems best.


@handle_internal_errors()
def set_attribute(self, key: str, value: Any) -> None:
"""Sets an attribute on the span.
Expand Down Expand Up @@ -2183,16 +2167,6 @@ def is_recording(self) -> bool:
return False


def _exit_span(span: trace_api.Span, exception: BaseException | None) -> None:
if not span.is_recording():
return

# record exception if present
# isinstance is to ignore BaseException
if isinstance(exception, Exception):
record_exception(span, exception, escaped=True)


AttributesValueType = TypeVar('AttributesValueType', bound=Union[Any, otel_types.AttributeValue])


Expand Down
6 changes: 6 additions & 0 deletions logfire/_internal/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,12 @@ def record_exception(
timestamp = timestamp or self.ns_timestamp_generator()
record_exception(self.span, exception, attributes=attributes, timestamp=timestamp, escaped=escaped)

def __exit__(self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: Any) -> None:
if self.is_recording():
if isinstance(exc_value, BaseException):
self.record_exception(exc_value, escaped=True)
self.end()

if not TYPE_CHECKING: # pragma: no branch
# for ReadableSpan
def __getattr__(self, name: str) -> Any:
Expand Down
21 changes: 21 additions & 0 deletions tests/test_logfire.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from functools import partial
from logging import getLogger
from typing import Any, Callable
from unittest.mock import patch

import pytest
from dirty_equals import IsInt, IsJson, IsStr
Expand Down Expand Up @@ -36,6 +37,7 @@
)
from logfire._internal.formatter import FormattingFailedWarning, InspectArgumentsFailedWarning
from logfire._internal.main import NoopSpan
from logfire._internal.tracer import record_exception
from logfire._internal.utils import is_instrumentation_suppressed
from logfire.integrations.logging import LogfireLoggingHandler
from logfire.testing import TestExporter
Expand Down Expand Up @@ -3171,3 +3173,22 @@ def test_suppress_scopes(exporter: TestExporter, metrics_reader: InMemoryMetricR
}
]
)


def test_logfire_span_records_exceptions_once():
n_calls_to_record_exception = 0

def patched_record_exception(*args: Any, **kwargs: Any) -> Any:
nonlocal n_calls_to_record_exception
n_calls_to_record_exception += 1

return record_exception(*args, **kwargs)

with patch('logfire._internal.tracer.record_exception', patched_record_exception), patch(
'logfire._internal.main.record_exception', patched_record_exception
):
with pytest.raises(RuntimeError):
with logfire.span('foo'):
raise RuntimeError('error')

assert n_calls_to_record_exception == 1
Loading