Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Accept Hopper matmuls and update default heuristic #3579

Merged
merged 41 commits into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
cb13e25
Accept Hopper matmuls and update default heuristic
jacobhinkle Dec 12, 2024
cd2d1e1
Merge remote-tracking branch 'origin/main' into hopper_matmul_heuristic
jacobhinkle Dec 18, 2024
1692b5d
Merge remote-tracking branch 'origin/main' into hopper_matmul_heuristic
jacobhinkle Dec 19, 2024
c8097b9
Fix bug in translating hopper LinearOps
jacobhinkle Dec 19, 2024
3751734
Factor default heuristic by arch
jacobhinkle Dec 19, 2024
076d56a
Unguard 2dA_2dB LinearOp translation tests on Hopper
jacobhinkle Dec 19, 2024
700df1f
Check prologues in compile-time check
jacobhinkle Dec 19, 2024
89f4887
Fix up getMmaOp
jacobhinkle Dec 19, 2024
15200fe
Revert innocuous change to ampere path
jacobhinkle Dec 19, 2024
bba9c88
Fix condition in prologue check
jacobhinkle Dec 19, 2024
e5def4c
Merge remote-tracking branch 'origin/main' into hopper_matmul_heuristic
jacobhinkle Dec 19, 2024
9a691c6
Fix up merge
jacobhinkle Dec 19, 2024
80c1232
Add incomplete fix for repeated operands in patterns
jacobhinkle Dec 19, 2024
fbde1e2
Enable BFloat16 in stmatrix
jacobhinkle Dec 23, 2024
e311250
Remove mistakenly-pasted line
jacobhinkle Dec 23, 2024
7305778
Merge remote-tracking branch 'origin/stmatrix_bfloat' into hopper_mat…
jacobhinkle Dec 23, 2024
3f7b6a6
Fix compile error
jacobhinkle Dec 23, 2024
d8f80e2
Merge remote-tracking branch 'origin/stmatrix_bfloat' into hopper_mat…
jacobhinkle Dec 23, 2024
6c17823
Merge remote-tracking branch 'origin/main' into hopper_matmul_heuristic
jacobhinkle Dec 23, 2024
486e4d9
clang-tidy
jacobhinkle Dec 30, 2024
5c6f504
Merge remote-tracking branch 'origin/main' into hopper_matmul_heuristic
jacobhinkle Dec 30, 2024
c9f1805
clang-tidy
jacobhinkle Dec 30, 2024
0f1ad25
Add test that we skip fusing matmuls with 64-bit indexing
jacobhinkle Jan 2, 2025
7ce8938
Default to not setting cluster_dims
jacobhinkle Jan 2, 2025
da8f27f
Keep instruction size >=64 for default heuristic
jacobhinkle Jan 2, 2025
7c3cd60
Don't increase CTA beyond 256
jacobhinkle Jan 2, 2025
f11516d
Default to 6 stages. This will be limited based on smem usage later
jacobhinkle Jan 2, 2025
0252ad4
Skip hopper matmuls that need int64 indexing
jacobhinkle Jan 2, 2025
9e0118d
Merge remote-tracking branch 'origin/main' into hopper_matmul_heuristic
jacobhinkle Jan 7, 2025
8f4b796
Limit macro to N=64
jacobhinkle Jan 7, 2025
10489a9
Increase by multiples of 2, respect device num registers
jacobhinkle Jan 7, 2025
c77f6e4
Set circular buffer stages to max, respecting smem limit
jacobhinkle Jan 7, 2025
645bc74
Always promote smem reuse on Hopper
jacobhinkle Jan 7, 2025
a5b79ac
Relax constraint of exactly one operand in heuristic plugin
jacobhinkle Jan 7, 2025
3c2b0fb
Apply suggestions from code review
jacobhinkle Jan 7, 2025
69ac0e5
Fix compile, accomodate extra warp group for warp specialization
jacobhinkle Jan 7, 2025
c71d6b7
Add comment/clean up
jacobhinkle Jan 7, 2025
0290b39
Add tests related to upcoming benchmarking effort
jacobhinkle Jan 7, 2025
3c4170e
Remove comment
jacobhinkle Jan 7, 2025
f386536
clang-tidy
jacobhinkle Jan 8, 2025
c56d845
Merge remote-tracking branch 'origin/main' into hopper_matmul_heuristic
jacobhinkle Jan 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
188 changes: 168 additions & 20 deletions csrc/scheduler/matmul_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,25 +49,43 @@ using ProblemShape = std::array<int64_t, 4>;
inline std::optional<MmaMacro> getMmaOp(
const int dev_version,
const ProblemShape& problem) {
using MacroType = MmaMacro;
const int64_t n_extent = problem[(size_t)MatmulDimRole::N];

// NOTE: A temp condition
const ProblemShape::value_type n_extend = problem[(size_t)MatmulDimRole::N];
const bool use_small_n = ((n_extend % 8) == 0) && ((n_extend % 16) != 0);
MmaMacroEncode macro_encode{MmaMacroEncode::Arch::NoMma, 16, 8, 16};

switch (dev_version) {
case 75:
return (use_small_n) ? MacroType::Turing_16_8_16
: MacroType::Turing_16_16_16;
macro_encode.arch = MmaMacroEncode::Arch::Turing;
if ((n_extent % 16) == 0) {
macro_encode.n = 16;
}
break;
case 80:
case 86:
case 89:
case 90: // NOTE: temp use ampere matmul for hopper
return (use_small_n) ? MacroType::Ampere_16_8_16
: MacroType::Ampere_16_16_16;
macro_encode.arch = MmaMacroEncode::Arch::Ampere;
if ((n_extent % 16) == 0) {
macro_encode.n = 16;
}
break;
case 90:
macro_encode.arch = MmaMacroEncode::Arch::Hopper;
macro_encode.m = 64;
// Find the largest instruction tile that divides the problem size and is
// a power of two
macro_encode.n = 256;
while (macro_encode.n >= 8) {
if (n_extent % macro_encode.n != 0) {
macro_encode.n /= 2;
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently this only chooses powers of two. For small problems I think we could choose one of the other sizes. For example if n_extent == 72 then we should probably use that size.

} else {
break;
}
}
break;
default:
return std::nullopt;
}
return macro_encode;
}

