diff --git a/gzhttp/compress.go b/gzhttp/compress.go index 28fe33195..7b697dd87 100644 --- a/gzhttp/compress.go +++ b/gzhttp/compress.go @@ -324,6 +324,20 @@ func (w *GzipResponseWriter) init() { w.gw = w.gwFactory.New(w.ResponseWriter, w.level) } +// bodyAllowedForStatus reports whether a given response status code +// permits a body. See RFC 7230, section 3.3. +func bodyAllowedForStatus(status int) bool { + switch { + case status >= 100 && status <= 199: + return false + case status == 204: + return false + case status == 304: + return false + } + return true +} + // Close will close the gzip.Writer and will put it back in the gzipWriterPool. func (w *GzipResponseWriter) Close() error { if w.ignore { @@ -340,7 +354,7 @@ func (w *GzipResponseWriter) Close() error { // Handles the intended case of setting a nil Content-Type (as for http/server or http/fs) // Set the header only if the key does not exist - if _, ok := w.Header()[contentType]; w.setContentType && !ok { + if _, ok := w.Header()[contentType]; bodyAllowedForStatus(w.code) && w.setContentType && !ok { w.Header().Set(contentType, ct) } } diff --git a/gzhttp/compress_test.go b/gzhttp/compress_test.go index 03c220d67..94ff0f332 100644 --- a/gzhttp/compress_test.go +++ b/gzhttp/compress_test.go @@ -662,7 +662,7 @@ func TestFlushAfterWrite3(t *testing.T) { } handler := gz(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { rw.WriteHeader(http.StatusOK) - //rw.Write(nil) + // rw.Write(nil) rw.(http.Flusher).Flush() })) r := httptest.NewRequest(http.MethodGet, "/", nil) @@ -1598,6 +1598,25 @@ var sniffTests = []struct { {"Incorrect RAR v5+", []byte("Rar \x1A\x07\x01\x00"), "application/octet-stream"}, } +func TestNoContentTypeWhenNoContent(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNoContent) + }) + + wrapper, err := NewWrapper() + assertNil(t, err) + + req, _ := http.NewRequest("GET", "/", nil) + req.Header.Set("Accept-Encoding", "gzip") + resp := httptest.NewRecorder() + wrapper(handler).ServeHTTP(resp, req) + res := resp.Result() + + assertEqual(t, http.StatusNoContent, res.StatusCode) + assertEqual(t, "", res.Header.Get("Content-Type")) + +} + func TestContentTypeDetect(t *testing.T) { for _, tt := range sniffTests { t.Run(tt.desc, func(t *testing.T) {