Skip to content

Commit

Permalink
feat(rust): better progress handling
Browse files Browse the repository at this point in the history
  • Loading branch information
neferin12 committed Apr 28, 2024
1 parent 16ef2da commit 38c10e0
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 8 deletions.
7 changes: 7 additions & 0 deletions rust/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions rust/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ serde = { version = "1.0.197", features = ["derive"] }
indicatif = "0.17.8"
console = "0.15.8"
tabled = "0.15.0"
once_cell = "1.19.0"
z3 = { version = "0.12.1", optional = true}

[features]
Expand Down
34 changes: 31 additions & 3 deletions rust/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::collections::HashMap;
use std::thread::ThreadId;
use clap::Parser;
use serde::Serialize;
use rism::rism_classic::run;
Expand All @@ -6,6 +8,8 @@ use rism::io::{import_students, import_seminars};
#[cfg(feature = "model-checking")]
use rism::rism_model_checking::run_model_check;
use console::style;
use indicatif::{MultiProgress, ProgressBar, ProgressStyle};
use once_cell::sync::Lazy;
use tabled::{Table, Tabled};
use tabled::settings::Style;

Expand Down Expand Up @@ -53,19 +57,43 @@ struct PrintableStudent {
p_seminar: String,
}

static mut PROGRESS_BARS: Lazy<HashMap<ThreadId, ProgressBar>> = Lazy::new(HashMap::new);
static mut MULTI_PROGRESS: Lazy<MultiProgress> = Lazy::new(MultiProgress::new);

fn update_progress(thread: ThreadId, p: u32, total: u32) {
unsafe {
let opt_prog = PROGRESS_BARS.get(&thread);
match opt_prog {
None => {
let loc_prog = ProgressBar::new(total as u64);
MULTI_PROGRESS.add(loc_prog.clone());
loc_prog.set_position(p as u64);
loc_prog.set_style(ProgressStyle::with_template("[{per_sec:15}] {wide_bar:.cyan/blue} {percent:>2}% ({pos:>8}/{len:>8}) ")
.unwrap()
.progress_chars("#>-"));
PROGRESS_BARS.insert(thread, loc_prog);
}
Some(prog) => {
prog.set_position(p as u64);
}
}
}
}

fn main() {
let args = Args::parse();
let seminars = import_seminars(&args.seminars_path);
let students = import_students(&args.students_path, &seminars);


let best_iteration = match args.variant {
ExecutionVariants::Classic => Some(run(&students, &seminars, args.iterations, get_default_points(), args.threads)),
ExecutionVariants::Classic => Some(run(&students, &seminars, args.iterations, get_default_points(), args.threads, update_progress)),
#[cfg(feature = "model-checking")]
ExecutionVariants::ModelChecking => run_model_check(&students, &seminars, get_default_points())
};
if let Some(bi_unwr) = best_iteration {

unsafe { PROGRESS_BARS.iter().for_each(|(_, p)| p.finish_and_clear()); }

if let Some(bi_unwr) = best_iteration {
let mut students_table_data = Vec::new();

for a in &bi_unwr.assignments {
Expand Down
11 changes: 6 additions & 5 deletions rust/src/rism_classic/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use rand::seq::SliceRandom;
use crate::constants::Points;
use std::{cmp, thread};
use std::sync::{Arc, Mutex};
use std::thread::ThreadId;

fn find_possible_assignment<'a>(wishes: &'a Vec<Seminar>, points: &Points, iteration: &RismResult) -> (Option<&'a Seminar>, u16) {
return if iteration.get_capacity(&wishes[0]) > 0 {
Expand All @@ -17,7 +18,7 @@ fn find_possible_assignment<'a>(wishes: &'a Vec<Seminar>, points: &Points, itera
};
}

pub fn run<'a>(students: &'a Vec<Student>, seminars: &'a Vec<Seminar>, iterations: u32, points: Points, threads: u16) -> RismResult<'a> {
pub fn run<'a>(students: &'a Vec<Student>, seminars: &'a Vec<Seminar>, iterations: u32, points: Points, threads: u16, progress: fn(ThreadId, u32, u32)) -> RismResult<'a> {

let results = Arc::new(Mutex::new(Vec::new()));

Expand All @@ -26,7 +27,7 @@ pub fn run<'a>(students: &'a Vec<Student>, seminars: &'a Vec<Seminar>, iteration
for _ in 0..threads {
s.spawn(|| {
let results_arc = results.clone();
let res = run_algorithm(students, seminars, thread_iterations, points.clone());
let res = run_algorithm(students, seminars, thread_iterations, points.clone(), progress);
let mut res_unwr = results_arc.lock().unwrap();
(*res_unwr).push(res);
});
Expand All @@ -41,11 +42,11 @@ pub fn run<'a>(students: &'a Vec<Student>, seminars: &'a Vec<Seminar>, iteration

}

pub fn run_algorithm<'a>(students: &'a Vec<Student>, seminars: &'a Vec<Seminar>, iterations: u32, points: Points) -> RismResult<'a> {
pub fn run_algorithm<'a>(students: &'a Vec<Student>, seminars: &'a Vec<Seminar>, iterations: u32, points: Points, progress: fn(ThreadId, u32, u32)) -> RismResult<'a> {
let mut best_iteration: Option<RismResult> = None;

// TODO Reintroduce progress
for _ in 0..iterations {
for p in 0..iterations {
progress(thread::current().id(), p, iterations);
let mut shuffled_indices: Vec<usize> = (0..students.len()).collect();
shuffled_indices.shuffle(&mut thread_rng());

Expand Down

0 comments on commit 38c10e0

Please sign in to comment.