diff --git a/client.go b/client.go index efee53c..1423605 100644 --- a/client.go +++ b/client.go @@ -687,6 +687,14 @@ func (c *Client) Do(req *Request) (*http.Response, error) { } } + // First attempt was already signed + if attempt > 1 && c.PrepareRetry != nil { + if err := c.PrepareRetry(req.Request); err != nil { + prepareErr = err + break + } + } + if c.RequestLogHook != nil { switch v := logger.(type) { case LeveledLogger: @@ -778,12 +786,6 @@ func (c *Client) Do(req *Request) (*http.Response, error) { httpreq := *req.Request req.Request = &httpreq - if c.PrepareRetry != nil { - if err := c.PrepareRetry(req.Request); err != nil { - prepareErr = err - break - } - } } // this is the closest we have to success criteria diff --git a/client_test.go b/client_test.go index cc05e91..1c320fe 100644 --- a/client_test.go +++ b/client_test.go @@ -6,6 +6,8 @@ package retryablehttp import ( "bytes" "context" + "crypto/sha256" + "encoding/base64" "errors" "fmt" "io" @@ -372,6 +374,21 @@ func TestClient_Do_WithPrepareRetry(t *testing.T) { client.PrepareRetry = func(req *http.Request) error { prepareChecks++ req.Header.Set("foo", strconv.Itoa(prepareChecks)) + + // if the method is POST or PUT, set a header based on request body content + if req.Method == "POST" || req.Method == "PUT" { + bodyBytes, err := io.ReadAll(req.Body) + if err != nil { + t.Fatalf("could not read request body: %s", err) + } + preparedBody := string(bodyBytes) + + if len(preparedBody) > 0 { + sum := sha256.Sum256([]byte(preparedBody)) + contentHash := base64.StdEncoding.EncodeToString(sum[:]) + req.Header.Set("content_hash", contentHash) + } + } return nil } @@ -384,6 +401,8 @@ func TestClient_Do_WithPrepareRetry(t *testing.T) { var shouldSucceed bool tests := []struct { name string + method string + requestBody string handler ResponseHandlerFunc expectedChecks int // often 2x number of attempts since we check twice expectedPrepareChecks int @@ -391,12 +410,14 @@ func TestClient_Do_WithPrepareRetry(t *testing.T) { }{ { name: "nil handler", + method: http.MethodGet, handler: nil, expectedChecks: 1, expectedPrepareChecks: 0, }, { - name: "handler always succeeds", + name: "handler always succeeds", + method: http.MethodGet, handler: func(*http.Response) error { return nil }, @@ -404,7 +425,8 @@ func TestClient_Do_WithPrepareRetry(t *testing.T) { expectedPrepareChecks: 0, }, { - name: "handler always fails in a retryable way", + name: "handler always fails in a retryable way", + method: http.MethodGet, handler: func(*http.Response) error { return errors.New("retryable failure") }, @@ -412,7 +434,8 @@ func TestClient_Do_WithPrepareRetry(t *testing.T) { expectedPrepareChecks: 2, }, { - name: "handler always fails in a nonretryable way", + name: "handler always fails in a nonretryable way", + method: http.MethodGet, handler: func(*http.Response) error { return errors.New("nonretryable failure") }, @@ -420,7 +443,8 @@ func TestClient_Do_WithPrepareRetry(t *testing.T) { expectedPrepareChecks: 0, }, { - name: "handler succeeds on second attempt", + name: "handler succeeds on second attempt", + method: http.MethodGet, handler: func(*http.Response) error { if shouldSucceed { return nil @@ -431,6 +455,34 @@ func TestClient_Do_WithPrepareRetry(t *testing.T) { expectedChecks: 4, expectedPrepareChecks: 1, }, + { + name: "POST - handler succeeds on second attempt, using body for PrepareRetry", + method: http.MethodPost, + requestBody: "dummy data", + handler: func(response *http.Response) error { + if shouldSucceed { + return nil + } + shouldSucceed = true + return errors.New("retryable failure") + }, + expectedChecks: 4, + expectedPrepareChecks: 1, + }, + { + name: "PUT - handler succeeds on second attempt, using body for PrepareRetry", + method: http.MethodPut, + requestBody: "dummy data", + handler: func(response *http.Response) error { + if shouldSucceed { + return nil + } + shouldSucceed = true + return errors.New("retryable failure") + }, + expectedChecks: 4, + expectedPrepareChecks: 1, + }, } for _, tt := range tests { @@ -438,8 +490,16 @@ func TestClient_Do_WithPrepareRetry(t *testing.T) { checks = 0 prepareChecks = 0 shouldSucceed = false + var req *Request + var err error + // Create the request - req, err := NewRequest("GET", ts.URL, nil) + if tt.requestBody != "" { + req, err = NewRequest(tt.method, ts.URL, strings.NewReader(tt.requestBody)) + } else { + req, err = NewRequest(tt.method, ts.URL, nil) + } + if err != nil { t.Fatalf("err: %v", err) } @@ -470,6 +530,12 @@ func TestClient_Do_WithPrepareRetry(t *testing.T) { t.Fatalf("expected changes in request header 'foo' '%s', but got '%s'", expectedHeader, header) } + if tt.method == "POST" || tt.method == "PUT" { + headerFromContent := req.Request.Header.Get("content_hash") + if headerFromContent == "" { + t.Fatalf("expected 'content_hash' header to exist, but it does not") + } + } }) } }