diff --git a/rust/Cargo.lock b/rust/Cargo.lock index 790ef22..2459ece 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -319,6 +319,12 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3" +[[package]] +name = "once_cell" +version = "1.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" + [[package]] name = "papergrid" version = "0.11.0" @@ -451,11 +457,12 @@ checksum = "c08c74e62047bb2de4ff487b251e4a92e24f48745648451635cec7d591162d9f" [[package]] name = "rism" -version = "1.1.1" +version = "1.2.1" dependencies = [ "clap", "console", "indicatif", + "once_cell", "rand", "serde", "tabled", diff --git a/rust/Cargo.toml b/rust/Cargo.toml index 6026ca9..e8692e8 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rism" -version = "1.1.1" +version = "1.2.1" edition = "2021" [lib] @@ -18,6 +18,7 @@ serde = { version = "1.0.197", features = ["derive"] } indicatif = "0.17.8" console = "0.15.8" tabled = "0.15.0" +once_cell = "1.19.0" z3 = { version = "0.12.1", optional = true} [features] diff --git a/rust/src/main.rs b/rust/src/main.rs index 8d4ed49..d49bdb6 100644 --- a/rust/src/main.rs +++ b/rust/src/main.rs @@ -1,3 +1,5 @@ +use std::collections::HashMap; +use std::thread::ThreadId; use clap::Parser; use serde::Serialize; use rism::rism_classic::run; @@ -6,6 +8,8 @@ use rism::io::{import_students, import_seminars}; #[cfg(feature = "model-checking")] use rism::rism_model_checking::run_model_check; use console::style; +use indicatif::{MultiProgress, ProgressBar, ProgressStyle}; +use once_cell::sync::Lazy; use tabled::{Table, Tabled}; use tabled::settings::Style; @@ -53,19 +57,43 @@ struct PrintableStudent { p_seminar: String, } +static mut PROGRESS_BARS: Lazy> = Lazy::new(HashMap::new); +static mut MULTI_PROGRESS: Lazy = Lazy::new(MultiProgress::new); + +fn update_progress(thread: ThreadId, p: u32, total: u32) { + unsafe { + let opt_prog = PROGRESS_BARS.get(&thread); + match opt_prog { + None => { + let loc_prog = ProgressBar::new(total as u64); + MULTI_PROGRESS.add(loc_prog.clone()); + loc_prog.set_position(p as u64); + loc_prog.set_style(ProgressStyle::with_template("[{per_sec:15}] {wide_bar:.cyan/blue} {percent:>2}% ({pos:>8}/{len:>8}) ") + .unwrap() + .progress_chars("#>-")); + PROGRESS_BARS.insert(thread, loc_prog); + } + Some(prog) => { + prog.set_position(p as u64); + } + } + } +} + fn main() { let args = Args::parse(); let seminars = import_seminars(&args.seminars_path); let students = import_students(&args.students_path, &seminars); - let best_iteration = match args.variant { - ExecutionVariants::Classic => Some(run(&students, &seminars, args.iterations, get_default_points(), args.threads)), + ExecutionVariants::Classic => Some(run(&students, &seminars, args.iterations, get_default_points(), args.threads, update_progress)), #[cfg(feature = "model-checking")] ExecutionVariants::ModelChecking => run_model_check(&students, &seminars, get_default_points()) }; - if let Some(bi_unwr) = best_iteration { + unsafe { PROGRESS_BARS.iter().for_each(|(_, p)| p.finish_and_clear()); } + + if let Some(bi_unwr) = best_iteration { let mut students_table_data = Vec::new(); for a in &bi_unwr.assignments { diff --git a/rust/src/rism_classic/mod.rs b/rust/src/rism_classic/mod.rs index 8bea616..d1c83fe 100644 --- a/rust/src/rism_classic/mod.rs +++ b/rust/src/rism_classic/mod.rs @@ -4,6 +4,7 @@ use rand::seq::SliceRandom; use crate::constants::Points; use std::{cmp, thread}; use std::sync::{Arc, Mutex}; +use std::thread::ThreadId; fn find_possible_assignment<'a>(wishes: &'a Vec, points: &Points, iteration: &RismResult) -> (Option<&'a Seminar>, u16) { return if iteration.get_capacity(&wishes[0]) > 0 { @@ -17,7 +18,7 @@ fn find_possible_assignment<'a>(wishes: &'a Vec, points: &Points, itera }; } -pub fn run<'a>(students: &'a Vec, seminars: &'a Vec, iterations: u32, points: Points, threads: u16) -> RismResult<'a> { +pub fn run<'a>(students: &'a Vec, seminars: &'a Vec, iterations: u32, points: Points, threads: u16, progress: fn(ThreadId, u32, u32)) -> RismResult<'a> { let results = Arc::new(Mutex::new(Vec::new())); @@ -26,7 +27,7 @@ pub fn run<'a>(students: &'a Vec, seminars: &'a Vec, iteration for _ in 0..threads { s.spawn(|| { let results_arc = results.clone(); - let res = run_algorithm(students, seminars, thread_iterations, points.clone()); + let res = run_algorithm(students, seminars, thread_iterations, points.clone(), progress); let mut res_unwr = results_arc.lock().unwrap(); (*res_unwr).push(res); }); @@ -41,11 +42,11 @@ pub fn run<'a>(students: &'a Vec, seminars: &'a Vec, iteration } -pub fn run_algorithm<'a>(students: &'a Vec, seminars: &'a Vec, iterations: u32, points: Points) -> RismResult<'a> { +pub fn run_algorithm<'a>(students: &'a Vec, seminars: &'a Vec, iterations: u32, points: Points, progress: fn(ThreadId, u32, u32)) -> RismResult<'a> { let mut best_iteration: Option = None; - // TODO Reintroduce progress - for _ in 0..iterations { + for p in 0..iterations { + progress(thread::current().id(), p, iterations); let mut shuffled_indices: Vec = (0..students.len()).collect(); shuffled_indices.shuffle(&mut thread_rng());