From 0c7173f7213f80b2c32c63bfe345ebaa9e0426ff Mon Sep 17 00:00:00 2001 From: knikolaou <> Date: Fri, 13 Oct 2023 15:56:43 +0200 Subject: [PATCH] run black and isort --- .../regulaizers/test_trace_regularizer.py | 19 +++++++++---------- .../loss_aware_reservoir.py | 2 +- .../partitioned_training.py | 2 +- znnl/training_strategies/simple_training.py | 9 +++------ 4 files changed, 14 insertions(+), 18 deletions(-) diff --git a/CI/unit_tests/regulaizers/test_trace_regularizer.py b/CI/unit_tests/regulaizers/test_trace_regularizer.py index 6dc3d79..b2ec9b9 100644 --- a/CI/unit_tests/regulaizers/test_trace_regularizer.py +++ b/CI/unit_tests/regulaizers/test_trace_regularizer.py @@ -36,8 +36,8 @@ import optax from flax import linen as nn from flax.training.train_state import TrainState -from neural_tangents import stax from jax import random +from neural_tangents import stax from znnl.models.flax_model import FlaxModel from znnl.models.nt_model import NTModel @@ -93,7 +93,6 @@ def create_flax_model(cls, key: int) -> FlaxModel: ) return flax_model - def test_constructor(self): """ Test the constructor of the norm regularizer class. @@ -128,10 +127,10 @@ def test_calculate_regularization(self): model=nt_model, params=nt_model.model_state.params, batch=batch, epoch=1 ) # Calculate norm from NTK - ntk = nt_model.compute_ntk( - batch["inputs"], infinite=False - )["empirical"] - num_parameters = jax.flatten_util.ravel_pytree(nt_model.model_state.params)[0].shape[0] + ntk = nt_model.compute_ntk(batch["inputs"], infinite=False)["empirical"] + num_parameters = jax.flatten_util.ravel_pytree(nt_model.model_state.params)[ + 0 + ].shape[0] normed_ntk = ntk / num_parameters diag_ntk = np.diagonal(normed_ntk) mean_trace = np.mean(diag_ntk) @@ -142,10 +141,10 @@ def test_calculate_regularization(self): model=flax_model, params=flax_model.model_state.params, batch=batch, epoch=1 ) # Calculate norm from NTK - ntk = flax_model.compute_ntk( - batch["inputs"], infinite=False - )['empirical'] - num_parameters = jax.flatten_util.ravel_pytree(flax_model.model_state.params)[0].shape[0] + ntk = flax_model.compute_ntk(batch["inputs"], infinite=False)["empirical"] + num_parameters = jax.flatten_util.ravel_pytree(flax_model.model_state.params)[ + 0 + ].shape[0] normed_ntk = ntk / num_parameters diag_ntk = np.diagonal(normed_ntk) mean_trace = np.mean(diag_ntk) diff --git a/znnl/training_strategies/loss_aware_reservoir.py b/znnl/training_strategies/loss_aware_reservoir.py index 4fc6cd4..34da1db 100644 --- a/znnl/training_strategies/loss_aware_reservoir.py +++ b/znnl/training_strategies/loss_aware_reservoir.py @@ -419,7 +419,7 @@ def train_model( train_accuracy = [] for i in loading_bar: self.epoch = i - + # Update the recorder properties if self.recorders is not None: for item in self.recorders: diff --git a/znnl/training_strategies/partitioned_training.py b/znnl/training_strategies/partitioned_training.py index f1b6b39..c52a98b 100644 --- a/znnl/training_strategies/partitioned_training.py +++ b/znnl/training_strategies/partitioned_training.py @@ -279,7 +279,7 @@ def train_model( for i in loading_bar: self.epoch = i - + # Update the recorder properties if self.recorders is not None: for item in self.recorders: diff --git a/znnl/training_strategies/simple_training.py b/znnl/training_strategies/simple_training.py index 9676f04..2a59bad 100644 --- a/znnl/training_strategies/simple_training.py +++ b/znnl/training_strategies/simple_training.py @@ -37,11 +37,11 @@ from znnl.accuracy_functions.accuracy_function import AccuracyFunction from znnl.models.jax_model import JaxModel from znnl.optimizers.trace_optimizer import TraceOptimizer +from znnl.regularizers import Regularizer from znnl.training_recording import JaxRecorder from znnl.training_strategies.recursive_mode import RecursiveMode from znnl.training_strategies.training_decorator import train_func from znnl.utils.prng import PRNGKey -from znnl.regularizers import Regularizer logger = logging.getLogger(__name__) @@ -217,10 +217,7 @@ def loss_fn(params): # Add gradient regularization if self.regularizer: reg_loss = self.regularizer( - model=self.model, - params=params, - batch=batch, - epoch=self.epoch + model=self.model, params=params, batch=batch, epoch=self.epoch ) loss += reg_loss return loss, inner_predictions @@ -384,7 +381,7 @@ def train_model( train_accuracy = [] for i in loading_bar: self.epoch = i - + # Update the recorder properties if self.recorders is not None: for item in self.recorders: