diff --git a/lib/web/apiserver_test.go b/lib/web/apiserver_test.go index 81c1aa46c5f10..eb4cdbb9707d0 100644 --- a/lib/web/apiserver_test.go +++ b/lib/web/apiserver_test.go @@ -7790,6 +7790,10 @@ type authProviderMock struct { server types.ServerV2 } +func (mock authProviderMock) ListUnifiedResources(ctx context.Context, req *authproto.ListUnifiedResourcesRequest) (*authproto.ListUnifiedResourcesResponse, error) { + return nil, nil +} + func (mock authProviderMock) GetNode(ctx context.Context, namespace, name string) (types.Server, error) { return &mock.server, nil } diff --git a/lib/web/terminal.go b/lib/web/terminal.go index 20f7b16c61aa4..d2e6d6b7c71c5 100644 --- a/lib/web/terminal.go +++ b/lib/web/terminal.go @@ -105,6 +105,7 @@ type UserAuthClient interface { CreateAuthenticateChallenge(ctx context.Context, req *authproto.CreateAuthenticateChallengeRequest) (*authproto.MFAAuthenticateChallenge, error) GenerateUserCerts(ctx context.Context, req authproto.UserCertsRequest) (*authproto.Certs, error) MaintainSessionPresence(ctx context.Context) (authproto.AuthService_MaintainSessionPresenceClient, error) + ListUnifiedResources(ctx context.Context, req *authproto.ListUnifiedResourcesRequest) (*authproto.ListUnifiedResourcesResponse, error) } // NewTerminal creates a web-based terminal based on WebSockets and returns a @@ -912,6 +913,21 @@ func (t *sshBaseHandler) connectToNode(ctx context.Context, ws terminal.WSConn, // The close error is ignored instead of using [trace.NewAggregate] because // aggregate errors do not allow error inspection with things like [trace.IsAccessDenied]. _ = conn.Close() + + // Since connection attempts are made via UUID and not hostname, any access denied errors + // will not contain the resolved host address. To provide an easier troubleshooting experience + // for users, attempt to resolve the hostname of the server and augment the error message with it. + if trace.IsAccessDenied(err) { + if resp, err := t.userAuthClient.ListUnifiedResources(ctx, &authproto.ListUnifiedResourcesRequest{ + SortBy: types.SortBy{Field: types.ResourceKind}, + Kinds: []string{types.KindNode}, + Limit: 1, + PredicateExpression: fmt.Sprintf(`resource.metadata.name == "%s"`, t.sessionData.ServerID), + }); err == nil && len(resp.Resources) > 0 { + return nil, trace.AccessDenied("access denied to %q connecting to %v", sshConfig.User, resp.Resources[0].GetNode().GetHostname()) + } + } + return nil, trace.Wrap(err) }