From dd3a9e9d8670fc8ac2b85dbb0ec62f75c02e21a5 Mon Sep 17 00:00:00 2001 From: Joey Kraut Date: Mon, 21 Oct 2024 14:59:42 -0700 Subject: [PATCH] auth-server: server: Build `Server` type as dependency container --- Cargo.lock | 5 ++ auth-server/Cargo.toml | 5 ++ auth-server/src/main.rs | 82 +++++++++-------------- auth-server/src/server/handle_proxy.rs | 52 +++++++++++++++ auth-server/src/server/mod.rs | 92 ++++++++++++++++++++++++++ 5 files changed, 186 insertions(+), 50 deletions(-) create mode 100644 auth-server/src/server/handle_proxy.rs create mode 100644 auth-server/src/server/mod.rs diff --git a/Cargo.lock b/Cargo.lock index 1997869..665acf5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -702,17 +702,22 @@ checksum = "628d228f918ac3b82fe590352cc719d30664a0c13ca3a60266fe02c7132d480a" name = "auth-server" version = "0.1.0" dependencies = [ + "bb8", "bytes", "clap", "diesel", + "diesel-async", "futures-util", "http 0.2.12", "hyper 0.14.30", + "native-tls", + "postgres-native-tls", "reqwest 0.11.27", "serde", "serde_json", "thiserror", "tokio", + "tokio-postgres", "tracing", "warp", ] diff --git a/auth-server/Cargo.toml b/auth-server/Cargo.toml index 2eac6d2..9206473 100644 --- a/auth-server/Cargo.toml +++ b/auth-server/Cargo.toml @@ -13,7 +13,12 @@ tokio = { version = "1", features = ["full"] } warp = "0.3" # === Database === # +bb8 = "0.8" diesel = { version = "2", features = ["postgres"] } +diesel-async = { version = "0.4", features = ["postgres", "bb8"] } +tokio-postgres = "0.7" +postgres-native-tls = "0.5" +native-tls = "0.2" # === Misc Dependencies === # bytes = "1.0" diff --git a/auth-server/src/main.rs b/auth-server/src/main.rs index 5a8c2ce..927a906 100644 --- a/auth-server/src/main.rs +++ b/auth-server/src/main.rs @@ -13,16 +13,19 @@ #[allow(missing_docs, clippy::missing_docs_in_private_items)] pub(crate) mod schema; +mod server; -use bytes::Bytes; use clap::Parser; -use reqwest::{Client, Method, StatusCode}; +use reqwest::StatusCode; use serde_json::json; use std::net::SocketAddr; +use std::sync::Arc; use thiserror::Error; use tracing::{error, info}; use warp::{Filter, Rejection, Reply}; +use server::Server; + // ------- // | CLI | // ------- @@ -30,19 +33,25 @@ use warp::{Filter, Rejection, Reply}; /// The command line arguments for the auth server #[derive(Parser, Debug)] #[command(author, version, about, long_about = None)] -struct Args { +pub struct Cli { + /// The database url + #[arg(long, env = "DATABASE_URL")] + pub database_url: String, + /// The encryption key used to encrypt/decrypt database values + #[arg(long, env = "ENCRYPTION_KEY")] + pub encryption_key: String, /// The URL of the relayer #[arg(long, env = "RELAYER_URL")] - relayer_url: String, + pub relayer_url: String, /// The admin key for the relayer #[arg(long, env = "RELAYER_ADMIN_KEY")] - relayer_admin_key: String, + pub relayer_admin_key: String, /// The port to run the server on - #[arg(long, env = "PORT", default_value = "3030")] - port: u16, + #[arg(long, env = "PORT", default_value = "3000")] + pub port: u16, /// Whether to enable datadog logging #[arg(long)] - datadog_logging: bool, + pub datadog_logging: bool, } // ------------- @@ -70,9 +79,13 @@ impl warp::reject::Reject for ApiError {} /// The main function for the auth server #[tokio::main] async fn main() { - let args = Args::parse(); + let args = Cli::parse(); let listen_addr: SocketAddr = ([0, 0, 0, 0], args.port).into(); + // Create the server + let server = Server::new(args).await.expect("Failed to create server"); + let server = Arc::new(server); + // TODO: Setup logging // --- Routes --- // @@ -87,53 +100,22 @@ async fn main() { .and(warp::method()) .and(warp::header::headers_cloned()) .and(warp::body::bytes()) - .and(warp::any().map(move || args.relayer_url.clone())) - .and(warp::any().map(move || args.relayer_admin_key.clone())) - .and_then(handle_request); + .and(with_server(server.clone())) + .and_then(|path, method, headers, body, server: Arc| async move { + server.handle_proxy_request(path, method, headers, body).await + }); // Bind the server and listen - info!("Starting auth server on port {}", args.port); + info!("Starting auth server on port {}", listen_addr.port()); let routes = ping.or(proxy).recover(handle_rejection); warp::serve(routes).bind(listen_addr).await; } -/// Handle a request to the relayer -async fn handle_request( - path: warp::path::FullPath, - method: Method, - headers: warp::hyper::HeaderMap, - body: Bytes, - relayer_url: String, - relayer_admin_key: String, -) -> Result { - let client = Client::new(); - let url = format!("{}{}", relayer_url, path.as_str()); - - let mut req = client.request(method, &url).headers(headers).body(body); - req = req.header("X-Admin-Key", &relayer_admin_key); - - match req.send().await { - Ok(resp) => { - let status = resp.status(); - let headers = resp.headers().clone(); - let body = resp.bytes().await.map_err(|e| { - warp::reject::custom(ApiError::InternalError(format!( - "Failed to read response body: {}", - e - ))) - })?; - - let mut response = warp::http::Response::new(body); - *response.status_mut() = status; - *response.headers_mut() = headers; - - Ok(response) - }, - Err(e) => { - error!("Error proxying request: {}", e); - Err(warp::reject::custom(ApiError::InternalError(e.to_string()))) - }, - } +/// Helper function to pass the server to filters +fn with_server( + server: Arc, +) -> impl Filter,), Error = std::convert::Infallible> + Clone { + warp::any().map(move || server.clone()) } /// Handle a rejection from an endpoint handler diff --git a/auth-server/src/server/handle_proxy.rs b/auth-server/src/server/handle_proxy.rs new file mode 100644 index 0000000..2965911 --- /dev/null +++ b/auth-server/src/server/handle_proxy.rs @@ -0,0 +1,52 @@ +//! Handler code for proxied relayer requests +//! +//! At a high level the server must first authenticate the request, then forward +//! it to the relayer with admin authentication + +use bytes::Bytes; +use http::Method; +use tracing::error; +use warp::{reject::Rejection, reply::Reply}; + +use crate::ApiError; + +use super::Server; + +/// Handle a proxied request +impl Server { + /// Handle a request meant to be authenticated and proxied to the relayer + pub async fn handle_proxy_request( + &self, + path: warp::path::FullPath, + method: Method, + headers: warp::hyper::HeaderMap, + body: Bytes, + ) -> Result { + let url = format!("{}{}", self.relayer_url, path.as_str()); + let req = self.client.request(method, &url).headers(headers).body(body); + + // TODO: Add admin auth here + match req.send().await { + Ok(resp) => { + let status = resp.status(); + let headers = resp.headers().clone(); + let body = resp.bytes().await.map_err(|e| { + warp::reject::custom(ApiError::InternalError(format!( + "Failed to read response body: {}", + e + ))) + })?; + + let mut response = warp::http::Response::new(body); + *response.status_mut() = status; + *response.headers_mut() = headers; + + Ok(response) + }, + Err(e) => { + error!("Error proxying request: {}", e); + Err(warp::reject::custom(ApiError::InternalError(e.to_string()))) + }, + } + } +} diff --git a/auth-server/src/server/mod.rs b/auth-server/src/server/mod.rs new file mode 100644 index 0000000..c399911 --- /dev/null +++ b/auth-server/src/server/mod.rs @@ -0,0 +1,92 @@ +//! Defines the server struct and associated functions +//! +//! The server is a dependency injection container for the authentication server +use crate::Cli; +use bb8::{Pool, PooledConnection}; +use diesel::ConnectionError; +use diesel_async::{ + pooled_connection::{AsyncDieselConnectionManager, ManagerConfig}, + AsyncPgConnection, +}; +use native_tls::TlsConnector; +use postgres_native_tls::MakeTlsConnector; +use reqwest::Client; +use std::sync::Arc; +use thiserror::Error; +use tracing::error; + +mod handle_proxy; + +/// The DB connection type +pub type DbConn<'a> = PooledConnection<'a, AsyncDieselConnectionManager>; +/// The DB pool type +pub type DbPool = Pool>; + +/// Custom error type for server errors +#[derive(Error, Debug)] +pub enum ServerError { + /// Database connection error + #[error("Database connection error: {0}")] + DatabaseConnectionError(String), +} + +/// The server struct that holds all the necessary components +pub struct Server { + /// The database connection pool + pub db_pool: Arc, + /// The URL of the relayer + pub relayer_url: String, + /// The admin key for the relayer + pub relayer_admin_key: String, + /// The HTTP client + pub client: Client, +} + +impl Server { + /// Create a new server instance + pub async fn new(args: Cli) -> Result { + let db_pool = create_db_pool(&args.database_url).await?; + Ok(Self { + db_pool: Arc::new(db_pool), + relayer_url: args.relayer_url, + relayer_admin_key: args.relayer_admin_key, + client: Client::new(), + }) + } +} + +/// Create a database pool +pub async fn create_db_pool(db_url: &str) -> Result { + let mut conf = ManagerConfig::default(); + conf.custom_setup = Box::new(move |url| Box::pin(establish_connection(url))); + + let manager = AsyncDieselConnectionManager::new_with_config(db_url, conf); + Pool::builder() + .build(manager) + .await + .map_err(|e| ServerError::DatabaseConnectionError(e.to_string())) +} + +/// Establish a connection to the database +pub async fn establish_connection(db_url: &str) -> Result { + // Build a TLS connector, we don't validate certificates for simplicity. + // Practically this is unnecessary because we will be limiting our traffic to + // within a siloed environment when deployed + let connector = TlsConnector::builder() + .danger_accept_invalid_certs(true) + .build() + .expect("failed to build tls connector"); + let connector = MakeTlsConnector::new(connector); + let (client, conn) = tokio_postgres::connect(db_url, connector) + .await + .map_err(|e| ConnectionError::BadConnection(e.to_string()))?; + + // Spawn the connection handle in a separate task + tokio::spawn(async move { + if let Err(e) = conn.await { + error!("Connection error: {}", e); + } + }); + + AsyncPgConnection::try_from(client).await +}