Skip to content

Commit

Permalink
Merge pull request #38 from c-hydro:dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
ltrotter authored Dec 20, 2024
2 parents d9899e6 + fcb5975 commit da33ab5
Show file tree
Hide file tree
Showing 6 changed files with 206 additions and 12 deletions.
16 changes: 12 additions & 4 deletions data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,7 +710,14 @@ def find_times(self, times: list[TimeStep|dt.datetime], id = False, rev = False,
Find the times for which data is available.
"""
all_ids = list(range(len(times)))
ids = [i for i in all_ids if self.check_data(times[i], **kwargs)] or []
if hasattr(times[0].start):
tr = TimeRange(min(times).start, max(times).end)
else:
tr = TimeRange(min(times), max(times))

all_times = self.get_available_tags(tr, **kwargs).get('time', [])

ids = [i for i in all_ids if times[i] in all_times] or []
if rev:
ids = [i for i in all_ids if i not in ids] or []

Expand All @@ -720,14 +727,15 @@ def find_times(self, times: list[TimeStep|dt.datetime], id = False, rev = False,
return [times[i] for i in ids]

@withcases
def find_tiles(self, time: Optional[TimeStep|dt.datetime] = None, rev = False,**kwargs) -> list[str]:
def find_tiles(self, time: Optional[TimeStep|dt.datetime] = None, rev = False, **kwargs) -> list[str]:
"""
Find the tiles for which data is available.
"""
all_tiles = self.tile_names
available_tiles = [tile for tile in all_tiles if self.check_data(time, tile = tile, **kwargs)]
available_tiles = self.get_available_tags(time, **kwargs).get('tile', [])

if not rev:
return available_tiles
return [tile for tile in all_tiles if tile in available_tiles]
else:
return [tile for tile in all_tiles if tile not in available_tiles]

Expand Down
4 changes: 2 additions & 2 deletions data/io_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def get_format_from_path(path: str) -> str:
elif extension == 'nc':
return 'netcdf'

elif extension == 'json':
elif extension in ['json', 'geojson']:
return 'json'

elif extension == 'txt':
Expand Down Expand Up @@ -54,7 +54,7 @@ def read_from_file(path, format: Optional[str] = None) -> xr.DataArray|xr.Datase
data = json.load(f)
# understand if the data is actually in a geodataframe format
if 'features' in data.keys():
data = gpd.GeoDataFrame.from_features(data['features'])
data = gpd.read_file(path)

# read the data from a txt file
elif format == 'txt':
Expand Down
Empty file added spatial/__init__.py
Empty file.
113 changes: 113 additions & 0 deletions spatial/bounding_box.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import xarray as xr
import rioxarray as rxr
from pyproj import CRS, Transformer

from .space_utils import get_crs, buffer_bbox
from ..data import Dataset

class BoundingBox():

def __init__(self,
left: float, bottom: float, right: float, top: float,
datum: str|int|CRS = 'EPSG:4326',
buffer: float = 0.0,):
"""
Create a bounding box object
"""

# datum should be able to accept both EPSG codes and WKT strings and should default to WGS84
self.crs = get_crs(datum)

# set the bounding box
self.bbox = (left, bottom, right, top)

# buffer the bounding box
self.add_buffer(buffer)

@property
def epsg_code(self):
return f'EPSG:{self.crs.to_epsg()}'

@property
def wkt_datum(self):
return self.crs.to_wkt()

@staticmethod
def from_xarray(data: xr.DataArray, buffer: float = 0.0):
left, bottom, right, top = data.rio.bounds()

return BoundingBox(left, bottom, right, top, datum = data.rio.crs, buffer = buffer)

@staticmethod
def from_dataset(dataset: Dataset, buffer: float = 0.0):
data:xr.DataArray = dataset.get_data()

return BoundingBox.from_xarray(data, buffer)

@staticmethod
def from_file(grid_file, buffer: float = 0.0):
"""
Get attributes from grid_file
We get the bounding box, crs, resolution, shape and transform of the grid.
"""

# grid_data = gdal.Open(grid_file, gdalconst.GA_ReadOnly)

# transform = grid_data.GetGeoTransform()
# shape = (grid_data.RasterYSize, grid_data.RasterXSize)

# #bbox in the form (min_lon, min_lat, max_lon, max_lat)
# left = transform[0]
# top = transform[3]
# right = transform[0] + shape[1]*transform[1]
# bottom = transform[3] + shape[0]*transform[5]

# proj = grid_data.GetProjection()

# grid_data = None
# return BoundingBox(left, bottom, right, top, datum = proj, buffer = buffer)

data = rxr.open_rasterio(grid_file)
return BoundingBox.from_xarray(data, buffer)

def buffer(self, buffer: int) -> None:
"""
Buffer the bounding box, the buffer is in units of coordinates
"""
self.add_buffer = buffer
self.bbox = buffer_bbox(self.bbox, buffer)

def transform(self, new_datum: str, inplace = False) -> None:
"""
Transform the bounding box to a new datum
new_datum: the new datum in the form of an EPSG code
"""

# figure out if we were given an EPSG code or a WKT string
new_crs: CRS = get_crs(new_datum)

# if the new datum is different to the current one, do nothing
if not new_crs==self.crs:

# Create a transformer to convert coordinates
transformer = Transformer.from_crs(self.crs, new_crs, always_xy=True)

# Transform the bounding box coordinates - because the image might be warped, we need to transform all 4 corners
bl_x, bl_y = transformer.transform(self.bbox[0], self.bbox[1])
tr_x, tr_y = transformer.transform(self.bbox[2], self.bbox[3])
br_x, br_y = transformer.transform(self.bbox[2], self.bbox[1])
tl_x, tl_y = transformer.transform(self.bbox[0], self.bbox[3])

# get the new bounding box
min_x = min(bl_x, tl_x)
max_x = max(br_x, tr_x)
min_y = min(bl_y, br_y)
max_y = max(tl_y, tr_y)
else:
min_x, min_y, max_x, max_y = self.bbox

if inplace:
self.bbox = (min_x, min_y, max_x, max_y)
self.crs = new_crs
else:
return BoundingBox(min_x, min_y, max_x, max_y, datum = new_crs)
73 changes: 73 additions & 0 deletions spatial/space_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from pyproj import CRS
import xarray as xr
from typing import Sequence

from .bounding_box import BoundingBox


def get_crs(datum: str|int|CRS) -> CRS:
"""
Get the CRS object from the datum
"""

if isinstance(datum, str):
return CRS.from_string(datum)
elif isinstance(datum, int):
return CRS.from_epsg(datum)
elif isinstance(datum, CRS):
return datum
else:
raise ValueError(f'Unknown datum type: {datum}, please provide an EPSG code ("EPSG:#####") or a WKT string.')

def crop_to_bb(src: str|xr.DataArray|xr.Dataset,
BBox: BoundingBox) -> xr.DataArray:
"""
Cut a geotiff to a bounding box.
"""
if isinstance(src, str):
if src.endswith(".nc"):
src_ds = xr.load_dataset(src, engine="netcdf4")
elif src.endswith(".grib"):
src_ds = xr.load_dataset(src, engine="cfgrib")
elif isinstance(src, xr.DataArray) or isinstance(src, xr.Dataset):
src_ds = src

x_dim = src_ds.rio.x_dim
lon_values = src_ds[x_dim].values
if (lon_values > 180).any():
new_lon_values = xr.where(lon_values > 180, lon_values - 360, lon_values)
new = src_ds.assign_coords({x_dim:new_lon_values}).sortby(x_dim)
src_ds = new.rio.set_spatial_dims(x_dim, new.rio.y_dim)

# transform the bounding box to the geotiff projection
if src_ds.rio.crs is not None:
transformed_BBox = BBox.transform(src_ds.rio.crs.to_wkt())
else:
src_ds = src_ds.rio.write_crs(BBox.wkt_datum, inplace=True)
transformed_BBox = BBox
# otherwise, let's assume that the bounding box is already in the right projection
#TODO: eventually fix this...

# Crop the raster
cropped = clip_xarray(src_ds, transformed_BBox)

return cropped

def clip_xarray(input: xr.DataArray,
bounds: tuple[float, float, float, float],
) -> xr.DataArray:

bounds_buffered = buffer_bbox(bounds, -1e-6)
input_clipped = input.rio.clip_box(*bounds_buffered)

return input_clipped

def buffer_bbox(bbox: Sequence[float], buffer: int) -> Sequence[float]:
"""
Buffer the bounding box, the buffer is in units of coordinates
"""
left, bottom, right, top = bbox
return (left - buffer,
bottom - buffer,
right + buffer,
top + buffer)
12 changes: 6 additions & 6 deletions thumbnails/thumbnail.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,21 +185,22 @@ def save(self, destination:str, **kwargs):
self.add_overlay(kwargs['overlay'])
elif kwargs['overlay'] == False or kwargs['overlay'] is None:
pass

if 'annotation' in kwargs:
if isinstance(kwargs['annotation'], dict):
annotation_opts = kwargs.pop('annotation')
if 'text' not in annotation_opts:
if 'source_key' in self.src.attrs:
annotation_opts['text'] = os.path.basename(self.src.attrs['source_key'])
text = os.path.basename(self.src.attrs['source_key'])
elif hasattr(self, 'raster_file'):
annotation_opts['text'] = os.path.basename(self.raster_file)
self.add_annotation(**annotation_opts)
text = os.path.basename(self.raster_file)
else:
text = annotation_opts.pop('text')
self.add_annotation(text, **annotation_opts)
elif isinstance(kwargs['annotation'], str):
self.add_annotation(kwargs['annotation'])
elif kwargs['annotation'] == False or kwargs['annotation'] is None or kwargs['annotation'].lower == 'none':
pass

elif 'source_key' in self.src.attrs:
annotation_txt = os.path.basename(self.src.attrs['source_key'])
self.add_annotation(annotation_txt)
Expand All @@ -220,5 +221,4 @@ def save(self, destination:str, **kwargs):

os.makedirs(os.path.dirname(destination), exist_ok=True)
self.fig.savefig(destination, dpi=self.dpi, bbox_inches='tight', pad_inches=0)

plt.close(self.fig)

0 comments on commit da33ab5

Please sign in to comment.