Skip to content

Commit

Permalink
implement derivative for real and imag
Browse files Browse the repository at this point in the history
  • Loading branch information
Pangoraw committed Oct 21, 2024
1 parent 37303d1 commit 4aef5f7
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 0 deletions.
4 changes: 4 additions & 0 deletions src/enzyme_ad/jax/Implementations/HLODerivatives.td
Original file line number Diff line number Diff line change
Expand Up @@ -602,6 +602,8 @@ def : HLODerivative<"FftOp", (Op $x),

def : HLOInactiveOp<"FloorOp">;

def : HLODerivative<"ImagOp", (Op $x), [(Complex (HLOConstantFP<"0">), (Neg (DiffeRet)))], (Imag (Shadow $x))>;

def : HLOInactiveOp<"IotaOp">;

def : HLOInactiveOp<"IsFiniteOp">;
Expand Down Expand Up @@ -655,6 +657,8 @@ def : HLODerivative<"AbsOp", (Op $x), [
(Select (Compare $x, (HLOConstantFP<"0"> $x), (GE)), (DiffeRet), (Neg (DiffeRet)))
]>;

def : HLODerivative<"RealOp", (Op $x), [(Complex (DiffeRet), (HLOConstantFP<"0">))], (Real (Shadow $x))>;

def : HLODerivative<"RemOp", (Op $x, $y),
[
(DiffeRet),
Expand Down
25 changes: 25 additions & 0 deletions test/lit_tests/diffrules/stablehlo/imag.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// RUN: enzymexlamlir-opt %s --enzyme-wrap="infn=main outfn= retTys=enzyme_dup argTys=enzyme_dup mode=ForwardMode" --enzyme-simplify-math | FileCheck %s --check-prefix=FORWARD
// RUN: enzymexlamlir-opt %s --enzyme-wrap="infn=main outfn= retTys=enzyme_active argTys=enzyme_active mode=ReverseModeCombined" --arith-raise --canonicalize --remove-unnecessary-enzyme-ops --verify-each=0 | FileCheck %s --check-prefix=REVERSE

func.func @main(%operand : tensor<2xcomplex<f32>>) -> tensor<2xf32> {
%result = "stablehlo.imag"(%operand) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
return %result : tensor<2xf32>
}

// FORWARD: func.func @main(%arg0: tensor<2xcomplex<f32>>, %arg1: tensor<2xcomplex<f32>>) -> (tensor<2xf32>, tensor<2xf32>) {
// FORWARD-NEXT: %0 = stablehlo.imag %arg1 : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
// FORWARD-NEXT: %1 = stablehlo.imag %arg0 : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
// FORWARD-NEXT: return %1, %0 : tensor<2xf32>, tensor<2xf32>
// FORWARD-NEXT: }

// REVERSE: func.func @main(%arg0: tensor<2xcomplex<f32>>, %arg1: tensor<2xf32>) -> tensor<2xcomplex<f32>> {
// REVERSE-NEXT: %cst = stablehlo.constant dense<0.000000e+00> : tensor<2xf32>
// REVERSE-NEXT: %cst_0 = arith.constant dense<0.000000e+00> : tensor<2xf32>
// REVERSE-NEXT: %cst_1 = arith.constant dense<(0.000000e+00,0.000000e+00)> : tensor<2xcomplex<f32>>
// REVERSE-NEXT: %0 = stablehlo.add %cst_0, %arg1 : tensor<2xf32>
// REVERSE-NEXT: %1 = stablehlo.negate %0 : tensor<2xf32>
// REVERSE-NEXT: %2 = stablehlo.complex %cst, %1 : tensor<2xcomplex<f32>>
// REVERSE-NEXT: %3 = chlo.conj %2 : tensor<2xcomplex<f32>> -> tensor<2xcomplex<f32>>
// REVERSE-NEXT: %4 = stablehlo.add %cst_1, %3 : tensor<2xcomplex<f32>>
// REVERSE-NEXT: return %4 : tensor<2xcomplex<f32>>
// REVERSE-NEXT: }
24 changes: 24 additions & 0 deletions test/lit_tests/diffrules/stablehlo/real.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// RUN: enzymexlamlir-opt %s --enzyme-wrap="infn=main outfn= retTys=enzyme_dup argTys=enzyme_dup mode=ForwardMode" --enzyme-simplify-math | FileCheck %s --check-prefix=FORWARD
// RUN: enzymexlamlir-opt %s --enzyme-wrap="infn=main outfn= retTys=enzyme_active argTys=enzyme_active mode=ReverseModeCombined" --arith-raise --canonicalize --remove-unnecessary-enzyme-ops --verify-each=0 | FileCheck %s --check-prefix=REVERSE

func.func @main(%operand : tensor<2xcomplex<f32>>) -> tensor<2xf32> {
%result = "stablehlo.real"(%operand) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
return %result : tensor<2xf32>
}

// FORWARD: func.func @main(%arg0: tensor<2xcomplex<f32>>, %arg1: tensor<2xcomplex<f32>>) -> (tensor<2xf32>, tensor<2xf32>) {
// FORWARD-NEXT: %0 = stablehlo.real %arg1 : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
// FORWARD-NEXT: %1 = stablehlo.real %arg0 : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
// FORWARD-NEXT: return %1, %0 : tensor<2xf32>, tensor<2xf32>
// FORWARD-NEXT: }

// REVERSE: func.func @main(%arg0: tensor<2xcomplex<f32>>, %arg1: tensor<2xf32>) -> tensor<2xcomplex<f32>> {
// REVERSE-NEXT: %cst = stablehlo.constant dense<0.000000e+00> : tensor<2xf32>
// REVERSE-NEXT: %cst_0 = arith.constant dense<0.000000e+00> : tensor<2xf32>
// REVERSE-NEXT: %cst_1 = arith.constant dense<(0.000000e+00,0.000000e+00)> : tensor<2xcomplex<f32>>
// REVERSE-NEXT: %0 = stablehlo.add %cst_0, %arg1 : tensor<2xf32>
// REVERSE-NEXT: %1 = stablehlo.complex %0, %cst : tensor<2xcomplex<f32>>
// REVERSE-NEXT: %2 = chlo.conj %1 : tensor<2xcomplex<f32>> -> tensor<2xcomplex<f32>>
// REVERSE-NEXT: %3 = stablehlo.add %cst_1, %2 : tensor<2xcomplex<f32>>
// REVERSE-NEXT: return %3 : tensor<2xcomplex<f32>>
// REVERSE-NEXT: }

0 comments on commit 4aef5f7

Please sign in to comment.