diff --git a/example/client/app/app.go b/example/client/app/app.go index 2cb5dfa7..0e339f40 100644 --- a/example/client/app/app.go +++ b/example/client/app/app.go @@ -7,11 +7,14 @@ import ( "net/http" "os" "strings" + "sync/atomic" "time" "github.com/google/uuid" "github.com/sirupsen/logrus" + "golang.org/x/exp/slog" + "github.com/zitadel/logging" "github.com/zitadel/oidc/v3/pkg/client/rp" httphelper "github.com/zitadel/oidc/v3/pkg/http" "github.com/zitadel/oidc/v3/pkg/oidc" @@ -33,9 +36,25 @@ func main() { redirectURI := fmt.Sprintf("http://localhost:%v%v", port, callbackPath) cookieHandler := httphelper.NewCookieHandler(key, key, httphelper.WithUnsecure()) + logger := slog.New( + slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{ + AddSource: true, + Level: slog.LevelDebug, + }), + ) + client := &http.Client{ + Timeout: time.Minute, + } + // enable outgoing request logging + logging.EnableHTTPClient(client, + logging.WithClientGroup("client"), + ) + options := []rp.Option{ rp.WithCookieHandler(cookieHandler), rp.WithVerifierOpts(rp.WithIssuedAtOffset(5 * time.Second)), + rp.WithHTTPClient(client), + rp.WithLogger(logger), } if clientSecret == "" { options = append(options, rp.WithPKCE(cookieHandler)) @@ -44,7 +63,10 @@ func main() { options = append(options, rp.WithJWTProfile(rp.SignerFromKeyPath(keyPath))) } - provider, err := rp.NewRelyingPartyOIDC(context.TODO(), issuer, clientID, clientSecret, redirectURI, scopes, options...) + // One can add a logger to the context, + // pre-defining log attributes as required. + ctx := logging.ToContext(context.TODO(), logger) + provider, err := rp.NewRelyingPartyOIDC(ctx, issuer, clientID, clientSecret, redirectURI, scopes, options...) if err != nil { logrus.Fatalf("error creating provider %s", err.Error()) } @@ -119,8 +141,22 @@ func main() { // // http.Handle(callbackPath, rp.CodeExchangeHandler(marshalToken, provider)) + // simple counter for request IDs + var counter atomic.Int64 + // enable incomming request logging + mw := logging.Middleware( + logging.WithLogger(logger), + logging.WithGroup("server"), + logging.WithIDFunc(func() slog.Attr { + return slog.Int64("id", counter.Add(1)) + }), + ) + lis := fmt.Sprintf("127.0.0.1:%s", port) - logrus.Infof("listening on http://%s/", lis) - logrus.Info("press ctrl+c to stop") - logrus.Fatal(http.ListenAndServe(lis, nil)) + logger.Info("server listening, press ctrl+c to stop", "addr", lis) + err = http.ListenAndServe(lis, mw(http.DefaultServeMux)) + if err != http.ErrServerClosed { + logger.Error("server terminated", "error", err) + os.Exit(1) + } } diff --git a/example/server/exampleop/op.go b/example/server/exampleop/op.go index 57db020e..b5ee7b37 100644 --- a/example/server/exampleop/op.go +++ b/example/server/exampleop/op.go @@ -4,9 +4,11 @@ import ( "crypto/sha256" "log" "net/http" + "sync/atomic" "time" "github.com/go-chi/chi" + "github.com/zitadel/logging" "golang.org/x/exp/slog" "golang.org/x/text/language" @@ -32,6 +34,9 @@ type Storage interface { deviceAuthenticate } +// simple counter for request IDs +var counter atomic.Int64 + // SetupServer creates an OIDC server with Issuer=http://localhost: // // Use one of the pre-made clients in storage/clients.go or register a new one. @@ -41,6 +46,12 @@ func SetupServer(issuer string, storage Storage, logger *slog.Logger) chi.Router key := sha256.Sum256([]byte("test")) router := chi.NewRouter() + router.Use(logging.Middleware( + logging.WithLogger(logger), + logging.WithIDFunc(func() slog.Attr { + return slog.Int64("id", counter.Add(1)) + }), + )) // for simplicity, we provide a very small default page for users who have signed out router.HandleFunc(pathLoggedOut, func(w http.ResponseWriter, req *http.Request) { diff --git a/example/server/main.go b/example/server/main.go index ee8422bb..a1cc4618 100644 --- a/example/server/main.go +++ b/example/server/main.go @@ -2,7 +2,6 @@ package main import ( "fmt" - "log" "net/http" "os" @@ -22,10 +21,6 @@ func main() { // in this example it will be handled in-memory storage := storage.NewStorage(storage.NewUserStore(issuer)) - // Using our wrapped logging handler, - // data set to the context gets printed - // as part of the log output. - // This helps us tie log output to requests. logger := slog.New( slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{ AddSource: true, @@ -40,7 +35,8 @@ func main() { } logger.Info("server listening, press ctrl+c to stop", "addr", fmt.Sprintf("http://localhost:%s/", port)) err := server.ListenAndServe() - if err != nil { - log.Fatal(err) + if err != http.ErrServerClosed { + logger.Error("server terminated", "error", err) + os.Exit(1) } } diff --git a/go.mod b/go.mod index 646829ef..c07f7f78 100644 --- a/go.mod +++ b/go.mod @@ -13,6 +13,7 @@ require ( github.com/rs/cors v1.9.0 github.com/sirupsen/logrus v1.9.3 github.com/stretchr/testify v1.8.2 + github.com/zitadel/logging v0.3.5-0.20230828081740-9d6abec32b43 github.com/zitadel/schema v1.3.0 golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63 golang.org/x/oauth2 v0.7.0 @@ -21,6 +22,7 @@ require ( ) require ( + github.com/benbjohnson/clock v1.3.5 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/golang/protobuf v1.5.3 // indirect github.com/google/go-cmp v0.5.9 // indirect diff --git a/go.sum b/go.sum index 4b6e60df..e804b67e 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +github.com/benbjohnson/clock v1.3.5 h1:VvXlSJBzZpA/zum6Sj74hxwYI2DIxRWuNIoXAzHZz5o= +github.com/benbjohnson/clock v1.3.5/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -47,6 +49,8 @@ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8= github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= +github.com/zitadel/logging v0.3.5-0.20230828081740-9d6abec32b43 h1:m89ASp88slzDla9BDfAFQZRFKFI2ywhJkZRqTEBPIuI= +github.com/zitadel/logging v0.3.5-0.20230828081740-9d6abec32b43/go.mod h1:WHfGs2W60PHHXOGRZDgAKMqLmtQgIgW5w6nCyR6tX5U= github.com/zitadel/schema v1.3.0 h1:kQ9W9tvIwZICCKWcMvCEweXET1OcOyGEuFbHs4o5kg0= github.com/zitadel/schema v1.3.0/go.mod h1:NptN6mkBDFvERUCvZHlvWmmME+gmZ44xzwRXwhzsbtc= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= @@ -102,6 +106,7 @@ gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntN gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/square/go-jose.v2 v2.6.0 h1:NGk74WTnPKBNUhNzQX7PYcTLUjoq7mzKk2OKbvwk2iI= gopkg.in/square/go-jose.v2 v2.6.0/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI= +gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/pkg/client/client.go b/pkg/client/client.go index e3efd611..7b76dfd6 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -14,6 +14,7 @@ import ( "golang.org/x/oauth2" "gopkg.in/square/go-jose.v2" + "github.com/zitadel/logging" "github.com/zitadel/oidc/v3/pkg/crypto" httphelper "github.com/zitadel/oidc/v3/pkg/http" "github.com/zitadel/oidc/v3/pkg/oidc" @@ -37,6 +38,10 @@ func Discover(ctx context.Context, issuer string, httpClient *http.Client, wellK if err != nil { return nil, err } + if logger, ok := logging.FromContext(ctx); ok { + logger.Debug("discover", "config", discoveryConfig) + } + if discoveryConfig.Issuer != issuer { return nil, oidc.ErrIssuerInvalid } diff --git a/pkg/client/rp/device.go b/pkg/client/rp/device.go index 390c8cf4..02c647e3 100644 --- a/pkg/client/rp/device.go +++ b/pkg/client/rp/device.go @@ -33,6 +33,7 @@ func newDeviceClientCredentialsRequest(scopes []string, rp RelyingParty) (*oidc. // in RFC 8628, section 3.1 and 3.2: // https://www.rfc-editor.org/rfc/rfc8628#section-3.1 func DeviceAuthorization(ctx context.Context, scopes []string, rp RelyingParty, authFn any) (*oidc.DeviceAuthorizationResponse, error) { + ctx = logCtxWithRPData(ctx, rp, "function", "DeviceAuthorization") req, err := newDeviceClientCredentialsRequest(scopes, rp) if err != nil { return nil, err @@ -45,6 +46,7 @@ func DeviceAuthorization(ctx context.Context, scopes []string, rp RelyingParty, // by means of polling as defined in RFC, section 3.3 and 3.4: // https://www.rfc-editor.org/rfc/rfc8628#section-3.4 func DeviceAccessToken(ctx context.Context, deviceCode string, interval time.Duration, rp RelyingParty) (resp *oidc.AccessTokenResponse, err error) { + ctx = logCtxWithRPData(ctx, rp, "function", "DeviceAccessToken") req := &client.DeviceAccessTokenRequest{ DeviceAccessTokenRequest: oidc.DeviceAccessTokenRequest{ GrantType: oidc.GrantTypeDeviceCode, diff --git a/pkg/client/rp/log.go b/pkg/client/rp/log.go new file mode 100644 index 00000000..6056fa2e --- /dev/null +++ b/pkg/client/rp/log.go @@ -0,0 +1,17 @@ +package rp + +import ( + "context" + + "github.com/zitadel/logging" + "golang.org/x/exp/slog" +) + +func logCtxWithRPData(ctx context.Context, rp RelyingParty, attrs ...any) context.Context { + logger, ok := rp.Logger(ctx) + if !ok { + return ctx + } + logger = logger.With(slog.Group("rp", attrs...)) + return logging.ToContext(ctx, logger) +} diff --git a/pkg/client/rp/relying_party.go b/pkg/client/rp/relying_party.go index 29215a1a..34cdb397 100644 --- a/pkg/client/rp/relying_party.go +++ b/pkg/client/rp/relying_party.go @@ -10,6 +10,8 @@ import ( "time" "github.com/google/uuid" + "github.com/zitadel/logging" + "golang.org/x/exp/slog" "golang.org/x/oauth2" "gopkg.in/square/go-jose.v2" @@ -67,6 +69,9 @@ type RelyingParty interface { // ErrorHandler returns the handler used for callback errors ErrorHandler() func(http.ResponseWriter, *http.Request, string, string, string) + + // Logger from the context, or a fallback if set. + Logger(context.Context) (logger *slog.Logger, ok bool) } type ErrorHandler func(w http.ResponseWriter, r *http.Request, errorType string, errorDesc string, state string) @@ -90,6 +95,7 @@ type relyingParty struct { idTokenVerifier *IDTokenVerifier verifierOpts []VerifierOption signer jose.Signer + logger *slog.Logger } func (rp *relyingParty) OAuthConfig() *oauth2.Config { @@ -150,6 +156,14 @@ func (rp *relyingParty) ErrorHandler() func(http.ResponseWriter, *http.Request, return rp.errorHandler } +func (rp *relyingParty) Logger(ctx context.Context) (logger *slog.Logger, ok bool) { + logger, ok = logging.FromContext(ctx) + if ok { + return logger, ok + } + return rp.logger, rp.logger != nil +} + // NewRelyingPartyOAuth creates an (OAuth2) RelyingParty with the given // OAuth2 Config and possible configOptions // it will use the AuthURL and TokenURL set in config @@ -194,6 +208,7 @@ func NewRelyingPartyOIDC(ctx context.Context, issuer, clientID, clientSecret, re return nil, err } } + ctx = logCtxWithRPData(ctx, rp, "function", "NewRelyingPartyOIDC") discoveryConfiguration, err := client.Discover(ctx, rp.issuer, rp.httpClient, rp.DiscoveryEndpoint) if err != nil { return nil, err @@ -281,6 +296,15 @@ func WithJWTProfile(signerFromKey SignerFromKey) Option { } } +// WithLogger sets a logger that is used +// in case the request context does not contain a logger. +func WithLogger(logger *slog.Logger) Option { + return func(rp *relyingParty) error { + rp.logger = logger + return nil + } +} + type SignerFromKey func() (jose.Signer, error) func SignerFromKeyPath(path string) SignerFromKey { @@ -378,6 +402,7 @@ func verifyTokenResponse[C oidc.IDClaims](ctx context.Context, token *oauth2.Tok // CodeExchange handles the oauth2 code exchange, extracting and validating the id_token // returning it parsed together with the oauth2 tokens (access, refresh) func CodeExchange[C oidc.IDClaims](ctx context.Context, code string, rp RelyingParty, opts ...CodeExchangeOpt) (tokens *oidc.Tokens[C], err error) { + ctx = logCtxWithRPData(ctx, rp, "function", "CodeExchange") ctx = context.WithValue(ctx, oauth2.HTTPClient, rp.HttpClient()) codeOpts := make([]oauth2.AuthCodeOption, 0) for _, opt := range opts { @@ -467,6 +492,7 @@ func UserinfoCallback[C oidc.IDClaims, U SubjectGetter](f CodeExchangeUserinfoCa // [UserInfo]: https://openid.net/specs/openid-connect-core-1_0.html#UserInfo func Userinfo[U SubjectGetter](ctx context.Context, token, tokenType, subject string, rp RelyingParty) (userinfo U, err error) { var nilU U + ctx = logCtxWithRPData(ctx, rp, "function", "Userinfo") req, err := http.NewRequestWithContext(ctx, http.MethodGet, rp.UserinfoEndpoint(), nil) if err != nil { @@ -621,6 +647,7 @@ type RefreshTokenRequest struct { // the IDToken and AccessToken will be verfied // and the IDToken and IDTokenClaims fields will be populated in the returned object. func RefreshTokens[C oidc.IDClaims](ctx context.Context, rp RelyingParty, refreshToken, clientAssertion, clientAssertionType string) (*oidc.Tokens[C], error) { + ctx = logCtxWithRPData(ctx, rp, "function", "RefreshTokens") request := RefreshTokenRequest{ RefreshToken: refreshToken, Scopes: rp.OAuthConfig().Scopes, @@ -644,6 +671,7 @@ func RefreshTokens[C oidc.IDClaims](ctx context.Context, rp RelyingParty, refres } func EndSession(ctx context.Context, rp RelyingParty, idToken, optionalRedirectURI, optionalState string) (*url.URL, error) { + ctx = logCtxWithRPData(ctx, rp, "function", "EndSession") request := oidc.EndSessionRequest{ IdTokenHint: idToken, ClientID: rp.OAuthConfig().ClientID, @@ -659,6 +687,7 @@ func EndSession(ctx context.Context, rp RelyingParty, idToken, optionalRedirectU // // tokenTypeHint should be either "id_token" or "refresh_token". func RevokeToken(ctx context.Context, rp RelyingParty, token string, tokenTypeHint string) error { + ctx = logCtxWithRPData(ctx, rp, "function", "RevokeToken") request := client.RevokeRequest{ Token: token, TokenTypeHint: tokenTypeHint,