Skip to content

Commit

Permalink
feat: SSO MFA - Add SSO MFA ceremony (#46982)
Browse files Browse the repository at this point in the history
* Add SSOMFACeremony to MFA prompt; Add SSOMFACeremonyConstructor to MFA ceremony.

* Set SSO MFA ceremony from client configuration.

* Close sso mfa redirector once mfa ceremony is complete.

* Refactor sso mfa ceremony.

* Add SSOMFACeremonyConstructor.

* Add tests.

* Add --mfa-mode=sso support; Add cli prompt UX changes.

* Remove unused field, fix test.

* Resolve comments.

* Remove convoluted context closing logic for sso redirector.

* Fix test.

* Update lib/client/sso/ceremony.go

Co-authored-by: Alan Parra <alan.parra@goteleport.com>

---------

Co-authored-by: Alan Parra <alan.parra@goteleport.com>
  • Loading branch information
Joerger and codingllama authored Oct 29, 2024
1 parent e22aad3 commit 5e7bdb2
Show file tree
Hide file tree
Showing 14 changed files with 606 additions and 41 deletions.
8 changes: 8 additions & 0 deletions api/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -664,6 +664,9 @@ type Config struct {
// MFAPromptConstructor is used to create MFA prompts when needed.
// If nil, the client will not prompt for MFA.
MFAPromptConstructor mfa.PromptConstructor
// SSOMFACeremonyConstructor is used to handle SSO MFA when needed.
// If nil, the client will not prompt for MFA.
SSOMFACeremonyConstructor mfa.SSOMFACeremonyConstructor
}

// CheckAndSetDefaults checks and sets default config values.
Expand Down Expand Up @@ -730,6 +733,11 @@ func (c *Client) SetMFAPromptConstructor(pc mfa.PromptConstructor) {
c.c.MFAPromptConstructor = pc
}

// SetSSOMFACeremonyConstructor sets the SSO MFA ceremony constructor for this client.
func (c *Client) SetSSOMFACeremonyConstructor(scc mfa.SSOMFACeremonyConstructor) {
c.c.SSOMFACeremonyConstructor = scc
}

// Close closes the Client connection to the auth server.
func (c *Client) Close() error {
if c.setClosed() && c.conn != nil {
Expand Down
1 change: 1 addition & 0 deletions api/client/mfa.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ func (c *Client) PerformMFACeremony(ctx context.Context, challengeRequest *proto
mfaCeremony := &mfa.Ceremony{
CreateAuthenticateChallenge: c.CreateAuthenticateChallenge,
PromptConstructor: c.c.MFAPromptConstructor,
SSOMFACeremonyConstructor: c.c.SSOMFACeremonyConstructor,
}
return mfaCeremony.Run(ctx, challengeRequest, promptOpts...)
}
26 changes: 26 additions & 0 deletions api/mfa/ceremony.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,21 @@ type Ceremony struct {
CreateAuthenticateChallenge CreateAuthenticateChallengeFunc
// PromptConstructor creates a prompt to prompt the user to solve an authentication challenge.
PromptConstructor PromptConstructor
// SSOMFACeremonyConstructor is an optional SSO MFA ceremony constructor. If provided,
// the MFA ceremony will also attempt to retrieve an SSO MFA challenge.
SSOMFACeremonyConstructor SSOMFACeremonyConstructor
}

// SSOMFACeremony is an SSO MFA ceremony.
type SSOMFACeremony interface {
GetClientCallbackURL() string
Run(ctx context.Context, chal *proto.MFAAuthenticateChallenge) (*proto.MFAAuthenticateResponse, error)
Close()
}

// SSOMFACeremonyConstructor constructs a new SSO MFA ceremony.
type SSOMFACeremonyConstructor func(ctx context.Context) (SSOMFACeremony, error)

// CreateAuthenticateChallengeFunc is a function that creates an authentication challenge.
type CreateAuthenticateChallengeFunc func(ctx context.Context, req *proto.CreateAuthenticateChallengeRequest) (*proto.MFAAuthenticateChallenge, error)

Expand All @@ -54,6 +67,19 @@ func (c *Ceremony) Run(ctx context.Context, req *proto.CreateAuthenticateChallen
return nil, trace.BadParameter("mfa challenge scope must be specified")
}

// If available, prepare an SSO MFA ceremony and set the client redirect URL in the challenge
// request to request an SSO challenge in addition to other challenges.
if c.SSOMFACeremonyConstructor != nil {
ssoMFACeremony, err := c.SSOMFACeremonyConstructor(ctx)
if err != nil {
return nil, trace.Wrap(err, "failed to handle SSO MFA ceremony")
}
defer ssoMFACeremony.Close()

req.SSOClientRedirectURL = ssoMFACeremony.GetClientCallbackURL()
promptOpts = append(promptOpts, withSSOMFACeremony(ssoMFACeremony))
}

chal, err := c.CreateAuthenticateChallenge(ctx, req)
if err != nil {
// CreateAuthenticateChallenge returns a bad parameter error when the client
Expand Down
77 changes: 76 additions & 1 deletion api/mfa/ceremony_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,14 @@ import (

"github.com/gravitational/trace"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/gravitational/teleport/api/client/proto"
mfav1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/mfa/v1"
"github.com/gravitational/teleport/api/mfa"
)

func TestPerformMFACeremony(t *testing.T) {
func TestMFACeremony(t *testing.T) {
t.Parallel()
ctx := context.Background()

Expand Down Expand Up @@ -128,3 +129,77 @@ func TestPerformMFACeremony(t *testing.T) {
})
}
}

func TestMFACeremony_SSO(t *testing.T) {
t.Parallel()
ctx := context.Background()

testMFAChallenge := &proto.MFAAuthenticateChallenge{
SSOChallenge: &proto.SSOChallenge{
RedirectUrl: "redirect",
RequestId: "request-id",
},
}
testMFAResponse := &proto.MFAAuthenticateResponse{
Response: &proto.MFAAuthenticateResponse_SSO{
SSO: &proto.SSOResponse{
Token: "token",
RequestId: "request-id",
},
},
}

ssoMFACeremony := &mfa.Ceremony{
CreateAuthenticateChallenge: func(ctx context.Context, req *proto.CreateAuthenticateChallengeRequest) (*proto.MFAAuthenticateChallenge, error) {
return testMFAChallenge, nil
},
PromptConstructor: func(opts ...mfa.PromptOpt) mfa.Prompt {
cfg := new(mfa.PromptConfig)
for _, opt := range opts {
opt(cfg)
}

return mfa.PromptFunc(func(ctx context.Context, chal *proto.MFAAuthenticateChallenge) (*proto.MFAAuthenticateResponse, error) {
if cfg.SSOMFACeremony == nil {
return nil, trace.BadParameter("expected sso mfa ceremony")
}

return cfg.SSOMFACeremony.Run(ctx, chal)
})
},
SSOMFACeremonyConstructor: func(ctx context.Context) (mfa.SSOMFACeremony, error) {
return &mockSSOMFACeremony{
clientCallbackURL: "client-redirect",
prompt: func(ctx context.Context, chal *proto.MFAAuthenticateChallenge) (*proto.MFAAuthenticateResponse, error) {
return testMFAResponse, nil
},
}, nil
},
}

resp, err := ssoMFACeremony.Run(ctx, &proto.CreateAuthenticateChallengeRequest{
ChallengeExtensions: &mfav1.ChallengeExtensions{
Scope: mfav1.ChallengeScope_CHALLENGE_SCOPE_ADMIN_ACTION,
},
MFARequiredCheck: &proto.IsMFARequiredRequest{},
})
require.NoError(t, err)
require.Equal(t, testMFAResponse, resp)
}

type mockSSOMFACeremony struct {
clientCallbackURL string
prompt mfa.PromptFunc
}

// GetClientCallbackURL returns the client callback URL.
func (m *mockSSOMFACeremony) GetClientCallbackURL() string {
return m.clientCallbackURL
}

// Run the SSO MFA ceremony.
func (m *mockSSOMFACeremony) Run(ctx context.Context, chal *proto.MFAAuthenticateChallenge) (*proto.MFAAuthenticateResponse, error) {
return m.prompt(ctx, chal)
}

func (m *mockSSOMFACeremony) Close() {}
9 changes: 9 additions & 0 deletions api/mfa/prompt.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ type PromptConfig struct {
// Extensions are the challenge extensions used to create the prompt's challenge.
// Used to enrich certain prompts.
Extensions *mfav1.ChallengeExtensions
// SSOMFACeremony is an SSO MFA ceremony.
SSOMFACeremony SSOMFACeremony
}

// DeviceDescriptor is a descriptor for a device, such as "registered".
Expand Down Expand Up @@ -117,3 +119,10 @@ func WithPromptChallengeExtensions(exts *mfav1.ChallengeExtensions) PromptOpt {
cfg.Extensions = exts
}
}

// withSSOMFACeremony sets the SSO MFA ceremony for the MFA prompt.
func withSSOMFACeremony(ssoMFACeremony SSOMFACeremony) PromptOpt {
return func(cfg *PromptConfig) {
cfg.SSOMFACeremony = ssoMFACeremony
}
}
12 changes: 9 additions & 3 deletions lib/client/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,9 @@ type Config struct {
// authenticators, such as remote hosts or virtual machines.
PreferOTP bool

// PreferSSO prefers SSO in favor of other MFA methods.
PreferSSO bool

// CheckVersions will check that client version is compatible
// with auth server version when connecting.
CheckVersions bool
Expand Down Expand Up @@ -3043,6 +3046,8 @@ func (tc *TeleportClient) ConnectToCluster(ctx context.Context) (_ *ClusterClien
return nil, trace.NewAggregate(err, pclt.Close())
}
authClientCfg.MFAPromptConstructor = tc.NewMFAPrompt
authClientCfg.SSOMFACeremonyConstructor = tc.NewSSOMFACeremony

authClient, err := authclient.NewClient(authClientCfg)
if err != nil {
return nil, trace.NewAggregate(err, pclt.Close())
Expand Down Expand Up @@ -5062,9 +5067,10 @@ func (tc *TeleportClient) NewKubernetesServiceClient(ctx context.Context, cluste
Credentials: []client.Credentials{
client.LoadTLS(tlsConfig),
},
ALPNConnUpgradeRequired: tc.TLSRoutingConnUpgradeRequired,
InsecureAddressDiscovery: tc.InsecureSkipVerify,
MFAPromptConstructor: tc.NewMFAPrompt,
ALPNConnUpgradeRequired: tc.TLSRoutingConnUpgradeRequired,
InsecureAddressDiscovery: tc.InsecureSkipVerify,
MFAPromptConstructor: tc.NewMFAPrompt,
SSOMFACeremonyConstructor: tc.NewSSOMFACeremony,
})
if err != nil {
return nil, trace.Wrap(err)
Expand Down
5 changes: 3 additions & 2 deletions lib/client/cluster_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -361,8 +361,9 @@ func TestIssueUserCertsWithMFA(t *testing.T) {
tc: &TeleportClient{
localAgent: agent,
Config: Config{
SiteName: "test",
Tracer: tracing.NoopTracer("test"),
WebProxyAddr: "proxy.example.com",
SiteName: "test",
Tracer: tracing.NoopTracer("test"),
MFAPromptConstructor: func(cfg *libmfa.PromptConfig) mfa.Prompt {
return test.prompt
},
Expand Down
19 changes: 19 additions & 0 deletions lib/client/mfa.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,15 @@ import (
"github.com/gravitational/teleport/api/client/proto"
"github.com/gravitational/teleport/api/mfa"
libmfa "github.com/gravitational/teleport/lib/client/mfa"
"github.com/gravitational/teleport/lib/client/sso"
)

// NewMFACeremony returns a new MFA ceremony configured for this client.
func (tc *TeleportClient) NewMFACeremony() *mfa.Ceremony {
return &mfa.Ceremony{
CreateAuthenticateChallenge: tc.createAuthenticateChallenge,
PromptConstructor: tc.NewMFAPrompt,
SSOMFACeremonyConstructor: tc.NewSSOMFACeremony,
}
}

Expand Down Expand Up @@ -61,6 +63,7 @@ func (tc *TeleportClient) NewMFAPrompt(opts ...mfa.PromptOpt) mfa.Prompt {
PromptConfig: *cfg,
Writer: tc.Stderr,
PreferOTP: tc.PreferOTP,
PreferSSO: tc.PreferSSO,
AllowStdinHijack: tc.AllowStdinHijack,
StdinFunc: tc.StdinFunc,
})
Expand All @@ -79,5 +82,21 @@ func (tc *TeleportClient) newPromptConfig(opts ...mfa.PromptOpt) *libmfa.PromptC
cfg.WebauthnLoginFunc = tc.WebauthnLogin
cfg.WebauthnSupported = true
}

return cfg
}

// NewSSOMFACeremony creates a new SSO MFA ceremony.
func (tc *TeleportClient) NewSSOMFACeremony(ctx context.Context) (mfa.SSOMFACeremony, error) {
rdConfig, err := tc.ssoRedirectorConfig(ctx, "" /*connectorDisplayName*/)
if err != nil {
return nil, trace.Wrap(err)
}

rd, err := sso.NewRedirector(rdConfig)
if err != nil {
return nil, trace.Wrap(err)
}

return sso.NewCLIMFACeremony(rd), nil
}
Loading

0 comments on commit 5e7bdb2

Please sign in to comment.