Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add resolution gradient support in the mesh generator #177

Merged
merged 1 commit into from
Mar 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/fluidity/advection2d/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ $(OUT_DIR)/$(SIM_NAME).flml: $(SRC_DIR)/$(SIM_NAME).flml
$(OUT_DIR)/$(SIM_NAME).msh:
@echo "********** Building the mesh file..."
./envcheck.sh -m
pydrex-mesh -k="rectangle" -a xy $(WIDTH),$(DEPTH) $(RESOLUTION) $@
pydrex-mesh -k="rectangle" -a xy $(WIDTH),$(DEPTH) -r G:$(RESOLUTION) $@

.PHONY: clean
clean:
Expand Down
36 changes: 33 additions & 3 deletions src/pydrex/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@ def _get_args(self) -> argparse.Namespace:
class MeshGenerator(CliTool):
"""PyDRex script to generate various simple meshes.

Only rectangular (2D) meshes are currently supported.
Only rectangular (2D) meshes are currently supported. The RESOLUTION must be a comma
delimited set of directives of the form `<LOC>:<RES>` where `<LOC>` is a location
specifier, i.e. either "G" (global) or a compas direction like "N", "S", "NE", etc.,
and `<RES>` is a floating point value to be set as the resolution at that location.

"""

Expand All @@ -54,21 +57,48 @@ def __call__(self):
assert len(center) == 2

width, height = map(float, args.size.split(","))
_loc_map = {
"G": "global",
"N": "north",
"S": "south",
"E": "east",
"W": "west",
"NE": "north-east",
"NW": "north-west",
"SE": "south-east",
"SW": "south-west",
}
try:
resolution = {
_loc_map[k]: float(v)
for k, v in map(lambda s: s.split(":"), args.resolution.split(","))
}
except KeyError:
raise KeyError(
"invalid or unsupported location specified in resolution directive"
) from None
except ValueError:
raise ValueError(
"invalid resolution value. The format should be '<LOC1>:<RES1>,<LOC2>:<RES2>,...'"
) from None
_mesh.rectangle(
args.output[:-4],
(args.ref_axes[0], args.ref_axes[1]),
center,
width,
height,
args.resolution,
resolution,
)

