-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Accept Hopper matmuls and update default heuristic #3579
Changes from 8 commits
cb13e25
cd2d1e1
1692b5d
c8097b9
3751734
076d56a
700df1f
89f4887
15200fe
bba9c88
e5def4c
9a691c6
80c1232
fbde1e2
e311250
7305778
3f7b6a6
d8f80e2
6c17823
486e4d9
5c6f504
c9f1805
0f1ad25
7ce8938
da8f27f
7c3cd60
f11516d
0252ad4
9e0118d
8f4b796
10489a9
c77f6e4
645bc74
a5b79ac
3c2b0fb
69ac0e5
c71d6b7
0290b39
3c4170e
f386536
c56d845
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -49,25 +49,43 @@ using ProblemShape = std::array<int64_t, 4>; | |||||
inline std::optional<MmaMacro> getMmaOp( | ||||||
const int dev_version, | ||||||
const ProblemShape& problem) { | ||||||
using MacroType = MmaMacro; | ||||||
const int64_t n_extent = problem[(size_t)MatmulDimRole::N]; | ||||||
|
||||||
// NOTE: A temp condition | ||||||
const ProblemShape::value_type n_extend = problem[(size_t)MatmulDimRole::N]; | ||||||
const bool use_small_n = ((n_extend % 8) == 0) && ((n_extend % 16) != 0); | ||||||
MmaMacroEncode macro_encode{MmaMacroEncode::Arch::NoMma, 16, 8, 16}; | ||||||
|
||||||
switch (dev_version) { | ||||||
case 75: | ||||||
return (use_small_n) ? MacroType::Turing_16_8_16 | ||||||
: MacroType::Turing_16_16_16; | ||||||
macro_encode.arch = MmaMacroEncode::Arch::Turing; | ||||||
if ((n_extent % 16) == 0) { | ||||||
macro_encode.n = 16; | ||||||
} | ||||||
break; | ||||||
case 80: | ||||||
case 86: | ||||||
case 89: | ||||||
case 90: // NOTE: temp use ampere matmul for hopper | ||||||
return (use_small_n) ? MacroType::Ampere_16_8_16 | ||||||
: MacroType::Ampere_16_16_16; | ||||||
macro_encode.arch = MmaMacroEncode::Arch::Ampere; | ||||||
if ((n_extent % 16) == 0) { | ||||||
macro_encode.n = 16; | ||||||
} | ||||||
break; | ||||||
case 90: | ||||||
macro_encode.arch = MmaMacroEncode::Arch::Hopper; | ||||||
macro_encode.m = 64; | ||||||
// Find the largest instruction tile that divides the problem size and is | ||||||
// a power of two | ||||||
macro_encode.n = 256; | ||||||
while (macro_encode.n >= 8) { | ||||||
if (n_extent % macro_encode.n != 0) { | ||||||
macro_encode.n /= 2; | ||||||
} else { | ||||||
break; | ||||||
} | ||||||
} | ||||||
break; | ||||||
default: | ||||||
return std::nullopt; | ||||||
} | ||||||
return macro_encode; | ||||||
} | ||||||
|
||||||
//! Find the number of circular buffer stages for shared memory operands, so | ||||||
|
@@ -93,9 +111,9 @@ void limitCircularBufferingSmemOperands( | |||||
mparams->circular_buffer_options.smem_circular_buffer_stage = (int)stages; | ||||||
} | ||||||
|
||||||
//! A wrapper for core heuristics initialization. | ||||||
//! We should have already set mparams->mma_macro before calling this function. | ||||||
inline bool initCoreHeuristics( | ||||||
namespace { | ||||||
|
||||||
bool fillDefaultAmpereHeuristic( | ||||||
MatmulParams* mparams, | ||||||
const ProblemShape& problem_shape, | ||||||
const mma_utils::TensorRolesMap& tensor_roles) { | ||||||
|
@@ -150,7 +168,7 @@ inline bool initCoreHeuristics( | |||||
// stages and async mem copy | ||||||
{ | ||||||
// NOTE: compilation errors when async is enabled on Turing devices | ||||||
if (isAmpere(mparams->mma_macro)) { | ||||||
if (!isTuring(mparams->mma_macro)) { | ||||||
constexpr int stages = 3; | ||||||
|
||||||
mparams->circular_buffer_options.circular_buffer_smem_write = true; | ||||||
|
@@ -169,6 +187,7 @@ inline bool initCoreHeuristics( | |||||
} | ||||||
return min_size_bytes; | ||||||
}; | ||||||
// Use cp.async on Ampere if possible | ||||||
mparams->async_gmem_load_operands = isCpAsyncOperandLoadSupported( | ||||||
mparams, | ||||||
std::min( | ||||||
|
@@ -185,6 +204,124 @@ inline bool initCoreHeuristics( | |||||
return true; | ||||||
} | ||||||
|
||||||
bool fillDefaultHopperHeuristic( | ||||||
MatmulParams* mparams, | ||||||
const ProblemShape& problem_shape, | ||||||
const mma_utils::TensorRolesMap& tensor_roles) { | ||||||
const GemmTile instruction_tile = getMmaOpShape(mparams->mma_macro); | ||||||
GemmTile warp_tile = {-1, -1, -1}; | ||||||
GemmTile cta_tile = {-1, -1, -1}; | ||||||
|
||||||
using DimType = decltype(GemmTile::m); | ||||||
|
||||||
// We typically use larger macros on Hopper. By default we will set the | ||||||
// warp tile equal to the macro and increase the CTA tile until we hit | ||||||
// a limit. The limits are given by the maximum number of threads per CTA. | ||||||
|
||||||
// TODO: it might be advantageous in some cases to issue multiple wgmma | ||||||
// instructions per warp group | ||||||
warp_tile = instruction_tile; | ||||||
|
||||||
// The MmaOp output is a 32-bit float which requires one register per value | ||||||
|
||||||
const DimType max_registers_per_sm = 512 * 100; | ||||||
|
||||||
const auto ratiosValid = [&](const DimType m_ratio, const DimType n_ratio) { | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nitpick: snake_case for lambda functions.
Suggested change
jacobhinkle marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
DimType cta_m = warp_tile.m * m_ratio; | ||||||
DimType cta_n = warp_tile.n * n_ratio; | ||||||
DimType num_warp_groups = m_ratio * n_ratio; | ||||||
return cta_n * cta_m < max_registers_per_sm | ||||||
// Each warp group is 128 threads. We can only have a maximum of 1024 | ||||||
// threads per SM, or 8 warp groups. | ||||||
&& num_warp_groups <= 8 && | ||||||
// Don't extend the CTA tile beyond the problem size | ||||||
warp_tile.m * (m_ratio + 1) <= | ||||||
problem_shape[(size_t)MatmulDimRole::M] && | ||||||
warp_tile.n * (n_ratio + 1) <= problem_shape[(size_t)MatmulDimRole::N]; | ||||||
}; | ||||||
|
||||||
DimType m_ratio = 1; | ||||||
DimType n_ratio = 1; | ||||||
|
||||||
bool increased = true; | ||||||
while (increased) { | ||||||
DimType cta_m = warp_tile.m * m_ratio; | ||||||
DimType cta_n = warp_tile.n * n_ratio; | ||||||
increased = false; | ||||||
|
||||||
const auto tryIncreaseM = [&]() { | ||||||
jacobhinkle marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
if (ratiosValid(m_ratio + 1, n_ratio)) { | ||||||
m_ratio++; | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should these also be powers of two? Currently this will chooses sizes like 192 |
||||||
increased = true; | ||||||
} | ||||||
return increased; | ||||||
}; | ||||||
const auto tryIncreaseN = [&]() { | ||||||
jacobhinkle marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
if (ratiosValid(m_ratio, n_ratio + 1)) { | ||||||
n_ratio++; | ||||||
increased = true; | ||||||
} | ||||||
return increased; | ||||||
}; | ||||||
|
||||||
if (cta_m < cta_n) { | ||||||
// Try to increase smaller tile dimension first since square tiles are | ||||||
// optimal for reducing operand load redundancy | ||||||
if (tryIncreaseM()) { | ||||||
continue; | ||||||
} | ||||||
tryIncreaseN(); | ||||||
} else { | ||||||
if (tryIncreaseN()) { | ||||||
continue; | ||||||
} | ||||||
tryIncreaseM(); | ||||||
} | ||||||
} | ||||||
|
||||||
cta_tile = {warp_tile.m * m_ratio, warp_tile.n * n_ratio, warp_tile.k}; | ||||||
|
||||||
mparams->tile_sizes = {cta_tile, warp_tile}; | ||||||
|
||||||
// stages and async mem copy | ||||||
{ | ||||||
constexpr int stages = 3; | ||||||
|
||||||
mparams->circular_buffer_options.circular_buffer_smem_write = true; | ||||||
mparams->circular_buffer_options.circular_buffer_smem_read = true; | ||||||
mparams->circular_buffer_options.smem_circular_buffer_stage = stages; | ||||||
} | ||||||
|
||||||
// Always use TMA on Hopper | ||||||
mparams->async_gmem_load_operands = true; | ||||||
|
||||||
if (!mparams->async_gmem_load_operands) { | ||||||
// Circular buffering requires async load. If we cannot use async load due | ||||||
// to unsupported vectorization width, then we can only circular buffer at | ||||||
// most. | ||||||
mparams->circular_buffer_options.smem_circular_buffer_stage = std::min( | ||||||
2, mparams->circular_buffer_options.smem_circular_buffer_stage); | ||||||
} | ||||||
return true; | ||||||
} | ||||||
|
||||||
} // namespace | ||||||
|
||||||
//! A wrapper for core heuristics initialization. | ||||||
//! We should have already set mparams->mma_macro before calling this function. | ||||||
inline bool initCoreHeuristics( | ||||||
MatmulParams* mparams, | ||||||
const ProblemShape& problem_shape, | ||||||
const mma_utils::TensorRolesMap& tensor_roles) { | ||||||
if (isHopper(mparams->mma_macro)) { | ||||||
return fillDefaultHopperHeuristic(mparams, problem_shape, tensor_roles); | ||||||
} else if (isAmpere(mparams->mma_macro) || isTuring(mparams->mma_macro)) { | ||||||
return fillDefaultHopperHeuristic(mparams, problem_shape, tensor_roles); | ||||||
} | ||||||
// Unsupported arch | ||||||
return false; | ||||||
} | ||||||
|
||||||
//! A helper for getting problem shape from fusion and runtime info. | ||||||
//! | ||||||
//! For a given domain, try to find the size by evaluating the extent of an | ||||||
|
@@ -825,13 +962,24 @@ std::string getMatmulCompileTimeRejectReason(Fusion* fusion) { | |||||
{ | ||||||
for (const mma_utils::MatmulPattern& pattern : patterns) { | ||||||
Expr* op = pattern.output->definition(); | ||||||
if (device_prop->major >= 9 && op->isA<ReductionOp>()) { | ||||||
bool found_reduction = false; | ||||||
for (size_t dim : c10::irange((size_t)pattern.output->nDims())) { | ||||||
if (found_reduction && | ||||||
!pattern.output->axis((int64_t)dim)->isReduction()) { | ||||||
return "Mul+Sum patterns can only be translated to MmaOp " | ||||||
"on Hopper if the reduction dim is innermost"; | ||||||
if (device_prop->major >= 9) { | ||||||
for (TensorView* operand : {pattern.A, pattern.B}) { | ||||||
if (!operand->isFusionInput() || operand->definition() == nullptr || | ||||||
!operand->definition()->isA<LoadStoreOp>() || | ||||||
!operand->definition()->input(0)->isFusionInput() || | ||||||
operand->hasRoot()) { | ||||||
return "Operand " + operand->toString() + | ||||||
" must be a fusion input or non-permuting LoadStoreOp of an input on Hopper"; | ||||||
} | ||||||
} | ||||||
if (op->isA<ReductionOp>()) { | ||||||
bool found_reduction = false; | ||||||
for (size_t dim : c10::irange((size_t)pattern.output->nDims())) { | ||||||
if (found_reduction && | ||||||
!pattern.output->axis((int64_t)dim)->isReduction()) { | ||||||
return "Mul+Sum patterns can only be translated to MmaOp " | ||||||
"on Hopper if the reduction dim is innermost"; | ||||||
} | ||||||
} | ||||||
} | ||||||
} | ||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1865,10 +1865,6 @@ class MatmulTranslator : public OptInDispatch { | |
int64_t out_dim = pattern_.A->nDims() + 1L; | ||
axis_mapping.a_axes.reserve(out_dim); | ||
for (int64_t d : c10::irange(out_dim - 2L)) { | ||
axis_mapping.a_axes.push_back(d); | ||
} | ||
axis_mapping.a_axes.reserve(out_dim); | ||
for (size_t d : c10::irange(out_dim - 2)) { | ||
Comment on lines
-1862
to
-1865
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this was just due to a busted merge. |
||
axis_mapping.a_axes.push_back((int64_t)d); | ||
} | ||
axis_mapping.a_axes.push_back(-1); // missing N dimension | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently this only chooses powers of two. For small problems I think we could choose one of the other sizes. For example if
n_extent == 72
then we should probably use that size.