diff --git a/api/client/proxy/client.go b/api/client/proxy/client.go index b6b76c8a2f248..358e8b9a5993b 100644 --- a/api/client/proxy/client.go +++ b/api/client/proxy/client.go @@ -49,8 +49,8 @@ type ClientConfig struct { ProxyAddress string // TLSRoutingEnabled indicates if the cluster is using TLS Routing. TLSRoutingEnabled bool - // TLSConfig contains the tls.Config required for mTLS connections. - TLSConfig *tls.Config + // TLSConfigFunc produces the [tls.Config] required for mTLS connections to a specific cluster. + TLSConfigFunc func(cluster string) (*tls.Config, error) // UnaryInterceptors are optional [grpc.UnaryClientInterceptor] to apply // to the gRPC client. UnaryInterceptors []grpc.UnaryClientInterceptor @@ -77,9 +77,9 @@ type ClientConfig struct { // The below items are intended to be used by tests to connect without mTLS. // The gRPC transport credentials to use when establishing the connection to proxy. - creds func() credentials.TransportCredentials + creds func(cluster string) (credentials.TransportCredentials, error) // The client credentials to use when establishing the connection to auth. - clientCreds func() client.Credentials + clientCreds func(cluster string) (client.Credentials, error) } // CheckAndSetDefaults ensures required options are present and @@ -94,13 +94,21 @@ func (c *ClientConfig) CheckAndSetDefaults() error { if c.DialTimeout <= 0 { c.DialTimeout = defaults.DefaultIOTimeout } - if c.TLSConfig != nil { - c.clientCreds = func() client.Credentials { - return client.LoadTLS(c.TLSConfig.Clone()) + if c.TLSConfigFunc != nil { + c.clientCreds = func(cluster string) (client.Credentials, error) { + cfg, err := c.TLSConfigFunc(cluster) + if err != nil { + return nil, trace.Wrap(err) + } + + return client.LoadTLS(cfg), nil } - c.creds = func() credentials.TransportCredentials { - tlsCfg := c.TLSConfig.Clone() - if !slices.Contains(c.TLSConfig.NextProtos, protocolProxySSHGRPC) { + c.creds = func(cluster string) (credentials.TransportCredentials, error) { + tlsCfg, err := c.TLSConfigFunc(cluster) + if err != nil { + return nil, trace.Wrap(err) + } + if !slices.Contains(tlsCfg.NextProtos, protocolProxySSHGRPC) { tlsCfg.NextProtos = append(tlsCfg.NextProtos, protocolProxySSHGRPC) } @@ -115,14 +123,14 @@ func (c *ClientConfig) CheckAndSetDefaults() error { } } - return credentials.NewTLS(tlsCfg) + return credentials.NewTLS(tlsCfg), nil } } else { - c.clientCreds = func() client.Credentials { - return insecureCredentials{} + c.clientCreds = func(cluster string) (client.Credentials, error) { + return insecureCredentials{}, nil } - c.creds = func() credentials.TransportCredentials { - return insecure.NewCredentials() + c.creds = func(cluster string) (credentials.TransportCredentials, error) { + return insecure.NewCredentials(), nil } } @@ -265,12 +273,18 @@ func newGRPCClient(ctx context.Context, cfg *ClientConfig) (_ *Client, err error defer cancel() c := &clusterName{} + + creds, err := cfg.creds("") + if err != nil { + return nil, trace.Wrap(err) + } + conn, err := grpc.DialContext( dialCtx, cfg.ProxyAddress, append([]grpc.DialOption{ grpc.WithContextDialer(newDialerForGRPCClient(ctx, cfg)), - grpc.WithTransportCredentials(&clusterCredentials{TransportCredentials: cfg.creds(), clusterName: c}), + grpc.WithTransportCredentials(&clusterCredentials{TransportCredentials: creds, clusterName: c}), grpc.WithChainUnaryInterceptor( append(cfg.UnaryInterceptors, //nolint:staticcheck // SA1019. There is a data race in the stats.Handler that is replacing @@ -361,22 +375,27 @@ type ClusterDetails struct { // Auth server in the provided cluster via [client.New] or similar. The [client.Config] // returned will have the correct credentials and dialer set based on the ClientConfig // that was provided to create this Client. -func (c *Client) ClientConfig(ctx context.Context, cluster string) client.Config { +func (c *Client) ClientConfig(ctx context.Context, cluster string) (client.Config, error) { + creds, err := c.cfg.clientCreds(cluster) + if err != nil { + return client.Config{}, trace.Wrap(err) + } + if c.cfg.TLSRoutingEnabled { return client.Config{ Context: ctx, Addrs: []string{c.cfg.ProxyAddress}, - Credentials: []client.Credentials{c.cfg.clientCreds()}, + Credentials: []client.Credentials{creds}, ALPNSNIAuthDialClusterName: cluster, CircuitBreakerConfig: breaker.NoopBreakerConfig(), ALPNConnUpgradeRequired: c.cfg.ALPNConnUpgradeRequired, DialOpts: c.cfg.DialOpts, - } + }, nil } return client.Config{ Context: ctx, - Credentials: []client.Credentials{c.cfg.clientCreds()}, + Credentials: []client.Credentials{creds}, CircuitBreakerConfig: breaker.NoopBreakerConfig(), DialInBackground: true, Dialer: client.ContextDialerFunc(func(dialCtx context.Context, _ string, _ string) (net.Conn, error) { @@ -395,8 +414,7 @@ func (c *Client) ClientConfig(ctx context.Context, cluster string) client.Config return conn, trace.Wrap(err) }), DialOpts: c.cfg.DialOpts, - } - + }, nil } // DialHost establishes a connection to the `target` in cluster named `cluster`. If a keyring diff --git a/api/client/proxy/client_test.go b/api/client/proxy/client_test.go index e79421ad8dea1..7cb788e5e76af 100644 --- a/api/client/proxy/client_test.go +++ b/api/client/proxy/client_test.go @@ -508,7 +508,9 @@ func TestClient_DialCluster(t *testing.T) { require.NoError(t, err) t.Cleanup(func() { require.NoError(t, clt.Close()) }) - authCfg := clt.ClientConfig(ctx, "cluster") + authCfg, err := clt.ClientConfig(ctx, "cluster") + require.NoError(t, err) + authCfg.DialOpts = []grpc.DialOption{ grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithReturnConnectionError(), diff --git a/integration/integration_test.go b/integration/integration_test.go index f37432abaaf3f..1d0b11f614cc2 100644 --- a/integration/integration_test.go +++ b/integration/integration_test.go @@ -1661,7 +1661,8 @@ func testIPPropagation(t *testing.T, suite *integrationTestSuite) { // The above dialer does not work clt.AuthClient as it requires a // custom transport from ProxyClient when TLS routing is disabled. // Recreating the authClient without the above dialer. - authClientCfg := clt.ProxyClient.ClientConfig(ctx, clusterName) + authClientCfg, err := clt.ProxyClient.ClientConfig(ctx, clusterName) + require.NoError(t, err) authClientCfg.DialOpts = nil authClient, err := auth.NewClient(authClientCfg) require.NoError(t, err) diff --git a/lib/client/api.go b/lib/client/api.go index 62731693de639..2de6156070af0 100644 --- a/lib/client/api.go +++ b/lib/client/api.go @@ -1573,7 +1573,12 @@ func (tc *TeleportClient) NewTracingClient(ctx context.Context) (*apitracing.Cli return nil, trace.Wrap(err) } - tracingClient, err := client.NewTracingClient(ctx, clusterClient.ProxyClient.ClientConfig(ctx, clusterClient.ClusterName())) + cfg, err := clusterClient.ProxyClient.ClientConfig(ctx, clusterClient.ClusterName()) + if err != nil { + return nil, trace.Wrap(err) + } + + tracingClient, err := client.NewTracingClient(ctx, cfg) return tracingClient, trace.Wrap(err) } @@ -3061,15 +3066,18 @@ func (tc *TeleportClient) ConnectToCluster(ctx context.Context) (_ *ClusterClien return nil, trace.Wrap(err) } - tlsConfig, err := tc.LoadTLSConfig() - if err != nil { - return nil, trace.Wrap(err) - } - pclt, err := proxyclient.NewClient(ctx, proxyclient.ClientConfig{ - ProxyAddress: cfg.proxyAddress, - TLSRoutingEnabled: tc.TLSRoutingEnabled, - TLSConfig: tlsConfig, + ProxyAddress: cfg.proxyAddress, + TLSRoutingEnabled: tc.TLSRoutingEnabled, + TLSConfigFunc: func(cluster string) (*tls.Config, error) { + if cluster == "" { + tlsCfg, err := tc.LoadTLSConfig() + return tlsCfg, trace.Wrap(err) + } + + tlsCfg, err := tc.LoadTLSConfigForClusters([]string{cluster}) + return tlsCfg, trace.Wrap(err) + }, DialOpts: tc.Config.DialOpts, UnaryInterceptors: []grpc.UnaryClientInterceptor{interceptors.GRPCClientUnaryErrorInterceptor}, StreamInterceptors: []grpc.StreamClientInterceptor{interceptors.GRPCClientStreamErrorInterceptor}, @@ -3092,7 +3100,10 @@ func (tc *TeleportClient) ConnectToCluster(ctx context.Context) (_ *ClusterClien cluster = connected } - authClientCfg := pclt.ClientConfig(ctx, cluster) + authClientCfg, err := pclt.ClientConfig(ctx, cluster) + if err != nil { + return nil, trace.NewAggregate(err, pclt.Close()) + } authClientCfg.MFAPromptConstructor = tc.NewMFAPrompt authClient, err := auth.NewClient(authClientCfg) if err != nil { diff --git a/lib/client/cluster_client.go b/lib/client/cluster_client.go index fbea09d994f88..e9d7714b1cd8d 100644 --- a/lib/client/cluster_client.go +++ b/lib/client/cluster_client.go @@ -86,7 +86,11 @@ func (c *ClusterClient) ConnectToCluster(ctx context.Context, clusterName string return c.CurrentCluster(), nil } - clientConfig := c.ProxyClient.ClientConfig(ctx, clusterName) + clientConfig, err := c.ProxyClient.ClientConfig(ctx, clusterName) + if err != nil { + return nil, trace.Wrap(err) + } + authClient, err := auth.NewClient(clientConfig) return authClient, trace.Wrap(err) } @@ -271,7 +275,12 @@ func (c *ClusterClient) SessionSSHConfig(ctx context.Context, user string, targe mfaClt := c if target.Cluster != rootClusterName { - authClient, err := auth.NewClient(c.ProxyClient.ClientConfig(ctx, rootClusterName)) + cfg, err := c.ProxyClient.ClientConfig(ctx, rootClusterName) + if err != nil { + return nil, trace.Wrap(err) + } + + authClient, err := auth.NewClient(cfg) if err != nil { return nil, trace.Wrap(MFARequiredUnknown(err)) }