From 4234b10b17b9b6e2c6cde813de304de6001a15bd Mon Sep 17 00:00:00 2001 From: janskiba Date: Wed, 14 Aug 2024 07:14:12 +0000 Subject: [PATCH] feat: Change health connection to SSE --- broker/src/serve_health.rs | 36 +++++++++++++-------- proxy/src/main.rs | 65 +++++++++++++++++++++++--------------- 2 files changed, 61 insertions(+), 40 deletions(-) diff --git a/broker/src/serve_health.rs b/broker/src/serve_health.rs index 81ab706d..d6ea2438 100644 --- a/broker/src/serve_health.rs +++ b/broker/src/serve_health.rs @@ -1,8 +1,9 @@ -use std::{sync::Arc, time::{Duration, SystemTime}}; +use std::{convert::Infallible, marker::PhantomData, sync::Arc, time::{Duration, SystemTime}}; -use axum::{extract::{State, Path}, http::StatusCode, routing::get, Json, Router, response::Response}; +use axum::{extract::{Path, State}, http::StatusCode, response::{sse::{Event, KeepAlive}, Response, Sse}, routing::get, Json, Router}; use axum_extra::{headers::{authorization::Basic, Authorization}, TypedHeader}; use beam_lib::ProxyId; +use futures_core::Stream; use serde::{Serialize, Deserialize}; use shared::{crypto_jwt::Authorized, Msg, config::CONFIG_CENTRAL}; use tokio::sync::RwLock; @@ -46,7 +47,7 @@ async fn handler( } async fn get_all_proxies(State(state): State>>) -> Json> { - Json(state.read().await.proxies.keys().cloned().collect()) + Json(state.read().await.proxies.iter().filter(|(_, v)| v.online()).map(|(k, _)| k).cloned().collect()) } async fn proxy_health( @@ -76,25 +77,32 @@ async fn proxy_health( async fn get_control_tasks( State(state): State>>, proxy_auth: Authorized, -) -> StatusCode { +) -> Sse { let proxy_id = proxy_auth.get_from().proxy_id(); // Once this is freed the connection will be removed from the map of connected proxies again // This ensures that when the connection is dropped and therefore this response future the status of this proxy will be updated - let _connection_remover = ConnectedGuard::connect(&proxy_id, &state).await; + let connect_guard = ConnectedGuard::connect(proxy_id, state).await; - // In the future, this will wait for control tasks for the given proxy - tokio::time::sleep(Duration::from_secs(60 * 60)).await; + Sse::new(ForeverStream(connect_guard)).keep_alive(KeepAlive::new()) +} + +struct ForeverStream(#[allow(dead_code)] ConnectedGuard); - StatusCode::OK +impl Stream for ForeverStream { + type Item = Result; + + fn poll_next(self: std::pin::Pin<&mut Self>, _cx: &mut std::task::Context<'_>) -> std::task::Poll> { + std::task::Poll::Pending + } } -struct ConnectedGuard<'a> { - proxy: &'a ProxyId, - state: &'a Arc> +struct ConnectedGuard { + proxy: ProxyId, + state: Arc> } -impl<'a> ConnectedGuard<'a> { - async fn connect(proxy: &'a ProxyId, state: &'a Arc>) -> ConnectedGuard<'a> { +impl ConnectedGuard { + async fn connect(proxy: ProxyId, state: Arc>) -> ConnectedGuard { { state.write().await.proxies .entry(proxy.clone()) @@ -105,7 +113,7 @@ impl<'a> ConnectedGuard<'a> { } } -impl<'a> Drop for ConnectedGuard<'a> { +impl Drop for ConnectedGuard { fn drop(&mut self) { let proxy_id = self.proxy.clone(); let map = self.state.clone(); diff --git a/proxy/src/main.rs b/proxy/src/main.rs index 8adc6bd6..a5a63ce6 100644 --- a/proxy/src/main.rs +++ b/proxy/src/main.rs @@ -6,6 +6,7 @@ use std::time::Duration; use axum::http::{header, HeaderValue, StatusCode}; use beam_lib::AppOrProxyId; use futures::future::Ready; +use futures::{StreamExt, TryStreamExt}; use shared::{reqwest, EncryptedMessage, MsgEmpty, PlainMessage}; use shared::crypto::CryptoPublicPortion; use shared::errors::SamplyBeamError; @@ -133,8 +134,12 @@ fn spawn_controller_polling(client: SamplyHttpClient, config: Config) { const RETRY_INTERVAL: Duration = Duration::from_secs(60); tokio::spawn(async move { let mut retries_this_min = 0; - let mut reset_interval = std::pin::pin!(tokio::time::sleep(Duration::from_secs(60))); + let mut reset_interval = Instant::now() + RETRY_INTERVAL; loop { + if reset_interval < Instant::now() { + retries_this_min = 0; + reset_interval = Instant::now() + RETRY_INTERVAL; + } let body = EncryptedMessage::MsgEmpty(MsgEmpty { from: AppOrProxyId::Proxy(config.proxy_id.clone()), }); @@ -146,39 +151,47 @@ fn spawn_controller_polling(client: SamplyHttpClient, config: Config) { let req = sign_request(body, parts, &config, None).await.expect("Unable to sign request; this should always work"); // In the future this will poll actual control related tasks - match client.execute(req).await { - Ok(res) => { - match res.status() { - StatusCode::OK => { - // Process control task - }, - status @ (StatusCode::GATEWAY_TIMEOUT | StatusCode::BAD_GATEWAY) => { - if retries_this_min < 10 { - retries_this_min += 1; - debug!("Connection to broker timed out; retrying."); - } else { - warn!("Retried more then 10 times in one minute getting status code: {status}"); - tokio::time::sleep(RETRY_INTERVAL).await; - continue; - } - }, - other => { - warn!("Got unexpected status getting control tasks from broker: {other}"); - tokio::time::sleep(RETRY_INTERVAL).await; - } - }; - }, + let res = match client.execute(req).await { + Ok(res) if res.status() != StatusCode::OK => { + if retries_this_min < 10 { + retries_this_min += 1; + debug!("Unexpected status code getting control tasks from broker: {}", res.status()); + } else { + warn!("Retried more then 10 times in one minute getting status code: {}", res.status()); + tokio::time::sleep(RETRY_INTERVAL).await; + } + continue; + } + Ok(res) => res, Err(e) if e.is_timeout() => { debug!("Connection to broker timed out; retrying: {e}"); + continue; }, Err(e) => { warn!("Error getting control tasks from broker; retrying in {}s: {e}", RETRY_INTERVAL.as_secs()); tokio::time::sleep(RETRY_INTERVAL).await; + continue; } }; - if reset_interval.is_elapsed() { - retries_this_min = 0; - reset_interval.as_mut().reset(Instant::now() + Duration::from_secs(60)); + let incoming = res + .bytes_stream() + .map(|result| result.map_err(|error| { + let kind = error.is_timeout().then_some(std::io::ErrorKind::TimedOut).unwrap_or(std::io::ErrorKind::Other); + std::io::Error::new(kind, format!("IO Error: {error}")) + })) + .into_async_read(); + let mut reader = async_sse::decode(incoming); + while let Some(ev) = reader.next().await { + match ev { + Ok(_)=> (), + Err(e) if e.downcast_ref::().unwrap().kind() == std::io::ErrorKind::TimedOut => { + debug!("SSE connection timed out"); + break; + }, + Err(err) => { + error!("Got error reading SSE stream: {err}"); + } + }; } } });