Skip to content

Commit

Permalink
Generalize cube extract (#2939)
Browse files Browse the repository at this point in the history
* generalize use of term "spectrum"
* move code to allow importing slice mixins without circular imports
* is_slice_indicator filter for extract results instead of relying on a spectrum viewer existing and being set
* attach marks to slice_indicator_viewers[0] instead of spectrum_viewer (temporarily to just the first instance, later we'll make this more flexible to allow any number of slice indicator viewers)
* remove reliance on loaded_flux_cube
* rename (internal) spectral_cube -> cube
* generalize logic for setting disabled_msg
* generalize internal marks names
* generalize some internal traitlet names
* update disabled message to specify cube
* update tests for renamed marks/method
* generalize logic within live-preview and extract
* allow skipping wavelength-dependence
* hide aperture masking method for entire cube and no bg
* fix (and allow disabling) bg export
* fix updating live preview on change to subset
* skip updating marks if not tray instance
  • Loading branch information
kecnry authored Jul 30, 2024
1 parent dc2186e commit 682e302
Show file tree
Hide file tree
Showing 20 changed files with 473 additions and 411 deletions.
4 changes: 4 additions & 0 deletions docs/reference/api_nuts_bolts.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ Utilities API
:no-inheritance-diagram:
:no-inherited-members:

.. automodapi:: jdaviz.configs.cubeviz.plugins.mixins
:no-inheritance-diagram:
:no-inherited-members:

.. automodapi:: jdaviz.configs.imviz.wcs_utils
:no-inheritance-diagram:
:no-inherited-members:
Expand Down
1 change: 1 addition & 0 deletions jdaviz/configs/cubeviz/plugins/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .tools import * # noqa
from .mixins import * # noqa
from .viewers import * # noqa
from .parsers import * # noqa
from .moment_maps.moment_maps import * # noqa
Expand Down
148 changes: 148 additions & 0 deletions jdaviz/configs/cubeviz/plugins/mixins.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import numpy as np
import astropy.units as u
from functools import cached_property

from jdaviz.core.marks import SliceIndicatorMarks

__all__ = ['WithSliceIndicator', 'WithSliceSelection']


class WithSliceIndicator:
@property
def slice_component_label(self):
return str(self.state.x_att)

@property
def slice_display_unit_name(self):
return 'spectral'

@cached_property
def slice_indicator(self):
# SliceIndicatorMarks does not yet exist
slice_indicator = SliceIndicatorMarks(self)
self.figure.marks = self.figure.marks + slice_indicator.marks
return slice_indicator

@property
def slice_values(self):
# NOTE: these are cached at the slice-plugin level
# Retrieve display units
slice_display_units = self.jdaviz_app._get_display_unit(
self.slice_display_unit_name
)

def _get_component(layer):
try:
# Retrieve layer data and units
data_comp = layer.layer.data.get_component(self.slice_component_label)
except (AttributeError, KeyError):
# layer either does not have get_component (because its a subset)
# or slice_component_label is not a component in this layer
# either way, return an empty array and skip this layer
return np.array([])

# Convert axis if display units are set and are different
data_units = getattr(data_comp, 'units', None)
if slice_display_units and data_units and slice_display_units != data_units:
data = np.asarray(data_comp.data, dtype=float) * u.Unit(data_units)
return data.to_value(slice_display_units,
equivalencies=u.spectral())
else:
return data_comp.data
try:
return np.asarray(np.unique(np.concatenate([_get_component(layer) for layer in self.layers])), # noqa
dtype=float)
except ValueError:
# NOTE: this will result in caching an empty list
return np.array([])

def _set_slice_indicator_value(self, value):
# this is a separate method so that viewers can override and map value if necessary
# NOTE: on first call, this will initialize the indicator itself
self.slice_indicator.value = value


class WithSliceSelection:
@property
def slice_index(self):
# index in state.slices corresponding to the slice axis
return 2

@property
def slice_component_label(self):
slice_plg = self.jdaviz_helper.plugins.get('Slice', None)
if slice_plg is None: # pragma: no cover
raise ValueError("slice plugin must be activated to access slice_component_label")
return slice_plg._obj.slice_indicator_viewers[0].slice_component_label

@property
def slice_display_unit_name(self):
return 'spectral'

@property
def slice_values(self):
# NOTE: these are cached at the slice-plugin level
# TODO: add support for multiple cubes (but then slice selection needs to be more complex)
# if slice_index is 0, then we want the equivalent of [:, 0, 0]
# if slice_index is 1, then we want the equivalent of [0, :, 0]
# if slice_index is 2, then we want the equivalent of [0, 0, :]
take_inds = [2, 1, 0]
take_inds.remove(self.slice_index)
converted_axis = np.array([])
for layer in self.layers:
world_comp_ids = layer.layer.data.world_component_ids
if self.slice_index >= len(world_comp_ids):
# Case where 2D image is loaded in image viewer
continue

# Retrieve display units
slice_display_units = self.jdaviz_app._get_display_unit(
self.slice_display_unit_name
)

try:
# Retrieve layer data and units using the slice index of the world components ids
data_comp = layer.layer.data.get_component(world_comp_ids[self.slice_index])
except (AttributeError, KeyError):
continue

data = np.asarray(data_comp.data.take(0, take_inds[0]).take(0, take_inds[1]), # noqa
dtype=float)

# Convert to display units if applicable
data_units = getattr(data_comp, 'units', None)
if slice_display_units and data_units and slice_display_units != data_units:
converted_axis = (data * u.Unit(data_units)).to_value(
slice_display_units,
equivalencies=u.spectral() + u.pixel_scale(1*u.pix)
)
else:
converted_axis = data

return converted_axis

@property
def slice(self):
return self.state.slices[self.slice_index]

@slice.setter
def slice(self, slice):
# NOTE: not intended for user-access - this should be controlled through the slice plugin
# in order to sync with all other viewers/slice indicators
slices = [0, 0, 0]
slices[self.slice_index] = slice
self.state.slices = tuple(slices)

@property
def slice_value(self):
return self.slice_values[self.slice]

@slice_value.setter
def slice_value(self, slice_value):
# NOTE: not intended for user-access - this should be controlled through the slice plugin
# in order to sync with all other viewers/slice indicators
# find the slice nearest slice_value
slice_values = self.slice_values
if not len(slice_values):
return
self.slice = np.argmin(abs(slice_values - slice_value))
4 changes: 0 additions & 4 deletions jdaviz/configs/cubeviz/plugins/parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ def parse_data(app, file_obj, data_type=None, data_label=None,
flux_viewer_reference_name=flux_viewer_reference_name,
uncert_viewer_reference_name=uncert_viewer_reference_name
)
app.get_tray_item_from_name("Spectral Extraction").disabled_msg = ""
elif isinstance(file_obj, str):
if file_obj.lower().endswith('.gif'): # pragma: no cover
_parse_gif(app, file_obj, data_label,
Expand Down Expand Up @@ -135,7 +134,6 @@ def parse_data(app, file_obj, data_type=None, data_label=None,
flux_viewer_reference_name=flux_viewer_reference_name,
uncert_viewer_reference_name=uncert_viewer_reference_name
)
app.get_tray_item_from_name("Spectral Extraction").disabled_msg = ""

# If the data types are custom data objects, use explicit parsers. Note
# that this relies on the glue-astronomy machinery to turn the data object
Expand All @@ -152,13 +150,11 @@ def parse_data(app, file_obj, data_type=None, data_label=None,
app, file_obj, data_label=data_label,
spectrum_viewer_reference_name=spectrum_viewer_reference_name
)
app.get_tray_item_from_name("Spectral Extraction").disabled_msg = ""

elif isinstance(file_obj, np.ndarray) and file_obj.ndim == 3:
_parse_ndarray(app, file_obj, data_label=data_label, data_type=data_type,
flux_viewer_reference_name=flux_viewer_reference_name,
uncert_viewer_reference_name=uncert_viewer_reference_name)
app.get_tray_item_from_name("Spectral Extraction").disabled_msg = ""
else:
raise NotImplementedError(f'Unsupported data format: {file_obj}')

Expand Down
Loading

0 comments on commit 682e302

Please sign in to comment.