Skip to content

Commit

Permalink
clamp forward rule
Browse files Browse the repository at this point in the history
  • Loading branch information
Pangoraw committed Oct 28, 2024
1 parent b67965e commit 4932b93
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 17 deletions.
9 changes: 8 additions & 1 deletion src/enzyme_ad/jax/Implementations/HLODerivatives.td
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,14 @@ def : HLODerivative<"ClampOp", (Op $min, $operand, $max), [
(DiffeRet),
(HLOConstantFP<"0"> $max)
),
]>;
], (Select
(Compare (Max $operand, $min), $max, (GT)),
(SelectIfActive $max, (Shadow $max), (HLOConstantFP<"0"> $max)),
(Select
(Compare $operand, $min, (LT)),
(SelectIfActive $min, (Shadow $min), (HLOConstantFP<"0"> $min)),
(SelectIfActive $operand, (Shadow $operand), (HLOConstantFP<"0"> $operand)))
)>;

def : HLODerivative<"CbrtOp", (Op $x), [
(CheckedMul (DiffeRet), (Div (Pow $x, (Div (HLOConstantFP<"-2">), (HLOConstantFP<"3">))), (HLOConstantFP<"3">))),
Expand Down
31 changes: 15 additions & 16 deletions test/lit_tests/diffrules/stablehlo/clamp.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -25,27 +25,26 @@ module {
check.expect_eq_const %res#0, dense<[1.5, 1.5, 1.0, 1.0, 2.0, 1.5, 1.5, 1.5, 1.5, 1.5]> : tensor<10xf32>
check.expect_eq_const %res#1, dense<[1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0]> : tensor<10xf32>

%res_fwd:2 = enzyme.fwddiff @clamp(%min, %operand, %dclamp, %max) {
activity=[#enzyme<activity enzyme_const>, #enzyme<activity enzyme_dup>, #enzyme<activity enzyme_const>],
ret_activity=[#enzyme<activity enzyme_dup>]
} : (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>) -> (tensor<10xf32>, tensor<10xf32>)

check.expect_eq_const %res_fwd#0, dense<[1.5, 1.5, 1.0, 1.0, 2.0, 1.5, 1.5, 1.5, 1.5, 1.5]> : tensor<10xf32>
check.expect_eq_const %res_fwd#1, dense<[1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0]> : tensor<10xf32>

func.return
}
}

// FORWARD: func.func @clamp(%arg0: tensor<10xf32>, %arg1: tensor<10xf32>, %arg2: tensor<10xf32>, %arg3: tensor<10xf32>, %arg4: tensor<10xf32>, %arg5: tensor<10xf32>) -> (tensor<10xf32>, tensor<10xf32>) {
// FORWARD-NEXT: %cst = stablehlo.constant dense<0.000000e+00> : tensor<10xf32>
// FORWARD-NEXT: %0 = stablehlo.compare LT, %arg0, %arg4 : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1>
// FORWARD-NEXT: %1 = stablehlo.compare GT, %arg0, %arg2 : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1>
// FORWARD-NEXT: %2 = stablehlo.and %0, %1 : tensor<10xi1>
// FORWARD-NEXT: %3 = stablehlo.select %2, %arg1, %cst : tensor<10xi1>, tensor<10xf32>
// FORWARD-NEXT: %4 = stablehlo.compare LT, %arg2, %arg0 : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1>
// FORWARD-NEXT: %5 = stablehlo.compare GT, %arg2, %arg4 : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1>
// FORWARD-NEXT: %6 = stablehlo.or %4, %5 : tensor<10xi1>
// FORWARD-NEXT: %7 = stablehlo.select %6, %cst, %arg3 : tensor<10xi1>, tensor<10xf32>
// FORWARD-NEXT: %8 = stablehlo.add %3, %7 : tensor<10xf32>
// FORWARD-NEXT: %9 = stablehlo.maximum %arg2, %arg0 : tensor<10xf32>
// FORWARD-NEXT: %10 = stablehlo.compare GT, %9, %arg4 : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1>
// FORWARD-NEXT: %11 = stablehlo.select %10, %arg5, %cst : tensor<10xi1>, tensor<10xf32>
// FORWARD-NEXT: %12 = stablehlo.add %8, %11 : tensor<10xf32>
// FORWARD-NEXT: %13 = stablehlo.clamp %arg0, %arg2, %arg4 : tensor<10xf32>
// FORWARD-NEXT: return %13, %12 : tensor<10xf32>, tensor<10xf32>
// FORWARD-NEXT: %0 = stablehlo.maximum %arg2, %arg0 : tensor<10xf32>
// FORWARD-NEXT: %1 = stablehlo.compare GT, %0, %arg4 : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1>
// FORWARD-NEXT: %2 = stablehlo.compare LT, %arg2, %arg0 : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1>
// FORWARD-NEXT: %3 = stablehlo.select %2, %arg1, %arg3 : tensor<10xi1>, tensor<10xf32>
// FORWARD-NEXT: %4 = stablehlo.select %1, %arg5, %3 : tensor<10xi1>, tensor<10xf32>
// FORWARD-NEXT: %5 = stablehlo.clamp %arg0, %arg2, %arg4 : tensor<10xf32>
// FORWARD-NEXT: return %5, %4 : tensor<10xf32>, tensor<10xf32>
// FORWARD-NEXT: }

// REVERSE: func.func @clamp(%arg0: tensor<10xf32>, %arg1: tensor<10xf32>, %arg2: tensor<10xf32>, %arg3: tensor<10xf32>) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>) {
Expand Down

0 comments on commit 4932b93

Please sign in to comment.