Skip to content

Commit

Permalink
Remove jitting from ntk calculation (#114)
Browse files Browse the repository at this point in the history
  • Loading branch information
KonstiNik authored Apr 8, 2024
1 parent cc82976 commit a60a80b
Showing 1 changed file with 1 addition and 2 deletions.
3 changes: 1 addition & 2 deletions znnl/models/jax_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,6 @@ def __init__(
batch_size=ntk_batch_size,
store_on_device=store_on_device,
)
self.empirical_ntk_jit = jax.jit(self.empirical_ntk)
self.apply_jit = jax.jit(self.apply)

def init_model(
Expand Down Expand Up @@ -249,7 +248,7 @@ def compute_ntk(
"""
if x_j is None:
x_j = x_i
empirical_ntk = self.empirical_ntk_jit(
empirical_ntk = self.empirical_ntk(
x_i,
x_j,
{
Expand Down

0 comments on commit a60a80b

Please sign in to comment.