Skip to content

Commit

Permalink
adding support for visualizing all pixels in subsets, and per pixel o…
Browse files Browse the repository at this point in the history
…n hover
  • Loading branch information
bmorris3 committed Aug 7, 2024
1 parent 46aa6a4 commit 282f193
Show file tree
Hide file tree
Showing 11 changed files with 216 additions and 118 deletions.
2 changes: 1 addition & 1 deletion jdaviz/configs/cubeviz/plugins/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .tools import * # noqa
from .mixins import * # noqa
from .viewers import * # noqa
from .parsers import * # noqa
from .moment_maps.moment_maps import * # noqa
from .slice.slice import * # noqa
from .spectral_extraction.spectral_extraction import * # noqa
from .tools import * # noqa
52 changes: 8 additions & 44 deletions jdaviz/configs/cubeviz/plugins/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,13 @@

from glue.config import viewer_tool
from glue_jupyter.bqplot.image import BqplotImageView
from glue_jupyter.bqplot.profile import BqplotProfileView
from glue.viewers.common.tool import CheckableTool
import numpy as np
from specutils import Spectrum1D

from jdaviz.core.events import SliceToolStateMessage, SliceSelectSliceMessage
from jdaviz.core.tools import PanZoom, BoxZoom, SinglePixelRegion, _MatchedZoomMixin
from jdaviz.core.marks import PluginLine
from jdaviz.core.tools import PanZoom, BoxZoom, _MatchedZoomMixin
from jdaviz.configs.default.plugins.tools import ProfileFromCube

__all__ = []

Expand Down Expand Up @@ -81,52 +80,17 @@ def on_mouse_event(self, data):


@viewer_tool
class SpectrumPerSpaxel(SinglePixelRegion):
class SpectrumPerSpaxel(ProfileFromCube):

icon = os.path.join(ICON_DIR, 'pixelspectra.svg')
tool_id = 'jdaviz:spectrumperspaxel'
action_text = 'See spectrum at a single spaxel'
tool_tip = 'Click on the viewer and see the spectrum at that spaxel in the spectrum viewer'

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._spectrum_viewer = None
self._previous_bounds = None
self._mark = None
self._data = None

def _reset_spectrum_viewer_bounds(self):
sv_state = self._spectrum_viewer.state
sv_state.x_min = self._previous_bounds[0]
sv_state.x_max = self._previous_bounds[1]
sv_state.y_min = self._previous_bounds[2]
sv_state.y_max = self._previous_bounds[3]

def activate(self):
self.viewer.add_event_callback(self.on_mouse_move, events=['mousemove', 'mouseleave'])
if self._spectrum_viewer is None:
# Get first profile viewer
for _, viewer in self.viewer.jdaviz_helper.app._viewer_store.items():
if isinstance(viewer, BqplotProfileView):
self._spectrum_viewer = viewer
break
if self._mark is None:
self._mark = PluginLine(self._spectrum_viewer, visible=False)
self._spectrum_viewer.figure.marks = self._spectrum_viewer.figure.marks + [self._mark,]
# Store these so we can revert to previous user-set zoom after preview view
sv_state = self._spectrum_viewer.state
self._previous_bounds = [sv_state.x_min, sv_state.x_max, sv_state.y_min, sv_state.y_max]
super().activate()

def deactivate(self):
self.viewer.remove_event_callback(self.on_mouse_move)
self._reset_spectrum_viewer_bounds()
super().deactivate()

def on_mouse_move(self, data):
if data['event'] == 'mouseleave':
self._mark.visible = False
self._reset_spectrum_viewer_bounds()
self._reset_profile_viewer_bounds()
return

x = int(np.round(data['domain']['x']))
Expand Down Expand Up @@ -157,13 +121,13 @@ def on_mouse_move(self, data):
else:
spectrum = cube_data.get_object(statistic=None)
# Note: change this when Spectrum1D.with_spectral_axis is fixed.
x_unit = self._spectrum_viewer.state.x_display_unit
x_unit = self._profile_viewer.state.x_display_unit
if spectrum.spectral_axis.unit != x_unit:
new_spectral_axis = spectrum.spectral_axis.to(x_unit)
spectrum = Spectrum1D(spectrum.flux, new_spectral_axis)

