diff --git a/mergernet/estimators/parametric.py b/mergernet/estimators/parametric.py index 31680ee5..32cbacfe 100644 --- a/mergernet/estimators/parametric.py +++ b/mergernet/estimators/parametric.py @@ -149,14 +149,12 @@ def train( restore_best_weights=True ) - t1_epochs = self.hp.get('tl_epochs', default=10) - t = Timming() L.info('Start of training loop with frozen CNN') h = model.fit( ds_train, batch_size=self.hp.get('batch_size'), - epochs=t1_epochs, + epochs=self.hp.get('tl_epochs', default=10), validation_data=ds_test, class_weight=class_weights, callbacks=[early_stop_cb, wandb_metrics, wandb_graphics]