From 4932b9348d294cf3d99633c65c9f82c826ad28fc Mon Sep 17 00:00:00 2001 From: Paul Berg Date: Mon, 28 Oct 2024 12:06:18 +0100 Subject: [PATCH] clamp forward rule --- .../jax/Implementations/HLODerivatives.td | 9 +++++- test/lit_tests/diffrules/stablehlo/clamp.mlir | 31 +++++++++---------- 2 files changed, 23 insertions(+), 17 deletions(-) diff --git a/src/enzyme_ad/jax/Implementations/HLODerivatives.td b/src/enzyme_ad/jax/Implementations/HLODerivatives.td index e8f3d173..b722be91 100644 --- a/src/enzyme_ad/jax/Implementations/HLODerivatives.td +++ b/src/enzyme_ad/jax/Implementations/HLODerivatives.td @@ -558,7 +558,14 @@ def : HLODerivative<"ClampOp", (Op $min, $operand, $max), [ (DiffeRet), (HLOConstantFP<"0"> $max) ), -]>; +], (Select + (Compare (Max $operand, $min), $max, (GT)), + (SelectIfActive $max, (Shadow $max), (HLOConstantFP<"0"> $max)), + (Select + (Compare $operand, $min, (LT)), + (SelectIfActive $min, (Shadow $min), (HLOConstantFP<"0"> $min)), + (SelectIfActive $operand, (Shadow $operand), (HLOConstantFP<"0"> $operand))) +)>; def : HLODerivative<"CbrtOp", (Op $x), [ (CheckedMul (DiffeRet), (Div (Pow $x, (Div (HLOConstantFP<"-2">), (HLOConstantFP<"3">))), (HLOConstantFP<"3">))), diff --git a/test/lit_tests/diffrules/stablehlo/clamp.mlir b/test/lit_tests/diffrules/stablehlo/clamp.mlir index 4b5dbf71..002037f2 100644 --- a/test/lit_tests/diffrules/stablehlo/clamp.mlir +++ b/test/lit_tests/diffrules/stablehlo/clamp.mlir @@ -25,27 +25,26 @@ module { check.expect_eq_const %res#0, dense<[1.5, 1.5, 1.0, 1.0, 2.0, 1.5, 1.5, 1.5, 1.5, 1.5]> : tensor<10xf32> check.expect_eq_const %res#1, dense<[1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0]> : tensor<10xf32> + %res_fwd:2 = enzyme.fwddiff @clamp(%min, %operand, %dclamp, %max) { + activity=[#enzyme, #enzyme, #enzyme], + ret_activity=[#enzyme] + } : (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>) -> (tensor<10xf32>, tensor<10xf32>) + + check.expect_eq_const %res_fwd#0, dense<[1.5, 1.5, 1.0, 1.0, 2.0, 1.5, 1.5, 1.5, 1.5, 1.5]> : tensor<10xf32> + check.expect_eq_const %res_fwd#1, dense<[1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0]> : tensor<10xf32> + func.return } } // FORWARD: func.func @clamp(%arg0: tensor<10xf32>, %arg1: tensor<10xf32>, %arg2: tensor<10xf32>, %arg3: tensor<10xf32>, %arg4: tensor<10xf32>, %arg5: tensor<10xf32>) -> (tensor<10xf32>, tensor<10xf32>) { -// FORWARD-NEXT: %cst = stablehlo.constant dense<0.000000e+00> : tensor<10xf32> -// FORWARD-NEXT: %0 = stablehlo.compare LT, %arg0, %arg4 : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1> -// FORWARD-NEXT: %1 = stablehlo.compare GT, %arg0, %arg2 : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1> -// FORWARD-NEXT: %2 = stablehlo.and %0, %1 : tensor<10xi1> -// FORWARD-NEXT: %3 = stablehlo.select %2, %arg1, %cst : tensor<10xi1>, tensor<10xf32> -// FORWARD-NEXT: %4 = stablehlo.compare LT, %arg2, %arg0 : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1> -// FORWARD-NEXT: %5 = stablehlo.compare GT, %arg2, %arg4 : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1> -// FORWARD-NEXT: %6 = stablehlo.or %4, %5 : tensor<10xi1> -// FORWARD-NEXT: %7 = stablehlo.select %6, %cst, %arg3 : tensor<10xi1>, tensor<10xf32> -// FORWARD-NEXT: %8 = stablehlo.add %3, %7 : tensor<10xf32> -// FORWARD-NEXT: %9 = stablehlo.maximum %arg2, %arg0 : tensor<10xf32> -// FORWARD-NEXT: %10 = stablehlo.compare GT, %9, %arg4 : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1> -// FORWARD-NEXT: %11 = stablehlo.select %10, %arg5, %cst : tensor<10xi1>, tensor<10xf32> -// FORWARD-NEXT: %12 = stablehlo.add %8, %11 : tensor<10xf32> -// FORWARD-NEXT: %13 = stablehlo.clamp %arg0, %arg2, %arg4 : tensor<10xf32> -// FORWARD-NEXT: return %13, %12 : tensor<10xf32>, tensor<10xf32> +// FORWARD-NEXT: %0 = stablehlo.maximum %arg2, %arg0 : tensor<10xf32> +// FORWARD-NEXT: %1 = stablehlo.compare GT, %0, %arg4 : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1> +// FORWARD-NEXT: %2 = stablehlo.compare LT, %arg2, %arg0 : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1> +// FORWARD-NEXT: %3 = stablehlo.select %2, %arg1, %arg3 : tensor<10xi1>, tensor<10xf32> +// FORWARD-NEXT: %4 = stablehlo.select %1, %arg5, %3 : tensor<10xi1>, tensor<10xf32> +// FORWARD-NEXT: %5 = stablehlo.clamp %arg0, %arg2, %arg4 : tensor<10xf32> +// FORWARD-NEXT: return %5, %4 : tensor<10xf32>, tensor<10xf32> // FORWARD-NEXT: } // REVERSE: func.func @clamp(%arg0: tensor<10xf32>, %arg1: tensor<10xf32>, %arg2: tensor<10xf32>, %arg3: tensor<10xf32>) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>) {