Skip to content

Commit

Permalink
Merge pull request #205 from pycroscopy/complex_fit_viz
Browse files Browse the repository at this point in the history
complex fitter viz
  • Loading branch information
ramav87 committed Apr 12, 2024
2 parents d7352ba + f8ebbc2 commit 64dde6d
Showing 1 changed file with 152 additions and 0 deletions.
152 changes: 152 additions & 0 deletions sidpy/viz/dataset_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -2173,4 +2173,156 @@ def _update(self, ev=None):
self.axes[1].set_ylabel(self.ylabel)

self.fig.canvas.draw_idle()

class ComplexSpectralImageFitVisualizer(ComplexSpectralImageVisualizer):

def __init__(self, original_dataset, fit_parameters, xvec = None, figure=None, horizontal=True, **kwargs):
'''
Visualizer for complex spectral image datasets, fit by the Sidpy Fitter
This class is called by Sidpy Fitter for visualizing the raw/fit dataset interactively.
Inputs:
- original_dataset: sidpy.Dataset containing the raw data
- fit_dataset: sidpy.Dataset with the fitted parameters.
- xvec: Independent dimension vector, default is None (will be acquired from original_dataset if not provided).
- figure: (Optional, default None) - handle to existing figure
- horiziontal: (Optional, default True) - whether spectrum should be plotted horizontally
'''

super().__init__(original_dataset, figure, horizontal, **kwargs)

self.original_dataset = original_dataset
self.fit_parameters = fit_parameters
self.fit_dset = self._return_fit_dataset()
print('fit dataset shape is {}'.format(self.fit_dset.shape))
self.axes[1].clear()
self.get_fit_spectrum()
self.axes[1].plot(self.energy_scale, self.spectrum, 'bo')
self.axes[1].plot(self.energy_scale, self.fit_spectrum, 'r-')

def _return_fit_dataset(self):
#let's get back the fit function
fit_fn_packed = self.fit_parameters.metadata['fitting_functions']
key_f = list(self.fit_parameters.metadata['fitting_functions'].keys())[0]
encoded_value = fit_fn_packed[key_f]
serialized_value = base64.b64decode(encoded_value)
self._fit_function = dill.loads(serialized_value)

#Let's get the independent vector
ind_dims = []
for ind, (shape1, shape2) in enumerate(zip(self.fit_parameters.shape, self.original_dataset.shape)):
if shape1!=shape2:
ind_dims.append(ind)

#We need to get the vector.
if len(ind_dims)>1:
raise NotImplementedError("2 dimensional indepndent vectors are not implemented yet. TODO!")
else:
ind_vec = self.original_dataset._axes[ind_dims[0]].values

#create a copy of the original dataset
self.fitted_dataset = self.original_dataset.copy()
self.fitted_dataset = self.fitted_dataset.fold(method = 'spaspec') #TODO: this might not always be the case.

self.fit_parameters_folded = self.fit_parameters[:].reshape((self.fitted_dataset.shape[0],-1))

for ind in range(self.fitted_dataset.shape[0]):
output = self._fit_function(ind_vec, *self.fit_parameters_folded[ind])
real_part = output[:len(output)//2]
imag_part = output[len(output)//2:]
complex_n = real_part + 1j*imag_part
self.fitted_dataset[ind,:] = complex_n

fitted_dataset = self.fitted_dataset.unfold().squeeze()
return fitted_dataset

def get_fit_spectrum(self):

if self.x > self.dset.shape[self.image_dims[0]] - self.bin_x:
self.x = self.dset.shape[self.image_dims[0]] - self.bin_x
if self.y > self.dset.shape[self.image_dims[1]] - self.bin_y:
self.y = self.dset.shape[self.image_dims[1]] - self.bin_y
selection = []

for dim, axis in self.dset._axes.items():
if axis.dimension_type == sidpy.DimensionType.SPATIAL:
if dim == self.image_dims[0]:
selection.append(slice(self.x, self.x + self.bin_x))
else:
selection.append(slice(self.y, self.y + self.bin_y))

elif axis.dimension_type == sidpy.DimensionType.SPECTRAL:
selection.append(slice(None))
else:
selection.append(slice(0, 1))
self.selection = selection

self.spectrum = self.dset[tuple(selection)].mean(axis=tuple(self.image_dims))
self.fit_spectrum = self.fit_dset[tuple(selection)].mean(axis=tuple(self.image_dims))
# * self.intensity_scale[self.x,self.y]

return self.fit_spectrum.squeeze(), self.spectrum.squeeze()
#return self.spectrum.squeeze()


def _update_backup(self, ev=None):

xlim = self.axes[1].get_xlim()
ylim = self.axes[1].get_ylim()
self.axes[1].clear()
self.get_fit_spectrum()

self.axes[1].plot(self.energy_scale, self.spectrum, 'bo', label='experiment')
self.axes[1].plot(self.energy_scale, self.fit_spectrum, 'r-', label='fit')

if self.set_title:
self.axes[1].set_title('spectrum {}, {}'.format(self.x, self.y))

self.axes[1].set_xlim(xlim)
#self.axes[1].set_ylim(ylim)
self.axes[1].set_xlabel(self.xlabel)
self.axes[1].set_ylabel(self.ylabel)

self.fig.canvas.draw_idle()

def _update(self, ev=None):
"""
xlim_ax1 = self.axes[1].get_xlim()
ylim_ax1 = self.axes[1].get_ylim()
xlim_ax2 = self.axes[2].get_xlim()
ylim_ax2 = self.axes[2].get_ylim()
xlims = [xlim_ax1,xlim_ax2]
ylims = [ylim_ax1, ylim_ax2]
self.axes[1].clear()
self.axes[2].clear()"""
#xlim = self.axes[1].get_xlim()
#ylim = self.axes[1].get_ylim()
self.axes[1].clear()
self.axes[2].clear()

fit_spectrum, raw_spectrum = self.get_fit_spectrum()
print(fit_spectrum, raw_spectrum)
if self.ri_ap == 'Real and Imaginary':
self.axes[1].plot(self.energy_scale, np.real(raw_spectrum.compute()), 'bo', label='Real Raw')
self.axes[2].plot(self.energy_scale, np.imag(raw_spectrum.compute()), 'bo', label='Imaginary Raw')
self.axes[1].plot(self.energy_scale, np.real(fit_spectrum.compute()), 'r-', label='Real Fit')
self.axes[2].plot(self.energy_scale, np.imag(fit_spectrum.compute()), 'r-', label='Imaginary Fit')
else:
self.axes[1].plot(self.energy_scale, np.abs(raw_spectrum.compute()), 'bo', label='Raw Amplitude')
self.axes[2].plot(self.energy_scale, np.angle(raw_spectrum.compute()), 'bo', label='Raw Phase')
self.axes[1].plot(self.energy_scale, np.abs(fit_spectrum.compute()), 'r-', label='Fit Amplitude')
self.axes[2].plot(self.energy_scale, np.angle(fit_spectrum.compute()), 'r-', label='Fit Phase')

for ind,ax_ind in enumerate([1,2]):
if self.set_title:
self.axes[ax_ind].set_title('spectrum {}, {}'.format(self.x, self.y))

self.axes[ax_ind].set_xlabel(self.xlabel)
self.axes[ax_ind].set_ylabel(self.ylabel)
leg = self.axes[ax_ind].legend(loc = 'best')
leg.get_frame().set_linewidth(0.0)
self.fig.canvas.draw_idle()
self.fig.tight_layout()

0 comments on commit 64dde6d

Please sign in to comment.