Skip to content

Commit

Permalink
run isort and black
Browse files Browse the repository at this point in the history
  • Loading branch information
knikolaou committed May 3, 2024
1 parent 5e0335d commit 68505a8
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 27 deletions.
41 changes: 24 additions & 17 deletions CI/unit_tests/training_recording/test_training_recording.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,20 @@
-------
"""

import copy
import tempfile
from pathlib import Path

import h5py as hf
import jax.numpy as np
import numpy as onp
import optax
from neural_tangents import stax
from numpy import testing

from znnl.training_recording import JaxRecorder
from znnl.loss_functions import MeanPowerLoss
from neural_tangents import stax
from znnl import models
import optax
import copy
from znnl.loss_functions import MeanPowerLoss
from znnl.training_recording import JaxRecorder


class TestModelRecording:
Expand Down Expand Up @@ -75,7 +75,6 @@ def setup_class(cls):
stax.Flatten(), stax.Dense(8), stax.Relu(), stax.Dense(5)
)


def test_instantiation(self):
"""
Test the instantiation of the recorder.
Expand Down Expand Up @@ -252,7 +251,7 @@ def test_class_specific_update_fn(self):

# Test trace update
recorder.class_specific_update_fn(
call_fn=recorder._update_trace,
call_fn=recorder._update_trace,
indices=recorder._class_idx[1][0],
parsed_data=self.class_specific_parsed_data,
)
Expand All @@ -262,7 +261,7 @@ def test_class_specific_update_fn(self):

# Test eigenvalues update
recorder.class_specific_update_fn(
call_fn=recorder._update_eigenvalues,
call_fn=recorder._update_eigenvalues,
indices=recorder._class_idx[1][0],
parsed_data=self.class_specific_parsed_data,
)
Expand Down Expand Up @@ -304,7 +303,7 @@ def test_multi_class_update(self):
assert np.shape(recorder._magnitude_variance_array) == (1, 5)
assert np.shape(recorder._eigenvalues_array) == (1, 5, 2)
# Even though NTK is selected, it should not be updated.
assert np.shape(recorder._ntk_array) == ()
assert np.shape(recorder._ntk_array) == ()

# Update the recorder again
recorder.update_recorder(epoch=2, model=model)
Expand All @@ -316,7 +315,7 @@ def test_multi_class_update(self):
assert np.shape(recorder._magnitude_variance_array) == (2, 5)
assert np.shape(recorder._eigenvalues_array) == (2, 5, 2)
# Even though NTK is selected, it should not be updated.
assert np.shape(recorder._ntk_array) == ()
assert np.shape(recorder._ntk_array) == ()

def test_class_combinations(self):
"""
Expand All @@ -335,27 +334,35 @@ def test_class_combinations(self):
recorder._data_set["targets"] = np.concatenate([np.eye(3), np.eye(3)], axis=0)
_, class_combinations = recorder._get_class_combinations()
assert np.all(np.array(class_combinations[:3]) == np.arange(3).reshape(3, 1))
assert np.all(np.array(class_combinations[3:6]) == np.array([[0, 1], [0, 2], [1, 2]]))
assert np.all(
np.array(class_combinations[3:6]) == np.array([[0, 1], [0, 2], [1, 2]])
)
assert np.all(np.array(class_combinations[6:]) == np.array([[0, 1, 2]]))

# Test for non-one-hot encoding
recorder._data_set["targets"] = np.concatenate([np.arange(3), np.arange(3)], axis=0).reshape(6, 1)
recorder._data_set["targets"] = np.concatenate(
[np.arange(3), np.arange(3)], axis=0
).reshape(6, 1)
_, class_combinations = recorder._get_class_combinations()
assert np.all(np.array(class_combinations[:3]) == np.arange(3).reshape(3, 1))
assert np.all(np.array(class_combinations[3:6]) == np.array([[0, 1], [0, 2], [1, 2]]))
assert np.all(
np.array(class_combinations[3:6]) == np.array([[0, 1], [0, 2], [1, 2]])
)
assert np.all(np.array(class_combinations[6:]) == np.array([[0, 1, 2]]))

# Test for non-consecutive classes
idx = np.array([0, 2, 3])
recorder._data_set["targets"] = np.concatenate([idx, idx], axis=0).reshape(6, 1)
_, class_combinations = recorder._get_class_combinations()
assert np.all(np.array(class_combinations[:3]) == np.arange(3).reshape(3, 1))
assert np.all(np.array(class_combinations[3:6]) == np.array([[0, 1], [0, 2], [1, 2]]))
assert np.all(
np.array(class_combinations[3:6]) == np.array([[0, 1], [0, 2], [1, 2]])
)
assert np.all(np.array(class_combinations[6:]) == np.array([[0, 1, 2]]))

def test_entropy_class_correlation(self):
"""
Test the entropy class correlation method.
Test the entropy class correlation method.
"""
recorder = JaxRecorder(
entropy_class_correlations=True,
Expand All @@ -371,11 +378,11 @@ def test_entropy_class_correlation(self):
optimizer=optax.adam(learning_rate=0.01),
input_shape=(1, 2, 3),
)

# Test the correlation
recorder.update_recorder(epoch=1, model=model)
assert np.shape(recorder._entropy_class_correlations_array) == (1, 31)

# Update the recorder again
recorder.update_recorder(epoch=2, model=model)
assert np.shape(recorder._entropy_class_correlations_array) == (2, 31)
assert np.shape(recorder._entropy_class_correlations_array) == (2, 31)
23 changes: 13 additions & 10 deletions znnl/training_recording/jax_recording.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
-------
"""

