From 7c24aac324ed44f829e684fefade206186c10fc1 Mon Sep 17 00:00:00 2001 From: Kang Ming Date: Wed, 28 Aug 2024 12:01:15 -0700 Subject: [PATCH] fix: apply shared limiters before email / sms is sent --- internal/api/context.go | 19 +++++++++++++++++++ internal/api/mail.go | 19 +++++++++++++++++++ internal/api/middleware.go | 36 +++++------------------------------- internal/api/phone.go | 13 ++++++++++++- 4 files changed, 55 insertions(+), 32 deletions(-) diff --git a/internal/api/context.go b/internal/api/context.go index 3047f3dd6..ff01e7120 100644 --- a/internal/api/context.go +++ b/internal/api/context.go @@ -4,6 +4,7 @@ import ( "context" "net/url" + "github.com/didip/tollbooth/v5/limiter" jwt "github.com/golang-jwt/jwt/v5" "github.com/supabase/auth/internal/models" ) @@ -31,6 +32,7 @@ const ( ssoProviderKey = contextKey("sso_provider") externalHostKey = contextKey("external_host") flowStateKey = contextKey("flow_state_id") + sharedLimiterKey = contextKey("shared_limiter") ) // withToken adds the JWT token to the context. @@ -241,3 +243,20 @@ func getExternalHost(ctx context.Context) *url.URL { } return obj.(*url.URL) } + +type SharedLimiter struct { + EmailLimiter *limiter.Limiter + PhoneLimiter *limiter.Limiter +} + +func withLimiter(ctx context.Context, limiter *SharedLimiter) context.Context { + return context.WithValue(ctx, sharedLimiterKey, limiter) +} + +func getLimiter(ctx context.Context) *SharedLimiter { + obj := ctx.Value(sharedLimiterKey) + if obj == nil { + return nil + } + return obj.(*SharedLimiter) +} diff --git a/internal/api/mail.go b/internal/api/mail.go index 35f529e25..7c66f6f45 100644 --- a/internal/api/mail.go +++ b/internal/api/mail.go @@ -5,8 +5,11 @@ import ( "strings" "time" + "github.com/didip/tollbooth/v5" "github.com/supabase/auth/internal/hooks" mail "github.com/supabase/auth/internal/mailer" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/metric" "github.com/badoux/checkmail" "github.com/fatih/structs" @@ -21,6 +24,7 @@ import ( var ( MaxFrequencyLimitError error = errors.New("frequency limit reached") + EmailRateLimitExceeded error = errors.New("email rate limit exceeded") ) type GenerateLinkParams struct { @@ -572,6 +576,21 @@ func (a *API) sendEmail(r *http.Request, tx *storage.Connection, u *models.User, config := a.config referrerURL := utilities.GetReferrer(r, config) externalURL := getExternalHost(ctx) + + // apply rate limiting before the email is sent out + limiter := getLimiter(ctx) + if limiter == nil { + return errors.New("email limiter not found in context") + } + if err := tollbooth.LimitByKeys(limiter.EmailLimiter, []string{"email_functions"}); err != nil { + emailRateLimitCounter.Add( + ctx, + 1, + metric.WithAttributeSet(attribute.NewSet(attribute.String("path", r.URL.Path))), + ) + return EmailRateLimitExceeded + } + if config.Hook.SendEmail.Enabled { emailData := mail.EmailData{ Token: otp, diff --git a/internal/api/middleware.go b/internal/api/middleware.go index 972ac5ab3..aa2c3e9ff 100644 --- a/internal/api/middleware.go +++ b/internal/api/middleware.go @@ -15,8 +15,6 @@ import ( "github.com/supabase/auth/internal/models" "github.com/supabase/auth/internal/observability" "github.com/supabase/auth/internal/security" - "go.opentelemetry.io/otel/attribute" - "go.opentelemetry.io/otel/metric" "github.com/didip/tollbooth/v5" "github.com/didip/tollbooth/v5/limiter" @@ -99,35 +97,11 @@ func (a *API) limitEmailOrPhoneSentHandler() middlewareHandler { if shouldRateLimitEmail || shouldRateLimitPhone { if req.Method == "PUT" || req.Method == "POST" { - var requestBody struct { - Email string `json:"email"` - Phone string `json:"phone"` - } - - if err := retrieveRequestParams(req, &requestBody); err != nil { - return c, err - } - - if shouldRateLimitEmail { - if requestBody.Email != "" { - if err := tollbooth.LimitByKeys(emailLimiter, []string{"email_functions"}); err != nil { - emailRateLimitCounter.Add( - req.Context(), - 1, - metric.WithAttributeSet(attribute.NewSet(attribute.String("path", req.URL.Path))), - ) - return c, tooManyRequestsError(ErrorCodeOverEmailSendRateLimit, "Email rate limit exceeded") - } - } - } - - if shouldRateLimitPhone { - if requestBody.Phone != "" { - if err := tollbooth.LimitByKeys(phoneLimiter, []string{"phone_functions"}); err != nil { - return c, tooManyRequestsError(ErrorCodeOverSMSSendRateLimit, "SMS rate limit exceeded") - } - } - } + // store rate limiter in request context + c = withLimiter(c, &SharedLimiter{ + EmailLimiter: emailLimiter, + PhoneLimiter: phoneLimiter, + }) } } diff --git a/internal/api/phone.go b/internal/api/phone.go index 8e7d39e63..63201b624 100644 --- a/internal/api/phone.go +++ b/internal/api/phone.go @@ -8,6 +8,7 @@ import ( "text/template" "time" + "github.com/didip/tollbooth/v5" "github.com/supabase/auth/internal/hooks" "github.com/pkg/errors" @@ -44,6 +45,7 @@ func formatPhoneNumber(phone string) string { // sendPhoneConfirmation sends an otp to the user's phone number func (a *API) sendPhoneConfirmation(r *http.Request, tx *storage.Connection, user *models.User, phone, otpType string, channel string) (string, error) { + ctx := r.Context() config := a.config var token *string @@ -84,7 +86,16 @@ func (a *API) sendPhoneConfirmation(r *http.Request, tx *storage.Connection, use messageID = "test-otp" } - if otp == "" { // not using test OTPs + // not using test OTPs + if otp == "" { + // apply rate limiting before the sms is sent out + limiter := getLimiter(ctx) + if limiter == nil { + return "", internalServerError("phone limiter not found in context") + } + if err := tollbooth.LimitByKeys(limiter.PhoneLimiter, []string{"phone_functions"}); err != nil { + return "", tooManyRequestsError(ErrorCodeOverSMSSendRateLimit, "SMS rate limit exceeded") + } otp, err = crypto.GenerateOtp(config.Sms.OtpLength) if err != nil { return "", internalServerError("error generating otp").WithInternalError(err)