Skip to content

Commit

Permalink
another one
Browse files Browse the repository at this point in the history
  • Loading branch information
SamTov committed Dec 20, 2023
1 parent 2d86ab0 commit 2cf3a6f
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 73 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@ fail_fast: true

repos:
- repo: https://github.com/psf/black
rev: 23.9.1
rev: 23.12.0
hooks:
- id: black

- repo: https://github.com/timothycrosley/isort
rev: 5.12.0
rev: 5.13.2
hooks:
- id: isort

Expand Down
24 changes: 11 additions & 13 deletions znnl/training_recording/jax_recording.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,19 +254,17 @@ def instantiate_recorder(self, data_set: dict = None, overwrite: bool = False):
self._index_count = 0

# Check if we need an NTK computation and update the class accordingly
if any(
[
"ntk" in self._selected_properties,
"covariance_ntk" in self._selected_properties,
"magnitude_ntk" in self._selected_properties,
"entropy" in self._selected_properties,
"magnitude_entropy" in self._selected_properties,
"magnitude_variance" in self._selected_properties,
"covariance_entropy" in self._selected_properties,
"eigenvalues" in self._selected_properties,
"trace" in self._selected_properties,
]
):
if any([
"ntk" in self._selected_properties,
"covariance_ntk" in self._selected_properties,
"magnitude_ntk" in self._selected_properties,
"entropy" in self._selected_properties,
"magnitude_entropy" in self._selected_properties,
"magnitude_variance" in self._selected_properties,
"covariance_entropy" in self._selected_properties,
"eigenvalues" in self._selected_properties,
"trace" in self._selected_properties,
]):
self._compute_ntk = True

if "loss_derivative" in self._selected_properties:
Expand Down
110 changes: 52 additions & 58 deletions znnl/visualization/tsne_visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,47 +104,45 @@ def run_visualization(self):
fig_dict["layout"]["xaxis2"] = {"domain": [0.8, 1.0]}
fig_dict["layout"]["yaxis2"] = {"anchor": "x2"}
fig_dict["layout"]["hovermode"] = "closest"
fig_dict["layout"]["updatemenus"] = [
{
"buttons": [
{
"args": [
None,
{
"frame": {"duration": 500, "redraw": False},
"fromcurrent": True,
"transition": {
"duration": 300,
"easing": "quadratic-in-out",
},
},
],
"label": "Play",
"method": "animate",
},
{
"args": [
[None],
{
"frame": {"duration": 0, "redraw": False},
"mode": "immediate",
"transition": {"duration": 0},
fig_dict["layout"]["updatemenus"] = [{
"buttons": [
{
"args": [
None,
{
"frame": {"duration": 500, "redraw": False},
"fromcurrent": True,
"transition": {
"duration": 300,
"easing": "quadratic-in-out",
},
],
"label": "Pause",
"method": "animate",
},
],
"direction": "left",
"pad": {"r": 10, "t": 87},
"showactive": False,
"type": "buttons",
"x": 0.1,
"xanchor": "right",
"y": 0,
"yanchor": "top",
}
]
},
],
"label": "Play",
"method": "animate",
},
{
"args": [
[None],
{
"frame": {"duration": 0, "redraw": False},
"mode": "immediate",
"transition": {"duration": 0},
},
],
"label": "Pause",
"method": "animate",
},
],
"direction": "left",
"pad": {"r": 10, "t": 87},
"showactive": False,
"type": "buttons",
"x": 0.1,
"xanchor": "right",
"y": 0,
"yanchor": "top",
}]

sliders_dict = {
"active": 0,
Expand All @@ -165,24 +163,20 @@ def run_visualization(self):
}

# Add initial data
fig_dict["data"].append(
{
"x": self.dynamic[0][:, 0],
"y": self.dynamic[0][:, 1],
"mode": "markers",
"name": "Predictor",
}
)
fig_dict["data"].append(
{
"x": self.reference[0][:, 0],
"y": self.reference[0][:, 1],
"mode": "markers",
"xaxis": "x2",
"yaxis": "y2",
"name": "Target",
}
)
fig_dict["data"].append({
"x": self.dynamic[0][:, 0],
"y": self.dynamic[0][:, 1],
"mode": "markers",
"name": "Predictor",
})
fig_dict["data"].append({
"x": self.reference[0][:, 0],
"y": self.reference[0][:, 1],
"mode": "markers",
"xaxis": "x2",
"yaxis": "y2",
"name": "Target",
})

# Make the figure frames.
for i, item in enumerate(self.dynamic):
Expand Down

0 comments on commit 2cf3a6f

Please sign in to comment.