Skip to content

Commit

Permalink
Bug/sc 456192/top value parameter should be present by default (#156)
Browse files Browse the repository at this point in the history
  • Loading branch information
rantolin authored Dec 11, 2024
1 parent 39236c8 commit 78a3684
Show file tree
Hide file tree
Showing 9 changed files with 145 additions and 53 deletions.
2 changes: 2 additions & 0 deletions .env
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
# === Raster Loader Environment Variables ===
GOOGLE_APPLICATION_CREDENTIALS=/usr/local/gcloud/credentials.json

16 changes: 16 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,22 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

<!-- insertion marker -->
## Unreleased

<small>[Compare with latest](https://github.com/CartoDB/raster-loader/compare/v0.9.1...HEAD)</small>
### Added

- Add: Compute top values only for integer bands ([6c10cc0](https://github.com/CartoDB/raster-loader/commit/6c10cc025f5691f7841beee560437fb591bddfe9) by Roberto Antolín).

### Fixed

- Fix: Tackle degenerate case of stdev computation ([b112c80](https://github.com/CartoDB/raster-loader/commit/b112c80be7d7c1adfd08f651b43dc591fd54a2ef) by Roberto Antolín).
- Fix: Get count stats from shape of raster band ([c066a30](https://github.com/CartoDB/raster-loader/commit/c066a307ee116598c54ea4871d563f79deebad0b) by Roberto Antolín).
- Fix: Raise error when 0 non-masked samples due to sparse rasters ([dfd89ae](https://github.com/CartoDB/raster-loader/commit/dfd89aef27726a3217843022769600315d8e5b6f) by Roberto Antolín).

### Changed

- Change '--all_stats' flag to '--basic_stats' ([2cb89cc](https://github.com/CartoDB/raster-loader/pull/156/commits/2cb89cca30eb15189c876760c026074e262cc10f) by Roberto Antolín).

## [0.9.1] 2024-11-26

Expand Down
8 changes: 4 additions & 4 deletions raster_loader/cli/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@ def bigquery(args=None):
is_flag=True,
)
@click.option(
"--all_stats",
help="Compute all statistics including quantiles and most frequent values.",
"--basic_stats",
help="Compute basic stats and omit quantiles and most frequent values.",
required=False,
is_flag=True,
)
Expand All @@ -104,7 +104,7 @@ def upload(
append=False,
cleanup_on_failure=False,
exact_stats=False,
all_stats=False,
basic_stats=False,
):
from raster_loader.io.common import (
get_number_of_blocks,
Expand Down Expand Up @@ -176,7 +176,7 @@ def upload(
append=append,
cleanup_on_failure=cleanup_on_failure,
exact_stats=exact_stats,
all_stats=all_stats,
basic_stats=basic_stats,
)

click.echo("Raster file uploaded to Google BigQuery")
Expand Down
8 changes: 4 additions & 4 deletions raster_loader/cli/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@ def snowflake(args=None):
is_flag=True,
)
@click.option(
"--all_stats",
help="Compute all statistics including quantiles and most frequent values.",
"--basic_stats",
help="Compute basic stats and omit quantiles and most frequent values.",
required=False,
is_flag=True,
)
Expand All @@ -117,7 +117,7 @@ def upload(
append=False,
cleanup_on_failure=False,
exact_stats=False,
all_stats=False,
basic_stats=False,
):
from raster_loader.io.common import (
get_number_of_blocks,
Expand Down Expand Up @@ -200,7 +200,7 @@ def upload(
append=append,
cleanup_on_failure=cleanup_on_failure,
exact_stats=exact_stats,
all_stats=all_stats,
basic_stats=basic_stats,
)

click.echo("Raster file uploaded to Snowflake")
Expand Down
4 changes: 2 additions & 2 deletions raster_loader/io/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def upload_raster(
append: bool = False,
cleanup_on_failure: bool = False,
exact_stats: bool = False,
all_stats: bool = False,
basic_stats: bool = False,
):
"""Write a raster file to a BigQuery table."""
print("Loading raster file to BigQuery...")
Expand All @@ -131,7 +131,7 @@ def upload_raster(
exit()

metadata = rasterio_metadata(
file_path, bands_info, self.band_rename_function, exact_stats, all_stats
file_path, bands_info, self.band_rename_function, exact_stats, basic_stats
)

overviews_records_gen = rasterio_overview_to_records(
Expand Down
92 changes: 53 additions & 39 deletions raster_loader/io/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,21 +162,36 @@ def get_resolution_and_block_sizes(
return block_width, block_height, resolution


def get_color_name(raster_dataset: rasterio.io.DatasetReader, band: int) -> str:
try:
# There is [an issue](https://github.com/OSGeo/gdal/issues/1928)
# in gdal with the same error message that we see in this line:
# "Failed to compute statistics, no valid pixels found in sampling."
#
# It seems to be an error with cropped rasters.
band_colorinterp = raster_dataset.colorinterp[band - 1].name
except Exception:
band_colorinterp = None

return band_colorinterp


def get_color_table(raster_dataset: rasterio.io.DatasetReader, band: int):
try:
if raster_dataset.colorinterp[band - 1].name == "palette":
if get_color_name(raster_dataset, band) == "palette":
return raster_dataset.colormap(band)
return None
except ValueError:
return None



def rasterio_metadata(
file_path: str,
bands_info: List[Tuple[int, str]],
band_rename_function: Callable,
exact_stats: bool = False,
all_stats: bool = False,
basic_stats: bool = False,
):
"""Open a raster file with rasterio."""
raster_info = rio_cogeo.cog_info(file_path).dict()
Expand Down Expand Up @@ -228,23 +243,14 @@ def rasterio_metadata(
"User is encourage to compute approximate statistics instead.",
UserWarning,
)
stats = raster_band_stats(raster_dataset, band, all_stats)
stats = raster_band_stats(raster_dataset, band, basic_stats)
else:
print("Computing approximate stats...")
stats = raster_band_approx_stats(
raster_dataset, samples, band, all_stats
raster_dataset, samples, band, basic_stats
)

try:
# There is [an issue](https://github.com/OSGeo/gdal/issues/1928)
# in gdal with the same error message that we see in this line:
# "Failed to compute statistics, no valid pixels found in sampling."
#
# It seems to be an error with cropped rasters.
band_colorinterp = raster_dataset.colorinterp[band - 1].name
except Exception:
band_colorinterp = None

band_colorinterp = get_color_name(raster_dataset, band)
if band_colorinterp == "alpha":
band_nodata = "0"
else:
Expand Down Expand Up @@ -302,7 +308,7 @@ def rasterio_metadata(

def get_alpha_band(raster_dataset: rasterio.io.DatasetReader):
for band in raster_dataset.indexes:
if raster_dataset.colorinterp[band - 1].name == "alpha":
if get_color_name(raster_dataset, band) == "alpha":
return band
return None

Expand Down Expand Up @@ -423,6 +429,13 @@ def not_enough_samples():
iterations += 1

if len(not_masked_samples[1]) < n_samples:
if len(not_masked_samples[1]) == 0:
raise ValueError(
"The data is very sparse and no non-masked samples were collected.\n"
"Please, consider to use the --exact_stats option to compute exact "
"stats"
)

warnings.warn(
"The data is very sparse and there are not enough non-masked samples.\n"
f"Only {len(not_masked_samples[1])} samples were collected and "
Expand Down Expand Up @@ -470,7 +483,7 @@ def raster_band_approx_stats(
raster_dataset: rasterio.io.DatasetReader,
samples: Samples,
band: int,
all_stats: bool,
basic_stats: bool,
) -> dict:
"""Get approximate statistics for a raster band."""

Expand All @@ -485,10 +498,10 @@ def raster_band_approx_stats(
_sum = int(np.sum(samples_band))
sum_squares = int(np.sum(np.array(samples_band) ** 2))

quantiles = None
most_common = None
if all_stats:

if basic_stats:
quantiles = None
most_common = None
else:
quantiles = compute_quantiles(samples_band, int)

most_common = dict()
Expand Down Expand Up @@ -551,7 +564,7 @@ def read_raster_band(raster_dataset: rasterio.io.DatasetReader, band: int) -> np


def raster_band_stats(
raster_dataset: rasterio.io.DatasetReader, band: int, all_stats: bool
raster_dataset: rasterio.io.DatasetReader, band: int, basic_stats: bool
) -> dict:
"""Get statistics for a raster band."""

Expand All @@ -563,35 +576,33 @@ def raster_band_stats(
_mean = _stats.mean
_std = _stats.std

count = math.prod(_stats.shape)
raster_band = read_raster_band(raster_dataset=raster_dataset, band=band)

count = math.prod(raster_band.shape)
if is_masked_band(raster_dataset, band):
count = np.count_nonzero(_stats.mask is False)
count = np.count_nonzero(raster_band.mask is False)

_sum = _mean * count
sum_squares = count * _std**2 + _mean**2

quantiles = None
most_common = None
if all_stats:
raster_band = read_raster_band(raster_dataset=raster_dataset, band=band)

print("Removing masked data...")
qdata = raster_band.compressed()
print("Removing masked data...")
qdata = raster_band.compressed()

if basic_stats:
quantiles = None
most_common = None
else:
casting_function = (
int if np.issubdtype(raster_band.dtype, np.integer) else float
)

quantiles = compute_quantiles(qdata, casting_function)

print("Computing most commons values...")
warnings.warn(
"Most common values are meant for categorical data. "
"Computing them for float bands can be meaningless."
)
most_common = Counter(qdata).most_common(100)
most_common.sort(key=lambda x: x[1], reverse=True)
most_common = dict([(casting_function(x[0]), x[1]) for x in most_common])
if casting_function == int:
print("Computing most commons values...")
most_common = Counter(qdata).most_common(100)
most_common.sort(key=lambda x: x[1], reverse=True)
most_common = dict([(casting_function(x[0]), x[1]) for x in most_common])

version = ".".join(__version__.split(".")[:3])

Expand Down Expand Up @@ -990,7 +1001,10 @@ def update_metadata(metadata, old_metadata):
stdev = math.sqrt(old_stats["stddev"] ** 2 + new_stats["stddev"] ** 2)
else:
mean = _sum / count
stdev = math.sqrt(sum_squares / count - mean * mean)
try:
stdev = math.sqrt(sum_squares / count - mean * mean)
except ValueError:
stdev = math.sqrt(old_stats["stddev"] ** 2 + new_stats["stddev"] ** 2)

approximated_stats = (
old_stats["approximated_stats"] or new_stats["approximated_stats"]
Expand Down
4 changes: 2 additions & 2 deletions raster_loader/io/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def upload_raster(
append: bool = False,
cleanup_on_failure: bool = False,
exact_stats: bool = False,
all_stats: bool = False,
basic_stats: bool = False,
) -> bool:
def band_rename_function(x):
return x.upper()
Expand Down Expand Up @@ -207,7 +207,7 @@ def band_rename_function(x):
exit()

metadata = rasterio_metadata(
file_path, bands_info, band_rename_function, exact_stats, all_stats
file_path, bands_info, band_rename_function, exact_stats, basic_stats
)

overviews_records_gen = rasterio_overview_to_records(
Expand Down
28 changes: 27 additions & 1 deletion raster_loader/tests/bigquery/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,33 @@ def test_bigquery_upload(*args, **kwargs):
assert result.exit_code == 0


@patch("raster_loader.cli.bigquery.BigQueryConnection.upload_raster", return_value=None)
@patch("raster_loader.cli.bigquery.BigQueryConnection.__init__", return_value=None)
def test_bigquery_upload_with_basic_stats(*args, **kwargs):
runner = CliRunner()
result = runner.invoke(
main,
[
"bigquery",
"upload",
"--file_path",
f"{tiff}",
"--project",
"project",
"--dataset",
"dataset",
"--table",
"table",
"--chunk_size",
1,
"--band",
1,
"--basic_stats",
],
)
assert result.exit_code == 0


@patch("raster_loader.cli.bigquery.BigQueryConnection.upload_raster", return_value=None)
@patch("raster_loader.cli.bigquery.BigQueryConnection.__init__", return_value=None)
def test_bigquery_upload_with_all_stats(*args, **kwargs):
Expand All @@ -59,7 +86,6 @@ def test_bigquery_upload_with_all_stats(*args, **kwargs):
1,
"--band",
1,
"--all_stats",
],
)
assert result.exit_code == 0
Expand Down
36 changes: 35 additions & 1 deletion raster_loader/tests/snowflake/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,41 @@ def test_snowflake_upload(*args, **kwargs):
assert result.exit_code == 0


@patch(
"raster_loader.io.snowflake.SnowflakeConnection.upload_raster", return_value=None
)
@patch("raster_loader.io.snowflake.SnowflakeConnection.__init__", return_value=None)
def test_snowflake_upload_with_basic_stats(*args, **kwargs):
runner = CliRunner()
result = runner.invoke(
main,
[
"snowflake",
"upload",
"--file_path",
f"{tiff}",
"--database",
"database",
"--schema",
"schema",
"--table",
"table",
"--account",
"account",
"--username",
"username",
"--password",
"password",
"--chunk_size",
1,
"--band",
1,
"--basic_stats",
],
)
assert result.exit_code == 0


@patch(
"raster_loader.io.snowflake.SnowflakeConnection.upload_raster", return_value=None
)
Expand Down Expand Up @@ -75,7 +110,6 @@ def test_snowflake_upload_with_all_stats(*args, **kwargs):
1,
"--band",
1,
"--all_stats",
],
)
assert result.exit_code == 0
Expand Down

0 comments on commit 78a3684

Please sign in to comment.