Skip to content

Commit

Permalink
🔨 refactore code improve readablity
Browse files Browse the repository at this point in the history
  • Loading branch information
khalil committed Mar 31, 2024
1 parent d1083f4 commit 8f72f20
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 43 deletions.
5 changes: 3 additions & 2 deletions cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,11 @@ func (c *cache) getUserRequestCap(ctx context.Context, ipAddr string, g *golim,
val, err := c.Get(ctx, key).Result()
if err != nil {
if err != redis.Nil {
go c.setUserRequestCap(ctx, key, role)
log.Printf("Error getting user request capacity: %v", err)
return 0
}
return 0
go c.setUserRequestCap(ctx, key, role)
return initialTokenForTheFirstUSerRequest
}
var res int64
json.Unmarshal([]byte(val), &res)
Expand Down
4 changes: 4 additions & 0 deletions const.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ const (
getRolesOperation = "getRoles"
)

const (
initialTokenForTheFirstUSerRequest = 1
)

const (
limiterCacheMainKey = "GOLIM_KEY"
limiterCacheRegexPatternKey = "*GOLIM_KEY*"
Expand Down
1 change: 1 addition & 0 deletions cron.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"github.com/robfig/cron/v3"
)

// TODO: fix the g.rl is nil
func scheduleIncreaseCap(ctx context.Context, g *golim) {
cr := cron.New()
_, err := cr.AddFunc("@every 1m", func() {
Expand Down
4 changes: 2 additions & 2 deletions golim.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func (g *golim) getRole(ctx context.Context) (role.GetRoleRow, bool, error) {
params := toGetRole(g)
data := g.cache.getLimiter(ctx, params)
if data != nil {
return *data, false, nil
return *data, true, nil
}

row, err := g.db.GetRole(ctx, params)
Expand All @@ -59,7 +59,7 @@ func (g *golim) getRole(ctx context.Context) (role.GetRoleRow, bool, error) {
return role.GetRoleRow{}, false, nil
}

g.cache.setLimiter(ctx, &params, &row)
go g.cache.setLimiter(ctx, &params, &row)

return row, true, nil
}
Expand Down
4 changes: 1 addition & 3 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ func initDB(ctx context.Context) *sql.DB {
return db
}

// everything start from main
func main() {
ctx := context.Background()
db := initDB(ctx)
Expand Down Expand Up @@ -62,15 +63,13 @@ func initFlags(ctx context.Context, db *sql.DB, cache *cache) (*golim, error) {

func createRootCommand(g *golim) *ff.Command {
rootFlags := ff.NewFlagSet("golim")

helpCMD := g.createHelpCMD()
initCMD := g.createInitCMD()
addCMD := g.createAddCMD()
removeCMD := g.createRemoveCMD()
getCMD := g.createGetRolesCMD()
removeLimiterCMD := g.addRemoveLimiterCMD()
runCMD := g.createRunCMD()

rootCmd := &ff.Command{
Name: "golim",
Usage: "golim [COMMANDS] <FLAGS>",
Expand All @@ -80,6 +79,5 @@ func createRootCommand(g *golim) *ff.Command {
return nil
},
}

return rootCmd
}
100 changes: 64 additions & 36 deletions proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"html"
"io"
"log"
"net"
"net/http"
"net/url"
"time"
Expand All @@ -21,6 +22,8 @@ var client = &http.Client{
Transport: customTransport,
}

var proxyClient = client

func runProxy(g *golim) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
method := r.Method
Expand All @@ -29,47 +32,68 @@ func runProxy(g *golim) http.HandlerFunc {
operation: method,
endPoint: path,
}
role, needToCheckRequest, err := g.getRole(r.Context())
currentUserRole, needToCheckRequest, err := g.getRole(r.Context())
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
if needToCheckRequest && !isOkRequest(r, g, role) {
if needToCheckRequest && !isOkRequest(r, g, currentUserRole) {
http.Error(w, slowDownError, http.StatusTooManyRequests)
return
}
newURL := url.URL{
Scheme: "http",
Host: role.Destination.String,
Path: role.Endpoint,
}
targetURL := newURL
proxyReq, err := http.NewRequest(r.Method, targetURL.String(), r.Body)
if err != nil {
http.Error(w, createProxyError, http.StatusInternalServerError)
return
}
for name, values := range r.Header {
for _, value := range values {
proxyReq.Header.Add(name, value)
}
}
resp, err := client.Do(proxyReq)
if err != nil {
http.Error(w, sendingProxyError, http.StatusInternalServerError)
return
proxyRequest(w, r, g, currentUserRole)
}
}

func proxyRequest(w http.ResponseWriter, r *http.Request, g *golim, role role.GetRoleRow) {
newURL := buildURL(role)
proxyReq := createProxyRequest(r, newURL)
copyHeaders(r, proxyReq)
resp := sendProxyRequest(proxyReq)
defer resp.Body.Close()
copyResponseHeaders(resp, w)
writeResponse(w, resp)
}

func buildURL(role role.GetRoleRow) url.URL {
return url.URL{
Scheme: "http",
Host: role.Destination.String,
Path: role.Endpoint,
}
}

func createProxyRequest(r *http.Request, newURL url.URL) *http.Request {
proxyReq, _ := http.NewRequest(r.Method, newURL.String(), r.Body)
return proxyReq
}

func copyHeaders(r *http.Request, proxyReq *http.Request) {
for name, values := range r.Header {
for _, value := range values {
proxyReq.Header.Add(name, value)
}
defer resp.Body.Close()
for name, values := range resp.Header {
for _, value := range values {
w.Header().Add(name, value)
}
}
}

func sendProxyRequest(proxyReq *http.Request) *http.Response {
resp, _ := proxyClient.Do(proxyReq)
return resp
}

func copyResponseHeaders(resp *http.Response, w http.ResponseWriter) {
for name, values := range resp.Header {
for _, value := range values {
w.Header().Add(name, value)
}
w.WriteHeader(resp.StatusCode)
io.Copy(w, resp.Body)
}
}

func writeResponse(w http.ResponseWriter, resp *http.Response) {
w.WriteHeader(resp.StatusCode)
io.Copy(w, resp.Body)
}

func isOkRequest(r *http.Request, g *golim, role role.GetRoleRow) bool {
ctx := r.Context()
userIP := readUserIP(r)
Expand All @@ -85,7 +109,7 @@ func startServer(g *golim) (interface{}, error) {
portStr := fmt.Sprintf(":%d", g.port)
server := http.Server{
Addr: portStr,
Handler: http.HandlerFunc(runProxy(g)),
Handler: runProxy(g),
}

// Start the server and log any errors
Expand All @@ -98,12 +122,16 @@ func startServer(g *golim) (interface{}, error) {
}

func readUserIP(r *http.Request) string {
IPAddress := r.Header.Get("X-Real-Ip")
if IPAddress == "" {
IPAddress = r.Header.Get("X-Forwarded-For")
ipAddress := r.Header.Get("X-Real-Ip")
if ipAddress == "" {
ipAddress = r.Header.Get("X-Forwarded-For")
}
if IPAddress == "" {
IPAddress = r.RemoteAddr
if ipAddress == "" {
ipAddress = r.RemoteAddr
ip, _, err := net.SplitHostPort(ipAddress)
if err == nil {
ipAddress = ip
}
}
return IPAddress
return ipAddress
}

0 comments on commit 8f72f20

Please sign in to comment.