Skip to content

Commit

Permalink
cmd/cue: add refresh tokens support
Browse files Browse the repository at this point in the history
This will update the logins file whenever the registry
token is refreshed.

Also improves logins file writing to be more atomic
by writing to a temp file and then renaming.

Change-Id: Ia2e89cee1002039c64ab08e4ee53493d867427a4
Signed-off-by: Rustam Abdullaev <rustamabd@gmail.com>
Dispatch-Trailer: {"type":"trybot","CL":1184900,"patchset":4,"ref":"refs/changes/00/1184900/4","targetBranch":"master"}
  • Loading branch information
rustyx authored and cueckoo committed Mar 22, 2024
1 parent e4e6d68 commit eae9a6b
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 7 deletions.
10 changes: 9 additions & 1 deletion internal/cueconfig/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,17 @@ func WriteLogins(path string, logins *Logins) error {
return err
}
// Discourage other users from reading this file.
if err := os.WriteFile(path, body, 0o600); err != nil {
// Write to a temp file and then try to atomically rename to avoid races
// with parallel reading/writing.
if err := os.WriteFile(path+".tmp", body, 0o600); err != nil {
return err
}
// TODO: on non-POSIX platforms os.Rename might not be atomic. Might need to
// find another solution. Note that Windows NTFS is also atomic.
if err := os.Rename(path+".tmp", path); err != nil {
return err
}

return nil
}

Expand Down
88 changes: 82 additions & 6 deletions mod/modconfig/modconfig.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,10 @@ type cueLoginsTransport struct {
// initOnce guards initErr, logins, and transport.
initOnce sync.Once
initErr error
// loginsMu guards the logins pointer below.
// Note that an instance of cueconfig.Logins is read-only and
// does not have to be guarded.
loginsMu sync.Mutex
logins *cueconfig.Logins
// transport holds the underlying transport. This wraps
// t.cfg.Transport.
Expand All @@ -211,14 +215,19 @@ func (t *cueLoginsTransport) RoundTrip(req *http.Request) (*http.Response, error
if err := t.init(); err != nil {
return nil, err
}
if t.logins == nil {

t.loginsMu.Lock()
logins := t.logins
t.loginsMu.Unlock()

if logins == nil {
return t.transport.RoundTrip(req)
}
// TODO: note that a CUE registry may include a path prefix,
// so using solely the host will not work with such a path.
// Can we do better here, perhaps keeping the path prefix up to "/v2/"?
host := req.URL.Host
login, ok := t.logins.Registries[host]
login, ok := logins.Registries[host]
if !ok {
return t.transport.RoundTrip(req)
}
Expand All @@ -231,15 +240,21 @@ func (t *cueLoginsTransport) RoundTrip(req *http.Request) (*http.Response, error
Name: host,
Insecure: req.URL.Scheme == "http",
})
// TODO: When this client refreshes an access token,
// we should store the refreshed token on disk.

// Make the oauth client use the transport that was set up
// in init.
ctx := context.WithValue(req.Context(), oauth2.HTTPClient, &http.Client{
Transport: t.transport,
})
transport = oauthCfg.Client(ctx, tok).Transport
transport = oauth2.NewClient(ctx,
&cachingTokenSource{
updateFunc: func(tok *oauth2.Token) error {
return t.updateLogin(host, tok)
},
base: oauthCfg.TokenSource(ctx, tok),
t: tok,
},
).Transport
t.cachedTransports[host] = transport
}
// Unlock immediately so we don't hold the lock for the entire
Expand All @@ -249,6 +264,29 @@ func (t *cueLoginsTransport) RoundTrip(req *http.Request) (*http.Response, error
return transport.RoundTrip(req)
}

func (t *cueLoginsTransport) updateLogin(host string, new *oauth2.Token) error {
// Lock the logins for the entire duration of the update to avoid races
t.loginsMu.Lock()
defer t.loginsMu.Unlock()

// Reload the logins file in case another process changed it in the meantime.
loginsPath, err := cueconfig.LoginConfigPath(t.getenv)
if err != nil {
// TODO: this should never fail. Log a warning.
return nil
}
t.logins, err = cueconfig.ReadLogins(loginsPath)
if err != nil || t.logins == nil {
// TODO: Log a warning. There should be a logins file since we're in the refresh flow.
return nil
}

t.logins.Registries[host] = cueconfig.LoginFromToken(new)

// TODO: lock the logins file properly so that the update is atomic at FS level.
return cueconfig.WriteLogins(loginsPath, t.logins)
}

func (t *cueLoginsTransport) init() error {
t.initOnce.Do(func() {
t.initErr = t._init()
Expand All @@ -270,10 +308,11 @@ func (t *cueLoginsTransport) _init() error {
Transport: t.cfg.Transport,
})

// If we can't locate a logins.json file at all, then we'll
// If we can't locate a logins.json file at all, then we'll continue.
// We only refuse to continue if we find an invalid logins.json file.
loginsPath, err := cueconfig.LoginConfigPath(t.getenv)
if err != nil {
// TODO: this should never fail. Log a warning.
return nil
}
logins, err := cueconfig.ReadLogins(loginsPath)
Expand Down Expand Up @@ -325,3 +364,40 @@ func newRef[T any](x *T) *T {
}
return &x1
}

// cachingTokenSource works similar to oauth2.ReuseTokenSource, except that it
// also exposes a hook to get a hold of the refreshed token, so that it can be
// stored in persistent storage.
type cachingTokenSource struct {
updateFunc func(tok *oauth2.Token) error
base oauth2.TokenSource // called when t is expired

mu sync.Mutex // guards t
t *oauth2.Token
}

func (s *cachingTokenSource) Token() (*oauth2.Token, error) {
s.mu.Lock()
t := s.t

if t.Valid() {
s.mu.Unlock()
return t, nil
}

t, err := s.base.Token()
if err != nil {
s.mu.Unlock()
return nil, err
}

s.t = t
s.mu.Unlock()

err = s.updateFunc(t)
if err != nil {
return nil, err
}

return t, nil
}

0 comments on commit eae9a6b

Please sign in to comment.