From df937bf0a8b90b8ffbf03a71ea55cf9c85ed2967 Mon Sep 17 00:00:00 2001 From: Piotr Fus Date: Fri, 13 Dec 2024 13:51:52 +0100 Subject: [PATCH] SNOW-1854657 Detect JSON response in Arrow batches mode and return error --- chunk_test.go | 39 +++++++++++++++++++++++++++++++++++++++ connection.go | 1 + errors.go | 1 + rows.go | 7 +++++++ 4 files changed, 48 insertions(+) diff --git a/chunk_test.go b/chunk_test.go index 95619fc1d..d6096900b 100644 --- a/chunk_test.go +++ b/chunk_test.go @@ -7,6 +7,7 @@ import ( "context" "database/sql/driver" "encoding/json" + "errors" "fmt" "io" "math/rand" @@ -533,6 +534,44 @@ func TestWithArrowBatchesAsync(t *testing.T) { }) } +func TestWithArrowBatchesButReturningJSON(t *testing.T) { + testWithArrowBatchesButReturningJSON(t, false) +} + +func TestWithArrowBatchesButReturningJSONAsync(t *testing.T) { + testWithArrowBatchesButReturningJSON(t, true) +} + +func testWithArrowBatchesButReturningJSON(t *testing.T, async bool) { + runSnowflakeConnTest(t, func(sct *SCTest) { + requestID := NewUUID() + pool := memory.NewCheckedAllocator(memory.DefaultAllocator) + defer pool.AssertSize(t, 0) + ctx := WithArrowAllocator(context.Background(), pool) + ctx = WithArrowBatches(ctx) + ctx = WithRequestID(ctx, requestID) + if async { + ctx = WithAsyncMode(ctx) + } + + sct.mustExec(forceJSON, nil) + rows := sct.mustQueryContext(ctx, "SELECT 'hello'", nil) + defer rows.Close() + _, err := rows.(SnowflakeRows).GetArrowBatches() + assertNotNilF(t, err) + var se *SnowflakeError + errors.As(err, &se) + assertEqualE(t, se.Message, errJSONResponseInArrowBatchesMode) + + ctx = WithRequestID(context.Background(), requestID) + rows2 := sct.mustQueryContext(ctx, "SELECT 'hello'", nil) + defer rows2.Close() + scanValues := make([]driver.Value, 1) + assertNilF(t, rows2.Next(scanValues)) + assertEqualE(t, scanValues[0], "hello") + }) +} + func TestQueryArrowStream(t *testing.T) { runSnowflakeConnTest(t, func(sct *SCTest) { numrows := 50000 // approximately 10 ArrowBatch objects diff --git a/connection.go b/connection.go index ad8c0eee5..bd7dc084a 100644 --- a/connection.go +++ b/connection.go @@ -435,6 +435,7 @@ func (sc *snowflakeConn) queryContextInternal( rows.sc = sc rows.queryID = data.Data.QueryID rows.ctx = ctx + rows.format = data.Data.QueryResultFormat if isMultiStmt(&data.Data) { // handleMultiQuery is responsible to fill rows with childResults diff --git a/errors.go b/errors.go index 2e5d902f0..988e6a75b 100644 --- a/errors.go +++ b/errors.go @@ -308,6 +308,7 @@ const ( errMsgFailedToParseTomlFile = "failed to parse toml file. the params %v occurred error with value %v" errMsgFailedToFindDSNInTomlFile = "failed to find DSN in toml file." errMsgInvalidPermissionToTomlFile = "file permissions different than read/write for user. Your Permission: %v" + errJSONResponseInArrowBatchesMode = "arrow batches enabled, but the response is not Arrow based" ) // Returned if a DNS doesn't include account parameter. diff --git a/rows.go b/rows.go index b70e4baeb..f5af1034f 100644 --- a/rows.go +++ b/rows.go @@ -46,6 +46,7 @@ type snowflakeRows struct { errChannel chan error location *time.Location ctx context.Context + format string } func (rows *snowflakeRows) getLocation() *time.Location { @@ -164,6 +165,12 @@ func (rows *snowflakeRows) GetStatus() queryStatus { func (rows *snowflakeRows) GetArrowBatches() ([]*ArrowBatch, error) { // Wait for all arrow batches before fetching. // Otherwise, a panic error "invalid memory address or nil pointer dereference" will be thrown. + if rows.format != "arrow" { + return nil, (&SnowflakeError{ + QueryID: rows.queryID, + Message: errJSONResponseInArrowBatchesMode, + }).exceptionTelemetry(rows.sc) + } if err := rows.waitForAsyncQueryStatus(); err != nil { return nil, err }