Skip to content

Commit

Permalink
Fix class wise recording for rank 4 ntk
Browse files Browse the repository at this point in the history
  • Loading branch information
knikolaou committed May 10, 2024
1 parent 290dd28 commit fe1382f
Show file tree
Hide file tree
Showing 5 changed files with 155 additions and 22 deletions.
42 changes: 39 additions & 3 deletions CI/unit_tests/training_recording/test_training_recording.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from znnl import models
from znnl.loss_functions import MeanPowerLoss
from znnl.training_recording import JaxRecorder
from znnl.utils import flatten_rank_4_tensor, unflatten_rank_2_tensor


class TestModelRecording:
Expand Down Expand Up @@ -244,6 +245,7 @@ def test_class_specific_update_fn(self):
trace=True,
eigenvalues=True,
)
recorder._ntk_rank = 2

# Instantiate the recorder
recorder.instantiate_recorder(data_set=self.class_specific_data)
Expand Down Expand Up @@ -401,18 +403,52 @@ def test_read_class_specific_data(self):
# Instantiate the recorder
recorder.instantiate_recorder(data_set=self.class_specific_data)

# Test the case of having one entry per sample, e.g. for recording the
# Test the case of having one entry per sample, e.g. for recording the
# eigenvalues.
test_record = np.arange(20).reshape(2, 10)
class_specific_dict = recorder.read_class_specific_data(test_record)
for i, (key, val) in enumerate(class_specific_dict.items()):
assert key == i
print(val)
assert np.all(val == np.array([[i, i + 5], [i+10, i + 15]]))
assert np.all(val == np.array([[i, i + 5], [i + 10, i + 15]]))

# Test the case of having one entry per class, e.g. for recording the trace.
test_record = np.arange(10).reshape(2, 5)
class_specific_dict = recorder.read_class_specific_data(test_record)
for i, (key, val) in enumerate(class_specific_dict.items()):
assert key == i
assert np.all(val == np.array([i, i + 5]))
assert np.all(val == np.array([i, i + 5]))

def test_select_class_specific_data(self):
"""
Test the selection of class specific data.
"""
recorder = JaxRecorder(
class_specific=True,
ntk=True,
trace=True,
eigenvalues=True,
)

# Instantiate the recorder with traced NTK-like data.
recorder._ntk_rank = 2
data = self.class_specific_parsed_data
data["ntk"] = np.arange(100).reshape(10, 10)
recorder.instantiate_recorder(data_set=data)
# Test it for label 0
indices = np.array([0, 5])
selected_data = recorder._select_class_specific_data(indices, data)
assert np.all(selected_data["ntk"] == np.array([[0, 5], [50, 55]]))

# Instantiate the recorder with full NTK-like data.
recorder._ntk_rank = 4
data = self.class_specific_parsed_data
data["ntk"] = np.arange(2500).reshape(50, 50)
recorder.instantiate_recorder(data_set=data)
# Test a dummy selection
indices = np.array([0, 5])
selected_data = recorder._select_class_specific_data(indices, data)
# Calculate the true selection
ntk_ = unflatten_rank_2_tensor(data["ntk"], 10, 5)
true_selection = flatten_rank_4_tensor(ntk_[np.ix_(indices, indices)])
assert np.all(selected_data["ntk"] == true_selection)
21 changes: 21 additions & 0 deletions CI/unit_tests/utils/test_matrix_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
compute_magnitude_density,
flatten_rank_4_tensor,
normalize_gram_matrix,
unflatten_rank_2_tensor,
)


Expand Down Expand Up @@ -168,3 +169,23 @@ def test_flatten_rank_4_tensor(self):
[[0, 1, 4, 5], [2, 3, 6, 7], [8, 9, 12, 13], [10, 11, 14, 15]]
)
assert_array_equal(flatten_rank_4_tensor(tensor), assertion_matrix)

def test_unflatten_rank_2_tensor(self):
"""
Test the unflattening of a rank 2 tensor.
It should invert the operation of flatten_rank_4_tensor.
"""
# Check for assertion errors
tensor = np.arange(24).reshape((6, 4))
n = 2
m = 3
assert_raises(ValueError, unflatten_rank_2_tensor, tensor, n, m)
tensor = np.arange(24).reshape((4, 6))
assert_raises(ValueError, unflatten_rank_2_tensor, tensor, n, m)

