Skip to content

Commit

Permalink
Create tests for the jax ntk computation
Browse files Browse the repository at this point in the history
  • Loading branch information
knikolaou committed May 16, 2024
1 parent 226cea1 commit d739dae
Show file tree
Hide file tree
Showing 2 changed files with 205 additions and 5 deletions.
171 changes: 171 additions & 0 deletions CI/unit_tests/ntk_computation/test_jax_ntk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
"""
ZnNL: A Zincwarecode package.
License
-------
This program and the accompanying materials are made available under the terms
of the Eclipse Public License v2.0 which accompanies this distribution, and is
available at https://www.eclipse.org/legal/epl-v20.html
SPDX-License-Identifier: EPL-2.0
Copyright Contributors to the Zincwarecode Project.
Contact Information
-------------------
email: zincwarecode@gmail.com
github: https://github.com/zincware
web: https://zincwarecode.com/
Citation
--------
If you use this module please cite us with:
Summary
-------
"""

import jax.numpy as np
import neural_tangents as nt
import optax
from flax import linen as nn
from jax import random

from znnl.models import FlaxModel
from znnl.ntk_computation import JAXNTKComputation


class FlaxTestModule(nn.Module):
"""
Test model for the Flax tests.
"""

@nn.compact
def __call__(self, x):
x = nn.Dense(5, use_bias=True)(x)
x = nn.relu(x)
x = nn.Dense(features=2, use_bias=True)(x)
return x


class TestJAXNTKComputation:
"""
Test class for the JAX NTK computation class.
"""

@classmethod
def setup_class(cls):
"""
Setup the test class.
"""
cls.flax_model = FlaxModel(
flax_module=FlaxTestModule(),
optimizer=optax.adam(learning_rate=0.001),
input_shape=(8,),
seed=17,
)

cls.dataset = {
"inputs": random.normal(random.PRNGKey(0), (10, 8)),
"targets": random.normal(random.PRNGKey(1), (10, 2)),
}

def test_constructor(self):
"""
Test the constructor of the JAX NTK computation class.
"""
apply_fn = lambda x: x
batch_size = 10
ntk_implementation = None
trace_axes = ()
store_on_device = False
flatten = True

jax_ntk_computation = JAXNTKComputation(
apply_fn=apply_fn,
batch_size=batch_size,
ntk_implementation=ntk_implementation,
trace_axes=trace_axes,
store_on_device=store_on_device,
flatten=flatten,
)

assert jax_ntk_computation.apply_fn == apply_fn
assert jax_ntk_computation.batch_size == batch_size
assert jax_ntk_computation.trace_axes == trace_axes
assert jax_ntk_computation.store_on_device == store_on_device
assert jax_ntk_computation.flatten == flatten

# Default ntk_implementation should be NTK_VECTOR_PRODUCTS
assert (
jax_ntk_computation.ntk_implementation
== nt.NtkImplementation.NTK_VECTOR_PRODUCTS
)

# Test the default trace_axes
jax_ntk_computation = JAXNTKComputation(
apply_fn=apply_fn,
batch_size=batch_size,
ntk_implementation=ntk_implementation,
store_on_device=store_on_device,
flatten=flatten,
)

assert jax_ntk_computation.trace_axes == ()

def test_check_shape(self):
"""
Test the shape checking function.
"""
jax_ntk_computation = JAXNTKComputation(apply_fn=self.flax_model.ntk_apply_fn)

ntk = np.ones((10, 10, 3, 3))
ntk_ = jax_ntk_computation._check_shape(ntk)

assert ntk_.shape == (30, 30)

def test_compute_ntk(self):
"""
Test the computation of the NTK.
"""
params = {"params": self.flax_model.model_state.params}

# Trace axes is empty and flatten is True
trace_axes = ()
jax_ntk_computation = JAXNTKComputation(
apply_fn=self.flax_model.ntk_apply_fn,
trace_axes=(),
flatten=True,
)
ntk = jax_ntk_computation.compute_ntk(params, self.dataset["inputs"])
assert np.shape(ntk) == (1, 20, 20)

