diff --git a/internal/internalpipe/any.go b/internal/internalpipe/any.go index da16eff..1892c05 100644 --- a/internal/internalpipe/any.go +++ b/internal/internalpipe/any.go @@ -2,12 +2,18 @@ package internalpipe import "sync" -const infiniteLenStep = 1 << 15 +const hugeLenStep = 1 << 15 func anySingleThread[T any](lenSet bool, limit int, fn GeneratorFn[T]) *T { - var obj *T - var skipped bool - for i := 0; (!lenSet && i >= 0) || (i < limit); i++ { + var ( + obj *T + skipped bool + ) + check := func(i int) bool { return i < limit } + if !lenSet { + check = func(i int) bool { return i > -1 && i < limit } + } + for i := 0; check(i); i++ { if obj, skipped = fn(i); !skipped { return obj } @@ -23,30 +29,27 @@ func (p Pipe[T]) Any() *T { return anySingleThread(lenSet, limit, p.Fn) } - step := infiniteLenStep + step := hugeLenStep if lenSet { step = max(divUp(limit, p.GoroutinesCnt), 1) } + var ( - res = make(chan *T) - // if p.len is not set, we need tickets to control the amount of goroutines - tickets = genTickets(p.GoroutinesCnt) + resSet bool + resCh = make(chan *T, 1) + mx sync.Mutex - done = make(chan struct{}) - wg sync.WaitGroup + tickets = genTickets(p.GoroutinesCnt) + wg sync.WaitGroup ) - if !lenSet { - step = infiniteLenStep - } - + defer close(resCh) setObj := func(obj *T) { - select { - case <-done: - return - default: - close(done) - res <- obj + mx.Lock() + if !resSet { + resSet = true + resCh <- obj } + mx.Unlock() } go func() { @@ -54,22 +57,23 @@ func (p Pipe[T]) Any() *T { for i := 0; i >= 0 && (!lenSet || i < limit); i += step { wg.Add(1) <-tickets + go func(lf, rg int) { defer func() { - wg.Done() tickets <- struct{}{} + wg.Done() }() - // accounting int owerflow case with max(rg, 0) + // int owerflow case rg = max(rg, 0) if lenSet { rg = min(rg, limit) } for j := lf; j < rg; j++ { - select { - case <-done: - return - default: + mx.Lock() + rs := resSet + mx.Unlock() + if !rs { obj, skipped := p.Fn(j) if !skipped { setObj(obj) @@ -83,8 +87,9 @@ func (p Pipe[T]) Any() *T { go func() { wg.Wait() setObj(nil) + defer close(tickets) }() }() - return <-res + return <-resCh } diff --git a/internal/internalpipe/any_test.go b/internal/internalpipe/any_test.go index 4a0a5b1..c2a5dac 100644 --- a/internal/internalpipe/any_test.go +++ b/internal/internalpipe/any_test.go @@ -5,8 +5,6 @@ import ( "testing" "github.com/stretchr/testify/require" - - "github.com/koss-null/funcfrog/internal/primitive/pointer" ) var ( @@ -28,48 +26,60 @@ func TestAny(t *testing.T) { t.Parallel() t.Run("Single thread no limit", func(t *testing.T) { + t.Parallel() + p := Func(func(i int) (float64, bool) { - return a100k[i], a100k[i] <= 90_000.0 + return a100k[i], a100k[i] > 90_000.0 }) s := p.Any() require.NotNil(t, s) - require.Greater(t, 90_000.0, *s) + require.Greater(t, *s, 90_000.0) }) - t.Run("Seven thread no limit", func(t *testing.T) { + t.Run("Single thread limit", func(t *testing.T) { + t.Parallel() + p := Func(func(i int) (float64, bool) { - if i >= len(a100k) { - return 0., false - } - return a100k[i], a100k[i] <= 90_000.0 - }).Parallel(7) + return a100k[i], a100k[i] > 90_000.0 + }).Gen(len(a100k)) s := p.Any() require.NotNil(t, s) - require.Greater(t, 90_000.0, *s) + require.Greater(t, *s, 90_000.0) }) - t.Run("Single thread limit", func(t *testing.T) { + t.Run("Seven thread no limit", func(t *testing.T) { + t.Parallel() + p := Func(func(i int) (float64, bool) { - return a100k[i], a100k[i] <= 90_000.0 - }).Gen(len(a100k)) + if i >= len(a100k) { + return 0., false + } + return a100k[i], true + }). + Filter(func(x *float64) bool { return *x > 90_000. }). + Parallel(7) s := p.Any() require.NotNil(t, s) - require.Greater(t, 90_000.0, pointer.Deref(s)) + require.Greater(t, *s, 90_000.0) }) t.Run("Seven thread limit", func(t *testing.T) { + t.Parallel() + p := Func(func(i int) (float64, bool) { if i >= len(a100k) { return 0., false } - return a100k[i], a100k[i] <= 90_000.0 + return a100k[i], a100k[i] > 90_000.0 }).Gen(len(a100k)).Parallel(7) s := p.Any() require.NotNil(t, s) - require.Greater(t, 90_000.0, pointer.Deref(s)) + require.Greater(t, *s, 90_000.0) }) t.Run("Single thread NF limit", func(t *testing.T) { + t.Parallel() + p := Func(func(i int) (float64, bool) { return a100k[i], false }).Gen(len(a100k)) @@ -78,6 +88,8 @@ func TestAny(t *testing.T) { }) t.Run("Seven thread NF limit", func(t *testing.T) { + t.Parallel() + p := Func(func(i int) (float64, bool) { if i >= len(a100k) { return 0., false @@ -89,6 +101,8 @@ func TestAny(t *testing.T) { }) t.Run("Single thread bounded limit", func(t *testing.T) { + t.Parallel() + p := Func(func(i int) (float64, bool) { return a100k[i], false }).Gen(len(a100k)) @@ -97,6 +111,8 @@ func TestAny(t *testing.T) { }) t.Run("Seven thread bounded limit", func(t *testing.T) { + t.Parallel() + p := Func(func(i int) (float64, bool) { if i >= len(a100k) { return 0., false @@ -109,6 +125,8 @@ func TestAny(t *testing.T) { }) t.Run("Single thread bounded no limit", func(t *testing.T) { + t.Parallel() + p := Func(func(i int) (float64, bool) { if i >= len(a100k) { return 0., false @@ -121,6 +139,8 @@ func TestAny(t *testing.T) { }) t.Run("Seven thread bounded no limit", func(t *testing.T) { + t.Parallel() + p := Func(func(i int) (float64, bool) { if i >= len(a100k) { return 0., false @@ -133,6 +153,8 @@ func TestAny(t *testing.T) { }) t.Run("Ten thread bounded no limit filter", func(t *testing.T) { + t.Parallel() + p := Func(func(i int) (float64, bool) { if i >= len(a100k) { return 0., false