Skip to content

Commit

Permalink
Bump pyo3 version
Browse files Browse the repository at this point in the history
switch to pseudonormal containment check is now n_rays=0 rather than
None, due to PyO3/pyo3#3735
  • Loading branch information
clbarnes committed Jan 10, 2024
1 parent 6138656 commit f76b0ae
Show file tree
Hide file tree
Showing 8 changed files with 49 additions and 63 deletions.
50 changes: 30 additions & 20 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@ version = "0.19.0"
edition = "2018"

[dependencies]
pyo3 = { version = "0.18", features = ["extension-module", "abi3-py38"] }
pyo3 = { version = "0.20", features = ["extension-module", "abi3-py38"] }
parry3d-f64 = { version = "0.13", features = ["dim3", "f64", "enhanced-determinism"] }
rayon = "1.7"
rand = "0.8"
rand_pcg = "0.3"
numpy = "0.18"
numpy = "0.20"
log = "0.4.20"

# ndarray is a dependency of numpy.
Expand Down
2 changes: 1 addition & 1 deletion python/ncollpyde/_ncollpyde.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class TriMeshWrapper:
self, points: Points, indices: Indices, n_rays: int, ray_seed: int
): ...
def contains(
self, points: Points, n_rays: Optional[int], consensus: int, parallel: bool
self, points: Points, n_rays: int, consensus: int, parallel: bool
) -> npt.NDArray[np.bool_]: ...
def distance(
self, points: Points, signed: bool, parallel: bool
Expand Down
5 changes: 3 additions & 2 deletions python/ncollpyde/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def contains(
:param coords:
:param n_rays: Optional[int]
If None, use the maximum rays defined on construction.
If < 0, use signed distance strategy
If < 1, use signed distance strategy
(more robust, but slower for many meshes).
Otherwise, use this many meshes, up to the maximum defined on construction.
:param consensus: Optional[int]
Expand All @@ -235,7 +235,7 @@ def contains(
coords = self._as_points(coords)
if n_rays is None:
n_rays = self.n_rays
elif n_rays < 0:
elif n_rays < 1:
n_rays = None
elif n_rays > self.n_rays:
logger.warning(
Expand All @@ -245,6 +245,7 @@ def contains(

if n_rays is None:
consensus = 1
n_rays = 0
else:
if consensus is None:
consensus = n_rays // 2 + 1
Expand Down
41 changes: 11 additions & 30 deletions src/interface.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use std::fmt::Debug;
use std::iter::repeat_with;

use ndarray::{Array2, ArrayView1};
Expand All @@ -17,13 +16,13 @@ use crate::utils::{
points_cross_mesh, random_dir, sdf_inner, Precision, FLAGS,
};

fn vec_to_point<T: 'static + Debug + PartialEq + Copy>(v: Vec<T>) -> Point<T> {
Point::new(v[0], v[1], v[2])
}
// fn vec_to_point<T: 'static + Debug + PartialEq + Copy>(v: Vec<T>) -> Point<T> {
// Point::new(v[0], v[1], v[2])
// }

fn point_to_vec<T: 'static + Debug + PartialEq + Copy>(p: &Point<T>) -> Vec<T> {
vec![p.x, p.y, p.z]
}
// fn point_to_vec<T: 'static + Debug + PartialEq + Copy>(p: &Point<T>) -> Vec<T> {
// vec![p.x, p.y, p.z]
// }

#[pyclass]
pub struct TriMeshWrapper {
Expand Down Expand Up @@ -96,20 +95,21 @@ impl TriMeshWrapper {
collected.into_pyarray(py)
}

// #[pyo3(signature = (points, n_rays=None, consensus=None, parallel=None))]
pub fn contains<'py>(
&self,
py: Python<'py>,
points: PyReadonlyArray2<Precision>,
n_rays: Option<usize>,
n_rays: usize,
consensus: usize,
parallel: bool,
) -> &'py PyArray1<bool> {
let p_arr = points.as_array();
let zipped = Zip::from(p_arr.rows());

let collected = if let Some(n) = n_rays {
let collected = if n_rays >= 1 {
// use ray casting
let rays = &self.ray_directions[..n];
let rays = &self.ray_directions[..n_rays];
let clos = |r: ArrayView1<f64>| {
mesh_contains_point(&self.mesh, &Point::new(r[0], r[1], r[2]), rays, consensus)
};
Expand Down Expand Up @@ -242,7 +242,7 @@ impl TriMeshWrapper {
.as_array()
.rows()
.into_iter()
.zip(tgt_points.as_array().rows().into_iter())
.zip(tgt_points.as_array().rows())
.zip(0_u64..)
.for_each(|((src, tgt), i)| {
if let Some((pt, is_bf)) = points_cross_mesh(
Expand All @@ -265,25 +265,6 @@ impl TriMeshWrapper {
)
}

#[deprecated]
pub fn intersections_many_threaded(
&self,
src_points: Vec<Vec<Precision>>,
tgt_points: Vec<Vec<Precision>>,
) -> (Vec<u64>, Vec<Vec<Precision>>, Vec<bool>) {
let (idxs, (intersections, is_backface)) = src_points
.into_par_iter()
.zip(tgt_points.into_par_iter())
.enumerate()
.filter_map(|(i, (src, tgt))| {
points_cross_mesh(&self.mesh, &vec_to_point(src), &vec_to_point(tgt))
.map(|o| (i as u64, (point_to_vec(&o.0), o.1)))
})
.unzip();

(idxs, intersections, is_backface)
}

pub fn intersections_many_threaded2<'py>(
&self,
py: Python<'py>,
Expand Down
2 changes: 1 addition & 1 deletion src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ pub fn mesh_contains_point(
return false;
}
}
return false;
false
}

pub fn mesh_contains_point_oriented(mesh: &TriMesh, point: &Point<f64>) -> bool {
Expand Down
2 changes: 1 addition & 1 deletion tests/test_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def test_trimesh_contains(trimesh_volume, sample_points, expected, benchmark):


@pytest.mark.benchmark(group=CONTAINS_SERIAL)
@pytest.mark.parametrize("n_rays", [0, 1, 2, 4, 8, 16])
@pytest.mark.parametrize("n_rays", [1, 2, 4, 8, 16])
def test_ncollpyde_contains(mesh, n_rays, sample_points, expected, benchmark):
ncollpyde_volume = Volume.from_meshio(mesh, n_rays=n_rays)
ncollpyde_volume.threads = False
Expand Down
6 changes: 0 additions & 6 deletions tests/test_ncollpyde.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,6 @@ def test_contains_results(volume: Volume):
assert np.allclose(ray, psnorms)


def test_0_rays(mesh):
vol = Volume.from_meshio(mesh, n_rays=0)
points = [p for p, _ in points_expected]
assert np.array_equal(vol.contains(points), [False] * len(points))


def test_no_validation(mesh):
triangles = mesh.cells_dict["triangle"]
Volume(mesh.points, triangles, True)
Expand Down

0 comments on commit f76b0ae

Please sign in to comment.