From 8e59e36e90e7b579a375f114fc4575e59d531204 Mon Sep 17 00:00:00 2001
From: rosstimothy <39066650+rosstimothy@users.noreply.github.com>
Date: Sat, 18 May 2024 12:56:28 -0400
Subject: [PATCH] Refactor registration code into lib/auth/join (#41679)
* Refactor joining code into lib/auth/join
auth.Register and auth.RegisterParams were moved into the new package
to prevent clients(tbot) from depending on lib/auth. For the moment
only iam joining has been moved into the new package. All other
join methods have been left behind in their lib/auth/join_foo.go
file and will be addressed in future PRs if need be. This removes
lib/auth entirely from tbot and reduces its binary size by ~10-20MB.
* fix: disabling second factor auth in tests
* fix license
---
integration/proxy/proxy_helpers.go | 4 +-
lib/auth/auth.go | 3 -
lib/auth/bot_test.go | 7 +-
lib/auth/join/iam/endpoints.go | 91 ++++
lib/auth/join/iam/iam.go | 161 ++++++
lib/auth/join/join.go | 796 +++++++++++++++++++++++++++++
lib/auth/join/join_test.go | 328 ++++++++++++
lib/auth/join_iam.go | 134 +----
lib/auth/join_test.go | 5 +-
lib/auth/register.go | 762 ---------------------------
lib/auth/register_test.go | 78 ---
lib/auth/sts_endpoints.go | 85 ---
lib/auth/tls_test.go | 15 +-
lib/service/connect.go | 7 +-
lib/tbot/service_bot_identity.go | 8 +-
15 files changed, 1405 insertions(+), 1079 deletions(-)
create mode 100644 lib/auth/join/iam/endpoints.go
create mode 100644 lib/auth/join/iam/iam.go
create mode 100644 lib/auth/join/join.go
create mode 100644 lib/auth/join/join_test.go
delete mode 100644 lib/auth/register_test.go
delete mode 100644 lib/auth/sts_endpoints.go
diff --git a/integration/proxy/proxy_helpers.go b/integration/proxy/proxy_helpers.go
index 726568e34d46e..7fff7e9a54100 100644
--- a/integration/proxy/proxy_helpers.go
+++ b/integration/proxy/proxy_helpers.go
@@ -53,7 +53,7 @@ import (
"github.com/gravitational/teleport/api/utils/retryutils"
"github.com/gravitational/teleport/integration/helpers"
"github.com/gravitational/teleport/lib"
- "github.com/gravitational/teleport/lib/auth"
+ "github.com/gravitational/teleport/lib/auth/join"
"github.com/gravitational/teleport/lib/auth/state"
"github.com/gravitational/teleport/lib/client"
"github.com/gravitational/teleport/lib/defaults"
@@ -664,7 +664,7 @@ func mustRegisterUsingIAMMethod(t *testing.T, proxyAddr utils.NetAddr, token str
require.NoError(t, err)
node := uuid.NewString()
- _, err = auth.Register(context.TODO(), auth.RegisterParams{
+ _, err = join.Register(context.TODO(), join.RegisterParams{
Token: token,
ID: state.IdentityID{
Role: types.RoleNode,
diff --git a/lib/auth/auth.go b/lib/auth/auth.go
index aedc4938cffa4..0ec085c293db6 100644
--- a/lib/auth/auth.go
+++ b/lib/auth/auth.go
@@ -52,7 +52,6 @@ import (
"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
"github.com/prometheus/client_golang/prometheus"
- "go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/exporters/otlp/otlptrace"
"golang.org/x/crypto/bcrypt"
"golang.org/x/crypto/ssh"
@@ -155,8 +154,6 @@ const (
"(hint: use 'tctl get roles' to find roles that need updating)"
)
-var tracer = otel.Tracer("github.com/gravitational/teleport/lib/auth")
-
var ErrRequiresEnterprise = services.ErrRequiresEnterprise
// ServerOption allows setting options as functional arguments to Server
diff --git a/lib/auth/bot_test.go b/lib/auth/bot_test.go
index e53e445ceafaf..4c54114a96522 100644
--- a/lib/auth/bot_test.go
+++ b/lib/auth/bot_test.go
@@ -42,6 +42,7 @@ import (
machineidv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/machineid/v1"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/auth/authclient"
+ "github.com/gravitational/teleport/lib/auth/join"
"github.com/gravitational/teleport/lib/auth/machineid/machineidv1"
"github.com/gravitational/teleport/lib/auth/state"
"github.com/gravitational/teleport/lib/auth/testauthority"
@@ -118,7 +119,7 @@ func TestRegisterBotCertificateGenerationCheck(t *testing.T) {
tlsPublicKey, err := tlsca.MarshalPublicKeyFromPrivateKeyPEM(sshPrivateKey)
require.NoError(t, err)
- certs, err := Register(ctx, RegisterParams{
+ certs, err := join.Register(ctx, join.RegisterParams{
Token: token.GetName(),
ID: state.IdentityID{
Role: types.RoleBot,
@@ -191,7 +192,7 @@ func TestRegisterBotCertificateGenerationStolen(t *testing.T) {
tlsPublicKey, err := tlsca.MarshalPublicKeyFromPrivateKeyPEM(sshPrivateKey)
require.NoError(t, err)
- certs, err := Register(ctx, RegisterParams{
+ certs, err := join.Register(ctx, join.RegisterParams{
Token: token.GetName(),
ID: state.IdentityID{
Role: types.RoleBot,
@@ -267,7 +268,7 @@ func TestRegisterBotCertificateExtensions(t *testing.T) {
tlsPublicKey, err := tlsca.MarshalPublicKeyFromPrivateKeyPEM(sshPrivateKey)
require.NoError(t, err)
- certs, err := Register(ctx, RegisterParams{
+ certs, err := join.Register(ctx, join.RegisterParams{
Token: token.GetName(),
ID: state.IdentityID{
Role: types.RoleBot,
diff --git a/lib/auth/join/iam/endpoints.go b/lib/auth/join/iam/endpoints.go
new file mode 100644
index 0000000000000..3434ce30228b1
--- /dev/null
+++ b/lib/auth/join/iam/endpoints.go
@@ -0,0 +1,91 @@
+// Teleport
+// Copyright (C) 2024 Gravitational, Inc.
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+package iam
+
+import "sync"
+
+var (
+ // ValidSTSEndpoints holds a sorted list of all known valid public endpoints for
+ // the AWS STS service. You can generate this list by running
+ // $ go run github.com/nklaassen/sts-endpoints@latest --go-list
+ // Update aws-sdk-go in that package to learn about new endpoints.
+ ValidSTSEndpoints = sync.OnceValue(func() []string {
+ return []string{
+ "sts-fips.us-east-1.amazonaws.com",
+ "sts-fips.us-east-2.amazonaws.com",
+ "sts-fips.us-west-1.amazonaws.com",
+ "sts-fips.us-west-2.amazonaws.com",
+ "sts.af-south-1.amazonaws.com",
+ "sts.amazonaws.com",
+ "sts.ap-east-1.amazonaws.com",
+ "sts.ap-northeast-1.amazonaws.com",
+ "sts.ap-northeast-2.amazonaws.com",
+ "sts.ap-northeast-3.amazonaws.com",
+ "sts.ap-south-1.amazonaws.com",
+ "sts.ap-south-2.amazonaws.com",
+ "sts.ap-southeast-1.amazonaws.com",
+ "sts.ap-southeast-2.amazonaws.com",
+ "sts.ap-southeast-3.amazonaws.com",
+ "sts.ap-southeast-4.amazonaws.com",
+ "sts.ca-central-1.amazonaws.com",
+ "sts.ca-west-1.amazonaws.com",
+ "sts.cn-north-1.amazonaws.com.cn",
+ "sts.cn-northwest-1.amazonaws.com.cn",
+ "sts.eu-central-1.amazonaws.com",
+ "sts.eu-central-2.amazonaws.com",
+ "sts.eu-north-1.amazonaws.com",
+ "sts.eu-south-1.amazonaws.com",
+ "sts.eu-south-2.amazonaws.com",
+ "sts.eu-west-1.amazonaws.com",
+ "sts.eu-west-2.amazonaws.com",
+ "sts.eu-west-3.amazonaws.com",
+ "sts.il-central-1.amazonaws.com",
+ "sts.me-central-1.amazonaws.com",
+ "sts.me-south-1.amazonaws.com",
+ "sts.sa-east-1.amazonaws.com",
+ "sts.us-east-1.amazonaws.com",
+ "sts.us-east-2.amazonaws.com",
+ "sts.us-gov-east-1.amazonaws.com",
+ "sts.us-gov-west-1.amazonaws.com",
+ "sts.us-iso-east-1.c2s.ic.gov",
+ "sts.us-iso-west-1.c2s.ic.gov",
+ "sts.us-isob-east-1.sc2s.sgov.gov",
+ "sts.us-west-1.amazonaws.com",
+ "sts.us-west-2.amazonaws.com",
+ }
+ })
+
+ GlobalSTSEndpoints = sync.OnceValue(func() []string {
+ return []string{
+ "sts.amazonaws.com",
+ // This is not a real endpoint, but the SDK will select it if
+ // AWS_USE_FIPS_ENDPOINT is set and a region is not.
+ "sts-fips.aws-global.amazonaws.com",
+ }
+ })
+
+ FIPSSTSEndpoints = sync.OnceValue(func() []string {
+ return []string{
+ "sts-fips.us-east-1.amazonaws.com",
+ "sts-fips.us-east-2.amazonaws.com",
+ "sts-fips.us-west-1.amazonaws.com",
+ "sts-fips.us-west-2.amazonaws.com",
+ "sts.us-gov-east-1.amazonaws.com",
+ "sts.us-gov-west-1.amazonaws.com",
+ }
+ })
+)
diff --git a/lib/auth/join/iam/iam.go b/lib/auth/join/iam/iam.go
new file mode 100644
index 0000000000000..241d4ec7800c3
--- /dev/null
+++ b/lib/auth/join/iam/iam.go
@@ -0,0 +1,161 @@
+// Teleport
+// Copyright (C) 2024 Gravitational, Inc.
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+package iam
+
+import (
+ "bytes"
+ "context"
+ "log/slog"
+ "slices"
+ "strings"
+
+ awssdk "github.com/aws/aws-sdk-go/aws"
+ "github.com/aws/aws-sdk-go/aws/endpoints"
+ "github.com/aws/aws-sdk-go/aws/session"
+ "github.com/aws/aws-sdk-go/service/sts"
+ "github.com/gravitational/trace"
+
+ cloudaws "github.com/gravitational/teleport/lib/cloud/imds/aws"
+)
+
+const (
+ // AWS SignedHeaders will always be lowercase
+ // https://docs.aws.amazon.com/AmazonS3/latest/API/sigv4-auth-using-authorization-header.html#sigv4-auth-header-overview
+ challengeHeaderKey = "x-teleport-challenge"
+)
+
+type stsIdentityRequestConfig struct {
+ regionalEndpointOption endpoints.STSRegionalEndpoint
+ fipsEndpointOption endpoints.FIPSEndpointState
+}
+
+type stsIdentityRequestOption func(cfg *stsIdentityRequestConfig)
+
+func WithRegionalEndpoint(useRegionalEndpoint bool) stsIdentityRequestOption {
+ return func(cfg *stsIdentityRequestConfig) {
+ if useRegionalEndpoint {
+ cfg.regionalEndpointOption = endpoints.RegionalSTSEndpoint
+ } else {
+ cfg.regionalEndpointOption = endpoints.LegacySTSEndpoint
+ }
+ }
+}
+
+func WithFIPSEndpoint(useFIPS bool) stsIdentityRequestOption {
+ return func(cfg *stsIdentityRequestConfig) {
+ if useFIPS {
+ cfg.fipsEndpointOption = endpoints.FIPSEndpointStateEnabled
+ } else {
+ cfg.fipsEndpointOption = endpoints.FIPSEndpointStateDisabled
+ }
+ }
+}
+
+// getEC2LocalRegion returns the AWS region this EC2 instance is running in, or
+// a NotFound error if the EC2 IMDS is unavailable.
+func getEC2LocalRegion(ctx context.Context) (string, error) {
+ imdsClient, err := cloudaws.NewInstanceMetadataClient(ctx)
+ if err != nil {
+ return "", trace.Wrap(err)
+ }
+
+ if !imdsClient.IsAvailable(ctx) {
+ return "", trace.NotFound("IMDS is unavailable")
+ }
+
+ region, err := imdsClient.GetRegion(ctx)
+ return region, trace.Wrap(err)
+}
+
+func newSTSClient(ctx context.Context, cfg *stsIdentityRequestConfig) (*sts.STS, error) {
+ awsConfig := awssdk.Config{
+ UseFIPSEndpoint: cfg.fipsEndpointOption,
+ STSRegionalEndpoint: cfg.regionalEndpointOption,
+ }
+ sess, err := session.NewSessionWithOptions(session.Options{
+ SharedConfigState: session.SharedConfigEnable,
+ Config: awsConfig,
+ })
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+
+ stsClient := sts.New(sess)
+
+ if slices.Contains(GlobalSTSEndpoints(), strings.TrimPrefix(stsClient.Endpoint, "https://")) {
+ // If the caller wants to use the regional endpoint but it was not resolved
+ // from the environment, attempt to find the region from the EC2 IMDS
+ if cfg.regionalEndpointOption == endpoints.RegionalSTSEndpoint {
+ region, err := getEC2LocalRegion(ctx)
+ if err != nil {
+ return nil, trace.Wrap(err, "failed to resolve local AWS region from environment or IMDS")
+ }
+ stsClient = sts.New(sess, awssdk.NewConfig().WithRegion(region))
+ } else {
+ const msg = "Attempting to use the global STS endpoint for the IAM join method. " +
+ "This will probably fail in non-default AWS partitions such as China or GovCloud, or if FIPS mode is enabled. " +
+ "Consider setting the AWS_REGION environment variable, setting the region in ~/.aws/config, or enabling the IMDSv2."
+ slog.InfoContext(ctx, msg)
+ }
+ }
+
+ if cfg.fipsEndpointOption == endpoints.FIPSEndpointStateEnabled &&
+ !slices.Contains(ValidSTSEndpoints(), strings.TrimPrefix(stsClient.Endpoint, "https://")) {
+ // The AWS SDK will generate invalid endpoints when attempting to
+ // resolve the FIPS endpoint for a region that does not have one.
+ // In this case, try to use the FIPS endpoint in us-east-1. This should
+ // work for all regions in the standard partition. In GovCloud, we should
+ // not hit this because all regional endpoints support FIPS. In China or
+ // other partitions, this will fail, and FIPS mode will not be supported.
+ const msg = "AWS SDK resolved invalid FIPS STS endpoint. " +
+ "Attempting to use the FIPS STS endpoint for us-east-1."
+ slog.InfoContext(ctx, msg, "resolved", stsClient.Endpoint)
+ stsClient = sts.New(sess, awssdk.NewConfig().WithRegion("us-east-1"))
+ }
+
+ return stsClient, nil
+}
+
+// CreateSignedSTSIdentityRequest is called on the client side and returns an
+// sts:GetCallerIdentity request signed with the local AWS credentials
+func CreateSignedSTSIdentityRequest(ctx context.Context, challenge string, opts ...stsIdentityRequestOption) ([]byte, error) {
+ cfg := &stsIdentityRequestConfig{}
+ for _, opt := range opts {
+ opt(cfg)
+ }
+
+ stsClient, err := newSTSClient(ctx, cfg)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+
+ req, _ := stsClient.GetCallerIdentityRequest(&sts.GetCallerIdentityInput{})
+ // set challenge header
+ req.HTTPRequest.Header.Set(challengeHeaderKey, challenge)
+ // request json for simpler parsing
+ req.HTTPRequest.Header.Set("Accept", "application/json")
+ // sign the request, including headers
+ if err := req.Sign(); err != nil {
+ return nil, trace.Wrap(err)
+ }
+ // write the signed HTTP request to a buffer
+ var signedRequest bytes.Buffer
+ if err := req.HTTPRequest.Write(&signedRequest); err != nil {
+ return nil, trace.Wrap(err)
+ }
+ return signedRequest.Bytes(), nil
+}
diff --git a/lib/auth/join/join.go b/lib/auth/join/join.go
new file mode 100644
index 0000000000000..367b091612ac5
--- /dev/null
+++ b/lib/auth/join/join.go
@@ -0,0 +1,796 @@
+// Teleport
+// Copyright (C) 2024 Gravitational, Inc.
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+package join
+
+import (
+ "context"
+ "crypto/tls"
+ "crypto/x509"
+ "log/slog"
+ "os"
+ "slices"
+ "time"
+
+ "github.com/gravitational/trace"
+ "github.com/jonboulle/clockwork"
+ log "github.com/sirupsen/logrus"
+ "go.opentelemetry.io/otel"
+ "golang.org/x/net/http2"
+ "google.golang.org/grpc"
+ "google.golang.org/grpc/credentials"
+
+ "github.com/gravitational/teleport"
+ "github.com/gravitational/teleport/api/breaker"
+ "github.com/gravitational/teleport/api/client"
+ "github.com/gravitational/teleport/api/client/proto"
+ "github.com/gravitational/teleport/api/constants"
+ apidefaults "github.com/gravitational/teleport/api/defaults"
+ "github.com/gravitational/teleport/api/metadata"
+ "github.com/gravitational/teleport/api/observability/tracing"
+ "github.com/gravitational/teleport/api/types"
+ "github.com/gravitational/teleport/api/utils/aws"
+ "github.com/gravitational/teleport/lib/auth/authclient"
+ "github.com/gravitational/teleport/lib/auth/join/iam"
+ "github.com/gravitational/teleport/lib/auth/state"
+ "github.com/gravitational/teleport/lib/circleci"
+ "github.com/gravitational/teleport/lib/cloud/imds/azure"
+ "github.com/gravitational/teleport/lib/cloud/imds/gcp"
+ "github.com/gravitational/teleport/lib/defaults"
+ "github.com/gravitational/teleport/lib/githubactions"
+ "github.com/gravitational/teleport/lib/gitlab"
+ "github.com/gravitational/teleport/lib/kubernetestoken"
+ "github.com/gravitational/teleport/lib/spacelift"
+ "github.com/gravitational/teleport/lib/srv/alpnproxy/common"
+ "github.com/gravitational/teleport/lib/tlsca"
+ "github.com/gravitational/teleport/lib/tpm"
+ "github.com/gravitational/teleport/lib/utils"
+)
+
+var tracer = otel.Tracer("github.com/gravitational/teleport/lib/auth/join")
+
+// HostCredentials is an interface for a client that can be used to get host
+// credentials. This interface is needed because lib/client cannot be imported
+// in lib/auth due to circular imports.
+type HostCredentials func(context.Context, string, bool, types.RegisterUsingTokenRequest) (*proto.Certs, error)
+
+// AzureParams is the parameters specific to the azure join method.
+type AzureParams struct {
+ // ClientID is the client ID of the managed identity for Teleport to assume
+ // when authenticating a node.
+ ClientID string
+}
+
+// RegisterParams specifies parameters
+// for first time register operation with auth server
+type RegisterParams struct {
+ // Token is a secure token to join the cluster
+ Token string
+ // ID is identity ID
+ ID state.IdentityID
+ // AuthServers is a list of auth servers to dial
+ AuthServers []utils.NetAddr
+ // ProxyServer is a proxy server to dial
+ ProxyServer utils.NetAddr
+ // AdditionalPrincipals is a list of additional principals to dial
+ AdditionalPrincipals []string
+ // DNSNames is a list of DNS names to add to x509 certificate
+ DNSNames []string
+ // PublicTLSKey is a server's public key to sign
+ PublicTLSKey []byte
+ // PublicSSHKey is a server's public SSH key to sign
+ PublicSSHKey []byte
+ // CipherSuites is a list of cipher suites to use for TLS client connection
+ CipherSuites []uint16
+ // CAPins are the SKPI hashes of the CAs used to verify the Auth Server.
+ CAPins []string
+ // CAPath is the path to the CA file.
+ CAPath string
+ // GetHostCredentials is a client that can fetch host credentials.
+ GetHostCredentials HostCredentials
+ // Clock specifies the time provider. Will be used to override the time anchor
+ // for TLS certificate verification.
+ // Defaults to real clock if unspecified
+ Clock clockwork.Clock
+ // JoinMethod is the joining method used for this register request.
+ JoinMethod types.JoinMethod
+ // ec2IdentityDocument is used for Simplified Node Joining to prove the
+ // identity of a joining EC2 instance.
+ ec2IdentityDocument []byte
+ // AzureParams is the parameters specific to the azure join method.
+ AzureParams AzureParams
+ // CircuitBreakerConfig defines how the circuit breaker should behave.
+ CircuitBreakerConfig breaker.Config
+ // FIPS means FedRAMP/FIPS 140-2 compliant configuration was requested.
+ FIPS bool
+ // IDToken is a token retrieved from a workload identity provider for
+ // certain join types e.g GitHub, Google.
+ IDToken string
+ // Expires is an optional field for bots that specifies a time that the
+ // certificates that are returned by registering should expire at.
+ // It should not be specified for non-bot registrations.
+ Expires *time.Time
+ // Insecure trusts the certificates from the Auth Server or Proxy during registration without verification.
+ Insecure bool
+}
+
+func (r *RegisterParams) checkAndSetDefaults() error {
+ if r.Clock == nil {
+ r.Clock = clockwork.NewRealClock()
+ }
+
+ if err := r.verifyAuthOrProxyAddress(); err != nil {
+ return trace.BadParameter("no auth or proxy servers set")
+ }
+
+ return nil
+}
+
+func (r *RegisterParams) verifyAuthOrProxyAddress() error {
+ haveAuthServers := len(r.AuthServers) > 0
+ haveProxyServer := !r.ProxyServer.IsEmpty()
+
+ if !haveAuthServers && !haveProxyServer {
+ return trace.BadParameter("no auth or proxy servers set")
+ }
+
+ if haveAuthServers && haveProxyServer {
+ return trace.BadParameter("only one of auth or proxy server should be set")
+ }
+
+ return nil
+}
+
+// Register is used to generate host keys when a node or proxy are running on
+// different hosts than the auth server. This method requires provisioning
+// tokens to prove a valid auth server was used to issue the joining request
+// as well as a method for the node to validate the auth server.
+func Register(ctx context.Context, params RegisterParams) (certs *proto.Certs, err error) {
+ ctx, span := tracer.Start(ctx, "Register")
+ defer func() { tracing.EndSpan(span, err) }()
+
+ if err := params.checkAndSetDefaults(); err != nil {
+ return nil, trace.Wrap(err)
+ }
+ // Read in the token. The token can either be passed in or come from a file
+ // on disk.
+ token, err := utils.TryReadValueAsFile(params.Token)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+
+ // add EC2 Identity Document to params if required for given join method
+ switch params.JoinMethod {
+ case types.JoinMethodEC2:
+ if !aws.IsEC2NodeID(params.ID.HostUUID) {
+ return nil, trace.BadParameter(
+ `Host ID %q is not valid when using the EC2 join method, `+
+ `try removing the "host_uuid" file in your teleport data dir `+
+ `(e.g. /var/lib/teleport/host_uuid)`,
+ params.ID.HostUUID)
+ }
+ params.ec2IdentityDocument, err = utils.GetRawEC2IdentityDocument(ctx)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ case types.JoinMethodGitHub:
+ params.IDToken, err = githubactions.NewIDTokenSource().GetIDToken(ctx)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ case types.JoinMethodGitLab:
+ params.IDToken, err = gitlab.NewIDTokenSource(os.Getenv).GetIDToken()
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ case types.JoinMethodCircleCI:
+ params.IDToken, err = circleci.GetIDToken(os.Getenv)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ case types.JoinMethodKubernetes:
+ params.IDToken, err = kubernetestoken.GetIDToken(os.Getenv, os.ReadFile)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ case types.JoinMethodGCP:
+ params.IDToken, err = gcp.GetIDToken(ctx)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ case types.JoinMethodSpacelift:
+ params.IDToken, err = spacelift.NewIDTokenSource(os.Getenv).GetIDToken()
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ }
+
+ type registerMethod struct {
+ call func(ctx context.Context, token string, params RegisterParams) (*proto.Certs, error)
+ desc string
+ }
+
+ registerThroughAuth := registerMethod{registerThroughAuth, "with auth server"}
+ registerThroughProxy := registerMethod{registerThroughProxy, "via proxy server"}
+
+ registerMethods := []registerMethod{registerThroughAuth, registerThroughProxy}
+
+ if !params.ProxyServer.IsEmpty() {
+ log.WithField("proxy-server", params.ProxyServer).Debugf("Registering node to the cluster.")
+
+ registerMethods = []registerMethod{registerThroughProxy}
+
+ if proxyServerIsAuth(params.ProxyServer) {
+ log.Debugf("The specified proxy server appears to be an auth server.")
+ }
+ } else {
+ log.WithField("auth-servers", params.AuthServers).Debugf("Registering node to the cluster.")
+
+ if params.GetHostCredentials == nil {
+ log.Debugf("Missing client, it is not possible to register through proxy.")
+ registerMethods = []registerMethod{registerThroughAuth}
+ } else if authServerIsProxy(params.AuthServers) {
+ log.Debugf("The first specified auth server appears to be a proxy.")
+ registerMethods = []registerMethod{registerThroughProxy, registerThroughAuth}
+ }
+ }
+
+ var collectedErrs []error
+ for _, method := range registerMethods {
+ log.Infof("Attempting registration %s.", method.desc)
+ certs, err := method.call(ctx, token, params)
+ if err != nil {
+ collectedErrs = append(collectedErrs, err)
+ log.WithError(err).Debugf("Registration %s failed.", method.desc)
+ continue
+ }
+ log.Infof("Successfully registered %s.", method.desc)
+ return certs, nil
+ }
+ return nil, trace.NewAggregate(collectedErrs...)
+}
+
+// authServerIsProxy returns true if the first specified auth server
+// to register with appears to be a proxy.
+func authServerIsProxy(servers []utils.NetAddr) bool {
+ if len(servers) == 0 {
+ return false
+ }
+ port := servers[0].Port(0)
+ return port == defaults.HTTPListenPort || port == teleport.StandardHTTPSPort
+}
+
+// proxyServerIsAuth returns true if the address given to register with
+// appears to be an auth server.
+func proxyServerIsAuth(server utils.NetAddr) bool {
+ port := server.Port(0)
+ return port == defaults.AuthListenPort
+}
+
+// registerThroughProxy is used to register through the proxy server.
+func registerThroughProxy(
+ ctx context.Context,
+ token string,
+ params RegisterParams,
+) (certs *proto.Certs, err error) {
+ ctx, span := tracer.Start(ctx, "registerThroughProxy")
+ defer func() { tracing.EndSpan(span, err) }()
+
+ switch params.JoinMethod {
+ case types.JoinMethodIAM, types.JoinMethodAzure, types.JoinMethodTPM:
+ // IAM and Azure join methods require gRPC client
+ conn, err := proxyJoinServiceConn(ctx, params, params.Insecure)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ defer conn.Close()
+
+ joinServiceClient := client.NewJoinServiceClient(proto.NewJoinServiceClient(conn))
+ switch params.JoinMethod {
+ case types.JoinMethodIAM:
+ certs, err = registerUsingIAMMethod(ctx, joinServiceClient, token, params)
+ case types.JoinMethodAzure:
+ certs, err = registerUsingAzureMethod(ctx, joinServiceClient, token, params)
+ case types.JoinMethodTPM:
+ certs, err = registerUsingTPMMethod(ctx, joinServiceClient, token, params)
+ default:
+ return nil, trace.BadParameter("unhandled join method %q", params.JoinMethod)
+ }
+
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ default:
+ // The rest of the join methods use GetHostCredentials function passed through
+ // params to call proxy HTTP endpoint
+ var err error
+ certs, err = params.GetHostCredentials(ctx,
+ getHostAddresses(params)[0],
+ params.Insecure,
+ types.RegisterUsingTokenRequest{
+ Token: token,
+ HostID: params.ID.HostUUID,
+ NodeName: params.ID.NodeName,
+ Role: params.ID.Role,
+ AdditionalPrincipals: params.AdditionalPrincipals,
+ DNSNames: params.DNSNames,
+ PublicTLSKey: params.PublicTLSKey,
+ PublicSSHKey: params.PublicSSHKey,
+ EC2IdentityDocument: params.ec2IdentityDocument,
+ IDToken: params.IDToken,
+ Expires: params.Expires,
+ })
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ }
+ return certs, nil
+}
+
+// registerThroughAuth is used to register through the auth server.
+func registerThroughAuth(
+ ctx context.Context, token string, params RegisterParams,
+) (certs *proto.Certs, err error) {
+ ctx, span := tracer.Start(ctx, "registerThroughAuth")
+ defer func() { tracing.EndSpan(span, err) }()
+
+ var client *authclient.Client
+ // Build a client for the Auth Server with different certificate validation
+ // depending on the configured values for Insecure, CAPins and CAPath.
+ switch {
+ case params.Insecure:
+ log.Warnf("Insecure mode enabled. Auth Server cert will not be validated and CAPins and CAPath value will be ignored.")
+ client, err = insecureRegisterClient(params)
+ case len(params.CAPins) != 0:
+ // CAPins takes precedence over CAPath
+ client, err = pinRegisterClient(ctx, params)
+ case params.CAPath != "":
+ client, err = caPathRegisterClient(params)
+ default:
+ // We fall back to insecure mode here - this is a little odd but is
+ // necessary to preserve the behavior of registration. At a later date,
+ // we may consider making this an error asking the user to provide
+ // Insecure, CAPins or CAPath.
+ client, err = insecureRegisterClient(params)
+ }
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ defer client.Close()
+
+ switch params.JoinMethod {
+ // IAM and Azure methods use unique gRPC endpoints
+ case types.JoinMethodIAM:
+ certs, err = registerUsingIAMMethod(ctx, client, token, params)
+ case types.JoinMethodAzure:
+ certs, err = registerUsingAzureMethod(ctx, client, token, params)
+ case types.JoinMethodTPM:
+ certs, err = registerUsingTPMMethod(ctx, client, token, params)
+ default:
+ // non-IAM join methods use HTTP endpoint
+ // Get the SSH and X509 certificates for a node.
+ certs, err = client.RegisterUsingToken(
+ ctx,
+ &types.RegisterUsingTokenRequest{
+ Token: token,
+ HostID: params.ID.HostUUID,
+ NodeName: params.ID.NodeName,
+ Role: params.ID.Role,
+ AdditionalPrincipals: params.AdditionalPrincipals,
+ DNSNames: params.DNSNames,
+ PublicTLSKey: params.PublicTLSKey,
+ PublicSSHKey: params.PublicSSHKey,
+ EC2IdentityDocument: params.ec2IdentityDocument,
+ IDToken: params.IDToken,
+ Expires: params.Expires,
+ })
+ }
+ return certs, trace.Wrap(err)
+}
+
+// proxyJoinServiceConn attempts to connect to the join service running on the
+// proxy. The Proxy's TLS cert will be verified using the host's root CA pool
+// (PKI) unless the --insecure flag was passed.
+func proxyJoinServiceConn(
+ ctx context.Context, params RegisterParams, insecure bool,
+) (*grpc.ClientConn, error) {
+ tlsConfig := utils.TLSConfig(params.CipherSuites)
+ tlsConfig.Time = params.Clock.Now
+ // set NextProtos for TLS routing, the actual protocol will be h2
+ tlsConfig.NextProtos = []string{string(common.ProtocolProxyGRPCInsecure), http2.NextProtoTLS}
+
+ if insecure {
+ tlsConfig.InsecureSkipVerify = true
+ log.Warnf("Joining cluster without validating the identity of the Proxy Server.")
+ }
+
+ // Check if proxy is behind a load balancer. If so, the connection upgrade
+ // will verify the load balancer's cert using system cert pool. This
+ // provides the same level of security as the client only verifies Proxy's
+ // web cert against system cert pool when connection upgrade is not
+ // required.
+ //
+ // With the ALPN connection upgrade, the tunneled TLS Routing request will
+ // skip verify as the Proxy server will present its host cert which is not
+ // fully verifiable at this point since the client does not have the host
+ // CAs yet before completing registration.
+ alpnConnUpgrade := client.IsALPNConnUpgradeRequired(ctx, getHostAddresses(params)[0], insecure)
+ if alpnConnUpgrade && !insecure {
+ tlsConfig.InsecureSkipVerify = true
+ tlsConfig.VerifyConnection = verifyALPNUpgradedConn(params.Clock)
+ }
+
+ dialer := client.NewDialer(
+ ctx,
+ apidefaults.DefaultIdleTimeout,
+ apidefaults.DefaultIOTimeout,
+ client.WithInsecureSkipVerify(insecure),
+ client.WithALPNConnUpgrade(alpnConnUpgrade),
+ )
+
+ conn, err := grpc.Dial(
+ getHostAddresses(params)[0],
+ grpc.WithContextDialer(client.GRPCContextDialer(dialer)),
+ grpc.WithUnaryInterceptor(metadata.UnaryClientInterceptor),
+ grpc.WithStreamInterceptor(metadata.StreamClientInterceptor),
+ grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)),
+ )
+ return conn, trace.Wrap(err)
+}
+
+func getHostAddresses(params RegisterParams) []string {
+ if !params.ProxyServer.IsEmpty() {
+ return []string{params.ProxyServer.String()}
+ }
+
+ return utils.NetAddrsToStrings(params.AuthServers)
+}
+
+// verifyALPNUpgradedConn is a tls.Config.VerifyConnection callback function
+// used by the tunneled TLS Routing request to verify the host cert of a Proxy
+// behind a L7 load balancer.
+//
+// Since the client has not obtained the cluster CAs at this point, the
+// presented cert cannot be fully verified yet. For now, this function only
+// checks if "teleport.cluster.local" is present as one of the DNS names and
+// verifies the cert is not expired.
+func verifyALPNUpgradedConn(clock clockwork.Clock) func(tls.ConnectionState) error {
+ return func(server tls.ConnectionState) error {
+ for _, cert := range server.PeerCertificates {
+ if slices.Contains(cert.DNSNames, constants.APIDomain) && clock.Now().Before(cert.NotAfter) {
+ return nil
+ }
+ }
+ return trace.AccessDenied("server is not a Teleport proxy or server certificate is expired")
+ }
+}
+
+// insecureRegisterClient attempts to connects to the Auth Server using the
+// CA on disk. If no CA is found on disk, Teleport will not verify the Auth
+// Server it is connecting to.
+func insecureRegisterClient(params RegisterParams) (*authclient.Client, error) {
+ log.Warnf("Joining cluster without validating the identity of the Auth " +
+ "Server. This may open you up to a Man-In-The-Middle (MITM) attack if an " +
+ "attacker can gain privileged network access. To remedy this, use the CA pin " +
+ "value provided when join token was generated to validate the identity of " +
+ "the Auth Server or point to a valid Certificate via the CA Path option.")
+
+ tlsConfig := utils.TLSConfig(params.CipherSuites)
+ tlsConfig.Time = params.Clock.Now
+ tlsConfig.InsecureSkipVerify = true
+
+ client, err := authclient.NewClient(client.Config{
+ Addrs: getHostAddresses(params),
+ Credentials: []client.Credentials{
+ client.LoadTLS(tlsConfig),
+ },
+ CircuitBreakerConfig: params.CircuitBreakerConfig,
+ })
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+
+ return client, nil
+}
+
+// pinRegisterClient first connects to the Auth Server using a insecure
+// connection to fetch the root CA. If the root CA matches the provided CA
+// pin, a connection will be re-established and the root CA will be used to
+// validate the certificate presented. If both conditions hold true, then we
+// know we are connecting to the expected Auth Server.
+func pinRegisterClient(
+ ctx context.Context, params RegisterParams,
+) (*authclient.Client, error) {
+ // Build a insecure client to the Auth Server. This is safe because even if
+ // an attacker were to MITM this connection the CA pin will not match below.
+ tlsConfig := utils.TLSConfig(params.CipherSuites)
+ tlsConfig.InsecureSkipVerify = true
+ tlsConfig.Time = params.Clock.Now
+ authClient, err := authclient.NewClient(client.Config{
+ Addrs: getHostAddresses(params),
+ Credentials: []client.Credentials{
+ client.LoadTLS(tlsConfig),
+ },
+ CircuitBreakerConfig: params.CircuitBreakerConfig,
+ })
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ defer authClient.Close()
+
+ // Fetch the root CA from the Auth Server. The NOP role has access to the
+ // GetClusterCACert endpoint.
+ localCA, err := authClient.GetClusterCACert(ctx)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ certs, err := tlsca.ParseCertificatePEMs(localCA.TLSCA)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+
+ // Check that the SPKI pin matches the CA we fetched over a insecure
+ // connection. This makes sure the CA fetched over a insecure connection is
+ // in-fact the expected CA.
+ err = utils.CheckSPKI(params.CAPins, certs)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+
+ for _, cert := range certs {
+ // Check that the fetched CA is valid at the current time.
+ err = utils.VerifyCertificateExpiry(cert, params.Clock)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+
+ }
+ log.Infof("Joining remote cluster %v with CA pin.", certs[0].Subject.CommonName)
+
+ // Create another client, but this time with the CA provided to validate
+ // that the Auth Server was issued a certificate by the same CA.
+ tlsConfig = utils.TLSConfig(params.CipherSuites)
+ tlsConfig.Time = params.Clock.Now
+ certPool := x509.NewCertPool()
+ for _, cert := range certs {
+ certPool.AddCert(cert)
+ }
+ tlsConfig.RootCAs = certPool
+
+ authClient, err = authclient.NewClient(client.Config{
+ Addrs: getHostAddresses(params),
+ Credentials: []client.Credentials{
+ client.LoadTLS(tlsConfig),
+ },
+ CircuitBreakerConfig: params.CircuitBreakerConfig,
+ })
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+
+ return authClient, nil
+}
+
+func caPathRegisterClient(params RegisterParams) (*authclient.Client, error) {
+ tlsConfig := utils.TLSConfig(params.CipherSuites)
+ tlsConfig.Time = params.Clock.Now
+
+ cert, err := readCA(params.CAPath)
+ if err != nil && !trace.IsNotFound(err) {
+ return nil, trace.Wrap(err)
+ }
+
+ // If we're unable to read the file at CAPath, we fall back to insecure
+ // registration. This preserves the existing behavior. At a later date,
+ // we may wish to consider changing this to return an error - but this is a
+ // breaking change.
+ if trace.IsNotFound(err) {
+ log.Warnf("Falling back to insecurely joining because a missing or empty CA Path was provided.")
+ return insecureRegisterClient(params)
+ }
+
+ certPool := x509.NewCertPool()
+ certPool.AddCert(cert)
+ tlsConfig.RootCAs = certPool
+
+ log.Infof("Joining remote cluster %v, validating connection with certificate on disk.", cert.Subject.CommonName)
+
+ client, err := authclient.NewClient(client.Config{
+ Addrs: getHostAddresses(params),
+ Credentials: []client.Credentials{
+ client.LoadTLS(tlsConfig),
+ },
+ CircuitBreakerConfig: params.CircuitBreakerConfig,
+ })
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+
+ return client, nil
+}
+
+type joinServiceClient interface {
+ RegisterUsingIAMMethod(ctx context.Context, challengeResponse client.RegisterIAMChallengeResponseFunc) (*proto.Certs, error)
+ RegisterUsingAzureMethod(ctx context.Context, challengeResponse client.RegisterAzureChallengeResponseFunc) (*proto.Certs, error)
+ RegisterUsingTPMMethod(
+ ctx context.Context,
+ initReq *proto.RegisterUsingTPMMethodInitialRequest,
+ solveChallenge client.RegisterTPMChallengeResponseFunc,
+ ) (*proto.Certs, error)
+}
+
+func registerUsingTokenRequestForParams(token string, params RegisterParams) *types.RegisterUsingTokenRequest {
+ return &types.RegisterUsingTokenRequest{
+ Token: token,
+ HostID: params.ID.HostUUID,
+ NodeName: params.ID.NodeName,
+ Role: params.ID.Role,
+ AdditionalPrincipals: params.AdditionalPrincipals,
+ DNSNames: params.DNSNames,
+ PublicTLSKey: params.PublicTLSKey,
+ PublicSSHKey: params.PublicSSHKey,
+ Expires: params.Expires,
+ }
+}
+
+// registerUsingIAMMethod is used to register using the IAM join method. It is
+// able to register through a proxy or through the auth server directly.
+func registerUsingIAMMethod(
+ ctx context.Context, joinServiceClient joinServiceClient, token string, params RegisterParams,
+) (*proto.Certs, error) {
+ log.Infof("Attempting to register %s with IAM method using regional STS endpoint", params.ID.Role)
+ // Call RegisterUsingIAMMethod and pass a callback to respond to the challenge with a signed join request.
+ certs, err := joinServiceClient.RegisterUsingIAMMethod(ctx, func(challenge string) (*proto.RegisterUsingIAMMethodRequest, error) {
+ // create the signed sts:GetCallerIdentity request and include the challenge
+ signedRequest, err := iam.CreateSignedSTSIdentityRequest(ctx, challenge,
+ iam.WithFIPSEndpoint(params.FIPS),
+ iam.WithRegionalEndpoint(true),
+ )
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+
+ // send the register request including the challenge response
+ return &proto.RegisterUsingIAMMethodRequest{
+ RegisterUsingTokenRequest: registerUsingTokenRequestForParams(token, params),
+ StsIdentityRequest: signedRequest,
+ }, nil
+ })
+ if err != nil {
+ log.WithError(err).Infof("Failed to register %s using regional STS endpoint", params.ID.Role)
+ return nil, trace.Wrap(err)
+ }
+
+ log.Infof("Successfully registered %s with IAM method using regional STS endpoint", params.ID.Role)
+ return certs, nil
+}
+
+// registerUsingAzureMethod is used to register using the Azure join method. It
+// is able to register through a proxy or through the auth server directly.
+func registerUsingAzureMethod(
+ ctx context.Context, client joinServiceClient, token string, params RegisterParams,
+) (*proto.Certs, error) {
+ certs, err := client.RegisterUsingAzureMethod(ctx, func(challenge string) (*proto.RegisterUsingAzureMethodRequest, error) {
+ imds := azure.NewInstanceMetadataClient()
+ if !imds.IsAvailable(ctx) {
+ return nil, trace.AccessDenied("could not reach instance metadata. Is Teleport running on an Azure VM?")
+ }
+ ad, err := imds.GetAttestedData(ctx, challenge)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ accessToken, err := imds.GetAccessToken(ctx, params.AzureParams.ClientID)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+
+ return &proto.RegisterUsingAzureMethodRequest{
+ RegisterUsingTokenRequest: registerUsingTokenRequestForParams(token, params),
+ AttestedData: ad,
+ AccessToken: accessToken,
+ }, nil
+ })
+ return certs, trace.Wrap(err)
+}
+
+// registerUsingTPMMethod is used to register using the TPM join method. It
+// is able to register through a proxy or through the auth server directly.
+func registerUsingTPMMethod(
+ ctx context.Context,
+ client joinServiceClient,
+ token string,
+ params RegisterParams,
+) (*proto.Certs, error) {
+ log := slog.Default()
+
+ initReq := &proto.RegisterUsingTPMMethodInitialRequest{
+ JoinRequest: registerUsingTokenRequestForParams(token, params),
+ }
+
+ attestation, close, err := tpm.Attest(ctx, log)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ defer func() {
+ if err := close(); err != nil {
+ log.WarnContext(ctx, "Failed to close TPM", "error", err)
+ }
+ }()
+
+ initReq.AttestationParams = tpm.AttestationParametersToProto(
+ attestation.AttestParams,
+ )
+ // Get the EKKey or EKCert. We want to prefer the EKCert if it is available
+ // as this is signed by the manufacturer.
+ switch {
+ case attestation.Data.EKCert != nil:
+ log.DebugContext(
+ ctx,
+ "Using EKCert for TPM registration",
+ "ekcert_serial", attestation.Data.EKCert.SerialNumber,
+ )
+ initReq.Ek = &proto.RegisterUsingTPMMethodInitialRequest_EkCert{
+ EkCert: attestation.Data.EKCert.Raw,
+ }
+ case attestation.Data.EKPub != nil:
+ log.DebugContext(
+ ctx,
+ "Using EKKey for TPM registration",
+ "ekpub_hash", attestation.Data.EKPubHash,
+ )
+ initReq.Ek = &proto.RegisterUsingTPMMethodInitialRequest_EkKey{
+ EkKey: attestation.Data.EKPub,
+ }
+ default:
+ return nil, trace.BadParameter("tpm has neither ekkey or ekcert")
+ }
+
+ // Submit initial request to the Auth Server.
+ certs, err := client.RegisterUsingTPMMethod(
+ ctx,
+ initReq,
+ func(
+ challenge *proto.TPMEncryptedCredential,
+ ) (*proto.RegisterUsingTPMMethodChallengeResponse, error) {
+ // Solve the encrypted credential with our AK to prove possession
+ // and obtain the solution we need to complete the ceremony.
+ solution, err := attestation.Solve(tpm.EncryptedCredentialFromProto(
+ challenge,
+ ))
+ if err != nil {
+ return nil, trace.Wrap(err, "activating credential")
+ }
+ return &proto.RegisterUsingTPMMethodChallengeResponse{
+ Solution: solution,
+ }, nil
+ },
+ )
+ return certs, trace.Wrap(err)
+}
+
+// readCA will read in CA that will be used to validate the certificate that
+// the Auth Server presents.
+func readCA(path string) (*x509.Certificate, error) {
+ certBytes, err := utils.ReadPath(path)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ cert, err := tlsca.ParseCertificatePEM(certBytes)
+ if err != nil {
+ return nil, trace.Wrap(err, "failed to parse certificate at %v", path)
+ }
+ return cert, nil
+}
diff --git a/lib/auth/join/join_test.go b/lib/auth/join/join_test.go
new file mode 100644
index 0000000000000..7e32eb55438c2
--- /dev/null
+++ b/lib/auth/join/join_test.go
@@ -0,0 +1,328 @@
+// Teleport
+// Copyright (C) 2024 Gravitational, Inc.
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+package join
+
+import (
+ "context"
+ "crypto/tls"
+ "errors"
+ "net"
+ "os"
+ "testing"
+ "time"
+
+ "github.com/gravitational/trace"
+ "github.com/jonboulle/clockwork"
+ "github.com/stretchr/testify/require"
+ "golang.org/x/crypto/ssh"
+
+ headerv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/header/v1"
+ machineidv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/machineid/v1"
+ "github.com/gravitational/teleport/api/types"
+ apievents "github.com/gravitational/teleport/api/types/events"
+ "github.com/gravitational/teleport/lib/auth"
+ "github.com/gravitational/teleport/lib/auth/machineid/machineidv1"
+ "github.com/gravitational/teleport/lib/auth/state"
+ "github.com/gravitational/teleport/lib/auth/testauthority"
+ "github.com/gravitational/teleport/lib/defaults"
+ "github.com/gravitational/teleport/lib/events"
+ "github.com/gravitational/teleport/lib/fixtures"
+ "github.com/gravitational/teleport/lib/modules"
+ "github.com/gravitational/teleport/lib/tlsca"
+ "github.com/gravitational/teleport/lib/utils"
+)
+
+func TestMain(m *testing.M) {
+ modules.SetInsecureTestMode(true)
+ os.Exit(m.Run())
+}
+
+func newTestTLSServer(t testing.TB) *auth.TestTLSServer {
+ as, err := auth.NewTestAuthServer(auth.TestAuthServerConfig{
+ Dir: t.TempDir(),
+ Clock: clockwork.NewFakeClockAt(time.Now().Round(time.Second).UTC()),
+ })
+ require.NoError(t, err)
+
+ srv, err := as.NewTestTLSServer()
+ require.NoError(t, err)
+
+ t.Cleanup(func() {
+ err := srv.Close()
+ if errors.Is(err, net.ErrClosed) {
+ return
+ }
+ require.NoError(t, err)
+ })
+
+ return srv
+}
+
+// TestRegister_Bot tests that a provision token can be used to generate
+// renewable certificates for a non-interactive user.
+func TestRegister_Bot(t *testing.T) {
+ t.Parallel()
+ ctx := context.Background()
+
+ srv := newTestTLSServer(t)
+
+ bot, err := machineidv1.UpsertBot(ctx, srv.Auth(), &machineidv1pb.Bot{
+ Metadata: &headerv1.Metadata{
+ Name: "test",
+ },
+ Spec: &machineidv1pb.BotSpec{
+ Roles: []string{},
+ },
+ }, srv.Clock().Now(), "")
+ require.NoError(t, err)
+
+ later := srv.Clock().Now().Add(4 * time.Hour)
+
+ goodToken := newBotToken(t, "good-token", bot.Metadata.Name, types.RoleBot, later)
+ expiredToken := newBotToken(t, "expired", bot.Metadata.Name, types.RoleBot, srv.Clock().Now().Add(-1*time.Hour))
+ wrongKind := newBotToken(t, "wrong-kind", "", types.RoleNode, later)
+ wrongUser := newBotToken(t, "wrong-user", "llama", types.RoleBot, later)
+ invalidToken := newBotToken(t, "this-token-does-not-exist", bot.Metadata.Name, types.RoleBot, later)
+
+ err = srv.Auth().UpsertToken(ctx, goodToken)
+ require.NoError(t, err)
+ err = srv.Auth().UpsertToken(ctx, expiredToken)
+ require.NoError(t, err)
+ err = srv.Auth().UpsertToken(ctx, wrongKind)
+ require.NoError(t, err)
+ err = srv.Auth().UpsertToken(ctx, wrongUser)
+ require.NoError(t, err)
+
+ privateKey, publicKey, err := testauthority.New().GenerateKeyPair()
+ require.NoError(t, err)
+ sshPrivateKey, err := ssh.ParseRawPrivateKey(privateKey)
+ require.NoError(t, err)
+ tlsPublicKey, err := tlsca.MarshalPublicKeyFromPrivateKeyPEM(sshPrivateKey)
+ require.NoError(t, err)
+
+ for _, test := range []struct {
+ desc string
+ token types.ProvisionToken
+ assertErr require.ErrorAssertionFunc
+ }{
+ {
+ desc: "OK good token",
+ token: goodToken,
+ assertErr: require.NoError,
+ },
+ {
+ desc: "NOK expired token",
+ token: expiredToken,
+ assertErr: require.Error,
+ },
+ {
+ desc: "NOK wrong token kind",
+ token: wrongKind,
+ assertErr: require.Error,
+ },
+ {
+ desc: "NOK token for wrong user",
+ token: wrongUser,
+ assertErr: require.Error,
+ },
+ {
+ desc: "NOK invalid token",
+ token: invalidToken,
+ assertErr: require.Error,
+ },
+ } {
+ t.Run(test.desc, func(t *testing.T) {
+ start := srv.Clock().Now()
+ certs, err := Register(ctx, RegisterParams{
+ Token: test.token.GetName(),
+ ID: state.IdentityID{
+ Role: types.RoleBot,
+ },
+ AuthServers: []utils.NetAddr{*utils.MustParseAddr(srv.Addr().String())},
+ PublicTLSKey: tlsPublicKey,
+ PublicSSHKey: publicKey,
+ })
+ test.assertErr(t, err)
+
+ if err == nil {
+ require.NotEmpty(t, certs.SSH)
+ require.NotEmpty(t, certs.TLS)
+
+ // ensure token was removed
+ _, err = srv.Auth().GetToken(ctx, test.token.GetName())
+ require.True(t, trace.IsNotFound(err), "expected not found error, got %v", err)
+
+ // ensure cert is renewable
+ x509, err := tlsca.ParseCertificatePEM(certs.TLS)
+ require.NoError(t, err)
+ id, err := tlsca.FromSubject(x509.Subject, later)
+ require.NoError(t, err)
+ require.True(t, id.Renewable)
+
+ // Check audit event
+ evts, _, err := srv.Auth().SearchEvents(ctx, events.SearchEventsRequest{
+ From: start,
+ To: srv.Clock().Now(),
+ EventTypes: []string{events.BotJoinEvent},
+ Limit: 1,
+ Order: types.EventOrderDescending,
+ })
+ require.NoError(t, err)
+ require.Len(t, evts, 1)
+ evt, ok := evts[0].(*apievents.BotJoin)
+ require.True(t, ok)
+ require.Equal(t, events.BotJoinEvent, evt.Type)
+ require.Equal(t, events.BotJoinCode, evt.Code)
+ require.EqualValues(t, types.JoinMethodToken, evt.Method)
+ }
+ })
+ }
+}
+
+// TestRegister_Bot_Expiry checks that bot certificate expiry can be set, and
+// does not exceed the limit.
+func TestRegister_Bot_Expiry(t *testing.T) {
+ t.Parallel()
+ ctx := context.Background()
+
+ srv := newTestTLSServer(t)
+ privateKey, publicKey, err := testauthority.New().GenerateKeyPair()
+ require.NoError(t, err)
+ sshPrivateKey, err := ssh.ParseRawPrivateKey(privateKey)
+ require.NoError(t, err)
+ tlsPublicKey, err := tlsca.MarshalPublicKeyFromPrivateKeyPEM(sshPrivateKey)
+ require.NoError(t, err)
+
+ validExpires := srv.Clock().Now().Add(time.Hour * 6)
+ tooGreatExpires := srv.Clock().Now().Add(time.Hour * 24 * 365)
+ tests := []struct {
+ name string
+ requestExpires *time.Time
+ expectTTL time.Duration
+ }{
+ {
+ name: "unspecified defaults",
+ requestExpires: nil,
+ expectTTL: defaults.DefaultRenewableCertTTL,
+ },
+ {
+ name: "valid value specified",
+ requestExpires: &validExpires,
+ expectTTL: time.Hour * 6,
+ },
+ {
+ name: "value exceeding limit specified",
+ requestExpires: &tooGreatExpires,
+ // MaxSessionTTL set in createBotRole is 12 hours, so this cap will
+ // apply instead of the defaults.MaxRenewableCertTTL specified
+ // in generateInitialBotCerts.
+ expectTTL: 12 * time.Hour,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ botName := t.Name()
+ _, err = machineidv1.UpsertBot(ctx, srv.Auth(), &machineidv1pb.Bot{
+ Metadata: &headerv1.Metadata{
+ Name: botName,
+ },
+ Spec: &machineidv1pb.BotSpec{
+ Roles: []string{},
+ Traits: []*machineidv1pb.Trait{},
+ },
+ }, srv.Clock().Now(), "")
+ require.NoError(t, err)
+ tok := newBotToken(t, t.Name(), botName, types.RoleBot, srv.Clock().Now().Add(time.Hour))
+ require.NoError(t, srv.Auth().UpsertToken(ctx, tok))
+
+ certs, err := Register(ctx, RegisterParams{
+ Token: tok.GetName(),
+ ID: state.IdentityID{
+ Role: types.RoleBot,
+ },
+ AuthServers: []utils.NetAddr{*utils.MustParseAddr(srv.Addr().String())},
+ PublicTLSKey: tlsPublicKey,
+ PublicSSHKey: publicKey,
+ Expires: tt.requestExpires,
+ })
+ require.NoError(t, err)
+ x509, err := tlsca.ParseCertificatePEM(certs.TLS)
+ require.NoError(t, err)
+ id, err := tlsca.FromSubject(x509.Subject, x509.NotAfter)
+ require.NoError(t, err)
+
+ ttl := id.Expires.Sub(srv.Clock().Now())
+ require.Equal(t, tt.expectTTL, ttl)
+ })
+ }
+}
+
+func newBotToken(t *testing.T, tokenName, botName string, role types.SystemRole, expiry time.Time) types.ProvisionToken {
+ t.Helper()
+ token, err := types.NewProvisionTokenFromSpec(tokenName, expiry, types.ProvisionTokenSpecV2{
+ Roles: []types.SystemRole{role},
+ BotName: botName,
+ })
+ require.NoError(t, err, "could not create bot token")
+ return token
+}
+
+func TestVerifyALPNUpgradedConn(t *testing.T) {
+ t.Parallel()
+
+ srv := newTestTLSServer(t)
+ proxy, err := auth.NewServerIdentity(srv.Auth(), "test-proxy", types.RoleProxy)
+ require.NoError(t, err)
+
+ tests := []struct {
+ name string
+ serverCert []byte
+ clock clockwork.Clock
+ checkError require.ErrorAssertionFunc
+ }{
+ {
+ name: "proxy verified",
+ serverCert: proxy.TLSCertBytes,
+ clock: srv.Clock(),
+ checkError: require.NoError,
+ },
+ {
+ name: "proxy expired",
+ serverCert: proxy.TLSCertBytes,
+ clock: clockwork.NewFakeClockAt(srv.Clock().Now().Add(defaults.CATTL + time.Hour)),
+ checkError: require.Error,
+ },
+ {
+ name: "not proxy",
+ serverCert: []byte(fixtures.TLSCACertPEM),
+ clock: srv.Clock(),
+ checkError: require.Error,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ serverCert, err := utils.ReadCertificates(test.serverCert)
+ require.NoError(t, err)
+
+ test.checkError(t, verifyALPNUpgradedConn(test.clock)(tls.ConnectionState{
+ PeerCertificates: serverCert,
+ }))
+ })
+ }
+}
diff --git a/lib/auth/join_iam.go b/lib/auth/join_iam.go
index b52a8b6b35b81..6668a3449b8b5 100644
--- a/lib/auth/join_iam.go
+++ b/lib/auth/join_iam.go
@@ -27,12 +27,7 @@ import (
"net/http"
"net/url"
"slices"
- "strings"
- awssdk "github.com/aws/aws-sdk-go/aws"
- "github.com/aws/aws-sdk-go/aws/endpoints"
- "github.com/aws/aws-sdk-go/aws/session"
- "github.com/aws/aws-sdk-go/service/sts"
"github.com/coreos/go-semver/semver"
"github.com/gravitational/trace"
@@ -40,7 +35,7 @@ import (
"github.com/gravitational/teleport/api/client"
"github.com/gravitational/teleport/api/client/proto"
"github.com/gravitational/teleport/api/types"
- cloudaws "github.com/gravitational/teleport/lib/cloud/imds/aws"
+ "github.com/gravitational/teleport/lib/auth/join/iam"
"github.com/gravitational/teleport/lib/utils"
"github.com/gravitational/teleport/lib/utils/aws"
)
@@ -71,7 +66,7 @@ const (
// against a static list of known valid endpoints. We will need to update this
// list as AWS adds new regions.
func validateSTSHost(stsHost string, cfg *iamRegisterConfig) error {
- valid := slices.Contains(validSTSEndpoints, stsHost)
+ valid := slices.Contains(iam.ValidSTSEndpoints(), stsHost)
if !valid {
return trace.AccessDenied("IAM join request uses unknown STS host %q. "+
"This could mean that the Teleport Node attempting to join the cluster is "+
@@ -81,10 +76,10 @@ func validateSTSHost(stsHost string, cfg *iamRegisterConfig) error {
"Following is the list of valid STS endpoints known to this auth server. "+
"If a legitimate STS endpoint is not included, please file an issue at "+
"https://github.com/gravitational/teleport. %v",
- stsHost, validSTSEndpoints)
+ stsHost, iam.ValidSTSEndpoints())
}
- if cfg.fips && !slices.Contains(fipsSTSEndpoints, stsHost) {
+ if cfg.fips && !slices.Contains(iam.FIPSSTSEndpoints(), stsHost) {
return trace.AccessDenied("node selected non-FIPS STS endpoint (%s) for the IAM join method", stsHost)
}
@@ -393,124 +388,3 @@ func (a *Server) RegisterUsingIAMMethod(
certs, err = a.generateCerts(ctx, provisionToken, req.RegisterUsingTokenRequest, nil)
return certs, trace.Wrap(err)
}
-
-type stsIdentityRequestConfig struct {
- regionalEndpointOption endpoints.STSRegionalEndpoint
- fipsEndpointOption endpoints.FIPSEndpointState
-}
-
-type stsIdentityRequestOption func(cfg *stsIdentityRequestConfig)
-
-func withRegionalEndpoint(useRegionalEndpoint bool) stsIdentityRequestOption {
- return func(cfg *stsIdentityRequestConfig) {
- if useRegionalEndpoint {
- cfg.regionalEndpointOption = endpoints.RegionalSTSEndpoint
- } else {
- cfg.regionalEndpointOption = endpoints.LegacySTSEndpoint
- }
- }
-}
-
-func withFIPSEndpoint(useFIPS bool) stsIdentityRequestOption {
- return func(cfg *stsIdentityRequestConfig) {
- if useFIPS {
- cfg.fipsEndpointOption = endpoints.FIPSEndpointStateEnabled
- } else {
- cfg.fipsEndpointOption = endpoints.FIPSEndpointStateDisabled
- }
- }
-}
-
-// createSignedSTSIdentityRequest is called on the client side and returns an
-// sts:GetCallerIdentity request signed with the local AWS credentials
-func createSignedSTSIdentityRequest(ctx context.Context, challenge string, opts ...stsIdentityRequestOption) ([]byte, error) {
- cfg := &stsIdentityRequestConfig{}
- for _, opt := range opts {
- opt(cfg)
- }
-
- stsClient, err := newSTSClient(ctx, cfg)
- if err != nil {
- return nil, trace.Wrap(err)
- }
-
- req, _ := stsClient.GetCallerIdentityRequest(&sts.GetCallerIdentityInput{})
- // set challenge header
- req.HTTPRequest.Header.Set(challengeHeaderKey, challenge)
- // request json for simpler parsing
- req.HTTPRequest.Header.Set("Accept", "application/json")
- // sign the request, including headers
- if err := req.Sign(); err != nil {
- return nil, trace.Wrap(err)
- }
- // write the signed HTTP request to a buffer
- var signedRequest bytes.Buffer
- if err := req.HTTPRequest.Write(&signedRequest); err != nil {
- return nil, trace.Wrap(err)
- }
- return signedRequest.Bytes(), nil
-}
-
-func newSTSClient(ctx context.Context, cfg *stsIdentityRequestConfig) (*sts.STS, error) {
- awsConfig := awssdk.Config{
- UseFIPSEndpoint: cfg.fipsEndpointOption,
- STSRegionalEndpoint: cfg.regionalEndpointOption,
- }
- sess, err := session.NewSessionWithOptions(session.Options{
- SharedConfigState: session.SharedConfigEnable,
- Config: awsConfig,
- })
- if err != nil {
- return nil, trace.Wrap(err)
- }
-
- stsClient := sts.New(sess)
-
- if slices.Contains(globalSTSEndpoints, strings.TrimPrefix(stsClient.Endpoint, "https://")) {
- // If the caller wants to use the regional endpoint but it was not resolved
- // from the environment, attempt to find the region from the EC2 IMDS
- if cfg.regionalEndpointOption == endpoints.RegionalSTSEndpoint {
- region, err := getEC2LocalRegion(ctx)
- if err != nil {
- return nil, trace.Wrap(err, "failed to resolve local AWS region from environment or IMDS")
- }
- stsClient = sts.New(sess, awssdk.NewConfig().WithRegion(region))
- } else {
- log.Info("Attempting to use the global STS endpoint for the IAM join method. " +
- "This will probably fail in non-default AWS partitions such as China or GovCloud, or if FIPS mode is enabled. " +
- "Consider setting the AWS_REGION environment variable, setting the region in ~/.aws/config, or enabling the IMDSv2.")
- }
- }
-
- if cfg.fipsEndpointOption == endpoints.FIPSEndpointStateEnabled &&
- !slices.Contains(validSTSEndpoints, strings.TrimPrefix(stsClient.Endpoint, "https://")) {
- // The AWS SDK will generate invalid endpoints when attempting to
- // resolve the FIPS endpoint for a region that does not have one.
- // In this case, try to use the FIPS endpoint in us-east-1. This should
- // work for all regions in the standard partition. In GovCloud, we should
- // not hit this because all regional endpoints support FIPS. In China or
- // other partitions, this will fail, and FIPS mode will not be supported.
- log.Infof("AWS SDK resolved FIPS STS endpoint %s, which does not appear to be valid. "+
- "Attempting to use the FIPS STS endpoint for us-east-1.",
- stsClient.Endpoint)
- stsClient = sts.New(sess, awssdk.NewConfig().WithRegion("us-east-1"))
- }
-
- return stsClient, nil
-}
-
-// getEC2LocalRegion returns the AWS region this EC2 instance is running in, or
-// a NotFound error if the EC2 IMDS is unavailable.
-func getEC2LocalRegion(ctx context.Context) (string, error) {
- imdsClient, err := cloudaws.NewInstanceMetadataClient(ctx)
- if err != nil {
- return "", trace.Wrap(err)
- }
-
- if !imdsClient.IsAvailable(ctx) {
- return "", trace.NotFound("IMDS is unavailable")
- }
-
- region, err := imdsClient.GetRegion(ctx)
- return region, trace.Wrap(err)
-}
diff --git a/lib/auth/join_test.go b/lib/auth/join_test.go
index 9e207f9667a3f..6ce94edcc0ab6 100644
--- a/lib/auth/join_test.go
+++ b/lib/auth/join_test.go
@@ -34,6 +34,7 @@ import (
"github.com/gravitational/teleport/api/types"
apievents "github.com/gravitational/teleport/api/types/events"
"github.com/gravitational/teleport/api/utils/sshutils"
+ "github.com/gravitational/teleport/lib/auth/join"
"github.com/gravitational/teleport/lib/auth/machineid/machineidv1"
"github.com/gravitational/teleport/lib/auth/state"
"github.com/gravitational/teleport/lib/auth/testauthority"
@@ -371,7 +372,7 @@ func TestRegister_Bot(t *testing.T) {
} {
t.Run(test.desc, func(t *testing.T) {
start := srv.Clock().Now()
- certs, err := Register(ctx, RegisterParams{
+ certs, err := join.Register(ctx, join.RegisterParams{
Token: test.token.GetName(),
ID: state.IdentityID{
Role: types.RoleBot,
@@ -474,7 +475,7 @@ func TestRegister_Bot_Expiry(t *testing.T) {
tok := newBotToken(t, t.Name(), botName, types.RoleBot, srv.Clock().Now().Add(time.Hour))
require.NoError(t, srv.Auth().UpsertToken(ctx, tok))
- certs, err := Register(ctx, RegisterParams{
+ certs, err := join.Register(ctx, join.RegisterParams{
Token: tok.GetName(),
ID: state.IdentityID{
Role: types.RoleBot,
diff --git a/lib/auth/register.go b/lib/auth/register.go
index 24e48594d6bc4..b425bdf8c3ff2 100644
--- a/lib/auth/register.go
+++ b/lib/auth/register.go
@@ -20,44 +20,15 @@ package auth
import (
"context"
- "crypto/tls"
- "crypto/x509"
- "log/slog"
- "os"
- "slices"
- "time"
"github.com/gravitational/trace"
- "github.com/jonboulle/clockwork"
- "golang.org/x/net/http2"
- "google.golang.org/grpc"
- "google.golang.org/grpc/credentials"
- "github.com/gravitational/teleport"
- "github.com/gravitational/teleport/api/breaker"
- "github.com/gravitational/teleport/api/client"
"github.com/gravitational/teleport/api/client/proto"
- "github.com/gravitational/teleport/api/constants"
- apidefaults "github.com/gravitational/teleport/api/defaults"
- "github.com/gravitational/teleport/api/metadata"
- "github.com/gravitational/teleport/api/observability/tracing"
"github.com/gravitational/teleport/api/types"
- "github.com/gravitational/teleport/api/utils/aws"
"github.com/gravitational/teleport/lib/auth/authclient"
"github.com/gravitational/teleport/lib/auth/native"
"github.com/gravitational/teleport/lib/auth/state"
- "github.com/gravitational/teleport/lib/circleci"
- "github.com/gravitational/teleport/lib/cloud/imds/azure"
- "github.com/gravitational/teleport/lib/cloud/imds/gcp"
"github.com/gravitational/teleport/lib/defaults"
- "github.com/gravitational/teleport/lib/githubactions"
- "github.com/gravitational/teleport/lib/gitlab"
- "github.com/gravitational/teleport/lib/kubernetestoken"
- "github.com/gravitational/teleport/lib/spacelift"
- "github.com/gravitational/teleport/lib/srv/alpnproxy/common"
- "github.com/gravitational/teleport/lib/tlsca"
- "github.com/gravitational/teleport/lib/tpm"
- "github.com/gravitational/teleport/lib/utils"
)
// LocalRegister is used to generate host keys when a node or proxy is running
@@ -104,739 +75,6 @@ func LocalRegister(id state.IdentityID, authServer *Server, additionalPrincipals
return identity, nil
}
-// AzureParams is the parameters specific to the azure join method.
-type AzureParams struct {
- // ClientID is the client ID of the managed identity for Teleport to assume
- // when authenticating a node.
- ClientID string
-}
-
-// RegisterParams specifies parameters
-// for first time register operation with auth server
-type RegisterParams struct {
- // Token is a secure token to join the cluster
- Token string
- // ID is identity ID
- ID state.IdentityID
- // AuthServers is a list of auth servers to dial
- AuthServers []utils.NetAddr
- // ProxyServer is a proxy server to dial
- ProxyServer utils.NetAddr
- // AdditionalPrincipals is a list of additional principals to dial
- AdditionalPrincipals []string
- // DNSNames is a list of DNS names to add to x509 certificate
- DNSNames []string
- // PublicTLSKey is a server's public key to sign
- PublicTLSKey []byte
- // PublicSSHKey is a server's public SSH key to sign
- PublicSSHKey []byte
- // CipherSuites is a list of cipher suites to use for TLS client connection
- CipherSuites []uint16
- // CAPins are the SKPI hashes of the CAs used to verify the Auth Server.
- CAPins []string
- // CAPath is the path to the CA file.
- CAPath string
- // GetHostCredentials is a client that can fetch host credentials.
- GetHostCredentials HostCredentials
- // Clock specifies the time provider. Will be used to override the time anchor
- // for TLS certificate verification.
- // Defaults to real clock if unspecified
- Clock clockwork.Clock
- // JoinMethod is the joining method used for this register request.
- JoinMethod types.JoinMethod
- // ec2IdentityDocument is used for Simplified Node Joining to prove the
- // identity of a joining EC2 instance.
- ec2IdentityDocument []byte
- // AzureParams is the parameters specific to the azure join method.
- AzureParams AzureParams
- // CircuitBreakerConfig defines how the circuit breaker should behave.
- CircuitBreakerConfig breaker.Config
- // FIPS means FedRAMP/FIPS 140-2 compliant configuration was requested.
- FIPS bool
- // IDToken is a token retrieved from a workload identity provider for
- // certain join types e.g GitHub, Google.
- IDToken string
- // Expires is an optional field for bots that specifies a time that the
- // certificates that are returned by registering should expire at.
- // It should not be specified for non-bot registrations.
- Expires *time.Time
- // Insecure trusts the certificates from the Auth Server or Proxy during registration without verification.
- Insecure bool
-}
-
-func (r *RegisterParams) checkAndSetDefaults() error {
- if r.Clock == nil {
- r.Clock = clockwork.NewRealClock()
- }
-
- if err := r.verifyAuthOrProxyAddress(); err != nil {
- return trace.BadParameter("no auth or proxy servers set")
- }
-
- return nil
-}
-
-func (r *RegisterParams) verifyAuthOrProxyAddress() error {
- haveAuthServers := len(r.AuthServers) > 0
- haveProxyServer := !r.ProxyServer.IsEmpty()
-
- if !haveAuthServers && !haveProxyServer {
- return trace.BadParameter("no auth or proxy servers set")
- }
-
- if haveAuthServers && haveProxyServer {
- return trace.BadParameter("only one of auth or proxy server should be set")
- }
-
- return nil
-}
-
-// CredGetter is an interface for a client that can be used to get host
-// credentials. This interface is needed because lib/client can not be imported
-// in lib/auth due to circular imports.
-type HostCredentials func(context.Context, string, bool, types.RegisterUsingTokenRequest) (*proto.Certs, error)
-
-// Register is used to generate host keys when a node or proxy are running on
-// different hosts than the auth server. This method requires provisioning
-// tokens to prove a valid auth server was used to issue the joining request
-// as well as a method for the node to validate the auth server.
-func Register(ctx context.Context, params RegisterParams) (certs *proto.Certs, err error) {
- ctx, span := tracer.Start(ctx, "Register")
- defer func() { tracing.EndSpan(span, err) }()
-
- if err := params.checkAndSetDefaults(); err != nil {
- return nil, trace.Wrap(err)
- }
- // Read in the token. The token can either be passed in or come from a file
- // on disk.
- token, err := utils.TryReadValueAsFile(params.Token)
- if err != nil {
- return nil, trace.Wrap(err)
- }
-
- // add EC2 Identity Document to params if required for given join method
- switch params.JoinMethod {
- case types.JoinMethodEC2:
- if !aws.IsEC2NodeID(params.ID.HostUUID) {
- return nil, trace.BadParameter(
- `Host ID %q is not valid when using the EC2 join method, `+
- `try removing the "host_uuid" file in your teleport data dir `+
- `(e.g. /var/lib/teleport/host_uuid)`,
- params.ID.HostUUID)
- }
- params.ec2IdentityDocument, err = utils.GetRawEC2IdentityDocument(ctx)
- if err != nil {
- return nil, trace.Wrap(err)
- }
- case types.JoinMethodGitHub:
- params.IDToken, err = githubactions.NewIDTokenSource().GetIDToken(ctx)
- if err != nil {
- return nil, trace.Wrap(err)
- }
- case types.JoinMethodGitLab:
- params.IDToken, err = gitlab.NewIDTokenSource(os.Getenv).GetIDToken()
- if err != nil {
- return nil, trace.Wrap(err)
- }
- case types.JoinMethodCircleCI:
- params.IDToken, err = circleci.GetIDToken(os.Getenv)
- if err != nil {
- return nil, trace.Wrap(err)
- }
- case types.JoinMethodKubernetes:
- params.IDToken, err = kubernetestoken.GetIDToken(os.Getenv, os.ReadFile)
- if err != nil {
- return nil, trace.Wrap(err)
- }
- case types.JoinMethodGCP:
- params.IDToken, err = gcp.GetIDToken(ctx)
- if err != nil {
- return nil, trace.Wrap(err)
- }
- case types.JoinMethodSpacelift:
- params.IDToken, err = spacelift.NewIDTokenSource(os.Getenv).GetIDToken()
- if err != nil {
- return nil, trace.Wrap(err)
- }
- }
-
- type registerMethod struct {
- call func(ctx context.Context, token string, params RegisterParams) (*proto.Certs, error)
- desc string
- }
-
- registerThroughAuth := registerMethod{registerThroughAuth, "with auth server"}
- registerThroughProxy := registerMethod{registerThroughProxy, "via proxy server"}
-
- registerMethods := []registerMethod{registerThroughAuth, registerThroughProxy}
-
- if !params.ProxyServer.IsEmpty() {
- log.WithField("proxy-server", params.ProxyServer).Debugf("Registering node to the cluster.")
-
- registerMethods = []registerMethod{registerThroughProxy}
-
- if proxyServerIsAuth(params.ProxyServer) {
- log.Debugf("The specified proxy server appears to be an auth server.")
- }
- } else {
- log.WithField("auth-servers", params.AuthServers).Debugf("Registering node to the cluster.")
-
- if params.GetHostCredentials == nil {
- log.Debugf("Missing client, it is not possible to register through proxy.")
- registerMethods = []registerMethod{registerThroughAuth}
- } else if authServerIsProxy(params.AuthServers) {
- log.Debugf("The first specified auth server appears to be a proxy.")
- registerMethods = []registerMethod{registerThroughProxy, registerThroughAuth}
- }
- }
-
- var collectedErrs []error
- for _, method := range registerMethods {
- log.Infof("Attempting registration %s.", method.desc)
- certs, err := method.call(ctx, token, params)
- if err != nil {
- collectedErrs = append(collectedErrs, err)
- log.WithError(err).Debugf("Registration %s failed.", method.desc)
- continue
- }
- log.Infof("Successfully registered %s.", method.desc)
- return certs, nil
- }
- return nil, trace.NewAggregate(collectedErrs...)
-}
-
-// authServerIsProxy returns true if the first specified auth server
-// to register with appears to be a proxy.
-func authServerIsProxy(servers []utils.NetAddr) bool {
- if len(servers) == 0 {
- return false
- }
- port := servers[0].Port(0)
- return port == defaults.HTTPListenPort || port == teleport.StandardHTTPSPort
-}
-
-// proxyServerIsAuth returns true if the address given to register with
-// appears to be an auth server.
-func proxyServerIsAuth(server utils.NetAddr) bool {
- port := server.Port(0)
- return port == defaults.AuthListenPort
-}
-
-// registerThroughProxy is used to register through the proxy server.
-func registerThroughProxy(
- ctx context.Context,
- token string,
- params RegisterParams,
-) (certs *proto.Certs, err error) {
- ctx, span := tracer.Start(ctx, "registerThroughProxy")
- defer func() { tracing.EndSpan(span, err) }()
-
- switch params.JoinMethod {
- case types.JoinMethodIAM, types.JoinMethodAzure, types.JoinMethodTPM:
- // IAM and Azure join methods require gRPC client
- conn, err := proxyJoinServiceConn(ctx, params, params.Insecure)
- if err != nil {
- return nil, trace.Wrap(err)
- }
- defer conn.Close()
-
- joinServiceClient := client.NewJoinServiceClient(proto.NewJoinServiceClient(conn))
- switch params.JoinMethod {
- case types.JoinMethodIAM:
- certs, err = registerUsingIAMMethod(ctx, joinServiceClient, token, params)
- case types.JoinMethodAzure:
- certs, err = registerUsingAzureMethod(ctx, joinServiceClient, token, params)
- case types.JoinMethodTPM:
- certs, err = registerUsingTPMMethod(ctx, joinServiceClient, token, params)
- default:
- return nil, trace.BadParameter("unhandled join method %q", params.JoinMethod)
- }
-
- if err != nil {
- return nil, trace.Wrap(err)
- }
- default:
- // The rest of the join methods use GetHostCredentials function passed through
- // params to call proxy HTTP endpoint
- var err error
- certs, err = params.GetHostCredentials(ctx,
- getHostAddresses(params)[0],
- params.Insecure,
- types.RegisterUsingTokenRequest{
- Token: token,
- HostID: params.ID.HostUUID,
- NodeName: params.ID.NodeName,
- Role: params.ID.Role,
- AdditionalPrincipals: params.AdditionalPrincipals,
- DNSNames: params.DNSNames,
- PublicTLSKey: params.PublicTLSKey,
- PublicSSHKey: params.PublicSSHKey,
- EC2IdentityDocument: params.ec2IdentityDocument,
- IDToken: params.IDToken,
- Expires: params.Expires,
- })
- if err != nil {
- return nil, trace.Wrap(err)
- }
- }
- return certs, nil
-}
-
-func getHostAddresses(params RegisterParams) []string {
- if !params.ProxyServer.IsEmpty() {
- return []string{params.ProxyServer.String()}
- }
-
- return utils.NetAddrsToStrings(params.AuthServers)
-}
-
-// registerThroughAuth is used to register through the auth server.
-func registerThroughAuth(
- ctx context.Context, token string, params RegisterParams,
-) (certs *proto.Certs, err error) {
- ctx, span := tracer.Start(ctx, "registerThroughAuth")
- defer func() { tracing.EndSpan(span, err) }()
-
- var client *authclient.Client
- // Build a client for the Auth Server with different certificate validation
- // depending on the configured values for Insecure, CAPins and CAPath.
- switch {
- case params.Insecure:
- log.Warnf("Insecure mode enabled. Auth Server cert will not be validated and CAPins and CAPath value will be ignored.")
- client, err = insecureRegisterClient(params)
- case len(params.CAPins) != 0:
- // CAPins takes precedence over CAPath
- client, err = pinRegisterClient(ctx, params)
- case params.CAPath != "":
- client, err = caPathRegisterClient(params)
- default:
- // We fall back to insecure mode here - this is a little odd but is
- // necessary to preserve the behavior of registration. At a later date,
- // we may consider making this an error asking the user to provide
- // Insecure, CAPins or CAPath.
- client, err = insecureRegisterClient(params)
- }
- if err != nil {
- return nil, trace.Wrap(err)
- }
- defer client.Close()
-
- switch params.JoinMethod {
- // IAM and Azure methods use unique gRPC endpoints
- case types.JoinMethodIAM:
- certs, err = registerUsingIAMMethod(ctx, client, token, params)
- case types.JoinMethodAzure:
- certs, err = registerUsingAzureMethod(ctx, client, token, params)
- case types.JoinMethodTPM:
- certs, err = registerUsingTPMMethod(ctx, client, token, params)
- default:
- // non-IAM join methods use HTTP endpoint
- // Get the SSH and X509 certificates for a node.
- certs, err = client.RegisterUsingToken(
- ctx,
- &types.RegisterUsingTokenRequest{
- Token: token,
- HostID: params.ID.HostUUID,
- NodeName: params.ID.NodeName,
- Role: params.ID.Role,
- AdditionalPrincipals: params.AdditionalPrincipals,
- DNSNames: params.DNSNames,
- PublicTLSKey: params.PublicTLSKey,
- PublicSSHKey: params.PublicSSHKey,
- EC2IdentityDocument: params.ec2IdentityDocument,
- IDToken: params.IDToken,
- Expires: params.Expires,
- })
- }
- return certs, trace.Wrap(err)
-}
-
-// proxyJoinServiceConn attempts to connect to the join service running on the
-// proxy. The Proxy's TLS cert will be verified using the host's root CA pool
-// (PKI) unless the --insecure flag was passed.
-func proxyJoinServiceConn(
- ctx context.Context, params RegisterParams, insecure bool,
-) (*grpc.ClientConn, error) {
- tlsConfig := utils.TLSConfig(params.CipherSuites)
- tlsConfig.Time = params.Clock.Now
- // set NextProtos for TLS routing, the actual protocol will be h2
- tlsConfig.NextProtos = []string{string(common.ProtocolProxyGRPCInsecure), http2.NextProtoTLS}
-
- if insecure {
- tlsConfig.InsecureSkipVerify = true
- log.Warnf("Joining cluster without validating the identity of the Proxy Server.")
- }
-
- // Check if proxy is behind a load balancer. If so, the connection upgrade
- // will verify the load balancer's cert using system cert pool. This
- // provides the same level of security as the client only verifies Proxy's
- // web cert against system cert pool when connection upgrade is not
- // required.
- //
- // With the ALPN connection upgrade, the tunneled TLS Routing request will
- // skip verify as the Proxy server will present its host cert which is not
- // fully verifiable at this point since the client does not have the host
- // CAs yet before completing registration.
- alpnConnUpgrade := client.IsALPNConnUpgradeRequired(ctx, getHostAddresses(params)[0], insecure)
- if alpnConnUpgrade && !insecure {
- tlsConfig.InsecureSkipVerify = true
- tlsConfig.VerifyConnection = verifyALPNUpgradedConn(params.Clock)
- }
-
- dialer := client.NewDialer(
- ctx,
- apidefaults.DefaultIdleTimeout,
- apidefaults.DefaultIOTimeout,
- client.WithInsecureSkipVerify(insecure),
- client.WithALPNConnUpgrade(alpnConnUpgrade),
- )
-
- conn, err := grpc.Dial(
- getHostAddresses(params)[0],
- grpc.WithContextDialer(client.GRPCContextDialer(dialer)),
- grpc.WithUnaryInterceptor(metadata.UnaryClientInterceptor),
- grpc.WithStreamInterceptor(metadata.StreamClientInterceptor),
- grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)),
- )
- return conn, trace.Wrap(err)
-}
-
-// verifyALPNUpgradedConn is a tls.Config.VerifyConnection callback function
-// used by the tunneled TLS Routing request to verify the host cert of a Proxy
-// behind a L7 load balancer.
-//
-// Since the client has not obtained the cluster CAs at this point, the
-// presented cert cannot be fully verified yet. For now, this function only
-// checks if "teleport.cluster.local" is present as one of the DNS names and
-// verifies the cert is not expired.
-func verifyALPNUpgradedConn(clock clockwork.Clock) func(tls.ConnectionState) error {
- return func(server tls.ConnectionState) error {
- for _, cert := range server.PeerCertificates {
- if slices.Contains(cert.DNSNames, constants.APIDomain) && clock.Now().Before(cert.NotAfter) {
- return nil
- }
- }
- return trace.AccessDenied("server is not a Teleport proxy or server certificate is expired")
- }
-}
-
-// insecureRegisterClient attempts to connects to the Auth Server using the
-// CA on disk. If no CA is found on disk, Teleport will not verify the Auth
-// Server it is connecting to.
-func insecureRegisterClient(params RegisterParams) (*authclient.Client, error) {
- log.Warnf("Joining cluster without validating the identity of the Auth " +
- "Server. This may open you up to a Man-In-The-Middle (MITM) attack if an " +
- "attacker can gain privileged network access. To remedy this, use the CA pin " +
- "value provided when join token was generated to validate the identity of " +
- "the Auth Server or point to a valid Certificate via the CA Path option.")
-
- tlsConfig := utils.TLSConfig(params.CipherSuites)
- tlsConfig.Time = params.Clock.Now
- tlsConfig.InsecureSkipVerify = true
-
- client, err := authclient.NewClient(client.Config{
- Addrs: getHostAddresses(params),
- Credentials: []client.Credentials{
- client.LoadTLS(tlsConfig),
- },
- CircuitBreakerConfig: params.CircuitBreakerConfig,
- })
- if err != nil {
- return nil, trace.Wrap(err)
- }
-
- return client, nil
-}
-
-// readCA will read in CA that will be used to validate the certificate that
-// the Auth Server presents.
-func readCA(path string) (*x509.Certificate, error) {
- certBytes, err := utils.ReadPath(path)
- if err != nil {
- return nil, trace.Wrap(err)
- }
- cert, err := tlsca.ParseCertificatePEM(certBytes)
- if err != nil {
- return nil, trace.Wrap(err, "failed to parse certificate at %v", path)
- }
- return cert, nil
-}
-
-// pinRegisterClient first connects to the Auth Server using a insecure
-// connection to fetch the root CA. If the root CA matches the provided CA
-// pin, a connection will be re-established and the root CA will be used to
-// validate the certificate presented. If both conditions hold true, then we
-// know we are connecting to the expected Auth Server.
-func pinRegisterClient(
- ctx context.Context, params RegisterParams,
-) (*authclient.Client, error) {
- // Build a insecure client to the Auth Server. This is safe because even if
- // an attacker were to MITM this connection the CA pin will not match below.
- tlsConfig := utils.TLSConfig(params.CipherSuites)
- tlsConfig.InsecureSkipVerify = true
- tlsConfig.Time = params.Clock.Now
- authClient, err := authclient.NewClient(client.Config{
- Addrs: getHostAddresses(params),
- Credentials: []client.Credentials{
- client.LoadTLS(tlsConfig),
- },
- CircuitBreakerConfig: params.CircuitBreakerConfig,
- })
- if err != nil {
- return nil, trace.Wrap(err)
- }
- defer authClient.Close()
-
- // Fetch the root CA from the Auth Server. The NOP role has access to the
- // GetClusterCACert endpoint.
- localCA, err := authClient.GetClusterCACert(ctx)
- if err != nil {
- return nil, trace.Wrap(err)
- }
- certs, err := tlsca.ParseCertificatePEMs(localCA.TLSCA)
- if err != nil {
- return nil, trace.Wrap(err)
- }
-
- // Check that the SPKI pin matches the CA we fetched over a insecure
- // connection. This makes sure the CA fetched over a insecure connection is
- // in-fact the expected CA.
- err = utils.CheckSPKI(params.CAPins, certs)
- if err != nil {
- return nil, trace.Wrap(err)
- }
-
- for _, cert := range certs {
- // Check that the fetched CA is valid at the current time.
- err = utils.VerifyCertificateExpiry(cert, params.Clock)
- if err != nil {
- return nil, trace.Wrap(err)
- }
-
- }
- log.Infof("Joining remote cluster %v with CA pin.", certs[0].Subject.CommonName)
-
- // Create another client, but this time with the CA provided to validate
- // that the Auth Server was issued a certificate by the same CA.
- tlsConfig = utils.TLSConfig(params.CipherSuites)
- tlsConfig.Time = params.Clock.Now
- certPool := x509.NewCertPool()
- for _, cert := range certs {
- certPool.AddCert(cert)
- }
- tlsConfig.RootCAs = certPool
-
- authClient, err = authclient.NewClient(client.Config{
- Addrs: getHostAddresses(params),
- Credentials: []client.Credentials{
- client.LoadTLS(tlsConfig),
- },
- CircuitBreakerConfig: params.CircuitBreakerConfig,
- })
- if err != nil {
- return nil, trace.Wrap(err)
- }
-
- return authClient, nil
-}
-
-func caPathRegisterClient(params RegisterParams) (*authclient.Client, error) {
- tlsConfig := utils.TLSConfig(params.CipherSuites)
- tlsConfig.Time = params.Clock.Now
-
- cert, err := readCA(params.CAPath)
- if err != nil && !trace.IsNotFound(err) {
- return nil, trace.Wrap(err)
- }
-
- // If we're unable to read the file at CAPath, we fall back to insecure
- // registration. This preserves the existing behavior. At a later date,
- // we may wish to consider changing this to return an error - but this is a
- // breaking change.
- if trace.IsNotFound(err) {
- log.Warnf("Falling back to insecurely joining because a missing or empty CA Path was provided.")
- return insecureRegisterClient(params)
- }
-
- certPool := x509.NewCertPool()
- certPool.AddCert(cert)
- tlsConfig.RootCAs = certPool
-
- log.Infof("Joining remote cluster %v, validating connection with certificate on disk.", cert.Subject.CommonName)
-
- client, err := authclient.NewClient(client.Config{
- Addrs: getHostAddresses(params),
- Credentials: []client.Credentials{
- client.LoadTLS(tlsConfig),
- },
- CircuitBreakerConfig: params.CircuitBreakerConfig,
- })
- if err != nil {
- return nil, trace.Wrap(err)
- }
-
- return client, nil
-}
-
-type joinServiceClient interface {
- RegisterUsingIAMMethod(ctx context.Context, challengeResponse client.RegisterIAMChallengeResponseFunc) (*proto.Certs, error)
- RegisterUsingAzureMethod(ctx context.Context, challengeResponse client.RegisterAzureChallengeResponseFunc) (*proto.Certs, error)
- RegisterUsingTPMMethod(
- ctx context.Context,
- initReq *proto.RegisterUsingTPMMethodInitialRequest,
- solveChallenge client.RegisterTPMChallengeResponseFunc,
- ) (*proto.Certs, error)
-}
-
-func registerUsingTokenRequestForParams(token string, params RegisterParams) *types.RegisterUsingTokenRequest {
- return &types.RegisterUsingTokenRequest{
- Token: token,
- HostID: params.ID.HostUUID,
- NodeName: params.ID.NodeName,
- Role: params.ID.Role,
- AdditionalPrincipals: params.AdditionalPrincipals,
- DNSNames: params.DNSNames,
- PublicTLSKey: params.PublicTLSKey,
- PublicSSHKey: params.PublicSSHKey,
- Expires: params.Expires,
- }
-}
-
-// registerUsingIAMMethod is used to register using the IAM join method. It is
-// able to register through a proxy or through the auth server directly.
-func registerUsingIAMMethod(
- ctx context.Context, joinServiceClient joinServiceClient, token string, params RegisterParams,
-) (*proto.Certs, error) {
- log.Infof("Attempting to register %s with IAM method using regional STS endpoint", params.ID.Role)
- // Call RegisterUsingIAMMethod and pass a callback to respond to the challenge with a signed join request.
- certs, err := joinServiceClient.RegisterUsingIAMMethod(ctx, func(challenge string) (*proto.RegisterUsingIAMMethodRequest, error) {
- // create the signed sts:GetCallerIdentity request and include the challenge
- signedRequest, err := createSignedSTSIdentityRequest(ctx, challenge,
- withFIPSEndpoint(params.FIPS),
- withRegionalEndpoint(true),
- )
- if err != nil {
- return nil, trace.Wrap(err)
- }
-
- // send the register request including the challenge response
- return &proto.RegisterUsingIAMMethodRequest{
- RegisterUsingTokenRequest: registerUsingTokenRequestForParams(token, params),
- StsIdentityRequest: signedRequest,
- }, nil
- })
- if err != nil {
- log.WithError(err).Infof("Failed to register %s using regional STS endpoint", params.ID.Role)
- return nil, trace.Wrap(err)
- }
-
- log.Infof("Successfully registered %s with IAM method using regional STS endpoint", params.ID.Role)
- return certs, nil
-}
-
-// registerUsingAzureMethod is used to register using the Azure join method. It
-// is able to register through a proxy or through the auth server directly.
-func registerUsingAzureMethod(
- ctx context.Context, client joinServiceClient, token string, params RegisterParams,
-) (*proto.Certs, error) {
- certs, err := client.RegisterUsingAzureMethod(ctx, func(challenge string) (*proto.RegisterUsingAzureMethodRequest, error) {
- imds := azure.NewInstanceMetadataClient()
- if !imds.IsAvailable(ctx) {
- return nil, trace.AccessDenied("could not reach instance metadata. Is Teleport running on an Azure VM?")
- }
- ad, err := imds.GetAttestedData(ctx, challenge)
- if err != nil {
- return nil, trace.Wrap(err)
- }
- accessToken, err := imds.GetAccessToken(ctx, params.AzureParams.ClientID)
- if err != nil {
- return nil, trace.Wrap(err)
- }
-
- return &proto.RegisterUsingAzureMethodRequest{
- RegisterUsingTokenRequest: registerUsingTokenRequestForParams(token, params),
- AttestedData: ad,
- AccessToken: accessToken,
- }, nil
- })
- return certs, trace.Wrap(err)
-}
-
-// registerUsingTPMMethod is used to register using the TPM join method. It
-// is able to register through a proxy or through the auth server directly.
-func registerUsingTPMMethod(
- ctx context.Context,
- client joinServiceClient,
- token string,
- params RegisterParams,
-) (*proto.Certs, error) {
- log := slog.Default()
-
- initReq := &proto.RegisterUsingTPMMethodInitialRequest{
- JoinRequest: registerUsingTokenRequestForParams(token, params),
- }
-
- attestation, close, err := tpm.Attest(ctx, log)
- if err != nil {
- return nil, trace.Wrap(err)
- }
- defer func() {
- if err := close(); err != nil {
- log.WarnContext(ctx, "Failed to close TPM", "error", err)
- }
- }()
-
- initReq.AttestationParams = tpm.AttestationParametersToProto(
- attestation.AttestParams,
- )
- // Get the EKKey or EKCert. We want to prefer the EKCert if it is available
- // as this is signed by the manufacturer.
- switch {
- case attestation.Data.EKCert != nil:
- log.DebugContext(
- ctx,
- "Using EKCert for TPM registration",
- "ekcert_serial", attestation.Data.EKCert.SerialNumber,
- )
- initReq.Ek = &proto.RegisterUsingTPMMethodInitialRequest_EkCert{
- EkCert: attestation.Data.EKCert.Raw,
- }
- case attestation.Data.EKPub != nil:
- log.DebugContext(
- ctx,
- "Using EKKey for TPM registration",
- "ekpub_hash", attestation.Data.EKPubHash,
- )
- initReq.Ek = &proto.RegisterUsingTPMMethodInitialRequest_EkKey{
- EkKey: attestation.Data.EKPub,
- }
- default:
- return nil, trace.BadParameter("tpm has neither ekkey or ekcert")
- }
-
- // Submit initial request to the Auth Server.
- certs, err := client.RegisterUsingTPMMethod(
- ctx,
- initReq,
- func(
- challenge *proto.TPMEncryptedCredential,
- ) (*proto.RegisterUsingTPMMethodChallengeResponse, error) {
- // Solve the encrypted credential with our AK to prove possession
- // and obtain the solution we need to complete the ceremony.
- solution, err := attestation.Solve(tpm.EncryptedCredentialFromProto(
- challenge,
- ))
- if err != nil {
- return nil, trace.Wrap(err, "activating credential")
- }
- return &proto.RegisterUsingTPMMethodChallengeResponse{
- Solution: solution,
- }, nil
- },
- )
- return certs, trace.Wrap(err)
-}
-
// ReRegisterParams specifies parameters for re-registering
// in the cluster (rotating certificates for existing members)
type ReRegisterParams struct {
diff --git a/lib/auth/register_test.go b/lib/auth/register_test.go
deleted file mode 100644
index 5f1d248ace8f1..0000000000000
--- a/lib/auth/register_test.go
+++ /dev/null
@@ -1,78 +0,0 @@
-/*
- * Teleport
- * Copyright (C) 2023 Gravitational, Inc.
- *
- * This program is free software: you can redistribute it and/or modify
- * it under the terms of the GNU Affero General Public License as published by
- * the Free Software Foundation, either version 3 of the License, or
- * (at your option) any later version.
- *
- * This program is distributed in the hope that it will be useful,
- * but WITHOUT ANY WARRANTY; without even the implied warranty of
- * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- * GNU Affero General Public License for more details.
- *
- * You should have received a copy of the GNU Affero General Public License
- * along with this program. If not, see .
- */
-
-package auth
-
-import (
- "crypto/tls"
- "testing"
- "time"
-
- "github.com/jonboulle/clockwork"
- "github.com/stretchr/testify/require"
-
- "github.com/gravitational/teleport/api/types"
- "github.com/gravitational/teleport/lib/defaults"
- "github.com/gravitational/teleport/lib/fixtures"
- "github.com/gravitational/teleport/lib/utils"
-)
-
-func TestVerifyALPNUpgradedConn(t *testing.T) {
- t.Parallel()
-
- auth := newTestTLSServer(t)
- proxy, err := NewServerIdentity(auth.Auth(), "test-proxy", types.RoleProxy)
- require.NoError(t, err)
-
- tests := []struct {
- name string
- serverCert []byte
- clock clockwork.Clock
- checkError require.ErrorAssertionFunc
- }{
- {
- name: "proxy verified",
- serverCert: proxy.TLSCertBytes,
- clock: auth.Clock(),
- checkError: require.NoError,
- },
- {
- name: "proxy expired",
- serverCert: proxy.TLSCertBytes,
- clock: clockwork.NewFakeClockAt(auth.Clock().Now().Add(defaults.CATTL + time.Hour)),
- checkError: require.Error,
- },
- {
- name: "not proxy",
- serverCert: []byte(fixtures.TLSCACertPEM),
- clock: auth.Clock(),
- checkError: require.Error,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- serverCert, err := utils.ReadCertificates(test.serverCert)
- require.NoError(t, err)
-
- test.checkError(t, verifyALPNUpgradedConn(test.clock)(tls.ConnectionState{
- PeerCertificates: serverCert,
- }))
- })
- }
-}
diff --git a/lib/auth/sts_endpoints.go b/lib/auth/sts_endpoints.go
deleted file mode 100644
index c0884fb1a2bd4..0000000000000
--- a/lib/auth/sts_endpoints.go
+++ /dev/null
@@ -1,85 +0,0 @@
-/*
- * Teleport
- * Copyright (C) 2023 Gravitational, Inc.
- *
- * This program is free software: you can redistribute it and/or modify
- * it under the terms of the GNU Affero General Public License as published by
- * the Free Software Foundation, either version 3 of the License, or
- * (at your option) any later version.
- *
- * This program is distributed in the hope that it will be useful,
- * but WITHOUT ANY WARRANTY; without even the implied warranty of
- * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- * GNU Affero General Public License for more details.
- *
- * You should have received a copy of the GNU Affero General Public License
- * along with this program. If not, see .
- */
-
-package auth
-
-var (
- // validSTSEndpoints holds a sorted list of all known valid public endpoints for
- // the AWS STS service. You can generate this list by running
- // $ go run github.com/nklaassen/sts-endpoints@latest --go-list
- // Update aws-sdk-go in that package to learn about new endpoints.
- validSTSEndpoints = []string{
- "sts-fips.us-east-1.amazonaws.com",
- "sts-fips.us-east-2.amazonaws.com",
- "sts-fips.us-west-1.amazonaws.com",
- "sts-fips.us-west-2.amazonaws.com",
- "sts.af-south-1.amazonaws.com",
- "sts.amazonaws.com",
- "sts.ap-east-1.amazonaws.com",
- "sts.ap-northeast-1.amazonaws.com",
- "sts.ap-northeast-2.amazonaws.com",
- "sts.ap-northeast-3.amazonaws.com",
- "sts.ap-south-1.amazonaws.com",
- "sts.ap-south-2.amazonaws.com",
- "sts.ap-southeast-1.amazonaws.com",
- "sts.ap-southeast-2.amazonaws.com",
- "sts.ap-southeast-3.amazonaws.com",
- "sts.ap-southeast-4.amazonaws.com",
- "sts.ca-central-1.amazonaws.com",
- "sts.ca-west-1.amazonaws.com",
- "sts.cn-north-1.amazonaws.com.cn",
- "sts.cn-northwest-1.amazonaws.com.cn",
- "sts.eu-central-1.amazonaws.com",
- "sts.eu-central-2.amazonaws.com",
- "sts.eu-north-1.amazonaws.com",
- "sts.eu-south-1.amazonaws.com",
- "sts.eu-south-2.amazonaws.com",
- "sts.eu-west-1.amazonaws.com",
- "sts.eu-west-2.amazonaws.com",
- "sts.eu-west-3.amazonaws.com",
- "sts.il-central-1.amazonaws.com",
- "sts.me-central-1.amazonaws.com",
- "sts.me-south-1.amazonaws.com",
- "sts.sa-east-1.amazonaws.com",
- "sts.us-east-1.amazonaws.com",
- "sts.us-east-2.amazonaws.com",
- "sts.us-gov-east-1.amazonaws.com",
- "sts.us-gov-west-1.amazonaws.com",
- "sts.us-iso-east-1.c2s.ic.gov",
- "sts.us-iso-west-1.c2s.ic.gov",
- "sts.us-isob-east-1.sc2s.sgov.gov",
- "sts.us-west-1.amazonaws.com",
- "sts.us-west-2.amazonaws.com",
- }
-
- globalSTSEndpoints = []string{
- "sts.amazonaws.com",
- // This is not a real endpoint, but the SDK will select it if
- // AWS_USE_FIPS_ENDPOINT is set and a region is not.
- "sts-fips.aws-global.amazonaws.com",
- }
-
- fipsSTSEndpoints = []string{
- "sts-fips.us-east-1.amazonaws.com",
- "sts-fips.us-east-2.amazonaws.com",
- "sts-fips.us-west-1.amazonaws.com",
- "sts-fips.us-west-2.amazonaws.com",
- "sts.us-gov-east-1.amazonaws.com",
- "sts.us-gov-west-1.amazonaws.com",
- }
-)
diff --git a/lib/auth/tls_test.go b/lib/auth/tls_test.go
index 972a2e34532ad..df0618c8d7a0a 100644
--- a/lib/auth/tls_test.go
+++ b/lib/auth/tls_test.go
@@ -57,6 +57,7 @@ import (
apiutils "github.com/gravitational/teleport/api/utils"
"github.com/gravitational/teleport/api/utils/sshutils"
"github.com/gravitational/teleport/lib/auth/authclient"
+ "github.com/gravitational/teleport/lib/auth/join"
"github.com/gravitational/teleport/lib/auth/state"
"github.com/gravitational/teleport/lib/auth/testauthority"
"github.com/gravitational/teleport/lib/authz"
@@ -3394,7 +3395,7 @@ func TestRegisterCAPin(t *testing.T) {
caPin := caPins[0]
// Attempt to register with valid CA pin, should work.
- _, err = Register(ctx, RegisterParams{
+ _, err = join.Register(ctx, join.RegisterParams{
AuthServers: []utils.NetAddr{utils.FromAddr(testSrv.Addr())},
Token: token,
ID: state.IdentityID{
@@ -3412,7 +3413,7 @@ func TestRegisterCAPin(t *testing.T) {
// Attempt to register with multiple CA pins where the auth server only
// matches one, should work.
- _, err = Register(ctx, RegisterParams{
+ _, err = join.Register(ctx, join.RegisterParams{
AuthServers: []utils.NetAddr{utils.FromAddr(testSrv.Addr())},
Token: token,
ID: state.IdentityID{
@@ -3429,7 +3430,7 @@ func TestRegisterCAPin(t *testing.T) {
require.NoError(t, err)
// Attempt to register with invalid CA pin, should fail.
- _, err = Register(ctx, RegisterParams{
+ _, err = join.Register(ctx, join.RegisterParams{
AuthServers: []utils.NetAddr{utils.FromAddr(testSrv.Addr())},
Token: token,
ID: state.IdentityID{
@@ -3446,7 +3447,7 @@ func TestRegisterCAPin(t *testing.T) {
require.Error(t, err)
// Attempt to register with multiple invalid CA pins, should fail.
- _, err = Register(ctx, RegisterParams{
+ _, err = join.Register(ctx, join.RegisterParams{
AuthServers: []utils.NetAddr{utils.FromAddr(testSrv.Addr())},
Token: token,
ID: state.IdentityID{
@@ -3482,7 +3483,7 @@ func TestRegisterCAPin(t *testing.T) {
require.Len(t, caPins, 2)
// Attempt to register with multiple CA pins, should work
- _, err = Register(ctx, RegisterParams{
+ _, err = join.Register(ctx, join.RegisterParams{
AuthServers: []utils.NetAddr{utils.FromAddr(testSrv.Addr())},
Token: token,
ID: state.IdentityID{
@@ -3527,7 +3528,7 @@ func TestRegisterCAPath(t *testing.T) {
require.NoError(t, err)
// Attempt to register with nothing at the CA path, should work.
- _, err = Register(ctx, RegisterParams{
+ _, err = join.Register(ctx, join.RegisterParams{
AuthServers: []utils.NetAddr{utils.FromAddr(testSrv.Addr())},
Token: token,
ID: state.IdentityID{
@@ -3556,7 +3557,7 @@ func TestRegisterCAPath(t *testing.T) {
require.NoError(t, err)
// Attempt to register with valid CA path, should work.
- _, err = Register(ctx, RegisterParams{
+ _, err = join.Register(ctx, join.RegisterParams{
AuthServers: []utils.NetAddr{utils.FromAddr(testSrv.Addr())},
Token: token,
ID: state.IdentityID{
diff --git a/lib/service/connect.go b/lib/service/connect.go
index bc61b4b1a5679..eeee5ebb9fb92 100644
--- a/lib/service/connect.go
+++ b/lib/service/connect.go
@@ -45,6 +45,7 @@ import (
"github.com/gravitational/teleport/lib"
"github.com/gravitational/teleport/lib/auth"
"github.com/gravitational/teleport/lib/auth/authclient"
+ "github.com/gravitational/teleport/lib/auth/join"
"github.com/gravitational/teleport/lib/auth/native"
"github.com/gravitational/teleport/lib/auth/state"
"github.com/gravitational/teleport/lib/client"
@@ -442,7 +443,7 @@ func (process *TeleportProcess) firstTimeConnect(role types.SystemRole) (*Connec
dataDir = process.Config.DataDir
}
- registerParams := auth.RegisterParams{
+ registerParams := join.RegisterParams{
Token: token,
ID: id,
AuthServers: process.Config.AuthServerAddresses(),
@@ -464,12 +465,12 @@ func (process *TeleportProcess) firstTimeConnect(role types.SystemRole) (*Connec
Insecure: lib.IsInsecureDevMode(),
}
if registerParams.JoinMethod == types.JoinMethodAzure {
- registerParams.AzureParams = auth.AzureParams{
+ registerParams.AzureParams = join.AzureParams{
ClientID: process.Config.JoinParams.Azure.ClientID,
}
}
- certs, err := auth.Register(process.ExitContext(), registerParams)
+ certs, err := join.Register(process.ExitContext(), registerParams)
if err != nil {
if utils.IsUntrustedCertErr(err) {
return nil, trace.WrapWithMessage(err, utils.SelfSignedCertsMsg)
diff --git a/lib/tbot/service_bot_identity.go b/lib/tbot/service_bot_identity.go
index f193b2542bc12..11418705398d6 100644
--- a/lib/tbot/service_bot_identity.go
+++ b/lib/tbot/service_bot_identity.go
@@ -33,8 +33,8 @@ import (
"github.com/gravitational/teleport/api/client/proto"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/api/utils/retryutils"
- "github.com/gravitational/teleport/lib/auth"
"github.com/gravitational/teleport/lib/auth/authclient"
+ "github.com/gravitational/teleport/lib/auth/join"
"github.com/gravitational/teleport/lib/auth/state"
"github.com/gravitational/teleport/lib/client"
"github.com/gravitational/teleport/lib/reversetunnelclient"
@@ -403,7 +403,7 @@ func botIdentityFromToken(ctx context.Context, log *slog.Logger, cfg *config.Bot
}
expires := time.Now().Add(cfg.CertificateTTL)
- params := auth.RegisterParams{
+ params := join.RegisterParams{
Token: token,
ID: state.IdentityID{
Role: types.RoleBot,
@@ -439,12 +439,12 @@ func botIdentityFromToken(ctx context.Context, log *slog.Logger, cfg *config.Bot
}
if params.JoinMethod == types.JoinMethodAzure {
- params.AzureParams = auth.AzureParams{
+ params.AzureParams = join.AzureParams{
ClientID: cfg.Onboarding.Azure.ClientID,
}
}
- certs, err := auth.Register(ctx, params)
+ certs, err := join.Register(ctx, params)
if err != nil {
return nil, trace.Wrap(err)
}