diff --git a/cache.go b/cache.go index ad8e3da..df10c9e 100644 --- a/cache.go +++ b/cache.go @@ -30,8 +30,8 @@ func initRedis() *cache { } // getAllUserLimitersKeys retrieves all user limiter keys from cache -func (c *cache) getAllUserLimitersKeys(ctx context.Context) []string { - res, err := c.Keys(ctx, limiterCacheRegexPatternKey).Result() +func (c *cache) getAllUserLimitersKeys(ctx context.Context, pattern string) []string { + res, err := c.Keys(ctx, pattern).Result() if err != nil { log.Printf("Error retrieving keys: %v", err) return nil @@ -40,15 +40,15 @@ func (c *cache) getAllUserLimitersKeys(ctx context.Context) []string { } // increaseCap increases the capacity in cache for a given key -func (c *cache) increaseCap(ctx context.Context, key string, rl *limiterRole) { - if err := c.IncrBy(ctx, key, rl.addToken).Err(); err != nil { +func (c *cache) increaseCap(ctx context.Context, key string, tokenAmount int64) { + if err := c.IncrBy(ctx, key, tokenAmount).Err(); err != nil { log.Printf("Error increasing capacity: %v", err) } } // decreaseCap decreases the capacity in cache for a given key func (c *cache) decreaseCap(ctx context.Context, userIP string, rl *limiterRole) { - key := fmt.Sprintf("%s%s%s %s", userIP, limiterCacheMainKey, rl.operation, rl.endPoint) + key := fmt.Sprintf("%s%s%s%s", userIP, limiterCacheMainKey, rl.operation, rl.endPoint) if err := c.Decr(ctx, key).Err(); err != nil { log.Printf("Error decreasing capacity: %v", err) } diff --git a/const.go b/const.go index 0c14592..581eb9f 100644 --- a/const.go +++ b/const.go @@ -17,7 +17,7 @@ const ( const ( limiterCacheMainKey = "GOLIM_KEY" - limiterCacheRegexPatternKey = "*GOLIM_KEY*" + limiterCacheRegexPatternKey = "*GOLIM_KEY" ) const ( diff --git a/cron.go b/cron.go index ea9b02d..3588f47 100644 --- a/cron.go +++ b/cron.go @@ -3,23 +3,40 @@ package main import ( "context" "fmt" + "log" + "sync" "github.com/robfig/cron/v3" ) -// TODO: fix the g.rl is nil -func scheduleIncreaseCap(ctx context.Context, g *golim) { - cr := cron.New() +var cr = cron.New() + +func runCronTasks(ctx context.Context, g *golim) { _, err := cr.AddFunc("@every 1m", func() { - userKeys := g.cache.getAllUserLimitersKeys(ctx) - fmt.Println("Running tasks") - for _, key := range userKeys { - g.cache.increaseCap(ctx, key, g.limiterRole) - } + scheduleIncreaseCap(ctx, g) }) if err != nil { fmt.Println("Error scheduling task:", err) - return } cr.Start() } + +func scheduleIncreaseCap(ctx context.Context, g *golim) { + roles, err := g.getRoles(ctx) + if err != nil { + log.Println("Error getting roles:", err) + return + } + var wg sync.WaitGroup + for _, role := range roles { + userKeys := g.cache.getAllUserLimitersKeys(ctx, limiterCacheRegexPatternKey+role.Operation+role.Endpoint) + for _, key := range userKeys { + wg.Add(1) + go func(ctx context.Context, key string, tokenAmount int64) { + defer wg.Done() + g.cache.increaseCap(ctx, key, tokenAmount) + }(context.Background(), key, g.limiterRole.addToken) + } + } + wg.Wait() +} diff --git a/golim.go b/golim.go index 5e6fc77..c93973f 100644 --- a/golim.go +++ b/golim.go @@ -15,6 +15,7 @@ type limiterRole struct { operation string limiterID int endPoint string + method string bucketSize int initialToken int addToken int64 @@ -91,7 +92,7 @@ func (g *golim) removeRateLimiter(ctx context.Context) error { func (g *golim) ExecCMD(ctx context.Context) (interface{}, error) { if g.port != 0 { - go scheduleIncreaseCap(ctx, g) + go runCronTasks(ctx, g) return startServer(g) } if g.limiter != nil { @@ -223,6 +224,7 @@ func (g *golim) createAddCMD() *ff.Command { addFlags := ff.NewFlagSet("add") limiterID := addFlags.Int('l', "limiter", 0, "The limiter id") endpoint := addFlags.String('e', "endpoint", "", "The endpoint address") + method := addFlags.String('m', "method", "GET", "The endpoint method") bucketSize := addFlags.Int('b', "bsize", 100, "The initial bucket size") addToken := addFlags.Int('a', "add_token", 60, "The number of tokens to add per minute") initialToken := addFlags.Int('i', "initial_token", 100, "The number of tokens to add per minute") @@ -236,7 +238,7 @@ func (g *golim) createAddCMD() *ff.Command { if g.skip { return nil } - if *limiterID == 0 || *endpoint == "" { + if *limiterID == 0 && *endpoint != "" { g.limiterRole = &limiterRole{ operation: addRoleOperation, limiterID: *limiterID, @@ -244,6 +246,7 @@ func (g *golim) createAddCMD() *ff.Command { bucketSize: *bucketSize, addToken: int64(*addToken), initialToken: *initialToken, + method: *method, } } g.skip = true