Skip to content

Commit

Permalink
Add feature refinement for collector hfun
Browse files Browse the repository at this point in the history
  • Loading branch information
SorooshMani-NOAA committed Nov 23, 2022
1 parent d774886 commit 1ac2ac5
Show file tree
Hide file tree
Showing 4 changed files with 294 additions and 20 deletions.
144 changes: 128 additions & 16 deletions ocsmesh/hfun/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,14 @@
import numpy.typing as npt
import geopandas as gpd
from pyproj import CRS, Transformer
from shapely.geometry import MultiPolygon, Polygon, GeometryCollection, box
from shapely.geometry import (
MultiLineString,
LineString,
MultiPolygon,
Polygon,
GeometryCollection,
box
)
from shapely import ops
from jigsawpy import jigsaw_msh_t
from rasterio.transform import from_origin
Expand All @@ -46,6 +53,7 @@
from ocsmesh.raster import Raster, get_iter_windows
from ocsmesh.features.contour import Contour
from ocsmesh.features.patch import Patch
from ocsmesh.features.linefeature import LineFeature
from ocsmesh.features.channel import Channel
from ocsmesh.features.constraint import (
TopoConstConstraint, TopoFuncConstraint, CourantNumConstraint
Expand Down Expand Up @@ -120,7 +128,7 @@ def __init__(
Parameters
----------
contours_info : _RefinementPatchInfoCollector
contours_info : _RefinementContourInfoCollector
Handle to the collection of user specified contours
specification.
"""
Expand Down Expand Up @@ -250,43 +258,46 @@ def __iter__(self) -> Tuple[Tuple[Sequence[int], Contour, Contour], dict]:



class _RefinementPatchInfoCollector:
"""Collection for patch refinement specifications"""
class _RefinementShapeInfoCollector:
"""Collection for shape (patch or line) refinement specifications"""

def __init__(self) -> None:
self._patch_info = {}
self._shape_info = {}

def add(self, patch_defn: Patch, **size_info: Any) -> None:
"""Add patch refinement specifications to the collection
def add(self,
shape_defn: Union[Patch, LineFeature],
**size_info: Any
) -> None:
"""Add shape refinement specifications to the collection
Parameters
----------
patch_defn : Patch
shape_defn : Patch or LineFeature
Shape of the region to apply the refinement during
application.
size_info : dict
Information related to patch application such as
Information related to shape application such as
target size, rate, etc.
Returns
-------
None
"""

self._patch_info[patch_defn] = size_info
self._shape_info[shape_defn] = size_info

def __iter__(self) -> Tuple[Patch, dict]:
def __iter__(self) -> Tuple[Union[Patch, LineFeature], dict]:
"""Iterator method for this collection object
Yields
------
defn : Patch
Patch object representing the shape of the patch area.
defn : Patch or LineFeature
Object representing the shape of refinement
info : dict
Dictionary of specifications for patch refinement.
Dictionary of specifications for shape refinement.
"""

for defn, info in self._patch_info.items():
for defn, info in self._shape_info.items():
yield defn, info


Expand Down Expand Up @@ -660,7 +671,8 @@ def __init__(

self._const_val_contour_coll = _ConstantValueContourInfoCollector()

self._refine_patch_info_coll = _RefinementPatchInfoCollector()
self._refine_patch_info_coll = _RefinementShapeInfoCollector()
self._refine_line_info_coll = _RefinementShapeInfoCollector()

self._flow_lim_coll = _FlowLimiterInfoCollector()

Expand Down Expand Up @@ -1268,6 +1280,63 @@ def add_patch(
expansion_rate=expansion_rate,
target_size=target_size)

def add_feature(
self,
shape: Union[MultiLineString, LineString, None] = None,
line_defn: Optional[LineString] = None,
shapefile: Union[None, str, Path] = None,
expansion_rate: float = 0.01,
target_size: Optional[float] = None,
crs: CRS = 4326
) -> None:
"""Add refinement as a region of fixed size with an optional rate
Add a refinement based on lines specified by `shape`,
`line_defn` or `shapefile`. The fixed `target_size`
refinement is expanded by the `expansion_rate`.
Parameters
----------
shape : MultiLineString or LineString or None, default=None
Shape of the region to use specified `target_size` for
refinement. Only one of `shape`, `line_defn` or `shapefile`
must be specified.
line_defn : LineFeature or None, default=None
Shape of the region to use specified `target_size` for
refinement. Only one of `shape`, `line_defn` or `shapefile`
must be specified.
shapefile : None or str or Path, default=None
Shape of the region to use specified `target_size` for
refinement. Only one of `shape`, `line_defn` or `shapefile`
must be specified.
expansion_rate : float, default=0.01
Rate to use for expanding refinement away from
the specified shape
target_size : float or None, default=None
Fixed target size of mesh to use for refinement in
`multipolygon`
crs : CRS, default 4326
The CRS of the input `shape`
Returns
-------
None
"""

self._applied = False

if not line_defn:
if shape:
line_defn = LineFeature(shape=shape, shape_crs=crs)

elif shapefile:
line_defn = LineFeature(shapefile=shapefile)

self._refine_line_info_coll.add(
line_defn,
expansion_rate=expansion_rate,
target_size=target_size)


@staticmethod
def _type_chk(input_list: List[Any]) -> None:
Expand Down Expand Up @@ -1319,6 +1388,7 @@ def _apply_features(self) -> None:
self._apply_contours()
self._apply_flow_limiters()
self._apply_const_val()
self._apply_linefeatures()
self._apply_patch()
self._apply_channels()
self._apply_constraints()
Expand Down Expand Up @@ -1588,6 +1658,47 @@ def _apply_patch(self, apply_to: Optional[SizeFuncList] = None) -> None:
shape, nprocs=self._nprocs, **size_info)


def _apply_linefeatures(self, apply_to: Optional[SizeFuncList] = None) -> None:
"""Internal: apply the specified line feature refinements.
Parameters
----------
apply_to : SizeFuncList or None, default=None
Size functions on which line features must be applied. If `None`
all inputs are used to apply the line features.
Returns
-------
None
"""

raster_hfun_list = [
i for i in self._hfun_list if isinstance(i, HfunRaster)]
if apply_to is None:
mesh_hfun_list = [
i for i in self._hfun_list if isinstance(i, HfunMesh)]
if self._base_mesh and self._base_as_hfun:
mesh_hfun_list.insert(0, self._base_mesh)
apply_to = [*mesh_hfun_list, *raster_hfun_list]

# TODO: Parallelize
with Pool(processes=self._nprocs) as p:
for hfun in apply_to:
for lineftr_defn, size_info in self._refine_line_info_coll:
shape, crs = lineftr_defn.get_multiline()
if hfun.crs != crs:
transformer = Transformer.from_crs(
crs, hfun.crs, always_xy=True)
shape = ops.transform(
transformer.transform, shape)

hfun.add_feature(
feature=shape,
pool=p,
**size_info
)


def _write_hfun_to_disk(
self,
out_path: Union[str, Path]
Expand Down Expand Up @@ -1925,6 +2036,7 @@ def _apply_features_fast(self, big_raster: HfunRaster):
self._apply_flow_limiters_fast(hfun_rast)
self._apply_const_val_fast(hfun_rast)
# Mesh hfun parts are still stateful
self._apply_linefeatures([*mesh_hfun_list, *rast_hfun_list])
self._apply_patch([*mesh_hfun_list, *rast_hfun_list])
self._apply_channels([*mesh_hfun_list, *rast_hfun_list])

Expand Down
2 changes: 1 addition & 1 deletion ocsmesh/hfun/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ def add_feature(
feature = [feature]

elif isinstance(feature, MultiLineString):
feature = list(feature)
feature = list(feature.geoms)

# check target size
target_size = self.hmin if target_size is None else target_size
Expand Down
2 changes: 1 addition & 1 deletion ocsmesh/hfun/raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -1178,7 +1178,7 @@ def add_feature(
if isinstance(geom, LineString):
points.extend(geom.coords)
elif isinstance(geom, MultiLineString):
for linestring in geom:
for linestring in geom.geoms:
points.extend(linestring.coords)
_logger.info(f'Point concatenation took {time()-start}.')

Expand Down
Loading

0 comments on commit 1ac2ac5

Please sign in to comment.