diff --git a/znnl/analysis/loss_ntk_calculation.py b/znnl/analysis/loss_ntk_calculation.py index 571944a..f6c02c7 100644 --- a/znnl/analysis/loss_ntk_calculation.py +++ b/znnl/analysis/loss_ntk_calculation.py @@ -26,7 +26,7 @@ """ import neural_tangents as nt -from typing import Callable +from typing import Callable, Union, Sequence from znnl.models.jax_model import JaxModel import jax import jax.numpy as np