diff --git a/.github/labeler.yml b/.github/labeler.yml index 71b6f2320d..6078d8ce02 100644 --- a/.github/labeler.yml +++ b/.github/labeler.yml @@ -60,3 +60,8 @@ specviz2d: - changed-files: - any-glob-to-any-file: - jdaviz/configs/specviz2d/**/* + +rampviz: +- changed-files: + - any-glob-to-any-file: + - jdaviz/configs/rampviz/**/* diff --git a/docs/cubeviz/index.rst b/docs/cubeviz/index.rst index 6282b9c0ba..0500ff8975 100644 --- a/docs/cubeviz/index.rst +++ b/docs/cubeviz/index.rst @@ -11,7 +11,7 @@ :alt: Introductory video tour of the Cubeviz configuration and its features Cubeviz is a visualization and analysis toolbox for data cubes from -integral field units (IFUs). It is built as part of the +integral field units (IFUs). It is built on top of the `glue visualization `_ tool. Cubeviz is designed to work with data cubes from the NIRSpec and MIRI instruments on JWST, and will work with IFU data cubes. It uses diff --git a/docs/index.rst b/docs/index.rst index 79a848f55b..9978d79e66 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -55,6 +55,16 @@ Jdaviz Jump to Mosviz + .. grid-item-card:: + :img-top: logos/cube.svg + + .. button-ref:: rampviz/index + :expand: + :color: primary + :click-parent: + + Jump to Rampviz + ``jdaviz`` is a package of astronomical data analysis visualization tools based on the Jupyter platform. These GUI-based tools link data visualization and interactive analysis. They are designed to work diff --git a/docs/index_using_jdaviz.rst b/docs/index_using_jdaviz.rst index 6f47c01235..19f4a6e750 100644 --- a/docs/index_using_jdaviz.rst +++ b/docs/index_using_jdaviz.rst @@ -15,6 +15,7 @@ User Guide cubeviz/index specviz2d/index mosviz/index + rampviz/index plugin_api save_state display diff --git a/docs/rampviz/index.rst b/docs/rampviz/index.rst new file mode 100644 index 0000000000..672ec971e3 --- /dev/null +++ b/docs/rampviz/index.rst @@ -0,0 +1,21 @@ +.. |cubeviz_logo| image:: ../logos/cube.svg + :height: 42px + +.. _rampviz: + +###################### +|cubeviz_logo| Rampviz +###################### + + +Rampviz is a visualization and analysis toolbox for ramp cubes from +infrared detectors. It is built on top of the +`glue visualization `_ tool. Rampviz is designed to work +with ramp files from the Roman Space Telescope and JWST. + +**Using Rampviz** + +.. toctree:: + :maxdepth: 2 + + plugins diff --git a/docs/rampviz/plugins.rst b/docs/rampviz/plugins.rst new file mode 100644 index 0000000000..ccd94760d8 --- /dev/null +++ b/docs/rampviz/plugins.rst @@ -0,0 +1,66 @@ +********************* +Data Analysis Plugins +********************* + + + +.. _rampviz-metadata-viewer: + +Metadata Viewer +=============== + +.. seealso:: + + :ref:`Metadata Viewer ` + Documentation on using the metadata viewer. + + +.. _rampviz-plot-options: + +Plot Options +============ + +This plugin gives access to per-viewer and per-layer plotting options. +To show axes on image viewers, toggle on the "Show axes" option at the bottom of the plugin. + +.. seealso:: + + :ref:`Image Plot Options ` + Documentation on Imviz display settings in the Jdaviz viewers. + +.. _rampviz-subset-plugin: + +Subset Tools +============ + +.. seealso:: + + :ref:`Subset Tools ` + Imviz documentation describing the concept of subsets in Jdaviz. + + +Markers +======= + +.. seealso:: + + :ref:`Markers ` + Imviz documentation describing the markers plugin. + +.. _rampviz-slice: + +Slice +===== + +.. seealso:: + + :ref:`Slice ` + Documentation on using the Slice plugin. + +.. _ramp-extraction: + +Ramp Extraction +=============== + + + diff --git a/docs/reference/api_configs.rst b/docs/reference/api_configs.rst index c5826d2ba4..bceb0a73c3 100644 --- a/docs/reference/api_configs.rst +++ b/docs/reference/api_configs.rst @@ -20,3 +20,6 @@ Helpers API .. automodapi:: jdaviz.configs.specviz2d.helper :no-inheritance-diagram: + +.. automodapi:: jdaviz.configs.rampviz.helper + :no-inheritance-diagram: diff --git a/docs/reference/api_parsers.rst b/docs/reference/api_parsers.rst index 9272aa20c3..99665e245c 100644 --- a/docs/reference/api_parsers.rst +++ b/docs/reference/api_parsers.rst @@ -17,3 +17,6 @@ Parsers API .. automodapi:: jdaviz.configs.specviz2d.plugins.parsers :no-inheritance-diagram: + +.. automodapi:: jdaviz.configs.rampviz.plugins.parsers + :no-inheritance-diagram: \ No newline at end of file diff --git a/docs/reference/api_plugins.rst b/docs/reference/api_plugins.rst index 185b1a652a..2c15d1ea1e 100644 --- a/docs/reference/api_plugins.rst +++ b/docs/reference/api_plugins.rst @@ -48,6 +48,9 @@ Plugins API .. automodapi:: jdaviz.configs.cubeviz.plugins.slice.slice :no-inheritance-diagram: +.. automodapi:: jdaviz.configs.cubeviz.plugins.spectral_extraction.spectral_extraction + :no-inheritance-diagram: + .. automodapi:: jdaviz.configs.imviz.plugins.aper_phot_simple.aper_phot_simple :no-inheritance-diagram: @@ -86,3 +89,6 @@ Plugins API .. automodapi:: jdaviz.configs.specviz2d.plugins.spectral_extraction.spectral_extraction :no-inheritance-diagram: + +.. automodapi:: jdaviz.configs.rampviz.plugins.ramp_extraction.ramp_extraction + :no-inheritance-diagram: diff --git a/docs/reference/api_viewers.rst b/docs/reference/api_viewers.rst index 5908516f5c..21a5800c3b 100644 --- a/docs/reference/api_viewers.rst +++ b/docs/reference/api_viewers.rst @@ -17,3 +17,6 @@ Viewers API .. automodapi:: jdaviz.configs.specviz.plugins.viewers :no-inheritance-diagram: + +.. automodapi:: jdaviz.configs.rampviz.plugins.viewers + :no-inheritance-diagram: diff --git a/jdaviz/__init__.py b/jdaviz/__init__.py index 8553c6671e..1da9ab765b 100644 --- a/jdaviz/__init__.py +++ b/jdaviz/__init__.py @@ -13,11 +13,14 @@ # Top-level API as exposed to users. from jdaviz.app import * # noqa: F401, F403 -from jdaviz.configs.specviz import Specviz # noqa: F401 -from jdaviz.configs.specviz2d import Specviz2d # noqa: F401 -from jdaviz.configs.mosviz import Mosviz # noqa: F401 + from jdaviz.configs.cubeviz import Cubeviz # noqa: F401 from jdaviz.configs.imviz import Imviz # noqa: F401 +from jdaviz.configs.mosviz import Mosviz # noqa: F401 +from jdaviz.configs.rampviz import Rampviz # noqa: F401 +from jdaviz.configs.specviz import Specviz # noqa: F401 +from jdaviz.configs.specviz2d import Specviz2d # noqa: F401 + from jdaviz.utils import enable_hot_reloading # noqa: F401 from jdaviz.core.launcher import open # noqa: F401 diff --git a/jdaviz/app.py b/jdaviz/app.py index b552fc018e..3a3c7ad0e7 100644 --- a/jdaviz/app.py +++ b/jdaviz/app.py @@ -99,10 +99,14 @@ def to_unit(self, data, cid, values, original_units, target_units): # should return the converted values. Note that original_units # gives the units of the values array, which might not be the same # as the original native units of the component in the data. - if cid.label == "flux": + + if cid.label == 'Pixel Axis 0 [z]' and target_units == '': + # handle ramps loaded into Rampviz by avoiding conversion + # of the groups axis: + return values + elif cid.label == "flux": try: spec = data.get_object(cls=Spectrum1D) - except RuntimeError: data = data.get_object(cls=NDDataArray) spec = Spectrum1D(flux=data.data * u.Unit(original_units)) @@ -1290,6 +1294,9 @@ def _get_display_unit(self, axis): if check_if_unit_is_per_solid_angle(sv_y_unit): return sv_y_unit return sv_y_unit / u.sr + elif axis == 'temporal': + # No unit for ramp's time (group/resultant) axis: + return None else: raise ValueError(f"could not find units for axis='{axis}'") uc = self._jdaviz_helper.plugins.get('Unit Conversion')._obj @@ -1747,7 +1754,8 @@ def _get_first_viewer_reference_name( require_spectrum_2d_viewer=False, require_table_viewer=False, require_flux_viewer=False, - require_image_viewer=False + require_image_viewer=False, + require_profile_viewer=False, ): """ Return the viewer reference name of the first available viewer. @@ -1761,12 +1769,16 @@ def _get_first_viewer_reference_name( from jdaviz.configs.mosviz.plugins.viewers import ( MosvizTableViewer, MosvizProfile2DView ) + from jdaviz.configs.rampviz.plugins.viewers import ( + RampvizImageView, RampvizProfileView + ) spectral_viewers = (SpecvizProfileView, CubevizProfileView) spectral_2d_viewers = (MosvizProfile2DView, ) table_viewers = (MosvizTableViewer, ) - image_viewers = (ImvizImageView, CubevizImageView) - flux_viewers = (CubevizImageView, ) + image_viewers = (ImvizImageView, CubevizImageView, RampvizImageView) + flux_viewers = (CubevizImageView, RampvizImageView) + ramp_viewers = (RampvizProfileView, ) for vid in self._viewer_store: viewer_item = self._viewer_item_by_id(vid) @@ -1789,6 +1801,9 @@ def _get_first_viewer_reference_name( elif require_flux_viewer: if isinstance(self._viewer_store[vid], flux_viewers) and is_returnable: return viewer_item['reference'] + elif require_profile_viewer: + if isinstance(self._viewer_store[vid], ramp_viewers) and is_returnable: + return viewer_item['reference'] else: if is_returnable: return viewer_item['reference'] diff --git a/jdaviz/configs/__init__.py b/jdaviz/configs/__init__.py index eb078d8c5c..fca0324a7a 100644 --- a/jdaviz/configs/__init__.py +++ b/jdaviz/configs/__init__.py @@ -1,6 +1,7 @@ from .cubeviz import * # noqa -from .specviz import * # noqa -from .specviz2d import * # noqa from .default import * # noqa -from .mosviz import * # noqa from .imviz import * # noqa +from .mosviz import * # noqa +from .rampviz import * # noqa +from .specviz import * # noqa +from .specviz2d import * # noqa diff --git a/jdaviz/configs/cubeviz/cubeviz.yaml b/jdaviz/configs/cubeviz/cubeviz.yaml index 038df00431..d1e7dc9186 100644 --- a/jdaviz/configs/cubeviz/cubeviz.yaml +++ b/jdaviz/configs/cubeviz/cubeviz.yaml @@ -23,7 +23,7 @@ tray: - g-data-quality - g-subset-plugin - g-markers - - cubeviz-slice + - cube-slice - g-unit-conversion - cubeviz-spectral-extraction - g-gaussian-smooth diff --git a/jdaviz/configs/cubeviz/helper.py b/jdaviz/configs/cubeviz/helper.py index 36caca8029..fce4f61b15 100644 --- a/jdaviz/configs/cubeviz/helper.py +++ b/jdaviz/configs/cubeviz/helper.py @@ -1,9 +1,8 @@ -from jdaviz.core.events import SnackbarMessage -from jdaviz.core.helpers import ImageConfigHelper from jdaviz.configs.default.plugins.line_lists.line_list_mixin import LineListMixin from jdaviz.configs.specviz import Specviz -from jdaviz.core.events import (AddDataMessage, - SliceSelectSliceMessage) +from jdaviz.core.events import AddDataMessage, SnackbarMessage +from jdaviz.core.helpers import CubeConfigHelper +from jdaviz.configs.cubeviz.plugins.viewers import CubevizImageView __all__ = ['Cubeviz'] @@ -12,19 +11,22 @@ "Wavenumber", "Velocity", "Energy"] -class Cubeviz(ImageConfigHelper, LineListMixin): +class Cubeviz(CubeConfigHelper, LineListMixin): """Cubeviz Helper class""" _default_configuration = 'cubeviz' _default_spectrum_viewer_reference_name = "spectrum-viewer" _default_uncert_viewer_reference_name = "uncert-viewer" _default_flux_viewer_reference_name = "flux-viewer" _default_image_viewer_reference_name = "image-viewer" + _cube_viewer_default_label = _default_flux_viewer_reference_name _loaded_flux_cube = None _loaded_uncert_cube = None + _cube_viewer_cls = CubevizImageView def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + self.app.hub.subscribe(self, AddDataMessage, handler=self._set_spectrum_x_axis) @@ -106,8 +108,7 @@ def select_wavelength(self, wavelength): """ if not isinstance(wavelength, (int, float)): raise TypeError("wavelength must be a float or int") - msg = SliceSelectSliceMessage(value=wavelength, sender=self) - self.app.hub.broadcast(msg) + self.select_slice(wavelength) @property def specviz(self): diff --git a/jdaviz/configs/cubeviz/plugins/__init__.py b/jdaviz/configs/cubeviz/plugins/__init__.py index 61390844ae..4cc5c65a3a 100644 --- a/jdaviz/configs/cubeviz/plugins/__init__.py +++ b/jdaviz/configs/cubeviz/plugins/__init__.py @@ -1,7 +1,7 @@ -from .tools import * # noqa from .mixins import * # noqa from .viewers import * # noqa from .parsers import * # noqa from .moment_maps.moment_maps import * # noqa from .slice.slice import * # noqa from .spectral_extraction.spectral_extraction import * # noqa +from .tools import * # noqa diff --git a/jdaviz/configs/cubeviz/plugins/mixins.py b/jdaviz/configs/cubeviz/plugins/mixins.py index a5be11bf59..ea57ab01cb 100644 --- a/jdaviz/configs/cubeviz/plugins/mixins.py +++ b/jdaviz/configs/cubeviz/plugins/mixins.py @@ -14,7 +14,7 @@ def slice_component_label(self): @property def slice_display_unit_name(self): - return 'spectral' + return 'spectral' if self.jdaviz_app.config == 'cubeviz' else 'temporal' @cached_property def slice_indicator(self): @@ -77,7 +77,7 @@ def slice_component_label(self): @property def slice_display_unit_name(self): - return 'spectral' + return 'spectral' if self.jdaviz_app.config == 'cubeviz' else 'temporal' @property def slice_values(self): @@ -91,6 +91,11 @@ def slice_values(self): converted_axis = np.array([]) for layer in self.layers: world_comp_ids = layer.layer.data.world_component_ids + + if not len(world_comp_ids): + # rampviz uses coordinate components: + world_comp_ids = layer.layer.data.coordinate_components + if self.slice_index >= len(world_comp_ids): # Case where 2D image is loaded in image viewer continue diff --git a/jdaviz/configs/cubeviz/plugins/parsers.py b/jdaviz/configs/cubeviz/plugins/parsers.py index fcdeb44f1c..d9b2484a11 100644 --- a/jdaviz/configs/cubeviz/plugins/parsers.py +++ b/jdaviz/configs/cubeviz/plugins/parsers.py @@ -10,7 +10,6 @@ from astropy.wcs import WCS from specutils import Spectrum1D -from jdaviz.configs.imviz.plugins.parsers import prep_data_layer_as_dq from jdaviz.core.registries import data_parser_registry from jdaviz.utils import standardize_metadata, PRIHDR_KEY, download_uri_to_path @@ -342,7 +341,11 @@ def _parse_jwst_s3d(app, hdulist, data_label, ext='SCI', app.add_data(data, data_label, parent=parent) # get glue data and update if DQ: + if ext == 'DQ': + # prevent circular import: + from jdaviz.configs.imviz.plugins.parsers import prep_data_layer_as_dq + data = app.data_collection[-1] prep_data_layer_as_dq(data) diff --git a/jdaviz/configs/cubeviz/plugins/slice/slice.py b/jdaviz/configs/cubeviz/plugins/slice/slice.py index a1cf014677..0bd4380ca0 100644 --- a/jdaviz/configs/cubeviz/plugins/slice/slice.py +++ b/jdaviz/configs/cubeviz/plugins/slice/slice.py @@ -8,9 +8,12 @@ from astropy.units import UnitsWarning from traitlets import Bool, Int, Unicode, observe -from jdaviz.configs.cubeviz.plugins.viewers import (WithSliceIndicator, WithSliceSelection, - CubevizImageView) +from jdaviz.configs.cubeviz.plugins.viewers import ( + WithSliceIndicator, WithSliceSelection, CubevizImageView +) from jdaviz.configs.cubeviz.helper import _spectral_axis_names +from jdaviz.configs.rampviz.helper import _temporal_axis_names +from jdaviz.configs.rampviz.plugins.viewers import RampvizImageView from jdaviz.core.custom_traitlets import FloatHandleEmpty from jdaviz.core.events import (AddDataMessage, RemoveDataMessage, SliceToolStateMessage, SliceSelectSliceMessage, SliceValueUpdatedMessage, @@ -24,7 +27,7 @@ __all__ = ['Slice'] -@tray_registry('cubeviz-slice', label="Slice", viewer_requirements='spectrum') +@tray_registry('cube-slice', label="Slice") class Slice(PluginTemplateMixin): """ See the :ref:`Slice Plugin Documentation ` for more details. @@ -44,8 +47,7 @@ class Slice(PluginTemplateMixin): * ``show_value`` Whether to show slice value in label to right of indicator. """ - _cube_viewer_cls = CubevizImageView - _cube_viewer_default_label = 'flux-viewer' + cube_viewer_exists = Bool(True).tag(sync=True) allow_disable_snapping = Bool(False).tag(sync=True) # noqa internal use to show and allow disabling snap-to-slice @@ -101,6 +103,18 @@ def __init__(self, *args, **kwargs): handler=self._on_global_display_unit_changed) self._initialize_location() + @property + def _cube_viewer_default_label(self): + if hasattr(self.app, '_jdaviz_helper') and self.app._jdaviz_helper is not None: + return getattr(self.app._jdaviz_helper, '_cube_viewer_default_label') + return tuple() + + @property + def _cube_viewer_cls(self): + if hasattr(self.app, '_jdaviz_helper') and self.app._jdaviz_helper is not None: + return getattr(self.app._jdaviz_helper, '_cube_viewer_cls') + return tuple() + def _initialize_location(self, *args): # initialize value_unit (this has to wait until data is loaded to an existing # slice_indicator_viewer, so we'll keep trying until it is set - after that, changes @@ -137,11 +151,17 @@ def _initialize_location(self, *args): @property def slice_display_unit_name(self): # global display unit "axis" corresponding to the slice axis - return 'spectral' + if self.app.config == 'cubeviz': + return 'spectral' + elif self.app.config == 'rampviz': + return 'temporal' @property def valid_slice_att_names(self): - return _spectral_axis_names + ['Pixel Axis 2 [x]', 'World 0'] + if self.app.config == 'cubeviz': + return _spectral_axis_names + ['Pixel Axis 2 [x]', 'World 0'] + elif self.app.config == 'rampviz': + return _temporal_axis_names + ['Pixel Axis 2 [x]'] @property def slice_selection_viewers(self): @@ -166,7 +186,8 @@ def _check_if_cube_viewer_exists(self, *args): self.cube_viewer_exists = False def vue_create_cube_viewer(self, *args): - self.app._on_new_viewer(NewViewerMessage(self._cube_viewer_cls, data=None, sender=self.app), + cls = RampvizImageView if self.app.config == 'rampviz' else CubevizImageView + self.app._on_new_viewer(NewViewerMessage(cls, data=None, sender=self.app), vid=self._cube_viewer_default_label, name=self._cube_viewer_default_label) @@ -194,6 +215,7 @@ def _on_viewer_removed(self, msg): self._check_if_cube_viewer_exists() def _on_add_data(self, msg): + self._check_if_cube_viewer_exists() self._clear_cache() self._initialize_location() if isinstance(msg.viewer, WithSliceSelection): @@ -207,6 +229,9 @@ def _on_select_slice_message(self, msg): self.value = msg.value def _on_global_display_unit_changed(self, msg): + if not self.app.config == 'cubeviz': + return + if msg.axis != self.slice_display_unit_name: return if not self.value_unit: diff --git a/jdaviz/configs/cubeviz/plugins/spectral_extraction/spectral_extraction.py b/jdaviz/configs/cubeviz/plugins/spectral_extraction/spectral_extraction.py index f59ba643a5..3762b1649b 100644 --- a/jdaviz/configs/cubeviz/plugins/spectral_extraction/spectral_extraction.py +++ b/jdaviz/configs/cubeviz/plugins/spectral_extraction/spectral_extraction.py @@ -37,7 +37,7 @@ class SpectralExtraction(PluginTemplateMixin, ApertureSubsetSelectMixin, DatasetSelectMixin, AddResultsMixin): """ - See the :ref:`Spectral Extraction Plugin Documentation ` for more details. + See the :ref:`Spectral Extraction Plugin Documentation ` for more details. Only the following attributes and methods are available through the :ref:`public plugin API `: @@ -52,7 +52,7 @@ class SpectralExtraction(PluginTemplateMixin, ApertureSubsetSelectMixin, to intersect ``aperture`` at ``reference_spectral_value``. * ``reference_spectral_value``: The wavelength that will be used to calculate the radius of the cone through the cube. - * ``background`` (:class:`~jdaviz.comre.template_mixin.ApertureSubsetSelect`): + * ``background`` (:class:`~jdaviz.core.template_mixin.ApertureSubsetSelect`): Subset to use for background subtraction, or ``None``. * ``bg_wavelength_dependent``: Whether the ``background`` aperture should be considered wavelength-dependent (requires @@ -67,7 +67,7 @@ class SpectralExtraction(PluginTemplateMixin, ApertureSubsetSelectMixin, * ``aperture_method`` (:class:`~jdaviz.core.template_mixin.SelectPluginComponent`): Method to use for extracting spectrum (and background, if applicable). * ``add_results`` (:class:`~jdaviz.core.template_mixin.AddResults`) - * :meth:`collapse` + * :meth:`extract` """ template_file = __file__, "spectral_extraction.vue" uses_active_status = Bool(True).tag(sync=True) diff --git a/jdaviz/configs/cubeviz/plugins/tools.py b/jdaviz/configs/cubeviz/plugins/tools.py index 0974b03299..5914151dcc 100644 --- a/jdaviz/configs/cubeviz/plugins/tools.py +++ b/jdaviz/configs/cubeviz/plugins/tools.py @@ -3,14 +3,13 @@ from glue.config import viewer_tool from glue_jupyter.bqplot.image import BqplotImageView -from glue_jupyter.bqplot.profile import BqplotProfileView from glue.viewers.common.tool import CheckableTool import numpy as np from specutils import Spectrum1D from jdaviz.core.events import SliceToolStateMessage, SliceSelectSliceMessage -from jdaviz.core.tools import PanZoom, BoxZoom, SinglePixelRegion, _MatchedZoomMixin -from jdaviz.core.marks import PluginLine +from jdaviz.core.tools import PanZoom, BoxZoom, _MatchedZoomMixin +from jdaviz.configs.default.plugins.tools import ProfileFromCube __all__ = [] @@ -81,52 +80,17 @@ def on_mouse_event(self, data): @viewer_tool -class SpectrumPerSpaxel(SinglePixelRegion): +class SpectrumPerSpaxel(ProfileFromCube): icon = os.path.join(ICON_DIR, 'pixelspectra.svg') tool_id = 'jdaviz:spectrumperspaxel' action_text = 'See spectrum at a single spaxel' tool_tip = 'Click on the viewer and see the spectrum at that spaxel in the spectrum viewer' - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._spectrum_viewer = None - self._previous_bounds = None - self._mark = None - self._data = None - - def _reset_spectrum_viewer_bounds(self): - sv_state = self._spectrum_viewer.state - sv_state.x_min = self._previous_bounds[0] - sv_state.x_max = self._previous_bounds[1] - sv_state.y_min = self._previous_bounds[2] - sv_state.y_max = self._previous_bounds[3] - - def activate(self): - self.viewer.add_event_callback(self.on_mouse_move, events=['mousemove', 'mouseleave']) - if self._spectrum_viewer is None: - # Get first profile viewer - for _, viewer in self.viewer.jdaviz_helper.app._viewer_store.items(): - if isinstance(viewer, BqplotProfileView): - self._spectrum_viewer = viewer - break - if self._mark is None: - self._mark = PluginLine(self._spectrum_viewer, visible=False) - self._spectrum_viewer.figure.marks = self._spectrum_viewer.figure.marks + [self._mark,] - # Store these so we can revert to previous user-set zoom after preview view - sv_state = self._spectrum_viewer.state - self._previous_bounds = [sv_state.x_min, sv_state.x_max, sv_state.y_min, sv_state.y_max] - super().activate() - - def deactivate(self): - self.viewer.remove_event_callback(self.on_mouse_move) - self._reset_spectrum_viewer_bounds() - super().deactivate() - def on_mouse_move(self, data): if data['event'] == 'mouseleave': self._mark.visible = False - self._reset_spectrum_viewer_bounds() + self._reset_profile_viewer_bounds() return x = int(np.round(data['domain']['x'])) @@ -157,13 +121,13 @@ def on_mouse_move(self, data): else: spectrum = cube_data.get_object(statistic=None) # Note: change this when Spectrum1D.with_spectral_axis is fixed. - x_unit = self._spectrum_viewer.state.x_display_unit + x_unit = self._profile_viewer.state.x_display_unit if spectrum.spectral_axis.unit != x_unit: new_spectral_axis = spectrum.spectral_axis.to(x_unit) spectrum = Spectrum1D(spectrum.flux, new_spectral_axis) if x >= spectrum.flux.shape[0] or x < 0 or y >= spectrum.flux.shape[1] or y < 0: - self._reset_spectrum_viewer_bounds() + self._reset_profile_viewer_bounds() self._mark.visible = False else: y_values = spectrum.flux[x, y, :] @@ -172,5 +136,5 @@ def on_mouse_move(self, data): return self._mark.update_xy(spectrum.spectral_axis.value, y_values) self._mark.visible = True - self._spectrum_viewer.state.y_max = np.nanmax(y_values.value) * 1.2 - self._spectrum_viewer.state.y_min = np.nanmin(y_values.value) * 0.8 + self._profile_viewer.state.y_max = np.nanmax(y_values.value) * 1.2 + self._profile_viewer.state.y_min = np.nanmin(y_values.value) * 0.8 diff --git a/jdaviz/configs/default/plugins/tools.py b/jdaviz/configs/default/plugins/tools.py new file mode 100644 index 0000000000..342dbbc298 --- /dev/null +++ b/jdaviz/configs/default/plugins/tools.py @@ -0,0 +1,47 @@ +from glue_jupyter.bqplot.profile import BqplotProfileView +from jdaviz.core.tools import SinglePixelRegion +from jdaviz.core.marks import PluginLine + + +__all__ = ['ProfileFromCube'] + + +class ProfileFromCube(SinglePixelRegion): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._profile_viewer = None + self._previous_bounds = None + self._mark = None + self._data = None + + def _reset_profile_viewer_bounds(self): + pv_state = self._profile_viewer.state + pv_state.x_min = self._previous_bounds[0] + pv_state.x_max = self._previous_bounds[1] + pv_state.y_min = self._previous_bounds[2] + pv_state.y_max = self._previous_bounds[3] + + def activate(self): + self.viewer.add_event_callback(self.on_mouse_move, events=['mousemove', 'mouseleave']) + if self._profile_viewer is None: + # Get first profile viewer + for _, viewer in self.viewer.jdaviz_helper.app._viewer_store.items(): + if isinstance(viewer, BqplotProfileView): + self._profile_viewer = viewer + break + if self._mark is None: + self._mark = PluginLine(self._profile_viewer, visible=False) + self._profile_viewer.figure.marks = self._profile_viewer.figure.marks + [self._mark, ] + # Store these so we can revert to previous user-set zoom after preview view + pv_state = self._profile_viewer.state + self._previous_bounds = [pv_state.x_min, pv_state.x_max, pv_state.y_min, pv_state.y_max] + super().activate() + + def deactivate(self): + self.viewer.remove_event_callback(self.on_mouse_move) + self._reset_profile_viewer_bounds() + super().deactivate() + + def on_mouse_move(self, data): + raise NotImplementedError("must be implemented by subclasses") diff --git a/jdaviz/configs/default/plugins/viewers.py b/jdaviz/configs/default/plugins/viewers.py index 4f11fd511f..f99c039d3e 100644 --- a/jdaviz/configs/default/plugins/viewers.py +++ b/jdaviz/configs/default/plugins/viewers.py @@ -1,20 +1,47 @@ -import numpy as np from echo import delay_callback +import numpy as np + +from glue.config import data_translator +from glue.core import BaseData +from glue.core.exceptions import IncompatibleAttribute +from glue.core.units import UnitConverter +from glue.core.subset import Subset +from glue.core.subset_group import GroupedSubset from glue.viewers.scatter.state import ScatterLayerState as BqplotScatterLayerState + +from glue_astronomy.spectral_coordinates import SpectralCoordinates from glue_jupyter.bqplot.profile import BqplotProfileView from glue_jupyter.bqplot.image import BqplotImageView from glue_jupyter.table import TableViewer +from astropy import units as u +from astropy.nddata import ( + NDDataArray, StdDevUncertainty, VarianceUncertainty, InverseVariance +) +from specutils import Spectrum1D + from jdaviz.components.toolbar_nested import NestedJupyterToolbar from jdaviz.core.astrowidgets_api import AstrowidgetsImageViewerMixin +from jdaviz.core.events import SnackbarMessage +from jdaviz.core.freezable_state import FreezableProfileViewerState +from jdaviz.core.marks import LineUncertainties, ScatterMask, OffscreenLinesMarks from jdaviz.core.registries import viewer_registry from jdaviz.core.template_mixin import WithCache from jdaviz.core.user_api import ViewerUserApi from jdaviz.utils import (ColorCycler, get_subset_type, _wcs_only_label, layer_is_image_data, layer_is_not_dq) -__all__ = ['JdavizViewerMixin'] +uc = UnitConverter() + +uncertainty_str_to_cls_mapping = { + "std": StdDevUncertainty, + "var": VarianceUncertainty, + "ivar": InverseVariance +} + + +__all__ = ['JdavizViewerMixin', 'JdavizProfileView'] viewer_registry.add("g-profile-viewer", label="Profile 1D", cls=BqplotProfileView) viewer_registry.add("g-image-viewer", label="Image 2D", cls=BqplotImageView) @@ -164,7 +191,7 @@ def _apply_layer_defaults(self, layer_state): layer_state.add_callback('as_steps', self._show_uncertainty_changed) def _expected_subset_layer_default(self, layer_state): - if self.__class__.__name__ == 'CubevizImageView': + if self.__class__.__name__ in ('CubevizImageView', 'RampvizImageView'): # Do not override default for subsets as for some reason # this isn't getting called when they're first added, but rather when # the next state change is made (for example: manually changing the visibility) @@ -214,8 +241,11 @@ def _get_layer_info(layer): ) if layer.visible and not layer_is_wcs_only: prefix_icon, subset_type = _get_layer_info(layer) - if self.__class__.__name__ == 'CubevizProfileView' and subset_type == 'spatial': - # do not show spatial subsets in spectral-viewer + if ( + subset_type == 'spatial' and + self.__class__.__name__ in ('CubevizProfileView', 'RampvizProfileView') + ): + # do not show spatial subsets in profile viewer continue visible_layers[layer.layer.label] = {'color': _get_layer_color(layer), 'linewidth': _get_layer_linewidth(layer), @@ -362,3 +392,413 @@ def _ref_or_id(self): def set_plot_axes(self): # individual viewers can override to set custom axes labels/ticks/styling return + + +@viewer_registry("jdaviz-profile-viewer", label="Profile 1D") +class JdavizProfileView(JdavizViewerMixin, BqplotProfileView): + # categories: zoom resets, zoom, pan, subset, select tools, shortcuts + tools_nested = [ + ['jdaviz:homezoom', 'jdaviz:prevzoom'], + ['jdaviz:boxzoom', 'jdaviz:xrangezoom', 'jdaviz:yrangezoom'], + ['jdaviz:panzoom', 'jdaviz:panzoom_x', 'jdaviz:panzoom_y'], + ['bqplot:xrange'], + ['jdaviz:sidebar_plot', 'jdaviz:sidebar_export'] + ] + + default_class = NDDataArray + _state_cls = FreezableProfileViewerState + _default_profile_subset_type = None + + def __init__(self, *args, **kwargs): + default_tool_priority = kwargs.pop('default_tool_priority', []) + super().__init__(*args, **kwargs) + + self._subscribe_to_layers_update() + self.initialize_toolbar(default_tool_priority=default_tool_priority) + self._offscreen_lines_marks = OffscreenLinesMarks(self) + self.figure.marks = self.figure.marks + self._offscreen_lines_marks.marks + + self.state.add_callback('show_uncertainty', self._show_uncertainty_changed) + + self.display_mask = False + + # Change collapse function to sum + default_collapse_function = kwargs.pop('default_collapse_function', 'sum') + + self.state.function = default_collapse_function + + def _expected_subset_layer_default(self, layer_state): + super()._expected_subset_layer_default(layer_state) + + layer_state.linewidth = 3 + + def data(self, cls=None): + # Grab the user's chosen statistic for collapsing data + statistic = getattr(self.state, 'function', None) + data = [] + + for layer_state in self.state.layers: + if hasattr(layer_state, 'layer'): + lyr = layer_state.layer + + # For raw data, just include the data itself + if isinstance(lyr, BaseData): + _class = cls or self.default_class + + if _class is not None: + cache_key = (lyr.label, statistic) + if cache_key in self.jdaviz_app._get_object_cache: + layer_data = self.jdaviz_app._get_object_cache[cache_key] + else: + # If spectrum, collapse via the defined statistic + if _class == Spectrum1D: + layer_data = lyr.get_object(cls=_class, statistic=statistic) + else: + layer_data = lyr.get_object(cls=_class) + self.jdaviz_app._get_object_cache[cache_key] = layer_data + + data.append(layer_data) + + # For subsets, make sure to apply the subset mask to the layer data first + elif isinstance(lyr, Subset): + layer_data = lyr + + if _class is not None: + handler, _ = data_translator.get_handler_for(_class) + try: + layer_data = handler.to_object(layer_data, statistic=statistic) + except IncompatibleAttribute: + continue + data.append(layer_data) + + return data + + def get_scales(self): + fig = self.figure + # Deselect any pan/zoom or subsetting tools so they don't interfere + # with the scale retrieval + if self.toolbar.active_tool is not None: + self.toolbar.active_tool = None + return {'x': fig.interaction.x_scale, 'y': fig.interaction.y_scale} + + def _show_uncertainty_changed(self, msg=None): + # this is subscribed in init to watch for changes to the state + # object since uncertainty handling is in jdaviz instead of glue/glue-jupyter + if self.state.show_uncertainty: + self._plot_uncertainties() + else: + self._clean_error() + + def show_mask(self): + self.display_mask = True + self._plot_mask() + + def clean(self): + # Remove extra traces, in case they exist. + self.display_mask = False + self._clean_mask() + + # this will automatically call _clean_error via _show_uncertainty_changed + self.state.show_uncertainty = False + + def _clean_mask(self): + fig = self.figure + fig.marks = [x for x in fig.marks if not isinstance(x, ScatterMask)] + + def _clean_error(self): + fig = self.figure + fig.marks = [x for x in fig.marks if not isinstance(x, LineUncertainties)] + + def add_data(self, data, color=None, alpha=None, **layer_state): + """ + Overrides the base class to add markers for plotting + uncertainties and data quality flags. + + Parameters + ---------- + spectrum : :class:`glue.core.data.Data` + Data object with the spectrum. + color : obj + Color value for plotting. + alpha : float + Alpha value for plotting. + + Returns + ------- + result : bool + `True` if successful, `False` otherwise. + """ + # If this is the first loaded data, set things up for unit conversion. + if len(self.layers) == 0: + reset_plot_axes = True + else: + # Check if the new data flux unit is actually compatible since flux not linked. + try: + if self.state.y_display_unit not in ['None', None, 'DN']: + uc.to_unit(data, data.find_component_id("flux"), [1, 1], + u.Unit(self.state.y_display_unit)) # Error if incompatible + except Exception as err: + # Raising exception here introduces a dirty state that messes up next load_data + # but not raising exception also causes weird behavior unless we remove the data + # completely. + self.session.hub.broadcast(SnackbarMessage( + f"Failed to load {data.label}, so removed it: {repr(err)}", + sender=self, color='error')) + self.jdaviz_app.data_collection.remove(data) + return False + reset_plot_axes = False + + # The base class handles the plotting of the main + # trace representing the profile itself. + result = super().add_data(data, color, alpha, **layer_state) + + if reset_plot_axes: + x_units = data.get_component(self.state.x_att.label).units + + y_axis_component = ( + 'flux' if 'flux' in [comp.label for comp in self.state.layers[0].layer.components] + else 'data' + ) + y_units = data.get_component(y_axis_component).units + with delay_callback(self.state, "x_display_unit", "y_display_unit"): + self.state.x_display_unit = x_units if len(x_units) else None + self.state.y_display_unit = y_units if len(y_units) else None + self.set_plot_axes() + + self._plot_uncertainties() + + self._plot_mask() + + # Set default linewidth on any created spectral subset layers + # NOTE: this logic will need updating if we add support for multiple cubes as this assumes + # that new data entries (from model fitting or gaussian smooth, etc) will only be spectra + # and all subsets affected will be spectral + for layer in self.state.layers: + if (isinstance(layer.layer, GroupedSubset) + and get_subset_type(layer.layer) == self._default_profile_subset_type + and layer.layer.data.label == data.label): + layer.linewidth = 3 + + return result + + def _plot_mask(self): + if not self.display_mask: + return + + # Remove existing mask marks + self._clean_mask() + + # Loop through all active data in the viewer + for index, layer_state in enumerate(self.state.layers): + lyr = layer_state.layer + comps = [str(component) for component in lyr.components] + + # Skip subsets + if hasattr(lyr, "subset_state"): + continue + + # Ignore data that does not have a mask component + if "mask" in comps: + mask = np.array(lyr['mask'].data) + + data_obj = lyr.data.get_object(cls=self.default_class) + + if self.default_class == Spectrum1D: + data_x = data_obj.spectral_axis.value + data_y = data_obj.flux.value + else: + data_x = np.arange(data_obj.shape[-1]) + data_y = data_obj.data.value + + # For plotting markers only for the masked data + # points, erase un-masked data from trace. + y = np.where(np.asarray(mask) == 0, np.nan, data_y) + + # A subclass of the bqplot Scatter object, ScatterMask places + # 'X' marks where there is masked data in the viewer. + color = layer_state.color + alpha_shade = layer_state.alpha / 3 + mask_line_mark = ScatterMask(scales=self.scales, + marker='cross', + x=data_x, + y=y, + stroke_width=0.5, + colors=[color], + default_size=25, + default_opacities=[alpha_shade] + ) + # Add mask marks to viewer + self.figure.marks = list(self.figure.marks) + [mask_line_mark] + + def _plot_uncertainties(self): + if not self.state.show_uncertainty: + return + + # Remove existing error bars + self._clean_error() + + # Loop through all active data in the viewer + for index, layer_state in enumerate(self.state.layers): + lyr = layer_state.layer + + # Skip subsets + if hasattr(lyr, "subset_state"): + continue + + comps = [str(component) for component in lyr.components] + + # Ignore data that does not have an uncertainty component + if "uncertainty" in comps: # noqa + error = np.array(lyr['uncertainty'].data) + + # ensure that the uncertainties are represented as stddev: + uncertainty_type_str = lyr.meta.get('uncertainty_type', 'stddev') + uncert_cls = uncertainty_str_to_cls_mapping[uncertainty_type_str] + error = uncert_cls(error).represent_as(StdDevUncertainty).array + + # Then we assume that last axis is always wavelength. + # This may need adjustment after the following + # specutils PR is merged: https://github.com/astropy/specutils/pull/1033 + spectral_axis = -1 + data_obj = lyr.data.get_object(cls=self.default_class, statistic=None) + + if isinstance(lyr.data.coords, SpectralCoordinates): + spectral_wcs = lyr.data.coords + data_x = spectral_wcs.pixel_to_world_values( + np.arange(lyr.data.shape[spectral_axis]) + ) + if isinstance(data_x, tuple): + data_x = data_x[0] + else: + if hasattr(lyr.data.coords, 'spectral_wcs'): + spectral_wcs = lyr.data.coords.spectral_wcs + elif hasattr(lyr.data.coords, 'spectral'): + spectral_wcs = lyr.data.coords.spectral + data_x = spectral_wcs.pixel_to_world( + np.arange(lyr.data.shape[spectral_axis]) + ) + + data_y = data_obj.data + + # The shaded band around the spectrum trace is bounded by + # two lines, above and below the spectrum trace itself. + data_x_list = np.ndarray.tolist(data_x) + x = [data_x_list, data_x_list] + y = [np.ndarray.tolist(data_y - error), + np.ndarray.tolist(data_y + error)] + + if layer_state.as_steps: + for i in (0, 1): + a = np.insert(x[i], 0, 2*x[i][0] - x[i][1]) + b = np.append(x[i], 2*x[i][-1] - x[i][-2]) + edges = (a + b) / 2 + x[i] = np.concatenate((edges[:1], np.repeat(edges[1:-1], 2), edges[-1:])) + y[i] = np.repeat(y[i], 2) + x, y = np.asarray(x), np.asarray(y) + + # A subclass of the bqplot Lines object, LineUncertainties keeps + # track of uncertainties plotted in the viewer. LineUncertainties + # appear with two lines and shaded area in between. + color = layer_state.color + alpha_shade = layer_state.alpha / 3 + error_line_mark = LineUncertainties(viewer=self, + x=[x], + y=[y], + scales=self.scales, + stroke_width=1, + colors=[color, color], + fill_colors=[color, color], + opacities=[0.0, 0.0], + fill_opacities=[alpha_shade, + alpha_shade], + fill='between', + close_path=False + ) + + # Add error lines to viewer + self.figure.marks = list(self.figure.marks) + [error_line_mark] + + def set_plot_axes(self): + # Set y axes labels for the spectrum viewer + y_display_unit = self.state.y_display_unit + y_unit = ( + u.Unit(y_display_unit) if y_display_unit and y_display_unit != 'None' + else u.dimensionless_unscaled + ) + + # Get local units. + locally_defined_flux_units = [ + u.Jy, u.mJy, u.uJy, u.MJy, + u.W / (u.m**2 * u.Hz), + u.eV / (u.s * u.m**2 * u.Hz), + u.erg / (u.s * u.cm**2), + u.erg / (u.s * u.cm**2 * u.Angstrom), + u.erg / (u.s * u.cm**2 * u.Hz), + u.ph / (u.s * u.cm**2 * u.Angstrom), + u.ph / (u.s * u.cm**2 * u.Hz), + u.bol, u.AB, u.ST + ] + + locally_defined_sb_units = [ + unit / u.sr for unit in locally_defined_flux_units + ] + + if any(y_unit.is_equivalent(unit) for unit in locally_defined_sb_units): + flux_unit_type = "Surface Brightness" + elif any(y_unit.is_equivalent(unit) for unit in locally_defined_flux_units): + flux_unit_type = 'Flux' + elif ( + y_unit.is_equivalent(u.DN) or + y_unit.is_equivalent(u.electron / u.s) or + y_unit.physical_type == 'dimensionless' + ): + # electron / s or 'dimensionless_unscaled' should be labeled counts + flux_unit_type = "Counts" + elif y_unit.is_equivalent(u.W): + flux_unit_type = "Luminosity" + else: + # default to Flux Density for flux density or uncaught types + flux_unit_type = "Flux density" + + # Set x axes labels for the spectrum viewer + x_disp_unit = self.state.x_display_unit + x_unit = u.Unit(x_disp_unit) if x_disp_unit else u.dimensionless_unscaled + + if x_unit.is_equivalent(u.m): + spectral_axis_unit_type = "Wavelength" + elif x_unit.is_equivalent(u.Hz): + spectral_axis_unit_type = "Frequency" + elif x_unit.is_equivalent(u.pixel): + spectral_axis_unit_type = "Pixel" + elif x_unit.is_equivalent(u.dimensionless_unscaled): + # case for rampviz + spectral_axis_unit_type = "Group" + else: + spectral_axis_unit_type = str(x_unit.physical_type).title() + + with self.figure.hold_sync(): + self.figure.axes[0].label = f"{spectral_axis_unit_type}" + ( + f" [{self.state.x_display_unit}]" + if self.state.x_display_unit not in ["None", None] else "" + ) + self.figure.axes[1].label = f"{flux_unit_type}" + ( + f"[{self.state.y_display_unit}]" + if self.state.y_display_unit not in ["None", None] else "" + ) + + # Make it so axis labels are not covering tick numbers. + self.figure.fig_margin["left"] = 95 + self.figure.fig_margin["bottom"] = 60 + self.figure.send_state('fig_margin') # Force update + self.figure.axes[0].label_offset = "40" + self.figure.axes[1].label_offset = "-70" + # NOTE: with tick_style changed below, the default responsive ticks in bqplot result + # in overlapping tick labels. For now we'll hardcode at 8, but this could be removed + # (default to None) if/when bqplot auto ticks react to styling options. + self.figure.axes[1].num_ticks = 8 + + # Set Y-axis to scientific notation + self.figure.axes[1].tick_format = '0.1e' + + for i in (0, 1): + self.figure.axes[i].tick_style = {'font-size': 15, 'font-weight': 600} diff --git a/jdaviz/configs/imviz/plugins/coords_info/coords_info.py b/jdaviz/configs/imviz/plugins/coords_info/coords_info.py index 1e0b82983b..aa2c99c714 100644 --- a/jdaviz/configs/imviz/plugins/coords_info/coords_info.py +++ b/jdaviz/configs/imviz/plugins/coords_info/coords_info.py @@ -11,6 +11,7 @@ from jdaviz.configs.imviz.plugins.viewers import ImvizImageView from jdaviz.configs.mosviz.plugins.viewers import (MosvizImageView, MosvizProfileView, MosvizProfile2DView) +from jdaviz.configs.rampviz.plugins.viewers import RampvizImageView, RampvizProfileView from jdaviz.configs.specviz.plugins.viewers import SpecvizProfileView from jdaviz.core.events import ViewerAddedMessage, GlobalDisplayUnitChanged from jdaviz.core.helpers import data_has_valid_wcs @@ -29,10 +30,12 @@ class CoordsInfo(TemplateMixin, DatasetSelectMixin): _supported_viewer_classes = (SpecvizProfileView, ImvizImageView, CubevizImageView, + RampvizImageView, + RampvizProfileView, MosvizImageView, MosvizProfile2DView) - _viewer_classes_with_marker = (SpecvizProfileView, MosvizProfile2DView) + _viewer_classes_with_marker = (RampvizProfileView, SpecvizProfileView, MosvizProfile2DView) dataset_icon = Unicode("").tag( sync=True @@ -240,10 +243,13 @@ def vue_next_layer(self, *args, **kwargs): def update_display(self, viewer, x, y): self._dict = {} - if isinstance(viewer, SpecvizProfileView): + if isinstance(viewer, (SpecvizProfileView, RampvizProfileView)): self._spectrum_viewer_update(viewer, x, y) elif isinstance(viewer, - (ImvizImageView, CubevizImageView, MosvizImageView, MosvizProfile2DView)): + (ImvizImageView, CubevizImageView, + MosvizImageView, MosvizProfile2DView, + RampvizImageView) + ): self._image_viewer_update(viewer, x, y) def _image_shape_inds(self, image): @@ -376,6 +382,9 @@ def _image_viewer_update(self, viewer, x, y): self._dict['spectral_axis'] = slice_plugin.value self._dict['spectral_axis:unit'] = slice_plugin._obj.value_unit + elif isinstance(viewer, RampvizImageView): + coords_status = False + elif isinstance(viewer, MosvizImageView): if data_has_valid_wcs(image, ndim=2): @@ -476,7 +485,7 @@ def _image_viewer_update(self, viewer, x, y): dq_data = associated_dq_layer.layer.get_data(dq_attribute) dq_value = dq_data[int(round(y)), int(round(x))] unit = image.get_component(attribute).units - elif isinstance(viewer, CubevizImageView): + elif isinstance(viewer, (CubevizImageView, RampvizImageView)): arr = image.get_component(attribute).data unit = image.get_component(attribute).units value = self._get_cube_value( diff --git a/jdaviz/configs/imviz/plugins/parsers.py b/jdaviz/configs/imviz/plugins/parsers.py index 2992b46963..7267341f73 100644 --- a/jdaviz/configs/imviz/plugins/parsers.py +++ b/jdaviz/configs/imviz/plugins/parsers.py @@ -14,7 +14,10 @@ from jdaviz.core.registries import data_parser_registry from jdaviz.core.events import SnackbarMessage -from jdaviz.utils import standardize_metadata, PRIHDR_KEY, _wcs_only_label, download_uri_to_path +from jdaviz.utils import ( + standardize_metadata, standardize_roman_metadata, + PRIHDR_KEY, _wcs_only_label, download_uri_to_path +) try: from roman_datamodels import datamodels as rdd @@ -400,8 +403,8 @@ def _roman_2d_to_glue_data(file_obj, data_label, ext=None): else: ext_list = (ext, ) - meta = getattr(file_obj, 'meta', {}) - coords = getattr(meta, 'wcs', None) + meta = standardize_roman_metadata(file_obj) + coords = getattr(getattr(file_obj, 'meta', {}), 'wcs', None) for cur_ext in ext_list: comp_label = cur_ext.upper() @@ -413,7 +416,7 @@ def _roman_2d_to_glue_data(file_obj, data_label, ext=None): bunit = getattr(ext_values, 'unit', '') component = Component.autotyped(np.array(ext_values), units=bunit) data.add_component(component=component, label=comp_label) - data.meta.update(standardize_metadata(dict(meta))) + data.meta.update(meta) if comp_label == 'dq': prep_data_layer_as_dq(data) diff --git a/jdaviz/configs/rampviz/__init__.py b/jdaviz/configs/rampviz/__init__.py new file mode 100644 index 0000000000..af257ac8cc --- /dev/null +++ b/jdaviz/configs/rampviz/__init__.py @@ -0,0 +1,2 @@ +from .plugins import * # noqa +from .helper import Rampviz # noqa diff --git a/jdaviz/configs/rampviz/helper.py b/jdaviz/configs/rampviz/helper.py new file mode 100644 index 0000000000..103b1e87c5 --- /dev/null +++ b/jdaviz/configs/rampviz/helper.py @@ -0,0 +1,99 @@ +from jdaviz.core.events import SliceSelectSliceMessage +from jdaviz.core.helpers import CubeConfigHelper +from jdaviz.configs.rampviz.plugins.viewers import RampvizImageView + +__all__ = ['Rampviz'] + +_temporal_axis_names = ['group', 'groups'] + + +class Rampviz(CubeConfigHelper): + """Rampviz Helper class""" + _default_configuration = 'rampviz' + _default_group_viewer_reference_name = "group-viewer" + _default_diff_viewer_reference_name = "diff-viewer" + _default_integration_viewer_reference_name = "integration-viewer" + _cube_viewer_default_label = _default_group_viewer_reference_name + + _loaded_flux_cube = None + _cube_viewer_cls = RampvizImageView + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.cube_cache = {} + + def load_data(self, data, data_label=None, **kwargs): + """ + Load and parse a data cube with Cubeviz. + (Note that only one cube may be loaded per Cubeviz instance.) + + Parameters + ---------- + data : str, `~roman_datamodels.datamodels.DataModel`, `~astropy.nddata.NDDataArray` or ndarray + A string file path, Roman DataModel object pointing to the + data cube, an NDDataArray, or a Numpy array. + If plain array is given, axes order must be ``(x, y, z)``. + data_label : str or `None` + Data label to go with the given data. If not given, + one will be automatically generated. + **kwargs : dict + Extra keywords accepted by Jdaviz application-level parser. + """ # noqa + if data_label: + kwargs['data_label'] = data_label + + super().load_data(data, parser_reference="ramp-data-parser", **kwargs) + self._set_x_axis() + + def _set_x_axis(self, msg={}): + group_viewer = self.app.get_viewer(self._default_group_viewer_reference_name) + ref_data = group_viewer.state.reference_data + group_viewer.state.x_att = ref_data.id["Pixel Axis 0 [z]"] + group_viewer.state.y_att = ref_data.id["Pixel Axis 1 [y]"] + + def select_group(self, group_index): + """ + Select the slice closest to the provided wavelength. + + Parameters + ---------- + group_index : float + Group index to select in units of the x-axis of the integration. + The nearest group will be selected if "snap to slice" is enabled + in the slice plugin. + """ + if not isinstance(group_index, (int, float)): + raise TypeError("group_index must be convertible to an integer") + if group_index < 0: + raise ValueError("group_index must be positive") + + msg = SliceSelectSliceMessage(value=int(group_index), sender=self) + self.app.hub.broadcast(msg) + + def get_data(self, data_label=None, spatial_subset=None, + temporal_subset=None, cls=None, use_display_units=False): + """ + Returns data with name equal to ``data_label`` of type ``cls`` with subsets applied from + ``temporal_subset``, if applicable. + + Parameters + ---------- + data_label : str, optional + Provide a label to retrieve a specific data set from data_collection. + spatial_subset : str, optional + Spatial subset applied to data. Only applicable if ``data_label`` points to a cube or + image. To extract a spectrum from a cube, use the spectral extraction plugin instead. + temporal_subset : str, optional + cls : `~specutils.Spectrum1D`, `~astropy.nddata.CCDData`, optional + The type that data will be returned as. + + Returns + ------- + data : cls + Data is returned as type cls with subsets applied. + + """ + return self._get_data(data_label=data_label, spatial_subset=spatial_subset, + temporal_subset=temporal_subset, + cls=cls, use_display_units=use_display_units) diff --git a/jdaviz/configs/rampviz/plugins/__init__.py b/jdaviz/configs/rampviz/plugins/__init__.py new file mode 100644 index 0000000000..e00c3cbac2 --- /dev/null +++ b/jdaviz/configs/rampviz/plugins/__init__.py @@ -0,0 +1,4 @@ +from .viewers import * # noqa +from .parsers import * # noqa +from .ramp_extraction import * # noqa +from .tools import * # noqa diff --git a/jdaviz/configs/rampviz/plugins/parsers.py b/jdaviz/configs/rampviz/plugins/parsers.py new file mode 100644 index 0000000000..ee5e9b933d --- /dev/null +++ b/jdaviz/configs/rampviz/plugins/parsers.py @@ -0,0 +1,304 @@ +import logging +import os +import numpy as np +import astropy.units as u +from astropy.io import fits +from astropy.nddata import NDData, NDDataArray +from astropy.time import Time + +from jdaviz.core.registries import data_parser_registry +from jdaviz.configs.cubeviz.plugins.parsers import _get_data_type_by_hdu +from jdaviz.utils import ( + standardize_metadata, download_uri_to_path, + PRIHDR_KEY, standardize_roman_metadata +) + +try: + from roman_datamodels import datamodels as rdd +except ImportError: + HAS_ROMAN_DATAMODELS = False +else: + HAS_ROMAN_DATAMODELS = True + +__all__ = ['parse_data'] + + +@data_parser_registry("ramp-data-parser") +def parse_data(app, file_obj, data_type=None, data_label=None, + parent=None, cache=None, local_path=None, timeout=None): + """ + Attempts to parse a data file and auto-populate available viewers in + rampviz. + + Parameters + ---------- + app : `~jdaviz.app.Application` + The application-level object used to reference the viewers. + file_obj : str + The path to a cube-like data file. + data_type : str, {'flux', 'mask', 'uncert'} + The data type used to explicitly differentiate parsed data. + data_label : str, optional + The label to be applied to the Glue data component. + parent : str, optional + Data label for "parent" data to associate with the loaded data as "child". + cache : None, bool, or str + Cache the downloaded file if the data are retrieved by a query + to a URL or URI. + local_path : str, optional + Cache remote files to this path. This is only used if data is + requested from `astroquery.mast`. + timeout : float, optional + If downloading from a remote URI, set the timeout limit for + remote requests in seconds (passed to + `~astropy.utils.data.download_file` or + `~astroquery.mast.Conf.timeout`). + """ + + group_viewer_reference_name = app._jdaviz_helper._default_group_viewer_reference_name + diff_viewer_reference_name = app._jdaviz_helper._default_diff_viewer_reference_name + integration_viewer_reference_name = ( + app._jdaviz_helper._default_integration_viewer_reference_name + ) + + if data_type is not None and data_type.lower() not in ('flux', 'mask', 'uncert'): + raise TypeError("Data type must be one of 'flux', 'mask', or 'uncert' " + f"but got '{data_type}'") + + # If the file object is an hdulist or a string, use the generic parser for + # fits files. + # TODO: this currently only supports fits files. We will want to make this + # generic enough to work with other file types (e.g. ASDF). For now, this + # supports MaNGA and JWST data. + if isinstance(file_obj, fits.hdu.hdulist.HDUList): + _parse_hdulist( + app, file_obj, file_name=data_label, + group_viewer_reference_name=group_viewer_reference_name, + diff_viewer_reference_name=diff_viewer_reference_name, + ) + elif isinstance(file_obj, str): + if file_obj.lower().endswith('.asdf'): + if not HAS_ROMAN_DATAMODELS: + raise ImportError( + "ASDF detected but roman-datamodels is not installed." + ) + with rdd.open(file_obj) as pf: + _roman_3d_to_glue_data( + app, pf, data_label, + group_viewer_reference_name=group_viewer_reference_name, + diff_viewer_reference_name=diff_viewer_reference_name, + meta=dict(pf.meta) + ) + return + + # try parsing file_obj as a URI/URL: + file_obj = download_uri_to_path( + file_obj, cache=cache, local_path=local_path, timeout=timeout + ) + + file_name = os.path.basename(file_obj) + + with fits.open(file_obj) as hdulist: + _parse_hdulist( + app, hdulist, file_name=data_label or file_name, + group_viewer_reference_name=group_viewer_reference_name, + diff_viewer_reference_name=diff_viewer_reference_name, + ) + + elif isinstance(file_obj, np.ndarray) and file_obj.ndim == 3: + # load 3D cube to group viewer + _parse_ndarray( + app, file_obj, data_label=data_label, data_type=data_type, + viewer_reference_name=group_viewer_reference_name, + meta=getattr(file_obj, 'meta') + ) + + elif isinstance(file_obj, (np.ndarray, NDData)) and file_obj.ndim in (1, 2): + # load 1D profile(s) to integration_viewer + _parse_ndarray( + app, file_obj, data_label=data_label, + viewer_reference_name=integration_viewer_reference_name, + meta=getattr(file_obj, 'meta') + ) + + elif HAS_ROMAN_DATAMODELS and isinstance(file_obj, rdd.DataModel): + with rdd.open(file_obj) as pf: + _roman_3d_to_glue_data( + app, pf, data_label, meta=pf.meta, + group_viewer_reference_name=group_viewer_reference_name, + diff_viewer_reference_name=diff_viewer_reference_name, + ) + + else: + raise NotImplementedError(f'Unsupported data format: {file_obj}') + + +def _roman_3d_to_glue_data( + app, file_obj, data_label, + group_viewer_reference_name=None, + diff_viewer_reference_name=None, + meta=None +): + """ + Parse a Roman 3D ramp cube file (Level 1), + usually with suffix '_uncal.asdf'. + """ + def _swap_axes(x): + # swap axes per the conventions of Roman cubes + # (group axis comes first) and the default in + # Cubeviz (wavelength axis expected last) + return np.swapaxes(x, 0, -1) + + # update viewer reference names for Roman ramp cubes: + # app._update_viewer_reference_name() + + data = file_obj.data + + if data_label is None: + data_label = app.return_data_label(file_obj) + + # last axis is the group axis, first two are spatial axes: + diff_data = np.vstack([ + # begin with a group of zeros, so + # that `diff_data.ndim == data.ndim` + np.zeros((1, *data[0].shape)), + np.diff(data, axis=0) + ]) + + ramp_cube_data_label = f"{data_label}[DATA]" + ramp_diff_data_label = f"{data_label}[DIFF]" + + # load these cubes into the cache: + app._jdaviz_helper.cube_cache[ramp_cube_data_label] = NDDataArray(_swap_axes(data)) + app._jdaviz_helper.cube_cache[ramp_diff_data_label] = NDDataArray(_swap_axes(diff_data)) + + if meta is not None: + meta = standardize_roman_metadata(file_obj) + + # load these cubes into the app: + _parse_ndarray( + app, + file_obj=_swap_axes(data), + data_label=ramp_cube_data_label, + viewer_reference_name=group_viewer_reference_name, + meta=meta + ) + _parse_ndarray( + app, + file_obj=_swap_axes(diff_data), + data_label=ramp_diff_data_label, + viewer_reference_name=diff_viewer_reference_name, + meta=meta + ) + + # the default collapse function in the profile viewer is "sum", + # but for ramp files, "median" is more useful: + viewer = app.get_viewer('integration-viewer') + viewer.state.function = 'median' + + +def _parse_hdulist( + app, hdulist, file_name=None, + viewer_reference_name=None +): + if file_name is None and hasattr(hdulist, 'file_name'): + file_name = hdulist.file_name + else: + file_name = file_name or "Unknown HDU object" + + is_loaded = [] + + # TODO: This needs refactoring to be more robust. + # Current logic fails if there are multiple EXTVER. + for hdu in hdulist: + if hdu.data is None or not hdu.is_image or hdu.data.ndim != 3: + continue + + data_type = _get_data_type_by_hdu(hdu) + if not data_type: + continue + + # Only load each type once. + if data_type in is_loaded: + continue + + is_loaded.append(data_type) + data_label = app.return_data_label(file_name, hdu.name) + + if 'BUNIT' in hdu.header: + try: + flux_unit = u.Unit(hdu.header['BUNIT']) + except Exception: + logging.warning("Invalid BUNIT, using DN as data unit") + flux_unit = u.DN + else: + logging.warning("Invalid BUNIT, using DN as data unit") + flux_unit = u.DN + + flux = hdu.data << flux_unit + metadata = standardize_metadata(hdu.header) + if hdu.name != 'PRIMARY' and 'PRIMARY' in hdulist: + metadata[PRIHDR_KEY] = standardize_metadata(hdulist['PRIMARY'].header) + + app.add_data(flux, data_label) + app.data_collection[data_label].get_component("data").units = flux_unit + app.add_data_to_viewer(viewer_reference_name, data_label) + app._jdaviz_helper._loaded_flux_cube = app.data_collection[data_label] + + +def _parse_jwst_level1( + app, hdulist, data_label, ext='SCI', + viewer_name=None, +): + hdu = hdulist[ext] + data_type = _get_data_type_by_hdu(hdu) + + # Manually inject MJD-OBS until we can support GWCS, see + # https://github.com/spacetelescope/jdaviz/issues/690 and + # https://github.com/glue-viz/glue-astronomy/issues/59 + if ext == 'SCI' and 'MJD-OBS' not in hdu.header: + for key in ('MJD-BEG', 'DATE-OBS'): # Possible alternatives + if key in hdu.header: + if key.startswith('MJD'): + hdu.header['MJD-OBS'] = hdu.header[key] + break + else: + t = Time(hdu.header[key]) + hdu.header['MJD-OBS'] = t.mjd + break + + unit = u.Unit(hdu.header.get('BUNIT', 'count')) + flux = hdu.data << unit + + metadata = standardize_metadata(hdu.header) + app.data_collection[data_label] = NDData(data=flux, meta=metadata) + + if data_type == 'flux': + app.data_collection[-1].get_component("data").units = flux.unit + + if viewer_name is not None: + app.add_data_to_viewer(viewer_name, data_label) + + if data_type == 'flux': + app._jdaviz_helper._loaded_flux_cube = app.data_collection[data_label] + + +def _parse_ndarray( + app, file_obj, data_label=None, + viewer_reference_name=None, + meta=None +): + if data_label is None: + data_label = app.return_data_label(file_obj) + + # Cannot change axis to ensure roundtripping within Rampviz. + # Axes must already be (x, y, z) at this point. + + if isinstance(file_obj, NDData): + ndd = file_obj + else: + ndd = NDDataArray(data=file_obj, meta=meta) + app.add_data(ndd, data_label) + + app.add_data_to_viewer(viewer_reference_name, data_label) + app._jdaviz_helper._loaded_flux_cube = app.data_collection[data_label] diff --git a/jdaviz/configs/rampviz/plugins/ramp_extraction/__init__.py b/jdaviz/configs/rampviz/plugins/ramp_extraction/__init__.py new file mode 100644 index 0000000000..ad40821f85 --- /dev/null +++ b/jdaviz/configs/rampviz/plugins/ramp_extraction/__init__.py @@ -0,0 +1 @@ +from .ramp_extraction import * # noqa \ No newline at end of file diff --git a/jdaviz/configs/rampviz/plugins/ramp_extraction/ramp_extraction.py b/jdaviz/configs/rampviz/plugins/ramp_extraction/ramp_extraction.py new file mode 100644 index 0000000000..817040e96f --- /dev/null +++ b/jdaviz/configs/rampviz/plugins/ramp_extraction/ramp_extraction.py @@ -0,0 +1,470 @@ +import numpy as np +import astropy.units as u +from astropy.nddata import NDDataArray + +from functools import cached_property +from traitlets import Bool, Float, List, Unicode, observe, Int +from glue.core.message import ( + DataCollectionAddMessage, SubsetCreateMessage, SubsetDeleteMessage, SubsetUpdateMessage +) + +from jdaviz.core.events import SnackbarMessage, SliceValueUpdatedMessage +from jdaviz.core.marks import PluginLine +from jdaviz.core.registries import tray_registry +from jdaviz.core.template_mixin import (PluginTemplateMixin, + DatasetSelectMixin, + SelectPluginComponent, + ApertureSubsetSelectMixin, + ApertureSubsetSelect, + AddResultsMixin, + skip_if_not_tray_instance, + skip_if_no_updates_since_last_active, + with_spinner, with_temp_disable) +from jdaviz.core.user_api import PluginUserApi +from jdaviz.configs.cubeviz.plugins.viewers import WithSliceIndicator + + +__all__ = ['RampExtraction'] + +rng = np.random.default_rng(seed=42) + + +@tray_registry( + 'ramp-extraction', label="Ramp Extraction", viewer_requirements='profile' +) +class RampExtraction(PluginTemplateMixin, ApertureSubsetSelectMixin, + DatasetSelectMixin, AddResultsMixin): + """ + See the :ref:`Ramp Extraction Plugin Documentation ` for more details. + + Only the following attributes and methods are available through the + :ref:`public plugin API `: + + * :meth:`~jdaviz.core.template_mixin.PluginTemplateMixin.show` + * :meth:`~jdaviz.core.template_mixin.PluginTemplateMixin.open_in_tray` + * :meth:`~jdaviz.core.template_mixin.PluginTemplateMixin.close_in_tray` + * ``aperture`` (:class:`~jdaviz.core.template_mixin.ApertureSubsetSelect`): + Subset to use for the ramp extraction, or ``Entire Cube``. + * ``aperture_method`` (:class:`~jdaviz.core.template_mixin.SelectPluginComponent`): + Method to use for extracting a ramp profile + * ``add_results`` (:class:`~jdaviz.core.template_mixin.AddResults`) + * :meth:`extract` + """ + template_file = __file__, "ramp_extraction.vue" + uses_active_status = Bool(True).tag(sync=True) + show_live_preview = Bool(False).tag(sync=True) + show_subset_preview = Bool(True).tag(sync=True) + subset_preview_warning = Bool(False).tag(sync=True) + subset_preview_limit = Int(250).tag(sync=True) + + active_step = Unicode().tag(sync=True) + + resulting_product_name = Unicode("ramp").tag(sync=True) + do_auto_extraction = True + + slice_group_value = Float().tag(sync=True) + + function_items = List().tag(sync=True) + function_selected = Unicode('Mean').tag(sync=True) + filename = Unicode().tag(sync=True) + extraction_available = Bool(False).tag(sync=True) + overwrite_warn = Bool(False).tag(sync=True) + + aperture_method_items = List().tag(sync=True) + aperture_method_selected = Unicode('Center').tag(sync=True) + + conflicting_aperture_and_function = Bool(False).tag(sync=True) + conflicting_aperture_error_message = Unicode('Aperture method Exact cannot be selected along' + ' with Min or Max.').tag(sync=True) + + # export_enabled controls whether saving to a file is enabled via the UI. This + # is a temporary measure to allow server-installations to disable saving server-side until + # saving client-side is supported + export_enabled = Bool(False).tag(sync=True) + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.dataset.filters = ['is_flux_cube'] + + # TODO: in the future this could be generalized with support in SelectPluginComponent + self.aperture._default_text = 'Entire Cube' + self.aperture._manual_options = ['Entire Cube'] + self.aperture.items = [{"label": "Entire Cube"}] + self.aperture._subset_selected_changed_callback = self._update_extract + # need to reinitialize choices since we overwrote items and some subsets may already + # exist. + self.aperture._initialize_choices() + self.aperture.select_default() + + self.extracted_ramp = None + + self.function = SelectPluginComponent( + self, + items='function_items', + selected='function_selected', + manual_options=['Mean', 'Median', 'Min', 'Max', 'Sum'] + ) + self._set_default_results_label() + self.add_results.viewer.filters = ['is_slice_indicator_viewer'] + + self.session.hub.subscribe(self, SliceValueUpdatedMessage, + handler=self._on_slice_changed) + + self.session.hub.subscribe(self, DataCollectionAddMessage, + handler=self._on_data_added) + + self.session.hub.subscribe(self, SubsetCreateMessage, + handler=self._on_subset_update) + + self.session.hub.subscribe(self, SubsetUpdateMessage, + handler=self._on_subset_update) + + self.session.hub.subscribe(self, SubsetDeleteMessage, + handler=self._on_subset_delete) + + self._update_disabled_msg() + + if self.app.state.settings.get('server_is_remote', False): + # when the server is remote, saving the file in python would save on the server, not + # on the user's machine, so export support in cubeviz should be disabled + self.export_enabled = False + + @property + def integration_viewer(self): + viewer = self.app.get_viewer( + self.app._jdaviz_helper._default_integration_viewer_reference_name + ) + return viewer + + def _on_data_added(self, msg={}): + # only perform the default collapse after the first data load: + if len(self.app.data_collection) == 2: + self.extract(add_data=True) + self.integration_viewer._initialize_x_axis() + + def _on_subset_update(self, msg={}): + if not hasattr(self.app._jdaviz_helper, 'cube_cache'): + # if called before fully initialized + return + + subset_lbl = msg.subset.label + color = msg.subset.style.color + subset = self.app.get_subsets(subset_lbl)[0] + region = subset['region'] + + if region is None: + return + + # glue region has transposed coords relative to cached cube: + region_mask = region.to_mask().to_image(self.cube.shape[:-1]).astype(bool).T + cube_subset = self.cube[region_mask] # shape: (N pixels extracted, M groups) + + n_pixels_in_extraction = cube_subset.shape[0] + + if n_pixels_in_extraction < self.subset_preview_limit: + self.subset_preview_warning = False + select_from_cube_subset = Ellipsis + else: + self.subset_preview_warning = True + select_from_cube_subset = rng.integers( + 0, n_pixels_in_extraction, size=self.subset_preview_limit + ) + + marks = [ + PluginLine( + self.integration_viewer, x=np.arange(cube_subset.shape[1]), y=y, + stroke_width=1, colors=[color], opacities=[0.25], label=subset_lbl, + visible=self._subset_preview_visible and subset_lbl == self.aperture.selected + ) + for y in cube_subset[select_from_cube_subset] + ] + + self.integration_viewer.figure.marks = [ + mark for mark in self.integration_viewer.figure.marks + if getattr(mark, 'label', None) != subset_lbl + ] + marks + + def _on_subset_delete(self, msg={}): + subset_lbl = msg.subset.label + self.integration_viewer.figure.marks = [ + mark for mark in self.integration_viewer.figure.marks + if getattr(mark, 'label', None) != subset_lbl + ] + + @observe('is_active', 'show_subset_preview', 'aperture_selected') + def _update_subset_previews(self, msg={}): + # remove preview marks for non-selected subsets + + if not hasattr(self.app._jdaviz_helper, '_default_integration_viewer_reference_name'): + return + + redraw_limits = False + for mark in self.integration_viewer.figure.marks: + if isinstance(mark, PluginLine) and mark.label is not None: + new_visibility = ( + self._subset_preview_visible and + self.aperture.selected == mark.label + ) + if mark.visible != new_visibility: + mark.visible = new_visibility + redraw_limits = True + + if redraw_limits: + self.integration_viewer.reset_limits() + + @property + def _subset_preview_visible(self): + return self.show_subset_preview and self.is_active + + @property + def user_api(self): + expose = [ + 'dataset', 'function', 'aperture', 'add_results', 'extract' + ] + + return PluginUserApi(self, expose=expose) + + @observe('dataset_items') + def _update_disabled_msg(self, msg={}): + for data in self.app.data_collection: + if data.data.ndim == 3: + self.disabled_msg = '' + break + else: + # no cube-like data loaded. Once loaded, the parser will unset this + self.disabled_msg = ( + f"{self.__class__.__name__} requires a 3d cube dataset to be loaded, " + "please load data to enable this plugin." + ) + + @property + def live_update_subscriptions(self): + return {'data': ('dataset',), 'subset': ('aperture', )} + + def __call__(self, add_data=True): + return self.extract(add_data=add_data) + + @property + def slice_display_unit_name(self): + return 'temporal' + + @property + def spatial_axes(self): + # Collapse an e.g. 3D ramp cube to 1D ramp profile, assuming that last axis + # is always the group/resultant index + return (0, 1) + + @property + def slice_indicator_viewers(self): + return [v for v in self.app._viewer_store.values() if isinstance(v, WithSliceIndicator)] + + @observe('active_step', 'is_active') + def _active_step_changed(self, *args): + self.aperture._set_mark_visiblities(self.active_step in ('', 'ap', 'extract')) + + @property + def slice_plugin(self): + return self.app._jdaviz_helper.plugins['Slice'] + + def _extract_in_new_instance(self, dataset=None, function='Mean', subset_lbl=None, + auto_update=False, add_data=False): + # create a new instance of the Ramp Extraction plugin (to not affect the instance in + # the tray) and extract the entire cube with defaults. + plg = self.new() + plg.dataset.selected = self.dataset.selected + if subset_lbl is not None: + plg.aperture.selected = subset_lbl + plg.function.selected = function + plg.add_results.auto_update_result = auto_update + # all other settings remain at their plugin defaults + return plg(add_data=add_data) + + def _on_slice_changed(self, msg): + self.slice_group_value = msg.value + + @observe('function_selected', 'aperture_method_selected') + def _update_aperture_method_on_function_change(self, *args): + if (self.function_selected.lower() in ('min', 'max') and + self.aperture_method_selected.lower() != 'center'): + self.conflicting_aperture_and_function = True + else: + self.conflicting_aperture_and_function = False + + @property + def cube(self): + return self.app._jdaviz_helper.cube_cache[self.dataset.selected] + + @property + def slice_display_unit(self): + x_display_unit = self.app._get_display_unit(self.slice_display_unit_name) + if x_display_unit not in ['None', None]: + return u.Unit(x_display_unit) + return u.dimensionless_unscaled + + @property + def aperture_weight_mask(self): + cube_shape = self.cube.shape + if self.aperture.selected != self.aperture.default_text: + + # note: glue subset mask is transposed relative to cube + region_mask = self.app.get_subsets( + subset_name=self.aperture.selected + )[0]['region'].to_mask().to_image( + cube_shape[:-1] + ).astype(bool).T + return region_mask + + return np.ones(cube_shape[:-1]).astype(bool) + + def _extract_from_aperture(self, **kwargs): + # This plugin collapses over the *spatial axes* (optionally over a spatial subset, + # defaults to ``No Subset``). Since the Cubeviz parser puts the fluxes + # and uncertainties in different glue Data objects, we translate the ramp + # cube and its uncertainties into separate NDDataArrays, then combine them: + selected_func = self.function_selected.lower() + + if not isinstance(self.aperture, ApertureSubsetSelect): + raise ValueError("aperture must be an ApertureSubsetSelect object") + + nddata = self.cube + mask = self.aperture_weight_mask + + if nddata.mask is not None: + mask = mask & ~nddata.mask + + collapsed = getattr(np, selected_func)( + nddata.data[mask], + # after the fancy indexing above, axis=1 corresponds to groups, and + # operations over axis=0 corresponds to individual pixels: + axis=0 + ) << nddata.unit + + def expand(x): + # put the resulting 1D profile (counts vs. groups) into the + # third dimension, which is the group dimension in the + # original 3D cube: + return np.expand_dims(x, axis=(0, 1)) + + return NDDataArray( + data=expand(collapsed), + meta=nddata.meta + ) + + def _preview_x_from_extracted(self, extracted): + return np.arange(extracted.shape[-1]) + + def _preview_y_from_extracted(self, extracted): + return extracted.data + + @with_spinner() + def extract(self, add_data=True, **kwargs): + """ + Extract the ramp profile from the data cube according to the plugin inputs. + + Parameters + ---------- + add_data : bool, optional + Whether to load the resulting data back into the application according to + ``add_results``. + kwargs : dict + Additional keyword arguments passed to the NDDataArray collapse operation. + Examples include ``propagate_uncertainties`` and ``operation_ignores_mask``. + """ + if self.conflicting_aperture_and_function: + raise ValueError(self.conflicting_aperture_error_message) + + selected_func = self.function_selected.lower() + ndd = self._extract_from_aperture(**kwargs) + self.extracted_ramp = ndd + self.extraction_available = True + fname_label = self.dataset_selected.replace("[", "_").replace("]", "") + self.filename = f"extracted_{selected_func}_{fname_label}.fits" + + if add_data: + if default_color := self.aperture.selected_item.get('color', None): + ndd.meta['_default_color'] = default_color + self.add_results.add_results_from_plugin(ndd) + + snackbar_message = SnackbarMessage( + f"{self.resulting_product_name.title()} extracted successfully.", + color="success", + sender=self) + self.hub.broadcast(snackbar_message) + + return ndd + + def vue_ramp_extraction(self, *args, **kwargs): + try: + self.extract(add_data=True) + except Exception as e: + self.hub.broadcast(SnackbarMessage( + f"Extraction failed: {repr(e)}", + sender=self, color="error")) + + @observe('aperture_selected', 'function_selected') + def _set_default_results_label(self, event={}): + if not hasattr(self, 'aperture'): + return + if self.aperture.selected == self.aperture.default_text: + self.results_label_default = (f"{self.resulting_product_name.title()} " + f"({self.function_selected.lower()})") + + else: + self.results_label_default = (f"{self.resulting_product_name.title()} " + f"({self.aperture_selected}, " + f"{self.function_selected.lower()})") + + @cached_property + def marks(self): + if not self._tray_instance: + return {} + # TODO: iterate over self.slice_indicator_viewers and handle adding/removing viewers + + sv = self.slice_indicator_viewers[0] + marks = {'extract': PluginLine(sv, visible=self.is_active)} + sv.figure.marks = sv.figure.marks + [marks['extract'],] + return marks + + def _clear_marks(self): + for mark in self.marks.values(): + if getattr(mark, 'visible', False): + mark.visible = False + + @observe('is_active', 'show_live_preview', + 'dataset_selected', 'aperture_selected', + 'function_selected', + 'aperture_method_selected', + 'previews_temp_disabled') + def _live_update_marks(self, event={}): + self._update_marks(event) + + @skip_if_not_tray_instance() + def _update_marks(self, event={}): + visible = self.show_live_preview and self.is_active + + if not visible: + self._clear_marks() + return + + # ensure the correct visibility, always (whether or not there have been updates) + if hasattr(self.marks['extract'], 'visible'): + self.marks['extract'].visible = True + + # _live_update will skip if no updates since last active + self._live_update_extract(event) + + @skip_if_no_updates_since_last_active() + @with_temp_disable(timeout=0.4) + def _live_update_extract(self, event={}): + self._update_extract() + + @skip_if_not_tray_instance() + def _update_extract(self): + try: + ext = self.extract(add_data=False) + except (ValueError, Exception): + self._clear_marks() + return False + + self.marks['extract'].update_xy(self._preview_x_from_extracted(ext), + self._preview_y_from_extracted(ext)) diff --git a/jdaviz/configs/rampviz/plugins/ramp_extraction/ramp_extraction.vue b/jdaviz/configs/rampviz/plugins/ramp_extraction/ramp_extraction.vue new file mode 100644 index 0000000000..da8194b206 --- /dev/null +++ b/jdaviz/configs/rampviz/plugins/ramp_extraction/ramp_extraction.vue @@ -0,0 +1,154 @@ + + \ No newline at end of file diff --git a/jdaviz/configs/rampviz/plugins/tools.py b/jdaviz/configs/rampviz/plugins/tools.py new file mode 100644 index 0000000000..a50fe7ffbe --- /dev/null +++ b/jdaviz/configs/rampviz/plugins/tools.py @@ -0,0 +1,43 @@ +import os +import numpy as np +from glue.config import viewer_tool +from jdaviz.configs.default.plugins.tools import ProfileFromCube + +__all__ = ['RampPerPixel'] + +ICON_DIR = os.path.join(os.path.dirname(__file__), '..', '..', '..', 'data', 'icons') + + +@viewer_tool +class RampPerPixel(ProfileFromCube): + + # TODO: replace "pixelspectra" graphic with a "pixelramp" equivalent + icon = os.path.join(ICON_DIR, 'pixelspectra.svg') + tool_id = 'jdaviz:rampperpixel' + action_text = 'See ramp at a single pixel' + tool_tip = ( + 'Click on the viewer and see the ramp profile ' + 'at that pixel in the integration viewer' + ) + + def on_mouse_move(self, data): + if data['event'] == 'mouseleave': + self._mark.visible = False + self._reset_profile_viewer_bounds() + return + + x = int(np.round(data['domain']['x'])) + y = int(np.round(data['domain']['y'])) + + cube_cache = self.viewer.jdaviz_app._jdaviz_helper.cube_cache + spectrum = cube_cache[list(cube_cache.keys())[0]].data + + if x >= spectrum.shape[0] or x < 0 or y >= spectrum.shape[1] or y < 0: + self._mark.visible = False + else: + y_values = spectrum[x, y, :] + if np.all(np.isnan(y_values)): + self._mark.visible = False + return + self._mark.update_xy(np.arange(y_values.size), y_values) + self._mark.visible = True diff --git a/jdaviz/configs/rampviz/plugins/viewers.py b/jdaviz/configs/rampviz/plugins/viewers.py new file mode 100644 index 0000000000..63b5a59b7f --- /dev/null +++ b/jdaviz/configs/rampviz/plugins/viewers.py @@ -0,0 +1,140 @@ +import numpy as np +from astropy.nddata import NDDataArray +from glue.core import BaseData +from glue_jupyter.bqplot.image import BqplotImageView + +from jdaviz.configs.default.plugins.viewers import JdavizViewerMixin, JdavizProfileView +from jdaviz.configs.cubeviz.plugins.mixins import WithSliceSelection, WithSliceIndicator +from jdaviz.core.registries import viewer_registry +from jdaviz.core.freezable_state import FreezableBqplotImageViewerState + +__all__ = ['RampvizProfileView', 'RampvizImageView'] + + +@viewer_registry("rampviz-profile-viewer", label="Profile 1D (Rampviz)") +class RampvizProfileView(JdavizProfileView, WithSliceIndicator): + # categories: zoom resets, zoom, pan, subset, select tools, shortcuts + tools_nested = [ + ['jdaviz:homezoom', 'jdaviz:prevzoom'], + ['jdaviz:boxzoom', 'jdaviz:xrangezoom', 'jdaviz:yrangezoom'], + ['jdaviz:panzoom', 'jdaviz:panzoom_x', 'jdaviz:panzoom_y'], + ['jdaviz:selectslice'], + ['jdaviz:sidebar_plot', 'jdaviz:sidebar_export'] + ] + + default_class = NDDataArray + _default_profile_subset_type = 'temporal' + + def __init__(self, *args, **kwargs): + kwargs.setdefault('default_tool_priority', ['jdaviz:selectslice']) + super().__init__(*args, **kwargs) + + def _initialize_x_axis(self): + if len(self.state.x_att_helper.choices): + self.state.x_att = self.state.x_att_helper.choices[-1] + self.set_plot_axes() + self.reset_limits() + + def reset_limits(self): + super().reset_limits() + + # override to reset to the global y limits including marks: + global_y_min = float(self.state.y_min) + global_y_max = float(self.state.y_max) + for mark in self.figure.marks: + if len(mark.y) and mark.visible: + global_y_min = min(global_y_min, np.nanmin(mark.y)) + global_y_max = max(global_y_max, np.nanmax(mark.y)) + + if global_y_min != self.state.y_min or global_y_max != self.state.y_max: + self.set_limits( + y_min=global_y_min * 0.9, + y_max=global_y_max * 1.1 + ) + + def set_plot_axes(self): + + with self.figure.hold_sync(): + self.figure.axes[0].label = "Group" + self.figure.axes[1].label = self.state.y_display_unit + + # Make it so axis labels are not covering tick numbers. + self.figure.fig_margin["left"] = 95 + self.figure.fig_margin["bottom"] = 60 + self.figure.send_state('fig_margin') # Force update + self.figure.axes[0].label_offset = "40" + self.figure.axes[1].label_offset = "-70" + # NOTE: with tick_style changed below, the default responsive ticks in bqplot result + # in overlapping tick labels. For now we'll hardcode at 8, but this could be removed + # (default to None) if/when bqplot auto ticks react to styling options. + self.figure.axes[1].num_ticks = 8 + + # Set Y-axis to scientific notation + self.figure.axes[1].tick_format = '0.1e' + + for i in (0, 1): + self.figure.axes[i].tick_style = {'font-size': 15, 'font-weight': 600} + + +@viewer_registry("rampviz-image-viewer", label="Image 2D (Rampviz)") +class RampvizImageView(JdavizViewerMixin, WithSliceSelection, BqplotImageView): + # categories: zoom resets, (zoom, pan), subset, select tools, shortcuts + # NOTE: zoom and pan are merged here for space consideration and to avoid + # overflow to second row when opening the tray + tools_nested = [ + ['jdaviz:homezoom', 'jdaviz:prevzoom'], + ['jdaviz:pixelboxzoommatch', 'jdaviz:boxzoom'], + ['jdaviz:pixelpanzoommatch', 'jdaviz:panzoom'], + ['bqplot:truecircle', 'bqplot:rectangle', 'bqplot:ellipse', + 'bqplot:circannulus'], + ['jdaviz:rampperpixel'], + ['jdaviz:sidebar_plot', 'jdaviz:sidebar_export'] + ] + + default_class = NDDataArray + _state_cls = FreezableBqplotImageViewerState + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # provide reference from state back to viewer to use for zoom syncing + self.state._viewer = self + + self._subscribe_to_layers_update() + self.state.add_callback('reference_data', self._initial_x_axis) + + # Hide axes by default + self.state.show_axes = False + + @property + def _default_group_viewer_reference_name(self): + return self.jdaviz_helper._default_group_viewer_reference_name + + @property + def _default_diff_viewer_reference_name(self): + return self.jdaviz_helper._default_diff_viewer_reference_name + + @property + def _default_integration_viewer_reference_name(self): + return self.jdaviz_helper._default_integration_viewer_reference_name + + def _initial_x_axis(self, *args): + # Make sure that the x_att is correct on data load + ref_data = self.state.reference_data + if ref_data and ref_data.ndim == 3: + self.state.x_att = ref_data.id["Pixel Axis 0 [z]"] + + def set_plot_axes(self): + self.figure.axes[1].tick_format = None + self.figure.axes[0].tick_format = None + + self.figure.axes[1].label = "y: pixels" + self.figure.axes[0].label = "x: pixels" + + # Make it so y axis label is not covering tick numbers. + self.figure.axes[1].label_offset = "-50" + + def data(self, cls=None): + return [layer_state.layer # .get_object(cls=cls or self.default_class) + for layer_state in self.state.layers + if hasattr(layer_state, 'layer') and + isinstance(layer_state.layer, BaseData)] diff --git a/jdaviz/configs/rampviz/rampviz.yaml b/jdaviz/configs/rampviz/rampviz.yaml new file mode 100644 index 0000000000..d961561107 --- /dev/null +++ b/jdaviz/configs/rampviz/rampviz.yaml @@ -0,0 +1,45 @@ +settings: + configuration: rampviz + data: + auto_populate: true + parser: rampviz-data-parser + visible: + menu_bar: false + toolbar: true + tray: true + tab_headers: true + dense_toolbar: false + server_is_remote: false + context: + notebook: + max_height: 750px +toolbar: + - g-data-tools + - g-subset-tools + - g-coords-info +tray: + - g-metadata-viewer + - g-plot-options + - g-data-quality + - g-subset-plugin + - ramp-extraction + - g-markers + - cube-slice + - export + - about +viewer_area: + - container: col + children: + - container: row + viewers: + - name: Flux + plot: rampviz-image-viewer + reference: group-viewer + - name: Difference + plot: rampviz-image-viewer + reference: diff-viewer + - container: row + viewers: + - name: Integration profile + plot: rampviz-profile-viewer + reference: integration-viewer diff --git a/jdaviz/configs/rampviz/tests/__init__.py b/jdaviz/configs/rampviz/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/jdaviz/configs/rampviz/tests/test_helper.py b/jdaviz/configs/rampviz/tests/test_helper.py new file mode 100644 index 0000000000..7478186b8f --- /dev/null +++ b/jdaviz/configs/rampviz/tests/test_helper.py @@ -0,0 +1,19 @@ +import pytest +from jdaviz.configs.imviz.plugins.parsers import HAS_ROMAN_DATAMODELS + + +@pytest.mark.skipif(not HAS_ROMAN_DATAMODELS, reason="roman_datamodels is not installed") +def test_load_data(rampviz_helper, roman_level_1_ramp): + rampviz_helper.load_data(roman_level_1_ramp) + + # on ramp cube load (1), the parser loads a diff cube (2) and + # the ramp extraction plugin produces a default extraction (3): + assert len(rampviz_helper.app.data_collection) == 3 + + # each viewer should have one loaded data entry: + for refname in 'group-viewer, diff-viewer, integration-viewer'.split(', '): + viewer = rampviz_helper.app.get_viewer(refname) + assert len(viewer.state.layers) == 1 + + assert viewer.axis_x.label == 'Group' + assert viewer.axis_y.label == 'DN' diff --git a/jdaviz/configs/rampviz/tests/test_ramp_extraction.py b/jdaviz/configs/rampviz/tests/test_ramp_extraction.py new file mode 100644 index 0000000000..5ce04b895d --- /dev/null +++ b/jdaviz/configs/rampviz/tests/test_ramp_extraction.py @@ -0,0 +1,57 @@ +import pytest +from regions import CirclePixelRegion, PixCoord +from jdaviz.core.marks import Lines +from jdaviz.configs.imviz.plugins.parsers import HAS_ROMAN_DATAMODELS + + +@pytest.mark.skipif(not HAS_ROMAN_DATAMODELS, reason="roman_datamodels is not installed") +def test_previews(rampviz_helper, roman_level_1_ramp): + rampviz_helper.load_data(roman_level_1_ramp) + + # add subset: + region = CirclePixelRegion(center=PixCoord(12.5, 15.5), radius=2) + rampviz_helper.load_regions(region) + ramp_extr = rampviz_helper.plugins['Ramp Extraction']._obj + + subsets = rampviz_helper.app.get_subsets() + ramp_cube = rampviz_helper.app.data_collection[0] + n_groups = ramp_cube.shape[-1] + + assert len(subsets) == 1 + assert 'Subset 1' in subsets + + integration_viewer = rampviz_helper.app.get_viewer('integration-viewer') + + # contains a layer for the default ramp extraction and the subset: + assert len(integration_viewer.layers) == 2 + + # profile viewer x-axis is the group dimension + assert str(integration_viewer.state.x_att) == 'Pixel Axis 2 [x]' + + # no subset previews should be visible yet: + assert len([ + mark for mark in integration_viewer.native_marks + # should be a subclass of Lines, should be visible, + if mark.visible and isinstance(mark, Lines) and + # and the default profile is a 1D series with length n_groups: + len(mark.x) == n_groups + ]) == 1 + + # check that when the plugin is active, there's one ramp profile generated by the + # plugin per pixel in the subset (if show_subset_preview), + # plus one live preview (if show_live_preview): + for show_live_preview in [True, False]: + for show_subset_preview in [True, False]: + with ramp_extr.as_active(): + ramp_extr.show_live_preview = show_live_preview + ramp_extr.show_subset_preview = show_subset_preview + ramp_extr.aperture_selected = 'Subset 1' + + subset_state = subsets[ramp_extr.aperture_selected][0]['subset_state'] + n_pixels_in_subset = subset_state.to_mask(ramp_cube)[..., 0].sum() + + assert len([ + mark for mark in integration_viewer.custom_marks + if mark.visible and isinstance(mark, Lines) and + len(mark.x) == n_groups + ]) == int(show_subset_preview) * n_pixels_in_subset + int(show_live_preview) diff --git a/jdaviz/configs/rampviz/tests/test_slice.py b/jdaviz/configs/rampviz/tests/test_slice.py new file mode 100644 index 0000000000..7aad59d4fd --- /dev/null +++ b/jdaviz/configs/rampviz/tests/test_slice.py @@ -0,0 +1,103 @@ +import pytest +from jdaviz.configs.cubeviz.plugins.slice.slice import Slice +from jdaviz.configs.imviz.plugins.parsers import HAS_ROMAN_DATAMODELS + + +@pytest.mark.skipif(not HAS_ROMAN_DATAMODELS, reason="roman_datamodels is not installed") +def test_slice(rampviz_helper, roman_level_1_ramp): + app = rampviz_helper.app + sl = Slice(app=app) + + # No data yet + assert len(sl.slice_selection_viewers) == 2 # group-viewer, diff-viewer + assert len(sl.slice_indicator_viewers) == 1 # integration-viewer + assert len(sl.valid_indicator_values_sorted) == 0 + assert len(sl.valid_selection_values_sorted) == 0 + + # Make sure nothing crashes if plugin used without data] + sl.vue_play_next() + sl.vue_play_start_stop() + assert not sl.is_playing + + rampviz_helper.load_data(roman_level_1_ramp, data_label='test') + app.add_data_to_viewer("group-viewer", "test[DATA]") + app.add_data_to_viewer("diff-viewer", "test[DATA]") + app.add_data_to_viewer("integration-viewer", "Ramp (mean)") + sv = rampviz_helper.viewers['integration-viewer']._obj + + # sample ramp only has 10 groups + assert len(sv.slice_values) == 10 + assert len(sl.valid_indicator_values_sorted) == 10 + slice_values = sl.valid_selection_values_sorted + assert len(slice_values) == 10 + + assert sl.value == slice_values[len(slice_values) // 2] + assert rampviz_helper.app.get_viewer("group-viewer").slice == len(slice_values) // 2 + assert rampviz_helper.app.get_viewer("group-viewer").state.slices[-1] == 5 + assert rampviz_helper.app.get_viewer("diff-viewer").state.slices[-1] == 5 + rampviz_helper.select_group(slice_values[0]) + assert rampviz_helper.app.get_viewer("group-viewer").slice == 0 + assert sl.value == slice_values[0] + + rampviz_helper.select_group(slice_values[1]) + assert sl.value == slice_values[1] + + # Retrieve updated slice_values + slice_values = sl.valid_selection_values_sorted + + # Test player buttons API + + sl.vue_goto_first() + assert sl.value == slice_values[0] + + sl.vue_goto_last() + assert sl.value == slice_values[-1] + + sl.vue_play_next() # Should automatically wrap to beginning + assert sl.value == slice_values[0] + + sl.vue_play_start_stop() # Start + assert sl.is_playing + assert sl._player.is_alive() + sl.vue_play_next() # Should be no-op + sl.vue_goto_last() # Should be no-op + sl.vue_goto_first() # Should be no-op + sl.vue_play_start_stop() # Stop + assert not sl.is_playing + assert not sl._player + # NOTE: Hard to check sl.slice here because it is non-deterministic. + + +@pytest.mark.skipif(not HAS_ROMAN_DATAMODELS, reason="roman_datamodels is not installed") +def test_indicator_settings(rampviz_helper, roman_level_1_ramp): + rampviz_helper.load_data(roman_level_1_ramp, data_label='test') + app = rampviz_helper.app + app.add_data_to_viewer("group-viewer", "test[DATA]") + app.add_data_to_viewer("integration-viewer", "Ramp (mean)") + sl = rampviz_helper.plugins['Slice']._obj + sv = app.get_viewer('integration-viewer') + indicator = sv.slice_indicator + + assert sl.show_indicator is True + assert indicator._show_if_inactive is True + assert sl.show_value is True + assert indicator.label.visible is True + + sl.show_indicator = False + assert indicator._show_if_inactive is False + + sl.show_value = False + assert indicator.label.visible is False + + +@pytest.mark.skipif(not HAS_ROMAN_DATAMODELS, reason="roman_datamodels is not installed") +def test_init_slice(rampviz_helper, roman_level_1_ramp): + rampviz_helper.load_data(roman_level_1_ramp, data_label='test') + + fv = rampviz_helper.app.get_viewer('group-viewer') + sl = rampviz_helper.plugins['Slice'] + slice_values = sl._obj.valid_selection_values_sorted + + assert sl.value == slice_values[len(slice_values)//2] + assert fv.slice == 5 + assert fv.state.slices == (0, 0, 5) diff --git a/jdaviz/configs/specviz/plugins/viewers.py b/jdaviz/configs/specviz/plugins/viewers.py index a228c2b026..bde414f458 100644 --- a/jdaviz/configs/specviz/plugins/viewers.py +++ b/jdaviz/configs/specviz/plugins/viewers.py @@ -2,41 +2,21 @@ import numpy as np from astropy import table -from astropy import units as u -from astropy.nddata import StdDevUncertainty, VarianceUncertainty, InverseVariance -from echo import delay_callback -from glue.config import data_translator -from glue.core import BaseData -from glue.core.exceptions import IncompatibleAttribute -from glue.core.units import UnitConverter -from glue.core.subset import Subset -from glue.core.subset_group import GroupedSubset -from glue_astronomy.spectral_coordinates import SpectralCoordinates -from glue_jupyter.bqplot.profile import BqplotProfileView from matplotlib.colors import cnames from specutils import Spectrum1D -from jdaviz.core.events import SpectralMarksChangedMessage, LineIdentifyMessage, SnackbarMessage +from jdaviz.core.events import SpectralMarksChangedMessage, LineIdentifyMessage from jdaviz.core.registries import viewer_registry -from jdaviz.core.marks import SpectralLine, LineUncertainties, ScatterMask, OffscreenLinesMarks +from jdaviz.core.marks import SpectralLine from jdaviz.core.linelists import load_preset_linelist, get_available_linelists from jdaviz.core.freezable_state import FreezableProfileViewerState -from jdaviz.configs.default.plugins.viewers import JdavizViewerMixin -from jdaviz.utils import get_subset_type +from jdaviz.configs.default.plugins.viewers import JdavizProfileView __all__ = ['SpecvizProfileView'] -uc = UnitConverter() - -uncertainty_str_to_cls_mapping = { - "std": StdDevUncertainty, - "var": VarianceUncertainty, - "ivar": InverseVariance -} - @viewer_registry("specviz-profile-viewer", label="Profile 1D (Specviz)") -class SpecvizProfileView(JdavizViewerMixin, BqplotProfileView): +class SpecvizProfileView(JdavizProfileView): # categories: zoom resets, zoom, pan, subset, select tools, shortcuts tools_nested = [ ['jdaviz:homezoom', 'jdaviz:prevzoom'], @@ -50,68 +30,7 @@ class SpecvizProfileView(JdavizViewerMixin, BqplotProfileView): default_class = Spectrum1D spectral_lines = None _state_cls = FreezableProfileViewerState - - def __init__(self, *args, **kwargs): - default_tool_priority = kwargs.pop('default_tool_priority', []) - super().__init__(*args, **kwargs) - - self._subscribe_to_layers_update() - self.initialize_toolbar(default_tool_priority=default_tool_priority) - self._offscreen_lines_marks = OffscreenLinesMarks(self) - self.figure.marks = self.figure.marks + self._offscreen_lines_marks.marks - - self.state.add_callback('show_uncertainty', self._show_uncertainty_changed) - - self.display_mask = False - - # Change collapse function to sum - self.state.function = 'sum' - - def _expected_subset_layer_default(self, layer_state): - super()._expected_subset_layer_default(layer_state) - - layer_state.linewidth = 3 - - def data(self, cls=None): - # Grab the user's chosen statistic for collapsing data - statistic = getattr(self.state, 'function', None) - data = [] - - for layer_state in self.state.layers: - if hasattr(layer_state, 'layer'): - lyr = layer_state.layer - - # For raw data, just include the data itself - if isinstance(lyr, BaseData): - _class = cls or self.default_class - - if _class is not None: - cache_key = (lyr.label, statistic) - if cache_key in self.jdaviz_app._get_object_cache: - layer_data = self.jdaviz_app._get_object_cache[cache_key] - else: - # If spectrum, collapse via the defined statistic - if _class == Spectrum1D: - layer_data = lyr.get_object(cls=_class, statistic=statistic) - else: - layer_data = lyr.get_object(cls=_class) - self.jdaviz_app._get_object_cache[cache_key] = layer_data - - data.append(layer_data) - - # For subsets, make sure to apply the subset mask to the layer data first - elif isinstance(lyr, Subset): - layer_data = lyr - - if _class is not None: - handler, _ = data_translator.get_handler_for(_class) - try: - layer_data = handler.to_object(layer_data, statistic=statistic) - except IncompatibleAttribute: - continue - data.append(layer_data) - - return data + _default_profile_subset_type = 'spectral' @property def redshift(self): @@ -253,14 +172,6 @@ def erase_spectral_lines(self, name=None, name_rest=None, show_none=True): fig.marks = temp_marks self._broadcast_plotted_lines() - def get_scales(self): - fig = self.figure - # Deselect any pan/zoom or subsetting tools so they don't interfere - # with the scale retrieval - if self.toolbar.active_tool is not None: - self.toolbar.active_tool = None - return {'x': fig.interaction.x_scale, 'y': fig.interaction.y_scale} - def plot_spectral_line(self, line, global_redshift=None, plot_units=None, **kwargs): if isinstance(line, str): # Try the full index first (for backend calls), otherwise name only @@ -327,297 +238,3 @@ def plot_spectral_lines(self, colors=["blue"], global_redshift=None, **kwargs): def available_linelists(self): return get_available_linelists() - - def _show_uncertainty_changed(self, msg=None): - # this is subscribed in init to watch for changes to the state - # object since uncertainty handling is in jdaviz instead of glue/glue-jupyter - if self.state.show_uncertainty: - self._plot_uncertainties() - else: - self._clean_error() - - def show_mask(self): - self.display_mask = True - self._plot_mask() - - def clean(self): - # Remove extra traces, in case they exist. - self.display_mask = False - self._clean_mask() - - # this will automatically call _clean_error via _show_uncertainty_changed - self.state.show_uncertainty = False - - def _clean_mask(self): - fig = self.figure - fig.marks = [x for x in fig.marks if not isinstance(x, ScatterMask)] - - def _clean_error(self): - fig = self.figure - fig.marks = [x for x in fig.marks if not isinstance(x, LineUncertainties)] - - def add_data(self, data, color=None, alpha=None, **layer_state): - """ - Overrides the base class to add markers for plotting - uncertainties and data quality flags. - - Parameters - ---------- - spectrum : :class:`glue.core.data.Data` - Data object with the spectrum. - color : obj - Color value for plotting. - alpha : float - Alpha value for plotting. - - Returns - ------- - result : bool - `True` if successful, `False` otherwise. - """ - # If this is the first loaded data, set things up for unit conversion. - if len(self.layers) == 0: - reset_plot_axes = True - else: - # Check if the new data flux unit is actually compatible since flux not linked. - try: - uc.to_unit(data, data.find_component_id("flux"), [1, 1], - u.Unit(self.state.y_display_unit)) # Error if incompatible - except Exception as err: - # Raising exception here introduces a dirty state that messes up next load_data - # but not raising exception also causes weird behavior unless we remove the data - # completely. - self.session.hub.broadcast(SnackbarMessage( - f"Failed to load {data.label}, so removed it: {repr(err)}", - sender=self, color='error')) - self.jdaviz_app.data_collection.remove(data) - return False - reset_plot_axes = False - - # The base class handles the plotting of the main - # trace representing the spectrum itself. - result = super().add_data(data, color, alpha, **layer_state) - - if reset_plot_axes: - x_units = data.get_component(self.state.x_att.label).units - y_units = data.get_component("flux").units - with delay_callback(self.state, "x_display_unit", "y_display_unit"): - self.state.x_display_unit = x_units if len(x_units) else None - self.state.y_display_unit = y_units if len(y_units) else None - self.set_plot_axes() - - self._plot_uncertainties() - - self._plot_mask() - - # Set default linewidth on any created spectral subset layers - # NOTE: this logic will need updating if we add support for multiple cubes as this assumes - # that new data entries (from model fitting or gaussian smooth, etc) will only be spectra - # and all subsets affected will be spectral - for layer in self.state.layers: - if (isinstance(layer.layer, GroupedSubset) - and get_subset_type(layer.layer) == 'spectral' - and layer.layer.data.label == data.label): - layer.linewidth = 3 - - return result - - def _plot_mask(self): - if not self.display_mask: - return - - # Remove existing mask marks - self._clean_mask() - - # Loop through all active data in the viewer - for index, layer_state in enumerate(self.state.layers): - lyr = layer_state.layer - comps = [str(component) for component in lyr.components] - - # Skip subsets - if hasattr(lyr, "subset_state"): - continue - - # Ignore data that does not have a mask component - if "mask" in comps: - mask = np.array(lyr['mask'].data) - - data_obj = lyr.data.get_object() - data_x = data_obj.spectral_axis.value - data_y = data_obj.flux.value - - # For plotting markers only for the masked data - # points, erase un-masked data from trace. - y = np.where(np.asarray(mask) == 0, np.nan, data_y) - - # A subclass of the bqplot Scatter object, ScatterMask places - # 'X' marks where there is masked data in the viewer. - color = layer_state.color - alpha_shade = layer_state.alpha / 3 - mask_line_mark = ScatterMask(scales=self.scales, - marker='cross', - x=data_x, - y=y, - stroke_width=0.5, - colors=[color], - default_size=25, - default_opacities=[alpha_shade] - ) - # Add mask marks to viewer - self.figure.marks = list(self.figure.marks) + [mask_line_mark] - - def _plot_uncertainties(self): - if not self.state.show_uncertainty: - return - - # Remove existing error bars - self._clean_error() - - # Loop through all active data in the viewer - for index, layer_state in enumerate(self.state.layers): - lyr = layer_state.layer - - # Skip subsets - if hasattr(lyr, "subset_state"): - continue - - comps = [str(component) for component in lyr.components] - - # Ignore data that does not have an uncertainty component - if "uncertainty" in comps: # noqa - error = np.array(lyr['uncertainty'].data) - - # ensure that the uncertainties are represented as stddev: - uncertainty_type_str = lyr.meta.get('uncertainty_type', 'stddev') - uncert_cls = uncertainty_str_to_cls_mapping[uncertainty_type_str] - error = uncert_cls(error).represent_as(StdDevUncertainty).array - - # Then we assume that last axis is always wavelength. - # This may need adjustment after the following - # specutils PR is merged: https://github.com/astropy/specutils/pull/1033 - spectral_axis = -1 - data_obj = lyr.data.get_object(cls=Spectrum1D, statistic=None) - - if isinstance(lyr.data.coords, SpectralCoordinates): - spectral_wcs = lyr.data.coords - data_x = spectral_wcs.pixel_to_world_values( - np.arange(lyr.data.shape[spectral_axis]) - ) - if isinstance(data_x, tuple): - data_x = data_x[0] - else: - if hasattr(lyr.data.coords, 'spectral_wcs'): - spectral_wcs = lyr.data.coords.spectral_wcs - elif hasattr(lyr.data.coords, 'spectral'): - spectral_wcs = lyr.data.coords.spectral - data_x = spectral_wcs.pixel_to_world( - np.arange(lyr.data.shape[spectral_axis]) - ) - - data_y = data_obj.data - - # The shaded band around the spectrum trace is bounded by - # two lines, above and below the spectrum trace itself. - data_x_list = np.ndarray.tolist(data_x) - x = [data_x_list, data_x_list] - y = [np.ndarray.tolist(data_y - error), - np.ndarray.tolist(data_y + error)] - - if layer_state.as_steps: - for i in (0, 1): - a = np.insert(x[i], 0, 2*x[i][0] - x[i][1]) - b = np.append(x[i], 2*x[i][-1] - x[i][-2]) - edges = (a + b) / 2 - x[i] = np.concatenate((edges[:1], np.repeat(edges[1:-1], 2), edges[-1:])) - y[i] = np.repeat(y[i], 2) - x, y = np.asarray(x), np.asarray(y) - - # A subclass of the bqplot Lines object, LineUncertainties keeps - # track of uncertainties plotted in the viewer. LineUncertainties - # appear with two lines and shaded area in between. - color = layer_state.color - alpha_shade = layer_state.alpha / 3 - error_line_mark = LineUncertainties(viewer=self, - x=[x], - y=[y], - scales=self.scales, - stroke_width=1, - colors=[color, color], - fill_colors=[color, color], - opacities=[0.0, 0.0], - fill_opacities=[alpha_shade, - alpha_shade], - fill='between', - close_path=False - ) - - # Add error lines to viewer - self.figure.marks = list(self.figure.marks) + [error_line_mark] - - def set_plot_axes(self): - # Set y axes labels for the spectrum viewer - y_display_unit = self.state.y_display_unit - y_unit = u.Unit(y_display_unit) if y_display_unit else u.dimensionless_unscaled - - # Get local units. - locally_defined_flux_units = [ - u.Jy, u.mJy, u.uJy, u.MJy, - u.W / (u.m**2 * u.Hz), - u.eV / (u.s * u.m**2 * u.Hz), - u.erg / (u.s * u.cm**2), - u.erg / (u.s * u.cm**2 * u.Angstrom), - u.erg / (u.s * u.cm**2 * u.Hz), - u.ph / (u.s * u.cm**2 * u.Angstrom), - u.ph / (u.s * u.cm**2 * u.Hz), - u.bol, u.AB, u.ST - ] - - locally_defined_sb_units = [ - unit / u.sr for unit in locally_defined_flux_units - ] - - if any(y_unit.is_equivalent(unit) for unit in locally_defined_sb_units): - flux_unit_type = "Surface Brightness" - elif any(y_unit.is_equivalent(unit) for unit in locally_defined_flux_units): - flux_unit_type = 'Flux' - elif y_unit.is_equivalent(u.electron / u.s) or y_unit.physical_type == 'dimensionless': - # electron / s or 'dimensionless_unscaled' should be labeled counts - flux_unit_type = "Counts" - elif y_unit.is_equivalent(u.W): - flux_unit_type = "Luminosity" - else: - # default to Flux Density for flux density or uncaught types - flux_unit_type = "Flux density" - - # Set x axes labels for the spectrum viewer - x_disp_unit = self.state.x_display_unit - x_unit = u.Unit(x_disp_unit) if x_disp_unit else u.dimensionless_unscaled - if x_unit.is_equivalent(u.m): - spectral_axis_unit_type = "Wavelength" - elif x_unit.is_equivalent(u.Hz): - spectral_axis_unit_type = "Frequency" - elif x_unit.is_equivalent(u.pixel): - spectral_axis_unit_type = "Pixel" - else: - spectral_axis_unit_type = str(x_unit.physical_type).title() - - with self.figure.hold_sync(): - self.figure.axes[0].label = f"{spectral_axis_unit_type} [{self.state.x_display_unit}]" - self.figure.axes[1].label = f"{flux_unit_type} [{self.state.y_display_unit}]" - - # Make it so axis labels are not covering tick numbers. - self.figure.fig_margin["left"] = 95 - self.figure.fig_margin["bottom"] = 60 - self.figure.send_state('fig_margin') # Force update - self.figure.axes[0].label_offset = "40" - self.figure.axes[1].label_offset = "-70" - # NOTE: with tick_style changed below, the default responsive ticks in bqplot result - # in overlapping tick labels. For now we'll hardcode at 8, but this could be removed - # (default to None) if/when bqplot auto ticks react to styling options. - self.figure.axes[1].num_ticks = 8 - - # Set Y-axis to scientific notation - self.figure.axes[1].tick_format = '0.1e' - - for i in (0, 1): - self.figure.axes[i].tick_style = {'font-size': 15, 'font-weight': 600} diff --git a/jdaviz/conftest.py b/jdaviz/conftest.py index 8a62132da2..ae90c0706a 100644 --- a/jdaviz/conftest.py +++ b/jdaviz/conftest.py @@ -14,7 +14,7 @@ from astropy.wcs import WCS from specutils import Spectrum1D, SpectrumCollection, SpectrumList -from jdaviz import __version__, Cubeviz, Imviz, Mosviz, Specviz, Specviz2d +from jdaviz import __version__, Cubeviz, Imviz, Mosviz, Specviz, Specviz2d, Rampviz from jdaviz.configs.imviz.tests.utils import create_wfi_image_model from jdaviz.configs.imviz.plugins.parsers import HAS_ROMAN_DATAMODELS from jdaviz.utils import NUMPY_LT_2_0 @@ -50,6 +50,26 @@ def specviz2d_helper(): return Specviz2d() +@pytest.fixture +def rampviz_helper(): + return Rampviz() + + +@pytest.fixture +def roman_level_1_ramp(): + from roman_datamodels.maker_utils import mk_datamodel + from roman_datamodels.datamodels import RampModel + rng = np.random.default_rng(seed=42) + + shape = (10, 25, 25) + data_model = mk_datamodel(RampModel, shape=shape, dq=False) + + data_model.data = u.Quantity( + 100 + 3 * np.cumsum(rng.uniform(size=shape), axis=0), u.DN + ) + return data_model + + @pytest.fixture def image_2d_wcs(): return WCS({'CTYPE1': 'RA---TAN', 'CUNIT1': 'deg', 'CDELT1': -0.0002777777778, diff --git a/jdaviz/core/config.py b/jdaviz/core/config.py index f5d883f966..eca26108ed 100644 --- a/jdaviz/core/config.py +++ b/jdaviz/core/config.py @@ -40,6 +40,8 @@ def read_configuration(path=None): path = default_path / "specviz2d" / "specviz2d.yaml" elif path == 'imviz': path = default_path / "imviz" / "imviz.yaml" + elif path == 'rampviz': + path = default_path / "rampviz" / "rampviz.yaml" elif not os.path.isfile(path): raise ValueError("Configuration must be path to a .yaml file.") diff --git a/jdaviz/core/helpers.py b/jdaviz/core/helpers.py index 194a3dfc80..ebee7d0e86 100644 --- a/jdaviz/core/helpers.py +++ b/jdaviz/core/helpers.py @@ -12,25 +12,25 @@ from inspect import isclass import numpy as np -import astropy.units as u -from astropy.nddata import CCDData, StdDevUncertainty -from regions.core.core import Region from glue.core import HubListener from glue.core.edit_subset_mode import NewMode from glue.core.message import SubsetCreateMessage, SubsetDeleteMessage from glue.core.subset import Subset, MaskSubsetState from glue.config import data_translator from ipywidgets.widgets import widget_serialization -from specutils import Spectrum1D, SpectralRegion +import astropy.units as u +from astropy.nddata import NDDataArray, CCDData, StdDevUncertainty +from regions.core.core import Region +from specutils import Spectrum1D, SpectralRegion from jdaviz.app import Application -from jdaviz.core.events import SnackbarMessage, ExitBatchLoadMessage +from jdaviz.core.events import SnackbarMessage, ExitBatchLoadMessage, SliceSelectSliceMessage from jdaviz.core.template_mixin import show_widget from jdaviz.utils import data_has_valid_wcs, flux_conversion, spectral_axis_conversion -__all__ = ['ConfigHelper', 'ImageConfigHelper'] +__all__ = ['ConfigHelper', 'ImageConfigHelper', 'CubeConfigHelper'] class ConfigHelper(HubListener): @@ -515,7 +515,7 @@ def _handle_display_units(self, data, use_display_units=True): return data def _get_data(self, data_label=None, spatial_subset=None, spectral_subset=None, - mask_subset=None, cls=None, use_display_units=False): + temporal_subset=None, mask_subset=None, cls=None, use_display_units=False): list_of_valid_subset_names = [x.label for x in self.app.data_collection.subset_groups] for subset in (spatial_subset, spectral_subset, mask_subset): if subset and subset not in list_of_valid_subset_names: @@ -542,6 +542,11 @@ def _get_data(self, data_label=None, spatial_subset=None, spectral_subset=None, # apps which would then need to do their own type checks, if necessary) mask_subset = spectral_subset + if temporal_subset: + if mask_subset is not None: + raise ValueError("cannot use both mask_subset and spectral_subset") + mask_subset = temporal_subset + # End validity checks and start data retrieval data = self.app.data_collection[data_label] @@ -553,7 +558,11 @@ def _get_data(self, data_label=None, spatial_subset=None, spectral_subset=None, elif data.ndim == 2: cls = CCDData elif data.ndim in [1, 3]: - cls = Spectrum1D + if self.app.config == 'rampviz': + cls = NDDataArray + else: + # for cubeviz, specviz, mosviz, this must be a spectrum: + cls = Spectrum1D object_kwargs = {} if cls == Spectrum1D: @@ -982,3 +991,27 @@ def _next_subset_num(label_prefix, subset_groups): max_i = i return max_i + 1 + + +class CubeConfigHelper(ImageConfigHelper): + """Base config helper class for cubes""" + + _loaded_flux_cube = None + _loaded_uncert_cube = None + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def select_slice(self, value): + """ + Select the slice closest to the provided value. + + Parameters + ---------- + value : float or int, optional + Slice value to select in units of the x-axis of the profile viewer. + The nearest slice will be selected if "snap to slice" is enabled in + the slice plugin. + """ + msg = SliceSelectSliceMessage(value=value, sender=self) + self.app.hub.broadcast(msg) diff --git a/jdaviz/core/marks.py b/jdaviz/core/marks.py index eed9a85fce..af21d275b2 100644 --- a/jdaviz/core/marks.py +++ b/jdaviz/core/marks.py @@ -158,8 +158,10 @@ def __init__(self, viewer, x, **kwargs): **kwargs) def _update_reference_data(self, reference_data): - if reference_data is None: + # don't update x units before initialization or in rampviz + if reference_data is None or self.viewer.jdaviz_app.config == 'rampviz': return + self._update_unit(reference_data.get_object(cls=Spectrum1D).spectral_axis.unit) def _update_unit(self, new_unit): @@ -538,6 +540,7 @@ def __init__(self, viewer, x=[], y=[], **kwargs): self.viewer = viewer # color is same blue as import button kwargs.setdefault('colors', [accent_color]) + self.label = kwargs.get('label') super().__init__(x=x, y=y, scales=kwargs.pop('scales', viewer.scales), **kwargs) diff --git a/jdaviz/core/registries.py b/jdaviz/core/registries.py index 170f0fc412..84ddc4370b 100644 --- a/jdaviz/core/registries.py +++ b/jdaviz/core/registries.py @@ -93,7 +93,7 @@ class TrayRegistry(UniqueDictRegistry): """ default_viewer_category = [ - "spectrum", "table", "image", "spectrum-2d", "flux", "uncert" + "spectrum", "table", "image", "spectrum-2d", "flux", "uncert", "profile" ] default_viewer_reqs = { category: { diff --git a/jdaviz/core/tools.py b/jdaviz/core/tools.py index cae2bc86a6..91f8ea0bc5 100644 --- a/jdaviz/core/tools.py +++ b/jdaviz/core/tools.py @@ -164,7 +164,13 @@ class HomeZoom(HomeTool, _BaseZoomHistory): def activate(self): self.save_prev_zoom() - super().activate() + + # typical case: + if not hasattr(self.viewer, 'reset_limits'): + super().activate() + else: + # if the viewer has its own reset_limits method, use it: + self.viewer.reset_limits() @viewer_tool diff --git a/jdaviz/tests/test_data_formats.py b/jdaviz/tests/test_data_formats.py index 414be4bcb4..3f32b914f0 100644 --- a/jdaviz/tests/test_data_formats.py +++ b/jdaviz/tests/test_data_formats.py @@ -123,7 +123,7 @@ def test_list_configurations(): """ test correct configurations are listed """ configs = list_configurations() assert set(configs).issubset({'default', 'cubeviz', 'specviz', 'mosviz', - 'imviz', 'specviz2d'}) + 'imviz', 'specviz2d', 'rampviz'}) @pytest.mark.parametrize('name, expconf, expstat', diff --git a/jdaviz/utils.py b/jdaviz/utils.py index ff19929662..389e274bf2 100644 --- a/jdaviz/utils.py +++ b/jdaviz/utils.py @@ -19,6 +19,7 @@ from glue.core import BaseData from glue.core.exceptions import IncompatibleAttribute from glue.core.subset import SubsetState, RangeSubsetState, RoiSubsetState +from glue_astronomy.spectral_coordinates import SpectralCoordinates from ipyvue import watch from jdaviz.core.validunits import check_if_unit_is_per_solid_angle @@ -27,7 +28,8 @@ 'standardize_metadata', 'ColorCycler', 'alpha_index', 'get_subset_type', 'download_uri_to_path', 'flux_conversion', 'spectral_axis_conversion', 'layer_is_2d', 'layer_is_2d_or_3d', 'layer_is_image_data', 'layer_is_wcs_only', - 'get_wcs_only_layer_labels', 'get_top_layer_index', 'get_reference_image_data'] + 'get_wcs_only_layer_labels', 'get_top_layer_index', 'get_reference_image_data', + 'standardize_roman_metadata'] NUMPY_LT_2_0 = not minversion("numpy", "2.0.dev") @@ -288,6 +290,36 @@ def standardize_metadata(metadata): return out_meta +def standardize_roman_metadata(data_model): + """ + Metadata standardization for Roman datamodels ``meta`` attributes. + + Converts to a flat dictionary and strips the redundant top-level + tags ("roman", and "meta"). + + Parameters + ---------- + data_model : `~roman_datamodels.datamodels.DataModel` + Roman datamodel. + + Returns + ------- + d : dict + Flattened dictionary of metadata + """ + import roman_datamodels.datamodels as rdm + if isinstance(data_model, rdm.DataModel): + # Roman metadata are in nested dicts that we flatten: + flat_dict_meta = data_model.to_flat_dict() + + # split off the redundant parts of the metadata: + return { + k.split('roman.meta.')[1]: v + for k, v in flat_dict_meta.items() + if 'roman.meta' in k + } + + def indirect_units(): return [ u.erg / (u.s * u.cm**2 * u.Angstrom * u.sr), @@ -517,7 +549,7 @@ def get_subset_type(subset): Returns ------- subset_type : str or None - 'spatial', 'spectral', or None + 'spatial', 'spectral', 'temporal', or None """ if not hasattr(subset, 'subset_state'): return None @@ -530,7 +562,40 @@ def get_subset_type(subset): if isinstance(subset.subset_state, RoiSubsetState): return 'spatial' elif isinstance(subset.subset_state, RangeSubsetState): - return 'spectral' + # look within a SubsetGroup, or a single Subset + subset_list = getattr(subset, 'subsets', [subset]) + + for ss in subset_list: + if hasattr(ss, 'data'): + ss_data = ss.data + elif hasattr(ss.att, 'parent'): + # if `ss` is a subset state, it won't have a `data` attr, + # check the world coordinate's parent data: + ss_data = ss.att.parent + else: + # if we reach this `else`, continue searching + # through other subsets in the group to identify the + # subset type: + continue + + # check for a spectral coordinate in FITS WCS: + wcs_coords = ( + ss_data.coords.wcs.ctype if hasattr(ss_data.coords, 'wcs') + else [] + ) + + has_spectral_coords = ( + any(str(coord).startswith('WAVE') for coord in wcs_coords) or + + # also check for a spectral coordinate from the glue_astronomy translator: + isinstance(ss_data.coords, SpectralCoordinates) + ) + + if has_spectral_coords: + return 'spectral' + + # otherwise, assume temporal: + return 'temporal' else: return None diff --git a/notebooks/concepts/RampvizExample.ipynb b/notebooks/concepts/RampvizExample.ipynb new file mode 100644 index 0000000000..7d817f4efd --- /dev/null +++ b/notebooks/concepts/RampvizExample.ipynb @@ -0,0 +1,162 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "39fa8c0d-da61-4704-9bba-c438347abf95", + "metadata": {}, + "source": [ + "# Visualize Roman L1 ramp files with Rampviz\n", + "\n", + "\n", + "To install jdaviz from source with the optional Roman dependencies:\n", + "```bash\n", + "pip install .[roman]\n", + "```\n", + "\n", + "\n", + "First, let's download a ramp file:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b69aafe8-264b-4b11-a908-4603ad86ba7b", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "from urllib.request import urlretrieve\n", + "\n", + "from jdaviz import Rampviz\n", + "\n", + "force_download = False\n", + "\n", + "url_L1 = \"https://stsci.box.com/shared/static/80vahj27t3y02itfohc22p999snkcocw.asdf\"\n", + "local_path = \"L1.asdf\"\n", + "\n", + "if not os.path.exists(local_path) or force_download:\n", + " urlretrieve(url_L1, local_path)" + ] + }, + { + "cell_type": "markdown", + "id": "64fb7c99-084b-40af-8180-5af16a04069e", + "metadata": {}, + "source": [ + "Let's load the file into the helper:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2b15f559-a7ad-4279-b777-4eff0b465b56", + "metadata": {}, + "outputs": [], + "source": [ + "rampviz = Rampviz()\n", + "rampviz.load_data(local_path, data_label='Roman L1')\n", + "rampviz.show(height=1000)" + ] + }, + { + "cell_type": "markdown", + "id": "7d73fbd0-8f08-4665-b2f5-858da26f51af", + "metadata": {}, + "source": [ + "We now reset viewer limits to center on a star:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e45a6c1a-bd9a-41a4-8437-905d9c690627", + "metadata": {}, + "outputs": [], + "source": [ + "plot_options = rampviz.plugins['Plot Options']\n", + "for viewer in ['group-viewer', 'diff-viewer']:\n", + " plot_options.viewer = viewer\n", + " plot_options.zoom_center_x = 1797\n", + " plot_options.zoom_center_y = 2051\n", + " plot_options.zoom_radius = 20" + ] + }, + { + "cell_type": "markdown", + "id": "d5aad2c7-bc44-4e93-85ed-9cb6511b66c2", + "metadata": {}, + "source": [ + "Let's load a spatial region to preview individual ramp profiles for each pixel within the subset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e26c0804-3b78-4b2b-8c9d-a9fc5c1c91ea", + "metadata": {}, + "outputs": [], + "source": [ + "from regions import CirclePixelRegion, PixCoord\n", + "\n", + "region = CirclePixelRegion(center=PixCoord(1797.2, 2051.2), radius=2)\n", + "rampviz.load_regions(region)" + ] + }, + { + "cell_type": "markdown", + "id": "b6c01020-f266-4a8e-ad23-0a60e55b993b", + "metadata": {}, + "source": [ + "And let's take the median of the ramps within Subset 1:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5407496a-96ac-4202-a2c0-345bb5649263", + "metadata": {}, + "outputs": [], + "source": [ + "ramp_extract = rampviz.plugins['Ramp Extraction']\n", + "\n", + "# # If you re-run this cell, you may need to re-enable previews:\n", + "# ramp_extract._obj.previews_temp_disabled = False\n", + "\n", + "# ramp_extract.keep_active = True\n", + "ramp_extract.function = 'Median'\n", + "ramp_extract.aperture = 'Subset 1'\n", + "\n", + "ramp_extract.extract()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ace3d5db-247e-4136-9851-e4ebab0ea9ea", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.5" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}