diff --git a/CI/unit_tests/distance_metrics/test_mahalanobis_distance.py b/CI/unit_tests/distance_metrics/test_mahalanobis_distance.py index d156d03..d91687c 100644 --- a/CI/unit_tests/distance_metrics/test_mahalanobis_distance.py +++ b/CI/unit_tests/distance_metrics/test_mahalanobis_distance.py @@ -135,12 +135,12 @@ def create_sample_set(): ------- Creates a random normal distributed sample set """ - point_1 = np.array([ - onp.random.normal(0, 10, 100), onp.random.normal(0, 20, 100) - ]).T - point_2 = np.array([ - onp.random.normal(0, 10, 100), onp.random.normal(0, 20, 100) - ]).T + point_1 = np.array( + [onp.random.normal(0, 10, 100), onp.random.normal(0, 20, 100)] + ).T + point_2 = np.array( + [onp.random.normal(0, 10, 100), onp.random.normal(0, 20, 100)] + ).T return point_1, point_2 @staticmethod diff --git a/znnl/training_recording/jax_recording.py b/znnl/training_recording/jax_recording.py index 862b7d4..9ba92f8 100644 --- a/znnl/training_recording/jax_recording.py +++ b/znnl/training_recording/jax_recording.py @@ -254,17 +254,19 @@ def instantiate_recorder(self, data_set: dict = None, overwrite: bool = False): self._index_count = 0 # Check if we need an NTK computation and update the class accordingly - if any([ - "ntk" in self._selected_properties, - "covariance_ntk" in self._selected_properties, - "magnitude_ntk" in self._selected_properties, - "entropy" in self._selected_properties, - "magnitude_entropy" in self._selected_properties, - "magnitude_variance" in self._selected_properties, - "covariance_entropy" in self._selected_properties, - "eigenvalues" in self._selected_properties, - "trace" in self._selected_properties, - ]): + if any( + [ + "ntk" in self._selected_properties, + "covariance_ntk" in self._selected_properties, + "magnitude_ntk" in self._selected_properties, + "entropy" in self._selected_properties, + "magnitude_entropy" in self._selected_properties, + "magnitude_variance" in self._selected_properties, + "covariance_entropy" in self._selected_properties, + "eigenvalues" in self._selected_properties, + "trace" in self._selected_properties, + ] + ): self._compute_ntk = True if "loss_derivative" in self._selected_properties: diff --git a/znnl/visualization/tsne_visualizer.py b/znnl/visualization/tsne_visualizer.py index 40618e0..d0d613a 100644 --- a/znnl/visualization/tsne_visualizer.py +++ b/znnl/visualization/tsne_visualizer.py @@ -165,20 +165,24 @@ def run_visualization(self): } # Add initial data - fig_dict["data"].append({ - "x": self.dynamic[0][:, 0], - "y": self.dynamic[0][:, 1], - "mode": "markers", - "name": "Predictor", - }) - fig_dict["data"].append({ - "x": self.reference[0][:, 0], - "y": self.reference[0][:, 1], - "mode": "markers", - "xaxis": "x2", - "yaxis": "y2", - "name": "Target", - }) + fig_dict["data"].append( + { + "x": self.dynamic[0][:, 0], + "y": self.dynamic[0][:, 1], + "mode": "markers", + "name": "Predictor", + } + ) + fig_dict["data"].append( + { + "x": self.reference[0][:, 0], + "y": self.reference[0][:, 1], + "mode": "markers", + "xaxis": "x2", + "yaxis": "y2", + "name": "Target", + } + ) # Make the figure frames. for i, item in enumerate(self.dynamic):