Skip to content

Commit

Permalink
fix ir.pyt test
Browse files Browse the repository at this point in the history
jax.arg_info was missing in test strings
  • Loading branch information
ftynse committed Dec 6, 2023
1 parent f67b028 commit 475e08a
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions test/lit_tests/ir.pyt
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def fwdmode(a, b, c, d):
print(fwdmode.lower(ones, twos, ones, twos).compiler_ir(dialect="mhlo"))

# CHECK: module @jit_fwdmode attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
# CHECK-NEXT: func.func public @main(%arg0: tensor<2x3xf32> {mhlo.sharding = "{replicated}"}, %arg1: tensor<5x7xf32> {mhlo.sharding = "{replicated}"}, %arg2: tensor<2x3xf32> {mhlo.sharding = "{replicated}"}, %arg3: tensor<5x7xf32> {mhlo.sharding = "{replicated}"}) -> (tensor<6x9xf32> {jax.result_info = "[0][0]"}, tensor<4x6xf32> {jax.result_info = "[0][1]"}, tensor<6x9xf32> {jax.result_info = "[1][0]"}, tensor<4x6xf32> {jax.result_info = "[1][1]"}) {
# CHECK-NEXT: func.func public @main(%arg0: tensor<2x3xf32> {jax.arg_info = "a", mhlo.sharding = "{replicated}"}, %arg1: tensor<5x7xf32> {jax.arg_info = "b", mhlo.sharding = "{replicated}"}, %arg2: tensor<2x3xf32> {jax.arg_info = "c", mhlo.sharding = "{replicated}"}, %arg3: tensor<5x7xf32> {jax.arg_info = "d", mhlo.sharding = "{replicated}"}) -> (tensor<6x9xf32> {jax.result_info = "[0][0]"}, tensor<4x6xf32> {jax.result_info = "[0][1]"}, tensor<6x9xf32> {jax.result_info = "[1][0]"}, tensor<4x6xf32> {jax.result_info = "[1][1]"}) {
# CHECK-NEXT: %0 = mhlo.constant dense<1> : tensor<1xi64>
# CHECK-NEXT: %1:4 = mhlo.custom_call @jaxzyme.fwd(%0, %arg0, %arg2, %arg1, %arg3) {backend_config = ""} : (tensor<1xi64>, tensor<2x3xf32>, tensor<2x3xf32>, tensor<5x7xf32>, tensor<5x7xf32>) -> (tensor<6x9xf32>, tensor<6x9xf32>, tensor<4x6xf32>, tensor<4x6xf32>)
# CHECK-NEXT: return %1#0, %1#2, %1#1, %1#3 : tensor<6x9xf32>, tensor<4x6xf32>, tensor<6x9xf32>, tensor<4x6xf32>
Expand All @@ -48,7 +48,7 @@ def f(a, b):
print(f.lower(ones, twos).compiler_ir(dialect="mhlo"))

