diff --git a/znnl/models/jax_model.py b/znnl/models/jax_model.py index 5bde9cc..6b18cef 100644 --- a/znnl/models/jax_model.py +++ b/znnl/models/jax_model.py @@ -81,11 +81,11 @@ def __init__( self.init_model(seed) # Prepare NTK calculation - self.empirical_ntk = nt.empirical_ntk_fn( + self.empirical_ntk = nt.batch(nt.empirical_ntk_fn( f=self._ntk_apply_fn, trace_axes=trace_axes - ) + ), batch_size=ntk_batch_size) - self.empirical_ntk_jit = jax.jit(self.empirical_ntk) + self.empirical_ntk_jit = self.empirical_ntk def init_model( self,