Skip to content
This repository has been archived by the owner on Oct 29, 2024. It is now read-only.

Commit

Permalink
Merge pull request #10 from vamsii777/patch
Browse files Browse the repository at this point in the history
Comprehensive Refactoring of OAuth Components for Concurrency and PKCE Support
  • Loading branch information
vamsii777 authored Jan 10, 2024
2 parents ede89f6 + 711e9e6 commit 2759152
Show file tree
Hide file tree
Showing 16 changed files with 156 additions and 89 deletions.
16 changes: 8 additions & 8 deletions Sources/VaporOAuth/DefaultImplementations/EmptyCodeManager.swift
Original file line number Diff line number Diff line change
@@ -1,31 +1,31 @@
public struct EmptyCodeManager: CodeManager {
public init() {}

public func getCode(_ code: String) -> OAuthCode? {
return nil
}

// Updated to include PKCE parameters

public func generateCode(
userID: String,
clientID: String,
redirectURI: String,
scopes: [String]?,
codeChallenge: String?,
codeChallengeMethod: String?
codeChallengeMethod: String?,
nonce: String?
) async throws -> String {
return ""
}

public func codeUsed(_ code: OAuthCode) {}

public func getDeviceCode(_ deviceCode: String) -> OAuthDeviceCode? {
return nil
}

public func generateDeviceCode(userID: String, clientID: String, scopes: [String]?) async throws -> String {
return ""
}

public func deviceCodeUsed(_ deviceCode: OAuthDeviceCode) {}
}
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
import Vapor

