Skip to content

Commit

Permalink
SNOW-1854661 Detect JSON response in Arrow batches mode and return er…
Browse files Browse the repository at this point in the history
…ror (#1277)
  • Loading branch information
sfc-gh-pfus authored Dec 16, 2024
1 parent f8baf23 commit 4f89e5b
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 0 deletions.
1 change: 1 addition & 0 deletions async.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ func (sr *snowflakeRestful) getAsync(
rows.errChannel <- err
return err
}
rows.format = respd.Data.QueryResultFormat
rows.errChannel <- nil // mark query status complete
}
} else {
Expand Down
39 changes: 39 additions & 0 deletions chunk_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"context"
"database/sql/driver"
"encoding/json"
"errors"
"fmt"
"io"
"math/rand"
Expand Down Expand Up @@ -533,6 +534,44 @@ func TestWithArrowBatchesAsync(t *testing.T) {
})
}

func TestWithArrowBatchesButReturningJSON(t *testing.T) {
testWithArrowBatchesButReturningJSON(t, false)
}

func TestWithArrowBatchesButReturningJSONAsync(t *testing.T) {
testWithArrowBatchesButReturningJSON(t, true)
}

func testWithArrowBatchesButReturningJSON(t *testing.T, async bool) {
runSnowflakeConnTest(t, func(sct *SCTest) {
requestID := NewUUID()
pool := memory.NewCheckedAllocator(memory.DefaultAllocator)
defer pool.AssertSize(t, 0)
ctx := WithArrowAllocator(context.Background(), pool)
ctx = WithArrowBatches(ctx)
ctx = WithRequestID(ctx, requestID)
if async {
ctx = WithAsyncMode(ctx)
}

sct.mustExec(forceJSON, nil)
rows := sct.mustQueryContext(ctx, "SELECT 'hello'", nil)
defer rows.Close()
_, err := rows.(SnowflakeRows).GetArrowBatches()
assertNotNilF(t, err)
var se *SnowflakeError
errors.As(err, &se)
assertEqualE(t, se.Message, errJSONResponseInArrowBatchesMode)

ctx = WithRequestID(context.Background(), requestID)
rows2 := sct.mustQueryContext(ctx, "SELECT 'hello'", nil)
defer rows2.Close()
scanValues := make([]driver.Value, 1)
assertNilF(t, rows2.Next(scanValues))
assertEqualE(t, scanValues[0], "hello")
})
}

func TestQueryArrowStream(t *testing.T) {
runSnowflakeConnTest(t, func(sct *SCTest) {
numrows := 50000 // approximately 10 ArrowBatch objects
Expand Down
1 change: 1 addition & 0 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,7 @@ func (sc *snowflakeConn) queryContextInternal(
rows.sc = sc
rows.queryID = data.Data.QueryID
rows.ctx = ctx
rows.format = data.Data.QueryResultFormat

if isMultiStmt(&data.Data) {
// handleMultiQuery is responsible to fill rows with childResults
Expand Down
1 change: 1 addition & 0 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,7 @@ const (
errMsgFailedToParseTomlFile = "failed to parse toml file. the params %v occurred error with value %v"
errMsgFailedToFindDSNInTomlFile = "failed to find DSN in toml file."
errMsgInvalidPermissionToTomlFile = "file permissions different than read/write for user. Your Permission: %v"
errJSONResponseInArrowBatchesMode = "arrow batches enabled, but the response is not Arrow based"
)

// Returned if a DNS doesn't include account parameter.
Expand Down
8 changes: 8 additions & 0 deletions rows.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ type snowflakeRows struct {
errChannel chan error
location *time.Location
ctx context.Context
format string
}

func (rows *snowflakeRows) getLocation() *time.Location {
Expand Down Expand Up @@ -168,6 +169,13 @@ func (rows *snowflakeRows) GetArrowBatches() ([]*ArrowBatch, error) {
return nil, err
}

if rows.format != "arrow" {
return nil, (&SnowflakeError{
QueryID: rows.queryID,
Message: errJSONResponseInArrowBatchesMode,
}).exceptionTelemetry(rows.sc)
}

return rows.ChunkDownloader.getArrowBatches(), nil
}

Expand Down

0 comments on commit 4f89e5b

Please sign in to comment.