From 78a368401963e2b683e2c2ab679e3e084587a450 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Roberto=20Antol=C3=ADn?= Date: Wed, 11 Dec 2024 07:16:21 +0100 Subject: [PATCH] Bug/sc 456192/top value parameter should be present by default (#156) --- .env | 2 + CHANGELOG.md | 16 ++++ raster_loader/cli/bigquery.py | 8 +- raster_loader/cli/snowflake.py | 8 +- raster_loader/io/bigquery.py | 4 +- raster_loader/io/common.py | 92 +++++++++++++---------- raster_loader/io/snowflake.py | 4 +- raster_loader/tests/bigquery/test_cli.py | 28 ++++++- raster_loader/tests/snowflake/test_cli.py | 36 ++++++++- 9 files changed, 145 insertions(+), 53 deletions(-) diff --git a/.env b/.env index d40c88b..ba6829f 100644 --- a/.env +++ b/.env @@ -1 +1,3 @@ # === Raster Loader Environment Variables === +GOOGLE_APPLICATION_CREDENTIALS=/usr/local/gcloud/credentials.json + diff --git a/CHANGELOG.md b/CHANGELOG.md index 4a2aa73..21fc47d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,22 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## Unreleased + +[Compare with latest](https://github.com/CartoDB/raster-loader/compare/v0.9.1...HEAD) +### Added + +- Add: Compute top values only for integer bands ([6c10cc0](https://github.com/CartoDB/raster-loader/commit/6c10cc025f5691f7841beee560437fb591bddfe9) by Roberto Antolín). + +### Fixed + +- Fix: Tackle degenerate case of stdev computation ([b112c80](https://github.com/CartoDB/raster-loader/commit/b112c80be7d7c1adfd08f651b43dc591fd54a2ef) by Roberto Antolín). +- Fix: Get count stats from shape of raster band ([c066a30](https://github.com/CartoDB/raster-loader/commit/c066a307ee116598c54ea4871d563f79deebad0b) by Roberto Antolín). +- Fix: Raise error when 0 non-masked samples due to sparse rasters ([dfd89ae](https://github.com/CartoDB/raster-loader/commit/dfd89aef27726a3217843022769600315d8e5b6f) by Roberto Antolín). + +### Changed + +- Change '--all_stats' flag to '--basic_stats' ([2cb89cc](https://github.com/CartoDB/raster-loader/pull/156/commits/2cb89cca30eb15189c876760c026074e262cc10f) by Roberto Antolín). ## [0.9.1] 2024-11-26 diff --git a/raster_loader/cli/bigquery.py b/raster_loader/cli/bigquery.py index e1c1279..6873bc8 100644 --- a/raster_loader/cli/bigquery.py +++ b/raster_loader/cli/bigquery.py @@ -84,8 +84,8 @@ def bigquery(args=None): is_flag=True, ) @click.option( - "--all_stats", - help="Compute all statistics including quantiles and most frequent values.", + "--basic_stats", + help="Compute basic stats and omit quantiles and most frequent values.", required=False, is_flag=True, ) @@ -104,7 +104,7 @@ def upload( append=False, cleanup_on_failure=False, exact_stats=False, - all_stats=False, + basic_stats=False, ): from raster_loader.io.common import ( get_number_of_blocks, @@ -176,7 +176,7 @@ def upload( append=append, cleanup_on_failure=cleanup_on_failure, exact_stats=exact_stats, - all_stats=all_stats, + basic_stats=basic_stats, ) click.echo("Raster file uploaded to Google BigQuery") diff --git a/raster_loader/cli/snowflake.py b/raster_loader/cli/snowflake.py index e5894c5..0620807 100644 --- a/raster_loader/cli/snowflake.py +++ b/raster_loader/cli/snowflake.py @@ -93,8 +93,8 @@ def snowflake(args=None): is_flag=True, ) @click.option( - "--all_stats", - help="Compute all statistics including quantiles and most frequent values.", + "--basic_stats", + help="Compute basic stats and omit quantiles and most frequent values.", required=False, is_flag=True, ) @@ -117,7 +117,7 @@ def upload( append=False, cleanup_on_failure=False, exact_stats=False, - all_stats=False, + basic_stats=False, ): from raster_loader.io.common import ( get_number_of_blocks, @@ -200,7 +200,7 @@ def upload( append=append, cleanup_on_failure=cleanup_on_failure, exact_stats=exact_stats, - all_stats=all_stats, + basic_stats=basic_stats, ) click.echo("Raster file uploaded to Snowflake") diff --git a/raster_loader/io/bigquery.py b/raster_loader/io/bigquery.py index 17848a8..35c96d6 100644 --- a/raster_loader/io/bigquery.py +++ b/raster_loader/io/bigquery.py @@ -108,7 +108,7 @@ def upload_raster( append: bool = False, cleanup_on_failure: bool = False, exact_stats: bool = False, - all_stats: bool = False, + basic_stats: bool = False, ): """Write a raster file to a BigQuery table.""" print("Loading raster file to BigQuery...") @@ -131,7 +131,7 @@ def upload_raster( exit() metadata = rasterio_metadata( - file_path, bands_info, self.band_rename_function, exact_stats, all_stats + file_path, bands_info, self.band_rename_function, exact_stats, basic_stats ) overviews_records_gen = rasterio_overview_to_records( diff --git a/raster_loader/io/common.py b/raster_loader/io/common.py index 8748a63..a232fc8 100644 --- a/raster_loader/io/common.py +++ b/raster_loader/io/common.py @@ -162,21 +162,36 @@ def get_resolution_and_block_sizes( return block_width, block_height, resolution +def get_color_name(raster_dataset: rasterio.io.DatasetReader, band: int) -> str: + try: + # There is [an issue](https://github.com/OSGeo/gdal/issues/1928) + # in gdal with the same error message that we see in this line: + # "Failed to compute statistics, no valid pixels found in sampling." + # + # It seems to be an error with cropped rasters. + band_colorinterp = raster_dataset.colorinterp[band - 1].name + except Exception: + band_colorinterp = None + + return band_colorinterp + + def get_color_table(raster_dataset: rasterio.io.DatasetReader, band: int): try: - if raster_dataset.colorinterp[band - 1].name == "palette": + if get_color_name(raster_dataset, band) == "palette": return raster_dataset.colormap(band) return None except ValueError: return None + def rasterio_metadata( file_path: str, bands_info: List[Tuple[int, str]], band_rename_function: Callable, exact_stats: bool = False, - all_stats: bool = False, + basic_stats: bool = False, ): """Open a raster file with rasterio.""" raster_info = rio_cogeo.cog_info(file_path).dict() @@ -228,23 +243,14 @@ def rasterio_metadata( "User is encourage to compute approximate statistics instead.", UserWarning, ) - stats = raster_band_stats(raster_dataset, band, all_stats) + stats = raster_band_stats(raster_dataset, band, basic_stats) else: print("Computing approximate stats...") stats = raster_band_approx_stats( - raster_dataset, samples, band, all_stats + raster_dataset, samples, band, basic_stats ) - try: - # There is [an issue](https://github.com/OSGeo/gdal/issues/1928) - # in gdal with the same error message that we see in this line: - # "Failed to compute statistics, no valid pixels found in sampling." - # - # It seems to be an error with cropped rasters. - band_colorinterp = raster_dataset.colorinterp[band - 1].name - except Exception: - band_colorinterp = None - + band_colorinterp = get_color_name(raster_dataset, band) if band_colorinterp == "alpha": band_nodata = "0" else: @@ -302,7 +308,7 @@ def rasterio_metadata( def get_alpha_band(raster_dataset: rasterio.io.DatasetReader): for band in raster_dataset.indexes: - if raster_dataset.colorinterp[band - 1].name == "alpha": + if get_color_name(raster_dataset, band) == "alpha": return band return None @@ -423,6 +429,13 @@ def not_enough_samples(): iterations += 1 if len(not_masked_samples[1]) < n_samples: + if len(not_masked_samples[1]) == 0: + raise ValueError( + "The data is very sparse and no non-masked samples were collected.\n" + "Please, consider to use the --exact_stats option to compute exact " + "stats" + ) + warnings.warn( "The data is very sparse and there are not enough non-masked samples.\n" f"Only {len(not_masked_samples[1])} samples were collected and " @@ -470,7 +483,7 @@ def raster_band_approx_stats( raster_dataset: rasterio.io.DatasetReader, samples: Samples, band: int, - all_stats: bool, + basic_stats: bool, ) -> dict: """Get approximate statistics for a raster band.""" @@ -485,10 +498,10 @@ def raster_band_approx_stats( _sum = int(np.sum(samples_band)) sum_squares = int(np.sum(np.array(samples_band) ** 2)) - quantiles = None - most_common = None - if all_stats: - + if basic_stats: + quantiles = None + most_common = None + else: quantiles = compute_quantiles(samples_band, int) most_common = dict() @@ -551,7 +564,7 @@ def read_raster_band(raster_dataset: rasterio.io.DatasetReader, band: int) -> np def raster_band_stats( - raster_dataset: rasterio.io.DatasetReader, band: int, all_stats: bool + raster_dataset: rasterio.io.DatasetReader, band: int, basic_stats: bool ) -> dict: """Get statistics for a raster band.""" @@ -563,35 +576,33 @@ def raster_band_stats( _mean = _stats.mean _std = _stats.std - count = math.prod(_stats.shape) + raster_band = read_raster_band(raster_dataset=raster_dataset, band=band) + + count = math.prod(raster_band.shape) if is_masked_band(raster_dataset, band): - count = np.count_nonzero(_stats.mask is False) + count = np.count_nonzero(raster_band.mask is False) _sum = _mean * count sum_squares = count * _std**2 + _mean**2 - quantiles = None - most_common = None - if all_stats: - raster_band = read_raster_band(raster_dataset=raster_dataset, band=band) - - print("Removing masked data...") - qdata = raster_band.compressed() + print("Removing masked data...") + qdata = raster_band.compressed() + if basic_stats: + quantiles = None + most_common = None + else: casting_function = ( int if np.issubdtype(raster_band.dtype, np.integer) else float ) quantiles = compute_quantiles(qdata, casting_function) - print("Computing most commons values...") - warnings.warn( - "Most common values are meant for categorical data. " - "Computing them for float bands can be meaningless." - ) - most_common = Counter(qdata).most_common(100) - most_common.sort(key=lambda x: x[1], reverse=True) - most_common = dict([(casting_function(x[0]), x[1]) for x in most_common]) + if casting_function == int: + print("Computing most commons values...") + most_common = Counter(qdata).most_common(100) + most_common.sort(key=lambda x: x[1], reverse=True) + most_common = dict([(casting_function(x[0]), x[1]) for x in most_common]) version = ".".join(__version__.split(".")[:3]) @@ -990,7 +1001,10 @@ def update_metadata(metadata, old_metadata): stdev = math.sqrt(old_stats["stddev"] ** 2 + new_stats["stddev"] ** 2) else: mean = _sum / count - stdev = math.sqrt(sum_squares / count - mean * mean) + try: + stdev = math.sqrt(sum_squares / count - mean * mean) + except ValueError: + stdev = math.sqrt(old_stats["stddev"] ** 2 + new_stats["stddev"] ** 2) approximated_stats = ( old_stats["approximated_stats"] or new_stats["approximated_stats"] diff --git a/raster_loader/io/snowflake.py b/raster_loader/io/snowflake.py index af0caa5..86bc0d6 100644 --- a/raster_loader/io/snowflake.py +++ b/raster_loader/io/snowflake.py @@ -179,7 +179,7 @@ def upload_raster( append: bool = False, cleanup_on_failure: bool = False, exact_stats: bool = False, - all_stats: bool = False, + basic_stats: bool = False, ) -> bool: def band_rename_function(x): return x.upper() @@ -207,7 +207,7 @@ def band_rename_function(x): exit() metadata = rasterio_metadata( - file_path, bands_info, band_rename_function, exact_stats, all_stats + file_path, bands_info, band_rename_function, exact_stats, basic_stats ) overviews_records_gen = rasterio_overview_to_records( diff --git a/raster_loader/tests/bigquery/test_cli.py b/raster_loader/tests/bigquery/test_cli.py index 9ff2214..5478e94 100644 --- a/raster_loader/tests/bigquery/test_cli.py +++ b/raster_loader/tests/bigquery/test_cli.py @@ -38,6 +38,33 @@ def test_bigquery_upload(*args, **kwargs): assert result.exit_code == 0 +@patch("raster_loader.cli.bigquery.BigQueryConnection.upload_raster", return_value=None) +@patch("raster_loader.cli.bigquery.BigQueryConnection.__init__", return_value=None) +def test_bigquery_upload_with_basic_stats(*args, **kwargs): + runner = CliRunner() + result = runner.invoke( + main, + [ + "bigquery", + "upload", + "--file_path", + f"{tiff}", + "--project", + "project", + "--dataset", + "dataset", + "--table", + "table", + "--chunk_size", + 1, + "--band", + 1, + "--basic_stats", + ], + ) + assert result.exit_code == 0 + + @patch("raster_loader.cli.bigquery.BigQueryConnection.upload_raster", return_value=None) @patch("raster_loader.cli.bigquery.BigQueryConnection.__init__", return_value=None) def test_bigquery_upload_with_all_stats(*args, **kwargs): @@ -59,7 +86,6 @@ def test_bigquery_upload_with_all_stats(*args, **kwargs): 1, "--band", 1, - "--all_stats", ], ) assert result.exit_code == 0 diff --git a/raster_loader/tests/snowflake/test_cli.py b/raster_loader/tests/snowflake/test_cli.py index 9c11771..a885231 100644 --- a/raster_loader/tests/snowflake/test_cli.py +++ b/raster_loader/tests/snowflake/test_cli.py @@ -46,6 +46,41 @@ def test_snowflake_upload(*args, **kwargs): assert result.exit_code == 0 +@patch( + "raster_loader.io.snowflake.SnowflakeConnection.upload_raster", return_value=None +) +@patch("raster_loader.io.snowflake.SnowflakeConnection.__init__", return_value=None) +def test_snowflake_upload_with_basic_stats(*args, **kwargs): + runner = CliRunner() + result = runner.invoke( + main, + [ + "snowflake", + "upload", + "--file_path", + f"{tiff}", + "--database", + "database", + "--schema", + "schema", + "--table", + "table", + "--account", + "account", + "--username", + "username", + "--password", + "password", + "--chunk_size", + 1, + "--band", + 1, + "--basic_stats", + ], + ) + assert result.exit_code == 0 + + @patch( "raster_loader.io.snowflake.SnowflakeConnection.upload_raster", return_value=None ) @@ -75,7 +110,6 @@ def test_snowflake_upload_with_all_stats(*args, **kwargs): 1, "--band", 1, - "--all_stats", ], ) assert result.exit_code == 0