Skip to content

Commit

Permalink
Load appropriate TLS config for clusters (#40293)
Browse files Browse the repository at this point in the history
Updates the proxy.Client to allow loading specific tls.Config for
individual clusters. This prevents issues when trying to access
leaf resources via the root cluster if WithAllCAs is not set.
  • Loading branch information
rosstimothy authored Apr 9, 2024
1 parent 67e74d9 commit ad26913
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 36 deletions.
62 changes: 40 additions & 22 deletions api/client/proxy/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ type ClientConfig struct {
ProxyAddress string
// TLSRoutingEnabled indicates if the cluster is using TLS Routing.
TLSRoutingEnabled bool
// TLSConfig contains the tls.Config required for mTLS connections.
TLSConfig *tls.Config
// TLSConfigFunc produces the [tls.Config] required for mTLS connections to a specific cluster.
TLSConfigFunc func(cluster string) (*tls.Config, error)
// UnaryInterceptors are optional [grpc.UnaryClientInterceptor] to apply
// to the gRPC client.
UnaryInterceptors []grpc.UnaryClientInterceptor
Expand All @@ -77,9 +77,9 @@ type ClientConfig struct {

// The below items are intended to be used by tests to connect without mTLS.
// The gRPC transport credentials to use when establishing the connection to proxy.
creds func() credentials.TransportCredentials
creds func(cluster string) (credentials.TransportCredentials, error)
// The client credentials to use when establishing the connection to auth.
clientCreds func() client.Credentials
clientCreds func(cluster string) (client.Credentials, error)
}

// CheckAndSetDefaults ensures required options are present and
Expand All @@ -94,13 +94,21 @@ func (c *ClientConfig) CheckAndSetDefaults() error {
if c.DialTimeout <= 0 {
c.DialTimeout = defaults.DefaultIOTimeout
}
if c.TLSConfig != nil {
c.clientCreds = func() client.Credentials {
return client.LoadTLS(c.TLSConfig.Clone())
if c.TLSConfigFunc != nil {
c.clientCreds = func(cluster string) (client.Credentials, error) {
cfg, err := c.TLSConfigFunc(cluster)
if err != nil {
return nil, trace.Wrap(err)
}

return client.LoadTLS(cfg), nil
}
c.creds = func() credentials.TransportCredentials {
tlsCfg := c.TLSConfig.Clone()
if !slices.Contains(c.TLSConfig.NextProtos, protocolProxySSHGRPC) {
c.creds = func(cluster string) (credentials.TransportCredentials, error) {
tlsCfg, err := c.TLSConfigFunc(cluster)
if err != nil {
return nil, trace.Wrap(err)
}
if !slices.Contains(tlsCfg.NextProtos, protocolProxySSHGRPC) {
tlsCfg.NextProtos = append(tlsCfg.NextProtos, protocolProxySSHGRPC)
}

Expand All @@ -115,14 +123,14 @@ func (c *ClientConfig) CheckAndSetDefaults() error {
}
}

return credentials.NewTLS(tlsCfg)
return credentials.NewTLS(tlsCfg), nil
}
} else {
c.clientCreds = func() client.Credentials {
return insecureCredentials{}
c.clientCreds = func(cluster string) (client.Credentials, error) {
return insecureCredentials{}, nil
}
c.creds = func() credentials.TransportCredentials {
return insecure.NewCredentials()
c.creds = func(cluster string) (credentials.TransportCredentials, error) {
return insecure.NewCredentials(), nil
}
}

Expand Down Expand Up @@ -265,12 +273,18 @@ func newGRPCClient(ctx context.Context, cfg *ClientConfig) (_ *Client, err error
defer cancel()

c := &clusterName{}

creds, err := cfg.creds("")
if err != nil {
return nil, trace.Wrap(err)
}

conn, err := grpc.DialContext(
dialCtx,
cfg.ProxyAddress,
append([]grpc.DialOption{
grpc.WithContextDialer(newDialerForGRPCClient(ctx, cfg)),
grpc.WithTransportCredentials(&clusterCredentials{TransportCredentials: cfg.creds(), clusterName: c}),
grpc.WithTransportCredentials(&clusterCredentials{TransportCredentials: creds, clusterName: c}),
grpc.WithChainUnaryInterceptor(
append(cfg.UnaryInterceptors,
//nolint:staticcheck // SA1019. There is a data race in the stats.Handler that is replacing
Expand Down Expand Up @@ -361,22 +375,27 @@ type ClusterDetails struct {
// Auth server in the provided cluster via [client.New] or similar. The [client.Config]
// returned will have the correct credentials and dialer set based on the ClientConfig
// that was provided to create this Client.
func (c *Client) ClientConfig(ctx context.Context, cluster string) client.Config {
func (c *Client) ClientConfig(ctx context.Context, cluster string) (client.Config, error) {
creds, err := c.cfg.clientCreds(cluster)
if err != nil {
return client.Config{}, trace.Wrap(err)
}

if c.cfg.TLSRoutingEnabled {
return client.Config{
Context: ctx,
Addrs: []string{c.cfg.ProxyAddress},
Credentials: []client.Credentials{c.cfg.clientCreds()},
Credentials: []client.Credentials{creds},
ALPNSNIAuthDialClusterName: cluster,
CircuitBreakerConfig: breaker.NoopBreakerConfig(),
ALPNConnUpgradeRequired: c.cfg.ALPNConnUpgradeRequired,
DialOpts: c.cfg.DialOpts,
}
}, nil
}

return client.Config{
Context: ctx,
Credentials: []client.Credentials{c.cfg.clientCreds()},
Credentials: []client.Credentials{creds},
CircuitBreakerConfig: breaker.NoopBreakerConfig(),
DialInBackground: true,
Dialer: client.ContextDialerFunc(func(dialCtx context.Context, _ string, _ string) (net.Conn, error) {
Expand All @@ -395,8 +414,7 @@ func (c *Client) ClientConfig(ctx context.Context, cluster string) client.Config
return conn, trace.Wrap(err)
}),
DialOpts: c.cfg.DialOpts,
}

}, nil
}

// DialHost establishes a connection to the `target` in cluster named `cluster`. If a keyring
Expand Down
4 changes: 3 additions & 1 deletion api/client/proxy/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,9 @@ func TestClient_DialCluster(t *testing.T) {
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, clt.Close()) })

authCfg := clt.ClientConfig(ctx, "cluster")
authCfg, err := clt.ClientConfig(ctx, "cluster")
require.NoError(t, err)

authCfg.DialOpts = []grpc.DialOption{
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithReturnConnectionError(),
Expand Down
3 changes: 2 additions & 1 deletion integration/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1661,7 +1661,8 @@ func testIPPropagation(t *testing.T, suite *integrationTestSuite) {
// The above dialer does not work clt.AuthClient as it requires a
// custom transport from ProxyClient when TLS routing is disabled.
// Recreating the authClient without the above dialer.
authClientCfg := clt.ProxyClient.ClientConfig(ctx, clusterName)
authClientCfg, err := clt.ProxyClient.ClientConfig(ctx, clusterName)
require.NoError(t, err)
authClientCfg.DialOpts = nil
authClient, err := auth.NewClient(authClientCfg)
require.NoError(t, err)
Expand Down
31 changes: 21 additions & 10 deletions lib/client/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -1573,7 +1573,12 @@ func (tc *TeleportClient) NewTracingClient(ctx context.Context) (*apitracing.Cli
return nil, trace.Wrap(err)
}

tracingClient, err := client.NewTracingClient(ctx, clusterClient.ProxyClient.ClientConfig(ctx, clusterClient.ClusterName()))
cfg, err := clusterClient.ProxyClient.ClientConfig(ctx, clusterClient.ClusterName())
if err != nil {
return nil, trace.Wrap(err)
}

tracingClient, err := client.NewTracingClient(ctx, cfg)
return tracingClient, trace.Wrap(err)
}

Expand Down Expand Up @@ -3061,15 +3066,18 @@ func (tc *TeleportClient) ConnectToCluster(ctx context.Context) (_ *ClusterClien
return nil, trace.Wrap(err)
}

tlsConfig, err := tc.LoadTLSConfig()
if err != nil {
return nil, trace.Wrap(err)
}

pclt, err := proxyclient.NewClient(ctx, proxyclient.ClientConfig{
ProxyAddress: cfg.proxyAddress,
TLSRoutingEnabled: tc.TLSRoutingEnabled,
TLSConfig: tlsConfig,
ProxyAddress: cfg.proxyAddress,
TLSRoutingEnabled: tc.TLSRoutingEnabled,
TLSConfigFunc: func(cluster string) (*tls.Config, error) {
if cluster == "" {
tlsCfg, err := tc.LoadTLSConfig()
return tlsCfg, trace.Wrap(err)
}

tlsCfg, err := tc.LoadTLSConfigForClusters([]string{cluster})
return tlsCfg, trace.Wrap(err)
},
DialOpts: tc.Config.DialOpts,
UnaryInterceptors: []grpc.UnaryClientInterceptor{interceptors.GRPCClientUnaryErrorInterceptor},
StreamInterceptors: []grpc.StreamClientInterceptor{interceptors.GRPCClientStreamErrorInterceptor},
Expand All @@ -3092,7 +3100,10 @@ func (tc *TeleportClient) ConnectToCluster(ctx context.Context) (_ *ClusterClien
cluster = connected
}

authClientCfg := pclt.ClientConfig(ctx, cluster)
authClientCfg, err := pclt.ClientConfig(ctx, cluster)
if err != nil {
return nil, trace.NewAggregate(err, pclt.Close())
}
authClientCfg.MFAPromptConstructor = tc.NewMFAPrompt
authClient, err := auth.NewClient(authClientCfg)
if err != nil {
Expand Down
13 changes: 11 additions & 2 deletions lib/client/cluster_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,11 @@ func (c *ClusterClient) ConnectToCluster(ctx context.Context, clusterName string
return c.CurrentCluster(), nil
}

clientConfig := c.ProxyClient.ClientConfig(ctx, clusterName)
clientConfig, err := c.ProxyClient.ClientConfig(ctx, clusterName)
if err != nil {
return nil, trace.Wrap(err)
}

authClient, err := auth.NewClient(clientConfig)
return authClient, trace.Wrap(err)
}
Expand Down Expand Up @@ -271,7 +275,12 @@ func (c *ClusterClient) SessionSSHConfig(ctx context.Context, user string, targe

mfaClt := c
if target.Cluster != rootClusterName {
authClient, err := auth.NewClient(c.ProxyClient.ClientConfig(ctx, rootClusterName))
cfg, err := c.ProxyClient.ClientConfig(ctx, rootClusterName)
if err != nil {
return nil, trace.Wrap(err)
}

authClient, err := auth.NewClient(cfg)
if err != nil {
return nil, trace.Wrap(MFARequiredUnknown(err))
}
Expand Down

0 comments on commit ad26913

Please sign in to comment.