Skip to content

Commit

Permalink
feat: Cache origin
Browse files Browse the repository at this point in the history
Add comments because i'm definitely going to forget how this works

fix formatting
  • Loading branch information
QuentinMoss committed Jul 2, 2023
1 parent ed67a48 commit db5dc31
Show file tree
Hide file tree
Showing 7 changed files with 78 additions and 8 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@ tower = { version = "0.4", features = ["util"] }
tower-http = { version = "0.4.0", features = ["trace"] }
clap = { version = "4.3.8", features = ["derive"] }
toml = "0.7.5"
parking_lot = "0.12.1"
57 changes: 57 additions & 0 deletions src/cache.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
use crate::db::{list_origins, Origin};
use parking_lot::RwLock;
use sqlx::SqlitePool;
use std::collections::HashMap;
use std::sync::Arc;

use crate::error::AppError;

#[derive(Debug)]
pub struct OriginCache {
origins: Arc<RwLock<HashMap<String, Origin>>>,
pool: Arc<SqlitePool>,
}

impl OriginCache {
pub fn new(pool: Arc<SqlitePool>) -> Self {
OriginCache {
origins: Arc::new(RwLock::new(HashMap::new())),
pool,
}
}

pub async fn refresh(&self) -> Result<(), AppError> {
// Fetch the latest origin data from the database using the provided SqlitePool
let new_origins = list_origins(&self.pool).await?;

// Create a new HashMap to store the updated origin data
let mut map = HashMap::new();

// Iterate over the fetched origins and insert them into the map
for origin in new_origins {
map.insert(origin.domain.clone(), origin);
}

// Update the cache by acquiring a write lock and replacing the HashMap
*self.origins.write() = map;
Ok(())
}

pub async fn get(&self, domain: &str) -> Option<Origin> {
tracing::info!("Get called on cache for domain: {}", domain);
let origins = self.origins.read();

// Look up domain in the cache and clone if found
let result = origins.get(domain).cloned();

// Mostly for development, but also useful if you want to see how often the cache is hit
if result.is_some() {
tracing::info!("Found origin in cache");
} else {
tracing::info!("Origin not found in cache");
}

// Return the result if found, otherwise None
result
}
}
2 changes: 1 addition & 1 deletion src/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use sqlx::Executor;

use crate::ingest::HttpRequest;

#[derive(Debug, Deserialize, Serialize, sqlx::FromRow)]
#[derive(Debug, Deserialize, Serialize, sqlx::FromRow, Clone)]
pub struct Origin {
pub id: i64,
pub domain: String,
Expand Down
8 changes: 8 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use axum::http::StatusCode;
use axum::response::{IntoResponse, Response};
use std::fmt;

#[derive(Debug)]
pub struct AppError(anyhow::Error);

impl IntoResponse for AppError {
Expand All @@ -14,6 +16,12 @@ impl IntoResponse for AppError {
}
}

impl fmt::Display for AppError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Display::fmt(&self.0, f)
}
}

impl<E> From<E> for AppError
where
E: Into<anyhow::Error>,
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
pub mod cache;
pub mod db;
pub mod error;
pub mod ingest;
Expand Down
16 changes: 9 additions & 7 deletions src/proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@ use hyper::client::HttpConnector;
use hyper::Body;
use hyper::Response;
use sqlx::SqlitePool;
use std::sync::Arc;

use crate::cache::OriginCache;
use crate::db::insert_attempt;
use crate::db::list_origins;
use crate::db::QueuedRequest;

pub type Client = hyper::client::Client<HttpConnector, Body>;
Expand Down Expand Up @@ -72,14 +73,15 @@ async fn map_origin(pool: &SqlitePool, req: &QueuedRequest) -> Result<Option<Uri
};
tracing::debug!("authority = {}", &authority);

let origins = list_origins(pool).await?;
tracing::debug!("origins = {:?}", &origins);
let matching_origin = origins
.iter()
.find(|origin| origin.domain == authority.as_str());
let origins_cache = Arc::new(OriginCache::new(pool.clone().into()));
tracing::debug!("origins = {:?}", &origins_cache);

origins_cache.refresh().await.unwrap();

let matching_origin = origins_cache.get(&authority.as_str()).await;

let origin_uri = match matching_origin {
Some(origin) => &origin.origin_uri,
Some(origin) => origin.origin_uri.clone(),
None => {
tracing::trace!("no match found");
return Ok(None);
Expand Down

0 comments on commit db5dc31

Please sign in to comment.