Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Various improvements, some for TP elements #337

Merged
merged 5 commits into from
Apr 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions examples/wave/wave-min-mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from grudge.array_context import MPIPyOpenCLArrayContext

from grudge.shortcuts import set_up_rk4
from grudge import DiscretizationCollection
from grudge import make_discretization_collection

from mpi4py import MPI

Expand All @@ -47,7 +47,7 @@ class WaveTag:
pass


def main(ctx_factory, dim=2, order=4, visualize=False):
def main(dim=2, order=4, visualize=True):
comm = MPI.COMM_WORLD
num_parts = comm.size

Expand Down Expand Up @@ -83,7 +83,7 @@ def main(ctx_factory, dim=2, order=4, visualize=False):
else:
local_mesh = comm.scatter(None)

dcoll = DiscretizationCollection(actx, local_mesh, order=order)
dcoll = make_discretization_collection(actx, local_mesh, order=order)

def source_f(actx, dcoll, t=0):
source_center = np.array([0.1, 0.22, 0.33])[:dcoll.dim]
Expand Down Expand Up @@ -196,7 +196,7 @@ def norm(u):
args = parser.parse_args()

logging.basicConfig(level=logging.INFO)
main(cl.create_some_context,
main(
dim=args.dim,
order=args.order,
visualize=args.visualize)
11 changes: 4 additions & 7 deletions grudge/discretization.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@
"""

from typing import Mapping, Optional, Union, TYPE_CHECKING, Any
from meshmode.discretization.poly_element import ModalGroupFactory
from meshmode.discretization.poly_element import (
InterpolatoryEdgeClusteredGroupFactory, ModalGroupFactory)

from pytools import memoize_method, single_valued

Expand Down Expand Up @@ -79,18 +80,14 @@ def _normalize_discr_tag_to_group_factory(
Mapping[DiscretizationTag, ElementGroupFactory]],
order: Optional[int]
) -> Mapping[DiscretizationTag, ElementGroupFactory]:
from meshmode.discretization.poly_element import \
default_simplex_group_factory

if discr_tag_to_group_factory is None:
if order is None:
raise TypeError(
"one of 'order' and 'discr_tag_to_group_factory' must be given"
)

discr_tag_to_group_factory = {
DISCR_TAG_BASE: default_simplex_group_factory(
base_dim=dim, order=order)}
DISCR_TAG_BASE: InterpolatoryEdgeClusteredGroupFactory(order=order)}
else:
discr_tag_to_group_factory = dict(discr_tag_to_group_factory)

Expand All @@ -102,7 +99,7 @@ def _normalize_discr_tag_to_group_factory(
)

discr_tag_to_group_factory[DISCR_TAG_BASE] = \
default_simplex_group_factory(base_dim=dim, order=order)
InterpolatoryEdgeClusteredGroupFactory(order)

assert discr_tag_to_group_factory is not None

Expand Down
6 changes: 0 additions & 6 deletions grudge/dof_desc.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@
THE SOFTWARE.
"""

import sys
from warnings import warn
from typing import Hashable, Union, Type, Optional, Any, Tuple
from dataclasses import dataclass, replace
Expand Down Expand Up @@ -491,11 +490,6 @@ def __getattr__(name):

raise AttributeError(f"module {__name__} has no attribute {name}")


if sys.version_info < (3, 7):
for name in _deprecated_name_to_new_name:
globals()[name] = globals()[_deprecated_name_to_new_name[name]]

# }}}


Expand Down
12 changes: 9 additions & 3 deletions grudge/models/wave.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

import numpy as np

from grudge.dof_desc import DISCR_TAG_BASE, as_dofdesc
from grudge.models import HyperbolicOperator

from meshmode.mesh import BTAG_ALL, BTAG_NONE
Expand Down Expand Up @@ -113,6 +114,8 @@ def operator(self, t, w):
v = w[1:]
actx = u.array_context

base_dd = as_dofdesc("vol", DISCR_TAG_BASE)

# boundary conditions -------------------------------------------------

# dirichlet BCs -------------------------------------------------------
Expand Down Expand Up @@ -160,9 +163,12 @@ def flux(tpair):
dcoll,
sum(flux(tpair) for tpair in op.interior_trace_pairs(
dcoll, w, comm_tag=self.comm_tag))
+ flux(op.bv_trace_pair(dcoll, self.dirichlet_tag, w, dir_bc))
+ flux(op.bv_trace_pair(dcoll, self.neumann_tag, w, neu_bc))
+ flux(op.bv_trace_pair(dcoll, self.radiation_tag, w, rad_bc))
+ flux(op.bv_trace_pair(
dcoll, base_dd.trace(self.dirichlet_tag), w, dir_bc))
+ flux(op.bv_trace_pair(
dcoll, base_dd.trace(self.neumann_tag), w, neu_bc))
+ flux(op.bv_trace_pair(
dcoll, base_dd.trace(self.radiation_tag), w, rad_bc))
)
)
)
Expand Down
Loading
Loading