Skip to content

Commit

Permalink
Introduce improved rate limiting
Browse files Browse the repository at this point in the history
  • Loading branch information
lucemans committed Dec 19, 2024
1 parent 0bccdb5 commit 076b256
Show file tree
Hide file tree
Showing 5 changed files with 180 additions and 6 deletions.
15 changes: 15 additions & 0 deletions server/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,4 @@ cid = "0.11.1"
url = "2.5.2"
prometheus = "0.13.4"
time = "0.3.36"
dashmap = "6.1.0"
147 changes: 142 additions & 5 deletions server/src/http.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
use axum::extract::MatchedPath;
use axum::http::Request;
use axum::response::{Html, Redirect};
use axum::extract::{MatchedPath, State};
use axum::http::{Request, StatusCode};
use axum::middleware::{self, Next};
use axum::response::{Html, Redirect, Response};
use std::env;
use std::{net::SocketAddr, sync::Arc};

use axum::{routing::get, Router};
use dashmap::DashMap;
use std::time::{Duration, Instant};
use tokio::net::TcpListener;
use tokio_util::sync::CancellationToken;
use tower_http::cors::CorsLayer;
Expand All @@ -14,6 +18,133 @@ use crate::routes;
use crate::state::AppState;
use crate::telemetry::metrics::{self};

// Add this struct to hold rate limit configuration
#[derive(Clone)]
struct RateLimit {
requests: u32,
window: Duration,
}

// Add this struct to track rate limiting state
struct RateLimitState {
last_reset: Instant,
count: u32,
}

// Add this to your AppState
pub struct RateLimiter {
limits: DashMap<String, RateLimit>,
states: DashMap<(String, String), RateLimitState>, // (path, ip) -> state
}

impl RateLimiter {
pub fn new() -> Self {
let limits = DashMap::new();

if env::var("RATE_LIMIT_ENABLED").unwrap_or_else(|_| "false".to_owned()) == "true" {
limits.insert(
"/n/:name".to_string(),
RateLimit {
requests: 60,
window: Duration::from_secs(60),
},
);
limits.insert(
"/a/:address".to_string(),
RateLimit {
requests: 60,
window: Duration::from_secs(60),
},
);
limits.insert(
"/bulk/a".to_string(),
RateLimit {
requests: 10,
window: Duration::from_secs(60),
},
);
limits.insert(
"/bulk/n".to_string(),
RateLimit {
requests: 10,
window: Duration::from_secs(60),
},
);
}

Self {
limits,
states: DashMap::new(),
}
}
}

// Add rate limiting middleware
async fn rate_limit_middleware(
State(state): State<Arc<AppState>>,
req: Request<axum::body::Body>,
next: Next,
) -> Result<Response, StatusCode> {
let ip = req
.headers()
.get("x-forwarded-for")
.and_then(|hv| hv.to_str().ok())
.unwrap_or("unknown")
.to_string();

let path = req
.extensions()
.get::<MatchedPath>()
.map(MatchedPath::as_str)
.unwrap_or("/")
.to_string();

let rate_limiter = &state.rate_limiter;

if let Some(limit) = rate_limiter.limits.get(&path) {
let key = (path.clone(), ip.clone());
let now = Instant::now();

let mut exceeded = false;

if env::var("RATE_LIMIT_ENABLED").unwrap_or_else(|_| "false".to_owned()) == "true" {
info!("Rate limit for {} is {}", path, ip);
}

rate_limiter
.states
.entry(key)
.and_modify(|state| {
if now.duration_since(state.last_reset) >= limit.window {
state.count = 1;
state.last_reset = now;
info!("Rate limit reset for {}", path);
} else if state.count >= limit.requests {
info!("Rate limit exceeded for {}", path);
exceeded = true;
} else {
state.count += 1;
}
})
.or_insert(RateLimitState {
last_reset: now,
count: 1,
});

if exceeded {
state
.metrics
.rate_limit_infringements
.with_label_values(&[&ip])
.inc();

return Err(StatusCode::TOO_MANY_REQUESTS);
}
}

Ok(next.run(req).await)
}

pub struct App {
router: Router,
}
Expand Down Expand Up @@ -45,13 +176,15 @@ impl App {
}
}

pub fn setup(state: AppState) -> App {
pub fn setup(mut state: AppState) -> App {
let docs = Router::new()
.route("/openapi.json", get(crate::docs::openapi))
.route("/", get(scalar_handler))
.route("/favicon.png", get(scalar_favicon_handler))
.route("/opengraph.png", get(scalar_opengraph_handler));

let state = Arc::new(state);

let router = Router::new()
.route("/", get(|| async { Redirect::temporary("/docs") }))
.nest("/docs", docs)
Expand All @@ -78,6 +211,10 @@ pub fn setup(state: AppState) -> App {
)
.route("/metrics", get(metrics::handle))
.fallback(routes::four_oh_four::handler)
.layer(middleware::from_fn_with_state(
state.clone(),
rate_limit_middleware,
))
.layer(CorsLayer::permissive())
.layer(
TraceLayer::new_for_http().make_span_with(|request: &Request<_>| {
Expand All @@ -96,7 +233,7 @@ pub fn setup(state: AppState) -> App {
)
}),
)
.with_state(Arc::new(state));
.with_state(state);

App { router }
}
Expand Down
3 changes: 3 additions & 0 deletions server/src/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use ethers_core::types::H160;
use tracing::{info, warn};
use url::Url;

use crate::http::RateLimiter;
use crate::provider::RoundRobin;
use crate::telemetry::metrics::Metrics;
use crate::{cache, database};
Expand All @@ -19,6 +20,7 @@ use crate::{cache, database};
pub struct AppState {
pub service: ENSService,
pub metrics: Metrics,
pub rate_limiter: RateLimiter,
}

impl AppState {
Expand Down Expand Up @@ -88,6 +90,7 @@ impl AppState {
env::var("PROFILE_CACHE_TTL").map_or(Some(600), |cache_ttl| cache_ttl.parse().ok());

Self {
rate_limiter: RateLimiter::new(),
service: ENSService {
cache,
rpc: Box::new(provider),
Expand Down
20 changes: 19 additions & 1 deletion server/src/telemetry/metrics.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
use std::sync::Arc;

use axum::{extract::State, response::IntoResponse};
use prometheus::{Counter, Encoder, Histogram, Registry, TextEncoder};
use prometheus::{Counter, CounterVec, Encoder, Histogram, Registry, TextEncoder};

#[derive(Clone)]
pub struct Metrics {
pub registry: Registry,

pub name_lookup_total: Counter,
pub name_lookup_latency: Histogram,

pub rate_limit_infringements: CounterVec,
}

impl Metrics {
Expand All @@ -34,10 +36,26 @@ impl Metrics {
.register(Box::new(name_lookup_latency.clone()))
.unwrap();

let rate_limit_infringements_opts = prometheus::Opts::new(
"rate_limit_infringements",
"Total number of rate limit infringements",
);

let rate_limit_infringements = CounterVec::new(rate_limit_infringements_opts, &["ip"]).unwrap();
registry
.register(Box::new(rate_limit_infringements.clone()))
.unwrap();

// let rate_limit_infringements = Counter::with_opts(rate_limit_infringements_opts).unwrap();
// registry
// .register(Box::new(rate_limit_infringements.clone()))
// .unwrap();

Self {
registry,
name_lookup_total,
name_lookup_latency,
rate_limit_infringements,
}
}
}
Expand Down

0 comments on commit 076b256

Please sign in to comment.