Skip to content

Commit

Permalink
better scheduling & lock cleanup for https connections
Browse files Browse the repository at this point in the history
  • Loading branch information
renbou committed Nov 18, 2023
1 parent 85768c2 commit d789a58
Showing 1 changed file with 143 additions and 107 deletions.
250 changes: 143 additions & 107 deletions cacheproxy/proxy/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"crypto/subtle"
"io"
"log"
"math/rand"
"net/http"
"net/http/httputil"
"os"
Expand All @@ -20,29 +19,63 @@ import (
"github.com/redis/go-redis/v9"
)

const (
envRedisURL = "REDIS_URL"
envAuthKey = "AUTH_KEY"
listenAddr = ":8888"
redisDeadline = time.Second * 30
readLockInterval = time.Millisecond * 25
maxCacheDuration = time.Minute * 5
headerAuthKey = "X-CBSProxy-Auth-Key"
headerCacheDuration = "X-CBSProxy-Cache-Duration"
headerCacheOverride = "X-CBSProxy-Cache-Override"
headerCached = "X-CBSProxy-Cached"
)
const maxCacheDuration = time.Minute * 5

type cachingData struct {
cacheKey string
cacheFor time.Duration
alreadyCached bool
func popHeader(r *http.Request, key string) string {
value := r.Header.Get(key)
r.Header.Del(key)
return value
}

// connDataKey is used for storing connContextData for a single connection in case of HTTPS,
// and a single request in case of plaintext HTTP.
type connDataKey struct{}

// connData contains the last locked key for the current connection/request,
// in case the proxy ends without calling OnResponse and we need to cleanup the lock.
type connData struct {
proxyCtx *goproxy.ProxyCtx
cacheKey string
unlockFn func()
}

func (ccd *connData) cleanup() {
if ccd.unlockFn != nil {
ccd.proxyCtx.Logf("unlocking cache lock for %q", ccd.cacheKey)
ccd.unlockFn()
ccd.unlockFn = nil
}
}

func connDataFromContext(ctx context.Context) *connData {
return ctx.Value(connDataKey{}).(*connData)
}

type cachingContext struct {
context.Context
key string
keyUnlock func()
func contextWithConnData(ctx context.Context, connData *connData) context.Context {
return context.WithValue(ctx, connDataKey{}, connData)
}

// cleanupRoundTripper is a special RoundTripper which cleans up the connection data
// if the underlying proxy roundtripper returns an error. This is needed because
// RoundTrip fails during HTTPS proxying aren't reported via OnResponse, and thus aren't closed.
type cleanupRoundTripper struct{}

func (cleanupRoundTripper) RoundTrip(req *http.Request, ctx *goproxy.ProxyCtx) (*http.Response, error) {
resp, err := ctx.Proxy.Tr.RoundTrip(req)
if err != nil {
connDataFromContext(req.Context()).cleanup()
}

return resp, err
}

// proxyData stores data about a request's caching configuration and
// state in the proxy context shared between OnRequest and OnResponse.
type proxyData struct {
connData *connData
cacheKey string
cacheDuration time.Duration
cacheHot bool
}

type cachingHandler struct {
Expand All @@ -58,16 +91,10 @@ func (c *cachingHandler) getFromCache(ctx context.Context, key string) (string,
rUnlock := c.rLockKey(key)
defer rUnlock()

ctx, cancel := context.WithTimeout(ctx, redisDeadline)
defer cancel()

return c.redis.Get(ctx, key).Result()
}

func (c *cachingHandler) storeInCache(ctx context.Context, key string, value []byte, d time.Duration) error {
ctx, cancel := context.WithTimeout(ctx, redisDeadline)
defer cancel()

_, err := c.redis.SetNX(ctx, key, value, d).Result()
return err
}
Expand Down Expand Up @@ -97,40 +124,43 @@ func (c *cachingHandler) rLockKey(key string) (unlock func()) {
func (c *cachingHandler) tryLockKey(key string) (ok bool, unlock func()) {
keyMu := c.getKeyLock(key)

if keyMu.TryLock() {
return true, keyMu.Unlock
// a few retries are needed because TryLock can fail during
// high contention for all the goroutines trying to lock it.
for i := 0; i < 5; i++ {
if keyMu.TryLock() {
return true, keyMu.Unlock
}

// Benchmarks have shown that this works better than runtime.Gosched()
time.Sleep(time.Millisecond)
}

return false, nil
}

func (c *cachingHandler) getCacheDuration(r *http.Request) time.Duration {
defer r.Header.Del(headerCacheDuration)

val := r.Header.Get(headerCacheDuration)
if val == "" {
return 0
}

// Try parse duration first.
if dur, err := time.ParseDuration(val); err == nil {
return dur
}
func (c *cachingHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
connData := new(connData)
ctx := contextWithConnData(r.Context(), connData)

// Parse number of seconds
if durS, err := strconv.Atoi(val); err == nil {
return time.Second * time.Duration(durS)
}
// proxy.ServeHTTP can fail after OnRequest has locked a key,
// in which case that key would stay locked forever without this defer
defer connData.cleanup()

return 0
c.proxy.ServeHTTP(w, r.WithContext(ctx))
}

func (c *cachingHandler) overrideCacheFlag(r *http.Request) bool {
defer r.Header.Del(headerCacheOverride)
func (c *cachingHandler) HandleConnect(host string, ctx *goproxy.ProxyCtx) (*goproxy.ConnectAction, string) {
// UserData needs to be set here because the original CONNECT request only appears here,
// and the OnRequest callback will receive MITM'd requests with their own contexts,
// all of which need to be linked to this one in order to be properly cleaned up.
data := &proxyData{connData: connDataFromContext(ctx.Req.Context())}
data.connData.proxyCtx = ctx
ctx.UserData = data

return r.Header.Get(headerCacheOverride) != ""
return goproxy.MitmConnect, host
}

func (c *cachingHandler) validateDuration(d time.Duration) time.Duration {
func validateDuration(d time.Duration) time.Duration {
if d > 0 && d < time.Second {
return time.Second
}
Expand All @@ -139,27 +169,45 @@ func (c *cachingHandler) validateDuration(d time.Duration) time.Duration {
return min(d, maxCacheDuration)
}

func (c *cachingHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx := &cachingContext{
Context: r.Context(),
key: "",
keyUnlock: nil,
}

// proxy.ServeHTTP can fail after OnRequest has locked a key,
// in which case that key would stay locked forever without this defer
func requestCacheDuration(r *http.Request) (d time.Duration) {
defer func() {
if ctx.keyUnlock != nil {
log.Printf("unlocking cache lock for %q", ctx.key)
ctx.keyUnlock()
}
d = validateDuration(d)
}()

c.proxy.ServeHTTP(w, r.WithContext(ctx))
val := popHeader(r, "X-Cache-Duration")
if val == "" {
return 0
}

// Try parse duration first.
if dur, err := time.ParseDuration(val); err == nil {
return dur
}

// Parse number of seconds
if durS, err := strconv.Atoi(val); err == nil {
return time.Second * time.Duration(durS)
}

return 0
}

func (c *cachingHandler) OnRequest(r *http.Request, ctx *goproxy.ProxyCtx) (*http.Request, *http.Response) {
if subtle.ConstantTimeCompare([]byte(r.Header.Get(headerAuthKey)), []byte(c.authKey)) != 1 {
ctx.RoundTripper = cleanupRoundTripper{}

// Either get UserData that was set during HandleConnect in case of HTTPS proxying,
// additionally saving the connection context to the current request context, or create and set per-request data.
var data *proxyData
if ctx.UserData == nil {
data = &proxyData{connData: connDataFromContext(r.Context())}
data.connData.proxyCtx = ctx
ctx.UserData = data
} else {
data = ctx.UserData.(*proxyData)
r = r.WithContext(contextWithConnData(r.Context(), data.connData))
}

if subtle.ConstantTimeCompare([]byte(popHeader(r, "X-Cache-Auth-Key")), []byte(c.authKey)) != 1 {
return nil, &http.Response{
StatusCode: http.StatusUnauthorized,
Header: make(http.Header),
Expand All @@ -169,23 +217,18 @@ func (c *cachingHandler) OnRequest(r *http.Request, ctx *goproxy.ProxyCtx) (*htt
}
}

cd := &cachingData{
cacheKey: r.URL.String(),
cacheFor: c.validateDuration(c.getCacheDuration(r)),
alreadyCached: false,
}

ctx.UserData = cd
data.cacheKey = r.URL.String()
data.cacheDuration = requestCacheDuration(r)

// Forcefully override cache using response to this request.
if c.overrideCacheFlag(r) {
if popHeader(r, "X-Cache-Override") != "" {
return r, nil
}

// Loop needed because the cache can be empty, but when we try to acquire a write lock,
// someone could've already gotten it first, and we need to wait for them to finish and read the result.
for {
cachedResponseString, err := c.getFromCache(r.Context(), cd.cacheKey)
cachedResponseString, err := c.getFromCache(r.Context(), data.cacheKey)
if err == nil && len(cachedResponseString) > 0 {
// Return cached response or proxy the request if the cache contains an invalid entry
cachedResp, err := http.ReadResponse(bufio.NewReader(strings.NewReader(cachedResponseString)), r)
Expand All @@ -194,46 +237,38 @@ func (c *cachingHandler) OnRequest(r *http.Request, ctx *goproxy.ProxyCtx) (*htt
return r, nil
}

ctx.Logf("returning cached response for %q", cd.cacheKey)
cd.alreadyCached = true
ctx.Logf("returning cached response for %q", data.cacheKey)
data.cacheHot = true
return nil, cachedResp
}

ok, unlock := c.tryLockKey(cd.cacheKey)
ok, unlock := c.tryLockKey(data.cacheKey)
if !ok {
// sleep for 25±5ms before next iteration
time.Sleep(time.Duration(float64(readLockInterval) * (1 + rand.Float64()*0.4 - 0.2)))
// someone else should've acquired a lock by now, so RLock with block
continue
}

// Write lock acquired:
// - save the unlock function to be called in the ServeHTTP defer
// - proxy the request and then save the response
ctx.Logf("cache lock acquired for %q, will proxy request", cd.cacheKey)
cachingCtx := r.Context().(*cachingContext)
cachingCtx.keyUnlock = unlock
cachingCtx.key = cd.cacheKey
ctx.Logf("cache lock acquired for %q, will proxy request", data.cacheKey)
data.connData.unlockFn = unlock
data.connData.cacheKey = data.cacheKey
return r, nil
}
}

func (c *cachingHandler) OnResponse(resp *http.Response, ctx *goproxy.ProxyCtx) *http.Response {
if ctx.UserData == nil {
// UserData isn't set when request authorization fails.
return resp
} else if resp == nil {
// resp is nil when an error has occurred
return nil
}
data := ctx.UserData.(*proxyData)
defer data.connData.cleanup()

cd, ok := ctx.UserData.(*cachingData)
if !ok {
ctx.Warnf("proxy context data contained %T instead of *cachingData", ctx.UserData)
return resp
// resp is nil when an error has occurred
if resp == nil {
return nil
}

resp.Header.Set(headerCached, strconv.FormatBool(cd.alreadyCached))
if cd.alreadyCached || cd.cacheFor <= 0 {
resp.Header.Set("X-Cache-Hot", strconv.FormatBool(data.cacheHot))
if data.cacheHot || data.cacheDuration <= 0 {
return resp
}

Expand All @@ -243,37 +278,38 @@ func (c *cachingHandler) OnResponse(resp *http.Response, ctx *goproxy.ProxyCtx)
return resp
}

if err := c.storeInCache(ctx.Req.Context(), cd.cacheKey, dump, cd.cacheFor); err != nil {
if err := c.storeInCache(ctx.Req.Context(), data.cacheKey, dump, data.cacheDuration); err != nil {
ctx.Warnf("storing dumped result in cache: %s", err)
}
return resp
}

func main() {
redopts, err := redis.ParseURL(os.Getenv(envRedisURL))
redopts, err := redis.ParseURL(os.Getenv("REDIS_URL"))
if err != nil {
log.Fatalf("Failed to parse %s: %s", envRedisURL, err)
log.Fatalf("Failed to parse REDIS_URL: %s", err)
}

handler := &cachingHandler{
redis: redis.NewClient(redopts),
proxy: goproxy.NewProxyHttpServer(),
keylocks: make(map[string]*sync.RWMutex),
authKey: os.Getenv(envAuthKey),
authKey: os.Getenv("AUTH_KEY"),
}

handler.proxy.Verbose = true
handler.proxy.OnRequest().HandleConnect(goproxy.AlwaysMitm)

handler.proxy.OnRequest().HandleConnect(goproxy.FuncHttpsHandler(handler.HandleConnect))
handler.proxy.OnRequest().DoFunc(handler.OnRequest)
handler.proxy.OnResponse().DoFunc(handler.OnResponse)

log.Printf("Proxy started on %s", listenAddr)
log.Printf("Proxy started on :8888")
srv := &http.Server{
Addr: listenAddr,
Handler: handler,
ReadTimeout: 10 * time.Second,
WriteTimeout: time.Minute,
IdleTimeout: time.Minute * 2,
Addr: ":8888",
// Handler: http.TimeoutHandler(handler, time.Minute, "timed out"),
Handler: handler,
ReadTimeout: time.Second * 10,
IdleTimeout: time.Minute * 2,
}

log.Fatal(srv.ListenAndServe())
Expand Down

0 comments on commit d789a58

Please sign in to comment.