diff --git a/src/cmd/main/main.go b/src/cmd/main/main.go index b7714ae..ec681f2 100644 --- a/src/cmd/main/main.go +++ b/src/cmd/main/main.go @@ -47,7 +47,7 @@ func main() { return outboxFactory(repository) }) deleteTopicProcess := del.NewProcess(logger, db, confluentClient, func(repository del.OutboxRepository) del.Outbox { return outboxFactory(repository) }) - addSchemaProcess := schema.NewProcess(logger, db, confluentClient, func(repository schema.OutboxRepository) schema.Outbox { return outboxFactory(repository) }) + addSchemaProcess := schema.NewProcess(logger, db, confluentClient, awsClient, func(repository schema.OutboxRepository) schema.Outbox { return outboxFactory(repository) }) consumer := Must(messaging.ConfigureConsumer(logger, config.KafkaBroker, config.KafkaGroupId, messaging.WithCredentials(config.CreateConsumerCredentials()), messaging.RegisterMessageHandler(config.TopicNameSelfService, "topic_requested", create.NewTopicRequestedHandler(createTopicProcess), &create.TopicRequested{}), diff --git a/src/functional_tests/create_schema_test.go b/src/functional_tests/create_schema_test.go index 53f6b34..2d29550 100644 --- a/src/functional_tests/create_schema_test.go +++ b/src/functional_tests/create_schema_test.go @@ -9,6 +9,7 @@ import ( schema "github.com/dfds/confluent-gateway/internal/schema" "github.com/dfds/confluent-gateway/messaging" "github.com/h2non/gock" + uuid "github.com/satori/go.uuid" "github.com/stretchr/testify/require" "testing" "time" @@ -53,7 +54,43 @@ func TestCreateSchemaProcess(t *testing.T) { ) require.NoError(t, err) - process := schema.NewProcess(testerApp.logger, testerApp.db, testerApp.confluentClient, func(repository schema.OutboxRepository) schema.Outbox { + err = testerApp.db.CreateTopic(&models.Topic{ + Id: createSchemaVariables.TopicId, + CapabilityId: createSchemaVariables.CapabilityId, + ClusterId: testerApp.dbSeedVariables.DevelopmentClusterId, + Name: createSchemaVariables.TopicName, + CreatedAt: time.Now(), + }) + require.NoError(t, err) + + someUserId := models.MakeUserAccountId(123) + someServiceAccountID := models.ServiceAccountId(uuid.NewV4().String()) + clusterAccess := models.ClusterAccess{ + Id: uuid.NewV4(), + ClusterId: testerApp.dbSeedVariables.DevelopmentClusterId, + ServiceAccountId: someServiceAccountID, + UserAccountId: someUserId, + Acl: nil, + CreatedAt: time.Time{}, + } + err = testerApp.db.CreateServiceAccount(&models.ServiceAccount{ + Id: someServiceAccountID, + UserAccountId: someUserId, + CapabilityId: createSchemaVariables.CapabilityId, + ClusterAccesses: []models.ClusterAccess{ + clusterAccess, + }, + CreatedAt: time.Time{}, + }) + + err = testerApp.db.CreateClusterAccess(&clusterAccess) + + //ensureServiceAccountSchemaRegistryAccessStep + setupListKeysHTTPMock(string(testerApp.dbSeedVariables.DevelopmentSchemaRegistryId), someServiceAccountID, 0) // Check if the api key has already been created + setupCreateApiKeyMock(string(testerApp.dbSeedVariables.DevelopmentSchemaRegistryId), someServiceAccountID, "username", "p4ssword") // Then we create an API key for the schema registry + setupRoleBindingHTTPMock(string(someServiceAccountID), testerApp.dbSeedVariables.GetDevelopmentClusterValues()) // Then we create a role binding for the service account + + process := schema.NewProcess(testerApp.logger, testerApp.db, testerApp.confluentClient, *testerApp.vaultClient, func(repository schema.OutboxRepository) schema.Outbox { return outboxFactory(repository) }) @@ -65,20 +102,6 @@ func TestCreateSchemaProcess(t *testing.T) { Schema: "test-schema", } - err = process.Process(context.Background(), input) - // TODO: Although the topic does not exist, the process ignores that and continues. - //require.ErrorIs(t, err, storage.ErrTopicNotFound) - require.NoError(t, err) - - err = testerApp.db.CreateTopic(&models.Topic{ - Id: createSchemaVariables.TopicId, - CapabilityId: createSchemaVariables.CapabilityId, - ClusterId: testerApp.dbSeedVariables.DevelopmentClusterId, - Name: createSchemaVariables.TopicName, - CreatedAt: time.Now(), - }) - require.NoError(t, err) - setupCreateSchemaHttpMock(input, createSchemaVariables.TopicName, testerApp.dbSeedVariables) err = process.Process(context.Background(), input) require.NoError(t, err) diff --git a/src/internal/schema/account.go b/src/internal/schema/account.go new file mode 100644 index 0000000..d16ed63 --- /dev/null +++ b/src/internal/schema/account.go @@ -0,0 +1,71 @@ +package schema + +import ( + "context" + "fmt" + "github.com/dfds/confluent-gateway/internal/models" +) + +type accountService struct { + context context.Context + confluent Confluent + repo serviceAccountRepository +} + +type serviceAccountRepository interface { + GetServiceAccount(capabilityId models.CapabilityId) (*models.ServiceAccount, error) +} + +func NewSchemaAccountService(ctx context.Context, confluent Confluent, repo serviceAccountRepository) *accountService { + return &accountService{ + context: ctx, + confluent: confluent, + repo: repo, + } +} + +func (h *accountService) GetServiceAccount(capabilityId models.CapabilityId) (*models.ServiceAccount, error) { + return h.repo.GetServiceAccount(capabilityId) +} + +func (h *accountService) DeleteSchemaRegistryApiKey(clusterAccess *models.ClusterAccess) error { + return h.confluent.DeleteSchemaRegistryApiKey(h.context, clusterAccess.ClusterId, clusterAccess.ServiceAccountId) +} + +func (h *accountService) GetClusterAccess(capabilityId models.CapabilityId, clusterId models.ClusterId) (*models.ClusterAccess, error) { + serviceAccount, err := h.repo.GetServiceAccount(capabilityId) + if err != nil { + return nil, err + } + if serviceAccount == nil { + return nil, fmt.Errorf("no service account for capability '%s' found", capabilityId) + } + + clusterAccess, hasClusterAccess := serviceAccount.TryGetClusterAccess(clusterId) + + if !hasClusterAccess { + return nil, fmt.Errorf("no cluster access for service account '%s' found", serviceAccount.Id) + } + return clusterAccess, nil +} + +func (h *accountService) CountSchemaRegistryApiKeys(clusterAccess *models.ClusterAccess) (int, error) { + keyCount, err := h.confluent.CountSchemaRegistryApiKeys(h.context, clusterAccess.ServiceAccountId, clusterAccess.ClusterId) + if err != nil { + return 0, err + } + return keyCount, nil +} + +func (h *accountService) CreateSchemaRegistryApiKey(clusterAccess *models.ClusterAccess) (models.ApiKey, error) { + return h.confluent.CreateSchemaRegistryApiKey(h.context, clusterAccess.ClusterId, clusterAccess.ServiceAccountId) + +} + +func (h *accountService) CreateServiceAccountRoleBinding(clusterAccess *models.ClusterAccess) error { + err := h.confluent.CreateServiceAccountRoleBinding(h.context, clusterAccess.ServiceAccountId, clusterAccess.ClusterId) + if err != nil { + return err + } + return nil +} diff --git a/src/internal/schema/confluent.go b/src/internal/schema/confluent.go new file mode 100644 index 0000000..47092b4 --- /dev/null +++ b/src/internal/schema/confluent.go @@ -0,0 +1,16 @@ +package schema + +import ( + "context" + + "github.com/dfds/confluent-gateway/internal/models" +) + +type Confluent interface { + CreateServiceAccount(ctx context.Context, name string, description string) (models.ServiceAccountId, error) + CreateSchemaRegistryApiKey(ctx context.Context, clusterId models.ClusterId, serviceAccountId models.ServiceAccountId) (models.ApiKey, error) + CreateServiceAccountRoleBinding(ctx context.Context, serviceAccount models.ServiceAccountId, clusterId models.ClusterId) error + CountSchemaRegistryApiKeys(ctx context.Context, clusterAccess models.ServiceAccountId, clusterId models.ClusterId) (int, error) + DeleteSchemaRegistryApiKey(ctx context.Context, clusterId models.ClusterId, serviceAccountId models.ServiceAccountId) error + RegisterSchema(ctx context.Context, clusterId models.ClusterId, subject string, schema string) error +} diff --git a/src/internal/schema/context.go b/src/internal/schema/context.go index ba4d61d..04c5f23 100644 --- a/src/internal/schema/context.go +++ b/src/internal/schema/context.go @@ -2,6 +2,7 @@ package schema import ( "context" + "fmt" "github.com/dfds/confluent-gateway/internal/models" "github.com/dfds/confluent-gateway/logging" "github.com/dfds/confluent-gateway/messaging" @@ -10,13 +11,29 @@ import ( type StepContext struct { logger logging.Logger ctx context.Context + account AccountService + vault VaultService state *models.SchemaProcess registry SchemaRegistry + input ProcessInput outbox Outbox + topic TopicService +} +type AccountService interface { + GetServiceAccount(models.CapabilityId) (*models.ServiceAccount, error) + GetClusterAccess(models.CapabilityId, models.ClusterId) (*models.ClusterAccess, error) + CreateSchemaRegistryApiKey(clusterAccess *models.ClusterAccess) (models.ApiKey, error) + CreateServiceAccountRoleBinding(clusterAccess *models.ClusterAccess) error + CountSchemaRegistryApiKeys(clusterAccess *models.ClusterAccess) (int, error) + DeleteSchemaRegistryApiKey(clusterAccess *models.ClusterAccess) error +} + +func NewStepContext(logger logging.Logger, ctx context.Context, schema *models.SchemaProcess, registry SchemaRegistry, outbox Outbox, account AccountService, vault VaultService, topic TopicService) *StepContext { + return &StepContext{logger: logger, ctx: ctx, state: schema, registry: registry, outbox: outbox, account: account, vault: vault, topic: topic} } -func NewStepContext(logger logging.Logger, ctx context.Context, schema *models.SchemaProcess, registry SchemaRegistry, outbox Outbox) *StepContext { - return &StepContext{logger: logger, ctx: ctx, state: schema, registry: registry, outbox: outbox} +type TopicService interface { + GetTopic(string) (*models.Topic, error) } type Outbox interface { @@ -55,3 +72,65 @@ func (c *StepContext) RaiseSchemaRegistrationFailed(reason string) error { } return c.outbox.Produce(event) } + +func (c *StepContext) LogDebug(format string, args ...string) { + c.logger.Debug(format, args...) +} + +func (c *StepContext) LogError(err error, format string, args ...string) { + c.logger.Error(err, format, args...) +} + +func (c *StepContext) LogWarning(format string, args ...string) { + c.logger.Warning(format, args...) +} + +func (c *StepContext) HasServiceAccount(capabilityId models.CapabilityId) bool { + account, err := c.account.GetServiceAccount(capabilityId) + if err != nil { + c.LogError(err, fmt.Sprintf("encountered error when checking if ServiceAccount exists for CapabilityId %s", capabilityId)) + return false + } + return account != nil +} + +func (c *StepContext) GetClusterAccess() (*models.ClusterAccess, error) { + topic, err := c.topic.GetTopic(c.state.TopicId) + if err != nil { + return nil, err + } + return c.account.GetClusterAccess(topic.CapabilityId, topic.ClusterId) +} + +func (c *StepContext) HasSchemaRegistryApiKey(clusterAccess *models.ClusterAccess) (bool, error) { + count, err := c.account.CountSchemaRegistryApiKeys(clusterAccess) + return count > 0, err +} + +func (c *StepContext) HasSchemaRegistryApiKeyInVault(clusterAccess *models.ClusterAccess) (bool, error) { + topic, err := c.topic.GetTopic(c.state.TopicId) + if err != nil { + return false, err + } + return c.vault.QuerySchemaRegistryApiKey(topic.CapabilityId, clusterAccess.ClusterId) +} + +func (c *StepContext) CreateSchemaRegistryApiKeyAndStoreInVault(clusterAccess *models.ClusterAccess, shouldOverwriteKey bool) error { + newKey, err := c.account.CreateSchemaRegistryApiKey(clusterAccess) + if err != nil { + return err + } + topic, err := c.topic.GetTopic(c.state.TopicId) + if err != nil { + return err + } + return c.vault.StoreSchemaRegistryApiKey(topic.CapabilityId, clusterAccess.ClusterId, newKey, shouldOverwriteKey) +} + +func (c *StepContext) DeleteSchemaRegistryApiKey(clusterAccess *models.ClusterAccess) error { + return c.account.DeleteSchemaRegistryApiKey(clusterAccess) +} + +func (c *StepContext) CreateServiceAccountRoleBinding(clusterAccess *models.ClusterAccess) error { + return c.account.CreateServiceAccountRoleBinding(clusterAccess) +} diff --git a/src/internal/schema/process.go b/src/internal/schema/process.go index f6bf0d6..9fb4f02 100644 --- a/src/internal/schema/process.go +++ b/src/internal/schema/process.go @@ -8,22 +8,31 @@ import ( "github.com/dfds/confluent-gateway/internal/models" . "github.com/dfds/confluent-gateway/internal/process" "github.com/dfds/confluent-gateway/internal/storage" + "github.com/dfds/confluent-gateway/internal/vault" "github.com/dfds/confluent-gateway/logging" ) +type logger interface { + LogDebug(string, ...string) + LogWarning(string, ...string) + LogError(error, string, ...string) +} + type process struct { - logger logging.Logger - database models.Database - registry SchemaRegistry - factory OutboxFactory + logger logging.Logger + database models.Database + confluent Confluent + factory OutboxFactory + vault vault.Vault } -func NewProcess(logger logging.Logger, database models.Database, registry SchemaRegistry, factory OutboxFactory) Process { +func NewProcess(logger logging.Logger, database models.Database, confluent Confluent, vault vault.Vault, factory OutboxFactory) Process { return &process{ - logger: logger, - database: database, - registry: registry, - factory: factory, + logger: logger, + database: database, + factory: factory, + confluent: confluent, + vault: vault, } } @@ -55,6 +64,7 @@ func (p *process) Process(ctx context.Context, input ProcessInput) error { } return PrepareSteps[*StepContext](). + Step(ensureServiceAccountSchemaRegistryAccessStep). Step(ensureSchemaIsRegistered). Run(func(step func(*StepContext) error) error { return session.Transaction(func(tx models.Transaction) error { @@ -123,7 +133,10 @@ func getOrCreateProcessState(repo schemaRepository, input ProcessInput, topic *m } func (p *process) getStepContext(ctx context.Context, tx models.Transaction, schema *models.SchemaProcess) *StepContext { - return NewStepContext(p.logger, ctx, schema, p.registry, p.factory(tx)) + newAccountService := NewSchemaAccountService(ctx, p.confluent, tx) + vaultService := NewVaultService(ctx, p.vault) + topicService := NewTopicService(tx) + return NewStepContext(p.logger, ctx, schema, p.confluent, p.factory(tx), newAccountService, vaultService, topicService) } // region Steps @@ -168,4 +181,70 @@ func ensureSchemaIsRegisteredStep(step EnsureSchemaIsRegisteredStep) error { return step.RaiseSchemaRegisteredEvent() } +type EnsureServiceAccountSchemaRegistryAccessStep interface { + logger + GetClusterAccess() (*models.ClusterAccess, error) + HasSchemaRegistryApiKey(clusterAccess *models.ClusterAccess) (bool, error) + HasSchemaRegistryApiKeyInVault(clusterAccess *models.ClusterAccess) (bool, error) + CreateServiceAccountRoleBinding(*models.ClusterAccess) error + CreateSchemaRegistryApiKeyAndStoreInVault(clusterAccess *models.ClusterAccess, shouldOverwriteKey bool) error + DeleteSchemaRegistryApiKey(clusterAccess *models.ClusterAccess) error +} + +func ensureServiceAccountSchemaRegistryAccessStep(step *StepContext) error { + inner := func(step EnsureServiceAccountSchemaRegistryAccessStep) error { + step.LogDebug("Running {Step}", "EnsureServiceAccountSchemaRegistryAccessStep") + + clusterAccess, err := step.GetClusterAccess() + if err != nil { + if !errors.Is(err, storage.ErrTopicNotFound) { + step.LogError(err, "unable to get cluster access") + } + return nil + } + + err = step.CreateServiceAccountRoleBinding(clusterAccess) + if err != nil { + if errors.Is(err, confluent.ErrMissingSchemaRegistryIds) { + step.LogError(err, "unable to setup schema registry access") + return nil // fail silently to not take down whole service + } + return err + } + + hasKeyInConfluent, err := step.HasSchemaRegistryApiKey(clusterAccess) + if err != nil { + return err + } + hasKeyInVault, err := step.HasSchemaRegistryApiKeyInVault(clusterAccess) + if err != nil { + return err + } + + if hasKeyInVault && hasKeyInConfluent { + return nil + } + + recreateKey := false + if hasKeyInConfluent && !hasKeyInVault { + step.LogWarning("found existing api key in Confluent, but not in Parameter Store. Deleting key and creating again.") + err = step.DeleteSchemaRegistryApiKey(clusterAccess) + if err != nil { + return err + } + } else if !hasKeyInConfluent && hasKeyInVault { // not sure if this can happen + step.LogWarning("found existing key in Parameter Store, but not in Confluent. Creating new key and updating Parameter Store.") + recreateKey = true + } + + err = step.CreateSchemaRegistryApiKeyAndStoreInVault(clusterAccess, recreateKey) + if err != nil { + return err + } + step.LogWarning("granted schema registry access through schema registration") + return nil + } + return inner(step) +} + // endregion diff --git a/src/internal/schema/topic.go b/src/internal/schema/topic.go new file mode 100644 index 0000000..ea04bd2 --- /dev/null +++ b/src/internal/schema/topic.go @@ -0,0 +1,21 @@ +package schema + +import ( + "github.com/dfds/confluent-gateway/internal/models" +) + +type topicService struct { + repo topicRepository +} + +type topicRepository interface { + GetTopic(topicId string) (*models.Topic, error) +} + +func NewTopicService(repo topicRepository) *topicService { + return &topicService{repo: repo} +} + +func (p *topicService) GetTopic(topicId string) (*models.Topic, error) { + return p.repo.GetTopic(topicId) +} diff --git a/src/internal/schema/vault.go b/src/internal/schema/vault.go new file mode 100644 index 0000000..e4d5961 --- /dev/null +++ b/src/internal/schema/vault.go @@ -0,0 +1,50 @@ +package schema + +import ( + "context" + "github.com/dfds/confluent-gateway/internal/vault" + + "github.com/dfds/confluent-gateway/internal/models" +) + +type VaultService interface { + StoreSchemaRegistryApiKey(capabilityId models.CapabilityId, clusterId models.ClusterId, apiKey models.ApiKey, shouldOverwrite bool) error + QuerySchemaRegistryApiKey(capabilityId models.CapabilityId, clusterId models.ClusterId) (bool, error) + DeleteSchemaRegistryApiKey(capabilityId models.CapabilityId, clusterId models.ClusterId) error +} + +type vaultService struct { + context context.Context + vault vault.Vault +} + +func NewVaultService(context context.Context, vault vault.Vault) *vaultService { + return &vaultService{context: context, vault: vault} +} + +func (v *vaultService) StoreSchemaRegistryApiKey(capabilityId models.CapabilityId, clusterId models.ClusterId, apiKey models.ApiKey, shouldOverwrite bool) error { + return v.vault.StoreApiKey(v.context, vault.Input{ + OperationDestination: vault.OperationDestinationSchemaRegistry, + CapabilityId: capabilityId, + ClusterId: clusterId, + StoringInput: &vault.StoringInput{ApiKey: apiKey, Overwrite: shouldOverwrite}, + }) + +} + +func (v *vaultService) QuerySchemaRegistryApiKey(capabilityId models.CapabilityId, clusterId models.ClusterId) (bool, error) { + return v.vault.QueryApiKey(v.context, vault.Input{ + OperationDestination: vault.OperationDestinationSchemaRegistry, + CapabilityId: capabilityId, + ClusterId: clusterId, + }) + +} + +func (v *vaultService) DeleteSchemaRegistryApiKey(capabilityId models.CapabilityId, clusterId models.ClusterId) error { + return v.vault.DeleteApiKey(v.context, vault.Input{ + OperationDestination: vault.OperationDestinationSchemaRegistry, + CapabilityId: capabilityId, + ClusterId: clusterId, + }) +}