From 8bc25c9210b38416a2e2809cbfd842920f18ebbe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Flc=E3=82=9B?= Date: Sun, 17 Mar 2024 17:13:45 +0800 Subject: [PATCH] refactor(contexts): I'm not sure how to write it (#166) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * refactor(provider): rename context to provider Signed-off-by: Flc゛ * refactor(provider): split to contexts and provider Signed-off-by: Flc゛ * refactor(provider): split to contexts and provider Signed-off-by: Flc゛ --------- Signed-off-by: Flc゛ --- contexts/contexts.go | 33 +++++++++++++++++ contexts/contexts_test.go | 78 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 111 insertions(+) create mode 100644 contexts/contexts.go create mode 100644 contexts/contexts_test.go diff --git a/contexts/contexts.go b/contexts/contexts.go new file mode 100644 index 00000000..00d5cf3d --- /dev/null +++ b/contexts/contexts.go @@ -0,0 +1,33 @@ +package contexts + +import ( + "context" +) + +type Func func(ctx context.Context) (context.Context, error) + +// Pipe returns a Provider that chains the provided Providers. +func Pipe(ctx context.Context, fns ...Func) (context.Context, error) { + var err error + for _, fn := range fns { + if fn != nil { + if ctx, err = fn(ctx); err != nil { + return ctx, err + } + } + } + return ctx, nil +} + +// Chain is a reverse Pipe. +func Chain(ctx context.Context, fns ...Func) (context.Context, error) { + var err error + for i := len(fns) - 1; i >= 0; i-- { + if fns[i] != nil { + if ctx, err = fns[i](ctx); err != nil { + return ctx, err + } + } + } + return ctx, nil +} diff --git a/contexts/contexts_test.go b/contexts/contexts_test.go new file mode 100644 index 00000000..06c0b930 --- /dev/null +++ b/contexts/contexts_test.go @@ -0,0 +1,78 @@ +package contexts + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/assert" +) + +type ( + mockProviderStruct1 struct{} + mockProviderStruct2 struct{} + mockProviderStruct3 struct{} +) + +var result chan string + +var ( + mockProvider1 = func(ctx context.Context) (context.Context, error) { + result <- "mockProvider1" + return context.WithValue(ctx, mockProviderStruct1{}, "mockProvider1"), nil + } + + mockProvider2 = func(ctx context.Context) (context.Context, error) { + result <- "mockProvider2" + return context.WithValue(ctx, mockProviderStruct2{}, "mockProvider2"), nil + } + + mockProvider3 = func(ctx context.Context) (context.Context, error) { + result <- "mockProvider3" + return ctx, errors.New("mockProvider3") + } +) + +func TestPipe(t *testing.T) { + result = make(chan string, 2) + ctx1, err1 := Pipe( + context.Background(), + mockProvider1, mockProvider2, + ) + assert.NoError(t, err1) + assert.Equal(t, "mockProvider1", ctx1.Value(mockProviderStruct1{})) + assert.Equal(t, "mockProvider2", ctx1.Value(mockProviderStruct2{})) + assert.Equal(t, "mockProvider1", <-result) + assert.Equal(t, "mockProvider2", <-result) + + ctx2, err2 := Pipe( + context.Background(), + mockProvider1, mockProvider3, + ) + assert.Error(t, err2) + assert.NotNil(t, ctx2) + assert.Equal(t, "mockProvider1", ctx2.Value(mockProviderStruct1{})) + assert.Nil(t, ctx2.Value(mockProviderStruct3{})) + assert.Equal(t, "mockProvider1", <-result) +} + +func TestChain(t *testing.T) { + result = make(chan string, 2) + ctx1, err1 := Chain( + context.Background(), + mockProvider1, mockProvider2, + ) + assert.NoError(t, err1) + assert.Equal(t, "mockProvider1", ctx1.Value(mockProviderStruct1{})) + assert.Equal(t, "mockProvider2", ctx1.Value(mockProviderStruct2{})) + assert.Equal(t, "mockProvider2", <-result) + assert.Equal(t, "mockProvider1", <-result) + + ctx2, err2 := Chain( + context.Background(), + mockProvider3, mockProvider1, + ) + assert.Error(t, err2) + assert.Equal(t, "mockProvider1", ctx2.Value(mockProviderStruct1{})) + assert.Equal(t, "mockProvider1", <-result) +}