Skip to content

Commit

Permalink
Update session logic
Browse files Browse the repository at this point in the history
Does the following:

- Clearing a session now clears traces only, instead of also deleting the session
- Requesting traces for a esssion will return a 404 Session Not found if no traces or session match a non-None session token
- Sessions now return session-less requests only up to the latest session marker
  • Loading branch information
tabgok committed Feb 12, 2024
1 parent 97d78ff commit 25161a8
Show file tree
Hide file tree
Showing 6 changed files with 162 additions and 108 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ var/
# Common virtual environments
.venv/
venv/
.python-version

# Editors
.idea/
Expand Down
241 changes: 137 additions & 104 deletions ddapm_test_agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@
from .tracestats import v06StatsPayload


class NoSuchSessionException(Exception):
pass


_Handler = Callable[[Request], Awaitable[web.Response]]


Expand Down Expand Up @@ -345,18 +349,34 @@ def _requests_by_session(self, token: Optional[str]) -> List[Request]:
"""
# Go backwards in the requests received gathering requests until
# the /session-start request for the token is found.
reqs: List[Request] = []
# Note that this may not return all associated traces, because some
# may be generated before the session-start call
session_reqs: List[Request] = []
sessionless_reqs: List[Request] = []
matched = token is None

for req in reversed(self._requests):
if req.match_info.handler == self.handle_session_start:
if token is None or _session_token(req) == token:
if token is None:
# If no token is specified, then we match the latest session
break
else:
# The requests made were from a different manual session
# so continue.
elif _session_token(req) == token:
# If a token is specified and it matches, we've hit the start of our session
matched = True
break
elif _session_token(req) != token:
# If a token is specified and it doesn't match, we've hit the start of a different session
# So we reset the list of requests
sessionless_reqs: List[Request] = []
continue
if _session_token(req) in [token, None]:
reqs.append(req)
return reqs
if _session_token(req) == token:
session_reqs.append(req)
elif _session_token(req) is None:
sessionless_reqs.append(req)

if not matched and not session_reqs:
raise NoSuchSessionException(f"No session found for token '{token}'")
return session_reqs + sessionless_reqs

async def _traces_from_request(self, req: Request) -> List[List[Span]]:
"""Return the trace from a trace request."""
Expand Down Expand Up @@ -699,109 +719,117 @@ async def handle_session_start(self, request: Request) -> web.Response:

async def handle_snapshot(self, request: Request) -> web.Response:
"""Generate a snapshot or perform a snapshot test."""
token = request["session_token"]
snap_dir = request.url.query.get("dir", request.app["snapshot_dir"])
snap_ci_mode = request.app["snapshot_ci_mode"]
log.info(
"performing snapshot with token=%r, ci_mode=%r and snapshot directory=%r",
token,
snap_ci_mode,
snap_dir,
)
try:
token = request["session_token"]
snap_dir = request.url.query.get("dir", request.app["snapshot_dir"])
snap_ci_mode = request.app["snapshot_ci_mode"]
log.info(
"performing snapshot with token=%r, ci_mode=%r and snapshot directory=%r",
token,
snap_ci_mode,
snap_dir,
)

# Get the span attributes that are to be ignored for this snapshot.
default_span_ignores: Set[str] = request.app["snapshot_ignored_attrs"]
overrides = set(_parse_csv(request.url.query.get("ignores", "")))
span_ignores = list(default_span_ignores | overrides)
log.info("using ignores %r", span_ignores)

# Get the span attributes that are to be removed for this snapshot.
default_span_removes: Set[str] = request.app["snapshot_removed_attrs"]
overrides = set(_parse_csv(request.url.query.get("removes", "")))
span_removes = list(default_span_removes | overrides)
log.info("using removes %r", span_removes)

if "span_id" in span_removes:
raise AssertionError("Cannot remove 'span_id' from spans")

with CheckTrace.add_frame(f"snapshot (token='{token}')") as frame:
frame.add_item(f"Directory: {snap_dir}")
frame.add_item(f"CI mode: {snap_ci_mode}")

if "X-Datadog-Test-Snapshot-Filename" in request.headers:
snap_file = request.headers["X-Datadog-Test-Snapshot-Filename"]
elif "file" in request.url.query:
snap_file = request.url.query["file"]
else:
snap_file = os.path.join(snap_dir, token)
# Get the span attributes that are to be ignored for this snapshot.
default_span_ignores: Set[str] = request.app["snapshot_ignored_attrs"]
overrides = set(_parse_csv(request.url.query.get("ignores", "")))
span_ignores = list(default_span_ignores | overrides)
log.info("using ignores %r", span_ignores)

# Get the span attributes that are to be removed for this snapshot.
default_span_removes: Set[str] = request.app["snapshot_removed_attrs"]
overrides = set(_parse_csv(request.url.query.get("removes", "")))
span_removes = list(default_span_removes | overrides)
log.info("using removes %r", span_removes)

if "span_id" in span_removes:
raise AssertionError("Cannot remove 'span_id' from spans")

with CheckTrace.add_frame(f"snapshot (token='{token}')") as frame:
frame.add_item(f"Directory: {snap_dir}")
frame.add_item(f"CI mode: {snap_ci_mode}")

if "X-Datadog-Test-Snapshot-Filename" in request.headers:
snap_file = request.headers["X-Datadog-Test-Snapshot-Filename"]
elif "file" in request.url.query:
snap_file = request.url.query["file"]
else:
snap_file = os.path.join(snap_dir, token)

# The logic from here is mostly duplicated for traces and trace stats.
# If another data type is to be snapshotted then it probably makes sense to abstract away
# the required pieces of snapshotting (loading, generating and comparing).
# The logic from here is mostly duplicated for traces and trace stats.
# If another data type is to be snapshotted then it probably makes sense to abstract away
# the required pieces of snapshotting (loading, generating and comparing).

# For backwards compatibility traces don't have a postfix of `_trace.json`
trace_snap_file = f"{snap_file}.json"
tracestats_snap_file = f"{snap_file}_tracestats.json"
# For backwards compatibility traces don't have a postfix of `_trace.json`
trace_snap_file = f"{snap_file}.json"
tracestats_snap_file = f"{snap_file}_tracestats.json"

frame.add_item(f"Trace File: {trace_snap_file}")
frame.add_item(f"Stats File: {tracestats_snap_file}")
log.info("using snapshot files %r and %r", trace_snap_file, tracestats_snap_file)
frame.add_item(f"Trace File: {trace_snap_file}")
frame.add_item(f"Stats File: {tracestats_snap_file}")
log.info("using snapshot files %r and %r", trace_snap_file, tracestats_snap_file)

trace_snap_path_exists = os.path.exists(trace_snap_file)
trace_snap_path_exists = os.path.exists(trace_snap_file)

received_traces = await self._traces_by_session(token)
if snap_ci_mode and received_traces and not trace_snap_path_exists:
raise AssertionError(
f"Trace snapshot file '{trace_snap_file}' not found. "
"Perhaps the file was not checked into source control? "
"The snapshot file is automatically generated when the test agent is not in CI mode."
)
elif trace_snap_path_exists:
# Do the snapshot comparison
with open(trace_snap_file, mode="r") as f:
raw_snapshot = json.load(f)
trace_snapshot.snapshot(
expected_traces=raw_snapshot,
received_traces=received_traces,
ignored=span_ignores,
)
elif received_traces:
# Create a new snapshot for the data received
with open(trace_snap_file, mode="w") as f:
f.write(trace_snapshot.generate_snapshot(received_traces=received_traces, removed=span_removes))
log.info("wrote new trace snapshot to %r", os.path.abspath(trace_snap_file))

# Get all stats buckets from the payloads since we don't care about the other fields (hostname, env, etc)
# in the payload.
received_stats = [bucket for p in (await self._tracestats_by_session(token)) for bucket in p["Stats"]]
tracestats_snap_path_exists = os.path.exists(tracestats_snap_file)
if snap_ci_mode and received_stats and not tracestats_snap_path_exists:
raise AssertionError(
f"Trace stats snapshot file '{tracestats_snap_file}' not found. "
"Perhaps the file was not checked into source control? "
"The snapshot file is automatically generated when the test case is run when not in CI mode."
)
elif tracestats_snap_path_exists:
# Do the snapshot comparison
with open(tracestats_snap_file, mode="r") as f:
raw_snapshot = json.load(f)
tracestats_snapshot.snapshot(
expected_stats=raw_snapshot,
received_stats=received_stats,
)
elif received_stats:
# Create a new snapshot for the data received
with open(tracestats_snap_file, mode="w") as f:
f.write(tracestats_snapshot.generate(received_stats))
log.info(
"wrote new tracestats snapshot to %r",
os.path.abspath(tracestats_snap_file),
)
return web.HTTPOk()
received_traces = await self._traces_by_session(token)
if snap_ci_mode and received_traces and not trace_snap_path_exists:
raise AssertionError(
f"Trace snapshot file '{trace_snap_file}' not found. "
"Perhaps the file was not checked into source control? "
"The snapshot file is automatically generated when the test agent is not in CI mode."
)
elif trace_snap_path_exists:
# Do the snapshot comparison
with open(trace_snap_file, mode="r") as f:
raw_snapshot = json.load(f)
trace_snapshot.snapshot(
expected_traces=raw_snapshot,
received_traces=received_traces,
ignored=span_ignores,
)
elif received_traces:
# Create a new snapshot for the data received
with open(trace_snap_file, mode="w") as f:
f.write(trace_snapshot.generate_snapshot(received_traces=received_traces, removed=span_removes))
log.info("wrote new trace snapshot to %r", os.path.abspath(trace_snap_file))

# Get all stats buckets from the payloads since we don't care about the other fields (hostname, env, etc)
# in the payload.
received_stats = [bucket for p in (await self._tracestats_by_session(token)) for bucket in p["Stats"]]
tracestats_snap_path_exists = os.path.exists(tracestats_snap_file)
if snap_ci_mode and received_stats and not tracestats_snap_path_exists:
raise AssertionError(
f"Trace stats snapshot file '{tracestats_snap_file}' not found. "
"Perhaps the file was not checked into source control? "
"The snapshot file is automatically generated when the test case is run when not in CI mode."
)
elif tracestats_snap_path_exists:
# Do the snapshot comparison
with open(tracestats_snap_file, mode="r") as f:
raw_snapshot = json.load(f)
tracestats_snapshot.snapshot(
expected_stats=raw_snapshot,
received_stats=received_stats,
)
elif received_stats:
# Create a new snapshot for the data received
with open(tracestats_snap_file, mode="w") as f:
f.write(tracestats_snapshot.generate(received_stats))
log.info(
"wrote new tracestats snapshot to %r",
os.path.abspath(tracestats_snap_file),
)
return web.HTTPOk()
except Exception as e:
return web.HTTPBadRequest(reason=str(e))

async def handle_session_traces(self, request: Request) -> web.Response:
token = request["session_token"]
traces = await self._traces_by_session(token)
traces = []
try:
traces = await self._traces_by_session(token)
except NoSuchSessionException as e:
return web.HTTPNotFound(reason=str(e))

return web.json_response(traces)

async def handle_session_apmtelemetry(self, request: Request) -> web.Response:
Expand Down Expand Up @@ -895,14 +923,19 @@ async def handle_session_clear(self, request: Request) -> web.Response:
if req.match_info.handler == self.handle_session_start:
if _session_token(req) == session_token:
in_token_sync_session = True
continue # Don't clear the session start
else:
in_token_sync_session = False
if in_token_sync_session:
setattr(req, "__delete", True)

# Filter out all the requests.
# Filter out all requests marked for deletion.
# Keep session starts.
self._requests = [
r for r in self._requests if _session_token(r) != session_token and not hasattr(r, "__delete")
r
for r in self._requests
if (_session_token(r) != session_token or r.match_info.handler == self.handle_session_start)
and not hasattr(r, "__delete")
]
else:
self._requests = []
Expand Down
2 changes: 1 addition & 1 deletion ddapm_test_agent/trace_snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
log = logging.getLogger(__name__)


DEFAULT_SNAPSHOT_IGNORES = "span_id,trace_id,parent_id,duration,start,metrics.system.pid,metrics.system.process_id,metrics.process_id,meta.runtime-id,span_links.trace_id_high,meta.pathway.hash"
DEFAULT_SNAPSHOT_IGNORES = "span_id,trace_id,parent_id,duration,start,metrics.system.pid,metrics.system.process_id,metrics.process_id,metrics._dd.tracer_kr,meta.runtime-id,span_links.trace_id_high,meta.pathway.hash"


def _key_match(d1: Dict[str, Any], d2: Dict[str, Any], key: str) -> bool:
Expand Down
6 changes: 3 additions & 3 deletions tests/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ async def test_concurrent_session(
for token in ["test_case", "test_case2"]:
resp = await agent.get("/test/session/traces", params={"test_session_token": token})
assert resp.status == 200
assert await resp.json() == v04_reference_http_trace_payload_data_raw
result = await resp.json()
assert result == v04_reference_http_trace_payload_data_raw, result

resp = await agent.get("/test/session/traces")
assert resp.status == 200
Expand All @@ -69,8 +70,7 @@ async def test_concurrent_session(
assert resp.status == 200
for token in ["test_case", "test_case2"]:
resp = await agent.get("/test/session/traces", params={"test_session_token": token})
assert resp.status == 200
assert await resp.json() == []
assert resp.status == 404


async def test_two_sessions(
Expand Down
17 changes: 17 additions & 0 deletions tests/test_snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ async def test_snapshot_single_trace(
When the same trace is sent again
The snapshot should pass
""" # noqa: RST301
# Start a session
resp = await agent.get("/test/session/start", params={"test_session_token": "test_case"})
assert resp.status == 200
# Send a trace
resp = await do_reference_v04_http_trace(token="test_case")
assert resp.status == 200
Expand Down Expand Up @@ -177,6 +180,8 @@ async def test_snapshot_single_trace(
],
)
async def test_snapshot_trace_differences(agent, expected_traces, actual_traces, error):
resp = await agent.get("/test/session/start", params={"test_session_token": "test"})
assert resp.status == 200
resp = await v04_trace(agent, expected_traces, token="test")
assert resp.status == 200, await resp.text()

