Skip to content
This repository has been archived by the owner on Apr 19, 2024. It is now read-only.

Commit

Permalink
Don't call OnChange() event from non-owner.
Browse files Browse the repository at this point in the history
Non-owners shouldn't be persisting rate limit state.
  • Loading branch information
Baliedge committed Mar 12, 2024
1 parent 5f137ad commit f51861d
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 26 deletions.
36 changes: 18 additions & 18 deletions algorithms.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import (
// with 100 emails and the request will succeed. You can override this default behavior with `DRAIN_OVER_LIMIT`

// Implements token bucket algorithm for rate limiting. https://en.wikipedia.org/wiki/Token_bucket
func tokenBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq, rs RateLimitReqState) (resp *RateLimitResp, err error) {
func tokenBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq, reqState RateLimitReqState) (resp *RateLimitResp, err error) {
tokenBucketTimer := prometheus.NewTimer(metricFuncTimeDuration.WithLabelValues("tokenBucket"))
defer tokenBucketTimer.ObserveDuration()

Expand Down Expand Up @@ -99,7 +99,7 @@ func tokenBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq, rs Rate
s.Remove(ctx, hashKey)
}

return tokenBucketNewItem(ctx, s, c, r, rs)
return tokenBucketNewItem(ctx, s, c, r, reqState)
}

// Update the limit if it changed.
Expand Down Expand Up @@ -146,7 +146,7 @@ func tokenBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq, rs Rate
rl.ResetTime = expire
}

