From f64bcfa986adbdddbb367767d45569ba397d47ac Mon Sep 17 00:00:00 2001 From: knikolaou <> Date: Fri, 5 Apr 2024 10:47:33 +0200 Subject: [PATCH] Remove jitting from ntk calculation --- znnl/models/jax_model.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/znnl/models/jax_model.py b/znnl/models/jax_model.py index 8e8ea95..9dee179 100644 --- a/znnl/models/jax_model.py +++ b/znnl/models/jax_model.py @@ -127,7 +127,6 @@ def __init__( batch_size=ntk_batch_size, store_on_device=store_on_device, ) - self.empirical_ntk_jit = jax.jit(self.empirical_ntk) self.apply_jit = jax.jit(self.apply) def init_model( @@ -249,7 +248,7 @@ def compute_ntk( """ if x_j is None: x_j = x_i - empirical_ntk = self.empirical_ntk_jit( + empirical_ntk = self.empirical_ntk( x_i, x_j, {