Skip to content

Commit

Permalink
Add typealias AWSMiddlewareNextHandler (#580)
Browse files Browse the repository at this point in the history
  • Loading branch information
adam-fowler authored Aug 10, 2023
1 parent c6a79b8 commit 704e805
Show file tree
Hide file tree
Showing 10 changed files with 18 additions and 16 deletions.
16 changes: 9 additions & 7 deletions Sources/SotoCore/Middleware/Middleware.swift
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,14 @@ public struct AWSMiddlewareContext {
public var logger: Logger
}

/// Function to call next middleware in the chain
public typealias AWSMiddlewareNextHandler = (AWSHTTPRequest, AWSMiddlewareContext) async throws -> AWSHTTPResponse
/// Middleware handler, function that takes a request, context and the next function to call
public typealias AWSMiddlewareHandler = @Sendable (AWSHTTPRequest, AWSMiddlewareContext, _ next: (AWSHTTPRequest, AWSMiddlewareContext) async throws -> AWSHTTPResponse) async throws -> AWSHTTPResponse
public typealias AWSMiddlewareHandler = @Sendable (AWSHTTPRequest, AWSMiddlewareContext, _ next: AWSMiddlewareNextHandler) async throws -> AWSHTTPResponse

/// Middleware protocol, with function that takes a request, context and the next function to call
public protocol AWSMiddlewareProtocol: Sendable {
func handle(_ request: AWSHTTPRequest, context: AWSMiddlewareContext, next: (AWSHTTPRequest, AWSMiddlewareContext) async throws -> AWSHTTPResponse) async throws -> AWSHTTPResponse
func handle(_ request: AWSHTTPRequest, context: AWSMiddlewareContext, next: AWSMiddlewareNextHandler) async throws -> AWSHTTPResponse
}

/// Middleware initialized with a middleware handle
Expand All @@ -39,7 +41,7 @@ public struct AWSMiddleware: AWSMiddlewareProtocol {
}

@inlinable
public func handle(_ request: AWSHTTPRequest, context: AWSMiddlewareContext, next: (AWSHTTPRequest, AWSMiddlewareContext) async throws -> AWSHTTPResponse) async throws -> AWSHTTPResponse {
public func handle(_ request: AWSHTTPRequest, context: AWSMiddlewareContext, next: AWSMiddlewareNextHandler) async throws -> AWSHTTPResponse {
try await self.middleware(request, context, next)
}
}
Expand All @@ -56,7 +58,7 @@ public struct AWSMiddleware2<M0: AWSMiddlewareProtocol, M1: AWSMiddlewareProtoco
}

