Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-1854661 Detect JSON response in Arrow batches mode and return error #1277

Merged
merged 1 commit into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions async.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ func (sr *snowflakeRestful) getAsync(
rows.errChannel <- err
return err
}
rows.format = respd.Data.QueryResultFormat
rows.errChannel <- nil // mark query status complete
}
} else {
Expand Down
39 changes: 39 additions & 0 deletions chunk_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"context"
"database/sql/driver"
"encoding/json"
"errors"
"fmt"
"io"
"math/rand"
Expand Down Expand Up @@ -533,6 +534,44 @@ func TestWithArrowBatchesAsync(t *testing.T) {
})
}

func TestWithArrowBatchesButReturningJSON(t *testing.T) {
testWithArrowBatchesButReturningJSON(t, false)
}

func TestWithArrowBatchesButReturningJSONAsync(t *testing.T) {
testWithArrowBatchesButReturningJSON(t, true)
}

func testWithArrowBatchesButReturningJSON(t *testing.T, async bool) {
runSnowflakeConnTest(t, func(sct *SCTest) {
requestID := NewUUID()
pool := memory.NewCheckedAllocator(memory.DefaultAllocator)
defer pool.AssertSize(t, 0)
ctx := WithArrowAllocator(context.Background(), pool)
ctx = WithArrowBatches(ctx)
ctx = WithRequestID(ctx, requestID)
if async {
ctx = WithAsyncMode(ctx)
}

sct.mustExec(forceJSON, nil)
rows := sct.mustQueryContext(ctx, "SELECT 'hello'", nil)
defer rows.Close()
_, err := rows.(SnowflakeRows).GetArrowBatches()
assertNotNilF(t, err)
var se *SnowflakeError
errors.As(err, &se)
assertEqualE(t, se.Message, errJSONResponseInArrowBatchesMode)

ctx = WithRequestID(context.Background(), requestID)
sfc-gh-pfus marked this conversation as resolved.
Show resolved Hide resolved
rows2 := sct.mustQueryContext(ctx, "SELECT 'hello'", nil)
defer rows2.Close()
scanValues := make([]driver.Value, 1)
assertNilF(t, rows2.Next(scanValues))
assertEqualE(t, scanValues[0], "hello")
})
}

func TestQueryArrowStream(t *testing.T) {
runSnowflakeConnTest(t, func(sct *SCTest) {
numrows := 50000 // approximately 10 ArrowBatch objects
Expand Down
1 change: 1 addition & 0 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,7 @@ func (sc *snowflakeConn) queryContextInternal(
rows.sc = sc
rows.queryID = data.Data.QueryID
rows.ctx = ctx
rows.format = data.Data.QueryResultFormat

if isMultiStmt(&data.Data) {
// handleMultiQuery is responsible to fill rows with childResults
Expand Down
1 change: 1 addition & 0 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,7 @@ const (
errMsgFailedToParseTomlFile = "failed to parse toml file. the params %v occurred error with value %v"
errMsgFailedToFindDSNInTomlFile = "failed to find DSN in toml file."
errMsgInvalidPermissionToTomlFile = "file permissions different than read/write for user. Your Permission: %v"
errJSONResponseInArrowBatchesMode = "arrow batches enabled, but the response is not Arrow based"
)

// Returned if a DNS doesn't include account parameter.
Expand Down
8 changes: 8 additions & 0 deletions rows.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ type snowflakeRows struct {
errChannel chan error
location *time.Location
ctx context.Context
format string
}

func (rows *snowflakeRows) getLocation() *time.Location {
Expand Down Expand Up @@ -168,6 +169,13 @@ func (rows *snowflakeRows) GetArrowBatches() ([]*ArrowBatch, error) {
return nil, err
}

if rows.format != "arrow" {
return nil, (&SnowflakeError{
QueryID: rows.queryID,
Message: errJSONResponseInArrowBatchesMode,
}).exceptionTelemetry(rows.sc)
}

return rows.ChunkDownloader.getArrowBatches(), nil
}

Expand Down
Loading