From 1294d56add9853c5901e75948d34d5a6253bcc87 Mon Sep 17 00:00:00 2001 From: Thayer Date: Mon, 23 Oct 2023 09:42:07 -0400 Subject: [PATCH] remove SAM and Lion and apply scheduler to weight_decay --- src/backend.py | 12 +--- src/callbacks.py | 27 ++++++-- src/lion.py | 106 ------------------------------- src/opticalnet.py | 63 ------------------ src/train.py | 6 -- src/train_fourierspace_models.sh | 32 +++------- 6 files changed, 34 insertions(+), 212 deletions(-) delete mode 100644 src/lion.py diff --git a/src/backend.py b/src/backend.py index 30a1ba14..4cb4eeea 100644 --- a/src/backend.py +++ b/src/backend.py @@ -53,7 +53,6 @@ import opticalresnet import baseline import otfnet -from lion import Lion logging.basicConfig( @@ -1284,7 +1283,6 @@ def train( increase_dropout_depth: bool = False, decrease_dropout_depth: bool = False, stem: bool = False, - sam: bool = True, ): network = network.lower() opt = opt.lower() @@ -1302,16 +1300,13 @@ def train( SGDR - Stochastic Gradient Descent with Warm Restarts: https://arxiv.org/pdf/1608.03983 AdamW - Decoupled weight decay regularization: https://arxiv.org/pdf/1711.05101 SAM - Sharpness-Aware-Minimization (SAM): https://openreview.net/pdf?id=6Tm1mposlrM - Lion - Symbolic Discovery of Optimization Algorithms: https://arxiv.org/pdf/2302.06675 """ - if opt.lower() == 'lion': - opt = Lion(learning_rate=lr, weight_decay=wd) - elif opt.lower() == 'adam': + if opt.lower() == 'adam': opt = Adam(learning_rate=lr) elif opt == 'sgd': opt = SGD(learning_rate=lr, momentum=0.9) elif opt.lower() == 'adamw': - opt = AdamW(learning_rate=lr, weight_decay=wd) + opt = AdamW(learning_rate=lr, weight_decay=wd, beta_1=0.9, beta_2=0.99) elif opt == 'sgdw': opt = SGDW(learning_rate=lr, weight_decay=wd, momentum=0.9) else: @@ -1343,7 +1338,6 @@ def train( name='OpticalNet', roi=roi, stem=stem, - sam=sam, patches=patch_size, modes=pmodes, depth_scalar=depth_scalar, @@ -1452,7 +1446,7 @@ def train( else: lrscheduler = LearningRateScheduler( initial_learning_rate=opt.learning_rate, - weight_decay=opt.weight_decay, + weight_decay=opt.weight_decay if hasattr(opt, 'weight_decay') else None, decay_period=epochs if decay_period is None else decay_period, warmup_epochs=0 if warmup is None or warmup >= epochs else warmup, alpha=.01, diff --git a/src/callbacks.py b/src/callbacks.py index b1b60764..bfa3d990 100644 --- a/src/callbacks.py +++ b/src/callbacks.py @@ -157,6 +157,9 @@ def on_epoch_begin(self, epoch, logs=None): except AttributeError: raise ValueError('Optimizer must have a `learning_rate`') + if hasattr(self.model.optimizer, 'weight_decay'): + wd = backend.get_value(self.model.optimizer.weight_decay) + if not self.fixed: lr = tf.cond( epoch < self.warmup_epochs, @@ -171,18 +174,30 @@ def on_epoch_begin(self, epoch, logs=None): ) backend.set_value(self.model.optimizer.lr, backend.get_value(lr)) + if hasattr(self.model.optimizer, 'weight_decay'): + wd = tf.cond( + epoch < self.warmup_epochs, + lambda: self.linear_warmup( + val=self.weight_decay, + step=epoch, + ), + lambda: self.cosine_decay( + val=self.weight_decay, + step=epoch - self.warmup_epochs, + ) + ) + backend.set_value(self.model.optimizer.weight_decay, backend.get_value(wd)) + if self.verbose > 0: logger.info(f'Scheduler setting learning rate: {lr}') + if hasattr(self.model.optimizer, 'weight_decay'): + logger.info(f'Scheduler setting weight decay: {wd}') + tf.summary.scalar('learning rate', data=lr, step=epoch) if hasattr(self.model.optimizer, 'weight_decay'): - backend.set_value(self.model.optimizer.weight_decay, backend.get_value(self.weight_decay)) - - if self.verbose > 0: - logger.info(f'Scheduler setting weight decay: {backend.get_value(self.model.optimizer.weight_decay)}') - - tf.summary.scalar('weight decay', data=self.weight_decay, step=epoch) + tf.summary.scalar('weight decay', data=wd, step=epoch) def linear_warmup(self, val, step, power=1.0): completed_fraction = step / self.warmup_epochs diff --git a/src/lion.py b/src/lion.py deleted file mode 100644 index 08ed78c2..00000000 --- a/src/lion.py +++ /dev/null @@ -1,106 +0,0 @@ -# Copyright 2023 Google Research. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# https://github.com/google/automl/tree/master/lion -# ============================================================================== -"""TF2 implementation of the Lion optimizer.""" - -import tensorflow as tf - - -class Lion(tf.keras.optimizers.Optimizer): - r"""Optimizer that implements the Lion algorithm.""" - - def __init__(self, - learning_rate=0.0001, - beta_1=0.9, - beta_2=0.99, - weight_decay=0, - name='lion', - **kwargs): - """Construct a new Lion optimizer.""" - - super(Lion, self).__init__(name, **kwargs) - self._set_hyper('learning_rate', kwargs.get('lr', learning_rate)) - self._set_hyper('beta_1', beta_1) - self._set_hyper('beta_2', beta_2) - self._set_hyper('weight_decay', weight_decay) - - def _create_slots(self, var_list): - # Create slots for the first and second moments. - # Separate for-loops to respect the ordering of slot variables from v1. - for var in var_list: - self.add_slot(var, 'm') - - def _prepare_local(self, var_device, var_dtype, apply_state): - super(Lion, self)._prepare_local(var_device, var_dtype, apply_state) - - beta_1_t = tf.identity(self._get_hyper('beta_1', var_dtype)) - beta_2_t = tf.identity(self._get_hyper('beta_2', var_dtype)) - weight_decay_t = tf.identity(self._get_hyper('weight_decay', var_dtype)) - lr = apply_state[(var_device, var_dtype)]['lr_t'] - apply_state[(var_device, var_dtype)].update( - dict( - lr=lr, - beta_1_t=beta_1_t, - one_minus_beta_1_t=1 - beta_1_t, - beta_2_t=beta_2_t, - one_minus_beta_2_t=1 - beta_2_t, - weight_decay_t=weight_decay_t)) - - @tf.function(jit_compile=True) - def _resource_apply_dense(self, grad, var, apply_state=None): - var_device, var_dtype = var.device, var.dtype.base_dtype - coefficients = ((apply_state or {}).get((var_device, var_dtype)) or - self._fallback_apply_state(var_device, var_dtype)) - - m = self.get_slot(var, 'm') - var_t = var.assign_sub( - coefficients['lr_t'] * - (tf.math.sign(m * coefficients['beta_1_t'] + - grad * coefficients['one_minus_beta_1_t']) + - var * coefficients['weight_decay_t'])) - with tf.control_dependencies([var_t]): - m.assign(m * coefficients['beta_2_t'] + - grad * coefficients['one_minus_beta_2_t']) - - @tf.function(jit_compile=True) - def _resource_apply_sparse(self, grad, var, indices, apply_state=None): - var_device, var_dtype = var.device, var.dtype.base_dtype - coefficients = ((apply_state or {}).get((var_device, var_dtype)) or - self._fallback_apply_state(var_device, var_dtype)) - - m = self.get_slot(var, 'm') - m_t = m.assign(m * coefficients['beta_1_t']) - m_scaled_g_values = grad * coefficients['one_minus_beta_1_t'] - m_t = m_t.scatter_add(tf.IndexedSlices(m_scaled_g_values, indices)) - var_t = var.assign_sub(coefficients['lr'] * - (tf.math.sign(m_t) + var * coefficients['weight_decay_t'])) - - with tf.control_dependencies([var_t]): - m_t = m_t.scatter_add(tf.IndexedSlices(-m_scaled_g_values, indices)) - m_t = m_t.assign(m_t * coefficients['beta_2_t'] / - coefficients['beta_1_t']) - m_scaled_g_values = grad * coefficients['one_minus_beta_2_t'] - m_t.scatter_add(tf.IndexedSlices(m_scaled_g_values, indices)) - - def get_config(self): - config = super(Lion, self).get_config() - config.update({ - 'learning_rate': self._serialize_hyperparameter('learning_rate'), - 'beta_1': self._serialize_hyperparameter('beta_1'), - 'beta_2': self._serialize_hyperparameter('beta_2'), - 'weight_decay': self._serialize_hyperparameter('weight_decay'), - }) - return config diff --git a/src/opticalnet.py b/src/opticalnet.py index 42baec25..8a7a0e43 100644 --- a/src/opticalnet.py +++ b/src/opticalnet.py @@ -430,7 +430,6 @@ def __init__( radial_encoding_nth_order=4, decrease_dropout_depth=False, increase_dropout_depth=False, - sam=False, stem=False, **kwargs ): @@ -454,7 +453,6 @@ def __init__( self.radial_encoding_nth_order = radial_encoding_nth_order self.increase_dropout_depth = increase_dropout_depth self.decrease_dropout_depth = decrease_dropout_depth - self.sam = sam def _calc_channels(self, channels, width_scalar): return int(tf.math.ceil(width_scalar * channels)) @@ -462,67 +460,6 @@ def _calc_channels(self, channels, width_scalar): def _calc_repeats(self, repeats, depth_scalar): return int(tf.math.ceil(depth_scalar * repeats)) - def sharpness_aware_minimization(self, x, y, sample_weight=None, rho=0.05, eps=1e-12): - """ - Sharpness-Aware-Minimization (SAM): https://openreview.net/pdf?id=6Tm1mposlrM - https://github.com/Jannoshh/simple-sam - """ - - with tf.GradientTape() as tape: - y_pred = self(x, training=True) - loss = self.compiled_loss(y, y_pred, sample_weight=sample_weight, regularization_losses=self.losses) - - trainable_vars = self.trainable_variables - gradients = tape.gradient(loss, trainable_vars) - - # first step - e_ws = [] - grad_norm = tf.linalg.global_norm(gradients) - ew_multiplier = rho / (grad_norm + eps) - for i in range(len(trainable_vars)): - e_w = tf.math.multiply(gradients[i], ew_multiplier) - trainable_vars[i].assign_add(e_w) - e_ws.append(e_w) - - with tf.GradientTape() as tape: - y_pred = self(x, training=True) - loss = self.compiled_loss(y, y_pred, sample_weight=sample_weight, regularization_losses=self.losses) - - trainable_vars = self.trainable_variables - gradients = tape.gradient(loss, trainable_vars) - - for i in range(len(trainable_vars)): - trainable_vars[i].assign_sub(e_ws[i]) - - self.optimizer.apply_gradients(zip(gradients, trainable_vars)) - self.compiled_metrics.update_state(y, y_pred, sample_weight=sample_weight) - return {m.name: m.result() for m in self.metrics} - - def train_step(self, data): - - # Unpack the data. Its structure depends on your model and - # on what you pass to `fit()`. - if len(data) == 3: - x, y, sample_weight = data - else: - sample_weight = None - x, y = data - - if self.sam: - return self.sharpness_aware_minimization(x=x, y=y, sample_weight=sample_weight) - else: - - with tf.GradientTape() as tape: - y_pred = self(x, training=True) - loss = self.compiled_loss(y, y_pred, sample_weight=sample_weight, regularization_losses=self.losses) - - trainable_vars = self.trainable_variables - gradients = tape.gradient(loss, trainable_vars) - - self.optimizer.apply_gradients(zip(gradients, trainable_vars)) - self.compiled_metrics.update_state(y, y_pred, sample_weight=sample_weight) - return {m.name: m.result() for m in self.metrics} - def transitional_block(self, inputs, img_shape, patch_size, expansion=1, org_patch_size=None): if org_patch_size is not None: inputs = Merge(patch_size=org_patch_size, expansion=expansion)(inputs) diff --git a/src/train.py b/src/train.py index e742b3f8..8c6dc57f 100644 --- a/src/train.py +++ b/src/train.py @@ -138,11 +138,6 @@ def parse_args(args): "--opt", default='AdamW', type=str, help='optimizer to use for training' ) - train_parser.add_argument( - '--sam', action='store_true', - help='toggle to use sharpness aware minimization' - ) - train_parser.add_argument( "--activation", default='gelu', type=str, help='activation function for the model' ) @@ -289,7 +284,6 @@ def main(args=None): radial_encoding_nth_order=args.radial_encoding_nth_order, positional_encoding_scheme=args.positional_encoding_scheme, stem=args.stem, - sam=args.sam, increase_dropout_depth=args.increase_dropout_depth, decrease_dropout_depth=args.decrease_dropout_depth, ) diff --git a/src/train_fourierspace_models.sh b/src/train_fourierspace_models.sh index f690b3aa..7732ce99 100755 --- a/src/train_fourierspace_models.sh +++ b/src/train_fourierspace_models.sh @@ -9,7 +9,6 @@ MAXAMP=1 DZ=200 DY=108 DX=108 -SAM='--sam' RADIAL_ENCODING_PERIOD='--radial_encoding_period 16' RADIAL_ENCODING_ORDER='--radial_encoding_nth_order 4' POSITIONAL_ENCODING_SCHEME='--positional_encoding_scheme rotational_symmetry' @@ -17,7 +16,7 @@ NO_PHASE='--no_phase' DEFOCUS='--lls_defocus' DEFOCUS_ONLY='--defocus_only' EMB="spatial_planes" -BATCH=2048 +BATCH=1024 NETWORK=opticalnet MODES=15 WARMUP=25 @@ -64,27 +63,16 @@ do for DROPOUT in $INCREASE_DROPOUT $DECREASE_DROPOUT do - for OPT in lion adamw #sam + for i in "5e-4 5e-3" "5e-5 5e-4" "5e-5 5e-5" do - for i in "5e-4 5e-5" "5e-5 5e-4" - do - set -- $i - LR=$1 - WD=$2 - - if [ $OPT = 'sam' ];then - python manager.py $CLUSTER train.py --partition gpu_a100 --gpus 4 --cpus 8 \ - --task "$DROPOUT $SAM --opt adamw --lr $LR --wd $WD $POSITIONAL_ENCODING_SCHEME $RADIAL_ENCODING_PERIOD --psf_type $PTYPE --wavelength $LAM --network $NETWORK --embedding $EMB --patch_size '32-16-8-8' --modes $MODES --max_amplitude $MAXAMP --batch_size $BATCH --dataset $DATA --input_shape $SHAPE --depth_scalar $DEPTH --epochs $EPOCHS --warmup $WARMUP" \ - --taskname $NETWORK \ - --name new/$SUBSET/$NETWORK-$MODES-$DIR-$OPT-$LR-$DROPOUT - else - python manager.py $CLUSTER train.py --partition gpu_a100 --gpus 4 --cpus 8 \ - --task "$DROPOUT --opt $OPT --lr $LR --wd $WD $POSITIONAL_ENCODING_SCHEME $RADIAL_ENCODING_PERIOD --psf_type $PTYPE --wavelength $LAM --network $NETWORK --embedding $EMB --patch_size '32-16-8-8' --modes $MODES --max_amplitude $MAXAMP --batch_size $BATCH --dataset $DATA --input_shape $SHAPE --depth_scalar $DEPTH --epochs $EPOCHS --warmup $WARMUP" \ - --taskname $NETWORK \ - --name new/$SUBSET/$NETWORK-$MODES-$DIR-$OPT-$LR-$DROPOUT - fi - - done + set -- $i + LR=$1 + WD=$2 + + python manager.py $CLUSTER train.py --partition gpu_a100 --gpus 4 --cpus 8 \ + --task "$DROPOUT --lr $LR --wd $WD $POSITIONAL_ENCODING_SCHEME $RADIAL_ENCODING_PERIOD --psf_type $PTYPE --wavelength $LAM --network $NETWORK --embedding $EMB --patch_size '32-16-8-8' --modes $MODES --max_amplitude $MAXAMP --batch_size $BATCH --dataset $DATA --input_shape $SHAPE --depth_scalar $DEPTH --epochs $EPOCHS --warmup $WARMUP" \ + --taskname $NETWORK \ + --name new/$SUBSET/$NETWORK-$MODES-$DIR-$LR-$WD-$DROPOUT done done done