Skip to content

Commit

Permalink
Run black
Browse files Browse the repository at this point in the history
  • Loading branch information
SamTov committed Dec 20, 2023
1 parent 2eb4d79 commit 2d86ab0
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 31 deletions.
12 changes: 6 additions & 6 deletions CI/unit_tests/distance_metrics/test_mahalanobis_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,12 +135,12 @@ def create_sample_set():
-------
Creates a random normal distributed sample set
"""
point_1 = np.array([
onp.random.normal(0, 10, 100), onp.random.normal(0, 20, 100)
]).T
point_2 = np.array([
onp.random.normal(0, 10, 100), onp.random.normal(0, 20, 100)
]).T
point_1 = np.array(
[onp.random.normal(0, 10, 100), onp.random.normal(0, 20, 100)]
).T
point_2 = np.array(
[onp.random.normal(0, 10, 100), onp.random.normal(0, 20, 100)]
).T
return point_1, point_2

@staticmethod
Expand Down
24 changes: 13 additions & 11 deletions znnl/training_recording/jax_recording.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,17 +254,19 @@ def instantiate_recorder(self, data_set: dict = None, overwrite: bool = False):
self._index_count = 0

# Check if we need an NTK computation and update the class accordingly
if any([
"ntk" in self._selected_properties,
"covariance_ntk" in self._selected_properties,
"magnitude_ntk" in self._selected_properties,
"entropy" in self._selected_properties,
"magnitude_entropy" in self._selected_properties,
"magnitude_variance" in self._selected_properties,
"covariance_entropy" in self._selected_properties,
"eigenvalues" in self._selected_properties,
"trace" in self._selected_properties,
]):
if any(
[
"ntk" in self._selected_properties,
"covariance_ntk" in self._selected_properties,
"magnitude_ntk" in self._selected_properties,
"entropy" in self._selected_properties,
"magnitude_entropy" in self._selected_properties,
"magnitude_variance" in self._selected_properties,
"covariance_entropy" in self._selected_properties,
"eigenvalues" in self._selected_properties,
"trace" in self._selected_properties,
]
):
self._compute_ntk = True

if "loss_derivative" in self._selected_properties:
Expand Down
32 changes: 18 additions & 14 deletions znnl/visualization/tsne_visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,20 +165,24 @@ def run_visualization(self):
}

# Add initial data
fig_dict["data"].append({
"x": self.dynamic[0][:, 0],
"y": self.dynamic[0][:, 1],
"mode": "markers",
"name": "Predictor",
})
fig_dict["data"].append({
"x": self.reference[0][:, 0],
"y": self.reference[0][:, 1],
"mode": "markers",
"xaxis": "x2",
"yaxis": "y2",
"name": "Target",
})
fig_dict["data"].append(
{
"x": self.dynamic[0][:, 0],
"y": self.dynamic[0][:, 1],
"mode": "markers",
"name": "Predictor",
}
)
fig_dict["data"].append(
{
"x": self.reference[0][:, 0],
"y": self.reference[0][:, 1],
"mode": "markers",
"xaxis": "x2",
"yaxis": "y2",
"name": "Target",
}
)

# Make the figure frames.
for i, item in enumerate(self.dynamic):
Expand Down

0 comments on commit 2d86ab0

Please sign in to comment.