From 599fed628b901b2f9bd23d958b32e692fbdea7fa Mon Sep 17 00:00:00 2001 From: Yonghwan SO Date: Sat, 19 Nov 2022 15:13:27 +0900 Subject: [PATCH] fixed empty slice bug on slices.Float --- slices/float.go | 25 ++++++++++++++++++------- slices/float_test.go | 30 ++++++++++++++++++++++++++++++ slices_test.go | 2 ++ 3 files changed, 50 insertions(+), 7 deletions(-) create mode 100644 slices/float_test.go diff --git a/slices/float.go b/slices/float.go index 6c6e41b2f..18dfa1306 100644 --- a/slices/float.go +++ b/slices/float.go @@ -27,8 +27,9 @@ func (f *Float) Scan(src interface{}) error { default: return fmt.Errorf("scan source was not []byte nor string but %T", src) } - *f = strToFloat(str) - return nil + v, err := strToFloat(str) + *f = v + return err } // Value implements the driver.Valuer interface. @@ -56,12 +57,22 @@ func (f *Float) UnmarshalText(text []byte) error { return nil } -func strToFloat(s string) []float64 { +func strToFloat(s string) ([]float64, error) { r := strings.Trim(s, "{}") a := make([]float64, 0, 10) - for _, t := range strings.Split(r, ",") { - i, _ := strconv.ParseFloat(t, 64) - a = append(a, i) + + elems := strings.Split(r, ",") + if len(elems) == 1 && elems[0] == "" { + return a, nil } - return a + + for _, t := range elems { + f, err := strconv.ParseFloat(t, 64) + if err != nil { + return nil, err + } + a = append(a, f) + } + + return a, nil } diff --git a/slices/float_test.go b/slices/float_test.go new file mode 100644 index 000000000..0b64a59d6 --- /dev/null +++ b/slices/float_test.go @@ -0,0 +1,30 @@ +package slices + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func Test_Float_Scan(t *testing.T) { + r := require.New(t) + t.Run("empty slice", func(t *testing.T) { + in := "{}" + v := &Float{} + r.NoError(v.Scan(in)) + r.Len(*v, 0) + }) + + t.Run("non-empty slice", func(t *testing.T) { + in := "{3.14,9.999}" + v := &Float{} + r.NoError(v.Scan(in)) + r.Equal([]float64(*v), []float64{3.14, 9.999}) + }) + + t.Run("invalid entry", func(t *testing.T) { + in := "{44,word}" + v := &Float{} + r.Error(v.Scan(in)) + }) +} diff --git a/slices_test.go b/slices_test.go index 3159ea872..83e63ed09 100644 --- a/slices_test.go +++ b/slices_test.go @@ -44,6 +44,7 @@ func (s *PostgreSQLSuite) Test_Int() { err = tx.Reload(c) r.NoError(err) r.Equal(slices.Int{1, 2, 3}, c.Int) + r.Equal(slices.Float{}, c.Float) }) } @@ -59,6 +60,7 @@ func (s *PostgreSQLSuite) Test_Float() { err = tx.Reload(c) r.NoError(err) + r.Equal(slices.Int{}, c.Int) r.Equal(slices.Float{1.0, 2.1, 3.2}, c.Float) }) }