Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

upgrade http ecosystem to hyper 1 #21

Merged
merged 1 commit into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,17 @@ license = "Apache-2.0"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
axum = "0.6"
axum-macros = "0.4.1"
base64 = { version = "0.21.0", default_features = false }
http = "0.2"
reqwest = { version = "0.11.20", default_features = false, features = ["json", "default-tls", "stream"] }
axum = "0.7"
base64 = "0.22.1"
reqwest = { version = "0.12", default_features = false, features = ["json", "default-tls", "stream"] }
serde = { version = "1.0.152", features = ["serde_derive"] }
serde_json = "1.0.96"
thiserror = "1.0.38"
rand = { default-features = false, version = "0.8.5" }
chrono = "0.4.31"
tokio = { version = "1.25.0", default_features = false, features = ["signal", "rt-multi-thread", "macros"] }
beam-lib = { git = "https://github.com/samply/beam", branch = "develop", features = ["http-util"] }
tower-http = { version = "0.4.4", features = ["cors"] }
tower-http = { version = "0.5", features = ["cors"] }
async-sse = "5.1.0"
anyhow = "1"
futures-util = { version = "0.3", features = ["io"] }
Expand Down
2 changes: 1 addition & 1 deletion src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ fn get_query_unencoded() -> String {
)
}

fn parse_cors(v: &str) -> Result<AllowOrigin, http::header::InvalidHeaderValue> {
fn parse_cors(v: &str) -> Result<AllowOrigin, reqwest::header::InvalidHeaderValue> {
if v == "*" || v.to_lowercase() == "any" {
Ok(AllowOrigin::any())
} else {
Expand Down
55 changes: 31 additions & 24 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,35 +7,33 @@ mod mr;

use crate::errors::PrismError;
use crate::{config::CONFIG, mr::MeasureReport};
use base64::engine::general_purpose::STANDARD as BASE64;
use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _};
use futures_util::{StreamExt as _, TryStreamExt};
use http::HeaderValue;
use std::collections::HashSet;
use std::io;
use std::process::exit;
use std::sync::Arc;
use std::time::SystemTime;
use tokio::sync::Mutex;
use tokio::{net::TcpListener, sync::Mutex};

use axum::{
extract::State,
http::{header, StatusCode},
extract::{Json, State},
http::StatusCode,
response::{IntoResponse, Response},
routing::post,
Json, Router,
Router,
};
use reqwest::{header, header::HeaderValue, Method};

use base64::Engine as _;
use once_cell::sync::Lazy;
use reqwest::Method;
use serde::{Deserialize, Serialize};

use beam::create_beam_task;
use beam_lib::{AppId, BeamClient, MsgId};
use criteria::{combine_groups_of_criteria_groups, CriteriaGroups};
use std::{collections::HashMap, time::Duration};
use tower_http::cors::CorsLayer;
use tracing::{error, info, warn, debug};
use tracing::{debug, error, info, warn};

use beam_lib::{RawString, TaskResult};

Expand All @@ -49,7 +47,7 @@ static BEAM_CLIENT: Lazy<BeamClient> = Lazy::new(|| {

#[derive(Serialize, Deserialize, Clone, Debug)]
struct LensQuery {
sites: Vec<String>,
sites: Vec<String>,
}

type Site = String;
Expand Down Expand Up @@ -111,7 +109,7 @@ pub async fn main() {
spawn_site_querying(shared_state.clone());

let cors = CorsLayer::new()
.allow_methods([http::Method::GET, http::Method::POST])
.allow_methods([Method::GET, Method::POST])
.allow_origin(CONFIG.cors_origin.clone())
.allow_headers([header::CONTENT_TYPE]);

Expand All @@ -120,10 +118,12 @@ pub async fn main() {
.with_state(shared_state)
.layer(cors);

axum::Server::bind(&CONFIG.bind_addr)
.serve(app.into_make_service())
.await
.unwrap()
axum::serve(
TcpListener::bind(CONFIG.bind_addr).await.unwrap(),
app.into_make_service(),
)
.await
.unwrap()
}

fn spawn_site_querying(shared_state: SharedState) {
Expand Down Expand Up @@ -164,9 +164,11 @@ async fn handle_get_criteria(
if SystemTime::now().duration_since(cached.1).unwrap() < CRITERIACACHE_TTL {
Some(cached.0.clone())
} else {
debug!("Results for site {} in cache sadly expired, will query again", &site);
debug!(
"Results for site {} in cache sadly expired, will query again",
&site
);
None

}
}
None => {
Expand Down Expand Up @@ -246,14 +248,15 @@ async fn query_sites(
Ok(())
}

async fn get_results(shared_state: SharedState, task_id: MsgId, wait_count: usize) -> Result<(), PrismError> {
async fn get_results(
shared_state: SharedState,
task_id: MsgId,
wait_count: usize,
) -> Result<(), PrismError> {
let resp = BEAM_CLIENT
.raw_beam_request(
Method::GET,
&format!(
"v1/tasks/{}/results?wait_count={}",
task_id, wait_count
),
&format!("v1/tasks/{}/results?wait_count={}", task_id, wait_count),
)
.header(
header::ACCEPT,
Expand Down Expand Up @@ -304,7 +307,11 @@ async fn get_results(shared_state: SharedState, task_id: MsgId, wait_count: usiz
from.as_ref().split('.').nth(1).unwrap().to_string(), // extracting site name from app long name
(criteria, std::time::SystemTime::now()),
);
info!("Cached results from site {} for task {}", from.as_ref().split('.').nth(1).unwrap().to_string(), task_id);
info!(
"Cached results from site {} for task {}",
from.as_ref().split('.').nth(1).unwrap().to_string(),
task_id
);
}
Ok(())
}
Expand Down Expand Up @@ -334,7 +341,7 @@ async fn wait_for_beam_proxy() -> beam_lib::Result<()> {
loop {
match reqwest::get(format!("{}v1/health", CONFIG.beam_proxy_url)).await {
//FIXME why doesn't it work with url from config
Ok(res) if res.status() == StatusCode::OK => return Ok(()),
Ok(res) if res.status() == reqwest::StatusCode::OK => return Ok(()),
_ if tries <= MAX_RETRIES => tries += 1,
Err(e) => return Err(e.into()),
Ok(res) => {
Expand Down
Loading