Skip to content

Commit

Permalink
fix: Cluster() did not work as documented (#67)
Browse files Browse the repository at this point in the history
  • Loading branch information
muir authored Apr 18, 2023
1 parent 2586205 commit 8f646e9
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 26 deletions.
14 changes: 10 additions & 4 deletions api.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,18 @@ var clusterId int32 = 1
// be included at all. It also downgrades providers that are in the
// cluster that would normally be considered desired because they don't
// return anything and aren't wrappers: they're no longer automatically
// considered desired.
// considered desired because doing so would imply the entire Cluster is
// is desired.
//
// A "Cluster" with only one member is not really a cluster and will
// not be treated as a cluster.
func Cluster(name string, providers ...interface{}) *Collection {
c := newCollection(name, providers...)
id := atomic.AddInt32(&clusterId, 1)
for _, fm := range c.contents {
fm.cluster = id
if len(providers) > 1 {
id := atomic.AddInt32(&clusterId, 1)
for _, fm := range c.contents {
fm.cluster = id
}
}
return c
}
Expand Down
8 changes: 6 additions & 2 deletions debug_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,12 @@ func debugOff() {
}

func wrapTest(t *testing.T, inner func(*testing.T)) {
if !t.Run("1st attempt", func(t *testing.T) { inner(t) }) {
t.Run("2nd attempt", func(t *testing.T) {
namedWrapTest(t, "", inner)
}

func namedWrapTest(t *testing.T, name string, inner func(*testing.T)) {
if !t.Run("1st attempt"+name, func(t *testing.T) { inner(t) }) {
t.Run("2nd attempt"+name, func(t *testing.T) {
debugOn(t)
defer debugOff()
inner(t)
Expand Down
5 changes: 5 additions & 0 deletions example_cluster_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ func ExampleCluster() {
func(s string) int32 {
return int32(len(s))
},
func() int64 {
fmt.Println("included even though no consumer")
return 0
},
func(i int32) {
fmt.Println("auto-desired in 1st cluster")
},
Expand Down Expand Up @@ -47,6 +51,7 @@ func ExampleCluster() {
},
)
// Output: no need for data from clusters
// included even though no consumer
// auto-desired in 1st cluster
// auto-desired in 2nd cluster
// got value that needed both chains - 28
Expand Down
50 changes: 30 additions & 20 deletions include.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,18 +75,8 @@ func computeDependenciesAndInclusion(funcs []*provider, initF *provider) ([]*pro
fm.chainPosition = i
}
debugln("initial set of functions")
clusterLeaders := make(map[int32]*provider)
for _, fm := range funcs {
debugf("\t%s", fm)
if fm.cluster != 0 {
if leader, ok := clusterLeaders[fm.cluster]; ok {
leader.d.clusterMembers = append(leader.d.clusterMembers, fm)
fm.d.clusterMembers = nil
} else {
clusterLeaders[fm.cluster] = fm
fm.d.clusterMembers = []*provider{fm}
}
}
fm.d.mustConsumeFlow = [lastFlowType]bool{}
if fm.mustConsume {
fm.d.mustConsumeFlow[outputParams] = true
Expand Down Expand Up @@ -127,15 +117,35 @@ func computeDependenciesAndInclusion(funcs []*provider, initF *provider) ([]*pro
}
}

// We process cluster membership after we've determined
// funcs that cannot be included so that we do not include
// in clusters anything that cannot be included at all.
clusterLeaders := make(map[int32]*provider) // leaders are chosen arbitrarily
for _, fm := range funcs {
if fm.cluster == 0 || fm.d.excluded != nil {
continue
}
if leader, ok := clusterLeaders[fm.cluster]; ok {
leader.d.clusterMembers = append(leader.d.clusterMembers, fm)
fm.d.clusterMembers = nil
} else {
clusterLeaders[fm.cluster] = fm
fm.d.clusterMembers = []*provider{fm}
}
if !fm.required && !fm.desired && fm.wanted {
fm.d.wantedInCluster = true
}
}

debugln("eliminate unused providers")
eliminateUnused(funcs)
eliminateUnused(funcs) // xxx

tryWithout := func(without ...*provider) bool {
tryWithout := func(without ...*provider) {
if len(without) == 1 {
if without[0].wanted && without[0].d.wantedInCluster {
// working around a bug: don't try to eliminate single
// wanted functions from clusters
return false
return
}
debugf("check chain validity, excluding %s", without[0])
} else {
Expand Down Expand Up @@ -166,20 +176,20 @@ func computeDependenciesAndInclusion(funcs []*provider, initF *provider) ([]*pro
}
}
}
return err == nil
}

debugln("attempt to eliminate additional providers")
for _, fm := range proposeEliminations(funcs) {
if fm.d.excluded != nil {
continue
}
if fm.d.clusterMembers != nil {
if tryWithout(fm.d.clusterMembers...) {
continue
if fm.cluster != 0 {
if fm.d.clusterMembers != nil {
tryWithout(fm.d.clusterMembers...)
}
} else {
tryWithout(fm)
}
tryWithout(fm)
}

debugln("final set of functions")
Expand Down Expand Up @@ -266,7 +276,7 @@ func checkFlows(funcs []*provider, numFuncs int, canRemoveDesired bool) error {
for tc, err := range errors {
fm.cannotInclude = err
redo = append(redo, fm)
debugf("\t\trequire error on %s %s: %s", param, tc, err)
debugf("\t\trequire error on %d %s: %s", param, tc, err)
continue Todo
}
}
Expand Down Expand Up @@ -439,7 +449,7 @@ PostCheck:
for len(check) > 0 {
var fm *provider
fm, check = check[0], check[1:]
if fm.required || fm.desired || fm.wanted || !fm.include || fm.d.excluded != nil {
if fm.required || fm.desired || fm.wanted || !fm.include || fm.d.excluded != nil || fm.cluster != 0 {
continue
}
for _, dep := range fm.d.usedBy {
Expand Down
41 changes: 41 additions & 0 deletions regressions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -912,3 +912,44 @@ func TestRegression9(t *testing.T) {
// invoker()
})
}

func TestClusterRegression(t *testing.T) {
genTest := func(clustering bool) func(*testing.T) {
return func(t *testing.T) {
var nn, nd, ndd, tc bool
seq1 := []interface{}{
Provide("needed", func() string { return "foo" }),
Provide("not-needed", func() int64 {
nn = true
return 0
}),
}
seq2 := []interface{}{
Provide("normally-desired", func(string) {
nd = true
}),
Provide("normally-desired too", func(string) {}),
}
seq3 := []interface{}{
Provide("normally-desired in degenerate cluster", func(string) {
ndd = true
}),
}
test := func(s string) {
assert.Equal(t, "foo", s)
tc = true
}
if clustering {
require.NoError(t, Run(t.Name(), Cluster("s1", seq1...), Cluster("s2", seq2...), Cluster("s3", seq3...), test))
} else {
require.NoError(t, Run(t.Name(), Sequence("s1", seq1...), Sequence("s2", seq2...), Sequence("s3", seq3...), test))
}
require.True(t, tc, "tc")
assert.Equal(t, clustering, nn, "nn")
assert.Equal(t, !clustering, nd, "nd")
assert.True(t, ndd, "ndd")
}
}
namedWrapTest(t, " cluster", genTest(true))
namedWrapTest(t, " sequence", genTest(false))
}

0 comments on commit 8f646e9

Please sign in to comment.