Skip to content

Commit

Permalink
Merge pull request #27 from marcpinet/refactor-dropout-seed
Browse files Browse the repository at this point in the history
refactor: add seed to dropout too
  • Loading branch information
marcpinet authored Apr 21, 2024
2 parents d645d63 + 7514a3c commit d64b130
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 56 deletions.
3 changes: 3 additions & 0 deletions examples/cnn-classification/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
matplotlib
numpy
tensorflow
98 changes: 47 additions & 51 deletions examples/cnn-classification/simple_cnn_classification_mnist.ipynb

Large diffs are not rendered by default.

11 changes: 7 additions & 4 deletions neuralnetlib/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,16 +180,18 @@ def from_config(config: dict):


class Dropout(Layer):
def __init__(self, rate: float):
def __init__(self, rate: float, seed: int = None):
self.rate = rate
self.mask = None
self.seed = seed

def __str__(self):
return f'Dropout(rate={self.rate})'

def forward_pass(self, input_data: np.ndarray, training: bool = True) -> np.ndarray:
if training:
self.mask = np.random.binomial(1, 1 - self.rate, size=input_data.shape) / (1 - self.rate)
rng = np.random.default_rng(self.seed)
self.mask = rng.binomial(1, 1 - self.rate, size=input_data.shape) / (1 - self.rate)
return input_data * self.mask
else:
return input_data
Expand All @@ -200,12 +202,13 @@ def backward_pass(self, output_error: np.ndarray) -> np.ndarray:
def get_config(self) -> dict:
return {
'name': self.__class__.__name__,
'rate': self.rate
'rate': self.rate,
'seed': self.seed
}

@staticmethod
def from_config(config: dict):
return Dropout(config['rate'])
return Dropout(config['rate'], config['seed'])


class Conv2D(Layer):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setup(
name='neuralnetlib',
version='2.2.0',
version='2.2.1',
author='Marc Pinet',
description='A simple convolutional neural network library with only numpy as dependency',
long_description=open('README.md', encoding="utf-8").read(),
Expand Down

0 comments on commit d64b130

Please sign in to comment.