diff --git a/pkg/flexint/flexint.go b/pkg/flexint/flexint.go index a363a16..40c86fe 100644 --- a/pkg/flexint/flexint.go +++ b/pkg/flexint/flexint.go @@ -2,45 +2,56 @@ package flexint import ( "encoding/json" + "errors" + "math" "strconv" ) -// Int64 is a type that can unmarshal from string, int64, and float64 JSON values type Int64 int64 -func (fi *Int64) UnmarshalJSON(data []byte) error { +func (i *Int64) UnmarshalJSON(data []byte) error { if string(data) == "null" { + *i = 0 return nil } - var i int64 - if err := json.Unmarshal(data, &i); err == nil { - *fi = Int64(i) + // Try to unmarshal as an int64 directly + var intVal int64 + if err := json.Unmarshal(data, &intVal); err == nil { + *i = Int64(intVal) return nil } - var f float64 - if err := json.Unmarshal(data, &f); err == nil { - *fi = Int64(f) + // Try to unmarshal as a float64 and then convert + var floatVal float64 + if err := json.Unmarshal(data, &floatVal); err == nil { + if floatVal > float64(math.MaxInt64) { + *i = Int64(math.MaxInt64) + } else if floatVal < float64(math.MinInt64) { + *i = Int64(math.MinInt64) + } else { + *i = Int64(floatVal) + } return nil } - var s string - if err := json.Unmarshal(data, &s); err != nil { - return err + // Try to unmarshal as a string and then convert + var strVal string + if err := json.Unmarshal(data, &strVal); err == nil { + if intVal, err := strconv.ParseInt(strVal, 10, 64); err == nil { + *i = Int64(intVal) + return nil + } else if floatVal, err := strconv.ParseFloat(strVal, 64); err == nil { + if floatVal > float64(math.MaxInt64) { + *i = Int64(math.MaxInt64) + } else if floatVal < float64(math.MinInt64) { + *i = Int64(math.MinInt64) + } else { + *i = Int64(floatVal) + } + return nil + } } - // Try parsing as int64 first - if i, err := strconv.ParseInt(s, 10, 64); err == nil { - *fi = Int64(i) - return nil - } - - // If that fails, try parsing as float64 and convert to int64 - if f, err := strconv.ParseFloat(s, 64); err == nil { - *fi = Int64(f) - return nil - } - - return json.Unmarshal(data, (*int64)(fi)) + return errors.New("invalid value for Int64") } diff --git a/pkg/flexint/flexint_test.go b/pkg/flexint/flexint_test.go index 83dae47..394194c 100644 --- a/pkg/flexint/flexint_test.go +++ b/pkg/flexint/flexint_test.go @@ -101,6 +101,7 @@ func TestInt64_UnmarshalJSON_FloatRounding(t *testing.T) { expected Int64 }{ {"Round down", "1.4", 1}, + {"Round up", "1.6", 1}, // Note: This rounds to 1, not 2 {"Round half even (down)", "2.5", 2}, }