diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 7483b2f7..329428d7 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -16,7 +16,7 @@ jobs: runs-on: ubuntu-20.04 strategy: matrix: - go: ['1.18', '1.19', '1.20'] + go: ['1.19', '1.20', '1.21'] name: Go ${{ matrix.go }} test steps: - uses: actions/checkout@v3 diff --git a/README.md b/README.md index b7993e69..91a2f39d 100644 --- a/README.md +++ b/README.md @@ -115,10 +115,10 @@ Versions that also build are marked with :warning:. | Version | Supported | | ------- | ------------------ | -| <1.18 | :x: | -| 1.18 | :warning: | -| 1.19 | :white_check_mark: | +| <1.19 | :x: | +| 1.19 | :warning: | | 1.20 | :white_check_mark: | +| 1.21 | :white_check_mark: | ## Why another library diff --git a/example/client/app/app.go b/example/client/app/app.go index 2cb5dfa7..0e339f40 100644 --- a/example/client/app/app.go +++ b/example/client/app/app.go @@ -7,11 +7,14 @@ import ( "net/http" "os" "strings" + "sync/atomic" "time" "github.com/google/uuid" "github.com/sirupsen/logrus" + "golang.org/x/exp/slog" + "github.com/zitadel/logging" "github.com/zitadel/oidc/v3/pkg/client/rp" httphelper "github.com/zitadel/oidc/v3/pkg/http" "github.com/zitadel/oidc/v3/pkg/oidc" @@ -33,9 +36,25 @@ func main() { redirectURI := fmt.Sprintf("http://localhost:%v%v", port, callbackPath) cookieHandler := httphelper.NewCookieHandler(key, key, httphelper.WithUnsecure()) + logger := slog.New( + slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{ + AddSource: true, + Level: slog.LevelDebug, + }), + ) + client := &http.Client{ + Timeout: time.Minute, + } + // enable outgoing request logging + logging.EnableHTTPClient(client, + logging.WithClientGroup("client"), + ) + options := []rp.Option{ rp.WithCookieHandler(cookieHandler), rp.WithVerifierOpts(rp.WithIssuedAtOffset(5 * time.Second)), + rp.WithHTTPClient(client), + rp.WithLogger(logger), } if clientSecret == "" { options = append(options, rp.WithPKCE(cookieHandler)) @@ -44,7 +63,10 @@ func main() { options = append(options, rp.WithJWTProfile(rp.SignerFromKeyPath(keyPath))) } - provider, err := rp.NewRelyingPartyOIDC(context.TODO(), issuer, clientID, clientSecret, redirectURI, scopes, options...) + // One can add a logger to the context, + // pre-defining log attributes as required. + ctx := logging.ToContext(context.TODO(), logger) + provider, err := rp.NewRelyingPartyOIDC(ctx, issuer, clientID, clientSecret, redirectURI, scopes, options...) if err != nil { logrus.Fatalf("error creating provider %s", err.Error()) } @@ -119,8 +141,22 @@ func main() { // // http.Handle(callbackPath, rp.CodeExchangeHandler(marshalToken, provider)) + // simple counter for request IDs + var counter atomic.Int64 + // enable incomming request logging + mw := logging.Middleware( + logging.WithLogger(logger), + logging.WithGroup("server"), + logging.WithIDFunc(func() slog.Attr { + return slog.Int64("id", counter.Add(1)) + }), + ) + lis := fmt.Sprintf("127.0.0.1:%s", port) - logrus.Infof("listening on http://%s/", lis) - logrus.Info("press ctrl+c to stop") - logrus.Fatal(http.ListenAndServe(lis, nil)) + logger.Info("server listening, press ctrl+c to stop", "addr", lis) + err = http.ListenAndServe(lis, mw(http.DefaultServeMux)) + if err != http.ErrServerClosed { + logger.Error("server terminated", "error", err) + os.Exit(1) + } } diff --git a/example/server/exampleop/op.go b/example/server/exampleop/op.go index 298bff69..b5ee7b37 100644 --- a/example/server/exampleop/op.go +++ b/example/server/exampleop/op.go @@ -4,9 +4,12 @@ import ( "crypto/sha256" "log" "net/http" + "sync/atomic" "time" "github.com/go-chi/chi" + "github.com/zitadel/logging" + "golang.org/x/exp/slog" "golang.org/x/text/language" "github.com/zitadel/oidc/v3/example/server/storage" @@ -31,26 +34,33 @@ type Storage interface { deviceAuthenticate } +// simple counter for request IDs +var counter atomic.Int64 + // SetupServer creates an OIDC server with Issuer=http://localhost: // // Use one of the pre-made clients in storage/clients.go or register a new one. -func SetupServer(issuer string, storage Storage) chi.Router { +func SetupServer(issuer string, storage Storage, logger *slog.Logger) chi.Router { // the OpenID Provider requires a 32-byte key for (token) encryption // be sure to create a proper crypto random key and manage it securely! key := sha256.Sum256([]byte("test")) router := chi.NewRouter() + router.Use(logging.Middleware( + logging.WithLogger(logger), + logging.WithIDFunc(func() slog.Attr { + return slog.Int64("id", counter.Add(1)) + }), + )) // for simplicity, we provide a very small default page for users who have signed out router.HandleFunc(pathLoggedOut, func(w http.ResponseWriter, req *http.Request) { - _, err := w.Write([]byte("signed out successfully")) - if err != nil { - log.Printf("error serving logged out page: %v", err) - } + w.Write([]byte("signed out successfully")) + // no need to check/log error, this will be handeled by the middleware. }) // creation of the OpenIDProvider with the just created in-memory Storage - provider, err := newOP(storage, issuer, key) + provider, err := newOP(storage, issuer, key, logger) if err != nil { log.Fatal(err) } @@ -80,7 +90,7 @@ func SetupServer(issuer string, storage Storage) chi.Router { // newOP will create an OpenID Provider for localhost on a specified port with a given encryption key // and a predefined default logout uri // it will enable all options (see descriptions) -func newOP(storage op.Storage, issuer string, key [32]byte) (op.OpenIDProvider, error) { +func newOP(storage op.Storage, issuer string, key [32]byte, logger *slog.Logger) (op.OpenIDProvider, error) { config := &op.Config{ CryptoKey: key, @@ -117,6 +127,8 @@ func newOP(storage op.Storage, issuer string, key [32]byte) (op.OpenIDProvider, op.WithAllowInsecure(), // as an example on how to customize an endpoint this will change the authorization_endpoint from /authorize to /auth op.WithCustomAuthEndpoint(op.NewEndpoint("auth")), + // Pass our logger to the OP + op.WithLogger(logger.WithGroup("op")), ) if err != nil { return nil, err diff --git a/example/server/main.go b/example/server/main.go index ee27bbab..a1cc4618 100644 --- a/example/server/main.go +++ b/example/server/main.go @@ -2,11 +2,12 @@ package main import ( "fmt" - "log" "net/http" + "os" "github.com/zitadel/oidc/v3/example/server/exampleop" "github.com/zitadel/oidc/v3/example/server/storage" + "golang.org/x/exp/slog" ) func main() { @@ -20,16 +21,22 @@ func main() { // in this example it will be handled in-memory storage := storage.NewStorage(storage.NewUserStore(issuer)) - router := exampleop.SetupServer(issuer, storage) + logger := slog.New( + slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{ + AddSource: true, + Level: slog.LevelDebug, + }), + ) + router := exampleop.SetupServer(issuer, storage, logger) server := &http.Server{ Addr: ":" + port, Handler: router, } - log.Printf("server listening on http://localhost:%s/", port) - log.Println("press ctrl+c to stop") + logger.Info("server listening, press ctrl+c to stop", "addr", fmt.Sprintf("http://localhost:%s/", port)) err := server.ListenAndServe() - if err != nil { - log.Fatal(err) + if err != http.ErrServerClosed { + logger.Error("server terminated", "error", err) + os.Exit(1) } } diff --git a/example/server/storage/oidc.go b/example/server/storage/oidc.go index b56ad090..63afcf93 100644 --- a/example/server/storage/oidc.go +++ b/example/server/storage/oidc.go @@ -3,6 +3,7 @@ package storage import ( "time" + "golang.org/x/exp/slog" "golang.org/x/text/language" "github.com/zitadel/oidc/v3/pkg/oidc" @@ -41,6 +42,19 @@ type AuthRequest struct { authTime time.Time } +// LogValue allows you to define which fields will be logged. +// Implements the [slog.LogValuer] +func (a *AuthRequest) LogValue() slog.Value { + return slog.GroupValue( + slog.String("id", a.ID), + slog.Time("creation_date", a.CreationDate), + slog.Any("scopes", a.Scopes), + slog.String("response_type", string(a.ResponseType)), + slog.String("app_id", a.ApplicationID), + slog.String("callback_uri", a.CallbackURI), + ) +} + func (a *AuthRequest) GetID() string { return a.ID } diff --git a/go.mod b/go.mod index 610d2a10..62aa39be 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/zitadel/oidc/v3 -go 1.18 +go 1.19 require ( github.com/go-chi/chi v1.5.4 @@ -11,9 +11,11 @@ require ( github.com/jeremija/gosubmit v0.2.7 github.com/muhlemmer/gu v0.3.1 github.com/rs/cors v1.9.0 - github.com/sirupsen/logrus v1.9.0 + github.com/sirupsen/logrus v1.9.3 github.com/stretchr/testify v1.8.2 + github.com/zitadel/logging v0.4.0 github.com/zitadel/schema v1.3.0 + golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63 golang.org/x/oauth2 v0.7.0 golang.org/x/text v0.9.0 gopkg.in/square/go-jose.v2 v2.6.0 @@ -27,7 +29,7 @@ require ( github.com/pmezard/go-difflib v1.0.0 // indirect golang.org/x/crypto v0.7.0 // indirect golang.org/x/net v0.9.0 // indirect - golang.org/x/sys v0.7.0 // indirect + golang.org/x/sys v0.11.0 // indirect google.golang.org/appengine v1.6.7 // indirect google.golang.org/protobuf v1.29.1 // indirect gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect diff --git a/go.sum b/go.sum index c9c85626..9c44f0f6 100644 --- a/go.sum +++ b/go.sum @@ -36,8 +36,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rs/cors v1.9.0 h1:l9HGsTsHJcvW14Nk7J9KFz8bzeAWXn3CG6bgt7LsrAE= github.com/rs/cors v1.9.0/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU= -github.com/sirupsen/logrus v1.9.0 h1:trlNQbNUG3OdDrDil03MCb1H2o9nJ1x4/5LYw7byDE0= -github.com/sirupsen/logrus v1.9.0/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= +github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= +github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= @@ -47,12 +47,16 @@ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8= github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= +github.com/zitadel/logging v0.4.0 h1:lRAIFgaRoJpLNbsL7jtIYHcMDoEJP9QZB4GqMfl4xaA= +github.com/zitadel/logging v0.4.0/go.mod h1:6uALRJawpkkuUPCkgzfgcPR3c2N908wqnOnIrRelUFc= github.com/zitadel/schema v1.3.0 h1:kQ9W9tvIwZICCKWcMvCEweXET1OcOyGEuFbHs4o5kg0= github.com/zitadel/schema v1.3.0/go.mod h1:NptN6mkBDFvERUCvZHlvWmmME+gmZ44xzwRXwhzsbtc= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.7.0 h1:AvwMYaRytfdeVt3u6mLaxYtErKYjxA2OXjJ1HHq6t3A= golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU= +golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63 h1:m64FZMko/V45gv0bNmrNYoDEq8U5YUhetc9cBWKS1TQ= +golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63/go.mod h1:0v4NqG35kSWCMzLaMeX+IQrlSnVE/bqGSyC2cz/9Le8= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= @@ -73,8 +77,8 @@ golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.7.0 h1:3jlCCIQZPdOYu1h8BkNvLz8Kgwtae2cagcG/VamtZRU= -golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.11.0 h1:eG7RXZHdqOJ1i+0lgLgCpSXAp6M3LYlAo6osgSi0xOM= +golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= @@ -100,6 +104,7 @@ gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntN gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/square/go-jose.v2 v2.6.0 h1:NGk74WTnPKBNUhNzQX7PYcTLUjoq7mzKk2OKbvwk2iI= gopkg.in/square/go-jose.v2 v2.6.0/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI= +gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/pkg/client/client.go b/pkg/client/client.go index e3efd611..7b76dfd6 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -14,6 +14,7 @@ import ( "golang.org/x/oauth2" "gopkg.in/square/go-jose.v2" + "github.com/zitadel/logging" "github.com/zitadel/oidc/v3/pkg/crypto" httphelper "github.com/zitadel/oidc/v3/pkg/http" "github.com/zitadel/oidc/v3/pkg/oidc" @@ -37,6 +38,10 @@ func Discover(ctx context.Context, issuer string, httpClient *http.Client, wellK if err != nil { return nil, err } + if logger, ok := logging.FromContext(ctx); ok { + logger.Debug("discover", "config", discoveryConfig) + } + if discoveryConfig.Issuer != issuer { return nil, oidc.ErrIssuerInvalid } diff --git a/pkg/client/integration_test.go b/pkg/client/integration_test.go index 073efef7..7cbb62e6 100644 --- a/pkg/client/integration_test.go +++ b/pkg/client/integration_test.go @@ -19,6 +19,7 @@ import ( "github.com/jeremija/gosubmit" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "golang.org/x/exp/slog" "github.com/zitadel/oidc/v3/example/server/exampleop" "github.com/zitadel/oidc/v3/example/server/storage" @@ -29,6 +30,13 @@ import ( "github.com/zitadel/oidc/v3/pkg/oidc" ) +var Logger = slog.New( + slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{ + AddSource: true, + Level: slog.LevelDebug, + }), +) + var CTX context.Context func TestMain(m *testing.M) { @@ -49,7 +57,7 @@ func TestRelyingPartySession(t *testing.T) { opServer := httptest.NewServer(&dh) defer opServer.Close() t.Logf("auth server at %s", opServer.URL) - dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage) + dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage, Logger) seed := rand.New(rand.NewSource(int64(os.Getpid()) + time.Now().UnixNano())) clientID := t.Name() + "-" + strconv.FormatInt(seed.Int63(), 25) @@ -100,7 +108,7 @@ func TestResourceServerTokenExchange(t *testing.T) { opServer := httptest.NewServer(&dh) defer opServer.Close() t.Logf("auth server at %s", opServer.URL) - dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage) + dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage, Logger) seed := rand.New(rand.NewSource(int64(os.Getpid()) + time.Now().UnixNano())) clientID := t.Name() + "-" + strconv.FormatInt(seed.Int63(), 25) diff --git a/pkg/client/rp/device.go b/pkg/client/rp/device.go index 390c8cf4..02c647e3 100644 --- a/pkg/client/rp/device.go +++ b/pkg/client/rp/device.go @@ -33,6 +33,7 @@ func newDeviceClientCredentialsRequest(scopes []string, rp RelyingParty) (*oidc. // in RFC 8628, section 3.1 and 3.2: // https://www.rfc-editor.org/rfc/rfc8628#section-3.1 func DeviceAuthorization(ctx context.Context, scopes []string, rp RelyingParty, authFn any) (*oidc.DeviceAuthorizationResponse, error) { + ctx = logCtxWithRPData(ctx, rp, "function", "DeviceAuthorization") req, err := newDeviceClientCredentialsRequest(scopes, rp) if err != nil { return nil, err @@ -45,6 +46,7 @@ func DeviceAuthorization(ctx context.Context, scopes []string, rp RelyingParty, // by means of polling as defined in RFC, section 3.3 and 3.4: // https://www.rfc-editor.org/rfc/rfc8628#section-3.4 func DeviceAccessToken(ctx context.Context, deviceCode string, interval time.Duration, rp RelyingParty) (resp *oidc.AccessTokenResponse, err error) { + ctx = logCtxWithRPData(ctx, rp, "function", "DeviceAccessToken") req := &client.DeviceAccessTokenRequest{ DeviceAccessTokenRequest: oidc.DeviceAccessTokenRequest{ GrantType: oidc.GrantTypeDeviceCode, diff --git a/pkg/client/rp/log.go b/pkg/client/rp/log.go new file mode 100644 index 00000000..6056fa2e --- /dev/null +++ b/pkg/client/rp/log.go @@ -0,0 +1,17 @@ +package rp + +import ( + "context" + + "github.com/zitadel/logging" + "golang.org/x/exp/slog" +) + +func logCtxWithRPData(ctx context.Context, rp RelyingParty, attrs ...any) context.Context { + logger, ok := rp.Logger(ctx) + if !ok { + return ctx + } + logger = logger.With(slog.Group("rp", attrs...)) + return logging.ToContext(ctx, logger) +} diff --git a/pkg/client/rp/relying_party.go b/pkg/client/rp/relying_party.go index 5597c9d9..34cdb397 100644 --- a/pkg/client/rp/relying_party.go +++ b/pkg/client/rp/relying_party.go @@ -10,6 +10,8 @@ import ( "time" "github.com/google/uuid" + "github.com/zitadel/logging" + "golang.org/x/exp/slog" "golang.org/x/oauth2" "gopkg.in/square/go-jose.v2" @@ -67,6 +69,9 @@ type RelyingParty interface { // ErrorHandler returns the handler used for callback errors ErrorHandler() func(http.ResponseWriter, *http.Request, string, string, string) + + // Logger from the context, or a fallback if set. + Logger(context.Context) (logger *slog.Logger, ok bool) } type ErrorHandler func(w http.ResponseWriter, r *http.Request, errorType string, errorDesc string, state string) @@ -90,6 +95,7 @@ type relyingParty struct { idTokenVerifier *IDTokenVerifier verifierOpts []VerifierOption signer jose.Signer + logger *slog.Logger } func (rp *relyingParty) OAuthConfig() *oauth2.Config { @@ -150,6 +156,14 @@ func (rp *relyingParty) ErrorHandler() func(http.ResponseWriter, *http.Request, return rp.errorHandler } +func (rp *relyingParty) Logger(ctx context.Context) (logger *slog.Logger, ok bool) { + logger, ok = logging.FromContext(ctx) + if ok { + return logger, ok + } + return rp.logger, rp.logger != nil +} + // NewRelyingPartyOAuth creates an (OAuth2) RelyingParty with the given // OAuth2 Config and possible configOptions // it will use the AuthURL and TokenURL set in config @@ -194,6 +208,7 @@ func NewRelyingPartyOIDC(ctx context.Context, issuer, clientID, clientSecret, re return nil, err } } + ctx = logCtxWithRPData(ctx, rp, "function", "NewRelyingPartyOIDC") discoveryConfiguration, err := client.Discover(ctx, rp.issuer, rp.httpClient, rp.DiscoveryEndpoint) if err != nil { return nil, err @@ -281,6 +296,15 @@ func WithJWTProfile(signerFromKey SignerFromKey) Option { } } +// WithLogger sets a logger that is used +// in case the request context does not contain a logger. +func WithLogger(logger *slog.Logger) Option { + return func(rp *relyingParty) error { + rp.logger = logger + return nil + } +} + type SignerFromKey func() (jose.Signer, error) func SignerFromKeyPath(path string) SignerFromKey { @@ -378,6 +402,7 @@ func verifyTokenResponse[C oidc.IDClaims](ctx context.Context, token *oauth2.Tok // CodeExchange handles the oauth2 code exchange, extracting and validating the id_token // returning it parsed together with the oauth2 tokens (access, refresh) func CodeExchange[C oidc.IDClaims](ctx context.Context, code string, rp RelyingParty, opts ...CodeExchangeOpt) (tokens *oidc.Tokens[C], err error) { + ctx = logCtxWithRPData(ctx, rp, "function", "CodeExchange") ctx = context.WithValue(ctx, oauth2.HTTPClient, rp.HttpClient()) codeOpts := make([]oauth2.AuthCodeOption, 0) for _, opt := range opts { @@ -467,6 +492,7 @@ func UserinfoCallback[C oidc.IDClaims, U SubjectGetter](f CodeExchangeUserinfoCa // [UserInfo]: https://openid.net/specs/openid-connect-core-1_0.html#UserInfo func Userinfo[U SubjectGetter](ctx context.Context, token, tokenType, subject string, rp RelyingParty) (userinfo U, err error) { var nilU U + ctx = logCtxWithRPData(ctx, rp, "function", "Userinfo") req, err := http.NewRequestWithContext(ctx, http.MethodGet, rp.UserinfoEndpoint(), nil) if err != nil { @@ -546,7 +572,7 @@ func withURLParam(key, value string) func() []oauth2.AuthCodeOption { // This is the generalized, unexported, function used by both // URLParamOpt and AuthURLOpt. func withPrompt(prompt ...string) func() []oauth2.AuthCodeOption { - return withURLParam("prompt", oidc.SpaceDelimitedArray(prompt).Encode()) + return withURLParam("prompt", oidc.SpaceDelimitedArray(prompt).String()) } type URLParamOpt func() []oauth2.AuthCodeOption @@ -621,6 +647,7 @@ type RefreshTokenRequest struct { // the IDToken and AccessToken will be verfied // and the IDToken and IDTokenClaims fields will be populated in the returned object. func RefreshTokens[C oidc.IDClaims](ctx context.Context, rp RelyingParty, refreshToken, clientAssertion, clientAssertionType string) (*oidc.Tokens[C], error) { + ctx = logCtxWithRPData(ctx, rp, "function", "RefreshTokens") request := RefreshTokenRequest{ RefreshToken: refreshToken, Scopes: rp.OAuthConfig().Scopes, @@ -644,6 +671,7 @@ func RefreshTokens[C oidc.IDClaims](ctx context.Context, rp RelyingParty, refres } func EndSession(ctx context.Context, rp RelyingParty, idToken, optionalRedirectURI, optionalState string) (*url.URL, error) { + ctx = logCtxWithRPData(ctx, rp, "function", "EndSession") request := oidc.EndSessionRequest{ IdTokenHint: idToken, ClientID: rp.OAuthConfig().ClientID, @@ -659,6 +687,7 @@ func EndSession(ctx context.Context, rp RelyingParty, idToken, optionalRedirectU // // tokenTypeHint should be either "id_token" or "refresh_token". func RevokeToken(ctx context.Context, rp RelyingParty, token string, tokenTypeHint string) error { + ctx = logCtxWithRPData(ctx, rp, "function", "RevokeToken") request := client.RevokeRequest{ Token: token, TokenTypeHint: tokenTypeHint, diff --git a/pkg/oidc/authorization.go b/pkg/oidc/authorization.go index d8bf3364..7e7c30cc 100644 --- a/pkg/oidc/authorization.go +++ b/pkg/oidc/authorization.go @@ -1,5 +1,9 @@ package oidc +import ( + "golang.org/x/exp/slog" +) + const ( // ScopeOpenID defines the scope `openid` // OpenID Connect requests MUST contain the `openid` scope value @@ -86,6 +90,15 @@ type AuthRequest struct { RequestParam string `schema:"request"` } +func (a *AuthRequest) LogValue() slog.Value { + return slog.GroupValue( + slog.Any("scopes", a.Scopes), + slog.String("response_type", string(a.ResponseType)), + slog.String("client_id", a.ClientID), + slog.String("redirect_uri", a.RedirectURI), + ) +} + // GetRedirectURI returns the redirect_uri value for the ErrAuthRequest interface func (a *AuthRequest) GetRedirectURI() string { return a.RedirectURI diff --git a/pkg/oidc/authorization_test.go b/pkg/oidc/authorization_test.go new file mode 100644 index 00000000..573d65c3 --- /dev/null +++ b/pkg/oidc/authorization_test.go @@ -0,0 +1,27 @@ +//go:build go1.20 + +package oidc + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "golang.org/x/exp/slog" +) + +func TestAuthRequest_LogValue(t *testing.T) { + a := &AuthRequest{ + Scopes: SpaceDelimitedArray{"a", "b"}, + ResponseType: "respType", + ClientID: "123", + RedirectURI: "http://example.com/callback", + } + want := slog.GroupValue( + slog.Any("scopes", SpaceDelimitedArray{"a", "b"}), + slog.String("response_type", "respType"), + slog.String("client_id", "123"), + slog.String("redirect_uri", "http://example.com/callback"), + ) + got := a.LogValue() + assert.Equal(t, want, got) +} diff --git a/pkg/oidc/error.go b/pkg/oidc/error.go index 79acecd9..07a90697 100644 --- a/pkg/oidc/error.go +++ b/pkg/oidc/error.go @@ -3,6 +3,8 @@ package oidc import ( "errors" "fmt" + + "golang.org/x/exp/slog" ) type errorType string @@ -171,3 +173,34 @@ func DefaultToServerError(err error, description string) *Error { } return oauth } + +func (e *Error) LogLevel() slog.Level { + level := slog.LevelWarn + if e.ErrorType == ServerError { + level = slog.LevelError + } + if e.ErrorType == AuthorizationPending { + level = slog.LevelInfo + } + return level +} + +func (e *Error) LogValue() slog.Value { + attrs := make([]slog.Attr, 0, 5) + if e.Parent != nil { + attrs = append(attrs, slog.Any("parent", e.Parent)) + } + if e.Description != "" { + attrs = append(attrs, slog.String("description", e.Description)) + } + if e.ErrorType != "" { + attrs = append(attrs, slog.String("type", string(e.ErrorType))) + } + if e.State != "" { + attrs = append(attrs, slog.String("state", e.State)) + } + if e.redirectDisabled { + attrs = append(attrs, slog.Bool("redirect_disabled", e.redirectDisabled)) + } + return slog.GroupValue(attrs...) +} diff --git a/pkg/oidc/error_go120_test.go b/pkg/oidc/error_go120_test.go new file mode 100644 index 00000000..399d7f71 --- /dev/null +++ b/pkg/oidc/error_go120_test.go @@ -0,0 +1,83 @@ +//go:build go1.20 + +package oidc + +import ( + "io" + "testing" + + "github.com/stretchr/testify/assert" + "golang.org/x/exp/slog" +) + +func TestError_LogValue(t *testing.T) { + type fields struct { + Parent error + ErrorType errorType + Description string + State string + redirectDisabled bool + } + tests := []struct { + name string + fields fields + want slog.Value + }{ + { + name: "parent", + fields: fields{ + Parent: io.EOF, + }, + want: slog.GroupValue(slog.Any("parent", io.EOF)), + }, + { + name: "description", + fields: fields{ + Description: "oops", + }, + want: slog.GroupValue(slog.String("description", "oops")), + }, + { + name: "errorType", + fields: fields{ + ErrorType: ExpiredToken, + }, + want: slog.GroupValue(slog.String("type", string(ExpiredToken))), + }, + { + name: "state", + fields: fields{ + State: "123", + }, + want: slog.GroupValue(slog.String("state", "123")), + }, + { + name: "all fields", + fields: fields{ + Parent: io.EOF, + Description: "oops", + ErrorType: ExpiredToken, + State: "123", + }, + want: slog.GroupValue( + slog.Any("parent", io.EOF), + slog.String("description", "oops"), + slog.String("type", string(ExpiredToken)), + slog.String("state", "123"), + ), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + e := &Error{ + Parent: tt.fields.Parent, + ErrorType: tt.fields.ErrorType, + Description: tt.fields.Description, + State: tt.fields.State, + redirectDisabled: tt.fields.redirectDisabled, + } + got := e.LogValue() + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/pkg/oidc/error_test.go b/pkg/oidc/error_test.go new file mode 100644 index 00000000..0554c8fb --- /dev/null +++ b/pkg/oidc/error_test.go @@ -0,0 +1,81 @@ +package oidc + +import ( + "io" + "testing" + + "github.com/stretchr/testify/assert" + "golang.org/x/exp/slog" +) + +func TestDefaultToServerError(t *testing.T) { + type args struct { + err error + description string + } + tests := []struct { + name string + args args + want *Error + }{ + { + name: "default", + args: args{ + err: io.ErrClosedPipe, + description: "oops", + }, + want: &Error{ + ErrorType: ServerError, + Description: "oops", + Parent: io.ErrClosedPipe, + }, + }, + { + name: "our Error", + args: args{ + err: ErrAccessDenied(), + description: "oops", + }, + want: &Error{ + ErrorType: AccessDenied, + Description: "The authorization request was denied.", + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := DefaultToServerError(tt.args.err, tt.args.description) + assert.ErrorIs(t, got, tt.want) + }) + } +} + +func TestError_LogLevel(t *testing.T) { + tests := []struct { + name string + err *Error + want slog.Level + }{ + { + name: "server error", + err: ErrServerError(), + want: slog.LevelError, + }, + { + name: "authorization pending", + err: ErrAuthorizationPending(), + want: slog.LevelInfo, + }, + { + name: "some other error", + err: ErrAccessDenied(), + want: slog.LevelWarn, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.err.LogLevel() + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/pkg/oidc/types.go b/pkg/oidc/types.go index 86ee1e0f..5db8badc 100644 --- a/pkg/oidc/types.go +++ b/pkg/oidc/types.go @@ -106,7 +106,7 @@ type ResponseType string type ResponseMode string -func (s SpaceDelimitedArray) Encode() string { +func (s SpaceDelimitedArray) String() string { return strings.Join(s, " ") } @@ -116,11 +116,11 @@ func (s *SpaceDelimitedArray) UnmarshalText(text []byte) error { } func (s SpaceDelimitedArray) MarshalText() ([]byte, error) { - return []byte(s.Encode()), nil + return []byte(s.String()), nil } func (s SpaceDelimitedArray) MarshalJSON() ([]byte, error) { - return json.Marshal((s).Encode()) + return json.Marshal((s).String()) } func (s *SpaceDelimitedArray) UnmarshalJSON(data []byte) error { @@ -165,7 +165,7 @@ func (s SpaceDelimitedArray) Value() (driver.Value, error) { func NewEncoder() *schema.Encoder { e := schema.NewEncoder() e.RegisterEncoder(SpaceDelimitedArray{}, func(value reflect.Value) string { - return value.Interface().(SpaceDelimitedArray).Encode() + return value.Interface().(SpaceDelimitedArray).String() }) return e } diff --git a/pkg/op/auth_request.go b/pkg/op/auth_request.go index 7af3779e..7610248e 100644 --- a/pkg/op/auth_request.go +++ b/pkg/op/auth_request.go @@ -14,6 +14,7 @@ import ( httphelper "github.com/zitadel/oidc/v3/pkg/http" "github.com/zitadel/oidc/v3/pkg/oidc" str "github.com/zitadel/oidc/v3/pkg/strings" + "golang.org/x/exp/slog" ) type AuthRequest interface { @@ -41,6 +42,7 @@ type Authorizer interface { IDTokenHintVerifier(context.Context) *IDTokenHintVerifier Crypto() Crypto RequestObjectSupported() bool + Logger() *slog.Logger } // AuthorizeValidator is an extension of Authorizer interface @@ -67,23 +69,23 @@ func authorizeCallbackHandler(authorizer Authorizer) func(http.ResponseWriter, * func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) { authReq, err := ParseAuthorizeRequest(r, authorizer.Decoder()) if err != nil { - AuthRequestError(w, r, nil, err, authorizer.Encoder()) + AuthRequestError(w, r, nil, err, authorizer) return } ctx := r.Context() if authReq.RequestParam != "" && authorizer.RequestObjectSupported() { authReq, err = ParseRequestObject(ctx, authReq, authorizer.Storage(), IssuerFromContext(ctx)) if err != nil { - AuthRequestError(w, r, authReq, err, authorizer.Encoder()) + AuthRequestError(w, r, authReq, err, authorizer) return } } if authReq.ClientID == "" { - AuthRequestError(w, r, authReq, fmt.Errorf("auth request is missing client_id"), authorizer.Encoder()) + AuthRequestError(w, r, authReq, fmt.Errorf("auth request is missing client_id"), authorizer) return } if authReq.RedirectURI == "" { - AuthRequestError(w, r, authReq, fmt.Errorf("auth request is missing redirect_uri"), authorizer.Encoder()) + AuthRequestError(w, r, authReq, fmt.Errorf("auth request is missing redirect_uri"), authorizer) return } validation := ValidateAuthRequest @@ -92,21 +94,21 @@ func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) { } userID, err := validation(ctx, authReq, authorizer.Storage(), authorizer.IDTokenHintVerifier(ctx)) if err != nil { - AuthRequestError(w, r, authReq, err, authorizer.Encoder()) + AuthRequestError(w, r, authReq, err, authorizer) return } if authReq.RequestParam != "" { - AuthRequestError(w, r, authReq, oidc.ErrRequestNotSupported(), authorizer.Encoder()) + AuthRequestError(w, r, authReq, oidc.ErrRequestNotSupported(), authorizer) return } req, err := authorizer.Storage().CreateAuthRequest(ctx, authReq, userID) if err != nil { - AuthRequestError(w, r, authReq, oidc.DefaultToServerError(err, "unable to save auth request"), authorizer.Encoder()) + AuthRequestError(w, r, authReq, oidc.DefaultToServerError(err, "unable to save auth request"), authorizer) return } client, err := authorizer.Storage().GetClientByClientID(ctx, req.GetClientID()) if err != nil { - AuthRequestError(w, r, req, oidc.DefaultToServerError(err, "unable to retrieve client by id"), authorizer.Encoder()) + AuthRequestError(w, r, req, oidc.DefaultToServerError(err, "unable to retrieve client by id"), authorizer) return } RedirectToLogin(req.GetID(), client, w, r) @@ -406,18 +408,18 @@ func RedirectToLogin(authReqID string, client Client, w http.ResponseWriter, r * func AuthorizeCallback(w http.ResponseWriter, r *http.Request, authorizer Authorizer) { id, err := ParseAuthorizeCallbackRequest(r) if err != nil { - AuthRequestError(w, r, nil, err, authorizer.Encoder()) + AuthRequestError(w, r, nil, err, authorizer) return } authReq, err := authorizer.Storage().AuthRequestByID(r.Context(), id) if err != nil { - AuthRequestError(w, r, nil, err, authorizer.Encoder()) + AuthRequestError(w, r, nil, err, authorizer) return } if !authReq.Done() { AuthRequestError(w, r, authReq, oidc.ErrInteractionRequired().WithDescription("Unfortunately, the user may be not logged in and/or additional interaction is required."), - authorizer.Encoder()) + authorizer) return } AuthResponse(authReq, authorizer, w, r) @@ -438,7 +440,7 @@ func ParseAuthorizeCallbackRequest(r *http.Request) (id string, err error) { func AuthResponse(authReq AuthRequest, authorizer Authorizer, w http.ResponseWriter, r *http.Request) { client, err := authorizer.Storage().GetClientByClientID(r.Context(), authReq.GetClientID()) if err != nil { - AuthRequestError(w, r, authReq, err, authorizer.Encoder()) + AuthRequestError(w, r, authReq, err, authorizer) return } if authReq.GetResponseType() == oidc.ResponseTypeCode { @@ -452,7 +454,7 @@ func AuthResponse(authReq AuthRequest, authorizer Authorizer, w http.ResponseWri func AuthResponseCode(w http.ResponseWriter, r *http.Request, authReq AuthRequest, authorizer Authorizer) { code, err := CreateAuthRequestCode(r.Context(), authReq, authorizer.Storage(), authorizer.Crypto()) if err != nil { - AuthRequestError(w, r, authReq, err, authorizer.Encoder()) + AuthRequestError(w, r, authReq, err, authorizer) return } codeResponse := struct { @@ -464,7 +466,7 @@ func AuthResponseCode(w http.ResponseWriter, r *http.Request, authReq AuthReques } callback, err := AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), authReq.GetResponseMode(), &codeResponse, authorizer.Encoder()) if err != nil { - AuthRequestError(w, r, authReq, err, authorizer.Encoder()) + AuthRequestError(w, r, authReq, err, authorizer) return } http.Redirect(w, r, callback, http.StatusFound) @@ -475,12 +477,12 @@ func AuthResponseToken(w http.ResponseWriter, r *http.Request, authReq AuthReque createAccessToken := authReq.GetResponseType() != oidc.ResponseTypeIDTokenOnly resp, err := CreateTokenResponse(r.Context(), authReq, client, authorizer, createAccessToken, "", "") if err != nil { - AuthRequestError(w, r, authReq, err, authorizer.Encoder()) + AuthRequestError(w, r, authReq, err, authorizer) return } callback, err := AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), authReq.GetResponseMode(), resp, authorizer.Encoder()) if err != nil { - AuthRequestError(w, r, authReq, err, authorizer.Encoder()) + AuthRequestError(w, r, authReq, err, authorizer) return } http.Redirect(w, r, callback, http.StatusFound) diff --git a/pkg/op/auth_request_test.go b/pkg/op/auth_request_test.go index df340b6b..42fd0aa0 100644 --- a/pkg/op/auth_request_test.go +++ b/pkg/op/auth_request_test.go @@ -18,6 +18,7 @@ import ( "github.com/zitadel/oidc/v3/pkg/op" "github.com/zitadel/oidc/v3/pkg/op/mock" "github.com/zitadel/schema" + "golang.org/x/exp/slog" ) func TestAuthorize(t *testing.T) { @@ -38,7 +39,7 @@ func TestAuthorize(t *testing.T) { expect := authorizer.EXPECT() expect.Decoder().Return(schema.NewDecoder()) - expect.Encoder().Return(schema.NewEncoder()) + expect.Logger().Return(slog.Default()) if tt.expect != nil { tt.expect(expect) diff --git a/pkg/op/device.go b/pkg/op/device.go index 09c7fca1..029bed8a 100644 --- a/pkg/op/device.go +++ b/pkg/op/device.go @@ -57,7 +57,7 @@ var ( func DeviceAuthorizationHandler(o OpenIDProvider) func(http.ResponseWriter, *http.Request) { return func(w http.ResponseWriter, r *http.Request) { if err := DeviceAuthorization(w, r, o); err != nil { - RequestError(w, r, err) + RequestError(w, r, err, o.Logger()) } } } @@ -190,7 +190,7 @@ func (r *deviceAccessTokenRequest) GetScopes() []string { func DeviceAccessToken(w http.ResponseWriter, r *http.Request, exchanger Exchanger) { if err := deviceAccessToken(w, r, exchanger); err != nil { - RequestError(w, r, err) + RequestError(w, r, err, exchanger.Logger()) } } diff --git a/pkg/op/error.go b/pkg/op/error.go index b2d84ae1..9981fecc 100644 --- a/pkg/op/error.go +++ b/pkg/op/error.go @@ -5,6 +5,7 @@ import ( httphelper "github.com/zitadel/oidc/v3/pkg/http" "github.com/zitadel/oidc/v3/pkg/oidc" + "golang.org/x/exp/slog" ) type ErrAuthRequest interface { @@ -13,13 +14,31 @@ type ErrAuthRequest interface { GetState() string } -func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq ErrAuthRequest, err error, encoder httphelper.Encoder) { +// LogAuthRequest is an optional interface, +// that allows logging AuthRequest fields. +// If the AuthRequest does not implement this interface, +// no details shall be printed to the logs. +type LogAuthRequest interface { + ErrAuthRequest + slog.LogValuer +} + +func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq ErrAuthRequest, err error, authorizer Authorizer) { + e := oidc.DefaultToServerError(err, err.Error()) + logger := authorizer.Logger().With("oidc_error", e) + if authReq == nil { + logger.Log(r.Context(), e.LogLevel(), "auth request") http.Error(w, err.Error(), http.StatusBadRequest) return } - e := oidc.DefaultToServerError(err, err.Error()) + + if logAuthReq, ok := authReq.(LogAuthRequest); ok { + logger = logger.With("auth_request", logAuthReq) + } + if authReq.GetRedirectURI() == "" || e.IsRedirectDisabled() { + logger.Log(r.Context(), e.LogLevel(), "auth request: not redirecting") http.Error(w, e.Description, http.StatusBadRequest) return } @@ -28,19 +47,22 @@ func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq ErrAuthReq if rm, ok := authReq.(interface{ GetResponseMode() oidc.ResponseMode }); ok { responseMode = rm.GetResponseMode() } - url, err := AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), responseMode, e, encoder) + url, err := AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), responseMode, e, authorizer.Encoder()) if err != nil { + logger.ErrorContext(r.Context(), "auth response URL", "error", err) http.Error(w, err.Error(), http.StatusBadRequest) return } + logger.Log(r.Context(), e.LogLevel(), "auth request") http.Redirect(w, r, url, http.StatusFound) } -func RequestError(w http.ResponseWriter, r *http.Request, err error) { +func RequestError(w http.ResponseWriter, r *http.Request, err error, logger *slog.Logger) { e := oidc.DefaultToServerError(err, err.Error()) status := http.StatusBadRequest if e.ErrorType == oidc.InvalidClient { - status = 401 + status = http.StatusUnauthorized } + logger.Log(r.Context(), e.LogLevel(), "request error", "oidc_error", e) httphelper.MarshalJSONWithStatus(w, e, status) } diff --git a/pkg/op/error_test.go b/pkg/op/error_test.go new file mode 100644 index 00000000..dc5ef110 --- /dev/null +++ b/pkg/op/error_test.go @@ -0,0 +1,277 @@ +package op + +import ( + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/zitadel/oidc/v3/pkg/oidc" + "github.com/zitadel/schema" + "golang.org/x/exp/slog" +) + +func TestAuthRequestError(t *testing.T) { + type args struct { + authReq ErrAuthRequest + err error + } + tests := []struct { + name string + args args + wantCode int + wantHeaders map[string]string + wantBody string + wantLog string + }{ + { + name: "nil auth request", + args: args{ + authReq: nil, + err: io.ErrClosedPipe, + }, + wantCode: http.StatusBadRequest, + wantBody: "io: read/write on closed pipe\n", + wantLog: `{ + "level":"ERROR", + "msg":"auth request", + "time":"not", + "oidc_error":{ + "description":"io: read/write on closed pipe", + "parent":"io: read/write on closed pipe", + "type":"server_error" + } + }`, + }, + { + name: "auth request, no redirect URI", + args: args{ + authReq: &oidc.AuthRequest{ + Scopes: oidc.SpaceDelimitedArray{"a", "b"}, + ResponseType: "responseType", + ClientID: "123", + State: "state1", + ResponseMode: oidc.ResponseModeQuery, + }, + err: oidc.ErrInteractionRequired().WithDescription("sign in"), + }, + wantCode: http.StatusBadRequest, + wantBody: "sign in\n", + wantLog: `{ + "level":"WARN", + "msg":"auth request: not redirecting", + "time":"not", + "auth_request":{ + "client_id":"123", + "redirect_uri":"", + "response_type":"responseType", + "scopes":"a b" + }, + "oidc_error":{ + "description":"sign in", + "type":"interaction_required" + } + }`, + }, + { + name: "auth request, redirect disabled", + args: args{ + authReq: &oidc.AuthRequest{ + Scopes: oidc.SpaceDelimitedArray{"a", "b"}, + ResponseType: "responseType", + ClientID: "123", + RedirectURI: "http://example.com/callback", + State: "state1", + ResponseMode: oidc.ResponseModeQuery, + }, + err: oidc.ErrInvalidRequestRedirectURI().WithDescription("oops"), + }, + wantCode: http.StatusBadRequest, + wantBody: "oops\n", + wantLog: `{ + "level":"WARN", + "msg":"auth request: not redirecting", + "time":"not", + "auth_request":{ + "client_id":"123", + "redirect_uri":"http://example.com/callback", + "response_type":"responseType", + "scopes":"a b" + }, + "oidc_error":{ + "description":"oops", + "type":"invalid_request", + "redirect_disabled":true + } + }`, + }, + { + name: "auth request, url parse error", + args: args{ + authReq: &oidc.AuthRequest{ + Scopes: oidc.SpaceDelimitedArray{"a", "b"}, + ResponseType: "responseType", + ClientID: "123", + RedirectURI: "can't parse this!\n", + State: "state1", + ResponseMode: oidc.ResponseModeQuery, + }, + err: oidc.ErrInteractionRequired().WithDescription("sign in"), + }, + wantCode: http.StatusBadRequest, + wantBody: "ErrorType=server_error Parent=parse \"can't parse this!\\n\": net/url: invalid control character in URL\n", + wantLog: `{ + "level":"ERROR", + "msg":"auth response URL", + "time":"not", + "auth_request":{ + "client_id":"123", + "redirect_uri":"can't parse this!\n", + "response_type":"responseType", + "scopes":"a b" + }, + "error":{ + "type":"server_error", + "parent":"parse \"can't parse this!\\n\": net/url: invalid control character in URL" + }, + "oidc_error":{ + "description":"sign in", + "type":"interaction_required" + } + }`, + }, + { + name: "auth request redirect", + args: args{ + authReq: &oidc.AuthRequest{ + Scopes: oidc.SpaceDelimitedArray{"a", "b"}, + ResponseType: "responseType", + ClientID: "123", + RedirectURI: "http://example.com/callback", + State: "state1", + ResponseMode: oidc.ResponseModeQuery, + }, + err: oidc.ErrInteractionRequired().WithDescription("sign in"), + }, + wantCode: http.StatusFound, + wantHeaders: map[string]string{"Location": "http://example.com/callback?error=interaction_required&error_description=sign+in&state=state1"}, + wantLog: `{ + "level":"WARN", + "msg":"auth request", + "time":"not", + "auth_request":{ + "client_id":"123", + "redirect_uri":"http://example.com/callback", + "response_type":"responseType", + "scopes":"a b" + }, + "oidc_error":{ + "description":"sign in", + "type":"interaction_required" + } + }`, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + logOut := new(strings.Builder) + authorizer := &Provider{ + encoder: schema.NewEncoder(), + logger: slog.New( + slog.NewJSONHandler(logOut, &slog.HandlerOptions{ + Level: slog.LevelInfo, + }).WithAttrs([]slog.Attr{slog.String("time", "not")}), + ), + } + + w := httptest.NewRecorder() + r := httptest.NewRequest("POST", "/path", nil) + AuthRequestError(w, r, tt.args.authReq, tt.args.err, authorizer) + + res := w.Result() + defer res.Body.Close() + + assert.Equal(t, tt.wantCode, res.StatusCode) + for key, wantHeader := range tt.wantHeaders { + gotHeader := res.Header.Get(key) + assert.Equalf(t, wantHeader, gotHeader, "header %q", key) + } + gotBody, err := io.ReadAll(res.Body) + require.NoError(t, err, "read result body") + assert.Equal(t, tt.wantBody, string(gotBody), "result body") + + gotLog := logOut.String() + t.Log(gotLog) + assert.JSONEq(t, tt.wantLog, gotLog, "log output") + }) + } +} + +func TestRequestError(t *testing.T) { + tests := []struct { + name string + err error + wantCode int + wantBody string + wantLog string + }{ + { + name: "server error", + err: io.ErrClosedPipe, + wantCode: http.StatusBadRequest, + wantBody: `{"error":"server_error", "error_description":"io: read/write on closed pipe"}`, + wantLog: `{ + "level":"ERROR", + "msg":"request error", + "time":"not", + "oidc_error":{ + "parent":"io: read/write on closed pipe", + "description":"io: read/write on closed pipe", + "type":"server_error"} + }`, + }, + { + name: "invalid client", + err: oidc.ErrInvalidClient().WithDescription("not good"), + wantCode: http.StatusUnauthorized, + wantBody: `{"error":"invalid_client", "error_description":"not good"}`, + wantLog: `{ + "level":"WARN", + "msg":"request error", + "time":"not", + "oidc_error":{ + "description":"not good", + "type":"invalid_client"} + }`, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + logOut := new(strings.Builder) + logger := slog.New( + slog.NewJSONHandler(logOut, &slog.HandlerOptions{ + Level: slog.LevelInfo, + }).WithAttrs([]slog.Attr{slog.String("time", "not")}), + ) + w := httptest.NewRecorder() + r := httptest.NewRequest("POST", "/path", nil) + RequestError(w, r, tt.err, logger) + + res := w.Result() + defer res.Body.Close() + + assert.Equal(t, tt.wantCode, res.StatusCode, "status code") + + gotBody, err := io.ReadAll(res.Body) + require.NoError(t, err, "read result body") + assert.JSONEq(t, tt.wantBody, string(gotBody), "result body") + + gotLog := logOut.String() + t.Log(gotLog) + assert.JSONEq(t, tt.wantLog, gotLog, "log output") + }) + } +} diff --git a/pkg/op/mock/authorizer.mock.go b/pkg/op/mock/authorizer.mock.go index a0c67e3d..e4297cb8 100644 --- a/pkg/op/mock/authorizer.mock.go +++ b/pkg/op/mock/authorizer.mock.go @@ -11,6 +11,7 @@ import ( gomock "github.com/golang/mock/gomock" http "github.com/zitadel/oidc/v3/pkg/http" op "github.com/zitadel/oidc/v3/pkg/op" + slog "golang.org/x/exp/slog" ) // MockAuthorizer is a mock of Authorizer interface. @@ -92,6 +93,20 @@ func (mr *MockAuthorizerMockRecorder) IDTokenHintVerifier(arg0 interface{}) *gom return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IDTokenHintVerifier", reflect.TypeOf((*MockAuthorizer)(nil).IDTokenHintVerifier), arg0) } +// Logger mocks base method. +func (m *MockAuthorizer) Logger() *slog.Logger { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Logger") + ret0, _ := ret[0].(*slog.Logger) + return ret0 +} + +// Logger indicates an expected call of Logger. +func (mr *MockAuthorizerMockRecorder) Logger() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Logger", reflect.TypeOf((*MockAuthorizer)(nil).Logger)) +} + // RequestObjectSupported mocks base method. func (m *MockAuthorizer) RequestObjectSupported() bool { m.ctrl.T.Helper() diff --git a/pkg/op/op.go b/pkg/op/op.go index 1fbe7801..d8ae570b 100644 --- a/pkg/op/op.go +++ b/pkg/op/op.go @@ -9,6 +9,7 @@ import ( "github.com/go-chi/chi" "github.com/rs/cors" "github.com/zitadel/schema" + "golang.org/x/exp/slog" "golang.org/x/text/language" "gopkg.in/square/go-jose.v2" @@ -79,6 +80,9 @@ type OpenIDProvider interface { DefaultLogoutRedirectURI() string Probes() []ProbesFn + // EXPERIMENTAL: Will change to log/slog import after we drop support for Go 1.20 + Logger() *slog.Logger + // Deprecated: Provider now implements http.Handler directly. HttpHandler() http.Handler } @@ -174,6 +178,7 @@ func newProvider(config *Config, storage Storage, issuer func(bool) (IssuerFromR storage: storage, endpoints: DefaultEndpoints, timer: make(<-chan time.Time), + logger: slog.Default(), } for _, optFunc := range opOpts { @@ -217,6 +222,7 @@ type Provider struct { timer <-chan time.Time accessTokenVerifierOpts []AccessTokenVerifierOpt idTokenHintVerifierOpts []IDTokenHintVerifierOpt + logger *slog.Logger } func (o *Provider) IssuerFromRequest(r *http.Request) string { @@ -375,6 +381,10 @@ func (o *Provider) Probes() []ProbesFn { } } +func (o *Provider) Logger() *slog.Logger { + return o.logger +} + // Deprecated: Provider now implements http.Handler directly. func (o *Provider) HttpHandler() http.Handler { return o @@ -523,6 +533,16 @@ func WithIDTokenHintVerifierOpts(opts ...IDTokenHintVerifierOpt) Option { } } +// WithLogger lets a logger other than slog.Default(). +// +// EXPERIMENTAL: Will change to log/slog import after we drop support for Go 1.20 +func WithLogger(logger *slog.Logger) Option { + return func(o *Provider) error { + o.logger = logger + return nil + } +} + func intercept(i IssuerFromRequest, interceptors ...HttpInterceptor) func(handler http.Handler) http.Handler { issuerInterceptor := NewIssuerInterceptor(i) return func(handler http.Handler) http.Handler { diff --git a/pkg/op/op_test.go b/pkg/op/op_test.go index d347d048..d33b39d5 100644 --- a/pkg/op/op_test.go +++ b/pkg/op/op_test.go @@ -156,7 +156,7 @@ func TestRoutes(t *testing.T) { values: map[string]string{ "client_id": client.GetID(), "redirect_uri": "https://example.com", - "scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.Encode(), + "scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(), "response_type": string(oidc.ResponseTypeCode), }, wantCode: http.StatusFound, @@ -193,7 +193,7 @@ func TestRoutes(t *testing.T) { path: testProvider.TokenEndpoint().Relative(), values: map[string]string{ "grant_type": string(oidc.GrantTypeBearer), - "scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.Encode(), + "scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(), "assertion": jwtToken, }, wantCode: http.StatusBadRequest, @@ -206,7 +206,7 @@ func TestRoutes(t *testing.T) { basicAuth: &basicAuth{"web", "secret"}, values: map[string]string{ "grant_type": string(oidc.GrantTypeTokenExchange), - "scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.Encode(), + "scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(), "subject_token": jwtToken, "subject_token_type": string(oidc.AccessTokenType), }, @@ -223,7 +223,7 @@ func TestRoutes(t *testing.T) { basicAuth: &basicAuth{"sid1", "verysecret"}, values: map[string]string{ "grant_type": string(oidc.GrantTypeClientCredentials), - "scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.Encode(), + "scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(), }, wantCode: http.StatusOK, contains: []string{`{"access_token":"`, `","token_type":"Bearer","expires_in":299}`}, @@ -338,7 +338,7 @@ func TestRoutes(t *testing.T) { path: testProvider.DeviceAuthorizationEndpoint().Relative(), basicAuth: &basicAuth{"web", "secret"}, values: map[string]string{ - "scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.Encode(), + "scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(), }, wantCode: http.StatusOK, contains: []string{ diff --git a/pkg/op/session.go b/pkg/op/session.go index fd914d11..2467b20f 100644 --- a/pkg/op/session.go +++ b/pkg/op/session.go @@ -8,6 +8,7 @@ import ( httphelper "github.com/zitadel/oidc/v3/pkg/http" "github.com/zitadel/oidc/v3/pkg/oidc" + "golang.org/x/exp/slog" ) type SessionEnder interface { @@ -15,6 +16,7 @@ type SessionEnder interface { Storage() Storage IDTokenHintVerifier(context.Context) *IDTokenHintVerifier DefaultLogoutRedirectURI() string + Logger() *slog.Logger } func endSessionHandler(ender SessionEnder) func(http.ResponseWriter, *http.Request) { @@ -31,12 +33,12 @@ func EndSession(w http.ResponseWriter, r *http.Request, ender SessionEnder) { } session, err := ValidateEndSessionRequest(r.Context(), req, ender) if err != nil { - RequestError(w, r, err) + RequestError(w, r, err, ender.Logger()) return } err = ender.Storage().TerminateSession(r.Context(), session.UserID, session.ClientID) if err != nil { - RequestError(w, r, oidc.DefaultToServerError(err, "error terminating session")) + RequestError(w, r, oidc.DefaultToServerError(err, "error terminating session"), ender.Logger()) return } http.Redirect(w, r, session.RedirectURI, http.StatusFound) diff --git a/pkg/op/token_client_credentials.go b/pkg/op/token_client_credentials.go index 0cf77961..043bb072 100644 --- a/pkg/op/token_client_credentials.go +++ b/pkg/op/token_client_credentials.go @@ -14,18 +14,18 @@ import ( func ClientCredentialsExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) { request, err := ParseClientCredentialsRequest(r, exchanger.Decoder()) if err != nil { - RequestError(w, r, err) + RequestError(w, r, err, exchanger.Logger()) } validatedRequest, client, err := ValidateClientCredentialsRequest(r.Context(), request, exchanger) if err != nil { - RequestError(w, r, err) + RequestError(w, r, err, exchanger.Logger()) return } resp, err := CreateClientCredentialsTokenResponse(r.Context(), validatedRequest, exchanger, client) if err != nil { - RequestError(w, r, err) + RequestError(w, r, err, exchanger.Logger()) return } diff --git a/pkg/op/token_code.go b/pkg/op/token_code.go index b5e892af..baf377bc 100644 --- a/pkg/op/token_code.go +++ b/pkg/op/token_code.go @@ -13,20 +13,20 @@ import ( func CodeExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) { tokenReq, err := ParseAccessTokenRequest(r, exchanger.Decoder()) if err != nil { - RequestError(w, r, err) + RequestError(w, r, err, exchanger.Logger()) } if tokenReq.Code == "" { - RequestError(w, r, oidc.ErrInvalidRequest().WithDescription("code missing")) + RequestError(w, r, oidc.ErrInvalidRequest().WithDescription("code missing"), exchanger.Logger()) return } authReq, client, err := ValidateAccessTokenRequest(r.Context(), tokenReq, exchanger) if err != nil { - RequestError(w, r, err) + RequestError(w, r, err, exchanger.Logger()) return } resp, err := CreateTokenResponse(r.Context(), authReq, client, exchanger, true, tokenReq.Code, "") if err != nil { - RequestError(w, r, err) + RequestError(w, r, err, exchanger.Logger()) return } httphelper.MarshalJSON(w, resp) diff --git a/pkg/op/token_exchange.go b/pkg/op/token_exchange.go index 93aa9b24..21db1347 100644 --- a/pkg/op/token_exchange.go +++ b/pkg/op/token_exchange.go @@ -136,17 +136,17 @@ func (r *tokenExchangeRequest) SetSubject(subject string) { func TokenExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) { tokenExchangeReq, clientID, clientSecret, err := ParseTokenExchangeRequest(r, exchanger.Decoder()) if err != nil { - RequestError(w, r, err) + RequestError(w, r, err, exchanger.Logger()) } tokenExchangeRequest, client, err := ValidateTokenExchangeRequest(r.Context(), tokenExchangeReq, clientID, clientSecret, exchanger) if err != nil { - RequestError(w, r, err) + RequestError(w, r, err, exchanger.Logger()) return } resp, err := CreateTokenExchangeResponse(r.Context(), tokenExchangeRequest, client, exchanger) if err != nil { - RequestError(w, r, err) + RequestError(w, r, err, exchanger.Logger()) return } httphelper.MarshalJSON(w, resp) diff --git a/pkg/op/token_jwt_profile.go b/pkg/op/token_jwt_profile.go index 4cd7b1e4..357200ee 100644 --- a/pkg/op/token_jwt_profile.go +++ b/pkg/op/token_jwt_profile.go @@ -18,23 +18,23 @@ type JWTAuthorizationGrantExchanger interface { func JWTProfile(w http.ResponseWriter, r *http.Request, exchanger JWTAuthorizationGrantExchanger) { profileRequest, err := ParseJWTProfileGrantRequest(r, exchanger.Decoder()) if err != nil { - RequestError(w, r, err) + RequestError(w, r, err, exchanger.Logger()) } tokenRequest, err := VerifyJWTAssertion(r.Context(), profileRequest.Assertion, exchanger.JWTProfileVerifier(r.Context())) if err != nil { - RequestError(w, r, err) + RequestError(w, r, err, exchanger.Logger()) return } tokenRequest.Scopes, err = exchanger.Storage().ValidateJWTProfileScopes(r.Context(), tokenRequest.Issuer, profileRequest.Scope) if err != nil { - RequestError(w, r, err) + RequestError(w, r, err, exchanger.Logger()) return } resp, err := CreateJWTTokenResponse(r.Context(), tokenRequest, exchanger) if err != nil { - RequestError(w, r, err) + RequestError(w, r, err, exchanger.Logger()) return } httphelper.MarshalJSON(w, resp) diff --git a/pkg/op/token_refresh.go b/pkg/op/token_refresh.go index aeaa5b4b..9421033f 100644 --- a/pkg/op/token_refresh.go +++ b/pkg/op/token_refresh.go @@ -26,16 +26,16 @@ type RefreshTokenRequest interface { func RefreshTokenExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) { tokenReq, err := ParseRefreshTokenRequest(r, exchanger.Decoder()) if err != nil { - RequestError(w, r, err) + RequestError(w, r, err, exchanger.Logger()) } validatedRequest, client, err := ValidateRefreshTokenRequest(r.Context(), tokenReq, exchanger) if err != nil { - RequestError(w, r, err) + RequestError(w, r, err, exchanger.Logger()) return } resp, err := CreateTokenResponse(r.Context(), validatedRequest, client, exchanger, true, "", tokenReq.RefreshToken) if err != nil { - RequestError(w, r, err) + RequestError(w, r, err, exchanger.Logger()) return } httphelper.MarshalJSON(w, resp) diff --git a/pkg/op/token_request.go b/pkg/op/token_request.go index c06a51bc..0df2fcee 100644 --- a/pkg/op/token_request.go +++ b/pkg/op/token_request.go @@ -7,6 +7,7 @@ import ( httphelper "github.com/zitadel/oidc/v3/pkg/http" "github.com/zitadel/oidc/v3/pkg/oidc" + "golang.org/x/exp/slog" ) type Exchanger interface { @@ -22,6 +23,7 @@ type Exchanger interface { GrantTypeDeviceCodeSupported() bool AccessTokenVerifier(context.Context) *AccessTokenVerifier IDTokenHintVerifier(context.Context) *IDTokenHintVerifier + Logger() *slog.Logger } func tokenHandler(exchanger Exchanger) func(w http.ResponseWriter, r *http.Request) { @@ -63,10 +65,10 @@ func Exchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) { return } case "": - RequestError(w, r, oidc.ErrInvalidRequest().WithDescription("grant_type missing")) + RequestError(w, r, oidc.ErrInvalidRequest().WithDescription("grant_type missing"), exchanger.Logger()) return } - RequestError(w, r, oidc.ErrUnsupportedGrantType().WithDescription("%s not supported", grantType)) + RequestError(w, r, oidc.ErrUnsupportedGrantType().WithDescription("%s not supported", grantType), exchanger.Logger()) } // AuthenticatedTokenRequest is a helper interface for ParseAuthenticatedTokenRequest