From 255af2bf02a25e8b48eade7525cd1df3ca76c284 Mon Sep 17 00:00:00 2001 From: Michel Migdal <120487391+mmigdal-nv@users.noreply.github.com> Date: Thu, 20 Apr 2023 00:41:55 +0200 Subject: [PATCH] FusionExecutorCache determines index type by taking all tensors into consideration (#108) Fixes #79 Stacked on top of #159 Currently we determine the index type by only looking at the sizes and strides of fusion inputs. That's not correct, for example, when a fusion has an output that's larger than inputs, for example, matmul of `{M, K } * {K, N}` where both inputs are small enough to use 32-bit indexing, but its output, `{M, N}` is not. This turned out to require a lot of changes in the code around fusion executor and kernel cache as we're assuming `KernelArgumentHolder` with only fusion inputs is enough to determine the index type. This isn't a valid assumption, so one of the major changes in this PR is that the index type is no longer dictated by `KernelArgumentHolder`. Another major change is that we can no longer create valid `TensorArgCodegen` for `KernelArgumentHolder` unless the fusion index type is determined since `TensorArgCodegen` is templated with the index type. However, `KernelArgumentHolder` is also used before index type is determined, so we need to have some form of placeholders for tensor arguments until the index type is determined. This placeholder object is replaced with a TensorArg with the resolved index type. Getting the placeholder pointer, i.e., `KernelArgumentHolder::at`, is invalid as it is not a valid `TensorArg`. Resolution of the index type is currently based on the complete, unsegmented original fusion. Fusion inputs and outputs may be non-contiguous, so if we have strides for them, they are also considered. If not, conservatively falls back to 64-bit indexing. For non-input, non-output intermediate tensors, rfactor root domains are used to find the required smallest index type. Note that contiguous tensors are assumed as they are intermediates. This analysis is not ideal at all, though. We should only need to check global memory tensors as shared and local tensors should be smaller than what 32-bit indexing can represent. However, when determining the index type, the fusion is not yet scheduled nor segmented, so we don't know exactly which tensors are global memory tensors. Furthermore, none of the tensors are inlined, so we need to look at rfactor root domains, which is likely an overestimate for many of shared and local tensors. For example, in a fusion with a matmul-like pattern, where two 2D tensors are broadcast and fed into a binary expression, producing a 3D tensor, which is then reduced to a 2D tensor and is returned as a fusion output. As long as the 3D intermediate tensor is inlined into the 2D output, we do not instantiate the whole 3D tensor, so that should not affect the fusion index type, but the current analysis makes a conservative assumption that it may be indeed instantiated and decides to use 64-bit indexing. As a temporary workaround, `FusionExecutorCache::runFusionWithInputs` has an optional parameter of `forced_index_type`. When given, no analysis of actual sizes is done, and the given type is just used with no validation. No guarantee of correctness is made when this option is used. To completely fix the conservative analysis, we would need to do this index-type analysis after segmentation and scheduling, but a problem is there's a circular dependency of index type and scheduling since the index type is used by the schedulers as well. I haven't looked into if the uses do require the actual final type or not, but have no idea how easy this circular dependency could be resolved. Maybe we could do something like, segment and schedule the original fusion optimistically with 32-bit indexing without considering intermediates, and then after scheduling is done, check the intermediates and if there's any scheduled tensor that actually requires 64-bit indexing, scrap the scheduled fusion and restart with 64-bit indexing. For now, I suppose the above workaround is sufficient. I'm still doing some final tests, but nothing seems to be failing. The change of the logic doesn't seem to affect any of the existing benchmarks, at least on A100-80G. Note that this index type analysis is only done when a fusion is executed through `FusionExecutorCache`. When a fusion is manually scheduled and executed with `FusionExecutor::compileFusion` and `FusionExecutor::runFusion`, the default type is 64-bit but can be overridden by `CompileParams` --------- Co-authored-by: Ryan Spring Co-authored-by: Naoya Maruyama --- benchmark/matmul.cpp | 5 +- benchmark/softmax.cpp | 4 +- csrc/evaluator_common.cpp | 9 +- csrc/evaluator_common.h | 2 +- csrc/executor.cpp | 46 +++--- csrc/executor.h | 6 +- csrc/executor_kernel_arg.cpp | 162 +++++++++++++-------- csrc/executor_kernel_arg.h | 164 +++++++++++---------- csrc/executor_utils.cpp | 1 + csrc/expr_evaluator.h | 6 + csrc/fusion_segmenter.cpp | 8 +- csrc/fusion_segmenter.h | 3 +- csrc/kernel_cache.cpp | 71 ++++++---- csrc/kernel_cache.h | 40 +++++- csrc/multidevice/multidevice_runtime.cpp | 2 +- csrc/scheduler/normalization.cpp | 8 +- csrc/scheduler/pointwise.cpp | 4 +- csrc/scheduler/reduction.cpp | 7 +- csrc/scheduler/registry.cpp | 172 ++++++++++++++++++----- csrc/scheduler/registry.h | 40 +++--- csrc/scheduler/transpose.cpp | 4 +- csrc/scheduler/utils.cpp | 3 +- csrc/scheduler/vectorize_helper.cpp | 4 +- csrc/utils.cpp | 43 ------ csrc/utils.h | 36 ++++- test/test_gpu2.cpp | 45 +++--- test/test_gpu3.cpp | 118 ++++++++++++---- test/test_gpu_indexing_ops.cpp | 7 +- test/test_gpu_tensorcore.cpp | 2 +- test/test_gpu_validator.h | 3 +- 30 files changed, 648 insertions(+), 377 deletions(-) diff --git a/benchmark/matmul.cpp b/benchmark/matmul.cpp index 42ca0296d1a..4f4b2c34465 100644 --- a/benchmark/matmul.cpp +++ b/benchmark/matmul.cpp @@ -224,12 +224,11 @@ static void SingleMatmulBase( KernelArgumentHolder args = KernelArgumentHolder::createKernelArgumentHolder( {inputs.first, inputs.second}); - // Always use 32b indexing mode for now. - TORCH_INTERNAL_ASSERT(args.getIndexMode() == KernelIndexMode::INT32); - // Disable magic zero CompileParams cparams; cparams.enable_magic_zero = false; + // Always use 32b indexing mode for now. + cparams.index_type = PrimDataType::Int32; // Compile kernel auto launch_constraints = LaunchParams(); diff --git a/benchmark/softmax.cpp b/benchmark/softmax.cpp index 81840d1287e..0f58d9ed574 100644 --- a/benchmark/softmax.cpp +++ b/benchmark/softmax.cpp @@ -92,7 +92,7 @@ static void Softmax_WarpReduceReference(benchmark::State& benchmark_state) { std::vector aten_inputs({aten_input}); // Schedule through magic scheduler: - SchedulerRuntimeInfo runtime_info(fusion, aten_inputs, true); + SchedulerRuntimeInfo runtime_info(fusion, aten_inputs); TORCH_INTERNAL_ASSERT(SchedulerEntry::canSchedule( ScheduleHeuristic::Persistent, fusion, runtime_info)); auto scheduler = SchedulerEntry::makeEntry( @@ -137,7 +137,7 @@ static void Softmax_WarpReduce(benchmark::State& benchmark_state) { std::vector aten_inputs({aten_input}); // Schedule through magic scheduler: - SchedulerRuntimeInfo runtime_info(fusion, aten_inputs, true); + SchedulerRuntimeInfo runtime_info(fusion, aten_inputs); TORCH_INTERNAL_ASSERT(SchedulerEntry::canSchedule( ScheduleHeuristic::Persistent, fusion, runtime_info)); auto scheduler = SchedulerEntry::makeEntry( diff --git a/csrc/evaluator_common.cpp b/csrc/evaluator_common.cpp index 84422532f30..4f18fb4969d 100644 --- a/csrc/evaluator_common.cpp +++ b/csrc/evaluator_common.cpp @@ -158,7 +158,7 @@ void PrecomputedValues::bindInputs(const KernelArgumentHolder& args) { TORCH_INTERNAL_ASSERT( args.size() == inputs.size(), "kernel inputs size does not match args"); - for (const auto i : c10::irange(inputs.size())) { + for (const auto i : c10::irange((int64_t)inputs.size())) { const auto input = inputs[i]; const ArgAbstract* arg = args[i]; if (auto tensor_input = dynamic_cast(input)) { @@ -216,7 +216,7 @@ void PrecomputedValues::initializeValueList( } c10::optional PrecomputedValues::getMaybeValueFor( - const Val* val) { + const Val* val) const { auto index = val->evaluatorIndex(); if (index < 0) { return c10::nullopt; @@ -308,11 +308,12 @@ void PrecomputedValues::bindTensorMetaData( const auto root_domain = TensorDomain::noReductions(tv->getMaybeRFactorDomain()); TORCH_INTERNAL_ASSERT( - tensor_arg_abstract->getRank() == static_cast(root_domain.size()), + tensor_arg_abstract->getRank() == + static_cast(root_domain.size()), "Something went wrong configuring launch. Inputs do not match."); for (const auto dim : c10::irange(root_domain.size())) { - auto value = tensor_arg_abstract->getSize((int)dim); + auto value = tensor_arg_abstract->getSize(static_cast(dim)); if (root_domain[dim]->hasExpandedExtent()) { auto extent = root_domain[dim]->extent(); auto expanded_extent = root_domain[dim]->expandedExtent(); diff --git a/csrc/evaluator_common.h b/csrc/evaluator_common.h index 593b72e31e6..de6c6c92d30 100644 --- a/csrc/evaluator_common.h +++ b/csrc/evaluator_common.h @@ -150,7 +150,7 @@ class PrecomputedValues { //! Returns value for the given IR node if it's stored //! in the workspace and has been evaluated. - c10::optional getMaybeValueFor(const Val* val); + c10::optional getMaybeValueFor(const Val* val) const; //! Debugging helper, prints all the currently known values void print() const; diff --git a/csrc/executor.cpp b/csrc/executor.cpp index 163ecedba70..e43a0c1400f 100644 --- a/csrc/executor.cpp +++ b/csrc/executor.cpp @@ -178,7 +178,6 @@ void FusionExecutor::debugCompileFusionFromStr( if (!kernel_summary.static_smem_allocations.empty()) { ExpressionEvaluator static_evaluator; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) const auto static_smem_size = computeSharedMemory( static_evaluator, kernel_summary.static_smem_allocations); TORCH_INTERNAL_ASSERT( @@ -250,17 +249,20 @@ void FusionExecutor::compileFusion( // Set the index type of compile params if not already set. If set, // make sure the compile param type is valid with the given kernel // arguments. - auto arg_index_type = args.getIndexType(); + auto arg_index_type = args.getSmallestIndexTypeOfArguments(); if (compile_params.index_type.has_value()) { // If the int32 compilation is requested, but the arguments demand // int64, that's an error TORCH_INTERNAL_ASSERT( - !(compile_params.index_type.value() == DataType::Int32 && - arg_index_type == DataType::Int), + !(compile_params.index_type.value() == PrimDataType::Int32 && + arg_index_type == PrimDataType::Int), "Compilation with int32 is requested but int64 is required for the arguments"); - } else { - // If the given compile option doesn't specify the index type, use - // the type determined by the arguments + } else if (arg_index_type == PrimDataType::Int) { + // If the given compile option doesn't specify the index type, and + // the arguments require 64-bit indexing, we need to use 64-bit + // indexing. Note that if the arg type is 32-bit, it doesn't mean + // it's safe to use 32-bit for the whole kernel, so unless it's + // specified through CompileParams, we do not use 32-bit indexing. compile_params.index_type = arg_index_type; } @@ -1036,7 +1038,7 @@ KernelArgumentHolder FusionExecutor::evaluateOutputSizes( FUSER_PERF_SCOPE("FusionExecutor::AllocOutputs"); const auto kernel = lowered_->kernel(); - KernelArgumentHolder ret(args.getIndexMode()); + KernelArgumentHolder ret; ret.setDeviceIndex(args.getDeviceIndex()); CompileOptions meta_options = options_; @@ -1086,7 +1088,7 @@ KernelArgumentHolder FusionExecutor::inferOutputSizes( FUSER_PERF_SCOPE("FusionExecutor::RunFusion"); ExecutorEntry* executor_entry = nullptr; - c10::optional opt_code = args.getCacheId(); + auto opt_code = args.getCacheId(); if (opt_code.has_value()) { executor_entry = &executor_entry_lookup_[*opt_code]; } @@ -1151,24 +1153,9 @@ KernelArgumentHolder FusionExecutor::inferOutputSizes( namespace { // Make sure the index type of Kernel is valid -// TODO: Check the size of all tensors, not just inputs. void validateIndexType( kir::Kernel* kernel, - KernelArgumentHolder& args, const CompileParams& compile_params) { - // Currently, once a Fusion is lowered to a Kernel, the index type - // has to be resolved completely. This means that - // args.getIndexType() must be equal to the index type of the - // compiled kernel. - TORCH_INTERNAL_ASSERT( - kernel->indexType() == args.getIndexType(), - "Invalid pair of kernel index type and argument index type. Kernel type: ", - kernel->indexType(), - ". Argument index type: ", - args.getIndexType()); - - // Similarly, if the type of the index type in the given compile - // parameters doesn't match, that's also an error. TORCH_INTERNAL_ASSERT( !compile_params.index_type.has_value() || kernel->indexType() == compile_params.index_type.value(), @@ -1383,7 +1370,7 @@ std::vector FusionExecutor::runFusion( !args.getCacheId().has_value() || outputs.empty(), "short cut input cache is not compatible with pre-allocated output"); - validateIndexType(kernel(), args, compile_params); + validateIndexType(kernel(), compile_params); const auto num_inputs = args.size(); @@ -1505,6 +1492,7 @@ std::vector FusionExecutor::runFusion( CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, launch_params_.smem())); } + auto arg_buffer = args.getBuffer(kernel()->indexType()); if (!kernel()->summary().has_cooperative_grid_reduction) { FUSER_PERF_SCOPE("ExecutorRunFusion::cuLaunchKernel"); CUDA_SAFE_CALL(cuLaunchKernel( @@ -1517,7 +1505,7 @@ std::vector FusionExecutor::runFusion( launch_params_.bdimz(), launch_params_.smem(), stream, - args.getBuffer(), + arg_buffer, nullptr)); } else { FUSER_PERF_SCOPE("ExecutorRunFusion::cuLaunchCooperativeKernel"); @@ -1531,7 +1519,7 @@ std::vector FusionExecutor::runFusion( launch_params_.bdimz(), launch_params_.smem(), stream, - args.getBuffer())); + arg_buffer)); } } @@ -1613,7 +1601,7 @@ float FusionExecutor::runRtc( CUDA_RT_SAFE_CALL(cudaEventCreate(&start_event)); CUDA_RT_SAFE_CALL(cudaEventCreate(&finish_event)); - KernelArgumentHolder kernel_arguments(index_type); + KernelArgumentHolder kernel_arguments; kernel_arguments.push(args); CUDA_RT_SAFE_CALL(cudaEventRecord(start_event, stream)); @@ -1628,7 +1616,7 @@ float FusionExecutor::runRtc( launch_params.bdimz(), launch_params.smem(), stream, - kernel_arguments.getBuffer(), + kernel_arguments.getBuffer(index_type), nullptr)); CUDA_RT_SAFE_CALL(cudaEventRecord(finish_event, stream)); diff --git a/csrc/executor.h b/csrc/executor.h index e68a71b892e..aa65d5f200f 100644 --- a/csrc/executor.h +++ b/csrc/executor.h @@ -74,6 +74,9 @@ class TORCH_CUDA_CU_API FusionExecutor : public NonCopyable { const KernelArgumentHolder& args, const LaunchParams& launch_constraints); + //! To compile a fusion with the 32-bit index type, CompileParams + //! must be passed in. There used to be an index type associated + //! with KernelArgumentHolder, but it is no longer the case. void compileFusion( Fusion* fusion, const KernelArgumentHolder& args, @@ -106,8 +109,7 @@ class TORCH_CUDA_CU_API FusionExecutor : public NonCopyable { CompileParams compile_params = CompileParams(), const c10::optional& opt_code = c10::nullopt) { KernelArgumentHolder args = - KernelArgumentHolder::createKernelArgumentHolder( - inputs, indexTypeToMode(kernel()->indexType())); + KernelArgumentHolder::createKernelArgumentHolder(inputs); if (opt_code.has_value()) { args.setCacheId(*opt_code); } diff --git a/csrc/executor_kernel_arg.cpp b/csrc/executor_kernel_arg.cpp index 795ebf48b83..613c4ffd936 100644 --- a/csrc/executor_kernel_arg.cpp +++ b/csrc/executor_kernel_arg.cpp @@ -32,35 +32,46 @@ std::string TensorArgAbstract::toString() const { namespace { template -std::unique_ptr getTensorArg(const at::Tensor& tensor) { +std::unique_ptr getTensorArg( + const at::Tensor& tensor, + bool index_type_resolved) { switch (tensor.ndimension()) { case (0): return std::make_unique< - TensorArg>>(tensor); + TensorArg>>( + tensor, index_type_resolved); case (1): return std::make_unique< - TensorArg>>(tensor); + TensorArg>>( + tensor, index_type_resolved); case (2): return std::make_unique< - TensorArg>>(tensor); + TensorArg>>( + tensor, index_type_resolved); case (3): return std::make_unique< - TensorArg>>(tensor); + TensorArg>>( + tensor, index_type_resolved); case (4): return std::make_unique< - TensorArg>>(tensor); + TensorArg>>( + tensor, index_type_resolved); case (5): return std::make_unique< - TensorArg>>(tensor); + TensorArg>>( + tensor, index_type_resolved); case (6): return std::make_unique< - TensorArg>>(tensor); + TensorArg>>( + tensor, index_type_resolved); case (7): return std::make_unique< - TensorArg>>(tensor); + TensorArg>>( + tensor, index_type_resolved); case (8): return std::make_unique< - TensorArg>>(tensor); + TensorArg>>( + tensor, index_type_resolved); default: TORCH_INTERNAL_ASSERT( false, @@ -74,66 +85,56 @@ std::unique_ptr getTensorArg(const at::Tensor& tensor) { template struct GetTensorArgWithNativeType { template - std::unique_ptr operator()(const at::Tensor& tensor) { - return getTensorArg(tensor); + std::unique_ptr operator()( + const at::Tensor& tensor, + bool index_type_resolved) { + return getTensorArg(tensor, index_type_resolved); }; }; -template -std::unique_ptr getTensorArg(const at::Tensor& tensor) { +template +std::unique_ptr getTensorArg( + const at::Tensor& tensor, + bool index_type_resolved) { return atenTypeDispatchWithC10Complex( - tensor.scalar_type(), GetTensorArgWithNativeType(), tensor); + tensor.scalar_type(), + GetTensorArgWithNativeType(), + tensor, + index_type_resolved); } std::unique_ptr getTensorArg( - KernelIndexMode index_mode, - const at::Tensor& tensor) { - switch (index_mode) { - case KernelIndexMode::INT32: - return getTensorArg(tensor); - case KernelIndexMode::INT64: - return getTensorArg(tensor); - default: - break; + const at::Tensor& tensor, + std::optional index_type) { + if (index_type.has_value()) { + switch (index_type.value()) { + case PrimDataType::Int32: + return getTensorArg(tensor, true); + case PrimDataType::Int: + return getTensorArg(tensor, true); + default: + TORCH_INTERNAL_ASSERT(false, "unknown index mode"); + break; + } + } else { + // Tentatively create TensorArgAbstract with int64_t + return getTensorArg(tensor, false); } - - TORCH_INTERNAL_ASSERT(false, "unknown index mode"); - return nullptr; } } // namespace KernelArgumentHolder KernelArgumentHolder::createKernelArgumentHolder( - const c10::ArrayRef& inputs, - const std::optional& opt_index_mode) { + const c10::ArrayRef& inputs) { if (inputs.empty()) { - // default to int32 on device 0 - KernelArgumentHolder args( - opt_index_mode.has_value() ? opt_index_mode.value() - : KernelIndexMode::INT32); + // default to device 0 + KernelArgumentHolder args; args.setDeviceIndex(0); return args; } auto device_index = getCommonDeviceCUDA(inputs); - auto input_index_mode = collectIndexMode(inputs); - - auto index_mode = input_index_mode; - - // Use index_mode if given. Make sure it is as large as the index - // mode required for the inputs - if (opt_index_mode.has_value()) { - TORCH_INTERNAL_ASSERT( - (opt_index_mode == input_index_mode) || - opt_index_mode == KernelIndexMode::INT64, - "Given index mode and argument index mode don't match.", - "Index mode: ", - opt_index_mode.value(), - ", argument index mode: ", - input_index_mode); - index_mode = opt_index_mode.value(); - } - KernelArgumentHolder args(index_mode); + KernelArgumentHolder args; args.setDeviceIndex(device_index); args.push(inputs); @@ -150,22 +151,28 @@ struct MakeCpuScalarTensor { } }; +PrimDataType getIndexTypeOfAtenTensor(const at::Tensor& tensor) { + KernelIndexTypeCompute index_type_helper; + for (const auto i : c10::irange(tensor.ndimension())) { + index_type_helper.addDim(tensor.sizes()[i], tensor.strides()[i]); + } + return index_type_helper.getType(); +} + } // namespace // Push a tensor to the arguments void KernelArgumentHolder::push(const at::Tensor& tensor) { - changed_ = true; if (is_cpu_scalar(tensor)) { arguments_.push_back(atenTypeDispatchWithC10Complex( tensor.scalar_type(), MakeCpuScalarTensor(), tensor)); } else { - arguments_.push_back(getTensorArg(index_mode_, tensor)); + arguments_.push_back(getTensorArg(tensor, std::nullopt)); } } // Push a scalar or integer to the arguments void KernelArgumentHolder::push(const c10::IValue& val) { - changed_ = true; TORCH_INTERNAL_ASSERT( val.isScalar(), "Tried to push an arg to run in a fused kernel, expected a scalar but got, ", @@ -205,13 +212,20 @@ void KernelArgumentHolder::push(const at::PhiloxCudaState& val) { // Create buffer, flatten arguments into it, align by 8 Bytes, return pointers // in the buffer -void** KernelArgumentHolder::getBuffer() { - if (changed_) { - void_ptrs_ = std::vector(arguments_.size(), nullptr); - for (const auto i : c10::irange(arguments_.size())) { - void_ptrs_[i] = static_cast(arguments_[i]->arg()); +void** KernelArgumentHolder::getBuffer(PrimDataType index_type) { + if (void_ptrs_.size() < arguments_.size()) { + void_ptrs_.resize(arguments_.size()); + } + for (const auto i : c10::irange(arguments_.size())) { + if (auto tensor_arg = + dynamic_cast(arguments_.at(i).get())) { + if (!tensor_arg->isIndexTypeResolved() || + tensor_arg->getIndexType() != index_type) { + auto resolved_arg = getTensorArg(tensor_arg->getTensor(), index_type); + arguments_.at(i) = std::move(resolved_arg); + } } - changed_ = false; + void_ptrs_.at(i) = static_cast(arguments_.at(i)->arg()); } return void_ptrs_.data(); } @@ -236,12 +250,10 @@ void KernelArgumentHolder::push(const std::vector& tensors) { } void KernelArgumentHolder::push(const ArgAbstract* arg) { - changed_ = true; arguments_.emplace_back(arg->copy_unique_ptr()); } void KernelArgumentHolder::swap(int i, const ArgAbstract* arg) { - changed_ = true; auto holder = arg->copy_unique_ptr(); arguments_[i].swap(holder); } @@ -259,4 +271,30 @@ void KernelArgumentHolder::appendPhiloxRNGSeed(uint64_t rand_offset) { push(philox_engine_inputs); } +std::string KernelArgumentHolder::toString() const { + std::stringstream ss; + for (const auto& arg : arguments_) { + ss << arg->toString() << "\n"; + } + return ss.str(); +} + +PrimDataType KernelArgumentHolder::getSmallestIndexTypeOfArguments() const { + for (const auto& arg : arguments_) { + auto tensor_arg = dynamic_cast(arg.get()); + if (tensor_arg == nullptr) { + continue; + } + KernelIndexTypeCompute index_type_helper; + for (const auto dim_i : c10::irange(tensor_arg->getRank())) { + auto size = tensor_arg->getSize(dim_i); + auto stride = tensor_arg->getStride(dim_i); + if (index_type_helper.addDim(size, stride) == PrimDataType::Int) { + return PrimDataType::Int; + } + } + } + return PrimDataType::Int32; +} + } // namespace nvfuser diff --git a/csrc/executor_kernel_arg.h b/csrc/executor_kernel_arg.h index 50b42534133..28fc60a398b 100644 --- a/csrc/executor_kernel_arg.h +++ b/csrc/executor_kernel_arg.h @@ -17,6 +17,45 @@ namespace nvfuser { +// TODO: macro this and the printer below +enum class ArgType { + PhiloxCudaState, + Long, + Double, + ComplexDouble, + Bool, + Tensor, + CpuScalarTensor +}; + +inline std::string argTypeToString(ArgType type) { + std::string ret; + switch (type) { + case ArgType::PhiloxCudaState: + ret = "PhiloxCudaState"; + break; + case ArgType::Long: + ret = "Long"; + break; + case ArgType::Double: + ret = "Double"; + break; + case ArgType::ComplexDouble: + ret = "ComplexDouble"; + break; + case ArgType::Bool: + ret = "Bool"; + break; + case ArgType::Tensor: + ret = "Tensor"; + break; + case ArgType::CpuScalarTensor: + ret = "CpuScalarTensor"; + break; + } + return ret; +} + // This should match the tensor used in the code generation (almost exactly) template struct TensorArgCodegen { @@ -88,45 +127,6 @@ struct CpuScalarTensorCodegen { T data; }; -// TODO: macro this and the printer below -enum class ArgType { - PhiloxCudaState, - Long, - Double, - ComplexDouble, - Bool, - Tensor, - CpuScalarTensor -}; - -inline std::string argTypeToString(ArgType type) { - std::string ret; - switch (type) { - case ArgType::PhiloxCudaState: - ret = "PhiloxCudaState"; - break; - case ArgType::Long: - ret = "Long"; - break; - case ArgType::Double: - ret = "Double"; - break; - case ArgType::ComplexDouble: - ret = "ComplexDouble"; - break; - case ArgType::Bool: - ret = "Bool"; - break; - case ArgType::Tensor: - ret = "Tensor"; - break; - case ArgType::CpuScalarTensor: - ret = "CpuScalarTensor"; - break; - } - return ret; -} - struct ArgAbstract { virtual ~ArgAbstract() = default; virtual const void* arg() const = 0; @@ -205,6 +205,10 @@ struct TensorArgAbstract : ArgAbstract { virtual DataType getDataType() const = 0; virtual int64_t numel() const = 0; virtual at::Tensor getTensor() const = 0; + virtual bool isIndexTypeResolved() const = 0; + //! Returns the index type of the tensor. It's an error if the + //! tensor does not have a resolved index type. + virtual PrimDataType getIndexType() const = 0; std::string toString() const override; }; @@ -213,8 +217,10 @@ template struct TensorArg : public TensorArgAbstract { TENSOR_TYPE instance_; at::Tensor tensor_; + bool index_type_resolved_ = false; - TensorArg(const at::Tensor& tensor) : tensor_(tensor) { + TensorArg(const at::Tensor& tensor, bool index_type_resolved) + : tensor_(tensor), index_type_resolved_(index_type_resolved) { setPointer(tensor.data_ptr()); for (const auto i : c10::irange(tensor.ndimension())) { setSize(i, tensor.sizes()[i]); @@ -262,7 +268,40 @@ struct TensorArg : public TensorArgAbstract { return ret; } - DEF_HELPEE_FUNC(Tensor, instance_) + bool isIndexTypeResolved() const override { + return index_type_resolved_; + } + + PrimDataType getIndexType() const override { + TORCH_INTERNAL_ASSERT(isIndexTypeResolved()); + return NativeTypeToDataType::type; + } + + bool isType(ArgType t) const override { + return type() == t; + } + + ArgType type() const override { + return ArgType::Tensor; + } + + //! Returns the address of an tensor argument struct. It's an error + //! if called with a tensor with no resolved index type + const void* arg() const override { + TORCH_INTERNAL_ASSERT(isIndexTypeResolved()); + return &instance_; + } + + //! Returns the address of an tensor argument struct. It's an error + //! if called with a tensor with no resolved index type + void* arg() override { + TORCH_INTERNAL_ASSERT(isIndexTypeResolved()); + return &instance_; + } + + std::unique_ptr copy_unique_ptr() const override { + return std::make_unique(*this); + } }; template @@ -290,31 +329,12 @@ class TORCH_CUDA_CU_API KernelArgumentHolder { //! the ownership of the memory from the original inputs, but just recording //! its meta data for kernel execution/compilation. static KernelArgumentHolder createKernelArgumentHolder( - const c10::ArrayRef& inputs, - const std::optional& index_mode = std::nullopt); - - KernelIndexMode getIndexMode() const { - return index_mode_; - } + const c10::ArrayRef& inputs); - void setIndexMode(KernelIndexMode mode) { - index_mode_ = mode; - } - - PrimDataType getIndexType() const { - return indexModeToDtype(index_mode_); - } - - explicit KernelArgumentHolder(KernelIndexMode index_mode) - : index_mode_(index_mode) {} - - explicit KernelArgumentHolder(PrimDataType index_type) - : index_mode_(indexTypeToMode(index_type)) {} + KernelArgumentHolder() = default; KernelArgumentHolder(const KernelArgumentHolder& self) - : device_index_(self.getDeviceIndex()), - cache_id_(self.getCacheId()), - index_mode_(self.getIndexMode()) { + : device_index_(self.getDeviceIndex()), cache_id_(self.getCacheId()) { for (const auto& arg : self.arguments_) { push(arg.get()); } @@ -322,13 +342,16 @@ class TORCH_CUDA_CU_API KernelArgumentHolder { KernelArgumentHolder& operator=(const KernelArgumentHolder& self) { device_index_ = self.getDeviceIndex(); - index_mode_ = self.getIndexMode(); for (const auto& arg : self.arguments_) { push(arg.get()); } return *this; } + //! Computes the smallest index type for the currently held + //! arguments. It does not consider any other tensors used in a kernel. + PrimDataType getSmallestIndexTypeOfArguments() const; + // Push a tensor to the arguments void push(const at::Tensor& tensor); @@ -337,9 +360,10 @@ class TORCH_CUDA_CU_API KernelArgumentHolder { void push(const at::PhiloxCudaState& val); - // Create buffer, flatten arguments into it, align by 8 Bytes, return pointers - // in the buffer - void** getBuffer(); + // Create a buffer, flatten arguments into it, align by 8 Bytes, return + // pointers in the buffer. Tensor arguments are passed with the given index + // type. + void** getBuffer(PrimDataType index_type); void push(const c10::ArrayRef& args); @@ -386,7 +410,7 @@ class TORCH_CUDA_CU_API KernelArgumentHolder { cache_id_ = id; } - c10::optional getCacheId() const { + std::optional getCacheId() const { return cache_id_; } @@ -395,11 +419,9 @@ class TORCH_CUDA_CU_API KernelArgumentHolder { private: std::vector> arguments_; std::vector void_ptrs_; - bool changed_ = true; int8_t device_index_ = 0; - c10::optional cache_id_ = c10::nullopt; - KernelIndexMode index_mode_ = KernelIndexMode::INT64; + std::optional cache_id_ = std::nullopt; }; } // namespace nvfuser diff --git a/csrc/executor_utils.cpp b/csrc/executor_utils.cpp index f2d869eaf52..255daa63c2a 100644 --- a/csrc/executor_utils.cpp +++ b/csrc/executor_utils.cpp @@ -838,6 +838,7 @@ void bindInputForExprEvaluation( "Something went wrong configuring launch. Inputs do not match."); auto tensor_arg_abstract = dynamic_cast(arg); + TORCH_INTERNAL_ASSERT( tensor_arg_abstract && tensor_arg_abstract->getRank() == (int64_t)root_domain.size(), diff --git a/csrc/expr_evaluator.h b/csrc/expr_evaluator.h index 0bea39d5afa..3f705957848 100644 --- a/csrc/expr_evaluator.h +++ b/csrc/expr_evaluator.h @@ -64,6 +64,12 @@ class TORCH_CUDA_CU_API ExpressionEvaluator { c10::optional getValue(const Val* value); private: + // TODO: Consider make this const. It can't be const as bind() of + // this class calls + // PrecomputedValuess::bindConcreteParallelTypeValue, but it's + // unclear why the precompute values cannot be kept constant and + // binding a value to ExpressionEvaluator just updates + // known_named_scalars_. PrecomputedValues* precomputed_values_ = nullptr; std::unordered_map known_values_; std::unordered_map known_named_scalars_; diff --git a/csrc/fusion_segmenter.cpp b/csrc/fusion_segmenter.cpp index 56763db8417..98133081e7e 100644 --- a/csrc/fusion_segmenter.cpp +++ b/csrc/fusion_segmenter.cpp @@ -2285,7 +2285,7 @@ bool TranslateApplicableWelford::wouldTranslateToPersistent( translateSingleWelford(welford_to_translate); } - SchedulerRuntimeInfo runtime_info(test_copy.get(), runtime_inputs_, true); + SchedulerRuntimeInfo runtime_info(test_copy.get(), runtime_inputs_); // If we are looking at a segment of fusion, // we maintain the segmented group boundary, // one set for in_progress copy and one set @@ -3000,7 +3000,7 @@ SegmentCandidateFinder::SegmentCandidateFinder( const KernelArgumentHolder& inputs, SegmentCandidateFinderOptions options) : options_(options), - runtime_info_(fusion.get(), inputs, true), + runtime_info_(fusion.get(), inputs), runtime_inputs_(inputs) { segmented_fusion_ = std::make_unique(std::move(fusion)); findSegments(); @@ -3501,9 +3501,9 @@ FusionKernelRuntime::SchedulerEntryPtr SegmentedFusion:: } std::unique_ptr SegmentedFusion::makeInitialHeuristics( - const KernelArgumentHolder& inputs) { + const KernelArgumentHolder& inputs, + SchedulerRuntimeInfo& runtime_info) { auto ret = std::make_unique(); - SchedulerRuntimeInfo runtime_info(completeFusion(), inputs, true); for (auto g : groups()) { ret->emplaceBack(makeInitialSchedulerEntry(g, runtime_info)); } diff --git a/csrc/fusion_segmenter.h b/csrc/fusion_segmenter.h index ea1e339db90..6930e8530dc 100644 --- a/csrc/fusion_segmenter.h +++ b/csrc/fusion_segmenter.h @@ -333,7 +333,8 @@ class TORCH_CUDA_CU_API SegmentedFusion { //! Make heuristics for all groups in this segmented fusion std::unique_ptr makeInitialHeuristics( - const KernelArgumentHolder& inputs); + const KernelArgumentHolder& inputs, + SchedulerRuntimeInfo& runtime_info); //! Inline Debug print for segmented fusion std::string toString(int verbosity) const; diff --git a/csrc/kernel_cache.cpp b/csrc/kernel_cache.cpp index 858a20946eb..4bb7d47797c 100644 --- a/csrc/kernel_cache.cpp +++ b/csrc/kernel_cache.cpp @@ -141,7 +141,6 @@ void FusionExecutorCache::compileFusionAsync( KernelArgumentHolder args = prepareInputs(inputs); auto kernel_runtime = getKernelRuntimeFor(args); - kernel_runtime->startAsyncCompile(args); } @@ -179,7 +178,8 @@ void FusionExecutorCache::compileFusionAsync( // For details on Part_2, refer to the implementation note. [ Permutation // Bookkeeping and Propagation in Parser ] std::vector FusionExecutorCache::runFusionWithInputs( - const at::ArrayRef& inputs) { + const at::ArrayRef& inputs, + std::optional forced_index_type) { FUSER_PERF_SCOPE("FusionExecutorCache::runFusionWithInputs"); // Permute input tensor for kernel execution. @@ -201,8 +201,18 @@ std::vector FusionExecutorCache::runFusionWithInputs( KernelArgumentHolder args = prepareInputs(perm_inputs); - auto kernel_runtime = getKernelRuntimeFor(args); + auto kernel_runtime = getKernelRuntimeFor(args, forced_index_type); most_recent_runtime_ = kernel_runtime; + + // Make sure the forced index type is indeed used + if (forced_index_type.has_value()) { + TORCH_INTERNAL_ASSERT( + kernel_runtime->getIndexType() == forced_index_type.value(), + "Enforcing index type of ", + forced_index_type.value(), + " failed"); + } + int seq_id = 0; // Record kernel input and output tensors so profiler can construct // the data flow graph @@ -326,12 +336,18 @@ void FusionExecutorCache::evictCache(size_t cache_id) { } FusionKernelRuntime* FusionExecutorCache::getKernelRuntimeFor( - const KernelArgumentHolder& args) { + const KernelArgumentHolder& args, + std::optional forced_index_type) { // Check for id hit case auto unique_id = *args.getCacheId(); auto id_it = id_to_kernel_runtime_.find(unique_id); if (id_it != id_to_kernel_runtime_.end()) { - return id_it->second; + // If the forced index type is given, don't use the cached runtime + // if its index type does not match with the forced type + if (!forced_index_type.has_value() || + forced_index_type.value() == id_it->second->getIndexType()) { + return id_it->second; + } } // Access kernels associated with the common device id @@ -345,8 +361,9 @@ FusionKernelRuntime* FusionExecutorCache::getKernelRuntimeFor( auto reuse_it = std::find_if( kernel_runtimes.begin(), kernel_runtimes.end(), - [&args, &new_heuristics](auto& kernel_runtime) { - auto maybe_heuristics = kernel_runtime->getMaybeHeuristicsFor(args); + [&args, &new_heuristics, &forced_index_type](auto& kernel_runtime) { + auto maybe_heuristics = + kernel_runtime->getMaybeHeuristicsFor(args, forced_index_type); if (!maybe_heuristics.has_value()) { return false; } @@ -360,8 +377,8 @@ FusionKernelRuntime* FusionExecutorCache::getKernelRuntimeFor( kernel_runtime->updateHeuristicsLaunchParams(new_heuristics.get()); } else { // graph miss, need to re-build an optimized graph for this case - kernel_runtimes.emplace_back( - std::make_unique(fusion_.get(), args)); + kernel_runtimes.emplace_back(std::make_unique( + fusion_.get(), args, forced_index_type)); kernel_runtime = kernel_runtimes.back().get(); if (profiling_) { kernel_runtime->profile(true); @@ -374,15 +391,19 @@ FusionKernelRuntime* FusionExecutorCache::getKernelRuntimeFor( FusionKernelRuntime::FusionKernelRuntime( Fusion* fusion, - const KernelArgumentHolder& args) { + const KernelArgumentHolder& args, + std::optional forced_index_type) { FUSER_PERF_SCOPE("FusionKernelRuntime::FusionKernelRuntime"); // Make a copy of fusion and do segmentation and translation // on this copy auto fusion_copy = std::make_unique(*fusion); + all_tvs_ = ir_utils::allTvs(fusion_copy.get()); + // Run segmentation on the copied fusion - SchedulerRuntimeInfo runtime_info(fusion_copy.get(), args, true); + SchedulerRuntimeInfo runtime_info( + fusion_copy.get(), args, nullptr, all_tvs_, forced_index_type); // Initialize the evaluator simplifer precomputed_values_ = std::make_unique(fusion_copy.get()); @@ -406,7 +427,8 @@ FusionKernelRuntime::FusionKernelRuntime( std::move(fusion_copy), maybe_complete_fusion_heuristic.value(), args); } - heuristics_ = segmented_fusion_->makeInitialHeuristics(args); + heuristics_ = segmented_fusion_->makeInitialHeuristics(args, runtime_info); + executors_ = std::vector(segmented_fusion_->groups().size()); if (isDebugDumpEnabled(DebugDumpOption::FusionSegments)) { segmented_fusion_->print(); @@ -759,7 +781,7 @@ std::unordered_map FusionKernelRuntime:: for (auto group_to_run : runtime_workspace_.group_run_order) { // TODO: index mode should be updated per segmented kernel // Prepare input vector - KernelArgumentHolder group_runtime_inputs(args.getIndexMode()); + KernelArgumentHolder group_runtime_inputs; group_runtime_inputs.setDeviceIndex(args.getDeviceIndex()); if (group_cache_id.has_value()) { group_runtime_inputs.setCacheId(group_cache_id.value()); @@ -786,7 +808,7 @@ std::unordered_map FusionKernelRuntime:: } const std::vector& FusionKernelRuntime:: - schedulers() { + schedulers() const { return heuristics_->heuristicsList(); } @@ -804,14 +826,19 @@ void FusionKernelRuntime::updateHeuristicsLaunchParams( } c10::optional FusionKernelRuntime:: - getMaybeHeuristicsFor(const KernelArgumentHolder& args) { + getMaybeHeuristicsFor( + const KernelArgumentHolder& args, + std::optional forced_index_type) { FUSER_PERF_SCOPE("FusionKernelRuntime::getMaybeHeuristicsFor"); auto complete_fusion = segmented_fusion_->completeFusion(); - SchedulerRuntimeInfo runtime_info(complete_fusion, args); precomputed_values_->bindInputs(args); precomputed_values_->evaluate(); - runtime_info.expressionEvaluator().bindPrecomputedValues( - precomputed_values_.get()); + SchedulerRuntimeInfo runtime_info( + complete_fusion, + args, + precomputed_values_.get(), + all_tvs_, + forced_index_type); c10::optional ret; ret = std::make_unique(); @@ -871,12 +898,4 @@ std::vector GraphCache::runGraphWithInputs( return outputs; } -std::string KernelArgumentHolder::toString() const { - std::stringstream ss; - for (const auto& arg : arguments_) { - ss << arg->toString() << "\n"; - } - return ss.str(); -} - } // namespace nvfuser diff --git a/csrc/kernel_cache.h b/csrc/kernel_cache.h index 474c629a759..ee1484d0dd7 100644 --- a/csrc/kernel_cache.h +++ b/csrc/kernel_cache.h @@ -46,7 +46,8 @@ class TORCH_CUDA_CU_API FusionKernelRuntime { public: explicit FusionKernelRuntime( Fusion* fusion, - const KernelArgumentHolder& inputs); + const KernelArgumentHolder& inputs, + std::optional forced_index_type = std::nullopt); //! Type notations within FusionKernelRuntime Context using HashType = size_t; @@ -75,6 +76,18 @@ class TORCH_CUDA_CU_API FusionKernelRuntime { }); } + //! Note that all heuristics use the same index type. + PrimDataType getIndexType() const { + // No scheduler means nothing to run. It may still be unsafe to + // save tensor sizes and strides in Int32 + if (schedulers().empty()) { + return PrimDataType::Int; + } + auto index_type = schedulers().at(0).get()->params()->cparams.index_type; + TORCH_INTERNAL_ASSERT(index_type.has_value()); + return index_type.value(); + } + //! Unified interface to run the managed kernels with given input std::vector runWithInputs(KernelArgumentHolder& args); @@ -128,9 +141,12 @@ class TORCH_CUDA_CU_API FusionKernelRuntime { // Try to compute heuristics based on the SegmentedFusion managed // in this kernel runtime, and will return a nullopt if either // any segment cannot be scheduled or the parameters don't match + // + // Heuristics must use the index type of forced_index_type if given. using HeuristicsPtr = std::unique_ptr; c10::optional getMaybeHeuristicsFor( - const KernelArgumentHolder& args); + const KernelArgumentHolder& args, + std::optional forced_index_type = std::nullopt); //! Copy the launch params given in the parameter heuristics to prepare //! for kernel launch for a new input dimension but same heuristics @@ -178,7 +194,7 @@ class TORCH_CUDA_CU_API FusionKernelRuntime { SegmentedGroup* sg); //! Access the list of schedulers maintained in this runtime instance - const std::vector& schedulers(); + const std::vector& schedulers() const; void prepareRuntimeOrder(); @@ -209,6 +225,9 @@ class TORCH_CUDA_CU_API FusionKernelRuntime { //! Utility to speed up value evaluation at runtime std::unique_ptr precomputed_values_; + //! Cache of all tensors in the complete fusion + std::vector all_tvs_; + // States for profiling support bool profiling_ = false; @@ -352,8 +371,15 @@ class TORCH_CUDA_CU_API FusionExecutorCache { //! Execute fusion graph with given inputs, create `FusionExecutor` as needed //! Note this function also handles permutation & input update outside of //! codegen. + //! + //! If given, the index type of forced_index_type is used no matter + //! what inputs and the fusion look like. This may be useful in some + //! cases as our analysis of index type may be overly conservative + //! for intermediate tensors. + //! WARING: Correctness is not guaranteed. std::vector runFusionWithInputs( - const at::ArrayRef& inputs); + const at::ArrayRef& inputs, + std::optional forced_index_type = std::nullopt); //! Compile a kernel executor for given inputs. Note: The compilation is //! async, there's some restriction on the user side. e.g. Do not overlap @@ -445,7 +471,11 @@ class TORCH_CUDA_CU_API FusionExecutorCache { //! entry in `FusionExecutor` void evictCache(size_t cache_id); - FusionKernelRuntime* getKernelRuntimeFor(const KernelArgumentHolder& inputs); + //! The index type of forced_index_type is used to get a kernel + //! runtime no matter what sizes inputs have + FusionKernelRuntime* getKernelRuntimeFor( + const KernelArgumentHolder& inputs, + std::optional forced_index_type = std::nullopt); private: //! original un-scheduled `Fusion`; diff --git a/csrc/multidevice/multidevice_runtime.cpp b/csrc/multidevice/multidevice_runtime.cpp index 86b936c8a4e..e0fcd4c4b6f 100644 --- a/csrc/multidevice/multidevice_runtime.cpp +++ b/csrc/multidevice/multidevice_runtime.cpp @@ -40,7 +40,7 @@ MultiDeviceRuntime::CompiledKernelPtr MultiDeviceRuntime::compileCluster( if (cluster->params().auto_schedule) { // Get runtime info from fusion graph and concrete tensor inputs. SchedulerRuntimeInfo runtime_info( - fusion_from_cluster.get(), cluster_inputs, true); + fusion_from_cluster.get(), cluster_inputs); // Get heuristic tag that applies to the given fusion and input info. auto heuristic = SchedulerEntry::proposeHeuristics( diff --git a/csrc/scheduler/normalization.cpp b/csrc/scheduler/normalization.cpp index 471148d2138..fc8ed02e13f 100644 --- a/csrc/scheduler/normalization.cpp +++ b/csrc/scheduler/normalization.cpp @@ -1077,9 +1077,7 @@ std::shared_ptr getPersistentHeuristics( max_dtype_size = std::max( max_dtype_size, - dataTypeSize( - tv->getDataType().value(), - indexModeToDtype(runtime_info.getIndexMode()))); + dataTypeSize(tv->getDataType().value(), runtime_info.getIndexType())); n_tensor_inputs++; } @@ -1096,7 +1094,7 @@ std::shared_ptr getPersistentHeuristics( max_persistent_size, vectorize_factor, project_persistent_buffers); - heuristic->cparams.index_type = indexModeToDtype(runtime_info.getIndexMode()); + heuristic->cparams.index_type = runtime_info.getIndexType(); return heuristic; } @@ -1105,7 +1103,7 @@ std::shared_ptr getPersistentHeuristics( const at::ArrayRef& runtime_inputs, HeuristicSummary* data_cache) { FUSER_PERF_SCOPE("getPersistentHeuristicsFromIValue"); - SchedulerRuntimeInfo runtime_info(fusion, runtime_inputs, true); + SchedulerRuntimeInfo runtime_info(fusion, runtime_inputs); return getPersistentHeuristics(fusion, runtime_info, data_cache); } diff --git a/csrc/scheduler/pointwise.cpp b/csrc/scheduler/pointwise.cpp index 090e0c04da8..0283b0a3691 100644 --- a/csrc/scheduler/pointwise.cpp +++ b/csrc/scheduler/pointwise.cpp @@ -71,7 +71,7 @@ std::shared_ptr getPointwiseHeuristics( Fusion* fusion, const at::ArrayRef& runtime_inputs, HeuristicSummary* data_cache) { - SchedulerRuntimeInfo runtime_info(fusion, runtime_inputs, true); + SchedulerRuntimeInfo runtime_info(fusion, runtime_inputs); return getPointwiseHeuristics(fusion, runtime_info, data_cache); } @@ -84,7 +84,7 @@ std::shared_ptr getPointwiseHeuristics( FusionGuard fg(fusion); // Incase any buffer is of type DataType::Index - const auto index_type = indexModeToDtype(runtime_info.getIndexMode()); + const auto index_type = runtime_info.getIndexType(); auto in_tvs = ir_utils::filterByType(fusion->inputs()); diff --git a/csrc/scheduler/reduction.cpp b/csrc/scheduler/reduction.cpp index a062940815f..336545cd3aa 100644 --- a/csrc/scheduler/reduction.cpp +++ b/csrc/scheduler/reduction.cpp @@ -883,7 +883,7 @@ std::shared_ptr getReductionHeuristics( HeuristicSummary* data_cache) { FUSER_PERF_SCOPE("getReductionHeuristics"); - SchedulerRuntimeInfo runtime_info(fusion, runtime_inputs, true); + SchedulerRuntimeInfo runtime_info(fusion, runtime_inputs); return getReductionHeuristics(fusion, runtime_info, data_cache); } @@ -958,8 +958,7 @@ std::shared_ptr getReductionHeuristics( max_dtype_size = std::max( max_dtype_size, static_cast(dataTypeSize( - tv->getDataType().value(), - indexModeToDtype(runtime_info.getIndexMode())))); + tv->getDataType().value(), runtime_info.getIndexType()))); n_tensor_inputs++; } @@ -974,7 +973,7 @@ std::shared_ptr getReductionHeuristics( n_tensor_inputs, max_dtype_size, vectorize_factor); - heuristic->cparams.index_type = indexModeToDtype(runtime_info.getIndexMode()); + heuristic->cparams.index_type = runtime_info.getIndexType(); return heuristic; } diff --git a/csrc/scheduler/registry.cpp b/csrc/scheduler/registry.cpp index 3ae4dc0622b..a5efb4680f0 100644 --- a/csrc/scheduler/registry.cpp +++ b/csrc/scheduler/registry.cpp @@ -748,16 +748,121 @@ bool reductionInterferingView( return false; } +PrimDataType getTensorIndexType(TensorView* tv, ExpressionEvaluator& ee) { + TORCH_INTERNAL_ASSERT( + !tv->isFusionInput(), + "This function is not supposed to be used for fusion inputs: ", + tv->toString()); + + auto non_contig = std::any_of( + tv->domain()->contiguity().begin(), + tv->domain()->contiguity().end(), + [](const auto contig) { return contig.has_value() && !contig.value(); }); + + // When a fusion output is non-contiguous, currently there's no + // way to obtain its strides. This is an interface problem and + // should be fixed. + if (tv->isFusionOutput() && non_contig) { + return PrimDataType::Int; + } + + // This function should not be used for fusion inputs, so any + // non-contig tensor means a fusion intermediate tensor. However, + // since we don't support non-contiguous intermediates, there must be + // something wrong. + TORCH_INTERNAL_ASSERT( + !non_contig, "Unexpected non-contiguous tensor found: ", tv->toString()); + + // Note that at this point tensors are not scheduled yet. Each + // tensor may end up being inlined, stored on Shared or Local, but + // the index type is currently supposed to be determined before + // any of scheduling decisions is made, so we may end up making a + // conservative decision. + // TODO: Consider index type resolution after segmentation and + // scheduling. At that point we have the final scheduled fusions + // with which we can make more precise analyses. It would require + // that the scheduling and segmentation should not have any + // assumption about the index type as it may change. + int64_t stride = 1; + KernelIndexTypeCompute index_type_helper; + for (auto i = tv->getMaybeRFactorDomain().size(); i > 0; --i) { + auto id = tv->getMaybeRFactorDomain().at(i - 1); + if (id->isReduction() || id->isStride() || id->isBroadcast()) { + continue; + } + + auto extent = ee.evaluate(id->extent()); + // We could also just conservatively use 64-bit indexing if the + // extent size is not determined, but this should be possible to + // evaluate. + TORCH_INTERNAL_ASSERT( + extent.has_value(), + "Axis with unknown extent found: ", + id->toString(), + ", tensor: ", + tv->toString()); + + auto extent_int = extent->as(); + + TORCH_INTERNAL_ASSERT( + extent_int >= 0, "Unexpected size of axis: ", extent_int); + + if (extent_int > 0) { + if (index_type_helper.addDim(extent->as(), stride) == + PrimDataType::Int) { + return PrimDataType::Int; + } + stride *= extent->as(); + } + } + + return index_type_helper.getType(); +} + +// Check inputs, outputs and intermediates +// Intermediates are contiguous, so strides are not necessary +// Strides are required for inputs and also maybe for outputs as +// they may be non-contiguous. However, in our current interface, +// output strides are not available, so if there's any outputs that +// are non contiguous, need to fall back to 64-bit indexing +PrimDataType getIndexTypeOfKernel( + Fusion* fusion, + const std::vector& all_tvs, + const KernelArgumentHolder& inputs, + ExpressionEvaluator& ee) { + if (inputs.getSmallestIndexTypeOfArguments() == PrimDataType::Int) { + return PrimDataType::Int; + } + + for (auto tv : all_tvs) { + // Fusion input tensors are included in the args parameter, and + // they are checked separately + if (tv->isFusionInput()) { + continue; + } + + if (getTensorIndexType(tv, ee) == PrimDataType::Int) { + return PrimDataType::Int; + } + } + + return PrimDataType::Int32; +} + } // namespace -void SchedulerRuntimeInfo::initialize( +SchedulerRuntimeInfo::SchedulerRuntimeInfo( + Fusion* complete_fusion, const KernelArgumentHolder& args, - bool create_expr_evaluator) { + PrecomputedValues* precomputed_values, + const std::vector& all_tvs, + std::optional forced_index_type) + : complete_fusion_(complete_fusion) { TORCH_INTERNAL_ASSERT( complete_fusion_->inputs().size() == args.size(), "Invalid number of arguments passed in for provided fusion group."); - for (auto inp_i : c10::irange(args.size())) { + for (auto inp_i : c10::irange(static_cast(args.size()))) { auto kernel_arg = args[inp_i]; // Note: we are skipping CpuScalar tensor here if (auto tensor_arg_abstract = @@ -770,7 +875,7 @@ void SchedulerRuntimeInfo::initialize( auto dtype_size = dataTypeSize(tensor_arg_abstract->getDataType()); input_discontig_strides_[fusion_inp] = {}; auto dims = tensor_arg_abstract->getRank(); - auto expected_stride = 1; + int64_t expected_stride = 1; for (auto dim = dims - 1; dim >= 0; dim--) { auto size = tensor_arg_abstract->getSize(dim); if (size <= 1) { @@ -786,44 +891,47 @@ void SchedulerRuntimeInfo::initialize( } } - expression_evaluator_ = std::make_unique(); - if (create_expr_evaluator) { - initializeExpressionEvaluator(args); - } - index_mode_ = args.getIndexMode(); -} + expression_evaluator_ = getExpressionEvaluator(args, precomputed_values); -SchedulerRuntimeInfo::SchedulerRuntimeInfo( - Fusion* complete_fusion, - const KernelArgumentHolder& args, - bool create_expr_evaluator) - : complete_fusion_(complete_fusion) { - initialize(args, create_expr_evaluator); + if (forced_index_type.has_value()) { + index_type_ = forced_index_type.value(); + } else { + index_type_ = getIndexTypeOfKernel( + complete_fusion_, + all_tvs.empty() ? ir_utils::allTvs(complete_fusion_) : all_tvs, + args, + *expression_evaluator_); + } } -// TODO: remove this one SchedulerRuntimeInfo::SchedulerRuntimeInfo( Fusion* complete_fusion, - const at::ArrayRef& aten_inputs, - bool create_expr_evaluator) - : complete_fusion_(complete_fusion) { - KernelArgumentHolder args = - KernelArgumentHolder::createKernelArgumentHolder(aten_inputs); - initialize(args, create_expr_evaluator); -} + const at::ArrayRef& aten_inputs) + : SchedulerRuntimeInfo( + complete_fusion, + KernelArgumentHolder::createKernelArgumentHolder(aten_inputs)) {} // TODO: Output tensors could have an alignment that is not 16 Bytes passed in // from user. -size_t SchedulerRuntimeInfo::ptrOf(TensorView* tv) { +size_t SchedulerRuntimeInfo::ptrOf(TensorView* tv) const { if (input_ptrs_.find(tv) != input_ptrs_.end()) { return input_ptrs_.at(tv); } return max_alignment_size_in_byte; } -void SchedulerRuntimeInfo::initializeExpressionEvaluator( - const KernelArgumentHolder& args) { - *expression_evaluator_ = executor_utils::bindInputs(args, complete_fusion_); +std::unique_ptr SchedulerRuntimeInfo:: + getExpressionEvaluator( + const KernelArgumentHolder& args, + PrecomputedValues* precomputed_values) { + std::unique_ptr ee = + std::make_unique(); + if (precomputed_values) { + ee->bindPrecomputedValues(precomputed_values); + } else { + *ee = executor_utils::bindInputs(args, complete_fusion_); + } + return ee; } size_t SchedulerRuntimeInfo::computeAlignmentSize(size_t ptr_address) { @@ -904,8 +1012,7 @@ size_t SchedulerRuntimeInfo::getMaxVectorizableWidth(TensorView* tv) { return 1; } - size_t item_size = - dataTypeSize(tv->dtype(), indexModeToDtype(getIndexMode())); + size_t item_size = dataTypeSize(tv->dtype(), getIndexType()); // Alignment should always at least be the data type size TORCH_INTERNAL_ASSERT(getAlignmentSize(tv) % item_size == 0); @@ -1021,8 +1128,7 @@ size_t SchedulerRuntimeInfo::getInnerDimVectorizableWidth(TensorView* tv) { return 1; } - size_t item_size = - dataTypeSize(tv->dtype(), indexModeToDtype(getIndexMode())); + size_t item_size = dataTypeSize(tv->dtype(), getIndexType()); // Alignment should always at least be the data type size TORCH_INTERNAL_ASSERT(getAlignmentSize(tv) % item_size == 0); @@ -1173,7 +1279,7 @@ class NoOpScheduler : public SchedulerEntry { SchedulerRuntimeInfo& runtime_info, HeuristicSummary* data_cache = nullptr) : SchedulerEntry(ScheduleHeuristic::NoOp) { - params_ = std::make_shared("", runtime_info.getIndexMode()); + params_ = std::make_shared("", runtime_info.getIndexType()); } //! Check if the no-op heuristics apply in given fusion diff --git a/csrc/scheduler/registry.h b/csrc/scheduler/registry.h index 9b3d416514f..82677f52956 100644 --- a/csrc/scheduler/registry.h +++ b/csrc/scheduler/registry.h @@ -6,6 +6,7 @@ */ // clang-format on #pragma once +#include #include #include #include @@ -40,20 +41,24 @@ class TORCH_CUDA_CU_API SchedulerRuntimeInfo : public NonCopyable { static constexpr size_t max_alignment_size_in_byte = 16; //! Create runtime info for given fusion and input. Creating and binding - //! evaluator is optional. The evaluator is used to manage intermediate - //! integers in the fusion. We need them for segmenter and schedulers, - //! but we don't need them when we are just using this class to provide - //! additional encoding for kernel cache lookup. + //! evaluator is optional. The evaluator is used to manage intermediate + //! integers in the fusion. We need them for segmenter and schedulers, + //! but we don't need them when we are just using this class to provide + //! additional encoding for kernel cache lookup. + //! + //! The index type of forced_index_type is used if given, no matter + //! how large the actual arguments and fusion tensors + //! are. CORRECTNESS IS NOT GUARANTEED. SchedulerRuntimeInfo( Fusion* complete_fusion, - const KernelArgumentHolder& inputs, - bool create_expr_evaluator = false); + const KernelArgumentHolder& args, + PrecomputedValues* precomputed_values = nullptr, + const std::vector& all_tvs = {}, + std::optional forced_index_type = std::nullopt); - // TODO: Remove this guy below. Everything needs to go into the other ctor SchedulerRuntimeInfo( Fusion* complete_fusion, - const at::ArrayRef& aten_inputs, - bool create_expr_evaluator = false); + const at::ArrayRef& aten_inputs); //! Lookup for the alignment sizes of the given tv. Currently only returns //! actual alignment info for input tensors to the complete fusion, @@ -74,10 +79,10 @@ class TORCH_CUDA_CU_API SchedulerRuntimeInfo : public NonCopyable { static size_t computeAlignmentSize(size_t ptr_address); // Return the runtime pointer value for provided tensor view - size_t ptrOf(TensorView* tv); + size_t ptrOf(TensorView* tv) const; - KernelIndexMode getIndexMode() { - return index_mode_; + PrimDataType getIndexType() const { + return index_type_; } Fusion* fusion() { @@ -90,11 +95,10 @@ class TORCH_CUDA_CU_API SchedulerRuntimeInfo : public NonCopyable { } private: - // Bind full fusion inputs to the internal expression evaluator - void initializeExpressionEvaluator(const KernelArgumentHolder& inputs); - - // Initialize SchedulerRuntimeInfo - void initialize(const KernelArgumentHolder& args, bool create_expr_evaluator); + // Build and bind full fusion inputs to an expression evaluator + std::unique_ptr getExpressionEvaluator( + const KernelArgumentHolder& inputs, + PrecomputedValues* precomputed_values); bool isInputTv(TensorView* tv) { return std::find( @@ -129,7 +133,7 @@ class TORCH_CUDA_CU_API SchedulerRuntimeInfo : public NonCopyable { std::unordered_map inner_vectorword_map_; // Found index mode kernel needs to be run in - KernelIndexMode index_mode_ = KernelIndexMode::INT64; + PrimDataType index_type_ = PrimDataType::Int; // TODO: Remove std::unordered_map vectorword_map_; diff --git a/csrc/scheduler/transpose.cpp b/csrc/scheduler/transpose.cpp index b6b0253d59c..5f8af3217a0 100644 --- a/csrc/scheduler/transpose.cpp +++ b/csrc/scheduler/transpose.cpp @@ -571,7 +571,7 @@ std::shared_ptr getTransposeHeuristics( Fusion* fusion, const at::ArrayRef& runtime_inputs, HeuristicSummary* data_cache) { - SchedulerRuntimeInfo runtime_info(fusion, runtime_inputs, true); + SchedulerRuntimeInfo runtime_info(fusion, runtime_inputs); return getTransposeHeuristics(fusion, runtime_info, data_cache); } @@ -584,7 +584,7 @@ std::shared_ptr getTransposeHeuristics( FusionGuard fg(fusion); // Incase any buffer is of type DataType::Index - const auto index_type = indexModeToDtype(runtime_info.getIndexMode()); + const auto index_type = runtime_info.getIndexType(); auto domain_map_entry = getDomainMap(data_cache, fusion); auto& domain_map = dynamic_cast(domain_map_entry.get()); diff --git a/csrc/scheduler/utils.cpp b/csrc/scheduler/utils.cpp index a43560ba135..63d444dd73d 100644 --- a/csrc/scheduler/utils.cpp +++ b/csrc/scheduler/utils.cpp @@ -810,8 +810,7 @@ PersistentBufferSizeReturn persistentBufferSize( ? 0 : persistent_buffer_sizes[buffer_i] * dataTypeSize( - buffer->getDataType().value(), - indexModeToDtype(runtime_info.getIndexMode())); + buffer->getDataType().value(), runtime_info.getIndexType()); } // Buffers involved in normal persistence diff --git a/csrc/scheduler/vectorize_helper.cpp b/csrc/scheduler/vectorize_helper.cpp index 460a994eb97..079c006bfea 100644 --- a/csrc/scheduler/vectorize_helper.cpp +++ b/csrc/scheduler/vectorize_helper.cpp @@ -1212,8 +1212,8 @@ size_t getExpandedVectorization( SchedulerRuntimeInfo::max_alignment_size_in_byte; for (auto inp_or_out : vectorizable_inputs_outputs) { - auto dtype_size = dataTypeSize( - inp_or_out->dtype(), indexModeToDtype(runtime_info.getIndexMode())); + auto dtype_size = + dataTypeSize(inp_or_out->dtype(), runtime_info.getIndexType()); max_expand_size = std::min( max_expand_size, diff --git a/csrc/utils.cpp b/csrc/utils.cpp index f7535477369..1ca10d26d03 100644 --- a/csrc/utils.cpp +++ b/csrc/utils.cpp @@ -308,49 +308,6 @@ int8_t getCommonDeviceCUDA(const at::ArrayRef& inputs) { } } -KernelIndexMode collectIndexMode(const at::ArrayRef& inputs) { - // Save 1 more bit besides the sign bit to be conservative - constexpr int64_t most_positive_int32_index = - std::numeric_limits::max() / 2; - constexpr int64_t most_negative_int32_index = - std::numeric_limits::min() / 2; - - // Check all runtime inputs, and if any one of - // the input's index exceeds max_int32 will - // fall back to int64 indexing - for (auto ivalue_input : inputs) { - if (ivalue_input.isTensor()) { - auto tensor_input = ivalue_input.toTensor(); - int64_t tensor_most_positive_index = 0; - int64_t tensor_most_negative_index = 0; - for (auto dim_i = 0; dim_i < tensor_input.ndimension(); dim_i++) { - // Ignore broadcast dimensions - if (tensor_input.size(dim_i) > 1) { - // accumulate based on the sign of stride - if (tensor_input.stride(dim_i) > 0) { - // Acuumulate positive stride - tensor_most_positive_index += - (tensor_input.size(dim_i) - 1) * tensor_input.stride(dim_i); - } else { - // Acuumulate negative stride - tensor_most_negative_index += - (tensor_input.size(dim_i) - 1) * tensor_input.stride(dim_i); - } - } - } - - // Fall back to int64 if it can be either too positive - // or too negative. - if (tensor_most_positive_index > most_positive_int32_index || - tensor_most_negative_index < most_negative_int32_index) { - return KernelIndexMode::INT64; - } - } - } - // return index mode as int32 - return KernelIndexMode::INT32; -} - bool isDebugDumpEnabled(DebugDumpOption option) { return getDebugDumpOptions().count(option); } diff --git a/csrc/utils.h b/csrc/utils.h index 0c9eaa59031..294e69e565b 100644 --- a/csrc/utils.h +++ b/csrc/utils.h @@ -32,7 +32,6 @@ bool is_cpu_scalar(const c10::TensorType& tensor_type); // TODO: merge these two // check if input is compatible with 32b index mode int8_t getCommonDeviceCUDA(const at::ArrayRef& inputs); -KernelIndexMode collectIndexMode(const at::ArrayRef& inputs); //! Types of debug print-outs //! @@ -544,4 +543,39 @@ auto atenTypeDispatchWithC10Complex( } } +// Computes the index type required. +// Made into a class w/ state to allow reuse with +// different tensors and without needing to pass an allocated +// vector of size+stride +class KernelIndexTypeCompute { + // Save 1 more bit besides the sign bit to be conservative + static constexpr int64_t most_positive_int32_index = + std::numeric_limits::max() / 2; + + public: + // Updates counters and returns current reqd mode + inline PrimDataType addDim(int64_t size, int64_t stride) { + if (size > 1) { + TORCH_INTERNAL_ASSERT( + stride >= 0, "Negative stride is not supported: ", stride); + if (stride > 0) { + // Accumulate positive stride + tensor_most_positive_index_ += (size - 1) * stride; + } + } + return getType(); + } + + inline PrimDataType getType() const { + if (tensor_most_positive_index_ > most_positive_int32_index) { + return PrimDataType::Int; + } else { + return PrimDataType::Int32; + } + } + + private: + int64_t tensor_most_positive_index_ = 0; +}; + } // namespace nvfuser diff --git a/test/test_gpu2.cpp b/test/test_gpu2.cpp index 226ca48931e..7a14d33bd54 100644 --- a/test/test_gpu2.cpp +++ b/test/test_gpu2.cpp @@ -1364,11 +1364,10 @@ TEST_F(NVFuserTest, FusionBiasGeluFwd_CUDA) { auto at_input = at::randn(input_shape, options); auto at_bias = at::randn(bias_shape, options); - auto at_x = - at_bias.to(c10::ScalarType::Float) + at_input.to(c10::ScalarType::Float); - auto aten_output_float = + auto at_x = at_bias.to(c10::ScalarType::Double) + + at_input.to(c10::ScalarType::Double); + auto aten_output_double = at_x * 0.5 * (1.0 + (k_079 * at_x * (1 + k_004 * at_x * at_x)).tanh()); - auto aten_output = aten_output_float.to(c10::ScalarType::Half); std::vector aten_inputs = {at_bias, at_input}; auto lparams = schedulePointwise(&fusion, aten_inputs); @@ -1378,7 +1377,12 @@ TEST_F(NVFuserTest, FusionBiasGeluFwd_CUDA) { auto cg_outputs = fe.runFusion(aten_inputs, lparams); testValidate( - &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); + &fusion, + cg_outputs, + aten_inputs, + {aten_output_double}, + __LINE__, + __FILE__); } TEST_F(NVFuserTest, FusionBiasGeluBwd_CUDA) { @@ -1431,24 +1435,23 @@ TEST_F(NVFuserTest, FusionBiasGeluBwd_CUDA) { fusion.addOutput(t27); auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); - at::manual_seed(1); + at::manual_seed(0); std::vector input_shape{6, 512, 4096}; std::vector bias_shape{4096}; auto at_input = at::randn(input_shape, options); auto at_bias = at::randn(bias_shape, options); auto at_grad = at::randn(input_shape, options); - auto at_x = - at_bias.to(c10::ScalarType::Float) + at_input.to(c10::ScalarType::Float); + auto at_x = at_bias.to(c10::ScalarType::Double) + + at_input.to(c10::ScalarType::Double); auto at_tanh_out = (k_079 * at_x * (1 + k_004 * at_x * at_x)).tanh(); auto at_ff = 0.5 * at_x * ((1 - at_tanh_out * at_tanh_out) * (k_079 + k_010 * at_x * at_x)) + 0.5 * (1 + at_tanh_out); auto at_out = at_ff * at_grad; - auto at_out_half = at_out.to(c10::ScalarType::Half); std::vector aten_inputs = {at_grad, at_bias, at_input}; - std::vector aten_outputs = {at_out, at_out_half}; + std::vector aten_outputs = {at_out, at_out}; auto lparams = schedulePointwise(&fusion, aten_inputs); @@ -5131,7 +5134,7 @@ TEST_F(NVFuserTest, FusionDAGMerging_CUDA) { std::vector aten_inputs = {t0, t1}; - KernelArgumentHolder args(KernelIndexMode::INT32); + KernelArgumentHolder args; args.setDeviceIndex(0); args.push(aten_inputs); @@ -5465,7 +5468,7 @@ TEST_F(NVFuserTest, FusionSegmentVerticalMerge_CUDA) { auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({2, 2, 2}, options); - KernelArgumentHolder args(KernelIndexMode::INT32); + KernelArgumentHolder args; args.setDeviceIndex(0); args.push(t0); @@ -5509,7 +5512,7 @@ TEST_F(NVFuserTest, FusionSegmentHorizontalMerge_CUDA) { auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({2, 2, 2}, options); - KernelArgumentHolder args(KernelIndexMode::INT32); + KernelArgumentHolder args; args.setDeviceIndex(0); args.push(t0); c10::IValue scalar = 1.0; @@ -5554,7 +5557,7 @@ TEST_F(NVFuserTest, FusionSegmentMixReduction_CUDA) { auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({2, 2, 2}, options); - KernelArgumentHolder args(KernelIndexMode::INT32); + KernelArgumentHolder args; args.setDeviceIndex(0); args.push(t0); @@ -8185,7 +8188,7 @@ TEST_F(NVFuserTest, FusionSegmenterCombineReductionsCycleRepro_CUDA) { c10::IValue val = at_d56; - KernelArgumentHolder args(KernelIndexMode::INT32); + KernelArgumentHolder args; args.setDeviceIndex(0); args.push(aten_inputs); args.push(val); @@ -9159,7 +9162,7 @@ TEST_F(NVFuserTest, FusionTestWarpSoftMax_CUDA) { std::vector aten_inputs({aten_input}); // Schedule through magic scheduler - SchedulerRuntimeInfo runtime_info(&fusion, aten_inputs, true); + SchedulerRuntimeInfo runtime_info(&fusion, aten_inputs); TORCH_CHECK(SchedulerEntry::canSchedule( ScheduleHeuristic::Persistent, &fusion, runtime_info)); auto scheduler = SchedulerEntry::makeEntry( @@ -9347,7 +9350,7 @@ TEST_F(NVFuserTest, FusionPersistentBufferCalculation1_CUDA) { at::Tensor aten_t0 = at::randn({99, 101}, options); // Schedule through magic scheduler - SchedulerRuntimeInfo runtime_info(&fusion, {aten_t0}, true); + SchedulerRuntimeInfo runtime_info(&fusion, {aten_t0}); auto persistent_buffer_size = persistentBufferSize(&fusion, runtime_info, persistent_buffer_info); @@ -9410,7 +9413,7 @@ TEST_F(NVFuserTest, FusionPersistentBufferCalculation2_CUDA) { at::Tensor aten_t0 = at::randn({99, 101}, options); // Schedule through magic scheduler - SchedulerRuntimeInfo runtime_info(&fusion, {aten_t0}, true); + SchedulerRuntimeInfo runtime_info(&fusion, {aten_t0}); auto persistent_buffer_size = persistentBufferSize(&fusion, runtime_info, persistent_buffer_info); @@ -9494,7 +9497,7 @@ TEST_F(NVFuserTest, FusionPersistentBufferCalculation3_CUDA) { at::Tensor aten_t5 = at::randn({99, 101}, options); // Schedule through magic scheduler - SchedulerRuntimeInfo runtime_info(&fusion, {aten_t0, aten_t5}, true); + SchedulerRuntimeInfo runtime_info(&fusion, {aten_t0, aten_t5}); auto persistent_buffer_size = persistentBufferSize(&fusion, runtime_info, persistent_buffer_info); @@ -9573,7 +9576,7 @@ TEST_F(NVFuserTest, FusionPersistentBufferCalculation4_CUDA) { at::Tensor aten_t0 = at::randn({99, 101}, options); // Schedule through magic scheduler - SchedulerRuntimeInfo runtime_info(&fusion, {aten_t0}, true); + SchedulerRuntimeInfo runtime_info(&fusion, {aten_t0}); auto persistent_buffer_size = persistentBufferSize(&fusion, runtime_info, persistent_buffer_info); @@ -9696,7 +9699,7 @@ TEST_F(NVFuserTest, FusionPersistentBufferProjection2_CUDA) { tv->toString()); } - SchedulerRuntimeInfo runtime_info(&fusion, {t0, t1}, true); + SchedulerRuntimeInfo runtime_info(&fusion, {t0, t1}); auto persistent_buffer_size = persistentBufferSize(&fusion, runtime_info, persistent_info); diff --git a/test/test_gpu3.cpp b/test/test_gpu3.cpp index fd4eec72598..fea97167abc 100644 --- a/test/test_gpu3.cpp +++ b/test/test_gpu3.cpp @@ -7806,10 +7806,10 @@ TEST_F(NVFuserTest, FusionCompileIndexType_CUDA) { TORCH_CHECK( KernelArgumentHolder::createKernelArgumentHolder(large_inputs) - .getIndexMode() == KernelIndexMode::INT64); + .getSmallestIndexTypeOfArguments() == PrimDataType::Int); TORCH_CHECK( KernelArgumentHolder::createKernelArgumentHolder(small_inputs) - .getIndexMode() == KernelIndexMode::INT32); + .getSmallestIndexTypeOfArguments() == PrimDataType::Int32); { FusionExecutor fe; @@ -7847,43 +7847,28 @@ TEST_F(NVFuserTest, FusionCompileIndexType_CUDA) { { FusionExecutor fe; - fe.compileFusion(&fusion, small_inputs); - TORCH_CHECK( - fe.kernel()->indexType() == PrimDataType::Int32, - "Unexpected kernel index type: ", - fe.kernel()->indexType()); - - // This should complete successfully as the arguments are small - // enough to use the int32 index type - fe.runFusion(small_inputs); - - // This should fail as the Kernel is already compiled for Int32, but - // the arguments are too large - EXPECT_THAT( - [&]() { fe.runFusion(large_inputs); }, - testing::ThrowsMessage(testing::HasSubstr( - "Given index mode and argument index mode don't match"))); - } - - { - FusionExecutor fe; - // Lower the kernel with int32 index type. + LaunchParams launch_params; CompileParams compile_opts = {.index_type = PrimDataType::Int32}; + fe.compileFusion(&fusion, small_inputs, launch_params, compile_opts); - fe.compileFusion(&fusion, {}, LaunchParams(), compile_opts); TORCH_CHECK( fe.kernel()->indexType() == PrimDataType::Int32, "Unexpected kernel index type: ", fe.kernel()->indexType()); + // This should complete successfully as the arguments are small + // enough to use the int32 index type fe.runFusion(small_inputs); // This should fail as the Kernel is already compiled for Int32, but // the arguments are too large + CompileParams compile_opts_large = {.index_type = PrimDataType::Int}; EXPECT_THAT( - [&]() { fe.runFusion(large_inputs); }, + [&]() { + fe.runFusion(large_inputs, launch_params, compile_opts_large); + }, testing::ThrowsMessage(testing::HasSubstr( - "Given index mode and argument index mode don't match"))); + "Kernel index type and compilation index type don't match"))); } { @@ -7904,6 +7889,87 @@ TEST_F(NVFuserTest, FusionCompileIndexType_CUDA) { c10::cuda::CUDACachingAllocator::emptyCache(); } +// Make sure the index type is determined both fusion inputs and outputs +TEST_F(NVFuserTest, FusionExecutorCacheIndexType1_CUDA) { + auto fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(fusion_ptr.get()); + + auto tv0 = makeSymbolicTensor(2, DataType::Half); + fusion.addInput(tv0); + auto tv1 = makeSymbolicTensor(2, DataType::Half); + fusion.addInput(tv1); + + auto tv2 = castOp(DataType::Float, tv0); + auto tv3 = castOp(DataType::Float, tv1); + auto tv4 = broadcast(tv2, {false, true, false}); + auto tv5 = broadcast(tv3, {true, false, false}); + auto tv6 = add(tv4, tv5); + auto tv7 = castOp(DataType::Half, tv6); + + fusion.addOutput(tv7); + + c10::cuda::CUDACachingAllocator::emptyCache(); + + // Inputs are small enough to use 32-bit indexing, but the output is + // not + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({2024, 1024}, options); + at::Tensor t1 = at::randn({2024, 1024}, options); + std::vector aten_inputs({t0, t1}); + + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs); + + auto kernel_runtime = executor_cache.getMostRecentKernelRuntime(); + TORCH_CHECK(kernel_runtime->getIndexType() == PrimDataType::Int); + + c10::cuda::CUDACachingAllocator::emptyCache(); +} + +// Make sure the index type is also determined by intermediate +// tensors. This is not ideal but just tests if the logic produces +// what is expected at this moment +TEST_F(NVFuserTest, FusionExecutorCacheIndexType2_CUDA) { + auto fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(fusion_ptr.get()); + + auto tv0 = makeSymbolicTensor(2, DataType::Half); + fusion.addInput(tv0); + auto tv1 = makeSymbolicTensor(2, DataType::Half); + fusion.addInput(tv1); + + auto tv2 = broadcast(tv0, {false, true, false}); + auto tv3 = broadcast(tv1, {true, false, false}); + auto tv4 = add(tv2, tv3); + auto tv5 = sum(tv4, {-1}); + + fusion.addOutput(tv5); + + // Inputs and outputs are small enough to use 32-bit indexing, + // however the intermediate, tv4, should cause the kernel to use + // 64-bit indexing. This is not ideal as tv4 should be inlined, and + // its allocation size should be small enough to use 32-bit + // indexing. However, the current logic should result in forcing + // 64-bit indexing. This would need to be fixed for matmul for + // example. + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({2024, 1024}, options); + at::Tensor t1 = at::randn({2024, 1024}, options); + std::vector aten_inputs({t0, t1}); + + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + executor_cache.runFusionWithInputs(aten_inputs); + auto kernel_runtime = executor_cache.getMostRecentKernelRuntime(); + TORCH_CHECK(kernel_runtime->getIndexType() == PrimDataType::Int); + + // Running again with forced type of Int32 + executor_cache.runFusionWithInputs(aten_inputs, PrimDataType::Int32); + kernel_runtime = executor_cache.getMostRecentKernelRuntime(); + TORCH_CHECK(kernel_runtime->getIndexType() == PrimDataType::Int32); +} + //! Test whether we can create and use float16 scalars TEST_F(NVFuserTest, FusionHalfScalars_CUDA) { auto fusion = std::make_unique(); diff --git a/test/test_gpu_indexing_ops.cpp b/test/test_gpu_indexing_ops.cpp index 6840a1d82ef..3a221a26fce 100644 --- a/test/test_gpu_indexing_ops.cpp +++ b/test/test_gpu_indexing_ops.cpp @@ -347,7 +347,7 @@ TEST_F(NVFuserTest, FusionIndexSelectCanSch_CUDA) { std::vector aten_inputs = {input_pre, input1, input0, input_idx}; // Schedule through magic scheduler - SchedulerRuntimeInfo runtime_info(&fusion_fail, aten_inputs, true); + SchedulerRuntimeInfo runtime_info(&fusion_fail, aten_inputs); auto sch_fail = SchedulerEntry::canSchedule( ScheduleHeuristic::PointWise, &fusion_fail, runtime_info); @@ -375,8 +375,7 @@ TEST_F(NVFuserTest, FusionIndexSelectCanSch_CUDA) { std::vector aten_sum_inputs = { input_pre, input1, input0, input_idx}; // Schedule through magic scheduler - SchedulerRuntimeInfo runtime_sum_info( - &fusion_sum_fail, aten_sum_inputs, true); + SchedulerRuntimeInfo runtime_sum_info(&fusion_sum_fail, aten_sum_inputs); auto sch_sum_fail = SchedulerEntry::canSchedule( ScheduleHeuristic::Reduction, &fusion_sum_fail, runtime_sum_info); @@ -397,7 +396,7 @@ TEST_F(NVFuserTest, FusionIndexSelectCanSch_CUDA) { fusion_pass.addOutput(tv3_p); // Schedule through magic scheduler std::vector aten_inputs_pass = {input1, input0, input_idx}; - SchedulerRuntimeInfo runtime_info_pass(&fusion_pass, aten_inputs_pass, true); + SchedulerRuntimeInfo runtime_info_pass(&fusion_pass, aten_inputs_pass); auto sch_pass = SchedulerEntry::canSchedule( ScheduleHeuristic::PointWise, &fusion_pass, runtime_info_pass); diff --git a/test/test_gpu_tensorcore.cpp b/test/test_gpu_tensorcore.cpp index 860a7acb47d..09bf6df19c2 100644 --- a/test/test_gpu_tensorcore.cpp +++ b/test/test_gpu_tensorcore.cpp @@ -2786,7 +2786,7 @@ TEST_F(NVFuserTest, FusionAmpereMatmulTNSwizzled_CUDA) { auto t1 = at::randn({N, K}, options); FusionExecutor fe; - fe.compileFusion(&fusion, {}, LaunchParams(), matmul_cparams); + fe.compileFusion(&fusion, {t0, t1}, LaunchParams(), matmul_cparams); auto cg_outputs = fe.runFusion({t0, t1}); auto tref = t0.to(at::kFloat).matmul(t1.t().to(at::kFloat)); diff --git a/test/test_gpu_validator.h b/test/test_gpu_validator.h index a30314196b0..5fbddb6a0c9 100644 --- a/test/test_gpu_validator.h +++ b/test/test_gpu_validator.h @@ -265,8 +265,7 @@ ExpressionEvaluator bindInputsAndLaunchParams( Fusion* fusion, const at::ArrayRef& aten_inputs, const LaunchParams& launch_constraints) { - // index_mode is not important here - KernelArgumentHolder argument_holder(KernelIndexMode::INT64); + KernelArgumentHolder argument_holder; argument_holder.push(aten_inputs); auto expr_eval = executor_utils::bindInputs(argument_holder, fusion);