Skip to content

Commit

Permalink
refactor: Keep scalar in more places (#18775)
Browse files Browse the repository at this point in the history
  • Loading branch information
coastalwhite authored Sep 23, 2024
1 parent 58265f6 commit ea7953e
Show file tree
Hide file tree
Showing 25 changed files with 390 additions and 126 deletions.
16 changes: 15 additions & 1 deletion crates/polars-arrow/src/array/struct_/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ use crate::compute::utils::combine_validities_and;
#[derive(Clone)]
pub struct StructArray {
dtype: ArrowDataType,
// invariant: each array has the same length
values: Vec<Box<dyn Array>>,
validity: Option<Bitmap>,
}
Expand Down Expand Up @@ -226,6 +227,17 @@ impl StructArray {
impl StructArray {
#[inline]
fn len(&self) -> usize {
#[cfg(debug_assertions)]
if let Some(fst) = self.values.first() {
for arr in self.values.iter().skip(1) {
assert_eq!(
arr.len(),
fst.len(),
"StructArray invariant: each array has same length"
);
}
}

self.values.first().map(|arr| arr.len()).unwrap_or(0)
}

Expand All @@ -242,7 +254,9 @@ impl StructArray {

/// Returns the fields of this [`StructArray`].
pub fn fields(&self) -> &[Field] {
Self::get_fields(&self.dtype)
let fields = Self::get_fields(&self.dtype);
debug_assert_eq!(self.values().len(), fields.len());
fields
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@ where
pub fn upcast(&'a self) -> &'a RwLock<dyn MetadataTrait + 'a> {
&self.0 as &RwLock<dyn MetadataTrait + 'a>
}

/// Cast the [`IMMetadata`] to a boxed trait object of [`MetadataTrait`]
pub fn boxed_upcast(&'a self) -> Box<dyn MetadataTrait + 'a> {
Box::new(self.0.read().unwrap().clone()) as Box<dyn MetadataTrait + 'a>
}
}

impl<T: PolarsDataType> IMMetadata<T> {
Expand Down
7 changes: 7 additions & 0 deletions crates/polars-core/src/chunked_array/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,13 @@ where
pub fn metadata_dyn(&self) -> Option<RwLockReadGuard<dyn MetadataTrait>> {
self.md.as_ref().upcast().try_read().ok()
}

/// Attempt to get a reference to the trait object containing the [`ChunkedArray`]'s [`Metadata`]
///
/// This fails if there is a need to block.
pub fn boxed_metadata_dyn<'a>(&'a self) -> Box<dyn MetadataTrait + 'a> {
self.md.as_ref().boxed_upcast()
}
}

impl<T: PolarsDataType> ChunkedArray<T> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ impl PolarsExtension {
.get(0)
.unwrap()
.into_static()
.unwrap()
}

pub(crate) unsafe fn new(array: FixedSizeBinaryArray) -> Self {
Expand Down
12 changes: 6 additions & 6 deletions crates/polars-core/src/chunked_array/ops/aggregate/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -483,11 +483,11 @@ impl StringChunked {
impl ChunkAggSeries for StringChunked {
fn max_reduce(&self) -> Scalar {
let av: AnyValue = self.max_str().into();
Scalar::new(DataType::String, av.into_static().unwrap())
Scalar::new(DataType::String, av.into_static())
}
fn min_reduce(&self) -> Scalar {
let av: AnyValue = self.min_str().into();
Scalar::new(DataType::String, av.into_static().unwrap())
Scalar::new(DataType::String, av.into_static())
}
}

Expand Down Expand Up @@ -554,11 +554,11 @@ impl CategoricalChunked {
impl ChunkAggSeries for CategoricalChunked {
fn min_reduce(&self) -> Scalar {
let av: AnyValue = self.min_categorical().into();
Scalar::new(DataType::String, av.into_static().unwrap())
Scalar::new(DataType::String, av.into_static())
}
fn max_reduce(&self) -> Scalar {
let av: AnyValue = self.max_categorical().into();
Scalar::new(DataType::String, av.into_static().unwrap())
Scalar::new(DataType::String, av.into_static())
}
}

Expand Down Expand Up @@ -618,11 +618,11 @@ impl ChunkAggSeries for BinaryChunked {
}
fn max_reduce(&self) -> Scalar {
let av: AnyValue = self.max_binary().into();
Scalar::new(self.dtype().clone(), av.into_static().unwrap())
Scalar::new(self.dtype().clone(), av.into_static())
}
fn min_reduce(&self) -> Scalar {
let av: AnyValue = self.min_binary().into();
Scalar::new(self.dtype().clone(), av.into_static().unwrap())
Scalar::new(self.dtype().clone(), av.into_static())
}
}

Expand Down
183 changes: 124 additions & 59 deletions crates/polars-core/src/datatypes/any_value.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@
use std::borrow::Cow;

#[cfg(feature = "dtype-struct")]
use arrow::legacy::trusted_len::TrustedLenPush;
use arrow::types::PrimitiveType;
use polars_utils::format_pl_smallstr;
use polars_utils::itertools::Itertools;
#[cfg(feature = "dtype-struct")]
use polars_utils::slice::GetSaferUnchecked;
#[cfg(feature = "dtype-categorical")]
use polars_utils::sync::SyncPtr;
use polars_utils::total_ord::ToTotalOrd;
Expand Down Expand Up @@ -907,12 +902,34 @@ impl<'a> AnyValue<'a> {
}
}

pub(crate) fn to_i128(&self) -> Option<i128> {
match self {
AnyValue::UInt8(v) => Some((*v).into()),
AnyValue::UInt16(v) => Some((*v).into()),
AnyValue::UInt32(v) => Some((*v).into()),
AnyValue::UInt64(v) => Some((*v).into()),
AnyValue::Int8(v) => Some((*v).into()),
AnyValue::Int16(v) => Some((*v).into()),
AnyValue::Int32(v) => Some((*v).into()),
AnyValue::Int64(v) => Some((*v).into()),
_ => None,
}
}

pub(crate) fn to_f64(&self) -> Option<f64> {
match self {
AnyValue::Float32(v) => Some((*v).into()),
AnyValue::Float64(v) => Some(*v),
_ => None,
}
}

#[must_use]
pub fn add(&self, rhs: &AnyValue) -> AnyValue<'static> {
use AnyValue::*;
match (self, rhs) {
(Null, r) => r.clone().into_static().unwrap(),
(l, Null) => l.clone().into_static().unwrap(),
(Null, r) => r.clone().into_static(),
(l, Null) => l.clone().into_static(),
(Int32(l), Int32(r)) => Int32(l + r),
(Int64(l), Int64(r)) => Int64(l + r),
(UInt32(l), UInt32(r)) => UInt32(l + r),
Expand Down Expand Up @@ -961,9 +978,9 @@ impl<'a> AnyValue<'a> {
/// Try to coerce to an AnyValue with static lifetime.
/// This can be done if it does not borrow any values.
#[inline]
pub fn into_static(self) -> PolarsResult<AnyValue<'static>> {
pub fn into_static(self) -> AnyValue<'static> {
use AnyValue::*;
let av = match self {
match self {
Null => Null,
Int8(v) => Int8(v),
Int16(v) => Int16(v),
Expand Down Expand Up @@ -997,7 +1014,7 @@ impl<'a> AnyValue<'a> {
Object(v) => ObjectOwned(OwnedObject(v.to_boxed())),
#[cfg(feature = "dtype-struct")]
Struct(idx, arr, fields) => {
let avs = struct_to_avs_static(idx, arr, fields)?;
let avs = struct_to_avs_static(idx, arr, fields);
StructOwned(Box::new((avs, fields.to_vec())))
},
#[cfg(feature = "dtype-struct")]
Expand All @@ -1022,8 +1039,7 @@ impl<'a> AnyValue<'a> {
Enum(v, rev, arr) => EnumOwned(v, Arc::new(rev.clone()), arr),
#[cfg(feature = "dtype-categorical")]
EnumOwned(v, rev, arr) => EnumOwned(v, rev, arr),
};
Ok(av)
}
}

/// Get a reference to the `&str` contained within [`AnyValue`].
Expand Down Expand Up @@ -1070,6 +1086,37 @@ impl<'a> From<AnyValue<'a>> for Option<i64> {
impl AnyValue<'_> {
#[inline]
pub fn eq_missing(&self, other: &Self, null_equal: bool) -> bool {
fn struct_owned_value_iter<'a>(
v: &'a (Vec<AnyValue<'_>>, Vec<Field>),
) -> impl ExactSizeIterator<Item = AnyValue<'a>> {
v.0.iter().map(|v| v.as_borrowed())
}
fn struct_value_iter(
idx: usize,
arr: &StructArray,
) -> impl ExactSizeIterator<Item = AnyValue<'_>> {
assert!(idx < arr.len());

arr.values().iter().map(move |field_arr| unsafe {
// SAFETY: We asserted before that idx is smaller than the array length. Since it
// is an invariant of StructArray that all fields have the same length this is fine
// to do.
field_arr.get_unchecked(idx)
})
}

fn struct_eq_missing<'a>(
l: impl ExactSizeIterator<Item = AnyValue<'a>>,
r: impl ExactSizeIterator<Item = AnyValue<'a>>,
null_equal: bool,
) -> bool {
if l.len() != r.len() {
return false;
}

l.zip(r).all(|(lv, rv)| lv.eq_missing(&rv, null_equal))
}

use AnyValue::*;
match (self, other) {
// Map to borrowed.
Expand Down Expand Up @@ -1150,25 +1197,31 @@ impl AnyValue<'_> {
},
#[cfg(feature = "dtype-duration")]
(Duration(l, tu_l), Duration(r, tu_r)) => l == r && tu_l == tu_r,

#[cfg(feature = "dtype-struct")]
(StructOwned(l), StructOwned(r)) => {
let l_av = &*l.0;
let r_av = &*r.0;
l_av == r_av
},
(StructOwned(l), StructOwned(r)) => struct_eq_missing(
struct_owned_value_iter(l.as_ref()),
struct_owned_value_iter(r.as_ref()),
null_equal,
),
#[cfg(feature = "dtype-struct")]
(StructOwned(l), Struct(idx, arr, fields)) => {
l.0.iter()
.eq_by_(struct_av_iter(*idx, arr, fields), |lv, rv| *lv == rv)
},
(StructOwned(l), Struct(idx, arr, _)) => struct_eq_missing(
struct_owned_value_iter(l.as_ref()),
struct_value_iter(*idx, arr),
null_equal,
),
#[cfg(feature = "dtype-struct")]
(Struct(idx, arr, fields), StructOwned(r)) => {
struct_av_iter(*idx, arr, fields).eq_by_(r.0.iter(), |lv, rv| lv == *rv)
},
(Struct(idx, arr, _), StructOwned(r)) => struct_eq_missing(
struct_value_iter(*idx, arr),
struct_owned_value_iter(r.as_ref()),
null_equal,
),
#[cfg(feature = "dtype-struct")]
(Struct(l_idx, l_arr, l_fields), Struct(r_idx, r_arr, r_fields)) => {
struct_av_iter(*l_idx, l_arr, l_fields).eq(struct_av_iter(*r_idx, r_arr, r_fields))
},
(Struct(l_idx, l_arr, _), Struct(r_idx, r_arr, _)) => struct_eq_missing(
struct_value_iter(*l_idx, l_arr),
struct_value_iter(*r_idx, r_arr),
null_equal,
),
#[cfg(feature = "dtype-decimal")]
(Decimal(l_v, l_s), Decimal(r_v, r_s)) => {
// l_v / 10**l_s == r_v / 10**r_s
Expand Down Expand Up @@ -1198,9 +1251,34 @@ impl AnyValue<'_> {
},
#[cfg(feature = "object")]
(Object(l), Object(r)) => l == r,
#[cfg(feature = "dtype-array")]
(Array(l_values, l_size), Array(r_values, r_size)) => {
if l_size != r_size {
return false;
}

debug_assert_eq!(l_values.len(), *l_size);
debug_assert_eq!(r_values.len(), *r_size);

let mut is_equal = true;
for i in 0..*l_size {
let l = unsafe { l_values.get_unchecked(i) };
let r = unsafe { r_values.get_unchecked(i) };

is_equal &= l.eq_missing(&r, null_equal);
}
is_equal
},

(l, r) if l.to_i128().is_some() && r.to_i128().is_some() => l.to_i128() == r.to_i128(),
(l, r) if l.to_f64().is_some() && r.to_f64().is_some() => {
l.to_f64().unwrap().to_total_ord() == r.to_f64().unwrap().to_total_ord()
},

(_, _) => {
unimplemented!("ordering for mixed dtypes is not supported")
unimplemented!(
"scalar eq_missing for mixed dtypes {self:?} and {other:?} is not supported"
)
},
}
}
Expand Down Expand Up @@ -1346,7 +1424,9 @@ impl PartialOrd for AnyValue<'_> {
},

(_, _) => {
unimplemented!("ordering for mixed dtypes is not supported")
unimplemented!(
"scalar ordering for mixed dtypes {self:?} and {other:?} is not supported"
)
},
}
}
Expand All @@ -1360,23 +1440,22 @@ impl TotalEq for AnyValue<'_> {
}

