From 4557f23538d28ce1c3b08943c1246afd421a1df6 Mon Sep 17 00:00:00 2001 From: "Andrew X. Shah" Date: Fri, 8 Sep 2023 14:57:26 -0600 Subject: [PATCH] feat(mlp): create eval mode, predict, track loss from backprop --- src/neural_network/layer.rs | 23 ++++++++++---- src/neural_network/mlp.rs | 61 +++++++++++++++++++++++++++++++------ 2 files changed, 68 insertions(+), 16 deletions(-) diff --git a/src/neural_network/layer.rs b/src/neural_network/layer.rs index 02bb004..09fac82 100644 --- a/src/neural_network/layer.rs +++ b/src/neural_network/layer.rs @@ -15,6 +15,7 @@ pub struct Layer { pub inputs: Option, pub output: Option, pub activation: Activation, + pub evaluation_mode: bool, } impl Layer { @@ -46,6 +47,7 @@ impl Layer { inputs: None, output: None, activation, + evaluation_mode: false, } } @@ -65,8 +67,11 @@ impl Layer { let biases = self.biases.broadcast_to(&weighted_sum); let output = weighted_sum.add(&biases).activate(&self.activation); - self.inputs = Some(inputs.clone()); - self.output = Some(output.clone()); + if !self.evaluation_mode { + self.inputs = Some(inputs.clone()); + self.output = Some(output.clone()); + } + output } @@ -89,16 +94,16 @@ impl Layer { targets: &Tensor, loss_function: &LossFunction, optimizer: &mut Optimizer, - ) { + ) -> f64 { let output = match &self.output { Some(output) => output, None => panic!("Call to back_propagate without calling feed_forward first!"), }; // Compute loss - let d_loss = loss_function - .loss(&output, &targets.resize_to(&output)) - .activate(&self.activation); + let loss = loss_function.loss(&output, &targets.resize_to(&output)); + let d_loss = loss.activate(&self.activation); + let mean_loss = loss.mean(); let inputs = self.inputs.as_ref().unwrap(); let num_samples = inputs.rows as f64; @@ -116,5 +121,11 @@ impl Layer { // Update weights and biases based on gradients self.weights.sub_assign(&self.d_weights); self.biases.sub_assign(&self.d_biases); + + mean_loss + } + + pub fn set_evaluation_mode(&mut self, evaluation_mode: bool) { + self.evaluation_mode = evaluation_mode; } } diff --git a/src/neural_network/mlp.rs b/src/neural_network/mlp.rs index fbac8c6..ca65f76 100644 --- a/src/neural_network/mlp.rs +++ b/src/neural_network/mlp.rs @@ -141,6 +141,7 @@ impl Network { let mut output = inputs.clone(); for layer in &mut self.layers { output = layer.feed_forward(&output); + println!("output: {:?}", output.data[0][0]); } output } @@ -161,7 +162,7 @@ impl Network { /// assert_ne!(network.layers[0].inputs, None); /// assert_ne!(network.layers[0].output, None); /// ``` - pub fn back_propagate(&mut self, targets: &Tensor) { + pub fn back_propagate(&mut self, targets: &Tensor) -> f64 { let final_layer_output_shape = self.layers.last().unwrap().output.as_ref().unwrap().shape(); if targets.shape() != final_layer_output_shape { panic!( @@ -170,17 +171,44 @@ impl Network { final_layer_output_shape ); } + let mut loss = 0.0; for layer in self.layers.iter_mut().rev() { - layer.back_propagate(targets, &self.loss_function, &mut self.optimizer); + loss += layer.back_propagate(targets, &self.loss_function, &mut self.optimizer); } + loss } /// Trains the network on the specified inputs and targets for the specified number of epochs. + /// + /// # Examples + /// + /// ``` + /// # use engram::*; + /// let mut network = Network::default(&[2, 2, 1]); + /// let inputs = tensor![[0.0, 0.0], [0.0, 1.0], [1.0, 0.0], [1.0, 1.0]]; + /// let targets = tensor![[0.0], [1.0], [1.0], [0.0]]; + /// network.train(&inputs, &targets, 1, 10); + /// let output = network.predict(&[1.0, 0.0]); + /// let expected = 1.0; + /// let prediction = output.data[0][0]; + /// println!("Predicted: {:.2}, Expected: {:.2}", prediction, expected); + /// assert!((expected - prediction).abs() < 0.1); + /// ``` pub fn train(&mut self, inputs: &Tensor, targets: &Tensor, batch_size: usize, epochs: usize) { + if targets.cols != self.layers.last().unwrap().weights.cols { + panic!( + "Target cols {:?} does not match the final layer's output cols {:?}", + targets.shape(), + self.layers.last().unwrap().weights.shape() + ); + } + + self.set_evaluation_mode(false); + let num_batches = (inputs.rows as f64 / batch_size as f64).ceil() as usize; for epoch in 0..epochs { - let mut error_sum = 0.0; + let mut total_loss = 0.0; for batch in 0..num_batches { let batch_start = batch * batch_size; @@ -188,18 +216,31 @@ impl Network { let inputs_batch = &inputs.slice(batch_start, batch_end); let targets_batch = &targets.slice(batch_start, batch_end); - let outputs = self.feed_forward(&inputs_batch); - let error = targets_batch.sub(&outputs); - - self.back_propagate(&targets_batch); + self.feed_forward(&inputs_batch); + let loss = self.back_propagate(&targets_batch); - error_sum += error.sum(); + total_loss += loss; } if epoch % 10 == 0 { - let mse = error_sum / (inputs.rows as f64); - println!("Epoch: {}, MSE: {}", epoch, mse); + let avg_loss = total_loss / (inputs.rows as f64); + println!("Epoch: {}, Avg loss: {}", epoch, avg_loss); } } } + + pub fn predict(&mut self, inputs: &[f64]) -> Tensor { + self.set_evaluation_mode(true); + + let inputs = Tensor::from(vec![inputs.to_vec()]); + let output = self.feed_forward(&inputs); + + output + } + + fn set_evaluation_mode(&mut self, evaluation_mode: bool) { + for layer in &mut self.layers { + layer.set_evaluation_mode(evaluation_mode); + } + } }