@inlinable
public func handle(_ request: AWSHTTPRequest, context: AWSMiddlewareContext, next: (AWSHTTPRequest, AWSMiddlewareContext) async throws -> AWSHTTPResponse) async throws -> AWSHTTPResponse {
public func handle(_ request: AWSHTTPRequest, context: AWSMiddlewareContext, next: AWSMiddlewareNextHandler) async throws -> AWSHTTPResponse {
try await self.m0.handle(request, context: context) { request, context in
try await self.m1.handle(request, context: context, next: next)
}
Expand All @@ -80,7 +82,7 @@ public struct AWSDynamicMiddlewareStack: AWSMiddlewareProtocol {
}

@inlinable
public func handle(_ request: AWSHTTPRequest, context: AWSMiddlewareContext, next: (AWSHTTPRequest, AWSMiddlewareContext) async throws -> AWSHTTPResponse) async throws -> AWSHTTPResponse {
public func handle(_ request: AWSHTTPRequest, context: AWSMiddlewareContext, next: AWSMiddlewareNextHandler) async throws -> AWSHTTPResponse {
let iterator = self.stack.makeIterator()
return try await self.run(request, context: context, iterator: iterator, finally: next)
}
Expand All @@ -90,7 +92,7 @@ public struct AWSDynamicMiddlewareStack: AWSMiddlewareProtocol {
_ request: AWSHTTPRequest,
context: AWSMiddlewareContext,
iterator: Stack.Iterator,
finally: (AWSHTTPRequest, AWSMiddlewareContext) async throws -> AWSHTTPResponse
finally: AWSMiddlewareNextHandler
) async throws -> AWSHTTPResponse {
var iterator = iterator
switch iterator.next() {
Expand All @@ -107,7 +109,7 @@ public struct AWSDynamicMiddlewareStack: AWSMiddlewareProtocol {
public struct PassThruMiddleware: AWSMiddlewareProtocol {
public init() {}

public func handle(_ request: AWSHTTPRequest, context: AWSMiddlewareContext, next: (AWSHTTPRequest, AWSMiddlewareContext) async throws -> AWSHTTPResponse) async throws -> AWSHTTPResponse {
public func handle(_ request: AWSHTTPRequest, context: AWSMiddlewareContext, next: AWSMiddlewareNextHandler) async throws -> AWSHTTPResponse {
try await next(request, context)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ public struct AWSEditHeadersMiddleware: AWSMiddlewareProtocol {
}

@inlinable
public func handle(_ request: AWSHTTPRequest, context: AWSMiddlewareContext, next: (AWSHTTPRequest, AWSMiddlewareContext) async throws -> AWSHTTPResponse) async throws -> AWSHTTPResponse {
public func handle(_ request: AWSHTTPRequest, context: AWSMiddlewareContext, next: AWSMiddlewareNextHandler) async throws -> AWSHTTPResponse {
var request = request
for edit in self.edits {
switch edit {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ public struct EndpointDiscoveryMiddleware: AWSMiddlewareProtocol {
self.isRequired = required
}

public func handle(_ request: AWSHTTPRequest, context: AWSMiddlewareContext, next: (AWSHTTPRequest, AWSMiddlewareContext) async throws -> AWSHTTPResponse) async throws -> AWSHTTPResponse {
public func handle(_ request: AWSHTTPRequest, context: AWSMiddlewareContext, next: AWSMiddlewareNextHandler) async throws -> AWSHTTPResponse {
let isEnabled = context.serviceConfig.options.contains(.enableEndpointDiscovery)
guard isEnabled || self.isRequired else {
return try await next(request, context)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import SotoSignerV4
struct ErrorHandlingMiddleware: AWSMiddlewareProtocol {
let options: AWSClient.Options

func handle(_ request: AWSHTTPRequest, context: AWSMiddlewareContext, next: (AWSHTTPRequest, AWSMiddlewareContext) async throws -> AWSHTTPResponse) async throws -> AWSHTTPResponse {
func handle(_ request: AWSHTTPRequest, context: AWSMiddlewareContext, next: AWSMiddlewareNextHandler) async throws -> AWSHTTPResponse {
let response = try await next(request, context)

// if response has an HTTP status code outside 2xx then throw an error
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ public struct AWSLoggingMiddleware: AWSMiddlewareProtocol {
}

@inlinable
public func handle(_ request: AWSHTTPRequest, context: AWSMiddlewareContext, next: (AWSHTTPRequest, AWSMiddlewareContext) async throws -> AWSHTTPResponse) async throws -> AWSHTTPResponse {
public func handle(_ request: AWSHTTPRequest, context: AWSMiddlewareContext, next: AWSMiddlewareNextHandler) async throws -> AWSHTTPResponse {
self.log(
"Request:\n" +
" \(context.operation)\n" +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ struct RetryMiddleware: AWSMiddlewareProtocol {
let retryPolicy: RetryPolicy

@inlinable
func handle(_ request: AWSHTTPRequest, context: AWSMiddlewareContext, next: (AWSHTTPRequest, AWSMiddlewareContext) async throws -> AWSHTTPResponse) async throws -> AWSHTTPResponse {
func handle(_ request: AWSHTTPRequest, context: AWSMiddlewareContext, next: AWSMiddlewareNextHandler) async throws -> AWSHTTPResponse {
var attempt = 0
while true {
do {
Expand Down
2 changes: 1 addition & 1 deletion Sources/SotoCore/Middleware/Middleware/S3Middleware.swift
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import Foundation
/// - Fixes up the GetBucketLocation response, so it can be decoded correctly
/// - Creates error body for notFound responses to HEAD requests
public struct S3Middleware: AWSMiddlewareProtocol {
public func handle(_ request: AWSHTTPRequest, context: AWSMiddlewareContext, next: (AWSHTTPRequest, AWSMiddlewareContext) async throws -> AWSHTTPResponse) async throws -> AWSHTTPResponse {
public func handle(_ request: AWSHTTPRequest, context: AWSMiddlewareContext, next: AWSMiddlewareNextHandler) async throws -> AWSHTTPResponse {
var request = request

self.virtualAddressFixup(request: &request, context: context)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ struct SigningMiddleware: AWSMiddlewareProtocol {
let credentialProvider: any CredentialProvider

@inlinable
func handle(_ request: AWSHTTPRequest, context: AWSMiddlewareContext, next: (AWSHTTPRequest, AWSMiddlewareContext) async throws -> AWSHTTPResponse) async throws -> AWSHTTPResponse {
func handle(_ request: AWSHTTPRequest, context: AWSMiddlewareContext, next: AWSMiddlewareNextHandler) async throws -> AWSHTTPResponse {
var request = request
// get credentials
let credential = try await self.credentialProvider.getCredential(logger: context.logger)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ public struct AWSTracingMiddleware: AWSMiddlewareProtocol {
public func handle(
_ request: AWSHTTPRequest,
context: AWSMiddlewareContext,
next: (AWSHTTPRequest, AWSMiddlewareContext) async throws -> AWSHTTPResponse
next: AWSMiddlewareNextHandler
) async throws -> AWSHTTPResponse {
try await InstrumentationSystem.tracer.withSpan(
"\(context.serviceConfig.serviceName).\(context.operation)",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ let MEGA_BYTE = 1024 * 1024
/// Calculates a tree hash calculated from the SHA256 of each 1MB section of the request body
/// and adds it to the request as a header value
public struct TreeHashMiddleware: AWSMiddlewareProtocol {
public func handle(_ request: AWSHTTPRequest, context: AWSMiddlewareContext, next: (AWSHTTPRequest, AWSMiddlewareContext) async throws -> AWSHTTPResponse) async throws -> AWSHTTPResponse {
public func handle(_ request: AWSHTTPRequest, context: AWSMiddlewareContext, next: AWSMiddlewareNextHandler) async throws -> AWSHTTPResponse {
var request = request
if request.headers[self.treeHashHeader].first == nil {
if case .byteBuffer(let buffer) = request.body.storage {
Expand Down

0 comments on commit 704e805

Please sign in to comment.