#[cfg(feature = "dtype-struct")]
fn struct_to_avs_static(
idx: usize,
arr: &StructArray,
fields: &[Field],
) -> PolarsResult<Vec<AnyValue<'static>>> {
fn struct_to_avs_static(idx: usize, arr: &StructArray, fields: &[Field]) -> Vec<AnyValue<'static>> {
assert!(idx < arr.len());

let arrs = arr.values();
let mut avs = Vec::with_capacity(arrs.len());
// amortize loop counter
for i in 0..arrs.len() {
unsafe {
let arr = &**arrs.get_unchecked_release(i);
let field = fields.get_unchecked_release(i);
let av = arr_to_any_value(arr, idx, &field.dtype);
avs.push_unchecked(av.into_static()?);
}
}
Ok(avs)

debug_assert_eq!(arrs.len(), fields.len());

arrs.iter()
.zip(fields)
.map(|(arr, field)| {
// SAFETY: We asserted above that the length of StructArray is larger than `idx`. Since
// StructArray has the invariant that each array is the same length. This is okay to do
// now.
unsafe { arr_to_any_value(arr.as_ref(), idx, &field.dtype) }.into_static()
})
.collect()
}

#[cfg(feature = "dtype-categorical")]
Expand All @@ -1397,20 +1476,6 @@ fn same_revmap(
}
}

#[cfg(feature = "dtype-struct")]
fn struct_av_iter<'a>(
idx: usize,
arr: &'a StructArray,
fields: &'a [Field],
) -> impl Iterator<Item = AnyValue<'a>> {
let arrs = arr.values();
(0..arrs.len()).map(move |i| unsafe {
let arr = &**arrs.get_unchecked_release(i);
let field = fields.get_unchecked_release(i);
arr_to_any_value(arr, idx, &field.dtype)
})
}

pub trait GetAnyValue {
/// # Safety
///
Expand Down
Loading

0 comments on commit ea7953e

Please sign in to comment.