From 6c66b77f428f8b8cdaa12e50107d44a6076e2960 Mon Sep 17 00:00:00 2001 From: Garry O'Donnell Date: Tue, 4 Jul 2023 09:46:18 +0000 Subject: [PATCH 01/45] Create chimp chomp crate --- Cargo.lock | 4 ++++ Cargo.toml | 1 + chimp_chomp/Cargo.toml | 6 ++++++ chimp_chomp/src/main.rs | 3 +++ 4 files changed, 14 insertions(+) create mode 100644 chimp_chomp/Cargo.toml create mode 100644 chimp_chomp/src/main.rs diff --git a/Cargo.lock b/Cargo.lock index 7cafbae6..286ddb12 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -428,6 +428,10 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "chimp_chomp" +version = "0.1.0" + [[package]] name = "chrono" version = "0.4.26" diff --git a/Cargo.toml b/Cargo.toml index 19879be0..678dc2e1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,6 @@ [workspace] members = [ + "chimp_chomp", "graphql_endpoints", "graphql_event_broker", "opa_client", diff --git a/chimp_chomp/Cargo.toml b/chimp_chomp/Cargo.toml new file mode 100644 index 00000000..64d7505b --- /dev/null +++ b/chimp_chomp/Cargo.toml @@ -0,0 +1,6 @@ +[package] +name = "chimp_chomp" +version = "0.1.0" +edition = "2021" + +[dependencies] diff --git a/chimp_chomp/src/main.rs b/chimp_chomp/src/main.rs new file mode 100644 index 00000000..4df2bf71 --- /dev/null +++ b/chimp_chomp/src/main.rs @@ -0,0 +1,3 @@ +fn main() { + println!("Hello world") +} From 1798efced7cb5db61976c421a640aa64e1818165 Mon Sep 17 00:00:00 2001 From: Garry O'Donnell Date: Tue, 4 Jul 2023 09:50:53 +0000 Subject: [PATCH 02/45] Setup ONNXRuntime inference session --- Cargo.lock | 253 ++++++++++++++++++++++++++++++++++++++- chimp_chomp/Cargo.toml | 7 ++ chimp_chomp/src/main.rs | 18 ++- chimp_chomp/src/model.rs | 16 +++ 4 files changed, 290 insertions(+), 4 deletions(-) create mode 100644 chimp_chomp/src/model.rs diff --git a/Cargo.lock b/Cargo.lock index 286ddb12..3b4a5e3b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -350,7 +350,7 @@ dependencies = [ "cc", "cfg-if", "libc", - "miniz_oxide", + "miniz_oxide 0.6.2", "object", "rustc-demangle", ] @@ -431,6 +431,11 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "chimp_chomp" version = "0.1.0" +dependencies = [ + "clap 4.3.9", + "dotenvy", + "ort", +] [[package]] name = "chrono" @@ -580,6 +585,15 @@ dependencies = [ "libc", ] +[[package]] +name = "crc32fast" +version = "1.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b540bd8bc810d3885c6ea91e2018302f68baba2129ab3e88f32389ee9370880d" +dependencies = [ + "cfg-if", +] + [[package]] name = "crossbeam-queue" version = "0.3.8" @@ -757,6 +771,28 @@ dependencies = [ "instant", ] +[[package]] +name = "filetime" +version = "0.2.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5cbc844cecaee9d4443931972e1289c8ff485cb4cc2767cb03ca139ed6885153" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall 0.2.16", + "windows-sys 0.48.0", +] + +[[package]] +name = "flate2" +version = "1.0.26" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b9429470923de8e8cbd4d2dc513535400b4b3fef0319fb5c4e1f520a7bef743" +dependencies = [ + "crc32fast", + "miniz_oxide 0.7.1", +] + [[package]] name = "flume" version = "0.10.14" @@ -766,7 +802,7 @@ dependencies = [ "futures-core", "futures-sink", "pin-project", - "spin", + "spin 0.9.8", ] [[package]] @@ -1413,6 +1449,16 @@ version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b87248edafb776e59e6ee64a79086f65890d3510f2c656c000bf2a7e8a0aea40" +[[package]] +name = "matrixmultiply" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "090126dc04f95dc0d1c1c91f61bdd474b3930ca064c1edc8a849da2c6cbe1e77" +dependencies = [ + "autocfg", + "rawpointer", +] + [[package]] name = "md-5" version = "0.10.5" @@ -1449,6 +1495,15 @@ dependencies = [ "adler", ] +[[package]] +name = "miniz_oxide" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7810e0be55b428ada41041c41f32c9f1a42817901b4ccf45fa3d4b6561e74c7" +dependencies = [ + "adler", +] + [[package]] name = "mio" version = "0.8.8" @@ -1474,7 +1529,7 @@ dependencies = [ "log", "memchr", "mime", - "spin", + "spin 0.9.8", "version_check", ] @@ -1496,6 +1551,19 @@ dependencies = [ "tempfile", ] +[[package]] +name = "ndarray" +version = "0.15.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "adb12d4e967ec485a5f71c6311fe28158e9d6f4bc4a447b474184d0f91a8fa32" +dependencies = [ + "matrixmultiply", + "num-complex", + "num-integer", + "num-traits", + "rawpointer", +] + [[package]] name = "nom" version = "7.1.3" @@ -1516,6 +1584,25 @@ dependencies = [ "winapi", ] +[[package]] +name = "num-complex" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "02e0d21255c828d6f128a1e41534206671e8c3ea0c62f32291e808dc82cff17d" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-integer" +version = "0.1.45" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "225d3389fb3509a24c93f5c29eb6bde2586b98d9f016636dff58d7c6f7569cd9" +dependencies = [ + "autocfg", + "num-traits", +] + [[package]] name = "num-traits" version = "0.2.15" @@ -1606,6 +1693,25 @@ dependencies = [ "vcpkg", ] +[[package]] +name = "ort" +version = "1.14.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "562aedddfc14391badc2e527fd568164bd2fb108e2572cd3fe0b53ed440ec439" +dependencies = [ + "flate2", + "lazy_static", + "libc", + "ndarray", + "tar", + "thiserror", + "tracing", + "ureq", + "vswhom", + "winapi", + "zip", +] + [[package]] name = "os_str_bytes" version = "6.5.1" @@ -1915,6 +2021,12 @@ dependencies = [ "getrandom", ] +[[package]] +name = "rawpointer" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" + [[package]] name = "redox_syscall" version = "0.2.16" @@ -2013,6 +2125,21 @@ dependencies = [ "winreg", ] +[[package]] +name = "ring" +version = "0.16.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3053cf52e236a3ed746dfc745aa9cacf1b791d846bdaf412f60a8d7d6e17c8fc" +dependencies = [ + "cc", + "libc", + "once_cell", + "spin 0.5.2", + "untrusted", + "web-sys", + "winapi", +] + [[package]] name = "rustc-demangle" version = "0.1.23" @@ -2042,6 +2169,28 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "rustls" +version = "0.21.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e32ca28af694bc1bbf399c33a516dbdf1c90090b8ab23c2bc24f834aa2247f5f" +dependencies = [ + "log", + "ring", + "rustls-webpki", + "sct", +] + +[[package]] +name = "rustls-webpki" +version = "0.100.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d6207cd5ed3d8dca7816f8f3725513a34609c0c765bf652b8c3cb4cfd87db46b" +dependencies = [ + "ring", + "untrusted", +] + [[package]] name = "rustversion" version = "1.0.12" @@ -2069,6 +2218,16 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" +[[package]] +name = "sct" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d53dcdb7c9f8158937a7981b48accfd39a43af418591a5d008c7b22b5e1b7ca4" +dependencies = [ + "ring", + "untrusted", +] + [[package]] name = "sea-orm" version = "0.11.3" @@ -2397,6 +2556,12 @@ dependencies = [ "winapi", ] +[[package]] +name = "spin" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d" + [[package]] name = "spin" version = "0.9.8" @@ -2567,6 +2732,17 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160" +[[package]] +name = "tar" +version = "0.4.38" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b55807c0344e1e6c04d7c965f5289c39a8d94ae23ed5c0b57aabac549f871c6" +dependencies = [ + "filetime", + "libc", + "xattr", +] + [[package]] name = "tempfile" version = "3.6.0" @@ -2900,6 +3076,27 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e" +[[package]] +name = "untrusted" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a156c684c91ea7d62626509bce3cb4e1d9ed5c4d978f7b4352658f96a4c26b4a" + +[[package]] +name = "ureq" +version = "2.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b11c96ac7ee530603dcdf68ed1557050f374ce55a5a07193ebf8cbc9f8927e9" +dependencies = [ + "base64 0.21.2", + "log", + "once_cell", + "rustls", + "rustls-webpki", + "url", + "webpki-roots", +] + [[package]] name = "url" version = "2.4.0" @@ -2951,6 +3148,26 @@ version = "0.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" +[[package]] +name = "vswhom" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be979b7f07507105799e854203b470ff7c78a1639e330a58f183b5fea574608b" +dependencies = [ + "libc", + "vswhom-sys", +] + +[[package]] +name = "vswhom-sys" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3b17ae1f6c8a2b28506cd96d412eebf83b4a0ff2cbefeeb952f2f9dfa44ba18" +dependencies = [ + "cc", + "libc", +] + [[package]] name = "want" version = "0.3.1" @@ -3042,6 +3259,15 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "webpki-roots" +version = "0.23.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b03058f88386e5ff5310d9111d53f48b17d732b401aeb83a8d5190f2ac459338" +dependencies = [ + "rustls-webpki", +] + [[package]] name = "whoami" version = "1.4.1" @@ -3223,3 +3449,24 @@ checksum = "80d0f4e272c85def139476380b12f9ac60926689dd2e01d4923222f40580869d" dependencies = [ "winapi", ] + +[[package]] +name = "xattr" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d1526bbe5aaeb5eb06885f4d987bcdfa5e23187055de9b83fe00156a821fabc" +dependencies = [ + "libc", +] + +[[package]] +name = "zip" +version = "0.6.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "760394e246e4c28189f19d488c058bf16f564016aefac5d32bb1f3b51d5e9261" +dependencies = [ + "byteorder", + "crc32fast", + "crossbeam-utils", + "flate2", +] diff --git a/chimp_chomp/Cargo.toml b/chimp_chomp/Cargo.toml index 64d7505b..eb761507 100644 --- a/chimp_chomp/Cargo.toml +++ b/chimp_chomp/Cargo.toml @@ -4,3 +4,10 @@ version = "0.1.0" edition = "2021" [dependencies] +clap = { workspace = true } +dotenvy = { workspace = true } +ort = { version = "1.14.8", default-features = false, features = [ + "download-binaries", + "tensorrt", + "copy-dylibs", +] } diff --git a/chimp_chomp/src/main.rs b/chimp_chomp/src/main.rs index 4df2bf71..dc11e5d1 100644 --- a/chimp_chomp/src/main.rs +++ b/chimp_chomp/src/main.rs @@ -1,3 +1,19 @@ +mod model; + +use crate::model::setup_inference_session; +use clap::Parser; +use std::path::PathBuf; + +#[derive(Debug, Parser)] +#[command(author, version, about, long_about=None)] +struct Cli { + /// The path to the ONNX model file. + model: PathBuf, +} + fn main() { - println!("Hello world") + dotenvy::dotenv().ok(); + let args = Cli::parse(); + + let _session = setup_inference_session(args.model).unwrap(); } diff --git a/chimp_chomp/src/model.rs b/chimp_chomp/src/model.rs new file mode 100644 index 00000000..ada4feb7 --- /dev/null +++ b/chimp_chomp/src/model.rs @@ -0,0 +1,16 @@ +use ort::{ + Environment, ExecutionProvider, GraphOptimizationLevel, OrtError, Session, SessionBuilder, +}; +use std::{path::Path, sync::Arc}; + +pub fn setup_inference_session(model_path: impl AsRef) -> Result { + let environment = Arc::new( + Environment::builder() + .with_name("CHiMP") + .with_execution_providers([ExecutionProvider::cpu()]) + .build()?, + ); + SessionBuilder::new(&environment)? + .with_optimization_level(GraphOptimizationLevel::Level3)? + .with_model_from_file(model_path) +} From d05a5254dbaad73459c1460711c7bf85b701bff2 Mon Sep 17 00:00:00 2001 From: Garry O'Donnell Date: Tue, 4 Jul 2023 11:05:28 +0000 Subject: [PATCH 03/45] Add image loading utils --- Cargo.lock | 258 +++++++++++++++++++++++++++++++ chimp_chomp/Cargo.toml | 6 + chimp_chomp/src/image_loading.rs | 20 +++ chimp_chomp/src/main.rs | 1 + 4 files changed, 285 insertions(+) create mode 100644 chimp_chomp/src/image_loading.rs diff --git a/Cargo.lock b/Cargo.lock index 3b4a5e3b..b510a842 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -380,6 +380,12 @@ version = "0.21.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "604178f6c5c21f02dc555784810edfb88d34ac2c73b2eae109655649ee73ce3d" +[[package]] +name = "bit_field" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc827186963e592360843fb5ba4b973e145841266c1357f7180c43526f2e5b61" + [[package]] name = "bitflags" version = "1.3.2" @@ -401,6 +407,12 @@ version = "3.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a3e2c3daef883ecc1b5d58c15adae93470a91d425f3532ba1695849656af3fc1" +[[package]] +name = "bytemuck" +version = "1.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "17febce684fd15d89027105661fec94afb475cb995fbc59d2865198446ba2eea" + [[package]] name = "byteorder" version = "1.4.3" @@ -434,6 +446,9 @@ version = "0.1.0" dependencies = [ "clap 4.3.9", "dotenvy", + "image", + "ndarray", + "nshare", "ort", ] @@ -548,6 +563,12 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2da6da31387c7e4ef160ffab6d5e7f00c42626fe39aea70a7b0f1773f7dd6c1b" +[[package]] +name = "color_quant" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d7b894f5411737b7867f4827955924d7c254fc9f4d91a6aad6b097804b1018b" + [[package]] name = "colorchoice" version = "1.0.0" @@ -594,6 +615,40 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "crossbeam-channel" +version = "0.5.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a33c2bf77f2df06183c3aa30d1e96c0695a313d4f9c453cc3762a6db39f99200" +dependencies = [ + "cfg-if", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-deque" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce6fd6f855243022dcecf8702fef0c297d4338e226845fe067f6341ad9fa0cef" +dependencies = [ + "cfg-if", + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae211234986c545741a7dc064309f67ee1e5ad243d0e48335adc0484d960bcc7" +dependencies = [ + "autocfg", + "cfg-if", + "crossbeam-utils", + "memoffset", + "scopeguard", +] + [[package]] name = "crossbeam-queue" version = "0.3.8" @@ -613,6 +668,12 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "crunchy" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7" + [[package]] name = "crypto-common" version = "0.1.6" @@ -762,6 +823,22 @@ version = "2.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0206175f82b8d6bf6652ff7d71a1e27fd2e4efde587fd368662814d6ec1d9ce0" +[[package]] +name = "exr" +version = "1.6.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85a7b44a196573e272e0cf0bcf130281c71e9a0c67062954b3323fd364bfdac9" +dependencies = [ + "bit_field", + "flume", + "half", + "lebe", + "miniz_oxide 0.7.1", + "rayon-core", + "smallvec", + "zune-inflate", +] + [[package]] name = "fastrand" version = "1.9.0" @@ -771,6 +848,15 @@ dependencies = [ "instant", ] +[[package]] +name = "fdeflate" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d329bdeac514ee06249dabc27877490f17f5d371ec693360768b838e19f3ae10" +dependencies = [ + "simd-adler32", +] + [[package]] name = "filetime" version = "0.2.21" @@ -801,6 +887,7 @@ checksum = "1657b4441c3403d9f7b3409e47575237dac27b1b5726df654a6ecbf92f0f7577" dependencies = [ "futures-core", "futures-sink", + "nanorand", "pin-project", "spin 0.9.8", ] @@ -952,8 +1039,20 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "be4136b2a15dd319360be1c07d9933517ccf0be8f16bf62a3bee4f0d618df427" dependencies = [ "cfg-if", + "js-sys", "libc", "wasi", + "wasm-bindgen", +] + +[[package]] +name = "gif" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "80792593675e051cf94a4b111980da2ba60d4a83e43e0048c5693baab3977045" +dependencies = [ + "color_quant", + "weezl", ] [[package]] @@ -1002,6 +1101,15 @@ dependencies = [ "tracing", ] +[[package]] +name = "half" +version = "2.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "02b4af3693f1b705df946e9fe5631932443781d0aabb423b62fcd4d73f6d2fd0" +dependencies = [ + "crunchy", +] + [[package]] name = "handlebars" version = "4.3.7" @@ -1233,6 +1341,25 @@ dependencies = [ "unicode-normalization", ] +[[package]] +name = "image" +version = "0.24.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "527909aa81e20ac3a44803521443a765550f09b5130c2c2fa1ea59c2f8f50a3a" +dependencies = [ + "bytemuck", + "byteorder", + "color_quant", + "exr", + "gif", + "jpeg-decoder", + "num-rational", + "num-traits", + "png", + "qoi", + "tiff", +] + [[package]] name = "indexmap" version = "1.9.3" @@ -1307,6 +1434,15 @@ version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "453ad9f582a441959e5f0d088b02ce04cfe8d51a8eaf077f12ac6d3e94164ca6" +[[package]] +name = "jpeg-decoder" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bc0000e42512c92e31c2252315bda326620a4e034105e900c98ec492fa077b3e" +dependencies = [ + "rayon", +] + [[package]] name = "js-sys" version = "0.3.64" @@ -1322,6 +1458,12 @@ version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" +[[package]] +name = "lebe" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "03087c2bad5e1034e8cace5926dec053fb3790248370865f5117a7d0213354c8" + [[package]] name = "lexical" version = "6.1.1" @@ -1474,6 +1616,15 @@ version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2dffe52ecf27772e601905b7522cb4ef790d2cc203488bbd0e2fe85fcb74566d" +[[package]] +name = "memoffset" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a634b1c61a95585bd15607c6ab0c4e5b226e695ff2800ba0cdccddf208c406c" +dependencies = [ + "autocfg", +] + [[package]] name = "mime" version = "0.3.17" @@ -1502,6 +1653,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e7810e0be55b428ada41041c41f32c9f1a42817901b4ccf45fa3d4b6561e74c7" dependencies = [ "adler", + "simd-adler32", ] [[package]] @@ -1533,6 +1685,15 @@ dependencies = [ "version_check", ] +[[package]] +name = "nanorand" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a51313c5820b0b02bd422f4b44776fbf47961755c74ce64afc73bfad10226c3" +dependencies = [ + "getrandom", +] + [[package]] name = "native-tls" version = "0.2.11" @@ -1574,6 +1735,16 @@ dependencies = [ "minimal-lexical", ] +[[package]] +name = "nshare" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4447657cd40e3107416ec4f2ac3e61a18781b00061789e3b8f4bbcbccb26c4c6" +dependencies = [ + "image", + "ndarray", +] + [[package]] name = "nu-ansi-term" version = "0.46.0" @@ -1603,6 +1774,17 @@ dependencies = [ "num-traits", ] +[[package]] +name = "num-rational" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0638a1c9d0a3c0914158145bc76cff373a75a627e6ecbfb71cbe6f453a5a19b0" +dependencies = [ + "autocfg", + "num-integer", + "num-traits", +] + [[package]] name = "num-traits" version = "0.2.15" @@ -1933,6 +2115,19 @@ version = "0.3.27" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "26072860ba924cbfa98ea39c8c19b4dd6a4a25423dbdf219c1eca91aa0cf6964" +[[package]] +name = "png" +version = "0.17.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59871cc5b6cce7eaccca5a802b4173377a1c2ba90654246789a8fa2334426d11" +dependencies = [ + "bitflags", + "crc32fast", + "fdeflate", + "flate2", + "miniz_oxide 0.7.1", +] + [[package]] name = "ppv-lite86" version = "0.2.17" @@ -1982,6 +2177,15 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "qoi" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f6d64c71eb498fe9eae14ce4ec935c555749aef511cca85b5568910d6e48001" +dependencies = [ + "bytemuck", +] + [[package]] name = "quote" version = "1.0.29" @@ -2027,6 +2231,28 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" +[[package]] +name = "rayon" +version = "1.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d2df5196e37bcc87abebc0053e20787d73847bb33134a69841207dd0a47f03b" +dependencies = [ + "either", + "rayon-core", +] + +[[package]] +name = "rayon-core" +version = "1.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b8f95bd6966f5c87776639160a66bd8ab9895d9d4ab01ddba9fc60661aebe8d" +dependencies = [ + "crossbeam-channel", + "crossbeam-deque", + "crossbeam-utils", + "num_cpus", +] + [[package]] name = "redox_syscall" version = "0.2.16" @@ -2491,6 +2717,12 @@ dependencies = [ "lazy_static", ] +[[package]] +name = "simd-adler32" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "238abfbb77c1915110ad968465608b68e869e0772622c9656714e73e5a1a522f" + [[package]] name = "siphasher" version = "0.3.10" @@ -2793,6 +3025,17 @@ dependencies = [ "once_cell", ] +[[package]] +name = "tiff" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7449334f9ff2baf290d55d73983a7d6fa15e01198faef72af07e2a8db851e471" +dependencies = [ + "flate2", + "jpeg-decoder", + "weezl", +] + [[package]] name = "tinyvec" version = "1.6.0" @@ -3268,6 +3511,12 @@ dependencies = [ "rustls-webpki", ] +[[package]] +name = "weezl" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9193164d4de03a926d909d3bc7c30543cecb35400c02114792c2cae20d5e2dbb" + [[package]] name = "whoami" version = "1.4.1" @@ -3470,3 +3719,12 @@ dependencies = [ "crossbeam-utils", "flate2", ] + +[[package]] +name = "zune-inflate" +version = "0.2.54" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73ab332fe2f6680068f3582b16a24f90ad7096d5d39b974d1c0aff0125116f02" +dependencies = [ + "simd-adler32", +] diff --git a/chimp_chomp/Cargo.toml b/chimp_chomp/Cargo.toml index eb761507..c93ef407 100644 --- a/chimp_chomp/Cargo.toml +++ b/chimp_chomp/Cargo.toml @@ -6,6 +6,12 @@ edition = "2021" [dependencies] clap = { workspace = true } dotenvy = { workspace = true } +image = { version = "0.24.6" } +ndarray = { version = "0.15.6" } +nshare = { version = "0.9.0", default-features = false, features = [ + "image", + "ndarray", +] } ort = { version = "1.14.8", default-features = false, features = [ "download-binaries", "tensorrt", diff --git a/chimp_chomp/src/image_loading.rs b/chimp_chomp/src/image_loading.rs new file mode 100644 index 00000000..cae7cd21 --- /dev/null +++ b/chimp_chomp/src/image_loading.rs @@ -0,0 +1,20 @@ +use image::{imageops::FilterType, ImageFormat}; +use ndarray::{ArrayBase, Axis, Dim, IxDynImpl, OwnedRepr}; +use nshare::ToNdarray3; +use std::{fs::File, io::BufReader, path::Path}; + +pub fn load_image( + path: impl AsRef, + width: u32, + height: u32, +) -> ArrayBase, Dim> { + let file = File::open(path).unwrap(); + let reader = BufReader::new(file); + image::load(reader, ImageFormat::Jpeg) + .unwrap() + .resize_exact(width, height, FilterType::Triangle) + .into_rgb32f() + .into_ndarray3() + .insert_axis(Axis(0)) + .into_dyn() +} diff --git a/chimp_chomp/src/main.rs b/chimp_chomp/src/main.rs index dc11e5d1..90f84075 100644 --- a/chimp_chomp/src/main.rs +++ b/chimp_chomp/src/main.rs @@ -1,3 +1,4 @@ +mod image_loading; mod model; use crate::model::setup_inference_session; From 835c3be41b08e8ec410d48885558e2ffe07fdfe0 Mon Sep 17 00:00:00 2001 From: Garry O'Donnell Date: Wed, 5 Jul 2023 10:18:37 +0000 Subject: [PATCH 04/45] Create chimp_protocol for common structs --- Cargo.lock | 12 ++++++++++-- Cargo.toml | 1 + chimp_chomp/Cargo.toml | 1 + chimp_protocol/Cargo.toml | 7 +++++++ chimp_protocol/src/lib.rs | 18 ++++++++++++++++++ 5 files changed, 37 insertions(+), 2 deletions(-) create mode 100644 chimp_protocol/Cargo.toml create mode 100644 chimp_protocol/src/lib.rs diff --git a/Cargo.lock b/Cargo.lock index b510a842..b6736c17 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -444,6 +444,7 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" name = "chimp_chomp" version = "0.1.0" dependencies = [ + "chimp_protocol", "clap 4.3.9", "dotenvy", "image", @@ -452,6 +453,13 @@ dependencies = [ "ort", ] +[[package]] +name = "chimp_protocol" +version = "0.1.0" +dependencies = [ + "serde", +] + [[package]] name = "chrono" version = "0.4.26" @@ -2656,9 +2664,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.99" +version = "1.0.100" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "46266871c240a00b8f503b877622fe33430b3c7d963bdc0f2adc511e54a1eae3" +checksum = "0f1e14e89be7aa4c4b78bdbdc9eb5bf8517829a600ae8eaa39a6e1d960b5185c" dependencies = [ "itoa", "ryu", diff --git a/Cargo.toml b/Cargo.toml index 678dc2e1..706c8322 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,7 @@ [workspace] members = [ "chimp_chomp", + "chimp_protocol", "graphql_endpoints", "graphql_event_broker", "opa_client", diff --git a/chimp_chomp/Cargo.toml b/chimp_chomp/Cargo.toml index c93ef407..b8674576 100644 --- a/chimp_chomp/Cargo.toml +++ b/chimp_chomp/Cargo.toml @@ -4,6 +4,7 @@ version = "0.1.0" edition = "2021" [dependencies] +chimp_protocol = { path = "../chimp_protocol" } clap = { workspace = true } dotenvy = { workspace = true } image = { version = "0.24.6" } diff --git a/chimp_protocol/Cargo.toml b/chimp_protocol/Cargo.toml new file mode 100644 index 00000000..c60b7972 --- /dev/null +++ b/chimp_protocol/Cargo.toml @@ -0,0 +1,7 @@ +[package] +name = "chimp_protocol" +version = "0.1.0" +edition = "2021" + +[dependencies] +serde = { version = "1.0.166" } diff --git a/chimp_protocol/src/lib.rs b/chimp_protocol/src/lib.rs new file mode 100644 index 00000000..93263532 --- /dev/null +++ b/chimp_protocol/src/lib.rs @@ -0,0 +1,18 @@ +use serde::{Deserialize, Serialize}; +use std::path::PathBuf; + +#[derive(Debug, Serialize, Deserialize)] +pub struct Job { + pub file: PathBuf, + pub predictions_channel: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct Predictions(pub Vec); + +#[derive(Debug, Serialize, Deserialize)] +pub struct Prediction { + pub bbox: [f32; 4], + pub label: i64, + pub score: f32, +} From 01c5ed6f9edb17d37868042dc53adc19d77cbb4d Mon Sep 17 00:00:00 2001 From: Garry O'Donnell Date: Wed, 5 Jul 2023 10:20:10 +0000 Subject: [PATCH 05/45] Create inference worker --- Cargo.lock | 2 + chimp_chomp/Cargo.toml | 2 + chimp_chomp/src/inference.rs | 94 ++++++++++++++++++++++++++++++++++++ chimp_chomp/src/main.rs | 26 ++++++++-- chimp_chomp/src/model.rs | 16 ------ 5 files changed, 120 insertions(+), 20 deletions(-) create mode 100644 chimp_chomp/src/inference.rs delete mode 100644 chimp_chomp/src/model.rs diff --git a/Cargo.lock b/Cargo.lock index b6736c17..9c41763b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -448,9 +448,11 @@ dependencies = [ "clap 4.3.9", "dotenvy", "image", + "itertools", "ndarray", "nshare", "ort", + "tokio", ] [[package]] diff --git a/chimp_chomp/Cargo.toml b/chimp_chomp/Cargo.toml index b8674576..c34c5cc0 100644 --- a/chimp_chomp/Cargo.toml +++ b/chimp_chomp/Cargo.toml @@ -8,6 +8,7 @@ chimp_protocol = { path = "../chimp_protocol" } clap = { workspace = true } dotenvy = { workspace = true } image = { version = "0.24.6" } +itertools = { workspace = true } ndarray = { version = "0.15.6" } nshare = { version = "0.9.0", default-features = false, features = [ "image", @@ -18,3 +19,4 @@ ort = { version = "1.14.8", default-features = false, features = [ "tensorrt", "copy-dylibs", ] } +tokio = { workspace = true, features = ["sync"] } diff --git a/chimp_chomp/src/inference.rs b/chimp_chomp/src/inference.rs new file mode 100644 index 00000000..e814d763 --- /dev/null +++ b/chimp_chomp/src/inference.rs @@ -0,0 +1,94 @@ +use chimp_protocol::{Prediction, Predictions}; +use itertools::{izip, Itertools}; +use ndarray::{ArrayBase, Axis, Dim, Ix2, Ix3, IxDynImpl, OwnedRepr, ViewRepr}; +use ort::{ + tensor::{FromArray, InputTensor}, + Environment, ExecutionProvider, GraphOptimizationLevel, OrtError, Session, SessionBuilder, +}; +use std::{path::Path, sync::Arc}; +use tokio::sync::mpsc::{Receiver, UnboundedSender}; + +pub fn setup_inference_session(model_path: impl AsRef) -> Result { + let environment = Arc::new( + Environment::builder() + .with_name("CHiMP") + .with_execution_providers([ExecutionProvider::cpu()]) + .build()?, + ); + SessionBuilder::new(&environment)? + .with_optimization_level(GraphOptimizationLevel::Level3)? + .with_model_from_file(model_path) +} + +fn do_inference( + session: &Session, + images: &[ArrayBase, Dim>], +) -> Vec { + let input = InputTensor::from_array(ndarray::concatenate(Axis(0), images).unwrap()); + let outputs = session.run(vec![input]).unwrap(); + let bboxes = outputs[0] + .try_extract::() + .unwrap() + .view() + .to_owned() + .into_dimensionality::() + .unwrap(); + let labels = outputs[1] + .try_extract::() + .unwrap() + .view() + .to_owned() + .into_dimensionality::() + .unwrap(); + let scores = outputs[2] + .try_extract::() + .unwrap() + .view() + .to_owned() + .into_dimensionality::() + .unwrap(); + + izip!( + bboxes.outer_iter(), + labels.outer_iter(), + scores.outer_iter() + ) + .map(|(bboxes, labels, scores)| { + Predictions( + izip!( + bboxes.outer_iter(), + labels.to_vec().iter(), + scores.to_vec().iter() + ) + .map(|(bbox, &label, &score)| Prediction { + bbox: bbox.to_vec().try_into().unwrap(), + label, + score, + }) + .collect(), + ) + }) + .collect() +} + +pub async fn inference_worker( + session: Session, + batch_size: usize, + mut image_rx: Receiver, Dim>>, + prediction_tx: UnboundedSender, +) { + image_rx + .recv() + .await + .iter() + .map(ArrayBase::view) + .chunks(batch_size) + .into_iter() + .for_each(|images| { + let images = images.collect::>(); + let predictions = do_inference(&session, &images); + predictions + .into_iter() + .for_each(|predictions| prediction_tx.send(predictions).unwrap()) + }); +} diff --git a/chimp_chomp/src/main.rs b/chimp_chomp/src/main.rs index 90f84075..af947ea3 100644 --- a/chimp_chomp/src/main.rs +++ b/chimp_chomp/src/main.rs @@ -1,9 +1,11 @@ mod image_loading; -mod model; +mod inference; +mod jobs; -use crate::model::setup_inference_session; use clap::Parser; +use inference::{inference_worker, setup_inference_session}; use std::path::PathBuf; +use tokio::task::JoinSet; #[derive(Debug, Parser)] #[command(author, version, about, long_about=None)] @@ -12,9 +14,25 @@ struct Cli { model: PathBuf, } -fn main() { +#[tokio::main] +async fn main() { dotenvy::dotenv().ok(); let args = Cli::parse(); - let _session = setup_inference_session(args.model).unwrap(); + let session = setup_inference_session(args.model).unwrap(); + let batch_size = session.inputs[0].dimensions[0].unwrap().try_into().unwrap(); + + let (_image_tx, image_rx) = tokio::sync::mpsc::channel(batch_size); + let (prediction_tx, _prediction_rx) = tokio::sync::mpsc::unbounded_channel(); + + let mut tasks = JoinSet::new(); + + tasks.spawn(inference_worker( + session, + batch_size, + image_rx, + prediction_tx, + )); + + tasks.join_next().await; } diff --git a/chimp_chomp/src/model.rs b/chimp_chomp/src/model.rs deleted file mode 100644 index ada4feb7..00000000 --- a/chimp_chomp/src/model.rs +++ /dev/null @@ -1,16 +0,0 @@ -use ort::{ - Environment, ExecutionProvider, GraphOptimizationLevel, OrtError, Session, SessionBuilder, -}; -use std::{path::Path, sync::Arc}; - -pub fn setup_inference_session(model_path: impl AsRef) -> Result { - let environment = Arc::new( - Environment::builder() - .with_name("CHiMP") - .with_execution_providers([ExecutionProvider::cpu()]) - .build()?, - ); - SessionBuilder::new(&environment)? - .with_optimization_level(GraphOptimizationLevel::Level3)? - .with_model_from_file(model_path) -} From ceb5c9c62c77efc4c0c5e4bb0cb22e72a7bf5a39 Mon Sep 17 00:00:00 2001 From: Garry O'Donnell Date: Wed, 5 Jul 2023 10:47:50 +0000 Subject: [PATCH 06/45] Add (de)serializers to protocol structs --- Cargo.lock | 1 + chimp_protocol/Cargo.toml | 1 + chimp_protocol/src/lib.rs | 30 ++++++++++++++++++++++++++++++ 3 files changed, 32 insertions(+) diff --git a/Cargo.lock b/Cargo.lock index 9c41763b..bc1528e0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -460,6 +460,7 @@ name = "chimp_protocol" version = "0.1.0" dependencies = [ "serde", + "serde_json", ] [[package]] diff --git a/chimp_protocol/Cargo.toml b/chimp_protocol/Cargo.toml index c60b7972..2cdabc37 100644 --- a/chimp_protocol/Cargo.toml +++ b/chimp_protocol/Cargo.toml @@ -5,3 +5,4 @@ edition = "2021" [dependencies] serde = { version = "1.0.166" } +serde_json = "1.0.100" diff --git a/chimp_protocol/src/lib.rs b/chimp_protocol/src/lib.rs index 93263532..a2aad245 100644 --- a/chimp_protocol/src/lib.rs +++ b/chimp_protocol/src/lib.rs @@ -7,12 +7,42 @@ pub struct Job { pub predictions_channel: String, } +impl Job { + pub fn from_slice(v: &[u8]) -> Result { + serde_json::from_slice(v) + } + + pub fn to_vec(&self) -> Result, serde_json::Error> { + serde_json::to_vec(&self) + } +} + #[derive(Debug, Serialize, Deserialize)] pub struct Predictions(pub Vec); +impl Predictions { + pub fn from_slice(v: &[u8]) -> Result { + serde_json::from_slice(v) + } + + pub fn to_vec(&self) -> Result, serde_json::Error> { + serde_json::to_vec(&self) + } +} + #[derive(Debug, Serialize, Deserialize)] pub struct Prediction { pub bbox: [f32; 4], pub label: i64, pub score: f32, } + +impl Prediction { + pub fn from_slice(v: &[u8]) -> Result { + serde_json::from_slice(v) + } + + pub fn to_vec(&self) -> Result, serde_json::Error> { + serde_json::to_vec(&self) + } +} From 346e5d4d2abae147921cb2d5cb52639620dad0b1 Mon Sep 17 00:00:00 2001 From: Garry O'Donnell Date: Wed, 5 Jul 2023 11:16:28 +0000 Subject: [PATCH 07/45] Setup RabbitMQ job consumer --- Cargo.lock | 336 +++++++++++++++++++++++++++++++++++++++- chimp_chomp/Cargo.toml | 6 + chimp_chomp/src/jobs.rs | 50 ++++++ chimp_chomp/src/main.rs | 22 ++- 4 files changed, 411 insertions(+), 3 deletions(-) create mode 100644 chimp_chomp/src/jobs.rs diff --git a/Cargo.lock b/Cargo.lock index bc1528e0..9cb50e88 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -70,6 +70,54 @@ version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "56fc6cf8dc8c4158eed8649f9b8b0ea1518eb62b544fe9490d66fa0b349eafe9" +[[package]] +name = "amq-protocol" +version = "7.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d40d8b2465c7959dd40cee32ba6ac334b5de57e9fca0cc756759894a4152a5d" +dependencies = [ + "amq-protocol-tcp", + "amq-protocol-types", + "amq-protocol-uri", + "cookie-factory", + "nom", + "serde", +] + +[[package]] +name = "amq-protocol-tcp" +version = "7.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9cb2100adae7da61953a2c3a01935d86caae13329fadce3333f524d6d6ce12e2" +dependencies = [ + "amq-protocol-uri", + "tcp-stream", + "tracing", +] + +[[package]] +name = "amq-protocol-types" +version = "7.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "156ff13c8a3ced600b4e54ed826a2ae6242b6069d00dd98466827cef07d3daff" +dependencies = [ + "cookie-factory", + "nom", + "serde", + "serde_json", +] + +[[package]] +name = "amq-protocol-uri" +version = "7.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "751bbd7d440576066233e740576f1b31fdc6ab86cfabfbd48c548de77eca73e4" +dependencies = [ + "amq-protocol-types", + "percent-encoding", + "url", +] + [[package]] name = "android-tzdata" version = "0.1.1" @@ -134,6 +182,57 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "async-channel" +version = "1.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cf46fee83e5ccffc220104713af3292ff9bc7c64c7de289f66dae8e38d826833" +dependencies = [ + "concurrent-queue", + "event-listener", + "futures-core", +] + +[[package]] +name = "async-executor" +version = "1.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6fa3dc5f2a8564f07759c008b9109dc0d39de92a88d5588b8a5036d286383afb" +dependencies = [ + "async-lock", + "async-task", + "concurrent-queue", + "fastrand", + "futures-lite", + "slab", +] + +[[package]] +name = "async-global-executor" +version = "2.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1b6f5d7df27bd294849f8eec66ecfc63d11814df7a4f5d74168a2394467b776" +dependencies = [ + "async-channel", + "async-executor", + "async-io", + "async-lock", + "blocking", + "futures-lite", + "once_cell", +] + +[[package]] +name = "async-global-executor-trait" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "33dd14c5a15affd2abcff50d84efd4009ada28a860f01c14f9d654f3e81b3f75" +dependencies = [ + "async-global-executor", + "async-trait", + "executor-trait", +] + [[package]] name = "async-graphql" version = "5.0.10" @@ -226,6 +325,47 @@ dependencies = [ "serde_json", ] +[[package]] +name = "async-io" +version = "1.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fc5b45d93ef0529756f812ca52e44c221b35341892d3dcc34132ac02f3dd2af" +dependencies = [ + "async-lock", + "autocfg", + "cfg-if", + "concurrent-queue", + "futures-lite", + "log", + "parking", + "polling", + "rustix", + "slab", + "socket2", + "waker-fn", +] + +[[package]] +name = "async-lock" +version = "2.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fa24f727524730b077666307f2734b4a1a1c57acb79193127dcc8914d5242dd7" +dependencies = [ + "event-listener", +] + +[[package]] +name = "async-reactor-trait" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a6012d170ad00de56c9ee354aef2e358359deb1ec504254e0e5a3774771de0e" +dependencies = [ + "async-io", + "async-trait", + "futures-core", + "reactor-trait", +] + [[package]] name = "async-stream" version = "0.3.5" @@ -248,6 +388,12 @@ dependencies = [ "syn 2.0.25", ] +[[package]] +name = "async-task" +version = "4.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ecc7ab41815b3c653ccd2978ec3255c81349336702dfdf62ee6f7069b12a3aae" + [[package]] name = "async-trait" version = "0.1.68" @@ -268,6 +414,12 @@ dependencies = [ "num-traits", ] +[[package]] +name = "atomic-waker" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1181e1e0d1fce796a03db1ae795d67167da795f9cf4a39c37589e85ef57f26d3" + [[package]] name = "autocfg" version = "1.1.0" @@ -401,6 +553,21 @@ dependencies = [ "generic-array", ] +[[package]] +name = "blocking" +version = "1.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77231a1c8f801696fc0123ec6150ce92cffb8e164a02afb9c8ddee0e9b65ad65" +dependencies = [ + "async-channel", + "async-lock", + "async-task", + "atomic-waker", + "fastrand", + "futures-lite", + "log", +] + [[package]] name = "bumpalo" version = "3.13.0" @@ -447,12 +614,16 @@ dependencies = [ "chimp_protocol", "clap 4.3.9", "dotenvy", + "futures-lite", "image", "itertools", + "lapin", "ndarray", "nshare", "ort", "tokio", + "url", + "uuid", ] [[package]] @@ -586,12 +757,27 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "acbf1af155f9b9ef647e42cdc158db4b64a1b61f743629225fde6f3e0be2a7c7" +[[package]] +name = "concurrent-queue" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "62ec6771ecfa0762d24683ee5a32ad78487a3d3afdc0fb8cae19d2c5deb50b7c" +dependencies = [ + "crossbeam-utils", +] + [[package]] name = "convert_case" version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6245d59a3e82a7fc217c5828a6692dbc6dfb63a0c8c90495621f7b9d79704a0e" +[[package]] +name = "cookie-factory" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "396de984970346b0d9e93d1415082923c679e5ae5c3ee3dcbd104f5610af126b" + [[package]] name = "core-foundation" version = "0.9.3" @@ -780,6 +966,12 @@ dependencies = [ "winapi", ] +[[package]] +name = "doc-comment" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fea41bba32d969b513997752735605054bc0dfa92b4c56bf1189f2e174be7a10" + [[package]] name = "dotenvy" version = "0.15.7" @@ -834,6 +1026,15 @@ version = "2.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0206175f82b8d6bf6652ff7d71a1e27fd2e4efde587fd368662814d6ec1d9ce0" +[[package]] +name = "executor-trait" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a1052dd43212a7777ec6a69b117da52f5e52f07aec47d00c1a2b33b85d06b08" +dependencies = [ + "async-trait", +] + [[package]] name = "exr" version = "1.6.5" @@ -983,7 +1184,7 @@ checksum = "a604f7a68fbf8103337523b1fadc8ade7361ee3f112f7c680ad179651616aed5" dependencies = [ "futures-core", "lock_api", - "parking_lot", + "parking_lot 0.11.2", ] [[package]] @@ -992,6 +1193,21 @@ version = "0.3.28" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4fff74096e71ed47f8e023204cfd0aa1289cd54ae5430a9523be060cdb849964" +[[package]] +name = "futures-lite" +version = "1.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49a9d51ce47660b1e808d3c990b4709f2f415d928835a17dfd16991515c46bce" +dependencies = [ + "fastrand", + "futures-core", + "futures-io", + "memchr", + "parking", + "pin-project-lite", + "waker-fn", +] + [[package]] name = "futures-macro" version = "0.3.28" @@ -1463,6 +1679,28 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "lapin" +version = "2.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "acc13beaa09eed710f406201f46b961345b4d061dd90ec3d3ccc70721e70342a" +dependencies = [ + "amq-protocol", + "async-global-executor-trait", + "async-reactor-trait", + "async-trait", + "executor-trait", + "flume", + "futures-core", + "futures-io", + "parking_lot 0.12.1", + "pinky-swear", + "reactor-trait", + "serde", + "tracing", + "waker-fn", +] + [[package]] name = "lazy_static" version = "1.4.0" @@ -1940,6 +2178,12 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" +[[package]] +name = "parking" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "14f2252c834a40ed9bb5422029649578e63aa341ac401f74e719dd1afda8394e" + [[package]] name = "parking_lot" version = "0.11.2" @@ -1948,7 +2192,17 @@ checksum = "7d17b78036a60663b797adeaee46f5c9dfebb86948d1255007a1d6be0271ff99" dependencies = [ "instant", "lock_api", - "parking_lot_core", + "parking_lot_core 0.8.6", +] + +[[package]] +name = "parking_lot" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3742b2c103b9f06bc9fff0a37ff4912935851bee6d36f3c02bcc755bcfec228f" +dependencies = [ + "lock_api", + "parking_lot_core 0.9.8", ] [[package]] @@ -1965,6 +2219,19 @@ dependencies = [ "winapi", ] +[[package]] +name = "parking_lot_core" +version = "0.9.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93f00c865fe7cabf650081affecd3871070f26767e7b2070a3ffae14c654b447" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall 0.3.5", + "smallvec", + "windows-targets", +] + [[package]] name = "parse-zoneinfo" version = "0.3.0" @@ -2120,6 +2387,18 @@ dependencies = [ "uuid", ] +[[package]] +name = "pinky-swear" +version = "6.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d894b67aa7a4bf295db5e85349078c604edaa6fa5c8721e8eca3c7729a27f2ac" +dependencies = [ + "doc-comment", + "flume", + "parking_lot 0.12.1", + "tracing", +] + [[package]] name = "pkg-config" version = "0.3.27" @@ -2139,6 +2418,22 @@ dependencies = [ "miniz_oxide 0.7.1", ] +[[package]] +name = "polling" +version = "2.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b2d323e8ca7996b3e23126511a523f7e62924d93ecd5ae73b333815b0eb3dce" +dependencies = [ + "autocfg", + "bitflags", + "cfg-if", + "concurrent-queue", + "libc", + "log", + "pin-project-lite", + "windows-sys 0.48.0", +] + [[package]] name = "ppv-lite86" version = "0.2.17" @@ -2264,6 +2559,17 @@ dependencies = [ "num_cpus", ] +[[package]] +name = "reactor-trait" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "438a4293e4d097556730f4711998189416232f009c137389e0f961d2bc0ddc58" +dependencies = [ + "async-trait", + "futures-core", + "futures-io", +] + [[package]] name = "redox_syscall" version = "0.2.16" @@ -2418,6 +2724,15 @@ dependencies = [ "sct", ] +[[package]] +name = "rustls-pemfile" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2d3987094b1d07b653b7dfdc3f70ce9a1da9c51ac18c1b06b662e4f9a0e9f4b2" +dependencies = [ + "base64 0.21.2", +] + [[package]] name = "rustls-webpki" version = "0.100.1" @@ -2986,6 +3301,17 @@ dependencies = [ "xattr", ] +[[package]] +name = "tcp-stream" +version = "0.26.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1322b18a9e329ba45e4430b19543045b85cd1dcb2892e77d27ab471ba2039bd1" +dependencies = [ + "cfg-if", + "native-tls", + "rustls-pemfile", +] + [[package]] name = "tempfile" version = "3.6.0" @@ -3422,6 +3748,12 @@ dependencies = [ "libc", ] +[[package]] +name = "waker-fn" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d5b2c62b4012a3e1eca5a7e077d13b3bf498c4073e33ccd58626607748ceeca" + [[package]] name = "want" version = "0.3.1" diff --git a/chimp_chomp/Cargo.toml b/chimp_chomp/Cargo.toml index c34c5cc0..2b4a4f5b 100644 --- a/chimp_chomp/Cargo.toml +++ b/chimp_chomp/Cargo.toml @@ -7,8 +7,12 @@ edition = "2021" chimp_protocol = { path = "../chimp_protocol" } clap = { workspace = true } dotenvy = { workspace = true } +futures-lite = { version = "1.13.0" } image = { version = "0.24.6" } itertools = { workspace = true } +lapin = { version = "2.2.1", default-features = false, features = [ + "native-tls", +] } ndarray = { version = "0.15.6" } nshare = { version = "0.9.0", default-features = false, features = [ "image", @@ -20,3 +24,5 @@ ort = { version = "1.14.8", default-features = false, features = [ "copy-dylibs", ] } tokio = { workspace = true, features = ["sync"] } +url = { workspace = true } +uuid = { workspace = true } diff --git a/chimp_chomp/src/jobs.rs b/chimp_chomp/src/jobs.rs new file mode 100644 index 00000000..df7de65f --- /dev/null +++ b/chimp_chomp/src/jobs.rs @@ -0,0 +1,50 @@ +use crate::image_loading::load_image; +use chimp_protocol::Job; +use futures_lite::StreamExt; +use lapin::{ + options::{BasicAckOptions, BasicConsumeOptions}, + types::FieldTable, + Connection, Consumer, +}; +use ndarray::{ArrayBase, Dim, IxDynImpl, OwnedRepr}; +use tokio::sync::mpsc::Sender; +use url::Url; +use uuid::Uuid; + +pub async fn setup_rabbitmq_client(address: Url) -> Result { + lapin::Connection::connect(address.as_str(), lapin::ConnectionProperties::default()).await +} + +pub async fn setup_job_consumer( + rabbitmq_client: Connection, + channel: impl AsRef, +) -> Result { + let worker_id = Uuid::new_v4(); + let worker_tag = format!("chimp_chomp_{worker_id}"); + let job_channel = rabbitmq_client.create_channel().await?; + job_channel + .basic_consume( + channel.as_ref(), + &worker_tag, + BasicConsumeOptions::default(), + FieldTable::default(), + ) + .await +} + +pub async fn job_consumption_worker( + mut job_consumer: Consumer, + input_width: u32, + input_height: u32, + image_tx: Sender, Dim>>, +) { + while let Some(delivery) = job_consumer.next().await { + let delievry = delivery.unwrap(); + delievry.ack(BasicAckOptions::default()).await.unwrap(); + + let job = Job::from_slice(&delievry.data).unwrap(); + let image = load_image(job.file, input_width, input_height); + + image_tx.send(image).await.unwrap(); + } +} diff --git a/chimp_chomp/src/main.rs b/chimp_chomp/src/main.rs index af947ea3..a79d76c1 100644 --- a/chimp_chomp/src/main.rs +++ b/chimp_chomp/src/main.rs @@ -4,14 +4,20 @@ mod jobs; use clap::Parser; use inference::{inference_worker, setup_inference_session}; +use jobs::{job_consumption_worker, setup_job_consumer, setup_rabbitmq_client}; use std::path::PathBuf; use tokio::task::JoinSet; +use url::Url; #[derive(Debug, Parser)] #[command(author, version, about, long_about=None)] struct Cli { /// The path to the ONNX model file. model: PathBuf, + /// The URL of the RabbitMQ server. + rabbitmq_address: Url, + /// The RabbitMQ channel on which jobs are assigned. + rabbitmq_channel: String, } #[tokio::main] @@ -20,9 +26,16 @@ async fn main() { let args = Cli::parse(); let session = setup_inference_session(args.model).unwrap(); + let input_width = session.inputs[0].dimensions[3].unwrap(); + let input_height = session.inputs[0].dimensions[2].unwrap(); let batch_size = session.inputs[0].dimensions[0].unwrap().try_into().unwrap(); - let (_image_tx, image_rx) = tokio::sync::mpsc::channel(batch_size); + let rabbitmq_client = setup_rabbitmq_client(args.rabbitmq_address).await.unwrap(); + let job_consumer = setup_job_consumer(rabbitmq_client, args.rabbitmq_channel) + .await + .unwrap(); + + let (image_tx, image_rx) = tokio::sync::mpsc::channel(batch_size); let (prediction_tx, _prediction_rx) = tokio::sync::mpsc::unbounded_channel(); let mut tasks = JoinSet::new(); @@ -34,5 +47,12 @@ async fn main() { prediction_tx, )); + tasks.spawn(job_consumption_worker( + job_consumer, + input_width, + input_height, + image_tx, + )); + tasks.join_next().await; } From 7089c59d00e47d9defc8758b927b3652074a97cc Mon Sep 17 00:00:00 2001 From: Garry O'Donnell Date: Wed, 5 Jul 2023 14:45:05 +0000 Subject: [PATCH 08/45] Create predictions producer worker --- chimp_chomp/src/image_loading.rs | 8 +++--- chimp_chomp/src/inference.rs | 24 +++++++++++------- chimp_chomp/src/jobs.rs | 43 ++++++++++++++++++++++++-------- chimp_chomp/src/main.rs | 15 ++++++++--- 4 files changed, 62 insertions(+), 28 deletions(-) diff --git a/chimp_chomp/src/image_loading.rs b/chimp_chomp/src/image_loading.rs index cae7cd21..fe420a1d 100644 --- a/chimp_chomp/src/image_loading.rs +++ b/chimp_chomp/src/image_loading.rs @@ -3,11 +3,9 @@ use ndarray::{ArrayBase, Axis, Dim, IxDynImpl, OwnedRepr}; use nshare::ToNdarray3; use std::{fs::File, io::BufReader, path::Path}; -pub fn load_image( - path: impl AsRef, - width: u32, - height: u32, -) -> ArrayBase, Dim> { +pub type Image = ArrayBase, Dim>; + +pub fn load_image(path: impl AsRef, width: u32, height: u32) -> Image { let file = File::open(path).unwrap(); let reader = BufReader::new(file); image::load(reader, ImageFormat::Jpeg) diff --git a/chimp_chomp/src/inference.rs b/chimp_chomp/src/inference.rs index e814d763..ea952463 100644 --- a/chimp_chomp/src/inference.rs +++ b/chimp_chomp/src/inference.rs @@ -1,6 +1,6 @@ use chimp_protocol::{Prediction, Predictions}; use itertools::{izip, Itertools}; -use ndarray::{ArrayBase, Axis, Dim, Ix2, Ix3, IxDynImpl, OwnedRepr, ViewRepr}; +use ndarray::{ArrayBase, Axis, Dim, Ix2, Ix3, IxDynImpl, ViewRepr}; use ort::{ tensor::{FromArray, InputTensor}, Environment, ExecutionProvider, GraphOptimizationLevel, OrtError, Session, SessionBuilder, @@ -8,6 +8,8 @@ use ort::{ use std::{path::Path, sync::Arc}; use tokio::sync::mpsc::{Receiver, UnboundedSender}; +use crate::image_loading::Image; + pub fn setup_inference_session(model_path: impl AsRef) -> Result { let environment = Arc::new( Environment::builder() @@ -74,21 +76,25 @@ fn do_inference( pub async fn inference_worker( session: Session, batch_size: usize, - mut image_rx: Receiver, Dim>>, - prediction_tx: UnboundedSender, + mut image_rx: Receiver<(Image, String)>, + prediction_tx: UnboundedSender<(Predictions, String)>, ) { image_rx .recv() .await .iter() - .map(ArrayBase::view) + .map(|(image, predictions_channel)| (image.view(), predictions_channel)) .chunks(batch_size) .into_iter() - .for_each(|images| { - let images = images.collect::>(); + .for_each(|jobs| { + let (images, prediction_channels) = jobs.into_iter().unzip::<_, _, Vec<_>, Vec<_>>(); let predictions = do_inference(&session, &images); - predictions - .into_iter() - .for_each(|predictions| prediction_tx.send(predictions).unwrap()) + izip!(predictions.into_iter(), prediction_channels.into_iter()).for_each( + |(predictions, prediction_channel)| { + prediction_tx + .send((predictions, prediction_channel.clone())) + .unwrap() + }, + ) }); } diff --git a/chimp_chomp/src/jobs.rs b/chimp_chomp/src/jobs.rs index df7de65f..0b2bf6c0 100644 --- a/chimp_chomp/src/jobs.rs +++ b/chimp_chomp/src/jobs.rs @@ -1,13 +1,12 @@ -use crate::image_loading::load_image; -use chimp_protocol::Job; +use crate::image_loading::{load_image, Image}; +use chimp_protocol::{Job, Predictions}; use futures_lite::StreamExt; use lapin::{ - options::{BasicAckOptions, BasicConsumeOptions}, + options::{BasicAckOptions, BasicConsumeOptions, BasicPublishOptions}, types::FieldTable, - Connection, Consumer, + BasicProperties, Channel, Connection, Consumer, }; -use ndarray::{ArrayBase, Dim, IxDynImpl, OwnedRepr}; -use tokio::sync::mpsc::Sender; +use tokio::sync::mpsc::{Sender, UnboundedReceiver}; use url::Url; use uuid::Uuid; @@ -16,13 +15,12 @@ pub async fn setup_rabbitmq_client(address: Url) -> Result, ) -> Result { let worker_id = Uuid::new_v4(); let worker_tag = format!("chimp_chomp_{worker_id}"); - let job_channel = rabbitmq_client.create_channel().await?; - job_channel + rabbitmq_channel .basic_consume( channel.as_ref(), &worker_tag, @@ -36,7 +34,7 @@ pub async fn job_consumption_worker( mut job_consumer: Consumer, input_width: u32, input_height: u32, - image_tx: Sender, Dim>>, + image_tx: Sender<(Image, String)>, ) { while let Some(delivery) = job_consumer.next().await { let delievry = delivery.unwrap(); @@ -45,6 +43,29 @@ pub async fn job_consumption_worker( let job = Job::from_slice(&delievry.data).unwrap(); let image = load_image(job.file, input_width, input_height); - image_tx.send(image).await.unwrap(); + image_tx + .send((image, job.predictions_channel)) + .await + .unwrap(); + } +} + +pub async fn predictions_producer_worker( + mut prediction_rx: UnboundedReceiver<(Predictions, String)>, + rabbitmq_channel: Channel, +) { + while let Some((predictions, predictions_channel)) = prediction_rx.recv().await { + rabbitmq_channel + .basic_publish( + "", + &predictions_channel, + BasicPublishOptions::default(), + &predictions.to_vec().unwrap(), + BasicProperties::default(), + ) + .await + .unwrap() + .await + .unwrap(); } } diff --git a/chimp_chomp/src/main.rs b/chimp_chomp/src/main.rs index a79d76c1..97ff6144 100644 --- a/chimp_chomp/src/main.rs +++ b/chimp_chomp/src/main.rs @@ -4,7 +4,9 @@ mod jobs; use clap::Parser; use inference::{inference_worker, setup_inference_session}; -use jobs::{job_consumption_worker, setup_job_consumer, setup_rabbitmq_client}; +use jobs::{ + job_consumption_worker, predictions_producer_worker, setup_job_consumer, setup_rabbitmq_client, +}; use std::path::PathBuf; use tokio::task::JoinSet; use url::Url; @@ -31,12 +33,14 @@ async fn main() { let batch_size = session.inputs[0].dimensions[0].unwrap().try_into().unwrap(); let rabbitmq_client = setup_rabbitmq_client(args.rabbitmq_address).await.unwrap(); - let job_consumer = setup_job_consumer(rabbitmq_client, args.rabbitmq_channel) + let job_channel = rabbitmq_client.create_channel().await.unwrap(); + let predictions_channel = rabbitmq_client.create_channel().await.unwrap(); + let job_consumer = setup_job_consumer(job_channel, args.rabbitmq_channel) .await .unwrap(); let (image_tx, image_rx) = tokio::sync::mpsc::channel(batch_size); - let (prediction_tx, _prediction_rx) = tokio::sync::mpsc::unbounded_channel(); + let (prediction_tx, prediction_rx) = tokio::sync::mpsc::unbounded_channel(); let mut tasks = JoinSet::new(); @@ -54,5 +58,10 @@ async fn main() { image_tx, )); + tasks.spawn(predictions_producer_worker( + prediction_rx, + predictions_channel, + )); + tasks.join_next().await; } From 1d0cd0385b589de5fc86c62d7470a60a4dcf3dc6 Mon Sep 17 00:00:00 2001 From: Garry O'Donnell Date: Wed, 5 Jul 2023 14:45:27 +0000 Subject: [PATCH 09/45] Bump clap from 4.3.9 to 4.3.10 --- Cargo.lock | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.lock b/Cargo.lock index 9cb50e88..222a94e6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -612,7 +612,7 @@ name = "chimp_chomp" version = "0.1.0" dependencies = [ "chimp_protocol", - "clap 4.3.9", + "clap 4.3.10", "dotenvy", "futures-lite", "image", From 10fb145a209c866b284475e90ca1123f02f876c3 Mon Sep 17 00:00:00 2001 From: Garry O'Donnell Date: Wed, 5 Jul 2023 14:55:07 +0000 Subject: [PATCH 10/45] Add chimp chomp & protocol to dockerfile --- .github/workflows/container.yml | 1 + Dockerfile | 12 ++++++++++-- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/.github/workflows/container.yml b/.github/workflows/container.yml index a2f2d8d3..8a6e936a 100644 --- a/.github/workflows/container.yml +++ b/.github/workflows/container.yml @@ -11,6 +11,7 @@ jobs: strategy: matrix: service: + - chimp_chomp - soakdb_sync - pin_packing runs-on: ubuntu-latest diff --git a/Dockerfile b/Dockerfile index 3ba1c2e9..75d3bfe1 100644 --- a/Dockerfile +++ b/Dockerfile @@ -4,13 +4,19 @@ WORKDIR /app # Build dependencies COPY Cargo.toml Cargo.lock ./ +COPY chimp_chomp/Cargo.toml chimp_chomp/Cargo.toml +COPY chimp_protocol/Cargo.toml chimp_protocol/Cargo.toml COPY graphql_endpoints/Cargo.toml graphql_endpoints/Cargo.toml COPY graphql_event_broker/Cargo.toml graphql_event_broker/Cargo.toml COPY opa_client/Cargo.toml opa_client/Cargo.toml COPY pin_packing/Cargo.toml pin_packing/Cargo.toml COPY soakdb_io/Cargo.toml soakdb_io/Cargo.toml COPY soakdb_sync/Cargo.toml soakdb_sync/Cargo.toml -RUN mkdir graphql_endpoints/src \ +RUN mkdir chimp_chomp/src \ + && touch chimp_chomp/src/lib.rs \ + && mkdir chimp_protocol/src \ + && touch chimp_protocol/src/lib.rs \ + && mkdir graphql_endpoints/src \ && touch graphql_endpoints/src/lib.rs \ && mkdir graphql_event_broker/src \ && touch graphql_event_broker/src/lib.rs \ @@ -26,7 +32,9 @@ RUN mkdir graphql_endpoints/src \ # Build workspace crates COPY . /app -RUN touch graphql_endpoints/src/lib.rs \ +RUN touch chimp_chomp/src/lib.rs \ + && touch chimp_protocol/src/lib.rs \ + && touch graphql_endpoints/src/lib.rs \ && touch graphql_event_broker/src/lib.rs \ && touch opa_client/src/lib.rs \ && touch pin_packing/src/main.rs \ From 359173eaeabee3610b5f5fd60e288a697ff4bfef Mon Sep 17 00:00:00 2001 From: Garry O'Donnell Date: Wed, 5 Jul 2023 15:09:21 +0000 Subject: [PATCH 11/45] Disable parallel container builds in CI --- .github/workflows/container.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/container.yml b/.github/workflows/container.yml index 8a6e936a..c3c50620 100644 --- a/.github/workflows/container.yml +++ b/.github/workflows/container.yml @@ -14,6 +14,7 @@ jobs: - chimp_chomp - soakdb_sync - pin_packing + max-parallel: 1 runs-on: ubuntu-latest steps: - name: Generate Image Name From 18e238d1e1b5044cf31f47b95e89099bf0b8c402 Mon Sep 17 00:00:00 2001 From: Garry O'Donnell Date: Wed, 5 Jul 2023 15:14:49 +0000 Subject: [PATCH 12/45] Add basic docs to chimp crates --- chimp_chomp/README.md | 3 +++ chimp_chomp/src/main.rs | 4 ++++ chimp_protocol/README.md | 3 +++ chimp_protocol/src/lib.rs | 18 ++++++++++++++++++ 4 files changed, 28 insertions(+) create mode 100644 chimp_chomp/README.md create mode 100644 chimp_protocol/README.md diff --git a/chimp_chomp/README.md b/chimp_chomp/README.md new file mode 100644 index 00000000..37dc5feb --- /dev/null +++ b/chimp_chomp/README.md @@ -0,0 +1,3 @@ +# CHiMP Worker + +This worker steals jobs from a RabbitMQ queue, retrieves images, performs batch inference on them using the CHiMP neural network and returns results on another RabbitMQ queue. The worker is intended to be deployed as a autoscaled to zero service. diff --git a/chimp_chomp/src/main.rs b/chimp_chomp/src/main.rs index 97ff6144..bd656af9 100644 --- a/chimp_chomp/src/main.rs +++ b/chimp_chomp/src/main.rs @@ -1,3 +1,7 @@ +#![forbid(unsafe_code)] +#![warn(missing_docs)] +#![doc=include_str!("../README.md")] + mod image_loading; mod inference; mod jobs; diff --git a/chimp_protocol/README.md b/chimp_protocol/README.md new file mode 100644 index 00000000..c2237153 --- /dev/null +++ b/chimp_protocol/README.md @@ -0,0 +1,3 @@ +# CHiMP Protocol + +This library defines a number data structures common to CHiMP - each of which implement (de)serialization to / from JSON. diff --git a/chimp_protocol/src/lib.rs b/chimp_protocol/src/lib.rs index a2aad245..2080e3fa 100644 --- a/chimp_protocol/src/lib.rs +++ b/chimp_protocol/src/lib.rs @@ -1,47 +1,65 @@ +#![forbid(unsafe_code)] +#![warn(missing_docs)] +#![doc=include_str!("../README.md")] + use serde::{Deserialize, Serialize}; use std::path::PathBuf; +/// A CHiMP job definition. #[derive(Debug, Serialize, Deserialize)] pub struct Job { + /// The path of a file containing the image to perform inference on. pub file: PathBuf, + /// The channel to send predictions to. pub predictions_channel: String, } impl Job { + /// Deserialize an instance [`Job`] from bytes of JSON text. pub fn from_slice(v: &[u8]) -> Result { serde_json::from_slice(v) } + /// Serialize the [`Job`] as a JSON byte vector pub fn to_vec(&self) -> Result, serde_json::Error> { serde_json::to_vec(&self) } } +/// A set of predictions which apply to a single image. #[derive(Debug, Serialize, Deserialize)] pub struct Predictions(pub Vec); impl Predictions { + /// Deserialize an instance [`Predictions`] from bytes of JSON text. pub fn from_slice(v: &[u8]) -> Result { serde_json::from_slice(v) } + /// Serialize the [`Predictions`] as a JSON byte vector pub fn to_vec(&self) -> Result, serde_json::Error> { serde_json::to_vec(&self) } } +/// A singular predicted region. #[derive(Debug, Serialize, Deserialize)] pub struct Prediction { + /// The bounding box which encompases the region. pub bbox: [f32; 4], + /// The class label predicted to exist within the region. pub label: i64, + /// The confidence of the prediction. pub score: f32, } impl Prediction { + /// Deserialize an instance [`Prediction`] from bytes of JSON text. pub fn from_slice(v: &[u8]) -> Result { serde_json::from_slice(v) } + /// Serialize the [`Prediction`] as a JSON byte vector pub fn to_vec(&self) -> Result, serde_json::Error> { serde_json::to_vec(&self) } From 923e464c48f9f47ce56c16d51aa43cd3a9d93a10 Mon Sep 17 00:00:00 2001 From: Garry O'Donnell Date: Wed, 5 Jul 2023 15:38:31 +0000 Subject: [PATCH 13/45] Add rabbitmq server to dev docker-compose --- .devcontainer/docker-compose.yaml | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/.devcontainer/docker-compose.yaml b/.devcontainer/docker-compose.yaml index 8923c7e0..626ea28d 100644 --- a/.devcontainer/docker-compose.yaml +++ b/.devcontainer/docker-compose.yaml @@ -11,6 +11,7 @@ services: environment: OPA_URL: http://opa:8181 POSTGRES_URL: postgres://postgres:password@postgres + RABBITMQ_URL: amqp://rabbitmq:password@rabbitmq opa: image: docker.io/openpolicyagent/opa:0.53.1 @@ -27,4 +28,10 @@ services: postgres: image: docker.io/library/postgres:15.3-bookworm environment: - POSTGRES_PASSWORD: password \ No newline at end of file + POSTGRES_PASSWORD: password + + rabbitmq: + image: docker.io/library/rabbitmq:3.12.1 + environment: + RABBITMQ_DEFAULT_USER: rabbitmq + RABBITMQ_DEFAULT_PASS: password From a9c1f6b4bb1fa29603a5d77181de77e95bfcb927 Mon Sep 17 00:00:00 2001 From: Garry O'Donnell Date: Wed, 5 Jul 2023 15:39:15 +0000 Subject: [PATCH 14/45] Add v4 feature to uuid dependency --- Cargo.lock | 2 +- Cargo.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 222a94e6..ad5c0338 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -612,7 +612,7 @@ name = "chimp_chomp" version = "0.1.0" dependencies = [ "chimp_protocol", - "clap 4.3.10", + "clap 4.3.14", "dotenvy", "futures-lite", "image", diff --git a/Cargo.toml b/Cargo.toml index 706c8322..38364352 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -35,4 +35,4 @@ tokio = { version = "1.29.1", features = ["macros", "rt-multi-thread"] } tracing = { version = "0.1.37" } tracing-subscriber = { version = "0.3.17", features = ["env-filter"] } url = { version = "2.4.0" } -uuid = { version = "1.4.1" } +uuid = { version = "1.4.1", features = ["v4"] } From b624e90dd87280bc78f361551b68a344931addac Mon Sep 17 00:00:00 2001 From: Garry O'Donnell Date: Wed, 5 Jul 2023 15:39:38 +0000 Subject: [PATCH 15/45] Rename chimp chimp rabbitmq url argument --- chimp_chomp/src/main.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/chimp_chomp/src/main.rs b/chimp_chomp/src/main.rs index bd656af9..69516ffe 100644 --- a/chimp_chomp/src/main.rs +++ b/chimp_chomp/src/main.rs @@ -21,7 +21,7 @@ struct Cli { /// The path to the ONNX model file. model: PathBuf, /// The URL of the RabbitMQ server. - rabbitmq_address: Url, + rabbitmq_url: Url, /// The RabbitMQ channel on which jobs are assigned. rabbitmq_channel: String, } @@ -36,7 +36,7 @@ async fn main() { let input_height = session.inputs[0].dimensions[2].unwrap(); let batch_size = session.inputs[0].dimensions[0].unwrap().try_into().unwrap(); - let rabbitmq_client = setup_rabbitmq_client(args.rabbitmq_address).await.unwrap(); + let rabbitmq_client = setup_rabbitmq_client(args.rabbitmq_url).await.unwrap(); let job_channel = rabbitmq_client.create_channel().await.unwrap(); let predictions_channel = rabbitmq_client.create_channel().await.unwrap(); let job_consumer = setup_job_consumer(job_channel, args.rabbitmq_channel) From 4239f2f0294c8964d03d176a0e5015c83074e9f2 Mon Sep 17 00:00:00 2001 From: Garry O'Donnell Date: Thu, 6 Jul 2023 14:34:47 +0000 Subject: [PATCH 16/45] Structure model inputs / outputs for batch inference --- chimp_chomp/src/inference.rs | 92 +++++++++++++++++++----------------- 1 file changed, 48 insertions(+), 44 deletions(-) diff --git a/chimp_chomp/src/inference.rs b/chimp_chomp/src/inference.rs index ea952463..43ca1653 100644 --- a/chimp_chomp/src/inference.rs +++ b/chimp_chomp/src/inference.rs @@ -1,6 +1,6 @@ use chimp_protocol::{Prediction, Predictions}; use itertools::{izip, Itertools}; -use ndarray::{ArrayBase, Axis, Dim, Ix2, Ix3, IxDynImpl, ViewRepr}; +use ndarray::{ArrayBase, Axis, Dim, Ix1, Ix2, IxDynImpl, ViewRepr}; use ort::{ tensor::{FromArray, InputTensor}, Environment, ExecutionProvider, GraphOptimizationLevel, OrtError, Session, SessionBuilder, @@ -25,52 +25,56 @@ pub fn setup_inference_session(model_path: impl AsRef) -> Result, Dim>], + batch_size: usize, ) -> Vec { - let input = InputTensor::from_array(ndarray::concatenate(Axis(0), images).unwrap()); + let images = images + .iter() + .cloned() + .chain(std::iter::repeat(images[0].clone()).take(batch_size - images.len())) + .collect::>(); + let input = InputTensor::from_array(ndarray::concatenate(Axis(0), &images).unwrap()); let outputs = session.run(vec![input]).unwrap(); - let bboxes = outputs[0] - .try_extract::() - .unwrap() - .view() - .to_owned() - .into_dimensionality::() - .unwrap(); - let labels = outputs[1] - .try_extract::() - .unwrap() - .view() - .to_owned() - .into_dimensionality::() - .unwrap(); - let scores = outputs[2] - .try_extract::() - .unwrap() - .view() - .to_owned() - .into_dimensionality::() - .unwrap(); + outputs + .into_iter() + .tuples() + .map(|(bboxes, labels, scores, _)| { + let bboxes = bboxes + .try_extract::() + .unwrap() + .view() + .to_owned() + .into_dimensionality::() + .unwrap(); + let labels = labels + .try_extract::() + .unwrap() + .view() + .to_owned() + .into_dimensionality::() + .unwrap(); + let scores = scores + .try_extract::() + .unwrap() + .view() + .to_owned() + .into_dimensionality::() + .unwrap(); - izip!( - bboxes.outer_iter(), - labels.outer_iter(), - scores.outer_iter() - ) - .map(|(bboxes, labels, scores)| { - Predictions( - izip!( - bboxes.outer_iter(), - labels.to_vec().iter(), - scores.to_vec().iter() + Predictions( + izip!( + bboxes.outer_iter(), + labels.to_vec().iter(), + scores.to_vec().iter() + ) + .map(|(bbox, &label, &score)| Prediction { + bbox: bbox.to_vec().try_into().unwrap(), + label, + score, + }) + .collect(), ) - .map(|(bbox, &label, &score)| Prediction { - bbox: bbox.to_vec().try_into().unwrap(), - label, - score, - }) - .collect(), - ) - }) - .collect() + }) + .collect() } pub async fn inference_worker( @@ -88,7 +92,7 @@ pub async fn inference_worker( .into_iter() .for_each(|jobs| { let (images, prediction_channels) = jobs.into_iter().unzip::<_, _, Vec<_>, Vec<_>>(); - let predictions = do_inference(&session, &images); + let predictions = do_inference(&session, &images, batch_size); izip!(predictions.into_iter(), prediction_channels.into_iter()).for_each( |(predictions, prediction_channel)| { prediction_tx From d28feff1a636e47de867dfc2485f8076ab0f115d Mon Sep 17 00:00:00 2001 From: Garry O'Donnell Date: Wed, 12 Jul 2023 15:50:24 +0000 Subject: [PATCH 17/45] Add opencv to development container --- .devcontainer/Dockerfile | 1 + 1 file changed, 1 insertion(+) diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile index abcf83bc..762d6cfd 100644 --- a/.devcontainer/Dockerfile +++ b/.devcontainer/Dockerfile @@ -5,4 +5,5 @@ RUN rustup component add rustfmt clippy RUN apt-get update \ && apt-get install --yes --no-install-recommends \ sqlite3 pre-commit \ + libopencv-dev clang libclang-dev \ && rm -rf /var/lib/apt/lists/* From b86cb8c27e33d08e5ff92344a0cc1b70c443b6b5 Mon Sep 17 00:00:00 2001 From: Garry O'Donnell Date: Wed, 12 Jul 2023 15:51:41 +0000 Subject: [PATCH 18/45] Use OpenCV for image loading & processing --- Cargo.lock | 326 ++++++++++------------------------------- chimp_chomp/Cargo.toml | 8 +- 2 files changed, 81 insertions(+), 253 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index ad5c0338..f898218e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -532,12 +532,6 @@ version = "0.21.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "604178f6c5c21f02dc555784810edfb88d34ac2c73b2eae109655649ee73ce3d" -[[package]] -name = "bit_field" -version = "0.10.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc827186963e592360843fb5ba4b973e145841266c1357f7180c43526f2e5b61" - [[package]] name = "bitflags" version = "1.3.2" @@ -574,12 +568,6 @@ version = "3.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a3e2c3daef883ecc1b5d58c15adae93470a91d425f3532ba1695849656af3fc1" -[[package]] -name = "bytemuck" -version = "1.13.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "17febce684fd15d89027105661fec94afb475cb995fbc59d2865198446ba2eea" - [[package]] name = "byteorder" version = "1.4.3" @@ -600,6 +588,9 @@ name = "cc" version = "1.0.79" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "50d30906286121d95be3d479533b458f87493b30a4b5f79a607db8f5d11aa91f" +dependencies = [ + "jobserver", +] [[package]] name = "cfg-if" @@ -615,11 +606,10 @@ dependencies = [ "clap 4.3.14", "dotenvy", "futures-lite", - "image", "itertools", "lapin", "ndarray", - "nshare", + "opencv", "ort", "tokio", "url", @@ -632,6 +622,7 @@ version = "0.1.0" dependencies = [ "serde", "serde_json", + "uuid", ] [[package]] @@ -668,6 +659,26 @@ dependencies = [ "phf_codegen", ] +[[package]] +name = "clang" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "84c044c781163c001b913cd018fc95a628c50d0d2dfea8bca77dad71edb16e37" +dependencies = [ + "clang-sys", + "libc", +] + +[[package]] +name = "clang-sys" +version = "1.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c688fc74432808e3eb684cae8830a86be1d66a2bd58e1f248ed0960a590baf6f" +dependencies = [ + "glob", + "libc", +] + [[package]] name = "clap" version = "3.2.25" @@ -745,12 +756,6 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2da6da31387c7e4ef160ffab6d5e7f00c42626fe39aea70a7b0f1773f7dd6c1b" -[[package]] -name = "color_quant" -version = "1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d7b894f5411737b7867f4827955924d7c254fc9f4d91a6aad6b097804b1018b" - [[package]] name = "colorchoice" version = "1.0.0" @@ -812,40 +817,6 @@ dependencies = [ "cfg-if", ] -[[package]] -name = "crossbeam-channel" -version = "0.5.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a33c2bf77f2df06183c3aa30d1e96c0695a313d4f9c453cc3762a6db39f99200" -dependencies = [ - "cfg-if", - "crossbeam-utils", -] - -[[package]] -name = "crossbeam-deque" -version = "0.8.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ce6fd6f855243022dcecf8702fef0c297d4338e226845fe067f6341ad9fa0cef" -dependencies = [ - "cfg-if", - "crossbeam-epoch", - "crossbeam-utils", -] - -[[package]] -name = "crossbeam-epoch" -version = "0.9.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ae211234986c545741a7dc064309f67ee1e5ad243d0e48335adc0484d960bcc7" -dependencies = [ - "autocfg", - "cfg-if", - "crossbeam-utils", - "memoffset", - "scopeguard", -] - [[package]] name = "crossbeam-queue" version = "0.3.8" @@ -865,12 +836,6 @@ dependencies = [ "cfg-if", ] -[[package]] -name = "crunchy" -version = "0.2.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7" - [[package]] name = "crypto-common" version = "0.1.6" @@ -978,6 +943,12 @@ version = "0.15.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1aaf95b3e5c8f23aa320147307562d361db0ae0d51242340f558153b4eb2439b" +[[package]] +name = "dunce" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56ce8c6da7551ec6c462cbaf3bfbc75131ebbfa1c944aeaa9dab51ca1c5f0c3b" + [[package]] name = "either" version = "1.8.1" @@ -1035,22 +1006,6 @@ dependencies = [ "async-trait", ] -[[package]] -name = "exr" -version = "1.6.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "85a7b44a196573e272e0cf0bcf130281c71e9a0c67062954b3323fd364bfdac9" -dependencies = [ - "bit_field", - "flume", - "half", - "lebe", - "miniz_oxide 0.7.1", - "rayon-core", - "smallvec", - "zune-inflate", -] - [[package]] name = "fastrand" version = "1.9.0" @@ -1060,15 +1015,6 @@ dependencies = [ "instant", ] -[[package]] -name = "fdeflate" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d329bdeac514ee06249dabc27877490f17f5d371ec693360768b838e19f3ae10" -dependencies = [ - "simd-adler32", -] - [[package]] name = "filetime" version = "0.2.21" @@ -1099,7 +1045,6 @@ checksum = "1657b4441c3403d9f7b3409e47575237dac27b1b5726df654a6ecbf92f0f7577" dependencies = [ "futures-core", "futures-sink", - "nanorand", "pin-project", "spin 0.9.8", ] @@ -1266,20 +1211,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "be4136b2a15dd319360be1c07d9933517ccf0be8f16bf62a3bee4f0d618df427" dependencies = [ "cfg-if", - "js-sys", "libc", "wasi", - "wasm-bindgen", -] - -[[package]] -name = "gif" -version = "0.12.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "80792593675e051cf94a4b111980da2ba60d4a83e43e0048c5693baab3977045" -dependencies = [ - "color_quant", - "weezl", ] [[package]] @@ -1288,6 +1221,12 @@ version = "0.27.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b6c80984affa11d98d1b88b66ac8853f143217b399d3c74116778ff8fdb4ed2e" +[[package]] +name = "glob" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" + [[package]] name = "graphql_endpoints" version = "0.1.0" @@ -1328,15 +1267,6 @@ dependencies = [ "tracing", ] -[[package]] -name = "half" -version = "2.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "02b4af3693f1b705df946e9fe5631932443781d0aabb423b62fcd4d73f6d2fd0" -dependencies = [ - "crunchy", -] - [[package]] name = "handlebars" version = "4.3.7" @@ -1568,25 +1498,6 @@ dependencies = [ "unicode-normalization", ] -[[package]] -name = "image" -version = "0.24.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "527909aa81e20ac3a44803521443a765550f09b5130c2c2fa1ea59c2f8f50a3a" -dependencies = [ - "bytemuck", - "byteorder", - "color_quant", - "exr", - "gif", - "jpeg-decoder", - "num-rational", - "num-traits", - "png", - "qoi", - "tiff", -] - [[package]] name = "indexmap" version = "1.9.3" @@ -1662,12 +1573,12 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "453ad9f582a441959e5f0d088b02ce04cfe8d51a8eaf077f12ac6d3e94164ca6" [[package]] -name = "jpeg-decoder" -version = "0.3.0" +name = "jobserver" +version = "0.1.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bc0000e42512c92e31c2252315bda326620a4e034105e900c98ec492fa077b3e" +checksum = "936cfd212a0155903bcbc060e316fb6cc7cbf2e1907329391ebadc1fe0ce77c2" dependencies = [ - "rayon", + "libc", ] [[package]] @@ -1707,12 +1618,6 @@ version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" -[[package]] -name = "lebe" -version = "0.5.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "03087c2bad5e1034e8cace5926dec053fb3790248370865f5117a7d0213354c8" - [[package]] name = "lexical" version = "6.1.1" @@ -1865,15 +1770,6 @@ version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2dffe52ecf27772e601905b7522cb4ef790d2cc203488bbd0e2fe85fcb74566d" -[[package]] -name = "memoffset" -version = "0.9.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5a634b1c61a95585bd15607c6ab0c4e5b226e695ff2800ba0cdccddf208c406c" -dependencies = [ - "autocfg", -] - [[package]] name = "mime" version = "0.3.17" @@ -1902,7 +1798,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e7810e0be55b428ada41041c41f32c9f1a42817901b4ccf45fa3d4b6561e74c7" dependencies = [ "adler", - "simd-adler32", ] [[package]] @@ -1934,15 +1829,6 @@ dependencies = [ "version_check", ] -[[package]] -name = "nanorand" -version = "0.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6a51313c5820b0b02bd422f4b44776fbf47961755c74ce64afc73bfad10226c3" -dependencies = [ - "getrandom", -] - [[package]] name = "native-tls" version = "0.2.11" @@ -1984,16 +1870,6 @@ dependencies = [ "minimal-lexical", ] -[[package]] -name = "nshare" -version = "0.9.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4447657cd40e3107416ec4f2ac3e61a18781b00061789e3b8f4bbcbccb26c4c6" -dependencies = [ - "image", - "ndarray", -] - [[package]] name = "nu-ansi-term" version = "0.46.0" @@ -2023,17 +1899,6 @@ dependencies = [ "num-traits", ] -[[package]] -name = "num-rational" -version = "0.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0638a1c9d0a3c0914158145bc76cff373a75a627e6ecbfb71cbe6f453a5a19b0" -dependencies = [ - "autocfg", - "num-integer", - "num-traits", -] - [[package]] name = "num-traits" version = "0.2.15" @@ -2080,6 +1945,39 @@ dependencies = [ "url", ] +[[package]] +name = "opencv" +version = "0.82.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "79290f5f138b26637cae0ae243d77de871a096e334d3fca22f5ddf31ab6f4cc5" +dependencies = [ + "cc", + "dunce", + "jobserver", + "libc", + "num-traits", + "once_cell", + "opencv-binding-generator", + "pkg-config", + "semver", + "shlex", + "vcpkg", +] + +[[package]] +name = "opencv-binding-generator" +version = "0.66.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be5f640bda28b478629f525e8525601586a2a2b9403a4b8f2264fa5fcfebe6be" +dependencies = [ + "clang", + "clang-sys", + "dunce", + "once_cell", + "percent-encoding", + "regex", +] + [[package]] name = "openssl" version = "0.10.55" @@ -2405,19 +2303,6 @@ version = "0.3.27" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "26072860ba924cbfa98ea39c8c19b4dd6a4a25423dbdf219c1eca91aa0cf6964" -[[package]] -name = "png" -version = "0.17.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "59871cc5b6cce7eaccca5a802b4173377a1c2ba90654246789a8fa2334426d11" -dependencies = [ - "bitflags", - "crc32fast", - "fdeflate", - "flate2", - "miniz_oxide 0.7.1", -] - [[package]] name = "polling" version = "2.8.0" @@ -2483,15 +2368,6 @@ dependencies = [ "unicode-ident", ] -[[package]] -name = "qoi" -version = "0.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f6d64c71eb498fe9eae14ce4ec935c555749aef511cca85b5568910d6e48001" -dependencies = [ - "bytemuck", -] - [[package]] name = "quote" version = "1.0.29" @@ -2537,28 +2413,6 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" -[[package]] -name = "rayon" -version = "1.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1d2df5196e37bcc87abebc0053e20787d73847bb33134a69841207dd0a47f03b" -dependencies = [ - "either", - "rayon-core", -] - -[[package]] -name = "rayon-core" -version = "1.11.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4b8f95bd6966f5c87776639160a66bd8ab9895d9d4ab01ddba9fc60661aebe8d" -dependencies = [ - "crossbeam-channel", - "crossbeam-deque", - "crossbeam-utils", - "num_cpus", -] - [[package]] name = "reactor-trait" version = "1.1.0" @@ -3044,10 +2898,10 @@ dependencies = [ ] [[package]] -name = "simd-adler32" -version = "0.3.5" +name = "shlex" +version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "238abfbb77c1915110ad968465608b68e869e0772622c9656714e73e5a1a522f" +checksum = "43b2853a4d09f215c24cc5489c992ce46052d359b5109343cbafbf26bc62f8a3" [[package]] name = "siphasher" @@ -3362,17 +3216,6 @@ dependencies = [ "once_cell", ] -[[package]] -name = "tiff" -version = "0.8.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7449334f9ff2baf290d55d73983a7d6fa15e01198faef72af07e2a8db851e471" -dependencies = [ - "flate2", - "jpeg-decoder", - "weezl", -] - [[package]] name = "tinyvec" version = "1.6.0" @@ -3854,12 +3697,6 @@ dependencies = [ "rustls-webpki", ] -[[package]] -name = "weezl" -version = "0.1.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9193164d4de03a926d909d3bc7c30543cecb35400c02114792c2cae20d5e2dbb" - [[package]] name = "whoami" version = "1.4.1" @@ -4062,12 +3899,3 @@ dependencies = [ "crossbeam-utils", "flate2", ] - -[[package]] -name = "zune-inflate" -version = "0.2.54" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "73ab332fe2f6680068f3582b16a24f90ad7096d5d39b974d1c0aff0125116f02" -dependencies = [ - "simd-adler32", -] diff --git a/chimp_chomp/Cargo.toml b/chimp_chomp/Cargo.toml index 2b4a4f5b..0baa1bc7 100644 --- a/chimp_chomp/Cargo.toml +++ b/chimp_chomp/Cargo.toml @@ -6,17 +6,17 @@ edition = "2021" [dependencies] chimp_protocol = { path = "../chimp_protocol" } clap = { workspace = true } +derive_more = { workspace = true } dotenvy = { workspace = true } futures-lite = { version = "1.13.0" } -image = { version = "0.24.6" } itertools = { workspace = true } lapin = { version = "2.2.1", default-features = false, features = [ "native-tls", ] } ndarray = { version = "0.15.6" } -nshare = { version = "0.9.0", default-features = false, features = [ - "image", - "ndarray", +opencv = { version = "0.82.1", default-features = false, features = [ + "imgproc", + "imgcodecs", ] } ort = { version = "1.14.8", default-features = false, features = [ "download-binaries", From 8681062d8c8ebca84c5c6a52320f08782ca530ff Mon Sep 17 00:00:00 2001 From: Garry O'Donnell Date: Wed, 12 Jul 2023 15:52:35 +0000 Subject: [PATCH 19/45] Refine chimp protocol types --- chimp_protocol/Cargo.toml | 1 + chimp_protocol/src/lib.rs | 70 +++++++++++++++++++++++++-------------- 2 files changed, 47 insertions(+), 24 deletions(-) diff --git a/chimp_protocol/Cargo.toml b/chimp_protocol/Cargo.toml index 2cdabc37..fa03723b 100644 --- a/chimp_protocol/Cargo.toml +++ b/chimp_protocol/Cargo.toml @@ -6,3 +6,4 @@ edition = "2021" [dependencies] serde = { version = "1.0.166" } serde_json = "1.0.100" +uuid = { workspace = true, features = ["serde"] } diff --git a/chimp_protocol/src/lib.rs b/chimp_protocol/src/lib.rs index 2080e3fa..97ae4693 100644 --- a/chimp_protocol/src/lib.rs +++ b/chimp_protocol/src/lib.rs @@ -4,10 +4,13 @@ use serde::{Deserialize, Serialize}; use std::path::PathBuf; +use uuid::Uuid; /// A CHiMP job definition. -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct Job { + /// A unique identifier for the job, to be returned in the [`Response`]. + pub id: Uuid, /// The path of a file containing the image to perform inference on. pub file: PathBuf, /// The channel to send predictions to. @@ -15,22 +18,33 @@ pub struct Job { } impl Job { - /// Deserialize an instance [`Job`] from bytes of JSON text. + /// Deserialize an instance [`Request`] from bytes of JSON text. pub fn from_slice(v: &[u8]) -> Result { serde_json::from_slice(v) } - /// Serialize the [`Job`] as a JSON byte vector + /// Serialize the [`Request`] as a JSON byte vector pub fn to_vec(&self) -> Result, serde_json::Error> { serde_json::to_vec(&self) } } /// A set of predictions which apply to a single image. -#[derive(Debug, Serialize, Deserialize)] -pub struct Predictions(pub Vec); +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Response { + /// The unique identifier of the requesting [`Request`]. + pub job_id: Uuid, + /// The proposed point for solvent insertion. + pub insertion_point: Point, + /// The location of the well centroid and radius. + pub well_location: Circle, + /// A bounding box emcompasing the solvent. + pub drop: BBox, + /// A set of bounding boxes, each emcompasing a crystal. + pub crystals: Vec, +} -impl Predictions { +impl Response { /// Deserialize an instance [`Predictions`] from bytes of JSON text. pub fn from_slice(v: &[u8]) -> Result { serde_json::from_slice(v) @@ -42,25 +56,33 @@ impl Predictions { } } -/// A singular predicted region. -#[derive(Debug, Serialize, Deserialize)] -pub struct Prediction { - /// The bounding box which encompases the region. - pub bbox: [f32; 4], - /// The class label predicted to exist within the region. - pub label: i64, - /// The confidence of the prediction. - pub score: f32, +/// A point in 2D space. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Point { + /// The position of the point in the X axis. + pub x: usize, + /// The position of the point in the Y axis. + pub y: usize, } -impl Prediction { - /// Deserialize an instance [`Prediction`] from bytes of JSON text. - pub fn from_slice(v: &[u8]) -> Result { - serde_json::from_slice(v) - } +/// A circle, defined by the center point and radius. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Circle { + /// The position of the circles center. + pub center: Point, + /// The radius of the circle. + pub radius: f32, +} - /// Serialize the [`Prediction`] as a JSON byte vector - pub fn to_vec(&self) -> Result, serde_json::Error> { - serde_json::to_vec(&self) - } +/// A bounding box which encompasing a region. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct BBox { + /// The position of the upper bound in the Y axis. + pub top: f32, + /// The position of the lower bound in the Y axis. + pub bottom: f32, + /// The position of the upper bound in the X axis. + pub right: f32, + /// The position of the lower bound in the X axis. + pub left: f32, } From f4cc22d42731f3ad8cbef6841f75e546b10f9589 Mon Sep 17 00:00:00 2001 From: Garry O'Donnell Date: Wed, 12 Jul 2023 15:53:02 +0000 Subject: [PATCH 20/45] Load images with OpenCV --- chimp_chomp/src/image_loading.rs | 82 +++++++++++++++++++++++++------- 1 file changed, 65 insertions(+), 17 deletions(-) diff --git a/chimp_chomp/src/image_loading.rs b/chimp_chomp/src/image_loading.rs index fe420a1d..2b3a3175 100644 --- a/chimp_chomp/src/image_loading.rs +++ b/chimp_chomp/src/image_loading.rs @@ -1,18 +1,66 @@ -use image::{imageops::FilterType, ImageFormat}; -use ndarray::{ArrayBase, Axis, Dim, IxDynImpl, OwnedRepr}; -use nshare::ToNdarray3; -use std::{fs::File, io::BufReader, path::Path}; - -pub type Image = ArrayBase, Dim>; - -pub fn load_image(path: impl AsRef, width: u32, height: u32) -> Image { - let file = File::open(path).unwrap(); - let reader = BufReader::new(file); - image::load(reader, ImageFormat::Jpeg) - .unwrap() - .resize_exact(width, height, FilterType::Triangle) - .into_rgb32f() - .into_ndarray3() - .insert_axis(Axis(0)) - .into_dyn() +use derive_more::Deref; +use ndarray::{Array, Ix3}; +use opencv::{ + core::{Size_, Vec3f, CV_32FC3}, + imgcodecs::{imread, IMREAD_COLOR}, + imgproc::{cvt_color, resize, COLOR_BGR2GRAY, COLOR_BGR2RGB, INTER_LINEAR}, + prelude::{Mat, MatTraitConst}, +}; +use std::path::Path; + +#[derive(Debug, Deref)] +pub struct WellImage(Mat); + +#[derive(Debug, Deref)] +pub struct ChimpImage(Array); + +pub fn load_image(path: impl AsRef, width: u32, height: u32) -> (ChimpImage, WellImage) { + let image = imread(path.as_ref().to_str().unwrap(), IMREAD_COLOR).unwrap(); + + let mut resized_image = Mat::default(); + resize( + &image, + &mut resized_image, + Size_ { + width: width as i32, + height: height as i32, + }, + 0.0, + 0.0, + INTER_LINEAR, + ) + .unwrap(); + + let mut well_image = Mat::default(); + cvt_color(&resized_image, &mut well_image, COLOR_BGR2GRAY, 0).unwrap(); + + let mut resized_rgb_image = Mat::default(); + cvt_color(&resized_image, &mut resized_rgb_image, COLOR_BGR2RGB, 0).unwrap(); + let mut resized_rgb_f32_image = Mat::default(); + + resized_rgb_image + .convert_to( + &mut resized_rgb_f32_image, + CV_32FC3, + f64::from(std::u8::MAX).recip(), + 0.0, + ) + .unwrap(); + let chimp_image = Array::from_iter( + resized_rgb_f32_image + .iter::() + .unwrap() + .flat_map(|(_, pixel)| pixel), + ) + .into_shape(( + resized_rgb_f32_image.mat_size()[0] as usize, + resized_rgb_f32_image.mat_size()[1] as usize, + resized_rgb_f32_image.channels() as usize, + )) + .unwrap() + .permuted_axes((2, 0, 1)) + .as_standard_layout() + .to_owned(); + + (ChimpImage(chimp_image), WellImage(well_image)) } From faa2f188e2a01944f1360185e841d8930f5a9031 Mon Sep 17 00:00:00 2001 From: Garry O'Donnell Date: Wed, 12 Jul 2023 15:53:36 +0000 Subject: [PATCH 21/45] Fix inference worker --- chimp_chomp/src/inference.rs | 106 ++++++++++++++++++++--------------- 1 file changed, 61 insertions(+), 45 deletions(-) diff --git a/chimp_chomp/src/inference.rs b/chimp_chomp/src/inference.rs index 43ca1653..50911201 100644 --- a/chimp_chomp/src/inference.rs +++ b/chimp_chomp/src/inference.rs @@ -1,14 +1,19 @@ -use chimp_protocol::{Prediction, Predictions}; +use chimp_protocol::Job; use itertools::{izip, Itertools}; -use ndarray::{ArrayBase, Axis, Dim, Ix1, Ix2, IxDynImpl, ViewRepr}; +use ndarray::{Array1, Array2, Array3, Axis, Ix1, Ix2, Ix4}; use ort::{ tensor::{FromArray, InputTensor}, Environment, ExecutionProvider, GraphOptimizationLevel, OrtError, Session, SessionBuilder, }; -use std::{path::Path, sync::Arc}; -use tokio::sync::mpsc::{Receiver, UnboundedSender}; +use std::{ops::Deref, path::Path, sync::Arc}; +use tokio::sync::mpsc::{error::TryRecvError, Receiver, UnboundedSender}; -use crate::image_loading::Image; +use crate::image_loading::ChimpImage; + +pub type BBoxes = Array2; +pub type Labels = Array1; +pub type Scores = Array1; +pub type Masks = Array3; pub fn setup_inference_session(model_path: impl AsRef) -> Result { let environment = Arc::new( @@ -24,20 +29,22 @@ pub fn setup_inference_session(model_path: impl AsRef) -> Result, Dim>], + images: &[ChimpImage], batch_size: usize, -) -> Vec { - let images = images +) -> Vec<(BBoxes, Labels, Scores, Masks)> { + let batch_images = images .iter() - .cloned() - .chain(std::iter::repeat(images[0].clone()).take(batch_size - images.len())) + .map(|image| image.deref().view()) + .cycle() + .take(batch_size) .collect::>(); - let input = InputTensor::from_array(ndarray::concatenate(Axis(0), &images).unwrap()); + let input = InputTensor::from_array(ndarray::stack(Axis(0), &batch_images).unwrap().into_dyn()); let outputs = session.run(vec![input]).unwrap(); outputs .into_iter() + .take(images.len() * 4) .tuples() - .map(|(bboxes, labels, scores, _)| { + .map(|(bboxes, labels, scores, masks)| { let bboxes = bboxes .try_extract::() .unwrap() @@ -59,20 +66,16 @@ fn do_inference( .to_owned() .into_dimensionality::() .unwrap(); + let masks = masks + .try_extract::() + .unwrap() + .view() + .to_owned() + .into_dimensionality::() + .unwrap() + .remove_axis(Axis(1)); - Predictions( - izip!( - bboxes.outer_iter(), - labels.to_vec().iter(), - scores.to_vec().iter() - ) - .map(|(bbox, &label, &score)| Prediction { - bbox: bbox.to_vec().try_into().unwrap(), - label, - score, - }) - .collect(), - ) + (bboxes, labels, scores, masks) }) .collect() } @@ -80,25 +83,38 @@ fn do_inference( pub async fn inference_worker( session: Session, batch_size: usize, - mut image_rx: Receiver<(Image, String)>, - prediction_tx: UnboundedSender<(Predictions, String)>, + mut image_rx: Receiver<(ChimpImage, Job)>, + prediction_tx: UnboundedSender<(BBoxes, Labels, Scores, Masks, Job)>, ) { - image_rx - .recv() - .await - .iter() - .map(|(image, predictions_channel)| (image.view(), predictions_channel)) - .chunks(batch_size) - .into_iter() - .for_each(|jobs| { - let (images, prediction_channels) = jobs.into_iter().unzip::<_, _, Vec<_>, Vec<_>>(); - let predictions = do_inference(&session, &images, batch_size); - izip!(predictions.into_iter(), prediction_channels.into_iter()).for_each( - |(predictions, prediction_channel)| { - prediction_tx - .send((predictions, prediction_channel.clone())) - .unwrap() - }, - ) - }); + let mut images = Vec::new(); + let mut jobs = Vec::::new(); + loop { + let (image, job) = image_rx.recv().await.unwrap(); + println!("Got image for job: {job:?}"); + images.push(image); + jobs.push(job); + while images.len() < batch_size { + match image_rx.try_recv() { + Ok((image, job)) => { + images.push(image); + jobs.push(job); + Ok(()) + } + Err(TryRecvError::Empty) => break, + Err(TryRecvError::Disconnected) => Err(TryRecvError::Disconnected), + } + .unwrap(); + } + println!("Performing inference on {} images", images.len()); + let predictions = do_inference(&session, &images, batch_size); + izip!(predictions.into_iter(), jobs.iter()).for_each( + |((bboxes, labels, scores, masks), job)| { + prediction_tx + .send((bboxes, labels, scores, masks, job.clone())) + .unwrap(); + }, + ); + images.clear(); + jobs.clear(); + } } From 1583069f0a0107cd3184ff1ff2901f8a7dd8c8e3 Mon Sep 17 00:00:00 2001 From: Garry O'Donnell Date: Wed, 12 Jul 2023 15:54:06 +0000 Subject: [PATCH 22/45] Determine insertion point from masks --- chimp_chomp/src/postprocessing.rs | 97 +++++++++++++++++++++++++++++++ 1 file changed, 97 insertions(+) create mode 100644 chimp_chomp/src/postprocessing.rs diff --git a/chimp_chomp/src/postprocessing.rs b/chimp_chomp/src/postprocessing.rs new file mode 100644 index 00000000..e8323b27 --- /dev/null +++ b/chimp_chomp/src/postprocessing.rs @@ -0,0 +1,97 @@ +use crate::inference::{BBoxes, Labels, Masks}; +use chimp_protocol::{BBox, Job, Point}; +use itertools::izip; +use ndarray::{Array2, ArrayView, ArrayView2, Ix1}; +use opencv::{ + core::CV_8U, + imgproc::{distance_transform, DIST_L1, DIST_MASK_3}, + prelude::Mat, +}; +use tokio::sync::mpsc::UnboundedSender; + +#[derive(Debug)] +pub struct Contents { + pub insertion_point: Point, + pub drop: BBox, + pub crystals: Vec, +} + +const PREDICTION_THRESHOLD: f32 = 0.5; + +fn insertion_mask( + drop_mask: ArrayView2, + crystal_masks: Vec>, +) -> Array2 { + let mut mask = drop_mask.mapv(|prediction| prediction > PREDICTION_THRESHOLD); + crystal_masks.into_iter().for_each(|crystal_mask| { + mask.zip_mut_with(&crystal_mask, |valid, prediction| { + *valid &= *prediction < PREDICTION_THRESHOLD + }) + }); + mask +} + +fn optimal_insert_position(insertion_mask: Array2) -> Point { + let mask = Mat::from_exact_iter(insertion_mask.mapv(|pixel| pixel as u8).into_iter()).unwrap(); + let mut distances = Mat::default(); + distance_transform(&mask, &mut distances, DIST_L1, DIST_MASK_3, CV_8U).unwrap(); + let (furthest_point, _) = distances + .iter::() + .unwrap() + .max_by(|(_, a), (_, b)| a.cmp(b)) + .unwrap(); + Point { + x: furthest_point.x as usize, + y: furthest_point.y as usize, + } +} + +fn bbox_from_array(bbox: ArrayView) -> BBox { + BBox { + left: bbox[0], + top: bbox[1], + right: bbox[2], + bottom: bbox[3], + } +} + +fn find_drop_instance<'a>( + labels: &Labels, + bboxes: &BBoxes, + masks: &'a Masks, +) -> Option<(BBox, ArrayView2<'a, f32>)> { + izip!(labels, bboxes.outer_iter(), masks.outer_iter()) + .find_map(|(label, bbox, mask)| (*label == 1).then_some((bbox_from_array(bbox), mask))) +} + +fn find_crystal_instances<'a>( + labels: &Labels, + bboxes: &BBoxes, + masks: &'a Masks, +) -> Vec<(BBox, ArrayView2<'a, f32>)> { + izip!(labels, bboxes.outer_iter(), masks.outer_iter()) + .filter_map(|(label, bbox, mask)| (*label == 2).then_some((bbox_from_array(bbox), mask))) + .collect() +} + +pub async fn postprocess_inference( + bboxes: BBoxes, + labels: Labels, + masks: Masks, + job: Job, + contents_tx: UnboundedSender<(Contents, Job)>, +) { + println!("Postprocessing: {job:?}"); + let (drop, drop_mask) = find_drop_instance(&labels, &bboxes, &masks).unwrap(); + let (crystals, crystal_masks) = find_crystal_instances(&labels, &bboxes, &masks) + .into_iter() + .unzip(); + let insertion_mask = insertion_mask(drop_mask, crystal_masks); + let insertion_point = optimal_insert_position(insertion_mask); + let contents = Contents { + drop, + crystals, + insertion_point, + }; + contents_tx.send((contents, job)).unwrap(); +} From 8acc21f273c6aee8b28e48eb247ed7aefbe36e97 Mon Sep 17 00:00:00 2001 From: Garry O'Donnell Date: Wed, 12 Jul 2023 15:54:20 +0000 Subject: [PATCH 23/45] Find well centers with Hough transforms --- chimp_chomp/src/well_centering.rs | 44 +++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) create mode 100644 chimp_chomp/src/well_centering.rs diff --git a/chimp_chomp/src/well_centering.rs b/chimp_chomp/src/well_centering.rs new file mode 100644 index 00000000..5737dc87 --- /dev/null +++ b/chimp_chomp/src/well_centering.rs @@ -0,0 +1,44 @@ +use std::ops::Deref; + +use crate::image_loading::WellImage; +use chimp_protocol::{Circle, Job, Point}; +use opencv::{ + core::{Vec4f, Vector}, + imgproc::{hough_circles, HOUGH_GRADIENT}, + prelude::MatTraitConst, +}; +use tokio::sync::mpsc::UnboundedSender; + +pub async fn find_well_center( + image: WellImage, + job: Job, + well_location_tx: UnboundedSender<(Circle, Job)>, +) { + println!("Finding Well Center for {job:?}"); + let max_side = *image.deref().mat_size().iter().max().unwrap(); + let mut circles = Vector::::new(); + hough_circles( + &*image, + &mut circles, + HOUGH_GRADIENT, + 1.0, + 1.0, + 10.0, + 10.0, + max_side / 2, + max_side, + ) + .unwrap(); + let well_location = circles + .into_iter() + .max_by(|&a, &b| a[3].total_cmp(&b[3])) + .unwrap(); + let well_location = Circle { + center: Point { + x: well_location[0] as usize, + y: well_location[1] as usize, + }, + radius: well_location[2], + }; + well_location_tx.send((well_location, job)).unwrap() +} From 6184a701110c1cc1d611aa9e8b07623a1e59e396 Mon Sep 17 00:00:00 2001 From: Garry O'Donnell Date: Wed, 12 Jul 2023 15:54:40 +0000 Subject: [PATCH 24/45] Update job interface --- chimp_chomp/src/jobs.rs | 79 ++++++++++++++++++++++++----------------- 1 file changed, 46 insertions(+), 33 deletions(-) diff --git a/chimp_chomp/src/jobs.rs b/chimp_chomp/src/jobs.rs index 0b2bf6c0..afe755d4 100644 --- a/chimp_chomp/src/jobs.rs +++ b/chimp_chomp/src/jobs.rs @@ -1,12 +1,15 @@ -use crate::image_loading::{load_image, Image}; -use chimp_protocol::{Job, Predictions}; -use futures_lite::StreamExt; +use crate::{ + image_loading::{load_image, ChimpImage, WellImage}, + postprocessing::Contents, +}; +use chimp_protocol::{Circle, Job, Response}; use lapin::{ + message::Delivery, options::{BasicAckOptions, BasicConsumeOptions, BasicPublishOptions}, types::FieldTable, BasicProperties, Channel, Connection, Consumer, }; -use tokio::sync::mpsc::{Sender, UnboundedReceiver}; +use tokio::sync::mpsc::Sender; use url::Url; use uuid::Uuid; @@ -30,42 +33,52 @@ pub async fn setup_job_consumer( .await } -pub async fn job_consumption_worker( - mut job_consumer: Consumer, +pub async fn consume_job( + delivery: Result, input_width: u32, input_height: u32, - image_tx: Sender<(Image, String)>, + chimp_image_tx: Sender<(ChimpImage, Job)>, + well_image_tx: Sender<(WellImage, Job)>, ) { - while let Some(delivery) = job_consumer.next().await { - let delievry = delivery.unwrap(); - delievry.ack(BasicAckOptions::default()).await.unwrap(); + let delievry = delivery.unwrap(); + delievry.ack(BasicAckOptions::default()).await.unwrap(); - let job = Job::from_slice(&delievry.data).unwrap(); - let image = load_image(job.file, input_width, input_height); + let job = Job::from_slice(&delievry.data).unwrap(); + println!("Consumed Job: {job:?}"); + let (chimp_image, well_image) = load_image(job.file.clone(), input_width, input_height); - image_tx - .send((image, job.predictions_channel)) - .await - .unwrap(); - } + chimp_image_tx + .send((chimp_image, job.clone())) + .await + .unwrap(); + well_image_tx.send((well_image, job)).await.unwrap(); } -pub async fn predictions_producer_worker( - mut prediction_rx: UnboundedReceiver<(Predictions, String)>, +pub async fn produce_response( + contents: Contents, + well_location: Circle, + job: Job, rabbitmq_channel: Channel, ) { - while let Some((predictions, predictions_channel)) = prediction_rx.recv().await { - rabbitmq_channel - .basic_publish( - "", - &predictions_channel, - BasicPublishOptions::default(), - &predictions.to_vec().unwrap(), - BasicProperties::default(), - ) - .await - .unwrap() - .await - .unwrap(); - } + println!("Producing response for: {job:?}"); + rabbitmq_channel + .basic_publish( + "", + &job.predictions_channel, + BasicPublishOptions::default(), + &Response { + job_id: job.id, + insertion_point: contents.insertion_point, + well_location, + drop: contents.drop, + crystals: contents.crystals, + } + .to_vec() + .unwrap(), + BasicProperties::default(), + ) + .await + .unwrap() + .await + .unwrap(); } From 282a05691cd169b488474ae6f00f11d889f487af Mon Sep 17 00:00:00 2001 From: Garry O'Donnell Date: Wed, 12 Jul 2023 15:55:06 +0000 Subject: [PATCH 25/45] Spawn workers from main thread --- chimp_chomp/src/main.rs | 80 +++++++++++++++++++++++++++++------------ 1 file changed, 57 insertions(+), 23 deletions(-) diff --git a/chimp_chomp/src/main.rs b/chimp_chomp/src/main.rs index 69516ffe..6deba65a 100644 --- a/chimp_chomp/src/main.rs +++ b/chimp_chomp/src/main.rs @@ -5,14 +5,19 @@ mod image_loading; mod inference; mod jobs; +mod postprocessing; +mod well_centering; -use clap::Parser; -use inference::{inference_worker, setup_inference_session}; -use jobs::{ - job_consumption_worker, predictions_producer_worker, setup_job_consumer, setup_rabbitmq_client, +use crate::{ + inference::{inference_worker, setup_inference_session}, + jobs::{consume_job, produce_response, setup_job_consumer, setup_rabbitmq_client}, + postprocessing::postprocess_inference, + well_centering::find_well_center, }; -use std::path::PathBuf; -use tokio::task::JoinSet; +use clap::Parser; +use futures_lite::StreamExt; +use std::{collections::HashMap, path::PathBuf}; +use tokio::{select, task::JoinSet}; use url::Url; #[derive(Debug, Parser)] @@ -26,10 +31,11 @@ struct Cli { rabbitmq_channel: String, } -#[tokio::main] +#[tokio::main(flavor = "multi_thread", worker_threads = 4)] async fn main() { dotenvy::dotenv().ok(); let args = Cli::parse(); + opencv::core::set_num_threads(0).unwrap(); let session = setup_inference_session(args.model).unwrap(); let input_width = session.inputs[0].dimensions[3].unwrap(); @@ -38,34 +44,62 @@ async fn main() { let rabbitmq_client = setup_rabbitmq_client(args.rabbitmq_url).await.unwrap(); let job_channel = rabbitmq_client.create_channel().await.unwrap(); - let predictions_channel = rabbitmq_client.create_channel().await.unwrap(); - let job_consumer = setup_job_consumer(job_channel, args.rabbitmq_channel) + let response_channel = rabbitmq_client.create_channel().await.unwrap(); + let mut job_consumer = setup_job_consumer(job_channel, args.rabbitmq_channel) .await .unwrap(); - let (image_tx, image_rx) = tokio::sync::mpsc::channel(batch_size); - let (prediction_tx, prediction_rx) = tokio::sync::mpsc::unbounded_channel(); + let (chimp_image_tx, chimp_image_rx) = tokio::sync::mpsc::channel(batch_size); + let (well_image_tx, mut well_image_rx) = tokio::sync::mpsc::channel(batch_size); + let (well_location_tx, mut well_location_rx) = tokio::sync::mpsc::unbounded_channel(); + let (prediction_tx, mut prediction_rx) = tokio::sync::mpsc::unbounded_channel(); + let (contents_tx, mut contents_rx) = tokio::sync::mpsc::unbounded_channel(); let mut tasks = JoinSet::new(); tasks.spawn(inference_worker( session, batch_size, - image_rx, + chimp_image_rx, prediction_tx, )); - tasks.spawn(job_consumption_worker( - job_consumer, - input_width, - input_height, - image_tx, - )); + let mut well_locations = HashMap::new(); + let mut well_contents = HashMap::new(); - tasks.spawn(predictions_producer_worker( - prediction_rx, - predictions_channel, - )); + loop { + select! { + biased; + + Some(delivery) = job_consumer.next() => { + tasks.spawn(consume_job(delivery, input_width, input_height, chimp_image_tx.clone(), well_image_tx.clone())); + } + + Some((well_image, job)) = well_image_rx.recv() => { + tasks.spawn(find_well_center(well_image, job, well_location_tx.clone())); + } + + Some((bboxes, labels, _, masks, job)) = prediction_rx.recv() => { + tasks.spawn(postprocess_inference(bboxes, labels, masks, job, contents_tx.clone())); + } + + Some((well_location, job)) = well_location_rx.recv() => { + if let Some(contents) = well_contents.remove(&job.id) { + tasks.spawn(produce_response(contents, well_location, job, response_channel.clone())); + } else { + well_locations.insert(job.id, well_location); + } + } + + Some((contents, job)) = contents_rx.recv() => { + if let Some(well_location) = well_locations.remove(&job.id) { + tasks.spawn(produce_response(contents, well_location, job, response_channel.clone())); + } else { + well_contents.insert(job.id, contents); + } + } - tasks.join_next().await; + else => break + } + } } From a8dc2035b7f84e1e1e1f5fdd6457bdda04fffb52 Mon Sep 17 00:00:00 2001 From: Garry O'Donnell Date: Wed, 12 Jul 2023 17:40:55 +0000 Subject: [PATCH 26/45] Improve inference logging --- chimp_chomp/src/inference.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/chimp_chomp/src/inference.rs b/chimp_chomp/src/inference.rs index 50911201..b43a41ca 100644 --- a/chimp_chomp/src/inference.rs +++ b/chimp_chomp/src/inference.rs @@ -90,7 +90,6 @@ pub async fn inference_worker( let mut jobs = Vec::::new(); loop { let (image, job) = image_rx.recv().await.unwrap(); - println!("Got image for job: {job:?}"); images.push(image); jobs.push(job); while images.len() < batch_size { @@ -105,7 +104,7 @@ pub async fn inference_worker( } .unwrap(); } - println!("Performing inference on {} images", images.len()); + println!("CHiMP Inference ({}): {:?}", images.len(), jobs); let predictions = do_inference(&session, &images, batch_size); izip!(predictions.into_iter(), jobs.iter()).for_each( |((bboxes, labels, scores, masks), job)| { From 0e7ec626b44ee5f984a9804e639e10ae3a6ec1dc Mon Sep 17 00:00:00 2001 From: Garry O'Donnell Date: Wed, 12 Jul 2023 17:41:44 +0000 Subject: [PATCH 27/45] Fix well centering --- chimp_chomp/src/image_loading.rs | 33 ++++++++++++++++++++----------- chimp_chomp/src/well_centering.rs | 15 +++++++------- 2 files changed, 29 insertions(+), 19 deletions(-) diff --git a/chimp_chomp/src/image_loading.rs b/chimp_chomp/src/image_loading.rs index 2b3a3175..fa685466 100644 --- a/chimp_chomp/src/image_loading.rs +++ b/chimp_chomp/src/image_loading.rs @@ -14,26 +14,18 @@ pub struct WellImage(Mat); #[derive(Debug, Deref)] pub struct ChimpImage(Array); -pub fn load_image(path: impl AsRef, width: u32, height: u32) -> (ChimpImage, WellImage) { - let image = imread(path.as_ref().to_str().unwrap(), IMREAD_COLOR).unwrap(); - +fn prepare_chimp(image: &Mat, width: i32, height: i32) -> ChimpImage { let mut resized_image = Mat::default(); resize( &image, &mut resized_image, - Size_ { - width: width as i32, - height: height as i32, - }, + Size_ { width, height }, 0.0, 0.0, INTER_LINEAR, ) .unwrap(); - let mut well_image = Mat::default(); - cvt_color(&resized_image, &mut well_image, COLOR_BGR2GRAY, 0).unwrap(); - let mut resized_rgb_image = Mat::default(); cvt_color(&resized_image, &mut resized_rgb_image, COLOR_BGR2RGB, 0).unwrap(); let mut resized_rgb_f32_image = Mat::default(); @@ -62,5 +54,24 @@ pub fn load_image(path: impl AsRef, width: u32, height: u32) -> (ChimpImag .as_standard_layout() .to_owned(); - (ChimpImage(chimp_image), WellImage(well_image)) + ChimpImage(chimp_image) +} + +fn prepare_well(image: &Mat) -> WellImage { + let mut well_image = Mat::default(); + cvt_color(&image, &mut well_image, COLOR_BGR2GRAY, 0).unwrap(); + WellImage(well_image) +} + +pub fn load_image( + path: impl AsRef, + chimp_width: u32, + chimp_height: u32, +) -> (ChimpImage, WellImage) { + let image = imread(path.as_ref().to_str().unwrap(), IMREAD_COLOR).unwrap(); + + let well_image = prepare_well(&image); + let chimp_image = prepare_chimp(&image, chimp_width as i32, chimp_height as i32); + + (chimp_image, well_image) } diff --git a/chimp_chomp/src/well_centering.rs b/chimp_chomp/src/well_centering.rs index 5737dc87..af236f18 100644 --- a/chimp_chomp/src/well_centering.rs +++ b/chimp_chomp/src/well_centering.rs @@ -1,5 +1,3 @@ -use std::ops::Deref; - use crate::image_loading::WellImage; use chimp_protocol::{Circle, Job, Point}; use opencv::{ @@ -7,6 +5,7 @@ use opencv::{ imgproc::{hough_circles, HOUGH_GRADIENT}, prelude::MatTraitConst, }; +use std::ops::Deref; use tokio::sync::mpsc::UnboundedSender; pub async fn find_well_center( @@ -15,18 +14,18 @@ pub async fn find_well_center( well_location_tx: UnboundedSender<(Circle, Job)>, ) { println!("Finding Well Center for {job:?}"); - let max_side = *image.deref().mat_size().iter().max().unwrap(); + let min_side = *image.deref().mat_size().iter().min().unwrap(); let mut circles = Vector::::new(); hough_circles( &*image, &mut circles, HOUGH_GRADIENT, + 4.0, 1.0, - 1.0, - 10.0, - 10.0, - max_side / 2, - max_side, + 100.0, + 100.0, + min_side * 3 / 8, + min_side / 2, ) .unwrap(); let well_location = circles From c5f6ca8c5b776769c6e1471bf4076d303892c41b Mon Sep 17 00:00:00 2001 From: Garry O'Donnell Date: Wed, 12 Jul 2023 17:42:23 +0000 Subject: [PATCH 28/45] Remove tokio thread limit --- chimp_chomp/src/main.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chimp_chomp/src/main.rs b/chimp_chomp/src/main.rs index 6deba65a..ec69a822 100644 --- a/chimp_chomp/src/main.rs +++ b/chimp_chomp/src/main.rs @@ -31,7 +31,7 @@ struct Cli { rabbitmq_channel: String, } -#[tokio::main(flavor = "multi_thread", worker_threads = 4)] +#[tokio::main] async fn main() { dotenvy::dotenv().ok(); let args = Cli::parse(); From 688ca46e3fc90fbb60d0ce1b15637eaea9ea88ba Mon Sep 17 00:00:00 2001 From: Garry O'Donnell Date: Thu, 13 Jul 2023 10:58:37 +0000 Subject: [PATCH 29/45] Add timeout option to shutdown when unoccupied --- Cargo.lock | 9 ++++++++- chimp_chomp/Cargo.toml | 3 ++- chimp_chomp/src/main.rs | 27 +++++++++++++++++++++------ 3 files changed, 31 insertions(+), 8 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index f898218e..6a3a5fd3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -605,7 +605,8 @@ dependencies = [ "chimp_protocol", "clap 4.3.14", "dotenvy", - "futures-lite", + "futures", + "futures-timer", "itertools", "lapin", "ndarray", @@ -1176,6 +1177,12 @@ version = "0.3.28" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "76d3d132be6c0e6aa1534069c705a74a5997a356c0dc2f86a47765e5617c5b65" +[[package]] +name = "futures-timer" +version = "3.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e64b03909df88034c26dc1547e8970b91f98bdb65165d6a4e9110d94263dbb2c" + [[package]] name = "futures-util" version = "0.3.28" diff --git a/chimp_chomp/Cargo.toml b/chimp_chomp/Cargo.toml index 0baa1bc7..a2645aa6 100644 --- a/chimp_chomp/Cargo.toml +++ b/chimp_chomp/Cargo.toml @@ -8,7 +8,8 @@ chimp_protocol = { path = "../chimp_protocol" } clap = { workspace = true } derive_more = { workspace = true } dotenvy = { workspace = true } -futures-lite = { version = "1.13.0" } +futures = { version = "0.3.28" } +futures-timer = { version = "3.0.2" } itertools = { workspace = true } lapin = { version = "2.2.1", default-features = false, features = [ "native-tls", diff --git a/chimp_chomp/src/main.rs b/chimp_chomp/src/main.rs index ec69a822..d40a510e 100644 --- a/chimp_chomp/src/main.rs +++ b/chimp_chomp/src/main.rs @@ -15,9 +15,10 @@ use crate::{ well_centering::find_well_center, }; use clap::Parser; -use futures_lite::StreamExt; -use std::{collections::HashMap, path::PathBuf}; -use tokio::{select, task::JoinSet}; +use futures::{future::Either, StreamExt}; +use futures_timer::Delay; +use std::{collections::HashMap, path::PathBuf, time::Duration}; +use tokio::{select, spawn, task::JoinSet}; use url::Url; #[derive(Debug, Parser)] @@ -29,6 +30,9 @@ struct Cli { rabbitmq_url: Url, /// The RabbitMQ channel on which jobs are assigned. rabbitmq_channel: String, + /// The duration (in milliseconds) to wait after completing all jobs before shutting down. + #[arg(long)] + timeout: Option, } #[tokio::main] @@ -55,19 +59,25 @@ async fn main() { let (prediction_tx, mut prediction_rx) = tokio::sync::mpsc::unbounded_channel(); let (contents_tx, mut contents_rx) = tokio::sync::mpsc::unbounded_channel(); - let mut tasks = JoinSet::new(); - - tasks.spawn(inference_worker( + spawn(inference_worker( session, batch_size, chimp_image_rx, prediction_tx, )); + let mut tasks = JoinSet::new(); + let mut well_locations = HashMap::new(); let mut well_contents = HashMap::new(); loop { + let timeout = if let Some(timeout) = args.timeout { + Either::Left(Delay::new(Duration::from_millis(timeout))) + } else { + Either::Right(std::future::pending()) + }; + select! { biased; @@ -99,6 +109,11 @@ async fn main() { } } + _ = timeout => { + println!("Stopping: No jobs processed for {}ms", args.timeout.unwrap()); + break; + } + else => break } } From 7dee89ff8d47fa8b919f6f8efaf71a85e7e6f8c1 Mon Sep 17 00:00:00 2001 From: Garry O'Donnell Date: Thu, 13 Jul 2023 11:01:53 +0000 Subject: [PATCH 30/45] Re-bias task scheduler to prefer sending responses --- chimp_chomp/src/main.rs | 31 +++++++++++++++++-------------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/chimp_chomp/src/main.rs b/chimp_chomp/src/main.rs index d40a510e..2a608099 100644 --- a/chimp_chomp/src/main.rs +++ b/chimp_chomp/src/main.rs @@ -14,9 +14,11 @@ use crate::{ postprocessing::postprocess_inference, well_centering::find_well_center, }; +use chimp_protocol::{Circle, Job}; use clap::Parser; use futures::{future::Either, StreamExt}; use futures_timer::Delay; +use postprocessing::Contents; use std::{collections::HashMap, path::PathBuf, time::Duration}; use tokio::{select, spawn, task::JoinSet}; use url::Url; @@ -55,9 +57,10 @@ async fn main() { let (chimp_image_tx, chimp_image_rx) = tokio::sync::mpsc::channel(batch_size); let (well_image_tx, mut well_image_rx) = tokio::sync::mpsc::channel(batch_size); - let (well_location_tx, mut well_location_rx) = tokio::sync::mpsc::unbounded_channel(); + let (well_location_tx, mut well_location_rx) = + tokio::sync::mpsc::unbounded_channel::<(Circle, Job)>(); let (prediction_tx, mut prediction_rx) = tokio::sync::mpsc::unbounded_channel(); - let (contents_tx, mut contents_rx) = tokio::sync::mpsc::unbounded_channel(); + let (contents_tx, mut contents_rx) = tokio::sync::mpsc::unbounded_channel::<(Contents, Job)>(); spawn(inference_worker( session, @@ -81,18 +84,6 @@ async fn main() { select! { biased; - Some(delivery) = job_consumer.next() => { - tasks.spawn(consume_job(delivery, input_width, input_height, chimp_image_tx.clone(), well_image_tx.clone())); - } - - Some((well_image, job)) = well_image_rx.recv() => { - tasks.spawn(find_well_center(well_image, job, well_location_tx.clone())); - } - - Some((bboxes, labels, _, masks, job)) = prediction_rx.recv() => { - tasks.spawn(postprocess_inference(bboxes, labels, masks, job, contents_tx.clone())); - } - Some((well_location, job)) = well_location_rx.recv() => { if let Some(contents) = well_contents.remove(&job.id) { tasks.spawn(produce_response(contents, well_location, job, response_channel.clone())); @@ -109,6 +100,18 @@ async fn main() { } } + Some(delivery) = job_consumer.next() => { + tasks.spawn(consume_job(delivery, input_width, input_height, chimp_image_tx.clone(), well_image_tx.clone())); + } + + Some((well_image, job)) = well_image_rx.recv() => { + tasks.spawn(find_well_center(well_image, job, well_location_tx.clone())); + } + + Some((bboxes, labels, _, masks, job)) = prediction_rx.recv() => { + tasks.spawn(postprocess_inference(bboxes, labels, masks, job, contents_tx.clone())); + } + _ = timeout => { println!("Stopping: No jobs processed for {}ms", args.timeout.unwrap()); break; From 366df4b36caa481d621dd83e482d23f914258fb3 Mon Sep 17 00:00:00 2001 From: Garry O'Donnell Date: Thu, 13 Jul 2023 12:19:17 +0000 Subject: [PATCH 31/45] Add thread option to control worker thread count --- chimp_chomp/src/main.rs | 159 +++++++++++++++++++++------------------- 1 file changed, 84 insertions(+), 75 deletions(-) diff --git a/chimp_chomp/src/main.rs b/chimp_chomp/src/main.rs index 2a608099..a14ca378 100644 --- a/chimp_chomp/src/main.rs +++ b/chimp_chomp/src/main.rs @@ -35,89 +35,98 @@ struct Cli { /// The duration (in milliseconds) to wait after completing all jobs before shutting down. #[arg(long)] timeout: Option, + /// The number of worker threads to use + #[arg(long)] + threads: Option, } -#[tokio::main] -async fn main() { +fn main() { dotenvy::dotenv().ok(); let args = Cli::parse(); opencv::core::set_num_threads(0).unwrap(); - let session = setup_inference_session(args.model).unwrap(); - let input_width = session.inputs[0].dimensions[3].unwrap(); - let input_height = session.inputs[0].dimensions[2].unwrap(); - let batch_size = session.inputs[0].dimensions[0].unwrap().try_into().unwrap(); - - let rabbitmq_client = setup_rabbitmq_client(args.rabbitmq_url).await.unwrap(); - let job_channel = rabbitmq_client.create_channel().await.unwrap(); - let response_channel = rabbitmq_client.create_channel().await.unwrap(); - let mut job_consumer = setup_job_consumer(job_channel, args.rabbitmq_channel) - .await - .unwrap(); - - let (chimp_image_tx, chimp_image_rx) = tokio::sync::mpsc::channel(batch_size); - let (well_image_tx, mut well_image_rx) = tokio::sync::mpsc::channel(batch_size); - let (well_location_tx, mut well_location_rx) = - tokio::sync::mpsc::unbounded_channel::<(Circle, Job)>(); - let (prediction_tx, mut prediction_rx) = tokio::sync::mpsc::unbounded_channel(); - let (contents_tx, mut contents_rx) = tokio::sync::mpsc::unbounded_channel::<(Contents, Job)>(); - - spawn(inference_worker( - session, - batch_size, - chimp_image_rx, - prediction_tx, - )); - - let mut tasks = JoinSet::new(); - - let mut well_locations = HashMap::new(); - let mut well_contents = HashMap::new(); - - loop { - let timeout = if let Some(timeout) = args.timeout { - Either::Left(Delay::new(Duration::from_millis(timeout))) - } else { - Either::Right(std::future::pending()) - }; - - select! { - biased; - - Some((well_location, job)) = well_location_rx.recv() => { - if let Some(contents) = well_contents.remove(&job.id) { - tasks.spawn(produce_response(contents, well_location, job, response_channel.clone())); - } else { - well_locations.insert(job.id, well_location); + let mut runtime_builder = tokio::runtime::Builder::new_multi_thread(); + runtime_builder.enable_all(); + if let Some(threads) = args.threads { + runtime_builder.worker_threads(threads); + } + runtime_builder.build().unwrap().block_on(async { + let session = setup_inference_session(args.model).unwrap(); + let input_width = session.inputs[0].dimensions[3].unwrap(); + let input_height = session.inputs[0].dimensions[2].unwrap(); + let batch_size = session.inputs[0].dimensions[0].unwrap().try_into().unwrap(); + + let rabbitmq_client = setup_rabbitmq_client(args.rabbitmq_url).await.unwrap(); + let job_channel = rabbitmq_client.create_channel().await.unwrap(); + let response_channel = rabbitmq_client.create_channel().await.unwrap(); + let mut job_consumer = setup_job_consumer(job_channel, args.rabbitmq_channel) + .await + .unwrap(); + + let (chimp_image_tx, chimp_image_rx) = tokio::sync::mpsc::channel(batch_size); + let (well_image_tx, mut well_image_rx) = tokio::sync::mpsc::channel(batch_size); + let (well_location_tx, mut well_location_rx) = + tokio::sync::mpsc::unbounded_channel::<(Circle, Job)>(); + let (prediction_tx, mut prediction_rx) = tokio::sync::mpsc::unbounded_channel(); + let (contents_tx, mut contents_rx) = tokio::sync::mpsc::unbounded_channel::<(Contents, Job)>(); + + spawn(inference_worker( + session, + batch_size, + chimp_image_rx, + prediction_tx, + )); + + let mut tasks = JoinSet::new(); + + let mut well_locations = HashMap::new(); + let mut well_contents = HashMap::new(); + + loop { + let timeout = if let Some(timeout) = args.timeout { + Either::Left(Delay::new(Duration::from_millis(timeout))) + } else { + Either::Right(std::future::pending()) + }; + + select! { + biased; + + Some((well_location, job)) = well_location_rx.recv() => { + if let Some(contents) = well_contents.remove(&job.id) { + tasks.spawn(produce_response(contents, well_location, job, response_channel.clone())); + } else { + well_locations.insert(job.id, well_location); + } } - } - - Some((contents, job)) = contents_rx.recv() => { - if let Some(well_location) = well_locations.remove(&job.id) { - tasks.spawn(produce_response(contents, well_location, job, response_channel.clone())); - } else { - well_contents.insert(job.id, contents); + + Some((contents, job)) = contents_rx.recv() => { + if let Some(well_location) = well_locations.remove(&job.id) { + tasks.spawn(produce_response(contents, well_location, job, response_channel.clone())); + } else { + well_contents.insert(job.id, contents); + } } + + Some(delivery) = job_consumer.next() => { + tasks.spawn(consume_job(delivery, input_width, input_height, chimp_image_tx.clone(), well_image_tx.clone())); + } + + Some((well_image, job)) = well_image_rx.recv() => { + tasks.spawn(find_well_center(well_image, job, well_location_tx.clone())); + } + + Some((bboxes, labels, _, masks, job)) = prediction_rx.recv() => { + tasks.spawn(postprocess_inference(bboxes, labels, masks, job, contents_tx.clone())); + } + + _ = timeout => { + println!("Stopping: No jobs processed for {}ms", args.timeout.unwrap()); + break; + } + + else => break } - - Some(delivery) = job_consumer.next() => { - tasks.spawn(consume_job(delivery, input_width, input_height, chimp_image_tx.clone(), well_image_tx.clone())); - } - - Some((well_image, job)) = well_image_rx.recv() => { - tasks.spawn(find_well_center(well_image, job, well_location_tx.clone())); - } - - Some((bboxes, labels, _, masks, job)) = prediction_rx.recv() => { - tasks.spawn(postprocess_inference(bboxes, labels, masks, job, contents_tx.clone())); - } - - _ = timeout => { - println!("Stopping: No jobs processed for {}ms", args.timeout.unwrap()); - break; - } - - else => break } - } + }); } From 94b52d854830eecc096c0cb3f93fc55065179f59 Mon Sep 17 00:00:00 2001 From: Garry O'Donnell Date: Thu, 13 Jul 2023 14:19:38 +0000 Subject: [PATCH 32/45] Propagate inference backpressure to RabbitMQ --- chimp_chomp/src/jobs.rs | 19 ++++++++----------- chimp_chomp/src/main.rs | 9 +++++---- 2 files changed, 13 insertions(+), 15 deletions(-) diff --git a/chimp_chomp/src/jobs.rs b/chimp_chomp/src/jobs.rs index afe755d4..711f4f48 100644 --- a/chimp_chomp/src/jobs.rs +++ b/chimp_chomp/src/jobs.rs @@ -3,13 +3,13 @@ use crate::{ postprocessing::Contents, }; use chimp_protocol::{Circle, Job, Response}; +use futures::StreamExt; use lapin::{ - message::Delivery, options::{BasicAckOptions, BasicConsumeOptions, BasicPublishOptions}, types::FieldTable, BasicProperties, Channel, Connection, Consumer, }; -use tokio::sync::mpsc::Sender; +use tokio::sync::mpsc::{OwnedPermit, Sender}; use url::Url; use uuid::Uuid; @@ -34,23 +34,20 @@ pub async fn setup_job_consumer( } pub async fn consume_job( - delivery: Result, + mut consumer: Consumer, input_width: u32, input_height: u32, - chimp_image_tx: Sender<(ChimpImage, Job)>, + chimp_permit: OwnedPermit<(ChimpImage, Job)>, well_image_tx: Sender<(WellImage, Job)>, ) { - let delievry = delivery.unwrap(); - delievry.ack(BasicAckOptions::default()).await.unwrap(); + let delivery = consumer.next().await.unwrap().unwrap(); + delivery.ack(BasicAckOptions::default()).await.unwrap(); - let job = Job::from_slice(&delievry.data).unwrap(); + let job = Job::from_slice(&delivery.data).unwrap(); println!("Consumed Job: {job:?}"); let (chimp_image, well_image) = load_image(job.file.clone(), input_width, input_height); - chimp_image_tx - .send((chimp_image, job.clone())) - .await - .unwrap(); + chimp_permit.send((chimp_image, job.clone())); well_image_tx.send((well_image, job)).await.unwrap(); } diff --git a/chimp_chomp/src/main.rs b/chimp_chomp/src/main.rs index a14ca378..fb907866 100644 --- a/chimp_chomp/src/main.rs +++ b/chimp_chomp/src/main.rs @@ -16,7 +16,7 @@ use crate::{ }; use chimp_protocol::{Circle, Job}; use clap::Parser; -use futures::{future::Either, StreamExt}; +use futures::future::Either; use futures_timer::Delay; use postprocessing::Contents; use std::{collections::HashMap, path::PathBuf, time::Duration}; @@ -59,7 +59,7 @@ fn main() { let rabbitmq_client = setup_rabbitmq_client(args.rabbitmq_url).await.unwrap(); let job_channel = rabbitmq_client.create_channel().await.unwrap(); let response_channel = rabbitmq_client.create_channel().await.unwrap(); - let mut job_consumer = setup_job_consumer(job_channel, args.rabbitmq_channel) + let job_consumer = setup_job_consumer(job_channel, args.rabbitmq_channel) .await .unwrap(); @@ -108,8 +108,9 @@ fn main() { } } - Some(delivery) = job_consumer.next() => { - tasks.spawn(consume_job(delivery, input_width, input_height, chimp_image_tx.clone(), well_image_tx.clone())); + chimp_permit = chimp_image_tx.clone().reserve_owned() => { + let chimp_permit = chimp_permit.unwrap(); + tasks.spawn(consume_job(job_consumer.clone(), input_width, input_height, chimp_permit, well_image_tx.clone())); } Some((well_image, job)) = well_image_rx.recv() => { From 70cb0a552c10cd74412dbf4b07f7aff0ec2a54f7 Mon Sep 17 00:00:00 2001 From: Garry O'Donnell Date: Thu, 13 Jul 2023 14:37:41 +0000 Subject: [PATCH 33/45] Refactor main function into setup & runtime --- chimp_chomp/src/main.rs | 167 +++++++++++++++++++++------------------- 1 file changed, 86 insertions(+), 81 deletions(-) diff --git a/chimp_chomp/src/main.rs b/chimp_chomp/src/main.rs index fb907866..77870a6b 100644 --- a/chimp_chomp/src/main.rs +++ b/chimp_chomp/src/main.rs @@ -45,89 +45,94 @@ fn main() { let args = Cli::parse(); opencv::core::set_num_threads(0).unwrap(); - let mut runtime_builder = tokio::runtime::Builder::new_multi_thread(); - runtime_builder.enable_all(); - if let Some(threads) = args.threads { - runtime_builder.worker_threads(threads); - } - runtime_builder.build().unwrap().block_on(async { - let session = setup_inference_session(args.model).unwrap(); - let input_width = session.inputs[0].dimensions[3].unwrap(); - let input_height = session.inputs[0].dimensions[2].unwrap(); - let batch_size = session.inputs[0].dimensions[0].unwrap().try_into().unwrap(); - - let rabbitmq_client = setup_rabbitmq_client(args.rabbitmq_url).await.unwrap(); - let job_channel = rabbitmq_client.create_channel().await.unwrap(); - let response_channel = rabbitmq_client.create_channel().await.unwrap(); - let job_consumer = setup_job_consumer(job_channel, args.rabbitmq_channel) - .await - .unwrap(); - - let (chimp_image_tx, chimp_image_rx) = tokio::sync::mpsc::channel(batch_size); - let (well_image_tx, mut well_image_rx) = tokio::sync::mpsc::channel(batch_size); - let (well_location_tx, mut well_location_rx) = - tokio::sync::mpsc::unbounded_channel::<(Circle, Job)>(); - let (prediction_tx, mut prediction_rx) = tokio::sync::mpsc::unbounded_channel(); - let (contents_tx, mut contents_rx) = tokio::sync::mpsc::unbounded_channel::<(Contents, Job)>(); - - spawn(inference_worker( - session, - batch_size, - chimp_image_rx, - prediction_tx, - )); - - let mut tasks = JoinSet::new(); - - let mut well_locations = HashMap::new(); - let mut well_contents = HashMap::new(); - - loop { - let timeout = if let Some(timeout) = args.timeout { - Either::Left(Delay::new(Duration::from_millis(timeout))) - } else { - Either::Right(std::future::pending()) - }; - - select! { - biased; - - Some((well_location, job)) = well_location_rx.recv() => { - if let Some(contents) = well_contents.remove(&job.id) { - tasks.spawn(produce_response(contents, well_location, job, response_channel.clone())); - } else { - well_locations.insert(job.id, well_location); - } - } - - Some((contents, job)) = contents_rx.recv() => { - if let Some(well_location) = well_locations.remove(&job.id) { - tasks.spawn(produce_response(contents, well_location, job, response_channel.clone())); - } else { - well_contents.insert(job.id, contents); - } - } - - chimp_permit = chimp_image_tx.clone().reserve_owned() => { - let chimp_permit = chimp_permit.unwrap(); - tasks.spawn(consume_job(job_consumer.clone(), input_width, input_height, chimp_permit, well_image_tx.clone())); - } - - Some((well_image, job)) = well_image_rx.recv() => { - tasks.spawn(find_well_center(well_image, job, well_location_tx.clone())); - } - - Some((bboxes, labels, _, masks, job)) = prediction_rx.recv() => { - tasks.spawn(postprocess_inference(bboxes, labels, masks, job, contents_tx.clone())); + let runtime = { + let mut builder = tokio::runtime::Builder::new_multi_thread(); + builder.enable_all(); + if let Some(threads) = args.threads { + builder.worker_threads(threads); + } + builder.build().unwrap() + }; + runtime.block_on(run(args)); +} + +async fn run(args: Cli) { + let session = setup_inference_session(args.model).unwrap(); + let input_width = session.inputs[0].dimensions[3].unwrap(); + let input_height = session.inputs[0].dimensions[2].unwrap(); + let batch_size = session.inputs[0].dimensions[0].unwrap().try_into().unwrap(); + + let rabbitmq_client = setup_rabbitmq_client(args.rabbitmq_url).await.unwrap(); + let job_channel = rabbitmq_client.create_channel().await.unwrap(); + let response_channel = rabbitmq_client.create_channel().await.unwrap(); + let job_consumer = setup_job_consumer(job_channel, args.rabbitmq_channel) + .await + .unwrap(); + + let (chimp_image_tx, chimp_image_rx) = tokio::sync::mpsc::channel(batch_size); + let (well_image_tx, mut well_image_rx) = tokio::sync::mpsc::channel(batch_size); + let (well_location_tx, mut well_location_rx) = + tokio::sync::mpsc::unbounded_channel::<(Circle, Job)>(); + let (prediction_tx, mut prediction_rx) = tokio::sync::mpsc::unbounded_channel(); + let (contents_tx, mut contents_rx) = tokio::sync::mpsc::unbounded_channel::<(Contents, Job)>(); + + spawn(inference_worker( + session, + batch_size, + chimp_image_rx, + prediction_tx, + )); + + let mut tasks = JoinSet::new(); + + let mut well_locations = HashMap::new(); + let mut well_contents = HashMap::new(); + + loop { + let timeout = if let Some(timeout) = args.timeout { + Either::Left(Delay::new(Duration::from_millis(timeout))) + } else { + Either::Right(std::future::pending()) + }; + + select! { + biased; + + Some((well_location, job)) = well_location_rx.recv() => { + if let Some(contents) = well_contents.remove(&job.id) { + tasks.spawn(produce_response(contents, well_location, job, response_channel.clone())); + } else { + well_locations.insert(job.id, well_location); } - - _ = timeout => { - println!("Stopping: No jobs processed for {}ms", args.timeout.unwrap()); - break; + } + + Some((contents, job)) = contents_rx.recv() => { + if let Some(well_location) = well_locations.remove(&job.id) { + tasks.spawn(produce_response(contents, well_location, job, response_channel.clone())); + } else { + well_contents.insert(job.id, contents); } - - else => break } + + chimp_permit = chimp_image_tx.clone().reserve_owned() => { + let chimp_permit = chimp_permit.unwrap(); + tasks.spawn(consume_job(job_consumer.clone(), input_width, input_height, chimp_permit, well_image_tx.clone())); + } + + Some((well_image, job)) = well_image_rx.recv() => { + tasks.spawn(find_well_center(well_image, job, well_location_tx.clone())); + } + + Some((bboxes, labels, _, masks, job)) = prediction_rx.recv() => { + tasks.spawn(postprocess_inference(bboxes, labels, masks, job, contents_tx.clone())); + } + + _ = timeout => { + println!("Stopping: No jobs processed for {}ms", args.timeout.unwrap()); + break; + } + + else => break } - }); + } } From 25ca8ddb351a610544ec1acc987788cc08180737 Mon Sep 17 00:00:00 2001 From: Garry O'Donnell Date: Fri, 14 Jul 2023 09:39:28 +0000 Subject: [PATCH 34/45] Install chimp_chomp dependencies in CI --- .github/workflows/code.yml | 10 ++++++++++ .github/workflows/docs.yml | 5 +++++ 2 files changed, 15 insertions(+) diff --git a/.github/workflows/code.yml b/.github/workflows/code.yml index b3237da6..b1984869 100644 --- a/.github/workflows/code.yml +++ b/.github/workflows/code.yml @@ -13,6 +13,11 @@ jobs: - name: Checkout source uses: actions/checkout@v3.5.2 + - name: Install dependencies + uses: awalsh128/cache-apt-pkgs-action@v1.3.0 + with: + packages: libopencv-dev clang libclang-dev + - name: Install stable toolchain uses: actions-rs/toolchain@v1.0.6 with: @@ -50,6 +55,11 @@ jobs: - name: Checkout source uses: actions/checkout@v3.5.2 + - name: Install dependencies + uses: awalsh128/cache-apt-pkgs-action@v1.3.0 + with: + packages: libopencv-dev clang libclang-dev + - name: Install stable toolchain uses: actions-rs/toolchain@v1.0.6 with: diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 48e15ecf..4c51606f 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -13,6 +13,11 @@ jobs: - name: Checkout source uses: actions/checkout@v3.5.2 + - name: Install dependencies + uses: awalsh128/cache-apt-pkgs-action@v1.3.0 + with: + packages: libopencv-dev clang libclang-dev + - name: Install nightly toolchain uses: actions-rs/toolchain@v1.0.6 with: From dc139c091a352f5258edcf841b944a70732e245d Mon Sep 17 00:00:00 2001 From: Garry O'Donnell Date: Fri, 14 Jul 2023 09:45:38 +0000 Subject: [PATCH 35/45] Add chimp chomp onnx model with Git LFS --- .devcontainer/devcontainer.json | 3 ++- .gitattributes | 1 + chimp_chomp/chimp.onnx | 3 +++ 3 files changed, 6 insertions(+), 1 deletion(-) create mode 100644 .gitattributes create mode 100644 chimp_chomp/chimp.onnx diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index b260276b..1b4aa889 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -19,7 +19,8 @@ "ghcr.io/devcontainers/features/common-utils:2": { "username": "none", "upgradePackages": false - } + }, + "ghcr.io/devcontainers/features/git-lfs:1": {} }, // Make sure the files we are mapping into the container exist on the host "initializeCommand": "bash -c 'for i in $HOME/.inputrc; do [ -f $i ] || touch $i; done'", diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 00000000..ae46686a --- /dev/null +++ b/.gitattributes @@ -0,0 +1 @@ +chimp_chomp/chimp.onnx filter=lfs diff=lfs merge=lfs -text diff --git a/chimp_chomp/chimp.onnx b/chimp_chomp/chimp.onnx new file mode 100644 index 00000000..75985761 --- /dev/null +++ b/chimp_chomp/chimp.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:37311eb9e26f488b06c5a644c80e7025c589c383a99bbe91e9a29b6e274a3396 +size 176578499 From ecf86bf6f86264a296a7fb3a2e1ba6070600d2b0 Mon Sep 17 00:00:00 2001 From: Garry O'Donnell Date: Fri, 14 Jul 2023 12:22:27 +0000 Subject: [PATCH 36/45] Export chimp model alongside binary --- chimp_chomp/build.rs | 17 +++++++++++++++++ chimp_chomp/src/inference.rs | 9 ++++----- chimp_chomp/src/main.rs | 6 ++---- 3 files changed, 23 insertions(+), 9 deletions(-) create mode 100644 chimp_chomp/build.rs diff --git a/chimp_chomp/build.rs b/chimp_chomp/build.rs new file mode 100644 index 00000000..ed2f4abc --- /dev/null +++ b/chimp_chomp/build.rs @@ -0,0 +1,17 @@ +use std::{env, fs::copy, path::PathBuf}; + +const MODEL_FILE: &str = "chimp.onnx"; + +fn main() { + println!("Copying chimp.onnx"); + let model_src = PathBuf::from(env::var("CARGO_MANIFEST_DIR").unwrap()).join(MODEL_FILE); + let model_dst = PathBuf::from(env::var("OUT_DIR").unwrap()) + .parent() + .unwrap() + .parent() + .unwrap() + .parent() + .unwrap() + .join(MODEL_FILE); + copy(model_src, model_dst).unwrap(); +} diff --git a/chimp_chomp/src/inference.rs b/chimp_chomp/src/inference.rs index b43a41ca..c2980cf3 100644 --- a/chimp_chomp/src/inference.rs +++ b/chimp_chomp/src/inference.rs @@ -1,3 +1,4 @@ +use crate::image_loading::ChimpImage; use chimp_protocol::Job; use itertools::{izip, Itertools}; use ndarray::{Array1, Array2, Array3, Axis, Ix1, Ix2, Ix4}; @@ -5,17 +6,15 @@ use ort::{ tensor::{FromArray, InputTensor}, Environment, ExecutionProvider, GraphOptimizationLevel, OrtError, Session, SessionBuilder, }; -use std::{ops::Deref, path::Path, sync::Arc}; +use std::{env::current_exe, ops::Deref, sync::Arc}; use tokio::sync::mpsc::{error::TryRecvError, Receiver, UnboundedSender}; -use crate::image_loading::ChimpImage; - pub type BBoxes = Array2; pub type Labels = Array1; pub type Scores = Array1; pub type Masks = Array3; -pub fn setup_inference_session(model_path: impl AsRef) -> Result { +pub fn setup_inference_session() -> Result { let environment = Arc::new( Environment::builder() .with_name("CHiMP") @@ -24,7 +23,7 @@ pub fn setup_inference_session(model_path: impl AsRef) -> Result Date: Fri, 14 Jul 2023 12:23:38 +0000 Subject: [PATCH 37/45] Fix docker build for chimp chomp --- .github/workflows/container.yml | 7 +++---- Dockerfile | 34 ++++++++++++++++++++++++++++----- 2 files changed, 32 insertions(+), 9 deletions(-) diff --git a/.github/workflows/container.yml b/.github/workflows/container.yml index c3c50620..51c1e3bb 100644 --- a/.github/workflows/container.yml +++ b/.github/workflows/container.yml @@ -10,7 +10,7 @@ jobs: if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name != github.repository strategy: matrix: - service: + target: - chimp_chomp - soakdb_sync - pin_packing @@ -18,7 +18,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Generate Image Name - run: echo IMAGE_REPOSITORY=ghcr.io/$(tr '[:upper:]' '[:lower:]' <<< "${{ github.repository }}")-${{ matrix.service }} >> $GITHUB_ENV + run: echo IMAGE_REPOSITORY=ghcr.io/$(tr '[:upper:]' '[:lower:]' <<< "${{ github.repository }}")-${{ matrix.target }} >> $GITHUB_ENV - name: Log in to GitHub Docker Registry if: github.event_name != 'pull_request' @@ -43,8 +43,7 @@ jobs: - name: Build Image uses: docker/build-push-action@v4.0.0 with: - build-args: | - SERVICE=${{ matrix.service }} + target: ${{ matrix.target }} push: ${{ github.event_name == 'push' && startsWith(github.ref, 'refs/tags') }} load: ${{ !(github.event_name == 'push' && startsWith(github.ref, 'refs/tags')) }} tags: ${{ steps.meta.outputs.tags }} diff --git a/Dockerfile b/Dockerfile index 75d3bfe1..0b5a536f 100644 --- a/Dockerfile +++ b/Dockerfile @@ -2,6 +2,10 @@ FROM docker.io/library/rust:1.71.0-bullseye AS build WORKDIR /app +RUN apt-get update \ + && apt-get install -y \ + libopencv-dev clang libclang-dev + # Build dependencies COPY Cargo.toml Cargo.lock ./ COPY chimp_chomp/Cargo.toml chimp_chomp/Cargo.toml @@ -13,7 +17,7 @@ COPY pin_packing/Cargo.toml pin_packing/Cargo.toml COPY soakdb_io/Cargo.toml soakdb_io/Cargo.toml COPY soakdb_sync/Cargo.toml soakdb_sync/Cargo.toml RUN mkdir chimp_chomp/src \ - && touch chimp_chomp/src/lib.rs \ + && echo "fn main() {}" > chimp_chomp/src/main.rs \ && mkdir chimp_protocol/src \ && touch chimp_protocol/src/lib.rs \ && mkdir graphql_endpoints/src \ @@ -42,9 +46,29 @@ RUN touch chimp_chomp/src/lib.rs \ && touch soakdb_sync/src/main.rs \ && cargo build --release -FROM gcr.io/distroless/cc -ARG SERVICE +# Collate dynamically linked shared objects for chimp_chomp +RUN mkdir /chimp_chomp_libraries \ + && cp \ + $(ldd /app/target/release/chimp_chomp | grep -o '/.*\.so\S*') \ + /app/target/release/libonnxruntime.so.1.14.1 \ + /chimp_chomp_libraries + +FROM gcr.io/distroless/cc as chimp_chomp + +COPY --from=build /chimp_chomp_libraries/* /lib +COPY --from=build /app/target/release/chimp.onnx /chimp.onnx +COPY --from=build /app/target/release/chimp_chomp /chimp_chomp + +ENTRYPOINT ["./chimp_chomp"] + +FROM gcr.io/distroless/cc as pin_packing + +COPY --from=build /app/target/release/pin_packing /pin_packing + +ENTRYPOINT ["./pin_packing"] + +FROM gcr.io/distroless/cc as soakdb_sync -COPY --from=build /app/target/release/${SERVICE} /service +COPY --from=build /app/target/release/soakdb_sync /soakdb_sync -ENTRYPOINT ["./service"] +ENTRYPOINT ["./soakdb_sync"] From 6f283f3db74722a3c7469dd74760d2519ef7957b Mon Sep 17 00:00:00 2001 From: Garry O'Donnell Date: Fri, 14 Jul 2023 12:24:22 +0000 Subject: [PATCH 38/45] Remove tensorrt dependency --- chimp_chomp/Cargo.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/chimp_chomp/Cargo.toml b/chimp_chomp/Cargo.toml index a2645aa6..4f1f4c90 100644 --- a/chimp_chomp/Cargo.toml +++ b/chimp_chomp/Cargo.toml @@ -21,7 +21,6 @@ opencv = { version = "0.82.1", default-features = false, features = [ ] } ort = { version = "1.14.8", default-features = false, features = [ "download-binaries", - "tensorrt", "copy-dylibs", ] } tokio = { workspace = true, features = ["sync"] } From 447395d924bc2e4d6d398cb99ffb9e906c7a1fd1 Mon Sep 17 00:00:00 2001 From: Garry O'Donnell Date: Mon, 17 Jul 2023 10:38:29 +0000 Subject: [PATCH 39/45] Fix insertion point location --- chimp_chomp/src/postprocessing.rs | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/chimp_chomp/src/postprocessing.rs b/chimp_chomp/src/postprocessing.rs index e8323b27..09ba0624 100644 --- a/chimp_chomp/src/postprocessing.rs +++ b/chimp_chomp/src/postprocessing.rs @@ -5,7 +5,7 @@ use ndarray::{Array2, ArrayView, ArrayView2, Ix1}; use opencv::{ core::CV_8U, imgproc::{distance_transform, DIST_L1, DIST_MASK_3}, - prelude::Mat, + prelude::{Mat, MatTraitConst}, }; use tokio::sync::mpsc::UnboundedSender; @@ -32,7 +32,21 @@ fn insertion_mask( } fn optimal_insert_position(insertion_mask: Array2) -> Point { - let mask = Mat::from_exact_iter(insertion_mask.mapv(|pixel| pixel as u8).into_iter()).unwrap(); + let mask = Mat::from_exact_iter( + insertion_mask + .mapv(|pixel| if pixel { std::u8::MAX } else { 0 }) + .into_iter(), + ) + .unwrap() + .reshape_nd( + 1, + &insertion_mask + .shape() + .iter() + .map(|&dim| dim as i32) + .collect::>(), + ) + .unwrap(); let mut distances = Mat::default(); distance_transform(&mask, &mut distances, DIST_L1, DIST_MASK_3, CV_8U).unwrap(); let (furthest_point, _) = distances From 7a819523e5963fb1e859af02762aec553f1bd225 Mon Sep 17 00:00:00 2001 From: Garry O'Donnell Date: Mon, 17 Jul 2023 14:02:41 +0000 Subject: [PATCH 40/45] Silence chimp_chomp build --- chimp_chomp/build.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/chimp_chomp/build.rs b/chimp_chomp/build.rs index ed2f4abc..5fbcfc66 100644 --- a/chimp_chomp/build.rs +++ b/chimp_chomp/build.rs @@ -3,7 +3,6 @@ use std::{env, fs::copy, path::PathBuf}; const MODEL_FILE: &str = "chimp.onnx"; fn main() { - println!("Copying chimp.onnx"); let model_src = PathBuf::from(env::var("CARGO_MANIFEST_DIR").unwrap()).join(MODEL_FILE); let model_dst = PathBuf::from(env::var("OUT_DIR").unwrap()) .parent() From dd6e5943c29a246673a623907f5c25d19fcfcf65 Mon Sep 17 00:00:00 2001 From: Garry O'Donnell Date: Mon, 17 Jul 2023 14:25:54 +0000 Subject: [PATCH 41/45] Cascade processing errors to response --- Cargo.lock | 8 ++++++ chimp_chomp/Cargo.toml | 1 + chimp_chomp/src/image_loading.rs | 15 ++++++++--- chimp_chomp/src/inference.rs | 14 +++++++--- chimp_chomp/src/jobs.rs | 41 ++++++++++++++++++++++++----- chimp_chomp/src/main.rs | 21 ++++++++++----- chimp_chomp/src/postprocessing.rs | 43 ++++++++++++++++++++----------- chimp_chomp/src/well_centering.rs | 28 +++++++++++++------- chimp_protocol/src/lib.rs | 32 +++++++++++++++-------- 9 files changed, 147 insertions(+), 56 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 6a3a5fd3..4ec3406e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -182,6 +182,12 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "anyhow" +version = "1.0.72" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b13c32d80ecc7ab747b80c3784bce54ee8a7a0cc4fbda9bf4cda2cf6fe90854" + [[package]] name = "async-channel" version = "1.8.0" @@ -602,8 +608,10 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" name = "chimp_chomp" version = "0.1.0" dependencies = [ + "anyhow", "chimp_protocol", "clap 4.3.14", + "derive_more", "dotenvy", "futures", "futures-timer", diff --git a/chimp_chomp/Cargo.toml b/chimp_chomp/Cargo.toml index 4f1f4c90..8490b971 100644 --- a/chimp_chomp/Cargo.toml +++ b/chimp_chomp/Cargo.toml @@ -4,6 +4,7 @@ version = "0.1.0" edition = "2021" [dependencies] +anyhow = { version = "1.0.72" } chimp_protocol = { path = "../chimp_protocol" } clap = { workspace = true } derive_more = { workspace = true } diff --git a/chimp_chomp/src/image_loading.rs b/chimp_chomp/src/image_loading.rs index fa685466..63383e78 100644 --- a/chimp_chomp/src/image_loading.rs +++ b/chimp_chomp/src/image_loading.rs @@ -1,3 +1,4 @@ +use anyhow::Context; use derive_more::Deref; use ndarray::{Array, Ix3}; use opencv::{ @@ -67,11 +68,19 @@ pub fn load_image( path: impl AsRef, chimp_width: u32, chimp_height: u32, -) -> (ChimpImage, WellImage) { - let image = imread(path.as_ref().to_str().unwrap(), IMREAD_COLOR).unwrap(); +) -> Result<(ChimpImage, WellImage), anyhow::Error> { + let image = imread( + path.as_ref() + .to_str() + .context("Image path contains non-UTF8 characters")?, + IMREAD_COLOR, + )?; + if image.empty() { + return Err(anyhow::Error::msg("No image data was loaded")); + } let well_image = prepare_well(&image); let chimp_image = prepare_chimp(&image, chimp_width as i32, chimp_height as i32); - (chimp_image, well_image) + Ok((chimp_image, well_image)) } diff --git a/chimp_chomp/src/inference.rs b/chimp_chomp/src/inference.rs index c2980cf3..e2ac5af9 100644 --- a/chimp_chomp/src/inference.rs +++ b/chimp_chomp/src/inference.rs @@ -1,10 +1,11 @@ use crate::image_loading::ChimpImage; +use anyhow::Context; use chimp_protocol::Job; use itertools::{izip, Itertools}; use ndarray::{Array1, Array2, Array3, Axis, Ix1, Ix2, Ix4}; use ort::{ tensor::{FromArray, InputTensor}, - Environment, ExecutionProvider, GraphOptimizationLevel, OrtError, Session, SessionBuilder, + Environment, ExecutionProvider, GraphOptimizationLevel, Session, SessionBuilder, }; use std::{env::current_exe, ops::Deref, sync::Arc}; use tokio::sync::mpsc::{error::TryRecvError, Receiver, UnboundedSender}; @@ -14,16 +15,21 @@ pub type Labels = Array1; pub type Scores = Array1; pub type Masks = Array3; -pub fn setup_inference_session() -> Result { +pub fn setup_inference_session() -> Result { let environment = Arc::new( Environment::builder() .with_name("CHiMP") .with_execution_providers([ExecutionProvider::cpu()]) .build()?, ); - SessionBuilder::new(&environment)? + Ok(SessionBuilder::new(&environment)? .with_optimization_level(GraphOptimizationLevel::Level3)? - .with_model_from_file(current_exe().unwrap().parent().unwrap().join("chimp.onnx")) + .with_model_from_file( + current_exe()? + .parent() + .context("Executable has no parent directory")? + .join("chimp.onnx"), + )?) } fn do_inference( diff --git a/chimp_chomp/src/jobs.rs b/chimp_chomp/src/jobs.rs index 711f4f48..8162336b 100644 --- a/chimp_chomp/src/jobs.rs +++ b/chimp_chomp/src/jobs.rs @@ -9,7 +9,7 @@ use lapin::{ types::FieldTable, BasicProperties, Channel, Connection, Consumer, }; -use tokio::sync::mpsc::{OwnedPermit, Sender}; +use tokio::sync::mpsc::{OwnedPermit, UnboundedSender}; use url::Url; use uuid::Uuid; @@ -38,17 +38,25 @@ pub async fn consume_job( input_width: u32, input_height: u32, chimp_permit: OwnedPermit<(ChimpImage, Job)>, - well_image_tx: Sender<(WellImage, Job)>, + well_image_tx: UnboundedSender<(WellImage, Job)>, + error_tx: UnboundedSender<(anyhow::Error, Job)>, ) { let delivery = consumer.next().await.unwrap().unwrap(); delivery.ack(BasicAckOptions::default()).await.unwrap(); let job = Job::from_slice(&delivery.data).unwrap(); println!("Consumed Job: {job:?}"); - let (chimp_image, well_image) = load_image(job.file.clone(), input_width, input_height); - chimp_permit.send((chimp_image, job.clone())); - well_image_tx.send((well_image, job)).await.unwrap(); + match load_image(job.file.clone(), input_width, input_height) { + Ok((chimp_image, well_image)) => { + chimp_permit.send((chimp_image, job.clone())); + well_image_tx + .send((well_image, job)) + .map_err(|_| anyhow::Error::msg("Could not send well image")) + .unwrap() + } + Err(err) => error_tx.send((err, job)).unwrap(), + }; } pub async fn produce_response( @@ -63,7 +71,7 @@ pub async fn produce_response( "", &job.predictions_channel, BasicPublishOptions::default(), - &Response { + &Response::Success { job_id: job.id, insertion_point: contents.insertion_point, well_location, @@ -79,3 +87,24 @@ pub async fn produce_response( .await .unwrap(); } + +pub async fn produce_error(error: anyhow::Error, job: Job, rabbitmq_channel: Channel) { + println!("Producing error for: {job:?}"); + rabbitmq_channel + .basic_publish( + "", + &job.predictions_channel, + BasicPublishOptions::default(), + &Response::Failure { + job_id: job.id, + error: error.to_string(), + } + .to_vec() + .unwrap(), + BasicProperties::default(), + ) + .await + .unwrap() + .await + .unwrap(); +} diff --git a/chimp_chomp/src/main.rs b/chimp_chomp/src/main.rs index e10a8c35..b37d7cb9 100644 --- a/chimp_chomp/src/main.rs +++ b/chimp_chomp/src/main.rs @@ -10,9 +10,11 @@ mod well_centering; use crate::{ inference::{inference_worker, setup_inference_session}, - jobs::{consume_job, produce_response, setup_job_consumer, setup_rabbitmq_client}, - postprocessing::postprocess_inference, - well_centering::find_well_center, + jobs::{ + consume_job, produce_error, produce_response, setup_job_consumer, setup_rabbitmq_client, + }, + postprocessing::inference_postprocessing, + well_centering::well_centering, }; use chimp_protocol::{Circle, Job}; use clap::Parser; @@ -68,11 +70,12 @@ async fn run(args: Cli) { .unwrap(); let (chimp_image_tx, chimp_image_rx) = tokio::sync::mpsc::channel(batch_size); - let (well_image_tx, mut well_image_rx) = tokio::sync::mpsc::channel(batch_size); + let (well_image_tx, mut well_image_rx) = tokio::sync::mpsc::unbounded_channel(); let (well_location_tx, mut well_location_rx) = tokio::sync::mpsc::unbounded_channel::<(Circle, Job)>(); let (prediction_tx, mut prediction_rx) = tokio::sync::mpsc::unbounded_channel(); let (contents_tx, mut contents_rx) = tokio::sync::mpsc::unbounded_channel::<(Contents, Job)>(); + let (error_tx, mut error_rx) = tokio::sync::mpsc::unbounded_channel(); spawn(inference_worker( session, @@ -96,6 +99,10 @@ async fn run(args: Cli) { select! { biased; + Some((error, job)) = error_rx.recv() => { + tasks.spawn(produce_error(error, job, response_channel.clone())); + } + Some((well_location, job)) = well_location_rx.recv() => { if let Some(contents) = well_contents.remove(&job.id) { tasks.spawn(produce_response(contents, well_location, job, response_channel.clone())); @@ -114,15 +121,15 @@ async fn run(args: Cli) { chimp_permit = chimp_image_tx.clone().reserve_owned() => { let chimp_permit = chimp_permit.unwrap(); - tasks.spawn(consume_job(job_consumer.clone(), input_width, input_height, chimp_permit, well_image_tx.clone())); + tasks.spawn(consume_job(job_consumer.clone(), input_width, input_height, chimp_permit, well_image_tx.clone(), error_tx.clone())); } Some((well_image, job)) = well_image_rx.recv() => { - tasks.spawn(find_well_center(well_image, job, well_location_tx.clone())); + tasks.spawn(well_centering(well_image, job, well_location_tx.clone(), error_tx.clone())); } Some((bboxes, labels, _, masks, job)) = prediction_rx.recv() => { - tasks.spawn(postprocess_inference(bboxes, labels, masks, job, contents_tx.clone())); + tasks.spawn(inference_postprocessing(bboxes, labels, masks, job, contents_tx.clone(), error_tx.clone())); } _ = timeout => { diff --git a/chimp_chomp/src/postprocessing.rs b/chimp_chomp/src/postprocessing.rs index 09ba0624..2dec92ea 100644 --- a/chimp_chomp/src/postprocessing.rs +++ b/chimp_chomp/src/postprocessing.rs @@ -1,4 +1,5 @@ use crate::inference::{BBoxes, Labels, Masks}; +use anyhow::Context; use chimp_protocol::{BBox, Job, Point}; use itertools::izip; use ndarray::{Array2, ArrayView, ArrayView2, Ix1}; @@ -31,7 +32,7 @@ fn insertion_mask( mask } -fn optimal_insert_position(insertion_mask: Array2) -> Point { +fn optimal_insert_position(insertion_mask: Array2) -> Result { let mask = Mat::from_exact_iter( insertion_mask .mapv(|pixel| if pixel { std::u8::MAX } else { 0 }) @@ -53,11 +54,11 @@ fn optimal_insert_position(insertion_mask: Array2) -> Point { .iter::() .unwrap() .max_by(|(_, a), (_, b)| a.cmp(b)) - .unwrap(); - Point { + .context("No valid insertion points")?; + Ok(Point { x: furthest_point.x as usize, y: furthest_point.y as usize, - } + }) } fn bbox_from_array(bbox: ArrayView) -> BBox { @@ -73,9 +74,10 @@ fn find_drop_instance<'a>( labels: &Labels, bboxes: &BBoxes, masks: &'a Masks, -) -> Option<(BBox, ArrayView2<'a, f32>)> { +) -> Result<(BBox, ArrayView2<'a, f32>), anyhow::Error> { izip!(labels, bboxes.outer_iter(), masks.outer_iter()) .find_map(|(label, bbox, mask)| (*label == 1).then_some((bbox_from_array(bbox), mask))) + .context("No drop instances in prediction") } fn find_crystal_instances<'a>( @@ -88,24 +90,35 @@ fn find_crystal_instances<'a>( .collect() } -pub async fn postprocess_inference( +fn postprocess_inference( bboxes: BBoxes, labels: Labels, masks: Masks, - job: Job, - contents_tx: UnboundedSender<(Contents, Job)>, -) { - println!("Postprocessing: {job:?}"); - let (drop, drop_mask) = find_drop_instance(&labels, &bboxes, &masks).unwrap(); +) -> Result { + let (drop, drop_mask) = find_drop_instance(&labels, &bboxes, &masks)?; let (crystals, crystal_masks) = find_crystal_instances(&labels, &bboxes, &masks) .into_iter() .unzip(); let insertion_mask = insertion_mask(drop_mask, crystal_masks); - let insertion_point = optimal_insert_position(insertion_mask); - let contents = Contents { + let insertion_point = optimal_insert_position(insertion_mask)?; + Ok(Contents { drop, crystals, insertion_point, - }; - contents_tx.send((contents, job)).unwrap(); + }) +} + +pub async fn inference_postprocessing( + bboxes: BBoxes, + labels: Labels, + masks: Masks, + job: Job, + contents_tx: UnboundedSender<(Contents, Job)>, + error_tx: UnboundedSender<(anyhow::Error, Job)>, +) { + println!("Postprocessing: {job:?}"); + match postprocess_inference(bboxes, labels, masks) { + Ok(contents) => contents_tx.send((contents, job)).unwrap(), + Err(err) => error_tx.send((err, job)).unwrap(), + } } diff --git a/chimp_chomp/src/well_centering.rs b/chimp_chomp/src/well_centering.rs index af236f18..3729b92b 100644 --- a/chimp_chomp/src/well_centering.rs +++ b/chimp_chomp/src/well_centering.rs @@ -1,4 +1,5 @@ use crate::image_loading::WellImage; +use anyhow::Context; use chimp_protocol::{Circle, Job, Point}; use opencv::{ core::{Vec4f, Vector}, @@ -8,12 +9,7 @@ use opencv::{ use std::ops::Deref; use tokio::sync::mpsc::UnboundedSender; -pub async fn find_well_center( - image: WellImage, - job: Job, - well_location_tx: UnboundedSender<(Circle, Job)>, -) { - println!("Finding Well Center for {job:?}"); +fn find_well_center(image: WellImage) -> Result { let min_side = *image.deref().mat_size().iter().min().unwrap(); let mut circles = Vector::::new(); hough_circles( @@ -31,13 +27,25 @@ pub async fn find_well_center( let well_location = circles .into_iter() .max_by(|&a, &b| a[3].total_cmp(&b[3])) - .unwrap(); - let well_location = Circle { + .context("No circles found in image")?; + Ok(Circle { center: Point { x: well_location[0] as usize, y: well_location[1] as usize, }, radius: well_location[2], - }; - well_location_tx.send((well_location, job)).unwrap() + }) +} + +pub async fn well_centering( + image: WellImage, + job: Job, + well_location_tx: UnboundedSender<(Circle, Job)>, + error_tx: UnboundedSender<(anyhow::Error, Job)>, +) { + println!("Finding Well Center for {job:?}"); + match find_well_center(image) { + Ok(well_center) => well_location_tx.send((well_center, job)).unwrap(), + Err(err) => error_tx.send((err, job)).unwrap(), + } } diff --git a/chimp_protocol/src/lib.rs b/chimp_protocol/src/lib.rs index 97ae4693..07d3e83e 100644 --- a/chimp_protocol/src/lib.rs +++ b/chimp_protocol/src/lib.rs @@ -31,17 +31,27 @@ impl Job { /// A set of predictions which apply to a single image. #[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Response { - /// The unique identifier of the requesting [`Request`]. - pub job_id: Uuid, - /// The proposed point for solvent insertion. - pub insertion_point: Point, - /// The location of the well centroid and radius. - pub well_location: Circle, - /// A bounding box emcompasing the solvent. - pub drop: BBox, - /// A set of bounding boxes, each emcompasing a crystal. - pub crystals: Vec, +pub enum Response { + /// The image was processed successfully, producing the contained predictions. + Success { + /// The unique identifier of the requesting [`Job`]. + job_id: Uuid, + /// The proposed point for solvent insertion. + insertion_point: Point, + /// The location of the well centroid and radius. + well_location: Circle, + /// A bounding box emcompasing the solvent. + drop: BBox, + /// A set of bounding boxes, each emcompasing a crystal. + crystals: Vec, + }, + /// Image processing failed, with the contained error. + Failure { + /// The unique identifier of the requesting [`Job`]. + job_id: Uuid, + /// A description of the error encountered. + error: String, + }, } impl Response { From 26e01fb8641d1a4ce21d8358b216406e7f1bcba0 Mon Sep 17 00:00:00 2001 From: Garry O'Donnell Date: Mon, 17 Jul 2023 14:26:29 +0000 Subject: [PATCH 42/45] Use single frame model with rockmaker image dimensions --- chimp_chomp/chimp.onnx | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/chimp_chomp/chimp.onnx b/chimp_chomp/chimp.onnx index 75985761..82724c1c 100644 --- a/chimp_chomp/chimp.onnx +++ b/chimp_chomp/chimp.onnx @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:37311eb9e26f488b06c5a644c80e7025c589c383a99bbe91e9a29b6e274a3396 -size 176578499 +oid sha256:62905cc5a5a2ed3118467ee5d44c435786eb542be9cb31667441b2c78622c6a3 +size 176151256 From b5c990dad88c689f84f3194244f85a11a52cbba9 Mon Sep 17 00:00:00 2001 From: Garry O'Donnell Date: Mon, 17 Jul 2023 16:05:21 +0000 Subject: [PATCH 43/45] Add well centering tests --- Cargo.lock | 10 +++++ chimp_chomp/Cargo.toml | 3 ++ chimp_chomp/src/image_loading.rs | 2 +- chimp_chomp/src/well_centering.rs | 61 ++++++++++++++++++++++++++++++- 4 files changed, 73 insertions(+), 3 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 4ec3406e..753074c5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -188,6 +188,15 @@ version = "1.0.72" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3b13c32d80ecc7ab747b80c3784bce54ee8a7a0cc4fbda9bf4cda2cf6fe90854" +[[package]] +name = "approx" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cab112f0a86d568ea0e627cc1d6be74a1e9cd55214684db5561995f6dad897c6" +dependencies = [ + "num-traits", +] + [[package]] name = "async-channel" version = "1.8.0" @@ -609,6 +618,7 @@ name = "chimp_chomp" version = "0.1.0" dependencies = [ "anyhow", + "approx", "chimp_protocol", "clap 4.3.14", "derive_more", diff --git a/chimp_chomp/Cargo.toml b/chimp_chomp/Cargo.toml index 8490b971..fd25a200 100644 --- a/chimp_chomp/Cargo.toml +++ b/chimp_chomp/Cargo.toml @@ -27,3 +27,6 @@ ort = { version = "1.14.8", default-features = false, features = [ tokio = { workspace = true, features = ["sync"] } url = { workspace = true } uuid = { workspace = true } + +[dev-dependencies] +approx = { version = "0.5.1" } diff --git a/chimp_chomp/src/image_loading.rs b/chimp_chomp/src/image_loading.rs index 63383e78..48846aa0 100644 --- a/chimp_chomp/src/image_loading.rs +++ b/chimp_chomp/src/image_loading.rs @@ -10,7 +10,7 @@ use opencv::{ use std::path::Path; #[derive(Debug, Deref)] -pub struct WellImage(Mat); +pub struct WellImage(pub Mat); #[derive(Debug, Deref)] pub struct ChimpImage(Array); diff --git a/chimp_chomp/src/well_centering.rs b/chimp_chomp/src/well_centering.rs index 3729b92b..b51a74b9 100644 --- a/chimp_chomp/src/well_centering.rs +++ b/chimp_chomp/src/well_centering.rs @@ -9,7 +9,7 @@ use opencv::{ use std::ops::Deref; use tokio::sync::mpsc::UnboundedSender; -fn find_well_center(image: WellImage) -> Result { +fn find_well_location(image: WellImage) -> Result { let min_side = *image.deref().mat_size().iter().min().unwrap(); let mut circles = Vector::::new(); hough_circles( @@ -44,8 +44,65 @@ pub async fn well_centering( error_tx: UnboundedSender<(anyhow::Error, Job)>, ) { println!("Finding Well Center for {job:?}"); - match find_well_center(image) { + match find_well_location(image) { Ok(well_center) => well_location_tx.send((well_center, job)).unwrap(), Err(err) => error_tx.send((err, job)).unwrap(), } } + +#[cfg(test)] +mod tests { + use crate::{image_loading::WellImage, well_centering::find_well_location}; + use approx::assert_relative_eq; + use opencv::{ + core::{Mat, Point_, Scalar, CV_8UC1}, + imgproc::{circle, LINE_8}, + }; + + #[test] + fn well_center_found() { + const CENTER_X: usize = 654; + const CENTER_Y: usize = 321; + const RADIUS: f32 = 480.0; + const THICKNESS: i32 = 196; + + let mut test_image = Mat::new_nd_with_default( + &[1024, 1224], + CV_8UC1, + Scalar::new( + std::u8::MAX as f64, + std::u8::MAX as f64, + std::u8::MAX as f64, + std::u8::MAX as f64, + ), + ) + .unwrap(); + circle( + &mut test_image, + Point_ { + x: CENTER_X as i32, + y: CENTER_Y as i32, + }, + RADIUS as i32 + THICKNESS / 2, + Scalar::new(0_f64, 0_f64, 0_f64, std::u8::MAX as f64), + THICKNESS, + LINE_8, + 0, + ) + .unwrap(); + + let location = find_well_location(WellImage(test_image)).unwrap(); + + assert_relative_eq!( + CENTER_X as f64, + location.center.x as f64, + max_relative = 8.0 + ); + assert_relative_eq!( + CENTER_Y as f64, + location.center.y as f64, + max_relative = 8.0 + ); + assert_relative_eq!(RADIUS, location.radius, max_relative = 8.0) + } +} From a34bb867b7741f6ba6351d0bf553f85b12174a68 Mon Sep 17 00:00:00 2001 From: Garry O'Donnell Date: Mon, 17 Jul 2023 16:28:49 +0000 Subject: [PATCH 44/45] Add insertion point finding tests --- chimp_chomp/src/postprocessing.rs | 58 ++++++++++++++++++++++++++----- 1 file changed, 50 insertions(+), 8 deletions(-) diff --git a/chimp_chomp/src/postprocessing.rs b/chimp_chomp/src/postprocessing.rs index 2dec92ea..3e39da06 100644 --- a/chimp_chomp/src/postprocessing.rs +++ b/chimp_chomp/src/postprocessing.rs @@ -32,24 +32,26 @@ fn insertion_mask( mask } -fn optimal_insert_position(insertion_mask: Array2) -> Result { - let mask = Mat::from_exact_iter( - insertion_mask - .mapv(|pixel| if pixel { std::u8::MAX } else { 0 }) +fn ndarray_mask_into_opencv_mat(mask: Array2) -> Mat { + Mat::from_exact_iter( + mask.mapv(|pixel| if pixel { std::u8::MAX } else { 0 }) .into_iter(), ) .unwrap() .reshape_nd( 1, - &insertion_mask + &mask .shape() .iter() .map(|&dim| dim as i32) .collect::>(), ) - .unwrap(); + .unwrap() +} + +fn optimal_insert_position(insertion_mask: Mat) -> Result { let mut distances = Mat::default(); - distance_transform(&mask, &mut distances, DIST_L1, DIST_MASK_3, CV_8U).unwrap(); + distance_transform(&insertion_mask, &mut distances, DIST_L1, DIST_MASK_3, CV_8U).unwrap(); let (furthest_point, _) = distances .iter::() .unwrap() @@ -99,7 +101,7 @@ fn postprocess_inference( let (crystals, crystal_masks) = find_crystal_instances(&labels, &bboxes, &masks) .into_iter() .unzip(); - let insertion_mask = insertion_mask(drop_mask, crystal_masks); + let insertion_mask = ndarray_mask_into_opencv_mat(insertion_mask(drop_mask, crystal_masks)); let insertion_point = optimal_insert_position(insertion_mask)?; Ok(Contents { drop, @@ -122,3 +124,43 @@ pub async fn inference_postprocessing( Err(err) => error_tx.send((err, job)).unwrap(), } } + +#[cfg(test)] +mod tests { + use super::optimal_insert_position; + use opencv::{ + core::{Point_, Scalar, CV_8UC1}, + imgproc::{circle, LINE_8}, + prelude::Mat, + }; + + #[test] + fn optimal_insert_found() { + let mut test_image = Mat::new_nd_with_default( + &[1024, 1224], + CV_8UC1, + Scalar::new(0_f64, 0_f64, 0_f64, std::u8::MAX as f64), + ) + .unwrap(); + circle( + &mut test_image, + Point_::new(256, 512), + 128, + Scalar::new( + std::u8::MAX as f64, + std::u8::MAX as f64, + std::u8::MAX as f64, + std::u8::MAX as f64, + ), + -1, + LINE_8, + 0, + ) + .unwrap(); + + let position = optimal_insert_position(test_image).unwrap(); + + assert_eq!(256, position.x); + assert_eq!(512, position.y); + } +} From 7813cb747a84f2fd9a135e16c1f6d6c499074020 Mon Sep 17 00:00:00 2001 From: Garry O'Donnell Date: Tue, 18 Jul 2023 09:58:48 +0000 Subject: [PATCH 45/45] Add chimp_chomp docstrings --- chimp_chomp/src/image_loading.rs | 24 +++++++++++++++++++----- chimp_chomp/src/inference.rs | 14 ++++++++++++++ chimp_chomp/src/jobs.rs | 15 +++++++++++++++ chimp_chomp/src/main.rs | 8 ++++++++ chimp_chomp/src/postprocessing.rs | 22 ++++++++++++++++++++++ chimp_chomp/src/well_centering.rs | 10 ++++++++++ 6 files changed, 88 insertions(+), 5 deletions(-) diff --git a/chimp_chomp/src/image_loading.rs b/chimp_chomp/src/image_loading.rs index 48846aa0..86e42c48 100644 --- a/chimp_chomp/src/image_loading.rs +++ b/chimp_chomp/src/image_loading.rs @@ -9,12 +9,15 @@ use opencv::{ }; use std::path::Path; +/// A grayscale image of the well in [W, H, C] format. #[derive(Debug, Deref)] pub struct WellImage(pub Mat); +/// A RGB image of the well in [C, W, H] format. #[derive(Debug, Deref)] pub struct ChimpImage(Array); +/// Converts an image from a [`Mat`] in BGR and ordered in [W, H, C] to a [`Array`] in RGB and ordered in [C, W, H] and resizes it to the input dimensions of the model. fn prepare_chimp(image: &Mat, width: i32, height: i32) -> ChimpImage { let mut resized_image = Mat::default(); resize( @@ -58,17 +61,17 @@ fn prepare_chimp(image: &Mat, width: i32, height: i32) -> ChimpImage { ChimpImage(chimp_image) } +/// Converts an image from BGR to grayscale. fn prepare_well(image: &Mat) -> WellImage { let mut well_image = Mat::default(); cvt_color(&image, &mut well_image, COLOR_BGR2GRAY, 0).unwrap(); WellImage(well_image) } -pub fn load_image( - path: impl AsRef, - chimp_width: u32, - chimp_height: u32, -) -> Result<(ChimpImage, WellImage), anyhow::Error> { +/// Reads an image from file. +/// +/// Returns an [`anyhow::Error`] if the image could not be read or is empty. +fn read_image(path: impl AsRef) -> Result { let image = imread( path.as_ref() .to_str() @@ -78,7 +81,18 @@ pub fn load_image( if image.empty() { return Err(anyhow::Error::msg("No image data was loaded")); } + Ok(image) +} +/// Reads an image from file and prepares both a [`ChimpImage`] and a [`WellImage`]. +/// +/// Returns an [`anyhow::Error`] if the image could not be read or is empty. +pub fn load_image( + path: impl AsRef, + chimp_width: u32, + chimp_height: u32, +) -> Result<(ChimpImage, WellImage), anyhow::Error> { + let image = read_image(path)?; let well_image = prepare_well(&image); let chimp_image = prepare_chimp(&image, chimp_width as i32, chimp_height as i32); diff --git a/chimp_chomp/src/inference.rs b/chimp_chomp/src/inference.rs index e2ac5af9..afd07afc 100644 --- a/chimp_chomp/src/inference.rs +++ b/chimp_chomp/src/inference.rs @@ -10,11 +10,18 @@ use ort::{ use std::{env::current_exe, ops::Deref, sync::Arc}; use tokio::sync::mpsc::{error::TryRecvError, Receiver, UnboundedSender}; +/// The raw box predictor output of a MaskRCNN. pub type BBoxes = Array2; +/// The raw label output of a MaskRCNN. pub type Labels = Array1; +/// The raw scores output of a MaskRCNN. pub type Scores = Array1; +/// The raw masks output of a MaskRCNN. pub type Masks = Array3; +/// Starts an inference session by setting up the ONNX Runtime environment and loading the model. +/// +/// Returns an [`anyhow::Error`] if the environment could not be built or if the model could not be loaded. pub fn setup_inference_session() -> Result { let environment = Arc::new( Environment::builder() @@ -32,6 +39,9 @@ pub fn setup_inference_session() -> Result { )?) } +/// Performs inference on a batch of images, dummy images are used to pad the tesnor if underfull. +/// +/// Returns a set of predictions, where each instances corresponds to the an input image, order is maintained. fn do_inference( session: &Session, images: &[ChimpImage], @@ -85,6 +95,10 @@ fn do_inference( .collect() } +/// Listens to a [`Receiver`] for instances of [`ChimpImage`] and performs batch inference on these. +/// +/// Each pass, all available images in the [`tokio::sync::mpsc::channel`] - up to the batch size - are taken and passed to the model for inference. +/// Model predictions are sent over a [`tokio::sync::mpsc::unbounded_channel`]. pub async fn inference_worker( session: Session, batch_size: usize, diff --git a/chimp_chomp/src/jobs.rs b/chimp_chomp/src/jobs.rs index 8162336b..a82a6380 100644 --- a/chimp_chomp/src/jobs.rs +++ b/chimp_chomp/src/jobs.rs @@ -13,10 +13,17 @@ use tokio::sync::mpsc::{OwnedPermit, UnboundedSender}; use url::Url; use uuid::Uuid; +/// Creates a RabbitMQ [`Connection`] with [`Default`] [`lapin::ConnectionProperties`]. +/// +/// Returns a [`lapin::Error`] if a connection could not be established. pub async fn setup_rabbitmq_client(address: Url) -> Result { lapin::Connection::connect(address.as_str(), lapin::ConnectionProperties::default()).await } +/// Joins a RabbitMQ channel, creating a [`Consumer`] with [`Default`] [`BasicConsumeOptions`] and [`FieldTable`]. +/// The consumer tag is generated following the format `chimp_chomp_${`[`Uuid::new_v4`]`}`. +/// +/// Returns a [`lapin::Error`] if the requested channel is not available. pub async fn setup_job_consumer( rabbitmq_channel: Channel, channel: impl AsRef, @@ -33,6 +40,12 @@ pub async fn setup_job_consumer( .await } +/// Reads a message from the [`lapin::Consumer`] then loads and prepares the requested image for downstream processing. +/// +/// An [`OwnedPermit`] to send to the chimp [`tokio::sync::mpsc::channel`] is required such that backpressure is be propagated to message consumption. +/// +/// The prepared images are sent over a [`tokio::sync::mpsc::channel`] and [`tokio::sync::mpsc::unbounded_channel`] if sucessful. +/// An [`anyhow::Error`] is sent if the image could not be read or is empty. pub async fn consume_job( mut consumer: Consumer, input_width: u32, @@ -59,6 +72,7 @@ pub async fn consume_job( }; } +/// Takes the results of postprocessing and well centering and publishes a [`Response::Success`] to the RabbitMQ [`Channel`] provided by the [`Job`]. pub async fn produce_response( contents: Contents, well_location: Circle, @@ -88,6 +102,7 @@ pub async fn produce_response( .unwrap(); } +/// Takes an error generated in one of the prior stages and publishes a [`Response::Failure`] to the RabbitMQ [`Channel`] provided by the [`Job`]. pub async fn produce_error(error: anyhow::Error, job: Job, rabbitmq_channel: Channel) { println!("Producing error for: {job:?}"); rabbitmq_channel diff --git a/chimp_chomp/src/main.rs b/chimp_chomp/src/main.rs index b37d7cb9..3083d86e 100644 --- a/chimp_chomp/src/main.rs +++ b/chimp_chomp/src/main.rs @@ -1,11 +1,17 @@ #![forbid(unsafe_code)] #![warn(missing_docs)] +#![warn(clippy::missing_docs_in_private_items)] #![doc=include_str!("../README.md")] +/// Utilities for loading images. mod image_loading; +/// Neural Netowrk inference with [`ort`]. mod inference; +/// RabbitMQ [`Job`] queue consumption and [`Response`] publishing. mod jobs; +/// Neural Network inference postprocessing with optimal insertion point finding. mod postprocessing; +/// Well localisation. mod well_centering; use crate::{ @@ -25,6 +31,7 @@ use std::{collections::HashMap, time::Duration}; use tokio::{select, spawn, task::JoinSet}; use url::Url; +/// An inference worker for the Crystal Hits in My Plate (CHiMP) neural network. #[derive(Debug, Parser)] #[command(author, version, about, long_about=None)] struct Cli { @@ -56,6 +63,7 @@ fn main() { runtime.block_on(run(args)); } +#[allow(clippy::missing_docs_in_private_items)] async fn run(args: Cli) { let session = setup_inference_session().unwrap(); let input_width = session.inputs[0].dimensions[3].unwrap(); diff --git a/chimp_chomp/src/postprocessing.rs b/chimp_chomp/src/postprocessing.rs index 3e39da06..95c04702 100644 --- a/chimp_chomp/src/postprocessing.rs +++ b/chimp_chomp/src/postprocessing.rs @@ -10,15 +10,21 @@ use opencv::{ }; use tokio::sync::mpsc::UnboundedSender; +/// The predicted contents of a well image. #[derive(Debug)] pub struct Contents { + /// The optimal point at which solvent should be inserted. pub insertion_point: Point, + /// A bounding box enclosing the drop of solution. pub drop: BBox, + /// A set of bounding boxes enclosing each crystal in the drop. pub crystals: Vec, } +/// The threshold to apply to the raw MaskRCNN [`Masks`] to generate a binary mask. const PREDICTION_THRESHOLD: f32 = 0.5; +/// Creates a mask of valid insertion positions by adding all pixels in the drop mask and subsequently subtracting those in the crystal masks. fn insertion_mask( drop_mask: ArrayView2, crystal_masks: Vec>, @@ -32,6 +38,7 @@ fn insertion_mask( mask } +/// Converts an [`Array2`] into an [`Mat`] of type [`CV_8U`] with the same dimensions. fn ndarray_mask_into_opencv_mat(mask: Array2) -> Mat { Mat::from_exact_iter( mask.mapv(|pixel| if pixel { std::u8::MAX } else { 0 }) @@ -49,6 +56,9 @@ fn ndarray_mask_into_opencv_mat(mask: Array2) -> Mat { .unwrap() } +/// Performs a distance transform to find the point in the mask which is furthest from any invalid region. +/// +/// Returns an [`anyhow::Error`] if no valid insertion point was found. fn optimal_insert_position(insertion_mask: Mat) -> Result { let mut distances = Mat::default(); distance_transform(&insertion_mask, &mut distances, DIST_L1, DIST_MASK_3, CV_8U).unwrap(); @@ -63,6 +73,7 @@ fn optimal_insert_position(insertion_mask: Mat) -> Result }) } +/// Converts an [`ArrayView) -> BBox { BBox { left: bbox[0], @@ -72,6 +83,9 @@ fn bbox_from_array(bbox: ArrayView) -> BBox { } } +/// Finds the first instance which is labelled as a drop. +/// +/// Returns an [`anyhow::Error`] if no drop instances were found. fn find_drop_instance<'a>( labels: &Labels, bboxes: &BBoxes, @@ -82,6 +96,7 @@ fn find_drop_instance<'a>( .context("No drop instances in prediction") } +/// Finds all instances which are labelled as crystals. fn find_crystal_instances<'a>( labels: &Labels, bboxes: &BBoxes, @@ -92,6 +107,9 @@ fn find_crystal_instances<'a>( .collect() } +/// Takes the results of inference on an image and uses it to produce useful regional data and an optimal insertion point. +/// +/// Returns an [`anyhow::Error`] if no drop instances could be found or if no valid insertion point was found. fn postprocess_inference( bboxes: BBoxes, labels: Labels, @@ -110,6 +128,10 @@ fn postprocess_inference( }) } +/// Takes the results of inference on an image and uses it to produce useful regional data and an optimal insertion point. +/// +/// The extracted [`Contents`] are sent over a [`tokio::sync::mpsc::unbounded_channel`] if sucessful. +/// An [`anyhow::Error`] is sent if no drop instances were found or if no valid insertion point was found. pub async fn inference_postprocessing( bboxes: BBoxes, labels: Labels, diff --git a/chimp_chomp/src/well_centering.rs b/chimp_chomp/src/well_centering.rs index b51a74b9..64424e75 100644 --- a/chimp_chomp/src/well_centering.rs +++ b/chimp_chomp/src/well_centering.rs @@ -9,6 +9,12 @@ use opencv::{ use std::ops::Deref; use tokio::sync::mpsc::UnboundedSender; +/// Uses a canny edge detector and a hough circle transform to localise a [`Circle`] of high contrast in the image. +/// +/// The circle is assumed to have a radius in [⅜ l, ½ l), where `l` denotes the shortest edge lenth of the image. +/// The circle with the most counts is selected. +/// +/// Returns an [`anyhow::Error`] if no circles were found. fn find_well_location(image: WellImage) -> Result { let min_side = *image.deref().mat_size().iter().min().unwrap(); let mut circles = Vector::::new(); @@ -37,6 +43,10 @@ fn find_well_location(image: WellImage) -> Result { }) } +/// Takes a grayscale image of the well and finds the center and radius. +/// +/// The extracted [`Circle`] is sent over a [`tokio::sync::mpsc::unbounded_channel`] if sucessful. +/// An [`anyhow::Error`] is sent if no circles were found. pub async fn well_centering( image: WellImage, job: Job,