if x >= spectrum.flux.shape[0] or x < 0 or y >= spectrum.flux.shape[1] or y < 0:
self._reset_spectrum_viewer_bounds()
self._reset_profile_viewer_bounds()
self._mark.visible = False
else:
y_values = spectrum.flux[x, y, :]
Expand All @@ -172,5 +136,5 @@ def on_mouse_move(self, data):
return
self._mark.update_xy(spectrum.spectral_axis.value, y_values)
self._mark.visible = True
self._spectrum_viewer.state.y_max = np.nanmax(y_values.value) * 1.2
self._spectrum_viewer.state.y_min = np.nanmin(y_values.value) * 0.8
self._profile_viewer.state.y_max = np.nanmax(y_values.value) * 1.2
self._profile_viewer.state.y_min = np.nanmin(y_values.value) * 0.8
47 changes: 47 additions & 0 deletions jdaviz/configs/default/plugins/tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from glue_jupyter.bqplot.profile import BqplotProfileView
from jdaviz.core.tools import SinglePixelRegion
from jdaviz.core.marks import PluginLine


__all__ = ['ProfileFromCube']


class ProfileFromCube(SinglePixelRegion):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._profile_viewer = None
self._previous_bounds = None
self._mark = None
self._data = None

def _reset_profile_viewer_bounds(self):
pv_state = self._profile_viewer.state
pv_state.x_min = self._previous_bounds[0]
pv_state.x_max = self._previous_bounds[1]
pv_state.y_min = self._previous_bounds[2]
pv_state.y_max = self._previous_bounds[3]

def activate(self):
self.viewer.add_event_callback(self.on_mouse_move, events=['mousemove', 'mouseleave'])
if self._profile_viewer is None:
# Get first profile viewer
for _, viewer in self.viewer.jdaviz_helper.app._viewer_store.items():
if isinstance(viewer, BqplotProfileView):
self._profile_viewer = viewer
break
if self._mark is None:
self._mark = PluginLine(self._profile_viewer, visible=False)
self._profile_viewer.figure.marks = self._profile_viewer.figure.marks + [self._mark, ]
# Store these so we can revert to previous user-set zoom after preview view
pv_state = self._profile_viewer.state
self._previous_bounds = [pv_state.x_min, pv_state.x_max, pv_state.y_min, pv_state.y_max]
super().activate()

def deactivate(self):
self.viewer.remove_event_callback(self.on_mouse_move)
self._reset_profile_viewer_bounds()
super().deactivate()

def on_mouse_move(self, data):
raise NotImplementedError("must be implemented by sublcasses")
2 changes: 1 addition & 1 deletion jdaviz/configs/default/plugins/viewers.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def _apply_layer_defaults(self, layer_state):
layer_state.add_callback('as_steps', self._show_uncertainty_changed)

def _expected_subset_layer_default(self, layer_state):
if self.__class__.__name__ == 'CubevizImageView':
if self.__class__.__name__ in ('CubevizImageView', 'RampvizImageView'):
# Do not override default for subsets as for some reason
# this isn't getting called when they're first added, but rather when
# the next state change is made (for example: manually changing the visibility)
Expand Down
39 changes: 4 additions & 35 deletions jdaviz/configs/rampviz/helper.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from jdaviz.core.events import SliceSelectSliceMessage
from jdaviz.core.events import AddDataMessage, SnackbarMessage
from jdaviz.core.events import AddDataMessage
from jdaviz.core.helpers import CubeConfigHelper
from jdaviz.configs.rampviz.plugins.viewers import RampvizImageView

Expand Down Expand Up @@ -49,42 +49,11 @@ def load_data(self, data, data_label=None, **kwargs):
self.app.hub.subscribe(self, AddDataMessage,
handler=self._set_x_axis)

