Skip to content

Commit

Permalink
Add slot machine implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
reinterpretcat committed Aug 5, 2023
1 parent 9e43b58 commit 2f793b7
Show file tree
Hide file tree
Showing 6 changed files with 201 additions and 0 deletions.
1 change: 1 addition & 0 deletions rosomaxa/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ heuristic-telemetry = []
rand.workspace = true
rayon.workspace = true
rustc-hash.workspace = true
rand_distr = "0.4.3"

[target.'cfg(not(target_arch = "wasm32"))'.dependencies]
num_cpus = "1.16.0"
Expand Down
1 change: 1 addition & 0 deletions rosomaxa/src/algorithms/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ pub mod gsom;
pub mod math;
pub mod mdp;
pub mod nsga2;
pub mod rl;
4 changes: 4 additions & 0 deletions rosomaxa/src/algorithms/rl/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
//! This module contains implementation of some reinforcement learning algorithms.

mod slot_machine;
pub use self::slot_machine::SlotMachine;
85 changes: 85 additions & 0 deletions rosomaxa/src/algorithms/rl/slot_machine.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
#[cfg(test)]
#[path = "../../../tests/unit/algorithms/rl/slot_machine_test.rs"]
mod slot_machine_test;

use crate::utils::DistributionSampler;
use std::fmt::{Display, Formatter};

/// Simulates a slot machine.
/// Internally tries to estimate reward probability distribution using one of methods from Thompson sampling.
#[derive(Clone)]
pub struct SlotMachine<T, S>
where
T: Clone,
S: Clone,
{
/// The number of times this slot machine has been tried.
n: usize,
/// Gamma shape parameter.
alpha: f64,
/// Gamma rate parameter.
beta: f64,
/// Estimated mean.
mu: f64,
/// Estimated variance.
v: f64,
/// Sampler: used to provide samples from underlying estimated distribution.
sampler: S,
/// Actual slot play function.
player: T,
}

impl<T, S> SlotMachine<T, S>
where
T: Fn() -> f64 + Clone,
S: DistributionSampler + Clone,
{
/// Creates a new instance of `SlotMachine`.
pub fn new(prior_mean: f64, sampler: S, player: T) -> Self {
let alpha = 1.;
let beta = 10.;
let mu_0 = prior_mean;
let v_0 = beta / (alpha + 1.);

Self { n: 0, alpha, beta, mu: mu_0, v: v_0, player, sampler }
}

/// Samples from estimated normal distribution.
pub fn sample(&self) -> f64 {
let precision = self.sampler.gamma(self.alpha, 1. / self.beta);
let precision = if precision == 0. || self.n == 0 { 0.001 } else { precision };
let variance = 1. / precision;

self.sampler.normal(self.mu, variance.sqrt())
}

/// Plays the game and updates slot state.
pub fn play(&mut self) {
let reward = (self.player)();
self.update(reward);
}

/// Updates slot machine.
fn update(&mut self, reward: f64) {
let n = 1.;
let v = self.n as f64;

self.alpha += n / 2.;
self.beta += (n * v / (v + n)) * (reward - self.mu).powi(2) / 2.;

// estimate the variance: calculate running mean from the gamma hyper-parameters
self.v = self.beta / (self.alpha + 1.);
self.n += 1;
self.mu += (reward - self.mu) / self.n as f64;
}
}

impl<T, S> Display for SlotMachine<T, S>
where
T: Clone,
S: Clone,
{
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "n={},alpha={},beta={},mu={},v={}", self.n, self.alpha, self.beta, self.mu, self.v)
}
}
65 changes: 65 additions & 0 deletions rosomaxa/src/utils/random.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,19 @@ mod random_test;

use rand::prelude::*;
use rand::Error;
use rand_distr::{Gamma, Normal};
use std::cell::RefCell;
use std::cmp::Ordering;
use std::sync::Arc;

