From 6e9fa42fef85262b2de6975efa716e212a1012a8 Mon Sep 17 00:00:00 2001 From: Kostas Stamatakis Date: Mon, 30 Dec 2024 22:43:04 +0200 Subject: [PATCH] fix #2204 --- .vscode/launch.json | 15 ++++++++ .vscode/settings.json | 5 +++ pgtype/json.go | 11 +++++- pgtype/json_test.go | 20 ++++++++++- pgtype/pgtype.go | 4 +++ tete/main.go | 80 +++++++++++++++++++++++++++++++++++++++++++ 6 files changed, 133 insertions(+), 2 deletions(-) create mode 100644 .vscode/launch.json create mode 100644 .vscode/settings.json create mode 100644 tete/main.go diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 000000000..8655150da --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,15 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "name": "Launch Package", + "type": "go", + "request": "launch", + "mode": "debug", + "program": "${fileDirname}" + } + ] +} \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 000000000..890fee03d --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,5 @@ +{ + "go.testEnvVars": { + "PGX_TEST_DATABASE":"host=127.0.0.1 user=gamerhound password=gamerhound dbname=gamerhound" + } +} \ No newline at end of file diff --git a/pgtype/json.go b/pgtype/json.go index 48b9f9771..76cec51b8 100644 --- a/pgtype/json.go +++ b/pgtype/json.go @@ -161,6 +161,10 @@ func (c *JSONCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanP // // https://github.com/jackc/pgx/issues/2146 func isSQLScanner(v any) bool { + if _, is := v.(sql.Scanner); is { + return true + } + val := reflect.ValueOf(v) for val.Kind() == reflect.Ptr { if _, ok := val.Interface().(sql.Scanner); ok { @@ -212,7 +216,12 @@ func (s *scanPlanJSONToJSONUnmarshal) Scan(src []byte, dst any) error { return fmt.Errorf("cannot scan NULL into %T", dst) } - elem := reflect.ValueOf(dst).Elem() + v := reflect.ValueOf(dst) + if v.Kind() != reflect.Pointer || v.IsNil() { + return fmt.Errorf("cannot scan into non-pointer or nil destinations %T", dst) + } + + elem := v.Elem() elem.Set(reflect.Zero(elem.Type())) return s.unmarshal(src, dst) diff --git a/pgtype/json_test.go b/pgtype/json_test.go index 18ca5a8e4..1f286b9da 100644 --- a/pgtype/json_test.go +++ b/pgtype/json_test.go @@ -267,7 +267,8 @@ func TestJSONCodecCustomMarshal(t *testing.T) { Unmarshal: func(data []byte, v any) error { return json.Unmarshal([]byte(`{"custom":"value"}`), v) }, - }}) + }, + }) } pgxtest.RunValueRoundTripTests(context.Background(), t, connTestRunner, pgxtest.KnownOIDQueryExecModes, "json", []pgxtest.ValueRoundTripTest{ @@ -278,3 +279,20 @@ func TestJSONCodecCustomMarshal(t *testing.T) { }}, }) } + +func TestJSONCodecScanToNonPointerValues(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + n := 44 + err := conn.QueryRow(ctx, "select '42'::jsonb").Scan(n) + require.Error(t, err) + + var i *int + err = conn.QueryRow(ctx, "select '42'::jsonb").Scan(i) + require.Error(t, err) + + m := 0 + err = conn.QueryRow(ctx, "select '42'::jsonb").Scan(&m) + require.NoError(t, err) + require.Equal(t, 42, m) + }) +} diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index f9d43edd7..20645d694 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -415,6 +415,10 @@ func (plan *scanPlanSQLScanner) Scan(src []byte, dst any) error { // we don't know if the target is a sql.Scanner or a pointer on a sql.Scanner, so we need to check recursively func getSQLScanner(target any) sql.Scanner { + if sc, is := target.(sql.Scanner); is { + return sc + } + val := reflect.ValueOf(target) for val.Kind() == reflect.Ptr { if _, ok := val.Interface().(sql.Scanner); ok { diff --git a/tete/main.go b/tete/main.go new file mode 100644 index 000000000..855636da0 --- /dev/null +++ b/tete/main.go @@ -0,0 +1,80 @@ +package main + +import ( + "context" + "fmt" + "log" + "reflect" + + "github.com/jackc/pgx/v5/pgxpool" +) + +func main() { + pool, err := pgxpool.New(context.Background(), "postgres://gamerhound:gamerhound@localhost:5432/gamerhound") + if err != nil { + log.Fatal(err) + } + defer pool.Close() + + // Create the enum type. + _, err = pool.Exec(context.Background(), `DROP TYPE IF EXISTS test_enum_type`) + if err != nil { + log.Print(err) + return + } + _, err = pool.Exec(context.Background(), `CREATE TYPE test_enum_type AS ENUM ('a', 'b')`) + if err != nil { + log.Print(err) + return + } + + err = testQuery(pool, "SELECT 'a'", "a") + if err != nil { + log.Printf("test TEXT error: %s\n", err) + } + + err = testQuery(pool, "SELECT 'a'::test_enum_type", "a") + if err != nil { + log.Printf("test ENUM error: %s\n", err) + } + + err = testQuery(pool, "SELECT '{}'::jsonb", "{}") + if err != nil { + log.Printf("test JSONB error: %s\n", err) + } +} + +// T implements the sql.Scanner interface. +type T struct { + v *any +} + +func (t T) Scan(v any) error { + *t.v = v + return nil +} + +// testQuery executes the query and checks if the scanned value matches +// the expected result. +func testQuery(pool *pgxpool.Pool, query string, expected any) error { + rows, err := pool.Query(context.Background(), query) + if err != nil { + return err + } + // defer rows.Close() + + var got any + t := T{v: &got} + for rows.Next() { + if err := rows.Scan(t); err != nil { + return err + } + } + if err = rows.Err(); err != nil { + return err + } + if !reflect.DeepEqual(got, expected) { + return fmt.Errorf("expected %#v; got %#v", expected, got) + } + return nil +}