diff --git a/src/neural_network/initializer.rs b/src/neural_network/initializer.rs index 3d80e9d..b03bf9d 100644 --- a/src/neural_network/initializer.rs +++ b/src/neural_network/initializer.rs @@ -15,6 +15,7 @@ pub enum Initializer { /// Use for ReLU and LeakyReLU activation functions. /// Gaussian, µ = 0, σ = √[2 / f_in] Kaiming, + Constant(f64), } impl Initializer { @@ -37,6 +38,7 @@ impl Initializer { match *self { Initializer::Xavier => Self::xavier(f_in, f_out), Initializer::Kaiming => Self::kaiming(f_in, f_out), + Initializer::Constant(val) => Self::constant(f_in, f_out, val), } } @@ -50,14 +52,26 @@ impl Initializer { Self::initialize_data(f_in, f_out, 2.0 / ((f_in) as f64).sqrt()) } + fn constant(f_in: usize, f_out: usize, val: f64) -> Tensor2D { + Self::initialize_with_closure(f_in, f_out, &mut || val) + } + /// Initializes a 2D tensor with random values based on the specified standard deviation. fn initialize_data(f_in: usize, f_out: usize, std_dev: f64) -> Tensor2D { let mut rng = rand::thread_rng(); + Self::initialize_with_closure(f_in, f_out, &mut || rng.gen::() * std_dev) + } + + fn initialize_with_closure( + f_in: usize, + f_out: usize, + closure: &mut dyn FnMut() -> f64, + ) -> Tensor2D { let mut data = Vec::with_capacity(f_in); for _ in 0..f_in { let mut row = Vec::with_capacity(f_out); for _ in 0..f_out { - row.push(rng.gen::() * std_dev); + row.push(closure()); } data.push(row); }