import itertools
import logging
from dataclasses import dataclass, make_dataclass
from os import path
Expand All @@ -48,7 +49,6 @@
flatten_rank_4_tensor,
normalize_gram_matrix,
)
import itertools

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -99,7 +99,7 @@ class JaxRecorder:
Warning, large overhead.
entropy_class_correlations : bool (default=False)
If true, the entropy of every class and class combination will be recorded.
This results in 2^n - 1 entropy computations, where n is the number of
This results in 2^n - 1 entropy computations, where n is the number of
classes.
Warning, large overhead.
magnitude_variance : bool (default=False)
Expand Down Expand Up @@ -279,13 +279,13 @@ def _get_class_indices(self):
)

return class_specific_idx

def _get_class_combinations(self):
"""
Get data indices of all class combinations.
Create sublists of indices for each class itself and all combinations of
classes.
classes.
For n classes, there are 2^n - 1 combinations of classes, for which the
a list of indices is created.
Expand All @@ -302,7 +302,7 @@ def _get_class_combinations(self):
Each index refers to the position of the class in the
class_specific_idx[0] array.
"""
# Get the indices for each class
# Get the indices for each class
class_specific_idx = self._get_class_indices()
class_labels, _ = class_specific_idx

Expand Down Expand Up @@ -401,12 +401,14 @@ def instantiate_recorder(self, data_set: dict = None, overwrite: bool = False):
if "loss_derivative" in self._selected_properties:
self._loss_derivative_fn = LossDerivative(self._loss_fn)

def class_specific_update_fn(self, call_fn: callable, indices: np.ndarray, parsed_data: dict):
def class_specific_update_fn(
self, call_fn: callable, indices: np.ndarray, parsed_data: dict
):
"""
Update the class specific arrays.
TODO: This only works for an NTK of rank 2. We need to generalize this for
NTKs of rank 4.
NTKs of rank 4.
Parameters
----------
Expand Down Expand Up @@ -533,7 +535,9 @@ def update_recorder(self, epoch: int, model: JaxModel):
elif item == "entropy_class_correlations":
for class_combination in self._class_combinations:
# Get the indices for the class combination
indices = np.concatenate([self._class_idx[1][i] for i in class_combination])
indices = np.concatenate(
[self._class_idx[1][i] for i in class_combination]
)
self.class_specific_update_fn(
call_fn=call_fn,
indices=indices,
Expand Down Expand Up @@ -743,7 +747,7 @@ def _update_entropy_class_correlations(self, parsed_data: dict):
"""
Update the entropy entropy of all class correlations using the covariance NTK.
The entropy of the class correlations calculates the entropy all classes and
The entropy of the class correlations calculates the entropy all classes and
class combinations. This results in 2^n - 1 entropy computations, where n is the
number of classes.
Expand Down Expand Up @@ -917,4 +921,3 @@ def _export_in_memory_data(self) -> dataclass:
}

return DataSet(**selected_data)

0 comments on commit 68505a8

Please sign in to comment.