def _get_args(self) -> argparse.Namespace:
description, epilog = self.__doc__.split(os.linesep + os.linesep, 1)
parser = argparse.ArgumentParser(description=description, epilog=epilog)
parser.add_argument("size", help="width,height[,depth] of the mesh")
parser.add_argument(
"resolution", help="base resolution of the mesh (edge length hint for gmsh)"
"-r",
"--resolution",
help="resolution for the mesh (edge length hint(s) for gmsh)",
required=True,
)
parser.add_argument("output", help="output file (.msh)")
parser.add_argument(
Expand Down
111 changes: 99 additions & 12 deletions src/pydrex/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

@dataclass
class Model:
"""An object-oriented gmsh model API.
"""A context manager for using the gmsh model API.

>>> with Model("example_model", 2, _write_file=False) as model:
... model.point_constraints = [
Expand All @@ -32,6 +32,8 @@ class Model:

"""

# https://gitlab.onelab.info/gmsh/gmsh/-/raw/master/api/gmsh.py

name: str
dim: int
optimize_args: dict = field(default_factory=dict)
Expand All @@ -42,9 +44,15 @@ class Model:
surface_tags: list = field(default_factory=list)
physical_line_tags: list = field(default_factory=list)
physical_group_tags: list = field(default_factory=list)
mesh_info: dict = field(default_factory=dict)
_write_file: bool = True
_was_entered: bool = False

# TODO: Possible attributes worth adding, no particular order:
# - boundary/boundaries, see gm.model.getBoundary()
# - bounding_box, see gm.model.getBoundingBox()
# - wrapper method for gm.model.getClosestPoint() ? maybe we need entities for that

def __enter__(self):
# See: <https://gitlab.onelab.info/gmsh/gmsh/-/issues/1142>
gm.initialize(["-noenv"]) # Don't let gmsh mess with $PYTHONPATH and $PATH.
Expand All @@ -56,6 +64,22 @@ def __exit__(self, exc_type, exc_val, exc_tb):
gm.model.geo.synchronize()
self.add_physical_groups()
gm.model.mesh.generate(self.dim)
# Populate some mesh info for later reference.
node_tags, node_coords, _ = gm.model.mesh.getNodes()
self.mesh_info["node_tags"] = node_tags
self.mesh_info["node_coords"] = node_coords.reshape((node_tags.size, 3))
element_types, element_tags, _ = gm.model.mesh.getElements()
self.mesh_info["element_types"] = element_types
self.mesh_info["element_tags"] = element_tags
edge_tags, edge_orientations = gm.model.mesh.getAllEdges()
self.mesh_info["edge_tags"] = edge_tags
self.mesh_info["edge_orientations"] = edge_orientations
tri_face_tags, tri_face_nodes = gm.model.mesh.getAllFaces(3)
self.mesh_info["tri_face_tags"] = tri_face_tags
self.mesh_info["tri_face_nodes"] = tri_face_nodes
quad_face_tags, quad_face_nodes = gm.model.mesh.getAllFaces(4)
self.mesh_info["quad_face_tags"] = quad_face_tags
self.mesh_info["quad_face_nodes"] = quad_face_nodes
if len(self.optimize_args) > 0:
gm.model.mesh.optimize(**self.optimize_args)
_log.info(
Expand Down Expand Up @@ -111,21 +135,83 @@ def add_physical_groups(self):
)


def rectangle(name, ref_axes, center, width, height, resolution):
"""Generate a rectangular (2D) mesh."""
def rectangle(name, ref_axes, center, width, height, resolution, **kwargs):
"""Generate a rectangular (2D) mesh.

>>> rect = rectangle(
... "test_rect",
... ("x", "z"),
... center=(0, 0),
... width=1,
... height=1,
... resolution={"global": 1e-2},
... _write_file=False
... )
>>> rect.dim
2
>>> rect.name
'test_rect'
>>> rect.line_tags
[1, 2, 3, 4]
>>> rect.loop_tags
[1]
>>> [p[-1] for p in rect.point_constraints]
[0.01, 0.01, 0.01, 0.01]

>>> rect = rectangle(
... "test_rect",
... ("x", "z"),
... center=(0, 0),
... width=1,
... height=1,
... resolution={"north": 1e-2, "south": 1e-3},
... _write_file=False
... )
>>> [p[-1] for p in rect.point_constraints]
[0.001, 0.001, 0.01, 0.01]

>>> rect = rectangle(
... "test_rect",
... ("x", "z"),
... center=(0, 0),
... width=1,
... height=1,
... resolution={"north-west": 1e-3, "south-east": 1e-2},
... _write_file=False
... )
>>> rect.point_constraints[1][-1]
0.01
>>> rect.point_constraints[3][-1]
0.001
>>> rect.point_constraints[0][-1] == rect.point_constraints[2][-1]
True
>>> rect.point_constraints[0][-1]
0.0055

# TODO: Support resolution gradients like:
# resolution_gradient=(1e-2, "radial_grow") # from 1e-2 at the center to `resolution` at the edges
# resolution_gradient=(1e-2, "radial_shrink") # opposite of the above
# resolution_gradient=(1e-2, "south") # from `resolution` at the top to 1e-2 at the bottom
# resolution_gradient=(1e-2, "west")
# default should be resolution_gradient=None
"""

h, v = _geo.to_indices(*ref_axes)
center_h, center_v = center
point_constraints = np.zeros((4, 4)) # x, y, z, nearby_edge_length
# TODO: Support "center" which should trigger creation of an additional
# point_constraint that is not connected by lines but just anchors the central
# resolution constraint (assuming this is possible in gmsh).
_loc_map = {
"global": range(4),
"north": (2, 3),
"south": (0, 1),
"east": (1, 2),
"west": (0, 3),
"north-east": (2,),
"north-west": (3,),
"south-east": (1,),
"south-west": (0,),
}
for i, p in enumerate(point_constraints):
p[-1] = resolution
p[-1] = np.mean(list(resolution.values()))
for k, res in resolution.items():
if i in _loc_map[k]:
p[-1] = res
match i:
case 0:
p[h] = center_h - width / 2
Expand All @@ -140,12 +226,13 @@ def rectangle(name, ref_axes, center, width, height, resolution):
p[h] = center_h - width / 2
p[v] = center_v + height / 2

with Model(name, 2) as model:
with Model(name, 2, **kwargs) as model:
model.point_constraints = point_constraints
model.add_tags()
model.add_physical_groups()
return model


#
# def orthopiped():
# """Generate an orthopiped (3D “box”) mesh."""
# ...
Expand Down
Loading