Skip to content

Commit

Permalink
Merge branch 'main' into Konsti_remove_ntk_jit
Browse files Browse the repository at this point in the history
  • Loading branch information
knikolaou committed Apr 8, 2024
2 parents f64bcfa + cc82976 commit 83ac7ea
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 66 deletions.
7 changes: 4 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@ nbsphinx
tensorflow_probability
scipy
scikit-learn
jaxlib
jax
# Temp fix of version of jax and jaxlib until the next release
jax<=0.4.25
jaxlib<=0.4.25
plotly
flax
tqdm
pandas
neural-tangents==0.6.4
neural-tangents>=0.6.5
tensorflow-datasets
isort
tensorflow
Expand Down
24 changes: 13 additions & 11 deletions znnl/training_recording/jax_recording.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,17 +254,19 @@ def instantiate_recorder(self, data_set: dict = None, overwrite: bool = False):
self._index_count = 0

# Check if we need an NTK computation and update the class accordingly
if any([
"ntk" in self._selected_properties,
"covariance_ntk" in self._selected_properties,
"magnitude_ntk" in self._selected_properties,
"entropy" in self._selected_properties,
"magnitude_entropy" in self._selected_properties,
"magnitude_variance" in self._selected_properties,
"covariance_entropy" in self._selected_properties,
"eigenvalues" in self._selected_properties,
"trace" in self._selected_properties,
]):
if any(
[
"ntk" in self._selected_properties,
"covariance_ntk" in self._selected_properties,
"magnitude_ntk" in self._selected_properties,
"entropy" in self._selected_properties,
"magnitude_entropy" in self._selected_properties,
"magnitude_variance" in self._selected_properties,
"covariance_entropy" in self._selected_properties,
"eigenvalues" in self._selected_properties,
"trace" in self._selected_properties,
]
):
self._compute_ntk = True

if "loss_derivative" in self._selected_properties:
Expand Down
110 changes: 58 additions & 52 deletions znnl/visualization/tsne_visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,45 +104,47 @@ def run_visualization(self):
fig_dict["layout"]["xaxis2"] = {"domain": [0.8, 1.0]}
fig_dict["layout"]["yaxis2"] = {"anchor": "x2"}
fig_dict["layout"]["hovermode"] = "closest"
fig_dict["layout"]["updatemenus"] = [{
"buttons": [
{
"args": [
None,
{
"frame": {"duration": 500, "redraw": False},
"fromcurrent": True,
"transition": {
"duration": 300,
"easing": "quadratic-in-out",
fig_dict["layout"]["updatemenus"] = [
{
"buttons": [
{
"args": [
None,
{
"frame": {"duration": 500, "redraw": False},
"fromcurrent": True,
"transition": {
"duration": 300,
"easing": "quadratic-in-out",
},
},
},
],
"label": "Play",
"method": "animate",
},
{
"args": [
[None],
{
"frame": {"duration": 0, "redraw": False},
"mode": "immediate",
"transition": {"duration": 0},
},
],
"label": "Pause",
"method": "animate",
},
],
"direction": "left",
"pad": {"r": 10, "t": 87},
"showactive": False,
"type": "buttons",
"x": 0.1,
"xanchor": "right",
"y": 0,
"yanchor": "top",
}]
],
"label": "Play",
"method": "animate",
},
{
"args": [
[None],
{
"frame": {"duration": 0, "redraw": False},
"mode": "immediate",
"transition": {"duration": 0},
},
],
"label": "Pause",
"method": "animate",
},
],
"direction": "left",
"pad": {"r": 10, "t": 87},
"showactive": False,
"type": "buttons",
"x": 0.1,
"xanchor": "right",
"y": 0,
"yanchor": "top",
}
]

sliders_dict = {
"active": 0,
Expand All @@ -163,20 +165,24 @@ def run_visualization(self):
}

# Add initial data
fig_dict["data"].append({
"x": self.dynamic[0][:, 0],
"y": self.dynamic[0][:, 1],
"mode": "markers",
"name": "Predictor",
})
fig_dict["data"].append({
"x": self.reference[0][:, 0],
"y": self.reference[0][:, 1],
"mode": "markers",
"xaxis": "x2",
"yaxis": "y2",
"name": "Target",
})
fig_dict["data"].append(
{
"x": self.dynamic[0][:, 0],
"y": self.dynamic[0][:, 1],
"mode": "markers",
"name": "Predictor",
}
)
fig_dict["data"].append(
{
"x": self.reference[0][:, 0],
"y": self.reference[0][:, 1],
"mode": "markers",
"xaxis": "x2",
"yaxis": "y2",
"name": "Target",
}
)

# Make the figure frames.
for i, item in enumerate(self.dynamic):
Expand Down

0 comments on commit 83ac7ea

Please sign in to comment.