# Trace axes is empty and flatten is False
jax_ntk_computation = JAXNTKComputation(
apply_fn=self.flax_model.ntk_apply_fn,
trace_axes=(),
flatten=False,
)
ntk = jax_ntk_computation.compute_ntk(params, self.dataset["inputs"])

assert np.shape(ntk) == (1, 10, 10, 2, 2)

# Trace axes is (-1,) and flatten is True
jax_ntk_computation = JAXNTKComputation(
apply_fn=self.flax_model.ntk_apply_fn,
trace_axes=(-1,),
flatten=True,
)
ntk = jax_ntk_computation.compute_ntk(params, self.dataset["inputs"])

assert np.shape(ntk) == (1, 10, 10)

# Trace axes is (-1,) and flatten is False
jax_ntk_computation = JAXNTKComputation(
apply_fn=self.flax_model.ntk_apply_fn,
trace_axes=(-1,),
flatten=False,
)
ntk = jax_ntk_computation.compute_ntk(params, self.dataset["inputs"])

assert np.shape(ntk) == (1, 10, 10)
39 changes: 34 additions & 5 deletions znnl/ntk_computation/jax_ntk.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

import jax.numpy as np
import neural_tangents as nt
from papyrus.utils.matrix_utils import flatten_rank_4_tensor


class JAXNTKComputation:
Expand All @@ -44,6 +45,7 @@ def __init__(
ntk_implementation: nt.NtkImplementation = None,
trace_axes: tuple = (),
store_on_device: bool = False,
flatten: bool = True,
):
"""
Constructor the JAX NTK computation class.
Expand Down Expand Up @@ -80,29 +82,54 @@ def apply_fn(params, x):
store_on_device : bool, default True
Whether to store the NTK on the device or not.
This should be set False for large NTKs that do not fit in GPU memory.
flatten : bool, default True
If True, the NTK shape is checked and flattened into a 2D matrix, if
required.
"""
self.apply_fn = apply_fn
self.batch_size = batch_size
self.ntk_implementation = ntk_implementation
self.trace_axes = trace_axes
self.store_on_device = store_on_device
self.flatten = flatten

self._ntk_shape = None

# Prepare NTK calculation
if not ntk_implementation:
if self.ntk_implementation is None:
if trace_axes == ():
ntk_implementation = nt.NtkImplementation.NTK_VECTOR_PRODUCTS
self.ntk_implementation = nt.NtkImplementation.NTK_VECTOR_PRODUCTS
else:
ntk_implementation = nt.NtkImplementation.JACOBIAN_CONTRACTION
self.ntk_implementation = nt.NtkImplementation.JACOBIAN_CONTRACTION
self.empirical_ntk = nt.batch(
nt.empirical_ntk_fn(
f=apply_fn,
trace_axes=trace_axes,
implementation=ntk_implementation,
implementation=self.ntk_implementation,
),
batch_size=batch_size,
store_on_device=store_on_device,
)

def _check_shape(self, ntk: np.ndarray) -> np.ndarray:
"""
Check the shape of the NTK matrix and flatten it if required.
Parameters
----------
ntk : np.ndarray
The NTK matrix.
Returns
-------
np.ndarray
The NTK matrix.
"""
self._ntk_shape = ntk.shape
if self.flatten and len(self._ntk_shape) > 2:
ntk, _ = flatten_rank_4_tensor(ntk)
return ntk

def compute_ntk(
self, params: dict, x_i: np.ndarray, x_j: Optional[np.ndarray] = None
) -> List[np.ndarray]:
Expand All @@ -121,4 +148,6 @@ def compute_ntk(
List[np.ndarray]
The NTK matrix.
"""
return [self.empirical_ntk(x_i, x_j, params)]
ntk = self.empirical_ntk(x_i, x_j, params)
ntk = self._check_shape(ntk)
return [ntk]

0 comments on commit d739dae

Please sign in to comment.