Skip to content

Commit

Permalink
Make it thread safe
Browse files Browse the repository at this point in the history
  • Loading branch information
olbrichattila committed Aug 25, 2024
1 parent f281d31 commit 307f091
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 0 deletions.
20 changes: 20 additions & 0 deletions internal/container/container.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"reflect"
"strings"
"sync"
)

var (
Expand All @@ -26,22 +27,29 @@ func New() *Cont {

// Cont is the container returned by New.
type Cont struct {
mu sync.Mutex
callStack map[string]bool
dependencies map[string]interface{}
}

// Build entire dependency tree
func (t *Cont) Build(dependencies map[string]interface{}) {
t.mu.Lock()
defer t.mu.Unlock()
t.dependencies = dependencies
}

// Set registers a new dependency. Provide a "packagePath.InterfaceName" as a string and your dependency, which should be an interface or struct.
func (t *Cont) Set(paramName string, dependency interface{}) {
t.mu.Lock()
defer t.mu.Unlock()
t.dependencies[paramName] = dependency
}

// GetDependency retrieve the dependency, or returns error
func (t *Cont) GetDependency(paramName string) (interface{}, error) {
t.mu.Lock()
defer t.mu.Unlock()
if dep, ok := t.dependencies[paramName]; ok {
return dep, nil
}
Expand All @@ -51,26 +59,36 @@ func (t *Cont) GetDependency(paramName string) (interface{}, error) {

// GetDependencies returns the entire dependency map
func (t *Cont) GetDependencies() map[string]interface{} {
t.mu.Lock()
defer t.mu.Unlock()
return t.dependencies
}

// Flush dependencies
func (t *Cont) Flush() {
t.mu.Lock()
defer t.mu.Unlock()
t.dependencies = make(map[string]interface{})
}

// Delete one dependency
func (t *Cont) Delete(paramName string) {
t.mu.Lock()
defer t.mu.Unlock()
delete(t.dependencies, paramName)
}

// Count returns how any dependencies provided
func (t *Cont) Count() int {
t.mu.Lock()
defer t.mu.Unlock()
return len(t.dependencies)
}

// Get resolves dependencies. Use a construct function with your dependency interface type hints. They will be resolved recursively.
func (t *Cont) Get(obj interface{}) (interface{}, error) {
t.mu.Lock()
defer t.mu.Unlock()
t.callStack = make(map[string]bool)
return t.getRecursive(obj)
}
Expand Down Expand Up @@ -185,6 +203,8 @@ func (t *Cont) resolveFunctionParam(paramType reflect.Type) (interface{}, string

// Call can invoke a function auto resolving dependencies and passing optional extra parameters at the beginning
func (t *Cont) Call(fn interface{}, params ...interface{}) ([]reflect.Value, error) {
t.mu.Lock()
defer t.mu.Unlock()
t.callStack = make(map[string]bool)
method := reflect.ValueOf(fn)
fnType := reflect.TypeOf(fn)
Expand Down
56 changes: 56 additions & 0 deletions internal/test/container/multiple-nesting_in_concurrency_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package containertest

import (
"sync"
"testing"

godi "github.com/olbrichattila/godi"
"github.com/stretchr/testify/suite"
)

type ConcurrentNestingTestSuite struct {
suite.Suite
}

func TestRunnerConcurrentNesting(t *testing.T) {
suite.Run(t, new(ConcurrentNestingTestSuite))
}

func (t *ConcurrentNestingTestSuite) TestMultipleNesting() {
container := godi.New()

nestedFirstMock := newNestedFirstMock()
nestedSecondMock := newNestedSecondMock()
nestedThirdMock := newNestedThirdMock()

container.Set("olbrichattila.godi.internal.test.container.nestedSecondInterface", nestedSecondMock)
container.Set("olbrichattila.godi.internal.test.container.nestedThirdInterface", nestedThirdMock)

var wg sync.WaitGroup
callCount := 15
wg.Add(callCount)
for i := 0; i < callCount; i++ {
go func() {
defer wg.Done()
_, err := container.Get(nestedFirstMock)
t.Nil(err)
}()
}

wg.Wait()

t.Equal(15, nestedFirstMock.ConstructorCallCount())
t.Equal(15, nestedSecondMock.ConstructorCallCount())
t.Equal(15, nestedThirdMock.ConstructorCallCount())

t.Equal(0, nestedFirstMock.MockFuncCallCount())
t.Equal(0, nestedSecondMock.MockFuncCallCount())
t.Equal(0, nestedThirdMock.MockFuncCallCount())

// Assert functions are relying on their dependencies and calling each other
nestedFirstMock.MockFunc()

t.Equal(1, nestedFirstMock.MockFuncCallCount())
t.Equal(1, nestedSecondMock.MockFuncCallCount())
t.Equal(1, nestedThirdMock.MockFuncCallCount())
}

0 comments on commit 307f091

Please sign in to comment.