Skip to content

Commit

Permalink
test(mlp/new,default): add doc tests
Browse files Browse the repository at this point in the history
  • Loading branch information
drewxs committed Sep 7, 2023
1 parent faa5290 commit 9b7a23b
Showing 1 changed file with 43 additions and 0 deletions.
43 changes: 43 additions & 0 deletions src/neural_network/mlp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,35 @@ pub struct Network {

impl Network {
/// Creates a new `Network` with the specified layers, activation function, learning rate, and optimizer.
///
/// # Examples
///
/// ```
/// # use engram::*;
/// let network = Network::new(
/// &[6, 4, 1],
/// Initializer::Xavier,
/// Activation::ReLU,
/// LossFunction::MeanSquaredError,
/// Optimizer::Adagrad {
/// learning_rate: 0.1,
/// epsilon: 1e-8,
/// weight_decay: Some(0.01),
/// shape: (3, 1),
/// },
/// );
/// assert_eq!(network.layers.len(), 2);
/// assert_eq!(network.layers[0].weights.shape(), (6, 4));
/// assert_eq!(network.layers[1].weights.shape(), (4, 1));
/// assert_eq!(network.layers[0].activation, Activation::ReLU);
/// assert_eq!(network.loss_function, LossFunction::MeanSquaredError);
/// assert_eq!(network.optimizer, Optimizer::Adagrad {
/// learning_rate: 0.1,
/// epsilon: 1e-8,
/// weight_decay: Some(0.01),
/// shape: (3, 1),
/// })
/// ```
pub fn new(
layer_sizes: &[usize],
initializer: Initializer,
Expand Down Expand Up @@ -42,6 +71,20 @@ impl Network {

/// Creates a new `Network` with defaults: xavier initialization, sigmoid activation,
/// and stochastic gradient descent optimizer with a learning rate of 0.1.
///
/// # Examples
///
/// ```
/// # use engram::*;
/// let network = Network::default(&[6, 4, 2, 3]);
/// assert_eq!(network.layers.len(), 3);
/// assert_eq!(network.layers[0].weights.shape(), (6, 4));
/// assert_eq!(network.layers[1].weights.shape(), (4, 2));
/// assert_eq!(network.layers[2].weights.shape(), (2, 3));
/// assert_eq!(network.layers[0].activation, Activation::Sigmoid);
/// assert_eq!(network.loss_function, LossFunction::BinaryCrossEntropy);
/// assert_eq!(network.optimizer, Optimizer::SGD { learning_rate: 0.1 })
/// ```
pub fn default(layer_sizes: &[usize]) -> Network {
Network::new(
layer_sizes,
Expand Down

0 comments on commit 9b7a23b

Please sign in to comment.