Skip to content

Commit

Permalink
MmaOp definition - cleaning (#166)
Browse files Browse the repository at this point in the history
MmaOp definition cleaning (#166)

  - remove redundant input layout from struct with mma options and matmul
    heuristics, MmaOp attribute is the single source of truth,
  - fix layout infering, works correctly for gemm, batch gemm and
    inter/intra split-k reductions,
  - update MmaOp constructors to ensure that all attributes are
    initialized,
  - rename mma_op to mma_macro in matmul heuristics struct,
  - replace c10::optional with std::optional,
  - fix typos,
  • Loading branch information
drzejan2 authored Apr 19, 2023
1 parent 26ecf7e commit 6904a74
Show file tree
Hide file tree
Showing 13 changed files with 103 additions and 94 deletions.
3 changes: 1 addition & 2 deletions benchmark/matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -309,8 +309,7 @@ MatmulParams getMatmulParams(
gemm_tile.instruction_tile = GemmTile(16, 16, 16);

MatmulParams params;
params.mma_op = MmaOptions::MacroType::Ampere_16_16_16;
params.layout = layout;
params.mma_macro = MmaOptions::MacroType::Ampere_16_16_16;
params.tile_sizes = gemm_tile;
params.async_gmem_load_operands = true;
params.double_buffer_options.double_buffer_smem_write = true;
Expand Down
4 changes: 3 additions & 1 deletion csrc/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -935,7 +935,9 @@ class CudaKernelGenerator : private OptOutConstDispatch {
ss << toString(options.macro);

if (isVolta(options.macro)) {
ss << toString(options.operand_layout);
TORCH_INTERNAL_ASSERT(
mma->inputLayout().has_value(), "mma unknown input layout");
ss << toString(mma->inputLayout().value());
} else if (isTuring(options.macro) || isAmpere(options.macro)) {
// mma's in turing and ampere TN only, transpose is handled either
// via ldmatrix for fp16 or explicitly for other types.
Expand Down
12 changes: 6 additions & 6 deletions csrc/executor_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1198,7 +1198,7 @@ void fillCompileOptions(

// CUDA 11.1 allows going directly to SASS (sm_) instead of PTX (compute_)
// which gives better backwards compatibility to work on older driver,
// (since older driver doesn't necessrily recognize PTX emitted by new
// (since older driver doesn't necessarily recognize PTX emitted by new
// toolkit);
// Meanwhile, for forward compatibility (future device with
// `unsupported_arch==True`), since SASS are not necessarily compatible,
Expand Down Expand Up @@ -1399,11 +1399,11 @@ std::tuple<NvrtcFunction, std::string, std::vector<char>> getCompiledKernel(
compile_to_sass = false;
}

NvrtcCompileDriver nvrtc_comiple_driver;
NvrtcCompileDriver nvrtc_compile_driver;
CuModuleLoadDataDriver module_load_driver;

fillCompileOptions(
nvrtc_comiple_driver,
nvrtc_compile_driver,
module_load_driver,
compile_to_sass,
major,
Expand All @@ -1415,7 +1415,7 @@ std::tuple<NvrtcFunction, std::string, std::vector<char>> getCompiledKernel(

if (compile_to_sass) {
log << "\nCompile options: ";
for (const auto& opt : nvrtc_comiple_driver.options()) {
for (const auto& opt : nvrtc_compile_driver.options()) {
log << opt << " ";
}
if (opt_block_size.has_value()) {
Expand All @@ -1427,7 +1427,7 @@ std::tuple<NvrtcFunction, std::string, std::vector<char>> getCompiledKernel(
std::vector<char> object_code;
std::string lowered_kernel_name_str;
const auto compile_args =
toDelimitedString(nvrtc_comiple_driver.options(), " ");
toDelimitedString(nvrtc_compile_driver.options(), " ");

auto& kernel_db = KernelDb::get();
const auto use_kernel_db = kernel_db.enabled() && kernel_code.has_value();
Expand All @@ -1440,7 +1440,7 @@ std::tuple<NvrtcFunction, std::string, std::vector<char>> getCompiledKernel(
lowered_kernel_name_str,
object_code))) {
std::tie(object_code, lowered_kernel_name_str) = compileSource(
full_src_code, func_name, id, compile_to_sass, nvrtc_comiple_driver);
full_src_code, func_name, id, compile_to_sass, nvrtc_compile_driver);

if (use_kernel_db) {
auto result = kernel_db.write(
Expand Down
8 changes: 4 additions & 4 deletions csrc/ir_internal_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -1035,17 +1035,16 @@ class TORCH_CUDA_CU_API MmaOp : public Expr {
// after additional cleaning ups.
struct OptionsInMma {
MmaOptions::MacroType macro = MmaOptions::MacroType::NoMMA;
MmaOptions::MmaInputLayout operand_layout = MmaOptions::MmaInputLayout::TT;
int accumulator_stride = 0;

bool operator==(const OptionsInMma& other) const {
return macro == other.macro && operand_layout == other.operand_layout &&
return macro == other.macro &&
accumulator_stride == other.accumulator_stride;
}
};

using AxesData = std::vector<int>;
using MmaInputLayoutOpt = c10::optional<MmaOptions::MmaInputLayout>;
using MmaInputLayoutOpt = std::optional<MmaOptions::MmaInputLayout>;
using Expr::Expr;

MmaOp(IrBuilderPasskey, Val* out, Val* in_a, Val* in_b, Val* init);
Expand All @@ -1056,7 +1055,8 @@ class TORCH_CUDA_CU_API MmaOp : public Expr {
Val* in_a,
Val* in_b,
Val* init,
OptionsInMma options);
const OptionsInMma& options,
const MmaInputLayoutOpt& input_layout);

NVFUSER_DECLARE_CLONE_AND_CREATE

Expand Down
65 changes: 41 additions & 24 deletions csrc/ir_nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1350,7 +1350,7 @@ struct MmaOpDetails {
// and output
AxesData batch_axes;
// A placeholder for mma input layout
c10::optional<MmaOptions::MmaInputLayout> input_layout = c10::nullopt;
std::optional<MmaOptions::MmaInputLayout> input_layout = std::nullopt;
};

// A helper structure with pieces of information about TensorView
Expand Down Expand Up @@ -1393,30 +1393,30 @@ MmaOptions::MmaInputLayout getInputLayout(
// A = [M, K, b]
// B = [b, K, N]
// C = [M, r, N]
if ((m_axes.back() < in_a.bcasts.back()) &&
(k_axes.back() < in_a.bcasts.back()) &&
(in_b.bcasts.back() < k_axes.back()) &&
(in_b.bcasts.back() < n_axes.back())) {
if ((m_axes.front() < in_a.bcasts.front()) &&
(k_axes.front() < in_a.bcasts.front()) &&
(in_b.bcasts.front() < k_axes.front()) &&
(in_b.bcasts.front() < n_axes.front())) {
return MmaOptions::MmaInputLayout::TT;
}
// TN layout (b - broadcast, r - reduction):
// A = [M, b, K]
// B = [b, N, K]
// C = [M, N, r]
if ((m_axes.back() < in_a.bcasts.back()) &&
(in_a.bcasts.back() < k_axes.back()) &&
(in_b.bcasts.back() < n_axes.back()) &&
(in_b.bcasts.back() < k_axes.back())) {
if ((m_axes.front() < in_a.bcasts.front()) &&
(in_a.bcasts.front() < k_axes.front()) &&
(in_b.bcasts.front() < n_axes.front()) &&
(in_b.bcasts.front() < k_axes.front())) {
return MmaOptions::MmaInputLayout::TN;
}
// NT layout (b - broadcast, r - reduction):
// A = [K, M, b]
// B = [K, b, N]
// C = [r, M, N]
if ((k_axes.back() < in_a.bcasts.back()) &&
(m_axes.back() < in_a.bcasts.back()) &&
(k_axes.back() < in_b.bcasts.back()) &&
(in_b.bcasts.back() < n_axes.back())) {
if ((k_axes.front() < in_a.bcasts.front()) &&
(m_axes.front() < in_a.bcasts.front()) &&
(k_axes.front() < in_b.bcasts.front()) &&
(in_b.bcasts.front() < n_axes.front())) {
return MmaOptions::MmaInputLayout::NT;
}

Expand Down Expand Up @@ -1578,7 +1578,6 @@ MmaOp::MmaOp(
Val* in_b,
Val* init)
: Expr(passkey) {
// Check output type
TORCH_INTERNAL_ASSERT(
out->getValType().value() == ValType::TensorView ||
out->getValType().value() == ValType::TensorIndex,
Expand All @@ -1594,14 +1593,6 @@ MmaOp::MmaOp(
in_b->getValType().value() == ValType::TensorIndex,
in_b->getValType().value());

MmaOpUtils::MmaOpDetails mma_details;
// Detailed consistency checks for use case with TensorViews as inputs/output
if (in_a->isA<TensorView>() && in_b->isA<TensorView>() &&
out->isA<TensorView>()) {
mma_details = MmaOpUtils::getMmaOpDetails(
out->as<TensorView>(), in_a->as<TensorView>(), in_b->as<TensorView>());
}

addOutput(out);
addInput(in_a);
addInput(in_b);
Expand All @@ -1622,6 +1613,15 @@ MmaOp::MmaOp(
addAttribute(
IrBuilder::create<Attribute<MmaInputLayoutOpt>>(passkey.ir_container_));

MmaOpUtils::MmaOpDetails mma_details;
// Detailed consistency checks for use case with TensorViews as
// inputs/output
if (in_a->isA<TensorView>() && in_b->isA<TensorView>() &&
out->isA<TensorView>()) {
mma_details = MmaOpUtils::getMmaOpDetails(
out->as<TensorView>(), in_a->as<TensorView>(), in_b->as<TensorView>());
}

attribute(ATTR_POS_M_AXES)->as<Attribute<AxesData>>()->value =
std::move(mma_details.m_axes);
attribute(ATTR_POS_N_AXES)->as<Attribute<AxesData>>()->value =
Expand All @@ -1640,9 +1640,27 @@ MmaOp::MmaOp(
Val* in_a,
Val* in_b,
Val* init,
OptionsInMma options)
const OptionsInMma& options,
const MmaInputLayoutOpt& input_layout)
: MmaOp(passkey, out, in_a, in_b, init) {
attribute(ATTR_POS_OPTS)->as<Attribute<OptionsInMma>>()->value = options;

const auto input_layout_ = attribute(ATTR_POS_INPUT_LAYOUT)
->as<Attribute<MmaInputLayoutOpt>>()
->value;
if (input_layout_.has_value()) {
TORCH_INTERNAL_ASSERT(
input_layout_.value() == input_layout.value(),
"Input layout mismatch, infered attribute (",
nvfuser::toString(input_layout_.value()),
"), provided attribute (",
nvfuser::toString(input_layout.value()),
")");
} else {
attribute(ATTR_POS_INPUT_LAYOUT)
->as<Attribute<MmaInputLayoutOpt>>()
->value = input_layout;
}
}

std::string MmaOp::toString(int indent_size) const {
Expand All @@ -1667,7 +1685,6 @@ void MmaOp::configureOptions(MmaOptions options) {
options.accumulator_stride > 0, "Un-configured accumulator stride.");
opt.accumulator_stride = options.accumulator_stride;
opt.macro = options.macro;
opt.operand_layout = options.operand_layout;
}

NVFUSER_DEFINE_CLONE_AND_CREATE(MmaOp)
Expand Down
4 changes: 2 additions & 2 deletions csrc/lower_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1258,8 +1258,8 @@ void IndexLowering::handle(const MmaOp* mma) {
const auto a = lowerSrcIndex(mma->inA(), mma->out());
const auto b = lowerSrcIndex(mma->inB(), mma->out());
const auto out = lowerDstIndex(mma->out());
auto mma_indexed =
IrBuilder::create<MmaOp>(out, a, b, mma->init(), mma->options());
auto mma_indexed = IrBuilder::create<MmaOp>(
out, a, b, mma->init(), mma->options(), mma->inputLayout());
pushBack(mma_indexed);
GpuLower::current()->propagateExprInfo(mma, back());
}
Expand Down
3 changes: 2 additions & 1 deletion csrc/lower_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,8 @@ class ReplaceExprInput : private kir::ExprMutator {
replaced_inputs->at(node->inA()),
replaced_inputs->at(node->inB()),
node->init(),
node->options());
node->options(),
node->inputLayout());
registerReplaceWithPredicate(node, replacement);
}
}
Expand Down
16 changes: 13 additions & 3 deletions csrc/scheduler/matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ void prologSwizzle(TensorView* shared_mem_tv, const MatmulParams& params) {
const auto tile_size_x = shared_mem_tv->axis(-2)->extent()->evaluateInt();
const auto tile_size_y = shared_mem_tv->axis(-1)->extent()->evaluateInt();

if (isTuring(params.mma_op) || isAmpere(params.mma_op)) {
if (isTuring(params.mma_macro) || isAmpere(params.mma_macro)) {
// TODO: right now, we are assuming ldmatrix access, which only supports
// sizeof(T) == 16bit (i.e. half/bfloat16) load according to offical doc:
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-load-instruction-ldmatrix
Expand Down Expand Up @@ -392,7 +392,7 @@ void prologSwizzle(TensorView* shared_mem_tv, const MatmulParams& params) {
shared_mem_tv->merge(-5);
shared_mem_tv->merge(-3);
shared_mem_tv->merge(-2);
} else if (isVolta(params.mma_op)) {
} else if (isVolta(params.mma_macro)) {
// TODO: Volta is slightly more complex, and a fixed recipe would
// not scale. In a follow up this would be inferred from the mma
// macro layout themselves as we already have them registered in
Expand Down Expand Up @@ -440,6 +440,7 @@ void scheduleProlog(TensorView* shared_mem_tv, const MatmulParams& params) {
void scheduleMatmul(Fusion* fusion, const MatmulParams& params) {
const auto& inputs = fusion->inputs();
const auto& outputs = fusion->outputs();
const auto mma_ops = ir_utils::getMmaOps(fusion);

TORCH_INTERNAL_ASSERT(
inputs.size() == 2,
Expand All @@ -458,13 +459,22 @@ void scheduleMatmul(Fusion* fusion, const MatmulParams& params) {
outputs[0]->isA<TensorView>(),
"fusion's output is not an instance of TensorView class");

TORCH_INTERNAL_ASSERT(
mma_ops.size() == 1,
"scheduleMatmul supports fusion with single mma op in definition, got ",
mma_ops.size());
TORCH_INTERNAL_ASSERT(
mma_ops.front()->inputLayout().has_value(),
"fusion mma op has undefined input layout");

TensorView* a = inputs[0]->as<TensorView>();
TensorView* b = inputs[1]->as<TensorView>();
TensorView* c = outputs[0]->as<TensorView>();

// Collect mma swizzle info
const auto layout = mma_ops.front()->inputLayout().value();
auto mma_builder =
MmaBuilder(params.mma_op, params.tile_sizes).layout(params.layout);
MmaBuilder(params.mma_macro, params.tile_sizes).layout(layout);
const auto& gemm_tile = params.tile_sizes;

// Including current tensor naming convention for reference,
Expand Down
21 changes: 8 additions & 13 deletions csrc/scheduler/matmul_heuristic.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,7 @@ class MatmulParams : public HeuristicParams {
MatMulTileOptions tile_sizes = {};

//! Specify the type of MMA op to be used in generated kernel.
MmaOptions::MacroType mma_op = MmaOptions::MacroType::NoMMA;

//! Specify the input layout of input tensors.
MmaOptions::MmaInputLayout layout =
static_cast<MmaOptions::MmaInputLayout>(-1);
MmaOptions::MacroType mma_macro = MmaOptions::MacroType::NoMMA;

//! Specify CTA rastrization order.
TileRasterizationOrder cta_order = TileRasterizationOrder::RowMajor;
Expand All @@ -98,8 +94,7 @@ class MatmulParams : public HeuristicParams {
std::stringstream ss;
ss << "\n===== Matmul Parameters ========\n"
<< (tag.empty() ? "" : "Tag: ") << tag << "\n"
<< "MMA op: " << nvfuser::toString(mma_op, true) << "\n"
<< "Layout: " << nvfuser::toString(layout) << "\n"
<< "MMA macro: " << nvfuser::toString(mma_macro, true) << "\n"
<< double_buffer_options.toString() << "\n"
<< nvfuser::toString(tile_sizes) << "\n"
<< "Rotate ldmatrix out of main loop: "
Expand Down Expand Up @@ -127,11 +122,11 @@ class MatmulParams : public HeuristicParams {
(static_cast<size_t>(async_gmem_load_operands));

// combined hash
attr_hash = std::hash<size_t>{}(attr_hash) ^ (nvfuser::hash(mma_op) << 1) ^
(nvfuser::hash(layout) << 2) ^ (double_buffer_options.hash() << 3) ^
(nvfuser::hash(tile_sizes) << 4) ^
(std::hash<size_t>{}(static_cast<size_t>(cta_order)) << 5) ^
(std::hash<size_t>{}(grid_swizzle_factor) << 6);
attr_hash = std::hash<size_t>{}(attr_hash) ^
(nvfuser::hash(mma_macro) << 1) ^ (double_buffer_options.hash() << 2) ^
(nvfuser::hash(tile_sizes) << 3) ^
(std::hash<size_t>{}(static_cast<size_t>(cta_order)) << 4) ^
(std::hash<size_t>{}(grid_swizzle_factor) << 5);
return attr_hash;
}

Expand All @@ -142,7 +137,7 @@ class MatmulParams : public HeuristicParams {
return false;
}

return other_casted->layout == layout && other_casted->mma_op == mma_op &&
return other_casted->mma_macro == mma_macro &&
other_casted->async_gmem_load_operands == async_gmem_load_operands &&
other_casted->rotate_ldmatrix_out_of_main_loop ==
rotate_ldmatrix_out_of_main_loop &&
Expand Down
Loading

0 comments on commit 6904a74

Please sign in to comment.