diff --git a/crates/rayexec_execution/src/arrays/buffer/physical_type.rs b/crates/rayexec_execution/src/arrays/buffer/physical_type.rs index 19b5f9c69..8e50eebaa 100644 --- a/crates/rayexec_execution/src/arrays/buffer/physical_type.rs +++ b/crates/rayexec_execution/src/arrays/buffer/physical_type.rs @@ -131,6 +131,36 @@ impl PhysicalStorage for PhysicalUntypedNull { } } +#[derive(Debug, Clone, Copy)] +pub struct PhysicalBoolean; + +impl PhysicalStorage for PhysicalBoolean { + const PHYSICAL_TYPE: PhysicalType = PhysicalType::Boolean; + + type PrimaryBufferType = bool; + type StorageType = Self::PrimaryBufferType; + + type Storage<'a> = &'a [Self::StorageType]; + + fn get_storage(buffer: &ArrayBuffer) -> Result> + where + B: BufferManager, + { + buffer.try_as_slice::() + } +} + +impl MutablePhysicalStorage for PhysicalBoolean { + type MutableStorage<'a> = &'a mut [Self::StorageType]; + + fn get_storage_mut(buffer: &mut ArrayBuffer) -> Result> + where + B: BufferManager, + { + buffer.try_as_slice_mut::() + } +} + #[derive(Debug, Clone, Copy)] pub struct PhysicalI8; diff --git a/crates/rayexec_execution/src/arrays/flat_array.rs b/crates/rayexec_execution/src/arrays/flat_array.rs index ac9365097..143babb29 100644 --- a/crates/rayexec_execution/src/arrays/flat_array.rs +++ b/crates/rayexec_execution/src/arrays/flat_array.rs @@ -37,6 +37,10 @@ where }) } } + + pub fn logical_len(&self) -> usize { + self.selection.len() + } } #[derive(Debug, Clone, Copy)] diff --git a/crates/rayexec_execution/src/arrays/mod.rs b/crates/rayexec_execution/src/arrays/mod.rs index f6c8dce3d..ee74dd13a 100644 --- a/crates/rayexec_execution/src/arrays/mod.rs +++ b/crates/rayexec_execution/src/arrays/mod.rs @@ -10,3 +10,6 @@ pub mod flat_array; pub mod scalar; pub mod schema; pub mod validity; + +#[cfg(test)] +pub(crate) mod testutil; diff --git a/crates/rayexec_execution/src/arrays/testutil.rs b/crates/rayexec_execution/src/arrays/testutil.rs new file mode 100644 index 000000000..a4a08e5c4 --- /dev/null +++ b/crates/rayexec_execution/src/arrays/testutil.rs @@ -0,0 +1,170 @@ +use std::collections::BTreeMap; +use std::fmt::Debug; + +use iterutil::exact_size::IntoExactSizeIterator; + +use super::array::Array; +use super::batch::Batch; +use super::buffer::{Int32BufferBuilder, StringBufferBuilder}; +use super::buffer_manager::NopBufferManager; +use super::datatype::DataType; +use crate::arrays::buffer::physical_type::{PhysicalBoolean, PhysicalI32, PhysicalStorage, PhysicalType, PhysicalUtf8}; +use crate::arrays::buffer::ArrayBuffer; +use crate::arrays::executor::scalar::binary::BinaryExecutor; +use crate::arrays::executor::scalar::unary::UnaryExecutor; +use crate::arrays::executor::OutBuffer; +use crate::arrays::flat_array::FlatArrayView; +use crate::arrays::validity::Validity; + +pub fn new_i32_array(vals: impl IntoExactSizeIterator) -> Array { + Array::new_with_buffer(DataType::Int32, Int32BufferBuilder::from_iter(vals).unwrap()) +} + +pub fn new_string_array<'a>(vals: impl IntoExactSizeIterator) -> Array { + Array::new_with_buffer(DataType::Utf8, StringBufferBuilder::from_iter(vals).unwrap()) +} + +pub fn new_batch_from_arrays(arrays: impl IntoIterator) -> Batch { + Batch::from_arrays(arrays, true).unwrap() +} + +/// Assert two arrays are logically equal. +/// +/// This will assume that the array's capacity is the array's logical length. +pub fn assert_arrays_eq(array1: &Array, array2: &Array) { + assert_eq!(array1.capacity(), array2.capacity(), "array capacities differ"); + assert_arrays_eq_count(array1, array2, array1.capacity()) +} + +/// Asserts that two arrays are logically equal for the first `count` rows. +/// +/// This will check valid and invalid values. Assertion error messages will +/// print out Some/None to represent valid/invalid. +pub fn assert_arrays_eq_count(array1: &Array, array2: &Array, count: usize) { + assert_eq!(array1.datatype, array2.datatype); + + let flat1 = array1.flat_view().unwrap(); + let flat2 = array2.flat_view().unwrap(); + + fn assert_eq_inner(flat1: FlatArrayView, flat2: FlatArrayView, count: usize) + where + S: PhysicalStorage, + S::StorageType: ToOwned, + { + let mut out = BTreeMap::new(); + let sel = 0..count; + + UnaryExecutor::for_each_flat::(flat1, sel.clone(), |idx, v| { + out.insert(idx, v.map(|v| v.to_owned())); + }) + .unwrap(); + + UnaryExecutor::for_each_flat::(flat2, sel, |idx, v| match out.remove(&idx) { + Some(existing) => { + let v = v.map(|v| v.to_owned()); + assert_eq!(existing, v, "values differ at index {idx}"); + } + None => panic!("missing value for index in array 1 {idx}"), + }) + .unwrap(); + + if !out.is_empty() { + panic!("extra entries in array 1: {:?}", out); + } + } + + match array1.datatype.physical_type() { + PhysicalType::Int32 => assert_eq_inner::(flat1, flat2, count), + PhysicalType::Utf8 => assert_eq_inner::(flat1, flat2, count), + other => unimplemented!("{other:?}"), + } +} + +/// Asserts two batches are logically equal. +pub fn assert_batches_eq(batch1: &Batch, batch2: &Batch) { + let arrays1 = batch1.arrays(); + let arrays2 = batch2.arrays(); + + assert_eq!(arrays1.len(), arrays2.len(), "batches have different number of arrays"); + assert_eq!( + batch1.num_rows(), + batch2.num_rows(), + "batches have different number of rows" + ); + + for (array1, array2) in arrays1.iter().zip(arrays2) { + assert_arrays_eq_count(array1, array2, batch1.num_rows()); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn assert_i32_arrays_eq_simple() { + let array1 = new_i32_array([4, 5, 6]); + let array2 = new_i32_array([4, 5, 6]); + + assert_arrays_eq(&array1, &array2); + } + + #[test] + fn assert_i32_arrays_eq_with_dictionary() { + let array1 = new_i32_array([5, 4, 4]); + let mut array2 = new_i32_array([4, 5]); + array2.select(&NopBufferManager, [1, 0, 0]).unwrap(); + + assert_arrays_eq(&array1, &array2); + } + + #[test] + fn assert_i32_arrays_eq_with_invalid() { + let mut array1 = new_i32_array([4, 5, 6]); + array1.validity.set_invalid(1); + + let mut array2 = new_i32_array([4, 8, 6]); + array2.validity.set_invalid(1); + + assert_arrays_eq(&array1, &array2); + } + + #[test] + fn assert_batches_eq_simple() { + let batch1 = new_batch_from_arrays([new_i32_array([4, 5, 6]), new_string_array(["a", "b", "c"])]); + let batch2 = new_batch_from_arrays([new_i32_array([4, 5, 6]), new_string_array(["a", "b", "c"])]); + + assert_batches_eq(&batch1, &batch2); + } + + #[test] + fn assert_batches_eq_logical_row_count() { + let mut batch1 = new_batch_from_arrays([ + new_i32_array([4, 5, 6, 7, 8]), + new_string_array(["a", "b", "c", "d", "e"]), + ]); + batch1.set_num_rows(3).unwrap(); + + let batch2 = new_batch_from_arrays([new_i32_array([4, 5, 6]), new_string_array(["a", "b", "c"])]); + + assert_batches_eq(&batch1, &batch2); + } + + #[test] + #[should_panic] + fn assert_i32_arrays_eq_not_eq() { + let array1 = new_i32_array([4, 5, 6]); + let array2 = new_i32_array([4, 5, 7]); + + assert_arrays_eq(&array1, &array2); + } + + #[test] + #[should_panic] + fn assert_i32_arrays_different_lengths() { + let array1 = new_i32_array([4, 5, 6]); + let array2 = new_i32_array([4, 5]); + + assert_arrays_eq(&array1, &array2); + } +} diff --git a/crates/rayexec_execution/src/arrays/validity.rs b/crates/rayexec_execution/src/arrays/validity.rs index fb1e9501b..4ded5895b 100644 --- a/crates/rayexec_execution/src/arrays/validity.rs +++ b/crates/rayexec_execution/src/arrays/validity.rs @@ -62,4 +62,28 @@ impl Validity { ValidityInner::Mask { bitmap } => bitmap.set(idx, false), } } + + pub fn iter(&self) -> ValidityIter { + ValidityIter { idx: 0, validity: self } + } +} + +#[derive(Debug)] +pub struct ValidityIter<'a> { + idx: usize, + validity: &'a Validity, +} + +impl<'a> Iterator for ValidityIter<'a> { + type Item = bool; + + fn next(&mut self) -> Option { + if self.idx >= self.validity.len() { + return None; + } + + let val = self.validity.is_valid(self.idx); + self.idx += 1; + Some(val) + } } diff --git a/crates/rayexec_execution/src/execution/operators_exp/batch_collection.rs b/crates/rayexec_execution/src/execution/operators_exp/batch_collection.rs index 7d5d4ccce..e9cff40b3 100644 --- a/crates/rayexec_execution/src/execution/operators_exp/batch_collection.rs +++ b/crates/rayexec_execution/src/execution/operators_exp/batch_collection.rs @@ -55,6 +55,14 @@ where self.capacity } + pub fn set_row_count(&mut self, count: usize) -> Result<()> { + if count > self.capacity { + return Err(RayexecError::new("Row count would exceed capacity")); + } + self.row_count = count; + Ok(()) + } + pub fn row_count(&self) -> usize { self.row_count } @@ -67,6 +75,7 @@ where self.row_count + additional < self.capacity } + /// Appends a batch to this block. pub fn append_batch_data(&mut self, batch: &Batch) -> Result<()> { let total_num_rows = self.row_count + batch.num_rows(); if total_num_rows > self.capacity { diff --git a/crates/rayexec_execution/src/execution/operators_exp/mod.rs b/crates/rayexec_execution/src/execution/operators_exp/mod.rs index 2805fce62..061c67d01 100644 --- a/crates/rayexec_execution/src/execution/operators_exp/mod.rs +++ b/crates/rayexec_execution/src/execution/operators_exp/mod.rs @@ -2,15 +2,19 @@ pub mod batch_collection; pub mod physical_project; pub mod physical_sort; +#[cfg(test)] +mod testutil; + use std::fmt::Debug; use std::task::Context; use physical_project::ProjectPartitionState; use physical_sort::{SortOperatorState, SortPartitionState}; -use rayexec_error::Result; +use rayexec_error::{RayexecError, Result}; use crate::arrays::batch::Batch; use crate::arrays::buffer_manager::BufferManager; +use crate::database::DatabaseContext; use crate::explain::explainable::Explainable; #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -82,6 +86,21 @@ pub enum PartitionAndOperatorStates { }, } +impl PartitionAndOperatorStates { + pub fn branchless_into_states(self) -> Result<(OperatorState, Vec)> { + match self { + Self::Branchless { + operator_state, + partition_states, + } => Ok((operator_state, partition_states)), + Self::BranchingOutput { .. } => Err(RayexecError::new("Expected branchless states, got branching output")), + Self::TerminatingInput { .. } => { + Err(RayexecError::new("Expected branchless states, got terminating input")) + } + } + } +} + #[derive(Debug)] pub enum PartitionState { Project(ProjectPartitionState), @@ -108,6 +127,13 @@ pub struct ExecuteInOutState<'a> { } pub trait ExecutableOperator: Sync + Send + Debug + Explainable { + fn create_states( + &self, + context: &DatabaseContext, + batch_size: usize, + partitions: usize, + ) -> Result; + fn poll_execute( &self, cx: &mut Context, diff --git a/crates/rayexec_execution/src/execution/operators_exp/physical_project.rs b/crates/rayexec_execution/src/execution/operators_exp/physical_project.rs index 3438df412..4804ab395 100644 --- a/crates/rayexec_execution/src/execution/operators_exp/physical_project.rs +++ b/crates/rayexec_execution/src/execution/operators_exp/physical_project.rs @@ -6,10 +6,12 @@ use super::{ ExecutableOperator, ExecuteInOutState, OperatorState, + PartitionAndOperatorStates, PartitionState, PollExecute, PollFinalize, }; +use crate::database::DatabaseContext; use crate::explain::explainable::{ExplainConfig, ExplainEntry, Explainable}; use crate::expr::physical::evaluator::ExpressionEvaluator; use crate::expr::physical::PhysicalScalarExpression; @@ -25,6 +27,26 @@ pub struct ProjectPartitionState { } impl ExecutableOperator for PhysicalProject { + fn create_states( + &self, + _context: &DatabaseContext, + batch_size: usize, + partitions: usize, + ) -> Result { + let partition_states = (0..partitions) + .map(|_| { + PartitionState::Project(ProjectPartitionState { + evaluator: ExpressionEvaluator::new(self.projections.clone(), batch_size), + }) + }) + .collect(); + + Ok(PartitionAndOperatorStates::Branchless { + operator_state: OperatorState::None, + partition_states, + }) + } + fn poll_execute( &self, _cx: &mut Context, diff --git a/crates/rayexec_execution/src/execution/operators_exp/physical_sort/merge.rs b/crates/rayexec_execution/src/execution/operators_exp/physical_sort/merge.rs index 09a863b5a..50bbeabc6 100644 --- a/crates/rayexec_execution/src/execution/operators_exp/physical_sort/merge.rs +++ b/crates/rayexec_execution/src/execution/operators_exp/physical_sort/merge.rs @@ -1,5 +1,5 @@ use std::cmp::{Ordering, Reverse}; -use std::collections::BinaryHeap; +use std::collections::{BinaryHeap, VecDeque}; use rayexec_error::Result; @@ -65,16 +65,16 @@ where #[derive(Debug)] pub struct MergingSortBlock { /// The current index in the block that we're comparing. - pub curr_idx: usize, + curr_idx: usize, /// The block we're merging. - pub block: SortBlock, + block: SortBlock, } #[derive(Debug)] pub struct MergeQueue { exhausted: bool, current: MergingSortBlock, - remaining: Vec>, // Pop from back to front. + remaining: VecDeque>, } impl MergeQueue @@ -95,16 +95,20 @@ where curr_idx: 0, block: sort_block, }, - remaining: Vec::new(), + remaining: VecDeque::new(), }) } /// Create a new queue of blocks. /// + /// Blocks should be totally ordered to from least to greatest. + /// /// May return None if there's no blocks with any data. - pub fn new(mut sort_blocks: Vec>) -> Option { + pub fn new(blocks: impl IntoIterator>) -> Option { + let mut blocks: VecDeque<_> = blocks.into_iter().collect(); + loop { - match sort_blocks.pop() { + match blocks.pop_front() { Some(first) => { if first.block.row_count() > 0 { return Some(MergeQueue { @@ -113,7 +117,7 @@ where curr_idx: 0, block: first, }, - remaining: sort_blocks, + remaining: blocks, }); } } @@ -131,7 +135,7 @@ where if self.current.curr_idx >= self.current.block.row_count() { // Get next block in queue. loop { - match self.remaining.pop() { + match self.remaining.pop_front() { Some(block) => { if block.block.row_count() == 0 { // Skip empty blocks. diff --git a/crates/rayexec_execution/src/execution/operators_exp/physical_sort/mod.rs b/crates/rayexec_execution/src/execution/operators_exp/physical_sort/mod.rs index d57bd4423..41602fd9d 100644 --- a/crates/rayexec_execution/src/execution/operators_exp/physical_sort/mod.rs +++ b/crates/rayexec_execution/src/execution/operators_exp/physical_sort/mod.rs @@ -12,8 +12,17 @@ use rayexec_error::{OptionExt, RayexecError, Result}; use sort_data::{SortBlock, SortData}; use sort_layout::SortLayout; -use super::{ExecutableOperator, ExecuteInOutState, OperatorState, PartitionState, PollExecute, PollFinalize}; +use super::{ + ExecutableOperator, + ExecuteInOutState, + OperatorState, + PartitionAndOperatorStates, + PartitionState, + PollExecute, + PollFinalize, +}; use crate::arrays::buffer_manager::NopBufferManager; +use crate::database::DatabaseContext; use crate::explain::explainable::{ExplainConfig, ExplainEntry, Explainable}; #[derive(Debug)] @@ -67,6 +76,34 @@ pub struct PhysicalSort { } impl ExecutableOperator for PhysicalSort { + fn create_states( + &self, + _context: &DatabaseContext, + batch_size: usize, + partitions: usize, + ) -> Result { + let operator_state = OperatorState::Sort(SortOperatorState { + inner: Mutex::new(SortOperatorStateInner { + remaining: partitions, + queues: Vec::new(), + }), + }); + + let partition_states = (0..partitions) + .map(|_| { + PartitionState::Sort(SortPartitionState::Consume(PartitionStateLocalSort { + output_capacity: batch_size, + sort_data: SortData::new(batch_size), + })) + }) + .collect(); + + Ok(PartitionAndOperatorStates::Branchless { + operator_state, + partition_states, + }) + } + fn poll_execute( &self, _cx: &mut Context, @@ -137,6 +174,7 @@ impl ExecutableOperator for PhysicalSort { loop { let mut block = SortBlock::new(&NopBufferManager, &self.layout, consume_state.output_capacity)?; let count = merger.merge_round(&mut block)?; + block.block.set_row_count(count)?; if count == 0 { // No rows, we're done merging. @@ -183,3 +221,104 @@ impl Explainable for PhysicalSort { unimplemented!() } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::arrays::batch::Batch; + use crate::arrays::datatype::DataType; + use crate::arrays::testutil::{assert_batches_eq, new_batch_from_arrays, new_i32_array, new_string_array}; + use crate::execution::operators_exp::testutil::context::test_database_context; + use crate::execution::operators_exp::testutil::wrapper::OperatorWrapper; + use crate::expr::physical::column_expr::PhysicalColumnExpr; + use crate::expr::physical::PhysicalSortExpression; + + #[test] + fn sort_single_partition_sort_on_i32() { + let operator = PhysicalSort { + layout: SortLayout::new( + vec![DataType::Int32, DataType::Utf8], + &[PhysicalSortExpression { + column: PhysicalColumnExpr { idx: 0 }, + desc: false, + nulls_first: false, + }], + ), + }; + + let context = test_database_context(); + let (op_state, mut part_states) = operator + .create_states(&context, 4, 1) + .unwrap() + .branchless_into_states() + .unwrap(); + + let operator = OperatorWrapper::new(operator); + + let inputs = [ + new_batch_from_arrays([new_i32_array([4, 5, 6]), new_string_array(["a", "b", "c"])]), + new_batch_from_arrays([new_i32_array([1, 2, 7]), new_string_array(["d", "e", "f"])]), + ]; + + for mut input in inputs { + let poll = operator + .poll_execute( + &mut part_states[0], + &op_state, + ExecuteInOutState { + input: Some(&mut input), + output: None, + }, + ) + .unwrap(); + + assert_eq!(PollExecute::NeedsMore, poll); + } + + let poll = operator.poll_finalize(&mut part_states[0], &op_state).unwrap(); + assert_eq!(PollFinalize::NeedsDrain, poll); + + let mut out = Batch::new(&NopBufferManager, [DataType::Int32, DataType::Utf8], 4).unwrap(); + let poll = operator + .poll_execute( + &mut part_states[0], + &op_state, + ExecuteInOutState { + input: None, + output: Some(&mut out), + }, + ) + .unwrap(); + assert_eq!(PollExecute::HasMore, poll); + + let expect1 = new_batch_from_arrays([new_i32_array([1, 2, 4, 5]), new_string_array(["d", "e", "a", "b"])]); + assert_batches_eq(&expect1, &out); + + let poll = operator + .poll_execute( + &mut part_states[0], + &op_state, + ExecuteInOutState { + input: None, + output: Some(&mut out), + }, + ) + .unwrap(); + assert_eq!(PollExecute::HasMore, poll); + + let expect2 = new_batch_from_arrays([new_i32_array([6, 7]), new_string_array(["c", "f"])]); + assert_batches_eq(&expect2, &out); + + let poll = operator + .poll_execute( + &mut part_states[0], + &op_state, + ExecuteInOutState { + input: None, + output: Some(&mut out), + }, + ) + .unwrap(); + assert_eq!(PollExecute::Exhausted, poll); + } +} diff --git a/crates/rayexec_execution/src/execution/operators_exp/physical_sort/sort_data.rs b/crates/rayexec_execution/src/execution/operators_exp/physical_sort/sort_data.rs index 7633703a4..20192446d 100644 --- a/crates/rayexec_execution/src/execution/operators_exp/physical_sort/sort_data.rs +++ b/crates/rayexec_execution/src/execution/operators_exp/physical_sort/sort_data.rs @@ -2,12 +2,9 @@ use rayexec_error::Result; use super::sort_layout::SortLayout; use crate::arrays::batch::Batch; -use crate::arrays::buffer::physical_type::{PhysicalI8, PhysicalType}; use crate::arrays::buffer_manager::BufferManager; -use crate::arrays::datatype::DataType; use crate::execution::operators_exp::batch_collection::BatchCollectionBlock; use crate::execution::operators_exp::physical_sort::encode::prefix_encode; -use crate::expr::physical::PhysicalSortExpression; // TODO: // - varlen tiebreaks @@ -142,13 +139,21 @@ where /// Get a buffer slice representing the encoded sort keys for a row. pub fn get_sort_key_buf(&self, row_idx: usize) -> &[u8] { let start = self.key_encode_offsets[row_idx]; - let end = self.key_encode_offsets[row_idx + 1]; + let end = if row_idx + 1 == self.key_encode_offsets.len() { + self.key_encode_buffer.len() + } else { + self.key_encode_offsets[row_idx + 1] + }; &self.key_encode_buffer[start..end] } pub fn get_sort_key_buf_mut(&mut self, row_idx: usize) -> &mut [u8] { let start = self.key_encode_offsets[row_idx]; - let end = self.key_encode_offsets[row_idx + 1]; + let end = if row_idx + 1 == self.key_encode_offsets.len() { + self.key_encode_buffer.len() + } else { + self.key_encode_offsets[row_idx + 1] + }; &mut self.key_encode_buffer[start..end] } @@ -212,12 +217,11 @@ where #[cfg(test)] mod tests { use super::*; - use crate::arrays::array::Array; - use crate::arrays::buffer::physical_type::PhysicalI32; - use crate::arrays::buffer::Int32BufferBuilder; use crate::arrays::buffer_manager::NopBufferManager; - use crate::arrays::executor::scalar::unary::UnaryExecutor; + use crate::arrays::datatype::DataType; + use crate::arrays::testutil::{assert_arrays_eq, new_batch_from_arrays, new_i32_array}; use crate::expr::physical::column_expr::PhysicalColumnExpr; + use crate::expr::physical::PhysicalSortExpression; #[test] fn sort_i32_batches() { @@ -232,22 +236,8 @@ mod tests { let mut sort_data = SortData::new(4096); - let batch1 = Batch::from_arrays( - [Array::new_with_buffer( - DataType::Int32, - Int32BufferBuilder::from_iter([4, 7, 6]).unwrap(), - )], - true, - ) - .unwrap(); - let batch2 = Batch::from_arrays( - [Array::new_with_buffer( - DataType::Int32, - Int32BufferBuilder::from_iter([2, 8]).unwrap(), - )], - true, - ) - .unwrap(); + let batch1 = new_batch_from_arrays([new_i32_array([4, 7, 6])]); + let batch2 = new_batch_from_arrays([new_i32_array([2, 8])]); sort_data.push_batch(&NopBufferManager, &layout, &batch1).unwrap(); sort_data.push_batch(&NopBufferManager, &layout, &batch2).unwrap(); @@ -258,13 +248,7 @@ mod tests { assert_eq!(1, sort_data.sorted.len()); assert_eq!(5, sort_data.sorted[0].block.row_count()); - let mut out = vec![0; 5]; - let flat = sort_data.sorted[0].block.arrays()[0].flat_view().unwrap(); - UnaryExecutor::for_each_flat::(flat, 0..5, |idx, v| { - out[idx] = v.copied().unwrap(); - }) - .unwrap(); - - assert_eq!(vec![2, 4, 6, 7, 8], out); + let expected = new_i32_array([2, 4, 6, 7, 8]); + assert_arrays_eq(&expected, &sort_data.sorted[0].block.arrays()[0]); } } diff --git a/crates/rayexec_execution/src/execution/operators_exp/testutil/context.rs b/crates/rayexec_execution/src/execution/operators_exp/testutil/context.rs new file mode 100644 index 000000000..32b433734 --- /dev/null +++ b/crates/rayexec_execution/src/execution/operators_exp/testutil/context.rs @@ -0,0 +1,9 @@ +use std::sync::Arc; + +use crate::database::system::new_system_catalog; +use crate::database::DatabaseContext; +use crate::datasource::DataSourceRegistry; + +pub fn test_database_context() -> DatabaseContext { + DatabaseContext::new(Arc::new(new_system_catalog(&DataSourceRegistry::default()).unwrap())).unwrap() +} diff --git a/crates/rayexec_execution/src/execution/operators_exp/testutil/mod.rs b/crates/rayexec_execution/src/execution/operators_exp/testutil/mod.rs new file mode 100644 index 000000000..030694e81 --- /dev/null +++ b/crates/rayexec_execution/src/execution/operators_exp/testutil/mod.rs @@ -0,0 +1,3 @@ +//! Utilities for testing individual operators. +pub mod context; +pub mod wrapper; diff --git a/crates/rayexec_execution/src/execution/operators_exp/testutil/wrapper.rs b/crates/rayexec_execution/src/execution/operators_exp/testutil/wrapper.rs new file mode 100644 index 000000000..6beb5c260 --- /dev/null +++ b/crates/rayexec_execution/src/execution/operators_exp/testutil/wrapper.rs @@ -0,0 +1,73 @@ +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; +use std::task::{Context, Wake, Waker}; + +use rayexec_error::Result; + +use crate::execution::operators_exp::{ + ExecutableOperator, + ExecuteInOutState, + OperatorState, + PartitionState, + PollExecute, + PollFinalize, +}; + +#[derive(Debug, Default)] +pub struct CountingWaker { + count: AtomicUsize, +} + +impl CountingWaker { + pub fn wake_count(&self) -> usize { + self.count.load(Ordering::SeqCst) + } +} + +impl Wake for CountingWaker { + fn wake(self: Arc) { + self.count.fetch_add(1, Ordering::SeqCst); + } +} + +/// Wrapper around an operator that uses a stub waker that tracks how many times +/// it's woken. +#[derive(Debug)] +pub struct OperatorWrapper { + pub waker: Arc, + pub operator: O, +} + +impl OperatorWrapper +where + O: ExecutableOperator, +{ + pub fn new(operator: O) -> Self { + OperatorWrapper { + waker: Arc::new(CountingWaker::default()), + operator, + } + } + + pub fn poll_execute( + &self, + partition_state: &mut PartitionState, + operator_state: &OperatorState, + inout: ExecuteInOutState, + ) -> Result { + let waker = Waker::from(self.waker.clone()); + let mut cx = Context::from_waker(&waker); + self.operator + .poll_execute(&mut cx, partition_state, operator_state, inout) + } + + pub fn poll_finalize( + &self, + partition_state: &mut PartitionState, + operator_state: &OperatorState, + ) -> Result { + let waker = Waker::from(self.waker.clone()); + let mut cx = Context::from_waker(&waker); + self.operator.poll_finalize(&mut cx, partition_state, operator_state) + } +} diff --git a/crates/rayexec_execution/src/expr/physical/evaluator.rs b/crates/rayexec_execution/src/expr/physical/evaluator.rs index 500ee43f1..11436fb82 100644 --- a/crates/rayexec_execution/src/expr/physical/evaluator.rs +++ b/crates/rayexec_execution/src/expr/physical/evaluator.rs @@ -31,7 +31,7 @@ impl ExpressionState { } impl ExpressionEvaluator { - pub fn new(expressions: Vec, batch_size: usize) -> Self { + pub fn new(expressions: Vec, batch_size: usize) -> Self { unimplemented!() } @@ -47,12 +47,7 @@ impl ExpressionEvaluator { /// /// `input` is mutable only to allow converting arrays from owned to /// managed. - pub fn eval_batch( - &mut self, - input: &mut Batch, - sel: FlatSelection, - output: &mut Batch, - ) -> Result<()> { + pub fn eval_batch(&mut self, input: &mut Batch, sel: FlatSelection, output: &mut Batch) -> Result<()> { for (idx, expr) in self.expressions.iter().enumerate() { let output = output.get_array_mut(idx)?; let state = &mut self.states[idx];