Skip to content

Commit

Permalink
docs: update general
Browse files Browse the repository at this point in the history
  • Loading branch information
drewxs committed Sep 8, 2023
1 parent fbf378d commit 3cdeb5d
Show file tree
Hide file tree
Showing 6 changed files with 21 additions and 9 deletions.
1 change: 1 addition & 0 deletions src/loss/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ pub use mse::*;

use crate::Tensor;

/// Loss functions that can be used to train a neural network.
#[derive(Debug, Clone, Copy, PartialEq, Hash)]
pub enum LossFunction {
BinaryCrossEntropy,
Expand Down
11 changes: 7 additions & 4 deletions src/neural_network/initializer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,16 @@ use rand::Rng;

use crate::Tensor2D;

/// Initialization methods for weight matrices.
#[derive(Debug)]
pub enum Initializer {
/// Xavier/Glorot initialization.
/// Use for Sigmoid and TanH activation functions.
/// Gaussian, µ = 0, σ = √[2 / (f_in + f_out)]
Xavier,
/// Kaiming initialization.
/// Use for ReLU and LeakyReLU activation functions.
/// Gaussian, µ = 0, σ = √[2 / f_in]
Kaiming,
}

Expand Down Expand Up @@ -34,15 +41,11 @@ impl Initializer {
}

/// Xavier/Glorot initialization
/// Use for sigmoid and tanh activation functions
/// Gaussian, µ = 0, σ = √[2 / (f_in + f_out)]
fn xavier(f_in: usize, f_out: usize) -> Tensor2D {
Self::initialize_data(f_in, f_out, 2.0 / ((f_in + f_out) as f64).sqrt())
}

/// Kaiming initialization
/// Use for ReLU activation functions
/// Gaussian, µ = 0, σ = √[2 / f_in]
fn kaiming(f_in: usize, f_out: usize) -> Tensor2D {
Self::initialize_data(f_in, f_out, 2.0 / ((f_in) as f64).sqrt())
}
Expand Down
6 changes: 6 additions & 0 deletions src/neural_network/mlp.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
//! Multi-layer perceptron (MLP) neural network.
//!
//! A generic feed forward neural network (FNN), also known as a multi-layer perceptron (MLP).
//! Typically used for classification and regression tasks.

use crate::{Activation, Initializer, Layer, LossFunction, Optimizer, Tensor};

#[derive(Debug)]
pub struct Network {
/// The layers in the network.
pub layers: Vec<Layer>,
/// The loss function used to train the network.
pub loss_function: LossFunction,
/// The initializer used to initialize the weights and biases in the network.
pub initializer: Initializer,
/// The optimizer used to optimize the weights and biases in the network.
pub optimizer: Optimizer,
}

Expand Down
4 changes: 3 additions & 1 deletion src/neural_network/regularization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@

use crate::Tensor;

/// An enum representing the regularization methods that can be applied to a tensor.
/// Regularization methods that can be applied to tensors.
pub enum Regularization {
/// L1 regularization with the given lambda value.
/// Also known as LASSO, used to encourage sparsity.
L1(f64),
/// L2 regularization with the given lambda value.
/// Also known as Ridge, used to encourage small weights.
L2(f64),
}

Expand Down
6 changes: 3 additions & 3 deletions src/optimizer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ use crate::Tensor;
/// Optimizer enum that allows for different optimizers to be used with neural networks.
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum Optimizer {
SGD {
learning_rate: f64,
},
/// Stochastic Gradient Descent (SGD) optimizer.
SGD { learning_rate: f64 },
/// An adaptive gradient descent optimizer.
Adagrad {
learning_rate: f64,
epsilon: f64,
Expand Down
2 changes: 1 addition & 1 deletion src/optimizer/sgd.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
//! Stochastic Gradient Descent (SGD)
//! Stochastic Gradient Descent (SGD).
//!
//! A basic optimizer that updates the weights based on the gradients of the loss function
//! with respect to the weights multiplied by a learning rate.
Expand Down

0 comments on commit 3cdeb5d

Please sign in to comment.