From 9a716a71e1e04ac6f5cce77a2b54b39ee22c2d4b Mon Sep 17 00:00:00 2001 From: Gregory Petrochenkov Date: Thu, 26 Oct 2023 09:51:42 -0700 Subject: [PATCH 1/2] Progress on integer sentinel value --- pyproject.toml | 2 +- src/gval/accessors/gval_xarray.py | 4 -- .../comparison/compute_continuous_metrics.py | 39 +++++++++++++++++++ src/gval/statistics/continuous_stat_utils.py | 4 ++ 4 files changed, 44 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index a6838128..171813b9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,7 @@ authors = [ requires-python = ">=3.8" keywords = ["geospatial", "evaluations"] license = {text = "MIT"} -version = "0.2.3.1" +version = "0.2.4" dynamic = ["readme", "dependencies"] diff --git a/src/gval/accessors/gval_xarray.py b/src/gval/accessors/gval_xarray.py index be74e608..8779a151 100644 --- a/src/gval/accessors/gval_xarray.py +++ b/src/gval/accessors/gval_xarray.py @@ -364,10 +364,6 @@ def continuous_compare( else: c_is_int, b_is_int = map(integer_check, [candidate, benchmark]) - if c_is_int or b_is_int: # pragma: no cover - raise TypeError( - "Cannot compute continuous statistics on data with type integer" - ) # --------------------------------------------------------------------------------------- results = candidate.gval.compute_agreement_map( diff --git a/src/gval/comparison/compute_continuous_metrics.py b/src/gval/comparison/compute_continuous_metrics.py index 73f73d8b..c4f93142 100644 --- a/src/gval/comparison/compute_continuous_metrics.py +++ b/src/gval/comparison/compute_continuous_metrics.py @@ -7,14 +7,17 @@ from typing import Iterable, Union +import numpy as np import pandera as pa import pandas as pd from pandera.typing import DataFrame import xarray as xr import geopandas as gpd +import dask as da from gval import ContStats from gval.utils.schemas import Metrics_df, Subsample_identifiers, Sample_identifiers +from gval.utils.loading_datasets import _check_dask_array @pa.check_types @@ -69,6 +72,42 @@ def _compute_continuous_metrics( for idx, (agreement, benchmark, candidate) in enumerate( zip(agreement_map, benchmark_map, candidate_map) ): + + is_dsk = _check_dask_array(candidate) + is_int = np.issubdtype(candidate.dtype, np.integer) if isinstance(candidate, xr.DataArray) \ + else np.issubdtype(candidate['band_1'].dtype, np.integer) + nodata = candidate.rio.nodata if isinstance(candidate, xr.DataArray) else candidate['band_1'].rio.nodata + + picked_coords = None + + # Remove no data value if int type form calculation, otherwise leave all values in + # Necessary because there is not int sentinel value + if is_int and nodata is not None: + + cmask, bmask = (xr.where(candidate == nodata, 0, 1), xr.where(benchmark == nodata, 0, 1)) + tmask = cmask & bmask + + if is_dsk: + grid_coords = da.array.asarray( + da.array.meshgrid(candidate.coords['x'], candidate.coords['y']) + ).T.reshape(-1, 2) + picked_coords = grid_coords[da.array.ravel(tmask.data).astype(bool), :] + else: + grid_coords = np.array( + np.meshgrid(candidate.coords['x'], candidate.coords['y']) + ).T.reshape(-1, 2) + picked_coords = grid_coords[np.ravel(tmask.data).astype(bool), :] + + candidate = candidate.sel({'x': picked_coords[:, 0], 'y': picked_coords[:, 1]}) \ + if picked_coords is not None else candidate + + benchmark = benchmark.sel({'x': picked_coords[:, 0], 'y': picked_coords[:, 1]}) \ + if picked_coords is not None else benchmark + + agreement = agreement.sel({'x': picked_coords[:, 0], 'y': picked_coords[:, 1]}) \ + if picked_coords is not None else agreement + + # compute error based metrics such as MAE, MSE, RMSE, etc. from agreement map and produce metrics_df statistics, names = ContStats.process_statistics( metrics, diff --git a/src/gval/statistics/continuous_stat_utils.py b/src/gval/statistics/continuous_stat_utils.py index 00cfde58..407fe415 100644 --- a/src/gval/statistics/continuous_stat_utils.py +++ b/src/gval/statistics/continuous_stat_utils.py @@ -7,6 +7,8 @@ import xarray as xr +from gval.utils.loading_datasets import _check_dask_array + def convert_output(func: Callable) -> Callable: # pragma: no cover """ @@ -31,6 +33,8 @@ def wrapper(*args, **kwargs): # Call the decorated function result = func(*args, **kwargs) + if _check_dask_array(result): + result = result.compute() if isinstance(result, xr.DataArray): # Convert to a single numeric value return result.item() From a874dfa7447d12965d179394dec7d050ba97859c Mon Sep 17 00:00:00 2001 From: Gregory Petrochenkov Date: Thu, 2 Nov 2023 21:52:14 -0400 Subject: [PATCH 2/2] Address continuous compare int dtypes --- src/gval/accessors/gval_xarray.py | 17 -- .../comparison/compute_continuous_metrics.py | 172 +++++++++++++----- src/gval/statistics/continuous_stat_utils.py | 12 +- src/gval/utils/loading_datasets.py | 30 +++ tests/cases_catalogs.py | 49 +++-- tests/cases_compute_continuous_metrics.py | 2 +- 6 files changed, 203 insertions(+), 79 deletions(-) diff --git a/src/gval/accessors/gval_xarray.py b/src/gval/accessors/gval_xarray.py index 8779a151..61224674 100644 --- a/src/gval/accessors/gval_xarray.py +++ b/src/gval/accessors/gval_xarray.py @@ -1,6 +1,5 @@ from typing import Iterable, Optional, Tuple, Union, Callable, Dict, List from numbers import Number -from functools import partial import numpy as np import numba as nb @@ -350,22 +349,6 @@ def continuous_compare( benchmark_map, target_map, resampling, rasterize_attributes ) - # Check whether either dataset is of integer type - integer_check = partial(np.issubdtype, arg2=np.integer) - - # Temporary code to check if the datatype is an integer -------------------------------- - if isinstance(candidate, xr.Dataset): - c_is_int, b_is_int = False, False - for c_var, b_var in zip(candidate.data_vars, benchmark.data_vars): - c_is_int, b_is_int = np.bitwise_or( - list(map(integer_check, [candidate[c_var], benchmark[b_var]])), - [c_is_int, b_is_int], - ) - else: - c_is_int, b_is_int = map(integer_check, [candidate, benchmark]) - - # --------------------------------------------------------------------------------------- - results = candidate.gval.compute_agreement_map( benchmark_map=benchmark, comparison_function=difference, diff --git a/src/gval/comparison/compute_continuous_metrics.py b/src/gval/comparison/compute_continuous_metrics.py index c4f93142..01a6b669 100644 --- a/src/gval/comparison/compute_continuous_metrics.py +++ b/src/gval/comparison/compute_continuous_metrics.py @@ -5,7 +5,7 @@ # __all__ = ['*'] __author__ = "Fernando Aristizabal" -from typing import Iterable, Union +from typing import Iterable, Union, List import numpy as np import pandera as pa @@ -17,7 +17,89 @@ from gval import ContStats from gval.utils.schemas import Metrics_df, Subsample_identifiers, Sample_identifiers -from gval.utils.loading_datasets import _check_dask_array +from gval.utils.loading_datasets import _check_dask_array, _convert_to_dataset + + +def _get_selected_datasets( + agreement: xr.Dataset, + candidate: xr.Dataset, + benchmark: xr.Dataset, + nodata: list, + var_name: str, +) -> List[xr.Dataset]: + """ + Selects specific coordinates for integer valued datasets to not process nodata values + + Parameters + ---------- + agreement : xr.Dataset + Agreement Map + candidate : xr.Dataset + Candidate Map + benchmark : xr.Dataset + Benchmark Map + nodata : list + Nodata values in the list + var_name : str + Name of variable + + Returns + ------- + List[xr.Dataset, xr.Dataset, xr.Dataset] + Datasets with selected coordinates + """ + + is_dsk = _check_dask_array(agreement) + cmask, bmask = ( + xr.where(candidate[var_name] == nodata, 0, 1), + xr.where(benchmark[var_name] == nodata, 0, 1), + ) + tmask = cmask & bmask + + # Create a coord meshgrid and select appropriate coords to select from xarray + if is_dsk: + with da.config.set({"array.slicing.split_large_chunks": True}): + grid_coords = da.array.asarray( + da.array.meshgrid(candidate.coords["x"], candidate.coords["y"]) + ).T.reshape(-1, 2) + picked_coords = grid_coords[da.array.ravel(tmask.data).astype(bool), :] + else: + grid_coords = np.array( + np.meshgrid(candidate.coords["x"], candidate.coords["y"]) + ).T.reshape(-1, 2) + picked_coords = grid_coords[np.ravel(tmask.data).astype(bool), :] + + # Select coordinates from xarray + with da.config.set({"array.slicing.split_large_chunks": True}): + agreement_sel = ( + agreement[var_name].sel( + {"x": picked_coords[:, 0], "y": picked_coords[:, 1]} + ) + if picked_coords is not None + else agreement[var_name] + ) + + candidate_sel = ( + candidate[var_name].sel( + {"x": picked_coords[:, 0], "y": picked_coords[:, 1]} + ) + if picked_coords is not None + else candidate[var_name] + ) + + benchmark_sel = ( + benchmark[var_name].sel( + {"x": picked_coords[:, 0], "y": picked_coords[:, 1]} + ) + if picked_coords is not None + else benchmark[var_name] + ) + + return ( + agreement_sel, + candidate_sel, + benchmark_sel, + ) @pa.check_types @@ -72,49 +154,53 @@ def _compute_continuous_metrics( for idx, (agreement, benchmark, candidate) in enumerate( zip(agreement_map, benchmark_map, candidate_map) ): - - is_dsk = _check_dask_array(candidate) - is_int = np.issubdtype(candidate.dtype, np.integer) if isinstance(candidate, xr.DataArray) \ - else np.issubdtype(candidate['band_1'].dtype, np.integer) - nodata = candidate.rio.nodata if isinstance(candidate, xr.DataArray) else candidate['band_1'].rio.nodata - - picked_coords = None + # Change data to Dataset if DataArray + agreement = _convert_to_dataset(agreement) + candidate = _convert_to_dataset(candidate) + benchmark = _convert_to_dataset(benchmark) + + # Check if integer type and nodata values + is_int = ( + np.issubdtype(candidate.dtype, np.integer) + if isinstance(candidate, xr.DataArray) + else np.issubdtype(candidate["band_1"].dtype, np.integer) + ) + nodata = [agreement[x].rio.nodata for x in agreement.data_vars] # Remove no data value if int type form calculation, otherwise leave all values in - # Necessary because there is not int sentinel value - if is_int and nodata is not None: - - cmask, bmask = (xr.where(candidate == nodata, 0, 1), xr.where(benchmark == nodata, 0, 1)) - tmask = cmask & bmask - - if is_dsk: - grid_coords = da.array.asarray( - da.array.meshgrid(candidate.coords['x'], candidate.coords['y']) - ).T.reshape(-1, 2) - picked_coords = grid_coords[da.array.ravel(tmask.data).astype(bool), :] - else: - grid_coords = np.array( - np.meshgrid(candidate.coords['x'], candidate.coords['y']) - ).T.reshape(-1, 2) - picked_coords = grid_coords[np.ravel(tmask.data).astype(bool), :] - - candidate = candidate.sel({'x': picked_coords[:, 0], 'y': picked_coords[:, 1]}) \ - if picked_coords is not None else candidate - - benchmark = benchmark.sel({'x': picked_coords[:, 0], 'y': picked_coords[:, 1]}) \ - if picked_coords is not None else benchmark - - agreement = agreement.sel({'x': picked_coords[:, 0], 'y': picked_coords[:, 1]}) \ - if picked_coords is not None else agreement - - - # compute error based metrics such as MAE, MSE, RMSE, etc. from agreement map and produce metrics_df - statistics, names = ContStats.process_statistics( - metrics, - error=agreement, - candidate_map=candidate, - benchmark_map=benchmark, - ) + # Necessary because there is not an int sentinel value + if is_int and np.all([x is not None for x in nodata]): + final_stats = [] + # Iterate through each band and gather statistics + for nodata_idx, var_name in enumerate(agreement.data_vars): + # Create mask for all nodata values + agreement_sel, candidate_sel, benchmark_sel = _get_selected_datasets( + agreement, candidate, benchmark, nodata[nodata_idx], var_name + ) + + statistics, names = ContStats.process_statistics( + metrics, + error=agreement_sel, + candidate_map=candidate_sel, + benchmark_map=benchmark_sel, + ) + + del agreement_sel, candidate_sel, benchmark_sel + + final_stats.append(statistics) + + statistics = [ + {f"band_{idx + 1}": val for idx, val in enumerate(lst)} + for lst in np.array(final_stats).T + ] + + else: + statistics, names = ContStats.process_statistics( + metrics, + error=agreement, + candidate_map=candidate, + benchmark_map=benchmark, + ) # create metrics_df metric_df = dict() diff --git a/src/gval/statistics/continuous_stat_utils.py b/src/gval/statistics/continuous_stat_utils.py index 407fe415..f9b7f259 100644 --- a/src/gval/statistics/continuous_stat_utils.py +++ b/src/gval/statistics/continuous_stat_utils.py @@ -33,14 +33,18 @@ def wrapper(*args, **kwargs): # Call the decorated function result = func(*args, **kwargs) - if _check_dask_array(result): - result = result.compute() + is_dsk = _check_dask_array(result) if isinstance(result, xr.DataArray): # Convert to a single numeric value - return result.item() + return result.compute().item() if is_dsk else result.item() elif isinstance(result, xr.Dataset): # Convert to a dictionary with band names as keys - return {band: result[band].item() for band in result.data_vars} + if is_dsk: + return { + band: result[band].compute().item() for band in result.data_vars + } + else: + return {band: result[band].item() for band in result.data_vars} return result diff --git a/src/gval/utils/loading_datasets.py b/src/gval/utils/loading_datasets.py index bb61dae9..e0481d7e 100644 --- a/src/gval/utils/loading_datasets.py +++ b/src/gval/utils/loading_datasets.py @@ -290,3 +290,33 @@ def _parse_string_attributes( } return obj + + +def _convert_to_dataset(xr_object=Union[xr.DataArray, xr.Dataset]) -> xr.Dataset: + """ + Converts xarray object to dataset if it is not one already. + + Parameters + ---------- + xr_object : Union[xr.DataArray, xr.Dataset] + Xarray object to convert or simply return + + Returns + ------- + xr.Dataset + Dataset object + + """ + + if isinstance(xr_object, xr.DataArray): + nodata = xr_object.rio.nodata + xr_object = xr_object.to_dataset(dim="band") + xr_object = xr_object.rename_vars({x: f"band_{x}" for x in xr_object.data_vars}) + + # Account for nodata + for var_name in xr_object.data_vars: + xr_object[var_name] = xr_object[var_name].rio.write_nodata(nodata) + + return xr_object + else: + return xr_object diff --git a/tests/cases_catalogs.py b/tests/cases_catalogs.py index 2a19ab9b..e216d980 100644 --- a/tests/cases_catalogs.py +++ b/tests/cases_catalogs.py @@ -79,26 +79,41 @@ pd.DataFrame( { "map_id_candidate": [ - f"{TEST_DATA_DIR}/candidate_continuous_0.tif", - f"{TEST_DATA_DIR}/candidate_continuous_1.tif", + "s3://gval-test/candidate_continuous_0.tif", + "s3://gval-test/candidate_continuous_1.tif", + "s3://gval-test/candidate_continuous_1.tif", ], - "compare_id": ["compare1", "compare2"], + "compare_id": ["compare1", "compare2", "compare2"], "map_id_benchmark": [ - f"{TEST_DATA_DIR}/benchmark_continuous_0.tif", - f"{TEST_DATA_DIR}/benchmark_continuous_1.tif", + "s3://gval-test/benchmark_continuous_0.tif", + "s3://gval-test/benchmark_continuous_1.tif", + "s3://gval-test/benchmark_continuous_1.tif", ], - "value1_candidate": [1, 2], - "value2_candidate": [5, 6], + "value1_candidate": [1, 2, 2], + "value2_candidate": [5, 6, 6], "agreement_maps": [ "agreement_continuous_0.tif", "agreement_continuous_1.tif", + "agreement_continuous_1.tif", + ], + "value1_benchmark": [1, 2, 2], + "value2_benchmark": [5, 6, 6], + "band": [1.0, 1.0, 2.0], + "coefficient_of_determination": [ + -0.06615996360778809, + -2.829420804977417, + 0.10903036594390869, + ], + "mean_absolute_error": [ + 0.3173885941505432, + 0.48503121733665466, + 0.48503121733665466, + ], + "mean_absolute_percentage_error": [ + 0.15956786274909973, + 0.20223499834537506, + 0.15323485434055328, ], - "value1_benchmark": [1, 2], - "value2_benchmark": [5, 6], - "band": [1.0, 1.0], - "coefficient_of_determination": [-0.066160, -0.329965], - "mean_absolute_error": [0.317389, 0.485031], - "mean_absolute_percentage_error": [0.159568, 0.177735], } ) ] * 2 + [ @@ -175,7 +190,12 @@ }, ] -open_kwargs = [{"mask_and_scale": True, "masked": True}] * 4 +open_kwargs = [ + {"mask_and_scale": True, "masked": True}, + {"mask_and_scale": True, "masked": True, "chunks": "auto"}, + {"mask_and_scale": True, "masked": True}, + {"mask_and_scale": True, "masked": True}, +] # agreement_map_field = [None, "agreement_maps"] agreement_map_field = ["agreement_maps"] * 4 @@ -185,6 +205,7 @@ ( f"{TEST_DATA_DIR}/agreement_continuous_0.tif", f"{TEST_DATA_DIR}/agreement_continuous_1.tif", + f"{TEST_DATA_DIR}/agreement_continuous_1.tif", ) ] * 2 + [ ( diff --git a/tests/cases_compute_continuous_metrics.py b/tests/cases_compute_continuous_metrics.py index 263d0957..5c7943e9 100644 --- a/tests/cases_compute_continuous_metrics.py +++ b/tests/cases_compute_continuous_metrics.py @@ -87,7 +87,7 @@ def case_compute_continuous_metrics_success( stat_names = ["non_existent_function"] -exceptions = [KeyError] +exceptions = [TypeError] @parametrize(