Skip to content

Commit

Permalink
fix(dropout): input shape for gans
Browse files Browse the repository at this point in the history
  • Loading branch information
marcpinet committed Dec 10, 2024
1 parent cb9db3a commit 3b457ad
Showing 1 changed file with 10 additions and 0 deletions.
10 changes: 10 additions & 0 deletions neuralnetlib/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,7 @@ def __init__(self,
self.adaptive = adaptive
self.random_state = random_state
self.mask = None
self.input = None

if adaptive:
self.dropout_impl = AdaptiveDropout(
Expand All @@ -296,6 +297,8 @@ def __str__(self) -> str:
return f'Dropout(rate={self.rate})'

def forward_pass(self, input_data: np.ndarray, training: bool = True) -> np.ndarray:
self.input = input_data

if not training:
return input_data

Expand All @@ -311,6 +314,13 @@ def forward_pass(self, input_data: np.ndarray, training: bool = True) -> np.ndar
def backward_pass(self, output_error: np.ndarray) -> np.ndarray:
if self.adaptive:
return self.dropout_impl.gradient(output_error)

if output_error.shape[0] != self.mask.shape[0]:
rng = np.random.default_rng(
self.random_state if self.random_state is not None else int(time.time_ns()))
self.mask = rng.binomial(1, 1 - self.rate,
size=(output_error.shape[0], self.mask.shape[1])) / (1 - self.rate)

return output_error * self.mask

def get_config(self) -> dict:
Expand Down

0 comments on commit 3b457ad

Please sign in to comment.