Skip to content

Commit

Permalink
cmd/cue: add refresh tokens support
Browse files Browse the repository at this point in the history
This will update the logins file whenever the registry
token is refreshed.

Also improves logins file writing to be more atomic
by writing to a temp file and then renaming.

Change-Id: Ia2e89cee1002039c64ab08e4ee53493d867427a4
Signed-off-by: Rustam Abdullaev <rustamabd@gmail.com>
Dispatch-Trailer: {"type":"trybot","CL":1184900,"patchset":2,"ref":"refs/changes/00/1184900/2","targetBranch":"master"}
  • Loading branch information
rustyx authored and cueckoo committed Mar 21, 2024
1 parent 74fb5cf commit 5f3487d
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 8 deletions.
7 changes: 6 additions & 1 deletion internal/cueconfig/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,14 @@ func WriteLogins(path string, logins *Logins) error {
return err
}
// Discourage other users from reading this file.
if err := os.WriteFile(path, body, 0o600); err != nil {
// Write to a temp file and then atomically rename to avoid races with parallel reading/writing.
if err := os.WriteFile(path+".tmp", body, 0o600); err != nil {
return err
}
if err := os.Rename(path+".tmp", path); err != nil {
return err
}

return nil
}

Expand Down
94 changes: 87 additions & 7 deletions mod/modconfig/modconfig.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,13 +189,13 @@ type cueLoginsTransport struct {
// initOnce guards initErr, logins, and transport.
initOnce sync.Once
initErr error
logins *cueconfig.Logins
// transport holds the underlying transport. This wraps
// t.cfg.Transport.
transport http.RoundTripper

// mu guards the fields below.
mu sync.Mutex
mu sync.Mutex
logins *cueconfig.Logins

// cachedTransports holds a transport per host.
// This is needed because the oauth2 API requires a
Expand All @@ -211,35 +211,44 @@ func (t *cueLoginsTransport) RoundTrip(req *http.Request) (*http.Response, error
if err := t.init(); err != nil {
return nil, err
}
// If t.logins is nil then it's not used and there's no need to lock the mutex.
if t.logins == nil {
return t.transport.RoundTrip(req)
}
// TODO: note that a CUE registry may include a path prefix,
// so using solely the host will not work with such a path.
// Can we do better here, perhaps keeping the path prefix up to "/v2/"?
host := req.URL.Host
t.mu.Lock()
login, ok := t.logins.Registries[host]
if !ok {
t.mu.Unlock()
return t.transport.RoundTrip(req)
}

t.mu.Lock()
transport := t.cachedTransports[host]
if transport == nil {
tok := cueconfig.TokenFromLogin(login)
oauthCfg := cueconfig.RegistryOAuthConfig(Host{
Name: host,
Insecure: req.URL.Scheme == "http",
})
// TODO: When this client refreshes an access token,
// we should store the refreshed token on disk.

// Make the oauth client use the transport that was set up
// in init.
ctx := context.WithValue(req.Context(), oauth2.HTTPClient, &http.Client{
Transport: t.transport,
})
transport = oauthCfg.Client(ctx, tok).Transport
refreshTokenFunc := func(old, new *oauth2.Token) error {
return t.updateLogin(host, new)
}
transport = oauth2.NewClient(ctx,
CachingTokenSource(
tok,
oauthCfg.TokenSource(ctx, tok),
refreshTokenFunc,
),
).Transport
t.cachedTransports[host] = transport
}
// Unlock immediately so we don't hold the lock for the entire
Expand All @@ -249,6 +258,33 @@ func (t *cueLoginsTransport) RoundTrip(req *http.Request) (*http.Response, error
return transport.RoundTrip(req)
}

func (t *cueLoginsTransport) updateLogin(host string, new *oauth2.Token) error {
// Reload the logins file, in case it changed in the meanwhile by another process.
loginsPath, err := cueconfig.LoginConfigPath(t.getenv)
if err != nil {
// TODO: this should never fail. Log a warning.
return nil
}
logins, err := cueconfig.ReadLogins(loginsPath)
if err != nil || logins == nil {
// TODO: Log a warning. There should be a logins file since we're in the refresh flow.
return nil
}

logins.Registries[host] = cueconfig.LoginFromToken(new)
t.mu.Lock()
t.logins = logins
t.mu.Unlock()

// TODO: lock the logins file properly so that the update is atomic at FS level.
err = cueconfig.WriteLogins(loginsPath, t.logins)
if err != nil {
return err
}

return nil
}

func (t *cueLoginsTransport) init() error {
t.initOnce.Do(func() {
t.initErr = t._init()
Expand All @@ -270,10 +306,11 @@ func (t *cueLoginsTransport) _init() error {
Transport: t.cfg.Transport,
})

// If we can't locate a logins.json file at all, then we'll
// If we can't locate a logins.json file at all, then we'll continue.
// We only refuse to continue if we find an invalid logins.json file.
loginsPath, err := cueconfig.LoginConfigPath(t.getenv)
if err != nil {
// TODO: this should never fail. Log a warning.
return nil
}
logins, err := cueconfig.ReadLogins(loginsPath)
Expand Down Expand Up @@ -325,3 +362,46 @@ func newRef[T any](x *T) *T {
}
return &x1
}

type cachingTokenSource struct {
refreshHook refreshTokenHookFunc
base oauth2.TokenSource // called when t is expired

mu sync.Mutex // guards t
t *oauth2.Token
}

type refreshTokenHookFunc func(old, base *oauth2.Token) error

// CachingTokenSource works similar to oauth2.ReuseTokenSource, except that it
// also exposes a hook to get a hold of the refreshed token, so that it can be
// stored in persistent storage.
func CachingTokenSource(t *oauth2.Token, src oauth2.TokenSource, refresh refreshTokenHookFunc) oauth2.TokenSource {
return &cachingTokenSource{
refreshHook: refresh,
base: src,
t: t,
}
}

func (s *cachingTokenSource) Token() (*oauth2.Token, error) {
s.mu.Lock()
defer s.mu.Unlock()

if s.t.Valid() {
return s.t, nil
}

t, err := s.base.Token()
if err != nil {
return nil, err
}

err = s.refreshHook(s.t, t)
if err != nil {
return nil, err
}
s.t = t

return t, nil
}

0 comments on commit 5f3487d

Please sign in to comment.