Skip to content

Commit

Permalink
fix(chore): POST request rewrite body
Browse files Browse the repository at this point in the history
  • Loading branch information
darkweak committed Dec 30, 2023
1 parent b36e5f3 commit fe0da0f
Show file tree
Hide file tree
Showing 11 changed files with 90 additions and 38 deletions.
4 changes: 4 additions & 0 deletions context/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ type cacheContext struct {
cacheName string
}

func (*cacheContext) SetContextWithBaseRequest(req *http.Request, _ *http.Request) *http.Request {
return req
}

func (cc *cacheContext) SetupContext(c configurationtypes.AbstractConfigurationInterface) {
cc.cacheName = defaultCacheName
if c.GetDefaultCache().GetCacheName() != "" {
Expand Down
27 changes: 27 additions & 0 deletions context/graphql.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,33 @@ type graphQLContext struct {
custom bool
}

func (g *graphQLContext) SetContextWithBaseRequest(req *http.Request, baseRq *http.Request) *http.Request {
ctx := req.Context()
ctx = context.WithValue(ctx, GraphQL, g.custom)
ctx = context.WithValue(ctx, HashBody, "")
ctx = context.WithValue(ctx, IsMutationRequest, false)

if g.custom && req.Body != nil {
// `{"text": "Holla, world,"source": "eng_Latn","target": "spa_Latn"}`
b := bytes.NewBuffer([]byte{})
_, _ = io.Copy(b, req.Body)
req.Body = io.NopCloser(b)
baseRq.Body = io.NopCloser(b)

if b.Len() > 0 {
if isMutation(b.Bytes()) {
ctx = context.WithValue(ctx, IsMutationRequest, true)
} else {
h := sha256.New()
h.Write(b.Bytes())
ctx = context.WithValue(ctx, HashBody, fmt.Sprintf("-%x", h.Sum(nil)))
}
}
}

return req.WithContext(ctx)
}

func (g *graphQLContext) SetupContext(c configurationtypes.AbstractConfigurationInterface) {
if len(c.GetDefaultCache().GetAllowedHTTPVerbs()) != 0 {
g.custom = true
Expand Down
4 changes: 4 additions & 0 deletions context/key.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ type keyContext struct {
overrides []map[*regexp.Regexp]keyContext
}

func (*keyContext) SetContextWithBaseRequest(req *http.Request, _ *http.Request) *http.Request {
return req
}

func (g *keyContext) SetupContext(c configurationtypes.AbstractConfigurationInterface) {
k := c.GetDefaultCache().GetKey()
g.disable_body = k.DisableBody
Expand Down
4 changes: 4 additions & 0 deletions context/method.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ type methodContext struct {
custom bool
}

func (*methodContext) SetContextWithBaseRequest(req *http.Request, _ *http.Request) *http.Request {
return req
}

func (m *methodContext) SetupContext(c configurationtypes.AbstractConfigurationInterface) {
m.allowedVerbs = defaultVerbs
if len(c.GetDefaultCache().GetAllowedHTTPVerbs()) != 0 {
Expand Down
6 changes: 5 additions & 1 deletion context/mode.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ type ModeContext struct {
Strict, Bypass_request, Bypass_response bool
}

func (*ModeContext) SetContextWithBaseRequest(req *http.Request, _ *http.Request) *http.Request {
return req
}

func (mc *ModeContext) SetupContext(c configurationtypes.AbstractConfigurationInterface) {
mode := c.GetDefaultCache().GetMode()
mc.Bypass_request = mode == "bypass" || mode == "bypass_request"
Expand All @@ -25,4 +29,4 @@ func (mc *ModeContext) SetContext(req *http.Request) *http.Request {
return req.WithContext(context.WithValue(req.Context(), Mode, mc))
}

var _ ctx = (*cacheContext)(nil)
var _ ctx = (*ModeContext)(nil)
4 changes: 4 additions & 0 deletions context/now.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ const Now ctxKey = "souin_ctx.NOW"

type nowContext struct{}

func (*nowContext) SetContextWithBaseRequest(req *http.Request, _ *http.Request) *http.Request {
return req
}

func (cc *nowContext) SetupContext(_ configurationtypes.AbstractConfigurationInterface) {}

func (cc *nowContext) SetContext(req *http.Request) *http.Request {
Expand Down
6 changes: 5 additions & 1 deletion context/timeout.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ type timeoutContext struct {
timeoutCache, timeoutBackend time.Duration
}

func (*timeoutContext) SetContextWithBaseRequest(req *http.Request, _ *http.Request) *http.Request {
return req
}

func (t *timeoutContext) SetupContext(c configurationtypes.AbstractConfigurationInterface) {
t.timeoutBackend = defaultTimeoutBackend
t.timeoutCache = defaultTimeoutCache
Expand All @@ -40,4 +44,4 @@ func (t *timeoutContext) SetContext(req *http.Request) *http.Request {
return req.WithContext(context.WithValue(context.WithValue(ctx, TimeoutCancel, cancel), TimeoutCache, t.timeoutCache))
}

var _ ctx = (*cacheContext)(nil)
var _ ctx = (*timeoutContext)(nil)
5 changes: 3 additions & 2 deletions context/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ type (
ctx interface {
SetupContext(c configurationtypes.AbstractConfigurationInterface)
SetContext(req *http.Request) *http.Request
SetContextWithBaseRequest(req *http.Request, baseRq *http.Request) *http.Request
}

Context struct {
Expand Down Expand Up @@ -53,6 +54,6 @@ func (c *Context) SetBaseContext(req *http.Request) *http.Request {
return c.Mode.SetContext(c.Timeout.SetContext(c.Method.SetContext(c.CacheName.SetContext(c.Now.SetContext(req)))))
}

func (c *Context) SetContext(req *http.Request) *http.Request {
return c.Key.SetContext(c.GraphQL.SetContext(req))
func (c *Context) SetContext(req *http.Request, baseRq *http.Request) *http.Request {
return c.Key.SetContext(c.GraphQL.SetContextWithBaseRequest(req, baseRq))
}
2 changes: 1 addition & 1 deletion context/types_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func Test_Context_SetContext(t *testing.T) {

co.Init(&c)
req := httptest.NewRequest(http.MethodGet, "http://domain.com", nil)
req = co.SetContext(req)
req = co.SetContext(req, req)
if req.Context().Value(Key) != "GET-http-domain.com-" {
t.Errorf("The Key context must be equal to GET-http-domain.com-, %s given.", req.Context().Value(Key))
}
Expand Down
65 changes: 33 additions & 32 deletions pkg/middleware/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,7 @@ func (s *SouinBaseHandler) Upstream(
sfValue, err, _ := s.singleflightPool.Do(cachedKey, func() (interface{}, error) {
shared = false
if e := next(customWriter, rq); e != nil {
s.Configuration.GetLogger().Sugar().Warnf("%#v", e)
customWriter.Header().Set("Cache-Status", fmt.Sprintf("%s; fwd=uri-miss; key=%s; detail=SERVE-HTTP-ERROR", rq.Context().Value(context.CacheName), rfc.GetCacheKeyFromCtx(rq.Context())))
return nil, e
}
Expand Down Expand Up @@ -427,61 +428,61 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n
return nil
}

rq = s.context.SetBaseContext(rq)
cacheName := rq.Context().Value(context.CacheName).(string)
req := s.context.SetBaseContext(rq)
cacheName := req.Context().Value(context.CacheName).(string)
if rq.Header.Get("Upgrade") == "websocket" || (s.ExcludeRegex != nil && s.ExcludeRegex.MatchString(rq.RequestURI)) {
rw.Header().Set("Cache-Status", cacheName+"; fwd=bypass; detail=EXCLUDED-REQUEST-URI")
return next(rw, rq)
return next(rw, req)
}

if !rq.Context().Value(context.SupportedMethod).(bool) {
if !req.Context().Value(context.SupportedMethod).(bool) {
rw.Header().Set("Cache-Status", cacheName+"; fwd=bypass; detail=UNSUPPORTED-METHOD")

err := next(rw, rq)
s.SurrogateKeyStorer.Invalidate(rq.Method, rw.Header())
err := next(rw, req)
s.SurrogateKeyStorer.Invalidate(req.Method, rw.Header())

return err
}

requestCc, coErr := cacheobject.ParseRequestCacheControl(rq.Header.Get("Cache-Control"))
requestCc, coErr := cacheobject.ParseRequestCacheControl(req.Header.Get("Cache-Control"))

modeContext := rq.Context().Value(context.Mode).(*context.ModeContext)
modeContext := req.Context().Value(context.Mode).(*context.ModeContext)
if !modeContext.Bypass_request && (coErr != nil || requestCc == nil) {
rw.Header().Set("Cache-Status", cacheName+"; fwd=bypass; detail=CACHE-CONTROL-EXTRACTION-ERROR")

err := next(rw, rq)
s.SurrogateKeyStorer.Invalidate(rq.Method, rw.Header())
err := next(rw, req)
s.SurrogateKeyStorer.Invalidate(req.Method, rw.Header())

return err
}

rq = s.context.SetContext(rq)
if rq.Context().Value(context.IsMutationRequest).(bool) {
req = s.context.SetContext(req, rq)
if req.Context().Value(context.IsMutationRequest).(bool) {
rw.Header().Set("Cache-Status", cacheName+"; fwd=bypass; detail=IS-MUTATION-REQUEST")

err := next(rw, rq)
s.SurrogateKeyStorer.Invalidate(rq.Method, rw.Header())
err := next(rw, req)
s.SurrogateKeyStorer.Invalidate(req.Method, rw.Header())

return err
}
cachedKey := rq.Context().Value(context.Key).(string)
cachedKey := req.Context().Value(context.Key).(string)

bufPool := s.bufPool.Get().(*bytes.Buffer)
bufPool.Reset()
defer s.bufPool.Put(bufPool)
customWriter := NewCustomWriter(rq, rw, bufPool)
customWriter := NewCustomWriter(req, rw, bufPool)
go func(req *http.Request, crw *CustomWriter) {
<-req.Context().Done()
crw.mutex.Lock()
crw.headersSent = true
crw.mutex.Unlock()
}(rq, customWriter)
}(req, customWriter)
s.Configuration.GetLogger().Sugar().Debugf("Request cache-control %+v", requestCc)
if modeContext.Bypass_request || !requestCc.NoCache {
validator := rfc.ParseRequest(rq)
validator := rfc.ParseRequest(req)
var response *http.Response
for _, currentStorer := range s.Storers {
response = currentStorer.Prefix(cachedKey, rq, validator)
response = currentStorer.Prefix(cachedKey, req, validator)
if response != nil {
s.Configuration.GetLogger().Sugar().Debugf("Found response in the %s storage", currentStorer.Name())
break
Expand All @@ -508,14 +509,14 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n
}

if validator.NeedRevalidation {
err := s.Revalidate(validator, next, customWriter, rq, requestCc, cachedKey)
err := s.Revalidate(validator, next, customWriter, req, requestCc, cachedKey)
_, _ = customWriter.Send()

return err
}
if resCc, _ := cacheobject.ParseResponseCacheControl(response.Header.Get("Cache-Control")); resCc.NoCachePresent {
prometheus.Increment(prometheus.NoCachedResponseCounter)
err := s.Revalidate(validator, next, customWriter, rq, requestCc, cachedKey)
err := s.Revalidate(validator, next, customWriter, req, requestCc, cachedKey)
_, _ = customWriter.Send()

return err
Expand All @@ -524,7 +525,7 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n
if !modeContext.Strict || rfc.ValidateMaxAgeCachedResponse(requestCc, response) != nil {
customWriter.Headers = response.Header
customWriter.statusCode = response.StatusCode
s.Configuration.GetLogger().Sugar().Debugf("Serve from cache %+v", rq)
s.Configuration.GetLogger().Sugar().Debugf("Serve from cache %+v", req)
_, _ = io.Copy(customWriter.Buf, response.Body)
_, err := customWriter.Send()
prometheus.Increment(prometheus.CachedResponseCounter)
Expand All @@ -533,7 +534,7 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n
}
} else if response == nil && !requestCc.OnlyIfCached && (requestCc.MaxStaleSet || requestCc.MaxStale > -1) {
for _, currentStorer := range s.Storers {
response = currentStorer.Prefix(storage.StalePrefix+cachedKey, rq, validator)
response = currentStorer.Prefix(storage.StalePrefix+cachedKey, req, validator)
if response != nil {
break
}
Expand All @@ -549,10 +550,10 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n
rfc.HitStaleCache(&response.Header)
_, _ = io.Copy(customWriter.Buf, response.Body)
_, err := customWriter.Send()
customWriter = NewCustomWriter(rq, rw, bufPool)
customWriter = NewCustomWriter(req, rw, bufPool)
go func(v *rfc.Revalidator, goCw *CustomWriter, goRq *http.Request, goNext func(http.ResponseWriter, *http.Request) error, goCc *cacheobject.RequestCacheDirectives, goCk string) {
_ = s.Revalidate(v, goNext, goCw, goRq, goCc, goCk)
}(validator, customWriter, rq, next, requestCc, cachedKey)
}(validator, customWriter, req, next, requestCc, cachedKey)
buf := s.bufPool.Get().(*bytes.Buffer)
buf.Reset()
defer s.bufPool.Put(buf)
Expand All @@ -561,8 +562,8 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n
}

if responseCc.MustRevalidate || responseCc.NoCachePresent || validator.NeedRevalidation {
rq.Header["If-None-Match"] = append(rq.Header["If-None-Match"], validator.ResponseETag)
err := s.Revalidate(validator, next, customWriter, rq, requestCc, cachedKey)
req.Header["If-None-Match"] = append(req.Header["If-None-Match"], validator.ResponseETag)
err := s.Revalidate(validator, next, customWriter, req, requestCc, cachedKey)
if err != nil {
if responseCc.StaleIfError > -1 || requestCc.StaleIfError > 0 {
code := fmt.Sprintf("; fwd-status=%d", customWriter.statusCode)
Expand Down Expand Up @@ -623,13 +624,13 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n
}

errorCacheCh := make(chan error)
go func() {
errorCacheCh <- s.Upstream(customWriter, rq, next, requestCc, cachedKey)
}()
go func(vr *http.Request) {
errorCacheCh <- s.Upstream(customWriter, vr, next, requestCc, cachedKey)
}(req)

select {
case <-rq.Context().Done():
switch rq.Context().Err() {
case <-req.Context().Done():
switch req.Context().Err() {
case baseCtx.DeadlineExceeded:
customWriter.WriteHeader(http.StatusGatewayTimeout)
rw.Header().Set("Cache-Status", cacheName+"; fwd=bypass; detail=DEADLINE-EXCEEDED")
Expand Down
1 change: 0 additions & 1 deletion pkg/middleware/writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ func (r *CustomWriter) Send() (int, error) {
r.Header().Del(rfc.StoredTTLHeader)

if !r.headersSent {

// r.Rw.Header().Set("Content-Length", fmt.Sprintf("%d", len(b)))
r.Rw.WriteHeader(r.statusCode)
r.headersSent = true
Expand Down

0 comments on commit fe0da0f

Please sign in to comment.