Skip to content

Commit

Permalink
add flatten layer
Browse files Browse the repository at this point in the history
  • Loading branch information
ChanLumerico committed Apr 17, 2024
1 parent 73f60c6 commit 439c765
Showing 1 changed file with 25 additions and 6 deletions.
31 changes: 25 additions & 6 deletions luma/neural/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from luma.neural.optimizer import SGDOptimizer


__all__ = ("Convolution", "Pooling", "Dense", "Dropout")
__all__ = ("Convolution", "Pooling", "Dense", "Dropout", "Flatten")


class Convolution(Layer):
Expand Down Expand Up @@ -305,14 +305,13 @@ def __init__(

def forward(self, X: Tensor | Matrix) -> Tensor:
self.input_ = X
X = self._flatten(X)

out = np.dot(X, self.weights_) + self.biases_
out = self.act_.func(out)
return out

def backward(self, d_out: Tensor, reshape: bool = True) -> Tensor:
X = self._flatten(self.input_)
X = self.input_
d_out = self.act_.derivative(d_out)

self.dX = np.dot(d_out, self.weights_.T)
Expand All @@ -326,9 +325,6 @@ def backward(self, d_out: Tensor, reshape: bool = True) -> Tensor:
else:
return self.dX

def _flatten(self, X: Tensor) -> Matrix:
return X.reshape(X.shape[0], -1) if len(X.shape) > 2 else X


class Dropout(Layer):
"""
Expand Down Expand Up @@ -370,3 +366,26 @@ def forward(self, X: Tensor, is_train: bool = False) -> Tensor:
def backward(self, d_out: Tensor) -> Tensor:
dX = d_out * self.mask_ if self.mask_ is not None else d_out
return dX


class Flatten(Layer):
"""
A flatten layer reshapes the input tensor into a 2D array(`Matrix`),
collapsing all dimensions except the batch dimension.
Notes
-----
- Use this class when using `Dense` layer.
Flatten the tensor into matrix in order to feed-forward dense layer(s).
"""

def __init__(self) -> None:
super().__init__()

def forward(self, X: Tensor) -> Matrix:
self.input_ = X
return X.reshape(X.shape[0], -1)

def backward(self, d_out: Matrix) -> Tensor:
dX = d_out.reshape(self.input_.shape)
return dX

0 comments on commit 439c765

Please sign in to comment.