diff --git a/constructor.go b/constructor.go index cf58bec8..2b69006e 100644 --- a/constructor.go +++ b/constructor.go @@ -27,6 +27,7 @@ import ( "go.uber.org/dig/internal/digerror" "go.uber.org/dig/internal/digreflect" "go.uber.org/dig/internal/dot" + "go.uber.org/dig/internal/promise" ) // constructorNode is a node in the dependency graph that represents @@ -45,12 +46,15 @@ type constructorNode struct { // id uniquely identifies the constructor that produces a node. id dot.CtorID - // Whether the constructor owned by this node was already called. - called bool + // State of the underlying constructor function + state functionState // Type information about constructor parameters. paramList paramList + // The result of calling the constructor + deferred promise.Deferred + // Type information about constructor results. resultList resultList @@ -123,47 +127,72 @@ func (n *constructorNode) ID() dot.CtorID { return n.id } func (n *constructorNode) CType() reflect.Type { return n.ctype } func (n *constructorNode) Order(s *Scope) int { return n.orders[s] } func (n *constructorNode) OrigScope() *Scope { return n.origS } +func (n *constructorNode) State() functionState { return n.state } func (n *constructorNode) String() string { return fmt.Sprintf("deps: %v, ctor: %v", n.paramList, n.ctype) } -// Call calls this constructor if it hasn't already been called and -// injects any values produced by it into the provided container. -func (n *constructorNode) Call(c containerStore) error { - if n.called { - return nil +// Call calls this constructor if it hasn't already been called and injects any values produced by it into the container +// passed to newConstructorNode. +// +// If constructorNode has a unresolved deferred already in the process of building, it will return that one. If it has +// already been called, it will return an already-resolved deferred. errMissingDependencies is non-fatal; any other +// errors means this node is permanently in an error state. +// +// Don't store the returned pointer; it points into a field that may be reused on non-fatal errors. +func (n *constructorNode) Call(c containerStore) *promise.Deferred { + if n.State() == functionCalled || n.State() == functionOnStack { + return &n.deferred } + n.state = functionVisited + n.deferred = promise.Deferred{} + if err := shallowCheckDependencies(c, n.paramList); err != nil { - return errMissingDependencies{ + n.deferred.Resolve(errMissingDependencies{ Func: n.location, Reason: err, - } + }) + return &n.deferred } - args, err := n.paramList.BuildList(c) - if err != nil { + var args []reflect.Value + var results []reflect.Value + + d := n.paramList.BuildList(c, &args) + + n.state = functionOnStack + + d.Catch(func(err error) error { return errArgumentsFailed{ Func: n.location, Reason: err, } - } - - receiver := newStagingContainerWriter() - results := c.invoker()(reflect.ValueOf(n.ctor), args) - if err := n.resultList.ExtractList(receiver, false /* decorating */, results); err != nil { - return errConstructorFailed{Func: n.location, Reason: err} - } - - // Commit the result to the original container that this constructor - // was supplied to. The provided constructor is only used for a view of - // the rest of the graph to instantiate the dependencies of this - // container. - receiver.Commit(n.s) - n.called = true + }).Then(func() *promise.Deferred { + return c.scheduler().Schedule(func() { + results = c.invoker()(reflect.ValueOf(n.ctor), args) + }) + }).Then(func() *promise.Deferred { + receiver := newStagingContainerWriter() + if err := n.resultList.ExtractList(receiver, false /* decorating */, results); err != nil { + return promise.Fail(errConstructorFailed{Func: n.location, Reason: err}) + } - return nil + // Commit the result to the original container that this constructor + // was supplied to. The provided container is only used for a view of + // the rest of the graph to instantiate the dependencies of this + // container. + receiver.Commit(n.s) + n.state = functionCalled + n.deferred.Resolve(nil) + return promise.Done + }).Catch(func(err error) error { + n.state = functionCalled + n.deferred.Resolve(err) + return nil + }) + return &n.deferred } // stagingContainerWriter is a containerWriter that records the changes that diff --git a/constructor_test.go b/constructor_test.go index e17a5433..5c531f9f 100644 --- a/constructor_test.go +++ b/constructor_test.go @@ -58,10 +58,21 @@ func TestNodeAlreadyCalled(t *testing.T) { s := newScope() n, err := newConstructorNode(f, s, s, constructorOptions{}) require.NoError(t, err, "failed to build node") - require.False(t, n.called, "node must not have been called") + require.False(t, n.State() == functionCalled, "node must not have been called") c := New() - require.NoError(t, n.Call(c.scope), "invoke failed") - require.True(t, n.called, "node must be called") - require.NoError(t, n.Call(c.scope), "calling again should be okay") + d := n.Call(c.scope) + c.scope.sched.Flush() + + ok, err := d.Resolved() + require.True(t, ok, "deferred must be resolved") + require.NoError(t, err, "invoke failed") + + require.True(t, n.State() == functionCalled, "node must be called") + d = n.Call(c.scope) + c.scope.sched.Flush() + + ok, err = d.Resolved() + require.True(t, ok, "deferred must be resolved") + require.NoError(t, err, "calling again should be okay") } diff --git a/container.go b/container.go index 57d11cbf..42acb18f 100644 --- a/container.go +++ b/container.go @@ -26,6 +26,7 @@ import ( "reflect" "go.uber.org/dig/internal/dot" + "go.uber.org/dig/internal/scheduler" ) const ( @@ -142,6 +143,9 @@ type containerStore interface { // Returns invokerFn function to use when calling arguments. invoker() invokerFn + + // Returns the scheduler to use for this scope. + scheduler() scheduler.Scheduler } // New constructs a Container. @@ -231,6 +235,35 @@ func dryInvoker(fn reflect.Value, _ []reflect.Value) []reflect.Value { return results } +type maxConcurrencyOption int + +// MaxConcurrency run constructors in this container with the given level of +// concurrency: +// +// - max = 0 or 1: run one constructor at a time (this is the default) +// +// - max > 1: run at most 'max' constructors at a time +// +// - max < 0: run an unlimited number of constructors at a time +// +// Concurrency is limited by how many constructors' dependencies are satisfied at +// once and Go's own allocation of OS threads to Goroutines. This is useful for +// applications that have many slow, independent constructors. +func MaxConcurrency(max int) Option { + return maxConcurrencyOption(max) +} + +func (m maxConcurrencyOption) applyOption(container *Container) { + switch { + case m == 0, m == 1: + container.scope.sched = scheduler.Synchronous + case m > 1: + container.scope.sched = scheduler.NewParallel(int(m)) + case m < 0: + container.scope.sched = new(scheduler.Unbounded) + } +} + // String representation of the entire Container func (c *Container) String() string { return c.scope.String() diff --git a/decorate.go b/decorate.go index eb9c2413..7362b03c 100644 --- a/decorate.go +++ b/decorate.go @@ -27,20 +27,23 @@ import ( "go.uber.org/dig/internal/digreflect" "go.uber.org/dig/internal/dot" + "go.uber.org/dig/internal/promise" ) -type decoratorState int +type functionState int const ( - decoratorReady decoratorState = iota - decoratorOnStack - decoratorCalled + functionReady functionState = iota + functionVisited // For avoiding cycles + functionOnStack // For telling that this function is already scheduled + functionCalled ) type decorator interface { - Call(c containerStore) error + Call(c containerStore) *promise.Deferred ID() dot.CtorID - State() decoratorState + State() functionState + OrigScope() *Scope } type decoratorNode struct { @@ -53,11 +56,14 @@ type decoratorNode struct { location *digreflect.Func // Current state of this decorator - state decoratorState + state functionState // Parameters of the decorator. params paramList + // The result of calling the constructor + deferred promise.Deferred + // Results of the decorator. results resultList @@ -96,39 +102,69 @@ func newDecoratorNode(dcor interface{}, s *Scope) (*decoratorNode, error) { return n, nil } -func (n *decoratorNode) Call(s containerStore) error { - if n.state == decoratorCalled { - return nil +// Call calls this decorator if it hasn't already been called and injects any values produced by it into the container +// passed to newConstructorNode. +// +// If constructorNode has a unresolved deferred already in the process of building, it will return that one. If it has +// already been successfully called, it will return an already-resolved deferred. Together these mean it will try the +// call again if it failed last time. +// +// On failure, the returned pointer is not guaranteed to stay in a failed state; another call will reset it back to its +// zero value; don't store the returned pointer. (It will still call each observer only once.) +func (n *decoratorNode) Call(s containerStore) *promise.Deferred { + if n.state == functionOnStack || n.state == functionCalled { + return &n.deferred } - n.state = decoratorOnStack + // We mark it as "visited" to avoid cycles + n.state = functionVisited + n.deferred = promise.Deferred{} if err := shallowCheckDependencies(s, n.params); err != nil { - return errMissingDependencies{ + n.deferred.Resolve(errMissingDependencies{ Func: n.location, Reason: err, - } + }) } - args, err := n.params.BuildList(n.s) - if err != nil { - return errArgumentsFailed{ - Func: n.location, - Reason: err, + var args []reflect.Value + d := n.params.BuildList(s, &args) + + n.state = functionOnStack + + d.Observe(func(err error) { + if err != nil { + n.state = functionCalled + n.deferred.Resolve(errArgumentsFailed{ + Func: n.location, + Reason: err, + }) + return } - } - results := reflect.ValueOf(n.dcor).Call(args) - if err := n.results.ExtractList(n.s, true /* decorated */, results); err != nil { - return err - } - n.state = decoratorCalled - return nil + var results []reflect.Value + + s.scheduler().Schedule(func() { + results = s.invoker()(reflect.ValueOf(n.dcor), args) + }).Observe(func(_ error) { + if err := n.results.ExtractList(n.s, true /* decorated */, results); err != nil { + n.deferred.Resolve(err) + return + } + + n.state = functionCalled + n.deferred.Resolve(nil) + }) + }) + + return &n.deferred } func (n *decoratorNode) ID() dot.CtorID { return n.id } -func (n *decoratorNode) State() decoratorState { return n.state } +func (n *decoratorNode) State() functionState { return n.state } + +func (n *decoratorNode) OrigScope() *Scope { return n.s } // DecorateOption modifies the default behavior of Decorate. type DecorateOption interface { diff --git a/dig_test.go b/dig_test.go index 5b78b853..2523215f 100644 --- a/dig_test.go +++ b/dig_test.go @@ -33,6 +33,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "go.uber.org/atomic" "go.uber.org/dig" "go.uber.org/dig/internal/digtest" ) @@ -3653,3 +3654,93 @@ func TestEndToEndSuccessWithAliases(t *testing.T) { }) }) } + +func TestConcurrency(t *testing.T) { + // Ensures providers will run at the same time + t.Run("TestMaxConcurrency", func(t *testing.T) { + t.Parallel() + + type ( + A int + B int + C int + ) + + const max = 3 + + var ( + timer = time.NewTimer(10 * time.Second) + done = make(chan struct{}) + running = atomic.Int32{} + waitForUs = func() error { + if running.Inc() == max { + close(done) + } + select { + case <-timer.C: + return errors.New("timeout expired") + case <-done: + return nil + } + } + c = digtest.New(t, dig.MaxConcurrency(max)) + ) + + c.RequireProvide(func() (A, error) { return 0, waitForUs() }) + c.RequireProvide(func() (B, error) { return 1, waitForUs() }) + c.RequireProvide(func() (C, error) { return 2, waitForUs() }) + + c.RequireInvoke(func(a A, b B, c C) { + require.Equal(t, a, A(0)) + require.Equal(t, b, B(1)) + require.Equal(t, c, C(2)) + require.Equal(t, running.Load(), int32(max)) + }) + }) + + t.Run("TestUnboundConcurrency", func(t *testing.T) { + t.Parallel() + + const max = 20 + + var ( + timer = time.NewTimer(10 * time.Second) + done = make(chan struct{}) + running = atomic.NewInt32(0) + waitForUs = func() error { + if running.Inc() == max { + close(done) + } + select { + case <-timer.C: + return errors.New("timeout expired") + case <-done: + return nil + } + } + c = digtest.New(t, dig.MaxConcurrency(-1)) + expected []int + ) + + for i := 0; i < max; i++ { + i := i + expected = append(expected, i) + type out struct { + dig.Out + + Value int `group:"a"` + } + c.RequireProvide(func() (out, error) { return out{Value: i}, waitForUs() }) + } + + type in struct { + dig.In + + Values []int `group:"a"` + } + + c.RequireInvoke(func(i in) { + require.ElementsMatch(t, expected, i.Values) + }) + }) +} diff --git a/doc.go b/doc.go index b8268eb4..3f9ab85c 100644 --- a/doc.go +++ b/doc.go @@ -98,7 +98,7 @@ // // # Invoke // -// Types added to to the container may be consumed by using the Invoke method. +// Types added to the container may be consumed by using the Invoke method. // Invoke accepts any function that accepts one or more parameters and // optionally, returns an error. Dig calls the function with the requested // type, instantiating only those types that were requested by the function. diff --git a/go.mod b/go.mod index 1b23da24..d629189f 100644 --- a/go.mod +++ b/go.mod @@ -2,7 +2,10 @@ module go.uber.org/dig go 1.17 -require github.com/stretchr/testify v1.7.1 +require ( + github.com/stretchr/testify v1.7.1 + go.uber.org/atomic v1.9.0 +) require ( github.com/davecgh/go-spew v1.1.1 // indirect diff --git a/go.sum b/go.sum index 8d61fd53..6ef4bd51 100644 --- a/go.sum +++ b/go.sum @@ -4,8 +4,11 @@ github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE= +go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/promise/deferred.go b/internal/promise/deferred.go new file mode 100644 index 00000000..41c8b367 --- /dev/null +++ b/internal/promise/deferred.go @@ -0,0 +1,152 @@ +// Copyright (c) 2022 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package promise + +// Observer is a function that gets called when a Deferred resolves. +type Observer func(error) + +// Deferred is an observable future result that may fail. +// Its zero value is unresolved and has no observers. +// It can be resolved once, at which point every observer will be called. +type Deferred struct { + observers []Observer + settled bool + err error +} + +// Resolved reports whether this Deferred has resolved, +// and if so, with what error. +// +// err is undefined if the Deferred has not yet resolved. +func (d *Deferred) Resolved() (resolved bool, err error) { + return d.settled, d.err +} + +// Done is a Deferred that has already been resolved with a nil error. +var Done = &Deferred{settled: true} + +// Fail returns a Deferred that has resolved with the given error. +func Fail(err error) *Deferred { + return &Deferred{settled: true, err: err} +} + +// Observe registers an observer to receive a callback when this deferred is +// resolved. +// It will be called at most one time. +// If this deferred is already resolved, the observer is called immediately, +// before Observe returns. +func (d *Deferred) Observe(obs Observer) { + if d.settled { + obs(d.err) + } else { + d.observers = append(d.observers, obs) + } +} + +// Resolve sets the status of this deferred and notifies all observers. +// This is a no-op if the Deferred has already resolved. +func (d *Deferred) Resolve(err error) { + if d.settled { + return + } + + d.settled = true + d.err = err + for _, obs := range d.observers { + obs(err) + } + d.observers = nil +} + +// Then returns a new Deferred that resolves with the same error as this +// Deferred or the eventual result of the Deferred returned by res. +func (d *Deferred) Then(res func() *Deferred) *Deferred { + // Shortcut: if we're settled... + if d.settled { + if d.err == nil { + // ...successfully, then return the other deferred + return res() + } + + // ...with an error, then return us + return d + } + + d2 := new(Deferred) + d.Observe(func(err error) { + if err != nil { + d2.Resolve(err) + } else { + res().Observe(d2.Resolve) + } + }) + return d2 +} + +// Catch maps any error from this deferred using the supplied function. +// The supplied function is only called if this deferred is resolved with an +// error. +// If the supplied function returns a nil error, the new deferred will resolve +// successfully. +func (d *Deferred) Catch(rej func(error) error) *Deferred { + d2 := new(Deferred) + d.Observe(func(err error) { + if err != nil { + err = rej(err) + } + d2.Resolve(err) + }) + return d2 +} + +// WhenAll returns a new Deferred that resolves when all the supplied deferreds +// resolve. +// It resolves with the first error reported by any deferred, or nil if they +// all succeed. +func WhenAll(others ...*Deferred) *Deferred { + if len(others) == 0 { + return Done + } + + d := new(Deferred) + count := len(others) + + onResolved := func(err error) { + if d.settled { + return + } + + if err != nil { + d.Resolve(err) + } + + count-- + if count == 0 { + d.Resolve(nil) + } + } + + for _, other := range others { + other.Observe(onResolved) + } + + return d +} diff --git a/internal/scheduler/parallel.go b/internal/scheduler/parallel.go new file mode 100644 index 00000000..548389ae --- /dev/null +++ b/internal/scheduler/parallel.go @@ -0,0 +1,96 @@ +// Copyright (c) 2022 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package scheduler + +import "go.uber.org/dig/internal/promise" + +// task is used by parallelScheduler to remember which function to +// call and which deferred to notify afterwards. +type task struct { + fn func() + d *promise.Deferred +} + +// Parallel processes enqueued work using a fixed-size worker pool. +// The pool is started and stopped during the call to flush. +type Parallel struct { + concurrency int + tasks []task +} + +var _ Scheduler = (*Parallel)(nil) + +// NewParallel builds a new parallel scheduler that will use the specified +// number of goroutines to run tasks. +func NewParallel(concurrency int) *Parallel { + return &Parallel{concurrency: concurrency} +} + +// Schedule enqueues a task and returns an unresolved deferred. +// It will be resolved during flush. +func (p *Parallel) Schedule(fn func()) *promise.Deferred { + d := new(promise.Deferred) + p.tasks = append(p.tasks, task{fn, d}) + return d +} + +// Flush processes enqueued work. +// concurrency controls how many executor goroutines are started and thus the +// maximum number of calls that may proceed in parallel. +// The real level of concurrency may be lower for CPU-heavy workloads if Go +// doesn't assign these goroutines to OS threads. +func (p *Parallel) Flush() { + inFlight := 0 + taskChan := make(chan task) + resultChan := make(chan *promise.Deferred) + + for n := 0; n < p.concurrency; n++ { + go func() { + for t := range taskChan { + t.fn() + resultChan <- t.d + } + }() + } + + for inFlight > 0 || len(p.tasks) > 0 { + var t task + var outChan chan<- task + + if len(p.tasks) > 0 { + t = p.tasks[len(p.tasks)-1] + outChan = taskChan + } + + select { + case outChan <- t: + inFlight++ + p.tasks = p.tasks[:len(p.tasks)-1] + case d := <-resultChan: + inFlight-- + d.Resolve(nil) + } + } + + close(taskChan) + close(resultChan) + p.tasks = nil +} diff --git a/internal/scheduler/scheduler.go b/internal/scheduler/scheduler.go new file mode 100644 index 00000000..8c6025c2 --- /dev/null +++ b/internal/scheduler/scheduler.go @@ -0,0 +1,70 @@ +// Copyright (c) 2022 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package scheduler + +import "go.uber.org/dig/internal/promise" + +// Scheduler queues work during resolution of params. +// +// constructorNode uses it to call its constructor function. +// This may happen in parallel with other calls (parallelScheduler) or +// synchronously, right when enqueued. +// +// Work is enqueued when building a paramList, but the user of scheduler +// must call flush() for asynchronous calls to proceed after the top-level +// paramList.BuildList() is called. +type Scheduler interface { + // schedule will call a the supplied func. The deferred will resolve + // after the func is called. The func may be called before schedule + // returns. The deferred will be resolved on the "main" goroutine, so + // it's safe to mutate containerStore during its resolution. It will + // always be resolved with a nil error. + Schedule(func()) *promise.Deferred + + // flush processes enqueued work. This may in turn enqueue more work; + // flush will keep processing the work until it's empty. After flush is + // called, every deferred returned from schedule will have been resolved. + // Asynchronous deferred values returned from schedule are resolved on the + // same goroutine as the one calling this method. + // + // The scheduler is ready for re-use after flush is called. + Flush() +} + +// Synchronous is a stateless synchronous scheduler. +// It invokes functions as soon as they are scheduled. +// This is equivalent to not using a concurrent scheduler at all. +var Synchronous = synchronous{} + +// synchronous is stateless and calls funcs as soon as they are schedule. It produces +// the exact same results as the code before deferred was introduced. +type synchronous struct{} + +var _ Scheduler = synchronous{} + +// schedule calls func and returns an already-resolved deferred. +func (s synchronous) Schedule(fn func()) *promise.Deferred { + fn() + return promise.Done +} + +// flush does nothing. All returned deferred values are already resolved. +func (s synchronous) Flush() {} diff --git a/internal/scheduler/unbounded.go b/internal/scheduler/unbounded.go new file mode 100644 index 00000000..af001a08 --- /dev/null +++ b/internal/scheduler/unbounded.go @@ -0,0 +1,65 @@ +// Copyright (c) 2022 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package scheduler + +import "go.uber.org/dig/internal/promise" + +// Unbounded starts a goroutine per task. +// Maximum concurrency is controlled by Go's allocation of OS threads to +// goroutines. +type Unbounded struct { + tasks []task +} + +var _ Scheduler = (*Unbounded)(nil) + +// Schedule enqueues a task and returns an unresolved deferred. +// It will be resolved during flush. +func (p *Unbounded) Schedule(fn func()) *promise.Deferred { + d := new(promise.Deferred) + p.tasks = append(p.tasks, task{fn, d}) + return d +} + +// Flush processes enqueued work with unlimited concurrency. +// The actual limit is up to Go's allocation of OS resources to goroutines. +func (p *Unbounded) Flush() { + inFlight := 0 + resultChan := make(chan *promise.Deferred) + + for inFlight > 0 || len(p.tasks) > 0 { + if len(p.tasks) > 0 { + t := p.tasks[len(p.tasks)-1] + p.tasks = p.tasks[:len(p.tasks)-1] + go func() { + t.fn() + resultChan <- t.d + }() + inFlight++ + continue + } + d := <-resultChan + inFlight-- + d.Resolve(nil) + } + close(resultChan) + p.tasks = nil +} diff --git a/invoke.go b/invoke.go index e154bdde..7054018e 100644 --- a/invoke.go +++ b/invoke.go @@ -82,7 +82,14 @@ func (s *Scope) Invoke(function interface{}, opts ...InvokeOption) error { s.isVerifiedAcyclic = true } - args, err := pl.BuildList(s) + var args []reflect.Value + + d := pl.BuildList(s, &args) + d.Observe(func(err2 error) { + err = err2 + }) + s.sched.Flush() + if err != nil { return errArgumentsFailed{ Func: digreflect.InspectFunc(function), diff --git a/param.go b/param.go index e7d0a8ea..4d4f9507 100644 --- a/param.go +++ b/param.go @@ -29,6 +29,7 @@ import ( "go.uber.org/dig/internal/digerror" "go.uber.org/dig/internal/dot" + "go.uber.org/dig/internal/promise" ) // The param interface represents a dependency for a constructor. @@ -47,10 +48,13 @@ type param interface { fmt.Stringer // Build this dependency and any of its dependencies from the provided - // Container. + // Container. It stores the result in the pointed-to reflect.Value, allocating + // it first if it points to an invalid reflect.Value. + // + // Build returns a deferred that resolves once the reflect.Value is filled in. // // This MAY panic if the param does not produce a single value. - Build(store containerStore) (reflect.Value, error) + Build(store containerStore, target *reflect.Value) *promise.Deferred // DotParam returns a slice of dot.Param(s). DotParam() []*dot.Param @@ -138,23 +142,21 @@ func newParamList(ctype reflect.Type, c containerStore) (paramList, error) { return pl, nil } -func (pl paramList) Build(containerStore) (reflect.Value, error) { +func (pl paramList) Build(containerStore, *reflect.Value) *promise.Deferred { digerror.BugPanicf("paramList.Build() must never be called") panic("") // Unreachable, as BugPanicf above will panic. } -// BuildList returns an ordered list of values which may be passed directly -// to the underlying constructor. -func (pl paramList) BuildList(c containerStore) ([]reflect.Value, error) { - args := make([]reflect.Value, len(pl.Params)) +// BuildList builds an ordered list of values which may be passed directly +// to the underlying constructor and stores them in the pointed-to slice. +// It returns a deferred that resolves when the slice is filled out. +func (pl paramList) BuildList(c containerStore, targets *[]reflect.Value) *promise.Deferred { + children := make([]*promise.Deferred, len(pl.Params)) + *targets = make([]reflect.Value, len(pl.Params)) for i, p := range pl.Params { - var err error - args[i], err = p.Build(c) - if err != nil { - return nil, err - } + children[i] = p.Build(c, &(*targets)[i]) } - return args, nil + return promise.WhenAll(children...) } // paramSingle is an explicitly requested type, optionally with a name. @@ -198,34 +200,22 @@ func (ps paramSingle) String() string { return fmt.Sprintf("%v[%v]", ps.Type, strings.Join(opts, ", ")) } -// search the given container and its ancestors for a decorated value. -func (ps paramSingle) getDecoratedValue(c containerStore) (reflect.Value, bool) { - for _, c := range c.storesToRoot() { - if v, ok := c.getDecoratedValue(ps.Name, ps.Type); ok { - return v, ok - } - } - return _noValue, false -} - -// builds the parameter using decorators in all scopes that affect the -// current scope, if there are any. If there are multiple Scopes that decorates -// this parameter, the closest one to the Scope that invoked this will be used. -// If there are no decorators associated with this parameter, _noValue is returned. -func (ps paramSingle) buildWithDecorators(c containerStore) (v reflect.Value, found bool, err error) { +func (ps paramSingle) buildWithDecorators(c containerStore, target *reflect.Value) (*promise.Deferred, bool) { var ( d decorator decoratingScope containerStore + found bool ) - stores := c.storesToRoot() - for _, s := range stores { + def := new(promise.Deferred) + for _, s := range c.storesToRoot() { if d, found = s.getValueDecorator(ps.Name, ps.Type); !found { continue } - if d.State() == decoratorOnStack { - // This decorator is already being run. - // Avoid a cycle and look further. + // This is for avoiding cycles i.e decorator -> function + // ^ | + // \ ------- / + if d.State() == functionVisited { d = nil continue } @@ -233,81 +223,96 @@ func (ps paramSingle) buildWithDecorators(c containerStore) (v reflect.Value, fo break } if !found || d == nil { - return _noValue, false, nil + return promise.Done, false } - if err = d.Call(decoratingScope); err != nil { - v, err = _noValue, errParamSingleFailed{ - CtorID: 1, - Key: key{t: ps.Type, name: ps.Name}, - Reason: err, + d.Call(decoratingScope).Observe(func(err error) { + if err != nil { + def.Resolve(errParamSingleFailed{ + CtorID: d.ID(), + Key: key{t: ps.Type, name: ps.Name}, + Reason: err, + }) + return } - return v, found, err - } - v, _ = decoratingScope.getDecoratedValue(ps.Name, ps.Type) - return + v, _ := decoratingScope.getDecoratedValue(ps.Name, ps.Type) + if v.IsValid() { + target.Set(v) + } + def.Resolve(nil) + }) + return def, found } -func (ps paramSingle) Build(c containerStore) (reflect.Value, error) { - v, found, err := ps.buildWithDecorators(c) - if found { - return v, err - } - - // Check whether the value is a decorated value first. - if v, ok := ps.getDecoratedValue(c); ok { - return v, nil - } - - // Starting at the given container and working our way up its parents, - // find one that provides this dependency. - // - // Once found, we'll use that container for the rest of the invocation. - // Dependencies of this type will begin searching at that container, - // rather than starting at base. - var providers []provider +func (ps paramSingle) build(c containerStore, target *reflect.Value) *promise.Deferred { var providingContainer containerStore + var providers []provider for _, container := range c.storesToRoot() { - // first check if the scope already has cached a value for the type. + // First we check if the value it's stored in the current store if v, ok := container.getValue(ps.Name, ps.Type); ok { - return v, nil + if v.IsValid() { + target.Set(v) + } + return promise.Done } + providers = container.getValueProviders(ps.Name, ps.Type) if len(providers) > 0 { providingContainer = container break } } - if len(providers) == 0 { if ps.Optional { - return reflect.Zero(ps.Type), nil + target.Set(reflect.Zero(ps.Type)) + return promise.Done } - return _noValue, newErrMissingTypes(c, key{name: ps.Name, t: ps.Type}) + return promise.Fail(newErrMissingTypes(c, key{name: ps.Name, t: ps.Type})) } - + var children []*promise.Deferred + def := new(promise.Deferred) for _, n := range providers { - err := n.Call(n.OrigScope()) - if err == nil { - continue + child := n.Call(n.OrigScope()).Catch(func(err error) error { + // If we're missing dependencies but the parameter itself is optional, + // we can just move on. + if _, ok := err.(errMissingDependencies); ok && ps.Optional { + return nil + } + return errParamSingleFailed{ + CtorID: n.ID(), + Key: key{t: ps.Type, name: ps.Name}, + Reason: err, + } + }) + children = append(children, child) + } + return promise.WhenAll(children...).Then(func() *promise.Deferred { + // If we get here, it's impossible for the value to be absent from the + // container. + v, _ := providingContainer.getValue(ps.Name, ps.Type) + if v.IsValid() { + target.Set(v) } + def.Resolve(nil) + return def + }) +} - // If we're missing dependencies but the parameter itself is optional, - // we can just move on. - if _, ok := err.(errMissingDependencies); ok && ps.Optional { - return reflect.Zero(ps.Type), nil - } +func (ps paramSingle) Build(c containerStore, target *reflect.Value) *promise.Deferred { + if !target.IsValid() { + *target = reflect.New(ps.Type).Elem() + } + + // try building with decorators first, in case this parameter has decorators. + d, found := ps.buildWithDecorators(c, target) - return _noValue, errParamSingleFailed{ - CtorID: n.ID(), - Key: key{t: ps.Type, name: ps.Name}, - Reason: err, + return d.Then(func() *promise.Deferred { + // Check whether the value is a decorated value first. + if found { + return promise.Done } - } - // If we get here, it's impossible for the value to be absent from the - // container. - v, _ = providingContainer.getValue(ps.Name, ps.Type) - return v, nil + return ps.build(c, target) + }) } // paramObject is a dig.In struct where each field is another param. @@ -394,8 +399,10 @@ func newParamObject(t reflect.Type, c containerStore) (paramObject, error) { return po, nil } -func (po paramObject) Build(c containerStore) (reflect.Value, error) { - dest := reflect.New(po.Type).Elem() +func (po paramObject) Build(c containerStore, target *reflect.Value) *promise.Deferred { + if !target.IsValid() { + *target = reflect.New(po.Type).Elem() + } // We have to build soft groups after all other fields, to avoid cases // when a field calls a provider for a soft value group, but the value is // not provided to it because the value group is declared before the field @@ -408,15 +415,21 @@ func (po paramObject) Build(c containerStore) (reflect.Value, error) { } fields = append(fields, f) } - fields = append(fields, softGroupsQueue...) - for _, f := range fields { - v, err := f.Build(c) - if err != nil { - return dest, err + + buildFields := func(fields []paramObjectField) *promise.Deferred { + children := make([]*promise.Deferred, len(fields)) + + for i, f := range fields { + field := target.Field(f.FieldIndex) + children[i] = f.Build(c, &field) } - dest.Field(f.FieldIndex).Set(v) + + return promise.WhenAll(children...) } - return dest, nil + + return buildFields(fields).Then(func() *promise.Deferred { + return buildFields(softGroupsQueue) + }) } // paramObjectField is a single field of a dig.In struct. @@ -482,12 +495,8 @@ func newParamObjectField(idx int, f reflect.StructField, c containerStore) (para return pof, nil } -func (pof paramObjectField) Build(c containerStore) (reflect.Value, error) { - v, err := pof.Param.Build(c) - if err != nil { - return v, err - } - return v, nil +func (pof paramObjectField) Build(c containerStore, target *reflect.Value) *promise.Deferred { + return pof.Param.Build(c, target) } // paramGroupedSlice is a param which produces a slice of values with the same @@ -574,84 +583,90 @@ func (pt paramGroupedSlice) getDecoratedValues(c containerStore) (reflect.Value, } // search the given container and its parents for matching group decorators -// and call them to commit values. If any decorators return an error, -// that error is returned immediately. If all decorators succeeds, nil is returned. // The order in which the decorators are invoked is from the top level scope to // the current scope, to account for decorators that decorate values that were // already decorated. -func (pt paramGroupedSlice) callGroupDecorators(c containerStore) error { +func (pt paramGroupedSlice) callGroupDecorators(c containerStore) *promise.Deferred { + var children []*promise.Deferred stores := c.storesToRoot() for i := len(stores) - 1; i >= 0; i-- { c := stores[i] - if d, found := c.getGroupDecorator(pt.Group, pt.Type.Elem()); found { - if d.State() == decoratorOnStack { + if d, ok := c.getGroupDecorator(pt.Group, pt.Type.Elem()); ok { + + if d.State() == functionVisited { // This decorator is already being run. Avoid cycle // and look further. continue } - if err := d.Call(c); err != nil { + + child := d.Call(c) + children = append(children, child.Catch(func(err error) error { return errParamGroupFailed{ CtorID: d.ID(), Key: key{group: pt.Group, t: pt.Type.Elem()}, Reason: err, } - } + })) } } - return nil + return promise.WhenAll(children...) } // search the given container and its parent for matching group providers and // call them to commit values. If an error is encountered, return the number // of providers called and a non-nil error from the first provided. -func (pt paramGroupedSlice) callGroupProviders(c containerStore) (int, error) { - itemCount := 0 +func (pt paramGroupedSlice) callGroupProviders(c containerStore) *promise.Deferred { + var children []*promise.Deferred for _, c := range c.storesToRoot() { providers := c.getGroupProviders(pt.Group, pt.Type.Elem()) - itemCount += len(providers) for _, n := range providers { - if err := n.Call(c); err != nil { - return 0, errParamGroupFailed{ + n := n + child := n.Call(c) + children = append(children, child.Catch(func(err error) error { + return errParamGroupFailed{ CtorID: n.ID(), Key: key{group: pt.Group, t: pt.Type.Elem()}, Reason: err, } - } + })) } } - return itemCount, nil + return promise.WhenAll(children...) } -func (pt paramGroupedSlice) Build(c containerStore) (reflect.Value, error) { +func (pt paramGroupedSlice) Build(c containerStore, target *reflect.Value) *promise.Deferred { // do not call this if we are already inside a decorator since // it will result in an infinite recursion. (i.e. decorate -> params.BuildList() -> Decorate -> params.BuildList...) // this is safe since a value can be decorated at most once in a given scope. - if err := pt.callGroupDecorators(c); err != nil { - return _noValue, err - } - - // Check if we have decorated values - if decoratedItems, ok := pt.getDecoratedValues(c); ok { - return decoratedItems, nil - } + d := pt.callGroupDecorators(c) + + return d.Then(func() *promise.Deferred { + // Check if we have decorated values + if decoratedItems, ok := pt.getDecoratedValues(c); ok { + if !target.IsValid() { + newCap := 0 + if decoratedItems.Kind() == reflect.Slice { + newCap = decoratedItems.Len() + } + *target = reflect.MakeSlice(pt.Type, 0, newCap) + } - // If we do not have any decorated values and the group isn't soft, - // find the providers and call them. - itemCount := 0 - if !pt.Soft { - var err error - itemCount, err = pt.callGroupProviders(c) - if err != nil { - return _noValue, err + target.Set(decoratedItems) + return promise.Done } - } - - stores := c.storesToRoot() - result := reflect.MakeSlice(pt.Type, 0, itemCount) - for _, c := range stores { - result = reflect.Append(result, c.getValueGroup(pt.Group, pt.Type.Elem())...) - } - return result, nil + setValues := func() *promise.Deferred { + for _, c := range c.storesToRoot() { + target.Set(reflect.Append(*target, c.getValueGroup(pt.Group, pt.Type.Elem())...)) + } + return promise.Done + } + if pt.Soft { + return setValues() + } + // If we do not have any decorated values and the group isn't soft, + // find the providers and call them. + return pt.callGroupProviders(c).Then(setValues) + }) } // Checks if ignoring unexported files in an In struct is allowed. diff --git a/param_test.go b/param_test.go index 7a1f41ed..8f38c3bd 100644 --- a/param_test.go +++ b/param_test.go @@ -33,7 +33,8 @@ func TestParamListBuild(t *testing.T) { p, err := newParamList(reflect.TypeOf(func() io.Writer { return nil }), newScope()) require.NoError(t, err) assert.Panics(t, func() { - p.Build(newScope()) + var target reflect.Value + p.Build(newScope(), &target) }) } diff --git a/provide.go b/provide.go index 1cf808c5..f18600df 100644 --- a/provide.go +++ b/provide.go @@ -30,6 +30,7 @@ import ( "go.uber.org/dig/internal/digreflect" "go.uber.org/dig/internal/dot" "go.uber.org/dig/internal/graph" + "go.uber.org/dig/internal/promise" ) // A ProvideOption modifies the default behavior of Provide. @@ -358,7 +359,7 @@ type provider interface { // // The values produced by this provider should be submitted into the // containerStore. - Call(containerStore) error + Call(store containerStore) *promise.Deferred CType() reflect.Type diff --git a/scope.go b/scope.go index 0c6498eb..d84ab6e4 100644 --- a/scope.go +++ b/scope.go @@ -27,6 +27,8 @@ import ( "reflect" "sort" "time" + + "go.uber.org/dig/internal/scheduler" ) // A ScopeOption modifies the default behavior of Scope; currently, @@ -78,6 +80,8 @@ type Scope struct { // invokerFn calls a function with arguments provided to Provide or Invoke. invokerFn invokerFn + sched scheduler.Scheduler + // graph of this Scope. Note that this holds the dependency graph of all the // nodes that affect this Scope, not just the ones provided directly to this Scope. gh *graphHolder @@ -99,6 +103,7 @@ func newScope() *Scope { decoratedGroups: make(map[key]reflect.Value), invokerFn: defaultInvoker, rand: rand.New(rand.NewSource(time.Now().UnixNano())), + sched: scheduler.Synchronous, } s.gh = newGraphHolder(s) return s @@ -115,6 +120,7 @@ func (s *Scope) Scope(name string, opts ...ScopeOption) *Scope { child.parentScope = s child.invokerFn = s.invokerFn child.deferAcyclicVerification = s.deferAcyclicVerification + child.sched = s.sched // child copies the parent's graph nodes. child.gh.nodes = append(child.gh.nodes, s.gh.nodes...) @@ -258,6 +264,10 @@ func (s *Scope) invoker() invokerFn { return s.invokerFn } +func (s *Scope) scheduler() scheduler.Scheduler { + return s.sched +} + // adds a new graphNode to this Scope and all of its descendent // scope. func (s *Scope) newGraphNode(wrapped interface{}, orders map[*Scope]int) {