From 41b5ff6a63b79206967ddc45376bb8026e989944 Mon Sep 17 00:00:00 2001 From: Adam Fowler Date: Mon, 6 May 2024 08:37:48 +0100 Subject: [PATCH] Ensure tracing span is correct length (#435) --- .../Middleware/TracingMiddleware.swift | 118 ++++++++++++------ Tests/HummingbirdTests/TracingTests.swift | 70 +++++++++++ 2 files changed, 149 insertions(+), 39 deletions(-) diff --git a/Sources/Hummingbird/Middleware/TracingMiddleware.swift b/Sources/Hummingbird/Middleware/TracingMiddleware.swift index e742b7a9b..9a39ded5d 100644 --- a/Sources/Hummingbird/Middleware/TracingMiddleware.swift +++ b/Sources/Hummingbird/Middleware/TracingMiddleware.swift @@ -50,56 +50,62 @@ public struct TracingMiddleware: RouterMiddleware { return endpointPath }() - return try await InstrumentationSystem.tracer.withSpan(operationName, context: serviceContext, ofKind: .server) { span in - span.updateAttributes { attributes in - if let staticAttributes = self.attributes { - attributes.merge(staticAttributes) - } - attributes["http.method"] = request.method.rawValue - attributes["http.target"] = request.uri.path - // TODO: Get HTTP version and scheme - // attributes["http.flavor"] = "\(request.version.major).\(request.version.minor)" - // attributes["http.scheme"] = request.uri.scheme?.rawValue - attributes["http.user_agent"] = request.headers[.userAgent] - attributes["http.request_content_length"] = request.headers[.contentLength].map { Int($0) } ?? nil - - if let remoteAddress = (context as? any RemoteAddressRequestContext)?.remoteAddress { - attributes["net.sock.peer.port"] = remoteAddress.port - - switch remoteAddress.protocol { - case .inet: - attributes["net.sock.peer.addr"] = remoteAddress.ipAddress - case .inet6: - attributes["net.sock.family"] = "inet6" - attributes["net.sock.peer.addr"] = remoteAddress.ipAddress - case .unix: - attributes["net.sock.family"] = "unix" - attributes["net.sock.peer.addr"] = remoteAddress.pathname - default: - break - } + let span = InstrumentationSystem.tracer.startSpan(operationName, context: serviceContext, ofKind: .server) + span.updateAttributes { attributes in + if let staticAttributes = self.attributes { + attributes.merge(staticAttributes) + } + attributes["http.method"] = request.method.rawValue + attributes["http.target"] = request.uri.path + // TODO: Get HTTP version and scheme + // attributes["http.flavor"] = "\(request.version.major).\(request.version.minor)" + // attributes["http.scheme"] = request.uri.scheme?.rawValue + attributes["http.user_agent"] = request.headers[.userAgent] + attributes["http.request_content_length"] = request.headers[.contentLength].map { Int($0) } ?? nil + + if let remoteAddress = (context as? any RemoteAddressRequestContext)?.remoteAddress { + attributes["net.sock.peer.port"] = remoteAddress.port + + switch remoteAddress.protocol { + case .inet: + attributes["net.sock.peer.addr"] = remoteAddress.ipAddress + case .inet6: + attributes["net.sock.family"] = "inet6" + attributes["net.sock.peer.addr"] = remoteAddress.ipAddress + case .unix: + attributes["net.sock.family"] = "unix" + attributes["net.sock.peer.addr"] = remoteAddress.pathname + default: + break } - attributes = self.recordHeaders(request.headers, toSpanAttributes: attributes, withPrefix: "http.request.header.") } + attributes = self.recordHeaders(request.headers, toSpanAttributes: attributes, withPrefix: "http.request.header.") + } - do { - let response = try await next(request, context) + do { + return try await ServiceContext.$current.withValue(span.context) { + var response = try await next(request, context) span.updateAttributes { attributes in attributes = self.recordHeaders(response.headers, toSpanAttributes: attributes, withPrefix: "http.response.header.") attributes["http.status_code"] = Int(response.status.code) attributes["http.response_content_length"] = response.body.contentLength } - return response - } catch let error as HTTPResponseError { - span.attributes["http.status_code"] = Int(error.status.code) - - if 500..<600 ~= error.status.code { - span.setStatus(.init(code: .error)) + let spanWrapper = UnsafeTransfer(SpanWrapper(span)) + response.body = response.body.withPostWriteClosure { + spanWrapper.wrappedValue.end() } - - throw error + return response + } + } catch { + let statusCode = (error as? HTTPResponseError)?.status.code ?? 500 + span.attributes["http.status_code"] = statusCode + if 500..<600 ~= statusCode { + span.setStatus(.init(code: .error)) } + span.recordError(error) + span.end() + throw error } } @@ -120,6 +126,40 @@ public struct TracingMiddleware: RouterMiddleware { } } +/// Stores a reference to a span and on release ends the span +private class SpanWrapper { + var span: (any Span)? + + init(_ span: any Span) { + self.span = span + } + + func end() { + self.span?.end() + self.span = nil + } + + deinit { + self.span?.end() + } +} + +/// ``UnsafeTransfer`` can be used to make non-`Sendable` values `Sendable`. +/// As the name implies, the usage of this is unsafe because it disables the sendable checking of the compiler. +/// It can be used similar to `@unsafe Sendable` but for values instead of types. +@usableFromInline +struct UnsafeTransfer { + @usableFromInline + var wrappedValue: Wrapped + + @inlinable + init(_ wrappedValue: Wrapped) { + self.wrappedValue = wrappedValue + } +} + +extension UnsafeTransfer: @unchecked Sendable {} + /// Protocol for request context that stores the remote address of connected client. /// /// If you want the TracingMiddleware to record the remote address of requests diff --git a/Tests/HummingbirdTests/TracingTests.swift b/Tests/HummingbirdTests/TracingTests.swift index 13357ca3b..d951bd2ad 100644 --- a/Tests/HummingbirdTests/TracingTests.swift +++ b/Tests/HummingbirdTests/TracingTests.swift @@ -265,6 +265,76 @@ final class TracingTests: XCTestCase { ]) } + /// Test span is ended even if the response body with the span end is not run + func testTracingMiddlewareDropResponse() async throws { + let expectation = expectation(description: "Expected span to be ended.") + struct ErrorMiddleware: RouterMiddleware { + public func handle(_ request: Request, context: Context, next: (Request, Context) async throws -> Response) async throws -> Response { + _ = try await next(request, context) + throw HTTPError(.badRequest) + } + } + + let tracer = TestTracer() + tracer.onEndSpan = { _ in + expectation.fulfill() + } + InstrumentationSystem.bootstrapInternal(tracer) + + let router = Router() + router.middlewares.add(ErrorMiddleware()) + router.middlewares.add(TracingMiddleware()) + router.get("users/:id") { _, _ -> String in + return "42" + } + let app = Application(responder: router.buildResponder()) + try await app.test(.router) { client in + try await client.execute(uri: "/users/42", method: .get) { response in + XCTAssertEqual(response.status, .badRequest) + } + } + + await self.wait(for: [expectation], timeout: 1) + + let span = try XCTUnwrap(tracer.spans.first) + + XCTAssertEqual(span.operationName, "/users/:id") + XCTAssertEqual(span.kind, .server) + XCTAssertNil(span.status) + XCTAssertTrue(span.recordedErrors.isEmpty) + } + + // Test span length is the time it takes to write the response + func testTracingSpanLength() async throws { + let expectation = expectation(description: "Expected span to be ended.") + let tracer = TestTracer() + tracer.onEndSpan = { _ in + expectation.fulfill() + } + InstrumentationSystem.bootstrapInternal(tracer) + + let router = Router() + router.middlewares.add(TracingMiddleware()) + router.get("users/:id") { _, _ -> Response in + return Response( + status: .ok, + body: .init { _ in try await Task.sleep(for: .milliseconds(100)) } + ) + } + let app = Application(responder: router.buildResponder()) + try await app.test(.router) { client in + try await client.execute(uri: "/users/42", method: .get) { response in + XCTAssertEqual(response.status, .ok) + } + } + + await self.wait(for: [expectation], timeout: 1) + + let span = try XCTUnwrap(tracer.spans.first) + // Test tracer records span times in milliseconds + XCTAssertGreaterThanOrEqual(span.endTime! - span.startTime, 100) + } + /// Test tracing serviceContext is attached to request when route handler is called func testServiceContextPropagation() async throws { let expectation = expectation(description: "Expected span to be ended.")