Skip to content

Commit

Permalink
SNOW-1006312: Cancel context not propagated to snowflakeFileTransferA…
Browse files Browse the repository at this point in the history
…gent on PUT/GET command (#1108)

Added context to snowflakeFileTransferAgent to support cancel for file transfer process
  • Loading branch information
sfc-gh-ext-simba-jl authored Aug 27, 2024
1 parent f434413 commit f20a464
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 3 deletions.
17 changes: 14 additions & 3 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,10 +169,21 @@ func (sc *snowflakeConn) exec(
}

// handle PUT/GET commands
fileTransferChan := make(chan error, 1)
if isFileTransfer(query) {
data, err = sc.processFileTransfer(ctx, data, query, isInternal)
if err != nil {
return nil, err
go func() {
data, err = sc.processFileTransfer(ctx, data, query, isInternal)
fileTransferChan <- err
}()

select {
case <-ctx.Done():
logger.WithContext(ctx).Info("File transfer has been cancelled")
return nil, ctx.Err()
case err := <-fileTransferChan:
if err != nil {
return nil, err
}
}
}

Expand Down
1 change: 1 addition & 0 deletions connection_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ func (sc *snowflakeConn) processFileTransfer(
isInternal bool) (
*execResponse, error) {
sfa := snowflakeFileTransferAgent{
ctx: ctx,
sc: sc,
data: &data.Data,
command: query,
Expand Down
1 change: 1 addition & 0 deletions file_transfer_agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ type SnowflakeFileTransferOptions struct {
}

type snowflakeFileTransferAgent struct {
ctx context.Context
sc *snowflakeConn
data *execResponseData
command string
Expand Down
14 changes: 14 additions & 0 deletions file_transfer_agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ func TestGetBucketAccelerateConfiguration(t *testing.T) {
}
runSnowflakeConnTest(t, func(sct *SCTest) {
sfa := &snowflakeFileTransferAgent{
ctx: context.Background(),
sc: sct.sc,
commandType: uploadCommand,
srcFiles: make([]string, 0),
Expand Down Expand Up @@ -81,6 +82,7 @@ func TestUnitDownloadWithInvalidLocalPath(t *testing.T) {
func TestUnitGetLocalFilePathFromCommand(t *testing.T) {
runSnowflakeConnTest(t, func(sct *SCTest) {
sfa := &snowflakeFileTransferAgent{
ctx: context.Background(),
sc: sct.sc,
commandType: uploadCommand,
srcFiles: make([]string, 0),
Expand Down Expand Up @@ -110,6 +112,7 @@ func TestUnitGetLocalFilePathFromCommand(t *testing.T) {
func TestUnitProcessFileCompressionType(t *testing.T) {
runSnowflakeConnTest(t, func(sct *SCTest) {
sfa := &snowflakeFileTransferAgent{
ctx: context.Background(),
sc: sct.sc,
commandType: uploadCommand,
srcFiles: make([]string, 0),
Expand Down Expand Up @@ -156,6 +159,7 @@ func TestUnitProcessFileCompressionType(t *testing.T) {
func TestParseCommandWithInvalidStageLocation(t *testing.T) {
runSnowflakeConnTest(t, func(sct *SCTest) {
sfa := &snowflakeFileTransferAgent{
ctx: context.Background(),
sc: sct.sc,
commandType: uploadCommand,
srcFiles: make([]string, 0),
Expand Down Expand Up @@ -190,6 +194,7 @@ func TestParseCommandEncryptionMaterialMismatchError(t *testing.T) {
}

sfa := &snowflakeFileTransferAgent{
ctx: context.Background(),
sc: sct.sc,
commandType: uploadCommand,
srcFiles: make([]string, 0),
Expand Down Expand Up @@ -226,6 +231,7 @@ func TestParseCommandInvalidStorageClientException(t *testing.T) {
}

sfa := &snowflakeFileTransferAgent{
ctx: context.Background(),
sc: sct.sc,
commandType: uploadCommand,
srcFiles: make([]string, 0),
Expand Down Expand Up @@ -253,6 +259,7 @@ func TestParseCommandInvalidStorageClientException(t *testing.T) {
func TestInitFileMetadataError(t *testing.T) {
runSnowflakeConnTest(t, func(sct *SCTest) {
sfa := &snowflakeFileTransferAgent{
ctx: context.Background(),
sc: sct.sc,
commandType: uploadCommand,
srcFiles: []string{"fileDoesNotExist.txt"},
Expand Down Expand Up @@ -352,6 +359,7 @@ func TestUpdateMetadataWithPresignedUrl(t *testing.T) {

sct.sc.rest.FuncPostQuery = presignedURLMock
sfa := &snowflakeFileTransferAgent{
ctx: context.Background(),
sc: sct.sc,
commandType: uploadCommand,
command: "put file:///tmp/test_data/data1.txt @~",
Expand Down Expand Up @@ -400,6 +408,7 @@ func TestUpdateMetadataWithPresignedUrlForDownload(t *testing.T) {
}

sfa := &snowflakeFileTransferAgent{
ctx: context.Background(),
sc: sct.sc,
commandType: downloadCommand,
command: "get @~/data1.txt.gz file:///tmp/testData",
Expand All @@ -421,6 +430,7 @@ func TestUpdateMetadataWithPresignedUrlForDownload(t *testing.T) {
func TestUpdateMetadataWithPresignedUrlError(t *testing.T) {
runSnowflakeConnTest(t, func(sct *SCTest) {
sfa := &snowflakeFileTransferAgent{
ctx: context.Background(),
sc: sct.sc,
command: "get @~/data1.txt.gz file:///tmp/testData",
stageLocationType: gcsClient,
Expand Down Expand Up @@ -486,6 +496,7 @@ func TestUploadWhenFilesystemReadOnlyError(t *testing.T) {
}

sfa := &snowflakeFileTransferAgent{
ctx: context.Background(),
sc: &snowflakeConn{
cfg: &Config{},
},
Expand Down Expand Up @@ -585,6 +596,7 @@ func TestCustomTmpDirPath(t *testing.T) {
}

sfa := snowflakeFileTransferAgent{
ctx: context.Background(),
sc: &snowflakeConn{
cfg: &Config{
TmpDirPath: tmpDir,
Expand Down Expand Up @@ -646,6 +658,7 @@ func TestReadonlyTmpDirPathShouldFail(t *testing.T) {
}

sfa := snowflakeFileTransferAgent{
ctx: context.Background(),
sc: &snowflakeConn{
cfg: &Config{
TmpDirPath: tmpDir,
Expand Down Expand Up @@ -720,6 +733,7 @@ func testUploadDownloadOneFile(t *testing.T, isStream bool) {
}

sfa := snowflakeFileTransferAgent{
ctx: context.Background(),
sc: &snowflakeConn{
cfg: &Config{
TmpDirPath: tmpDir,
Expand Down
35 changes: 35 additions & 0 deletions put_get_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@ import (
"math/rand"
"os"
"os/user"
"path"
"path/filepath"
"strconv"
"strings"
"testing"
"time"
)

const createStageStmt = "CREATE OR REPLACE STAGE %v URL = '%v' CREDENTIALS = (%v)"
Expand Down Expand Up @@ -48,6 +50,7 @@ func TestPutError(t *testing.T) {
}

fta := &snowflakeFileTransferAgent{
ctx: context.Background(),
data: data,
options: &SnowflakeFileTransferOptions{
RaisePutGetError: false,
Expand All @@ -64,6 +67,7 @@ func TestPutError(t *testing.T) {
}

fta = &snowflakeFileTransferAgent{
ctx: context.Background(),
data: data,
options: &SnowflakeFileTransferOptions{
RaisePutGetError: true,
Expand Down Expand Up @@ -829,3 +833,34 @@ func TestPutGetMaxLOBSize(t *testing.T) {
}
})
}

func TestPutCancel(t *testing.T) {
sourceDir, err := os.Getwd()
assertNilF(t, err)
testData := path.Join(sourceDir, "/test_data/largefile.txt")

runDBTest(t, func(dbt *DBTest) {
c := make(chan error)
ctx, cancel := context.WithCancel(context.Background())
go func() {
// attempt to upload a large file, but it should be canceled in 3 seconds
_, err := dbt.conn.ExecContext(
ctx,
fmt.Sprintf("put 'file://%v' @~/test_put_cancel overwrite=true",
strings.ReplaceAll(testData, "\\", "/")))
if err != nil {
c <- err
return
}
c <- nil
}()
// cancel after 3 seconds
time.Sleep(3 * time.Second)
fmt.Println("Canceled")
cancel()
ret := <-c
assertNotNilF(t, ret)
assertStringContainsF(t, ret.Error(), "context canceled", "failed to cancel.")
close(c)
})
}
1 change: 1 addition & 0 deletions telemetry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ func TestTelemetrySQLException(t *testing.T) {
flushSize: defaultFlushSize,
}
sfa := &snowflakeFileTransferAgent{
ctx: context.Background(),
sc: sct.sc,
commandType: uploadCommand,
srcFiles: make([]string, 0),
Expand Down

0 comments on commit f20a464

Please sign in to comment.