Skip to content

Commit

Permalink
simplify FileCache by introducing the Cachable Trait
Browse files Browse the repository at this point in the history
this masively reduces the type noise.
only drawback being that implementations on foreign types requires
a wrapper now.
still worth it + the types are actually nameable now.
  • Loading branch information
djugei committed Oct 11, 2024
1 parent 39ca8d1 commit 07de06f
Show file tree
Hide file tree
Showing 3 changed files with 199 additions and 184 deletions.
87 changes: 45 additions & 42 deletions async_file_cache/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,23 @@ use tracing::{debug, trace};

use tokio::io::AsyncSeekExt;
use tokio::{fs::File, sync::Mutex, sync::Semaphore};

pub trait Cacheable {
// todo: get rid of the Clone bound, it seems to not be strictly nescesary, the entry-api is just a bit unwieldy that way
type Key: Eq + Hash + Clone;
type Error;

/// Turns the Key into a path where the computation is cached
fn key_to_path(k: &Self::Key) -> PathBuf;

/// does the expensive operation.
/// should not block, either on io or on compute.
/// use async_runtime::spawn for calculation
/// and async io for io.
/// is expected to serialize the computed value to the provided file
/// should not panic but cleanup code is in place.
fn gen_value(&self, k: &Self::Key, f: File) -> impl Future<Output = Result<File, Self::Error>>;
}
/**
File-Backed Cache:
Execute an expensive opertion to generate a Resource and store the results in a file on disk.
Expand All @@ -15,55 +32,31 @@ use tokio::{fs::File, sync::Mutex, sync::Semaphore};
It is possible to set a max parallelism level.
todo: streaming, currently the whole thing is generated at once
*/
pub struct FileCache<Key, S, KF, F, FF, E>
pub struct FileCache<State>
where
Key: Hash + PartialEq + Eq + Clone + Send,
S: Clone + Send,
KF: Send + Fn(&S, &Key) -> PathBuf,
F: Send + Fn(S, Key, File) -> FF,
FF: Future<Output = Result<File, E>> + Send,
E: Send,
State: Cacheable,
{
in_flight: tokio::sync::Mutex<HashMap<Key, tokio::sync::watch::Receiver<()>>>,
state: S,
in_flight: tokio::sync::Mutex<HashMap<State::Key, tokio::sync::watch::Receiver<()>>>,
state: State,
max_para: Option<Semaphore>,
kf: KF,
f: F,
}

