Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Accept Hopper matmuls and update default heuristic #3579

Open
wants to merge 33 commits into
base: main
Choose a base branch
from

Conversation

jacobhinkle
Copy link
Collaborator

@jacobhinkle jacobhinkle commented Dec 12, 2024

This updates the default (non-plugin) matmul heuristic to support Hopper matmuls. This change means that we can not run matmuls on Hopper similarly to how we do it on Ampere and Turing, including using the Python interface.

I tried to make the default heuristic somewhat thoughtful and not just a placeholder. Here are some notes about the Hopper heuristic in its current form:

  • I set the macro to Hopper_64_64_16. I intended to always use the largest macro for which the N size divided the problem's N, but this led to lower perf on the handful of examples I looked at. We should benchmark more and find out why this is once we have warp specialization and register stealing fully plumbed in, but for the time being I simply left it at N=64.
  • Once the instruction tile is set we set the warp tile equal to the instruction tile (we can revisit this in the future). Then to find the CTA tile we double the instruction tile in the M or N dimension until we run out of registers.
  • We start with 8 circular buffering stages and decrease until the circular buffers fit into smem.
  • We use use_smem_epilogue when possible. Whenever that is possible we always use promote_prologue_smem_reuse even if it's not needed. This is to try and avoid bugs like Misaligned read from smem doing TMA store #3602.
  • I set the tile rasterization order so that the fast axis is the axis with the fewest tiles, which should encourage more L2 hits unless there are tons of tiles in each dimension.
  • I cannot yet set grid swizzling due to Inlining error in Hopper matmul with AxisMapping and grid swizzling #3671, but I placed a TODO comment and some code to do the proper swizzling.

jacobhinkle added a commit that referenced this pull request Dec 16, 2024
This enables Hopper matmul in our automatic scheduler by translating
them without introducing new broadcasts. Specifically:

1. Update `mma_utils::MatmulPattern::translateToMmaOp` to optionally
avoid intermediates by using an `MmaOp::AxisMapping`. Enable this option
when the target arch is not Ampere or Turing.
3. Unguard some tests in `test_translate_mma.cpp`

This does not update the default heuristic or change the `canSchedule`
checks. See #3579 for that follow-up PR

---------

Co-authored-by: Ryan Spring <rspring@nvidia.com>
Co-authored-by: Naoya Maruyama <naoyam@users.noreply.github.com>
Co-authored-by: Jingyue Wu <wujingyue@gmail.com>
Co-authored-by: nsarka <nsarkauskas@nvidia.com>
Co-authored-by: Protonu <pbasu@nvidia.com>
Co-authored-by: samnordmann <snordmann@nvidia.com>
Comment on lines -1868 to -1865
axis_mapping.a_axes.push_back(d);
}
axis_mapping.a_axes.reserve(out_dim);
for (size_t d : c10::irange(out_dim - 2)) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this was just due to a busted merge.

macro_encode.n = 256;
while (macro_encode.n >= 8) {
if (n_extent % macro_encode.n != 0) {
macro_encode.n /= 2;
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently this only chooses powers of two. For small problems I think we could choose one of the other sizes. For example if n_extent == 72 then we should probably use that size.


const auto tryIncreaseM = [&]() {
if (ratiosValid(m_ratio + 1, n_ratio)) {
m_ratio++;
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should these also be powers of two? Currently this will chooses sizes like 192

@jacobhinkle jacobhinkle changed the title [WIP] Accept Hopper matmuls and update default heuristic Accept Hopper matmuls and update default heuristic Dec 20, 2024
@jacobhinkle
Copy link
Collaborator Author

!test

@jacobhinkle jacobhinkle marked this pull request as ready for review January 7, 2025 02:01
@jacobhinkle
Copy link
Collaborator Author

!test

@jacobhinkle
Copy link
Collaborator Author

!test --matmul-bench

@jacobhinkle jacobhinkle requested a review from rdspring1 January 7, 2025 02:03
@@ -193,7 +193,7 @@ class MatmulParams : public HeuristicParams {

//! This is the CGA size on Hopper+ devices. This parameter is ignored on
//! Ampere and Turing.
std::tuple<int64_t, int64_t, int64_t> cluster_dims = {2, 1, 1};
std::tuple<int64_t, int64_t, int64_t> cluster_dims = {1, 1, 1};
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the grid size is not divisible by the cluster size then we get a launch error, so we should default to not use cluster dims unless explicitly handled by a heuristic.

Copy link
Collaborator

@rdspring1 rdspring1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. I just left some thoughts.

// The Hopper register file is 256KiB. We reduce this by a factor of 1/2 to
// account for overhead, since not all of the registers will hold MMA
// outputs.
const size_t max_registers_per_sm = device_prop->regsPerMultiprocessor / 2L;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if getRegPerThreadGivenThreadsPerSM is more accurate than registers_per_sm / 2.

Suggested change
const size_t max_registers_per_sm = device_prop->regsPerMultiprocessor / 2L;
// tma warp group + 2 * compute warp groups
constexpr int64_t threads_per_sm = 384;
const size_t max_registers_per_sm = getRegPerThreadGivenThreadsPerSM(threads_per_sm) * threads_per_sm;

// outputs.
const size_t max_registers_per_sm = device_prop->regsPerMultiprocessor / 2L;

const size_t regs_per_warp_group = warp_tile.m * warp_tile.n * num_problems;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
const size_t regs_per_warp_group = warp_tile.m * warp_tile.n * num_problems;
// total accumulator registers for warp group
const size_t accum_regs_per_warp_group = warp_tile.m * warp_tile.n * num_problems;


const size_t regs_per_warp_group = warp_tile.m * warp_tile.n * num_problems;

const auto ratiosValid = [&](const DimType m_ratio, const DimType n_ratio) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick: snake_case for lambda functions.

Suggested change
const auto ratiosValid = [&](const DimType m_ratio, const DimType n_ratio) {
const auto ratios_valid = [&](const DimType m_ratio, const DimType n_ratio) {

DimType cta_n = warp_tile.n * n_ratio;
increased = false;

const auto tryIncreaseM = [&]() {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
const auto tryIncreaseM = [&]() {
const auto try_increaseM = [&]() {

}
return increased;
};
const auto tryIncreaseN = [&]() {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
const auto tryIncreaseN = [&]() {
const auto try_increaseN = [&]() {


const size_t regs_per_warp_group = warp_tile.m * warp_tile.n * num_problems;

const auto ratiosValid = [&](const DimType m_ratio, const DimType n_ratio) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

alternate name and description.

Suggested change
const auto ratiosValid = [&](const DimType m_ratio, const DimType n_ratio) {
// The cta tile is a multiple of the warp tile. This lambda checks that cta tile given by warp_tile and multiple fits on the SM.
const auto validate_cta_tile_multiple = [&](const DimType m_ratio, const DimType n_ratio) {

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants