From 768fc13bb5ca704b60c196fb542f19ac54f7eebf Mon Sep 17 00:00:00 2001 From: m-sauter Date: Thu, 18 Jan 2024 19:46:55 +0100 Subject: [PATCH] quick fix --- znnl/analysis/loss_ntk_calculation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/znnl/analysis/loss_ntk_calculation.py b/znnl/analysis/loss_ntk_calculation.py index 571944a..f6c02c7 100644 --- a/znnl/analysis/loss_ntk_calculation.py +++ b/znnl/analysis/loss_ntk_calculation.py @@ -26,7 +26,7 @@ """ import neural_tangents as nt -from typing import Callable +from typing import Callable, Union, Sequence from znnl.models.jax_model import JaxModel import jax import jax.numpy as np