diff --git a/CI/integration_tests/training_recording/test_loss_ntk_recording_deployment.py b/CI/integration_tests/training_recording/test_loss_ntk_recording_deployment.py new file mode 100644 index 0000000..d0032bb --- /dev/null +++ b/CI/integration_tests/training_recording/test_loss_ntk_recording_deployment.py @@ -0,0 +1,108 @@ +""" +ZnNL: A Zincwarecode package. + +License +------- +This program and the accompanying materials are made available under the terms +of the Eclipse Public License v2.0 which accompanies this distribution, and is +available at https://www.eclipse.org/legal/epl-v20.html + +SPDX-License-Identifier: EPL-2.0 + +Copyright Contributors to the Zincwarecode Project. + +Contact Information +------------------- +email: zincwarecode@gmail.com +github: https://github.com/zincware +web: https://zincwarecode.com/ + +Citation +-------- +If you use this module please cite us with: + +Summary +------- +""" + +import os + +os.environ["CUDA_VISIBLE_DEVICES"] = "-1" + +import numpy as np +import optax +from neural_tangents import stax +from numpy.testing import assert_array_almost_equal + +from znnl.loss_functions import LPNormLoss +from znnl.models import NTModel +from znnl.training_recording import JaxRecorder +from znnl.training_strategies import SimpleTraining + + +class TestLossNTKRecorderDeployment: + """ + Test suite for the loss and NTK recorder. + """ + + @classmethod + def setup_class(cls): + """ + Create a model and data for the tests. + """ + + network = stax.serial( + stax.Dense(10), stax.Relu(), stax.Dense(10), stax.Relu(), stax.Dense(1) + ) + cls.model = NTModel( + nt_module=network, input_shape=(5,), optimizer=optax.adam(1e-3) + ) + + cls.data_set = { + "inputs": np.random.rand(10, 5), + "targets": np.random.randint(0, 2, (10, 1)), + } + + cls.ntk_recorder = JaxRecorder( + name="ntk_recorder", + ntk=True, + update_rate=1, + ) + cls.loss_ntk_recorder = JaxRecorder( + name="loss_ntk_recorder", + ntk=True, + use_loss_ntk=True, + update_rate=1, + ) + + cls.ntk_recorder.instantiate_recorder(data_set=cls.data_set) + cls.loss_ntk_recorder.instantiate_recorder(data_set=cls.data_set) + + cls.trainer = SimpleTraining( + model=cls.model, + loss_fn=LPNormLoss(order=2), + recorders=[cls.ntk_recorder, cls.loss_ntk_recorder], + ) + + def test_loss_ntk_deployment(self): + """ + Test the deployment of the loss_NTK recorder. + """ + + # train the model + training_metrics = self.trainer.train_model( + train_ds=self.data_set, + test_ds=self.data_set, + epochs=10, + batch_size=2, + ) + + # gather the recording + ntk_recording = self.ntk_recorder.gather_recording() + loss_ntk_recording = self.loss_ntk_recorder.gather_recording() + + # For LPNormLoss of order 2 and a 1D output Network, the NTK and the loss NTK + # should be the same up to a factor of +1 or -1. + assert_array_almost_equal( + np.abs(ntk_recording.ntk), np.abs(loss_ntk_recording.ntk), decimal=5 + ) diff --git a/CI/unit_tests/analysis/test_loss_ntk_calculation.py b/CI/unit_tests/analysis/test_loss_ntk_calculation.py new file mode 100644 index 0000000..ee51cb3 --- /dev/null +++ b/CI/unit_tests/analysis/test_loss_ntk_calculation.py @@ -0,0 +1,240 @@ +""" +ZnNL: A Zincwarecode package. + +License +------- +This program and the accompanying materials are made available under the terms +of the Eclipse Public License v2.0 which accompanies this distribution, and is +available at https://www.eclipse.org/legal/epl-v20.html + +SPDX-License-Identifier: EPL-2.0 + +Copyright Contributors to the Zincwarecode Project. + +Contact Information +------------------- +email: zincwarecode@gmail.com +github: https://github.com/zincware +web: https://zincwarecode.com/ + +Citation +-------- +If you use this module please cite us with: + +Summary +------- +""" + +import os + +os.environ["CUDA_VISIBLE_DEVICES"] = "-1" + +import jax.numpy as np +import numpy as onp +import optax +from neural_tangents import stax +from numpy.testing import assert_array_almost_equal + +from znnl.analysis import EigenSpaceAnalysis, LossNTKCalculation +from znnl.data import MNISTGenerator +from znnl.distance_metrics import LPNorm +from znnl.loss_functions import LPNormLoss +from znnl.models import FlaxModel, NTModel + + +class TestLossNTKCalculation: + """ + Test Suite for the LossNTKCalculation module. + """ + + def test_reshaping_methods(self): + """ + Test the _reshape_dataset and _unshape_dataset methods. + These are functions used in the loss NTK calculation to + """ + # Define a dummy model and dataset to be able to define a + # LossNTKCalculation class + feed_forward_model = stax.serial( + stax.Dense(5), + stax.Relu(), + stax.Dense(2), + stax.Relu(), + ) + + # Initialize the model + test_model = NTModel( + optimizer=optax.adam(learning_rate=0.01), + input_shape=(1, 5), + trace_axes=(), + nt_module=feed_forward_model, + ) + + data_generator = MNISTGenerator(ds_size=20) + data_set = { + "inputs": data_generator.train_ds["inputs"], + "targets": data_generator.train_ds["targets"], + } + + # Initialize the loss NTK calculation + loss_ntk_calculator = LossNTKCalculation( + metric_fn=LPNorm(order=2), + model=test_model, + dataset=data_set, + ) + + # Setup a test dataset for reshaping + test_data_set = { + "inputs": np.array([[1, 2, 3], [4, 5, 6]]), + "targets": np.array([[7, 9], [10, 12]]), + } + + # Test the reshaping + reshaped_test_data_set = loss_ntk_calculator._reshape_dataset(test_data_set) + + assert_array_almost_equal( + reshaped_test_data_set, np.array([[1, 2, 3, 7, 9], [4, 5, 6, 10, 12]]) + ) + + # Test the unshaping + input_0, target_0 = loss_ntk_calculator._unshape_data( + reshaped_test_data_set, + input_dimension=3, + input_shape=(2, 3), + target_shape=(2, 2), + batch_length=reshaped_test_data_set.shape[0], + ) + assert_array_almost_equal(input_0, test_data_set["inputs"]) + assert_array_almost_equal(target_0, test_data_set["targets"]) + + def test_function_for_loss_ntk(self): + """ + This method tests the function that calculates the loss for single + datapoints. + """ + # Define a simple feed forward test model + feed_forward_model = stax.serial( + stax.Dense(5), + stax.Relu(), + stax.Dense(2), + stax.Relu(), + ) + + # Initialize the model + model = NTModel( + optimizer=optax.adam(learning_rate=0.01), + input_shape=(1, 5), + trace_axes=(), + nt_module=feed_forward_model, + ) + + # Define a test dataset with only two datapoints + test_data_set = { + "inputs": np.array([[1, 2, 3, 4, 5], [1, 2, 3, 4, 8]]), + "targets": np.array([[1, 3], [2, 5]]), + } + + # Initialize loss + loss = LPNormLoss(order=2) + # Initialize the loss NTK calculation + loss_ntk_calculator = LossNTKCalculation( + metric_fn=loss.metric, + model=model, + dataset=test_data_set, + ) + + # Calculate the subloss from the NTK first + datapoint = loss_ntk_calculator._reshape_dataset(test_data_set)[0:1] + subloss_from_NTK = loss_ntk_calculator._function_for_loss_ntk( + { + "params": model.model_state.params, + "batch_stats": model.model_state.batch_stats, + }, + datapoint=datapoint, + ) + + # Now calculate subloss manually + applied_model = model.apply( + { + "params": model.model_state.params, + "batch_stats": model.model_state.batch_stats, + }, + test_data_set["inputs"][0], + ) + subloss = np.linalg.norm(applied_model - test_data_set["targets"][0], ord=2) + + # Check that the two losses are the same + assert subloss - subloss_from_NTK < 1e-5 + + def test_loss_NTK_calculation(self): + """ + Test the Loss NTK calculation with a manually hardcoded dataset + and a simple feed forward network. + """ + # Define a simple feed forward test model + feed_forward_model = stax.serial( + stax.Dense(5), + stax.Relu(), + stax.Dense(1), + stax.Relu(), + ) + + # Initialize the model + model = NTModel( + optimizer=optax.adam(learning_rate=0.01), + input_shape=(2,), + trace_axes=(), + nt_module=feed_forward_model, + ) + + # Create a test dataset + inputs = np.array(onp.random.rand(10, 2)) + targets = np.array(100 * onp.random.rand(10, 1)) + test_data_set = {"inputs": inputs, "targets": targets} + + # Initialize loss + loss = LPNormLoss(order=2) + # Initialize the loss NTK calculation + loss_ntk_calculator = LossNTKCalculation( + metric_fn=loss.metric, + model=model, + dataset=test_data_set, + ) + + # Compute the loss NTK + loss_ntk = loss_ntk_calculator.compute_loss_ntk(x_i=test_data_set, model=model)[ + "empirical" + ] + + # Now for comparison calculate regular ntk + ntk = model.compute_ntk(test_data_set["inputs"], infinite=False)["empirical"] + + # predictions calculation analogous to the one in jax recording + predictions = model(test_data_set["inputs"]) + if type(predictions) is tuple: + predictions = predictions[0] + + # calculation of loss derivatives + # note: here we need the derivatives of the subloss, not the regular loss fn + loss_derivatives = onp.empty(shape=(len(predictions))) + for i in range(len(loss_derivatives)): + loss_derivatives[i] = ( + predictions[i, 0] / np.abs(predictions[i, 0]) + if predictions[i, 0] != 0 + else 0 + ) + + # Calculate the loss NTK from the loss derivatives and the ntk + loss_ntk_2 = np.einsum( + "i, j, ijkl-> ij", loss_derivatives, loss_derivatives, ntk + ) + + # Assert that the loss NTKs are the same + assert_array_almost_equal(loss_ntk, loss_ntk_2, decimal=6) + + calculator1 = EigenSpaceAnalysis(matrix=loss_ntk) + calculator2 = EigenSpaceAnalysis(matrix=loss_ntk_2) + + eigenvalues1 = calculator1.compute_eigenvalues(normalize=False) + eigenvalue2 = calculator2.compute_eigenvalues(normalize=False) + + assert_array_almost_equal(eigenvalues1, eigenvalue2, decimal=6) diff --git a/CI/unit_tests/training_recording/test_training_recording.py b/CI/unit_tests/training_recording/test_training_recording.py index 5249b3e..cf7ba60 100644 --- a/CI/unit_tests/training_recording/test_training_recording.py +++ b/CI/unit_tests/training_recording/test_training_recording.py @@ -75,6 +75,7 @@ def test_instantiation(self): "name", "storage_path", "chunk_size", + "use_loss_ntk", ] for key, val in vars(recorder).items(): if key[0] != "_" and key not in _exclude_list: diff --git a/requirements.txt b/requirements.txt index 3050110..cb1c088 100644 --- a/requirements.txt +++ b/requirements.txt @@ -18,7 +18,7 @@ plotly flax tqdm pandas -neural-tangents==0.6.4 +neural-tangents tensorflow-datasets isort tensorflow diff --git a/znnl/analysis/__init__.py b/znnl/analysis/__init__.py index 94701c5..eed4fcb 100644 --- a/znnl/analysis/__init__.py +++ b/znnl/analysis/__init__.py @@ -28,9 +28,11 @@ from znnl.analysis.eigensystem import EigenSpaceAnalysis from znnl.analysis.entropy import EntropyAnalysis from znnl.analysis.loss_fn_derivative import LossDerivative +from znnl.analysis.loss_ntk_calculation import LossNTKCalculation __all__ = [ EntropyAnalysis.__name__, EigenSpaceAnalysis.__name__, LossDerivative.__name__, + LossNTKCalculation.__name__, ] diff --git a/znnl/analysis/loss_ntk_calculation.py b/znnl/analysis/loss_ntk_calculation.py new file mode 100644 index 0000000..0b37603 --- /dev/null +++ b/znnl/analysis/loss_ntk_calculation.py @@ -0,0 +1,241 @@ +""" +ZnNL: A Zincwarecode package. + +License +------- +This program and the accompanying materials are made available under the terms +of the Eclipse Public License v2.0 which accompanies this distribution, and is +available at https://www.eclipse.org/legal/epl-v20.html + +SPDX-License-Identifier: EPL-2.0 + +Copyright Contributors to the Zincwarecode Project. + +Contact Information +------------------- +email: zincwarecode@gmail.com +github: https://github.com/zincware +web: https://zincwarecode.com/ + +Citation +-------- +If you use this module please cite us with: + +Summary +------- +""" + +from typing import Callable + +import jax +import jax.numpy as np +import neural_tangents as nt + +from znnl.models.jax_model import JaxModel + + +class LossNTKCalculation: + """ + Class to calculate the loss NTK matrix for a given model and dataset. + + The loss NTK matrix is defined as + .. math:: + (\\Lambda_L)_{ij} = \\sum_\\theta\\frac{\\partial\\ell_i}{\\partial\\theta} + \\frac{\\partial\\ell_j}{\\partial\\theta}}. + Where :math:`\\ell_i` is the loss function for the single datapoint :math:`i` and + :math:`\\theta` are the parameters of the network. + """ + + def __init__( + self, + metric_fn: Callable, + model: JaxModel, + dataset: dict, + ): + """ + Constructor for the loss ntk calculation class. + + Parameters + ---------- + + metric_fn : Callable + The metric function to be used for the loss calculation. + !This has to be the metric, not the Loss! + If you put in the Loss here you won't get an error but an + incorrect result. + + model : JaxModel + The model for which to calculate the loss NTK. + + dataset : dict + The dataset for which to calculate the loss NTK. + The dictionary should contain the keys "inputs" and "targets". + """ + + # Set the attributes + self.ntk_batch_size = model.ntk_batch_size + self.store_on_device = model.store_on_device + self.trace_axes = model.trace_axes + self.input_shape = dataset["inputs"].shape + self.input_dimension = int(np.prod(np.array(self.input_shape[1:]))) + self.target_shape = dataset["targets"].shape + self.metric_fn = metric_fn + self.apply_fn = model._ntk_apply_fn + + # Prepare NTK calculation + empirical_ntk = nt.batch( + nt.empirical_ntk_fn( + f=self._function_for_loss_ntk, + trace_axes=(), + vmap_axes=0, + ), + batch_size=self.ntk_batch_size, + store_on_device=self.store_on_device, + ) + self.empirical_ntk_jit = jax.jit(empirical_ntk) + + @staticmethod + def _reshape_dataset(dataset): + """ + Helper function to reshape the dataset for the Loss NTK calculation. + + Parameters + ---------- + dataset : dict + The dataset to be reshaped. + Should contain the keys "inputs" and "targets". + + Returns + ------- + reshaped_dataset : np.ndarray + The reshaped dataset. + """ + return np.concatenate( + ( + dataset["inputs"].reshape(dataset["inputs"].shape[0], -1), + dataset["targets"].reshape(dataset["targets"].shape[0], -1), + ), + axis=1, + ) + + @staticmethod + def _unshape_data( + datapoint: np.ndarray, + input_dimension: int, + input_shape: tuple, + target_shape: tuple, + batch_length: int, + ): + """ + Helper function to unshape the data for the subloss calculation. + + Parameters + ---------- + datapoint : np.ndarray + The datapoint to be unshaped. + input_dimension : int + The total dimension of the input, i.e. the product of its shape. + input_shape : tuple + The shape of the original input. + target_shape : tuple + The shape of the original target. + batch_length : int + The length of the batch, i.e. the number of actual datapoints + inside the datapoint variable + + Returns + ------- + input: np.ndarray + The unshaped input. + target: np.ndarray + The unshaped target. + """ + return datapoint[:, :input_dimension].reshape( + batch_length, *input_shape[1:] + ), datapoint[:, input_dimension:].reshape(batch_length, *target_shape[1:]) + + def _function_for_loss_ntk(self, params, datapoint) -> float: + """ + Helper function that calculates the subloss for single datapoints. + + Parameters + ---------- + params : dict + The parameters of the model. + datapoint : np.ndarray (batch_length, input_dimension + output_dimension) + The datapoint for which to calculate the subloss. The indicated + shaping is necessary so that this function can be understood by the + empirical NTK function of the neural tangents library. + + Returns + ------- + subloss : float + The subloss for the given datapoint. + """ + batch_length = datapoint.shape[0] + _input, _target = self._unshape_data( + datapoint, + self.input_dimension, + self.input_shape, + self.target_shape, + batch_length, + ) + return self.metric_fn( + self.apply_fn(params, _input), + _target, + ) + + def compute_loss_ntk( + self, + x_i: np.ndarray, + model: JaxModel, + x_j: np.ndarray = None, + infinite: bool = False, + ): + """ + Compute the loss NTK matrix for the model. + The dataset gets reshaped to (n_data, input_dimension + output_dimension) + so that the neural tangents empirical_ntk_fn function can take each input + target pair as its input. + + Parameters + ---------- + x_i : np.ndarray + Dataset for which to compute the loss NTK matrix. + x_j : np.ndarray (optional) + Dataset for which to compute the loss NTK matrix. + infinite : bool (default = False) + If true, compute the infinite width limit as well. + + Returns + ------- + Loss NTK : dict + The Loss NTK matrix for both the empirical and + infinite width computation. + """ + + x_i = self._reshape_dataset(x_i) + + if x_j is None: + x_j = x_i + else: + x_j = self._reshape_dataset(x_j) + + empirical_ntk = self.empirical_ntk_jit( + x_i, + x_j, + { + "params": model.model_state.params, + "batch_stats": model.model_state.batch_stats, + }, + ) + + if infinite: + try: + infinite_ntk = self.kernel_fn(x_i, x_j, "ntk") + except AttributeError: + raise NotImplementedError("Infinite NTK not available for this model.") + else: + infinite_ntk = None + + return {"empirical": empirical_ntk, "infinite": infinite_ntk} diff --git a/znnl/models/jax_model.py b/znnl/models/jax_model.py index 8e8ea95..dc4f4ea 100644 --- a/znnl/models/jax_model.py +++ b/znnl/models/jax_model.py @@ -127,6 +127,13 @@ def __init__( batch_size=ntk_batch_size, store_on_device=store_on_device, ) + + # Next values need to be set to be available for the loss ntk calculation because + # it's implemented outside of the model class. + self.ntk_batch_size = ntk_batch_size + self.trace_axes = trace_axes + self.store_on_device = store_on_device + self.empirical_ntk_jit = jax.jit(self.empirical_ntk) self.apply_jit = jax.jit(self.apply) diff --git a/znnl/training_recording/jax_recording.py b/znnl/training_recording/jax_recording.py index 862b7d4..4ec3cfd 100644 --- a/znnl/training_recording/jax_recording.py +++ b/znnl/training_recording/jax_recording.py @@ -37,6 +37,7 @@ from znnl.analysis.eigensystem import EigenSpaceAnalysis from znnl.analysis.entropy import EntropyAnalysis from znnl.analysis.loss_fn_derivative import LossDerivative +from znnl.analysis.loss_ntk_calculation import LossNTKCalculation from znnl.loss_functions import SimpleLoss from znnl.models.jax_model import JaxModel from znnl.training_recording.data_storage import DataStorage @@ -157,6 +158,9 @@ class JaxRecorder: loss_derivative: bool = False _loss_derivative_array: list = None + # Use Loss NTK + use_loss_ntk: bool = False + # Class helpers update_rate: int = 1 _loss_fn: SimpleLoss = None @@ -165,8 +169,12 @@ class JaxRecorder: _model: JaxModel = None _data_set: dict = None _compute_ntk: bool = False # Helps to know if we can compute it once and share. + _compute_loss_ntk: bool = ( + False # Helps to know if we can compute it once and share. + ) _compute_loss_derivative: bool = False _loss_derivative_fn: LossDerivative = False + _loss_ntk_calculator: LossNTKCalculation = None _index_count: int = 0 # Helps to avoid problems with non-1 update rates. _data_storage: DataStorage = None # For writing to disk. @@ -178,7 +186,7 @@ def _read_selected_attributes(self): self._selected_properties = [ value for value in list(vars(self)) - if value[0] != "_" and vars(self)[value] is True + if value[0] != "_" and value != "use_loss_ntk" and vars(self)[value] is True ] def _build_or_resize_array(self, name: str, overwrite: bool): @@ -254,19 +262,38 @@ def instantiate_recorder(self, data_set: dict = None, overwrite: bool = False): self._index_count = 0 # Check if we need an NTK computation and update the class accordingly - if any([ - "ntk" in self._selected_properties, - "covariance_ntk" in self._selected_properties, - "magnitude_ntk" in self._selected_properties, - "entropy" in self._selected_properties, - "magnitude_entropy" in self._selected_properties, - "magnitude_variance" in self._selected_properties, - "covariance_entropy" in self._selected_properties, - "eigenvalues" in self._selected_properties, - "trace" in self._selected_properties, - ]): + if any( + [ + "ntk" in self._selected_properties, + "covariance_ntk" in self._selected_properties, + "magnitude_ntk" in self._selected_properties, + "entropy" in self._selected_properties, + "magnitude_entropy" in self._selected_properties, + "magnitude_variance" in self._selected_properties, + "covariance_entropy" in self._selected_properties, + "eigenvalues" in self._selected_properties, + "trace" in self._selected_properties, + ] + ): self._compute_ntk = True + # Check if we need a loss NTK computation and update the class accordingly + if self.use_loss_ntk: + try: + self._loss_ntk_calculator = LossNTKCalculation( + metric_fn=self._loss_fn.metric, + model=self._model, + dataset=self._data_set, + ) + except AttributeError: + # This happens frequently during the instantiation of the recorder. + # As this shouldn't lead to a problem if the loss function is set later, + # before the loss NTK is computed, we just log the issue and continue. + logger.info( + "Warning: The loss function hasn't been set yet." + "Please set it before training." + ) + if "loss_derivative" in self._selected_properties: self._loss_derivative_fn = LossDerivative(self._loss_fn) @@ -300,7 +327,7 @@ def update_recorder(self, epoch: int, model: JaxModel): parsed_data["predictions"] = predictions # Compute ntk here to avoid repeated computation. - if self._compute_ntk: + if self._compute_ntk and not self.use_loss_ntk: try: ntk = self._model.compute_ntk( self._data_set["inputs"], infinite=False @@ -321,6 +348,15 @@ def update_recorder(self, epoch: int, model: JaxModel): self.eigenvalues = False self._read_selected_attributes() + # Compute loss ntk here to avoid repeated computation. + if self._compute_ntk and self.use_loss_ntk: + parsed_data["ntk"] = self._loss_ntk_calculator.compute_loss_ntk( + x_i=self._data_set, + x_j=None, + model=self._model, + infinite=False, # Set true to compute infinite width limit + )["empirical"] + for item in self._selected_properties: call_fn = getattr(self, f"_update_{item}") # get the callable function diff --git a/znnl/training_strategies/simple_training.py b/znnl/training_strategies/simple_training.py index 325bf7e..707e37c 100644 --- a/znnl/training_strategies/simple_training.py +++ b/znnl/training_strategies/simple_training.py @@ -109,9 +109,11 @@ def __init__( self.review_metric = None - # Add the loss and accuracy function to the recorders and re-instantiate them + # Add the loss and accuracy function and the model + # to the recorders and re-instantiate them if self.recorders is not None: for item in self.recorders: + item._model = self.model item.loss_fn = loss_fn item.accuracy_fn = accuracy_fn item.instantiate_recorder() diff --git a/znnl/visualization/tsne_visualizer.py b/znnl/visualization/tsne_visualizer.py index 856ecdd..d0d613a 100644 --- a/znnl/visualization/tsne_visualizer.py +++ b/znnl/visualization/tsne_visualizer.py @@ -104,45 +104,47 @@ def run_visualization(self): fig_dict["layout"]["xaxis2"] = {"domain": [0.8, 1.0]} fig_dict["layout"]["yaxis2"] = {"anchor": "x2"} fig_dict["layout"]["hovermode"] = "closest" - fig_dict["layout"]["updatemenus"] = [{ - "buttons": [ - { - "args": [ - None, - { - "frame": {"duration": 500, "redraw": False}, - "fromcurrent": True, - "transition": { - "duration": 300, - "easing": "quadratic-in-out", + fig_dict["layout"]["updatemenus"] = [ + { + "buttons": [ + { + "args": [ + None, + { + "frame": {"duration": 500, "redraw": False}, + "fromcurrent": True, + "transition": { + "duration": 300, + "easing": "quadratic-in-out", + }, }, - }, - ], - "label": "Play", - "method": "animate", - }, - { - "args": [ - [None], - { - "frame": {"duration": 0, "redraw": False}, - "mode": "immediate", - "transition": {"duration": 0}, - }, - ], - "label": "Pause", - "method": "animate", - }, - ], - "direction": "left", - "pad": {"r": 10, "t": 87}, - "showactive": False, - "type": "buttons", - "x": 0.1, - "xanchor": "right", - "y": 0, - "yanchor": "top", - }] + ], + "label": "Play", + "method": "animate", + }, + { + "args": [ + [None], + { + "frame": {"duration": 0, "redraw": False}, + "mode": "immediate", + "transition": {"duration": 0}, + }, + ], + "label": "Pause", + "method": "animate", + }, + ], + "direction": "left", + "pad": {"r": 10, "t": 87}, + "showactive": False, + "type": "buttons", + "x": 0.1, + "xanchor": "right", + "y": 0, + "yanchor": "top", + } + ] sliders_dict = { "active": 0, @@ -163,20 +165,24 @@ def run_visualization(self): } # Add initial data - fig_dict["data"].append({ - "x": self.dynamic[0][:, 0], - "y": self.dynamic[0][:, 1], - "mode": "markers", - "name": "Predictor", - }) - fig_dict["data"].append({ - "x": self.reference[0][:, 0], - "y": self.reference[0][:, 1], - "mode": "markers", - "xaxis": "x2", - "yaxis": "y2", - "name": "Target", - }) + fig_dict["data"].append( + { + "x": self.dynamic[0][:, 0], + "y": self.dynamic[0][:, 1], + "mode": "markers", + "name": "Predictor", + } + ) + fig_dict["data"].append( + { + "x": self.reference[0][:, 0], + "y": self.reference[0][:, 1], + "mode": "markers", + "xaxis": "x2", + "yaxis": "y2", + "name": "Target", + } + ) # Make the figure frames. for i, item in enumerate(self.dynamic):