From 863b6e4b32479b44a5b22590c0adc2b4b08db7a0 Mon Sep 17 00:00:00 2001 From: Nadia Dencheva Date: Fri, 13 Dec 2024 16:07:34 -0500 Subject: [PATCH] Respect the bounding_box in inverse transforms (#498) --- CHANGES.rst | 8 +- gwcs/api.py | 9 +- gwcs/tests/conftest.py | 2 +- gwcs/tests/test_api.py | 17 +--- gwcs/tests/test_api_slicing.py | 4 +- gwcs/tests/test_bounding_box.py | 87 +++++++++++++++++ gwcs/tests/test_coordinate_systems.py | 3 - gwcs/tests/test_utils.py | 6 ++ gwcs/tests/test_wcs.py | 8 +- gwcs/utils.py | 11 +-- gwcs/wcs.py | 131 +++++++++++++++++--------- 11 files changed, 212 insertions(+), 74 deletions(-) create mode 100644 gwcs/tests/test_bounding_box.py diff --git a/CHANGES.rst b/CHANGES.rst index cca5803e..d308cb23 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -11,6 +11,7 @@ - Add support for compound bounding boxes and ignored bounding box entries. [#519] + - Add ``gwcs.examples`` module, based on the examples located in the testing ``conftest.py``. [#521] - Force ``bounding_box`` to always be returned as a ``F`` ordered box. [#522] @@ -19,8 +20,11 @@ - Adjust ``world_to_array_index_values`` to round to integer coordinates as specified by APE 14. [#525] -- Add warning filter to asdf extension to prevent the ``bounding_box`` order warning for gwcs - objects originating from a file. [#526] +- Add warning filter to asdf extension to prevent the ``bounding_box`` order warning for gwcs objects originating from a file. [#526] + +- Fixed a bug where evaluating the inverse transform did not + respect the bounding box. [#498] + 0.21.0 (2024-03-10) ------------------- diff --git a/gwcs/api.py b/gwcs/api.py index 2b1ca22e..32f6f9c3 100644 --- a/gwcs/api.py +++ b/gwcs/api.py @@ -132,7 +132,13 @@ def world_to_pixel_values(self, *world_arrays): be returned in the ``(x, y)`` order, where for an image, ``x`` is the horizontal coordinate and ``y`` is the vertical coordinate. """ - world_arrays = self._add_units_input(world_arrays, self.backward_transform, self.output_frame) + try: + backward_transform = self.backward_transform + world_arrays = self._add_units_input(world_arrays, + backward_transform, + self.output_frame) + except NotImplementedError: + pass result = self.invert(*world_arrays, with_units=False) @@ -317,7 +323,6 @@ def world_to_pixel(self, *world_objects): Convert world coordinates to pixel values. """ result = self.invert(*world_objects, with_units=True) - if self.input_frame.naxes > 1: first_res = result[0] if not utils.isnumerical(first_res): diff --git a/gwcs/tests/conftest.py b/gwcs/tests/conftest.py index 83fb216b..92979383 100644 --- a/gwcs/tests/conftest.py +++ b/gwcs/tests/conftest.py @@ -88,8 +88,8 @@ def sellmeier_zemax(): @pytest.fixture(scope="function") def gwcs_3d_galactic_spectral(): - return examples.gwcs_3d_galactic_spectral() + return examples.gwcs_3d_galactic_spectral() @pytest.fixture(scope="function") def gwcs_1d_spectral(): diff --git a/gwcs/tests/test_api.py b/gwcs/tests/test_api.py index f80c8d8a..a676d3fb 100644 --- a/gwcs/tests/test_api.py +++ b/gwcs/tests/test_api.py @@ -63,7 +63,6 @@ def wcs_ndim_types_units(request): @fixture_all_wcses def test_lowlevel_types(wcsobj): - pytest.importorskip("typeguard") try: # Skip this on older versions of astropy where it dosen't exist. from astropy.wcs.wcsapi.tests.utils import validate_low_level_wcs_types @@ -236,12 +235,12 @@ def test_world_axis_object_classes_4d(gwcs_4d_identity_units): def _compare_frame_output(wc1, wc2): if isinstance(wc1, coord.SkyCoord): assert isinstance(wc1.frame, type(wc2.frame)) - assert u.allclose(wc1.spherical.lon, wc2.spherical.lon) - assert u.allclose(wc1.spherical.lat, wc2.spherical.lat) - assert u.allclose(wc1.spherical.distance, wc2.spherical.distance) + assert u.allclose(wc1.spherical.lon, wc2.spherical.lon, equal_nan=True) + assert u.allclose(wc1.spherical.lat, wc2.spherical.lat, equal_nan=True) + assert u.allclose(wc1.spherical.distance, wc2.spherical.distance, equal_nan=True) elif isinstance(wc1, u.Quantity): - assert u.allclose(wc1, wc2) + assert u.allclose(wc1, wc2, equal_nan=True) elif isinstance(wc1, time.Time): assert u.allclose((wc1 - wc2).to(u.s), 0*u.s) @@ -258,12 +257,6 @@ def _compare_frame_output(wc1, wc2): @fixture_all_wcses def test_high_level_wrapper(wcsobj, request): - if request.node.callspec.params['wcsobj'] in ('gwcs_4d_identity_units', 'gwcs_stokes_lookup'): - pytest.importorskip("astropy", minversion="4.0dev0") - - # Remove the bounding box because the type test is a little broken with the - # bounding box. - del wcsobj._pipeline[0].transform.bounding_box hlvl = HighLevelWCSWrapper(wcsobj) @@ -286,8 +279,6 @@ def test_high_level_wrapper(wcsobj, request): def test_stokes_wrapper(gwcs_stokes_lookup): - pytest.importorskip("astropy", minversion="4.0dev0") - hlvl = HighLevelWCSWrapper(gwcs_stokes_lookup) pixel_input = [0, 1, 2, 3] diff --git a/gwcs/tests/test_api_slicing.py b/gwcs/tests/test_api_slicing.py index ba0a9b54..626a0e2a 100644 --- a/gwcs/tests/test_api_slicing.py +++ b/gwcs/tests/test_api_slicing.py @@ -262,8 +262,8 @@ def test_celestial_slice(gwcs_3d_galactic_spectral): assert_allclose(wcs.pixel_to_world_values(39, 44), (10.24, 20, 25)) assert_allclose(wcs.array_index_to_world_values(44, 39), (10.24, 20, 25)) - assert_allclose(wcs.world_to_pixel_values(12.4, 20, 25), (39., 44.)) - assert_equal(wcs.world_to_array_index_values(12.4, 20, 25), (44, 39)) + assert_allclose(wcs.world_to_pixel_values(10.24, 20, 25), (39., 44.)) + assert_equal(wcs.world_to_array_index_values(10.24, 20, 25), (44, 39)) assert_equal(wcs.pixel_bounds, [(-2, 45), (5, 50)]) diff --git a/gwcs/tests/test_bounding_box.py b/gwcs/tests/test_bounding_box.py new file mode 100644 index 00000000..1cc46404 --- /dev/null +++ b/gwcs/tests/test_bounding_box.py @@ -0,0 +1,87 @@ +import numpy as np +from numpy.testing import assert_array_equal, assert_allclose + +import pytest + + +x = [-1, 2, 4, 13] +y = [np.nan, np.nan, 4, np.nan] +y1 = [np.nan, np.nan, 4, np.nan] + + +@pytest.mark.parametrize((("input", "output")), [((2, 4), (2, 4)), + ((100, 200), (np.nan, np.nan)), + ((x, x),(y, y)) + ]) +def test_2d_spatial(gwcs_2d_spatial_shift, input, output): + w = gwcs_2d_spatial_shift + w.bounding_box = ((-.5, 21), (4, 12)) + + assert_array_equal(w.invert(*w(*input)), output) + assert_array_equal(w.world_to_pixel_values(*w.pixel_to_world_values(*input)), output) + assert_array_equal(w.world_to_pixel(w.pixel_to_world(*input)), output) + + +@pytest.mark.parametrize((("input", "output")), [((2, 4), (2, 4)), + ((100, 200), (np.nan, np.nan)), + ((x, x), (y, y)) + ]) +def test_2d_spatial_coordinate(gwcs_2d_quantity_shift, input, output): + w = gwcs_2d_quantity_shift + w.bounding_box = ((-.5, 21), (4, 12)) + + assert_array_equal(w.invert(*w(*input)), output) + assert_array_equal(w.world_to_pixel_values(*w.pixel_to_world_values(*input)), output) + assert_array_equal(w.world_to_pixel(*w.pixel_to_world(*input)), output) + + +@pytest.mark.parametrize((("input", "output")), [((2, 4), (2, 4)), + ((100, 200), (np.nan, np.nan)), + ((x, x), (y, y)) + ]) +def test_2d_spatial_coordinate_reordered(gwcs_2d_spatial_reordered, input, output): + w = gwcs_2d_spatial_reordered + w.bounding_box = ((-.5, 21), (4, 12)) + + assert_array_equal(w.invert(*w(*input)), output) + assert_array_equal(w.world_to_pixel_values(*w.pixel_to_world_values(*input)), output) + assert_array_equal(w.world_to_pixel(w.pixel_to_world(*input)), output) + + +@pytest.mark.parametrize((("input", "output")), [(2, 2), + ((10, 200), (10, np.nan)), + (x, (np.nan, 2, 4, 13)) + ]) +def test_1d_freq(gwcs_1d_freq, input, output): + w = gwcs_1d_freq + w.bounding_box = (-.5, 21) + print(f"input {input}, {output}") + assert_array_equal(w.invert(w(input)), output) + assert_array_equal(w.world_to_pixel_values(w.pixel_to_world_values(input)), output) + assert_array_equal(w.world_to_pixel(w.pixel_to_world(input)), output) + + +@pytest.mark.parametrize((("input", "output")), [((2, 4, 5), (2, 4, 5)), + ((100, 200, 5), (np.nan, np.nan, np.nan)), + ((x, x, x), (y1, y1, y1)) + ]) +def test_3d_spatial_wave(gwcs_3d_spatial_wave, input, output): + w = gwcs_3d_spatial_wave + w.bounding_box = ((-.5, 21), (4, 12), (3, 21)) + + assert_array_equal(w.invert(*w(*input)), output) + assert_array_equal(w.world_to_pixel_values(*w.pixel_to_world_values(*input)), output) + assert_array_equal(w.world_to_pixel(*w.pixel_to_world(*input)), output) + + +@pytest.mark.parametrize((("input", "output")), [((1, 2, 3, 4), (1., 2., 3., 4.)), + ((100, 3, 3, 3), (np.nan, 3, 3, 3)), + ((x, x, x, x), [[np.nan, 2., 4., 13.], + [np.nan, 2., 4., 13.], + [np.nan, 2., 4., 13.], + [np.nan, 2., 4., np.nan]]) + ]) +def test_gwcs_spec_cel_time_4d(gwcs_spec_cel_time_4d, input, output): + w = gwcs_spec_cel_time_4d + + assert_allclose(w.invert(*w(*input, with_bounding_box=False)), output, atol=1e-8) diff --git a/gwcs/tests/test_coordinate_systems.py b/gwcs/tests/test_coordinate_systems.py index 967657f8..ac2c7303 100644 --- a/gwcs/tests/test_coordinate_systems.py +++ b/gwcs/tests/test_coordinate_systems.py @@ -190,7 +190,6 @@ def test_temporal_relative(): assert a[1] == Time("2018-01-01T00:00:00") + 20 * u.s -@pytest.mark.skipif(astropy_version<"4", reason="Requires astropy 4.0 or higher") def test_temporal_absolute(): t = cf.TemporalFrame(reference_frame=Time([], format='isot')) assert t.coordinates("2018-01-01T00:00:00") == Time("2018-01-01T00:00:00") @@ -240,7 +239,6 @@ def test_coordinate_to_quantity_spectral(inp): (Time("2011-01-01T00:00:10"),), (10 * u.s,) ]) -@pytest.mark.skipif(astropy_version<"4", reason="Requires astropy 4.0 or higher.") def test_coordinate_to_quantity_temporal(inp): temp = cf.TemporalFrame(reference_frame=Time("2011-01-01T00:00:00"), unit=u.s) @@ -325,7 +323,6 @@ def test_coordinate_to_quantity_frame_2d(): assert_quantity_allclose(output, exp) -@pytest.mark.skipif(astropy_version<"4", reason="Requires astropy 4.0 or higher.") def test_coordinate_to_quantity_error(): frame = cf.Frame2D(unit=(u.one, u.arcsec)) with pytest.raises(ValueError): diff --git a/gwcs/tests/test_utils.py b/gwcs/tests/test_utils.py index e69ec536..748b1472 100644 --- a/gwcs/tests/test_utils.py +++ b/gwcs/tests/test_utils.py @@ -6,6 +6,8 @@ from astropy import units as u from astropy import coordinates as coord from astropy.modeling import models +from astropy import table + from astropy.tests.helper import assert_quantity_allclose import pytest from numpy.testing import assert_allclose @@ -104,6 +106,10 @@ def test_isnumerical(): assert gwutils.isnumerical(np.array(0, dtype='>f8')) assert gwutils.isnumerical(np.array(0, dtype='>i4')) + # check a table column + t = table.Table(data=[[1,2,3], [4,5,6]], names=['x', 'y']) + assert not gwutils.isnumerical(t['x']) + def test_get_values(): args = 2 * u.cm diff --git a/gwcs/tests/test_wcs.py b/gwcs/tests/test_wcs.py index 5c9af093..23fc82c1 100644 --- a/gwcs/tests/test_wcs.py +++ b/gwcs/tests/test_wcs.py @@ -1163,8 +1163,8 @@ def test_in_image(): assert np.isscalar(w2.in_image(2, 6)) assert not np.isscalar(w2.in_image([2], [6])) - assert w2.in_image(4, 6) - assert not w2.in_image(5, 0) + assert (w2.in_image(4, 6)) + assert not (w2.in_image(5, 0)) assert np.array_equal( w2.in_image( [[9, 10, 11, 15], [8, 9, 67, 98], [2, 2, np.nan, 102]], @@ -1199,6 +1199,7 @@ def test_iter_inv(): *w(x, y), adaptive=True, detect_divergence=True, + tolerance=1e-4, maxiter=50, quiet=False ) assert np.allclose((x, y), (xp, yp)) @@ -1218,6 +1219,7 @@ def test_iter_inv(): xp, yp = w.numerical_inverse( *w(x, y), adaptive=True, + tolerance=1e-5, maxiter=50, detect_divergence=False, quiet=False ) @@ -1252,6 +1254,7 @@ def test_iter_inv(): xp, yp = w.numerical_inverse( *w(x, y, with_bounding_box=False), adaptive=False, + tolerance=1e-5, maxiter=50, detect_divergence=True, quiet=False, with_bounding_box=False @@ -1265,6 +1268,7 @@ def test_iter_inv(): xp, yp = w.numerical_inverse( *w(x, y, with_bounding_box=False), adaptive=False, + tolerance=1e-5, maxiter=50, detect_divergence=True, quiet=False, with_bounding_box=False diff --git a/gwcs/utils.py b/gwcs/utils.py index 104558cf..dcae0558 100644 --- a/gwcs/utils.py +++ b/gwcs/utils.py @@ -12,6 +12,7 @@ from astropy import coordinates as coords from astropy import units as u from astropy.time import Time, TimeDelta +from astropy import table from astropy.wcs import Celprm @@ -470,14 +471,12 @@ def isnumerical(val): Determine if a value is numerical (number or np.array of numbers). """ isnum = True - if isinstance(val, coords.SkyCoord): - isnum = False - elif isinstance(val, u.Quantity): - isnum = False - elif isinstance(val, (Time, TimeDelta)): + astropy_types=(coords.SkyCoord, u.Quantity, Time, TimeDelta, table.Column, table.Row) + if isinstance(val, astropy_types): isnum = False elif (isinstance(val, np.ndarray) and not np.issubdtype(val.dtype, np.floating) - and not np.issubdtype(val.dtype, np.integer)): + and not np.issubdtype(val.dtype, np.integer) + ): isnum = False return isnum diff --git a/gwcs/wcs.py b/gwcs/wcs.py index 2ae714b2..a52dc06e 100644 --- a/gwcs/wcs.py +++ b/gwcs/wcs.py @@ -13,7 +13,11 @@ from astropy.modeling.models import (Const1D, Identity, Mapping, Polynomial2D, RotateCelestial2Native, Shift, Sky2Pix_TAN) +from astropy.modeling.parameters import _tofloat from astropy.wcs.utils import celestial_frame_to_wcs, proj_plane_pixel_scales +from astropy.wcs.wcsapi.high_level_api import high_level_objects_to_values, values_to_high_level_objects + +from astropy import units as u from scipy import linalg, optimize from . import coordinate_frames as cf @@ -396,34 +400,12 @@ def in_image(self, *args, **kwargs): and `False` if input is outside the footprint. """ - kwargs['with_bounding_box'] = True - kwargs['fill_value'] = np.nan - coords = self.invert(*args, **kwargs) result = np.isfinite(coords) if self.input_frame.naxes > 1: result = np.all(result, axis=0) - if self.bounding_box is None or not np.any(result): - return result - - if self.input_frame.naxes == 1: - x1, x2 = self.bounding_box.bounding_box() - - if len(np.shape(args[0])) > 0: - result[result] = (coords[result] >= x1) & (coords[result] <= x2) - elif result: - result = (coords >= x1) and (coords <= x2) - - else: - if len(np.shape(args[0])) > 0: - for c, (x1, x2) in zip(coords, self.bounding_box): - result[result] = (c[result] >= x1) & (c[result] <= x2) - - elif result: - result = all([(c >= x1) and (c <= x2) for c, (x1, x2) in zip(coords, self.bounding_box)]) - return result def invert(self, *args, **kwargs): @@ -470,29 +452,35 @@ def invert(self, *args, **kwargs): """ with_units = kwargs.pop('with_units', False) + try: + btrans = self.backward_transform + except NotImplementedError: + btrans = None if not utils.isnumerical(args[0]): + # convert astropy objects to numbers and arrays args = self.output_frame.coordinate_to_quantity(*args) if self.output_frame.naxes == 1: args = [args] - try: - if not self.backward_transform.uses_quantity: - args = utils.get_values(self.output_frame.unit, *args) - except (NotImplementedError, KeyError): - args = utils.get_values(self.output_frame.unit, *args) - if 'with_bounding_box' not in kwargs: - kwargs['with_bounding_box'] = True + # if the transform does not use units, getthe numerical values + if btrans is not None and not btrans.uses_quantity: + args = utils.get_values(self.output_frame.unit, *args) - if 'fill_value' not in kwargs: - kwargs['fill_value'] = np.nan + with_bounding_box = kwargs.pop('with_bounding_box', True) + fill_value = kwargs.pop('fill_value', np.nan) + akwargs = {k: v for k, v in kwargs.items() if k not in _ITER_INV_KWARGS} + if with_bounding_box and self.bounding_box is not None: + args = self.outside_footprint(args) - try: - # remove iterative inverse-specific keyword arguments: - akwargs = {k: v for k, v in kwargs.items() if k not in _ITER_INV_KWARGS} - result = self.backward_transform(*args, **akwargs) - except (NotImplementedError, KeyError): + if btrans is not None: + result = btrans(*args, **akwargs) + else: result = self.numerical_inverse(*args, **kwargs, with_units=with_units) + # deal with values outside the bounding box + if with_bounding_box and self.bounding_box is not None: + result = self.out_of_bounds(result, fill_value=fill_value) + if with_units and self.input_frame: if self.input_frame.naxes == 1: return self.input_frame.coordinates(result) @@ -501,7 +489,57 @@ def invert(self, *args, **kwargs): else: return result - def numerical_inverse(self, *args, tolerance=1e-5, maxiter=50, adaptive=True, + def outside_footprint(self, world_arrays): + world_arrays = list(world_arrays) + + axes_types = set(self.output_frame.axes_type) + footprint = self.footprint() + not_numerical = False + if not utils.isnumerical(world_arrays[0]): + not_numerical = True + world_arrays = high_level_objects_to_values(*world_arrays, low_level_wcs=self) + for axtyp in axes_types: + ind = np.asarray((np.asarray(self.output_frame.axes_type) == axtyp)) + + for idim, coord in enumerate(world_arrays): + coord = _tofloat(coord) + if np.asarray(ind).sum() > 1: + axis_range = footprint[:, idim] + else: + axis_range = footprint + range = [axis_range.min(), axis_range.max()] + outside = (coord < range[0]) | (coord > range[1]) + if np.any(outside): + if np.isscalar(coord): + coord = np.nan + else: + coord[outside] = np.nan + world_arrays[idim] = coord + if not_numerical: + world_arrays = values_to_high_level_objects(*world_arrays, low_level_wcs=self) + return world_arrays + + + def out_of_bounds(self, pixel_arrays, fill_value=np.nan): + if np.isscalar(pixel_arrays) or self.input_frame.naxes == 1: + pixel_arrays = [pixel_arrays] + + pixel_arrays = list(pixel_arrays) + bbox = self.bounding_box + for idim, pix in enumerate(pixel_arrays): + outside = (pix < bbox[idim][0]) | (pix > bbox[idim][1]) + if np.any(outside): + if np.isscalar(pix): + pixel_arrays[idim] = np.nan + else: + pix = pixel_arrays[idim].astype(float, copy=True) + pix[outside] = np.nan + pixel_arrays[idim] = pix + if self.input_frame.naxes == 1: + pixel_arrays = pixel_arrays[0] + return pixel_arrays + + def numerical_inverse(self, *args, tolerance=1e-5, maxiter=30, adaptive=True, detect_divergence=True, quiet=True, with_bounding_box=True, fill_value=np.nan, with_units=False, **kwargs): """ @@ -683,7 +721,7 @@ def numerical_inverse(self, *args, tolerance=1e-5, maxiter=50, adaptive=True, >>> import numpy as np >>> filename = get_pkg_data_filename('data/nircamwcs.asdf', package='gwcs.tests') - >>> with asdf.open(filename, memmap=False, lazy_load=False, ignore_missing_extensions=True) as af: + >>> with asdf.open(filename, lazy_load=False, ignore_missing_extensions=True) as af: ... w = af.tree['wcs'] >>> ra, dec = w([1,2,3], [1,1,1]) @@ -1411,17 +1449,22 @@ def _order_clockwise(v): if bounding_box is None: if self.bounding_box is None: raise TypeError("Need a valid bounding_box to compute the footprint.") - bb = self.bounding_box + bb = self.bounding_box.bounding_box(order='F') else: bb = bounding_box all_spatial = all([t.lower() == "spatial" for t in self.output_frame.axes_type]) - - if all_spatial: + if self.output_frame.naxes == 1: + if isinstance(bb[0], u.Quantity): + bb = np.asarray([b.value for b in bb]) * bb[0].unit + vertices = (bb,) + elif all_spatial: vertices = _order_clockwise(bb) else: vertices = np.array(list(itertools.product(*bb))).T + # workaround an issue with bbox with quantity, interval needs to be a cquantity, not a list of quantities + # strip units if center: vertices = utils._toindex(vertices) @@ -1435,14 +1478,16 @@ def _order_clockwise(v): axtyp_ind = np.array([t.lower() for t in self.output_frame.axes_type]) == axis_type if not axtyp_ind.any(): raise ValueError('This WCS does not have axis of type "{}".'.format(axis_type)) - result = np.asarray([(r.min(), r.max()) for r in result[axtyp_ind]]) + if len(axtyp_ind) > 1: + result = np.asarray([(r.min(), r.max()) for r in result[axtyp_ind]]) if axis_type == "spatial": result = _order_clockwise(result) else: result.sort() result = np.squeeze(result) - + if self.output_frame.naxes == 1: + return np.array([result]).T return result.T def fix_inputs(self, fixed):