diff --git a/oauthproxy.go b/oauthproxy.go index 0c65ed9868..1411c24cf8 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -196,6 +196,7 @@ func buildSessionChain(opts *options.Options, provider providers.Provider, sessi }) chain = chain.Append(loadSession) provider.Data().StoredSession = ss + provider.Data().StoredSession.NeedsVerifier = provider.Data().NeedsVerifier return chain } @@ -385,8 +386,12 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) { util.SendError("Invalid authentication via OAuth2: unauthorized", rw, http.StatusForbidden) } } - if _, err := (*p.provider.Data().Verifier.GetKeySet()).VerifySignature(req.Context(), session.IDToken); err != nil { - (*p.provider.Data().Verifier.GetKeySet()).UpdateKeys(p.client, p.provider.Data().VerifierTimeout, updateKeysCallback) + if p.provider.Data().NeedsVerifier { + if _, err := (*p.provider.Data().Verifier.GetKeySet()).VerifySignature(req.Context(), session.IDToken); err != nil { + (*p.provider.Data().Verifier.GetKeySet()).UpdateKeys(p.client, p.provider.Data().VerifierTimeout, updateKeysCallback) + } else { + updateKeysCallback() + } } else { updateKeysCallback() } @@ -407,13 +412,12 @@ func (p *OAuthProxy) Proxy(rw http.ResponseWriter, req *http.Request) { case err == nil: rw.WriteHeader(http.StatusOK) if p.passAuthorization { - proxywasm.AddHttpRequestHeader("Authorization", fmt.Sprintf("%s %s", providers.TokenTypeBearer, session.IDToken)) + proxywasm.AddHttpRequestHeader("Authorization", fmt.Sprintf("%s %s", providers.TokenTypeBearer, session.AccessToken)) } if cookies, ok := rw.Header()[SetCookieHeader]; ok && len(cookies) > 0 { newCookieValue := strings.Join(cookies, ",") if p.ctx != nil { p.ctx.SetContext(SetCookieHeader, newCookieValue) - modifyRequestCookie(req, p.CookieOptions.Name, newCookieValue) util.Logger.Info("Authentication and session refresh successfully .") } else { util.Logger.Error("Set Cookie failed cause HttpContext is nil.") @@ -493,7 +497,7 @@ func (p *OAuthProxy) IsAllowedRequest(req *http.Request) bool { } func (p *OAuthProxy) ValidateVerifier() error { - if p.provider.Data().Verifier == nil { + if p.provider.Data().Verifier == nil && p.provider.Data().NeedsVerifier { return errors.New("Failed to obtain OpenID configuration, current OIDC plugin is not working properly.") } return nil @@ -504,7 +508,7 @@ func (p *OAuthProxy) SetContext(ctx wrapper.HttpContext) { } func (p *OAuthProxy) SetVerifier(opts *options.Options) { - if p.provider.Data().Verifier == nil { + if p.provider.Data().Verifier == nil && p.provider.Data().NeedsVerifier { providers.NewVerifierFromConfig(opts.Providers[0], p.provider.Data(), p.client) } } @@ -631,21 +635,3 @@ func redirectToLocation(rw http.ResponseWriter, location string) { } proxywasm.SendHttpResponse(http.StatusFound, headersMap, nil, -1) } - -func modifyRequestCookie(req *http.Request, cookieName, newValue string) { - var cookies []string - found := false - for _, cookie := range req.Cookies() { - // find specify cookie name - if cookie.Name == cookieName { - found = true - cookies = append(cookies, fmt.Sprintf("%s=%s", cookie.Name, newValue)) - } else { - cookies = append(cookies, fmt.Sprintf("%s=%s", cookie.Name, cookie.Value)) - } - } - if !found { - cookies = append(cookies, fmt.Sprintf("%s=%s", cookieName, newValue)) - } - proxywasm.ReplaceHttpRequestHeader("Cookie", strings.Join(cookies, "; ")) -} diff --git a/pkg/apis/options/legacy_options.go b/pkg/apis/options/legacy_options.go index f593053ab0..05fb4fd437 100644 --- a/pkg/apis/options/legacy_options.go +++ b/pkg/apis/options/legacy_options.go @@ -80,7 +80,7 @@ func legacyProviderDefaults() LegacyProvider { ValidateURL: "", Scope: "", Prompt: "", - ApprovalPrompt: "force", + ApprovalPrompt: "", UserIDClaim: OIDCEmailClaim, AllowedGroups: nil, AcrValues: "", @@ -133,10 +133,6 @@ func (l *LegacyProvider) convert() (Providers, error) { urlParams = append(urlParams, LoginURLParameter{Name: "prompt", Default: []string{l.Prompt}}) case l.ApprovalPrompt != "": urlParams = append(urlParams, LoginURLParameter{Name: "approval_prompt", Default: []string{l.ApprovalPrompt}}) - default: - // match legacy behaviour by default - if neither prompt nor approval_prompt - // specified, use approval_prompt=force - urlParams = append(urlParams, LoginURLParameter{Name: "approval_prompt", Default: []string{"force"}}) } provider.LoginURLParameters = urlParams diff --git a/pkg/apis/options/providers.go b/pkg/apis/options/providers.go index ae326188b1..2e763256e1 100644 --- a/pkg/apis/options/providers.go +++ b/pkg/apis/options/providers.go @@ -68,6 +68,8 @@ type ProviderType string const ( // OIDCProvider is the provider type for OIDC OIDCProvider ProviderType = "oidc" + + AliyunProvider ProviderType = "aliyun" ) type OIDCOptions struct { diff --git a/pkg/middleware/stored_session.go b/pkg/middleware/stored_session.go index f4c6452609..b176c1758c 100644 --- a/pkg/middleware/stored_session.go +++ b/pkg/middleware/stored_session.go @@ -69,6 +69,7 @@ type StoredSessionLoader struct { refreshClient wrapper.HttpClient refreshRequestTimeout uint32 RemoteKeySet *oidc.KeySet + NeedsVerifier bool } // loadSession attempts to load a session as identified by the request cookies. @@ -100,7 +101,7 @@ func (s *StoredSessionLoader) loadSession(next http.Handler) http.Handler { } } } - keysNeedsUpdate := (session != nil) + keysNeedsUpdate := (session != nil) && (s.NeedsVerifier) if keysNeedsUpdate { if _, err := (*s.RemoteKeySet).VerifySignature(req.Context(), session.IDToken); err == nil { keysNeedsUpdate = false diff --git a/providers/aliyun.go b/providers/aliyun.go new file mode 100644 index 0000000000..ac64dbc2b6 --- /dev/null +++ b/providers/aliyun.go @@ -0,0 +1,134 @@ +package providers + +import ( + "context" + "fmt" + "net/http" + "net/url" + + "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" + "github.com/higress-group/oauth2-proxy/pkg/apis/sessions" + "github.com/higress-group/oauth2-proxy/pkg/util" +) + +type AliyunProvider struct { + *ProviderData +} + +const ( + aliyunProviderName = "Aliyun" + aliyunDefaultScope = "openid" +) + +var ( + aliyunDefaultLoginURL = &url.URL{ + Scheme: "https", + Host: "signin.aliyun.com", + Path: "/oauth2/v1/auth", + RawQuery: "access_type=offline", + } + + aliyunDefaultRedeemURL = &url.URL{ + Scheme: "https", + Host: "oauth.aliyun.com", + Path: "/v1/token", + } +) + +func NewAliyunProvider(p *ProviderData) *AliyunProvider { + p.setProviderDefaults(providerDefaults{ + name: aliyunProviderName, + loginURL: aliyunDefaultLoginURL, + redeemURL: aliyunDefaultRedeemURL, + profileURL: nil, + validateURL: nil, + scope: aliyunDefaultScope, + }) + + provider := &AliyunProvider{ProviderData: p} + + return provider +} + +var _ Provider = (*AliyunProvider)(nil) + +func (p *AliyunProvider) Redeem(ctx context.Context, redirectURL, code, codeVerifier string, client wrapper.HttpClient, callback func(args ...interface{}), timeout uint32) error { + clientSecret, err := p.GetClientSecret() + if err != nil { + return err + } + params := url.Values{} + params.Add("redirect_uri", redirectURL) + params.Add("client_id", p.ClientID) + params.Add("client_secret", clientSecret) + params.Add("code", code) + params.Add("grant_type", "authorization_code") + + headers := [][2]string{{"Content-Type", "application/x-www-form-urlencoded"}} + + client.Post(p.RedeemURL.String(), headers, []byte(params.Encode()), func(statusCode int, responseHeaders http.Header, responseBody []byte) { + token, err := util.UnmarshalToken(responseHeaders, responseBody) + if err != nil { + util.SendError(err.Error(), nil, http.StatusInternalServerError) + return + } + id_token, ok := token.Extra("id_token").(string) + if !ok { + util.SendError("id_token not found", nil, http.StatusInternalServerError) + return + } + session := &sessions.SessionState{ + IDToken: id_token, + AccessToken: token.AccessToken, + RefreshToken: token.RefreshToken, + } + session.CreatedAtNow() + session.SetExpiresOn(token.Expiry) + + callback(session) + }, timeout) + + return nil +} + +func (p *AliyunProvider) RefreshSession(ctx context.Context, s *sessions.SessionState, client wrapper.HttpClient, callback func(args ...interface{}), timeout uint32) (bool, error) { + if s == nil || s.RefreshToken == "" { + return false, fmt.Errorf("refresh token is empty") + } + + err := p.redeemRefreshToken(ctx, s, client, callback, timeout) + if err != nil { + return false, fmt.Errorf("unable to redeem refresh token: %v", err) + } + + return true, nil +} + +func (p *AliyunProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState, client wrapper.HttpClient, callback func(args ...interface{}), timeout uint32) error { + clientSecret, err := p.GetClientSecret() + if err != nil { + return err + } + params := url.Values{} + params.Add("client_id", p.ClientID) + params.Add("client_secret", clientSecret) + params.Add("refresh_token", s.RefreshToken) + params.Add("grant_type", "refresh_token") + + headers := [][2]string{{"Content-Type", "application/x-www-form-urlencoded"}} + + client.Post(p.RedeemURL.String(), headers, []byte(params.Encode()), func(statusCode int, responseHeaders http.Header, responseBody []byte) { + token, err := util.UnmarshalToken(responseHeaders, responseBody) + if err != nil { + util.SendError(err.Error(), nil, http.StatusInternalServerError) + return + } + s.AccessToken = token.AccessToken + s.CreatedAtNow() + s.SetExpiresOn(token.Expiry) + + callback(s, true) + }, timeout) + + return nil +} diff --git a/providers/provider_data.go b/providers/provider_data.go index a4d043767e..3519bbffc4 100644 --- a/providers/provider_data.go +++ b/providers/provider_data.go @@ -44,6 +44,7 @@ type ProviderData struct { EmailClaim string GroupsClaim string Verifier internaloidc.IDTokenVerifier + NeedsVerifier bool SkipClaimsFromProfileURL bool // Universal Group authorization data structure diff --git a/providers/providers.go b/providers/providers.go index 817cf955b7..e413cb9b60 100644 --- a/providers/providers.go +++ b/providers/providers.go @@ -43,18 +43,15 @@ func NewProvider(providerConfig options.Provider) (Provider, error) { switch providerConfig.Type { case options.OIDCProvider: return NewOIDCProvider(providerData, providerConfig.OIDCConfig), nil + case options.AliyunProvider: + return NewAliyunProvider(providerData), nil default: return nil, fmt.Errorf("unknown provider type %q", providerConfig.Type) } } func NewVerifierFromConfig(providerConfig options.Provider, p *ProviderData, client wrapper.HttpClient) error { - - needsVerifier, err := providerRequiresOIDCProviderVerifier(providerConfig.Type) - if err != nil { - return err - } - if needsVerifier { + if p.NeedsVerifier { verifierOptions := internaloidc.ProviderVerifierOptions{ AudienceClaims: providerConfig.OIDCConfig.AudienceClaims, ClientID: providerConfig.ClientID, @@ -104,6 +101,12 @@ func newProviderDataFromConfig(providerConfig options.Provider) (*ProviderData, VerifierTimeout: providerConfig.OIDCConfig.VerifierRequestTimeout, } + needsVerifier, err := providerRequiresOIDCProviderVerifier(providerConfig.Type) + if err != nil { + return nil, err + } + p.NeedsVerifier = needsVerifier + errs := providerConfigInfoCheck(providerConfig, p) // handle LoginURLParameters errs = append(errs, p.compileLoginParams(providerConfig.LoginURLParameters)...) @@ -155,6 +158,8 @@ func providerRequiresOIDCProviderVerifier(providerType options.ProviderType) (bo switch providerType { case options.OIDCProvider: return true, nil + case options.AliyunProvider: + return false, nil default: return false, fmt.Errorf("unknown provider type: %s", providerType) }