From fd0a51248cefeb1ca0dd9bd51993f7148c6c6b4c Mon Sep 17 00:00:00 2001 From: wangjin Date: Sun, 18 Aug 2024 18:21:44 +0800 Subject: [PATCH] add op.Result function, add op tests --- op/op.go | 8 ++++ op/op_test.go | 103 ++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 111 insertions(+) diff --git a/op/op.go b/op/op.go index d19aa65..8f9c2b9 100644 --- a/op/op.go +++ b/op/op.go @@ -162,3 +162,11 @@ func Identity[T any](v T) func() T { return v } } + +// Result returns err if it is not nil, otherwise it returns value. +func Result(err error, value any) any { + if err != nil { + return err + } + return value +} diff --git a/op/op_test.go b/op/op_test.go index 7eef9d3..eadfb03 100644 --- a/op/op_test.go +++ b/op/op_test.go @@ -240,6 +240,109 @@ func TestAddr(t *testing.T) { } } +func TestMust(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Errorf("Must did not panic") + } + }() + op.Must(fmt.Errorf("error")) +} + +func TestMustValue(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Errorf("Must did not panic") + } + }() + op.MustValue(42, fmt.Errorf("error")) +} + +func TestMustValue_no_panic(t *testing.T) { + if got := op.MustValue(42, nil); got != 42 { + t.Errorf("MustValue(42, nil) = %v, want 42", got) + } +} + +func TestReverseCompare(t *testing.T) { + cmp := func(x, y int) int { + if x < y { + return -1 + } + if x > y { + return 1 + } + return 0 + } + revCmp := op.ReverseCompare(cmp) + + if got := revCmp(1, 2); got != 1 { + t.Errorf("ReverseCompare(cmp)(1, 2) = %v, want 1", got) + } + if got := revCmp(2, 1); got != -1 { + t.Errorf("ReverseCompare(cmp)(2, 1) = %v, want -1", got) + } + if got := revCmp(1, 1); got != 0 { + t.Errorf("ReverseCompare(cmp)(1, 1) = %v, want 0", got) + } +} + +func TestZero(t *testing.T) { + if zero := op.Zero[bool](); zero != false { + t.Errorf("Zero[bool]() = %v, want false", zero) + } + if zero := op.Zero[int](); zero != 0 { + t.Errorf("Zero[int]() = %v, want 0", zero) + } + if zero := op.Zero[string](); zero != "" { + t.Errorf("Zero[string]() = %v, want \"\"", zero) + } + if zero := op.Zero[float64](); zero != 0.0 { + t.Errorf("Zero[float64]() = %v, want 0.0", zero) + } + if zero := op.Zero[error](); zero != nil { + t.Errorf("Zero[error]() = %v, want nil", zero) + } +} + +func TestIdentity(t *testing.T) { + x := 42 + if got := op.Identity(x)(); got != x { + t.Errorf("Identity(%v)() = %v, want %v", x, got, x) + } + e := fmt.Errorf("error") + if got := op.Identity(e)(); got != e { + t.Errorf("Identity(%v)() = %v, want %v", e, got, e) + } +} + +func TestResult(t *testing.T) { + var err = fmt.Errorf("error") + tests := []struct { + name string + a error + b any + want any + }{ + {"nil", nil, 1, 1}, + {"error", err, 1, err}, + {"nil_nil", nil, nil, nil}, + {"error_nil", err, nil, err}, + {"nil_error", nil, err, err}, + {"nil_string", nil, "string", "string"}, + {"error_string", err, "string", err}, + {"nil_float", nil, 1.0, 1.0}, + {"error_float", err, 1.0, err}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := op.Result(tt.a, tt.b); got != tt.want { + t.Errorf("Result(%v, %v) = %v, want %v", tt.a, tt.b, got, tt.want) + } + }) + } +} + func ExampleOr() { fmt.Println(op.Or(0, 1)) fmt.Println(op.Or(2, 3))