diff --git a/cache.go b/cache.go index 4a6c165..ad8e3da 100644 --- a/cache.go +++ b/cache.go @@ -90,10 +90,11 @@ func (c *cache) getUserRequestCap(ctx context.Context, ipAddr string, g *golim, val, err := c.Get(ctx, key).Result() if err != nil { if err != redis.Nil { - go c.setUserRequestCap(ctx, key, role) log.Printf("Error getting user request capacity: %v", err) + return 0 } - return 0 + go c.setUserRequestCap(ctx, key, role) + return initialTokenForTheFirstUSerRequest } var res int64 json.Unmarshal([]byte(val), &res) diff --git a/const.go b/const.go index 1d4442d..0c14592 100644 --- a/const.go +++ b/const.go @@ -11,6 +11,10 @@ const ( getRolesOperation = "getRoles" ) +const ( + initialTokenForTheFirstUSerRequest = 1 +) + const ( limiterCacheMainKey = "GOLIM_KEY" limiterCacheRegexPatternKey = "*GOLIM_KEY*" diff --git a/cron.go b/cron.go index 109c273..ea9b02d 100644 --- a/cron.go +++ b/cron.go @@ -7,6 +7,7 @@ import ( "github.com/robfig/cron/v3" ) +// TODO: fix the g.rl is nil func scheduleIncreaseCap(ctx context.Context, g *golim) { cr := cron.New() _, err := cr.AddFunc("@every 1m", func() { diff --git a/golim.go b/golim.go index bd0861f..5e6fc77 100644 --- a/golim.go +++ b/golim.go @@ -44,7 +44,7 @@ func (g *golim) getRole(ctx context.Context) (role.GetRoleRow, bool, error) { params := toGetRole(g) data := g.cache.getLimiter(ctx, params) if data != nil { - return *data, false, nil + return *data, true, nil } row, err := g.db.GetRole(ctx, params) @@ -59,7 +59,7 @@ func (g *golim) getRole(ctx context.Context) (role.GetRoleRow, bool, error) { return role.GetRoleRow{}, false, nil } - g.cache.setLimiter(ctx, ¶ms, &row) + go g.cache.setLimiter(ctx, ¶ms, &row) return row, true, nil } diff --git a/main.go b/main.go index 48ee299..f8f419c 100644 --- a/main.go +++ b/main.go @@ -29,6 +29,7 @@ func initDB(ctx context.Context) *sql.DB { return db } +// everything start from main func main() { ctx := context.Background() db := initDB(ctx) @@ -62,7 +63,6 @@ func initFlags(ctx context.Context, db *sql.DB, cache *cache) (*golim, error) { func createRootCommand(g *golim) *ff.Command { rootFlags := ff.NewFlagSet("golim") - helpCMD := g.createHelpCMD() initCMD := g.createInitCMD() addCMD := g.createAddCMD() @@ -70,7 +70,6 @@ func createRootCommand(g *golim) *ff.Command { getCMD := g.createGetRolesCMD() removeLimiterCMD := g.addRemoveLimiterCMD() runCMD := g.createRunCMD() - rootCmd := &ff.Command{ Name: "golim", Usage: "golim [COMMANDS] ", @@ -80,6 +79,5 @@ func createRootCommand(g *golim) *ff.Command { return nil }, } - return rootCmd } diff --git a/proxy.go b/proxy.go index c89998e..09a1701 100644 --- a/proxy.go +++ b/proxy.go @@ -6,6 +6,7 @@ import ( "html" "io" "log" + "net" "net/http" "net/url" "time" @@ -21,6 +22,8 @@ var client = &http.Client{ Transport: customTransport, } +var proxyClient = client + func runProxy(g *golim) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { method := r.Method @@ -29,47 +32,68 @@ func runProxy(g *golim) http.HandlerFunc { operation: method, endPoint: path, } - role, needToCheckRequest, err := g.getRole(r.Context()) + currentUserRole, needToCheckRequest, err := g.getRole(r.Context()) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } - if needToCheckRequest && !isOkRequest(r, g, role) { + if needToCheckRequest && !isOkRequest(r, g, currentUserRole) { http.Error(w, slowDownError, http.StatusTooManyRequests) return } - newURL := url.URL{ - Scheme: "http", - Host: role.Destination.String, - Path: role.Endpoint, - } - targetURL := newURL - proxyReq, err := http.NewRequest(r.Method, targetURL.String(), r.Body) - if err != nil { - http.Error(w, createProxyError, http.StatusInternalServerError) - return - } - for name, values := range r.Header { - for _, value := range values { - proxyReq.Header.Add(name, value) - } - } - resp, err := client.Do(proxyReq) - if err != nil { - http.Error(w, sendingProxyError, http.StatusInternalServerError) - return + proxyRequest(w, r, g, currentUserRole) + } +} + +func proxyRequest(w http.ResponseWriter, r *http.Request, g *golim, role role.GetRoleRow) { + newURL := buildURL(role) + proxyReq := createProxyRequest(r, newURL) + copyHeaders(r, proxyReq) + resp := sendProxyRequest(proxyReq) + defer resp.Body.Close() + copyResponseHeaders(resp, w) + writeResponse(w, resp) +} + +func buildURL(role role.GetRoleRow) url.URL { + return url.URL{ + Scheme: "http", + Host: role.Destination.String, + Path: role.Endpoint, + } +} + +func createProxyRequest(r *http.Request, newURL url.URL) *http.Request { + proxyReq, _ := http.NewRequest(r.Method, newURL.String(), r.Body) + return proxyReq +} + +func copyHeaders(r *http.Request, proxyReq *http.Request) { + for name, values := range r.Header { + for _, value := range values { + proxyReq.Header.Add(name, value) } - defer resp.Body.Close() - for name, values := range resp.Header { - for _, value := range values { - w.Header().Add(name, value) - } + } +} + +func sendProxyRequest(proxyReq *http.Request) *http.Response { + resp, _ := proxyClient.Do(proxyReq) + return resp +} + +func copyResponseHeaders(resp *http.Response, w http.ResponseWriter) { + for name, values := range resp.Header { + for _, value := range values { + w.Header().Add(name, value) } - w.WriteHeader(resp.StatusCode) - io.Copy(w, resp.Body) } } +func writeResponse(w http.ResponseWriter, resp *http.Response) { + w.WriteHeader(resp.StatusCode) + io.Copy(w, resp.Body) +} + func isOkRequest(r *http.Request, g *golim, role role.GetRoleRow) bool { ctx := r.Context() userIP := readUserIP(r) @@ -85,7 +109,7 @@ func startServer(g *golim) (interface{}, error) { portStr := fmt.Sprintf(":%d", g.port) server := http.Server{ Addr: portStr, - Handler: http.HandlerFunc(runProxy(g)), + Handler: runProxy(g), } // Start the server and log any errors @@ -98,12 +122,16 @@ func startServer(g *golim) (interface{}, error) { } func readUserIP(r *http.Request) string { - IPAddress := r.Header.Get("X-Real-Ip") - if IPAddress == "" { - IPAddress = r.Header.Get("X-Forwarded-For") + ipAddress := r.Header.Get("X-Real-Ip") + if ipAddress == "" { + ipAddress = r.Header.Get("X-Forwarded-For") } - if IPAddress == "" { - IPAddress = r.RemoteAddr + if ipAddress == "" { + ipAddress = r.RemoteAddr + ip, _, err := net.SplitHostPort(ipAddress) + if err == nil { + ipAddress = ip + } } - return IPAddress + return ipAddress }