Skip to content

Commit

Permalink
remove SAM and Lion and apply scheduler to weight_decay
Browse files Browse the repository at this point in the history
  • Loading branch information
thayeral committed Oct 23, 2023
1 parent 8961566 commit 1294d56
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 212 deletions.
12 changes: 3 additions & 9 deletions src/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@
import opticalresnet
import baseline
import otfnet
from lion import Lion


logging.basicConfig(
Expand Down Expand Up @@ -1284,7 +1283,6 @@ def train(
increase_dropout_depth: bool = False,
decrease_dropout_depth: bool = False,
stem: bool = False,
sam: bool = True,
):
network = network.lower()
opt = opt.lower()
Expand All @@ -1302,16 +1300,13 @@ def train(
SGDR - Stochastic Gradient Descent with Warm Restarts: https://arxiv.org/pdf/1608.03983
AdamW - Decoupled weight decay regularization: https://arxiv.org/pdf/1711.05101
SAM - Sharpness-Aware-Minimization (SAM): https://openreview.net/pdf?id=6Tm1mposlrM
Lion - Symbolic Discovery of Optimization Algorithms: https://arxiv.org/pdf/2302.06675
"""
if opt.lower() == 'lion':
opt = Lion(learning_rate=lr, weight_decay=wd)
elif opt.lower() == 'adam':
if opt.lower() == 'adam':
opt = Adam(learning_rate=lr)
elif opt == 'sgd':
opt = SGD(learning_rate=lr, momentum=0.9)
elif opt.lower() == 'adamw':
opt = AdamW(learning_rate=lr, weight_decay=wd)
opt = AdamW(learning_rate=lr, weight_decay=wd, beta_1=0.9, beta_2=0.99)
elif opt == 'sgdw':
opt = SGDW(learning_rate=lr, weight_decay=wd, momentum=0.9)
else:
Expand Down Expand Up @@ -1343,7 +1338,6 @@ def train(
name='OpticalNet',
roi=roi,
stem=stem,
sam=sam,
patches=patch_size,
modes=pmodes,
depth_scalar=depth_scalar,
Expand Down Expand Up @@ -1452,7 +1446,7 @@ def train(
else:
lrscheduler = LearningRateScheduler(
initial_learning_rate=opt.learning_rate,
weight_decay=opt.weight_decay,
weight_decay=opt.weight_decay if hasattr(opt, 'weight_decay') else None,
decay_period=epochs if decay_period is None else decay_period,
warmup_epochs=0 if warmup is None or warmup >= epochs else warmup,
alpha=.01,
Expand Down
27 changes: 21 additions & 6 deletions src/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,9 @@ def on_epoch_begin(self, epoch, logs=None):
except AttributeError:
raise ValueError('Optimizer must have a `learning_rate`')

if hasattr(self.model.optimizer, 'weight_decay'):
wd = backend.get_value(self.model.optimizer.weight_decay)

if not self.fixed:
lr = tf.cond(
epoch < self.warmup_epochs,
Expand All @@ -171,18 +174,30 @@ def on_epoch_begin(self, epoch, logs=None):
)
backend.set_value(self.model.optimizer.lr, backend.get_value(lr))

if hasattr(self.model.optimizer, 'weight_decay'):
wd = tf.cond(
epoch < self.warmup_epochs,
lambda: self.linear_warmup(
val=self.weight_decay,
step=epoch,
),
lambda: self.cosine_decay(
val=self.weight_decay,
step=epoch - self.warmup_epochs,
)
)
backend.set_value(self.model.optimizer.weight_decay, backend.get_value(wd))

if self.verbose > 0:
logger.info(f'Scheduler setting learning rate: {lr}')

if hasattr(self.model.optimizer, 'weight_decay'):
logger.info(f'Scheduler setting weight decay: {wd}')

tf.summary.scalar('learning rate', data=lr, step=epoch)

if hasattr(self.model.optimizer, 'weight_decay'):
backend.set_value(self.model.optimizer.weight_decay, backend.get_value(self.weight_decay))

if self.verbose > 0:
logger.info(f'Scheduler setting weight decay: {backend.get_value(self.model.optimizer.weight_decay)}')

tf.summary.scalar('weight decay', data=self.weight_decay, step=epoch)
tf.summary.scalar('weight decay', data=wd, step=epoch)

def linear_warmup(self, val, step, power=1.0):
completed_fraction = step / self.warmup_epochs
Expand Down
106 changes: 0 additions & 106 deletions src/lion.py

This file was deleted.

63 changes: 0 additions & 63 deletions src/opticalnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,6 @@ def __init__(
radial_encoding_nth_order=4,
decrease_dropout_depth=False,
increase_dropout_depth=False,
sam=False,
stem=False,
**kwargs
):
Expand All @@ -454,75 +453,13 @@ def __init__(
self.radial_encoding_nth_order = radial_encoding_nth_order
self.increase_dropout_depth = increase_dropout_depth
self.decrease_dropout_depth = decrease_dropout_depth
self.sam = sam

def _calc_channels(self, channels, width_scalar):
return int(tf.math.ceil(width_scalar * channels))

def _calc_repeats(self, repeats, depth_scalar):
return int(tf.math.ceil(depth_scalar * repeats))

def sharpness_aware_minimization(self, x, y, sample_weight=None, rho=0.05, eps=1e-12):
"""
Sharpness-Aware-Minimization (SAM): https://openreview.net/pdf?id=6Tm1mposlrM
https://github.com/Jannoshh/simple-sam
"""

with tf.GradientTape() as tape:
y_pred = self(x, training=True)
loss = self.compiled_loss(y, y_pred, sample_weight=sample_weight, regularization_losses=self.losses)

trainable_vars = self.trainable_variables
gradients = tape.gradient(loss, trainable_vars)

# first step
e_ws = []
grad_norm = tf.linalg.global_norm(gradients)
ew_multiplier = rho / (grad_norm + eps)
for i in range(len(trainable_vars)):
e_w = tf.math.multiply(gradients[i], ew_multiplier)
trainable_vars[i].assign_add(e_w)
e_ws.append(e_w)

with tf.GradientTape() as tape:
y_pred = self(x, training=True)
loss = self.compiled_loss(y, y_pred, sample_weight=sample_weight, regularization_losses=self.losses)

trainable_vars = self.trainable_variables
gradients = tape.gradient(loss, trainable_vars)

for i in range(len(trainable_vars)):
trainable_vars[i].assign_sub(e_ws[i])

self.optimizer.apply_gradients(zip(gradients, trainable_vars))
self.compiled_metrics.update_state(y, y_pred, sample_weight=sample_weight)
return {m.name: m.result() for m in self.metrics}

def train_step(self, data):

# Unpack the data. Its structure depends on your model and
# on what you pass to `fit()`.
if len(data) == 3:
x, y, sample_weight = data
else:
sample_weight = None
x, y = data

if self.sam:
return self.sharpness_aware_minimization(x=x, y=y, sample_weight=sample_weight)
else:

with tf.GradientTape() as tape:
y_pred = self(x, training=True)
loss = self.compiled_loss(y, y_pred, sample_weight=sample_weight, regularization_losses=self.losses)

trainable_vars = self.trainable_variables
gradients = tape.gradient(loss, trainable_vars)

self.optimizer.apply_gradients(zip(gradients, trainable_vars))
self.compiled_metrics.update_state(y, y_pred, sample_weight=sample_weight)
return {m.name: m.result() for m in self.metrics}

def transitional_block(self, inputs, img_shape, patch_size, expansion=1, org_patch_size=None):
if org_patch_size is not None:
inputs = Merge(patch_size=org_patch_size, expansion=expansion)(inputs)
Expand Down
6 changes: 0 additions & 6 deletions src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,11 +138,6 @@ def parse_args(args):
"--opt", default='AdamW', type=str, help='optimizer to use for training'
)

train_parser.add_argument(
'--sam', action='store_true',
help='toggle to use sharpness aware minimization'
)

train_parser.add_argument(
"--activation", default='gelu', type=str, help='activation function for the model'
)
Expand Down Expand Up @@ -289,7 +284,6 @@ def main(args=None):
radial_encoding_nth_order=args.radial_encoding_nth_order,
positional_encoding_scheme=args.positional_encoding_scheme,
stem=args.stem,
sam=args.sam,
increase_dropout_depth=args.increase_dropout_depth,
decrease_dropout_depth=args.decrease_dropout_depth,
)
Expand Down
32 changes: 10 additions & 22 deletions src/train_fourierspace_models.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,14 @@ MAXAMP=1
DZ=200
DY=108
DX=108
SAM='--sam'
RADIAL_ENCODING_PERIOD='--radial_encoding_period 16'
RADIAL_ENCODING_ORDER='--radial_encoding_nth_order 4'
POSITIONAL_ENCODING_SCHEME='--positional_encoding_scheme rotational_symmetry'
NO_PHASE='--no_phase'
DEFOCUS='--lls_defocus'
DEFOCUS_ONLY='--defocus_only'
EMB="spatial_planes"
BATCH=2048
BATCH=1024
NETWORK=opticalnet
MODES=15
WARMUP=25
Expand Down Expand Up @@ -64,27 +63,16 @@ do

for DROPOUT in $INCREASE_DROPOUT $DECREASE_DROPOUT
do
for OPT in lion adamw #sam
for i in "5e-4 5e-3" "5e-5 5e-4" "5e-5 5e-5"
do
for i in "5e-4 5e-5" "5e-5 5e-4"
do
set -- $i
LR=$1
WD=$2

if [ $OPT = 'sam' ];then
python manager.py $CLUSTER train.py --partition gpu_a100 --gpus 4 --cpus 8 \
--task "$DROPOUT $SAM --opt adamw --lr $LR --wd $WD $POSITIONAL_ENCODING_SCHEME $RADIAL_ENCODING_PERIOD --psf_type $PTYPE --wavelength $LAM --network $NETWORK --embedding $EMB --patch_size '32-16-8-8' --modes $MODES --max_amplitude $MAXAMP --batch_size $BATCH --dataset $DATA --input_shape $SHAPE --depth_scalar $DEPTH --epochs $EPOCHS --warmup $WARMUP" \
--taskname $NETWORK \
--name new/$SUBSET/$NETWORK-$MODES-$DIR-$OPT-$LR-$DROPOUT
else
python manager.py $CLUSTER train.py --partition gpu_a100 --gpus 4 --cpus 8 \
--task "$DROPOUT --opt $OPT --lr $LR --wd $WD $POSITIONAL_ENCODING_SCHEME $RADIAL_ENCODING_PERIOD --psf_type $PTYPE --wavelength $LAM --network $NETWORK --embedding $EMB --patch_size '32-16-8-8' --modes $MODES --max_amplitude $MAXAMP --batch_size $BATCH --dataset $DATA --input_shape $SHAPE --depth_scalar $DEPTH --epochs $EPOCHS --warmup $WARMUP" \
--taskname $NETWORK \
--name new/$SUBSET/$NETWORK-$MODES-$DIR-$OPT-$LR-$DROPOUT
fi

done
set -- $i
LR=$1
WD=$2

python manager.py $CLUSTER train.py --partition gpu_a100 --gpus 4 --cpus 8 \
--task "$DROPOUT --lr $LR --wd $WD $POSITIONAL_ENCODING_SCHEME $RADIAL_ENCODING_PERIOD --psf_type $PTYPE --wavelength $LAM --network $NETWORK --embedding $EMB --patch_size '32-16-8-8' --modes $MODES --max_amplitude $MAXAMP --batch_size $BATCH --dataset $DATA --input_shape $SHAPE --depth_scalar $DEPTH --epochs $EPOCHS --warmup $WARMUP" \
--taskname $NETWORK \
--name new/$SUBSET/$NETWORK-$MODES-$DIR-$LR-$WD-$DROPOUT
done
done
done
Expand Down

0 comments on commit 1294d56

Please sign in to comment.