From adbeeca967549059dec9880fcc27806cdb78eb11 Mon Sep 17 00:00:00 2001 From: Pearu Peterson Date: Thu, 19 Dec 2024 17:36:17 +0200 Subject: [PATCH] Add StableHLO complex log1p operation. Add pass stablehlo-complex-math-expander (#2636) As in the title. This PR introduces a new pass stablehlo-complex-math-expander that expands StableHLO complex functions in terms of StableHLO real functions. Currently, only StableHLO_Log1pOp on complex inputs is included in this expander, more will be added as follow-ups. The provided complex `log1p` operation fixes the inaccuracy problems in jax.numpy.log1p for complex inputs `x+i*y` when `x` is close to `-0.5 * y * y` which triggers catastrophic cancellations when using straightforward definition of log1p on complex inputs. The current state of `jax.numpy.log1p` inaccuracies is given in https://github.com/pearu/functional_algorithms/issues/47 . With this PR, the accuracy statistics of log1p is: ``` complex64: ULP difference == 0 count is 961936 ULP difference == 1 count is 39644 ULP difference == 2 count is 455 ULP difference == 3 count is 0 ULP difference >= 4 count is 0 complex128: ULP difference == 0 count is 988144 ULP difference == 1 count is 13891 ULP difference == 2 count is 0 ULP difference == 3 count is 0 ULP difference >= 4 count is 0 ``` --- BUILD.bazel | 17 ++ build_tools/math/README.md | 16 +- .../generate_ChloDecompositionPatternsMath.py | 57 +++- build_tools/math/generate_tests.py | 66 +++-- docs/generated/stablehlo_passes.md | 27 ++ docs/spec.md | 3 +- .../tests/math/log_plus_one_complex128.mlir | 19 ++ .../tests/math/log_plus_one_complex64.mlir | 19 ++ .../tests/math/log_plus_one_float32.mlir | 19 ++ .../tests/math/log_plus_one_float64.mlir | 19 ++ .../stablehlo_complex_math_expander.mlir | 110 +++++++ stablehlo/transforms/CMakeLists.txt | 6 + .../ChloDecompositionPatternsMath.td | 2 +- stablehlo/transforms/Passes.h | 5 + stablehlo/transforms/Passes.td | 34 +++ .../StablehloComplexMathExpander.cpp | 76 +++++ .../StablehloComplexMathExpanderPatterns.td | 275 ++++++++++++++++++ 17 files changed, 724 insertions(+), 46 deletions(-) create mode 100644 stablehlo/tests/math/log_plus_one_complex128.mlir create mode 100644 stablehlo/tests/math/log_plus_one_complex64.mlir create mode 100644 stablehlo/tests/math/log_plus_one_float32.mlir create mode 100644 stablehlo/tests/math/log_plus_one_float64.mlir create mode 100644 stablehlo/tests/stablehlo_complex_math_expander.mlir create mode 100644 stablehlo/transforms/StablehloComplexMathExpander.cpp create mode 100644 stablehlo/transforms/StablehloComplexMathExpanderPatterns.td diff --git a/BUILD.bazel b/BUILD.bazel index 8d2b26c32a..866e971507 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -371,6 +371,21 @@ gentbl_cc_library( ], ) +gentbl_cc_library( + name = "stablehlo_create_complex_math_expander_inc_gen", + tbl_outs = [ + ( + ["--gen-rewriters"], + "stablehlo/transforms/StablehloComplexMathExpanderPatterns.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "stablehlo/transforms/StablehloComplexMathExpanderPatterns.td", + deps = [ + ":stablehlo_ops_td_files", + ], +) + cc_library( name = "interpreter_ops", srcs = [ @@ -1121,6 +1136,7 @@ cc_library( "stablehlo/transforms/StablehloAggressiveSimplification.cpp", "stablehlo/transforms/StablehloCanonicalizeDynamism.cpp", "stablehlo/transforms/StablehloCompatibilityExpander.cpp", + "stablehlo/transforms/StablehloComplexMathExpander.cpp", "stablehlo/transforms/StablehloConvertToSignless.cpp", "stablehlo/transforms/StablehloLegalizeCompositeToCall.cpp", "stablehlo/transforms/StablehloLegalizeDeprecatedOps.cpp", @@ -1149,6 +1165,7 @@ cc_library( ":linalg_passes", ":stablehlo_aggressive_simplification_inc_gen", ":stablehlo_create_compatibility_expander_inc_gen", + ":stablehlo_create_complex_math_expander_inc_gen", ":stablehlo_legalize_deprecated_ops_inc_gen", ":stablehlo_ops", ":stablehlo_ops_inc_gen", diff --git a/build_tools/math/README.md b/build_tools/math/README.md index 8e3a8f180d..e2f10a8881 100644 --- a/build_tools/math/README.md +++ b/build_tools/math/README.md @@ -31,7 +31,7 @@ following requirements: - Python 3.11 or newer - mpmath 1.3 or newer -- functional_algorithms 0.11.1 or newer +- functional_algorithms 0.12 or newer that can be installed via pypi: @@ -62,7 +62,7 @@ To execute generated tests from a `build` directory, use: ```sh for t in $(ls ../stablehlo/tests/math/*.mlir); \ -do echo $t && ( bin/stablehlo-opt --chlo-legalize-to-stablehlo $t \ +do echo $t && ( bin/stablehlo-opt --stablehlo-complex-math-expander --chlo-legalize-to-stablehlo $t \ | bin/stablehlo-translate --interpret 2>&1 | grep "^ULP difference" ) ; done ``` @@ -77,6 +77,14 @@ build/bin/stablehlo-opt --chlo-legalize-to-stablehlo --split-input-file --verify and copy relevant checks to `chlo_legalize_to_stablehlo.mlir`. +A similar procedure is applied for updating +`stablehlo/tests/stablehlo_complex_math_expander.mlir`: + +```sh +build/bin/stablehlo-opt --stablehlo-complex-math-expander --split-input-file --verify-diagnostics \ + stablehlo/tests/stablehlo_complex_math_expander.mlir | python llvm-project/mlir/utils/generate-test-checks.py | less +``` + ## A procedure for adding a new algorithm to an existing operation 1. Implement a new algorithm in @@ -98,6 +106,10 @@ and copy relevant checks to `chlo_legalize_to_stablehlo.mlir`. 7. Add a record of the operation to `generate_ChloDecompositionPatternsMath.py`, see the for-loop in `main` function. + - If the operation is a StableHLO operation on complex inputs, add + it to `stable-complex-math-expander` pass: update + `populateStablehloComplexMathExpanderPatterns` function in + `stablehlo/transforms/StablehloComplexMathExpander.cpp`. 8. Generate new implementations by running `generate_ChloDecompositionPatternsMath.py` and remove existing implementations in diff --git a/build_tools/math/generate_ChloDecompositionPatternsMath.py b/build_tools/math/generate_ChloDecompositionPatternsMath.py index 885849ce89..62b99474dc 100644 --- a/build_tools/math/generate_ChloDecompositionPatternsMath.py +++ b/build_tools/math/generate_ChloDecompositionPatternsMath.py @@ -44,7 +44,7 @@ def get_functional_algorithms_required_version(): ) -def main(): +def main(kind="CHLO"): try: import functional_algorithms as fa except ImportError as msg: @@ -64,16 +64,15 @@ def main(): warnings.warn(msg) return + output_filename = dict( + CHLO="ChloDecompositionPatternsMath.td", + StableHLO="StablehloComplexMathExpanderPatterns.td", + )[kind] + output_file = os.path.relpath( os.path.normpath( - os.path.join( - os.path.dirname(__file__), - "..", - "..", - "stablehlo", - "transforms", - "ChloDecompositionPatternsMath.td", - )), + os.path.join(os.path.dirname(__file__), "..", "..", "stablehlo", + "transforms", output_filename)), os.getcwd(), ) @@ -98,13 +97,15 @@ def main(): ("CHLO_AtanhOp", "complex_atanh", ("z:complex",)), ("CHLO_SquareOp", "complex_square", ("z:complex",)), ("CHLO_SquareOp", "real_square", ("x:float",)), + ("StableHLO_Log1pOp", "complex_log1p", ("z:complex",)), ]: + if not chloname.startswith(kind): + continue print(f'Generating {chloname} from {fname}{args}') func = getattr(fa.algorithms, fname, None) if func is None: warnings.warn( - f"{fa.algorithms.__name__} does not define {fname}. Skipping." - ) + f"{fa.algorithms.__name__} does not define {fname}. Skipping.") continue ctx = fa.Context(paths=[fa.algorithms], parameters=dict(rewrite_keep_integer_literals=True)) @@ -115,6 +116,16 @@ def main(): sources[-1] += src source = "\n\n".join(sources) + "\n" + if chloname.startswith('StableHLO_'): + # an ugly hack to fix the definition of stablehlo complex math + # functions. TODO(pearu): add the corresponding feature to + # functional_algorithms stablehlo printer + NameOp = chloname.split('_', 1)[1] + source = source.replace( + f'def : Pat<({chloname}', + f'def {NameOp}_ComplexElementType_ComplexMathExpander : Pat<({chloname}' + ) + if os.path.isfile(output_file): f = open(output_file, "r") content = f.read() @@ -146,10 +157,32 @@ def main(): This file is generated using functional_algorithms tool ({fa.__version__}). See build_tools/math/README.md for more information.""") + "\n") + + if kind == "StableHLO": + f.write("""\ +include "mlir/IR/OpBase.td" +include "stablehlo/dialect/StablehloOps.td" + +class StableHLO_ComparisonDirectionValue : + ConstantAttr; + +class StableHLO_ConstantLike : NativeCodeCall< + "::mlir::stablehlo::getConstantLike($_builder, $_loc, " # value # ", $0)">; + +def ComplexElementType : Type< + CPred<"isa(cast($_self).getElementType())">, + "Complex element type">; + +def StableHLO_ConstantLikeMaxFiniteValue : NativeCodeCall< + "::mlir::stablehlo::getConstantLikeMaxFiniteValue($_builder, $_loc, $0)">; + +""") f.write(source) f.close() print(f"Created {output_file}") if __name__ == "__main__": - main() + main(kind="CHLO") + main(kind="StableHLO") diff --git a/build_tools/math/generate_tests.py b/build_tools/math/generate_tests.py index ecd413cda2..fe20a1ae0b 100644 --- a/build_tools/math/generate_tests.py +++ b/build_tools/math/generate_tests.py @@ -43,27 +43,31 @@ default_max_ulp_difference = 1 operations = [ - # The following dictionaries may have additional keys like - # - # size - defines the number of samples: size ** 2 - # - # max_ulp_difference - the maximal allowed ULP difference between - # function and reference values - # - # extra_prec_multiplier - the precison multiplier for mpmath.mp - # that defines the precision of computing reference values: - # mpmath.mp.prec * extra_prec_multiplier - # - # When unspecifed, these parameters are retrieved from - # functional_algorithms database of support functions. - # - dict(name="asin", mpmath_name="arcsin"), - dict(name="acos", mpmath_name="arccos"), - dict(name="atan", mpmath_name="arctan"), - dict(name="asinh", mpmath_name="arcsinh"), - dict(name="acosh", mpmath_name="arccosh"), - dict(name="atanh", mpmath_name="arctanh"), - dict(name="square", mpmath_name="square"), + # The following dictionaries may have additional keys like + # + # size - defines the number of samples: size ** 2 + # + # max_ulp_difference - the maximal allowed ULP difference between + # function and reference values + # + # extra_prec_multiplier - the precison multiplier for mpmath.mp + # that defines the precision of computing reference values: + # mpmath.mp.prec * extra_prec_multiplier + # + # When unspecifed, these parameters are retrieved from + # functional_algorithms database of support functions. + # + dict(name="asin", mpmath_name="arcsin"), + dict(name="acos", mpmath_name="arccos"), + dict(name="atan", mpmath_name="arctan"), + dict(name="asinh", mpmath_name="arcsinh"), + dict(name="acosh", mpmath_name="arccosh"), + dict(name="atanh", mpmath_name="arctanh"), + dict(name="square", mpmath_name="square"), + dict(name="log_plus_one", + mpmath_name="log1p", + namespace="stablehlo", + passes="--stablehlo-complex-math-expander"), ] @@ -127,19 +131,21 @@ def main(): for op in operations: opname = op["name"] mpmath_opname = op.get("mpmath_name", opname) + namespace = op.get("namespace", "chlo") size_re = size_im = op.get("size", default_size) - + passes = op.get("passes", "--chlo-legalize-to-stablehlo") for dtype in [np.complex64, np.complex128, np.float32, np.float64]: params = fa.utils.function_validation_parameters(opname, dtype) max_ulp_difference = op.get( - "max_ulp_difference", - params.get("max_valid_ulp_count", default_max_ulp_difference)) + "max_ulp_difference", + params.get("max_valid_ulp_count", default_max_ulp_difference)) nmp = fa.utils.numpy_with_mpmath( - extra_prec_multiplier = op.get( - "extra_prec_multiplier", - params.get("extra_prec_multiplier", default_extra_prec_multiplier)), - flush_subnormals=flush_subnormals, + extra_prec_multiplier=op.get( + "extra_prec_multiplier", + params.get("extra_prec_multiplier", + default_extra_prec_multiplier)), + flush_subnormals=flush_subnormals, ) fi = np.finfo(dtype) @@ -180,7 +186,7 @@ def main(): main_func = m.make_function("main", "", "", "public") ref_samples = main_func.call("samples") - actual = main_func.composite(f"chlo.{opname}", ref_samples) + actual = main_func.composite(f"{namespace}.{opname}", ref_samples) expected = main_func.call("expected") main_func.void_call( @@ -202,7 +208,7 @@ def main(): continue f = open(fname, "w") - f.write("// RUN: stablehlo-opt --chlo-legalize-to-stablehlo %s |" + f.write(f"// RUN: stablehlo-opt {passes} %s |" " stablehlo-translate --interpret\n") f.write( "// This file is generated, see build_tools/math/README.md for more" diff --git a/docs/generated/stablehlo_passes.md b/docs/generated/stablehlo_passes.md index 4f5b6e7447..2fdda99136 100755 --- a/docs/generated/stablehlo_passes.md +++ b/docs/generated/stablehlo_passes.md @@ -86,6 +86,33 @@ func.func @tan_op_non_complex(%arg0: tensor<4xf64>) -> tensor<4xf64> { ``` -target : The target version. Must be a version of the form #.#.#. ``` +### `-stablehlo-complex-math-expander` + +_Expander for StableHLO complex math operations._ + +StableHLO complex math operations are decompositions using +StableHLO real math operations. + +This statement is based on the assumption that no hardware exists +that supports complex numbers nor complex math operations +natively. This means that the fallback mechanisms on complex math +operations that compilers may implement, are redundant. With +enabling this pass, all StableHLO complex math operations will be +expanded. + +```mlir +func.func @sqrt_op_complex(%arg0: tensor<4xcomplex>) -> tensor<4xcomplex> { + %1 = stablehlo.sqrt %arg0 : tensor<4xcomplex> + func.return %1 : tensor<4xcomplex> +} + +==> + +func.func @sqrt_op_complex(%arg0: tensor<4xcomplex>) -> tensor<4xcomplex> { + TBD + return %2 : tensor<4xcomplex> +} +``` ### `-stablehlo-convert-to-signless` _Pass to transform the IR to be on signless integers._ diff --git a/docs/spec.md b/docs/spec.md index 48760dbd96..e59de3f716 100644 --- a/docs/spec.md +++ b/docs/spec.md @@ -4066,7 +4066,8 @@ Performs element-wise logarithm plus one operation on `operand` tensor and produces a `result` tensor. Depending on the element type, does the following: * For floats: `logp1` from IEEE-754. -* For complex numbers: complex logarithm plus one. +* For complex numbers: + `complex(log(hypot(real(x) + 1, imag(x))), atan2(imag(x), real(x) + 1))` * For quantized types: `dequantize_op_quantize(log_plus_one, operand, type(result))`. diff --git a/stablehlo/tests/math/log_plus_one_complex128.mlir b/stablehlo/tests/math/log_plus_one_complex128.mlir new file mode 100644 index 0000000000..5d95770d4d --- /dev/null +++ b/stablehlo/tests/math/log_plus_one_complex128.mlir @@ -0,0 +1,19 @@ +// RUN: stablehlo-opt --stablehlo-complex-math-expander %s | stablehlo-translate --interpret +// This file is generated, see build_tools/math/README.md for more information. +module @log_plus_one_complex128 { + func.func private @samples() -> tensor<169xcomplex> { + %0 = stablehlo.constant dense<"0xtensor<169xcomplex> + return %0 : tensor<169xcomplex> + } + func.func private @expected() -> tensor<169xcomplex> { + %0 = stablehlo.constant dense<"0x000000000000F07FD221337F7CD902C0000000000000F07F182D4454FB21F9BF000000000000F07F182D4454FB21F9BF000000000000F07F182D4454FB21F9BF000000000000F07F182D4454FB21F9BF000000000000F07F182D4454FB21F9BF000000000000F07F182D4454FB21F9BF000000000000F07F182D4454FB21F9BF000000000000F07F182D4454FB21F9BF000000000000F07F182D4454FB21F9BF000000000000F07F182D4454FB21F9BF000000000000F07F182D4454FB21F9BF000000000000F07F182D4454FB21E9BF000000000000F07F182D4454FB2109C036195AC708318640D221337F7CD902C036195AC708318640D221337F7CD902C0EF39FAFE422E8640182D4454FB21F9BFEF39FAFE422E8640182D4454FB21F9BFEF39FAFE422E8640182D4454FB21F9BFEF39FAFE422E8640182D4454FB21F9BFEF39FAFE422E8640182D4454FB21F9BFEF39FAFE422E8640182D4454FB21F9BFEF39FAFE422E8640182D4454FB21F9BF36195AC708318640192D4454FB21E9BF36195AC708318640182D4454FB21E9BF000000000000F07F0000000000000000000000000000F07F182D4454FB2109C036195AC708318640D221337F7CD902C036195AC708318640D221337F7CD902C0EF39FAFE422E8640182D4454FB21F9BFEF39FAFE422E8640182D4454FB21F9BFEF39FAFE422E8640182D4454FB21F9BFEF39FAFE422E8640182D4454FB21F9BFEF39FAFE422E8640182D4454FB21F9BFEF39FAFE422E8640182D4454FB21F9BFEF39FAFE422E8640182D4454FB21F9BF36195AC708318640182D4454FB21E9BF36195AC708318640182D4454FB21E9BF000000000000F07F0000000000000000000000000000F07F182D4454FB2109C0EF39FAFE422E8640182D4454FB2109C0EF39FAFE422E8640182D4454FB2109C078E0E0F04052DD3FD1D40D3DDF47FEBFB3DE6857C5DBE23F9BF681D20B73EFBFB3DE6857C5DBE23F9BF681D20B73EFBFB3DE6857C5DBE23F9BF681D20B73EFBFB3DE6857C5DBE23F9BF681D20B73EFBFB3DE6857C5DBE23F9BF681D20B73EFBFFEF03B02DB1EF13FE10CF9D51D4BE1BFEF39FAFE422E86400000000000000680EF39FAFE422E86400000000000000680000000000000F07F0000000000000000000000000000F07F182D4454FB2109C0EF39FAFE422E8640182D4454FB2109C0EF39FAFE422E8640182D4454FB2109C0EF39FAFE422EE6BF182D4454FB2109C0000000000000FC9F000000000000FC9FFFFFFFFFFF1F0600000000000000FC9F0000000000200600000000000000FC9F0100000000200600000000000000FC9F000000000000FC1F000000000000FC9F78E0E0F04052ED3F666666666666E69FEF39FAFE422E86400000000000000080EF39FAFE422E86400000000000000080000000000000F07F0000000000000000000000000000F07F182D4454FB2109C0EF39FAFE422E8640182D4454FB2109C0EF39FAFE422E8640182D4454FB2109C0EF39FAFE422EE6BF182D4454FB2109C0000000000000FC9F0100000000000080010000000000008001000000000000800000000000000000010000000000008001000000000000000100000000000080000000000000FC1F010000000000008078E0E0F04052ED3F0000000000000080EF39FAFE422E86400000000000000080EF39FAFE422E86400000000000000080000000000000F07F0000000000000000000000000000F07F182D4454FB210940EF39FAFE422E8640182D4454FB210940EF39FAFE422E8640182D4454FB210940EF39FAFE422EE6BF182D4454FB210940000000000000FC9F0000000000000000010000000000008000000000000000000000000000000000000000000000000001000000000000000000000000000000000000000000FC1F000000000000000078E0E0F04052ED3F0000000000000000EF39FAFE422E86400000000000000000EF39FAFE422E86400000000000000000000000000000F07F0000000000000000000000000000F07F182D4454FB210940EF39FAFE422E8640182D4454FB210940EF39FAFE422E8640182D4454FB210940EF39FAFE422EE6BF182D4454FB210940000000000000FC9F0100000000000000010000000000008001000000000000000000000000000000010000000000000001000000000000000100000000000000000000000000FC1F010000000000000078E0E0F04052ED3F0000000000000000EF39FAFE422E86400000000000000000EF39FAFE422E86400000000000000000000000000000F07F0000000000000000000000000000F07F182D4454FB210940EF39FAFE422E8640182D4454FB210940EF39FAFE422E8640182D4454FB210940EF39FAFE422EE6BF182D4454FB210940000000000000FC9F000000000000FC1FFFFFFFFFFF1F0600000000000000FC1F0000000000200600000000000000FC1F0100000000200600000000000000FC1F000000000000FC1F000000000000FC1F78E0E0F04052ED3F666666666666E61FEF39FAFE422E86400000000000000000EF39FAFE422E86400000000000000000000000000000F07F0000000000000000000000000000F07F182D4454FB210940EF39FAFE422E8640182D4454FB210940EF39FAFE422E8640182D4454FB21094078E0E0F04052DD3FD1D40D3DDF47FE3FB3DE6857C5DBE23F9BF681D20B73EF3FB3DE6857C5DBE23F9BF681D20B73EF3FB3DE6857C5DBE23F9BF681D20B73EF3FB3DE6857C5DBE23F9BF681D20B73EF3FB3DE6857C5DBE23F9BF681D20B73EF3FFEF03B02DB1EF13FE10CF9D51D4BE13FEF39FAFE422E86400000000000000600EF39FAFE422E86400000000000000600000000000000F07F0000000000000000000000000000F07F182D4454FB21094036195AC708318640D221337F7CD9024036195AC708318640D221337F7CD90240EF39FAFE422E8640182D4454FB21F93FEF39FAFE422E8640182D4454FB21F93FEF39FAFE422E8640182D4454FB21F93FEF39FAFE422E8640182D4454FB21F93FEF39FAFE422E8640182D4454FB21F93FEF39FAFE422E8640182D4454FB21F93FEF39FAFE422E8640182D4454FB21F93F36195AC708318640182D4454FB21E93F36195AC708318640182D4454FB21E93F000000000000F07F0000000000000000000000000000F07F182D4454FB21094036195AC708318640D221337F7CD9024036195AC708318640D221337F7CD90240EF39FAFE422E8640182D4454FB21F93FEF39FAFE422E8640182D4454FB21F93FEF39FAFE422E8640182D4454FB21F93FEF39FAFE422E8640182D4454FB21F93FEF39FAFE422E8640182D4454FB21F93FEF39FAFE422E8640182D4454FB21F93FEF39FAFE422E8640182D4454FB21F93F36195AC708318640192D4454FB21E93F36195AC708318640182D4454FB21E93F000000000000F07F0000000000000000000000000000F07FD221337F7CD90240000000000000F07F182D4454FB21F93F000000000000F07F182D4454FB21F93F000000000000F07F182D4454FB21F93F000000000000F07F182D4454FB21F93F000000000000F07F182D4454FB21F93F000000000000F07F182D4454FB21F93F000000000000F07F182D4454FB21F93F000000000000F07F182D4454FB21F93F000000000000F07F182D4454FB21F93F000000000000F07F182D4454FB21F93F000000000000F07F182D4454FB21F93F000000000000F07F182D4454FB21E93F"> : tensor<169xcomplex> + return %0 : tensor<169xcomplex> + } + func.func public @main() { + %0 = call @samples() : () -> tensor<169xcomplex> + %1 = "stablehlo.log_plus_one"(%0) : (tensor<169xcomplex>) -> tensor<169xcomplex> + %2 = call @expected() : () -> tensor<169xcomplex> + check.expect_close %1, %2, max_ulp_difference = 3 : tensor<169xcomplex>, tensor<169xcomplex> + func.return + } +} diff --git a/stablehlo/tests/math/log_plus_one_complex64.mlir b/stablehlo/tests/math/log_plus_one_complex64.mlir new file mode 100644 index 0000000000..e00c93251a --- /dev/null +++ b/stablehlo/tests/math/log_plus_one_complex64.mlir @@ -0,0 +1,19 @@ +// RUN: stablehlo-opt --stablehlo-complex-math-expander %s | stablehlo-translate --interpret +// This file is generated, see build_tools/math/README.md for more information. +module @log_plus_one_complex64 { + func.func private @samples() -> tensor<169xcomplex> { + %0 = stablehlo.constant dense<"0xtensor<169xcomplex> + return %0 : tensor<169xcomplex> + } + func.func private @expected() -> tensor<169xcomplex> { + %0 = stablehlo.constant dense<"0xtensor<169xcomplex> + return %0 : tensor<169xcomplex> + } + func.func public @main() { + %0 = call @samples() : () -> tensor<169xcomplex> + %1 = "stablehlo.log_plus_one"(%0) : (tensor<169xcomplex>) -> tensor<169xcomplex> + %2 = call @expected() : () -> tensor<169xcomplex> + check.expect_close %1, %2, max_ulp_difference = 3 : tensor<169xcomplex>, tensor<169xcomplex> + func.return + } +} diff --git a/stablehlo/tests/math/log_plus_one_float32.mlir b/stablehlo/tests/math/log_plus_one_float32.mlir new file mode 100644 index 0000000000..4384eadb70 --- /dev/null +++ b/stablehlo/tests/math/log_plus_one_float32.mlir @@ -0,0 +1,19 @@ +// RUN: stablehlo-opt --stablehlo-complex-math-expander %s | stablehlo-translate --interpret +// This file is generated, see build_tools/math/README.md for more information. +module @log_plus_one_float32 { + func.func private @samples() -> tensor<169xf32> { + %0 = stablehlo.constant dense<"0x000080FFFFFF7FFFFEFF7FFF05E763FC88DAD5FA0BCE47F98EC1B9F711B52BF695A89DF4189C0FF39B8F81F11E83F3EFA17665EE246AD7ECA75D49EB2B51BBE9AE442DE831389FE6B42B11E5371F83E3BA12F5E13D0667E0C1F9D8DE44ED4ADDC7E0BCDB4AD42EDACDC7A0D850BB12D7D3AE84D557A2F6D3DA9568D25D89DAD0E07C4CCF6370BECDE66330CC6957A2CAED4A14C9703E86C7F331F8C576256AC4F918DCC27C0C4EC10000C0BF83F331BE06E7A3BC89DA15BB0CCE87B98FC1F9B712B56BB696A8DDB4199C4FB39C8FC1B11F8333B0A276A5AE256A17ADA85D89AB2C51FBA9AF446DA83238DFA6B52B51A5381FC3A3BB1235A23E06A7A0C2F9189F45ED8A9DC8E0FC9B4BD46E9ACEC7E09851BB5297D4AEC49558A23694DB95A8925E891A91E17C8C8F6470FE8DE763708C6A57E28AEE4A5489713EC687F43138867725AA84FA181C837D0C8E810100008000000000010000007D0C8E01FA181C037725AA04F4313806713EC607EE4A54096A57E20AE763700C6470FE0DE17C8C0F5E891A11DB95A81258A23614D4AEC41551BB5217CEC7E0184BD46E1AC8E0FC1B45ED8A1DC2F9181F3E06A720BB123522381FC323B52B51253238DF26AF446D282C51FB29A85D892B256A172DA276A52E1F8333309C8FC131199C4F3396A8DD3412B56B368FC1F9370CCE873989DA153B06E7A33C83F3313E0000C03F7C0C4E41F918DC4276256A44F331F845703E8647ED4A14496957A24AE663304C6370BE4DE07C4C4F5D89DA50DA95685257A2F653D3AE845550BB1257CDC7A0584AD42E5AC7E0BC5B44ED4A5DC1F9D85E3D066760BA12F561371F8363B42B116531389F66AE442D682B51BB69A75D496B246AD76CA176656E1E83F36F9B8F8171189C0F7395A89D7411B52B768EC1B9770BCE477988DAD57A05E7637CFEFF7F7FFFFF7F7F0000807F"> : tensor<169xf32> + return %0 : tensor<169xf32> + } + func.func private @expected() -> tensor<169xf32> { + %0 = stablehlo.constant dense<"0xtensor<169xf32> + return %0 : tensor<169xf32> + } + func.func public @main() { + %0 = call @samples() : () -> tensor<169xf32> + %1 = "stablehlo.log_plus_one"(%0) : (tensor<169xf32>) -> tensor<169xf32> + %2 = call @expected() : () -> tensor<169xf32> + check.expect_close %1, %2, max_ulp_difference = 3 : tensor<169xf32>, tensor<169xf32> + func.return + } +} diff --git a/stablehlo/tests/math/log_plus_one_float64.mlir b/stablehlo/tests/math/log_plus_one_float64.mlir new file mode 100644 index 0000000000..8d0c0c7f1d --- /dev/null +++ b/stablehlo/tests/math/log_plus_one_float64.mlir @@ -0,0 +1,19 @@ +// RUN: stablehlo-opt --stablehlo-complex-math-expander %s | stablehlo-translate --interpret +// This file is generated, see build_tools/math/README.md for more information. +module @log_plus_one_float64 { + func.func private @samples() -> tensor<169xf64> { + %0 = stablehlo.constant dense<"0xtensor<169xf64> + return %0 : tensor<169xf64> + } + func.func private @expected() -> tensor<169xf64> { + %0 = stablehlo.constant dense<"0xtensor<169xf64> + return %0 : tensor<169xf64> + } + func.func public @main() { + %0 = call @samples() : () -> tensor<169xf64> + %1 = "stablehlo.log_plus_one"(%0) : (tensor<169xf64>) -> tensor<169xf64> + %2 = call @expected() : () -> tensor<169xf64> + check.expect_close %1, %2, max_ulp_difference = 3 : tensor<169xf64>, tensor<169xf64> + func.return + } +} diff --git a/stablehlo/tests/stablehlo_complex_math_expander.mlir b/stablehlo/tests/stablehlo_complex_math_expander.mlir new file mode 100644 index 0000000000..7a43f76d6e --- /dev/null +++ b/stablehlo/tests/stablehlo_complex_math_expander.mlir @@ -0,0 +1,110 @@ +// RUN: stablehlo-opt --stablehlo-complex-math-expander --split-input-file --verify-diagnostics %s | FileCheck %s + +// CHECK-LABEL: func.func @log_plus_one_complex_f32( +// CHECK-SAME: %[[VAL_0:.*]]: tensor>) -> tensor> { +// CHECK: %[[VAL_1:.*]] = stablehlo.constant dense<6.500000e+01> : tensor +// CHECK: %[[VAL_2:.*]] = stablehlo.constant dense<4.097000e+03> : tensor +// CHECK: %[[VAL_3:.*]] = stablehlo.constant dense<9.99999968E+37> : tensor +// CHECK: %[[VAL_4:.*]] = stablehlo.constant dense<0x4D000000> : tensor +// CHECK: %[[VAL_5:.*]] = stablehlo.constant dense<0x7F800000> : tensor +// CHECK: %[[VAL_6:.*]] = stablehlo.constant dense<2.000000e-01> : tensor +// CHECK: %[[VAL_7:.*]] = stablehlo.constant dense<1.000000e+00> : tensor +// CHECK: %[[VAL_8:.*]] = stablehlo.constant dense<5.000000e-01> : tensor +// CHECK: %[[VAL_9:.*]] = stablehlo.constant dense<0.00999999977> : tensor +// CHECK: %[[VAL_10:.*]] = stablehlo.constant dense<3.40282347E+38> : tensor +// CHECK: %[[VAL_11:.*]] = stablehlo.real %[[VAL_0]] : (tensor>) -> tensor +// CHECK: %[[VAL_12:.*]] = stablehlo.abs %[[VAL_11]] : tensor +// CHECK: %[[VAL_13:.*]] = stablehlo.imag %[[VAL_0]] : (tensor>) -> tensor +// CHECK: %[[VAL_14:.*]] = stablehlo.abs %[[VAL_13]] : tensor +// CHECK: %[[VAL_15:.*]] = stablehlo.maximum %[[VAL_12]], %[[VAL_14]] : tensor +// CHECK: %[[VAL_16:.*]] = stablehlo.sqrt %[[VAL_10]] : tensor +// CHECK: %[[VAL_17:.*]] = stablehlo.multiply %[[VAL_16]], %[[VAL_9]] : tensor +// CHECK: %[[VAL_18:.*]] = stablehlo.compare GT, %[[VAL_15]], %[[VAL_17]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_19:.*]] = stablehlo.log %[[VAL_15]] : tensor +// CHECK: %[[VAL_20:.*]] = stablehlo.minimum %[[VAL_12]], %[[VAL_14]] : tensor +// CHECK: %[[VAL_21:.*]] = stablehlo.compare EQ, %[[VAL_20]], %[[VAL_15]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_22:.*]] = stablehlo.divide %[[VAL_20]], %[[VAL_15]] : tensor +// CHECK: %[[VAL_23:.*]] = stablehlo.multiply %[[VAL_22]], %[[VAL_22]] : tensor +// CHECK: %[[VAL_24:.*]] = stablehlo.select %[[VAL_21]], %[[VAL_7]], %[[VAL_23]] : tensor, tensor +// CHECK: %[[VAL_25:.*]] = stablehlo.log_plus_one %[[VAL_24]] : tensor +// CHECK: %[[VAL_26:.*]] = stablehlo.multiply %[[VAL_8]], %[[VAL_25]] : tensor +// CHECK: %[[VAL_27:.*]] = stablehlo.add %[[VAL_19]], %[[VAL_26]] : tensor +// CHECK: %[[VAL_28:.*]] = stablehlo.add %[[VAL_11]], %[[VAL_7]] : tensor +// CHECK: %[[VAL_29:.*]] = stablehlo.abs %[[VAL_28]] : tensor +// CHECK: %[[VAL_30:.*]] = stablehlo.add %[[VAL_29]], %[[VAL_14]] : tensor +// CHECK: %[[VAL_31:.*]] = stablehlo.compare LT, %[[VAL_30]], %[[VAL_6]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_32:.*]] = stablehlo.multiply %[[VAL_28]], %[[VAL_28]] : tensor +// CHECK: %[[VAL_33:.*]] = stablehlo.multiply %[[VAL_13]], %[[VAL_13]] : tensor +// CHECK: %[[VAL_34:.*]] = stablehlo.add %[[VAL_32]], %[[VAL_33]] : tensor +// CHECK: %[[VAL_35:.*]] = stablehlo.log %[[VAL_34]] : tensor +// CHECK: %[[VAL_36:.*]] = stablehlo.multiply %[[VAL_8]], %[[VAL_35]] : tensor +// CHECK: %[[VAL_37:.*]] = stablehlo.add %[[VAL_11]], %[[VAL_11]] : tensor +// CHECK: %[[VAL_38:.*]] = stablehlo.add %[[VAL_37]], %[[VAL_33]] : tensor +// CHECK: %[[VAL_39:.*]] = stablehlo.multiply %[[VAL_11]], %[[VAL_11]] : tensor +// CHECK: %[[VAL_40:.*]] = stablehlo.add %[[VAL_38]], %[[VAL_39]] : tensor +// CHECK: %[[VAL_41:.*]] = stablehlo.negate %[[VAL_33]] : tensor +// CHECK: %[[VAL_42:.*]] = stablehlo.compare GT, %[[VAL_10]], %[[VAL_5]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_43:.*]] = stablehlo.compare GT, %[[VAL_10]], %[[VAL_3]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_44:.*]] = stablehlo.select %[[VAL_43]], %[[VAL_2]], %[[VAL_1]] : tensor, tensor +// CHECK: %[[VAL_45:.*]] = stablehlo.select %[[VAL_42]], %[[VAL_4]], %[[VAL_44]] : tensor, tensor +// CHECK: %[[VAL_46:.*]] = stablehlo.multiply %[[VAL_45]], %[[VAL_13]] : tensor +// CHECK: %[[VAL_47:.*]] = stablehlo.subtract %[[VAL_13]], %[[VAL_46]] : tensor +// CHECK: %[[VAL_48:.*]] = stablehlo.add %[[VAL_46]], %[[VAL_47]] : tensor +// CHECK: %[[VAL_49:.*]] = stablehlo.multiply %[[VAL_48]], %[[VAL_48]] : tensor +// CHECK: %[[VAL_50:.*]] = stablehlo.add %[[VAL_41]], %[[VAL_49]] : tensor +// CHECK: %[[VAL_51:.*]] = stablehlo.subtract %[[VAL_13]], %[[VAL_48]] : tensor +// CHECK: %[[VAL_52:.*]] = stablehlo.multiply %[[VAL_48]], %[[VAL_51]] : tensor +// CHECK: %[[VAL_53:.*]] = stablehlo.add %[[VAL_50]], %[[VAL_52]] : tensor +// CHECK: %[[VAL_54:.*]] = stablehlo.add %[[VAL_53]], %[[VAL_52]] : tensor +// CHECK: %[[VAL_55:.*]] = stablehlo.multiply %[[VAL_51]], %[[VAL_51]] : tensor +// CHECK: %[[VAL_56:.*]] = stablehlo.add %[[VAL_54]], %[[VAL_55]] : tensor +// CHECK: %[[VAL_57:.*]] = stablehlo.add %[[VAL_40]], %[[VAL_56]] : tensor +// CHECK: %[[VAL_58:.*]] = stablehlo.negate %[[VAL_39]] : tensor +// CHECK: %[[VAL_59:.*]] = stablehlo.multiply %[[VAL_45]], %[[VAL_11]] : tensor +// CHECK: %[[VAL_60:.*]] = stablehlo.subtract %[[VAL_11]], %[[VAL_59]] : tensor +// CHECK: %[[VAL_61:.*]] = stablehlo.add %[[VAL_59]], %[[VAL_60]] : tensor +// CHECK: %[[VAL_62:.*]] = stablehlo.multiply %[[VAL_61]], %[[VAL_61]] : tensor +// CHECK: %[[VAL_63:.*]] = stablehlo.add %[[VAL_58]], %[[VAL_62]] : tensor +// CHECK: %[[VAL_64:.*]] = stablehlo.subtract %[[VAL_11]], %[[VAL_61]] : tensor +// CHECK: %[[VAL_65:.*]] = stablehlo.multiply %[[VAL_61]], %[[VAL_64]] : tensor +// CHECK: %[[VAL_66:.*]] = stablehlo.add %[[VAL_63]], %[[VAL_65]] : tensor +// CHECK: %[[VAL_67:.*]] = stablehlo.add %[[VAL_66]], %[[VAL_65]] : tensor +// CHECK: %[[VAL_68:.*]] = stablehlo.multiply %[[VAL_64]], %[[VAL_64]] : tensor +// CHECK: %[[VAL_69:.*]] = stablehlo.add %[[VAL_67]], %[[VAL_68]] : tensor +// CHECK: %[[VAL_70:.*]] = stablehlo.add %[[VAL_57]], %[[VAL_69]] : tensor +// CHECK: %[[VAL_71:.*]] = stablehlo.subtract %[[VAL_38]], %[[VAL_37]] : tensor +// CHECK: %[[VAL_72:.*]] = stablehlo.subtract %[[VAL_38]], %[[VAL_71]] : tensor +// CHECK: %[[VAL_73:.*]] = stablehlo.subtract %[[VAL_37]], %[[VAL_72]] : tensor +// CHECK: %[[VAL_74:.*]] = stablehlo.subtract %[[VAL_33]], %[[VAL_71]] : tensor +// CHECK: %[[VAL_75:.*]] = stablehlo.add %[[VAL_73]], %[[VAL_74]] : tensor +// CHECK: %[[VAL_76:.*]] = stablehlo.subtract %[[VAL_40]], %[[VAL_38]] : tensor +// CHECK: %[[VAL_77:.*]] = stablehlo.subtract %[[VAL_40]], %[[VAL_76]] : tensor +// CHECK: %[[VAL_78:.*]] = stablehlo.subtract %[[VAL_38]], %[[VAL_77]] : tensor +// CHECK: %[[VAL_79:.*]] = stablehlo.subtract %[[VAL_39]], %[[VAL_76]] : tensor +// CHECK: %[[VAL_80:.*]] = stablehlo.add %[[VAL_78]], %[[VAL_79]] : tensor +// CHECK: %[[VAL_81:.*]] = stablehlo.add %[[VAL_75]], %[[VAL_80]] : tensor +// CHECK: %[[VAL_82:.*]] = stablehlo.subtract %[[VAL_57]], %[[VAL_40]] : tensor +// CHECK: %[[VAL_83:.*]] = stablehlo.subtract %[[VAL_57]], %[[VAL_82]] : tensor +// CHECK: %[[VAL_84:.*]] = stablehlo.subtract %[[VAL_40]], %[[VAL_83]] : tensor +// CHECK: %[[VAL_85:.*]] = stablehlo.subtract %[[VAL_56]], %[[VAL_82]] : tensor +// CHECK: %[[VAL_86:.*]] = stablehlo.add %[[VAL_84]], %[[VAL_85]] : tensor +// CHECK: %[[VAL_87:.*]] = stablehlo.add %[[VAL_81]], %[[VAL_86]] : tensor +// CHECK: %[[VAL_88:.*]] = stablehlo.subtract %[[VAL_70]], %[[VAL_57]] : tensor +// CHECK: %[[VAL_89:.*]] = stablehlo.subtract %[[VAL_70]], %[[VAL_88]] : tensor +// CHECK: %[[VAL_90:.*]] = stablehlo.subtract %[[VAL_57]], %[[VAL_89]] : tensor +// CHECK: %[[VAL_91:.*]] = stablehlo.subtract %[[VAL_69]], %[[VAL_88]] : tensor +// CHECK: %[[VAL_92:.*]] = stablehlo.add %[[VAL_90]], %[[VAL_91]] : tensor +// CHECK: %[[VAL_93:.*]] = stablehlo.add %[[VAL_87]], %[[VAL_92]] : tensor +// CHECK: %[[VAL_94:.*]] = stablehlo.add %[[VAL_70]], %[[VAL_93]] : tensor +// CHECK: %[[VAL_95:.*]] = stablehlo.log_plus_one %[[VAL_94]] : tensor +// CHECK: %[[VAL_96:.*]] = stablehlo.multiply %[[VAL_8]], %[[VAL_95]] : tensor +// CHECK: %[[VAL_97:.*]] = stablehlo.select %[[VAL_31]], %[[VAL_36]], %[[VAL_96]] : tensor, tensor +// CHECK: %[[VAL_98:.*]] = stablehlo.select %[[VAL_18]], %[[VAL_27]], %[[VAL_97]] : tensor, tensor +// CHECK: %[[VAL_99:.*]] = stablehlo.atan2 %[[VAL_13]], %[[VAL_28]] : tensor +// CHECK: %[[VAL_100:.*]] = stablehlo.complex %[[VAL_98]], %[[VAL_99]] : tensor> +// CHECK: return %[[VAL_100]] : tensor> +// CHECK: } +func.func @log_plus_one_complex_f32(%arg : tensor>) -> tensor> { + %result = "stablehlo.log_plus_one"(%arg) : (tensor>) -> tensor> + func.return %result : tensor> +} diff --git a/stablehlo/transforms/CMakeLists.txt b/stablehlo/transforms/CMakeLists.txt index c69f7ff58c..f5193cb61e 100644 --- a/stablehlo/transforms/CMakeLists.txt +++ b/stablehlo/transforms/CMakeLists.txt @@ -28,6 +28,10 @@ set(LLVM_TARGET_DEFINITIONS StablehloCompatibilityExpanderPatterns.td) mlir_tablegen(StablehloCompatibilityExpanderPatterns.h.inc --gen-rewriters) add_public_tablegen_target(StablehloCompatibilityExpanderPatternsIncGen) +set(LLVM_TARGET_DEFINITIONS StablehloComplexMathExpanderPatterns.td) +mlir_tablegen(StablehloComplexMathExpanderPatterns.h.inc --gen-rewriters) +add_public_tablegen_target(StablehloComplexMathExpanderPatternsIncGen) + set(LLVM_TARGET_DEFINITIONS StablehloLegalizeDeprecatedOpsPatterns.td) mlir_tablegen(StablehloLegalizeDeprecatedOpsPatterns.h.inc --gen-rewriters) add_public_tablegen_target(StablehloLegalizeDeprecatedOpsPatternsIncGen) @@ -47,6 +51,7 @@ add_mlir_dialect_library(StablehloPasses StablehloCanonicalizeDynamism.cpp StablehloConvertToSignless.cpp StablehloCompatibilityExpander.cpp + StablehloComplexMathExpander.cpp StablehloLegalizeCompositeToCall.cpp StablehloLegalizeDeprecatedOps.cpp StablehloLegalizeQuantToMath.cpp @@ -64,6 +69,7 @@ add_mlir_dialect_library(StablehloPasses PassesIncGen StablehloAggressiveSimplificationPatternsIncGen StablehloCompatibilityExpanderPatternsIncGen + StablehloComplexMathExpanderPatternsIncGen StablehloLegalizeDeprecatedOpsPatternsIncGen VhloToVersionPatterns diff --git a/stablehlo/transforms/ChloDecompositionPatternsMath.td b/stablehlo/transforms/ChloDecompositionPatternsMath.td index 3d1be0d4f6..85bfb0cf26 100644 --- a/stablehlo/transforms/ChloDecompositionPatternsMath.td +++ b/stablehlo/transforms/ChloDecompositionPatternsMath.td @@ -14,7 +14,7 @@ limitations under the License. ==============================================================================*/ // -// This file is generated using functional_algorithms tool (0.11.1). +// This file is generated using functional_algorithms tool (0.12.0). // See build_tools/math/README.md for more information. // A kernel for evaluating asin and acos functions on complex inputs. diff --git a/stablehlo/transforms/Passes.h b/stablehlo/transforms/Passes.h index 43ee22f355..bf01394cda 100644 --- a/stablehlo/transforms/Passes.h +++ b/stablehlo/transforms/Passes.h @@ -144,6 +144,11 @@ void createStablehloRemoveDynamismPipeline(OpPassManager &pm, // operations into a primitive math operations. void createStablehloLowerQuantPipeline(OpPassManager &pm); +/// Collection of patterns to create expander for StableHLO complex +/// math operations. +void populateStablehloComplexMathExpanderPatterns(RewritePatternSet *patterns, + MLIRContext *context); + // Adds `stablehlo-deserialize` pipeline as a registered pass pipeline // for opt tools. void registerPassPipelines(); diff --git a/stablehlo/transforms/Passes.td b/stablehlo/transforms/Passes.td index 186560d012..aa9d696664 100644 --- a/stablehlo/transforms/Passes.td +++ b/stablehlo/transforms/Passes.td @@ -123,6 +123,40 @@ def StablehloCompatibilityExpanderPass : Pass<"stablehlo-compatibility-expander" ]; } +def StablehloComplexMathExpanderPass : Pass<"stablehlo-complex-math-expander", "mlir::func::FuncOp"> { + let summary = "Expander for StableHLO complex math operations."; + + let description = [{ + StableHLO complex math operations are decompositions using + StableHLO real math operations. + + This statement is based on the assumption that no hardware exists + that supports complex numbers nor complex math operations + natively. This means that the fallback mechanisms on complex math + operations that compilers may implement, are redundant. With + enabling this pass, all StableHLO complex math operations will be + expanded. + + ```mlir + func.func @sqrt_op_complex(%arg0: tensor<4xcomplex>) -> tensor<4xcomplex> { + %1 = stablehlo.sqrt %arg0 : tensor<4xcomplex> + func.return %1 : tensor<4xcomplex> + } + + ==> + + func.func @sqrt_op_complex(%arg0: tensor<4xcomplex>) -> tensor<4xcomplex> { + TBD + return %2 : tensor<4xcomplex> + } + ``` + }]; + let dependentDialects = [ + "mlir::stablehlo::StablehloDialect", + "mlir::chlo::ChloDialect", + ]; +} + def StablehloConvertToSignlessPass : Pass<"stablehlo-convert-to-signless", "ModuleOp"> { let summary = "Pass to transform the IR to be on signless integers."; } diff --git a/stablehlo/transforms/StablehloComplexMathExpander.cpp b/stablehlo/transforms/StablehloComplexMathExpander.cpp new file mode 100644 index 0000000000..b830db6196 --- /dev/null +++ b/stablehlo/transforms/StablehloComplexMathExpander.cpp @@ -0,0 +1,76 @@ +/* Copyright 2024 The StableHLO Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "stablehlo/dialect/StablehloOps.h" +#include "stablehlo/transforms/PassUtils.h" +#include "stablehlo/transforms/Passes.h" + +namespace mlir { +namespace stablehlo { +#define GEN_PASS_DEF_STABLEHLOCOMPLEXMATHEXPANDERPASS +#include "stablehlo/transforms/Passes.h.inc" + +namespace { + +static Value getConstantLikeMaxFiniteValue(OpBuilder &b, Location loc, + Value val) { + auto ty = cast(getElementTypeOrSelf(val.getType())); + return getConstantLike( + b, loc, llvm::APFloat::getLargest(ty.getFloatSemantics()), val); +} + +//===----------------------------------------------------------------------===// +// Pass +//===----------------------------------------------------------------------===// + +struct StablehloComplexMathExpanderPass + : public impl::StablehloComplexMathExpanderPassBase< + StablehloComplexMathExpanderPass> { + StablehloComplexMathExpanderPass() + : StablehloComplexMathExpanderPassBase< + StablehloComplexMathExpanderPass>() {} + + public: + LogicalResult initialize(MLIRContext *context) override { + config.useTopDownTraversal = true; + RewritePatternSet patterns_(context); + populateStablehloComplexMathExpanderPatterns(&patterns_, context); + patterns = std::move(patterns_); + return success(); + } + + void runOnOperation() override { + auto func = getOperation(); + if (failed(applyPatternsAndFoldGreedily(func, patterns, config))) { + func.emitError("Failed to converge StableHLOComplexMathExpanderPass in ") + << config.maxIterations << " iterations"; + signalPassFailure(); + } + } + + private: + FrozenRewritePatternSet patterns; + GreedyRewriteConfig config; +}; + +#include "stablehlo/transforms/StablehloComplexMathExpanderPatterns.h.inc" + +} // namespace + +void populateStablehloComplexMathExpanderPatterns(RewritePatternSet *patterns, + MLIRContext *context) { + populateWithGenerated(*patterns); +} + +} // namespace stablehlo +} // namespace mlir diff --git a/stablehlo/transforms/StablehloComplexMathExpanderPatterns.td b/stablehlo/transforms/StablehloComplexMathExpanderPatterns.td new file mode 100644 index 0000000000..42cb7f72fc --- /dev/null +++ b/stablehlo/transforms/StablehloComplexMathExpanderPatterns.td @@ -0,0 +1,275 @@ +/* Copyright 2024 The StableHLO Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// +// This file is generated using functional_algorithms tool (0.12.0). +// See build_tools/math/README.md for more information. + +include "mlir/IR/OpBase.td" +include "stablehlo/dialect/StablehloOps.td" + +class StableHLO_ComparisonDirectionValue : + ConstantAttr; + +class StableHLO_ConstantLike : NativeCodeCall< + "::mlir::stablehlo::getConstantLike($_builder, $_loc, " # value # ", $0)">; + +def ComplexElementType : Type< + CPred<"isa(cast($_self).getElementType())">, + "Complex element type">; + +def StableHLO_ConstantLikeMaxFiniteValue : NativeCodeCall< + "::mlir::stablehlo::getConstantLikeMaxFiniteValue($_builder, $_loc, $0)">; + +// Logarithm of 1 + z on complex input: +// +// log1p(x + I * y) = 0.5 * log((x + 1) ** 2 + y ** 2) + I * arctan2(y, x + 1) +// +// where +// +// x and y are real and imaginary parts of the input to log1p, and +// I is imaginary unit. +// +// For evaluating the real part of log1p accurately on the whole +// complex plane, the following cases must be handled separately: +// +// A) Avoid catastrophic cancellation errors when x is close `-0.5 * y * y` +// and `abs(y) < 1`. +// B) Avoid overflow from square when x or y are large in absolute value. +// C) Avoid cancellation errors when x is close to -1 and y is not large. +// D) Avoid cancellation errors when x is close to -2 and y is not large. +// +// Case A +// ------ +// +// The real part of log1p reads: +// +// 0.5 * log((x + 1) ** 2 + y ** 2) = 0.5 * log1p(x + x + x * x + y * y) +// +// When abs(y) < 1 and abs(x + 0.5 * y ** 2) is small, catastrophic +// cancellation errors occur when evaluating `x + x + x * x + y * y` +// using floating-point arithmetics. To avoid these errors, we'll use +// Dekker's product for computing `x * x` and `y * y` which +// effectively doubles the precision of the used floating-point +// system. In addition, the terms are summed together using 2Sum +// algorithm that minimizes cancellation errors. We'll have +// +// xxh, xxl = square_dekker(x) +// yyh, yyl = square_dekker(y) +// x + x + x * x + y * y = sum_2sum([x + x, yyh, xxh, yyl, xxl]) +// +// which is accurate when the following inequalities hold: +// +// abs(x) < sqrt(largest) * 0.1 +// abs(y) < sqrt(largest) * 0.99 +// +// [verified numerically for float32 and float64], except when x is +// close to -1 (see Case C). +// +// Case B +// ------ +// +// If abs(x) or abs(y) is larger than sqrt(largest), squareing +// these will overflow. To avoid such overflows, we'll apply +// rescaling of log1p arguments. +// +// First notice that if `abs(x) > sqrt(largest) > 4 / eps` holds then +// `x + 1 ~= x`. Also, if `abs(x) < 4 / eps` then `(x + 1) ** 2 + y +// ** 2 ~= y ** 2`. Proof: +// +// (x + 1) ** 2 + y ** 2 ~= y ** 2 iff y ** 2 > 4 * (x + 1) ** 2 / eps +// +// The lower limit to `y ** 2` is largest. The upper limit to +// `4 * (x + 1) ** 2 / eps` is `64 / eps ** 3` which is smaller than +// largest. QED. +// +// In conclusion, we can write +// +// (x + 1) ** 2 + y ** 2 ~= x ** 2 + y ** 2 +// +// whenever abs(x) or abs(y) is greater than sqrt(largest). +// +// Define +// +// mx = max(abs(x), abs(y)) +// mn = min(abs(x), abs(y)) +// +// then under the given restrictions we'll have +// +// real(log(x + I * y)) ~= 0.5 * log(x ** 2 + y ** 2) +// = 0.5 * log(mx ** 2 * (1 + (mn / mx) ** 2)) +// = log(mx) + 0.5 * log1p((mn / mx) ** 2) +// +// If mn == inf and mx == inf, we'll define `mn / mx == 1` for the +// sake of reusing the above expression for complex infinities +// (recall, `real(log(+-inf +-inf * I)) == inf`). +// +// Case C +// ------ +// +// If x is close to -1, then we'll use +// +// real(log1p(x + I * y)) = 0.5 * log((1 + x) ** 2 + y ** 2) +// +// which is accurate when the following inequalities hold: +// +// -1.5 < x < -0.5 or abs(x + 1) < 0.5 +// abs(y) < sqrt(largest) +// +// [verified numerically for float32 and float64]. For simplicity, +// we'll use the case C only when `abs(x) + abs(y) < 0.2`. +// +// Case D +// ------ +// +// If x is close to -2, the cancellation errors are avoided by using +// the Case A method [verified numerically for float32 and float64]. +// +// +def Log1pOp_ComplexElementType_ComplexMathExpander : Pat<(StableHLO_Log1pOp ComplexElementType:$z), + (StableHLO_ComplexOp + (StableHLO_SelectOp + (StableHLO_CompareOp + (StableHLO_MaxOp:$mx + (StableHLO_AbsOp:$ax + (StableHLO_RealOp:$x $z)), + (StableHLO_AbsOp:$ay + (StableHLO_ImagOp:$y $z))), + (StableHLO_MulOp + (StableHLO_SqrtOp + (StableHLO_ConstantLikeMaxFiniteValue:$largest $x)), + (StableHLO_ConstantLike<"0.01"> $x)), + StableHLO_ComparisonDirectionValue<"GT">, + (STABLEHLO_DEFAULT_COMPARISON_TYPE)), + (StableHLO_AddOp + (StableHLO_LogOp $mx), + (StableHLO_MulOp + (StableHLO_ConstantLike<"0.5">:$half $x), + (StableHLO_Log1pOp + (StableHLO_SelectOp + (StableHLO_CompareOp + (StableHLO_MinOp:$mn $ax, $ay), + $mx, + StableHLO_ComparisonDirectionValue<"EQ">, + (STABLEHLO_DEFAULT_COMPARISON_TYPE)), + (StableHLO_ConstantLike<"1">:$one $x), + (StableHLO_MulOp + (StableHLO_DivOp:$r $mn, $mx), + $r))))), + (StableHLO_SelectOp + (StableHLO_CompareOp + (StableHLO_AddOp + (StableHLO_AbsOp + (StableHLO_AddOp:$xp1 $x, $one)), + $ay), + (StableHLO_ConstantLike<"0.2"> $x), + StableHLO_ComparisonDirectionValue<"LT">, + (STABLEHLO_DEFAULT_COMPARISON_TYPE)), + (StableHLO_MulOp + $half, + (StableHLO_LogOp + (StableHLO_AddOp + (StableHLO_MulOp $xp1, $xp1), + (StableHLO_MulOp:$square_dekker_high $y, $y)))), + (StableHLO_MulOp + $half, + (StableHLO_Log1pOp + (StableHLO_AddOp:$sum_2sum_high + (StableHLO_AddOp:$add_2sum_high + (StableHLO_AddOp:$_add_2sum_high_0_ + (StableHLO_AddOp:$_add_2sum_high_1_ + (StableHLO_AddOp:$_add_2sum_high_2_ + (StableHLO_AddOp:$x2h $x, $x), + $square_dekker_high), + (StableHLO_MulOp:$_square_dekker_high_0_ $x, $x)), + (StableHLO_AddOp:$square_dekker_low + (StableHLO_AddOp + (StableHLO_AddOp + (StableHLO_AddOp + (StableHLO_NegOp $square_dekker_high), + (StableHLO_MulOp + (StableHLO_AddOp:$yh + (StableHLO_MulOp:$multiply_veltkamp_splitter_constant_y + (StableHLO_SelectOp:$veltkamp_splitter_constant + (StableHLO_CompareOp + $largest, + (StableHLO_ConstantLike<"1e+308"> $x), + StableHLO_ComparisonDirectionValue<"GT">, + (STABLEHLO_DEFAULT_COMPARISON_TYPE)), + (StableHLO_ConstantLike<"134217729"> $x), + (StableHLO_SelectOp + (StableHLO_CompareOp + $largest, + (StableHLO_ConstantLike<"1e+38"> $x), + StableHLO_ComparisonDirectionValue<"GT">, + (STABLEHLO_DEFAULT_COMPARISON_TYPE)), + (StableHLO_ConstantLike<"4097"> $x), + (StableHLO_ConstantLike<"65"> $x))), + $y), + (StableHLO_SubtractOp $y, $multiply_veltkamp_splitter_constant_y)), + $yh)), + (StableHLO_MulOp:$multiply_yh_yl + $yh, + (StableHLO_SubtractOp:$yl $y, $yh))), + $multiply_yh_yl), + (StableHLO_MulOp $yl, $yl))), + (StableHLO_AddOp:$_square_dekker_low_0_ + (StableHLO_AddOp + (StableHLO_AddOp + (StableHLO_AddOp + (StableHLO_NegOp $_square_dekker_high_0_), + (StableHLO_MulOp + (StableHLO_AddOp:$xh + (StableHLO_MulOp:$multiply_veltkamp_splitter_constant_x $veltkamp_splitter_constant, $x), + (StableHLO_SubtractOp $x, $multiply_veltkamp_splitter_constant_x)), + $xh)), + (StableHLO_MulOp:$multiply_xh_xl + $xh, + (StableHLO_SubtractOp:$xl $x, $xh))), + $multiply_xh_xl), + (StableHLO_MulOp $xl, $xl))), + (StableHLO_AddOp + (StableHLO_AddOp + (StableHLO_AddOp + (StableHLO_AddOp:$add_2sum_low + (StableHLO_SubtractOp + $x2h, + (StableHLO_SubtractOp + $_add_2sum_high_2_, + (StableHLO_SubtractOp:$subtract__add_2sum_high_2__x2h $_add_2sum_high_2_, $x2h))), + (StableHLO_SubtractOp $square_dekker_high, $subtract__add_2sum_high_2__x2h)), + (StableHLO_AddOp:$_add_2sum_low_0_ + (StableHLO_SubtractOp + $_add_2sum_high_2_, + (StableHLO_SubtractOp + $_add_2sum_high_1_, + (StableHLO_SubtractOp:$subtract__add_2sum_high_1___add_2sum_high_2_ $_add_2sum_high_1_, $_add_2sum_high_2_))), + (StableHLO_SubtractOp $_square_dekker_high_0_, $subtract__add_2sum_high_1___add_2sum_high_2_))), + (StableHLO_AddOp:$_add_2sum_low_1_ + (StableHLO_SubtractOp + $_add_2sum_high_1_, + (StableHLO_SubtractOp + $_add_2sum_high_0_, + (StableHLO_SubtractOp:$subtract__add_2sum_high_0___add_2sum_high_1_ $_add_2sum_high_0_, $_add_2sum_high_1_))), + (StableHLO_SubtractOp $square_dekker_low, $subtract__add_2sum_high_0___add_2sum_high_1_))), + (StableHLO_AddOp:$_add_2sum_low_2_ + (StableHLO_SubtractOp + $_add_2sum_high_0_, + (StableHLO_SubtractOp + $add_2sum_high, + (StableHLO_SubtractOp:$subtract_add_2sum_high__add_2sum_high_0_ $add_2sum_high, $_add_2sum_high_0_))), + (StableHLO_SubtractOp $_square_dekker_low_0_, $subtract_add_2sum_high__add_2sum_high_0_)))))))), + (StableHLO_Atan2Op $y, $xp1))>;