Skip to content

Commit

Permalink
k
Browse files Browse the repository at this point in the history
  • Loading branch information
ksagiyam committed Jan 22, 2025
1 parent 08226e8 commit aae77cf
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 19 deletions.
29 changes: 10 additions & 19 deletions ufl/coefficient.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ def __repr__(self):
return self._repr

@Expr.traverse_dag
@Expr.traverse_dag_apply_coefficient_split_cache
def traverse_dag_apply_coefficient_split(
self,
coefficient_split,
Expand All @@ -214,6 +215,7 @@ def traverse_dag_apply_coefficient_split(
cache=None,
):
from ufl.classes import (
Terminal,
ComponentTensor,
MultiIndex,
NegativeRestricted,
Expand All @@ -228,25 +230,14 @@ def traverse_dag_apply_coefficient_split(
from ufl.tensors import as_tensor

if self not in coefficient_split:
c = self
if reference_value:
c = ReferenceValue(c)
for _ in range(reference_grad):
# Return zero if expression is trivially constant. This has to
# happen here because ReferenceGrad has no access to the
# topological dimension of a literal zero.
if is_cellwise_constant(c):
dim = extract_unique_domain(subcoeff).topological_dimension()
c = Zero(c.ufl_shape + (dim,), c.ufl_free_indices, c.ufl_index_dimensions)
else:
c = ReferenceGrad(c)
if restricted == "+":
c = PositiveRestricted(c)
elif restricted == "-":
c = NegativeRestricted(c)
elif restricted is not None:
raise RuntimeError(f"Got unknown restriction: {restricted}")
return c
return Terminal.traverse_dag_apply_coefficient_split(
self,
coefficient_split,
reference_value=reference_value,
reference_grad=reference_grad,
restricted=restricted,
cache=cache,
)
# Reference value expected
if not reference_value:
raise RuntimeError(f"ReferenceValue expected: got {o}")
Expand Down
33 changes: 33 additions & 0 deletions ufl/core/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,39 @@ def traverse_dag_reuse_if_untouched(self, method_name: str, *args, **kwargs) ->
else:
return self._ufl_expr_reconstruct_(*ops)

@staticmethod
def traverse_dag_apply_coefficient_split_cache(f):
"""Use method specific key for caching."""
@functools.wraps(f)
def wrapper(
self,
coefficient_split,
reference_value=False,
reference_grad=0,
restricted=None,
cache=None,
):
if cache is None:
raise RuntimeError(f"""
Can not have cache=None.
Must decorate {f} with ``Expr.traverse_dag``.
""")
key = (self, reference_value, reference_grad, restricted)
if key in cache:
return cache[key]
else:
result = f(
self,
coefficient_split,
reference_value=reference_value,
reference_grad=reference_grad,
restricted=restricted,
cache=cache,
)
cache[key] = result
return result
return wrapper

def __getattr__(self, name):
if name.startswith(PREFIX_TRAVERSE_DAG):
# Traverse DAG with reuse_if_untouched by default.
Expand Down
2 changes: 2 additions & 0 deletions ufl/core/terminal.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ def __eq__(self, other):
"""Default comparison of terminals just compare repr strings."""
return repr(self) == repr(other)

@Expr.traverse_dag
@Expr.traverse_dag_apply_coefficient_split_cache
def traverse_dag_apply_coefficient_split(
self,
coefficient_split,
Expand Down
1 change: 1 addition & 0 deletions ufl/differentiation.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,7 @@ def __str__(self):
return "reference_grad(%s)" % self.ufl_operands[0]

@Expr.traverse_dag
@Expr.traverse_dag_apply_coefficient_split_cache
def traverse_dag_apply_coefficient_split(
self,
coefficient_split,
Expand Down
1 change: 1 addition & 0 deletions ufl/referencevalue.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __str__(self):
return f"reference_value({self.ufl_operands[0]})"

@Expr.traverse_dag
@Expr.traverse_dag_apply_coefficient_split_cache
def traverse_dag_apply_coefficient_split(
self,
coefficient_split: dict,
Expand Down
1 change: 1 addition & 0 deletions ufl/restriction.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def __str__(self):
return f"{parstr(self.ufl_operands[0], self)}({self._side})"

@Expr.traverse_dag
@Expr.traverse_dag_apply_coefficient_split_cache
def traverse_dag_apply_coefficient_split(
self,
coefficient_split,
Expand Down

0 comments on commit aae77cf

Please sign in to comment.