Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Two speedups #262

Merged
merged 2 commits into from
Apr 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions samgeo/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1122,6 +1122,8 @@ def tiff_to_tiff(
func,
data_to_rgb=chw_to_hwc,
sample_size=(512, 512),
sample_nodata_threshold=1.0,
nodata_value=None,
sample_resize=None,
bound=128,
foreground=True,
Expand All @@ -1132,6 +1134,9 @@ def tiff_to_tiff(
with rasterio.open(src_fp) as src:
profile = src.profile

if nodata_value is None:
nodata_values = profile.get("nodata", None)

# Computer blocks
rh, rw = profile["height"], profile["width"]
sh, sw = sample_size
Expand All @@ -1154,6 +1159,11 @@ def tiff_to_tiff(
for b in tqdm(sample_grid):
# Read each tile from the source
r = read_block(src, **b)

if nodata_value is not None:
if (r == nodata_value).mean() >= sample_nodata_threshold:
continue

# Extract the first 3 channels as RGB
uint8_rgb_in = data_to_rgb(r)
orig_size = uint8_rgb_in.shape[:2]
Expand Down
12 changes: 12 additions & 0 deletions samgeo/samgeo.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,9 @@ def generate(
output=None,
foreground=True,
batch=False,
batch_sample_size=(512, 512),
batch_nodata_threshold=1.0,
nodata_value=None,
erosion_kernel=None,
mask_multiplier=255,
unique=True,
Expand All @@ -164,6 +167,12 @@ def generate(
output (str, optional): The path to the output image. Defaults to None.
foreground (bool, optional): Whether to generate the foreground mask. Defaults to True.
batch (bool, optional): Whether to generate masks for a batch of image tiles. Defaults to False.
batch_sample_size (tuple, optional): When batch=True, the size of the sample window when iterating over rasters.
batch_nodata_threshold (float,optional): Batch samples with a fraction of nodata pixels above this threshold will
not be used to generate a mask. The default, 1.0, will skip samples with 100% nodata values. This is useful
when rasters have large areas of nodata values which can be skipped.
nodata_value (int, optional): Nodata value to use in checking batch_nodata_threshold. The default, None,
will use the nodata value in the raster metadata if present.
erosion_kernel (tuple, optional): The erosion kernel for filtering object masks and extract borders.
Such as (3, 3) or (5, 5). Set to None to disable it. Defaults to None.
mask_multiplier (int, optional): The mask multiplier for the output mask, which is usually a binary mask [0, 1].
Expand All @@ -190,6 +199,9 @@ def generate(
output,
self,
foreground=foreground,
sample_size=batch_sample_size,
sample_nodata_threshold=batch_nodata_threshold,
nodata_value=nodata_value,
erosion_kernel=erosion_kernel,
mask_multiplier=mask_multiplier,
**kwargs,
Expand Down
Loading