if 'Ramp Extraction' not in self.plugins: # pragma: no cover
msg = SnackbarMessage(
"Automatic ramp extraction requires the Ramp Extraction plugin to be enabled", # noqa
color='error', sender=self, timeout=10000)
self.app.hub.broadcast(msg)
else:
try:
self.plugins['Ramp Extraction']._obj._extract_in_new_instance(auto_update=False, add_data=True) # noqa
except Exception as err:
msg = SnackbarMessage(
"Automatic ramp extraction for the entire cube failed."
f" See the ramp extraction plugin to perform a custom extraction: {err}",
color='error', sender=self, timeout=10000)
else:
msg = SnackbarMessage(
"The extracted ramp profile was generated automatically for the entire cube."
" See the ramp extraction plugin for details or to"
" perform a custom extraction.",
color='warning', sender=self, timeout=10000)
self.app.hub.broadcast(msg)

def _set_x_axis(self, msg):
viewer = self.app.get_viewer(self._default_integration_viewer_reference_name)
if msg.viewer_id != viewer.reference_id:
return
viewer = self.app.get_viewer(self._default_group_viewer_reference_name)
ref_data = viewer.state.reference_data
if ref_data and ref_data.ndim == 3:
for att_name in _temporal_axis_names:
if att_name in ref_data.component_ids():
if viewer.state.x_att != ref_data.id[att_name]:
viewer.state.x_att = ref_data.id[att_name]
viewer.state.reset_limits()
break
else:
viewer.state.x_att = ref_data.id["Pixel Axis 2 [x]"]
viewer.state.reset_limits()
viewer.state.x_att = ref_data.id["Pixel Axis 2 [x]"]
viewer.state.reset_limits()

