diff --git a/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py b/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py index 6c859b8dd..9ed09a615 100644 --- a/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py +++ b/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py @@ -180,7 +180,7 @@ def init_optimizer_state(workload: spec.Workload, def jax_cosine_warmup(step_hint: int, hyperparameters): # Create learning rate schedule. - warmup_steps = int(hyperparameters['warmup_factor * step_hint']) + warmup_steps = int(hyperparameters['warmup_factor'] * step_hint) warmup_fn = optax.linear_schedule( init_value=0., end_value=hyperparameters['learning_rate'],