From 076b25671b93de5c6d952ed4f4d9b6192316cdde Mon Sep 17 00:00:00 2001 From: Luc Date: Thu, 19 Dec 2024 10:35:21 +0100 Subject: [PATCH] Introduce improved rate limiting --- server/Cargo.lock | 15 ++++ server/Cargo.toml | 1 + server/src/http.rs | 147 ++++++++++++++++++++++++++++++-- server/src/state.rs | 3 + server/src/telemetry/metrics.rs | 20 ++++- 5 files changed, 180 insertions(+), 6 deletions(-) diff --git a/server/Cargo.lock b/server/Cargo.lock index b27c397..46b3a1d 100644 --- a/server/Cargo.lock +++ b/server/Cargo.lock @@ -972,6 +972,20 @@ dependencies = [ "syn 2.0.72", ] +[[package]] +name = "dashmap" +version = "6.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5041cc499144891f3790297212f32a74fb938e5136a14943f338ef9e0ae276cf" +dependencies = [ + "cfg-if", + "crossbeam-utils", + "hashbrown 0.14.5", + "lock_api", + "once_cell", + "parking_lot_core", +] + [[package]] name = "data-encoding" version = "2.6.0" @@ -1200,6 +1214,7 @@ dependencies = [ "cid", "crc16", "crc32fast", + "dashmap", "digest", "dotenvy", "enstate_shared", diff --git a/server/Cargo.toml b/server/Cargo.toml index c038dbc..f5f7ab5 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -73,3 +73,4 @@ cid = "0.11.1" url = "2.5.2" prometheus = "0.13.4" time = "0.3.36" +dashmap = "6.1.0" diff --git a/server/src/http.rs b/server/src/http.rs index 4a8bed2..4d563a7 100644 --- a/server/src/http.rs +++ b/server/src/http.rs @@ -1,9 +1,13 @@ -use axum::extract::MatchedPath; -use axum::http::Request; -use axum::response::{Html, Redirect}; +use axum::extract::{MatchedPath, State}; +use axum::http::{Request, StatusCode}; +use axum::middleware::{self, Next}; +use axum::response::{Html, Redirect, Response}; +use std::env; use std::{net::SocketAddr, sync::Arc}; use axum::{routing::get, Router}; +use dashmap::DashMap; +use std::time::{Duration, Instant}; use tokio::net::TcpListener; use tokio_util::sync::CancellationToken; use tower_http::cors::CorsLayer; @@ -14,6 +18,133 @@ use crate::routes; use crate::state::AppState; use crate::telemetry::metrics::{self}; +// Add this struct to hold rate limit configuration +#[derive(Clone)] +struct RateLimit { + requests: u32, + window: Duration, +} + +// Add this struct to track rate limiting state +struct RateLimitState { + last_reset: Instant, + count: u32, +} + +// Add this to your AppState +pub struct RateLimiter { + limits: DashMap, + states: DashMap<(String, String), RateLimitState>, // (path, ip) -> state +} + +impl RateLimiter { + pub fn new() -> Self { + let limits = DashMap::new(); + + if env::var("RATE_LIMIT_ENABLED").unwrap_or_else(|_| "false".to_owned()) == "true" { + limits.insert( + "/n/:name".to_string(), + RateLimit { + requests: 60, + window: Duration::from_secs(60), + }, + ); + limits.insert( + "/a/:address".to_string(), + RateLimit { + requests: 60, + window: Duration::from_secs(60), + }, + ); + limits.insert( + "/bulk/a".to_string(), + RateLimit { + requests: 10, + window: Duration::from_secs(60), + }, + ); + limits.insert( + "/bulk/n".to_string(), + RateLimit { + requests: 10, + window: Duration::from_secs(60), + }, + ); + } + + Self { + limits, + states: DashMap::new(), + } + } +} + +// Add rate limiting middleware +async fn rate_limit_middleware( + State(state): State>, + req: Request, + next: Next, +) -> Result { + let ip = req + .headers() + .get("x-forwarded-for") + .and_then(|hv| hv.to_str().ok()) + .unwrap_or("unknown") + .to_string(); + + let path = req + .extensions() + .get::() + .map(MatchedPath::as_str) + .unwrap_or("/") + .to_string(); + + let rate_limiter = &state.rate_limiter; + + if let Some(limit) = rate_limiter.limits.get(&path) { + let key = (path.clone(), ip.clone()); + let now = Instant::now(); + + let mut exceeded = false; + + if env::var("RATE_LIMIT_ENABLED").unwrap_or_else(|_| "false".to_owned()) == "true" { + info!("Rate limit for {} is {}", path, ip); + } + + rate_limiter + .states + .entry(key) + .and_modify(|state| { + if now.duration_since(state.last_reset) >= limit.window { + state.count = 1; + state.last_reset = now; + info!("Rate limit reset for {}", path); + } else if state.count >= limit.requests { + info!("Rate limit exceeded for {}", path); + exceeded = true; + } else { + state.count += 1; + } + }) + .or_insert(RateLimitState { + last_reset: now, + count: 1, + }); + + if exceeded { + state + .metrics + .rate_limit_infringements + .with_label_values(&[&ip]) + .inc(); + + return Err(StatusCode::TOO_MANY_REQUESTS); + } + } + + Ok(next.run(req).await) +} + pub struct App { router: Router, } @@ -45,13 +176,15 @@ impl App { } } -pub fn setup(state: AppState) -> App { +pub fn setup(mut state: AppState) -> App { let docs = Router::new() .route("/openapi.json", get(crate::docs::openapi)) .route("/", get(scalar_handler)) .route("/favicon.png", get(scalar_favicon_handler)) .route("/opengraph.png", get(scalar_opengraph_handler)); + let state = Arc::new(state); + let router = Router::new() .route("/", get(|| async { Redirect::temporary("/docs") })) .nest("/docs", docs) @@ -78,6 +211,10 @@ pub fn setup(state: AppState) -> App { ) .route("/metrics", get(metrics::handle)) .fallback(routes::four_oh_four::handler) + .layer(middleware::from_fn_with_state( + state.clone(), + rate_limit_middleware, + )) .layer(CorsLayer::permissive()) .layer( TraceLayer::new_for_http().make_span_with(|request: &Request<_>| { @@ -96,7 +233,7 @@ pub fn setup(state: AppState) -> App { ) }), ) - .with_state(Arc::new(state)); + .with_state(state); App { router } } diff --git a/server/src/state.rs b/server/src/state.rs index 77bd6c0..653b145 100644 --- a/server/src/state.rs +++ b/server/src/state.rs @@ -11,6 +11,7 @@ use ethers_core::types::H160; use tracing::{info, warn}; use url::Url; +use crate::http::RateLimiter; use crate::provider::RoundRobin; use crate::telemetry::metrics::Metrics; use crate::{cache, database}; @@ -19,6 +20,7 @@ use crate::{cache, database}; pub struct AppState { pub service: ENSService, pub metrics: Metrics, + pub rate_limiter: RateLimiter, } impl AppState { @@ -88,6 +90,7 @@ impl AppState { env::var("PROFILE_CACHE_TTL").map_or(Some(600), |cache_ttl| cache_ttl.parse().ok()); Self { + rate_limiter: RateLimiter::new(), service: ENSService { cache, rpc: Box::new(provider), diff --git a/server/src/telemetry/metrics.rs b/server/src/telemetry/metrics.rs index c526fed..5c0df74 100644 --- a/server/src/telemetry/metrics.rs +++ b/server/src/telemetry/metrics.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use axum::{extract::State, response::IntoResponse}; -use prometheus::{Counter, Encoder, Histogram, Registry, TextEncoder}; +use prometheus::{Counter, CounterVec, Encoder, Histogram, Registry, TextEncoder}; #[derive(Clone)] pub struct Metrics { @@ -9,6 +9,8 @@ pub struct Metrics { pub name_lookup_total: Counter, pub name_lookup_latency: Histogram, + + pub rate_limit_infringements: CounterVec, } impl Metrics { @@ -34,10 +36,26 @@ impl Metrics { .register(Box::new(name_lookup_latency.clone())) .unwrap(); + let rate_limit_infringements_opts = prometheus::Opts::new( + "rate_limit_infringements", + "Total number of rate limit infringements", + ); + + let rate_limit_infringements = CounterVec::new(rate_limit_infringements_opts, &["ip"]).unwrap(); + registry + .register(Box::new(rate_limit_infringements.clone())) + .unwrap(); + + // let rate_limit_infringements = Counter::with_opts(rate_limit_infringements_opts).unwrap(); + // registry + // .register(Box::new(rate_limit_infringements.clone())) + // .unwrap(); + Self { registry, name_lookup_total, name_lookup_latency, + rate_limit_infringements, } } }