Skip to content

Commit

Permalink
🐛 fix get roles in cron tasks
Browse files Browse the repository at this point in the history
  • Loading branch information
khalil-farashiani committed Apr 3, 2024
1 parent 8f72f20 commit 5ba7e99
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 17 deletions.
10 changes: 5 additions & 5 deletions cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ func initRedis() *cache {
}

// getAllUserLimitersKeys retrieves all user limiter keys from cache
func (c *cache) getAllUserLimitersKeys(ctx context.Context) []string {
res, err := c.Keys(ctx, limiterCacheRegexPatternKey).Result()
func (c *cache) getAllUserLimitersKeys(ctx context.Context, pattern string) []string {
res, err := c.Keys(ctx, pattern).Result()
if err != nil {
log.Printf("Error retrieving keys: %v", err)
return nil
Expand All @@ -40,15 +40,15 @@ func (c *cache) getAllUserLimitersKeys(ctx context.Context) []string {
}

// increaseCap increases the capacity in cache for a given key
func (c *cache) increaseCap(ctx context.Context, key string, rl *limiterRole) {
if err := c.IncrBy(ctx, key, rl.addToken).Err(); err != nil {
func (c *cache) increaseCap(ctx context.Context, key string, tokenAmount int64) {
if err := c.IncrBy(ctx, key, tokenAmount).Err(); err != nil {
log.Printf("Error increasing capacity: %v", err)
}
}

// decreaseCap decreases the capacity in cache for a given key
func (c *cache) decreaseCap(ctx context.Context, userIP string, rl *limiterRole) {
key := fmt.Sprintf("%s%s%s %s", userIP, limiterCacheMainKey, rl.operation, rl.endPoint)
key := fmt.Sprintf("%s%s%s%s", userIP, limiterCacheMainKey, rl.operation, rl.endPoint)
if err := c.Decr(ctx, key).Err(); err != nil {
log.Printf("Error decreasing capacity: %v", err)
}
Expand Down
2 changes: 1 addition & 1 deletion const.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ const (

const (
limiterCacheMainKey = "GOLIM_KEY"
limiterCacheRegexPatternKey = "*GOLIM_KEY*"
limiterCacheRegexPatternKey = "*GOLIM_KEY"
)

const (
Expand Down
35 changes: 26 additions & 9 deletions cron.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,40 @@ package main
import (
"context"
"fmt"
"log"
"sync"

"github.com/robfig/cron/v3"
)

// TODO: fix the g.rl is nil
func scheduleIncreaseCap(ctx context.Context, g *golim) {
cr := cron.New()
var cr = cron.New()

func runCronTasks(ctx context.Context, g *golim) {
_, err := cr.AddFunc("@every 1m", func() {
userKeys := g.cache.getAllUserLimitersKeys(ctx)
fmt.Println("Running tasks")
for _, key := range userKeys {
g.cache.increaseCap(ctx, key, g.limiterRole)
}
scheduleIncreaseCap(ctx, g)
})
if err != nil {
fmt.Println("Error scheduling task:", err)
return
}
cr.Start()
}

func scheduleIncreaseCap(ctx context.Context, g *golim) {
roles, err := g.getRoles(ctx)
if err != nil {
log.Println("Error getting roles:", err)
return
}
var wg sync.WaitGroup
for _, role := range roles {
userKeys := g.cache.getAllUserLimitersKeys(ctx, limiterCacheRegexPatternKey+role.Operation+role.Endpoint)
for _, key := range userKeys {
wg.Add(1)
go func(ctx context.Context, key string, tokenAmount int64) {
defer wg.Done()
g.cache.increaseCap(ctx, key, tokenAmount)
}(context.Background(), key, g.limiterRole.addToken)
}
}
wg.Wait()
}
7 changes: 5 additions & 2 deletions golim.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ type limiterRole struct {
operation string
limiterID int
endPoint string
method string
bucketSize int
initialToken int
addToken int64
Expand Down Expand Up @@ -91,7 +92,7 @@ func (g *golim) removeRateLimiter(ctx context.Context) error {
func (g *golim) ExecCMD(ctx context.Context) (interface{}, error) {

if g.port != 0 {
go scheduleIncreaseCap(ctx, g)
go runCronTasks(ctx, g)
return startServer(g)
}
if g.limiter != nil {
Expand Down Expand Up @@ -223,6 +224,7 @@ func (g *golim) createAddCMD() *ff.Command {
addFlags := ff.NewFlagSet("add")
limiterID := addFlags.Int('l', "limiter", 0, "The limiter id")
endpoint := addFlags.String('e', "endpoint", "", "The endpoint address")
method := addFlags.String('m', "method", "GET", "The endpoint method")
bucketSize := addFlags.Int('b', "bsize", 100, "The initial bucket size")
addToken := addFlags.Int('a', "add_token", 60, "The number of tokens to add per minute")
initialToken := addFlags.Int('i', "initial_token", 100, "The number of tokens to add per minute")
Expand All @@ -236,14 +238,15 @@ func (g *golim) createAddCMD() *ff.Command {
if g.skip {
return nil
}
if *limiterID == 0 || *endpoint == "" {
if *limiterID == 0 && *endpoint != "" {
g.limiterRole = &limiterRole{
operation: addRoleOperation,
limiterID: *limiterID,
endPoint: *endpoint,
bucketSize: *bucketSize,
addToken: int64(*addToken),
initialToken: *initialToken,
method: *method,
}
}
g.skip = true
Expand Down

0 comments on commit 5ba7e99

Please sign in to comment.