public actor StaticClientRetriever: ClientRetriever {
public struct StaticClientRetriever: ClientRetriever {
private let clients: [String: OAuthClient]

public init(clients: [OAuthClient]) {
self.clients = clients.reduce(into: [String: OAuthClient]()) { (dict, client) in
dict[client.clientID] = client
}
}

public func getClient(clientID: String) async throws -> OAuthClient? {
public func getClient(clientID: String) throws -> OAuthClient? {
return clients[clientID]
}
}
14 changes: 7 additions & 7 deletions Sources/VaporOAuth/Helper/OAuthHelper+remote.swift
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,16 @@ import Vapor

actor RemoteTokenResponseActor {
var remoteTokenResponse: RemoteTokenResponse?

func setRemoteTokenResponse(_ response: RemoteTokenResponse) {
self.remoteTokenResponse = response
}

func hasTokenResponse() async -> Bool {
func hasTokenResponse() -> Bool {
return remoteTokenResponse != nil
}

func getRemoteTokenResponse() async throws -> RemoteTokenResponse {
func getRemoteTokenResponse() throws -> RemoteTokenResponse {
guard let response = remoteTokenResponse else {
throw Abort(.internalServerError)
}
Expand Down Expand Up @@ -67,7 +67,7 @@ extension OAuthHelper {
responseActor: responseActor
)
}

let remoteTokenResponse = try await responseActor.getRemoteTokenResponse()

guard let user = remoteTokenResponse.user else {
Expand All @@ -78,7 +78,7 @@ extension OAuthHelper {
}
)
}

private static func setupRemoteTokenResponse(
request: Request,
tokenIntrospectionEndpoint: String,
Expand Down
2 changes: 1 addition & 1 deletion Sources/VaporOAuth/Models/OAuthClient.swift
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import Vapor

public final class OAuthClient: Extendable {
public final class OAuthClient: Extendable, Sendable {
public let clientID: String
public let redirectURIs: [String]?
public let clientSecret: String?
Expand Down
7 changes: 0 additions & 7 deletions Sources/VaporOAuth/Models/OAuthDeviceCode.swift
Original file line number Diff line number Diff line change
@@ -1,10 +1,3 @@
//
// OAuthDeviceCode.swift
//
//
// Created by Vamsi Madduluri on 24/08/23.
//

import Foundation

public final class OAuthDeviceCode {
Expand Down
43 changes: 38 additions & 5 deletions Sources/VaporOAuth/Protocols/CodeManager.swift
Original file line number Diff line number Diff line change
@@ -1,13 +1,46 @@
/// Responsible for generating and managing OAuth Codes
public protocol CodeManager: Sendable {
// Updated to include PKCE parameters
func generateCode(userID: String, clientID: String, redirectURI: String, scopes: [String]?, codeChallenge: String?, codeChallengeMethod: String?) async throws -> String
/// Generates an OAuth code for the specified user, client, redirect URI, scopes, code challenge, and code challenge method.
/// - Parameters:
/// - userID: The ID of the user.
/// - clientID: The ID of the client.
/// - redirectURI: The redirect URI.
/// - scopes: The requested scopes.
/// - codeChallenge: The code challenge.
/// - codeChallengeMethod: The code challenge method.
/// - nonce: The nonce.
/// - Returns: The generated OAuth code.
/// - Throws: An error if the code generation fails.
func generateCode(userID: String, clientID: String, redirectURI: String, scopes: [String]?, codeChallenge: String?, codeChallengeMethod: String?, nonce: String?) async throws -> String

/// Retrieves the OAuth code associated with the specified code.
/// - Parameter code: The OAuth code.
/// - Returns: The associated OAuth code, or `nil` if not found.
/// - Throws: An error if the retrieval fails.
func getCode(_ code: String) async throws -> OAuthCode?

// This is explicit to ensure that the code is marked as used or deleted (it could be implied that this is done when you call
// `getCode` but it is called explicitly to remind developers to ensure that codes can't be reused)

/// Marks the specified OAuth code as used or deleted.
/// - Parameter code: The OAuth code to mark as used or deleted.
/// - Throws: An error if the operation fails.
func codeUsed(_ code: OAuthCode) async throws

/// Generates a device code for the specified user, client, and scopes.
/// - Parameters:
/// - userID: The ID of the user.
/// - clientID: The ID of the client.
/// - scopes: The requested scopes.
/// - Returns: The generated device code.
/// - Throws: An error if the code generation fails.
func generateDeviceCode(userID: String, clientID: String, scopes: [String]?) async throws -> String

/// Retrieves the device code associated with the specified device code.
/// - Parameter deviceCode: The device code.
/// - Returns: The associated device code, or `nil` if not found.
/// - Throws: An error if the retrieval fails.
func getDeviceCode(_ deviceCode: String) async throws -> OAuthDeviceCode?

/// Marks the specified device code as used or deleted.
/// - Parameter deviceCode: The device code to mark as used or deleted.
/// - Throws: An error if the operation fails.
func deviceCodeUsed(_ deviceCode: OAuthDeviceCode) async throws
}
52 changes: 43 additions & 9 deletions Sources/VaporOAuth/Protocols/TokenManager.swift
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
import Vapor

/// A protocol that defines the behavior of a token manager.
public protocol TokenManager: Sendable {
// Generates access, refresh, and ID tokens. Should be called after successful authentication.

/// Generates access, refresh, and ID tokens. Should be called after successful authentication.
/// - Parameters:
/// - clientID: The client ID.
/// - userID: The user ID.
/// - scopes: The scopes.
/// - accessTokenExpiryTime: The expiry time for the access token.
/// - idTokenExpiryTime: The expiry time for the ID token.
/// - nonce: The nonce.
/// - Returns: A tuple containing the generated access token, refresh token, and ID token.
func generateTokens(
clientID: String,
userID: String?,
Expand All @@ -11,32 +19,58 @@ public protocol TokenManager: Sendable {
nonce: String?
) async throws -> (AccessToken, RefreshToken, IDToken)

// Generates only an access token. Should be called after successful authentication.
/// Generates only an access token. Should be called after successful authentication.
/// - Parameters:
/// - clientID: The client ID.
/// - userID: The user ID.
/// - scopes: The scopes.
/// - expiryTime: The expiry time for the access token.
/// - Returns: The generated access token.
func generateAccessToken(
clientID: String,
userID: String?,
scopes: [String]?,
expiryTime: Int
) async throws -> AccessToken

// Generates both access and refresh tokens. Should be called after successful PKCE validation.
/// Generates both access and refresh tokens. Should be called after successful PKCE validation.
/// - Parameters:
/// - clientID: The client ID.
/// - userID: The user ID.
/// - scopes: The scopes.
/// - accessTokenExpiryTime: The expiry time for the access token.
/// - Returns: A tuple containing the generated access token and refresh token.
func generateAccessRefreshTokens(
clientID: String,
userID: String?,
scopes: [String]?,
accessTokenExpiryTime: Int
) async throws -> (AccessToken, RefreshToken)

// Retrieves a refresh token by its string representation.
/// Retrieves a refresh token by its string representation.
/// - Parameter refreshToken: The string representation of the refresh token.
/// - Returns: The refresh token, if found. Otherwise, `nil`.
func getRefreshToken(_ refreshToken: String) async throws -> RefreshToken?

// Retrieves an access token by its string representation.
/// Retrieves an access token by its string representation.
/// - Parameter accessToken: The string representation of the access token.
/// - Returns: The access token, if found. Otherwise, `nil`.
func getAccessToken(_ accessToken: String) async throws -> AccessToken?

// Updates a refresh token, typically to change its scope.
/// Updates a refresh token, typically to change its scope.
/// - Parameters:
/// - refreshToken: The refresh token to update.
/// - scopes: The new scopes for the refresh token.
func updateRefreshToken(_ refreshToken: RefreshToken, scopes: [String]) async throws

// Generates an ID token. Should be called after successful authentication.
/// Generates an ID token. Should be called after successful authentication.
/// - Parameters:
/// - clientID: The client ID.
/// - userID: The user ID.
/// - scopes: The scopes.
/// - expiryTime: The expiry time for the ID token.
/// - nonce: The nonce.
/// - Returns: The generated ID token.
func generateIDToken(
clientID: String,
userID: String,
Expand Down
3 changes: 2 additions & 1 deletion Sources/VaporOAuth/RouteHandlers/AuthorizePostHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ struct AuthorizePostHandler {
redirectURI: requestObject.redirectURIBaseString,
scopes: requestObject.scopes,
codeChallenge: requestObject.codeChallenge,
codeChallengeMethod: requestObject.codeChallengeMethod
codeChallengeMethod: requestObject.codeChallengeMethod,
nonce: requestObject.nonce
)
redirectURI += "?code=\(generatedCode)"
} else if requestObject.responseType == ResponseType.idToken || requestObject.responseType == ResponseType.tokenAndIdToken {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,3 @@
//
// DeviceCodeTokenHandler.swift
//
//
// Created by Vamsi Madduluri on 24/08/23.
//

import Vapor

struct DeviceCodeTokenHandler {
Expand Down
2 changes: 1 addition & 1 deletion Sources/VaporOAuth/Utilities/OAuthFlowType.swift
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
public enum OAuthFlowType: String {
public enum OAuthFlowType: String, Sendable {
case authorization = "authorization_code"
case implicit = "implicit"
case password = "password"
Expand Down
24 changes: 6 additions & 18 deletions Sources/VaporOAuth/Validators/CodeValidator.swift
Original file line number Diff line number Diff line change
@@ -1,36 +1,24 @@
import Foundation
import Crypto // Import SwiftCrypto for SHA-256
import Crypto

struct CodeValidator {
func validateCode(_ code: OAuthCode, clientID: String, redirectURI: String, codeVerifier: String?) -> Bool {
guard code.clientID == clientID else {
return false
}

guard code.expiryDate >= Date() else {
return false
}

guard code.redirectURI == redirectURI else {
return false
}

// Optional PKCE validation

if let codeChallenge = code.codeChallenge, let codeChallengeMethod = code.codeChallengeMethod, let verifier = codeVerifier {
switch codeChallengeMethod {
case "S256":
// Transform the codeVerifier using SHA256 and base64-url-encode it
guard let verifierData = verifier.data(using: .utf8) else { return false }
let verifierHash = SHA256.hash(data: verifierData)
let encodedVerifier = Data(verifierHash).base64URLEncodedString()

return codeChallenge == encodedVerifier
default:
// If the code challenge method is unknown, fail the validation
return false
}
return PKCEValidator.validate(codeChallenge: codeChallenge, verifier: verifier, method: code.codeChallengeMethod)
}

// If no PKCE was used (codeVerifier is nil), skip PKCE validation
return true
}
Expand Down
34 changes: 34 additions & 0 deletions Sources/VaporOAuth/Validators/PKCEValidator.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import Foundation
import Crypto

struct PKCEValidator {

static func validate(codeChallenge: String, verifier: String?, method: String?) -> Bool {
guard let verifier = verifier else {
// Fail validation if codeVerifier is not provided
return false
}

guard let method = method else {
// Default to plain if no method is provided
return codeChallenge == verifier
}

switch method {
case "S256":
return validateS256(codeChallenge: codeChallenge, verifier: verifier)
case "plain":
return codeChallenge == verifier
default:
// Unsupported code challenge method
return false
}
}

private static func validateS256(codeChallenge: String, verifier: String) -> Bool {
guard let verifierData = verifier.data(using: .utf8) else { return false }
let hashedVerifier = SHA256.hash(data: verifierData)
let base64UrlEncodedHash = Data(hashedVerifier).base64URLEncodedString()
return codeChallenge == base64UrlEncodedHash
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ class DefaultImplementationTests: XCTestCase {
redirectURI: "https://api.brokenhands.io/callback",
scopes: nil,
codeChallenge: "dummyChallenge",
codeChallengeMethod: "S256"
codeChallengeMethod: "S256",
nonce: "nonce"
)

// Perform the assertion
Expand Down
Loading

0 comments on commit 2759152

Please sign in to comment.