diff --git a/examples/oauth/Cargo.toml b/examples/oauth/Cargo.toml index bb0aafc92a..fb0c2f6df1 100644 --- a/examples/oauth/Cargo.toml +++ b/examples/oauth/Cargo.toml @@ -5,6 +5,7 @@ edition = "2021" publish = false [dependencies] +anyhow = "1" async-session = "3.0.0" axum = { path = "../../axum" } axum-extra = { path = "../../axum-extra", features = ["typed-header"] } diff --git a/examples/oauth/src/main.rs b/examples/oauth/src/main.rs index ba879e7be8..dadfba8f7c 100644 --- a/examples/oauth/src/main.rs +++ b/examples/oauth/src/main.rs @@ -8,6 +8,7 @@ //! CLIENT_ID=REPLACE_ME CLIENT_SECRET=REPLACE_ME cargo run -p example-oauth //! ``` +use anyhow::{Context, Result}; use async_session::{MemoryStore, Session, SessionStore}; use axum::{ async_trait, @@ -18,7 +19,7 @@ use axum::{ RequestPartsExt, Router, }; use axum_extra::{headers, typed_header::TypedHeaderRejectionReason, TypedHeader}; -use http::{header, request::Parts}; +use http::{header, request::Parts, StatusCode}; use oauth2::{ basic::BasicClient, reqwest::async_http_client, AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, RedirectUrl, Scope, TokenResponse, TokenUrl, @@ -41,7 +42,7 @@ async fn main() { // `MemoryStore` is just used as an example. Don't use this in production. let store = MemoryStore::new(); - let oauth_client = oauth_client(); + let oauth_client = oauth_client().unwrap(); let app_state = AppState { store, oauth_client, @@ -57,9 +58,21 @@ async fn main() { let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") .await + .context("failed to bind TcpListener") + .unwrap(); + + tracing::debug!( + "listening on {}", + listener + .local_addr() + .context("failed to return local address") + .unwrap() + ); + + axum::serve(listener, app) + .await + .context("failed to serve service") .unwrap(); - tracing::debug!("listening on {}", listener.local_addr().unwrap()); - axum::serve(listener, app).await.unwrap(); } #[derive(Clone)] @@ -80,7 +93,7 @@ impl FromRef for BasicClient { } } -fn oauth_client() -> BasicClient { +fn oauth_client() -> Result { // Environment variables (* = required): // *"CLIENT_ID" "REPLACE_ME"; // *"CLIENT_SECRET" "REPLACE_ME"; @@ -88,8 +101,8 @@ fn oauth_client() -> BasicClient { // "AUTH_URL" "https://discord.com/api/oauth2/authorize?response_type=code"; // "TOKEN_URL" "https://discord.com/api/oauth2/token"; - let client_id = env::var("CLIENT_ID").expect("Missing CLIENT_ID!"); - let client_secret = env::var("CLIENT_SECRET").expect("Missing CLIENT_SECRET!"); + let client_id = env::var("CLIENT_ID").context("Missing CLIENT_ID!")?; + let client_secret = env::var("CLIENT_SECRET").context("Missing CLIENT_SECRET!")?; let redirect_url = env::var("REDIRECT_URL") .unwrap_or_else(|_| "http://127.0.0.1:3000/auth/authorized".to_string()); @@ -100,13 +113,15 @@ fn oauth_client() -> BasicClient { let token_url = env::var("TOKEN_URL") .unwrap_or_else(|_| "https://discord.com/api/oauth2/token".to_string()); - BasicClient::new( + Ok(BasicClient::new( ClientId::new(client_id), Some(ClientSecret::new(client_secret)), - AuthUrl::new(auth_url).unwrap(), - Some(TokenUrl::new(token_url).unwrap()), + AuthUrl::new(auth_url).context("failed to create new authorization server URL")?, + Some(TokenUrl::new(token_url).context("failed to create new token endpoint URL")?), ) - .set_redirect_uri(RedirectUrl::new(redirect_url).unwrap()) + .set_redirect_uri( + RedirectUrl::new(redirect_url).context("failed to create new redirection URL")?, + )) } // The user data we'll get back from Discord. @@ -151,17 +166,27 @@ async fn protected(user: User) -> impl IntoResponse { async fn logout( State(store): State, TypedHeader(cookies): TypedHeader, -) -> impl IntoResponse { - let cookie = cookies.get(COOKIE_NAME).unwrap(); - let session = match store.load_session(cookie.to_string()).await.unwrap() { +) -> Result { + let cookie = cookies + .get(COOKIE_NAME) + .context("unexpected error getting cookie name")?; + + let session = match store + .load_session(cookie.to_string()) + .await + .context("failed to load session")? + { Some(s) => s, // No session active, just redirect - None => return Redirect::to("/"), + None => return Ok(Redirect::to("/")), }; - store.destroy_session(session).await.unwrap(); + store + .destroy_session(session) + .await + .context("failed to destroy session")?; - Redirect::to("/") + Ok(Redirect::to("/")) } #[derive(Debug, Deserialize)] @@ -175,13 +200,13 @@ async fn login_authorized( Query(query): Query, State(store): State, State(oauth_client): State, -) -> impl IntoResponse { +) -> Result { // Get an auth token let token = oauth_client .exchange_code(AuthorizationCode::new(query.code.clone())) .request_async(async_http_client) .await - .unwrap(); + .context("failed in sending request request to authorization server")?; // Fetch user data from discord let client = reqwest::Client::new(); @@ -191,26 +216,35 @@ async fn login_authorized( .bearer_auth(token.access_token().secret()) .send() .await - .unwrap() + .context("failed in sending request to target Url")? .json::() .await - .unwrap(); + .context("failed to deserialize response as JSON")?; // Create a new session filled with user data let mut session = Session::new(); - session.insert("user", &user_data).unwrap(); + session + .insert("user", &user_data) + .context("failed in inserting serialized value into session")?; // Store session and get corresponding cookie - let cookie = store.store_session(session).await.unwrap().unwrap(); + let cookie = store + .store_session(session) + .await + .context("failed to store session")? + .context("unexpected error retrieving cookie value")?; // Build the cookie let cookie = format!("{}={}; SameSite=Lax; Path=/", COOKIE_NAME, cookie); // Set cookie let mut headers = HeaderMap::new(); - headers.insert(SET_COOKIE, cookie.parse().unwrap()); + headers.insert( + SET_COOKIE, + cookie.parse().context("failed to parse cookie")?, + ); - (headers, Redirect::to("/")) + Ok((headers, Redirect::to("/"))) } struct AuthRedirect; @@ -256,3 +290,28 @@ where Ok(user) } } + +// Use anyhow, define error and enable '?' +// For a simplified example of using anyhow in axum check /examples/anyhow-error-response +#[derive(Debug)] +struct AppError(anyhow::Error); + +// Tell axum how to convert `AppError` into a response. +impl IntoResponse for AppError { + fn into_response(self) -> Response { + tracing::error!("Application error: {:#}", self.0); + + (StatusCode::INTERNAL_SERVER_ERROR, "Something went wrong").into_response() + } +} + +// This enables using `?` on functions that return `Result<_, anyhow::Error>` to turn them into +// `Result<_, AppError>`. That way you don't need to do that manually. +impl From for AppError +where + E: Into, +{ + fn from(err: E) -> Self { + Self(err.into()) + } +}