diff --git a/raster_tools/creation.py b/raster_tools/creation.py index a949d17..cf95cf9 100644 --- a/raster_tools/creation.py +++ b/raster_tools/creation.py @@ -263,7 +263,9 @@ def full_like(raster_template, value, bands=1, dtype=None, copy_mask=False): return _build_result(rst, ndata, bands, copy_mask, copy_null) -def constant_raster(raster_template, value=1, bands=1, copy_mask=False): +def constant_raster( + raster_template, value=1, bands=1, dtype=None, copy_mask=False +): """Create a Raster filled with a constant value like a template raster. This is a convenience function that wraps :func:`full_like`. @@ -276,6 +278,8 @@ def constant_raster(raster_template, value=1, bands=1, copy_mask=False): Value to fill result with. Default is 1. bands : int, optional Number of bands desired for output. Default is 1. + dtype : data-type, optional + Overrides the result dtype. copy_mask : bool If `True`, the template raster's mask is copied to the result raster. If `bands` differs from `raster_template`, the first band's mask is @@ -287,7 +291,9 @@ def constant_raster(raster_template, value=1, bands=1, copy_mask=False): The resulting raster of constant values. """ - return full_like(raster_template, value, bands=bands, copy_mask=copy_mask) + return full_like( + raster_template, value, bands=bands, dtype=dtype, copy_mask=copy_mask + ) def zeros_like(raster_template, bands=1, dtype=None, copy_mask=False): diff --git a/raster_tools/raster.py b/raster_tools/raster.py index 6862f34..712b244 100644 --- a/raster_tools/raster.py +++ b/raster_tools/raster.py @@ -1798,6 +1798,23 @@ def set_null_value(self, value): # dtypes. This catches and fixes the issue. fvalue = float(value) new_dtype = get_common_dtype([fvalue, dtype]) + elif is_int(value) and is_int(dtype): + # Catch the case where value is 0 and dtype is int8, etc. In this + # case, get_common_dtype will use uint8 for the dtype of value, + # which will combine with int8 to produce int16 for the new dtype. + with warnings.catch_warnings(): + warnings.filterwarnings("error") + try: + test_value = dtype.type(value) + assert test_value == value + except ( + OverflowError, + AssertionError, + RuntimeWarning, + DeprecationWarning, + ): + new_dtype = get_common_dtype([value, dtype]) + # else: do nothing else: new_dtype = get_common_dtype([value, dtype]) diff --git a/tests/test_creation.py b/tests/test_creation.py index aaab2fe..253117c 100644 --- a/tests/test_creation.py +++ b/tests/test_creation.py @@ -125,6 +125,23 @@ def test_full_like(template, value, nbands, dtype, copy_mask): ) +@pytest.mark.parametrize("copy_mask", [0, 1]) +@pytest.mark.parametrize("dtype", [None, "int32"]) +@pytest.mark.parametrize("nbands", [1, 2, 3]) +@pytest.mark.parametrize("value", [9, 100, -10]) +@pytest.mark.parametrize("template", templates()) +def test_constant_raster(template, value, nbands, dtype, copy_mask): + run_constant_raster_tests( + creation.constant_raster, + template, + value, + nbands, + dtype, + copy_mask, + True, + ) + + @pytest.mark.parametrize("copy_mask", [0, 1]) @pytest.mark.parametrize("dtype", [None, "int32"]) @pytest.mark.parametrize("nbands", [1, 2, 3]) diff --git a/tests/test_raster.py b/tests/test_raster.py index 051aae0..67132c1 100644 --- a/tests/test_raster.py +++ b/tests/test_raster.py @@ -1653,6 +1653,11 @@ def test_set_null_value(): @pytest.mark.parametrize( "raster,value,expected_dtype", [ + ( + rts.creation.zeros_like(testdata.raster.dem_small, dtype="int8"), + 0, + I8, + ), ( rts.creation.zeros_like(testdata.raster.dem_small, dtype="int64"), 0,