Skip to content

Commit

Permalink
Merge pull request #165 from NOAA-OWP/continuous_nan
Browse files Browse the repository at this point in the history
Continuous int dtype nodata and DataArray Multiband Processing
  • Loading branch information
fernando-aristizabal authored Nov 6, 2023
2 parents 0acd820 + a874dfa commit c5964b6
Show file tree
Hide file tree
Showing 7 changed files with 209 additions and 46 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ authors = [
requires-python = ">=3.8"
keywords = ["geospatial", "evaluations"]
license = {text = "MIT"}
version = "0.2.3.1"
version = "0.2.4"
dynamic = ["readme", "dependencies"]


Expand Down
21 changes: 0 additions & 21 deletions src/gval/accessors/gval_xarray.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Iterable, Optional, Tuple, Union, Callable, Dict, List
from numbers import Number
from functools import partial

import numpy as np
import numba as nb
Expand Down Expand Up @@ -350,26 +349,6 @@ def continuous_compare(
benchmark_map, target_map, resampling, rasterize_attributes
)

# Check whether either dataset is of integer type
integer_check = partial(np.issubdtype, arg2=np.integer)

# Temporary code to check if the datatype is an integer --------------------------------
if isinstance(candidate, xr.Dataset):
c_is_int, b_is_int = False, False
for c_var, b_var in zip(candidate.data_vars, benchmark.data_vars):
c_is_int, b_is_int = np.bitwise_or(
list(map(integer_check, [candidate[c_var], benchmark[b_var]])),
[c_is_int, b_is_int],
)
else:
c_is_int, b_is_int = map(integer_check, [candidate, benchmark])

if c_is_int or b_is_int: # pragma: no cover
raise TypeError(
"Cannot compute continuous statistics on data with type integer"
)
# ---------------------------------------------------------------------------------------

results = candidate.gval.compute_agreement_map(
benchmark_map=benchmark,
comparison_function=difference,
Expand Down
139 changes: 132 additions & 7 deletions src/gval/comparison/compute_continuous_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,101 @@
# __all__ = ['*']
__author__ = "Fernando Aristizabal"

from typing import Iterable, Union
from typing import Iterable, Union, List

import numpy as np
import pandera as pa
import pandas as pd
from pandera.typing import DataFrame
import xarray as xr
import geopandas as gpd
import dask as da

from gval import ContStats
from gval.utils.schemas import Metrics_df, Subsample_identifiers, Sample_identifiers
from gval.utils.loading_datasets import _check_dask_array, _convert_to_dataset


def _get_selected_datasets(
agreement: xr.Dataset,
candidate: xr.Dataset,
benchmark: xr.Dataset,
nodata: list,
var_name: str,
) -> List[xr.Dataset]:
"""
Selects specific coordinates for integer valued datasets to not process nodata values
Parameters
----------
agreement : xr.Dataset
Agreement Map
candidate : xr.Dataset
Candidate Map
benchmark : xr.Dataset
Benchmark Map
nodata : list
Nodata values in the list
var_name : str
Name of variable
Returns
-------
List[xr.Dataset, xr.Dataset, xr.Dataset]
Datasets with selected coordinates
"""

is_dsk = _check_dask_array(agreement)
cmask, bmask = (
xr.where(candidate[var_name] == nodata, 0, 1),
xr.where(benchmark[var_name] == nodata, 0, 1),
)
tmask = cmask & bmask

# Create a coord meshgrid and select appropriate coords to select from xarray
if is_dsk:
with da.config.set({"array.slicing.split_large_chunks": True}):
grid_coords = da.array.asarray(
da.array.meshgrid(candidate.coords["x"], candidate.coords["y"])
).T.reshape(-1, 2)
picked_coords = grid_coords[da.array.ravel(tmask.data).astype(bool), :]
else:
grid_coords = np.array(
np.meshgrid(candidate.coords["x"], candidate.coords["y"])
).T.reshape(-1, 2)
picked_coords = grid_coords[np.ravel(tmask.data).astype(bool), :]

# Select coordinates from xarray
with da.config.set({"array.slicing.split_large_chunks": True}):
agreement_sel = (
agreement[var_name].sel(
{"x": picked_coords[:, 0], "y": picked_coords[:, 1]}
)
if picked_coords is not None
else agreement[var_name]
)

candidate_sel = (
candidate[var_name].sel(
{"x": picked_coords[:, 0], "y": picked_coords[:, 1]}
)
if picked_coords is not None
else candidate[var_name]
)

benchmark_sel = (
benchmark[var_name].sel(
{"x": picked_coords[:, 0], "y": picked_coords[:, 1]}
)
if picked_coords is not None
else benchmark[var_name]
)

return (
agreement_sel,
candidate_sel,
benchmark_sel,
)


@pa.check_types
Expand Down Expand Up @@ -69,13 +154,53 @@ def _compute_continuous_metrics(
for idx, (agreement, benchmark, candidate) in enumerate(
zip(agreement_map, benchmark_map, candidate_map)
):
# compute error based metrics such as MAE, MSE, RMSE, etc. from agreement map and produce metrics_df
statistics, names = ContStats.process_statistics(
metrics,
error=agreement,
candidate_map=candidate,
benchmark_map=benchmark,
# Change data to Dataset if DataArray
agreement = _convert_to_dataset(agreement)
candidate = _convert_to_dataset(candidate)
benchmark = _convert_to_dataset(benchmark)

# Check if integer type and nodata values
is_int = (
np.issubdtype(candidate.dtype, np.integer)
if isinstance(candidate, xr.DataArray)
else np.issubdtype(candidate["band_1"].dtype, np.integer)
)
nodata = [agreement[x].rio.nodata for x in agreement.data_vars]

# Remove no data value if int type form calculation, otherwise leave all values in
# Necessary because there is not an int sentinel value
if is_int and np.all([x is not None for x in nodata]):
final_stats = []
# Iterate through each band and gather statistics
for nodata_idx, var_name in enumerate(agreement.data_vars):
# Create mask for all nodata values
agreement_sel, candidate_sel, benchmark_sel = _get_selected_datasets(
agreement, candidate, benchmark, nodata[nodata_idx], var_name
)

statistics, names = ContStats.process_statistics(
metrics,
error=agreement_sel,
candidate_map=candidate_sel,
benchmark_map=benchmark_sel,
)

del agreement_sel, candidate_sel, benchmark_sel

final_stats.append(statistics)

statistics = [
{f"band_{idx + 1}": val for idx, val in enumerate(lst)}
for lst in np.array(final_stats).T
]

else:
statistics, names = ContStats.process_statistics(
metrics,
error=agreement,
candidate_map=candidate,
benchmark_map=benchmark,
)

# create metrics_df
metric_df = dict()
Expand Down
12 changes: 10 additions & 2 deletions src/gval/statistics/continuous_stat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

import xarray as xr

from gval.utils.loading_datasets import _check_dask_array


def convert_output(func: Callable) -> Callable: # pragma: no cover
"""
Expand All @@ -31,12 +33,18 @@ def wrapper(*args, **kwargs):
# Call the decorated function
result = func(*args, **kwargs)

is_dsk = _check_dask_array(result)
if isinstance(result, xr.DataArray):
# Convert to a single numeric value
return result.item()
return result.compute().item() if is_dsk else result.item()
elif isinstance(result, xr.Dataset):
# Convert to a dictionary with band names as keys
return {band: result[band].item() for band in result.data_vars}
if is_dsk:
return {
band: result[band].compute().item() for band in result.data_vars
}
else:
return {band: result[band].item() for band in result.data_vars}

return result

Expand Down
30 changes: 30 additions & 0 deletions src/gval/utils/loading_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,3 +290,33 @@ def _parse_string_attributes(
}

return obj


def _convert_to_dataset(xr_object=Union[xr.DataArray, xr.Dataset]) -> xr.Dataset:
"""
Converts xarray object to dataset if it is not one already.
Parameters
----------
xr_object : Union[xr.DataArray, xr.Dataset]
Xarray object to convert or simply return
Returns
-------
xr.Dataset
Dataset object
"""

if isinstance(xr_object, xr.DataArray):
nodata = xr_object.rio.nodata
xr_object = xr_object.to_dataset(dim="band")
xr_object = xr_object.rename_vars({x: f"band_{x}" for x in xr_object.data_vars})

# Account for nodata
for var_name in xr_object.data_vars:
xr_object[var_name] = xr_object[var_name].rio.write_nodata(nodata)

return xr_object
else:
return xr_object
49 changes: 35 additions & 14 deletions tests/cases_catalogs.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,26 +79,41 @@
pd.DataFrame(
{
"map_id_candidate": [
f"{TEST_DATA_DIR}/candidate_continuous_0.tif",
f"{TEST_DATA_DIR}/candidate_continuous_1.tif",
"s3://gval-test/candidate_continuous_0.tif",
"s3://gval-test/candidate_continuous_1.tif",
"s3://gval-test/candidate_continuous_1.tif",
],
"compare_id": ["compare1", "compare2"],
"compare_id": ["compare1", "compare2", "compare2"],
"map_id_benchmark": [
f"{TEST_DATA_DIR}/benchmark_continuous_0.tif",
f"{TEST_DATA_DIR}/benchmark_continuous_1.tif",
"s3://gval-test/benchmark_continuous_0.tif",
"s3://gval-test/benchmark_continuous_1.tif",
"s3://gval-test/benchmark_continuous_1.tif",
],
"value1_candidate": [1, 2],
"value2_candidate": [5, 6],
"value1_candidate": [1, 2, 2],
"value2_candidate": [5, 6, 6],
"agreement_maps": [
"agreement_continuous_0.tif",
"agreement_continuous_1.tif",
"agreement_continuous_1.tif",
],
"value1_benchmark": [1, 2, 2],
"value2_benchmark": [5, 6, 6],
"band": [1.0, 1.0, 2.0],
"coefficient_of_determination": [
-0.06615996360778809,
-2.829420804977417,
0.10903036594390869,
],
"mean_absolute_error": [
0.3173885941505432,
0.48503121733665466,
0.48503121733665466,
],
"mean_absolute_percentage_error": [
0.15956786274909973,
0.20223499834537506,
0.15323485434055328,
],
"value1_benchmark": [1, 2],
"value2_benchmark": [5, 6],
"band": [1.0, 1.0],
"coefficient_of_determination": [-0.066160, -0.329965],
"mean_absolute_error": [0.317389, 0.485031],
"mean_absolute_percentage_error": [0.159568, 0.177735],
}
)
] * 2 + [
Expand Down Expand Up @@ -175,7 +190,12 @@
},
]

open_kwargs = [{"mask_and_scale": True, "masked": True}] * 4
open_kwargs = [
{"mask_and_scale": True, "masked": True},
{"mask_and_scale": True, "masked": True, "chunks": "auto"},
{"mask_and_scale": True, "masked": True},
{"mask_and_scale": True, "masked": True},
]

# agreement_map_field = [None, "agreement_maps"]
agreement_map_field = ["agreement_maps"] * 4
Expand All @@ -185,6 +205,7 @@
(
f"{TEST_DATA_DIR}/agreement_continuous_0.tif",
f"{TEST_DATA_DIR}/agreement_continuous_1.tif",
f"{TEST_DATA_DIR}/agreement_continuous_1.tif",
)
] * 2 + [
(
Expand Down
2 changes: 1 addition & 1 deletion tests/cases_compute_continuous_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def case_compute_continuous_metrics_success(


stat_names = ["non_existent_function"]
exceptions = [KeyError]
exceptions = [TypeError]


@parametrize(
Expand Down

0 comments on commit c5964b6

Please sign in to comment.