Skip to content

Commit

Permalink
Merge pull request #31 from marcpinet/refactor-better-compat-handling
Browse files Browse the repository at this point in the history
Refactor better compat handling
  • Loading branch information
marcpinet authored Apr 24, 2024
2 parents 521d3c7 + d77dfdc commit 4df5d60
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 35 deletions.
20 changes: 20 additions & 0 deletions neuralnetlib/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -963,3 +963,23 @@ def _pool_backward(output_error: np.ndarray, input_data: np.ndarray, pool_size:
d_input = d_input[:, pad_steps:-pad_steps, :]

return d_input


# --------------------------------------------------------------------------------------------------------------


compatibility_dict = {
Input: [Dense, Conv2D, Conv1D, Embedding],
Dense: [Dense, Activation, Dropout, BatchNormalization],
Activation: [Dense, Conv2D, Conv1D, MaxPooling2D, AveragePooling2D, MaxPooling1D, AveragePooling1D, Flatten, Dropout],
Conv2D: [Conv2D, MaxPooling2D, AveragePooling2D, Activation, Dropout, Flatten, BatchNormalization],
MaxPooling2D: [Conv2D, MaxPooling2D, AveragePooling2D, Flatten],
AveragePooling2D: [Conv2D, MaxPooling2D, AveragePooling2D, Flatten],
Conv1D: [Conv1D, MaxPooling1D, AveragePooling1D, Activation, Dropout, Flatten, BatchNormalization],
MaxPooling1D: [Conv1D, MaxPooling1D, AveragePooling1D, Flatten],
AveragePooling1D: [Conv1D, MaxPooling1D, AveragePooling1D, Flatten],
Flatten: [Dense, Dropout],
Dropout: [Dense, Conv2D, Conv1D, Activation],
Embedding: [Conv1D, Flatten, Dense],
BatchNormalization: [Dense, Conv2D, Conv1D, Activation]
}
37 changes: 3 additions & 34 deletions neuralnetlib/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import numpy as np

from neuralnetlib.layers import Layer, Input, Dense, Activation, Conv2D, MaxPooling2D, Conv1D, MaxPooling1D, AveragePooling2D, AveragePooling1D, Flatten, Dropout, Embedding, BatchNormalization
from neuralnetlib.layers import Layer, Input, Activation, Dropout, compatibility_dict
from neuralnetlib.losses import LossFunction, CategoricalCrossentropy
from neuralnetlib.optimizers import Optimizer
from neuralnetlib.utils import shuffle, progress_bar
Expand Down Expand Up @@ -38,39 +38,8 @@ def add(self, layer: Layer):
raise ValueError("The first layer must be an Input layer.")
else:
previous_layer = self.layers[-1]
if isinstance(previous_layer, Input):
if not isinstance(layer, (Dense, Conv2D, Conv1D, Embedding)):
raise ValueError("Input layer can only be followed by Dense, Conv2D, Conv1D, or Embedding.")
elif isinstance(previous_layer, Dense):
if not isinstance(layer, (Dense, Activation, Dropout, BatchNormalization)):
raise ValueError("Dense layer can only be followed by Dense, Activation, BatchNormalization, or Dropout.")
elif isinstance(previous_layer, Activation):
if not isinstance(layer, (Dense, Conv2D, Conv1D, MaxPooling2D, AveragePooling2D, MaxPooling1D, AveragePooling1D, Flatten, Dropout)):
raise ValueError("Activation layer can only be followed by Dense, Conv2D, Conv1D, MaxPooling2D, AveragePooling2D, MaxPooling1D, AveragePooling1D, Flatten, or Dropout.")
elif isinstance(previous_layer, Conv2D):
if not isinstance(layer, (Conv2D, MaxPooling2D, AveragePooling2D, Activation, Dropout, Flatten, BatchNormalization)):
raise ValueError("Conv2D layer can only be followed by Conv2D, MaxPooling2D, AveragePooling2D, Activation, Dropout, BatchNormalization, or Flatten.")
elif isinstance(previous_layer, MaxPooling2D) or isinstance(previous_layer, AveragePooling2D):
if not isinstance(layer, (Conv2D, MaxPooling2D, AveragePooling2D, Flatten)):
raise ValueError("MaxPooling2D or AveragePooling2D layer can only be followed by Conv2D, MaxPooling2D, AveragePooling2D, or Flatten.")
elif isinstance(previous_layer, Conv1D):
if not isinstance(layer, (Conv1D, MaxPooling1D, AveragePooling1D, Activation, Dropout, Flatten, BatchNormalization)):
raise ValueError("Conv1D layer can only be followed by Conv1D, MaxPooling1D, AveragePooling1D, Activation, Dropout, BatchNormalization, or Flatten.")
elif isinstance(previous_layer, MaxPooling1D) or isinstance(previous_layer, AveragePooling1D):
if not isinstance(layer, (Conv1D, MaxPooling1D, AveragePooling1D, Flatten)):
raise ValueError("MaxPooling1D or AveragePooling1D layer can only be followed by Conv1D, MaxPooling1D, AveragePooling1D, or Flatten.")
elif isinstance(previous_layer, Flatten):
if not isinstance(layer, (Dense, Dropout)):
raise ValueError("Flatten layer can only be followed by Dense or Dropout.")
elif isinstance(previous_layer, Dropout):
if not isinstance(layer, (Dense, Conv2D, Conv1D, Activation)):
raise ValueError("Dropout layer can only be followed by Dense, Conv2D, Conv1D, or Activation.")
elif isinstance(previous_layer, Embedding):
if not isinstance(layer, (Conv1D, Flatten, Dense)):
raise ValueError("Embedding layer can only be followed by Conv1D, Flatten, or Dense.")
elif isinstance(previous_layer, BatchNormalization):
if not isinstance(layer, (Dense, Conv2D, Conv1D, Activation)):
raise ValueError("BatchNormalization layer can only be followed by Dense, Conv2D, Conv1D, or Activation.")
if type(layer) not in compatibility_dict[type(previous_layer)]:
raise ValueError(f"{type(layer).__name__} layer cannot follow {type(previous_layer).__name__} layer.")

self.layers.append(layer)

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setup(
name='neuralnetlib',
version='2.3.1',
version='2.3.2',
author='Marc Pinet',
description='A simple convolutional neural network library with only numpy as dependency',
long_description=open('README.md', encoding="utf-8").read(),
Expand Down

0 comments on commit 4df5d60

Please sign in to comment.