Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

auth-server: server: Build Server type as dependency container #49

Merged
merged 1 commit into from
Oct 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 5 additions & 0 deletions auth-server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@ tokio = { version = "1", features = ["full"] }
warp = "0.3"

# === Database === #
bb8 = "0.8"
diesel = { version = "2", features = ["postgres"] }
diesel-async = { version = "0.4", features = ["postgres", "bb8"] }
tokio-postgres = "0.7"
postgres-native-tls = "0.5"
native-tls = "0.2"

# === Misc Dependencies === #
bytes = "1.0"
Expand Down
82 changes: 32 additions & 50 deletions auth-server/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,36 +13,45 @@

#[allow(missing_docs, clippy::missing_docs_in_private_items)]
pub(crate) mod schema;
mod server;

use bytes::Bytes;
use clap::Parser;
use reqwest::{Client, Method, StatusCode};
use reqwest::StatusCode;
use serde_json::json;
use std::net::SocketAddr;
use std::sync::Arc;
use thiserror::Error;
use tracing::{error, info};
use warp::{Filter, Rejection, Reply};

use server::Server;

// -------
// | CLI |
// -------

/// The command line arguments for the auth server
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
pub struct Cli {
/// The database url
#[arg(long, env = "DATABASE_URL")]
pub database_url: String,
/// The encryption key used to encrypt/decrypt database values
#[arg(long, env = "ENCRYPTION_KEY")]
pub encryption_key: String,
/// The URL of the relayer
#[arg(long, env = "RELAYER_URL")]
relayer_url: String,
pub relayer_url: String,
/// The admin key for the relayer
#[arg(long, env = "RELAYER_ADMIN_KEY")]
relayer_admin_key: String,
pub relayer_admin_key: String,
/// The port to run the server on
#[arg(long, env = "PORT", default_value = "3030")]
port: u16,
#[arg(long, env = "PORT", default_value = "3000")]
pub port: u16,
/// Whether to enable datadog logging
#[arg(long)]
datadog_logging: bool,
pub datadog_logging: bool,
}

