Skip to content
This repository has been archived by the owner on Oct 29, 2024. It is now read-only.

Commit

Permalink
Merge pull request #15 from vamsii777/fixes/scopes
Browse files Browse the repository at this point in the history
Refactor Scope Handling to String Type in Compliance with OAuth RFC Standards
  • Loading branch information
vamsii777 authored Feb 1, 2024
2 parents 3bc4d51 + cbe572d commit e84c8ea
Show file tree
Hide file tree
Showing 42 changed files with 491 additions and 349 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ public struct EmptyCodeManager: CodeManager {
userID: String,
clientID: String,
redirectURI: String,
scopes: [String]?,
scopes: String?,
codeChallenge: String?,
codeChallengeMethod: String?,
nonce: String?
Expand All @@ -24,7 +24,7 @@ public struct EmptyCodeManager: CodeManager {
return nil
}

public func generateDeviceCode(userID: String, clientID: String, scopes: [String]?) async throws -> String {
public func generateDeviceCode(userID: String, clientID: String, scopes: String?) async throws -> String {
return ""
}

Expand Down
21 changes: 11 additions & 10 deletions Sources/VaporOAuth/Helper/OAuthHelper+remote.swift
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@ import Vapor

actor RemoteTokenResponseActor {
var remoteTokenResponse: RemoteTokenResponse?

func setRemoteTokenResponse(_ response: RemoteTokenResponse) {
self.remoteTokenResponse = response
}

func hasTokenResponse() async -> Bool {
return remoteTokenResponse != nil
}

func getRemoteTokenResponse() async throws -> RemoteTokenResponse {
guard let response = remoteTokenResponse else {
throw Abort(.internalServerError)
Expand Down Expand Up @@ -43,8 +43,8 @@ extension OAuthHelper {

let remoteTokenResponse = try await responseActor.getRemoteTokenResponse()

if let requiredScopes = scopes {
guard let tokenScopes = remoteTokenResponse.scopes else {
if let requiredScopes = scopes?.components(separatedBy: " ") {
guard let tokenScopes = remoteTokenResponse.scopes?.components(separatedBy: " ") else {
throw Abort(.unauthorized)
}

Expand All @@ -67,7 +67,7 @@ extension OAuthHelper {
responseActor: responseActor
)
}

let remoteTokenResponse = try await responseActor.getRemoteTokenResponse()

guard let user = remoteTokenResponse.user else {
Expand All @@ -78,7 +78,7 @@ extension OAuthHelper {
}
)
}

private static func setupRemoteTokenResponse(
request: Request,
tokenIntrospectionEndpoint: String,
Expand Down Expand Up @@ -110,11 +110,11 @@ extension OAuthHelper {
throw Abort(.unauthorized)
}

var scopes: [String]?
var scopes: String?
var oauthUser: OAuthUser?

if let tokenScopes: String = tokenInfoJSON[OAuthResponseParameters.scope] {
scopes = tokenScopes.components(separatedBy: " ")
scopes = tokenScopes
}

if let userID: String = tokenInfoJSON[OAuthResponseParameters.userID] {
Expand All @@ -133,6 +133,7 @@ extension OAuthHelper {
}

struct RemoteTokenResponse {
let scopes: [String]?
let scopes: String?
let user: OAuthUser?
}

6 changes: 3 additions & 3 deletions Sources/VaporOAuth/Helper/OAuthHelper.swift
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import Vapor

public struct OAuthHelper: Sendable {
public var assertScopes: @Sendable ([String]?, Request) async throws -> Void
public var assertScopes: @Sendable (String?, Request) async throws -> Void
public var user: @Sendable (Request) async throws -> OAuthUser

public init(
assertScopes: @escaping @Sendable ([String]?, Request) async throws -> Void,
user: @escaping @Sendable (Request) async throws -> OAuthUser
assertScopes: @escaping @Sendable (String?, Request) async throws -> Void,
user: @escaping @Sendable (Request) async throws -> OAuthUser
) {
self.assertScopes = assertScopes
self.user = user
Expand Down
5 changes: 3 additions & 2 deletions Sources/VaporOAuth/Middleware/OAuth2ScopeMiddleware.swift
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import Vapor

public struct OAuth2ScopeMiddleware: AsyncMiddleware {
let requiredScopes: [String]?
let requiredScopes: String?

public init(requiredScopes: [String]?) {
public init(requiredScopes: String?) {
self.requiredScopes = requiredScopes
}

public func respond(to request: Request, chainingTo next: AsyncResponder) async throws -> Response {
// Pass the scopes as a string directly
try await request.oAuthHelper.assertScopes(requiredScopes, request)

return try await next.respond(to: request)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import Vapor

public struct OAuth2TokenIntrospectionMiddleware: AsyncMiddleware {
let requiredScopes: [String]?
let requiredScopes: String?

public init(requiredScopes: [String]?) {
public init(requiredScopes: String?) {
self.requiredScopes = requiredScopes
}

Expand Down
4 changes: 2 additions & 2 deletions Sources/VaporOAuth/Models/OAuthClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ public struct OAuthClient {
public let clientID: String
public let redirectURIs: [String]?
public let clientSecret: String?
public let validScopes: [String]?
public let validScopes: String?
public let confidentialClient: Bool?
public let firstParty: Bool
public let allowedGrantType: OAuthFlowType
Expand All @@ -15,7 +15,7 @@ public struct OAuthClient {

public var extend: Vapor.Extend = .init()

public init(clientID: String, redirectURIs: [String]?, clientSecret: String? = nil, validScopes: [String]? = nil,
public init(clientID: String, redirectURIs: [String]?, clientSecret: String? = nil, validScopes: String? = nil,
confidential: Bool? = nil, firstParty: Bool = false, allowedGrantType: OAuthFlowType,
postLogoutRedirectURIs: [String]? = nil, idTokenSignedResponseAlg: String? = "RS256") {
self.clientID = clientID
Expand Down
4 changes: 2 additions & 2 deletions Sources/VaporOAuth/Models/OAuthCode.swift
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ public struct OAuthCode {
public let redirectURI: String
public let userID: String
public let expiryDate: Date
public let scopes: [String]?
public let scopes: String?

// PKCE parameters
public let codeChallenge: String?
Expand All @@ -23,7 +23,7 @@ public struct OAuthCode {
redirectURI: String,
userID: String,
expiryDate: Date,
scopes: [String]?,
scopes: String?,
codeChallenge: String?, // Add PKCE parameters
codeChallengeMethod: String?,
nonce: String? = nil
Expand Down
4 changes: 2 additions & 2 deletions Sources/VaporOAuth/Models/OAuthDeviceCode.swift
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ public struct OAuthDeviceCode {
public let clientID: String
public let userID: String?
public let expiryDate: Date
public let scopes: [String]?
public let scopes: String? // Updated to String?

public var extend: [String: Any] = [:]

Expand All @@ -16,7 +16,7 @@ public struct OAuthDeviceCode {
clientID: String,
userID: String?,
expiryDate: Date,
scopes: [String]?
scopes: String? // Updated to String?
) {
self.deviceCodeID = deviceCodeID
self.userCode = userCode
Expand Down
2 changes: 1 addition & 1 deletion Sources/VaporOAuth/Models/Tokens/AccessToken.swift
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ public protocol AccessToken: JWTPayload {
var jti: String { get }
var clientID: String { get }
var userID: String? { get }
var scopes: [String]? { get }
var scopes: String? { get }
var expiryTime: Date { get }
}

Expand Down
2 changes: 1 addition & 1 deletion Sources/VaporOAuth/Models/Tokens/RefreshToken.swift
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ public protocol RefreshToken: JWTPayload {
var jti: String { get set }
var clientID: String { get set }
var userID: String? { get set }
var scopes: [String]? { get set }
var scopes: String? { get set }
var exp: Date { get }
}

Expand Down
4 changes: 2 additions & 2 deletions Sources/VaporOAuth/Protocols/AuthorizeHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ public struct AuthorizationRequestObject {
public let responseType: String
public let clientID: String
public let redirectURI: URI
public let scope: [String]
public let scope: String
public let state: String?
public let csrfToken: String
// PKCE parameters
Expand All @@ -50,7 +50,7 @@ public struct AuthorizationRequestObject {
// OpenID Connect specific parameters
public let nonce: String?

public init(responseType: String, clientID: String, redirectURI: URI, scope: [String], state: String?, csrfToken: String, codeChallenge: String?, codeChallengeMethod: String?, nonce: String?) {
public init(responseType: String, clientID: String, redirectURI: URI, scope: String, state: String?, csrfToken: String, codeChallenge: String?, codeChallengeMethod: String?, nonce: String?) {
self.responseType = responseType
self.clientID = clientID
self.redirectURI = redirectURI
Expand Down
4 changes: 2 additions & 2 deletions Sources/VaporOAuth/Protocols/CodeManager.swift
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
/// Responsible for generating and managing OAuth Codes
public protocol CodeManager: Sendable {
// Updated to include PKCE parameters
func generateCode(userID: String, clientID: String, redirectURI: String, scopes: [String]?, codeChallenge: String?, codeChallengeMethod: String?, nonce: String?) async throws -> String
func generateCode(userID: String, clientID: String, redirectURI: String, scopes: String?, codeChallenge: String?, codeChallengeMethod: String?, nonce: String?) async throws -> String
func getCode(_ code: String) async throws -> OAuthCode?

// This is explicit to ensure that the code is marked as used or deleted (it could be implied that this is done when you call
// `getCode` but it is called explicitly to remind developers to ensure that codes can't be reused)
func codeUsed(_ code: OAuthCode) async throws
func generateDeviceCode(userID: String, clientID: String, scopes: [String]?) async throws -> String
func generateDeviceCode(userID: String, clientID: String, scopes: String?) async throws -> String
func getDeviceCode(_ deviceCode: String) async throws -> OAuthDeviceCode?
func deviceCodeUsed(_ deviceCode: OAuthDeviceCode) async throws
}
10 changes: 5 additions & 5 deletions Sources/VaporOAuth/Protocols/TokenManager.swift
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ public protocol TokenManager: Sendable {
func generateTokens(
clientID: String,
userID: String?,
scopes: [String]?,
scopes: String?,
accessTokenExpiryTime: Int,
idTokenExpiryTime: Int,
nonce: String?
Expand All @@ -15,15 +15,15 @@ public protocol TokenManager: Sendable {
func generateAccessToken(
clientID: String,
userID: String?,
scopes: [String]?,
scopes: String?,
expiryTime: Int
) async throws -> AccessToken

// Generates both access and refresh tokens. Should be called after successful PKCE validation.
func generateAccessRefreshTokens(
clientID: String,
userID: String?,
scopes: [String]?,
scopes: String?,
accessTokenExpiryTime: Int
) async throws -> (AccessToken, RefreshToken)

Expand All @@ -34,13 +34,13 @@ public protocol TokenManager: Sendable {
func getAccessToken(_ accessToken: String) async throws -> AccessToken?

// Updates a refresh token, typically to change its scope.
func updateRefreshToken(_ refreshToken: RefreshToken, scopes: [String]) async throws
func updateRefreshToken(_ refreshToken: RefreshToken, scopes: String) async throws

// Generates an ID token. Should be called after successful authentication.
func generateIDToken(
clientID: String,
userID: String,
scopes: [String]?,
scopes: String?,
expiryTime: Int,
nonce: String?
) async throws -> IDToken
Expand Down
12 changes: 4 additions & 8 deletions Sources/VaporOAuth/RouteHandlers/AuthorizeGetHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,9 @@ struct AuthorizeGetHandler {
return (try await authorizeHandler.handleAuthorizationError(.invalidRedirectURI), nil)
}

let scopes: [String]

if let scopeQuery: String = request.query[OAuthRequestParameters.scope] {
scopes = scopeQuery.components(separatedBy: " ")
} else {
scopes = []
}
// Extract scopes as a single string
let scopes: String = request.query[OAuthRequestParameters.scope] ?? ""


let state: String? = request.query[OAuthRequestParameters.state]

Expand Down Expand Up @@ -188,7 +184,7 @@ struct AuthorizeGetHandler {
struct AuthorizationGetRequestObject {
let clientID: String
let redirectURIString: String
let scopes: [String]
let scopes: String
let state: String?
let responseType: String
let codeChallenge: String?
Expand Down
17 changes: 6 additions & 11 deletions Sources/VaporOAuth/RouteHandlers/AuthorizePostHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ struct AuthorizePostRequest {
let clientID: String
let responseType: String
let csrfToken: String
let scopes: [String]?
let scopes: String?
let codeChallenge: String?
let codeChallengeMethod: String?
let nonce: String? // OpenID Connect specific
Expand Down Expand Up @@ -38,6 +38,7 @@ struct AuthorizePostHandler {
}

if requestObject.approveApplication {

if requestObject.responseType == ResponseType.token {
let accessToken = try await tokenManager.generateAccessToken(
clientID: requestObject.clientID,
Expand Down Expand Up @@ -115,7 +116,7 @@ struct AuthorizePostHandler {

if let requestedScopes = requestObject.scopes {
if !requestedScopes.isEmpty {
redirectURI += "&scope=\(requestedScopes.joined(separator: "+"))"
redirectURI += "&scope=\(requestedScopes)"
}
}

Expand Down Expand Up @@ -153,13 +154,8 @@ struct AuthorizePostHandler {
throw Abort(.badRequest)
}

let scopes: [String]?

if let scopeQuery: String = request.query[OAuthRequestParameters.scope] {
scopes = scopeQuery.components(separatedBy: " ")
} else {
scopes = nil
}
// Extract scopes as a single string
let scopesString: String? = request.query[OAuthRequestParameters.scope]

// Extract PKCE parameters
let codeChallenge: String? = request.content[OAuthRequestParameters.codeChallenge]
Expand All @@ -168,7 +164,6 @@ struct AuthorizePostHandler {
// Extract nonce for OpenID Connect from the request content
let nonce: String? = request.content[OAuthRequestParameters.nonce]


return AuthorizePostRequest(
user: user,
userID: userID,
Expand All @@ -177,7 +172,7 @@ struct AuthorizePostHandler {
clientID: clientID,
responseType: responseType,
csrfToken: csrfToken,
scopes: scopes,
scopes: scopesString, // Pass the scope string directly
codeChallenge: codeChallenge,
codeChallengeMethod: codeChallengeMethod,
nonce: nonce
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ struct AuthCodeTokenHandler {
try await codeManager.codeUsed(code)

let scopes = code.scopes
// Check for 'openid' scope to determine if it's an OpenID Connect request
let isOpenIDConnectFlow = scopes?.contains("openid") ?? false
let expiryTime = 3600

Expand All @@ -60,7 +59,7 @@ struct AuthCodeTokenHandler {
nonce: code.nonce
)

return try await tokenResponseGenerator.createOpenIDConnectResponse(accessToken: access, refreshToken: refresh, idToken: idToken, expires: Int(expiryTime), scope: scopes?.joined(separator: " "))
return try await tokenResponseGenerator.createOpenIDConnectResponse(accessToken: access, refreshToken: refresh, idToken: idToken, expires: Int(expiryTime), scope: scopes)

} else {
let (access, refresh) = try await tokenManager.generateAccessRefreshTokens(
Expand All @@ -70,7 +69,7 @@ struct AuthCodeTokenHandler {
accessTokenExpiryTime: expiryTime
)

return try await tokenResponseGenerator.createResponse(accessToken: access, refreshToken: refresh, expires: Int(expiryTime), scope: scopes?.joined(separator: " "))
return try await tokenResponseGenerator.createResponse(accessToken: access, refreshToken: refresh, expires: Int(expiryTime), scope: scopes)
}
}
}
Loading

0 comments on commit e84c8ea

Please sign in to comment.