Skip to content

Commit

Permalink
ported submission_runner.py from a100, remove eval on train/test add …
Browse files Browse the repository at this point in the history
…script
  • Loading branch information
Niccolo-Ajroldi committed Nov 17, 2024
1 parent bf3ffb0 commit 76ae00f
Showing 1 changed file with 121 additions and 131 deletions.
252 changes: 121 additions & 131 deletions submission_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,11 @@
import datetime
import gc
import importlib
from inspect import signature
import itertools
import json
import os
import struct
import time
from types import MappingProxyType
from typing import Any, Dict, Optional, Tuple

from absl import app
Expand Down Expand Up @@ -135,6 +133,9 @@
flags.DEFINE_integer('max_global_steps',
None,
'Maximum number of update steps.')
flags.DEFINE_float('max_pct_of_global_steps',
0.0,
'Maximum number of update steps.')
flags.DEFINE_boolean(
'overwrite',
False,
Expand Down Expand Up @@ -162,6 +163,14 @@
'Number of workers for ImageNet PyTorch evaluation data loaders.'
'WARNING: Setting pytorch_eval_num_workers != 0, will result '
'in incorrect evals currently, see issues/732.')
flags.DEFINE_boolean(
'halve_CUDA_mem',
False,
'Halve the available VRAM.')
flags.DEFINE_boolean(
'allow_tf32',
False,
'Allow TF32 on Ampere.')
FLAGS = flags.FLAGS
USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_setup()

Expand Down Expand Up @@ -202,12 +211,12 @@ def train_once(
init_optimizer_state: spec.InitOptimizerFn,
update_params: spec.UpdateParamsFn,
data_selection: spec.DataSelectionFn,
prepare_for_eval: Optional[spec.PrepareForEvalFn],
hyperparameters: Optional[spec.Hyperparameters],
rng_seed: int,
rng: spec.RandomState,
profiler: Profiler,
max_global_steps: int = None,
max_pct_of_global_steps: float = None,
log_dir: Optional[str] = None,
save_checkpoints: Optional[bool] = True
) -> Tuple[spec.Timing, Dict[str, Any]]:
Expand Down Expand Up @@ -276,10 +285,6 @@ def train_once(
hyperparameters,
opt_init_rng)
logging.info('Initializing metrics bundle.')

# Check if 'train_state' is in the function signature
needs_train_state = 'train_state' in signature(update_params).parameters

