From 0f3fe31a02a3744bec59cd636256de50805d2c69 Mon Sep 17 00:00:00 2001 From: khalil Date: Sat, 30 Mar 2024 22:48:59 +0330 Subject: [PATCH] :hammer: refactor the golim logics --- cache.go | 70 +++++++++++++++++++++++++++++++--------------- const.go | 26 +++++++++++------ cron.go | 2 +- golim.go | 71 ++++++++++++++++++++++++----------------------- main.go | 19 ++++++------- proxy.go | 37 ++++++++++++------------ role/query.sql.go | 6 ++++ 7 files changed, 137 insertions(+), 94 deletions(-) diff --git a/cache.go b/cache.go index f12d686..4a6c165 100644 --- a/cache.go +++ b/cache.go @@ -3,73 +3,99 @@ package main import ( "context" "encoding/json" + "fmt" + "log" "os" "time" "github.com/khalil-farashiani/golim/role" - "github.com/pingcap/log" "github.com/redis/go-redis/v9" ) +// cache struct to handle redis operations type cache struct { *redis.Client } +// initRedis initializes redis connection func initRedis() *cache { url := os.Getenv("REDIS_URI") opts, err := redis.ParseURL(url) if err != nil { - panic(err) + log.Fatalf("Error parsing Redis URL: %v", err) } return &cache{ redis.NewClient(opts), } } +// getAllUserLimitersKeys retrieves all user limiter keys from cache func (c *cache) getAllUserLimitersKeys(ctx context.Context) []string { - res, err := c.Keys(ctx, "*GLOLIM_KEY*").Result() + res, err := c.Keys(ctx, limiterCacheRegexPatternKey).Result() if err != nil { - log.Fatal(err.Error()) + log.Printf("Error retrieving keys: %v", err) + return nil } return res } +// increaseCap increases the capacity in cache for a given key func (c *cache) increaseCap(ctx context.Context, key string, rl *limiterRole) { - c.IncrBy(ctx, key, rl.addToken) + if err := c.IncrBy(ctx, key, rl.addToken).Err(); err != nil { + log.Printf("Error increasing capacity: %v", err) + } } +// decreaseCap decreases the capacity in cache for a given key func (c *cache) decreaseCap(ctx context.Context, userIP string, rl *limiterRole) { - key := userIP + "GLOLIM_KEY" + rl.operation + " " + rl.endPoint - c.Decr(ctx, key) + key := fmt.Sprintf("%s%s%s %s", userIP, limiterCacheMainKey, rl.operation, rl.endPoint) + if err := c.Decr(ctx, key).Err(); err != nil { + log.Printf("Error decreasing capacity: %v", err) + } } +// setLimiter sets a limiter in cache based on parameters func (c *cache) setLimiter(ctx context.Context, params *role.GetRoleParams, val *role.GetRoleRow) { - key := params.Operation + " " + params.Endpoint - err := c.Set(ctx, key, val, time.Minute*60).Err() - if err != nil { - panic(err) + key := fmt.Sprintf("%s %s", params.Operation, params.Endpoint) + if err := c.Set(ctx, key, *val, time.Minute*60).Err(); err != nil { + log.Printf("Error setting limiter: %v", err) } } +// getLimiter retrieves a limiter from cache based on parameters func (c *cache) getLimiter(ctx context.Context, params role.GetRoleParams) *role.GetRoleRow { var res role.GetRoleRow - var key = params.Operation + " " + params.Endpoint + key := fmt.Sprintf("%s %s", params.Operation, params.Endpoint) val, err := c.Get(ctx, key).Result() - if err != nil && err != redis.Nil { - panic(err) + if err != nil { + if err != redis.Nil { + log.Printf("Error getting limiter: %v", err) + } + return nil } json.Unmarshal([]byte(val), &res) return &res } -func (c *cache) getUserRequestCap(ctx context.Context, ipAddr string, rl *limiterRole) int64 { - key := ipAddr + rl.endPoint + rl.endPoint - var res = new(int64) - val, err := c.Get(ctx, key).Result() +// setUserRequestCap sets user request capacity in cache +func (c *cache) setUserRequestCap(ctx context.Context, key string, role role.GetRoleRow) { + if err := c.Set(ctx, key, role.InitialTokens, time.Hour).Err(); err != nil { + log.Printf("Error setting user request capacity: %v", err) + } +} - if err != nil && err != redis.Nil { - panic(err) +// getUserRequestCap retrieves user request capacity from cache +func (c *cache) getUserRequestCap(ctx context.Context, ipAddr string, g *golim, role role.GetRoleRow) int64 { + key := fmt.Sprintf("%s%s%s %s", ipAddr, limiterCacheMainKey, g.limiterRole.operation, g.limiterRole.endPoint) + val, err := c.Get(ctx, key).Result() + if err != nil { + if err != redis.Nil { + go c.setUserRequestCap(ctx, key, role) + log.Printf("Error getting user request capacity: %v", err) + } + return 0 } - json.Unmarshal([]byte(val), res) - return *res + var res int64 + json.Unmarshal([]byte(val), &res) + return res } diff --git a/const.go b/const.go index 5128c56..1d4442d 100644 --- a/const.go +++ b/const.go @@ -6,26 +6,34 @@ const ( ) const ( - addRoleOperation = "add" - removeRoleOperationID = "remove" - getRolesOperationID = "getRoles" + addRoleOperation = "add" + removeRoleOperation = "remove" + getRolesOperation = "getRoles" ) const ( - OperationGet = "GET" - OperationPost = "POST" - OperationPut = "PUT" - OperationPatch = "PATCH" - OperationDelete = "DELETE" + limiterCacheMainKey = "GOLIM_KEY" + limiterCacheRegexPatternKey = "*GOLIM_KEY*" ) const ( unknownLimiterRoleError = "unknown limiter role operation" unknownLimiterError = "unknown limiter operation" - unsupportedOperationError = "unsupported operation" requiredNameDestinationError = "name and destination is required" requiredLimiterIDError = "limiter id is required" createProxyError = "Error creating proxy request" sendingProxyError = "Error sending proxy request" slowDownError = "slow down" + notFoundSqlError = "sql: no rows in result set" +) + +const ( + helpMessageUsage = ` +Golim help: + - golim run -p{--port} [run in the specific port default is 8080] + - golim get -l{--limiter} [get roles of a rate limiter] + - golim init -n{--name} foo -d{--destination} 8.8.8.8 [initial new rate limiter] + - golim add -l{--limiter} -e{--endpoint} -b{--bsize} -a{--add_token} -i{--initial_token} [add specific role to limiter] + - golim remove -i{--id} [remove specific role] + - golim remove-limiter -l{--limiter} [remove specific limiter]` ) diff --git a/cron.go b/cron.go index d65f078..109c273 100644 --- a/cron.go +++ b/cron.go @@ -8,9 +8,9 @@ import ( ) func scheduleIncreaseCap(ctx context.Context, g *golim) { - userKeys := g.cache.getAllUserLimitersKeys(ctx) cr := cron.New() _, err := cr.AddFunc("@every 1m", func() { + userKeys := g.cache.getAllUserLimitersKeys(ctx) fmt.Println("Running tasks") for _, key := range userKeys { g.cache.increaseCap(ctx, key, g.limiterRole) diff --git a/golim.go b/golim.go index a328a2d..bd0861f 100644 --- a/golim.go +++ b/golim.go @@ -4,6 +4,8 @@ import ( "context" "database/sql" "errors" + "fmt" + "strings" "github.com/khalil-farashiani/golim/role" "github.com/peterbourgon/ff/v4" @@ -38,20 +40,28 @@ type golim struct { Store } -func (g *golim) getRole(ctx context.Context) (role.GetRoleRow, error) { +func (g *golim) getRole(ctx context.Context) (role.GetRoleRow, bool, error) { params := toGetRole(g) data := g.cache.getLimiter(ctx, params) if data != nil { - return *data, nil + return *data, false, nil } + row, err := g.db.GetRole(ctx, params) if err != nil { - return role.GetRoleRow{}, err + if strings.Contains(err.Error(), notFoundSqlError) { + return role.GetRoleRow{}, false, nil + } + return role.GetRoleRow{}, false, err + } + + if row.Endpoint == "" { + return role.GetRoleRow{}, false, nil } - go func() { - g.cache.setLimiter(ctx, ¶ms, &row) - }() - return row, nil + + g.cache.setLimiter(ctx, ¶ms, &row) + + return row, true, nil } func (g *golim) getRoles(ctx context.Context) ([]role.GetRolesRow, error) { @@ -75,12 +85,13 @@ func (g *golim) createRateLimiter(ctx context.Context) error { } func (g *golim) removeRateLimiter(ctx context.Context) error { - return g.db.DeleteRateLimiter(ctx, g.limiter.id.(int64)) + return g.db.DeleteRateLimiter(ctx, int64(g.limiter.id.(int))) } func (g *golim) ExecCMD(ctx context.Context) (interface{}, error) { if g.port != 0 { + go scheduleIncreaseCap(ctx, g) return startServer(g) } if g.limiter != nil { @@ -89,7 +100,7 @@ func (g *golim) ExecCMD(ctx context.Context) (interface{}, error) { if g.limiterRole != nil { return handleLimiterRoleOperation(g, ctx) } - return nil, errors.New(unsupportedOperationError) + return nil, nil } func handleLimiterOperation(g *golim, ctx context.Context) (interface{}, error) { @@ -106,9 +117,9 @@ func handleLimiterRoleOperation(g *golim, ctx context.Context) (interface{}, err switch g.limiterRole.operation { case addRoleOperation: return nil, g.addRole(ctx) - case removeRoleOperationID: + case removeRoleOperation: return nil, g.removeRole(ctx) - case getRolesOperationID: + case getRolesOperation: return g.getRoles(ctx) } return nil, errors.New(unknownLimiterRoleError) @@ -124,12 +135,14 @@ func newLimiter(db *sql.DB, cache *cache) *golim { } func (g *golim) createHelpCMD() *ff.Command { + helpFlags := ff.NewFlagSet("help") return &ff.Command{ Name: "help", Usage: "golim help", ShortHelp: "Displays help information for golim", - Flags: ff.NewFlagSet("help"), + Flags: helpFlags, Exec: func(ctx context.Context, args []string) error { + fmt.Println(helpMessageUsage) return nil }, } @@ -183,13 +196,13 @@ func (g *golim) createInitCMD() *ff.Command { } func (g *golim) addRemoveLimiterCMD() *ff.Command { - initFlags := ff.NewFlagSet("remove-limiter") - limiterID := initFlags.Int('l', "limiter", 0, "The name of the golim to initialize") + removeFlags := ff.NewFlagSet("removel") + limiterID := removeFlags.Int('l', "limiter", 0, "The name of the golim to initialize") return &ff.Command{ - Name: "init", - Usage: "golim init -n ", + Name: "removel", + Usage: "golim removel -l ", ShortHelp: "Initializes a standalone rate golim", - Flags: initFlags, + Flags: removeFlags, Exec: func(ctx context.Context, args []string) error { if g.skip { return nil @@ -241,21 +254,21 @@ func (g *golim) createAddCMD() *ff.Command { func (g *golim) createRemoveCMD() *ff.Command { removeFlags := ff.NewFlagSet("remove") - limiterID := removeFlags.Int('l', "limiter", 0, "The limiter id") + roleID := removeFlags.Int('i', "role_id", 0, "the role id") return &ff.Command{ - Name: "add", - Usage: "golim add -e -b -a ", + Name: "remove", + Usage: "golim remove -i ", ShortHelp: "Adds a new golim with the specified configuration", Flags: removeFlags, Exec: func(ctx context.Context, args []string) error { if g.skip { return nil } - if *limiterID == 0 { + if *roleID != 0 { g.limiterRole = &limiterRole{ - operation: removeRoleOperationID, - limiterID: *limiterID, + operation: removeRoleOperation, + limiterID: *roleID, } } g.skip = true @@ -279,7 +292,7 @@ func (g *golim) createGetRolesCMD() *ff.Command { } if *limiterID != 0 { g.limiterRole = &limiterRole{ - operation: getRolesOperationID, + operation: getRolesOperation, limiterID: *limiterID, } } else { @@ -315,13 +328,3 @@ func toGetRole(g *golim) role.GetRoleParams { Operation: g.limiterRole.operation, } } - -func toRole(row role.GetRoleRow) role.Role { - return role.Role{ - Endpoint: row.Endpoint, - Operation: row.Operation, - BucketSize: row.BucketSize, - AddTokenPerMin: row.AddTokenPerMin, - InitialTokens: row.InitialTokens, - } -} diff --git a/main.go b/main.go index def08d6..48ee299 100644 --- a/main.go +++ b/main.go @@ -5,6 +5,7 @@ import ( "database/sql" _ "embed" "fmt" + "log" "os" "strings" @@ -19,36 +20,32 @@ var ddl string func initDB(ctx context.Context) *sql.DB { db, err := sql.Open("sqlite3", "golim.sqlite") if err != nil { - panic(err) + log.Fatal(err) } // create tables if _, err := db.ExecContext(ctx, ddl); err != nil && !strings.Contains(err.Error(), "already exists") { - panic(err) + log.Fatal(err) } return db } func main() { ctx := context.Background() - db := initDB(ctx) cache := initRedis() limiter, err := initFlags(ctx, db, cache) if err != nil { - fmt.Fprintf(os.Stderr, "error: %v\n", err) - os.Exit(1) + log.Fatalf("Error initializing limiter: %v", err) } + data, err := limiter.ExecCMD(ctx) if err != nil { - fmt.Fprintf(os.Stderr, "error: %v\n", err) - os.Exit(1) + log.Fatalf("Error executing command: %v", err) } if data != nil { makeTable(toSlice(data)) - fmt.Fprintf(os.Stdout, "DONE") - return } - fmt.Printf("DONE") + fmt.Println("DONE") } // initFlags get command and flags from std input to create a golim or role @@ -71,7 +68,7 @@ func createRootCommand(g *golim) *ff.Command { addCMD := g.createAddCMD() removeCMD := g.createRemoveCMD() getCMD := g.createGetRolesCMD() - removeLimiterCMD := g.createRemoveCMD() + removeLimiterCMD := g.addRemoveLimiterCMD() runCMD := g.createRunCMD() rootCmd := &ff.Command{ diff --git a/proxy.go b/proxy.go index 8a7ce68..c89998e 100644 --- a/proxy.go +++ b/proxy.go @@ -2,35 +2,42 @@ package main import ( "fmt" + "github.com/khalil-farashiani/golim/role" "html" "io" "log" "net/http" "net/url" + "time" ) -var customTransport = http.DefaultTransport +var customTransport = &http.Transport{ + MaxIdleConns: 100, + MaxIdleConnsPerHost: 100, +} + +var client = &http.Client{ + Timeout: time.Second * 10, + Transport: customTransport, +} -func runProxy(g *golim) func(w http.ResponseWriter, r *http.Request) { +func runProxy(g *golim) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { method := r.Method path := html.EscapeString(r.URL.Path) - g.limiterRole = &limiterRole{ operation: method, endPoint: path, } - if !isOkRequest(r, g) { - http.Error(w, slowDownError, http.StatusTooManyRequests) - return - } - - role, err := g.getRole(r.Context()) + role, needToCheckRequest, err := g.getRole(r.Context()) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } - + if needToCheckRequest && !isOkRequest(r, g, role) { + http.Error(w, slowDownError, http.StatusTooManyRequests) + return + } newURL := url.URL{ Scheme: "http", Host: role.Destination.String, @@ -47,30 +54,26 @@ func runProxy(g *golim) func(w http.ResponseWriter, r *http.Request) { proxyReq.Header.Add(name, value) } } - - resp, err := customTransport.RoundTrip(proxyReq) + resp, err := client.Do(proxyReq) if err != nil { http.Error(w, sendingProxyError, http.StatusInternalServerError) return } defer resp.Body.Close() - for name, values := range resp.Header { for _, value := range values { w.Header().Add(name, value) } } - w.WriteHeader(resp.StatusCode) - io.Copy(w, resp.Body) } } -func isOkRequest(r *http.Request, g *golim) bool { +func isOkRequest(r *http.Request, g *golim, role role.GetRoleRow) bool { ctx := r.Context() userIP := readUserIP(r) - capacity := g.cache.getUserRequestCap(ctx, userIP, g.limiterRole) + capacity := g.cache.getUserRequestCap(ctx, userIP, g, role) if capacity > 0 { go g.cache.decreaseCap(r.Context(), userIP, g.limiterRole) return true diff --git a/role/query.sql.go b/role/query.sql.go index 0bf66da..fc48330 100644 --- a/role/query.sql.go +++ b/role/query.sql.go @@ -8,6 +8,7 @@ package role import ( "context" "database/sql" + "encoding/json" ) const crateRateLimiter = `-- name: CrateRateLimiter :one @@ -159,6 +160,11 @@ type GetRoleRow struct { Destination sql.NullString } + +func (g GetRoleRow) MarshalBinary() ([]byte, error) { + return json.Marshal(g) +} + func (q *Queries) GetRole(ctx context.Context, arg GetRoleParams) (GetRoleRow, error) { row := q.db.QueryRowContext(ctx, getRole, arg.Endpoint, arg.Operation) var i GetRoleRow