Skip to content

Commit

Permalink
Resolves #111
Browse files Browse the repository at this point in the history
  • Loading branch information
punmechanic committed Jul 11, 2024
1 parent d81dbce commit c6e7338
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 18 deletions.
57 changes: 39 additions & 18 deletions cli/oauth2.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,11 @@ var (
GrantTypeTokenExchange GrantType = "urn:ietf:params:oauth:grant-type:token-exchange"
)

var ErrNoSAMLAssertion = errors.New("no saml assertion")
var (
ErrNoSAMLAssertion = errors.New("no saml assertion")
// ErrUnauthorized indicates that the Okta server rejected the request.
ErrUnauthorized = errors.New("unauthorized")
)

// stateBufSize is the size of the buffer used to generate the state parameter.
// 43 is a magic number - It generates states that are not too short or long for Okta's validation.
Expand Down Expand Up @@ -304,6 +308,38 @@ func (r TokenExchange) NewRequest(ctx context.Context, config *oauth2.Config) (*
return req, nil
}

func (r TokenExchange) ProcessResponse(resp *http.Response) (*oauth2.Token, error) {
if resp.StatusCode != http.StatusOK {
// We've not been able to replicate this, but it does happen that sometimes Okta returns a non-200
// response code and a body that does not contain an OAuth2 token. This causes KeyConjurer to submit
// a blank oauth2 token to the next endpoint in the chain, resulting in a cryptic
// "unable to parse SAML assertion" error.
//
// So, we will just assume that in any instance this returns a non-200 code that the user is unauthorized
return nil, ErrUnauthorized
}
var tok oauth2.Token
return &tok, json.NewDecoder(resp.Body).Decode(&tok)
}

func (r TokenExchange) Execute(ctx context.Context, client *http.Client, cfg *oauth2.Config) (*oauth2.Token, error) {
if client == nil {
client = http.DefaultClient
}

req, err := r.NewRequest(ctx, cfg)
if err != nil {
return nil, err
}

resp, err := client.Do(req)
if err != nil {
return nil, err
}

return r.ProcessResponse(resp)
}

func makeOktaApplicationURN(applicationID string) string {
return fmt.Sprintf("urn:okta:apps:%s", applicationID)
}
Expand All @@ -319,23 +355,8 @@ func ExchangeAccessTokenForWebSSOToken(ctx context.Context, client *http.Client,
RequestedTokenType: TokenTypeOktaWebSSOToken,
Audience: makeOktaApplicationURN(applicationID),
}
req, err := tex.NewRequest(ctx, oauthCfg)
if err != nil {
return nil, err
}

if client == nil {
client = http.DefaultClient
}

// TODO: The response can indicate a failure, we should check that for this function
resp, err := client.Do(req)
if err != nil {
return nil, err
}

var tok oauth2.Token
return &tok, json.NewDecoder(resp.Body).Decode(&tok)
return tex.Execute(ctx, client, oauthCfg)
}

// TODO: This is actually an Okta-specific API
Expand Down Expand Up @@ -379,7 +400,7 @@ func DiscoverConfigAndExchangeTokenForAssertion(ctx context.Context, client *htt

tok, err := ExchangeAccessTokenForWebSSOToken(ctx, client, oauthCfg, toks, applicationID)
if err != nil {
return nil, "", OktaError{Message: "error exchanging token", InnerError: err}
return nil, "", OktaError{Message: "error exchanging token - try logging in again by deleting ~/.keyconjurerrc", InnerError: err}
}

assertionBytes, err := ExchangeWebSSOTokenForSAMLAssertion(ctx, client, oidcDomain, tok)
Expand Down
20 changes: 20 additions & 0 deletions cli/oauth2_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
package main

import (
"bytes"
"context"
"encoding/json"
"io"
"net"
"net/http"
"net/http/httptest"
Expand Down Expand Up @@ -147,3 +150,20 @@ func Test_ListenAnyPort_RejectsIfAllProvidedPortsExhausted(t *testing.T) {
_, err := ListenAnyPort("127.0.0.1", activePorts)(context.Background())
assert.ErrorIs(t, err, ErrNoPortsAvailable)
}

func Test_TokenExchange_ProcessProcess_Non200ErrorCodeReturnsError(t *testing.T) {
var req TokenExchange
bodyProps := map[string]string{
"error": "unauthorized",
"error_description": "...",
}

blob, _ := json.Marshal(bodyProps)
tokens, err := req.ProcessResponse(&http.Response{
StatusCode: http.StatusBadRequest,
Body: io.NopCloser(bytes.NewReader(blob)),
})

assert.Nil(t, tokens)
assert.ErrorIs(t, ErrUnauthorized, err)
}

0 comments on commit c6e7338

Please sign in to comment.