Skip to content

Commit

Permalink
Refactor registration code into lib/auth/join (#41679)
Browse files Browse the repository at this point in the history
* Refactor joining code into lib/auth/join

auth.Register and auth.RegisterParams were moved into the new package
to prevent clients(tbot) from depending on lib/auth. For the moment
only iam joining has been moved into the new package. All other
join methods have been left behind in their lib/auth/join_foo.go
file and will be addressed in future PRs if need be.  This removes
lib/auth entirely from tbot and reduces its binary size by ~10-20MB.

* fix: disabling second factor auth in tests

* fix license
  • Loading branch information
rosstimothy authored May 18, 2024
1 parent 7929ab0 commit 8e59e36
Show file tree
Hide file tree
Showing 15 changed files with 1,405 additions and 1,079 deletions.
4 changes: 2 additions & 2 deletions integration/proxy/proxy_helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ import (
"github.com/gravitational/teleport/api/utils/retryutils"
"github.com/gravitational/teleport/integration/helpers"
"github.com/gravitational/teleport/lib"
"github.com/gravitational/teleport/lib/auth"
"github.com/gravitational/teleport/lib/auth/join"
"github.com/gravitational/teleport/lib/auth/state"
"github.com/gravitational/teleport/lib/client"
"github.com/gravitational/teleport/lib/defaults"
Expand Down Expand Up @@ -664,7 +664,7 @@ func mustRegisterUsingIAMMethod(t *testing.T, proxyAddr utils.NetAddr, token str
require.NoError(t, err)

node := uuid.NewString()
_, err = auth.Register(context.TODO(), auth.RegisterParams{
_, err = join.Register(context.TODO(), join.RegisterParams{
Token: token,
ID: state.IdentityID{
Role: types.RoleNode,
Expand Down
3 changes: 0 additions & 3 deletions lib/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ import (
"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
"github.com/prometheus/client_golang/prometheus"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/exporters/otlp/otlptrace"
"golang.org/x/crypto/bcrypt"
"golang.org/x/crypto/ssh"
Expand Down Expand Up @@ -155,8 +154,6 @@ const (
"(hint: use 'tctl get roles' to find roles that need updating)"
)

var tracer = otel.Tracer("github.com/gravitational/teleport/lib/auth")

var ErrRequiresEnterprise = services.ErrRequiresEnterprise

// ServerOption allows setting options as functional arguments to Server
Expand Down
7 changes: 4 additions & 3 deletions lib/auth/bot_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ import (
machineidv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/machineid/v1"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/auth/authclient"
"github.com/gravitational/teleport/lib/auth/join"
"github.com/gravitational/teleport/lib/auth/machineid/machineidv1"
"github.com/gravitational/teleport/lib/auth/state"
"github.com/gravitational/teleport/lib/auth/testauthority"
Expand Down Expand Up @@ -118,7 +119,7 @@ func TestRegisterBotCertificateGenerationCheck(t *testing.T) {
tlsPublicKey, err := tlsca.MarshalPublicKeyFromPrivateKeyPEM(sshPrivateKey)
require.NoError(t, err)

certs, err := Register(ctx, RegisterParams{
certs, err := join.Register(ctx, join.RegisterParams{
Token: token.GetName(),
ID: state.IdentityID{
Role: types.RoleBot,
Expand Down Expand Up @@ -191,7 +192,7 @@ func TestRegisterBotCertificateGenerationStolen(t *testing.T) {
tlsPublicKey, err := tlsca.MarshalPublicKeyFromPrivateKeyPEM(sshPrivateKey)
require.NoError(t, err)

certs, err := Register(ctx, RegisterParams{
certs, err := join.Register(ctx, join.RegisterParams{
Token: token.GetName(),
ID: state.IdentityID{
Role: types.RoleBot,
Expand Down Expand Up @@ -267,7 +268,7 @@ func TestRegisterBotCertificateExtensions(t *testing.T) {
tlsPublicKey, err := tlsca.MarshalPublicKeyFromPrivateKeyPEM(sshPrivateKey)
require.NoError(t, err)

certs, err := Register(ctx, RegisterParams{
certs, err := join.Register(ctx, join.RegisterParams{
Token: token.GetName(),
ID: state.IdentityID{
Role: types.RoleBot,
Expand Down
91 changes: 91 additions & 0 deletions lib/auth/join/iam/endpoints.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
// Teleport
// Copyright (C) 2024 Gravitational, Inc.
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.

package iam

import "sync"

var (
// ValidSTSEndpoints holds a sorted list of all known valid public endpoints for
// the AWS STS service. You can generate this list by running
// $ go run github.com/nklaassen/sts-endpoints@latest --go-list
// Update aws-sdk-go in that package to learn about new endpoints.
ValidSTSEndpoints = sync.OnceValue(func() []string {
return []string{
"sts-fips.us-east-1.amazonaws.com",
"sts-fips.us-east-2.amazonaws.com",
"sts-fips.us-west-1.amazonaws.com",
"sts-fips.us-west-2.amazonaws.com",
"sts.af-south-1.amazonaws.com",
"sts.amazonaws.com",
"sts.ap-east-1.amazonaws.com",
"sts.ap-northeast-1.amazonaws.com",
"sts.ap-northeast-2.amazonaws.com",
"sts.ap-northeast-3.amazonaws.com",
"sts.ap-south-1.amazonaws.com",
"sts.ap-south-2.amazonaws.com",
"sts.ap-southeast-1.amazonaws.com",
"sts.ap-southeast-2.amazonaws.com",
"sts.ap-southeast-3.amazonaws.com",
"sts.ap-southeast-4.amazonaws.com",
"sts.ca-central-1.amazonaws.com",
"sts.ca-west-1.amazonaws.com",
"sts.cn-north-1.amazonaws.com.cn",
"sts.cn-northwest-1.amazonaws.com.cn",
"sts.eu-central-1.amazonaws.com",
"sts.eu-central-2.amazonaws.com",
"sts.eu-north-1.amazonaws.com",
"sts.eu-south-1.amazonaws.com",
"sts.eu-south-2.amazonaws.com",
"sts.eu-west-1.amazonaws.com",
"sts.eu-west-2.amazonaws.com",
"sts.eu-west-3.amazonaws.com",
"sts.il-central-1.amazonaws.com",
"sts.me-central-1.amazonaws.com",
"sts.me-south-1.amazonaws.com",
"sts.sa-east-1.amazonaws.com",
"sts.us-east-1.amazonaws.com",
"sts.us-east-2.amazonaws.com",
"sts.us-gov-east-1.amazonaws.com",
"sts.us-gov-west-1.amazonaws.com",
"sts.us-iso-east-1.c2s.ic.gov",
"sts.us-iso-west-1.c2s.ic.gov",
"sts.us-isob-east-1.sc2s.sgov.gov",
"sts.us-west-1.amazonaws.com",
"sts.us-west-2.amazonaws.com",
}
})

GlobalSTSEndpoints = sync.OnceValue(func() []string {
return []string{
"sts.amazonaws.com",
// This is not a real endpoint, but the SDK will select it if
// AWS_USE_FIPS_ENDPOINT is set and a region is not.
"sts-fips.aws-global.amazonaws.com",
}
})

FIPSSTSEndpoints = sync.OnceValue(func() []string {
return []string{
"sts-fips.us-east-1.amazonaws.com",
"sts-fips.us-east-2.amazonaws.com",
"sts-fips.us-west-1.amazonaws.com",
"sts-fips.us-west-2.amazonaws.com",
"sts.us-gov-east-1.amazonaws.com",
"sts.us-gov-west-1.amazonaws.com",
}
})
)
161 changes: 161 additions & 0 deletions lib/auth/join/iam/iam.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
// Teleport
// Copyright (C) 2024 Gravitational, Inc.
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.

package iam

import (
"bytes"
"context"
"log/slog"
"slices"
"strings"

awssdk "github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/endpoints"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/gravitational/trace"

cloudaws "github.com/gravitational/teleport/lib/cloud/imds/aws"
)

const (
// AWS SignedHeaders will always be lowercase
// https://docs.aws.amazon.com/AmazonS3/latest/API/sigv4-auth-using-authorization-header.html#sigv4-auth-header-overview
challengeHeaderKey = "x-teleport-challenge"
)

type stsIdentityRequestConfig struct {
regionalEndpointOption endpoints.STSRegionalEndpoint
fipsEndpointOption endpoints.FIPSEndpointState
}

type stsIdentityRequestOption func(cfg *stsIdentityRequestConfig)

func WithRegionalEndpoint(useRegionalEndpoint bool) stsIdentityRequestOption {
return func(cfg *stsIdentityRequestConfig) {
if useRegionalEndpoint {
cfg.regionalEndpointOption = endpoints.RegionalSTSEndpoint
} else {
cfg.regionalEndpointOption = endpoints.LegacySTSEndpoint
}
}
}

func WithFIPSEndpoint(useFIPS bool) stsIdentityRequestOption {
return func(cfg *stsIdentityRequestConfig) {
if useFIPS {
cfg.fipsEndpointOption = endpoints.FIPSEndpointStateEnabled
} else {
cfg.fipsEndpointOption = endpoints.FIPSEndpointStateDisabled
}
}
}

// getEC2LocalRegion returns the AWS region this EC2 instance is running in, or
// a NotFound error if the EC2 IMDS is unavailable.
func getEC2LocalRegion(ctx context.Context) (string, error) {
imdsClient, err := cloudaws.NewInstanceMetadataClient(ctx)
if err != nil {
return "", trace.Wrap(err)
}

if !imdsClient.IsAvailable(ctx) {
return "", trace.NotFound("IMDS is unavailable")
}

region, err := imdsClient.GetRegion(ctx)
return region, trace.Wrap(err)
}

func newSTSClient(ctx context.Context, cfg *stsIdentityRequestConfig) (*sts.STS, error) {
awsConfig := awssdk.Config{
UseFIPSEndpoint: cfg.fipsEndpointOption,
STSRegionalEndpoint: cfg.regionalEndpointOption,
}
sess, err := session.NewSessionWithOptions(session.Options{
SharedConfigState: session.SharedConfigEnable,
Config: awsConfig,
})
if err != nil {
return nil, trace.Wrap(err)
}

stsClient := sts.New(sess)

if slices.Contains(GlobalSTSEndpoints(), strings.TrimPrefix(stsClient.Endpoint, "https://")) {
// If the caller wants to use the regional endpoint but it was not resolved
// from the environment, attempt to find the region from the EC2 IMDS
if cfg.regionalEndpointOption == endpoints.RegionalSTSEndpoint {
region, err := getEC2LocalRegion(ctx)
if err != nil {
return nil, trace.Wrap(err, "failed to resolve local AWS region from environment or IMDS")
}
stsClient = sts.New(sess, awssdk.NewConfig().WithRegion(region))
} else {
const msg = "Attempting to use the global STS endpoint for the IAM join method. " +
"This will probably fail in non-default AWS partitions such as China or GovCloud, or if FIPS mode is enabled. " +
"Consider setting the AWS_REGION environment variable, setting the region in ~/.aws/config, or enabling the IMDSv2."
slog.InfoContext(ctx, msg)
}
}

if cfg.fipsEndpointOption == endpoints.FIPSEndpointStateEnabled &&
!slices.Contains(ValidSTSEndpoints(), strings.TrimPrefix(stsClient.Endpoint, "https://")) {
// The AWS SDK will generate invalid endpoints when attempting to
// resolve the FIPS endpoint for a region that does not have one.
// In this case, try to use the FIPS endpoint in us-east-1. This should
// work for all regions in the standard partition. In GovCloud, we should
// not hit this because all regional endpoints support FIPS. In China or
// other partitions, this will fail, and FIPS mode will not be supported.
const msg = "AWS SDK resolved invalid FIPS STS endpoint. " +
"Attempting to use the FIPS STS endpoint for us-east-1."
slog.InfoContext(ctx, msg, "resolved", stsClient.Endpoint)
stsClient = sts.New(sess, awssdk.NewConfig().WithRegion("us-east-1"))
}

return stsClient, nil
}

// CreateSignedSTSIdentityRequest is called on the client side and returns an
// sts:GetCallerIdentity request signed with the local AWS credentials
func CreateSignedSTSIdentityRequest(ctx context.Context, challenge string, opts ...stsIdentityRequestOption) ([]byte, error) {
cfg := &stsIdentityRequestConfig{}
for _, opt := range opts {
opt(cfg)
}

stsClient, err := newSTSClient(ctx, cfg)
if err != nil {
return nil, trace.Wrap(err)
}

req, _ := stsClient.GetCallerIdentityRequest(&sts.GetCallerIdentityInput{})
// set challenge header
req.HTTPRequest.Header.Set(challengeHeaderKey, challenge)
// request json for simpler parsing
req.HTTPRequest.Header.Set("Accept", "application/json")
// sign the request, including headers
if err := req.Sign(); err != nil {
return nil, trace.Wrap(err)
}
// write the signed HTTP request to a buffer
var signedRequest bytes.Buffer
if err := req.HTTPRequest.Write(&signedRequest); err != nil {
return nil, trace.Wrap(err)
}
return signedRequest.Bytes(), nil
}
Loading

0 comments on commit 8e59e36

Please sign in to comment.