Skip to content

Commit

Permalink
fix: add new config option for email rate limiting
Browse files Browse the repository at this point in the history
  • Loading branch information
hf committed Aug 28, 2024
1 parent c6efec4 commit f498771
Show file tree
Hide file tree
Showing 5 changed files with 117 additions and 10 deletions.
5 changes: 1 addition & 4 deletions internal/api/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,9 @@ func (a *API) limitHandler(lmt *limiter.Limiter) middlewareHandler {

func (a *API) limitEmailOrPhoneSentHandler() middlewareHandler {
// limit per hour
emailFreq := a.config.RateLimitEmailSent / (60 * 60)
smsFreq := a.config.RateLimitSmsSent / (60 * 60)

emailLimiter := tollbooth.NewLimiter(emailFreq, &limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
}).SetBurst(int(a.config.RateLimitEmailSent)).SetMethods([]string{"PUT", "POST"})
emailLimiter := a.config.RateLimitEmailSent.DivideIfDefaultDuration(60 * 60).CreateLimiter().SetBurst(int(a.config.RateLimitEmailSent.Events)).SetMethods([]string{"PUT", "POST"})

phoneLimiter := tollbooth.NewLimiter(smsFreq, &limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
Expand Down
10 changes: 5 additions & 5 deletions internal/api/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ func (ts *MiddlewareTestSuite) TestVerifyCaptchaInvalid() {

func (ts *MiddlewareTestSuite) TestLimitEmailOrPhoneSentHandler() {
// Set up rate limit config for this test
ts.Config.RateLimitEmailSent = 5
ts.Config.RateLimitEmailSent = conf.Rate{Events: 5}
ts.Config.RateLimitSmsSent = 5
ts.Config.External.Phone.Enabled = true

Expand Down Expand Up @@ -419,7 +419,7 @@ func (ts *MiddlewareTestSuite) TestLimitHandlerWithSharedLimiter() {
{
desc: "Exceed ip-based rate limit before shared limiter",
sharedLimiterConfig: &conf.GlobalConfiguration{
RateLimitEmailSent: 10,
RateLimitEmailSent: conf.Rate{Events: 10},
RateLimitSmsSent: 10,
},
ipBasedLimiterConfig: 1,
Expand All @@ -431,7 +431,7 @@ func (ts *MiddlewareTestSuite) TestLimitHandlerWithSharedLimiter() {
{
desc: "Exceed email shared limiter",
sharedLimiterConfig: &conf.GlobalConfiguration{
RateLimitEmailSent: 1,
RateLimitEmailSent: conf.Rate{Events: 1},
RateLimitSmsSent: 1,
},
ipBasedLimiterConfig: 10,
Expand All @@ -443,7 +443,7 @@ func (ts *MiddlewareTestSuite) TestLimitHandlerWithSharedLimiter() {
{
desc: "Exceed sms shared limiter",
sharedLimiterConfig: &conf.GlobalConfiguration{
RateLimitEmailSent: 1,
RateLimitEmailSent: conf.Rate{Events: 1},
RateLimitSmsSent: 1,
},
ipBasedLimiterConfig: 10,
Expand All @@ -462,7 +462,7 @@ func (ts *MiddlewareTestSuite) TestLimitHandlerWithSharedLimiter() {
sharedLimiter := ts.API.limitEmailOrPhoneSentHandler()

// get the minimum amount to reach the threshold just before the rate limit is exceeded
threshold := min(c.sharedLimiterConfig.RateLimitEmailSent, c.sharedLimiterConfig.RateLimitSmsSent, c.ipBasedLimiterConfig)
threshold := min(c.sharedLimiterConfig.RateLimitEmailSent.Events, c.sharedLimiterConfig.RateLimitSmsSent, c.ipBasedLimiterConfig)
for i := 0; i < int(threshold); i++ {
var buffer bytes.Buffer
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.body))
Expand Down
2 changes: 1 addition & 1 deletion internal/conf/configuration.go
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ type GlobalConfiguration struct {
Metrics MetricsConfig
SMTP SMTPConfiguration
RateLimitHeader string `split_words:"true"`
RateLimitEmailSent float64 `split_words:"true" default:"30"`
RateLimitEmailSent Rate `split_words:"true" default:"30"`
RateLimitSmsSent float64 `split_words:"true" default:"30"`
RateLimitVerify float64 `split_words:"true" default:"30"`
RateLimitTokenRefresh float64 `split_words:"true" default:"150"`
Expand Down
81 changes: 81 additions & 0 deletions internal/conf/rate.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
package conf

import (
"fmt"
"strconv"
"strings"
"time"

"github.com/didip/tollbooth/v5"
"github.com/didip/tollbooth/v5/limiter"
)

type Rate struct {
Events float64 `json:"events,omitempty"`
OverTime time.Duration `json:"over_time,omitempty"`
}

func (r *Rate) EventsPerSecond() float64 {
if int64(r.OverTime) == 0 {
return r.Events
}

return r.Events / r.OverTime.Seconds()
}

func (r *Rate) DivideIfDefaultDuration(div float64) *Rate {
if r.OverTime == time.Duration(0) {
return &Rate{
Events: r.Events / div,
}
}

return r
}

func (r *Rate) CreateLimiter() *limiter.Limiter {
overTime := r.OverTime
if int64(overTime) == 0 {
// if r.OverTime is not specified, i.e. the configuration specified just a single float64 number, the
overTime = time.Hour
}

return tollbooth.NewLimiter(r.EventsPerSecond(), &limiter.ExpirableOptions{
DefaultExpirationTTL: overTime,
})
}

func (r *Rate) Decode(value string) error {
if f, err := strconv.ParseFloat(value, 64); err == nil {
r.Events = f
// r.OverTime remains 0 in this case
return nil
}
parts := strings.Split(value, "/")
if len(parts) != 2 {
return fmt.Errorf("rate: value does not match rate syntax %q", value)
}

f, err := strconv.ParseFloat(parts[0], 64)
if err != nil {
return fmt.Errorf("rate: events part of rate value %q failed to parse as float64: %w", value, err)
}

d, err := time.ParseDuration(parts[1])
if err != nil {
return fmt.Errorf("rate: over-time part of rate value %q failed to parse as duration: %w", value, err)
}

r.Events = f
r.OverTime = d

return nil
}

func (r *Rate) String() string {
if r.OverTime == 0 {
return fmt.Sprintf("%f", r.Events)
}

return fmt.Sprintf("%f/%s", r.Events, r.OverTime.String())
}
29 changes: 29 additions & 0 deletions internal/conf/rate_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package conf

import (
"testing"
"time"

"github.com/stretchr/testify/require"
)

func TestRateDecode(t *testing.T) {
r := Rate{}

r = Rate{}
require.NoError(t, r.Decode("123.0"))
require.Equal(t, r, Rate{Events: 123.0, OverTime: 0})

r = Rate{}
require.NoError(t, r.Decode("123.0/24h"))
require.Equal(t, r, Rate{Events: 123.0, OverTime: 24 * time.Hour})

r = Rate{}
require.Error(t, r.Decode("not a number"))

r = Rate{}
require.Error(t, r.Decode("123/456/789"))

r = Rate{}
require.Error(t, r.Decode("123/text"))
}

0 comments on commit f498771

Please sign in to comment.