diff --git a/Cargo.lock b/Cargo.lock index 96498c7..b7fde0f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2673,7 +2673,7 @@ checksum = "b1141d4d61095b28419e22cb0bbf02755f5e54e0526f97f1e3d1d160e60885fb" [[package]] name = "tfreg" -version = "0.1.0" +version = "0.1.1" dependencies = [ "anyhow", "axum", diff --git a/Cargo.toml b/Cargo.toml index 6a4f849..187afe7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "tfreg" description = "Registry serving terraform providers from github releases" -version = "0.1.0" +version = "0.1.1" edition = "2021" license = "MIT" repository = "https://github.com/mattclement/tfreg" diff --git a/src/middleware.rs b/src/middleware.rs index 255f329..81b24c3 100644 --- a/src/middleware.rs +++ b/src/middleware.rs @@ -100,6 +100,8 @@ async fn check_repo_permissions(token: String, repo: &Repo) -> Result<(), Status Ok(()) } +/// Extract the repo specified in the given URL path. This is designed to handle paths that point +/// at either the downloads API or the provider API. fn repo_from_path(path: &str) -> Option { let repo_components_in_url_path = path .trim_start_matches('/') @@ -115,3 +117,23 @@ fn repo_from_path(path: &str) -> Option { repo_components_in_url_path.last()?.to_string(), )) } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_repo_from_path() { + let expected = Repo::new("org".to_string(), "name".to_string()); + + assert_eq!( + expected, + repo_from_path("/downloads/org/terraform-provider-name/2.3.4/SHA256SUMS").unwrap() + ); + + assert_eq!( + expected, + repo_from_path("/org/terraform-provider-name/2.3.4/SHA256SUMS").unwrap() + ) + } +} diff --git a/src/oauth/mod.rs b/src/oauth/mod.rs index dc99b5a..716a340 100644 --- a/src/oauth/mod.rs +++ b/src/oauth/mod.rs @@ -8,7 +8,7 @@ use std::{ use oauth2::{ AuthorizationCode, CsrfToken, PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, TokenResponse, }; -use orion::{aead, util::secure_rand_bytes}; +use orion::aead; use tokio::sync::RwLock; use crate::app_config::AppConfig; @@ -107,8 +107,7 @@ impl Authenticator { // a bit. If we want to support horizontal scaling of this server (lol) we will have to // write the actual token out so any other instance that has the secret key can use the // token. - let mut key_bytes = [0u8; 64]; - secure_rand_bytes(&mut key_bytes).map_err(OAuth2Error::Encryption)?; + let key_bytes = utils::generate_random_key()?; let key = utils::base64url_encode(key_bytes); let expires_at = diff --git a/src/oauth/utils.rs b/src/oauth/utils.rs index 171c42e..e9589d6 100644 --- a/src/oauth/utils.rs +++ b/src/oauth/utils.rs @@ -11,6 +11,7 @@ use crate::app_config::AppConfig; use super::{OAuth2Error, Result, AUTH_URL, TOKEN_URL}; const BASE64_FORMAT: base64::Config = base64::URL_SAFE; + pub fn base64url_encode>(key_bytes: T) -> String { base64::encode_config(key_bytes, BASE64_FORMAT) } @@ -42,3 +43,22 @@ pub fn current_epoch() -> u64 { .unwrap() .as_secs() } + +pub fn generate_random_key() -> Result<[u8; 64]> { + let mut key_bytes = [0u8; 64]; + orion::util::secure_rand_bytes(&mut key_bytes).map_err(OAuth2Error::Encryption)?; + Ok(key_bytes) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn base64() { + let key = generate_random_key().unwrap(); + let encoded = base64url_encode(key); + let decoded = base64url_decode(encoded).unwrap(); + assert_eq!(key.to_vec(), decoded); + } +}