Skip to content

Commit

Permalink
fix(plugins): use base writer as ServeHTTP parameter (#341) (#347)
Browse files Browse the repository at this point in the history
* fix(plugins): use base writer as ServeHTTP parameter (#341)

* fix(middleware): write the response body from upstream response

* fix(writer): handle request cancel before writing headers

* fix(lint): golangci-lint
  • Loading branch information
darkweak authored May 27, 2023
1 parent ea83658 commit 03e1575
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 6 deletions.
43 changes: 37 additions & 6 deletions pkg/middleware/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package middleware

import (
"bytes"
baseCtx "context"
"errors"
"fmt"
"io"
Expand Down Expand Up @@ -279,7 +280,12 @@ func (s *SouinBaseHandler) Upstream(
customWriter.Header().Set("Cache-Control", s.DefaultMatchedUrl.DefaultCacheControl)
}

return s.Store(customWriter, rq, requestCc, cachedKey)
select {
case <-rq.Context().Done():
return baseCtx.Canceled
default:
return s.Store(customWriter, rq, requestCc, cachedKey)
}
}

func (s *SouinBaseHandler) Revalidate(validator *rfc.Revalidator, next handlerFunc, customWriter *CustomWriter, rq *http.Request, requestCc *cacheobject.RequestCacheDirectives, cachedKey string) error {
Expand Down Expand Up @@ -371,6 +377,12 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n
bufPool.Reset()
defer s.bufPool.Put(bufPool)
customWriter := NewCustomWriter(rq, rw, bufPool)
go func(req *http.Request, crw *CustomWriter) {
<-req.Context().Done()
crw.mutex.Lock()
crw.headersSent = true
crw.mutex.Unlock()
}(rq, customWriter)
s.Configuration.GetLogger().Sugar().Debugf("Request cache-control %+v", requestCc)
if !requestCc.NoCache {
validator := rfc.ParseRequest(rq)
Expand All @@ -388,6 +400,7 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n
if rfc.ValidateMaxAgeCachedResponse(requestCc, response) != nil {
customWriter.Headers = response.Header
customWriter.statusCode = response.StatusCode
s.Configuration.GetLogger().Sugar().Debugf("Serve from cache %+v", rq)
_, _ = io.Copy(customWriter.Buf, response.Body)
_, err := customWriter.Send()

Expand Down Expand Up @@ -466,10 +479,28 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n
}
}

if err := s.Upstream(customWriter, rq, next, requestCc, cachedKey); err != nil {
return err
errorCacheCh := make(chan error)
go func() {
errorCacheCh <- s.Upstream(customWriter, rq, next, requestCc, cachedKey)
}()

select {
case <-rq.Context().Done():
switch rq.Context().Err() {
case baseCtx.DeadlineExceeded:
customWriter.WriteHeader(http.StatusGatewayTimeout)
rw.Header().Set("Cache-Status", cacheName+"; fwd=bypass; detail=DEADLINE-EXCEEDED")
_, _ = customWriter.Rw.Write([]byte("Internal server error"))
return baseCtx.DeadlineExceeded
case baseCtx.Canceled:
return baseCtx.Canceled
default:
return nil
}
case v := <-errorCacheCh:
if v == nil {
_, _ = customWriter.Send()
}
return v
}

_, _ = customWriter.Send()
return nil
}
9 changes: 9 additions & 0 deletions pkg/middleware/writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"net/http"
"strings"
"sync"

"github.com/darkweak/go-esi/esi"
"github.com/darkweak/souin/pkg/rfc"
Expand All @@ -24,6 +25,7 @@ func NewCustomWriter(rq *http.Request, rw http.ResponseWriter, b *bytes.Buffer)
Req: rq,
Rw: rw,
Headers: http.Header{},
mutex: &sync.Mutex{},
}
}

Expand All @@ -34,12 +36,15 @@ type CustomWriter struct {
Req *http.Request
Headers http.Header
headersSent bool
mutex *sync.Mutex
statusCode int
// size int
}

// Header will write the response headers
func (r *CustomWriter) Header() http.Header {
r.mutex.Lock()
defer r.mutex.Unlock()
if r.headersSent {
return http.Header{}
}
Expand All @@ -48,6 +53,8 @@ func (r *CustomWriter) Header() http.Header {

// WriteHeader will write the response headers
func (r *CustomWriter) WriteHeader(code int) {
r.mutex.Lock()
defer r.mutex.Unlock()
if r.headersSent {
return
}
Expand Down Expand Up @@ -76,10 +83,12 @@ func (r *CustomWriter) Send() (int, error) {
}
}

r.mutex.Lock()
if !r.headersSent {
r.Rw.Header().Set("Content-Length", fmt.Sprintf("%d", len(b)))
r.Rw.WriteHeader(r.statusCode)
r.headersSent = true
}
r.mutex.Unlock()
return r.Rw.Write(b)
}

0 comments on commit 03e1575

Please sign in to comment.