diff --git a/chimp_chomp/src/postprocessing.rs b/chimp_chomp/src/postprocessing.rs index 2dec92ea..3e39da06 100644 --- a/chimp_chomp/src/postprocessing.rs +++ b/chimp_chomp/src/postprocessing.rs @@ -32,24 +32,26 @@ fn insertion_mask( mask } -fn optimal_insert_position(insertion_mask: Array2) -> Result { - let mask = Mat::from_exact_iter( - insertion_mask - .mapv(|pixel| if pixel { std::u8::MAX } else { 0 }) +fn ndarray_mask_into_opencv_mat(mask: Array2) -> Mat { + Mat::from_exact_iter( + mask.mapv(|pixel| if pixel { std::u8::MAX } else { 0 }) .into_iter(), ) .unwrap() .reshape_nd( 1, - &insertion_mask + &mask .shape() .iter() .map(|&dim| dim as i32) .collect::>(), ) - .unwrap(); + .unwrap() +} + +fn optimal_insert_position(insertion_mask: Mat) -> Result { let mut distances = Mat::default(); - distance_transform(&mask, &mut distances, DIST_L1, DIST_MASK_3, CV_8U).unwrap(); + distance_transform(&insertion_mask, &mut distances, DIST_L1, DIST_MASK_3, CV_8U).unwrap(); let (furthest_point, _) = distances .iter::() .unwrap() @@ -99,7 +101,7 @@ fn postprocess_inference( let (crystals, crystal_masks) = find_crystal_instances(&labels, &bboxes, &masks) .into_iter() .unzip(); - let insertion_mask = insertion_mask(drop_mask, crystal_masks); + let insertion_mask = ndarray_mask_into_opencv_mat(insertion_mask(drop_mask, crystal_masks)); let insertion_point = optimal_insert_position(insertion_mask)?; Ok(Contents { drop, @@ -122,3 +124,43 @@ pub async fn inference_postprocessing( Err(err) => error_tx.send((err, job)).unwrap(), } } + +#[cfg(test)] +mod tests { + use super::optimal_insert_position; + use opencv::{ + core::{Point_, Scalar, CV_8UC1}, + imgproc::{circle, LINE_8}, + prelude::Mat, + }; + + #[test] + fn optimal_insert_found() { + let mut test_image = Mat::new_nd_with_default( + &[1024, 1224], + CV_8UC1, + Scalar::new(0_f64, 0_f64, 0_f64, std::u8::MAX as f64), + ) + .unwrap(); + circle( + &mut test_image, + Point_::new(256, 512), + 128, + Scalar::new( + std::u8::MAX as f64, + std::u8::MAX as f64, + std::u8::MAX as f64, + std::u8::MAX as f64, + ), + -1, + LINE_8, + 0, + ) + .unwrap(); + + let position = optimal_insert_position(test_image).unwrap(); + + assert_eq!(256, position.x); + assert_eq!(512, position.y); + } +}