From 076d0dffccd2192af5c8cdc59d14d2d85bb04a70 Mon Sep 17 00:00:00 2001 From: Marc Sauter Date: Fri, 12 Jan 2024 18:23:35 +0100 Subject: [PATCH 01/35] starting out --- znnl/analysis/ntk_calculation.py | 47 ++++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) create mode 100644 znnl/analysis/ntk_calculation.py diff --git a/znnl/analysis/ntk_calculation.py b/znnl/analysis/ntk_calculation.py new file mode 100644 index 0000000..8bea744 --- /dev/null +++ b/znnl/analysis/ntk_calculation.py @@ -0,0 +1,47 @@ +""" +ZnNL: A Zincwarecode package. + +License +------- +This program and the accompanying materials are made available under the terms +of the Eclipse Public License v2.0 which accompanies this distribution, and is +available at https://www.eclipse.org/legal/epl-v20.html + +SPDX-License-Identifier: EPL-2.0 + +Copyright Contributors to the Zincwarecode Project. + +Contact Information +------------------- +email: zincwarecode@gmail.com +github: https://github.com/zincware +web: https://zincwarecode.com/ + +Citation +-------- +If you use this module please cite us with: + +Summary +------- +""" + +import neural_tangents as nt + + +class ntk_calculation: + def function_for_loss_ntk(params) + def calculate_loss_ntk( + self, + model, + metric_fn, + dataset, + ntk_batch_size: int = 10, + store_on_device: bool = True, + ): + # Prepare NTK calculation + empirical_ntk = nt.batch( + nt.empirical_ntk_fn(f=model._ntk_apply_fn, trace_axes=trace_axes), + batch_size=ntk_batch_size, + store_on_device=store_on_device, + ) + empirical_ntk_jit = jax.jit(empirical_ntk) From 930cac8e6aaca9f667e9a0776c193525bb61fd82 Mon Sep 17 00:00:00 2001 From: m-sauter Date: Thu, 18 Jan 2024 18:58:46 +0100 Subject: [PATCH 02/35] Renaming file --- ...calculation.py => loss_ntk_calculation.py} | 30 ++++++++++++++----- 1 file changed, 23 insertions(+), 7 deletions(-) rename znnl/analysis/{ntk_calculation.py => loss_ntk_calculation.py} (55%) diff --git a/znnl/analysis/ntk_calculation.py b/znnl/analysis/loss_ntk_calculation.py similarity index 55% rename from znnl/analysis/ntk_calculation.py rename to znnl/analysis/loss_ntk_calculation.py index 8bea744..10dea10 100644 --- a/znnl/analysis/ntk_calculation.py +++ b/znnl/analysis/loss_ntk_calculation.py @@ -26,21 +26,37 @@ """ import neural_tangents as nt +from typing import Callable -class ntk_calculation: - def function_for_loss_ntk(params) - def calculate_loss_ntk( +class loss_ntk_calculation: + def __init__( self, - model, - metric_fn, - dataset, + metric_fn: Callable, + dataset: dict, ntk_batch_size: int = 10, store_on_device: bool = True, ): + """Constructor for the loss ntk calculation class.""" + self.metric_fn = metric_fn + self.dataset = dataset + self.ntk_batch_size = ntk_batch_size + self.store_on_device = store_on_device + + def _function_for_loss_ntk_helper(params, dataset, metric_fn, apply_fn): + return metric_fn(apply_fn(params, dataset["inputs"]), dataset["targets"]) + + def calculate_loss_ntk( + self, + model, + ): + _function_for_loss_ntk = lambda x, y: self._function_for_loss_ntk_helper( + x, y, metric_fn, model._ntk_apply_fn + ) + # Prepare NTK calculation empirical_ntk = nt.batch( - nt.empirical_ntk_fn(f=model._ntk_apply_fn, trace_axes=trace_axes), + nt.empirical_ntk_fn(f=_function_for_loss_ntk, trace_axes=trace_axes), batch_size=ntk_batch_size, store_on_device=store_on_device, ) From 580194a71edc476ec5c6439c8f76c8eacdc3a0ce Mon Sep 17 00:00:00 2001 From: m-sauter Date: Thu, 18 Jan 2024 19:45:16 +0100 Subject: [PATCH 03/35] analysis file should be done --- znnl/analysis/loss_ntk_calculation.py | 75 ++++++++++++++++++++++----- 1 file changed, 62 insertions(+), 13 deletions(-) diff --git a/znnl/analysis/loss_ntk_calculation.py b/znnl/analysis/loss_ntk_calculation.py index 10dea10..571944a 100644 --- a/znnl/analysis/loss_ntk_calculation.py +++ b/znnl/analysis/loss_ntk_calculation.py @@ -27,37 +27,86 @@ import neural_tangents as nt from typing import Callable +from znnl.models.jax_model import JaxModel +import jax +import jax.numpy as np class loss_ntk_calculation: def __init__( self, metric_fn: Callable, - dataset: dict, + model: JaxModel, ntk_batch_size: int = 10, store_on_device: bool = True, + trace_axes: Union[int, Sequence[int]] = (-1,), ): """Constructor for the loss ntk calculation class.""" + + # Set the attributes self.metric_fn = metric_fn - self.dataset = dataset self.ntk_batch_size = ntk_batch_size self.store_on_device = store_on_device + self.trace_axes = trace_axes - def _function_for_loss_ntk_helper(params, dataset, metric_fn, apply_fn): - return metric_fn(apply_fn(params, dataset["inputs"]), dataset["targets"]) - - def calculate_loss_ntk( - self, - model, - ): + # Set the loss ntk function _function_for_loss_ntk = lambda x, y: self._function_for_loss_ntk_helper( x, y, metric_fn, model._ntk_apply_fn ) # Prepare NTK calculation empirical_ntk = nt.batch( - nt.empirical_ntk_fn(f=_function_for_loss_ntk, trace_axes=trace_axes), - batch_size=ntk_batch_size, - store_on_device=store_on_device, + nt.empirical_ntk_fn(f=_function_for_loss_ntk, trace_axes=self.trace_axes), + batch_size=self.ntk_batch_size, + store_on_device=self.store_on_device, ) - empirical_ntk_jit = jax.jit(empirical_ntk) + self.empirical_ntk_jit = jax.jit(empirical_ntk) + + def _function_for_loss_ntk_helper(params, dataset, metric_fn, apply_fn) -> float: + """ + Helper function to create a subloss apply function of structure + (params, dataset) -> loss. + """ + return metric_fn(apply_fn(params, dataset["inputs"]), dataset["targets"]) + + def compute_loss_ntk( + self, x_i: np.ndarray, x_j: np.ndarray, model: JaxModel, infinite: bool = False + ): + """ + Compute the loss NTK matrix for the model. + + Parameters + ---------- + x_i : np.ndarray + Dataset for which to compute the loss NTK matrix. + x_j : np.ndarray (optional) + Dataset for which to compute the loss NTK matrix. + infinite : bool (default = False) + If true, compute the infinite width limit as well. + + Returns + ------- + NTK : dict + The NTK matrix for both the empirical and infinite width computation. + """ + + if x_j is None: + x_j = x_i + empirical_ntk = self.empirical_ntk_jit( + x_i, + x_j, + { + "params": model.model_state.params, + "batch_stats": model.model_state.batch_stats, + }, + ) + + if infinite: + try: + infinite_ntk = self.kernel_fn(x_i, x_j, "ntk") + except AttributeError: + raise NotImplementedError("Infinite NTK not available for this model.") + else: + infinite_ntk = None + + return {"empirical": empirical_ntk, "infinite": infinite_ntk} From 768fc13bb5ca704b60c196fb542f19ac54f7eebf Mon Sep 17 00:00:00 2001 From: m-sauter Date: Thu, 18 Jan 2024 19:46:55 +0100 Subject: [PATCH 04/35] quick fix --- znnl/analysis/loss_ntk_calculation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/znnl/analysis/loss_ntk_calculation.py b/znnl/analysis/loss_ntk_calculation.py index 571944a..f6c02c7 100644 --- a/znnl/analysis/loss_ntk_calculation.py +++ b/znnl/analysis/loss_ntk_calculation.py @@ -26,7 +26,7 @@ """ import neural_tangents as nt -from typing import Callable +from typing import Callable, Union, Sequence from znnl.models.jax_model import JaxModel import jax import jax.numpy as np From c8020100576bd073bccb3beec2e1f7bad7d93e4c Mon Sep 17 00:00:00 2001 From: m-sauter Date: Fri, 19 Jan 2024 10:22:27 +0100 Subject: [PATCH 05/35] Included loss ntk, eigval and entropy in jax recorder --- znnl/analysis/__init__.py | 2 + znnl/analysis/loss_ntk_calculation.py | 9 +- znnl/training_recording/jax_recording.py | 101 ++++++++++++++++++++--- 3 files changed, 95 insertions(+), 17 deletions(-) diff --git a/znnl/analysis/__init__.py b/znnl/analysis/__init__.py index 94701c5..09105e9 100644 --- a/znnl/analysis/__init__.py +++ b/znnl/analysis/__init__.py @@ -28,9 +28,11 @@ from znnl.analysis.eigensystem import EigenSpaceAnalysis from znnl.analysis.entropy import EntropyAnalysis from znnl.analysis.loss_fn_derivative import LossDerivative +from znnl.analysis.loss_ntk_calculation import loss_ntk_calculation __all__ = [ EntropyAnalysis.__name__, EigenSpaceAnalysis.__name__, LossDerivative.__name__, + loss_ntk_calculation.__name__, ] diff --git a/znnl/analysis/loss_ntk_calculation.py b/znnl/analysis/loss_ntk_calculation.py index f6c02c7..0fc9d7d 100644 --- a/znnl/analysis/loss_ntk_calculation.py +++ b/znnl/analysis/loss_ntk_calculation.py @@ -37,17 +37,14 @@ def __init__( self, metric_fn: Callable, model: JaxModel, - ntk_batch_size: int = 10, - store_on_device: bool = True, - trace_axes: Union[int, Sequence[int]] = (-1,), ): """Constructor for the loss ntk calculation class.""" # Set the attributes self.metric_fn = metric_fn - self.ntk_batch_size = ntk_batch_size - self.store_on_device = store_on_device - self.trace_axes = trace_axes + self.ntk_batch_size = model.ntk_batch_size + self.store_on_device = model.store_on_device + self.trace_axes = model.trace_axes # Set the loss ntk function _function_for_loss_ntk = lambda x, y: self._function_for_loss_ntk_helper( diff --git a/znnl/training_recording/jax_recording.py b/znnl/training_recording/jax_recording.py index 862b7d4..08c5ea9 100644 --- a/znnl/training_recording/jax_recording.py +++ b/znnl/training_recording/jax_recording.py @@ -37,6 +37,7 @@ from znnl.analysis.eigensystem import EigenSpaceAnalysis from znnl.analysis.entropy import EntropyAnalysis from znnl.analysis.loss_fn_derivative import LossDerivative +from znnl.analysis.loss_ntk_calculation import loss_ntk_calculation from znnl.loss_functions import SimpleLoss from znnl.models.jax_model import JaxModel from znnl.training_recording.data_storage import DataStorage @@ -157,6 +158,18 @@ class JaxRecorder: loss_derivative: bool = False _loss_derivative_array: list = None + # Loss NTK + loss_ntk: bool = False + _loss_ntk_array: list = None + + # Loss NTK eigenvalues + loss_ntk_eigenvalues: bool = False + _loss_ntk_eigenvalues_array: list = None + + # Loss NTK entropy + loss_ntk_entropy: bool = False + _loss_ntk_entropy_array: list = None + # Class helpers update_rate: int = 1 _loss_fn: SimpleLoss = None @@ -165,6 +178,9 @@ class JaxRecorder: _model: JaxModel = None _data_set: dict = None _compute_ntk: bool = False # Helps to know if we can compute it once and share. + _compute_loss_ntk: bool = ( + False # Helps to know if we can compute it once and share. + ) _compute_loss_derivative: bool = False _loss_derivative_fn: LossDerivative = False _index_count: int = 0 # Helps to avoid problems with non-1 update rates. @@ -254,19 +270,35 @@ def instantiate_recorder(self, data_set: dict = None, overwrite: bool = False): self._index_count = 0 # Check if we need an NTK computation and update the class accordingly - if any([ - "ntk" in self._selected_properties, - "covariance_ntk" in self._selected_properties, - "magnitude_ntk" in self._selected_properties, - "entropy" in self._selected_properties, - "magnitude_entropy" in self._selected_properties, - "magnitude_variance" in self._selected_properties, - "covariance_entropy" in self._selected_properties, - "eigenvalues" in self._selected_properties, - "trace" in self._selected_properties, - ]): + if any( + [ + "ntk" in self._selected_properties, + "covariance_ntk" in self._selected_properties, + "magnitude_ntk" in self._selected_properties, + "entropy" in self._selected_properties, + "magnitude_entropy" in self._selected_properties, + "magnitude_variance" in self._selected_properties, + "covariance_entropy" in self._selected_properties, + "eigenvalues" in self._selected_properties, + "trace" in self._selected_properties, + ] + ): self._compute_ntk = True + # Check if we need a loss NTK computation and update the class accordingly + + if any( + [ + "loss_ntk" in self._selected_properties, + "loss_ntk_eigenvalues" in self._selected_properties, + "loss_ntk_entropy" in self._selected_properties, + ] + ): + self._compute_loss_ntk = True + self._loss_ntk_calculator = loss_ntk_calculation( + metric_fn=self._loss_fn.metric, model=self._model + ) + if "loss_derivative" in self._selected_properties: self._loss_derivative_fn = LossDerivative(self._loss_fn) @@ -321,6 +353,14 @@ def update_recorder(self, epoch: int, model: JaxModel): self.eigenvalues = False self._read_selected_attributes() + # Compute loss ntk here to avoid repeated computation. + if self._compute_loss_ntk: + parsed_data["loss_ntk"] = self._loss_ntk_calculator.compute_loss_ntk( + x_i=self._data_set, + model=self._model, + infinite=False, # Set true to compute infinite width limit of loss ntk + ) + for item in self._selected_properties: call_fn = getattr(self, f"_update_{item}") # get the callable function @@ -587,6 +627,45 @@ def _update_loss_derivative(self, parsed_data): loss_derivative = calculate_l_pq_norm(vector_loss_derivative) self._loss_derivative_array.append(loss_derivative) + def _update_loss_ntk(self, parsed_data): + """ + Update the loss ntk array. + + Parameters + ---------- + parsed_data : dict + Data computed before the update to prevent repeated calculations. + """ + self._loss_ntk_array.append(parsed_data["loss_ntk"]) + + def _update_loss_ntk_eigenvalues(self, parsed_data): + """ + Update the loss ntk eigenvalue array. + + Parameters + ---------- + parsed_data : dict + Data computed before the update to prevent repeated calculations. + """ + calculator = EigenSpaceAnalysis(matrix=parsed_data["loss_ntk"]) + eigenvalues = calculator.compute_eigenvalues(normalize=False) + self._loss_ntk_eigenvalues_array.append(eigenvalues) + + def _update_loss_ntk_entropy(self, parsed_data): + """ + Update the loss ntk entropy array. + + Parameters + ---------- + parsed_data : dict + Data computed before the update to prevent repeated calculations. + """ + calculator = EntropyAnalysis(matrix=parsed_data["loss_ntk"]) + entropy = calculator.compute_von_neumann_entropy( + effective=False, normalize_eig=True + ) + self._loss_ntk_entropy_array.append(entropy) + def gather_recording(self, selected_properties: list = None) -> dataclass: """ Export a dataclass of used properties. From 17fd83a6d51f9fc00ba91fec68777b97e5f01401 Mon Sep 17 00:00:00 2001 From: m-sauter Date: Fri, 19 Jan 2024 10:32:34 +0100 Subject: [PATCH 06/35] Updated recorder instantiation test --- CI/unit_tests/training_recording/test_training_recording.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CI/unit_tests/training_recording/test_training_recording.py b/CI/unit_tests/training_recording/test_training_recording.py index 5249b3e..39ee257 100644 --- a/CI/unit_tests/training_recording/test_training_recording.py +++ b/CI/unit_tests/training_recording/test_training_recording.py @@ -66,6 +66,9 @@ def test_instantiation(self): eigenvalues=True, trace=True, loss_derivative=True, + loss_ntk=True, + loss_ntk_derivative=True, + loss_ntk_eigenvalues=True, ) recorder.instantiate_recorder(data_set=self.dummy_data_set) _exclude_list = [ From e88292f4fc655dcc06c41e895364578ba176587b Mon Sep 17 00:00:00 2001 From: m-sauter Date: Fri, 19 Jan 2024 13:40:35 +0100 Subject: [PATCH 07/35] Bugfixing --- znnl/models/jax_model.py | 7 +++++++ znnl/training_recording/jax_recording.py | 17 ++++++++++++++--- znnl/training_strategies/simple_training.py | 1 + 3 files changed, 22 insertions(+), 3 deletions(-) diff --git a/znnl/models/jax_model.py b/znnl/models/jax_model.py index 8e8ea95..dc4f4ea 100644 --- a/znnl/models/jax_model.py +++ b/znnl/models/jax_model.py @@ -127,6 +127,13 @@ def __init__( batch_size=ntk_batch_size, store_on_device=store_on_device, ) + + # Next values need to be set to be available for the loss ntk calculation because + # it's implemented outside of the model class. + self.ntk_batch_size = ntk_batch_size + self.trace_axes = trace_axes + self.store_on_device = store_on_device + self.empirical_ntk_jit = jax.jit(self.empirical_ntk) self.apply_jit = jax.jit(self.apply) diff --git a/znnl/training_recording/jax_recording.py b/znnl/training_recording/jax_recording.py index 08c5ea9..5230d3e 100644 --- a/znnl/training_recording/jax_recording.py +++ b/znnl/training_recording/jax_recording.py @@ -183,6 +183,7 @@ class JaxRecorder: ) _compute_loss_derivative: bool = False _loss_derivative_fn: LossDerivative = False + _loss_ntk_calculator: loss_ntk_calculation = None _index_count: int = 0 # Helps to avoid problems with non-1 update rates. _data_storage: DataStorage = None # For writing to disk. @@ -295,9 +296,18 @@ def instantiate_recorder(self, data_set: dict = None, overwrite: bool = False): ] ): self._compute_loss_ntk = True - self._loss_ntk_calculator = loss_ntk_calculation( - metric_fn=self._loss_fn.metric, model=self._model - ) + print("instantiating") + print(self._loss_fn) + try: + self._loss_ntk_calculator = loss_ntk_calculation( + metric_fn=self._loss_fn.metric, model=self._model + ) + except AttributeError: + print("Warning") + logger.info( + "Warning: The loss function hasn't been set yet." + "Please set it before training." + ) if "loss_derivative" in self._selected_properties: self._loss_derivative_fn = LossDerivative(self._loss_fn) @@ -357,6 +367,7 @@ def update_recorder(self, epoch: int, model: JaxModel): if self._compute_loss_ntk: parsed_data["loss_ntk"] = self._loss_ntk_calculator.compute_loss_ntk( x_i=self._data_set, + x_j=None, model=self._model, infinite=False, # Set true to compute infinite width limit of loss ntk ) diff --git a/znnl/training_strategies/simple_training.py b/znnl/training_strategies/simple_training.py index 325bf7e..eb461d7 100644 --- a/znnl/training_strategies/simple_training.py +++ b/znnl/training_strategies/simple_training.py @@ -112,6 +112,7 @@ def __init__( # Add the loss and accuracy function to the recorders and re-instantiate them if self.recorders is not None: for item in self.recorders: + item._model = self.model item.loss_fn = loss_fn item.accuracy_fn = accuracy_fn item.instantiate_recorder() From 4eb83627d85debbfba3e0b66c1020f1287c9570a Mon Sep 17 00:00:00 2001 From: Marc Sauter Date: Fri, 19 Jan 2024 16:48:39 +0100 Subject: [PATCH 08/35] Stand vor Kino --- znnl/analysis/loss_ntk_calculation.py | 54 +++++++++++++++++++----- znnl/training_recording/jax_recording.py | 10 ++--- 2 files changed, 48 insertions(+), 16 deletions(-) diff --git a/znnl/analysis/loss_ntk_calculation.py b/znnl/analysis/loss_ntk_calculation.py index 0fc9d7d..0f7c84b 100644 --- a/znnl/analysis/loss_ntk_calculation.py +++ b/znnl/analysis/loss_ntk_calculation.py @@ -26,7 +26,7 @@ """ import neural_tangents as nt -from typing import Callable, Union, Sequence +from typing import Callable from znnl.models.jax_model import JaxModel import jax import jax.numpy as np @@ -37,6 +37,7 @@ def __init__( self, metric_fn: Callable, model: JaxModel, + dataset: dict, ): """Constructor for the loss ntk calculation class.""" @@ -45,32 +46,46 @@ def __init__( self.ntk_batch_size = model.ntk_batch_size self.store_on_device = model.store_on_device self.trace_axes = model.trace_axes - - # Set the loss ntk function - _function_for_loss_ntk = lambda x, y: self._function_for_loss_ntk_helper( - x, y, metric_fn, model._ntk_apply_fn - ) + self.input_shape = dataset["inputs"].shape + self.input_dimension = int(np.prod(np.array(self.input_shape[1:]))) + self.target_shape = dataset["targets"].shape + self.metric_fn = metric_fn + self.apply_fn = model._ntk_apply_fn # Prepare NTK calculation empirical_ntk = nt.batch( - nt.empirical_ntk_fn(f=_function_for_loss_ntk, trace_axes=self.trace_axes), + nt.empirical_ntk_fn( + f=self._function_for_loss_ntk, + trace_axes=self.trace_axes, + ), batch_size=self.ntk_batch_size, store_on_device=self.store_on_device, ) self.empirical_ntk_jit = jax.jit(empirical_ntk) - def _function_for_loss_ntk_helper(params, dataset, metric_fn, apply_fn) -> float: + def _function_for_loss_ntk(self, params, datapoint) -> float: """ - Helper function to create a subloss apply function of structure - (params, dataset) -> loss. + Helper function to create a subloss apply function. + The datapoint here has to be shaped so that its an array of length + input dimension + output dimension. + This is done so that the inputs and targets can be understood + by the neural tangents empirical_ntk_fn function. """ - return metric_fn(apply_fn(params, dataset["inputs"]), dataset["targets"]) + _input = datapoint[: self.input_dimension] + _target = datapoint[self.input_dimension :] + return self.metric_fn( + self.apply_fn(params, _input), + _target, + ) def compute_loss_ntk( self, x_i: np.ndarray, x_j: np.ndarray, model: JaxModel, infinite: bool = False ): """ Compute the loss NTK matrix for the model. + The dataset gets reshaped to (n_data, input_dimension + output_dimension) + so that the neural tangents empirical_ntk_fn function can take each input + target pair as its input. Parameters ---------- @@ -87,8 +102,25 @@ def compute_loss_ntk( The NTK matrix for both the empirical and infinite width computation. """ + x_i = np.concatenate( + ( + x_i["inputs"].reshape(x_i["inputs"].shape[0], -1), + x_i["targets"].reshape(x_i["targets"].shape[0], -1), + ), + axis=1, + ) + if x_j is None: x_j = x_i + else: + x_j = np.concatenate( + ( + x_j["inputs"].reshape(x_j["inputs"].shape[0], -1), + x_j["targets"].reshape(x_j["targets"].shape[0], -1), + ), + axis=1, + ) + empirical_ntk = self.empirical_ntk_jit( x_i, x_j, diff --git a/znnl/training_recording/jax_recording.py b/znnl/training_recording/jax_recording.py index 5230d3e..84b41e3 100644 --- a/znnl/training_recording/jax_recording.py +++ b/znnl/training_recording/jax_recording.py @@ -296,11 +296,11 @@ def instantiate_recorder(self, data_set: dict = None, overwrite: bool = False): ] ): self._compute_loss_ntk = True - print("instantiating") - print(self._loss_fn) try: self._loss_ntk_calculator = loss_ntk_calculation( - metric_fn=self._loss_fn.metric, model=self._model + metric_fn=self._loss_fn.metric, + model=self._model, + dataset=self._data_set, ) except AttributeError: print("Warning") @@ -369,8 +369,8 @@ def update_recorder(self, epoch: int, model: JaxModel): x_i=self._data_set, x_j=None, model=self._model, - infinite=False, # Set true to compute infinite width limit of loss ntk - ) + infinite=False, # Set true to compute infinite width limit + )["empirical"] for item in self._selected_properties: call_fn = getattr(self, f"_update_{item}") # get the callable function From fb1da2e40d1ce03b0343eb068ae2ef8e0c5300ef Mon Sep 17 00:00:00 2001 From: m-sauter Date: Mon, 22 Jan 2024 01:48:45 +0100 Subject: [PATCH 09/35] First thing state that might be working --- znnl/analysis/loss_ntk_calculation.py | 7 +++++-- znnl/training_recording/jax_recording.py | 1 - 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/znnl/analysis/loss_ntk_calculation.py b/znnl/analysis/loss_ntk_calculation.py index 0f7c84b..07deab0 100644 --- a/znnl/analysis/loss_ntk_calculation.py +++ b/znnl/analysis/loss_ntk_calculation.py @@ -70,9 +70,12 @@ def _function_for_loss_ntk(self, params, datapoint) -> float: input dimension + output dimension. This is done so that the inputs and targets can be understood by the neural tangents empirical_ntk_fn function. + + Seems like during the NTK calculation, this function needs to handle + the whole dataset at once instead of just one datapoint. """ - _input = datapoint[: self.input_dimension] - _target = datapoint[self.input_dimension :] + _input = datapoint[:, : self.input_dimension] + _target = datapoint[:, self.input_dimension :] return self.metric_fn( self.apply_fn(params, _input), _target, diff --git a/znnl/training_recording/jax_recording.py b/znnl/training_recording/jax_recording.py index 84b41e3..dfb37d8 100644 --- a/znnl/training_recording/jax_recording.py +++ b/znnl/training_recording/jax_recording.py @@ -303,7 +303,6 @@ def instantiate_recorder(self, data_set: dict = None, overwrite: bool = False): dataset=self._data_set, ) except AttributeError: - print("Warning") logger.info( "Warning: The loss function hasn't been set yet." "Please set it before training." From d22bb682bdc8de314bfc0933aabac7e455af45f4 Mon Sep 17 00:00:00 2001 From: Marc Sauter Date: Mon, 22 Jan 2024 17:51:52 +0100 Subject: [PATCH 10/35] writing test --- .../analysis/test_loss_ntk_calculation.py | 82 +++++++++++++++++++ znnl/analysis/loss_ntk_calculation.py | 10 ++- 2 files changed, 91 insertions(+), 1 deletion(-) create mode 100644 CI/unit_tests/analysis/test_loss_ntk_calculation.py diff --git a/CI/unit_tests/analysis/test_loss_ntk_calculation.py b/CI/unit_tests/analysis/test_loss_ntk_calculation.py new file mode 100644 index 0000000..90a4632 --- /dev/null +++ b/CI/unit_tests/analysis/test_loss_ntk_calculation.py @@ -0,0 +1,82 @@ +""" +ZnNL: A Zincwarecode package. + +License +------- +This program and the accompanying materials are made available under the terms +of the Eclipse Public License v2.0 which accompanies this distribution, and is +available at https://www.eclipse.org/legal/epl-v20.html + +SPDX-License-Identifier: EPL-2.0 + +Copyright Contributors to the Zincwarecode Project. + +Contact Information +------------------- +email: zincwarecode@gmail.com +github: https://github.com/zincware +web: https://zincwarecode.com/ + +Citation +-------- +If you use this module please cite us with: + +Summary +------- +""" + +import os + +os.environ["CUDA_VISIBLE_DEVICES"] = "-1" + +import jax.numpy as np +import pytest + +from znnl.analysis import loss_ntk_calculation +from znnl.training_recording import JaxRecorder +from znnl.models import NTModel +from znnl.data import MNISTGenerator +from neural_tangents import stax + +import optax +import tensorflow_datasets as tfds + + +class TestLossNTKCalculation: + """ + Test Suite for the loss NTK calculation module. + """ + + def test_loss_ntk_calculation(self): + """ + Test the loss NTK calculation. + """ + + # Define a test Network + dense_network = stax.serial( + stax.Dense(32), + stax.Relu(), + stax.Dense(32), + ) + + # Define a test model + fuel_model = NTModel( + nt_module=dense_network, + optimizer=optax.adam(learning_rate=0.005), + input_shape=(9,), + trace_axes=(), + batch_size=314, + ) + + # Initialize model parameters + + data_generator = MNISTGenerator(ds_size=10) + data_set = { + "inputs": data_generator.train_ds["inputs"], + "targets": data_generator.train_ds["targets"], + } + + print(fuel_model.model_state.params) + + +TestLossNTKCalculation().test_loss_ntk_calculation() diff --git a/znnl/analysis/loss_ntk_calculation.py b/znnl/analysis/loss_ntk_calculation.py index 07deab0..952e389 100644 --- a/znnl/analysis/loss_ntk_calculation.py +++ b/znnl/analysis/loss_ntk_calculation.py @@ -61,6 +61,10 @@ def __init__( batch_size=self.ntk_batch_size, store_on_device=self.store_on_device, ) + empirical_ntk = nt.empirical_ntk_fn( + f=self._function_for_loss_ntk, + trace_axes=self.trace_axes, + ) self.empirical_ntk_jit = jax.jit(empirical_ntk) def _function_for_loss_ntk(self, params, datapoint) -> float: @@ -82,7 +86,11 @@ def _function_for_loss_ntk(self, params, datapoint) -> float: ) def compute_loss_ntk( - self, x_i: np.ndarray, x_j: np.ndarray, model: JaxModel, infinite: bool = False + self, + x_i: np.ndarray, + model: JaxModel, + x_j: np.ndarray = None, + infinite: bool = False, ): """ Compute the loss NTK matrix for the model. From e141dd8c4a798f05e250a1d73bc5e343c8906f77 Mon Sep 17 00:00:00 2001 From: m-sauter Date: Tue, 30 Jan 2024 12:21:47 +0100 Subject: [PATCH 11/35] Included vmap_axes --- CI/unit_tests/analysis/test_loss_ntk_calculation.py | 12 ++++++++++-- znnl/analysis/loss_ntk_calculation.py | 11 ++++------- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/CI/unit_tests/analysis/test_loss_ntk_calculation.py b/CI/unit_tests/analysis/test_loss_ntk_calculation.py index 90a4632..80ee61a 100644 --- a/CI/unit_tests/analysis/test_loss_ntk_calculation.py +++ b/CI/unit_tests/analysis/test_loss_ntk_calculation.py @@ -39,7 +39,6 @@ from neural_tangents import stax import optax -import tensorflow_datasets as tfds class TestLossNTKCalculation: @@ -76,7 +75,16 @@ def test_loss_ntk_calculation(self): "targets": data_generator.train_ds["targets"], } - print(fuel_model.model_state.params) + # Initialize the loss NTK calculation + loss_ntk_calculator = loss_ntk_calculation( + metric_fn=lambda x, y: (x - y) ** 2, + model=fuel_model, + dataset=data_set, + ) + + # Compute the loss NTK + ntk = loss_ntk_calculator.compute_loss_ntk(x_i=data_set, model=fuel_model) + print(ntk.shape) TestLossNTKCalculation().test_loss_ntk_calculation() diff --git a/znnl/analysis/loss_ntk_calculation.py b/znnl/analysis/loss_ntk_calculation.py index 952e389..31e018d 100644 --- a/znnl/analysis/loss_ntk_calculation.py +++ b/znnl/analysis/loss_ntk_calculation.py @@ -30,7 +30,7 @@ from znnl.models.jax_model import JaxModel import jax import jax.numpy as np - +import numpy class loss_ntk_calculation: def __init__( @@ -57,14 +57,11 @@ def __init__( nt.empirical_ntk_fn( f=self._function_for_loss_ntk, trace_axes=self.trace_axes, + vmap_axes=0, ), batch_size=self.ntk_batch_size, store_on_device=self.store_on_device, ) - empirical_ntk = nt.empirical_ntk_fn( - f=self._function_for_loss_ntk, - trace_axes=self.trace_axes, - ) self.empirical_ntk_jit = jax.jit(empirical_ntk) def _function_for_loss_ntk(self, params, datapoint) -> float: @@ -109,8 +106,8 @@ def compute_loss_ntk( Returns ------- - NTK : dict - The NTK matrix for both the empirical and infinite width computation. + Loss NTK : dict + The Loss NTK matrix for both the empirical and infinite width computation. """ x_i = np.concatenate( From 22cfc904358addbe14c64e0f2a9c6f0c542d7b0f Mon Sep 17 00:00:00 2001 From: m-sauter Date: Tue, 30 Jan 2024 13:49:32 +0100 Subject: [PATCH 12/35] Working on calculating loss derivatives to calculate loss ntk comparison --- .../analysis/test_loss_ntk_calculation.py | 65 +++++++++++++++---- znnl/analysis/loss_ntk_calculation.py | 10 ++- 2 files changed, 62 insertions(+), 13 deletions(-) diff --git a/CI/unit_tests/analysis/test_loss_ntk_calculation.py b/CI/unit_tests/analysis/test_loss_ntk_calculation.py index 80ee61a..c2465c0 100644 --- a/CI/unit_tests/analysis/test_loss_ntk_calculation.py +++ b/CI/unit_tests/analysis/test_loss_ntk_calculation.py @@ -33,14 +33,41 @@ import pytest from znnl.analysis import loss_ntk_calculation +from znnl.distance_metrics import LPNorm from znnl.training_recording import JaxRecorder -from znnl.models import NTModel +from znnl.loss_functions import LPNormLoss +from znnl.analysis import LossDerivative +from znnl.models import FlaxModel from znnl.data import MNISTGenerator +from flax import linen as nn + from neural_tangents import stax import optax +# Defines a simple CNN module +class ProductionModule(nn.Module): + """ + Simple CNN module. + """ + + @nn.compact + def __call__(self, x): + x = nn.Conv(features=128, kernel_size=(3, 3))(x) + x = nn.relu(x) + x = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2)) + x = nn.Conv(features=128, kernel_size=(3, 3))(x) + x = nn.relu(x) + x = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2)) + x = x.reshape((x.shape[0], -1)) # flatten + x = nn.Dense(features=300)(x) + x = nn.relu(x) + x = nn.Dense(10)(x) + + return x + + class TestLossNTKCalculation: """ Test Suite for the loss NTK calculation module. @@ -59,14 +86,12 @@ def test_loss_ntk_calculation(self): ) # Define a test model - fuel_model = NTModel( - nt_module=dense_network, - optimizer=optax.adam(learning_rate=0.005), - input_shape=(9,), + production_model = FlaxModel( + flax_module=ProductionModule(), + optimizer=optax.adam(learning_rate=0.01), + input_shape=(1, 28, 28, 1), trace_axes=(), - batch_size=314, ) - # Initialize model parameters data_generator = MNISTGenerator(ds_size=10) @@ -77,14 +102,32 @@ def test_loss_ntk_calculation(self): # Initialize the loss NTK calculation loss_ntk_calculator = loss_ntk_calculation( - metric_fn=lambda x, y: (x - y) ** 2, - model=fuel_model, + metric_fn=LPNorm(order=2), + model=production_model, dataset=data_set, ) # Compute the loss NTK - ntk = loss_ntk_calculator.compute_loss_ntk(x_i=data_set, model=fuel_model) - print(ntk.shape) + loss_ntk = loss_ntk_calculator.compute_loss_ntk( + x_i=data_set, model=production_model + )["empirical"] + + # Now for comparison calculate regular ntk + ntk = production_model.compute_ntk(data_set["inputs"], infinite=False)[ + "empirical" + ] + # Calculate Loss derivative fn + loss_derivative_calculator = LossDerivative(LPNormLoss(order=2)) + # predictions calculation analogous to the one in jax recording + predictions = production_model(data_set["inputs"]) + if type(predictions) is tuple: + predictions = predictions[0] + # calculation of loss derivatives + loss_derivatives = loss_derivative_calculator.calculate( + predictions=predictions, + targets=data_set["targets"], + ) + print(loss_derivatives.shape) TestLossNTKCalculation().test_loss_ntk_calculation() diff --git a/znnl/analysis/loss_ntk_calculation.py b/znnl/analysis/loss_ntk_calculation.py index 31e018d..f943578 100644 --- a/znnl/analysis/loss_ntk_calculation.py +++ b/znnl/analysis/loss_ntk_calculation.py @@ -32,6 +32,7 @@ import jax.numpy as np import numpy + class loss_ntk_calculation: def __init__( self, @@ -75,8 +76,13 @@ def _function_for_loss_ntk(self, params, datapoint) -> float: Seems like during the NTK calculation, this function needs to handle the whole dataset at once instead of just one datapoint. """ - _input = datapoint[:, : self.input_dimension] - _target = datapoint[:, self.input_dimension :] + batch_length = datapoint.shape[0] + _input = datapoint[:, : self.input_dimension].reshape( + batch_length, *self.input_shape[1:] + ) + _target = datapoint[:, self.input_dimension :].reshape( + batch_length, *self.target_shape[1:] + ) return self.metric_fn( self.apply_fn(params, _input), _target, From 5f9b6bfa5a91f1add8c7d543302757348067f979 Mon Sep 17 00:00:00 2001 From: m-sauter Date: Sun, 4 Feb 2024 21:17:26 +0100 Subject: [PATCH 13/35] Calculation and test should work now --- .../analysis/test_loss_ntk_calculation.py | 36 ++++++++++++++----- 1 file changed, 27 insertions(+), 9 deletions(-) diff --git a/CI/unit_tests/analysis/test_loss_ntk_calculation.py b/CI/unit_tests/analysis/test_loss_ntk_calculation.py index c2465c0..456d7a8 100644 --- a/CI/unit_tests/analysis/test_loss_ntk_calculation.py +++ b/CI/unit_tests/analysis/test_loss_ntk_calculation.py @@ -30,6 +30,9 @@ os.environ["CUDA_VISIBLE_DEVICES"] = "-1" import jax.numpy as np +import numpy as onp +from numpy.testing import assert_array_almost_equal + import pytest from znnl.analysis import loss_ntk_calculation @@ -75,7 +78,10 @@ class TestLossNTKCalculation: def test_loss_ntk_calculation(self): """ - Test the loss NTK calculation. + Test the Loss NTK calculation. + Here we test if the Loss NTK calculated through the neural tangents module is + the same as the Loss NTK calculated with the already implemented NTK and loss + derivatives. """ # Define a test Network @@ -118,16 +124,28 @@ def test_loss_ntk_calculation(self): ] # Calculate Loss derivative fn loss_derivative_calculator = LossDerivative(LPNormLoss(order=2)) + # predictions calculation analogous to the one in jax recording predictions = production_model(data_set["inputs"]) if type(predictions) is tuple: predictions = predictions[0] - # calculation of loss derivatives - loss_derivatives = loss_derivative_calculator.calculate( - predictions=predictions, - targets=data_set["targets"], - ) - print(loss_derivatives.shape) - -TestLossNTKCalculation().test_loss_ntk_calculation() + # calculation of loss derivatives + # note: here we need the derivatives of the subloss, not the regular loss fn + loss_derivatives = onp.empty(shape=(len(predictions), len(predictions[0]))) + for i in range(len(loss_derivatives)): + # The weird indexing here is because of axis constraints in the LPNormLoss module + loss_derivatives[i] = loss_derivative_calculator.calculate( + predictions[i : i + 1], data_set["targets"][i : i + 1] + )[0] + + # Calculate the loss NTK from the loss derivatives and the ntk + loss_ntk_2 = onp.zeros_like(loss_ntk) + for i in range(len(loss_ntk_2)): + for j in range(len(loss_ntk_2[0])): + loss_ntk_2[i, j] = np.einsum( + "i, j, ij", loss_derivatives[i], loss_derivatives[j], ntk[i, j] + ) + + # Assert that the loss NTKs are the same + assert_array_almost_equal(loss_ntk, loss_ntk_2, decimal=2) From 131d5cb9e6ff59a98a1a1d31f225499b5a40ac64 Mon Sep 17 00:00:00 2001 From: m-sauter Date: Sun, 4 Feb 2024 21:23:49 +0100 Subject: [PATCH 14/35] some linting updates --- .../analysis/test_loss_ntk_calculation.py | 17 ++++----- znnl/analysis/loss_ntk_calculation.py | 7 ++-- znnl/training_recording/jax_recording.py | 36 +++++++++---------- 3 files changed, 26 insertions(+), 34 deletions(-) diff --git a/CI/unit_tests/analysis/test_loss_ntk_calculation.py b/CI/unit_tests/analysis/test_loss_ntk_calculation.py index 456d7a8..4ec4755 100644 --- a/CI/unit_tests/analysis/test_loss_ntk_calculation.py +++ b/CI/unit_tests/analysis/test_loss_ntk_calculation.py @@ -31,22 +31,17 @@ import jax.numpy as np import numpy as onp +import optax +from flax import linen as nn +from neural_tangents import stax from numpy.testing import assert_array_almost_equal -import pytest - -from znnl.analysis import loss_ntk_calculation +from znnl.analysis import LossDerivative, loss_ntk_calculation +from znnl.data import MNISTGenerator from znnl.distance_metrics import LPNorm -from znnl.training_recording import JaxRecorder from znnl.loss_functions import LPNormLoss -from znnl.analysis import LossDerivative from znnl.models import FlaxModel -from znnl.data import MNISTGenerator -from flax import linen as nn - -from neural_tangents import stax - -import optax +from znnl.training_recording import JaxRecorder # Defines a simple CNN module diff --git a/znnl/analysis/loss_ntk_calculation.py b/znnl/analysis/loss_ntk_calculation.py index f943578..04cc979 100644 --- a/znnl/analysis/loss_ntk_calculation.py +++ b/znnl/analysis/loss_ntk_calculation.py @@ -25,12 +25,13 @@ ------- """ -import neural_tangents as nt from typing import Callable -from znnl.models.jax_model import JaxModel + import jax import jax.numpy as np -import numpy +import neural_tangents as nt + +from znnl.models.jax_model import JaxModel class loss_ntk_calculation: diff --git a/znnl/training_recording/jax_recording.py b/znnl/training_recording/jax_recording.py index dfb37d8..d01d74f 100644 --- a/znnl/training_recording/jax_recording.py +++ b/znnl/training_recording/jax_recording.py @@ -271,30 +271,26 @@ def instantiate_recorder(self, data_set: dict = None, overwrite: bool = False): self._index_count = 0 # Check if we need an NTK computation and update the class accordingly - if any( - [ - "ntk" in self._selected_properties, - "covariance_ntk" in self._selected_properties, - "magnitude_ntk" in self._selected_properties, - "entropy" in self._selected_properties, - "magnitude_entropy" in self._selected_properties, - "magnitude_variance" in self._selected_properties, - "covariance_entropy" in self._selected_properties, - "eigenvalues" in self._selected_properties, - "trace" in self._selected_properties, - ] - ): + if any([ + "ntk" in self._selected_properties, + "covariance_ntk" in self._selected_properties, + "magnitude_ntk" in self._selected_properties, + "entropy" in self._selected_properties, + "magnitude_entropy" in self._selected_properties, + "magnitude_variance" in self._selected_properties, + "covariance_entropy" in self._selected_properties, + "eigenvalues" in self._selected_properties, + "trace" in self._selected_properties, + ]): self._compute_ntk = True # Check if we need a loss NTK computation and update the class accordingly - if any( - [ - "loss_ntk" in self._selected_properties, - "loss_ntk_eigenvalues" in self._selected_properties, - "loss_ntk_entropy" in self._selected_properties, - ] - ): + if any([ + "loss_ntk" in self._selected_properties, + "loss_ntk_eigenvalues" in self._selected_properties, + "loss_ntk_entropy" in self._selected_properties, + ]): self._compute_loss_ntk = True try: self._loss_ntk_calculator = loss_ntk_calculation( From 5b244ffda9b03600712768bcebaeb7f8bfa43a45 Mon Sep 17 00:00:00 2001 From: m-sauter Date: Sun, 4 Feb 2024 21:46:05 +0100 Subject: [PATCH 15/35] Fixed tests --- CI/unit_tests/training_recording/test_training_recording.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CI/unit_tests/training_recording/test_training_recording.py b/CI/unit_tests/training_recording/test_training_recording.py index 39ee257..6218338 100644 --- a/CI/unit_tests/training_recording/test_training_recording.py +++ b/CI/unit_tests/training_recording/test_training_recording.py @@ -67,8 +67,8 @@ def test_instantiation(self): trace=True, loss_derivative=True, loss_ntk=True, - loss_ntk_derivative=True, loss_ntk_eigenvalues=True, + loss_ntk_entropy=True, ) recorder.instantiate_recorder(data_set=self.dummy_data_set) _exclude_list = [ From 3a06e75e21fb847a61993b6d28d76ee64f4840e7 Mon Sep 17 00:00:00 2001 From: m-sauter Date: Tue, 6 Feb 2024 18:43:48 +0100 Subject: [PATCH 16/35] Some modifications to simplify the loss ntk test code --- .../analysis/test_loss_ntk_calculation.py | 29 +++++++------------ 1 file changed, 11 insertions(+), 18 deletions(-) diff --git a/CI/unit_tests/analysis/test_loss_ntk_calculation.py b/CI/unit_tests/analysis/test_loss_ntk_calculation.py index 4ec4755..5e3ac27 100644 --- a/CI/unit_tests/analysis/test_loss_ntk_calculation.py +++ b/CI/unit_tests/analysis/test_loss_ntk_calculation.py @@ -52,14 +52,14 @@ class ProductionModule(nn.Module): @nn.compact def __call__(self, x): - x = nn.Conv(features=128, kernel_size=(3, 3))(x) + x = nn.Conv(features=16, kernel_size=(3, 3))(x) x = nn.relu(x) x = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2)) - x = nn.Conv(features=128, kernel_size=(3, 3))(x) + x = nn.Conv(features=16, kernel_size=(3, 3))(x) x = nn.relu(x) x = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2)) x = x.reshape((x.shape[0], -1)) # flatten - x = nn.Dense(features=300)(x) + x = nn.Dense(features=10)(x) x = nn.relu(x) x = nn.Dense(10)(x) @@ -79,13 +79,6 @@ def test_loss_ntk_calculation(self): derivatives. """ - # Define a test Network - dense_network = stax.serial( - stax.Dense(32), - stax.Relu(), - stax.Dense(32), - ) - # Define a test model production_model = FlaxModel( flax_module=ProductionModule(), @@ -95,7 +88,7 @@ def test_loss_ntk_calculation(self): ) # Initialize model parameters - data_generator = MNISTGenerator(ds_size=10) + data_generator = MNISTGenerator(ds_size=20) data_set = { "inputs": data_generator.train_ds["inputs"], "targets": data_generator.train_ds["targets"], @@ -135,12 +128,12 @@ def test_loss_ntk_calculation(self): )[0] # Calculate the loss NTK from the loss derivatives and the ntk - loss_ntk_2 = onp.zeros_like(loss_ntk) - for i in range(len(loss_ntk_2)): - for j in range(len(loss_ntk_2[0])): - loss_ntk_2[i, j] = np.einsum( - "i, j, ij", loss_derivatives[i], loss_derivatives[j], ntk[i, j] - ) + loss_ntk_2 = np.einsum( + "ik, jl, ijkl-> ij", loss_derivatives, loss_derivatives, ntk + ) # Assert that the loss NTKs are the same - assert_array_almost_equal(loss_ntk, loss_ntk_2, decimal=2) + assert_array_almost_equal(loss_ntk, loss_ntk_2, decimal=4) + + +TestLossNTKCalculation().test_loss_ntk_calculation() From 001e46e97f7e5312cdd30de7b76c1813a981ae47 Mon Sep 17 00:00:00 2001 From: m-sauter Date: Tue, 6 Feb 2024 18:49:14 +0100 Subject: [PATCH 17/35] Class renaming to follow convention --- CI/unit_tests/analysis/test_loss_ntk_calculation.py | 8 ++++---- znnl/analysis/__init__.py | 4 ++-- znnl/analysis/loss_ntk_calculation.py | 2 +- znnl/training_recording/jax_recording.py | 6 +++--- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/CI/unit_tests/analysis/test_loss_ntk_calculation.py b/CI/unit_tests/analysis/test_loss_ntk_calculation.py index 5e3ac27..79292c2 100644 --- a/CI/unit_tests/analysis/test_loss_ntk_calculation.py +++ b/CI/unit_tests/analysis/test_loss_ntk_calculation.py @@ -36,7 +36,7 @@ from neural_tangents import stax from numpy.testing import assert_array_almost_equal -from znnl.analysis import LossDerivative, loss_ntk_calculation +from znnl.analysis import LossDerivative, LossNTKCalculation from znnl.data import MNISTGenerator from znnl.distance_metrics import LPNorm from znnl.loss_functions import LPNormLoss @@ -71,7 +71,7 @@ class TestLossNTKCalculation: Test Suite for the loss NTK calculation module. """ - def test_loss_ntk_calculation(self): + def test_LossNTKCalculation(self): """ Test the Loss NTK calculation. Here we test if the Loss NTK calculated through the neural tangents module is @@ -95,7 +95,7 @@ def test_loss_ntk_calculation(self): } # Initialize the loss NTK calculation - loss_ntk_calculator = loss_ntk_calculation( + loss_ntk_calculator = LossNTKCalculation( metric_fn=LPNorm(order=2), model=production_model, dataset=data_set, @@ -136,4 +136,4 @@ def test_loss_ntk_calculation(self): assert_array_almost_equal(loss_ntk, loss_ntk_2, decimal=4) -TestLossNTKCalculation().test_loss_ntk_calculation() +TestLossNTKCalculation().test_LossNTKCalculation() diff --git a/znnl/analysis/__init__.py b/znnl/analysis/__init__.py index 09105e9..eed4fcb 100644 --- a/znnl/analysis/__init__.py +++ b/znnl/analysis/__init__.py @@ -28,11 +28,11 @@ from znnl.analysis.eigensystem import EigenSpaceAnalysis from znnl.analysis.entropy import EntropyAnalysis from znnl.analysis.loss_fn_derivative import LossDerivative -from znnl.analysis.loss_ntk_calculation import loss_ntk_calculation +from znnl.analysis.loss_ntk_calculation import LossNTKCalculation __all__ = [ EntropyAnalysis.__name__, EigenSpaceAnalysis.__name__, LossDerivative.__name__, - loss_ntk_calculation.__name__, + LossNTKCalculation.__name__, ] diff --git a/znnl/analysis/loss_ntk_calculation.py b/znnl/analysis/loss_ntk_calculation.py index 04cc979..8444cbf 100644 --- a/znnl/analysis/loss_ntk_calculation.py +++ b/znnl/analysis/loss_ntk_calculation.py @@ -34,7 +34,7 @@ from znnl.models.jax_model import JaxModel -class loss_ntk_calculation: +class LossNTKCalculation: def __init__( self, metric_fn: Callable, diff --git a/znnl/training_recording/jax_recording.py b/znnl/training_recording/jax_recording.py index d01d74f..2f9a819 100644 --- a/znnl/training_recording/jax_recording.py +++ b/znnl/training_recording/jax_recording.py @@ -37,7 +37,7 @@ from znnl.analysis.eigensystem import EigenSpaceAnalysis from znnl.analysis.entropy import EntropyAnalysis from znnl.analysis.loss_fn_derivative import LossDerivative -from znnl.analysis.loss_ntk_calculation import loss_ntk_calculation +from znnl.analysis.loss_ntk_calculation import LossNTKCalculation from znnl.loss_functions import SimpleLoss from znnl.models.jax_model import JaxModel from znnl.training_recording.data_storage import DataStorage @@ -183,7 +183,7 @@ class JaxRecorder: ) _compute_loss_derivative: bool = False _loss_derivative_fn: LossDerivative = False - _loss_ntk_calculator: loss_ntk_calculation = None + _loss_ntk_calculator: LossNTKCalculation = None _index_count: int = 0 # Helps to avoid problems with non-1 update rates. _data_storage: DataStorage = None # For writing to disk. @@ -293,7 +293,7 @@ def instantiate_recorder(self, data_set: dict = None, overwrite: bool = False): ]): self._compute_loss_ntk = True try: - self._loss_ntk_calculator = loss_ntk_calculation( + self._loss_ntk_calculator = LossNTKCalculation( metric_fn=self._loss_fn.metric, model=self._model, dataset=self._data_set, From fd99b7706d9f26d554df134703f155a786fa94d8 Mon Sep 17 00:00:00 2001 From: m-sauter Date: Tue, 6 Feb 2024 19:02:33 +0100 Subject: [PATCH 18/35] added reshape and unshape methods in the loss_ntk_calculation --- znnl/analysis/loss_ntk_calculation.py | 55 ++++++++++++++++++--------- 1 file changed, 36 insertions(+), 19 deletions(-) diff --git a/znnl/analysis/loss_ntk_calculation.py b/znnl/analysis/loss_ntk_calculation.py index 8444cbf..39d8b8f 100644 --- a/znnl/analysis/loss_ntk_calculation.py +++ b/znnl/analysis/loss_ntk_calculation.py @@ -66,6 +66,34 @@ def __init__( ) self.empirical_ntk_jit = jax.jit(empirical_ntk) + @staticmethod + def _reshape_dataset(dataset): + """ + Helper function to reshape the dataset for the Loss NTK calculation. + """ + return np.concatenate( + ( + dataset["inputs"].reshape(dataset["inputs"].shape[0], -1), + dataset["targets"].reshape(dataset["targets"].shape[0], -1), + ), + axis=1, + ) + + @staticmethod + def _unshape_data( + datapoint: np.ndarray, + input_dimension: int, + input_shape: tuple, + target_shape: tuple, + batch_length: int, + ): + """ + Helper function to unshape the data for the subloss calculation. + """ + return datapoint[:, :input_dimension].reshape( + batch_length, *input_shape[1:] + ), datapoint[:, input_dimension:].reshape(batch_length, *target_shape[1:]) + def _function_for_loss_ntk(self, params, datapoint) -> float: """ Helper function to create a subloss apply function. @@ -78,11 +106,12 @@ def _function_for_loss_ntk(self, params, datapoint) -> float: the whole dataset at once instead of just one datapoint. """ batch_length = datapoint.shape[0] - _input = datapoint[:, : self.input_dimension].reshape( - batch_length, *self.input_shape[1:] - ) - _target = datapoint[:, self.input_dimension :].reshape( - batch_length, *self.target_shape[1:] + _input, _target = self._unshape_data( + datapoint, + self.input_dimension, + self.input_shape, + self.target_shape, + batch_length, ) return self.metric_fn( self.apply_fn(params, _input), @@ -117,24 +146,12 @@ def compute_loss_ntk( The Loss NTK matrix for both the empirical and infinite width computation. """ - x_i = np.concatenate( - ( - x_i["inputs"].reshape(x_i["inputs"].shape[0], -1), - x_i["targets"].reshape(x_i["targets"].shape[0], -1), - ), - axis=1, - ) + x_i = self._reshape_dataset(x_i) if x_j is None: x_j = x_i else: - x_j = np.concatenate( - ( - x_j["inputs"].reshape(x_j["inputs"].shape[0], -1), - x_j["targets"].reshape(x_j["targets"].shape[0], -1), - ), - axis=1, - ) + x_j = self._reshape_dataset(x_j) empirical_ntk = self.empirical_ntk_jit( x_i, From 0ef059f657b5fded15117687060d2c7b81a4040c Mon Sep 17 00:00:00 2001 From: Marc Sauter Date: Tue, 20 Feb 2024 12:37:03 +0100 Subject: [PATCH 19/35] quicksave --- znnl/analysis/loss_ntk_calculation.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/znnl/analysis/loss_ntk_calculation.py b/znnl/analysis/loss_ntk_calculation.py index 39d8b8f..c91d68b 100644 --- a/znnl/analysis/loss_ntk_calculation.py +++ b/znnl/analysis/loss_ntk_calculation.py @@ -58,8 +58,7 @@ def __init__( empirical_ntk = nt.batch( nt.empirical_ntk_fn( f=self._function_for_loss_ntk, - trace_axes=self.trace_axes, - vmap_axes=0, + trace_axes=(), ), batch_size=self.ntk_batch_size, store_on_device=self.store_on_device, From 98c6b307dda2b88c127a00b01044be72eff59992 Mon Sep 17 00:00:00 2001 From: m-sauter Date: Tue, 20 Feb 2024 13:57:15 +0100 Subject: [PATCH 20/35] Quick save --- .../analysis/test_loss_ntk_calculation.py | 101 +++++++++++++++++- znnl/analysis/loss_ntk_calculation.py | 1 + 2 files changed, 98 insertions(+), 4 deletions(-) diff --git a/CI/unit_tests/analysis/test_loss_ntk_calculation.py b/CI/unit_tests/analysis/test_loss_ntk_calculation.py index 79292c2..cebc70c 100644 --- a/CI/unit_tests/analysis/test_loss_ntk_calculation.py +++ b/CI/unit_tests/analysis/test_loss_ntk_calculation.py @@ -40,7 +40,7 @@ from znnl.data import MNISTGenerator from znnl.distance_metrics import LPNorm from znnl.loss_functions import LPNormLoss -from znnl.models import FlaxModel +from znnl.models import FlaxModel, NTModel from znnl.training_recording import JaxRecorder @@ -71,7 +71,100 @@ class TestLossNTKCalculation: Test Suite for the loss NTK calculation module. """ - def test_LossNTKCalculation(self): + def test_reshaping_methods(self): + """ + Test the _reshape_dataset and _unshape_dataset methods. + These are functions used in the loss NTK calculation. + """ + # Define a dummy model and dataset to be able to define a LossNTKCalculation class + production_model = FlaxModel( + flax_module=ProductionModule(), + optimizer=optax.adam(learning_rate=0.01), + input_shape=(1, 28, 28, 1), + trace_axes=(), + ) + + data_generator = MNISTGenerator(ds_size=20) + data_set = { + "inputs": data_generator.train_ds["inputs"], + "targets": data_generator.train_ds["targets"], + } + + # Initialize the loss NTK calculation + loss_ntk_calculator = LossNTKCalculation( + metric_fn=LPNorm(order=2), + model=production_model, + dataset=data_set, + ) + + # Setup a test dataset for reshaping + test_data_set = { + "inputs": np.array([[1, 2, 3], [4, 5, 6]]), + "targets": np.array([[7], [10]]), + } + + # Test the reshaping + reshaped_test_data_set = loss_ntk_calculator._reshape_dataset(test_data_set) + + assert_array_almost_equal( + reshaped_test_data_set, np.array([[1, 2, 3, 7], [4, 5, 6, 10]]) + ) + + # Test the unshaping + input_0, target_0 = loss_ntk_calculator._unshape_data( + reshaped_test_data_set, + input_dimension=3, + input_shape=(2, 3), + target_shape=(2, 1), + batch_length=reshaped_test_data_set.shape[0], + ) + assert_array_almost_equal(input_0, test_data_set["inputs"]) + assert_array_almost_equal(target_0, test_data_set["targets"]) + + def test_function_for_loss_ntk(self): + """ + This method tests the function that is used for the correlation matrix in the loss NTK calculation. + It is supposed to yield the loss per single datapoint.""" + # Define a simple feed forward test model + feed_forward_model = stax.serial( + stax.Dense(5), + stax.Relu(), + stax.Dense(2), + stax.Relu(), + ) + + # Initialize the model + model = NTModel( + optimizer=optax.adam(learning_rate=0.01), + input_shape=(1, 5), + trace_axes=(), + nt_module=feed_forward_model, + ) + + # Define a test dataset with only two datapoints + test_data_set = { + "inputs": np.array([[1, 2, 3, 4, 5], [1, 2, 3, 4, 8]]), + "targets": np.array([[1, 3], [2, 5]]), + } + + # Initialize loss + loss = LPNormLoss(order=2) + # Initialize the loss NTK calculation + loss_ntk_calculator = LossNTKCalculation( + metric_fn=loss, + model=model, + dataset=test_data_set, + ) + + # Calculate the subloss from the NTK first + datapoint = np.array([1, 2, 3, 4, 5, 1]) + subloss_from_NTK = loss_ntk_calculator._function_for_loss_ntk( + model.model_state.params, datapoint + ) + + print(subloss_from_NTK) + + def test_loss_NTK_calculation(self): """ Test the Loss NTK calculation. Here we test if the Loss NTK calculated through the neural tangents module is @@ -133,7 +226,7 @@ def test_LossNTKCalculation(self): ) # Assert that the loss NTKs are the same - assert_array_almost_equal(loss_ntk, loss_ntk_2, decimal=4) + assert_array_almost_equal(loss_ntk, loss_ntk_2) -TestLossNTKCalculation().test_LossNTKCalculation() +TestLossNTKCalculation().test_function_for_loss_ntk() diff --git a/znnl/analysis/loss_ntk_calculation.py b/znnl/analysis/loss_ntk_calculation.py index c91d68b..1d7c3e6 100644 --- a/znnl/analysis/loss_ntk_calculation.py +++ b/znnl/analysis/loss_ntk_calculation.py @@ -59,6 +59,7 @@ def __init__( nt.empirical_ntk_fn( f=self._function_for_loss_ntk, trace_axes=(), + vmap_axes=0, ), batch_size=self.ntk_batch_size, store_on_device=self.store_on_device, From cd5644accc0f192a7c8e13f33143dc28cf5eccb9 Mon Sep 17 00:00:00 2001 From: m-sauter Date: Tue, 20 Feb 2024 18:24:36 +0100 Subject: [PATCH 21/35] Added test for eigenvalues, precision is still only e-4 --- .../analysis/test_loss_ntk_calculation.py | 36 ++++++++++++++----- znnl/analysis/loss_ntk_calculation.py | 5 +-- 2 files changed, 31 insertions(+), 10 deletions(-) diff --git a/CI/unit_tests/analysis/test_loss_ntk_calculation.py b/CI/unit_tests/analysis/test_loss_ntk_calculation.py index cebc70c..cf91056 100644 --- a/CI/unit_tests/analysis/test_loss_ntk_calculation.py +++ b/CI/unit_tests/analysis/test_loss_ntk_calculation.py @@ -36,7 +36,7 @@ from neural_tangents import stax from numpy.testing import assert_array_almost_equal -from znnl.analysis import LossDerivative, LossNTKCalculation +from znnl.analysis import LossDerivative, LossNTKCalculation, EigenSpaceAnalysis from znnl.data import MNISTGenerator from znnl.distance_metrics import LPNorm from znnl.loss_functions import LPNormLoss @@ -74,7 +74,7 @@ class TestLossNTKCalculation: def test_reshaping_methods(self): """ Test the _reshape_dataset and _unshape_dataset methods. - These are functions used in the loss NTK calculation. + These are functions used in the loss NTK calculation to """ # Define a dummy model and dataset to be able to define a LossNTKCalculation class production_model = FlaxModel( @@ -151,18 +151,33 @@ def test_function_for_loss_ntk(self): loss = LPNormLoss(order=2) # Initialize the loss NTK calculation loss_ntk_calculator = LossNTKCalculation( - metric_fn=loss, + metric_fn=loss.metric, model=model, dataset=test_data_set, ) # Calculate the subloss from the NTK first - datapoint = np.array([1, 2, 3, 4, 5, 1]) + datapoint = loss_ntk_calculator._reshape_dataset(test_data_set)[0:1] subloss_from_NTK = loss_ntk_calculator._function_for_loss_ntk( - model.model_state.params, datapoint + { + "params": model.model_state.params, + "batch_stats": model.model_state.batch_stats, + }, + datapoint=datapoint, ) - print(subloss_from_NTK) + # Now calculate subloss manually + applied_model = model.apply( + { + "params": model.model_state.params, + "batch_stats": model.model_state.batch_stats, + }, + test_data_set["inputs"][0], + ) + subloss = np.linalg.norm(applied_model - test_data_set["targets"][0], ord=2) + + # Check that the two losses are the same + assert subloss - subloss_from_NTK < 1e-5 def test_loss_NTK_calculation(self): """ @@ -226,7 +241,12 @@ def test_loss_NTK_calculation(self): ) # Assert that the loss NTKs are the same - assert_array_almost_equal(loss_ntk, loss_ntk_2) + assert_array_almost_equal(loss_ntk, loss_ntk_2, decimal=4) + + calculator1 = EigenSpaceAnalysis(matrix=loss_ntk) + calculator2 = EigenSpaceAnalysis(matrix=loss_ntk_2) + eigenvalues1 = calculator1.compute_eigenvalues(normalize=False) + eigenvalue2 = calculator2.compute_eigenvalues(normalize=False) -TestLossNTKCalculation().test_function_for_loss_ntk() + assert_array_almost_equal(eigenvalues1, eigenvalue2, decimal=4) diff --git a/znnl/analysis/loss_ntk_calculation.py b/znnl/analysis/loss_ntk_calculation.py index 1d7c3e6..3fb5ad8 100644 --- a/znnl/analysis/loss_ntk_calculation.py +++ b/znnl/analysis/loss_ntk_calculation.py @@ -41,10 +41,11 @@ def __init__( model: JaxModel, dataset: dict, ): - """Constructor for the loss ntk calculation class.""" + """Constructor for the loss ntk calculation class. + + Metri fn has to be the Metric, not the Loss!""" # Set the attributes - self.metric_fn = metric_fn self.ntk_batch_size = model.ntk_batch_size self.store_on_device = model.store_on_device self.trace_axes = model.trace_axes From b4fe246a56deed7bc2fe3a4ee42fe0be9e2f2b16 Mon Sep 17 00:00:00 2001 From: m-sauter Date: Tue, 20 Feb 2024 18:38:29 +0100 Subject: [PATCH 22/35] Added some docstrings --- .../analysis/test_loss_ntk_calculation.py | 21 ++++-- znnl/analysis/loss_ntk_calculation.py | 72 ++++++++++++++++--- 2 files changed, 78 insertions(+), 15 deletions(-) diff --git a/CI/unit_tests/analysis/test_loss_ntk_calculation.py b/CI/unit_tests/analysis/test_loss_ntk_calculation.py index cf91056..85f020e 100644 --- a/CI/unit_tests/analysis/test_loss_ntk_calculation.py +++ b/CI/unit_tests/analysis/test_loss_ntk_calculation.py @@ -36,12 +36,11 @@ from neural_tangents import stax from numpy.testing import assert_array_almost_equal -from znnl.analysis import LossDerivative, LossNTKCalculation, EigenSpaceAnalysis +from znnl.analysis import EigenSpaceAnalysis, LossDerivative, LossNTKCalculation from znnl.data import MNISTGenerator from znnl.distance_metrics import LPNorm from znnl.loss_functions import LPNormLoss from znnl.models import FlaxModel, NTModel -from znnl.training_recording import JaxRecorder # Defines a simple CNN module @@ -68,7 +67,7 @@ def __call__(self, x): class TestLossNTKCalculation: """ - Test Suite for the loss NTK calculation module. + Test Suite for the LossNTKCalculation module. """ def test_reshaping_methods(self): @@ -76,7 +75,8 @@ def test_reshaping_methods(self): Test the _reshape_dataset and _unshape_dataset methods. These are functions used in the loss NTK calculation to """ - # Define a dummy model and dataset to be able to define a LossNTKCalculation class + # Define a dummy model and dataset to be able to define a + # LossNTKCalculation class production_model = FlaxModel( flax_module=ProductionModule(), optimizer=optax.adam(learning_rate=0.01), @@ -123,8 +123,10 @@ def test_reshaping_methods(self): def test_function_for_loss_ntk(self): """ - This method tests the function that is used for the correlation matrix in the loss NTK calculation. - It is supposed to yield the loss per single datapoint.""" + This method tests the function that is used for the correlation matrix + in the loss NTK calculation. It is supposed to yield the loss per single + datapoint. + """ # Define a simple feed forward test model feed_forward_model = stax.serial( stax.Dense(5), @@ -185,6 +187,11 @@ def test_loss_NTK_calculation(self): Here we test if the Loss NTK calculated through the neural tangents module is the same as the Loss NTK calculated with the already implemented NTK and loss derivatives. + We do this for a small CNN model and the MNIST dataset. + We also check if the eigenvalues of the two Loss NTKs are the same. + + The current implementation yields a precision of e-4. If these are numerical + errors or due to a mistake in the implementation is to be decided. """ # Define a test model @@ -230,7 +237,7 @@ def test_loss_NTK_calculation(self): # note: here we need the derivatives of the subloss, not the regular loss fn loss_derivatives = onp.empty(shape=(len(predictions), len(predictions[0]))) for i in range(len(loss_derivatives)): - # The weird indexing here is because of axis constraints in the LPNormLoss module + # The weird indexing here is because of axis constraints in LPNormLoss loss_derivatives[i] = loss_derivative_calculator.calculate( predictions[i : i + 1], data_set["targets"][i : i + 1] )[0] diff --git a/znnl/analysis/loss_ntk_calculation.py b/znnl/analysis/loss_ntk_calculation.py index 3fb5ad8..fd5d10d 100644 --- a/znnl/analysis/loss_ntk_calculation.py +++ b/znnl/analysis/loss_ntk_calculation.py @@ -41,9 +41,25 @@ def __init__( model: JaxModel, dataset: dict, ): - """Constructor for the loss ntk calculation class. + """ + Constructor for the loss ntk calculation class. + + Parameters + ---------- + + metric_fn : Callable + The metric function to be used for the loss calculation. + !This has to be the metric, not the Loss! + If you put in the Loss here you won't get an error but an + incorrect result. + + model : JaxModel + The model for which to calculate the loss NTK. - Metri fn has to be the Metric, not the Loss!""" + dataset : dict + The dataset for which to calculate the loss NTK. + The dictionary should contain the keys "inputs" and "targets". + """ # Set the attributes self.ntk_batch_size = model.ntk_batch_size @@ -71,6 +87,17 @@ def __init__( def _reshape_dataset(dataset): """ Helper function to reshape the dataset for the Loss NTK calculation. + + Parameters + ---------- + dataset : dict + The dataset to be reshaped. + Should contain the keys "inputs" and "targets". + + Returns + ------- + reshaped_dataset : np.ndarray + The reshaped dataset. """ return np.concatenate( ( @@ -90,6 +117,24 @@ def _unshape_data( ): """ Helper function to unshape the data for the subloss calculation. + + Parameters + ---------- + datapoint : np.ndarray + The datapoint to be unshaped. + input_dimension : int + The total dimension of the input, i.e. the product of its shape. + input_shape : tuple + The shape of the original input. + target_shape : tuple + The shape of the original target. + + Returns + ------- + input: np.ndarray + The unshaped input. + target: np.ndarray + The unshaped target. """ return datapoint[:, :input_dimension].reshape( batch_length, *input_shape[1:] @@ -99,12 +144,22 @@ def _function_for_loss_ntk(self, params, datapoint) -> float: """ Helper function to create a subloss apply function. The datapoint here has to be shaped so that its an array of length - input dimension + output dimension. - This is done so that the inputs and targets can be understood - by the neural tangents empirical_ntk_fn function. + input dimension + output dimension. This is done so that the inputs + and targets can be understood by the neural tangents empirical_ntk_fn + function. It gets unpacked by the _unshape_data function in here. + + Parameters + ---------- + params : dict + The parameters of the model. + datapoint : np.ndarray + The datapoint for which to calculate the subloss. Shaped as + described in the description of this function. - Seems like during the NTK calculation, this function needs to handle - the whole dataset at once instead of just one datapoint. + Returns + ------- + subloss : float + The subloss for the given datapoint. """ batch_length = datapoint.shape[0] _input, _target = self._unshape_data( @@ -144,7 +199,8 @@ def compute_loss_ntk( Returns ------- Loss NTK : dict - The Loss NTK matrix for both the empirical and infinite width computation. + The Loss NTK matrix for both the empirical and + infinite width computation. """ x_i = self._reshape_dataset(x_i) From bf7a401ee2334c9b4c2ca3a0943ca3c34110a68e Mon Sep 17 00:00:00 2001 From: m-sauter Date: Tue, 20 Feb 2024 18:40:35 +0100 Subject: [PATCH 23/35] More docstrings --- znnl/training_recording/jax_recording.py | 40 ++++++++++++++---------- 1 file changed, 23 insertions(+), 17 deletions(-) diff --git a/znnl/training_recording/jax_recording.py b/znnl/training_recording/jax_recording.py index 2f9a819..373c76b 100644 --- a/znnl/training_recording/jax_recording.py +++ b/znnl/training_recording/jax_recording.py @@ -271,26 +271,29 @@ def instantiate_recorder(self, data_set: dict = None, overwrite: bool = False): self._index_count = 0 # Check if we need an NTK computation and update the class accordingly - if any([ - "ntk" in self._selected_properties, - "covariance_ntk" in self._selected_properties, - "magnitude_ntk" in self._selected_properties, - "entropy" in self._selected_properties, - "magnitude_entropy" in self._selected_properties, - "magnitude_variance" in self._selected_properties, - "covariance_entropy" in self._selected_properties, - "eigenvalues" in self._selected_properties, - "trace" in self._selected_properties, - ]): + if any( + [ + "ntk" in self._selected_properties, + "covariance_ntk" in self._selected_properties, + "magnitude_ntk" in self._selected_properties, + "entropy" in self._selected_properties, + "magnitude_entropy" in self._selected_properties, + "magnitude_variance" in self._selected_properties, + "covariance_entropy" in self._selected_properties, + "eigenvalues" in self._selected_properties, + "trace" in self._selected_properties, + ] + ): self._compute_ntk = True # Check if we need a loss NTK computation and update the class accordingly - - if any([ - "loss_ntk" in self._selected_properties, - "loss_ntk_eigenvalues" in self._selected_properties, - "loss_ntk_entropy" in self._selected_properties, - ]): + if any( + [ + "loss_ntk" in self._selected_properties, + "loss_ntk_eigenvalues" in self._selected_properties, + "loss_ntk_entropy" in self._selected_properties, + ] + ): self._compute_loss_ntk = True try: self._loss_ntk_calculator = LossNTKCalculation( @@ -299,6 +302,9 @@ def instantiate_recorder(self, data_set: dict = None, overwrite: bool = False): dataset=self._data_set, ) except AttributeError: + # This happens frequently during the instantiation of the recorder. + # As this shouldn't lead to a problem if the loss function is set later, + # before the loss NTK is computed, we just log the issue and continue. logger.info( "Warning: The loss function hasn't been set yet." "Please set it before training." From 43cdcb943df66d86eb8b7b7e4209a63890948966 Mon Sep 17 00:00:00 2001 From: m-sauter Date: Tue, 20 Feb 2024 18:43:07 +0100 Subject: [PATCH 24/35] Black formatting --- znnl/training_recording/jax_recording.py | 36 +++++++++++------------- 1 file changed, 16 insertions(+), 20 deletions(-) diff --git a/znnl/training_recording/jax_recording.py b/znnl/training_recording/jax_recording.py index 373c76b..c8f6fd1 100644 --- a/znnl/training_recording/jax_recording.py +++ b/znnl/training_recording/jax_recording.py @@ -271,29 +271,25 @@ def instantiate_recorder(self, data_set: dict = None, overwrite: bool = False): self._index_count = 0 # Check if we need an NTK computation and update the class accordingly - if any( - [ - "ntk" in self._selected_properties, - "covariance_ntk" in self._selected_properties, - "magnitude_ntk" in self._selected_properties, - "entropy" in self._selected_properties, - "magnitude_entropy" in self._selected_properties, - "magnitude_variance" in self._selected_properties, - "covariance_entropy" in self._selected_properties, - "eigenvalues" in self._selected_properties, - "trace" in self._selected_properties, - ] - ): + if any([ + "ntk" in self._selected_properties, + "covariance_ntk" in self._selected_properties, + "magnitude_ntk" in self._selected_properties, + "entropy" in self._selected_properties, + "magnitude_entropy" in self._selected_properties, + "magnitude_variance" in self._selected_properties, + "covariance_entropy" in self._selected_properties, + "eigenvalues" in self._selected_properties, + "trace" in self._selected_properties, + ]): self._compute_ntk = True # Check if we need a loss NTK computation and update the class accordingly - if any( - [ - "loss_ntk" in self._selected_properties, - "loss_ntk_eigenvalues" in self._selected_properties, - "loss_ntk_entropy" in self._selected_properties, - ] - ): + if any([ + "loss_ntk" in self._selected_properties, + "loss_ntk_eigenvalues" in self._selected_properties, + "loss_ntk_entropy" in self._selected_properties, + ]): self._compute_loss_ntk = True try: self._loss_ntk_calculator = LossNTKCalculation( From fd8626f28493710ca37d35edb1b56f8ee4a0eaae Mon Sep 17 00:00:00 2001 From: m-sauter Date: Tue, 20 Feb 2024 18:50:21 +0100 Subject: [PATCH 25/35] More docstrings --- znnl/analysis/loss_ntk_calculation.py | 3 ++ znnl/training_recording/jax_recording.py | 36 +++++++++++++----------- 2 files changed, 23 insertions(+), 16 deletions(-) diff --git a/znnl/analysis/loss_ntk_calculation.py b/znnl/analysis/loss_ntk_calculation.py index fd5d10d..695987c 100644 --- a/znnl/analysis/loss_ntk_calculation.py +++ b/znnl/analysis/loss_ntk_calculation.py @@ -128,6 +128,9 @@ def _unshape_data( The shape of the original input. target_shape : tuple The shape of the original target. + batch_length : int + The length of the batch, i.e. the number of actual datapoints + inside the datapoint variable Returns ------- diff --git a/znnl/training_recording/jax_recording.py b/znnl/training_recording/jax_recording.py index c8f6fd1..373c76b 100644 --- a/znnl/training_recording/jax_recording.py +++ b/znnl/training_recording/jax_recording.py @@ -271,25 +271,29 @@ def instantiate_recorder(self, data_set: dict = None, overwrite: bool = False): self._index_count = 0 # Check if we need an NTK computation and update the class accordingly - if any([ - "ntk" in self._selected_properties, - "covariance_ntk" in self._selected_properties, - "magnitude_ntk" in self._selected_properties, - "entropy" in self._selected_properties, - "magnitude_entropy" in self._selected_properties, - "magnitude_variance" in self._selected_properties, - "covariance_entropy" in self._selected_properties, - "eigenvalues" in self._selected_properties, - "trace" in self._selected_properties, - ]): + if any( + [ + "ntk" in self._selected_properties, + "covariance_ntk" in self._selected_properties, + "magnitude_ntk" in self._selected_properties, + "entropy" in self._selected_properties, + "magnitude_entropy" in self._selected_properties, + "magnitude_variance" in self._selected_properties, + "covariance_entropy" in self._selected_properties, + "eigenvalues" in self._selected_properties, + "trace" in self._selected_properties, + ] + ): self._compute_ntk = True # Check if we need a loss NTK computation and update the class accordingly - if any([ - "loss_ntk" in self._selected_properties, - "loss_ntk_eigenvalues" in self._selected_properties, - "loss_ntk_entropy" in self._selected_properties, - ]): + if any( + [ + "loss_ntk" in self._selected_properties, + "loss_ntk_eigenvalues" in self._selected_properties, + "loss_ntk_entropy" in self._selected_properties, + ] + ): self._compute_loss_ntk = True try: self._loss_ntk_calculator = LossNTKCalculation( From b6ebb8d5f0b5991fb8f17e4fcbfe7403a6566fa5 Mon Sep 17 00:00:00 2001 From: m-sauter Date: Wed, 21 Feb 2024 22:37:24 +0100 Subject: [PATCH 26/35] Some PR modifications --- .../analysis/test_loss_ntk_calculation.py | 88 +++++++++---------- znnl/analysis/loss_ntk_calculation.py | 24 +++-- 2 files changed, 58 insertions(+), 54 deletions(-) diff --git a/CI/unit_tests/analysis/test_loss_ntk_calculation.py b/CI/unit_tests/analysis/test_loss_ntk_calculation.py index 85f020e..1876bb3 100644 --- a/CI/unit_tests/analysis/test_loss_ntk_calculation.py +++ b/CI/unit_tests/analysis/test_loss_ntk_calculation.py @@ -36,7 +36,7 @@ from neural_tangents import stax from numpy.testing import assert_array_almost_equal -from znnl.analysis import EigenSpaceAnalysis, LossDerivative, LossNTKCalculation +from znnl.analysis import EigenSpaceAnalysis, LossNTKCalculation from znnl.data import MNISTGenerator from znnl.distance_metrics import LPNorm from znnl.loss_functions import LPNormLoss @@ -100,14 +100,14 @@ def test_reshaping_methods(self): # Setup a test dataset for reshaping test_data_set = { "inputs": np.array([[1, 2, 3], [4, 5, 6]]), - "targets": np.array([[7], [10]]), + "targets": np.array([[7, 9], [10, 12]]), } # Test the reshaping reshaped_test_data_set = loss_ntk_calculator._reshape_dataset(test_data_set) assert_array_almost_equal( - reshaped_test_data_set, np.array([[1, 2, 3, 7], [4, 5, 6, 10]]) + reshaped_test_data_set, np.array([[1, 2, 3, 7, 9], [4, 5, 6, 10, 12]]) ) # Test the unshaping @@ -115,7 +115,7 @@ def test_reshaping_methods(self): reshaped_test_data_set, input_dimension=3, input_shape=(2, 3), - target_shape=(2, 1), + target_shape=(2, 2), batch_length=reshaped_test_data_set.shape[0], ) assert_array_almost_equal(input_0, test_data_set["inputs"]) @@ -123,9 +123,8 @@ def test_reshaping_methods(self): def test_function_for_loss_ntk(self): """ - This method tests the function that is used for the correlation matrix - in the loss NTK calculation. It is supposed to yield the loss per single - datapoint. + This method tests the function that calculates the loss for single + datapoints. """ # Define a simple feed forward test model feed_forward_model = stax.serial( @@ -183,72 +182,69 @@ def test_function_for_loss_ntk(self): def test_loss_NTK_calculation(self): """ - Test the Loss NTK calculation. - Here we test if the Loss NTK calculated through the neural tangents module is - the same as the Loss NTK calculated with the already implemented NTK and loss - derivatives. - We do this for a small CNN model and the MNIST dataset. - We also check if the eigenvalues of the two Loss NTKs are the same. - - The current implementation yields a precision of e-4. If these are numerical - errors or due to a mistake in the implementation is to be decided. + Test the Loss NTK calculation with a manually hardcoded dataset + and a simple feed forward network. """ + # Define a simple feed forward test model + feed_forward_model = stax.serial( + stax.Dense(5), + stax.Relu(), + stax.Dense(1), + stax.Relu(), + ) - # Define a test model - production_model = FlaxModel( - flax_module=ProductionModule(), + # Initialize the model + model = NTModel( optimizer=optax.adam(learning_rate=0.01), - input_shape=(1, 28, 28, 1), + input_shape=(2,), trace_axes=(), + nt_module=feed_forward_model, ) - # Initialize model parameters - data_generator = MNISTGenerator(ds_size=20) - data_set = { - "inputs": data_generator.train_ds["inputs"], - "targets": data_generator.train_ds["targets"], - } + # Create a test dataset + inputs = np.array(onp.random.rand(10, 2)) + targets = np.array(100 * onp.random.rand(10, 1)) + test_data_set = {"inputs": inputs, "targets": targets} + # Initialize loss + loss = LPNormLoss(order=2) # Initialize the loss NTK calculation loss_ntk_calculator = LossNTKCalculation( - metric_fn=LPNorm(order=2), - model=production_model, - dataset=data_set, + metric_fn=loss.metric, + model=model, + dataset=test_data_set, ) # Compute the loss NTK - loss_ntk = loss_ntk_calculator.compute_loss_ntk( - x_i=data_set, model=production_model - )["empirical"] - - # Now for comparison calculate regular ntk - ntk = production_model.compute_ntk(data_set["inputs"], infinite=False)[ + loss_ntk = loss_ntk_calculator.compute_loss_ntk(x_i=test_data_set, model=model)[ "empirical" ] - # Calculate Loss derivative fn - loss_derivative_calculator = LossDerivative(LPNormLoss(order=2)) + + # Now for comparison calculate regular ntk + ntk = model.compute_ntk(test_data_set["inputs"], infinite=False)["empirical"] # predictions calculation analogous to the one in jax recording - predictions = production_model(data_set["inputs"]) + predictions = model(test_data_set["inputs"]) if type(predictions) is tuple: predictions = predictions[0] # calculation of loss derivatives # note: here we need the derivatives of the subloss, not the regular loss fn - loss_derivatives = onp.empty(shape=(len(predictions), len(predictions[0]))) + loss_derivatives = onp.empty(shape=(len(predictions))) for i in range(len(loss_derivatives)): - # The weird indexing here is because of axis constraints in LPNormLoss - loss_derivatives[i] = loss_derivative_calculator.calculate( - predictions[i : i + 1], data_set["targets"][i : i + 1] - )[0] + loss_derivatives[i] = ( + predictions[i, 0] / np.abs(predictions[i, 0]) + if predictions[i, 0] != 0 + else 0 + ) # Calculate the loss NTK from the loss derivatives and the ntk loss_ntk_2 = np.einsum( - "ik, jl, ijkl-> ij", loss_derivatives, loss_derivatives, ntk + "i, j, ijkl-> ij", loss_derivatives, loss_derivatives, ntk ) # Assert that the loss NTKs are the same - assert_array_almost_equal(loss_ntk, loss_ntk_2, decimal=4) + assert_array_almost_equal(loss_ntk, loss_ntk_2, decimal=6) calculator1 = EigenSpaceAnalysis(matrix=loss_ntk) calculator2 = EigenSpaceAnalysis(matrix=loss_ntk_2) @@ -256,4 +252,4 @@ def test_loss_NTK_calculation(self): eigenvalues1 = calculator1.compute_eigenvalues(normalize=False) eigenvalue2 = calculator2.compute_eigenvalues(normalize=False) - assert_array_almost_equal(eigenvalues1, eigenvalue2, decimal=4) + assert_array_almost_equal(eigenvalues1, eigenvalue2, decimal=6) diff --git a/znnl/analysis/loss_ntk_calculation.py b/znnl/analysis/loss_ntk_calculation.py index 695987c..0b37603 100644 --- a/znnl/analysis/loss_ntk_calculation.py +++ b/znnl/analysis/loss_ntk_calculation.py @@ -35,6 +35,17 @@ class LossNTKCalculation: + """ + Class to calculate the loss NTK matrix for a given model and dataset. + + The loss NTK matrix is defined as + .. math:: + (\\Lambda_L)_{ij} = \\sum_\\theta\\frac{\\partial\\ell_i}{\\partial\\theta} + \\frac{\\partial\\ell_j}{\\partial\\theta}}. + Where :math:`\\ell_i` is the loss function for the single datapoint :math:`i` and + :math:`\\theta` are the parameters of the network. + """ + def __init__( self, metric_fn: Callable, @@ -145,19 +156,16 @@ def _unshape_data( def _function_for_loss_ntk(self, params, datapoint) -> float: """ - Helper function to create a subloss apply function. - The datapoint here has to be shaped so that its an array of length - input dimension + output dimension. This is done so that the inputs - and targets can be understood by the neural tangents empirical_ntk_fn - function. It gets unpacked by the _unshape_data function in here. + Helper function that calculates the subloss for single datapoints. Parameters ---------- params : dict The parameters of the model. - datapoint : np.ndarray - The datapoint for which to calculate the subloss. Shaped as - described in the description of this function. + datapoint : np.ndarray (batch_length, input_dimension + output_dimension) + The datapoint for which to calculate the subloss. The indicated + shaping is necessary so that this function can be understood by the + empirical NTK function of the neural tangents library. Returns ------- From d17e5149e0e93df1ee89509ff16fca738950f430 Mon Sep 17 00:00:00 2001 From: Marc Sauter Date: Mon, 26 Feb 2024 14:40:30 +0100 Subject: [PATCH 27/35] fixing PR comment --- znnl/training_strategies/simple_training.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/znnl/training_strategies/simple_training.py b/znnl/training_strategies/simple_training.py index eb461d7..707e37c 100644 --- a/znnl/training_strategies/simple_training.py +++ b/znnl/training_strategies/simple_training.py @@ -109,7 +109,8 @@ def __init__( self.review_metric = None - # Add the loss and accuracy function to the recorders and re-instantiate them + # Add the loss and accuracy function and the model + # to the recorders and re-instantiate them if self.recorders is not None: for item in self.recorders: item._model = self.model From eb490430d25180b35d03d7f357ffb12ffce6f81b Mon Sep 17 00:00:00 2001 From: Marc Sauter Date: Mon, 26 Feb 2024 15:07:40 +0100 Subject: [PATCH 28/35] requirements change --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 3050110..cb1c088 100644 --- a/requirements.txt +++ b/requirements.txt @@ -18,7 +18,7 @@ plotly flax tqdm pandas -neural-tangents==0.6.4 +neural-tangents tensorflow-datasets isort tensorflow From 53c548dc45b228b832a4b48dd8727f3cda407b76 Mon Sep 17 00:00:00 2001 From: Marc Sauter Date: Mon, 26 Feb 2024 15:21:56 +0100 Subject: [PATCH 29/35] Black formatter changes --- znnl/visualization/tsne_visualizer.py | 110 ++++++++++++++------------ 1 file changed, 58 insertions(+), 52 deletions(-) diff --git a/znnl/visualization/tsne_visualizer.py b/znnl/visualization/tsne_visualizer.py index 856ecdd..d0d613a 100644 --- a/znnl/visualization/tsne_visualizer.py +++ b/znnl/visualization/tsne_visualizer.py @@ -104,45 +104,47 @@ def run_visualization(self): fig_dict["layout"]["xaxis2"] = {"domain": [0.8, 1.0]} fig_dict["layout"]["yaxis2"] = {"anchor": "x2"} fig_dict["layout"]["hovermode"] = "closest" - fig_dict["layout"]["updatemenus"] = [{ - "buttons": [ - { - "args": [ - None, - { - "frame": {"duration": 500, "redraw": False}, - "fromcurrent": True, - "transition": { - "duration": 300, - "easing": "quadratic-in-out", + fig_dict["layout"]["updatemenus"] = [ + { + "buttons": [ + { + "args": [ + None, + { + "frame": {"duration": 500, "redraw": False}, + "fromcurrent": True, + "transition": { + "duration": 300, + "easing": "quadratic-in-out", + }, }, - }, - ], - "label": "Play", - "method": "animate", - }, - { - "args": [ - [None], - { - "frame": {"duration": 0, "redraw": False}, - "mode": "immediate", - "transition": {"duration": 0}, - }, - ], - "label": "Pause", - "method": "animate", - }, - ], - "direction": "left", - "pad": {"r": 10, "t": 87}, - "showactive": False, - "type": "buttons", - "x": 0.1, - "xanchor": "right", - "y": 0, - "yanchor": "top", - }] + ], + "label": "Play", + "method": "animate", + }, + { + "args": [ + [None], + { + "frame": {"duration": 0, "redraw": False}, + "mode": "immediate", + "transition": {"duration": 0}, + }, + ], + "label": "Pause", + "method": "animate", + }, + ], + "direction": "left", + "pad": {"r": 10, "t": 87}, + "showactive": False, + "type": "buttons", + "x": 0.1, + "xanchor": "right", + "y": 0, + "yanchor": "top", + } + ] sliders_dict = { "active": 0, @@ -163,20 +165,24 @@ def run_visualization(self): } # Add initial data - fig_dict["data"].append({ - "x": self.dynamic[0][:, 0], - "y": self.dynamic[0][:, 1], - "mode": "markers", - "name": "Predictor", - }) - fig_dict["data"].append({ - "x": self.reference[0][:, 0], - "y": self.reference[0][:, 1], - "mode": "markers", - "xaxis": "x2", - "yaxis": "y2", - "name": "Target", - }) + fig_dict["data"].append( + { + "x": self.dynamic[0][:, 0], + "y": self.dynamic[0][:, 1], + "mode": "markers", + "name": "Predictor", + } + ) + fig_dict["data"].append( + { + "x": self.reference[0][:, 0], + "y": self.reference[0][:, 1], + "mode": "markers", + "xaxis": "x2", + "yaxis": "y2", + "name": "Target", + } + ) # Make the figure frames. for i, item in enumerate(self.dynamic): From 1dac43435ff563c7e89d931906c752d0d4788722 Mon Sep 17 00:00:00 2001 From: Marc Sauter Date: Mon, 26 Feb 2024 17:35:40 +0100 Subject: [PATCH 30/35] changed recorder to use the use_loss_ntk flag --- znnl/training_recording/jax_recording.py | 69 +++--------------------- 1 file changed, 7 insertions(+), 62 deletions(-) diff --git a/znnl/training_recording/jax_recording.py b/znnl/training_recording/jax_recording.py index 373c76b..4ec3cfd 100644 --- a/znnl/training_recording/jax_recording.py +++ b/znnl/training_recording/jax_recording.py @@ -158,17 +158,8 @@ class JaxRecorder: loss_derivative: bool = False _loss_derivative_array: list = None - # Loss NTK - loss_ntk: bool = False - _loss_ntk_array: list = None - - # Loss NTK eigenvalues - loss_ntk_eigenvalues: bool = False - _loss_ntk_eigenvalues_array: list = None - - # Loss NTK entropy - loss_ntk_entropy: bool = False - _loss_ntk_entropy_array: list = None + # Use Loss NTK + use_loss_ntk: bool = False # Class helpers update_rate: int = 1 @@ -195,7 +186,7 @@ def _read_selected_attributes(self): self._selected_properties = [ value for value in list(vars(self)) - if value[0] != "_" and vars(self)[value] is True + if value[0] != "_" and value != "use_loss_ntk" and vars(self)[value] is True ] def _build_or_resize_array(self, name: str, overwrite: bool): @@ -287,14 +278,7 @@ def instantiate_recorder(self, data_set: dict = None, overwrite: bool = False): self._compute_ntk = True # Check if we need a loss NTK computation and update the class accordingly - if any( - [ - "loss_ntk" in self._selected_properties, - "loss_ntk_eigenvalues" in self._selected_properties, - "loss_ntk_entropy" in self._selected_properties, - ] - ): - self._compute_loss_ntk = True + if self.use_loss_ntk: try: self._loss_ntk_calculator = LossNTKCalculation( metric_fn=self._loss_fn.metric, @@ -343,7 +327,7 @@ def update_recorder(self, epoch: int, model: JaxModel): parsed_data["predictions"] = predictions # Compute ntk here to avoid repeated computation. - if self._compute_ntk: + if self._compute_ntk and not self.use_loss_ntk: try: ntk = self._model.compute_ntk( self._data_set["inputs"], infinite=False @@ -365,8 +349,8 @@ def update_recorder(self, epoch: int, model: JaxModel): self._read_selected_attributes() # Compute loss ntk here to avoid repeated computation. - if self._compute_loss_ntk: - parsed_data["loss_ntk"] = self._loss_ntk_calculator.compute_loss_ntk( + if self._compute_ntk and self.use_loss_ntk: + parsed_data["ntk"] = self._loss_ntk_calculator.compute_loss_ntk( x_i=self._data_set, x_j=None, model=self._model, @@ -639,45 +623,6 @@ def _update_loss_derivative(self, parsed_data): loss_derivative = calculate_l_pq_norm(vector_loss_derivative) self._loss_derivative_array.append(loss_derivative) - def _update_loss_ntk(self, parsed_data): - """ - Update the loss ntk array. - - Parameters - ---------- - parsed_data : dict - Data computed before the update to prevent repeated calculations. - """ - self._loss_ntk_array.append(parsed_data["loss_ntk"]) - - def _update_loss_ntk_eigenvalues(self, parsed_data): - """ - Update the loss ntk eigenvalue array. - - Parameters - ---------- - parsed_data : dict - Data computed before the update to prevent repeated calculations. - """ - calculator = EigenSpaceAnalysis(matrix=parsed_data["loss_ntk"]) - eigenvalues = calculator.compute_eigenvalues(normalize=False) - self._loss_ntk_eigenvalues_array.append(eigenvalues) - - def _update_loss_ntk_entropy(self, parsed_data): - """ - Update the loss ntk entropy array. - - Parameters - ---------- - parsed_data : dict - Data computed before the update to prevent repeated calculations. - """ - calculator = EntropyAnalysis(matrix=parsed_data["loss_ntk"]) - entropy = calculator.compute_von_neumann_entropy( - effective=False, normalize_eig=True - ) - self._loss_ntk_entropy_array.append(entropy) - def gather_recording(self, selected_properties: list = None) -> dataclass: """ Export a dataclass of used properties. From 422eceb4d34f5de3795383dd2584403eecd406a4 Mon Sep 17 00:00:00 2001 From: Marc Sauter Date: Mon, 26 Feb 2024 17:55:20 +0100 Subject: [PATCH 31/35] Change recorder test for new flag --- CI/unit_tests/training_recording/test_training_recording.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/CI/unit_tests/training_recording/test_training_recording.py b/CI/unit_tests/training_recording/test_training_recording.py index 6218338..cf7ba60 100644 --- a/CI/unit_tests/training_recording/test_training_recording.py +++ b/CI/unit_tests/training_recording/test_training_recording.py @@ -66,9 +66,6 @@ def test_instantiation(self): eigenvalues=True, trace=True, loss_derivative=True, - loss_ntk=True, - loss_ntk_eigenvalues=True, - loss_ntk_entropy=True, ) recorder.instantiate_recorder(data_set=self.dummy_data_set) _exclude_list = [ @@ -78,6 +75,7 @@ def test_instantiation(self): "name", "storage_path", "chunk_size", + "use_loss_ntk", ] for key, val in vars(recorder).items(): if key[0] != "_" and key not in _exclude_list: From b5e8170de495e40249d2741198c1d00ac593bf81 Mon Sep 17 00:00:00 2001 From: Marc Sauter Date: Mon, 26 Feb 2024 18:13:15 +0100 Subject: [PATCH 32/35] removed unneccesary CNN model from loss_ntk calculation test --- .../analysis/test_loss_ntk_calculation.py | 39 ++++++------------- 1 file changed, 12 insertions(+), 27 deletions(-) diff --git a/CI/unit_tests/analysis/test_loss_ntk_calculation.py b/CI/unit_tests/analysis/test_loss_ntk_calculation.py index 1876bb3..ee51cb3 100644 --- a/CI/unit_tests/analysis/test_loss_ntk_calculation.py +++ b/CI/unit_tests/analysis/test_loss_ntk_calculation.py @@ -32,7 +32,6 @@ import jax.numpy as np import numpy as onp import optax -from flax import linen as nn from neural_tangents import stax from numpy.testing import assert_array_almost_equal @@ -43,28 +42,6 @@ from znnl.models import FlaxModel, NTModel -# Defines a simple CNN module -class ProductionModule(nn.Module): - """ - Simple CNN module. - """ - - @nn.compact - def __call__(self, x): - x = nn.Conv(features=16, kernel_size=(3, 3))(x) - x = nn.relu(x) - x = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2)) - x = nn.Conv(features=16, kernel_size=(3, 3))(x) - x = nn.relu(x) - x = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2)) - x = x.reshape((x.shape[0], -1)) # flatten - x = nn.Dense(features=10)(x) - x = nn.relu(x) - x = nn.Dense(10)(x) - - return x - - class TestLossNTKCalculation: """ Test Suite for the LossNTKCalculation module. @@ -77,11 +54,19 @@ def test_reshaping_methods(self): """ # Define a dummy model and dataset to be able to define a # LossNTKCalculation class - production_model = FlaxModel( - flax_module=ProductionModule(), + feed_forward_model = stax.serial( + stax.Dense(5), + stax.Relu(), + stax.Dense(2), + stax.Relu(), + ) + + # Initialize the model + test_model = NTModel( optimizer=optax.adam(learning_rate=0.01), - input_shape=(1, 28, 28, 1), + input_shape=(1, 5), trace_axes=(), + nt_module=feed_forward_model, ) data_generator = MNISTGenerator(ds_size=20) @@ -93,7 +78,7 @@ def test_reshaping_methods(self): # Initialize the loss NTK calculation loss_ntk_calculator = LossNTKCalculation( metric_fn=LPNorm(order=2), - model=production_model, + model=test_model, dataset=data_set, ) From 594f4cb14e92b1ef2acbf2f53ff62c17703e7034 Mon Sep 17 00:00:00 2001 From: Marc Sauter Date: Mon, 26 Feb 2024 18:48:34 +0100 Subject: [PATCH 33/35] Started integration test for loss_ntk_calculation --- .../test_loss_ntk_recording_deployment.py | 104 ++++++++++++++++++ 1 file changed, 104 insertions(+) create mode 100644 CI/integration_tests/training_recording/test_loss_ntk_recording_deployment.py diff --git a/CI/integration_tests/training_recording/test_loss_ntk_recording_deployment.py b/CI/integration_tests/training_recording/test_loss_ntk_recording_deployment.py new file mode 100644 index 0000000..2e9097a --- /dev/null +++ b/CI/integration_tests/training_recording/test_loss_ntk_recording_deployment.py @@ -0,0 +1,104 @@ +""" +ZnNL: A Zincwarecode package. + +License +------- +This program and the accompanying materials are made available under the terms +of the Eclipse Public License v2.0 which accompanies this distribution, and is +available at https://www.eclipse.org/legal/epl-v20.html + +SPDX-License-Identifier: EPL-2.0 + +Copyright Contributors to the Zincwarecode Project. + +Contact Information +------------------- +email: zincwarecode@gmail.com +github: https://github.com/zincware +web: https://zincwarecode.com/ + +Citation +-------- +If you use this module please cite us with: + +Summary +------- +""" + +import os + +os.environ["CUDA_VISIBLE_DEVICES"] = "-1" + +import numpy as np +from numpy.testing import assert_array_almost_equal +import optax +from neural_tangents import stax + +from znnl.loss_functions import LPNormLoss +from znnl.models import NTModel +from znnl.training_recording import JaxRecorder +from znnl.training_strategies import SimpleTraining + + +class TestLossNTKRecorderDeployment: + """ + Test suite for the loss and NTK recorder. + """ + + @classmethod + def setup_class(cls): + """ + Create a model and data for the tests. + """ + + network = stax.serial(stax.Dense(5), stax.Relu(), stax.Dense(1), stax.Relu()) + cls.model = NTModel( + nt_module=network, input_shape=(5,), optimizer=optax.adam(1e-3) + ) + + cls.data_set = { + "inputs": np.random.rand(10, 5), + "targets": np.random.randint(0, 2, 10), + } + + cls.ntk_recorder = JaxRecorder( + name="ntk_recorder", + ntk=True, + update_rate=1, + ) + cls.loss_ntk_recorder = JaxRecorder( + name="loss_ntk_recorder", + ntk=True, + use_loss_ntk=True, + update_rate=1, + ) + + cls.ntk_recorder.instantiate_recorder(data_set=cls.data_set) + cls.loss_ntk_recorder.instantiate_recorder(data_set=cls.data_set) + + cls.trainer = SimpleTraining( + model=cls.model, + loss_fn=LPNormLoss(order=2), + recorders=[cls.ntk_recorder, cls.loss_ntk_recorder], + ) + + def test_loss_ntk_deployment(self): + """ + Test the deployment of the loss_NTK recorder. + """ + + # train the model + training_metrics = self.trainer.train_model( + train_ds=self.data_set, + test_ds=self.data_set, + epochs=10, + batch_size=2, + ) + + # gather the recording + ntk_recording = self.ntk_recorder.gather_recording() + loss_ntk_recording = self.loss_ntk_recorder.gather_recording() + + # assert_array_almost_equal( + # np.abs(ntk_recording.ntk), np.abs(loss_ntk_recording.ntk), decimal=5 + # ) From b55fc309925e6e9ffc4715c83b26b24d5812e532 Mon Sep 17 00:00:00 2001 From: Marc Sauter Date: Tue, 27 Feb 2024 12:20:00 +0100 Subject: [PATCH 34/35] Implemented integration test --- .../test_loss_ntk_recording_deployment.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/CI/integration_tests/training_recording/test_loss_ntk_recording_deployment.py b/CI/integration_tests/training_recording/test_loss_ntk_recording_deployment.py index 2e9097a..33a4c60 100644 --- a/CI/integration_tests/training_recording/test_loss_ntk_recording_deployment.py +++ b/CI/integration_tests/training_recording/test_loss_ntk_recording_deployment.py @@ -51,14 +51,16 @@ def setup_class(cls): Create a model and data for the tests. """ - network = stax.serial(stax.Dense(5), stax.Relu(), stax.Dense(1), stax.Relu()) + network = stax.serial( + stax.Dense(10), stax.Relu(), stax.Dense(10), stax.Relu(), stax.Dense(1) + ) cls.model = NTModel( nt_module=network, input_shape=(5,), optimizer=optax.adam(1e-3) ) cls.data_set = { "inputs": np.random.rand(10, 5), - "targets": np.random.randint(0, 2, 10), + "targets": np.random.randint(0, 2, (10, 1)), } cls.ntk_recorder = JaxRecorder( @@ -99,6 +101,8 @@ def test_loss_ntk_deployment(self): ntk_recording = self.ntk_recorder.gather_recording() loss_ntk_recording = self.loss_ntk_recorder.gather_recording() - # assert_array_almost_equal( - # np.abs(ntk_recording.ntk), np.abs(loss_ntk_recording.ntk), decimal=5 - # ) + # For LPNormLoss of order 2 and a 1D output Network, the NTK and the loss NTK + # should be the same up to a factor of +1 or -1. + assert_array_almost_equal( + np.abs(ntk_recording.ntk), np.abs(loss_ntk_recording.ntk), decimal=5 + ) From 080099ecae3136511ce8d504feb400df31800928 Mon Sep 17 00:00:00 2001 From: Marc Sauter Date: Tue, 27 Feb 2024 12:26:34 +0100 Subject: [PATCH 35/35] isort --- .../training_recording/test_loss_ntk_recording_deployment.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CI/integration_tests/training_recording/test_loss_ntk_recording_deployment.py b/CI/integration_tests/training_recording/test_loss_ntk_recording_deployment.py index 33a4c60..d0032bb 100644 --- a/CI/integration_tests/training_recording/test_loss_ntk_recording_deployment.py +++ b/CI/integration_tests/training_recording/test_loss_ntk_recording_deployment.py @@ -30,9 +30,9 @@ os.environ["CUDA_VISIBLE_DEVICES"] = "-1" import numpy as np -from numpy.testing import assert_array_almost_equal import optax from neural_tangents import stax +from numpy.testing import assert_array_almost_equal from znnl.loss_functions import LPNormLoss from znnl.models import NTModel