Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
moukoublen committed Dec 30, 2024
1 parent bcf3fbd commit 5a50d4a
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 1 deletion.
9 changes: 9 additions & 0 deletions pgtype/json.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,10 @@ func (c *JSONCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanP
//
// https://github.com/jackc/pgx/issues/2146
func isSQLScanner(v any) bool {
if _, is := v.(sql.Scanner); is {
return true
}

val := reflect.ValueOf(v)
for val.Kind() == reflect.Ptr {
if _, ok := val.Interface().(sql.Scanner); ok {
Expand Down Expand Up @@ -212,6 +216,11 @@ func (s *scanPlanJSONToJSONUnmarshal) Scan(src []byte, dst any) error {
return fmt.Errorf("cannot scan NULL into %T", dst)
}

v := reflect.ValueOf(dst)
if v.Kind() != reflect.Pointer || v.IsNil() {
return fmt.Errorf("cannot scan into non-pointer or nil destinations %T", dst)
}

elem := reflect.ValueOf(dst).Elem()
elem.Set(reflect.Zero(elem.Type()))

Expand Down
20 changes: 19 additions & 1 deletion pgtype/json_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,8 @@ func TestJSONCodecCustomMarshal(t *testing.T) {
Unmarshal: func(data []byte, v any) error {
return json.Unmarshal([]byte(`{"custom":"value"}`), v)
},
}})
},
})
}

pgxtest.RunValueRoundTripTests(context.Background(), t, connTestRunner, pgxtest.KnownOIDQueryExecModes, "json", []pgxtest.ValueRoundTripTest{
Expand All @@ -278,3 +279,20 @@ func TestJSONCodecCustomMarshal(t *testing.T) {
}},
})
}

func TestJSONCodecScanToNonPointerValues(t *testing.T) {
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
n := 44
err := conn.QueryRow(ctx, "select '42'::jsonb").Scan(n)
require.Error(t, err)

var i *int
err = conn.QueryRow(ctx, "select '42'::jsonb").Scan(i)
require.Error(t, err)

m := 0
err = conn.QueryRow(ctx, "select '42'::jsonb").Scan(&m)
require.NoError(t, err)
require.Equal(t, 42, m)
})
}
4 changes: 4 additions & 0 deletions pgtype/pgtype.go
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,10 @@ func (plan *scanPlanSQLScanner) Scan(src []byte, dst any) error {

// we don't know if the target is a sql.Scanner or a pointer on a sql.Scanner, so we need to check recursively
func getSQLScanner(target any) sql.Scanner {
if sc, is := target.(sql.Scanner); is {
return sc
}

val := reflect.ValueOf(target)
for val.Kind() == reflect.Ptr {
if _, ok := val.Interface().(sql.Scanner); ok {
Expand Down

0 comments on commit 5a50d4a

Please sign in to comment.