diff --git a/azure_storage_client.go b/azure_storage_client.go index ae1b72118..92b3b6605 100644 --- a/azure_storage_client.go +++ b/azure_storage_client.go @@ -23,6 +23,7 @@ import ( ) type snowflakeAzureClient struct { + cfg *Config } type azureLocation struct { @@ -85,9 +86,11 @@ func (util *snowflakeAzureClient) getFileHeader(meta *fileMetadata, filename str if meta.mockAzureClient != nil { blobClient = meta.mockAzureClient } - resp, err := blobClient.GetProperties(context.Background(), &blob.GetPropertiesOptions{ - AccessConditions: &blob.AccessConditions{}, - CPKInfo: &blob.CPKInfo{}, + resp, err := withCloudStorageTimeout(util.cfg, func(ctx context.Context) (blob.GetPropertiesResponse, error) { + return blobClient.GetProperties(ctx, &blob.GetPropertiesOptions{ + AccessConditions: &blob.AccessConditions{}, + CPKInfo: &blob.CPKInfo{}, + }) }) if err != nil { var se *azcore.ResponseError @@ -203,9 +206,11 @@ func (util *snowflakeAzureClient) uploadFile( if meta.realSrcStream != nil { uploadSrc = meta.realSrcStream } - _, err = blobClient.UploadStream(context.Background(), uploadSrc, &azblob.UploadStreamOptions{ - BlockSize: int64(uploadSrc.Len()), - Metadata: azureMeta, + _, err = withCloudStorageTimeout(util.cfg, func(ctx context.Context) (azblob.UploadStreamResponse, error) { + return blobClient.UploadStream(ctx, uploadSrc, &azblob.UploadStreamOptions{ + BlockSize: int64(uploadSrc.Len()), + Metadata: azureMeta, + }) }) } else { var f *os.File @@ -228,7 +233,9 @@ func (util *snowflakeAzureClient) uploadFile( if meta.options.putAzureCallback != nil { blobOptions.Progress = meta.options.putAzureCallback.call } - _, err = blobClient.UploadFile(context.Background(), f, blobOptions) + _, err = withCloudStorageTimeout(util.cfg, func(ctx context.Context) (azblob.UploadFileResponse, error) { + return blobClient.UploadFile(ctx, f, blobOptions) + }) } if err != nil { var se *azcore.ResponseError @@ -279,7 +286,9 @@ func (util *snowflakeAzureClient) nativeDownloadFile( blobClient = meta.mockAzureClient } if meta.options.GetFileToStream { - blobDownloadResponse, err := blobClient.DownloadStream(context.Background(), &azblob.DownloadStreamOptions{}) + blobDownloadResponse, err := withCloudStorageTimeout(util.cfg, func(ctx context.Context) (azblob.DownloadStreamResponse, error) { + return blobClient.DownloadStream(context.Background(), &azblob.DownloadStreamOptions{}) + }) if err != nil { return err } @@ -295,9 +304,11 @@ func (util *snowflakeAzureClient) nativeDownloadFile( return err } defer f.Close() - _, err = blobClient.DownloadFile( - context.Background(), f, &azblob.DownloadFileOptions{ - Concurrency: uint16(maxConcurrency)}) + _, err = withCloudStorageTimeout(util.cfg, func(ctx context.Context) (any, error) { + return blobClient.DownloadFile( + ctx, f, &azblob.DownloadFileOptions{ + Concurrency: uint16(maxConcurrency)}) + }) if err != nil { return err } diff --git a/azure_storage_client_test.go b/azure_storage_client_test.go index ddf45b5a3..1490e3a6d 100644 --- a/azure_storage_client_test.go +++ b/azure_storage_client_test.go @@ -177,6 +177,11 @@ func TestUploadFileWithAzureUploadFailedError(t *testing.T) { return azblob.UploadFileResponse{}, errors.New("unexpected error uploading file") }, }, + sfa: &snowflakeFileTransferAgent{ + sc: &snowflakeConn{ + cfg: &Config{}, + }, + }, } uploadMeta.realSrcFileName = uploadMeta.srcFileName @@ -230,6 +235,11 @@ func TestUploadStreamWithAzureUploadFailedError(t *testing.T) { return azblob.UploadStreamResponse{}, errors.New("unexpected error uploading file") }, }, + sfa: &snowflakeFileTransferAgent{ + sc: &snowflakeConn{ + cfg: &Config{}, + }, + }, } uploadMeta.realSrcStream = uploadMeta.srcStream @@ -291,6 +301,11 @@ func TestUploadFileWithAzureUploadTokenExpired(t *testing.T) { } }, }, + sfa: &snowflakeFileTransferAgent{ + sc: &snowflakeConn{ + cfg: &Config{}, + }, + }, } uploadMeta.realSrcFileName = uploadMeta.srcFileName @@ -362,6 +377,11 @@ func TestUploadFileWithAzureUploadNeedsRetry(t *testing.T) { } }, }, + sfa: &snowflakeFileTransferAgent{ + sc: &snowflakeConn{ + cfg: &Config{}, + }, + }, } uploadMeta.realSrcFileName = uploadMeta.srcFileName @@ -418,6 +438,11 @@ func TestDownloadOneFileToAzureFailed(t *testing.T) { return blob.GetPropertiesResponse{}, nil }, }, + sfa: &snowflakeFileTransferAgent{ + sc: &snowflakeConn{ + cfg: &Config{}, + }, + }, } err = new(remoteStorageUtil).downloadOneFile(&downloadMeta) if err == nil { @@ -444,9 +469,14 @@ func TestGetFileHeaderErrorStatus(t *testing.T) { return blob.GetPropertiesResponse{}, errors.New("failed to retrieve headers") }, }, + sfa: &snowflakeFileTransferAgent{ + sc: &snowflakeConn{ + cfg: &Config{}, + }, + }, } - if header, err := new(snowflakeAzureClient).getFileHeader(&meta, "file.txt"); header != nil || err == nil { + if header, err := (&snowflakeAzureClient{cfg: &Config{}}).getFileHeader(&meta, "file.txt"); header != nil || err == nil { t.Fatalf("expected null header, got: %v", header) } if meta.resStatus != errStatus { @@ -477,9 +507,14 @@ func TestGetFileHeaderErrorStatus(t *testing.T) { } }, }, + sfa: &snowflakeFileTransferAgent{ + sc: &snowflakeConn{ + cfg: &Config{}, + }, + }, } - if header, err := new(snowflakeAzureClient).getFileHeader(&meta, "file.txt"); header != nil || err == nil { + if header, err := (&snowflakeAzureClient{cfg: &Config{}}).getFileHeader(&meta, "file.txt"); header != nil || err == nil { t.Fatalf("expected null header, got: %v", header) } if meta.resStatus != notFoundFile { @@ -505,7 +540,7 @@ func TestGetFileHeaderErrorStatus(t *testing.T) { }, } - if header, err := new(snowflakeAzureClient).getFileHeader(&meta, "file.txt"); header != nil || err == nil { + if header, err := (&snowflakeAzureClient{cfg: &Config{}}).getFileHeader(&meta, "file.txt"); header != nil || err == nil { t.Fatalf("expected null header, got: %v", header) } if meta.resStatus != renewToken { @@ -540,6 +575,11 @@ func TestUploadFileToAzureClientCastFail(t *testing.T) { options: &SnowflakeFileTransferOptions{ MultiPartThreshold: dataSizeThreshold, }, + sfa: &snowflakeFileTransferAgent{ + sc: &snowflakeConn{ + cfg: &Config{}, + }, + }, } uploadMeta.realSrcFileName = uploadMeta.srcFileName @@ -573,6 +613,11 @@ func TestAzureGetHeaderClientCastFail(t *testing.T) { return blob.GetPropertiesResponse{}, nil }, }, + sfa: &snowflakeFileTransferAgent{ + sc: &snowflakeConn{ + cfg: &Config{}, + }, + }, } _, err = new(snowflakeAzureClient).getFileHeader(&meta, "file.txt") diff --git a/dsn.go b/dsn.go index 7b7efb569..85d86f1d7 100644 --- a/dsn.go +++ b/dsn.go @@ -25,6 +25,7 @@ const ( defaultRequestTimeout = 0 * time.Second // Timeout for retry for request EXCLUDING clientTimeout defaultJWTTimeout = 60 * time.Second defaultExternalBrowserTimeout = 120 * time.Second // Timeout for external browser login + defaultCloudStorageTimeout = -1 // Timeout for calling cloud storage. defaultMaxRetryCount = 7 // specifies maximum number of subsequent retries defaultDomain = ".snowflakecomputing.com" cnDomain = ".snowflakecomputing.cn" @@ -77,6 +78,7 @@ type Config struct { ClientTimeout time.Duration // Timeout for network round trip + read out http response JWTClientTimeout time.Duration // Timeout for network round trip + read out http response used when JWT token auth is taking place ExternalBrowserTimeout time.Duration // Timeout for external browser login + CloudStorageTimeout time.Duration // Timeout for a single call to a cloud storage provider MaxRetryCount int // Specifies how many times non-periodic HTTP request can be retried Application string // application name. @@ -215,6 +217,9 @@ func DSN(cfg *Config) (dsn string, err error) { if cfg.ExternalBrowserTimeout != defaultExternalBrowserTimeout { params.Add("externalBrowserTimeout", strconv.FormatInt(int64(cfg.ExternalBrowserTimeout/time.Second), 10)) } + if cfg.CloudStorageTimeout != defaultCloudStorageTimeout { + params.Add("cloudStorageTimeout", strconv.FormatInt(int64(cfg.CloudStorageTimeout/time.Second), 10)) + } if cfg.MaxRetryCount != defaultMaxRetryCount { params.Add("maxRetryCount", strconv.Itoa(cfg.MaxRetryCount)) } @@ -498,6 +503,9 @@ func fillMissingConfigParameters(cfg *Config) error { if cfg.ExternalBrowserTimeout == 0 { cfg.ExternalBrowserTimeout = defaultExternalBrowserTimeout } + if cfg.CloudStorageTimeout == 0 { + cfg.CloudStorageTimeout = defaultCloudStorageTimeout + } if cfg.MaxRetryCount == 0 { cfg.MaxRetryCount = defaultMaxRetryCount } @@ -714,6 +722,11 @@ func parseDSNParams(cfg *Config, params string) (err error) { if err != nil { return err } + case "cloudStorageTimeout": + cfg.CloudStorageTimeout, err = parseTimeout(value) + if err != nil { + return err + } case "maxRetryCount": cfg.MaxRetryCount, err = strconv.Atoi(value) if err != nil { diff --git a/dsn_test.go b/dsn_test.go index ba6386d5d..ec4f61303 100644 --- a/dsn_test.go +++ b/dsn_test.go @@ -45,6 +45,7 @@ func TestParseDSN(t *testing.T) { ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, ExternalBrowserTimeout: defaultExternalBrowserTimeout, + CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: ConfigBoolTrue, }, ocspMode: ocspModeFailOpen, @@ -61,6 +62,7 @@ func TestParseDSN(t *testing.T) { ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, ExternalBrowserTimeout: defaultExternalBrowserTimeout, + CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: ConfigBoolTrue, }, ocspMode: ocspModeFailOpen, @@ -75,6 +77,7 @@ func TestParseDSN(t *testing.T) { ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, ExternalBrowserTimeout: defaultExternalBrowserTimeout, + CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: ConfigBoolTrue, }, ocspMode: ocspModeFailOpen, @@ -90,6 +93,7 @@ func TestParseDSN(t *testing.T) { ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, ExternalBrowserTimeout: defaultExternalBrowserTimeout, + CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: ConfigBoolTrue, }, ocspMode: ocspModeFailOpen, @@ -105,6 +109,7 @@ func TestParseDSN(t *testing.T) { ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, ExternalBrowserTimeout: defaultExternalBrowserTimeout, + CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: ConfigBoolTrue, }, ocspMode: ocspModeFailOpen, @@ -121,6 +126,7 @@ func TestParseDSN(t *testing.T) { ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, ExternalBrowserTimeout: defaultExternalBrowserTimeout, + CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: ConfigBoolTrue, }, ocspMode: ocspModeFailOpen, @@ -137,6 +143,7 @@ func TestParseDSN(t *testing.T) { ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, ExternalBrowserTimeout: defaultExternalBrowserTimeout, + CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: ConfigBoolTrue, }, ocspMode: ocspModeFailOpen, @@ -152,6 +159,7 @@ func TestParseDSN(t *testing.T) { ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, ExternalBrowserTimeout: defaultExternalBrowserTimeout, + CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: ConfigBoolTrue, }, ocspMode: ocspModeFailOpen, @@ -167,6 +175,7 @@ func TestParseDSN(t *testing.T) { ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, ExternalBrowserTimeout: defaultExternalBrowserTimeout, + CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: ConfigBoolTrue, }, ocspMode: ocspModeFailOpen, @@ -182,6 +191,7 @@ func TestParseDSN(t *testing.T) { ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, ExternalBrowserTimeout: defaultExternalBrowserTimeout, + CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: ConfigBoolTrue, }, ocspMode: ocspModeFailOpen, @@ -197,6 +207,7 @@ func TestParseDSN(t *testing.T) { ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, ExternalBrowserTimeout: defaultExternalBrowserTimeout, + CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: ConfigBoolTrue, }, ocspMode: ocspModeFailOpen, @@ -213,6 +224,7 @@ func TestParseDSN(t *testing.T) { ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, ExternalBrowserTimeout: defaultExternalBrowserTimeout, + CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: ConfigBoolTrue, }, ocspMode: ocspModeFailOpen, @@ -229,6 +241,7 @@ func TestParseDSN(t *testing.T) { ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, ExternalBrowserTimeout: defaultExternalBrowserTimeout, + CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: ConfigBoolTrue, }, ocspMode: ocspModeFailOpen, @@ -245,6 +258,7 @@ func TestParseDSN(t *testing.T) { ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, ExternalBrowserTimeout: defaultExternalBrowserTimeout, + CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: ConfigBoolTrue, }, ocspMode: ocspModeFailOpen, @@ -261,6 +275,7 @@ func TestParseDSN(t *testing.T) { ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, ExternalBrowserTimeout: defaultExternalBrowserTimeout, + CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: ConfigBoolTrue, }, ocspMode: ocspModeFailOpen, @@ -277,6 +292,7 @@ func TestParseDSN(t *testing.T) { ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, ExternalBrowserTimeout: defaultExternalBrowserTimeout, + CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: ConfigBoolTrue, }, ocspMode: ocspModeFailOpen, @@ -293,6 +309,7 @@ func TestParseDSN(t *testing.T) { ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, ExternalBrowserTimeout: defaultExternalBrowserTimeout, + CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: ConfigBoolTrue, }, ocspMode: ocspModeFailOpen, @@ -309,6 +326,7 @@ func TestParseDSN(t *testing.T) { ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, ExternalBrowserTimeout: defaultExternalBrowserTimeout, + CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: ConfigBoolTrue, }, ocspMode: ocspModeFailOpen, @@ -325,6 +343,7 @@ func TestParseDSN(t *testing.T) { ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, ExternalBrowserTimeout: defaultExternalBrowserTimeout, + CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: ConfigBoolTrue, }, ocspMode: ocspModeFailOpen, @@ -341,6 +360,7 @@ func TestParseDSN(t *testing.T) { ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, ExternalBrowserTimeout: defaultExternalBrowserTimeout, + CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: ConfigBoolTrue, }, ocspMode: ocspModeFailOpen, @@ -357,6 +377,7 @@ func TestParseDSN(t *testing.T) { ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, ExternalBrowserTimeout: defaultExternalBrowserTimeout, + CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: ConfigBoolTrue, }, ocspMode: ocspModeFailOpen, @@ -373,6 +394,7 @@ func TestParseDSN(t *testing.T) { ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, ExternalBrowserTimeout: defaultExternalBrowserTimeout, + CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: ConfigBoolTrue, }, ocspMode: ocspModeFailOpen, @@ -389,6 +411,7 @@ func TestParseDSN(t *testing.T) { ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, ExternalBrowserTimeout: defaultExternalBrowserTimeout, + CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: ConfigBoolTrue, }, ocspMode: ocspModeFailOpen, @@ -404,6 +427,7 @@ func TestParseDSN(t *testing.T) { ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, ExternalBrowserTimeout: defaultExternalBrowserTimeout, + CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: ConfigBoolTrue, }, ocspMode: ocspModeFailOpen, @@ -419,6 +443,7 @@ func TestParseDSN(t *testing.T) { ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, ExternalBrowserTimeout: defaultExternalBrowserTimeout, + CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: ConfigBoolTrue, }, ocspMode: ocspModeFailOpen, @@ -434,6 +459,7 @@ func TestParseDSN(t *testing.T) { ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, ExternalBrowserTimeout: defaultExternalBrowserTimeout, + CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: ConfigBoolTrue, }, ocspMode: ocspModeFailOpen, @@ -452,17 +478,19 @@ func TestParseDSN(t *testing.T) { ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, ExternalBrowserTimeout: defaultExternalBrowserTimeout, + CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: ConfigBoolTrue, }, ocspMode: ocspModeFailOpen, }, { - dsn: "u:p@a?database=d&externalBrowserTimeout=20", + dsn: "u:p@a?database=d&externalBrowserTimeout=20&cloudStorageTimeout=7", config: &Config{ Account: "a", User: "u", Password: "p", Protocol: "https", Host: "a.snowflakecomputing.com", Port: 443, Database: "d", Schema: "", ExternalBrowserTimeout: 20 * time.Second, + CloudStorageTimeout: 7 * time.Second, OCSPFailOpen: OCSPFailOpenTrue, ValidateDefaultParameters: ConfigBoolTrue, ClientTimeout: defaultClientTimeout, @@ -483,6 +511,7 @@ func TestParseDSN(t *testing.T) { ValidateDefaultParameters: ConfigBoolTrue, ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, + CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: ConfigBoolTrue, MaxRetryCount: 20, }, @@ -500,6 +529,7 @@ func TestParseDSN(t *testing.T) { ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, ExternalBrowserTimeout: defaultExternalBrowserTimeout, + CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: ConfigBoolTrue, }, ocspMode: ocspModeFailOpen, @@ -514,6 +544,7 @@ func TestParseDSN(t *testing.T) { ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, ExternalBrowserTimeout: defaultExternalBrowserTimeout, + CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: ConfigBoolTrue, }, ocspMode: ocspModeFailOpen, @@ -535,6 +566,7 @@ func TestParseDSN(t *testing.T) { ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, ExternalBrowserTimeout: defaultExternalBrowserTimeout, + CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: ConfigBoolTrue, }, ocspMode: ocspModeInsecure, @@ -552,6 +584,7 @@ func TestParseDSN(t *testing.T) { ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, ExternalBrowserTimeout: defaultExternalBrowserTimeout, + CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: ConfigBoolTrue, }, ocspMode: ocspModeInsecure, @@ -569,6 +602,7 @@ func TestParseDSN(t *testing.T) { ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, ExternalBrowserTimeout: defaultExternalBrowserTimeout, + CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: ConfigBoolTrue, }, ocspMode: ocspModeInsecure, @@ -587,6 +621,7 @@ func TestParseDSN(t *testing.T) { ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, ExternalBrowserTimeout: defaultExternalBrowserTimeout, + CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: ConfigBoolTrue, }, ocspMode: ocspModeFailOpen, @@ -604,6 +639,7 @@ func TestParseDSN(t *testing.T) { ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, ExternalBrowserTimeout: defaultExternalBrowserTimeout, + CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: ConfigBoolTrue, }, ocspMode: ocspModeFailOpen, @@ -656,6 +692,7 @@ func TestParseDSN(t *testing.T) { ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, ExternalBrowserTimeout: defaultExternalBrowserTimeout, + CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: ConfigBoolTrue, }, ocspMode: ocspModeFailOpen, @@ -672,6 +709,7 @@ func TestParseDSN(t *testing.T) { ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, ExternalBrowserTimeout: defaultExternalBrowserTimeout, + CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: ConfigBoolTrue, }, ocspMode: ocspModeFailOpen, @@ -693,6 +731,7 @@ func TestParseDSN(t *testing.T) { ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, ExternalBrowserTimeout: defaultExternalBrowserTimeout, + CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: ConfigBoolTrue, }, ocspMode: ocspModeFailOpen, @@ -715,6 +754,7 @@ func TestParseDSN(t *testing.T) { ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, ExternalBrowserTimeout: defaultExternalBrowserTimeout, + CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: ConfigBoolTrue, }, ocspMode: ocspModeFailOpen, @@ -731,6 +771,7 @@ func TestParseDSN(t *testing.T) { ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, ExternalBrowserTimeout: defaultExternalBrowserTimeout, + CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: ConfigBoolTrue, }, ocspMode: ocspModeFailOpen, @@ -746,6 +787,7 @@ func TestParseDSN(t *testing.T) { ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, ExternalBrowserTimeout: defaultExternalBrowserTimeout, + CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: ConfigBoolTrue, }, ocspMode: ocspModeFailOpen, @@ -761,6 +803,7 @@ func TestParseDSN(t *testing.T) { ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, ExternalBrowserTimeout: defaultExternalBrowserTimeout, + CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: ConfigBoolTrue, }, ocspMode: ocspModeFailClosed, @@ -776,6 +819,7 @@ func TestParseDSN(t *testing.T) { ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, ExternalBrowserTimeout: defaultExternalBrowserTimeout, + CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: ConfigBoolTrue, }, ocspMode: ocspModeInsecure, @@ -790,6 +834,7 @@ func TestParseDSN(t *testing.T) { ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, ExternalBrowserTimeout: defaultExternalBrowserTimeout, + CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: ConfigBoolTrue, }, ocspMode: ocspModeFailOpen, @@ -804,6 +849,7 @@ func TestParseDSN(t *testing.T) { ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, ExternalBrowserTimeout: defaultExternalBrowserTimeout, + CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: ConfigBoolTrue, }, ocspMode: ocspModeFailOpen, @@ -818,6 +864,7 @@ func TestParseDSN(t *testing.T) { ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, ExternalBrowserTimeout: defaultExternalBrowserTimeout, + CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: ConfigBoolTrue, }, ocspMode: ocspModeFailOpen, @@ -832,6 +879,7 @@ func TestParseDSN(t *testing.T) { ClientTimeout: 300 * time.Second, JWTClientTimeout: 45 * time.Second, ExternalBrowserTimeout: defaultExternalBrowserTimeout, + CloudStorageTimeout: defaultCloudStorageTimeout, DisableQueryContextCache: false, IncludeRetryReason: ConfigBoolFalse, }, @@ -847,6 +895,7 @@ func TestParseDSN(t *testing.T) { ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, ExternalBrowserTimeout: defaultExternalBrowserTimeout, + CloudStorageTimeout: defaultCloudStorageTimeout, TmpDirPath: "/tmp", IncludeRetryReason: ConfigBoolTrue, }, @@ -862,6 +911,7 @@ func TestParseDSN(t *testing.T) { ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, ExternalBrowserTimeout: defaultExternalBrowserTimeout, + CloudStorageTimeout: defaultCloudStorageTimeout, DisableQueryContextCache: true, IncludeRetryReason: ConfigBoolTrue, }, @@ -877,6 +927,7 @@ func TestParseDSN(t *testing.T) { ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, ExternalBrowserTimeout: defaultExternalBrowserTimeout, + CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: ConfigBoolTrue, }, ocspMode: ocspModeFailOpen, @@ -891,6 +942,7 @@ func TestParseDSN(t *testing.T) { ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, ExternalBrowserTimeout: defaultExternalBrowserTimeout, + CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: ConfigBoolTrue, ClientConfigFile: "/Users/user/config.json", }, @@ -906,6 +958,7 @@ func TestParseDSN(t *testing.T) { ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, ExternalBrowserTimeout: defaultExternalBrowserTimeout, + CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: ConfigBoolTrue, ClientConfigFile: "c:\\Users\\user\\config.json", }, @@ -927,6 +980,7 @@ func TestParseDSN(t *testing.T) { ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, ExternalBrowserTimeout: defaultExternalBrowserTimeout, + CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: ConfigBoolTrue, DisableConsoleLogin: ConfigBoolTrue, }, @@ -944,6 +998,7 @@ func TestParseDSN(t *testing.T) { ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, ExternalBrowserTimeout: defaultExternalBrowserTimeout, + CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: ConfigBoolTrue, DisableConsoleLogin: ConfigBoolFalse, }, @@ -961,6 +1016,7 @@ func TestParseDSN(t *testing.T) { ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, ExternalBrowserTimeout: defaultExternalBrowserTimeout, + CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: ConfigBoolTrue, DisableSamlURLCheck: ConfigBoolTrue, }, @@ -978,6 +1034,7 @@ func TestParseDSN(t *testing.T) { ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, ExternalBrowserTimeout: defaultExternalBrowserTimeout, + CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: ConfigBoolTrue, DisableSamlURLCheck: ConfigBoolFalse, }, @@ -998,6 +1055,7 @@ func TestParseDSN(t *testing.T) { ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, ExternalBrowserTimeout: defaultExternalBrowserTimeout, + CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: ConfigBoolTrue, Authenticator: at, }, @@ -1018,6 +1076,7 @@ func TestParseDSN(t *testing.T) { ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, ExternalBrowserTimeout: defaultExternalBrowserTimeout, + CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: ConfigBoolTrue, Authenticator: at, }, @@ -1038,6 +1097,7 @@ func TestParseDSN(t *testing.T) { ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, ExternalBrowserTimeout: defaultExternalBrowserTimeout, + CloudStorageTimeout: defaultCloudStorageTimeout, IncludeRetryReason: ConfigBoolTrue, Authenticator: at, }, @@ -1134,6 +1194,10 @@ func TestParseDSN(t *testing.T) { t.Fatalf("%d: Failed to match ExternalBrowserTimeout. expected: %v, got: %v", i, test.config.ExternalBrowserTimeout, cfg.ExternalBrowserTimeout) } + if test.config.CloudStorageTimeout != cfg.CloudStorageTimeout { + t.Fatalf("%d: Failed to match CloudStorageTimeout. expected: %v, got: %v", + i, test.config.CloudStorageTimeout, cfg.CloudStorageTimeout) + } if test.config.TmpDirPath != cfg.TmpDirPath { t.Fatalf("%v: Failed to match TmpDirPatch. expected: %v, got: %v", i, test.config.TmpDirPath, cfg.TmpDirPath) } @@ -1293,8 +1357,9 @@ func TestDSN(t *testing.T) { Account: "a", Region: "r", ExternalBrowserTimeout: 20 * time.Second, + CloudStorageTimeout: 7 * time.Second, }, - dsn: "u:p@a.r.snowflakecomputing.com:443?externalBrowserTimeout=20&ocspFailOpen=true®ion=r&validateDefaultParameters=true", + dsn: "u:p@a.r.snowflakecomputing.com:443?cloudStorageTimeout=7&externalBrowserTimeout=20&ocspFailOpen=true®ion=r&validateDefaultParameters=true", }, { cfg: &Config{ diff --git a/file_transfer_agent.go b/file_transfer_agent.go index aa40282c7..c30f9868c 100644 --- a/file_transfer_agent.go +++ b/file_transfer_agent.go @@ -618,8 +618,10 @@ func (sfa *snowflakeFileTransferAgent) transferAccelerateConfigWithUtil(s3Util s Message: errMsgFailedToConvertToS3Client, }).exceptionTelemetry(sfa.sc) } - ret, err := client.GetBucketAccelerateConfiguration(context.Background(), &s3.GetBucketAccelerateConfigurationInput{ - Bucket: &s3Loc.bucketName, + ret, err := withCloudStorageTimeout(sfa.sc.cfg, func(ctx context.Context) (*s3.GetBucketAccelerateConfigurationOutput, error) { + return client.GetBucketAccelerateConfiguration(ctx, &s3.GetBucketAccelerateConfigurationInput{ + Bucket: &s3Loc.bucketName, + }) }) sfa.useAccelerateEndpoint = ret != nil && ret.Status == "Enabled" if err != nil { @@ -628,6 +630,15 @@ func (sfa *snowflakeFileTransferAgent) transferAccelerateConfigWithUtil(s3Util s return nil } +func withCloudStorageTimeout[T any](cfg *Config, f func(ctx context.Context) (T, error)) (T, error) { + if cfg.CloudStorageTimeout > 0 { + ctx, cancelFunc := context.WithTimeout(context.Background(), cfg.CloudStorageTimeout) + defer cancelFunc() + return f(ctx) + } + return f(context.Background()) +} + func (sfa *snowflakeFileTransferAgent) transferAccelerateConfig() error { if sfa.stageLocationType == s3Client { s3Util := new(snowflakeS3Client) @@ -681,7 +692,7 @@ func (sfa *snowflakeFileTransferAgent) upload( largeFileMetadata []*fileMetadata, smallFileMetadata []*fileMetadata) error { client, err := sfa.getStorageClient(sfa.stageLocationType). - createClient(sfa.stageInfo, sfa.useAccelerateEndpoint) + createClient(sfa.stageInfo, sfa.useAccelerateEndpoint, sfa.sc.cfg) if err != nil { return err } @@ -710,7 +721,7 @@ func (sfa *snowflakeFileTransferAgent) upload( func (sfa *snowflakeFileTransferAgent) download( fileMetadata []*fileMetadata) error { client, err := sfa.getStorageClient(sfa.stageLocationType). - createClient(sfa.stageInfo, sfa.useAccelerateEndpoint) + createClient(sfa.stageInfo, sfa.useAccelerateEndpoint, sfa.sc.cfg) if err != nil { return err } @@ -987,7 +998,9 @@ func (sfa *snowflakeFileTransferAgent) getStorageClient(stageLocationType cloudT if stageLocationType == local { return &localUtil{} } else if stageLocationType == s3Client || stageLocationType == azureClient || stageLocationType == gcsClient { - return &remoteStorageUtil{} + return &remoteStorageUtil{ + cfg: sfa.sc.cfg, + } } return nil } @@ -1004,7 +1017,7 @@ func (sfa *snowflakeFileTransferAgent) renewExpiredClient() (cloudClient, error) return nil, err } storageClient := sfa.getStorageClient(sfa.stageLocationType) - return storageClient.createClient(&data.Data.StageInfo, sfa.useAccelerateEndpoint) + return storageClient.createClient(&data.Data.StageInfo, sfa.useAccelerateEndpoint, sfa.sc.cfg) } func (sfa *snowflakeFileTransferAgent) result() (*execResponse, error) { diff --git a/gcs_storage_client.go b/gcs_storage_client.go index 0627f6122..b45c51504 100644 --- a/gcs_storage_client.go +++ b/gcs_storage_client.go @@ -3,6 +3,7 @@ package gosnowflake import ( + "context" "encoding/json" "fmt" "io" @@ -22,6 +23,7 @@ const ( ) type snowflakeGcsClient struct { + cfg *Config } type gcsLocation struct { @@ -62,19 +64,21 @@ func (util *snowflakeGcsClient) getFileHeader(meta *fileMetadata, filename strin "Authorization": "Bearer " + accessToken, } - req, err := http.NewRequest("HEAD", URL.String(), nil) - if err != nil { - return nil, err - } - for k, v := range gcsHeaders { - req.Header.Add(k, v) - } - client := newGcsClient() - // for testing only - if meta.mockGcsClient != nil { - client = meta.mockGcsClient - } - resp, err := client.Do(req) + resp, err := withCloudStorageTimeout(util.cfg, func(ctx context.Context) (*http.Response, error) { + req, err := http.NewRequestWithContext(ctx, "HEAD", URL.String(), nil) + if err != nil { + return nil, err + } + for k, v := range gcsHeaders { + req.Header.Add(k, v) + } + client := newGcsClient() + // for testing only + if meta.mockGcsClient != nil { + client = meta.mockGcsClient + } + return client.Do(req) + }) if err != nil { return nil, err } @@ -208,19 +212,22 @@ func (util *snowflakeGcsClient) uploadFile( } } - req, err := http.NewRequest("PUT", uploadURL.String(), uploadSrc) - if err != nil { - return err - } - for k, v := range gcsHeaders { - req.Header.Add(k, v) - } - client := newGcsClient() - // for testing only - if meta.mockGcsClient != nil { - client = meta.mockGcsClient - } - resp, err := client.Do(req) + resp, err := withCloudStorageTimeout(util.cfg, func(ctx context.Context) (*http.Response, error) { + req, err := http.NewRequestWithContext(ctx, "PUT", uploadURL.String(), uploadSrc) + if err != nil { + return nil, err + } + for k, v := range gcsHeaders { + req.Header.Add(k, v) + } + client := newGcsClient() + // for testing only + if meta.mockGcsClient != nil { + client = meta.mockGcsClient + } + return client.Do(req) + }) + if err != nil { return err } @@ -286,19 +293,22 @@ func (util *snowflakeGcsClient) nativeDownloadFile( } } - req, err := http.NewRequest("GET", downloadURL.String(), nil) - if err != nil { - return err - } - for k, v := range gcsHeaders { - req.Header.Add(k, v) - } - client := newGcsClient() - // for testing only - if meta.mockGcsClient != nil { - client = meta.mockGcsClient - } - resp, err := client.Do(req) + resp, err := withCloudStorageTimeout(util.cfg, func(ctx context.Context) (*http.Response, error) { + req, err := http.NewRequestWithContext(ctx, "GET", downloadURL.String(), nil) + if err != nil { + return nil, err + } + for k, v := range gcsHeaders { + req.Header.Add(k, v) + } + client := newGcsClient() + // for testing only + if meta.mockGcsClient != nil { + client = meta.mockGcsClient + } + return client.Do(req) + }) + if err != nil { return err } diff --git a/gcs_storage_client_test.go b/gcs_storage_client_test.go index 04d3ea67f..3c360f7f7 100644 --- a/gcs_storage_client_test.go +++ b/gcs_storage_client_test.go @@ -163,6 +163,11 @@ func TestUploadFileWithGcsUploadFailedError(t *testing.T) { return nil, errors.New("unexpected error uploading file") }, }, + sfa: &snowflakeFileTransferAgent{ + sc: &snowflakeConn{ + cfg: &Config{}, + }, + }, } uploadMeta.realSrcFileName = uploadMeta.srcFileName @@ -222,6 +227,11 @@ func TestUploadFileWithGcsUploadFailedWithRetry(t *testing.T) { }, nil }, }, + sfa: &snowflakeFileTransferAgent{ + sc: &snowflakeConn{ + cfg: &Config{}, + }, + }, } uploadMeta.realSrcFileName = uploadMeta.srcFileName @@ -282,6 +292,11 @@ func TestUploadFileWithGcsUploadFailedWithTokenExpired(t *testing.T) { }, nil }, }, + sfa: &snowflakeFileTransferAgent{ + sc: &snowflakeConn{ + cfg: &Config{}, + }, + }, } uploadMeta.realSrcFileName = uploadMeta.srcFileName @@ -335,6 +350,11 @@ func TestDownloadOneFileFromGcsFailed(t *testing.T) { return nil, errors.New("unexpected error downloading file") }, }, + sfa: &snowflakeFileTransferAgent{ + sc: &snowflakeConn{ + cfg: &Config{}, + }, + }, resStatus: downloaded, // bypass file header request } err = new(remoteStorageUtil).downloadOneFile(&downloadMeta) @@ -379,6 +399,11 @@ func TestDownloadOneFileFromGcsFailedWithRetry(t *testing.T) { }, nil }, }, + sfa: &snowflakeFileTransferAgent{ + sc: &snowflakeConn{ + cfg: &Config{}, + }, + }, resStatus: downloaded, // bypass file header request } err = new(remoteStorageUtil).downloadOneFile(&downloadMeta) @@ -431,6 +456,11 @@ func TestDownloadOneFileFromGcsFailedWithTokenExpired(t *testing.T) { }, nil }, }, + sfa: &snowflakeFileTransferAgent{ + sc: &snowflakeConn{ + cfg: &Config{}, + }, + }, resStatus: downloaded, // bypass file header request } err = new(remoteStorageUtil).downloadOneFile(&downloadMeta) @@ -483,6 +513,11 @@ func TestDownloadOneFileFromGcsFailedWithFileNotFound(t *testing.T) { }, nil }, }, + sfa: &snowflakeFileTransferAgent{ + sc: &snowflakeConn{ + cfg: &Config{}, + }, + }, resStatus: downloaded, // bypass file header request } err = new(remoteStorageUtil).downloadOneFile(&downloadMeta) @@ -515,8 +550,13 @@ func TestGetHeaderTokenExpiredError(t *testing.T) { }, nil }, }, + sfa: &snowflakeFileTransferAgent{ + sc: &snowflakeConn{ + cfg: &Config{}, + }, + }, } - if header, err := new(snowflakeGcsClient).getFileHeader(&meta, "file.txt"); header != nil || err == nil { + if header, err := (&snowflakeGcsClient{cfg: &Config{}}).getFileHeader(&meta, "file.txt"); header != nil || err == nil { t.Fatalf("expected null header, got: %v", header) } if meta.resStatus != renewToken { @@ -544,8 +584,13 @@ func TestGetHeaderFileNotFound(t *testing.T) { }, nil }, }, + sfa: &snowflakeFileTransferAgent{ + sc: &snowflakeConn{ + cfg: &Config{}, + }, + }, } - if header, err := new(snowflakeGcsClient).getFileHeader(&meta, "file.txt"); header != nil || err == nil { + if header, err := (&snowflakeGcsClient{cfg: &Config{}}).getFileHeader(&meta, "file.txt"); header != nil || err == nil { t.Fatalf("expected null header, got: %v", header) } if meta.resStatus != notFoundFile { @@ -571,7 +616,7 @@ func TestGetHeaderPresignedUrlReturns404(t *testing.T) { stageInfo: &info, presignedURL: presignedURL, } - header, err := new(snowflakeGcsClient).getFileHeader(&meta, "file.txt") + header, err := (&snowflakeGcsClient{cfg: &Config{}}).getFileHeader(&meta, "file.txt") if header != nil { t.Fatalf("expected null header, got: %v", header) } @@ -600,8 +645,13 @@ func TestGetHeaderReturnsError(t *testing.T) { return nil, errors.New("unexpected exception getting file header") }, }, + sfa: &snowflakeFileTransferAgent{ + sc: &snowflakeConn{ + cfg: &Config{}, + }, + }, } - if header, err := new(snowflakeGcsClient).getFileHeader(&meta, "file.txt"); header != nil || err == nil { + if header, err := (&snowflakeGcsClient{cfg: &Config{}}).getFileHeader(&meta, "file.txt"); header != nil || err == nil { t.Fatalf("expected null header, got: %v", header) } } @@ -625,8 +675,13 @@ func TestGetHeaderBadRequest(t *testing.T) { }, nil }, }, + sfa: &snowflakeFileTransferAgent{ + sc: &snowflakeConn{ + cfg: &Config{}, + }, + }, } - if header, err := new(snowflakeGcsClient).getFileHeader(&meta, "file.txt"); header != nil || err == nil { + if header, err := (&snowflakeGcsClient{cfg: &Config{}}).getFileHeader(&meta, "file.txt"); header != nil || err == nil { t.Fatalf("expected null header, got: %v", header) } @@ -655,8 +710,13 @@ func TestGetHeaderRetryableError(t *testing.T) { }, nil }, }, + sfa: &snowflakeFileTransferAgent{ + sc: &snowflakeConn{ + cfg: &Config{}, + }, + }, } - if header, err := new(snowflakeGcsClient).getFileHeader(&meta, "file.txt"); header != nil || err == nil { + if header, err := (&snowflakeGcsClient{cfg: &Config{}}).getFileHeader(&meta, "file.txt"); header != nil || err == nil { t.Fatalf("expected null header, got: %v", header) } if meta.resStatus != needRetry { @@ -697,6 +757,11 @@ func TestUploadStreamFailed(t *testing.T) { return nil, errors.New("unexpected error uploading file") }, }, + sfa: &snowflakeFileTransferAgent{ + sc: &snowflakeConn{ + cfg: &Config{}, + }, + }, } uploadMeta.realSrcStream = uploadMeta.srcStream @@ -744,6 +809,11 @@ func TestUploadFileWithBadRequest(t *testing.T) { }, nil }, }, + sfa: &snowflakeFileTransferAgent{ + sc: &snowflakeConn{ + cfg: &Config{}, + }, + }, } uploadMeta.realSrcFileName = uploadMeta.srcFileName @@ -791,8 +861,13 @@ func TestGetFileHeaderEncryptionData(t *testing.T) { }, nil }, }, + sfa: &snowflakeFileTransferAgent{ + sc: &snowflakeConn{ + cfg: &Config{}, + }, + }, } - header, err := new(snowflakeGcsClient).getFileHeader(&meta, "file.txt") + header, err := (&snowflakeGcsClient{cfg: &Config{}}).getFileHeader(&meta, "file.txt") if err != nil { t.Fatal(err) } @@ -837,8 +912,13 @@ func TestGetFileHeaderEncryptionDataInterfaceConversionError(t *testing.T) { }, nil }, }, + sfa: &snowflakeFileTransferAgent{ + sc: &snowflakeConn{ + cfg: &Config{}, + }, + }, } - _, err := new(snowflakeGcsClient).getFileHeader(&meta, "file.txt") + _, err := (&snowflakeGcsClient{cfg: &Config{}}).getFileHeader(&meta, "file.txt") if err == nil { t.Error("should have raised an error") } @@ -888,6 +968,11 @@ func TestUploadFileToGcsNoStatus(t *testing.T) { }, nil }, }, + sfa: &snowflakeFileTransferAgent{ + sc: &snowflakeConn{ + cfg: &Config{}, + }, + }, } uploadMeta.realSrcFileName = uploadMeta.srcFileName @@ -939,6 +1024,11 @@ func TestDownloadFileFromGcsError(t *testing.T) { }, nil }, }, + sfa: &snowflakeFileTransferAgent{ + sc: &snowflakeConn{ + cfg: &Config{}, + }, + }, resStatus: downloaded, // bypass file header request } err = new(remoteStorageUtil).downloadOneFile(&downloadMeta) @@ -983,6 +1073,11 @@ func TestDownloadFileWithBadRequest(t *testing.T) { }, nil }, }, + sfa: &snowflakeFileTransferAgent{ + sc: &snowflakeConn{ + cfg: &Config{}, + }, + }, resStatus: downloaded, // bypass file header request } err = new(remoteStorageUtil).downloadOneFile(&downloadMeta) diff --git a/local_storage_client.go b/local_storage_client.go index 2ae072b63..95f9a8859 100644 --- a/local_storage_client.go +++ b/local_storage_client.go @@ -15,7 +15,7 @@ import ( type localUtil struct { } -func (util *localUtil) createClient(_ *execResponseStageInfo, _ bool) (cloudClient, error) { +func (util *localUtil) createClient(_ *execResponseStageInfo, _ bool, _ *Config) (cloudClient, error) { return nil, nil } diff --git a/local_storage_client_test.go b/local_storage_client_test.go index 4edfc178f..0a4ea809c 100644 --- a/local_storage_client_test.go +++ b/local_storage_client_test.go @@ -39,7 +39,7 @@ func TestLocalUpload(t *testing.T) { LocationType: "LOCAL_FS", } localUtil := new(localUtil) - localCli, err := localUtil.createClient(&info, false) + localCli, err := localUtil.createClient(&info, false, nil) if err != nil { t.Error(err) } @@ -134,7 +134,7 @@ func TestDownloadLocalFile(t *testing.T) { LocationType: "LOCAL_FS", } localUtil := new(localUtil) - localCli, err := localUtil.createClient(&info, false) + localCli, err := localUtil.createClient(&info, false, nil) if err != nil { t.Error(err) } diff --git a/s3_storage_client.go b/s3_storage_client.go index c389c3127..d02dd6840 100644 --- a/s3_storage_client.go +++ b/s3_storage_client.go @@ -7,17 +7,16 @@ import ( "context" "errors" "fmt" - "github.com/aws/smithy-go/logging" - "io" - "net/http" - "os" - "strings" - "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/credentials" "github.com/aws/aws-sdk-go-v2/feature/s3/manager" "github.com/aws/aws-sdk-go-v2/service/s3" "github.com/aws/smithy-go" + "github.com/aws/smithy-go/logging" + "io" + "net/http" + "os" + "strings" ) const ( @@ -32,6 +31,7 @@ const ( ) type snowflakeS3Client struct { + cfg *Config } type s3Location struct { @@ -98,7 +98,9 @@ func (util *snowflakeS3Client) getFileHeader(meta *fileMetadata, filename string if meta.mockHeader != nil { s3Cli = meta.mockHeader } - out, err := s3Cli.HeadObject(context.Background(), headObjInput) + out, err := withCloudStorageTimeout(util.cfg, func(ctx context.Context) (*s3.HeadObjectOutput, error) { + return s3Cli.HeadObject(ctx, headObjInput) + }) if err != nil { var ae smithy.APIError if errors.As(err, &ae) { @@ -191,30 +193,32 @@ func (util *snowflakeS3Client) uploadFile( uploader = meta.mockUploader } - if meta.srcStream != nil { - uploadStream := meta.srcStream - if meta.realSrcStream != nil { - uploadStream = meta.realSrcStream + _, err = withCloudStorageTimeout(util.cfg, func(ctx context.Context) (any, error) { + if meta.srcStream != nil { + uploadStream := meta.srcStream + if meta.realSrcStream != nil { + uploadStream = meta.realSrcStream + } + return uploader.Upload(context.Background(), &s3.PutObjectInput{ + Bucket: &s3loc.bucketName, + Key: &s3path, + Body: bytes.NewBuffer(uploadStream.Bytes()), + Metadata: s3Meta, + }) } - _, err = uploader.Upload(context.Background(), &s3.PutObjectInput{ - Bucket: &s3loc.bucketName, - Key: &s3path, - Body: bytes.NewBuffer(uploadStream.Bytes()), - Metadata: s3Meta, - }) - } else { var file *os.File file, err = os.Open(dataFile) if err != nil { - return err + return nil, err } - _, err = uploader.Upload(context.Background(), &s3.PutObjectInput{ + return uploader.Upload(context.Background(), &s3.PutObjectInput{ Bucket: &s3loc.bucketName, Key: &s3path, Body: file, Metadata: s3Meta, }) - } + + }) if err != nil { var ae smithy.APIError @@ -268,19 +272,22 @@ func (util *snowflakeS3Client) nativeDownloadFile( downloader = meta.mockDownloader } - if meta.options.GetFileToStream { - buf := manager.NewWriteAtBuffer([]byte{}) - _, err = downloader.Download(context.Background(), buf, &s3.GetObjectInput{ - Bucket: s3Obj.Bucket, - Key: s3Obj.Key, - }) - meta.dstStream = bytes.NewBuffer(buf.Bytes()) - } else { - _, err = downloader.Download(context.Background(), f, &s3.GetObjectInput{ - Bucket: s3Obj.Bucket, - Key: s3Obj.Key, - }) - } + _, err = withCloudStorageTimeout(util.cfg, func(ctx context.Context) (any, error) { + if meta.options.GetFileToStream { + buf := manager.NewWriteAtBuffer([]byte{}) + _, err = downloader.Download(ctx, buf, &s3.GetObjectInput{ + Bucket: s3Obj.Bucket, + Key: s3Obj.Key, + }) + meta.dstStream = bytes.NewBuffer(buf.Bytes()) + } else { + _, err = downloader.Download(ctx, f, &s3.GetObjectInput{ + Bucket: s3Obj.Bucket, + Key: s3Obj.Key, + }) + } + return nil, err + }) if err != nil { var ae smithy.APIError diff --git a/s3_storage_client_test.go b/s3_storage_client_test.go index e85920cd1..a9c3eb5c9 100644 --- a/s3_storage_client_test.go +++ b/s3_storage_client_test.go @@ -93,7 +93,11 @@ func TestUploadOneFileToS3WSAEConnAborted(t *testing.T) { Message: "mock err, connection aborted", } }), - } + sfa: &snowflakeFileTransferAgent{ + sc: &snowflakeConn{ + cfg: &Config{}, + }, + }} uploadMeta.realSrcFileName = uploadMeta.srcFileName fi, err := os.Stat(uploadMeta.srcFileName) @@ -165,6 +169,11 @@ func TestUploadOneFileToS3ConnReset(t *testing.T) { Message: "mock err, connection aborted", } }), + sfa: &snowflakeFileTransferAgent{ + sc: &snowflakeConn{ + cfg: &Config{}, + }, + }, } uploadMeta.realSrcFileName = uploadMeta.srcFileName @@ -220,6 +229,11 @@ func TestUploadFileWithS3UploadFailedError(t *testing.T) { "operation: The provided token has expired.", } }), + sfa: &snowflakeFileTransferAgent{ + sc: &snowflakeConn{ + cfg: &Config{}, + }, + }, } uploadMeta.realSrcFileName = uploadMeta.srcFileName @@ -257,8 +271,13 @@ func TestGetHeadExpiryError(t *testing.T) { Code: expiredToken, } }), + sfa: &snowflakeFileTransferAgent{ + sc: &snowflakeConn{ + cfg: &Config{}, + }, + }, } - if header, err := new(snowflakeS3Client).getFileHeader(&meta, "file.txt"); header != nil || err == nil { + if header, err := (&snowflakeS3Client{cfg: &Config{}}).getFileHeader(&meta, "file.txt"); header != nil || err == nil { t.Fatalf("expected null header, got: %v", header) } if meta.resStatus != renewToken { @@ -276,8 +295,13 @@ func TestGetHeaderUnexpectedError(t *testing.T) { Code: "-1", } }), + sfa: &snowflakeFileTransferAgent{ + sc: &snowflakeConn{ + cfg: &Config{}, + }, + }, } - if header, err := new(snowflakeS3Client).getFileHeader(&meta, "file.txt"); header != nil || err == nil { + if header, err := (&snowflakeS3Client{cfg: &Config{}}).getFileHeader(&meta, "file.txt"); header != nil || err == nil { t.Fatalf("expected null header, got: %v", header) } if meta.resStatus != errStatus { @@ -292,9 +316,14 @@ func TestGetHeaderNonApiError(t *testing.T) { mockHeader: mockHeaderAPI(func(ctx context.Context, params *s3.HeadObjectInput, optFns ...func(*s3.Options)) (*s3.HeadObjectOutput, error) { return nil, errors.New("something went wrong here") }), + sfa: &snowflakeFileTransferAgent{ + sc: &snowflakeConn{ + cfg: &Config{}, + }, + }, } - header, err := new(snowflakeS3Client).getFileHeader(&meta, "file.txt") + header, err := (&snowflakeS3Client{cfg: &Config{}}).getFileHeader(&meta, "file.txt") assertNilE(t, header, fmt.Sprintf("expected header to be nil, actual: %v", header)) assertNotNilE(t, err, "expected err to not be nil") assertEqualE(t, meta.resStatus, errStatus, fmt.Sprintf("expected %v result status for non-APIerror, got: %v", errStatus, meta.resStatus)) @@ -309,9 +338,14 @@ func TestGetHeaderNotFoundError(t *testing.T) { Code: notFound, } }), + sfa: &snowflakeFileTransferAgent{ + sc: &snowflakeConn{ + cfg: &Config{}, + }, + }, } - _, err := new(snowflakeS3Client).getFileHeader(&meta, "file.txt") + _, err := (&snowflakeS3Client{cfg: &Config{}}).getFileHeader(&meta, "file.txt") if err != nil && err.Error() != "could not find file" { t.Error(err) } @@ -369,6 +403,11 @@ func TestDownloadFileWithS3TokenExpired(t *testing.T) { mockHeader: mockHeaderAPI(func(ctx context.Context, params *s3.HeadObjectInput, optFns ...func(*s3.Options)) (*s3.HeadObjectOutput, error) { return &s3.HeadObjectOutput{}, nil }), + sfa: &snowflakeFileTransferAgent{ + sc: &snowflakeConn{ + cfg: &Config{}, + }, + }, } err = new(remoteStorageUtil).downloadOneFile(&downloadMeta) if err == nil { @@ -417,6 +456,11 @@ func TestDownloadFileWithS3ConnReset(t *testing.T) { mockHeader: mockHeaderAPI(func(ctx context.Context, params *s3.HeadObjectInput, optFns ...func(*s3.Options)) (*s3.HeadObjectOutput, error) { return &s3.HeadObjectOutput{}, nil }), + sfa: &snowflakeFileTransferAgent{ + sc: &snowflakeConn{ + cfg: &Config{}, + }, + }, } err = new(remoteStorageUtil).downloadOneFile(&downloadMeta) if err == nil { @@ -465,6 +509,11 @@ func TestDownloadOneFileToS3WSAEConnAborted(t *testing.T) { mockHeader: mockHeaderAPI(func(ctx context.Context, params *s3.HeadObjectInput, optFns ...func(*s3.Options)) (*s3.HeadObjectOutput, error) { return &s3.HeadObjectOutput{}, nil }), + sfa: &snowflakeFileTransferAgent{ + sc: &snowflakeConn{ + cfg: &Config{}, + }, + }, } err = new(remoteStorageUtil).downloadOneFile(&downloadMeta) if err == nil { @@ -511,6 +560,11 @@ func TestDownloadOneFileToS3Failed(t *testing.T) { mockHeader: mockHeaderAPI(func(ctx context.Context, params *s3.HeadObjectInput, optFns ...func(*s3.Options)) (*s3.HeadObjectOutput, error) { return &s3.HeadObjectOutput{}, nil }), + sfa: &snowflakeFileTransferAgent{ + sc: &snowflakeConn{ + cfg: &Config{}, + }, + }, } err = new(remoteStorageUtil).downloadOneFile(&downloadMeta) if err == nil { @@ -550,6 +604,11 @@ func TestUploadFileToS3ClientCastFail(t *testing.T) { options: &SnowflakeFileTransferOptions{ MultiPartThreshold: dataSizeThreshold, }, + sfa: &snowflakeFileTransferAgent{ + sc: &snowflakeConn{ + cfg: &Config{}, + }, + }, } uploadMeta.realSrcFileName = uploadMeta.srcFileName @@ -583,6 +642,11 @@ func TestGetHeaderClientCastFail(t *testing.T) { Code: notFound, } }), + sfa: &snowflakeFileTransferAgent{ + sc: &snowflakeConn{ + cfg: &Config{}, + }, + }, } _, err = new(snowflakeS3Client).getFileHeader(&meta, "file.txt") @@ -630,6 +694,11 @@ func TestS3UploadRetryWithHeaderNotFound(t *testing.T) { Code: notFound, } }), + sfa: &snowflakeFileTransferAgent{ + sc: &snowflakeConn{ + cfg: &Config{}, + }, + }, } uploadMeta.realSrcFileName = uploadMeta.srcFileName @@ -639,7 +708,7 @@ func TestS3UploadRetryWithHeaderNotFound(t *testing.T) { } uploadMeta.uploadSize = fi.Size() - err = new(remoteStorageUtil).uploadOneFileWithRetry(&uploadMeta) + err = (&remoteStorageUtil{cfg: &Config{}}).uploadOneFileWithRetry(&uploadMeta) if err != nil { t.Error(err) } @@ -679,6 +748,11 @@ func TestS3UploadStreamFailed(t *testing.T) { mockUploader: mockUploadObjectAPI(func(ctx context.Context, params *s3.PutObjectInput, optFns ...func(*manager.Uploader)) (*manager.UploadOutput, error) { return nil, errors.New("unexpected error uploading file") }), + sfa: &snowflakeFileTransferAgent{ + sc: &snowflakeConn{ + cfg: &Config{}, + }, + }, } uploadMeta.realSrcStream = uploadMeta.srcStream diff --git a/storage_client.go b/storage_client.go index e855dba13..316c5ad38 100644 --- a/storage_client.go +++ b/storage_client.go @@ -19,7 +19,7 @@ const ( // implemented by localUtil and remoteStorageUtil type storageUtil interface { - createClient(*execResponseStageInfo, bool) (cloudClient, error) + createClient(*execResponseStageInfo, bool, *Config) (cloudClient, error) uploadOneFileWithRetry(*fileMetadata) error downloadOneFile(*fileMetadata) error } @@ -35,22 +35,29 @@ type cloudUtil interface { type cloudClient interface{} type remoteStorageUtil struct { + cfg *Config } -func (rsu *remoteStorageUtil) getNativeCloudType(cli string) cloudUtil { +func (rsu *remoteStorageUtil) getNativeCloudType(cli string, cfg *Config) cloudUtil { if cloudType(cli) == s3Client { - return &snowflakeS3Client{} + return &snowflakeS3Client{ + cfg, + } } else if cloudType(cli) == azureClient { - return &snowflakeAzureClient{} + return &snowflakeAzureClient{ + cfg, + } } else if cloudType(cli) == gcsClient { - return &snowflakeGcsClient{} + return &snowflakeGcsClient{ + cfg, + } } return nil } // call cloud utils' native create client methods -func (rsu *remoteStorageUtil) createClient(info *execResponseStageInfo, useAccelerateEndpoint bool) (cloudClient, error) { - utilClass := rsu.getNativeCloudType(info.LocationType) +func (rsu *remoteStorageUtil) createClient(info *execResponseStageInfo, useAccelerateEndpoint bool, cfg *Config) (cloudClient, error) { + utilClass := rsu.getNativeCloudType(info.LocationType, cfg) return utilClass.createClient(info, useAccelerateEndpoint) } @@ -81,7 +88,7 @@ func (rsu *remoteStorageUtil) uploadOneFile(meta *fileMetadata) error { dataFile = meta.realSrcFileName } - utilClass := rsu.getNativeCloudType(meta.stageInfo.LocationType) + utilClass := rsu.getNativeCloudType(meta.stageInfo.LocationType, meta.sfa.sc.cfg) maxConcurrency := int(meta.parallel) var lastErr error maxRetry := defaultMaxRetry @@ -134,7 +141,7 @@ func (rsu *remoteStorageUtil) uploadOneFile(meta *fileMetadata) error { } func (rsu *remoteStorageUtil) uploadOneFileWithRetry(meta *fileMetadata) error { - utilClass := rsu.getNativeCloudType(meta.stageInfo.LocationType) + utilClass := rsu.getNativeCloudType(meta.stageInfo.LocationType, rsu.cfg) retryOuter := true for i := 0; i < 10; i++ { // retry @@ -196,7 +203,7 @@ func (rsu *remoteStorageUtil) downloadOneFile(meta *fileMetadata) error { } } - utilClass := rsu.getNativeCloudType(meta.stageInfo.LocationType) + utilClass := rsu.getNativeCloudType(meta.stageInfo.LocationType, meta.sfa.sc.cfg) header, err := utilClass.getFileHeader(meta, meta.srcFileName) if err != nil { return err