diff --git a/Cargo.toml b/Cargo.toml index 7b04223..9f6e00c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,7 +4,45 @@ members = [ "core", "example", "geo", - "lettre", "queue", + "smtp", ] resolver = "2" + +[workspace.dependencies] +async-session = "3" +async-trait = "0.1" +axum = "0.6" +base64 = "0.21" +bcrypt = "0.14" +bytes = "1.0" +derivative = "2.2" +derive-getters = "0.2.0" +derive_more = { version = "0.99.0", default-features = false } +env_logger = "0.8" +futures = "0.3" +http = "0.2.8" +http-body = "0.4" +hyper = { version = "0.14", features = ["full"] } +lettre = { version = "0.10.0-rc.6", default-features = false } +log = "0.4" +lru_time_cache = "0.11" +mime = "0.3" +paste = "1.0" +quoted_printable = "0.5" +rand = "0.8" +regex = "1" +reqwest = { version = "0.11", default-features = false, features = ["rustls-tls"] } +serde_json = "1" +serde_test = "1" +serde_urlencoded = "0.7" +serde = { version = "1", features = ["derive"] } +sqlx = "0.7" +temp-env = "0.3.2" +thiserror = "1.0" +time = "0.3" +tokio = "1" +tower = "0.4" +tower-http = { version = "0.3", features = ["cors"] } +url = "2.3" +uuid = { version = "1.0", default-features = false, features = ["serde", "std", "v4"] } diff --git a/authn/Cargo.toml b/authn/Cargo.toml index f4a9781..4c6131c 100644 --- a/authn/Cargo.toml +++ b/authn/Cargo.toml @@ -10,39 +10,39 @@ publish = false default = ["postgres"] postgres = ["iii-iv-core/postgres", "sqlx/postgres"] sqlite = ["iii-iv-core/sqlite", "sqlx/sqlite"] -testutils = ["dep:url", "iii-iv-core/sqlite", "iii-iv-core/testutils", "iii-iv-lettre/testutils"] +testutils = ["dep:url", "iii-iv-core/sqlite", "iii-iv-core/testutils", "iii-iv-smtp/testutils"] [dependencies] -async-trait = "0.1" -axum = "0.6" -base64 = "0.21" -bcrypt = "0.14" -derivative = "2.2" -futures = "0.3" -log = "0.4" -http = "0.2.8" +async-trait = { workspace = true } +axum = { workspace = true } +base64 = { workspace = true } +bcrypt = { workspace = true } +derivative = { workspace = true } +futures = { workspace = true } +http = { workspace = true } iii-iv-core = { path = "../core" } -iii-iv-lettre = { path = "../lettre" } -lru_time_cache = "0.11" -rand = "0.8" -serde = { version = "1", features = ["derive"] } -serde_urlencoded = "0.7" -time = "0.3" -url = { version = "2.3", optional = true } +iii-iv-smtp = { path = "../smtp" } +log = { workspace = true } +lru_time_cache = { workspace = true } +rand = { workspace = true } +serde_urlencoded = { workspace = true } +serde = { workspace = true } +time = { workspace = true } +url = { workspace = true, optional = true } [dependencies.sqlx] -version = "0.7" +workspace = true optional = true features = ["runtime-tokio-rustls", "time"] [dev-dependencies] -futures = "0.3" +futures = { workspace = true } iii-iv-core = { path = "../core", features = ["sqlite", "testutils"] } -iii-iv-lettre = { path = "../lettre", features = ["testutils"] } -temp-env = "0.3.2" -tokio = { version = "1", features = ["rt", "macros", "rt-multi-thread"] } -url = "2.3" +iii-iv-smtp = { path = "../smtp", features = ["testutils"] } +temp-env = { workspace = true } +tokio = { workspace = true, features = ["rt", "macros", "rt-multi-thread"] } +url = { workspace = true } [dev-dependencies.sqlx] -version = "0.7" +workspace = true features = ["runtime-tokio-rustls", "sqlite", "time"] diff --git a/authn/src/driver/email.rs b/authn/src/driver/email.rs index 2ac03c5..1f74a06 100644 --- a/authn/src/driver/email.rs +++ b/authn/src/driver/email.rs @@ -18,7 +18,8 @@ use crate::driver::DriverResult; use iii_iv_core::model::{EmailAddress, Username}; use iii_iv_core::rest::BaseUrls; -use iii_iv_lettre::{EmailTemplate, SmtpMailer}; +use iii_iv_smtp::driver::SmtpMailer; +use iii_iv_smtp::model::EmailTemplate; /// Sends the activation code `code` for `username` to the given `email` address. /// @@ -51,7 +52,8 @@ pub(crate) mod testutils { //! Utilities to help testing services that integrate with the `authn` features. use super::*; - use iii_iv_lettre::testutils::{parse_message, RecorderSmtpMailer}; + use iii_iv_smtp::driver::testutils::RecorderSmtpMailer; + use iii_iv_smtp::model::testutils::parse_message; use url::Url; /// Creates an email activation template to capture activation codes during tests. @@ -105,7 +107,8 @@ pub(crate) mod testutils { mod tests { use super::testutils::*; use super::*; - use iii_iv_lettre::testutils::*; + use iii_iv_smtp::driver::testutils::RecorderSmtpMailer; + use iii_iv_smtp::model::testutils::parse_message; #[tokio::test] async fn test_send_activation_code() { diff --git a/authn/src/driver/mod.rs b/authn/src/driver/mod.rs index 1158313..4d88ad3 100644 --- a/authn/src/driver/mod.rs +++ b/authn/src/driver/mod.rs @@ -24,7 +24,8 @@ use iii_iv_core::db::{Db, DbError, TxExecutor}; use iii_iv_core::driver::{DriverError, DriverResult}; use iii_iv_core::env::get_optional_var; use iii_iv_core::rest::BaseUrls; -use iii_iv_lettre::{EmailTemplate, SmtpMailer}; +use iii_iv_smtp::driver::SmtpMailer; +use iii_iv_smtp::model::EmailTemplate; use log::warn; use lru_time_cache::LruCache; use std::sync::Arc; diff --git a/authn/src/driver/testutils.rs b/authn/src/driver/testutils.rs index 9d9e20f..d5e87f8 100644 --- a/authn/src/driver/testutils.rs +++ b/authn/src/driver/testutils.rs @@ -26,7 +26,7 @@ use iii_iv_core::db::Executor; use iii_iv_core::model::EmailAddress; use iii_iv_core::model::Username; use iii_iv_core::rest::BaseUrls; -use iii_iv_lettre::testutils::RecorderSmtpMailer; +use iii_iv_smtp::driver::testutils::RecorderSmtpMailer; use std::sync::Arc; /// State of a running test. diff --git a/authn/src/rest/testutils.rs b/authn/src/rest/testutils.rs index 2823ae3..5a5cf7a 100644 --- a/authn/src/rest/testutils.rs +++ b/authn/src/rest/testutils.rs @@ -32,7 +32,7 @@ use { iii_iv_core::clocks::testutils::MonotonicClock, iii_iv_core::db::{Db, DbError}, iii_iv_core::rest::BaseUrls, - iii_iv_lettre::testutils::RecorderSmtpMailer, + iii_iv_smtp::driver::testutils::RecorderSmtpMailer, std::sync::Arc, }; diff --git a/core/Cargo.toml b/core/Cargo.toml index a6a9fc8..a2f18aa 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -38,35 +38,35 @@ testutils = [ ] [dependencies] -async-trait = "0.1" -axum = "0.6" -base64 = { version = "0.21", optional = true } -bytes = { version = "1.0", optional = true } -derivative = "2.2" -env_logger = { version = "0.8", optional = true } -futures = { version = "0.3", optional = true } -http = "0.2.8" -http-body = { version = "0.4", optional = true } -hyper = { version = "0.14", optional = true, features = ["full"] } -log = { version = "0.4", optional = true } -mime = { version = "0.3", optional = true } -paste = { version = "1.0", optional = true } -rand = { version = "0.8", optional = true } -regex = { version = "1", optional = true } -serde = { version = "1", features = ["derive"] } -serde_json = "1" -serde_urlencoded = { version = "0.7", optional = true } -sqlx = "0.7" -thiserror = "1.0" -time = "0.3" -tokio = { version = "1", optional = true } -tower = { version = "0.4", optional = true } -url = "2.3" +async-trait = { workspace = true } +axum = { workspace = true } +base64 = { workspace = true, optional = true } +bytes = { workspace = true, optional = true } +derivative = { workspace = true } +env_logger = { workspace = true, optional = true } +futures = { workspace = true, optional = true } +http-body = { workspace = true, optional = true } +http = { workspace = true } +hyper = { workspace = true, optional = true } +log = { workspace = true, optional = true } +mime = { workspace = true, optional = true } +paste = { workspace = true, optional = true } +rand = { workspace = true, optional = true } +regex = { workspace = true, optional = true } +serde_json = { workspace = true } +serde_urlencoded = { workspace = true, optional = true } +serde = { workspace = true } +sqlx = { workspace = true } +thiserror = { workspace = true } +time = { workspace = true } +tokio = { workspace = true, optional = true } +tower = { workspace = true, optional = true } +url = { workspace = true } [dev-dependencies] -env_logger = "0.8" -paste = "1.0" -rand = "0.8" -serde_test = "1" -temp-env = "0.3.2" -tokio = { version = "1", features = ["macros"] } +env_logger = { workspace = true } +paste = { workspace = true } +rand = { workspace = true } +serde_test = { workspace = true } +temp-env = { workspace = true } +tokio = { workspace = true, features = ["macros"] } diff --git a/core/src/db.rs b/core/src/db.rs index c696f61..ca4d2cf 100644 --- a/core/src/db.rs +++ b/core/src/db.rs @@ -119,6 +119,25 @@ pub trait Db { async fn begin(&self) -> DbResult; } +/// Parses a `COUNT` result as a `usize`. +pub fn count_as_usize(count: i64) -> DbResult { + match usize::try_from(count) { + Ok(count) => Ok(count), + Err(_) => Err(DbError::BackendError( + "COUNT should have returned a positive value that fits in usize".to_owned(), + )), + } +} + +/// Helper to verify that an insert and/or update opeeration affected just one row. +pub fn ensure_one_upsert(rows_affected: u64) -> DbResult<()> { + if rows_affected != 1 { + Err(DbError::BackendError(format!("Expected 1 new/modified row but got {}", rows_affected))) + } else { + Ok(()) + } +} + /// Macros to help instantiate tests for multiple database systems. #[cfg(any(test, feature = "testutils"))] pub mod testutils { diff --git a/example/Cargo.toml b/example/Cargo.toml index e32e519..2a62e42 100644 --- a/example/Cargo.toml +++ b/example/Cargo.toml @@ -7,29 +7,28 @@ edition = "2021" publish = false [dependencies] -async-session = "3" -async-trait = "0.1" -axum = "0.6" -derive-getters = "0.2.0" -env_logger = "0.8" -futures = "0.3" -hyper = { version = "0.14", features = ["full"] } -log = "0.4" -serde = { version = "1", features = ["derive"] } -serde_json = "1" -sqlx = { version = "0.7", features = ["runtime-tokio-rustls", "sqlite"] } -thiserror = "1.0" -tokio = { version = "1", features = ["rt", "macros", "rt-multi-thread"] } -tower-http = { version = "0.3", features = ["cors"] } -url = { version = "2.3", features = ["serde"] } +async-session = { workspace = true } +async-trait = { workspace = true } +axum = { workspace = true } +derive-getters = { workspace = true } +env_logger = { workspace = true } +futures = { workspace = true } +hyper = { workspace = true } +log = { workspace = true } +serde_json = { workspace = true } +serde = { workspace = true } +sqlx = { workspace = true, features = ["runtime-tokio-rustls", "sqlite"] } +thiserror = { workspace = true } +tokio = { workspace = true, features = ["rt", "macros", "rt-multi-thread"] } +tower-http = { workspace = true } +url = { workspace = true, features = ["serde"] } [dependencies.iii-iv-core] path = "../core" features = ["postgres"] [dependencies.derive_more] -version = "0.99.0" -default-features = false +workspace = true features = ["as_ref", "constructor"] [dev-dependencies.iii-iv-core] diff --git a/geo/Cargo.toml b/geo/Cargo.toml index 9d78635..3372f4f 100644 --- a/geo/Cargo.toml +++ b/geo/Cargo.toml @@ -11,21 +11,21 @@ default = [] testutils = [] [dependencies] -async-trait = "0.1" -bytes = "1.0" -derivative = "2.2" -futures = "0.3" +async-trait = { workspace = true } +bytes = { workspace = true } +derivative = { workspace = true } +futures = { workspace = true } iii-iv-core = { path = "../core" } -log = "0.4" -lru_time_cache = "0.11" -reqwest = { version = "0.11", default-features = false, features = ["rustls-tls"] } -serde = "1" -serde_json = "1" -time = "0.3" +log = { workspace = true } +lru_time_cache = { workspace = true } +reqwest = { workspace = true } +serde_json = { workspace = true } +serde = { workspace = true } +time = { workspace = true } [dev-dependencies] iii-iv-core = { path = "../core", features = ["testutils"] } -serde_test = "1" -temp-env = "0.3.2" -time = { version = "0.3", features = ["macros"] } -tokio = { version = "1", features = ["rt", "macros", "rt-multi-thread"] } +serde_test = { workspace = true } +temp-env = { workspace = true } +time = { workspace = true, features = ["macros"] } +tokio = { workspace = true, features = ["rt", "macros", "rt-multi-thread"] } diff --git a/lettre/Cargo.toml b/lettre/Cargo.toml deleted file mode 100644 index 2fcd020..0000000 --- a/lettre/Cargo.toml +++ /dev/null @@ -1,35 +0,0 @@ -[package] -name = "iii-iv-lettre" -version = "0.0.0" -description = "III-IV: SMTP support" -authors = ["Julio Merino "] -edition = "2021" -publish = false - -[features] -default = [] -testutils = ["dep:futures", "dep:quoted_printable"] - -[dependencies] -async-trait = "0.1" -axum = "0.6" -derivative = "2.2" -futures = { version = "0.3", optional = true } -http = "0.2.8" -quoted_printable = { version = "0.5", optional = true } -serde_json = "1" -thiserror = "1.0" -time = "0.3" -iii-iv-core = { path = "../core" } - -[dependencies.lettre] -version = "0.10.0-rc.6" -default-features = false -features = ["builder", "hostname", "pool", "rustls-tls", "smtp-transport", "tokio1-rustls-tls"] - -[dev-dependencies] -futures = "0.3" -iii-iv-core = { path = "../core", features = ["testutils"] } -quoted_printable = "0.5" -temp-env = "0.3.2" -tokio = { version = "1", features = ["macros"] } diff --git a/lettre/src/lib.rs b/lettre/src/lib.rs deleted file mode 100644 index 3580066..0000000 --- a/lettre/src/lib.rs +++ /dev/null @@ -1,453 +0,0 @@ -// III-IV -// Copyright 2023 Julio Merino -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may not -// use this file except in compliance with the License. You may obtain a copy -// of the License at: -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -// License for the specific language governing permissions and limitations -// under the License. - -//! Utilities to send messages over email. - -// Keep these in sync with other top-level files. -#![warn(anonymous_parameters, bad_style, clippy::missing_docs_in_private_items, missing_docs)] -#![warn(unused, unused_extern_crates, unused_import_braces, unused_qualifications)] -#![warn(unsafe_code)] - -use async_trait::async_trait; -use derivative::Derivative; -use iii_iv_core::driver::{DriverError, DriverResult}; -use iii_iv_core::env::get_required_var; -use iii_iv_core::model::EmailAddress; -use iii_iv_core::template; -use lettre::message::header::ContentTransferEncoding; -use lettre::message::Body; -pub use lettre::message::{Mailbox, Message}; -use lettre::transport::smtp::authentication::Credentials; -use lettre::{AsyncSmtpTransport, AsyncTransport, Tokio1Executor}; - -/// Options to establish an SMTP connection. -#[derive(Derivative)] -#[derivative(Debug)] -#[cfg_attr(test, derivative(PartialEq))] -pub struct SmtpOptions { - /// SMTP server to use. - pub relay: String, - - /// Username for logging into the SMTP server. - pub username: String, - - /// Password for logging into the SMTP server. - #[derivative(Debug = "ignore")] - pub password: String, -} - -impl SmtpOptions { - /// Initializes a set of options from environment variables whose name is prefixed with the - /// given `prefix`. - /// - /// This will use variables such as `_RELAY`, `_USERNAME` and - /// `_PASSWORD`. - pub fn from_env(prefix: &str) -> Result { - Ok(Self { - relay: get_required_var::(prefix, "RELAY")?, - username: get_required_var::(prefix, "USERNAME")?, - password: get_required_var::(prefix, "PASSWORD")?, - }) - } -} - -/// Trait to abstract the integration with the mailer. -#[async_trait] -pub trait SmtpMailer { - /// Sends a message over SMTP. - async fn send(&self, message: Message) -> DriverResult<()>; -} - -/// Mailer backed by a real SMTP connection using `lettre`. -#[derive(Clone)] -pub struct LettreSmtpMailer(AsyncSmtpTransport); - -impl LettreSmtpMailer { - /// Establishes a connection to the SMTP server. - pub fn connect(opts: SmtpOptions) -> Result { - let creds = Credentials::new(opts.username, opts.password); - let mailer = AsyncSmtpTransport::::relay(&opts.relay) - .map_err(|e| format!("{}", e))? - .credentials(creds) - .build(); - Ok(LettreSmtpMailer(mailer)) - } -} - -#[async_trait] -impl SmtpMailer for LettreSmtpMailer { - async fn send(&self, message: Message) -> DriverResult<()> { - self.0 - .send(message) - .await - .map_err(|e| DriverError::BackendError(format!("SMTP communication failed: {}", e)))?; - Ok(()) - } -} - -/// A template for an email message. -pub struct EmailTemplate { - /// Who the message comes from. - pub from: Mailbox, - - /// Subject of the message. - pub subject_template: &'static str, - - /// Body of the message. - pub body_template: &'static str, -} - -impl EmailTemplate { - /// Creates a message sent to `to` based on the template by applying the collection of - /// `replacements` to it. - /// - /// The subject and body of the template are subject to string replacements per the rules - /// described in `iii_iv_core::template::apply`. - pub fn apply( - &self, - to: &EmailAddress, - replacements: &[(&'static str, &str)], - ) -> DriverResult { - let to = to.as_str().parse().map_err(|e| { - // TODO(jmmv): This should never happen... but there is no guarantee right now that we can - // convert III-IV's `EmailAddress` into whatever Lettre expects. It'd be nice if we didn't - // need this though. - DriverError::InvalidInput(format!("Cannot parse email address {}: {}", to.as_str(), e)) - })?; - - let subject = template::apply(self.subject_template, replacements); - - let body = Body::new_with_encoding( - template::apply(self.body_template, replacements), - ContentTransferEncoding::QuotedPrintable, - ) - .map_err(|e| DriverError::BackendError(format!("Failed to encode message: {:?}", e)))?; - - let message = - Message::builder().from(self.from.clone()).to(to).subject(subject).body(body).map_err( - |e| DriverError::BackendError(format!("Failed to encode message: {:?}", e)), - )?; - Ok(message) - } -} - -/// Test utilities for email handling. -#[cfg(any(test, feature = "testutils"))] -pub mod testutils { - use super::*; - use futures::lock::Mutex; - use std::collections::{HashMap, HashSet}; - use std::sync::Arc; - - /// Given an SMTP `message`, parses it and extracts its headers and body. - pub fn parse_message(message: &Message) -> (HashMap, String) { - let text = String::from_utf8(message.formatted()).unwrap(); - let (raw_headers, encoded_body) = text - .split_once("\r\n\r\n") - .unwrap_or_else(|| panic!("Message seems to have the wrong format: {}", text)); - - let mut headers = HashMap::default(); - for raw_header in raw_headers.split("\r\n") { - let (key, value) = raw_header - .split_once(": ") - .unwrap_or_else(|| panic!("Header seems to have the wrong format: {}", raw_header)); - let previous = headers.insert(key.to_owned(), value.to_owned()); - assert!(previous.is_none(), "Duplicate header {}", raw_header); - } - - let decoded_body = - quoted_printable::decode(encoded_body, quoted_printable::ParseMode::Strict).unwrap(); - let body = String::from_utf8(decoded_body).unwrap().replace("\r\n", "\n"); - - (headers, body) - } - - /// Mailer that captures outgoing messages. - #[derive(Clone, Default)] - pub struct RecorderSmtpMailer { - /// Storage for captured messages. - pub inboxes: Arc>>>, - - /// Addresses for which to fail sending a message to. - errors: Arc>>, - } - - impl RecorderSmtpMailer { - /// Makes trying to send errors to `email` fail with an error. - pub async fn inject_error_for>(&self, email: E) { - let mut errors = self.errors.lock().await; - errors.insert(email.into()); - } - - /// Expects that messages were sent to `exp_to` and nobody else, and returns the list of - /// messages to that recipient. - pub async fn expect_one_inbox(&self, exp_to: &EmailAddress) -> Vec { - let inboxes = self.inboxes.lock().await; - assert_eq!(1, inboxes.len(), "Expected to find just one message in one inbox"); - let (to, messages) = inboxes.iter().next().unwrap(); - assert_eq!(exp_to, to); - messages.clone() - } - - /// Expects that only one message was sent to `exp_to` and nobody else, and returns the - /// message. - pub async fn expect_one_message(&self, exp_to: &EmailAddress) -> Message { - let mut messages = self.expect_one_inbox(exp_to).await; - assert_eq!( - 1, - messages.len(), - "Expected to find just one message for {}", - exp_to.as_str() - ); - messages.pop().unwrap() - } - } - - #[async_trait] - impl SmtpMailer for RecorderSmtpMailer { - async fn send(&self, message: Message) -> DriverResult<()> { - let to = EmailAddress::from( - message.headers().get_raw("To").expect("To header must have been present"), - ); - - { - let errors = self.errors.lock().await; - if errors.contains(&to) { - return Err(DriverError::BackendError(format!( - "Sending email to {} failed", - to.as_str() - ))); - } - } - - let mut inboxes = self.inboxes.lock().await; - inboxes.entry(to).or_insert_with(Vec::default).push(message); - Ok(()) - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::testutils::*; - use iii_iv_core::model::EmailAddress; - use std::{env, panic::catch_unwind}; - - #[test] - pub fn test_smtp_options_from_env_all_present() { - let overrides = [ - ("SMTP_RELAY", Some("the-relay")), - ("SMTP_USERNAME", Some("the-username")), - ("SMTP_PASSWORD", Some("the-password")), - ]; - temp_env::with_vars(overrides, || { - let opts = SmtpOptions::from_env("SMTP").unwrap(); - assert_eq!( - SmtpOptions { - relay: "the-relay".to_owned(), - username: "the-username".to_owned(), - password: "the-password".to_owned() - }, - opts - ); - }); - } - - #[test] - pub fn test_smtp_options_from_env_missing() { - let overrides = [ - ("MISSING_RELAY", Some("the-relay")), - ("MISSING_USERNAME", Some("the-username")), - ("MISSING_PASSWORD", Some("the-password")), - ]; - for (var, _) in overrides { - temp_env::with_vars(overrides, || { - env::remove_var(var); - let err = SmtpOptions::from_env("MISSING").unwrap_err(); - assert!(err.contains(&format!("{} not present", var))); - }); - } - } - - #[tokio::test] - pub async fn test_parse_message() { - let exp_body = " -This is a sample message with a line that should be longer than 72 characters to test line wraps. - -There is also a second paragraph with = quoted printable characters. -"; - let message = Message::builder() - .from("From someone ".parse().unwrap()) - .to("to@example.com".parse().unwrap()) - .subject("This: is the: subject line") - .body(exp_body.to_owned()) - .unwrap(); - - // Make sure the encoding of the message is quoted-printable. This isn't strictly required - // because I suppose `parse_message` might succeed anyway, but it's good to encode our - // assumption in a test. - let text = String::from_utf8(message.formatted()).unwrap(); - assert!(text.contains("=3D")); - - let (headers, body) = parse_message(&message); - - assert!(headers.len() >= 3); - assert_eq!("\"From someone\" ", headers.get("From").unwrap()); - assert_eq!("to@example.com", headers.get("To").unwrap()); - assert_eq!("This: is the: subject line", headers.get("Subject").unwrap()); - - assert_eq!(exp_body, body); - } - - /// Creates a new message where the only thing that matters is toe `to` field. - fn new_message(to: &EmailAddress) -> Message { - Message::builder() - .from("from@example.com".parse().unwrap()) - .to(to.as_str().parse().unwrap()) - .subject("Test") - .body("Body".to_owned()) - .unwrap() - } - - #[tokio::test] - pub async fn test_recorder_inject_error() { - let to1 = EmailAddress::from("to1@example.com"); - let to2 = EmailAddress::from("to2@example.com"); - let to3 = EmailAddress::from("to3@example.com"); - - let mailer = RecorderSmtpMailer::default(); - mailer.inject_error_for(to2.clone()).await; - - mailer.send(new_message(&to1)).await.unwrap(); - mailer.send(new_message(&to2)).await.unwrap_err(); - mailer.send(new_message(&to3)).await.unwrap(); - - let inboxes = mailer.inboxes.lock().await; - assert!(inboxes.contains_key(&to1)); - assert!(!inboxes.contains_key(&to2)); - assert!(inboxes.contains_key(&to3)); - } - - #[tokio::test] - pub async fn test_recorder_expect_one_inbox_ok() { - let to = EmailAddress::from("to@example.com"); - let message = new_message(&to); - let exp_formatted = message.formatted(); - - let mailer = RecorderSmtpMailer::default(); - mailer.send(message.clone()).await.unwrap(); - mailer.send(message).await.unwrap(); - - let messages = mailer.expect_one_inbox(&to).await; - assert_eq!( - vec![exp_formatted.clone(), exp_formatted], - messages.iter().map(Message::formatted).collect::>>(), - ); - } - - #[test] - pub fn test_recorder_expect_one_inbox_too_many_recipients() { - #[tokio::main(flavor = "current_thread")] - async fn do_test() { - let to1 = EmailAddress::from("to1@example.com"); - let to2 = EmailAddress::from("to2@example.com"); - - let mailer = RecorderSmtpMailer::default(); - mailer.send(new_message(&to1)).await.unwrap(); - mailer.send(new_message(&to2)).await.unwrap(); - - let _ = mailer.expect_one_inbox(&to1).await; // Will panic. - } - assert!(catch_unwind(do_test).is_err()); - } - - #[tokio::test] - pub async fn test_recorder_expect_one_message_ok() { - let to = EmailAddress::from("to@example.com"); - let message = new_message(&to); - let exp_formatted = message.formatted(); - - let mailer = RecorderSmtpMailer::default(); - mailer.send(message).await.unwrap(); - - assert_eq!(exp_formatted, mailer.expect_one_message(&to).await.formatted()); - } - - #[test] - pub fn test_recorder_expect_one_message_too_many_recipients() { - #[tokio::main(flavor = "current_thread")] - async fn do_test() { - let to1 = EmailAddress::from("to1@example.com"); - let to2 = EmailAddress::from("to2@example.com"); - - let mailer = RecorderSmtpMailer::default(); - mailer.send(new_message(&to1)).await.unwrap(); - mailer.send(new_message(&to2)).await.unwrap(); - - let _ = mailer.expect_one_message(&to1).await; // Will panic. - } - assert!(catch_unwind(do_test).is_err()); - } - - #[test] - pub fn test_recorder_expect_one_message_too_many_messages() { - #[tokio::main(flavor = "current_thread")] - async fn do_test() { - let to = EmailAddress::from("to@example.com"); - - let mailer = RecorderSmtpMailer::default(); - mailer.send(new_message(&to)).await.unwrap(); - mailer.send(new_message(&to)).await.unwrap(); - - let _ = mailer.expect_one_message(&to).await; // Will panic. - } - assert!(catch_unwind(do_test).is_err()); - } - - #[test] - fn test_email_template() { - let template = EmailTemplate { - from: "Sender ".parse().unwrap(), - subject_template: "The %s%", - body_template: "The %b% with quoted printable =50 characters", - }; - - let message = template - .apply( - &EmailAddress::from("recipient@example.com"), - &[("s", "replaced subject"), ("b", "replaced body")], - ) - .unwrap(); - let (headers, body) = parse_message(&message); - - let exp_message = Message::builder() - .from(template.from) - .to("recipient@example.com".parse().unwrap()) - .subject("The replaced subject") - .body( - Body::new_with_encoding( - "The replaced body with quoted printable =50 characters".to_owned(), - ContentTransferEncoding::QuotedPrintable, - ) - .unwrap(), - ) - .unwrap(); - let (exp_headers, exp_body) = parse_message(&exp_message); - - assert_eq!(exp_headers, headers); - assert_eq!(exp_body, body); - } -} diff --git a/queue/Cargo.toml b/queue/Cargo.toml index 6f6369e..dceca24 100644 --- a/queue/Cargo.toml +++ b/queue/Cargo.toml @@ -13,30 +13,30 @@ sqlite = ["iii-iv-core/sqlite", "sqlx/sqlite", "sqlx/time", "sqlx/uuid"] testutils = [] [dependencies] -async-trait = "0.1" -axum = "0.6" -derivative = "2.2" -futures = "0.3" +async-trait = { workspace = true } +axum = { workspace = true } +derivative = { workspace = true } +futures = { workspace = true } iii-iv-core = { path = "../core" } -log = "0.4" -serde = "1" -serde_json = "1" -time = "0.3" -tokio = "1" -uuid = { version = "1.0", default-features = false, features = ["serde", "std", "v4"] } +log = { workspace = true } +serde_json = { workspace = true } +serde = { workspace = true } +time = { workspace = true } +tokio = { workspace = true } +uuid = { workspace = true } [dependencies.sqlx] -version = "0.7" +workspace = true optional = true features = ["runtime-tokio-rustls", "time"] [dev-dependencies] iii-iv-core = { path = "../core", features = ["sqlite", "testutils"] } -rand = "0.8" -serde = { version = "1", features = ["derive"] } -time = { version = "0.3", features = ["formatting"] } -tokio = { version = "1", features = ["rt", "macros", "rt-multi-thread"] } +rand = { workspace = true } +serde = { workspace = true, features = ["derive"] } +time = { workspace = true, features = ["formatting"] } +tokio = { workspace = true, features = ["rt", "macros", "rt-multi-thread"] } [dev-dependencies.sqlx] -version = "0.7" +workspace = true features = ["runtime-tokio-rustls", "sqlite", "time", "uuid"] diff --git a/queue/src/db/mod.rs b/queue/src/db/mod.rs index 223ff30..ad1567d 100644 --- a/queue/src/db/mod.rs +++ b/queue/src/db/mod.rs @@ -161,9 +161,15 @@ pub(crate) async fn get_result(ex: &mut Executor, id: Uuid) -> DbResult { let query_str = " - SELECT status_code, status_reason + SELECT status_code, status_reason, runs, only_after FROM tasks - WHERE id = $1 AND status_code != $2 + WHERE id = $1 AND ( + status_code != $2 + OR ( + status_code = $2 AND status_reason IS NOT NULL + AND runs > 0 AND only_after IS NOT NULL + ) + ) "; match sqlx::query(query_str) .bind(id) @@ -176,6 +182,9 @@ pub(crate) async fn get_result(ex: &mut Executor, id: Uuid) -> DbResult = row.try_get("status_reason").map_err(postgres::map_sqlx_error)?; + let runs: i16 = row.try_get("runs").map_err(postgres::map_sqlx_error)?; + let only_after: Option = + row.try_get("only_after").map_err(postgres::map_sqlx_error)?; let code = match i8::try_from(code) { Ok(code) => code, @@ -187,7 +196,7 @@ pub(crate) async fn get_result(ex: &mut Executor, id: Uuid) -> DbResult DbResult { let query_str = " - SELECT status_code, status_reason + SELECT status_code, status_reason, runs, only_after_sec, only_after_nsec FROM tasks - WHERE id = ? AND status_code != ? + WHERE id = ? AND ( + status_code != ? + OR ( + status_code = ? AND status_reason IS NOT NULL + AND runs > 0 AND only_after_sec IS NOT NULL + ) + ) "; match sqlx::query(query_str) .bind(id) .bind(TaskStatus::Runnable as i8) + .bind(TaskStatus::Runnable as i8) .fetch_optional(ex) .await .map_err(sqlite::map_sqlx_error)? @@ -213,8 +229,24 @@ pub(crate) async fn get_result(ex: &mut Executor, id: Uuid) -> DbResult = row.try_get("status_reason").map_err(sqlite::map_sqlx_error)?; + let runs: i16 = row.try_get("runs").map_err(postgres::map_sqlx_error)?; + let only_after_sec: Option = + row.try_get("only_after_sec").map_err(sqlite::map_sqlx_error)?; + let only_after_nsec: Option = + row.try_get("only_after_nsec").map_err(sqlite::map_sqlx_error)?; + + let only_after = match (only_after_sec, only_after_nsec) { + (Some(sec), Some(nsec)) => Some(sqlite::build_timestamp(sec, nsec)?), + (None, None) => None, + (_, _) => { + return Err(DbError::DataIntegrityError(format!( + "Inconsistent only_after sec ({:?}) and nsec ({:?}) values", + only_after_sec, only_after_nsec + ))); + } + }; - let result = status_to_result(id, code, reason)? + let result = status_to_result(id, code, reason, runs, only_after)? .expect("Must not have queried runnable tasks"); Ok(Some(result)) } @@ -239,9 +271,12 @@ pub(crate) async fn get_results_since( #[cfg(feature = "postgres")] Executor::Postgres(ref mut ex) => { let query_str = " - SELECT id, status_code, status_reason + SELECT id, status_code, status_reason, runs, only_after FROM tasks - WHERE status_code != $1 AND updated >= $2 + WHERE ( + status_code != $1 + OR (status_code = $1 AND runs > 0 AND only_after IS NOT NULL) + ) AND updated >= $2 ORDER BY updated ASC "; let mut rows = @@ -252,6 +287,9 @@ pub(crate) async fn get_results_since( let code: i16 = row.try_get("status_code").map_err(postgres::map_sqlx_error)?; let reason: Option = row.try_get("status_reason").map_err(postgres::map_sqlx_error)?; + let runs: i16 = row.try_get("runs").map_err(postgres::map_sqlx_error)?; + let only_after: Option = + row.try_get("only_after").map_err(postgres::map_sqlx_error)?; let code = match i8::try_from(code) { Ok(code) => code, @@ -263,7 +301,7 @@ pub(crate) async fn get_results_since( } }; - let result = status_to_result(id, code, reason)? + let result = status_to_result(id, code, reason, runs, only_after)? .expect("Must not have queried runnable tasks"); results.push((id, result)); } @@ -274,14 +312,16 @@ pub(crate) async fn get_results_since( let (since_sec, since_nsec) = sqlite::unpack_timestamp(since); let query_str = " - SELECT id, status_code, status_reason + SELECT id, status_code, status_reason, runs, only_after_sec, only_after_nsec FROM tasks - WHERE + WHERE ( status_code != ? - AND (updated_sec >= ? OR (updated_sec = ? AND updated_nsec >= ?)) + OR (status_code = ? AND runs > 0 AND only_after_sec IS NOT NULL) + ) AND (updated_sec >= ? OR (updated_sec = ? AND updated_nsec >= ?)) ORDER BY updated_sec ASC, updated_nsec ASC "; let mut rows = sqlx::query(query_str) + .bind(TaskStatus::Runnable as i8) .bind(TaskStatus::Runnable as i8) .bind(since_sec) .bind(since_sec) @@ -293,8 +333,24 @@ pub(crate) async fn get_results_since( let code: i8 = row.try_get("status_code").map_err(sqlite::map_sqlx_error)?; let reason: Option = row.try_get("status_reason").map_err(sqlite::map_sqlx_error)?; + let runs: i16 = row.try_get("runs").map_err(postgres::map_sqlx_error)?; + let only_after_sec: Option = + row.try_get("only_after_sec").map_err(sqlite::map_sqlx_error)?; + let only_after_nsec: Option = + row.try_get("only_after_nsec").map_err(sqlite::map_sqlx_error)?; + + let only_after = match (only_after_sec, only_after_nsec) { + (Some(sec), Some(nsec)) => Some(sqlite::build_timestamp(sec, nsec)?), + (None, None) => None, + (_, _) => { + return Err(DbError::DataIntegrityError(format!( + "Inconsistent only_after sec ({:?}) and msec ({:?}) values", + only_after_sec, only_after_nsec + ))); + } + }; - let result = status_to_result(id, code, reason)? + let result = status_to_result(id, code, reason, runs, only_after)? .expect("Must not have queried runnable tasks"); results.push((id, result)); } diff --git a/queue/src/db/status.rs b/queue/src/db/status.rs index e854662..72e87a7 100644 --- a/queue/src/db/status.rs +++ b/queue/src/db/status.rs @@ -58,16 +58,29 @@ pub(super) fn result_to_status( /// Parses a status `code`/`reason` pair as extracted from the database into a `TaskResult`. /// -/// If the task is still running, there is no result yet. +/// If the task is still running, there is no result yet, unless the task has been deferred after +/// a retry, in which case there will be a result. /// /// The `id` is used for error reporting reasons only. pub(super) fn status_to_result( id: Uuid, code: i8, reason: Option, + runs: i16, + only_after: Option, ) -> DbResult> { match code { - x if x == (TaskStatus::Runnable as i8) => Ok(None), + x if x == (TaskStatus::Runnable as i8) => match (runs, only_after) { + (0, _) => Ok(None), + (runs, Some(only_after)) => match reason { + Some(reason) => Ok(Some(TaskResult::Retry(only_after, reason))), + None => Err(DbError::DataIntegrityError(format!( + "Task {} is Retry with runs={} but status_reason is missing", + id, runs + ))), + }, + (_, None) => Ok(None), + }, x if x == (TaskStatus::Done as i8) => Ok(Some(TaskResult::Done(reason))), @@ -123,27 +136,81 @@ mod tests { #[test] fn test_status_to_result_runnable_is_none() { - match status_to_result(Uuid::new_v4(), TaskStatus::Runnable as i8, None) { + match status_to_result(Uuid::new_v4(), TaskStatus::Runnable as i8, None, 3, None) { + Ok(None) => (), + r => panic!("Unexpected result: {:?}", r), + } + + match status_to_result( + Uuid::new_v4(), + TaskStatus::Runnable as i8, + Some("foo".to_owned()), + 0, + None, + ) { + Ok(None) => (), + r => panic!("Unexpected result: {:?}", r), + } + } + + #[test] + fn test_status_to_result_runnable_in_the_future_is_none() { + let now = datetime!(2023-10-19 15:50:00 UTC); + + match status_to_result(Uuid::new_v4(), TaskStatus::Runnable as i8, None, 0, Some(now)) { Ok(None) => (), r => panic!("Unexpected result: {:?}", r), } - match status_to_result(Uuid::new_v4(), TaskStatus::Runnable as i8, Some("foo".to_owned())) { + match status_to_result( + Uuid::new_v4(), + TaskStatus::Runnable as i8, + Some("foo".to_owned()), + 0, + Some(now), + ) { Ok(None) => (), r => panic!("Unexpected result: {:?}", r), } } + #[test] + fn test_status_to_result_retry_after_failure() { + let now = datetime!(2023-10-19 15:50:00 UTC); + + match status_to_result(Uuid::new_v4(), TaskStatus::Runnable as i8, None, 1, Some(now)) { + Err(DbError::DataIntegrityError(_)) => (), + r => panic!("Unexpected result: {:?}", r), + } + + assert_eq!( + Ok(Some(TaskResult::Retry(now, "foo".to_owned()))), + status_to_result( + Uuid::new_v4(), + TaskStatus::Runnable as i8, + Some("foo".to_owned()), + 1, + Some(now), + ) + ); + } + #[test] fn test_status_to_result_done_may_have_reason() { assert_eq!( Ok(Some(TaskResult::Done(None))), - status_to_result(Uuid::new_v4(), TaskStatus::Done as i8, None) + status_to_result(Uuid::new_v4(), TaskStatus::Done as i8, None, 123, None) ); assert_eq!( Ok(Some(TaskResult::Done(Some("msg".to_owned())))), - status_to_result(Uuid::new_v4(), TaskStatus::Done as i8, Some("msg".to_owned())) + status_to_result( + Uuid::new_v4(), + TaskStatus::Done as i8, + Some("msg".to_owned()), + 0, + None + ) ); } @@ -151,10 +218,16 @@ mod tests { fn test_status_to_result_failed_must_have_reason() { assert_eq!( Ok(Some(TaskResult::Failed("msg".to_owned()))), - status_to_result(Uuid::new_v4(), TaskStatus::Failed as i8, Some("msg".to_owned())) + status_to_result( + Uuid::new_v4(), + TaskStatus::Failed as i8, + Some("msg".to_owned()), + 0, + None + ) ); - match status_to_result(Uuid::new_v4(), TaskStatus::Failed as i8, None) { + match status_to_result(Uuid::new_v4(), TaskStatus::Failed as i8, None, 1, None) { Err(DbError::DataIntegrityError(_)) => (), r => panic!("Unexpected result: {:?}", r), } @@ -164,10 +237,16 @@ mod tests { fn test_status_to_result_abandoned_must_have_reason() { assert_eq!( Ok(Some(TaskResult::Abandoned("msg".to_owned()))), - status_to_result(Uuid::new_v4(), TaskStatus::Abandoned as i8, Some("msg".to_owned())) + status_to_result( + Uuid::new_v4(), + TaskStatus::Abandoned as i8, + Some("msg".to_owned()), + 1, + None + ) ); - match status_to_result(Uuid::new_v4(), TaskStatus::Abandoned as i8, None) { + match status_to_result(Uuid::new_v4(), TaskStatus::Abandoned as i8, None, 0, None) { Err(DbError::DataIntegrityError(_)) => (), r => panic!("Unexpected result: {:?}", r), } @@ -175,12 +254,12 @@ mod tests { #[test] fn test_status_to_result_unknown_code() { - match status_to_result(Uuid::new_v4(), 123, None) { + match status_to_result(Uuid::new_v4(), 123, None, 0, None) { Err(DbError::DataIntegrityError(e)) => assert!(e.contains("unknown")), r => panic!("Unexpected result: {:?}", r), } - match status_to_result(Uuid::new_v4(), 123, Some("foo".to_owned())) { + match status_to_result(Uuid::new_v4(), 123, Some("foo".to_owned()), 0, None) { Err(DbError::DataIntegrityError(e)) => assert!(e.contains("unknown")), r => panic!("Unexpected result: {:?}", r), } diff --git a/queue/src/driver/client.rs b/queue/src/driver/client.rs index 455a235..60d3539 100644 --- a/queue/src/driver/client.rs +++ b/queue/src/driver/client.rs @@ -137,11 +137,35 @@ where } /// Waits for task `id` until it has completed execution by polling its state every `period`. + /// + /// In other words: wait for tasks to be fully done, and in particular, if a task asks to be + /// retried, wait for all necessary retries to happen until the task is done. Use `wait_once` + /// instead to return on the first retry attempt. pub async fn wait( &mut self, db: Arc, id: Uuid, period: Duration, + ) -> DriverResult { + loop { + match self.poll(&mut db.ex().await?, id).await? { + None | Some(TaskResult::Retry(_, _)) => (), + Some(result) => break Ok(result), + } + + self.maybe_notify_worker().await; + + tokio::time::sleep(period).await; + } + } + + /// Waits for task `id` until it has completed execution or until it has attempted to run but + /// has decided to retry by polling its state every `period`. + pub async fn wait_once( + &mut self, + db: Arc, + id: Uuid, + period: Duration, ) -> DriverResult { loop { if let Some(result) = self.poll(&mut db.ex().await?, id).await? { diff --git a/queue/src/driver/mod.rs b/queue/src/driver/mod.rs index 21edfa1..547fa73 100644 --- a/queue/src/driver/mod.rs +++ b/queue/src/driver/mod.rs @@ -518,7 +518,10 @@ mod tests { context.notify_workers(1).await; tokio::time::sleep(Duration::from_millis(1)).await; } - assert_eq!(None, context.client.poll(&mut context.ex().await, id).await.unwrap()); + match context.client.poll(&mut context.ex().await, id).await { + Ok(Some(TaskResult::Retry(_, _))) => (), + e => panic!("{:?}", e), + } context.advance_clock(delay * 2); @@ -556,7 +559,10 @@ mod tests { context.notify_workers(1).await; tokio::time::sleep(Duration::from_millis(1)).await; } - assert_eq!(None, context.client.poll(&mut context.ex().await, id).await.unwrap()); + match context.client.poll(&mut context.ex().await, id).await { + Ok(Some(TaskResult::Retry(_, _))) => (), + e => panic!("{:?}", e), + } context.advance_clock(Duration::from_secs(1)); @@ -601,4 +607,48 @@ mod tests { assert_eq!(4, state.deferred); assert!(!state.done); } + + #[tokio::test] + async fn test_wait_once_returns_retries() { + let opts = WorkerOptions { max_runs: 5, ..Default::default() }; + let mut context = TestContext::setup_one_connected(opts.clone()).await; + + let delay = Duration::from_secs(60); + let task = MockTask { id: 123, defer: Some((2, delay)), ..Default::default() }; + let id = context.client.enqueue(&mut context.ex().await, &task).await.unwrap(); + + // Wait until we know the task has asked to retry the `defer` times we configured. + loop { + { + let state = context.state.lock().await; + assert!(state.len() <= 1); + if let Some(state) = state.get(&123) { + assert!(!state.done); + if state.deferred == task.defer.unwrap().0 { + break; + } + } + } + context.advance_clock(delay); + context.notify_workers(1).await; + tokio::time::sleep(Duration::from_millis(1)).await; + } + + for _ in 0..2 { + let result = + context.client.wait_once(context.db.clone(), id, Duration::from_millis(1)).await; + match result { + Ok(TaskResult::Retry(_, _)) => (), + e => panic!("{:?}", e), + } + } + let result = context.client.wait(context.db.clone(), id, Duration::from_millis(1)).await; + assert_eq!(Ok(TaskResult::Done(None)), result); + + let state = context.state.lock().await; + assert_eq!(1, state.len()); + let state = state.get(&123).unwrap(); + assert_eq!(2, state.deferred); + assert!(state.done); + } } diff --git a/smtp/Cargo.toml b/smtp/Cargo.toml new file mode 100644 index 0000000..ac698fd --- /dev/null +++ b/smtp/Cargo.toml @@ -0,0 +1,48 @@ +[package] +name = "iii-iv-smtp" +version = "0.0.0" +description = "III-IV: SMTP support" +authors = ["Julio Merino "] +edition = "2021" +publish = false + +[features] +default = ["postgres"] +postgres = ["dep:sqlx", "iii-iv-core/postgres", "sqlx/postgres"] +sqlite = ["dep:sqlx", "iii-iv-core/sqlite", "sqlx/sqlite"] +testutils = ["dep:futures", "dep:env_logger", "dep:quoted_printable", "iii-iv-core/sqlite"] + +[dependencies] +async-trait = { workspace = true } +axum = { workspace = true } +derivative = { workspace = true } +env_logger = { workspace = true, optional = true } +futures = { workspace = true, optional = true } +http = { workspace = true } +iii-iv-core = { path = "../core" } +quoted_printable = { workspace = true, optional = true } +serde_json = { workspace = true } +thiserror = { workspace = true } +time = { workspace = true } + +[dependencies.lettre] +workspace = true +features = ["builder", "hostname", "pool", "rustls-tls", "smtp-transport", "tokio1-rustls-tls"] + +[dependencies.sqlx] +version = "0.7" +optional = true +features = ["runtime-tokio-rustls", "time"] + +[dev-dependencies] +env_logger = { workspace = true } +futures = { workspace = true } +iii-iv-core = { path = "../core", features = ["sqlite", "testutils"] } +quoted_printable = { workspace = true } +temp-env = { workspace = true } +time = { workspace = true, features = ["macros"] } +tokio = { workspace = true, features = ["macros", "rt-multi-thread"] } + +[dev-dependencies.sqlx] +workspace = true +features = ["runtime-tokio-rustls", "sqlite", "time"] diff --git a/smtp/src/db/mod.rs b/smtp/src/db/mod.rs new file mode 100644 index 0000000..7e5961e --- /dev/null +++ b/smtp/src/db/mod.rs @@ -0,0 +1,217 @@ +// III-IV +// Copyright 2023 Julio Merino +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not +// use this file except in compliance with the License. You may obtain a copy +// of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. + +//! Database abstraction to track email submissions. + +#[cfg(test)] +use futures::TryStreamExt; +#[cfg(feature = "postgres")] +use iii_iv_core::db::postgres; +#[cfg(test)] +use iii_iv_core::db::sqlite::build_timestamp; +#[cfg(any(feature = "sqlite", test))] +use iii_iv_core::db::sqlite::{self, unpack_timestamp}; +use iii_iv_core::db::{count_as_usize, ensure_one_upsert, DbResult, Executor}; +use lettre::Message; +use sqlx::Row; +use time::{Date, OffsetDateTime}; + +#[cfg(test)] +mod tests; + +/// Initializes the database schema. +pub async fn init_schema(ex: &mut Executor) -> DbResult<()> { + match ex { + #[cfg(feature = "postgres")] + Executor::Postgres(ref mut ex) => { + postgres::run_schema(ex, include_str!("postgres.sql")).await + } + + #[cfg(any(feature = "sqlite", test))] + Executor::Sqlite(ref mut ex) => sqlite::run_schema(ex, include_str!("sqlite.sql")).await, + + #[allow(unused)] + _ => unreachable!(), + } +} + +/// Counts how many emails were sent on `day`. +pub(crate) async fn count_email_log(ex: &mut Executor, day: Date) -> DbResult { + let total: i64 = match ex { + Executor::Postgres(ref mut ex) => { + let from = day.midnight().assume_utc(); + let to = from + time::Duration::DAY; + + let query_str = + "SELECT COUNT(*) AS total FROM email_log WHERE sent >= $1 AND sent < $2"; + let row = sqlx::query(query_str) + .bind(from) + .bind(to) + .fetch_one(ex) + .await + .map_err(postgres::map_sqlx_error)?; + row.try_get("total").map_err(postgres::map_sqlx_error)? + } + + #[cfg(any(test, feature = "sqlite"))] + Executor::Sqlite(ref mut ex) => { + let from = day.midnight().assume_utc(); + let to = from + time::Duration::DAY; + + let (from_sec, from_nsec) = unpack_timestamp(from); + let (to_sec, to_nsec) = unpack_timestamp(to); + + let query_str = " + SELECT COUNT(*) AS total + FROM email_log + WHERE + (sent_sec >= ? OR (sent_sec = ? AND sent_nsec >= ?)) + AND (sent_sec < ? OR (sent_sec = ? AND sent_nsec < ?)) + "; + let row = sqlx::query(query_str) + .bind(from_sec) + .bind(from_sec) + .bind(from_nsec) + .bind(to_sec) + .bind(to_sec) + .bind(to_nsec) + .fetch_one(ex) + .await + .map_err(sqlite::map_sqlx_error)?; + row.try_get("total").map_err(sqlite::map_sqlx_error)? + } + + #[allow(unused)] + _ => unreachable!(), + }; + count_as_usize(total) +} + +/// En entry in the email log. +#[cfg(test)] +type EmailLogEntry = (OffsetDateTime, Vec, Option); + +/// Gets all entries in the email log. +#[cfg(test)] +pub(crate) async fn get_email_log(ex: &mut Executor) -> DbResult> { + let mut entries = vec![]; + match ex { + Executor::Postgres(ref mut ex) => { + let query_str = "SELECT sent, message, result FROM email_log"; + let mut rows = sqlx::query(query_str).fetch(ex); + while let Some(row) = rows.try_next().await.map_err(postgres::map_sqlx_error)? { + let sent: OffsetDateTime = row.try_get("sent").map_err(postgres::map_sqlx_error)?; + let message: Vec = row.try_get("message").map_err(postgres::map_sqlx_error)?; + let result: Option = + row.try_get("result").map_err(postgres::map_sqlx_error)?; + + entries.push((sent, message, result)); + } + } + + #[cfg(any(test, feature = "sqlite"))] + Executor::Sqlite(ref mut ex) => { + let query_str = "SELECT sent_sec, sent_nsec, message, result FROM email_log"; + let mut rows = sqlx::query(query_str).fetch(ex); + while let Some(row) = rows.try_next().await.map_err(sqlite::map_sqlx_error)? { + let sent_sec: i64 = row.try_get("sent_sec").map_err(sqlite::map_sqlx_error)?; + let sent_nsec: i64 = row.try_get("sent_nsec").map_err(sqlite::map_sqlx_error)?; + let message: Vec = row.try_get("message").map_err(sqlite::map_sqlx_error)?; + let result: Option = + row.try_get("result").map_err(sqlite::map_sqlx_error)?; + + let sent = build_timestamp(sent_sec, sent_nsec)?; + + entries.push((sent, message, result)) + } + } + + #[allow(unused)] + _ => unreachable!(), + } + Ok(entries) +} + +/// Records that an email was sent to `email` at time `now`. +pub(crate) async fn put_email_log( + ex: &mut Executor, + message: &Message, + now: OffsetDateTime, +) -> DbResult { + match ex { + Executor::Postgres(ref mut ex) => { + let query_str = "INSERT INTO email_log (sent, message) VALUES ($1, $2) RETURNING id"; + let row = sqlx::query(query_str) + .bind(now) + .bind(message.formatted()) + .fetch_one(ex) + .await + .map_err(postgres::map_sqlx_error)?; + let last_insert_id: i32 = row.try_get("id").map_err(postgres::map_sqlx_error)?; + Ok(i64::from(last_insert_id)) + } + + #[cfg(any(test, feature = "sqlite"))] + Executor::Sqlite(ref mut ex) => { + let (now_sec, now_nsec) = unpack_timestamp(now); + + let query_str = "INSERT INTO email_log (sent_sec, sent_nsec, message) VALUES (?, ?, ?)"; + let done = sqlx::query(query_str) + .bind(now_sec) + .bind(now_nsec) + .bind(message.formatted()) + .execute(ex) + .await + .map_err(sqlite::map_sqlx_error)?; + Ok(done.last_insert_rowid()) + } + + #[allow(unused)] + _ => unreachable!(), + } +} + +/// Records the result of sending an email. +pub(crate) async fn update_email_log(ex: &mut Executor, id: i64, result: &str) -> DbResult<()> { + match ex { + Executor::Postgres(ref mut ex) => { + let query_str = "UPDATE email_log SET result = $1 WHERE id = $2"; + let done = sqlx::query(query_str) + .bind(result) + .bind(id) + .execute(ex) + .await + .map_err(postgres::map_sqlx_error)?; + ensure_one_upsert(done.rows_affected())?; + Ok(()) + } + + #[cfg(any(test, feature = "sqlite"))] + Executor::Sqlite(ref mut ex) => { + let query_str = "UPDATE email_log SET result = ? WHERE id = ?"; + let done = sqlx::query(query_str) + .bind(result) + .bind(id) + .execute(ex) + .await + .map_err(sqlite::map_sqlx_error)?; + ensure_one_upsert(done.rows_affected())?; + Ok(()) + } + + #[allow(unused)] + _ => unreachable!(), + } +} diff --git a/smtp/src/db/postgres.sql b/smtp/src/db/postgres.sql new file mode 100644 index 0000000..84774df --- /dev/null +++ b/smtp/src/db/postgres.sql @@ -0,0 +1,24 @@ +-- III-IV +-- Copyright 2023 Julio Merino +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); you may not +-- use this file except in compliance with the License. You may obtain a copy +-- of the License at: +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +-- WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +-- License for the specific language governing permissions and limitations +-- under the License. + +CREATE TABLE IF NOT EXISTS email_log ( + id SERIAL PRIMARY KEY, + + sent TIMESTAMPTZ NOT NULL, + message BYTEA NOT NULL, + result TEXT +); + +CREATE INDEX email_log_by_sent ON email_log (sent); diff --git a/smtp/src/db/sqlite.sql b/smtp/src/db/sqlite.sql new file mode 100644 index 0000000..f2fba70 --- /dev/null +++ b/smtp/src/db/sqlite.sql @@ -0,0 +1,27 @@ +-- III-IV +-- Copyright 2023 Julio Merino +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); you may not +-- use this file except in compliance with the License. You may obtain a copy +-- of the License at: +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +-- WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +-- License for the specific language governing permissions and limitations +-- under the License. + +PRAGMA foreign_keys = ON; + +CREATE TABLE IF NOT EXISTS email_log ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + + sent_sec INTEGER NOT NULL, + sent_nsec INTEGER NOT NULL, + message BYTEA NOT NULL, + result TEXT +); + +CREATE INDEX email_log_by_sent ON email_log (sent_sec, sent_nsec); diff --git a/smtp/src/db/tests.rs b/smtp/src/db/tests.rs new file mode 100644 index 0000000..39dd650 --- /dev/null +++ b/smtp/src/db/tests.rs @@ -0,0 +1,87 @@ +// III-IV +// Copyright 2023 Julio Merino +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not +// use this file except in compliance with the License. You may obtain a copy +// of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. + +//! Common tests for any database implementation. + +use crate::db::*; +use iii_iv_core::db::Executor; +use time::macros::{date, datetime}; + +async fn test_email_log(ex: &mut Executor) { + // The message contents should be completely irrelevant for counting purposes, so keeping + // them all identical helps assert that. + let message = Message::builder() + .from("from@example.com".parse().unwrap()) + .to("to@example.com".parse().unwrap()) + .subject("Foo") + .body("Bar".to_owned()) + .unwrap(); + + put_email_log(ex, &message, datetime!(2023-06-11 00:00:00.000000 UTC)).await.unwrap(); + put_email_log(ex, &message, datetime!(2023-06-12 06:20:00.000001 UTC)).await.unwrap(); + put_email_log(ex, &message, datetime!(2023-06-12 06:20:00.000002 UTC)).await.unwrap(); + put_email_log(ex, &message, datetime!(2023-06-12 23:59:59.999999 UTC)).await.unwrap(); + + assert_eq!(0, count_email_log(ex, date!(2023 - 06 - 10)).await.unwrap()); + assert_eq!(1, count_email_log(ex, date!(2023 - 06 - 11)).await.unwrap()); + assert_eq!(3, count_email_log(ex, date!(2023 - 06 - 12)).await.unwrap()); + assert_eq!(0, count_email_log(ex, date!(2023 - 06 - 13)).await.unwrap()); +} + +macro_rules! generate_db_tests [ + ( $setup:expr $(, #[$extra:meta] )? ) => { + iii_iv_core::db::testutils::generate_tests!( + $(#[$extra],)? + $setup, + $crate::db::tests, + test_email_log + ); + } +]; + +use generate_db_tests; + +mod postgres { + use super::*; + use crate::db::init_schema; + use iii_iv_core::db::postgres::PostgresDb; + use iii_iv_core::db::Db; + + async fn setup() -> PostgresDb { + let db = iii_iv_core::db::postgres::testutils::setup().await; + init_schema(&mut db.ex().await.unwrap()).await.unwrap(); + db + } + + generate_db_tests!( + &mut setup().await.ex().await.unwrap(), + #[ignore = "Requires environment configuration and is expensive"] + ); +} + +mod sqlite { + use super::*; + use crate::db::init_schema; + use iii_iv_core::db::sqlite::SqliteDb; + use iii_iv_core::db::Db; + + async fn setup() -> SqliteDb { + let db = iii_iv_core::db::sqlite::testutils::setup().await; + init_schema(&mut db.ex().await.unwrap()).await.unwrap(); + db + } + + generate_db_tests!(&mut setup().await.ex().await.unwrap()); +} diff --git a/smtp/src/driver/mod.rs b/smtp/src/driver/mod.rs new file mode 100644 index 0000000..aab39ac --- /dev/null +++ b/smtp/src/driver/mod.rs @@ -0,0 +1,362 @@ +// III-IV +// Copyright 2023 Julio Merino +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not +// use this file except in compliance with the License. You may obtain a copy +// of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. + +//! Utilities to send messages over email. + +use crate::db::{count_email_log, put_email_log, update_email_log}; +use async_trait::async_trait; +use derivative::Derivative; +use iii_iv_core::clocks::Clock; +use iii_iv_core::db::Db; +use iii_iv_core::driver::{DriverError, DriverResult}; +use iii_iv_core::env::{get_optional_var, get_required_var}; +use lettre::message::Message; +use lettre::transport::smtp::authentication::Credentials; +use lettre::{AsyncSmtpTransport, AsyncTransport, Tokio1Executor}; +use std::sync::Arc; + +#[cfg(any(test, feature = "testutils"))] +pub mod testutils; + +/// Options to establish an SMTP connection. +#[derive(Derivative)] +#[derivative(Debug)] +#[cfg_attr(test, derivative(PartialEq))] +pub struct SmtpOptions { + /// SMTP server to use. + pub relay: String, + + /// Username for logging into the SMTP server. + pub username: String, + + /// Password for logging into the SMTP server. + #[derivative(Debug = "ignore")] + pub password: String, + + /// Maximum number of messages to send per day, if any. + pub max_daily_emails: Option, +} + +impl SmtpOptions { + /// Initializes a set of options from environment variables whose name is prefixed with the + /// given `prefix`. + /// + /// This will use variables such as `_RELAY`, `_USERNAME`, `_PASSWORD` + /// and `_MAX_DAILY_EMAILS`. + pub fn from_env(prefix: &str) -> Result { + Ok(Self { + relay: get_required_var::(prefix, "RELAY")?, + username: get_required_var::(prefix, "USERNAME")?, + password: get_required_var::(prefix, "PASSWORD")?, + max_daily_emails: get_optional_var::(prefix, "MAX_DAILY_EMAILS")?, + }) + } +} + +/// Trait to abstract the integration with the mailer. +#[async_trait] +pub trait SmtpMailer { + /// Sends a message over SMTP. + async fn send(&self, message: Message) -> DriverResult<()>; +} + +/// Mailer backed by a real SMTP connection using `lettre`. +#[derive(Clone)] +pub struct LettreSmtpMailer(AsyncSmtpTransport); + +impl LettreSmtpMailer { + /// Establishes a connection to the SMTP server. + fn connect(relay: &str, username: String, password: String) -> Result { + let creds = Credentials::new(username, password); + let mailer = AsyncSmtpTransport::::relay(relay) + .map_err(|e| format!("{}", e))? + .credentials(creds) + .build(); + Ok(LettreSmtpMailer(mailer)) + } +} + +#[async_trait] +impl SmtpMailer for LettreSmtpMailer { + async fn send(&self, message: Message) -> DriverResult<()> { + self.0 + .send(message) + .await + .map_err(|e| DriverError::BackendError(format!("SMTP communication failed: {}", e)))?; + Ok(()) + } +} + +/// Encapsulates logic to send email messages while respecting quotas. +#[derive(Clone)] +pub struct SmtpDriver { + /// The SMTP transport with which to send email messages. + transport: T, + + /// The database with which to track sent messages. + db: Arc, + + /// The clock from which to obtain the current time. + clock: Arc, + + /// Maximum number of messages to send per day, if any. + max_daily_emails: Option, +} + +impl SmtpDriver { + /// Creates a new driver with the given values. + pub fn new( + transport: T, + db: Arc, + clock: Arc, + max_daily_emails: Option, + ) -> Self { + Self { transport, db, clock, max_daily_emails } + } + + /// Obtains a reference to the wrapped SMTP transport. + pub fn get_transport(&self) -> &T { + &self.transport + } +} + +#[async_trait] +impl SmtpMailer for SmtpDriver +where + T: SmtpMailer + Send + Sync, +{ + /// Sends an email message after recording it and accounting for it for quota purposes. + async fn send(&self, message: Message) -> DriverResult<()> { + let mut tx = self.db.begin().await?; + let now = self.clock.now_utc(); + + // We must insert into the table first, before counting, to grab an exclusive transaction + // lock. Otherwise the count will be stale by the time we use it. + let id = put_email_log(tx.ex(), &message, now).await?; + + if let Some(max_daily_emails) = self.max_daily_emails { + let daily_emails = count_email_log(tx.ex(), now.date()).await? - 1; + if daily_emails >= max_daily_emails { + let msg = format!( + "Too many emails sent today ({} >= {})", + daily_emails, max_daily_emails, + ); + update_email_log(tx.ex(), id, &msg).await?; + return Err(DriverError::NoSpace(msg)); + } + } + + // Commit the transaction _before_ trying to send the email. This is intentional to ignore + // errors from the server because we don't know if errors are counted towards the daily + // quota. Furthermore, this avoids sequencing email submissions if the server is slow. + tx.commit().await?; + + let result = self.transport.send(message).await; + + match result { + Ok(()) => update_email_log(&mut self.db.ex().await?, id, "OK").await?, + Err(ref e) => update_email_log(&mut self.db.ex().await?, id, &format!("{}", e)).await?, + } + + result + } +} + +/// Creates a new SMTP driver that sends email messages via the service configured in `opts`. +/// +/// `db` and `clock` are used to keep track of the messages that have been sent for quota +/// accounting purposes. +pub fn new_prod_driver( + opts: SmtpOptions, + db: Arc, + clock: Arc, +) -> Result, String> { + let transport = LettreSmtpMailer::connect(&opts.relay, opts.username, opts.password)?; + Ok(SmtpDriver::new(transport, db, clock, opts.max_daily_emails)) +} + +#[cfg(test)] +mod tests { + use super::testutils::*; + use super::*; + use crate::db::get_email_log; + use futures::future; + use std::env; + use std::time::Duration; + + #[test] + fn test_smtp_options_from_env_all_required_present() { + let overrides = [ + ("SMTP_RELAY", Some("the-relay")), + ("SMTP_USERNAME", Some("the-username")), + ("SMTP_PASSWORD", Some("the-password")), + ]; + temp_env::with_vars(overrides, || { + let opts = SmtpOptions::from_env("SMTP").unwrap(); + assert_eq!( + SmtpOptions { + relay: "the-relay".to_owned(), + username: "the-username".to_owned(), + password: "the-password".to_owned(), + max_daily_emails: None, + }, + opts + ); + }); + } + + #[test] + fn test_smtp_options_from_env_all_required_and_optional_present() { + let overrides = [ + ("SMTP_RELAY", Some("the-relay")), + ("SMTP_USERNAME", Some("the-username")), + ("SMTP_PASSWORD", Some("the-password")), + ("SMTP_MAX_DAILY_EMAILS", Some("123")), + ]; + temp_env::with_vars(overrides, || { + let opts = SmtpOptions::from_env("SMTP").unwrap(); + assert_eq!( + SmtpOptions { + relay: "the-relay".to_owned(), + username: "the-username".to_owned(), + password: "the-password".to_owned(), + max_daily_emails: Some(123), + }, + opts + ); + }); + } + + #[test] + fn test_smtp_options_from_env_missing() { + let overrides = [ + ("MISSING_RELAY", Some("the-relay")), + ("MISSING_USERNAME", Some("the-username")), + ("MISSING_PASSWORD", Some("the-password")), + ]; + for (var, _) in overrides { + temp_env::with_vars(overrides, || { + env::remove_var(var); + let err = SmtpOptions::from_env("MISSING").unwrap_err(); + assert!(err.contains(&format!("{} not present", var))); + }); + } + } + + /// Creates a new email message with hardcoded values. + fn new_message() -> Message { + Message::builder() + .from("from@example.com".parse().unwrap()) + .to("to@example.com".parse().unwrap()) + .subject("Foo") + .body("Bar".to_owned()) + .unwrap() + } + + #[tokio::test] + async fn test_send_ok() { + let mut context = TestContext::setup(None).await; + let exp_message = new_message(); + + context.driver.send(exp_message.clone()).await.unwrap(); + + let message = context.mailer.expect_one_message(&"to@example.com".into()).await; + assert_eq!(exp_message.formatted(), message.formatted()); + + let log = get_email_log(&mut context.ex().await).await.unwrap(); + assert_eq!(1, log.len()); + assert_eq!(exp_message.formatted(), log[0].1); + assert_eq!(Some("OK"), log[0].2.as_deref()); + } + + #[tokio::test] + async fn test_send_error() { + let mut context = TestContext::setup(None).await; + let exp_message = new_message(); + + context.mailer.inject_error_for("to@example.com").await; + let err = context.driver.send(exp_message.clone()).await.unwrap_err(); + assert_eq!("Sending email to to@example.com failed", &format!("{}", err)); + + context.mailer.expect_no_messages().await; + + let log = get_email_log(&mut context.ex().await).await.unwrap(); + assert_eq!(1, log.len()); + assert_eq!(exp_message.formatted(), log[0].1); + assert_eq!(Some("Sending email to to@example.com failed"), log[0].2.as_deref()); + } + + #[tokio::test] + async fn test_daily_limit_enforced_and_clears_every_day() { + let mut context = TestContext::setup(Some(50)).await; + let exp_message = new_message(); + + for _ in 0..50 { + put_email_log(&mut context.ex().await, &exp_message, context.clock.now_utc()) + .await + .unwrap(); + } + + let err = context.driver.send(exp_message.clone()).await.unwrap_err(); + assert_eq!("Too many emails sent today (50 >= 50)", &format!("{}", err)); + context.mailer.expect_no_messages().await; + + // Advance the clock to reach just the 23rd hour of the same day. + let current_hour = u64::from(context.clock.now_utc().hour()); + context.clock.advance(Duration::from_secs((23 - current_hour) * 60 * 60)); + + let err = context.driver.send(exp_message.clone()).await.unwrap_err(); + assert_eq!("Too many emails sent today (50 >= 50)", &format!("{}", err)); + context.mailer.expect_no_messages().await; + + // Push the clock into the next day. + context.clock.advance(Duration::from_secs(60 * 60)); + + context.driver.send(exp_message.clone()).await.unwrap(); + let message = context.mailer.expect_one_message(&"to@example.com".into()).await; + assert_eq!(exp_message.formatted(), message.formatted()); + } + + #[tokio::test(flavor = "multi_thread")] + async fn test_daily_limit_concurrency() { + let context = TestContext::setup(Some(10)).await; + let exp_message = new_message(); + + let mut futures = Vec::with_capacity(1000); + for _ in 0..1000 { + futures.push(async { + match context.driver.send(exp_message.clone()).await { + Ok(()) => true, + Err(_) => false, + } + }); + } + + let mut count_ok = 0; + let mut count_err = 0; + for ok in future::join_all(futures.into_iter()).await { + if ok { + count_ok += 1; + } else { + count_err += 1; + } + } + assert_eq!(10, count_ok); + assert_eq!(990, count_err); + + let inbox = context.mailer.expect_one_inbox(&"to@example.com".into()).await; + assert_eq!(10, inbox.len()); + } +} diff --git a/smtp/src/driver/testutils.rs b/smtp/src/driver/testutils.rs new file mode 100644 index 0000000..9b20a3f --- /dev/null +++ b/smtp/src/driver/testutils.rs @@ -0,0 +1,266 @@ +// III-IV +// Copyright 2023 Julio Merino +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not +// use this file except in compliance with the License. You may obtain a copy +// of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. + +//! Test utilities for email handling. + +use crate::driver::SmtpMailer; +use async_trait::async_trait; +use futures::lock::Mutex; +use iii_iv_core::driver::{DriverError, DriverResult}; +use iii_iv_core::model::EmailAddress; +use lettre::Message; +use std::collections::{HashMap, HashSet}; +use std::sync::Arc; + +#[cfg(test)] +use { + super::SmtpDriver, + crate::db::init_schema, + iii_iv_core::clocks::testutils::SettableClock, + iii_iv_core::db::{sqlite, Db, Executor}, + time::macros::datetime, +}; + +/// Mailer that captures outgoing messages. +#[derive(Clone, Default)] +pub struct RecorderSmtpMailer { + /// Storage for captured messages. + pub inboxes: Arc>>>, + + /// Addresses for which to fail sending a message to. + errors: Arc>>, +} + +impl RecorderSmtpMailer { + /// Makes trying to send errors to `email` fail with an error. + pub async fn inject_error_for>(&self, email: E) { + let mut errors = self.errors.lock().await; + errors.insert(email.into()); + } + + /// Expects that no messages were sent. + pub async fn expect_no_messages(&self) { + let inboxes = self.inboxes.lock().await; + assert_eq!(0, inboxes.len(), "Expected to find no messages"); + } + + /// Expects that messages were sent to `exp_to` and nobody else, and returns the list of + /// messages to that recipient. + pub async fn expect_one_inbox(&self, exp_to: &EmailAddress) -> Vec { + let inboxes = self.inboxes.lock().await; + assert_eq!(1, inboxes.len(), "Expected to find just one message in one inbox"); + let (to, messages) = inboxes.iter().next().unwrap(); + assert_eq!(exp_to, to); + messages.clone() + } + + /// Expects that only one message was sent to `exp_to` and nobody else, and returns the + /// message. + pub async fn expect_one_message(&self, exp_to: &EmailAddress) -> Message { + let mut messages = self.expect_one_inbox(exp_to).await; + assert_eq!(1, messages.len(), "Expected to find just one message for {}", exp_to.as_str()); + messages.pop().unwrap() + } +} + +#[async_trait] +impl SmtpMailer for RecorderSmtpMailer { + async fn send(&self, message: Message) -> DriverResult<()> { + let to = EmailAddress::from( + message.headers().get_raw("To").expect("To header must have been present"), + ); + + { + let errors = self.errors.lock().await; + if errors.contains(&to) { + return Err(DriverError::BackendError(format!( + "Sending email to {} failed", + to.as_str() + ))); + } + } + + let mut inboxes = self.inboxes.lock().await; + inboxes.entry(to).or_insert_with(Vec::default).push(message); + Ok(()) + } +} + +/// Container for the state required to run a driver test. +#[cfg(test)] +pub(crate) struct TestContext { + pub(crate) driver: SmtpDriver, + pub(crate) db: Arc, + pub(crate) clock: Arc, + pub(crate) mailer: RecorderSmtpMailer, +} + +#[cfg(test)] +impl TestContext { + pub(crate) async fn setup(max_daily_emails: Option) -> Self { + let _can_fail = env_logger::builder().is_test(true).try_init(); + + let db = Arc::from(sqlite::testutils::setup().await); + let mut ex = db.ex().await.unwrap(); + init_schema(&mut ex).await.unwrap(); + + let clock = Arc::from(SettableClock::new(datetime!(2023-10-17 06:00:00 UTC))); + + let mailer = RecorderSmtpMailer::default(); + + let driver = SmtpDriver { + transport: mailer.clone(), + db: db.clone(), + clock: clock.clone(), + max_daily_emails, + }; + + Self { driver, db, clock, mailer } + } + + pub(crate) async fn ex(&mut self) -> Executor { + self.db.ex().await.unwrap() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use iii_iv_core::model::EmailAddress; + use std::panic::catch_unwind; + + /// Creates a new message where the only thing that matters is toe `to` field. + fn new_message(to: &EmailAddress) -> Message { + Message::builder() + .from("from@example.com".parse().unwrap()) + .to(to.as_str().parse().unwrap()) + .subject("Test") + .body("Body".to_owned()) + .unwrap() + } + + #[tokio::test] + async fn test_recorder_inject_error() { + let to1 = EmailAddress::from("to1@example.com"); + let to2 = EmailAddress::from("to2@example.com"); + let to3 = EmailAddress::from("to3@example.com"); + + let mailer = RecorderSmtpMailer::default(); + mailer.inject_error_for(to2.clone()).await; + + mailer.send(new_message(&to1)).await.unwrap(); + mailer.send(new_message(&to2)).await.unwrap_err(); + mailer.send(new_message(&to3)).await.unwrap(); + + let inboxes = mailer.inboxes.lock().await; + assert!(inboxes.contains_key(&to1)); + assert!(!inboxes.contains_key(&to2)); + assert!(inboxes.contains_key(&to3)); + } + + #[tokio::test] + async fn test_recorder_expect_no_messages_ok() { + let mailer = RecorderSmtpMailer::default(); + mailer.expect_no_messages().await; + } + + #[tokio::test] + async fn test_recorder_expect_no_messages_fail() { + #[tokio::main(flavor = "current_thread")] + async fn do_test() { + let to1 = EmailAddress::from("to1@example.com"); + let mailer = RecorderSmtpMailer::default(); + mailer.send(new_message(&to1)).await.unwrap(); + mailer.expect_no_messages().await; // Will panic. + } + assert!(catch_unwind(do_test).is_err()); + } + + #[tokio::test] + async fn test_recorder_expect_one_inbox_ok() { + let to = EmailAddress::from("to@example.com"); + let message = new_message(&to); + let exp_formatted = message.formatted(); + + let mailer = RecorderSmtpMailer::default(); + mailer.send(message.clone()).await.unwrap(); + mailer.send(message).await.unwrap(); + + let messages = mailer.expect_one_inbox(&to).await; + assert_eq!( + vec![exp_formatted.clone(), exp_formatted], + messages.iter().map(Message::formatted).collect::>>(), + ); + } + + #[test] + fn test_recorder_expect_one_inbox_too_many_recipients() { + #[tokio::main(flavor = "current_thread")] + async fn do_test() { + let to1 = EmailAddress::from("to1@example.com"); + let to2 = EmailAddress::from("to2@example.com"); + + let mailer = RecorderSmtpMailer::default(); + mailer.send(new_message(&to1)).await.unwrap(); + mailer.send(new_message(&to2)).await.unwrap(); + + let _ = mailer.expect_one_inbox(&to1).await; // Will panic. + } + assert!(catch_unwind(do_test).is_err()); + } + + #[tokio::test] + async fn test_recorder_expect_one_message_ok() { + let to = EmailAddress::from("to@example.com"); + let message = new_message(&to); + let exp_formatted = message.formatted(); + + let mailer = RecorderSmtpMailer::default(); + mailer.send(message).await.unwrap(); + + assert_eq!(exp_formatted, mailer.expect_one_message(&to).await.formatted()); + } + + #[test] + fn test_recorder_expect_one_message_too_many_recipients() { + #[tokio::main(flavor = "current_thread")] + async fn do_test() { + let to1 = EmailAddress::from("to1@example.com"); + let to2 = EmailAddress::from("to2@example.com"); + + let mailer = RecorderSmtpMailer::default(); + mailer.send(new_message(&to1)).await.unwrap(); + mailer.send(new_message(&to2)).await.unwrap(); + + let _ = mailer.expect_one_message(&to1).await; // Will panic. + } + assert!(catch_unwind(do_test).is_err()); + } + + #[test] + fn test_recorder_expect_one_message_too_many_messages() { + #[tokio::main(flavor = "current_thread")] + async fn do_test() { + let to = EmailAddress::from("to@example.com"); + + let mailer = RecorderSmtpMailer::default(); + mailer.send(new_message(&to)).await.unwrap(); + mailer.send(new_message(&to)).await.unwrap(); + + let _ = mailer.expect_one_message(&to).await; // Will panic. + } + assert!(catch_unwind(do_test).is_err()); + } +} diff --git a/smtp/src/lib.rs b/smtp/src/lib.rs new file mode 100644 index 0000000..00cd6e5 --- /dev/null +++ b/smtp/src/lib.rs @@ -0,0 +1,25 @@ +// III-IV +// Copyright 2023 Julio Merino +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not +// use this file except in compliance with the License. You may obtain a copy +// of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. + +//! Utilities to send messages over email. + +// Keep these in sync with other top-level files. +#![warn(anonymous_parameters, bad_style, clippy::missing_docs_in_private_items, missing_docs)] +#![warn(unused, unused_extern_crates, unused_import_braces, unused_qualifications)] +#![warn(unsafe_code)] + +pub mod db; +pub mod driver; +pub mod model; diff --git a/smtp/src/model.rs b/smtp/src/model.rs new file mode 100644 index 0000000..49c54b4 --- /dev/null +++ b/smtp/src/model.rs @@ -0,0 +1,170 @@ +// III-IV +// Copyright 2023 Julio Merino +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not +// use this file except in compliance with the License. You may obtain a copy +// of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. + +//! Data types to interact with email messages. + +use iii_iv_core::model::{EmailAddress, ModelError, ModelResult}; +use iii_iv_core::template; +use lettre::message::header::ContentTransferEncoding; +use lettre::message::Body; +pub use lettre::message::{Mailbox, Message}; + +/// A template for an email message. +pub struct EmailTemplate { + /// Who the message comes from. + pub from: Mailbox, + + /// Subject of the message. + pub subject_template: &'static str, + + /// Body of the message. + pub body_template: &'static str, +} + +impl EmailTemplate { + /// Creates a message sent to `to` based on the template by applying the collection of + /// `replacements` to it. + /// + /// The subject and body of the template are subject to string replacements per the rules + /// described in `iii_iv_core::template::apply`. + pub fn apply( + &self, + to: &EmailAddress, + replacements: &[(&'static str, &str)], + ) -> ModelResult { + let to = to.as_str().parse().map_err(|e| { + // TODO(jmmv): This should never happen... but there is no guarantee right now that we can + // convert III-IV's `EmailAddress` into whatever Lettre expects. It'd be nice if we didn't + // need this though. + ModelError(format!("Cannot parse email address {}: {}", to.as_str(), e)) + })?; + + let subject = template::apply(self.subject_template, replacements); + + let body = Body::new_with_encoding( + template::apply(self.body_template, replacements), + ContentTransferEncoding::QuotedPrintable, + ) + .map_err(|e| ModelError(format!("Failed to encode message: {:?}", e)))?; + + let message = Message::builder() + .from(self.from.clone()) + .to(to) + .subject(subject) + .body(body) + .map_err(|e| ModelError(format!("Failed to encode message: {:?}", e)))?; + Ok(message) + } +} + +/// Utilities to help testing email messages. +#[cfg(any(test, feature = "testutils"))] +pub mod testutils { + use super::*; + use std::collections::HashMap; + + /// Given an SMTP `message`, parses it and extracts its headers and body. + pub fn parse_message(message: &Message) -> (HashMap, String) { + let text = String::from_utf8(message.formatted()).unwrap(); + let (raw_headers, encoded_body) = text + .split_once("\r\n\r\n") + .unwrap_or_else(|| panic!("Message seems to have the wrong format: {}", text)); + + let mut headers = HashMap::default(); + for raw_header in raw_headers.split("\r\n") { + let (key, value) = raw_header + .split_once(": ") + .unwrap_or_else(|| panic!("Header seems to have the wrong format: {}", raw_header)); + let previous = headers.insert(key.to_owned(), value.to_owned()); + assert!(previous.is_none(), "Duplicate header {}", raw_header); + } + + let decoded_body = + quoted_printable::decode(encoded_body, quoted_printable::ParseMode::Strict).unwrap(); + let body = String::from_utf8(decoded_body).unwrap().replace("\r\n", "\n"); + + (headers, body) + } +} + +#[cfg(test)] +mod tests { + use super::testutils::*; + use super::*; + + #[test] + fn test_email_template() { + let template = EmailTemplate { + from: "Sender ".parse().unwrap(), + subject_template: "The %s%", + body_template: "The %b% with quoted printable =50 characters", + }; + + let message = template + .apply( + &EmailAddress::from("recipient@example.com"), + &[("s", "replaced subject"), ("b", "replaced body")], + ) + .unwrap(); + let (headers, body) = parse_message(&message); + + let exp_message = Message::builder() + .from(template.from) + .to("recipient@example.com".parse().unwrap()) + .subject("The replaced subject") + .body( + Body::new_with_encoding( + "The replaced body with quoted printable =50 characters".to_owned(), + ContentTransferEncoding::QuotedPrintable, + ) + .unwrap(), + ) + .unwrap(); + let (exp_headers, exp_body) = parse_message(&exp_message); + + assert_eq!(exp_headers, headers); + assert_eq!(exp_body, body); + } + + #[test] + fn test_parse_message() { + let exp_body = " +This is a sample message with a line that should be longer than 72 characters to test line wraps. + +There is also a second paragraph with = quoted printable characters. +"; + let message = Message::builder() + .from("From someone ".parse().unwrap()) + .to("to@example.com".parse().unwrap()) + .subject("This: is the: subject line") + .body(exp_body.to_owned()) + .unwrap(); + + // Make sure the encoding of the message is quoted-printable. This isn't strictly required + // because I suppose `parse_message` might succeed anyway, but it's good to encode our + // assumption in a test. + let text = String::from_utf8(message.formatted()).unwrap(); + assert!(text.contains("=3D")); + + let (headers, body) = parse_message(&message); + + assert!(headers.len() >= 3); + assert_eq!("\"From someone\" ", headers.get("From").unwrap()); + assert_eq!("to@example.com", headers.get("To").unwrap()); + assert_eq!("This: is the: subject line", headers.get("Subject").unwrap()); + + assert_eq!(exp_body, body); + } +}