Skip to content

Commit

Permalink
Merge branch 'main' into Konsti_distance_metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
knikolaou committed Apr 9, 2024
2 parents cfc5a49 + a60a80b commit e128da5
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 68 deletions.
7 changes: 4 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@ nbsphinx
tensorflow_probability
scipy
scikit-learn
jaxlib
jax
# Temp fix of version of jax and jaxlib until the next release
jax<=0.4.25
jaxlib<=0.4.25
plotly
flax
tqdm
pandas
neural-tangents==0.6.4
neural-tangents>=0.6.5
tensorflow-datasets
isort
tensorflow
Expand Down
3 changes: 1 addition & 2 deletions znnl/models/jax_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,6 @@ def __init__(
batch_size=ntk_batch_size,
store_on_device=store_on_device,
)
self.empirical_ntk_jit = jax.jit(self.empirical_ntk)
self.apply_jit = jax.jit(self.apply)

def init_model(
Expand Down Expand Up @@ -249,7 +248,7 @@ def compute_ntk(
"""
if x_j is None:
x_j = x_i
empirical_ntk = self.empirical_ntk_jit(
empirical_ntk = self.empirical_ntk(
x_i,
x_j,
{
Expand Down
24 changes: 13 additions & 11 deletions znnl/training_recording/jax_recording.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,17 +254,19 @@ def instantiate_recorder(self, data_set: dict = None, overwrite: bool = False):
self._index_count = 0

# Check if we need an NTK computation and update the class accordingly
if any([
"ntk" in self._selected_properties,
"covariance_ntk" in self._selected_properties,
"magnitude_ntk" in self._selected_properties,
"entropy" in self._selected_properties,
"magnitude_entropy" in self._selected_properties,
"magnitude_variance" in self._selected_properties,
"covariance_entropy" in self._selected_properties,
"eigenvalues" in self._selected_properties,
"trace" in self._selected_properties,
]):
if any(
[
"ntk" in self._selected_properties,
"covariance_ntk" in self._selected_properties,
"magnitude_ntk" in self._selected_properties,
"entropy" in self._selected_properties,
"magnitude_entropy" in self._selected_properties,
"magnitude_variance" in self._selected_properties,
"covariance_entropy" in self._selected_properties,
"eigenvalues" in self._selected_properties,
"trace" in self._selected_properties,
]
):
self._compute_ntk = True

if "loss_derivative" in self._selected_properties:
Expand Down
110 changes: 58 additions & 52 deletions znnl/visualization/tsne_visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,45 +104,47 @@ def run_visualization(self):
fig_dict["layout"]["xaxis2"] = {"domain": [0.8, 1.0]}
fig_dict["layout"]["yaxis2"] = {"anchor": "x2"}
fig_dict["layout"]["hovermode"] = "closest"
fig_dict["layout"]["updatemenus"] = [{
"buttons": [
{
"args": [
None,
{
"frame": {"duration": 500, "redraw": False},
"fromcurrent": True,
"transition": {
"duration": 300,
"easing": "quadratic-in-out",
fig_dict["layout"]["updatemenus"] = [
{
"buttons": [
{
"args": [
None,
{
"frame": {"duration": 500, "redraw": False},
"fromcurrent": True,
"transition": {
"duration": 300,
"easing": "quadratic-in-out",
},
},
},
],
"label": "Play",
"method": "animate",
},
{
"args": [
[None],
{
"frame": {"duration": 0, "redraw": False},
"mode": "immediate",
"transition": {"duration": 0},
},
],
"label": "Pause",
"method": "animate",
},
],
"direction": "left",
"pad": {"r": 10, "t": 87},
"showactive": False,
"type": "buttons",
"x": 0.1,
"xanchor": "right",
"y": 0,
"yanchor": "top",
}]
],
"label": "Play",
"method": "animate",
},
{
"args": [
[None],
{
"frame": {"duration": 0, "redraw": False},
"mode": "immediate",
"transition": {"duration": 0},
},
],
"label": "Pause",
"method": "animate",
},
],
"direction": "left",
"pad": {"r": 10, "t": 87},
"showactive": False,
"type": "buttons",
"x": 0.1,
"xanchor": "right",
"y": 0,
"yanchor": "top",
}
]

sliders_dict = {
"active": 0,
Expand All @@ -163,20 +165,24 @@ def run_visualization(self):
}

# Add initial data
fig_dict["data"].append({
"x": self.dynamic[0][:, 0],
"y": self.dynamic[0][:, 1],
"mode": "markers",
"name": "Predictor",
})
fig_dict["data"].append({
"x": self.reference[0][:, 0],
"y": self.reference[0][:, 1],
"mode": "markers",
"xaxis": "x2",
"yaxis": "y2",
"name": "Target",
})
fig_dict["data"].append(
{
"x": self.dynamic[0][:, 0],
"y": self.dynamic[0][:, 1],
"mode": "markers",
"name": "Predictor",
}
)
fig_dict["data"].append(
{
"x": self.reference[0][:, 0],
"y": self.reference[0][:, 1],
"mode": "markers",
"xaxis": "x2",
"yaxis": "y2",
"name": "Target",
}
)

# Make the figure frames.
for i, item in enumerate(self.dynamic):
Expand Down

0 comments on commit e128da5

Please sign in to comment.