Skip to content

Commit

Permalink
Move Redis tests to e2e (#664)
Browse files Browse the repository at this point in the history
  • Loading branch information
douglascamata authored Mar 1, 2024
1 parent f2995f3 commit 8b041be
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 53 deletions.
24 changes: 12 additions & 12 deletions ratelimit/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@ import (
"github.com/observatorium/api/ratelimit/gubernator"
)

var errOverLimit = errors.New("over limit")
var ErrOverLimit = errors.New("over limit")

type request struct {
type Request struct {
name string
key string
limit int64
// duration is the duration of the rate limit window in milliseconds.
duration int64
Key string
Limit int64
// Duration is the Duration of the rate limit window in milliseconds.
Duration int64
failOpen bool
retryAfterMin time.Duration
retryAfterMax time.Duration
Expand All @@ -38,7 +38,7 @@ type SharedRateLimiter interface {
// GetRateLimits retrieves the rate limits for a given request.
// It returns the remaining requests, the reset time as Unix time (millisecond from epoch), and any error that occurred.
// When a rate limit is exceeded, the error errOverLimit is returned.
GetRateLimits(ctx context.Context, req *request) (remaining, resetTime int64, err error)
GetRateLimits(ctx context.Context, req *Request) (remaining, resetTime int64, err error)
}

// NewClient creates a new gubernator client with default configuration.
Expand Down Expand Up @@ -78,14 +78,14 @@ func (c *Client) Dial(ctx context.Context, address string) error {

// GetRateLimits gets the rate limits corresponding to a request.
// Note: Dial must be called before calling this method, otherwise the client will panic.
func (c *Client) GetRateLimits(ctx context.Context, req *request) (remaining, resetTime int64, err error) {
func (c *Client) GetRateLimits(ctx context.Context, req *Request) (remaining, resetTime int64, err error) {
resp, err := c.client.GetRateLimits(ctx, &gubernator.GetRateLimitsReq{
Requests: []*gubernator.RateLimitReq{{
Name: req.name,
UniqueKey: req.key,
UniqueKey: req.Key,
Hits: 1,
Limit: req.limit,
Duration: req.duration,
Limit: req.Limit,
Duration: req.Duration,
Algorithm: gubernator.Algorithm_LEAKY_BUCKET,
Behavior: gubernator.Behavior_GLOBAL,
}},
Expand All @@ -96,7 +96,7 @@ func (c *Client) GetRateLimits(ctx context.Context, req *request) (remaining, re

response := resp.Responses[0]
if response.Status == gubernator.Status_OVER_LIMIT {
return 0, 0, errOverLimit
return 0, 0, ErrOverLimit
}

return response.GetRemaining(), response.GetResetTime(), nil
Expand Down
22 changes: 11 additions & 11 deletions ratelimit/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,11 @@ func WithSharedRateLimiter(logger log.Logger, client SharedRateLimiter, configs
middleware{
c.Matcher,
rateLimiter{logger, client,
&request{
&Request{
name: requestName,
key: fmt.Sprintf("%s:%s", c.Tenant, c.Matcher.String()),
limit: int64(c.Limit),
duration: c.Window.Milliseconds(),
Key: fmt.Sprintf("%s:%s", c.Tenant, c.Matcher.String()),
Limit: int64(c.Limit),
Duration: c.Window.Milliseconds(),
failOpen: c.FailOpen,
retryAfterMin: c.RetryAfterMin,
retryAfterMax: c.RetryAfterMax,
Expand Down Expand Up @@ -122,7 +122,7 @@ func combine(middlewares map[string][]middleware) func(next http.Handler) http.H
type rateLimiter struct {
logger log.Logger
limiterClient SharedRateLimiter
req *request
req *Request
mut *sync.RWMutex
limitTracker map[string]time.Duration
}
Expand All @@ -133,7 +133,7 @@ func (l rateLimiter) Handler(next http.Handler) http.Handler {
defer cancel()

remaining, resetTime, err := l.limiterClient.GetRateLimits(ctx, l.req)
w.Header().Set(headerKeyLimit, strconv.FormatInt(l.req.limit, 10))
w.Header().Set(headerKeyLimit, strconv.FormatInt(l.req.Limit, 10))
w.Header().Set(headerKeyRemaining, strconv.FormatInt(remaining, 10))
w.Header().Set(headerKeyReset, strconv.FormatInt(resetTime, 10))

Expand All @@ -144,7 +144,7 @@ func (l rateLimiter) Handler(next http.Handler) http.Handler {
w.Header().Set(headerRetryAfter, retryAfter)
}

if errors.Is(err, errOverLimit) {
if errors.Is(err, ErrOverLimit) {
httperr.PrometheusAPIError(w, http.StatusText(http.StatusTooManyRequests), http.StatusTooManyRequests)
return
}
Expand Down Expand Up @@ -173,10 +173,10 @@ func (l rateLimiter) getAndSetNextRetryAfterValue() (string, bool) {
l.mut.Lock()
defer l.mut.Unlock()

current, ok := l.limitTracker[l.req.key]
current, ok := l.limitTracker[l.req.Key]
if !ok {
nextValue := l.req.retryAfterMin.Seconds()
l.limitTracker[l.req.key] = l.req.retryAfterMin * 2
l.limitTracker[l.req.Key] = l.req.retryAfterMin * 2
return fmt.Sprintf("%d", int(nextValue)), true
}

Expand All @@ -188,7 +188,7 @@ func (l rateLimiter) getAndSetNextRetryAfterValue() (string, bool) {
nextValue = l.req.retryAfterMax
}

l.limitTracker[l.req.key] = nextValue
l.limitTracker[l.req.Key] = nextValue
next := strconv.Itoa(int(nextValue.Seconds()))
return next, true
}
Expand All @@ -200,5 +200,5 @@ func (l rateLimiter) resetRetryAfterValue() {

l.mut.Lock()
defer l.mut.Unlock()
delete(l.limitTracker, l.req.key)
delete(l.limitTracker, l.req.Key)
}
7 changes: 4 additions & 3 deletions ratelimit/http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/prometheus/client_golang/prometheus"

"github.com/go-chi/chi"

"github.com/observatorium/api/authentication"
"github.com/observatorium/api/logger"
"github.com/observatorium/api/server"
Expand Down Expand Up @@ -202,16 +203,16 @@ type mockSharedLimiter struct {
received int64
}

func (m *mockSharedLimiter) GetRateLimits(ctx context.Context, req *request) (remaining, resetTime int64, err error) {
func (m *mockSharedLimiter) GetRateLimits(ctx context.Context, req *Request) (remaining, resetTime int64, err error) {
m.mtx.Lock()
defer m.mtx.Unlock()

if req.limit > m.received {
if req.Limit > m.received {
m.received++
return m.received, mockResetTime, nil
}

return m.received, mockResetTime, errOverLimit
return m.received, mockResetTime, ErrOverLimit
}

func (m *mockSharedLimiter) reset() {
Expand Down
12 changes: 6 additions & 6 deletions ratelimit/redis.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,23 +33,23 @@ func NewRedisRateLimiter(addresses []string) (*RedisRateLimiter, error) {

// GetRateLimits retrieves the rate limits for a given request using a Redis Rate Limiter.
// It returns the amount of remaining requests, the reset time in milliseconds, and any error that occurred.
func (r *RedisRateLimiter) GetRateLimits(ctx context.Context, req *request) (remaining, resetTime int64, err error) {
func (r *RedisRateLimiter) GetRateLimits(ctx context.Context, req *Request) (remaining, resetTime int64, err error) {
inspectScript := rueidis.NewLuaScript(gcraRateLimitScript)
rateLimitParameters := []string{
strconv.FormatInt(time.Now().UnixMilli(), 10), // now
strconv.FormatInt(req.limit, 10), // burst
strconv.FormatInt(req.limit, 10), // rate
strconv.FormatInt(req.duration, 10), // period
strconv.FormatInt(req.Limit, 10), // burst
strconv.FormatInt(req.Limit, 10), // rate
strconv.FormatInt(req.Duration, 10), // period
"1", // cost
}
result := inspectScript.Exec(ctx, r.client, []string{req.key}, rateLimitParameters)
result := inspectScript.Exec(ctx, r.client, []string{req.Key}, rateLimitParameters)
limited, remaining, resetIn, err := r.parseRateLimitResult(&result)
if err != nil {
return 0, 0, err
}
resetTime = time.Now().Add(time.Duration(resetIn) * time.Millisecond).UnixMilli()
if limited {
return remaining, resetTime, errOverLimit
return remaining, resetTime, ErrOverLimit
}
return remaining, resetTime, nil
}
Expand Down
44 changes: 23 additions & 21 deletions ratelimit/redis_test.go → test/e2e/redis_rate_limiter_test.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package ratelimit
package e2e

import (
"context"
Expand All @@ -8,6 +8,8 @@ import (
"github.com/efficientgo/core/backoff"
"github.com/efficientgo/core/testutil"
"github.com/efficientgo/e2e"

"github.com/observatorium/api/ratelimit"
)

func TestRedisRateLimiter_GetRateLimits(t *testing.T) {
Expand All @@ -24,7 +26,7 @@ func TestRedisRateLimiter_GetRateLimits(t *testing.T) {

type args struct {
ctx context.Context
req *request
req *ratelimit.Request
}
tests := []struct {
name string
Expand All @@ -41,10 +43,10 @@ func TestRedisRateLimiter_GetRateLimits(t *testing.T) {
name: "Single hit, far from limit",
args: args{
ctx: context.Background(),
req: &request{
key: "single-hit",
limit: 10,
duration: (10 * time.Second).Milliseconds(),
req: &ratelimit.Request{
Key: "single-hit",
Limit: 10,
Duration: (10 * time.Second).Milliseconds(),
},
},
totalHits: 1,
Expand All @@ -57,10 +59,10 @@ func TestRedisRateLimiter_GetRateLimits(t *testing.T) {
name: "At the edge of the limit",
args: args{
ctx: context.Background(),
req: &request{
key: "edge-hit",
limit: 10,
duration: (10 * time.Second).Milliseconds(),
req: &ratelimit.Request{
Key: "edge-hit",
Limit: 10,
Duration: (10 * time.Second).Milliseconds(),
},
},
totalHits: 10,
Expand All @@ -73,15 +75,15 @@ func TestRedisRateLimiter_GetRateLimits(t *testing.T) {
name: "Beyond the limit",
args: args{
ctx: context.Background(),
req: &request{
key: "beyond-limit",
limit: 10,
duration: (10 * time.Second).Milliseconds(),
req: &ratelimit.Request{
Key: "beyond-limit",
Limit: 10,
Duration: (10 * time.Second).Milliseconds(),
},
},
totalHits: 11,
wantRemaining: 0,
wantErr: errOverLimit,
wantErr: ratelimit.ErrOverLimit,
wantResetTimeFunc: func() time.Time {
return time.Now().Add(10 * time.Second)
},
Expand All @@ -97,10 +99,10 @@ func TestRedisRateLimiter_GetRateLimits(t *testing.T) {
name: "Wait for 1 leak",
args: args{
ctx: context.Background(),
req: &request{
key: "wait-for-leak",
limit: 10,
duration: (10 * time.Second).Milliseconds(),
req: &ratelimit.Request{
Key: "wait-for-leak",
Limit: 10,
Duration: (10 * time.Second).Milliseconds(),
},
},
totalHits: 2,
Expand All @@ -125,10 +127,10 @@ func TestRedisRateLimiter_GetRateLimits(t *testing.T) {

var (
err error
r *RedisRateLimiter
r *ratelimit.RedisRateLimiter
)
for b.Reset(); b.Ongoing(); b.Wait() {
r, err = NewRedisRateLimiter([]string{redis.Endpoint("http")})
r, err = ratelimit.NewRedisRateLimiter([]string{redis.Endpoint("http")})
}
testutil.Ok(t, err)
testutil.Assert(t, r != nil)
Expand Down

0 comments on commit 8b041be

Please sign in to comment.