// -------------
Expand Down Expand Up @@ -70,9 +79,13 @@ impl warp::reject::Reject for ApiError {}
/// The main function for the auth server
#[tokio::main]
async fn main() {
let args = Args::parse();
let args = Cli::parse();
let listen_addr: SocketAddr = ([0, 0, 0, 0], args.port).into();

// Create the server
let server = Server::new(args).await.expect("Failed to create server");
let server = Arc::new(server);

// TODO: Setup logging

// --- Routes --- //
Expand All @@ -87,53 +100,22 @@ async fn main() {
.and(warp::method())
.and(warp::header::headers_cloned())
.and(warp::body::bytes())
.and(warp::any().map(move || args.relayer_url.clone()))
.and(warp::any().map(move || args.relayer_admin_key.clone()))
.and_then(handle_request);
.and(with_server(server.clone()))
.and_then(|path, method, headers, body, server: Arc<Server>| async move {
server.handle_proxy_request(path, method, headers, body).await
});

// Bind the server and listen
info!("Starting auth server on port {}", args.port);
info!("Starting auth server on port {}", listen_addr.port());
let routes = ping.or(proxy).recover(handle_rejection);
warp::serve(routes).bind(listen_addr).await;
}

/// Handle a request to the relayer
async fn handle_request(
path: warp::path::FullPath,
method: Method,
headers: warp::hyper::HeaderMap,
body: Bytes,
relayer_url: String,
relayer_admin_key: String,
) -> Result<impl Reply, Rejection> {
let client = Client::new();
let url = format!("{}{}", relayer_url, path.as_str());

let mut req = client.request(method, &url).headers(headers).body(body);
req = req.header("X-Admin-Key", &relayer_admin_key);

match req.send().await {
Ok(resp) => {
let status = resp.status();
let headers = resp.headers().clone();
let body = resp.bytes().await.map_err(|e| {
warp::reject::custom(ApiError::InternalError(format!(
"Failed to read response body: {}",
e
)))
})?;

let mut response = warp::http::Response::new(body);
*response.status_mut() = status;
*response.headers_mut() = headers;

Ok(response)
},
Err(e) => {
error!("Error proxying request: {}", e);
Err(warp::reject::custom(ApiError::InternalError(e.to_string())))
},
}
/// Helper function to pass the server to filters
fn with_server(
server: Arc<Server>,
) -> impl Filter<Extract = (Arc<Server>,), Error = std::convert::Infallible> + Clone {
warp::any().map(move || server.clone())
}

/// Handle a rejection from an endpoint handler
Expand Down
52 changes: 52 additions & 0 deletions auth-server/src/server/handle_proxy.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
//! Handler code for proxied relayer requests
//!
//! At a high level the server must first authenticate the request, then forward
//! it to the relayer with admin authentication

use bytes::Bytes;
use http::Method;
use tracing::error;
use warp::{reject::Rejection, reply::Reply};

use crate::ApiError;

use super::Server;

/// Handle a proxied request
impl Server {
/// Handle a request meant to be authenticated and proxied to the relayer
pub async fn handle_proxy_request(
&self,
path: warp::path::FullPath,
method: Method,
headers: warp::hyper::HeaderMap,
body: Bytes,
) -> Result<impl Reply, Rejection> {
let url = format!("{}{}", self.relayer_url, path.as_str());
let req = self.client.request(method, &url).headers(headers).body(body);

// TODO: Add admin auth here
match req.send().await {
Ok(resp) => {
let status = resp.status();
let headers = resp.headers().clone();
let body = resp.bytes().await.map_err(|e| {
warp::reject::custom(ApiError::InternalError(format!(
"Failed to read response body: {}",
e
)))
})?;

let mut response = warp::http::Response::new(body);
*response.status_mut() = status;
*response.headers_mut() = headers;

Ok(response)
},
Err(e) => {
error!("Error proxying request: {}", e);
Err(warp::reject::custom(ApiError::InternalError(e.to_string())))
},
}
}
}
92 changes: 92 additions & 0 deletions auth-server/src/server/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
//! Defines the server struct and associated functions
//!
//! The server is a dependency injection container for the authentication server
use crate::Cli;
use bb8::{Pool, PooledConnection};
use diesel::ConnectionError;
use diesel_async::{
pooled_connection::{AsyncDieselConnectionManager, ManagerConfig},
AsyncPgConnection,
};
use native_tls::TlsConnector;
use postgres_native_tls::MakeTlsConnector;
use reqwest::Client;
use std::sync::Arc;
use thiserror::Error;
use tracing::error;

mod handle_proxy;

/// The DB connection type
pub type DbConn<'a> = PooledConnection<'a, AsyncDieselConnectionManager<AsyncPgConnection>>;

Check failure on line 21 in auth-server/src/server/mod.rs

View workflow job for this annotation

GitHub Actions / clippy

type alias `DbConn` is never used

error: type alias `DbConn` is never used --> auth-server/src/server/mod.rs:21:10 | 21 | pub type DbConn<'a> = PooledConnection<'a, AsyncDieselConnectionManager<AsyncPgConnection>>; | ^^^^^^ | = note: `-D dead-code` implied by `-D warnings` = help: to override `-D warnings` add `#[allow(dead_code)]`
/// The DB pool type
pub type DbPool = Pool<AsyncDieselConnectionManager<AsyncPgConnection>>;

/// Custom error type for server errors
#[derive(Error, Debug)]
pub enum ServerError {
/// Database connection error
#[error("Database connection error: {0}")]
DatabaseConnectionError(String),
}

/// The server struct that holds all the necessary components
pub struct Server {
/// The database connection pool
pub db_pool: Arc<DbPool>,
/// The URL of the relayer
pub relayer_url: String,
/// The admin key for the relayer
pub relayer_admin_key: String,
/// The HTTP client
pub client: Client,
}

impl Server {
/// Create a new server instance
pub async fn new(args: Cli) -> Result<Self, ServerError> {
let db_pool = create_db_pool(&args.database_url).await?;
Ok(Self {
db_pool: Arc::new(db_pool),
relayer_url: args.relayer_url,
relayer_admin_key: args.relayer_admin_key,
client: Client::new(),
})
}
}

/// Create a database pool
pub async fn create_db_pool(db_url: &str) -> Result<DbPool, ServerError> {
let mut conf = ManagerConfig::default();
conf.custom_setup = Box::new(move |url| Box::pin(establish_connection(url)));

let manager = AsyncDieselConnectionManager::new_with_config(db_url, conf);
Pool::builder()
.build(manager)
.await
.map_err(|e| ServerError::DatabaseConnectionError(e.to_string()))
}

/// Establish a connection to the database
pub async fn establish_connection(db_url: &str) -> Result<AsyncPgConnection, ConnectionError> {
// Build a TLS connector, we don't validate certificates for simplicity.
// Practically this is unnecessary because we will be limiting our traffic to
// within a siloed environment when deployed
let connector = TlsConnector::builder()
.danger_accept_invalid_certs(true)
.build()
.expect("failed to build tls connector");
let connector = MakeTlsConnector::new(connector);
let (client, conn) = tokio_postgres::connect(db_url, connector)
.await
.map_err(|e| ConnectionError::BadConnection(e.to_string()))?;

// Spawn the connection handle in a separate task
tokio::spawn(async move {
if let Err(e) = conn.await {
error!("Connection error: {}", e);
}
});

AsyncPgConnection::try_from(client).await
}
Loading