Skip to content

Commit

Permalink
Ensure tracing span is correct length (#435)
Browse files Browse the repository at this point in the history
  • Loading branch information
adam-fowler authored May 6, 2024
1 parent 91dbc92 commit 41b5ff6
Show file tree
Hide file tree
Showing 2 changed files with 149 additions and 39 deletions.
118 changes: 79 additions & 39 deletions Sources/Hummingbird/Middleware/TracingMiddleware.swift
Original file line number Diff line number Diff line change
Expand Up @@ -50,56 +50,62 @@ public struct TracingMiddleware<Context: BaseRequestContext>: RouterMiddleware {
return endpointPath
}()

return try await InstrumentationSystem.tracer.withSpan(operationName, context: serviceContext, ofKind: .server) { span in
span.updateAttributes { attributes in
if let staticAttributes = self.attributes {
attributes.merge(staticAttributes)
}
attributes["http.method"] = request.method.rawValue
attributes["http.target"] = request.uri.path
// TODO: Get HTTP version and scheme
// attributes["http.flavor"] = "\(request.version.major).\(request.version.minor)"
// attributes["http.scheme"] = request.uri.scheme?.rawValue
attributes["http.user_agent"] = request.headers[.userAgent]
attributes["http.request_content_length"] = request.headers[.contentLength].map { Int($0) } ?? nil

if let remoteAddress = (context as? any RemoteAddressRequestContext)?.remoteAddress {
attributes["net.sock.peer.port"] = remoteAddress.port

switch remoteAddress.protocol {
case .inet:
attributes["net.sock.peer.addr"] = remoteAddress.ipAddress
case .inet6:
attributes["net.sock.family"] = "inet6"
attributes["net.sock.peer.addr"] = remoteAddress.ipAddress
case .unix:
attributes["net.sock.family"] = "unix"
attributes["net.sock.peer.addr"] = remoteAddress.pathname
default:
break
}
let span = InstrumentationSystem.tracer.startSpan(operationName, context: serviceContext, ofKind: .server)
span.updateAttributes { attributes in
if let staticAttributes = self.attributes {
attributes.merge(staticAttributes)
}
attributes["http.method"] = request.method.rawValue
attributes["http.target"] = request.uri.path
// TODO: Get HTTP version and scheme
// attributes["http.flavor"] = "\(request.version.major).\(request.version.minor)"
// attributes["http.scheme"] = request.uri.scheme?.rawValue
attributes["http.user_agent"] = request.headers[.userAgent]
attributes["http.request_content_length"] = request.headers[.contentLength].map { Int($0) } ?? nil

if let remoteAddress = (context as? any RemoteAddressRequestContext)?.remoteAddress {
attributes["net.sock.peer.port"] = remoteAddress.port

switch remoteAddress.protocol {
case .inet:
attributes["net.sock.peer.addr"] = remoteAddress.ipAddress
case .inet6:
attributes["net.sock.family"] = "inet6"
attributes["net.sock.peer.addr"] = remoteAddress.ipAddress
case .unix:
attributes["net.sock.family"] = "unix"
attributes["net.sock.peer.addr"] = remoteAddress.pathname
default:
break
}
attributes = self.recordHeaders(request.headers, toSpanAttributes: attributes, withPrefix: "http.request.header.")
}
attributes = self.recordHeaders(request.headers, toSpanAttributes: attributes, withPrefix: "http.request.header.")
}

do {
let response = try await next(request, context)
do {
return try await ServiceContext.$current.withValue(span.context) {
var response = try await next(request, context)
span.updateAttributes { attributes in
attributes = self.recordHeaders(response.headers, toSpanAttributes: attributes, withPrefix: "http.response.header.")

attributes["http.status_code"] = Int(response.status.code)
attributes["http.response_content_length"] = response.body.contentLength
}
return response
} catch let error as HTTPResponseError {
span.attributes["http.status_code"] = Int(error.status.code)

if 500..<600 ~= error.status.code {
span.setStatus(.init(code: .error))
let spanWrapper = UnsafeTransfer(SpanWrapper(span))
response.body = response.body.withPostWriteClosure {
spanWrapper.wrappedValue.end()
}

throw error
return response
}
} catch {
let statusCode = (error as? HTTPResponseError)?.status.code ?? 500
span.attributes["http.status_code"] = statusCode
if 500..<600 ~= statusCode {
span.setStatus(.init(code: .error))
}
span.recordError(error)
span.end()
throw error
}
}

Expand All @@ -120,6 +126,40 @@ public struct TracingMiddleware<Context: BaseRequestContext>: RouterMiddleware {
}
}

/// Stores a reference to a span and on release ends the span
private class SpanWrapper {
var span: (any Span)?

init(_ span: any Span) {
self.span = span
}

func end() {
self.span?.end()
self.span = nil
}

deinit {
self.span?.end()
}
}

/// ``UnsafeTransfer`` can be used to make non-`Sendable` values `Sendable`.
/// As the name implies, the usage of this is unsafe because it disables the sendable checking of the compiler.
/// It can be used similar to `@unsafe Sendable` but for values instead of types.
@usableFromInline
struct UnsafeTransfer<Wrapped> {
@usableFromInline
var wrappedValue: Wrapped

@inlinable
init(_ wrappedValue: Wrapped) {
self.wrappedValue = wrappedValue
}
}

extension UnsafeTransfer: @unchecked Sendable {}

/// Protocol for request context that stores the remote address of connected client.
///
/// If you want the TracingMiddleware to record the remote address of requests
Expand Down
70 changes: 70 additions & 0 deletions Tests/HummingbirdTests/TracingTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,76 @@ final class TracingTests: XCTestCase {
])
}

/// Test span is ended even if the response body with the span end is not run
func testTracingMiddlewareDropResponse() async throws {
let expectation = expectation(description: "Expected span to be ended.")
struct ErrorMiddleware<Context: BaseRequestContext>: RouterMiddleware {
public func handle(_ request: Request, context: Context, next: (Request, Context) async throws -> Response) async throws -> Response {
_ = try await next(request, context)
throw HTTPError(.badRequest)
}
}

let tracer = TestTracer()
tracer.onEndSpan = { _ in
expectation.fulfill()
}
InstrumentationSystem.bootstrapInternal(tracer)

let router = Router()
router.middlewares.add(ErrorMiddleware())
router.middlewares.add(TracingMiddleware())
router.get("users/:id") { _, _ -> String in
return "42"
}
let app = Application(responder: router.buildResponder())
try await app.test(.router) { client in
try await client.execute(uri: "/users/42", method: .get) { response in
XCTAssertEqual(response.status, .badRequest)
}
}

await self.wait(for: [expectation], timeout: 1)

let span = try XCTUnwrap(tracer.spans.first)

XCTAssertEqual(span.operationName, "/users/:id")
XCTAssertEqual(span.kind, .server)
XCTAssertNil(span.status)
XCTAssertTrue(span.recordedErrors.isEmpty)
}

// Test span length is the time it takes to write the response
func testTracingSpanLength() async throws {
let expectation = expectation(description: "Expected span to be ended.")
let tracer = TestTracer()
tracer.onEndSpan = { _ in
expectation.fulfill()
}
InstrumentationSystem.bootstrapInternal(tracer)

let router = Router()
router.middlewares.add(TracingMiddleware())
router.get("users/:id") { _, _ -> Response in
return Response(
status: .ok,
body: .init { _ in try await Task.sleep(for: .milliseconds(100)) }
)
}
let app = Application(responder: router.buildResponder())
try await app.test(.router) { client in
try await client.execute(uri: "/users/42", method: .get) { response in
XCTAssertEqual(response.status, .ok)
}
}

await self.wait(for: [expectation], timeout: 1)

let span = try XCTUnwrap(tracer.spans.first)
// Test tracer records span times in milliseconds
XCTAssertGreaterThanOrEqual(span.endTime! - span.startTime, 100)
}

/// Test tracing serviceContext is attached to request when route handler is called
func testServiceContextPropagation() async throws {
let expectation = expectation(description: "Expected span to be ended.")
Expand Down

0 comments on commit 41b5ff6

Please sign in to comment.