Skip to content

Commit

Permalink
feat(tensor): add grad field, create <set|zero>_grad methods
Browse files Browse the repository at this point in the history
  • Loading branch information
drewxs committed Mar 10, 2024
1 parent bfbc6e3 commit 876bb81
Show file tree
Hide file tree
Showing 7 changed files with 73 additions and 7 deletions.
15 changes: 9 additions & 6 deletions src/linalg/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,24 @@ mod macros;
mod manipulation;
mod ops;
mod sparse;
mod state;
mod transformation;
mod validation;

/// A one-dimensional matrix of floating point values.
/// A one-dimensional tensor of floating point values.
pub type Tensor1D = Vec<f64>;
/// A two-dimensional matrix of floating point values.
/// A two-dimensional tensor of floating point values.
pub type Tensor2D = Vec<Tensor1D>;

/// A matrix of floating point values, represented as a two-dimensional vector.
/// A tensor of floating point values.
#[derive(Clone, Debug)]
pub struct Tensor {
/// The number of rows in the matrix.
/// The number of rows in the tensor.
pub rows: usize,
/// The number of columns in the matrix.
/// The number of columns in the tensor.
pub cols: usize,
/// The data in the matrix, represented as a two-dimensional vector.
/// The data in the tensor, represented as a two-dimensional vector.
pub data: Tensor2D,
/// Gradient of the tensor.
pub grad: Option<Tensor2D>,
}
1 change: 1 addition & 0 deletions src/linalg/tensor/computation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ impl Tensor {
rows: self.rows,
cols: self.cols,
data,
grad: self.grad.clone(),
}
}

Expand Down
2 changes: 2 additions & 0 deletions src/linalg/tensor/creation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ impl Tensor {
rows,
cols,
data: vec![vec![0.0; cols]; rows],
grad: None,
}
}

Expand Down Expand Up @@ -64,6 +65,7 @@ impl Tensor {
rows,
cols,
data: vec![vec![1.0; cols]; rows],
grad: None,
}
}

Expand Down
12 changes: 12 additions & 0 deletions src/linalg/tensor/from.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,18 @@ impl From<Tensor2D> for Tensor {
rows: data.len(),
cols: data[0].len(),
data,
grad: None,
}
}
}

impl From<&Tensor2D> for Tensor {
fn from(data: &Tensor2D) -> Self {
Tensor {
rows: data.len(),
cols: data[0].len(),
data: data.to_owned(),
grad: None,
}
}
}
7 changes: 6 additions & 1 deletion src/linalg/tensor/iter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ impl FromIterator<Tensor1D> for Tensor {
}
let rows = data.len();
let cols = data[0].len();
Tensor { rows, cols, data }
Tensor {
rows,
cols,
data,
grad: None,
}
}
}
1 change: 1 addition & 0 deletions src/linalg/tensor/manipulation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ impl Tensor {
rows: end - start,
cols: self.cols,
data: self.data[start..end].to_vec(),
grad: None,
}
}
}
42 changes: 42 additions & 0 deletions src/linalg/tensor/state.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
use crate::Tensor;

impl Tensor {
/// Sets the gradient of the tensor.
///
/// # Examples
///
/// ```
/// # use engram::*;
/// let mut a = tensor![[1.0, 2.0], [3.0, 4.0]];
/// a.set_grad(vec![vec![1.0, 2.0], vec![3.0, 4.0]]);
/// assert_eq!(a.grad, Some(vec![vec![1.0, 2.0], vec![3.0, 4.0]]));
/// ```
pub fn set_grad(&mut self, grad: Vec<Vec<f64>>) {
self.grad = Some(grad);
}

/// Zeros out the gradient of the tensor.
///
/// # Examples
///
/// ```
/// # use engram::*;
/// let mut a = tensor![[1.0, 2.0], [3.0, 4.0]];
/// a.zero_grad();
/// assert_eq!(a.grad, Some(vec![vec![0.0, 0.0], vec![0.0, 0.0]]));
/// ```
pub fn zero_grad(&mut self) {
match &mut self.grad {
Some(grad) => {
for row in grad.iter_mut() {
for val in row.iter_mut() {
*val = 0.0;
}
}
}
None => {
self.grad = Some(vec![vec![0.0; self.cols]; self.rows]);
}
}
}
}

0 comments on commit 876bb81

Please sign in to comment.