From f87c1d8fc2da0aef1ce6a2b539591d9c922ca22d Mon Sep 17 00:00:00 2001 From: Nicolas-ggd Date: Sat, 29 Jun 2024 11:14:23 +0400 Subject: [PATCH 1/2] [FIX] token bucket fill (#9) --- limiter.go | 69 ++++++++++++++++++++++++------------------------------ 1 file changed, 31 insertions(+), 38 deletions(-) diff --git a/limiter.go b/limiter.go index f242a77..1be2e41 100644 --- a/limiter.go +++ b/limiter.go @@ -5,6 +5,7 @@ import ( b64 "encoding/base64" "errors" "log" + "math" "net/http" "os" "sync" @@ -15,7 +16,10 @@ import ( ) // define constant variable of keyPrefix to avoid duplicate key in Redis -const keyPrefix = "ls_prefix:" +const ( + keyPrefix = "ls_prefix:" + lastRefillPrefix = "_lastRefillTime" +) // RateLimiter is struct based on Redis type RateLimiter struct { @@ -29,7 +33,7 @@ type RateLimiter struct { currentToken int64 // lastRefillTime represents time that this bucket fill operation was tried - lastRefillTime time.Time + refillInterval time.Duration mutex sync.Mutex @@ -46,12 +50,12 @@ func encodeKey(value string) string { } // NewRateLimiter to received and define new RateLimiter struct -func NewRateLimiter(client *redis.Client, rate, maxToken int64) *RateLimiter { +func NewRateLimiter(client *redis.Client, rate, maxToken int64, refillInterval time.Duration) *RateLimiter { return &RateLimiter{ client: client, rate: rate, maxTokens: maxToken, - lastRefillTime: time.Now(), + refillInterval: refillInterval, currentToken: maxToken, logger: log.New(os.Stdout, "RateLimiter: ", log.Lmicroseconds), } @@ -67,23 +71,24 @@ func NewRateLimiter(client *redis.Client, rate, maxToken int64) *RateLimiter { // Returns: // // bool: Returns true if the request is allowed, false otherwise. -func (rl *RateLimiter) IsRequestAllowed(key string, tokens int64) bool { +func (rl *RateLimiter) IsRequestAllowed(key string, token int64) bool { // use mutex to avoid race condition rl.mutex.Lock() defer rl.mutex.Unlock() - // encode key sEnc := keyPrefix + encodeKey(key) - // get current token count from Redis tokenCount, err := rl.client.Get(context.Background(), sEnc).Int64() if err != nil && !errors.Is(err, redis.Nil) { rl.logger.Printf("Error getting token count from Redis: %v", err) return false } - // get last refill time from Redis - lastRefillTimeStr, err := rl.client.Get(context.Background(), sEnc+"_lastRefillTime").Result() + if errors.Is(err, redis.Nil) { + tokenCount = rl.maxTokens + } + + lastRefillTimeStr, err := rl.client.Get(context.Background(), sEnc+lastRefillPrefix).Result() var lastRefillTime time.Time if err == nil { lastRefillTime, err = time.Parse(time.RFC3339, lastRefillTimeStr) @@ -94,27 +99,20 @@ func (rl *RateLimiter) IsRequestAllowed(key string, tokens int64) bool { } else if !errors.Is(err, redis.Nil) { rl.logger.Printf("Error getting last refill time from Redis: %v", err) return false + } else { + lastRefillTime = time.Now() } - // refill tokens - tokenCount, lastRefillTime = rl.refillBucket(lastRefillTime, tokenCount) - - // update last refill time in Redis - rl.client.Set(context.Background(), sEnc+"_lastRefillTime", lastRefillTime.Format(time.RFC3339), 0) + tokenCount = rl.refill(tokenCount, lastRefillTime) - // check if enough tokens are available - if tokenCount > 0 { - // decrement token count - tokenCount-- - // update token count in Redis - err = rl.client.Set(context.Background(), sEnc, tokenCount, 0).Err() - if err != nil { - rl.logger.Printf("Error setting token count in Redis: %v", err) - return false - } + if tokenCount >= token { + tokenCount -= token + rl.client.Set(context.Background(), sEnc, tokenCount, 0) + rl.client.Set(context.Background(), sEnc+lastRefillPrefix, time.Now().Format(time.RFC3339), 0) return true } + rl.client.Set(context.Background(), sEnc+lastRefillPrefix, time.Now().Format(time.RFC3339), 0) return false } @@ -128,11 +126,11 @@ func (rl *RateLimiter) IsRequestAllowed(key string, tokens int64) bool { // Returns: // // gin.HandlerFunc: A Gin handler function that can be used as middleware in the Gin router. -func RateLimiterMiddleware(limiter *RateLimiter, tokens int64) gin.HandlerFunc { +func RateLimiterMiddleware(limiter *RateLimiter) gin.HandlerFunc { return func(c *gin.Context) { ip := c.ClientIP() - - if !limiter.IsRequestAllowed(ip, tokens) { + token := int64(1) + if !limiter.IsRequestAllowed(ip, token) { limiter.logger.Printf("Rate limit exceeded for IP: %s", ip) c.Header("X-RateLimit-Remaining", "0") c.JSON(http.StatusTooManyRequests, gin.H{"error": "too many requests"}) @@ -144,18 +142,13 @@ func RateLimiterMiddleware(limiter *RateLimiter, tokens int64) gin.HandlerFunc { } } -// refillBucket function calculate time, when token bucket can refill -func (rl *RateLimiter) refillBucket(lastRefillTime time.Time, tokenCount int64) (int64, time.Time) { +func (rl *RateLimiter) refill(currentTokens int64, lastRefillTime time.Time) int64 { now := time.Now() - duration := now.Sub(lastRefillTime) + elapsed := now.Sub(lastRefillTime) - // Calculate tokens to add based on elapsed time and rate - tokensToAdd := (duration.Nanoseconds() * rl.rate) / 1e9 // maybe this calculation isn't correct, but i try to avoid float64, because sometimes it not accuracy - - tokenCount = tokenCount + tokensToAdd - if tokenCount > rl.maxTokens { - tokenCount = rl.maxTokens - } + // calculate time which each token needs to refill in token bucket + tokensToAdd := elapsed.Nanoseconds() / rl.refillInterval.Nanoseconds() + newTokens := int64(math.Min(float64(currentTokens+tokensToAdd), float64(rl.maxTokens))) - return tokenCount, now + return newTokens } From 0b1c7e253e66c8fc556c212583b1ee5be67a1672 Mon Sep 17 00:00:00 2001 From: Nicolas-ggd Date: Sat, 29 Jun 2024 11:16:52 +0400 Subject: [PATCH 2/2] [FIX] test typo error (#9) --- README.md | 13 +++++++------ limiter_test.go | 2 +- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index fc7d107..21df7c4 100644 --- a/README.md +++ b/README.md @@ -33,27 +33,28 @@ func main() { r := gin.Default() client := redis.NewClient(&redis.Options{ - Addr: "localhost:6379", + Addr: "localhost:6379", }) // Call NewRateLimiter function from rrl package. // First parameter is the Redis client. // Second parameter is the rate (tokens per second). // Third parameter is the maximum number of tokens. - limiter := rrl.NewRateLimiter(client, 1, 10) + // Fourth parameter is time duration, token refill is depending on x time interval + limiter := rrl.NewRateLimiter(client, 1, 10, 30*time.Second) // Use RateLimiterMiddleware from rrl package and pass limiter. // This middleware works for all routes in your application, // including static files served when you open a web browser. - r.Use(rrl.RateLimiterMiddleware(limiter, 1)) + r.Use(rrl.RateLimiterMiddleware(limiter)) r.GET("/", func(c *gin.Context) { - c.JSON(http.StatusOK, gin.H{"message": "Welcome!"}) + c.JSON(http.StatusOK, gin.H{"message": "Welcome!"}) }) // Using this way allows the RateLimiterMiddleware to work for only specific routes. - r.GET("/some", rrl.RateLimiterMiddleware(limiter, 1), func(c *gin.Context) { - c.JSON(http.StatusOK, gin.H{"message": "Some!"}) + r.GET("/some", func(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{"message": "Some!"}) }) r.Run(":8080") diff --git a/limiter_test.go b/limiter_test.go index cd1a3d5..4bd8d3b 100644 --- a/limiter_test.go +++ b/limiter_test.go @@ -20,7 +20,7 @@ func setupRedisClient() *redis.Client { func TestRateLimiter_Allow(t *testing.T) { client := setupRedisClient() - limiter := NewRateLimiter(client, 1, 5) + limiter := NewRateLimiter(client, 1, 5, time.Second) tests := []struct { name string