diff --git a/algorithmic_efficiency/spec.py b/algorithmic_efficiency/spec.py index 598715696..6d1743a66 100644 --- a/algorithmic_efficiency/spec.py +++ b/algorithmic_efficiency/spec.py @@ -318,20 +318,18 @@ def eval_model(self, imagenet_v2_data_dir: Optional[str], global_step: int) -> Dict[str, float]: """Run a full evaluation of the model.""" - # (nico): skip eval on test eval_metrics = {} - if False: - logging.info('Evaluating on the training split.') - train_metrics = self._eval_model_on_split( - split='eval_train', - num_examples=self.num_eval_train_examples, - global_batch_size=global_batch_size, - params=params, - model_state=model_state, - rng=rng, - data_dir=data_dir, - global_step=global_step) - eval_metrics = {'train/' + k: v for k, v in train_metrics.items()} + logging.info('Evaluating on the training split.') + train_metrics = self._eval_model_on_split( + split='eval_train', + num_examples=self.num_eval_train_examples, + global_batch_size=global_batch_size, + params=params, + model_state=model_state, + rng=rng, + data_dir=data_dir, + global_step=global_step) + eval_metrics = {'train/' + k: v for k, v in train_metrics.items()} # We always require a validation set. logging.info('Evaluating on the validation split.') validation_metrics = self._eval_model_on_split(