Skip to content

Commit

Permalink
fix: apply shared limiters before email / sms is sent
Browse files Browse the repository at this point in the history
  • Loading branch information
kangmingtay committed Aug 28, 2024
1 parent 7e38f4c commit 7c24aac
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 32 deletions.
19 changes: 19 additions & 0 deletions internal/api/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"net/url"

"github.com/didip/tollbooth/v5/limiter"
jwt "github.com/golang-jwt/jwt/v5"
"github.com/supabase/auth/internal/models"
)
Expand Down Expand Up @@ -31,6 +32,7 @@ const (
ssoProviderKey = contextKey("sso_provider")
externalHostKey = contextKey("external_host")
flowStateKey = contextKey("flow_state_id")
sharedLimiterKey = contextKey("shared_limiter")
)

// withToken adds the JWT token to the context.
Expand Down Expand Up @@ -241,3 +243,20 @@ func getExternalHost(ctx context.Context) *url.URL {
}
return obj.(*url.URL)
}

type SharedLimiter struct {
EmailLimiter *limiter.Limiter
PhoneLimiter *limiter.Limiter
}

func withLimiter(ctx context.Context, limiter *SharedLimiter) context.Context {
return context.WithValue(ctx, sharedLimiterKey, limiter)
}

func getLimiter(ctx context.Context) *SharedLimiter {
obj := ctx.Value(sharedLimiterKey)
if obj == nil {
return nil
}
return obj.(*SharedLimiter)
}
19 changes: 19 additions & 0 deletions internal/api/mail.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@ import (
"strings"
"time"

"github.com/didip/tollbooth/v5"
"github.com/supabase/auth/internal/hooks"
mail "github.com/supabase/auth/internal/mailer"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/metric"

"github.com/badoux/checkmail"
"github.com/fatih/structs"
Expand All @@ -21,6 +24,7 @@ import (

var (
MaxFrequencyLimitError error = errors.New("frequency limit reached")
EmailRateLimitExceeded error = errors.New("email rate limit exceeded")
)

type GenerateLinkParams struct {
Expand Down Expand Up @@ -572,6 +576,21 @@ func (a *API) sendEmail(r *http.Request, tx *storage.Connection, u *models.User,
config := a.config
referrerURL := utilities.GetReferrer(r, config)
externalURL := getExternalHost(ctx)

// apply rate limiting before the email is sent out
limiter := getLimiter(ctx)
if limiter == nil {
return errors.New("email limiter not found in context")
}
if err := tollbooth.LimitByKeys(limiter.EmailLimiter, []string{"email_functions"}); err != nil {
emailRateLimitCounter.Add(
ctx,
1,
metric.WithAttributeSet(attribute.NewSet(attribute.String("path", r.URL.Path))),
)
return EmailRateLimitExceeded
}

if config.Hook.SendEmail.Enabled {
emailData := mail.EmailData{
Token: otp,
Expand Down
36 changes: 5 additions & 31 deletions internal/api/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@ import (
"github.com/supabase/auth/internal/models"
"github.com/supabase/auth/internal/observability"
"github.com/supabase/auth/internal/security"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/metric"

"github.com/didip/tollbooth/v5"
"github.com/didip/tollbooth/v5/limiter"
Expand Down Expand Up @@ -99,35 +97,11 @@ func (a *API) limitEmailOrPhoneSentHandler() middlewareHandler {

if shouldRateLimitEmail || shouldRateLimitPhone {
if req.Method == "PUT" || req.Method == "POST" {
var requestBody struct {
Email string `json:"email"`
Phone string `json:"phone"`
}

if err := retrieveRequestParams(req, &requestBody); err != nil {
return c, err
}

if shouldRateLimitEmail {
if requestBody.Email != "" {
if err := tollbooth.LimitByKeys(emailLimiter, []string{"email_functions"}); err != nil {
emailRateLimitCounter.Add(
req.Context(),
1,
metric.WithAttributeSet(attribute.NewSet(attribute.String("path", req.URL.Path))),
)
return c, tooManyRequestsError(ErrorCodeOverEmailSendRateLimit, "Email rate limit exceeded")
}
}
}

if shouldRateLimitPhone {
if requestBody.Phone != "" {
if err := tollbooth.LimitByKeys(phoneLimiter, []string{"phone_functions"}); err != nil {
return c, tooManyRequestsError(ErrorCodeOverSMSSendRateLimit, "SMS rate limit exceeded")
}
}
}
// store rate limiter in request context
c = withLimiter(c, &SharedLimiter{
EmailLimiter: emailLimiter,
PhoneLimiter: phoneLimiter,
})
}
}

Expand Down
13 changes: 12 additions & 1 deletion internal/api/phone.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"text/template"
"time"

"github.com/didip/tollbooth/v5"
"github.com/supabase/auth/internal/hooks"

"github.com/pkg/errors"
Expand Down Expand Up @@ -44,6 +45,7 @@ func formatPhoneNumber(phone string) string {

// sendPhoneConfirmation sends an otp to the user's phone number
func (a *API) sendPhoneConfirmation(r *http.Request, tx *storage.Connection, user *models.User, phone, otpType string, channel string) (string, error) {
ctx := r.Context()
config := a.config

var token *string
Expand Down Expand Up @@ -84,7 +86,16 @@ func (a *API) sendPhoneConfirmation(r *http.Request, tx *storage.Connection, use
messageID = "test-otp"
}

if otp == "" { // not using test OTPs
// not using test OTPs
if otp == "" {
// apply rate limiting before the sms is sent out
limiter := getLimiter(ctx)
if limiter == nil {
return "", internalServerError("phone limiter not found in context")
}
if err := tollbooth.LimitByKeys(limiter.PhoneLimiter, []string{"phone_functions"}); err != nil {
return "", tooManyRequestsError(ErrorCodeOverSMSSendRateLimit, "SMS rate limit exceeded")
}
otp, err = crypto.GenerateOtp(config.Sms.OtpLength)
if err != nil {
return "", internalServerError("error generating otp").WithInternalError(err)
Expand Down

0 comments on commit 7c24aac

Please sign in to comment.