Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

stablehlo.compare derivative isn't implemented #57

Closed
avik-pal opened this issue Jul 27, 2024 · 5 comments
Closed

stablehlo.compare derivative isn't implemented #57

avik-pal opened this issue Jul 27, 2024 · 5 comments

Comments

@avik-pal
Copy link
Collaborator

error: Unimplemented derivative for argument 0 in reverse mode for op %4 = "stablehlo.select"(%3, %2, %1) : (tensor<10x10xi1>, tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>

error: could not compute the adjoint for this operation %3 = "stablehlo.compare"(%2, %1) <{comparison_direction = #stablehlo<comparison_direction GT>}> : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xi1>

Originally posted by @avik-pal in #55 (comment)

The relu activation test is marked broken for now. Once this is fixed that should pass.

@wsmoses
Copy link
Member

wsmoses commented Jul 27, 2024 via email

@wsmoses
Copy link
Member

wsmoses commented Jul 27, 2024 via email

@avik-pal
Copy link
Collaborator Author

julia> Reactant.@code_hlo sumabs2(relu, x_act_ca)
Module:
module attributes {transform.with_named_sequence} {
  func.func @main(%arg0: tensor<10x10xf32>) -> tensor<f32> {
    %cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<10x10xf32>
    %0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<10x10xf32>) -> tensor<10x10xf32>
    %1 = stablehlo.compare  GT, %0, %cst_0 : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xi1>
    %2 = stablehlo.select %1, %0, %cst_0 : tensor<10x10xi1>, tensor<10x10xf32>
    %3 = stablehlo.multiply %2, %2 : tensor<10x10xf32>
    %4 = stablehlo.reduce(%3 init: %cst) applies stablehlo.add across dimensions = [0, 1] : (tensor<10x10xf32>, tensor<f32>) -> tensor<f32>
    return %4 : tensor<f32>
  }
}

julia> Reactant.@code_hlo optimize=false sumabs2(relu, x_act_ca)
Module:
module {
  func.func private @abs2_broadcast_scalar(%arg0: tensor<f32>) -> (tensor<f32>, tensor<f32>) {
    %0 = stablehlo.transpose %arg0, dims = [] : (tensor<f32>) -> tensor<f32>
    %1 = stablehlo.multiply %0, %0 : tensor<f32>
    %2 = stablehlo.transpose %0, dims = [] : (tensor<f32>) -> tensor<f32>
    %3 = stablehlo.transpose %1, dims = [] : (tensor<f32>) -> tensor<f32>
    return %2, %3 : tensor<f32>, tensor<f32>
  }
  func.func @main(%arg0: tensor<10x10xf32>) -> (tensor<10x10xf32>, tensor<f32>) {
    %0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<10x10xf32>) -> tensor<10x10xf32>
    %cst = stablehlo.constant dense<0.000000e+00> : tensor<10x10xf32>
    %1 = stablehlo.compare  GT, %0, %cst : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xi1>
    %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<10x10xf32>
    %2 = stablehlo.select %1, %0, %cst_0 : tensor<10x10xi1>, tensor<10x10xf32>
    %cst_1 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %3:2 = enzyme.batch @abs2_broadcast_scalar(%2) {batch_shape = array<i64: 10, 10>} : (tensor<10x10xf32>) -> (tensor<10x10xf32>, tensor<10x10xf32>)
    %4 = stablehlo.reduce(%3#1 init: %cst_1) applies stablehlo.add across dimensions = [0, 1] : (tensor<10x10xf32>, tensor<f32>) -> tensor<f32>
    %5 = stablehlo.transpose %0, dims = [1, 0] : (tensor<10x10xf32>) -> tensor<10x10xf32>
    %6 = stablehlo.transpose %4, dims = [] : (tensor<f32>) -> tensor<f32>
    return %5, %6 : tensor<10x10xf32>, tensor<f32>
  }
}

@wsmoses
Copy link
Member

wsmoses commented Jul 27, 2024

fun fact the optimize=false is printed before AD is run, so you could even do the optimize=false on the function with the autodiff [which will make easier to repro in the future]

@Pangoraw
Copy link
Collaborator

This should be fixed with EnzymeAD/Enzyme-JAX#106.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants