Skip to content

Commit

Permalink
Remove component tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
pbrubeck committed Jan 18, 2025
1 parent 7d7c676 commit d328fe5
Show file tree
Hide file tree
Showing 6 changed files with 157 additions and 2 deletions.
15 changes: 15 additions & 0 deletions test/test_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
FacetNormal,
FunctionSpace,
Mesh,
SpatialCoordinate,
TestFunction,
TrialFunction,
adjoint,
Expand All @@ -21,9 +22,11 @@
dx,
grad,
inner,
sin,
triangle,
)
from ufl.algorithms import (
compute_form_data,
expand_derivatives,
expand_indices,
extract_arguments,
Expand Down Expand Up @@ -182,3 +185,15 @@ def test_adjoint(domain):
d = adjoint(b)
d_arg_degrees = [arg.ufl_element().embedded_superdegree for arg in extract_arguments(d)]
assert d_arg_degrees == [2, 1]


def test_remove_component_tensors(domain):
x = SpatialCoordinate(domain)
u = sin(x[0])

f = div(grad(div(grad(u))))
form = f * dx

fd = compute_form_data(form)

assert "ComponentTensor" not in repr(fd.preprocessed_form)
4 changes: 2 additions & 2 deletions test/test_derivative.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,7 +665,7 @@ def test_vector_coefficient_scalar_derivatives(self):
integrand = inner(f, g)

i0, i1, i2, i3, i4 = [Index(count=c) for c in range(5)]
expected = as_tensor(df[i1] * dv, (i1,))[i0] * g[i0]
expected = as_tensor(df[i1], (i1,))[i0] * dv * g[i0]

F = integrand * dx
J = derivative(F, u, dv, cd)
Expand Down Expand Up @@ -693,7 +693,7 @@ def test_vector_coefficient_derivatives(self):
integrand = inner(f, g)

i0, i1, i2, i3, i4 = [Index(count=c) for c in range(5)]
expected = as_tensor(df[i2, i1] * dv[i1], (i2,))[i0] * g[i0]
expected = as_tensor(df[i2, i1], (i2,))[i0] * dv[i1] * g[i0]

F = integrand * dx
J = derivative(F, u, dv, cd)
Expand Down
6 changes: 6 additions & 0 deletions ufl/algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from ufl.core.operator import Operator
from ufl.core.ufl_type import ufl_type
from ufl.index_combination_utils import merge_unique_indices
from ufl.indexed import Indexed
from ufl.precedence import parstr
from ufl.sorting import sorted_expr

Expand Down Expand Up @@ -89,6 +90,11 @@ def __init__(self, a, b):
"""Initialise."""
Operator.__init__(self)

def _simplify_indexed(self, multiindex):
"""Return a simplified Expr used in the constructor of Indexed(self, multiindex)."""
a, b = self.ufl_operands
return Sum(Indexed(a, multiindex), Indexed(b, multiindex))

def evaluate(self, x, mapping, component, index_values):
"""Evaluate."""
return sum(o.evaluate(x, mapping, component, index_values) for o in self.ufl_operands)
Expand Down
4 changes: 4 additions & 0 deletions ufl/algorithms/compute_form_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from ufl.algorithms.formdata import FormData
from ufl.algorithms.formtransformations import compute_form_arities
from ufl.algorithms.remove_complex_nodes import remove_complex_nodes
from ufl.algorithms.remove_component_tensors import remove_component_tensors
from ufl.classes import Coefficient, Form, FunctionSpace, GeometricFacetQuantity
from ufl.corealg.traversal import traverse_unique_terminals
from ufl.domain import extract_unique_domain
Expand Down Expand Up @@ -328,6 +329,9 @@ def compute_form_data(

form = apply_coordinate_derivatives(form)

# Remove component tensors
form = remove_component_tensors(form)

# Propagate restrictions to terminals
if do_apply_restrictions:
form = apply_restrictions(form, apply_default=do_apply_default_restrictions)
Expand Down
124 changes: 124 additions & 0 deletions ufl/algorithms/remove_component_tensors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
"""Remove component tensors.
This module contains classes and functions to remove component tensors.
"""
# Copyright (C) 2008-2016 Martin Sandve Alnæs
#
# This file is part of UFL (https://www.fenicsproject.org)
#
# SPDX-License-Identifier: LGPL-3.0-or-later

from ufl.classes import (
ComponentTensor,
Form,
Index,
MultiIndex,
Zero,
)
from ufl.corealg.map_dag import map_expr_dag
from ufl.corealg.multifunction import MultiFunction, memoized_handler


class IndexReplacer(MultiFunction):
"""Replace Indices."""

def __init__(self, fimap: dict):
"""Initialise.
Args:
fimap: map for index replacements.
"""
MultiFunction.__init__(self)
self.fimap = fimap
self._object_cache = {}

expr = MultiFunction.reuse_if_untouched

@memoized_handler
def zero(self, o):
"""Handle Zero."""
free_indices = []
index_dimensions = []
for i, d in zip(o.ufl_free_indices, o.ufl_index_dimensions):
if Index(i) in self.fimap:
ind_j = self.fimap[Index(i)]
if isinstance(ind_j, Index):
free_indices.append(ind_j.count())
index_dimensions.append(d)
else:
free_indices.append(i)
index_dimensions.append(d)
return Zero(
shape=o.ufl_shape,
free_indices=tuple(free_indices),
index_dimensions=tuple(index_dimensions),
)

@memoized_handler
def multi_index(self, o):
"""Handle MultiIndex."""
return MultiIndex(tuple(self.fimap.get(i, i) for i in o.indices()))


class IndexRemover(MultiFunction):
"""Remove Indexed."""

def __init__(self):
"""Initialise."""
MultiFunction.__init__(self)
self._object_cache = {}

expr = MultiFunction.reuse_if_untouched

@memoized_handler
def _zero_simplify(self, o):
"""Apply simplification for Zero()."""
(operand,) = o.ufl_operands
operand = map_expr_dag(self, operand)
if isinstance(operand, Zero):
return Zero(
shape=o.ufl_shape,
free_indices=o.ufl_free_indices,
index_dimensions=o.ufl_index_dimensions,
)
return o._ufl_expr_reconstruct_(operand)

@memoized_handler
def indexed(self, o):
"""Simplify Indexed."""
o1, i1 = o.ufl_operands
if isinstance(o1, ComponentTensor):
# Simplify Indexed ComponentTensor
o2, i2 = o1.ufl_operands
assert len(i2) == len(i1)
fimap = dict(zip(i2, i1))
rule = IndexReplacer(fimap)
v = map_expr_dag(self, o2)
return map_expr_dag(rule, v)

expr = map_expr_dag(self, o1)
if expr is o1:
# Reuse if untouched
return o
return o._ufl_expr_reconstruct_(expr, i1)

# Do something nicer
positive_restricted = _zero_simplify
negative_restricted = _zero_simplify
reference_grad = _zero_simplify
reference_value = _zero_simplify


def remove_component_tensors(o):
"""Remove component tensors."""
if isinstance(o, Form):
integrals = []
for integral in o.integrals():
integrand = remove_component_tensors(integral.integrand())
if not isinstance(integrand, Zero):
integrals.append(integral.reconstruct(integrand=integrand))
return o._ufl_expr_reconstruct_(integrals)
else:
rule = IndexRemover()
return map_expr_dag(rule, o)
6 changes: 6 additions & 0 deletions ufl/indexsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from ufl.core.multiindex import MultiIndex
from ufl.core.operator import Operator
from ufl.core.ufl_type import ufl_type
from ufl.indexed import Indexed
from ufl.precedence import parstr

# --- Sum over an index ---
Expand Down Expand Up @@ -69,6 +70,11 @@ def ufl_shape(self):
"""Get UFL shape."""
return self.ufl_operands[0].ufl_shape

def _simplify_indexed(self, multiindex):
"""Return a simplified Expr used in the constructor of Indexed(self, multiindex)."""
A, i = self.ufl_operands
return IndexSum(Indexed(A, multiindex), i)

def evaluate(self, x, mapping, component, index_values):
"""Evaluate."""
(i,) = self.ufl_operands[1]
Expand Down

0 comments on commit d328fe5

Please sign in to comment.