From d9bb28262b9a7753685d9b5baaee73a9c723920e Mon Sep 17 00:00:00 2001 From: knikolaou <> Date: Thu, 2 May 2024 20:21:23 +0200 Subject: [PATCH] run black --- znnl/training_recording/jax_recording.py | 10 +++++----- znnl/utils/matrix_utils.py | 1 + 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/znnl/training_recording/jax_recording.py b/znnl/training_recording/jax_recording.py index 8d0c916..0efcff9 100644 --- a/znnl/training_recording/jax_recording.py +++ b/znnl/training_recording/jax_recording.py @@ -29,6 +29,7 @@ from dataclasses import dataclass, make_dataclass from os import path from pathlib import Path +from typing import Optional import jax.numpy as np import numpy as onp @@ -41,12 +42,11 @@ from znnl.models.jax_model import JaxModel from znnl.training_recording.data_storage import DataStorage from znnl.utils.matrix_utils import ( + calculate_trace, compute_magnitude_density, flatten_rank_4_tensor, normalize_gram_matrix, - calculate_trace, ) -from typing import Optional logger = logging.getLogger(__name__) @@ -181,7 +181,7 @@ class JaxRecorder: _loss_derivative_fn: LossDerivative = False _index_count: int = 0 # Helps to avoid problems with non-1 update rates. _data_storage: DataStorage = None # For writing to disk. - _ntk_rank: Optional[int] = None # Rank of the NTK matrix. + _ntk_rank: Optional[int] = None # Rank of the NTK matrix. def _read_selected_attributes(self): """ @@ -320,7 +320,7 @@ def update_recorder(self, epoch: int, model: JaxModel): ntk = self._model.compute_ntk( self._data_set["inputs"], infinite=False )["empirical"] - self._ntk_rank = len(ntk.shape) + self._ntk_rank = len(ntk.shape) if self.flatten_ntk and self._ntk_rank == 4: ntk = flatten_rank_4_tensor(ntk) parsed_data["ntk"] = ntk @@ -591,7 +591,7 @@ def _update_trace(self, parsed_data: dict): """ Update the trace of the NTK. - The trace of the NTK is computed as the mean of the diagonal elements of the + The trace of the NTK is computed as the mean of the diagonal elements of the NTK. Parameters diff --git a/znnl/utils/matrix_utils.py b/znnl/utils/matrix_utils.py index 95819c7..496fd8a 100644 --- a/znnl/utils/matrix_utils.py +++ b/znnl/utils/matrix_utils.py @@ -168,6 +168,7 @@ def flatten_rank_4_tensor(tensor: np.ndarray) -> np.ndarray: _tensor.shape[0] * _tensor.shape[1], _tensor.shape[0] * _tensor.shape[1] ) + def calculate_trace(matrix: np.ndarray, normalize: bool = False) -> np.ndarray: """ Calculate the trace of a matrix, including optional normalization.