diff --git a/CI/unit_tests/training_recording/test_training_recording.py b/CI/unit_tests/training_recording/test_training_recording.py index 7e8b9d8..80cb1dc 100644 --- a/CI/unit_tests/training_recording/test_training_recording.py +++ b/CI/unit_tests/training_recording/test_training_recording.py @@ -39,6 +39,7 @@ from znnl import models from znnl.loss_functions import MeanPowerLoss from znnl.training_recording import JaxRecorder +from znnl.utils import flatten_rank_4_tensor, unflatten_rank_2_tensor class TestModelRecording: @@ -244,6 +245,7 @@ def test_class_specific_update_fn(self): trace=True, eigenvalues=True, ) + recorder._ntk_rank = 2 # Instantiate the recorder recorder.instantiate_recorder(data_set=self.class_specific_data) @@ -401,18 +403,52 @@ def test_read_class_specific_data(self): # Instantiate the recorder recorder.instantiate_recorder(data_set=self.class_specific_data) - # Test the case of having one entry per sample, e.g. for recording the + # Test the case of having one entry per sample, e.g. for recording the # eigenvalues. test_record = np.arange(20).reshape(2, 10) class_specific_dict = recorder.read_class_specific_data(test_record) for i, (key, val) in enumerate(class_specific_dict.items()): assert key == i print(val) - assert np.all(val == np.array([[i, i + 5], [i+10, i + 15]])) + assert np.all(val == np.array([[i, i + 5], [i + 10, i + 15]])) # Test the case of having one entry per class, e.g. for recording the trace. test_record = np.arange(10).reshape(2, 5) class_specific_dict = recorder.read_class_specific_data(test_record) for i, (key, val) in enumerate(class_specific_dict.items()): assert key == i - assert np.all(val == np.array([i, i + 5])) \ No newline at end of file + assert np.all(val == np.array([i, i + 5])) + + def test_select_class_specific_data(self): + """ + Test the selection of class specific data. + """ + recorder = JaxRecorder( + class_specific=True, + ntk=True, + trace=True, + eigenvalues=True, + ) + + # Instantiate the recorder with traced NTK-like data. + recorder._ntk_rank = 2 + data = self.class_specific_parsed_data + data["ntk"] = np.arange(100).reshape(10, 10) + recorder.instantiate_recorder(data_set=data) + # Test it for label 0 + indices = np.array([0, 5]) + selected_data = recorder._select_class_specific_data(indices, data) + assert np.all(selected_data["ntk"] == np.array([[0, 5], [50, 55]])) + + # Instantiate the recorder with full NTK-like data. + recorder._ntk_rank = 4 + data = self.class_specific_parsed_data + data["ntk"] = np.arange(2500).reshape(50, 50) + recorder.instantiate_recorder(data_set=data) + # Test a dummy selection + indices = np.array([0, 5]) + selected_data = recorder._select_class_specific_data(indices, data) + # Calculate the true selection + ntk_ = unflatten_rank_2_tensor(data["ntk"], 10, 5) + true_selection = flatten_rank_4_tensor(ntk_[np.ix_(indices, indices)]) + assert np.all(selected_data["ntk"] == true_selection) diff --git a/CI/unit_tests/utils/test_matrix_utils.py b/CI/unit_tests/utils/test_matrix_utils.py index 44c5a5a..7fa4b3b 100644 --- a/CI/unit_tests/utils/test_matrix_utils.py +++ b/CI/unit_tests/utils/test_matrix_utils.py @@ -38,6 +38,7 @@ compute_magnitude_density, flatten_rank_4_tensor, normalize_gram_matrix, + unflatten_rank_2_tensor, ) @@ -168,3 +169,23 @@ def test_flatten_rank_4_tensor(self): [[0, 1, 4, 5], [2, 3, 6, 7], [8, 9, 12, 13], [10, 11, 14, 15]] ) assert_array_equal(flatten_rank_4_tensor(tensor), assertion_matrix) + + def test_unflatten_rank_2_tensor(self): + """ + Test the unflattening of a rank 2 tensor. + + It should invert the operation of flatten_rank_4_tensor. + """ + # Check for assertion errors + tensor = np.arange(24).reshape((6, 4)) + n = 2 + m = 3 + assert_raises(ValueError, unflatten_rank_2_tensor, tensor, n, m) + tensor = np.arange(24).reshape((4, 6)) + assert_raises(ValueError, unflatten_rank_2_tensor, tensor, n, m) + + # Check the unflattening + tensor = np.arange(4 * 4).reshape(2, 2, 2, 2) + flattened_tensor = flatten_rank_4_tensor(tensor) + unflattened_tensor = unflatten_rank_2_tensor(flattened_tensor, 2, 2) + assert_array_equal(unflattened_tensor, tensor) diff --git a/znnl/training_recording/jax_recording.py b/znnl/training_recording/jax_recording.py index 5d360ad..34e7f31 100644 --- a/znnl/training_recording/jax_recording.py +++ b/znnl/training_recording/jax_recording.py @@ -48,6 +48,7 @@ compute_magnitude_density, flatten_rank_4_tensor, normalize_gram_matrix, + unflatten_rank_2_tensor, ) logger = logging.getLogger(__name__) @@ -248,6 +249,8 @@ def _get_class_indices(self): """ Get indices of the classes, when class specific properties are recorded. + TODO: Generalize this method for rank 4 NTKs. + Returns ------- class_specific_idx : tuple @@ -261,8 +264,8 @@ def _get_class_indices(self): numpy array of indices for each class. """ - class_idx = np.unique(self._data_set["targets"], axis=0, return_index=True)[1] - class_idx = np.sort(class_idx) + t, class_idx = np.unique(self._data_set["targets"], axis=0, return_index=True) + class_idx = class_idx[np.argsort(np.argmax(t, axis=1))] class_labels = np.take(self._data_set["targets"], class_idx, axis=0) label_dim = class_labels.shape[1] if label_dim > 1: @@ -318,34 +321,38 @@ def _get_class_combinations(self): class_combinations.extend(list(itertools.combinations(classes, i))) return class_specific_idx, class_combinations - + def read_class_specific_data(self, data_array: np.ndarray): """ Read class specific data. Due to storage of arrays, class specific data is stored in a list of arrays. - This method separates recorded data for each class and returns it as a + This method separates recorded data for each class and returns it as a dictionary. Parameters ---------- data_array : np.ndarray Data to read and separate. - + Returns ------- class_specific_data : dict Dictionary containing the separated data. """ # Sort the data into class specific arrays. - class_specific_data = {} + class_specific_data = {} - for class_label, indices in zip(self._class_idx[0].tolist(), self._class_idx[1]): + for class_label, indices in zip( + self._class_idx[0].tolist(), self._class_idx[1] + ): # Check whether each class has one or multiple entries. if np.shape(data_array)[1] == len(self._class_idx[0]): # This is the case when there is only one entry per class. - class_specific_data[class_label] = np.take(data_array, class_label, axis=1) + class_specific_data[class_label] = np.take( + data_array, class_label, axis=1 + ) elif np.shape(data_array)[1] > len(self._class_idx[0]): # This is the case when there is one entry per sample. class_specific_data[class_label] = np.take(data_array, indices, axis=1) @@ -466,10 +473,12 @@ def class_specific_update_fn( data = self._select_class_specific_data(indices, parsed_data) call_fn(data) - @staticmethod - def _select_class_specific_data(indices: np.ndarray, parsed_data: dict): + def _select_class_specific_data(self, indices: np.ndarray, parsed_data: dict): """ - Select class specific data. + Select class specific data from the parsed data. + + Using pre-selected indices, the data corresponding to a single class is + selected from the parsed data. Parameters ---------- @@ -482,16 +491,41 @@ def _select_class_specific_data(indices: np.ndarray, parsed_data: dict): ------- Updates the class specific arrays. """ + # Get the class specific predictions data = {"predictions": np.take(parsed_data["predictions"], indices, axis=0)} data["targets"] = np.take(parsed_data["targets"], indices, axis=0) + + # Get the class specific NTK entries try: - data["ntk"] = parsed_data["ntk"][np.ix_(indices, indices)] + ntk = parsed_data["ntk"] + raw_data_shape = np.shape(parsed_data["predictions"]) + + if self._ntk_rank == 2: + ntk = ntk[np.ix_(indices, indices)] + data["ntk"] = ntk + elif self._ntk_rank == 4 and self.flatten_ntk: + ntk = unflatten_rank_2_tensor(ntk, raw_data_shape[0], raw_data_shape[1]) + ntk = ntk[np.ix_(indices, indices)] + data["ntk"] = flatten_rank_4_tensor(ntk) + elif self._ntk_rank == 4 and not self.flatten_ntk: + ntk = unflatten_rank_2_tensor(ntk, raw_data_shape[0], raw_data_shape[1]) + ntk = ntk[np.ix_(indices, indices)] + data["ntk"] = ntk + logger.warning( + "The NTK is of rank 4 and not flattened. This may lead to " + "unexpected results or errors in calculations of observations." + ) + else: + raise ValueError( + "The NTK rank is not supported for class specific recording." + ) + except KeyError: pass return data - + def _reformat_class_specific_recording(self, data_array: np.ndarray): """ Reformat class specific data from the recorder. @@ -521,9 +555,9 @@ def _reformat_class_specific_recording(self, data_array: np.ndarray): new_data = np.concatenate(new_data_unformatted, axis=0) # Append the new data to the old data. data = old_data + [new_data] - + return data - + def update_recorder(self, epoch: int, model: JaxModel): """ Update the values stored in the recorder. @@ -588,16 +622,18 @@ def update_recorder(self, epoch: int, model: JaxModel): if self.class_specific and item != "entropy_class_correlations": # Loop over the classes and update the properties. - for class_label in self._class_idx[0]: + for i, class_label in enumerate(self._class_idx[0]): self.class_specific_update_fn( call_fn=call_fn, - indices=self._class_idx[1][class_label], + indices=self._class_idx[1][i], parsed_data=parsed_data, ) # Re-format the updated changes in the corresponding arrays. array = self.__dict__[f"_{item}_array"] - self.__dict__[f"_{item}_array"] = self._reformat_class_specific_recording(array) + self.__dict__[f"_{item}_array"] = ( + self._reformat_class_specific_recording(array) + ) elif item == "entropy_class_correlations": for class_combination in self._class_combinations: @@ -829,7 +865,7 @@ class combinations. This results in 2^n - 1 entropy computations, where n is the cov_ntk = normalize_gram_matrix(parsed_data["ntk"]) calculator = EntropyAnalysis(matrix=cov_ntk) entropy = calculator.compute_von_neumann_entropy( - effective=False, normalize_eig=True + effective=True, normalize_eig=True ) self._entropy_class_correlations_array.append(entropy) diff --git a/znnl/utils/__init__.py b/znnl/utils/__init__.py index 2cbd161..2f9da02 100644 --- a/znnl/utils/__init__.py +++ b/znnl/utils/__init__.py @@ -29,6 +29,7 @@ compute_eigensystem, flatten_rank_4_tensor, normalize_gram_matrix, + unflatten_rank_2_tensor, ) from znnl.utils.prng import PRNGKey @@ -36,5 +37,6 @@ compute_eigensystem.__name__, normalize_gram_matrix.__name__, flatten_rank_4_tensor.__name__, + unflatten_rank_2_tensor.__name__, PRNGKey.__name__, ] diff --git a/znnl/utils/matrix_utils.py b/znnl/utils/matrix_utils.py index f5b2ff2..ddb703b 100644 --- a/znnl/utils/matrix_utils.py +++ b/znnl/utils/matrix_utils.py @@ -171,6 +171,44 @@ def flatten_rank_4_tensor(tensor: np.ndarray) -> np.ndarray: ) +def unflatten_rank_2_tensor(tensor: np.ndarray, n: int, m: int) -> np.ndarray: + """ + Unflatten a rank 2 tensor to a rank 4 tensor using a specific reshaping. + + The tensor is assumed to be of shape (n * m, n * m). The reshaping is done by + concatenating first with the third and then with the fourth dimension, resulting + in a tensor of shape (n, n, m, m). + + Parameters + ---------- + tensor : np.ndarray (shape=(n * m, n * m)) + Tensor to unflatten. + n : int + First dimension of the unflattened tensor. + m : int + Second dimension of the unflattened tensor. + + Returns + ------- + unflattened_tensor : np.ndarray (shape=(n, n, m, m)) + Unflattened tensor. + """ + + if not n * m == tensor.shape[0]: + raise ValueError( + "The shape of the tensor does not match the given dimensions. " + f"Expected {n * m} but got {tensor.shape[0]}." + ) + if not n * m == tensor.shape[1]: + raise ValueError( + "The shape of the tensor does not match the given dimensions. " + f"Expected {n * m} but got {tensor.shape[1]}." + ) + + _tensor = tensor.reshape(n, m, n, m) + return np.moveaxis(_tensor, [2, 1], [1, 2]) + + def calculate_trace(matrix: np.ndarray, normalize: bool = False) -> np.ndarray: """ Calculate the trace of a matrix, including optional normalization.