diff --git a/crates/rayexec_execution/src/arrays/array.rs b/crates/rayexec_execution/src/arrays/array.rs index a07cd6596..bc6cd3e9f 100644 --- a/crates/rayexec_execution/src/arrays/array.rs +++ b/crates/rayexec_execution/src/arrays/array.rs @@ -4,9 +4,9 @@ use iterutil::exact_size::IntoExactSizeIterator; use rayexec_bullet::scalar::ScalarValue; use rayexec_error::{not_implemented, RayexecError, Result}; -use super::buffer::addressable::MutableAddressableStorage; +use super::buffer::addressable::{AddressableStorage, MutableAddressableStorage}; use super::buffer::dictionary::DictionaryBuffer; -use super::buffer::physical_type::{PhysicalDictionary, PhysicalType}; +use super::buffer::physical_type::{MutablePhysicalStorage, PhysicalDictionary, PhysicalType}; use super::buffer::{ArrayBuffer, SecondaryBuffers}; use super::buffer_manager::{BufferManager, NopBufferManager}; use super::datatype::DataType; @@ -62,17 +62,11 @@ where } } - pub fn new_with_validity( - datatype: DataType, - buffer: ArrayBuffer, - validity: Validity, - ) -> Result { + pub fn new_with_validity(datatype: DataType, buffer: ArrayBuffer, validity: Validity) -> Result { if validity.len() != buffer.capacity() { - return Err( - RayexecError::new("Validty length does not match buffer length") - .with_field("validity_len", validity.len()) - .with_field("buffer_len", buffer.capacity()), - ); + return Err(RayexecError::new("Validty length does not match buffer length") + .with_field("validity_len", validity.len()) + .with_field("buffer_len", buffer.capacity())); } Ok(Array { @@ -84,10 +78,11 @@ where pub fn make_managed_from(&mut self, manager: &B, other: &mut Self) -> Result<()> { if self.datatype != other.datatype { - return Err( - RayexecError::new("Attempted to make array managed with data from other array with different data types") - .with_field("own_datatype", self.datatype.clone()) - .with_field("other_datatype", other.datatype.clone())); + return Err(RayexecError::new( + "Attempted to make array managed with data from other array with different data types", + ) + .with_field("own_datatype", self.datatype.clone()) + .with_field("other_datatype", other.datatype.clone())); } let managed = other.data.make_managed(manager)?; @@ -116,11 +111,7 @@ where /// Selects indice from the array. /// /// This will convert the underlying array buffer into a dictionary buffer. - pub fn select( - &mut self, - manager: &B, - selection: impl IntoExactSizeIterator, - ) -> Result<()> { + pub fn select(&mut self, manager: &B, selection: impl IntoExactSizeIterator) -> Result<()> { if self.is_dictionary() { // Already dictionary, select the selection. let sel = selection.into_iter(); @@ -161,10 +152,7 @@ where // Now replace the original buffer, and put the original buffer in the // secondary buffer. - let orig_validity = std::mem::replace( - &mut self.validity, - Validity::new_all_valid(new_buf.capacity()), - ); + let orig_validity = std::mem::replace(&mut self.validity, Validity::new_all_valid(new_buf.capacity())); let orig_buffer = std::mem::replace(&mut self.data, ArrayData::owned(new_buf)); // TODO: Should just clone the pointer if managed. *self.data.try_as_mut()?.secondary_buffers_mut() = @@ -220,6 +208,63 @@ where Ok(()) } + + /// Copy rows from self to another array. + /// + /// `mapping` provides a mapping of source indices to destination indices in + /// (source, dest) pairs. + pub fn copy_rows(&self, mapping: impl IntoExactSizeIterator, dest: &mut Self) -> Result<()> { + match self.datatype.physical_type() { + PhysicalType::Int8 => copy_rows::(self, mapping, dest)?, + PhysicalType::Int32 => copy_rows::(self, mapping, dest)?, + PhysicalType::Utf8 => copy_rows::(self, mapping, dest)?, + _ => unimplemented!(), + } + + Ok(()) + } + + /// Helper fo copying a single row from self to a destination array. + pub fn copy_row(&self, source_idx: usize, dest: &mut Self, dest_idx: usize) -> Result<()> { + let mapping = [(source_idx, dest_idx)]; + self.copy_rows(mapping, dest) + } +} + +fn copy_rows( + from: &Array, + mapping: impl IntoExactSizeIterator, + to: &mut Array, +) -> Result<()> +where + S: MutablePhysicalStorage, + B: BufferManager, +{ + let from_flat = from.flat_view()?; + let from_storage = S::get_storage(from_flat.array_buffer)?; + + let to_data = to.data.try_as_mut()?; + let mut to_storage = S::get_storage_mut(to_data)?; + + if from_flat.validity.all_valid() && to.validity.all_valid() { + for (from_idx, to_idx) in mapping.into_iter() { + let from_idx = from_flat.selection.get(from_idx).unwrap(); + let v = from_storage.get(from_idx).unwrap(); + to_storage.put(to_idx, v); + } + } else { + for (from_idx, to_idx) in mapping.into_iter() { + let from_idx = from_flat.selection.get(from_idx).unwrap(); + if from_flat.validity.is_valid(from_idx) { + let v = from_storage.get(from_idx).unwrap(); + to_storage.put(to_idx, v); + } else { + to.validity.set_invalid(to_idx); + } + } + } + + Ok(()) } #[derive(Debug)] @@ -329,3 +374,27 @@ where ArrayData::as_ref(&self) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::arrays::buffer::Int32BufferBuilder; + use crate::arrays::executor::scalar::unary::UnaryExecutor; + + #[test] + fn copy_rows_i32() { + let array1 = Array::new_with_buffer(DataType::Int32, Int32BufferBuilder::from_iter([4, 5, 6]).unwrap()); + let mut array2 = Array::new(&NopBufferManager, DataType::Int32, 3).unwrap(); + + // Copies the reverse. + array1.copy_rows([(0, 2), (1, 1), (2, 0)], &mut array2).unwrap(); + + let mut out = vec![0; 3]; + UnaryExecutor::for_each_flat::(array2.flat_view().unwrap(), 0..3, |idx, v| { + out[idx] = v.copied().unwrap(); + }) + .unwrap(); + + assert_eq!(vec![6, 5, 4], out); + } +} diff --git a/crates/rayexec_execution/src/arrays/batch.rs b/crates/rayexec_execution/src/arrays/batch.rs index b538bbf63..ac9b3a75f 100644 --- a/crates/rayexec_execution/src/arrays/batch.rs +++ b/crates/rayexec_execution/src/arrays/batch.rs @@ -11,6 +11,7 @@ use super::flat_array::FlatSelection; pub struct Batch { pub(crate) arrays: Vec>, pub(crate) num_rows: usize, + pub(crate) capacity: usize, } impl Batch @@ -25,6 +26,7 @@ where Batch { arrays: Vec::new(), num_rows, + capacity: 0, } } @@ -37,7 +39,11 @@ where arrays.push(array) } - Ok(Batch { arrays, num_rows: 0 }) + Ok(Batch { + arrays, + num_rows: 0, + capacity, + }) } /// Create a new batch from some number of arrays. @@ -55,6 +61,7 @@ where return Ok(Batch { arrays: Vec::new(), num_rows: 0, + capacity: 0, }) } }; @@ -72,6 +79,7 @@ where Ok(Batch { arrays, num_rows: if rows_eq_cap { capacity } else { 0 }, + capacity, }) } @@ -116,6 +124,10 @@ where Ok(()) } + pub fn capacity(&self) -> usize { + self.capacity + } + pub fn num_rows(&self) -> usize { self.num_rows } diff --git a/crates/rayexec_execution/src/execution/operators_exp/batch_collection.rs b/crates/rayexec_execution/src/execution/operators_exp/batch_collection.rs index 314c9458d..7d5d4ccce 100644 --- a/crates/rayexec_execution/src/execution/operators_exp/batch_collection.rs +++ b/crates/rayexec_execution/src/execution/operators_exp/batch_collection.rs @@ -82,13 +82,7 @@ where for (from, to) in batch.arrays.iter().zip(self.arrays.iter_mut()) { // [0..batch_num_rows) => [self_row_count..) let mapping = (0..batch.num_rows()).zip(self.row_count..(self.row_count + batch.num_rows())); - - match to.datatype.physical_type() { - PhysicalType::Int8 => copy_rows::(from, mapping, to)?, - PhysicalType::Int32 => copy_rows::(from, mapping, to)?, - PhysicalType::Utf8 => copy_rows::(from, mapping, to)?, - _ => unimplemented!(), - } + from.copy_rows(mapping, to)?; } self.row_count += batch.num_rows(); @@ -108,13 +102,7 @@ where for (from, to) in source.arrays().iter().zip(self.arrays.iter_mut()) { let mapping = [(source_row, dest_row)]; - - match to.datatype.physical_type() { - PhysicalType::Int8 => copy_rows::(from, mapping, to)?, - PhysicalType::Int32 => copy_rows::(from, mapping, to)?, - PhysicalType::Utf8 => copy_rows::(from, mapping, to)?, - _ => unimplemented!(), - } + from.copy_rows(mapping, to)?; } Ok(()) @@ -130,46 +118,6 @@ where } } -/// Copy rows from `from` to `to`. -/// -/// `mapping` provides a mapping of source to destination rows in the form of -/// pairs (from, to). -fn copy_rows( - from: &Array, - mapping: impl IntoExactSizeIterator, - to: &mut Array, -) -> Result<()> -where - S: MutablePhysicalStorage, - B: BufferManager, -{ - let from_flat = from.flat_view()?; - let from_storage = S::get_storage(from_flat.array_buffer)?; - - let to_data = to.data.try_as_mut()?; - let mut to_storage = S::get_storage_mut(to_data)?; - - if from_flat.validity.all_valid() && to.validity.all_valid() { - for (from_idx, to_idx) in mapping.into_iter() { - let from_idx = from_flat.selection.get(from_idx).unwrap(); - let v = from_storage.get(from_idx).unwrap(); - to_storage.put(to_idx, v); - } - } else { - for (from_idx, to_idx) in mapping.into_iter() { - let from_idx = from_flat.selection.get(from_idx).unwrap(); - if from_flat.validity.is_valid(from_idx) { - let v = from_storage.get(from_idx).unwrap(); - to_storage.put(to_idx, v); - } else { - to.validity.set_invalid(to_idx); - } - } - } - - Ok(()) -} - #[cfg(test)] mod tests { use super::*; diff --git a/crates/rayexec_execution/src/execution/operators_exp/mod.rs b/crates/rayexec_execution/src/execution/operators_exp/mod.rs index 78ca822bd..2805fce62 100644 --- a/crates/rayexec_execution/src/execution/operators_exp/mod.rs +++ b/crates/rayexec_execution/src/execution/operators_exp/mod.rs @@ -6,7 +6,7 @@ use std::fmt::Debug; use std::task::Context; use physical_project::ProjectPartitionState; -use physical_sort::partition_state::SortPartitionState; +use physical_sort::{SortOperatorState, SortPartitionState}; use rayexec_error::Result; use crate::arrays::batch::Batch; @@ -47,6 +47,41 @@ pub enum PollFinalize { Pending, } +#[derive(Debug)] +pub enum PartitionAndOperatorStates { + /// Operators that have a single input/output. + Branchless { + /// Global operator state. + operator_state: OperatorState, + /// State per-partition. + partition_states: Vec, + }, + /// Operators that produce 1 or more output branches. + /// + /// Mostly for materializations. + BranchingOutput { + /// Global operator state. + operator_state: OperatorState, + /// Single set of input states. + inputs_states: Vec, + /// Multiple sets of output states. + output_states: Vec>, + }, + /// Operators that have two children, with this operator acting as the + /// "sink" for one child. + /// + /// For joins, the build side is the terminating input, while the probe side + /// is non-terminating. + TerminatingInput { + /// Global operator state. + operator_state: OperatorState, + /// States for the input that is non-terminating. + nonterminating_states: Vec, + /// States for the input that is terminated by this operator. + terminating_states: Vec, + }, +} + #[derive(Debug)] pub enum PartitionState { Project(ProjectPartitionState), @@ -56,6 +91,7 @@ pub enum PartitionState { #[derive(Debug)] pub enum OperatorState { + Sort(SortOperatorState), None, } diff --git a/crates/rayexec_execution/src/execution/operators_exp/physical_sort/merge.rs b/crates/rayexec_execution/src/execution/operators_exp/physical_sort/merge.rs index 80940e711..09a863b5a 100644 --- a/crates/rayexec_execution/src/execution/operators_exp/physical_sort/merge.rs +++ b/crates/rayexec_execution/src/execution/operators_exp/physical_sort/merge.rs @@ -1,12 +1,66 @@ use std::cmp::{Ordering, Reverse}; -use std::collections::{BinaryHeap, VecDeque}; +use std::collections::BinaryHeap; use rayexec_error::Result; -use super::sort_data::{SortBlock, SortData}; -use super::sort_layout::SortLayout; +use super::sort_data::SortBlock; +use crate::arrays::batch::Batch; use crate::arrays::buffer_manager::BufferManager; +/// Trait for allows writing the result of a merge to either a Batch or a +/// SortBlock. +/// +/// Merging happens twice during a physical sort. The first is when sorting the +/// data local to a partition. The results get written to a new SortBlock. +/// +/// The second is when we output the globally sorted data. We can just write +/// that to the output batch since we don't need to keep the encoded data +/// around. +/// +/// This trait just facilitates using either a Batch or SortData with the +/// merger. +pub trait MergeOutput { + /// Return the row capacity of self. + fn capacity(&self) -> usize; + + /// Copy the row represented by the heap entry into self at the `dest_idx`. + fn copy_row_from_entry(&mut self, dest_idx: usize, ent: &HeapEntry) -> Result<()>; +} + +impl MergeOutput for Batch +where + B: BufferManager, +{ + fn capacity(&self) -> usize { + self.capacity() + } + + fn copy_row_from_entry(&mut self, dest_idx: usize, ent: &HeapEntry) -> Result<()> { + for (source, dest) in ent.queue.current.block.block.arrays().iter().zip(self.arrays_mut()) { + let source_idx = ent.queue.current.curr_idx; + source.copy_row(source_idx, dest, dest_idx)?; + } + + Ok(()) + } +} + +impl MergeOutput for SortBlock +where + B: BufferManager, +{ + fn capacity(&self) -> usize { + self.block.capacity() + } + + fn copy_row_from_entry(&mut self, dest_idx: usize, ent: &HeapEntry) -> Result<()> { + let source_block = &ent.queue.current.block; + let source_idx = ent.queue.current.curr_idx; + + self.copy_row_from_other(dest_idx, source_block, source_idx) + } +} + /// A block containing sorted rows that's being merged with other blocks. #[derive(Debug)] pub struct MergingSortBlock { @@ -18,15 +72,56 @@ pub struct MergingSortBlock { #[derive(Debug)] pub struct MergeQueue { - pub exhausted: bool, - pub current: MergingSortBlock, - pub remaining: Vec>, // Pop from back to front. + exhausted: bool, + current: MergingSortBlock, + remaining: Vec>, // Pop from back to front. } impl MergeQueue where B: BufferManager, { + /// Create a new queue with a single block. + /// + /// May return None if the sort block contains no rows. + pub fn new_single(sort_block: SortBlock) -> Option { + if sort_block.row_count() == 0 { + return None; + } + + Some(MergeQueue { + exhausted: false, + current: MergingSortBlock { + curr_idx: 0, + block: sort_block, + }, + remaining: Vec::new(), + }) + } + + /// Create a new queue of blocks. + /// + /// May return None if there's no blocks with any data. + pub fn new(mut sort_blocks: Vec>) -> Option { + loop { + match sort_blocks.pop() { + Some(first) => { + if first.block.row_count() > 0 { + return Some(MergeQueue { + exhausted: false, + current: MergingSortBlock { + curr_idx: 0, + block: first, + }, + remaining: sort_blocks, + }); + } + } + None => return None, + } + } + } + fn prepare_next_row(&mut self) { if self.exhausted { return; @@ -58,18 +153,26 @@ where #[derive(Debug)] pub struct Merger { - pub queues: Vec>, - pub layout: SortLayout, - pub out_capacity: usize, + queues: Vec>, } impl Merger where B: BufferManager, { - /// Do a single round of merging. - pub fn merge_round(&mut self, manager: &B) -> Result>> { - let mut out_block = SortBlock::new(manager, &self.layout, self.out_capacity)?; + /// Create a new merger using the given queues. + pub fn new(queues: Vec>) -> Self { + Merger { queues } + } + + /// Do a single round of merging, writing the output to `out`. + /// + /// The number of rows written will be written. + pub fn merge_round(&mut self, out: &mut impl MergeOutput) -> Result { + // TODO: Optimization, if only a single queue remains, just drain + // instead of building min heap. + + let out_capacity = out.capacity(); // Min heap containing at most one entry from each queue of blocks we're // merging. @@ -90,22 +193,17 @@ where })); } - for row_idx in 0..self.out_capacity { + for row_idx in 0..out_capacity { let ent = match min_heap.pop() { Some(ent) => ent, None => { - // If heap is empty, we exhausted all queues. If out is - // empty, then just return None. - if out_block.row_count() == 0 { - return Ok(None); - } else { - return Ok(Some(out_block)); - } + // If heap is empty, we exhausted all queues. + return Ok(row_idx); } }; // Copy the row to out. - out_block.copy_row_from_other(row_idx, &ent.0.queue.current.block, ent.0.row_idx)?; + out.copy_row_from_entry(row_idx, &ent.0)?; // Get next entry for the queue and put into heap. let queue = ent.0.queue; @@ -123,7 +221,8 @@ where min_heap.push(ent); } - Ok(Some(out_block)) + // We wrote the entire capacity of the block. + Ok(out_capacity) } } @@ -131,7 +230,7 @@ where /// /// Eq and Ord comparisons delegate the key buffer this entry represents. #[derive(Debug)] -struct HeapEntry<'a, B: BufferManager> { +pub struct HeapEntry<'a, B: BufferManager> { /// The queue this entry was from. queue: &'a mut MergeQueue, /// Row index within the block this entry is for. diff --git a/crates/rayexec_execution/src/execution/operators_exp/physical_sort/mod.rs b/crates/rayexec_execution/src/execution/operators_exp/physical_sort/mod.rs index 110016d64..d57bd4423 100644 --- a/crates/rayexec_execution/src/execution/operators_exp/physical_sort/mod.rs +++ b/crates/rayexec_execution/src/execution/operators_exp/physical_sort/mod.rs @@ -1,5 +1,4 @@ pub mod encode; -pub mod partition_state; mod merge; mod sort_data; @@ -7,38 +6,175 @@ mod sort_layout; use std::task::Context; -use rayexec_error::{OptionExt, Result}; +use merge::{MergeQueue, Merger}; +use parking_lot::Mutex; +use rayexec_error::{OptionExt, RayexecError, Result}; +use sort_data::{SortBlock, SortData}; use sort_layout::SortLayout; use super::{ExecutableOperator, ExecuteInOutState, OperatorState, PartitionState, PollExecute, PollFinalize}; +use crate::arrays::buffer_manager::NopBufferManager; use crate::explain::explainable::{ExplainConfig, ExplainEntry, Explainable}; -use crate::expr::physical::evaluator::ExpressionEvaluator; -use crate::expr::physical::{PhysicalScalarExpression, PhysicalSortExpression}; + +#[derive(Debug)] +pub enum SortPartitionState { + /// Partition is consuming data, building up locally sorted blocks. + Consume(PartitionStateLocalSort), + /// Partition is producing data from all input partitions. + /// + /// Only a single partition will ever be in this state (to enforce data + /// being globally sorted). + Produce(PartitionStateGlobalSort), + /// Partition is done producing any data. + Finished, +} + +#[derive(Debug)] +pub struct PartitionStateLocalSort { + output_capacity: usize, + /// Partition-local sort data. + sort_data: SortData, +} + +#[derive(Debug)] +pub struct PartitionStateGlobalSort { + /// Merger containing all sort blocks from all partitions. + merger: Merger, +} + +#[derive(Debug)] +pub struct SortOperatorState { + inner: Mutex, +} + +#[derive(Debug)] +struct SortOperatorStateInner { + /// Number of partitions we're still waiting on before we can start + /// producing output. + remaining: usize, + /// Queues from each partition. + /// + /// Each queue contains totally ordered blocks within that queue. These will + /// be the inputs to the final global merge. + /// + /// Pushed to when a partition is finalized. + queues: Vec>, +} #[derive(Debug)] pub struct PhysicalSort { pub(crate) layout: SortLayout, - pub(crate) exprs: Vec, } impl ExecutableOperator for PhysicalSort { fn poll_execute( &self, - cx: &mut Context, + _cx: &mut Context, partition_state: &mut PartitionState, - operator_state: &OperatorState, + _operator_state: &OperatorState, inout: ExecuteInOutState, ) -> Result { - unimplemented!() + let state = match partition_state { + PartitionState::Sort(state) => state, + other => panic!("invalid state: {other:?}"), + }; + + match state { + SortPartitionState::Consume(state) => { + let batch = inout.input.required("input batch required")?; + state.sort_data.push_batch(&NopBufferManager, &self.layout, batch)?; + + // TODO: Threshold begin sort. + + Ok(PollExecute::NeedsMore) + } + SortPartitionState::Produce(state) => { + let out = inout.output.required("output batch required")?; + let count = state.merger.merge_round(out)?; + + // Update output batch with correct number of rows. + out.set_num_rows(count)?; + + if out.num_rows() == 0 { + Ok(PollExecute::Exhausted) + } else { + Ok(PollExecute::HasMore) + } + } + SortPartitionState::Finished => Ok(PollExecute::Exhausted), + } } fn poll_finalize( &self, - cx: &mut Context, + _cx: &mut Context, partition_state: &mut PartitionState, operator_state: &OperatorState, ) -> Result { - unimplemented!() + let state = match partition_state { + PartitionState::Sort(state) => state, + other => panic!("invalid partition state: {other:?}"), + }; + + match state { + SortPartitionState::Consume(consume_state) => { + // Ensure local sort data is completely sorted with blocks. + consume_state + .sort_data + .sort_unsorted_blocks(&NopBufferManager, &self.layout)?; + + // Create a queue per block. + let blocks = consume_state.sort_data.take_sorted_for_merge(); + let queues: Vec<_> = blocks + .into_iter() + .flat_map(|block| MergeQueue::new_single(block)) + .collect(); + + let mut merger = Merger::new(queues); + + // Merge until we run out of mergeable rows. + let mut blocks = Vec::new(); + loop { + let mut block = SortBlock::new(&NopBufferManager, &self.layout, consume_state.output_capacity)?; + let count = merger.merge_round(&mut block)?; + + if count == 0 { + // No rows, we're done merging. + break; + } + + blocks.push(block); + } + + let mut operator_state = match operator_state { + OperatorState::Sort(state) => state.inner.lock(), + other => panic!("invalid operator state: {other:?}"), + }; + + let partition_queue = MergeQueue::new(blocks); + if let Some(queue) = partition_queue { + operator_state.queues.push(queue); + } + + operator_state.remaining -= 1; + + // If we're the last partition, take the queues and start the + // global merge. + if operator_state.remaining == 0 { + let queues = std::mem::take(&mut operator_state.queues); + let global_merger = Merger::new(queues); + + *state = SortPartitionState::Produce(PartitionStateGlobalSort { merger: global_merger }); + + return Ok(PollFinalize::NeedsDrain); + } + + // Otherwise this partition's finished. + Ok(PollFinalize::Finalized) + } + SortPartitionState::Produce(_) => Err(RayexecError::new("cannot finalize partition that's producing")), + SortPartitionState::Finished => Err(RayexecError::new("cannot finalize partition that's finished")), + } } } diff --git a/crates/rayexec_execution/src/execution/operators_exp/physical_sort/partition_state.rs b/crates/rayexec_execution/src/execution/operators_exp/physical_sort/partition_state.rs deleted file mode 100644 index 393ca9aca..000000000 --- a/crates/rayexec_execution/src/execution/operators_exp/physical_sort/partition_state.rs +++ /dev/null @@ -1,20 +0,0 @@ -use rayexec_error::Result; - -use crate::arrays::batch::Batch; -use crate::expr::physical::PhysicalSortExpression; - -#[derive(Debug)] -pub struct SortPartitionState { - /// Selection indices that would produce a sorted batch. - selection: Vec, -} - -impl SortPartitionState { - pub fn push_local_batch( - &mut self, - exprs: &[PhysicalSortExpression], - batch: &mut Batch, - ) -> Result<()> { - unimplemented!() - } -} diff --git a/crates/rayexec_execution/src/execution/operators_exp/physical_sort/sort_data.rs b/crates/rayexec_execution/src/execution/operators_exp/physical_sort/sort_data.rs index 950cd9ee1..7633703a4 100644 --- a/crates/rayexec_execution/src/execution/operators_exp/physical_sort/sort_data.rs +++ b/crates/rayexec_execution/src/execution/operators_exp/physical_sort/sort_data.rs @@ -14,12 +14,8 @@ use crate::expr::physical::PhysicalSortExpression; #[derive(Debug)] pub struct SortData { - /// Buffer manager. - manager: B, /// Capacity per block. max_per_block: usize, - /// Layout indicating how we're performing the sort. - layout: SortLayout, /// Blocks not yet sorted. Can continue to be written to. unsorted: Vec>, /// Sorted blocks with each block containing sorted rows. @@ -32,25 +28,16 @@ impl SortData where B: BufferManager, { - pub fn new( - manager: B, - max_per_block: usize, - input_types: Vec, - sort_exprs: &[PhysicalSortExpression], - ) -> Result { - let layout = SortLayout::new(input_types, sort_exprs); - - Ok(SortData { - manager, + pub fn new(max_per_block: usize) -> Self { + SortData { max_per_block, - layout, unsorted: Vec::new(), sorted: Vec::new(), - }) + } } - pub fn push_batch(&mut self, batch: &Batch) -> Result<()> { - let mut block = self.pop_or_allocate_unsorted_block(batch.num_rows())?; + pub fn push_batch(&mut self, manager: &B, layout: &SortLayout, batch: &Batch) -> Result<()> { + let mut block = self.pop_or_allocate_unsorted_block(manager, layout, batch.num_rows())?; let mut add_offset = 0; @@ -58,9 +45,9 @@ where let curr_offset = block.key_encode_offsets[block.block.row_count()]; let buf = &mut block.key_encode_buffer[curr_offset..]; - for (idx, sort_col) in self.layout.key_columns.iter().enumerate() { - let nulls_first = self.layout.key_nulls_first[idx]; - let desc = self.layout.key_desc[idx]; + for (idx, sort_col) in layout.key_columns.iter().enumerate() { + let nulls_first = layout.key_nulls_first[idx]; + let desc = layout.key_desc[idx]; let offsets = &block.key_encode_offsets; let key_array = batch.get_array(*sort_col)?; @@ -71,7 +58,7 @@ where // Update add offet to get to the correct offset for subsequent // keys. - add_offset += self.layout.key_sizes[idx]; + add_offset += layout.key_sizes[idx]; } block.block.append_batch_data(batch)?; @@ -81,12 +68,17 @@ where Ok(()) } - pub fn sort_unsorted_blocks(&mut self) -> Result<()> { + pub fn take_sorted_for_merge(&mut self) -> Vec> { + debug_assert_eq!(0, self.unsorted.len()); + std::mem::take(&mut self.sorted) + } + + pub fn sort_unsorted_blocks(&mut self, manager: &B, layout: &SortLayout) -> Result<()> { let mut sort_indices_buf = Vec::new(); for block in self.unsorted.drain(..) { sort_indices_buf.resize(block.block.row_count(), 0); - let sorted = block.sort(&self.manager, &self.layout, &mut sort_indices_buf)?; + let sorted = block.sort(manager, layout, &mut sort_indices_buf)?; self.sorted.push(sorted); } @@ -97,7 +89,12 @@ where /// rows. Otherwise we allocate a new block. /// /// Pops to satisfy lifetimes more easily. - fn pop_or_allocate_unsorted_block(&mut self, count: usize) -> Result> { + fn pop_or_allocate_unsorted_block( + &mut self, + manager: &B, + layout: &SortLayout, + count: usize, + ) -> Result> { debug_assert!(count <= self.max_per_block); if let Some(last) = self.unsorted.last() { @@ -106,7 +103,7 @@ where } } - let block = SortBlock::new(&self.manager, &self.layout, self.max_per_block)?; + let block = SortBlock::new(manager, layout, self.max_per_block)?; Ok(block) } @@ -224,17 +221,16 @@ mod tests { #[test] fn sort_i32_batches() { - let mut sort_data = SortData::new( - NopBufferManager, - 4096, + let layout = SortLayout::new( vec![DataType::Int32], &[PhysicalSortExpression { column: PhysicalColumnExpr { idx: 0 }, desc: false, nulls_first: false, }], - ) - .unwrap(); + ); + + let mut sort_data = SortData::new(4096); let batch1 = Batch::from_arrays( [Array::new_with_buffer( @@ -253,10 +249,10 @@ mod tests { ) .unwrap(); - sort_data.push_batch(&batch1).unwrap(); - sort_data.push_batch(&batch2).unwrap(); + sort_data.push_batch(&NopBufferManager, &layout, &batch1).unwrap(); + sort_data.push_batch(&NopBufferManager, &layout, &batch2).unwrap(); - sort_data.sort_unsorted_blocks().unwrap(); + sort_data.sort_unsorted_blocks(&NopBufferManager, &layout).unwrap(); assert_eq!(0, sort_data.unsorted.len()); assert_eq!(1, sort_data.sorted.len());