Skip to content

Commit

Permalink
Add HBBaseRequestContext protocol (#287)
Browse files Browse the repository at this point in the history
* Add HBBaseRequestContext protocol

That doesn't require an init using a NIO Channel

* Use some

* Use some HBBaseRequestContext in request handler tests
  • Loading branch information
adam-fowler authored Nov 21, 2023
1 parent 221b1d8 commit feaa75b
Show file tree
Hide file tree
Showing 34 changed files with 154 additions and 164 deletions.
41 changes: 16 additions & 25 deletions Sources/Hummingbird/Application.swift
Original file line number Diff line number Diff line change
Expand Up @@ -141,22 +141,14 @@ public struct HBApplication<Responder: HBResponder, ChannelSetup: HBChannelSetup
self.services.append(service)
}

/// Helper function that runs application inside a ServiceGroup which will gracefully
/// shutdown on signals SIGINT, SIGTERM
public func runService(gracefulShutdownSignals: [UnixSignal] = [.sigterm, .sigint]) async throws {
let serviceGroup = ServiceGroup(
configuration: .init(
services: [self],
gracefulShutdownSignals: gracefulShutdownSignals,
logger: self.logger
)
)
try await serviceGroup.run()
public static func loggerWithRequestId(_ logger: Logger) -> Logger {
let requestId = globalRequestID.loadThenWrappingIncrement(by: 1, ordering: .relaxed)
return logger.with(metadataKey: "hb_id", value: .stringConvertible(requestId))
}
}

