Skip to content

Commit

Permalink
add cosine_similarity_2d function
Browse files Browse the repository at this point in the history
  • Loading branch information
georgypv committed Jan 29, 2024
1 parent 057184e commit 041dc67
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 2 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "polars_coord_transforms"
version = "0.5.0"
version = "0.6.0"
edition = "2021"

[lib]
Expand Down
8 changes: 8 additions & 0 deletions polars_coord_transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,14 @@ def euclidean_2d(self, other: pl.Expr) -> pl.Expr:
is_elementwise=True,
args=[other],
)

def cosine_similarity_2d(self, other: pl.Expr) -> pl.Expr:
return self._expr.register_plugin(
lib=lib,
symbol="cosine_similarity_2d",
is_elementwise=True,
args=[other,]
)


class CoordTransformExpr(pl.Expr):
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ classifiers = [
"Programming Language :: Python :: Implementation :: CPython",
"Programming Language :: Python :: Implementation :: PyPy",
]
version = "0.5.0"
version = "0.6.0"
authors = [
{name="Georgy Popov"}
]
Expand Down
14 changes: 14 additions & 0 deletions src/distance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,17 @@ pub fn euclidean_3d_elementwise(x1: f64, y1: f64, z1: f64, x2: f64, y2: f64, z2:
pub fn euclidean_2d_elementwise(x1: f64, y1: f64, x2: f64, y2: f64) -> f64 {
(((x2 - x1).powi(2)) + ((y2 - y1).powi(2))).sqrt()
}


pub fn cosine_similarity_2d_elementwise(x1: f64, y1: f64, x2: f64, y2: f64) -> f64 {
let dot_product = (x1*x2) + (y1*y2);
let magnitude1 = (x1.powi(2) + y1.powi(2)).powf(0.5);
let magnitude2 = (x2.powi(2) + y2.powi(2)).powf(0.5);

let res = if magnitude1 == 0.0 || magnitude2 == 0.0 {
0.0
} else {
dot_product / (magnitude1*magnitude2)
};
res
}
28 changes: 28 additions & 0 deletions src/expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,34 @@ fn euclidean_2d(inputs: &[Series]) -> PolarsResult<Series> {
}


#[polars_expr(output_type=Float64)]
fn cosine_similarity_2d(inputs: &[Series]) -> PolarsResult<Series> {
let ca1: &StructChunked = inputs[0].struct_()?;
let ca2: &StructChunked = inputs[1].struct_()?;

let (x1, y1, _z1) = unpack_xyz(ca1, false);
let (x2, y2, _z2) = unpack_xyz(ca2, false);

let iter = izip!(
x1.f64()?,
y1.f64()?,
x2.f64()?,
y2.f64()?
).into_iter().map(
|(x1_op, y1_op, x2_op, y2_op)| {
match (x1_op, y1_op, x2_op, y2_op) {
(Some(x1), Some(y1), Some(x2), Some(y2)) => cosine_similarity_2d_elementwise(x1, y1, x2, y2),
_ => panic!("Unable to find cosine similarity!")
}
});

let out_ca: ChunkedArray<Float64Type> = iter.collect_ca_with_dtype("cosine_similarity", DataType::Float64);
Ok(out_ca.into_series())

}



#[polars_expr(output_type=Float64)]
fn euclidean_3d(inputs: &[Series]) -> PolarsResult<Series> {
let ca1: &StructChunked = inputs[0].struct_()?;
Expand Down

0 comments on commit 041dc67

Please sign in to comment.