Skip to content

Commit

Permalink
finish the RP and example
Browse files Browse the repository at this point in the history
  • Loading branch information
muhlemmer committed Aug 28, 2023
1 parent dd9d8f2 commit 4005cb9
Show file tree
Hide file tree
Showing 9 changed files with 114 additions and 11 deletions.
44 changes: 40 additions & 4 deletions example/client/app/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,14 @@ import (
"net/http"
"os"
"strings"
"sync/atomic"
"time"

"github.com/google/uuid"
"github.com/sirupsen/logrus"
"golang.org/x/exp/slog"

"github.com/zitadel/logging"
"github.com/zitadel/oidc/v3/pkg/client/rp"
httphelper "github.com/zitadel/oidc/v3/pkg/http"
"github.com/zitadel/oidc/v3/pkg/oidc"
Expand All @@ -33,9 +36,25 @@ func main() {
redirectURI := fmt.Sprintf("http://localhost:%v%v", port, callbackPath)
cookieHandler := httphelper.NewCookieHandler(key, key, httphelper.WithUnsecure())

logger := slog.New(
slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
AddSource: true,
Level: slog.LevelDebug,
}),
)
client := &http.Client{
Timeout: time.Minute,
}
// enable outgoing request logging
logging.EnableHTTPClient(client,
logging.WithClientGroup("client"),
)

