Skip to content

Commit

Permalink
Merge branch 'dev'
Browse files Browse the repository at this point in the history
  • Loading branch information
fbunt committed Nov 28, 2024
2 parents 44c26d5 + 96cc3b8 commit 0bba84b
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 2 deletions.
10 changes: 8 additions & 2 deletions raster_tools/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,9 @@ def full_like(raster_template, value, bands=1, dtype=None, copy_mask=False):
return _build_result(rst, ndata, bands, copy_mask, copy_null)


def constant_raster(raster_template, value=1, bands=1, copy_mask=False):
def constant_raster(
raster_template, value=1, bands=1, dtype=None, copy_mask=False
):
"""Create a Raster filled with a constant value like a template raster.
This is a convenience function that wraps :func:`full_like`.
Expand All @@ -276,6 +278,8 @@ def constant_raster(raster_template, value=1, bands=1, copy_mask=False):
Value to fill result with. Default is 1.
bands : int, optional
Number of bands desired for output. Default is 1.
dtype : data-type, optional
Overrides the result dtype.
copy_mask : bool
If `True`, the template raster's mask is copied to the result raster.
If `bands` differs from `raster_template`, the first band's mask is
Expand All @@ -287,7 +291,9 @@ def constant_raster(raster_template, value=1, bands=1, copy_mask=False):
The resulting raster of constant values.
"""
return full_like(raster_template, value, bands=bands, copy_mask=copy_mask)
return full_like(
raster_template, value, bands=bands, dtype=dtype, copy_mask=copy_mask
)


def zeros_like(raster_template, bands=1, dtype=None, copy_mask=False):
Expand Down
17 changes: 17 additions & 0 deletions raster_tools/raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -1798,6 +1798,23 @@ def set_null_value(self, value):
# dtypes. This catches and fixes the issue.
fvalue = float(value)
new_dtype = get_common_dtype([fvalue, dtype])
elif is_int(value) and is_int(dtype):
# Catch the case where value is 0 and dtype is int8, etc. In this
# case, get_common_dtype will use uint8 for the dtype of value,
# which will combine with int8 to produce int16 for the new dtype.
with warnings.catch_warnings():
warnings.filterwarnings("error")
try:
test_value = dtype.type(value)
assert test_value == value
except (
OverflowError,
AssertionError,
RuntimeWarning,
DeprecationWarning,
):
new_dtype = get_common_dtype([value, dtype])
# else: do nothing
else:
new_dtype = get_common_dtype([value, dtype])

Expand Down
17 changes: 17 additions & 0 deletions tests/test_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,23 @@ def test_full_like(template, value, nbands, dtype, copy_mask):
)


@pytest.mark.parametrize("copy_mask", [0, 1])
@pytest.mark.parametrize("dtype", [None, "int32"])
@pytest.mark.parametrize("nbands", [1, 2, 3])
@pytest.mark.parametrize("value", [9, 100, -10])
@pytest.mark.parametrize("template", templates())
def test_constant_raster(template, value, nbands, dtype, copy_mask):
run_constant_raster_tests(
creation.constant_raster,
template,
value,
nbands,
dtype,
copy_mask,
True,
)


@pytest.mark.parametrize("copy_mask", [0, 1])
@pytest.mark.parametrize("dtype", [None, "int32"])
@pytest.mark.parametrize("nbands", [1, 2, 3])
Expand Down
5 changes: 5 additions & 0 deletions tests/test_raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -1653,6 +1653,11 @@ def test_set_null_value():
@pytest.mark.parametrize(
"raster,value,expected_dtype",
[
(
rts.creation.zeros_like(testdata.raster.dem_small, dtype="int8"),
0,
I8,
),
(
rts.creation.zeros_like(testdata.raster.dem_small, dtype="int64"),
0,
Expand Down

0 comments on commit 0bba84b

Please sign in to comment.