Skip to content

Commit

Permalink
include batch norm and pooling
Browse files Browse the repository at this point in the history
  • Loading branch information
nmcardoso committed Aug 23, 2023
1 parent 68330d7 commit 9d179e0
Showing 1 changed file with 40 additions and 16 deletions.
56 changes: 40 additions & 16 deletions mergernet/estimators/parametric.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ def build(self, freeze_conv: bool = False) -> tf.keras.Model:
)
conv_block._name = 'conv_block'
conv_block.trainable = (not freeze_conv)
L.info(f'Trainable weights of convolutional block: {len(conv_block.trainable_weights)}')
L.info(f'Conv block weights: {len(conv_block.weights)}')
L.info(f'Conv block trainable weights: {len(conv_block.trainable_weights)}')
L.info(f'Conv block non-trainable weights: {len(conv_block.non_trainable_weights)}')

data_aug_block = self.get_dataaug_block(
flip_horizontal=True,
Expand All @@ -41,27 +43,49 @@ def build(self, freeze_conv: bool = False) -> tf.keras.Model:
zoom=False
)

# Input
inputs = tf.keras.Input(shape=self.dataset.config.image_shape)

# Data Augmentation
x = data_aug_block(inputs)

# Input pre-processing
x = preprocess_input(x)

# Feature extractor
x = conv_block(x)
x = tf.keras.layers.Flatten()(x)
if self.hp.get('dense_1_units'):
x = tf.keras.layers.Dense(self.hp.get('dense_1_units'), activation='relu')(x)
if self.hp.get('dropout_1_rate'):
x = tf.keras.layers.Dropout(self.hp.get('dropout_1_rate'))(x)
if self.hp.get('dense_2_units'):
x = tf.keras.layers.Dense(self.hp.get('dense_2_units'), activation='relu')(x)
if self.hp.get('dropout_2_rate'):
x = tf.keras.layers.Dropout(self.hp.get('dropout_2_rate'))(x)
if self.hp.get('dense_3_units'):
x = tf.keras.layers.Dense(self.hp.get('dense_3_units'), activation='relu')(x)
if self.hp.get('dropout_3_rate'):
x = tf.keras.layers.Dropout(self.hp.get('dropout_3_rate'), activation='softmax')(x)
outputs = tf.keras.layers.Dense(self.dataset.config.n_classes)(x)

# Representation layer
representation_mode = self.hp.get('features_reduction', default='flatten')
if representation_mode == 'flatten':
x = tf.keras.layers.Flatten()(x)
elif representation_mode == 'avg_pooling':
x = tf.keras.layers.GlobalAveragePooling2D()(x)
elif representation_mode == 'max_pooling':
x = tf.keras.layers.GlobalMaxPooling2D()(x)
if self.hp.get('batch_norm_0'):
x = tf.keras.layers.BatchNormalization()(x)
if self.hp.get('dropout_0_rate'):
x = tf.keras.layers.Dropout(self.hp.get('dropout_0_rate'))(x)

# Classifier
for i in range(1, 4):
if self.hp.get(f'dense_{i}_units'):
x = tf.keras.layers.Dense(self.hp.get(f'dense_{i}_units'))(x)
if self.hp.get(f'batch_norm_{i}'):
x = tf.keras.layers.BatchNormalization()(x)
if self.hp.get(f'activation_{i}', default='relu') == 'relu':
x = tf.keras.layers.Activation('relu')(x)
if self.hp.get(f'dropout_{i}_rate'):
x = tf.keras.layers.Dropout(self.hp.get(f'dropout_{i}_rate'))(x)

# Classifications
outputs = tf.keras.layers.Dense(self.dataset.config.n_classes, activation='softmax')(x)

self._tf_model = tf.keras.Model(inputs, outputs)
L.info(f'Trainable weights (TOTAL): {len(self._tf_model.trainable_weights)}')
L.info(f'Final model weights: {len(self._tf_model.weights)}')
L.info(f'Final model trainable weights: {len(self._tf_model.trainable_weights)}')
L.info(f'Final model non-trainable weights: {len(self._tf_model.non_trainable_weights)}')

return self._tf_model

Expand Down

0 comments on commit 9d179e0

Please sign in to comment.