From b8dd2be589e6dd200378b7236970cb956adcf907 Mon Sep 17 00:00:00 2001 From: Louis-Marie Baer Date: Fri, 14 Jun 2024 13:59:55 +0200 Subject: [PATCH] feat: initial code for the caching proxy no test nor logs yet but compiles. --- Cargo.toml | 8 +- README.md | 6 +- src/config.rs | 30 ++++-- src/main.rs | 287 +++++++++++++++++++++++++++++++++++++++++++++++++- 4 files changed, 318 insertions(+), 13 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index b723400..4b3d968 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,9 +10,13 @@ serde = { version = "1", features = ["derive"]} anyhow = "1.0" tracing = "0.1" tracing-subscriber = "0.3" -axum = {version="0.7", default-features= false, features= ["tokio", "http2"] } -tokio = {version="1", default-features=false, features= ["rt-multi-thread", "sync"] } +axum = {version="0.7", default-features= false, features= ["tokio", "http2", "macros"] } +tokio = {version="1", default-features=false, features= ["rt-multi-thread", "sync", "macros"] } reqwest = {version="0.12", default-features=false, features=["rustls-tls", "http2"]} url = {version="2.5.0", features=["serde"]} moka = {version="0.12", features=["future"]} ahash = "0.8" +uuid = {version="1.8", features=["v4", "fast-rng"]} +nohash = "0.2" +derive_more = {version="0.99", default-features=false, features=["deref", "deref_mut"]} +enclose = "1.2" diff --git a/README.md b/README.md index a496ab7..4c6d077 100644 --- a/README.md +++ b/README.md @@ -12,6 +12,8 @@ For example, the code will always be safe even if unsafe could bring more perfom - configuration file - multiple endpoints possible - cache invalidation api -- well thought expiration of cache -- use etag and and non-modified headers. Etag takes into account Vary header from server. +- well thought expiration of cache (thanks moka) +- add etag header +- return non modified status when client has a valid etag +- takes into account Vary header from server (will save different cache object for every variation of the specified header) - let server decide his own caching controls. diff --git a/src/config.rs b/src/config.rs index 3eecb3d..bf51082 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,4 +1,7 @@ -use std::time::Duration; +use std::{ + net::{Ipv4Addr, SocketAddr, SocketAddrV4}, + time::Duration, +}; use reqwest::Url; use serde::{Deserialize, Serialize}; @@ -6,20 +9,33 @@ use serde::{Deserialize, Serialize}; /// Example: /// listen_port: 9834, /// endpoints: [("/api1", "127.0.0.1:3998")] -#[derive(Serialize, Deserialize, Default)] +/// request /api1/abc +/// will do 127.0.0.1:3998/abc +#[derive(Serialize, Deserialize, Clone)] pub struct Config { - /// port to which Mnemosyne will listen for incoming requests. - pub listen_port: u16, + /// address and port to which Mnemosyne will listen for incoming requests. + pub listen_address: SocketAddr, /// String is the path mnemosyne will accept request and redirect them to Url pub endpoints: Vec<(String, Url)>, + /// if none of the request contained reconized uri or if you want to redirect every request to one backend. + pub fall_back_endpoint: Url, /// cache backend configuration pub cache: CacheConfig, } -#[derive(Serialize, Deserialize, Default)] +impl Default for Config { + fn default() -> Self { + Self { + listen_address: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 9830)), + endpoints: Default::default(), + cache: Default::default(), + fall_back_endpoint: Url::parse("http://127.0.0.1:1000").unwrap(), + } + } +} + +#[derive(Serialize, Deserialize, Default, Clone)] pub struct CacheConfig { /// cache expiration after last request pub expiration: Duration, - /// maximum cache entry - pub max_entry: u64, } diff --git a/src/main.rs b/src/main.rs index e7a11a9..107b469 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,3 +1,286 @@ -fn main() { - println!("Hello, world!"); +use ahash::{HashMap, HashMapExt}; +use anyhow::Result; +use axum::{ + body::{to_bytes, Bytes}, + extract::{Path, Request, State}, + http::{HeaderMap, HeaderValue, Uri}, + response::IntoResponse, + routing::delete, + Router, +}; +use config::Config; +use derive_more::{Deref, DerefMut}; +use enclose::enc; +use moka::future::Cache as MokaCache; +use reqwest::{ + header::{ETAG, VARY}, + Client, Method, StatusCode, +}; +use std::{str::FromStr, sync::Arc}; +use tokio::{spawn, sync::Mutex}; +use url::Url; +use uuid::Uuid; + +/// configuration from file +mod config; +#[derive(Clone)] +struct AppState { + config: Config, + // option HeaderMap is the header request that needs to be present. + // the response will contains a Vary Header in this case. + // one method and uri can contain multiple different response based on headers, so we use a Vec per entry since the id of the entry is based on uri and method. + cache: Cache, + index_cache: Arc>, + client: Client, } +#[derive(Deref, DerefMut, Clone)] +/// IndexCache will store entry for each combination of uri/method with a vec of uuid per HeaderMap. HeaderMap here are request headers that match the headers name in the Vary header value response. +struct IndexCache(HashMap<(Method, Uri), Vec<(Uuid, HeaderMap)>>); +#[derive(Deref, DerefMut, Clone)] +struct Cache(MokaCache); + +impl Cache { + fn new(config: &Config) -> Cache { + Self( + MokaCache::builder() + .name("mnemosyne") + .time_to_idle(config.cache.expiration) + .build_with_hasher(ahash::RandomState::new()), + ) + } + fn check_etag(&self, headers: &HeaderMap) -> bool { + if let Some(etag) = headers.get("Etag") { + if let Ok(str) = etag.to_str() { + if let Ok(uuid) = Uuid::from_str(str) { + return self.contains_key(&uuid); + } + } + } + false + } +} + +/// from a request, keep only headers that are present in Vary response header +fn headers_match_vary( + request_headers: &HeaderMap, + vary_header: Option<&HeaderValue>, +) -> Result { + if let Some(vary) = vary_header { + let mut h_vary = vary.to_str()?.split(','); + let mut headers = HeaderMap::new(); + request_headers + .iter() + .filter(|h_req| h_vary.any(|name| name == h_req.0.as_str())) + .for_each(|header| { + headers.insert(header.0, header.1.clone()); + }); + Ok(headers) + } else { + Ok(HeaderMap::new()) + } +} + +impl IndexCache { + fn new() -> Self { + IndexCache(HashMap::new()) + } + fn add_entry( + &mut self, + uuid: Uuid, + req_method: Method, + req_uri: Uri, + req_headers_match_vary: HeaderMap, + ) { + let key = (req_method, req_uri); + let value = (uuid, req_headers_match_vary); + // check if entry exist for method/uri + + if let Some(v) = self.get_mut(&key) { + // if entry exist, push into vec + v.push(value); + } else { + // if no entries, create one. + self.insert(key, vec![value]); + } + } + /// will search for an entry in cache based on a request. Will check that request headers includes the ones associated in this entry if any. + /// Will return the uuid of the entry. + fn request_to_uuid(&self, request: &Request) -> Option { + let method = request.method().to_owned(); + let uri = request.uri().to_owned(); + let headermap = request.headers(); + if let Some(uuids) = self.get(&(method, uri)) { + return uuids + .iter() + .find(|(_, headermap_object)| { + headermap_object + .iter() + .all(|x| headermap.get(x.0).is_some_and(|value| value == x.1)) + }) + .map(|v| v.0); + } + None + } + fn delete_uuid_from_index(&mut self, uuid: &Uuid) { + // remove uuid entry from vec + self.iter_mut().for_each(|v| v.1.retain(|c| &c.0 != uuid)); + // check if the entry for method/uri is now empty and delete it if that's the case. + let key = self.iter().find_map(|(key, value)| { + if value.is_empty() { + Some(key.to_owned()) + } else { + None + } + }); + if let Some(key) = key { + self.remove(&key); + } + } +} + +#[tokio::main] +async fn main() -> Result<()> { + // load config + let config = confy::load_path::("/etc/mnemosyne")?; + // create cache moka + // create state app + let listen = config.listen_address; + let state = AppState { + cache: Cache::new(&config), + config, + index_cache: Arc::new(Mutex::new(IndexCache::new())), + client: Client::new(), + }; + // create route for cache API + let route = Router::new() + .route("/delete/:uuid", delete(delete_entry)) + .route("/delete_all", delete(delete_all)) + .fallback(handler) + .with_state(state); + let listener = tokio::net::TcpListener::bind(listen).await?; + axum::serve(listener, route.into_make_service()).await?; + // create listener for all endpoints + + Ok(()) +} + +// handle delete endpoint +// will also delete from index by iterating over the entries to find the method/path +async fn delete_entry( + Path(path): Path, + State(state): State, +) -> impl IntoResponse { + if let Ok(uuid) = Uuid::from_str(&path) { + state.cache.invalidate(&uuid).await; + state.index_cache.lock().await.delete_uuid_from_index(&uuid); + return StatusCode::OK; + } + StatusCode::NOT_FOUND +} +// handle delete_all endpoint +async fn delete_all(State(state): State) -> impl IntoResponse { + state.cache.invalidate_all(); + *state.index_cache.lock().await = IndexCache::new(); + StatusCode::OK +} + +// handle request +#[axum::debug_handler] +async fn handler(State(state): State, request: Request) -> impl IntoResponse { + // check if etag is present in headers + if state.cache.check_etag(request.headers()) { + // respond 304 if etag is present in cache + return StatusCode::NOT_MODIFIED.into_response(); + } + + // if response is in cache with valid header if any, return response from cache + + if let Some(uuid) = state.index_cache.lock().await.request_to_uuid(&request) { + let rep = state + .cache + .get(&uuid) + .await + .expect("a value should be there if index has one"); + // Body can not be saved in Cache so we save Bytes and convert to body when we need it. + return rep.into_response(); + } + + // if not in cache, make the request to backend service + let req_method = request.method().to_owned(); + let req_headers = request.headers().to_owned(); + let req_uri = request.uri().to_owned(); + match state + .client + .request( + request.method().to_owned(), + to_backend_uri(&state.config, request.uri()), + ) + .headers(request.headers().to_owned()) + .body(to_bytes(request.into_body(), usize::MAX).await.unwrap()) + .send() + .await + { + Ok(mut rep) => { + // first send Response and then cache so client wait as little as possible. + // need to add Etag headers to response + let uuid = Uuid::new_v4(); + + let index = state.index_cache.clone(); + let cache = state.cache.clone(); + rep.headers_mut() + .insert(ETAG, HeaderValue::from_str(&uuid.to_string()).unwrap()); + let headers = rep.headers().to_owned(); + let req_headers_match_vary = match headers_match_vary(&req_headers, headers.get(VARY)) { + Ok(h) => h, + Err(_err) => { + // seems backend service response contains malformated header value for Vary + HeaderMap::new() + } + }; + + let axum_rep = ( + rep.status(), + rep.headers().to_owned(), + rep.bytes().await.unwrap(), + ); + + spawn(enc!((uuid, axum_rep) async move { + // add entry to index cache + index.lock().await.add_entry(uuid, req_method, req_uri, req_headers_match_vary); + // add repsonse to cache + cache.insert(uuid, axum_rep).await; + + })); + axum_rep.into_response() + } + Err(_err) => { + // the request to the backend failed + + StatusCode::INTERNAL_SERVER_ERROR.into_response() + } + } +} + +fn to_backend_uri(config: &Config, uri_request: &Uri) -> Url { + if let Some((endpoint, url)) = config + .endpoints + .iter() + .find(|b| uri_request.to_string().contains(&format!("^{}", b.0))) + { + let new_uri = uri_request.to_string().replace(endpoint, ""); + Url::parse(&format!("{}{}", url, new_uri).replace("//", "/")) + .expect("could not parse to Url") + } else { + // no uri recognized, using fallback backend + config.fall_back_endpoint.to_owned() + } +} + +// if not, +// check cache availability +// if not in cache, +// send request for each route in config to backend service + +// add caching headers +// send response +// save in cache