Skip to content

Commit

Permalink
auth-server: server: Build Server type as dependency container
Browse files Browse the repository at this point in the history
  • Loading branch information
joeykraut committed Oct 21, 2024
1 parent 5857bac commit 47ac156
Show file tree
Hide file tree
Showing 5 changed files with 186 additions and 50 deletions.
5 changes: 5 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 5 additions & 0 deletions auth-server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@ tokio = { version = "1", features = ["full"] }
warp = "0.3"

# === Database === #
bb8 = "0.8"
diesel = { version = "2", features = ["postgres"] }
diesel-async = { version = "0.4", features = ["postgres", "bb8"] }
tokio-postgres = "0.7"
postgres-native-tls = "0.5"
native-tls = "0.2"

# === Misc Dependencies === #
bytes = "1.0"
Expand Down
82 changes: 32 additions & 50 deletions auth-server/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,36 +13,45 @@

#[allow(missing_docs, clippy::missing_docs_in_private_items)]
pub(crate) mod schema;
mod server;

use bytes::Bytes;
use clap::Parser;
use reqwest::{Client, Method, StatusCode};
use reqwest::StatusCode;
use serde_json::json;
use std::net::SocketAddr;
use std::sync::Arc;
use thiserror::Error;
use tracing::{error, info};
use warp::{Filter, Rejection, Reply};

use server::Server;

// -------
// | CLI |
// -------

/// The command line arguments for the auth server
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
pub struct Cli {
/// The database url
#[arg(long, env = "DATABASE_URL")]
pub database_url: String,
/// The encryption key used to encrypt/decrypt database values
#[arg(long, env = "ENCRYPTION_KEY")]
pub encryption_key: String,
/// The URL of the relayer
#[arg(long, env = "RELAYER_URL")]
relayer_url: String,
pub relayer_url: String,
/// The admin key for the relayer
#[arg(long, env = "RELAYER_ADMIN_KEY")]
relayer_admin_key: String,
pub relayer_admin_key: String,
/// The port to run the server on
#[arg(long, env = "PORT", default_value = "3030")]
port: u16,
#[arg(long, env = "PORT", default_value = "3000")]
pub port: u16,
/// Whether to enable datadog logging
#[arg(long)]
datadog_logging: bool,
pub datadog_logging: bool,
}

