From dea0679d2b7d7a1223efd5c2d4d342f003b36948 Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Tue, 30 Jul 2024 13:57:55 +0200 Subject: [PATCH] feat(python): IO plugins (#17939) --- Cargo.lock | 1 + crates/polars-lazy/src/frame/python.rs | 7 +- crates/polars-lazy/src/tests/io.rs | 15 -- crates/polars-mem-engine/src/executors/mod.rs | 4 - .../src/executors/python_scan.rs | 53 ------- .../src/executors/scan/mod.rs | 4 + .../src/executors/scan/python_scan.rs | 134 ++++++++++++++++++ crates/polars-mem-engine/src/planner/lp.rs | 47 +++++- crates/polars-plan/Cargo.toml | 1 + .../polars-plan/src/plans/anonymous_scan.rs | 25 ++-- .../src/plans/conversion/dsl_to_ir.rs | 5 +- crates/polars-plan/src/plans/ir/dot.rs | 9 +- crates/polars-plan/src/plans/ir/format.rs | 8 +- crates/polars-plan/src/plans/ir/inputs.rs | 3 +- crates/polars-plan/src/plans/ir/mod.rs | 1 - crates/polars-plan/src/plans/ir/schema.rs | 4 +- crates/polars-plan/src/plans/mod.rs | 2 +- .../polars-plan/src/plans/optimizer/fused.rs | 10 ++ .../plans/optimizer/predicate_pushdown/mod.rs | 74 ++-------- .../optimizer/projection_pushdown/mod.rs | 7 +- .../src/plans/optimizer/slice_pushdown_lp.rs | 4 +- crates/polars-plan/src/plans/options.rs | 36 ++++- crates/polars-plan/src/plans/python/mod.rs | 2 + .../polars-plan/src/plans/python/predicate.rs | 69 +++++++++ .../src/plans/{ => python}/pyarrow.rs | 51 +++---- py-polars/polars/expr/expr.py | 7 +- py-polars/polars/io/plugins.py | 75 ++++++++++ py-polars/src/lazyframe/visit.rs | 9 +- py-polars/src/lazyframe/visitor/nodes.rs | 60 ++++---- py-polars/tests/unit/io/test_plugins.py | 54 +++++++ .../tests/unit/io/test_pyarrow_dataset.py | 2 +- 31 files changed, 541 insertions(+), 242 deletions(-) delete mode 100644 crates/polars-mem-engine/src/executors/python_scan.rs create mode 100644 crates/polars-mem-engine/src/executors/scan/python_scan.rs create mode 100644 crates/polars-plan/src/plans/python/mod.rs create mode 100644 crates/polars-plan/src/plans/python/predicate.rs rename crates/polars-plan/src/plans/{ => python}/pyarrow.rs (87%) create mode 100644 py-polars/polars/io/plugins.py create mode 100644 py-polars/tests/unit/io/test_plugins.py diff --git a/Cargo.lock b/Cargo.lock index 626ab52988d3..dd5534716945 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3323,6 +3323,7 @@ dependencies = [ "recursive", "regex", "serde", + "serde_json", "smartstring", "strum_macros 0.26.4", "version_check", diff --git a/crates/polars-lazy/src/frame/python.rs b/crates/polars-lazy/src/frame/python.rs index 75363c7194dc..e436bec8a727 100644 --- a/crates/polars-lazy/src/frame/python.rs +++ b/crates/polars-lazy/src/frame/python.rs @@ -7,9 +7,14 @@ impl LazyFrame { pub fn scan_from_python_function(schema: Schema, scan_fn: PyObject, pyarrow: bool) -> Self { DslPlan::PythonScan { options: PythonOptions { + // Should be a python function that returns a generator scan_fn: Some(scan_fn.into()), schema: Arc::new(schema), - pyarrow, + python_source: if pyarrow { + PythonScanSource::Pyarrow + } else { + PythonScanSource::IOPlugin + }, ..Default::default() }, } diff --git a/crates/polars-lazy/src/tests/io.rs b/crates/polars-lazy/src/tests/io.rs index 4493c705fbe0..ff8a93f0112b 100644 --- a/crates/polars-lazy/src/tests/io.rs +++ b/crates/polars-lazy/src/tests/io.rs @@ -649,21 +649,6 @@ fn scan_predicate_on_set_null_values() -> PolarsResult<()> { Ok(()) } -#[test] -fn scan_anonymous_fn() -> PolarsResult<()> { - let function = Arc::new(|_scan_opts: AnonymousScanArgs| Ok(fruits_cars())); - - let args = ScanArgsAnonymous { - schema: Some(Arc::new(fruits_cars().schema())), - ..ScanArgsAnonymous::default() - }; - - let df = LazyFrame::anonymous_scan(function, args)?.collect()?; - - assert_eq!(df.shape(), (5, 4)); - Ok(()) -} - #[test] fn scan_anonymous_fn_with_options() -> PolarsResult<()> { struct MyScan {} diff --git a/crates/polars-mem-engine/src/executors/mod.rs b/crates/polars-mem-engine/src/executors/mod.rs index 0300c832805c..5c9d093d986a 100644 --- a/crates/polars-mem-engine/src/executors/mod.rs +++ b/crates/polars-mem-engine/src/executors/mod.rs @@ -11,8 +11,6 @@ mod join; mod projection; mod projection_simple; mod projection_utils; -#[cfg(feature = "python")] -mod python_scan; mod scan; mod slice; mod sort; @@ -43,8 +41,6 @@ pub(super) use self::hconcat::*; pub(super) use self::join::*; pub(super) use self::projection::*; pub(super) use self::projection_simple::*; -#[cfg(feature = "python")] -pub(super) use self::python_scan::*; pub(super) use self::scan::*; pub(super) use self::slice::*; pub(super) use self::sort::*; diff --git a/crates/polars-mem-engine/src/executors/python_scan.rs b/crates/polars-mem-engine/src/executors/python_scan.rs deleted file mode 100644 index b8dd3e469fc7..000000000000 --- a/crates/polars-mem-engine/src/executors/python_scan.rs +++ /dev/null @@ -1,53 +0,0 @@ -use polars_core::error::to_compute_err; -use pyo3::prelude::*; - -use super::*; - -pub(crate) struct PythonScanExec { - pub(crate) options: PythonOptions, -} - -impl Executor for PythonScanExec { - fn execute(&mut self, state: &mut ExecutionState) -> PolarsResult { - state.should_stop()?; - #[cfg(debug_assertions)] - { - if state.verbose() { - eprintln!("run PythonScanExec") - } - } - let with_columns = self.options.with_columns.take(); - let pyarrow_predicate = self.options.predicate.take(); - let n_rows = self.options.n_rows.take(); - Python::with_gil(|py| { - let pl = PyModule::import_bound(py, "polars").unwrap(); - let utils = pl.getattr("_utils").unwrap(); - let callable = utils.getattr("_execute_from_rust").unwrap(); - - let python_scan_function = self.options.scan_fn.take().unwrap().0; - - let with_columns = with_columns.map(|cols| cols.iter().cloned().collect::>()); - - let out = callable - .call1(( - python_scan_function, - with_columns, - pyarrow_predicate, - n_rows, - )) - .map_err(to_compute_err)?; - let pydf = out.getattr("_df").unwrap(); - let raw_parts = pydf.call_method0("into_raw_parts").unwrap(); - let raw_parts = raw_parts.extract::<(usize, usize, usize)>().unwrap(); - - let (ptr, len, cap) = raw_parts; - unsafe { - Ok(DataFrame::new_no_checks(Vec::from_raw_parts( - ptr as *mut Series, - len, - cap, - ))) - } - }) - } -} diff --git a/crates/polars-mem-engine/src/executors/scan/mod.rs b/crates/polars-mem-engine/src/executors/scan/mod.rs index 93d86da8cb57..8dbb9750de7f 100644 --- a/crates/polars-mem-engine/src/executors/scan/mod.rs +++ b/crates/polars-mem-engine/src/executors/scan/mod.rs @@ -6,6 +6,8 @@ mod ipc; mod ndjson; #[cfg(feature = "parquet")] mod parquet; +#[cfg(feature = "python")] +mod python_scan; use std::mem; @@ -23,6 +25,8 @@ use polars_io::predicates::PhysicalIoExpr; use polars_io::prelude::*; use polars_plan::global::_set_n_rows_for_scan; +#[cfg(feature = "python")] +pub(crate) use self::python_scan::*; use super::*; use crate::prelude::*; diff --git a/crates/polars-mem-engine/src/executors/scan/python_scan.rs b/crates/polars-mem-engine/src/executors/scan/python_scan.rs new file mode 100644 index 000000000000..1b44453b088d --- /dev/null +++ b/crates/polars-mem-engine/src/executors/scan/python_scan.rs @@ -0,0 +1,134 @@ +use polars_core::error::to_compute_err; +use polars_core::utils::accumulate_dataframes_vertical; +use pyo3::exceptions::PyStopIteration; +use pyo3::prelude::*; +use pyo3::types::PyBytes; +use pyo3::{intern, PyTypeInfo}; + +use super::*; + +pub(crate) struct PythonScanExec { + pub(crate) options: PythonOptions, + pub(crate) predicate: Option>, + pub(crate) predicate_serialized: Option>, +} + +fn python_df_to_rust(py: Python, df: Bound) -> PolarsResult { + let err = |_| polars_err!(ComputeError: "expected a polars.DataFrame; got {}", df); + let pydf = df.getattr(intern!(py, "_df")).map_err(err)?; + let raw_parts = pydf + .call_method0(intern!(py, "into_raw_parts")) + .map_err(err)?; + let raw_parts = raw_parts.extract::<(usize, usize, usize)>().unwrap(); + + let (ptr, len, cap) = raw_parts; + unsafe { + Ok(DataFrame::new_no_checks(Vec::from_raw_parts( + ptr as *mut Series, + len, + cap, + ))) + } +} + +impl Executor for PythonScanExec { + fn execute(&mut self, state: &mut ExecutionState) -> PolarsResult { + state.should_stop()?; + #[cfg(debug_assertions)] + { + if state.verbose() { + eprintln!("run PythonScanExec") + } + } + let with_columns = self.options.with_columns.take(); + let n_rows = self.options.n_rows.take(); + Python::with_gil(|py| { + let pl = PyModule::import_bound(py, intern!(py, "polars")).unwrap(); + let utils = pl.getattr(intern!(py, "_utils")).unwrap(); + let callable = utils.getattr(intern!(py, "_execute_from_rust")).unwrap(); + + let python_scan_function = self.options.scan_fn.take().unwrap().0; + + let with_columns = with_columns.map(|cols| cols.iter().cloned().collect::>()); + + let predicate = match &self.options.predicate { + PythonPredicate::PyArrow(s) => s.into_py(py), + PythonPredicate::None => None::<()>.into_py(py), + PythonPredicate::Polars(_) => { + assert!(self.predicate.is_some(), "should be set"); + + match &self.predicate_serialized { + None => None::<()>.into_py(py), + Some(buf) => PyBytes::new_bound(py, buf).to_object(py), + } + }, + }; + + let generator_init = if matches!( + self.options.python_source, + PythonScanSource::Pyarrow | PythonScanSource::Cuda + ) { + let args = (python_scan_function, with_columns, predicate, n_rows); + callable.call1(args).map_err(to_compute_err) + } else { + // If there are filters, take smaller chunks to ensure we can keep memory + // pressure low. + let batch_size = if self.predicate.is_some() { + Some(100_000usize) + } else { + None + }; + let args = ( + python_scan_function, + with_columns, + predicate, + n_rows, + batch_size, + ); + callable.call1(args).map_err(to_compute_err) + }?; + + // This isn't a generator, but a `DataFrame`. + // This is the pyarrow and the CuDF path. + if generator_init.getattr(intern!(py, "_df")).is_ok() { + let df = python_df_to_rust(py, generator_init)?; + return if let Some(pred) = &self.predicate { + let mask = pred.evaluate(&df, state)?; + df.filter(mask.bool()?) + } else { + Ok(df) + }; + } + + // This is the IO plugin path. + let generator = generator_init + .get_item(0) + .map_err(|_| polars_err!(ComputeError: "expected tuple got {}", generator_init))?; + let can_parse_predicate = generator_init + .get_item(1) + .map_err(|_| polars_err!(ComputeError: "expected tuple got {}", generator))?; + let can_parse_predicate = can_parse_predicate.extract::().map_err( + |_| polars_err!(ComputeError: "expected bool got {}", can_parse_predicate), + )?; + + let mut chunks = vec![]; + loop { + match generator.call_method0(intern!(py, "__next__")) { + Ok(out) => { + let mut df = python_df_to_rust(py, out)?; + if let (Some(pred), false) = (&self.predicate, can_parse_predicate) { + let mask = pred.evaluate(&df, state)?; + df = df.filter(mask.bool()?)?; + } + chunks.push(df) + }, + Err(err) if err.matches(py, PyStopIteration::type_object_bound(py)) => break, + Err(err) => { + polars_bail!(ComputeError: "caught exception during execution of a Python source, exception: {}", err) + }, + } + } + accumulate_dataframes_vertical(chunks) + }) + } +} diff --git a/crates/polars-mem-engine/src/planner/lp.rs b/crates/polars-mem-engine/src/planner/lp.rs index 693c6e92ffc0..643aa2f4c995 100644 --- a/crates/polars-mem-engine/src/planner/lp.rs +++ b/crates/polars-mem-engine/src/planner/lp.rs @@ -158,7 +158,52 @@ fn create_physical_plan_impl( let logical_plan = lp_arena.take(root); match logical_plan { #[cfg(feature = "python")] - PythonScan { options, .. } => Ok(Box::new(executors::PythonScanExec { options })), + PythonScan { mut options } => { + let mut predicate_serialized = None; + + let predicate = if let PythonPredicate::Polars(e) = &options.predicate { + let phys_expr = || { + let mut state = ExpressionConversionState::new(true, state.expr_depth); + create_physical_expr( + e, + Context::Default, + expr_arena, + Some(&options.schema), + &mut state, + ) + }; + + // Convert to a pyarrow eval string. + if matches!(options.python_source, PythonScanSource::Pyarrow) { + if let Some(eval_str) = polars_plan::plans::python::pyarrow::predicate_to_pa( + e.node(), + expr_arena, + Default::default(), + ) { + options.predicate = PythonPredicate::PyArrow(eval_str); + // We don't have to use a physical expression as pyarrow deals with the filter. + None + } else { + Some(phys_expr()?) + } + } + // Convert to physical expression for the case the reader cannot consume the predicate. + else { + let dsl_expr = e.to_expr(expr_arena); + predicate_serialized = + polars_plan::plans::python::predicate::serialize(&dsl_expr)?; + + Some(phys_expr()?) + } + } else { + None + }; + Ok(Box::new(executors::PythonScanExec { + options, + predicate, + predicate_serialized, + })) + }, Sink { payload, .. } => match payload { SinkType::Memory => { polars_bail!(InvalidOperation: "memory sink not supported in the standard engine") diff --git a/crates/polars-plan/Cargo.toml b/crates/polars-plan/Cargo.toml index a05b19720be6..3ad50ace7fd8 100644 --- a/crates/polars-plan/Cargo.toml +++ b/crates/polars-plan/Cargo.toml @@ -40,6 +40,7 @@ rayon = { workspace = true } recursive = { workspace = true } regex = { workspace = true, optional = true } serde = { workspace = true, features = ["rc"], optional = true } +serde_json = { workspace = true, optional = true } smartstring = { workspace = true } strum_macros = { workspace = true } diff --git a/crates/polars-plan/src/plans/anonymous_scan.rs b/crates/polars-plan/src/plans/anonymous_scan.rs index 203a17e21403..d426b12f9af4 100644 --- a/crates/polars-plan/src/plans/anonymous_scan.rs +++ b/crates/polars-plan/src/plans/anonymous_scan.rs @@ -19,24 +19,30 @@ pub trait AnonymousScan: Send + Sync { /// Creates a DataFrame from the supplied function & scan options. fn scan(&self, scan_opts: AnonymousScanArgs) -> PolarsResult; + /// Produce the next batch Polars can consume. Implement this method to get proper + /// streaming support. + fn next_batch(&self, scan_opts: AnonymousScanArgs) -> PolarsResult> { + self.scan(scan_opts).map(Some) + } + /// function to supply the schema. /// Allows for an optional infer schema argument for data sources with dynamic schemas fn schema(&self, _infer_schema_length: Option) -> PolarsResult { polars_bail!(ComputeError: "must supply either a schema or a schema function"); } - /// specify if the scan provider should allow predicate pushdowns + /// Specify if the scan provider should allow predicate pushdowns. /// /// Defaults to `false` fn allows_predicate_pushdown(&self) -> bool { false } - /// specify if the scan provider should allow projection pushdowns + /// Specify if the scan provider should allow projection pushdowns. /// /// Defaults to `false` fn allows_projection_pushdown(&self) -> bool { false } - /// specify if the scan provider should allow slice pushdowns + /// Specify if the scan provider should allow slice pushdowns. /// /// Defaults to `false` fn allows_slice_pushdown(&self) -> bool { @@ -44,19 +50,6 @@ pub trait AnonymousScan: Send + Sync { } } -impl AnonymousScan for F -where - F: Fn(AnonymousScanArgs) -> PolarsResult + Send + Sync, -{ - fn as_any(&self) -> &dyn Any { - unimplemented!() - } - - fn scan(&self, scan_opts: AnonymousScanArgs) -> PolarsResult { - self(scan_opts) - } -} - impl Debug for dyn AnonymousScan { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "anonymous_scan") diff --git a/crates/polars-plan/src/plans/conversion/dsl_to_ir.rs b/crates/polars-plan/src/plans/conversion/dsl_to_ir.rs index 4ed934d1423b..3ee71042b78d 100644 --- a/crates/polars-plan/src/plans/conversion/dsl_to_ir.rs +++ b/crates/polars-plan/src/plans/conversion/dsl_to_ir.rs @@ -259,10 +259,7 @@ pub fn to_alp_impl( } }, #[cfg(feature = "python")] - DslPlan::PythonScan { options } => IR::PythonScan { - options, - predicate: None, - }, + DslPlan::PythonScan { options } => IR::PythonScan { options }, DslPlan::Union { inputs, args } => { let mut inputs = inputs .into_iter() diff --git a/crates/polars-plan/src/plans/ir/dot.rs b/crates/polars-plan/src/plans/ir/dot.rs index 45bdd0bb281c..49e9bef1a3dc 100644 --- a/crates/polars-plan/src/plans/ir/dot.rs +++ b/crates/polars-plan/src/plans/ir/dot.rs @@ -153,11 +153,14 @@ impl<'a> IRDotDisplay<'a> { write_label(f, id, |f| write!(f, "FILTER BY {pred}"))?; }, #[cfg(feature = "python")] - PythonScan { predicate, options } => { - let predicate = predicate.as_ref().map(|e| self.display_expr(e)); + PythonScan { options } => { + let predicate = match &options.predicate { + PythonPredicate::Polars(e) => format!("{}", self.display_expr(e)), + PythonPredicate::PyArrow(s) => s.clone(), + PythonPredicate::None => "none".to_string(), + }; let with_columns = NumColumns(options.with_columns.as_ref().map(|s| s.as_ref())); let total_columns = options.schema.len(); - let predicate = OptionExprIRDisplay(predicate); write_label(f, id, |f| { write!( diff --git a/crates/polars-plan/src/plans/ir/format.rs b/crates/polars-plan/src/plans/ir/format.rs index dc5a6072f4c4..2ad3050c3d0a 100644 --- a/crates/polars-plan/src/plans/ir/format.rs +++ b/crates/polars-plan/src/plans/ir/format.rs @@ -154,7 +154,7 @@ impl<'a> IRDisplay<'a> { match self.root() { #[cfg(feature = "python")] - PythonScan { options, predicate } => { + PythonScan { options } => { let total_columns = options.schema.len(); let n_columns = options .with_columns @@ -162,7 +162,11 @@ impl<'a> IRDisplay<'a> { .map(|s| s.len() as i64) .unwrap_or(-1); - let predicate = predicate.as_ref().map(|p| self.display_expr(p)); + let predicate = match &options.predicate { + PythonPredicate::Polars(e) => Some(self.display_expr(e)), + PythonPredicate::PyArrow(_) => None, + PythonPredicate::None => None, + }; write_scan( f, diff --git a/crates/polars-plan/src/plans/ir/inputs.rs b/crates/polars-plan/src/plans/ir/inputs.rs index 2f941bc9282d..b00c91cddae4 100644 --- a/crates/polars-plan/src/plans/ir/inputs.rs +++ b/crates/polars-plan/src/plans/ir/inputs.rs @@ -7,9 +7,8 @@ impl IR { match self { #[cfg(feature = "python")] - PythonScan { options, predicate } => PythonScan { + PythonScan { options } => PythonScan { options: options.clone(), - predicate: predicate.clone(), }, Union { options, .. } => Union { inputs, diff --git a/crates/polars-plan/src/plans/ir/mod.rs b/crates/polars-plan/src/plans/ir/mod.rs index 5440fdb80686..095501995f55 100644 --- a/crates/polars-plan/src/plans/ir/mod.rs +++ b/crates/polars-plan/src/plans/ir/mod.rs @@ -37,7 +37,6 @@ pub enum IR { #[cfg(feature = "python")] PythonScan { options: PythonOptions, - predicate: Option, }, Slice { input: Node, diff --git a/crates/polars-plan/src/plans/ir/schema.rs b/crates/polars-plan/src/plans/ir/schema.rs index 46d793687345..5b5042e50377 100644 --- a/crates/polars-plan/src/plans/ir/schema.rs +++ b/crates/polars-plan/src/plans/ir/schema.rs @@ -51,7 +51,7 @@ impl IR { use IR::*; let schema = match self { #[cfg(feature = "python")] - PythonScan { options, .. } => &options.schema, + PythonScan { options } => &options.schema, DataFrameScan { schema, .. } => schema, Scan { file_info, .. } => &file_info.schema, node => { @@ -68,7 +68,7 @@ impl IR { use IR::*; let schema = match self { #[cfg(feature = "python")] - PythonScan { options, .. } => options.output_schema.as_ref().unwrap_or(&options.schema), + PythonScan { options } => options.output_schema.as_ref().unwrap_or(&options.schema), Union { inputs, .. } => return arena.get(inputs[0]).schema(arena), HConcat { schema, .. } => schema, Cache { input, .. } => return arena.get(*input).schema(arena), diff --git a/crates/polars-plan/src/plans/mod.rs b/crates/polars-plan/src/plans/mod.rs index 4f0ec8cafafb..8688521edeaf 100644 --- a/crates/polars-plan/src/plans/mod.rs +++ b/crates/polars-plan/src/plans/mod.rs @@ -29,7 +29,7 @@ mod lit; pub(crate) mod optimizer; pub(crate) mod options; #[cfg(feature = "python")] -mod pyarrow; +pub mod python; mod schema; pub mod visitor; diff --git a/crates/polars-plan/src/plans/optimizer/fused.rs b/crates/polars-plan/src/plans/optimizer/fused.rs index e6840f992fbc..d548147f65ce 100644 --- a/crates/polars-plan/src/plans/optimizer/fused.rs +++ b/crates/polars-plan/src/plans/optimizer/fused.rs @@ -65,6 +65,16 @@ impl OptimizationRule for FusedArithmetic { lp_arena: &Arena, lp_node: Node, ) -> PolarsResult> { + // We don't want to fuse arithmetic that we send to pyarrow. + #[cfg(feature = "python")] + if let IR::PythonScan { options } = lp_arena.get(lp_node) { + if matches!( + options.python_source, + PythonScanSource::Pyarrow | PythonScanSource::IOPlugin + ) { + return Ok(None); + } + }; let expr = expr_arena.get(expr_node); use AExpr::*; diff --git a/crates/polars-plan/src/plans/optimizer/predicate_pushdown/mod.rs b/crates/polars-plan/src/plans/optimizer/predicate_pushdown/mod.rs index 8e5ad9f9bb27..b410954fe551 100644 --- a/crates/polars-plan/src/plans/optimizer/predicate_pushdown/mod.rs +++ b/crates/polars-plan/src/plans/optimizer/predicate_pushdown/mod.rs @@ -657,69 +657,23 @@ impl<'a> PredicatePushDown<'a> { } }, #[cfg(feature = "python")] - PythonScan { - mut options, - predicate, - } => { - if options.pyarrow { - let predicate = predicate_at_scan(acc_predicates, predicate, expr_arena); - - if let Some(predicate) = predicate.clone() { - // simplify expressions before we translate them to pyarrow - let lp = PythonScan { - options: options.clone(), - predicate: Some(predicate), - }; - let lp_top = lp_arena.add(lp); - let stack_opt = StackOptimizer {}; - let lp_top = stack_opt - .optimize_loop( - &mut [Box::new(SimplifyExprRule {})], - expr_arena, - lp_arena, - lp_top, - ) - .unwrap(); - let PythonScan { - options: _, - predicate: Some(predicate), - } = lp_arena.take(lp_top) - else { - unreachable!() - }; - - match super::super::pyarrow::predicate_to_pa( - predicate.node(), + PythonScan { mut options } => { + let predicate = predicate_at_scan(acc_predicates, None, expr_arena); + if let Some(predicate) = predicate { + // Only accept streamable expressions as we want to apply the predicates to the batches. + if !is_streamable(predicate.node(), expr_arena, Context::Default) { + let lp = PythonScan { options }; + return Ok(self.optional_apply_predicate( + lp, + vec![predicate], + lp_arena, expr_arena, - Default::default(), - ) { - // we we able to create a pyarrow string, mutate the options - Some(eval_str) => options.predicate = Some(eval_str), - // we were not able to translate the predicate - // apply here - None => { - let lp = PythonScan { - options, - predicate: None, - }; - return Ok(self.optional_apply_predicate( - lp, - vec![predicate], - lp_arena, - expr_arena, - )); - }, - } + )); } - Ok(PythonScan { options, predicate }) - } else { - self.no_pushdown_restart_opt( - PythonScan { options, predicate }, - acc_predicates, - lp_arena, - expr_arena, - ) + + options.predicate = PythonPredicate::Polars(predicate); } + Ok(PythonScan { options }) }, Invalid => unreachable!(), } diff --git a/crates/polars-plan/src/plans/optimizer/projection_pushdown/mod.rs b/crates/polars-plan/src/plans/optimizer/projection_pushdown/mod.rs index 092292540d17..ed89ae336287 100644 --- a/crates/polars-plan/src/plans/optimizer/projection_pushdown/mod.rs +++ b/crates/polars-plan/src/plans/optimizer/projection_pushdown/mod.rs @@ -380,10 +380,7 @@ impl ProjectionPushDown { Ok(lp) }, #[cfg(feature = "python")] - PythonScan { - mut options, - predicate, - } => { + PythonScan { mut options } => { options.with_columns = get_scan_columns(&acc_projections, expr_arena, None, None); options.output_schema = if options.with_columns.is_none() { @@ -396,7 +393,7 @@ impl ProjectionPushDown { true, )?)) }; - Ok(PythonScan { options, predicate }) + Ok(PythonScan { options }) }, Scan { paths, diff --git a/crates/polars-plan/src/plans/optimizer/slice_pushdown_lp.rs b/crates/polars-plan/src/plans/optimizer/slice_pushdown_lp.rs index 14ad0e03a9be..949b4288af08 100644 --- a/crates/polars-plan/src/plans/optimizer/slice_pushdown_lp.rs +++ b/crates/polars-plan/src/plans/optimizer/slice_pushdown_lp.rs @@ -153,15 +153,13 @@ impl SlicePushDown { #[cfg(feature = "python")] (PythonScan { mut options, - predicate, }, // TODO! we currently skip slice pushdown if there is a predicate. // we can modify the readers to only limit after predicates have been applied - Some(state)) if state.offset == 0 && predicate.is_none() => { + Some(state)) if state.offset == 0 && matches!(options.predicate, PythonPredicate::None) => { options.n_rows = Some(state.len as usize); let lp = PythonScan { options, - predicate }; Ok(lp) } diff --git a/crates/polars-plan/src/plans/options.rs b/crates/polars-plan/src/plans/options.rs index dfae2c5917ec..e1232020711c 100644 --- a/crates/polars-plan/src/plans/options.rs +++ b/crates/polars-plan/src/plans/options.rs @@ -19,6 +19,7 @@ use polars_time::{DynamicGroupOptions, RollingGroupOptions}; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; +use crate::plans::ExprIR; #[cfg(feature = "python")] use crate::prelude::python_udf::PythonFunction; @@ -226,18 +227,43 @@ pub struct LogicalPlanUdfOptions { #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg(feature = "python")] pub struct PythonOptions { + /// A function that returns a Python Generator. + /// The generator should produce Polars DataFrame's. pub scan_fn: Option, + /// Schema of the file. pub schema: SchemaRef, + /// Schema the reader will produce when the file is read. pub output_schema: Option, + // Projected column names. pub with_columns: Option>, - pub pyarrow: bool, - // a pyarrow predicate python expression - // can be evaluated with python.eval - pub predicate: Option, - // a `head` call passed to pyarrow + // Which interface is the python function. + pub python_source: PythonScanSource, + /// Optional predicate the reader must apply. + #[cfg_attr(feature = "serde", serde(skip))] + pub predicate: PythonPredicate, + /// A `head` call passed to the reader. pub n_rows: Option, } +#[derive(Clone, PartialEq, Eq, Debug, Default)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum PythonScanSource { + Pyarrow, + Cuda, + #[default] + IOPlugin, +} + +#[derive(Clone, PartialEq, Eq, Debug, Default)] +pub enum PythonPredicate { + // A pyarrow predicate python expression + // can be evaluated with python.eval + PyArrow(String), + Polars(ExprIR), + #[default] + None, +} + #[derive(Clone, PartialEq, Eq, Debug, Default, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct AnonymousScanOptions { diff --git a/crates/polars-plan/src/plans/python/mod.rs b/crates/polars-plan/src/plans/python/mod.rs new file mode 100644 index 000000000000..e82c95f09e78 --- /dev/null +++ b/crates/polars-plan/src/plans/python/mod.rs @@ -0,0 +1,2 @@ +pub mod predicate; +pub mod pyarrow; diff --git a/crates/polars-plan/src/plans/python/predicate.rs b/crates/polars-plan/src/plans/python/predicate.rs new file mode 100644 index 000000000000..2e4a21af2749 --- /dev/null +++ b/crates/polars-plan/src/plans/python/predicate.rs @@ -0,0 +1,69 @@ +use polars_core::error::polars_err; +use polars_core::prelude::PolarsResult; + +use crate::prelude::*; + +fn accept_as_io_predicate(e: &Expr) -> bool { + const LIMIT: usize = 1 << 16; + match e { + Expr::Literal(lv) => match lv { + LiteralValue::Binary(v) => v.len() <= LIMIT, + LiteralValue::String(v) => v.len() <= LIMIT, + LiteralValue::Series(s) => s.estimated_size() < LIMIT, + // Don't accept dynamic types + LiteralValue::Int(_) => false, + LiteralValue::Float(_) => false, + _ => true, + }, + Expr::Wildcard | Expr::Column(_) => true, + Expr::BinaryExpr { left, right, .. } => { + accept_as_io_predicate(left) && accept_as_io_predicate(right) + }, + Expr::Ternary { + truthy, + falsy, + predicate, + } => { + accept_as_io_predicate(truthy) + && accept_as_io_predicate(falsy) + && accept_as_io_predicate(predicate) + }, + Expr::Alias(_, _) => true, + Expr::Function { + function, input, .. + } => { + match function { + // we already checked if streaming, so we can all functions + FunctionExpr::Boolean(_) | FunctionExpr::BinaryExpr(_) | FunctionExpr::Coalesce => { + }, + #[cfg(feature = "log")] + FunctionExpr::Entropy { .. } + | FunctionExpr::Log { .. } + | FunctionExpr::Log1p { .. } + | FunctionExpr::Exp { .. } => {}, + #[cfg(feature = "abs")] + FunctionExpr::Abs => {}, + #[cfg(feature = "trigonometry")] + FunctionExpr::Atan2 => {}, + #[cfg(feature = "round_series")] + FunctionExpr::Clip { .. } => {}, + #[cfg(feature = "fused")] + FunctionExpr::Fused(_) => {}, + _ => return false, + } + input.iter().all(accept_as_io_predicate) + }, + _ => false, + } +} + +pub fn serialize(expr: &Expr) -> PolarsResult>> { + if !accept_as_io_predicate(expr) { + return Ok(None); + } + let mut buf = vec![]; + ciborium::into_writer(expr, &mut buf) + .map_err(|_| polars_err!(ComputeError: "could not serialize: {}", expr))?; + + Ok(Some(buf)) +} diff --git a/crates/polars-plan/src/plans/pyarrow.rs b/crates/polars-plan/src/plans/python/pyarrow.rs similarity index 87% rename from crates/polars-plan/src/plans/pyarrow.rs rename to crates/polars-plan/src/plans/python/pyarrow.rs index 019f8d074b39..1232fcfde673 100644 --- a/crates/polars-plan/src/plans/pyarrow.rs +++ b/crates/polars-plan/src/plans/python/pyarrow.rs @@ -6,7 +6,7 @@ use polars_core::prelude::{TimeUnit, TimeZone}; use crate::prelude::*; #[derive(Default, Copy, Clone)] -pub(super) struct Args { +pub struct PyarrowArgs { // pyarrow doesn't allow `filter([True, False])` // but does allow `filter(field("a").isin([True, False]))` allow_literal_series: bool, @@ -22,10 +22,10 @@ fn to_py_datetime(v: i64, tu: &TimeUnit, tz: Option<&TimeZone>) -> String { } // convert to a pyarrow expression that can be evaluated with pythons eval -pub(super) fn predicate_to_pa( +pub fn predicate_to_pa( predicate: Node, expr_arena: &Arena, - args: Args, + args: PyarrowArgs, ) -> Option { match expr_arena.get(predicate) { AExpr::BinaryExpr { left, right, op } => { @@ -38,7 +38,6 @@ pub(super) fn predicate_to_pa( } }, AExpr::Column(name) => Some(format!("pa.compute.field('{}')", name.as_ref())), - AExpr::Alias(input, _) => predicate_to_pa(*input, expr_arena, args), AExpr::Literal(LiteralValue::Series(s)) => { if !args.allow_literal_series || s.is_empty() || s.len() > 100 { None @@ -115,33 +114,6 @@ pub(super) fn predicate_to_pa( }, } }, - AExpr::Function { - function: FunctionExpr::Boolean(BooleanFunction::Not), - input, - .. - } => { - let input = input.first().unwrap().node(); - let input = predicate_to_pa(input, expr_arena, args)?; - Some(format!("~({input})")) - }, - AExpr::Function { - function: FunctionExpr::Boolean(BooleanFunction::IsNull), - input, - .. - } => { - let input = input.first().unwrap().node(); - let input = predicate_to_pa(input, expr_arena, args)?; - Some(format!("({input}).is_null()")) - }, - AExpr::Function { - function: FunctionExpr::Boolean(BooleanFunction::IsNotNull), - input, - .. - } => { - let input = input.first().unwrap().node(); - let input = predicate_to_pa(input, expr_arena, args)?; - Some(format!("~({input}).is_null()")) - }, #[cfg(feature = "is_in")] AExpr::Function { function: FunctionExpr::Boolean(BooleanFunction::IsIn), @@ -182,6 +154,23 @@ pub(super) fn predicate_to_pa( )) } }, + AExpr::Function { + function, input, .. + } => { + let input = input.first().unwrap().node(); + let input = predicate_to_pa(input, expr_arena, args)?; + + match function { + FunctionExpr::Boolean(BooleanFunction::Not) => Some(format!("~({input})")), + FunctionExpr::Boolean(BooleanFunction::IsNull) => { + Some(format!("({input}).is_null()")) + }, + FunctionExpr::Boolean(BooleanFunction::IsNotNull) => { + Some(format!("~({input}).is_null()")) + }, + _ => None, + } + }, _ => None, } } diff --git a/py-polars/polars/expr/expr.py b/py-polars/polars/expr/expr.py index ec65f40af0d7..92baadbaddee 100644 --- a/py-polars/polars/expr/expr.py +++ b/py-polars/polars/expr/expr.py @@ -341,7 +341,10 @@ def function(s: Series) -> Series: # pragma: no cover @classmethod def deserialize( - cls, source: str | Path | IOBase, *, format: SerializationFormat = "binary" + cls, + source: str | Path | IOBase | bytes, + *, + format: SerializationFormat = "binary", ) -> Expr: """ Read a serialized expression from a file. @@ -385,6 +388,8 @@ def deserialize( source = BytesIO(source.getvalue().encode()) elif isinstance(source, (str, Path)): source = normalize_filepath(source) + elif isinstance(source, bytes): + source = BytesIO(source) if format == "binary": deserializer = PyExpr.deserialize_binary diff --git a/py-polars/polars/io/plugins.py b/py-polars/polars/io/plugins.py new file mode 100644 index 000000000000..02f598515c1e --- /dev/null +++ b/py-polars/polars/io/plugins.py @@ -0,0 +1,75 @@ +from __future__ import annotations + +import os +import sys +from typing import TYPE_CHECKING, Callable, Iterator + +import polars._reexport as pl +from polars._utils.unstable import unstable + +if TYPE_CHECKING: + from typing import Callable, Iterator + + from polars import DataFrame, Expr, LazyFrame + from polars._typing import SchemaDict + + +@unstable() +def register_io_source( + callable: Callable[ + [list[str] | None, Expr | None, int | None, int | None], Iterator[DataFrame] + ], + schema: SchemaDict, +) -> LazyFrame: + """ + Register your IO plugin and initialize a LazyFrame. + + Parameters + ---------- + callable + Function that accepts the following arguments: + with_columns + Columns that are projected. The reader must + project these columns if applied + predicate + Polars expression. The reader must filter + their rows accordingly. + n_rows: + Materialize only n rows from the source. + The reader can stop when `n_rows` are read. + batch_size + A hint of the ideal batch size the reader's + generator must produce. + The function should return a DataFrame batch + (an iterator over individual DataFrames). + schema + Schema that the reader will produce before projection pushdown. + + """ + + def wrap( + with_columns: list[str] | None, + predicate: bytes | None, + n_rows: int | None, + batch_size: int | None, + ) -> tuple[Iterator[DataFrame], bool]: + parsed_predicate_success = True + parsed_predicate = None + if predicate: + try: + parsed_predicate = pl.Expr.deserialize(predicate) + except Exception as e: + if os.environ.get("POLARS_VERBOSE"): + print( + f"failed parsing IO plugin expression\n\nfilter will be handled on Polars' side: {e}", + file=sys.stderr, + ) + parsed_predicate_success = False + + return callable( + with_columns, parsed_predicate, n_rows, batch_size + ), parsed_predicate_success + + return pl.LazyFrame._scan_python_function( + schema=schema, scan_fn=wrap, pyarrow=False + ) diff --git a/py-polars/src/lazyframe/visit.rs b/py-polars/src/lazyframe/visit.rs index 3beec16c72be..fad2e25fc7ee 100644 --- a/py-polars/src/lazyframe/visit.rs +++ b/py-polars/src/lazyframe/visit.rs @@ -2,7 +2,7 @@ use std::sync::Mutex; use polars_plan::plans::{to_aexpr, Context, IR}; use polars_plan::prelude::expr_ir::ExprIR; -use polars_plan::prelude::{AExpr, PythonOptions}; +use polars_plan::prelude::{AExpr, PythonOptions, PythonScanSource}; use polars_utils::arena::{Arena, Node}; use pyo3::prelude::*; use visitor::{expr_nodes, nodes}; @@ -54,7 +54,7 @@ impl NodeTraverser { // Incremement major on breaking changes to the IR (e.g. renaming // fields, reordering tuples), minor on backwards compatible // changes (e.g. exposing a new expression node). - const VERSION: Version = (0, 0); + const VERSION: Version = (1, 0); pub(crate) fn new(root: Node, lp_arena: Arena, expr_arena: Arena) -> Self { Self { @@ -164,11 +164,10 @@ impl NodeTraverser { schema, output_schema: None, with_columns: None, - pyarrow: false, - predicate: None, + python_source: PythonScanSource::Cuda, + predicate: Default::default(), n_rows: None, }, - predicate: None, }; lp_arena.replace(self.root, ir); } diff --git a/py-polars/src/lazyframe/visitor/nodes.rs b/py-polars/src/lazyframe/visitor/nodes.rs index ca9fbd067a92..833cdf76b239 100644 --- a/py-polars/src/lazyframe/visitor/nodes.rs +++ b/py-polars/src/lazyframe/visitor/nodes.rs @@ -1,7 +1,9 @@ use polars_core::prelude::{IdxSize, UniqueKeepStrategy}; use polars_ops::prelude::JoinType; use polars_plan::plans::IR; -use polars_plan::prelude::{FileCount, FileScan, FileScanOptions, FunctionNode}; +use polars_plan::prelude::{ + FileCount, FileScan, FileScanOptions, FunctionNode, PythonPredicate, PythonScanSource, +}; use pyo3::exceptions::{PyNotImplementedError, PyValueError}; use pyo3::prelude::*; @@ -14,8 +16,6 @@ use crate::PyDataFrame; pub struct PythonScan { #[pyo3(get)] options: PyObject, - #[pyo3(get)] - predicate: Option, } #[pyclass] @@ -257,29 +257,37 @@ pub struct Sink { pub(crate) fn into_py(py: Python<'_>, plan: &IR) -> PyResult { let result = match plan { - IR::PythonScan { options, predicate } => PythonScan { - options: ( - options - .scan_fn - .as_ref() - .map_or_else(|| py.None(), |s| s.0.clone()), - options - .with_columns - .as_ref() - .map_or_else(|| py.None(), |cols| cols.to_object(py)), - options.pyarrow, - options - .predicate - .as_ref() - .map_or_else(|| py.None(), |s| s.to_object(py)), - options - .n_rows - .map_or_else(|| py.None(), |s| s.to_object(py)), - ) - .to_object(py), - predicate: predicate.as_ref().map(|e| e.into()), - } - .into_py(py), + IR::PythonScan { options } => { + let python_src = match options.python_source { + PythonScanSource::Pyarrow => "pyarrow", + PythonScanSource::Cuda => "cuda", + PythonScanSource::IOPlugin => "io_plugin", + }; + + PythonScan { + options: ( + options + .scan_fn + .as_ref() + .map_or_else(|| py.None(), |s| s.0.clone()), + options + .with_columns + .as_ref() + .map_or_else(|| py.None(), |cols| cols.to_object(py)), + python_src, + match &options.predicate { + PythonPredicate::None => py.None(), + PythonPredicate::PyArrow(s) => ("pyarrow", s).to_object(py), + PythonPredicate::Polars(e) => ("polars", e.node().0).to_object(py), + }, + options + .n_rows + .map_or_else(|| py.None(), |s| s.to_object(py)), + ) + .to_object(py), + } + .into_py(py) + }, IR::Slice { input, offset, len } => Slice { input: input.0, offset: *offset, diff --git a/py-polars/tests/unit/io/test_plugins.py b/py-polars/tests/unit/io/test_plugins.py new file mode 100644 index 000000000000..98c25edc3f4a --- /dev/null +++ b/py-polars/tests/unit/io/test_plugins.py @@ -0,0 +1,54 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import polars as pl +from polars.io.plugins import register_io_source +from polars.testing import assert_frame_equal + +if TYPE_CHECKING: + from typing import Iterator + + +# A simple python source. But this can dispatch into a rust IO source as well. +def my_source( + with_columns: list[str] | None, + predicate: pl.Expr | None, + _n_rows: int | None, + _batch_size: int | None, +) -> Iterator[pl.DataFrame]: + for i in [1, 2, 3]: + df = pl.DataFrame({"a": [i], "b": [i]}) + + if predicate is not None: + df = df.filter(predicate) + + if with_columns is not None: + df = df.select(with_columns) + + yield df + + +def scan_my_source() -> pl.LazyFrame: + # schema inference logic + # TODO: make lazy via callable + schema = pl.Schema({"a": pl.Int64(), "b": pl.Int64()}) + + return register_io_source(my_source, schema=schema) + + +def test_my_source() -> None: + assert_frame_equal( + scan_my_source().collect(), pl.DataFrame({"a": [1, 2, 3], "b": [1, 2, 3]}) + ) + assert_frame_equal( + scan_my_source().filter(pl.col("b") > 1).collect(), + pl.DataFrame({"a": [2, 3], "b": [2, 3]}), + ) + assert_frame_equal( + scan_my_source().filter(pl.col("b") > 1).select("a").collect(), + pl.DataFrame({"a": [2, 3]}), + ) + assert_frame_equal( + scan_my_source().select("a").collect(), pl.DataFrame({"a": [1, 2, 3]}) + ) diff --git a/py-polars/tests/unit/io/test_pyarrow_dataset.py b/py-polars/tests/unit/io/test_pyarrow_dataset.py index cb579331c8f4..aa4bccb14717 100644 --- a/py-polars/tests/unit/io/test_pyarrow_dataset.py +++ b/py-polars/tests/unit/io/test_pyarrow_dataset.py @@ -28,7 +28,7 @@ def helper_dataset_test( @pytest.mark.write_disk() -def test_dataset_foo(df: pl.DataFrame, tmp_path: Path) -> None: +def test_pyarrow_dataset_source(df: pl.DataFrame, tmp_path: Path) -> None: file_path = tmp_path / "small.ipc" df.write_ipc(file_path)