impl<Key, S, KF, F, FF, E> FileCache<Key, S, KF, F, FF, E>
where
Key: Hash + PartialEq + Eq + Clone + Send,
S: Clone + Send,
KF: Send + Fn(&S, &Key) -> PathBuf,
F: Send + Fn(S, Key, File) -> FF,
// fixme: why does this need to return a file? why not just a result
FF: Future<Output = Result<File, E>> + Send,
E: Send,
{
impl<State: Cacheable> FileCache<State> {
/**
* # Parameters:
* kf: function that turns the Key into a path where the result is cached
* f: function returning the future that does the expensive operation.
* The future should not be blocking either by long calculations or io operations
* use async_runtime::spawn for calculations and async io for io
* f needs to take care of serializing the value to the file on its own
* f should not panic, but measures are taken to not leave inconsistent state
* max_para: maximum number of expensive operations in flight at the same time
*/
pub fn new(init_state: S, kf: KF, f: F, max_para: Option<usize>) -> Self {
pub fn new(init_state: State, max_para: Option<usize>) -> Self {
let max_para = max_para.map(Semaphore::new);
let in_flight = Mutex::new(HashMap::new());
Self {
state: init_state,
in_flight,
max_para,
kf,
f,
}
}

pub async fn get_or_generate(&self, key: Key) -> std::io::Result<Result<File, E>> {
pub async fn get_or_generate(&self, key: State::Key) -> std::io::Result<Result<File, State::Error>> {
let mut oo = tokio::fs::OpenOptions::new();
let read = oo.read(true).write(false).create(false);
let mut oo = tokio::fs::OpenOptions::new();
Expand All @@ -76,7 +69,7 @@ where
Entry::Occupied(mut entry) => {
let mut e = entry.get_mut().clone();
drop(in_flight);
let path = (self.kf)(&self.state, &key);
let path = State::key_to_path(&key);
debug!("waiting {:?}", path);
match e.changed().await {
Ok(()) => {
Expand Down Expand Up @@ -111,7 +104,7 @@ where
}
}
Entry::Vacant(entry) => {
let path = (self.kf)(&self.state, &key);
let path = State::key_to_path(&key);
match read.open(&path).await {
Ok(f) => {
debug!("exists {:?}", path);
Expand Down Expand Up @@ -142,7 +135,7 @@ where
};
// do the expensive operation

let f = (self.f)(self.state.clone(), key.clone(), w);
let f = self.state.gen_value(&key, w);
let f = std::pin::pin!(f);
let f = f.await;
let _f = match f {
Expand Down Expand Up @@ -178,18 +171,28 @@ where

#[test]
fn cache_simple() {
use futures_util::future::join_all;

tracing_subscriber::fmt().with_max_level(tracing::Level::DEBUG).init();

use futures_util::future::join_all;
let kf = |_: &(), s: &String| PathBuf::from(s);
async fn inner_f(_s: (), key: String, mut file: File) -> Result<File, std::io::Error> {
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
use tokio::io::AsyncWriteExt;
file.write_all(key.as_bytes()).await?;
Ok(file)
struct TestState;
impl Cacheable for TestState {
type Key = String;
type Error = std::io::Error;

fn key_to_path(k: &Self::Key) -> PathBuf {
PathBuf::from(k)
}

async fn gen_value(&self, key: &Self::Key, mut file: File) -> Result<File, Self::Error> {
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
use tokio::io::AsyncWriteExt;
file.write_all(key.as_bytes()).await?;
Ok(file)
}
}

let c = FileCache::new((), kf, inner_f, None);
let c = FileCache::new(TestState, None);

let tmpdir = tempfile::TempDir::new().unwrap();
let basepath = tmpdir.path();
Expand Down
129 changes: 129 additions & 0 deletions server/src/caches.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
use std::path::PathBuf;

use async_file_cache::{Cacheable, FileCache};
use parsing::{Delta, Package};
use reqwest::Client;
use thiserror::Error;
use tokio::fs::File;
use tracing::{debug, info, Instrument};

use crate::{FALLBACK_MIRROR, MIRROR};

#[derive(Error, Debug)]
pub enum DownloadError {
#[error("could not write to file: {0}")]
Io(#[from] std::io::Error),
#[error("http request failed: {0}")]
Connection(#[from] reqwest::Error),
#[error("bad status code: {status} while fetching {url}")]
Status {
url: reqwest::Url,
status: reqwest::StatusCode,
},
}

pub struct PackageCache(pub Client);
impl Cacheable for PackageCache {
type Key = Package;
type Error = DownloadError;

fn key_to_path(p: &Self::Key) -> PathBuf {
let mut path = crate::get_pkg_path();
path.push(p.to_string());
path
}

#[tracing::instrument(level = "info", skip(self, file), "Downloading")]
async fn gen_value(&self, key: &Self::Key, mut file: File) -> Result<File, Self::Error> {
use tokio::io::AsyncWriteExt;

let mirror = MIRROR.get().expect("initialized");
let uri = format!("{mirror}{key}");
debug!(key = key.to_string(), uri, "getting from primary");
let mut response = self.0.get(uri).send().await?;

if response.status() == reqwest::StatusCode::NOT_FOUND {
// fall back to archive mirror
let fallback_mirror = FALLBACK_MIRROR.get().expect("initialized");
let uri = format!("{fallback_mirror}{key}");
info!(key = key.to_string(), "using fallback mirror");
response = self.0.get(uri).send().await?;
}
if !response.status().is_success() {
return Err(DownloadError::Status {
status: response.status(),
url: response.url().clone(),
});
}

while let Some(mut chunk) = response.chunk().await? {
file.write_all_buf(&mut chunk).await?;
}
Ok(file)
}
}

#[derive(Error, Debug)]
pub enum DeltaError {
#[error("could not download file: {0}")]
Download(#[from] DownloadError),
#[error("io error: {0}")]
Io(#[from] std::io::Error),
#[error("generation error: {0}")]
DeltaGen(#[from] ddelta::DiffError),
}

pub struct DeltaCache(pub FileCache<PackageCache>);

impl Cacheable for DeltaCache {
type Key = Delta;
type Error = DeltaError;

fn key_to_path(d: &Self::Key) -> PathBuf {
let mut p = crate::get_delta_path();
p.push(d.to_string());
p
}

async fn gen_value(&self, key: &Self::Key, patch: File) -> Result<File, Self::Error> {
let keystring = key.to_string();
let (old, new) = key.clone().get_both();
let old = self.0.get_or_generate(old);
let new = self.0.get_or_generate(new);
let (old, new) = tokio::join!(old, new);
let (old, new) = (old??, new??);

let patch = patch.into_std().await;
let old = old.into_std().await;
let mut old = zstd::Decoder::new(old)?;
let new = new.into_std().await;
let mut new = zstd::Decoder::new(new)?;
let span = tracing::info_span!("delta request", key = keystring);

let f: tokio::task::JoinHandle<Result<_, DeltaError>> = tokio::task::spawn_blocking(move || {
let mut zpatch = zstd::Encoder::new(patch, 22)?;
let e = zpatch.set_parameter(zstd::zstd_safe::CParameter::NbWorkers(4));
if let Err(e) = e {
debug!("failed to make zstd multithread: {e:?}");
}
let mut last_report = 0;
ddelta::generate_chunked(&mut old, &mut new, &mut zpatch, None, |s| match s {
ddelta::State::Reading => debug!(key = keystring, "reading"),
ddelta::State::Sorting => debug!(key = keystring, "sorting"),
ddelta::State::Working(p) => {
const MB: u64 = 1024 * 1024;
if p > last_report + (8 * MB) {
debug!(key = keystring, "working: {}MB done", p / MB);
last_report = p;
}
}
})?;
Ok(zpatch.finish()?)
});
let f = f.instrument(span).await.expect("threading error")?;

let f = File::from_std(f);

Ok(f)
}
}
Loading

0 comments on commit 07de06f

Please sign in to comment.