-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Accept Hopper matmuls and update default heuristic #3579
Merged
+457
−47
Merged
Changes from all commits
Commits
Show all changes
41 commits
Select commit
Hold shift + click to select a range
cb13e25
Accept Hopper matmuls and update default heuristic
jacobhinkle cd2d1e1
Merge remote-tracking branch 'origin/main' into hopper_matmul_heuristic
jacobhinkle 1692b5d
Merge remote-tracking branch 'origin/main' into hopper_matmul_heuristic
jacobhinkle c8097b9
Fix bug in translating hopper LinearOps
jacobhinkle 3751734
Factor default heuristic by arch
jacobhinkle 076d56a
Unguard 2dA_2dB LinearOp translation tests on Hopper
jacobhinkle 700df1f
Check prologues in compile-time check
jacobhinkle 89f4887
Fix up getMmaOp
jacobhinkle 15200fe
Revert innocuous change to ampere path
jacobhinkle bba9c88
Fix condition in prologue check
jacobhinkle e5def4c
Merge remote-tracking branch 'origin/main' into hopper_matmul_heuristic
jacobhinkle 9a691c6
Fix up merge
jacobhinkle 80c1232
Add incomplete fix for repeated operands in patterns
jacobhinkle fbde1e2
Enable BFloat16 in stmatrix
jacobhinkle e311250
Remove mistakenly-pasted line
jacobhinkle 7305778
Merge remote-tracking branch 'origin/stmatrix_bfloat' into hopper_mat…
jacobhinkle 3f7b6a6
Fix compile error
jacobhinkle d8f80e2
Merge remote-tracking branch 'origin/stmatrix_bfloat' into hopper_mat…
jacobhinkle 6c17823
Merge remote-tracking branch 'origin/main' into hopper_matmul_heuristic
jacobhinkle 486e4d9
clang-tidy
jacobhinkle 5c6f504
Merge remote-tracking branch 'origin/main' into hopper_matmul_heuristic
jacobhinkle c9f1805
clang-tidy
jacobhinkle 0f1ad25
Add test that we skip fusing matmuls with 64-bit indexing
jacobhinkle 7ce8938
Default to not setting cluster_dims
jacobhinkle da8f27f
Keep instruction size >=64 for default heuristic
jacobhinkle 7c3cd60
Don't increase CTA beyond 256
jacobhinkle f11516d
Default to 6 stages. This will be limited based on smem usage later
jacobhinkle 0252ad4
Skip hopper matmuls that need int64 indexing
jacobhinkle 9e0118d
Merge remote-tracking branch 'origin/main' into hopper_matmul_heuristic
jacobhinkle 8f4b796
Limit macro to N=64
jacobhinkle 10489a9
Increase by multiples of 2, respect device num registers
jacobhinkle c77f6e4
Set circular buffer stages to max, respecting smem limit
jacobhinkle 645bc74
Always promote smem reuse on Hopper
jacobhinkle a5b79ac
Relax constraint of exactly one operand in heuristic plugin
jacobhinkle 3c2b0fb
Apply suggestions from code review
jacobhinkle 69ac0e5
Fix compile, accomodate extra warp group for warp specialization
jacobhinkle c71d6b7
Add comment/clean up
jacobhinkle 0290b39
Add tests related to upcoming benchmarking effort
jacobhinkle 3c4170e
Remove comment
jacobhinkle f386536
clang-tidy
jacobhinkle c56d845
Merge remote-tracking branch 'origin/main' into hopper_matmul_heuristic
jacobhinkle File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,9 +20,14 @@ | |
#include <ir/interface_nodes.h> | ||
#include <ir/internal_nodes.h> | ||
#include <ir/utils.h> | ||
#include <mma_type.h> | ||
#include <options.h> | ||
#include <runtime/executor_utils.h> | ||
#include <scheduler/mma_utils.h> | ||
#include <type.h> | ||
#include <utils.h> | ||
#include <val_graph.h> | ||
|
||
#include <algorithm> | ||
#include <deque> | ||
#include <iostream> | ||
|
@@ -32,11 +37,8 @@ | |
#include <type_traits> | ||
#include <utility> | ||
#include <variant> | ||
#include "ATen/cuda/CUDAContext.h" | ||
#include "mma_type.h" | ||
#include "mma_utils.h" | ||
#include "type.h" | ||
#include "utils.h" | ||
|
||
#include <ATen/cuda/CUDAContext.h> | ||
|
||
namespace nvfuser { | ||
namespace matmul_utils { | ||
|
@@ -49,25 +51,44 @@ using ProblemShape = std::array<int64_t, 4>; | |
inline std::optional<MmaMacro> getMmaOp( | ||
const int dev_version, | ||
const ProblemShape& problem) { | ||
using MacroType = MmaMacro; | ||
const int64_t n_extent = problem[(size_t)MatmulDimRole::N]; | ||
|
||
// NOTE: A temp condition | ||
const ProblemShape::value_type n_extend = problem[(size_t)MatmulDimRole::N]; | ||
const bool use_small_n = ((n_extend % 8) == 0) && ((n_extend % 16) != 0); | ||
MmaMacroEncode macro_encode{MmaMacroEncode::Arch::NoMma, 16, 8, 16}; | ||
|
||
switch (dev_version) { | ||
case 75: | ||
return (use_small_n) ? MacroType::Turing_16_8_16 | ||
: MacroType::Turing_16_16_16; | ||
macro_encode.arch = MmaMacroEncode::Arch::Turing; | ||
if ((n_extent % 16) == 0) { | ||
macro_encode.n = 16; | ||
} | ||
break; | ||
case 80: | ||
case 86: | ||
case 89: | ||
case 90: // NOTE: temp use ampere matmul for hopper | ||
return (use_small_n) ? MacroType::Ampere_16_8_16 | ||
: MacroType::Ampere_16_16_16; | ||
macro_encode.arch = MmaMacroEncode::Arch::Ampere; | ||
if ((n_extent % 16) == 0) { | ||
macro_encode.n = 16; | ||
} | ||
break; | ||
case 90: | ||
macro_encode.arch = MmaMacroEncode::Arch::Hopper; | ||
macro_encode.m = 64; | ||
// Find the largest instruction tile that divides the problem size and is | ||
// a power of two | ||
macro_encode.n = 64; | ||
// TODO: enable instructions smaller than 64_64_16 | ||
while (macro_encode.n > 64) { | ||
if (n_extent % macro_encode.n != 0) { | ||
macro_encode.n /= 2; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Currently this only chooses powers of two. For small problems I think we could choose one of the other sizes. For example if |
||
} else { | ||
break; | ||
} | ||
} | ||
break; | ||
default: | ||
return std::nullopt; | ||
} | ||
return macro_encode; | ||
} | ||
|
||
//! Find the number of circular buffer stages for shared memory operands, so | ||
|
@@ -93,9 +114,9 @@ void limitCircularBufferingSmemOperands( | |
mparams->circular_buffer_options.smem_circular_buffer_stage = (int)stages; | ||
} | ||
|
||
//! A wrapper for core heuristics initialization. | ||
//! We should have already set mparams->mma_macro before calling this function. | ||
inline bool initCoreHeuristics( | ||
namespace { | ||
|
||
bool fillDefaultAmpereHeuristic( | ||
MatmulParams* mparams, | ||
const ProblemShape& problem_shape, | ||
const mma_utils::TensorRolesMap& tensor_roles, | ||
|
@@ -170,6 +191,7 @@ inline bool initCoreHeuristics( | |
} | ||
return min_size_bytes; | ||
}; | ||
// Use cp.async on Ampere if possible | ||
mparams->async_gmem_load_operands = isCpAsyncOperandLoadSupported( | ||
mparams, | ||
std::min( | ||
|
@@ -186,6 +208,180 @@ inline bool initCoreHeuristics( | |
return true; | ||
} | ||
|
||
bool fillDefaultHopperHeuristic( | ||
MatmulParams* mparams, | ||
const ProblemShape& problem_shape, | ||
const mma_utils::TensorRolesMap& tensor_roles, | ||
const size_t num_problems) { | ||
const auto device_prop = at::cuda::getCurrentDeviceProperties(); | ||
|
||
const GemmTile instruction_tile = getMmaOpShape(mparams->mma_macro); | ||
GemmTile warp_tile = {-1, -1, -1}; | ||
GemmTile cta_tile = {-1, -1, -1}; | ||
|
||
using DimType = decltype(GemmTile::m); | ||
|
||
// We typically use larger macros on Hopper. By default we will set the | ||
// warp tile equal to the macro and increase the CTA tile until we hit | ||
// a limit. The limits are given by the maximum number of threads per CTA. | ||
|
||
// TODO: it might be advantageous in some cases to issue multiple wgmma | ||
// instructions per warp group | ||
warp_tile = instruction_tile; | ||
|
||
// The MmaOp output is a 32-bit float which requires one register per value | ||
|
||
// total accumulator registers for warp group | ||
const size_t accum_regs_per_warp_group = | ||
warp_tile.m * warp_tile.n * num_problems; | ||
|
||
// The cta tile is a multiple of the warp tile. This lambda checks that cta | ||
// tile given by warp_tile and multiple fits on the SM. | ||
const auto validate_cta_tile_multiple = [&](const DimType m_ratio, | ||
const DimType n_ratio) { | ||
DimType cta_m = warp_tile.m * m_ratio; | ||
DimType cta_n = warp_tile.n * n_ratio; | ||
DimType num_compute_warp_groups = m_ratio * n_ratio; | ||
|
||
// This assumes warp specialization: | ||
// tma warp group + compute warp groups | ||
DimType num_warp_groups = num_compute_warp_groups + 1; | ||
|
||
const int64_t threads_per_sm = num_warp_groups * 128; | ||
const size_t max_registers_per_sm = | ||
getRegPerThreadGivenThreadsPerSM(threads_per_sm) * threads_per_sm; | ||
return | ||
// We store one float per CTA tile element for each matmul problem we | ||
// compute | ||
num_warp_groups * accum_regs_per_warp_group < max_registers_per_sm | ||
// TMA box dimensions must be less than or equal to 256 | ||
&& cta_m <= 256 && | ||
cta_n <= 256 | ||
// Each warp group is 128 threads. We can only have a maximum of 1024 | ||
// threads per SM, or 8 warp groups. | ||
&& num_warp_groups <= 8 && | ||
// Don't extend the CTA tile beyond the problem size | ||
cta_m <= problem_shape[(size_t)MatmulDimRole::M] && | ||
cta_n <= problem_shape[(size_t)MatmulDimRole::N]; | ||
}; | ||
|
||
DimType m_ratio = 1; | ||
DimType n_ratio = 1; | ||
|
||
bool increased = true; | ||
while (increased) { | ||
DimType cta_m = warp_tile.m * m_ratio; | ||
DimType cta_n = warp_tile.n * n_ratio; | ||
increased = false; | ||
|
||
const auto try_increaseM = [&]() { | ||
if (validate_cta_tile_multiple(m_ratio * 2, n_ratio)) { | ||
m_ratio *= 2; | ||
increased = true; | ||
} | ||
return increased; | ||
}; | ||
const auto try_increaseN = [&]() { | ||
if (validate_cta_tile_multiple(m_ratio, n_ratio * 2)) { | ||
n_ratio *= 2; | ||
increased = true; | ||
} | ||
return increased; | ||
}; | ||
|
||
if (cta_m < cta_n) { | ||
// Try to increase smaller tile dimension first since square tiles are | ||
// optimal for reducing operand load redundancy | ||
if (try_increaseM()) { | ||
continue; | ||
} | ||
try_increaseN(); | ||
} else { | ||
if (try_increaseN()) { | ||
continue; | ||
} | ||
try_increaseM(); | ||
} | ||
} | ||
|
||
cta_tile = {warp_tile.m * m_ratio, warp_tile.n * n_ratio, warp_tile.k}; | ||
|
||
mparams->tile_sizes = {cta_tile, warp_tile}; | ||
|
||
// stages and async mem copy | ||
mparams->circular_buffer_options.smem_circular_buffer_stage = 8; | ||
|
||
// TODO: We should take the main loop structure into account here to get a | ||
// more accurate estimate in case of horizontal fusion | ||
int64_t operand_smem_per_stage = | ||
(int64_t)num_problems * 2 * (cta_tile.m + cta_tile.n) * cta_tile.k; | ||
// We leave a bit of space for semaphores | ||
int64_t max_operand_smem = | ||
(int64_t)device_prop->sharedMemPerBlock - (1L << 7); | ||
|
||
while (mparams->circular_buffer_options.smem_circular_buffer_stage * | ||
operand_smem_per_stage > | ||
max_operand_smem) { | ||
mparams->circular_buffer_options.smem_circular_buffer_stage--; | ||
} | ||
|
||
mparams->circular_buffer_options.circular_buffer_smem_write = | ||
mparams->circular_buffer_options.smem_circular_buffer_stage > 1; | ||
|
||
// Always use TMA on Hopper | ||
mparams->async_gmem_load_operands = true; | ||
|
||
// See here for more information: | ||
// https://research.colfax-intl.com/cutlass-tutorial-wgmma-hopper/ | ||
|
||
// We count the number of tiles in each dimension to determine the | ||
// rasterization order. The fast rasterization axis is the shortest axis, to | ||
// encourage L2 hits by looping over the same rows or cols more frequently. | ||
int64_t Mtiles = ceilDiv(problem_shape[(size_t)MatmulDimRole::M], cta_tile.m); | ||
int64_t Ntiles = ceilDiv(problem_shape[(size_t)MatmulDimRole::N], cta_tile.n); | ||
|
||
mparams->cta_order = Ntiles >= Mtiles | ||
? MatmulParams::TileRasterizationOrder::ColumnMajor | ||
: MatmulParams::TileRasterizationOrder::RowMajor; | ||
|
||
// We also swizzle the tiles as much as possible up to 4 tiles. Like choosing | ||
// the rasterization order, this is used to increase L2 locality | ||
mparams->grid_swizzle_factor = 4L; | ||
while (Mtiles % mparams->grid_swizzle_factor != 0 || | ||
Ntiles % mparams->grid_swizzle_factor != 0) { | ||
// Decrease the swizzle factor if it would result in nondivisible splits, | ||
// since this would unnecessarily increase the grid size. | ||
mparams->grid_swizzle_factor /= 2L; | ||
} | ||
// TODO: grid swizzling is currently disabled on Hopper since we cannot | ||
// properly inline when we swizzle unmapped loop broadcasts | ||
mparams->grid_swizzle_factor = 1L; | ||
|
||
// TODO: Finally, we set the CGA size | ||
|
||
return true; | ||
} | ||
|
||
} // namespace | ||
|
||
//! A wrapper for core heuristics initialization. | ||
//! We should have already set mparams->mma_macro before calling this function. | ||
inline bool initCoreHeuristics( | ||
MatmulParams* mparams, | ||
const ProblemShape& problem_shape, | ||
const mma_utils::TensorRolesMap& tensor_roles, | ||
const size_t num_problems) { | ||
if (isHopper(mparams->mma_macro)) { | ||
return fillDefaultHopperHeuristic( | ||
mparams, problem_shape, tensor_roles, num_problems); | ||
} else if (isAmpere(mparams->mma_macro) || isTuring(mparams->mma_macro)) { | ||
return fillDefaultAmpereHeuristic( | ||
mparams, problem_shape, tensor_roles, num_problems); | ||
} | ||
// Unsupported arch | ||
return false; | ||
} | ||
|
||
//! A helper for getting problem shape from fusion and runtime info. | ||
//! | ||
//! For a given domain, try to find the size by evaluating the extent of an | ||
|
@@ -790,7 +986,15 @@ std::unique_ptr<MatmulParams> getMatmulHeuristics( | |
mma_utils::generateSharedMemoryEpilogueHeuristics( | ||
mparams->tile_sizes, | ||
mparams->circular_buffer_options.smem_circular_buffer_stage, | ||
tensor_roles); | ||
tensor_roles, | ||
/*ignore_occupancy_drop=*/true); | ||
if (isHopper(mparams->mma_macro)) { | ||
// Always promote smem reuse for Hopper. This is needed because we use TMA | ||
// which has higher alignment requirements, so it's important that we place | ||
// our TMA buffers at an offset that's a multiple of 64 (like 0) if | ||
// possible. | ||
mparams->promote_prologue_smem_reuse = true; | ||
} | ||
|
||
if (isDebugDumpEnabled(DebugDumpOption::SchedulerDebug)) { | ||
debug() << mparams->toString() << std::endl; | ||
|
@@ -842,13 +1046,25 @@ std::string getMatmulCompileTimeRejectReason(Fusion* fusion) { | |
{ | ||
for (const mma_utils::MatmulPattern& pattern : patterns) { | ||
Expr* op = pattern.output->definition(); | ||
if (device_prop->major >= 9 && op->isA<ReductionOp>()) { | ||
bool found_reduction = false; | ||
for (size_t dim : c10::irange((size_t)pattern.output->nDims())) { | ||
if (found_reduction && | ||
!pattern.output->axis((int64_t)dim)->isReduction()) { | ||
return "Mul+Sum patterns can only be translated to MmaOp " | ||
"on Hopper if the reduction dim is innermost"; | ||
if (device_prop->major >= 9) { | ||
for (TensorView* operand : {pattern.A, pattern.B}) { | ||
if (!operand->isFusionInput() && | ||
(operand->definition() == nullptr || | ||
!operand->definition()->isA<LoadStoreOp>() || | ||
!operand->definition()->input(0)->isFusionInput() || | ||
operand->hasRoot())) { | ||
return "Operand " + operand->toString() + | ||
" must be a fusion input or non-permuting LoadStoreOp of an input on Hopper"; | ||
} | ||
} | ||
if (op->isA<ReductionOp>()) { | ||
bool found_reduction = false; | ||
for (size_t dim : c10::irange((size_t)pattern.output->nDims())) { | ||
if (found_reduction && | ||
!pattern.output->axis((int64_t)dim)->isReduction()) { | ||
return "Mul+Sum patterns can only be translated to MmaOp " | ||
"on Hopper if the reduction dim is innermost"; | ||
} | ||
} | ||
} | ||
} | ||
|
@@ -922,7 +1138,14 @@ std::string getMatmulRunTimeRejectReason( | |
Fusion* fusion, | ||
HeuristicDataCache* data_cache, | ||
SchedulerRuntimeInfo& runtime_info) { | ||
// TODO: add proper set of checks | ||
const auto device_prop = at::cuda::getCurrentDeviceProperties(); | ||
|
||
if (device_prop->major >= 9 && | ||
runtime_info.getIndexType() != DataType::Int32) { | ||
// See https://github.com/NVIDIA/Fuser/issues/3595 | ||
return "Hopper matmul is not yet supported with problem sizes requiring 64-bit indexing"; | ||
} | ||
|
||
return ""; | ||
} | ||
|
||
|
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the grid size is not divisible by the cluster size then we get a launch error, so we should default to not use cluster dims unless explicitly handled by a heuristic.