Skip to content

Commit

Permalink
Prevent cancellation from propagating to background request context (#17
Browse files Browse the repository at this point in the history
)

* Failing test for issue #16
* Cancellation does not propagate to background request context
* Updating test description
  • Loading branch information
kevburnsjr committed Oct 4, 2019
1 parent b6ab5e1 commit 485a0fc
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 29 deletions.
22 changes: 22 additions & 0 deletions background_request.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package microcache

import (
"context"
"net/http"
)

// newBackgroundRequest clones a request for use in background object revalidation.
// This prevents a closed foreground request context from prematurely cancelling
// the background request context.
func newBackgroundRequest(r *http.Request) *http.Request {
return r.Clone(bgContext{r.Context(), make(chan struct{})})
}

type bgContext struct {
context.Context
done chan struct{}
}

func (c bgContext) Done() <-chan struct{} {
return c.done
}
44 changes: 29 additions & 15 deletions microcache.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ type microcache struct {
collapseMutex *sync.Mutex

// Used to advance time for testing
offset time.Duration
offset time.Duration
offsetMutex *sync.RWMutex
}

type Config struct {
Expand Down Expand Up @@ -155,6 +156,7 @@ func New(o Config) *microcache {
revalidateMutex: &sync.Mutex{},
collapse: map[string]*sync.Mutex{},
collapseMutex: &sync.Mutex{},
offsetMutex: &sync.RWMutex{},
}
if o.Driver == nil {
m.Driver = NewDriverLRU(1e4) // default 10k cache items
Expand All @@ -180,6 +182,9 @@ func New(o Config) *microcache {
// chain.Append(mx.Middleware)
//
func (m *microcache) Middleware(h http.Handler) http.Handler {
if m.Timeout > 0 {
h = http.TimeoutHandler(h, m.Timeout, "Timed out")
}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Websocket passthrough
upgrade := strings.ToLower(r.Header.Get("connection")) == "upgrade"
Expand Down Expand Up @@ -259,7 +264,7 @@ func (m *microcache) Middleware(h http.Handler) http.Handler {
}

// Fresh response object found
if obj.found && obj.expires.After(time.Now().Add(m.offset)) {
if obj.found && obj.expires.After(m.now()) {
if m.Monitor != nil {
m.Monitor.Hit()
}
Expand All @@ -273,7 +278,7 @@ func (m *microcache) Middleware(h http.Handler) http.Handler {

// Stale While Revalidate
if obj.found && req.staleWhileRevalidate > 0 &&
obj.expires.Add(req.staleWhileRevalidate).After(time.Now().Add(m.offset)) {
obj.expires.Add(req.staleWhileRevalidate).After(m.now()) {
if m.Monitor != nil {
m.Monitor.Stale()
}
Expand All @@ -291,15 +296,15 @@ func (m *microcache) Middleware(h http.Handler) http.Handler {
}
m.revalidateMutex.Unlock()
if !revalidating {
br := newBackgroundRequest(r)
go func() {
defer func() {
// Clear revalidation lock
m.revalidateMutex.Lock()
delete(m.revalidating, objHash)
m.revalidateMutex.Unlock()
}()

m.handleBackendResponse(h, w, r, reqHash, req, objHash, obj, true)
m.handleBackendResponse(h, w, br, reqHash, req, objHash, obj, true)
}()
}

Expand Down Expand Up @@ -329,12 +334,7 @@ func (m *microcache) handleBackendResponse(
beres := Response{header: http.Header{}}

// Execute request
if m.Timeout > 0 {
th := http.TimeoutHandler(h, m.Timeout, "Timed out")
th.ServeHTTP(&beres, r)
} else {
h.ServeHTTP(&beres, r)
}
h.ServeHTTP(&beres, r)

if !beres.headerWritten {
beres.status = http.StatusOK
Expand All @@ -347,10 +347,10 @@ func (m *microcache) handleBackendResponse(

// Serve Stale
if beres.status >= 500 && obj.found {
serveStale := obj.expires.Add(req.staleIfError).After(time.Now().Add(m.offset))
serveStale := obj.expires.Add(req.staleIfError).After(m.now())
// Extend stale response expiration by staleIfError grace period
if req.found && serveStale && req.staleRecache {
obj.expires = obj.date.Add(m.offset).Add(req.ttl)
obj.expires = obj.date.Add(m.getOffset()).Add(req.ttl)
m.store(objHash, obj)
}
if !background && serveStale {
Expand All @@ -376,7 +376,7 @@ func (m *microcache) handleBackendResponse(
}
// Cache response
if !req.nocache {
beres.expires = time.Now().Add(m.offset).Add(req.ttl)
beres.expires = m.now().Add(req.ttl)
m.store(objHash, beres)
}
}
Expand Down Expand Up @@ -418,7 +418,7 @@ func (m *microcache) Start() {
// setAgeHeader sets the age header if not suppressed
func (m *microcache) setAgeHeader(w http.ResponseWriter, obj Response) {
if !m.SuppressAgeHeader {
age := (time.Now().Add(m.offset).Unix() - obj.date.Unix())
age := (m.now().Unix() - obj.date.Unix())
w.Header().Set("age", fmt.Sprintf("%d", age))
}
}
Expand All @@ -444,5 +444,19 @@ func (m *microcache) Stop() {

// Increments the offset for testing purposes
func (m *microcache) offsetIncr(o time.Duration) {
m.offsetMutex.Lock()
defer m.offsetMutex.Unlock()
m.offset += o
}

// Get offset
func (m *microcache) getOffset() time.Duration {
m.offsetMutex.RLock()
defer m.offsetMutex.RUnlock()
return m.offset
}

// Get current time with offset
func (m *microcache) now() time.Time {
return time.Now().Add(m.getOffset())
}
64 changes: 50 additions & 14 deletions microcache_test.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package microcache

import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"sync"
"testing"
"time"
Expand Down Expand Up @@ -173,26 +175,16 @@ func TestCollapsedFowardingStaleWhileRevalidate(t *testing.T) {
})
defer cache.Stop()
handler := cache.Middleware(http.HandlerFunc(timelySuccessHandler))
batchGet(handler, []string{
"/",
})
batchGet(handler, []string{"/"})
cache.offsetIncr(31 * time.Second)
start := time.Now()
parallelGet(handler, []string{
"/",
"/",
"/",
"/",
"/",
"/",
})
parallelGet(handler, strings.Split(strings.Repeat(",/", 10)[1:], ","))
end := time.Since(start)
// Sleep for a little bit to give the StaleWhileRevalidate goroutines some time to start.
time.Sleep(time.Millisecond * 10)
if testMonitor.getMisses() != 1 || testMonitor.getStales() != 6 ||
if testMonitor.getMisses() != 1 || testMonitor.getStales() != 10 ||
testMonitor.getBackends() != 2 || end > 20*time.Millisecond {
t.Logf("%#v", testMonitor)
t.Fatal("CollapsedFowarding and StaleWhileRevalidate not respected - got", testMonitor.getBackends(), "backend")
t.Fatalf("CollapsedFowarding and StaleWhileRevalidate not respected %s", dumpMonitor(testMonitor))
}
}

Expand Down Expand Up @@ -300,6 +292,40 @@ func TestTimeout(t *testing.T) {
}
}

// Request context cancellation should not cause error from TimeoutHandler
func TestRequestContextCancel(t *testing.T) {
testMonitor := &monitorFunc{interval: 100 * time.Second, logFunc: func(Stats) {}}
cache := New(Config{
TTL: 30 * time.Second,
StaleWhileRevalidate: 30 * time.Second,
Timeout: 10 * time.Second,
CollapsedForwarding: true,
Monitor: testMonitor,
Driver: NewDriverLRU(10),
})
defer cache.Stop()
handler := cache.Middleware(http.HandlerFunc(timelySuccessHandler))
batchGet(handler, []string{"/"})
cache.offsetIncr(31 * time.Second)
r, _ := http.NewRequest("GET", "/", nil)
ctx, cancel := context.WithCancel(r.Context())
r = r.WithContext(ctx)
w := httptest.NewRecorder()
handler.ServeHTTP(w, r)
cancel()
time.Sleep(1 * time.Millisecond)
if testMonitor.getErrors() > 0 {
t.Fatal("TimeoutHandler returned error")
}
cache.offsetIncr(31 * time.Second)
cache.Timeout = 1 * time.Millisecond
batchGet(cache.Middleware(http.HandlerFunc(slowSuccessHandler)), []string{"/"})
time.Sleep(2 * time.Millisecond)
if testMonitor.getErrors() != 1 {
t.Fatal("Request did not time out")
}
}

// CollapsedFowarding
func TestCollapsedFowarding(t *testing.T) {
testMonitor := &monitorFunc{interval: 100 * time.Second, logFunc: func(Stats) {}}
Expand Down Expand Up @@ -745,3 +771,13 @@ func timelySuccessHandler(w http.ResponseWriter, r *http.Request) {
time.Sleep(10 * time.Millisecond)
http.Error(w, "done", 200)
}

func dumpMonitor(m *monitorFunc) string {
return fmt.Sprintf("Hits: %d, Misses: %d, Backend: %d, Stales: %d, Errors: %d",
m.getHits(),
m.getMisses(),
m.getBackends(),
m.getStales(),
m.getErrors(),
)
}

0 comments on commit 485a0fc

Please sign in to comment.