diff --git a/lib/auth/accesspoint/accesspoint.go b/lib/auth/accesspoint/accesspoint.go
index d078d25d87b92..66bf51223990f 100644
--- a/lib/auth/accesspoint/accesspoint.go
+++ b/lib/auth/accesspoint/accesspoint.go
@@ -103,6 +103,7 @@ type Config struct {
Users services.UsersService
WebSession types.WebSessionInterface
WebToken types.WebTokenInterface
+ WorkloadIdentity cache.WorkloadIdentityReader
DynamicWindowsDesktops services.DynamicWindowsDesktops
WindowsDesktops services.WindowsDesktops
AutoUpdateService services.AutoUpdateServiceGetter
@@ -203,6 +204,7 @@ func NewCache(cfg Config) (*cache.Cache, error) {
Users: cfg.Users,
WebSession: cfg.WebSession,
WebToken: cfg.WebToken,
+ WorkloadIdentity: cfg.WorkloadIdentity,
WindowsDesktops: cfg.WindowsDesktops,
DynamicWindowsDesktops: cfg.DynamicWindowsDesktops,
ProvisioningStates: cfg.ProvisioningStates,
diff --git a/lib/auth/auth.go b/lib/auth/auth.go
index dedda21f57012..342e52e821194 100644
--- a/lib/auth/auth.go
+++ b/lib/auth/auth.go
@@ -401,6 +401,13 @@ func NewServer(cfg *InitConfig, opts ...ServerOption) (*Server, error) {
return nil, trace.Wrap(err, "creating GitServer service")
}
}
+ if cfg.WorkloadIdentity == nil {
+ workloadIdentity, err := local.NewWorkloadIdentityService(cfg.Backend)
+ if err != nil {
+ return nil, trace.Wrap(err, "creating WorkloadIdentity service")
+ }
+ cfg.WorkloadIdentity = workloadIdentity
+ }
if cfg.Logger == nil {
cfg.Logger = slog.With(teleport.ComponentKey, teleport.ComponentAuth)
}
@@ -499,6 +506,7 @@ func NewServer(cfg *InitConfig, opts ...ServerOption) (*Server, error) {
IdentityCenter: cfg.IdentityCenter,
PluginStaticCredentials: cfg.PluginStaticCredentials,
GitServers: cfg.GitServers,
+ WorkloadIdentities: cfg.WorkloadIdentity,
}
as := Server{
@@ -718,6 +726,7 @@ type Services struct {
services.IdentityCenter
services.PluginStaticCredentials
services.GitServers
+ services.WorkloadIdentities
}
// GetWebSession returns existing web session described by req.
diff --git a/lib/auth/authclient/api.go b/lib/auth/authclient/api.go
index 2a9d3095b4137..409e4850e8a97 100644
--- a/lib/auth/authclient/api.go
+++ b/lib/auth/authclient/api.go
@@ -38,6 +38,7 @@ import (
userprovisioningpb "github.com/gravitational/teleport/api/gen/proto/go/teleport/userprovisioning/v2"
userspb "github.com/gravitational/teleport/api/gen/proto/go/teleport/users/v1"
usertasksv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/usertasks/v1"
+ workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/api/types/accesslist"
"github.com/gravitational/teleport/api/types/discoveryconfig"
@@ -1229,6 +1230,12 @@ type Cache interface {
// pagination.
ListSPIFFEFederations(ctx context.Context, pageSize int, lastToken string) ([]*machineidv1.SPIFFEFederation, string, error)
+ // GetWorkloadIdentity gets a WorkloadIdentity by name.
+ GetWorkloadIdentity(ctx context.Context, name string) (*workloadidentityv1pb.WorkloadIdentity, error)
+ // ListWorkloadIdentities lists all SPIFFE Federations using Google style
+ // pagination.
+ ListWorkloadIdentities(ctx context.Context, pageSize int, lastToken string) ([]*workloadidentityv1pb.WorkloadIdentity, string, error)
+
// ListStaticHostUsers lists static host users.
ListStaticHostUsers(ctx context.Context, pageSize int, startKey string) ([]*userprovisioningpb.StaticHostUser, string, error)
// GetStaticHostUser returns a static host user by name.
diff --git a/lib/auth/helpers.go b/lib/auth/helpers.go
index 71804b4ca0049..58079d4745374 100644
--- a/lib/auth/helpers.go
+++ b/lib/auth/helpers.go
@@ -360,6 +360,7 @@ func NewTestAuthServer(cfg TestAuthServerConfig) (*TestAuthServer, error) {
SecReports: svces.SecReports,
SnowflakeSession: svces.Identity,
SPIFFEFederations: svces.SPIFFEFederations,
+ WorkloadIdentity: svces.WorkloadIdentities,
StaticHostUsers: svces.StaticHostUser,
Trust: svces.TrustInternal,
UserGroups: svces.UserGroups,
diff --git a/lib/auth/init.go b/lib/auth/init.go
index 61bb8cba0e447..9d86d21c1106f 100644
--- a/lib/auth/init.go
+++ b/lib/auth/init.go
@@ -322,6 +322,10 @@ type InitConfig struct {
// SPIFFEFederations is a service that manages storing SPIFFE federations.
SPIFFEFederations services.SPIFFEFederations
+ // WorkloadIdentity is the service for storing and retrieving
+ // WorkloadIdentity resources.
+ WorkloadIdentity services.WorkloadIdentities
+
// StaticHostUsers is a service that manages host users that should be
// created on SSH nodes.
StaticHostUsers services.StaticHostUser
diff --git a/lib/cache/cache.go b/lib/cache/cache.go
index fcb2a3bf7da5f..b29d9cbd07054 100644
--- a/lib/cache/cache.go
+++ b/lib/cache/cache.go
@@ -201,6 +201,7 @@ func ForAuth(cfg Config) Config {
{Kind: types.KindIdentityCenterAccountAssignment},
{Kind: types.KindPluginStaticCredentials},
{Kind: types.KindGitServer},
+ {Kind: types.KindWorkloadIdentity},
}
cfg.QueueSize = defaults.AuthQueueSize
// We don't want to enable partial health for auth cache because auth uses an event stream
@@ -556,6 +557,7 @@ type Cache struct {
identityCenterCache *local.IdentityCenterService
pluginStaticCredentialsCache *local.PluginStaticCredentialsService
gitServersCache *local.GitServerService
+ workloadIdentityCache workloadIdentityCacher
// closed indicates that the cache has been closed
closed atomic.Bool
@@ -738,6 +740,9 @@ type Config struct {
SPIFFEFederations SPIFFEFederationReader
// StaticHostUsers is the static host user service.
StaticHostUsers services.StaticHostUser
+ // WorkloadIdentity is the upstream Workload Identities service that we're
+ // caching
+ WorkloadIdentity WorkloadIdentityReader
// Backend is a backend for local cache
Backend backend.Backend
// MaxRetryPeriod is the maximum period between cache retries on failures
@@ -1008,6 +1013,12 @@ func New(config Config) (*Cache, error) {
return nil, trace.Wrap(err)
}
+ workloadIdentityCache, err := local.NewWorkloadIdentityService(config.Backend)
+ if err != nil {
+ cancel()
+ return nil, trace.Wrap(err)
+ }
+
staticHostUserCache, err := local.NewStaticHostUserService(config.Backend)
if err != nil {
cancel()
@@ -1094,6 +1105,7 @@ func New(config Config) (*Cache, error) {
identityCenterCache: identityCenterCache,
pluginStaticCredentialsCache: pluginStaticCredentialsCache,
gitServersCache: gitServersCache,
+ workloadIdentityCache: workloadIdentityCache,
Logger: log.WithFields(log.Fields{
teleport.ComponentKey: config.Component,
}),
diff --git a/lib/cache/cache_test.go b/lib/cache/cache_test.go
index af4b4d195bce4..e60f0acac0174 100644
--- a/lib/cache/cache_test.go
+++ b/lib/cache/cache_test.go
@@ -143,6 +143,7 @@ type testPack struct {
identityCenter services.IdentityCenter
pluginStaticCredentials *local.PluginStaticCredentialsService
gitServers services.GitServers
+ workloadIdentity *local.WorkloadIdentityService
}
// testFuncs are functions to support testing an object in a cache.
@@ -365,6 +366,12 @@ func newPackWithoutCache(dir string, opts ...packOption) (*testPack, error) {
}
p.spiffeFederations = spiffeFederationsSvc
+ workloadIdentitySvc, err := local.NewWorkloadIdentityService(p.backend)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ p.workloadIdentity = workloadIdentitySvc
+
databaseObjectsSvc, err := local.NewDatabaseObjectService(p.backend)
if err != nil {
return nil, trace.Wrap(err)
@@ -470,6 +477,7 @@ func newPack(dir string, setupConfig func(c Config) Config, opts ...packOption)
IdentityCenter: p.identityCenter,
PluginStaticCredentials: p.pluginStaticCredentials,
GitServers: p.gitServers,
+ WorkloadIdentity: p.workloadIdentity,
MaxRetryPeriod: 200 * time.Millisecond,
EventsC: p.eventsC,
}))
@@ -881,6 +889,7 @@ func TestCompletenessInit(t *testing.T) {
StaticHostUsers: p.staticHostUsers,
AutoUpdateService: p.autoUpdateService,
ProvisioningStates: p.provisioningStates,
+ WorkloadIdentity: p.workloadIdentity,
MaxRetryPeriod: 200 * time.Millisecond,
IdentityCenter: p.identityCenter,
PluginStaticCredentials: p.pluginStaticCredentials,
@@ -969,6 +978,7 @@ func TestCompletenessReset(t *testing.T) {
ProvisioningStates: p.provisioningStates,
IdentityCenter: p.identityCenter,
PluginStaticCredentials: p.pluginStaticCredentials,
+ WorkloadIdentity: p.workloadIdentity,
MaxRetryPeriod: 200 * time.Millisecond,
EventsC: p.eventsC,
GitServers: p.gitServers,
@@ -1181,6 +1191,7 @@ func TestListResources_NodesTTLVariant(t *testing.T) {
ProvisioningStates: p.provisioningStates,
IdentityCenter: p.identityCenter,
PluginStaticCredentials: p.pluginStaticCredentials,
+ WorkloadIdentity: p.workloadIdentity,
MaxRetryPeriod: 200 * time.Millisecond,
EventsC: p.eventsC,
neverOK: true, // ensure reads are never healthy
@@ -1278,6 +1289,7 @@ func initStrategy(t *testing.T) {
ProvisioningStates: p.provisioningStates,
IdentityCenter: p.identityCenter,
PluginStaticCredentials: p.pluginStaticCredentials,
+ WorkloadIdentity: p.workloadIdentity,
MaxRetryPeriod: 200 * time.Millisecond,
EventsC: p.eventsC,
GitServers: p.gitServers,
@@ -3556,6 +3568,7 @@ func TestCacheWatchKindExistsInEvents(t *testing.T) {
types.KindIdentityCenterPrincipalAssignment: types.Resource153ToLegacy(newIdentityCenterPrincipalAssignment("some_principal_assignment")),
types.KindPluginStaticCredentials: &types.PluginStaticCredentialsV1{},
types.KindGitServer: &types.ServerV2{},
+ types.KindWorkloadIdentity: types.Resource153ToLegacy(newWorkloadIdentity("some_identifier")),
}
for name, cfg := range cases {
diff --git a/lib/cache/collections.go b/lib/cache/collections.go
index 2635a0d71ea04..f73f83fcddb83 100644
--- a/lib/cache/collections.go
+++ b/lib/cache/collections.go
@@ -41,6 +41,7 @@ import (
userprovisioningpb "github.com/gravitational/teleport/api/gen/proto/go/teleport/userprovisioning/v2"
userspb "github.com/gravitational/teleport/api/gen/proto/go/teleport/users/v1"
usertasksv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/usertasks/v1"
+ workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/api/types/accesslist"
"github.com/gravitational/teleport/api/types/discoveryconfig"
@@ -178,6 +179,7 @@ type cacheCollections struct {
identityCenterAccountAssignments collectionReader[identityCenterAccountAssignmentGetter]
pluginStaticCredentials collectionReader[pluginStaticCredentialsGetter]
gitServers collectionReader[services.GitServerGetter]
+ workloadIdentity collectionReader[WorkloadIdentityReader]
}
// setupCollections returns a registry of collections.
@@ -706,6 +708,15 @@ func setupCollections(c *Cache, watches []types.WatchKind) (*cacheCollections, e
watch: watch,
}
collections.byKind[resourceKind] = collections.spiffeFederations
+ case types.KindWorkloadIdentity:
+ if c.Config.WorkloadIdentity == nil {
+ return nil, trace.BadParameter("missing parameter WorkloadIdentity")
+ }
+ collections.workloadIdentity = &genericCollection[*workloadidentityv1pb.WorkloadIdentity, WorkloadIdentityReader, workloadIdentityExecutor]{
+ cache: c,
+ watch: watch,
+ }
+ collections.byKind[resourceKind] = collections.workloadIdentity
case types.KindAutoUpdateConfig:
if c.AutoUpdateService == nil {
return nil, trace.BadParameter("missing parameter AutoUpdateService")
diff --git a/lib/cache/resource_workload_identity.go b/lib/cache/resource_workload_identity.go
new file mode 100644
index 0000000000000..75efb50fedbd5
--- /dev/null
+++ b/lib/cache/resource_workload_identity.go
@@ -0,0 +1,119 @@
+// Teleport
+// Copyright (C) 2024 Gravitational, Inc.
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+//nolint:unused // Because the executors generate a large amount of false positives.
+package cache
+
+import (
+ "context"
+
+ "github.com/gravitational/trace"
+
+ workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1"
+ "github.com/gravitational/teleport/api/types"
+)
+
+// WorkloadIdentityReader is an interface that defines the methods for getting
+// WorkloadIdentity. This is returned as the reader for the WorkloadIdentity
+// collection but is also used by the executor to read the full list of
+// WorkloadIdentity on initialization.
+type WorkloadIdentityReader interface {
+ ListWorkloadIdentities(ctx context.Context, pageSize int, nextToken string) ([]*workloadidentityv1pb.WorkloadIdentity, string, error)
+ GetWorkloadIdentity(ctx context.Context, name string) (*workloadidentityv1pb.WorkloadIdentity, error)
+}
+
+// workloadIdentityCacher is used for storing and retrieving WorkloadIdentity
+// from the cache's local backend.
+type workloadIdentityCacher interface {
+ WorkloadIdentityReader
+ UpsertWorkloadIdentity(ctx context.Context, resource *workloadidentityv1pb.WorkloadIdentity) (*workloadidentityv1pb.WorkloadIdentity, error)
+ DeleteWorkloadIdentity(ctx context.Context, name string) error
+ DeleteAllWorkloadIdentities(ctx context.Context) error
+}
+
+type workloadIdentityExecutor struct{}
+
+var _ executor[*workloadidentityv1pb.WorkloadIdentity, WorkloadIdentityReader] = workloadIdentityExecutor{}
+
+func (workloadIdentityExecutor) getAll(ctx context.Context, cache *Cache, loadSecrets bool) ([]*workloadidentityv1pb.WorkloadIdentity, error) {
+ var out []*workloadidentityv1pb.WorkloadIdentity
+ var nextToken string
+ for {
+ var page []*workloadidentityv1pb.WorkloadIdentity
+ var err error
+
+ const defaultPageSize = 0
+ page, nextToken, err = cache.Config.WorkloadIdentity.ListWorkloadIdentities(ctx, defaultPageSize, nextToken)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ out = append(out, page...)
+ if nextToken == "" {
+ break
+ }
+ }
+ return out, nil
+}
+
+func (workloadIdentityExecutor) upsert(ctx context.Context, cache *Cache, resource *workloadidentityv1pb.WorkloadIdentity) error {
+ _, err := cache.workloadIdentityCache.UpsertWorkloadIdentity(ctx, resource)
+ return trace.Wrap(err)
+}
+
+func (workloadIdentityExecutor) deleteAll(ctx context.Context, cache *Cache) error {
+ return trace.Wrap(cache.workloadIdentityCache.DeleteAllWorkloadIdentities(ctx))
+}
+
+func (workloadIdentityExecutor) delete(ctx context.Context, cache *Cache, resource types.Resource) error {
+ return trace.Wrap(cache.workloadIdentityCache.DeleteWorkloadIdentity(ctx, resource.GetName()))
+}
+
+func (workloadIdentityExecutor) isSingleton() bool { return false }
+
+func (workloadIdentityExecutor) getReader(cache *Cache, cacheOK bool) WorkloadIdentityReader {
+ if cacheOK {
+ return cache.workloadIdentityCache
+ }
+ return cache.Config.WorkloadIdentity
+}
+
+// ListWorkloadIdentities returns a paginated list of WorkloadIdentity resources.
+func (c *Cache) ListWorkloadIdentities(ctx context.Context, pageSize int, nextToken string) ([]*workloadidentityv1pb.WorkloadIdentity, string, error) {
+ ctx, span := c.Tracer.Start(ctx, "cache/ListWorkloadIdentities")
+ defer span.End()
+
+ rg, err := readCollectionCache(c, c.collections.workloadIdentity)
+ if err != nil {
+ return nil, "", trace.Wrap(err)
+ }
+ defer rg.Release()
+ out, nextKey, err := rg.reader.ListWorkloadIdentities(ctx, pageSize, nextToken)
+ return out, nextKey, trace.Wrap(err)
+}
+
+// GetWorkloadIdentity returns a single WorkloadIdentity by name
+func (c *Cache) GetWorkloadIdentity(ctx context.Context, name string) (*workloadidentityv1pb.WorkloadIdentity, error) {
+ ctx, span := c.Tracer.Start(ctx, "cache/GetWorkloadIdentity")
+ defer span.End()
+
+ rg, err := readCollectionCache(c, c.collections.workloadIdentity)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ defer rg.Release()
+ out, err := rg.reader.GetWorkloadIdentity(ctx, name)
+ return out, trace.Wrap(err)
+}
diff --git a/lib/cache/resource_workload_identity_test.go b/lib/cache/resource_workload_identity_test.go
new file mode 100644
index 0000000000000..da82d64fec27c
--- /dev/null
+++ b/lib/cache/resource_workload_identity_test.go
@@ -0,0 +1,74 @@
+// Teleport
+// Copyright (C) 2024 Gravitational, Inc.
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+package cache
+
+import (
+ "context"
+ "testing"
+
+ "github.com/gravitational/trace"
+
+ headerv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/header/v1"
+ workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1"
+ "github.com/gravitational/teleport/api/types"
+)
+
+func newWorkloadIdentity(name string) *workloadidentityv1pb.WorkloadIdentity {
+ return &workloadidentityv1pb.WorkloadIdentity{
+ Kind: types.KindWorkloadIdentity,
+ Version: types.V1,
+ Metadata: &headerv1.Metadata{
+ Name: name,
+ },
+ Spec: &workloadidentityv1pb.WorkloadIdentitySpec{
+ Spiffe: &workloadidentityv1pb.WorkloadIdentitySPIFFE{
+ Id: "/example",
+ },
+ },
+ }
+}
+
+func TestWorkloadIdentity(t *testing.T) {
+ t.Parallel()
+
+ p := newTestPack(t, ForAuth)
+ t.Cleanup(p.Close)
+
+ testResources153(t, p, testFuncs153[*workloadidentityv1pb.WorkloadIdentity]{
+ newResource: func(s string) (*workloadidentityv1pb.WorkloadIdentity, error) {
+ return newWorkloadIdentity(s), nil
+ },
+
+ create: func(ctx context.Context, item *workloadidentityv1pb.WorkloadIdentity) error {
+ _, err := p.workloadIdentity.CreateWorkloadIdentity(ctx, item)
+ return trace.Wrap(err)
+ },
+ list: func(ctx context.Context) ([]*workloadidentityv1pb.WorkloadIdentity, error) {
+ items, _, err := p.workloadIdentity.ListWorkloadIdentities(ctx, 0, "")
+ return items, trace.Wrap(err)
+ },
+ deleteAll: func(ctx context.Context) error {
+ return p.workloadIdentity.DeleteAllWorkloadIdentities(ctx)
+ },
+
+ cacheList: func(ctx context.Context) ([]*workloadidentityv1pb.WorkloadIdentity, error) {
+ items, _, err := p.cache.ListWorkloadIdentities(ctx, 0, "")
+ return items, trace.Wrap(err)
+ },
+ cacheGet: p.cache.GetWorkloadIdentity,
+ })
+}
diff --git a/lib/service/service.go b/lib/service/service.go
index a60a926a7c486..16276a69827fa 100644
--- a/lib/service/service.go
+++ b/lib/service/service.go
@@ -2548,6 +2548,7 @@ func (process *TeleportProcess) newAccessCacheForServices(cfg accesspoint.Config
cfg.WebSession = services.Identity.WebSessions()
cfg.WebToken = services.Identity.WebTokens()
cfg.WindowsDesktops = services.WindowsDesktops
+ cfg.WorkloadIdentity = services.WorkloadIdentities
cfg.DynamicWindowsDesktops = services.DynamicWindowsDesktops
cfg.AutoUpdateService = services.AutoUpdateService
cfg.ProvisioningStates = services.ProvisioningStates
diff --git a/lib/services/local/events.go b/lib/services/local/events.go
index d0522bf7bd5f2..9931b80857500 100644
--- a/lib/services/local/events.go
+++ b/lib/services/local/events.go
@@ -252,6 +252,8 @@ func (e *EventsService) NewWatcher(ctx context.Context, watch types.Watch) (type
parser = newPluginStaticCredentialsParser()
case types.KindGitServer:
parser = newGitServerParser()
+ case types.KindWorkloadIdentity:
+ parser = newWorkloadIdentityParser()
default:
if watch.AllowPartialSuccess {
continue
@@ -3179,6 +3181,46 @@ func (p *spiffeFederationParser) parse(event backend.Event) (types.Resource, err
}
}
+func newWorkloadIdentityParser() *workloadIdentityParser {
+ return &workloadIdentityParser{
+ baseParser: newBaseParser(backend.NewKey(workloadIdentityPrefix)),
+ }
+}
+
+type workloadIdentityParser struct {
+ baseParser
+}
+
+func (p *workloadIdentityParser) parse(event backend.Event) (types.Resource, error) {
+ switch event.Type {
+ case types.OpDelete:
+ name := event.Item.Key.TrimPrefix(backend.NewKey(workloadIdentityPrefix)).String()
+ if name == "" {
+ return nil, trace.NotFound("failed parsing %v", event.Item.Key.String())
+ }
+
+ return &types.ResourceHeader{
+ Kind: types.KindWorkloadIdentity,
+ Version: types.V1,
+ Metadata: types.Metadata{
+ Name: strings.TrimPrefix(name, backend.SeparatorString),
+ Namespace: apidefaults.Namespace,
+ },
+ }, nil
+ case types.OpPut:
+ resource, err := services.UnmarshalWorkloadIdentity(
+ event.Item.Value,
+ services.WithExpires(event.Item.Expires),
+ services.WithRevision(event.Item.Revision))
+ if err != nil {
+ return nil, trace.Wrap(err, "unmarshalling resource from event")
+ }
+ return types.Resource153ToLegacy(resource), nil
+ default:
+ return nil, trace.BadParameter("event %v is not supported", event.Type)
+ }
+}
+
func newProvisioningStateParser() *provisioningStateParser {
return &provisioningStateParser{
baseParser: newBaseParser(backend.NewKey(provisioningStatePrefix)),
diff --git a/lib/services/local/workload_identity.go b/lib/services/local/workload_identity.go
new file mode 100644
index 0000000000000..e0504e989cbe8
--- /dev/null
+++ b/lib/services/local/workload_identity.go
@@ -0,0 +1,118 @@
+// Teleport
+// Copyright (C) 2024 Gravitational, Inc.
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+package local
+
+import (
+ "context"
+
+ "github.com/gravitational/trace"
+
+ workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1"
+ "github.com/gravitational/teleport/api/types"
+ "github.com/gravitational/teleport/lib/backend"
+ "github.com/gravitational/teleport/lib/services"
+ "github.com/gravitational/teleport/lib/services/local/generic"
+)
+
+const (
+ workloadIdentityPrefix = "workload_identity"
+)
+
+// WorkloadIdentityService exposes backend functionality for storing
+// WorkloadIdentity resources
+type WorkloadIdentityService struct {
+ service *generic.ServiceWrapper[*workloadidentityv1pb.WorkloadIdentity]
+}
+
+// NewWorkloadIdentityService creates a new WorkloadIdentityService
+func NewWorkloadIdentityService(b backend.Backend) (*WorkloadIdentityService, error) {
+ service, err := generic.NewServiceWrapper(
+ generic.ServiceWrapperConfig[*workloadidentityv1pb.WorkloadIdentity]{
+ Backend: b,
+ ResourceKind: types.KindWorkloadIdentity,
+ BackendPrefix: backend.NewKey(workloadIdentityPrefix),
+ MarshalFunc: services.MarshalWorkloadIdentity,
+ UnmarshalFunc: services.UnmarshalWorkloadIdentity,
+ ValidateFunc: services.ValidateWorkloadIdentity,
+ })
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ return &WorkloadIdentityService{
+ service: service,
+ }, nil
+}
+
+// CreateWorkloadIdentity inserts a new WorkloadIdentity into the backend.
+func (b *WorkloadIdentityService) CreateWorkloadIdentity(
+ ctx context.Context, resource *workloadidentityv1pb.WorkloadIdentity,
+) (*workloadidentityv1pb.WorkloadIdentity, error) {
+ created, err := b.service.CreateResource(ctx, resource)
+ return created, trace.Wrap(err)
+}
+
+// GetWorkloadIdentity retrieves a specific WorkloadIdentity given a name
+func (b *WorkloadIdentityService) GetWorkloadIdentity(
+ ctx context.Context, name string,
+) (*workloadidentityv1pb.WorkloadIdentity, error) {
+ resource, err := b.service.GetResource(ctx, name)
+ return resource, trace.Wrap(err)
+}
+
+// ListWorkloadIdentities lists all WorkloadIdentities using a given page size
+// and last key.
+func (b *WorkloadIdentityService) ListWorkloadIdentities(
+ ctx context.Context, pageSize int, currentToken string,
+) ([]*workloadidentityv1pb.WorkloadIdentity, string, error) {
+ r, nextToken, err := b.service.ListResources(ctx, pageSize, currentToken)
+ return r, nextToken, trace.Wrap(err)
+}
+
+// DeleteWorkloadIdentity deletes a specific WorkloadIdentity.
+func (b *WorkloadIdentityService) DeleteWorkloadIdentity(
+ ctx context.Context, name string,
+) error {
+ return trace.Wrap(b.service.DeleteResource(ctx, name))
+}
+
+// DeleteAllWorkloadIdentities deletes all SPIFFE resources, this is typically
+// only meant to be used by the cache.
+func (b *WorkloadIdentityService) DeleteAllWorkloadIdentities(
+ ctx context.Context,
+) error {
+ return trace.Wrap(b.service.DeleteAllResources(ctx))
+}
+
+// UpsertWorkloadIdentity upserts a WorkloadIdentitys. Prefer using
+// CreateWorkloadIdentity. This is only designed for usage by the cache.
+func (b *WorkloadIdentityService) UpsertWorkloadIdentity(
+ ctx context.Context, resource *workloadidentityv1pb.WorkloadIdentity,
+) (*workloadidentityv1pb.WorkloadIdentity, error) {
+ upserted, err := b.service.UpsertResource(ctx, resource)
+ return upserted, trace.Wrap(err)
+}
+
+// UpdateWorkloadIdentity updates a specific WorkloadIdentity. The resource must
+// already exist, and, condition update semantics are used - e.g the submitted
+// resource must have a revision matching the revision of the resource in the
+// backend.
+func (b *WorkloadIdentityService) UpdateWorkloadIdentity(
+ ctx context.Context, resource *workloadidentityv1pb.WorkloadIdentity,
+) (*workloadidentityv1pb.WorkloadIdentity, error) {
+ updated, err := b.service.ConditionalUpdateResource(ctx, resource)
+ return updated, trace.Wrap(err)
+}
diff --git a/lib/services/local/workload_identity_test.go b/lib/services/local/workload_identity_test.go
new file mode 100644
index 0000000000000..acba05d9c8e4a
--- /dev/null
+++ b/lib/services/local/workload_identity_test.go
@@ -0,0 +1,355 @@
+// Teleport
+// Copyright (C) 2024 Gravitational, Inc.
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+package local
+
+import (
+ "context"
+ "fmt"
+ "slices"
+ "testing"
+
+ "github.com/google/go-cmp/cmp"
+ "github.com/gravitational/trace"
+ "github.com/jonboulle/clockwork"
+ "github.com/stretchr/testify/require"
+ "google.golang.org/protobuf/proto"
+ "google.golang.org/protobuf/testing/protocmp"
+
+ headerv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/header/v1"
+ workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1"
+ "github.com/gravitational/teleport/api/types"
+ "github.com/gravitational/teleport/lib/backend"
+ "github.com/gravitational/teleport/lib/backend/memory"
+)
+
+func setupWorkloadIdentityServiceTest(
+ t *testing.T,
+) (context.Context, *WorkloadIdentityService) {
+ t.Parallel()
+ ctx := context.Background()
+ clock := clockwork.NewFakeClock()
+ mem, err := memory.New(memory.Config{
+ Context: ctx,
+ Clock: clock,
+ })
+ require.NoError(t, err)
+ service, err := NewWorkloadIdentityService(backend.NewSanitizer(mem))
+ require.NoError(t, err)
+ return ctx, service
+}
+
+func newValidWorkloadIdentity(name string) *workloadidentityv1pb.WorkloadIdentity {
+ return &workloadidentityv1pb.WorkloadIdentity{
+ Kind: types.KindWorkloadIdentity,
+ Version: types.V1,
+ Metadata: &headerv1.Metadata{
+ Name: name,
+ },
+ Spec: &workloadidentityv1pb.WorkloadIdentitySpec{
+ Spiffe: &workloadidentityv1pb.WorkloadIdentitySPIFFE{
+ Id: "/test",
+ },
+ },
+ }
+}
+
+func TestWorkloadIdentityService_CreateWorkloadIdentity(t *testing.T) {
+ ctx, service := setupWorkloadIdentityServiceTest(t)
+
+ t.Run("ok", func(t *testing.T) {
+ want := newValidWorkloadIdentity("example")
+ got, err := service.CreateWorkloadIdentity(
+ ctx,
+ // Clone to avoid Marshaling modifying want
+ proto.Clone(want).(*workloadidentityv1pb.WorkloadIdentity),
+ )
+ require.NoError(t, err)
+ require.NotEmpty(t, got.Metadata.Revision)
+ require.Empty(t, cmp.Diff(
+ want,
+ got,
+ protocmp.Transform(),
+ protocmp.IgnoreFields(&headerv1.Metadata{}, "revision"),
+ ))
+ })
+ t.Run("validation occurs", func(t *testing.T) {
+ out, err := service.CreateWorkloadIdentity(ctx, newValidWorkloadIdentity(""))
+ require.ErrorContains(t, err, "metadata.name: is required")
+ require.Nil(t, out)
+ })
+ t.Run("no upsert", func(t *testing.T) {
+ res := newValidWorkloadIdentity("duplicate")
+ _, err := service.CreateWorkloadIdentity(
+ ctx,
+ // Clone to avoid Marshaling modifying want
+ proto.Clone(res).(*workloadidentityv1pb.WorkloadIdentity),
+ )
+ require.NoError(t, err)
+ _, err = service.CreateWorkloadIdentity(
+ ctx,
+ // Clone to avoid Marshaling modifying want
+ proto.Clone(res).(*workloadidentityv1pb.WorkloadIdentity),
+ )
+ require.Error(t, err)
+ require.True(t, trace.IsAlreadyExists(err))
+ })
+}
+
+func TestWorkloadIdentityService_UpsertWorkloadIdentity(t *testing.T) {
+ ctx, service := setupWorkloadIdentityServiceTest(t)
+
+ t.Run("ok", func(t *testing.T) {
+ want := newValidWorkloadIdentity("example")
+ got, err := service.UpsertWorkloadIdentity(
+ ctx,
+ // Clone to avoid Marshaling modifying want
+ proto.Clone(want).(*workloadidentityv1pb.WorkloadIdentity),
+ )
+ require.NoError(t, err)
+ require.NotEmpty(t, got.Metadata.Revision)
+ require.Empty(t, cmp.Diff(
+ want,
+ got,
+ protocmp.Transform(),
+ protocmp.IgnoreFields(&headerv1.Metadata{}, "revision"),
+ ))
+
+ // Ensure we can upsert over an existing resource
+ _, err = service.UpsertWorkloadIdentity(
+ ctx,
+ // Clone to avoid Marshaling modifying want
+ proto.Clone(want).(*workloadidentityv1pb.WorkloadIdentity),
+ )
+ require.NoError(t, err)
+ })
+ t.Run("validation occurs", func(t *testing.T) {
+ out, err := service.UpdateWorkloadIdentity(ctx, newValidWorkloadIdentity(""))
+ require.ErrorContains(t, err, "metadata.name: is required")
+ require.Nil(t, out)
+ })
+}
+
+func TestWorkloadIdentityService_ListWorkloadIdentities(t *testing.T) {
+ ctx, service := setupWorkloadIdentityServiceTest(t)
+ // Create entities to list
+ createdObjects := []*workloadidentityv1pb.WorkloadIdentity{}
+ // Create 49 entities to test an incomplete page at the end.
+ for i := 0; i < 49; i++ {
+ created, err := service.CreateWorkloadIdentity(
+ ctx,
+ newValidWorkloadIdentity(fmt.Sprintf("%d", i)),
+ )
+ require.NoError(t, err)
+ createdObjects = append(createdObjects, created)
+ }
+ t.Run("default page size", func(t *testing.T) {
+ page, nextToken, err := service.ListWorkloadIdentities(ctx, 0, "")
+ require.NoError(t, err)
+ require.Len(t, page, 49)
+ require.Empty(t, nextToken)
+
+ // Expect that we get all the things we have created
+ for _, created := range createdObjects {
+ slices.ContainsFunc(page, func(resource *workloadidentityv1pb.WorkloadIdentity) bool {
+ return proto.Equal(created, resource)
+ })
+ }
+ })
+ t.Run("pagination", func(t *testing.T) {
+ fetched := []*workloadidentityv1pb.WorkloadIdentity{}
+ token := ""
+ iterations := 0
+ for {
+ iterations++
+ page, nextToken, err := service.ListWorkloadIdentities(ctx, 10, token)
+ require.NoError(t, err)
+ fetched = append(fetched, page...)
+ if nextToken == "" {
+ break
+ }
+ token = nextToken
+ }
+ require.Equal(t, 5, iterations)
+
+ require.Len(t, fetched, 49)
+ // Expect that we get all the things we have created
+ for _, created := range createdObjects {
+ slices.ContainsFunc(fetched, func(resource *workloadidentityv1pb.WorkloadIdentity) bool {
+ return proto.Equal(created, resource)
+ })
+ }
+ })
+}
+
+func TestWorkloadIdentityService_GetWorkloadIdentity(t *testing.T) {
+ ctx, service := setupWorkloadIdentityServiceTest(t)
+
+ t.Run("ok", func(t *testing.T) {
+ want := newValidWorkloadIdentity("example")
+ _, err := service.CreateWorkloadIdentity(
+ ctx,
+ // Clone to avoid Marshaling modifying want
+ proto.Clone(want).(*workloadidentityv1pb.WorkloadIdentity),
+ )
+ require.NoError(t, err)
+ got, err := service.GetWorkloadIdentity(ctx, "example")
+ require.NoError(t, err)
+ require.NotEmpty(t, got.Metadata.Revision)
+ require.Empty(t, cmp.Diff(
+ want,
+ got,
+ protocmp.Transform(),
+ protocmp.IgnoreFields(&headerv1.Metadata{}, "revision"),
+ ))
+ })
+ t.Run("not found", func(t *testing.T) {
+ _, err := service.GetWorkloadIdentity(ctx, "not-found")
+ require.Error(t, err)
+ require.True(t, trace.IsNotFound(err))
+ })
+}
+
+func TestWorkloadIdentityService_DeleteWorkloadIdentity(t *testing.T) {
+ ctx, service := setupWorkloadIdentityServiceTest(t)
+
+ t.Run("ok", func(t *testing.T) {
+ _, err := service.CreateWorkloadIdentity(
+ ctx,
+ newValidWorkloadIdentity("example"),
+ )
+ require.NoError(t, err)
+
+ _, err = service.GetWorkloadIdentity(ctx, "example")
+ require.NoError(t, err)
+
+ err = service.DeleteWorkloadIdentity(ctx, "example")
+ require.NoError(t, err)
+
+ _, err = service.GetWorkloadIdentity(ctx, "example")
+ require.Error(t, err)
+ require.True(t, trace.IsNotFound(err))
+ })
+ t.Run("not found", func(t *testing.T) {
+ err := service.DeleteWorkloadIdentity(ctx, "foo.example.com")
+ require.Error(t, err)
+ require.True(t, trace.IsNotFound(err))
+ })
+}
+
+func TestWorkloadIdentityService_DeleteAllWorkloadIdentities(t *testing.T) {
+ ctx, service := setupWorkloadIdentityServiceTest(t)
+ _, err := service.CreateWorkloadIdentity(
+ ctx,
+ newValidWorkloadIdentity("1"),
+ )
+ require.NoError(t, err)
+ _, err = service.CreateWorkloadIdentity(
+ ctx,
+ newValidWorkloadIdentity("2"),
+ )
+ require.NoError(t, err)
+
+ page, _, err := service.ListWorkloadIdentities(ctx, 0, "")
+ require.NoError(t, err)
+ require.Len(t, page, 2)
+
+ err = service.DeleteAllWorkloadIdentities(ctx)
+ require.NoError(t, err)
+
+ page, _, err = service.ListWorkloadIdentities(ctx, 0, "")
+ require.NoError(t, err)
+ require.Empty(t, page)
+}
+
+func TestWorkloadIdentityService_UpdateWorkloadIdentity(t *testing.T) {
+ ctx, service := setupWorkloadIdentityServiceTest(t)
+
+ t.Run("ok", func(t *testing.T) {
+ // Create first to support updating
+ toCreate := newValidWorkloadIdentity("example")
+ got, err := service.CreateWorkloadIdentity(
+ ctx,
+ // Clone to avoid Marshaling modifying want
+ proto.Clone(toCreate).(*workloadidentityv1pb.WorkloadIdentity),
+ )
+ require.NoError(t, err)
+ require.NotEmpty(t, got.Metadata.Revision)
+ got.Spec.Spiffe.Id = "/changed"
+ got2, err := service.UpdateWorkloadIdentity(
+ ctx,
+ // Clone to avoid Marshaling modifying want
+ proto.Clone(got).(*workloadidentityv1pb.WorkloadIdentity),
+ )
+ require.NoError(t, err)
+ require.NotEmpty(t, got2.Metadata.Revision)
+ require.Empty(t, cmp.Diff(
+ got,
+ got2,
+ protocmp.Transform(),
+ protocmp.IgnoreFields(&headerv1.Metadata{}, "revision"),
+ ))
+ })
+ t.Run("validation occurs", func(t *testing.T) {
+ // Create first to support updating
+ toCreate := newValidWorkloadIdentity("example2")
+ got, err := service.CreateWorkloadIdentity(
+ ctx,
+ // Clone to avoid Marshaling modifying want
+ proto.Clone(toCreate).(*workloadidentityv1pb.WorkloadIdentity),
+ )
+ require.NoError(t, err)
+ require.NotEmpty(t, got.Metadata.Revision)
+ got.Spec.Spiffe.Id = ""
+ got2, err := service.UpdateWorkloadIdentity(
+ ctx,
+ // Clone to avoid Marshaling modifying want
+ proto.Clone(got).(*workloadidentityv1pb.WorkloadIdentity),
+ )
+ require.ErrorContains(t, err, "spec.spiffe.id: is required")
+ require.Nil(t, got2)
+ })
+ t.Run("cond update blocks", func(t *testing.T) {
+ toCreate := newValidWorkloadIdentity("example4")
+ got, err := service.CreateWorkloadIdentity(
+ ctx,
+ // Clone to avoid Marshaling modifying want
+ proto.Clone(toCreate).(*workloadidentityv1pb.WorkloadIdentity),
+ )
+ require.NoError(t, err)
+ // We'll now update it twice, but on the second update, we will use the
+ // revision from the creation not the second update.
+ _, err = service.UpdateWorkloadIdentity(
+ ctx,
+ proto.Clone(got).(*workloadidentityv1pb.WorkloadIdentity),
+ )
+ require.NoError(t, err)
+ _, err = service.UpdateWorkloadIdentity(
+ ctx,
+ proto.Clone(got).(*workloadidentityv1pb.WorkloadIdentity),
+ )
+ require.ErrorIs(t, err, backend.ErrIncorrectRevision)
+ })
+ t.Run("no upsert", func(t *testing.T) {
+ toUpdate := newValidWorkloadIdentity("example3")
+ _, err := service.UpdateWorkloadIdentity(
+ ctx,
+ // Clone to avoid Marshaling modifying want
+ proto.Clone(toUpdate).(*workloadidentityv1pb.WorkloadIdentity),
+ )
+ require.Error(t, err)
+ })
+}
diff --git a/lib/services/workload_identity.go b/lib/services/workload_identity.go
new file mode 100644
index 0000000000000..89b87ba0d2473
--- /dev/null
+++ b/lib/services/workload_identity.go
@@ -0,0 +1,122 @@
+// Teleport
+// Copyright (C) 2024 Gravitational, Inc.
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+package services
+
+import (
+ "context"
+ "strings"
+
+ "github.com/gravitational/trace"
+
+ workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1"
+ "github.com/gravitational/teleport/api/types"
+)
+
+// WorkloadIdentities is an interface over the WorkloadIdentities service. This
+// interface may also be implemented by a client to allow remote and local
+// consumers to access the resource in a similar way.
+type WorkloadIdentities interface {
+ // GetWorkloadIdentity gets a SPIFFE Federation by name.
+ GetWorkloadIdentity(
+ ctx context.Context, name string,
+ ) (*workloadidentityv1pb.WorkloadIdentity, error)
+ // ListWorkloadIdentities lists all WorkloadIdentities using Google style
+ // pagination.
+ ListWorkloadIdentities(
+ ctx context.Context, pageSize int, lastToken string,
+ ) ([]*workloadidentityv1pb.WorkloadIdentity, string, error)
+ // CreateWorkloadIdentity creates a new WorkloadIdentity.
+ CreateWorkloadIdentity(
+ ctx context.Context, workloadIdentity *workloadidentityv1pb.WorkloadIdentity,
+ ) (*workloadidentityv1pb.WorkloadIdentity, error)
+ // DeleteWorkloadIdentity deletes a SPIFFE Federation by name.
+ DeleteWorkloadIdentity(ctx context.Context, name string) error
+ // UpdateWorkloadIdentity updates a specific WorkloadIdentity. The resource must
+ // already exist, and, condition update semantics are used - e.g the submitted
+ // resource must have a revision matching the revision of the resource in the
+ // backend.
+ UpdateWorkloadIdentity(
+ ctx context.Context, workloadIdentity *workloadidentityv1pb.WorkloadIdentity,
+ ) (*workloadidentityv1pb.WorkloadIdentity, error)
+ // UpsertWorkloadIdentity creates or updates a WorkloadIdentity.
+ UpsertWorkloadIdentity(
+ ctx context.Context, workloadIdentity *workloadidentityv1pb.WorkloadIdentity,
+ ) (*workloadidentityv1pb.WorkloadIdentity, error)
+}
+
+// MarshalWorkloadIdentity marshals the WorkloadIdentity object into a JSON byte
+// array.
+func MarshalWorkloadIdentity(
+ object *workloadidentityv1pb.WorkloadIdentity, opts ...MarshalOption,
+) ([]byte, error) {
+ return MarshalProtoResource(object, opts...)
+}
+
+// UnmarshalWorkloadIdentity unmarshals the WorkloadIdentity object from a
+// JSON byte array.
+func UnmarshalWorkloadIdentity(
+ data []byte, opts ...MarshalOption,
+) (*workloadidentityv1pb.WorkloadIdentity, error) {
+ return UnmarshalProtoResource[*workloadidentityv1pb.WorkloadIdentity](data, opts...)
+}
+
+// ValidateWorkloadIdentity validates the WorkloadIdentity object. This is
+// performed prior to writing to the backend.
+func ValidateWorkloadIdentity(s *workloadidentityv1pb.WorkloadIdentity) error {
+ switch {
+ case s == nil:
+ return trace.BadParameter("object cannot be nil")
+ case s.Version != types.V1:
+ return trace.BadParameter("version: only %q is supported", types.V1)
+ case s.Kind != types.KindWorkloadIdentity:
+ return trace.BadParameter("kind: must be %q", types.KindWorkloadIdentity)
+ case s.Metadata == nil:
+ return trace.BadParameter("metadata: is required")
+ case s.Metadata.Name == "":
+ return trace.BadParameter("metadata.name: is required")
+ case s.Spec == nil:
+ return trace.BadParameter("spec: is required")
+ case s.Spec.Spiffe.Id == "":
+ return trace.BadParameter("spec.spiffe.id: is required")
+ case !strings.HasPrefix(s.Spec.Spiffe.Id, "/"):
+ return trace.BadParameter("spec.spiffe.id: must start with a /")
+ }
+
+ for i, rule := range s.GetSpec().GetRules().GetAllow() {
+ if len(rule.Conditions) == 0 {
+ return trace.BadParameter("spec.rules.allow[%d].conditions: must be non-empty", i)
+ }
+ for j, condition := range rule.Conditions {
+ if condition.Attribute == "" {
+ return trace.BadParameter("spec.rules.allow[%d].conditions[%d].attribute: must be non-empty", i, j)
+ }
+ // Ensure exactly one operator is set.
+ operatorsSet := 0
+ if condition.Equals != "" {
+ operatorsSet++
+ }
+ if operatorsSet == 0 || operatorsSet > 1 {
+ return trace.BadParameter(
+ "spec.rules.allow[%d].conditions[%d]: exactly one operator must be specified, found %d",
+ i, j, operatorsSet,
+ )
+ }
+ }
+ }
+
+ return nil
+}
diff --git a/lib/services/workload_identity_test.go b/lib/services/workload_identity_test.go
new file mode 100644
index 0000000000000..429612ed48555
--- /dev/null
+++ b/lib/services/workload_identity_test.go
@@ -0,0 +1,231 @@
+// Teleport
+// Copyright (C) 2024 Gravitational, Inc.
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+package services
+
+import (
+ "testing"
+
+ "github.com/google/go-cmp/cmp"
+ "github.com/stretchr/testify/require"
+ "google.golang.org/protobuf/testing/protocmp"
+
+ headerv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/header/v1"
+ workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1"
+ "github.com/gravitational/teleport/api/types"
+)
+
+func TestWorkloadIdentityMarshaling(t *testing.T) {
+ t.Parallel()
+
+ testCases := []struct {
+ name string
+ in *workloadidentityv1pb.WorkloadIdentity
+ }{
+ {
+ name: "normal",
+ in: &workloadidentityv1pb.WorkloadIdentity{
+ Kind: types.KindWorkloadIdentity,
+ Version: types.V1,
+ Metadata: &headerv1.Metadata{
+ Name: "example",
+ },
+ Spec: &workloadidentityv1pb.WorkloadIdentitySpec{
+ Spiffe: &workloadidentityv1pb.WorkloadIdentitySPIFFE{
+ Id: "/example",
+ },
+ },
+ },
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ gotBytes, err := MarshalWorkloadIdentity(tc.in)
+ require.NoError(t, err)
+ // Test that unmarshaling gives us the same object
+ got, err := UnmarshalWorkloadIdentity(gotBytes)
+ require.NoError(t, err)
+ require.Empty(t, cmp.Diff(tc.in, got, protocmp.Transform()))
+ })
+ }
+}
+
+func TestValidateWorkloadIdentity(t *testing.T) {
+ t.Parallel()
+
+ var errContains = func(contains string) require.ErrorAssertionFunc {
+ return func(t require.TestingT, err error, msgAndArgs ...interface{}) {
+ require.ErrorContains(t, err, contains, msgAndArgs...)
+ }
+ }
+
+ testCases := []struct {
+ name string
+ in *workloadidentityv1pb.WorkloadIdentity
+ requireErr require.ErrorAssertionFunc
+ }{
+ {
+ name: "success - full",
+ in: &workloadidentityv1pb.WorkloadIdentity{
+ Kind: types.KindWorkloadIdentity,
+ Version: types.V1,
+ Metadata: &headerv1.Metadata{
+ Name: "example",
+ },
+ Spec: &workloadidentityv1pb.WorkloadIdentitySpec{
+ Rules: &workloadidentityv1pb.WorkloadIdentityRules{
+ Allow: []*workloadidentityv1pb.WorkloadIdentityRule{
+ {
+ Conditions: []*workloadidentityv1pb.WorkloadIdentityCondition{
+ {
+ Attribute: "example",
+ Equals: "foo",
+ },
+ },
+ },
+ },
+ },
+ Spiffe: &workloadidentityv1pb.WorkloadIdentitySPIFFE{
+ Id: "/example",
+ },
+ },
+ },
+ requireErr: require.NoError,
+ },
+ {
+ name: "success - minimal",
+ in: &workloadidentityv1pb.WorkloadIdentity{
+ Kind: types.KindWorkloadIdentity,
+ Version: types.V1,
+ Metadata: &headerv1.Metadata{
+ Name: "example",
+ },
+ Spec: &workloadidentityv1pb.WorkloadIdentitySpec{
+ Spiffe: &workloadidentityv1pb.WorkloadIdentitySPIFFE{
+ Id: "/example",
+ },
+ },
+ },
+ requireErr: require.NoError,
+ },
+ {
+ name: "missing name",
+ in: &workloadidentityv1pb.WorkloadIdentity{
+ Kind: types.KindWorkloadIdentity,
+ Version: types.V1,
+ Metadata: &headerv1.Metadata{},
+ Spec: &workloadidentityv1pb.WorkloadIdentitySpec{
+ Spiffe: &workloadidentityv1pb.WorkloadIdentitySPIFFE{
+ Id: "/example",
+ },
+ },
+ },
+ requireErr: errContains("metadata.name: is required"),
+ },
+ {
+ name: "missing spiffe id",
+ in: &workloadidentityv1pb.WorkloadIdentity{
+ Kind: types.KindWorkloadIdentity,
+ Version: types.V1,
+ Metadata: &headerv1.Metadata{
+ Name: "example",
+ },
+ Spec: &workloadidentityv1pb.WorkloadIdentitySpec{
+ Spiffe: &workloadidentityv1pb.WorkloadIdentitySPIFFE{},
+ },
+ },
+ requireErr: errContains("spec.spiffe.id: is required"),
+ },
+ {
+ name: "spiffe id must have leading /",
+ in: &workloadidentityv1pb.WorkloadIdentity{
+ Kind: types.KindWorkloadIdentity,
+ Version: types.V1,
+ Metadata: &headerv1.Metadata{
+ Name: "example",
+ },
+ Spec: &workloadidentityv1pb.WorkloadIdentitySpec{
+ Spiffe: &workloadidentityv1pb.WorkloadIdentitySPIFFE{
+ Id: "example",
+ },
+ },
+ },
+ requireErr: errContains("spec.spiffe.id: must start with a /"),
+ },
+ {
+ name: "missing attribute",
+ in: &workloadidentityv1pb.WorkloadIdentity{
+ Kind: types.KindWorkloadIdentity,
+ Version: types.V1,
+ Metadata: &headerv1.Metadata{
+ Name: "example",
+ },
+ Spec: &workloadidentityv1pb.WorkloadIdentitySpec{
+ Rules: &workloadidentityv1pb.WorkloadIdentityRules{
+ Allow: []*workloadidentityv1pb.WorkloadIdentityRule{
+ {
+ Conditions: []*workloadidentityv1pb.WorkloadIdentityCondition{
+ {
+ Attribute: "",
+ Equals: "foo",
+ },
+ },
+ },
+ },
+ },
+ Spiffe: &workloadidentityv1pb.WorkloadIdentitySPIFFE{
+ Id: "/example",
+ },
+ },
+ },
+ requireErr: errContains("spec.rules.allow[0].conditions[0].attribute: must be non-empty"),
+ },
+ {
+ name: "missing operator",
+ in: &workloadidentityv1pb.WorkloadIdentity{
+ Kind: types.KindWorkloadIdentity,
+ Version: types.V1,
+ Metadata: &headerv1.Metadata{
+ Name: "example",
+ },
+ Spec: &workloadidentityv1pb.WorkloadIdentitySpec{
+ Rules: &workloadidentityv1pb.WorkloadIdentityRules{
+ Allow: []*workloadidentityv1pb.WorkloadIdentityRule{
+ {
+ Conditions: []*workloadidentityv1pb.WorkloadIdentityCondition{
+ {
+ Attribute: "example",
+ },
+ },
+ },
+ },
+ },
+ Spiffe: &workloadidentityv1pb.WorkloadIdentitySPIFFE{
+ Id: "/example",
+ },
+ },
+ },
+ requireErr: errContains("spec.rules.allow[0].conditions[0]: exactly one operator must be specified, found 0"),
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ err := ValidateWorkloadIdentity(tc.in)
+ tc.requireErr(t, err)
+ })
+ }
+}