diff --git a/Cargo.lock b/Cargo.lock index 27d4c3088..78c1f44c1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -283,6 +283,7 @@ dependencies = [ "tokio", "tokio-stream", "tonic", + "tower", "tracing", "tracing-core", "tracing-subscriber", diff --git a/console-subscriber/Cargo.toml b/console-subscriber/Cargo.toml index d3c414e3b..b79e6e3d5 100644 --- a/console-subscriber/Cargo.toml +++ b/console-subscriber/Cargo.toml @@ -55,6 +55,7 @@ crossbeam-channel = "0.5" [dev-dependencies] tokio = { version = "^1.21", features = ["full", "rt-multi-thread"] } +tower = "0.4" futures = "0.3" [package.metadata.docs.rs] diff --git a/console-subscriber/src/aggregator/id_data.rs b/console-subscriber/src/aggregator/id_data.rs index b9010b445..2ad2c74b0 100644 --- a/console-subscriber/src/aggregator/id_data.rs +++ b/console-subscriber/src/aggregator/id_data.rs @@ -104,18 +104,18 @@ impl IdData { if let Some(dropped_at) = stats.dropped_at() { let dropped_for = now.checked_duration_since(dropped_at).unwrap_or_default(); let dirty = stats.is_unsent(); - let should_drop = + let should_retain = // if there are any clients watching, retain all dirty tasks regardless of age (dirty && has_watchers) - || dropped_for > retention; + || dropped_for <= retention; tracing::trace!( stats.id = ?id, stats.dropped_at = ?dropped_at, stats.dropped_for = ?dropped_for, stats.dirty = dirty, - should_drop, + should_retain, ); - return !should_drop; + return should_retain; } true diff --git a/console-subscriber/src/lib.rs b/console-subscriber/src/lib.rs index 6b0c1a75e..7b6daecf2 100644 --- a/console-subscriber/src/lib.rs +++ b/console-subscriber/src/lib.rs @@ -1,6 +1,6 @@ #![doc = include_str!("../README.md")] use console_api as proto; -use proto::resources::resource; +use proto::{instrument::instrument_server::InstrumentServer, resources::resource}; use serde::Serialize; use std::{ cell::RefCell, @@ -15,7 +15,10 @@ use std::{ use thread_local::ThreadLocal; #[cfg(unix)] use tokio::net::UnixListener; -use tokio::sync::{mpsc, oneshot}; +use tokio::{ + sync::{mpsc, oneshot}, + task::JoinHandle, +}; #[cfg(unix)] use tokio_stream::wrappers::UnixListenerStream; use tracing_core::{ @@ -933,18 +936,15 @@ impl Server { /// /// [`tonic`]: https://docs.rs/tonic/ pub async fn serve_with( - mut self, + self, mut builder: tonic::transport::Server, ) -> Result<(), Box> { - let aggregate = self - .aggregator - .take() - .expect("cannot start server multiple times"); - let aggregate = spawn_named(aggregate.run(), "console::aggregate"); let addr = self.addr.clone(); - let router = builder.add_service( - proto::instrument::instrument_server::InstrumentServer::new(self), - ); + let ServerParts { + instrument_server: service, + aggregator_handle: aggregate, + } = self.into_parts(); + let router = builder.add_service(service); let res = match addr { ServerAddr::Tcp(addr) => { let serve = router.serve(addr); @@ -957,9 +957,110 @@ impl Server { spawn_named(serve, "console::serve").await } }; - aggregate.abort(); + drop(aggregate); res?.map_err(Into::into) } + + /// Returns the parts needed to spawn a gRPC server and keep the aggregation + /// worker running. + /// + /// Note that a server spawned in this way will overwrite any value set by + /// [`Builder::server_addr`] as the user becomes responsible for defining + /// the address when calling [`Router::serve`]. + /// + /// # Examples + /// + /// The parts can be used to serve the instrument server together with + /// other endpoints from the same gRPC server. + /// + /// ``` + /// use console_subscriber::{ConsoleLayer, ServerParts}; + /// + /// # let runtime = tokio::runtime::Builder::new_current_thread() + /// # .enable_all() + /// # .build() + /// # .unwrap(); + /// # runtime.block_on(async { + /// let (console_layer, server) = ConsoleLayer::builder().build(); + /// let ServerParts { + /// instrument_server, + /// aggregator_handle, + /// .. + /// } = server.into_parts(); + /// + /// let router = tonic::transport::Server::builder() + /// //.add_service(some_other_service) + /// .add_service(instrument_server); + /// let serve = router.serve(std::net::SocketAddr::new( + /// std::net::IpAddr::V4(std::net::Ipv4Addr::new(127, 0, 0, 1)), + /// 6669, + /// )); + /// + /// // Finally, spawn the server. + /// tokio::spawn(serve); + /// # // Avoid a warning that `console_layer` and `aggregator_handle` are unused. + /// # drop(console_layer); + /// # drop(aggregator_handle); + /// # }); + /// ``` + /// + /// [`Router::serve`]: fn@tonic::transport::server::Router::serve + pub fn into_parts(mut self) -> ServerParts { + let aggregate = self + .aggregator + .take() + .expect("cannot start server multiple times"); + let aggregate = spawn_named(aggregate.run(), "console::aggregate"); + + let service = proto::instrument::instrument_server::InstrumentServer::new(self); + + ServerParts { + instrument_server: service, + aggregator_handle: AggregatorHandle { + join_handle: aggregate, + }, + } + } +} + +/// Server Parts +/// +/// This struct contains the parts returned by [`Server::into_parts`]. It may contain +/// further parts in the future, an as such is marked as `non_exhaustive`. +/// +/// The `InstrumentServer` can be used to construct a router which +/// can be added to a [`tonic`] gRPC server. +/// +/// The [`AggregatorHandle`] must be kept until after the server has been +/// shut down. +/// +/// See the [`Server::into_parts`] documentation for usage. +#[non_exhaustive] +pub struct ServerParts { + /// The instrument server. + /// + /// See the documentation for [`InstrumentServer`] for details. + pub instrument_server: InstrumentServer, + + /// The aggregate handle. + /// + /// See the documentation for [`AggregatorHandle`] for details. + pub aggregator_handle: AggregatorHandle, +} + +/// Aggregator handle. +/// +/// This object is returned from [`Server::into_parts`] and must be +/// kept as long as the `InstrumentServer` - which is also +/// returned - is in use. +pub struct AggregatorHandle { + join_handle: JoinHandle<()>, +} + +impl Drop for AggregatorHandle { + fn drop(&mut self) { + self.join_handle.abort(); + } } #[tonic::async_trait] diff --git a/console-subscriber/tests/framework.rs b/console-subscriber/tests/framework.rs new file mode 100644 index 000000000..f027e29c7 --- /dev/null +++ b/console-subscriber/tests/framework.rs @@ -0,0 +1,184 @@ +//! Framework tests +//! +//! The tests in this module are here to verify the testing framework itself. +//! As such, some of these tests may be repeated elsewhere (where we wish to +//! actually test the functionality of `console-subscriber`) and others are +//! negative tests that should panic. + +use std::time::Duration; + +use tokio::{task, time::sleep}; + +mod support; +use support::{assert_task, assert_tasks, ExpectedTask, MAIN_TASK_NAME}; + +#[test] +fn expect_present() { + let expected_task = ExpectedTask::default() + .match_default_name() + .expect_present(); + + let future = async { + sleep(Duration::ZERO).await; + }; + + assert_task(expected_task, future); +} + +#[test] +#[should_panic(expected = "Test failed: Task validation failed: + - Task: no expectations set, if you want to just expect that a matching task is present, use `expect_present()` +")] +fn fail_no_expectations() { + let expected_task = ExpectedTask::default().match_default_name(); + + let future = async { + sleep(Duration::ZERO).await; + }; + + assert_task(expected_task, future); +} + +#[test] +fn wakes() { + let expected_task = ExpectedTask::default().match_default_name().expect_wakes(1); + + let future = async { + sleep(Duration::ZERO).await; + }; + + assert_task(expected_task, future); +} + +#[test] +#[should_panic(expected = "Test failed: Task validation failed: + - Task: expected `wakes` to be 5, but actual was 1 +")] +fn fail_wakes() { + let expected_task = ExpectedTask::default().match_default_name().expect_wakes(5); + + let future = async { + sleep(Duration::ZERO).await; + }; + + assert_task(expected_task, future); +} + +#[test] +fn self_wakes() { + let expected_task = ExpectedTask::default() + .match_default_name() + .expect_self_wakes(1); + + let future = async { task::yield_now().await }; + + assert_task(expected_task, future); +} + +#[test] +#[should_panic(expected = "Test failed: Task validation failed: + - Task: expected `self_wakes` to be 1, but actual was 0 +")] +fn fail_self_wake() { + let expected_task = ExpectedTask::default() + .match_default_name() + .expect_self_wakes(1); + + let future = async { + sleep(Duration::ZERO).await; + }; + + assert_task(expected_task, future); +} + +#[test] +fn test_spawned_task() { + let expected_task = ExpectedTask::default() + .match_name("another-name".into()) + .expect_present(); + + let future = async { + task::Builder::new() + .name("another-name") + .spawn(async { task::yield_now().await }) + }; + + assert_task(expected_task, future); +} + +#[test] +#[should_panic(expected = "Test failed: Task validation failed: + - Task: no matching actual task was found +")] +fn fail_wrong_task_name() { + let expected_task = ExpectedTask::default().match_name("wrong-name".into()); + + let future = async { task::yield_now().await }; + + assert_task(expected_task, future); +} + +#[test] +fn multiple_tasks() { + let expected_tasks = vec![ + ExpectedTask::default() + .match_name("task-1".into()) + .expect_wakes(1), + ExpectedTask::default() + .match_name("task-2".into()) + .expect_wakes(1), + ]; + + let future = async { + let task1 = task::Builder::new() + .name("task-1") + .spawn(async { task::yield_now().await }) + .unwrap(); + let task2 = task::Builder::new() + .name("task-2") + .spawn(async { task::yield_now().await }) + .unwrap(); + + tokio::try_join! { + task1, + task2, + } + .unwrap(); + }; + + assert_tasks(expected_tasks, future); +} + +#[test] +#[should_panic(expected = "Test failed: Task validation failed: + - Task: expected `wakes` to be 2, but actual was 1 +")] +fn fail_1_of_2_expected_tasks() { + let expected_tasks = vec![ + ExpectedTask::default() + .match_name("task-1".into()) + .expect_wakes(1), + ExpectedTask::default() + .match_name("task-2".into()) + .expect_wakes(2), + ]; + + let future = async { + let task1 = task::Builder::new() + .name("task-1") + .spawn(async { task::yield_now().await }) + .unwrap(); + let task2 = task::Builder::new() + .name("task-2") + .spawn(async { task::yield_now().await }) + .unwrap(); + + tokio::try_join! { + task1, + task2, + } + .unwrap(); + }; + + assert_tasks(expected_tasks, future); +} diff --git a/console-subscriber/tests/support/mod.rs b/console-subscriber/tests/support/mod.rs new file mode 100644 index 000000000..4937aff6a --- /dev/null +++ b/console-subscriber/tests/support/mod.rs @@ -0,0 +1,47 @@ +use futures::Future; + +mod state; +mod subscriber; +mod task; + +use subscriber::run_test; + +pub(crate) use subscriber::MAIN_TASK_NAME; +pub(crate) use task::ExpectedTask; + +/// Assert that an `expected_task` is recorded by a console-subscriber +/// when driving the provided `future` to completion. +/// +/// This function is equivalent to calling [`assert_tasks`] with a vector +/// containing a single task. +/// +/// # Panics +/// +/// This function will panic if the expectations on the expected task are not +/// met or if a matching task is not recorded. +#[track_caller] +#[allow(dead_code)] +pub(crate) fn assert_task(expected_task: ExpectedTask, future: Fut) +where + Fut: Future + Send + 'static, + Fut::Output: Send + 'static, +{ + run_test(vec![expected_task], future) +} + +/// Assert that the `expected_tasks` are recorded by a console-subscriber +/// when driving the provided `future` to completion. +/// +/// # Panics +/// +/// This function will panic if the expectations on any of the expected tasks +/// are not met or if matching tasks are not recorded for all expected tasks. +#[track_caller] +#[allow(dead_code)] +pub(crate) fn assert_tasks(expected_tasks: Vec, future: Fut) +where + Fut: Future + Send + 'static, + Fut::Output: Send + 'static, +{ + run_test(expected_tasks, future) +} diff --git a/console-subscriber/tests/support/state.rs b/console-subscriber/tests/support/state.rs new file mode 100644 index 000000000..ebdc75b5d --- /dev/null +++ b/console-subscriber/tests/support/state.rs @@ -0,0 +1,112 @@ +use std::fmt; + +use tokio::sync::broadcast::{ + self, + error::{RecvError, TryRecvError}, +}; + +#[derive(Clone, Debug, PartialEq, PartialOrd)] +pub(super) enum TestStep { + Start, + ServerStarted, + ClientConnected, + TestFinished, + UpdatesRecorded, +} + +impl fmt::Display for TestStep { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + (self as &dyn fmt::Debug).fmt(f) + } +} + +pub(super) struct TestState { + receiver: broadcast::Receiver, + sender: broadcast::Sender, + step: TestStep, +} + +impl TestState { + pub(super) fn new() -> Self { + let (sender, receiver) = broadcast::channel(1); + Self { + receiver, + sender, + step: TestStep::Start, + } + } + + pub(super) async fn wait_for_step(&mut self, desired_step: TestStep) { + loop { + if self.step >= desired_step { + break; + } + + match self.receiver.recv().await { + Ok(step) => self.step = step, + Err(RecvError::Lagged(_)) => { + // we don't mind being lagged, we'll just get the latest state + } + Err(RecvError::Closed) => { + panic!("failed to receive current step, waiting for step: {desired_step}, did the test abort?"); + } + } + } + } + + pub(super) fn try_wait_for_step(&mut self, desired_step: TestStep) -> bool { + self.update_step(); + + self.step == desired_step + } + + #[track_caller] + pub(super) fn advance_to_step(&mut self, next_step: TestStep) { + self.update_step(); + + if self.step >= next_step { + panic!( + "cannot advance to previous or current step! current step: {current}, next step: {next_step}", + current = self.step); + } + + match (&self.step, &next_step) { + (TestStep::Start, TestStep::ServerStarted) | + (TestStep::ServerStarted, TestStep::ClientConnected) | + (TestStep::ClientConnected, TestStep::TestFinished) | + (TestStep::TestFinished, TestStep::UpdatesRecorded) => {}, + (_, _) => panic!( + "cannot advance more than one step! current step: {current}, next step: {next_step}", + current = self.step), + } + + self.sender + .send(next_step) + .expect("failed to send the next test step, did the test abort?"); + } + + fn update_step(&mut self) { + loop { + match self.receiver.try_recv() { + Ok(step) => self.step = step, + Err(TryRecvError::Lagged(_)) => { + // we don't mind being lagged, we'll just get the latest state + } + Err(TryRecvError::Closed) => { + panic!("failed to update current step, did the test abort?") + } + Err(TryRecvError::Empty) => break, + } + } + } +} + +impl Clone for TestState { + fn clone(&self) -> Self { + Self { + receiver: self.receiver.resubscribe(), + sender: self.sender.clone(), + step: self.step.clone(), + } + } +} diff --git a/console-subscriber/tests/support/subscriber.rs b/console-subscriber/tests/support/subscriber.rs new file mode 100644 index 000000000..aa510e6fc --- /dev/null +++ b/console-subscriber/tests/support/subscriber.rs @@ -0,0 +1,311 @@ +use std::{collections::HashMap, fmt, future::Future, thread}; + +use console_api::{ + field::Value, + instrument::{instrument_client::InstrumentClient, InstrumentRequest}, +}; +use console_subscriber::ServerParts; +use futures::stream::StreamExt; +use tokio::{io::DuplexStream, task}; +use tonic::transport::{Channel, Endpoint, Server, Uri}; +use tower::service_fn; + +use super::state::{TestState, TestStep}; +use super::task::{ActualTask, ExpectedTask, TaskValidationFailure}; + +pub(crate) const MAIN_TASK_NAME: &str = "main"; + +#[derive(Debug)] +struct TestFailure { + failures: Vec, +} + +impl fmt::Display for TestFailure { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "Task validation failed:\n")?; + for failure in &self.failures { + write!(f, " - {failure}\n")?; + } + Ok(()) + } +} + +/// Runs the test +/// +/// This function runs the whole test. It sets up a `console-subscriber` layer +/// together with the gRPC server and connects a client to it. The subscriber +/// is then used to record traces as the provided future is driven to +/// completion on a current thread tokio runtime. +#[track_caller] +pub(super) fn run_test(expected_tasks: Vec, future: Fut) +where + Fut: Future + Send + 'static, + Fut::Output: Send + 'static, +{ + use tracing_subscriber::prelude::*; + + let (client_stream, server_stream) = tokio::io::duplex(1024); + let (console_layer, server) = console_subscriber::ConsoleLayer::builder().build(); + let registry = tracing_subscriber::registry().with(console_layer); + + let mut test_state = TestState::new(); + let mut test_state_test = test_state.clone(); + + let join_handle = thread::Builder::new() + .name("console-subscriber".into()) + .spawn(move || { + let _subscriber_guard = + tracing::subscriber::set_default(tracing_core::subscriber::NoSubscriber::default()); + let runtime = tokio::runtime::Builder::new_current_thread() + .enable_io() + .enable_time() + .build() + .expect("console-test error: failed to initialize console subscriber runtime"); + + runtime.block_on(async move { + task::Builder::new() + .name("console-server") + .spawn(console_server(server, server_stream, test_state.clone())) + .expect("console-test error: could not spawn 'console-server' task"); + + let actual_tasks = task::Builder::new() + .name("console-client") + .spawn(console_client(client_stream, test_state.clone())) + .expect("console-test error: could not spawn 'console-client' task") + .await + .expect("console-test error: failed to await 'console-client' task"); + + test_state.advance_to_step(TestStep::UpdatesRecorded); + actual_tasks + }) + }) + .expect("console subscriber could not spawn thread"); + + tracing::subscriber::with_default(registry, || { + let runtime = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap(); + + runtime.block_on(async move { + test_state_test + .wait_for_step(TestStep::ClientConnected) + .await; + + // Run the future that we are testing. + _ = tokio::task::Builder::new() + .name(MAIN_TASK_NAME) + .spawn(future) + .expect("console-test error: couldn't spawn test task") + .await; + test_state_test.advance_to_step(TestStep::TestFinished); + + test_state_test + .wait_for_step(TestStep::UpdatesRecorded) + .await; + }); + }); + + let actual_tasks = join_handle + .join() + .expect("console-test error: failed to join 'console-subscriber' thread"); + + if let Err(test_failure) = validate_expected_tasks(expected_tasks, actual_tasks) { + panic!("Test failed: {test_failure}") + } +} + +/// Starts the console server. +/// +/// The server will start serving over its side of the duplex stream. +/// +/// Once the server gets spawned into its task, the test state is advanced +/// to the `ServerStarted` step. This function will then wait until the test +/// state reaches the `UpdatesRecorded` step (indicating that all validation of the +/// received updates has been completed) before dropping the aggregator. +/// +/// # Test State +/// +/// 1. Advances to: `ServerStarted` +/// 2. Waits for: `UpdatesRecorded` +async fn console_server( + server: console_subscriber::Server, + server_stream: DuplexStream, + mut test_state: TestState, +) { + let ServerParts { + instrument_server: service, + aggregator_handle: aggregate, + .. + } = server.into_parts(); + Server::builder() + .add_service(service) + .serve_with_incoming(futures::stream::iter(vec![Ok::<_, std::io::Error>( + server_stream, + )])) + .await + .expect("console subscriber failed."); + test_state.advance_to_step(TestStep::ServerStarted); + + test_state.wait_for_step(TestStep::UpdatesRecorded).await; + drop(aggregate); +} + +/// Starts the console client and validates the expected tasks. +/// +/// First we wait until the server has started (test step `ServerStarted`), then +/// the client is connected to its half of the duplex stream and we start recording +/// the actual tasks. +/// +/// Once recording finishes (see [`record_actual_tasks()`] for details on the test +/// state condition), the actual tasks returned. +/// +/// # Test State +/// +/// 1. Waits for: `ServerStarted` +/// 2. Advances to: `ClientConnected` +async fn console_client(client_stream: DuplexStream, mut test_state: TestState) -> Vec { + test_state.wait_for_step(TestStep::ServerStarted).await; + + let mut client_stream = Some(client_stream); + let channel = Endpoint::try_from("http://[::]:6669") + .expect("Could not create endpoint") + .connect_with_connector(service_fn(move |_: Uri| { + let client = client_stream.take(); + + async move { + if let Some(client) = client { + Ok(client) + } else { + Err(std::io::Error::new( + std::io::ErrorKind::Other, + "Client already taken", + )) + } + } + })) + .await + .expect("client-console error: couldn't create client"); + test_state.advance_to_step(TestStep::ClientConnected); + + record_actual_tasks(channel, test_state.clone()).await +} + +/// Records the actual tasks which are received by the client channel. +/// +/// Updates will be received until the test state reaches the `TestFinished` step +/// (indicating that the test itself has finished running), at which point we wait +/// for a final update before returning all the actual tasks which were recorded. +/// +/// # Test State +/// +/// 1. Waits for: `TestFinished` +async fn record_actual_tasks( + client_channel: Channel, + mut test_state: TestState, +) -> Vec { + let mut client = InstrumentClient::new(client_channel); + + let mut stream = loop { + let request = tonic::Request::new(InstrumentRequest {}); + match client.watch_updates(request).await { + Ok(stream) => break stream.into_inner(), + Err(err) => panic!("Client cannot connect to watch updates: {err}"), + } + }; + + let mut tasks = HashMap::new(); + + let mut last_update = false; + while let Some(update) = stream.next().await { + match update { + Ok(update) => { + if let Some(task_update) = &update.task_update { + for new_task in &task_update.new_tasks { + if let Some(id) = &new_task.id { + let mut actual_task = ActualTask::new(id.id); + for field in &new_task.fields { + if let Some(console_api::field::Name::StrName(field_name)) = + &field.name + { + if field_name == "task.name" { + actual_task.name = match &field.value { + Some(Value::DebugVal(value)) => Some(value.clone()), + Some(Value::StrVal(value)) => Some(value.clone()), + _ => None, // Anything that isn't string-like shouldn't be used as a name. + }; + } + } + } + tasks.insert(actual_task.id, actual_task); + } + } + + for (id, stats) in &task_update.stats_update { + if let Some(mut task) = tasks.get_mut(id) { + task.wakes = stats.wakes; + task.self_wakes = stats.self_wakes; + } + } + } + } + Err(e) => { + panic!("update stream error: {}", e); + } + } + + if last_update { + break; + } + + if test_state.try_wait_for_step(TestStep::TestFinished) { + // Once the test finishes running, we will get one further update and finish. + last_update = true; + } + } + + tasks.into_values().collect() +} + +/// Validate the expected tasks against the actual tasks. +/// +/// Each expected task is checked in turn. +/// +/// A matching actual task is searched for. If one is found it, the +/// expected task is validated against the actual task. +/// +/// Any validation errors result in failure. If no matches +fn validate_expected_tasks( + expected_tasks: Vec, + actual_tasks: Vec, +) -> Result<(), TestFailure> { + let failures: Vec<_> = expected_tasks + .iter() + .map(|expected| validate_expected_task(expected, &actual_tasks)) + .filter_map(|r| match r { + Ok(_) => None, + Err(validation_error) => Some(validation_error), + }) + .collect(); + + if failures.is_empty() { + Ok(()) + } else { + Err(TestFailure { failures: failures }) + } +} + +fn validate_expected_task( + expected: &ExpectedTask, + actual_tasks: &Vec, +) -> Result<(), TaskValidationFailure> { + for actual in actual_tasks { + if expected.matches_actual_task(actual) { + // We only match a single task. + // FIXME(hds): We should probably create an error or a warning if multiple tasks match. + return expected.validate_actual_task(actual); + } + } + + expected.no_match_error() +} diff --git a/console-subscriber/tests/support/task.rs b/console-subscriber/tests/support/task.rs new file mode 100644 index 000000000..6df878b1b --- /dev/null +++ b/console-subscriber/tests/support/task.rs @@ -0,0 +1,228 @@ +use std::{error, fmt}; + +use super::MAIN_TASK_NAME; + +/// An actual task +/// +/// This struct contains the values recorded from the console subscriber +/// client and represents what is known about an actual task running on +/// the test's runtime. +#[derive(Clone, Debug)] +pub(super) struct ActualTask { + pub(super) id: u64, + pub(super) name: Option, + pub(super) wakes: u64, + pub(super) self_wakes: u64, +} + +impl ActualTask { + pub(super) fn new(id: u64) -> Self { + Self { + id, + name: None, + wakes: 0, + self_wakes: 0, + } + } +} + +/// An error in task validation. +pub(super) struct TaskValidationFailure { + /// The expected task whose expectations were not met. + expected: ExpectedTask, + /// The actual task which failed the validation + actual: Option, + /// A textual description of the validation failure + failure: String, +} + +impl error::Error for TaskValidationFailure {} + +impl fmt::Display for TaskValidationFailure { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.failure) + } +} + +impl fmt::Debug for TaskValidationFailure { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match &self.actual { + Some(actual) => write!( + f, + "Task Validation Failed!\n Expected Task: {expected:?}\n Actual Task: {actual:?}\n Failure: {failure}", + expected = self.expected, failure = self.failure), + None => write!( + f, + "Task Validation Failed!\n Expected Task: {expected:?}\n Failure: {failure}", + expected = self.expected, failure = self.failure), + } + } +} + +/// An expected task. +/// +/// This struct contains the fields that an expected task will attempt to match +/// actual tasks on, as well as the expectations that will be used to validate +/// which the actual task is as expected. +#[derive(Clone, Debug)] +pub(crate) struct ExpectedTask { + match_name: Option, + + expect_present: Option, + expect_wakes: Option, + expect_self_wakes: Option, +} + +impl Default for ExpectedTask { + fn default() -> Self { + Self { + match_name: None, + expect_present: None, + expect_wakes: None, + expect_self_wakes: None, + } + } +} + +impl ExpectedTask { + /// Returns whether or not an actual task matches this expected task. + /// + /// All matching rules will be run, if they all succeed, then `true` will + /// be returned, otherwise `false`. + pub(super) fn matches_actual_task(&self, actual_task: &ActualTask) -> bool { + if let Some(match_name) = &self.match_name { + if Some(match_name) == actual_task.name.as_ref() { + return true; + } + } + + false + } + + /// Returns an error specifying that no match was found for this expected + /// task. + pub(super) fn no_match_error(&self) -> Result<(), TaskValidationFailure> { + Err(TaskValidationFailure { + expected: self.clone(), + actual: None, + failure: format!("{self}: no matching actual task was found"), + }) + } + + /// Validates all expectations against the provided actual task. + /// + /// No check that the actual task matches is performed. That must have been + /// done prior. + /// + /// If all expections are met, this method returns `Ok(())`. If any + /// expectations are not met, then the first incorrect expectation will + /// be returned as an `Err`. + pub(super) fn validate_actual_task( + &self, + actual_task: &ActualTask, + ) -> Result<(), TaskValidationFailure> { + let mut no_expectations = true; + if let Some(_expected) = self.expect_present { + no_expectations = false; + } + + if let Some(expected_wakes) = self.expect_wakes { + no_expectations = false; + if expected_wakes != actual_task.wakes { + return Err(TaskValidationFailure { + expected: self.clone(), + actual: Some(actual_task.clone()), + failure: format!( + "{self}: expected `wakes` to be {expected_wakes}, but actual was {actual_wakes}", + actual_wakes = actual_task.wakes), + }); + } + } + + if let Some(expected_self_wakes) = self.expect_self_wakes { + no_expectations = false; + if expected_self_wakes != actual_task.self_wakes { + return Err(TaskValidationFailure { + expected: self.clone(), + actual: Some(actual_task.clone()), + failure: format!( + "{self}: expected `self_wakes` to be {expected_self_wakes}, but actual was {actual_self_wakes}", + actual_self_wakes = actual_task.self_wakes), + }); + } + } + + if no_expectations { + return Err(TaskValidationFailure { + expected: self.clone(), + actual: Some(actual_task.clone()), + failure: format!( + "{self}: no expectations set, if you want to just expect that a matching task is present, use `expect_present()`") + }); + } + + Ok(()) + } + + /// Matches tasks by name. + /// + /// To match this expected task, an actual task must have the name `name`. + #[allow(dead_code)] + pub(crate) fn match_name(mut self, name: String) -> Self { + self.match_name = Some(name); + self + } + + /// Matches tasks by the default task name. + /// + /// To match this expected task, an actual task must have the default name + /// assigned to the task which runs the future provided to [`assert_task`] + /// or [`assert_tasks`]. + /// + /// [`assert_task`]: fn@support::assert_task + /// [`assert_tasks`]: fn@support::assert_tasks + #[allow(dead_code)] + pub(crate) fn match_default_name(mut self) -> Self { + self.match_name = Some(MAIN_TASK_NAME.into()); + self + } + + /// Expects that a task is present. + /// + /// To validate, an actual task matching this expected task must be found. + #[allow(dead_code)] + pub(crate) fn expect_present(mut self) -> Self { + self.expect_present = Some(true); + self + } + + /// Expects that a task has a specific value for `wakes`. + /// + /// To validate, the actual task matching this expected task must have + /// a count of wakes equal to `wakes`. + #[allow(dead_code)] + pub(crate) fn expect_wakes(mut self, wakes: u64) -> Self { + self.expect_wakes = Some(wakes); + self + } + + /// Expects that a task has a specific value for `self_wakes`. + /// + /// To validate, the actual task matching this expected task must have + /// a count of self wakes equal to `self_wakes`. + #[allow(dead_code)] + pub(crate) fn expect_self_wakes(mut self, self_wakes: u64) -> Self { + self.expect_self_wakes = Some(self_wakes); + self + } +} + +impl fmt::Display for ExpectedTask { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let fields = match &self.match_name { + Some(name) => format!("name={name}"), + None => "(no fields to match on)".into(), + }; + write!(f, "Task<{fields}>") + } +} diff --git a/console-subscriber/tests/wake.rs b/console-subscriber/tests/wake.rs new file mode 100644 index 000000000..e64e87a6e --- /dev/null +++ b/console-subscriber/tests/wake.rs @@ -0,0 +1,48 @@ +mod support; +use std::time::Duration; + +use support::{assert_task, ExpectedTask}; +use tokio::{task, time::sleep}; + +#[test] +fn sleep_wakes() { + let expected_task = ExpectedTask::default() + .match_default_name() + .expect_wakes(1) + .expect_self_wakes(0); + + let future = async { + sleep(Duration::ZERO).await; + }; + + assert_task(expected_task, future); +} + +#[test] +fn double_sleep_wakes() { + let expected_task = ExpectedTask::default() + .match_default_name() + .expect_wakes(2) + .expect_self_wakes(0); + + let future = async { + sleep(Duration::ZERO).await; + sleep(Duration::ZERO).await; + }; + + assert_task(expected_task, future); +} + +#[test] +fn self_wake() { + let expected_task = ExpectedTask::default() + .match_default_name() + .expect_wakes(1) + .expect_self_wakes(1); + + let future = async { + task::yield_now().await; + }; + + assert_task(expected_task, future); +}