From e86cde96c6f39f6c321fee0f45dd4469168a4f48 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Fri, 20 Dec 2024 11:46:23 +0100 Subject: [PATCH] restore eval on training split --- algorithmic_efficiency/spec.py | 24 +++++++++++------------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/algorithmic_efficiency/spec.py b/algorithmic_efficiency/spec.py index 598715696..6d1743a66 100644 --- a/algorithmic_efficiency/spec.py +++ b/algorithmic_efficiency/spec.py @@ -318,20 +318,18 @@ def eval_model(self, imagenet_v2_data_dir: Optional[str], global_step: int) -> Dict[str, float]: """Run a full evaluation of the model.""" - # (nico): skip eval on test eval_metrics = {} - if False: - logging.info('Evaluating on the training split.') - train_metrics = self._eval_model_on_split( - split='eval_train', - num_examples=self.num_eval_train_examples, - global_batch_size=global_batch_size, - params=params, - model_state=model_state, - rng=rng, - data_dir=data_dir, - global_step=global_step) - eval_metrics = {'train/' + k: v for k, v in train_metrics.items()} + logging.info('Evaluating on the training split.') + train_metrics = self._eval_model_on_split( + split='eval_train', + num_examples=self.num_eval_train_examples, + global_batch_size=global_batch_size, + params=params, + model_state=model_state, + rng=rng, + data_dir=data_dir, + global_step=global_step) + eval_metrics = {'train/' + k: v for k, v in train_metrics.items()} # We always require a validation set. logging.info('Evaluating on the validation split.') validation_metrics = self._eval_model_on_split(