diff --git a/requirements.txt b/requirements.txt index 3050110..d31d588 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,13 +12,14 @@ nbsphinx tensorflow_probability scipy scikit-learn -jaxlib -jax +# Temp fix of version of jax and jaxlib until the next release +jax<=0.4.25 +jaxlib<=0.4.25 plotly flax tqdm pandas -neural-tangents==0.6.4 +neural-tangents>=0.6.5 tensorflow-datasets isort tensorflow diff --git a/znnl/training_recording/jax_recording.py b/znnl/training_recording/jax_recording.py index 862b7d4..9ba92f8 100644 --- a/znnl/training_recording/jax_recording.py +++ b/znnl/training_recording/jax_recording.py @@ -254,17 +254,19 @@ def instantiate_recorder(self, data_set: dict = None, overwrite: bool = False): self._index_count = 0 # Check if we need an NTK computation and update the class accordingly - if any([ - "ntk" in self._selected_properties, - "covariance_ntk" in self._selected_properties, - "magnitude_ntk" in self._selected_properties, - "entropy" in self._selected_properties, - "magnitude_entropy" in self._selected_properties, - "magnitude_variance" in self._selected_properties, - "covariance_entropy" in self._selected_properties, - "eigenvalues" in self._selected_properties, - "trace" in self._selected_properties, - ]): + if any( + [ + "ntk" in self._selected_properties, + "covariance_ntk" in self._selected_properties, + "magnitude_ntk" in self._selected_properties, + "entropy" in self._selected_properties, + "magnitude_entropy" in self._selected_properties, + "magnitude_variance" in self._selected_properties, + "covariance_entropy" in self._selected_properties, + "eigenvalues" in self._selected_properties, + "trace" in self._selected_properties, + ] + ): self._compute_ntk = True if "loss_derivative" in self._selected_properties: diff --git a/znnl/visualization/tsne_visualizer.py b/znnl/visualization/tsne_visualizer.py index 856ecdd..d0d613a 100644 --- a/znnl/visualization/tsne_visualizer.py +++ b/znnl/visualization/tsne_visualizer.py @@ -104,45 +104,47 @@ def run_visualization(self): fig_dict["layout"]["xaxis2"] = {"domain": [0.8, 1.0]} fig_dict["layout"]["yaxis2"] = {"anchor": "x2"} fig_dict["layout"]["hovermode"] = "closest" - fig_dict["layout"]["updatemenus"] = [{ - "buttons": [ - { - "args": [ - None, - { - "frame": {"duration": 500, "redraw": False}, - "fromcurrent": True, - "transition": { - "duration": 300, - "easing": "quadratic-in-out", + fig_dict["layout"]["updatemenus"] = [ + { + "buttons": [ + { + "args": [ + None, + { + "frame": {"duration": 500, "redraw": False}, + "fromcurrent": True, + "transition": { + "duration": 300, + "easing": "quadratic-in-out", + }, }, - }, - ], - "label": "Play", - "method": "animate", - }, - { - "args": [ - [None], - { - "frame": {"duration": 0, "redraw": False}, - "mode": "immediate", - "transition": {"duration": 0}, - }, - ], - "label": "Pause", - "method": "animate", - }, - ], - "direction": "left", - "pad": {"r": 10, "t": 87}, - "showactive": False, - "type": "buttons", - "x": 0.1, - "xanchor": "right", - "y": 0, - "yanchor": "top", - }] + ], + "label": "Play", + "method": "animate", + }, + { + "args": [ + [None], + { + "frame": {"duration": 0, "redraw": False}, + "mode": "immediate", + "transition": {"duration": 0}, + }, + ], + "label": "Pause", + "method": "animate", + }, + ], + "direction": "left", + "pad": {"r": 10, "t": 87}, + "showactive": False, + "type": "buttons", + "x": 0.1, + "xanchor": "right", + "y": 0, + "yanchor": "top", + } + ] sliders_dict = { "active": 0, @@ -163,20 +165,24 @@ def run_visualization(self): } # Add initial data - fig_dict["data"].append({ - "x": self.dynamic[0][:, 0], - "y": self.dynamic[0][:, 1], - "mode": "markers", - "name": "Predictor", - }) - fig_dict["data"].append({ - "x": self.reference[0][:, 0], - "y": self.reference[0][:, 1], - "mode": "markers", - "xaxis": "x2", - "yaxis": "y2", - "name": "Target", - }) + fig_dict["data"].append( + { + "x": self.dynamic[0][:, 0], + "y": self.dynamic[0][:, 1], + "mode": "markers", + "name": "Predictor", + } + ) + fig_dict["data"].append( + { + "x": self.reference[0][:, 0], + "y": self.reference[0][:, 1], + "mode": "markers", + "xaxis": "x2", + "yaxis": "y2", + "name": "Target", + } + ) # Make the figure frames. for i, item in enumerate(self.dynamic):