Skip to content

Commit

Permalink
feat(mlp): create eval mode, predict, track loss from backprop
Browse files Browse the repository at this point in the history
  • Loading branch information
drewxs committed Sep 8, 2023
1 parent 3cdeb5d commit 4557f23
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 16 deletions.
23 changes: 17 additions & 6 deletions src/neural_network/layer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ pub struct Layer {
pub inputs: Option<Tensor>,
pub output: Option<Tensor>,
pub activation: Activation,
pub evaluation_mode: bool,
}

impl Layer {
Expand Down Expand Up @@ -46,6 +47,7 @@ impl Layer {
inputs: None,
output: None,
activation,
evaluation_mode: false,
}
}

Expand All @@ -65,8 +67,11 @@ impl Layer {
let biases = self.biases.broadcast_to(&weighted_sum);
let output = weighted_sum.add(&biases).activate(&self.activation);

self.inputs = Some(inputs.clone());
self.output = Some(output.clone());
if !self.evaluation_mode {
self.inputs = Some(inputs.clone());
self.output = Some(output.clone());
}

output
}

Expand All @@ -89,16 +94,16 @@ impl Layer {
targets: &Tensor,
loss_function: &LossFunction,
optimizer: &mut Optimizer,
) {
) -> f64 {
let output = match &self.output {
Some(output) => output,
None => panic!("Call to back_propagate without calling feed_forward first!"),
};

// Compute loss
let d_loss = loss_function
.loss(&output, &targets.resize_to(&output))
.activate(&self.activation);
let loss = loss_function.loss(&output, &targets.resize_to(&output));
let d_loss = loss.activate(&self.activation);
let mean_loss = loss.mean();

let inputs = self.inputs.as_ref().unwrap();
let num_samples = inputs.rows as f64;
Expand All @@ -116,5 +121,11 @@ impl Layer {
// Update weights and biases based on gradients
self.weights.sub_assign(&self.d_weights);
self.biases.sub_assign(&self.d_biases);

mean_loss
}

pub fn set_evaluation_mode(&mut self, evaluation_mode: bool) {
self.evaluation_mode = evaluation_mode;
}
}
61 changes: 51 additions & 10 deletions src/neural_network/mlp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ impl Network {
let mut output = inputs.clone();
for layer in &mut self.layers {
output = layer.feed_forward(&output);
println!("output: {:?}", output.data[0][0]);
}
output
}
Expand All @@ -161,7 +162,7 @@ impl Network {
/// assert_ne!(network.layers[0].inputs, None);
/// assert_ne!(network.layers[0].output, None);
/// ```
pub fn back_propagate(&mut self, targets: &Tensor) {
pub fn back_propagate(&mut self, targets: &Tensor) -> f64 {
let final_layer_output_shape = self.layers.last().unwrap().output.as_ref().unwrap().shape();
if targets.shape() != final_layer_output_shape {
panic!(
Expand All @@ -170,36 +171,76 @@ impl Network {
final_layer_output_shape
);
}
let mut loss = 0.0;
for layer in self.layers.iter_mut().rev() {
layer.back_propagate(targets, &self.loss_function, &mut self.optimizer);
loss += layer.back_propagate(targets, &self.loss_function, &mut self.optimizer);
}
loss
}

/// Trains the network on the specified inputs and targets for the specified number of epochs.
///
/// # Examples
///
/// ```
/// # use engram::*;
/// let mut network = Network::default(&[2, 2, 1]);
/// let inputs = tensor![[0.0, 0.0], [0.0, 1.0], [1.0, 0.0], [1.0, 1.0]];
/// let targets = tensor![[0.0], [1.0], [1.0], [0.0]];
/// network.train(&inputs, &targets, 1, 10);
/// let output = network.predict(&[1.0, 0.0]);
/// let expected = 1.0;
/// let prediction = output.data[0][0];
/// println!("Predicted: {:.2}, Expected: {:.2}", prediction, expected);
/// assert!((expected - prediction).abs() < 0.1);
/// ```
pub fn train(&mut self, inputs: &Tensor, targets: &Tensor, batch_size: usize, epochs: usize) {
if targets.cols != self.layers.last().unwrap().weights.cols {
panic!(
"Target cols {:?} does not match the final layer's output cols {:?}",
targets.shape(),
self.layers.last().unwrap().weights.shape()
);
}

self.set_evaluation_mode(false);

let num_batches = (inputs.rows as f64 / batch_size as f64).ceil() as usize;

for epoch in 0..epochs {
let mut error_sum = 0.0;
let mut total_loss = 0.0;

for batch in 0..num_batches {
let batch_start = batch * batch_size;
let batch_end = (batch + 1) * batch_size;
let inputs_batch = &inputs.slice(batch_start, batch_end);
let targets_batch = &targets.slice(batch_start, batch_end);

let outputs = self.feed_forward(&inputs_batch);
let error = targets_batch.sub(&outputs);

self.back_propagate(&targets_batch);
self.feed_forward(&inputs_batch);
let loss = self.back_propagate(&targets_batch);

error_sum += error.sum();
total_loss += loss;
}

if epoch % 10 == 0 {
let mse = error_sum / (inputs.rows as f64);
println!("Epoch: {}, MSE: {}", epoch, mse);
let avg_loss = total_loss / (inputs.rows as f64);
println!("Epoch: {}, Avg loss: {}", epoch, avg_loss);
}
}
}

pub fn predict(&mut self, inputs: &[f64]) -> Tensor {
self.set_evaluation_mode(true);

let inputs = Tensor::from(vec![inputs.to_vec()]);
let output = self.feed_forward(&inputs);

output
}

fn set_evaluation_mode(&mut self, evaluation_mode: bool) {
for layer in &mut self.layers {
layer.set_evaluation_mode(evaluation_mode);
}
}
}

0 comments on commit 4557f23

Please sign in to comment.