From 485a0fcc49ba16bfc582ecca85abf59dc79c2893 Mon Sep 17 00:00:00 2001 From: Kevin Burns Date: Thu, 3 Oct 2019 20:43:57 -0700 Subject: [PATCH] Prevent cancellation from propagating to background request context (#17) * Failing test for issue #16 * Cancellation does not propagate to background request context * Updating test description --- background_request.go | 22 +++++++++++++++ microcache.go | 44 +++++++++++++++++++---------- microcache_test.go | 64 +++++++++++++++++++++++++++++++++---------- 3 files changed, 101 insertions(+), 29 deletions(-) create mode 100644 background_request.go diff --git a/background_request.go b/background_request.go new file mode 100644 index 0000000..d46ee49 --- /dev/null +++ b/background_request.go @@ -0,0 +1,22 @@ +package microcache + +import ( + "context" + "net/http" +) + +// newBackgroundRequest clones a request for use in background object revalidation. +// This prevents a closed foreground request context from prematurely cancelling +// the background request context. +func newBackgroundRequest(r *http.Request) *http.Request { + return r.Clone(bgContext{r.Context(), make(chan struct{})}) +} + +type bgContext struct { + context.Context + done chan struct{} +} + +func (c bgContext) Done() <-chan struct{} { + return c.done +} diff --git a/microcache.go b/microcache.go index 6968da9..ab32934 100644 --- a/microcache.go +++ b/microcache.go @@ -40,7 +40,8 @@ type microcache struct { collapseMutex *sync.Mutex // Used to advance time for testing - offset time.Duration + offset time.Duration + offsetMutex *sync.RWMutex } type Config struct { @@ -155,6 +156,7 @@ func New(o Config) *microcache { revalidateMutex: &sync.Mutex{}, collapse: map[string]*sync.Mutex{}, collapseMutex: &sync.Mutex{}, + offsetMutex: &sync.RWMutex{}, } if o.Driver == nil { m.Driver = NewDriverLRU(1e4) // default 10k cache items @@ -180,6 +182,9 @@ func New(o Config) *microcache { // chain.Append(mx.Middleware) // func (m *microcache) Middleware(h http.Handler) http.Handler { + if m.Timeout > 0 { + h = http.TimeoutHandler(h, m.Timeout, "Timed out") + } return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Websocket passthrough upgrade := strings.ToLower(r.Header.Get("connection")) == "upgrade" @@ -259,7 +264,7 @@ func (m *microcache) Middleware(h http.Handler) http.Handler { } // Fresh response object found - if obj.found && obj.expires.After(time.Now().Add(m.offset)) { + if obj.found && obj.expires.After(m.now()) { if m.Monitor != nil { m.Monitor.Hit() } @@ -273,7 +278,7 @@ func (m *microcache) Middleware(h http.Handler) http.Handler { // Stale While Revalidate if obj.found && req.staleWhileRevalidate > 0 && - obj.expires.Add(req.staleWhileRevalidate).After(time.Now().Add(m.offset)) { + obj.expires.Add(req.staleWhileRevalidate).After(m.now()) { if m.Monitor != nil { m.Monitor.Stale() } @@ -291,6 +296,7 @@ func (m *microcache) Middleware(h http.Handler) http.Handler { } m.revalidateMutex.Unlock() if !revalidating { + br := newBackgroundRequest(r) go func() { defer func() { // Clear revalidation lock @@ -298,8 +304,7 @@ func (m *microcache) Middleware(h http.Handler) http.Handler { delete(m.revalidating, objHash) m.revalidateMutex.Unlock() }() - - m.handleBackendResponse(h, w, r, reqHash, req, objHash, obj, true) + m.handleBackendResponse(h, w, br, reqHash, req, objHash, obj, true) }() } @@ -329,12 +334,7 @@ func (m *microcache) handleBackendResponse( beres := Response{header: http.Header{}} // Execute request - if m.Timeout > 0 { - th := http.TimeoutHandler(h, m.Timeout, "Timed out") - th.ServeHTTP(&beres, r) - } else { - h.ServeHTTP(&beres, r) - } + h.ServeHTTP(&beres, r) if !beres.headerWritten { beres.status = http.StatusOK @@ -347,10 +347,10 @@ func (m *microcache) handleBackendResponse( // Serve Stale if beres.status >= 500 && obj.found { - serveStale := obj.expires.Add(req.staleIfError).After(time.Now().Add(m.offset)) + serveStale := obj.expires.Add(req.staleIfError).After(m.now()) // Extend stale response expiration by staleIfError grace period if req.found && serveStale && req.staleRecache { - obj.expires = obj.date.Add(m.offset).Add(req.ttl) + obj.expires = obj.date.Add(m.getOffset()).Add(req.ttl) m.store(objHash, obj) } if !background && serveStale { @@ -376,7 +376,7 @@ func (m *microcache) handleBackendResponse( } // Cache response if !req.nocache { - beres.expires = time.Now().Add(m.offset).Add(req.ttl) + beres.expires = m.now().Add(req.ttl) m.store(objHash, beres) } } @@ -418,7 +418,7 @@ func (m *microcache) Start() { // setAgeHeader sets the age header if not suppressed func (m *microcache) setAgeHeader(w http.ResponseWriter, obj Response) { if !m.SuppressAgeHeader { - age := (time.Now().Add(m.offset).Unix() - obj.date.Unix()) + age := (m.now().Unix() - obj.date.Unix()) w.Header().Set("age", fmt.Sprintf("%d", age)) } } @@ -444,5 +444,19 @@ func (m *microcache) Stop() { // Increments the offset for testing purposes func (m *microcache) offsetIncr(o time.Duration) { + m.offsetMutex.Lock() + defer m.offsetMutex.Unlock() m.offset += o } + +// Get offset +func (m *microcache) getOffset() time.Duration { + m.offsetMutex.RLock() + defer m.offsetMutex.RUnlock() + return m.offset +} + +// Get current time with offset +func (m *microcache) now() time.Time { + return time.Now().Add(m.getOffset()) +} diff --git a/microcache_test.go b/microcache_test.go index c89ffc6..6b8314c 100644 --- a/microcache_test.go +++ b/microcache_test.go @@ -1,9 +1,11 @@ package microcache import ( + "context" "fmt" "net/http" "net/http/httptest" + "strings" "sync" "testing" "time" @@ -173,26 +175,16 @@ func TestCollapsedFowardingStaleWhileRevalidate(t *testing.T) { }) defer cache.Stop() handler := cache.Middleware(http.HandlerFunc(timelySuccessHandler)) - batchGet(handler, []string{ - "/", - }) + batchGet(handler, []string{"/"}) cache.offsetIncr(31 * time.Second) start := time.Now() - parallelGet(handler, []string{ - "/", - "/", - "/", - "/", - "/", - "/", - }) + parallelGet(handler, strings.Split(strings.Repeat(",/", 10)[1:], ",")) end := time.Since(start) // Sleep for a little bit to give the StaleWhileRevalidate goroutines some time to start. time.Sleep(time.Millisecond * 10) - if testMonitor.getMisses() != 1 || testMonitor.getStales() != 6 || + if testMonitor.getMisses() != 1 || testMonitor.getStales() != 10 || testMonitor.getBackends() != 2 || end > 20*time.Millisecond { - t.Logf("%#v", testMonitor) - t.Fatal("CollapsedFowarding and StaleWhileRevalidate not respected - got", testMonitor.getBackends(), "backend") + t.Fatalf("CollapsedFowarding and StaleWhileRevalidate not respected %s", dumpMonitor(testMonitor)) } } @@ -300,6 +292,40 @@ func TestTimeout(t *testing.T) { } } +// Request context cancellation should not cause error from TimeoutHandler +func TestRequestContextCancel(t *testing.T) { + testMonitor := &monitorFunc{interval: 100 * time.Second, logFunc: func(Stats) {}} + cache := New(Config{ + TTL: 30 * time.Second, + StaleWhileRevalidate: 30 * time.Second, + Timeout: 10 * time.Second, + CollapsedForwarding: true, + Monitor: testMonitor, + Driver: NewDriverLRU(10), + }) + defer cache.Stop() + handler := cache.Middleware(http.HandlerFunc(timelySuccessHandler)) + batchGet(handler, []string{"/"}) + cache.offsetIncr(31 * time.Second) + r, _ := http.NewRequest("GET", "/", nil) + ctx, cancel := context.WithCancel(r.Context()) + r = r.WithContext(ctx) + w := httptest.NewRecorder() + handler.ServeHTTP(w, r) + cancel() + time.Sleep(1 * time.Millisecond) + if testMonitor.getErrors() > 0 { + t.Fatal("TimeoutHandler returned error") + } + cache.offsetIncr(31 * time.Second) + cache.Timeout = 1 * time.Millisecond + batchGet(cache.Middleware(http.HandlerFunc(slowSuccessHandler)), []string{"/"}) + time.Sleep(2 * time.Millisecond) + if testMonitor.getErrors() != 1 { + t.Fatal("Request did not time out") + } +} + // CollapsedFowarding func TestCollapsedFowarding(t *testing.T) { testMonitor := &monitorFunc{interval: 100 * time.Second, logFunc: func(Stats) {}} @@ -745,3 +771,13 @@ func timelySuccessHandler(w http.ResponseWriter, r *http.Request) { time.Sleep(10 * time.Millisecond) http.Error(w, "done", 200) } + +func dumpMonitor(m *monitorFunc) string { + return fmt.Sprintf("Hits: %d, Misses: %d, Backend: %d, Stales: %d, Errors: %d", + m.getHits(), + m.getMisses(), + m.getBackends(), + m.getStales(), + m.getErrors(), + ) +}