From 8f646e9b1d6f4110c5d3408fa7b868c14c9d0446 Mon Sep 17 00:00:00 2001 From: David Sharnoff Date: Mon, 17 Apr 2023 20:30:01 -0700 Subject: [PATCH] fix: Cluster() did not work as documented (#67) --- api.go | 14 ++++++++---- debug_test.go | 8 +++++-- example_cluster_test.go | 5 +++++ include.go | 50 ++++++++++++++++++++++++----------------- regressions_test.go | 41 +++++++++++++++++++++++++++++++++ 5 files changed, 92 insertions(+), 26 deletions(-) diff --git a/api.go b/api.go index 85575a0..c3eae9c 100644 --- a/api.go +++ b/api.go @@ -73,12 +73,18 @@ var clusterId int32 = 1 // be included at all. It also downgrades providers that are in the // cluster that would normally be considered desired because they don't // return anything and aren't wrappers: they're no longer automatically -// considered desired. +// considered desired because doing so would imply the entire Cluster is +// is desired. +// +// A "Cluster" with only one member is not really a cluster and will +// not be treated as a cluster. func Cluster(name string, providers ...interface{}) *Collection { c := newCollection(name, providers...) - id := atomic.AddInt32(&clusterId, 1) - for _, fm := range c.contents { - fm.cluster = id + if len(providers) > 1 { + id := atomic.AddInt32(&clusterId, 1) + for _, fm := range c.contents { + fm.cluster = id + } } return c } diff --git a/debug_test.go b/debug_test.go index 3a12448..72bf5ac 100644 --- a/debug_test.go +++ b/debug_test.go @@ -32,8 +32,12 @@ func debugOff() { } func wrapTest(t *testing.T, inner func(*testing.T)) { - if !t.Run("1st attempt", func(t *testing.T) { inner(t) }) { - t.Run("2nd attempt", func(t *testing.T) { + namedWrapTest(t, "", inner) +} + +func namedWrapTest(t *testing.T, name string, inner func(*testing.T)) { + if !t.Run("1st attempt"+name, func(t *testing.T) { inner(t) }) { + t.Run("2nd attempt"+name, func(t *testing.T) { debugOn(t) defer debugOff() inner(t) diff --git a/example_cluster_test.go b/example_cluster_test.go index 1274d8b..567735f 100644 --- a/example_cluster_test.go +++ b/example_cluster_test.go @@ -15,6 +15,10 @@ func ExampleCluster() { func(s string) int32 { return int32(len(s)) }, + func() int64 { + fmt.Println("included even though no consumer") + return 0 + }, func(i int32) { fmt.Println("auto-desired in 1st cluster") }, @@ -47,6 +51,7 @@ func ExampleCluster() { }, ) // Output: no need for data from clusters + // included even though no consumer // auto-desired in 1st cluster // auto-desired in 2nd cluster // got value that needed both chains - 28 diff --git a/include.go b/include.go index fd7c410..f466215 100644 --- a/include.go +++ b/include.go @@ -75,18 +75,8 @@ func computeDependenciesAndInclusion(funcs []*provider, initF *provider) ([]*pro fm.chainPosition = i } debugln("initial set of functions") - clusterLeaders := make(map[int32]*provider) for _, fm := range funcs { debugf("\t%s", fm) - if fm.cluster != 0 { - if leader, ok := clusterLeaders[fm.cluster]; ok { - leader.d.clusterMembers = append(leader.d.clusterMembers, fm) - fm.d.clusterMembers = nil - } else { - clusterLeaders[fm.cluster] = fm - fm.d.clusterMembers = []*provider{fm} - } - } fm.d.mustConsumeFlow = [lastFlowType]bool{} if fm.mustConsume { fm.d.mustConsumeFlow[outputParams] = true @@ -127,15 +117,35 @@ func computeDependenciesAndInclusion(funcs []*provider, initF *provider) ([]*pro } } + // We process cluster membership after we've determined + // funcs that cannot be included so that we do not include + // in clusters anything that cannot be included at all. + clusterLeaders := make(map[int32]*provider) // leaders are chosen arbitrarily + for _, fm := range funcs { + if fm.cluster == 0 || fm.d.excluded != nil { + continue + } + if leader, ok := clusterLeaders[fm.cluster]; ok { + leader.d.clusterMembers = append(leader.d.clusterMembers, fm) + fm.d.clusterMembers = nil + } else { + clusterLeaders[fm.cluster] = fm + fm.d.clusterMembers = []*provider{fm} + } + if !fm.required && !fm.desired && fm.wanted { + fm.d.wantedInCluster = true + } + } + debugln("eliminate unused providers") - eliminateUnused(funcs) + eliminateUnused(funcs) // xxx - tryWithout := func(without ...*provider) bool { + tryWithout := func(without ...*provider) { if len(without) == 1 { if without[0].wanted && without[0].d.wantedInCluster { // working around a bug: don't try to eliminate single // wanted functions from clusters - return false + return } debugf("check chain validity, excluding %s", without[0]) } else { @@ -166,7 +176,6 @@ func computeDependenciesAndInclusion(funcs []*provider, initF *provider) ([]*pro } } } - return err == nil } debugln("attempt to eliminate additional providers") @@ -174,12 +183,13 @@ func computeDependenciesAndInclusion(funcs []*provider, initF *provider) ([]*pro if fm.d.excluded != nil { continue } - if fm.d.clusterMembers != nil { - if tryWithout(fm.d.clusterMembers...) { - continue + if fm.cluster != 0 { + if fm.d.clusterMembers != nil { + tryWithout(fm.d.clusterMembers...) } + } else { + tryWithout(fm) } - tryWithout(fm) } debugln("final set of functions") @@ -266,7 +276,7 @@ func checkFlows(funcs []*provider, numFuncs int, canRemoveDesired bool) error { for tc, err := range errors { fm.cannotInclude = err redo = append(redo, fm) - debugf("\t\trequire error on %s %s: %s", param, tc, err) + debugf("\t\trequire error on %d %s: %s", param, tc, err) continue Todo } } @@ -439,7 +449,7 @@ PostCheck: for len(check) > 0 { var fm *provider fm, check = check[0], check[1:] - if fm.required || fm.desired || fm.wanted || !fm.include || fm.d.excluded != nil { + if fm.required || fm.desired || fm.wanted || !fm.include || fm.d.excluded != nil || fm.cluster != 0 { continue } for _, dep := range fm.d.usedBy { diff --git a/regressions_test.go b/regressions_test.go index 03a2054..b5fffb4 100644 --- a/regressions_test.go +++ b/regressions_test.go @@ -912,3 +912,44 @@ func TestRegression9(t *testing.T) { // invoker() }) } + +func TestClusterRegression(t *testing.T) { + genTest := func(clustering bool) func(*testing.T) { + return func(t *testing.T) { + var nn, nd, ndd, tc bool + seq1 := []interface{}{ + Provide("needed", func() string { return "foo" }), + Provide("not-needed", func() int64 { + nn = true + return 0 + }), + } + seq2 := []interface{}{ + Provide("normally-desired", func(string) { + nd = true + }), + Provide("normally-desired too", func(string) {}), + } + seq3 := []interface{}{ + Provide("normally-desired in degenerate cluster", func(string) { + ndd = true + }), + } + test := func(s string) { + assert.Equal(t, "foo", s) + tc = true + } + if clustering { + require.NoError(t, Run(t.Name(), Cluster("s1", seq1...), Cluster("s2", seq2...), Cluster("s3", seq3...), test)) + } else { + require.NoError(t, Run(t.Name(), Sequence("s1", seq1...), Sequence("s2", seq2...), Sequence("s3", seq3...), test)) + } + require.True(t, tc, "tc") + assert.Equal(t, clustering, nn, "nn") + assert.Equal(t, !clustering, nd, "nd") + assert.True(t, ndd, "ndd") + } + } + namedWrapTest(t, " cluster", genTest(true)) + namedWrapTest(t, " sequence", genTest(false)) +}