//! Find the number of circular buffer stages for shared memory operands, so
Expand All @@ -93,9 +111,9 @@ void limitCircularBufferingSmemOperands(
mparams->circular_buffer_options.smem_circular_buffer_stage = (int)stages;
}

//! A wrapper for core heuristics initialization.
//! We should have already set mparams->mma_macro before calling this function.
inline bool initCoreHeuristics(
namespace {

bool fillDefaultAmpereHeuristic(
MatmulParams* mparams,
const ProblemShape& problem_shape,
const mma_utils::TensorRolesMap& tensor_roles) {
Expand Down Expand Up @@ -150,7 +168,7 @@ inline bool initCoreHeuristics(
// stages and async mem copy
{
// NOTE: compilation errors when async is enabled on Turing devices
if (isAmpere(mparams->mma_macro)) {
if (!isTuring(mparams->mma_macro)) {
constexpr int stages = 3;

mparams->circular_buffer_options.circular_buffer_smem_write = true;
Expand All @@ -169,6 +187,7 @@ inline bool initCoreHeuristics(
}
return min_size_bytes;
};
// Use cp.async on Ampere if possible
mparams->async_gmem_load_operands = isCpAsyncOperandLoadSupported(
mparams,
std::min(
Expand All @@ -185,6 +204,124 @@ inline bool initCoreHeuristics(
return true;
}

bool fillDefaultHopperHeuristic(
MatmulParams* mparams,
const ProblemShape& problem_shape,
const mma_utils::TensorRolesMap& tensor_roles) {
const GemmTile instruction_tile = getMmaOpShape(mparams->mma_macro);
GemmTile warp_tile = {-1, -1, -1};
GemmTile cta_tile = {-1, -1, -1};

using DimType = decltype(GemmTile::m);

// We typically use larger macros on Hopper. By default we will set the
// warp tile equal to the macro and increase the CTA tile until we hit
// a limit. The limits are given by the maximum number of threads per CTA.

// TODO: it might be advantageous in some cases to issue multiple wgmma
// instructions per warp group
warp_tile = instruction_tile;

// The MmaOp output is a 32-bit float which requires one register per value

const DimType max_registers_per_sm = 512 * 100;

const auto ratiosValid = [&](const DimType m_ratio, const DimType n_ratio) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick: snake_case for lambda functions.

Suggested change
const auto ratiosValid = [&](const DimType m_ratio, const DimType n_ratio) {
const auto ratios_valid = [&](const DimType m_ratio, const DimType n_ratio) {

jacobhinkle marked this conversation as resolved.
Show resolved Hide resolved
DimType cta_m = warp_tile.m * m_ratio;
DimType cta_n = warp_tile.n * n_ratio;
DimType num_warp_groups = m_ratio * n_ratio;
return cta_n * cta_m < max_registers_per_sm
// Each warp group is 128 threads. We can only have a maximum of 1024
// threads per SM, or 8 warp groups.
&& num_warp_groups <= 8 &&
// Don't extend the CTA tile beyond the problem size
warp_tile.m * (m_ratio + 1) <=
problem_shape[(size_t)MatmulDimRole::M] &&
warp_tile.n * (n_ratio + 1) <= problem_shape[(size_t)MatmulDimRole::N];
};

DimType m_ratio = 1;
DimType n_ratio = 1;

bool increased = true;
while (increased) {
DimType cta_m = warp_tile.m * m_ratio;
DimType cta_n = warp_tile.n * n_ratio;
increased = false;

const auto tryIncreaseM = [&]() {
jacobhinkle marked this conversation as resolved.
Show resolved Hide resolved
if (ratiosValid(m_ratio + 1, n_ratio)) {
m_ratio++;
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should these also be powers of two? Currently this will chooses sizes like 192

increased = true;
}
return increased;
};
const auto tryIncreaseN = [&]() {
jacobhinkle marked this conversation as resolved.
Show resolved Hide resolved
if (ratiosValid(m_ratio, n_ratio + 1)) {
n_ratio++;
increased = true;
}
return increased;
};

if (cta_m < cta_n) {
// Try to increase smaller tile dimension first since square tiles are
// optimal for reducing operand load redundancy
if (tryIncreaseM()) {
continue;
}
tryIncreaseN();
} else {
if (tryIncreaseN()) {
continue;
}
tryIncreaseM();
}
}

cta_tile = {warp_tile.m * m_ratio, warp_tile.n * n_ratio, warp_tile.k};

mparams->tile_sizes = {cta_tile, warp_tile};

// stages and async mem copy
{
constexpr int stages = 3;

mparams->circular_buffer_options.circular_buffer_smem_write = true;
mparams->circular_buffer_options.circular_buffer_smem_read = true;
mparams->circular_buffer_options.smem_circular_buffer_stage = stages;
}

// Always use TMA on Hopper
mparams->async_gmem_load_operands = true;

if (!mparams->async_gmem_load_operands) {
// Circular buffering requires async load. If we cannot use async load due
// to unsupported vectorization width, then we can only circular buffer at
// most.
mparams->circular_buffer_options.smem_circular_buffer_stage = std::min(
2, mparams->circular_buffer_options.smem_circular_buffer_stage);
}
return true;
}

} // namespace

//! A wrapper for core heuristics initialization.
//! We should have already set mparams->mma_macro before calling this function.
inline bool initCoreHeuristics(
MatmulParams* mparams,
const ProblemShape& problem_shape,
const mma_utils::TensorRolesMap& tensor_roles) {
if (isHopper(mparams->mma_macro)) {
return fillDefaultHopperHeuristic(mparams, problem_shape, tensor_roles);
} else if (isAmpere(mparams->mma_macro) || isTuring(mparams->mma_macro)) {
return fillDefaultHopperHeuristic(mparams, problem_shape, tensor_roles);
}
// Unsupported arch
return false;
}

//! A helper for getting problem shape from fusion and runtime info.
//!
//! For a given domain, try to find the size by evaluating the extent of an
Expand Down Expand Up @@ -825,13 +962,24 @@ std::string getMatmulCompileTimeRejectReason(Fusion* fusion) {
{
for (const mma_utils::MatmulPattern& pattern : patterns) {
Expr* op = pattern.output->definition();
if (device_prop->major >= 9 && op->isA<ReductionOp>()) {
bool found_reduction = false;
for (size_t dim : c10::irange((size_t)pattern.output->nDims())) {
if (found_reduction &&
!pattern.output->axis((int64_t)dim)->isReduction()) {
return "Mul+Sum patterns can only be translated to MmaOp "
"on Hopper if the reduction dim is innermost";
if (device_prop->major >= 9) {
for (TensorView* operand : {pattern.A, pattern.B}) {
if (!operand->isFusionInput() || operand->definition() == nullptr ||
!operand->definition()->isA<LoadStoreOp>() ||
!operand->definition()->input(0)->isFusionInput() ||
operand->hasRoot()) {
return "Operand " + operand->toString() +
" must be a fusion input or non-permuting LoadStoreOp of an input on Hopper";
}
}
if (op->isA<ReductionOp>()) {
bool found_reduction = false;
for (size_t dim : c10::irange((size_t)pattern.output->nDims())) {
if (found_reduction &&
!pattern.output->axis((int64_t)dim)->isReduction()) {
return "Mul+Sum patterns can only be translated to MmaOp "
"on Hopper if the reduction dim is innermost";
}
}
}
}
Expand Down
4 changes: 0 additions & 4 deletions csrc/scheduler/mma_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1865,10 +1865,6 @@ class MatmulTranslator : public OptInDispatch {
int64_t out_dim = pattern_.A->nDims() + 1L;
axis_mapping.a_axes.reserve(out_dim);
for (int64_t d : c10::irange(out_dim - 2L)) {
axis_mapping.a_axes.push_back(d);
}
axis_mapping.a_axes.reserve(out_dim);
for (size_t d : c10::irange(out_dim - 2)) {
Comment on lines -1862 to -1865
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this was just due to a busted merge.

axis_mapping.a_axes.push_back((int64_t)d);
}
axis_mapping.a_axes.push_back(-1); // missing N dimension
Expand Down
7 changes: 6 additions & 1 deletion tests/cpp/test_translate_mma.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,7 @@ using LinearNodeTranslationTest =
// Test that a simple linear op fusion is picked up by the appropriate scheduler
// and the translation to MmaOp is performed properly.
TEST_P(LinearNodeTranslationTest, AutomaticSchedulerLinearNode) {
NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(7, 5, 9, 0);
NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(7, 5, 10, 0);
// The allocation domain propagation pass sets the output allocation domain,
// which sometimes causes the matmul scheduler to decline the whole fusion
// when it could compile it otherwise.
Expand All @@ -488,6 +488,11 @@ TEST_P(LinearNodeTranslationTest, AutomaticSchedulerLinearNode) {

EnableOptionsGuard eog;
if (enable_fusion) {
if (A_dim != 2 && !cudaArchGuardShouldSkip(9, 0)) {
GTEST_SKIP()
<< "Translating linear with batch dims is not yet supported on Hopper";
}

EnableOptionsGuard::getCurOptions().set(EnableOption::FuseMatmul);
} else {
EnableOptionsGuard::getCurOptions().unset(EnableOption::FuseMatmul);
Expand Down
Loading