-
Notifications
You must be signed in to change notification settings - Fork 42
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
36 changed files
with
624 additions
and
89 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
[workspace] | ||
resolver = "2" | ||
members = [ | ||
"example/derive_expression/expression_lib", | ||
"example/extend_polars_python_dispatch/extend_polars", | ||
"pyo3-polars", | ||
"pyo3-polars-derive", | ||
] | ||
|
||
[workspace.dependencies] | ||
polars = {version = "0.33.2", default-features=false} | ||
polars-core = {version = "0.33.2", default-features=false} | ||
polars-ffi = {ersion = "0.33.2", default-features=false} | ||
polars-plan = {version = "0.33.2", default-feautres=false} | ||
polars-lazy = {version = "0.33.2", default-features=false} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
|
||
SHELL=/bin/bash | ||
|
||
venv: ## Set up virtual environment | ||
python3 -m venv venv | ||
venv/bin/pip install -r requirements.txt | ||
|
||
install: venv | ||
unset CONDA_PREFIX && \ | ||
source venv/bin/activate && maturin develop -m expression_lib/Cargo.toml | ||
|
||
install-release: venv | ||
unset CONDA_PREFIX && \ | ||
source venv/bin/activate && maturin develop --release -m expression_lib/Cargo.toml | ||
|
||
clean: | ||
-@rm -r venv | ||
-@cd experssion_lib && cargo clean | ||
|
||
|
||
run: install | ||
source venv/bin/activate && python run.py | ||
|
||
run-release: install-release | ||
source venv/bin/activate && python run.py |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
[package] | ||
name = "expression_lib" | ||
version = "0.1.0" | ||
edition = "2021" | ||
|
||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html | ||
[lib] | ||
name = "expression_lib" | ||
crate-type = ["cdylib"] | ||
|
||
[dependencies] | ||
pyo3 = { version = "0.19.0", features = ["extension-module"] } | ||
pyo3-polars = { version = "*", path = "../../../pyo3-polars", features=["derive"] } | ||
polars = { workspace = true, features = ["fmt"], default-features=false } | ||
polars-plan = { workspace = true, default-features=false } |
48 changes: 48 additions & 0 deletions
48
example/derive_expression/expression_lib/expression_lib/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
import polars as pl | ||
from polars.type_aliases import IntoExpr | ||
from polars.utils.udfs import _get_shared_lib_location | ||
|
||
lib = _get_shared_lib_location(__file__) | ||
|
||
|
||
@pl.api.register_expr_namespace("language") | ||
class Language: | ||
def __init__(self, expr: pl.Expr): | ||
self._expr = expr | ||
|
||
def pig_latinnify(self) -> pl.Expr: | ||
return self._expr._register_plugin( | ||
lib=lib, | ||
symbol="pig_latinnify", | ||
is_elementwise=True, | ||
) | ||
|
||
@pl.api.register_expr_namespace("dist") | ||
class Distance: | ||
def __init__(self, expr: pl.Expr): | ||
self._expr = expr | ||
|
||
def hamming_distance(self, other: IntoExpr) -> pl.Expr: | ||
return self._expr._register_plugin( | ||
lib=lib, | ||
args=[other], | ||
symbol="hamming_distance", | ||
is_elementwise=True, | ||
) | ||
|
||
def jaccard_similarity(self, other: IntoExpr) -> pl.Expr: | ||
return self._expr._register_plugin( | ||
lib=lib, | ||
args=[other], | ||
symbol="jaccard_similarity", | ||
is_elementwise=True, | ||
) | ||
|
||
def haversine(self, start_lat: IntoExpr, start_long: IntoExpr, end_lat: IntoExpr, end_long: IntoExpr) -> pl.Expr: | ||
return self._expr._register_plugin( | ||
lib=lib, | ||
args=[start_lat, start_long, end_lat, end_long], | ||
symbol="haversine", | ||
is_elementwise=True, | ||
cast_to_supertypes=True | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
[build-system] | ||
requires = ["maturin>=1.0,<2.0"] | ||
build-backend = "maturin" | ||
|
||
[project] | ||
name = "expression_lib" | ||
requires-python = ">=3.8" | ||
classifiers = [ | ||
"Programming Language :: Rust", | ||
"Programming Language :: Python :: Implementation :: CPython", | ||
"Programming Language :: Python :: Implementation :: PyPy", | ||
] | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
use polars::datatypes::PlHashSet; | ||
use polars::export::arrow::array::PrimitiveArray; | ||
use polars::export::num::Float; | ||
use polars::prelude::*; | ||
use pyo3_polars::export::polars_core::utils::arrow::types::NativeType; | ||
use pyo3_polars::export::polars_core::with_match_physical_integer_type; | ||
use std::hash::Hash; | ||
|
||
#[allow(clippy::all)] | ||
pub(super) fn naive_hamming_dist(a: &str, b: &str) -> u32 { | ||
let x = a.as_bytes(); | ||
let y = b.as_bytes(); | ||
x.iter() | ||
.zip(y) | ||
.fold(0, |a, (b, c)| a + (*b ^ *c).count_ones() as u32) | ||
} | ||
|
||
fn jacc_helper<T: NativeType + Hash + Eq>(a: &PrimitiveArray<T>, b: &PrimitiveArray<T>) -> f64 { | ||
// convert to hashsets over Option<T> | ||
let s1 = a.into_iter().collect::<PlHashSet<_>>(); | ||
let s2 = b.into_iter().collect::<PlHashSet<_>>(); | ||
|
||
// count the number of intersections | ||
let s3_len = s1.intersection(&s2).count(); | ||
// return similarity | ||
s3_len as f64 / (s1.len() + s2.len() - s3_len) as f64 | ||
} | ||
|
||
pub(super) fn naive_jaccard_sim(a: &ListChunked, b: &ListChunked) -> PolarsResult<Float64Chunked> { | ||
polars_ensure!( | ||
a.inner_dtype() == b.inner_dtype(), | ||
ComputeError: "inner data types don't match" | ||
); | ||
polars_ensure!( | ||
a.inner_dtype().is_integer(), | ||
ComputeError: "inner data types must be integer" | ||
); | ||
Ok(with_match_physical_integer_type!(a.inner_dtype(), |$T| { | ||
polars::prelude::arity::binary_elementwise(a, b, |a, b| { | ||
match (a, b) { | ||
(Some(a), Some(b)) => { | ||
let a = a.as_any().downcast_ref::<PrimitiveArray<$T>>().unwrap(); | ||
let b = b.as_any().downcast_ref::<PrimitiveArray<$T>>().unwrap(); | ||
Some(jacc_helper(a, b)) | ||
}, | ||
_ => None | ||
} | ||
}) | ||
})) | ||
} | ||
|
||
fn haversine_elementwise<T: Float>(start_lat: T, start_long: T, end_lat: T, end_long: T) -> T { | ||
let r_in_km = T::from(6371.0).unwrap(); | ||
let two = T::from(2.0).unwrap(); | ||
let one = T::one(); | ||
|
||
let d_lat = (end_lat - start_lat).to_radians(); | ||
let d_lon = (end_long - start_long).to_radians(); | ||
let lat1 = (start_lat).to_radians(); | ||
let lat2 = (end_lat).to_radians(); | ||
|
||
let a = ((d_lat / two).sin()) * ((d_lat / two).sin()) | ||
+ ((d_lon / two).sin()) * ((d_lon / two).sin()) * (lat1.cos()) * (lat2.cos()); | ||
let c = two * ((a.sqrt()).atan2((one - a).sqrt())); | ||
r_in_km * c | ||
} | ||
|
||
pub(super) fn naive_haversine<T>( | ||
start_lat: &ChunkedArray<T>, | ||
start_long: &ChunkedArray<T>, | ||
end_lat: &ChunkedArray<T>, | ||
end_long: &ChunkedArray<T>, | ||
) -> PolarsResult<ChunkedArray<T>> | ||
where | ||
T: PolarsFloatType, | ||
T::Native: Float, | ||
{ | ||
let out: ChunkedArray<T> = start_lat | ||
.into_iter() | ||
.zip(start_long.into_iter()) | ||
.zip(end_lat.into_iter()) | ||
.zip(end_long.into_iter()) | ||
.map(|(((start_lat, start_long), end_lat), end_long)| { | ||
let start_lat = start_lat?; | ||
let start_long = start_long?; | ||
let end_lat = end_lat?; | ||
let end_long = end_long?; | ||
Some(haversine_elementwise( | ||
start_lat, start_long, end_lat, end_long, | ||
)) | ||
}) | ||
.collect(); | ||
|
||
Ok(out.with_name(start_lat.name())) | ||
} |
61 changes: 61 additions & 0 deletions
61
example/derive_expression/expression_lib/src/expressions.rs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
use polars::prelude::*; | ||
use polars_plan::dsl::FieldsMapper; | ||
use pyo3_polars::derive::polars_expr; | ||
use std::fmt::Write; | ||
|
||
fn pig_latin_str(value: &str, output: &mut String) { | ||
if let Some(first_char) = value.chars().next() { | ||
write!(output, "{}{}ay", &value[1..], first_char).unwrap() | ||
} | ||
} | ||
|
||
#[polars_expr(output_type=Utf8)] | ||
fn pig_latinnify(inputs: &[Series]) -> PolarsResult<Series> { | ||
let ca = inputs[0].utf8()?; | ||
let out: Utf8Chunked = ca.apply_to_buffer(pig_latin_str); | ||
Ok(out.into_series()) | ||
} | ||
|
||
#[polars_expr(output_type=Float64)] | ||
fn jaccard_similarity(inputs: &[Series]) -> PolarsResult<Series> { | ||
let a = inputs[0].list()?; | ||
let b = inputs[1].list()?; | ||
crate::distances::naive_jaccard_sim(a, b).map(|ca| ca.into_series()) | ||
} | ||
|
||
#[polars_expr(output_type=Float64)] | ||
fn hamming_distance(inputs: &[Series]) -> PolarsResult<Series> { | ||
let a = inputs[0].utf8()?; | ||
let b = inputs[1].utf8()?; | ||
let out: UInt32Chunked = | ||
arity::binary_elementwise_values(a, b, crate::distances::naive_hamming_dist); | ||
Ok(out.into_series()) | ||
} | ||
|
||
fn haversine_output(input_fields: &[Field]) -> PolarsResult<Field> { | ||
FieldsMapper::new(input_fields).map_to_float_dtype() | ||
} | ||
|
||
#[polars_expr(type_func=haversine_output)] | ||
fn haversine(inputs: &[Series]) -> PolarsResult<Series> { | ||
let out = match inputs[0].dtype() { | ||
DataType::Float32 => { | ||
let start_lat = inputs[0].f32().unwrap(); | ||
let start_long = inputs[1].f32().unwrap(); | ||
let end_lat = inputs[2].f32().unwrap(); | ||
let end_long = inputs[3].f32().unwrap(); | ||
crate::distances::naive_haversine(start_lat, start_long, end_lat, end_long)? | ||
.into_series() | ||
} | ||
DataType::Float64 => { | ||
let start_lat = inputs[0].f64().unwrap(); | ||
let start_long = inputs[1].f64().unwrap(); | ||
let end_lat = inputs[2].f64().unwrap(); | ||
let end_long = inputs[3].f64().unwrap(); | ||
crate::distances::naive_haversine(start_lat, start_long, end_lat, end_long)? | ||
.into_series() | ||
} | ||
_ => unimplemented!(), | ||
}; | ||
Ok(out) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
mod distances; | ||
mod expressions; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
maturin |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
import polars as pl | ||
from expression_lib import Language, Distance | ||
|
||
df = pl.DataFrame({ | ||
"names": ["Richard", "Alice", "Bob"], | ||
"moons": ["full", "half", "red"], | ||
"dist_a": [[12, 32, 1], [], [1, -2]], | ||
"dist_b": [[-12, 1], [43], [876, -45, 9]] | ||
}) | ||
|
||
|
||
out = df.with_columns( | ||
pig_latin = pl.col("names").language.pig_latinnify() | ||
).with_columns( | ||
hamming_dist = pl.col("names").dist.hamming_distance("pig_latin"), | ||
jaccard_sim = pl.col("dist_a").dist.jaccard_similarity("dist_b") | ||
) | ||
|
||
print(out) |
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.