Skip to content

Commit

Permalink
Add StableHLO complex log1p operation. Add pass stablehlo-complex-mat…
Browse files Browse the repository at this point in the history
…h-expander (#2636)

As in the title.

This PR introduces a new pass stablehlo-complex-math-expander that
expands StableHLO complex functions in terms of StableHLO real
functions. Currently, only StableHLO_Log1pOp on complex inputs is
included in this expander, more will be added as follow-ups.

The provided complex `log1p` operation fixes the inaccuracy problems in
jax.numpy.log1p for complex inputs `x+i*y` when `x` is close to `-0.5 *
y * y` which triggers catastrophic cancellations when using
straightforward definition of log1p on complex inputs. The current state
of `jax.numpy.log1p` inaccuracies is given in
pearu/functional_algorithms#47 .

With this PR, the accuracy statistics of log1p is:
```
complex64:
ULP difference == 0 count is 961936
ULP difference == 1 count is 39644
ULP difference == 2 count is 455
ULP difference == 3 count is 0
ULP difference >= 4 count is 0

complex128:
ULP difference == 0 count is 988144
ULP difference == 1 count is 13891
ULP difference == 2 count is 0
ULP difference == 3 count is 0
ULP difference >= 4 count is 0
```
  • Loading branch information
pearu authored Dec 19, 2024
1 parent 33f19a4 commit adbeeca
Show file tree
Hide file tree
Showing 17 changed files with 724 additions and 46 deletions.
17 changes: 17 additions & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,21 @@ gentbl_cc_library(
],
)

