diff --git a/smtp/Cargo.toml b/smtp/Cargo.toml index 6dce4fb..ac698fd 100644 --- a/smtp/Cargo.toml +++ b/smtp/Cargo.toml @@ -7,13 +7,16 @@ edition = "2021" publish = false [features] -default = [] -testutils = ["dep:futures", "dep:quoted_printable"] +default = ["postgres"] +postgres = ["dep:sqlx", "iii-iv-core/postgres", "sqlx/postgres"] +sqlite = ["dep:sqlx", "iii-iv-core/sqlite", "sqlx/sqlite"] +testutils = ["dep:futures", "dep:env_logger", "dep:quoted_printable", "iii-iv-core/sqlite"] [dependencies] async-trait = { workspace = true } axum = { workspace = true } derivative = { workspace = true } +env_logger = { workspace = true, optional = true } futures = { workspace = true, optional = true } http = { workspace = true } iii-iv-core = { path = "../core" } @@ -26,9 +29,20 @@ time = { workspace = true } workspace = true features = ["builder", "hostname", "pool", "rustls-tls", "smtp-transport", "tokio1-rustls-tls"] +[dependencies.sqlx] +version = "0.7" +optional = true +features = ["runtime-tokio-rustls", "time"] + [dev-dependencies] +env_logger = { workspace = true } futures = { workspace = true } -iii-iv-core = { path = "../core", features = ["testutils"] } +iii-iv-core = { path = "../core", features = ["sqlite", "testutils"] } quoted_printable = { workspace = true } temp-env = { workspace = true } -tokio = { workspace = true, features = ["macros"] } +time = { workspace = true, features = ["macros"] } +tokio = { workspace = true, features = ["macros", "rt-multi-thread"] } + +[dev-dependencies.sqlx] +workspace = true +features = ["runtime-tokio-rustls", "sqlite", "time"] diff --git a/smtp/src/db/mod.rs b/smtp/src/db/mod.rs new file mode 100644 index 0000000..9b7d322 --- /dev/null +++ b/smtp/src/db/mod.rs @@ -0,0 +1,217 @@ +// III-IV +// Copyright 2023 Julio Merino +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not +// use this file except in compliance with the License. You may obtain a copy +// of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. + +//! Database abstraction to track email submissions. + +#[cfg(test)] +use futures::TryStreamExt; +#[cfg(feature = "postgres")] +use iii_iv_core::db::postgres; +#[cfg(test)] +use iii_iv_core::db::sqlite::build_timestamp; +#[cfg(any(feature = "sqlite", test))] +use iii_iv_core::db::sqlite::{self, unpack_timestamp}; +use iii_iv_core::db::{count_as_usize, ensure_one_upsert, DbResult, Executor}; +use lettre::Message; +use sqlx::Row; +use time::{Date, OffsetDateTime}; + +#[cfg(test)] +mod tests; + +/// Initializes the database schema. +pub async fn init_schema(ex: &mut Executor) -> DbResult<()> { + match ex { + #[cfg(feature = "postgres")] + Executor::Postgres(ref mut ex) => { + postgres::run_schema(ex, include_str!("postgres.sql")).await + } + + #[cfg(any(feature = "sqlite", test))] + Executor::Sqlite(ref mut ex) => sqlite::run_schema(ex, include_str!("sqlite.sql")).await, + + #[allow(unused)] + _ => unreachable!(), + } +} + +/// Counts how many emails were sent on `day`. +pub(crate) async fn count_email_log(ex: &mut Executor, day: Date) -> DbResult { + let total: i64 = match ex { + Executor::Postgres(ref mut ex) => { + let from = day.midnight().assume_utc(); + let to = from + time::Duration::DAY; + + let query_str = + "SELECT COUNT(*) AS total FROM email_log WHERE sent >= $1 AND sent < $2"; + let row = sqlx::query(query_str) + .bind(from) + .bind(to) + .fetch_one(ex) + .await + .map_err(postgres::map_sqlx_error)?; + row.try_get("total").map_err(postgres::map_sqlx_error)? + } + + #[cfg(any(test, feature = "sqlite"))] + Executor::Sqlite(ref mut ex) => { + let from = day.midnight().assume_utc(); + let to = from + time::Duration::DAY; + + let (from_sec, from_nsec) = unpack_timestamp(from); + let (to_sec, to_nsec) = unpack_timestamp(to); + + let query_str = " + SELECT COUNT(*) AS total + FROM email_log + WHERE + (sent_sec >= ? OR (sent_sec = ? AND sent_nsec >= ?)) + AND (sent_sec < ? OR (sent_sec = ? AND sent_nsec < ?)) + "; + let row = sqlx::query(query_str) + .bind(from_sec) + .bind(from_sec) + .bind(from_nsec) + .bind(to_sec) + .bind(to_sec) + .bind(to_nsec) + .fetch_one(ex) + .await + .map_err(sqlite::map_sqlx_error)?; + row.try_get("total").map_err(sqlite::map_sqlx_error)? + } + + #[allow(unused)] + _ => unreachable!(), + }; + count_as_usize(total) +} + +/// En entry in the email log. +#[cfg(test)] +type EmailLogEntry = (OffsetDateTime, Vec, Option); + +/// Gets all entries in the email log. +#[cfg(test)] +pub(crate) async fn get_email_log(ex: &mut Executor) -> DbResult> { + let mut entries = vec![]; + match ex { + Executor::Postgres(ref mut ex) => { + let query_str = "SELECT sent, message, result FROM email_log"; + let mut rows = sqlx::query(query_str).fetch(ex); + while let Some(row) = rows.try_next().await.map_err(postgres::map_sqlx_error)? { + let sent: OffsetDateTime = row.try_get("sent").map_err(postgres::map_sqlx_error)?; + let message: Vec = row.try_get("message").map_err(postgres::map_sqlx_error)?; + let result: Option = + row.try_get("result").map_err(postgres::map_sqlx_error)?; + + entries.push((sent, message, result)); + } + } + + #[cfg(any(test, feature = "sqlite"))] + Executor::Sqlite(ref mut ex) => { + let query_str = "SELECT sent_sec, sent_nsec, message, result FROM email_log"; + let mut rows = sqlx::query(query_str).fetch(ex); + while let Some(row) = rows.try_next().await.map_err(sqlite::map_sqlx_error)? { + let sent_sec: i64 = row.try_get("sent_sec").map_err(sqlite::map_sqlx_error)?; + let sent_nsec: i64 = row.try_get("sent_nsec").map_err(sqlite::map_sqlx_error)?; + let message: Vec = row.try_get("message").map_err(sqlite::map_sqlx_error)?; + let result: Option = + row.try_get("result").map_err(sqlite::map_sqlx_error)?; + + let sent = build_timestamp(sent_sec, sent_nsec)?; + + entries.push((sent, message, result)) + } + } + + #[allow(unused)] + _ => unreachable!(), + } + Ok(entries) +} + +/// Records that an email was sent to `email` at time `now`. +pub(crate) async fn put_email_log( + ex: &mut Executor, + message: &Message, + now: OffsetDateTime, +) -> DbResult { + match ex { + Executor::Postgres(ref mut ex) => { + let query_str = "INSERT INTO email_log (sent, message) VALUES ($1, $2) RETURNING id"; + let row = sqlx::query(query_str) + .bind(now) + .bind(message.formatted()) + .fetch_one(ex) + .await + .map_err(postgres::map_sqlx_error)?; + let last_insert_id: i64 = row.try_get("id").map_err(postgres::map_sqlx_error)?; + Ok(last_insert_id) + } + + #[cfg(any(test, feature = "sqlite"))] + Executor::Sqlite(ref mut ex) => { + let (now_sec, now_nsec) = unpack_timestamp(now); + + let query_str = "INSERT INTO email_log (sent_sec, sent_nsec, message) VALUES (?, ?, ?)"; + let done = sqlx::query(query_str) + .bind(now_sec) + .bind(now_nsec) + .bind(message.formatted()) + .execute(ex) + .await + .map_err(sqlite::map_sqlx_error)?; + Ok(done.last_insert_rowid()) + } + + #[allow(unused)] + _ => unreachable!(), + } +} + +/// Records the result of sending an email. +pub(crate) async fn update_email_log(ex: &mut Executor, id: i64, result: &str) -> DbResult<()> { + match ex { + Executor::Postgres(ref mut ex) => { + let query_str = "UPDATE email_log SET result = $1 WHERE id = $2"; + let done = sqlx::query(query_str) + .bind(result) + .bind(id) + .execute(ex) + .await + .map_err(postgres::map_sqlx_error)?; + ensure_one_upsert(done.rows_affected())?; + Ok(()) + } + + #[cfg(any(test, feature = "sqlite"))] + Executor::Sqlite(ref mut ex) => { + let query_str = "UPDATE email_log SET result = ? WHERE id = ?"; + let done = sqlx::query(query_str) + .bind(result) + .bind(id) + .execute(ex) + .await + .map_err(sqlite::map_sqlx_error)?; + ensure_one_upsert(done.rows_affected())?; + Ok(()) + } + + #[allow(unused)] + _ => unreachable!(), + } +} diff --git a/smtp/src/db/postgres.sql b/smtp/src/db/postgres.sql new file mode 100644 index 0000000..84774df --- /dev/null +++ b/smtp/src/db/postgres.sql @@ -0,0 +1,24 @@ +-- III-IV +-- Copyright 2023 Julio Merino +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); you may not +-- use this file except in compliance with the License. You may obtain a copy +-- of the License at: +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +-- WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +-- License for the specific language governing permissions and limitations +-- under the License. + +CREATE TABLE IF NOT EXISTS email_log ( + id SERIAL PRIMARY KEY, + + sent TIMESTAMPTZ NOT NULL, + message BYTEA NOT NULL, + result TEXT +); + +CREATE INDEX email_log_by_sent ON email_log (sent); diff --git a/smtp/src/db/sqlite.sql b/smtp/src/db/sqlite.sql new file mode 100644 index 0000000..f2fba70 --- /dev/null +++ b/smtp/src/db/sqlite.sql @@ -0,0 +1,27 @@ +-- III-IV +-- Copyright 2023 Julio Merino +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); you may not +-- use this file except in compliance with the License. You may obtain a copy +-- of the License at: +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +-- WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +-- License for the specific language governing permissions and limitations +-- under the License. + +PRAGMA foreign_keys = ON; + +CREATE TABLE IF NOT EXISTS email_log ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + + sent_sec INTEGER NOT NULL, + sent_nsec INTEGER NOT NULL, + message BYTEA NOT NULL, + result TEXT +); + +CREATE INDEX email_log_by_sent ON email_log (sent_sec, sent_nsec); diff --git a/smtp/src/db/tests.rs b/smtp/src/db/tests.rs new file mode 100644 index 0000000..39dd650 --- /dev/null +++ b/smtp/src/db/tests.rs @@ -0,0 +1,87 @@ +// III-IV +// Copyright 2023 Julio Merino +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not +// use this file except in compliance with the License. You may obtain a copy +// of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. + +//! Common tests for any database implementation. + +use crate::db::*; +use iii_iv_core::db::Executor; +use time::macros::{date, datetime}; + +async fn test_email_log(ex: &mut Executor) { + // The message contents should be completely irrelevant for counting purposes, so keeping + // them all identical helps assert that. + let message = Message::builder() + .from("from@example.com".parse().unwrap()) + .to("to@example.com".parse().unwrap()) + .subject("Foo") + .body("Bar".to_owned()) + .unwrap(); + + put_email_log(ex, &message, datetime!(2023-06-11 00:00:00.000000 UTC)).await.unwrap(); + put_email_log(ex, &message, datetime!(2023-06-12 06:20:00.000001 UTC)).await.unwrap(); + put_email_log(ex, &message, datetime!(2023-06-12 06:20:00.000002 UTC)).await.unwrap(); + put_email_log(ex, &message, datetime!(2023-06-12 23:59:59.999999 UTC)).await.unwrap(); + + assert_eq!(0, count_email_log(ex, date!(2023 - 06 - 10)).await.unwrap()); + assert_eq!(1, count_email_log(ex, date!(2023 - 06 - 11)).await.unwrap()); + assert_eq!(3, count_email_log(ex, date!(2023 - 06 - 12)).await.unwrap()); + assert_eq!(0, count_email_log(ex, date!(2023 - 06 - 13)).await.unwrap()); +} + +macro_rules! generate_db_tests [ + ( $setup:expr $(, #[$extra:meta] )? ) => { + iii_iv_core::db::testutils::generate_tests!( + $(#[$extra],)? + $setup, + $crate::db::tests, + test_email_log + ); + } +]; + +use generate_db_tests; + +mod postgres { + use super::*; + use crate::db::init_schema; + use iii_iv_core::db::postgres::PostgresDb; + use iii_iv_core::db::Db; + + async fn setup() -> PostgresDb { + let db = iii_iv_core::db::postgres::testutils::setup().await; + init_schema(&mut db.ex().await.unwrap()).await.unwrap(); + db + } + + generate_db_tests!( + &mut setup().await.ex().await.unwrap(), + #[ignore = "Requires environment configuration and is expensive"] + ); +} + +mod sqlite { + use super::*; + use crate::db::init_schema; + use iii_iv_core::db::sqlite::SqliteDb; + use iii_iv_core::db::Db; + + async fn setup() -> SqliteDb { + let db = iii_iv_core::db::sqlite::testutils::setup().await; + init_schema(&mut db.ex().await.unwrap()).await.unwrap(); + db + } + + generate_db_tests!(&mut setup().await.ex().await.unwrap()); +} diff --git a/smtp/src/driver/mod.rs b/smtp/src/driver/mod.rs index 5906551..aab39ac 100644 --- a/smtp/src/driver/mod.rs +++ b/smtp/src/driver/mod.rs @@ -15,13 +15,17 @@ //! Utilities to send messages over email. +use crate::db::{count_email_log, put_email_log, update_email_log}; use async_trait::async_trait; use derivative::Derivative; +use iii_iv_core::clocks::Clock; +use iii_iv_core::db::Db; use iii_iv_core::driver::{DriverError, DriverResult}; -use iii_iv_core::env::get_required_var; -pub use lettre::message::{Mailbox, Message}; +use iii_iv_core::env::{get_optional_var, get_required_var}; +use lettre::message::Message; use lettre::transport::smtp::authentication::Credentials; use lettre::{AsyncSmtpTransport, AsyncTransport, Tokio1Executor}; +use std::sync::Arc; #[cfg(any(test, feature = "testutils"))] pub mod testutils; @@ -40,19 +44,23 @@ pub struct SmtpOptions { /// Password for logging into the SMTP server. #[derivative(Debug = "ignore")] pub password: String, + + /// Maximum number of messages to send per day, if any. + pub max_daily_emails: Option, } impl SmtpOptions { /// Initializes a set of options from environment variables whose name is prefixed with the /// given `prefix`. /// - /// This will use variables such as `_RELAY`, `_USERNAME` and - /// `_PASSWORD`. + /// This will use variables such as `_RELAY`, `_USERNAME`, `_PASSWORD` + /// and `_MAX_DAILY_EMAILS`. pub fn from_env(prefix: &str) -> Result { Ok(Self { relay: get_required_var::(prefix, "RELAY")?, username: get_required_var::(prefix, "USERNAME")?, password: get_required_var::(prefix, "PASSWORD")?, + max_daily_emails: get_optional_var::(prefix, "MAX_DAILY_EMAILS")?, }) } } @@ -70,9 +78,9 @@ pub struct LettreSmtpMailer(AsyncSmtpTransport); impl LettreSmtpMailer { /// Establishes a connection to the SMTP server. - pub fn connect(opts: SmtpOptions) -> Result { - let creds = Credentials::new(opts.username, opts.password); - let mailer = AsyncSmtpTransport::::relay(&opts.relay) + fn connect(relay: &str, username: String, password: String) -> Result { + let creds = Credentials::new(username, password); + let mailer = AsyncSmtpTransport::::relay(relay) .map_err(|e| format!("{}", e))? .credentials(creds) .build(); @@ -91,17 +99,131 @@ impl SmtpMailer for LettreSmtpMailer { } } +/// Encapsulates logic to send email messages while respecting quotas. +#[derive(Clone)] +pub struct SmtpDriver { + /// The SMTP transport with which to send email messages. + transport: T, + + /// The database with which to track sent messages. + db: Arc, + + /// The clock from which to obtain the current time. + clock: Arc, + + /// Maximum number of messages to send per day, if any. + max_daily_emails: Option, +} + +impl SmtpDriver { + /// Creates a new driver with the given values. + pub fn new( + transport: T, + db: Arc, + clock: Arc, + max_daily_emails: Option, + ) -> Self { + Self { transport, db, clock, max_daily_emails } + } + + /// Obtains a reference to the wrapped SMTP transport. + pub fn get_transport(&self) -> &T { + &self.transport + } +} + +#[async_trait] +impl SmtpMailer for SmtpDriver +where + T: SmtpMailer + Send + Sync, +{ + /// Sends an email message after recording it and accounting for it for quota purposes. + async fn send(&self, message: Message) -> DriverResult<()> { + let mut tx = self.db.begin().await?; + let now = self.clock.now_utc(); + + // We must insert into the table first, before counting, to grab an exclusive transaction + // lock. Otherwise the count will be stale by the time we use it. + let id = put_email_log(tx.ex(), &message, now).await?; + + if let Some(max_daily_emails) = self.max_daily_emails { + let daily_emails = count_email_log(tx.ex(), now.date()).await? - 1; + if daily_emails >= max_daily_emails { + let msg = format!( + "Too many emails sent today ({} >= {})", + daily_emails, max_daily_emails, + ); + update_email_log(tx.ex(), id, &msg).await?; + return Err(DriverError::NoSpace(msg)); + } + } + + // Commit the transaction _before_ trying to send the email. This is intentional to ignore + // errors from the server because we don't know if errors are counted towards the daily + // quota. Furthermore, this avoids sequencing email submissions if the server is slow. + tx.commit().await?; + + let result = self.transport.send(message).await; + + match result { + Ok(()) => update_email_log(&mut self.db.ex().await?, id, "OK").await?, + Err(ref e) => update_email_log(&mut self.db.ex().await?, id, &format!("{}", e)).await?, + } + + result + } +} + +/// Creates a new SMTP driver that sends email messages via the service configured in `opts`. +/// +/// `db` and `clock` are used to keep track of the messages that have been sent for quota +/// accounting purposes. +pub fn new_prod_driver( + opts: SmtpOptions, + db: Arc, + clock: Arc, +) -> Result, String> { + let transport = LettreSmtpMailer::connect(&opts.relay, opts.username, opts.password)?; + Ok(SmtpDriver::new(transport, db, clock, opts.max_daily_emails)) +} + #[cfg(test)] mod tests { + use super::testutils::*; use super::*; + use crate::db::get_email_log; + use futures::future; use std::env; + use std::time::Duration; + + #[test] + fn test_smtp_options_from_env_all_required_present() { + let overrides = [ + ("SMTP_RELAY", Some("the-relay")), + ("SMTP_USERNAME", Some("the-username")), + ("SMTP_PASSWORD", Some("the-password")), + ]; + temp_env::with_vars(overrides, || { + let opts = SmtpOptions::from_env("SMTP").unwrap(); + assert_eq!( + SmtpOptions { + relay: "the-relay".to_owned(), + username: "the-username".to_owned(), + password: "the-password".to_owned(), + max_daily_emails: None, + }, + opts + ); + }); + } #[test] - fn test_smtp_options_from_env_all_present() { + fn test_smtp_options_from_env_all_required_and_optional_present() { let overrides = [ ("SMTP_RELAY", Some("the-relay")), ("SMTP_USERNAME", Some("the-username")), ("SMTP_PASSWORD", Some("the-password")), + ("SMTP_MAX_DAILY_EMAILS", Some("123")), ]; temp_env::with_vars(overrides, || { let opts = SmtpOptions::from_env("SMTP").unwrap(); @@ -109,7 +231,8 @@ mod tests { SmtpOptions { relay: "the-relay".to_owned(), username: "the-username".to_owned(), - password: "the-password".to_owned() + password: "the-password".to_owned(), + max_daily_emails: Some(123), }, opts ); @@ -131,4 +254,109 @@ mod tests { }); } } + + /// Creates a new email message with hardcoded values. + fn new_message() -> Message { + Message::builder() + .from("from@example.com".parse().unwrap()) + .to("to@example.com".parse().unwrap()) + .subject("Foo") + .body("Bar".to_owned()) + .unwrap() + } + + #[tokio::test] + async fn test_send_ok() { + let mut context = TestContext::setup(None).await; + let exp_message = new_message(); + + context.driver.send(exp_message.clone()).await.unwrap(); + + let message = context.mailer.expect_one_message(&"to@example.com".into()).await; + assert_eq!(exp_message.formatted(), message.formatted()); + + let log = get_email_log(&mut context.ex().await).await.unwrap(); + assert_eq!(1, log.len()); + assert_eq!(exp_message.formatted(), log[0].1); + assert_eq!(Some("OK"), log[0].2.as_deref()); + } + + #[tokio::test] + async fn test_send_error() { + let mut context = TestContext::setup(None).await; + let exp_message = new_message(); + + context.mailer.inject_error_for("to@example.com").await; + let err = context.driver.send(exp_message.clone()).await.unwrap_err(); + assert_eq!("Sending email to to@example.com failed", &format!("{}", err)); + + context.mailer.expect_no_messages().await; + + let log = get_email_log(&mut context.ex().await).await.unwrap(); + assert_eq!(1, log.len()); + assert_eq!(exp_message.formatted(), log[0].1); + assert_eq!(Some("Sending email to to@example.com failed"), log[0].2.as_deref()); + } + + #[tokio::test] + async fn test_daily_limit_enforced_and_clears_every_day() { + let mut context = TestContext::setup(Some(50)).await; + let exp_message = new_message(); + + for _ in 0..50 { + put_email_log(&mut context.ex().await, &exp_message, context.clock.now_utc()) + .await + .unwrap(); + } + + let err = context.driver.send(exp_message.clone()).await.unwrap_err(); + assert_eq!("Too many emails sent today (50 >= 50)", &format!("{}", err)); + context.mailer.expect_no_messages().await; + + // Advance the clock to reach just the 23rd hour of the same day. + let current_hour = u64::from(context.clock.now_utc().hour()); + context.clock.advance(Duration::from_secs((23 - current_hour) * 60 * 60)); + + let err = context.driver.send(exp_message.clone()).await.unwrap_err(); + assert_eq!("Too many emails sent today (50 >= 50)", &format!("{}", err)); + context.mailer.expect_no_messages().await; + + // Push the clock into the next day. + context.clock.advance(Duration::from_secs(60 * 60)); + + context.driver.send(exp_message.clone()).await.unwrap(); + let message = context.mailer.expect_one_message(&"to@example.com".into()).await; + assert_eq!(exp_message.formatted(), message.formatted()); + } + + #[tokio::test(flavor = "multi_thread")] + async fn test_daily_limit_concurrency() { + let context = TestContext::setup(Some(10)).await; + let exp_message = new_message(); + + let mut futures = Vec::with_capacity(1000); + for _ in 0..1000 { + futures.push(async { + match context.driver.send(exp_message.clone()).await { + Ok(()) => true, + Err(_) => false, + } + }); + } + + let mut count_ok = 0; + let mut count_err = 0; + for ok in future::join_all(futures.into_iter()).await { + if ok { + count_ok += 1; + } else { + count_err += 1; + } + } + assert_eq!(10, count_ok); + assert_eq!(990, count_err); + + let inbox = context.mailer.expect_one_inbox(&"to@example.com".into()).await; + assert_eq!(10, inbox.len()); + } } diff --git a/smtp/src/driver/testutils.rs b/smtp/src/driver/testutils.rs index 5ce05d4..9b20a3f 100644 --- a/smtp/src/driver/testutils.rs +++ b/smtp/src/driver/testutils.rs @@ -24,6 +24,15 @@ use lettre::Message; use std::collections::{HashMap, HashSet}; use std::sync::Arc; +#[cfg(test)] +use { + super::SmtpDriver, + crate::db::init_schema, + iii_iv_core::clocks::testutils::SettableClock, + iii_iv_core::db::{sqlite, Db, Executor}, + time::macros::datetime, +}; + /// Mailer that captures outgoing messages. #[derive(Clone, Default)] pub struct RecorderSmtpMailer { @@ -41,6 +50,12 @@ impl RecorderSmtpMailer { errors.insert(email.into()); } + /// Expects that no messages were sent. + pub async fn expect_no_messages(&self) { + let inboxes = self.inboxes.lock().await; + assert_eq!(0, inboxes.len(), "Expected to find no messages"); + } + /// Expects that messages were sent to `exp_to` and nobody else, and returns the list of /// messages to that recipient. pub async fn expect_one_inbox(&self, exp_to: &EmailAddress) -> Vec { @@ -83,6 +98,43 @@ impl SmtpMailer for RecorderSmtpMailer { } } +/// Container for the state required to run a driver test. +#[cfg(test)] +pub(crate) struct TestContext { + pub(crate) driver: SmtpDriver, + pub(crate) db: Arc, + pub(crate) clock: Arc, + pub(crate) mailer: RecorderSmtpMailer, +} + +#[cfg(test)] +impl TestContext { + pub(crate) async fn setup(max_daily_emails: Option) -> Self { + let _can_fail = env_logger::builder().is_test(true).try_init(); + + let db = Arc::from(sqlite::testutils::setup().await); + let mut ex = db.ex().await.unwrap(); + init_schema(&mut ex).await.unwrap(); + + let clock = Arc::from(SettableClock::new(datetime!(2023-10-17 06:00:00 UTC))); + + let mailer = RecorderSmtpMailer::default(); + + let driver = SmtpDriver { + transport: mailer.clone(), + db: db.clone(), + clock: clock.clone(), + max_daily_emails, + }; + + Self { driver, db, clock, mailer } + } + + pub(crate) async fn ex(&mut self) -> Executor { + self.db.ex().await.unwrap() + } +} + #[cfg(test)] mod tests { use super::*; @@ -118,6 +170,24 @@ mod tests { assert!(inboxes.contains_key(&to3)); } + #[tokio::test] + async fn test_recorder_expect_no_messages_ok() { + let mailer = RecorderSmtpMailer::default(); + mailer.expect_no_messages().await; + } + + #[tokio::test] + async fn test_recorder_expect_no_messages_fail() { + #[tokio::main(flavor = "current_thread")] + async fn do_test() { + let to1 = EmailAddress::from("to1@example.com"); + let mailer = RecorderSmtpMailer::default(); + mailer.send(new_message(&to1)).await.unwrap(); + mailer.expect_no_messages().await; // Will panic. + } + assert!(catch_unwind(do_test).is_err()); + } + #[tokio::test] async fn test_recorder_expect_one_inbox_ok() { let to = EmailAddress::from("to@example.com"); diff --git a/smtp/src/lib.rs b/smtp/src/lib.rs index 4813e8b..00cd6e5 100644 --- a/smtp/src/lib.rs +++ b/smtp/src/lib.rs @@ -20,5 +20,6 @@ #![warn(unused, unused_extern_crates, unused_import_braces, unused_qualifications)] #![warn(unsafe_code)] +pub mod db; pub mod driver; pub mod model;