diff --git a/chimp_chomp/src/image_loading.rs b/chimp_chomp/src/image_loading.rs index 48846aa0..86e42c48 100644 --- a/chimp_chomp/src/image_loading.rs +++ b/chimp_chomp/src/image_loading.rs @@ -9,12 +9,15 @@ use opencv::{ }; use std::path::Path; +/// A grayscale image of the well in [W, H, C] format. #[derive(Debug, Deref)] pub struct WellImage(pub Mat); +/// A RGB image of the well in [C, W, H] format. #[derive(Debug, Deref)] pub struct ChimpImage(Array); +/// Converts an image from a [`Mat`] in BGR and ordered in [W, H, C] to a [`Array`] in RGB and ordered in [C, W, H] and resizes it to the input dimensions of the model. fn prepare_chimp(image: &Mat, width: i32, height: i32) -> ChimpImage { let mut resized_image = Mat::default(); resize( @@ -58,17 +61,17 @@ fn prepare_chimp(image: &Mat, width: i32, height: i32) -> ChimpImage { ChimpImage(chimp_image) } +/// Converts an image from BGR to grayscale. fn prepare_well(image: &Mat) -> WellImage { let mut well_image = Mat::default(); cvt_color(&image, &mut well_image, COLOR_BGR2GRAY, 0).unwrap(); WellImage(well_image) } -pub fn load_image( - path: impl AsRef, - chimp_width: u32, - chimp_height: u32, -) -> Result<(ChimpImage, WellImage), anyhow::Error> { +/// Reads an image from file. +/// +/// Returns an [`anyhow::Error`] if the image could not be read or is empty. +fn read_image(path: impl AsRef) -> Result { let image = imread( path.as_ref() .to_str() @@ -78,7 +81,18 @@ pub fn load_image( if image.empty() { return Err(anyhow::Error::msg("No image data was loaded")); } + Ok(image) +} +/// Reads an image from file and prepares both a [`ChimpImage`] and a [`WellImage`]. +/// +/// Returns an [`anyhow::Error`] if the image could not be read or is empty. +pub fn load_image( + path: impl AsRef, + chimp_width: u32, + chimp_height: u32, +) -> Result<(ChimpImage, WellImage), anyhow::Error> { + let image = read_image(path)?; let well_image = prepare_well(&image); let chimp_image = prepare_chimp(&image, chimp_width as i32, chimp_height as i32); diff --git a/chimp_chomp/src/inference.rs b/chimp_chomp/src/inference.rs index e2ac5af9..afd07afc 100644 --- a/chimp_chomp/src/inference.rs +++ b/chimp_chomp/src/inference.rs @@ -10,11 +10,18 @@ use ort::{ use std::{env::current_exe, ops::Deref, sync::Arc}; use tokio::sync::mpsc::{error::TryRecvError, Receiver, UnboundedSender}; +/// The raw box predictor output of a MaskRCNN. pub type BBoxes = Array2; +/// The raw label output of a MaskRCNN. pub type Labels = Array1; +/// The raw scores output of a MaskRCNN. pub type Scores = Array1; +/// The raw masks output of a MaskRCNN. pub type Masks = Array3; +/// Starts an inference session by setting up the ONNX Runtime environment and loading the model. +/// +/// Returns an [`anyhow::Error`] if the environment could not be built or if the model could not be loaded. pub fn setup_inference_session() -> Result { let environment = Arc::new( Environment::builder() @@ -32,6 +39,9 @@ pub fn setup_inference_session() -> Result { )?) } +/// Performs inference on a batch of images, dummy images are used to pad the tesnor if underfull. +/// +/// Returns a set of predictions, where each instances corresponds to the an input image, order is maintained. fn do_inference( session: &Session, images: &[ChimpImage], @@ -85,6 +95,10 @@ fn do_inference( .collect() } +/// Listens to a [`Receiver`] for instances of [`ChimpImage`] and performs batch inference on these. +/// +/// Each pass, all available images in the [`tokio::sync::mpsc::channel`] - up to the batch size - are taken and passed to the model for inference. +/// Model predictions are sent over a [`tokio::sync::mpsc::unbounded_channel`]. pub async fn inference_worker( session: Session, batch_size: usize, diff --git a/chimp_chomp/src/jobs.rs b/chimp_chomp/src/jobs.rs index 8162336b..a82a6380 100644 --- a/chimp_chomp/src/jobs.rs +++ b/chimp_chomp/src/jobs.rs @@ -13,10 +13,17 @@ use tokio::sync::mpsc::{OwnedPermit, UnboundedSender}; use url::Url; use uuid::Uuid; +/// Creates a RabbitMQ [`Connection`] with [`Default`] [`lapin::ConnectionProperties`]. +/// +/// Returns a [`lapin::Error`] if a connection could not be established. pub async fn setup_rabbitmq_client(address: Url) -> Result { lapin::Connection::connect(address.as_str(), lapin::ConnectionProperties::default()).await } +/// Joins a RabbitMQ channel, creating a [`Consumer`] with [`Default`] [`BasicConsumeOptions`] and [`FieldTable`]. +/// The consumer tag is generated following the format `chimp_chomp_${`[`Uuid::new_v4`]`}`. +/// +/// Returns a [`lapin::Error`] if the requested channel is not available. pub async fn setup_job_consumer( rabbitmq_channel: Channel, channel: impl AsRef, @@ -33,6 +40,12 @@ pub async fn setup_job_consumer( .await } +/// Reads a message from the [`lapin::Consumer`] then loads and prepares the requested image for downstream processing. +/// +/// An [`OwnedPermit`] to send to the chimp [`tokio::sync::mpsc::channel`] is required such that backpressure is be propagated to message consumption. +/// +/// The prepared images are sent over a [`tokio::sync::mpsc::channel`] and [`tokio::sync::mpsc::unbounded_channel`] if sucessful. +/// An [`anyhow::Error`] is sent if the image could not be read or is empty. pub async fn consume_job( mut consumer: Consumer, input_width: u32, @@ -59,6 +72,7 @@ pub async fn consume_job( }; } +/// Takes the results of postprocessing and well centering and publishes a [`Response::Success`] to the RabbitMQ [`Channel`] provided by the [`Job`]. pub async fn produce_response( contents: Contents, well_location: Circle, @@ -88,6 +102,7 @@ pub async fn produce_response( .unwrap(); } +/// Takes an error generated in one of the prior stages and publishes a [`Response::Failure`] to the RabbitMQ [`Channel`] provided by the [`Job`]. pub async fn produce_error(error: anyhow::Error, job: Job, rabbitmq_channel: Channel) { println!("Producing error for: {job:?}"); rabbitmq_channel diff --git a/chimp_chomp/src/main.rs b/chimp_chomp/src/main.rs index b37d7cb9..3083d86e 100644 --- a/chimp_chomp/src/main.rs +++ b/chimp_chomp/src/main.rs @@ -1,11 +1,17 @@ #![forbid(unsafe_code)] #![warn(missing_docs)] +#![warn(clippy::missing_docs_in_private_items)] #![doc=include_str!("../README.md")] +/// Utilities for loading images. mod image_loading; +/// Neural Netowrk inference with [`ort`]. mod inference; +/// RabbitMQ [`Job`] queue consumption and [`Response`] publishing. mod jobs; +/// Neural Network inference postprocessing with optimal insertion point finding. mod postprocessing; +/// Well localisation. mod well_centering; use crate::{ @@ -25,6 +31,7 @@ use std::{collections::HashMap, time::Duration}; use tokio::{select, spawn, task::JoinSet}; use url::Url; +/// An inference worker for the Crystal Hits in My Plate (CHiMP) neural network. #[derive(Debug, Parser)] #[command(author, version, about, long_about=None)] struct Cli { @@ -56,6 +63,7 @@ fn main() { runtime.block_on(run(args)); } +#[allow(clippy::missing_docs_in_private_items)] async fn run(args: Cli) { let session = setup_inference_session().unwrap(); let input_width = session.inputs[0].dimensions[3].unwrap(); diff --git a/chimp_chomp/src/postprocessing.rs b/chimp_chomp/src/postprocessing.rs index 3e39da06..95c04702 100644 --- a/chimp_chomp/src/postprocessing.rs +++ b/chimp_chomp/src/postprocessing.rs @@ -10,15 +10,21 @@ use opencv::{ }; use tokio::sync::mpsc::UnboundedSender; +/// The predicted contents of a well image. #[derive(Debug)] pub struct Contents { + /// The optimal point at which solvent should be inserted. pub insertion_point: Point, + /// A bounding box enclosing the drop of solution. pub drop: BBox, + /// A set of bounding boxes enclosing each crystal in the drop. pub crystals: Vec, } +/// The threshold to apply to the raw MaskRCNN [`Masks`] to generate a binary mask. const PREDICTION_THRESHOLD: f32 = 0.5; +/// Creates a mask of valid insertion positions by adding all pixels in the drop mask and subsequently subtracting those in the crystal masks. fn insertion_mask( drop_mask: ArrayView2, crystal_masks: Vec>, @@ -32,6 +38,7 @@ fn insertion_mask( mask } +/// Converts an [`Array2`] into an [`Mat`] of type [`CV_8U`] with the same dimensions. fn ndarray_mask_into_opencv_mat(mask: Array2) -> Mat { Mat::from_exact_iter( mask.mapv(|pixel| if pixel { std::u8::MAX } else { 0 }) @@ -49,6 +56,9 @@ fn ndarray_mask_into_opencv_mat(mask: Array2) -> Mat { .unwrap() } +/// Performs a distance transform to find the point in the mask which is furthest from any invalid region. +/// +/// Returns an [`anyhow::Error`] if no valid insertion point was found. fn optimal_insert_position(insertion_mask: Mat) -> Result { let mut distances = Mat::default(); distance_transform(&insertion_mask, &mut distances, DIST_L1, DIST_MASK_3, CV_8U).unwrap(); @@ -63,6 +73,7 @@ fn optimal_insert_position(insertion_mask: Mat) -> Result }) } +/// Converts an [`ArrayView) -> BBox { BBox { left: bbox[0], @@ -72,6 +83,9 @@ fn bbox_from_array(bbox: ArrayView) -> BBox { } } +/// Finds the first instance which is labelled as a drop. +/// +/// Returns an [`anyhow::Error`] if no drop instances were found. fn find_drop_instance<'a>( labels: &Labels, bboxes: &BBoxes, @@ -82,6 +96,7 @@ fn find_drop_instance<'a>( .context("No drop instances in prediction") } +/// Finds all instances which are labelled as crystals. fn find_crystal_instances<'a>( labels: &Labels, bboxes: &BBoxes, @@ -92,6 +107,9 @@ fn find_crystal_instances<'a>( .collect() } +/// Takes the results of inference on an image and uses it to produce useful regional data and an optimal insertion point. +/// +/// Returns an [`anyhow::Error`] if no drop instances could be found or if no valid insertion point was found. fn postprocess_inference( bboxes: BBoxes, labels: Labels, @@ -110,6 +128,10 @@ fn postprocess_inference( }) } +/// Takes the results of inference on an image and uses it to produce useful regional data and an optimal insertion point. +/// +/// The extracted [`Contents`] are sent over a [`tokio::sync::mpsc::unbounded_channel`] if sucessful. +/// An [`anyhow::Error`] is sent if no drop instances were found or if no valid insertion point was found. pub async fn inference_postprocessing( bboxes: BBoxes, labels: Labels, diff --git a/chimp_chomp/src/well_centering.rs b/chimp_chomp/src/well_centering.rs index b51a74b9..64424e75 100644 --- a/chimp_chomp/src/well_centering.rs +++ b/chimp_chomp/src/well_centering.rs @@ -9,6 +9,12 @@ use opencv::{ use std::ops::Deref; use tokio::sync::mpsc::UnboundedSender; +/// Uses a canny edge detector and a hough circle transform to localise a [`Circle`] of high contrast in the image. +/// +/// The circle is assumed to have a radius in [⅜ l, ½ l), where `l` denotes the shortest edge lenth of the image. +/// The circle with the most counts is selected. +/// +/// Returns an [`anyhow::Error`] if no circles were found. fn find_well_location(image: WellImage) -> Result { let min_side = *image.deref().mat_size().iter().min().unwrap(); let mut circles = Vector::::new(); @@ -37,6 +43,10 @@ fn find_well_location(image: WellImage) -> Result { }) } +/// Takes a grayscale image of the well and finds the center and radius. +/// +/// The extracted [`Circle`] is sent over a [`tokio::sync::mpsc::unbounded_channel`] if sucessful. +/// An [`anyhow::Error`] is sent if no circles were found. pub async fn well_centering( image: WellImage, job: Job,