Skip to content

Commit

Permalink
test(mlp): add simple constant unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
krishnachittur committed Sep 9, 2023
1 parent 3980caf commit c8bc9f6
Showing 1 changed file with 52 additions and 0 deletions.
52 changes: 52 additions & 0 deletions src/neural_network/mlp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ impl Network {
let mut output = inputs.clone();
for layer in &mut self.layers {
output = layer.feed_forward(&output);
println!("Output: {:?}", output.data);
}
output
}
Expand Down Expand Up @@ -291,3 +292,54 @@ impl Network {
}
}
}

#[cfg(test)]
mod tests {
use crate::*;

#[test]
fn test_correct_constant_network() {
// Test a simple network with a 1x1 layer and a 1x1 output.
let mut network = Network::new(
&[1, 1],
Initializer::Constant(1.),
Activation::ReLU,
LossFunction::MeanSquaredError,
Optimizer::SGD { learning_rate: 0.1 },
);
// The outputs are just 1 times the input plus 1, so the goal is for the
// network to learn the weights [[1.0]] and bias [1.0].
let inputs = tensor![[0.0], [1.0], [2.0], [3.0]];
let targets = tensor![[1.0], [2.0], [3.0], [4.0]];

network.train(&inputs, &targets, 4, 10);
let output = network.predict(&[4.0]);
let expected = 5.0;
let prediction = output.data[0][0];
println!("Predicted: {:.2}, Expected: {:.2}", prediction, expected);
assert!((expected - prediction).abs() < 0.1);
}

#[test]
fn test_constant_network() {
// Again test a network with a 1x1 layer and a 1x1 output, but this time
// we want the weight to stay at 0 and the bias to increase to 1.
let mut network = Network::new(
&[1, 1],
Initializer::Constant(0.),
Activation::ReLU,
LossFunction::MeanSquaredError,
Optimizer::SGD { learning_rate: 0.1 },
);
// The outputs are just 1 times the input plus 0.
let inputs = tensor![[0.0], [1.0], [2.0], [3.0]];
let targets = tensor![[1.], [1.], [1.], [1.]];

network.train(&inputs, &targets, 4, 2);
let output = network.predict(&[4.0]);
let expected = 4.0;
let prediction = output.data[0][0];
println!("Predicted: {:.2}, Expected: {:.2}", prediction, expected);
assert!((expected - prediction).abs() < 0.1);
}
}

0 comments on commit c8bc9f6

Please sign in to comment.