diff --git a/helper/helper.go b/helper/helper.go index 66049695..e42f1160 100644 --- a/helper/helper.go +++ b/helper/helper.go @@ -33,6 +33,18 @@ func Chain[T any](fns ...func(T) T) func(T) T { } } +func ChainWithErr[T any](fns ...func(T) (T, error)) func(T) (T, error) { + var err error + return func(v T) (T, error) { + for _, fn := range fns { + if v, err = fn(v); err != nil { + return v, err + } + } + return v, nil + } +} + func When[T any](value T, condition bool, callbacks ...func(T) T) T { if condition { return With(value, callbacks...) diff --git a/helper/helper_test.go b/helper/helper_test.go index fed23994..9cfa3ded 100644 --- a/helper/helper_test.go +++ b/helper/helper_test.go @@ -1,6 +1,7 @@ package helper import ( + "context" "testing" "github.com/stretchr/testify/assert" @@ -260,3 +261,72 @@ func TestScan_ComplexStruct(t *testing.T) { assert.Equal(t, "A1", b.Companies[0].Name) assert.Equal(t, "A2", b.Companies[1].Name) } + +func TestChainWithErr(t *testing.T) { + // chain functions + chain := ChainWithErr( + func(s string) (string, error) { + return s + "1", nil + }, + func(s string) (string, error) { + return s + "2", nil + }, + func(s string) (string, error) { + return s + "3", nil + }, + ) + + got, err := chain("0") + assert.Nil(t, err) + assert.Equal(t, "0123", got) + + // chain functions + chain2 := ChainWithErr( + func(foo *foo) (*foo, error) { + foo.Name = "bar" + return foo, nil + }, + func(foo *foo) (*foo, error) { + foo.Age = 18 + return foo, nil + }, + ) + + f := &foo{Name: "foo"} + assert.Equal(t, "foo", f.Name) + assert.Equal(t, 0, f.Age) + + got2, err := chain2(f) + assert.Nil(t, err) + assert.Equal(t, "bar", got2.Name) + assert.Equal(t, 18, got2.Age) + + // context + chain3 := ChainWithErr( + func(ctx context.Context) (context.Context, error) { + return context.WithValue(ctx, "foo", "bar"), nil + }, + func(ctx context.Context) (context.Context, error) { + return context.WithValue(ctx, "bar", "baz"), nil + }, + ) + + ctx, err := chain3(context.Background()) + assert.NoError(t, err) + assert.Equal(t, "bar", ctx.Value("foo")) + assert.Equal(t, "baz", ctx.Value("bar")) + + // context with error + chain4 := ChainWithErr( + func(ctx context.Context) (context.Context, error) { + return context.WithValue(ctx, "foo", "bar"), nil + }, + func(ctx context.Context) (context.Context, error) { + return nil, assert.AnError + }, + ) + + ctx, err = chain4(context.Background()) + assert.Error(t, err) + assert.Nil(t, ctx) +}