From 876bb81cb78159b1dc003d0f5b0090db96e849c7 Mon Sep 17 00:00:00 2001 From: "Andrew X. Shah" Date: Sun, 10 Mar 2024 07:58:05 -0600 Subject: [PATCH] feat(tensor): add grad field, create _grad methods --- src/linalg/tensor.rs | 15 ++++++----- src/linalg/tensor/computation.rs | 1 + src/linalg/tensor/creation.rs | 2 ++ src/linalg/tensor/from.rs | 12 +++++++++ src/linalg/tensor/iter.rs | 7 +++++- src/linalg/tensor/manipulation.rs | 1 + src/linalg/tensor/state.rs | 42 +++++++++++++++++++++++++++++++ 7 files changed, 73 insertions(+), 7 deletions(-) create mode 100644 src/linalg/tensor/state.rs diff --git a/src/linalg/tensor.rs b/src/linalg/tensor.rs index 1b41318..814c223 100644 --- a/src/linalg/tensor.rs +++ b/src/linalg/tensor.rs @@ -14,21 +14,24 @@ mod macros; mod manipulation; mod ops; mod sparse; +mod state; mod transformation; mod validation; -/// A one-dimensional matrix of floating point values. +/// A one-dimensional tensor of floating point values. pub type Tensor1D = Vec; -/// A two-dimensional matrix of floating point values. +/// A two-dimensional tensor of floating point values. pub type Tensor2D = Vec; -/// A matrix of floating point values, represented as a two-dimensional vector. +/// A tensor of floating point values. #[derive(Clone, Debug)] pub struct Tensor { - /// The number of rows in the matrix. + /// The number of rows in the tensor. pub rows: usize, - /// The number of columns in the matrix. + /// The number of columns in the tensor. pub cols: usize, - /// The data in the matrix, represented as a two-dimensional vector. + /// The data in the tensor, represented as a two-dimensional vector. pub data: Tensor2D, + /// Gradient of the tensor. + pub grad: Option, } diff --git a/src/linalg/tensor/computation.rs b/src/linalg/tensor/computation.rs index 335743a..f903b2e 100644 --- a/src/linalg/tensor/computation.rs +++ b/src/linalg/tensor/computation.rs @@ -49,6 +49,7 @@ impl Tensor { rows: self.rows, cols: self.cols, data, + grad: self.grad.clone(), } } diff --git a/src/linalg/tensor/creation.rs b/src/linalg/tensor/creation.rs index 3213bba..0846dac 100644 --- a/src/linalg/tensor/creation.rs +++ b/src/linalg/tensor/creation.rs @@ -16,6 +16,7 @@ impl Tensor { rows, cols, data: vec![vec![0.0; cols]; rows], + grad: None, } } @@ -64,6 +65,7 @@ impl Tensor { rows, cols, data: vec![vec![1.0; cols]; rows], + grad: None, } } diff --git a/src/linalg/tensor/from.rs b/src/linalg/tensor/from.rs index 7629d41..278446b 100644 --- a/src/linalg/tensor/from.rs +++ b/src/linalg/tensor/from.rs @@ -6,6 +6,18 @@ impl From for Tensor { rows: data.len(), cols: data[0].len(), data, + grad: None, + } + } +} + +impl From<&Tensor2D> for Tensor { + fn from(data: &Tensor2D) -> Self { + Tensor { + rows: data.len(), + cols: data[0].len(), + data: data.to_owned(), + grad: None, } } } diff --git a/src/linalg/tensor/iter.rs b/src/linalg/tensor/iter.rs index 67f8703..b1ca7a6 100644 --- a/src/linalg/tensor/iter.rs +++ b/src/linalg/tensor/iter.rs @@ -42,6 +42,11 @@ impl FromIterator for Tensor { } let rows = data.len(); let cols = data[0].len(); - Tensor { rows, cols, data } + Tensor { + rows, + cols, + data, + grad: None, + } } } diff --git a/src/linalg/tensor/manipulation.rs b/src/linalg/tensor/manipulation.rs index 77c6e60..81a27c3 100644 --- a/src/linalg/tensor/manipulation.rs +++ b/src/linalg/tensor/manipulation.rs @@ -123,6 +123,7 @@ impl Tensor { rows: end - start, cols: self.cols, data: self.data[start..end].to_vec(), + grad: None, } } } diff --git a/src/linalg/tensor/state.rs b/src/linalg/tensor/state.rs new file mode 100644 index 0000000..3d911fb --- /dev/null +++ b/src/linalg/tensor/state.rs @@ -0,0 +1,42 @@ +use crate::Tensor; + +impl Tensor { + /// Sets the gradient of the tensor. + /// + /// # Examples + /// + /// ``` + /// # use engram::*; + /// let mut a = tensor![[1.0, 2.0], [3.0, 4.0]]; + /// a.set_grad(vec![vec![1.0, 2.0], vec![3.0, 4.0]]); + /// assert_eq!(a.grad, Some(vec![vec![1.0, 2.0], vec![3.0, 4.0]])); + /// ``` + pub fn set_grad(&mut self, grad: Vec>) { + self.grad = Some(grad); + } + + /// Zeros out the gradient of the tensor. + /// + /// # Examples + /// + /// ``` + /// # use engram::*; + /// let mut a = tensor![[1.0, 2.0], [3.0, 4.0]]; + /// a.zero_grad(); + /// assert_eq!(a.grad, Some(vec![vec![0.0, 0.0], vec![0.0, 0.0]])); + /// ``` + pub fn zero_grad(&mut self) { + match &mut self.grad { + Some(grad) => { + for row in grad.iter_mut() { + for val in row.iter_mut() { + *val = 0.0; + } + } + } + None => { + self.grad = Some(vec![vec![0.0; self.cols]; self.rows]); + } + } + } +}