diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f102bd4..ce4d47f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,12 +4,12 @@ fail_fast: true repos: - repo: https://github.com/psf/black - rev: 23.9.1 + rev: 23.12.0 hooks: - id: black - repo: https://github.com/timothycrosley/isort - rev: 5.12.0 + rev: 5.13.2 hooks: - id: isort diff --git a/znnl/training_recording/jax_recording.py b/znnl/training_recording/jax_recording.py index 9ba92f8..862b7d4 100644 --- a/znnl/training_recording/jax_recording.py +++ b/znnl/training_recording/jax_recording.py @@ -254,19 +254,17 @@ def instantiate_recorder(self, data_set: dict = None, overwrite: bool = False): self._index_count = 0 # Check if we need an NTK computation and update the class accordingly - if any( - [ - "ntk" in self._selected_properties, - "covariance_ntk" in self._selected_properties, - "magnitude_ntk" in self._selected_properties, - "entropy" in self._selected_properties, - "magnitude_entropy" in self._selected_properties, - "magnitude_variance" in self._selected_properties, - "covariance_entropy" in self._selected_properties, - "eigenvalues" in self._selected_properties, - "trace" in self._selected_properties, - ] - ): + if any([ + "ntk" in self._selected_properties, + "covariance_ntk" in self._selected_properties, + "magnitude_ntk" in self._selected_properties, + "entropy" in self._selected_properties, + "magnitude_entropy" in self._selected_properties, + "magnitude_variance" in self._selected_properties, + "covariance_entropy" in self._selected_properties, + "eigenvalues" in self._selected_properties, + "trace" in self._selected_properties, + ]): self._compute_ntk = True if "loss_derivative" in self._selected_properties: diff --git a/znnl/visualization/tsne_visualizer.py b/znnl/visualization/tsne_visualizer.py index d0d613a..856ecdd 100644 --- a/znnl/visualization/tsne_visualizer.py +++ b/znnl/visualization/tsne_visualizer.py @@ -104,47 +104,45 @@ def run_visualization(self): fig_dict["layout"]["xaxis2"] = {"domain": [0.8, 1.0]} fig_dict["layout"]["yaxis2"] = {"anchor": "x2"} fig_dict["layout"]["hovermode"] = "closest" - fig_dict["layout"]["updatemenus"] = [ - { - "buttons": [ - { - "args": [ - None, - { - "frame": {"duration": 500, "redraw": False}, - "fromcurrent": True, - "transition": { - "duration": 300, - "easing": "quadratic-in-out", - }, - }, - ], - "label": "Play", - "method": "animate", - }, - { - "args": [ - [None], - { - "frame": {"duration": 0, "redraw": False}, - "mode": "immediate", - "transition": {"duration": 0}, + fig_dict["layout"]["updatemenus"] = [{ + "buttons": [ + { + "args": [ + None, + { + "frame": {"duration": 500, "redraw": False}, + "fromcurrent": True, + "transition": { + "duration": 300, + "easing": "quadratic-in-out", }, - ], - "label": "Pause", - "method": "animate", - }, - ], - "direction": "left", - "pad": {"r": 10, "t": 87}, - "showactive": False, - "type": "buttons", - "x": 0.1, - "xanchor": "right", - "y": 0, - "yanchor": "top", - } - ] + }, + ], + "label": "Play", + "method": "animate", + }, + { + "args": [ + [None], + { + "frame": {"duration": 0, "redraw": False}, + "mode": "immediate", + "transition": {"duration": 0}, + }, + ], + "label": "Pause", + "method": "animate", + }, + ], + "direction": "left", + "pad": {"r": 10, "t": 87}, + "showactive": False, + "type": "buttons", + "x": 0.1, + "xanchor": "right", + "y": 0, + "yanchor": "top", + }] sliders_dict = { "active": 0, @@ -165,24 +163,20 @@ def run_visualization(self): } # Add initial data - fig_dict["data"].append( - { - "x": self.dynamic[0][:, 0], - "y": self.dynamic[0][:, 1], - "mode": "markers", - "name": "Predictor", - } - ) - fig_dict["data"].append( - { - "x": self.reference[0][:, 0], - "y": self.reference[0][:, 1], - "mode": "markers", - "xaxis": "x2", - "yaxis": "y2", - "name": "Target", - } - ) + fig_dict["data"].append({ + "x": self.dynamic[0][:, 0], + "y": self.dynamic[0][:, 1], + "mode": "markers", + "name": "Predictor", + }) + fig_dict["data"].append({ + "x": self.reference[0][:, 0], + "y": self.reference[0][:, 1], + "mode": "markers", + "xaxis": "x2", + "yaxis": "y2", + "name": "Target", + }) # Make the figure frames. for i, item in enumerate(self.dynamic):