diff --git a/connection.go b/connection.go index 6508f09a8..e386a631a 100644 --- a/connection.go +++ b/connection.go @@ -169,10 +169,21 @@ func (sc *snowflakeConn) exec( } // handle PUT/GET commands + fileTransferChan := make(chan error, 1) if isFileTransfer(query) { - data, err = sc.processFileTransfer(ctx, data, query, isInternal) - if err != nil { - return nil, err + go func() { + data, err = sc.processFileTransfer(ctx, data, query, isInternal) + fileTransferChan <- err + }() + + select { + case <-ctx.Done(): + logger.WithContext(ctx).Info("File transfer has been cancelled") + return nil, ctx.Err() + case err := <-fileTransferChan: + if err != nil { + return nil, err + } } } diff --git a/connection_util.go b/connection_util.go index 31b97c91d..78ea9d1c8 100644 --- a/connection_util.go +++ b/connection_util.go @@ -89,6 +89,7 @@ func (sc *snowflakeConn) processFileTransfer( isInternal bool) ( *execResponse, error) { sfa := snowflakeFileTransferAgent{ + ctx: ctx, sc: sc, data: &data.Data, command: query, diff --git a/file_transfer_agent.go b/file_transfer_agent.go index 03cd1dc40..c22d7489b 100644 --- a/file_transfer_agent.go +++ b/file_transfer_agent.go @@ -106,6 +106,7 @@ type SnowflakeFileTransferOptions struct { } type snowflakeFileTransferAgent struct { + ctx context.Context sc *snowflakeConn data *execResponseData command string diff --git a/file_transfer_agent_test.go b/file_transfer_agent_test.go index b94261ce3..3ffd6a8a9 100644 --- a/file_transfer_agent_test.go +++ b/file_transfer_agent_test.go @@ -30,6 +30,7 @@ func TestGetBucketAccelerateConfiguration(t *testing.T) { } runSnowflakeConnTest(t, func(sct *SCTest) { sfa := &snowflakeFileTransferAgent{ + ctx: context.Background(), sc: sct.sc, commandType: uploadCommand, srcFiles: make([]string, 0), @@ -81,6 +82,7 @@ func TestUnitDownloadWithInvalidLocalPath(t *testing.T) { func TestUnitGetLocalFilePathFromCommand(t *testing.T) { runSnowflakeConnTest(t, func(sct *SCTest) { sfa := &snowflakeFileTransferAgent{ + ctx: context.Background(), sc: sct.sc, commandType: uploadCommand, srcFiles: make([]string, 0), @@ -110,6 +112,7 @@ func TestUnitGetLocalFilePathFromCommand(t *testing.T) { func TestUnitProcessFileCompressionType(t *testing.T) { runSnowflakeConnTest(t, func(sct *SCTest) { sfa := &snowflakeFileTransferAgent{ + ctx: context.Background(), sc: sct.sc, commandType: uploadCommand, srcFiles: make([]string, 0), @@ -156,6 +159,7 @@ func TestUnitProcessFileCompressionType(t *testing.T) { func TestParseCommandWithInvalidStageLocation(t *testing.T) { runSnowflakeConnTest(t, func(sct *SCTest) { sfa := &snowflakeFileTransferAgent{ + ctx: context.Background(), sc: sct.sc, commandType: uploadCommand, srcFiles: make([]string, 0), @@ -190,6 +194,7 @@ func TestParseCommandEncryptionMaterialMismatchError(t *testing.T) { } sfa := &snowflakeFileTransferAgent{ + ctx: context.Background(), sc: sct.sc, commandType: uploadCommand, srcFiles: make([]string, 0), @@ -226,6 +231,7 @@ func TestParseCommandInvalidStorageClientException(t *testing.T) { } sfa := &snowflakeFileTransferAgent{ + ctx: context.Background(), sc: sct.sc, commandType: uploadCommand, srcFiles: make([]string, 0), @@ -253,6 +259,7 @@ func TestParseCommandInvalidStorageClientException(t *testing.T) { func TestInitFileMetadataError(t *testing.T) { runSnowflakeConnTest(t, func(sct *SCTest) { sfa := &snowflakeFileTransferAgent{ + ctx: context.Background(), sc: sct.sc, commandType: uploadCommand, srcFiles: []string{"fileDoesNotExist.txt"}, @@ -352,6 +359,7 @@ func TestUpdateMetadataWithPresignedUrl(t *testing.T) { sct.sc.rest.FuncPostQuery = presignedURLMock sfa := &snowflakeFileTransferAgent{ + ctx: context.Background(), sc: sct.sc, commandType: uploadCommand, command: "put file:///tmp/test_data/data1.txt @~", @@ -400,6 +408,7 @@ func TestUpdateMetadataWithPresignedUrlForDownload(t *testing.T) { } sfa := &snowflakeFileTransferAgent{ + ctx: context.Background(), sc: sct.sc, commandType: downloadCommand, command: "get @~/data1.txt.gz file:///tmp/testData", @@ -421,6 +430,7 @@ func TestUpdateMetadataWithPresignedUrlForDownload(t *testing.T) { func TestUpdateMetadataWithPresignedUrlError(t *testing.T) { runSnowflakeConnTest(t, func(sct *SCTest) { sfa := &snowflakeFileTransferAgent{ + ctx: context.Background(), sc: sct.sc, command: "get @~/data1.txt.gz file:///tmp/testData", stageLocationType: gcsClient, @@ -486,6 +496,7 @@ func TestUploadWhenFilesystemReadOnlyError(t *testing.T) { } sfa := &snowflakeFileTransferAgent{ + ctx: context.Background(), sc: &snowflakeConn{ cfg: &Config{}, }, @@ -585,6 +596,7 @@ func TestCustomTmpDirPath(t *testing.T) { } sfa := snowflakeFileTransferAgent{ + ctx: context.Background(), sc: &snowflakeConn{ cfg: &Config{ TmpDirPath: tmpDir, @@ -646,6 +658,7 @@ func TestReadonlyTmpDirPathShouldFail(t *testing.T) { } sfa := snowflakeFileTransferAgent{ + ctx: context.Background(), sc: &snowflakeConn{ cfg: &Config{ TmpDirPath: tmpDir, @@ -720,6 +733,7 @@ func testUploadDownloadOneFile(t *testing.T, isStream bool) { } sfa := snowflakeFileTransferAgent{ + ctx: context.Background(), sc: &snowflakeConn{ cfg: &Config{ TmpDirPath: tmpDir, diff --git a/put_get_test.go b/put_get_test.go index 2c0382294..8a4322a1b 100644 --- a/put_get_test.go +++ b/put_get_test.go @@ -11,10 +11,12 @@ import ( "math/rand" "os" "os/user" + "path" "path/filepath" "strconv" "strings" "testing" + "time" ) const createStageStmt = "CREATE OR REPLACE STAGE %v URL = '%v' CREDENTIALS = (%v)" @@ -48,6 +50,7 @@ func TestPutError(t *testing.T) { } fta := &snowflakeFileTransferAgent{ + ctx: context.Background(), data: data, options: &SnowflakeFileTransferOptions{ RaisePutGetError: false, @@ -64,6 +67,7 @@ func TestPutError(t *testing.T) { } fta = &snowflakeFileTransferAgent{ + ctx: context.Background(), data: data, options: &SnowflakeFileTransferOptions{ RaisePutGetError: true, @@ -829,3 +833,34 @@ func TestPutGetMaxLOBSize(t *testing.T) { } }) } + +func TestPutCancel(t *testing.T) { + sourceDir, err := os.Getwd() + assertNilF(t, err) + testData := path.Join(sourceDir, "/test_data/largefile.txt") + + runDBTest(t, func(dbt *DBTest) { + c := make(chan error) + ctx, cancel := context.WithCancel(context.Background()) + go func() { + // attempt to upload a large file, but it should be canceled in 3 seconds + _, err := dbt.conn.ExecContext( + ctx, + fmt.Sprintf("put 'file://%v' @~/test_put_cancel overwrite=true", + strings.ReplaceAll(testData, "\\", "/"))) + if err != nil { + c <- err + return + } + c <- nil + }() + // cancel after 3 seconds + time.Sleep(3 * time.Second) + fmt.Println("Canceled") + cancel() + ret := <-c + assertNotNilF(t, ret) + assertStringContainsF(t, ret.Error(), "context canceled", "failed to cancel.") + close(c) + }) +} diff --git a/telemetry_test.go b/telemetry_test.go index 4ff94471d..6c6835f6f 100644 --- a/telemetry_test.go +++ b/telemetry_test.go @@ -53,6 +53,7 @@ func TestTelemetrySQLException(t *testing.T) { flushSize: defaultFlushSize, } sfa := &snowflakeFileTransferAgent{ + ctx: context.Background(), sc: sct.sc, commandType: uploadCommand, srcFiles: make([]string, 0),