From 96b7a9a092d45705d5e1786be722ef6352db55cc Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Wed, 25 Dec 2024 10:08:48 +0100 Subject: [PATCH] perf: Dedup binviews up front (#20449) --- .../src/chunked_array/gather/chunked.rs | 95 +++++++++---------- 1 file changed, 43 insertions(+), 52 deletions(-) diff --git a/crates/polars-ops/src/chunked_array/gather/chunked.rs b/crates/polars-ops/src/chunked_array/gather/chunked.rs index 58ceafea92c9..eb02399a2f3e 100644 --- a/crates/polars-ops/src/chunked_array/gather/chunked.rs +++ b/crates/polars-ops/src/chunked_array/gather/chunked.rs @@ -4,7 +4,6 @@ use arrow::array::{Array, BinaryViewArrayGeneric, View, ViewType}; use arrow::bitmap::BitmapBuilder; use arrow::buffer::Buffer; use arrow::legacy::trusted_len::TrustedLenPush; -use hashbrown::hash_map::Entry; use polars_core::prelude::gather::_update_gather_sorted_flag; use polars_core::prelude::*; use polars_core::series::IsSorted; @@ -431,29 +430,6 @@ unsafe fn take_opt_unchecked_object(s: &Series, by: &[ChunkId]) builder.to_series() } -unsafe fn update_view( - mut view: View, - orig_buffers: &[Buffer], - buffer_idxs: &mut PlHashMap<(*const u8, usize), u32>, - buffers: &mut Vec>, -) -> View { - if view.length > 12 { - // Dedup on pointer + length. - let orig_buffer = orig_buffers.get_unchecked(view.buffer_idx as usize); - view.buffer_idx = - match buffer_idxs.entry((orig_buffer.as_slice().as_ptr(), orig_buffer.len())) { - Entry::Occupied(o) => *o.get(), - Entry::Vacant(v) => { - let buffer_idx = buffers.len() as u32; - buffers.push(orig_buffer.clone()); - v.insert(buffer_idx); - buffer_idx - }, - }; - } - view -} - unsafe fn take_unchecked_binview( ca: &ChunkedArray, by: &[ChunkId], @@ -497,8 +473,7 @@ where arc_data_buffers = arr.data_buffers().clone(); } else { - let mut buffer_idxs = PlHashMap::with_capacity(8); - let mut buffers = Vec::with_capacity(8); + let (buffers, buffer_offsets) = dedup_buffers(ca); validity = if ca.has_nulls() { let mut validity = BitmapBuilder::with_capacity(by.len()); @@ -511,12 +486,8 @@ where validity.push_unchecked(false); } else { let view = *arr.views().get_unchecked(array_idx as usize); - views.push_unchecked(update_view( - view, - arr.data_buffers(), - &mut buffer_idxs, - &mut buffers, - )); + let view = rewrite_view(view, chunk_idx, &buffer_offsets); + views.push_unchecked(view); validity.push_unchecked(true); } } @@ -527,12 +498,8 @@ where let arr = ca.downcast_get_unchecked(chunk_idx as usize); let view = *arr.views().get_unchecked(array_idx as usize); - views.push_unchecked(update_view( - view, - arr.data_buffers(), - &mut buffer_idxs, - &mut buffers, - )); + let view = rewrite_view(view, chunk_idx, &buffer_offsets); + views.push_unchecked(view); } None }; @@ -554,6 +521,39 @@ where out } +#[allow(clippy::unnecessary_cast)] +#[inline(always)] +unsafe fn rewrite_view(mut view: View, chunk_idx: IdxSize, buffer_offsets: &[u32]) -> View { + if view.length > 12 { + let base_offset = *buffer_offsets.get_unchecked(chunk_idx as usize); + view.buffer_idx += base_offset; + } + view +} + +fn dedup_buffers(ca: &ChunkedArray) -> (Vec>, Vec) +where + T: PolarsDataType>, + V: ViewType + ?Sized, +{ + // Dedup buffers up front. Note: don't do this during view update, as this is much more + // costly. + let mut buffers = Vec::with_capacity(ca.chunks().len()); + // Dont need to include the length, as we look at the arc pointers, which are immutable. + let mut buffers_dedup = PlHashSet::with_capacity(ca.chunks().len()); + let mut buffer_offsets = Vec::with_capacity(ca.chunks().len() + 1); + + for arr in ca.downcast_iter() { + let data_buffers = arr.data_buffers(); + let arc_ptr = data_buffers.as_ptr(); + buffer_offsets.push(buffers.len() as u32); + if buffers_dedup.insert(arc_ptr) { + buffers.extend(data_buffers.iter().cloned()) + } + } + (buffers, buffer_offsets) +} + unsafe fn take_unchecked_binview_opt( ca: &ChunkedArray, by: &[ChunkId], @@ -599,8 +599,7 @@ where arr.data_buffers().clone() } else { - let mut buffer_idxs = PlHashMap::with_capacity(8); - let mut buffers = Vec::with_capacity(8); + let (buffers, buffer_offsets) = dedup_buffers(ca); if ca.has_nulls() { for id in by.iter() { @@ -616,12 +615,8 @@ where validity.push_unchecked(false); } else { let view = *arr.views().get_unchecked(array_idx as usize); - views.push_unchecked(update_view( - view, - arr.data_buffers(), - &mut buffer_idxs, - &mut buffers, - )); + let view = rewrite_view(view, chunk_idx, &buffer_offsets); + views.push_unchecked(view); validity.push_unchecked(true); } } @@ -636,12 +631,8 @@ where } else { let arr = ca.downcast_get_unchecked(chunk_idx as usize); let view = *arr.views().get_unchecked(array_idx as usize); - views.push_unchecked(update_view( - view, - arr.data_buffers(), - &mut buffer_idxs, - &mut buffers, - )); + let view = rewrite_view(view, chunk_idx, &buffer_offsets); + views.push_unchecked(view); validity.push_unchecked(true); } }