Skip to content

Commit

Permalink
Play with optimizer options
Browse files Browse the repository at this point in the history
  • Loading branch information
SamTov committed Sep 26, 2023
1 parent b093a08 commit ab13509
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions znnl/models/jax_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,11 @@ def __init__(
self.init_model(seed)

# Prepare NTK calculation
self.empirical_ntk = nt.empirical_ntk_fn(
self.empirical_ntk = nt.batch(nt.empirical_ntk_fn(
f=self._ntk_apply_fn, trace_axes=trace_axes
)
), batch_size=ntk_batch_size)

self.empirical_ntk_jit = jax.jit(self.empirical_ntk)
self.empirical_ntk_jit = self.empirical_ntk

def init_model(
self,
Expand Down

0 comments on commit ab13509

Please sign in to comment.