# Check the unflattening
tensor = np.arange(4 * 4).reshape(2, 2, 2, 2)
flattened_tensor = flatten_rank_4_tensor(tensor)
unflattened_tensor = unflatten_rank_2_tensor(flattened_tensor, 2, 2)
assert_array_equal(unflattened_tensor, tensor)
74 changes: 55 additions & 19 deletions znnl/training_recording/jax_recording.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
compute_magnitude_density,
flatten_rank_4_tensor,
normalize_gram_matrix,
unflatten_rank_2_tensor,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -248,6 +249,8 @@ def _get_class_indices(self):
"""
Get indices of the classes, when class specific properties are recorded.
TODO: Generalize this method for rank 4 NTKs.
Returns
-------
class_specific_idx : tuple
Expand All @@ -261,8 +264,8 @@ def _get_class_indices(self):
numpy array of indices for each class.
"""
class_idx = np.unique(self._data_set["targets"], axis=0, return_index=True)[1]
class_idx = np.sort(class_idx)
t, class_idx = np.unique(self._data_set["targets"], axis=0, return_index=True)
class_idx = class_idx[np.argsort(np.argmax(t, axis=1))]
class_labels = np.take(self._data_set["targets"], class_idx, axis=0)
label_dim = class_labels.shape[1]
if label_dim > 1:
Expand Down Expand Up @@ -318,34 +321,38 @@ def _get_class_combinations(self):
class_combinations.extend(list(itertools.combinations(classes, i)))

return class_specific_idx, class_combinations

def read_class_specific_data(self, data_array: np.ndarray):
"""
Read class specific data.
Due to storage of arrays, class specific data is stored in a list of arrays.
This method separates recorded data for each class and returns it as a
This method separates recorded data for each class and returns it as a
dictionary.
Parameters
----------
data_array : np.ndarray
Data to read and separate.
Returns
-------
class_specific_data : dict
Dictionary containing the separated data.
"""
# Sort the data into class specific arrays.
class_specific_data = {}
class_specific_data = {}

for class_label, indices in zip(self._class_idx[0].tolist(), self._class_idx[1]):
for class_label, indices in zip(
self._class_idx[0].tolist(), self._class_idx[1]
):

