Skip to content

Commit

Permalink
Merge pull request #23 from samply/catalogue-extender
Browse files Browse the repository at this point in the history
Catalogue extender
  • Loading branch information
enola-dkfz committed Apr 4, 2024
2 parents bdc96ee + 66b52e4 commit b5bea21
Show file tree
Hide file tree
Showing 5 changed files with 178 additions and 8 deletions.
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "spot"
version = "0.1.0"
version = "0.2.0"
edition = "2021"
license = "Apache-2.0"
documentation = "https://github.com/samply/spot"
Expand All @@ -20,7 +20,7 @@ once_cell = "1"
# Logging
tracing = "0.1.37"
tracing-subscriber = { version = "0.3.17", features = ["env-filter"] }
reqwest = { version = "0.11.20", default-features = false, features = ["stream"] }
reqwest = { version = "0.11.20", default-features = false, features = ["stream", "default-tls"] }
tower-http = { version = "0.4.4", features = ["cors"] }

[build-dependencies]
Expand Down
2 changes: 1 addition & 1 deletion src/banner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ pub(crate) fn print_banner() {
_ => "SNAPSHOT",
};
info!(
"🌈 Samply.Spot v{} (built {} {}, {}) starting up ...",
"🌈 Samply.Spot v{} (built {} {}, {}) ready to take requests.",
env!("CARGO_PKG_VERSION"),
env!("BUILD_DATE"),
env!("BUILD_TIME"),
Expand Down
139 changes: 139 additions & 0 deletions src/catalogue.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
use std::{collections::BTreeMap, time::Duration};

use reqwest::Url;
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use tracing::{debug, info};

pub type Criteria = BTreeMap<String, u64>;

pub type CriteriaGroup = BTreeMap<String, Criteria>;

pub type CriteriaGroups = BTreeMap<String, CriteriaGroup>;

fn get_element<'a>(
count: &'a CriteriaGroups,
key1: &'a str,
key2: &'a str,
key3: &'a str,
) -> Option<&'a u64> {
count
.get(key1)
.and_then(|group| group.get(key2))
.and_then(|criteria| criteria.get(key3))
}

pub async fn get_extended_json(catalogue_url: Url, prism_url: Url) -> Value {
debug!("Fetching catalogue from {catalogue_url} ...");

let resp = reqwest::Client::new()
.get(catalogue_url)
.timeout(Duration::from_secs(30))
.send()
.await
.expect("Unable to fetch catalogue from upstream; please check URL specified in config.");

let mut json: Value = resp
.json()
.await
.expect("Unable to parse catalogue from upstream; please check URL specified in config.");

let prism_resp = reqwest::Client::new()
.post(format!("{}criteria", prism_url))
.header("Content-Type", "application/json")
.body("{\"sites\": []}")
.timeout(Duration::from_secs(300))
.send()
.await
.expect("Unable to fetch response from Prism; please check it's running.");

let mut counts: CriteriaGroups = prism_resp
.json()
.await
.expect("Unable to parse response from Prism into CriteriaGroups");

recurse(&mut json, &mut counts); //TODO remove from counts once copied into catalogue to make it O(n log n)

info!("Catalogue built successfully.");

json
}

