Skip to content
This repository has been archived by the owner on Jul 23, 2024. It is now read-only.

Commit

Permalink
Add chimp_chomp docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
garryod committed Jul 18, 2023
1 parent 774a780 commit ed4ef50
Show file tree
Hide file tree
Showing 6 changed files with 88 additions and 5 deletions.
24 changes: 19 additions & 5 deletions chimp_chomp/src/image_loading.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,15 @@ use opencv::{
};
use std::path::Path;

/// A grayscale image of the well in [W, H, C] format.
#[derive(Debug, Deref)]
pub struct WellImage(pub Mat);

/// A RGB image of the well in [C, W, H] format.
#[derive(Debug, Deref)]
pub struct ChimpImage(Array<f32, Ix3>);

/// Converts an image from a [`Mat`] in BGR and ordered in [W, H, C] to a [`Array`] in RGB and ordered in [C, W, H] and resizes it to the input dimensions of the model.
fn prepare_chimp(image: &Mat, width: i32, height: i32) -> ChimpImage {
let mut resized_image = Mat::default();
resize(
Expand Down Expand Up @@ -58,17 +61,17 @@ fn prepare_chimp(image: &Mat, width: i32, height: i32) -> ChimpImage {
ChimpImage(chimp_image)
}

/// Converts an image from BGR to grayscale.
fn prepare_well(image: &Mat) -> WellImage {
let mut well_image = Mat::default();
cvt_color(&image, &mut well_image, COLOR_BGR2GRAY, 0).unwrap();
WellImage(well_image)
}

pub fn load_image(
path: impl AsRef<Path>,
chimp_width: u32,
chimp_height: u32,
) -> Result<(ChimpImage, WellImage), anyhow::Error> {
/// Reads an image from file.
///
/// Returns an [`anyhow::Error`] if the image could not be read or is empty.
fn read_image(path: impl AsRef<Path>) -> Result<Mat, anyhow::Error> {
let image = imread(
path.as_ref()
.to_str()
Expand All @@ -78,7 +81,18 @@ pub fn load_image(
if image.empty() {
return Err(anyhow::Error::msg("No image data was loaded"));
}
Ok(image)
}

/// Reads an image from file and prepares both a [`ChimpImage`] and a [`WellImage`].
///
/// Returns an [`anyhow::Error`] if the image could not be read or is empty.
pub fn load_image(
path: impl AsRef<Path>,
chimp_width: u32,
chimp_height: u32,
) -> Result<(ChimpImage, WellImage), anyhow::Error> {
let image = read_image(path)?;
let well_image = prepare_well(&image);
let chimp_image = prepare_chimp(&image, chimp_width as i32, chimp_height as i32);

Expand Down
14 changes: 14 additions & 0 deletions chimp_chomp/src/inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,18 @@ use ort::{
use std::{env::current_exe, ops::Deref, sync::Arc};
use tokio::sync::mpsc::{error::TryRecvError, Receiver, UnboundedSender};

/// The raw box predictor output of a MaskRCNN.
pub type BBoxes = Array2<f32>;
/// The raw label output of a MaskRCNN.
pub type Labels = Array1<i64>;
/// The raw scores output of a MaskRCNN.
pub type Scores = Array1<f32>;
/// The raw masks output of a MaskRCNN.
pub type Masks = Array3<f32>;

/// Starts an inference session by setting up the ONNX Runtime environment and loading the model.
///
/// Returns an [`anyhow::Error`] if the environment could not be built or if the model could not be loaded.
pub fn setup_inference_session() -> Result<Session, anyhow::Error> {
let environment = Arc::new(
Environment::builder()
Expand All @@ -32,6 +39,9 @@ pub fn setup_inference_session() -> Result<Session, anyhow::Error> {
)?)
}

/// Performs inference on a batch of images, dummy images are used to pad the tesnor if underfull.
///
/// Returns a set of predictions, where each instances corresponds to the an input image, order is maintained.
fn do_inference(
session: &Session,
images: &[ChimpImage],
Expand Down Expand Up @@ -85,6 +95,10 @@ fn do_inference(
.collect()
}

/// Listens to a [`Receiver`] for instances of [`ChimpImage`] and performs batch inference on these.
///
/// Each pass, all available images in the [`tokio::sync::mpsc::channel`] - up to the batch size - are taken and passed to the model for inference.
/// Model predictions are sent over a [`tokio::sync::mpsc::unbounded_channel`].
pub async fn inference_worker(
session: Session,
batch_size: usize,
Expand Down
15 changes: 15 additions & 0 deletions chimp_chomp/src/jobs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,17 @@ use tokio::sync::mpsc::{OwnedPermit, UnboundedSender};
use url::Url;
use uuid::Uuid;

/// Creates a RabbitMQ [`Connection`] with [`Default`] [`lapin::ConnectionProperties`].
///
/// Returns a [`lapin::Error`] if a connection could not be established.
pub async fn setup_rabbitmq_client(address: Url) -> Result<Connection, lapin::Error> {
lapin::Connection::connect(address.as_str(), lapin::ConnectionProperties::default()).await
}

/// Joins a RabbitMQ channel, creating a [`Consumer`] with [`Default`] [`BasicConsumeOptions`] and [`FieldTable`].
/// The consumer tag is generated following the format `chimp_chomp_${`[`Uuid::new_v4`]`}`.
///
/// Returns a [`lapin::Error`] if the requested channel is not available.
pub async fn setup_job_consumer(
rabbitmq_channel: Channel,
channel: impl AsRef<str>,
Expand All @@ -33,6 +40,12 @@ pub async fn setup_job_consumer(
.await
}

/// Reads a message from the [`lapin::Consumer`] then loads and prepares the requested image for downstream processing.
///
/// An [`OwnedPermit`] to send to the chimp [`tokio::sync::mpsc::channel`] is required such that backpressure is be propagated to message consumption.
///
/// The prepared images are sent over a [`tokio::sync::mpsc::channel`] and [`tokio::sync::mpsc::unbounded_channel`] if sucessful.
/// An [`anyhow::Error`] is sent if the image could not be read or is empty.
pub async fn consume_job(
mut consumer: Consumer,
input_width: u32,
Expand All @@ -59,6 +72,7 @@ pub async fn consume_job(
};
}

/// Takes the results of postprocessing and well centering and publishes a [`Response::Success`] to the RabbitMQ [`Channel`] provided by the [`Job`].
pub async fn produce_response(
contents: Contents,
well_location: Circle,
Expand Down Expand Up @@ -88,6 +102,7 @@ pub async fn produce_response(
.unwrap();
}

/// Takes an error generated in one of the prior stages and publishes a [`Response::Failure`] to the RabbitMQ [`Channel`] provided by the [`Job`].
pub async fn produce_error(error: anyhow::Error, job: Job, rabbitmq_channel: Channel) {
println!("Producing error for: {job:?}");
rabbitmq_channel
Expand Down
8 changes: 8 additions & 0 deletions chimp_chomp/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
#![forbid(unsafe_code)]
#![warn(missing_docs)]
#![warn(clippy::missing_docs_in_private_items)]
#![doc=include_str!("../README.md")]

/// Utilities for loading images.
mod image_loading;
/// Neural Netowrk inference with [`ort`].
mod inference;
/// RabbitMQ [`Job`] queue consumption and [`Response`] publishing.
mod jobs;
/// Neural Network inference postprocessing with optimal insertion point finding.
mod postprocessing;
/// Well localisation.
mod well_centering;

use crate::{
Expand All @@ -25,6 +31,7 @@ use std::{collections::HashMap, time::Duration};
use tokio::{select, spawn, task::JoinSet};
use url::Url;

/// An inference worker for the Crystal Hits in My Plate (CHiMP) neural network.
#[derive(Debug, Parser)]
#[command(author, version, about, long_about=None)]
struct Cli {
Expand Down Expand Up @@ -56,6 +63,7 @@ fn main() {
runtime.block_on(run(args));
}

#[allow(clippy::missing_docs_in_private_items)]
async fn run(args: Cli) {
let session = setup_inference_session().unwrap();
let input_width = session.inputs[0].dimensions[3].unwrap();
Expand Down
22 changes: 22 additions & 0 deletions chimp_chomp/src/postprocessing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,21 @@ use opencv::{
};
use tokio::sync::mpsc::UnboundedSender;

/// The predicted contents of a well image.
#[derive(Debug)]
pub struct Contents {
/// The optimal point at which solvent should be inserted.
pub insertion_point: Point,
/// A bounding box enclosing the drop of solution.
pub drop: BBox,
/// A set of bounding boxes enclosing each crystal in the drop.
pub crystals: Vec<BBox>,
}

/// The threshold to apply to the raw MaskRCNN [`Masks`] to generate a binary mask.
const PREDICTION_THRESHOLD: f32 = 0.5;

/// Creates a mask of valid insertion positions by adding all pixels in the drop mask and subsequently subtracting those in the crystal masks.
fn insertion_mask(
drop_mask: ArrayView2<f32>,
crystal_masks: Vec<ArrayView2<'_, f32>>,
Expand All @@ -32,6 +38,7 @@ fn insertion_mask(
mask
}

/// Converts an [`Array2<bool>`] into an [`Mat`] of type [`CV_8U`] with the same dimensions.
fn ndarray_mask_into_opencv_mat(mask: Array2<bool>) -> Mat {
Mat::from_exact_iter(
mask.mapv(|pixel| if pixel { std::u8::MAX } else { 0 })
Expand All @@ -49,6 +56,9 @@ fn ndarray_mask_into_opencv_mat(mask: Array2<bool>) -> Mat {
.unwrap()
}

/// Performs a distance transform to find the point in the mask which is furthest from any invalid region.
///
/// Returns an [`anyhow::Error`] if no valid insertion point was found.
fn optimal_insert_position(insertion_mask: Mat) -> Result<Point, anyhow::Error> {
let mut distances = Mat::default();
distance_transform(&insertion_mask, &mut distances, DIST_L1, DIST_MASK_3, CV_8U).unwrap();
Expand All @@ -63,6 +73,7 @@ fn optimal_insert_position(insertion_mask: Mat) -> Result<Point, anyhow::Error>
})
}

/// Converts an [`ArrayView<f32, Ix1`] of length 4 into a [`BBox`] according to the layout of a MaskRCNN box prediction.
fn bbox_from_array(bbox: ArrayView<f32, Ix1>) -> BBox {
BBox {
left: bbox[0],
Expand All @@ -72,6 +83,9 @@ fn bbox_from_array(bbox: ArrayView<f32, Ix1>) -> BBox {
}
}

/// Finds the first instance which is labelled as a drop.
///
/// Returns an [`anyhow::Error`] if no drop instances were found.
fn find_drop_instance<'a>(
labels: &Labels,
bboxes: &BBoxes,
Expand All @@ -82,6 +96,7 @@ fn find_drop_instance<'a>(
.context("No drop instances in prediction")
}

/// Finds all instances which are labelled as crystals.
fn find_crystal_instances<'a>(
labels: &Labels,
bboxes: &BBoxes,
Expand All @@ -92,6 +107,9 @@ fn find_crystal_instances<'a>(
.collect()
}

/// Takes the results of inference on an image and uses it to produce useful regional data and an optimal insertion point.
///
/// Returns an [`anyhow::Error`] if no drop instances could be found or if no valid insertion point was found.
fn postprocess_inference(
bboxes: BBoxes,
labels: Labels,
Expand All @@ -110,6 +128,10 @@ fn postprocess_inference(
})
}

/// Takes the results of inference on an image and uses it to produce useful regional data and an optimal insertion point.
///
/// The extracted [`Contents`] are sent over a [`tokio::sync::mpsc::unbounded_channel`] if sucessful.
/// An [`anyhow::Error`] is sent if no drop instances were found or if no valid insertion point was found.
pub async fn inference_postprocessing(
bboxes: BBoxes,
labels: Labels,
Expand Down
10 changes: 10 additions & 0 deletions chimp_chomp/src/well_centering.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@ use opencv::{
use std::ops::Deref;
use tokio::sync::mpsc::UnboundedSender;

/// Uses a canny edge detector and a hough circle transform to localise a [`Circle`] of high contrast in the image.
///
/// The circle is assumed to have a radius in [⅜ l, ½ l), where `l` denotes the shortest edge lenth of the image.
/// The circle with the most counts is selected.
///
/// Returns an [`anyhow::Error`] if no circles were found.
fn find_well_location(image: WellImage) -> Result<Circle, anyhow::Error> {
let min_side = *image.deref().mat_size().iter().min().unwrap();
let mut circles = Vector::<Vec4f>::new();
Expand Down Expand Up @@ -37,6 +43,10 @@ fn find_well_location(image: WellImage) -> Result<Circle, anyhow::Error> {
})
}

/// Takes a grayscale image of the well and finds the center and radius.
///
/// The extracted [`Circle`] is sent over a [`tokio::sync::mpsc::unbounded_channel`] if sucessful.
/// An [`anyhow::Error`] is sent if no circles were found.
pub async fn well_centering(
image: WellImage,
job: Job,
Expand Down

0 comments on commit ed4ef50

Please sign in to comment.