def select_group(self, group_index):
"""
Expand Down
1 change: 1 addition & 0 deletions jdaviz/configs/rampviz/plugins/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .viewers import * # noqa
from .parsers import * # noqa
from .ramp_extraction import * # noqa
from .tools import * # noqa
15 changes: 7 additions & 8 deletions jdaviz/configs/rampviz/plugins/parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,6 @@ def parse_data(app, file_obj, data_type=None, data_label=None,
)

elif isinstance(file_obj, (np.ndarray, NDData)) and file_obj.ndim in (1, 2):
if file_obj.ndim == 2:
app.get_viewer(integration_viewer_reference_name).is2d = True
# load 1D profile(s) to integration_viewer
_parse_ndarray(
app, file_obj, data_label=data_label,
Expand Down Expand Up @@ -166,24 +164,25 @@ def _swap_axes(x):
])

ramp_cube_data_label = f"{data_label}[DATA]"
ramp_diff_data_label = f"{data_label}[DIFF]"

# load these cubes into the cache:
app._jdaviz_helper.cube_cache[ramp_cube_data_label] = NDDataArray(_swap_axes(data))
app._jdaviz_helper.cube_cache[ramp_diff_data_label] = NDDataArray(_swap_axes(diff_data))

# load these cubes into the app:
_parse_ndarray(
app,
file_obj=_swap_axes(data),
data_label=ramp_cube_data_label,
viewer_reference_name=group_viewer_reference_name,
)

app._jdaviz_helper.cube_cache[ramp_cube_data_label] = NDDataArray(_swap_axes(data))

# load the diff of the data cube
ramp_diff_data_label = f"{data_label}[DIFF]"
_parse_ndarray(
app,
file_obj=_swap_axes(diff_data),
data_label=ramp_diff_data_label,
viewer_reference_name=diff_viewer_reference_name,
)
app._jdaviz_helper.cube_cache[ramp_diff_data_label] = NDDataArray(_swap_axes(diff_data))

# the default collapse function in the profile viewer is "sum",
# but for ramp files, "median" is more useful:
Expand Down
82 changes: 55 additions & 27 deletions jdaviz/configs/rampviz/plugins/ramp_extraction/ramp_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from functools import cached_property
from traitlets import Bool, Float, List, Unicode, observe
from glue.core.message import DataCollectionAddMessage, SubsetUpdateMessage

from jdaviz.core.events import SnackbarMessage, SliceValueUpdatedMessage
from jdaviz.core.marks import PluginLine
Expand Down Expand Up @@ -102,13 +103,61 @@ def __init__(self, *args, **kwargs):
self.session.hub.subscribe(self, SliceValueUpdatedMessage,
handler=self._on_slice_changed)

self.session.hub.subscribe(self, DataCollectionAddMessage,
handler=self._on_data_added)

self.session.hub.subscribe(self, SubsetUpdateMessage,
handler=self._on_subset_update)

self._update_disabled_msg()

if self.app.state.settings.get('server_is_remote', False):
# when the server is remote, saving the file in python would save on the server, not
# on the user's machine, so export support in cubeviz should be disabled
self.export_enabled = False

@property
def integration_viewer(self):
viewer = self.app.get_viewer(
self.app._jdaviz_helper._default_integration_viewer_reference_name
)
return viewer

def _on_data_added(self, msg={}):
if msg.data.label.endswith('[DATA]'):
self.extract(add_data=True)
self.integration_viewer._initialize_x_axis()

def _on_subset_update(self, msg={}):

if not hasattr(self, 'aperture') or not hasattr(self.app._jdaviz_helper, 'cube_cache'):
return

cube_cache = self.app._jdaviz_helper.cube_cache
cube = cube_cache[list(cube_cache.keys())[0]]

subset_lbl = msg.subset.label
color = msg.subset.style.color

subset = self.app.get_subsets(subset_lbl)[0]
region = subset['region']
# glue region has transposed coords relative to cached cube:
region_mask = region.to_mask().to_image(cube.shape[:-1]).astype(bool).T
cube_subset = cube[region_mask]

mark = [
PluginLine(self.integration_viewer, x=np.arange(cube_subset.shape[1]), y=y,
stroke_width=1.5, colors=[color], opacities=[0.3], label=subset_lbl)
for y in cube_subset
]

self.integration_viewer.figure.marks = [
mark for mark in self.integration_viewer.figure.marks
if getattr(mark, 'label', None) != subset_lbl
] + mark

self.integration_viewer.reset_limits()

@property
def user_api(self):
expose = ['dataset', 'function', 'aperture',
Expand Down Expand Up @@ -158,31 +207,6 @@ def _active_step_changed(self, *args):
def slice_plugin(self):
return self.app._jdaviz_helper.plugins['Slice']

@observe('aperture_items')
@skip_if_not_tray_instance()
def _aperture_items_changed(self, msg):
if not self.do_auto_extraction:
return
if not hasattr(self, 'aperture'):
return
for item in msg['new']:
if item not in msg['old']:
if item.get('type') != 'spatial':
continue
subset_lbl = item.get('label')
try:
self._extract_in_new_instance(subset_lbl=subset_lbl,
auto_update=True, add_data=True)
except Exception as err:
msg = SnackbarMessage(
f"Automatic {self.resulting_product_name} extraction for {subset_lbl} failed: {err}", # noqa
color='error', sender=self, timeout=10000)
else:
msg = SnackbarMessage(
f"Automatic {self.resulting_product_name} extraction for {subset_lbl} successful", # noqa
color='success', sender=self)
self.app.hub.broadcast(msg)

def _extract_in_new_instance(self, dataset=None, function='Mean', subset_lbl=None,
auto_update=False, add_data=False):
# create a new instance of the Ramp Extraction plugin (to not affect the instance in
Expand Down Expand Up @@ -261,9 +285,13 @@ def _extract_from_aperture(self, **kwargs):
collapsed = getattr(np, selected_func)(
nddata.data, **collapse_kwargs
) << nddata.unit

def expand(x):
return np.expand_dims(x, axis=(0, 1))

return NDDataArray(
data=collapsed,
mask=mask.all(axis=self.spatial_axes),
data=expand(collapsed),
mask=expand(mask.all(axis=self.spatial_axes)),
meta=nddata.meta
)

Expand Down
Loading

0 comments on commit 282f193

Please sign in to comment.