Skip to content

Commit

Permalink
Memory management progress
Browse files Browse the repository at this point in the history
  • Loading branch information
GregoryPetrochenkov-NOAA committed May 15, 2024
1 parent 9d5bc26 commit a241aad
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 24 deletions.
2 changes: 1 addition & 1 deletion notebooks/Tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1158,7 +1158,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.10"
"version": "3.10.14"
}
},
"nbformat": 4,
Expand Down
53 changes: 41 additions & 12 deletions src/gval/comparison/pairing_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,34 @@
from numbers import Number

import numpy as np
import numba as nb


@nb.vectorize(nopython=True)
from numba import vectorize, uint8, int32, int64, float32, float64, boolean


# Numba Type Definitions
one_param_function_types = [
uint8(uint8),
int32(int32),
int64(int64),
float32(float32),
float64(float64),
]
two_param_function_types = [
uint8(uint8, uint8),
int32(int32, int32),
int64(int64, int64),
float32(float32, float32),
float64(float64, float64),
]
not_natural_number_types = [
boolean(uint8, boolean),
boolean(int32, boolean),
int64(int64, boolean),
float32(float32, boolean),
float64(float64, boolean),
]


@vectorize(not_natural_number_types, nopython=True)
def _is_not_natural_number(
x: Number, raise_exception: bool
) -> bool: # pragma: no cover
Expand Down Expand Up @@ -49,7 +73,7 @@ def _is_not_natural_number(
return False # treated as natural for this use case

# checks for non-negative and whole number
elif (x < 0) | ((x - nb.int64(x)) != 0):
elif (x < 0) | ((x - int64(x)) != 0):
if raise_exception:
raise ValueError(
"Non natural number found (non-negative integers, excluding Inf) [0, 1, 2, 3, 4, ...)"
Expand All @@ -62,7 +86,7 @@ def _is_not_natural_number(
return False


@nb.vectorize(nopython=True)
@vectorize(two_param_function_types, nopython=True)
def cantor_pair(c: Number, b: Number) -> Number: # pragma: no cover
"""
Produces unique natural number for two non-negative natural numbers (0,1,2,...)
Expand Down Expand Up @@ -92,7 +116,7 @@ def cantor_pair(c: Number, b: Number) -> Number: # pragma: no cover
return 0.5 * (c**2 + c + 2 * c * b + 3 * b + b**2)


@nb.vectorize(nopython=True)
@vectorize(two_param_function_types, nopython=True)
def szudzik_pair(c: Number, b: Number) -> Number: # pragma: no cover
"""
Produces unique natural number for two non-negative natural numbers (0,1,2,3,...).
Expand Down Expand Up @@ -122,7 +146,7 @@ def szudzik_pair(c: Number, b: Number) -> Number: # pragma: no cover
return c**2 + c + b if c >= b else b**2 + c


@nb.vectorize(nopython=True)
@vectorize(one_param_function_types, nopython=True)
def _negative_value_transformation(x: Number) -> Number: # pragma: no cover
"""
Transforms negative values for use with pairing functions that only accept non-negative integers.
Expand All @@ -147,7 +171,7 @@ def _negative_value_transformation(x: Number) -> Number: # pragma: no cover
return 2 * x if x >= 0 else -2 * x - 1


@nb.vectorize(nopython=True)
@vectorize(two_param_function_types, nopython=True)
def cantor_pair_signed(c: Number, b: Number) -> Number: # pragma: no cover
"""
Output unique natural number for each unique combination of whole numbers using Cantor signed method.
Expand Down Expand Up @@ -177,7 +201,12 @@ def cantor_pair_signed(c: Number, b: Number) -> Number: # pragma: no cover
return cantor_pair(ct, bt)


@nb.vectorize(nopython=True)
# from typing import TypeVar
#
# T = TypeVar("T")


@vectorize(two_param_function_types, nopython=True)
def szudzik_pair_signed(c: Number, b: Number) -> Number: # pragma: no cover
"""
Output unique natural number for each unique combination of whole numbers using Szudzik signed method._summary_
Expand Down Expand Up @@ -386,10 +415,10 @@ def pairing_dict_fn(
"Value combination found not accounted for in pairing dictionary"
)

return nb.vectorize(nopython=True)(pairing_dict_fn)
return vectorize(two_param_function_types, nopython=True)(pairing_dict_fn)


@nb.vectorize(nopython=True)
@vectorize(two_param_function_types, nopython=True)
def difference(c: Number, b: Number) -> Number: # pragma: no cover
"""
Calculates the difference between candidate and benchmark.
Expand Down
14 changes: 10 additions & 4 deletions src/gval/comparison/tabulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,16 +91,20 @@ def _crosstab_2d_DataArrays(
is_dsk = True

agreement_map.name = "group"
ag_dtype = agreement_map.dtype

if is_dsk:
agreement_counts = xarray_reduce(
agreement_map,
agreement_map,
engine="numba",
expected_groups=dask.array.unique(agreement_map.data),
func="count",
)
else:
agreement_counts = xarray_reduce(agreement_map, agreement_map, func="count")
agreement_counts = xarray_reduce(
agreement_map, agreement_map, engine="numba", func="count"
)

def not_nan(number):
return not np.isnan(number)
Expand Down Expand Up @@ -129,13 +133,15 @@ def not_nan(number):
for x in filter(not_nan, agreement_counts.coords["group"].values)
],
"agreement_values": list(
filter(not_nan, agreement_counts.coords["group"].values.astype(float))
filter(
not_nan, agreement_counts.coords["group"].values.astype(ag_dtype)
)
),
"counts": [
x
for x, y in zip(
agreement_counts.values.astype(float),
agreement_counts.coords["group"].values.astype(float),
agreement_counts.values.astype(ag_dtype),
agreement_counts.coords["group"].values.astype(ag_dtype),
)
if not np.isnan(y)
],
Expand Down
12 changes: 6 additions & 6 deletions tests/cases_catalogs.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,15 +79,15 @@
pd.DataFrame(
{
"map_id_candidate": [
"s3://gval-test/candidate_continuous_0.tif",
"s3://gval-test/candidate_continuous_1.tif",
"s3://gval-test/candidate_continuous_1.tif",
f"{TEST_DATA_DIR}/candidate_continuous_0.tif",
f"{TEST_DATA_DIR}/candidate_continuous_1.tif",
f"{TEST_DATA_DIR}/candidate_continuous_1.tif",
],
"compare_id": ["compare1", "compare2", "compare2"],
"map_id_benchmark": [
"s3://gval-test/benchmark_continuous_0.tif",
"s3://gval-test/benchmark_continuous_1.tif",
"s3://gval-test/benchmark_continuous_1.tif",
f"{TEST_DATA_DIR}/benchmark_continuous_0.tif",
f"{TEST_DATA_DIR}/benchmark_continuous_1.tif",
f"{TEST_DATA_DIR}/benchmark_continuous_1.tif",
],
"value1_candidate": [1, 2, 2],
"value2_candidate": [5, 6, 6],
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from gval.comparison.pairing_functions import PairingDict

# name of S3 for test data
TEST_DATA_S3_NAME = "gval-test"
TEST_DATA_S3_NAME = "gval"
TEST_DATA_DIR = f"s3://{TEST_DATA_S3_NAME}"


Expand Down

0 comments on commit a241aad

Please sign in to comment.