// -------------
Expand Down Expand Up @@ -70,9 +79,13 @@ impl warp::reject::Reject for ApiError {}
/// The main function for the auth server
#[tokio::main]
async fn main() {
let args = Args::parse();
let args = Cli::parse();
let listen_addr: SocketAddr = ([0, 0, 0, 0], args.port).into();

// Create the server
let server = Server::new(args).await.expect("Failed to create server");
let server = Arc::new(server);

// TODO: Setup logging

// --- Routes --- //
Expand All @@ -87,53 +100,22 @@ async fn main() {
.and(warp::method())
.and(warp::header::headers_cloned())
.and(warp::body::bytes())
.and(warp::any().map(move || args.relayer_url.clone()))
.and(warp::any().map(move || args.relayer_admin_key.clone()))
.and_then(handle_request);
.and(with_server(server.clone()))
.and_then(|path, method, headers, body, server: Arc<Server>| async move {
server.handle_proxy_request(path, method, headers, body).await
});

// Bind the server and listen
info!("Starting auth server on port {}", args.port);
info!("Starting auth server on port {}", listen_addr.port());
let routes = ping.or(proxy).recover(handle_rejection);
warp::serve(routes).bind(listen_addr).await;
}

/// Handle a request to the relayer
async fn handle_request(
path: warp::path::FullPath,
method: Method,
headers: warp::hyper::HeaderMap,
body: Bytes,
relayer_url: String,
relayer_admin_key: String,
) -> Result<impl Reply, Rejection> {
let client = Client::new();
let url = format!("{}{}", relayer_url, path.as_str());

let mut req = client.request(method, &url).headers(headers).body(body);
req = req.header("X-Admin-Key", &relayer_admin_key);

match req.send().await {
Ok(resp) => {
let status = resp.status();
let headers = resp.headers().clone();
let body = resp.bytes().await.map_err(|e| {
warp::reject::custom(ApiError::InternalError(format!(
"Failed to read response body: {}",
e
)))
})?;

let mut response = warp::http::Response::new(body);
*response.status_mut() = status;
*response.headers_mut() = headers;

Ok(response)
},
Err(e) => {
error!("Error proxying request: {}", e);
Err(warp::reject::custom(ApiError::InternalError(e.to_string())))
},
}
/// Helper function to pass the server to filters
fn with_server(
server: Arc<Server>,
) -> impl Filter<Extract = (Arc<Server>,), Error = std::convert::Infallible> + Clone {
warp::any().map(move || server.clone())
}

/// Handle a rejection from an endpoint handler
Expand Down
52 changes: 52 additions & 0 deletions auth-server/src/server/handle_proxy.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
//! Handler code for proxied relayer requests
//!
//! At a high level the server must first authenticate the request, then forward
//! it to the relayer with admin authentication

use bytes::Bytes;
use http::Method;
use tracing::error;
use warp::{reject::Rejection, reply::Reply};

use crate::ApiError;

use super::Server;

/// Handle a proxied request
impl Server {
/// Handle a request meant to be authenticated and proxied to the relayer
pub async fn handle_proxy_request(
&self,
path: warp::path::FullPath,
method: Method,
headers: warp::hyper::HeaderMap,
body: Bytes,
) -> Result<impl Reply, Rejection> {
let url = format!("{}{}", self.relayer_url, path.as_str());
let req = self.client.request(method, &url).headers(headers).body(body);

// TODO: Add admin auth here
match req.send().await {
Ok(resp) => {
let status = resp.status();
let headers = resp.headers().clone();
let body = resp.bytes().await.map_err(|e| {
warp::reject::custom(ApiError::InternalError(format!(
"Failed to read response body: {}",
e
)))
})?;

let mut response = warp::http::Response::new(body);
*response.status_mut() = status;
*response.headers_mut() = headers;

Ok(response)
},
Err(e) => {
error!("Error proxying request: {}", e);
Err(warp::reject::custom(ApiError::InternalError(e.to_string())))
},
}
}
}
92 changes: 92 additions & 0 deletions auth-server/src/server/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
//! Defines the server struct and associated functions
//!
//! The server is a dependency injection container for the authentication server
use crate::Cli;
use bb8::{Pool, PooledConnection};
use diesel::ConnectionError;
use diesel_async::{
pooled_connection::{AsyncDieselConnectionManager, ManagerConfig},
AsyncPgConnection,
};
use native_tls::TlsConnector;
use postgres_native_tls::MakeTlsConnector;
use reqwest::Client;
use std::sync::Arc;
use thiserror::Error;
use tracing::error;

mod handle_proxy;

/// The DB connection type
pub type DbConn<'a> = PooledConnection<'a, AsyncDieselConnectionManager<AsyncPgConnection>>;
/// The DB pool type
pub type DbPool = Pool<AsyncDieselConnectionManager<AsyncPgConnection>>;

/// Custom error type for server errors
#[derive(Error, Debug)]
pub enum ServerError {
/// Database connection error
#[error("Database connection error: {0}")]
DatabaseConnectionError(String),
}

/// The server struct that holds all the necessary components
pub struct Server {
/// The database connection pool
pub db_pool: Arc<DbPool>,
/// The URL of the relayer
pub relayer_url: String,
/// The admin key for the relayer
pub relayer_admin_key: String,
/// The HTTP client
pub client: Client,
}

impl Server {
/// Create a new server instance
pub async fn new(args: Cli) -> Result<Self, ServerError> {
let db_pool = create_db_pool(&args.database_url).await?;
Ok(Self {
db_pool: Arc::new(db_pool),
relayer_url: args.relayer_url,
relayer_admin_key: args.relayer_admin_key,
client: Client::new(),
})
}
}

/// Create a database pool
pub async fn create_db_pool(db_url: &str) -> Result<DbPool, ServerError> {
let mut conf = ManagerConfig::default();
conf.custom_setup = Box::new(move |url| Box::pin(establish_connection(url)));

let manager = AsyncDieselConnectionManager::new_with_config(db_url, conf);
Pool::builder()
.build(manager)
.await
.map_err(|e| ServerError::DatabaseConnectionError(e.to_string()))
}

/// Establish a connection to the database
pub async fn establish_connection(db_url: &str) -> Result<AsyncPgConnection, ConnectionError> {
// Build a TLS connector, we don't validate certificates for simplicity.
// Practically this is unnecessary because we will be limiting our traffic to
// within a siloed environment when deployed
let connector = TlsConnector::builder()
.danger_accept_invalid_certs(true)
.build()
.expect("failed to build tls connector");
let connector = MakeTlsConnector::new(connector);
let (client, conn) = tokio_postgres::connect(db_url, connector)
.await
.map_err(|e| ConnectionError::BadConnection(e.to_string()))?;

// Spawn the connection handle in a separate task
tokio::spawn(async move {
if let Err(e) = conn.await {
error!("Connection error: {}", e);
}
});

AsyncPgConnection::try_from(client).await
}

0 comments on commit 47ac156

Please sign in to comment.