From adbeeca967549059dec9880fcc27806cdb78eb11 Mon Sep 17 00:00:00 2001 From: Pearu Peterson Date: Thu, 19 Dec 2024 17:36:17 +0200 Subject: [PATCH] Add StableHLO complex log1p operation. Add pass stablehlo-complex-math-expander (#2636) As in the title. This PR introduces a new pass stablehlo-complex-math-expander that expands StableHLO complex functions in terms of StableHLO real functions. Currently, only StableHLO_Log1pOp on complex inputs is included in this expander, more will be added as follow-ups. The provided complex `log1p` operation fixes the inaccuracy problems in jax.numpy.log1p for complex inputs `x+i*y` when `x` is close to `-0.5 * y * y` which triggers catastrophic cancellations when using straightforward definition of log1p on complex inputs. The current state of `jax.numpy.log1p` inaccuracies is given in https://github.com/pearu/functional_algorithms/issues/47 . With this PR, the accuracy statistics of log1p is: ``` complex64: ULP difference == 0 count is 961936 ULP difference == 1 count is 39644 ULP difference == 2 count is 455 ULP difference == 3 count is 0 ULP difference >= 4 count is 0 complex128: ULP difference == 0 count is 988144 ULP difference == 1 count is 13891 ULP difference == 2 count is 0 ULP difference == 3 count is 0 ULP difference >= 4 count is 0 ``` --- BUILD.bazel | 17 ++ build_tools/math/README.md | 16 +- .../generate_ChloDecompositionPatternsMath.py | 57 +++- build_tools/math/generate_tests.py | 66 +++-- docs/generated/stablehlo_passes.md | 27 ++ docs/spec.md | 3 +- .../tests/math/log_plus_one_complex128.mlir | 19 ++ .../tests/math/log_plus_one_complex64.mlir | 19 ++ .../tests/math/log_plus_one_float32.mlir | 19 ++ .../tests/math/log_plus_one_float64.mlir | 19 ++ .../stablehlo_complex_math_expander.mlir | 110 +++++++ stablehlo/transforms/CMakeLists.txt | 6 + .../ChloDecompositionPatternsMath.td | 2 +- stablehlo/transforms/Passes.h | 5 + stablehlo/transforms/Passes.td | 34 +++ .../StablehloComplexMathExpander.cpp | 76 +++++ .../StablehloComplexMathExpanderPatterns.td | 275 ++++++++++++++++++ 17 files changed, 724 insertions(+), 46 deletions(-) create mode 100644 stablehlo/tests/math/log_plus_one_complex128.mlir create mode 100644 stablehlo/tests/math/log_plus_one_complex64.mlir create mode 100644 stablehlo/tests/math/log_plus_one_float32.mlir create mode 100644 stablehlo/tests/math/log_plus_one_float64.mlir create mode 100644 stablehlo/tests/stablehlo_complex_math_expander.mlir create mode 100644 stablehlo/transforms/StablehloComplexMathExpander.cpp create mode 100644 stablehlo/transforms/StablehloComplexMathExpanderPatterns.td diff --git a/BUILD.bazel b/BUILD.bazel index 8d2b26c32a..866e971507 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -371,6 +371,21 @@ gentbl_cc_library( ], ) +gentbl_cc_library( + name = "stablehlo_create_complex_math_expander_inc_gen", + tbl_outs = [ + ( + ["--gen-rewriters"], + "stablehlo/transforms/StablehloComplexMathExpanderPatterns.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "stablehlo/transforms/StablehloComplexMathExpanderPatterns.td", + deps = [ + ":stablehlo_ops_td_files", + ], +) + cc_library( name = "interpreter_ops", srcs = [ @@ -1121,6 +1136,7 @@ cc_library( "stablehlo/transforms/StablehloAggressiveSimplification.cpp", "stablehlo/transforms/StablehloCanonicalizeDynamism.cpp", "stablehlo/transforms/StablehloCompatibilityExpander.cpp", + "stablehlo/transforms/StablehloComplexMathExpander.cpp", "stablehlo/transforms/StablehloConvertToSignless.cpp", "stablehlo/transforms/StablehloLegalizeCompositeToCall.cpp", "stablehlo/transforms/StablehloLegalizeDeprecatedOps.cpp", @@ -1149,6 +1165,7 @@ cc_library( ":linalg_passes", ":stablehlo_aggressive_simplification_inc_gen", ":stablehlo_create_compatibility_expander_inc_gen", + ":stablehlo_create_complex_math_expander_inc_gen", ":stablehlo_legalize_deprecated_ops_inc_gen", ":stablehlo_ops", ":stablehlo_ops_inc_gen", diff --git a/build_tools/math/README.md b/build_tools/math/README.md index 8e3a8f180d..e2f10a8881 100644 --- a/build_tools/math/README.md +++ b/build_tools/math/README.md @@ -31,7 +31,7 @@ following requirements: - Python 3.11 or newer - mpmath 1.3 or newer -- functional_algorithms 0.11.1 or newer +- functional_algorithms 0.12 or newer that can be installed via pypi: @@ -62,7 +62,7 @@ To execute generated tests from a `build` directory, use: ```sh for t in $(ls ../stablehlo/tests/math/*.mlir); \ -do echo $t && ( bin/stablehlo-opt --chlo-legalize-to-stablehlo $t \ +do echo $t && ( bin/stablehlo-opt --stablehlo-complex-math-expander --chlo-legalize-to-stablehlo $t \ | bin/stablehlo-translate --interpret 2>&1 | grep "^ULP difference" ) ; done ``` @@ -77,6 +77,14 @@ build/bin/stablehlo-opt --chlo-legalize-to-stablehlo --split-input-file --verify and copy relevant checks to `chlo_legalize_to_stablehlo.mlir`. +A similar procedure is applied for updating +`stablehlo/tests/stablehlo_complex_math_expander.mlir`: + +```sh +build/bin/stablehlo-opt --stablehlo-complex-math-expander --split-input-file --verify-diagnostics \ + stablehlo/tests/stablehlo_complex_math_expander.mlir | python llvm-project/mlir/utils/generate-test-checks.py | less +``` + ## A procedure for adding a new algorithm to an existing operation 1. Implement a new algorithm in @@ -98,6 +106,10 @@ and copy relevant checks to `chlo_legalize_to_stablehlo.mlir`. 7. Add a record of the operation to `generate_ChloDecompositionPatternsMath.py`, see the for-loop in `main` function. + - If the operation is a StableHLO operation on complex inputs, add + it to `stable-complex-math-expander` pass: update + `populateStablehloComplexMathExpanderPatterns` function in + `stablehlo/transforms/StablehloComplexMathExpander.cpp`. 8. Generate new implementations by running `generate_ChloDecompositionPatternsMath.py` and remove existing implementations in diff --git a/build_tools/math/generate_ChloDecompositionPatternsMath.py b/build_tools/math/generate_ChloDecompositionPatternsMath.py index 885849ce89..62b99474dc 100644 --- a/build_tools/math/generate_ChloDecompositionPatternsMath.py +++ b/build_tools/math/generate_ChloDecompositionPatternsMath.py @@ -44,7 +44,7 @@ def get_functional_algorithms_required_version(): ) -def main(): +def main(kind="CHLO"): try: import functional_algorithms as fa except ImportError as msg: @@ -64,16 +64,15 @@ def main(): warnings.warn(msg) return + output_filename = dict( + CHLO="ChloDecompositionPatternsMath.td", + StableHLO="StablehloComplexMathExpanderPatterns.td", + )[kind] + output_file = os.path.relpath( os.path.normpath( - os.path.join( - os.path.dirname(__file__), - "..", - "..", - "stablehlo", - "transforms", - "ChloDecompositionPatternsMath.td", - )), + os.path.join(os.path.dirname(__file__), "..", "..", "stablehlo", + "transforms", output_filename)), os.getcwd(), ) @@ -98,13 +97,15 @@ def main(): ("CHLO_AtanhOp", "complex_atanh", ("z:complex",)), ("CHLO_SquareOp", "complex_square", ("z:complex",)), ("CHLO_SquareOp", "real_square", ("x:float",)), + ("StableHLO_Log1pOp", "complex_log1p", ("z:complex",)), ]: + if not chloname.startswith(kind): + continue print(f'Generating {chloname} from {fname}{args}') func = getattr(fa.algorithms, fname, None) if func is None: warnings.warn( - f"{fa.algorithms.__name__} does not define {fname}. Skipping." - ) + f"{fa.algorithms.__name__} does not define {fname}. Skipping.") continue ctx = fa.Context(paths=[fa.algorithms], parameters=dict(rewrite_keep_integer_literals=True)) @@ -115,6 +116,16 @@ def main(): sources[-1] += src source = "\n\n".join(sources) + "\n" + if chloname.startswith('StableHLO_'): + # an ugly hack to fix the definition of stablehlo complex math + # functions. TODO(pearu): add the corresponding feature to + # functional_algorithms stablehlo printer + NameOp = chloname.split('_', 1)[1] + source = source.replace( + f'def : Pat<({chloname}', + f'def {NameOp}_ComplexElementType_ComplexMathExpander : Pat<({chloname}' + ) + if os.path.isfile(output_file): f = open(output_file, "r") content = f.read() @@ -146,10 +157,32 @@ def main(): This file is generated using functional_algorithms tool ({fa.__version__}). See build_tools/math/README.md for more information.""") + "\n") + + if kind == "StableHLO": + f.write("""\ +include "mlir/IR/OpBase.td" +include "stablehlo/dialect/StablehloOps.td" + +class StableHLO_ComparisonDirectionValue : + ConstantAttr; + +class StableHLO_ConstantLike : NativeCodeCall< + "::mlir::stablehlo::getConstantLike($_builder, $_loc, " # value # ", $0)">; + +def ComplexElementType : Type< + CPred<"isa(cast($_self).getElementType())">, + "Complex element type">; + +def StableHLO_ConstantLikeMaxFiniteValue : NativeCodeCall< + "::mlir::stablehlo::getConstantLikeMaxFiniteValue($_builder, $_loc, $0)">; + +""") f.write(source) f.close() print(f"Created {output_file}") if __name__ == "__main__": - main() + main(kind="CHLO") + main(kind="StableHLO") diff --git a/build_tools/math/generate_tests.py b/build_tools/math/generate_tests.py index ecd413cda2..fe20a1ae0b 100644 --- a/build_tools/math/generate_tests.py +++ b/build_tools/math/generate_tests.py @@ -43,27 +43,31 @@ default_max_ulp_difference = 1 operations = [ - # The following dictionaries may have additional keys like - # - # size - defines the number of samples: size ** 2 - # - # max_ulp_difference - the maximal allowed ULP difference between - # function and reference values - # - # extra_prec_multiplier - the precison multiplier for mpmath.mp - # that defines the precision of computing reference values: - # mpmath.mp.prec * extra_prec_multiplier - # - # When unspecifed, these parameters are retrieved from - # functional_algorithms database of support functions. - # - dict(name="asin", mpmath_name="arcsin"), - dict(name="acos", mpmath_name="arccos"), - dict(name="atan", mpmath_name="arctan"), - dict(name="asinh", mpmath_name="arcsinh"), - dict(name="acosh", mpmath_name="arccosh"), - dict(name="atanh", mpmath_name="arctanh"), - dict(name="square", mpmath_name="square"), + # The following dictionaries may have additional keys like + # + # size - defines the number of samples: size ** 2 + # + # max_ulp_difference - the maximal allowed ULP difference between + # function and reference values + # + # extra_prec_multiplier - the precison multiplier for mpmath.mp + # that defines the precision of computing reference values: + # mpmath.mp.prec * extra_prec_multiplier + # + # When unspecifed, these parameters are retrieved from + # functional_algorithms database of support functions. + # + dict(name="asin", mpmath_name="arcsin"), + dict(name="acos", mpmath_name="arccos"), + dict(name="atan", mpmath_name="arctan"), + dict(name="asinh", mpmath_name="arcsinh"), + dict(name="acosh", mpmath_name="arccosh"), + dict(name="atanh", mpmath_name="arctanh"), + dict(name="square", mpmath_name="square"), + dict(name="log_plus_one", + mpmath_name="log1p", + namespace="stablehlo", + passes="--stablehlo-complex-math-expander"), ] @@ -127,19 +131,21 @@ def main(): for op in operations: opname = op["name"] mpmath_opname = op.get("mpmath_name", opname) + namespace = op.get("namespace", "chlo") size_re = size_im = op.get("size", default_size) - + passes = op.get("passes", "--chlo-legalize-to-stablehlo") for dtype in [np.complex64, np.complex128, np.float32, np.float64]: params = fa.utils.function_validation_parameters(opname, dtype) max_ulp_difference = op.get( - "max_ulp_difference", - params.get("max_valid_ulp_count", default_max_ulp_difference)) + "max_ulp_difference", + params.get("max_valid_ulp_count", default_max_ulp_difference)) nmp = fa.utils.numpy_with_mpmath( - extra_prec_multiplier = op.get( - "extra_prec_multiplier", - params.get("extra_prec_multiplier", default_extra_prec_multiplier)), - flush_subnormals=flush_subnormals, + extra_prec_multiplier=op.get( + "extra_prec_multiplier", + params.get("extra_prec_multiplier", + default_extra_prec_multiplier)), + flush_subnormals=flush_subnormals, ) fi = np.finfo(dtype) @@ -180,7 +186,7 @@ def main(): main_func = m.make_function("main", "", "", "public") ref_samples = main_func.call("samples") - actual = main_func.composite(f"chlo.{opname}", ref_samples) + actual = main_func.composite(f"{namespace}.{opname}", ref_samples) expected = main_func.call("expected") main_func.void_call( @@ -202,7 +208,7 @@ def main(): continue f = open(fname, "w") - f.write("// RUN: stablehlo-opt --chlo-legalize-to-stablehlo %s |" + f.write(f"// RUN: stablehlo-opt {passes} %s |" " stablehlo-translate --interpret\n") f.write( "// This file is generated, see build_tools/math/README.md for more" diff --git a/docs/generated/stablehlo_passes.md b/docs/generated/stablehlo_passes.md index 4f5b6e7447..2fdda99136 100755 --- a/docs/generated/stablehlo_passes.md +++ b/docs/generated/stablehlo_passes.md @@ -86,6 +86,33 @@ func.func @tan_op_non_complex(%arg0: tensor<4xf64>) -> tensor<4xf64> { ``` -target : The target version. Must be a version of the form #.#.#. ``` +### `-stablehlo-complex-math-expander` + +_Expander for StableHLO complex math operations._ + +StableHLO complex math operations are decompositions using +StableHLO real math operations. + +This statement is based on the assumption that no hardware exists +that supports complex numbers nor complex math operations +natively. This means that the fallback mechanisms on complex math +operations that compilers may implement, are redundant. With +enabling this pass, all StableHLO complex math operations will be +expanded. + +```mlir +func.func @sqrt_op_complex(%arg0: tensor<4xcomplex>) -> tensor<4xcomplex> { + %1 = stablehlo.sqrt %arg0 : tensor<4xcomplex> + func.return %1 : tensor<4xcomplex> +} + +==> + +func.func @sqrt_op_complex(%arg0: tensor<4xcomplex>) -> tensor<4xcomplex> { + TBD + return %2 : tensor<4xcomplex> +} +``` ### `-stablehlo-convert-to-signless` _Pass to transform the IR to be on signless integers._ diff --git a/docs/spec.md b/docs/spec.md index 48760dbd96..e59de3f716 100644 --- a/docs/spec.md +++ b/docs/spec.md @@ -4066,7 +4066,8 @@ Performs element-wise logarithm plus one operation on `operand` tensor and produces a `result` tensor. Depending on the element type, does the following: * For floats: `logp1` from IEEE-754. -* For complex numbers: complex logarithm plus one. +* For complex numbers: + `complex(log(hypot(real(x) + 1, imag(x))), atan2(imag(x), real(x) + 1))` * For quantized types: `dequantize_op_quantize(log_plus_one, operand, type(result))`. diff --git a/stablehlo/tests/math/log_plus_one_complex128.mlir b/stablehlo/tests/math/log_plus_one_complex128.mlir new file mode 100644 index 0000000000..5d95770d4d --- /dev/null +++ b/stablehlo/tests/math/log_plus_one_complex128.mlir @@ -0,0 +1,19 @@ +// RUN: stablehlo-opt --stablehlo-complex-math-expander %s | stablehlo-translate --interpret +// This file is generated, see build_tools/math/README.md for more information. +module @log_plus_one_complex128 { + func.func private @samples() -> tensor<169xcomplex> { + %0 = stablehlo.constant dense<"0x000000000000F0FF000000000000F0FFFFFFFFFFFFFFEFFF000000000000F0FFFEFFFFFFFFFFEFFF000000000000F0FF000000000000F8BF000000000000F0FF000000000000FC9F000000000000F0FF0100000000000080000000000000F0FF0000000000000000000000000000F0FF0100000000000000000000000000F0FF000000000000FC1F000000000000F0FF000000000000F83F000000000000F0FFFEFFFFFFFFFFEF7F000000000000F0FFFFFFFFFFFFFFEF7F000000000000F0FF000000000000F07F000000000000F0FF000000000000F0FFFFFFFFFFFFFFEFFFFFFFFFFFFFFFEFFFFFFFFFFFFFFFEFFFFEFFFFFFFFFFEFFFFFFFFFFFFFFFEFFF000000000000F8BFFFFFFFFFFFFFEFFF000000000000FC9FFFFFFFFFFFFFEFFF0100000000000080FFFFFFFFFFFFEFFF0000000000000000FFFFFFFFFFFFEFFF0100000000000000FFFFFFFFFFFFEFFF000000000000FC1FFFFFFFFFFFFFEFFF000000000000F83FFFFFFFFFFFFFEFFFFEFFFFFFFFFFEF7FFFFFFFFFFFFFEFFFFFFFFFFFFFFFEF7FFFFFFFFFFFFFEFFF000000000000F07FFFFFFFFFFFFFEFFF000000000000F0FFFEFFFFFFFFFFEFFFFFFFFFFFFFFFEFFFFEFFFFFFFFFFEFFFFEFFFFFFFFFFEFFFFEFFFFFFFFFFEFFF000000000000F8BFFEFFFFFFFFFFEFFF000000000000FC9FFEFFFFFFFFFFEFFF0100000000000080FEFFFFFFFFFFEFFF0000000000000000FEFFFFFFFFFFEFFF0100000000000000FEFFFFFFFFFFEFFF000000000000FC1FFEFFFFFFFFFFEFFF000000000000F83FFEFFFFFFFFFFEFFFFEFFFFFFFFFFEF7FFEFFFFFFFFFFEFFFFFFFFFFFFFFFEF7FFEFFFFFFFFFFEFFF000000000000F07FFEFFFFFFFFFFEFFF000000000000F0FF000000000000F8BFFFFFFFFFFFFFEFFF000000000000F8BFFEFFFFFFFFFFEFFF000000000000F8BF000000000000F8BF000000000000F8BF000000000000FC9F000000000000F8BF0100000000000080000000000000F8BF0000000000000000000000000000F8BF0100000000000000000000000000F8BF000000000000FC1F000000000000F8BF000000000000F83F000000000000F8BFFEFFFFFFFFFFEF7F000000000000F8BFFFFFFFFFFFFFEF7F000000000000F8BF000000000000F07F000000000000F8BF000000000000F0FF000000000000FC9FFFFFFFFFFFFFEFFF000000000000FC9FFEFFFFFFFFFFEFFF000000000000FC9F000000000000F8BF000000000000FC9F000000000000FC9F000000000000FC9F0100000000000080000000000000FC9F0000000000000000000000000000FC9F0100000000000000000000000000FC9F000000000000FC1F000000000000FC9F000000000000F83F000000000000FC9FFEFFFFFFFFFFEF7F000000000000FC9FFFFFFFFFFFFFEF7F000000000000FC9F000000000000F07F000000000000FC9F000000000000F0FF0100000000000080FFFFFFFFFFFFEFFF0100000000000080FEFFFFFFFFFFEFFF0100000000000080000000000000F8BF0100000000000080000000000000FC9F0100000000000080010000000000008001000000000000800000000000000000010000000000008001000000000000000100000000000080000000000000FC1F0100000000000080000000000000F83F0100000000000080FEFFFFFFFFFFEF7F0100000000000080FFFFFFFFFFFFEF7F0100000000000080000000000000F07F0100000000000080000000000000F0FF0000000000000000FFFFFFFFFFFFEFFF0000000000000000FEFFFFFFFFFFEFFF0000000000000000000000000000F8BF0000000000000000000000000000FC9F0000000000000000010000000000008000000000000000000000000000000000000000000000000001000000000000000000000000000000000000000000FC1F0000000000000000000000000000F83F0000000000000000FEFFFFFFFFFFEF7F0000000000000000FFFFFFFFFFFFEF7F0000000000000000000000000000F07F0000000000000000000000000000F0FF0100000000000000FFFFFFFFFFFFEFFF0100000000000000FEFFFFFFFFFFEFFF0100000000000000000000000000F8BF0100000000000000000000000000FC9F0100000000000000010000000000008001000000000000000000000000000000010000000000000001000000000000000100000000000000000000000000FC1F0100000000000000000000000000F83F0100000000000000FEFFFFFFFFFFEF7F0100000000000000FFFFFFFFFFFFEF7F0100000000000000000000000000F07F0100000000000000000000000000F0FF000000000000FC1FFFFFFFFFFFFFEFFF000000000000FC1FFEFFFFFFFFFFEFFF000000000000FC1F000000000000F8BF000000000000FC1F000000000000FC9F000000000000FC1F0100000000000080000000000000FC1F0000000000000000000000000000FC1F0100000000000000000000000000FC1F000000000000FC1F000000000000FC1F000000000000F83F000000000000FC1FFEFFFFFFFFFFEF7F000000000000FC1FFFFFFFFFFFFFEF7F000000000000FC1F000000000000F07F000000000000FC1F000000000000F0FF000000000000F83FFFFFFFFFFFFFEFFF000000000000F83FFEFFFFFFFFFFEFFF000000000000F83F000000000000F8BF000000000000F83F000000000000FC9F000000000000F83F0100000000000080000000000000F83F0000000000000000000000000000F83F0100000000000000000000000000F83F000000000000FC1F000000000000F83F000000000000F83F000000000000F83FFEFFFFFFFFFFEF7F000000000000F83FFFFFFFFFFFFFEF7F000000000000F83F000000000000F07F000000000000F83F000000000000F0FFFEFFFFFFFFFFEF7FFFFFFFFFFFFFEFFFFEFFFFFFFFFFEF7FFEFFFFFFFFFFEFFFFEFFFFFFFFFFEF7F000000000000F8BFFEFFFFFFFFFFEF7F000000000000FC9FFEFFFFFFFFFFEF7F0100000000000080FEFFFFFFFFFFEF7F0000000000000000FEFFFFFFFFFFEF7F0100000000000000FEFFFFFFFFFFEF7F000000000000FC1FFEFFFFFFFFFFEF7F000000000000F83FFEFFFFFFFFFFEF7FFEFFFFFFFFFFEF7FFEFFFFFFFFFFEF7FFFFFFFFFFFFFEF7FFEFFFFFFFFFFEF7F000000000000F07FFEFFFFFFFFFFEF7F000000000000F0FFFFFFFFFFFFFFEF7FFFFFFFFFFFFFEFFFFFFFFFFFFFFFEF7FFEFFFFFFFFFFEFFFFFFFFFFFFFFFEF7F000000000000F8BFFFFFFFFFFFFFEF7F000000000000FC9FFFFFFFFFFFFFEF7F0100000000000080FFFFFFFFFFFFEF7F0000000000000000FFFFFFFFFFFFEF7F0100000000000000FFFFFFFFFFFFEF7F000000000000FC1FFFFFFFFFFFFFEF7F000000000000F83FFFFFFFFFFFFFEF7FFEFFFFFFFFFFEF7FFFFFFFFFFFFFEF7FFFFFFFFFFFFFEF7FFFFFFFFFFFFFEF7F000000000000F07FFFFFFFFFFFFFEF7F000000000000F0FF000000000000F07FFFFFFFFFFFFFEFFF000000000000F07FFEFFFFFFFFFFEFFF000000000000F07F000000000000F8BF000000000000F07F000000000000FC9F000000000000F07F0100000000000080000000000000F07F0000000000000000000000000000F07F0100000000000000000000000000F07F000000000000FC1F000000000000F07F000000000000F83F000000000000F07FFEFFFFFFFFFFEF7F000000000000F07FFFFFFFFFFFFFEF7F000000000000F07F000000000000F07F000000000000F07F"> : tensor<169xcomplex> + return %0 : tensor<169xcomplex> + } + func.func private @expected() -> tensor<169xcomplex> { + %0 = stablehlo.constant dense<"0x000000000000F07FD221337F7CD902C0000000000000F07F182D4454FB21F9BF000000000000F07F182D4454FB21F9BF000000000000F07F182D4454FB21F9BF000000000000F07F182D4454FB21F9BF000000000000F07F182D4454FB21F9BF000000000000F07F182D4454FB21F9BF000000000000F07F182D4454FB21F9BF000000000000F07F182D4454FB21F9BF000000000000F07F182D4454FB21F9BF000000000000F07F182D4454FB21F9BF000000000000F07F182D4454FB21F9BF000000000000F07F182D4454FB21E9BF000000000000F07F182D4454FB2109C036195AC708318640D221337F7CD902C036195AC708318640D221337F7CD902C0EF39FAFE422E8640182D4454FB21F9BFEF39FAFE422E8640182D4454FB21F9BFEF39FAFE422E8640182D4454FB21F9BFEF39FAFE422E8640182D4454FB21F9BFEF39FAFE422E8640182D4454FB21F9BFEF39FAFE422E8640182D4454FB21F9BFEF39FAFE422E8640182D4454FB21F9BF36195AC708318640192D4454FB21E9BF36195AC708318640182D4454FB21E9BF000000000000F07F0000000000000000000000000000F07F182D4454FB2109C036195AC708318640D221337F7CD902C036195AC708318640D221337F7CD902C0EF39FAFE422E8640182D4454FB21F9BFEF39FAFE422E8640182D4454FB21F9BFEF39FAFE422E8640182D4454FB21F9BFEF39FAFE422E8640182D4454FB21F9BFEF39FAFE422E8640182D4454FB21F9BFEF39FAFE422E8640182D4454FB21F9BFEF39FAFE422E8640182D4454FB21F9BF36195AC708318640182D4454FB21E9BF36195AC708318640182D4454FB21E9BF000000000000F07F0000000000000000000000000000F07F182D4454FB2109C0EF39FAFE422E8640182D4454FB2109C0EF39FAFE422E8640182D4454FB2109C078E0E0F04052DD3FD1D40D3DDF47FEBFB3DE6857C5DBE23F9BF681D20B73EFBFB3DE6857C5DBE23F9BF681D20B73EFBFB3DE6857C5DBE23F9BF681D20B73EFBFB3DE6857C5DBE23F9BF681D20B73EFBFB3DE6857C5DBE23F9BF681D20B73EFBFFEF03B02DB1EF13FE10CF9D51D4BE1BFEF39FAFE422E86400000000000000680EF39FAFE422E86400000000000000680000000000000F07F0000000000000000000000000000F07F182D4454FB2109C0EF39FAFE422E8640182D4454FB2109C0EF39FAFE422E8640182D4454FB2109C0EF39FAFE422EE6BF182D4454FB2109C0000000000000FC9F000000000000FC9FFFFFFFFFFF1F0600000000000000FC9F0000000000200600000000000000FC9F0100000000200600000000000000FC9F000000000000FC1F000000000000FC9F78E0E0F04052ED3F666666666666E69FEF39FAFE422E86400000000000000080EF39FAFE422E86400000000000000080000000000000F07F0000000000000000000000000000F07F182D4454FB2109C0EF39FAFE422E8640182D4454FB2109C0EF39FAFE422E8640182D4454FB2109C0EF39FAFE422EE6BF182D4454FB2109C0000000000000FC9F0100000000000080010000000000008001000000000000800000000000000000010000000000008001000000000000000100000000000080000000000000FC1F010000000000008078E0E0F04052ED3F0000000000000080EF39FAFE422E86400000000000000080EF39FAFE422E86400000000000000080000000000000F07F0000000000000000000000000000F07F182D4454FB210940EF39FAFE422E8640182D4454FB210940EF39FAFE422E8640182D4454FB210940EF39FAFE422EE6BF182D4454FB210940000000000000FC9F0000000000000000010000000000008000000000000000000000000000000000000000000000000001000000000000000000000000000000000000000000FC1F000000000000000078E0E0F04052ED3F0000000000000000EF39FAFE422E86400000000000000000EF39FAFE422E86400000000000000000000000000000F07F0000000000000000000000000000F07F182D4454FB210940EF39FAFE422E8640182D4454FB210940EF39FAFE422E8640182D4454FB210940EF39FAFE422EE6BF182D4454FB210940000000000000FC9F0100000000000000010000000000008001000000000000000000000000000000010000000000000001000000000000000100000000000000000000000000FC1F010000000000000078E0E0F04052ED3F0000000000000000EF39FAFE422E86400000000000000000EF39FAFE422E86400000000000000000000000000000F07F0000000000000000000000000000F07F182D4454FB210940EF39FAFE422E8640182D4454FB210940EF39FAFE422E8640182D4454FB210940EF39FAFE422EE6BF182D4454FB210940000000000000FC9F000000000000FC1FFFFFFFFFFF1F0600000000000000FC1F0000000000200600000000000000FC1F0100000000200600000000000000FC1F000000000000FC1F000000000000FC1F78E0E0F04052ED3F666666666666E61FEF39FAFE422E86400000000000000000EF39FAFE422E86400000000000000000000000000000F07F0000000000000000000000000000F07F182D4454FB210940EF39FAFE422E8640182D4454FB210940EF39FAFE422E8640182D4454FB21094078E0E0F04052DD3FD1D40D3DDF47FE3FB3DE6857C5DBE23F9BF681D20B73EF3FB3DE6857C5DBE23F9BF681D20B73EF3FB3DE6857C5DBE23F9BF681D20B73EF3FB3DE6857C5DBE23F9BF681D20B73EF3FB3DE6857C5DBE23F9BF681D20B73EF3FFEF03B02DB1EF13FE10CF9D51D4BE13FEF39FAFE422E86400000000000000600EF39FAFE422E86400000000000000600000000000000F07F0000000000000000000000000000F07F182D4454FB21094036195AC708318640D221337F7CD9024036195AC708318640D221337F7CD90240EF39FAFE422E8640182D4454FB21F93FEF39FAFE422E8640182D4454FB21F93FEF39FAFE422E8640182D4454FB21F93FEF39FAFE422E8640182D4454FB21F93FEF39FAFE422E8640182D4454FB21F93FEF39FAFE422E8640182D4454FB21F93FEF39FAFE422E8640182D4454FB21F93F36195AC708318640182D4454FB21E93F36195AC708318640182D4454FB21E93F000000000000F07F0000000000000000000000000000F07F182D4454FB21094036195AC708318640D221337F7CD9024036195AC708318640D221337F7CD90240EF39FAFE422E8640182D4454FB21F93FEF39FAFE422E8640182D4454FB21F93FEF39FAFE422E8640182D4454FB21F93FEF39FAFE422E8640182D4454FB21F93FEF39FAFE422E8640182D4454FB21F93FEF39FAFE422E8640182D4454FB21F93FEF39FAFE422E8640182D4454FB21F93F36195AC708318640192D4454FB21E93F36195AC708318640182D4454FB21E93F000000000000F07F0000000000000000000000000000F07FD221337F7CD90240000000000000F07F182D4454FB21F93F000000000000F07F182D4454FB21F93F000000000000F07F182D4454FB21F93F000000000000F07F182D4454FB21F93F000000000000F07F182D4454FB21F93F000000000000F07F182D4454FB21F93F000000000000F07F182D4454FB21F93F000000000000F07F182D4454FB21F93F000000000000F07F182D4454FB21F93F000000000000F07F182D4454FB21F93F000000000000F07F182D4454FB21F93F000000000000F07F182D4454FB21E93F"> : tensor<169xcomplex> + return %0 : tensor<169xcomplex> + } + func.func public @main() { + %0 = call @samples() : () -> tensor<169xcomplex> + %1 = "stablehlo.log_plus_one"(%0) : (tensor<169xcomplex>) -> tensor<169xcomplex> + %2 = call @expected() : () -> tensor<169xcomplex> + check.expect_close %1, %2, max_ulp_difference = 3 : tensor<169xcomplex>, tensor<169xcomplex> + func.return + } +} diff --git a/stablehlo/tests/math/log_plus_one_complex64.mlir b/stablehlo/tests/math/log_plus_one_complex64.mlir new file mode 100644 index 0000000000..e00c93251a --- /dev/null +++ b/stablehlo/tests/math/log_plus_one_complex64.mlir @@ -0,0 +1,19 @@ +// RUN: stablehlo-opt --stablehlo-complex-math-expander %s | stablehlo-translate --interpret +// This file is generated, see build_tools/math/README.md for more information. +module @log_plus_one_complex64 { + func.func private @samples() -> tensor<169xcomplex> { + %0 = stablehlo.constant dense<"0x000080FF000080FFFFFF7FFF000080FFFEFF7FFF000080FF0000C0BF000080FF0000E09F000080FF01000080000080FF00000000000080FF01000000000080FF0000E01F000080FF0000C03F000080FFFEFF7F7F000080FFFFFF7F7F000080FF0000807F000080FF000080FFFFFF7FFFFFFF7FFFFFFF7FFFFEFF7FFFFFFF7FFF0000C0BFFFFF7FFF0000E09FFFFF7FFF01000080FFFF7FFF00000000FFFF7FFF01000000FFFF7FFF0000E01FFFFF7FFF0000C03FFFFF7FFFFEFF7F7FFFFF7FFFFFFF7F7FFFFF7FFF0000807FFFFF7FFF000080FFFEFF7FFFFFFF7FFFFEFF7FFFFEFF7FFFFEFF7FFF0000C0BFFEFF7FFF0000E09FFEFF7FFF01000080FEFF7FFF00000000FEFF7FFF01000000FEFF7FFF0000E01FFEFF7FFF0000C03FFEFF7FFFFEFF7F7FFEFF7FFFFFFF7F7FFEFF7FFF0000807FFEFF7FFF000080FF0000C0BFFFFF7FFF0000C0BFFEFF7FFF0000C0BF0000C0BF0000C0BF0000E09F0000C0BF010000800000C0BF000000000000C0BF010000000000C0BF0000E01F0000C0BF0000C03F0000C0BFFEFF7F7F0000C0BFFFFF7F7F0000C0BF0000807F0000C0BF000080FF0000E09FFFFF7FFF0000E09FFEFF7FFF0000E09F0000C0BF0000E09F0000E09F0000E09F010000800000E09F000000000000E09F010000000000E09F0000E01F0000E09F0000C03F0000E09FFEFF7F7F0000E09FFFFF7F7F0000E09F0000807F0000E09F000080FF01000080FFFF7FFF01000080FEFF7FFF010000800000C0BF010000800000E09F010000800100008001000080000000000100008001000000010000800000E01F010000800000C03F01000080FEFF7F7F01000080FFFF7F7F010000800000807F01000080000080FF00000000FFFF7FFF00000000FEFF7FFF000000000000C0BF000000000000E09F000000000100008000000000000000000000000001000000000000000000E01F000000000000C03F00000000FEFF7F7F00000000FFFF7F7F000000000000807F00000000000080FF01000000FFFF7FFF01000000FEFF7FFF010000000000C0BF010000000000E09F010000000100008001000000000000000100000001000000010000000000E01F010000000000C03F01000000FEFF7F7F01000000FFFF7F7F010000000000807F01000000000080FF0000E01FFFFF7FFF0000E01FFEFF7FFF0000E01F0000C0BF0000E01F0000E09F0000E01F010000800000E01F000000000000E01F010000000000E01F0000E01F0000E01F0000C03F0000E01FFEFF7F7F0000E01FFFFF7F7F0000E01F0000807F0000E01F000080FF0000C03FFFFF7FFF0000C03FFEFF7FFF0000C03F0000C0BF0000C03F0000E09F0000C03F010000800000C03F000000000000C03F010000000000C03F0000E01F0000C03F0000C03F0000C03FFEFF7F7F0000C03FFFFF7F7F0000C03F0000807F0000C03F000080FFFEFF7F7FFFFF7FFFFEFF7F7FFEFF7FFFFEFF7F7F0000C0BFFEFF7F7F0000E09FFEFF7F7F01000080FEFF7F7F00000000FEFF7F7F01000000FEFF7F7F0000E01FFEFF7F7F0000C03FFEFF7F7FFEFF7F7FFEFF7F7FFFFF7F7FFEFF7F7F0000807FFEFF7F7F000080FFFFFF7F7FFFFF7FFFFFFF7F7FFEFF7FFFFFFF7F7F0000C0BFFFFF7F7F0000E09FFFFF7F7F01000080FFFF7F7F00000000FFFF7F7F01000000FFFF7F7F0000E01FFFFF7F7F0000C03FFFFF7F7FFEFF7F7FFFFF7F7FFFFF7F7FFFFF7F7F0000807FFFFF7F7F000080FF0000807FFFFF7FFF0000807FFEFF7FFF0000807F0000C0BF0000807F0000E09F0000807F010000800000807F000000000000807F010000000000807F0000E01F0000807F0000C03F0000807FFEFF7F7F0000807FFFFF7F7F0000807F0000807F0000807F"> : tensor<169xcomplex> + return %0 : tensor<169xcomplex> + } + func.func private @expected() -> tensor<169xcomplex> { + %0 = stablehlo.constant dense<"0x0000807FE4CB16C00000807FDB0FC9BF0000807FDB0FC9BF0000807FDB0FC9BF0000807FDB0FC9BF0000807FDB0FC9BF0000807FDB0FC9BF0000807FDB0FC9BF0000807FDB0FC9BF0000807FDB0FC9BF0000807FDB0FC9BF0000807FDB0FC9BF0000807FDB0F49BF0000807FDB0F49C08A23B242E4CB16C08A23B242E4CB16C01872B142DB0FC9BF1872B142DB0FC9BF1872B142DB0FC9BF1872B142DB0FC9BF1872B142DB0FC9BF1872B142DB0FC9BF1872B142DB0FC9BF8A23B242DB0F49BF8A23B242DB0F49BF0000807F000000000000807FDB0F49C08A23B242E4CB16C08A23B242E4CB16C01872B142DB0FC9BF1872B142DB0FC9BF1872B142DB0FC9BF1872B142DB0FC9BF1872B142DB0FC9BF1872B142DB0FC9BF1872B142DB0FC9BF8A23B242DB0F49BF8A23B242DA0F49BF0000807F000000000000807FDB0F49C01872B142DB0F49C01872B142DB0F49C00892EA3EFA3EF2BF2BDE163F5F987BBF2BDE163F5F987BBF2BDE163F5F987BBF2BDE163F5F987BBF2BDE163F5F987BBFD8F6883FEF580ABF1872B142000030801872B142000030800000807F000000000000807FDB0F49C01872B142DB0F49C01872B142DB0F49C0187231BFDB0F49C00000E09F0000E09FFFFF30000000E09F000031000000E09F010031000000E09F0000E01F0000E09F08926A3F3333339F1872B142000000801872B142000000800000807F000000000000807FDB0F49C01872B142DB0F49C01872B142DB0F49C0187231BFDB0F49C00000E09F010000800100008001000080000000000100008001000000010000800000E01F0100008008926A3F000000801872B142000000801872B142000000800000807F000000000000807FDB0F49401872B142DB0F49401872B142DB0F4940187231BFDB0F49400000E09F000000000100008000000000000000000000000001000000000000000000E01F0000000008926A3F000000001872B142000000001872B142000000000000807F000000000000807FDB0F49401872B142DB0F49401872B142DB0F4940187231BFDB0F49400000E09F010000000100008001000000000000000100000001000000010000000000E01F0100000008926A3F000000001872B142000000001872B142000000000000807F000000000000807FDB0F49401872B142DB0F49401872B142DB0F4940187231BFDB0F49400000E09F0000E01FFFFF30000000E01F000031000000E01F010031000000E01F0000E01F0000E01F08926A3F3333331F1872B142000000001872B142000000000000807F000000000000807FDB0F49401872B142DB0F49401872B142DB0F49400892EA3EFA3EF23F2BDE163F5F987B3F2BDE163F5F987B3F2BDE163F5F987B3F2BDE163F5F987B3F2BDE163F5F987B3FD8F6883FEF580A3F1872B142000030001872B142000030000000807F000000000000807FDB0F49408A23B242E4CB16408A23B242E4CB16401872B142DB0FC93F1872B142DB0FC93F1872B142DB0FC93F1872B142DB0FC93F1872B142DB0FC93F1872B142DB0FC93F1872B142DB0FC93F8A23B242DB0F493F8A23B242DA0F493F0000807F000000000000807FDB0F49408A23B242E4CB16408A23B242E4CB16401872B142DB0FC93F1872B142DB0FC93F1872B142DB0FC93F1872B142DB0FC93F1872B142DB0FC93F1872B142DB0FC93F1872B142DB0FC93F8A23B242DB0F493F8A23B242DB0F493F0000807F000000000000807FE4CB16400000807FDB0FC93F0000807FDB0FC93F0000807FDB0FC93F0000807FDB0FC93F0000807FDB0FC93F0000807FDB0FC93F0000807FDB0FC93F0000807FDB0FC93F0000807FDB0FC93F0000807FDB0FC93F0000807FDB0FC93F0000807FDB0F493F"> : tensor<169xcomplex> + return %0 : tensor<169xcomplex> + } + func.func public @main() { + %0 = call @samples() : () -> tensor<169xcomplex> + %1 = "stablehlo.log_plus_one"(%0) : (tensor<169xcomplex>) -> tensor<169xcomplex> + %2 = call @expected() : () -> tensor<169xcomplex> + check.expect_close %1, %2, max_ulp_difference = 3 : tensor<169xcomplex>, tensor<169xcomplex> + func.return + } +} diff --git a/stablehlo/tests/math/log_plus_one_float32.mlir b/stablehlo/tests/math/log_plus_one_float32.mlir new file mode 100644 index 0000000000..4384eadb70 --- /dev/null +++ b/stablehlo/tests/math/log_plus_one_float32.mlir @@ -0,0 +1,19 @@ +// RUN: stablehlo-opt --stablehlo-complex-math-expander %s | stablehlo-translate --interpret +// This file is generated, see build_tools/math/README.md for more information. +module @log_plus_one_float32 { + func.func private @samples() -> tensor<169xf32> { + %0 = stablehlo.constant dense<"0x000080FFFFFF7FFFFEFF7FFF05E763FC88DAD5FA0BCE47F98EC1B9F711B52BF695A89DF4189C0FF39B8F81F11E83F3EFA17665EE246AD7ECA75D49EB2B51BBE9AE442DE831389FE6B42B11E5371F83E3BA12F5E13D0667E0C1F9D8DE44ED4ADDC7E0BCDB4AD42EDACDC7A0D850BB12D7D3AE84D557A2F6D3DA9568D25D89DAD0E07C4CCF6370BECDE66330CC6957A2CAED4A14C9703E86C7F331F8C576256AC4F918DCC27C0C4EC10000C0BF83F331BE06E7A3BC89DA15BB0CCE87B98FC1F9B712B56BB696A8DDB4199C4FB39C8FC1B11F8333B0A276A5AE256A17ADA85D89AB2C51FBA9AF446DA83238DFA6B52B51A5381FC3A3BB1235A23E06A7A0C2F9189F45ED8A9DC8E0FC9B4BD46E9ACEC7E09851BB5297D4AEC49558A23694DB95A8925E891A91E17C8C8F6470FE8DE763708C6A57E28AEE4A5489713EC687F43138867725AA84FA181C837D0C8E810100008000000000010000007D0C8E01FA181C037725AA04F4313806713EC607EE4A54096A57E20AE763700C6470FE0DE17C8C0F5E891A11DB95A81258A23614D4AEC41551BB5217CEC7E0184BD46E1AC8E0FC1B45ED8A1DC2F9181F3E06A720BB123522381FC323B52B51253238DF26AF446D282C51FB29A85D892B256A172DA276A52E1F8333309C8FC131199C4F3396A8DD3412B56B368FC1F9370CCE873989DA153B06E7A33C83F3313E0000C03F7C0C4E41F918DC4276256A44F331F845703E8647ED4A14496957A24AE663304C6370BE4DE07C4C4F5D89DA50DA95685257A2F653D3AE845550BB1257CDC7A0584AD42E5AC7E0BC5B44ED4A5DC1F9D85E3D066760BA12F561371F8363B42B116531389F66AE442D682B51BB69A75D496B246AD76CA176656E1E83F36F9B8F8171189C0F7395A89D7411B52B768EC1B9770BCE477988DAD57A05E7637CFEFF7F7FFFFF7F7F0000807F"> : tensor<169xf32> + return %0 : tensor<169xf32> + } + func.func private @expected() -> tensor<169xf32> { + %0 = stablehlo.constant dense<"0x0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07FEC7943BE7590A5BC760616BB8DD287B983C2F9B72DB56BB699A8DDB4199C4FB39C8FC1B11F8333B0A276A5AE256A17ADA85D89AB2C51FBA9AF446DA83238DFA6B52B51A5381FC3A3BB1235A23E06A7A0C2F9189F45ED8A9DC8E0FC9B4BD46E9ACEC7E09851BB5297D4AEC49558A23694DB95A8925E891A91E17C8C8F6470FE8DE763708C6A57E28AEE4A5489713EC687F43138867725AA84FA181C837D0C8E810100008000000000010000007D0C8E01FA181C037725AA04F4313806713EC607EE4A54096A57E20AE763700C6470FE0DE17C8C0F5E891A11DB95A81258A23614D4AEC41551BB5217CEC7E0184BD46E1AC8E0FC1B45ED8A1DC2F9181F3E06A720BB123522381FC323B52B51253238DF26AF446D282C51FB29A85D892B256A172DA276A52E1F8333309C8FC131199C4F3393A8DD34F7B46B369BC0F9378CC98739BEAE153BCA48A23C4313243E08926A3FF956284012B8964061FCDA4065AE0F413D3532414012554162CA7741DD318D4182719E41F9A5AF41BFD0C0410BF3D141DA0DE3417846F441A4DB02429E8A0B42BF3114423ED21C420E6D2542F3022E428B94364257223F4208B947429B7250428B2259427ACA6142A96B6A4210077342799D7B42C2178242DB5E864257A78A42BA048F422F5D934290B1974281029C428150A042F89BA44238E5A8421872B1421872B1420000807F"> : tensor<169xf32> + return %0 : tensor<169xf32> + } + func.func public @main() { + %0 = call @samples() : () -> tensor<169xf32> + %1 = "stablehlo.log_plus_one"(%0) : (tensor<169xf32>) -> tensor<169xf32> + %2 = call @expected() : () -> tensor<169xf32> + check.expect_close %1, %2, max_ulp_difference = 3 : tensor<169xf32>, tensor<169xf32> + func.return + } +} diff --git a/stablehlo/tests/math/log_plus_one_float64.mlir b/stablehlo/tests/math/log_plus_one_float64.mlir new file mode 100644 index 0000000000..8d0c0c7f1d --- /dev/null +++ b/stablehlo/tests/math/log_plus_one_float64.mlir @@ -0,0 +1,19 @@ +// RUN: stablehlo-opt --stablehlo-complex-math-expander %s | stablehlo-translate --interpret +// This file is generated, see build_tools/math/README.md for more information. +module @log_plus_one_float64 { + func.func private @samples() -> tensor<169xf64> { + %0 = stablehlo.constant dense<"0x000000000000F0FFFFFFFFFFFFFFEFFFFEFFFFFFFFFFEFFF2A51BB12B52BD1FCC0F9189C8FC141FB56A276256A57B2F9EC4AD4AE44ED22F882F331381F8393F6189C8FC1F91804F5AE44ED4AD4AE74F343ED4AD4AE44E5F1D995A85D89DA55F06F3E06E76370C6EE05E763703E0637ED9B8FC1F9189CA7EB31381F83F33118EAC7E07C0CCEC788E85D89DA95A85DF9E6F231381F83F369E588DA95A85D89DAE31E83F331381F4BE2B42B51BB12B5BBE04AD4AE44ED4A2CDFE07C0CCEC7E09CDD76256A57A2760DDC0CCEC7E07C0C7EDAA176256A57A2EED8371F83F331385FD7CDC7E07C0CCECFD563703E06E76340D4F9189C8FC1F9B0D28FC1F9189C8F21D1256A57A2762592CFBB12B52B51BB02CE50BB12B52B5173CCE663703E06E7E3CA7C0CCEC7E07C54C912B52B51BB12C5C7A85D89DA95A835C63E06E763703EA6C4D4AE44ED4AD416C36A57A276256A87C1000000000000F8BF95A85D89DA9568BE2B51BB12B52BD9BCC1F9189C8FC149BB57A276256A57BAB9ED4AD4AE44ED2AB883F331381F839BB6199C8FC1F9180CB5AF44ED4AD4AE7CB344ED4AD4AE44EDB1DA95A85D89DA5DB0703E06E76370CEAE06E763703E063FAD9C8FC1F9189CAFAB32381F83F33120AAC8E07C0CCEC790A85E89DA95A85D01A7F331381F83F371A589DA95A85D89E2A31F83F331381F53A2B52B51BB12B5C3A04BD4AE44ED4A349FE17C0CCEC7E0A49D77256A57A276159C0DCEC7E07C0C869AA276256A57A2F698381F83F331386797CEC7E07C0CCED79564703E06E7634894FA189C8FC1F9B89290C1F9189C8F2991266A57A276259A8FBC12B52B51BB0A8E51BB12B52B517B8CE763703E06E7EB8A7D0CCEC7E07C5C8913B52B51BB12CD87A95D89DA95A83D863F06E763703EAE84D5AE44ED4AD41E836B57A276256A8F810100000000000080000000000000000001000000000000006B57A276256A8F01D5AE44ED4AD41E033F06E763703EAE04A95D89DA95A83D0613B52B51BB12CD077D0CCEC7E07C5C09E763703E06E7EB0A51BB12B52B517B0CBC12B52B51BB0A0E266A57A276259A0F90C1F9189C8F2911FA189C8FC1F9B81264703E06E7634814CEC7E07C0CCED715381F83F331386717A276256A57A2F6180DCEC7E07C0C861A77256A57A276151CE17C0CCEC7E0A41D4BD4AE44ED4A341FB52B51BB12B5C3201F83F331381F532289DA95A85D89E223F331381F83F371255E89DA95A85D0127C8E07C0CCEC7902832381F83F331202A9C8FC1F9189CAF2B06E763703E063F2D703E06E76370CE2EDA95A85D89DA5D3044ED4AD4AE44ED31AF44ED4AD4AE7C33199C8FC1F9180C3583F331381F839B36ED4AD4AE44ED2A3857A276256A57BA39C1F9189C8FC1493B2B51BB12B52BD93C95A85D89DA95683E000000000000F83F6A57A276256A8741D4AE44ED4AD416433E06E763703EA644A85D89DA95A8354612B52B51BB12C5477C0CCEC7E07C5449E663703E06E7E34A50BB12B52B51734CBB12B52B51BB024E256A57A27625924F8FC1F9189C8F2151F9189C8FC1F9B05263703E06E7634054CDC7E07C0CCECF55371F83F331385F57A176256A57A2EE580CCEC7E07C0C7E5A76256A57A2760D5CE07C0CCEC7E09C5D4AD4AE44ED4A2C5FB42B51BB12B5BB601E83F331381F4B6288DA95A85D89DA63F231381F83F369655D89DA95A85DF966C7E07C0CCEC7886831381F83F331186A9B8FC1F9189CA76B05E763703E06376D6F3E06E76370C66ED995A85D89DA557043ED4AD4AE44E571AE44ED4AD4AE7473189C8FC1F918047582F331381F839376EC4AD4AE44ED227856A276256A57B279C0F9189C8FC1417B2A51BB12B52BD17CFEFFFFFFFFFFEF7FFFFFFFFFFFFFEF7F000000000000F07F"> : tensor<169xf64> + return %0 : tensor<169xf64> + } + func.func private @expected() -> tensor<169xf64> { + %0 = stablehlo.constant dense<"0x000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F616BCF92DA9568BE3051BB12B52BD9BCC1F9189C8FC149BB57A276256A57BAB9ED4AD4AE44ED2AB883F331381F839BB6199C8FC1F9180CB5AF44ED4AD4AE7CB344ED4AD4AE44EDB1DA95A85D89DA5DB0703E06E76370CEAE06E763703E063FAD9C8FC1F9189CAFAB32381F83F33120AAC8E07C0CCEC790A85E89DA95A85D01A7F331381F83F371A589DA95A85D89E2A31F83F331381F53A2B52B51BB12B5C3A04BD4AE44ED4A349FE17C0CCEC7E0A49D77256A57A276159C0DCEC7E07C0C869AA276256A57A2F698381F83F331386797CEC7E07C0CCED79564703E06E7634894FA189C8FC1F9B89290C1F9189C8F2991266A57A276259A8FBC12B52B51BB0A8E51BB12B52B517B8CE763703E06E7EB8A7D0CCEC7E07C5C8913B52B51BB12CD87A95D89DA95A83D863F06E763703EAE84D5AE44ED4AD41E836B57A276256A8F810100000000000080000000000000000001000000000000006B57A276256A8F01D5AE44ED4AD41E033F06E763703EAE04A95D89DA95A83D0613B52B51BB12CD077D0CCEC7E07C5C09E763703E06E7EB0A51BB12B52B517B0CBC12B52B51BB0A0E266A57A276259A0F90C1F9189C8F2911FA189C8FC1F9B81264703E06E7634814CEC7E07C0CCED715381F83F331386717A276256A57A2F6180DCEC7E07C0C861A77256A57A276151CE17C0CCEC7E0A41D4BD4AE44ED4A341FB52B51BB12B5C3201F83F331381F532289DA95A85D89E223F331381F83F371255E89DA95A85D0127C8E07C0CCEC7902832381F83F331202A9C8FC1F9189CAF2B06E763703E063F2D703E06E76370CE2EDA95A85D89DA5D3044ED4AD4AE44ED31AF44ED4AD4AE7C33199C8FC1F9180C3583F331381F839B36ED4AD4AE44ED2A3857A276256A57BA39C1F9189C8FC1493B2651BB12B52BD93CD2E5EB7FDA95683E78E0E0F04052ED3FD36BF2A59EB531408C923CE1A38141406B04AAE262284A40B79D659885675140CAF791C6CDBA5540B805264F090E5A40DB2D687637615E4041F9E6B72B5A614028F5F42DB48363405CC602A334AD6540B120AB90ACD66740F898DC621B006A40D029EE7580296C40BCBAF17D0D536E40FA8CABED7C3E7040D2CDA9AB715371406A4BB9EA646872405AB5C89B567D7340F67CDEAE469274408BCE051335A775404D8D39B621BC7640D6104D850CD17740D358D26BF5E578401260FD53DCFA7940E9298426C10F7B40B3147BCAA3247C405FE82C2584397D408000EE19624E7E403AD5E9893D637F402002F5290B3C8040F7E80A2A76C680402474D4B1DF5081402FEE4FAC47DB8140765BAC02AE658240E2AF129C12F08240BAAB665D757A83408BCCFD28D6048440B0634ADE348F8440C36F785991198540EF39FAFE422E8640EF39FAFE422E8640000000000000F07F"> : tensor<169xf64> + return %0 : tensor<169xf64> + } + func.func public @main() { + %0 = call @samples() : () -> tensor<169xf64> + %1 = "stablehlo.log_plus_one"(%0) : (tensor<169xf64>) -> tensor<169xf64> + %2 = call @expected() : () -> tensor<169xf64> + check.expect_close %1, %2, max_ulp_difference = 3 : tensor<169xf64>, tensor<169xf64> + func.return + } +} diff --git a/stablehlo/tests/stablehlo_complex_math_expander.mlir b/stablehlo/tests/stablehlo_complex_math_expander.mlir new file mode 100644 index 0000000000..7a43f76d6e --- /dev/null +++ b/stablehlo/tests/stablehlo_complex_math_expander.mlir @@ -0,0 +1,110 @@ +// RUN: stablehlo-opt --stablehlo-complex-math-expander --split-input-file --verify-diagnostics %s | FileCheck %s + +// CHECK-LABEL: func.func @log_plus_one_complex_f32( +// CHECK-SAME: %[[VAL_0:.*]]: tensor>) -> tensor> { +// CHECK: %[[VAL_1:.*]] = stablehlo.constant dense<6.500000e+01> : tensor +// CHECK: %[[VAL_2:.*]] = stablehlo.constant dense<4.097000e+03> : tensor +// CHECK: %[[VAL_3:.*]] = stablehlo.constant dense<9.99999968E+37> : tensor +// CHECK: %[[VAL_4:.*]] = stablehlo.constant dense<0x4D000000> : tensor +// CHECK: %[[VAL_5:.*]] = stablehlo.constant dense<0x7F800000> : tensor +// CHECK: %[[VAL_6:.*]] = stablehlo.constant dense<2.000000e-01> : tensor +// CHECK: %[[VAL_7:.*]] = stablehlo.constant dense<1.000000e+00> : tensor +// CHECK: %[[VAL_8:.*]] = stablehlo.constant dense<5.000000e-01> : tensor +// CHECK: %[[VAL_9:.*]] = stablehlo.constant dense<0.00999999977> : tensor +// CHECK: %[[VAL_10:.*]] = stablehlo.constant dense<3.40282347E+38> : tensor +// CHECK: %[[VAL_11:.*]] = stablehlo.real %[[VAL_0]] : (tensor>) -> tensor +// CHECK: %[[VAL_12:.*]] = stablehlo.abs %[[VAL_11]] : tensor +// CHECK: %[[VAL_13:.*]] = stablehlo.imag %[[VAL_0]] : (tensor>) -> tensor +// CHECK: %[[VAL_14:.*]] = stablehlo.abs %[[VAL_13]] : tensor +// CHECK: %[[VAL_15:.*]] = stablehlo.maximum %[[VAL_12]], %[[VAL_14]] : tensor +// CHECK: %[[VAL_16:.*]] = stablehlo.sqrt %[[VAL_10]] : tensor +// CHECK: %[[VAL_17:.*]] = stablehlo.multiply %[[VAL_16]], %[[VAL_9]] : tensor +// CHECK: %[[VAL_18:.*]] = stablehlo.compare GT, %[[VAL_15]], %[[VAL_17]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_19:.*]] = stablehlo.log %[[VAL_15]] : tensor +// CHECK: %[[VAL_20:.*]] = stablehlo.minimum %[[VAL_12]], %[[VAL_14]] : tensor +// CHECK: %[[VAL_21:.*]] = stablehlo.compare EQ, %[[VAL_20]], %[[VAL_15]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_22:.*]] = stablehlo.divide %[[VAL_20]], %[[VAL_15]] : tensor +// CHECK: %[[VAL_23:.*]] = stablehlo.multiply %[[VAL_22]], %[[VAL_22]] : tensor +// CHECK: %[[VAL_24:.*]] = stablehlo.select %[[VAL_21]], %[[VAL_7]], %[[VAL_23]] : tensor, tensor +// CHECK: %[[VAL_25:.*]] = stablehlo.log_plus_one %[[VAL_24]] : tensor +// CHECK: %[[VAL_26:.*]] = stablehlo.multiply %[[VAL_8]], %[[VAL_25]] : tensor +// CHECK: %[[VAL_27:.*]] = stablehlo.add %[[VAL_19]], %[[VAL_26]] : tensor +// CHECK: %[[VAL_28:.*]] = stablehlo.add %[[VAL_11]], %[[VAL_7]] : tensor +// CHECK: %[[VAL_29:.*]] = stablehlo.abs %[[VAL_28]] : tensor +// CHECK: %[[VAL_30:.*]] = stablehlo.add %[[VAL_29]], %[[VAL_14]] : tensor +// CHECK: %[[VAL_31:.*]] = stablehlo.compare LT, %[[VAL_30]], %[[VAL_6]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_32:.*]] = stablehlo.multiply %[[VAL_28]], %[[VAL_28]] : tensor +// CHECK: %[[VAL_33:.*]] = stablehlo.multiply %[[VAL_13]], %[[VAL_13]] : tensor +// CHECK: %[[VAL_34:.*]] = stablehlo.add %[[VAL_32]], %[[VAL_33]] : tensor +// CHECK: %[[VAL_35:.*]] = stablehlo.log %[[VAL_34]] : tensor +// CHECK: %[[VAL_36:.*]] = stablehlo.multiply %[[VAL_8]], %[[VAL_35]] : tensor +// CHECK: %[[VAL_37:.*]] = stablehlo.add %[[VAL_11]], %[[VAL_11]] : tensor +// CHECK: %[[VAL_38:.*]] = stablehlo.add %[[VAL_37]], %[[VAL_33]] : tensor +// CHECK: %[[VAL_39:.*]] = stablehlo.multiply %[[VAL_11]], %[[VAL_11]] : tensor +// CHECK: %[[VAL_40:.*]] = stablehlo.add %[[VAL_38]], %[[VAL_39]] : tensor +// CHECK: %[[VAL_41:.*]] = stablehlo.negate %[[VAL_33]] : tensor +// CHECK: %[[VAL_42:.*]] = stablehlo.compare GT, %[[VAL_10]], %[[VAL_5]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_43:.*]] = stablehlo.compare GT, %[[VAL_10]], %[[VAL_3]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_44:.*]] = stablehlo.select %[[VAL_43]], %[[VAL_2]], %[[VAL_1]] : tensor, tensor +// CHECK: %[[VAL_45:.*]] = stablehlo.select %[[VAL_42]], %[[VAL_4]], %[[VAL_44]] : tensor, tensor +// CHECK: %[[VAL_46:.*]] = stablehlo.multiply %[[VAL_45]], %[[VAL_13]] : tensor +// CHECK: %[[VAL_47:.*]] = stablehlo.subtract %[[VAL_13]], %[[VAL_46]] : tensor +// CHECK: %[[VAL_48:.*]] = stablehlo.add %[[VAL_46]], %[[VAL_47]] : tensor +// CHECK: %[[VAL_49:.*]] = stablehlo.multiply %[[VAL_48]], %[[VAL_48]] : tensor +// CHECK: %[[VAL_50:.*]] = stablehlo.add %[[VAL_41]], %[[VAL_49]] : tensor +// CHECK: %[[VAL_51:.*]] = stablehlo.subtract %[[VAL_13]], %[[VAL_48]] : tensor +// CHECK: %[[VAL_52:.*]] = stablehlo.multiply %[[VAL_48]], %[[VAL_51]] : tensor +// CHECK: %[[VAL_53:.*]] = stablehlo.add %[[VAL_50]], %[[VAL_52]] : tensor +// CHECK: %[[VAL_54:.*]] = stablehlo.add %[[VAL_53]], %[[VAL_52]] : tensor +// CHECK: %[[VAL_55:.*]] = stablehlo.multiply %[[VAL_51]], %[[VAL_51]] : tensor +// CHECK: %[[VAL_56:.*]] = stablehlo.add %[[VAL_54]], %[[VAL_55]] : tensor +// CHECK: %[[VAL_57:.*]] = stablehlo.add %[[VAL_40]], %[[VAL_56]] : tensor +// CHECK: %[[VAL_58:.*]] = stablehlo.negate %[[VAL_39]] : tensor +// CHECK: %[[VAL_59:.*]] = stablehlo.multiply %[[VAL_45]], %[[VAL_11]] : tensor +// CHECK: %[[VAL_60:.*]] = stablehlo.subtract %[[VAL_11]], %[[VAL_59]] : tensor +// CHECK: %[[VAL_61:.*]] = stablehlo.add %[[VAL_59]], %[[VAL_60]] : tensor +// CHECK: %[[VAL_62:.*]] = stablehlo.multiply %[[VAL_61]], %[[VAL_61]] : tensor +// CHECK: %[[VAL_63:.*]] = stablehlo.add %[[VAL_58]], %[[VAL_62]] : tensor +// CHECK: %[[VAL_64:.*]] = stablehlo.subtract %[[VAL_11]], %[[VAL_61]] : tensor +// CHECK: %[[VAL_65:.*]] = stablehlo.multiply %[[VAL_61]], %[[VAL_64]] : tensor +// CHECK: %[[VAL_66:.*]] = stablehlo.add %[[VAL_63]], %[[VAL_65]] : tensor +// CHECK: %[[VAL_67:.*]] = stablehlo.add %[[VAL_66]], %[[VAL_65]] : tensor +// CHECK: %[[VAL_68:.*]] = stablehlo.multiply %[[VAL_64]], %[[VAL_64]] : tensor +// CHECK: %[[VAL_69:.*]] = stablehlo.add %[[VAL_67]], %[[VAL_68]] : tensor +// CHECK: %[[VAL_70:.*]] = stablehlo.add %[[VAL_57]], %[[VAL_69]] : tensor +// CHECK: %[[VAL_71:.*]] = stablehlo.subtract %[[VAL_38]], %[[VAL_37]] : tensor +// CHECK: %[[VAL_72:.*]] = stablehlo.subtract %[[VAL_38]], %[[VAL_71]] : tensor +// CHECK: %[[VAL_73:.*]] = stablehlo.subtract %[[VAL_37]], %[[VAL_72]] : tensor +// CHECK: %[[VAL_74:.*]] = stablehlo.subtract %[[VAL_33]], %[[VAL_71]] : tensor +// CHECK: %[[VAL_75:.*]] = stablehlo.add %[[VAL_73]], %[[VAL_74]] : tensor +// CHECK: %[[VAL_76:.*]] = stablehlo.subtract %[[VAL_40]], %[[VAL_38]] : tensor +// CHECK: %[[VAL_77:.*]] = stablehlo.subtract %[[VAL_40]], %[[VAL_76]] : tensor +// CHECK: %[[VAL_78:.*]] = stablehlo.subtract %[[VAL_38]], %[[VAL_77]] : tensor +// CHECK: %[[VAL_79:.*]] = stablehlo.subtract %[[VAL_39]], %[[VAL_76]] : tensor +// CHECK: %[[VAL_80:.*]] = stablehlo.add %[[VAL_78]], %[[VAL_79]] : tensor +// CHECK: %[[VAL_81:.*]] = stablehlo.add %[[VAL_75]], %[[VAL_80]] : tensor +// CHECK: %[[VAL_82:.*]] = stablehlo.subtract %[[VAL_57]], %[[VAL_40]] : tensor +// CHECK: %[[VAL_83:.*]] = stablehlo.subtract %[[VAL_57]], %[[VAL_82]] : tensor +// CHECK: %[[VAL_84:.*]] = stablehlo.subtract %[[VAL_40]], %[[VAL_83]] : tensor +// CHECK: %[[VAL_85:.*]] = stablehlo.subtract %[[VAL_56]], %[[VAL_82]] : tensor +// CHECK: %[[VAL_86:.*]] = stablehlo.add %[[VAL_84]], %[[VAL_85]] : tensor +// CHECK: %[[VAL_87:.*]] = stablehlo.add %[[VAL_81]], %[[VAL_86]] : tensor +// CHECK: %[[VAL_88:.*]] = stablehlo.subtract %[[VAL_70]], %[[VAL_57]] : tensor +// CHECK: %[[VAL_89:.*]] = stablehlo.subtract %[[VAL_70]], %[[VAL_88]] : tensor +// CHECK: %[[VAL_90:.*]] = stablehlo.subtract %[[VAL_57]], %[[VAL_89]] : tensor +// CHECK: %[[VAL_91:.*]] = stablehlo.subtract %[[VAL_69]], %[[VAL_88]] : tensor +// CHECK: %[[VAL_92:.*]] = stablehlo.add %[[VAL_90]], %[[VAL_91]] : tensor +// CHECK: %[[VAL_93:.*]] = stablehlo.add %[[VAL_87]], %[[VAL_92]] : tensor +// CHECK: %[[VAL_94:.*]] = stablehlo.add %[[VAL_70]], %[[VAL_93]] : tensor +// CHECK: %[[VAL_95:.*]] = stablehlo.log_plus_one %[[VAL_94]] : tensor +// CHECK: %[[VAL_96:.*]] = stablehlo.multiply %[[VAL_8]], %[[VAL_95]] : tensor +// CHECK: %[[VAL_97:.*]] = stablehlo.select %[[VAL_31]], %[[VAL_36]], %[[VAL_96]] : tensor, tensor +// CHECK: %[[VAL_98:.*]] = stablehlo.select %[[VAL_18]], %[[VAL_27]], %[[VAL_97]] : tensor, tensor +// CHECK: %[[VAL_99:.*]] = stablehlo.atan2 %[[VAL_13]], %[[VAL_28]] : tensor +// CHECK: %[[VAL_100:.*]] = stablehlo.complex %[[VAL_98]], %[[VAL_99]] : tensor> +// CHECK: return %[[VAL_100]] : tensor> +// CHECK: } +func.func @log_plus_one_complex_f32(%arg : tensor>) -> tensor> { + %result = "stablehlo.log_plus_one"(%arg) : (tensor>) -> tensor> + func.return %result : tensor> +} diff --git a/stablehlo/transforms/CMakeLists.txt b/stablehlo/transforms/CMakeLists.txt index c69f7ff58c..f5193cb61e 100644 --- a/stablehlo/transforms/CMakeLists.txt +++ b/stablehlo/transforms/CMakeLists.txt @@ -28,6 +28,10 @@ set(LLVM_TARGET_DEFINITIONS StablehloCompatibilityExpanderPatterns.td) mlir_tablegen(StablehloCompatibilityExpanderPatterns.h.inc --gen-rewriters) add_public_tablegen_target(StablehloCompatibilityExpanderPatternsIncGen) +set(LLVM_TARGET_DEFINITIONS StablehloComplexMathExpanderPatterns.td) +mlir_tablegen(StablehloComplexMathExpanderPatterns.h.inc --gen-rewriters) +add_public_tablegen_target(StablehloComplexMathExpanderPatternsIncGen) + set(LLVM_TARGET_DEFINITIONS StablehloLegalizeDeprecatedOpsPatterns.td) mlir_tablegen(StablehloLegalizeDeprecatedOpsPatterns.h.inc --gen-rewriters) add_public_tablegen_target(StablehloLegalizeDeprecatedOpsPatternsIncGen) @@ -47,6 +51,7 @@ add_mlir_dialect_library(StablehloPasses StablehloCanonicalizeDynamism.cpp StablehloConvertToSignless.cpp StablehloCompatibilityExpander.cpp + StablehloComplexMathExpander.cpp StablehloLegalizeCompositeToCall.cpp StablehloLegalizeDeprecatedOps.cpp StablehloLegalizeQuantToMath.cpp @@ -64,6 +69,7 @@ add_mlir_dialect_library(StablehloPasses PassesIncGen StablehloAggressiveSimplificationPatternsIncGen StablehloCompatibilityExpanderPatternsIncGen + StablehloComplexMathExpanderPatternsIncGen StablehloLegalizeDeprecatedOpsPatternsIncGen VhloToVersionPatterns diff --git a/stablehlo/transforms/ChloDecompositionPatternsMath.td b/stablehlo/transforms/ChloDecompositionPatternsMath.td index 3d1be0d4f6..85bfb0cf26 100644 --- a/stablehlo/transforms/ChloDecompositionPatternsMath.td +++ b/stablehlo/transforms/ChloDecompositionPatternsMath.td @@ -14,7 +14,7 @@ limitations under the License. ==============================================================================*/ // -// This file is generated using functional_algorithms tool (0.11.1). +// This file is generated using functional_algorithms tool (0.12.0). // See build_tools/math/README.md for more information. // A kernel for evaluating asin and acos functions on complex inputs. diff --git a/stablehlo/transforms/Passes.h b/stablehlo/transforms/Passes.h index 43ee22f355..bf01394cda 100644 --- a/stablehlo/transforms/Passes.h +++ b/stablehlo/transforms/Passes.h @@ -144,6 +144,11 @@ void createStablehloRemoveDynamismPipeline(OpPassManager &pm, // operations into a primitive math operations. void createStablehloLowerQuantPipeline(OpPassManager &pm); +/// Collection of patterns to create expander for StableHLO complex +/// math operations. +void populateStablehloComplexMathExpanderPatterns(RewritePatternSet *patterns, + MLIRContext *context); + // Adds `stablehlo-deserialize` pipeline as a registered pass pipeline // for opt tools. void registerPassPipelines(); diff --git a/stablehlo/transforms/Passes.td b/stablehlo/transforms/Passes.td index 186560d012..aa9d696664 100644 --- a/stablehlo/transforms/Passes.td +++ b/stablehlo/transforms/Passes.td @@ -123,6 +123,40 @@ def StablehloCompatibilityExpanderPass : Pass<"stablehlo-compatibility-expander" ]; } +def StablehloComplexMathExpanderPass : Pass<"stablehlo-complex-math-expander", "mlir::func::FuncOp"> { + let summary = "Expander for StableHLO complex math operations."; + + let description = [{ + StableHLO complex math operations are decompositions using + StableHLO real math operations. + + This statement is based on the assumption that no hardware exists + that supports complex numbers nor complex math operations + natively. This means that the fallback mechanisms on complex math + operations that compilers may implement, are redundant. With + enabling this pass, all StableHLO complex math operations will be + expanded. + + ```mlir + func.func @sqrt_op_complex(%arg0: tensor<4xcomplex>) -> tensor<4xcomplex> { + %1 = stablehlo.sqrt %arg0 : tensor<4xcomplex> + func.return %1 : tensor<4xcomplex> + } + + ==> + + func.func @sqrt_op_complex(%arg0: tensor<4xcomplex>) -> tensor<4xcomplex> { + TBD + return %2 : tensor<4xcomplex> + } + ``` + }]; + let dependentDialects = [ + "mlir::stablehlo::StablehloDialect", + "mlir::chlo::ChloDialect", + ]; +} + def StablehloConvertToSignlessPass : Pass<"stablehlo-convert-to-signless", "ModuleOp"> { let summary = "Pass to transform the IR to be on signless integers."; } diff --git a/stablehlo/transforms/StablehloComplexMathExpander.cpp b/stablehlo/transforms/StablehloComplexMathExpander.cpp new file mode 100644 index 0000000000..b830db6196 --- /dev/null +++ b/stablehlo/transforms/StablehloComplexMathExpander.cpp @@ -0,0 +1,76 @@ +/* Copyright 2024 The StableHLO Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "stablehlo/dialect/StablehloOps.h" +#include "stablehlo/transforms/PassUtils.h" +#include "stablehlo/transforms/Passes.h" + +namespace mlir { +namespace stablehlo { +#define GEN_PASS_DEF_STABLEHLOCOMPLEXMATHEXPANDERPASS +#include "stablehlo/transforms/Passes.h.inc" + +namespace { + +static Value getConstantLikeMaxFiniteValue(OpBuilder &b, Location loc, + Value val) { + auto ty = cast(getElementTypeOrSelf(val.getType())); + return getConstantLike( + b, loc, llvm::APFloat::getLargest(ty.getFloatSemantics()), val); +} + +//===----------------------------------------------------------------------===// +// Pass +//===----------------------------------------------------------------------===// + +struct StablehloComplexMathExpanderPass + : public impl::StablehloComplexMathExpanderPassBase< + StablehloComplexMathExpanderPass> { + StablehloComplexMathExpanderPass() + : StablehloComplexMathExpanderPassBase< + StablehloComplexMathExpanderPass>() {} + + public: + LogicalResult initialize(MLIRContext *context) override { + config.useTopDownTraversal = true; + RewritePatternSet patterns_(context); + populateStablehloComplexMathExpanderPatterns(&patterns_, context); + patterns = std::move(patterns_); + return success(); + } + + void runOnOperation() override { + auto func = getOperation(); + if (failed(applyPatternsAndFoldGreedily(func, patterns, config))) { + func.emitError("Failed to converge StableHLOComplexMathExpanderPass in ") + << config.maxIterations << " iterations"; + signalPassFailure(); + } + } + + private: + FrozenRewritePatternSet patterns; + GreedyRewriteConfig config; +}; + +#include "stablehlo/transforms/StablehloComplexMathExpanderPatterns.h.inc" + +} // namespace + +void populateStablehloComplexMathExpanderPatterns(RewritePatternSet *patterns, + MLIRContext *context) { + populateWithGenerated(*patterns); +} + +} // namespace stablehlo +} // namespace mlir diff --git a/stablehlo/transforms/StablehloComplexMathExpanderPatterns.td b/stablehlo/transforms/StablehloComplexMathExpanderPatterns.td new file mode 100644 index 0000000000..42cb7f72fc --- /dev/null +++ b/stablehlo/transforms/StablehloComplexMathExpanderPatterns.td @@ -0,0 +1,275 @@ +/* Copyright 2024 The StableHLO Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// +// This file is generated using functional_algorithms tool (0.12.0). +// See build_tools/math/README.md for more information. + +include "mlir/IR/OpBase.td" +include "stablehlo/dialect/StablehloOps.td" + +class StableHLO_ComparisonDirectionValue : + ConstantAttr; + +class StableHLO_ConstantLike : NativeCodeCall< + "::mlir::stablehlo::getConstantLike($_builder, $_loc, " # value # ", $0)">; + +def ComplexElementType : Type< + CPred<"isa(cast($_self).getElementType())">, + "Complex element type">; + +def StableHLO_ConstantLikeMaxFiniteValue : NativeCodeCall< + "::mlir::stablehlo::getConstantLikeMaxFiniteValue($_builder, $_loc, $0)">; + +// Logarithm of 1 + z on complex input: +// +// log1p(x + I * y) = 0.5 * log((x + 1) ** 2 + y ** 2) + I * arctan2(y, x + 1) +// +// where +// +// x and y are real and imaginary parts of the input to log1p, and +// I is imaginary unit. +// +// For evaluating the real part of log1p accurately on the whole +// complex plane, the following cases must be handled separately: +// +// A) Avoid catastrophic cancellation errors when x is close `-0.5 * y * y` +// and `abs(y) < 1`. +// B) Avoid overflow from square when x or y are large in absolute value. +// C) Avoid cancellation errors when x is close to -1 and y is not large. +// D) Avoid cancellation errors when x is close to -2 and y is not large. +// +// Case A +// ------ +// +// The real part of log1p reads: +// +// 0.5 * log((x + 1) ** 2 + y ** 2) = 0.5 * log1p(x + x + x * x + y * y) +// +// When abs(y) < 1 and abs(x + 0.5 * y ** 2) is small, catastrophic +// cancellation errors occur when evaluating `x + x + x * x + y * y` +// using floating-point arithmetics. To avoid these errors, we'll use +// Dekker's product for computing `x * x` and `y * y` which +// effectively doubles the precision of the used floating-point +// system. In addition, the terms are summed together using 2Sum +// algorithm that minimizes cancellation errors. We'll have +// +// xxh, xxl = square_dekker(x) +// yyh, yyl = square_dekker(y) +// x + x + x * x + y * y = sum_2sum([x + x, yyh, xxh, yyl, xxl]) +// +// which is accurate when the following inequalities hold: +// +// abs(x) < sqrt(largest) * 0.1 +// abs(y) < sqrt(largest) * 0.99 +// +// [verified numerically for float32 and float64], except when x is +// close to -1 (see Case C). +// +// Case B +// ------ +// +// If abs(x) or abs(y) is larger than sqrt(largest), squareing +// these will overflow. To avoid such overflows, we'll apply +// rescaling of log1p arguments. +// +// First notice that if `abs(x) > sqrt(largest) > 4 / eps` holds then +// `x + 1 ~= x`. Also, if `abs(x) < 4 / eps` then `(x + 1) ** 2 + y +// ** 2 ~= y ** 2`. Proof: +// +// (x + 1) ** 2 + y ** 2 ~= y ** 2 iff y ** 2 > 4 * (x + 1) ** 2 / eps +// +// The lower limit to `y ** 2` is largest. The upper limit to +// `4 * (x + 1) ** 2 / eps` is `64 / eps ** 3` which is smaller than +// largest. QED. +// +// In conclusion, we can write +// +// (x + 1) ** 2 + y ** 2 ~= x ** 2 + y ** 2 +// +// whenever abs(x) or abs(y) is greater than sqrt(largest). +// +// Define +// +// mx = max(abs(x), abs(y)) +// mn = min(abs(x), abs(y)) +// +// then under the given restrictions we'll have +// +// real(log(x + I * y)) ~= 0.5 * log(x ** 2 + y ** 2) +// = 0.5 * log(mx ** 2 * (1 + (mn / mx) ** 2)) +// = log(mx) + 0.5 * log1p((mn / mx) ** 2) +// +// If mn == inf and mx == inf, we'll define `mn / mx == 1` for the +// sake of reusing the above expression for complex infinities +// (recall, `real(log(+-inf +-inf * I)) == inf`). +// +// Case C +// ------ +// +// If x is close to -1, then we'll use +// +// real(log1p(x + I * y)) = 0.5 * log((1 + x) ** 2 + y ** 2) +// +// which is accurate when the following inequalities hold: +// +// -1.5 < x < -0.5 or abs(x + 1) < 0.5 +// abs(y) < sqrt(largest) +// +// [verified numerically for float32 and float64]. For simplicity, +// we'll use the case C only when `abs(x) + abs(y) < 0.2`. +// +// Case D +// ------ +// +// If x is close to -2, the cancellation errors are avoided by using +// the Case A method [verified numerically for float32 and float64]. +// +// +def Log1pOp_ComplexElementType_ComplexMathExpander : Pat<(StableHLO_Log1pOp ComplexElementType:$z), + (StableHLO_ComplexOp + (StableHLO_SelectOp + (StableHLO_CompareOp + (StableHLO_MaxOp:$mx + (StableHLO_AbsOp:$ax + (StableHLO_RealOp:$x $z)), + (StableHLO_AbsOp:$ay + (StableHLO_ImagOp:$y $z))), + (StableHLO_MulOp + (StableHLO_SqrtOp + (StableHLO_ConstantLikeMaxFiniteValue:$largest $x)), + (StableHLO_ConstantLike<"0.01"> $x)), + StableHLO_ComparisonDirectionValue<"GT">, + (STABLEHLO_DEFAULT_COMPARISON_TYPE)), + (StableHLO_AddOp + (StableHLO_LogOp $mx), + (StableHLO_MulOp + (StableHLO_ConstantLike<"0.5">:$half $x), + (StableHLO_Log1pOp + (StableHLO_SelectOp + (StableHLO_CompareOp + (StableHLO_MinOp:$mn $ax, $ay), + $mx, + StableHLO_ComparisonDirectionValue<"EQ">, + (STABLEHLO_DEFAULT_COMPARISON_TYPE)), + (StableHLO_ConstantLike<"1">:$one $x), + (StableHLO_MulOp + (StableHLO_DivOp:$r $mn, $mx), + $r))))), + (StableHLO_SelectOp + (StableHLO_CompareOp + (StableHLO_AddOp + (StableHLO_AbsOp + (StableHLO_AddOp:$xp1 $x, $one)), + $ay), + (StableHLO_ConstantLike<"0.2"> $x), + StableHLO_ComparisonDirectionValue<"LT">, + (STABLEHLO_DEFAULT_COMPARISON_TYPE)), + (StableHLO_MulOp + $half, + (StableHLO_LogOp + (StableHLO_AddOp + (StableHLO_MulOp $xp1, $xp1), + (StableHLO_MulOp:$square_dekker_high $y, $y)))), + (StableHLO_MulOp + $half, + (StableHLO_Log1pOp + (StableHLO_AddOp:$sum_2sum_high + (StableHLO_AddOp:$add_2sum_high + (StableHLO_AddOp:$_add_2sum_high_0_ + (StableHLO_AddOp:$_add_2sum_high_1_ + (StableHLO_AddOp:$_add_2sum_high_2_ + (StableHLO_AddOp:$x2h $x, $x), + $square_dekker_high), + (StableHLO_MulOp:$_square_dekker_high_0_ $x, $x)), + (StableHLO_AddOp:$square_dekker_low + (StableHLO_AddOp + (StableHLO_AddOp + (StableHLO_AddOp + (StableHLO_NegOp $square_dekker_high), + (StableHLO_MulOp + (StableHLO_AddOp:$yh + (StableHLO_MulOp:$multiply_veltkamp_splitter_constant_y + (StableHLO_SelectOp:$veltkamp_splitter_constant + (StableHLO_CompareOp + $largest, + (StableHLO_ConstantLike<"1e+308"> $x), + StableHLO_ComparisonDirectionValue<"GT">, + (STABLEHLO_DEFAULT_COMPARISON_TYPE)), + (StableHLO_ConstantLike<"134217729"> $x), + (StableHLO_SelectOp + (StableHLO_CompareOp + $largest, + (StableHLO_ConstantLike<"1e+38"> $x), + StableHLO_ComparisonDirectionValue<"GT">, + (STABLEHLO_DEFAULT_COMPARISON_TYPE)), + (StableHLO_ConstantLike<"4097"> $x), + (StableHLO_ConstantLike<"65"> $x))), + $y), + (StableHLO_SubtractOp $y, $multiply_veltkamp_splitter_constant_y)), + $yh)), + (StableHLO_MulOp:$multiply_yh_yl + $yh, + (StableHLO_SubtractOp:$yl $y, $yh))), + $multiply_yh_yl), + (StableHLO_MulOp $yl, $yl))), + (StableHLO_AddOp:$_square_dekker_low_0_ + (StableHLO_AddOp + (StableHLO_AddOp + (StableHLO_AddOp + (StableHLO_NegOp $_square_dekker_high_0_), + (StableHLO_MulOp + (StableHLO_AddOp:$xh + (StableHLO_MulOp:$multiply_veltkamp_splitter_constant_x $veltkamp_splitter_constant, $x), + (StableHLO_SubtractOp $x, $multiply_veltkamp_splitter_constant_x)), + $xh)), + (StableHLO_MulOp:$multiply_xh_xl + $xh, + (StableHLO_SubtractOp:$xl $x, $xh))), + $multiply_xh_xl), + (StableHLO_MulOp $xl, $xl))), + (StableHLO_AddOp + (StableHLO_AddOp + (StableHLO_AddOp + (StableHLO_AddOp:$add_2sum_low + (StableHLO_SubtractOp + $x2h, + (StableHLO_SubtractOp + $_add_2sum_high_2_, + (StableHLO_SubtractOp:$subtract__add_2sum_high_2__x2h $_add_2sum_high_2_, $x2h))), + (StableHLO_SubtractOp $square_dekker_high, $subtract__add_2sum_high_2__x2h)), + (StableHLO_AddOp:$_add_2sum_low_0_ + (StableHLO_SubtractOp + $_add_2sum_high_2_, + (StableHLO_SubtractOp + $_add_2sum_high_1_, + (StableHLO_SubtractOp:$subtract__add_2sum_high_1___add_2sum_high_2_ $_add_2sum_high_1_, $_add_2sum_high_2_))), + (StableHLO_SubtractOp $_square_dekker_high_0_, $subtract__add_2sum_high_1___add_2sum_high_2_))), + (StableHLO_AddOp:$_add_2sum_low_1_ + (StableHLO_SubtractOp + $_add_2sum_high_1_, + (StableHLO_SubtractOp + $_add_2sum_high_0_, + (StableHLO_SubtractOp:$subtract__add_2sum_high_0___add_2sum_high_1_ $_add_2sum_high_0_, $_add_2sum_high_1_))), + (StableHLO_SubtractOp $square_dekker_low, $subtract__add_2sum_high_0___add_2sum_high_1_))), + (StableHLO_AddOp:$_add_2sum_low_2_ + (StableHLO_SubtractOp + $_add_2sum_high_0_, + (StableHLO_SubtractOp + $add_2sum_high, + (StableHLO_SubtractOp:$subtract_add_2sum_high__add_2sum_high_0_ $add_2sum_high, $_add_2sum_high_0_))), + (StableHLO_SubtractOp $_square_dekker_low_0_, $subtract_add_2sum_high__add_2sum_high_0_)))))))), + (StableHLO_Atan2Op $y, $xp1))>;