Skip to content

Commit

Permalink
Merge pull request #242 from NREL/gb/optm_state
Browse files Browse the repository at this point in the history
Gb/optm state
  • Loading branch information
grantbuster authored Nov 14, 2024
2 parents e43e4bb + a74ed5c commit 9c34e65
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 15 deletions.
25 changes: 24 additions & 1 deletion sup3r/models/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -1037,7 +1037,7 @@ def get_optimizer_config(optimizer):
Parameters
----------
optimizer : tf.keras.optimizers.Optimizer
TF-Keras optimizer object
TF-Keras optimizer object (e.g., Adam)
Returns
-------
Expand All @@ -1053,6 +1053,29 @@ def get_optimizer_config(optimizer):
conf[k] = int(v)
return conf

@classmethod
def get_optimizer_state(cls, optimizer):
"""Get a set of state variables for the optimizer
Parameters
----------
optimizer : tf.keras.optimizers.Optimizer
TF-Keras optimizer object (e.g., Adam)
Returns
-------
state : dict
Optimizer state variables
"""
lr = cls.get_optimizer_config(optimizer)['learning_rate']
state = {'learning_rate': lr}
for var in optimizer.variables:
name = var.name
var = var.numpy().flatten()
var = np.abs(var).mean() # collapse ndarrays into mean absolute
state[name] = float(var)
return state

@staticmethod
def update_loss_details(loss_details, new_data, batch_len, prefix=None):
"""Update a dictionary of loss_details with loss information from a new
Expand Down
22 changes: 12 additions & 10 deletions sup3r/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,8 +746,10 @@ def train_epoch(

b_loss_details['gen_trained_frac'] = float(trained_gen)
b_loss_details['disc_trained_frac'] = float(trained_disc)

self.dict_to_tensorboard(b_loss_details)
self.dict_to_tensorboard(self.timer.log)

loss_details = self.update_loss_details(
loss_details,
b_loss_details,
Expand Down Expand Up @@ -1000,10 +1002,9 @@ def train(
loss_details['train_loss_gen'], loss_details['train_loss_disc']
)

if all(
loss in loss_details
for loss in ('val_loss_gen', 'val_loss_disc')
):
check1 = 'val_loss_gen' in loss_details
check2 = 'val_loss_disc' in loss_details
if check1 and check2:
msg += 'gen/disc val loss: {:.2e}/{:.2e} '.format(
loss_details['val_loss_gen'], loss_details['val_loss_disc']
)
Expand All @@ -1016,14 +1017,15 @@ def train(
'weight_gen_advers': weight_gen_advers,
'disc_loss_bound_0': disc_loss_bounds[0],
'disc_loss_bound_1': disc_loss_bounds[1],
'learning_rate_gen': self.get_optimizer_config(self.optimizer)[
'learning_rate'
],
'learning_rate_disc': self.get_optimizer_config(
self.optimizer_disc
)['learning_rate'],
}

opt_g = self.get_optimizer_state(self.optimizer)
opt_d = self.get_optimizer_state(self.optimizer_disc)
opt_g = {f'OptmGen/{key}': val for key, val in opt_g.items()}
opt_d = {f'OptmDisc/{key}': val for key, val in opt_d.items()}
extras.update(opt_g)
extras.update(opt_d)

weight_gen_advers = self.update_adversarial_weights(
loss_details,
adaptive_update_fraction,
Expand Down
13 changes: 11 additions & 2 deletions sup3r/preprocessing/cachers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from sup3r.preprocessing.base import Container
from sup3r.preprocessing.names import Dimension
from sup3r.preprocessing.utilities import _mem_check, log_args, _lowered
from sup3r.utilities.utilities import safe_cast
from sup3r.utilities.utilities import safe_cast, safe_serialize
from rex.utilities.utilities import to_records_array

from .utilities import _check_for_cache
Expand Down Expand Up @@ -475,7 +475,16 @@ def write_netcdf(
ncfile.variables[dset][:] = np.asarray(data_var.data)

for attr_name, attr_value in attrs.items():
ncfile.setncattr(attr_name, safe_cast(attr_value))
attr_value = safe_cast(attr_value)
try:
ncfile.setncattr(attr_name, attr_value)
except Exception as e:
msg = (f'Could not write {attr_name} as attribute, '
f'serializing with json dumps, '
f'received error: "{e}"')
logger.warning(msg)
warn(msg)
ncfile.setncattr(attr_name, safe_serialize(attr_value))

for feature in features:
cls.write_netcdf_chunks(
Expand Down
10 changes: 8 additions & 2 deletions tests/training/test_train_gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,14 @@ def test_train(fp_gen, fp_disc, s_enhance, t_enhance, sample_shape, n_epoch=8):

assert np.allclose(model_params['optimizer']['learning_rate'], lr)
assert np.allclose(model_params['optimizer_disc']['learning_rate'], lr)
assert 'learning_rate_gen' in model.history
assert 'learning_rate_disc' in model.history
assert 'OptmGen/learning_rate' in model.history
assert 'OptmDisc/learning_rate' in model.history

msg = ('Could not find OptmGen states in columns: '
f'{sorted(model.history.columns)}')
check = [col.startswith('OptmGen/Adam/v')
for col in model.history.columns]
assert any(check), msg

assert 'config_generator' in loaded.meta
assert 'config_discriminator' in loaded.meta
Expand Down

0 comments on commit 9c34e65

Please sign in to comment.