/// Conform to `Service` from `ServiceLifecycle`.
extension HBApplication: Service {
extension HBApplication: Service where Responder.Context: HBRequestContext {
public func run() async throws {
let context = HBApplicationContext(
threadPool: self.threadPool,
Expand All @@ -173,7 +165,7 @@ extension HBApplication: Service {
)
let context = Responder.Context(
applicationContext: context,
source: ChannelContextSource(channel: channel),
channel: channel,
logger: HBApplication.loggerWithRequestId(context.logger)
)
// respond to request
Expand Down Expand Up @@ -209,18 +201,17 @@ extension HBApplication: Service {
}
}

public static func loggerWithRequestId(_ logger: Logger) -> Logger {
let requestId = globalRequestID.loadThenWrappingIncrement(by: 1, ordering: .relaxed)
return logger.with(metadataKey: "hb_id", value: .stringConvertible(requestId))
}

/// Request Context Source from NIO Channel
public struct ChannelContextSource: RequestContextSource {
let channel: Channel

public var eventLoop: EventLoop { self.channel.eventLoop }
public var allocator: ByteBufferAllocator { self.channel.allocator }
public var remoteAddress: SocketAddress? { self.channel.remoteAddress }
/// Helper function that runs application inside a ServiceGroup which will gracefully
/// shutdown on signals SIGINT, SIGTERM
public func runService(gracefulShutdownSignals: [UnixSignal] = [.sigterm, .sigint]) async throws {
let serviceGroup = ServiceGroup(
configuration: .init(
services: [self],
gracefulShutdownSignals: gracefulShutdownSignals,
logger: self.logger
)
)
try await serviceGroup.run()
}
}

Expand Down
8 changes: 4 additions & 4 deletions Sources/Hummingbird/Codable/CodableProtocols.swift
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ public protocol HBResponseEncoder: Sendable {
/// - Parameters:
/// - value: value to encode
/// - request: request that generated this value
func encode<T: Encodable>(_ value: T, from request: HBRequest, context: HBRequestContext) throws -> HBResponse
func encode(_ value: some Encodable, from request: HBRequest, context: some HBBaseRequestContext) throws -> HBResponse
}

/// protocol for decoder deserializing from a Request body
Expand All @@ -28,12 +28,12 @@ public protocol HBRequestDecoder: Sendable {
/// - Parameters:
/// - type: type to decode to
/// - request: request
func decode<T: Decodable>(_ type: T.Type, from request: HBRequest, context: HBRequestContext) throws -> T
func decode<T: Decodable>(_ type: T.Type, from request: HBRequest, context: some HBBaseRequestContext) throws -> T
}

/// Default encoder. Outputs request with the swift string description of object
struct NullEncoder: HBResponseEncoder {
func encode<T: Encodable>(_ value: T, from request: HBRequest, context: HBRequestContext) throws -> HBResponse {
func encode(_ value: some Encodable, from request: HBRequest, context: some HBBaseRequestContext) throws -> HBResponse {
return HBResponse(
status: .ok,
headers: ["content-type": "text/plain; charset=utf-8"],
Expand All @@ -44,7 +44,7 @@ struct NullEncoder: HBResponseEncoder {

/// Default decoder. there is no default decoder path so this generates an error
struct NullDecoder: HBRequestDecoder {
func decode<T: Decodable>(_ type: T.Type, from request: HBRequest, context: HBRequestContext) throws -> T {
func decode<T: Decodable>(_ type: T.Type, from request: HBRequest, context: some HBBaseRequestContext) throws -> T {
preconditionFailure("HBApplication.decoder has not been set")
}
}
2 changes: 1 addition & 1 deletion Sources/Hummingbird/Codable/RequestDecodable.swift
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ extension HBRequestDecodable {
/// Create using `Codable` interfaces
/// - Parameter request: request
/// - Throws: HBHTTPError
public init(from request: HBRequest, context: HBRequestContext) throws {
public init(from request: HBRequest, context: some HBBaseRequestContext) throws {
self = try request.decode(as: Self.self, using: context)
}
}
6 changes: 3 additions & 3 deletions Sources/Hummingbird/Codable/ResponseEncodable.swift
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ public protocol HBResponseCodable: HBResponseEncodable, Decodable {}

/// Extend ResponseEncodable to conform to ResponseGenerator
extension HBResponseEncodable {
public func response(from request: HBRequest, context: HBRequestContext) throws -> HBResponse {
public func response(from request: HBRequest, context: some HBBaseRequestContext) throws -> HBResponse {
return try context.applicationContext.encoder.encode(self, from: request, context: context)
}
}
Expand All @@ -33,7 +33,7 @@ extension Array: HBResponseGenerator where Element: Encodable {}

/// Extend Array to conform to HBResponseEncodable
extension Array: HBResponseEncodable where Element: Encodable {
public func response(from request: HBRequest, context: HBRequestContext) throws -> HBResponse {
public func response(from request: HBRequest, context: some HBBaseRequestContext) throws -> HBResponse {
return try context.applicationContext.encoder.encode(self, from: request, context: context)
}
}
Expand All @@ -43,7 +43,7 @@ extension Dictionary: HBResponseGenerator where Key: Encodable, Value: Encodable

/// Extend Array to conform to HBResponseEncodable
extension Dictionary: HBResponseEncodable where Key: Encodable, Value: Encodable {
public func response(from request: HBRequest, context: HBRequestContext) throws -> HBResponse {
public func response(from request: HBRequest, context: some HBBaseRequestContext) throws -> HBResponse {
return try context.applicationContext.encoder.encode(self, from: request, context: context)
}
}
2 changes: 1 addition & 1 deletion Sources/Hummingbird/Middleware/CORSMiddleware.swift
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import NIOCore
/// then return an empty body with all the standard CORS headers otherwise send
/// request onto the next handler and when you receive the response add a
/// "access-control-allow-origin" header
public struct HBCORSMiddleware<Context: HBRequestContext>: HBMiddleware {
public struct HBCORSMiddleware<Context: HBBaseRequestContext>: HBMiddleware {
/// Defines what origins are allowed
public enum AllowOrigin: Sendable {
case none
Expand Down
2 changes: 1 addition & 1 deletion Sources/Hummingbird/Middleware/LogRequestMiddleware.swift
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import Logging

/// Middleware outputting to log for every call to server
public struct HBLogRequestsMiddleware<Context: HBRequestContext>: HBMiddleware {
public struct HBLogRequestsMiddleware<Context: HBBaseRequestContext>: HBMiddleware {
let logLevel: Logger.Level
let includeHeaders: Bool

Expand Down
2 changes: 1 addition & 1 deletion Sources/Hummingbird/Middleware/MetricsMiddleware.swift
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import Metrics
///
/// Records the number of requests, the request duration and how many errors were thrown. Each metric has additional
/// dimensions URI and method.
public struct HBMetricsMiddleware<Context: HBRequestContext>: HBMiddleware {
public struct HBMetricsMiddleware<Context: HBBaseRequestContext>: HBMiddleware {
public init() {}

public func apply(to request: HBRequest, context: Context, next: any HBResponder<Context>) async throws -> HBResponse {
Expand Down
4 changes: 2 additions & 2 deletions Sources/Hummingbird/Middleware/Middleware.swift
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,11 @@ import NIOCore
/// }
/// ```
public protocol HBMiddleware<Context>: Sendable {
associatedtype Context: HBRequestContext
associatedtype Context
func apply(to request: HBRequest, context: Context, next: any HBResponder<Context>) async throws -> HBResponse
}

struct MiddlewareResponder<Context: HBRequestContext>: HBResponder {
struct MiddlewareResponder<Context>: HBResponder {
let middleware: any HBMiddleware<Context>
let next: any HBResponder<Context>

Expand Down
2 changes: 1 addition & 1 deletion Sources/Hummingbird/Middleware/MiddlewareGroup.swift
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
//===----------------------------------------------------------------------===//

/// Group of middleware that can be used to create a responder chain. Each middleware calls the next one
public final class HBMiddlewareGroup<Context: HBRequestContext> {
public final class HBMiddlewareGroup<Context> {
var middlewares: [any HBMiddleware<Context>]

/// Initialize `HBMiddlewareGroup`
Expand Down
4 changes: 2 additions & 2 deletions Sources/Hummingbird/Middleware/TracingMiddleware.swift
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import Tracing
/// You may opt in to recording a specific subset of HTTP request/response header values by passing
/// a set of header names to ``init(recordingHeaders:)``.
@available(macOS 10.15, iOS 13.0, tvOS 13.0, watchOS 6.0, *)
public struct HBTracingMiddleware<Context: HBRequestContext>: HBMiddleware {
public struct HBTracingMiddleware<Context: HBBaseRequestContext>: HBMiddleware {
private let headerNamesToRecord: Set<RecordingHeader>

/// Intialize a new HBTracingMiddleware.
Expand Down Expand Up @@ -122,7 +122,7 @@ public struct HBTracingMiddleware<Context: HBRequestContext>: HBMiddleware {
///
/// If you want the HBTracingMiddleware to record the remote address of requests
/// then your request context will need to conform to this protocol
public protocol HBRemoteAddressRequestContext: HBRequestContext {
public protocol HBRemoteAddressRequestContext: HBBaseRequestContext {
/// Connected host address
var remoteAddress: SocketAddress? { get }
}
Expand Down
2 changes: 1 addition & 1 deletion Sources/Hummingbird/Router/EndpointResponder.swift
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import NIOCore
import NIOHTTP1

/// Stores endpoint responders for each HTTP method
struct HBEndpointResponders<Context: HBRequestContext>: Sendable {
struct HBEndpointResponders<Context: HBBaseRequestContext>: Sendable {
init(path: String) {
self.path = path
self.methods = [:]
Expand Down
16 changes: 8 additions & 8 deletions Sources/Hummingbird/Router/ResponseGenerator.swift
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,19 @@ import NIOHTTP1
/// This is used by `Router` to convert handler return values into a `HBResponse`.
public protocol HBResponseGenerator {
/// Generate response based on the request this object came from
func response(from request: HBRequest, context: HBRequestContext) throws -> HBResponse
func response(from request: HBRequest, context: some HBBaseRequestContext) throws -> HBResponse
}

/// Extend Response to conform to ResponseGenerator
extension HBResponse: HBResponseGenerator {
/// Return self as the response
public func response(from request: HBRequest, context: HBRequestContext) -> HBResponse { self }
public func response(from request: HBRequest, context: some HBBaseRequestContext) -> HBResponse { self }
}

/// Extend String to conform to ResponseGenerator
extension String: HBResponseGenerator {
/// Generate response holding string
public func response(from request: HBRequest, context: HBRequestContext) -> HBResponse {
public func response(from request: HBRequest, context: some HBBaseRequestContext) -> HBResponse {
let buffer = context.allocator.buffer(string: self)
return HBResponse(status: .ok, headers: ["content-type": "text/plain; charset=utf-8"], body: .init(byteBuffer: buffer))
}
Expand All @@ -40,7 +40,7 @@ extension String: HBResponseGenerator {
/// Extend String to conform to ResponseGenerator
extension Substring: HBResponseGenerator {
/// Generate response holding string
public func response(from request: HBRequest, context: HBRequestContext) -> HBResponse {
public func response(from request: HBRequest, context: some HBBaseRequestContext) -> HBResponse {
let buffer = context.allocator.buffer(substring: self)
return HBResponse(status: .ok, headers: ["content-type": "text/plain; charset=utf-8"], body: .init(byteBuffer: buffer))
}
Expand All @@ -49,22 +49,22 @@ extension Substring: HBResponseGenerator {
/// Extend ByteBuffer to conform to ResponseGenerator
extension ByteBuffer: HBResponseGenerator {
/// Generate response holding bytebuffer
public func response(from request: HBRequest, context: HBRequestContext) -> HBResponse {
public func response(from request: HBRequest, context: some HBBaseRequestContext) -> HBResponse {
HBResponse(status: .ok, headers: ["content-type": "application/octet-stream"], body: .init(byteBuffer: self))
}
}

/// Extend HTTPResponseStatus to conform to ResponseGenerator
extension HTTPResponseStatus: HBResponseGenerator {
/// Generate response with this response status code
public func response(from request: HBRequest, context: HBRequestContext) -> HBResponse {
public func response(from request: HBRequest, context: some HBBaseRequestContext) -> HBResponse {
HBResponse(status: self, headers: [:], body: .init())
}
}

/// Extend Optional to conform to HBResponseGenerator
extension Optional: HBResponseGenerator where Wrapped: HBResponseGenerator {
public func response(from request: HBRequest, context: HBRequestContext) throws -> HBResponse {
public func response(from request: HBRequest, context: some HBBaseRequestContext) throws -> HBResponse {
switch self {
case .some(let wrapped):
return try wrapped.response(from: request, context: context)
Expand All @@ -89,7 +89,7 @@ public struct HBEditedResponse<Generator: HBResponseGenerator>: HBResponseGenera
self.responseGenerator = response
}

public func response(from request: HBRequest, context: HBRequestContext) throws -> HBResponse {
public func response(from request: HBRequest, context: some HBBaseRequestContext) throws -> HBResponse {
var response = try responseGenerator.response(from: request, context: context)
if let status = self.status {
response.status = status
Expand Down
4 changes: 2 additions & 2 deletions Sources/Hummingbird/Router/RouteHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@
/// ```
public protocol HBRouteHandler {
associatedtype _Output
init(from: HBRequest, context: HBRequestContext) throws
func handle(request: HBRequest, context: HBRequestContext) async throws -> _Output
init(from: HBRequest, context: some HBBaseRequestContext) throws
func handle(request: HBRequest, context: some HBBaseRequestContext) async throws -> _Output
}

extension HBRouterMethods {
Expand Down
2 changes: 1 addition & 1 deletion Sources/Hummingbird/Router/Router.swift
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
/// Conforms to `HBResponder` so need to provide its own implementation of
/// `func apply(to request: Request) -> EventLoopFuture<Response>`.
///
struct HBRouter<Context: HBRequestContext>: HBResponder {
struct HBRouter<Context: HBBaseRequestContext>: HBResponder {
let trie: RouterPathTrie<HBEndpointResponders<Context>>
let notFoundResponder: any HBResponder<Context>

Expand Down
6 changes: 3 additions & 3 deletions Sources/Hummingbird/Router/RouterBuilder.swift
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ import NIOHTTP1
/// Both of these match routes which start with "/user" and the next path segment being anything.
/// The second version extracts the path segment out and adds it to `HBRequest.parameters` with the
/// key "id".
public final class HBRouterBuilder<Context: HBRequestContext>: HBRouterMethods {
public final class HBRouterBuilder<Context: HBBaseRequestContext>: HBRouterMethods {
var trie: RouterPathTrieBuilder<HBEndpointResponders<Context>>
public let middlewares: HBMiddlewareGroup<Context>

Expand Down Expand Up @@ -83,7 +83,7 @@ public final class HBRouterBuilder<Context: HBRequestContext>: HBRouterMethods {
self.add(path, method: method, responder: responder)
return self
}

/// return new `RouterGroup`
/// - Parameter path: prefix to add to paths inside the group
public func group(_ path: String = "") -> HBRouterGroup<Context> {
Expand All @@ -92,7 +92,7 @@ public final class HBRouterBuilder<Context: HBRequestContext>: HBRouterMethods {
}

/// Responder that return a not found error
struct NotFoundResponder<Context: HBRequestContext>: HBResponder {
struct NotFoundResponder<Context: HBBaseRequestContext>: HBResponder {
func respond(to request: HBRequest, context: Context) throws -> HBResponse {
throw HBHTTPError(.notFound)
}
Expand Down
2 changes: 1 addition & 1 deletion Sources/Hummingbird/Router/RouterGroup.swift
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import NIOHTTP1
/// .put(":id", use: todoController.update)
/// .delete(":id", use: todoController.delete)
/// ```
public struct HBRouterGroup<Context: HBRequestContext>: HBRouterMethods {
public struct HBRouterGroup<Context: HBBaseRequestContext>: HBRouterMethods {
let path: String
let router: HBRouterBuilder<Context>
let middlewares: HBMiddlewareGroup<Context>
Expand Down
2 changes: 1 addition & 1 deletion Sources/Hummingbird/Router/RouterMethods.swift
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ public struct HBRouterMethodOptions: OptionSet, Sendable {

/// Conform to `HBRouterMethods` to add standard router verb (get, post ...) methods
public protocol HBRouterMethods {
associatedtype Context: HBRequestContext
associatedtype Context: HBBaseRequestContext

/// Add path for async closure
@available(macOS 10.15, iOS 13.0, tvOS 13.0, watchOS 6.0, *)
Expand Down
2 changes: 1 addition & 1 deletion Sources/Hummingbird/Server/Request.swift
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ public struct HBRequest: Sendable {

/// Decode request using decoder stored at `HBApplication.decoder`.
/// - Parameter type: Type you want to decode to
public func decode<Type: Decodable>(as type: Type.Type, using context: HBRequestContext) throws -> Type {
public func decode<Type: Decodable>(as type: Type.Type, using context: some HBBaseRequestContext) throws -> Type {
do {
return try context.applicationContext.decoder.decode(type, from: self, context: context)
} catch DecodingError.dataCorrupted(_) {
Expand Down
Loading

0 comments on commit feaa75b

Please sign in to comment.