Skip to content

Commit

Permalink
Merge pull request #12 from Nicolas-ggd/change-limiter
Browse files Browse the repository at this point in the history
[UPDATE] configuration
  • Loading branch information
Nicolas-ggd authored Aug 26, 2024
2 parents ec98d4f + 28bdffd commit 9094c4a
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 68 deletions.
21 changes: 13 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,28 +33,33 @@ func main() {
r := gin.Default()

client := redis.NewClient(&redis.Options{
Addr: "localhost:6379",
Addr: "localhost:6379",
})

// Call NewRateLimiter function from rrl package.
// First parameter is the Redis client.
// Second parameter is the rate (tokens per second).
// Third parameter is the maximum number of tokens.
// Fourth parameter is time duration, token refill is depending on x time interval
limiter := rrl.NewRateLimiter(client, 1, 10, 30*time.Second)
limiter, err := rrl.NewRateLimiter(&rrl.RateLimiter{
Rate: 1, // amount of each request as a token
MaxTokens: 5, // maximum token quantity for requests
RefillInterval: 15 * time.Second, // each token fill in 'X' time frame
Client: client, // redis client
HashKey: false, // make true if you want to hash redis key
})
if err != nil {
log.Fatal(err)
}

// Use RateLimiterMiddleware from rrl package and pass limiter.
// This middleware works for all routes in your application,
// including static files served when you open a web browser.
r.Use(rrl.RateLimiterMiddleware(limiter))

r.GET("/", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"message": "Welcome!"})
c.JSON(http.StatusOK, gin.H{"message": "Welcome!"})
})

// Using this way allows the RateLimiterMiddleware to work for only specific routes.
r.GET("/some", rrl.RateLimiterMiddleware(limiter), func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"message": "Some!"})
c.JSON(http.StatusOK, gin.H{"message": "Some!"})
})

r.Run(":8080")
Expand Down
86 changes: 29 additions & 57 deletions limiter.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,12 @@ import (
"context"
b64 "encoding/base64"
"errors"
"github.com/redis/go-redis/v9"
"log"
"math"
"net/http"
"os"
"sync"
"time"

"github.com/gin-gonic/gin"
"github.com/redis/go-redis/v9"
)

