diff --git a/claim/claimstore.go b/claim/claimstore.go index 6e36703a..a71dff39 100644 --- a/claim/claimstore.go +++ b/claim/claimstore.go @@ -107,6 +107,7 @@ func (s Store) ListResults(claimID string) ([]string, error) { return nil, err } + sort.Strings(results) return results, nil } @@ -125,11 +126,17 @@ func (s Store) ListOutputs(resultID string) ([]string, error) { for i, fullName := range outputNames { outputNames[i] = strings.TrimLeft(fullName, resultID+"-") } - + sort.Strings(outputNames) return outputNames, nil } func (s Store) ReadInstallation(installation string) (Installation, error) { + handleClose, err := s.backingStore.HandleConnect() + defer handleClose() + if err != nil { + return Installation{}, err + } + claims, err := s.ReadAllClaims(installation) if err != nil { return Installation{}, err @@ -153,6 +160,12 @@ func (s Store) ReadInstallation(installation string) (Installation, error) { } func (s Store) ReadInstallationStatus(installation string) (Installation, error) { + handleClose, err := s.backingStore.HandleConnect() + defer handleClose() + if err != nil { + return Installation{}, err + } + claimIds, err := s.ListClaims(installation) if err != nil { return Installation{}, err @@ -191,6 +204,12 @@ func (s Store) ReadInstallationStatus(installation string) (Installation, error) } func (s Store) ReadAllInstallationStatus() ([]Installation, error) { + handleClose, err := s.backingStore.HandleConnect() + defer handleClose() + if err != nil { + return nil, err + } + names, err := s.ListInstallations() if err != nil { return nil, err @@ -254,6 +273,12 @@ func (s Store) ReadAllClaims(installation string) ([]Claim, error) { } func (s Store) ReadLastClaim(installation string) (Claim, error) { + handleClose, err := s.backingStore.HandleConnect() + defer handleClose() + if err != nil { + return Claim{}, err + } + claimIds, err := s.backingStore.List(ItemTypeClaims, installation) if err != nil { return Claim{}, s.handleNotExistsError(err, ErrInstallationNotFound) @@ -306,12 +331,24 @@ func (s Store) ReadAllResults(claimID string) ([]Result, error) { // ReadLastOutputs returns the most recent (last) value of each output associated // with the installation. func (s Store) ReadLastOutputs(installation string) (Outputs, error) { + handleClose, err := s.backingStore.HandleConnect() + defer handleClose() + if err != nil { + return Outputs{}, err + } + return s.readLastOutputs(installation, "") } // ReadLastOutput returns the most recent value (last) of the specified Output associated // with the installation. func (s Store) ReadLastOutput(installation string, name string) (Output, error) { + handleClose, err := s.backingStore.HandleConnect() + defer handleClose() + if err != nil { + return Output{}, err + } + outputs, err := s.readLastOutputs(installation, name) if err != nil { return Output{}, err @@ -380,6 +417,12 @@ func (s Store) readLastOutputs(installation string, filterOutput string) (Output } func (s Store) ReadLastResult(claimID string) (Result, error) { + handleClose, err := s.backingStore.HandleConnect() + defer handleClose() + if err != nil { + return Result{}, err + } + resultIDs, err := s.backingStore.List(ItemTypeResults, claimID) if err != nil { return Result{}, s.handleNotExistsError(err, ErrClaimNotFound) @@ -417,6 +460,12 @@ func (s Store) ReadOutput(c Claim, r Result, outputName string) (Output, error) } func (s Store) SaveClaim(c Claim) error { + handleClose, err := s.backingStore.HandleConnect() + defer handleClose() + if err != nil { + return err + } + bytes, err := json.MarshalIndent(c, "", " ") if err != nil { return err @@ -466,6 +515,12 @@ func (s Store) SaveOutput(o Output) error { } func (s Store) DeleteInstallation(installation string) error { + handleClose, err := s.backingStore.HandleConnect() + defer handleClose() + if err != nil { + return err + } + claimIds, err := s.ListClaims(installation) if err != nil { return err @@ -483,6 +538,12 @@ func (s Store) DeleteInstallation(installation string) error { } func (s Store) DeleteClaim(claimID string) error { + handleClose, err := s.backingStore.HandleConnect() + defer handleClose() + if err != nil { + return err + } + resultIds, err := s.ListResults(claimID) if err != nil { return err @@ -500,6 +561,12 @@ func (s Store) DeleteClaim(claimID string) error { } func (s Store) DeleteResult(resultID string) error { + handleClose, err := s.backingStore.HandleConnect() + defer handleClose() + if err != nil { + return err + } + outputNames, err := s.ListOutputs(resultID) if err != nil { return err diff --git a/claim/claimstore_test.go b/claim/claimstore_test.go index 8e090e04..504d3248 100644 --- a/claim/claimstore_test.go +++ b/claim/claimstore_test.go @@ -66,13 +66,8 @@ var b64decode = func(src []byte) ([]byte, error) { // RESULT_ID_2/ // RESULT_ID_2_OUTPUT_1 // RESULT_ID_2_OUTPUT_2 -func generateClaimData(t *testing.T) (Provider, func() error) { - tempDir, err := ioutil.TempDir("", "cnabtest") - require.NoError(t, err, "Failed to create temp dir") - cleanup := func() error { return os.RemoveAll(tempDir) } - - storeDir := filepath.Join(tempDir, "claimstore") - backingStore := crud.NewFileSystemStore(storeDir, NewClaimStoreFileExtensions()) +func generateClaimData(t *testing.T) (Provider, crud.MockStore) { + backingStore := crud.NewMockStore() cp := NewClaimStore(backingStore, nil, nil) bun := bundle.Bundle{ @@ -117,7 +112,7 @@ func generateClaimData(t *testing.T) (Provider, func() error) { createOutput := func(c Claim, r Result, name string) Output { o := NewOutput(c, r, name, []byte(c.Action+" "+name)) - err = cp.SaveOutput(o) + err := cp.SaveOutput(o) require.NoError(t, err, "SaveOutput failed") return o @@ -153,7 +148,20 @@ func generateClaimData(t *testing.T) (Provider, func() error) { createClaim(baz, ActionInstall) - return cp, cleanup + backingStore.ResetCounts() + return cp, backingStore +} + +func assertSingleConnection(t *testing.T, datastore crud.MockStore) { + t.Helper() + + connects, err := datastore.GetConnectCount() + require.NoError(t, err, "GetConnectCount failed") + assert.Equal(t, 1, connects, "expected a single connect") + + closes, err := datastore.GetCloseCount() + require.NoError(t, err, "GetCloseCount failed") + assert.Equal(t, 1, closes, "expected a single close") } func TestCanSaveReadAndDelete(t *testing.T) { @@ -230,18 +238,21 @@ func TestCanUpdate(t *testing.T) { } func TestClaimStore_Installations(t *testing.T) { - cp, cleanup := generateClaimData(t) - defer cleanup() + cp, datastore := generateClaimData(t) t.Run("ListInstallations", func(t *testing.T) { + datastore.ResetCounts() installations, err := cp.ListInstallations() require.NoError(t, err, "ListInstallations failed") require.Len(t, installations, 3, "Expected 3 installations") assert.Equal(t, []string{"bar", "baz", "foo"}, installations) + + assertSingleConnection(t, datastore) }) t.Run("ReadAllInstallationStatus", func(t *testing.T) { + datastore.ResetCounts() installations, err := cp.ReadAllInstallationStatus() require.NoError(t, err, "ReadAllInstallationStatus failed") @@ -254,9 +265,12 @@ func TestClaimStore_Installations(t *testing.T) { assert.Equal(t, "bar", bar.Name) assert.Equal(t, "baz", baz.Name) assert.Equal(t, "foo", foo.Name) + + assertSingleConnection(t, datastore) }) t.Run("ReadInstallationStatus", func(t *testing.T) { + datastore.ResetCounts() foo, err := cp.ReadInstallationStatus("foo") require.NoError(t, err, "ReadInstallationStatus failed") @@ -267,6 +281,8 @@ func TestClaimStore_Installations(t *testing.T) { lastClaim, err := foo.GetLastClaim() require.NoError(t, err, "GetLastClaim failed") assert.Equal(t, ActionUninstall, lastClaim.Action) + + assertSingleConnection(t, datastore) }) t.Run("ReadInstallationStatus - invalid installation", func(t *testing.T) { @@ -276,6 +292,7 @@ func TestClaimStore_Installations(t *testing.T) { }) t.Run("ReadInstallation", func(t *testing.T) { + datastore.ResetCounts() foo, err := cp.ReadInstallation("foo") require.NoError(t, err, "ReadInstallation failed") @@ -287,6 +304,8 @@ func TestClaimStore_Installations(t *testing.T) { assert.Equal(t, ActionUpgrade, foo.Claims[1].Action) assert.Equal(t, "test", foo.Claims[2].Action) assert.Equal(t, ActionUninstall, foo.Claims[3].Action) + + assertSingleConnection(t, datastore) }) t.Run("ReadInstallation - invalid installation", func(t *testing.T) { @@ -297,12 +316,13 @@ func TestClaimStore_Installations(t *testing.T) { } func TestClaimStore_DeleteInstallation(t *testing.T) { - cp, cleanup := generateClaimData(t) - defer cleanup() + cp, datastore := generateClaimData(t) err := cp.DeleteInstallation("foo") require.NoError(t, err, "DeleteInstallation failed") + assertSingleConnection(t, datastore) + names, err := cp.ListInstallations() require.NoError(t, err, "ListInstallations failed") assert.Equal(t, []string{"bar", "baz"}, names, "expected foo to be deleted completely") @@ -312,10 +332,10 @@ func TestClaimStore_DeleteInstallation(t *testing.T) { } func TestClaimStore_Claims(t *testing.T) { - cp, cleanup := generateClaimData(t) - defer cleanup() + cp, datastore := generateClaimData(t) t.Run("ReadAllClaims", func(t *testing.T) { + datastore.ResetCounts() claims, err := cp.ReadAllClaims("foo") require.NoError(t, err, "Failed to read claims: %s", err) @@ -324,6 +344,8 @@ func TestClaimStore_Claims(t *testing.T) { assert.Equal(t, ActionUpgrade, claims[1].Action) assert.Equal(t, "test", claims[2].Action) assert.Equal(t, ActionUninstall, claims[3].Action) + + assertSingleConnection(t, datastore) }) t.Run("ReadAllClaims - invalid installation", func(t *testing.T) { @@ -333,10 +355,13 @@ func TestClaimStore_Claims(t *testing.T) { }) t.Run("ListClaims", func(t *testing.T) { + datastore.ResetCounts() claims, err := cp.ListClaims("foo") require.NoError(t, err, "Failed to read claims: %s", err) require.Len(t, claims, 4, "Expected 4 claims") + + assertSingleConnection(t, datastore) }) t.Run("ListClaims - invalid installation", func(t *testing.T) { @@ -352,11 +377,14 @@ func TestClaimStore_Claims(t *testing.T) { assert.NotEmpty(t, claims, "no claims were found") claimID := claims[0] + datastore.ResetCounts() c, err := cp.ReadClaim(claimID) require.NoError(t, err, "ReadClaim failed") assert.Equal(t, "foo", c.Installation) assert.Equal(t, ActionInstall, c.Action) + + assertSingleConnection(t, datastore) }) t.Run("ReadClaim - invalid claim", func(t *testing.T) { @@ -365,11 +393,14 @@ func TestClaimStore_Claims(t *testing.T) { }) t.Run("ReadLastClaim", func(t *testing.T) { + datastore.ResetCounts() c, err := cp.ReadLastClaim("bar") require.NoError(t, err, "ReadLastClaim failed") assert.Equal(t, "bar", c.Installation) assert.Equal(t, ActionInstall, c.Action) + + assertSingleConnection(t, datastore) }) t.Run("ReadLastClaim - invalid installation", func(t *testing.T) { @@ -380,8 +411,7 @@ func TestClaimStore_Claims(t *testing.T) { } func TestClaimStore_Results(t *testing.T) { - cp, cleanup := generateClaimData(t) - defer cleanup() + cp, datastore := generateClaimData(t) barClaims, err := cp.ListClaims("bar") require.NoError(t, err, "ListClaims failed") @@ -394,9 +424,13 @@ func TestClaimStore_Results(t *testing.T) { unfinishedClaimID := bazClaims[1] // this claim doesn't have any results yet t.Run("ListResults", func(t *testing.T) { + datastore.ResetCounts() + results, err := cp.ListResults(claimID) require.NoError(t, err, "ListResults failed") assert.Len(t, results, 2, "expected 2 results") + + assertSingleConnection(t, datastore) }) t.Run("ListResults - unfinished claim", func(t *testing.T) { @@ -406,12 +440,16 @@ func TestClaimStore_Results(t *testing.T) { }) t.Run("ReadAllResults", func(t *testing.T) { + datastore.ResetCounts() + results, err := cp.ReadAllResults(claimID) require.NoError(t, err, "ReadAllResults failed") assert.Len(t, results, 2, "expected 2 results") assert.Equal(t, StatusRunning, results[0].Status) assert.Equal(t, StatusSucceeded, results[1].Status) + + assertSingleConnection(t, datastore) }) t.Run("ReadAllResults - unfinished claim", func(t *testing.T) { @@ -421,10 +459,14 @@ func TestClaimStore_Results(t *testing.T) { }) t.Run("ReadLastResult", func(t *testing.T) { + datastore.ResetCounts() + r, err := cp.ReadLastResult(claimID) require.NoError(t, err, "ReadLastResult failed") assert.Equal(t, StatusSucceeded, r.Status) + + assertSingleConnection(t, datastore) }) t.Run("ReadLastResult - unfinished claim", func(t *testing.T) { @@ -439,10 +481,13 @@ func TestClaimStore_Results(t *testing.T) { resultID := results[0] + datastore.ResetCounts() r, err := cp.ReadResult(resultID) require.NoError(t, err, "ReadResult failed") assert.Equal(t, StatusRunning, r.Status) + + assertSingleConnection(t, datastore) }) t.Run("ReadResult - invalid result", func(t *testing.T) { @@ -453,8 +498,7 @@ func TestClaimStore_Results(t *testing.T) { } func TestClaimStore_Outputs(t *testing.T) { - cp, cleanup := generateClaimData(t) - defer cleanup() + cp, datastore := generateClaimData(t) fooClaims, err := cp.ReadAllClaims("foo") require.NoError(t, err, "ReadAllClaims failed") @@ -477,12 +521,15 @@ func TestClaimStore_Outputs(t *testing.T) { resultIDWithoutOutputs := barResult.ID t.Run("ListOutputs", func(t *testing.T) { + datastore.ResetCounts() outputs, err := cp.ListOutputs(resultID) require.NoError(t, err, "ListResults failed") assert.Len(t, outputs, 2, "expected 2 outputs") assert.Equal(t, "output1", outputs[0]) assert.Equal(t, "output2", outputs[1]) + + assertSingleConnection(t, datastore) }) t.Run("ListOutputs - no outputs", func(t *testing.T) { @@ -492,6 +539,7 @@ func TestClaimStore_Outputs(t *testing.T) { }) t.Run("ReadLastOutputs", func(t *testing.T) { + datastore.ResetCounts() outputs, err := cp.ReadLastOutputs("foo") require.NoError(t, err, "GetLastOutputs failed") @@ -504,6 +552,8 @@ func TestClaimStore_Outputs(t *testing.T) { gotOutput2, hasOutput2 := outputs.GetByName("output2") assert.True(t, hasOutput2, "should have found output2") assert.Equal(t, "upgrade output2", string(gotOutput2.Value), "did not find the most recent value for output2") + + assertSingleConnection(t, datastore) }) t.Run("ReadLastOutputs - invalid installation", func(t *testing.T) { @@ -513,10 +563,13 @@ func TestClaimStore_Outputs(t *testing.T) { }) t.Run("ReadLastOutput", func(t *testing.T) { + datastore.ResetCounts() o, err := cp.ReadLastOutput("foo", "output1") require.NoError(t, err, "GetLastOutputs failed") assert.Equal(t, "upgrade output1", string(o.Value), "did not find the most recent value for output1") + + assertSingleConnection(t, datastore) }) t.Run("ReadLastOutput - invalid installation", func(t *testing.T) { @@ -531,6 +584,8 @@ func TestClaimStore_Outputs(t *testing.T) { installResult, err := cp.ReadLastResult(installClaim.ID) require.NoError(t, err, "ReadLastResult failed") + datastore.ResetCounts() + o, err := cp.ReadOutput(installClaim, installResult, "output1") require.NoError(t, err, "ReadOutput failed") @@ -538,6 +593,8 @@ func TestClaimStore_Outputs(t *testing.T) { assert.Equal(t, installResult.ID, o.result.ID, "output.Result is not set") assert.Equal(t, installClaim.ID, o.result.claim.ID, "output.Result.Claim is not set") assert.Equal(t, "install output1", string(o.Value)) + + assertSingleConnection(t, datastore) }) t.Run("ReadOutput - no outputs", func(t *testing.T) { diff --git a/utils/crud/backingstore.go b/utils/crud/backingstore.go index f2562902..4845f8a4 100644 --- a/utils/crud/backingstore.go +++ b/utils/crud/backingstore.go @@ -70,33 +70,30 @@ func (s *BackingStore) autoClose() error { } func (s *BackingStore) List(itemType string, group string) ([]string, error) { - if s.shouldAutoConnect() { - defer s.autoClose() - if err := s.Connect(); err != nil { - return nil, err - } + handleClose, err := s.HandleConnect() + defer handleClose() + if err != nil { + return nil, err } return s.backingStore.List(itemType, group) } func (s *BackingStore) Save(itemType string, group string, name string, data []byte) error { - if s.shouldAutoConnect() { - defer s.autoClose() - if err := s.Connect(); err != nil { - return err - } + handleClose, err := s.HandleConnect() + defer handleClose() + if err != nil { + return err } return s.backingStore.Save(itemType, group, name, data) } func (s *BackingStore) Read(itemType string, name string) ([]byte, error) { - if s.shouldAutoConnect() { - defer s.autoClose() - if err := s.Connect(); err != nil { - return nil, err - } + handleClose, err := s.HandleConnect() + defer handleClose() + if err != nil { + return nil, err } return s.backingStore.Read(itemType, name) @@ -104,11 +101,10 @@ func (s *BackingStore) Read(itemType string, name string) ([]byte, error) { // ReadAll retrieves all the items with the specified prefix func (s *BackingStore) ReadAll(itemType string, group string) ([][]byte, error) { - if s.shouldAutoConnect() { - defer s.autoClose() - if err := s.Connect(); err != nil { - return nil, err - } + handleClose, err := s.HandleConnect() + defer handleClose() + if err != nil { + return nil, err } results := make([][]byte, 0) @@ -129,11 +125,10 @@ func (s *BackingStore) ReadAll(itemType string, group string) ([][]byte, error) } func (s *BackingStore) Delete(itemType string, name string) error { - if s.shouldAutoConnect() { - defer s.autoClose() - if err := s.Connect(); err != nil { - return err - } + handleClose, err := s.HandleConnect() + defer handleClose() + if err != nil { + return err } return s.backingStore.Delete(itemType, name) @@ -144,3 +139,13 @@ func (s *BackingStore) shouldAutoConnect() bool { // caller manage the connection. return !s.opened && s.connect != nil } + +func (s *BackingStore) HandleConnect() (func() error, error) { + if s.shouldAutoConnect() { + err := s.Connect() + return s.autoClose, err + } + + // Return a no-op close function + return func() error { return nil }, nil +} diff --git a/utils/crud/backingstore_test.go b/utils/crud/backingstore_test.go index dcabd7a9..c4f0ac5f 100644 --- a/utils/crud/backingstore_test.go +++ b/utils/crud/backingstore_test.go @@ -19,10 +19,8 @@ func TestBackingStore_Read(t *testing.T) { for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { s := NewMockStore() - s.groups[testItemType] = map[string][]string{ - testGroup: {"key1"}, - } - s.data[testItemType] = map[string][]byte{"key1": []byte("value1")} + s.Save(testItemType, testGroup, "key1", []byte("value1")) + s.ResetCounts() bs := NewBackingStore(s) bs.AutoClose = tc.autoclose @@ -110,13 +108,9 @@ func TestBackingStore_List(t *testing.T) { for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { s := NewMockStore() - s.groups[testItemType] = map[string][]string{ - testGroup: {"key1"}, - } - s.data[testItemType] = map[string][]byte{ - "key1": []byte("value1"), - "key2": []byte("value2"), - } + s.Save(testItemType, testGroup, "key1", []byte("value1")) + s.Save(testItemType, "", "key2", []byte("value2")) + s.ResetCounts() bs := NewBackingStore(s) bs.AutoClose = tc.autoclose @@ -152,10 +146,8 @@ func TestBackingStore_Delete(t *testing.T) { for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { s := NewMockStore() - s.groups[testItemType] = map[string][]string{ - testGroup: {"key1"}, - } - s.data[testItemType] = map[string][]byte{"key1": []byte("value1")} + s.Save(testItemType, testGroup, "key1", []byte("value1")) + s.ResetCounts() bs := NewBackingStore(s) bs.AutoClose = tc.autoclose @@ -192,13 +184,9 @@ func TestBackingStore_ReadAll(t *testing.T) { for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { s := NewMockStore() - s.groups[testItemType] = map[string][]string{ - testGroup: {"key1"}, - } - s.data[testItemType] = map[string][]byte{ - "key1": []byte("value1"), - "key2": []byte("value2"), - } + s.Save(testItemType, testGroup, "key1", []byte("value1")) + s.Save(testItemType, "", "key2", []byte("value2")) + bs := NewBackingStore(s) bs.AutoClose = tc.autoclose diff --git a/utils/crud/mock_store.go b/utils/crud/mock_store.go index c5f76ac1..e816792e 100644 --- a/utils/crud/mock_store.go +++ b/utils/crud/mock_store.go @@ -2,6 +2,7 @@ package crud import ( "fmt" + "path" "strconv" ) @@ -9,6 +10,16 @@ import ( // changes. But we also provide a mock for testing. var _ Store = MockStore{} +type item struct { + itemType, group, name string + data []byte +} + +type itemGroup struct { + itemType, group string + items map[string]struct{} +} + const ( connectCount = "connect-count" closeCount = "close-count" @@ -22,11 +33,11 @@ const ( type MockStore struct { // data stores the mocked data // itemType -> name -> data - data map[string]map[string][]byte + data map[string]*item // groups stores the groupings applied to the mocked data - // itemType -> group -> list of names - groups map[string]map[string][]string + // itemType -> group -> list of keys + groups map[string]*itemGroup // DeleteMock replaces the default Delete implementation with the specified function. // This allows for simulating failures. @@ -47,69 +58,49 @@ type MockStore struct { func NewMockStore() MockStore { return MockStore{ - groups: map[string]map[string][]string{}, - data: map[string]map[string][]byte{}, + groups: map[string]*itemGroup{}, + data: map[string]*item{}, } } func (s MockStore) Connect() error { - _, ok := s.data[mockStoreType] - if !ok { - s.data[mockStoreType] = make(map[string][]byte, 1) - } - // Keep track of Connect calls for test asserts later count, err := s.GetConnectCount() if err != nil { return err } - - s.data[mockStoreType][connectCount] = []byte(strconv.Itoa(count + 1)) + s.setCount(connectCount, count+1) return nil } func (s MockStore) Close() error { - _, ok := s.data[mockStoreType] - if !ok { - s.data[mockStoreType] = make(map[string][]byte, 1) - } - // Keep track of Close calls for test asserts later count, err := s.GetCloseCount() if err != nil { return err } - - s.data[mockStoreType][closeCount] = []byte(strconv.Itoa(count + 1)) + s.setCount(closeCount, count+1) return nil } +func (s MockStore) key(itemType string, id string) string { + return path.Join(itemType, id) +} + func (s MockStore) List(itemType string, group string) ([]string, error) { if s.ListMock != nil { return s.ListMock(itemType, group) } - if groups, ok := s.groups[itemType]; ok { - if names, ok := groups[group]; ok { - buf := make([]string, len(names)) - i := 0 - for _, name := range names { - buf[i] = name - i++ - } - return buf, nil - } - - if group == "" { - // List all the groups, e.g. if we were listing claims, this would list the installation names - names := make([]string, 0, len(groups)) - for groupName := range groups { - names = append(names, groupName) - } - return names, nil + // List all items in a group, e.g. claims in an installation + if g, ok := s.groups[s.key(itemType, group)]; ok { + names := make([]string, 0, len(g.items)) + for name := range g.items { + names = append(names, name) } + return names, nil } return nil, nil @@ -120,22 +111,25 @@ func (s MockStore) Save(itemType string, group string, name string, data []byte) return s.SaveMock(itemType, name, data) } - groupNames, ok := s.groups[itemType] + g, ok := s.groups[s.key(itemType, group)] if !ok { - groupNames = map[string][]string{ - group: make([]string, 0, 1), + g = &itemGroup{ + group: group, + itemType: itemType, + items: make(map[string]struct{}, 1), } - s.groups[itemType] = groupNames + s.groups[s.key(itemType, group)] = g } - groupNames[group] = append(groupNames[group], name) + g.items[name] = struct{}{} - itemData, ok := s.data[itemType] - if !ok { - itemData = make(map[string][]byte, 1) - s.data[itemType] = itemData + i := &item{ + itemType: itemType, + group: group, + name: name, + data: data, } + s.data[s.key(itemType, name)] = i - itemData[name] = data return nil } @@ -144,10 +138,8 @@ func (s MockStore) Read(itemType string, name string) ([]byte, error) { return s.ReadMock(itemType, name) } - if itemData, ok := s.data[itemType]; ok { - if data, ok := itemData[name]; ok { - return data, nil - } + if i, ok := s.data[s.key(itemType, name)]; ok { + return i.data, nil } return nil, ErrRecordDoesNotExist @@ -158,11 +150,16 @@ func (s MockStore) Delete(itemType string, name string) error { return s.DeleteMock(itemType, name) } - if itemData, ok := s.data[itemType]; ok { - if _, ok := itemData[name]; ok { - delete(itemData, name) - return nil + if i, ok := s.data[s.key(itemType, name)]; ok { + delete(s.data, s.key(itemType, name)) + + if g, ok := s.groups[s.key(itemType, i.group)]; ok { + delete(g.items, i.name) + if len(g.items) == 0 { + delete(s.groups, s.key(itemType, i.group)) + } } + return nil } return ErrRecordDoesNotExist @@ -171,14 +168,14 @@ func (s MockStore) Delete(itemType string, name string) error { // GetConnectCount is for tests to safely read the Connect call count // without accidentally triggering it by using Read. func (s MockStore) GetConnectCount() (int, error) { - countB, ok := s.data[mockStoreType][connectCount] + countB, ok := s.data[s.key(mockStoreType, connectCount)] if !ok { - countB = []byte("0") + return 0, nil } - count, err := strconv.Atoi(string(countB)) + count, err := strconv.Atoi(string(countB.data)) if err != nil { - return 0, fmt.Errorf("could not convert connect-count %s to int: %v", string(countB), err) + return 0, fmt.Errorf("could not convert connect-count %s to int: %v", string(countB.data), err) } return count, nil @@ -187,15 +184,28 @@ func (s MockStore) GetConnectCount() (int, error) { // GetCloseCount is for tests to safely read the Close call count // without accidentally triggering it by using Read. func (s MockStore) GetCloseCount() (int, error) { - countB, ok := s.data[mockStoreType][closeCount] + countB, ok := s.data[s.key(mockStoreType, closeCount)] if !ok { - countB = []byte("0") + return 0, nil } - count, err := strconv.Atoi(string(countB)) + count, err := strconv.Atoi(string(countB.data)) if err != nil { - return 0, fmt.Errorf("could not convert close-count %s to int: %v", string(countB), err) + return 0, fmt.Errorf("could not convert close-count %s to int: %v", string(countB.data), err) } return count, nil } + +func (s MockStore) ResetCounts() { + s.setCount(connectCount, 0) + s.setCount(closeCount, 0) +} + +func (s MockStore) setCount(count string, value int) { + s.data[path.Join(mockStoreType, count)] = &item{ + itemType: mockStoreType, + name: count, + data: []byte(strconv.Itoa(value)), + } +}