From fd3b7cec019ba4d62b31d03e9b4dfa4116ee683e Mon Sep 17 00:00:00 2001 From: "Soroosh.Mani" Date: Fri, 6 Oct 2023 10:27:05 -0400 Subject: [PATCH 1/5] Add band argument for interpolation & Cleanup older comments --- ocsmesh/hfun/mesh.py | 2 -- ocsmesh/mesh/mesh.py | 20 +++++++++++++------- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/ocsmesh/hfun/mesh.py b/ocsmesh/hfun/mesh.py index 4b121656..5669bbee 100644 --- a/ocsmesh/hfun/mesh.py +++ b/ocsmesh/hfun/mesh.py @@ -571,8 +571,6 @@ def add_region_constraint( Add fixed-value or fixed-matrix constraint. add_topo_func_constraint : Addint constraint based on function of topography - add_courant_num_constraint : - Add constraint based on approximated Courant number """ if crs is None: diff --git a/ocsmesh/mesh/mesh.py b/ocsmesh/mesh/mesh.py index 5da8d79b..ebb953f4 100644 --- a/ocsmesh/mesh/mesh.py +++ b/ocsmesh/mesh/mesh.py @@ -344,7 +344,8 @@ def interpolate( method: Literal['spline', 'linear', 'nearest'] = 'spline', nprocs: Optional[int] = None, info_out_path: Union[pathlib.Path, str, None] = None, - filter_by_shape: bool = False + filter_by_shape: bool = False, + band: int = 1, ) -> None: """Interplate values from raster inputs to the mesh nodes. @@ -359,8 +360,10 @@ def interpolate( Number of workers to use when interpolating data. info_out_path : pathlike or str or None Path for the output node interpolation information file - filter_by_shape : bool + filter_by_shape : bool, default=False Flag for node filtering based on raster bbox or shape + band : int, default=1 + The band from rasters to use for interpolation Returns ------- @@ -382,7 +385,7 @@ def interpolate( _mesh_interpolate_worker, [(self.vert2['coord'], self.crs, _raster.tmpfile, _raster.chunk_size, - method, filter_by_shape) + method, filter_by_shape, band) for _raster in raster] ) pool.join() @@ -390,7 +393,7 @@ def interpolate( res = [_mesh_interpolate_worker( self.vert2['coord'], self.crs, _raster.tmpfile, _raster.chunk_size, - method, filter_by_shape) + method, filter_by_shape, band) for _raster in raster] values = self.msh_t.value.flatten() @@ -2234,7 +2237,8 @@ def _mesh_interpolate_worker( raster_path: Union[str, Path], chunk_size: Optional[int], method: Literal['spline', 'linear', 'nearest'] = "spline", - filter_by_shape: bool = False): + filter_by_shape: bool = False, + band: int = 1): """Interpolator worker function to be used in parallel calls Parameters @@ -2249,8 +2253,10 @@ def _mesh_interpolate_worker( Chunk size for windowing over the raster. method : {'spline', 'linear', 'nearest'}, default='spline' Method of interpolation. - filter_by_shape : bool + filter_by_shape : bool, default=False Flag for node filtering based on raster bbox or shape + band : int, default=1 + The band from rasters to use for interpolation Returns ------- @@ -2281,7 +2287,7 @@ def _mesh_interpolate_worker( xi = raster.get_x(window) yi = raster.get_y(window) # Use masked array to ignore missing values from DEM - zi = raster.get_values(window=window, masked=True) + zi = raster.get_values(window=window, masked=True, band=band) if not filter_by_shape: _idxs = np.logical_and( From 503c6b8163269d6056e8694d78a032097a28c252 Mon Sep 17 00:00:00 2001 From: "Soroosh.Mani" Date: Fri, 6 Oct 2023 11:17:47 -0400 Subject: [PATCH 2/5] Placeholder for tests before merge --- tests/api/mesh.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/api/mesh.py b/tests/api/mesh.py index 0f2e8804..326fab23 100644 --- a/tests/api/mesh.py +++ b/tests/api/mesh.py @@ -317,5 +317,13 @@ def test_specify_boundary_on_mesh_with_no_boundary(self): self.assertEqual(bdry.open().iloc[0]['index_id'], [1, 2, 3]) +class RasterInterpolation(unittest.TestCase): + + def test_interpolation_io(self): + self.assert(False) + + def test_interpolation_band(self): + self.assert(False) + if __name__ == '__main__': unittest.main() From 199dc4e2a0ba076710709d03e4b7e1eb52962368 Mon Sep 17 00:00:00 2001 From: "Soroosh.Mani" Date: Tue, 10 Oct 2023 16:51:59 -0400 Subject: [PATCH 3/5] Add multiband raster creation function --- ocsmesh/utils.py | 29 ++++++++++++++++++++--------- tests/api/utils.py | 39 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 59 insertions(+), 9 deletions(-) diff --git a/ocsmesh/utils.py b/ocsmesh/utils.py index 896ba59f..3e34e212 100644 --- a/ocsmesh/utils.py +++ b/ocsmesh/utils.py @@ -1385,13 +1385,13 @@ def msh_t_to_grd(msh: jigsaw_msh_t) -> Dict: desc = "EPSG:4326" if src_crs is not None: # TODO: Support non EPSG:4326 CRS -# desc = src_crs.to_string() - epsg_4326 = CRS.from_epsg(4326) - if not src_crs.equals(epsg_4326): - transformer = Transformer.from_crs( - src_crs, epsg_4326, always_xy=True) - coords = np.vstack( - transformer.transform(coords[:, 0], coords[:, 1])).T + desc = src_crs.to_string() +# epsg_4326 = CRS.from_epsg(4326) +# if not src_crs.equals(epsg_4326): +# transformer = Transformer.from_crs( +# src_crs, epsg_4326, always_xy=True) +# coords = np.vstack( +# transformer.transform(coords[:, 0], coords[:, 1])).T nodes = { i + 1: [tuple(p.tolist()), v] for i, (p, v) in @@ -2080,20 +2080,31 @@ def raster_from_numpy( if not isinstance(crs, CRS): crs = CRS.from_user_input(crs) + nbands = 1 + if data.ndim == 3: + nbands = data.shape[2] + elif data.ndim != 2: + raise ValueError("Invalid data dimensions!") + with rio.open( filename, 'w', driver='GTiff', height=data.shape[0], width=data.shape[1], - count=1, + count=nbands, dtype=data.dtype, crs=crs, transform=transform, ) as dst: if isinstance(data, np.ma.MaskedArray): dst.nodata = data.fill_value - dst.write(data, 1) + + data = data.reshape(data.shape[0], data.shape[1], -1) + for i in range(nbands): + dst.write(data.take(i, axis=2), i + 1) + + def msht_from_numpy( diff --git a/tests/api/utils.py b/tests/api/utils.py index 9955a6ee..b5611844 100644 --- a/tests/api/utils.py +++ b/tests/api/utils.py @@ -695,6 +695,45 @@ def test_data_masking(self): self.assertEqual(rast.src.nodata, fill_value) + def test_multiband_raster_data(self): + nbands = 5 + in_data = np.ones((3, 4, nbands)) + for i in range(nbands): + in_data[:, :, i] *= i + in_rast_xy = np.mgrid[-74:-71:1, 40.5:40.9:0.1] + with tempfile.NamedTemporaryFile(suffix='.tiff') as tf: + utils.raster_from_numpy( + tf.name, + data=in_data, + mgrid=in_rast_xy, + crs=4326 + ) + rast = Raster(tf.name) + self.assertEqual(rast.count, nbands) + for i in range(nbands): + with self.subTest(band_number=i): + self.assertTrue( + (rast.get_values(band=i+1) == i).all() + ) + + + def test_multiband_raster_invalid_io(self): + in_data = np.ones((3, 4, 5, 6)) + in_rast_xy = np.mgrid[-74:-71:1, 40.5:40.9:0.1] + with tempfile.NamedTemporaryFile(suffix='.tiff') as tf: + with self.assertRaises(ValueError) as cm: + utils.raster_from_numpy( + tf.name, + data=in_data, + mgrid=in_rast_xy, + crs=4326 + ) + exc = cm.exception + self.assertRegex(str(exc).lower(), '.*dimension.*') + + + + class ShapeToMeshT(unittest.TestCase): def setUp(self): From 87aa00f5f9da9e5b4f93540b4fac7a14210decfa Mon Sep 17 00:00:00 2001 From: "Soroosh.Mani" Date: Tue, 10 Oct 2023 17:19:29 -0400 Subject: [PATCH 4/5] Add test for interpolate band arg --- tests/api/mesh.py | 61 ++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 58 insertions(+), 3 deletions(-) diff --git a/tests/api/mesh.py b/tests/api/mesh.py index 326fab23..65a3a612 100644 --- a/tests/api/mesh.py +++ b/tests/api/mesh.py @@ -2,14 +2,16 @@ import tempfile import unittest import warnings +import shutil from pathlib import Path import numpy as np from jigsawpy import jigsaw_msh_t +from pyproj import CRS from shapely import geometry from ocsmesh import utils -from ocsmesh.mesh.mesh import Mesh +from ocsmesh.mesh.mesh import Mesh, Raster @@ -319,11 +321,64 @@ def test_specify_boundary_on_mesh_with_no_boundary(self): class RasterInterpolation(unittest.TestCase): + def setUp(self): + self.tdir = Path(tempfile.mkdtemp()) + + msht1 = utils.create_rectangle_mesh( + nx=13, ny=5, x_extent=(-73.9, -71.1), y_extent=(40.55, 40.85), + holes=[], + ) + msht1.crs = CRS.from_user_input(4326) + msht2 = utils.create_rectangle_mesh( + nx=11, ny=7, x_extent=(-73.9, -71.1), y_extent=(40.55, 40.85), + holes=[], + ) + msht2.crs = CRS.from_user_input(4326) + with warnings.catch_warnings(): + warnings.filterwarnings( + 'ignore', category=UserWarning, + message='Input mesh has no CRS information' + ) + self.mesh1 = Mesh(msht1) + self.mesh2 = Mesh(msht2) + + self.rast = self.tdir / 'rast.tif' + + rast_xy = np.mgrid[-74:-71:0.1, 40.9:40.5:-0.01] + rast_z = np.ones((rast_xy.shape[1], rast_xy.shape[2], 2)) + rast_z[:, :, 1] = 2 + utils.raster_from_numpy( + self.rast, rast_z, rast_xy, 4326 + ) + + + def tearDown(self): + shutil.rmtree(self.tdir) + + def test_interpolation_io(self): - self.assert(False) + rast = Raster(self.rast) + + self.mesh1.interpolate(rast) + self.assertTrue(np.isclose(self.mesh1.value, 1).all()) + + # TODO: Improve the assertion! + with self.assertRaises(Exception): + self.mesh1.interpolate(self.mesh2) + def test_interpolation_band(self): - self.assert(False) + rast = Raster(self.rast) + + self.mesh1.interpolate(rast) + self.assertTrue(np.isclose(self.mesh1.value, 1).all()) + + self.mesh1.interpolate(rast, band=2) + self.assertTrue(np.isclose(self.mesh1.value, 2).all()) + + + # TODO Add more interpolation tests + if __name__ == '__main__': unittest.main() From b0d9b0cace61fac13e5a0a892befd51dd8752ada Mon Sep 17 00:00:00 2001 From: "Soroosh.Mani" Date: Tue, 10 Oct 2023 17:30:04 -0400 Subject: [PATCH 5/5] Undo incorrect commit --- ocsmesh/utils.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/ocsmesh/utils.py b/ocsmesh/utils.py index 3e34e212..957cae95 100644 --- a/ocsmesh/utils.py +++ b/ocsmesh/utils.py @@ -1385,13 +1385,13 @@ def msh_t_to_grd(msh: jigsaw_msh_t) -> Dict: desc = "EPSG:4326" if src_crs is not None: # TODO: Support non EPSG:4326 CRS - desc = src_crs.to_string() -# epsg_4326 = CRS.from_epsg(4326) -# if not src_crs.equals(epsg_4326): -# transformer = Transformer.from_crs( -# src_crs, epsg_4326, always_xy=True) -# coords = np.vstack( -# transformer.transform(coords[:, 0], coords[:, 1])).T +# desc = src_crs.to_string() + epsg_4326 = CRS.from_epsg(4326) + if not src_crs.equals(epsg_4326): + transformer = Transformer.from_crs( + src_crs, epsg_4326, always_xy=True) + coords = np.vstack( + transformer.transform(coords[:, 0], coords[:, 1])).T nodes = { i + 1: [tuple(p.tolist()), v] for i, (p, v) in