From 439c765813dfbe09e154c3fc267229cfa6838781 Mon Sep 17 00:00:00 2001 From: Chan Lee Date: Thu, 18 Apr 2024 04:39:00 +0900 Subject: [PATCH] add flatten layer --- luma/neural/layer.py | 31 +++++++++++++++++++++++++------ 1 file changed, 25 insertions(+), 6 deletions(-) diff --git a/luma/neural/layer.py b/luma/neural/layer.py index 91cbf02..5580dc7 100644 --- a/luma/neural/layer.py +++ b/luma/neural/layer.py @@ -7,7 +7,7 @@ from luma.neural.optimizer import SGDOptimizer -__all__ = ("Convolution", "Pooling", "Dense", "Dropout") +__all__ = ("Convolution", "Pooling", "Dense", "Dropout", "Flatten") class Convolution(Layer): @@ -305,14 +305,13 @@ def __init__( def forward(self, X: Tensor | Matrix) -> Tensor: self.input_ = X - X = self._flatten(X) out = np.dot(X, self.weights_) + self.biases_ out = self.act_.func(out) return out def backward(self, d_out: Tensor, reshape: bool = True) -> Tensor: - X = self._flatten(self.input_) + X = self.input_ d_out = self.act_.derivative(d_out) self.dX = np.dot(d_out, self.weights_.T) @@ -326,9 +325,6 @@ def backward(self, d_out: Tensor, reshape: bool = True) -> Tensor: else: return self.dX - def _flatten(self, X: Tensor) -> Matrix: - return X.reshape(X.shape[0], -1) if len(X.shape) > 2 else X - class Dropout(Layer): """ @@ -370,3 +366,26 @@ def forward(self, X: Tensor, is_train: bool = False) -> Tensor: def backward(self, d_out: Tensor) -> Tensor: dX = d_out * self.mask_ if self.mask_ is not None else d_out return dX + + +class Flatten(Layer): + """ + A flatten layer reshapes the input tensor into a 2D array(`Matrix`), + collapsing all dimensions except the batch dimension. + + Notes + ----- + - Use this class when using `Dense` layer. + Flatten the tensor into matrix in order to feed-forward dense layer(s). + """ + + def __init__(self) -> None: + super().__init__() + + def forward(self, X: Tensor) -> Matrix: + self.input_ = X + return X.reshape(X.shape[0], -1) + + def backward(self, d_out: Matrix) -> Tensor: + dX = d_out.reshape(self.input_.shape) + return dX