Skip to content

Commit

Permalink
fix: few fixes to the existing email rate limiter changes by Stojan
Browse files Browse the repository at this point in the history
Summary of changes:

 - I replaced the existing rate limitier made by Stojan with the
   `x/rate/limit` package from golang.org while trying to preserve
   the same behavior.
 - Fixed the tests that are failing with a small change in the
   helper function `setupAPIForTestWithCallback`.
 - Updated the call sites using limiters (mail.go, phone.go)
 - Added some basic test cases along with an example to help
   visualize rate limits.

Some small notes:

 - Setting the "Burst" value a little higher could be a consideration if
   the default of 1 is too restrictive. Adding Burst to conf.Rate for
   better control of the Burst is another option.
 - Using a value such as 100/24h is equivelant in functionality to the
   expression 1/14m, though slightly less clear. If the intent is to
   not limit the rate, but impose a _quota_ of 100 per 24 hours we may
   want to add some additional changes.
  • Loading branch information
Chris Stockton committed Sep 24, 2024
1 parent f05a4b7 commit 92f869c
Show file tree
Hide file tree
Showing 8 changed files with 158 additions and 81 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ require (
golang.org/x/sync v0.7.0 // indirect
golang.org/x/sys v0.21.0 // indirect
golang.org/x/text v0.16.0 // indirect
golang.org/x/time v0.0.0-20220411224347-583f2d630306 // indirect
golang.org/x/time v0.6.0 // indirect
google.golang.org/appengine v1.6.8 // indirect
google.golang.org/grpc v1.63.2 // indirect
google.golang.org/protobuf v1.33.0 // indirect
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,8 @@ golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI=
golang.org/x/time v0.0.0-20160926182426-711ca1cb8763/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20220411224347-583f2d630306 h1:+gHMid33q6pen7kv9xvT+JRinntgeXO2AeZVd0AWD3w=
golang.org/x/time v0.0.0-20220411224347-583f2d630306/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.6.0 h1:eTDhh4ZXt5Qf0augr54TN6suAUudPcawVZeIAPU7D4U=
golang.org/x/time v0.6.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
golang.org/x/tools v0.0.0-20190425163242-31fd60d6bfdc/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
Expand Down
12 changes: 5 additions & 7 deletions internal/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"github.com/supabase/auth/internal/storage"
"github.com/supabase/auth/internal/utilities"
"github.com/supabase/hibp"
"golang.org/x/time/rate"
)