# Check whether each class has one or multiple entries.
if np.shape(data_array)[1] == len(self._class_idx[0]):
# This is the case when there is only one entry per class.
class_specific_data[class_label] = np.take(data_array, class_label, axis=1)
class_specific_data[class_label] = np.take(
data_array, class_label, axis=1
)
elif np.shape(data_array)[1] > len(self._class_idx[0]):
# This is the case when there is one entry per sample.
class_specific_data[class_label] = np.take(data_array, indices, axis=1)
Expand Down Expand Up @@ -466,10 +473,12 @@ def class_specific_update_fn(
data = self._select_class_specific_data(indices, parsed_data)
call_fn(data)

@staticmethod
def _select_class_specific_data(indices: np.ndarray, parsed_data: dict):
def _select_class_specific_data(self, indices: np.ndarray, parsed_data: dict):
"""
Select class specific data.
Select class specific data from the parsed data.
Using pre-selected indices, the data corresponding to a single class is
selected from the parsed data.
Parameters
----------
Expand All @@ -482,16 +491,41 @@ def _select_class_specific_data(indices: np.ndarray, parsed_data: dict):
-------
Updates the class specific arrays.
"""

# Get the class specific predictions
data = {"predictions": np.take(parsed_data["predictions"], indices, axis=0)}
data["targets"] = np.take(parsed_data["targets"], indices, axis=0)

# Get the class specific NTK entries
try:
data["ntk"] = parsed_data["ntk"][np.ix_(indices, indices)]
ntk = parsed_data["ntk"]
raw_data_shape = np.shape(parsed_data["predictions"])

if self._ntk_rank == 2:
ntk = ntk[np.ix_(indices, indices)]
data["ntk"] = ntk
elif self._ntk_rank == 4 and self.flatten_ntk:
ntk = unflatten_rank_2_tensor(ntk, raw_data_shape[0], raw_data_shape[1])
ntk = ntk[np.ix_(indices, indices)]
data["ntk"] = flatten_rank_4_tensor(ntk)
elif self._ntk_rank == 4 and not self.flatten_ntk:
ntk = unflatten_rank_2_tensor(ntk, raw_data_shape[0], raw_data_shape[1])
ntk = ntk[np.ix_(indices, indices)]
data["ntk"] = ntk
logger.warning(
"The NTK is of rank 4 and not flattened. This may lead to "
"unexpected results or errors in calculations of observations."
)
else:
raise ValueError(
"The NTK rank is not supported for class specific recording."
)

except KeyError:
pass

return data

def _reformat_class_specific_recording(self, data_array: np.ndarray):
"""
Reformat class specific data from the recorder.
Expand Down Expand Up @@ -521,9 +555,9 @@ def _reformat_class_specific_recording(self, data_array: np.ndarray):
new_data = np.concatenate(new_data_unformatted, axis=0)
# Append the new data to the old data.
data = old_data + [new_data]

return data

def update_recorder(self, epoch: int, model: JaxModel):
"""
Update the values stored in the recorder.
Expand Down Expand Up @@ -588,16 +622,18 @@ def update_recorder(self, epoch: int, model: JaxModel):
if self.class_specific and item != "entropy_class_correlations":

# Loop over the classes and update the properties.
for class_label in self._class_idx[0]:
for i, class_label in enumerate(self._class_idx[0]):
self.class_specific_update_fn(
call_fn=call_fn,
indices=self._class_idx[1][class_label],
indices=self._class_idx[1][i],
parsed_data=parsed_data,
)

# Re-format the updated changes in the corresponding arrays.
array = self.__dict__[f"_{item}_array"]
self.__dict__[f"_{item}_array"] = self._reformat_class_specific_recording(array)
self.__dict__[f"_{item}_array"] = (
self._reformat_class_specific_recording(array)
)

elif item == "entropy_class_correlations":
for class_combination in self._class_combinations:
Expand Down Expand Up @@ -829,7 +865,7 @@ class combinations. This results in 2^n - 1 entropy computations, where n is the
cov_ntk = normalize_gram_matrix(parsed_data["ntk"])
calculator = EntropyAnalysis(matrix=cov_ntk)
entropy = calculator.compute_von_neumann_entropy(
effective=False, normalize_eig=True
effective=True, normalize_eig=True
)
self._entropy_class_correlations_array.append(entropy)

Expand Down
2 changes: 2 additions & 0 deletions znnl/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,14 @@
compute_eigensystem,
flatten_rank_4_tensor,
normalize_gram_matrix,
unflatten_rank_2_tensor,
)
from znnl.utils.prng import PRNGKey

__all__ = [
compute_eigensystem.__name__,
normalize_gram_matrix.__name__,
flatten_rank_4_tensor.__name__,
unflatten_rank_2_tensor.__name__,
PRNGKey.__name__,
]
38 changes: 38 additions & 0 deletions znnl/utils/matrix_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,44 @@ def flatten_rank_4_tensor(tensor: np.ndarray) -> np.ndarray:
)


def unflatten_rank_2_tensor(tensor: np.ndarray, n: int, m: int) -> np.ndarray:
"""
Unflatten a rank 2 tensor to a rank 4 tensor using a specific reshaping.
The tensor is assumed to be of shape (n * m, n * m). The reshaping is done by
concatenating first with the third and then with the fourth dimension, resulting
in a tensor of shape (n, n, m, m).
Parameters
----------
tensor : np.ndarray (shape=(n * m, n * m))
Tensor to unflatten.
n : int
First dimension of the unflattened tensor.
m : int
Second dimension of the unflattened tensor.
Returns
-------
unflattened_tensor : np.ndarray (shape=(n, n, m, m))
Unflattened tensor.
"""

if not n * m == tensor.shape[0]:
raise ValueError(
"The shape of the tensor does not match the given dimensions. "
f"Expected {n * m} but got {tensor.shape[0]}."
)
if not n * m == tensor.shape[1]:
raise ValueError(
"The shape of the tensor does not match the given dimensions. "
f"Expected {n * m} but got {tensor.shape[1]}."
)

_tensor = tensor.reshape(n, m, n, m)
return np.moveaxis(_tensor, [2, 1], [1, 2])


def calculate_trace(matrix: np.ndarray, normalize: bool = False) -> np.ndarray:
"""
Calculate the trace of a matrix, including optional normalization.
Expand Down

0 comments on commit fe1382f

Please sign in to comment.