if s != nil {
if s != nil && reqState.IsOwner {
defer func() {
s.OnChange(ctx, r, item)
}()
Expand All @@ -161,7 +161,7 @@ func tokenBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq, rs Rate
// If we are already at the limit.
if rl.Remaining == 0 && r.Hits > 0 {
trace.SpanFromContext(ctx).AddEvent("Already over the limit")
if rs.IsOwner {
if reqState.IsOwner {
metricOverLimitCounter.Add(1)
}
rl.Status = Status_OVER_LIMIT
Expand All @@ -181,7 +181,7 @@ func tokenBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq, rs Rate
// without updating the cache.
if r.Hits > t.Remaining {
trace.SpanFromContext(ctx).AddEvent("Over the limit")
if rs.IsOwner {
if reqState.IsOwner {
metricOverLimitCounter.Add(1)
}
rl.Status = Status_OVER_LIMIT
Expand All @@ -199,11 +199,11 @@ func tokenBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq, rs Rate
}

// Item is not found in cache or store, create new.
return tokenBucketNewItem(ctx, s, c, r, rs)
return tokenBucketNewItem(ctx, s, c, r, reqState)
}

// Called by tokenBucket() when adding a new item in the store.
func tokenBucketNewItem(ctx context.Context, s Store, c Cache, r *RateLimitReq, rs RateLimitReqState) (resp *RateLimitResp, err error) {
func tokenBucketNewItem(ctx context.Context, s Store, c Cache, r *RateLimitReq, reqState RateLimitReqState) (resp *RateLimitResp, err error) {
requestTime := *r.RequestTime
expire := requestTime + r.Duration

Expand Down Expand Up @@ -239,7 +239,7 @@ func tokenBucketNewItem(ctx context.Context, s Store, c Cache, r *RateLimitReq,
// Client could be requesting that we always return OVER_LIMIT.
if r.Hits > r.Limit {
trace.SpanFromContext(ctx).AddEvent("Over the limit")
if rs.IsOwner {
if reqState.IsOwner {
metricOverLimitCounter.Add(1)
}
rl.Status = Status_OVER_LIMIT
Expand All @@ -249,15 +249,15 @@ func tokenBucketNewItem(ctx context.Context, s Store, c Cache, r *RateLimitReq,

c.Add(item)

if s != nil {
if s != nil && reqState.IsOwner {
s.OnChange(ctx, r, item)
}

return rl, nil
}

// Implements leaky bucket algorithm for rate limiting https://en.wikipedia.org/wiki/Leaky_bucket
func leakyBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq, rs RateLimitReqState) (resp *RateLimitResp, err error) {
func leakyBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq, reqState RateLimitReqState) (resp *RateLimitResp, err error) {
leakyBucketTimer := prometheus.NewTimer(metricFuncTimeDuration.WithLabelValues("V1Instance.getRateLimit_leakyBucket"))
defer leakyBucketTimer.ObserveDuration()

Expand Down Expand Up @@ -314,7 +314,7 @@ func leakyBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq, rs Rate
s.Remove(ctx, hashKey)
}

return leakyBucketNewItem(ctx, s, c, r, rs)
return leakyBucketNewItem(ctx, s, c, r, reqState)
}

if HasBehavior(r.Behavior, Behavior_RESET_REMAINING) {
Expand Down Expand Up @@ -379,15 +379,15 @@ func leakyBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq, rs Rate

// TODO: Feature missing: check for Duration change between item/request.

if s != nil {
if s != nil && reqState.IsOwner {
defer func() {
s.OnChange(ctx, r, item)
}()
}

// If we are already at the limit
if int64(b.Remaining) == 0 && r.Hits > 0 {
if rs.IsOwner {
if reqState.IsOwner {
metricOverLimitCounter.Add(1)
}
rl.Status = Status_OVER_LIMIT
Expand All @@ -405,7 +405,7 @@ func leakyBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq, rs Rate
// If requested is more than available, then return over the limit
// without updating the bucket, unless `DRAIN_OVER_LIMIT` is set.
if r.Hits > int64(b.Remaining) {
if rs.IsOwner {
if reqState.IsOwner {
metricOverLimitCounter.Add(1)
}
rl.Status = Status_OVER_LIMIT
Expand All @@ -430,11 +430,11 @@ func leakyBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq, rs Rate
return rl, nil
}

return leakyBucketNewItem(ctx, s, c, r, rs)
return leakyBucketNewItem(ctx, s, c, r, reqState)
}

// Called by leakyBucket() when adding a new item in the store.
func leakyBucketNewItem(ctx context.Context, s Store, c Cache, r *RateLimitReq, rs RateLimitReqState) (resp *RateLimitResp, err error) {
func leakyBucketNewItem(ctx context.Context, s Store, c Cache, r *RateLimitReq, reqState RateLimitReqState) (resp *RateLimitResp, err error) {
requestTime := *r.RequestTime
duration := r.Duration
rate := float64(duration) / float64(r.Limit)
Expand Down Expand Up @@ -467,7 +467,7 @@ func leakyBucketNewItem(ctx context.Context, s Store, c Cache, r *RateLimitReq,

// Client could be requesting that we start with the bucket OVER_LIMIT
if r.Hits > r.Burst {
if rs.IsOwner {
if reqState.IsOwner {
metricOverLimitCounter.Add(1)
}
rl.Status = Status_OVER_LIMIT
Expand All @@ -485,7 +485,7 @@ func leakyBucketNewItem(ctx context.Context, s Store, c Cache, r *RateLimitReq,

c.Add(item)

if s != nil {
if s != nil && reqState.IsOwner {
s.OnChange(ctx, r, item)
}

Expand Down
6 changes: 3 additions & 3 deletions gubernator.go
Original file line number Diff line number Diff line change
Expand Up @@ -585,7 +585,7 @@ func (s *V1Instance) HealthCheck(ctx context.Context, r *HealthCheckReq) (health
return health, nil
}

func (s *V1Instance) getLocalRateLimit(ctx context.Context, r *RateLimitReq, rs RateLimitReqState) (_ *RateLimitResp, err error) {
func (s *V1Instance) getLocalRateLimit(ctx context.Context, r *RateLimitReq, reqState RateLimitReqState) (_ *RateLimitResp, err error) {
ctx = tracing.StartNamedScope(ctx, "V1Instance.getLocalRateLimit", trace.WithAttributes(
attribute.String("ratelimit.key", r.UniqueKey),
attribute.String("ratelimit.name", r.Name),
Expand All @@ -595,7 +595,7 @@ func (s *V1Instance) getLocalRateLimit(ctx context.Context, r *RateLimitReq, rs
defer func() { tracing.EndScope(ctx, err) }()
defer prometheus.NewTimer(metricFuncTimeDuration.WithLabelValues("V1Instance.getLocalRateLimit")).ObserveDuration()

resp, err := s.workerPool.GetRateLimit(ctx, r, rs)
resp, err := s.workerPool.GetRateLimit(ctx, r, reqState)
if err != nil {
return nil, errors.Wrap(err, "during workerPool.GetRateLimit")
}
Expand All @@ -605,7 +605,7 @@ func (s *V1Instance) getLocalRateLimit(ctx context.Context, r *RateLimitReq, rs
s.global.QueueUpdate(r)
}

if rs.IsOwner {
if reqState.IsOwner {
metricGetRateLimitCounter.WithLabelValues("local").Inc()
}
return resp, nil
Expand Down
10 changes: 5 additions & 5 deletions workers.go
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ func (p *WorkerPool) dispatch(worker *Worker) {
}

// GetRateLimit sends a GetRateLimit request to worker pool.
func (p *WorkerPool) GetRateLimit(ctx context.Context, rlRequest *RateLimitReq, rs RateLimitReqState) (*RateLimitResp, error) {
func (p *WorkerPool) GetRateLimit(ctx context.Context, rlRequest *RateLimitReq, reqState RateLimitReqState) (*RateLimitResp, error) {
// Delegate request to assigned channel based on request key.
worker := p.getWorker(rlRequest.HashKey())
queueGauge := metricWorkerQueue.WithLabelValues("GetRateLimit", worker.name)
Expand All @@ -268,7 +268,7 @@ func (p *WorkerPool) GetRateLimit(ctx context.Context, rlRequest *RateLimitReq,
ctx: ctx,
resp: make(chan *response, 1),
request: rlRequest,
reqState: rs,
reqState: reqState,
}

// Send request.
Expand All @@ -290,14 +290,14 @@ func (p *WorkerPool) GetRateLimit(ctx context.Context, rlRequest *RateLimitReq,
}

// Handle request received by worker.
func (worker *Worker) handleGetRateLimit(ctx context.Context, req *RateLimitReq, rs RateLimitReqState, cache Cache) (*RateLimitResp, error) {
func (worker *Worker) handleGetRateLimit(ctx context.Context, req *RateLimitReq, reqState RateLimitReqState, cache Cache) (*RateLimitResp, error) {
defer prometheus.NewTimer(metricFuncTimeDuration.WithLabelValues("Worker.handleGetRateLimit")).ObserveDuration()
var rlResponse *RateLimitResp
var err error

switch req.Algorithm {
case Algorithm_TOKEN_BUCKET:
rlResponse, err = tokenBucket(ctx, worker.conf.Store, cache, req, rs)
rlResponse, err = tokenBucket(ctx, worker.conf.Store, cache, req, reqState)
if err != nil {
msg := "Error in tokenBucket"
countError(err, msg)
Expand All @@ -306,7 +306,7 @@ func (worker *Worker) handleGetRateLimit(ctx context.Context, req *RateLimitReq,
}

case Algorithm_LEAKY_BUCKET:
rlResponse, err = leakyBucket(ctx, worker.conf.Store, cache, req, rs)
rlResponse, err = leakyBucket(ctx, worker.conf.Store, cache, req, reqState)
if err != nil {
msg := "Error in leakyBucket"
countError(err, msg)
Expand Down

0 comments on commit f51861d

Please sign in to comment.