gentbl_cc_library(
name = "stablehlo_create_complex_math_expander_inc_gen",
tbl_outs = [
(
["--gen-rewriters"],
"stablehlo/transforms/StablehloComplexMathExpanderPatterns.h.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "stablehlo/transforms/StablehloComplexMathExpanderPatterns.td",
deps = [
":stablehlo_ops_td_files",
],
)

cc_library(
name = "interpreter_ops",
srcs = [
Expand Down Expand Up @@ -1121,6 +1136,7 @@ cc_library(
"stablehlo/transforms/StablehloAggressiveSimplification.cpp",
"stablehlo/transforms/StablehloCanonicalizeDynamism.cpp",
"stablehlo/transforms/StablehloCompatibilityExpander.cpp",
"stablehlo/transforms/StablehloComplexMathExpander.cpp",
"stablehlo/transforms/StablehloConvertToSignless.cpp",
"stablehlo/transforms/StablehloLegalizeCompositeToCall.cpp",
"stablehlo/transforms/StablehloLegalizeDeprecatedOps.cpp",
Expand Down Expand Up @@ -1149,6 +1165,7 @@ cc_library(
":linalg_passes",
":stablehlo_aggressive_simplification_inc_gen",
":stablehlo_create_compatibility_expander_inc_gen",
":stablehlo_create_complex_math_expander_inc_gen",
":stablehlo_legalize_deprecated_ops_inc_gen",
":stablehlo_ops",
":stablehlo_ops_inc_gen",
Expand Down
16 changes: 14 additions & 2 deletions build_tools/math/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ following requirements:

- Python 3.11 or newer
- mpmath 1.3 or newer
- functional_algorithms 0.11.1 or newer
- functional_algorithms 0.12 or newer

that can be installed via pypi:

Expand Down Expand Up @@ -62,7 +62,7 @@ To execute generated tests from a `build` directory, use:

```sh
for t in $(ls ../stablehlo/tests/math/*.mlir); \
do echo $t && ( bin/stablehlo-opt --chlo-legalize-to-stablehlo $t \
do echo $t && ( bin/stablehlo-opt --stablehlo-complex-math-expander --chlo-legalize-to-stablehlo $t \
| bin/stablehlo-translate --interpret 2>&1 | grep "^ULP difference" ) ; done
```

Expand All @@ -77,6 +77,14 @@ build/bin/stablehlo-opt --chlo-legalize-to-stablehlo --split-input-file --verify

and copy relevant checks to `chlo_legalize_to_stablehlo.mlir`.

A similar procedure is applied for updating
`stablehlo/tests/stablehlo_complex_math_expander.mlir`:

```sh
build/bin/stablehlo-opt --stablehlo-complex-math-expander --split-input-file --verify-diagnostics \
stablehlo/tests/stablehlo_complex_math_expander.mlir | python llvm-project/mlir/utils/generate-test-checks.py | less
```

## A procedure for adding a new algorithm to an existing operation

1. Implement a new algorithm in
Expand All @@ -98,6 +106,10 @@ and copy relevant checks to `chlo_legalize_to_stablehlo.mlir`.
7. Add a record of the operation to
`generate_ChloDecompositionPatternsMath.py`, see the for-loop in
`main` function.
- If the operation is a StableHLO operation on complex inputs, add
it to `stable-complex-math-expander` pass: update
`populateStablehloComplexMathExpanderPatterns` function in
`stablehlo/transforms/StablehloComplexMathExpander.cpp`.
8. Generate new implementations by running
`generate_ChloDecompositionPatternsMath.py` and remove existing
implementations in
Expand Down
57 changes: 45 additions & 12 deletions build_tools/math/generate_ChloDecompositionPatternsMath.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def get_functional_algorithms_required_version():
)


def main():
def main(kind="CHLO"):
try:
import functional_algorithms as fa
except ImportError as msg:
Expand All @@ -64,16 +64,15 @@ def main():
warnings.warn(msg)
return

output_filename = dict(
CHLO="ChloDecompositionPatternsMath.td",
StableHLO="StablehloComplexMathExpanderPatterns.td",
)[kind]

output_file = os.path.relpath(
os.path.normpath(
os.path.join(
os.path.dirname(__file__),
"..",
"..",
"stablehlo",
"transforms",
"ChloDecompositionPatternsMath.td",
)),
os.path.join(os.path.dirname(__file__), "..", "..", "stablehlo",
"transforms", output_filename)),
os.getcwd(),
)

Expand All @@ -98,13 +97,15 @@ def main():
("CHLO_AtanhOp", "complex_atanh", ("z:complex",)),
("CHLO_SquareOp", "complex_square", ("z:complex",)),
("CHLO_SquareOp", "real_square", ("x:float",)),
("StableHLO_Log1pOp", "complex_log1p", ("z:complex",)),
]:
if not chloname.startswith(kind):
continue
print(f'Generating {chloname} from {fname}{args}')
func = getattr(fa.algorithms, fname, None)
if func is None:
warnings.warn(
f"{fa.algorithms.__name__} does not define {fname}. Skipping."
)
f"{fa.algorithms.__name__} does not define {fname}. Skipping.")
continue
ctx = fa.Context(paths=[fa.algorithms],
parameters=dict(rewrite_keep_integer_literals=True))
Expand All @@ -115,6 +116,16 @@ def main():
sources[-1] += src
source = "\n\n".join(sources) + "\n"

if chloname.startswith('StableHLO_'):
# an ugly hack to fix the definition of stablehlo complex math
# functions. TODO(pearu): add the corresponding feature to
# functional_algorithms stablehlo printer
NameOp = chloname.split('_', 1)[1]
source = source.replace(
f'def : Pat<({chloname}',
f'def {NameOp}_ComplexElementType_ComplexMathExpander : Pat<({chloname}'
)

if os.path.isfile(output_file):
f = open(output_file, "r")
content = f.read()
Expand Down Expand Up @@ -146,10 +157,32 @@ def main():
This file is generated using functional_algorithms tool ({fa.__version__}).
See build_tools/math/README.md for more information.""") + "\n")

if kind == "StableHLO":
f.write("""\
include "mlir/IR/OpBase.td"
include "stablehlo/dialect/StablehloOps.td"
class StableHLO_ComparisonDirectionValue<string enumStr> :
ConstantAttr<StableHLO_ComparisonDirectionAttr,
"::mlir::stablehlo::ComparisonDirection::" # enumStr>;
class StableHLO_ConstantLike<string value> : NativeCodeCall<
"::mlir::stablehlo::getConstantLike($_builder, $_loc, " # value # ", $0)">;
def ComplexElementType : Type<
CPred<"isa<ComplexType>(cast<ShapedType>($_self).getElementType())">,
"Complex element type">;
def StableHLO_ConstantLikeMaxFiniteValue : NativeCodeCall<
"::mlir::stablehlo::getConstantLikeMaxFiniteValue($_builder, $_loc, $0)">;
""")
f.write(source)
f.close()
print(f"Created {output_file}")


if __name__ == "__main__":
main()
main(kind="CHLO")
main(kind="StableHLO")
66 changes: 36 additions & 30 deletions build_tools/math/generate_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,27 +43,31 @@
default_max_ulp_difference = 1

operations = [
# The following dictionaries may have additional keys like
#
# size - defines the number of samples: size ** 2
#
# max_ulp_difference - the maximal allowed ULP difference between
# function and reference values
#
# extra_prec_multiplier - the precison multiplier for mpmath.mp
# that defines the precision of computing reference values:
# mpmath.mp.prec * extra_prec_multiplier
#
# When unspecifed, these parameters are retrieved from
# functional_algorithms database of support functions.
#
dict(name="asin", mpmath_name="arcsin"),
dict(name="acos", mpmath_name="arccos"),
dict(name="atan", mpmath_name="arctan"),
dict(name="asinh", mpmath_name="arcsinh"),
dict(name="acosh", mpmath_name="arccosh"),
dict(name="atanh", mpmath_name="arctanh"),
dict(name="square", mpmath_name="square"),
# The following dictionaries may have additional keys like
#
# size - defines the number of samples: size ** 2
#
# max_ulp_difference - the maximal allowed ULP difference between
# function and reference values
#
# extra_prec_multiplier - the precison multiplier for mpmath.mp
# that defines the precision of computing reference values:
# mpmath.mp.prec * extra_prec_multiplier
#
# When unspecifed, these parameters are retrieved from
# functional_algorithms database of support functions.
#
dict(name="asin", mpmath_name="arcsin"),
dict(name="acos", mpmath_name="arccos"),
dict(name="atan", mpmath_name="arctan"),
dict(name="asinh", mpmath_name="arcsinh"),
dict(name="acosh", mpmath_name="arccosh"),
dict(name="atanh", mpmath_name="arctanh"),
dict(name="square", mpmath_name="square"),
dict(name="log_plus_one",
mpmath_name="log1p",
namespace="stablehlo",
passes="--stablehlo-complex-math-expander"),
]


Expand Down Expand Up @@ -127,19 +131,21 @@ def main():
for op in operations:
opname = op["name"]
mpmath_opname = op.get("mpmath_name", opname)
namespace = op.get("namespace", "chlo")
size_re = size_im = op.get("size", default_size)

passes = op.get("passes", "--chlo-legalize-to-stablehlo")
for dtype in [np.complex64, np.complex128, np.float32, np.float64]:
params = fa.utils.function_validation_parameters(opname, dtype)
max_ulp_difference = op.get(
"max_ulp_difference",
params.get("max_valid_ulp_count", default_max_ulp_difference))
"max_ulp_difference",
params.get("max_valid_ulp_count", default_max_ulp_difference))

nmp = fa.utils.numpy_with_mpmath(
extra_prec_multiplier = op.get(
"extra_prec_multiplier",
params.get("extra_prec_multiplier", default_extra_prec_multiplier)),
flush_subnormals=flush_subnormals,
extra_prec_multiplier=op.get(
"extra_prec_multiplier",
params.get("extra_prec_multiplier",
default_extra_prec_multiplier)),
flush_subnormals=flush_subnormals,
)

fi = np.finfo(dtype)
Expand Down Expand Up @@ -180,7 +186,7 @@ def main():
main_func = m.make_function("main", "", "", "public")

ref_samples = main_func.call("samples")
actual = main_func.composite(f"chlo.{opname}", ref_samples)
actual = main_func.composite(f"{namespace}.{opname}", ref_samples)
expected = main_func.call("expected")

main_func.void_call(
Expand All @@ -202,7 +208,7 @@ def main():
continue

f = open(fname, "w")
f.write("// RUN: stablehlo-opt --chlo-legalize-to-stablehlo %s |"
f.write(f"// RUN: stablehlo-opt {passes} %s |"
" stablehlo-translate --interpret\n")
f.write(
"// This file is generated, see build_tools/math/README.md for more"
Expand Down
27 changes: 27 additions & 0 deletions docs/generated/stablehlo_passes.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,33 @@ func.func @tan_op_non_complex(%arg0: tensor<4xf64>) -> tensor<4xf64> {
```
-target : The target version. Must be a version of the form #.#.#.
```
### `-stablehlo-complex-math-expander`

_Expander for StableHLO complex math operations._

StableHLO complex math operations are decompositions using
StableHLO real math operations.

This statement is based on the assumption that no hardware exists
that supports complex numbers nor complex math operations
natively. This means that the fallback mechanisms on complex math
operations that compilers may implement, are redundant. With
enabling this pass, all StableHLO complex math operations will be
expanded.

```mlir
func.func @sqrt_op_complex(%arg0: tensor<4xcomplex<f64>>) -> tensor<4xcomplex<f64>> {
%1 = stablehlo.sqrt %arg0 : tensor<4xcomplex<f64>>
func.return %1 : tensor<4xcomplex<f64>>
}
==>
func.func @sqrt_op_complex(%arg0: tensor<4xcomplex<f64>>) -> tensor<4xcomplex<f64>> {
TBD
return %2 : tensor<4xcomplex<f64>>
}
```
### `-stablehlo-convert-to-signless`

_Pass to transform the IR to be on signless integers._
Expand Down
3 changes: 2 additions & 1 deletion docs/spec.md
Original file line number Diff line number Diff line change
Expand Up @@ -4066,7 +4066,8 @@ Performs element-wise logarithm plus one operation on `operand` tensor and
produces a `result` tensor. Depending on the element type, does the following:

* For floats: `logp1` from IEEE-754.
* For complex numbers: complex logarithm plus one.
* For complex numbers:
`complex(log(hypot(real(x) + 1, imag(x))), atan2(imag(x), real(x) + 1))`
* For quantized types:
`dequantize_op_quantize(log_plus_one, operand, type(result))`.

Expand Down
Loading

0 comments on commit adbeeca

Please sign in to comment.