Skip to content

Commit

Permalink
FusionExecutorCache determines index type by taking all tensors into …
Browse files Browse the repository at this point in the history
…consideration (#108)

Fixes #79 
Stacked on top of #159 

Currently we determine the index type by only looking at the sizes and
strides of fusion inputs. That's not correct, for example, when a fusion
has an output that's larger than inputs, for example, matmul of `{M, K }
* {K, N}` where both inputs are small enough to use 32-bit indexing, but
its output, `{M, N}` is not.

This turned out to require a lot of changes in the code around fusion
executor and kernel cache as we're assuming `KernelArgumentHolder` with
only fusion inputs is enough to determine the index type. This isn't a
valid assumption, so one of the major changes in this PR is that the
index type is no longer dictated by `KernelArgumentHolder`.

Another major change is that we can no longer create valid
`TensorArgCodegen` for `KernelArgumentHolder` unless the fusion index
type is determined since `TensorArgCodegen` is templated with the index
type. However, `KernelArgumentHolder` is also used before index type is
determined, so we need to have some form of placeholders for tensor
arguments until the index type is determined. This placeholder object is
replaced with a TensorArg with the resolved index type. Getting the
placeholder pointer, i.e., `KernelArgumentHolder::at`, is invalid as it
is not a valid `TensorArg`.

Resolution of the index type is currently based on the complete,
unsegmented original fusion. Fusion inputs and outputs may be
non-contiguous, so if we have strides for them, they are also
considered. If not, conservatively falls back to 64-bit indexing.

For non-input, non-output intermediate tensors, rfactor root domains are
used to find the required smallest index type. Note that contiguous
tensors are assumed as they are intermediates. This analysis is not
ideal at all, though. We should only need to check global memory tensors
as shared and local tensors should be smaller than what 32-bit indexing
can represent. However, when determining the index type, the fusion is
not yet scheduled nor segmented, so we don't know exactly which tensors
are global memory tensors. Furthermore, none of the tensors are inlined,
so we need to look at rfactor root domains, which is likely an
overestimate for many of shared and local tensors. For example, in a
fusion with a matmul-like pattern, where two 2D tensors are broadcast
and fed into a binary expression, producing a 3D tensor, which is then
reduced to a 2D tensor and is returned as a fusion output. As long as
the 3D intermediate tensor is inlined into the 2D output, we do not
instantiate the whole 3D tensor, so that should not affect the fusion
index type, but the current analysis makes a conservative assumption
that it may be indeed instantiated and decides to use 64-bit indexing.

As a temporary workaround, `FusionExecutorCache::runFusionWithInputs`
has an optional parameter of `forced_index_type`. When given, no
analysis of actual sizes is done, and the given type is just used with
no validation. No guarantee of correctness is made when this option is
used.

To completely fix the conservative analysis, we would need to do this
index-type analysis after segmentation and scheduling, but a problem is
there's a circular dependency of index type and scheduling since the
index type is used by the schedulers as well. I haven't looked into if
the uses do require the actual final type or not, but have no idea how
easy this circular dependency could be resolved.

Maybe we could do something like, segment and schedule the original
fusion optimistically with 32-bit indexing without considering
intermediates, and then after scheduling is done, check the
intermediates and if there's any scheduled tensor that actually requires
64-bit indexing, scrap the scheduled fusion and restart with 64-bit
indexing.

For now, I suppose the above workaround is sufficient. I'm still doing
some final tests, but nothing seems to be failing. The change of the
logic doesn't seem to affect any of the existing benchmarks, at least on
A100-80G.

Note that this index type analysis is only done when a fusion is
executed through `FusionExecutorCache`. When a fusion is manually
scheduled and executed with `FusionExecutor::compileFusion` and
`FusionExecutor::runFusion`, the default type is 64-bit but can be
overridden by `CompileParams`

---------

Co-authored-by: Ryan Spring <rdspring1@gmail.com>
Co-authored-by: Naoya Maruyama <nmaruyama@nvidia.com>
  • Loading branch information
3 people authored Apr 19, 2023
1 parent 6904a74 commit 255af2b
Show file tree
Hide file tree
Showing 30 changed files with 648 additions and 377 deletions.
5 changes: 2 additions & 3 deletions benchmark/matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -224,12 +224,11 @@ static void SingleMatmulBase(
KernelArgumentHolder args = KernelArgumentHolder::createKernelArgumentHolder(
{inputs.first, inputs.second});

// Always use 32b indexing mode for now.
TORCH_INTERNAL_ASSERT(args.getIndexMode() == KernelIndexMode::INT32);

// Disable magic zero
CompileParams cparams;
cparams.enable_magic_zero = false;
// Always use 32b indexing mode for now.
cparams.index_type = PrimDataType::Int32;

// Compile kernel
auto launch_constraints = LaunchParams();
Expand Down
4 changes: 2 additions & 2 deletions benchmark/softmax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ static void Softmax_WarpReduceReference(benchmark::State& benchmark_state) {
std::vector<c10::IValue> aten_inputs({aten_input});

// Schedule through magic scheduler:
SchedulerRuntimeInfo runtime_info(fusion, aten_inputs, true);
SchedulerRuntimeInfo runtime_info(fusion, aten_inputs);
TORCH_INTERNAL_ASSERT(SchedulerEntry::canSchedule(
ScheduleHeuristic::Persistent, fusion, runtime_info));
auto scheduler = SchedulerEntry::makeEntry(
Expand Down Expand Up @@ -137,7 +137,7 @@ static void Softmax_WarpReduce(benchmark::State& benchmark_state) {
std::vector<c10::IValue> aten_inputs({aten_input});

// Schedule through magic scheduler:
SchedulerRuntimeInfo runtime_info(fusion, aten_inputs, true);
SchedulerRuntimeInfo runtime_info(fusion, aten_inputs);
TORCH_INTERNAL_ASSERT(SchedulerEntry::canSchedule(
ScheduleHeuristic::Persistent, fusion, runtime_info));
auto scheduler = SchedulerEntry::makeEntry(
Expand Down
9 changes: 5 additions & 4 deletions csrc/evaluator_common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ void PrecomputedValues::bindInputs(const KernelArgumentHolder& args) {
TORCH_INTERNAL_ASSERT(
args.size() == inputs.size(), "kernel inputs size does not match args");

for (const auto i : c10::irange(inputs.size())) {
for (const auto i : c10::irange((int64_t)inputs.size())) {
const auto input = inputs[i];
const ArgAbstract* arg = args[i];
if (auto tensor_input = dynamic_cast<TensorView*>(input)) {
Expand Down Expand Up @@ -216,7 +216,7 @@ void PrecomputedValues::initializeValueList(
}

c10::optional<EvaluatorValue> PrecomputedValues::getMaybeValueFor(
const Val* val) {
const Val* val) const {
auto index = val->evaluatorIndex();
if (index < 0) {
return c10::nullopt;
Expand Down Expand Up @@ -308,11 +308,12 @@ void PrecomputedValues::bindTensorMetaData(
const auto root_domain =
TensorDomain::noReductions(tv->getMaybeRFactorDomain());
TORCH_INTERNAL_ASSERT(
tensor_arg_abstract->getRank() == static_cast<int>(root_domain.size()),
tensor_arg_abstract->getRank() ==
static_cast<int64_t>(root_domain.size()),
"Something went wrong configuring launch. Inputs do not match.");

for (const auto dim : c10::irange(root_domain.size())) {
auto value = tensor_arg_abstract->getSize((int)dim);
auto value = tensor_arg_abstract->getSize(static_cast<int64_t>(dim));
if (root_domain[dim]->hasExpandedExtent()) {
auto extent = root_domain[dim]->extent();
auto expanded_extent = root_domain[dim]->expandedExtent();
Expand Down
2 changes: 1 addition & 1 deletion csrc/evaluator_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ class PrecomputedValues {

//! Returns value for the given IR node if it's stored
//! in the workspace and has been evaluated.
c10::optional<EvaluatorValue> getMaybeValueFor(const Val* val);
c10::optional<EvaluatorValue> getMaybeValueFor(const Val* val) const;

//! Debugging helper, prints all the currently known values
void print() const;
Expand Down
46 changes: 17 additions & 29 deletions csrc/executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,6 @@ void FusionExecutor::debugCompileFusionFromStr(

if (!kernel_summary.static_smem_allocations.empty()) {
ExpressionEvaluator static_evaluator;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
const auto static_smem_size = computeSharedMemory(
static_evaluator, kernel_summary.static_smem_allocations);
TORCH_INTERNAL_ASSERT(
Expand Down Expand Up @@ -250,17 +249,20 @@ void FusionExecutor::compileFusion(
// Set the index type of compile params if not already set. If set,
// make sure the compile param type is valid with the given kernel
// arguments.
auto arg_index_type = args.getIndexType();
auto arg_index_type = args.getSmallestIndexTypeOfArguments();
if (compile_params.index_type.has_value()) {
// If the int32 compilation is requested, but the arguments demand
// int64, that's an error
TORCH_INTERNAL_ASSERT(
!(compile_params.index_type.value() == DataType::Int32 &&
arg_index_type == DataType::Int),
!(compile_params.index_type.value() == PrimDataType::Int32 &&
arg_index_type == PrimDataType::Int),
"Compilation with int32 is requested but int64 is required for the arguments");
} else {
// If the given compile option doesn't specify the index type, use
// the type determined by the arguments
} else if (arg_index_type == PrimDataType::Int) {
// If the given compile option doesn't specify the index type, and
// the arguments require 64-bit indexing, we need to use 64-bit
// indexing. Note that if the arg type is 32-bit, it doesn't mean
// it's safe to use 32-bit for the whole kernel, so unless it's
// specified through CompileParams, we do not use 32-bit indexing.
compile_params.index_type = arg_index_type;
}

Expand Down Expand Up @@ -1036,7 +1038,7 @@ KernelArgumentHolder FusionExecutor::evaluateOutputSizes(
FUSER_PERF_SCOPE("FusionExecutor::AllocOutputs");
const auto kernel = lowered_->kernel();

KernelArgumentHolder ret(args.getIndexMode());
KernelArgumentHolder ret;
ret.setDeviceIndex(args.getDeviceIndex());

CompileOptions meta_options = options_;
Expand Down Expand Up @@ -1086,7 +1088,7 @@ KernelArgumentHolder FusionExecutor::inferOutputSizes(
FUSER_PERF_SCOPE("FusionExecutor::RunFusion");

ExecutorEntry* executor_entry = nullptr;
c10::optional<size_t> opt_code = args.getCacheId();
auto opt_code = args.getCacheId();
if (opt_code.has_value()) {
executor_entry = &executor_entry_lookup_[*opt_code];
}
Expand Down Expand Up @@ -1151,24 +1153,9 @@ KernelArgumentHolder FusionExecutor::inferOutputSizes(
namespace {

// Make sure the index type of Kernel is valid
// TODO: Check the size of all tensors, not just inputs.
void validateIndexType(
kir::Kernel* kernel,
KernelArgumentHolder& args,
const CompileParams& compile_params) {
// Currently, once a Fusion is lowered to a Kernel, the index type
// has to be resolved completely. This means that
// args.getIndexType() must be equal to the index type of the
// compiled kernel.
TORCH_INTERNAL_ASSERT(
kernel->indexType() == args.getIndexType(),
"Invalid pair of kernel index type and argument index type. Kernel type: ",
kernel->indexType(),
". Argument index type: ",
args.getIndexType());

// Similarly, if the type of the index type in the given compile
// parameters doesn't match, that's also an error.
TORCH_INTERNAL_ASSERT(
!compile_params.index_type.has_value() ||
kernel->indexType() == compile_params.index_type.value(),
Expand Down Expand Up @@ -1383,7 +1370,7 @@ std::vector<at::Tensor> FusionExecutor::runFusion(
!args.getCacheId().has_value() || outputs.empty(),
"short cut input cache is not compatible with pre-allocated output");

validateIndexType(kernel(), args, compile_params);
validateIndexType(kernel(), compile_params);

const auto num_inputs = args.size();

Expand Down Expand Up @@ -1505,6 +1492,7 @@ std::vector<at::Tensor> FusionExecutor::runFusion(
CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
launch_params_.smem()));
}
auto arg_buffer = args.getBuffer(kernel()->indexType());
if (!kernel()->summary().has_cooperative_grid_reduction) {
FUSER_PERF_SCOPE("ExecutorRunFusion::cuLaunchKernel");
CUDA_SAFE_CALL(cuLaunchKernel(
Expand All @@ -1517,7 +1505,7 @@ std::vector<at::Tensor> FusionExecutor::runFusion(
launch_params_.bdimz(),
launch_params_.smem(),
stream,
args.getBuffer(),
arg_buffer,
nullptr));
} else {
FUSER_PERF_SCOPE("ExecutorRunFusion::cuLaunchCooperativeKernel");
Expand All @@ -1531,7 +1519,7 @@ std::vector<at::Tensor> FusionExecutor::runFusion(
launch_params_.bdimz(),
launch_params_.smem(),
stream,
args.getBuffer()));
arg_buffer));
}
}

Expand Down Expand Up @@ -1613,7 +1601,7 @@ float FusionExecutor::runRtc(
CUDA_RT_SAFE_CALL(cudaEventCreate(&start_event));
CUDA_RT_SAFE_CALL(cudaEventCreate(&finish_event));

KernelArgumentHolder kernel_arguments(index_type);
KernelArgumentHolder kernel_arguments;
kernel_arguments.push(args);

CUDA_RT_SAFE_CALL(cudaEventRecord(start_event, stream));
Expand All @@ -1628,7 +1616,7 @@ float FusionExecutor::runRtc(
launch_params.bdimz(),
launch_params.smem(),
stream,
kernel_arguments.getBuffer(),
kernel_arguments.getBuffer(index_type),
nullptr));

CUDA_RT_SAFE_CALL(cudaEventRecord(finish_event, stream));
Expand Down
6 changes: 4 additions & 2 deletions csrc/executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ class TORCH_CUDA_CU_API FusionExecutor : public NonCopyable {
const KernelArgumentHolder& args,
const LaunchParams& launch_constraints);

//! To compile a fusion with the 32-bit index type, CompileParams
//! must be passed in. There used to be an index type associated
//! with KernelArgumentHolder, but it is no longer the case.
void compileFusion(
Fusion* fusion,
const KernelArgumentHolder& args,
Expand Down Expand Up @@ -106,8 +109,7 @@ class TORCH_CUDA_CU_API FusionExecutor : public NonCopyable {
CompileParams compile_params = CompileParams(),
const c10::optional<size_t>& opt_code = c10::nullopt) {
KernelArgumentHolder args =
KernelArgumentHolder::createKernelArgumentHolder(
inputs, indexTypeToMode(kernel()->indexType()));
KernelArgumentHolder::createKernelArgumentHolder(inputs);
if (opt_code.has_value()) {
args.setCacheId(*opt_code);
}
Expand Down
Loading

0 comments on commit 255af2b

Please sign in to comment.