Skip to content

Commit

Permalink
JWT cookie fetcher converted to session storage. (#3946)
Browse files Browse the repository at this point in the history
Add initial session storage

This pulls in the scs package for storing sessions.

This means that cookies are stored in-memory for now, with a short cookie being sent to the user's browser, we access the old ID and refresh cookies from the session.

Currently, does not support anything other than in-memory storage, but this is the next step.
  • Loading branch information
bigkevmcd authored Aug 30, 2023
1 parent e8ce330 commit 20749c9
Show file tree
Hide file tree
Showing 17 changed files with 477 additions and 273 deletions.
10 changes: 9 additions & 1 deletion cmd/gitops-server/cmd/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"time"

"github.com/NYTimes/gziphandler"
"github.com/alexedwards/scs/v2"
"github.com/go-logr/logr"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
Expand Down Expand Up @@ -186,12 +187,16 @@ func runCmd(cmd *cobra.Command, args []string) error {
return fmt.Errorf("couldn't get current namespace")
}

sessionManager := scs.New()
// TODO: Make this configurable
sessionManager.Lifetime = 24 * time.Hour
authServer, err := auth.InitAuthServer(cmd.Context(), log, rawClient, auth.AuthParams{
OIDCConfig: options.OIDC,
OIDCSecretName: options.OIDCSecret,
AuthMethodStrings: options.AuthMethods,
NoAuthUser: options.NoAuthUser,
Namespace: namespace,
SessionManager: sessionManager,
})
if err != nil {
return fmt.Errorf("could not initialise authentication server: %w", err)
Expand Down Expand Up @@ -257,6 +262,7 @@ func runCmd(cmd *cobra.Command, args []string) error {
CoreServerConfig: coreConfig,
AuthServer: authServer,
},
sessionManager,
)
if err != nil {
return fmt.Errorf("could not create handler: %w", err)
Expand All @@ -283,11 +289,13 @@ func runCmd(cmd *cobra.Command, args []string) error {
mdlw := httpmiddleware.New(httpmiddleware.Config{
Recorder: metrics.NewRecorder(metrics.Config{}),
})
handler = httpmiddlewarestd.Handler("", mdlw, mux)
handler = httpmiddlewarestd.Handler("", mdlw, handler)
}

handler = middleware.WithLogging(log, handler)

handler = sessionManager.LoadAndSave(handler)

addr := net.JoinHostPort(options.Host, options.Port)
srv := &http.Server{
Addr: addr,
Expand Down
9 changes: 2 additions & 7 deletions cmd/gitops-server/main.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,10 @@
package main

import (
"fmt"
"os"

"github.com/spf13/cobra"
"github.com/weaveworks/weave-gitops/cmd/gitops-server/cmd"
)

func main() {
if err := cmd.NewCommand().Execute(); err != nil {
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
os.Exit(1)
}
cobra.CheckErr(cmd.NewCommand().Execute())
}
4 changes: 1 addition & 3 deletions core/clustersmngr/factory_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package clustersmngr_test

import (
"fmt"
"testing"

"github.com/go-logr/logr"
Expand Down Expand Up @@ -110,8 +109,7 @@ func TestUseUserClientForNamespaces(t *testing.T) {
g.Expect(userClient.Namespaces()["test"]).To(HaveLen(1))
g.Expect(userClient.Namespaces()["test"][0].GetName()).To(Equal(ns2.Name))

a, b, nss := nsChecker.FilterAccessibleNamespacesArgsForCall(0)
fmt.Println(a, b, nss)
_, _, nss := nsChecker.FilterAccessibleNamespacesArgsForCall(0)
nsFound := 0
for _, n := range nss {
if n.Name == ns1.Name || n.Name == ns2.Name {
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ go 1.20
require (
github.com/Masterminds/semver/v3 v3.2.0
github.com/NYTimes/gziphandler v1.1.1
github.com/alexedwards/scs/v2 v2.5.1
github.com/charmbracelet/bubbles v0.14.0
github.com/charmbracelet/bubbletea v0.22.1
github.com/charmbracelet/lipgloss v0.6.0
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ github.com/alecthomas/colour v0.0.0-20160524082231-60882d9e2721/go.mod h1:QO9JBo
github.com/alecthomas/kong v0.2.4/go.mod h1:kQOmtJgV+Lb4aj+I2LEn40cbtawdWJ9Y8QLq+lElKxE=
github.com/alecthomas/repr v0.0.0-20180818092828-117648cd9897 h1:p9Sln00KOTlrYkxI1zYWl1QLnEqAqEARBEYa8FQnQcY=
github.com/alecthomas/repr v0.0.0-20180818092828-117648cd9897/go.mod h1:xTS7Pm1pD1mvyM075QCDSRqH6qRLXylzS24ZTpRiSzQ=
github.com/alexedwards/scs/v2 v2.5.1 h1:EhAz3Kb3OSQzD8T+Ub23fKsiuvE0GzbF5Lgn0uTwM3Y=
github.com/alexedwards/scs/v2 v2.5.1/go.mod h1:ToaROZxyKukJKT/xLcVQAChi5k6+Pn1Gvmdl7h3RRj8=
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFIImctFaOjnTIavg87rW78vTPkQqLI8=
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4=
github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY=
Expand Down
13 changes: 8 additions & 5 deletions pkg/server/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func RegisterAuthServer(mux *http.ServeMux, prefix string, srv *AuthServer, logi
mux.Handle(prefix+"/sign_in", middleware.Handle(srv.SignIn()))
mux.HandleFunc(prefix+"/userinfo", srv.UserInfo)
mux.HandleFunc(prefix+"/refresh", srv.RefreshHandler)
mux.Handle(prefix+"/logout", srv.Logout())
mux.HandleFunc(prefix+"/logout", srv.Logout)

return nil
}
Expand Down Expand Up @@ -156,7 +156,7 @@ func WithPrincipal(ctx context.Context, p *UserPrincipal) context.Context {
// WithAPIAuth middleware adds auth validation to API handlers.
//
// Unauthorized requests will be denied with a 401 status code.
func WithAPIAuth(next http.Handler, srv *AuthServer, publicRoutes []string) http.Handler {
func WithAPIAuth(next http.Handler, srv *AuthServer, publicRoutes []string, sm SessionManager) http.Handler {
multi := MultiAuthPrincipal{Log: srv.Log, Getters: []PrincipalGetter{}}

// FIXME: currently the order must be OIDC last, or it'll "shadow" the other
Expand All @@ -181,15 +181,16 @@ func WithAPIAuth(next http.Handler, srv *AuthServer, publicRoutes []string) http

if srv.oidcPassthroughEnabled() {
srv.Log.V(logger.LogLevelDebug).Info("JWT Token Passthrough Enabled")
multi.Getters = append(multi.Getters, NewJWTPassthroughCookiePrincipalGetter(srv.Log, srv.verifier(), IDTokenCookieName))
multi.Getters = append(multi.Getters, NewJWTPassthroughCookiePrincipalGetter(srv.Log, srv.verifier(), IDTokenCookieName, sm))
} else {
multi.Getters = append(multi.Getters, NewJWTCookiePrincipalGetter(srv.Log, srv.verifier(), IDTokenCookieName, srv.OIDCConfig.ClaimsConfig))
multi.Getters = append(multi.Getters, NewJWTCookiePrincipalGetter(srv.Log, srv.verifier(), srv.OIDCConfig.ClaimsConfig, IDTokenCookieName, sm))
}
}

case UserAccount:
if featureflags.IsSet(FeatureFlagClusterUser) {
adminAuth := NewJWTAdminCookiePrincipalGetter(srv.Log, srv.tokenSignerVerifier, IDTokenCookieName)
adminAuth := NewJWTAdminCookiePrincipalGetter(srv.Log, srv.tokenSignerVerifier, IDTokenCookieName, sm)

multi.Getters = append(multi.Getters, adminAuth)
}

Expand All @@ -207,6 +208,7 @@ func WithAPIAuth(next http.Handler, srv *AuthServer, publicRoutes []string) http
srv: srv,
publicRoutes: publicRoutes,
principalGetter: multi,
sm: sm,
}
}

Expand All @@ -215,6 +217,7 @@ type authenticatedMiddleware struct {
publicRoutes []string
next http.Handler
principalGetter PrincipalGetter
sm SessionManager
}

func (a *authenticatedMiddleware) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
Expand Down
104 changes: 93 additions & 11 deletions pkg/server/auth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@ import (
"bytes"
"context"
"fmt"

"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"

"github.com/alexedwards/scs/v2"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/go-logr/logr"
"github.com/oauth2-proxy/mockoidc"
Expand All @@ -25,6 +27,7 @@ import (
const testNamespace = "flux-system"

func TestWithAPIAuthReturns401ForUnauthenticatedRequests(t *testing.T) {
sm := scs.New()
g := NewGomegaWithT(t)

m, err := mockoidc.Run()
Expand All @@ -49,7 +52,7 @@ func TestWithAPIAuthReturns401ForUnauthenticatedRequests(t *testing.T) {

authMethods := map[auth.AuthMethod]bool{auth.OIDC: true}

authCfg, err := auth.NewAuthServerConfig(logr.Discard(), oidcCfg, fakeKubernetesClient, tokenSignerVerifier, testNamespace, authMethods, "")
authCfg, err := auth.NewAuthServerConfig(logr.Discard(), oidcCfg, fakeKubernetesClient, tokenSignerVerifier, testNamespace, authMethods, "", sm)
g.Expect(err).NotTo(HaveOccurred())

srv, err := auth.NewAuthServer(context.Background(), authCfg)
Expand All @@ -68,14 +71,16 @@ func TestWithAPIAuthReturns401ForUnauthenticatedRequests(t *testing.T) {

res := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, s.URL, nil)
auth.WithAPIAuth(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {}), srv, nil).ServeHTTP(res, req)
handler := sm.LoadAndSave(auth.WithAPIAuth(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {}), srv, nil, sm))

handler.ServeHTTP(res, req)

g.Expect(res).To(HaveHTTPStatus(http.StatusUnauthorized))

// Test out the publicRoutes
res = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, s.URL+"/v1/featureflags", nil)
auth.WithAPIAuth(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {}), srv, []string{"/v1/featureflags"}).ServeHTTP(res, req)
auth.WithAPIAuth(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {}), srv, []string{"/v1/featureflags"}, sm).ServeHTTP(res, req)

g.Expect(res).To(HaveHTTPStatus(http.StatusOK))
}
Expand All @@ -84,7 +89,7 @@ func TestAnonymousAuth(t *testing.T) {
g := NewGomegaWithT(t)

authMethods := map[auth.AuthMethod]bool{auth.Anonymous: true}
authCfg, err := auth.NewAuthServerConfig(logr.Discard(), auth.OIDCConfig{}, nil, nil, testNamespace, authMethods, "test-user")
authCfg, err := auth.NewAuthServerConfig(logr.Discard(), auth.OIDCConfig{}, nil, nil, testNamespace, authMethods, "test-user", nil)
g.Expect(err).NotTo(HaveOccurred())

srv, err := auth.NewAuthServer(context.Background(), authCfg)
Expand All @@ -95,13 +100,14 @@ func TestAnonymousAuth(t *testing.T) {
auth.WithAPIAuth(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
// no cookie checking etc, principal is just there ready to go
g.Expect(auth.Principal(r.Context()).ID).To(Equal("test-user"))
}), srv, nil).ServeHTTP(res, req)
}), srv, nil, nil).ServeHTTP(res, req)
}

func TestWithAPIAuthOnlyUsesValidMethods(t *testing.T) {
// In theory all attempts to login in this should fail as, despite
// the auth server having access to
g := NewGomegaWithT(t)
sm := scs.New()

m, err := mockoidc.Run()
g.Expect(err).NotTo(HaveOccurred())
Expand Down Expand Up @@ -140,7 +146,7 @@ func TestWithAPIAuthOnlyUsesValidMethods(t *testing.T) {

authMethods := map[auth.AuthMethod]bool{} // This is not a valid AuthMethod

authCfg, err := auth.NewAuthServerConfig(logr.Discard(), oidcCfg, fakeKubernetesClient, tokenSignerVerifier, testNamespace, authMethods, "")
authCfg, err := auth.NewAuthServerConfig(logr.Discard(), oidcCfg, fakeKubernetesClient, tokenSignerVerifier, testNamespace, authMethods, "", sm)
g.Expect(err).NotTo(HaveOccurred())

srv, err := auth.NewAuthServer(context.Background(), authCfg)
Expand All @@ -159,7 +165,7 @@ func TestWithAPIAuthOnlyUsesValidMethods(t *testing.T) {

res := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, s.URL, nil)
auth.WithAPIAuth(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {}), srv, nil).ServeHTTP(res, req)
auth.WithAPIAuth(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {}), srv, nil, scs.New()).ServeHTTP(res, req)

g.Expect(res).To(HaveHTTPStatus(http.StatusUnauthorized))

Expand All @@ -172,13 +178,14 @@ func TestWithAPIAuthOnlyUsesValidMethods(t *testing.T) {
// Test out the publicRoutes
res = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, s.URL+"/v1/featureflags", nil)
auth.WithAPIAuth(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {}), srv, []string{"/v1/featureflags"}).ServeHTTP(res, req)
auth.WithAPIAuth(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {}), srv, []string{"/v1/featureflags"}, scs.New()).ServeHTTP(res, req)