# Bookkeeping.
train_state = {
'validation_goal_reached': False,
Expand Down Expand Up @@ -342,8 +347,7 @@ def train_once(
not train_state['training_complete']:

step_rng = prng.fold_in(rng, global_step)
data_select_rng, update_rng, prep_eval_rng, eval_rng = \
prng.split(step_rng, 4)
data_select_rng, update_rng, eval_rng = prng.split(step_rng, 3)

with profiler.profile('Data selection'):
batch = data_selection(workload,
Expand All @@ -367,143 +371,115 @@ def train_once(
optimizer_state=optimizer_state,
eval_results=eval_results,
global_step=global_step,
rng=update_rng,
**({'train_state': MappingProxyType(train_state)}
if needs_train_state else {}))
rng=update_rng)
except spec.TrainingCompleteError:
train_state['training_complete'] = True
global_step += 1
if (max_global_steps is not None) and (global_step == max_global_steps):
train_state['training_complete'] = True
# (nico): train for a fixed pct of step_hint
if (max_pct_of_global_steps is not None) and \
(global_step / workload.step_hint >= max_pct_of_global_steps):
train_state['training_complete'] = True

train_step_end_time = get_time()

train_state['accumulated_submission_time'] += (
train_step_end_time - train_state['last_step_end_time'])

# Use 3x the runtime budget for the self-tuning ruleset.
max_allowed_runtime_sec = (
workload.max_allowed_runtime_sec if FLAGS.tuning_ruleset == 'external'
else 3 * workload.max_allowed_runtime_sec)
train_state['is_time_remaining'] = (
train_state['accumulated_submission_time'] < max_allowed_runtime_sec)
# Check if submission is eligible for an untimed eval.
if ((train_step_end_time - train_state['last_eval_time']) >=
workload.eval_period_time_sec or train_state['training_complete']):

# Prepare for evaluation (timed).
if prepare_for_eval is not None:

with profiler.profile('Prepare for eval'):
del batch
prepare_for_eval_start_time = get_time()
optimizer_state, model_params, model_state = prepare_for_eval(
workload=workload,
current_param_container=model_params,
current_params_types=workload.model_params_types,
model_state=model_state,
hyperparameters=hyperparameters,
loss_type=workload.loss_type,
optimizer_state=optimizer_state,
eval_results=eval_results,
global_step=global_step,
rng=prep_eval_rng)
prepare_for_eval_end_time = get_time()

# Update sumbission time.
train_state['accumulated_submission_time'] += (
prepare_for_eval_end_time - prepare_for_eval_start_time)

# Check if time is remaining,
# use 3x the runtime budget for the self-tuning ruleset.
max_allowed_runtime_sec = (
workload.max_allowed_runtime_sec if FLAGS.tuning_ruleset == 'external'
else 3 * workload.max_allowed_runtime_sec)
train_state['is_time_remaining'] = (
train_state['accumulated_submission_time'] < max_allowed_runtime_sec)

# Eval if time is remaining (untimed).
if train_state['is_time_remaining']:

with profiler.profile('Evaluation'):
_reset_cuda_mem()

try:
eval_start_time = get_time()
latest_eval_result = workload.eval_model(global_eval_batch_size,
model_params,
model_state,
eval_rng,
data_dir,
imagenet_v2_data_dir,
global_step)
# Check if targets reached.
# Note that this is one of the stopping conditions for the length of
# a training run. To score the run we only consider the time
# to validation target retrospectively.
train_state['validation_goal_reached'] = (
workload.has_reached_validation_target(latest_eval_result) or
train_state['validation_goal_reached'])
train_state['test_goal_reached'] = (
workload.has_reached_test_target(latest_eval_result) or
train_state['test_goal_reached'])
goals_reached = (
train_state['validation_goal_reached'] and
train_state['test_goal_reached'])
# Save last eval time.
eval_end_time = get_time()
train_state['last_eval_time'] = eval_end_time

# Accumulate eval time.
train_state[
'accumulated_eval_time'] += eval_end_time - eval_start_time

# Add times to eval results for logging.
latest_eval_result['score'] = (
train_state['accumulated_submission_time'])
latest_eval_result[
'total_duration'] = eval_end_time - global_start_time
latest_eval_result['accumulated_submission_time'] = train_state[
'accumulated_submission_time']
latest_eval_result['accumulated_eval_time'] = train_state[
'accumulated_eval_time']
latest_eval_result['accumulated_logging_time'] = train_state[
'accumulated_logging_time']
time_since_start = latest_eval_result['total_duration']
logging.info(f'Time since start: {time_since_start:.2f}s, '
f'\tStep: {global_step}, \t{latest_eval_result}')
eval_results.append((global_step, latest_eval_result))

logging_start_time = get_time()

if log_dir is not None and RANK == 0:
metrics_logger.append_scalar_metrics(
latest_eval_result,
with profiler.profile('Evaluation'):
del batch
_reset_cuda_mem()

try:
eval_start_time = get_time()
latest_eval_result = workload.eval_model(global_eval_batch_size,
model_params,
model_state,
eval_rng,
data_dir,
imagenet_v2_data_dir,
global_step)
# Check if targets reached.
# Note that this is one of the stopping conditions for the length of
# a training run. To score the run we only consider the time
# to validation target retrospectively.
train_state['validation_goal_reached'] = (
workload.has_reached_validation_target(latest_eval_result) or
train_state['validation_goal_reached'])
train_state['test_goal_reached'] = (
workload.has_reached_test_target(latest_eval_result) or
train_state['test_goal_reached'])
goals_reached = (
train_state['validation_goal_reached'] and
train_state['test_goal_reached'])
# Save last eval time.
eval_end_time = get_time()
train_state['last_eval_time'] = eval_end_time

# Accumulate eval time.
train_state[
'accumulated_eval_time'] += eval_end_time - eval_start_time

# Add times to eval results for logging.
latest_eval_result['score'] = (
train_state['accumulated_submission_time'])
latest_eval_result[
'total_duration'] = eval_end_time - global_start_time
latest_eval_result['accumulated_submission_time'] = train_state[
'accumulated_submission_time']
latest_eval_result['accumulated_eval_time'] = train_state[
'accumulated_eval_time']
latest_eval_result['accumulated_logging_time'] = train_state[
'accumulated_logging_time']
time_since_start = latest_eval_result['total_duration']
logging.info(f'Time since start: {time_since_start:.2f}s, '
f'\tStep: {global_step}, \t{latest_eval_result}')
eval_results.append((global_step, latest_eval_result))

logging_start_time = get_time()

if log_dir is not None and RANK == 0:
metrics_logger.append_scalar_metrics(
latest_eval_result,
global_step=global_step,
preemption_count=preemption_count,
is_eval=True,
)
if save_checkpoints:
checkpoint_utils.save_checkpoint(
framework=FLAGS.framework,
optimizer_state=optimizer_state,
model_params=model_params,
model_state=model_state,
train_state=train_state,
eval_results=eval_results,
global_step=global_step,
preemption_count=preemption_count,
is_eval=True,
)
if save_checkpoints:
checkpoint_utils.save_checkpoint(
framework=FLAGS.framework,
optimizer_state=optimizer_state,
model_params=model_params,
model_state=model_state,
train_state=train_state,
eval_results=eval_results,
global_step=global_step,
preemption_count=preemption_count,
checkpoint_dir=log_dir,
save_intermediate_checkpoints=FLAGS
.save_intermediate_checkpoints)

logging_end_time = get_time()
train_state['accumulated_logging_time'] += (
logging_end_time - logging_start_time)
checkpoint_dir=log_dir,
save_intermediate_checkpoints=FLAGS
.save_intermediate_checkpoints)

_reset_cuda_mem()
logging_end_time = get_time()
train_state['accumulated_logging_time'] += (
logging_end_time - logging_start_time)

except RuntimeError as e:
logging.exception(f'Eval step {global_step} error.\n')
if 'out of memory' in str(e):
logging.warning(
'Error: GPU out of memory during eval during step '
f'{global_step}, error : {str(e)}.')
_reset_cuda_mem()
_reset_cuda_mem()

except RuntimeError as e:
logging.exception(f'Eval step {global_step} error.\n')
if 'out of memory' in str(e):
logging.warning('Error: GPU out of memory during eval during step '
f'{global_step}, error : {str(e)}.')
_reset_cuda_mem()

train_state['last_step_end_time'] = get_time()

Expand Down Expand Up @@ -538,6 +514,7 @@ def score_submission_on_workload(workload: spec.Workload,
tuning_ruleset: str,
profiler: Optional[Profiler] = None,
max_global_steps: Optional[int] = None,
max_pct_of_global_steps: Optional[float] = None,
imagenet_v2_data_dir: Optional[str] = None,
tuning_search_space: Optional[str] = None,
num_tuning_trials: Optional[int] = None,
Expand All @@ -558,7 +535,6 @@ def score_submission_on_workload(workload: spec.Workload,
init_optimizer_state = submission_module.init_optimizer_state
update_params = submission_module.update_params
data_selection = submission_module.data_selection
prepare_for_eval = getattr(submission_module, 'prepare_for_eval', None)
try:
global_batch_size = submission_module.get_batch_size(workload_name)
except ValueError:
Expand Down Expand Up @@ -631,12 +607,12 @@ def score_submission_on_workload(workload: spec.Workload,
data_dir, imagenet_v2_data_dir,
init_optimizer_state,
update_params, data_selection,
prepare_for_eval,
hyperparameters,
rng_seed,
rng,
profiler,
max_global_steps,
max_pct_of_global_steps,
tuning_dir_name,
save_checkpoints=save_checkpoints,)
all_timings[hi] = timing
Expand Down Expand Up @@ -673,6 +649,19 @@ def score_submission_on_workload(workload: spec.Workload,


def main(_):

if FLAGS.framework == 'pytorch':

if FLAGS.halve_CUDA_mem:
torch.cuda.set_per_process_memory_fraction(0.5, device=DEVICE)

if FLAGS.allow_tf32:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
else:
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False

if FLAGS.profile:
profiler = Profiler()
else:
Expand Down Expand Up @@ -729,6 +718,7 @@ def main(_):
tuning_ruleset=FLAGS.tuning_ruleset,
profiler=profiler,
max_global_steps=FLAGS.max_global_steps,
max_pct_of_global_steps=FLAGS.max_pct_of_global_steps,
imagenet_v2_data_dir=FLAGS.imagenet_v2_data_dir,
tuning_search_space=FLAGS.tuning_search_space,
num_tuning_trials=FLAGS.num_tuning_trials,
Expand Down

0 comments on commit 76ae00f

Please sign in to comment.