Skip to content

Commit

Permalink
Store extra information about heuristic
Browse files Browse the repository at this point in the history
  • Loading branch information
reinterpretcat committed Aug 7, 2023
1 parent 812250e commit f468ed5
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 28 deletions.
88 changes: 60 additions & 28 deletions experiments/heuristic-research/src/solver/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,15 +123,21 @@ fn create_rosomaxa_state(network_state: NetworkState, fitness_values: Vec<f64>)
#[derive(Default, Serialize, Deserialize)]
pub struct SearchResult(pub usize, pub f64, pub (usize, usize), pub usize);

/// Heuristic state result represented as (state idx, alpha, beta, mu, v, n).
#[derive(Default, Serialize, Deserialize)]
pub struct HeuristicResult(pub usize, pub f64, pub f64, pub f64, pub f64, pub usize);

/// Keeps track of dynamic selective hyper heuristic state.
#[derive(Default, Serialize, Deserialize)]
pub struct HyperHeuristicState {
/// Unique heuristic names.
pub names: HashMap<String, usize>,
/// Unique state names.
pub states: HashMap<String, usize>,
/// Heuristic states at specific generations.
/// Search states at specific generations.
pub search_states: HashMap<usize, Vec<SearchResult>>,
/// Heuristic states at specific generations.
pub heuristic_states: HashMap<usize, Vec<HeuristicResult>>,
}

impl HyperHeuristicState {
Expand All @@ -146,37 +152,63 @@ impl HyperHeuristicState {
map.entry(key).or_insert_with(|| length);
};

let mut search_states = data.lines().skip(2).fold(HashMap::new(), |mut data, line| {
let fields: Vec<String> = line.split(',').map(|s| s.to_string()).collect();
let name = fields[0].clone();
let generation = fields[1].parse().unwrap();
let reward = fields[2].parse().unwrap();
let from = fields[3].clone();
let to = fields[4].clone();
let duration = fields[5].parse().unwrap();

insert_to_map(&mut names, name.clone());
insert_to_map(&mut states, from.clone());
insert_to_map(&mut states, to.clone());

let name = names.get(&name).copied().unwrap();
let from = states.get(&from).copied().unwrap();
let to = states.get(&to).copied().unwrap();

data.entry(generation).or_insert_with(Vec::default).push(SearchResult(
name,
reward,
(from, to),
duration,
));

data
});
let mut search_states =
data.lines().skip(3).take_while(|line| *line != "heuristic:").fold(HashMap::new(), |mut data, line| {
let fields: Vec<String> = line.split(',').map(|s| s.to_string()).collect();
let name = fields[0].clone();
let generation = fields[1].parse().unwrap();
let reward = fields[2].parse().unwrap();
let from = fields[3].clone();
let to = fields[4].clone();
let duration = fields[5].parse().unwrap();

insert_to_map(&mut names, name.clone());
insert_to_map(&mut states, from.clone());
insert_to_map(&mut states, to.clone());

let name = names.get(&name).copied().unwrap();
let from = states.get(&from).copied().unwrap();
let to = states.get(&to).copied().unwrap();

data.entry(generation).or_insert_with(Vec::default).push(SearchResult(
name,
reward,
(from, to),
duration,
));

data
});
search_states
.values_mut()
.for_each(|states| states.sort_by(|SearchResult(a, ..), SearchResult(b, ..)| a.cmp(b)));

Some(Self { names, states, search_states })
let mut heuristic_states =
data.lines().skip_while(|line| *line != "heuristic:").skip(2).fold(HashMap::new(), |mut data, line| {
let fields: Vec<String> = line.split(',').map(|s| s.to_string()).collect();

let generation: usize = fields[0].parse().unwrap();
let state = fields[1].clone();
let alpha = fields[2].parse().unwrap();
let beta = fields[3].parse().unwrap();
let mu = fields[4].parse().unwrap();
let v = fields[5].parse().unwrap();
let n = fields[6].parse().unwrap();

insert_to_map(&mut states, state.clone());
let state = states.get(&state).copied().unwrap();

data.entry(generation)
.or_insert_with(Vec::default)
.push(HeuristicResult(state, alpha, beta, mu, v, n));

data
});
heuristic_states
.values_mut()
.for_each(|states| states.sort_by(|HeuristicResult(a, ..), HeuristicResult(b, ..)| a.cmp(b)));

Some(Self { names, states, search_states, heuristic_states })
} else {
None
}
Expand Down
5 changes: 5 additions & 0 deletions rosomaxa/src/algorithms/rl/slot_machine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,11 @@ where
self.n += 1;
self.mu += (reward - self.mu) / self.n as f64;
}

