Skip to content

Commit

Permalink
run black
Browse files Browse the repository at this point in the history
  • Loading branch information
knikolaou committed May 2, 2024
1 parent 2f17a51 commit d9bb282
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 5 deletions.
10 changes: 5 additions & 5 deletions znnl/training_recording/jax_recording.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from dataclasses import dataclass, make_dataclass
from os import path
from pathlib import Path
from typing import Optional

import jax.numpy as np
import numpy as onp
Expand All @@ -41,12 +42,11 @@
from znnl.models.jax_model import JaxModel
from znnl.training_recording.data_storage import DataStorage
from znnl.utils.matrix_utils import (
calculate_trace,
compute_magnitude_density,
flatten_rank_4_tensor,
normalize_gram_matrix,
calculate_trace,
)
from typing import Optional

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -181,7 +181,7 @@ class JaxRecorder:
_loss_derivative_fn: LossDerivative = False
_index_count: int = 0 # Helps to avoid problems with non-1 update rates.
_data_storage: DataStorage = None # For writing to disk.
_ntk_rank: Optional[int] = None # Rank of the NTK matrix.
_ntk_rank: Optional[int] = None # Rank of the NTK matrix.

def _read_selected_attributes(self):
"""
Expand Down Expand Up @@ -320,7 +320,7 @@ def update_recorder(self, epoch: int, model: JaxModel):
ntk = self._model.compute_ntk(
self._data_set["inputs"], infinite=False
)["empirical"]
self._ntk_rank = len(ntk.shape)
self._ntk_rank = len(ntk.shape)
if self.flatten_ntk and self._ntk_rank == 4:
ntk = flatten_rank_4_tensor(ntk)
parsed_data["ntk"] = ntk
Expand Down Expand Up @@ -591,7 +591,7 @@ def _update_trace(self, parsed_data: dict):
"""
Update the trace of the NTK.
The trace of the NTK is computed as the mean of the diagonal elements of the
The trace of the NTK is computed as the mean of the diagonal elements of the
NTK.
Parameters
Expand Down
1 change: 1 addition & 0 deletions znnl/utils/matrix_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ def flatten_rank_4_tensor(tensor: np.ndarray) -> np.ndarray:
_tensor.shape[0] * _tensor.shape[1], _tensor.shape[0] * _tensor.shape[1]
)


def calculate_trace(matrix: np.ndarray, normalize: bool = False) -> np.ndarray:
"""
Calculate the trace of a matrix, including optional normalization.
Expand Down

0 comments on commit d9bb282

Please sign in to comment.