From 76afc495feccd060663dc0e502670eb949a8793a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Flc=E3=82=9B?= Date: Thu, 14 Mar 2024 22:47:19 +0800 Subject: [PATCH] refactor(context): Rename `Chain` to `Pipe` and add `Chain` (#153) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * refactor(context): Rename `Chain` to `Pipe` and add `Chain` Signed-off-by: Flc゛ * refactor(context): Rename `Chain` to `Pipe` and add `Chain` Signed-off-by: Flc゛ --------- Signed-off-by: Flc゛ --- context/context.go | 16 +++++++++- context/context_test.go | 68 +++++++++++++++++++++++++++++------------ 2 files changed, 64 insertions(+), 20 deletions(-) diff --git a/context/context.go b/context/context.go index 088fc33d..08c67934 100644 --- a/context/context.go +++ b/context/context.go @@ -6,7 +6,8 @@ import ( type Provider func(ctx context.Context) (context.Context, error) -func Chain(ctx context.Context, providers ...Provider) (context.Context, error) { +// Pipe returns a Provider that chains the provided Providers. +func Pipe(ctx context.Context, providers ...Provider) (context.Context, error) { var err error for _, provider := range providers { if provider != nil { @@ -17,3 +18,16 @@ func Chain(ctx context.Context, providers ...Provider) (context.Context, error) } return ctx, nil } + +// Chain is a reverse Pipe. +func Chain(ctx context.Context, providers ...Provider) (context.Context, error) { + var err error + for i := len(providers) - 1; i >= 0; i-- { + if providers[i] != nil { + if ctx, err = providers[i](ctx); err != nil { + return ctx, err + } + } + } + return ctx, nil +} diff --git a/context/context_test.go b/context/context_test.go index 733debc5..8df367e8 100644 --- a/context/context_test.go +++ b/context/context_test.go @@ -8,36 +8,44 @@ import ( "github.com/stretchr/testify/assert" ) -func TestContext(t *testing.T) { - type ( - mockProviderStruct1 struct{} - mockProviderStruct2 struct{} - mockProviderStruct3 struct{} - ) +type ( + mockProviderStruct1 struct{} + mockProviderStruct2 struct{} + mockProviderStruct3 struct{} +) - var ( - mockProvider1 = func(ctx context.Context) (context.Context, error) { - return context.WithValue(ctx, mockProviderStruct1{}, "mockProvider1"), nil - } +var result chan string - mockProvider2 = func(ctx context.Context) (context.Context, error) { - return context.WithValue(ctx, mockProviderStruct2{}, "mockProvider2"), nil - } +var ( + mockProvider1 = func(ctx context.Context) (context.Context, error) { + result <- "mockProvider1" + return context.WithValue(ctx, mockProviderStruct1{}, "mockProvider1"), nil + } - mockProvider3 = func(ctx context.Context) (context.Context, error) { - return ctx, errors.New("mockProvider3") - } - ) + mockProvider2 = func(ctx context.Context) (context.Context, error) { + result <- "mockProvider2" + return context.WithValue(ctx, mockProviderStruct2{}, "mockProvider2"), nil + } - ctx1, err1 := Chain( + mockProvider3 = func(ctx context.Context) (context.Context, error) { + result <- "mockProvider3" + return ctx, errors.New("mockProvider3") + } +) + +func TestPipe(t *testing.T) { + result = make(chan string, 2) + ctx1, err1 := Pipe( context.Background(), mockProvider1, mockProvider2, ) assert.NoError(t, err1) assert.Equal(t, "mockProvider1", ctx1.Value(mockProviderStruct1{})) assert.Equal(t, "mockProvider2", ctx1.Value(mockProviderStruct2{})) + assert.Equal(t, "mockProvider1", <-result) + assert.Equal(t, "mockProvider2", <-result) - ctx2, err2 := Chain( + ctx2, err2 := Pipe( context.Background(), mockProvider1, mockProvider3, ) @@ -45,4 +53,26 @@ func TestContext(t *testing.T) { assert.NotNil(t, ctx2) assert.Equal(t, "mockProvider1", ctx2.Value(mockProviderStruct1{})) assert.Nil(t, ctx2.Value(mockProviderStruct3{})) + assert.Equal(t, "mockProvider1", <-result) +} + +func TestChain(t *testing.T) { + result = make(chan string, 2) + ctx1, err1 := Chain( + context.Background(), + mockProvider1, mockProvider2, + ) + assert.NoError(t, err1) + assert.Equal(t, "mockProvider1", ctx1.Value(mockProviderStruct1{})) + assert.Equal(t, "mockProvider2", ctx1.Value(mockProviderStruct2{})) + assert.Equal(t, "mockProvider2", <-result) + assert.Equal(t, "mockProvider1", <-result) + + ctx2, err2 := Chain( + context.Background(), + mockProvider3, mockProvider1, + ) + assert.Error(t, err2) + assert.Equal(t, "mockProvider1", ctx2.Value(mockProviderStruct1{})) + assert.Equal(t, "mockProvider1", <-result) }