diff --git a/Cargo.lock b/Cargo.lock index 27d4c3088..2b6a44e0e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -257,6 +257,7 @@ dependencies = [ name = "console-api" version = "0.5.0" dependencies = [ + "futures-core", "prost", "prost-build", "prost-types", @@ -283,6 +284,7 @@ dependencies = [ "tokio", "tokio-stream", "tonic", + "tower", "tracing", "tracing-core", "tracing-subscriber", diff --git a/console-subscriber/Cargo.toml b/console-subscriber/Cargo.toml index d3c414e3b..b79e6e3d5 100644 --- a/console-subscriber/Cargo.toml +++ b/console-subscriber/Cargo.toml @@ -55,6 +55,7 @@ crossbeam-channel = "0.5" [dev-dependencies] tokio = { version = "^1.21", features = ["full", "rt-multi-thread"] } +tower = "0.4" futures = "0.3" [package.metadata.docs.rs] diff --git a/console-subscriber/src/aggregator/id_data.rs b/console-subscriber/src/aggregator/id_data.rs index b9010b445..2ad2c74b0 100644 --- a/console-subscriber/src/aggregator/id_data.rs +++ b/console-subscriber/src/aggregator/id_data.rs @@ -104,18 +104,18 @@ impl IdData { if let Some(dropped_at) = stats.dropped_at() { let dropped_for = now.checked_duration_since(dropped_at).unwrap_or_default(); let dirty = stats.is_unsent(); - let should_drop = + let should_retain = // if there are any clients watching, retain all dirty tasks regardless of age (dirty && has_watchers) - || dropped_for > retention; + || dropped_for <= retention; tracing::trace!( stats.id = ?id, stats.dropped_at = ?dropped_at, stats.dropped_for = ?dropped_for, stats.dirty = dirty, - should_drop, + should_retain, ); - return !should_drop; + return should_retain; } true diff --git a/console-subscriber/src/aggregator/mod.rs b/console-subscriber/src/aggregator/mod.rs index 4496cba28..2d08a4fa4 100644 --- a/console-subscriber/src/aggregator/mod.rs +++ b/console-subscriber/src/aggregator/mod.rs @@ -221,7 +221,7 @@ impl Aggregator { // to be woken when the flush interval has elapsed, or when the // channel is almost full. let mut drained = false; - while let Some(event) = self.events.recv().now_or_never() { + while let Some(event) = tokio::task::unconstrained(self.events.recv()).now_or_never() { match event { Some(event) => { self.update_state(event); diff --git a/console-subscriber/tests/framework.rs b/console-subscriber/tests/framework.rs new file mode 100644 index 000000000..68bf2a0ce --- /dev/null +++ b/console-subscriber/tests/framework.rs @@ -0,0 +1,184 @@ +//! Framework tests +//! +//! The tests in this module are here to verify the testing framework itself. +//! As such, some of these tests may be repeated elsewhere (where we wish to +//! actually test the functionality of `console-subscriber`) and others are +//! negative tests that should panic. + +use std::time::Duration; + +use tokio::{task, time::sleep}; + +mod support; +use support::{assert_task, assert_tasks, ExpectedTask}; + +#[test] +fn expect_present() { + let expected_task = ExpectedTask::default() + .match_default_name() + .expect_present(); + + let future = async { + sleep(Duration::ZERO).await; + }; + + assert_task(expected_task, future); +} + +#[test] +#[should_panic(expected = "Test failed: Task validation failed: + - Task: no expectations set, if you want to just expect that a matching task is present, use `expect_present()` +")] +fn fail_no_expectations() { + let expected_task = ExpectedTask::default().match_default_name(); + + let future = async { + sleep(Duration::ZERO).await; + }; + + assert_task(expected_task, future); +} + +#[test] +fn wakes() { + let expected_task = ExpectedTask::default().match_default_name().expect_wakes(1); + + let future = async { + sleep(Duration::ZERO).await; + }; + + assert_task(expected_task, future); +} + +#[test] +#[should_panic(expected = "Test failed: Task validation failed: + - Task: expected `wakes` to be 5, but actual was 1 +")] +fn fail_wakes() { + let expected_task = ExpectedTask::default().match_default_name().expect_wakes(5); + + let future = async { + sleep(Duration::ZERO).await; + }; + + assert_task(expected_task, future); +} + +#[test] +fn self_wakes() { + let expected_task = ExpectedTask::default() + .match_default_name() + .expect_self_wakes(1); + + let future = async { task::yield_now().await }; + + assert_task(expected_task, future); +} + +#[test] +#[should_panic(expected = "Test failed: Task validation failed: + - Task: expected `self_wakes` to be 1, but actual was 0 +")] +fn fail_self_wake() { + let expected_task = ExpectedTask::default() + .match_default_name() + .expect_self_wakes(1); + + let future = async { + sleep(Duration::ZERO).await; + }; + + assert_task(expected_task, future); +} + +#[test] +fn test_spawned_task() { + let expected_task = ExpectedTask::default() + .match_name("another-name".into()) + .expect_present(); + + let future = async { + task::Builder::new() + .name("another-name") + .spawn(async { task::yield_now().await }) + }; + + assert_task(expected_task, future); +} + +#[test] +#[should_panic(expected = "Test failed: Task validation failed: + - Task: no matching actual task was found +")] +fn fail_wrong_task_name() { + let expected_task = ExpectedTask::default().match_name("wrong-name".into()); + + let future = async { task::yield_now().await }; + + assert_task(expected_task, future); +} + +#[test] +fn multiple_tasks() { + let expected_tasks = vec![ + ExpectedTask::default() + .match_name("task-1".into()) + .expect_wakes(1), + ExpectedTask::default() + .match_name("task-2".into()) + .expect_wakes(1), + ]; + + let future = async { + let task1 = task::Builder::new() + .name("task-1") + .spawn(async { task::yield_now().await }) + .unwrap(); + let task2 = task::Builder::new() + .name("task-2") + .spawn(async { task::yield_now().await }) + .unwrap(); + + tokio::try_join! { + task1, + task2, + } + .unwrap(); + }; + + assert_tasks(expected_tasks, future); +} + +#[test] +#[should_panic(expected = "Test failed: Task validation failed: + - Task: expected `wakes` to be 2, but actual was 1 +")] +fn fail_1_of_2_expected_tasks() { + let expected_tasks = vec![ + ExpectedTask::default() + .match_name("task-1".into()) + .expect_wakes(1), + ExpectedTask::default() + .match_name("task-2".into()) + .expect_wakes(2), + ]; + + let future = async { + let task1 = task::Builder::new() + .name("task-1") + .spawn(async { task::yield_now().await }) + .unwrap(); + let task2 = task::Builder::new() + .name("task-2") + .spawn(async { task::yield_now().await }) + .unwrap(); + + tokio::try_join! { + task1, + task2, + } + .unwrap(); + }; + + assert_tasks(expected_tasks, future); +} diff --git a/console-subscriber/tests/support/mod.rs b/console-subscriber/tests/support/mod.rs new file mode 100644 index 000000000..4937aff6a --- /dev/null +++ b/console-subscriber/tests/support/mod.rs @@ -0,0 +1,47 @@ +use futures::Future; + +mod state; +mod subscriber; +mod task; + +use subscriber::run_test; + +pub(crate) use subscriber::MAIN_TASK_NAME; +pub(crate) use task::ExpectedTask; + +/// Assert that an `expected_task` is recorded by a console-subscriber +/// when driving the provided `future` to completion. +/// +/// This function is equivalent to calling [`assert_tasks`] with a vector +/// containing a single task. +/// +/// # Panics +/// +/// This function will panic if the expectations on the expected task are not +/// met or if a matching task is not recorded. +#[track_caller] +#[allow(dead_code)] +pub(crate) fn assert_task(expected_task: ExpectedTask, future: Fut) +where + Fut: Future + Send + 'static, + Fut::Output: Send + 'static, +{ + run_test(vec![expected_task], future) +} + +/// Assert that the `expected_tasks` are recorded by a console-subscriber +/// when driving the provided `future` to completion. +/// +/// # Panics +/// +/// This function will panic if the expectations on any of the expected tasks +/// are not met or if matching tasks are not recorded for all expected tasks. +#[track_caller] +#[allow(dead_code)] +pub(crate) fn assert_tasks(expected_tasks: Vec, future: Fut) +where + Fut: Future + Send + 'static, + Fut::Output: Send + 'static, +{ + run_test(expected_tasks, future) +} diff --git a/console-subscriber/tests/support/state.rs b/console-subscriber/tests/support/state.rs new file mode 100644 index 000000000..6fc663808 --- /dev/null +++ b/console-subscriber/tests/support/state.rs @@ -0,0 +1,139 @@ +use std::fmt; + +use tokio::sync::broadcast::{ + self, + error::{RecvError, TryRecvError}, +}; + +/// A step in the running of the test +#[derive(Clone, Debug, PartialEq, PartialOrd)] +pub(super) enum TestStep { + /// The overall test has begun + Start, + /// The instrument server has been started + ServerStarted, + /// The client has connected to the instrument server + ClientConnected, + /// The future being driven has completed + TestFinished, + /// The client has finished recording updates + UpdatesRecorded, +} + +impl fmt::Display for TestStep { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + (self as &dyn fmt::Debug).fmt(f) + } +} + +/// The state of the test. +/// +/// This struct is used by various parts of the test framework to wait until +/// a specific test step has been reached and advance the test state to a new +/// step. +pub(super) struct TestState { + receiver: broadcast::Receiver, + sender: broadcast::Sender, + step: TestStep, +} + +impl TestState { + pub(super) fn new() -> Self { + let (sender, receiver) = broadcast::channel(1); + Self { + receiver, + sender, + step: TestStep::Start, + } + } + + /// Block asynchronously until the desired step has been reached. + /// + /// # Panics + /// + /// This function will panic if the underlying channel gets closed. + pub(super) async fn wait_for_step(&mut self, desired_step: TestStep) { + loop { + if self.step >= desired_step { + break; + } + + match self.receiver.recv().await { + Ok(step) => self.step = step, + Err(RecvError::Lagged(_)) => { + // we don't mind being lagged, we'll just get the latest state + } + Err(RecvError::Closed) => { + panic!("failed to receive current step, waiting for step: {desired_step}, did the test abort?"); + } + } + } + } + + /// Check whether the desired step has been reached without blocking. + pub(super) fn try_wait_for_step(&mut self, desired_step: TestStep) -> bool { + self.update_step(); + + self.step == desired_step + } + + /// Advance to the next step. + /// + /// The test must be at the step prior to the next step before starting. + /// Being in a different step is likely to indicate a logic error in the + /// test framework. + /// + /// # Panics + /// + /// This method will panic if the test state is not at the step prior to + /// `next_step` or if the underlying channel is closed. + #[track_caller] + pub(super) fn advance_to_step(&mut self, next_step: TestStep) { + self.update_step(); + + if self.step >= next_step { + panic!( + "cannot advance to previous or current step! current step: {current}, next step: {next_step}", + current = self.step); + } + + match (&self.step, &next_step) { + (TestStep::Start, TestStep::ServerStarted) | + (TestStep::ServerStarted, TestStep::ClientConnected) | + (TestStep::ClientConnected, TestStep::TestFinished) | + (TestStep::TestFinished, TestStep::UpdatesRecorded) => {}, + (_, _) => panic!( + "cannot advance more than one step! current step: {current}, next step: {next_step}", + current = self.step), + } + + self.sender + .send(next_step) + .expect("failed to send the next test step, did the test abort?"); + } + + fn update_step(&mut self) { + loop { + match self.receiver.try_recv() { + Ok(step) => self.step = step, + Err(TryRecvError::Lagged(_)) => { + // we don't mind being lagged, we'll just get the latest state + } + Err(TryRecvError::Closed) => { + panic!("failed to update current step, did the test abort?") + } + Err(TryRecvError::Empty) => break, + } + } + } +} + +impl Clone for TestState { + fn clone(&self) -> Self { + Self { + receiver: self.receiver.resubscribe(), + sender: self.sender.clone(), + step: self.step.clone(), + } + } +} diff --git a/console-subscriber/tests/support/subscriber.rs b/console-subscriber/tests/support/subscriber.rs new file mode 100644 index 000000000..36888ad5a --- /dev/null +++ b/console-subscriber/tests/support/subscriber.rs @@ -0,0 +1,318 @@ +use std::{collections::HashMap, fmt, future::Future, thread}; + +use console_api::{ + field::Value, + instrument::{instrument_client::InstrumentClient, InstrumentRequest}, +}; +use console_subscriber::ServerParts; +use futures::stream::StreamExt; +use tokio::{io::DuplexStream, task}; +use tonic::transport::{Channel, Endpoint, Server, Uri}; +use tower::service_fn; + +use super::state::{TestState, TestStep}; +use super::task::{ActualTask, ExpectedTask, TaskValidationFailure}; + +pub(crate) const MAIN_TASK_NAME: &str = "main"; + +#[derive(Debug)] +struct TestFailure { + failures: Vec, +} + +impl fmt::Display for TestFailure { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "Task validation failed:\n")?; + for failure in &self.failures { + write!(f, " - {failure}\n")?; + } + Ok(()) + } +} + +/// Runs the test +/// +/// This function runs the whole test. It sets up a `console-subscriber` layer +/// together with the gRPC server and connects a client to it. The subscriber +/// is then used to record traces as the provided future is driven to +/// completion on a current thread tokio runtime. +/// +/// This function will panic if the expectations on any of the expected tasks +/// are not met or if matching tasks are not recorded for all expected tasks. +#[track_caller] +pub(super) fn run_test(expected_tasks: Vec, future: Fut) +where + Fut: Future + Send + 'static, + Fut::Output: Send + 'static, +{ + use tracing_subscriber::prelude::*; + + let (client_stream, server_stream) = tokio::io::duplex(1024); + let (console_layer, server) = console_subscriber::ConsoleLayer::builder().build(); + let registry = tracing_subscriber::registry().with(console_layer); + + let mut test_state = TestState::new(); + let mut test_state_test = test_state.clone(); + + let join_handle = thread::Builder::new() + .name("console::subscriber".into()) + .spawn(move || { + let _subscriber_guard = + tracing::subscriber::set_default(tracing_core::subscriber::NoSubscriber::default()); + let runtime = tokio::runtime::Builder::new_current_thread() + .enable_io() + .enable_time() + .build() + .expect("console-test error: failed to initialize console subscriber runtime"); + + runtime.block_on(async move { + task::Builder::new() + .name("console::serve") + .spawn(console_server(server, server_stream, test_state.clone())) + .expect("console-test error: could not spawn 'console-server' task"); + + let actual_tasks = task::Builder::new() + .name("console::client") + .spawn(console_client(client_stream, test_state.clone())) + .expect("console-test error: could not spawn 'console-client' task") + .await + .expect("console-test error: failed to await 'console-client' task"); + + test_state.advance_to_step(TestStep::UpdatesRecorded); + actual_tasks + }) + }) + .expect("console subscriber could not spawn thread"); + + tracing::subscriber::with_default(registry, || { + let runtime = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap(); + + runtime.block_on(async move { + test_state_test + .wait_for_step(TestStep::ClientConnected) + .await; + + // Run the future that we are testing. + _ = tokio::task::Builder::new() + .name(MAIN_TASK_NAME) + .spawn(future) + .expect("console-test error: couldn't spawn test task") + .await; + test_state_test.advance_to_step(TestStep::TestFinished); + + test_state_test + .wait_for_step(TestStep::UpdatesRecorded) + .await; + }); + }); + + let actual_tasks = join_handle + .join() + .expect("console-test error: failed to join 'console-subscriber' thread"); + + if let Err(test_failure) = validate_expected_tasks(expected_tasks, actual_tasks) { + panic!("Test failed: {test_failure}") + } +} + +/// Starts the console server. +/// +/// The server will start serving over its side of the duplex stream. +/// +/// Once the server gets spawned into its task, the test state is advanced +/// to the `ServerStarted` step. This function will then wait until the test +/// state reaches the `UpdatesRecorded` step (indicating that all validation of the +/// received updates has been completed) before dropping the aggregator. +/// +/// # Test State +/// +/// 1. Advances to: `ServerStarted` +/// 2. Waits for: `UpdatesRecorded` +async fn console_server( + server: console_subscriber::Server, + server_stream: DuplexStream, + mut test_state: TestState, +) { + let ServerParts { + instrument_server: service, + aggregator, + .. + } = server.into_parts(); + let aggregate = task::Builder::new() + .name("console::aggregate") + .spawn(aggregator.run()) + .expect("client-console error: couldn't spawn aggregator"); + Server::builder() + .add_service(service) + .serve_with_incoming(futures::stream::iter(vec![Ok::<_, std::io::Error>( + server_stream, + )])) + .await + .expect("client-console error: couldn't start instrument server."); + test_state.advance_to_step(TestStep::ServerStarted); + + test_state.wait_for_step(TestStep::UpdatesRecorded).await; + aggregate.abort(); +} + +/// Starts the console client and validates the expected tasks. +/// +/// First we wait until the server has started (test step `ServerStarted`), then +/// the client is connected to its half of the duplex stream and we start recording +/// the actual tasks. +/// +/// Once recording finishes (see [`record_actual_tasks()`] for details on the test +/// state condition), the actual tasks returned. +/// +/// # Test State +/// +/// 1. Waits for: `ServerStarted` +/// 2. Advances to: `ClientConnected` +async fn console_client(client_stream: DuplexStream, mut test_state: TestState) -> Vec { + test_state.wait_for_step(TestStep::ServerStarted).await; + + let mut client_stream = Some(client_stream); + let channel = Endpoint::try_from("http://[::]:6669") + .expect("Could not create endpoint") + .connect_with_connector(service_fn(move |_: Uri| { + let client = client_stream.take(); + + async move { + if let Some(client) = client { + Ok(client) + } else { + Err(std::io::Error::new( + std::io::ErrorKind::Other, + "Client already taken", + )) + } + } + })) + .await + .expect("client-console error: couldn't create client"); + test_state.advance_to_step(TestStep::ClientConnected); + + record_actual_tasks(channel, test_state.clone()).await +} + +/// Records the actual tasks which are received by the client channel. +/// +/// Updates will be received until the test state reaches the `TestFinished` step +/// (indicating that the test itself has finished running), at which point we wait +/// for a final update before returning all the actual tasks which were recorded. +/// +/// # Test State +/// +/// 1. Waits for: `TestFinished` +async fn record_actual_tasks( + client_channel: Channel, + mut test_state: TestState, +) -> Vec { + let mut client = InstrumentClient::new(client_channel); + + let mut stream = loop { + let request = tonic::Request::new(InstrumentRequest {}); + match client.watch_updates(request).await { + Ok(stream) => break stream.into_inner(), + Err(err) => panic!("Client cannot connect to watch updates: {err}"), + } + }; + + let mut tasks = HashMap::new(); + + let mut last_update = false; + while let Some(update) = stream.next().await { + match update { + Ok(update) => { + if let Some(task_update) = &update.task_update { + for new_task in &task_update.new_tasks { + if let Some(id) = &new_task.id { + let mut actual_task = ActualTask::new(id.id); + for field in &new_task.fields { + if let Some(console_api::field::Name::StrName(field_name)) = + &field.name + { + if field_name == "task.name" { + actual_task.name = match &field.value { + Some(Value::DebugVal(value)) => Some(value.clone()), + Some(Value::StrVal(value)) => Some(value.clone()), + _ => None, // Anything that isn't string-like shouldn't be used as a name. + }; + } + } + } + tasks.insert(actual_task.id, actual_task); + } + } + + for (id, stats) in &task_update.stats_update { + if let Some(mut task) = tasks.get_mut(id) { + task.wakes = stats.wakes; + task.self_wakes = stats.self_wakes; + } + } + } + } + Err(e) => { + panic!("update stream error: {}", e); + } + } + + if last_update { + break; + } + + if test_state.try_wait_for_step(TestStep::TestFinished) { + // Once the test finishes running, we will get one further update and finish. + last_update = true; + } + } + + tasks.into_values().collect() +} + +/// Validate the expected tasks against the actual tasks. +/// +/// Each expected task is checked in turn. +/// +/// A matching actual task is searched for. If one is found it, the +/// expected task is validated against the actual task. +/// +/// Any validation errors result in failure. If no matches +fn validate_expected_tasks( + expected_tasks: Vec, + actual_tasks: Vec, +) -> Result<(), TestFailure> { + let failures: Vec<_> = expected_tasks + .iter() + .map(|expected| validate_expected_task(expected, &actual_tasks)) + .filter_map(|r| match r { + Ok(_) => None, + Err(validation_error) => Some(validation_error), + }) + .collect(); + + if failures.is_empty() { + Ok(()) + } else { + Err(TestFailure { failures: failures }) + } +} + +fn validate_expected_task( + expected: &ExpectedTask, + actual_tasks: &Vec, +) -> Result<(), TaskValidationFailure> { + for actual in actual_tasks { + if expected.matches_actual_task(actual) { + // We only match a single task. + // FIXME(hds): We should probably create an error or a warning if multiple tasks match. + return expected.validate_actual_task(actual); + } + } + + expected.no_match_error() +} diff --git a/console-subscriber/tests/support/task.rs b/console-subscriber/tests/support/task.rs new file mode 100644 index 000000000..6df878b1b --- /dev/null +++ b/console-subscriber/tests/support/task.rs @@ -0,0 +1,228 @@ +use std::{error, fmt}; + +use super::MAIN_TASK_NAME; + +/// An actual task +/// +/// This struct contains the values recorded from the console subscriber +/// client and represents what is known about an actual task running on +/// the test's runtime. +#[derive(Clone, Debug)] +pub(super) struct ActualTask { + pub(super) id: u64, + pub(super) name: Option, + pub(super) wakes: u64, + pub(super) self_wakes: u64, +} + +impl ActualTask { + pub(super) fn new(id: u64) -> Self { + Self { + id, + name: None, + wakes: 0, + self_wakes: 0, + } + } +} + +/// An error in task validation. +pub(super) struct TaskValidationFailure { + /// The expected task whose expectations were not met. + expected: ExpectedTask, + /// The actual task which failed the validation + actual: Option, + /// A textual description of the validation failure + failure: String, +} + +impl error::Error for TaskValidationFailure {} + +impl fmt::Display for TaskValidationFailure { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.failure) + } +} + +impl fmt::Debug for TaskValidationFailure { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match &self.actual { + Some(actual) => write!( + f, + "Task Validation Failed!\n Expected Task: {expected:?}\n Actual Task: {actual:?}\n Failure: {failure}", + expected = self.expected, failure = self.failure), + None => write!( + f, + "Task Validation Failed!\n Expected Task: {expected:?}\n Failure: {failure}", + expected = self.expected, failure = self.failure), + } + } +} + +/// An expected task. +/// +/// This struct contains the fields that an expected task will attempt to match +/// actual tasks on, as well as the expectations that will be used to validate +/// which the actual task is as expected. +#[derive(Clone, Debug)] +pub(crate) struct ExpectedTask { + match_name: Option, + + expect_present: Option, + expect_wakes: Option, + expect_self_wakes: Option, +} + +impl Default for ExpectedTask { + fn default() -> Self { + Self { + match_name: None, + expect_present: None, + expect_wakes: None, + expect_self_wakes: None, + } + } +} + +impl ExpectedTask { + /// Returns whether or not an actual task matches this expected task. + /// + /// All matching rules will be run, if they all succeed, then `true` will + /// be returned, otherwise `false`. + pub(super) fn matches_actual_task(&self, actual_task: &ActualTask) -> bool { + if let Some(match_name) = &self.match_name { + if Some(match_name) == actual_task.name.as_ref() { + return true; + } + } + + false + } + + /// Returns an error specifying that no match was found for this expected + /// task. + pub(super) fn no_match_error(&self) -> Result<(), TaskValidationFailure> { + Err(TaskValidationFailure { + expected: self.clone(), + actual: None, + failure: format!("{self}: no matching actual task was found"), + }) + } + + /// Validates all expectations against the provided actual task. + /// + /// No check that the actual task matches is performed. That must have been + /// done prior. + /// + /// If all expections are met, this method returns `Ok(())`. If any + /// expectations are not met, then the first incorrect expectation will + /// be returned as an `Err`. + pub(super) fn validate_actual_task( + &self, + actual_task: &ActualTask, + ) -> Result<(), TaskValidationFailure> { + let mut no_expectations = true; + if let Some(_expected) = self.expect_present { + no_expectations = false; + } + + if let Some(expected_wakes) = self.expect_wakes { + no_expectations = false; + if expected_wakes != actual_task.wakes { + return Err(TaskValidationFailure { + expected: self.clone(), + actual: Some(actual_task.clone()), + failure: format!( + "{self}: expected `wakes` to be {expected_wakes}, but actual was {actual_wakes}", + actual_wakes = actual_task.wakes), + }); + } + } + + if let Some(expected_self_wakes) = self.expect_self_wakes { + no_expectations = false; + if expected_self_wakes != actual_task.self_wakes { + return Err(TaskValidationFailure { + expected: self.clone(), + actual: Some(actual_task.clone()), + failure: format!( + "{self}: expected `self_wakes` to be {expected_self_wakes}, but actual was {actual_self_wakes}", + actual_self_wakes = actual_task.self_wakes), + }); + } + } + + if no_expectations { + return Err(TaskValidationFailure { + expected: self.clone(), + actual: Some(actual_task.clone()), + failure: format!( + "{self}: no expectations set, if you want to just expect that a matching task is present, use `expect_present()`") + }); + } + + Ok(()) + } + + /// Matches tasks by name. + /// + /// To match this expected task, an actual task must have the name `name`. + #[allow(dead_code)] + pub(crate) fn match_name(mut self, name: String) -> Self { + self.match_name = Some(name); + self + } + + /// Matches tasks by the default task name. + /// + /// To match this expected task, an actual task must have the default name + /// assigned to the task which runs the future provided to [`assert_task`] + /// or [`assert_tasks`]. + /// + /// [`assert_task`]: fn@support::assert_task + /// [`assert_tasks`]: fn@support::assert_tasks + #[allow(dead_code)] + pub(crate) fn match_default_name(mut self) -> Self { + self.match_name = Some(MAIN_TASK_NAME.into()); + self + } + + /// Expects that a task is present. + /// + /// To validate, an actual task matching this expected task must be found. + #[allow(dead_code)] + pub(crate) fn expect_present(mut self) -> Self { + self.expect_present = Some(true); + self + } + + /// Expects that a task has a specific value for `wakes`. + /// + /// To validate, the actual task matching this expected task must have + /// a count of wakes equal to `wakes`. + #[allow(dead_code)] + pub(crate) fn expect_wakes(mut self, wakes: u64) -> Self { + self.expect_wakes = Some(wakes); + self + } + + /// Expects that a task has a specific value for `self_wakes`. + /// + /// To validate, the actual task matching this expected task must have + /// a count of self wakes equal to `self_wakes`. + #[allow(dead_code)] + pub(crate) fn expect_self_wakes(mut self, self_wakes: u64) -> Self { + self.expect_self_wakes = Some(self_wakes); + self + } +} + +impl fmt::Display for ExpectedTask { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let fields = match &self.match_name { + Some(name) => format!("name={name}"), + None => "(no fields to match on)".into(), + }; + write!(f, "Task<{fields}>") + } +} diff --git a/console-subscriber/tests/wake.rs b/console-subscriber/tests/wake.rs new file mode 100644 index 000000000..e64e87a6e --- /dev/null +++ b/console-subscriber/tests/wake.rs @@ -0,0 +1,48 @@ +mod support; +use std::time::Duration; + +use support::{assert_task, ExpectedTask}; +use tokio::{task, time::sleep}; + +#[test] +fn sleep_wakes() { + let expected_task = ExpectedTask::default() + .match_default_name() + .expect_wakes(1) + .expect_self_wakes(0); + + let future = async { + sleep(Duration::ZERO).await; + }; + + assert_task(expected_task, future); +} + +#[test] +fn double_sleep_wakes() { + let expected_task = ExpectedTask::default() + .match_default_name() + .expect_wakes(2) + .expect_self_wakes(0); + + let future = async { + sleep(Duration::ZERO).await; + sleep(Duration::ZERO).await; + }; + + assert_task(expected_task, future); +} + +#[test] +fn self_wake() { + let expected_task = ExpectedTask::default() + .match_default_name() + .expect_wakes(1) + .expect_self_wakes(1); + + let future = async { + task::yield_now().await; + }; + + assert_task(expected_task, future); +}