Skip to content

Commit

Permalink
feat(loss): impl gradient loss
Browse files Browse the repository at this point in the history
  • Loading branch information
drewxs committed Nov 28, 2023
1 parent 21a4b82 commit 4709fdf
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 31 deletions.
24 changes: 22 additions & 2 deletions src/loss/bce.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ use crate::Tensor;
/// # use engram::*;
/// let predictions = tensor![[0.9, 0.1, 0.2], [0.1, 0.9, 0.8]];
/// let targets = tensor![[1.0, 0.0, 0.0], [0.0, 1.0, 1.0]];
/// let loss = binary_cross_entropy(&predictions, &targets);
/// let loss = bce(&predictions, &targets);
/// assert_eq!(loss, tensor![[0.10536051565782628, 0.10536051565782628, 0.2231435513142097],
/// [0.10536051565782628, 0.10536051565782628, 0.2231435513142097]]);
/// ```
pub fn binary_cross_entropy(predictions: &Tensor, targets: &Tensor) -> Tensor {
pub fn bce(predictions: &Tensor, targets: &Tensor) -> Tensor {
assert_eq!(
predictions.shape(),
targets.shape(),
Expand All @@ -34,3 +34,23 @@ pub fn binary_cross_entropy(predictions: &Tensor, targets: &Tensor) -> Tensor {

loss
}

pub fn d_bce(predictions: &Tensor, targets: &Tensor) -> Tensor {
assert_eq!(
predictions.shape(),
targets.shape(),
"Shapes of predictions and targets must match."
);

let epsilon = 1e-15; // Small constant to avoid log(0)
let predictions = predictions.clip(epsilon, 1.0 - epsilon);
let ones = Tensor::ones_like(&predictions);
let predictions_complement = ones.sub(&predictions);
let targets_complement = ones.sub(&targets);

let gradient = targets
.div(&predictions)
.sub(&targets_complement.div(&predictions_complement));

gradient
}
15 changes: 11 additions & 4 deletions src/loss/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,23 @@ use crate::Tensor;

/// Loss functions that can be used to train a neural network.
#[derive(Debug, Clone, Copy, PartialEq, Hash)]
pub enum LossFunction {
pub enum Loss {
BinaryCrossEntropy,
MeanSquaredError,
}

impl LossFunction {
impl Loss {
pub fn loss(&self, predictions: &Tensor, targets: &Tensor) -> Tensor {
match self {
LossFunction::BinaryCrossEntropy => binary_cross_entropy(predictions, targets),
LossFunction::MeanSquaredError => mean_squared_error(predictions, targets),
Loss::BinaryCrossEntropy => bce(predictions, targets),
Loss::MeanSquaredError => mse(predictions, targets),
}
}

pub fn gradient(&self, predictions: &Tensor, targets: &Tensor) -> Tensor {
match self {
Loss::BinaryCrossEntropy => d_bce(predictions, targets),
Loss::MeanSquaredError => d_mse(predictions, targets),
}
}
}
8 changes: 6 additions & 2 deletions src/loss/mse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,14 @@ use crate::Tensor;
/// # use engram::*;
/// let predictions = tensor![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
/// let targets = tensor![[1.2, 2.4, 3.1], [4.4, 4.7, 5.9]];
/// let loss = mean_squared_error(&predictions, &targets);
/// let loss = mse(&predictions, &targets);
/// assert_eq!(loss, tensor![[0.03999999999999998, 0.15999999999999992, 0.010000000000000018],
/// [0.16000000000000028, 0.0899999999999999, 0.009999999999999929]]);
/// ```
pub fn mean_squared_error(predictions: &Tensor, targets: &Tensor) -> Tensor {
pub fn mse(predictions: &Tensor, targets: &Tensor) -> Tensor {
predictions.sub(targets).square()
}

pub fn d_mse(predictions: &Tensor, targets: &Tensor) -> Tensor {
predictions.sub(targets).mul_scalar(2.0)
}
4 changes: 2 additions & 2 deletions src/neural_network/layer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
//! This module provides a Layer struct for representing a single layer in a neural network,
//! along with methods for feeding inputs through the layer and performing backpropagation.

use crate::{Activation, Initializer, LossFunction, Optimize, Optimizer, Tensor};
use crate::{Activation, Initializer, Loss, Optimize, Optimizer, Tensor};

/// A single layer in a neural network.
#[derive(Debug)]
Expand Down Expand Up @@ -116,7 +116,7 @@ impl Layer {
pub fn back_propagate(
&mut self,
targets: &Tensor,
loss_function: &LossFunction,
loss_function: &Loss,
optimizer: &mut Optimizer,
) -> f64 {
let output = match &self.output {
Expand Down
20 changes: 10 additions & 10 deletions src/neural_network/mlp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
//! A generic feed forward neural network (FNN), also known as a multi-layer perceptron (MLP).
//! Typically used for classification and regression tasks.

use crate::{Activation, Initializer, Layer, LossFunction, Optimizer, Tensor};
use crate::{Activation, Initializer, Layer, Loss, Optimizer, Tensor};

#[derive(Debug)]
pub struct Network {
/// The layers in the network.
pub layers: Vec<Layer>,
/// The loss function used to train the network.
pub loss_function: LossFunction,
pub loss_function: Loss,
/// The initializer used to initialize the weights and biases in the network.
pub initializer: Initializer,
/// The optimizer used to optimize the weights and biases in the network.
Expand All @@ -29,7 +29,7 @@ impl Network {
/// &[6, 4, 1],
/// Initializer::Xavier,
/// Activation::ReLU,
/// LossFunction::MeanSquaredError,
/// Loss::MeanSquaredError,
/// Optimizer::Adagrad {
/// learning_rate: 0.1,
/// shape: (3, 1),
Expand All @@ -42,7 +42,7 @@ impl Network {
/// assert_eq!(network.layers[0].weights.shape(), (6, 4));
/// assert_eq!(network.layers[1].weights.shape(), (4, 1));
/// assert_eq!(network.layers[0].activation, Activation::ReLU);
/// assert_eq!(network.loss_function, LossFunction::MeanSquaredError);
/// assert_eq!(network.loss_function, Loss::MeanSquaredError);
/// assert_eq!(network.optimizer, Optimizer::Adagrad {
/// learning_rate: 0.1,
/// shape: (3, 1),
Expand All @@ -54,7 +54,7 @@ impl Network {
layer_sizes: &[usize],
initializer: Initializer,
activation: Activation,
loss_function: LossFunction,
loss_function: Loss,
optimizer: Optimizer,
) -> Network {
let mut layers = Vec::new();
Expand Down Expand Up @@ -92,15 +92,15 @@ impl Network {
/// assert_eq!(network.layers[1].weights.shape(), (4, 2));
/// assert_eq!(network.layers[2].weights.shape(), (2, 3));
/// assert_eq!(network.layers[0].activation, Activation::Sigmoid);
/// assert_eq!(network.loss_function, LossFunction::BinaryCrossEntropy);
/// assert_eq!(network.loss_function, Loss::MeanSquaredError);
/// assert_eq!(network.optimizer, Optimizer::SGD { learning_rate: 0.1 })
/// ```
pub fn default(layer_sizes: &[usize]) -> Network {
Network::new(
layer_sizes,
Initializer::Xavier,
Activation::Sigmoid,
LossFunction::BinaryCrossEntropy,
Loss::MeanSquaredError,
Optimizer::SGD { learning_rate: 0.1 },
)
}
Expand Down Expand Up @@ -228,7 +228,7 @@ impl Network {
/// &[2, 3, 1],
/// Initializer::Xavier,
/// Activation::ReLU,
/// LossFunction::MeanSquaredError,
/// Loss::MeanSquaredError,
/// Optimizer::SGD { learning_rate: 0.1 },
/// );
///
Expand Down Expand Up @@ -323,7 +323,7 @@ mod tests {
&[1, 1],
Initializer::Constant(1.),
Activation::ReLU,
LossFunction::MeanSquaredError,
Loss::MeanSquaredError,
Optimizer::SGD { learning_rate: 0.1 },
);
// The outputs are just 1 times the input plus 1, so the goal is for the
Expand All @@ -348,7 +348,7 @@ mod tests {
&[1, 1],
Initializer::Constant(0.),
Activation::ReLU,
LossFunction::MeanSquaredError,
Loss::MeanSquaredError,
Optimizer::SGD { learning_rate: 0.1 },
);
// The outputs are just 1 times the input plus 0.
Expand Down
20 changes: 9 additions & 11 deletions src/neural_network/sequential.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
//!
//! Allows creating a network using a builder pattern.

use crate::{Activation, Initializer, Layer, LossFunction, Network, Optimizer};
use crate::{Activation, Initializer, Layer, Loss, Network, Optimizer};

/// A builder for creating a `Network`.
///
Expand All @@ -12,20 +12,20 @@ use crate::{Activation, Initializer, Layer, LossFunction, Network, Optimizer};
/// # use engram::*;
/// let network = Sequential::new(&[2, 3, 1])
/// .activation(Activation::ReLU)
/// .loss_function(LossFunction::MeanSquaredError)
/// .loss_function(Loss::MeanSquaredError)
/// .optimizer(Optimizer::SGD { learning_rate: 0.1 })
/// .build();
/// assert_eq!(network.layers[0].weights.shape(), (2, 3));
/// assert_eq!(network.layers[1].weights.shape(), (3, 1));
/// assert_eq!(network.layers[0].activation, Activation::ReLU);
/// assert_eq!(network.loss_function, LossFunction::MeanSquaredError);
/// assert_eq!(network.loss_function, Loss::MeanSquaredError);
/// assert_eq!(network.optimizer, Optimizer::SGD { learning_rate: 0.1 })
/// ```
pub struct Sequential {
layer_sizes: Vec<usize>,
initializer: Option<Initializer>,
activation: Option<Activation>,
loss_function: Option<LossFunction>,
loss_function: Option<Loss>,
optimizer: Option<Optimizer>,
}

Expand Down Expand Up @@ -84,10 +84,10 @@ impl Sequential {
///
/// ```
/// # use engram::*;
/// let network = Sequential::new(&[2, 3, 1]).loss_function(LossFunction::MeanSquaredError).build();
/// assert_eq!(network.loss_function, LossFunction::MeanSquaredError);
/// let network = Sequential::new(&[2, 3, 1]).loss_function(Loss::MeanSquaredError).build();
/// assert_eq!(network.loss_function, Loss::MeanSquaredError);
/// ```
pub fn loss_function(mut self, loss_function: LossFunction) -> Self {
pub fn loss_function(mut self, loss_function: Loss) -> Self {
self.loss_function = Some(loss_function);
self
}
Expand Down Expand Up @@ -116,15 +116,13 @@ impl Sequential {
/// assert_eq!(network.layers[0].weights.shape(), (2, 3));
/// assert_eq!(network.layers[1].weights.shape(), (3, 1));
/// assert_eq!(network.layers[0].activation, Activation::Sigmoid);
/// assert_eq!(network.loss_function, LossFunction::BinaryCrossEntropy);
/// assert_eq!(network.loss_function, Loss::MeanSquaredError);
/// assert_eq!(network.optimizer, Optimizer::SGD { learning_rate: 0.1 })
/// ```
pub fn build(self) -> Network {
let initializer = self.initializer.unwrap_or(Initializer::Xavier);
let activation = self.activation.unwrap_or(Activation::Sigmoid);
let loss_function = self
.loss_function
.unwrap_or(LossFunction::BinaryCrossEntropy);
let loss_function = self.loss_function.unwrap_or(Loss::MeanSquaredError);
let optimizer = self
.optimizer
.unwrap_or(Optimizer::SGD { learning_rate: 0.1 });
Expand Down

0 comments on commit 4709fdf

Please sign in to comment.