Skip to content

Commit

Permalink
fix: normalize distribution normalization w/ mcv (#163)
Browse files Browse the repository at this point in the history
- As title
- Contains refactoring for HLL & Tdigest
- Add support for 12 JOB queries

Next steps: get intermediate q-errors

---------

Co-authored-by: Gun9niR <gun9nir.guo@gmail.com>
  • Loading branch information
AlSchlo and Gun9niR authored Apr 18, 2024
1 parent e282624 commit 831df6f
Show file tree
Hide file tree
Showing 7 changed files with 168 additions and 130 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

83 changes: 33 additions & 50 deletions optd-datafusion-repr/src/cost/base_cost/stats.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,19 @@ use datafusion::arrow::array::{
};
use itertools::Itertools;
use optd_core::rel_node::{SerializableOrderedF64, Value};
use optd_gungnir::{
stats::{
counter::Counter,
hyperloglog::{self, HyperLogLog},
misragries::{self, MisraGries},
tdigest::{self, TDigest},
},
utils::arith_encoder,
use optd_gungnir::stats::{
counter::Counter,
hyperloglog::{self, HyperLogLog},
misragries::{self, MisraGries},
tdigest::{self, TDigest},
};
use ordered_float::OrderedFloat;
use serde::{de::DeserializeOwned, Deserialize, Serialize};

// The "standard" concrete types that optd currently uses.
// All of optd (except unit tests) must use the same types.
pub type DataFusionMostCommonValues = Counter<Vec<Option<Value>>>;
pub type DataFusionDistribution = TDigest;
pub type DataFusionDistribution = TDigest<Value>;

pub type DataFusionBaseTableStats =
BaseTableStats<DataFusionMostCommonValues, DataFusionDistribution>;
Expand All @@ -40,27 +37,14 @@ pub trait Distribution: 'static + Send + Sync {
fn cdf(&self, value: &Value) -> f64;
}

fn value_to_float(val: &Value) -> f64 {
match val {
Value::UInt8(v) => *v as f64,
Value::UInt16(v) => *v as f64,
Value::UInt32(v) => *v as f64,
Value::UInt64(v) => *v as f64,
Value::Int8(v) => *v as f64,
Value::Int16(v) => *v as f64,
Value::Int32(v) => *v as f64,
Value::Int64(v) => *v as f64,
Value::Float(v) => *v.0,
Value::Bool(v) => *v as i64 as f64,
Value::String(v) => arith_encoder::encode(v),
Value::Date32(v) => *v as f64,
_ => unreachable!(),
}
}

impl Distribution for TDigest {
impl Distribution for TDigest<Value> {
fn cdf(&self, value: &Value) -> f64 {
self.cdf(value_to_float(value))
let nb_rows = self.norm_weight;
if nb_rows == 0 {
self.cdf(value)
} else {
self.centroids.len() as f64 * self.cdf(value) / nb_rows as f64
}
}
}

Expand Down Expand Up @@ -161,7 +145,7 @@ impl<

pub type BaseTableStats<M, D> = HashMap<String, TableStats<M, D>>;

impl TableStats<Counter<ColumnCombValue>, TDigest> {
impl TableStats<Counter<ColumnCombValue>, TDigest<Value>> {
fn is_type_supported(data_type: &DataType) -> bool {
matches!(
data_type,
Expand Down Expand Up @@ -288,9 +272,8 @@ impl TableStats<Counter<ColumnCombValue>, TDigest> {
fn generate_partial_stats(
column_combs: &[Vec<ColumnCombValue>],
mgs: &mut [MisraGries<ColumnCombValue>],
hlls: &mut [HyperLogLog],
hlls: &mut [HyperLogLog<ColumnCombValue>],
null_counts: &mut [i32],
row_counts: &mut [i32],
) {
for (idx, column_comb) in column_combs.iter().enumerate() {
// TODO(Alexis): Redundant copy.
Expand All @@ -302,7 +285,6 @@ impl TableStats<Counter<ColumnCombValue>, TDigest> {
let nb_rows: i32 = column_comb.len() as i32;

null_counts[idx] += nb_rows - filtered_nulls.len() as i32;
row_counts[idx] += nb_rows;

mgs[idx].aggregate(&filtered_nulls);
hlls[idx].aggregate(&filtered_nulls);
Expand All @@ -312,25 +294,26 @@ impl TableStats<Counter<ColumnCombValue>, TDigest> {
fn generate_full_stats(
column_combs: &[Vec<ColumnCombValue>],
cnts: &mut [Counter<ColumnCombValue>],
distrs: &mut [Option<TDigest>],
distrs: &mut [Option<TDigest<Value>>],
row_counts: &mut [i32],
) {
for (idx, column_comb) in column_combs.iter().enumerate() {
// TODO(Alexis): Redundant copy.
// Here, we filter out mfks, so it's guaranteed to never be null.
let filtered_mfks: Vec<ColumnCombValue> = column_comb
.iter()
.filter(|row| cnts[idx].is_tracking(row))
.cloned()
.collect();
let nb_rows: i32 = column_comb.len() as i32;
row_counts[idx] += nb_rows;

cnts[idx].aggregate(&filtered_mfks);
if let Some(distr) = distrs[idx].take() {
cnts[idx].aggregate(column_comb);
if let Some(distr) = &mut distrs[idx] {
// TODO(Alexis): Redundant copy.
// We project it down to 1D, as we do not support nD TDigests.
let mut single_col_f64 = filtered_mfks
let single_col_filtered = column_comb
.iter()
.map(|row| value_to_float(row[0].as_ref().unwrap()))
.filter(|row| !cnts[idx].is_tracking(row))
.filter_map(|row| row[0].as_ref())
.cloned()
.collect_vec();
distrs[idx] = Some(distr.merge_values(&mut single_col_f64));

distr.norm_weight += nb_rows as usize;
distr.merge_values(&single_col_filtered);
}
}
}
Expand All @@ -356,11 +339,10 @@ impl TableStats<Counter<ColumnCombValue>, TDigest> {
});
}

// 1. FIRST PASS: hlls + mgs + null_cnts + row_cnts.
// 1. FIRST PASS: hlls + mgs + null_cnts.
let mut hlls = vec![HyperLogLog::new(hyperloglog::DEFAULT_PRECISION); nb_stats];
let mut mgs = vec![MisraGries::new(misragries::DEFAULT_K_TO_TRACK); nb_stats];
let mut null_cnts = vec![0; nb_stats];
let mut row_cnts = vec![0; nb_stats]; // All the same, but more convenient like this.

for batch in batch_iter {
let batch = batch?;
Expand All @@ -369,11 +351,10 @@ impl TableStats<Counter<ColumnCombValue>, TDigest> {
&mut mgs,
&mut hlls,
&mut null_cnts,
&mut row_cnts,
);
}

// 2. SECOND PASS: MCV + TDigest.
// 2. SECOND PASS: MCV + TDigest + row_cnts.
let batch_iter = batch_iter_builder()?;
let mut distrs = comb_stat_types
.iter()
Expand All @@ -389,13 +370,15 @@ impl TableStats<Counter<ColumnCombValue>, TDigest> {
Counter::new(&mfk)
})
.collect_vec();
let mut row_cnts = vec![0; nb_stats]; // All the same, but more convenient like this.

for batch in batch_iter {
let batch = batch?;
Self::generate_full_stats(
&Self::get_column_combs(&batch, &comb_stat_types),
&mut cnts,
&mut distrs,
&mut row_cnts,
);
}

Expand Down
1 change: 1 addition & 0 deletions optd-gungnir/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@ crossbeam = "0.8"
lazy_static = "1.4"
serde = {version = "1.0", features = ["derive"]}
serde_with = {version = "3.7.0", features = ["json"]}
ordered-float = "4"
optd-core = { path = "../optd-core" }
39 changes: 19 additions & 20 deletions optd-gungnir/src/stats/hyperloglog.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
use optd_core::rel_node::Value;

use crate::stats::murmur2::murmur_hash;
use std::cmp::max;
use std::{cmp::max, marker::PhantomData};

pub const DEFAULT_PRECISION: u8 = 12;

Expand All @@ -20,11 +20,13 @@ pub trait ByteSerializable {
/// The HyperLogLog (HLL) structure to provide a statistical estimate of NDistinct.
/// For safety reasons, HLLs can only count elements of the same ByteSerializable type.
#[derive(Clone)]
pub struct HyperLogLog {
pub struct HyperLogLog<T: ByteSerializable> {
registers: Vec<u8>, // The buckets to estimate HLL on (i.e. upper p bits).
precision: u8, // The precision (p) of our HLL; 4 <= p <= 16.
m: usize, // The number of HLL buckets; 2^p.
alpha: f64, // The normal HLL multiplier factor.

data_type: PhantomData<T>, // For type checker.
}

// Serialize optd's Value.
Expand Down Expand Up @@ -86,7 +88,10 @@ impl_byte_serializable_for_numeric!(usize, isize);
impl_byte_serializable_for_numeric!(f64, f32);

// Self-contained implementation of the HyperLogLog data structure.
impl HyperLogLog {
impl<T> HyperLogLog<T>
where
T: ByteSerializable,
{
/// Creates and initializes a new empty HyperLogLog.
pub fn new(precision: u8) -> Self {
assert!((4..=16).contains(&precision));
Expand All @@ -99,11 +104,13 @@ impl HyperLogLog {
precision,
m,
alpha,

data_type: PhantomData,
}
}

/// Digests an array of ByteSerializable data into the HLL.
pub fn aggregate<T>(&mut self, data: &[T])
pub fn aggregate(&mut self, data: &[T])
where
T: ByteSerializable,
{
Expand All @@ -117,23 +124,15 @@ impl HyperLogLog {

/// Merges two HLLs together and returns a new one.
/// Particularly useful for parallel execution.
/// NOTE: Takes ownership of self and other.
pub fn merge(self, other: HyperLogLog) -> Self {
pub fn merge(&mut self, other: &HyperLogLog<T>) {
assert!(self.precision == other.precision);

let merged_registers = self
self.registers = self
.registers
.into_iter()
.zip(other.registers)
.map(|(x, y)| x.max(y))
.iter()
.zip(other.registers.iter())
.map(|(&x, &y)| x.max(y))
.collect();

HyperLogLog {
registers: merged_registers,
precision: self.precision,
m: self.m,
alpha: self.alpha,
}
}

/// Returns an estimation of the n_distinct seen so far by the HLL.
Expand Down Expand Up @@ -256,7 +255,7 @@ mod tests {
let n_jobs = 16;
let relative_error = 0.05; // We allow a 5% relatative error rate.

let result_hll = Arc::new(Mutex::new(Option::Some(HyperLogLog::new(precision))));
let result_hll = Arc::new(Mutex::new(HyperLogLog::new(precision)));
let job_id = AtomicUsize::new(0);
thread::scope(|s| {
for _ in 0..n_jobs {
Expand All @@ -274,13 +273,13 @@ mod tests {
));

let mut result = result_hll.lock().unwrap();
*result = Option::Some(result.take().unwrap().merge(local_hll));
result.merge(&local_hll);
});
}
})
.unwrap();

let hll = result_hll.lock().unwrap().take().unwrap();
let hll = result_hll.lock().unwrap();
assert!(is_close(
hll.n_distinct() as f64,
(n_distinct * n_jobs) as f64,
Expand Down
Loading

0 comments on commit 831df6f

Please sign in to comment.