options := []rp.Option{
rp.WithCookieHandler(cookieHandler),
rp.WithVerifierOpts(rp.WithIssuedAtOffset(5 * time.Second)),
rp.WithHTTPClient(client),
rp.WithLogger(logger),
}
if clientSecret == "" {
options = append(options, rp.WithPKCE(cookieHandler))
Expand All @@ -44,7 +63,10 @@ func main() {
options = append(options, rp.WithJWTProfile(rp.SignerFromKeyPath(keyPath)))
}

provider, err := rp.NewRelyingPartyOIDC(context.TODO(), issuer, clientID, clientSecret, redirectURI, scopes, options...)
// One can add a logger to the context,
// pre-defining log attributes as required.
ctx := logging.ToContext(context.TODO(), logger)
provider, err := rp.NewRelyingPartyOIDC(ctx, issuer, clientID, clientSecret, redirectURI, scopes, options...)
if err != nil {
logrus.Fatalf("error creating provider %s", err.Error())
}
Expand Down Expand Up @@ -119,8 +141,22 @@ func main() {
//
// http.Handle(callbackPath, rp.CodeExchangeHandler(marshalToken, provider))

// simple counter for request IDs
var counter atomic.Int64
// enable incomming request logging
mw := logging.Middleware(
logging.WithLogger(logger),
logging.WithGroup("server"),
logging.WithIDFunc(func() slog.Attr {
return slog.Int64("id", counter.Add(1))
}),
)

lis := fmt.Sprintf("127.0.0.1:%s", port)
logrus.Infof("listening on http://%s/", lis)
logrus.Info("press ctrl+c to stop")
logrus.Fatal(http.ListenAndServe(lis, nil))
logger.Info("server listening, press ctrl+c to stop", "addr", lis)
err = http.ListenAndServe(lis, mw(http.DefaultServeMux))
if err != http.ErrServerClosed {
logger.Error("server terminated", "error", err)
os.Exit(1)
}
}
11 changes: 11 additions & 0 deletions example/server/exampleop/op.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@ import (
"crypto/sha256"
"log"
"net/http"
"sync/atomic"
"time"

"github.com/go-chi/chi"
"github.com/zitadel/logging"
"golang.org/x/exp/slog"
"golang.org/x/text/language"

Expand All @@ -32,6 +34,9 @@ type Storage interface {
deviceAuthenticate
}

// simple counter for request IDs
var counter atomic.Int64

// SetupServer creates an OIDC server with Issuer=http://localhost:<port>
//
// Use one of the pre-made clients in storage/clients.go or register a new one.
Expand All @@ -41,6 +46,12 @@ func SetupServer(issuer string, storage Storage, logger *slog.Logger) chi.Router
key := sha256.Sum256([]byte("test"))

router := chi.NewRouter()
router.Use(logging.Middleware(
logging.WithLogger(logger),
logging.WithIDFunc(func() slog.Attr {
return slog.Int64("id", counter.Add(1))
}),
))

// for simplicity, we provide a very small default page for users who have signed out
router.HandleFunc(pathLoggedOut, func(w http.ResponseWriter, req *http.Request) {
Expand Down
10 changes: 3 additions & 7 deletions example/server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package main

import (
"fmt"
"log"
"net/http"
"os"

Expand All @@ -22,10 +21,6 @@ func main() {
// in this example it will be handled in-memory
storage := storage.NewStorage(storage.NewUserStore(issuer))

// Using our wrapped logging handler,
// data set to the context gets printed
// as part of the log output.
// This helps us tie log output to requests.
logger := slog.New(
slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
AddSource: true,
Expand All @@ -40,7 +35,8 @@ func main() {
}
logger.Info("server listening, press ctrl+c to stop", "addr", fmt.Sprintf("http://localhost:%s/", port))
err := server.ListenAndServe()
if err != nil {
log.Fatal(err)
if err != http.ErrServerClosed {
logger.Error("server terminated", "error", err)
os.Exit(1)
}
}
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ require (
github.com/rs/cors v1.9.0
github.com/sirupsen/logrus v1.9.3
github.com/stretchr/testify v1.8.2
github.com/zitadel/logging v0.3.5-0.20230828081740-9d6abec32b43
github.com/zitadel/schema v1.3.0
golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63
golang.org/x/oauth2 v0.7.0
Expand All @@ -21,6 +22,7 @@ require (
)

require (
github.com/benbjohnson/clock v1.3.5 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/golang/protobuf v1.5.3 // indirect
github.com/google/go-cmp v0.5.9 // indirect
Expand Down
5 changes: 5 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
github.com/benbjohnson/clock v1.3.5 h1:VvXlSJBzZpA/zum6Sj74hxwYI2DIxRWuNIoXAzHZz5o=
github.com/benbjohnson/clock v1.3.5/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
Expand Down Expand Up @@ -47,6 +49,8 @@ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO
github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8=
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
github.com/zitadel/logging v0.3.5-0.20230828081740-9d6abec32b43 h1:m89ASp88slzDla9BDfAFQZRFKFI2ywhJkZRqTEBPIuI=
github.com/zitadel/logging v0.3.5-0.20230828081740-9d6abec32b43/go.mod h1:WHfGs2W60PHHXOGRZDgAKMqLmtQgIgW5w6nCyR6tX5U=
github.com/zitadel/schema v1.3.0 h1:kQ9W9tvIwZICCKWcMvCEweXET1OcOyGEuFbHs4o5kg0=
github.com/zitadel/schema v1.3.0/go.mod h1:NptN6mkBDFvERUCvZHlvWmmME+gmZ44xzwRXwhzsbtc=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
Expand Down Expand Up @@ -102,6 +106,7 @@ gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntN
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/square/go-jose.v2 v2.6.0 h1:NGk74WTnPKBNUhNzQX7PYcTLUjoq7mzKk2OKbvwk2iI=
gopkg.in/square/go-jose.v2 v2.6.0/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI=
gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
5 changes: 5 additions & 0 deletions pkg/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"golang.org/x/oauth2"
"gopkg.in/square/go-jose.v2"

"github.com/zitadel/logging"
"github.com/zitadel/oidc/v3/pkg/crypto"
httphelper "github.com/zitadel/oidc/v3/pkg/http"
"github.com/zitadel/oidc/v3/pkg/oidc"
Expand All @@ -37,6 +38,10 @@ func Discover(ctx context.Context, issuer string, httpClient *http.Client, wellK
if err != nil {
return nil, err
}
if logger, ok := logging.FromContext(ctx); ok {
logger.Debug("discover", "config", discoveryConfig)
}

Check warning on line 43 in pkg/client/client.go

View check run for this annotation

Codecov / codecov/patch

pkg/client/client.go#L42-L43

Added lines #L42 - L43 were not covered by tests

if discoveryConfig.Issuer != issuer {
return nil, oidc.ErrIssuerInvalid
}
Expand Down
2 changes: 2 additions & 0 deletions pkg/client/rp/device.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ func newDeviceClientCredentialsRequest(scopes []string, rp RelyingParty) (*oidc.
// in RFC 8628, section 3.1 and 3.2:
// https://www.rfc-editor.org/rfc/rfc8628#section-3.1
func DeviceAuthorization(ctx context.Context, scopes []string, rp RelyingParty, authFn any) (*oidc.DeviceAuthorizationResponse, error) {
ctx = logCtxWithRPData(ctx, rp, "function", "DeviceAuthorization")

Check warning on line 36 in pkg/client/rp/device.go

View check run for this annotation

Codecov / codecov/patch

pkg/client/rp/device.go#L36

Added line #L36 was not covered by tests
req, err := newDeviceClientCredentialsRequest(scopes, rp)
if err != nil {
return nil, err
Expand All @@ -45,6 +46,7 @@ func DeviceAuthorization(ctx context.Context, scopes []string, rp RelyingParty,
// by means of polling as defined in RFC, section 3.3 and 3.4:
// https://www.rfc-editor.org/rfc/rfc8628#section-3.4
func DeviceAccessToken(ctx context.Context, deviceCode string, interval time.Duration, rp RelyingParty) (resp *oidc.AccessTokenResponse, err error) {
ctx = logCtxWithRPData(ctx, rp, "function", "DeviceAccessToken")

Check warning on line 49 in pkg/client/rp/device.go

View check run for this annotation

Codecov / codecov/patch

pkg/client/rp/device.go#L49

Added line #L49 was not covered by tests
req := &client.DeviceAccessTokenRequest{
DeviceAccessTokenRequest: oidc.DeviceAccessTokenRequest{
GrantType: oidc.GrantTypeDeviceCode,
Expand Down
17 changes: 17 additions & 0 deletions pkg/client/rp/log.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package rp

import (
"context"

"github.com/zitadel/logging"
"golang.org/x/exp/slog"
)

func logCtxWithRPData(ctx context.Context, rp RelyingParty, attrs ...any) context.Context {
logger, ok := rp.Logger(ctx)
if !ok {
return ctx
}
logger = logger.With(slog.Group("rp", attrs...))
return logging.ToContext(ctx, logger)

Check warning on line 16 in pkg/client/rp/log.go

View check run for this annotation

Codecov / codecov/patch

pkg/client/rp/log.go#L15-L16

Added lines #L15 - L16 were not covered by tests
}
29 changes: 29 additions & 0 deletions pkg/client/rp/relying_party.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import (
"time"

"github.com/google/uuid"
"github.com/zitadel/logging"
"golang.org/x/exp/slog"
"golang.org/x/oauth2"
"gopkg.in/square/go-jose.v2"

Expand Down Expand Up @@ -67,6 +69,9 @@ type RelyingParty interface {
// ErrorHandler returns the handler used for callback errors

ErrorHandler() func(http.ResponseWriter, *http.Request, string, string, string)

// Logger from the context, or a fallback if set.
Logger(context.Context) (logger *slog.Logger, ok bool)
}

type ErrorHandler func(w http.ResponseWriter, r *http.Request, errorType string, errorDesc string, state string)
Expand All @@ -90,6 +95,7 @@ type relyingParty struct {
idTokenVerifier *IDTokenVerifier
verifierOpts []VerifierOption
signer jose.Signer
logger *slog.Logger
}

func (rp *relyingParty) OAuthConfig() *oauth2.Config {
Expand Down Expand Up @@ -150,6 +156,14 @@ func (rp *relyingParty) ErrorHandler() func(http.ResponseWriter, *http.Request,
return rp.errorHandler
}

func (rp *relyingParty) Logger(ctx context.Context) (logger *slog.Logger, ok bool) {
logger, ok = logging.FromContext(ctx)
if ok {
return logger, ok
}

Check warning on line 163 in pkg/client/rp/relying_party.go

View check run for this annotation

Codecov / codecov/patch

pkg/client/rp/relying_party.go#L162-L163

Added lines #L162 - L163 were not covered by tests
return rp.logger, rp.logger != nil
}

// NewRelyingPartyOAuth creates an (OAuth2) RelyingParty with the given
// OAuth2 Config and possible configOptions
// it will use the AuthURL and TokenURL set in config
Expand Down Expand Up @@ -194,6 +208,7 @@ func NewRelyingPartyOIDC(ctx context.Context, issuer, clientID, clientSecret, re
return nil, err
}
}
ctx = logCtxWithRPData(ctx, rp, "function", "NewRelyingPartyOIDC")
discoveryConfiguration, err := client.Discover(ctx, rp.issuer, rp.httpClient, rp.DiscoveryEndpoint)
if err != nil {
return nil, err
Expand Down Expand Up @@ -281,6 +296,15 @@ func WithJWTProfile(signerFromKey SignerFromKey) Option {
}
}

// WithLogger sets a logger that is used
// in case the request context does not contain a logger.
func WithLogger(logger *slog.Logger) Option {
return func(rp *relyingParty) error {
rp.logger = logger
return nil
}

Check warning on line 305 in pkg/client/rp/relying_party.go

View check run for this annotation

Codecov / codecov/patch

pkg/client/rp/relying_party.go#L301-L305

Added lines #L301 - L305 were not covered by tests
}

type SignerFromKey func() (jose.Signer, error)

func SignerFromKeyPath(path string) SignerFromKey {
Expand Down Expand Up @@ -378,6 +402,7 @@ func verifyTokenResponse[C oidc.IDClaims](ctx context.Context, token *oauth2.Tok
// CodeExchange handles the oauth2 code exchange, extracting and validating the id_token
// returning it parsed together with the oauth2 tokens (access, refresh)
func CodeExchange[C oidc.IDClaims](ctx context.Context, code string, rp RelyingParty, opts ...CodeExchangeOpt) (tokens *oidc.Tokens[C], err error) {
ctx = logCtxWithRPData(ctx, rp, "function", "CodeExchange")
ctx = context.WithValue(ctx, oauth2.HTTPClient, rp.HttpClient())
codeOpts := make([]oauth2.AuthCodeOption, 0)
for _, opt := range opts {
Expand Down Expand Up @@ -467,6 +492,7 @@ func UserinfoCallback[C oidc.IDClaims, U SubjectGetter](f CodeExchangeUserinfoCa
// [UserInfo]: https://openid.net/specs/openid-connect-core-1_0.html#UserInfo
func Userinfo[U SubjectGetter](ctx context.Context, token, tokenType, subject string, rp RelyingParty) (userinfo U, err error) {
var nilU U
ctx = logCtxWithRPData(ctx, rp, "function", "Userinfo")

req, err := http.NewRequestWithContext(ctx, http.MethodGet, rp.UserinfoEndpoint(), nil)
if err != nil {
Expand Down Expand Up @@ -621,6 +647,7 @@ type RefreshTokenRequest struct {
// the IDToken and AccessToken will be verfied
// and the IDToken and IDTokenClaims fields will be populated in the returned object.
func RefreshTokens[C oidc.IDClaims](ctx context.Context, rp RelyingParty, refreshToken, clientAssertion, clientAssertionType string) (*oidc.Tokens[C], error) {
ctx = logCtxWithRPData(ctx, rp, "function", "RefreshTokens")
request := RefreshTokenRequest{
RefreshToken: refreshToken,
Scopes: rp.OAuthConfig().Scopes,
Expand All @@ -644,6 +671,7 @@ func RefreshTokens[C oidc.IDClaims](ctx context.Context, rp RelyingParty, refres
}

func EndSession(ctx context.Context, rp RelyingParty, idToken, optionalRedirectURI, optionalState string) (*url.URL, error) {
ctx = logCtxWithRPData(ctx, rp, "function", "EndSession")
request := oidc.EndSessionRequest{
IdTokenHint: idToken,
ClientID: rp.OAuthConfig().ClientID,
Expand All @@ -659,6 +687,7 @@ func EndSession(ctx context.Context, rp RelyingParty, idToken, optionalRedirectU
//
// tokenTypeHint should be either "id_token" or "refresh_token".
func RevokeToken(ctx context.Context, rp RelyingParty, token string, tokenTypeHint string) error {
ctx = logCtxWithRPData(ctx, rp, "function", "RevokeToken")

Check warning on line 690 in pkg/client/rp/relying_party.go

View check run for this annotation

Codecov / codecov/patch

pkg/client/rp/relying_party.go#L690

Added line #L690 was not covered by tests
request := client.RevokeRequest{
Token: token,
TokenTypeHint: tokenTypeHint,
Expand Down

0 comments on commit 4005cb9

Please sign in to comment.