/// Key order: group key (e.g. patient)
/// \-- stratifier key (e.g. admin_gender)
/// \-- stratum key (e.g. male, other)
fn recurse(json: &mut Value, counts: &mut CriteriaGroups) {
match json {
Value::Array(arr) => {
for ele in arr {
recurse(ele, counts);
}
}
Value::Object(obj) => {
if !obj.contains_key("childCategories") {
for (_key, child_val) in obj.iter_mut() {
recurse(child_val, counts);
}
} else {
let group_key = obj.get("key").expect("Got JSON element with childCategories but without (group) key. Please check json.").as_str()
.expect("Got JSON where a criterion key was not a string. Please check json.").to_owned();

//TODO consolidate catalogue and MeasureReport group names
let group_key = if group_key == "patient" {
"patients"
} else if group_key == "tumor_classification" {
"diagnosis"
} else if group_key == "biosamples" {
"specimen"
} else {
&group_key
};

let children_cats = obj
.get_mut("childCategories")
.unwrap()
.as_array_mut()
.unwrap()
.iter_mut()
.filter(|item| item.get("type").unwrap_or(&Value::Null) == "EQUALS");

for child_cat in children_cats {
let stratifier_key = child_cat.get("key").expect("Got JSON element with childCategory that does not contain a (stratifier) key. Please check json.").as_str()
.expect("Got JSON where a criterion key was not a string. Please check json.").to_owned();

let criteria = child_cat
.get_mut("criteria")
.expect("Got JSON element with childCategory that does not contain a criteria array. Please check json.")
.as_array_mut()
.expect("Got JSON element with childCategory with criteria that are not an array. Please check json.");

for criterion in criteria {
let criterion = criterion.as_object_mut().expect(
"Got JSON where a criterion was not an object. Please check json.",
);
let stratum_key = criterion.get("key")
.expect("Got JSON where a criterion did not have a key. Please check json.")
.as_str()
.expect("Got JSON where a criterion key was not a string. Please check json.");

let count_from_prism =
get_element(counts, &group_key, &stratifier_key, stratum_key);

match count_from_prism {
Some(count) => {
criterion.insert("count".into(), json!(count));
}
None => {
debug!(
"No count from Prism for {}, {}, {}",
group_key, stratifier_key, stratum_key
);
}
}
}
}
}
}
_ => {}
}
}
10 changes: 9 additions & 1 deletion src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,16 @@ pub struct Config {
pub project: Option<String>,

/// The socket address this server will bind to
#[clap(long, env, default_value = "0.0.0.0:8080")]
#[clap(long, env, default_value = "0.0.0.0:8055")]
pub bind_addr: SocketAddr,

/// URL to catalogue.json file
#[clap(long, env)]
pub catalogue_url: Url,

/// URL to prism
#[clap(long, env, default_value= "http://localhost:8066")]
pub prism_url: Url
}

fn parse_cors(v: &str) -> Result<AllowOrigin, InvalidHeaderValue> {
Expand Down
31 changes: 27 additions & 4 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use axum::{
extract::{Json, Path, Query},
extract::{Json, Path, Query, State},
http::HeaderValue,
response::{IntoResponse, Response},
routing::{get, post},
Expand All @@ -12,12 +12,14 @@ use config::Config;
use once_cell::sync::Lazy;
use reqwest::{header, Method, StatusCode};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use tower_http::cors::CorsLayer;
use tracing::{info, warn, Level};
use tracing_subscriber::{EnvFilter, util::SubscriberInitExt};
use tracing_subscriber::{util::SubscriberInitExt, EnvFilter};

mod banner;
mod beam;
mod catalogue;
mod config;

static CONFIG: Lazy<Config> = Lazy::new(Config::parse);
Expand All @@ -30,15 +32,25 @@ static BEAM_CLIENT: Lazy<BeamClient> = Lazy::new(|| {
)
});

#[derive(Clone)]
struct SharedState {
extended_json: Value,
}

#[tokio::main]
async fn main() {
tracing_subscriber::FmtSubscriber::builder()
.with_max_level(Level::DEBUG)
.with_env_filter(EnvFilter::from_default_env())
.finish()
.init();
banner::print_banner();

info!("{:#?}", Lazy::force(&CONFIG));

let extended_json =
catalogue::get_extended_json(CONFIG.catalogue_url.clone(), CONFIG.prism_url.clone()).await;
let state = SharedState { extended_json };

// TODO: Add check for reachability of beam-proxy

let cors = CorsLayer::new()
Expand All @@ -49,9 +61,13 @@ async fn main() {
let app = Router::new()
.route("/beam", post(handle_create_beam_task))
.route("/beam/:task_id", get(handle_listen_to_beam_tasks))
.route("/catalogue", get(handle_get_catalogue))
.with_state(state)
.layer(axum::middleware::map_response(banner::set_server_header))
.layer(cors);

banner::print_banner();

axum::Server::bind(&CONFIG.bind_addr)
.serve(app.into_make_service())
.await
Expand Down Expand Up @@ -102,7 +118,10 @@ async fn handle_listen_to_beam_tasks(
.send()
.await
.map_err(|err| {
println!("Failed request to {} with error: {}", CONFIG.beam_proxy_url, err);
println!(
"Failed request to {} with error: {}",
CONFIG.beam_proxy_url, err
);
(
StatusCode::BAD_GATEWAY,
format!("Error calling beam, check the server logs."),
Expand All @@ -129,3 +148,7 @@ fn convert_response(response: reqwest::Response) -> axum::response::Response {
.unwrap()
.into_response()
}

async fn handle_get_catalogue(State(state): State<SharedState>) -> Json<Value> {
Json(state.extended_json)
}

0 comments on commit b5bea21

Please sign in to comment.