// define constant variable of keyPrefix to avoid duplicate key in Redis
Expand All @@ -24,24 +21,27 @@ const (
// RateLimiter is struct based on Redis
type RateLimiter struct {
// represents the rate at which the bucket should be filled
rate int64
Rate int64

// represents the max tokens capacity that the bucket can hold
maxTokens int64
MaxTokens int64

// tokens currently present in the bucket at any time
currentToken int64

// lastRefillTime represents time that this bucket fill operation was tried
refillInterval time.Duration

mutex sync.Mutex
RefillInterval time.Duration

// client is redis Client
client *redis.Client
Client *redis.Client

// decide to hash redis key
HashKey bool

// logger for logging rate limit events
logger *log.Logger

mutex sync.Mutex
}

// encodeKey function encodes received value parameter with base64
Expand All @@ -50,15 +50,8 @@ func encodeKey(value string) string {
}

// NewRateLimiter to received and define new RateLimiter struct
func NewRateLimiter(client *redis.Client, rate, maxToken int64, refillInterval time.Duration) *RateLimiter {
return &RateLimiter{
client: client,
rate: rate,
maxTokens: maxToken,
refillInterval: refillInterval,
currentToken: maxToken,
logger: log.New(os.Stdout, "RateLimiter: ", log.Lmicroseconds),
}
func NewRateLimiter(config *RateLimiter) (*RateLimiter, error) {
return config, nil
}

// IsRequestAllowed function is a method of the RateLimiter struct. It is responsible for determining whether a specific request should be allowed based on the rate limiting rules.
Expand All @@ -71,24 +64,29 @@ func NewRateLimiter(client *redis.Client, rate, maxToken int64, refillInterval t
// Returns:
//
// bool: Returns true if the request is allowed, false otherwise.
func (rl *RateLimiter) IsRequestAllowed(key string, token int64) bool {
func (rl *RateLimiter) IsRequestAllowed(key string) bool {
// use mutex to avoid race condition
rl.mutex.Lock()
defer rl.mutex.Unlock()

sEnc := keyPrefix + encodeKey(key)
var sEnc string
if rl.HashKey {
sEnc = keyPrefix + encodeKey(key)
} else {
sEnc = keyPrefix + key
}

tokenCount, err := rl.client.Get(context.Background(), sEnc).Int64()
tokenCount, err := rl.Client.Get(context.Background(), sEnc).Int64()
if err != nil && !errors.Is(err, redis.Nil) {
rl.logger.Printf("Error getting token count from Redis: %v", err)
return false
}

if errors.Is(err, redis.Nil) {
tokenCount = rl.maxTokens
tokenCount = rl.MaxTokens
}

lastRefillTimeStr, err := rl.client.Get(context.Background(), sEnc+lastRefillPrefix).Result()
lastRefillTimeStr, err := rl.Client.Get(context.Background(), sEnc+lastRefillPrefix).Result()
var lastRefillTime time.Time
if err == nil {
lastRefillTime, err = time.Parse(time.RFC3339, lastRefillTimeStr)
Expand All @@ -105,50 +103,24 @@ func (rl *RateLimiter) IsRequestAllowed(key string, token int64) bool {

tokenCount = rl.refill(tokenCount, lastRefillTime)

if tokenCount >= token {
tokenCount -= token
rl.client.Set(context.Background(), sEnc, tokenCount, 0)
rl.client.Set(context.Background(), sEnc+lastRefillPrefix, time.Now().Format(time.RFC3339), 0)
if tokenCount >= rl.Rate {
tokenCount -= rl.Rate
rl.Client.Set(context.Background(), sEnc, tokenCount, 0)
rl.Client.Set(context.Background(), sEnc+lastRefillPrefix, time.Now().Format(time.RFC3339), 0)
return true
}

rl.client.Set(context.Background(), sEnc+lastRefillPrefix, time.Now().Format(time.RFC3339), 0)
rl.Client.Set(context.Background(), sEnc+lastRefillPrefix, time.Now().Format(time.RFC3339), 0)
return false
}

// RateLimiterMiddleware function is a middleware for the Gin web framework that enforces rate limiting on incoming requests.
// This middleware uses a RateLimiter instance to track and limit the number of requests a client can make within a specified time interval.
//
// Parameters:
//
// limiter (*RateLimiter): An instance of the RateLimiter struct that defines the rate limiting rules and interacts with Redis to enforce them.
//
// Returns:
//
// gin.HandlerFunc: A Gin handler function that can be used as middleware in the Gin router.
func RateLimiterMiddleware(limiter *RateLimiter) gin.HandlerFunc {
return func(c *gin.Context) {
ip := c.ClientIP()
token := int64(1)
if !limiter.IsRequestAllowed(ip, token) {
limiter.logger.Printf("Rate limit exceeded for IP: %s", ip)
c.Header("X-RateLimit-Remaining", "0")
c.JSON(http.StatusTooManyRequests, gin.H{"error": "too many requests"})
c.Abort()
return
}

c.Next()
}
}

func (rl *RateLimiter) refill(currentTokens int64, lastRefillTime time.Time) int64 {
now := time.Now()
elapsed := now.Sub(lastRefillTime)

// calculate time which each token needs to refill in token bucket
tokensToAdd := elapsed.Nanoseconds() / rl.refillInterval.Nanoseconds()
newTokens := int64(math.Min(float64(currentTokens+tokensToAdd), float64(rl.maxTokens)))
tokensToAdd := elapsed.Nanoseconds() / rl.RefillInterval.Nanoseconds()
newTokens := int64(math.Min(float64(currentTokens+tokensToAdd), float64(rl.MaxTokens)))

return newTokens
}
15 changes: 12 additions & 3 deletions limiter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,16 @@ func setupRedisClient() *redis.Client {

func TestRateLimiter_Allow(t *testing.T) {
client := setupRedisClient()
limiter := NewRateLimiter(client, 1, 5, time.Second)
limiter, err := NewRateLimiter(&RateLimiter{
Rate: 1,
MaxTokens: 5,
RefillInterval: 1 * time.Second,
Client: client,
HashKey: false,
})
if err != nil {
t.Fatal(err)
}

tests := []struct {
name string
Expand All @@ -34,15 +43,15 @@ func TestRateLimiter_Allow(t *testing.T) {
{"Fourth Request", "user1", true, 0},
{"Fifth Request", "user1", true, 0},
{"Sixth Request", "user1", false, 0},
{"Wait for Refill", "user1", true, 5 * time.Second},
{"Wait for Refill", "user1", true, 1 * time.Second},
}

for _, tt := range tests {
if tt.delay > 0 {
time.Sleep(tt.delay)
}
t.Run(tt.name, func(t *testing.T) {
result := limiter.IsRequestAllowed(tt.key, 1)
result := limiter.IsRequestAllowed(tt.key)
assert.Equal(t, tt.expected, result)
})
}
Expand Down
39 changes: 39 additions & 0 deletions middleware.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package rrl

import (
"github.com/gin-gonic/gin"
"log"
"net/http"
"strconv"
)

const (
HeaderRateLimit = "X-RateLimit-Limit"
HeaderRateLimitRemaining = "X-RateLimit-Remaining"
)

// RateLimiterMiddleware function is a middleware for the Gin web framework that enforces rate limiting on incoming requests.
// This middleware uses a RateLimiter instance to track and limit the number of requests a client can make within a specified time interval.
//
// Parameters:
//
// limiter (*RateLimiter): An instance of the RateLimiter struct that defines the rate limiting rules and interacts with Redis to enforce them.
//
// Returns:
//
// gin.HandlerFunc: A Gin handler function that can be used as middleware in the Gin router.
func RateLimiterMiddleware(limiter *RateLimiter) gin.HandlerFunc {
return func(c *gin.Context) {
ip := c.ClientIP()
if !limiter.IsRequestAllowed(ip) {
log.Printf("Rate limit exceeded for IP: %s", ip)
c.Header(HeaderRateLimitRemaining, strconv.Itoa(int(limiter.currentToken)))
c.Header(HeaderRateLimit, strconv.Itoa(int(limiter.MaxTokens)))
c.JSON(http.StatusTooManyRequests, gin.H{"error": "too many requests"})
c.Abort()
return
}

c.Next()
}
}

0 comments on commit 9094c4a

Please sign in to comment.