Skip to content

Commit

Permalink
fix transform tests (#103)
Browse files Browse the repository at this point in the history
  • Loading branch information
makslevental authored Oct 23, 2024
1 parent b2b8d72 commit d47c553
Showing 1 changed file with 19 additions and 23 deletions.
42 changes: 19 additions & 23 deletions tests/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,11 +653,10 @@ def tile_inner(target):
%extracted_slice_1 = tensor.extract_slice %arg6[0, %arg3, %1, %2] [1, 1, 8, 8] [1, 1, 1, 1] : tensor<1x3x64x64xf32> to tensor<1x1x8x8xf32>
%3 = scf.forall (%arg7, %arg8, %arg9) in (1, 8, 8) shared_outs(%arg10 = %extracted_slice_1) -> (tensor<1x1x8x8xf32>) {
%extracted_slice_2 = tensor.extract_slice %extracted_slice[0, 0, %arg8, %arg9] [1, 1, 3, 3] [1, 1, 1, 1] : tensor<1x1x10x10xf32> to tensor<1x1x3x3xf32>
%extracted_slice_3 = tensor.extract_slice %extracted_slice_0[%arg7, 0, 0, 0] [1, 1, 3, 3] [1, 1, 1, 1] : tensor<1x1x3x3xf32> to tensor<1x1x3x3xf32>
%extracted_slice_4 = tensor.extract_slice %arg10[0, %arg7, %arg8, %arg9] [1, 1, 1, 1] [1, 1, 1, 1] : tensor<1x1x8x8xf32> to tensor<1x1x1x1xf32>
%4 = linalg.conv_2d_nchw_fchw ins(%extracted_slice_2, %extracted_slice_3 : tensor<1x1x3x3xf32>, tensor<1x1x3x3xf32>) outs(%extracted_slice_4 : tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32>
%extracted_slice_3 = tensor.extract_slice %arg10[0, 0, %arg8, %arg9] [1, 1, 1, 1] [1, 1, 1, 1] : tensor<1x1x8x8xf32> to tensor<1x1x1x1xf32>
%4 = linalg.conv_2d_nchw_fchw ins(%extracted_slice_2, %extracted_slice_0 : tensor<1x1x3x3xf32>, tensor<1x1x3x3xf32>) outs(%extracted_slice_3 : tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32>
scf.forall.in_parallel {
tensor.parallel_insert_slice %4 into %arg10[0, %arg7, %arg8, %arg9] [1, 1, 1, 1] [1, 1, 1, 1] : tensor<1x1x1x1xf32> into tensor<1x1x8x8xf32>
tensor.parallel_insert_slice %4 into %arg10[0, 0, %arg8, %arg9] [1, 1, 1, 1] [1, 1, 1, 1] : tensor<1x1x1x1xf32> into tensor<1x1x8x8xf32>
}
} {mapping = [#gpu.thread<x>, #gpu.thread<y>, #gpu.thread<z>]}
scf.forall.in_parallel {
Expand Down Expand Up @@ -983,37 +982,34 @@ def main(variant_op: any_op_t()):
)
correct = dedent(
"""\
#map = affine_map<(d0) -> (d0 * 16)>
#map1 = affine_map<(d0) -> (d0 * 64)>
#map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>
#map3 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d1, d5, d4)>
#map4 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>
#map = affine_map<(d0) -> (d0 * 64)>
#map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>
#map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d1, d5, d4)>
#map3 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>
module {
module attributes {transform.target_tag = "payload"} {
func.func @matmul_i8_i8(%arg0: tensor<16x256xi8>, %arg1: tensor<256x256xi8>) -> tensor<16x256xi8> {
%c0_i32 = arith.constant 0 : i32
%0 = tensor.empty() : tensor<16x256xi8>
%1 = tensor.empty() : tensor<1x4x16x64xi8>
%pack = tensor.pack %arg0 inner_dims_pos = [0, 1] inner_tiles = [16, 64] into %1 : tensor<16x256xi8> -> tensor<1x4x16x64xi8>
%2 = tensor.empty() : tensor<4x1x64x64xi8>
%3 = tensor.empty() : tensor<1x1x16x64xi8>
%4 = linalg.fill ins(%c0_i32 : i32) outs(%3 : tensor<1x1x16x64xi8>) -> tensor<1x1x16x64xi8>
%5 = scf.forall (%arg2, %arg3) in (1, 4) shared_outs(%arg4 = %0) -> (tensor<16x256xi8>) {
%6 = affine.apply #map(%arg2)
%7 = affine.apply #map1(%arg3)
%extracted_slice = tensor.extract_slice %arg0[%6, 0] [16, 256] [1, 1] : tensor<16x256xi8> to tensor<16x256xi8>
%extracted_slice_0 = tensor.extract_slice %arg1[0, %7] [256, 64] [1, 1] : tensor<256x256xi8> to tensor<256x64xi8>
%extracted_slice_1 = tensor.extract_slice %arg4[%6, %7] [16, 64] [1, 1] : tensor<16x256xi8> to tensor<16x64xi8>
%pack = tensor.pack %extracted_slice inner_dims_pos = [0, 1] inner_tiles = [16, 64] into %1 : tensor<16x256xi8> -> tensor<1x4x16x64xi8>
%pack_2 = tensor.pack %extracted_slice_0 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %2 : tensor<256x64xi8> -> tensor<4x1x64x64xi8>
%8 = linalg.generic {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%pack, %pack_2 : tensor<1x4x16x64xi8>, tensor<4x1x64x64xi8>) outs(%4 : tensor<1x1x16x64xi8>) {
^bb0(%in: i8, %in_3: i8, %out: i8):
%9 = arith.muli %in, %in_3 : i8
%10 = arith.addi %out, %9 : i8
linalg.yield %10 : i8
%6 = affine.apply #map(%arg3)
%extracted_slice = tensor.extract_slice %arg1[0, %6] [256, 64] [1, 1] : tensor<256x256xi8> to tensor<256x64xi8>
%extracted_slice_0 = tensor.extract_slice %arg4[0, %6] [16, 64] [1, 1] : tensor<16x256xi8> to tensor<16x64xi8>
%pack_1 = tensor.pack %extracted_slice outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %2 : tensor<256x64xi8> -> tensor<4x1x64x64xi8>
%7 = linalg.generic {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%pack, %pack_1 : tensor<1x4x16x64xi8>, tensor<4x1x64x64xi8>) outs(%4 : tensor<1x1x16x64xi8>) {
^bb0(%in: i8, %in_2: i8, %out: i8):
%8 = arith.muli %in, %in_2 : i8
%9 = arith.addi %out, %8 : i8
linalg.yield %9 : i8
} -> tensor<1x1x16x64xi8>
%unpack = tensor.unpack %8 inner_dims_pos = [0, 1] inner_tiles = [16, 64] into %extracted_slice_1 : tensor<1x1x16x64xi8> -> tensor<16x64xi8>
%unpack = tensor.unpack %7 inner_dims_pos = [0, 1] inner_tiles = [16, 64] into %extracted_slice_0 : tensor<1x1x16x64xi8> -> tensor<16x64xi8>
scf.forall.in_parallel {
tensor.parallel_insert_slice %unpack into %arg4[%6, %7] [16, 64] [1, 1] : tensor<16x64xi8> into tensor<16x256xi8>
tensor.parallel_insert_slice %unpack into %arg4[0, %6] [16, 64] [1, 1] : tensor<16x64xi8> into tensor<16x256xi8>
}
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
return %5 : tensor<16x256xi8>
Expand Down

0 comments on commit d47c553

Please sign in to comment.