Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use stride instead of order to determine block attr #2349

Merged
merged 5 commits into from
Sep 26, 2024

Conversation

alexbaden
Copy link
Contributor

Per the Triton slack, order is unused on architecture below Hopper. But more importantly, order provides information that stride already has. In fact, order can be completely different from stride (i.e. wrong) and we still generate correct code. I think it is better to use the stride assuming the logic I added here makes sense.

Note this depends on #2348, I'd like to land the debug logging separately, so we have it even if we decide to modify this approach. It was very useful in debugging this problem.

cc #2347

@alexbaden alexbaden force-pushed the alex/materialize_block_pointer_stride branch from db2a0e9 to c09e10e Compare September 26, 2024 00:52
Copy link
Contributor

@whitneywhtsang whitneywhtsang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add a lit test?

alexbaden and others added 2 commits September 25, 2024 20:56
Co-authored-by: Whitney Tsang <whitney.tsang@intel.com>
@alexbaden
Copy link
Contributor Author

can we add a lit test?

The existing tests actually cover this scenario - because they change both the order and the stride.

    // CHECK: tt.load {{.*}} {boundaryCheck = array<i32: 1>, padding = 1 : i32, triton_intel_gpu.block_io = "row_major"}
    // CHECK: tt.load {{.*}} {boundaryCheck = array<i32: 0>, padding = 1 : i32, triton_intel_gpu.block_io = "row_major"}
    %3 = tt.make_tensor_ptr %arg0, [%c0_i64, %c0_i64], [%pitch, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x32xf16, #dot_a>>
    %4 = tt.make_tensor_ptr %arg0, [%c0_i64, %c0_i64], [%pitch, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<32x64xf16, #dot_b>>
    %5 = tt.load %3 {boundaryCheck = array<i32: 1>, cache = 1 : i32, evict = 1 : i32, isVolatile = false, padding = 1 : i32} : !tt.ptr<tensor<64x32xf16, #dot_a>>
    %6 = tt.load %4 {boundaryCheck = array<i32: 0>, cache = 1 : i32, evict = 1 : i32, isVolatile = false, padding = 1 : i32} : !tt.ptr<tensor<32x64xf16, #dot_b>>

    // CHECK: tt.load {{.*}} {boundaryCheck = array<i32: 1>, padding = 1 : i32, triton_intel_gpu.block_io = "column_major"}
    // CHECK: tt.load {{.*}} {boundaryCheck = array<i32: 0>, padding = 1 : i32, triton_intel_gpu.block_io = "column_major"}
    %7 = tt.make_tensor_ptr %arg0, [%c0_i64, %c0_i64], [%c1_i64, %pitch], [%c0_i32, %c0_i32] {order = array<i32: 0, 1>} : <tensor<64x32xf16, #dot_a>>
    %8 = tt.make_tensor_ptr %arg0, [%c0_i64, %c0_i64], [%c1_i64, %pitch], [%c0_i32, %c0_i32] {order = array<i32: 0, 1>} : <tensor<32x64xf16, #dot_b>>
    %9 = tt.load %7 {boundaryCheck = array<i32: 1>, cache = 1 : i32, evict = 1 : i32, isVolatile = false, padding = 1 : i32} : !tt.ptr<tensor<64x32xf16, #dot_a>>
    %10 = tt.load %8 {boundaryCheck = array<i32: 0>, cache = 1 : i32, evict = 1 : i32, isVolatile = false, padding = 1 : i32} : !tt.ptr<tensor<32x64xf16, #dot_b>>

But, I added one that covers this scenario + rewrite tensor pointer here: #2347

@alexbaden alexbaden merged commit 979301f into main Sep 26, 2024
4 checks passed
@alexbaden alexbaden deleted the alex/materialize_block_pointer_stride branch September 26, 2024 01:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants