Skip to content

Commit

Permalink
add support for netcdf/hdf groups with different shapes (#62)
Browse files Browse the repository at this point in the history
* add support for netcdf/hdf groups with different shapes

* break out open_rasterio into sub functions to reduce complexity of the main function

* added base tags for main dataset
  • Loading branch information
snowman2 authored Nov 5, 2019
1 parent 09c3d78 commit 9cf5a74
Show file tree
Hide file tree
Showing 4 changed files with 230 additions and 121 deletions.
264 changes: 159 additions & 105 deletions rioxarray/_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from distutils.version import LooseVersion

import numpy as np
import rasterio
from rasterio.vrt import WarpedVRT
from xarray import DataArray, Dataset
from xarray.backends.common import BackendArray
from xarray.backends.file_manager import CachingFileManager
Expand Down Expand Up @@ -233,6 +235,144 @@ def build_subdataset_filter(group_names, variable_names):
)


def _rio_transform(riods):
"""
Get the transform from a rasterio dataset
reguardless of rasterio version.
"""
try:
return riods.transform
except AttributeError:
return riods.affine # rasterio < 1.0


def _get_rasterio_attrs(riods, masked):
"""
Get rasterio specific attributes/encoding
"""
# Add rasterio attributes
attrs = _parse_tags(riods.tags(1))
encoding = dict()
# Affine transformation matrix (always available)
# This describes coefficients mapping pixel coordinates to CRS
# For serialization store as tuple of 6 floats, the last row being
# always (0, 0, 1) per definition (see
# https://github.com/sgillies/affine)
attrs["transform"] = tuple(_rio_transform(riods))[:6]
if hasattr(riods, "nodata") and riods.nodata is not None:
# The nodata values for the raster bands
if masked:
encoding["_FillValue"] = riods.nodata
else:
attrs["_FillValue"] = riods.nodata
if hasattr(riods, "scales"):
# The scale values for the raster bands
attrs["scales"] = riods.scales
if hasattr(riods, "offsets"):
# The offset values for the raster bands
attrs["offsets"] = riods.offsets
if hasattr(riods, "descriptions") and any(riods.descriptions):
# Descriptions for each dataset band
attrs["descriptions"] = riods.descriptions
if hasattr(riods, "units") and any(riods.units):
# A list of units string for each dataset band
attrs["units"] = riods.units
return attrs, encoding


def _parse_driver_tags(riods, attrs, coords):
# Parse extra metadata from tags, if supported
parsers = {"ENVI": _parse_envi}

driver = riods.driver
if driver in parsers:
meta = parsers[driver](riods.tags(ns=driver))

for k, v in meta.items():
# Add values as coordinates if they match the band count,
# as attributes otherwise
if isinstance(v, (list, np.ndarray)) and len(v) == riods.count:
coords[k] = ("band", np.asarray(v))
else:
attrs[k] = v


def _load_subdatasets(
riods, group, variable, parse_coordinates, chunks, cache, lock, masked
):
"""
Load in rasterio subdatasets
"""
base_tags = _parse_tags(riods.tags())
dim_groups = {}
subdataset_filter = None
if any((group, variable)):
subdataset_filter = build_subdataset_filter(group, variable)
for iii, subdataset in enumerate(riods.subdatasets):
if subdataset_filter is not None and not subdataset_filter.match(subdataset):
continue
with rasterio.open(subdataset) as rds:
shape = rds.shape
rioda = open_rasterio(
subdataset,
parse_coordinates=shape not in dim_groups and parse_coordinates,
chunks=chunks,
cache=cache,
lock=lock,
masked=masked,
default_name=subdataset.split(":")[-1].lstrip("/").replace("/", "_"),
)
if shape not in dim_groups:
dim_groups[shape] = {rioda.name: rioda}
else:
dim_groups[shape][rioda.name] = rioda

if len(dim_groups) > 1:
dataset = [
Dataset(dim_group, attrs=base_tags) for dim_group in dim_groups.values()
]
elif not dim_groups:
dataset = Dataset(attrs=base_tags)
else:
dataset = Dataset(list(dim_groups.values())[0], attrs=base_tags)
return dataset


def _prepare_dask(result, riods, filename, chunks):
"""
Prepare the data for dask computations
"""
from dask.base import tokenize

# augment the token with the file modification time
try:
mtime = os.path.getmtime(filename)
except OSError:
# the filename is probably an s3 bucket rather than a regular file
mtime = None

if chunks in (True, "auto"):
from dask.array.core import normalize_chunks
import dask

if LooseVersion(dask.__version__) < LooseVersion("0.18.0"):
msg = (
"Automatic chunking requires dask.__version__ >= 0.18.0 . "
"You currently have version %s" % dask.__version__
)
raise NotImplementedError(msg)
block_shape = (1,) + riods.block_shapes[0]
chunks = normalize_chunks(
chunks=(1, "auto", "auto"),
shape=(riods.count, riods.height, riods.width),
dtype=riods.dtypes[0],
previous_chunks=tuple((c,) for c in block_shape),
)
token = tokenize(filename, mtime, chunks)
name_prefix = "open_rasterio-%s" % token
return result.chunk(chunks, name_prefix=name_prefix, token=token)


def open_rasterio(
filename,
parse_coordinates=None,
Expand Down Expand Up @@ -306,10 +446,6 @@ def open_rasterio(
The newly created DataArray.
"""
parse_coordinates = True if parse_coordinates is None else parse_coordinates

import rasterio
from rasterio.vrt import WarpedVRT

vrt_params = None
if isinstance(filename, rasterio.io.DatasetReader):
filename = filename.name
Expand Down Expand Up @@ -338,50 +474,38 @@ def open_rasterio(
rasterio.open, filename, lock=lock, mode="r", kwargs=open_kwargs
)
riods = manager.acquire()

# open the subdatasets if they exist
if riods.subdatasets:
subdataset_filter = None
if any((group, variable)):
subdataset_filter = build_subdataset_filter(group, variable)
data_arrays = {}
for iii, subdataset in enumerate(riods.subdatasets):
if subdataset_filter is not None and not subdataset_filter.match(
subdataset
):
continue
rioda = open_rasterio(
subdataset,
parse_coordinates=iii == 0 and parse_coordinates,
chunks=chunks,
cache=cache,
lock=lock,
masked=masked,
default_name=subdataset.split(":")[-1].lstrip("/").replace("/", "_"),
)
data_arrays[rioda.name] = rioda
return Dataset(data_arrays)
return _load_subdatasets(
riods=riods,
group=group,
variable=variable,
parse_coordinates=parse_coordinates,
chunks=chunks,
cache=cache,
lock=lock,
masked=masked,
)

if vrt_params is not None:
riods = WarpedVRT(riods, **vrt_params)

if cache is None:
cache = chunks is None

coords = OrderedDict()

# Get bands
if riods.count < 1:
raise ValueError("Unknown dims")
coords = OrderedDict()
coords["band"] = np.asarray(riods.indexes)

# Get coordinates
if LooseVersion(rasterio.__version__) < LooseVersion("1.0"):
transform = riods.affine
else:
transform = riods.transform
# parse tags
attrs, encoding = _get_rasterio_attrs(riods=riods, masked=masked)
_parse_driver_tags(riods=riods, attrs=attrs, coords=coords)

if transform.is_rectilinear and parse_coordinates:
# Get geospatial coordinates
transform = _rio_transform(riods)
if parse_coordinates and transform.is_rectilinear:
# 1d coordinates
coords.update(affine_to_coords(riods.transform, riods.width, riods.height))
elif parse_coordinates:
Expand All @@ -395,49 +519,6 @@ def open_rasterio(
stacklevel=3,
)

# Attributes
attrs = _parse_tags(riods.tags(1))
encoding = dict()
# Affine transformation matrix (always available)
# This describes coefficients mapping pixel coordinates to CRS
# For serialization store as tuple of 6 floats, the last row being
# always (0, 0, 1) per definition (see
# https://github.com/sgillies/affine)
attrs["transform"] = tuple(transform)[:6]
if hasattr(riods, "nodata") and riods.nodata is not None:
# The nodata values for the raster bands
if masked:
encoding["_FillValue"] = riods.nodata
else:
attrs["_FillValue"] = riods.nodata
if hasattr(riods, "scales"):
# The scale values for the raster bands
attrs["scales"] = riods.scales
if hasattr(riods, "offsets"):
# The offset values for the raster bands
attrs["offsets"] = riods.offsets
if hasattr(riods, "descriptions") and any(riods.descriptions):
# Descriptions for each dataset band
attrs["descriptions"] = riods.descriptions
if hasattr(riods, "units") and any(riods.units):
# A list of units string for each dataset band
attrs["units"] = riods.units

# Parse extra metadata from tags, if supported
parsers = {"ENVI": _parse_envi}

driver = riods.driver
if driver in parsers:
meta = parsers[driver](riods.tags(ns=driver))

for k, v in meta.items():
# Add values as coordinates if they match the band count,
# as attributes otherwise
if isinstance(v, (list, np.ndarray)) and len(v) == riods.count:
coords[k] = ("band", np.asarray(v))
else:
attrs[k] = v

data = indexing.LazilyOuterIndexedArray(
RasterioArrayWrapper(manager, lock, vrt_params, masked=masked)
)
Expand All @@ -447,6 +528,7 @@ def open_rasterio(
if cache and chunks is None:
data = indexing.MemoryCachedArray(data)

# create the output data array
da_name = attrs.pop("NETCDF_VARNAME", default_name)
result = DataArray(
data=data, dims=("band", "y", "x"), coords=coords, attrs=attrs, name=da_name
Expand All @@ -457,35 +539,7 @@ def open_rasterio(
result.rio.write_crs(riods.crs, inplace=True)

if chunks is not None:
from dask.base import tokenize

# augment the token with the file modification time
try:
mtime = os.path.getmtime(filename)
except OSError:
# the filename is probably an s3 bucket rather than a regular file
mtime = None

if chunks in (True, "auto"):
from dask.array.core import normalize_chunks
import dask

if LooseVersion(dask.__version__) < LooseVersion("0.18.0"):
msg = (
"Automatic chunking requires dask.__version__ >= 0.18.0 . "
"You currently have version %s" % dask.__version__
)
raise NotImplementedError(msg)
block_shape = (1,) + riods.block_shapes[0]
chunks = normalize_chunks(
chunks=(1, "auto", "auto"),
shape=(riods.count, riods.height, riods.width),
dtype=riods.dtypes[0],
previous_chunks=tuple((c,) for c in block_shape),
)
token = tokenize(filename, mtime, chunks)
name_prefix = "open_rasterio-%s" % token
result = result.chunk(chunks, name_prefix=name_prefix, token=token)
result = _prepare_dask(result, riods, filename, chunks)

# Make the file closeable
result._file_obj = manager
Expand Down
4 changes: 4 additions & 0 deletions sphinx/history.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
History
=======

0.0.16
------
- Add support for netcdf/hdf groups with different shapes (pull #62)

0.0.15
------
- Added `variable` and `group` kwargs to `rioxarray.open_rasterio()` to allow filtering of subdatasets (pull #57)
Expand Down
Loading

0 comments on commit 9cf5a74

Please sign in to comment.