# CHECK: module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
# CHECK-NEXT: func.func public @main(%arg0: tensor<2x3xf32> {mhlo.sharding = "{replicated}"}, %arg1: tensor<5x7xf32> {mhlo.sharding = "{replicated}"}) -> (tensor<6x9xf32> {jax.result_info = "[0][0]"}, tensor<4x6xf32> {jax.result_info = "[0][1]"}, tensor<16xi8> {jax.result_info = "[1][<flat index 0>][0][<flat index 0>][0][0]"}, tensor<2x3xf32> {jax.result_info = "[1][<flat index 0>][0][<flat index 0>][0][1]"}, tensor<5x7xf32> {jax.result_info = "[1][<flat index 0>][0][<flat index 0>][0][2]"}) {
# CHECK-NEXT: func.func public @main(%arg0: tensor<2x3xf32> {jax.arg_info = "a", mhlo.sharding = "{replicated}"}, %arg1: tensor<5x7xf32> {jax.arg_info = "b", mhlo.sharding = "{replicated}"}) -> (tensor<6x9xf32> {jax.result_info = "[0][0]"}, tensor<4x6xf32> {jax.result_info = "[0][1]"}, tensor<16xi8> {jax.result_info = "[1][<flat index 0>][0][<flat index 0>][0][0]"}, tensor<2x3xf32> {jax.result_info = "[1][<flat index 0>][0][<flat index 0>][0][1]"}, tensor<5x7xf32> {jax.result_info = "[1][<flat index 0>][0][<flat index 0>][0][2]"}) {
# CHECK-NEXT: %0 = mhlo.constant dense<2> : tensor<1xi64>
# CHECK-NEXT: %1:3 = mhlo.custom_call @jaxzyme.aug(%0, %arg0, %arg1) {backend_config = ""} : (tensor<1xi64>, tensor<2x3xf32>, tensor<5x7xf32>) -> (tensor<6x9xf32>, tensor<4x6xf32>, tensor<16xi8>)
# CHECK-NEXT: return %1#0, %1#1, %1#2, %arg0, %arg1 : tensor<6x9xf32>, tensor<4x6xf32>, tensor<16xi8>, tensor<2x3xf32>, tensor<5x7xf32>
Expand All @@ -64,7 +64,7 @@ def g(a, b, x, y):
return primals, f_vjp((x, y))

# CHECK: module @jit_g attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
# CHECK-NEXT: func.func public @main(%arg0: tensor<2x3xf32> {mhlo.sharding = "{replicated}"}, %arg1: tensor<5x7xf32> {mhlo.sharding = "{replicated}"}, %arg2: tensor<6x9xf32> {mhlo.sharding = "{replicated}"}, %arg3: tensor<4x6xf32> {mhlo.sharding = "{replicated}"}) -> (tensor<6x9xf32> {jax.result_info = "[0][0]"}, tensor<4x6xf32> {jax.result_info = "[0][1]"}, tensor<2x3xf32> {jax.result_info = "[1][0]"}, tensor<5x7xf32> {jax.result_info = "[1][1]"}) {
# CHECK-NEXT: func.func public @main(%arg0: tensor<2x3xf32> {jax.arg_info = "a", mhlo.sharding = "{replicated}"}, %arg1: tensor<5x7xf32> {jax.arg_info = "b", mhlo.sharding = "{replicated}"}, %arg2: tensor<6x9xf32> {jax.arg_info = "x", mhlo.sharding = "{replicated}"}, %arg3: tensor<4x6xf32> {jax.arg_info = "y", mhlo.sharding = "{replicated}"}) -> (tensor<6x9xf32> {jax.result_info = "[0][0]"}, tensor<4x6xf32> {jax.result_info = "[0][1]"}, tensor<2x3xf32> {jax.result_info = "[1][0]"}, tensor<5x7xf32> {jax.result_info = "[1][1]"}) {
# CHECK-NEXT: %0 = mhlo.constant dense<3> : tensor<1xi64>
# CHECK-NEXT: %1:3 = mhlo.custom_call @jaxzyme.aug(%0, %arg0, %arg1) {backend_config = ""} : (tensor<1xi64>, tensor<2x3xf32>, tensor<5x7xf32>) -> (tensor<6x9xf32>, tensor<4x6xf32>, tensor<16xi8>)
# CHECK-NEXT: %2 = mhlo.constant dense<4> : tensor<1xi64>
Expand All @@ -89,4 +89,4 @@ print(jax.jit(f_vjp).lower((x, y)).compiler_ir(dialect="mhlo"))
# CHECK-NEXT: %1:2 = mhlo.custom_call @jaxzyme.rev(%0, %arg0, %arg1, %arg2) {backend_config = ""} : (tensor<1xi64>, tensor<16xi8>, tensor<6x9xf32>, tensor<4x6xf32>) -> (tensor<2x3xf32>, tensor<5x7xf32>)
# CHECK-NEXT: return %1#0, %1#1 : tensor<2x3xf32>, tensor<5x7xf32>
# CHECK-NEXT: }
# CHECK-NEXT: }
# CHECK-NEXT: }

0 comments on commit 475e08a

Please sign in to comment.