Skip to content

Commit

Permalink
[BUG Fix]: restrict max 2d block load heigth according to spec (#1152)
Browse files Browse the repository at this point in the history
![image](https://github.com/intel/intel-xpu-backend-for-triton/assets/68101902/8598d5e4-d70e-491a-a042-91f386b5b3b3)
For 2d load, max height is 32(also for VNNI load). For those loads that
have height greater than 32, they will return all zeros. This will lead
to incorrect results.
  • Loading branch information
quintinwang5 authored May 20, 2024
1 parent 7141854 commit 1bd2c8e
Showing 1 changed file with 2 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,8 @@ MatchTargetSizePass::getSubOpSize(RankedTensorType type) const {
(isa<ttgi::WarpEncodingAttr, ttg::DotOperandEncodingAttr>(layout)) ? 32
: 0;
subSize[1] = (shape[1] > colLimit) ? colLimit : shape[1];
int64_t max = maxLoadStoreSize * 4 / sizeInBytes / subSize[1];
// FIXME: From gfxspec, max 2d block load height is 32
int64_t max = 32;
subSize[0] = std::min(max, shape[0]);
} break;
default:
Expand Down

0 comments on commit 1bd2c8e

Please sign in to comment.