From d739dae7966aa5df7cec458a31b17158c4aaea63 Mon Sep 17 00:00:00 2001 From: knikolaou <> Date: Thu, 16 May 2024 14:34:46 +0200 Subject: [PATCH] Create tests for the jax ntk computation --- CI/unit_tests/ntk_computation/test_jax_ntk.py | 171 ++++++++++++++++++ znnl/ntk_computation/jax_ntk.py | 39 +++- 2 files changed, 205 insertions(+), 5 deletions(-) create mode 100644 CI/unit_tests/ntk_computation/test_jax_ntk.py diff --git a/CI/unit_tests/ntk_computation/test_jax_ntk.py b/CI/unit_tests/ntk_computation/test_jax_ntk.py new file mode 100644 index 0000000..e9ea32e --- /dev/null +++ b/CI/unit_tests/ntk_computation/test_jax_ntk.py @@ -0,0 +1,171 @@ +""" +ZnNL: A Zincwarecode package. + +License +------- +This program and the accompanying materials are made available under the terms +of the Eclipse Public License v2.0 which accompanies this distribution, and is +available at https://www.eclipse.org/legal/epl-v20.html + +SPDX-License-Identifier: EPL-2.0 + +Copyright Contributors to the Zincwarecode Project. + +Contact Information +------------------- +email: zincwarecode@gmail.com +github: https://github.com/zincware +web: https://zincwarecode.com/ + +Citation +-------- +If you use this module please cite us with: + +Summary +------- +""" + +import jax.numpy as np +import neural_tangents as nt +import optax +from flax import linen as nn +from jax import random + +from znnl.models import FlaxModel +from znnl.ntk_computation import JAXNTKComputation + + +class FlaxTestModule(nn.Module): + """ + Test model for the Flax tests. + """ + + @nn.compact + def __call__(self, x): + x = nn.Dense(5, use_bias=True)(x) + x = nn.relu(x) + x = nn.Dense(features=2, use_bias=True)(x) + return x + + +class TestJAXNTKComputation: + """ + Test class for the JAX NTK computation class. + """ + + @classmethod + def setup_class(cls): + """ + Setup the test class. + """ + cls.flax_model = FlaxModel( + flax_module=FlaxTestModule(), + optimizer=optax.adam(learning_rate=0.001), + input_shape=(8,), + seed=17, + ) + + cls.dataset = { + "inputs": random.normal(random.PRNGKey(0), (10, 8)), + "targets": random.normal(random.PRNGKey(1), (10, 2)), + } + + def test_constructor(self): + """ + Test the constructor of the JAX NTK computation class. + """ + apply_fn = lambda x: x + batch_size = 10 + ntk_implementation = None + trace_axes = () + store_on_device = False + flatten = True + + jax_ntk_computation = JAXNTKComputation( + apply_fn=apply_fn, + batch_size=batch_size, + ntk_implementation=ntk_implementation, + trace_axes=trace_axes, + store_on_device=store_on_device, + flatten=flatten, + ) + + assert jax_ntk_computation.apply_fn == apply_fn + assert jax_ntk_computation.batch_size == batch_size + assert jax_ntk_computation.trace_axes == trace_axes + assert jax_ntk_computation.store_on_device == store_on_device + assert jax_ntk_computation.flatten == flatten + + # Default ntk_implementation should be NTK_VECTOR_PRODUCTS + assert ( + jax_ntk_computation.ntk_implementation + == nt.NtkImplementation.NTK_VECTOR_PRODUCTS + ) + + # Test the default trace_axes + jax_ntk_computation = JAXNTKComputation( + apply_fn=apply_fn, + batch_size=batch_size, + ntk_implementation=ntk_implementation, + store_on_device=store_on_device, + flatten=flatten, + ) + + assert jax_ntk_computation.trace_axes == () + + def test_check_shape(self): + """ + Test the shape checking function. + """ + jax_ntk_computation = JAXNTKComputation(apply_fn=self.flax_model.ntk_apply_fn) + + ntk = np.ones((10, 10, 3, 3)) + ntk_ = jax_ntk_computation._check_shape(ntk) + + assert ntk_.shape == (30, 30) + + def test_compute_ntk(self): + """ + Test the computation of the NTK. + """ + params = {"params": self.flax_model.model_state.params} + + # Trace axes is empty and flatten is True + trace_axes = () + jax_ntk_computation = JAXNTKComputation( + apply_fn=self.flax_model.ntk_apply_fn, + trace_axes=(), + flatten=True, + ) + ntk = jax_ntk_computation.compute_ntk(params, self.dataset["inputs"]) + assert np.shape(ntk) == (1, 20, 20) + + # Trace axes is empty and flatten is False + jax_ntk_computation = JAXNTKComputation( + apply_fn=self.flax_model.ntk_apply_fn, + trace_axes=(), + flatten=False, + ) + ntk = jax_ntk_computation.compute_ntk(params, self.dataset["inputs"]) + + assert np.shape(ntk) == (1, 10, 10, 2, 2) + + # Trace axes is (-1,) and flatten is True + jax_ntk_computation = JAXNTKComputation( + apply_fn=self.flax_model.ntk_apply_fn, + trace_axes=(-1,), + flatten=True, + ) + ntk = jax_ntk_computation.compute_ntk(params, self.dataset["inputs"]) + + assert np.shape(ntk) == (1, 10, 10) + + # Trace axes is (-1,) and flatten is False + jax_ntk_computation = JAXNTKComputation( + apply_fn=self.flax_model.ntk_apply_fn, + trace_axes=(-1,), + flatten=False, + ) + ntk = jax_ntk_computation.compute_ntk(params, self.dataset["inputs"]) + + assert np.shape(ntk) == (1, 10, 10) diff --git a/znnl/ntk_computation/jax_ntk.py b/znnl/ntk_computation/jax_ntk.py index 80e9b7d..27ef8d7 100644 --- a/znnl/ntk_computation/jax_ntk.py +++ b/znnl/ntk_computation/jax_ntk.py @@ -29,6 +29,7 @@ import jax.numpy as np import neural_tangents as nt +from papyrus.utils.matrix_utils import flatten_rank_4_tensor class JAXNTKComputation: @@ -44,6 +45,7 @@ def __init__( ntk_implementation: nt.NtkImplementation = None, trace_axes: tuple = (), store_on_device: bool = False, + flatten: bool = True, ): """ Constructor the JAX NTK computation class. @@ -80,29 +82,54 @@ def apply_fn(params, x): store_on_device : bool, default True Whether to store the NTK on the device or not. This should be set False for large NTKs that do not fit in GPU memory. + flatten : bool, default True + If True, the NTK shape is checked and flattened into a 2D matrix, if + required. """ self.apply_fn = apply_fn self.batch_size = batch_size self.ntk_implementation = ntk_implementation self.trace_axes = trace_axes self.store_on_device = store_on_device + self.flatten = flatten + + self._ntk_shape = None # Prepare NTK calculation - if not ntk_implementation: + if self.ntk_implementation is None: if trace_axes == (): - ntk_implementation = nt.NtkImplementation.NTK_VECTOR_PRODUCTS + self.ntk_implementation = nt.NtkImplementation.NTK_VECTOR_PRODUCTS else: - ntk_implementation = nt.NtkImplementation.JACOBIAN_CONTRACTION + self.ntk_implementation = nt.NtkImplementation.JACOBIAN_CONTRACTION self.empirical_ntk = nt.batch( nt.empirical_ntk_fn( f=apply_fn, trace_axes=trace_axes, - implementation=ntk_implementation, + implementation=self.ntk_implementation, ), batch_size=batch_size, store_on_device=store_on_device, ) + def _check_shape(self, ntk: np.ndarray) -> np.ndarray: + """ + Check the shape of the NTK matrix and flatten it if required. + + Parameters + ---------- + ntk : np.ndarray + The NTK matrix. + + Returns + ------- + np.ndarray + The NTK matrix. + """ + self._ntk_shape = ntk.shape + if self.flatten and len(self._ntk_shape) > 2: + ntk, _ = flatten_rank_4_tensor(ntk) + return ntk + def compute_ntk( self, params: dict, x_i: np.ndarray, x_j: Optional[np.ndarray] = None ) -> List[np.ndarray]: @@ -121,4 +148,6 @@ def compute_ntk( List[np.ndarray] The NTK matrix. """ - return [self.empirical_ntk(x_i, x_j, params)] + ntk = self.empirical_ntk(x_i, x_j, params) + ntk = self._check_shape(ntk) + return [ntk]