From 68505a834c3dc3c46b7a940f4bc4aaec2afd133f Mon Sep 17 00:00:00 2001 From: knikolaou <> Date: Fri, 3 May 2024 14:33:16 +0200 Subject: [PATCH] run isort and black --- .../test_training_recording.py | 41 +++++++++++-------- znnl/training_recording/jax_recording.py | 23 ++++++----- 2 files changed, 37 insertions(+), 27 deletions(-) diff --git a/CI/unit_tests/training_recording/test_training_recording.py b/CI/unit_tests/training_recording/test_training_recording.py index 6ccea27..fef4d6f 100644 --- a/CI/unit_tests/training_recording/test_training_recording.py +++ b/CI/unit_tests/training_recording/test_training_recording.py @@ -25,20 +25,20 @@ ------- """ +import copy import tempfile from pathlib import Path import h5py as hf import jax.numpy as np import numpy as onp +import optax +from neural_tangents import stax from numpy import testing -from znnl.training_recording import JaxRecorder -from znnl.loss_functions import MeanPowerLoss -from neural_tangents import stax from znnl import models -import optax -import copy +from znnl.loss_functions import MeanPowerLoss +from znnl.training_recording import JaxRecorder class TestModelRecording: @@ -75,7 +75,6 @@ def setup_class(cls): stax.Flatten(), stax.Dense(8), stax.Relu(), stax.Dense(5) ) - def test_instantiation(self): """ Test the instantiation of the recorder. @@ -252,7 +251,7 @@ def test_class_specific_update_fn(self): # Test trace update recorder.class_specific_update_fn( - call_fn=recorder._update_trace, + call_fn=recorder._update_trace, indices=recorder._class_idx[1][0], parsed_data=self.class_specific_parsed_data, ) @@ -262,7 +261,7 @@ def test_class_specific_update_fn(self): # Test eigenvalues update recorder.class_specific_update_fn( - call_fn=recorder._update_eigenvalues, + call_fn=recorder._update_eigenvalues, indices=recorder._class_idx[1][0], parsed_data=self.class_specific_parsed_data, ) @@ -304,7 +303,7 @@ def test_multi_class_update(self): assert np.shape(recorder._magnitude_variance_array) == (1, 5) assert np.shape(recorder._eigenvalues_array) == (1, 5, 2) # Even though NTK is selected, it should not be updated. - assert np.shape(recorder._ntk_array) == () + assert np.shape(recorder._ntk_array) == () # Update the recorder again recorder.update_recorder(epoch=2, model=model) @@ -316,7 +315,7 @@ def test_multi_class_update(self): assert np.shape(recorder._magnitude_variance_array) == (2, 5) assert np.shape(recorder._eigenvalues_array) == (2, 5, 2) # Even though NTK is selected, it should not be updated. - assert np.shape(recorder._ntk_array) == () + assert np.shape(recorder._ntk_array) == () def test_class_combinations(self): """ @@ -335,14 +334,20 @@ def test_class_combinations(self): recorder._data_set["targets"] = np.concatenate([np.eye(3), np.eye(3)], axis=0) _, class_combinations = recorder._get_class_combinations() assert np.all(np.array(class_combinations[:3]) == np.arange(3).reshape(3, 1)) - assert np.all(np.array(class_combinations[3:6]) == np.array([[0, 1], [0, 2], [1, 2]])) + assert np.all( + np.array(class_combinations[3:6]) == np.array([[0, 1], [0, 2], [1, 2]]) + ) assert np.all(np.array(class_combinations[6:]) == np.array([[0, 1, 2]])) # Test for non-one-hot encoding - recorder._data_set["targets"] = np.concatenate([np.arange(3), np.arange(3)], axis=0).reshape(6, 1) + recorder._data_set["targets"] = np.concatenate( + [np.arange(3), np.arange(3)], axis=0 + ).reshape(6, 1) _, class_combinations = recorder._get_class_combinations() assert np.all(np.array(class_combinations[:3]) == np.arange(3).reshape(3, 1)) - assert np.all(np.array(class_combinations[3:6]) == np.array([[0, 1], [0, 2], [1, 2]])) + assert np.all( + np.array(class_combinations[3:6]) == np.array([[0, 1], [0, 2], [1, 2]]) + ) assert np.all(np.array(class_combinations[6:]) == np.array([[0, 1, 2]])) # Test for non-consecutive classes @@ -350,12 +355,14 @@ def test_class_combinations(self): recorder._data_set["targets"] = np.concatenate([idx, idx], axis=0).reshape(6, 1) _, class_combinations = recorder._get_class_combinations() assert np.all(np.array(class_combinations[:3]) == np.arange(3).reshape(3, 1)) - assert np.all(np.array(class_combinations[3:6]) == np.array([[0, 1], [0, 2], [1, 2]])) + assert np.all( + np.array(class_combinations[3:6]) == np.array([[0, 1], [0, 2], [1, 2]]) + ) assert np.all(np.array(class_combinations[6:]) == np.array([[0, 1, 2]])) def test_entropy_class_correlation(self): """ - Test the entropy class correlation method. + Test the entropy class correlation method. """ recorder = JaxRecorder( entropy_class_correlations=True, @@ -371,11 +378,11 @@ def test_entropy_class_correlation(self): optimizer=optax.adam(learning_rate=0.01), input_shape=(1, 2, 3), ) - + # Test the correlation recorder.update_recorder(epoch=1, model=model) assert np.shape(recorder._entropy_class_correlations_array) == (1, 31) # Update the recorder again recorder.update_recorder(epoch=2, model=model) - assert np.shape(recorder._entropy_class_correlations_array) == (2, 31) \ No newline at end of file + assert np.shape(recorder._entropy_class_correlations_array) == (2, 31) diff --git a/znnl/training_recording/jax_recording.py b/znnl/training_recording/jax_recording.py index 978e2be..455d149 100644 --- a/znnl/training_recording/jax_recording.py +++ b/znnl/training_recording/jax_recording.py @@ -25,6 +25,7 @@ ------- """ +import itertools import logging from dataclasses import dataclass, make_dataclass from os import path @@ -48,7 +49,6 @@ flatten_rank_4_tensor, normalize_gram_matrix, ) -import itertools logger = logging.getLogger(__name__) @@ -99,7 +99,7 @@ class JaxRecorder: Warning, large overhead. entropy_class_correlations : bool (default=False) If true, the entropy of every class and class combination will be recorded. - This results in 2^n - 1 entropy computations, where n is the number of + This results in 2^n - 1 entropy computations, where n is the number of classes. Warning, large overhead. magnitude_variance : bool (default=False) @@ -279,13 +279,13 @@ def _get_class_indices(self): ) return class_specific_idx - + def _get_class_combinations(self): """ Get data indices of all class combinations. Create sublists of indices for each class itself and all combinations of - classes. + classes. For n classes, there are 2^n - 1 combinations of classes, for which the a list of indices is created. @@ -302,7 +302,7 @@ def _get_class_combinations(self): Each index refers to the position of the class in the class_specific_idx[0] array. """ - # Get the indices for each class + # Get the indices for each class class_specific_idx = self._get_class_indices() class_labels, _ = class_specific_idx @@ -401,12 +401,14 @@ def instantiate_recorder(self, data_set: dict = None, overwrite: bool = False): if "loss_derivative" in self._selected_properties: self._loss_derivative_fn = LossDerivative(self._loss_fn) - def class_specific_update_fn(self, call_fn: callable, indices: np.ndarray, parsed_data: dict): + def class_specific_update_fn( + self, call_fn: callable, indices: np.ndarray, parsed_data: dict + ): """ Update the class specific arrays. TODO: This only works for an NTK of rank 2. We need to generalize this for - NTKs of rank 4. + NTKs of rank 4. Parameters ---------- @@ -533,7 +535,9 @@ def update_recorder(self, epoch: int, model: JaxModel): elif item == "entropy_class_correlations": for class_combination in self._class_combinations: # Get the indices for the class combination - indices = np.concatenate([self._class_idx[1][i] for i in class_combination]) + indices = np.concatenate( + [self._class_idx[1][i] for i in class_combination] + ) self.class_specific_update_fn( call_fn=call_fn, indices=indices, @@ -743,7 +747,7 @@ def _update_entropy_class_correlations(self, parsed_data: dict): """ Update the entropy entropy of all class correlations using the covariance NTK. - The entropy of the class correlations calculates the entropy all classes and + The entropy of the class correlations calculates the entropy all classes and class combinations. This results in 2^n - 1 entropy computations, where n is the number of classes. @@ -917,4 +921,3 @@ def _export_in_memory_data(self) -> dataclass: } return DataSet(**selected_data) -