Skip to content

Commit

Permalink
feat(initializer): support constant activation functions
Browse files Browse the repository at this point in the history
  • Loading branch information
krishnachittur committed Sep 9, 2023
1 parent a36bbed commit 85214ac
Showing 1 changed file with 15 additions and 1 deletion.
16 changes: 15 additions & 1 deletion src/neural_network/initializer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ pub enum Initializer {
/// Use for ReLU and LeakyReLU activation functions.
/// Gaussian, µ = 0, σ = √[2 / f_in]
Kaiming,
Constant(f64),
}

impl Initializer {
Expand All @@ -37,6 +38,7 @@ impl Initializer {
match *self {
Initializer::Xavier => Self::xavier(f_in, f_out),
Initializer::Kaiming => Self::kaiming(f_in, f_out),
Initializer::Constant(val) => Self::constant(f_in, f_out, val),
}
}

Expand All @@ -50,14 +52,26 @@ impl Initializer {
Self::initialize_data(f_in, f_out, 2.0 / ((f_in) as f64).sqrt())
}

fn constant(f_in: usize, f_out: usize, val: f64) -> Tensor2D {
Self::initialize_with_closure(f_in, f_out, &mut || val)
}

/// Initializes a 2D tensor with random values based on the specified standard deviation.
fn initialize_data(f_in: usize, f_out: usize, std_dev: f64) -> Tensor2D {
let mut rng = rand::thread_rng();
Self::initialize_with_closure(f_in, f_out, &mut || rng.gen::<f64>() * std_dev)
}

fn initialize_with_closure(
f_in: usize,
f_out: usize,
closure: &mut dyn FnMut() -> f64,
) -> Tensor2D {
let mut data = Vec::with_capacity(f_in);
for _ in 0..f_in {
let mut row = Vec::with_capacity(f_out);
for _ in 0..f_out {
row.push(rng.gen::<f64>() * std_dev);
row.push(closure());
}
data.push(row);
}
Expand Down

0 comments on commit 85214ac

Please sign in to comment.