From d789a58140cc8ef96c4f0efe505ae3183008e4c8 Mon Sep 17 00:00:00 2001 From: Artem Mikheev <30644072+renbou@users.noreply.github.com> Date: Sat, 18 Nov 2023 03:00:01 +0300 Subject: [PATCH] better scheduling & lock cleanup for https connections --- cacheproxy/proxy/main.go | 250 ++++++++++++++++++++++----------------- 1 file changed, 143 insertions(+), 107 deletions(-) diff --git a/cacheproxy/proxy/main.go b/cacheproxy/proxy/main.go index 70871f0..62866c1 100644 --- a/cacheproxy/proxy/main.go +++ b/cacheproxy/proxy/main.go @@ -7,7 +7,6 @@ import ( "crypto/subtle" "io" "log" - "math/rand" "net/http" "net/http/httputil" "os" @@ -20,29 +19,63 @@ import ( "github.com/redis/go-redis/v9" ) -const ( - envRedisURL = "REDIS_URL" - envAuthKey = "AUTH_KEY" - listenAddr = ":8888" - redisDeadline = time.Second * 30 - readLockInterval = time.Millisecond * 25 - maxCacheDuration = time.Minute * 5 - headerAuthKey = "X-CBSProxy-Auth-Key" - headerCacheDuration = "X-CBSProxy-Cache-Duration" - headerCacheOverride = "X-CBSProxy-Cache-Override" - headerCached = "X-CBSProxy-Cached" -) +const maxCacheDuration = time.Minute * 5 -type cachingData struct { - cacheKey string - cacheFor time.Duration - alreadyCached bool +func popHeader(r *http.Request, key string) string { + value := r.Header.Get(key) + r.Header.Del(key) + return value +} + +// connDataKey is used for storing connContextData for a single connection in case of HTTPS, +// and a single request in case of plaintext HTTP. +type connDataKey struct{} + +// connData contains the last locked key for the current connection/request, +// in case the proxy ends without calling OnResponse and we need to cleanup the lock. +type connData struct { + proxyCtx *goproxy.ProxyCtx + cacheKey string + unlockFn func() +} + +func (ccd *connData) cleanup() { + if ccd.unlockFn != nil { + ccd.proxyCtx.Logf("unlocking cache lock for %q", ccd.cacheKey) + ccd.unlockFn() + ccd.unlockFn = nil + } +} + +func connDataFromContext(ctx context.Context) *connData { + return ctx.Value(connDataKey{}).(*connData) } -type cachingContext struct { - context.Context - key string - keyUnlock func() +func contextWithConnData(ctx context.Context, connData *connData) context.Context { + return context.WithValue(ctx, connDataKey{}, connData) +} + +// cleanupRoundTripper is a special RoundTripper which cleans up the connection data +// if the underlying proxy roundtripper returns an error. This is needed because +// RoundTrip fails during HTTPS proxying aren't reported via OnResponse, and thus aren't closed. +type cleanupRoundTripper struct{} + +func (cleanupRoundTripper) RoundTrip(req *http.Request, ctx *goproxy.ProxyCtx) (*http.Response, error) { + resp, err := ctx.Proxy.Tr.RoundTrip(req) + if err != nil { + connDataFromContext(req.Context()).cleanup() + } + + return resp, err +} + +// proxyData stores data about a request's caching configuration and +// state in the proxy context shared between OnRequest and OnResponse. +type proxyData struct { + connData *connData + cacheKey string + cacheDuration time.Duration + cacheHot bool } type cachingHandler struct { @@ -58,16 +91,10 @@ func (c *cachingHandler) getFromCache(ctx context.Context, key string) (string, rUnlock := c.rLockKey(key) defer rUnlock() - ctx, cancel := context.WithTimeout(ctx, redisDeadline) - defer cancel() - return c.redis.Get(ctx, key).Result() } func (c *cachingHandler) storeInCache(ctx context.Context, key string, value []byte, d time.Duration) error { - ctx, cancel := context.WithTimeout(ctx, redisDeadline) - defer cancel() - _, err := c.redis.SetNX(ctx, key, value, d).Result() return err } @@ -97,40 +124,43 @@ func (c *cachingHandler) rLockKey(key string) (unlock func()) { func (c *cachingHandler) tryLockKey(key string) (ok bool, unlock func()) { keyMu := c.getKeyLock(key) - if keyMu.TryLock() { - return true, keyMu.Unlock + // a few retries are needed because TryLock can fail during + // high contention for all the goroutines trying to lock it. + for i := 0; i < 5; i++ { + if keyMu.TryLock() { + return true, keyMu.Unlock + } + + // Benchmarks have shown that this works better than runtime.Gosched() + time.Sleep(time.Millisecond) } + return false, nil } -func (c *cachingHandler) getCacheDuration(r *http.Request) time.Duration { - defer r.Header.Del(headerCacheDuration) - - val := r.Header.Get(headerCacheDuration) - if val == "" { - return 0 - } - - // Try parse duration first. - if dur, err := time.ParseDuration(val); err == nil { - return dur - } +func (c *cachingHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + connData := new(connData) + ctx := contextWithConnData(r.Context(), connData) - // Parse number of seconds - if durS, err := strconv.Atoi(val); err == nil { - return time.Second * time.Duration(durS) - } + // proxy.ServeHTTP can fail after OnRequest has locked a key, + // in which case that key would stay locked forever without this defer + defer connData.cleanup() - return 0 + c.proxy.ServeHTTP(w, r.WithContext(ctx)) } -func (c *cachingHandler) overrideCacheFlag(r *http.Request) bool { - defer r.Header.Del(headerCacheOverride) +func (c *cachingHandler) HandleConnect(host string, ctx *goproxy.ProxyCtx) (*goproxy.ConnectAction, string) { + // UserData needs to be set here because the original CONNECT request only appears here, + // and the OnRequest callback will receive MITM'd requests with their own contexts, + // all of which need to be linked to this one in order to be properly cleaned up. + data := &proxyData{connData: connDataFromContext(ctx.Req.Context())} + data.connData.proxyCtx = ctx + ctx.UserData = data - return r.Header.Get(headerCacheOverride) != "" + return goproxy.MitmConnect, host } -func (c *cachingHandler) validateDuration(d time.Duration) time.Duration { +func validateDuration(d time.Duration) time.Duration { if d > 0 && d < time.Second { return time.Second } @@ -139,27 +169,45 @@ func (c *cachingHandler) validateDuration(d time.Duration) time.Duration { return min(d, maxCacheDuration) } -func (c *cachingHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - ctx := &cachingContext{ - Context: r.Context(), - key: "", - keyUnlock: nil, - } - - // proxy.ServeHTTP can fail after OnRequest has locked a key, - // in which case that key would stay locked forever without this defer +func requestCacheDuration(r *http.Request) (d time.Duration) { defer func() { - if ctx.keyUnlock != nil { - log.Printf("unlocking cache lock for %q", ctx.key) - ctx.keyUnlock() - } + d = validateDuration(d) }() - c.proxy.ServeHTTP(w, r.WithContext(ctx)) + val := popHeader(r, "X-Cache-Duration") + if val == "" { + return 0 + } + + // Try parse duration first. + if dur, err := time.ParseDuration(val); err == nil { + return dur + } + + // Parse number of seconds + if durS, err := strconv.Atoi(val); err == nil { + return time.Second * time.Duration(durS) + } + + return 0 } func (c *cachingHandler) OnRequest(r *http.Request, ctx *goproxy.ProxyCtx) (*http.Request, *http.Response) { - if subtle.ConstantTimeCompare([]byte(r.Header.Get(headerAuthKey)), []byte(c.authKey)) != 1 { + ctx.RoundTripper = cleanupRoundTripper{} + + // Either get UserData that was set during HandleConnect in case of HTTPS proxying, + // additionally saving the connection context to the current request context, or create and set per-request data. + var data *proxyData + if ctx.UserData == nil { + data = &proxyData{connData: connDataFromContext(r.Context())} + data.connData.proxyCtx = ctx + ctx.UserData = data + } else { + data = ctx.UserData.(*proxyData) + r = r.WithContext(contextWithConnData(r.Context(), data.connData)) + } + + if subtle.ConstantTimeCompare([]byte(popHeader(r, "X-Cache-Auth-Key")), []byte(c.authKey)) != 1 { return nil, &http.Response{ StatusCode: http.StatusUnauthorized, Header: make(http.Header), @@ -169,23 +217,18 @@ func (c *cachingHandler) OnRequest(r *http.Request, ctx *goproxy.ProxyCtx) (*htt } } - cd := &cachingData{ - cacheKey: r.URL.String(), - cacheFor: c.validateDuration(c.getCacheDuration(r)), - alreadyCached: false, - } - - ctx.UserData = cd + data.cacheKey = r.URL.String() + data.cacheDuration = requestCacheDuration(r) // Forcefully override cache using response to this request. - if c.overrideCacheFlag(r) { + if popHeader(r, "X-Cache-Override") != "" { return r, nil } // Loop needed because the cache can be empty, but when we try to acquire a write lock, // someone could've already gotten it first, and we need to wait for them to finish and read the result. for { - cachedResponseString, err := c.getFromCache(r.Context(), cd.cacheKey) + cachedResponseString, err := c.getFromCache(r.Context(), data.cacheKey) if err == nil && len(cachedResponseString) > 0 { // Return cached response or proxy the request if the cache contains an invalid entry cachedResp, err := http.ReadResponse(bufio.NewReader(strings.NewReader(cachedResponseString)), r) @@ -194,46 +237,38 @@ func (c *cachingHandler) OnRequest(r *http.Request, ctx *goproxy.ProxyCtx) (*htt return r, nil } - ctx.Logf("returning cached response for %q", cd.cacheKey) - cd.alreadyCached = true + ctx.Logf("returning cached response for %q", data.cacheKey) + data.cacheHot = true return nil, cachedResp } - ok, unlock := c.tryLockKey(cd.cacheKey) + ok, unlock := c.tryLockKey(data.cacheKey) if !ok { - // sleep for 25±5ms before next iteration - time.Sleep(time.Duration(float64(readLockInterval) * (1 + rand.Float64()*0.4 - 0.2))) + // someone else should've acquired a lock by now, so RLock with block continue } // Write lock acquired: // - save the unlock function to be called in the ServeHTTP defer // - proxy the request and then save the response - ctx.Logf("cache lock acquired for %q, will proxy request", cd.cacheKey) - cachingCtx := r.Context().(*cachingContext) - cachingCtx.keyUnlock = unlock - cachingCtx.key = cd.cacheKey + ctx.Logf("cache lock acquired for %q, will proxy request", data.cacheKey) + data.connData.unlockFn = unlock + data.connData.cacheKey = data.cacheKey return r, nil } } func (c *cachingHandler) OnResponse(resp *http.Response, ctx *goproxy.ProxyCtx) *http.Response { - if ctx.UserData == nil { - // UserData isn't set when request authorization fails. - return resp - } else if resp == nil { - // resp is nil when an error has occurred - return nil - } + data := ctx.UserData.(*proxyData) + defer data.connData.cleanup() - cd, ok := ctx.UserData.(*cachingData) - if !ok { - ctx.Warnf("proxy context data contained %T instead of *cachingData", ctx.UserData) - return resp + // resp is nil when an error has occurred + if resp == nil { + return nil } - resp.Header.Set(headerCached, strconv.FormatBool(cd.alreadyCached)) - if cd.alreadyCached || cd.cacheFor <= 0 { + resp.Header.Set("X-Cache-Hot", strconv.FormatBool(data.cacheHot)) + if data.cacheHot || data.cacheDuration <= 0 { return resp } @@ -243,37 +278,38 @@ func (c *cachingHandler) OnResponse(resp *http.Response, ctx *goproxy.ProxyCtx) return resp } - if err := c.storeInCache(ctx.Req.Context(), cd.cacheKey, dump, cd.cacheFor); err != nil { + if err := c.storeInCache(ctx.Req.Context(), data.cacheKey, dump, data.cacheDuration); err != nil { ctx.Warnf("storing dumped result in cache: %s", err) } return resp } func main() { - redopts, err := redis.ParseURL(os.Getenv(envRedisURL)) + redopts, err := redis.ParseURL(os.Getenv("REDIS_URL")) if err != nil { - log.Fatalf("Failed to parse %s: %s", envRedisURL, err) + log.Fatalf("Failed to parse REDIS_URL: %s", err) } handler := &cachingHandler{ redis: redis.NewClient(redopts), proxy: goproxy.NewProxyHttpServer(), keylocks: make(map[string]*sync.RWMutex), - authKey: os.Getenv(envAuthKey), + authKey: os.Getenv("AUTH_KEY"), } handler.proxy.Verbose = true - handler.proxy.OnRequest().HandleConnect(goproxy.AlwaysMitm) + + handler.proxy.OnRequest().HandleConnect(goproxy.FuncHttpsHandler(handler.HandleConnect)) handler.proxy.OnRequest().DoFunc(handler.OnRequest) handler.proxy.OnResponse().DoFunc(handler.OnResponse) - log.Printf("Proxy started on %s", listenAddr) + log.Printf("Proxy started on :8888") srv := &http.Server{ - Addr: listenAddr, - Handler: handler, - ReadTimeout: 10 * time.Second, - WriteTimeout: time.Minute, - IdleTimeout: time.Minute * 2, + Addr: ":8888", + // Handler: http.TimeoutHandler(handler, time.Minute, "timed out"), + Handler: handler, + ReadTimeout: time.Second * 10, + IdleTimeout: time.Minute * 2, } log.Fatal(srv.ListenAndServe())