Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow for LogRequestMiddleware to filter the headers it outputs #433

Merged
merged 9 commits into from
May 6, 2024
101 changes: 93 additions & 8 deletions Sources/Hummingbird/Middleware/LogRequestMiddleware.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,32 +12,117 @@
//
//===----------------------------------------------------------------------===//

import HTTPTypes
import Logging

/// Middleware outputting to log for every call to server
public struct LogRequestsMiddleware<Context: BaseRequestContext>: RouterMiddleware {
/// Header filter
public struct HeaderFilter: Sendable, ExpressibleByArrayLiteral {
fileprivate enum _Internal: Sendable {
case none
case all(except: [HTTPField.Name])
case some([HTTPField.Name])
}

fileprivate let value: _Internal
fileprivate init(_ value: _Internal) {
self.value = value
}

/// Don't output any headers
public static var none: Self { .init(.none) }
/// Output all headers, except the ones indicated
public static func all(except: [HTTPField.Name] = []) -> Self { .init(.all(except: except)) }
/// Output only these headers
public static func some(_ headers: [HTTPField.Name]) -> Self { .init(.some(headers)) }

public typealias ArrayLiteralElement = HTTPField.Name

/// ExpressibleByArrayLiteral requirement
public init(arrayLiteral elements: ArrayLiteralElement...) {
self.value = .some(elements)
}
}

let logLevel: Logger.Level
let includeHeaders: Bool
let includeHeaders: HeaderFilter
let redactHeaders: [HTTPField.Name]

public init(_ logLevel: Logger.Level, includeHeaders: Bool = false) {
public init(_ logLevel: Logger.Level, includeHeaders: HeaderFilter = .none, redactHeaders: [HTTPField.Name] = []) {
self.logLevel = logLevel
self.includeHeaders = includeHeaders
// only include headers in the redaction list if we are outputting them
self.redactHeaders = switch includeHeaders.value {
case .all(let exceptions):
// don't include headers in the except list
redactHeaders.filter { header in !exceptions.contains(header) }
case .some(let included):
// only include headers in the included list
redactHeaders.filter { header in included.contains(header) }
case .none:
[]
}
}

public func handle(_ request: Request, context: Context, next: (Request, Context) async throws -> Response) async throws -> Response {
if self.includeHeaders {
switch self.includeHeaders.value {
case .none:
context.logger.log(
level: self.logLevel,
"Request",
metadata: [
"hb_uri": .stringConvertible(request.uri),
"hb_method": .string(request.method.rawValue),
]
)
case .all(let except):
context.logger.log(
level: self.logLevel,
"\(request.headers)",
metadata: ["hb_uri": .stringConvertible(request.uri), "hb_method": .string(request.method.rawValue)]
"Request",
metadata: [
"hb_uri": .stringConvertible(request.uri),
"hb_method": .string(request.method.rawValue),
"hb_headers": .stringConvertible(self.allHeaders(headers: request.headers, except: except)),
]
)
} else {
case .some(let filter):
context.logger.log(
level: self.logLevel,
"",
metadata: ["hb_uri": .stringConvertible(request.uri), "hb_method": .string(request.method.rawValue)]
"Request",
metadata: [
"hb_uri": .stringConvertible(request.uri),
"hb_method": .string(request.method.rawValue),
"hb_headers": .stringConvertible(self.filterHeaders(headers: request.headers, filter: filter)),
]
)
}
return try await next(request, context)
}

func filterHeaders(headers: HTTPFields, filter: [HTTPField.Name]) -> [String: String] {
let headers = filter
.compactMap { entry -> (key: String, value: String)? in
guard let value = headers[entry] else { return nil }
if self.redactHeaders.contains(entry) {
return (key: entry.canonicalName, value: "***")
} else {
return (key: entry.canonicalName, value: value)
}
}
return .init(headers) { "\($0), \($1)" }
}

func allHeaders(headers: HTTPFields, except: [HTTPField.Name]) -> [String: String] {
let headers = headers
.compactMap { entry -> (key: String, value: String)? in
if except.contains(where: { entry.name == $0 }) { return nil }
if self.redactHeaders.contains(entry.name) {
return (key: entry.name.canonicalName, value: "***")
} else {
return (key: entry.name.canonicalName, value: entry.value)
}
}
return .init(headers) { "\($0), \($1)" }
}
}
221 changes: 215 additions & 6 deletions Tests/HummingbirdTests/MiddlewareTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

@testable import Hummingbird
import HummingbirdTesting
import Logging
import NIOConcurrencyHelpers
import XCTest

final class MiddlewareTests: XCTestCase {
Expand Down Expand Up @@ -231,16 +233,223 @@ final class MiddlewareTests: XCTestCase {
}
}

func testRouteLoggingMiddleware() async throws {
func testLogRequestMiddleware() async throws {
let logAccumalator = TestLogHandler.LogAccumalator()
let router = Router()
router.middlewares.add(LogRequestsMiddleware(.debug))
router.put("/hello") { _, _ -> String in
throw HTTPError(.badRequest)
router.middlewares.add(LogRequestsMiddleware(.info))
router.get("test") { _, _ in
return HTTPResponse.Status.ok
}
let app = Application(responder: router.buildResponder())
let app = Application(
responder: router.buildResponder(),
logger: Logger(label: "TestLogging") { label in
TestLogHandler(label, accumalator: logAccumalator)
}
)
try await app.test(.router) { client in
try await client.execute(uri: "/hello", method: .put) { _ in
try await client.execute(
uri: "/test",
method: .get,
headers: [.contentType: "application/json"],
body: .init(string: "{}")
) { _ in
let logs = logAccumalator.filter { $0.metadata?["hb_uri"]?.description == "/test" }
let firstLog = try XCTUnwrap(logs.first)
XCTAssertEqual(firstLog.metadata?["hb_method"]?.description, "GET")
XCTAssertNotNil(firstLog.metadata?["hb_id"])
}
}
}

func testLogRequestMiddlewareHeaderFiltering() async throws {
let logAccumalator = TestLogHandler.LogAccumalator()
let router = Router()
router.group()
.add(middleware: LogRequestsMiddleware(.info, includeHeaders: .all(except: [.connection])))
.get("all") { _, _ in return HTTPResponse.Status.ok }
router.group()
.add(middleware: LogRequestsMiddleware(.info, includeHeaders: .none))
.get("none") { _, _ in return HTTPResponse.Status.ok }
router.group()
.add(middleware: LogRequestsMiddleware(.info, includeHeaders: [.contentType]))
.get("some") { _, _ in return HTTPResponse.Status.ok }
let app = Application(
responder: router.buildResponder(),
logger: Logger(label: "TestLogging") { label in
TestLogHandler(label, accumalator: logAccumalator)
}
)
try await app.test(.live) { client in
try await client.execute(
uri: "/some",
method: .get,
headers: [.contentType: "application/json"],
body: .init(string: "{}")
) { _ in
let logEntries = logAccumalator.filter { $0.metadata?["hb_uri"]?.description == "/some" }
XCTAssertEqual(logEntries.first?.metadata?["hb_headers"], .stringConvertible(["content-type": "application/json"]))
}
try await client.execute(
uri: "/none",
method: .get,
headers: [.contentType: "application/json"],
body: .init(string: "{}")
) { _ in
let logEntries = logAccumalator.filter { $0.metadata?["hb_uri"]?.description == "/none" }
XCTAssertNil(logEntries.first?.metadata?["hb_headers"])
}
try await client.execute(
uri: "/all",
method: .get,
headers: [.contentType: "application/json"],
body: .init(string: "{}")
) { _ in
let logEntries = logAccumalator.filter { $0.metadata?["hb_uri"]?.description == "/all" }
guard case .stringConvertible(let headers) = logEntries.first?.metadata?["hb_headers"] else {
fatalError("Should never get here")
}
let reportedHeaders = try XCTUnwrap(headers as? [String: String])
XCTAssertEqual(reportedHeaders["content-type"], "application/json")
XCTAssertEqual(reportedHeaders["content-length"], "2")
XCTAssertNil(reportedHeaders["connection"])
}
}
}

func testLogRequestMiddlewareHeaderRedaction() async throws {
let logAccumalator = TestLogHandler.LogAccumalator()
let router = Router()
router.group()
.add(middleware: LogRequestsMiddleware(.info, includeHeaders: .all(), redactHeaders: [.authorization]))
.get("all") { _, _ in return HTTPResponse.Status.ok }
router.group()
.add(middleware: LogRequestsMiddleware(.info, includeHeaders: [.authorization], redactHeaders: [.authorization]))
.get("some") { _, _ in return HTTPResponse.Status.ok }
let app = Application(
responder: router.buildResponder(),
logger: Logger(label: "TestLogging") { label in
TestLogHandler(label, accumalator: logAccumalator)
}
)
try await app.test(.live) { client in
try await client.execute(
uri: "/some",
method: .get,
headers: [.authorization: "basic okhasdf87654"],
body: .init(string: "{}")
) { _ in
let logEntries = logAccumalator.filter { $0.metadata?["hb_uri"]?.description == "/some" }
XCTAssertEqual(logEntries.first?.metadata?["hb_headers"], .stringConvertible(["authorization": "***"]))
}
try await client.execute(
uri: "/all",
method: .get,
headers: [.authorization: "basic kjhdfi7udsfkhj"],
body: .init(string: "{}")
) { _ in
let logEntries = logAccumalator.filter { $0.metadata?["hb_uri"]?.description == "/all" }
guard case .stringConvertible(let headers) = logEntries.first?.metadata?["hb_headers"] else {
fatalError("Should never get here")
}
let reportedHeaders = try XCTUnwrap(headers as? [String: String])
XCTAssertEqual(reportedHeaders["authorization"], "***")
XCTAssertEqual(reportedHeaders["content-length"], "2")
}
}
}

func testLogRequestMiddlewareMultipleHeaders() async throws {
let logAccumalator = TestLogHandler.LogAccumalator()
let router = Router()
router.middlewares.add(LogRequestsMiddleware(.info, includeHeaders: [.test]))
router.get("test") { _, _ in
return HTTPResponse.Status.ok
}
let app = Application(
responder: router.buildResponder(),
logger: Logger(label: "TestLogging") { label in
TestLogHandler(label, accumalator: logAccumalator)
}
)
try await app.test(.router) { client in
var headers = HTTPFields()
headers[.test] = "One"
headers.append(.init(name: .test, value: "Two"))
try await client.execute(
uri: "/test",
method: .get,
headers: headers,
body: .init(string: "{}")
) { _ in
let logs = logAccumalator.filter { $0.metadata?["hb_uri"]?.description == "/test" }
let firstLog = try XCTUnwrap(logs.first)
XCTAssertEqual(firstLog.metadata?["hb_headers"], .stringConvertible(["hbtest": "One, Two"]))
}
}
}
}

/// LogHandler used in tests. Stores all log entries in provided `LogAccumalator``
struct TestLogHandler: LogHandler {
struct LogEntry {
let level: Logger.Level
let message: Logger.Message
let metadata: Logger.Metadata?
}

/// Used to store Logs
final class LogAccumalator {
var logEntries: NIOLockedValueBox<[LogEntry]>

init() {
self.logEntries = .init([])
}

func addEntry(_ entry: LogEntry) {
self.logEntries.withLockedValue { value in
value.append(entry)
}
}

func filter(_ isIncluded: (LogEntry) -> Bool) -> [LogEntry] {
self.logEntries.withLockedValue { logs in
logs.filter(isIncluded)
}
}
}

subscript(metadataKey key: String) -> Logger.Metadata.Value? {
get {
self.metadata[key]
}
set(newValue) {
self.metadata[key] = newValue
}
}

init(_: String, accumalator: LogAccumalator) {
self.logLevel = .info
self.metadata = [:]
self.accumalator = accumalator
}

public func log(
level: Logger.Level,
message: Logger.Message,
metadata explicitMetadata: Logger.Metadata?,
source: String,
file: String,
function: String,
line: UInt
) {
var metadata = self.metadata
if let explicitMetadata, !explicitMetadata.isEmpty {
metadata.merge(explicitMetadata, uniquingKeysWith: { _, explicit in explicit })
}
self.accumalator.addEntry(.init(level: level, message: message, metadata: metadata))
}

var logLevel: Logger.Level
var metadata: Logger.Metadata
let accumalator: LogAccumalator
}
Loading