From cf9836abad8d16d58e9884ecbab8616a4f652e34 Mon Sep 17 00:00:00 2001 From: Maxime Lagresle Date: Thu, 7 Nov 2024 14:06:01 +0100 Subject: [PATCH] Fix concurrency issues (#184) * add test and fix race in test * test for race * listen on local interface to avoid firewall warnings * introduce mutex around library operations * remove unnecessary lock * skip high concurrency test with legacy client --- .github/workflows/tests.yml | 2 +- internal/bitwarden/embedded/models.go | 7 +- .../embedded/password_manager_base.go | 48 ++- .../embedded/password_manager_webapi.go | 291 ++++++++++-------- internal/bitwarden/models/password_manager.go | 1 + internal/provider/resource_item_login_test.go | 27 ++ .../testhelper_secrets_manager_test.go | 40 ++- 7 files changed, 283 insertions(+), 133 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 12fe372..1051bc8 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -97,7 +97,7 @@ jobs: run: go build -v ./... - name: Test with Embedded Client - run: go test -coverprofile=profile.cov -v -coverpkg=./... ./... + run: go test -coverprofile=profile.cov -v -race -coverpkg=./... ./... env: VAULTWARDEN_HOST: "127.0.0.1" VAULTWARDEN_PORT: "8080" diff --git a/internal/bitwarden/embedded/models.go b/internal/bitwarden/embedded/models.go index 49b2ce7..b7dcc35 100644 --- a/internal/bitwarden/embedded/models.go +++ b/internal/bitwarden/embedded/models.go @@ -20,7 +20,11 @@ type Account struct { Secrets AccountSecrets `json:"-"` } -func (a *Account) PrivateKeyDecrypted() bool { +func (a *Account) LoggedIn() bool { + return len(a.ProtectedRSAPrivateKey) > 0 +} + +func (a *Account) SecretsLoaded() bool { return len(a.Secrets.MainKey.Key) > 0 } @@ -47,6 +51,7 @@ func (s *AccountSecrets) GetOrganizationKey(orgId string) (*symmetrickey.Key, er type OrganizationSecret struct { Key symmetrickey.Key OrganizationUUID string + Name string } type MachineAccountClaims struct { diff --git a/internal/bitwarden/embedded/password_manager_base.go b/internal/bitwarden/embedded/password_manager_base.go index cd0a64c..9c9b874 100644 --- a/internal/bitwarden/embedded/password_manager_base.go +++ b/internal/bitwarden/embedded/password_manager_base.go @@ -6,6 +6,7 @@ import ( "encoding/json" "fmt" "strings" + "sync" "time" "github.com/hashicorp/terraform-plugin-log/tflog" @@ -27,14 +28,30 @@ type BaseVault interface { } type baseVault struct { - locked bool - loginAccount Account - objectStore map[string]models.Object + loginAccount Account + objectStore map[string]models.Object + + // vaultOperationMutex protects the objectStore and loginAccount fields + // from concurrent access. Read operations are allowed to run concurrently, + // but write operations are serialized. In theory we could protect the two + // fields individually, but it's just much more easier to have a single + // mutex for both. + vaultOperationMutex sync.RWMutex + + // verifyObjectEncryption is a flag that can be set to true to verify that + // every object that is encrypted can be decrypted back to its original. verifyObjectEncryption bool } -func (v *baseVault) GetObject(_ context.Context, obj models.Object) (*models.Object, error) { - if v.locked { +func (v *baseVault) GetObject(ctx context.Context, obj models.Object) (*models.Object, error) { + v.vaultOperationMutex.RLock() + defer v.vaultOperationMutex.RUnlock() + + return v.getObject(ctx, obj) +} + +func (v *baseVault) getObject(_ context.Context, obj models.Object) (*models.Object, error) { + if v.objectStore == nil { return nil, models.ErrVaultLocked } @@ -47,7 +64,10 @@ func (v *baseVault) GetObject(_ context.Context, obj models.Object) (*models.Obj } func (v *baseVault) ListObjects(ctx context.Context, objType models.ObjectType, options ...bitwarden.ListObjectsOption) ([]models.Object, error) { - if v.locked { + v.vaultOperationMutex.RLock() + defer v.vaultOperationMutex.RUnlock() + + if v.objectStore == nil { return nil, models.ErrVaultLocked } @@ -117,12 +137,28 @@ func (v *baseVault) encryptFolder(_ context.Context, obj models.Object, secret A return &encFolder, nil } +func (v *baseVault) objectsLoaded() bool { + return v.objectStore != nil +} + func (v *baseVault) storeObject(ctx context.Context, obj models.Object) { tflog.Trace(ctx, "Storing new object", map[string]interface{}{"object_id": obj.ID, "object_name": obj.Name, "object_folder_id": obj.FolderID}) v.objectStore[objKey(obj)] = obj } +func (v *baseVault) storeObjects(ctx context.Context, objs []models.Object) { + for _, obj := range objs { + v.storeObject(ctx, obj) + } +} + func decryptAccountSecrets(account Account, password string) (*AccountSecrets, error) { + if len(account.Email) == 0 { + // A common mistake is trying to decrypt account secrets without an + // email, the content of an Account comes from two different API calls. + return nil, fmt.Errorf("BUG: email required to decrypt account secrets") + } + masterKey, err := keybuilder.BuildPreloginKey(password, account.Email, account.KdfConfig) if err != nil { return nil, fmt.Errorf("error building prelogin key: %w", err) diff --git a/internal/bitwarden/embedded/password_manager_webapi.go b/internal/bitwarden/embedded/password_manager_webapi.go index b7bb72e..1a519c4 100644 --- a/internal/bitwarden/embedded/password_manager_webapi.go +++ b/internal/bitwarden/embedded/password_manager_webapi.go @@ -8,7 +8,6 @@ import ( "strings" "github.com/google/uuid" - "github.com/hashicorp/terraform-plugin-log/tflog" "github.com/maxlaverse/terraform-provider-bitwarden/internal/bitwarden" "github.com/maxlaverse/terraform-provider-bitwarden/internal/bitwarden/crypto" "github.com/maxlaverse/terraform-provider-bitwarden/internal/bitwarden/crypto/encryptedstring" @@ -79,7 +78,6 @@ func NewPasswordManagerClient(serverURL, deviceIdentifier, providerVersion strin c := &webAPIVault{ baseVault: baseVault{ objectStore: make(map[string]models.Object), - locked: true, verifyObjectEncryption: true, }, serverURL: serverURL, @@ -107,13 +105,15 @@ type webAPIVault struct { client webapi.Client clientOpts []webapi.Options - ciphersMap webapi.SyncResponse syncAfterWrite bool serverURL string } func (v *webAPIVault) CreateAttachment(ctx context.Context, itemId, filePath string) (*models.Object, error) { - if v.locked { + v.vaultOperationMutex.Lock() + defer v.vaultOperationMutex.Unlock() + + if !v.objectsLoaded() { return nil, models.ErrVaultLocked } @@ -140,12 +140,12 @@ func (v *webAPIVault) CreateAttachment(ctx context.Context, itemId, filePath str v.storeObject(ctx, *resObj) if v.syncAfterWrite { - err = v.Sync(ctx) + err = v.sync(ctx) if err != nil { return nil, fmt.Errorf("sync-after-write error: %w", err) } - remoteObj, err := v.GetObject(ctx, *resObj) + remoteObj, err := v.getObject(ctx, *resObj) if err != nil { return nil, fmt.Errorf("error getting object after attachment upload (sync-after-write): %w", err) } @@ -162,7 +162,10 @@ func (v *webAPIVault) CreateAttachment(ctx context.Context, itemId, filePath str } func (v *webAPIVault) CreateObject(ctx context.Context, obj models.Object) (*models.Object, error) { - if v.locked { + v.vaultOperationMutex.Lock() + defer v.vaultOperationMutex.Unlock() + + if !v.objectsLoaded() { return nil, models.ErrVaultLocked } @@ -218,13 +221,14 @@ func (v *webAPIVault) CreateObject(ctx context.Context, obj models.Object) (*mod } v.storeObject(ctx, *resObj) + if v.syncAfterWrite { - err := v.Sync(ctx) + err := v.sync(ctx) if err != nil { return nil, fmt.Errorf("sync-after-write error: %w", err) } - remoteObj, err := v.GetObject(ctx, *resObj) + remoteObj, err := v.getObject(ctx, *resObj) if err != nil { return nil, fmt.Errorf("error getting object after creation (sync-after-write): %w", err) } @@ -241,7 +245,10 @@ func (v *webAPIVault) CreateObject(ctx context.Context, obj models.Object) (*mod } func (v *webAPIVault) CreateOrganization(ctx context.Context, organizationName, organizationLabel, billingEmail string) (string, error) { - if v.locked { + v.vaultOperationMutex.Lock() + defer v.vaultOperationMutex.Unlock() + + if !v.objectsLoaded() { return "", models.ErrVaultLocked } @@ -279,13 +286,20 @@ func (v *webAPIVault) CreateOrganization(ctx context.Context, organizationName, } func (v *webAPIVault) DeleteAttachment(ctx context.Context, itemId, attachmentId string) error { + v.vaultOperationMutex.Lock() + defer v.vaultOperationMutex.Unlock() + + if !v.objectsLoaded() { + return models.ErrVaultLocked + } + // TODO: Don't fail if attachment is already gone err := v.client.DeleteObjectAttachment(ctx, itemId, attachmentId) if err != nil { return fmt.Errorf("error deleting attachment: %w", err) } - resObj, err := v.GetObject(ctx, models.Object{ID: itemId, Object: models.ObjectTypeItem}) + resObj, err := v.getObject(ctx, models.Object{ID: itemId, Object: models.ObjectTypeItem}) if err != nil { return fmt.Errorf("error getting object after attachment deletion: %w", err) } @@ -300,12 +314,12 @@ func (v *webAPIVault) DeleteAttachment(ctx context.Context, itemId, attachmentId v.storeObject(ctx, *resObj) if v.syncAfterWrite { - err := v.Sync(ctx) + err := v.sync(ctx) if err != nil { return fmt.Errorf("sync-after-write error: %w", err) } - remoteObj, err := v.GetObject(ctx, *resObj) + remoteObj, err := v.getObject(ctx, *resObj) if err != nil { return fmt.Errorf("error getting object after attachment deletion (syncAfterWrite): %w", err) } @@ -317,6 +331,13 @@ func (v *webAPIVault) DeleteAttachment(ctx context.Context, itemId, attachmentId } func (v *webAPIVault) DeleteObject(ctx context.Context, obj models.Object) error { + v.vaultOperationMutex.Lock() + defer v.vaultOperationMutex.Unlock() + + if !v.objectsLoaded() { + return models.ErrVaultLocked + } + // TODO: Don't fail if object is already gone var err error if obj.Object == models.ObjectTypeFolder { @@ -334,13 +355,16 @@ func (v *webAPIVault) DeleteObject(ctx context.Context, obj models.Object) error v.deleteObjectFromStore(ctx, obj) if v.syncAfterWrite { - return v.Sync(ctx) + return v.sync(ctx) } return nil } func (v *webAPIVault) EditObject(ctx context.Context, obj models.Object) (*models.Object, error) { - if v.locked { + v.vaultOperationMutex.Lock() + defer v.vaultOperationMutex.Unlock() + + if !v.objectsLoaded() { return nil, models.ErrVaultLocked } @@ -395,12 +419,12 @@ func (v *webAPIVault) EditObject(ctx context.Context, obj models.Object) (*model v.storeObject(ctx, *resObj) if v.syncAfterWrite { - err := v.Sync(ctx) + err := v.sync(ctx) if err != nil { return nil, fmt.Errorf("sync-after-write error: %w", err) } - remoteObj, err := v.GetObject(ctx, *resObj) + remoteObj, err := v.getObject(ctx, *resObj) if err != nil { return nil, fmt.Errorf("error getting object after edition (sync-after-write): %w", err) } @@ -430,7 +454,10 @@ func (v *webAPIVault) GetAPIKey(ctx context.Context, username, password string) } func (v *webAPIVault) GetAttachment(ctx context.Context, itemId, attachmentId string) ([]byte, error) { - if v.locked { + v.vaultOperationMutex.Lock() + defer v.vaultOperationMutex.Unlock() + + if !v.objectsLoaded() { return nil, models.ErrVaultLocked } @@ -450,7 +477,7 @@ func (v *webAPIVault) GetAttachment(ctx context.Context, itemId, attachmentId st return nil, fmt.Errorf("error fetching attachment body: %w", err) } - originalObj, err := v.GetObject(ctx, models.Object{ID: itemId, Object: models.ObjectTypeItem}) + originalObj, err := v.getObject(ctx, models.Object{ID: itemId, Object: models.ObjectTypeItem}) if err != nil { return nil, fmt.Errorf("error getting original object: %w", err) } @@ -479,6 +506,13 @@ func (v *webAPIVault) GetAttachment(ctx context.Context, itemId, attachmentId st } func (v *webAPIVault) LoginWithAPIKey(ctx context.Context, password, clientId, clientSecret string) error { + v.vaultOperationMutex.Lock() + defer v.vaultOperationMutex.Unlock() + + if v.loginAccount.LoggedIn() { + return models.ErrAlreadyLoggedIn + } + tokenResp, err := v.client.LoginWithAPIKey(ctx, clientId, clientSecret) if err != nil { return fmt.Errorf("error login with api key: %w", err) @@ -488,6 +522,13 @@ func (v *webAPIVault) LoginWithAPIKey(ctx context.Context, password, clientId, c } func (v *webAPIVault) LoginWithPassword(ctx context.Context, username, password string) error { + v.vaultOperationMutex.Lock() + defer v.vaultOperationMutex.Unlock() + + if v.loginAccount.LoggedIn() { + return models.ErrAlreadyLoggedIn + } + preResp, err := v.client.PreLogin(ctx, username) if err != nil { return fmt.Errorf("error prelogin with username/password: %w", err) @@ -545,47 +586,61 @@ func (v *webAPIVault) RegisterUser(ctx context.Context, name, username, password } func (v *webAPIVault) Sync(ctx context.Context) error { + v.vaultOperationMutex.Lock() + defer v.vaultOperationMutex.Unlock() + + return v.sync(ctx) +} + +func (v *webAPIVault) sync(ctx context.Context) error { + if !v.loginAccount.LoggedIn() { + return models.ErrLoggedOut + } else if !v.loginAccount.SecretsLoaded() { + return models.ErrVaultLocked + } + ciphersRaw, err := v.client.Sync(ctx) if err != nil { return fmt.Errorf("error syncing: %w", err) } - if len(v.loginAccount.Email) > 0 && v.loginAccount.Email != ciphersRaw.Profile.Email || len(v.loginAccount.AccountUUID) > 0 && v.loginAccount.AccountUUID != ciphersRaw.Profile.Id { + + if v.loginAccount.Email != ciphersRaw.Profile.Email || v.loginAccount.AccountUUID != ciphersRaw.Profile.Id { return fmt.Errorf("BUG: account UUID or email changed during sync") } - v.ciphersMap = *ciphersRaw - v.loginAccount.Email = v.ciphersMap.Profile.Email - v.loginAccount.AccountUUID = v.ciphersMap.Profile.Id - - if !v.loginAccount.PrivateKeyDecrypted() { - return nil + err = loadOrganizationSecrets(v.loginAccount.Secrets, ciphersRaw.Profile.Organizations) + if err != nil { + return fmt.Errorf("error loading organization secrets: %w", err) } - return v.loadObjectMap(ctx) + return v.loadObjectMap(ctx, *ciphersRaw) } func (v *webAPIVault) Unlock(ctx context.Context, password string) error { - if len(v.loginAccount.Email) == 0 { - return fmt.Errorf("please login first") - } + v.vaultOperationMutex.Lock() + defer v.vaultOperationMutex.Unlock() - accountSecrets, err := decryptAccountSecrets(v.loginAccount, password) - if err != nil { - return fmt.Errorf("error decrypting account secrets: %w", err) + return v.unlock(ctx, password) +} + +func (v *webAPIVault) unlock(ctx context.Context, password string) error { + if !v.loginAccount.LoggedIn() { + return models.ErrLoggedOut } - v.loginAccount.Secrets = *accountSecrets profile, err := v.client.GetProfile(ctx) if err != nil { return fmt.Errorf("error loading profile: %w", err) } - v.ciphersMap.Profile = *profile + v.loginAccount.Email = profile.Email + v.loginAccount.AccountUUID = profile.Id - err = v.loadObjectMap(ctx) + accountSecrets, err := decryptAccountSecrets(v.loginAccount, password) if err != nil { - return fmt.Errorf("error loading cipher map: %w", err) + return fmt.Errorf("error decrypting account secrets: %w", err) } + v.loginAccount.Secrets = *accountSecrets return nil } @@ -603,101 +658,18 @@ func (v *webAPIVault) continueLoginWithTokens(ctx context.Context, tokenResp web ProtectedSymmetricKey: tokenResp.Key, } - err := v.Sync(ctx) + err := v.unlock(ctx, password) if err != nil { - return fmt.Errorf("error syncing after login: %w", err) - } - - return v.Unlock(ctx, password) -} - -func (v *webAPIVault) loadCollectionsFromObjectMap(ctx context.Context) error { - for _, collection := range v.ciphersMap.Collections { - obj, err := decryptCollection(collection, v.loginAccount.Secrets) - if err != nil { - return fmt.Errorf("error decrypting collection: %w", err) - } - v.storeObject(ctx, *obj) - } - return nil -} - -func (v *webAPIVault) loadFoldersFromObjectMap(ctx context.Context) error { - for _, folder := range v.ciphersMap.Folders { - obj, err := decryptFolder(folder, v.loginAccount.Secrets) - if err != nil { - return fmt.Errorf("error decrypting folder: %w", err) - } - v.storeObject(ctx, *obj) + return fmt.Errorf("error unlocking after login: %w", err) } - return nil -} -func (v *webAPIVault) loadObjectsFromObjectMap(ctx context.Context) error { - for _, value := range v.ciphersMap.Ciphers { - obj, err := decryptItem(value, v.loginAccount.Secrets) - if err != nil { - return fmt.Errorf("error decrypting object: %w", err) - } - v.storeObject(ctx, *obj) - } - return nil -} - -func (v *webAPIVault) loadObjectMap(ctx context.Context) error { - v.clearObjectStore(ctx) - - err := v.loadOrganizationSecretsFromObjectMap(ctx) - if err != nil { - return fmt.Errorf("error updating organization secrets: %w", err) - } - err = v.loadObjectsFromObjectMap(ctx) - if err != nil { - return fmt.Errorf("error updating object in store: %w", err) - } - - err = v.loadFoldersFromObjectMap(ctx) - if err != nil { - return fmt.Errorf("error updating folder in store: %w", err) - } - - err = v.loadCollectionsFromObjectMap(ctx) - if err != nil { - return fmt.Errorf("error updating collections in store: %w", err) - } - - tflog.Debug(ctx, "Vault is unlocked") - v.locked = false - return nil -} - -func (v *webAPIVault) loadOrganizationSecretsFromObjectMap(ctx context.Context) error { - for _, organization := range v.ciphersMap.Profile.Organizations { - key, err := decryptOrganizationKey(organization.Key, v.loginAccount.Secrets.RSAPrivateKey) - if err != nil { - return fmt.Errorf("error loading organization key: %w", err) - } - - orgSecret := OrganizationSecret{ - OrganizationUUID: organization.Id, - Key: *key, - } - v.loginAccount.Secrets.OrganizationSecrets[orgSecret.OrganizationUUID] = orgSecret - - obj := models.Object{ - ID: organization.Id, - Object: models.ObjectTypeOrganization, - Name: organization.Name, - } - v.storeObject(ctx, obj) - } - return nil + return v.sync(ctx) } func (v *webAPIVault) prepareAttachmentCreationRequest(ctx context.Context, itemId, filePath string) (*webapi.AttachmentRequestData, []byte, error) { // NOTE: We don't Sync() to get the latest version of Object before adding an attachment to it, because we // assume the Object's key can't change. - originalObj, err := v.GetObject(ctx, models.Object{ID: itemId, Object: models.ObjectTypeItem}) + originalObj, err := v.getObject(ctx, models.Object{ID: itemId, Object: models.ObjectTypeItem}) if err != nil { return nil, nil, fmt.Errorf("error getting original object: %w", err) } @@ -744,3 +716,74 @@ func (v *webAPIVault) prepareAttachmentCreationRequest(ctx context.Context, item } return &req, encDataBuffer, nil } + +func (v *webAPIVault) loadObjectMap(ctx context.Context, cipherMap webapi.SyncResponse) error { + v.clearObjectStore(ctx) + + for _, orgSecret := range v.loginAccount.Secrets.OrganizationSecrets { + v.storeObject(ctx, models.Object{ + ID: orgSecret.OrganizationUUID, + Object: models.ObjectTypeOrganization, + Name: orgSecret.Name, + }) + } + + res, err := ciphersToObjects(v.loginAccount.Secrets, cipherMap.Ciphers) + if err != nil { + return fmt.Errorf("error updating object in store: %w", err) + } + v.storeObjects(ctx, res) + + res, err = ciphersToObjects(v.loginAccount.Secrets, cipherMap.Folders) + if err != nil { + return fmt.Errorf("error updating folder in store: %w", err) + } + v.storeObjects(ctx, res) + + res, err = ciphersToObjects(v.loginAccount.Secrets, cipherMap.Collections) + if err != nil { + return fmt.Errorf("error updating collections in store: %w", err) + } + v.storeObjects(ctx, res) + + return nil +} + +func ciphersToObjects[T any](accountSecrets AccountSecrets, ciphers []T) ([]models.Object, error) { + objects := make([]models.Object, len(ciphers)) + for k, value := range ciphers { + var obj *models.Object + var err error + switch secret := any(value).(type) { + case models.Object: + obj, err = decryptItem(secret, accountSecrets) + case webapi.Folder: + obj, err = decryptFolder(secret, accountSecrets) + case webapi.Collection: + obj, err = decryptCollection(secret, accountSecrets) + } + if err != nil { + return nil, fmt.Errorf("error decrypting cipher: %w", err) + } + objects[k] = *obj + } + return objects, nil +} + +func loadOrganizationSecrets(accountSecrets AccountSecrets, organizations []webapi.Organization) error { + for _, organization := range organizations { + key, err := decryptOrganizationKey(organization.Key, accountSecrets.RSAPrivateKey) + if err != nil { + return fmt.Errorf("error loading organization key: %w", err) + } + + orgSecret := OrganizationSecret{ + OrganizationUUID: organization.Id, + Key: *key, + Name: organization.Name, + } + accountSecrets.OrganizationSecrets[orgSecret.OrganizationUUID] = orgSecret + + } + return nil +} diff --git a/internal/bitwarden/models/password_manager.go b/internal/bitwarden/models/password_manager.go index 9e8fd47..d682b50 100644 --- a/internal/bitwarden/models/password_manager.go +++ b/internal/bitwarden/models/password_manager.go @@ -9,6 +9,7 @@ var ( ErrObjectNotFound = errors.New("object not found") ErrAttachmentNotFound = errors.New("attachment not found") ErrVaultLocked = errors.New("vault is locked") + ErrAlreadyLoggedIn = errors.New("you are already logged in") ErrWrongMasterPassword = errors.New("invalid master password") ErrLoggedOut = errors.New("please login first") ) diff --git a/internal/provider/resource_item_login_test.go b/internal/provider/resource_item_login_test.go index bddefb5..3d05c47 100644 --- a/internal/provider/resource_item_login_test.go +++ b/internal/provider/resource_item_login_test.go @@ -50,6 +50,22 @@ func TestAccResourceItemLoginAttributes(t *testing.T) { }) } +func TestAccResourceItemLoginMany(t *testing.T) { + if !useEmbeddedClient { + t.Skip("Skipping test because using the official client to create many items is too slow") + } + ensureVaultwardenConfigured(t) + + resource.Test(t, resource.TestCase{ + ProviderFactories: providerFactories, + Steps: []resource.TestStep{ + { + Config: tfConfigPasswordManagerProvider() + tfConfigResourceItemManyLogins(), + }, + }, + }) +} + func TestAccMissingResourceItemLoginIsRecreated(t *testing.T) { ensureVaultwardenConfigured(t) @@ -101,6 +117,17 @@ func tfConfigResourceItemLoginSmall() string { ` } +func tfConfigResourceItemManyLogins() string { + return ` + resource "bitwarden_item_login" "foo" { + provider = bitwarden + + count = 100 + name = "Login Item ${count.index + 1}" + } +` +} + func tfConfigResourceItemLogin(source string) string { return fmt.Sprintf(` resource "bitwarden_item_login" "foo" { diff --git a/internal/provider/testhelper_secrets_manager_test.go b/internal/provider/testhelper_secrets_manager_test.go index a8c5305..898a344 100644 --- a/internal/provider/testhelper_secrets_manager_test.go +++ b/internal/provider/testhelper_secrets_manager_test.go @@ -8,6 +8,7 @@ import ( "fmt" "io" "net/http" + "sync" "time" "github.com/golang-jwt/jwt/v5" @@ -45,6 +46,7 @@ type testSecretsManager struct { knownOrganizations map[string]struct{} projectsStore map[string]models.Project secretsStore map[string]webapi.Secret + mutex sync.RWMutex } type Clients struct { @@ -82,7 +84,7 @@ func (tsm *testSecretsManager) Run(ctx context.Context, serverPort int) { server := &http.Server{ Handler: handler, - Addr: fmt.Sprintf(":%d", serverPort), + Addr: fmt.Sprintf("127.0.0.1:%d", serverPort), } go func() { @@ -97,6 +99,9 @@ func (tsm *testSecretsManager) Run(ctx context.Context, serverPort int) { } func (tsm *testSecretsManager) ClientCreateNewOrganization() (string, error) { + tsm.mutex.Lock() + defer tsm.mutex.Unlock() + encryptionKey, err := generateOrganizationKey() if err != nil { return "", err @@ -110,6 +115,9 @@ func (tsm *testSecretsManager) ClientCreateNewOrganization() (string, error) { } func (tsm *testSecretsManager) ClientCreateAccessToken(orgId string) (string, error) { + tsm.mutex.Lock() + defer tsm.mutex.Unlock() + orgKey, v := tsm.clientSideInformation.orgEncryptionKeys[orgId] if !v { return "", fmt.Errorf("organization not found") @@ -158,6 +166,9 @@ func (tsm *testSecretsManager) createAccessToken(orgId string, request CreateAcc } func (tsm *testSecretsManager) handlerLogin(w http.ResponseWriter, r *http.Request) { + tsm.mutex.Lock() + defer tsm.mutex.Unlock() + if err := r.ParseForm(); err != nil { http.Error(w, "Failed to parse form", http.StatusBadRequest) return @@ -215,6 +226,9 @@ func (tsm *testSecretsManager) handlerCreateGetSecret(w http.ResponseWriter, r * } func (tsm *testSecretsManager) handlerCreateProject(w http.ResponseWriter, r *http.Request) { + tsm.mutex.Lock() + defer tsm.mutex.Unlock() + orgId := mux.Vars(r)["orgId"] err := tsm.checkAuthentication(r.Header.Get("Authorization")) @@ -253,6 +267,9 @@ func (tsm *testSecretsManager) handlerCreateProject(w http.ResponseWriter, r *ht } func (tsm *testSecretsManager) handlerCreateSecret(w http.ResponseWriter, r *http.Request) { + tsm.mutex.Lock() + defer tsm.mutex.Unlock() + orgId := mux.Vars(r)["orgId"] _, v := tsm.knownOrganizations[orgId] if !v { @@ -313,6 +330,9 @@ func (tsm *testSecretsManager) handlerCreateSecret(w http.ResponseWriter, r *htt } func (tsm *testSecretsManager) handlerGetSecrets(w http.ResponseWriter, r *http.Request) { + tsm.mutex.RLock() + defer tsm.mutex.RUnlock() + orgId := mux.Vars(r)["orgId"] secretList := webapi.SecretsWithProjectsList{} @@ -341,6 +361,9 @@ func (tsm *testSecretsManager) handlerGetSecrets(w http.ResponseWriter, r *http. } func (tsm *testSecretsManager) handlerGetProject(w http.ResponseWriter, r *http.Request) { + tsm.mutex.RLock() + defer tsm.mutex.RUnlock() + err := tsm.checkAuthentication(r.Header.Get("Authorization")) if err != nil { http.Error(w, err.Error(), http.StatusUnauthorized) @@ -360,6 +383,9 @@ func (tsm *testSecretsManager) handlerGetProject(w http.ResponseWriter, r *http. } func (tsm *testSecretsManager) handlerGetSecret(w http.ResponseWriter, r *http.Request) { + tsm.mutex.RLock() + defer tsm.mutex.RUnlock() + err := tsm.checkAuthentication(r.Header.Get("Authorization")) if err != nil { http.Error(w, err.Error(), http.StatusUnauthorized) @@ -379,6 +405,9 @@ func (tsm *testSecretsManager) handlerGetSecret(w http.ResponseWriter, r *http.R } func (tsm *testSecretsManager) handlerEditProject(w http.ResponseWriter, r *http.Request) { + tsm.mutex.Lock() + defer tsm.mutex.Unlock() + err := tsm.checkAuthentication(r.Header.Get("Authorization")) if err != nil { http.Error(w, err.Error(), http.StatusUnauthorized) @@ -416,6 +445,9 @@ func (tsm *testSecretsManager) handlerEditProject(w http.ResponseWriter, r *http } func (tsm *testSecretsManager) handlerEditSecret(w http.ResponseWriter, r *http.Request) { + tsm.mutex.Lock() + defer tsm.mutex.Unlock() + err := tsm.checkAuthentication(r.Header.Get("Authorization")) if err != nil { http.Error(w, err.Error(), http.StatusUnauthorized) @@ -455,6 +487,9 @@ func (tsm *testSecretsManager) handlerEditSecret(w http.ResponseWriter, r *http. } func (tsm *testSecretsManager) handlerDeleteProject(w http.ResponseWriter, r *http.Request) { + tsm.mutex.Lock() + defer tsm.mutex.Unlock() + err := tsm.checkAuthentication(r.Header.Get("Authorization")) if err != nil { http.Error(w, err.Error(), http.StatusUnauthorized) @@ -481,6 +516,9 @@ func (tsm *testSecretsManager) handlerDeleteProject(w http.ResponseWriter, r *ht } func (tsm *testSecretsManager) handlerDeleteSecret(w http.ResponseWriter, r *http.Request) { + tsm.mutex.Lock() + defer tsm.mutex.Unlock() + err := tsm.checkAuthentication(r.Header.Get("Authorization")) if err != nil { http.Error(w, err.Error(), http.StatusUnauthorized)