diff --git a/callback.go b/callback.go index dfe47ea6..106a4591 100644 --- a/callback.go +++ b/callback.go @@ -20,6 +20,8 @@ package dig +import "reflect" + // CallbackInfo contains information about a provided function or decorator // called by Dig, and is passed to a [Callback] registered with // [WithProviderCallback] or [WithDecoratorCallback]. @@ -32,6 +34,11 @@ type CallbackInfo struct { // function, if any. When used in conjunction with [RecoverFromPanics], // this will be set to a [PanicError] when the function panics. Error error + + // Values contains all values constructed by the [Callback]'s + // associated function. These are the actual values inside the container: + // modifying them may result in undefined behaviour. + Values []reflect.Value } // Callback is a function that can be registered with a provided function diff --git a/constructor.go b/constructor.go index 034c41c2..7e75f8bf 100644 --- a/constructor.go +++ b/constructor.go @@ -160,12 +160,15 @@ func (n *constructorNode) Call(c containerStore) (err error) { } } + var values []reflect.Value + if n.callback != nil { // Wrap in separate func to include PanicErrors defer func() { n.callback(CallbackInfo{ - Name: fmt.Sprintf("%v.%v", n.location.Package, n.location.Name), - Error: err, + Name: fmt.Sprintf("%v.%v", n.location.Package, n.location.Name), + Error: err, + Values: values, }) }() } @@ -192,6 +195,9 @@ func (n *constructorNode) Call(c containerStore) (err error) { // the rest of the graph to instantiate the dependencies of this // container. receiver.Commit(n.s) + + values = n.resultList.Values(results) + n.called = true return nil } diff --git a/decorate.go b/decorate.go index df362e98..80d37421 100644 --- a/decorate.go +++ b/decorate.go @@ -121,12 +121,15 @@ func (n *decoratorNode) Call(s containerStore) (err error) { } } + var values []reflect.Value + if n.callback != nil { // Wrap in separate func to include PanicErrors defer func() { n.callback(CallbackInfo{ - Name: fmt.Sprintf("%v.%v", n.location.Package, n.location.Name), - Error: err, + Name: fmt.Sprintf("%v.%v", n.location.Package, n.location.Name), + Error: err, + Values: values, }) }() } @@ -146,6 +149,8 @@ func (n *decoratorNode) Call(s containerStore) (err error) { if err = n.results.ExtractList(n.s, true /* decorated */, results); err != nil { return err } + values = n.results.Values(results) + n.state = decoratorCalled return nil } diff --git a/dig_test.go b/dig_test.go index 5cbf4ee8..f74a7cf4 100644 --- a/dig_test.go +++ b/dig_test.go @@ -1685,6 +1685,14 @@ func TestRecoverFromPanic(t *testing.T) { func giveInt() int { return 5 } +type providedInterface interface { + Start() +} + +type providedStruct struct{} + +func (providedStruct) Start() {} + func TestCallback(t *testing.T) { t.Run("no errors", func(t *testing.T) { var ( @@ -1698,6 +1706,9 @@ func TestCallback(t *testing.T) { dig.WithProviderCallback(func(ci dig.CallbackInfo) { assert.Equal(t, "go.uber.org/dig_test.giveInt", ci.Name) assert.NoError(t, ci.Error) + assert.Len(t, ci.Values, 1) + assert.True(t, ci.Values[0].CanInt()) + assert.EqualValues(t, 5, ci.Values[0].Int()) provideCallbackCalled = true }), ) @@ -1706,6 +1717,9 @@ func TestCallback(t *testing.T) { dig.WithDecoratorCallback(func(ci dig.CallbackInfo) { assert.Equal(t, "go.uber.org/dig_test.TestCallback.func1.2", ci.Name) assert.NoError(t, ci.Error) + assert.Len(t, ci.Values, 1) + assert.True(t, ci.Values[0].CanInt()) + assert.EqualValues(t, 10, ci.Values[0].Int()) decorateCallbackCalled = true }), ) @@ -1727,6 +1741,7 @@ func TestCallback(t *testing.T) { dig.WithProviderCallback(func(ci dig.CallbackInfo) { assert.Equal(t, "go.uber.org/dig_test.TestCallback.func2.1", ci.Name) assert.ErrorContains(t, ci.Error, "terrible callback sadness") + assert.Nil(t, ci.Values) called = true }), ) @@ -1747,6 +1762,7 @@ func TestCallback(t *testing.T) { dig.WithDecoratorCallback(func(ci dig.CallbackInfo) { assert.Equal(t, "go.uber.org/dig_test.TestCallback.func3.1", ci.Name) assert.ErrorContains(t, ci.Error, "terrible callback sadness") + assert.Nil(t, ci.Values) called = true }), ) @@ -1766,6 +1782,7 @@ func TestCallback(t *testing.T) { var pe dig.PanicError assert.True(t, errors.As(ci.Error, &pe)) assert.ErrorContains(t, ci.Error, "panic: \"unreal misfortune\"") + assert.Nil(t, ci.Values) called = true }), ) @@ -1786,6 +1803,7 @@ func TestCallback(t *testing.T) { var pe dig.PanicError assert.True(t, errors.As(ci.Error, &pe)) assert.ErrorContains(t, ci.Error, "panic: \"unreal misfortune\"") + assert.Nil(t, ci.Values) called = true }), @@ -1794,6 +1812,179 @@ func TestCallback(t *testing.T) { c.Invoke(func(int) {}) assert.True(t, called) }) + t.Run("callback receives primitives", func(t *testing.T) { + var providerCallbackCalled bool + + c := digtest.New(t) + + c.RequireProvide(giveInt, dig.WithProviderCallback(func(ci dig.CallbackInfo) { + assert.Equal(t, "go.uber.org/dig_test.giveInt", ci.Name) + assert.Nil(t, ci.Error) + assert.Len(t, ci.Values, 1) + assert.True(t, ci.Values[0].CanInt()) + assert.EqualValues(t, 5, ci.Values[0].Int()) + providerCallbackCalled = true + })) + + c.RequireInvoke(func(a int) { + assert.Equal(t, 5, a) + }) + + assert.True(t, providerCallbackCalled) + }) + + t.Run("callback works with value groups", func(t *testing.T) { + var providerCallbackCalledTimes int + + c := digtest.New(t) + + c.RequireProvide(giveInt, dig.Group("test"), dig.WithProviderCallback(func(ci dig.CallbackInfo) { + assert.NotEmpty(t, ci.Name) + assert.Nil(t, ci.Error) + assert.Len(t, ci.Values, 1) + assert.True(t, ci.Values[0].CanInt()) + assert.EqualValues(t, 5, ci.Values[0].Int()) + providerCallbackCalledTimes++ + })) + c.RequireProvide(func() int { return 6 }, dig.Group("test"), dig.WithProviderCallback(func(ci dig.CallbackInfo) { + assert.NotEmpty(t, ci.Name) + assert.Nil(t, ci.Error) + assert.Len(t, ci.Values, 1) + assert.True(t, ci.Values[0].CanInt()) + assert.EqualValues(t, 6, ci.Values[0].Int()) + providerCallbackCalledTimes++ + })) + + type params struct { + dig.In + + Value []int `group:"test"` + } + + c.RequireInvoke(func(a params) { + assert.ElementsMatch(t, []int{5, 6}, a.Value) + }) + + assert.Equal(t, 2, providerCallbackCalledTimes) + }) + + t.Run("callback works with interfaces", func(t *testing.T) { + var providerCallbackCalled bool + + var gave providedInterface + + c := digtest.New(t) + + c.RequireProvide( + func() providedInterface { + gave = &providedStruct{} + return gave + }, + dig.WithProviderCallback(func(ci dig.CallbackInfo) { + assert.NotEmpty(t, ci.Name) + assert.Nil(t, ci.Error) + assert.Len(t, ci.Values, 1) + assert.True(t, ci.Values[0].CanInterface()) + _, ok := ci.Values[0].Interface().(providedInterface) + assert.True(t, ok) + providerCallbackCalled = true + }), + ) + + c.RequireInvoke(func(got providedInterface) { + assert.Equal(t, gave, got) + }) + + assert.True(t, providerCallbackCalled) + }) + + t.Run("callback works with provider returning multiple values", func(t *testing.T) { + var providerCallbackCalled bool + + c := digtest.New(t) + + c.RequireProvide( + func() (string, int) { + return "five", 5 + }, + dig.WithProviderCallback(func(ci dig.CallbackInfo) { + assert.NotEmpty(t, ci.Name) + assert.Nil(t, ci.Error) + assert.Len(t, ci.Values, 2) + assert.EqualValues(t, "five", ci.Values[0].String()) + assert.EqualValues(t, 5, ci.Values[1].Int()) + providerCallbackCalled = true + }), + ) + + c.RequireInvoke(func(s string, i int) { + assert.Equal(t, "five", s) + assert.Equal(t, 5, i) + }) + + assert.True(t, providerCallbackCalled) + }) + + t.Run("callback does not receive nil error value with providers that can fail", func(t *testing.T) { + var providerCallbackCalled bool + + c := digtest.New(t) + + c.RequireProvide( + func() (string, int, error) { + return "five", 5, nil + }, + dig.WithProviderCallback(func(ci dig.CallbackInfo) { + assert.NotEmpty(t, ci.Name) + assert.Nil(t, ci.Error) + assert.Len(t, ci.Values, 2) + assert.EqualValues(t, "five", ci.Values[0].String()) + assert.EqualValues(t, 5, ci.Values[1].Int()) + providerCallbackCalled = true + }), + ) + + c.RequireInvoke(func(s string, i int) { + assert.Equal(t, "five", s) + assert.Equal(t, 5, i) + }) + + assert.True(t, providerCallbackCalled) + }) + + t.Run("object values are exploded into single values", func(t *testing.T) { + var containerCallbackCalled bool + + c := digtest.New(t) + + type out struct { + dig.Out + + String string + Int int + } + + c.RequireProvide( + func() (out, error) { + return out{String: "five", Int: 5}, nil + }, + dig.WithProviderCallback(func(ci dig.CallbackInfo) { + assert.NotEmpty(t, ci.Name) + assert.Nil(t, ci.Error) + assert.Len(t, ci.Values, 2) + assert.EqualValues(t, "five", ci.Values[0].String()) + assert.EqualValues(t, 5, ci.Values[1].Int()) + containerCallbackCalled = true + }), + ) + + c.RequireInvoke(func(s string, i int) { + assert.Equal(t, "five", s) + assert.Equal(t, 5, i) + }) + + assert.True(t, containerCallbackCalled) + }) } func TestProvideConstructorErrors(t *testing.T) { diff --git a/result.go b/result.go index 369cd218..6be24e26 100644 --- a/result.go +++ b/result.go @@ -45,6 +45,9 @@ type result interface { // This MAY panic if the result does not consume a single value. Extract(containerWriter, bool, reflect.Value) + // GetValues returns all values contained in a result. + GetValues(reflect.Value) []reflect.Value + // DotResult returns a slice of dot.Result(s). DotResult() []*dot.Result } @@ -259,6 +262,21 @@ func (rl resultList) ExtractList(cw containerWriter, decorated bool, values []re return nil } +func (rl resultList) GetValues(values reflect.Value) []reflect.Value { + digerror.BugPanicf("resultList.GetValues() must never be called") + panic("") // Unreachable, as BugPanicf above will panic. +} + +func (rl resultList) Values(values []reflect.Value) []reflect.Value { + result := make([]reflect.Value, 0) + for i, v := range values { + if resultIdx := rl.resultIndexes[i]; resultIdx >= 0 { + result = append(result, rl.Results[resultIdx].GetValues(v)...) + } + } + return result +} + // resultSingle is an explicit value produced by a constructor, optionally // with a name. // @@ -336,6 +354,10 @@ func (rs resultSingle) Extract(cw containerWriter, decorated bool, v reflect.Val } } +func (rs resultSingle) GetValues(v reflect.Value) []reflect.Value { + return []reflect.Value{v} +} + // resultObject is a dig.Out struct where each field is another result. // // This object is not added to the graph. Its fields are interpreted as @@ -388,6 +410,14 @@ func (ro resultObject) Extract(cw containerWriter, decorated bool, v reflect.Val } } +func (ro resultObject) GetValues(v reflect.Value) []reflect.Value { + res := make([]reflect.Value, len(ro.Fields)) + for i, f := range ro.Fields { + res[i] = v.Field(f.FieldIndex) + } + return res +} + // resultObjectField is a single field inside a dig.Out struct. type resultObjectField struct { // Name of the field in the struct. @@ -533,3 +563,7 @@ func (rt resultGrouped) Extract(cw containerWriter, decorated bool, v reflect.Va cw.submitGroupedValue(rt.Group, rt.Type, v.Index(i)) } } + +func (rt resultGrouped) GetValues(v reflect.Value) []reflect.Value { + return []reflect.Value{v} +}