diff --git a/Cargo.toml b/Cargo.toml index 6639818..9887ff5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "spot" -version = "0.1.0" +version = "0.2.0" edition = "2021" license = "Apache-2.0" documentation = "https://github.com/samply/spot" @@ -20,7 +20,7 @@ once_cell = "1" # Logging tracing = "0.1.37" tracing-subscriber = { version = "0.3.17", features = ["env-filter"] } -reqwest = { version = "0.11.20", default-features = false, features = ["stream"] } +reqwest = { version = "0.11.20", default-features = false, features = ["stream", "default-tls"] } tower-http = { version = "0.4.4", features = ["cors"] } [build-dependencies] diff --git a/src/banner.rs b/src/banner.rs index 3068a92..da63103 100644 --- a/src/banner.rs +++ b/src/banner.rs @@ -10,7 +10,7 @@ pub(crate) fn print_banner() { _ => "SNAPSHOT", }; info!( - "🌈 Samply.Spot v{} (built {} {}, {}) starting up ...", + "🌈 Samply.Spot v{} (built {} {}, {}) ready to take requests.", env!("CARGO_PKG_VERSION"), env!("BUILD_DATE"), env!("BUILD_TIME"), diff --git a/src/catalogue.rs b/src/catalogue.rs new file mode 100644 index 0000000..67ee22d --- /dev/null +++ b/src/catalogue.rs @@ -0,0 +1,139 @@ +use std::{collections::BTreeMap, time::Duration}; + +use reqwest::Url; +use serde::{Deserialize, Serialize}; +use serde_json::{json, Value}; +use tracing::{debug, info}; + +pub type Criteria = BTreeMap; + +pub type CriteriaGroup = BTreeMap; + +pub type CriteriaGroups = BTreeMap; + +fn get_element<'a>( + count: &'a CriteriaGroups, + key1: &'a str, + key2: &'a str, + key3: &'a str, +) -> Option<&'a u64> { + count + .get(key1) + .and_then(|group| group.get(key2)) + .and_then(|criteria| criteria.get(key3)) +} + +pub async fn get_extended_json(catalogue_url: Url, prism_url: Url) -> Value { + debug!("Fetching catalogue from {catalogue_url} ..."); + + let resp = reqwest::Client::new() + .get(catalogue_url) + .timeout(Duration::from_secs(30)) + .send() + .await + .expect("Unable to fetch catalogue from upstream; please check URL specified in config."); + + let mut json: Value = resp + .json() + .await + .expect("Unable to parse catalogue from upstream; please check URL specified in config."); + + let prism_resp = reqwest::Client::new() + .post(format!("{}criteria", prism_url)) + .header("Content-Type", "application/json") + .body("{\"sites\": []}") + .timeout(Duration::from_secs(300)) + .send() + .await + .expect("Unable to fetch response from Prism; please check it's running."); + + let mut counts: CriteriaGroups = prism_resp + .json() + .await + .expect("Unable to parse response from Prism into CriteriaGroups"); + + recurse(&mut json, &mut counts); //TODO remove from counts once copied into catalogue to make it O(n log n) + + info!("Catalogue built successfully."); + + json +} + +/// Key order: group key (e.g. patient) +/// \-- stratifier key (e.g. admin_gender) +/// \-- stratum key (e.g. male, other) +fn recurse(json: &mut Value, counts: &mut CriteriaGroups) { + match json { + Value::Array(arr) => { + for ele in arr { + recurse(ele, counts); + } + } + Value::Object(obj) => { + if !obj.contains_key("childCategories") { + for (_key, child_val) in obj.iter_mut() { + recurse(child_val, counts); + } + } else { + let group_key = obj.get("key").expect("Got JSON element with childCategories but without (group) key. Please check json.").as_str() + .expect("Got JSON where a criterion key was not a string. Please check json.").to_owned(); + + //TODO consolidate catalogue and MeasureReport group names + let group_key = if group_key == "patient" { + "patients" + } else if group_key == "tumor_classification" { + "diagnosis" + } else if group_key == "biosamples" { + "specimen" + } else { + &group_key + }; + + let children_cats = obj + .get_mut("childCategories") + .unwrap() + .as_array_mut() + .unwrap() + .iter_mut() + .filter(|item| item.get("type").unwrap_or(&Value::Null) == "EQUALS"); + + for child_cat in children_cats { + let stratifier_key = child_cat.get("key").expect("Got JSON element with childCategory that does not contain a (stratifier) key. Please check json.").as_str() + .expect("Got JSON where a criterion key was not a string. Please check json.").to_owned(); + + let criteria = child_cat + .get_mut("criteria") + .expect("Got JSON element with childCategory that does not contain a criteria array. Please check json.") + .as_array_mut() + .expect("Got JSON element with childCategory with criteria that are not an array. Please check json."); + + for criterion in criteria { + let criterion = criterion.as_object_mut().expect( + "Got JSON where a criterion was not an object. Please check json.", + ); + let stratum_key = criterion.get("key") + .expect("Got JSON where a criterion did not have a key. Please check json.") + .as_str() + .expect("Got JSON where a criterion key was not a string. Please check json."); + + let count_from_prism = + get_element(counts, &group_key, &stratifier_key, stratum_key); + + match count_from_prism { + Some(count) => { + criterion.insert("count".into(), json!(count)); + } + None => { + debug!( + "No count from Prism for {}, {}, {}", + group_key, stratifier_key, stratum_key + ); + } + } + } + } + } + } + _ => {} + } +} diff --git a/src/config.rs b/src/config.rs index 5e4d6cb..c14e3db 100644 --- a/src/config.rs +++ b/src/config.rs @@ -29,8 +29,16 @@ pub struct Config { pub project: Option, /// The socket address this server will bind to - #[clap(long, env, default_value = "0.0.0.0:8080")] + #[clap(long, env, default_value = "0.0.0.0:8055")] pub bind_addr: SocketAddr, + + /// URL to catalogue.json file + #[clap(long, env)] + pub catalogue_url: Url, + + /// URL to prism + #[clap(long, env, default_value= "http://localhost:8066")] + pub prism_url: Url } fn parse_cors(v: &str) -> Result { diff --git a/src/main.rs b/src/main.rs index a588d2e..ae20f9d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,5 +1,5 @@ use axum::{ - extract::{Json, Path, Query}, + extract::{Json, Path, Query, State}, http::HeaderValue, response::{IntoResponse, Response}, routing::{get, post}, @@ -12,12 +12,14 @@ use config::Config; use once_cell::sync::Lazy; use reqwest::{header, Method, StatusCode}; use serde::{Deserialize, Serialize}; +use serde_json::Value; use tower_http::cors::CorsLayer; use tracing::{info, warn, Level}; -use tracing_subscriber::{EnvFilter, util::SubscriberInitExt}; +use tracing_subscriber::{util::SubscriberInitExt, EnvFilter}; mod banner; mod beam; +mod catalogue; mod config; static CONFIG: Lazy = Lazy::new(Config::parse); @@ -30,6 +32,11 @@ static BEAM_CLIENT: Lazy = Lazy::new(|| { ) }); +#[derive(Clone)] +struct SharedState { + extended_json: Value, +} + #[tokio::main] async fn main() { tracing_subscriber::FmtSubscriber::builder() @@ -37,8 +44,13 @@ async fn main() { .with_env_filter(EnvFilter::from_default_env()) .finish() .init(); - banner::print_banner(); + info!("{:#?}", Lazy::force(&CONFIG)); + + let extended_json = + catalogue::get_extended_json(CONFIG.catalogue_url.clone(), CONFIG.prism_url.clone()).await; + let state = SharedState { extended_json }; + // TODO: Add check for reachability of beam-proxy let cors = CorsLayer::new() @@ -49,9 +61,13 @@ async fn main() { let app = Router::new() .route("/beam", post(handle_create_beam_task)) .route("/beam/:task_id", get(handle_listen_to_beam_tasks)) + .route("/catalogue", get(handle_get_catalogue)) + .with_state(state) .layer(axum::middleware::map_response(banner::set_server_header)) .layer(cors); + banner::print_banner(); + axum::Server::bind(&CONFIG.bind_addr) .serve(app.into_make_service()) .await @@ -102,7 +118,10 @@ async fn handle_listen_to_beam_tasks( .send() .await .map_err(|err| { - println!("Failed request to {} with error: {}", CONFIG.beam_proxy_url, err); + println!( + "Failed request to {} with error: {}", + CONFIG.beam_proxy_url, err + ); ( StatusCode::BAD_GATEWAY, format!("Error calling beam, check the server logs."), @@ -129,3 +148,7 @@ fn convert_response(response: reqwest::Response) -> axum::response::Response { .unwrap() .into_response() } + +async fn handle_get_catalogue(State(state): State) -> Json { + Json(state.extended_json) +}