Expand Down Expand Up @@ -364,6 +369,8 @@ def test_generate_tracestats_snapshot(buckets, expected):


async def test_snapshot_custom_dir(agent, tmp_path, do_reference_v04_http_trace):
resp = await agent.get("/test/session/start", params={"test_session_token": "test_case"})
assert resp.status == 200
resp = await do_reference_v04_http_trace(token="test_case")
assert resp.status == 200

Expand All @@ -382,6 +389,8 @@ async def test_snapshot_custom_dir(agent, tmp_path, do_reference_v04_http_trace)


async def test_snapshot_custom_file(agent, tmp_path, do_reference_v04_http_trace):
resp = await agent.get("/test/session/start", params={"test_session_token": "test_case"})
assert resp.status == 200
resp = await do_reference_v04_http_trace(token="test_case")
assert resp.status == 200

Expand All @@ -402,6 +411,8 @@ async def test_snapshot_custom_file(agent, tmp_path, do_reference_v04_http_trace

@pytest.mark.parametrize("snapshot_ci_mode", [False, True])
async def test_snapshot_tracestats(agent, tmp_path, snapshot_ci_mode, do_reference_v06_http_stats, snapshot_dir):
resp = await agent.get("/test/session/start", params={"test_session_token": "test_case"})
assert resp.status == 200
resp = await do_reference_v06_http_stats(token="test_case")
assert resp.status == 200

Expand Down Expand Up @@ -431,6 +442,8 @@ async def test_snapshot_tracestats(agent, tmp_path, snapshot_ci_mode, do_referen

@pytest.mark.parametrize("snapshot_removed_attrs", [{"start", "duration"}])
async def test_removed_attributes(agent, tmp_path, snapshot_removed_attrs, do_reference_v04_http_trace):
resp = await agent.get("/test/session/start", params={"test_session_token": "test_case"})
assert resp.status == 200
resp = await do_reference_v04_http_trace(token="test_case")
assert resp.status == 200

Expand All @@ -455,6 +468,8 @@ async def test_removed_attributes(agent, tmp_path, snapshot_removed_attrs, do_re

@pytest.mark.parametrize("snapshot_removed_attrs", [{"metrics.process_id"}])
async def test_removed_attributes_metrics(agent, tmp_path, snapshot_removed_attrs, do_reference_v04_http_trace):
resp = await agent.get("/test/session/start", params={"test_session_token": "test_case"})
assert resp.status == 200
resp = await do_reference_v04_http_trace(token="test_case")
assert resp.status == 200

Expand Down Expand Up @@ -707,6 +722,8 @@ async def test_removed_attributes_metrics(agent, tmp_path, snapshot_removed_attr
],
)
async def test_snapshot_trace_differences_removed_start(agent, expected_traces, actual_traces, error):
resp = await agent.get("/test/session/start", params={"test_session_token": "test"})
assert resp.status == 200
resp = await v04_trace(agent, expected_traces, token="test")
assert resp.status == 200, await resp.text()

Expand Down
Loading

0 comments on commit 25161a8

Please sign in to comment.