diff --git a/dig.go b/dig.go index dc3c2363..18baf2f9 100644 --- a/dig.go +++ b/dig.go @@ -129,6 +129,60 @@ func Group(group string) ProvideOption { }) } +// GroupInvoke runs the given function after instantiating its dependencies within multiple containers +func GroupInvoke(function interface{}, containers ...*Container) error { + arguments, err := resolve(function, containers...) + if err != nil { + return err + } + + reflect.ValueOf(function).Call(arguments) + + return nil +} + +func resolve(function interface{}, containers ...*Container) ([]reflect.Value, error) { + ftype := reflect.TypeOf(function) + arguments := ftype.NumIn() + result := make([]reflect.Value, arguments) + if ftype == nil { + return nil, errors.New("can't invoke an untyped nil") + } + if ftype.Kind() != reflect.Func { + return nil, errf("can't invoke non-function %v (type %v)", function, ftype) + } + + pl, err := newParamList(ftype) + if err != nil { + return nil, err + } + + for _, c := range containers { + if !c.isVerifiedAcyclic { + if err := c.verifyAcyclic(); err != nil { + return nil, err + } + } + + args, err := pl.UnsafeBuildList(c) + if err != nil { + return nil, err + } + + for i, a := range args { + if a.IsValid() { + result[i] = reflect.ValueOf(a.Interface()) + } + } + } + + if len(result) != arguments { + return nil, errors.New("parameters count does not match") + } + + return result, nil +} + // ID is a unique integer representing the constructor node in the dependency graph. type ID int diff --git a/dig_test.go b/dig_test.go index a907f619..277af74b 100644 --- a/dig_test.go +++ b/dig_test.go @@ -3112,3 +3112,47 @@ func TestProvideInfoOption(t *testing.T) { assert.Equal(t, "*dig.type4", info2.Outputs[0].String()) }) } + +func TestGroupInvoke(t *testing.T) { + type TestParam struct { + Name string + Value string + } + + type TestParam1 struct { + AdditionaInfo string + } + + singletonIOC := New() + singletonIOC.Provide(func() *TestParam { + return &TestParam{ + Name: "TestName", + Value: "TestValue", + } + }) + + customIOC := New() + customIOC.Provide(func() *TestParam1 { + return &TestParam1{ + AdditionaInfo: "Some info", + } + }) + + function := func(p *TestParam, p1 *TestParam1) { + res1 := &TestParam{ + Name: "TestName", + Value: "TestValue", + } + + res2 := &TestParam1{ + AdditionaInfo: "Some info", + } + + assert.Equal(t, res1, p) + assert.Equal(t, res2, p1) + } + + if err := GroupInvoke(function, singletonIOC, customIOC); err != nil { + assert.FailNow(t, err.Error()) + } +} diff --git a/param.go b/param.go index 0979228a..e8def84f 100644 --- a/param.go +++ b/param.go @@ -206,6 +206,20 @@ func (pl paramList) BuildList(c containerStore) ([]reflect.Value, error) { return args, nil } +// UnsafeBuildList returns an ordered list of values which may be passed directly +// to the underlying constructor without interruption in case of missing field. +func (pl paramList) UnsafeBuildList(c containerStore) ([]reflect.Value, error) { + args := make([]reflect.Value, len(pl.Params)) + for i, p := range pl.Params { + var err error + args[i], err = p.Build(c) + if err != nil { + continue + } + } + return args, nil +} + // paramSingle is an explicitly requested type, optionally with a name. // // This object must be present in the graph as-is unless it's specified as