From cbc72d1259eed02b490089cb314343e36074a8c9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Felix=20R=C3=B6hrich?= Date: Sun, 17 Mar 2024 14:30:18 +0100 Subject: [PATCH] make parsing stricter and add corresponding test --- pgtype/time.go | 13 +++++++-- pgtype/time_test.go | 68 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 79 insertions(+), 2 deletions(-) diff --git a/pgtype/time.go b/pgtype/time.go index 2eb6ace28..a3d0ab1af 100644 --- a/pgtype/time.go +++ b/pgtype/time.go @@ -45,7 +45,12 @@ func (t *Time) Scan(src any) error { switch src := src.(type) { case string: - return scanPlanTextAnyToTimeScanner{}.Scan([]byte(src), t) + err := scanPlanTextAnyToTimeScanner{}.Scan([]byte(src), t) + if err != nil { + t.Microseconds = 0 + t.Valid = false + } + return err } return fmt.Errorf("cannot scan %T", src) @@ -176,7 +181,7 @@ func (scanPlanTextAnyToTimeScanner) Scan(src []byte, dst any) error { s := string(src) - if len(s) < 8 { + if len(s) < 8 || s[2] != ':' || s[5] != ':' { return fmt.Errorf("cannot decode %v into Time", s) } @@ -199,6 +204,10 @@ func (scanPlanTextAnyToTimeScanner) Scan(src []byte, dst any) error { usec += seconds * microsecondsPerSecond if len(s) > 9 { + if s[8] != '.' || len(s) > 15 { + return fmt.Errorf("cannot decode %v into Time", s) + } + fraction := s[9:] n, err := strconv.ParseInt(fraction, 10, 64) if err != nil { diff --git a/pgtype/time_test.go b/pgtype/time_test.go index 01bcee0f4..06970bacd 100644 --- a/pgtype/time_test.go +++ b/pgtype/time_test.go @@ -2,11 +2,13 @@ package pgtype_test import ( "context" + "strconv" "testing" "time" "github.com/jackc/pgx/v5/pgtype" "github.com/jackc/pgx/v5/pgxtest" + "github.com/stretchr/testify/assert" ) func TestTimeCodec(t *testing.T) { @@ -45,3 +47,69 @@ func TestTimeCodec(t *testing.T) { {nil, new(pgtype.Time), isExpectedEq(pgtype.Time{})}, }) } + +func TestTimeTextScanner(t *testing.T) { + var pgTime pgtype.Time + + assert.NoError(t, pgTime.Scan("07:37:16")) + assert.Equal(t, true, pgTime.Valid) + assert.Equal(t, int64(7*time.Hour+37*time.Minute+16*time.Second), pgTime.Microseconds*int64(time.Microsecond)) + + assert.NoError(t, pgTime.Scan("15:04:05")) + assert.Equal(t, true, pgTime.Valid) + assert.Equal(t, int64(15*time.Hour+4*time.Minute+5*time.Second), pgTime.Microseconds*int64(time.Microsecond)) + + // parsing of fractional digits + assert.NoError(t, pgTime.Scan("15:04:05.00")) + assert.Equal(t, true, pgTime.Valid) + assert.Equal(t, int64(15*time.Hour+4*time.Minute+5*time.Second), pgTime.Microseconds*int64(time.Microsecond)) + + const mirco = "789123" + const woFraction = int64(4*time.Hour + 5*time.Minute + 6*time.Second) // time without fraction + for i := 0; i <= len(mirco); i++ { + assert.NoError(t, pgTime.Scan("04:05:06."+mirco[:i])) + assert.Equal(t, true, pgTime.Valid) + + frac, _ := strconv.ParseInt(mirco[:i], 10, 64) + for k := i; k < 6; k++ { + frac *= 10 + } + assert.Equal(t, woFraction+frac*int64(time.Microsecond), pgTime.Microseconds*int64(time.Microsecond)) + } + + // parsing of too long fraction errors + assert.Error(t, pgTime.Scan("04:05:06.7891234")) + assert.Equal(t, false, pgTime.Valid) + assert.Equal(t, int64(0), pgTime.Microseconds) + + // parsing of timetz errors + assert.Error(t, pgTime.Scan("04:05:06.789-08")) + assert.Equal(t, false, pgTime.Valid) + assert.Equal(t, int64(0), pgTime.Microseconds) + + assert.Error(t, pgTime.Scan("04:05:06-08:00")) + assert.Equal(t, false, pgTime.Valid) + assert.Equal(t, int64(0), pgTime.Microseconds) + + // parsing of date errors + assert.Error(t, pgTime.Scan("1997-12-17")) + assert.Equal(t, false, pgTime.Valid) + assert.Equal(t, int64(0), pgTime.Microseconds) + + // parsing of text errors + assert.Error(t, pgTime.Scan("12345678")) + assert.Equal(t, false, pgTime.Valid) + assert.Equal(t, int64(0), pgTime.Microseconds) + + assert.Error(t, pgTime.Scan("12-34-56")) + assert.Equal(t, false, pgTime.Valid) + assert.Equal(t, int64(0), pgTime.Microseconds) + + assert.Error(t, pgTime.Scan("12:34-56")) + assert.Equal(t, false, pgTime.Valid) + assert.Equal(t, int64(0), pgTime.Microseconds) + + assert.Error(t, pgTime.Scan("12-34:56")) + assert.Equal(t, false, pgTime.Valid) + assert.Equal(t, int64(0), pgTime.Microseconds) +}