From 977e88749291d5422f0ed37cdb9df8123bc7e1d8 Mon Sep 17 00:00:00 2001 From: Greg Date: Mon, 12 Feb 2024 03:20:21 +0900 Subject: [PATCH 1/5] add missing int types from database/sql --- README.md | 56 ++++-- byte.go | 107 ++++++++++ go.mod | 3 + int.go | 58 ++---- int16.go | 107 ++++++++++ int32.go | 107 ++++++++++ int_test.go | 480 +++++++++++++++++++++++++++----------------- internal/int.go | 58 ++++++ internal/type.go | 7 + zero/byte.go | 119 +++++++++++ zero/int.go | 54 ++--- zero/int16.go | 119 +++++++++++ zero/int32.go | 119 +++++++++++ zero/int_test.go | 509 +++++++++++++++++++++++++++++------------------ 14 files changed, 1435 insertions(+), 468 deletions(-) create mode 100644 byte.go create mode 100644 go.mod create mode 100644 int16.go create mode 100644 int32.go create mode 100644 internal/int.go create mode 100644 internal/type.go create mode 100644 zero/byte.go create mode 100644 zero/int16.go create mode 100644 zero/int32.go diff --git a/README.md b/README.md index 233c71f..3ce6406 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,5 @@ -## null [![GoDoc](https://godoc.org/github.com/guregu/null?status.svg)](https://godoc.org/github.com/guregu/null) [![CircleCI](https://circleci.com/gh/guregu/null.svg?style=svg)](https://circleci.com/gh/guregu/null) -`import "gopkg.in/guregu/null.v4"` +## null [![GoDoc](https://godoc.org/github.com/guregu/null/v5?status.svg)](https://godoc.org/github.com/guregu/null/v5) +`import "github.com/guregu/null/v5"` null is a library with reasonable options for dealing with nullable SQL and JSON values @@ -9,20 +9,23 @@ Types in `null` will only be considered null on null input, and will JSON encode Types in `zero` are treated like zero values in Go: blank string input will produce a null `zero.String`, and null Strings will JSON encode to `""`. Zero values of these types will be considered null to SQL. If you need zero and null treated the same, use these. -All types implement `sql.Scanner` and `driver.Valuer`, so you can use this library in place of `sql.NullXXX`. -All types also implement: `encoding.TextMarshaler`, `encoding.TextUnmarshaler`, `json.Marshaler`, and `json.Unmarshaler`. A null object's `MarshalText` will return a blank string. +#### Interfaces -### null package +- All types implement `sql.Scanner` and `driver.Valuer`, so you can use this library in place of `sql.NullXXX`. +- All types also implement `json.Marshaler` and `json.Unmarshaler`, so you can marshal them to their native JSON representation. +- All non-generic types implement `encoding.TextMarshaler`, `encoding.TextUnmarshaler`. A null object's `MarshalText` will return a blank string. -`import "gopkg.in/guregu/null.v4"` +## null package + +`import "github.com/guregu/null/v5"` #### null.String Nullable string. Marshals to JSON null if SQL source data is null. Zero (blank) input will not produce a null String. -#### null.Int -Nullable int64. +#### null.Int, null.Int32, null.Int16, null.Byte +Nullable int64/int32/int16/byte. Marshals to JSON null if SQL source data is null. Zero input will not produce a null Int. @@ -40,17 +43,22 @@ Marshals to JSON null if SQL source data is null. False input will not produce a Marshals to JSON null if SQL source data is null. Zero input will not produce a null Time. -### zero package +#### null.Value +Generic nullable value. + +Will marshal to JSON null if SQL source data is null. Does not implement `encoding.TextMarshaler`. + +## zero package -`import "gopkg.in/guregu/null.v4/zero"` +`import "github.com/guregu/null/v5/zero"` #### zero.String Nullable string. Will marshal to a blank string if null. Blank string input produces a null String. Null values and zero values are considered equivalent. -#### zero.Int -Nullable int64. +#### zero.Int, zero.Int32, zero.Int16, zero.Byte +Nullable int64/int32/int16/byte. Will marshal to 0 if null. 0 produces a null Int. Null values and zero values are considered equivalent. @@ -65,17 +73,35 @@ Nullable bool. Will marshal to false if null. `false` produces a null Float. Null values and zero values are considered equivalent. #### zero.Time +Nullable time. Will marshal to the zero time if null. Uses `time.Time`'s marshaler. -### Can you add support for other types? +#### zero.Value[`T`] +Generic nullable value. + +Will marshal to zero value if null. `T` is required to be a comparable type. Does not implement `encoding.TextMarshaler`. + +## About + +### Q&A + +#### Can you add support for other types? This package is intentionally limited in scope. It will only support the types that [`driver.Value`](https://godoc.org/database/sql/driver#Value) supports. Feel free to fork this and add more types if you want. -### Can you add a feature that ____? +#### Can you add a feature that ____? This package isn't intended to be a catch-all data-wrangling package. It is essentially finished. If you have an idea for a new feature, feel free to open an issue to talk about it or fork this package, but don't expect this to do everything. ### Package history -*As of v4*, unmarshaling from JSON `sql.NullXXX` JSON objects (ex. `{"Int64": 123, "Valid": true}`) is no longer supported. It's unlikely many people used this, but if you need it, use v3. + +#### v5 +- Now a Go module under the path `github.com/guregu/null/v5` +- Added missing types from `database/sql`: `Int32, Int16, Byte` +- Added generic `Value[T]` embedding `sql.Null[T]` + +#### v4 +- Available at `gopkg.in/guregu/null.v4` +- Unmarshaling from JSON `sql.NullXXX` JSON objects (e.g. `{"Int64": 123, "Valid": true}`) is no longer supported. It's unlikely many people used this, but if you need it, use v3. ### Bugs `json`'s `",omitempty"` struct tag does not work correctly right now. It will never omit a null or empty String. This might be [fixed eventually](https://github.com/golang/go/issues/11939). diff --git a/byte.go b/byte.go new file mode 100644 index 0000000..c463b98 --- /dev/null +++ b/byte.go @@ -0,0 +1,107 @@ +package null + +import ( + "database/sql" + "strconv" + + "github.com/guregu/null/v5/internal" +) + +// Byte is an nullable byte. +// It does not consider zero values to be null. +// It will decode to null, not zero, if null. +type Byte struct { + sql.NullByte +} + +// NewByte creates a new Byte. +func NewByte(b byte, valid bool) Byte { + return Byte{ + NullByte: sql.NullByte{ + Byte: b, + Valid: valid, + }, + } +} + +// ByteFrom creates a new Byte that will always be valid. +func ByteFrom(b byte) Byte { + return NewByte(b, true) +} + +// ByteFromPtr creates a new Byte that be null if i is nil. +func ByteFromPtr(b *byte) Byte { + if b == nil { + return NewByte(0, false) + } + return NewByte(*b, true) +} + +// ValueOrZero returns the inner value if valid, otherwise zero. +func (b Byte) ValueOrZero() byte { + if !b.Valid { + return 0 + } + return b.Byte +} + +// UnmarshalJSON implements json.Unmarshaler. +// It supports number, string, and null input. +// 0 will not be considered a null Byte. +func (b *Byte) UnmarshalJSON(data []byte) error { + return internal.UnmarshalIntJSON(data, &b.Byte, &b.Valid, 8, strconv.ParseUint) +} + +// UnmarshalText implements encoding.TextUnmarshaler. +// It will unmarshal to a null Byte if the input is blank. +// It will return an error if the input is not an integer, blank, or "null". +func (b *Byte) UnmarshalText(text []byte) error { + return internal.UnmarshalIntText(text, &b.Byte, &b.Valid, 8, strconv.ParseUint) +} + +// MarshalJSON implements json.Marshaler. +// It will encode null if this Byte is null. +func (b Byte) MarshalJSON() ([]byte, error) { + if !b.Valid { + return []byte("null"), nil + } + return []byte(strconv.FormatInt(int64(b.Byte), 10)), nil +} + +// MarshalText implements encoding.TextMarshaler. +// It will encode a blank string if this Byte is null. +func (b Byte) MarshalText() ([]byte, error) { + if !b.Valid { + return []byte{}, nil + } + return []byte(strconv.FormatInt(int64(b.Byte), 10)), nil +} + +// SetValid changes this Byte's value and also sets it to be non-null. +func (b *Byte) SetValid(n byte) { + b.Byte = n + b.Valid = true +} + +// Ptr returns a pointer to this Byte's value, or a nil pointer if this Byte is null. +func (b Byte) Ptr() *byte { + if !b.Valid { + return nil + } + return &b.Byte +} + +// IsZero returns true for invalid Bytes, for future omitempty support (Go 1.4?) +// A non-null Byte with a 0 value will not be considered zero. +func (b Byte) IsZero() bool { + return !b.Valid +} + +// Equal returns true if both ints have the same value or are both null. +func (b Byte) Equal(other Byte) bool { + return b.Valid == other.Valid && (!b.Valid || b.Byte == other.Byte) +} + +func (b Byte) value() (int64, bool) { + return int64(b.Byte), b.Valid +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..a93d89d --- /dev/null +++ b/go.mod @@ -0,0 +1,3 @@ +module github.com/guregu/null/v5 + +go 1.21.4 diff --git a/int.go b/int.go index adc36ae..fd91c77 100644 --- a/int.go +++ b/int.go @@ -1,12 +1,10 @@ package null import ( - "bytes" "database/sql" - "encoding/json" - "errors" - "fmt" "strconv" + + "github.com/guregu/null/v5/internal" ) // Int is an nullable int64. @@ -16,7 +14,10 @@ type Int struct { sql.NullInt64 } -// NewInt creates a new Int +// Int64 is an alias for Int. +type Int64 = Int + +// NewInt creates a new Int. func NewInt(i int64, valid bool) Int { return Int{ NullInt64: sql.NullInt64{ @@ -51,53 +52,14 @@ func (i Int) ValueOrZero() int64 { // It supports number, string, and null input. // 0 will not be considered a null Int. func (i *Int) UnmarshalJSON(data []byte) error { - if bytes.Equal(data, nullBytes) { - i.Valid = false - return nil - } - - if err := json.Unmarshal(data, &i.Int64); err != nil { - var typeError *json.UnmarshalTypeError - if errors.As(err, &typeError) { - // special case: accept string input - if typeError.Value != "string" { - return fmt.Errorf("null: JSON input is invalid type (need int or string): %w", err) - } - var str string - if err := json.Unmarshal(data, &str); err != nil { - return fmt.Errorf("null: couldn't unmarshal number string: %w", err) - } - n, err := strconv.ParseInt(str, 10, 64) - if err != nil { - return fmt.Errorf("null: couldn't convert string to int: %w", err) - } - i.Int64 = n - i.Valid = true - return nil - } - return fmt.Errorf("null: couldn't unmarshal JSON: %w", err) - } - - i.Valid = true - return nil + return internal.UnmarshalIntJSON(data, &i.Int64, &i.Valid, 64, strconv.ParseInt) } // UnmarshalText implements encoding.TextUnmarshaler. // It will unmarshal to a null Int if the input is blank. // It will return an error if the input is not an integer, blank, or "null". func (i *Int) UnmarshalText(text []byte) error { - str := string(text) - if str == "" || str == "null" { - i.Valid = false - return nil - } - var err error - i.Int64, err = strconv.ParseInt(string(text), 10, 64) - if err != nil { - return fmt.Errorf("null: couldn't unmarshal text: %w", err) - } - i.Valid = true - return nil + return internal.UnmarshalIntText(text, &i.Int64, &i.Valid, 64, strconv.ParseInt) } // MarshalJSON implements json.Marshaler. @@ -142,3 +104,7 @@ func (i Int) IsZero() bool { func (i Int) Equal(other Int) bool { return i.Valid == other.Valid && (!i.Valid || i.Int64 == other.Int64) } + +func (i Int) value() (int64, bool) { + return i.Int64, i.Valid +} diff --git a/int16.go b/int16.go new file mode 100644 index 0000000..75df88c --- /dev/null +++ b/int16.go @@ -0,0 +1,107 @@ +package null + +import ( + "database/sql" + "strconv" + + "github.com/guregu/null/v5/internal" +) + +// Int16 is an nullable int16. +// It does not consider zero values to be null. +// It will decode to null, not zero, if null. +type Int16 struct { + sql.NullInt16 +} + +// NewInt16 creates a new Int16. +func NewInt16(i int16, valid bool) Int16 { + return Int16{ + NullInt16: sql.NullInt16{ + Int16: i, + Valid: valid, + }, + } +} + +// Int16From creates a new Int16 that will always be valid. +func Int16From(i int16) Int16 { + return NewInt16(i, true) +} + +// Int16FromPtr creates a new Int16 that be null if i is nil. +func Int16FromPtr(i *int16) Int16 { + if i == nil { + return NewInt16(0, false) + } + return NewInt16(*i, true) +} + +// ValueOrZero returns the inner value if valid, otherwise zero. +func (i Int16) ValueOrZero() int16 { + if !i.Valid { + return 0 + } + return i.Int16 +} + +// UnmarshalJSON implements json.Unmarshaler. +// It supports number, string, and null input. +// 0 will not be considered a null Int16. +func (i *Int16) UnmarshalJSON(data []byte) error { + return internal.UnmarshalIntJSON(data, &i.Int16, &i.Valid, 16, strconv.ParseInt) +} + +// UnmarshalText implements encoding.TextUnmarshaler. +// It will unmarshal to a null Int16 if the input is blank. +// It will return an error if the input is not an integer, blank, or "null". +func (i *Int16) UnmarshalText(text []byte) error { + return internal.UnmarshalIntText(text, &i.Int16, &i.Valid, 16, strconv.ParseInt) +} + +// MarshalJSON implements json.Marshaler. +// It will encode null if this Int16 is null. +func (i Int16) MarshalJSON() ([]byte, error) { + if !i.Valid { + return []byte("null"), nil + } + return []byte(strconv.FormatInt(int64(i.Int16), 10)), nil +} + +// MarshalText implements encoding.TextMarshaler. +// It will encode a blank string if this Int16 is null. +func (i Int16) MarshalText() ([]byte, error) { + if !i.Valid { + return []byte{}, nil + } + return []byte(strconv.FormatInt(int64(i.Int16), 10)), nil +} + +// SetValid changes this Int16's value and also sets it to be non-null. +func (i *Int16) SetValid(n int16) { + i.Int16 = n + i.Valid = true +} + +// Ptr returns a pointer to this Int16's value, or a nil pointer if this Int16 is null. +func (i Int16) Ptr() *int16 { + if !i.Valid { + return nil + } + return &i.Int16 +} + +// IsZero returns true for invalid Int16s, for future omitempty support (Go 1.4?) +// A non-null Int16 with a 0 value will not be considered zero. +func (i Int16) IsZero() bool { + return !i.Valid +} + +// Equal returns true if both ints have the same value or are both null. +func (i Int16) Equal(other Int16) bool { + return i.Valid == other.Valid && (!i.Valid || i.Int16 == other.Int16) +} + +func (i Int16) value() (int64, bool) { + return int64(i.Int16), i.Valid +} diff --git a/int32.go b/int32.go new file mode 100644 index 0000000..26195ca --- /dev/null +++ b/int32.go @@ -0,0 +1,107 @@ +package null + +import ( + "database/sql" + "strconv" + + "github.com/guregu/null/v5/internal" +) + +// Int32 is an nullable int32. +// It does not consider zero values to be null. +// It will decode to null, not zero, if null. +type Int32 struct { + sql.NullInt32 +} + +// NewInt32 creates a new Int32. +func NewInt32(i int32, valid bool) Int32 { + return Int32{ + NullInt32: sql.NullInt32{ + Int32: i, + Valid: valid, + }, + } +} + +// Int32From creates a new Int32 that will always be valid. +func Int32From(i int32) Int32 { + return NewInt32(i, true) +} + +// Int32FromPtr creates a new Int32 that be null if i is nil. +func Int32FromPtr(i *int32) Int32 { + if i == nil { + return NewInt32(0, false) + } + return NewInt32(*i, true) +} + +// ValueOrZero returns the inner value if valid, otherwise zero. +func (i Int32) ValueOrZero() int32 { + if !i.Valid { + return 0 + } + return i.Int32 +} + +// UnmarshalJSON implements json.Unmarshaler. +// It supports number, string, and null input. +// 0 will not be considered a null Int32. +func (i *Int32) UnmarshalJSON(data []byte) error { + return internal.UnmarshalIntJSON(data, &i.Int32, &i.Valid, 32, strconv.ParseInt) +} + +// UnmarshalText implements encoding.TextUnmarshaler. +// It will unmarshal to a null Int32 if the input is blank. +// It will return an error if the input is not an integer, blank, or "null". +func (i *Int32) UnmarshalText(text []byte) error { + return internal.UnmarshalIntText(text, &i.Int32, &i.Valid, 32, strconv.ParseInt) +} + +// MarshalJSON implements json.Marshaler. +// It will encode null if this Int32 is null. +func (i Int32) MarshalJSON() ([]byte, error) { + if !i.Valid { + return []byte("null"), nil + } + return []byte(strconv.FormatInt(int64(i.Int32), 10)), nil +} + +// MarshalText implements encoding.TextMarshaler. +// It will encode a blank string if this Int32 is null. +func (i Int32) MarshalText() ([]byte, error) { + if !i.Valid { + return []byte{}, nil + } + return []byte(strconv.FormatInt(int64(i.Int32), 10)), nil +} + +// SetValid changes this Int32's value and also sets it to be non-null. +func (i *Int32) SetValid(n int32) { + i.Int32 = n + i.Valid = true +} + +// Ptr returns a pointer to this Int32's value, or a nil pointer if this Int32 is null. +func (i Int32) Ptr() *int32 { + if !i.Valid { + return nil + } + return &i.Int32 +} + +// IsZero returns true for invalid Int32s, for future omitempty support (Go 1.4?) +// A non-null Int32 with a 0 value will not be considered zero. +func (i Int32) IsZero() bool { + return !i.Valid +} + +// Equal returns true if both ints have the same value or are both null. +func (i Int32) Equal(other Int32) bool { + return i.Valid == other.Valid && (!i.Valid || i.Int32 == other.Int32) +} + +func (i Int32) value() (int64, bool) { + return int64(i.Int32), i.Valid +} diff --git a/int_test.go b/int_test.go index e545687..8b0e897 100644 --- a/int_test.go +++ b/int_test.go @@ -1,269 +1,393 @@ package null import ( + "encoding" "encoding/json" "errors" "math" "strconv" "testing" + + "github.com/guregu/null/v5/internal" ) var ( - intJSON = []byte(`12345`) - intStringJSON = []byte(`"12345"`) - nullIntJSON = []byte(`{"Int64":12345,"Valid":true}`) + intJSON = []byte(`123`) + intStringJSON = []byte(`"123"`) ) +type nullint interface { + Int | Int32 | Int16 | Byte + IsZero() bool + value() (int64, bool) +} + func TestIntFrom(t *testing.T) { - i := IntFrom(12345) - assertInt(t, i, "IntFrom()") + testIntFrom(t, IntFrom) + testIntFrom(t, Int32From) + testIntFrom(t, Int16From) + testIntFrom(t, ByteFrom) +} - zero := IntFrom(0) - if !zero.Valid { - t.Error("IntFrom(0)", "is invalid, but should be valid") - } +func testIntFrom[N nullint, V internal.Integer](t *testing.T, from func(V) N) { + t.Run(internal.TypeName[N](), func(t *testing.T) { + i := from(123) + assertInt(t, i, "from(123)") + + zero := from(0) + _, valid := zero.value() + if !valid { + t.Error("from(0)", "is invalid, but should be valid") + } + }) } func TestIntFromPtr(t *testing.T) { - n := int64(12345) - iptr := &n - i := IntFromPtr(iptr) - assertInt(t, i, "IntFromPtr()") - - null := IntFromPtr(nil) - assertNullInt(t, null, "IntFromPtr(nil)") + testIntFromPtr(t, IntFromPtr) + testIntFromPtr(t, Int32FromPtr) + testIntFromPtr(t, Int16FromPtr) + testIntFromPtr(t, ByteFromPtr) } -func TestUnmarshalInt(t *testing.T) { - var i Int - err := json.Unmarshal(intJSON, &i) - maybePanic(err) - assertInt(t, i, "int json") - - var si Int - err = json.Unmarshal(intStringJSON, &si) - maybePanic(err) - assertInt(t, si, "int string json") - - var ni Int - err = json.Unmarshal(nullIntJSON, &ni) - if err == nil { - panic("err should not be nill") - } - - var bi Int - err = json.Unmarshal(floatBlankJSON, &bi) - if err == nil { - panic("err should not be nill") - } +func testIntFromPtr[N nullint, V internal.Integer](t *testing.T, fromPtr func(*V) N) { + t.Run(internal.TypeName[N](), func(t *testing.T) { + n := V(123) + iptr := &n + i := fromPtr(iptr) + assertInt(t, i, "fromPtr()") - var null Int - err = json.Unmarshal(nullJSON, &null) - maybePanic(err) - assertNullInt(t, null, "null json") + null := fromPtr(nil) + assertNullInt(t, null, "fromPtr(nil)") + }) +} - var badType Int - err = json.Unmarshal(boolJSON, &badType) - if err == nil { - panic("err should not be nil") - } - assertNullInt(t, badType, "wrong type json") +func TestUnmarshalInt(t *testing.T) { + testUnmarshalInt[Int](t) + testUnmarshalInt[Int32](t) + testUnmarshalInt[Int16](t) + testUnmarshalInt[Byte](t) +} - var invalid Int - err = invalid.UnmarshalJSON(invalidJSON) - var syntaxError *json.SyntaxError - if !errors.As(err, &syntaxError) { - t.Errorf("expected wrapped json.SyntaxError, not %T", err) - } - assertNullInt(t, invalid, "invalid json") +func testUnmarshalInt[N nullint](t *testing.T) { + t.Run(internal.TypeName[N](), func(t *testing.T) { + var i N + err := json.Unmarshal(intJSON, &i) + maybePanic(err) + assertInt(t, i, "int json") + + var si N + err = json.Unmarshal(intStringJSON, &si) + maybePanic(err) + assertInt(t, si, "int string json") + + var bi N + err = json.Unmarshal(floatBlankJSON, &bi) + if err == nil { + panic("err should not be nill") + } + + var null N + err = json.Unmarshal(nullJSON, &null) + maybePanic(err) + assertNullInt(t, null, "null json") + + var badType N + err = json.Unmarshal(boolJSON, &badType) + if err == nil { + panic("err should not be nil") + } + assertNullInt(t, badType, "wrong type json") + + var invalid N + err = json.Unmarshal(invalidJSON, &invalid) + var syntaxError *json.SyntaxError + if !errors.As(err, &syntaxError) { + t.Errorf("expected wrapped json.SyntaxError, not %T", err) + } + assertNullInt(t, invalid, "invalid json") + }) } func TestUnmarshalNonIntegerNumber(t *testing.T) { var i Int err := json.Unmarshal(floatJSON, &i) if err == nil { - panic("err should be present; non-integer number coerced to int") + panic("err should be present; non-internal.Integer number coerced to int") } } -func TestUnmarshalInt64Overflow(t *testing.T) { - int64Overflow := uint64(math.MaxInt64) - - // Max int64 should decode successfully - var i Int - err := json.Unmarshal([]byte(strconv.FormatUint(int64Overflow, 10)), &i) - maybePanic(err) +func TestUnmarshalIntOverflow(t *testing.T) { + testUnmarshalIntOverflow[Int, int64](t, math.MaxInt64) + testUnmarshalIntOverflow[Int32, int32](t, math.MaxInt32) + testUnmarshalIntOverflow[Int16, int16](t, math.MaxInt16) + testUnmarshalIntOverflow[Byte, byte](t, math.MaxUint8) +} - // Attempt to overflow - int64Overflow++ - err = json.Unmarshal([]byte(strconv.FormatUint(int64Overflow, 10)), &i) - if err == nil { - panic("err should be present; decoded value overflows int64") - } +func testUnmarshalIntOverflow[N nullint, V internal.Integer](t *testing.T, max V) { + t.Run(internal.TypeName[N](), func(t *testing.T) { + overflow := uint64(max) + + // Max int64 should decode successfully + var i N + err := json.Unmarshal([]byte(strconv.FormatUint(overflow, 10)), &i) + maybePanic(err) + + // Attempt to overflow + overflow++ + err = json.Unmarshal([]byte(strconv.FormatUint(overflow, 10)), &i) + if err == nil { + t.Error("err should be present but isn't; decoded value overflows") + } + }) } func TestTextUnmarshalInt(t *testing.T) { - var i Int - err := i.UnmarshalText([]byte("12345")) - maybePanic(err) - assertInt(t, i, "UnmarshalText() int") - - var blank Int - err = blank.UnmarshalText([]byte("")) - maybePanic(err) - assertNullInt(t, blank, "UnmarshalText() empty int") - - var null Int - err = null.UnmarshalText([]byte("null")) - maybePanic(err) - assertNullInt(t, null, `UnmarshalText() "null"`) - - var invalid Int - err = invalid.UnmarshalText([]byte("hello world")) - if err == nil { - panic("expected error") - } + testTextUnmarshalInt(t, (*Int).UnmarshalText) + testTextUnmarshalInt(t, (*Int32).UnmarshalText) + testTextUnmarshalInt(t, (*Int16).UnmarshalText) + testTextUnmarshalInt(t, (*Byte).UnmarshalText) +} + +func testTextUnmarshalInt[N nullint](t *testing.T, unmarshal func(*N, []byte) error) { + t.Run(internal.TypeName[N](), func(t *testing.T) { + var i N + err := unmarshal(&i, []byte("123")) + maybePanic(err) + assertInt(t, i, "unmarshal int") + + var blank N + err = unmarshal(&blank, []byte("")) + maybePanic(err) + assertNullInt(t, blank, "unmarshal empty int") + + var null N + err = unmarshal(&null, []byte("null")) + maybePanic(err) + assertNullInt(t, null, `unmarshal "null"`) + + var invalid N + err = unmarshal(&invalid, []byte("hello world")) + if err == nil { + panic("expected error") + } + }) } func TestMarshalInt(t *testing.T) { - i := IntFrom(12345) - data, err := json.Marshal(i) - maybePanic(err) - assertJSONEquals(t, data, "12345", "non-empty json marshal") + testMarshalInt(t, NewInt) + testMarshalInt(t, NewInt32) + testMarshalInt(t, NewInt16) + testMarshalInt(t, NewByte) +} - // invalid values should be encoded as null - null := NewInt(0, false) - data, err = json.Marshal(null) - maybePanic(err) - assertJSONEquals(t, data, "null", "null json marshal") +func testMarshalInt[N interface{ ValueOrZero() V }, V internal.Integer](t *testing.T, newInt func(V, bool) N) { + t.Run(internal.TypeName[N](), func(t *testing.T) { + i := newInt(123, true) + data, err := json.Marshal(i) + maybePanic(err) + assertJSONEquals(t, data, "123", "non-empty json marshal") + + // invalid values should be encoded as null + null := newInt(0, false) + data, err = json.Marshal(null) + maybePanic(err) + assertJSONEquals(t, data, "null", "null json marshal") + }) } func TestMarshalIntText(t *testing.T) { - i := IntFrom(12345) - data, err := i.MarshalText() - maybePanic(err) - assertJSONEquals(t, data, "12345", "non-empty text marshal") + testMarshalIntText(t, NewInt) + testMarshalIntText(t, NewInt32) + testMarshalIntText(t, NewInt16) + testMarshalIntText(t, NewByte) +} - // invalid values should be encoded as null - null := NewInt(0, false) - data, err = null.MarshalText() - maybePanic(err) - assertJSONEquals(t, data, "", "null text marshal") +func testMarshalIntText[N encoding.TextMarshaler, V internal.Integer](t *testing.T, newInt func(V, bool) N) { + t.Run(internal.TypeName[N](), func(t *testing.T) { + i := newInt(123, true) + data, err := i.MarshalText() + maybePanic(err) + assertJSONEquals(t, data, "123", "non-empty text marshal") + + // invalid values should be encoded as null + null := newInt(0, false) + data, err = null.MarshalText() + maybePanic(err) + assertJSONEquals(t, data, "", "null text marshal") + }) } func TestIntPointer(t *testing.T) { - i := IntFrom(12345) - ptr := i.Ptr() - if *ptr != 12345 { - t.Errorf("bad %s int: %#v ≠ %d\n", "pointer", ptr, 12345) - } + testIntPointer(t, NewInt) + testIntPointer(t, NewInt32) + testIntPointer(t, NewInt16) + testIntPointer(t, NewByte) +} - null := NewInt(0, false) - ptr = null.Ptr() - if ptr != nil { - t.Errorf("bad %s int: %#v ≠ %s\n", "nil pointer", ptr, "nil") - } +func testIntPointer[N interface{ Ptr() *V }, V internal.Integer](t *testing.T, newInt func(V, bool) N) { + t.Run(internal.TypeName[N](), func(t *testing.T) { + i := newInt(123, true) + ptr := i.Ptr() + if *ptr != 123 { + t.Errorf("bad %s int: %#v ≠ %d\n", "pointer", ptr, 123) + } + + null := newInt(0, false) + ptr = null.Ptr() + if ptr != nil { + t.Errorf("bad %s int: %#v ≠ %s\n", "nil pointer", ptr, "nil") + } + }) } func TestIntIsZero(t *testing.T) { - i := IntFrom(12345) - if i.IsZero() { - t.Errorf("IsZero() should be false") - } - - null := NewInt(0, false) - if !null.IsZero() { - t.Errorf("IsZero() should be true") - } + testIntIsZero(t, NewInt) + testIntIsZero(t, NewInt32) + testIntIsZero(t, NewInt16) + testIntIsZero(t, NewByte) +} - zero := NewInt(0, true) - if zero.IsZero() { - t.Errorf("IsZero() should be false") - } +func testIntIsZero[N nullint, V internal.Integer](t *testing.T, newInt func(V, bool) N) { + t.Run(internal.TypeName[N](), func(t *testing.T) { + i := newInt(123, true) + if i.IsZero() { + t.Errorf("IsZero() should be false") + } + + null := newInt(0, false) + if !null.IsZero() { + t.Errorf("IsZero() should be true") + } + + zero := newInt(0, true) + if zero.IsZero() { + t.Errorf("IsZero() should be false") + } + }) } func TestIntSetValid(t *testing.T) { - change := NewInt(0, false) - assertNullInt(t, change, "SetValid()") - change.SetValid(12345) - assertInt(t, change, "SetValid()") + testIntSetValid(t, NewInt, (*Int).SetValid) + testIntSetValid(t, NewInt32, (*Int32).SetValid) + testIntSetValid(t, NewInt16, (*Int16).SetValid) + testIntSetValid(t, NewByte, (*Byte).SetValid) +} + +func testIntSetValid[N nullint, V internal.Integer](t *testing.T, newInt func(V, bool) N, setValid func(*N, V)) { + t.Run(internal.TypeName[N](), func(t *testing.T) { + change := newInt(0, false) + assertNullInt(t, change, "SetValid()") + setValid(&change, 123) + assertInt(t, change, "SetValid()") + }) } func TestIntScan(t *testing.T) { - var i Int - err := i.Scan(12345) - maybePanic(err) - assertInt(t, i, "scanned int") + testIntScan(t, (*Int).Scan) + testIntScan(t, (*Int32).Scan) + testIntScan(t, (*Int16).Scan) + testIntScan(t, (*Byte).Scan) +} - var null Int - err = null.Scan(nil) - maybePanic(err) - assertNullInt(t, null, "scanned null") +func testIntScan[N nullint](t *testing.T, scan func(*N, any) error) { + t.Run(internal.TypeName[N](), func(t *testing.T) { + var i N + err := scan(&i, 123) + maybePanic(err) + assertInt(t, i, "scanned int") + + var null N + err = scan(&null, nil) + maybePanic(err) + assertNullInt(t, null, "scanned null") + }) } func TestIntValueOrZero(t *testing.T) { - valid := NewInt(12345, true) - if valid.ValueOrZero() != 12345 { - t.Error("unexpected ValueOrZero", valid.ValueOrZero()) - } + testIntValueOrZero(t, NewInt) + testIntValueOrZero(t, NewInt32) + testIntValueOrZero(t, NewInt16) + testIntValueOrZero(t, NewByte) +} - invalid := NewInt(12345, false) - if invalid.ValueOrZero() != 0 { - t.Error("unexpected ValueOrZero", invalid.ValueOrZero()) - } +func testIntValueOrZero[N interface{ ValueOrZero() V }, V internal.Integer](t *testing.T, newInt func(V, bool) N) { + t.Run(internal.TypeName[N](), func(t *testing.T) { + valid := newInt(123, true) + if valid.ValueOrZero() != 123 { + t.Error("unexpected ValueOrZero", valid.ValueOrZero()) + } + + invalid := newInt(123, false) + if invalid.ValueOrZero() != 0 { + t.Error("unexpected ValueOrZero", invalid.ValueOrZero()) + } + }) } func TestIntEqual(t *testing.T) { - int1 := NewInt(10, false) - int2 := NewInt(10, false) - assertIntEqualIsTrue(t, int1, int2) + testIntEqual(t, NewInt) + testIntEqual(t, NewInt32) + testIntEqual(t, NewInt16) + testIntEqual(t, NewByte) +} - int1 = NewInt(10, false) - int2 = NewInt(20, false) - assertIntEqualIsTrue(t, int1, int2) +func testIntEqual[N interface{ Equal(N) bool }, V internal.Integer](t *testing.T, newInt func(V, bool) N) { + t.Run(internal.TypeName[N](), func(t *testing.T) { + int1 := newInt(10, false) + int2 := newInt(10, false) + assertIntEqualIsTrue(t, int1, int2) - int1 = NewInt(10, true) - int2 = NewInt(10, true) - assertIntEqualIsTrue(t, int1, int2) + int1 = newInt(10, false) + int2 = newInt(20, false) + assertIntEqualIsTrue(t, int1, int2) - int1 = NewInt(10, true) - int2 = NewInt(10, false) - assertIntEqualIsFalse(t, int1, int2) + int1 = newInt(10, true) + int2 = newInt(10, true) + assertIntEqualIsTrue(t, int1, int2) - int1 = NewInt(10, false) - int2 = NewInt(10, true) - assertIntEqualIsFalse(t, int1, int2) + int1 = newInt(10, true) + int2 = newInt(10, false) + assertIntEqualIsFalse(t, int1, int2) - int1 = NewInt(10, true) - int2 = NewInt(20, true) - assertIntEqualIsFalse(t, int1, int2) + int1 = newInt(10, false) + int2 = newInt(10, true) + assertIntEqualIsFalse(t, int1, int2) + + int1 = newInt(10, true) + int2 = newInt(20, true) + assertIntEqualIsFalse(t, int1, int2) + }) } -func assertInt(t *testing.T, i Int, from string) { - if i.Int64 != 12345 { - t.Errorf("bad %s int: %d ≠ %d\n", from, i.Int64, 12345) +func assertInt(t *testing.T, i interface{ value() (int64, bool) }, from string) { + t.Helper() + n, valid := i.value() + if n != 123 { + t.Errorf("bad %s int: %d ≠ %d\n", from, n, 123) } - if !i.Valid { + if !valid { t.Error(from, "is invalid, but should be valid") } } -func assertNullInt(t *testing.T, i Int, from string) { - if i.Valid { +func assertNullInt(t *testing.T, i interface{ value() (int64, bool) }, from string) { + t.Helper() + _, valid := i.value() + if valid { t.Error(from, "is valid, but should be invalid") } } -func assertIntEqualIsTrue(t *testing.T, a, b Int) { +func assertIntEqualIsTrue[N interface{ Equal(N) bool }](t *testing.T, a, b N) { t.Helper() if !a.Equal(b) { - t.Errorf("Equal() of Int{%v, Valid:%t} and Int{%v, Valid:%t} should return true", a.Int64, a.Valid, b.Int64, b.Valid) + t.Errorf("Equal() of %#v and %#v should return true", a, b) } } -func assertIntEqualIsFalse(t *testing.T, a, b Int) { +func assertIntEqualIsFalse[N interface{ Equal(N) bool }](t *testing.T, a, b N) { t.Helper() if a.Equal(b) { - t.Errorf("Equal() of Int{%v, Valid:%t} and Int{%v, Valid:%t} should return false", a.Int64, a.Valid, b.Int64, b.Valid) + t.Errorf("Equal() of %#v and %#v should return false", a, b) } } diff --git a/internal/int.go b/internal/int.go new file mode 100644 index 0000000..5267b58 --- /dev/null +++ b/internal/int.go @@ -0,0 +1,58 @@ +package internal + +import ( + "encoding/json" + "fmt" +) + +type Integer interface { + int64 | int32 | int16 | byte +} + +func UnmarshalIntJSON[T Integer, U int64 | uint64](data []byte, value *T, valid *bool, bits int, parse func(string, int, int) (U, error)) error { + if len(data) == 0 { + return fmt.Errorf("UnmarshalJSON: no data") + } + + switch data[0] { + case 'n': + *value = 0 + *valid = false + return nil + + case '"': + var str string + if err := json.Unmarshal(data, &str); err != nil { + return fmt.Errorf("null: couldn't unmarshal number string: %w", err) + } + n, err := parse(str, 10, bits) + if err != nil { + return fmt.Errorf("null: couldn't convert string to int: %w", err) + } + *value = T(n) + *valid = true + return nil + + default: + err := json.Unmarshal(data, value) + *valid = err == nil + return err + } +} + +func UnmarshalIntText[T Integer, U int64 | uint64](text []byte, value *T, valid *bool, bits int, parse func(string, int, int) (U, error)) error { + str := string(text) + if str == "" || str == "null" { + *value = 0 + *valid = false + return nil + } + n, err := parse(str, 10, bits) + *value = T(n) + if err != nil { + *valid = false + return fmt.Errorf("null: couldn't unmarshal text: %w", err) + } + *valid = true + return nil +} diff --git a/internal/type.go b/internal/type.go new file mode 100644 index 0000000..fde06b8 --- /dev/null +++ b/internal/type.go @@ -0,0 +1,7 @@ +package internal + +import "fmt" + +func TypeName[T any]() string { + return fmt.Sprintf("%T", *(new(T))) +} diff --git a/zero/byte.go b/zero/byte.go new file mode 100644 index 0000000..f911c6d --- /dev/null +++ b/zero/byte.go @@ -0,0 +1,119 @@ +package zero + +import ( + "database/sql" + "strconv" + + "github.com/guregu/null/v5/internal" +) + +// Byte is a nullable byte. +// JSON marshals to zero if null. +// Considered null to SQL if zero. +type Byte struct { + sql.NullByte +} + +// NewByte creates a new Byte +func NewByte(i byte, valid bool) Byte { + return Byte{ + NullByte: sql.NullByte{ + Byte: i, + Valid: valid, + }, + } +} + +// ByteFrom creates a new Byte that will be null if zero. +func ByteFrom(i byte) Byte { + return NewByte(i, i != 0) +} + +// ByteFromPtr creates a new Byte that be null if i is nil. +func ByteFromPtr(i *byte) Byte { + if i == nil { + return NewByte(0, false) + } + n := NewByte(*i, true) + return n +} + +// ValueOrZero returns the inner value if valid, otherwise zero. +func (b Byte) ValueOrZero() byte { + if !b.Valid { + return 0 + } + return b.Byte +} + +// UnmarshalJSON implements json.Unmarshaler. +// It supports number and null input. +// 0 will be considered a null Byte. +func (b *Byte) UnmarshalJSON(data []byte) error { + err := internal.UnmarshalIntJSON(data, &b.Byte, &b.Valid, 8, strconv.ParseUint) + if err != nil { + return err + } + b.Valid = b.Byte != 0 + return nil +} + +// UnmarshalText implements encoding.TextUnmarshaler. +// It will unmarshal to a null Byte if the input is a blank, or zero. +// It will return an error if the input is not an integer, blank, or "null". +func (b *Byte) UnmarshalText(text []byte) error { + err := internal.UnmarshalIntText(text, &b.Byte, &b.Valid, 8, strconv.ParseUint) + if err != nil { + return err + } + b.Valid = b.Byte != 0 + return nil +} + +// MarshalJSON implements json.Marshaler. +// It will encode 0 if this Byte is null. +func (b Byte) MarshalJSON() ([]byte, error) { + n := b.Byte + if !b.Valid { + n = 0 + } + return []byte(strconv.FormatInt(int64(n), 10)), nil +} + +// MarshalText implements encoding.TextMarshaler. +// It will encode a zero if this Byte is null. +func (b Byte) MarshalText() ([]byte, error) { + n := b.Byte + if !b.Valid { + n = 0 + } + return []byte(strconv.FormatInt(int64(n), 10)), nil +} + +// SetValid changes this Byte's value and also sets it to be non-null. +func (b *Byte) SetValid(n byte) { + b.Byte = n + b.Valid = true +} + +// Ptr returns a pointer to this Byte's value, or a nil pointer if this Byte is null. +func (b Byte) Ptr() *byte { + if !b.Valid { + return nil + } + return &b.Byte +} + +// IsZero returns true for null or zero Bytes, for future omitempty support (Go 1.4?) +func (b Byte) IsZero() bool { + return !b.Valid || b.Byte == 0 +} + +// Equal returns true if both ints have the same value or are both either null or zero. +func (b Byte) Equal(other Byte) bool { + return b.ValueOrZero() == other.ValueOrZero() +} + +func (b Byte) value() (int64, bool) { + return int64(b.Byte), b.Valid +} diff --git a/zero/int.go b/zero/int.go index f7092ff..1673255 100644 --- a/zero/int.go +++ b/zero/int.go @@ -1,12 +1,10 @@ package zero import ( - "bytes" "database/sql" - "encoding/json" - "errors" - "fmt" "strconv" + + "github.com/guregu/null/v5/internal" ) // Int is a nullable int64. @@ -16,6 +14,9 @@ type Int struct { sql.NullInt64 } +// Int64 is an alias for Int. +type Int64 = Int + // NewInt creates a new Int func NewInt(i int64, valid bool) Int { return Int{ @@ -52,33 +53,10 @@ func (i Int) ValueOrZero() int64 { // It supports number and null input. // 0 will be considered a null Int. func (i *Int) UnmarshalJSON(data []byte) error { - if bytes.Equal(data, nullBytes) { - i.Valid = false - return nil - } - - if err := json.Unmarshal(data, &i.Int64); err != nil { - var typeError *json.UnmarshalTypeError - if errors.As(err, &typeError) { - // special case: accept string input - if typeError.Value != "string" { - return fmt.Errorf("zero: JSON input is invalid type (need int or string): %w", err) - } - var str string - if err := json.Unmarshal(data, &str); err != nil { - return fmt.Errorf("zero: couldn't unmarshal number string: %w", err) - } - n, err := strconv.ParseInt(str, 10, 64) - if err != nil { - return fmt.Errorf("zero: couldn't convert string to int: %w", err) - } - i.Int64 = n - i.Valid = n != 0 - return nil - } - return fmt.Errorf("zero: couldn't unmarshal JSON: %w", err) + err := internal.UnmarshalIntJSON(data, &i.Int64, &i.Valid, 64, strconv.ParseInt) + if err != nil { + return err } - i.Valid = i.Int64 != 0 return nil } @@ -87,18 +65,12 @@ func (i *Int) UnmarshalJSON(data []byte) error { // It will unmarshal to a null Int if the input is a blank, or zero. // It will return an error if the input is not an integer, blank, or "null". func (i *Int) UnmarshalText(text []byte) error { - str := string(text) - if str == "" || str == "null" { - i.Valid = false - return nil - } - var err error - i.Int64, err = strconv.ParseInt(string(text), 10, 64) + err := internal.UnmarshalIntText(text, &i.Int64, &i.Valid, 64, strconv.ParseInt) if err != nil { - return fmt.Errorf("zero: couldn't unmarshal text: %w", err) + return err } i.Valid = i.Int64 != 0 - return err + return nil } // MarshalJSON implements json.Marshaler. @@ -144,3 +116,7 @@ func (i Int) IsZero() bool { func (i Int) Equal(other Int) bool { return i.ValueOrZero() == other.ValueOrZero() } + +func (i Int) value() (int64, bool) { + return i.Int64, i.Valid +} diff --git a/zero/int16.go b/zero/int16.go new file mode 100644 index 0000000..ba0e7ad --- /dev/null +++ b/zero/int16.go @@ -0,0 +1,119 @@ +package zero + +import ( + "database/sql" + "strconv" + + "github.com/guregu/null/v5/internal" +) + +// Int16 is a nullable int16. +// JSON marshals to zero if null. +// Considered null to SQL if zero. +type Int16 struct { + sql.NullInt16 +} + +// NewInt16 creates a new Int16 +func NewInt16(i int16, valid bool) Int16 { + return Int16{ + NullInt16: sql.NullInt16{ + Int16: i, + Valid: valid, + }, + } +} + +// Int16From creates a new Int16 that will be null if zero. +func Int16From(i int16) Int16 { + return NewInt16(i, i != 0) +} + +// Int16FromPtr creates a new Int16 that be null if i is nil. +func Int16FromPtr(i *int16) Int16 { + if i == nil { + return NewInt16(0, false) + } + n := NewInt16(*i, true) + return n +} + +// ValueOrZero returns the inner value if valid, otherwise zero. +func (i Int16) ValueOrZero() int16 { + if !i.Valid { + return 0 + } + return i.Int16 +} + +// UnmarshalJSON implements json.Unmarshaler. +// It supports number and null input. +// 0 will be considered a null Int16. +func (i *Int16) UnmarshalJSON(data []byte) error { + err := internal.UnmarshalIntJSON(data, &i.Int16, &i.Valid, 16, strconv.ParseInt) + if err != nil { + return err + } + i.Valid = i.Int16 != 0 + return nil +} + +// UnmarshalText implements encoding.TextUnmarshaler. +// It will unmarshal to a null Int16 if the input is a blank, or zero. +// It will return an error if the input is not an integer, blank, or "null". +func (i *Int16) UnmarshalText(text []byte) error { + err := internal.UnmarshalIntText(text, &i.Int16, &i.Valid, 16, strconv.ParseInt) + if err != nil { + return err + } + i.Valid = i.Int16 != 0 + return nil +} + +// MarshalJSON implements json.Marshaler. +// It will encode 0 if this Int16 is null. +func (i Int16) MarshalJSON() ([]byte, error) { + n := i.Int16 + if !i.Valid { + n = 0 + } + return []byte(strconv.FormatInt(int64(n), 10)), nil +} + +// MarshalText implements encoding.TextMarshaler. +// It will encode a zero if this Int16 is null. +func (i Int16) MarshalText() ([]byte, error) { + n := i.Int16 + if !i.Valid { + n = 0 + } + return []byte(strconv.FormatInt(int64(n), 10)), nil +} + +// SetValid changes this Int16's value and also sets it to be non-null. +func (i *Int16) SetValid(n int16) { + i.Int16 = n + i.Valid = true +} + +// Ptr returns a pointer to this Int16's value, or a nil pointer if this Int16 is null. +func (i Int16) Ptr() *int16 { + if !i.Valid { + return nil + } + return &i.Int16 +} + +// IsZero returns true for null or zero Int16s, for future omitempty support (Go 1.4?) +func (i Int16) IsZero() bool { + return !i.Valid || i.Int16 == 0 +} + +// Equal returns true if both ints have the same value or are both either null or zero. +func (i Int16) Equal(other Int16) bool { + return i.ValueOrZero() == other.ValueOrZero() +} + +func (i Int16) value() (int64, bool) { + return int64(i.Int16), i.Valid +} diff --git a/zero/int32.go b/zero/int32.go new file mode 100644 index 0000000..e8227ff --- /dev/null +++ b/zero/int32.go @@ -0,0 +1,119 @@ +package zero + +import ( + "database/sql" + "strconv" + + "github.com/guregu/null/v5/internal" +) + +// Int32 is a nullable int32. +// JSON marshals to zero if null. +// Considered null to SQL if zero. +type Int32 struct { + sql.NullInt32 +} + +// NewInt32 creates a new Int32 +func NewInt32(i int32, valid bool) Int32 { + return Int32{ + NullInt32: sql.NullInt32{ + Int32: i, + Valid: valid, + }, + } +} + +// Int32From creates a new Int32 that will be null if zero. +func Int32From(i int32) Int32 { + return NewInt32(i, i != 0) +} + +// Int32FromPtr creates a new Int32 that be null if i is nil. +func Int32FromPtr(i *int32) Int32 { + if i == nil { + return NewInt32(0, false) + } + n := NewInt32(*i, true) + return n +} + +// ValueOrZero returns the inner value if valid, otherwise zero. +func (i Int32) ValueOrZero() int32 { + if !i.Valid { + return 0 + } + return i.Int32 +} + +// UnmarshalJSON implements json.Unmarshaler. +// It supports number and null input. +// 0 will be considered a null Int32. +func (i *Int32) UnmarshalJSON(data []byte) error { + err := internal.UnmarshalIntJSON(data, &i.Int32, &i.Valid, 32, strconv.ParseInt) + if err != nil { + return err + } + i.Valid = i.Int32 != 0 + return nil +} + +// UnmarshalText implements encoding.TextUnmarshaler. +// It will unmarshal to a null Int32 if the input is a blank, or zero. +// It will return an error if the input is not an integer, blank, or "null". +func (i *Int32) UnmarshalText(text []byte) error { + err := internal.UnmarshalIntText(text, &i.Int32, &i.Valid, 32, strconv.ParseInt) + if err != nil { + return err + } + i.Valid = i.Int32 != 0 + return nil +} + +// MarshalJSON implements json.Marshaler. +// It will encode 0 if this Int32 is null. +func (i Int32) MarshalJSON() ([]byte, error) { + n := i.Int32 + if !i.Valid { + n = 0 + } + return []byte(strconv.FormatInt(int64(n), 10)), nil +} + +// MarshalText implements encoding.TextMarshaler. +// It will encode a zero if this Int32 is null. +func (i Int32) MarshalText() ([]byte, error) { + n := i.Int32 + if !i.Valid { + n = 0 + } + return []byte(strconv.FormatInt(int64(n), 10)), nil +} + +// SetValid changes this Int32's value and also sets it to be non-null. +func (i *Int32) SetValid(n int32) { + i.Int32 = n + i.Valid = true +} + +// Ptr returns a pointer to this Int32's value, or a nil pointer if this Int32 is null. +func (i Int32) Ptr() *int32 { + if !i.Valid { + return nil + } + return &i.Int32 +} + +// IsZero returns true for null or zero Int32s, for future omitempty support (Go 1.4?) +func (i Int32) IsZero() bool { + return !i.Valid || i.Int32 == 0 +} + +// Equal returns true if both ints have the same value or are both either null or zero. +func (i Int32) Equal(other Int32) bool { + return i.ValueOrZero() == other.ValueOrZero() +} + +func (i Int32) value() (int64, bool) { + return int64(i.Int32), i.Valid +} diff --git a/zero/int_test.go b/zero/int_test.go index 6f71125..c1396ec 100644 --- a/zero/int_test.go +++ b/zero/int_test.go @@ -1,87 +1,117 @@ package zero import ( + "encoding" "encoding/json" "errors" "math" "strconv" "testing" + + "github.com/guregu/null/v5/internal" ) var ( - intJSON = []byte(`12345`) - intStringJSON = []byte(`"12345"`) - nullIntJSON = []byte(`{"Int64":12345,"Valid":true}`) + intJSON = []byte(`123`) + intStringJSON = []byte(`"123"`) zeroJSON = []byte(`0`) ) +type nullint interface { + Int | Int32 | Int16 | Byte + IsZero() bool + value() (int64, bool) +} + func TestIntFrom(t *testing.T) { - i := IntFrom(12345) - assertInt(t, i, "IntFrom()") + testIntFrom(t, IntFrom) + testIntFrom(t, Int32From) + testIntFrom(t, Int16From) + testIntFrom(t, ByteFrom) +} - zero := IntFrom(0) - if zero.Valid { - t.Error("IntFrom(0)", "is valid, but should be invalid") - } +func testIntFrom[N nullint, V internal.Integer](t *testing.T, from func(V) N) { + t.Run(internal.TypeName[N](), func(t *testing.T) { + i := from(123) + assertInt(t, i, "from(123)") + + zero := from(0) + _, valid := zero.value() + if valid { + t.Error("from(0)", "is valid, but should be invalid") + } + }) } func TestIntFromPtr(t *testing.T) { - n := int64(12345) - iptr := &n - i := IntFromPtr(iptr) - assertInt(t, i, "IntFromPtr()") - - null := IntFromPtr(nil) - assertNullInt(t, null, "IntFromPtr(nil)") + testIntFromPtr(t, IntFromPtr) + testIntFromPtr(t, Int32FromPtr) + testIntFromPtr(t, Int16FromPtr) + testIntFromPtr(t, ByteFromPtr) } -func TestUnmarshalInt(t *testing.T) { - var i Int - err := json.Unmarshal(intJSON, &i) - maybePanic(err) - assertInt(t, i, "int json") - - var si Int - err = json.Unmarshal(intStringJSON, &si) - maybePanic(err) - assertInt(t, si, "int string json") +func testIntFromPtr[N nullint, V internal.Integer](t *testing.T, fromPtr func(*V) N) { + t.Run(internal.TypeName[N](), func(t *testing.T) { + n := V(123) + iptr := &n + i := fromPtr(iptr) + assertInt(t, i, "fromPtr()") - var ni Int - err = json.Unmarshal(nullIntJSON, &ni) - if err == nil { - panic("expected error") - } - - var bi Int - err = json.Unmarshal(floatBlankJSON, &bi) - if err == nil { - panic("expected error") - } - - var zero Int - err = json.Unmarshal(zeroJSON, &zero) - maybePanic(err) - assertNullInt(t, zero, "zero json") - - var null Int - err = json.Unmarshal(nullJSON, &null) - maybePanic(err) - assertNullInt(t, null, "null json") + null := fromPtr(nil) + assertNullInt(t, null, "fromPtr(nil)") + }) +} - var badType Int - err = json.Unmarshal(boolJSON, &badType) - if err == nil { - panic("err should not be nil") - } - assertNullInt(t, badType, "wrong type json") +func TestUnmarshalInt(t *testing.T) { + testUnmarshalInt[Int](t) + testUnmarshalInt[Int32](t) + testUnmarshalInt[Int16](t) + testUnmarshalInt[Byte](t) +} - var invalid Int - err = invalid.UnmarshalJSON(invalidJSON) - var syntaxError *json.SyntaxError - if !errors.As(err, &syntaxError) { - t.Errorf("expected wrapped json.SyntaxError, not %T", err) - } - assertNullInt(t, invalid, "invalid json") +func testUnmarshalInt[N nullint](t *testing.T) { + t.Run(internal.TypeName[N](), func(t *testing.T) { + var i N + err := json.Unmarshal(intJSON, &i) + maybePanic(err) + assertInt(t, i, "int json") + + var si N + err = json.Unmarshal(intStringJSON, &si) + maybePanic(err) + assertInt(t, si, "int string json") + + var bi N + err = json.Unmarshal(floatBlankJSON, &bi) + if err == nil { + panic("expected error") + } + + var zero N + err = json.Unmarshal(zeroJSON, &zero) + maybePanic(err) + assertNullInt(t, zero, "zero json") + + var null N + err = json.Unmarshal(nullJSON, &null) + maybePanic(err) + assertNullInt(t, null, "null json") + + var badType N + err = json.Unmarshal(boolJSON, &badType) + if err == nil { + panic("err should not be nil") + } + assertNullInt(t, badType, "wrong type json") + + var invalid N + err = json.Unmarshal(invalidJSON, &invalid) + var syntaxError *json.SyntaxError + if !errors.As(err, &syntaxError) { + t.Errorf("expected wrapped json.SyntaxError, not %T", err) + } + assertNullInt(t, invalid, "invalid json") + }) } func TestUnmarshalNonIntegerNumber(t *testing.T) { @@ -92,193 +122,292 @@ func TestUnmarshalNonIntegerNumber(t *testing.T) { } } -func TestUnmarshalInt64Overflow(t *testing.T) { - int64Overflow := uint64(math.MaxInt64) - - // Max int64 should decode successfully - var i Int - err := json.Unmarshal([]byte(strconv.FormatUint(int64Overflow, 10)), &i) - maybePanic(err) +func TestUnmarshalIntOverflow(t *testing.T) { + testUnmarshalIntOverflow[Int, int64](t, math.MaxInt64) + testUnmarshalIntOverflow[Int32, int32](t, math.MaxInt32) + testUnmarshalIntOverflow[Int16, int16](t, math.MaxInt16) + testUnmarshalIntOverflow[Byte, byte](t, math.MaxUint8) +} - // Attempt to overflow - int64Overflow++ - err = json.Unmarshal([]byte(strconv.FormatUint(int64Overflow, 10)), &i) - if err == nil { - panic("err should be present; decoded value overflows int64") - } +func testUnmarshalIntOverflow[N nullint, V internal.Integer](t *testing.T, max V) { + t.Run(internal.TypeName[N](), func(t *testing.T) { + overflow := uint64(max) + + // Max int64 should decode successfully + var i N + err := json.Unmarshal([]byte(strconv.FormatUint(overflow, 10)), &i) + maybePanic(err) + + // Attempt to overflow + overflow++ + err = json.Unmarshal([]byte(strconv.FormatUint(overflow, 10)), &i) + if err == nil { + t.Error("err should be present but isn't; decoded value overflows") + } + }) } func TestTextUnmarshalInt(t *testing.T) { - var i Int - err := i.UnmarshalText([]byte("12345")) - maybePanic(err) - assertInt(t, i, "UnmarshalText() int") - - var zero Int - err = zero.UnmarshalText([]byte("0")) - maybePanic(err) - assertNullInt(t, zero, "UnmarshalText() zero int") - - var blank Int - err = blank.UnmarshalText([]byte("")) - maybePanic(err) - assertNullInt(t, blank, "UnmarshalText() empty int") - - var null Int - err = null.UnmarshalText([]byte("null")) - maybePanic(err) - assertNullInt(t, null, `UnmarshalText() "null"`) - - var invalid Int - err = invalid.UnmarshalText([]byte("hello world")) - if err == nil { - panic("expected error") - } + testTextUnmarshalInt(t, (*Int).UnmarshalText) + testTextUnmarshalInt(t, (*Int32).UnmarshalText) + testTextUnmarshalInt(t, (*Int16).UnmarshalText) + testTextUnmarshalInt(t, (*Byte).UnmarshalText) +} + +func testTextUnmarshalInt[N nullint](t *testing.T, unmarshal func(*N, []byte) error) { + t.Run(internal.TypeName[N](), func(t *testing.T) { + var i N + err := unmarshal(&i, []byte("123")) + maybePanic(err) + assertInt(t, i, "unmarshal int") + + var zero Int + err = zero.UnmarshalText([]byte("0")) + maybePanic(err) + assertNullInt(t, zero, "UnmarshalText() zero int") + + var blank N + err = unmarshal(&blank, []byte("")) + maybePanic(err) + assertNullInt(t, blank, "unmarshal empty int") + + var null N + err = unmarshal(&null, []byte("null")) + maybePanic(err) + assertNullInt(t, null, `unmarshal "null"`) + + var invalid N + err = unmarshal(&invalid, []byte("hello world")) + if err == nil { + panic("expected error") + } + }) } func TestMarshalInt(t *testing.T) { - i := IntFrom(12345) - data, err := json.Marshal(i) - maybePanic(err) - assertJSONEquals(t, data, "12345", "non-empty json marshal") + testMarshalInt(t, NewInt) + testMarshalInt(t, NewInt32) + testMarshalInt(t, NewInt16) + testMarshalInt(t, NewByte) +} - // invalid values should be encoded as 0 - null := NewInt(0, false) - data, err = json.Marshal(null) - maybePanic(err) - assertJSONEquals(t, data, "0", "null json marshal") +func testMarshalInt[N interface{ ValueOrZero() V }, V internal.Integer](t *testing.T, newInt func(V, bool) N) { + t.Run(internal.TypeName[N](), func(t *testing.T) { + i := newInt(123, true) + data, err := json.Marshal(i) + maybePanic(err) + assertJSONEquals(t, data, "123", "non-empty json marshal") + + // invalid values should be encoded as 0 + null := NewInt(0, false) + data, err = json.Marshal(null) + maybePanic(err) + assertJSONEquals(t, data, "0", "null json marshal") + }) } func TestMarshalIntText(t *testing.T) { - i := IntFrom(12345) - data, err := i.MarshalText() - maybePanic(err) - assertJSONEquals(t, data, "12345", "non-empty text marshal") + testMarshalIntText(t, NewInt) + testMarshalIntText(t, NewInt32) + testMarshalIntText(t, NewInt16) + testMarshalIntText(t, NewByte) +} - // invalid values should be encoded as zero - null := NewInt(0, false) - data, err = null.MarshalText() - maybePanic(err) - assertJSONEquals(t, data, "0", "null text marshal") +func testMarshalIntText[N encoding.TextMarshaler, V internal.Integer](t *testing.T, newInt func(V, bool) N) { + t.Run(internal.TypeName[N](), func(t *testing.T) { + i := newInt(123, true) + data, err := i.MarshalText() + maybePanic(err) + assertJSONEquals(t, data, "123", "non-empty text marshal") + + // invalid values should be encoded as zero + null := newInt(0, false) + data, err = null.MarshalText() + maybePanic(err) + assertJSONEquals(t, data, "0", "null text marshal") + }) } func TestIntPointer(t *testing.T) { - i := IntFrom(12345) - ptr := i.Ptr() - if *ptr != 12345 { - t.Errorf("bad %s int: %#v ≠ %d\n", "pointer", ptr, 12345) - } + testIntPointer(t, NewInt) + testIntPointer(t, NewInt32) + testIntPointer(t, NewInt16) + testIntPointer(t, NewByte) +} - null := NewInt(0, false) - ptr = null.Ptr() - if ptr != nil { - t.Errorf("bad %s int: %#v ≠ %s\n", "nil pointer", ptr, "nil") - } +func testIntPointer[N interface{ Ptr() *V }, V internal.Integer](t *testing.T, newInt func(V, bool) N) { + t.Run(internal.TypeName[N](), func(t *testing.T) { + i := newInt(123, true) + ptr := i.Ptr() + if *ptr != 123 { + t.Errorf("bad %s int: %#v ≠ %d\n", "pointer", ptr, 123) + } + + null := newInt(0, false) + ptr = null.Ptr() + if ptr != nil { + t.Errorf("bad %s int: %#v ≠ %s\n", "nil pointer", ptr, "nil") + } + }) } func TestIntIsZero(t *testing.T) { - i := IntFrom(12345) - if i.IsZero() { - t.Errorf("IsZero() should be false") - } - - null := NewInt(0, false) - if !null.IsZero() { - t.Errorf("IsZero() should be true") - } + testIntIsZero(t, NewInt) + testIntIsZero(t, NewInt32) + testIntIsZero(t, NewInt16) + testIntIsZero(t, NewByte) +} - zero := NewInt(0, true) - if !zero.IsZero() { - t.Errorf("IsZero() should be true") - } +func testIntIsZero[N nullint, V internal.Integer](t *testing.T, newInt func(V, bool) N) { + t.Run(internal.TypeName[N](), func(t *testing.T) { + i := newInt(123, true) + if i.IsZero() { + t.Errorf("IsZero() should be false") + } + + null := newInt(0, false) + if !null.IsZero() { + t.Errorf("IsZero() should be true") + } + + zero := newInt(0, true) + if !zero.IsZero() { + t.Errorf("IsZero() should be true") + } + }) } func TestIntScan(t *testing.T) { - var i Int - err := i.Scan(12345) - maybePanic(err) - assertInt(t, i, "scanned int") + testIntScan(t, (*Int).Scan) + testIntScan(t, (*Int32).Scan) + testIntScan(t, (*Int16).Scan) + testIntScan(t, (*Byte).Scan) +} - var null Int - err = null.Scan(nil) - maybePanic(err) - assertNullInt(t, null, "scanned null") +func testIntScan[N nullint](t *testing.T, scan func(*N, any) error) { + t.Run(internal.TypeName[N](), func(t *testing.T) { + var i N + err := scan(&i, 123) + maybePanic(err) + assertInt(t, i, "scanned int") + + var zero N + err = scan(&i, 0) + maybePanic(err) + assertNullInt(t, zero, "scanned zero int") + + var null N + err = scan(&null, nil) + maybePanic(err) + assertNullInt(t, null, "scanned null") + }) } func TestIntSetValid(t *testing.T) { - change := NewInt(0, false) - assertNullInt(t, change, "SetValid()") - change.SetValid(12345) - assertInt(t, change, "SetValid()") + testIntSetValid(t, NewInt, (*Int).SetValid) + testIntSetValid(t, NewInt32, (*Int32).SetValid) + testIntSetValid(t, NewInt16, (*Int16).SetValid) + testIntSetValid(t, NewByte, (*Byte).SetValid) +} + +func testIntSetValid[N nullint, V internal.Integer](t *testing.T, newInt func(V, bool) N, setValid func(*N, V)) { + t.Run(internal.TypeName[N](), func(t *testing.T) { + change := newInt(0, false) + assertNullInt(t, change, "SetValid()") + setValid(&change, 123) + assertInt(t, change, "SetValid()") + }) } func TestIntValueOrZero(t *testing.T) { - valid := NewInt(12345, true) - if valid.ValueOrZero() != 12345 { - t.Error("unexpected ValueOrZero", valid.ValueOrZero()) - } + testIntValueOrZero(t, NewInt) + testIntValueOrZero(t, NewInt32) + testIntValueOrZero(t, NewInt16) + testIntValueOrZero(t, NewByte) +} - invalid := NewInt(12345, false) - if invalid.ValueOrZero() != 0 { - t.Error("unexpected ValueOrZero", invalid.ValueOrZero()) - } +func testIntValueOrZero[N interface{ ValueOrZero() V }, V internal.Integer](t *testing.T, newInt func(V, bool) N) { + t.Run(internal.TypeName[N](), func(t *testing.T) { + valid := newInt(123, true) + if valid.ValueOrZero() != 123 { + t.Error("unexpected ValueOrZero", valid.ValueOrZero()) + } + + invalid := newInt(123, false) + if invalid.ValueOrZero() != 0 { + t.Error("unexpected ValueOrZero", invalid.ValueOrZero()) + } + }) } func TestIntEqual(t *testing.T) { - int1 := NewInt(10, false) - int2 := NewInt(10, false) - assertIntEqualIsTrue(t, int1, int2) + testIntEqual(t, NewInt) + testIntEqual(t, NewInt32) + testIntEqual(t, NewInt16) + testIntEqual(t, NewByte) +} - int1 = NewInt(10, false) - int2 = NewInt(20, false) - assertIntEqualIsTrue(t, int1, int2) +func testIntEqual[N interface{ Equal(N) bool }, V internal.Integer](t *testing.T, newInt func(V, bool) N) { + t.Run(internal.TypeName[N](), func(t *testing.T) { + int1 := newInt(10, false) + int2 := newInt(10, false) + assertIntEqualIsTrue(t, int1, int2) - int1 = NewInt(10, true) - int2 = NewInt(10, true) - assertIntEqualIsTrue(t, int1, int2) + int1 = newInt(10, false) + int2 = newInt(20, false) + assertIntEqualIsTrue(t, int1, int2) - int1 = NewInt(0, true) - int2 = NewInt(10, false) - assertIntEqualIsTrue(t, int1, int2) + int1 = newInt(10, true) + int2 = newInt(10, true) + assertIntEqualIsTrue(t, int1, int2) - int1 = NewInt(10, true) - int2 = NewInt(10, false) - assertIntEqualIsFalse(t, int1, int2) + int1 = newInt(0, true) + int2 = newInt(10, false) + assertIntEqualIsTrue(t, int1, int2) - int1 = NewInt(10, false) - int2 = NewInt(10, true) - assertIntEqualIsFalse(t, int1, int2) + int1 = newInt(10, true) + int2 = newInt(10, false) + assertIntEqualIsFalse(t, int1, int2) - int1 = NewInt(10, true) - int2 = NewInt(20, true) - assertIntEqualIsFalse(t, int1, int2) + int1 = newInt(10, false) + int2 = newInt(10, true) + assertIntEqualIsFalse(t, int1, int2) + + int1 = newInt(10, true) + int2 = newInt(20, true) + assertIntEqualIsFalse(t, int1, int2) + }) } -func assertInt(t *testing.T, i Int, from string) { - if i.Int64 != 12345 { - t.Errorf("bad %s int: %d ≠ %d\n", from, i.Int64, 12345) +func assertInt(t *testing.T, i interface{ value() (int64, bool) }, from string) { + t.Helper() + n, valid := i.value() + if n != 123 { + t.Errorf("bad %s int: %d ≠ %d\n", from, n, 123) } - if !i.Valid { + if !valid { t.Error(from, "is invalid, but should be valid") } } -func assertNullInt(t *testing.T, i Int, from string) { - if i.Valid { +func assertNullInt(t *testing.T, i interface{ value() (int64, bool) }, from string) { + t.Helper() + _, valid := i.value() + if valid { t.Error(from, "is valid, but should be invalid") } } -func assertIntEqualIsTrue(t *testing.T, a, b Int) { +func assertIntEqualIsTrue[N interface{ Equal(N) bool }](t *testing.T, a, b N) { t.Helper() if !a.Equal(b) { - t.Errorf("Equal() of Int{%v, Valid:%t} and Int{%v, Valid:%t} should return true", a.Int64, a.Valid, b.Int64, b.Valid) + t.Errorf("Equal() of %#v and %#v should return true", a, b) } } -func assertIntEqualIsFalse(t *testing.T, a, b Int) { +func assertIntEqualIsFalse[N interface{ Equal(N) bool }](t *testing.T, a, b N) { t.Helper() if a.Equal(b) { - t.Errorf("Equal() of Int{%v, Valid:%t} and Int{%v, Valid:%t} should return false", a.Int64, a.Valid, b.Int64, b.Valid) + t.Errorf("Equal() of %#v and %#v should return false", a, b) } } From 4511f5bc2e5a3557d35db7d7bda771958e7a4ad4 Mon Sep 17 00:00:00 2001 From: Greg Date: Mon, 12 Feb 2024 03:22:10 +0900 Subject: [PATCH 2/5] add Value[T] for sql.Null[T] --- value.go | 147 +++++++++++++++++++++++++++++++++ value_test.go | 178 ++++++++++++++++++++++++++++++++++++++++ zero/value.go | 101 +++++++++++++++++++++++ zero/value_test.go | 197 +++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 623 insertions(+) create mode 100644 value.go create mode 100644 value_test.go create mode 100644 zero/value.go create mode 100644 zero/value_test.go diff --git a/value.go b/value.go new file mode 100644 index 0000000..a2a9401 --- /dev/null +++ b/value.go @@ -0,0 +1,147 @@ +//go:build go1.22 + +package null + +import ( + "bytes" + "database/sql" + "encoding/json" + "fmt" +) + +// Value represents a value that may be null. +type Value[T any] struct { + sql.Null[T] +} + +// NewValue creates a new Value. +func NewValue[T any](t T, valid bool) Value[T] { + return Value[T]{ + Null: sql.Null[T]{ + V: t, + Valid: valid, + }, + } +} + +// ValueFrom creates a new Value that will always be valid. +func ValueFrom[T any](t T) Value[T] { + return NewValue(t, true) +} + +// ValueFromPtr creates a new Value that will be null if t is nil. +func ValueFromPtr[T any](t *T) Value[T] { + if t == nil { + var zero T + return NewValue(zero, false) + } + return NewValue(*t, true) +} + +// ValueOrZero returns the inner value if valid, otherwise zero. +func (t Value[T]) ValueOrZero() T { + if !t.Valid { + var zero T + return zero + } + return t.V +} + +// MarshalJSON implements json.Marshaler. +// It will encode null if this value is null. +func (t Value[T]) MarshalJSON() ([]byte, error) { + if !t.Valid { + return []byte("null"), nil + } + return json.Marshal(t.V) +} + +// UnmarshalJSON implements json.Unmarshaler. +// It supports string and null input. +func (t *Value[T]) UnmarshalJSON(data []byte) error { + if bytes.Equal(data, nullBytes) { + t.Valid = false + return nil + } + + if err := json.Unmarshal(data, &t.V); err != nil { + return fmt.Errorf("null: couldn't unmarshal JSON: %w", err) + } + + t.Valid = true + return nil +} + +/* +// MarshalText implements encoding.TextMarshaler. +// It returns an empty string if invalid, otherwise T's MarshalText. +func (t Value[T]) MarshalText() ([]byte, error) { + if !t.Valid { + return []byte{}, nil + } + if tm, ok := any(t.V).(encoding.TextMarshaler); ok { + return tm.MarshalText() + } + + rv := reflect.ValueOf(t.V) + if !rv.IsValid() { + return []byte{}, nil + } + +try: + switch rv.Kind() { + case reflect.Pointer: + if rv.IsNil() { + return []byte{}, nil + } + rv = rv.Elem() + goto try + case reflect.String: + return []byte(rv.String()), nil + case reflect.Int, reflect.Int64, reflect.Int32, reflect.Int16, reflect.Int8: + return []byte(strconv.FormatInt(rv.Int(), 10)), nil + case reflect.Uint, reflect.Uint64, reflect.Uint32, reflect.Uint16, reflect.Uint8: + return []byte(strconv.FormatUint(rv.Uint(), 10)), nil + case reflect.Float32, reflect.Float64: + return []byte(strconv.FormatFloat(rv.Float(), 'f', -1, rv.Type().Bits())), nil + + // case reflect.Slice: + // if rv.IsNil() { + // return []byte{}, nil + // } + // if rv.Type().Elem().Kind() == reflect.Uint8 { + // return rv.Bytes(), nil + // } + // + } + + return t.Value.MarshalText() +} +*/ + +// SetValid changes this Value's value and sets it to be non-null. +func (t *Value[T]) SetValid(v T) { + t.V = v + t.Valid = true +} + +// Ptr returns a pointer to this Value's value, or a nil pointer if this Value is null. +func (t Value[T]) Ptr() *T { + if !t.Valid { + return nil + } + return &t.V +} + +// IsZero returns true for invalid Values, hopefully for future omitempty support. +// A non-null Value with a zero value will not be considered zero. +func (t Value[T]) IsZero() bool { + return !t.Valid +} + +/* +// Equal returns true if both Value objects encode the same value or are both null. +func (t Value[T]) Equal(other Value[T]) bool { + return t.Valid == other.Valid && (t.V == other.V) +} +*/ diff --git a/value_test.go b/value_test.go new file mode 100644 index 0000000..8f9ab85 --- /dev/null +++ b/value_test.go @@ -0,0 +1,178 @@ +package null + +import ( + "bytes" + "database/sql" + "encoding/json" + "fmt" + "reflect" + "testing" + + "github.com/guregu/null/v5/internal" +) + +func TestValue(t *testing.T) { + testValue[string](t, "hello") + testValue[uint32](t, 1337) + testValue[uint64](t, 42) + + type myint int + testValue[myint](t, 2) +} + +func testValue[T any](t *testing.T, good T) { + t.Run(internal.TypeName[Value[T]](), func(t *testing.T) { + var zero T + var nilv *T + + // valid Value[T] + testValueValid[T](t, good) + testValueValid[T](t, zero) + + // invalid Value[T] + t.Run("null", func(t *testing.T) { + null := NewValue(zero, false) + if !null.IsZero() { + t.Errorf("%v IsZero() should be true", null) + } + nullVFP := ValueFromPtr(nilv) + if !reflect.DeepEqual(null, nullVFP) { + t.Errorf("%#v != %#v", null, nullVFP) + } + + nullp := null.Ptr() + if nullp != nil { + t.Errorf("%#v Ptr() should be nil", null) + } + + nullVOZ := null.ValueOrZero() + if !reflect.DeepEqual(nullVOZ, zero) { + t.Error("ValueOrZero() want:", zero, "got:", nullVOZ) + } + + t.Run("MarshalJSON", func(t *testing.T) { + wantJSON, err := json.Marshal(nilv) + if err != nil { + t.Fatal(err) + } + got, err := json.Marshal(null) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(wantJSON, got) { + t.Error("unexpected json. want:", string(wantJSON), "got:", string(got)) + } + + t.Run("UnmarshalJSON", func(t *testing.T) { + var want T + if err := json.Unmarshal(wantJSON, &want); err != nil { + t.Fatal(err) + } + var got Value[T] + if err := json.Unmarshal(wantJSON, &got); err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(want, got.V) { + t.Error("bad unmarshal. want:", want, "got:", got) + } + if !got.IsZero() { + t.Errorf("%#v IsZero() should be true", got) + } + }) + }) + + t.Run("Scan(nil)", func(t *testing.T) { + var want sql.Null[T] + if err := want.Scan(nil); err != nil { + t.Fatal(err) + } + var got Value[T] + if err := got.Scan(nil); err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(want, got.Null) { + t.Error("bad scan. want:", want, "got:", got) + } + }) + + t.Run(fmt.Sprintf("SetValid(%v)", zero), func(t *testing.T) { + valid2 := null + valid2.SetValid(zero) + if valid2.IsZero() { + t.Errorf("%#v IsZero() should be false", valid2) + } + }) + }) + + }) +} + +func testValueValid[T any](t *testing.T, value T) { + valid := NewValue(value, true) + if valid.IsZero() { + t.Errorf("%#v IsZero() should be false", valid) + } + validVF := ValueFrom(value) + if !reflect.DeepEqual(valid, validVF) { + t.Errorf("%#v != %#v", valid, validVF) + } + validVFP := ValueFromPtr(&value) + if !reflect.DeepEqual(valid, validVFP) { + t.Errorf("%#v != %#v", valid, validVFP) + } + + validp := valid.Ptr() + if validp == nil { + t.Errorf("%#v Ptr() shouldn't be nil", valid) + } + + validVOZ := valid.ValueOrZero() + if !reflect.DeepEqual(validVOZ, value) { + t.Error("ValueOrZero() want:", value, "got:", validVOZ) + } + + t.Run("MarshalJSON", func(t *testing.T) { + wantJSON, err := json.Marshal(value) + if err != nil { + t.Fatal(err) + } + got, err := json.Marshal(valid) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(wantJSON, got) { + t.Error("unexpected json. want:", string(wantJSON), "got:", string(got)) + } + + t.Run("UnmarshalJSON", func(t *testing.T) { + var want T + if err := json.Unmarshal(wantJSON, &want); err != nil { + t.Fatal(err) + } + var got Value[T] + if err := json.Unmarshal(wantJSON, &got); err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(want, got.V) { + t.Error("bad unmarshal. want:", want, "got:", got) + } + if got.IsZero() { + t.Errorf("%#v IsZero() should be false", got) + } + }) + }) + + t.Run(fmt.Sprintf("Scan(%v)", value), func(t *testing.T) { + var want sql.Null[T] + if err := want.Scan(value); err != nil { + t.Fatal(err) + } + var got Value[T] + if err := got.Scan(value); err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(want, got.Null) { + t.Error("bad scan. want:", want, "got:", got) + } + }) +} diff --git a/zero/value.go b/zero/value.go new file mode 100644 index 0000000..8218879 --- /dev/null +++ b/zero/value.go @@ -0,0 +1,101 @@ +//go:build go1.22 + +package zero + +import ( + "bytes" + "database/sql" + "encoding/json" + "fmt" +) + +type Value[T comparable] struct { + sql.Null[T] +} + +// NewValue creates a new Value. +func NewValue[T comparable](t T, valid bool) Value[T] { + return Value[T]{ + Null: sql.Null[T]{ + V: t, + Valid: valid, + }, + } +} + +// ValueFrom creates a new Value that will always be valid. +func ValueFrom[T comparable](t T) Value[T] { + var zero T + return NewValue(t, t != zero) +} + +// ValueFromPtr creates a new Value that will be null if t is nil. +func ValueFromPtr[T comparable](t *T) Value[T] { + var zero T + if t == nil { + return NewValue(zero, false) + } + return NewValue(*t, *t != zero) +} + +// ValueOrZero returns the inner value if valid, otherwise zero. +func (t Value[T]) ValueOrZero() T { + if !t.Valid { + var zero T + return zero + } + return t.V +} + +// MarshalJSON implements json.Marshaler. +// It will encode null if this value is null or zero. +func (t Value[T]) MarshalJSON() ([]byte, error) { + var zero T + if !t.Valid || t.V == zero { + return []byte("null"), nil + } + return json.Marshal(t.V) +} + +// UnmarshalJSON implements json.Unmarshaler. +// It supports string and null input. +func (t *Value[T]) UnmarshalJSON(data []byte) error { + var zero T + if bytes.Equal(data, nullBytes) { + t.Valid = false + t.V = zero + return nil + } + + if err := json.Unmarshal(data, &t.V); err != nil { + return fmt.Errorf("null: couldn't unmarshal JSON: %w", err) + } + + t.Valid = t.V != zero + return nil +} + +// SetValid changes this Value's value and sets it to be non-null. +func (t *Value[T]) SetValid(v T) { + t.V = v + t.Valid = true +} + +// Ptr returns a pointer to this Value's value, or a nil pointer if this Value is null. +func (t Value[T]) Ptr() *T { + if !t.Valid { + return nil + } + return &t.V +} + +// IsZero returns true for invalid or zero Values, hopefully for future omitempty support. +func (t Value[T]) IsZero() bool { + var zero T + return !t.Valid || t.V == zero +} + +// Equal returns true if both Value objects encode the same value or are both null. +func (t Value[T]) Equal(other Value[T]) bool { + return t.ValueOrZero() == other.ValueOrZero() +} diff --git a/zero/value_test.go b/zero/value_test.go new file mode 100644 index 0000000..9817c87 --- /dev/null +++ b/zero/value_test.go @@ -0,0 +1,197 @@ +package zero + +import ( + "bytes" + "database/sql" + "encoding/json" + "fmt" + "reflect" + "testing" + + "github.com/guregu/null/v5/internal" +) + +func TestValue(t *testing.T) { + testValue[string](t, "hello") + testValue[uint32](t, 1337) + testValue[uint64](t, 42) + + type myint int + testValue[myint](t, 2) +} + +func testValue[T comparable](t *testing.T, good T) { + t.Run(internal.TypeName[Value[T]](), func(t *testing.T) { + var zero T + + // valid Value[T] + testValueValid[T](t, good) + + // invalid Value[T] + testValueNull[T](t, zero, good) + }) +} + +func testValueValid[T comparable](t *testing.T, value T) { + valid := NewValue(value, true) + if valid.IsZero() { + t.Errorf("%#v IsZero() should be false", valid) + } + validVF := ValueFrom(value) + if !reflect.DeepEqual(valid, validVF) { + t.Errorf("%#v != %#v", valid, validVF) + } + validVFP := ValueFromPtr(&value) + if !reflect.DeepEqual(valid, validVFP) { + t.Errorf("%#v != %#v", valid, validVFP) + } + + validp := valid.Ptr() + if validp == nil { + t.Errorf("%#v Ptr() shouldn't be nil", valid) + } + + validVOZ := valid.ValueOrZero() + if !reflect.DeepEqual(validVOZ, value) { + t.Error("ValueOrZero() want:", value, "got:", validVOZ) + } + + t.Run("MarshalJSON", func(t *testing.T) { + wantJSON, err := json.Marshal(value) + if err != nil { + t.Fatal(err) + } + got, err := json.Marshal(valid) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(wantJSON, got) { + t.Error("unexpected json. want:", string(wantJSON), "got:", string(got)) + } + + t.Run("UnmarshalJSON", func(t *testing.T) { + var want T + if err := json.Unmarshal(wantJSON, &want); err != nil { + t.Fatal(err) + } + var got Value[T] + if err := json.Unmarshal(wantJSON, &got); err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(want, got.V) { + t.Error("bad unmarshal. want:", want, "got:", got) + } + if got.IsZero() { + t.Errorf("%#v IsZero() should be false", got) + } + }) + }) + + t.Run(fmt.Sprintf("Scan(%v)", value), func(t *testing.T) { + var want sql.Null[T] + if err := want.Scan(value); err != nil { + t.Fatal(err) + } + var got Value[T] + if err := got.Scan(value); err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(want, got.Null) { + t.Error("bad scan. want:", want, "got:", got) + } + }) +} + +func testValueNull[T comparable](t *testing.T, value T, good T) { + var zero T + var nilv *T + + null := NewValue(zero, false) + if !null.IsZero() { + t.Errorf("%v IsZero() should be true", null) + } + nullVFP := ValueFromPtr(nilv) + if !reflect.DeepEqual(null, nullVFP) { + t.Errorf("%#v != %#v", null, nullVFP) + } + if !null.Equal(nullVFP) { + t.Errorf("!%#v.Equal(%#v)", null, nullVFP) + } + + nullVFPZ := ValueFromPtr(new(T)) + if !reflect.DeepEqual(null, nullVFPZ) { + t.Errorf("%#v != %#v", null, nullVFPZ) + } + if !null.Equal(nullVFPZ) { + t.Errorf("!%#v.Equal(%#v)", null, nullVFPZ) + } + + nullp := null.Ptr() + if nullp != nil { + t.Errorf("%#v Ptr() should be nil", null) + } + + nullVOZ := null.ValueOrZero() + if !reflect.DeepEqual(nullVOZ, zero) { + t.Error("ValueOrZero() want:", zero, "got:", nullVOZ) + } + + t.Run("MarshalJSON", func(t *testing.T) { + wantJSON, err := json.Marshal(nilv) + if err != nil { + t.Fatal(err) + } + got, err := json.Marshal(null) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(wantJSON, got) { + t.Error("unexpected json. want:", string(wantJSON), "got:", string(got)) + } + + t.Run("UnmarshalJSON", func(t *testing.T) { + var want T + if err := json.Unmarshal(wantJSON, &want); err != nil { + t.Fatal(err) + } + var got Value[T] + if err := json.Unmarshal(wantJSON, &got); err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(want, got.V) { + t.Error("bad unmarshal. want:", want, "got:", got) + } + if !got.IsZero() { + t.Errorf("%#v IsZero() should be true", got) + } + }) + }) + + t.Run("Scan(nil)", func(t *testing.T) { + var want sql.Null[T] + if err := want.Scan(nil); err != nil { + t.Fatal(err) + } + var got Value[T] + if err := got.Scan(nil); err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(want, got.Null) { + t.Error("bad scan. want:", want, "got:", got) + } + }) + + t.Run(fmt.Sprintf("SetValid(%v)", zero), func(t *testing.T) { + valid2 := null + valid2.SetValid(zero) + if !valid2.IsZero() { + t.Errorf("%#v IsZero() should be true", valid2) + } + + valid3 := null + valid3.SetValid(good) + if valid3.IsZero() { + t.Errorf("%#v IsZero() should be false", valid2) + } + }) +} From 329942be0524fe37b033249fb59bda730f5d0c13 Mon Sep 17 00:00:00 2001 From: Greg Date: Mon, 12 Feb 2024 03:36:23 +0900 Subject: [PATCH 3/5] optimize a bit --- bool.go | 3 +-- float.go | 36 ++++-------------------------------- internal/float.go | 38 ++++++++++++++++++++++++++++++++++++++ string.go | 6 +----- string_test.go | 9 ++++----- time.go | 3 +-- value.go | 3 +-- zero/bool.go | 3 +-- zero/float.go | 36 +++++------------------------------- zero/string.go | 3 +-- zero/string_test.go | 9 ++++----- zero/value.go | 3 +-- 12 files changed, 62 insertions(+), 90 deletions(-) create mode 100644 internal/float.go diff --git a/bool.go b/bool.go index a27c749..73ca475 100644 --- a/bool.go +++ b/bool.go @@ -1,7 +1,6 @@ package null import ( - "bytes" "database/sql" "encoding/json" "errors" @@ -47,7 +46,7 @@ func (b Bool) ValueOrZero() bool { // It supports number and null input. // 0 will not be considered a null Bool. func (b *Bool) UnmarshalJSON(data []byte) error { - if bytes.Equal(data, nullBytes) { + if len(data) > 0 && data[0] == 'n' { b.Valid = false return nil } diff --git a/float.go b/float.go index 6ec31eb..9542838 100644 --- a/float.go +++ b/float.go @@ -1,14 +1,14 @@ package null import ( - "bytes" "database/sql" "encoding/json" - "errors" "fmt" "math" "reflect" "strconv" + + "github.com/guregu/null/v5/internal" ) // Float is a nullable float64. @@ -53,35 +53,7 @@ func (f Float) ValueOrZero() float64 { // It supports number and null input. // 0 will not be considered a null Float. func (f *Float) UnmarshalJSON(data []byte) error { - if bytes.Equal(data, nullBytes) { - f.Valid = false - return nil - } - - if err := json.Unmarshal(data, &f.Float64); err != nil { - var typeError *json.UnmarshalTypeError - if errors.As(err, &typeError) { - // special case: accept string input - if typeError.Value != "string" { - return fmt.Errorf("null: JSON input is invalid type (need float or string): %w", err) - } - var str string - if err := json.Unmarshal(data, &str); err != nil { - return fmt.Errorf("null: couldn't unmarshal number string: %w", err) - } - n, err := strconv.ParseFloat(str, 64) - if err != nil { - return fmt.Errorf("null: couldn't convert string to float: %w", err) - } - f.Float64 = n - f.Valid = true - return nil - } - return fmt.Errorf("null: couldn't unmarshal JSON: %w", err) - } - - f.Valid = true - return nil + return internal.UnmarshalFloatJSON(data, &f.Float64, &f.Valid) } // UnmarshalText implements encoding.TextUnmarshaler. @@ -94,7 +66,7 @@ func (f *Float) UnmarshalText(text []byte) error { return nil } var err error - f.Float64, err = strconv.ParseFloat(string(text), 64) + f.Float64, err = strconv.ParseFloat(str, 64) if err != nil { return fmt.Errorf("null: couldn't unmarshal text: %w", err) } diff --git a/internal/float.go b/internal/float.go new file mode 100644 index 0000000..167fd62 --- /dev/null +++ b/internal/float.go @@ -0,0 +1,38 @@ +package internal + +import ( + "encoding/json" + "fmt" + "strconv" +) + +func UnmarshalFloatJSON(data []byte, value *float64, valid *bool) error { + if len(data) == 0 { + return fmt.Errorf("UnmarshalJSON: no data") + } + + switch data[0] { + case 'n': + *value = 0 + *valid = false + return nil + + case '"': + var str string + if err := json.Unmarshal(data, &str); err != nil { + return fmt.Errorf("null: couldn't unmarshal number string: %w", err) + } + n, err := strconv.ParseFloat(str, 64) + if err != nil { + return fmt.Errorf("null: couldn't convert string to int: %w", err) + } + *value = n + *valid = true + return nil + + default: + err := json.Unmarshal(data, value) + *valid = err == nil + return err + } +} diff --git a/string.go b/string.go index 67f6aaf..2e2becb 100644 --- a/string.go +++ b/string.go @@ -5,15 +5,11 @@ package null import ( - "bytes" "database/sql" "encoding/json" "fmt" ) -// nullBytes is a JSON null literal -var nullBytes = []byte("null") - // String is a nullable string. It supports SQL and JSON serialization. // It will marshal to null if null. Blank string input will be considered null. type String struct { @@ -54,7 +50,7 @@ func NewString(s string, valid bool) String { // UnmarshalJSON implements json.Unmarshaler. // It supports string and null input. Blank string input does not produce a null String. func (s *String) UnmarshalJSON(data []byte) error { - if bytes.Equal(data, nullBytes) { + if len(data) > 0 && data[0] == 'n' { s.Valid = false return nil } diff --git a/string_test.go b/string_test.go index ce470a4..eb71217 100644 --- a/string_test.go +++ b/string_test.go @@ -15,10 +15,6 @@ var ( invalidJSON = []byte(`:)`) ) -type stringInStruct struct { - Test String `json:"test,omitempty"` -} - func TestStringFrom(t *testing.T) { str := StringFrom("test") assertStr(t, str, "StringFrom() string") @@ -118,7 +114,10 @@ func TestMarshalString(t *testing.T) { assertJSONEquals(t, data, "", "string marshal text") } -// Tests omitempty... broken until Go 1.4 +// Tests omitempty... broken until json/v2? +// type stringInStruct struct { +// Test String `json:"test,omitempty"` +// } // func TestMarshalStringInStruct(t *testing.T) { // obj := stringInStruct{Test: StringFrom("")} // data, err := json.Marshal(obj) diff --git a/time.go b/time.go index 15c16cf..1bce583 100644 --- a/time.go +++ b/time.go @@ -1,7 +1,6 @@ package null import ( - "bytes" "database/sql" "database/sql/driver" "encoding/json" @@ -66,7 +65,7 @@ func (t Time) MarshalJSON() ([]byte, error) { // UnmarshalJSON implements json.Unmarshaler. // It supports string and null input. func (t *Time) UnmarshalJSON(data []byte) error { - if bytes.Equal(data, nullBytes) { + if len(data) > 0 && data[0] == 'n' { t.Valid = false return nil } diff --git a/value.go b/value.go index a2a9401..6f351ee 100644 --- a/value.go +++ b/value.go @@ -3,7 +3,6 @@ package null import ( - "bytes" "database/sql" "encoding/json" "fmt" @@ -59,7 +58,7 @@ func (t Value[T]) MarshalJSON() ([]byte, error) { // UnmarshalJSON implements json.Unmarshaler. // It supports string and null input. func (t *Value[T]) UnmarshalJSON(data []byte) error { - if bytes.Equal(data, nullBytes) { + if len(data) > 0 && data[0] == 'n' { t.Valid = false return nil } diff --git a/zero/bool.go b/zero/bool.go index 0612287..d4c4722 100644 --- a/zero/bool.go +++ b/zero/bool.go @@ -1,7 +1,6 @@ package zero import ( - "bytes" "database/sql" "encoding/json" "errors" @@ -46,7 +45,7 @@ func (b Bool) ValueOrZero() bool { // UnmarshalJSON implements json.Unmarshaler. // "false" will be considered a null Bool. func (b *Bool) UnmarshalJSON(data []byte) error { - if bytes.Equal(data, nullBytes) { + if len(data) > 0 && data[0] == 'n' { b.Valid = false return nil } diff --git a/zero/float.go b/zero/float.go index d0a79ed..783d681 100644 --- a/zero/float.go +++ b/zero/float.go @@ -1,14 +1,14 @@ package zero import ( - "bytes" "database/sql" "encoding/json" - "errors" "fmt" "math" "reflect" "strconv" + + "github.com/guregu/null/v5/internal" ) // Float is a nullable float64. Zero input will be considered null. @@ -53,35 +53,9 @@ func (f Float) ValueOrZero() float64 { // It supports number and null input. // 0 will be considered a null Float. func (f *Float) UnmarshalJSON(data []byte) error { - if bytes.Equal(data, nullBytes) { - f.Valid = false - return nil - } - - if err := json.Unmarshal(data, &f.Float64); err != nil { - var typeError *json.UnmarshalTypeError - if errors.As(err, &typeError) { - // special case: accept string input - if typeError.Value != "string" { - return fmt.Errorf("zero: JSON input is invalid type (need float or string): %w", err) - } - var str string - if err := json.Unmarshal(data, &str); err != nil { - return fmt.Errorf("zero: couldn't unmarshal number string: %w", err) - } - n, err := strconv.ParseFloat(str, 64) - if err != nil { - return fmt.Errorf("zero: couldn't convert string to float: %w", err) - } - f.Float64 = n - f.Valid = n != 0 - return nil - } - return fmt.Errorf("zero: couldn't unmarshal JSON: %w", err) - } - + err := internal.UnmarshalFloatJSON(data, &f.Float64, &f.Valid) f.Valid = f.Float64 != 0 - return nil + return err } // UnmarshalText implements encoding.TextUnmarshaler. @@ -94,7 +68,7 @@ func (f *Float) UnmarshalText(text []byte) error { return nil } var err error - f.Float64, err = strconv.ParseFloat(string(text), 64) + f.Float64, err = strconv.ParseFloat(str, 64) if err != nil { return fmt.Errorf("zero: couldn't unmarshal text: %w", err) } diff --git a/zero/string.go b/zero/string.go index db1c832..4790b31 100644 --- a/zero/string.go +++ b/zero/string.go @@ -5,7 +5,6 @@ package zero import ( - "bytes" "database/sql" "encoding/json" "fmt" @@ -56,7 +55,7 @@ func (s String) ValueOrZero() string { // UnmarshalJSON implements json.Unmarshaler. // It supports string and null input. Blank string input produces a null String. func (s *String) UnmarshalJSON(data []byte) error { - if bytes.Equal(data, nullBytes) { + if len(data) > 0 && data[0] == 'n' { s.Valid = false return nil } diff --git a/zero/string_test.go b/zero/string_test.go index 1b691e6..36e435c 100644 --- a/zero/string_test.go +++ b/zero/string_test.go @@ -15,10 +15,6 @@ var ( invalidJSON = []byte(`:)`) ) -type stringInStruct struct { - Test String `json:"test,omitempty"` -} - func TestStringFrom(t *testing.T) { str := StringFrom("test") assertStr(t, str, "StringFrom() string") @@ -90,7 +86,10 @@ func TestMarshalString(t *testing.T) { assertJSONEquals(t, data, `""`, "empty json marshal") } -// Tests omitempty... broken until Go 1.4 +// Tests omitempty... broken until json/v2? +// type stringInStruct struct { +// Test String `json:"test,omitempty"` +// } // func TestMarshalStringInStruct(t *testing.T) { // obj := stringInStruct{Test: StringFrom("")} // data, err := json.Marshal(obj) diff --git a/zero/value.go b/zero/value.go index 8218879..5836d61 100644 --- a/zero/value.go +++ b/zero/value.go @@ -3,7 +3,6 @@ package zero import ( - "bytes" "database/sql" "encoding/json" "fmt" @@ -61,7 +60,7 @@ func (t Value[T]) MarshalJSON() ([]byte, error) { // It supports string and null input. func (t *Value[T]) UnmarshalJSON(data []byte) error { var zero T - if bytes.Equal(data, nullBytes) { + if len(data) > 0 && data[0] == 'n' { t.Valid = false t.V = zero return nil From 7d9fc9f53550586a4e039394b69a2a5acfb43c42 Mon Sep 17 00:00:00 2001 From: Greg Date: Mon, 12 Feb 2024 03:47:31 +0900 Subject: [PATCH 4/5] add Github Actions CI --- .github/workflows/test.yml | 14 ++++++++++++++ .gitignore | 3 ++- 2 files changed, 16 insertions(+), 1 deletion(-) create mode 100644 .github/workflows/test.yml diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 0000000..f2f6e41 --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,14 @@ +name: Deploy + +on: [push, pull_request] + +jobs: + spin: + runs-on: ubuntu-latest + name: Test + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-go@v5 + with: + go-version: 'stable' + - run: go test -v -race -coverpkg=./... ./... diff --git a/.gitignore b/.gitignore index e9eb644..3c1b123 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ -coverage.out +cover*.out .idea/ +.DS_Store From 782c7fe8147ff3a286623c1e3e63a18c21ba2e5a Mon Sep 17 00:00:00 2001 From: Greg Date: Mon, 12 Feb 2024 03:52:24 +0900 Subject: [PATCH 5/5] tidy up --- zero/string.go | 3 --- zero/time_test.go | 4 ++-- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/zero/string.go b/zero/string.go index 4790b31..5eb45ff 100644 --- a/zero/string.go +++ b/zero/string.go @@ -10,9 +10,6 @@ import ( "fmt" ) -// nullBytes is a JSON null literal -var nullBytes = []byte("null") - // String is a nullable string. // JSON marshals to a blank string if null. // Considered null to SQL if zero. diff --git a/zero/time_test.go b/zero/time_test.go index 577a2a1..330ab48 100644 --- a/zero/time_test.go +++ b/zero/time_test.go @@ -179,8 +179,8 @@ func TestTimeValue(t *testing.T) { ti := TimeFrom(timeValue1) v, err := ti.Value() maybePanic(err) - if ti.Time != timeValue1 { - t.Errorf("bad time.Time value: %v ≠ %v", ti.Time, timeValue1) + if v != timeValue1 { + t.Errorf("bad time.Time value: %v ≠ %v", v, timeValue1) } var nt time.Time