From f468ed505510618f16d031d6b978f42f264c2a16 Mon Sep 17 00:00:00 2001 From: reinterpretcat Date: Mon, 7 Aug 2023 20:53:58 +0200 Subject: [PATCH] Store extra information about heuristic --- .../heuristic-research/src/solver/state.rs | 88 +++++++++++++------ rosomaxa/src/algorithms/rl/slot_machine.rs | 5 ++ rosomaxa/src/hyper/dynamic_selective.rs | 43 +++++++++ 3 files changed, 108 insertions(+), 28 deletions(-) diff --git a/experiments/heuristic-research/src/solver/state.rs b/experiments/heuristic-research/src/solver/state.rs index 6ae85ec0f..4693b600b 100644 --- a/experiments/heuristic-research/src/solver/state.rs +++ b/experiments/heuristic-research/src/solver/state.rs @@ -123,6 +123,10 @@ fn create_rosomaxa_state(network_state: NetworkState, fitness_values: Vec) #[derive(Default, Serialize, Deserialize)] pub struct SearchResult(pub usize, pub f64, pub (usize, usize), pub usize); +/// Heuristic state result represented as (state idx, alpha, beta, mu, v, n). +#[derive(Default, Serialize, Deserialize)] +pub struct HeuristicResult(pub usize, pub f64, pub f64, pub f64, pub f64, pub usize); + /// Keeps track of dynamic selective hyper heuristic state. #[derive(Default, Serialize, Deserialize)] pub struct HyperHeuristicState { @@ -130,8 +134,10 @@ pub struct HyperHeuristicState { pub names: HashMap, /// Unique state names. pub states: HashMap, - /// Heuristic states at specific generations. + /// Search states at specific generations. pub search_states: HashMap>, + /// Heuristic states at specific generations. + pub heuristic_states: HashMap>, } impl HyperHeuristicState { @@ -146,37 +152,63 @@ impl HyperHeuristicState { map.entry(key).or_insert_with(|| length); }; - let mut search_states = data.lines().skip(2).fold(HashMap::new(), |mut data, line| { - let fields: Vec = line.split(',').map(|s| s.to_string()).collect(); - let name = fields[0].clone(); - let generation = fields[1].parse().unwrap(); - let reward = fields[2].parse().unwrap(); - let from = fields[3].clone(); - let to = fields[4].clone(); - let duration = fields[5].parse().unwrap(); - - insert_to_map(&mut names, name.clone()); - insert_to_map(&mut states, from.clone()); - insert_to_map(&mut states, to.clone()); - - let name = names.get(&name).copied().unwrap(); - let from = states.get(&from).copied().unwrap(); - let to = states.get(&to).copied().unwrap(); - - data.entry(generation).or_insert_with(Vec::default).push(SearchResult( - name, - reward, - (from, to), - duration, - )); - - data - }); + let mut search_states = + data.lines().skip(3).take_while(|line| *line != "heuristic:").fold(HashMap::new(), |mut data, line| { + let fields: Vec = line.split(',').map(|s| s.to_string()).collect(); + let name = fields[0].clone(); + let generation = fields[1].parse().unwrap(); + let reward = fields[2].parse().unwrap(); + let from = fields[3].clone(); + let to = fields[4].clone(); + let duration = fields[5].parse().unwrap(); + + insert_to_map(&mut names, name.clone()); + insert_to_map(&mut states, from.clone()); + insert_to_map(&mut states, to.clone()); + + let name = names.get(&name).copied().unwrap(); + let from = states.get(&from).copied().unwrap(); + let to = states.get(&to).copied().unwrap(); + + data.entry(generation).or_insert_with(Vec::default).push(SearchResult( + name, + reward, + (from, to), + duration, + )); + + data + }); search_states .values_mut() .for_each(|states| states.sort_by(|SearchResult(a, ..), SearchResult(b, ..)| a.cmp(b))); - Some(Self { names, states, search_states }) + let mut heuristic_states = + data.lines().skip_while(|line| *line != "heuristic:").skip(2).fold(HashMap::new(), |mut data, line| { + let fields: Vec = line.split(',').map(|s| s.to_string()).collect(); + + let generation: usize = fields[0].parse().unwrap(); + let state = fields[1].clone(); + let alpha = fields[2].parse().unwrap(); + let beta = fields[3].parse().unwrap(); + let mu = fields[4].parse().unwrap(); + let v = fields[5].parse().unwrap(); + let n = fields[6].parse().unwrap(); + + insert_to_map(&mut states, state.clone()); + let state = states.get(&state).copied().unwrap(); + + data.entry(generation) + .or_insert_with(Vec::default) + .push(HeuristicResult(state, alpha, beta, mu, v, n)); + + data + }); + heuristic_states + .values_mut() + .for_each(|states| states.sort_by(|HeuristicResult(a, ..), HeuristicResult(b, ..)| a.cmp(b))); + + Some(Self { names, states, search_states, heuristic_states }) } else { None } diff --git a/rosomaxa/src/algorithms/rl/slot_machine.rs b/rosomaxa/src/algorithms/rl/slot_machine.rs index ae7bdf325..dac754fd5 100644 --- a/rosomaxa/src/algorithms/rl/slot_machine.rs +++ b/rosomaxa/src/algorithms/rl/slot_machine.rs @@ -86,6 +86,11 @@ where self.n += 1; self.mu += (reward - self.mu) / self.n as f64; } + + /// Gets learned params (alpha, beta, mean and variants) and usage amount. + pub fn get_params(&self) -> (f64, f64, f64, f64, usize) { + (self.alpha, self.beta, self.mu, self.v, self.n) + } } impl Display for SlotMachine diff --git a/rosomaxa/src/hyper/dynamic_selective.rs b/rosomaxa/src/hyper/dynamic_selective.rs index c869e59bc..b7e374f22 100644 --- a/rosomaxa/src/hyper/dynamic_selective.rs +++ b/rosomaxa/src/hyper/dynamic_selective.rs @@ -67,6 +67,8 @@ where self.agent.update(generation, feedback); }); + self.agent.save_params(generation); + feedbacks.into_iter().map(|feedback| feedback.solution).collect() } @@ -215,6 +217,7 @@ where tracker: HeuristicTracker { total_median: RemedianUsize::new(11, |a, b| a.cmp(b)), search_telemetry: Default::default(), + heuristic_telemetry: Default::default(), is_experimental: environment.is_experimental, }, random: environment.random.clone(), @@ -249,6 +252,20 @@ where self.tracker.observe_sample(generation, feedback.sample.clone()) } + + /// Updates statistics about heuristic internal parameters. + pub fn save_params(&mut self, generation: usize) { + if !self.tracker.telemetry_enabled() { + return; + } + + self.slot_machines.iter().for_each(|(state, slots)| { + slots.iter().map(|slot| slot.get_params()).for_each(|(alpha, beta, mu, v, n)| { + self.tracker + .observe_params(generation, HeuristicSample { state: state.clone(), alpha, beta, mu, v, n }); + }) + }); + } } impl Display for DynamicSelective @@ -263,6 +280,7 @@ where } f.write_fmt(format_args!("TELEMETRY\n"))?; + f.write_fmt(format_args!("search:\n"))?; f.write_fmt(format_args!("name,generation,reward,from,to,duration\n"))?; for (generation, sample) in self.agent.tracker.search_telemetry.iter() { f.write_fmt(format_args!( @@ -271,6 +289,15 @@ where ))?; } + f.write_fmt(format_args!("heuristic:\n"))?; + f.write_fmt(format_args!("generation,state,alpha,beta,mu,v,n\n"))?; + for (generation, sample) in self.agent.tracker.heuristic_telemetry.iter() { + f.write_fmt(format_args!( + "{},{},{},{},{},{},{}\n", + generation, sample.state, sample.alpha, sample.beta, sample.mu, sample.v, sample.n + ))?; + } + Ok(()) } } @@ -408,10 +435,20 @@ struct SearchSample { transition: (SearchState, SearchState), } +struct HeuristicSample { + state: SearchState, + alpha: f64, + beta: f64, + mu: f64, + v: f64, + n: usize, +} + /// Provides way to track heuristic's telemetry and duration median estimation. struct HeuristicTracker { total_median: RemedianUsize, search_telemetry: Vec<(usize, SearchSample)>, + heuristic_telemetry: Vec<(usize, HeuristicSample)>, is_experimental: bool, } @@ -433,4 +470,10 @@ impl HeuristicTracker { self.search_telemetry.push((generation, sample)); } } + + pub fn observe_params(&mut self, generation: usize, sample: HeuristicSample) { + if self.telemetry_enabled() { + self.heuristic_telemetry.push((generation, sample)); + } + } }