/// Gets learned params (alpha, beta, mean and variants) and usage amount.
pub fn get_params(&self) -> (f64, f64, f64, f64, usize) {
(self.alpha, self.beta, self.mu, self.v, self.n)
}
}

impl<T, S> Display for SlotMachine<T, S>
Expand Down
43 changes: 43 additions & 0 deletions rosomaxa/src/hyper/dynamic_selective.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ where
self.agent.update(generation, feedback);
});

self.agent.save_params(generation);

feedbacks.into_iter().map(|feedback| feedback.solution).collect()
}

Expand Down Expand Up @@ -215,6 +217,7 @@ where
tracker: HeuristicTracker {
total_median: RemedianUsize::new(11, |a, b| a.cmp(b)),
search_telemetry: Default::default(),
heuristic_telemetry: Default::default(),
is_experimental: environment.is_experimental,
},
random: environment.random.clone(),
Expand Down Expand Up @@ -249,6 +252,20 @@ where

self.tracker.observe_sample(generation, feedback.sample.clone())
}

/// Updates statistics about heuristic internal parameters.
pub fn save_params(&mut self, generation: usize) {
if !self.tracker.telemetry_enabled() {
return;
}

self.slot_machines.iter().for_each(|(state, slots)| {
slots.iter().map(|slot| slot.get_params()).for_each(|(alpha, beta, mu, v, n)| {
self.tracker
.observe_params(generation, HeuristicSample { state: state.clone(), alpha, beta, mu, v, n });
})
});
}
}

impl<C, O, S> Display for DynamicSelective<C, O, S>
Expand All @@ -263,6 +280,7 @@ where
}

f.write_fmt(format_args!("TELEMETRY\n"))?;
f.write_fmt(format_args!("search:\n"))?;
f.write_fmt(format_args!("name,generation,reward,from,to,duration\n"))?;
for (generation, sample) in self.agent.tracker.search_telemetry.iter() {
f.write_fmt(format_args!(
Expand All @@ -271,6 +289,15 @@ where
))?;
}

f.write_fmt(format_args!("heuristic:\n"))?;
f.write_fmt(format_args!("generation,state,alpha,beta,mu,v,n\n"))?;
for (generation, sample) in self.agent.tracker.heuristic_telemetry.iter() {
f.write_fmt(format_args!(
"{},{},{},{},{},{},{}\n",
generation, sample.state, sample.alpha, sample.beta, sample.mu, sample.v, sample.n
))?;
}

Ok(())
}
}
Expand Down Expand Up @@ -408,10 +435,20 @@ struct SearchSample {
transition: (SearchState, SearchState),
}

struct HeuristicSample {
state: SearchState,
alpha: f64,
beta: f64,
mu: f64,
v: f64,
n: usize,
}

/// Provides way to track heuristic's telemetry and duration median estimation.
struct HeuristicTracker {
total_median: RemedianUsize,
search_telemetry: Vec<(usize, SearchSample)>,
heuristic_telemetry: Vec<(usize, HeuristicSample)>,
is_experimental: bool,
}

Expand All @@ -433,4 +470,10 @@ impl HeuristicTracker {
self.search_telemetry.push((generation, sample));
}
}

pub fn observe_params(&mut self, generation: usize, sample: HeuristicSample) {
if self.telemetry_enabled() {
self.heuristic_telemetry.push((generation, sample));
}
}
}

0 comments on commit f468ed5

Please sign in to comment.