diff --git a/ufl/coefficient.py b/ufl/coefficient.py index c10eeb1de..90d20316c 100644 --- a/ufl/coefficient.py +++ b/ufl/coefficient.py @@ -205,6 +205,7 @@ def __repr__(self): return self._repr @Expr.traverse_dag + @Expr.traverse_dag_apply_coefficient_split_cache def traverse_dag_apply_coefficient_split( self, coefficient_split, @@ -214,6 +215,7 @@ def traverse_dag_apply_coefficient_split( cache=None, ): from ufl.classes import ( + Terminal, ComponentTensor, MultiIndex, NegativeRestricted, @@ -228,25 +230,14 @@ def traverse_dag_apply_coefficient_split( from ufl.tensors import as_tensor if self not in coefficient_split: - c = self - if reference_value: - c = ReferenceValue(c) - for _ in range(reference_grad): - # Return zero if expression is trivially constant. This has to - # happen here because ReferenceGrad has no access to the - # topological dimension of a literal zero. - if is_cellwise_constant(c): - dim = extract_unique_domain(subcoeff).topological_dimension() - c = Zero(c.ufl_shape + (dim,), c.ufl_free_indices, c.ufl_index_dimensions) - else: - c = ReferenceGrad(c) - if restricted == "+": - c = PositiveRestricted(c) - elif restricted == "-": - c = NegativeRestricted(c) - elif restricted is not None: - raise RuntimeError(f"Got unknown restriction: {restricted}") - return c + return Terminal.traverse_dag_apply_coefficient_split( + self, + coefficient_split, + reference_value=reference_value, + reference_grad=reference_grad, + restricted=restricted, + cache=cache, + ) # Reference value expected if not reference_value: raise RuntimeError(f"ReferenceValue expected: got {o}") diff --git a/ufl/core/expr.py b/ufl/core/expr.py index 2179d4821..fa93e18ec 100644 --- a/ufl/core/expr.py +++ b/ufl/core/expr.py @@ -430,6 +430,39 @@ def traverse_dag_reuse_if_untouched(self, method_name: str, *args, **kwargs) -> else: return self._ufl_expr_reconstruct_(*ops) + @staticmethod + def traverse_dag_apply_coefficient_split_cache(f): + """Use method specific key for caching.""" + @functools.wraps(f) + def wrapper( + self, + coefficient_split, + reference_value=False, + reference_grad=0, + restricted=None, + cache=None, + ): + if cache is None: + raise RuntimeError(f""" + Can not have cache=None. + Must decorate {f} with ``Expr.traverse_dag``. + """) + key = (self, reference_value, reference_grad, restricted) + if key in cache: + return cache[key] + else: + result = f( + self, + coefficient_split, + reference_value=reference_value, + reference_grad=reference_grad, + restricted=restricted, + cache=cache, + ) + cache[key] = result + return result + return wrapper + def __getattr__(self, name): if name.startswith(PREFIX_TRAVERSE_DAG): # Traverse DAG with reuse_if_untouched by default. diff --git a/ufl/core/terminal.py b/ufl/core/terminal.py index 9ec56bc5a..918d3aab3 100644 --- a/ufl/core/terminal.py +++ b/ufl/core/terminal.py @@ -93,6 +93,8 @@ def __eq__(self, other): """Default comparison of terminals just compare repr strings.""" return repr(self) == repr(other) + @Expr.traverse_dag + @Expr.traverse_dag_apply_coefficient_split_cache def traverse_dag_apply_coefficient_split( self, coefficient_split, diff --git a/ufl/differentiation.py b/ufl/differentiation.py index 07392de89..59d7a7788 100644 --- a/ufl/differentiation.py +++ b/ufl/differentiation.py @@ -345,6 +345,7 @@ def __str__(self): return "reference_grad(%s)" % self.ufl_operands[0] @Expr.traverse_dag + @Expr.traverse_dag_apply_coefficient_split_cache def traverse_dag_apply_coefficient_split( self, coefficient_split, diff --git a/ufl/referencevalue.py b/ufl/referencevalue.py index 9af39dcf7..0f0c53417 100644 --- a/ufl/referencevalue.py +++ b/ufl/referencevalue.py @@ -37,6 +37,7 @@ def __str__(self): return f"reference_value({self.ufl_operands[0]})" @Expr.traverse_dag + @Expr.traverse_dag_apply_coefficient_split_cache def traverse_dag_apply_coefficient_split( self, coefficient_split: dict, diff --git a/ufl/restriction.py b/ufl/restriction.py index 7ac098b0b..1a27f4d96 100644 --- a/ufl/restriction.py +++ b/ufl/restriction.py @@ -50,6 +50,7 @@ def __str__(self): return f"{parstr(self.ufl_operands[0], self)}({self._side})" @Expr.traverse_dag + @Expr.traverse_dag_apply_coefficient_split_cache def traverse_dag_apply_coefficient_split( self, coefficient_split,