Skip to content

Commit

Permalink
run black and isort
Browse files Browse the repository at this point in the history
  • Loading branch information
knikolaou committed Oct 13, 2023
1 parent ebdf3f1 commit 0c7173f
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 18 deletions.
19 changes: 9 additions & 10 deletions CI/unit_tests/regulaizers/test_trace_regularizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@
import optax
from flax import linen as nn
from flax.training.train_state import TrainState
from neural_tangents import stax
from jax import random
from neural_tangents import stax

from znnl.models.flax_model import FlaxModel
from znnl.models.nt_model import NTModel
Expand Down Expand Up @@ -93,7 +93,6 @@ def create_flax_model(cls, key: int) -> FlaxModel:
)
return flax_model


def test_constructor(self):
"""
Test the constructor of the norm regularizer class.
Expand Down Expand Up @@ -128,10 +127,10 @@ def test_calculate_regularization(self):
model=nt_model, params=nt_model.model_state.params, batch=batch, epoch=1
)
# Calculate norm from NTK
ntk = nt_model.compute_ntk(
batch["inputs"], infinite=False
)["empirical"]
num_parameters = jax.flatten_util.ravel_pytree(nt_model.model_state.params)[0].shape[0]
ntk = nt_model.compute_ntk(batch["inputs"], infinite=False)["empirical"]
num_parameters = jax.flatten_util.ravel_pytree(nt_model.model_state.params)[
0
].shape[0]
normed_ntk = ntk / num_parameters
diag_ntk = np.diagonal(normed_ntk)
mean_trace = np.mean(diag_ntk)
Expand All @@ -142,10 +141,10 @@ def test_calculate_regularization(self):
model=flax_model, params=flax_model.model_state.params, batch=batch, epoch=1
)
# Calculate norm from NTK
ntk = flax_model.compute_ntk(
batch["inputs"], infinite=False
)['empirical']
num_parameters = jax.flatten_util.ravel_pytree(flax_model.model_state.params)[0].shape[0]
ntk = flax_model.compute_ntk(batch["inputs"], infinite=False)["empirical"]
num_parameters = jax.flatten_util.ravel_pytree(flax_model.model_state.params)[
0
].shape[0]
normed_ntk = ntk / num_parameters
diag_ntk = np.diagonal(normed_ntk)
mean_trace = np.mean(diag_ntk)
Expand Down
2 changes: 1 addition & 1 deletion znnl/training_strategies/loss_aware_reservoir.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@ def train_model(
train_accuracy = []
for i in loading_bar:
self.epoch = i

# Update the recorder properties
if self.recorders is not None:
for item in self.recorders:
Expand Down
2 changes: 1 addition & 1 deletion znnl/training_strategies/partitioned_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ def train_model(

for i in loading_bar:
self.epoch = i

# Update the recorder properties
if self.recorders is not None:
for item in self.recorders:
Expand Down
9 changes: 3 additions & 6 deletions znnl/training_strategies/simple_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,11 @@
from znnl.accuracy_functions.accuracy_function import AccuracyFunction
from znnl.models.jax_model import JaxModel
from znnl.optimizers.trace_optimizer import TraceOptimizer
from znnl.regularizers import Regularizer
from znnl.training_recording import JaxRecorder
from znnl.training_strategies.recursive_mode import RecursiveMode
from znnl.training_strategies.training_decorator import train_func
from znnl.utils.prng import PRNGKey
from znnl.regularizers import Regularizer

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -217,10 +217,7 @@ def loss_fn(params):
# Add gradient regularization
if self.regularizer:
reg_loss = self.regularizer(
model=self.model,
params=params,
batch=batch,
epoch=self.epoch
model=self.model, params=params, batch=batch, epoch=self.epoch
)
loss += reg_loss
return loss, inner_predictions
Expand Down Expand Up @@ -384,7 +381,7 @@ def train_model(
train_accuracy = []
for i in loading_bar:
self.epoch = i

# Update the recorder properties
if self.recorders is not None:
for item in self.recorders:
Expand Down

0 comments on commit 0c7173f

Please sign in to comment.