diff --git a/src/cache.rs b/src/cache.rs index 0e058bf..93ec7d2 100644 --- a/src/cache.rs +++ b/src/cache.rs @@ -1,29 +1,47 @@ -use crate::db::{list_origins, Origin}; +use crate::db::Origin; use parking_lot::RwLock; -use sqlx::SqlitePool; use std::collections::HashMap; use std::sync::Arc; use crate::error::AppError; #[derive(Debug)] -pub struct OriginCache { +pub struct OriginCache(pub(crate) Arc); + +impl OriginCache { + pub fn new() -> Self { + let inner = OriginCacheInner::new(); + Self(Arc::new(inner)) + } + + pub fn refresh(&self, new_origins: Vec) -> Result<(), AppError> { + self.0.refresh(new_origins) + } + + pub fn get(&self, domain: &str) -> Option { + self.0.get(domain) + } +} + +impl Clone for OriginCache { + fn clone(&self) -> Self { + Self(Arc::clone(&self.0)) + } +} + +#[derive(Debug)] +pub struct OriginCacheInner { origins: Arc>>, - pool: Arc, } -impl OriginCache { - pub fn new(pool: Arc) -> Self { - OriginCache { +impl OriginCacheInner { + pub fn new() -> Self { + Self { origins: Arc::new(RwLock::new(HashMap::new())), - pool, } } - pub async fn refresh(&self) -> Result<(), AppError> { - // Fetch the latest origin data from the database using the provided SqlitePool - let new_origins = list_origins(&self.pool).await?; - + pub fn refresh(&self, new_origins: Vec) -> Result<(), AppError> { // Create a new HashMap to store the updated origin data let mut map = HashMap::new(); @@ -37,8 +55,8 @@ impl OriginCache { Ok(()) } - pub async fn get(&self, domain: &str) -> Option { - tracing::info!("Get called on cache for domain: {}", domain); + pub fn get(&self, domain: &str) -> Option { + tracing::info!("Got called on cache for domain: {}", domain); let origins = self.origins.read(); // Look up domain in the cache and clone if found diff --git a/src/lib.rs b/src/lib.rs index d4ef4cf..d351e92 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -16,9 +16,11 @@ use axum::http::StatusCode; use axum::response::IntoResponse; use axum::{routing::post, Router}; use hyper::HeaderMap; +use queue::RetryQueue; use serde::Deserialize; use sqlx::sqlite::SqlitePool; +use crate::cache::OriginCache; use crate::db::{ensure_schema, insert_request, mark_complete, mark_error}; use crate::error::AppError; use crate::ingest::HttpRequest; @@ -44,28 +46,32 @@ impl Default for Config { } } -pub async fn app(config: &Config) -> Result<(Router, Router, SqlitePool)> { +pub async fn app(config: Config) -> Result<(Router, Router, RetryQueue)> { let pool = SqlitePool::connect(&config.database_url).await?; - let pool2 = pool.clone(); - ensure_schema(&pool).await?; - let mgmt_router = mgmt::router(pool.clone()); + let origin_cache = OriginCache::new(); + + let mgmt_router = mgmt::router(pool.clone(), origin_cache.clone()); let client = Client::new(); let router = Router::new() .route("/", post(handler)) .route("/*path", post(handler)) - .layer(Extension(pool)) + .layer(Extension(pool.clone())) + .layer(Extension(origin_cache.clone())) .with_state(client); - Ok((router, mgmt_router, pool2)) + let retry_queue = RetryQueue::new(pool, origin_cache); + + Ok((router, mgmt_router, retry_queue)) } #[tracing::instrument(level = "trace", "ingest", skip_all)] async fn handler( State(client): State, Extension(pool): Extension, + Extension(origin_cache): Extension, req: Request, ) -> StdResult { let method = req.method().to_string(); @@ -85,7 +91,7 @@ async fn handler( let queued_req = insert_request(&pool, r).await?; let req_id = queued_req.id; - let is_success = proxy(&pool, &client, queued_req).await?; + let is_success = proxy(&pool, &origin_cache, &client, queued_req).await?; if is_success { mark_complete(&pool, req_id).await?; diff --git a/src/main.rs b/src/main.rs index 7a46136..b9cb66d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,8 +1,5 @@ -use std::time::Duration; - use anyhow::Result; use clap::Parser; -use tokio::time; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; #[derive(Parser, Debug)] @@ -28,7 +25,7 @@ async fn main() -> Result<()> { None => soldr::Config::default(), }; - let (ingest, mgmt, pool) = soldr::app(&config).await?; + let (ingest, mgmt, retry_queue) = soldr::app(config).await?; let mgmt_listener = config.management_listener.parse()?; let ingest_listener = config.ingest_listener.parse()?; @@ -45,14 +42,7 @@ async fn main() -> Result<()> { tokio::spawn(async move { tracing::info!("starting retry queue"); - let mut interval = time::interval(Duration::from_secs(60)); - - loop { - interval.tick().await; - let pool2 = pool.clone(); - tracing::trace!("retrying failed requests"); - soldr::queue::tick(pool2).await; - } + retry_queue.start().await; }); tracing::info!("ingest listening on {}", ingest_listener); diff --git a/src/mgmt.rs b/src/mgmt.rs index bb5448b..23d0174 100644 --- a/src/mgmt.rs +++ b/src/mgmt.rs @@ -9,15 +9,17 @@ use serde::{Deserialize, Serialize}; use sqlx::sqlite::SqlitePool; use tracing::Level; +use crate::cache::OriginCache; use crate::db; use crate::error::AppError; -pub fn router(pool: SqlitePool) -> Router { +pub fn router(pool: SqlitePool, origin_cache: OriginCache) -> Router { Router::new() .route("/origins", post(create_origin)) .route("/requests", get(list_requests)) .route("/attempts", get(list_attempts)) .layer(Extension(pool)) + .layer(Extension(origin_cache)) } async fn list_requests( @@ -52,6 +54,7 @@ pub struct CreateOrigin { async fn create_origin( Extension(pool): Extension, + Extension(origin_cache): Extension, Json(payload): Json, ) -> StdResult, AppError> { let span = tracing::span!(Level::TRACE, "create_origin"); @@ -61,5 +64,8 @@ async fn create_origin( let origin = db::insert_origin(&pool, &payload.domain, &payload.origin_uri).await?; tracing::debug!("response = {:?}", &origin); + let origins = db::list_origins(&pool).await?; + origin_cache.refresh(origins).unwrap(); + Ok(Json(origin)) } diff --git a/src/proxy.rs b/src/proxy.rs index 8084257..e34926e 100644 --- a/src/proxy.rs +++ b/src/proxy.rs @@ -5,7 +5,6 @@ use hyper::client::HttpConnector; use hyper::Body; use hyper::Response; use sqlx::SqlitePool; -use std::sync::Arc; use crate::cache::OriginCache; use crate::db::insert_attempt; @@ -13,8 +12,13 @@ use crate::db::QueuedRequest; pub type Client = hyper::client::Client; -pub async fn proxy(pool: &SqlitePool, client: &Client, mut req: QueuedRequest) -> Result { - let uri = map_origin(pool, &req).await?; +pub async fn proxy( + pool: &SqlitePool, + origin_cache: &OriginCache, + client: &Client, + mut req: QueuedRequest, +) -> Result { + let uri = map_origin(origin_cache, &req).await?; if uri.is_none() { // no origin found, so mark as complete and move on @@ -46,7 +50,7 @@ pub async fn proxy(pool: &SqlitePool, client: &Client, mut req: QueuedRequest) - Ok(is_success) } -async fn map_origin(pool: &SqlitePool, req: &QueuedRequest) -> Result> { +async fn map_origin(origin_cache: &OriginCache, req: &QueuedRequest) -> Result> { let uri = Uri::try_from(&req.uri)?; let parts = uri.into_parts(); @@ -73,15 +77,10 @@ async fn map_origin(pool: &SqlitePool, req: &QueuedRequest) -> Result origin.origin_uri.clone(), + Some(origin) => origin.origin_uri, None => { tracing::trace!("no match found"); return Ok(None); diff --git a/src/queue.rs b/src/queue.rs index c5efa64..3cf840d 100644 --- a/src/queue.rs +++ b/src/queue.rs @@ -1,25 +1,52 @@ +use std::time::Duration; + use anyhow::Result; use sqlx::sqlite::SqlitePool; +use tokio::time; + +use crate::cache::OriginCache; use crate::{ db::{list_failed_requests, mark_complete, mark_error, QueuedRequest}, proxy::{self, Client}, }; -pub async fn tick(pool: SqlitePool) { - if let Err(err) = do_tick(pool).await { - // TODO flow through the request id - tracing::error!("tick error {:?}", err); +pub struct RetryQueue { + pool: SqlitePool, + origin_cache: OriginCache, +} + +impl RetryQueue { + pub fn new(pool: SqlitePool, origin_cache: OriginCache) -> Self { + Self { pool, origin_cache } + } + + pub async fn start(&self) { + let mut interval = time::interval(Duration::from_secs(60)); + + loop { + interval.tick().await; + tracing::trace!("retrying failed requests"); + self.tick().await; + } + } + + pub async fn tick(&self) { + if let Err(err) = do_tick(&self.pool, &self.origin_cache).await { + // TODO flow through the request id + tracing::error!("tick error {:?}", err); + } } } -async fn do_tick(pool: SqlitePool) -> Result<()> { - let requests = list_failed_requests(&pool).await?; +async fn do_tick(pool: &SqlitePool, origin_cache: &OriginCache) -> Result<()> { + let requests = list_failed_requests(pool).await?; let mut tasks = Vec::with_capacity(requests.len()); for request in requests { let pool2 = pool.clone(); - tasks.push(tokio::spawn(retry_request(pool2, request))); + let origin_cache2 = origin_cache.clone(); + tasks.push(tokio::spawn(retry_request(pool2, origin_cache2, request))); } for task in tasks { @@ -32,12 +59,16 @@ async fn do_tick(pool: SqlitePool) -> Result<()> { Ok(()) } -async fn retry_request(pool: SqlitePool, request: QueuedRequest) -> Result<()> { +async fn retry_request( + pool: SqlitePool, + origin_cache: OriginCache, + request: QueuedRequest, +) -> Result<()> { tracing::trace!("retrying {:?}", &request); let req_id = request.id; let client = Client::new(); - let is_success = proxy::proxy(&pool, &client, request).await?; + let is_success = proxy::proxy(&pool, &origin_cache, &client, request).await?; if is_success { mark_complete(&pool, req_id).await?; diff --git a/tests/queue.rs b/tests/queue.rs index 7d2dc28..a51b3ae 100644 --- a/tests/queue.rs +++ b/tests/queue.rs @@ -36,7 +36,7 @@ async fn queue_retry_request() { .unwrap(); }); - let (ingest, mgmt, pool) = app(&common::config()).await.unwrap(); + let (ingest, mgmt, retry_queue) = app(common::config()).await.unwrap(); // create an origin mapping let domain = "example.wh.soldr.dev"; @@ -120,7 +120,7 @@ async fn queue_retry_request() { assert_eq!(attempts[0].response_status, 500); assert_eq!(attempts[0].response_body, b"unexpected error"); - soldr::queue::tick(pool).await; + retry_queue.tick().await; // use management API to verify an attempt was made let response = mgmt