Skip to content

Commit

Permalink
rp: allow to set custom URL parameters (#273)
Browse files Browse the repository at this point in the history
* rp: allow to set prompts in AuthURLHandler

Fixes #241

* rp: configuration for handlers with URL options to call RS

Fixes #265
  • Loading branch information
muhlemmer authored Feb 13, 2023
1 parent ff2729c commit c8d61c0
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 20 deletions.
10 changes: 6 additions & 4 deletions example/client/app/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,12 @@ func main() {
return uuid.New().String()
}

// register the AuthURLHandler at your preferred path
// the AuthURLHandler creates the auth request and redirects the user to the auth server
// including state handling with secure cookie and the possibility to use PKCE
http.Handle("/login", rp.AuthURLHandler(state, provider))
// register the AuthURLHandler at your preferred path.
// the AuthURLHandler creates the auth request and redirects the user to the auth server.
// including state handling with secure cookie and the possibility to use PKCE.
// Prompts can optionally be set to inform the server of
// any messages that need to be prompted back to the user.
http.Handle("/login", rp.AuthURLHandler(state, provider, rp.WithPromptURLParam("Welcome back!")))

// for demonstration purposes the returned userinfo response is written as JSON object onto response
marshalUserinfo := func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string, rp rp.RelyingParty, info oidc.UserInfo) {
Expand Down
17 changes: 15 additions & 2 deletions pkg/client/rp/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,10 @@ func TestRelyingPartySession(t *testing.T) {
state := "state-" + strconv.FormatInt(seed.Int63(), 25)
capturedW := httptest.NewRecorder()
get := httptest.NewRequest("GET", localURL.String(), nil)
rp.AuthURLHandler(func() string { return state }, provider)(capturedW, get)
rp.AuthURLHandler(func() string { return state }, provider,
rp.WithPromptURLParam("Hello, World!", "Goodbye, World!"),
rp.WithURLParam("custom", "param"),
)(capturedW, get)

defer func() {
if t.Failed() {
Expand All @@ -84,6 +87,8 @@ func TestRelyingPartySession(t *testing.T) {
}()
require.GreaterOrEqual(t, capturedW.Code, 200, "captured response code")
require.Less(t, capturedW.Code, 400, "captured response code")
require.Contains(t, capturedW.Body.String(), `prompt=Hello%2C+World%21+Goodbye%2C+World%21`)
require.Contains(t, capturedW.Body.String(), `custom=param`)

//nolint:bodyclose
resp := capturedW.Result()
Expand Down Expand Up @@ -140,7 +145,7 @@ func TestRelyingPartySession(t *testing.T) {
email = info.GetEmail()
http.Redirect(w, r, targetURL, 302)
}
rp.CodeExchangeHandler(rp.UserinfoCallback(redirect), provider)(capturedW, get)
rp.CodeExchangeHandler(rp.UserinfoCallback(redirect), provider, rp.WithURLParam("custom", "param"))(capturedW, get)

defer func() {
if t.Failed() {
Expand All @@ -150,6 +155,7 @@ func TestRelyingPartySession(t *testing.T) {
}()
require.Less(t, capturedW.Code, 400, "token exchange response code")
require.Less(t, capturedW.Code, 400, "token exchange response code")
// TODO: how to check the custom header was sent to the server?

//nolint:bodyclose
resp = capturedW.Result()
Expand Down Expand Up @@ -193,6 +199,13 @@ func TestRelyingPartySession(t *testing.T) {
_, err = rp.RefreshAccessToken(provider, newTokens.RefreshToken, "", "")
assert.Errorf(t, err, "refresh with replacement")
}

t.Run("WithPrompt", func(t *testing.T) {
opts := rp.WithPrompt("foo", "bar")()
url := provider.OAuthConfig().AuthCodeURL("some", opts...)

require.Contains(t, url, "prompt=foo+bar")
})
}

type deferredHandler struct {
Expand Down
66 changes: 52 additions & 14 deletions pkg/client/rp/relying_party.go
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ func WithVerifierOpts(opts ...VerifierOption) Option {

// WithClientKey specifies the path to the key.json to be used for the JWT Profile Client Authentication on the token endpoint
//
//deprecated: use WithJWTProfile(SignerFromKeyPath(path)) instead
// deprecated: use WithJWTProfile(SignerFromKeyPath(path)) instead
func WithClientKey(path string) Option {
return WithJWTProfile(SignerFromKeyPath(path))
}
Expand Down Expand Up @@ -304,7 +304,7 @@ func SignerFromKeyAndKeyID(key []byte, keyID string) SignerFromKey {

// Discover calls the discovery endpoint of the provided issuer and returns the found endpoints
//
//deprecated: use client.Discover
// deprecated: use client.Discover
func Discover(issuer string, httpClient *http.Client) (Endpoints, error) {
wellKnown := strings.TrimSuffix(issuer, "/") + oidc.DiscoveryEndpoint
req, err := http.NewRequest("GET", wellKnown, nil)
Expand All @@ -323,7 +323,7 @@ func Discover(issuer string, httpClient *http.Client) (Endpoints, error) {
}

// AuthURL returns the auth request url
//(wrapping the oauth2 `AuthCodeURL`)
// (wrapping the oauth2 `AuthCodeURL`)
func AuthURL(state string, rp RelyingParty, opts ...AuthURLOpt) string {
authOpts := make([]oauth2.AuthCodeOption, 0)
for _, opt := range opts {
Expand All @@ -333,10 +333,15 @@ func AuthURL(state string, rp RelyingParty, opts ...AuthURLOpt) string {
}

// AuthURLHandler extends the `AuthURL` method with a http redirect handler
// including handling setting cookie for secure `state` transfer
func AuthURLHandler(stateFn func() string, rp RelyingParty) http.HandlerFunc {
// including handling setting cookie for secure `state` transfer.
// Custom paramaters can optionally be set to the redirect URL.
func AuthURLHandler(stateFn func() string, rp RelyingParty, urlParam ...URLParamOpt) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
opts := make([]AuthURLOpt, 0)
opts := make([]AuthURLOpt, len(urlParam))
for i, p := range urlParam {
opts[i] = AuthURLOpt(p)
}

state := stateFn()
if err := trySetStateCookie(w, state, rp); err != nil {
http.Error(w, "failed to create state cookie: "+err.Error(), http.StatusUnauthorized)
Expand All @@ -350,6 +355,7 @@ func AuthURLHandler(stateFn func() string, rp RelyingParty) http.HandlerFunc {
}
opts = append(opts, WithCodeChallenge(codeChallenge))
}

http.Redirect(w, r, AuthURL(state, rp, opts...), http.StatusFound)
}
}
Expand Down Expand Up @@ -398,8 +404,9 @@ type CodeExchangeCallback func(w http.ResponseWriter, r *http.Request, tokens *o

// CodeExchangeHandler extends the `CodeExchange` method with a http handler
// including cookie handling for secure `state` transfer
// and optional PKCE code verifier checking
func CodeExchangeHandler(callback CodeExchangeCallback, rp RelyingParty) http.HandlerFunc {
// and optional PKCE code verifier checking.
// Custom paramaters can optionally be set to the token URL.
func CodeExchangeHandler(callback CodeExchangeCallback, rp RelyingParty, urlParam ...URLParamOpt) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
state, err := tryReadStateCookie(w, r, rp)
if err != nil {
Expand All @@ -411,7 +418,11 @@ func CodeExchangeHandler(callback CodeExchangeCallback, rp RelyingParty) http.Ha
rp.ErrorHandler()(w, r, params.Get("error"), params.Get("error_description"), state)
return
}
codeOpts := make([]CodeExchangeOpt, 0)
codeOpts := make([]CodeExchangeOpt, len(urlParam))
for i, p := range urlParam {
codeOpts[i] = CodeExchangeOpt(p)
}

if rp.IsPKCE() {
codeVerifier, err := rp.CookieHandler().CheckCookie(r, pkceCode)
if err != nil {
Expand Down Expand Up @@ -517,6 +528,37 @@ func GetEndpoints(discoveryConfig *oidc.DiscoveryConfiguration) Endpoints {
}
}

// withURLParam sets custom url paramaters.
// This is the generalized, unexported, function used by both
// URLParamOpt and AuthURLOpt.
func withURLParam(key, value string) func() []oauth2.AuthCodeOption {
return func() []oauth2.AuthCodeOption {
return []oauth2.AuthCodeOption{
oauth2.SetAuthURLParam(key, value),
}
}
}

// withPrompt sets the `prompt` params in the auth request
// This is the generalized, unexported, function used by both
// URLParamOpt and AuthURLOpt.
func withPrompt(prompt ...string) func() []oauth2.AuthCodeOption {
return withURLParam("prompt", oidc.SpaceDelimitedArray(prompt).Encode())
}

type URLParamOpt func() []oauth2.AuthCodeOption

// WithURLParam allows setting custom key-vale pairs
// to an OAuth2 URL.
func WithURLParam(key, value string) URLParamOpt {
return withURLParam(key, value)
}

// WithPromptURLParam sets the `prompt` parameter in a URL.
func WithPromptURLParam(prompt ...string) URLParamOpt {
return withPrompt(prompt...)
}

type AuthURLOpt func() []oauth2.AuthCodeOption

// WithCodeChallenge sets the `code_challenge` params in the auth request
Expand All @@ -531,11 +573,7 @@ func WithCodeChallenge(codeChallenge string) AuthURLOpt {

// WithPrompt sets the `prompt` params in the auth request
func WithPrompt(prompt ...string) AuthURLOpt {
return func() []oauth2.AuthCodeOption {
return []oauth2.AuthCodeOption{
oauth2.SetAuthURLParam("prompt", oidc.SpaceDelimitedArray(prompt).Encode()),
}
}
return withPrompt(prompt...)
}

type CodeExchangeOpt func() []oauth2.AuthCodeOption
Expand Down

0 comments on commit c8d61c0

Please sign in to comment.