Skip to content

Commit

Permalink
Merge pull request #722 from HEXRD/polar-view-speedups
Browse files Browse the repository at this point in the history
Add option to cache coordinate map in PolarView projection
  • Loading branch information
saransh13 authored Oct 2, 2024
2 parents 0052c05 + 3be4204 commit 45847c9
Show file tree
Hide file tree
Showing 4 changed files with 200 additions and 16 deletions.
14 changes: 11 additions & 3 deletions hexrd/instrument/detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -1025,7 +1025,9 @@ def interpolate_nearest(self, xy, img, pad_with_nans=True):
int_xy[on_panel] = int_vals
return int_xy

def interpolate_bilinear(self, xy, img, pad_with_nans=True):
def interpolate_bilinear(self, xy, img, pad_with_nans=True,
clip_to_panel=True,
on_panel: Optional[np.ndarray] = None):
"""
Interpolate an image array at the specified cartesian points.
Expand All @@ -1039,6 +1041,9 @@ def interpolate_bilinear(self, xy, img, pad_with_nans=True):
pad_with_nans : bool, optional
Toggle for assigning NaN to points that fall off the detector.
The default is True.
on_panel : np.ndarray, optional
If you want to skip clip_to_panel() for performance reasons,
just provide an array of which pixels are on the panel.
Returns
-------
Expand Down Expand Up @@ -1066,8 +1071,11 @@ def interpolate_bilinear(self, xy, img, pad_with_nans=True):
else:
int_xy = np.zeros(len(xy))

# clip away points too close to or off the edges of the detector
xy_clip, on_panel = self.clip_to_panel(xy, buffer_edges=True)
if on_panel is None:
# clip away points too close to or off the edges of the detector
xy_clip, on_panel = self.clip_to_panel(xy, buffer_edges=True)
else:
xy_clip = xy[on_panel]

# grab fractional pixel indices of clipped points
ij_frac = self.cartToPixel(xy_clip)
Expand Down
114 changes: 101 additions & 13 deletions hexrd/projections/polar.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ class PolarView:

def __init__(self, plane_data, instrument,
eta_min=0., eta_max=360.,
pixel_size=(0.1, 0.25)):
pixel_size=(0.1, 0.25),
cache_coordinate_map=False):
"""
Instantiates a PolarView class.
Expand All @@ -36,6 +37,13 @@ def __init__(self, plane_data, instrument,
pixel_size : array_like, optional
The angular pixels sizes (2theta, eta) in degrees.
The default is (0.1, 0.25).
cache_coordinate_map : bool, optional
If True, the coordinate map will be cached so that calls to
`warp_image()` will be *significantly* faster.
If set to True, the caller *must* ensure that no parameters
on the instrument that would affect polar view generation,
and no parameters set on this class, will be modified,
because doing so would result in an incorrect `warp_image()`.
Returns
-------
Expand Down Expand Up @@ -67,6 +75,15 @@ def __init__(self, plane_data, instrument,

self._instrument = instrument

self._coordinate_mapping = None
self._cache_coordinate_map = cache_coordinate_map
if cache_coordinate_map:
# It is important to generate the cached map now, rather than
# later, because this object might be sent to other processes
# for parallelization, and it will be faster if the mapping
# is already generated.
self._coordinate_mapping = self._generate_coordinate_mapping()

@property
def instrument(self):
return self._instrument
Expand All @@ -75,6 +92,10 @@ def instrument(self):
def detectors(self):
return self._instrument.detectors

@property
def cache_coordinate_map(self):
return self._cache_coordinate_map

@property
def tvec(self):
return self._instrument.tvec
Expand Down Expand Up @@ -216,6 +237,10 @@ def warp_image(self, image_dict, pad_with_nans=False,
"""
Performs the polar mapping of the input images.
Note: this function has the potential to run much faster if
`cache_coordinate_map` is set to `True` on the `PolarView`
initialization.
Parameters
----------
image_dict : dict
Expand All @@ -232,16 +257,46 @@ def warp_image(self, image_dict, pad_with_nans=False,
Tested ouput using Maud.
"""

if self.cache_coordinate_map:
# The mapping should have already been generated.
mapping = self._coordinate_mapping
else:
# Otherwise, we must generate it every time
mapping = self._generate_coordinate_mapping()

return self._warp_image_from_coordinate_map(
image_dict,
mapping,
pad_with_nans=pad_with_nans,
do_interpolation=do_interpolation,
)

def _generate_coordinate_mapping(self) -> dict[str, dict[str, np.ndarray]]:
"""Generate mapping of detector coordinates to generate polar view
This function is, in general, the most time consuming part of creating
the polar view. Its results can be cached
If you plan to generate the polar view many times in a row using the
same instrument configuration, but different data files, this
function can be called once at the beginning to generate a mapping
of the detectors to the cartesian coordinates for each angular pixel,
followed by warp_image_from_mapping() to create the polar view.
This can be significantly faster than calling `warp_image()` every
time
The dictionary that returns has detector IDs as the first key, and
another dict as the second key.
The nested dict has "xypts" and "on_panel" as keys, and the
respective arrays as the values.
"""
angpts = self.angular_grid
dummy_ome = np.zeros((self.ntth*self.neta))

# lcount = 0
img_dict = dict.fromkeys(self.detectors)
mapping = {}
for detector_id, panel in self.detectors.items():
_project_on_detector = self._func_project_on_detector(panel)
img = image_dict[detector_id]

gvec_angs = np.vstack([
angpts[1].flatten(),
angpts[0].flatten(),
Expand All @@ -255,20 +310,53 @@ def warp_image(self, image_dict, pad_with_nans=False,
**kwargs)
xypts[on_plane, :] = valid_xys

_, on_panel = panel.clip_to_panel(xypts, buffer_edges=True)

mapping[detector_id] = {
'xypts': xypts,
'on_panel': on_panel,
}

return mapping

def _warp_image_from_coordinate_map(
self,
image_dict: dict[str, np.ndarray],
coordinate_map: dict[str, dict[str, np.ndarray]],
pad_with_nans: bool = False,
do_interpolation=True) -> np.ma.MaskedArray:
img_dict = dict.fromkeys(self.detectors)
nan_mask = None
for detector_id, panel in self.detectors.items():
img = image_dict[detector_id]
xypts = coordinate_map[detector_id]['xypts']
on_panel = coordinate_map[detector_id]['on_panel']

if do_interpolation:
this_img = panel.interpolate_bilinear(
xypts, img,
pad_with_nans=pad_with_nans).reshape(self.shape)
pad_with_nans=pad_with_nans,
on_panel=on_panel).reshape(self.shape)
else:
this_img = panel.interpolate_nearest(
xypts, img,
pad_with_nans=pad_with_nans).reshape(self.shape)
nan_mask = np.isnan(this_img)
img_dict[detector_id] = np.ma.masked_array(
data=this_img, mask=nan_mask, fill_value=0.
)
maimg = np.ma.sum(np.ma.stack(img_dict.values()), axis=0)
return maimg

# It is faster to keep track of the global nans like this
# rather than the previous way we were doing it...
img_nans = np.isnan(this_img)
if nan_mask is None:
nan_mask = img_nans
else:
nan_mask = np.logical_and(img_nans, nan_mask)

this_img[img_nans] = 0
img_dict[detector_id] = this_img

summed_img = np.sum(list(img_dict.values()), axis=0)
return np.ma.masked_array(
data=summed_img, mask=nan_mask, fill_value=0.
)

def tth_to_pixel(self, tth):
"""
Expand Down
Binary file added tests/data/test_polar_view_expected.npy
Binary file not shown.
88 changes: 88 additions & 0 deletions tests/test_polar_view.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from pathlib import Path

import h5py
import numpy as np
import pytest

from hexrd import imageseries
from hexrd.imageseries.process import ProcessedImageSeries
from hexrd.instrument import HEDMInstrument
from hexrd.projections.polar import PolarView


@pytest.fixture
def eiger_examples_path(example_repo_path: Path) -> Path:
return Path(example_repo_path) / 'eiger'


@pytest.fixture
def ceria_examples_path(eiger_examples_path: Path) -> Path:
return eiger_examples_path / 'first_ceria'


@pytest.fixture
def ceria_example_data(ceria_examples_path: Path) -> np.ndarray:
data_path = ceria_examples_path / 'ff_000_data_000001.h5'
with h5py.File(data_path, 'r') as rf:
# Just return the first frame
return rf['/entry/data/data'][0]


@pytest.fixture
def ceria_composite_instrument(ceria_examples_path: Path) -> HEDMInstrument:
instr_path = (
ceria_examples_path / 'eiger_ceria_uncalibrated_composite.hexrd'
)
with h5py.File(instr_path, 'r') as rf:
return HEDMInstrument(rf)


def test_polar_view(
ceria_composite_instrument: HEDMInstrument,
ceria_example_data: np.ndarray,
test_data_dir: Path,
):
instr = ceria_composite_instrument
image_data = ceria_example_data

# Break up the image data into separate images for each detector
# It's easiest to do this using hexrd's imageseries and
# ProcessedImageSeries
ims_dict = {}
ims = imageseries.open(None, format='array', data=image_data)
for det_key, panel in instr.detectors.items():
ims_dict[det_key] = ProcessedImageSeries(
ims, oplist=[('rectangle', panel.roi)]
)

# Create the img_dict
img_dict = {k: v[0] for k, v in ims_dict.items()}

# Create the PolarView
tth_range = [0, 14.0]
eta_min = -180.0
eta_max = 180.0
pixel_size = (0.01, 5.0)

pv = PolarView(tth_range, instr, eta_min, eta_max, pixel_size)
img = pv.warp_image(img_dict, pad_with_nans=True,
do_interpolation=True)

# This is a masked array. Just fill it with nans.
img = img.filled(np.nan)

# Verify that the image is identical to a reference image
ref = np.load(
test_data_dir / 'test_polar_view_expected.npy', allow_pickle=True
)
assert np.allclose(img, ref, equal_nan=True)

# Also generate it using the cache
pv = PolarView(tth_range, instr, eta_min, eta_max, pixel_size,
cache_coordinate_map=True)
fast_img = pv.warp_image(img_dict, pad_with_nans=True,
do_interpolation=True)

# This should also be identical
fast_img = fast_img.filled(np.nan)
assert np.allclose(fast_img, ref, equal_nan=True)

0 comments on commit 45847c9

Please sign in to comment.