Skip to content

Commit

Permalink
fix(layer): resolve tensor shapes for initialization & feed_forward
Browse files Browse the repository at this point in the history
Broadcast biases to weighted_sum shape to allow element-wise addition.
  • Loading branch information
drewxs committed Jul 23, 2023
1 parent 8cff008 commit ea5f522
Showing 1 changed file with 10 additions and 7 deletions.
17 changes: 10 additions & 7 deletions src/layer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,17 @@ impl Layer {
/// ```
/// # use engram::{Layer, Initializer};
/// let layer = Layer::new(2, 3, &Initializer::Xavier);
/// assert_eq!(layer.weights.shape(), (2, 3));
/// assert_eq!(layer.weights.shape(), (3, 2));
/// assert_eq!(layer.biases.shape(), (3, 1));
/// assert_eq!(layer.d_weights.shape(), (2, 3));
/// assert_eq!(layer.d_biases.shape(), (1, 3));
/// assert_eq!(layer.d_weights.shape(), (3, 2));
/// assert_eq!(layer.d_biases.shape(), (3, 1));
/// assert!(layer.output.is_none());
/// ```
pub fn new(f_in: usize, f_out: usize, initializer: &Initializer) -> Layer {
let weights = Tensor::initialize(f_in, f_out, initializer);
let weights = Tensor::initialize(f_out, f_in, initializer);
let biases = Tensor::initialize(f_out, 1, initializer);
let d_weights = Tensor::zeros(f_in, f_out);
let d_biases = Tensor::zeros(1, f_out);
let d_weights = Tensor::zeros(f_out, f_in);
let d_biases = Tensor::zeros(f_out, 1);
let output = None;

Layer {
Expand All @@ -58,7 +58,10 @@ impl Layer {
/// assert_eq!(output.shape(), (3, 3));
/// ```
pub fn feed_forward(&mut self, inputs: &Tensor, activation: Activation) -> Tensor {
let output = activation.apply_tensor(&(self.weights).matmul(inputs).add(&self.biases));
let weighted_sum = inputs.matmul(&self.weights.transpose());
let biases_broadcasted = self.biases.broadcast_to(weighted_sum.shape());
let output = activation.apply_tensor(&weighted_sum.add(&biases_broadcasted));

self.output = Some(output.clone());
output
}
Expand Down

0 comments on commit ea5f522

Please sign in to comment.