g.Expect(res).To(HaveHTTPStatus(http.StatusOK))
}

func TestOauth2FlowRedirectsToOIDCIssuerWithCustomScopes(t *testing.T) {
g := NewGomegaWithT(t)
sm := &fakeSessionManager{}

m, err := mockoidc.Run()
g.Expect(err).NotTo(HaveOccurred())
Expand All @@ -204,7 +211,7 @@ func TestOauth2FlowRedirectsToOIDCIssuerWithCustomScopes(t *testing.T) {

authMethods := map[auth.AuthMethod]bool{auth.OIDC: true}

authCfg, err := auth.NewAuthServerConfig(logr.Discard(), oidcCfg, fakeKubernetesClient, tokenSignerVerifier, testNamespace, authMethods, "")
authCfg, err := auth.NewAuthServerConfig(logr.Discard(), oidcCfg, fakeKubernetesClient, tokenSignerVerifier, testNamespace, authMethods, "", sm)
g.Expect(err).NotTo(HaveOccurred())

srv, err := auth.NewAuthServer(context.Background(), authCfg)
Expand Down Expand Up @@ -232,6 +239,7 @@ func TestOauth2FlowRedirectsToOIDCIssuerWithCustomScopes(t *testing.T) {

func TestOauth2FlowRedirectsToOIDCIssuerForUnauthenticatedRequests(t *testing.T) {
g := NewGomegaWithT(t)
sm := &fakeSessionManager{}

m, err := mockoidc.Run()
g.Expect(err).NotTo(HaveOccurred())
Expand All @@ -257,7 +265,7 @@ func TestOauth2FlowRedirectsToOIDCIssuerForUnauthenticatedRequests(t *testing.T)

authMethods := map[auth.AuthMethod]bool{auth.OIDC: true}

authCfg, err := auth.NewAuthServerConfig(logr.Discard(), oidcCfg, fakeKubernetesClient, tokenSignerVerifier, testNamespace, authMethods, "")
authCfg, err := auth.NewAuthServerConfig(logr.Discard(), oidcCfg, fakeKubernetesClient, tokenSignerVerifier, testNamespace, authMethods, "", sm)
g.Expect(err).NotTo(HaveOccurred())

srv, err := auth.NewAuthServer(context.Background(), authCfg)
Expand Down Expand Up @@ -295,6 +303,7 @@ func TestIsPublicRoute(t *testing.T) {

func TestRateLimit(t *testing.T) {
g := NewGomegaWithT(t)
sm := &fakeSessionManager{}

mux := http.NewServeMux()
tokenSignerVerifier, err := auth.NewHMACTokenSignerVerifier(5 * time.Minute)
Expand All @@ -319,7 +328,7 @@ func TestRateLimit(t *testing.T) {

authMethods := map[auth.AuthMethod]bool{auth.UserAccount: true}

authCfg, err := auth.NewAuthServerConfig(logr.Discard(), oidcCfg, fakeKubernetesClient, tokenSignerVerifier, testNamespace, authMethods, "")
authCfg, err := auth.NewAuthServerConfig(logr.Discard(), oidcCfg, fakeKubernetesClient, tokenSignerVerifier, testNamespace, authMethods, "", sm)
g.Expect(err).NotTo(HaveOccurred())

srv, err := auth.NewAuthServer(context.Background(), authCfg)
Expand Down Expand Up @@ -407,3 +416,76 @@ func TestUserPrincipal_String(t *testing.T) {
t.Fatalf("principal.String() got %s, want %s", s, `id="testing" groups=[group1 group2]`)
}
}

type sessionsCtxKey struct{}

// Use the fakeSessionManager for cases where you want to pass in an
// *http.Request to a handler rather than routing through a Mux.
var _ auth.SessionManager = (*fakeSessionManager)(nil)

func contextWithSessionValues(values map[string]any) context.Context {
return contextWithValues(context.TODO(), values)
}

func contextWithValues(ctx context.Context, values map[string]any) context.Context {
return context.WithValue(ctx, sessionsCtxKey{}, values)
}

type fakeSessionManager struct {
// Things that are put into the context are actually stored here
// They would normally be output in the `LoadAndSave` middleware by the
// user's session cookie
PutValues map[string]any
// Record the IDs of destroyed sessions (taken from the sessionid in the
// context).
Destroyed []string
}

func (sm *fakeSessionManager) stringValue(name string) string {
v, ok := sm.PutValues[name]
if !ok {
return ""
}
return v.(string)
}

func (sm *fakeSessionManager) LoadAndSave(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
http.Error(w, "Not Implemented", http.StatusNotImplemented)
})
}

func (sm *fakeSessionManager) GetString(ctx context.Context, key string) string {
values, ok := ctx.Value(sessionsCtxKey{}).(map[string]any)
if !ok {
values = map[string]any{}
}

v, ok := values[key]
if ok {
return v.(string)
}

return ""
}

func (sm *fakeSessionManager) Remove(context.Context, string) {
panic("not implemented")
}

func (sm *fakeSessionManager) Put(ctx context.Context, key string, val interface{}) {
if sm.PutValues == nil {
sm.PutValues = map[string]any{}
}

sm.PutValues[key] = val
}

func (sm *fakeSessionManager) Destroy(ctx context.Context) error {
sid := sm.GetString(ctx, "sessionid")
if sid != "" {
sm.Destroyed = append(sm.Destroyed, sid)
}

return nil
}
Loading

0 comments on commit 20749c9

Please sign in to comment.