diff --git a/algorithm/all/all.go b/algorithm/all/all.go new file mode 100644 index 0000000..bb068f8 --- /dev/null +++ b/algorithm/all/all.go @@ -0,0 +1,7 @@ +package all + +import ( + _ "github.com/bringg/go_redis_ratelimit/algorithm/cloudflare" + _ "github.com/bringg/go_redis_ratelimit/algorithm/gcra" + _ "github.com/bringg/go_redis_ratelimit/algorithm/sliding_window" +) diff --git a/algorithm/cloudflare/cloudflare.go b/algorithm/cloudflare/cloudflare.go index 47adcf2..2b8c934 100644 --- a/algorithm/cloudflare/cloudflare.go +++ b/algorithm/cloudflare/cloudflare.go @@ -11,51 +11,58 @@ import ( const AlgorithmName = "cloudflare" type Cloudflare struct { - Limit algorithm.Limit - RDB algorithm.Rediser + RDB algorithm.Rediser +} + +func init() { + algorithm.Register(&algorithm.RegInfo{ + Name: AlgorithmName, + NewAlgorithm: NewAlgorithm, + }) +} - key string +func NewAlgorithm(rdb algorithm.Rediser) (algorithm.Algorithm, error) { + return &Cloudflare{ + RDB: rdb, + }, nil } -func (c *Cloudflare) Allow() (*algorithm.Result, error) { - rate := c.Limit.GetRate() - 1 +func (c *Cloudflare) Allow(key string, limit algorithm.Limit) (*algorithm.Result, error) { + rate := limit.GetRate() - 1 rateLimiter := ratelimiter.New(&RedisDataStore{ - RDB: c.RDB, - }, rate, c.Limit.GetPeriod()) + RDB: c.RDB, + ExpirationTime: 2 * limit.GetPeriod(), + }, rate, limit.GetPeriod()) - limitStatus, err := rateLimiter.Check(c.key) + limitStatus, err := rateLimiter.Check(key) if err != nil { return nil, err } - rateKey := mapKey(c.key, time.Now().UTC().Truncate(c.Limit.GetPeriod())) + rateKey := mapKey(key, time.Now().UTC().Truncate(limit.GetPeriod())) currentRate := int64(limitStatus.CurrentRate) if limitStatus.IsLimited { return &algorithm.Result{ - Limit: c.Limit, + Limit: limit, Key: rateKey, Allowed: false, Remaining: 0, RetryAfter: *limitStatus.LimitDuration, - ResetAfter: c.Limit.GetPeriod(), + ResetAfter: limit.GetPeriod(), }, nil } - if err := rateLimiter.Inc(c.key); err != nil { + if err := rateLimiter.Inc(key); err != nil { return nil, err } return &algorithm.Result{ - Limit: c.Limit, + Limit: limit, Key: rateKey, Allowed: true, Remaining: rate - currentRate, RetryAfter: 0, - ResetAfter: c.Limit.GetPeriod(), + ResetAfter: limit.GetPeriod(), }, nil } - -func (c *Cloudflare) SetKey(key string) { - c.key = key -} diff --git a/algorithm/cloudflare/redis_store.go b/algorithm/cloudflare/redis_store.go index 8df4603..ce6e2f3 100644 --- a/algorithm/cloudflare/redis_store.go +++ b/algorithm/cloudflare/redis_store.go @@ -11,7 +11,8 @@ import ( ) type RedisDataStore struct { - RDB algorithm.Rediser + RDB algorithm.Rediser + ExpirationTime time.Duration } func (s *RedisDataStore) Inc(key string, window time.Time) error { @@ -20,7 +21,7 @@ func (s *RedisDataStore) Inc(key string, window time.Time) error { if _, err := s.RDB.TxPipelined(ctx, func(pipe redis.Pipeliner) error { pipe.Incr(ctx, key) - pipe.Expire(ctx, key, time.Since(window)+time.Second) + pipe.Expire(ctx, key, s.ExpirationTime) return nil }); err != nil { diff --git a/algorithm/gcra/gcra.go b/algorithm/gcra/gcra.go index da952aa..93bae7b 100644 --- a/algorithm/gcra/gcra.go +++ b/algorithm/gcra/gcra.go @@ -12,33 +12,38 @@ import ( const AlgorithmName = "gcra" type GCRA struct { - Limit algorithm.Limit - RDB algorithm.Rediser + limiter *redis_rate.Limiter +} - key string +func init() { + algorithm.Register(&algorithm.RegInfo{ + Name: AlgorithmName, + NewAlgorithm: NewAlgorithm, + }) } -func (c *GCRA) Allow() (*algorithm.Result, error) { - res, err := redis_rate.NewLimiter(c.RDB).Allow(context.Background(), c.key, redis_rate.Limit{ - Rate: int(c.Limit.GetRate()), - Period: c.Limit.GetPeriod(), - Burst: int(c.Limit.GetBurst()), +func NewAlgorithm(rdb algorithm.Rediser) (algorithm.Algorithm, error) { + return &GCRA{ + limiter: redis_rate.NewLimiter(rdb), + }, nil +} + +func (c *GCRA) Allow(key string, limit algorithm.Limit) (*algorithm.Result, error) { + res, err := c.limiter.Allow(context.Background(), key, redis_rate.Limit{ + Rate: int(limit.GetRate()), + Period: limit.GetPeriod(), + Burst: int(limit.GetBurst()), }) if err != nil { return nil, err } return &algorithm.Result{ - Limit: c.Limit, - Key: c.key, + Limit: limit, + Key: key, Allowed: res.Allowed == 1, Remaining: int64(res.Remaining), RetryAfter: res.RetryAfter, ResetAfter: res.ResetAfter, }, nil } - -// SetKey _ -func (c *GCRA) SetKey(key string) { - c.key = key -} diff --git a/algorithm/gcra/gcra_lua.go b/algorithm/gcra/gcra_lua.go deleted file mode 100644 index 7214468..0000000 --- a/algorithm/gcra/gcra_lua.go +++ /dev/null @@ -1,64 +0,0 @@ -package gcra - -import "github.com/go-redis/redis/v8" - -// Copyright (c) 2017 Pavel Pravosud -// https://github.com/rwz/redis-gcra/blob/master/vendor/perform_gcra_ratelimit.lua -var script = redis.NewScript(` --- this script has side-effects, so it requires replicate commands mode -redis.replicate_commands() - -local rate_limit_key = KEYS[1] -local burst = ARGV[1] -local rate = ARGV[2] -local period = ARGV[3] -local cost = ARGV[4] - -local emission_interval = period / rate -local increment = emission_interval * cost -local burst_offset = emission_interval * burst -local now = redis.call("TIME") - --- redis returns time as an array containing two integers: seconds of the epoch --- time (10 digits) and microseconds (6 digits). for convenience we need to --- convert them to a floating point number. the resulting number is 16 digits, --- bordering on the limits of a 64-bit double-precision floating point number. --- adjust the epoch to be relative to Jan 1, 2017 00:00:00 GMT to avoid floating --- point problems. this approach is good until "now" is 2,483,228,799 (Wed, 09 --- Sep 2048 01:46:39 GMT), when the adjusted value is 16 digits. -local jan_1_2017 = 1483228800 -now = (now[1] - jan_1_2017) + (now[2] / 1000000) - -local tat = redis.call("GET", rate_limit_key) - -if not tat then - tat = now -else - tat = tonumber(tat) -end - -local new_tat = math.max(tat, now) + increment - -local allow_at = new_tat - burst_offset -local diff = now - allow_at - -local limited -local retry_after -local reset_after - -local remaining = math.floor(diff / emission_interval + 0.5) -- poor man's round - -if remaining < 0 then - limited = 1 - remaining = 0 - reset_after = tat - now - retry_after = diff * -1 -else - limited = 0 - reset_after = new_tat - now - redis.call("SET", rate_limit_key, new_tat, "EX", math.ceil(reset_after)) - retry_after = -1 -end - -return {limited, remaining, tostring(retry_after), tostring(reset_after)} -`) diff --git a/algorithm/model.go b/algorithm/model.go index 9ee660c..39c90c3 100644 --- a/algorithm/model.go +++ b/algorithm/model.go @@ -7,29 +7,35 @@ import ( "github.com/go-redis/redis/v8" ) +var Registry []*RegInfo + type ( Limit interface { - GetAlgorithm() string - GetBurst() int64 GetRate() int64 + GetBurst() int64 + GetAlgorithm() string GetPeriod() time.Duration } Rediser interface { TxPipeline() redis.Pipeliner - TxPipelined(ctx context.Context, fn func(pipe redis.Pipeliner) error) ([]redis.Cmder, error) - Del(ctx context.Context, keys ...string) *redis.IntCmd - Get(ctx context.Context, key string) *redis.StringCmd Incr(ctx context.Context, key string) *redis.IntCmd - Eval(ctx context.Context, script string, keys []string, args ...interface{}) *redis.Cmd - EvalSha(ctx context.Context, sha1 string, keys []string, args ...interface{}) *redis.Cmd - ScriptExists(ctx context.Context, hashes ...string) *redis.BoolSliceCmd - ScriptLoad(ctx context.Context, script string) *redis.StringCmd - ZRangeByScoreWithScores(ctx context.Context, key string, opt *redis.ZRangeBy) *redis.ZSliceCmd - ZRemRangeByScore(ctx context.Context, key string, min string, max string) *redis.IntCmd ZCard(ctx context.Context, key string) *redis.IntCmd + Get(ctx context.Context, key string) *redis.StringCmd + Del(ctx context.Context, keys ...string) *redis.IntCmd + ScriptLoad(ctx context.Context, script string) *redis.StringCmd + ScriptExists(ctx context.Context, hashes ...string) *redis.BoolSliceCmd ZAdd(ctx context.Context, key string, members ...*redis.Z) *redis.IntCmd Expire(ctx context.Context, key string, expiration time.Duration) *redis.BoolCmd + ZRemRangeByScore(ctx context.Context, key string, min string, max string) *redis.IntCmd + Eval(ctx context.Context, script string, keys []string, args ...interface{}) *redis.Cmd + EvalSha(ctx context.Context, sha1 string, keys []string, args ...interface{}) *redis.Cmd + TxPipelined(ctx context.Context, fn func(pipe redis.Pipeliner) error) ([]redis.Cmder, error) + ZRangeByScoreWithScores(ctx context.Context, key string, opt *redis.ZRangeBy) *redis.ZSliceCmd + } + + Algorithm interface { + Allow(key string, limit Limit) (*Result, error) } Result struct { @@ -60,4 +66,13 @@ type ( // until Limit and Remaining will be equal. ResetAfter time.Duration } + + RegInfo struct { + Name string + NewAlgorithm func(rdb Rediser) (Algorithm, error) + } ) + +func Register(info *RegInfo) { + Registry = append(Registry, info) +} diff --git a/algorithm/sliding_window/sliding_window.go b/algorithm/sliding_window/sliding_window.go index 373cea4..56574c5 100644 --- a/algorithm/sliding_window/sliding_window.go +++ b/algorithm/sliding_window/sliding_window.go @@ -11,21 +11,26 @@ import ( const AlgorithmName = "sliding_window" type SlidingWindow struct { - Limit algorithm.Limit - RDB algorithm.Rediser + RDB algorithm.Rediser +} - key string +func init() { + algorithm.Register(&algorithm.RegInfo{ + Name: AlgorithmName, + NewAlgorithm: NewAlgorithm, + }) } -func (c *SlidingWindow) SetKey(key string) { - c.key = key +func NewAlgorithm(rdb algorithm.Rediser) (algorithm.Algorithm, error) { + return &SlidingWindow{ + RDB: rdb, + }, nil } -func (c *SlidingWindow) Allow() (r *algorithm.Result, err error) { - limit := c.Limit +func (c *SlidingWindow) Allow(key string, limit algorithm.Limit) (r *algorithm.Result, err error) { values := []interface{}{limit.GetRate(), limit.GetPeriod().Seconds()} - v, err := script2.Run(context.Background(), c.RDB, []string{c.key}, values...).Result() + v, err := script.Run(context.Background(), c.RDB, []string{key}, values...).Result() if err != nil { return nil, err } @@ -39,7 +44,7 @@ func (c *SlidingWindow) Allow() (r *algorithm.Result, err error) { return &algorithm.Result{ Limit: limit, - Key: c.key, + Key: key, Allowed: values[0].(int64) == 1, Remaining: values[1].(int64), RetryAfter: dur(retryAfter), diff --git a/algorithm/sliding_window/sliding_window_lua.go b/algorithm/sliding_window/sliding_window_lua.go index 4c4d191..8cc1e44 100644 --- a/algorithm/sliding_window/sliding_window_lua.go +++ b/algorithm/sliding_window/sliding_window_lua.go @@ -2,7 +2,7 @@ package sliding_window import "github.com/go-redis/redis/v8" -var script2 = redis.NewScript(` +var script = redis.NewScript(` -- this script has side-effects, so it requires replicate commands mode redis.replicate_commands() diff --git a/examples/cloudflare/main.go b/examples/cloudflare/main.go index 1db855e..ebbe66b 100644 --- a/examples/cloudflare/main.go +++ b/examples/cloudflare/main.go @@ -6,7 +6,7 @@ import ( "github.com/go-redis/redis/v8" - "github.com/bringg/go_redis_ratelimit" + limiter "github.com/bringg/go_redis_ratelimit" "github.com/bringg/go_redis_ratelimit/algorithm/cloudflare" ) @@ -15,10 +15,14 @@ func main() { if err != nil { log.Fatal(err) } + client := redis.NewClient(option) + l, err := limiter.NewLimiter(client) + if err != nil { + log.Fatal(err) + } - limiter := go_redis_ratelimit.NewLimiter(client) - res, err := limiter.Allow("api_gateway:klu4ik", &go_redis_ratelimit.Limit{ + res, err := l.Allow("api_gateway:klu4ik", &limiter.Limit{ Algorithm: cloudflare.AlgorithmName, Rate: 10, Period: 10 * time.Second, diff --git a/examples/gcra/main.go b/examples/gcra/main.go index 2eb03dc..fa93d6f 100644 --- a/examples/gcra/main.go +++ b/examples/gcra/main.go @@ -17,7 +17,11 @@ func main() { } client := redis.NewClient(option) - limiter := go_redis_ratelimit.NewLimiter(client) + limiter, err := go_redis_ratelimit.NewLimiter(client) + if err != nil { + log.Fatal(err) + } + res, err := limiter.Allow("api_gateway:klu4ik", &go_redis_ratelimit.Limit{ Algorithm: gcra.AlgorithmName, Rate: 10, diff --git a/examples/sliding_window/main.go b/examples/sliding_window/main.go index 4a221a0..57b07c7 100644 --- a/examples/sliding_window/main.go +++ b/examples/sliding_window/main.go @@ -17,7 +17,11 @@ func main() { } client := redis.NewClient(option) - limiter := go_redis_ratelimit.NewLimiter(client) + limiter, err := go_redis_ratelimit.NewLimiter(client) + if err != nil { + log.Fatal(err) + } + res, err := limiter.Allow("api_gateway:klu4ik", &go_redis_ratelimit.Limit{ Algorithm: sliding_window.AlgorithmName, Rate: 10, diff --git a/go.mod b/go.mod index 45cea42..0ed29b9 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/bringg/go_redis_ratelimit -go 1.15 +go 1.16 require ( github.com/alicebob/miniredis/v2 v2.14.3 @@ -9,6 +9,5 @@ require ( github.com/go-redis/redis/v8 v8.8.0 github.com/go-redis/redis_rate/v9 v9.1.1 github.com/stretchr/testify v1.7.0 - google.golang.org/protobuf v1.26.0 // indirect gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect ) diff --git a/go.sum b/go.sum index c983116..71e7d2d 100644 --- a/go.sum +++ b/go.sum @@ -29,8 +29,6 @@ github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrU github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= -github.com/golang/protobuf v1.5.0 h1:LUVKkCeviFUMKqHa4tXIIij/lbhnMbP7Fn5wKdKkRh4= -github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= @@ -111,7 +109,6 @@ golang.org/x/tools v0.0.0-20201224043029-2b0845dc783e/go.mod h1:emZCQorbCU4vsT4f golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= @@ -119,9 +116,6 @@ google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQ google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= -google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= -google.golang.org/protobuf v1.26.0 h1:bxAC2xTBsZGibn2RTntX0oH50xLsqy1OxA9tTL3p/lk= -google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= diff --git a/limit.go b/limit.go new file mode 100644 index 0000000..2743d9b --- /dev/null +++ b/limit.go @@ -0,0 +1,26 @@ +package go_redis_ratelimit + +import "time" + +type Limit struct { + Algorithm string + Burst int64 + Rate int64 + Period time.Duration +} + +func (l *Limit) GetAlgorithm() string { + return l.Algorithm +} + +func (l *Limit) GetBurst() int64 { + return l.Burst +} + +func (l *Limit) GetRate() int64 { + return l.Rate +} + +func (l *Limit) GetPeriod() time.Duration { + return l.Period +} diff --git a/limiter.go b/limiter.go new file mode 100644 index 0000000..baff76f --- /dev/null +++ b/limiter.go @@ -0,0 +1,55 @@ +package go_redis_ratelimit + +import ( + "errors" + + "github.com/go-redis/redis/v8" + + "github.com/bringg/go_redis_ratelimit/algorithm" + _ "github.com/bringg/go_redis_ratelimit/algorithm/all" +) + +const DefaultPrefix = "limiter" + +// Limiter controls how frequently events are allowed to happen. +type Limiter struct { + Prefix string + algorithms map[string]algorithm.Algorithm +} + +// NewLimiter returns a new Limiter. +func NewLimiter(rdb *redis.Client) (*Limiter, error) { + algorithms := make(map[string]algorithm.Algorithm, len(algorithm.Registry)) + + var err error + for _, info := range algorithm.Registry { + if algorithms[info.Name], err = info.NewAlgorithm(rdb); err != nil { + return nil, err + } + } + + return &Limiter{ + Prefix: DefaultPrefix, + algorithms: algorithms, + }, nil +} + +func (l *Limiter) Allow(key string, limit *Limit) (*algorithm.Result, error) { + algo, err := l.findAlgorithm(limit.Algorithm) + if err != nil { + return nil, err + } + + return algo.Allow( + l.Prefix+":"+limit.Algorithm+":"+key, + limit, + ) +} + +func (l *Limiter) findAlgorithm(name string) (algorithm.Algorithm, error) { + if algo, ok := l.algorithms[name]; ok { + return algo, nil + } + + return nil, errors.New("algorithm is not supported") +} diff --git a/rate_test.go b/limiter_test.go similarity index 79% rename from rate_test.go rename to limiter_test.go index a5a9020..7e11aed 100644 --- a/rate_test.go +++ b/limiter_test.go @@ -12,9 +12,10 @@ import ( "github.com/bringg/go_redis_ratelimit/algorithm/cloudflare" "github.com/bringg/go_redis_ratelimit/algorithm/gcra" "github.com/bringg/go_redis_ratelimit/algorithm/sliding_window" - swv2 "github.com/bringg/go_redis_ratelimit/algorithm/sliding_window/v2" ) +var limiter = rateLimiter() + func rateLimiter() *Limiter { mr, err := miniredis.Run() if err != nil { @@ -29,7 +30,12 @@ func rateLimiter() *Limiter { panic(err) } - return NewLimiter(client) + l, err := NewLimiter(client) + if err != nil { + panic(err) + } + + return l } func TestLimiter_Allow(t *testing.T) { @@ -74,53 +80,32 @@ func TestLimiter_Allow(t *testing.T) { } func Benchmark_CloudflareAlgorithm(b *testing.B) { - limiter := rateLimiter() - for i := 0; i < b.N; i++ { limiter.Allow("cloudflare", &Limit{ Algorithm: cloudflare.AlgorithmName, Rate: 2000, - Burst: 2000, Period: 60 * time.Second, }) } } func Benchmark_GcraAlgorithm(b *testing.B) { - limiter := rateLimiter() - for i := 0; i < b.N; i++ { limiter.Allow("gcra", &Limit{ Algorithm: gcra.AlgorithmName, Rate: 2000, Burst: 2000, - Period: 2 * time.Second, + Period: 10 * time.Second, }) } } func Benchmark_SlidingWindowAlgorithm(b *testing.B) { - limiter := rateLimiter() - for i := 0; i < b.N; i++ { limiter.Allow("sliding_window", &Limit{ Algorithm: sliding_window.AlgorithmName, Rate: 2000, - Burst: 2000, - Period: 2 * time.Second, - }) - } -} - -func Benchmark_SlidingWindowV2Algorithm(b *testing.B) { - limiter := rateLimiter() - - for i := 0; i < b.N; i++ { - limiter.Allow("sliding_window_v2", &Limit{ - Algorithm: swv2.AlgorithmName, - Rate: 2000, - Burst: 2000, - Period: 2 * time.Second, + Period: 10 * time.Second, }) } } diff --git a/rate.go b/rate.go deleted file mode 100644 index be3dda4..0000000 --- a/rate.go +++ /dev/null @@ -1,83 +0,0 @@ -package go_redis_ratelimit - -import ( - "errors" - "time" - - "github.com/go-redis/redis/v8" - - "github.com/bringg/go_redis_ratelimit/algorithm" - "github.com/bringg/go_redis_ratelimit/algorithm/cloudflare" - "github.com/bringg/go_redis_ratelimit/algorithm/gcra" - "github.com/bringg/go_redis_ratelimit/algorithm/sliding_window" - sliding_windowV2 "github.com/bringg/go_redis_ratelimit/algorithm/sliding_window/v2" -) - -const ( - DefaultPrefix = "limiter" -) - -type ( - Algorithm interface { - Allow() (*algorithm.Result, error) - SetKey(string) - } - - Limit struct { - Algorithm string - Burst int64 - Rate int64 - Period time.Duration - } - - // Limiter controls how frequently events are allowed to happen. - Limiter struct { - rdb *redis.Client - Prefix string - } -) - -// NewLimiter returns a new Limiter. -func NewLimiter(rdb *redis.Client) *Limiter { - return &Limiter{ - rdb: rdb, - Prefix: DefaultPrefix, - } -} - -func (l *Limiter) Allow(key string, limit *Limit) (*algorithm.Result, error) { - var algo Algorithm - - switch limit.Algorithm { - case sliding_windowV2.AlgorithmName: - algo = &sliding_windowV2.SlidingWindow{Limit: limit, RDB: l.rdb} - case sliding_window.AlgorithmName: - algo = &sliding_window.SlidingWindow{Limit: limit, RDB: l.rdb} - case cloudflare.AlgorithmName: - algo = &cloudflare.Cloudflare{Limit: limit, RDB: l.rdb} - case gcra.AlgorithmName: - algo = &gcra.GCRA{Limit: limit, RDB: l.rdb} - default: - return nil, errors.New("algorithm is not supported") - } - - algo.SetKey(l.Prefix + ":" + limit.Algorithm + ":" + key) - - return algo.Allow() -} - -func (l *Limit) GetAlgorithm() string { - return l.Algorithm -} - -func (l *Limit) GetBurst() int64 { - return l.Burst -} - -func (l *Limit) GetRate() int64 { - return l.Rate -} - -func (l *Limit) GetPeriod() time.Duration { - return l.Period -}