/// Provides the way to sample from different distributions.
pub trait DistributionSampler {
/// Returns a sample from gamma distribution.
fn gamma(&self, shape: f64, scale: f64) -> f64;

/// Returns a sample from normal distribution.
fn normal(&self, mean: f64, std_dev: f64) -> f64;
}

/// Provides the way to use randomized values in generic way.
pub trait Random {
Expand All @@ -29,6 +41,31 @@ pub trait Random {
fn get_rng(&self) -> RandomGen;
}

/// Provides way to sample from different distributions.
#[derive(Clone)]
pub struct DefaultDistributionSampler(Arc<dyn Random>);

impl DefaultDistributionSampler {
/// Creates a new instance of `DefaultDistributionSampler`.
pub fn new(random: Arc<dyn Random>) -> Self {
Self(random)
}
}

impl DistributionSampler for DefaultDistributionSampler {
fn gamma(&self, shape: f64, scale: f64) -> f64 {
Gamma::new(shape, scale)
.unwrap_or_else(|_| panic!("cannot create gamma dist: shape={shape}, scale={scale}"))
.sample(&mut self.0.get_rng())
}

fn normal(&self, mean: f64, std_dev: f64) -> f64 {
Normal::new(mean, std_dev)
.unwrap_or_else(|_| panic!("cannot create normal dist: mean={mean}, std_dev={std_dev}"))
.sample(&mut self.0.get_rng())
}
}

/// A default random implementation.
#[derive(Default)]
pub struct DefaultRandom {}
Expand Down Expand Up @@ -137,3 +174,31 @@ impl RngCore for RandomGen {
}

impl CryptoRng for RandomGen {}

/// Returns an index of max element in values. In case of many same max elements,
/// returns the one from them at random.
pub fn random_argmax<I>(values: I, random: &dyn Random) -> Option<usize>
where
I: Iterator<Item = f64>,
{
let mut rng = random.get_rng();
let mut count = 0;
values
.enumerate()
.max_by(move |(_, r), (_, s)| match r.total_cmp(s) {
Ordering::Equal => {
count += 1;
if rng.gen_range(0..=count) == 0 {
Ordering::Less
} else {
Ordering::Greater
}
}
Ordering::Less => {
count = 0;
Ordering::Less
}
Ordering::Greater => Ordering::Greater,
})
.map(|(idx, _)| idx)
}
45 changes: 45 additions & 0 deletions rosomaxa/tests/unit/algorithms/rl/slot_machine_test.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
use super::*;
use crate::helpers::utils::create_test_random;
use crate::utils::{random_argmax, DefaultDistributionSampler};

#[test]
fn can_find_proper_estimations() {
let total_episodes = 100;
let expected_failures_threshold = 20;
let failed_slot_estimations: usize = (0..total_episodes)
.map(|_| {
let slot_means = &[5.0_f64, 9., 7., 13., 11.];
let slot_vars = &[2.0_f64, 3., 4., 6., 1.];
let prior_mean = 1.;
let attempts = 1000;
let delta = 2.;

let random = create_test_random();
let sampler = DefaultDistributionSampler::new(random.clone());
let mut slots = (0..5)
.map(|idx| {
SlotMachine::new(prior_mean, sampler.clone(), {
let sampler = sampler.clone();
move || sampler.normal(slot_means[idx], slot_vars[idx].sqrt())
})
})
.collect::<Vec<_>>();

for _ in 0..attempts {
let slot_idx = random_argmax(slots.iter().map(|slot| slot.sample()), random.as_ref()).unwrap();
slots[slot_idx].play();
}

slots
.iter()
.enumerate()
.filter(|(idx, slot)| {
(slot.mu - slot_means[*idx]).abs() > delta || (slot.v - slot_vars[*idx]).abs() > delta
})
.map(|_| 1)
.sum::<usize>()
})
.sum();

assert!(failed_slot_estimations < expected_failures_threshold);
}

0 comments on commit 2f793b7

Please sign in to comment.