Skip to content

Commit

Permalink
chore: add monty policy
Browse files Browse the repository at this point in the history
╔══════════════════════════════════════════════════════════╗
║    Name               Elo Error   Wins Loss Draw   Total ║
╠══════════════════════════════════════════════════════════╣
║  1. Mexx              -102   83     23   42    0      65 ║
║  2. MexxMontyPolicy   +102   83     42   23    0      65 ║
╚══════════════════════════════════════════════════════════╝
  • Loading branch information
raklaptudirm committed Jun 5, 2024
1 parent 2e87021 commit 4578e25
Show file tree
Hide file tree
Showing 7 changed files with 119 additions and 51 deletions.
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ edition = "2021"
ataxx = { path = "ataxx" }
uxi = { path = "uxi" }
rand = "0.8.5"
goober = { git = 'https://github.com/jw1912/goober.git' }


[profile.release]
opt-level = 3
Expand Down
Binary file modified resources/net.network
Binary file not shown.
62 changes: 62 additions & 0 deletions src/mcts/features.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
use goober::SparseVector;

use super::policy::PolicyNetwork;
use super::simulate::ValueNetwork;

#[repr(C)]
pub struct Nets(pub ValueNetwork<2916, 256>, pub PolicyNetwork);

pub const NETS: Nets =
unsafe { std::mem::transmute(*include_bytes!("../../resources/net.network")) };

pub fn value_feature_map<F: FnMut(usize)>(position: &ataxx::Position, mut f: F) {
const PER_TUPLE: usize = 3usize.pow(4);
const POWERS: [usize; 4] = [1, 3, 9, 27];
const MASK: u64 = 0b0001_1000_0011;

let friends = position.bitboard(position.side_to_move).0;
let enemies = position.bitboard(!position.side_to_move).0;

for i in 0..6 {
for j in 0..6 {
let tuple = 6 * i + j;
let mut feat = PER_TUPLE * tuple;

let offset = 7 * i + j;
let mut b = (friends >> offset) & MASK;
let mut o = (enemies >> offset) & MASK;

while b > 0 {
let mut sq = b.trailing_zeros() as usize;
if sq > 6 {
sq -= 5;
}

feat += POWERS[sq];

b &= b - 1;
}

while o > 0 {
let mut sq = o.trailing_zeros() as usize;
if sq > 6 {
sq -= 5;
}

feat += 2 * POWERS[sq];

o &= o - 1;
}

f(feat);
}
}
}

pub fn get_features(position: &ataxx::Position) -> SparseVector {
let mut feats = SparseVector::with_capacity(36);

value_feature_map(position, |feat| feats.push(feat));

feats
}
4 changes: 2 additions & 2 deletions src/mcts/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ pub struct Tree {

impl Tree {
pub fn new(position: ataxx::Position) -> Tree {
let policy = policy::handcrafted;
let simulator = simulate::monty_network;
let policy = policy::monty;
let simulator = simulate::material_count;

let mut root = Node::new(position, -1);
root.expand(policy);
Expand Down
2 changes: 2 additions & 0 deletions src/mcts/mod.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
pub mod policy;

mod features;
mod graph;
mod node;
mod params;
mod simulate;

pub use self::features::*;
pub use self::graph::*;
pub use self::node::*;
pub use self::params::*;
49 changes: 48 additions & 1 deletion src/mcts/policy.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use super::Node;
use super::{get_features, Node, NETS};
use ataxx::{BitBoard, Move};

use goober::{activation, layer, FeedForwardNetwork, Matrix, SparseVector, Vector};

pub type Fn = fn(node: &Node, mov: Move) -> f64;

pub fn handcrafted(node: &Node, mov: Move) -> f64 {
Expand All @@ -26,3 +28,48 @@ pub fn handcrafted(node: &Node, mov: Move) -> f64 {

score.max(0.1)
}

pub fn monty(node: &Node, mov: Move) -> f64 {
NETS.1.get(&mov, &get_features(&node.position)) as f64
}

#[repr(C)]
#[derive(Clone, Copy, FeedForwardNetwork)]
pub struct SubNet {
ft: layer::SparseConnected<activation::ReLU, 2916, 8>,
}

impl SubNet {
pub const fn zeroed() -> Self {
Self {
ft: layer::SparseConnected::zeroed(),
}
}

pub fn from_fn<F: FnMut() -> f32>(mut f: F) -> Self {
let matrix = Matrix::from_fn(|_, _| f());
let vector = Vector::from_fn(|_| f());

Self {
ft: layer::SparseConnected::from_raw(matrix, vector),
}
}
}

#[repr(C)]
#[derive(Clone, Copy)]
pub struct PolicyNetwork {
pub subnets: [SubNet; 99],
}

impl PolicyNetwork {
pub fn get(&self, mov: &Move, feats: &SparseVector) -> f32 {
let from_subnet = &self.subnets[(mov.source() as usize).min(49)];
let from_vec = from_subnet.out(feats);

let to_subnet = &self.subnets[50 + (mov.target() as usize).min(48)];
let to_vec = to_subnet.out(feats);

from_vec.dot(&to_vec)
}
}
51 changes: 3 additions & 48 deletions src/mcts/simulate.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use super::{value_feature_map, NETS};

pub type Fn = fn(position: &ataxx::Position) -> f64;

pub fn material_count(position: &ataxx::Position) -> f64 {
Expand All @@ -11,62 +13,15 @@ pub fn material_count(position: &ataxx::Position) -> f64 {
1.0 / (1.0 + f64::exp(-eval / 400.0))
}

const VALUE_NETWORK: ValueNetwork<2916, 256> =
unsafe { std::mem::transmute(*include_bytes!("../../resources/net.network")) };

pub fn monty_network(position: &ataxx::Position) -> f64 {
1.0 / (1.0 + (-(VALUE_NETWORK.eval(position) as f64) / 400.0).exp())
1.0 / (1.0 + (-(NETS.0.eval(position) as f64) / 400.0).exp())
}

const SCALE: i32 = 400;
const QA: i32 = 255;
const QB: i32 = 64;
const QAB: i32 = QA * QB;

pub fn value_feature_map<F: FnMut(usize)>(position: &ataxx::Position, mut f: F) {
const PER_TUPLE: usize = 3usize.pow(4);
const POWERS: [usize; 4] = [1, 3, 9, 27];
const MASK: u64 = 0b0001_1000_0011;

let friends = position.bitboard(position.side_to_move).0;
let enemies = position.bitboard(!position.side_to_move).0;

for i in 0..6 {
for j in 0..6 {
let tuple = 6 * i + j;
let mut feat = PER_TUPLE * tuple;

let offset = 7 * i + j;
let mut b = (friends >> offset) & MASK;
let mut o = (enemies >> offset) & MASK;

while b > 0 {
let mut sq = b.trailing_zeros() as usize;
if sq > 6 {
sq -= 5;
}

feat += POWERS[sq];

b &= b - 1;
}

while o > 0 {
let mut sq = o.trailing_zeros() as usize;
if sq > 6 {
sq -= 5;
}

feat += 2 * POWERS[sq];

o &= o - 1;
}

f(feat);
}
}
}

#[repr(C, align(64))]
pub struct ValueNetwork<const INPUT: usize, const HIDDEN: usize> {
l1_weights: [Accumulator<HIDDEN>; INPUT],
Expand Down

0 comments on commit 4578e25

Please sign in to comment.