Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove unwraps via '?' with anyhow crate for example-oauth #2069

Merged
merged 5 commits into from
Jul 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/oauth/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ edition = "2021"
publish = false

[dependencies]
anyhow = "1"
async-session = "3.0.0"
axum = { path = "../../axum" }
axum-extra = { path = "../../axum-extra", features = ["typed-header"] }
Expand Down
109 changes: 84 additions & 25 deletions examples/oauth/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
//! CLIENT_ID=REPLACE_ME CLIENT_SECRET=REPLACE_ME cargo run -p example-oauth
//! ```

use anyhow::{Context, Result};
use async_session::{MemoryStore, Session, SessionStore};
use axum::{
async_trait,
Expand All @@ -18,7 +19,7 @@ use axum::{
RequestPartsExt, Router,
};
use axum_extra::{headers, typed_header::TypedHeaderRejectionReason, TypedHeader};
use http::{header, request::Parts};
use http::{header, request::Parts, StatusCode};
use oauth2::{
basic::BasicClient, reqwest::async_http_client, AuthUrl, AuthorizationCode, ClientId,
ClientSecret, CsrfToken, RedirectUrl, Scope, TokenResponse, TokenUrl,
Expand All @@ -41,7 +42,7 @@ async fn main() {

// `MemoryStore` is just used as an example. Don't use this in production.
let store = MemoryStore::new();
let oauth_client = oauth_client();
let oauth_client = oauth_client().unwrap();
let app_state = AppState {
store,
oauth_client,
Expand All @@ -57,9 +58,21 @@ async fn main() {

let listener = tokio::net::TcpListener::bind("127.0.0.1:3000")
.await
.context("failed to bind TcpListener")
.unwrap();

tracing::debug!(
"listening on {}",
listener
.local_addr()
.context("failed to return local address")
.unwrap()
);

axum::serve(listener, app)
.await
.context("failed to serve service")
.unwrap();
tracing::debug!("listening on {}", listener.local_addr().unwrap());
axum::serve(listener, app).await.unwrap();
}

#[derive(Clone)]
Expand All @@ -80,16 +93,16 @@ impl FromRef<AppState> for BasicClient {
}
}

fn oauth_client() -> BasicClient {
fn oauth_client() -> Result<BasicClient, AppError> {
// Environment variables (* = required):
// *"CLIENT_ID" "REPLACE_ME";
// *"CLIENT_SECRET" "REPLACE_ME";
// "REDIRECT_URL" "http://127.0.0.1:3000/auth/authorized";
// "AUTH_URL" "https://discord.com/api/oauth2/authorize?response_type=code";
// "TOKEN_URL" "https://discord.com/api/oauth2/token";

let client_id = env::var("CLIENT_ID").expect("Missing CLIENT_ID!");
let client_secret = env::var("CLIENT_SECRET").expect("Missing CLIENT_SECRET!");
let client_id = env::var("CLIENT_ID").context("Missing CLIENT_ID!")?;
let client_secret = env::var("CLIENT_SECRET").context("Missing CLIENT_SECRET!")?;
let redirect_url = env::var("REDIRECT_URL")
.unwrap_or_else(|_| "http://127.0.0.1:3000/auth/authorized".to_string());

Expand All @@ -100,13 +113,15 @@ fn oauth_client() -> BasicClient {
let token_url = env::var("TOKEN_URL")
.unwrap_or_else(|_| "https://discord.com/api/oauth2/token".to_string());

BasicClient::new(
Ok(BasicClient::new(
ClientId::new(client_id),
Some(ClientSecret::new(client_secret)),
AuthUrl::new(auth_url).unwrap(),
Some(TokenUrl::new(token_url).unwrap()),
AuthUrl::new(auth_url).context("failed to create new authorization server URL")?,
Some(TokenUrl::new(token_url).context("failed to create new token endpoint URL")?),
)
.set_redirect_uri(RedirectUrl::new(redirect_url).unwrap())
.set_redirect_uri(
RedirectUrl::new(redirect_url).context("failed to create new redirection URL")?,
))
}

// The user data we'll get back from Discord.
Expand Down Expand Up @@ -151,17 +166,27 @@ async fn protected(user: User) -> impl IntoResponse {
async fn logout(
State(store): State<MemoryStore>,
TypedHeader(cookies): TypedHeader<headers::Cookie>,
) -> impl IntoResponse {
let cookie = cookies.get(COOKIE_NAME).unwrap();
let session = match store.load_session(cookie.to_string()).await.unwrap() {
) -> Result<impl IntoResponse, AppError> {
let cookie = cookies
.get(COOKIE_NAME)
.context("unexpected error getting cookie name")?;

let session = match store
.load_session(cookie.to_string())
.await
.context("failed to load session")?
{
Some(s) => s,
// No session active, just redirect
None => return Redirect::to("/"),
None => return Ok(Redirect::to("/")),
};

store.destroy_session(session).await.unwrap();
store
.destroy_session(session)
.await
.context("failed to destroy session")?;

Redirect::to("/")
Ok(Redirect::to("/"))
}

#[derive(Debug, Deserialize)]
Expand All @@ -175,13 +200,13 @@ async fn login_authorized(
Query(query): Query<AuthRequest>,
State(store): State<MemoryStore>,
State(oauth_client): State<BasicClient>,
) -> impl IntoResponse {
) -> Result<impl IntoResponse, AppError> {
// Get an auth token
let token = oauth_client
.exchange_code(AuthorizationCode::new(query.code.clone()))
.request_async(async_http_client)
.await
.unwrap();
.context("failed in sending request request to authorization server")?;

// Fetch user data from discord
let client = reqwest::Client::new();
Expand All @@ -191,26 +216,35 @@ async fn login_authorized(
.bearer_auth(token.access_token().secret())
.send()
.await
.unwrap()
.context("failed in sending request to target Url")?
.json::<User>()
.await
.unwrap();
.context("failed to deserialize response as JSON")?;

// Create a new session filled with user data
let mut session = Session::new();
session.insert("user", &user_data).unwrap();
session
.insert("user", &user_data)
.context("failed in inserting serialized value into session")?;

// Store session and get corresponding cookie
let cookie = store.store_session(session).await.unwrap().unwrap();
let cookie = store
.store_session(session)
.await
.context("failed to store session")?
.context("unexpected error retrieving cookie value")?;

// Build the cookie
let cookie = format!("{}={}; SameSite=Lax; Path=/", COOKIE_NAME, cookie);

// Set cookie
let mut headers = HeaderMap::new();
headers.insert(SET_COOKIE, cookie.parse().unwrap());
headers.insert(
SET_COOKIE,
cookie.parse().context("failed to parse cookie")?,
);

(headers, Redirect::to("/"))
Ok((headers, Redirect::to("/")))
}

struct AuthRedirect;
Expand Down Expand Up @@ -256,3 +290,28 @@ where
Ok(user)
}
}

// Use anyhow, define error and enable '?'
// For a simplified example of using anyhow in axum check /examples/anyhow-error-response
#[derive(Debug)]
struct AppError(anyhow::Error);

// Tell axum how to convert `AppError` into a response.
impl IntoResponse for AppError {
fn into_response(self) -> Response {
tracing::error!("Application error: {:#}", self.0);

(StatusCode::INTERNAL_SERVER_ERROR, "Something went wrong").into_response()
}
}

// This enables using `?` on functions that return `Result<_, anyhow::Error>` to turn them into
// `Result<_, AppError>`. That way you don't need to do that manually.
impl<E> From<E> for AppError
where
E: Into<anyhow::Error>,
{
fn from(err: E) -> Self {
Self(err.into())
}
}