Skip to content

Commit

Permalink
Add request context usage for oauth token renewals
Browse files Browse the repository at this point in the history
- Fix issue 325
  • Loading branch information
sneal committed Oct 27, 2023
1 parent 87ceb0c commit 9042551
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 30 deletions.
6 changes: 3 additions & 3 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,8 @@ func New(config *config.Config) (*Client, error) {
}

// AccessToken returns the raw encoded OAuth access token without the bearer prefix
func (c *Client) AccessToken(ignoredCtx context.Context) (string, error) {
token, err := c.authenticatedClientProvider.AccessToken()
func (c *Client) AccessToken(ctx context.Context) (string, error) {
token, err := c.authenticatedClientProvider.AccessToken(ctx)
if err != nil {
return "", err
}
Expand All @@ -164,7 +164,7 @@ func (c *Client) SSHCode(ctx context.Context) (string, error) {
values.Set("response_type", "code")
values.Set("client_id", r.Links.AppSSH.Meta.OauthClient) // client_id,used by cf server

token, err := c.authenticatedClientProvider.AccessToken()
token, err := c.authenticatedClientProvider.AccessToken(ctx)
if err != nil {
return "", err
}
Expand Down
13 changes: 8 additions & 5 deletions internal/http/client_provider.go
Original file line number Diff line number Diff line change
@@ -1,28 +1,31 @@
package http

import "net/http"
import (
"context"
"net/http"
)

type ClientProvider interface {
// Client returns a *http.Client
Client(followRedirects bool) (*http.Client, error)
Client(ctx context.Context, followRedirects bool) (*http.Client, error)

// ReAuthenticate tells the provider to re-initialize the auth context
ReAuthenticate() error
ReAuthenticate(ctx context.Context) error
}

type UnauthenticatedClientProvider struct {
httpClient *http.Client
httpClientNonRedirecting *http.Client
}

func (c *UnauthenticatedClientProvider) Client(followRedirects bool) (*http.Client, error) {
func (c *UnauthenticatedClientProvider) Client(ctx context.Context, followRedirects bool) (*http.Client, error) {
if followRedirects {
return c.httpClient, nil
}
return c.httpClientNonRedirecting, nil
}

func (c *UnauthenticatedClientProvider) ReAuthenticate() error {
func (c *UnauthenticatedClientProvider) ReAuthenticate(ctx context.Context) error {
return nil
}

Expand Down
9 changes: 5 additions & 4 deletions internal/http/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package http

import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
Expand Down Expand Up @@ -61,7 +62,7 @@ func (c *Executor) ExecuteRequest(request *Request) (*http.Response, error) {
// refresh token is expired or revoked. Attempt to get a new refresh and access token and retry the request.
var authErr *unauthorizedError
if errors.As(err, &authErr) {
err = c.reAuthenticate()
err = c.reAuthenticate(req.Context())
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -111,7 +112,7 @@ func (c *Executor) newHTTPRequest(request *Request) (*http.Request, error) {

// do will get the proper http.Client and calls Do on it using the specified http.Request
func (c *Executor) do(request *http.Request, followRedirects bool) (*http.Response, error) {
client, err := c.clientProvider.Client(followRedirects)
client, err := c.clientProvider.Client(request.Context(), followRedirects)
if err != nil {
return nil, fmt.Errorf("error executing request, failed to get the underlying HTTP client: %w", err)
}
Expand Down Expand Up @@ -148,8 +149,8 @@ func (c *Executor) do(request *http.Request, followRedirects bool) (*http.Respon
}

// reAuthenticate tells the client provider to restart authentication anew because we received a 401
func (c *Executor) reAuthenticate() error {
err := c.clientProvider.ReAuthenticate()
func (c *Executor) reAuthenticate(ctx context.Context) error {
err := c.clientProvider.ReAuthenticate(ctx)
if err != nil {
return fmt.Errorf("an error occurred attempting to reauthenticate "+
"after initially receiving a 401 executing a request: %w", err)
Expand Down
12 changes: 6 additions & 6 deletions internal/http/oauth_session_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ func NewOAuthSessionManager(config *config.Config) *OAuthSessionManager {
}

// Client returns an authenticated OAuth http client
func (m *OAuthSessionManager) Client(followRedirects bool) (*http.Client, error) {
err := m.init(context.Background())
func (m *OAuthSessionManager) Client(ctx context.Context, followRedirects bool) (*http.Client, error) {
err := m.init(ctx)
if err != nil {
return nil, err
}
Expand All @@ -55,7 +55,7 @@ func (m *OAuthSessionManager) Client(followRedirects bool) (*http.Client, error)
// likely in response to a 401
//
// This won't work for userTokenAuth since we have no credentials to exchange for a new token.
func (m *OAuthSessionManager) ReAuthenticate() error {
func (m *OAuthSessionManager) ReAuthenticate(ctx context.Context) error {
m.mutex.Lock()
defer m.mutex.Unlock()

Expand All @@ -64,12 +64,12 @@ func (m *OAuthSessionManager) ReAuthenticate() error {
}

// attempt to create a new token source
return m.newTokenSource(context.Background())
return m.newTokenSource(ctx)
}

// AccessToken returns the raw OAuth access token
func (m *OAuthSessionManager) AccessToken() (string, error) {
err := m.init(context.Background())
func (m *OAuthSessionManager) AccessToken(ctx context.Context) (string, error) {
err := m.init(ctx)
if err != nil {
return "", err
}
Expand Down
25 changes: 13 additions & 12 deletions internal/http/oauth_session_manager_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package http_test

import (
"context"
"github.com/cloudfoundry-community/go-cfclient/v3/config"
"github.com/cloudfoundry-community/go-cfclient/v3/internal/http"
"github.com/cloudfoundry-community/go-cfclient/v3/testutil"
Expand All @@ -22,7 +23,7 @@ func TestOAuthSessionManager(t *testing.T) {
require.Empty(t, c.UAAEndpointURL)
m := http.NewOAuthSessionManager(c)

_, err = m.Client(true)
_, err = m.Client(context.Background(), true)
require.Error(t, err, "expected an error when UAA or Login endpoint is empty")
require.Equal(t, "login and UAA endpoints must not be empty", err.Error())

Expand All @@ -31,35 +32,35 @@ func TestOAuthSessionManager(t *testing.T) {
c.UAAEndpointURL = uaaURL

// we can create a client that utilizes oauth
client1, err := m.Client(true)
client1, err := m.Client(context.Background(), true)
require.NoError(t, err)
require.NotNil(t, client1)

// the same access token is returned as long as it's not expired (which it's not - 300s)
token, err := m.AccessToken()
token, err := m.AccessToken(context.Background())
require.NoError(t, err)
require.Equal(t, "foobar1", token)
require.NoError(t, err)
token, err = m.AccessToken()
token, err = m.AccessToken(context.Background())
require.NoError(t, err)
require.Equal(t, "foobar1", token)

// the same client is returned
client2, err := m.Client(true)
client2, err := m.Client(context.Background(), true)
require.NoError(t, err)
require.Same(t, client1, client2)

// we force new auth context
err = m.ReAuthenticate()
err = m.ReAuthenticate(context.Background())
require.NoError(t, err)

// a different client is now returned
client3, err := m.Client(true)
client3, err := m.Client(context.Background(), true)
require.NoError(t, err)
require.NotSame(t, client2, client3)

// a new token is also returned
token, err = m.AccessToken()
token, err = m.AccessToken(context.Background())
require.NoError(t, err)
require.Equal(t, "foobar2", token)

Expand All @@ -84,22 +85,22 @@ func TestOAuthSessionManagerRefreshToken(t *testing.T) {
m := http.NewOAuthSessionManager(c)

// we can create a client that utilizes oauth
client1, err := m.Client(true)
client1, err := m.Client(context.Background(), true)
require.NoError(t, err)
require.NotNil(t, client1)

// get the access token, it should have been auto-refreshed because the one we gave in the config was expired
token, err := m.AccessToken()
token, err := m.AccessToken(context.Background())
require.NoError(t, err)
require.NotEqual(t, accessToken, token)
require.Equal(t, "foobar1", token)

// get the access token, should be the same as before
token, err = m.AccessToken()
token, err = m.AccessToken(context.Background())
require.NoError(t, err)
require.Equal(t, "foobar1", token)

// we cannot re-auth with only a refresh token (no credentials)
err = m.ReAuthenticate()
err = m.ReAuthenticate(context.Background())
require.EqualError(t, err, "cannot reauthenticate user token auth type, check your access and/or refresh token expiration date")
}

0 comments on commit 9042551

Please sign in to comment.