Skip to content

Commit

Permalink
SNOW-878073 Add MaxRetryCount as configurable parameter (#948)
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-dheyman authored Oct 27, 2023
1 parent 7eca43f commit 8479309
Show file tree
Hide file tree
Showing 10 changed files with 121 additions and 96 deletions.
2 changes: 1 addition & 1 deletion auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ func postAuth(

fullURL := sr.getFullURL(loginRequestPath, params)
logger.Infof("full URL: %v", fullURL)
resp, err := sr.FuncAuthPost(ctx, client, fullURL, headers, bodyCreator, timeout)
resp, err := sr.FuncAuthPost(ctx, client, fullURL, headers, bodyCreator, timeout, sr.MaxRetryCount)
if err != nil {
return nil, err
}
Expand Down
4 changes: 2 additions & 2 deletions chunk_downloader.go
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ func getChunk(
if err != nil {
return nil, err
}
return newRetryHTTP(ctx, sc.rest.Client, http.NewRequest, u, headers, timeout, sc.currentTimeProvider, sc.cfg).execute()
return newRetryHTTP(ctx, sc.rest.Client, http.NewRequest, u, headers, timeout, sc.rest.MaxRetryCount, sc.currentTimeProvider, sc.cfg).execute()
}

func (scd *snowflakeChunkDownloader) startArrowBatches() error {
Expand Down Expand Up @@ -638,7 +638,7 @@ func (f *httpStreamChunkFetcher) fetch(URL string, rows chan<- []*string) error
if err != nil {
return err
}
res, err := newRetryHTTP(context.Background(), f.client, http.NewRequest, fullURL, f.headers, 0, defaultTimeProvider, nil).execute()
res, err := newRetryHTTP(context.Background(), f.client, http.NewRequest, fullURL, f.headers, 0, 0, defaultTimeProvider, nil).execute()
if err != nil {
return err
}
Expand Down
1 change: 1 addition & 0 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -789,6 +789,7 @@ func buildSnowflakeConn(ctx context.Context, config Config) (*snowflakeConn, err
TokenAccessor: tokenAccessor,
LoginTimeout: sc.cfg.LoginTimeout,
RequestTimeout: sc.cfg.RequestTimeout,
MaxRetryCount: sc.cfg.MaxRetryCount,
FuncPost: postRestful,
FuncGet: getRestful,
FuncAuthPost: postAuthRestful,
Expand Down
17 changes: 15 additions & 2 deletions dsn.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,13 @@ import (
)

const (
defaultClientTimeout = 300 * time.Second // Timeout for network round trip + read out http response
defaultClientTimeout = 900 * time.Second // Timeout for network round trip + read out http response
defaultJWTClientTimeout = 10 * time.Second // Timeout for network round trip + read out http response but used for JWT auth
defaultLoginTimeout = 60 * time.Second // Timeout for retry for login EXCLUDING clientTimeout
defaultLoginTimeout = 300 * time.Second // Timeout for retry for login EXCLUDING clientTimeout
defaultRequestTimeout = 0 * time.Second // Timeout for retry for request EXCLUDING clientTimeout
defaultJWTTimeout = 60 * time.Second
defaultExternalBrowserTimeout = 120 * time.Second // Timeout for external browser login
defaultMaxRetryCount = 7 // specifies maximum number of subsequent retries
defaultDomain = ".snowflakecomputing.com"
)

Expand Down Expand Up @@ -74,6 +75,7 @@ type Config struct {
ClientTimeout time.Duration // Timeout for network round trip + read out http response
JWTClientTimeout time.Duration // Timeout for network round trip + read out http response used when JWT token auth is taking place
ExternalBrowserTimeout time.Duration // Timeout for external browser login
MaxRetryCount int // Specifies how many times non-periodic HTTP request can be retried

Application string // application name.
InsecureMode bool // driver doesn't check certificate revocation status
Expand Down Expand Up @@ -205,6 +207,9 @@ func DSN(cfg *Config) (dsn string, err error) {
if cfg.ExternalBrowserTimeout != defaultExternalBrowserTimeout {
params.Add("externalBrowserTimeout", strconv.FormatInt(int64(cfg.ExternalBrowserTimeout/time.Second), 10))
}
if cfg.MaxRetryCount != defaultMaxRetryCount {
params.Add("maxRetryCount", strconv.Itoa(cfg.MaxRetryCount))
}
if cfg.Application != clientType {
params.Add("application", cfg.Application)
}
Expand Down Expand Up @@ -471,6 +476,9 @@ func fillMissingConfigParameters(cfg *Config) error {
if cfg.ExternalBrowserTimeout == 0 {
cfg.ExternalBrowserTimeout = defaultExternalBrowserTimeout
}
if cfg.MaxRetryCount == 0 {
cfg.MaxRetryCount = defaultMaxRetryCount
}
if strings.Trim(cfg.Application, " ") == "" {
cfg.Application = clientType
}
Expand Down Expand Up @@ -642,6 +650,11 @@ func parseDSNParams(cfg *Config, params string) (err error) {
if err != nil {
return err
}
case "maxRetryCount":
cfg.MaxRetryCount, err = strconv.Atoi(value)
if err != nil {
return err
}
case "application":
cfg.Application = value
case "authenticator":
Expand Down
27 changes: 27 additions & 0 deletions dsn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,23 @@ func TestParseDSN(t *testing.T) {
ClientTimeout: defaultClientTimeout,
JWTClientTimeout: defaultJWTClientTimeout,
IncludeRetryReason: ConfigBoolTrue,
MaxRetryCount: defaultMaxRetryCount,
},
ocspMode: ocspModeFailOpen,
},
{
dsn: "u:p@a?database=d&maxRetryCount=20",
config: &Config{
Account: "a", User: "u", Password: "p",
Protocol: "https", Host: "a.snowflakecomputing.com", Port: 443,
Database: "d", Schema: "",
ExternalBrowserTimeout: defaultExternalBrowserTimeout,
OCSPFailOpen: OCSPFailOpenTrue,
ValidateDefaultParameters: ConfigBoolTrue,
ClientTimeout: defaultClientTimeout,
JWTClientTimeout: defaultJWTClientTimeout,
IncludeRetryReason: ConfigBoolTrue,
MaxRetryCount: 20,
},
ocspMode: ocspModeFailOpen,
},
Expand Down Expand Up @@ -1239,6 +1256,16 @@ func TestDSN(t *testing.T) {
},
dsn: "u:p@a.b.c.snowflakecomputing.com:443?ocspFailOpen=true&region=b.c&tmpDirPath=%2Ftmp&validateDefaultParameters=true",
},
{
cfg: &Config{
User: "u",
Password: "p",
Account: "a.b.c",
IncludeRetryReason: ConfigBoolFalse,
MaxRetryCount: 30,
},
dsn: "u:p@a.b.c.snowflakecomputing.com:443?includeRetryReason=false&maxRetryCount=30&ocspFailOpen=true&region=b.c&validateDefaultParameters=true",
},
{
cfg: &Config{
User: "u",
Expand Down
6 changes: 3 additions & 3 deletions ocsp.go
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ func checkOCSPCacheServer(
ocspS *ocspStatus) {
var respd map[string][]interface{}
headers := make(map[string]string)
res, err := newRetryHTTP(ctx, client, req, ocspServerHost, headers, totalTimeout, defaultTimeProvider, nil).execute()
res, err := newRetryHTTP(ctx, client, req, ocspServerHost, headers, totalTimeout, defaultMaxRetryCount, defaultTimeProvider, nil).execute()
if err != nil {
logger.Errorf("failed to get OCSP cache from OCSP Cache Server. %v", err)
return nil, &ocspStatus{
Expand Down Expand Up @@ -413,7 +413,7 @@ func retryOCSP(
}
res, err := newRetryHTTP(
ctx, client, req, ocspHost, headers,
totalTimeout*time.Duration(multiplier), defaultTimeProvider, nil).doPost().setBody(reqBody).execute()
totalTimeout*time.Duration(multiplier), defaultMaxRetryCount, defaultTimeProvider, nil).doPost().setBody(reqBody).execute()
if err != nil {
return ocspRes, ocspResBytes, &ocspStatus{
code: ocspFailedSubmit,
Expand Down Expand Up @@ -466,7 +466,7 @@ func fallbackRetryOCSPToGETRequest(
multiplier = 3 // up to 3 times for Fail Close mode
}
res, err := newRetryHTTP(ctx, client, req, ocspHost, headers,
totalTimeout*time.Duration(multiplier), defaultTimeProvider, nil).execute()
totalTimeout*time.Duration(multiplier), defaultMaxRetryCount, defaultTimeProvider, nil).execute()
if err != nil {
return ocspRes, ocspResBytes, &ocspStatus{
code: ocspFailedSubmit,
Expand Down
12 changes: 7 additions & 5 deletions restful.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ const (
type (
funcGetType func(context.Context, *snowflakeRestful, *url.URL, map[string]string, time.Duration) (*http.Response, error)
funcPostType func(context.Context, *snowflakeRestful, *url.URL, map[string]string, []byte, time.Duration, currentTimeProvider, *Config) (*http.Response, error)
funcAuthPostType func(context.Context, *http.Client, *url.URL, map[string]string, bodyCreatorType, time.Duration) (*http.Response, error)
funcAuthPostType func(context.Context, *http.Client, *url.URL, map[string]string, bodyCreatorType, time.Duration, int) (*http.Response, error)
bodyCreatorType func() ([]byte, error)
)

Expand All @@ -58,6 +58,7 @@ type snowflakeRestful struct {
Protocol string
LoginTimeout time.Duration // Login timeout
RequestTimeout time.Duration // request timeout
MaxRetryCount int

Client *http.Client
JWTClient *http.Client
Expand Down Expand Up @@ -165,7 +166,7 @@ func postRestful(
currentTimeProvider currentTimeProvider,
cfg *Config) (
*http.Response, error) {
return newRetryHTTP(ctx, sr.Client, http.NewRequest, fullURL, headers, timeout, currentTimeProvider, cfg).
return newRetryHTTP(ctx, sr.Client, http.NewRequest, fullURL, headers, timeout, sr.MaxRetryCount, currentTimeProvider, cfg).
doPost().
setBody(body).
execute()
Expand All @@ -178,7 +179,7 @@ func getRestful(
headers map[string]string,
timeout time.Duration) (
*http.Response, error) {
return newRetryHTTP(ctx, sr.Client, http.NewRequest, fullURL, headers, timeout, defaultTimeProvider, nil).execute()
return newRetryHTTP(ctx, sr.Client, http.NewRequest, fullURL, headers, timeout, sr.MaxRetryCount, defaultTimeProvider, nil).execute()
}

func postAuthRestful(
Expand All @@ -187,9 +188,10 @@ func postAuthRestful(
fullURL *url.URL,
headers map[string]string,
bodyCreator bodyCreatorType,
timeout time.Duration) (
timeout time.Duration,
maxRetryCount int) (
*http.Response, error) {
return newRetryHTTP(ctx, client, http.NewRequest, fullURL, headers, timeout, defaultTimeProvider, nil).
return newRetryHTTP(ctx, client, http.NewRequest, fullURL, headers, timeout, maxRetryCount, defaultTimeProvider, nil).
doPost().
setBodyCreator(bodyCreator).
execute()
Expand Down
10 changes: 5 additions & 5 deletions restful_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ func postTestError(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[str
}, errors.New("failed to run post method")
}

func postAuthTestError(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration) (*http.Response, error) {
func postAuthTestError(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration, _ int) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusOK,
Body: &fakeResponseBody{body: []byte{0x12, 0x34}},
Expand All @@ -43,7 +43,7 @@ func postTestAppBadGatewayError(_ context.Context, _ *snowflakeRestful, _ *url.U
}, nil
}

func postAuthTestAppBadGatewayError(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration) (*http.Response, error) {
func postAuthTestAppBadGatewayError(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration, _ int) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusBadGateway,
Body: &fakeResponseBody{body: []byte{0x12, 0x34}},
Expand All @@ -57,14 +57,14 @@ func postTestAppForbiddenError(_ context.Context, _ *snowflakeRestful, _ *url.UR
}, nil
}

func postAuthTestAppForbiddenError(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration) (*http.Response, error) {
func postAuthTestAppForbiddenError(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration, _ int) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusForbidden,
Body: &fakeResponseBody{body: []byte{0x12, 0x34}},
}, nil
}

func postAuthTestAppUnexpectedError(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration) (*http.Response, error) {
func postAuthTestAppUnexpectedError(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration, _ int) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusInsufficientStorage,
Body: &fakeResponseBody{body: []byte{0x12, 0x34}},
Expand Down Expand Up @@ -110,7 +110,7 @@ func postTestRenew(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[str
}, nil
}

func postAuthTestAfterRenew(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration) (*http.Response, error) {
func postAuthTestAfterRenew(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration, _ int) (*http.Response, error) {
dd := &execResponseData{}
er := &execResponse{
Data: *dd,
Expand Down
10 changes: 4 additions & 6 deletions retry.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,6 @@ import (
"time"
)

const (
// defaultMaxRetryCount specifies maximum number of subsequent retries
defaultMaxRetryCount = 7
)

type waitAlgo struct {
mutex *sync.Mutex // required for *rand.Rand usage
random *rand.Rand
Expand Down Expand Up @@ -248,6 +243,7 @@ type retryHTTP struct {
headers map[string]string
bodyCreator bodyCreatorType
timeout time.Duration
maxRetryCount int
currentTimeProvider currentTimeProvider
cfg *Config
}
Expand All @@ -258,6 +254,7 @@ func newRetryHTTP(ctx context.Context,
fullURL *url.URL,
headers map[string]string,
timeout time.Duration,
maxRetryCount int,
currentTimeProvider currentTimeProvider,
cfg *Config) *retryHTTP {
instance := retryHTTP{}
Expand All @@ -268,6 +265,7 @@ func newRetryHTTP(ctx context.Context,
instance.fullURL = fullURL
instance.headers = headers
instance.timeout = timeout
instance.maxRetryCount = maxRetryCount
instance.bodyCreator = emptyBodyCreator
instance.currentTimeProvider = currentTimeProvider
instance.cfg = cfg
Expand Down Expand Up @@ -341,7 +339,7 @@ func (r *retryHTTP) execute() (res *http.Response, err error) {
logger.WithContext(r.ctx).Infof("to timeout: %v", totalTimeout)
// if any timeout is set
totalTimeout -= time.Duration(sleepTime * float64(time.Second))
if totalTimeout <= 0 || retryCounter >= defaultMaxRetryCount {
if totalTimeout <= 0 || retryCounter > r.maxRetryCount {
if err != nil {
return nil, err
}
Expand Down
Loading

0 comments on commit 8479309

Please sign in to comment.