diff --git a/integration/proxy/proxy_helpers.go b/integration/proxy/proxy_helpers.go index 726568e34d46e..7fff7e9a54100 100644 --- a/integration/proxy/proxy_helpers.go +++ b/integration/proxy/proxy_helpers.go @@ -53,7 +53,7 @@ import ( "github.com/gravitational/teleport/api/utils/retryutils" "github.com/gravitational/teleport/integration/helpers" "github.com/gravitational/teleport/lib" - "github.com/gravitational/teleport/lib/auth" + "github.com/gravitational/teleport/lib/auth/join" "github.com/gravitational/teleport/lib/auth/state" "github.com/gravitational/teleport/lib/client" "github.com/gravitational/teleport/lib/defaults" @@ -664,7 +664,7 @@ func mustRegisterUsingIAMMethod(t *testing.T, proxyAddr utils.NetAddr, token str require.NoError(t, err) node := uuid.NewString() - _, err = auth.Register(context.TODO(), auth.RegisterParams{ + _, err = join.Register(context.TODO(), join.RegisterParams{ Token: token, ID: state.IdentityID{ Role: types.RoleNode, diff --git a/lib/auth/auth.go b/lib/auth/auth.go index aedc4938cffa4..0ec085c293db6 100644 --- a/lib/auth/auth.go +++ b/lib/auth/auth.go @@ -52,7 +52,6 @@ import ( "github.com/gravitational/trace" "github.com/jonboulle/clockwork" "github.com/prometheus/client_golang/prometheus" - "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/exporters/otlp/otlptrace" "golang.org/x/crypto/bcrypt" "golang.org/x/crypto/ssh" @@ -155,8 +154,6 @@ const ( "(hint: use 'tctl get roles' to find roles that need updating)" ) -var tracer = otel.Tracer("github.com/gravitational/teleport/lib/auth") - var ErrRequiresEnterprise = services.ErrRequiresEnterprise // ServerOption allows setting options as functional arguments to Server diff --git a/lib/auth/bot_test.go b/lib/auth/bot_test.go index e53e445ceafaf..4c54114a96522 100644 --- a/lib/auth/bot_test.go +++ b/lib/auth/bot_test.go @@ -42,6 +42,7 @@ import ( machineidv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/machineid/v1" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/auth/authclient" + "github.com/gravitational/teleport/lib/auth/join" "github.com/gravitational/teleport/lib/auth/machineid/machineidv1" "github.com/gravitational/teleport/lib/auth/state" "github.com/gravitational/teleport/lib/auth/testauthority" @@ -118,7 +119,7 @@ func TestRegisterBotCertificateGenerationCheck(t *testing.T) { tlsPublicKey, err := tlsca.MarshalPublicKeyFromPrivateKeyPEM(sshPrivateKey) require.NoError(t, err) - certs, err := Register(ctx, RegisterParams{ + certs, err := join.Register(ctx, join.RegisterParams{ Token: token.GetName(), ID: state.IdentityID{ Role: types.RoleBot, @@ -191,7 +192,7 @@ func TestRegisterBotCertificateGenerationStolen(t *testing.T) { tlsPublicKey, err := tlsca.MarshalPublicKeyFromPrivateKeyPEM(sshPrivateKey) require.NoError(t, err) - certs, err := Register(ctx, RegisterParams{ + certs, err := join.Register(ctx, join.RegisterParams{ Token: token.GetName(), ID: state.IdentityID{ Role: types.RoleBot, @@ -267,7 +268,7 @@ func TestRegisterBotCertificateExtensions(t *testing.T) { tlsPublicKey, err := tlsca.MarshalPublicKeyFromPrivateKeyPEM(sshPrivateKey) require.NoError(t, err) - certs, err := Register(ctx, RegisterParams{ + certs, err := join.Register(ctx, join.RegisterParams{ Token: token.GetName(), ID: state.IdentityID{ Role: types.RoleBot, diff --git a/lib/auth/join/iam/endpoints.go b/lib/auth/join/iam/endpoints.go new file mode 100644 index 0000000000000..3434ce30228b1 --- /dev/null +++ b/lib/auth/join/iam/endpoints.go @@ -0,0 +1,91 @@ +// Teleport +// Copyright (C) 2024 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package iam + +import "sync" + +var ( + // ValidSTSEndpoints holds a sorted list of all known valid public endpoints for + // the AWS STS service. You can generate this list by running + // $ go run github.com/nklaassen/sts-endpoints@latest --go-list + // Update aws-sdk-go in that package to learn about new endpoints. + ValidSTSEndpoints = sync.OnceValue(func() []string { + return []string{ + "sts-fips.us-east-1.amazonaws.com", + "sts-fips.us-east-2.amazonaws.com", + "sts-fips.us-west-1.amazonaws.com", + "sts-fips.us-west-2.amazonaws.com", + "sts.af-south-1.amazonaws.com", + "sts.amazonaws.com", + "sts.ap-east-1.amazonaws.com", + "sts.ap-northeast-1.amazonaws.com", + "sts.ap-northeast-2.amazonaws.com", + "sts.ap-northeast-3.amazonaws.com", + "sts.ap-south-1.amazonaws.com", + "sts.ap-south-2.amazonaws.com", + "sts.ap-southeast-1.amazonaws.com", + "sts.ap-southeast-2.amazonaws.com", + "sts.ap-southeast-3.amazonaws.com", + "sts.ap-southeast-4.amazonaws.com", + "sts.ca-central-1.amazonaws.com", + "sts.ca-west-1.amazonaws.com", + "sts.cn-north-1.amazonaws.com.cn", + "sts.cn-northwest-1.amazonaws.com.cn", + "sts.eu-central-1.amazonaws.com", + "sts.eu-central-2.amazonaws.com", + "sts.eu-north-1.amazonaws.com", + "sts.eu-south-1.amazonaws.com", + "sts.eu-south-2.amazonaws.com", + "sts.eu-west-1.amazonaws.com", + "sts.eu-west-2.amazonaws.com", + "sts.eu-west-3.amazonaws.com", + "sts.il-central-1.amazonaws.com", + "sts.me-central-1.amazonaws.com", + "sts.me-south-1.amazonaws.com", + "sts.sa-east-1.amazonaws.com", + "sts.us-east-1.amazonaws.com", + "sts.us-east-2.amazonaws.com", + "sts.us-gov-east-1.amazonaws.com", + "sts.us-gov-west-1.amazonaws.com", + "sts.us-iso-east-1.c2s.ic.gov", + "sts.us-iso-west-1.c2s.ic.gov", + "sts.us-isob-east-1.sc2s.sgov.gov", + "sts.us-west-1.amazonaws.com", + "sts.us-west-2.amazonaws.com", + } + }) + + GlobalSTSEndpoints = sync.OnceValue(func() []string { + return []string{ + "sts.amazonaws.com", + // This is not a real endpoint, but the SDK will select it if + // AWS_USE_FIPS_ENDPOINT is set and a region is not. + "sts-fips.aws-global.amazonaws.com", + } + }) + + FIPSSTSEndpoints = sync.OnceValue(func() []string { + return []string{ + "sts-fips.us-east-1.amazonaws.com", + "sts-fips.us-east-2.amazonaws.com", + "sts-fips.us-west-1.amazonaws.com", + "sts-fips.us-west-2.amazonaws.com", + "sts.us-gov-east-1.amazonaws.com", + "sts.us-gov-west-1.amazonaws.com", + } + }) +) diff --git a/lib/auth/join/iam/iam.go b/lib/auth/join/iam/iam.go new file mode 100644 index 0000000000000..241d4ec7800c3 --- /dev/null +++ b/lib/auth/join/iam/iam.go @@ -0,0 +1,161 @@ +// Teleport +// Copyright (C) 2024 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package iam + +import ( + "bytes" + "context" + "log/slog" + "slices" + "strings" + + awssdk "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/endpoints" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/sts" + "github.com/gravitational/trace" + + cloudaws "github.com/gravitational/teleport/lib/cloud/imds/aws" +) + +const ( + // AWS SignedHeaders will always be lowercase + // https://docs.aws.amazon.com/AmazonS3/latest/API/sigv4-auth-using-authorization-header.html#sigv4-auth-header-overview + challengeHeaderKey = "x-teleport-challenge" +) + +type stsIdentityRequestConfig struct { + regionalEndpointOption endpoints.STSRegionalEndpoint + fipsEndpointOption endpoints.FIPSEndpointState +} + +type stsIdentityRequestOption func(cfg *stsIdentityRequestConfig) + +func WithRegionalEndpoint(useRegionalEndpoint bool) stsIdentityRequestOption { + return func(cfg *stsIdentityRequestConfig) { + if useRegionalEndpoint { + cfg.regionalEndpointOption = endpoints.RegionalSTSEndpoint + } else { + cfg.regionalEndpointOption = endpoints.LegacySTSEndpoint + } + } +} + +func WithFIPSEndpoint(useFIPS bool) stsIdentityRequestOption { + return func(cfg *stsIdentityRequestConfig) { + if useFIPS { + cfg.fipsEndpointOption = endpoints.FIPSEndpointStateEnabled + } else { + cfg.fipsEndpointOption = endpoints.FIPSEndpointStateDisabled + } + } +} + +// getEC2LocalRegion returns the AWS region this EC2 instance is running in, or +// a NotFound error if the EC2 IMDS is unavailable. +func getEC2LocalRegion(ctx context.Context) (string, error) { + imdsClient, err := cloudaws.NewInstanceMetadataClient(ctx) + if err != nil { + return "", trace.Wrap(err) + } + + if !imdsClient.IsAvailable(ctx) { + return "", trace.NotFound("IMDS is unavailable") + } + + region, err := imdsClient.GetRegion(ctx) + return region, trace.Wrap(err) +} + +func newSTSClient(ctx context.Context, cfg *stsIdentityRequestConfig) (*sts.STS, error) { + awsConfig := awssdk.Config{ + UseFIPSEndpoint: cfg.fipsEndpointOption, + STSRegionalEndpoint: cfg.regionalEndpointOption, + } + sess, err := session.NewSessionWithOptions(session.Options{ + SharedConfigState: session.SharedConfigEnable, + Config: awsConfig, + }) + if err != nil { + return nil, trace.Wrap(err) + } + + stsClient := sts.New(sess) + + if slices.Contains(GlobalSTSEndpoints(), strings.TrimPrefix(stsClient.Endpoint, "https://")) { + // If the caller wants to use the regional endpoint but it was not resolved + // from the environment, attempt to find the region from the EC2 IMDS + if cfg.regionalEndpointOption == endpoints.RegionalSTSEndpoint { + region, err := getEC2LocalRegion(ctx) + if err != nil { + return nil, trace.Wrap(err, "failed to resolve local AWS region from environment or IMDS") + } + stsClient = sts.New(sess, awssdk.NewConfig().WithRegion(region)) + } else { + const msg = "Attempting to use the global STS endpoint for the IAM join method. " + + "This will probably fail in non-default AWS partitions such as China or GovCloud, or if FIPS mode is enabled. " + + "Consider setting the AWS_REGION environment variable, setting the region in ~/.aws/config, or enabling the IMDSv2." + slog.InfoContext(ctx, msg) + } + } + + if cfg.fipsEndpointOption == endpoints.FIPSEndpointStateEnabled && + !slices.Contains(ValidSTSEndpoints(), strings.TrimPrefix(stsClient.Endpoint, "https://")) { + // The AWS SDK will generate invalid endpoints when attempting to + // resolve the FIPS endpoint for a region that does not have one. + // In this case, try to use the FIPS endpoint in us-east-1. This should + // work for all regions in the standard partition. In GovCloud, we should + // not hit this because all regional endpoints support FIPS. In China or + // other partitions, this will fail, and FIPS mode will not be supported. + const msg = "AWS SDK resolved invalid FIPS STS endpoint. " + + "Attempting to use the FIPS STS endpoint for us-east-1." + slog.InfoContext(ctx, msg, "resolved", stsClient.Endpoint) + stsClient = sts.New(sess, awssdk.NewConfig().WithRegion("us-east-1")) + } + + return stsClient, nil +} + +// CreateSignedSTSIdentityRequest is called on the client side and returns an +// sts:GetCallerIdentity request signed with the local AWS credentials +func CreateSignedSTSIdentityRequest(ctx context.Context, challenge string, opts ...stsIdentityRequestOption) ([]byte, error) { + cfg := &stsIdentityRequestConfig{} + for _, opt := range opts { + opt(cfg) + } + + stsClient, err := newSTSClient(ctx, cfg) + if err != nil { + return nil, trace.Wrap(err) + } + + req, _ := stsClient.GetCallerIdentityRequest(&sts.GetCallerIdentityInput{}) + // set challenge header + req.HTTPRequest.Header.Set(challengeHeaderKey, challenge) + // request json for simpler parsing + req.HTTPRequest.Header.Set("Accept", "application/json") + // sign the request, including headers + if err := req.Sign(); err != nil { + return nil, trace.Wrap(err) + } + // write the signed HTTP request to a buffer + var signedRequest bytes.Buffer + if err := req.HTTPRequest.Write(&signedRequest); err != nil { + return nil, trace.Wrap(err) + } + return signedRequest.Bytes(), nil +} diff --git a/lib/auth/join/join.go b/lib/auth/join/join.go new file mode 100644 index 0000000000000..367b091612ac5 --- /dev/null +++ b/lib/auth/join/join.go @@ -0,0 +1,796 @@ +// Teleport +// Copyright (C) 2024 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package join + +import ( + "context" + "crypto/tls" + "crypto/x509" + "log/slog" + "os" + "slices" + "time" + + "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" + log "github.com/sirupsen/logrus" + "go.opentelemetry.io/otel" + "golang.org/x/net/http2" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials" + + "github.com/gravitational/teleport" + "github.com/gravitational/teleport/api/breaker" + "github.com/gravitational/teleport/api/client" + "github.com/gravitational/teleport/api/client/proto" + "github.com/gravitational/teleport/api/constants" + apidefaults "github.com/gravitational/teleport/api/defaults" + "github.com/gravitational/teleport/api/metadata" + "github.com/gravitational/teleport/api/observability/tracing" + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/api/utils/aws" + "github.com/gravitational/teleport/lib/auth/authclient" + "github.com/gravitational/teleport/lib/auth/join/iam" + "github.com/gravitational/teleport/lib/auth/state" + "github.com/gravitational/teleport/lib/circleci" + "github.com/gravitational/teleport/lib/cloud/imds/azure" + "github.com/gravitational/teleport/lib/cloud/imds/gcp" + "github.com/gravitational/teleport/lib/defaults" + "github.com/gravitational/teleport/lib/githubactions" + "github.com/gravitational/teleport/lib/gitlab" + "github.com/gravitational/teleport/lib/kubernetestoken" + "github.com/gravitational/teleport/lib/spacelift" + "github.com/gravitational/teleport/lib/srv/alpnproxy/common" + "github.com/gravitational/teleport/lib/tlsca" + "github.com/gravitational/teleport/lib/tpm" + "github.com/gravitational/teleport/lib/utils" +) + +var tracer = otel.Tracer("github.com/gravitational/teleport/lib/auth/join") + +// HostCredentials is an interface for a client that can be used to get host +// credentials. This interface is needed because lib/client cannot be imported +// in lib/auth due to circular imports. +type HostCredentials func(context.Context, string, bool, types.RegisterUsingTokenRequest) (*proto.Certs, error) + +// AzureParams is the parameters specific to the azure join method. +type AzureParams struct { + // ClientID is the client ID of the managed identity for Teleport to assume + // when authenticating a node. + ClientID string +} + +// RegisterParams specifies parameters +// for first time register operation with auth server +type RegisterParams struct { + // Token is a secure token to join the cluster + Token string + // ID is identity ID + ID state.IdentityID + // AuthServers is a list of auth servers to dial + AuthServers []utils.NetAddr + // ProxyServer is a proxy server to dial + ProxyServer utils.NetAddr + // AdditionalPrincipals is a list of additional principals to dial + AdditionalPrincipals []string + // DNSNames is a list of DNS names to add to x509 certificate + DNSNames []string + // PublicTLSKey is a server's public key to sign + PublicTLSKey []byte + // PublicSSHKey is a server's public SSH key to sign + PublicSSHKey []byte + // CipherSuites is a list of cipher suites to use for TLS client connection + CipherSuites []uint16 + // CAPins are the SKPI hashes of the CAs used to verify the Auth Server. + CAPins []string + // CAPath is the path to the CA file. + CAPath string + // GetHostCredentials is a client that can fetch host credentials. + GetHostCredentials HostCredentials + // Clock specifies the time provider. Will be used to override the time anchor + // for TLS certificate verification. + // Defaults to real clock if unspecified + Clock clockwork.Clock + // JoinMethod is the joining method used for this register request. + JoinMethod types.JoinMethod + // ec2IdentityDocument is used for Simplified Node Joining to prove the + // identity of a joining EC2 instance. + ec2IdentityDocument []byte + // AzureParams is the parameters specific to the azure join method. + AzureParams AzureParams + // CircuitBreakerConfig defines how the circuit breaker should behave. + CircuitBreakerConfig breaker.Config + // FIPS means FedRAMP/FIPS 140-2 compliant configuration was requested. + FIPS bool + // IDToken is a token retrieved from a workload identity provider for + // certain join types e.g GitHub, Google. + IDToken string + // Expires is an optional field for bots that specifies a time that the + // certificates that are returned by registering should expire at. + // It should not be specified for non-bot registrations. + Expires *time.Time + // Insecure trusts the certificates from the Auth Server or Proxy during registration without verification. + Insecure bool +} + +func (r *RegisterParams) checkAndSetDefaults() error { + if r.Clock == nil { + r.Clock = clockwork.NewRealClock() + } + + if err := r.verifyAuthOrProxyAddress(); err != nil { + return trace.BadParameter("no auth or proxy servers set") + } + + return nil +} + +func (r *RegisterParams) verifyAuthOrProxyAddress() error { + haveAuthServers := len(r.AuthServers) > 0 + haveProxyServer := !r.ProxyServer.IsEmpty() + + if !haveAuthServers && !haveProxyServer { + return trace.BadParameter("no auth or proxy servers set") + } + + if haveAuthServers && haveProxyServer { + return trace.BadParameter("only one of auth or proxy server should be set") + } + + return nil +} + +// Register is used to generate host keys when a node or proxy are running on +// different hosts than the auth server. This method requires provisioning +// tokens to prove a valid auth server was used to issue the joining request +// as well as a method for the node to validate the auth server. +func Register(ctx context.Context, params RegisterParams) (certs *proto.Certs, err error) { + ctx, span := tracer.Start(ctx, "Register") + defer func() { tracing.EndSpan(span, err) }() + + if err := params.checkAndSetDefaults(); err != nil { + return nil, trace.Wrap(err) + } + // Read in the token. The token can either be passed in or come from a file + // on disk. + token, err := utils.TryReadValueAsFile(params.Token) + if err != nil { + return nil, trace.Wrap(err) + } + + // add EC2 Identity Document to params if required for given join method + switch params.JoinMethod { + case types.JoinMethodEC2: + if !aws.IsEC2NodeID(params.ID.HostUUID) { + return nil, trace.BadParameter( + `Host ID %q is not valid when using the EC2 join method, `+ + `try removing the "host_uuid" file in your teleport data dir `+ + `(e.g. /var/lib/teleport/host_uuid)`, + params.ID.HostUUID) + } + params.ec2IdentityDocument, err = utils.GetRawEC2IdentityDocument(ctx) + if err != nil { + return nil, trace.Wrap(err) + } + case types.JoinMethodGitHub: + params.IDToken, err = githubactions.NewIDTokenSource().GetIDToken(ctx) + if err != nil { + return nil, trace.Wrap(err) + } + case types.JoinMethodGitLab: + params.IDToken, err = gitlab.NewIDTokenSource(os.Getenv).GetIDToken() + if err != nil { + return nil, trace.Wrap(err) + } + case types.JoinMethodCircleCI: + params.IDToken, err = circleci.GetIDToken(os.Getenv) + if err != nil { + return nil, trace.Wrap(err) + } + case types.JoinMethodKubernetes: + params.IDToken, err = kubernetestoken.GetIDToken(os.Getenv, os.ReadFile) + if err != nil { + return nil, trace.Wrap(err) + } + case types.JoinMethodGCP: + params.IDToken, err = gcp.GetIDToken(ctx) + if err != nil { + return nil, trace.Wrap(err) + } + case types.JoinMethodSpacelift: + params.IDToken, err = spacelift.NewIDTokenSource(os.Getenv).GetIDToken() + if err != nil { + return nil, trace.Wrap(err) + } + } + + type registerMethod struct { + call func(ctx context.Context, token string, params RegisterParams) (*proto.Certs, error) + desc string + } + + registerThroughAuth := registerMethod{registerThroughAuth, "with auth server"} + registerThroughProxy := registerMethod{registerThroughProxy, "via proxy server"} + + registerMethods := []registerMethod{registerThroughAuth, registerThroughProxy} + + if !params.ProxyServer.IsEmpty() { + log.WithField("proxy-server", params.ProxyServer).Debugf("Registering node to the cluster.") + + registerMethods = []registerMethod{registerThroughProxy} + + if proxyServerIsAuth(params.ProxyServer) { + log.Debugf("The specified proxy server appears to be an auth server.") + } + } else { + log.WithField("auth-servers", params.AuthServers).Debugf("Registering node to the cluster.") + + if params.GetHostCredentials == nil { + log.Debugf("Missing client, it is not possible to register through proxy.") + registerMethods = []registerMethod{registerThroughAuth} + } else if authServerIsProxy(params.AuthServers) { + log.Debugf("The first specified auth server appears to be a proxy.") + registerMethods = []registerMethod{registerThroughProxy, registerThroughAuth} + } + } + + var collectedErrs []error + for _, method := range registerMethods { + log.Infof("Attempting registration %s.", method.desc) + certs, err := method.call(ctx, token, params) + if err != nil { + collectedErrs = append(collectedErrs, err) + log.WithError(err).Debugf("Registration %s failed.", method.desc) + continue + } + log.Infof("Successfully registered %s.", method.desc) + return certs, nil + } + return nil, trace.NewAggregate(collectedErrs...) +} + +// authServerIsProxy returns true if the first specified auth server +// to register with appears to be a proxy. +func authServerIsProxy(servers []utils.NetAddr) bool { + if len(servers) == 0 { + return false + } + port := servers[0].Port(0) + return port == defaults.HTTPListenPort || port == teleport.StandardHTTPSPort +} + +// proxyServerIsAuth returns true if the address given to register with +// appears to be an auth server. +func proxyServerIsAuth(server utils.NetAddr) bool { + port := server.Port(0) + return port == defaults.AuthListenPort +} + +// registerThroughProxy is used to register through the proxy server. +func registerThroughProxy( + ctx context.Context, + token string, + params RegisterParams, +) (certs *proto.Certs, err error) { + ctx, span := tracer.Start(ctx, "registerThroughProxy") + defer func() { tracing.EndSpan(span, err) }() + + switch params.JoinMethod { + case types.JoinMethodIAM, types.JoinMethodAzure, types.JoinMethodTPM: + // IAM and Azure join methods require gRPC client + conn, err := proxyJoinServiceConn(ctx, params, params.Insecure) + if err != nil { + return nil, trace.Wrap(err) + } + defer conn.Close() + + joinServiceClient := client.NewJoinServiceClient(proto.NewJoinServiceClient(conn)) + switch params.JoinMethod { + case types.JoinMethodIAM: + certs, err = registerUsingIAMMethod(ctx, joinServiceClient, token, params) + case types.JoinMethodAzure: + certs, err = registerUsingAzureMethod(ctx, joinServiceClient, token, params) + case types.JoinMethodTPM: + certs, err = registerUsingTPMMethod(ctx, joinServiceClient, token, params) + default: + return nil, trace.BadParameter("unhandled join method %q", params.JoinMethod) + } + + if err != nil { + return nil, trace.Wrap(err) + } + default: + // The rest of the join methods use GetHostCredentials function passed through + // params to call proxy HTTP endpoint + var err error + certs, err = params.GetHostCredentials(ctx, + getHostAddresses(params)[0], + params.Insecure, + types.RegisterUsingTokenRequest{ + Token: token, + HostID: params.ID.HostUUID, + NodeName: params.ID.NodeName, + Role: params.ID.Role, + AdditionalPrincipals: params.AdditionalPrincipals, + DNSNames: params.DNSNames, + PublicTLSKey: params.PublicTLSKey, + PublicSSHKey: params.PublicSSHKey, + EC2IdentityDocument: params.ec2IdentityDocument, + IDToken: params.IDToken, + Expires: params.Expires, + }) + if err != nil { + return nil, trace.Wrap(err) + } + } + return certs, nil +} + +// registerThroughAuth is used to register through the auth server. +func registerThroughAuth( + ctx context.Context, token string, params RegisterParams, +) (certs *proto.Certs, err error) { + ctx, span := tracer.Start(ctx, "registerThroughAuth") + defer func() { tracing.EndSpan(span, err) }() + + var client *authclient.Client + // Build a client for the Auth Server with different certificate validation + // depending on the configured values for Insecure, CAPins and CAPath. + switch { + case params.Insecure: + log.Warnf("Insecure mode enabled. Auth Server cert will not be validated and CAPins and CAPath value will be ignored.") + client, err = insecureRegisterClient(params) + case len(params.CAPins) != 0: + // CAPins takes precedence over CAPath + client, err = pinRegisterClient(ctx, params) + case params.CAPath != "": + client, err = caPathRegisterClient(params) + default: + // We fall back to insecure mode here - this is a little odd but is + // necessary to preserve the behavior of registration. At a later date, + // we may consider making this an error asking the user to provide + // Insecure, CAPins or CAPath. + client, err = insecureRegisterClient(params) + } + if err != nil { + return nil, trace.Wrap(err) + } + defer client.Close() + + switch params.JoinMethod { + // IAM and Azure methods use unique gRPC endpoints + case types.JoinMethodIAM: + certs, err = registerUsingIAMMethod(ctx, client, token, params) + case types.JoinMethodAzure: + certs, err = registerUsingAzureMethod(ctx, client, token, params) + case types.JoinMethodTPM: + certs, err = registerUsingTPMMethod(ctx, client, token, params) + default: + // non-IAM join methods use HTTP endpoint + // Get the SSH and X509 certificates for a node. + certs, err = client.RegisterUsingToken( + ctx, + &types.RegisterUsingTokenRequest{ + Token: token, + HostID: params.ID.HostUUID, + NodeName: params.ID.NodeName, + Role: params.ID.Role, + AdditionalPrincipals: params.AdditionalPrincipals, + DNSNames: params.DNSNames, + PublicTLSKey: params.PublicTLSKey, + PublicSSHKey: params.PublicSSHKey, + EC2IdentityDocument: params.ec2IdentityDocument, + IDToken: params.IDToken, + Expires: params.Expires, + }) + } + return certs, trace.Wrap(err) +} + +// proxyJoinServiceConn attempts to connect to the join service running on the +// proxy. The Proxy's TLS cert will be verified using the host's root CA pool +// (PKI) unless the --insecure flag was passed. +func proxyJoinServiceConn( + ctx context.Context, params RegisterParams, insecure bool, +) (*grpc.ClientConn, error) { + tlsConfig := utils.TLSConfig(params.CipherSuites) + tlsConfig.Time = params.Clock.Now + // set NextProtos for TLS routing, the actual protocol will be h2 + tlsConfig.NextProtos = []string{string(common.ProtocolProxyGRPCInsecure), http2.NextProtoTLS} + + if insecure { + tlsConfig.InsecureSkipVerify = true + log.Warnf("Joining cluster without validating the identity of the Proxy Server.") + } + + // Check if proxy is behind a load balancer. If so, the connection upgrade + // will verify the load balancer's cert using system cert pool. This + // provides the same level of security as the client only verifies Proxy's + // web cert against system cert pool when connection upgrade is not + // required. + // + // With the ALPN connection upgrade, the tunneled TLS Routing request will + // skip verify as the Proxy server will present its host cert which is not + // fully verifiable at this point since the client does not have the host + // CAs yet before completing registration. + alpnConnUpgrade := client.IsALPNConnUpgradeRequired(ctx, getHostAddresses(params)[0], insecure) + if alpnConnUpgrade && !insecure { + tlsConfig.InsecureSkipVerify = true + tlsConfig.VerifyConnection = verifyALPNUpgradedConn(params.Clock) + } + + dialer := client.NewDialer( + ctx, + apidefaults.DefaultIdleTimeout, + apidefaults.DefaultIOTimeout, + client.WithInsecureSkipVerify(insecure), + client.WithALPNConnUpgrade(alpnConnUpgrade), + ) + + conn, err := grpc.Dial( + getHostAddresses(params)[0], + grpc.WithContextDialer(client.GRPCContextDialer(dialer)), + grpc.WithUnaryInterceptor(metadata.UnaryClientInterceptor), + grpc.WithStreamInterceptor(metadata.StreamClientInterceptor), + grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)), + ) + return conn, trace.Wrap(err) +} + +func getHostAddresses(params RegisterParams) []string { + if !params.ProxyServer.IsEmpty() { + return []string{params.ProxyServer.String()} + } + + return utils.NetAddrsToStrings(params.AuthServers) +} + +// verifyALPNUpgradedConn is a tls.Config.VerifyConnection callback function +// used by the tunneled TLS Routing request to verify the host cert of a Proxy +// behind a L7 load balancer. +// +// Since the client has not obtained the cluster CAs at this point, the +// presented cert cannot be fully verified yet. For now, this function only +// checks if "teleport.cluster.local" is present as one of the DNS names and +// verifies the cert is not expired. +func verifyALPNUpgradedConn(clock clockwork.Clock) func(tls.ConnectionState) error { + return func(server tls.ConnectionState) error { + for _, cert := range server.PeerCertificates { + if slices.Contains(cert.DNSNames, constants.APIDomain) && clock.Now().Before(cert.NotAfter) { + return nil + } + } + return trace.AccessDenied("server is not a Teleport proxy or server certificate is expired") + } +} + +// insecureRegisterClient attempts to connects to the Auth Server using the +// CA on disk. If no CA is found on disk, Teleport will not verify the Auth +// Server it is connecting to. +func insecureRegisterClient(params RegisterParams) (*authclient.Client, error) { + log.Warnf("Joining cluster without validating the identity of the Auth " + + "Server. This may open you up to a Man-In-The-Middle (MITM) attack if an " + + "attacker can gain privileged network access. To remedy this, use the CA pin " + + "value provided when join token was generated to validate the identity of " + + "the Auth Server or point to a valid Certificate via the CA Path option.") + + tlsConfig := utils.TLSConfig(params.CipherSuites) + tlsConfig.Time = params.Clock.Now + tlsConfig.InsecureSkipVerify = true + + client, err := authclient.NewClient(client.Config{ + Addrs: getHostAddresses(params), + Credentials: []client.Credentials{ + client.LoadTLS(tlsConfig), + }, + CircuitBreakerConfig: params.CircuitBreakerConfig, + }) + if err != nil { + return nil, trace.Wrap(err) + } + + return client, nil +} + +// pinRegisterClient first connects to the Auth Server using a insecure +// connection to fetch the root CA. If the root CA matches the provided CA +// pin, a connection will be re-established and the root CA will be used to +// validate the certificate presented. If both conditions hold true, then we +// know we are connecting to the expected Auth Server. +func pinRegisterClient( + ctx context.Context, params RegisterParams, +) (*authclient.Client, error) { + // Build a insecure client to the Auth Server. This is safe because even if + // an attacker were to MITM this connection the CA pin will not match below. + tlsConfig := utils.TLSConfig(params.CipherSuites) + tlsConfig.InsecureSkipVerify = true + tlsConfig.Time = params.Clock.Now + authClient, err := authclient.NewClient(client.Config{ + Addrs: getHostAddresses(params), + Credentials: []client.Credentials{ + client.LoadTLS(tlsConfig), + }, + CircuitBreakerConfig: params.CircuitBreakerConfig, + }) + if err != nil { + return nil, trace.Wrap(err) + } + defer authClient.Close() + + // Fetch the root CA from the Auth Server. The NOP role has access to the + // GetClusterCACert endpoint. + localCA, err := authClient.GetClusterCACert(ctx) + if err != nil { + return nil, trace.Wrap(err) + } + certs, err := tlsca.ParseCertificatePEMs(localCA.TLSCA) + if err != nil { + return nil, trace.Wrap(err) + } + + // Check that the SPKI pin matches the CA we fetched over a insecure + // connection. This makes sure the CA fetched over a insecure connection is + // in-fact the expected CA. + err = utils.CheckSPKI(params.CAPins, certs) + if err != nil { + return nil, trace.Wrap(err) + } + + for _, cert := range certs { + // Check that the fetched CA is valid at the current time. + err = utils.VerifyCertificateExpiry(cert, params.Clock) + if err != nil { + return nil, trace.Wrap(err) + } + + } + log.Infof("Joining remote cluster %v with CA pin.", certs[0].Subject.CommonName) + + // Create another client, but this time with the CA provided to validate + // that the Auth Server was issued a certificate by the same CA. + tlsConfig = utils.TLSConfig(params.CipherSuites) + tlsConfig.Time = params.Clock.Now + certPool := x509.NewCertPool() + for _, cert := range certs { + certPool.AddCert(cert) + } + tlsConfig.RootCAs = certPool + + authClient, err = authclient.NewClient(client.Config{ + Addrs: getHostAddresses(params), + Credentials: []client.Credentials{ + client.LoadTLS(tlsConfig), + }, + CircuitBreakerConfig: params.CircuitBreakerConfig, + }) + if err != nil { + return nil, trace.Wrap(err) + } + + return authClient, nil +} + +func caPathRegisterClient(params RegisterParams) (*authclient.Client, error) { + tlsConfig := utils.TLSConfig(params.CipherSuites) + tlsConfig.Time = params.Clock.Now + + cert, err := readCA(params.CAPath) + if err != nil && !trace.IsNotFound(err) { + return nil, trace.Wrap(err) + } + + // If we're unable to read the file at CAPath, we fall back to insecure + // registration. This preserves the existing behavior. At a later date, + // we may wish to consider changing this to return an error - but this is a + // breaking change. + if trace.IsNotFound(err) { + log.Warnf("Falling back to insecurely joining because a missing or empty CA Path was provided.") + return insecureRegisterClient(params) + } + + certPool := x509.NewCertPool() + certPool.AddCert(cert) + tlsConfig.RootCAs = certPool + + log.Infof("Joining remote cluster %v, validating connection with certificate on disk.", cert.Subject.CommonName) + + client, err := authclient.NewClient(client.Config{ + Addrs: getHostAddresses(params), + Credentials: []client.Credentials{ + client.LoadTLS(tlsConfig), + }, + CircuitBreakerConfig: params.CircuitBreakerConfig, + }) + if err != nil { + return nil, trace.Wrap(err) + } + + return client, nil +} + +type joinServiceClient interface { + RegisterUsingIAMMethod(ctx context.Context, challengeResponse client.RegisterIAMChallengeResponseFunc) (*proto.Certs, error) + RegisterUsingAzureMethod(ctx context.Context, challengeResponse client.RegisterAzureChallengeResponseFunc) (*proto.Certs, error) + RegisterUsingTPMMethod( + ctx context.Context, + initReq *proto.RegisterUsingTPMMethodInitialRequest, + solveChallenge client.RegisterTPMChallengeResponseFunc, + ) (*proto.Certs, error) +} + +func registerUsingTokenRequestForParams(token string, params RegisterParams) *types.RegisterUsingTokenRequest { + return &types.RegisterUsingTokenRequest{ + Token: token, + HostID: params.ID.HostUUID, + NodeName: params.ID.NodeName, + Role: params.ID.Role, + AdditionalPrincipals: params.AdditionalPrincipals, + DNSNames: params.DNSNames, + PublicTLSKey: params.PublicTLSKey, + PublicSSHKey: params.PublicSSHKey, + Expires: params.Expires, + } +} + +// registerUsingIAMMethod is used to register using the IAM join method. It is +// able to register through a proxy or through the auth server directly. +func registerUsingIAMMethod( + ctx context.Context, joinServiceClient joinServiceClient, token string, params RegisterParams, +) (*proto.Certs, error) { + log.Infof("Attempting to register %s with IAM method using regional STS endpoint", params.ID.Role) + // Call RegisterUsingIAMMethod and pass a callback to respond to the challenge with a signed join request. + certs, err := joinServiceClient.RegisterUsingIAMMethod(ctx, func(challenge string) (*proto.RegisterUsingIAMMethodRequest, error) { + // create the signed sts:GetCallerIdentity request and include the challenge + signedRequest, err := iam.CreateSignedSTSIdentityRequest(ctx, challenge, + iam.WithFIPSEndpoint(params.FIPS), + iam.WithRegionalEndpoint(true), + ) + if err != nil { + return nil, trace.Wrap(err) + } + + // send the register request including the challenge response + return &proto.RegisterUsingIAMMethodRequest{ + RegisterUsingTokenRequest: registerUsingTokenRequestForParams(token, params), + StsIdentityRequest: signedRequest, + }, nil + }) + if err != nil { + log.WithError(err).Infof("Failed to register %s using regional STS endpoint", params.ID.Role) + return nil, trace.Wrap(err) + } + + log.Infof("Successfully registered %s with IAM method using regional STS endpoint", params.ID.Role) + return certs, nil +} + +// registerUsingAzureMethod is used to register using the Azure join method. It +// is able to register through a proxy or through the auth server directly. +func registerUsingAzureMethod( + ctx context.Context, client joinServiceClient, token string, params RegisterParams, +) (*proto.Certs, error) { + certs, err := client.RegisterUsingAzureMethod(ctx, func(challenge string) (*proto.RegisterUsingAzureMethodRequest, error) { + imds := azure.NewInstanceMetadataClient() + if !imds.IsAvailable(ctx) { + return nil, trace.AccessDenied("could not reach instance metadata. Is Teleport running on an Azure VM?") + } + ad, err := imds.GetAttestedData(ctx, challenge) + if err != nil { + return nil, trace.Wrap(err) + } + accessToken, err := imds.GetAccessToken(ctx, params.AzureParams.ClientID) + if err != nil { + return nil, trace.Wrap(err) + } + + return &proto.RegisterUsingAzureMethodRequest{ + RegisterUsingTokenRequest: registerUsingTokenRequestForParams(token, params), + AttestedData: ad, + AccessToken: accessToken, + }, nil + }) + return certs, trace.Wrap(err) +} + +// registerUsingTPMMethod is used to register using the TPM join method. It +// is able to register through a proxy or through the auth server directly. +func registerUsingTPMMethod( + ctx context.Context, + client joinServiceClient, + token string, + params RegisterParams, +) (*proto.Certs, error) { + log := slog.Default() + + initReq := &proto.RegisterUsingTPMMethodInitialRequest{ + JoinRequest: registerUsingTokenRequestForParams(token, params), + } + + attestation, close, err := tpm.Attest(ctx, log) + if err != nil { + return nil, trace.Wrap(err) + } + defer func() { + if err := close(); err != nil { + log.WarnContext(ctx, "Failed to close TPM", "error", err) + } + }() + + initReq.AttestationParams = tpm.AttestationParametersToProto( + attestation.AttestParams, + ) + // Get the EKKey or EKCert. We want to prefer the EKCert if it is available + // as this is signed by the manufacturer. + switch { + case attestation.Data.EKCert != nil: + log.DebugContext( + ctx, + "Using EKCert for TPM registration", + "ekcert_serial", attestation.Data.EKCert.SerialNumber, + ) + initReq.Ek = &proto.RegisterUsingTPMMethodInitialRequest_EkCert{ + EkCert: attestation.Data.EKCert.Raw, + } + case attestation.Data.EKPub != nil: + log.DebugContext( + ctx, + "Using EKKey for TPM registration", + "ekpub_hash", attestation.Data.EKPubHash, + ) + initReq.Ek = &proto.RegisterUsingTPMMethodInitialRequest_EkKey{ + EkKey: attestation.Data.EKPub, + } + default: + return nil, trace.BadParameter("tpm has neither ekkey or ekcert") + } + + // Submit initial request to the Auth Server. + certs, err := client.RegisterUsingTPMMethod( + ctx, + initReq, + func( + challenge *proto.TPMEncryptedCredential, + ) (*proto.RegisterUsingTPMMethodChallengeResponse, error) { + // Solve the encrypted credential with our AK to prove possession + // and obtain the solution we need to complete the ceremony. + solution, err := attestation.Solve(tpm.EncryptedCredentialFromProto( + challenge, + )) + if err != nil { + return nil, trace.Wrap(err, "activating credential") + } + return &proto.RegisterUsingTPMMethodChallengeResponse{ + Solution: solution, + }, nil + }, + ) + return certs, trace.Wrap(err) +} + +// readCA will read in CA that will be used to validate the certificate that +// the Auth Server presents. +func readCA(path string) (*x509.Certificate, error) { + certBytes, err := utils.ReadPath(path) + if err != nil { + return nil, trace.Wrap(err) + } + cert, err := tlsca.ParseCertificatePEM(certBytes) + if err != nil { + return nil, trace.Wrap(err, "failed to parse certificate at %v", path) + } + return cert, nil +} diff --git a/lib/auth/join/join_test.go b/lib/auth/join/join_test.go new file mode 100644 index 0000000000000..7e32eb55438c2 --- /dev/null +++ b/lib/auth/join/join_test.go @@ -0,0 +1,328 @@ +// Teleport +// Copyright (C) 2024 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package join + +import ( + "context" + "crypto/tls" + "errors" + "net" + "os" + "testing" + "time" + + "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" + "github.com/stretchr/testify/require" + "golang.org/x/crypto/ssh" + + headerv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/header/v1" + machineidv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/machineid/v1" + "github.com/gravitational/teleport/api/types" + apievents "github.com/gravitational/teleport/api/types/events" + "github.com/gravitational/teleport/lib/auth" + "github.com/gravitational/teleport/lib/auth/machineid/machineidv1" + "github.com/gravitational/teleport/lib/auth/state" + "github.com/gravitational/teleport/lib/auth/testauthority" + "github.com/gravitational/teleport/lib/defaults" + "github.com/gravitational/teleport/lib/events" + "github.com/gravitational/teleport/lib/fixtures" + "github.com/gravitational/teleport/lib/modules" + "github.com/gravitational/teleport/lib/tlsca" + "github.com/gravitational/teleport/lib/utils" +) + +func TestMain(m *testing.M) { + modules.SetInsecureTestMode(true) + os.Exit(m.Run()) +} + +func newTestTLSServer(t testing.TB) *auth.TestTLSServer { + as, err := auth.NewTestAuthServer(auth.TestAuthServerConfig{ + Dir: t.TempDir(), + Clock: clockwork.NewFakeClockAt(time.Now().Round(time.Second).UTC()), + }) + require.NoError(t, err) + + srv, err := as.NewTestTLSServer() + require.NoError(t, err) + + t.Cleanup(func() { + err := srv.Close() + if errors.Is(err, net.ErrClosed) { + return + } + require.NoError(t, err) + }) + + return srv +} + +// TestRegister_Bot tests that a provision token can be used to generate +// renewable certificates for a non-interactive user. +func TestRegister_Bot(t *testing.T) { + t.Parallel() + ctx := context.Background() + + srv := newTestTLSServer(t) + + bot, err := machineidv1.UpsertBot(ctx, srv.Auth(), &machineidv1pb.Bot{ + Metadata: &headerv1.Metadata{ + Name: "test", + }, + Spec: &machineidv1pb.BotSpec{ + Roles: []string{}, + }, + }, srv.Clock().Now(), "") + require.NoError(t, err) + + later := srv.Clock().Now().Add(4 * time.Hour) + + goodToken := newBotToken(t, "good-token", bot.Metadata.Name, types.RoleBot, later) + expiredToken := newBotToken(t, "expired", bot.Metadata.Name, types.RoleBot, srv.Clock().Now().Add(-1*time.Hour)) + wrongKind := newBotToken(t, "wrong-kind", "", types.RoleNode, later) + wrongUser := newBotToken(t, "wrong-user", "llama", types.RoleBot, later) + invalidToken := newBotToken(t, "this-token-does-not-exist", bot.Metadata.Name, types.RoleBot, later) + + err = srv.Auth().UpsertToken(ctx, goodToken) + require.NoError(t, err) + err = srv.Auth().UpsertToken(ctx, expiredToken) + require.NoError(t, err) + err = srv.Auth().UpsertToken(ctx, wrongKind) + require.NoError(t, err) + err = srv.Auth().UpsertToken(ctx, wrongUser) + require.NoError(t, err) + + privateKey, publicKey, err := testauthority.New().GenerateKeyPair() + require.NoError(t, err) + sshPrivateKey, err := ssh.ParseRawPrivateKey(privateKey) + require.NoError(t, err) + tlsPublicKey, err := tlsca.MarshalPublicKeyFromPrivateKeyPEM(sshPrivateKey) + require.NoError(t, err) + + for _, test := range []struct { + desc string + token types.ProvisionToken + assertErr require.ErrorAssertionFunc + }{ + { + desc: "OK good token", + token: goodToken, + assertErr: require.NoError, + }, + { + desc: "NOK expired token", + token: expiredToken, + assertErr: require.Error, + }, + { + desc: "NOK wrong token kind", + token: wrongKind, + assertErr: require.Error, + }, + { + desc: "NOK token for wrong user", + token: wrongUser, + assertErr: require.Error, + }, + { + desc: "NOK invalid token", + token: invalidToken, + assertErr: require.Error, + }, + } { + t.Run(test.desc, func(t *testing.T) { + start := srv.Clock().Now() + certs, err := Register(ctx, RegisterParams{ + Token: test.token.GetName(), + ID: state.IdentityID{ + Role: types.RoleBot, + }, + AuthServers: []utils.NetAddr{*utils.MustParseAddr(srv.Addr().String())}, + PublicTLSKey: tlsPublicKey, + PublicSSHKey: publicKey, + }) + test.assertErr(t, err) + + if err == nil { + require.NotEmpty(t, certs.SSH) + require.NotEmpty(t, certs.TLS) + + // ensure token was removed + _, err = srv.Auth().GetToken(ctx, test.token.GetName()) + require.True(t, trace.IsNotFound(err), "expected not found error, got %v", err) + + // ensure cert is renewable + x509, err := tlsca.ParseCertificatePEM(certs.TLS) + require.NoError(t, err) + id, err := tlsca.FromSubject(x509.Subject, later) + require.NoError(t, err) + require.True(t, id.Renewable) + + // Check audit event + evts, _, err := srv.Auth().SearchEvents(ctx, events.SearchEventsRequest{ + From: start, + To: srv.Clock().Now(), + EventTypes: []string{events.BotJoinEvent}, + Limit: 1, + Order: types.EventOrderDescending, + }) + require.NoError(t, err) + require.Len(t, evts, 1) + evt, ok := evts[0].(*apievents.BotJoin) + require.True(t, ok) + require.Equal(t, events.BotJoinEvent, evt.Type) + require.Equal(t, events.BotJoinCode, evt.Code) + require.EqualValues(t, types.JoinMethodToken, evt.Method) + } + }) + } +} + +// TestRegister_Bot_Expiry checks that bot certificate expiry can be set, and +// does not exceed the limit. +func TestRegister_Bot_Expiry(t *testing.T) { + t.Parallel() + ctx := context.Background() + + srv := newTestTLSServer(t) + privateKey, publicKey, err := testauthority.New().GenerateKeyPair() + require.NoError(t, err) + sshPrivateKey, err := ssh.ParseRawPrivateKey(privateKey) + require.NoError(t, err) + tlsPublicKey, err := tlsca.MarshalPublicKeyFromPrivateKeyPEM(sshPrivateKey) + require.NoError(t, err) + + validExpires := srv.Clock().Now().Add(time.Hour * 6) + tooGreatExpires := srv.Clock().Now().Add(time.Hour * 24 * 365) + tests := []struct { + name string + requestExpires *time.Time + expectTTL time.Duration + }{ + { + name: "unspecified defaults", + requestExpires: nil, + expectTTL: defaults.DefaultRenewableCertTTL, + }, + { + name: "valid value specified", + requestExpires: &validExpires, + expectTTL: time.Hour * 6, + }, + { + name: "value exceeding limit specified", + requestExpires: &tooGreatExpires, + // MaxSessionTTL set in createBotRole is 12 hours, so this cap will + // apply instead of the defaults.MaxRenewableCertTTL specified + // in generateInitialBotCerts. + expectTTL: 12 * time.Hour, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + botName := t.Name() + _, err = machineidv1.UpsertBot(ctx, srv.Auth(), &machineidv1pb.Bot{ + Metadata: &headerv1.Metadata{ + Name: botName, + }, + Spec: &machineidv1pb.BotSpec{ + Roles: []string{}, + Traits: []*machineidv1pb.Trait{}, + }, + }, srv.Clock().Now(), "") + require.NoError(t, err) + tok := newBotToken(t, t.Name(), botName, types.RoleBot, srv.Clock().Now().Add(time.Hour)) + require.NoError(t, srv.Auth().UpsertToken(ctx, tok)) + + certs, err := Register(ctx, RegisterParams{ + Token: tok.GetName(), + ID: state.IdentityID{ + Role: types.RoleBot, + }, + AuthServers: []utils.NetAddr{*utils.MustParseAddr(srv.Addr().String())}, + PublicTLSKey: tlsPublicKey, + PublicSSHKey: publicKey, + Expires: tt.requestExpires, + }) + require.NoError(t, err) + x509, err := tlsca.ParseCertificatePEM(certs.TLS) + require.NoError(t, err) + id, err := tlsca.FromSubject(x509.Subject, x509.NotAfter) + require.NoError(t, err) + + ttl := id.Expires.Sub(srv.Clock().Now()) + require.Equal(t, tt.expectTTL, ttl) + }) + } +} + +func newBotToken(t *testing.T, tokenName, botName string, role types.SystemRole, expiry time.Time) types.ProvisionToken { + t.Helper() + token, err := types.NewProvisionTokenFromSpec(tokenName, expiry, types.ProvisionTokenSpecV2{ + Roles: []types.SystemRole{role}, + BotName: botName, + }) + require.NoError(t, err, "could not create bot token") + return token +} + +func TestVerifyALPNUpgradedConn(t *testing.T) { + t.Parallel() + + srv := newTestTLSServer(t) + proxy, err := auth.NewServerIdentity(srv.Auth(), "test-proxy", types.RoleProxy) + require.NoError(t, err) + + tests := []struct { + name string + serverCert []byte + clock clockwork.Clock + checkError require.ErrorAssertionFunc + }{ + { + name: "proxy verified", + serverCert: proxy.TLSCertBytes, + clock: srv.Clock(), + checkError: require.NoError, + }, + { + name: "proxy expired", + serverCert: proxy.TLSCertBytes, + clock: clockwork.NewFakeClockAt(srv.Clock().Now().Add(defaults.CATTL + time.Hour)), + checkError: require.Error, + }, + { + name: "not proxy", + serverCert: []byte(fixtures.TLSCACertPEM), + clock: srv.Clock(), + checkError: require.Error, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + serverCert, err := utils.ReadCertificates(test.serverCert) + require.NoError(t, err) + + test.checkError(t, verifyALPNUpgradedConn(test.clock)(tls.ConnectionState{ + PeerCertificates: serverCert, + })) + }) + } +} diff --git a/lib/auth/join_iam.go b/lib/auth/join_iam.go index b52a8b6b35b81..6668a3449b8b5 100644 --- a/lib/auth/join_iam.go +++ b/lib/auth/join_iam.go @@ -27,12 +27,7 @@ import ( "net/http" "net/url" "slices" - "strings" - awssdk "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/endpoints" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/sts" "github.com/coreos/go-semver/semver" "github.com/gravitational/trace" @@ -40,7 +35,7 @@ import ( "github.com/gravitational/teleport/api/client" "github.com/gravitational/teleport/api/client/proto" "github.com/gravitational/teleport/api/types" - cloudaws "github.com/gravitational/teleport/lib/cloud/imds/aws" + "github.com/gravitational/teleport/lib/auth/join/iam" "github.com/gravitational/teleport/lib/utils" "github.com/gravitational/teleport/lib/utils/aws" ) @@ -71,7 +66,7 @@ const ( // against a static list of known valid endpoints. We will need to update this // list as AWS adds new regions. func validateSTSHost(stsHost string, cfg *iamRegisterConfig) error { - valid := slices.Contains(validSTSEndpoints, stsHost) + valid := slices.Contains(iam.ValidSTSEndpoints(), stsHost) if !valid { return trace.AccessDenied("IAM join request uses unknown STS host %q. "+ "This could mean that the Teleport Node attempting to join the cluster is "+ @@ -81,10 +76,10 @@ func validateSTSHost(stsHost string, cfg *iamRegisterConfig) error { "Following is the list of valid STS endpoints known to this auth server. "+ "If a legitimate STS endpoint is not included, please file an issue at "+ "https://github.com/gravitational/teleport. %v", - stsHost, validSTSEndpoints) + stsHost, iam.ValidSTSEndpoints()) } - if cfg.fips && !slices.Contains(fipsSTSEndpoints, stsHost) { + if cfg.fips && !slices.Contains(iam.FIPSSTSEndpoints(), stsHost) { return trace.AccessDenied("node selected non-FIPS STS endpoint (%s) for the IAM join method", stsHost) } @@ -393,124 +388,3 @@ func (a *Server) RegisterUsingIAMMethod( certs, err = a.generateCerts(ctx, provisionToken, req.RegisterUsingTokenRequest, nil) return certs, trace.Wrap(err) } - -type stsIdentityRequestConfig struct { - regionalEndpointOption endpoints.STSRegionalEndpoint - fipsEndpointOption endpoints.FIPSEndpointState -} - -type stsIdentityRequestOption func(cfg *stsIdentityRequestConfig) - -func withRegionalEndpoint(useRegionalEndpoint bool) stsIdentityRequestOption { - return func(cfg *stsIdentityRequestConfig) { - if useRegionalEndpoint { - cfg.regionalEndpointOption = endpoints.RegionalSTSEndpoint - } else { - cfg.regionalEndpointOption = endpoints.LegacySTSEndpoint - } - } -} - -func withFIPSEndpoint(useFIPS bool) stsIdentityRequestOption { - return func(cfg *stsIdentityRequestConfig) { - if useFIPS { - cfg.fipsEndpointOption = endpoints.FIPSEndpointStateEnabled - } else { - cfg.fipsEndpointOption = endpoints.FIPSEndpointStateDisabled - } - } -} - -// createSignedSTSIdentityRequest is called on the client side and returns an -// sts:GetCallerIdentity request signed with the local AWS credentials -func createSignedSTSIdentityRequest(ctx context.Context, challenge string, opts ...stsIdentityRequestOption) ([]byte, error) { - cfg := &stsIdentityRequestConfig{} - for _, opt := range opts { - opt(cfg) - } - - stsClient, err := newSTSClient(ctx, cfg) - if err != nil { - return nil, trace.Wrap(err) - } - - req, _ := stsClient.GetCallerIdentityRequest(&sts.GetCallerIdentityInput{}) - // set challenge header - req.HTTPRequest.Header.Set(challengeHeaderKey, challenge) - // request json for simpler parsing - req.HTTPRequest.Header.Set("Accept", "application/json") - // sign the request, including headers - if err := req.Sign(); err != nil { - return nil, trace.Wrap(err) - } - // write the signed HTTP request to a buffer - var signedRequest bytes.Buffer - if err := req.HTTPRequest.Write(&signedRequest); err != nil { - return nil, trace.Wrap(err) - } - return signedRequest.Bytes(), nil -} - -func newSTSClient(ctx context.Context, cfg *stsIdentityRequestConfig) (*sts.STS, error) { - awsConfig := awssdk.Config{ - UseFIPSEndpoint: cfg.fipsEndpointOption, - STSRegionalEndpoint: cfg.regionalEndpointOption, - } - sess, err := session.NewSessionWithOptions(session.Options{ - SharedConfigState: session.SharedConfigEnable, - Config: awsConfig, - }) - if err != nil { - return nil, trace.Wrap(err) - } - - stsClient := sts.New(sess) - - if slices.Contains(globalSTSEndpoints, strings.TrimPrefix(stsClient.Endpoint, "https://")) { - // If the caller wants to use the regional endpoint but it was not resolved - // from the environment, attempt to find the region from the EC2 IMDS - if cfg.regionalEndpointOption == endpoints.RegionalSTSEndpoint { - region, err := getEC2LocalRegion(ctx) - if err != nil { - return nil, trace.Wrap(err, "failed to resolve local AWS region from environment or IMDS") - } - stsClient = sts.New(sess, awssdk.NewConfig().WithRegion(region)) - } else { - log.Info("Attempting to use the global STS endpoint for the IAM join method. " + - "This will probably fail in non-default AWS partitions such as China or GovCloud, or if FIPS mode is enabled. " + - "Consider setting the AWS_REGION environment variable, setting the region in ~/.aws/config, or enabling the IMDSv2.") - } - } - - if cfg.fipsEndpointOption == endpoints.FIPSEndpointStateEnabled && - !slices.Contains(validSTSEndpoints, strings.TrimPrefix(stsClient.Endpoint, "https://")) { - // The AWS SDK will generate invalid endpoints when attempting to - // resolve the FIPS endpoint for a region that does not have one. - // In this case, try to use the FIPS endpoint in us-east-1. This should - // work for all regions in the standard partition. In GovCloud, we should - // not hit this because all regional endpoints support FIPS. In China or - // other partitions, this will fail, and FIPS mode will not be supported. - log.Infof("AWS SDK resolved FIPS STS endpoint %s, which does not appear to be valid. "+ - "Attempting to use the FIPS STS endpoint for us-east-1.", - stsClient.Endpoint) - stsClient = sts.New(sess, awssdk.NewConfig().WithRegion("us-east-1")) - } - - return stsClient, nil -} - -// getEC2LocalRegion returns the AWS region this EC2 instance is running in, or -// a NotFound error if the EC2 IMDS is unavailable. -func getEC2LocalRegion(ctx context.Context) (string, error) { - imdsClient, err := cloudaws.NewInstanceMetadataClient(ctx) - if err != nil { - return "", trace.Wrap(err) - } - - if !imdsClient.IsAvailable(ctx) { - return "", trace.NotFound("IMDS is unavailable") - } - - region, err := imdsClient.GetRegion(ctx) - return region, trace.Wrap(err) -} diff --git a/lib/auth/join_test.go b/lib/auth/join_test.go index 9e207f9667a3f..6ce94edcc0ab6 100644 --- a/lib/auth/join_test.go +++ b/lib/auth/join_test.go @@ -34,6 +34,7 @@ import ( "github.com/gravitational/teleport/api/types" apievents "github.com/gravitational/teleport/api/types/events" "github.com/gravitational/teleport/api/utils/sshutils" + "github.com/gravitational/teleport/lib/auth/join" "github.com/gravitational/teleport/lib/auth/machineid/machineidv1" "github.com/gravitational/teleport/lib/auth/state" "github.com/gravitational/teleport/lib/auth/testauthority" @@ -371,7 +372,7 @@ func TestRegister_Bot(t *testing.T) { } { t.Run(test.desc, func(t *testing.T) { start := srv.Clock().Now() - certs, err := Register(ctx, RegisterParams{ + certs, err := join.Register(ctx, join.RegisterParams{ Token: test.token.GetName(), ID: state.IdentityID{ Role: types.RoleBot, @@ -474,7 +475,7 @@ func TestRegister_Bot_Expiry(t *testing.T) { tok := newBotToken(t, t.Name(), botName, types.RoleBot, srv.Clock().Now().Add(time.Hour)) require.NoError(t, srv.Auth().UpsertToken(ctx, tok)) - certs, err := Register(ctx, RegisterParams{ + certs, err := join.Register(ctx, join.RegisterParams{ Token: tok.GetName(), ID: state.IdentityID{ Role: types.RoleBot, diff --git a/lib/auth/register.go b/lib/auth/register.go index 24e48594d6bc4..b425bdf8c3ff2 100644 --- a/lib/auth/register.go +++ b/lib/auth/register.go @@ -20,44 +20,15 @@ package auth import ( "context" - "crypto/tls" - "crypto/x509" - "log/slog" - "os" - "slices" - "time" "github.com/gravitational/trace" - "github.com/jonboulle/clockwork" - "golang.org/x/net/http2" - "google.golang.org/grpc" - "google.golang.org/grpc/credentials" - "github.com/gravitational/teleport" - "github.com/gravitational/teleport/api/breaker" - "github.com/gravitational/teleport/api/client" "github.com/gravitational/teleport/api/client/proto" - "github.com/gravitational/teleport/api/constants" - apidefaults "github.com/gravitational/teleport/api/defaults" - "github.com/gravitational/teleport/api/metadata" - "github.com/gravitational/teleport/api/observability/tracing" "github.com/gravitational/teleport/api/types" - "github.com/gravitational/teleport/api/utils/aws" "github.com/gravitational/teleport/lib/auth/authclient" "github.com/gravitational/teleport/lib/auth/native" "github.com/gravitational/teleport/lib/auth/state" - "github.com/gravitational/teleport/lib/circleci" - "github.com/gravitational/teleport/lib/cloud/imds/azure" - "github.com/gravitational/teleport/lib/cloud/imds/gcp" "github.com/gravitational/teleport/lib/defaults" - "github.com/gravitational/teleport/lib/githubactions" - "github.com/gravitational/teleport/lib/gitlab" - "github.com/gravitational/teleport/lib/kubernetestoken" - "github.com/gravitational/teleport/lib/spacelift" - "github.com/gravitational/teleport/lib/srv/alpnproxy/common" - "github.com/gravitational/teleport/lib/tlsca" - "github.com/gravitational/teleport/lib/tpm" - "github.com/gravitational/teleport/lib/utils" ) // LocalRegister is used to generate host keys when a node or proxy is running @@ -104,739 +75,6 @@ func LocalRegister(id state.IdentityID, authServer *Server, additionalPrincipals return identity, nil } -// AzureParams is the parameters specific to the azure join method. -type AzureParams struct { - // ClientID is the client ID of the managed identity for Teleport to assume - // when authenticating a node. - ClientID string -} - -// RegisterParams specifies parameters -// for first time register operation with auth server -type RegisterParams struct { - // Token is a secure token to join the cluster - Token string - // ID is identity ID - ID state.IdentityID - // AuthServers is a list of auth servers to dial - AuthServers []utils.NetAddr - // ProxyServer is a proxy server to dial - ProxyServer utils.NetAddr - // AdditionalPrincipals is a list of additional principals to dial - AdditionalPrincipals []string - // DNSNames is a list of DNS names to add to x509 certificate - DNSNames []string - // PublicTLSKey is a server's public key to sign - PublicTLSKey []byte - // PublicSSHKey is a server's public SSH key to sign - PublicSSHKey []byte - // CipherSuites is a list of cipher suites to use for TLS client connection - CipherSuites []uint16 - // CAPins are the SKPI hashes of the CAs used to verify the Auth Server. - CAPins []string - // CAPath is the path to the CA file. - CAPath string - // GetHostCredentials is a client that can fetch host credentials. - GetHostCredentials HostCredentials - // Clock specifies the time provider. Will be used to override the time anchor - // for TLS certificate verification. - // Defaults to real clock if unspecified - Clock clockwork.Clock - // JoinMethod is the joining method used for this register request. - JoinMethod types.JoinMethod - // ec2IdentityDocument is used for Simplified Node Joining to prove the - // identity of a joining EC2 instance. - ec2IdentityDocument []byte - // AzureParams is the parameters specific to the azure join method. - AzureParams AzureParams - // CircuitBreakerConfig defines how the circuit breaker should behave. - CircuitBreakerConfig breaker.Config - // FIPS means FedRAMP/FIPS 140-2 compliant configuration was requested. - FIPS bool - // IDToken is a token retrieved from a workload identity provider for - // certain join types e.g GitHub, Google. - IDToken string - // Expires is an optional field for bots that specifies a time that the - // certificates that are returned by registering should expire at. - // It should not be specified for non-bot registrations. - Expires *time.Time - // Insecure trusts the certificates from the Auth Server or Proxy during registration without verification. - Insecure bool -} - -func (r *RegisterParams) checkAndSetDefaults() error { - if r.Clock == nil { - r.Clock = clockwork.NewRealClock() - } - - if err := r.verifyAuthOrProxyAddress(); err != nil { - return trace.BadParameter("no auth or proxy servers set") - } - - return nil -} - -func (r *RegisterParams) verifyAuthOrProxyAddress() error { - haveAuthServers := len(r.AuthServers) > 0 - haveProxyServer := !r.ProxyServer.IsEmpty() - - if !haveAuthServers && !haveProxyServer { - return trace.BadParameter("no auth or proxy servers set") - } - - if haveAuthServers && haveProxyServer { - return trace.BadParameter("only one of auth or proxy server should be set") - } - - return nil -} - -// CredGetter is an interface for a client that can be used to get host -// credentials. This interface is needed because lib/client can not be imported -// in lib/auth due to circular imports. -type HostCredentials func(context.Context, string, bool, types.RegisterUsingTokenRequest) (*proto.Certs, error) - -// Register is used to generate host keys when a node or proxy are running on -// different hosts than the auth server. This method requires provisioning -// tokens to prove a valid auth server was used to issue the joining request -// as well as a method for the node to validate the auth server. -func Register(ctx context.Context, params RegisterParams) (certs *proto.Certs, err error) { - ctx, span := tracer.Start(ctx, "Register") - defer func() { tracing.EndSpan(span, err) }() - - if err := params.checkAndSetDefaults(); err != nil { - return nil, trace.Wrap(err) - } - // Read in the token. The token can either be passed in or come from a file - // on disk. - token, err := utils.TryReadValueAsFile(params.Token) - if err != nil { - return nil, trace.Wrap(err) - } - - // add EC2 Identity Document to params if required for given join method - switch params.JoinMethod { - case types.JoinMethodEC2: - if !aws.IsEC2NodeID(params.ID.HostUUID) { - return nil, trace.BadParameter( - `Host ID %q is not valid when using the EC2 join method, `+ - `try removing the "host_uuid" file in your teleport data dir `+ - `(e.g. /var/lib/teleport/host_uuid)`, - params.ID.HostUUID) - } - params.ec2IdentityDocument, err = utils.GetRawEC2IdentityDocument(ctx) - if err != nil { - return nil, trace.Wrap(err) - } - case types.JoinMethodGitHub: - params.IDToken, err = githubactions.NewIDTokenSource().GetIDToken(ctx) - if err != nil { - return nil, trace.Wrap(err) - } - case types.JoinMethodGitLab: - params.IDToken, err = gitlab.NewIDTokenSource(os.Getenv).GetIDToken() - if err != nil { - return nil, trace.Wrap(err) - } - case types.JoinMethodCircleCI: - params.IDToken, err = circleci.GetIDToken(os.Getenv) - if err != nil { - return nil, trace.Wrap(err) - } - case types.JoinMethodKubernetes: - params.IDToken, err = kubernetestoken.GetIDToken(os.Getenv, os.ReadFile) - if err != nil { - return nil, trace.Wrap(err) - } - case types.JoinMethodGCP: - params.IDToken, err = gcp.GetIDToken(ctx) - if err != nil { - return nil, trace.Wrap(err) - } - case types.JoinMethodSpacelift: - params.IDToken, err = spacelift.NewIDTokenSource(os.Getenv).GetIDToken() - if err != nil { - return nil, trace.Wrap(err) - } - } - - type registerMethod struct { - call func(ctx context.Context, token string, params RegisterParams) (*proto.Certs, error) - desc string - } - - registerThroughAuth := registerMethod{registerThroughAuth, "with auth server"} - registerThroughProxy := registerMethod{registerThroughProxy, "via proxy server"} - - registerMethods := []registerMethod{registerThroughAuth, registerThroughProxy} - - if !params.ProxyServer.IsEmpty() { - log.WithField("proxy-server", params.ProxyServer).Debugf("Registering node to the cluster.") - - registerMethods = []registerMethod{registerThroughProxy} - - if proxyServerIsAuth(params.ProxyServer) { - log.Debugf("The specified proxy server appears to be an auth server.") - } - } else { - log.WithField("auth-servers", params.AuthServers).Debugf("Registering node to the cluster.") - - if params.GetHostCredentials == nil { - log.Debugf("Missing client, it is not possible to register through proxy.") - registerMethods = []registerMethod{registerThroughAuth} - } else if authServerIsProxy(params.AuthServers) { - log.Debugf("The first specified auth server appears to be a proxy.") - registerMethods = []registerMethod{registerThroughProxy, registerThroughAuth} - } - } - - var collectedErrs []error - for _, method := range registerMethods { - log.Infof("Attempting registration %s.", method.desc) - certs, err := method.call(ctx, token, params) - if err != nil { - collectedErrs = append(collectedErrs, err) - log.WithError(err).Debugf("Registration %s failed.", method.desc) - continue - } - log.Infof("Successfully registered %s.", method.desc) - return certs, nil - } - return nil, trace.NewAggregate(collectedErrs...) -} - -// authServerIsProxy returns true if the first specified auth server -// to register with appears to be a proxy. -func authServerIsProxy(servers []utils.NetAddr) bool { - if len(servers) == 0 { - return false - } - port := servers[0].Port(0) - return port == defaults.HTTPListenPort || port == teleport.StandardHTTPSPort -} - -// proxyServerIsAuth returns true if the address given to register with -// appears to be an auth server. -func proxyServerIsAuth(server utils.NetAddr) bool { - port := server.Port(0) - return port == defaults.AuthListenPort -} - -// registerThroughProxy is used to register through the proxy server. -func registerThroughProxy( - ctx context.Context, - token string, - params RegisterParams, -) (certs *proto.Certs, err error) { - ctx, span := tracer.Start(ctx, "registerThroughProxy") - defer func() { tracing.EndSpan(span, err) }() - - switch params.JoinMethod { - case types.JoinMethodIAM, types.JoinMethodAzure, types.JoinMethodTPM: - // IAM and Azure join methods require gRPC client - conn, err := proxyJoinServiceConn(ctx, params, params.Insecure) - if err != nil { - return nil, trace.Wrap(err) - } - defer conn.Close() - - joinServiceClient := client.NewJoinServiceClient(proto.NewJoinServiceClient(conn)) - switch params.JoinMethod { - case types.JoinMethodIAM: - certs, err = registerUsingIAMMethod(ctx, joinServiceClient, token, params) - case types.JoinMethodAzure: - certs, err = registerUsingAzureMethod(ctx, joinServiceClient, token, params) - case types.JoinMethodTPM: - certs, err = registerUsingTPMMethod(ctx, joinServiceClient, token, params) - default: - return nil, trace.BadParameter("unhandled join method %q", params.JoinMethod) - } - - if err != nil { - return nil, trace.Wrap(err) - } - default: - // The rest of the join methods use GetHostCredentials function passed through - // params to call proxy HTTP endpoint - var err error - certs, err = params.GetHostCredentials(ctx, - getHostAddresses(params)[0], - params.Insecure, - types.RegisterUsingTokenRequest{ - Token: token, - HostID: params.ID.HostUUID, - NodeName: params.ID.NodeName, - Role: params.ID.Role, - AdditionalPrincipals: params.AdditionalPrincipals, - DNSNames: params.DNSNames, - PublicTLSKey: params.PublicTLSKey, - PublicSSHKey: params.PublicSSHKey, - EC2IdentityDocument: params.ec2IdentityDocument, - IDToken: params.IDToken, - Expires: params.Expires, - }) - if err != nil { - return nil, trace.Wrap(err) - } - } - return certs, nil -} - -func getHostAddresses(params RegisterParams) []string { - if !params.ProxyServer.IsEmpty() { - return []string{params.ProxyServer.String()} - } - - return utils.NetAddrsToStrings(params.AuthServers) -} - -// registerThroughAuth is used to register through the auth server. -func registerThroughAuth( - ctx context.Context, token string, params RegisterParams, -) (certs *proto.Certs, err error) { - ctx, span := tracer.Start(ctx, "registerThroughAuth") - defer func() { tracing.EndSpan(span, err) }() - - var client *authclient.Client - // Build a client for the Auth Server with different certificate validation - // depending on the configured values for Insecure, CAPins and CAPath. - switch { - case params.Insecure: - log.Warnf("Insecure mode enabled. Auth Server cert will not be validated and CAPins and CAPath value will be ignored.") - client, err = insecureRegisterClient(params) - case len(params.CAPins) != 0: - // CAPins takes precedence over CAPath - client, err = pinRegisterClient(ctx, params) - case params.CAPath != "": - client, err = caPathRegisterClient(params) - default: - // We fall back to insecure mode here - this is a little odd but is - // necessary to preserve the behavior of registration. At a later date, - // we may consider making this an error asking the user to provide - // Insecure, CAPins or CAPath. - client, err = insecureRegisterClient(params) - } - if err != nil { - return nil, trace.Wrap(err) - } - defer client.Close() - - switch params.JoinMethod { - // IAM and Azure methods use unique gRPC endpoints - case types.JoinMethodIAM: - certs, err = registerUsingIAMMethod(ctx, client, token, params) - case types.JoinMethodAzure: - certs, err = registerUsingAzureMethod(ctx, client, token, params) - case types.JoinMethodTPM: - certs, err = registerUsingTPMMethod(ctx, client, token, params) - default: - // non-IAM join methods use HTTP endpoint - // Get the SSH and X509 certificates for a node. - certs, err = client.RegisterUsingToken( - ctx, - &types.RegisterUsingTokenRequest{ - Token: token, - HostID: params.ID.HostUUID, - NodeName: params.ID.NodeName, - Role: params.ID.Role, - AdditionalPrincipals: params.AdditionalPrincipals, - DNSNames: params.DNSNames, - PublicTLSKey: params.PublicTLSKey, - PublicSSHKey: params.PublicSSHKey, - EC2IdentityDocument: params.ec2IdentityDocument, - IDToken: params.IDToken, - Expires: params.Expires, - }) - } - return certs, trace.Wrap(err) -} - -// proxyJoinServiceConn attempts to connect to the join service running on the -// proxy. The Proxy's TLS cert will be verified using the host's root CA pool -// (PKI) unless the --insecure flag was passed. -func proxyJoinServiceConn( - ctx context.Context, params RegisterParams, insecure bool, -) (*grpc.ClientConn, error) { - tlsConfig := utils.TLSConfig(params.CipherSuites) - tlsConfig.Time = params.Clock.Now - // set NextProtos for TLS routing, the actual protocol will be h2 - tlsConfig.NextProtos = []string{string(common.ProtocolProxyGRPCInsecure), http2.NextProtoTLS} - - if insecure { - tlsConfig.InsecureSkipVerify = true - log.Warnf("Joining cluster without validating the identity of the Proxy Server.") - } - - // Check if proxy is behind a load balancer. If so, the connection upgrade - // will verify the load balancer's cert using system cert pool. This - // provides the same level of security as the client only verifies Proxy's - // web cert against system cert pool when connection upgrade is not - // required. - // - // With the ALPN connection upgrade, the tunneled TLS Routing request will - // skip verify as the Proxy server will present its host cert which is not - // fully verifiable at this point since the client does not have the host - // CAs yet before completing registration. - alpnConnUpgrade := client.IsALPNConnUpgradeRequired(ctx, getHostAddresses(params)[0], insecure) - if alpnConnUpgrade && !insecure { - tlsConfig.InsecureSkipVerify = true - tlsConfig.VerifyConnection = verifyALPNUpgradedConn(params.Clock) - } - - dialer := client.NewDialer( - ctx, - apidefaults.DefaultIdleTimeout, - apidefaults.DefaultIOTimeout, - client.WithInsecureSkipVerify(insecure), - client.WithALPNConnUpgrade(alpnConnUpgrade), - ) - - conn, err := grpc.Dial( - getHostAddresses(params)[0], - grpc.WithContextDialer(client.GRPCContextDialer(dialer)), - grpc.WithUnaryInterceptor(metadata.UnaryClientInterceptor), - grpc.WithStreamInterceptor(metadata.StreamClientInterceptor), - grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)), - ) - return conn, trace.Wrap(err) -} - -// verifyALPNUpgradedConn is a tls.Config.VerifyConnection callback function -// used by the tunneled TLS Routing request to verify the host cert of a Proxy -// behind a L7 load balancer. -// -// Since the client has not obtained the cluster CAs at this point, the -// presented cert cannot be fully verified yet. For now, this function only -// checks if "teleport.cluster.local" is present as one of the DNS names and -// verifies the cert is not expired. -func verifyALPNUpgradedConn(clock clockwork.Clock) func(tls.ConnectionState) error { - return func(server tls.ConnectionState) error { - for _, cert := range server.PeerCertificates { - if slices.Contains(cert.DNSNames, constants.APIDomain) && clock.Now().Before(cert.NotAfter) { - return nil - } - } - return trace.AccessDenied("server is not a Teleport proxy or server certificate is expired") - } -} - -// insecureRegisterClient attempts to connects to the Auth Server using the -// CA on disk. If no CA is found on disk, Teleport will not verify the Auth -// Server it is connecting to. -func insecureRegisterClient(params RegisterParams) (*authclient.Client, error) { - log.Warnf("Joining cluster without validating the identity of the Auth " + - "Server. This may open you up to a Man-In-The-Middle (MITM) attack if an " + - "attacker can gain privileged network access. To remedy this, use the CA pin " + - "value provided when join token was generated to validate the identity of " + - "the Auth Server or point to a valid Certificate via the CA Path option.") - - tlsConfig := utils.TLSConfig(params.CipherSuites) - tlsConfig.Time = params.Clock.Now - tlsConfig.InsecureSkipVerify = true - - client, err := authclient.NewClient(client.Config{ - Addrs: getHostAddresses(params), - Credentials: []client.Credentials{ - client.LoadTLS(tlsConfig), - }, - CircuitBreakerConfig: params.CircuitBreakerConfig, - }) - if err != nil { - return nil, trace.Wrap(err) - } - - return client, nil -} - -// readCA will read in CA that will be used to validate the certificate that -// the Auth Server presents. -func readCA(path string) (*x509.Certificate, error) { - certBytes, err := utils.ReadPath(path) - if err != nil { - return nil, trace.Wrap(err) - } - cert, err := tlsca.ParseCertificatePEM(certBytes) - if err != nil { - return nil, trace.Wrap(err, "failed to parse certificate at %v", path) - } - return cert, nil -} - -// pinRegisterClient first connects to the Auth Server using a insecure -// connection to fetch the root CA. If the root CA matches the provided CA -// pin, a connection will be re-established and the root CA will be used to -// validate the certificate presented. If both conditions hold true, then we -// know we are connecting to the expected Auth Server. -func pinRegisterClient( - ctx context.Context, params RegisterParams, -) (*authclient.Client, error) { - // Build a insecure client to the Auth Server. This is safe because even if - // an attacker were to MITM this connection the CA pin will not match below. - tlsConfig := utils.TLSConfig(params.CipherSuites) - tlsConfig.InsecureSkipVerify = true - tlsConfig.Time = params.Clock.Now - authClient, err := authclient.NewClient(client.Config{ - Addrs: getHostAddresses(params), - Credentials: []client.Credentials{ - client.LoadTLS(tlsConfig), - }, - CircuitBreakerConfig: params.CircuitBreakerConfig, - }) - if err != nil { - return nil, trace.Wrap(err) - } - defer authClient.Close() - - // Fetch the root CA from the Auth Server. The NOP role has access to the - // GetClusterCACert endpoint. - localCA, err := authClient.GetClusterCACert(ctx) - if err != nil { - return nil, trace.Wrap(err) - } - certs, err := tlsca.ParseCertificatePEMs(localCA.TLSCA) - if err != nil { - return nil, trace.Wrap(err) - } - - // Check that the SPKI pin matches the CA we fetched over a insecure - // connection. This makes sure the CA fetched over a insecure connection is - // in-fact the expected CA. - err = utils.CheckSPKI(params.CAPins, certs) - if err != nil { - return nil, trace.Wrap(err) - } - - for _, cert := range certs { - // Check that the fetched CA is valid at the current time. - err = utils.VerifyCertificateExpiry(cert, params.Clock) - if err != nil { - return nil, trace.Wrap(err) - } - - } - log.Infof("Joining remote cluster %v with CA pin.", certs[0].Subject.CommonName) - - // Create another client, but this time with the CA provided to validate - // that the Auth Server was issued a certificate by the same CA. - tlsConfig = utils.TLSConfig(params.CipherSuites) - tlsConfig.Time = params.Clock.Now - certPool := x509.NewCertPool() - for _, cert := range certs { - certPool.AddCert(cert) - } - tlsConfig.RootCAs = certPool - - authClient, err = authclient.NewClient(client.Config{ - Addrs: getHostAddresses(params), - Credentials: []client.Credentials{ - client.LoadTLS(tlsConfig), - }, - CircuitBreakerConfig: params.CircuitBreakerConfig, - }) - if err != nil { - return nil, trace.Wrap(err) - } - - return authClient, nil -} - -func caPathRegisterClient(params RegisterParams) (*authclient.Client, error) { - tlsConfig := utils.TLSConfig(params.CipherSuites) - tlsConfig.Time = params.Clock.Now - - cert, err := readCA(params.CAPath) - if err != nil && !trace.IsNotFound(err) { - return nil, trace.Wrap(err) - } - - // If we're unable to read the file at CAPath, we fall back to insecure - // registration. This preserves the existing behavior. At a later date, - // we may wish to consider changing this to return an error - but this is a - // breaking change. - if trace.IsNotFound(err) { - log.Warnf("Falling back to insecurely joining because a missing or empty CA Path was provided.") - return insecureRegisterClient(params) - } - - certPool := x509.NewCertPool() - certPool.AddCert(cert) - tlsConfig.RootCAs = certPool - - log.Infof("Joining remote cluster %v, validating connection with certificate on disk.", cert.Subject.CommonName) - - client, err := authclient.NewClient(client.Config{ - Addrs: getHostAddresses(params), - Credentials: []client.Credentials{ - client.LoadTLS(tlsConfig), - }, - CircuitBreakerConfig: params.CircuitBreakerConfig, - }) - if err != nil { - return nil, trace.Wrap(err) - } - - return client, nil -} - -type joinServiceClient interface { - RegisterUsingIAMMethod(ctx context.Context, challengeResponse client.RegisterIAMChallengeResponseFunc) (*proto.Certs, error) - RegisterUsingAzureMethod(ctx context.Context, challengeResponse client.RegisterAzureChallengeResponseFunc) (*proto.Certs, error) - RegisterUsingTPMMethod( - ctx context.Context, - initReq *proto.RegisterUsingTPMMethodInitialRequest, - solveChallenge client.RegisterTPMChallengeResponseFunc, - ) (*proto.Certs, error) -} - -func registerUsingTokenRequestForParams(token string, params RegisterParams) *types.RegisterUsingTokenRequest { - return &types.RegisterUsingTokenRequest{ - Token: token, - HostID: params.ID.HostUUID, - NodeName: params.ID.NodeName, - Role: params.ID.Role, - AdditionalPrincipals: params.AdditionalPrincipals, - DNSNames: params.DNSNames, - PublicTLSKey: params.PublicTLSKey, - PublicSSHKey: params.PublicSSHKey, - Expires: params.Expires, - } -} - -// registerUsingIAMMethod is used to register using the IAM join method. It is -// able to register through a proxy or through the auth server directly. -func registerUsingIAMMethod( - ctx context.Context, joinServiceClient joinServiceClient, token string, params RegisterParams, -) (*proto.Certs, error) { - log.Infof("Attempting to register %s with IAM method using regional STS endpoint", params.ID.Role) - // Call RegisterUsingIAMMethod and pass a callback to respond to the challenge with a signed join request. - certs, err := joinServiceClient.RegisterUsingIAMMethod(ctx, func(challenge string) (*proto.RegisterUsingIAMMethodRequest, error) { - // create the signed sts:GetCallerIdentity request and include the challenge - signedRequest, err := createSignedSTSIdentityRequest(ctx, challenge, - withFIPSEndpoint(params.FIPS), - withRegionalEndpoint(true), - ) - if err != nil { - return nil, trace.Wrap(err) - } - - // send the register request including the challenge response - return &proto.RegisterUsingIAMMethodRequest{ - RegisterUsingTokenRequest: registerUsingTokenRequestForParams(token, params), - StsIdentityRequest: signedRequest, - }, nil - }) - if err != nil { - log.WithError(err).Infof("Failed to register %s using regional STS endpoint", params.ID.Role) - return nil, trace.Wrap(err) - } - - log.Infof("Successfully registered %s with IAM method using regional STS endpoint", params.ID.Role) - return certs, nil -} - -// registerUsingAzureMethod is used to register using the Azure join method. It -// is able to register through a proxy or through the auth server directly. -func registerUsingAzureMethod( - ctx context.Context, client joinServiceClient, token string, params RegisterParams, -) (*proto.Certs, error) { - certs, err := client.RegisterUsingAzureMethod(ctx, func(challenge string) (*proto.RegisterUsingAzureMethodRequest, error) { - imds := azure.NewInstanceMetadataClient() - if !imds.IsAvailable(ctx) { - return nil, trace.AccessDenied("could not reach instance metadata. Is Teleport running on an Azure VM?") - } - ad, err := imds.GetAttestedData(ctx, challenge) - if err != nil { - return nil, trace.Wrap(err) - } - accessToken, err := imds.GetAccessToken(ctx, params.AzureParams.ClientID) - if err != nil { - return nil, trace.Wrap(err) - } - - return &proto.RegisterUsingAzureMethodRequest{ - RegisterUsingTokenRequest: registerUsingTokenRequestForParams(token, params), - AttestedData: ad, - AccessToken: accessToken, - }, nil - }) - return certs, trace.Wrap(err) -} - -// registerUsingTPMMethod is used to register using the TPM join method. It -// is able to register through a proxy or through the auth server directly. -func registerUsingTPMMethod( - ctx context.Context, - client joinServiceClient, - token string, - params RegisterParams, -) (*proto.Certs, error) { - log := slog.Default() - - initReq := &proto.RegisterUsingTPMMethodInitialRequest{ - JoinRequest: registerUsingTokenRequestForParams(token, params), - } - - attestation, close, err := tpm.Attest(ctx, log) - if err != nil { - return nil, trace.Wrap(err) - } - defer func() { - if err := close(); err != nil { - log.WarnContext(ctx, "Failed to close TPM", "error", err) - } - }() - - initReq.AttestationParams = tpm.AttestationParametersToProto( - attestation.AttestParams, - ) - // Get the EKKey or EKCert. We want to prefer the EKCert if it is available - // as this is signed by the manufacturer. - switch { - case attestation.Data.EKCert != nil: - log.DebugContext( - ctx, - "Using EKCert for TPM registration", - "ekcert_serial", attestation.Data.EKCert.SerialNumber, - ) - initReq.Ek = &proto.RegisterUsingTPMMethodInitialRequest_EkCert{ - EkCert: attestation.Data.EKCert.Raw, - } - case attestation.Data.EKPub != nil: - log.DebugContext( - ctx, - "Using EKKey for TPM registration", - "ekpub_hash", attestation.Data.EKPubHash, - ) - initReq.Ek = &proto.RegisterUsingTPMMethodInitialRequest_EkKey{ - EkKey: attestation.Data.EKPub, - } - default: - return nil, trace.BadParameter("tpm has neither ekkey or ekcert") - } - - // Submit initial request to the Auth Server. - certs, err := client.RegisterUsingTPMMethod( - ctx, - initReq, - func( - challenge *proto.TPMEncryptedCredential, - ) (*proto.RegisterUsingTPMMethodChallengeResponse, error) { - // Solve the encrypted credential with our AK to prove possession - // and obtain the solution we need to complete the ceremony. - solution, err := attestation.Solve(tpm.EncryptedCredentialFromProto( - challenge, - )) - if err != nil { - return nil, trace.Wrap(err, "activating credential") - } - return &proto.RegisterUsingTPMMethodChallengeResponse{ - Solution: solution, - }, nil - }, - ) - return certs, trace.Wrap(err) -} - // ReRegisterParams specifies parameters for re-registering // in the cluster (rotating certificates for existing members) type ReRegisterParams struct { diff --git a/lib/auth/register_test.go b/lib/auth/register_test.go deleted file mode 100644 index 5f1d248ace8f1..0000000000000 --- a/lib/auth/register_test.go +++ /dev/null @@ -1,78 +0,0 @@ -/* - * Teleport - * Copyright (C) 2023 Gravitational, Inc. - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with this program. If not, see . - */ - -package auth - -import ( - "crypto/tls" - "testing" - "time" - - "github.com/jonboulle/clockwork" - "github.com/stretchr/testify/require" - - "github.com/gravitational/teleport/api/types" - "github.com/gravitational/teleport/lib/defaults" - "github.com/gravitational/teleport/lib/fixtures" - "github.com/gravitational/teleport/lib/utils" -) - -func TestVerifyALPNUpgradedConn(t *testing.T) { - t.Parallel() - - auth := newTestTLSServer(t) - proxy, err := NewServerIdentity(auth.Auth(), "test-proxy", types.RoleProxy) - require.NoError(t, err) - - tests := []struct { - name string - serverCert []byte - clock clockwork.Clock - checkError require.ErrorAssertionFunc - }{ - { - name: "proxy verified", - serverCert: proxy.TLSCertBytes, - clock: auth.Clock(), - checkError: require.NoError, - }, - { - name: "proxy expired", - serverCert: proxy.TLSCertBytes, - clock: clockwork.NewFakeClockAt(auth.Clock().Now().Add(defaults.CATTL + time.Hour)), - checkError: require.Error, - }, - { - name: "not proxy", - serverCert: []byte(fixtures.TLSCACertPEM), - clock: auth.Clock(), - checkError: require.Error, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - serverCert, err := utils.ReadCertificates(test.serverCert) - require.NoError(t, err) - - test.checkError(t, verifyALPNUpgradedConn(test.clock)(tls.ConnectionState{ - PeerCertificates: serverCert, - })) - }) - } -} diff --git a/lib/auth/sts_endpoints.go b/lib/auth/sts_endpoints.go deleted file mode 100644 index c0884fb1a2bd4..0000000000000 --- a/lib/auth/sts_endpoints.go +++ /dev/null @@ -1,85 +0,0 @@ -/* - * Teleport - * Copyright (C) 2023 Gravitational, Inc. - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with this program. If not, see . - */ - -package auth - -var ( - // validSTSEndpoints holds a sorted list of all known valid public endpoints for - // the AWS STS service. You can generate this list by running - // $ go run github.com/nklaassen/sts-endpoints@latest --go-list - // Update aws-sdk-go in that package to learn about new endpoints. - validSTSEndpoints = []string{ - "sts-fips.us-east-1.amazonaws.com", - "sts-fips.us-east-2.amazonaws.com", - "sts-fips.us-west-1.amazonaws.com", - "sts-fips.us-west-2.amazonaws.com", - "sts.af-south-1.amazonaws.com", - "sts.amazonaws.com", - "sts.ap-east-1.amazonaws.com", - "sts.ap-northeast-1.amazonaws.com", - "sts.ap-northeast-2.amazonaws.com", - "sts.ap-northeast-3.amazonaws.com", - "sts.ap-south-1.amazonaws.com", - "sts.ap-south-2.amazonaws.com", - "sts.ap-southeast-1.amazonaws.com", - "sts.ap-southeast-2.amazonaws.com", - "sts.ap-southeast-3.amazonaws.com", - "sts.ap-southeast-4.amazonaws.com", - "sts.ca-central-1.amazonaws.com", - "sts.ca-west-1.amazonaws.com", - "sts.cn-north-1.amazonaws.com.cn", - "sts.cn-northwest-1.amazonaws.com.cn", - "sts.eu-central-1.amazonaws.com", - "sts.eu-central-2.amazonaws.com", - "sts.eu-north-1.amazonaws.com", - "sts.eu-south-1.amazonaws.com", - "sts.eu-south-2.amazonaws.com", - "sts.eu-west-1.amazonaws.com", - "sts.eu-west-2.amazonaws.com", - "sts.eu-west-3.amazonaws.com", - "sts.il-central-1.amazonaws.com", - "sts.me-central-1.amazonaws.com", - "sts.me-south-1.amazonaws.com", - "sts.sa-east-1.amazonaws.com", - "sts.us-east-1.amazonaws.com", - "sts.us-east-2.amazonaws.com", - "sts.us-gov-east-1.amazonaws.com", - "sts.us-gov-west-1.amazonaws.com", - "sts.us-iso-east-1.c2s.ic.gov", - "sts.us-iso-west-1.c2s.ic.gov", - "sts.us-isob-east-1.sc2s.sgov.gov", - "sts.us-west-1.amazonaws.com", - "sts.us-west-2.amazonaws.com", - } - - globalSTSEndpoints = []string{ - "sts.amazonaws.com", - // This is not a real endpoint, but the SDK will select it if - // AWS_USE_FIPS_ENDPOINT is set and a region is not. - "sts-fips.aws-global.amazonaws.com", - } - - fipsSTSEndpoints = []string{ - "sts-fips.us-east-1.amazonaws.com", - "sts-fips.us-east-2.amazonaws.com", - "sts-fips.us-west-1.amazonaws.com", - "sts-fips.us-west-2.amazonaws.com", - "sts.us-gov-east-1.amazonaws.com", - "sts.us-gov-west-1.amazonaws.com", - } -) diff --git a/lib/auth/tls_test.go b/lib/auth/tls_test.go index 972a2e34532ad..df0618c8d7a0a 100644 --- a/lib/auth/tls_test.go +++ b/lib/auth/tls_test.go @@ -57,6 +57,7 @@ import ( apiutils "github.com/gravitational/teleport/api/utils" "github.com/gravitational/teleport/api/utils/sshutils" "github.com/gravitational/teleport/lib/auth/authclient" + "github.com/gravitational/teleport/lib/auth/join" "github.com/gravitational/teleport/lib/auth/state" "github.com/gravitational/teleport/lib/auth/testauthority" "github.com/gravitational/teleport/lib/authz" @@ -3394,7 +3395,7 @@ func TestRegisterCAPin(t *testing.T) { caPin := caPins[0] // Attempt to register with valid CA pin, should work. - _, err = Register(ctx, RegisterParams{ + _, err = join.Register(ctx, join.RegisterParams{ AuthServers: []utils.NetAddr{utils.FromAddr(testSrv.Addr())}, Token: token, ID: state.IdentityID{ @@ -3412,7 +3413,7 @@ func TestRegisterCAPin(t *testing.T) { // Attempt to register with multiple CA pins where the auth server only // matches one, should work. - _, err = Register(ctx, RegisterParams{ + _, err = join.Register(ctx, join.RegisterParams{ AuthServers: []utils.NetAddr{utils.FromAddr(testSrv.Addr())}, Token: token, ID: state.IdentityID{ @@ -3429,7 +3430,7 @@ func TestRegisterCAPin(t *testing.T) { require.NoError(t, err) // Attempt to register with invalid CA pin, should fail. - _, err = Register(ctx, RegisterParams{ + _, err = join.Register(ctx, join.RegisterParams{ AuthServers: []utils.NetAddr{utils.FromAddr(testSrv.Addr())}, Token: token, ID: state.IdentityID{ @@ -3446,7 +3447,7 @@ func TestRegisterCAPin(t *testing.T) { require.Error(t, err) // Attempt to register with multiple invalid CA pins, should fail. - _, err = Register(ctx, RegisterParams{ + _, err = join.Register(ctx, join.RegisterParams{ AuthServers: []utils.NetAddr{utils.FromAddr(testSrv.Addr())}, Token: token, ID: state.IdentityID{ @@ -3482,7 +3483,7 @@ func TestRegisterCAPin(t *testing.T) { require.Len(t, caPins, 2) // Attempt to register with multiple CA pins, should work - _, err = Register(ctx, RegisterParams{ + _, err = join.Register(ctx, join.RegisterParams{ AuthServers: []utils.NetAddr{utils.FromAddr(testSrv.Addr())}, Token: token, ID: state.IdentityID{ @@ -3527,7 +3528,7 @@ func TestRegisterCAPath(t *testing.T) { require.NoError(t, err) // Attempt to register with nothing at the CA path, should work. - _, err = Register(ctx, RegisterParams{ + _, err = join.Register(ctx, join.RegisterParams{ AuthServers: []utils.NetAddr{utils.FromAddr(testSrv.Addr())}, Token: token, ID: state.IdentityID{ @@ -3556,7 +3557,7 @@ func TestRegisterCAPath(t *testing.T) { require.NoError(t, err) // Attempt to register with valid CA path, should work. - _, err = Register(ctx, RegisterParams{ + _, err = join.Register(ctx, join.RegisterParams{ AuthServers: []utils.NetAddr{utils.FromAddr(testSrv.Addr())}, Token: token, ID: state.IdentityID{ diff --git a/lib/service/connect.go b/lib/service/connect.go index bc61b4b1a5679..eeee5ebb9fb92 100644 --- a/lib/service/connect.go +++ b/lib/service/connect.go @@ -45,6 +45,7 @@ import ( "github.com/gravitational/teleport/lib" "github.com/gravitational/teleport/lib/auth" "github.com/gravitational/teleport/lib/auth/authclient" + "github.com/gravitational/teleport/lib/auth/join" "github.com/gravitational/teleport/lib/auth/native" "github.com/gravitational/teleport/lib/auth/state" "github.com/gravitational/teleport/lib/client" @@ -442,7 +443,7 @@ func (process *TeleportProcess) firstTimeConnect(role types.SystemRole) (*Connec dataDir = process.Config.DataDir } - registerParams := auth.RegisterParams{ + registerParams := join.RegisterParams{ Token: token, ID: id, AuthServers: process.Config.AuthServerAddresses(), @@ -464,12 +465,12 @@ func (process *TeleportProcess) firstTimeConnect(role types.SystemRole) (*Connec Insecure: lib.IsInsecureDevMode(), } if registerParams.JoinMethod == types.JoinMethodAzure { - registerParams.AzureParams = auth.AzureParams{ + registerParams.AzureParams = join.AzureParams{ ClientID: process.Config.JoinParams.Azure.ClientID, } } - certs, err := auth.Register(process.ExitContext(), registerParams) + certs, err := join.Register(process.ExitContext(), registerParams) if err != nil { if utils.IsUntrustedCertErr(err) { return nil, trace.WrapWithMessage(err, utils.SelfSignedCertsMsg) diff --git a/lib/tbot/service_bot_identity.go b/lib/tbot/service_bot_identity.go index f193b2542bc12..11418705398d6 100644 --- a/lib/tbot/service_bot_identity.go +++ b/lib/tbot/service_bot_identity.go @@ -33,8 +33,8 @@ import ( "github.com/gravitational/teleport/api/client/proto" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/utils/retryutils" - "github.com/gravitational/teleport/lib/auth" "github.com/gravitational/teleport/lib/auth/authclient" + "github.com/gravitational/teleport/lib/auth/join" "github.com/gravitational/teleport/lib/auth/state" "github.com/gravitational/teleport/lib/client" "github.com/gravitational/teleport/lib/reversetunnelclient" @@ -403,7 +403,7 @@ func botIdentityFromToken(ctx context.Context, log *slog.Logger, cfg *config.Bot } expires := time.Now().Add(cfg.CertificateTTL) - params := auth.RegisterParams{ + params := join.RegisterParams{ Token: token, ID: state.IdentityID{ Role: types.RoleBot, @@ -439,12 +439,12 @@ func botIdentityFromToken(ctx context.Context, log *slog.Logger, cfg *config.Bot } if params.JoinMethod == types.JoinMethodAzure { - params.AzureParams = auth.AzureParams{ + params.AzureParams = join.AzureParams{ ClientID: cfg.Onboarding.Azure.ClientID, } } - certs, err := auth.Register(ctx, params) + certs, err := join.Register(ctx, params) if err != nil { return nil, trace.Wrap(err) }