-
Notifications
You must be signed in to change notification settings - Fork 70
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
9e43b58
commit 2f793b7
Showing
6 changed files
with
201 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,3 +4,4 @@ pub mod gsom; | |
pub mod math; | ||
pub mod mdp; | ||
pub mod nsga2; | ||
pub mod rl; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
//! This module contains implementation of some reinforcement learning algorithms. | ||
|
||
mod slot_machine; | ||
pub use self::slot_machine::SlotMachine; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
#[cfg(test)] | ||
#[path = "../../../tests/unit/algorithms/rl/slot_machine_test.rs"] | ||
mod slot_machine_test; | ||
|
||
use crate::utils::DistributionSampler; | ||
use std::fmt::{Display, Formatter}; | ||
|
||
/// Simulates a slot machine. | ||
/// Internally tries to estimate reward probability distribution using one of methods from Thompson sampling. | ||
#[derive(Clone)] | ||
pub struct SlotMachine<T, S> | ||
where | ||
T: Clone, | ||
S: Clone, | ||
{ | ||
/// The number of times this slot machine has been tried. | ||
n: usize, | ||
/// Gamma shape parameter. | ||
alpha: f64, | ||
/// Gamma rate parameter. | ||
beta: f64, | ||
/// Estimated mean. | ||
mu: f64, | ||
/// Estimated variance. | ||
v: f64, | ||
/// Sampler: used to provide samples from underlying estimated distribution. | ||
sampler: S, | ||
/// Actual slot play function. | ||
player: T, | ||
} | ||
|
||
impl<T, S> SlotMachine<T, S> | ||
where | ||
T: Fn() -> f64 + Clone, | ||
S: DistributionSampler + Clone, | ||
{ | ||
/// Creates a new instance of `SlotMachine`. | ||
pub fn new(prior_mean: f64, sampler: S, player: T) -> Self { | ||
let alpha = 1.; | ||
let beta = 10.; | ||
let mu_0 = prior_mean; | ||
let v_0 = beta / (alpha + 1.); | ||
|
||
Self { n: 0, alpha, beta, mu: mu_0, v: v_0, player, sampler } | ||
} | ||
|
||
/// Samples from estimated normal distribution. | ||
pub fn sample(&self) -> f64 { | ||
let precision = self.sampler.gamma(self.alpha, 1. / self.beta); | ||
let precision = if precision == 0. || self.n == 0 { 0.001 } else { precision }; | ||
let variance = 1. / precision; | ||
|
||
self.sampler.normal(self.mu, variance.sqrt()) | ||
} | ||
|
||
/// Plays the game and updates slot state. | ||
pub fn play(&mut self) { | ||
let reward = (self.player)(); | ||
self.update(reward); | ||
} | ||
|
||
/// Updates slot machine. | ||
fn update(&mut self, reward: f64) { | ||
let n = 1.; | ||
let v = self.n as f64; | ||
|
||
self.alpha += n / 2.; | ||
self.beta += (n * v / (v + n)) * (reward - self.mu).powi(2) / 2.; | ||
|
||
// estimate the variance: calculate running mean from the gamma hyper-parameters | ||
self.v = self.beta / (self.alpha + 1.); | ||
self.n += 1; | ||
self.mu += (reward - self.mu) / self.n as f64; | ||
} | ||
} | ||
|
||
impl<T, S> Display for SlotMachine<T, S> | ||
where | ||
T: Clone, | ||
S: Clone, | ||
{ | ||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { | ||
write!(f, "n={},alpha={},beta={},mu={},v={}", self.n, self.alpha, self.beta, self.mu, self.v) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
use super::*; | ||
use crate::helpers::utils::create_test_random; | ||
use crate::utils::{random_argmax, DefaultDistributionSampler}; | ||
|
||
#[test] | ||
fn can_find_proper_estimations() { | ||
let total_episodes = 100; | ||
let expected_failures_threshold = 20; | ||
let failed_slot_estimations: usize = (0..total_episodes) | ||
.map(|_| { | ||
let slot_means = &[5.0_f64, 9., 7., 13., 11.]; | ||
let slot_vars = &[2.0_f64, 3., 4., 6., 1.]; | ||
let prior_mean = 1.; | ||
let attempts = 1000; | ||
let delta = 2.; | ||
|
||
let random = create_test_random(); | ||
let sampler = DefaultDistributionSampler::new(random.clone()); | ||
let mut slots = (0..5) | ||
.map(|idx| { | ||
SlotMachine::new(prior_mean, sampler.clone(), { | ||
let sampler = sampler.clone(); | ||
move || sampler.normal(slot_means[idx], slot_vars[idx].sqrt()) | ||
}) | ||
}) | ||
.collect::<Vec<_>>(); | ||
|
||
for _ in 0..attempts { | ||
let slot_idx = random_argmax(slots.iter().map(|slot| slot.sample()), random.as_ref()).unwrap(); | ||
slots[slot_idx].play(); | ||
} | ||
|
||
slots | ||
.iter() | ||
.enumerate() | ||
.filter(|(idx, slot)| { | ||
(slot.mu - slot_means[*idx]).abs() > delta || (slot.v - slot_vars[*idx]).abs() > delta | ||
}) | ||
.map(|_| 1) | ||
.sum::<usize>() | ||
}) | ||
.sum(); | ||
|
||
assert!(failed_slot_estimations < expected_failures_threshold); | ||
} |