Skip to content

Commit

Permalink
rename headers, test rate with cost>1
Browse files Browse the repository at this point in the history
  • Loading branch information
klaidliadon committed Jul 11, 2024
1 parent f83fe63 commit e910cd1
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 53 deletions.
6 changes: 3 additions & 3 deletions common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ func newConfig() quotacontrol.Config {
},
RateLimiter: middleware.RLConfig{
Enabled: true,
PublicRPM: 10,
AccountRPM: 100,
ServiceRPM: 1000,
PublicRPM: 100,
AccountRPM: 1000,
ServiceRPM: 10000,
},
}
}
Expand Down
55 changes: 28 additions & 27 deletions handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,13 @@ func TestMiddlewareUseAccessKey(t *testing.T) {
key := proto.GenerateAccessKey(project)
service := proto.Service_Indexer

const _credits = 10
limit := proto.Limit{
RateLimit: 100,
FreeWarn: 5,
FreeMax: 5,
OverWarn: 7,
OverMax: 10,
RateLimit: _credits * 100,
FreeWarn: _credits * 5,
FreeMax: _credits * 5,
OverWarn: _credits * 7,
OverMax: _credits * 10,
}

ctx := context.Background()
Expand All @@ -53,8 +54,8 @@ func TestMiddlewareUseAccessKey(t *testing.T) {
r := chi.NewRouter()
r.Use(
middleware.Session(client, auth),
addCredits(2).Middleware,
addCredits(-1).Middleware,
addCredits(_credits*2).Middleware,
addCredits(_credits*-1).Middleware,
middleware.RateLimit(cfg.RateLimiter, cfg.Redis),
middleware.SpendUsage(client),
)
Expand All @@ -70,15 +71,15 @@ func TestMiddlewareUseAccessKey(t *testing.T) {
server.FlushNotifications()

// Spend Free CU
for i := int64(1); i < limit.FreeWarn; i++ {
for i := int64(_credits); i < limit.FreeWarn; i += _credits {
ok, headers, err := executeRequest(ctx, r, "", key, "")
assert.NoError(t, err)
assert.True(t, ok)
assert.Equal(t, strconv.FormatInt(limit.FreeMax, 10), headers.Get(middleware.HeaderQuotaLimit))
assert.Equal(t, strconv.FormatInt(limit.FreeMax-i, 10), headers.Get(middleware.HeaderQuotaRemaining))
assert.Equal(t, "", headers.Get(middleware.HeaderQuotaOverage))
assert.Empty(t, server.GetEvents(project), i)
expectedUsage.Add(proto.AccessUsage{ValidCompute: 1})
expectedUsage.Add(proto.AccessUsage{ValidCompute: _credits})
}

// Go over free CU
Expand All @@ -89,18 +90,18 @@ func TestMiddlewareUseAccessKey(t *testing.T) {
assert.Equal(t, "0", headers.Get(middleware.HeaderQuotaRemaining))
assert.Equal(t, "", headers.Get(middleware.HeaderQuotaOverage))
assert.Contains(t, server.GetEvents(project), proto.EventType_FreeMax)
expectedUsage.Add(proto.AccessUsage{ValidCompute: 1})
expectedUsage.Add(proto.AccessUsage{ValidCompute: _credits})

// Get close to soft quota
for i := limit.FreeWarn + 1; i < limit.OverWarn; i++ {
for i := limit.FreeWarn + _credits; i < limit.OverWarn; i += _credits {
ok, headers, err := executeRequest(ctx, r, "", key, "")
assert.NoError(t, err)
assert.True(t, ok)
assert.Equal(t, strconv.FormatInt(limit.FreeMax, 10), headers.Get(middleware.HeaderQuotaLimit))
assert.Equal(t, "0", headers.Get(middleware.HeaderQuotaRemaining))
assert.Equal(t, strconv.FormatInt(i-limit.FreeWarn, 10), headers.Get(middleware.HeaderQuotaOverage))
assert.Len(t, server.GetEvents(project), 1)
expectedUsage.Add(proto.AccessUsage{OverCompute: 1})
expectedUsage.Add(proto.AccessUsage{OverCompute: _credits})
}

// Go over soft quota
Expand All @@ -111,18 +112,18 @@ func TestMiddlewareUseAccessKey(t *testing.T) {
assert.Equal(t, "0", headers.Get(middleware.HeaderQuotaRemaining))
assert.Equal(t, strconv.FormatInt(limit.OverWarn-limit.FreeWarn, 10), headers.Get(middleware.HeaderQuotaOverage))
assert.Contains(t, server.GetEvents(project), proto.EventType_OverWarn)
expectedUsage.Add(proto.AccessUsage{OverCompute: 1})
expectedUsage.Add(proto.AccessUsage{OverCompute: _credits})

// Get close to hard quota
for i := limit.OverWarn + 1; i < limit.OverMax; i++ {
for i := limit.OverWarn + _credits; i < limit.OverMax; i += _credits {
ok, headers, err := executeRequest(ctx, r, "", key, "")
assert.NoError(t, err)
assert.True(t, ok)
assert.Equal(t, strconv.FormatInt(limit.FreeMax, 10), headers.Get(middleware.HeaderQuotaLimit))
assert.Equal(t, "0", headers.Get(middleware.HeaderQuotaRemaining))
assert.Equal(t, strconv.FormatInt(i-limit.FreeWarn, 10), headers.Get(middleware.HeaderQuotaOverage))
assert.Len(t, server.GetEvents(project), 2)
expectedUsage.Add(proto.AccessUsage{OverCompute: 1})
expectedUsage.Add(proto.AccessUsage{OverCompute: _credits})
}

// Go over hard quota
Expand All @@ -133,7 +134,7 @@ func TestMiddlewareUseAccessKey(t *testing.T) {
assert.Equal(t, "0", headers.Get(middleware.HeaderQuotaRemaining))
assert.Equal(t, strconv.FormatInt(limit.OverMax-limit.FreeWarn, 10), headers.Get(middleware.HeaderQuotaOverage))
assert.Contains(t, server.GetEvents(project), proto.EventType_OverMax)
expectedUsage.Add(proto.AccessUsage{OverCompute: 1})
expectedUsage.Add(proto.AccessUsage{OverCompute: _credits})

// Denied
for i := 0; i < 10; i++ {
Expand All @@ -143,23 +144,23 @@ func TestMiddlewareUseAccessKey(t *testing.T) {
assert.Equal(t, strconv.FormatInt(limit.FreeMax, 10), headers.Get(middleware.HeaderQuotaLimit))
assert.Equal(t, "0", headers.Get(middleware.HeaderQuotaRemaining))
assert.Equal(t, strconv.FormatInt(limit.OverMax-limit.FreeWarn, 10), headers.Get(middleware.HeaderQuotaOverage))
expectedUsage.Add(proto.AccessUsage{LimitedCompute: 1})
expectedUsage.Add(proto.AccessUsage{LimitedCompute: _credits})
}

// check the usage
client.Stop(context.Background())
usage, err := server.Store.GetAccountUsage(ctx, project, proto.Ptr(service), now.Add(-time.Hour), now.Add(time.Hour))
assert.NoError(t, err)
assert.Equal(t, int64(expectedUsage.GetTotalUsage()), counter.GetValue())
assert.Equal(t, int64(expectedUsage.GetTotalUsage()), _credits*counter.GetValue())
assert.Equal(t, &expectedUsage, &usage)
})

t.Run("ChangeLimits", func(t *testing.T) {
// Increase CreditsOverageLimit which should still allow requests to go through, etc.
err = server.Store.SetAccessLimit(ctx, project, &proto.Limit{
RateLimit: 100,
OverWarn: 5,
OverMax: 110,
RateLimit: _credits * 100,
OverWarn: _credits * 5,
OverMax: _credits * 110,
})
assert.NoError(t, err)
err = client.ClearQuotaCacheByAccessKey(ctx, key)
Expand All @@ -178,8 +179,8 @@ func TestMiddlewareUseAccessKey(t *testing.T) {
client.Stop(context.Background())
usage, err := server.Store.GetAccountUsage(ctx, project, proto.Ptr(service), now.Add(-time.Hour), now.Add(time.Hour))
assert.NoError(t, err)
expectedUsage.Add(proto.AccessUsage{ValidCompute: 0, OverCompute: 1, LimitedCompute: 0})
assert.Equal(t, int64(expectedUsage.GetTotalUsage()), counter.GetValue())
expectedUsage.Add(proto.AccessUsage{ValidCompute: 0, OverCompute: _credits, LimitedCompute: 0})
assert.Equal(t, int64(expectedUsage.GetTotalUsage()), _credits*counter.GetValue())
assert.Equal(t, &expectedUsage, &usage)
})

Expand All @@ -188,7 +189,7 @@ func TestMiddlewareUseAccessKey(t *testing.T) {

ctx := middleware.WithTime(context.Background(), now)

for i, max := 0, cfg.RateLimiter.PublicRPM*2; i < max; i++ {
for i, max := 0, cfg.RateLimiter.PublicRPM*2; i < max; i += _credits {
ok, headers, err := executeRequest(ctx, r, "", "", "")
if i < cfg.RateLimiter.PublicRPM {
assert.NoError(t, err, i)
Expand All @@ -204,7 +205,7 @@ func TestMiddlewareUseAccessKey(t *testing.T) {
client.Stop(context.Background())
usage, err := server.Store.GetAccountUsage(ctx, project, proto.Ptr(service), now.Add(-time.Hour), now.Add(time.Hour))
assert.NoError(t, err)
assert.Equal(t, int64(expectedUsage.GetTotalUsage()), counter.GetValue())
assert.Equal(t, int64(expectedUsage.GetTotalUsage()), _credits*counter.GetValue())
assert.Equal(t, &expectedUsage, &usage)
})

Expand Down Expand Up @@ -244,7 +245,7 @@ func TestMiddlewareUseAccessKey(t *testing.T) {
client.Stop(context.Background())
usage, err := server.Store.GetAccountUsage(ctx, project, proto.Ptr(service), now.Add(-time.Hour), now.Add(time.Hour))
assert.NoError(t, err)
assert.Equal(t, int64(expectedUsage.GetTotalUsage()), counter.GetValue())
assert.Equal(t, int64(expectedUsage.GetTotalUsage()), _credits*counter.GetValue())
assert.Equal(t, &expectedUsage, &usage)
})

Expand All @@ -264,7 +265,7 @@ func TestMiddlewareUseAccessKey(t *testing.T) {
client.Stop(context.Background())
usage, err := server.Store.GetAccountUsage(ctx, project, proto.Ptr(service), now.Add(-time.Hour), now.Add(time.Hour))
assert.NoError(t, err)
assert.Equal(t, int64(expectedUsage.GetTotalUsage()), counter.GetValue())
assert.Equal(t, int64(expectedUsage.GetTotalUsage()), _credits*counter.GetValue())
assert.Equal(t, &expectedUsage, &usage)
})
}
Expand Down
18 changes: 12 additions & 6 deletions middleware/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,19 @@ import (
"time"

"github.com/0xsequence/quotacontrol/proto"
"github.com/go-chi/httprate"
)

const (
HeaderAccessKey = "X-Access-Key"
HeaderOrigin = "Origin"
HeaderQuotaLimit = "Quota-Limit"
HeaderQuotaRemaining = "Quota-Remaining"
HeaderQuotaOverage = "Quota-Overage"
HeaderAccessKey = "X-Access-Key"
HeaderOrigin = "Origin"
HeaderQuotaLimit = "Quota-Limit"
HeaderQuotaRemaining = "Quota-Remaining"
HeaderQuotaOverage = "Quota-Overage"
HeaderQuotaCost = "Quota-Cost"
HeaderQuotaRateRemaining = "Quota-Rate-Remaining"
HeaderQuotaRateLimit = "Quota-Rate-Limit"
HeaderQuotaRateReset = "Quota-Rate-Reset"
)

// Client is the interface that wraps the basic FetchKeyQuota, GetUsage and SpendQuota methods.
Expand Down Expand Up @@ -202,8 +207,9 @@ func getProjectID(ctx context.Context) (uint64, bool) {
return v, ok
}

// WithComputeUnits sets the compute units.
// WithComputeUnits sets the compute units and rate limit increment to the context.
func WithComputeUnits(ctx context.Context, cu int64) context.Context {
ctx = httprate.WithIncrement(ctx, int(cu))
return context.WithValue(ctx, ctxKeyComputeUnits, cu)
}

Expand Down
52 changes: 35 additions & 17 deletions middleware/middleware_ratelimit.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package middleware

import (
"cmp"
"context"
"net/http"
"time"

Expand All @@ -21,16 +22,29 @@ type RLConfig struct {
ErrorMsg string `toml:"error_message"`
}

func (r RLConfig) getRateLimit(ctx context.Context) int {
if _, ok := GetService(ctx); ok {
return r.ServiceRPM
}
if q, ok := GetAccessQuota(ctx); ok {
return int(q.Limit.RateLimit)
}
if _, ok := GetAccount(ctx); ok {
return r.AccountRPM
}
return r.PublicRPM
}

func RateLimit(rlCfg RLConfig, redisCfg redis.Config) func(next http.Handler) http.Handler {
if !rlCfg.Enabled {
return func(next http.Handler) http.Handler {
return next
}
}

defaultRPM := cmp.Or(rlCfg.PublicRPM, 120)
accountRPM := cmp.Or(rlCfg.AccountRPM, 4000)
serviceRPM := cmp.Or(rlCfg.ServiceRPM, 0)
rlCfg.PublicRPM = cmp.Or(rlCfg.PublicRPM, 1000)
rlCfg.AccountRPM = cmp.Or(rlCfg.AccountRPM, 4000)
rlCfg.ServiceRPM = cmp.Or(rlCfg.ServiceRPM, 0)

var limitCounter httprate.LimitCounter
if redisCfg.Enabled {
Expand Down Expand Up @@ -61,25 +75,29 @@ func RateLimit(rlCfg RLConfig, redisCfg redis.Config) func(next http.Handler) ht
httprate.WithLimitHandler(proto.ErrLimitExceeded.WithMessage(rlCfg.ErrorMsg).Handler),
}

limiter := httprate.NewRateLimiter(defaultRPM, _RateLimitWindow, options...)
limiter := httprate.NewRateLimiter(rlCfg.PublicRPM, _RateLimitWindow, options...)

// The rate limiter middleware
return func(next http.Handler) http.Handler {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
h := w.Header()
swapHeader(h, "X-RateLimit-Limit", HeaderQuotaRateLimit)
swapHeader(h, "X-RateLimit-Remaining", HeaderQuotaRateRemaining)
swapHeader(h, "X-RateLimit-Reset", HeaderQuotaRateReset)
next.ServeHTTP(w, r)
})
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()

if _, ok := GetService(ctx); ok {
ctx = httprate.WithRequestLimit(ctx, serviceRPM)
} else if q, ok := GetAccessQuota(ctx); ok {
if cu, ok := getComputeUnits(ctx); ok {
ctx = httprate.WithIncrement(ctx, int(cu))
}
ctx = httprate.WithRequestLimit(ctx, int(q.Limit.RateLimit))
} else if _, ok := GetAccount(ctx); ok {
ctx = httprate.WithRequestLimit(ctx, accountRPM)
}

limiter.Handler(next).ServeHTTP(w, r.WithContext(ctx))
ctx = httprate.WithRequestLimit(ctx, rlCfg.getRateLimit(ctx))
limiter.Handler(handler).ServeHTTP(w, r.WithContext(ctx))
})
}
}

// swapHeader swaps the header from one key to another.
func swapHeader(h http.Header, from, to string) {
if v := h.Get(from); v != "" {
h.Set(to, v)
h.Del(from)
}
}
2 changes: 2 additions & 0 deletions middleware/middleware_usage.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ func EnsureUsage(client Client) func(next http.Handler) http.Handler {
next.ServeHTTP(w, r)
return
}
w.Header().Set(HeaderQuotaCost, strconv.FormatInt(cu, 10))

usage, err := client.FetchUsage(ctx, quota, GetTime(ctx))
if err != nil {
Expand Down Expand Up @@ -73,6 +74,7 @@ func SpendUsage(client Client) func(next http.Handler) http.Handler {
next.ServeHTTP(w, r)
return
}
w.Header().Set(HeaderQuotaCost, strconv.FormatInt(cu, 10))

ok, total, err := client.SpendQuota(ctx, quota, cu, GetTime(ctx))
if err != nil && !errors.Is(err, proto.ErrLimitExceeded) {
Expand Down

0 comments on commit e910cd1

Please sign in to comment.