const (
Expand All @@ -38,8 +39,8 @@ type API struct {
// overrideTime can be used to override the clock used by handlers. Should only be used in tests!
overrideTime func() time.Time

emailRateLimiter *SimpleRateLimiter
smsRateLimiter *SimpleRateLimiter
emailRateLimiter *rate.Limiter
smsRateLimiter *rate.Limiter
}

func (a *API) Now() time.Time {
Expand Down Expand Up @@ -72,11 +73,8 @@ func (a *API) deprecationNotices() {
// NewAPIWithVersion creates a new REST API using the specified version
func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Connection, version string) *API {
api := &API{config: globalConfig, db: db, version: version}

now := time.Now()

api.emailRateLimiter = NewSimpleRateLimiter(now, globalConfig.RateLimitEmailSent.DefaultOverTime(time.Hour))
api.smsRateLimiter = NewSimpleRateLimiter(now, globalConfig.RateLimitSmsSent.DefaultOverTime(time.Hour))
api.emailRateLimiter = newRateLimiter(globalConfig.RateLimitEmailSent)
api.smsRateLimiter = newRateLimiter(globalConfig.RateLimitSmsSent)

if api.config.Password.HIBP.Enabled {
httpClient := &http.Client{
Expand Down
5 changes: 4 additions & 1 deletion internal/api/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,10 @@ func setupAPIForTestWithCallback(cb func(*conf.GlobalConfiguration, *storage.Con
cb(nil, conn)
}

return NewAPIWithVersion(config, conn, apiTestVersion), config, nil
a := NewAPIWithVersion(config, conn, apiTestVersion)
a.smsRateLimiter = newUnlimitedLimiter()
a.emailRateLimiter = newUnlimitedLimiter()
return a, config, nil
}

func TestEmailEnabledByDefault(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion internal/api/mail.go
Original file line number Diff line number Diff line change
Expand Up @@ -594,7 +594,7 @@ func (a *API) sendEmail(r *http.Request, tx *storage.Connection, u *models.User,
referrerURL := utilities.GetReferrer(r, config)
externalURL := getExternalHost(ctx)

if ok := a.emailRateLimiter.Increment(1); !ok {
if ok := a.emailRateLimiter.Allow(); !ok {
emailRateLimitCounter.Add(
ctx,
1,
Expand Down
2 changes: 1 addition & 1 deletion internal/api/phone.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ func (a *API) sendPhoneConfirmation(r *http.Request, tx *storage.Connection, use

// not using test OTPs
if otp == "" {
if ok := a.smsRateLimiter.Increment(1); !ok {
if ok := a.smsRateLimiter.Allow(); !ok {
return "", tooManyRequestsError(ErrorCodeOverSMSSendRateLimit, "SMS rate limit exceeded")
}

Expand Down
86 changes: 16 additions & 70 deletions internal/api/ratelimits.go
Original file line number Diff line number Diff line change
@@ -1,78 +1,24 @@
package api

import (
"fmt"
"sync/atomic"
"time"

"github.com/supabase/auth/internal/conf"
"golang.org/x/time/rate"
)

// SimpleRateLimiter holds a rate limiter that implements a token-bucket
// algorithm. Rate.OverTime is the duration at which the bucket is filled, and
// Rate.Events is the number of tokens in the bucket.
//
// Internally it uses an atomically increasing counter that resets to 0 on
// every OverTime tick.
//
// You should always use NewSimpleRateLimiter to create a new one.
type SimpleRateLimiter struct {
Rate conf.Rate

ticker *time.Ticker
counter uint64
}

// NewSimpleRateLimiter creates a new rate limiter starting at the specified
// time and with the specified Rate.
// newRateLimiter returns a rate limiter configured using the given conf.Rate.
//
// Initially the bucket is filled with a proprotion of the Rate.Events
// depending on how close to the Rate.OverTime tick it has been crated. This is
// one way of making sure that server restarts do not give out a too big of a
// rate limit, as the counter is reset.
func NewSimpleRateLimiter(now time.Time, rate conf.Rate) *SimpleRateLimiter {
r := &SimpleRateLimiter{
Rate: rate,
}

counterStartedAt := now.Truncate(rate.OverTime)
counterResetsAt := counterStartedAt.Add(rate.OverTime)

proRate := float64(counterStartedAt.Sub(now).Milliseconds()) / float64(rate.OverTime.Milliseconds())

r.counter = uint64(rate.Events * proRate)
r.ticker = time.NewTicker(counterResetsAt.Sub(now))

go r.fillBucket()

return r
}

func (r *SimpleRateLimiter) Increment(events uint64) bool {
fmt.Printf("@@@@@@@@@@@@@@@@@@@@@@@ %d %f\n", r.counter, r.Rate.Events)
return atomic.AddUint64(&r.counter, events) < uint64(r.Rate.Events)
}

func (r *SimpleRateLimiter) fillBucket() {
if _, ok := <-r.ticker.C; !ok {
return
}

// reset ticker to start ticking at the OverTime rate, as it was
// initially set up to tick at the next aligned OverTime event
r.ticker.Reset(r.Rate.OverTime)

// reset counter
atomic.StoreUint64(&r.counter, 0)

// then keep resetting at regular OverTime intervals
for range r.ticker.C {
atomic.StoreUint64(&r.counter, 0)
}
}

func (r *SimpleRateLimiter) Close() {
if r.ticker != nil {
r.ticker.Stop()
}
// The returned *rate.Limiter will be configured with a token bucket containing
// a single token, which will fill up at a rate of r. For example to allow 100
// events every 24 hours. This will fill a token bucket approximately once every
// 864 seconds (14.4 minutes). See Example_newRateLimiter for a visualization.
func newRateLimiter(r conf.Rate) *rate.Limiter {
// The rate limiter deals in events per second.
eps := r.EventsPerSecond()
const burst = 1

// NewLimiter will have an initial token bucket of size `burst`. It will
// be refilled at a rate of `eps` indefinitely. Note that the expression
// 100 / 24h is roughly equivelant to the expression 1 / 15m. The 100 is
// a rate, not a quota.
return rate.NewLimiter(rate.Limit(eps), burst)
}
128 changes: 128 additions & 0 deletions internal/api/ratelimits_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
package api

import (
"fmt"
"testing"
"time"

"github.com/supabase/auth/internal/conf"
"golang.org/x/time/rate"
)

func newUnlimitedLimiter() *rate.Limiter {
return rate.NewLimiter(rate.Inf, 0)
}

func Example_newRateLimiter() {
now, _ := time.Parse(time.RFC3339, "2024-09-24T10:00:00.00Z")
{
cfg := conf.Rate{Events: 100, OverTime: time.Hour * 24}
rl := newRateLimiter(cfg)
cur := now
for i := 0; i < 61; i++ {
fmt.Printf("%-5v @ %v\n", rl.AllowN(cur, 1), cur)
cur = cur.Add(time.Minute)
}
}

// Output:
// true @ 2024-09-24 10:00:00 +0000 UTC
// false @ 2024-09-24 10:01:00 +0000 UTC
// false @ 2024-09-24 10:02:00 +0000 UTC
// false @ 2024-09-24 10:03:00 +0000 UTC
// false @ 2024-09-24 10:04:00 +0000 UTC
// false @ 2024-09-24 10:05:00 +0000 UTC
// false @ 2024-09-24 10:06:00 +0000 UTC
// false @ 2024-09-24 10:07:00 +0000 UTC
// false @ 2024-09-24 10:08:00 +0000 UTC
// false @ 2024-09-24 10:09:00 +0000 UTC
// false @ 2024-09-24 10:10:00 +0000 UTC
// false @ 2024-09-24 10:11:00 +0000 UTC
// false @ 2024-09-24 10:12:00 +0000 UTC
// false @ 2024-09-24 10:13:00 +0000 UTC
// false @ 2024-09-24 10:14:00 +0000 UTC
// true @ 2024-09-24 10:15:00 +0000 UTC
// false @ 2024-09-24 10:16:00 +0000 UTC
// false @ 2024-09-24 10:17:00 +0000 UTC
// false @ 2024-09-24 10:18:00 +0000 UTC
// false @ 2024-09-24 10:19:00 +0000 UTC
// false @ 2024-09-24 10:20:00 +0000 UTC
// false @ 2024-09-24 10:21:00 +0000 UTC
// false @ 2024-09-24 10:22:00 +0000 UTC
// false @ 2024-09-24 10:23:00 +0000 UTC
// false @ 2024-09-24 10:24:00 +0000 UTC
// false @ 2024-09-24 10:25:00 +0000 UTC
// false @ 2024-09-24 10:26:00 +0000 UTC
// false @ 2024-09-24 10:27:00 +0000 UTC
// false @ 2024-09-24 10:28:00 +0000 UTC
// false @ 2024-09-24 10:29:00 +0000 UTC
// true @ 2024-09-24 10:30:00 +0000 UTC
// false @ 2024-09-24 10:31:00 +0000 UTC
// false @ 2024-09-24 10:32:00 +0000 UTC
// false @ 2024-09-24 10:33:00 +0000 UTC
// false @ 2024-09-24 10:34:00 +0000 UTC
// false @ 2024-09-24 10:35:00 +0000 UTC
// false @ 2024-09-24 10:36:00 +0000 UTC
// false @ 2024-09-24 10:37:00 +0000 UTC
// false @ 2024-09-24 10:38:00 +0000 UTC
// false @ 2024-09-24 10:39:00 +0000 UTC
// false @ 2024-09-24 10:40:00 +0000 UTC
// false @ 2024-09-24 10:41:00 +0000 UTC
// false @ 2024-09-24 10:42:00 +0000 UTC
// false @ 2024-09-24 10:43:00 +0000 UTC
// false @ 2024-09-24 10:44:00 +0000 UTC
// true @ 2024-09-24 10:45:00 +0000 UTC
// false @ 2024-09-24 10:46:00 +0000 UTC
// false @ 2024-09-24 10:47:00 +0000 UTC
// false @ 2024-09-24 10:48:00 +0000 UTC
// false @ 2024-09-24 10:49:00 +0000 UTC
// false @ 2024-09-24 10:50:00 +0000 UTC
// false @ 2024-09-24 10:51:00 +0000 UTC
// false @ 2024-09-24 10:52:00 +0000 UTC
// false @ 2024-09-24 10:53:00 +0000 UTC
// false @ 2024-09-24 10:54:00 +0000 UTC
// false @ 2024-09-24 10:55:00 +0000 UTC
// false @ 2024-09-24 10:56:00 +0000 UTC
// false @ 2024-09-24 10:57:00 +0000 UTC
// false @ 2024-09-24 10:58:00 +0000 UTC
// false @ 2024-09-24 10:59:00 +0000 UTC
// true @ 2024-09-24 11:00:00 +0000 UTC

}

func TestNewRateLimiter(t *testing.T) {
now, _ := time.Parse(time.RFC3339, "2024-09-24T10:00:00.00Z")

type event struct {
ok bool
at time.Time
}
cases := []struct {
cfg conf.Rate
now time.Time
evts []event
}{
{
cfg: conf.Rate{Events: 100, OverTime: time.Hour * 24},
now: now,
evts: []event{
{true, now},
{false, now.Add(time.Minute)},
{false, now.Add(time.Minute)},
{false, now.Add(time.Minute * 14)},
{true, now.Add(time.Minute * 15)},
{false, now.Add(time.Minute * 16)},
{false, now.Add(time.Minute * 17)},
{true, now.Add(time.Minute * 30)},
},
},
}
for _, tc := range cases {
rl := newRateLimiter(tc.cfg)
for _, evt := range tc.evts {
if exp, got := evt.ok, rl.AllowN(evt.at, 1); exp != got {
t.Fatalf("exp AllowN(%v, 1) to be %v; got %v", evt.at, exp, got)
}
}
}
}

0 comments on commit 92f869c

Please sign in to comment.