diff --git a/esmvalcore/preprocessor/_area.py b/esmvalcore/preprocessor/_area.py index a47ca29892..7d8d867155 100644 --- a/esmvalcore/preprocessor/_area.py +++ b/esmvalcore/preprocessor/_area.py @@ -21,6 +21,8 @@ from iris.exceptions import CoordinateNotFoundError from esmvalcore.preprocessor._shared import ( + apply_mask, + get_dims_along_axes, get_iris_aggregator, get_normalized_cube, preserve_float_dtype, @@ -188,8 +190,8 @@ def _extract_irregular_region( cube = cube[..., i_slice, j_slice] selection = selection[i_slice, j_slice] # Mask remaining coordinates outside region - mask = da.broadcast_to(~selection, cube.shape) - cube.data = da.ma.masked_where(mask, cube.core_data()) + horizontal_dims = get_dims_along_axes(cube, ["X", "Y"]) + cube.data = apply_mask(~selection, cube.core_data(), horizontal_dims) return cube @@ -857,31 +859,45 @@ def _mask_cube(cube: Cube, masks: dict[str, np.ndarray]) -> Cube: _cube.add_aux_coord( AuxCoord(id_, units="no_unit", long_name="shape_id") ) - mask = da.broadcast_to(mask, _cube.shape) - _cube.data = da.ma.masked_where(~mask, _cube.core_data()) + horizontal_dims = get_dims_along_axes(cube, axes=["X", "Y"]) + _cube.data = apply_mask(~mask, _cube.core_data(), horizontal_dims) cubelist.append(_cube) result = fix_coordinate_ordering(cubelist.merge_cube()) - if cube.cell_measures(): - for measure in cube.cell_measures(): - # Cell measures that are time-dependent, with 4 dimension and - # an original shape of (time, depth, lat, lon), need to be - # broadcasted to the cube with 5 dimensions and shape - # (time, shape_id, depth, lat, lon) - if measure.ndim > 3 and result.ndim > 4: - data = measure.core_data() - data = da.expand_dims(data, axis=(1,)) - data = da.broadcast_to(data, result.shape) - measure = iris.coords.CellMeasure( + for measure in cube.cell_measures(): + # Cell measures that are time-dependent, with 4 dimension and + # an original shape of (time, depth, lat, lon), need to be + # broadcast to the cube with 5 dimensions and shape + # (time, shape_id, depth, lat, lon) + if measure.ndim > 3 and result.ndim > 4: + data = measure.core_data() + if result.has_lazy_data(): + # Make the cell measure lazy if the result is lazy. + cube_chunks = cube.lazy_data().chunks + chunk_dims = cube.cell_measure_dims(measure) + data = da.asarray( data, - standard_name=measure.standard_name, - long_name=measure.long_name, - units=measure.units, - measure=measure.measure, - var_name=measure.var_name, - attributes=measure.attributes, + chunks=tuple(cube_chunks[i] for i in chunk_dims), ) - add_cell_measure(result, measure, measure.measure) - if cube.ancillary_variables(): - for ancillary_variable in cube.ancillary_variables(): - add_ancillary_variable(result, ancillary_variable) + chunks = result.lazy_data().chunks + else: + chunks = None + dim_map = get_dims_along_axes(result, ["T", "Z", "Y", "X"]) + data = iris.util.broadcast_to_shape( + data, + result.shape, + dim_map=dim_map, + chunks=chunks, + ) + measure = iris.coords.CellMeasure( + data, + standard_name=measure.standard_name, + long_name=measure.long_name, + units=measure.units, + measure=measure.measure, + var_name=measure.var_name, + attributes=measure.attributes, + ) + add_cell_measure(result, measure, measure.measure) + for ancillary_variable in cube.ancillary_variables(): + add_ancillary_variable(result, ancillary_variable) return result diff --git a/esmvalcore/preprocessor/_io.py b/esmvalcore/preprocessor/_io.py index 5f83b1946c..83f4d9bae5 100644 --- a/esmvalcore/preprocessor/_io.py +++ b/esmvalcore/preprocessor/_io.py @@ -22,6 +22,7 @@ from esmvalcore.cmor.check import CheckLevels from esmvalcore.esgf.facets import FACETS from esmvalcore.iris_helpers import merge_cube_attributes +from esmvalcore.preprocessor._shared import _rechunk_aux_factory_dependencies from .._task import write_ncl_settings @@ -392,6 +393,7 @@ def concatenate(cubes, check_level=CheckLevels.DEFAULT): cubes = _sort_cubes_by_time(cubes) _fix_calendars(cubes) cubes = _check_time_overlaps(cubes) + cubes = [_rechunk_aux_factory_dependencies(cube) for cube in cubes] result = _concatenate_cubes(cubes, check_level=check_level) if len(result) == 1: diff --git a/esmvalcore/preprocessor/_mask.py b/esmvalcore/preprocessor/_mask.py index 1896475704..1f1d0ddc00 100644 --- a/esmvalcore/preprocessor/_mask.py +++ b/esmvalcore/preprocessor/_mask.py @@ -9,8 +9,7 @@ import logging import os -from collections.abc import Iterable -from typing import Literal, Optional +from typing import Literal import cartopy.io.shapereader as shpreader import dask.array as da @@ -22,7 +21,7 @@ from iris.cube import Cube from iris.util import rolling_window -from esmvalcore.preprocessor._shared import get_array_module +from esmvalcore.preprocessor._shared import apply_mask from ._supplementary_vars import register_supplementaries @@ -61,24 +60,6 @@ def _get_fx_mask( return inmask -def _apply_mask( - mask: np.ndarray | da.Array, - array: np.ndarray | da.Array, - dim_map: Optional[Iterable[int]] = None, -) -> np.ndarray | da.Array: - """Apply a (broadcasted) mask on an array.""" - npx = get_array_module(mask, array) - if dim_map is not None: - if isinstance(array, da.Array): - chunks = array.chunks - else: - chunks = None - mask = iris.util.broadcast_to_shape( - mask, array.shape, dim_map, chunks=chunks - ) - return npx.ma.masked_where(mask, array) - - @register_supplementaries( variables=["sftlf", "sftof"], required="prefer_at_least_one", @@ -145,7 +126,7 @@ def mask_landsea(cube: Cube, mask_out: Literal["land", "sea"]) -> Cube: landsea_mask = _get_fx_mask( ancillary_var.core_data(), mask_out, ancillary_var.var_name ) - cube.data = _apply_mask( + cube.data = apply_mask( landsea_mask, cube.core_data(), cube.ancillary_variable_dims(ancillary_var), @@ -212,7 +193,7 @@ def mask_landseaice(cube: Cube, mask_out: Literal["landsea", "ice"]) -> Cube: landseaice_mask = _get_fx_mask( ancillary_var.core_data(), mask_out, ancillary_var.var_name ) - cube.data = _apply_mask( + cube.data = apply_mask( landseaice_mask, cube.core_data(), cube.ancillary_variable_dims(ancillary_var), @@ -350,10 +331,7 @@ def _mask_with_shp(cube, shapefilename, region_indices=None): else: mask |= shp_vect.contains(region, x_p_180, y_p_90) - if cube.has_lazy_data(): - mask = da.array(mask) - - cube.data = _apply_mask( + cube.data = apply_mask( mask, cube.core_data(), cube.coord_dims("latitude") + cube.coord_dims("longitude"), diff --git a/esmvalcore/preprocessor/_regrid.py b/esmvalcore/preprocessor/_regrid.py index 10511102e2..5bbed48dcf 100644 --- a/esmvalcore/preprocessor/_regrid.py +++ b/esmvalcore/preprocessor/_regrid.py @@ -32,6 +32,7 @@ from esmvalcore.exceptions import ESMValCoreDeprecationWarning from esmvalcore.iris_helpers import has_irregular_grid, has_unstructured_grid from esmvalcore.preprocessor._shared import ( + _rechunk_aux_factory_dependencies, get_array_module, get_dims_along_axes, preserve_float_dtype, @@ -1174,36 +1175,6 @@ def parse_vertical_scheme(scheme): return scheme, extrap_scheme -def _rechunk_aux_factory_dependencies( - cube: iris.cube.Cube, - coord_name: str, -) -> iris.cube.Cube: - """Rechunk coordinate aux factory dependencies. - - This ensures that the resulting coordinate has reasonably sized - chunks that are aligned with the cube data for optimal computational - performance. - """ - # Workaround for https://github.com/SciTools/iris/issues/5457 - try: - factory = cube.aux_factory(coord_name) - except iris.exceptions.CoordinateNotFoundError: - return cube - - cube = cube.copy() - cube_chunks = cube.lazy_data().chunks - for coord in factory.dependencies.values(): - coord_dims = cube.coord_dims(coord) - if coord_dims: - coord = coord.copy() - chunks = tuple(cube_chunks[i] for i in coord_dims) - coord.points = coord.lazy_points().rechunk(chunks) - if coord.has_bounds(): - coord.bounds = coord.lazy_bounds().rechunk(chunks + (None,)) - cube.replace_coord(coord) - return cube - - @preserve_float_dtype def extract_levels( cube: iris.cube.Cube, diff --git a/esmvalcore/preprocessor/_shared.py b/esmvalcore/preprocessor/_shared.py index 2355215800..adf45ca1c2 100644 --- a/esmvalcore/preprocessor/_shared.py +++ b/esmvalcore/preprocessor/_shared.py @@ -517,3 +517,81 @@ def get_dims_along_coords( """Get a tuple with the dimensions along one or more coordinates.""" dims = {d for coord in coords for d in _get_dims_along(cube, coord)} return tuple(sorted(dims)) + + +def apply_mask( + mask: np.ndarray | da.Array, + array: np.ndarray | da.Array, + dim_map: Iterable[int], +) -> np.ma.MaskedArray | da.Array: + """Apply a (broadcasted) mask on an array. + + Parameters + ---------- + mask: + The mask to apply to array. + array: + The array to mask out. + dim_map : + A mapping of the dimensions of *mask* to their corresponding + dimension in *array*. + See :func:`iris.util.broadcast_to_shape` for additional details. + + Returns + ------- + np.ma.MaskedArray or da.Array: + A copy of the input array with the mask applied. + + """ + if isinstance(array, da.Array): + array_chunks = array.chunks + # If the mask is not a Dask array yet, we make it into a Dask array + # before broadcasting to avoid inserting a large array into the Dask + # graph. + mask_chunks = tuple(array_chunks[i] for i in dim_map) + mask = da.asarray(mask, chunks=mask_chunks) + else: + array_chunks = None + + mask = iris.util.broadcast_to_shape( + mask, array.shape, dim_map=dim_map, chunks=array_chunks + ) + + array_module = get_array_module(mask, array) + return array_module.ma.masked_where(mask, array) + + +def _rechunk_aux_factory_dependencies( + cube: iris.cube.Cube, + coord_name: str | None = None, +) -> iris.cube.Cube: + """Rechunk coordinate aux factory dependencies. + + This ensures that the resulting coordinate has reasonably sized + chunks that are aligned with the cube data for optimal computational + performance. + """ + # Workaround for https://github.com/SciTools/iris/issues/5457 + if coord_name is None: + factories = cube.aux_factories + else: + try: + factories = [cube.aux_factory(coord_name)] + except iris.exceptions.CoordinateNotFoundError: + return cube + + cube = cube.copy() + cube_chunks = cube.lazy_data().chunks + for factory in factories: + for coord in factory.dependencies.values(): + coord_dims = cube.coord_dims(coord) + if coord_dims: + coord = coord.copy() + chunks = tuple(cube_chunks[i] for i in coord_dims) + coord.points = coord.lazy_points().rechunk(chunks) + if coord.has_bounds(): + coord.bounds = coord.lazy_bounds().rechunk( + chunks + (None,) + ) + cube.replace_coord(coord) + return cube diff --git a/esmvalcore/preprocessor/_time.py b/esmvalcore/preprocessor/_time.py index b3e4ab5b0f..ac00f13d01 100644 --- a/esmvalcore/preprocessor/_time.py +++ b/esmvalcore/preprocessor/_time.py @@ -1286,6 +1286,10 @@ def timeseries_filter( # Apply filter (agg, agg_kwargs) = get_iris_aggregator(filter_stats, **operator_kwargs) agg_kwargs["weights"] = wgts + if cube.has_lazy_data(): + # Ensure the cube data chunktype is np.MaskedArray so rolling_window + # does not ignore a potential mask. + cube.data = da.ma.masked_array(cube.core_data()) cube = cube.rolling_window("time", agg, len(wgts), **agg_kwargs) return cube diff --git a/tests/unit/preprocessor/_area/test_area.py b/tests/unit/preprocessor/_area/test_area.py index ec741b629a..4e5f19c28d 100644 --- a/tests/unit/preprocessor/_area/test_area.py +++ b/tests/unit/preprocessor/_area/test_area.py @@ -871,7 +871,10 @@ def test_extract_shape_natural_earth(make_testcube, ne_ocean_shapefile): np.testing.assert_array_equal(result.data.data, expected) -def test_extract_shape_fx(make_testcube, ne_ocean_shapefile): +@pytest.mark.parametrize("lazy", [True, False]) +def test_extract_shape_with_supplementaries( + make_testcube, ne_ocean_shapefile, lazy +): """Test for extracting a shape from NE file.""" expected = np.ones((5, 5)) cube = make_testcube @@ -888,6 +891,10 @@ def test_extract_shape_fx(make_testcube, ne_ocean_shapefile): var_name="sftgif", units="%", ) + if lazy: + cube.data = cube.lazy_data() + measure.data = measure.lazy_data() + ancillary_var.data = ancillary_var.lazy_data() cube.add_cell_measure(measure, (0, 1)) cube.add_ancillary_variable(ancillary_var, (0, 1)) result = extract_shape( @@ -895,17 +902,20 @@ def test_extract_shape_fx(make_testcube, ne_ocean_shapefile): ne_ocean_shapefile, crop=False, ) + assert result.has_lazy_data() is lazy np.testing.assert_array_equal(result.data.data, expected) assert result.cell_measures() - result_measure = result.cell_measure("cell_area").data - np.testing.assert_array_equal(measure.data, result_measure) + result_measure = result.cell_measure("cell_area") + assert result_measure.has_lazy_data() is lazy + np.testing.assert_array_equal(measure.data, result_measure.data) assert result.ancillary_variables() - result_ancillary_var = result.ancillary_variable( - "land_ice_area_fraction" - ).data - np.testing.assert_array_equal(ancillary_var.data, result_ancillary_var) + result_ancillary_var = result.ancillary_variable("land_ice_area_fraction") + assert result_ancillary_var.has_lazy_data() is lazy + np.testing.assert_array_equal( + ancillary_var.data, result_ancillary_var.data + ) def test_extract_shape_ne_check_nans(ne_ocean_shapefile): @@ -1471,7 +1481,8 @@ def test_meridional_statistics_invalid_norm_fail(make_testcube): meridional_statistics(make_testcube, "sum", normalize="x") -def test_time_dependent_volcello(): +@pytest.mark.parametrize("lazy", [True, False]) +def test_time_dependent_volcello(lazy): coord_sys = iris.coord_systems.GeogCS(iris.fileformats.pp.EARTH_RADIUS) data = np.ma.ones((2, 3, 2, 2)) @@ -1508,8 +1519,11 @@ def test_time_dependent_volcello(): volcello = iris.coords.CellMeasure( data, standard_name="ocean_volume", units="m3", measure="volume" ) + if lazy: + cube.data = cube.lazy_data() + volcello.data = volcello.lazy_data() cube.add_cell_measure(volcello, range(0, volcello.ndim)) - cube = extract_shape( + result = extract_shape( cube, "AR6", method="contains", @@ -1517,8 +1531,11 @@ def test_time_dependent_volcello(): decomposed=True, ids={"Acronym": ["EAO", "WAF"]}, ) + assert cube.has_lazy_data() is lazy + assert volcello.has_lazy_data() is lazy + assert result.has_lazy_data() is lazy - assert cube.shape == cube.cell_measure("ocean_volume").shape + assert result.shape == result.cell_measure("ocean_volume").shape if __name__ == "__main__": diff --git a/tests/unit/preprocessor/_mask/test_mask.py b/tests/unit/preprocessor/_mask/test_mask.py index 59b383c59a..dc6bfba162 100644 --- a/tests/unit/preprocessor/_mask/test_mask.py +++ b/tests/unit/preprocessor/_mask/test_mask.py @@ -9,7 +9,6 @@ import tests from esmvalcore.preprocessor._mask import ( - _apply_mask, _get_fx_mask, count_spells, mask_above_threshold, @@ -59,30 +58,6 @@ def setUp(self): ) self.fx_data = np.array([20.0, 60.0, 50.0]) - def test_apply_fx_mask_on_nonmasked_data(self): - """Test _apply_fx_mask func.""" - dummy_fx_mask = np.ma.array((True, False, True)) - app_mask = _apply_mask( - dummy_fx_mask, self.time_cube.data[0:3].astype("float64") - ) - fixed_mask = np.ma.array( - self.time_cube.data[0:3].astype("float64"), mask=dummy_fx_mask - ) - self.assert_array_equal(fixed_mask, app_mask) - - def test_apply_fx_mask_on_masked_data(self): - """Test _apply_fx_mask func.""" - dummy_fx_mask = np.ma.array((True, True, True)) - masked_data = np.ma.array( - self.time_cube.data[0:3].astype("float64"), - mask=np.ma.array((False, True, False)), - ) - app_mask = _apply_mask(dummy_fx_mask, masked_data) - fixed_mask = np.ma.array( - self.time_cube.data[0:3].astype("float64"), mask=dummy_fx_mask - ) - self.assert_array_equal(fixed_mask, app_mask) - def test_count_spells(self): """Test count_spells func.""" ref_spells = count_spells(self.time_cube.data, -1000.0, 0, 1) diff --git a/tests/unit/preprocessor/_regrid/test_extract_levels.py b/tests/unit/preprocessor/_regrid/test_extract_levels.py index e1b14b7a14..ec00d45438 100644 --- a/tests/unit/preprocessor/_regrid/test_extract_levels.py +++ b/tests/unit/preprocessor/_regrid/test_extract_levels.py @@ -2,10 +2,8 @@ from unittest import mock -import dask.array as da import iris import numpy as np -from iris.aux_factory import HybridPressureFactory from numpy import ma import tests @@ -13,7 +11,6 @@ _MDI, VERTICAL_SCHEMES, _preserve_fx_vars, - _rechunk_aux_factory_dependencies, extract_levels, parse_vertical_scheme, ) @@ -349,56 +346,3 @@ def test_interpolation__masked(self): self.assert_array_equal(args[3], levels) # Check the _create_cube kwargs ... self.assertEqual(kwargs, dict()) - - -def test_rechunk_aux_factory_dependencies(): - delta = iris.coords.AuxCoord( - points=np.array([0.0, 1.0, 2.0], dtype=np.float64), - bounds=np.array( - [[-0.5, 0.5], [0.5, 1.5], [1.5, 2.5]], dtype=np.float64 - ), - long_name="level_pressure", - units="Pa", - ) - sigma = iris.coords.AuxCoord( - np.array([1.0, 0.9, 0.8], dtype=np.float64), - long_name="sigma", - units="1", - ) - surface_air_pressure = iris.coords.AuxCoord( - np.arange(4).astype(np.float64).reshape(2, 2), - long_name="surface_air_pressure", - units="Pa", - ) - factory = HybridPressureFactory( - delta=delta, - sigma=sigma, - surface_air_pressure=surface_air_pressure, - ) - - cube = iris.cube.Cube( - da.asarray( - np.arange(3 * 2 * 2).astype(np.float32).reshape(3, 2, 2), - chunks=(1, 2, 2), - ), - ) - cube.add_aux_coord(delta, 0) - cube.add_aux_coord(sigma, 0) - cube.add_aux_coord(surface_air_pressure, [1, 2]) - cube.add_aux_factory(factory) - - result = _rechunk_aux_factory_dependencies(cube, "air_pressure") - - # Check that the 'air_pressure' coordinate of the resulting cube has been - # rechunked: - assert ( - (1, 1, 1), - (2,), - (2,), - ) == result.coord("air_pressure").core_points().chunks - # Check that the original cube has not been modified: - assert ( - (3,), - (2,), - (2,), - ) == cube.coord("air_pressure").core_points().chunks diff --git a/tests/unit/preprocessor/_time/test_time.py b/tests/unit/preprocessor/_time/test_time.py index 6a9d78747d..e6c7ca09e6 100644 --- a/tests/unit/preprocessor/_time/test_time.py +++ b/tests/unit/preprocessor/_time/test_time.py @@ -1528,18 +1528,24 @@ def test_regrid_time_hour_no_divisor_of_24(cube_1d_time, freq): regrid_time(cube_1d_time, freq) -class TestTimeseriesFilter(tests.Test): +class TestTimeseriesFilter: """Tests for timeseries filter.""" + @pytest.fixture(autouse=True) def setUp(self): """Prepare tests.""" self.cube = _create_sample_cube() - def test_timeseries_filter_simple(self): + @pytest.mark.parametrize("lazy", [True, False]) + def test_timeseries_filter_simple(self, lazy): """Test timeseries_filter func.""" + if lazy: + self.cube.data = self.cube.lazy_data() filtered_cube = timeseries_filter( self.cube, 7, 14, filter_type="lowpass", filter_stats="sum" ) + if lazy: + assert filtered_cube.has_lazy_data() expected_data = np.array( [ 2.44824568, @@ -1569,14 +1575,14 @@ def test_timeseries_filter_timecoord(self): """Test missing time axis.""" new_cube = self.cube.copy() new_cube.remove_coord(new_cube.coord("time")) - with self.assertRaises(iris.exceptions.CoordinateNotFoundError): + with pytest.raises(iris.exceptions.CoordinateNotFoundError): timeseries_filter( new_cube, 7, 14, filter_type="lowpass", filter_stats="sum" ) def test_timeseries_filter_implemented(self): """Test a not implemented filter.""" - with self.assertRaises(NotImplementedError): + with pytest.raises(NotImplementedError): timeseries_filter( self.cube, 7, 14, filter_type="bypass", filter_stats="sum" ) diff --git a/tests/unit/preprocessor/test_shared.py b/tests/unit/preprocessor/test_shared.py index 90dd04135b..b449e1998f 100644 --- a/tests/unit/preprocessor/test_shared.py +++ b/tests/unit/preprocessor/test_shared.py @@ -8,6 +8,7 @@ import numpy as np import pytest from cf_units import Unit +from iris.aux_factory import HybridPressureFactory from iris.coords import AuxCoord from iris.cube import Cube @@ -15,12 +16,15 @@ from esmvalcore.preprocessor._shared import ( _compute_area_weights, _group_products, + _rechunk_aux_factory_dependencies, aggregator_accept_weights, + apply_mask, get_array_module, get_iris_aggregator, preserve_float_dtype, try_adding_calculated_cell_area, ) +from tests import assert_array_equal @pytest.mark.parametrize("operator", ["gmean", "GmEaN", "GMEAN"]) @@ -330,3 +334,101 @@ def test_try_adding_calculated_cell_area(): try_adding_calculated_cell_area(cube) assert cube.cell_measures("cell_area") + + +@pytest.mark.parametrize( + ["mask", "array", "dim_map", "expected"], + [ + ( + np.arange(2), + da.arange(2), + (0,), + da.ma.masked_array(np.arange(2), np.arange(2)), + ), + ( + da.arange(2), + np.arange(2), + (0,), + da.ma.masked_array(np.arange(2), np.arange(2)), + ), + ( + np.ma.masked_array(np.arange(2), mask=[1, 0]), + da.arange(2), + (0,), + da.ma.masked_array(np.ones(2), np.arange(2)), + ), + ( + np.ones((2, 5)), + da.zeros((2, 3, 5), chunks=(1, 2, 3)), + (0, 2), + da.ma.masked_array( + da.zeros((2, 3, 5), da.ones(2, 3, 5), chunks=(1, 2, 3)) + ), + ), + ( + np.arange(2), + np.ones((3, 2)), + (1,), + np.ma.masked_array(np.ones((3, 2)), mask=[[0, 1], [0, 1], [0, 1]]), + ), + ], +) +def test_apply_mask(mask, array, dim_map, expected): + result = apply_mask(mask, array, dim_map) + assert isinstance(result, type(expected)) + if isinstance(expected, da.Array): + assert result.chunks == expected.chunks + assert_array_equal(result, expected) + + +def test_rechunk_aux_factory_dependencies(): + delta = iris.coords.AuxCoord( + points=np.array([0.0, 1.0, 2.0], dtype=np.float64), + bounds=np.array( + [[-0.5, 0.5], [0.5, 1.5], [1.5, 2.5]], dtype=np.float64 + ), + long_name="level_pressure", + units="Pa", + ) + sigma = iris.coords.AuxCoord( + np.array([1.0, 0.9, 0.8], dtype=np.float64), + long_name="sigma", + units="1", + ) + surface_air_pressure = iris.coords.AuxCoord( + np.arange(4).astype(np.float64).reshape(2, 2), + long_name="surface_air_pressure", + units="Pa", + ) + factory = HybridPressureFactory( + delta=delta, + sigma=sigma, + surface_air_pressure=surface_air_pressure, + ) + + cube = iris.cube.Cube( + da.asarray( + np.arange(3 * 2 * 2).astype(np.float32).reshape(3, 2, 2), + chunks=(1, 2, 2), + ), + ) + cube.add_aux_coord(delta, 0) + cube.add_aux_coord(sigma, 0) + cube.add_aux_coord(surface_air_pressure, [1, 2]) + cube.add_aux_factory(factory) + + result = _rechunk_aux_factory_dependencies(cube, "air_pressure") + + # Check that the 'air_pressure' coordinate of the resulting cube has been + # rechunked: + assert ( + (1, 1, 1), + (2,), + (2,), + ) == result.coord("air_pressure").core_points().chunks + # Check that the original cube has not been modified: + assert ( + (3,), + (2,), + (2,), + ) == cube.coord("air_pressure").core_points().chunks