-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Accept Hopper matmuls and update default heuristic #3579
base: main
Are you sure you want to change the base?
Conversation
This enables Hopper matmul in our automatic scheduler by translating them without introducing new broadcasts. Specifically: 1. Update `mma_utils::MatmulPattern::translateToMmaOp` to optionally avoid intermediates by using an `MmaOp::AxisMapping`. Enable this option when the target arch is not Ampere or Turing. 3. Unguard some tests in `test_translate_mma.cpp` This does not update the default heuristic or change the `canSchedule` checks. See #3579 for that follow-up PR --------- Co-authored-by: Ryan Spring <rspring@nvidia.com> Co-authored-by: Naoya Maruyama <naoyam@users.noreply.github.com> Co-authored-by: Jingyue Wu <wujingyue@gmail.com> Co-authored-by: nsarka <nsarkauskas@nvidia.com> Co-authored-by: Protonu <pbasu@nvidia.com> Co-authored-by: samnordmann <snordmann@nvidia.com>
Must have been a broken merge
I'm still skipping the ones with batch dimensions on A since these hit an error currently. Will investigate later but we only need 2d A for now.
axis_mapping.a_axes.push_back(d); | ||
} | ||
axis_mapping.a_axes.reserve(out_dim); | ||
for (size_t d : c10::irange(out_dim - 2)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this was just due to a busted merge.
macro_encode.n = 256; | ||
while (macro_encode.n >= 8) { | ||
if (n_extent % macro_encode.n != 0) { | ||
macro_encode.n /= 2; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently this only chooses powers of two. For small problems I think we could choose one of the other sizes. For example if n_extent == 72
then we should probably use that size.
csrc/scheduler/matmul_utils.cpp
Outdated
|
||
const auto tryIncreaseM = [&]() { | ||
if (ratiosValid(m_ratio + 1, n_ratio)) { | ||
m_ratio++; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should these also be powers of two? Currently this will chooses sizes like 192
Should fix this for both matmul and linear, and for avoid_intermediates_ and not
The dtype for stmatrix should have never been constrained to only Half. The only constraint we have is that the dtype size is 16-bit. This PR is needed for us to actually use stmatrix in bfloat16 matmuls.
!test |
We should re-enable this when we plumb through cooperative launch and when we guard against invalid configs in the heuristic
!test |
!test --matmul-bench |
@@ -193,7 +193,7 @@ class MatmulParams : public HeuristicParams { | |||
|
|||
//! This is the CGA size on Hopper+ devices. This parameter is ignored on | |||
//! Ampere and Turing. | |||
std::tuple<int64_t, int64_t, int64_t> cluster_dims = {2, 1, 1}; | |||
std::tuple<int64_t, int64_t, int64_t> cluster_dims = {1, 1, 1}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the grid size is not divisible by the cluster size then we get a launch error, so we should default to not use cluster dims unless explicitly handled by a heuristic.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. I just left some thoughts.
// The Hopper register file is 256KiB. We reduce this by a factor of 1/2 to | ||
// account for overhead, since not all of the registers will hold MMA | ||
// outputs. | ||
const size_t max_registers_per_sm = device_prop->regsPerMultiprocessor / 2L; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if getRegPerThreadGivenThreadsPerSM
is more accurate than registers_per_sm / 2
.
const size_t max_registers_per_sm = device_prop->regsPerMultiprocessor / 2L; | |
// tma warp group + 2 * compute warp groups | |
constexpr int64_t threads_per_sm = 384; | |
const size_t max_registers_per_sm = getRegPerThreadGivenThreadsPerSM(threads_per_sm) * threads_per_sm; |
// outputs. | ||
const size_t max_registers_per_sm = device_prop->regsPerMultiprocessor / 2L; | ||
|
||
const size_t regs_per_warp_group = warp_tile.m * warp_tile.n * num_problems; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const size_t regs_per_warp_group = warp_tile.m * warp_tile.n * num_problems; | |
// total accumulator registers for warp group | |
const size_t accum_regs_per_warp_group = warp_tile.m * warp_tile.n * num_problems; |
|
||
const size_t regs_per_warp_group = warp_tile.m * warp_tile.n * num_problems; | ||
|
||
const auto ratiosValid = [&](const DimType m_ratio, const DimType n_ratio) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nitpick: snake_case for lambda functions.
const auto ratiosValid = [&](const DimType m_ratio, const DimType n_ratio) { | |
const auto ratios_valid = [&](const DimType m_ratio, const DimType n_ratio) { |
DimType cta_n = warp_tile.n * n_ratio; | ||
increased = false; | ||
|
||
const auto tryIncreaseM = [&]() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const auto tryIncreaseM = [&]() { | |
const auto try_increaseM = [&]() { |
} | ||
return increased; | ||
}; | ||
const auto tryIncreaseN = [&]() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const auto tryIncreaseN = [&]() { | |
const auto try_increaseN = [&]() { |
|
||
const size_t regs_per_warp_group = warp_tile.m * warp_tile.n * num_problems; | ||
|
||
const auto ratiosValid = [&](const DimType m_ratio, const DimType n_ratio) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
alternate name and description.
const auto ratiosValid = [&](const DimType m_ratio, const DimType n_ratio) { | |
// The cta tile is a multiple of the warp tile. This lambda checks that cta tile given by warp_tile and multiple fits on the SM. | |
const auto validate_cta_tile_multiple = [&](const DimType m_ratio, const DimType n_ratio) { |
This updates the default (non-plugin) matmul heuristic to support Hopper matmuls. This change means that we can not run matmuls on Hopper similarly to how we do it on Ampere and Turing, including using the Python interface.
I tried to make the default heuristic somewhat thoughtful and not just a placeholder. Here are some notes about the Hopper heuristic in its current form:
use_smem_epilogue
when possible. Whenever that is possible we always usepromote_prologue_smem_reuse
even if it's not needed. This is to try and avoid bugs like Misaligned read from smem doing TMA store #3602.