Skip to content

Commit

Permalink
Merge pull request #204 from pycroscopy/fitter_viz_update
Browse files Browse the repository at this point in the history
Fitter viz update
  • Loading branch information
ramav87 committed Apr 9, 2024
2 parents fbdc67b + 3b49d7a commit d7352ba
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 14 deletions.
60 changes: 52 additions & 8 deletions sidpy/viz/dataset_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
import ipywidgets
from IPython.display import display
import scipy

import dill
import base64

# import matplotlib.animation as animation

Expand Down Expand Up @@ -2058,29 +2059,73 @@ def set_legend(self, set_legend):
def get_xy(self):
return [self.x, self.y]


class SpectralImageFitVisualizer(SpectralImageVisualizer):

def __init__(self, original_dataset, fit_dataset, figure=None, horizontal=True):
def __init__(self, original_dataset, fit_dataset, xvec = None, figure=None, horizontal=True):
'''
Visualizer for spectral image datasets, fit by the Sidpy Fitter
This class is called by Sidpy Fitter for visualizing the raw/fit dataset interactively.
Inputs:
- original_dataset: sidpy.Dataset containing the raw data
- fit_dataset: sidpy.Dataset with the fitted data. This is returned by the
Sidpy Fitter after functional fitting.
- fit_dataset: sidpy.Dataset with the fitted parameters, or the sidpy.Dataset returned by SidpyFitter.
- xvec: Independent dimension vector, default is None (will be acquired from original_dataset if not provided).
- figure: (Optional, default None) - handle to existing figure
- horiziontal: (Optional, default True) - whether spectrum should be plotted horizontally
'''

super().__init__(original_dataset, figure, horizontal)

self.original_dataset = original_dataset
if xvec is not None:
self.xvec = xvec
else:
self.xvec = None
if fit_dataset.shape != original_dataset.shape: #check if we have an actual fitted dataset or just the parameters
self.fit_parameters = fit_dataset
self.fit_dset = self._return_fit_dataset()
else:
self.fit_parameters = None
self.fit_dset = fit_dataset

self.fit_dset = fit_dataset
self.axes[1].clear()
self.get_fit_spectrum()
self.axes[1].plot(self.energy_scale, self.spectrum, 'bo')
self.axes[1].plot(self.energy_scale, self.fit_spectrum, 'r-')

def _return_fit_dataset(self):
#let's get back the fit function
fit_fn_packed = self.fit_parameters.metadata['fitting_functions']
key_f = list(self.fit_parameters.metadata['fitting_functions'].keys())[0]
encoded_value = fit_fn_packed[key_f]
serialized_value = base64.b64decode(encoded_value)
self._fit_function = dill.loads(serialized_value)

#Let's get the independent vector
if self.xvec is None:
ind_dims = []
for ind, (shape1, shape2) in enumerate(zip(self.fit_parameters.shape, self.original_dataset.shape)):
if shape1!=shape2:
ind_dims.append(ind)

#We need to get the vector.
if len(ind_dims)>1:
raise NotImplementedError("2 dimensional indepndent vectors are not implemented yet. TODO!")
else:
ind_vec = self.original_dataset._axes[ind_dims[0]].values
else:
ind_vec = self.xvec.copy()

#create a copy of the original dataset
self.fitted_dataset = self.original_dataset.copy()
self.fitted_dataset = self.fitted_dataset.fold(method = 'spaspec') #TODO: this might not always be the case.
self.fit_parameters_folded = self.fit_parameters[:].reshape((self.fitted_dataset.shape[0],-1))

for ind in range(self.fitted_dataset.shape[0]):
self.fitted_dataset[ind,:] = self._fit_function(ind_vec, *self.fit_parameters_folded[ind])
fitted_dataset = self.fitted_dataset.unfold()

return fitted_dataset

def get_fit_spectrum(self):

Expand Down Expand Up @@ -2128,5 +2173,4 @@ def _update(self, ev=None):
self.axes[1].set_ylabel(self.ylabel)

self.fig.canvas.draw_idle()



68 changes: 62 additions & 6 deletions tests/viz/test_dataset_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,48 @@
import numpy as np
sys.path.insert(0, "../../sidpy/")
import sidpy

from sidpy.proc.fitter import SidFitter


def get_fit_dataset(dset_shape=(5,5,32)):
#Define the function we want each spectrum to

def one_lin_func(xvec, *coeff):
a1,a2 = coeff
return a1*xvec + a2


#create a dataset
xvec = np.linspace(0,1, dset_shape[-1])
data_mat = np.zeros(shape=(dset_shape[0]*dset_shape[1], dset_shape[2]))
noise_level = 0.10

for xind in range(data_mat.shape[0]):
y_values = one_lin_func(xvec, *[np.random.uniform(0,1), np.random.normal()]) + \
noise_level*np.random.normal(size=len(xvec))
data_mat[xind] = y_values

data_mat = data_mat.reshape(dset_shape)

#make it a sidpy dataset
data_set = sidpy.Dataset.from_array(data_mat, name='test_dataset')
data_set.data_type = 'spectral_image'
data_set.units = 'nA'
data_set.quantity = 'Current'

data_set.set_dimension(0, sidpy.Dimension(np.arange(data_set.shape[0]),
name='x', units='um', quantity='Length',
dimension_type='spatial'))
data_set.set_dimension(1, sidpy.Dimension(np.arange(data_set.shape[0]),
'y', units='um', quantity='Length',
dimension_type='spatial'))
data_set.set_dimension(2, sidpy.Dimension(xvec,
name = 'bias',quantity = 'V', units = 'V', dimension_type='spectral'))
fitter = SidFitter(data_set, one_lin_func,num_workers=4,
threads=2, return_cov=False, return_fit=True, return_std=False,
km_guess=False,num_fit_parms = 2)
output = fitter.do_fit()
return data_set, output[0], output[1]

def get_spectrum(dtype=float):
x = np.array(np.random.normal(3, 2.5, size=1024), dtype=dtype)
Expand Down Expand Up @@ -434,11 +475,6 @@ def test_point_selection(self):
self.assertTrue(np.allclose(actual, expected, equal_nan=True, rtol=1e-05, atol=1e-08))







class Test4DImageStackPlot(unittest.TestCase):

def test_plot(self):
Expand Down Expand Up @@ -481,5 +517,25 @@ def test_plot_complex(self):
view = dataset.plot()
self.assertEqual(len(view.axes), 3)


class TestSpectralImageFitVisualizer(unittest.TestCase):

def test_plot_with_fit_parms(self):
original_dataset, fit_parameters, fitted_dataset = get_fit_dataset()
view = sidpy.viz.dataset_viz.SpectralImageFitVisualizer(original_dataset, fit_parameters)
self.assertEqual(len(view.axes), 2)

def test_plot_with_fitted_dataset(self):
original_dataset, fit_parameters, fitted_dataset = get_fit_dataset()
view = sidpy.viz.dataset_viz.SpectralImageFitVisualizer(original_dataset, fitted_dataset)
self.assertEqual(len(view.axes), 2)

def test_plot_with_custom_xvec(self):
original_dataset, fit_parameters, fitted_dataset = get_fit_dataset()
xvec = np.linspace(-1,2,32)
view = sidpy.viz.dataset_viz.SpectralImageFitVisualizer(original_dataset, fit_parameters, xvec = xvec)
self.assertEqual(len(view.axes), 2)


if __name__ == '__main__':
unittest.main()

0 comments on commit d7352ba

Please sign in to comment.