diff --git a/pkg/lib/api/build_policies.go b/pkg/lib/api/build_policies.go index 5097a30..271a626 100644 --- a/pkg/lib/api/build_policies.go +++ b/pkg/lib/api/build_policies.go @@ -13,7 +13,7 @@ import ( "github.com/programmfabrik/apitest/pkg/lib/util" ) -func buildMultipart(request Request) (additionalHeaders map[string]string, body io.Reader, bodyCloser io.Closer, err error) { +func buildMultipart(request Request) (additionalHeaders map[string]string, body io.Reader, err error) { additionalHeaders = make(map[string]string, 0) var buf = bytes.NewBuffer([]byte{}) @@ -24,7 +24,7 @@ func buildMultipart(request Request) (additionalHeaders map[string]string, body if ok { f, ok := val.(util.JsonString) if !ok { - return additionalHeaders, body, nil, fmt.Errorf("file:filename should be a string") + return nil, nil, fmt.Errorf("file:filename should be a string") } replaceFilename = &f } @@ -70,7 +70,7 @@ func buildMultipart(request Request) (additionalHeaders map[string]string, body for key, val := range request.Body.(map[string]any) { err = createPart(key, val) if err != nil { - return additionalHeaders, body, bodyCloser, err + return nil, nil, err } } @@ -80,7 +80,7 @@ func buildMultipart(request Request) (additionalHeaders map[string]string, body return } -func buildUrlencoded(request Request) (additionalHeaders map[string]string, body io.Reader, bodyCloser io.Closer, err error) { +func buildUrlencoded(request Request) (additionalHeaders map[string]string, body io.Reader, err error) { additionalHeaders = make(map[string]string, 0) additionalHeaders["Content-Type"] = "application/x-www-form-urlencoded" formParams := url.Values{} @@ -95,11 +95,11 @@ func buildUrlencoded(request Request) (additionalHeaders map[string]string, body } } body = strings.NewReader(formParams.Encode()) - return additionalHeaders, body, nil, nil + return additionalHeaders, body, nil } -func buildRegular(request Request) (additionalHeaders map[string]string, body io.Reader, bodyCloser io.Closer, err error) { +func buildRegular(request Request) (additionalHeaders map[string]string, body io.Reader, err error) { additionalHeaders = make(map[string]string, 0) additionalHeaders["Content-Type"] = "application/json" @@ -108,18 +108,20 @@ func buildRegular(request Request) (additionalHeaders map[string]string, body io } else { bodyBytes, err := json.Marshal(request.Body) if err != nil { - return additionalHeaders, body, nil, fmt.Errorf("error marshaling request body: %s", err) + return nil, nil, fmt.Errorf("error marshaling request body: %s", err) } body = bytes.NewBuffer(bodyBytes) } - return additionalHeaders, body, nil, nil + return additionalHeaders, body, nil } -func buildFile(req Request) (map[string]string, io.Reader, io.Closer, error) { +// buildFile opens a file for use with buildPolicy. +// WARNING: This returns a file handle that must be closed! +func buildFile(req Request) (map[string]string, io.Reader, error) { headers := map[string]string{} if req.BodyFile == "" { - return nil, nil, nil, errors.New(`Request.buildFile: Missing "body_file"`) + return nil, nil, errors.New(`Request.buildFile: Missing "body_file"`) } path := req.BodyFile @@ -130,7 +132,7 @@ func buildFile(req Request) (map[string]string, io.Reader, io.Closer, error) { file, err := util.OpenFileOrUrl(path, req.ManifestDir) if err != nil { - return nil, nil, nil, err + return nil, nil, err } - return headers, file, file, err + return headers, file, err } diff --git a/pkg/lib/api/build_policies_test.go b/pkg/lib/api/build_policies_test.go index 9a30f35..be94ef7 100644 --- a/pkg/lib/api/build_policies_test.go +++ b/pkg/lib/api/build_policies_test.go @@ -26,7 +26,6 @@ func TestBuildMultipart(t *testing.T) { ManifestDir: "test/", BodyType: "multipart", } - defer testRequest.Close() httpRequest, err := testRequest.buildHttpRequest() go_test_utils.ExpectNoError(t, err, "error building multipart request") @@ -50,7 +49,7 @@ func TestBuildMultipart_ErrPathSpec(t *testing.T) { ManifestDir: "test/path/", } - _, _, _, err := buildMultipart(testRequest) + _, _, err := buildMultipart(testRequest) if err == nil { t.Fatal("expected error") } @@ -67,7 +66,7 @@ func TestBuildMultipart_ErrPathSpecNoString(t *testing.T) { ManifestDir: "test/path/", } - _, _, _, err := buildMultipart(testRequest) + _, _, err := buildMultipart(testRequest) if err == nil { t.Fatal("expected error") } @@ -84,7 +83,7 @@ func TestBuildMultipart_FileDoesNotExist(t *testing.T) { ManifestDir: "test/path/", } - _, _, _, err := buildMultipart(testRequest) + _, _, err := buildMultipart(testRequest) if err == nil { t.Fatal("expected error") } diff --git a/pkg/lib/api/request.go b/pkg/lib/api/request.go index 0c2f72f..5d3e1d8 100755 --- a/pkg/lib/api/request.go +++ b/pkg/lib/api/request.go @@ -57,8 +57,7 @@ type Request struct { BodyFile string `yaml:"body_file" json:"body_file"` Body any `yaml:"body" json:"body"` - buildPolicy func(Request) (additionalHeaders map[string]string, body io.Reader, bodyCloser io.Closer, err error) - bodyCloser io.Closer + buildPolicy func(Request) (additionalHeaders map[string]string, body io.Reader, err error) ManifestDir string DataStore *datastore.Datastore } @@ -89,16 +88,14 @@ func (request Request) buildHttpRequest() (req *http.Request, err error) { return nil, errors.Wrapf(err, "Unable to buildHttpRequest with URL %q", requestUrl) } - if request.bodyCloser != nil { - // in case bodyCloser is already set, close the old body first - request.bodyCloser.Close() - } - - additionalHeaders, body, bodyCloser, err := request.buildPolicy(request) + // Note that buildPolicy may return a file handle that needs to be + // closed. According to standard library documentation, the NewRequest + // call below will take into account if body also happens to implement + // io.Closer. + additionalHeaders, body, err := request.buildPolicy(request) if err != nil { return req, fmt.Errorf("error executing buildpolicy: %s", err) } - request.bodyCloser = bodyCloser req, err = http.NewRequest(request.Method, requestUrl, body) if err != nil { @@ -273,14 +270,6 @@ func (request Request) buildHttpRequest() (req *http.Request, err error) { return req, nil } -func (request Request) Close() error { - if request.bodyCloser != nil { - return request.bodyCloser.Close() - } - - return nil -} - func (request Request) ToString(curl bool) (res string) { httpRequest, err := request.buildHttpRequest() if err != nil { diff --git a/pkg/lib/api/request_test.go b/pkg/lib/api/request_test.go index 54c276b..d761ac0 100755 --- a/pkg/lib/api/request_test.go +++ b/pkg/lib/api/request_test.go @@ -21,12 +21,11 @@ func TestRequestBuildHttp(t *testing.T) { }, ServerURL: "serverUrl", } - defer request.Close() - request.buildPolicy = func(request Request) (ah map[string]string, b io.Reader, c io.Closer, err error) { + request.buildPolicy = func(request Request) (ah map[string]string, b io.Reader, err error) { ah = make(map[string]string) ah["mock-header"] = "application/mock" b = strings.NewReader("mock_body") - return ah, b, c, nil + return ah, b, nil } httpRequest, err := request.buildHttpRequest() go_test_utils.ExpectNoError(t, err, fmt.Sprintf("error building http-request: %s", err)) @@ -88,7 +87,6 @@ func TestRequestBuildHttpWithCookie(t *testing.T) { Method: "GET", Cookies: reqCookies, } - defer request.Close() request.buildPolicy = buildRegular ds := datastore.NewStore(false) for key, val := range storeCookies {