Skip to content

Commit

Permalink
improve origin cache implementation
Browse files Browse the repository at this point in the history
- hide away Arc details
- use dependency injection
- create a retry queue implementation instead of returning pool
  directly to main
  • Loading branch information
hjr3 authored and QuentinMoss committed Jul 2, 2023
1 parent 94c6757 commit ac50b5f
Show file tree
Hide file tree
Showing 7 changed files with 106 additions and 56 deletions.
46 changes: 32 additions & 14 deletions src/cache.rs
Original file line number Diff line number Diff line change
@@ -1,29 +1,47 @@
use crate::db::{list_origins, Origin};
use crate::db::Origin;
use parking_lot::RwLock;
use sqlx::SqlitePool;
use std::collections::HashMap;
use std::sync::Arc;

use crate::error::AppError;

#[derive(Debug)]
pub struct OriginCache {
pub struct OriginCache(pub(crate) Arc<OriginCacheInner>);

impl OriginCache {
pub fn new() -> Self {
let inner = OriginCacheInner::new();
Self(Arc::new(inner))
}

pub fn refresh(&self, new_origins: Vec<Origin>) -> Result<(), AppError> {
self.0.refresh(new_origins)
}

pub fn get(&self, domain: &str) -> Option<Origin> {
self.0.get(domain)
}
}

impl Clone for OriginCache {
fn clone(&self) -> Self {
Self(Arc::clone(&self.0))
}
}

#[derive(Debug)]
pub struct OriginCacheInner {
origins: Arc<RwLock<HashMap<String, Origin>>>,
pool: Arc<SqlitePool>,
}

impl OriginCache {
pub fn new(pool: Arc<SqlitePool>) -> Self {
OriginCache {
impl OriginCacheInner {
pub fn new() -> Self {
Self {
origins: Arc::new(RwLock::new(HashMap::new())),
pool,
}
}

pub async fn refresh(&self) -> Result<(), AppError> {
// Fetch the latest origin data from the database using the provided SqlitePool
let new_origins = list_origins(&self.pool).await?;

pub fn refresh(&self, new_origins: Vec<Origin>) -> Result<(), AppError> {
// Create a new HashMap to store the updated origin data
let mut map = HashMap::new();

Expand All @@ -37,8 +55,8 @@ impl OriginCache {
Ok(())
}

pub async fn get(&self, domain: &str) -> Option<Origin> {
tracing::info!("Get called on cache for domain: {}", domain);
pub fn get(&self, domain: &str) -> Option<Origin> {
tracing::info!("Got called on cache for domain: {}", domain);
let origins = self.origins.read();

// Look up domain in the cache and clone if found
Expand Down
20 changes: 13 additions & 7 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@ use axum::http::StatusCode;
use axum::response::IntoResponse;
use axum::{routing::post, Router};
use hyper::HeaderMap;
use queue::RetryQueue;
use serde::Deserialize;
use sqlx::sqlite::SqlitePool;

use crate::cache::OriginCache;
use crate::db::{ensure_schema, insert_request, mark_complete, mark_error};
use crate::error::AppError;
use crate::ingest::HttpRequest;
Expand All @@ -44,28 +46,32 @@ impl Default for Config {
}
}

pub async fn app(config: &Config) -> Result<(Router, Router, SqlitePool)> {
pub async fn app(config: Config) -> Result<(Router, Router, RetryQueue)> {
let pool = SqlitePool::connect(&config.database_url).await?;
let pool2 = pool.clone();

ensure_schema(&pool).await?;

let mgmt_router = mgmt::router(pool.clone());
let origin_cache = OriginCache::new();

let mgmt_router = mgmt::router(pool.clone(), origin_cache.clone());

let client = Client::new();
let router = Router::new()
.route("/", post(handler))
.route("/*path", post(handler))
.layer(Extension(pool))
.layer(Extension(pool.clone()))
.layer(Extension(origin_cache.clone()))
.with_state(client);

Ok((router, mgmt_router, pool2))
let retry_queue = RetryQueue::new(pool, origin_cache);

Ok((router, mgmt_router, retry_queue))
}

#[tracing::instrument(level = "trace", "ingest", skip_all)]
async fn handler(
State(client): State<Client>,
Extension(pool): Extension<SqlitePool>,
Extension(origin_cache): Extension<OriginCache>,
req: Request<Body>,
) -> StdResult<impl IntoResponse, AppError> {
let method = req.method().to_string();
Expand All @@ -85,7 +91,7 @@ async fn handler(
let queued_req = insert_request(&pool, r).await?;
let req_id = queued_req.id;

let is_success = proxy(&pool, &client, queued_req).await?;
let is_success = proxy(&pool, &origin_cache, &client, queued_req).await?;

if is_success {
mark_complete(&pool, req_id).await?;
Expand Down
14 changes: 2 additions & 12 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
use std::time::Duration;

use anyhow::Result;
use clap::Parser;
use tokio::time;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};

#[derive(Parser, Debug)]
Expand All @@ -28,7 +25,7 @@ async fn main() -> Result<()> {
None => soldr::Config::default(),
};

let (ingest, mgmt, pool) = soldr::app(&config).await?;
let (ingest, mgmt, retry_queue) = soldr::app(config).await?;

let mgmt_listener = config.management_listener.parse()?;
let ingest_listener = config.ingest_listener.parse()?;
Expand All @@ -45,14 +42,7 @@ async fn main() -> Result<()> {

tokio::spawn(async move {
tracing::info!("starting retry queue");
let mut interval = time::interval(Duration::from_secs(60));

loop {
interval.tick().await;
let pool2 = pool.clone();
tracing::trace!("retrying failed requests");
soldr::queue::tick(pool2).await;
}
retry_queue.start().await;
});

tracing::info!("ingest listening on {}", ingest_listener);
Expand Down
8 changes: 7 additions & 1 deletion src/mgmt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,17 @@ use serde::{Deserialize, Serialize};
use sqlx::sqlite::SqlitePool;
use tracing::Level;

use crate::cache::OriginCache;
use crate::db;
use crate::error::AppError;

pub fn router(pool: SqlitePool) -> Router {
pub fn router(pool: SqlitePool, origin_cache: OriginCache) -> Router {
Router::new()
.route("/origins", post(create_origin))
.route("/requests", get(list_requests))
.route("/attempts", get(list_attempts))
.layer(Extension(pool))
.layer(Extension(origin_cache))
}

async fn list_requests(
Expand Down Expand Up @@ -52,6 +54,7 @@ pub struct CreateOrigin {

async fn create_origin(
Extension(pool): Extension<SqlitePool>,
Extension(origin_cache): Extension<OriginCache>,
Json(payload): Json<CreateOrigin>,
) -> StdResult<Json<db::Origin>, AppError> {
let span = tracing::span!(Level::TRACE, "create_origin");
Expand All @@ -61,5 +64,8 @@ async fn create_origin(
let origin = db::insert_origin(&pool, &payload.domain, &payload.origin_uri).await?;
tracing::debug!("response = {:?}", &origin);

let origins = db::list_origins(&pool).await?;
origin_cache.refresh(origins).unwrap();

Ok(Json(origin))
}
21 changes: 10 additions & 11 deletions src/proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,20 @@ use hyper::client::HttpConnector;
use hyper::Body;
use hyper::Response;
use sqlx::SqlitePool;
use std::sync::Arc;

use crate::cache::OriginCache;
use crate::db::insert_attempt;
use crate::db::QueuedRequest;

pub type Client = hyper::client::Client<HttpConnector, Body>;

pub async fn proxy(pool: &SqlitePool, client: &Client, mut req: QueuedRequest) -> Result<bool> {
let uri = map_origin(pool, &req).await?;
pub async fn proxy(
pool: &SqlitePool,
origin_cache: &OriginCache,
client: &Client,
mut req: QueuedRequest,
) -> Result<bool> {
let uri = map_origin(origin_cache, &req).await?;

if uri.is_none() {
// no origin found, so mark as complete and move on
Expand Down Expand Up @@ -46,7 +50,7 @@ pub async fn proxy(pool: &SqlitePool, client: &Client, mut req: QueuedRequest) -
Ok(is_success)
}

async fn map_origin(pool: &SqlitePool, req: &QueuedRequest) -> Result<Option<Uri>> {
async fn map_origin(origin_cache: &OriginCache, req: &QueuedRequest) -> Result<Option<Uri>> {
let uri = Uri::try_from(&req.uri)?;
let parts = uri.into_parts();

Expand All @@ -73,15 +77,10 @@ async fn map_origin(pool: &SqlitePool, req: &QueuedRequest) -> Result<Option<Uri
};
tracing::debug!("authority = {}", &authority);

let origins_cache = Arc::new(OriginCache::new(pool.clone().into()));
tracing::debug!("origins = {:?}", &origins_cache);

origins_cache.refresh().await.unwrap();

let matching_origin = origins_cache.get(&authority.as_str()).await;
let matching_origin = origin_cache.get(authority.as_str());

let origin_uri = match matching_origin {
Some(origin) => origin.origin_uri.clone(),
Some(origin) => origin.origin_uri,
None => {
tracing::trace!("no match found");
return Ok(None);
Expand Down
49 changes: 40 additions & 9 deletions src/queue.rs
Original file line number Diff line number Diff line change
@@ -1,25 +1,52 @@
use std::time::Duration;

use anyhow::Result;
use sqlx::sqlite::SqlitePool;
use tokio::time;

use crate::cache::OriginCache;

use crate::{
db::{list_failed_requests, mark_complete, mark_error, QueuedRequest},
proxy::{self, Client},
};

pub async fn tick(pool: SqlitePool) {
if let Err(err) = do_tick(pool).await {
// TODO flow through the request id
tracing::error!("tick error {:?}", err);
pub struct RetryQueue {
pool: SqlitePool,
origin_cache: OriginCache,
}

impl RetryQueue {
pub fn new(pool: SqlitePool, origin_cache: OriginCache) -> Self {
Self { pool, origin_cache }
}

pub async fn start(&self) {
let mut interval = time::interval(Duration::from_secs(60));

loop {
interval.tick().await;
tracing::trace!("retrying failed requests");
self.tick().await;
}
}

pub async fn tick(&self) {
if let Err(err) = do_tick(&self.pool, &self.origin_cache).await {
// TODO flow through the request id
tracing::error!("tick error {:?}", err);
}
}
}

async fn do_tick(pool: SqlitePool) -> Result<()> {
let requests = list_failed_requests(&pool).await?;
async fn do_tick(pool: &SqlitePool, origin_cache: &OriginCache) -> Result<()> {
let requests = list_failed_requests(pool).await?;

let mut tasks = Vec::with_capacity(requests.len());
for request in requests {
let pool2 = pool.clone();
tasks.push(tokio::spawn(retry_request(pool2, request)));
let origin_cache2 = origin_cache.clone();
tasks.push(tokio::spawn(retry_request(pool2, origin_cache2, request)));
}

for task in tasks {
Expand All @@ -32,12 +59,16 @@ async fn do_tick(pool: SqlitePool) -> Result<()> {
Ok(())
}

async fn retry_request(pool: SqlitePool, request: QueuedRequest) -> Result<()> {
async fn retry_request(
pool: SqlitePool,
origin_cache: OriginCache,
request: QueuedRequest,
) -> Result<()> {
tracing::trace!("retrying {:?}", &request);

let req_id = request.id;
let client = Client::new();
let is_success = proxy::proxy(&pool, &client, request).await?;
let is_success = proxy::proxy(&pool, &origin_cache, &client, request).await?;

if is_success {
mark_complete(&pool, req_id).await?;
Expand Down
4 changes: 2 additions & 2 deletions tests/queue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ async fn queue_retry_request() {
.unwrap();
});

let (ingest, mgmt, pool) = app(&common::config()).await.unwrap();
let (ingest, mgmt, retry_queue) = app(common::config()).await.unwrap();

// create an origin mapping
let domain = "example.wh.soldr.dev";
Expand Down Expand Up @@ -120,7 +120,7 @@ async fn queue_retry_request() {
assert_eq!(attempts[0].response_status, 500);
assert_eq!(attempts[0].response_body, b"unexpected error");

soldr::queue::tick(pool).await;
retry_queue.tick().await;

// use management API to verify an attempt was made
let response = mgmt
Expand Down

0 comments on commit ac50b5f

Please sign in to comment.