diff --git a/benchmarks/cpp/nvfuser/timm.cpp b/benchmarks/cpp/nvfuser/timm.cpp index da66cfe2e5e8c..60565a844d846 100644 --- a/benchmarks/cpp/nvfuser/timm.cpp +++ b/benchmarks/cpp/nvfuser/timm.cpp @@ -11,7 +11,7 @@ using namespace torch::jit::fuser::cuda; -static void setup_vit_base_patch16_224_kernel17(Fusion* fusion, void* null) { +static void setup_vit_base_patch16_224_bcast7(Fusion* fusion, void* null) { FusionGuard fg(fusion); auto t2 = makeContigTensor(3, DataType::Float); @@ -48,7 +48,7 @@ static void setup_vit_base_patch16_224_kernel17(Fusion* fusion, void* null) { fusion->addOutput(t39); } -static void NvFuserScheduler_TIMM_vit_base_patch16_224_kernel17( +static void NvFuserScheduler_TIMM_vit_base_patch16_224_bcast7( benchmark::State& benchmark_state, FusionExecutorCache* fusion_executor_cache, void* null) { @@ -82,17 +82,18 @@ static void NvFuserScheduler_TIMM_vit_base_patch16_224_kernel17( } NVFUSER_BENCHMARK_DEFINE( - NvFuserScheduler_TIMM_NCHW_vit_base_patch16_224_kernel17, - setup_vit_base_patch16_224_kernel17, - NvFuserScheduler_TIMM_vit_base_patch16_224_kernel17, + NvFuserScheduler_TIMM_NCHW_vit_base_patch16_224_bcast7, + setup_vit_base_patch16_224_bcast7, + NvFuserScheduler_TIMM_vit_base_patch16_224_bcast7, nullptr); -NVFUSER_BENCHMARK_RUN(NvFuserScheduler_TIMM_NCHW_vit_base_patch16_224_kernel17) +// pwise case, broadcasting both sides +NVFUSER_BENCHMARK_RUN(NvFuserScheduler_TIMM_NCHW_vit_base_patch16_224_bcast7) ->Args({64, 197, 768}) ->Unit(benchmark::kMicrosecond) ->UseManualTime(); -static void setup_vit_base_patch16_224_kernel5(Fusion* fusion, void* null) { +static void setup_vit_base_patch16_224_bcast5(Fusion* fusion, void* null) { FusionGuard fg(fusion); auto t2 = makeContigTensor(3, DataType::Float); @@ -162,7 +163,7 @@ static void setup_vit_base_patch16_224_kernel5(Fusion* fusion, void* null) { fusion->addOutput(t34); // full 3d half } -static void NvFuserScheduler_TIMM_vit_base_patch16_224_kernel5( +static void NvFuserScheduler_TIMM_vit_base_patch16_224_bcast5( benchmark::State& benchmark_state, FusionExecutorCache* fusion_executor_cache, void* null) { @@ -196,17 +197,20 @@ static void NvFuserScheduler_TIMM_vit_base_patch16_224_kernel5( } NVFUSER_BENCHMARK_DEFINE( - NvFuserScheduler_TIMM_vit_base_patch16_224_kernel5_NCHW, - setup_vit_base_patch16_224_kernel5, - NvFuserScheduler_TIMM_vit_base_patch16_224_kernel5, + NvFuserScheduler_TIMM_vit_base_patch16_224_bcast5_NCHW, + setup_vit_base_patch16_224_bcast5, + NvFuserScheduler_TIMM_vit_base_patch16_224_bcast5, nullptr); -NVFUSER_BENCHMARK_RUN(NvFuserScheduler_TIMM_vit_base_patch16_224_kernel5_NCHW) +// Broadcast on both sides +NVFUSER_BENCHMARK_RUN(NvFuserScheduler_TIMM_vit_base_patch16_224_bcast5_NCHW) ->Args({64, 197, 768}) ->Unit(benchmark::kMicrosecond) ->UseManualTime(); -static void setup_vit_base_patch16_224_kernel2(Fusion* fusion, void* null) { +static void setup_vit_base_patch16_224_bcast_outer2( + Fusion* fusion, + void* null) { FusionGuard fg(fusion); auto t0 = makeContigTensor(3, DataType::Half); @@ -226,7 +230,7 @@ static void setup_vit_base_patch16_224_kernel2(Fusion* fusion, void* null) { fusion->addOutput(t7); } -static void NvFuserScheduler_TIMM_vit_base_patch16_224_kernel2( +static void NvFuserScheduler_TIMM_vit_base_patch16_224_bcast_outer2( benchmark::State& benchmark_state, FusionExecutorCache* fusion_executor_cache, void* null) { @@ -254,17 +258,18 @@ static void NvFuserScheduler_TIMM_vit_base_patch16_224_kernel2( } NVFUSER_BENCHMARK_DEFINE( - NvFuserScheduler_TIMM_NCHW_vit_base_patch16_224_kernel2, - setup_vit_base_patch16_224_kernel2, - NvFuserScheduler_TIMM_vit_base_patch16_224_kernel2, + NvFuserScheduler_TIMM_NCHW_vit_base_patch16_224_bcast_outer2, + setup_vit_base_patch16_224_bcast_outer2, + NvFuserScheduler_TIMM_vit_base_patch16_224_bcast_outer2, nullptr); -NVFUSER_BENCHMARK_RUN(NvFuserScheduler_TIMM_NCHW_vit_base_patch16_224_kernel2) +NVFUSER_BENCHMARK_RUN( + NvFuserScheduler_TIMM_NCHW_vit_base_patch16_224_bcast_outer2) ->Args({64, 197, 2304}) ->Unit(benchmark::kMicrosecond) ->UseManualTime(); -static void setup_vit_base_patch16_224_kernel3(Fusion* fusion, void* null) { +static void setup_vit_base_patch16_224_norm_inner3(Fusion* fusion, void* null) { FusionGuard fg(fusion); auto t0 = makeContigTensor(4, DataType::Half); @@ -280,7 +285,7 @@ static void setup_vit_base_patch16_224_kernel3(Fusion* fusion, void* null) { auto t6 = broadcast(t5, {false, false, false, true}); auto t7 = sub(t4, t6); auto t8 = exp(t7); - auto t9 = sum(t8, {-2}); + auto t9 = sum(t8, {3}); auto t10 = broadcast(t9, {false, false, false, true}); auto t11 = reciprocal(t10); auto t12 = mul(t8, t11); @@ -303,7 +308,7 @@ static void setup_vit_base_patch16_224_kernel3(Fusion* fusion, void* null) { fusion->addOutput(t4); } -static void NvFuserScheduler_TIMM_vit_base_patch16_224_kernel3( +static void NvFuserScheduler_TIMM_vit_base_patch16_224_norm_inner3( benchmark::State& benchmark_state, FusionExecutorCache* fusion_executor_cache, void* null) { @@ -329,17 +334,21 @@ static void NvFuserScheduler_TIMM_vit_base_patch16_224_kernel3( } NVFUSER_BENCHMARK_DEFINE( - NvFuserScheduler_TIMM_NCHW_vit_base_patch16_224_kernel3, - setup_vit_base_patch16_224_kernel3, - NvFuserScheduler_TIMM_vit_base_patch16_224_kernel3, + NvFuserScheduler_TIMM_NCHW_vit_base_patch16_224_norm_inner3, + setup_vit_base_patch16_224_norm_inner3, + NvFuserScheduler_TIMM_vit_base_patch16_224_norm_inner3, nullptr); -NVFUSER_BENCHMARK_RUN(NvFuserScheduler_TIMM_NCHW_vit_base_patch16_224_kernel3) +// Norm inner dim +NVFUSER_BENCHMARK_RUN( + NvFuserScheduler_TIMM_NCHW_vit_base_patch16_224_norm_inner3) ->Args({64, 12, 197}) ->Unit(benchmark::kMicrosecond) ->UseManualTime(); -static void setup_vit_base_patch16_224_kernel6(Fusion* fusion, void* null) { +static void setup_vit_base_patch16_224_bcast_outer6( + Fusion* fusion, + void* null) { FusionGuard fg(fusion); auto t0 = makeContigTensor(3, DataType::Half); @@ -378,7 +387,7 @@ static void setup_vit_base_patch16_224_kernel6(Fusion* fusion, void* null) { fusion->addOutput(t19); } -static void NvFuserScheduler_TIMM_vit_base_patch16_224_kernel6( +static void NvFuserScheduler_TIMM_vit_base_patch16_224_bcast_outer6( benchmark::State& benchmark_state, FusionExecutorCache* fusion_executor_cache, void* null) { @@ -405,12 +414,13 @@ static void NvFuserScheduler_TIMM_vit_base_patch16_224_kernel6( } NVFUSER_BENCHMARK_DEFINE( - NvFuserScheduler_TIMM_NCHW_vit_base_patch16_224_kernel6, - setup_vit_base_patch16_224_kernel6, - NvFuserScheduler_TIMM_vit_base_patch16_224_kernel6, + NvFuserScheduler_TIMM_NCHW_vit_base_patch16_224_bcast_outer6, + setup_vit_base_patch16_224_bcast_outer6, + NvFuserScheduler_TIMM_vit_base_patch16_224_bcast_outer6, nullptr); -NVFUSER_BENCHMARK_RUN(NvFuserScheduler_TIMM_NCHW_vit_base_patch16_224_kernel6) +NVFUSER_BENCHMARK_RUN( + NvFuserScheduler_TIMM_NCHW_vit_base_patch16_224_bcast_outer6) // First size is original, the rest are variations to check perf // reliability. ->Args({64, 197, 3 * 1024}) @@ -425,7 +435,7 @@ NVFUSER_BENCHMARK_RUN(NvFuserScheduler_TIMM_NCHW_vit_base_patch16_224_kernel6) ->UseManualTime(); // Reverse the broadcast dimensions to check for consistency in scheduling. -static void setup_vit_base_patch16_224_kernel6_reversed( +static void setup_vit_base_patch16_224_bcast_inner6( Fusion* fusion, void* null) { FusionGuard fg(fusion); @@ -466,7 +476,7 @@ static void setup_vit_base_patch16_224_kernel6_reversed( fusion->addOutput(t19); } -static void NvFuserScheduler_TIMM_vit_base_patch16_224_kernel6_reversed( +static void NvFuserScheduler_TIMM_vit_base_patch16_224_bcast_inner6( benchmark::State& benchmark_state, FusionExecutorCache* fusion_executor_cache, void* null) { @@ -494,13 +504,13 @@ static void NvFuserScheduler_TIMM_vit_base_patch16_224_kernel6_reversed( } NVFUSER_BENCHMARK_DEFINE( - NvFuserScheduler_TIMM_NCHW_vit_base_patch16_224_kernel6_reversed, - setup_vit_base_patch16_224_kernel6_reversed, - NvFuserScheduler_TIMM_vit_base_patch16_224_kernel6_reversed, + NvFuserScheduler_TIMM_NCHW_vit_base_patch16_224_bcast_inner6, + setup_vit_base_patch16_224_bcast_inner6, + NvFuserScheduler_TIMM_vit_base_patch16_224_bcast_inner6, nullptr); NVFUSER_BENCHMARK_RUN( - NvFuserScheduler_TIMM_NCHW_vit_base_patch16_224_kernel6_reversed) + NvFuserScheduler_TIMM_NCHW_vit_base_patch16_224_bcast_inner6) ->Args({64, 197, 3 * 1024}) ->Args({64, 197, 2 * 1024}) ->Args({64, 197, 1024}) @@ -511,3 +521,218 @@ NVFUSER_BENCHMARK_RUN( ->Args({2, 256, 64 * 197}) ->Unit(benchmark::kMicrosecond) ->UseManualTime(); + +static void setup_vit_base_patch16_224_LN_BWD(Fusion* fusion, void* null) { + FusionGuard fg(fusion); + + auto t0 = makeContigTensor(3, DataType::Bool); + fusion->addInput(t0); + + auto t1 = makeContigTensor(3, DataType::Half); + fusion->addInput(t1); + + auto t2 = castOp(DataType::Float, t1); + + auto t3 = makeContigTensor(3, DataType::Half); + fusion->addInput(t3); + + auto t4 = castOp(DataType::Float, t3); + + auto d35 = t3->axis(2)->extent(); + + auto t5 = TensorViewBuilder() + .shape({-1, -1, 1}) + .dtype(DataType::Float) + .contiguity({true, true, false}) + .build(); + fusion->addInput(t5); + + auto t6 = TensorViewBuilder() + .shape({-1, -1, 1}) + .dtype(DataType::Float) + .contiguity({true, true, false}) + .build(); + fusion->addInput(t6); + + auto t7 = makeContigTensor(1, DataType::Half); + fusion->addInput(t7); + + auto t8 = castOp(DataType::Float, t7); + + auto t9 = makeContigTensor(1, DataType::Half); + fusion->addInput(t9); + + auto t11 = sub(t4, t5); + auto t12 = mul(t11, t6); + + auto t13 = broadcast(t8, {true, true, false}); + auto t14 = mul(t2, t13); + auto t15 = mul(d35, t14); + auto t16 = sum(t14, {2}); + auto t17 = broadcast(t16, {false, false, true}); + auto t18 = mul(t14, t12); + auto t19 = sum(t18, {2}); + auto t20 = broadcast(t19, {false, false, true}); + + auto t40 = castOp(DataType::Half, t12); + auto t41 = castOp(DataType::Float, t40); + auto t42 = castOp(DataType::Half, t20); + auto t43 = castOp(DataType::Float, t42); + auto t21 = mul(t42, t43); + + auto t38 = castOp(DataType::Half, t15); + auto t39 = castOp(DataType::Float, t38); + auto t44 = castOp(DataType::Half, t17); + auto t45 = castOp(DataType::Float, t44); + auto t22 = sub(t39, t45); + + auto t23 = sub(t22, t21); + + auto d87 = reciprocal(d35); + auto t24 = mul(d87, t6); + + auto t25 = mul(t24, t23); + auto t26 = mul(t2, t41); + auto t27 = sum(t26, {0, 1}); + auto t28 = sum(t2, {0, 1}); + + auto t29 = castOp(DataType::Float, t0); + auto t30 = mul(t25, t29); + + auto d33 = IrBuilder::create(); + fusion->addInput(d33); + auto t31 = mul(t30, d33); + auto t32 = sum(t31, {0, 1}); + auto t33 = castOp(DataType::Half, t32); + auto t34 = castOp(DataType::Half, t31); + auto t35 = castOp(DataType::Half, t25); + auto t36 = castOp(DataType::Half, t27); + auto t37 = castOp(DataType::Half, t28); + + fusion->addOutput(t33); + fusion->addOutput(t34); + fusion->addOutput(t35); + fusion->addOutput(t36); + fusion->addOutput(t37); +} + +static void NvFuserScheduler_TIMM_vit_base_patch16_224_LN_BWD( + benchmark::State& benchmark_state, + FusionExecutorCache* fusion_executor_cache, + void* null) { + std::vector input_shape{ + benchmark_state.range(0), + benchmark_state.range(1), + benchmark_state.range(2)}; + + at::manual_seed(0); + // auto bool_options = at::TensorOptions().dtype(at::kBool).device(at::kCUDA, + // 0); + auto fp16_options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + auto fp32_options = + at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + auto t0 = at::randn(input_shape, fp16_options).to(at::kBool); + auto t1 = at::randn(input_shape, fp16_options); + auto t3 = at::randn(input_shape, fp16_options); + auto t5 = at::randn({input_shape[0], input_shape[1], 1}, fp32_options); + auto t6 = at::randn({input_shape[0], input_shape[1], 1}, fp32_options); + auto t7 = at::randn({input_shape[2]}, fp16_options); + auto t9 = at::randn({input_shape[2]}, fp16_options); + + std::vector aten_inputs({t0, t1, t3, t5, t6, t7, t9, 1.0}); + runBenchmarkIterations(benchmark_state, fusion_executor_cache, aten_inputs); + + // Full tensors - bool, halfx4 - t0, t1, t3, t34, t35 + // Outer two dimensions - floatx2 - t5, t6 + // Inner dimension - halfx5 - t7, t9, t33, t36, t37 + benchmark_state.SetBytesProcessed( + int64_t(benchmark_state.iterations()) * ((t0.numel() * (4 * 2 + 1))) + + (t5.numel() * 4 * 2) + (t7.numel() * 5 * 2)); +} + +NVFUSER_BENCHMARK_DEFINE( + NvFuserScheduler_TIMM_NCHW_vit_base_patch16_224_LN_BWD, + setup_vit_base_patch16_224_LN_BWD, + NvFuserScheduler_TIMM_vit_base_patch16_224_LN_BWD, + nullptr); + +NVFUSER_BENCHMARK_RUN(NvFuserScheduler_TIMM_NCHW_vit_base_patch16_224_LN_BWD) + ->Args({128, 197, 768}) + ->Unit(benchmark::kMicrosecond) + ->UseManualTime(); + +static void setup_vit_base_patch16_224_norm_inner2(Fusion* fusion, void* null) { + FusionGuard fg(fusion); + + auto t0 = makeContigTensor(4, DataType::Half); + fusion->addInput(t0); + auto d13 = IrBuilder::create(); + fusion->addInput(d13); + + auto t1 = castOp(DataType::Float, t0); + auto t2 = mul(t1, d13); + auto t3 = max(t2, {3}); + auto t4 = broadcast(t3, {false, false, false, true}); + auto t5 = sub(t2, t4); + auto t6 = exp(t5); + auto t7 = sum(t2, {3}); + auto t8 = broadcast(t7, {false, false, false, true}); + auto t9 = reciprocal(t8); + auto t10 = mul(t6, t9); + auto t11 = randlike(t10); + auto d59 = sub(IrBuilder::create(1), IrBuilder::create(0)); + ; + auto t12 = lt(t10, d59); + auto t13 = castOp(DataType::Float, t12); + auto t14 = mul(t10, t13); + auto b61 = eq(d59, IrBuilder::create(0)); + auto d62 = castOp(DataType::Float, b61); + auto d63 = add(d62, d59); + auto d65 = div(IrBuilder::create(1), d63); + auto t15 = mul(t14, d65); + auto t16 = castOp(DataType::Half, t15); + auto t17 = castOp(DataType::Half, t10); + auto t18 = castOp(DataType::Half, t2); + + fusion->addOutput(t16); + fusion->addOutput(t12); + fusion->addOutput(t17); + fusion->addOutput(t18); +} + +static void NvFuserScheduler_TIMM_vit_base_patch16_224_norm_inner2( + benchmark::State& benchmark_state, + FusionExecutorCache* fusion_executor_cache, + void* null) { + std::vector input_shape{ + benchmark_state.range(0), + benchmark_state.range(1), + benchmark_state.range(2), + benchmark_state.range(2)}; + + at::manual_seed(0); + auto fp16_options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + + auto t0 = at::randn(input_shape, fp16_options); + + std::vector aten_inputs({t0, 0.125}); + runBenchmarkIterations(benchmark_state, fusion_executor_cache, aten_inputs); + + // Full tensors - halfx4, bool - t12, t4, t0, t19, t14 + benchmark_state.SetBytesProcessed( + int64_t(benchmark_state.iterations()) * t0.numel() * 4 * 2 + 1); +} + +NVFUSER_BENCHMARK_DEFINE( + NvFuserScheduler_TIMM_NCHW_vit_base_patch16_224_norm_inner2, + setup_vit_base_patch16_224_norm_inner2, + NvFuserScheduler_TIMM_vit_base_patch16_224_norm_inner2, + nullptr); + +// Norm inner dim Half version of vit_base_patch16_224_norm_inner3 +NVFUSER_BENCHMARK_RUN( + NvFuserScheduler_TIMM_NCHW_vit_base_patch16_224_norm_inner2) + ->Args({128, 12, 197}) + ->Unit(benchmark::kMicrosecond) + ->UseManualTime(); \ No newline at end of file diff --git a/torch/csrc/jit/codegen/cuda/codegen.cpp b/torch/csrc/jit/codegen/cuda/codegen.cpp index 8d2a87f195858..ef223bae6d5b3 100644 --- a/torch/csrc/jit/codegen/cuda/codegen.cpp +++ b/torch/csrc/jit/codegen/cuda/codegen.cpp @@ -1342,6 +1342,8 @@ class CudaKernelGenerator : private OptOutConstDispatch { } // Init val func_args.arg(genCall(data_type, genInline(grop->init()))); + func_args.arg(genInline(grop->entrance_index())); + func_args.arg(genInline(grop->entrances())); indent() << "reduction::gridReduce<" << template_args << ">(\n"; indent() << kTab << func_args << ");\n"; @@ -1658,7 +1660,10 @@ class CudaKernelGenerator : private OptOutConstDispatch { indent() << kTab << read_pred << ",\n"; } // TODO : init value support or remove. - indent() << kTab << data_type << "(0));\n"; + indent() << kTab << data_type << "(0),\n"; + indent() << kTab << genInline(gwop->entrance_index()) << ",\n"; + indent() << kTab << genInline(gwop->entrances()); + code_ << ");\n"; } void generateGridAllreduce(const kir::GridWelford* gwop) { diff --git a/torch/csrc/jit/codegen/cuda/kernel.cpp b/torch/csrc/jit/codegen/cuda/kernel.cpp index f0fb2b5db2652..cbbc4f53462e9 100644 --- a/torch/csrc/jit/codegen/cuda/kernel.cpp +++ b/torch/csrc/jit/codegen/cuda/kernel.cpp @@ -115,15 +115,17 @@ class KernelIrScanner : private IrVisitor { void handle(GridWelford* grid_welford) final { summary_.has_welford = true; summary_.has_grid_welford = true; - const auto dom = - grid_welford->welford_op()->out()->as()->view()->domain(); - updateGridReductionInLoop(dom); + summary_.has_grid_reductions = true; + if (grid_welford->welford_op()->isAllreduce()) { + summary_.has_cooperative_grid_reduction = true; + } } void handle(GridReduction* grid_reduction) final { summary_.has_grid_reductions = true; - const auto dom = ir_utils::getTvOutput(grid_reduction)->domain(); - updateGridReductionInLoop(dom); + if (grid_reduction->isAllreduce()) { + summary_.has_cooperative_grid_reduction = true; + } } void handle(GroupedGridReduction* grid_reduction) final { @@ -156,8 +158,6 @@ class KernelIrScanner : private IrVisitor { private: void updateGridReductionInLoop(TensorDomain* dom) { - summary_.has_grid_reductions = true; - for (const auto i : c10::irange(dom->nDims())) { const auto id = GpuLower::current()->caMap()->getConcreteMappedID( dom->domain()[i], IdMappingMode::LOOP); diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir.cpp b/torch/csrc/jit/codegen/cuda/kernel_ir.cpp index a2da12bd0e18c..35537f7a4fcb2 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_ir.cpp +++ b/torch/csrc/jit/codegen/cuda/kernel_ir.cpp @@ -435,6 +435,8 @@ GridReduction::GridReduction( Val* in, Allocate* reduction_buffer, Allocate* sync_buffer, + Val* entrance_index, + Val* entrances, bool is_allreduce) : ReductionOp( passkey, @@ -445,7 +447,9 @@ GridReduction::GridReduction( is_allreduce, ExprType::GridReduction), reduction_buffer_(reduction_buffer), - sync_buffer_(sync_buffer) { + sync_buffer_(sync_buffer), + entrance_index_(entrance_index), + entrances_(entrances) { TORCH_INTERNAL_ASSERT( passkey.ir_container_->isA(), "IR type only valid for Kernel container."); @@ -495,13 +499,17 @@ GridWelford::GridWelford( Allocate* var_buffer, Allocate* avg_buffer, Allocate* n_buffer, - Allocate* sync_buffer) + Allocate* sync_buffer, + Val* entrance_index, + Val* entrances) : Expr(passkey, ExprType::GridWelford), welford_op_(welford_op), var_buffer_(var_buffer), avg_buffer_(avg_buffer), n_buffer_(n_buffer), - sync_buffer_(sync_buffer) { + sync_buffer_(sync_buffer), + entrance_index_(entrance_index), + entrances_(entrances) { TORCH_INTERNAL_ASSERT( passkey.ir_container_->isA(), "IR type only valid for Kernel container."); diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir.h b/torch/csrc/jit/codegen/cuda/kernel_ir.h index 2b39aa4384ac1..99ebdba5bab3e 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_ir.h +++ b/torch/csrc/jit/codegen/cuda/kernel_ir.h @@ -513,6 +513,8 @@ class TORCH_CUDA_CU_API GridReduction final : public ReductionOp { Val* in, Allocate* reduction_buffer, Allocate* sync_buffer, + Val* entrance_index, + Val* entrances, bool is_fused = false); Allocate* reduction_buffer() const { @@ -523,6 +525,16 @@ class TORCH_CUDA_CU_API GridReduction final : public ReductionOp { return sync_buffer_; } + // Which instance of entering this grid reduction is this iteration? + Val* entrance_index() const { + return entrance_index_; + } + + // How many times will this grid reduction be entered + Val* entrances() const { + return entrances_; + } + const ParallelTypeBitmap& threadPredicate() const { return thread_predicate_; } @@ -538,6 +550,8 @@ class TORCH_CUDA_CU_API GridReduction final : public ReductionOp { // use them, the thread predicate is held here separately from // Expr::predicate_. ParallelTypeBitmap thread_predicate_; + Val* entrance_index_ = nullptr; + Val* entrances_ = nullptr; }; class TORCH_CUDA_CU_API GroupedGridReduction final : public GroupedReductionOp { @@ -629,7 +643,9 @@ class TORCH_CUDA_CU_API GridWelford final : public Expr { Allocate* var_buffer, Allocate* avg_buffer, Allocate* n_buffer, - Allocate* sync_buffer); + Allocate* sync_buffer, + Val* entrance_index, + Val* entrances); WelfordOp* welford_op() const { return welford_op_; @@ -651,6 +667,16 @@ class TORCH_CUDA_CU_API GridWelford final : public Expr { return sync_buffer_; } + // Which instance of entering this grid reduction is this iteration? + Val* entrance_index() const { + return entrance_index_; + } + + // How many times will this grid reduction be entered + Val* entrances() const { + return entrances_; + } + const ParallelTypeBitmap& threadPredicate() const { return thread_predicate_; } @@ -665,6 +691,8 @@ class TORCH_CUDA_CU_API GridWelford final : public Expr { Allocate* avg_buffer_ = nullptr; Allocate* n_buffer_ = nullptr; Allocate* sync_buffer_ = nullptr; + Val* entrance_index_ = nullptr; + Val* entrances_ = nullptr; // gridReduce has template flags for thread predicates. In order to // use them, the thread predicate is held here separately from // Expr::predicate_. diff --git a/torch/csrc/jit/codegen/cuda/lower_index.cpp b/torch/csrc/jit/codegen/cuda/lower_index.cpp index 610e6146bb9ca..82c5e017fe480 100644 --- a/torch/csrc/jit/codegen/cuda/lower_index.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_index.cpp @@ -143,6 +143,7 @@ namespace { // size. For example, FusedReduction should double the work buffer size. Val* getGridCommWorkBufferSize( const TensorDomain* td, + const std::vector& for_loops = {}, int expansion_factor = 1) { // The buffer size is the number of thread blocks multiplied by the // number of threads not used for reduction domains. @@ -172,10 +173,28 @@ Val* getGridCommWorkBufferSize( } buffer_size = SimplifyingIrBuilder::mulExpr(buffer_size, pt_dim); } + + // All iteration domains require a separate entry in the buffer for re-entrant + // grid reductions. + for (auto fl : for_loops) { + if (fl->isTrivial()) { + continue; + } + if (fl->iter_domain()->isThread()) { + // already accounted for. + continue; + } + + buffer_size = + SimplifyingIrBuilder::mulExpr(buffer_size, fl->iter_domain()->extent()); + } + return buffer_size; } -Val* getGridSyncBufferSize(const TensorDomain* td) { +Val* getGridSyncBufferSize( + const TensorDomain* td, + const std::vector& for_loops = {}) { // See the comment above for getGridCommWorkBufferSize. Val* buffer_size = GpuLower::current()->kernel()->oneVal(); for (auto pt : kParallelTypeBIDs) { @@ -191,9 +210,66 @@ Val* getGridSyncBufferSize(const TensorDomain* td) { } buffer_size = SimplifyingIrBuilder::mulExpr(buffer_size, pt_dim); } + + // All iteration domains require a separate semaphore for re-entrant grid + // reductions + for (auto fl : for_loops) { + if (fl->isTrivial()) { + continue; + } + if (fl->iter_domain()->isThread()) { + // already accounted for. + continue; + } + + buffer_size = + SimplifyingIrBuilder::mulExpr(buffer_size, fl->iter_domain()->extent()); + } + return buffer_size; } +Val* getEntranceCountGridReduce(std::vector& for_loops) { + Val* grid_reduction_entrances = GpuLower::current()->kernel()->oneVal(); + + for (const auto loop : for_loops) { + if (loop->isTrivial()) { + continue; + } + if (loop->iter_domain()->isThread()) { + // already accounted for. + continue; + } + // TODO: Does this work for shift/gather? + grid_reduction_entrances = SimplifyingIrBuilder::mulExpr( + grid_reduction_entrances, loop->iter_domain()->extent()); + } + return grid_reduction_entrances; +} + +// Linear indexing of for loops for multiple entrances into grid reduce +// TODO: What happens if there's a broadcast that's resolved (not present in the +// grid reduce) but the global buffer isn't expanded? +Val* getEntranceLinIndGridReduce(std::vector& for_loops) { + Val* linear_index = GpuLower::current()->kernel()->zeroVal(); + + for (const auto loop : for_loops) { + if (loop->isTrivial()) { + continue; + } + if (loop->iter_domain()->isThread()) { + // already accounted for. + continue; + } + // TODO: Does this work for shift/gather? + linear_index = SimplifyingIrBuilder::addExpr( + SimplifyingIrBuilder::mulExpr( + linear_index, loop->iter_domain()->extent()), + loop->index()); + } + return linear_index; +} + } // namespace void IndexLowering::handle(const ReductionOp* rop) { @@ -271,12 +347,25 @@ void IndexLowering::handleGridReduction( const auto reduce_buffer = ir_utils::allocGlobalBufferForGridComm( getGridCommWorkBufferSize( - out_domain, rop->isAllreduce() && is_within_a_loop ? 2 : 1), + out_domain, + rop->isAllreduce() ? std::vector() : for_loops_, + rop->isAllreduce() && is_within_a_loop ? 2 : 1), out->dtype(), false); const auto sync_buffer = ir_utils::allocGlobalBufferForGridComm( - getGridSyncBufferSize(out_domain), DataType::Int, true); + getGridSyncBufferSize( + out_domain, + rop->isAllreduce() ? std::vector() : for_loops_), + DataType::Int, + true); + + const auto entrance_ind = rop->isAllreduce() + ? GpuLower::current()->kernel()->zeroVal() + : getEntranceLinIndGridReduce(for_loops_); + const auto n_entrances = rop->isAllreduce() + ? GpuLower::current()->kernel()->oneVal() + : getEntranceCountGridReduce(for_loops_); // The thread predicate for GridReduction needs to be set // separately from the main predicate. Do not combine them like @@ -291,6 +380,8 @@ void IndexLowering::handleGridReduction( in, reduce_buffer, sync_buffer, + entrance_ind, + n_entrances, rop->isAllreduce()); grid_reduction->setThreadPredicate(thread_pred); @@ -412,13 +503,14 @@ void IndexLowering::handleGridReduction( return ir_utils::allocGlobalBufferForGridComm( getGridCommWorkBufferSize( out_domain, + for_loops_, (grouped_rop->isAllreduce() && is_within_a_loop ? 2 : 1)), output->dtype(), false); }); const auto sync_buffer = ir_utils::allocGlobalBufferForGridComm( - getGridSyncBufferSize(out_domain), DataType::Int, true); + getGridSyncBufferSize(out_domain, for_loops_), DataType::Int, true); // The thread predicate for GridReduction needs to be set // separately from the main predicate. Do not combine them like @@ -547,7 +639,9 @@ void IndexLowering::handleGridWelford(WelfordOp* indexed_wop) { [](IterDomain* id) { return !isTrivialIterDomain(id); }); const auto work_buffer_size = getGridCommWorkBufferSize( - out_domain, indexed_wop->isAllreduce() && is_within_a_loop ? 2 : 1); + out_domain, + indexed_wop->isAllreduce() ? std::vector() : for_loops_, + indexed_wop->isAllreduce() && is_within_a_loop ? 2 : 1); const auto out_var_buffer = ir_utils::allocGlobalBufferForGridComm( work_buffer_size, indexed_wop->outVar()->dtype(), false); @@ -557,7 +651,19 @@ void IndexLowering::handleGridWelford(WelfordOp* indexed_wop) { work_buffer_size, indexed_wop->outN()->dtype(), false); const auto sync_buffer = ir_utils::allocGlobalBufferForGridComm( - getGridSyncBufferSize(out_domain), DataType::Int, true); + getGridSyncBufferSize( + out_domain, + indexed_wop->isAllreduce() ? std::vector() + : for_loops_), + DataType::Int, + true); + + const auto entrance_ind = indexed_wop->isAllreduce() + ? GpuLower::current()->kernel()->zeroVal() + : getEntranceLinIndGridReduce(for_loops_); + const auto n_entrances = indexed_wop->isAllreduce() + ? GpuLower::current()->kernel()->oneVal() + : getEntranceCountGridReduce(for_loops_); // The thread predicate for GridReduction needs to be set // separately from the main predicate. Do not combine them like @@ -566,7 +672,13 @@ void IndexLowering::handleGridWelford(WelfordOp* indexed_wop) { GpuLower::current()->threadPredMap().getPredicatedParallelTypes(out_tv); auto grid_welford = IrBuilder::create( - indexed_wop, out_var_buffer, out_avg_buffer, out_N_buffer, sync_buffer); + indexed_wop, + out_var_buffer, + out_avg_buffer, + out_N_buffer, + sync_buffer, + entrance_ind, + n_entrances); grid_welford->setThreadPredicate(thread_pred); diff --git a/torch/csrc/jit/codegen/cuda/runtime/grid_reduction.cu b/torch/csrc/jit/codegen/cuda/runtime/grid_reduction.cu index 8b4151e0e6f45..4d6d9e42e0230 100644 --- a/torch/csrc/jit/codegen/cuda/runtime/grid_reduction.cu +++ b/torch/csrc/jit/codegen/cuda/runtime/grid_reduction.cu @@ -174,6 +174,11 @@ __device__ void gridReduceLastBlock( // gets valid reduction results. There is no guarantee which particular block // gets the final results. // +// entrance_ind and n_entrances are allowed when PERSISTENT_REDUCTION = false. +// If a grid reduction call is only called once per thread, entrance_ind == 0 +// and n_entrances == 1. However, grid reduction can be called in a loop in a +// thread, in that case entrance_ind is the count of times the function has been +// called, and n_entrances is the total number of times it will be called. template < bool X_BLOCK, bool Y_BLOCK, @@ -193,9 +198,15 @@ __device__ void gridReduce( T* shared_buf, bool read_pred, bool write_pred, - T init_val) { + T init_val, + const nvfuser_index_t entrance_ind, + const nvfuser_index_t n_entrances) { T block_reduction_val = init_val; + // entrance index only matters for non-persistent re-entrant grid reductions. + const nvfuser_index_t entrance_ind_ = PERSISTENT_REDUCTION ? 0 : entrance_ind; + const nvfuser_index_t n_entrances_ = PERSISTENT_REDUCTION ? 1 : n_entrances; + // Do block reduction when required if (X_THREAD || Y_THREAD || Z_THREAD) { blockReduce( @@ -227,10 +238,15 @@ __device__ void gridReduce( const auto block_reduction_segment_size = index_utils::maskedSize(blockDim); + // Number of reductions in the grid + const nvfuser_index_t grid_segment_size = PERSISTENT_REDUCTION + ? 1 + : index_utils::maskedSize(gridDim); + // advance to the offset for this segment // index of reduction * size of the reduction * size of threads - work_buf += idx_in_grid_segment * grid_reduction_segment_size * - block_reduction_segment_size; + work_buf += (entrance_ind * grid_segment_size + idx_in_grid_segment) * + grid_reduction_segment_size * block_reduction_segment_size; if ((!X_THREAD || threadIdx.x == 0) && (!Y_THREAD || threadIdx.y == 0) && (!Z_THREAD || threadIdx.z == 0)) { @@ -243,9 +259,16 @@ __device__ void gridReduce( block_offset * block_reduction_segment_size + thread_offset; work_buf[work_buf_offset] = block_reduction_val; } + if (PERSISTENT_REDUCTION) { + grid_sync::sync( + sync_flags[idx_in_grid_segment], grid_reduction_segment_size); - grid_sync::sync( - sync_flags[idx_in_grid_segment], grid_reduction_segment_size); + } else { + grid_sync::sync( + sync_flags[idx_in_grid_segment], + grid_reduction_segment_size, + n_entrances); + } bool last_block = index_utils::maskedIsLast(blockIdx, gridDim); diff --git a/torch/csrc/jit/codegen/cuda/runtime/grid_sync.cu b/torch/csrc/jit/codegen/cuda/runtime/grid_sync.cu index 15bf59ecfbe95..28fe2e0f1e94e 100644 --- a/torch/csrc/jit/codegen/cuda/runtime/grid_sync.cu +++ b/torch/csrc/jit/codegen/cuda/runtime/grid_sync.cu @@ -77,4 +77,54 @@ __device__ void sync(int64_t& semaphore, const uint64_t& segment_size) { index_utils::maskedIsLast(blockIdx, gridDim)); } +// Grid sync that can be called multiple times in the same kernel without all +// blocks being resident on device. This allows grid sync to be called multiple +// times as long as it's not broadcasted on the parallel axis it was reduced on. +// +// n_entrances is how many times every block is expected to enter into this +// function. All blocks must enter n_entrances times. The last block is only +// allowed to proceed once all other blocks have entered n_entrance times. +template +__device__ void sync( + int64_t& semaphore, + const uint64_t& segment_size, + const nvfuser_index_t n_entrances) { + // Finish all global memory transactions before synchronizing + __threadfence(); + + // Synchronize all threads in a block before synchronizing blocks + block_sync::sync(); + + // Only allow linear_tid == 0 to participate in the synchronization + if (threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0) { + // Makes the assumption that blocks are in increasing order, this is not + // guaranteed by CUDA but this is the current behavior, and unlikely to + // change. + bool last_block = + index_utils::maskedIsLast(blockIdx, gridDim); + if (last_block) { + int64_t finished_val = + ((int64_t)(index_utils::maskedSize(gridDim) - 1)) * + ((int64_t)n_entrances); + + unsigned int ns = 8; + // Last block needs to wait for all other blocks to finish + while (globalAsVolatile(semaphore) < finished_val) { +#if __CUDA_ARCH__ >= 700 + // __nanosleep only available on compute capability 7.0 or higher + __nanosleep(ns); // avoids busy waiting + if (ns < 256) { + ns *= 2; + } +#endif + } + } else { + auto old = atomicAdd(reinterpret_cast(&semaphore), 1); + } + } + + // Sync block to make sure all other threads are waiting on the sync + block_sync::sync(); +} + } // namespace grid_sync diff --git a/torch/csrc/jit/codegen/cuda/runtime/welford.cu b/torch/csrc/jit/codegen/cuda/runtime/welford.cu index 4d4fd3876bc19..90502d72ce36d 100644 --- a/torch/csrc/jit/codegen/cuda/runtime/welford.cu +++ b/torch/csrc/jit/codegen/cuda/runtime/welford.cu @@ -261,7 +261,7 @@ __device__ void gridWelfordLastBlock( } } -// Grid welford combine +// Grid welford combine. See GridReduction for more information template < bool X_BLOCK, bool Y_BLOCK, @@ -288,7 +288,13 @@ __device__ void gridWelford( TN* shared_buf_N, bool read_pred, bool write_pred, - T init_val) { + T init_val, + const nvfuser_index_t entrance_ind, + const nvfuser_index_t n_entrances) { + // entrance index only matters for non-persistent re-entrant grid reductions. + const nvfuser_index_t entrance_ind_ = PERSISTENT_REDUCTION ? 0 : entrance_ind; + const nvfuser_index_t n_entrances_ = PERSISTENT_REDUCTION ? 1 : n_entrances; + // Number of values to reduce in the reduction segment const auto grid_reduction_segment_size = index_utils::maskedSize(gridDim); @@ -304,14 +310,21 @@ __device__ void gridWelford( const auto block_reduction_segment_size = index_utils::maskedSize(blockDim); + // Number of reductions in the grid + const nvfuser_index_t grid_segment_size = PERSISTENT_REDUCTION + ? 1 + : index_utils::maskedSize(gridDim); +// if(threadIdx.x == 0 && blockIdx.x == 0){ +// printf("%d\n", grid_segment_size); +// } // advance to the offset for this segment // index of reduction * size of the reduction * size of threads - work_buf_avg += idx_in_grid_segment * grid_reduction_segment_size * - block_reduction_segment_size; - work_buf_M2 += idx_in_grid_segment * grid_reduction_segment_size * - block_reduction_segment_size; - work_buf_N += idx_in_grid_segment * grid_reduction_segment_size * - block_reduction_segment_size; + work_buf_avg += (entrance_ind_ * grid_segment_size + idx_in_grid_segment) * + grid_reduction_segment_size * block_reduction_segment_size; + work_buf_M2 += (entrance_ind_ * grid_segment_size + idx_in_grid_segment) * + grid_reduction_segment_size * block_reduction_segment_size; + work_buf_N += (entrance_ind_ * grid_segment_size + idx_in_grid_segment) * + grid_reduction_segment_size * block_reduction_segment_size; if ((X_THREAD || threadIdx.x == 0) && (Y_THREAD || threadIdx.y == 0) && (Z_THREAD || threadIdx.z == 0)) { @@ -333,8 +346,15 @@ __device__ void gridWelford( } } - grid_sync::sync( - sync_flags[idx_in_grid_segment], grid_reduction_segment_size); + if (PERSISTENT_REDUCTION) { + grid_sync::sync( + sync_flags[idx_in_grid_segment], grid_reduction_segment_size); + } else { + grid_sync::sync( + sync_flags[idx_in_grid_segment], + grid_reduction_segment_size, + n_entrances_); + } bool last_block = index_utils::maskedIsLast(blockIdx, gridDim); diff --git a/torch/csrc/jit/codegen/cuda/scheduler/normalization.cpp b/torch/csrc/jit/codegen/cuda/scheduler/normalization.cpp index 6f3b736579d93..752416f2b01cc 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/normalization.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/normalization.cpp @@ -13,6 +13,8 @@ #include +#include + namespace torch { namespace jit { namespace fuser { @@ -21,14 +23,40 @@ namespace cuda { namespace { // round up to multiple of 8 or pow2 whichever smaller -int64_t roundUpPow2Or8(const int64_t x) { +int64_t roundUpPow2OrMultipleOf(const int64_t x, const int64_t multiple) { auto round_up_pow2 = scheduler_utils::lastPow2(x); if (round_up_pow2 < x) { round_up_pow2 *= 2; } - constexpr int64_t kEight = 8; // clang tidy - auto round_up_8 = x % kEight == 0 ? x : x + (kEight - x % kEight); - return std::min(round_up_8, round_up_pow2); + auto round_up_multiple = + x % multiple == 0 ? x : x + (multiple - x % multiple); + return std::min(round_up_multiple, round_up_pow2); +} + +int64_t safeDiv(const int64_t x, const int64_t y) { + return x / y == 0 ? 1 : x / y; +} + +int64_t clamp(const int64_t val, const int64_t min_val, const int64_t max_val) { + return std::min(std::max(val, min_val), max_val); +} + +// Reduce x, y, z until it's product is less than max value, reduce round robin +// starting with x +void reduceProductTo(int64_t& x, int64_t& y, int64_t& z, const int64_t max) { + TORCH_INTERNAL_ASSERT(max > 1); + if (x * y * z > max) { + x = safeDiv(x, 2); + } + if (x * y * z > max) { + y = safeDiv(y, 2); + } + if (x * y * z > max) { + z = safeDiv(z, 2); + } + if (x * y * z > max) { + reduceProductTo(x, y, z, max); + } } // Copied from reduction scheduler, should generalize. Simply needed to take out @@ -53,22 +81,30 @@ ReductionParams innerPersistentHeuristic( (int64_t)at::cuda::getCurrentDeviceProperties() ->maxThreadsPerMultiProcessor; + // No strict reason for this except a quarter of max available threads is + // typically enough to get performance for these types of kernels. Going under + // should be fine, but often unnecsesary to go with more. + const int64_t thread_target = + ceilDiv(device_max_threads_per_multiprocessor, 4); + const int64_t device_multiprocessor_count = (int64_t)at::cuda::getCurrentDeviceProperties()->multiProcessorCount; - auto const max_unroll = ceilDiv( + // Try to set a minmum amount of work for each thread, as cross thread + // communication is slow so it shouldn't be done for every element in the + // reduction. It also includes what we may want to put in unrolling, so just + // double the simple max unrolling as target iterations shouldn't cost + // additional registers since this is persitent anyways. + auto const max_total_unroll = + 2 * std::max((int64_t)32 / (int64_t)max_input_dtype_size, (int64_t)1); + + auto const register_safe_unroll = ceilDiv( // Available unrolling based on size of data type (int64_t)16 / (int64_t)max_input_dtype_size, // Reduce unrolling if we have many inputs, start reduction at 4 inputs scheduler_utils::lastPow2( std::max((int64_t)n_tensor_inputs >> 2, (int64_t)1))); - // Conservative value, could be set to larger based on arch if necessary. - constexpr int64_t l1_cache = 32 * 1024; - // Could change per generation, but for l1 we want to consider active threads, - // not resident - constexpr int64_t active_threads = 1024; - // if data fits in l2 and we need more parallelization in the reduction dim, // we can use a smaller warp size. While thread local data fits in l1, and // reduction dim is really small, we can use <32 threads per warp. @@ -78,102 +114,15 @@ ReductionParams innerPersistentHeuristic( // If it fits in l2, we just want to make sure each warp uses 32Bytes. Set // minimum warp as 16 threads instead of 32 as if we have a small reduction // dim going a bit smaller than 32 usually helps. - const int64_t warp_size_based_on_l2 = + const int64_t min_warp_size = fits_in_l2 ? (int64_t)32 / max_input_dtype_size : 16; - // Check how many elements it would take per thread to start thrashing l1 - // set that to minimum number we want to reduce per thread. - const int64_t warp_size_based_on_l1 = std::min( - ceilDiv( - total_reduction_numel, - std::max( - l1_cache / - (n_tensor_inputs * max_input_dtype_size * active_threads), - (int64_t)1)), - (int64_t)16); - - // Take the smaller - const int64_t warp_size = - std::min(warp_size_based_on_l1, warp_size_based_on_l2); - - // Initialization - int64_t target_blocks = 1; - int64_t target_unroll = 1; - int64_t target_iterations = 1; - - // Try to set a minmum amount of work for each thread, as cross thread - // communication is slow so it shouldn't be done for every element in the - // reduction. - int64_t min_target_iterations = - std::max((int64_t)32 / (int64_t)max_input_dtype_size, (int64_t)1); - - // Start trying to break parallelization up across threads, - // unrolling/iterations, and blocks. - - // max_threads_in_block is the cap on a thread block, the minimum is based on - // warp_size - int64_t max_threads_in_block = std::max( - warp_size, ceilDiv(total_reduction_numel, min_target_iterations)); - - // If we have one warp per block, check if that's enough to saturate the SMs - target_blocks = ceilDiv(n_elems, warp_size); - - // If we have more than a wave of blocks, put parallelism into unrolling and - // target iterations - if (target_blocks > device_multiprocessor_count) { - auto available_unroll = std::max( - n_elems / (warp_size * device_multiprocessor_count), (int64_t)1); - - // Spread across unrolling and iterations, want a balance of the two so flip - // back and forth to alternate adding to them. - bool flip = true; - - while (available_unroll > 1 && - (target_unroll < max_unroll || - // Prefer unrolling - target_iterations < max_unroll)) { - if (target_unroll * 2 <= max_unroll && flip) { - target_unroll *= 2; - } - - if (target_iterations * 2 <= max_unroll && !flip) { - target_iterations *= 2; - } - - available_unroll = std::max( - n_elems / - (warp_size * device_multiprocessor_count * target_unroll * - target_iterations), - (int64_t)1); - - flip = !flip; - } - - // Recompute target blocks - target_blocks = - ceilDiv(n_elems, warp_size * target_unroll * target_iterations); - } - - // Cap target blocks to 4 waves - target_blocks = std::min(target_blocks, device_multiprocessor_count * 4); - - if (target_blocks * target_unroll * target_iterations < n_elems) { - // targetting 4 waves, so try to use a quarter of available threads - max_threads_in_block = std::min( - ceilDiv(n_elems, target_blocks * target_unroll), - ceilDiv(device_max_threads_per_multiprocessor, (int64_t)4)); - } - - // Round up to nearest warp. - if (max_threads_in_block % warp_size != 0) { - max_threads_in_block += warp_size - max_threads_in_block % warp_size; - } - // Compute maximum number of reductions we could do in the same kernel based - // on persistent buffer size - const int64_t max_multi_reduction_factor = std::max( - scheduler_utils::register_file_size / max_persistent_buffer_size, - (int64_t)1); + // on persistent buffer size, it's unlikely approaching + // max_multi_reduction_factor is productive in this scheduler, leave a factor + // of 4 in it. + const int64_t max_multi_reduction_factor = safeDiv( + scheduler_utils::register_file_size, (max_persistent_buffer_size * 4)); // To get to target threads: // Prioritize @@ -200,136 +149,164 @@ ReductionParams innerPersistentHeuristic( int64_t outer_reduction_unroll_factor = 1; int64_t iter_unroll_factor = 1; - inner_reduction_unroll_factor = - vectorize_factor > 1 ? (int64_t)vectorize_factor : 1; - - // Grab what we can out of reduction domain, but don't go over a warp size yet - bdimx = std::min( - std::max( - ceilDiv(inner_most_dimension_numel, inner_reduction_unroll_factor), - (int64_t)warp_size), - max_threads_in_block); + bool vectorize = false; - // If we're not just barely covering the dimension, round to a more friendly - // number - if (bdimx * inner_reduction_unroll_factor != inner_most_dimension_numel) { - bdimx = bdimx > warp_size ? bdimx - bdimx % warp_size - : scheduler_utils::lastPow2(bdimx); + // If vectorize is available on the inner dimension always use it + if (vectorize_factor > 1) { + inner_reduction_unroll_factor = scheduler_utils::lastPow2( + std::min(inner_most_dimension_numel, (int64_t)vectorize_factor)); + // reset bdimx as we could have gone lower than a warp + bdimx = std::min( + bdimx, + ceilDiv(inner_reduction_unroll_factor, inner_reduction_unroll_factor)); - // Round bdimx down to multiple of warp size or power 2 - if (bdimx < warp_size) { - bdimx = scheduler_utils::lastPow2(bdimx); - } else { - bdimx = bdimx - bdimx % warp_size; - } + vectorize = inner_reduction_unroll_factor > 1; } - // Put everything else in bdimy for now - bdimy = std::min( - std::max(warp_size / bdimx, (int64_t)1), max_multi_reduction_factor); - - // If 3D fill the rest of the threads into bdimz - bdimz = std::min( - std::min( - std::max(max_threads_in_block / (bdimx * bdimy), (int64_t)1), - outer_reduction_numel), - scheduler_utils::z_block_limit); + bdimx = min_warp_size; - // If 3D doesn't fill out the threads, adjust to add to bdimy - bdimy = std::min( - std::max(max_threads_in_block / (bdimx * bdimz), (int64_t)1), - max_multi_reduction_factor); + // Start buliding into the unrolling of the inner reduction dimension if not + // already vectorized - // If we don't have a full warp and have an unroll factor, move unroll into - // bdimx - if (bdimx * bdimy * bdimz < warp_size && inner_reduction_unroll_factor > 1) { - bdimx = std::min( - std::max(inner_most_dimension_numel, warp_size), max_threads_in_block); + // If the inner dimension is small, but there's a significant outer dimension, + // unroll something other than the inner dimension (probably prefer the + // outer). + // + // If there isn't an outer dimension, but some unroll available in the inner + // dimension, prefer that. + // Don't unroll bdimx right now if: + // There's not a lot of threads here, but is in the outer dimension + // vectorized + if (!vectorize && + !(inner_most_dimension_numel < thread_target && + outer_reduction_numel > register_safe_unroll)) { + // Unroll if there's something to unroll + if (ceilDiv( + inner_most_dimension_numel, bdimx * inner_reduction_unroll_factor) > + 1) { + // bdimx actually ends up being the compliment of the persistent space + // and inner dimension. Dive into unrolling not going above the + // "register safe unroll". auto max_inner_unroll_factor = + // ceilDiv(inner_most_dimension_numel, bdimx); + inner_reduction_unroll_factor = std::min( + ceilDiv(inner_most_dimension_numel, bdimx), register_safe_unroll); + inner_reduction_unroll_factor = + scheduler_utils::lastPow2(inner_reduction_unroll_factor); + } + } - inner_reduction_unroll_factor = - std::min(ceilDiv(inner_most_dimension_numel, bdimx), max_unroll); + // Now that getting warp number of threads then unrolling was prioritized on + // the inner dimension, push it further if it has target iterations. + bdimx = clamp( + ceilDiv(inner_most_dimension_numel, inner_reduction_unroll_factor), + bdimx, + thread_target); + + // Since this is a persistent kernel and the thread dim is set based on + // fitting the reduction domain in registers bdimx is not accurate. + { + // Will adjust this later in the heuristics, but need to estimate bdimx + // more accurately before filling the rest of the block dimensions. + auto batches_per_inner_dim = roundUpPow2OrMultipleOf( + ceilDiv( + inner_most_dimension_numel, inner_reduction_unroll_factor * bdimx), + 8); + bdimx = ceilDiv( + inner_most_dimension_numel, + batches_per_inner_dim * inner_reduction_unroll_factor); + // Keep factors for now at clean values + bdimx = roundUpPow2OrMultipleOf(bdimx, min_warp_size); + } - // Readjust bdimy and bdimz - bdimy = std::min( - std::max(warp_size / bdimx, (int64_t)1), max_multi_reduction_factor); + // If we have a wave put everything else we can in bdimy + bdimy = std::min( + safeDiv(thread_target, bdimx), + std::min( + ceilDiv(total_iteration_numel, device_multiprocessor_count), + max_multi_reduction_factor)); - bdimz = std::min( - std::max(max_threads_in_block / (bdimx * bdimy), (int64_t)1), - outer_reduction_numel); + bdimy = roundUpPow2OrMultipleOf(bdimy, min_warp_size); - bdimy = std::min( - std::max(max_threads_in_block / (bdimx * bdimz), (int64_t)1), - max_multi_reduction_factor); - } + // Then if 3D into bdimz + bdimz = std::min( + std::min( + std::min( + safeDiv(thread_target, bdimx * bdimy), outer_reduction_numel), + scheduler_utils::z_block_limit), + ceilDiv(total_reduction_numel, bdimx * max_total_unroll)); - bool vectorize = false; + bdimz = roundUpPow2OrMultipleOf(bdimz, min_warp_size); - // Move unrolling factor into vectorization upto vectorization limit. - if (vectorize_factor > 1 && inner_reduction_unroll_factor > 1) { - vectorize = true; - inner_reduction_unroll_factor = std::min( - scheduler_utils::lastPow2(inner_reduction_unroll_factor), - (int64_t)vectorize_factor); - } + // Make sure we're not over thread target at this point + reduceProductTo(bdimz, bdimy, bdimx, thread_target); - // Attempt to put some unrolling into the outer reduction if inner hasn't - // taken the max unrolling - if (inner_reduction_unroll_factor < max_unroll) { + // If reduction isn't fully unrolled, try unrolling iter dimension + iter_unroll_factor = std::min( + std::min( + // take what ever unrolling is left + safeDiv(register_safe_unroll, inner_reduction_unroll_factor), + // Don't go under a wave + ceilDiv(total_iteration_numel, device_multiprocessor_count * bdimy)), + // Don't go over the max multi reduction factor + safeDiv(max_multi_reduction_factor, bdimy)); + + // Attempt to put some unrolling into the outer reduction if nothing else + // has taken it + if (inner_reduction_unroll_factor * iter_unroll_factor < max_total_unroll) { outer_reduction_unroll_factor = std::min( - ceilDiv(max_unroll, inner_reduction_unroll_factor), - ceilDiv(outer_reduction_numel, bdimz)); + std::min( + safeDiv( + max_total_unroll, + inner_reduction_unroll_factor * iter_unroll_factor), + ceilDiv(outer_reduction_numel, bdimz)), + register_safe_unroll); } godim = ceilDiv(total_iteration_numel, bdimy); // Set size of persistent per thread buffer on inner reduction buffer - int64_t batches_per_block_inner_reduction = roundUpPow2Or8(ceilDiv( - inner_most_dimension_numel, bdimx * inner_reduction_unroll_factor)); + int64_t batches_per_block_inner_reduction = roundUpPow2OrMultipleOf( + ceilDiv( + inner_most_dimension_numel, bdimx * inner_reduction_unroll_factor), + 8); // Prefer putting iterations into unrolling over having a very large // persistent buffer. - while (!vectorize && inner_reduction_unroll_factor < max_unroll && + while (!vectorize && inner_reduction_unroll_factor < max_total_unroll && batches_per_block_inner_reduction >= 2) { inner_reduction_unroll_factor *= 2; - batches_per_block_inner_reduction = roundUpPow2Or8(ceilDiv( - inner_most_dimension_numel, bdimx * inner_reduction_unroll_factor)); + batches_per_block_inner_reduction = roundUpPow2OrMultipleOf( + ceilDiv( + inner_most_dimension_numel, bdimx * inner_reduction_unroll_factor), + 8); } // Set size of persistent per thread buffer on outer reduction buffer - int64_t batches_per_block_outer_reduction = roundUpPow2Or8(ceilDiv( - ceilDiv(total_reduction_numel, inner_most_dimension_numel), - bdimz * outer_reduction_unroll_factor)); + int64_t batches_per_block_outer_reduction = roundUpPow2OrMultipleOf( + ceilDiv( + ceilDiv(total_reduction_numel, inner_most_dimension_numel), + bdimz * outer_reduction_unroll_factor), + 8); // Prefer putting iterations into unrolling over having a very large // persistent buffer. - while (outer_reduction_unroll_factor < max_unroll && - batches_per_block_outer_reduction >= 2) { + while (outer_reduction_unroll_factor * 2 <= register_safe_unroll && + // Roughly try to leave number of iterations to 4, because unroll will + // unswitch the predicate, if we can frequently miss the unswitch path + // then we'll start triggering the inlinepredicates frequently. + batches_per_block_outer_reduction >= 8) { + // std::cout << batches_per_block_outer_reduction << ", " + // << outer_reduction_unroll_factor << std::endl; outer_reduction_unroll_factor *= 2; - batches_per_block_outer_reduction = roundUpPow2Or8( - ceilDiv(outer_reduction_numel, bdimz * outer_reduction_unroll_factor)); - } - - // If we haven't gotten to the max_unroll case, try to take it out of the - // iteration domain - if (inner_reduction_unroll_factor * outer_reduction_unroll_factor < - max_unroll && - std::max(max_multi_reduction_factor / bdimy, (int64_t)1) > 2) { - // Don't go over a combined inner/outer unroll of max_unroll - auto unroll_available = std::min( - std::max( - max_unroll / - (inner_reduction_unroll_factor * outer_reduction_unroll_factor), - (int64_t)1), - std::max(max_multi_reduction_factor / bdimy, (int64_t)1)); - if (unroll_available > 1 && godim > 2 * device_multiprocessor_count) { - unroll_available = std::min( - unroll_available, ceilDiv(godim, 2 * device_multiprocessor_count)); - iter_unroll_factor = unroll_available; - } + batches_per_block_outer_reduction = roundUpPow2OrMultipleOf( + ceilDiv(outer_reduction_numel, bdimz * outer_reduction_unroll_factor), + 8); + // std::cout << batches_per_block_outer_reduction << ", " + // << outer_reduction_unroll_factor << std::endl; } - // Adjust bdimx based on batches_per_block and unroll factor set as they could - // have moved a bit since they're the free variables, not the buffers + // Adjust bdimx based on batches_per_block and unroll factor set as they + // could have moved a bit since they're the free variables, not the buffers bdimx = ceilDiv( inner_most_dimension_numel, inner_reduction_unroll_factor * batches_per_block_inner_reduction); @@ -345,39 +322,56 @@ ReductionParams innerPersistentHeuristic( ceilDiv(device_max_threads_per_multiprocessor, (int64_t)4) && // And batches_per_block_inner_reduction can be divided by two (batches_per_block_inner_reduction >= 2 || - batches_per_block_outer_reduction >= 2)) { - // Try to decrease per thread register allocation persistence size on inner - // reduction + batches_per_block_outer_reduction >= 2 || + outer_reduction_unroll_factor >= 2)) { + // Try to decrease per thread register allocation persistence size on + // inner reduction if (batches_per_block_inner_reduction >= 2 && batches_per_block_inner_reduction != - roundUpPow2Or8(batches_per_block_inner_reduction / 2)) { + roundUpPow2OrMultipleOf(batches_per_block_inner_reduction / 2, 8)) { batches_per_block_inner_reduction = - roundUpPow2Or8(batches_per_block_inner_reduction / 2); + roundUpPow2OrMultipleOf(batches_per_block_inner_reduction / 2, 8); bdimx = ceilDiv( inner_most_dimension_numel, inner_reduction_unroll_factor * batches_per_block_inner_reduction); continue; } - // Try to decrease per thread register allocation persistence size on outer - // reduction - if (batches_per_block_outer_reduction >= 2 && - batches_per_block_outer_reduction != - roundUpPow2Or8(batches_per_block_outer_reduction / 2) && - bdimz * 2 <= scheduler_utils::z_block_limit) { - batches_per_block_outer_reduction = - roundUpPow2Or8(batches_per_block_outer_reduction / 2); - bdimz = ceilDiv( - outer_reduction_numel, - batches_per_block_outer_reduction * outer_reduction_unroll_factor); - continue; + // Try to decrease per thread register allocation persistence size on + // outer reduction + if (bdimz * 2 <= scheduler_utils::z_block_limit) { + if (outer_reduction_unroll_factor > batches_per_block_outer_reduction) { + outer_reduction_unroll_factor = + safeDiv(outer_reduction_unroll_factor, 2); + bdimz = ceilDiv( + outer_reduction_numel, + batches_per_block_outer_reduction * outer_reduction_unroll_factor); + continue; + } else { + if (batches_per_block_outer_reduction >= 2 && + batches_per_block_outer_reduction != + roundUpPow2OrMultipleOf( + safeDiv(batches_per_block_outer_reduction, 2), 8) && + bdimz * 2 <= scheduler_utils::z_block_limit) { + batches_per_block_outer_reduction = roundUpPow2OrMultipleOf( + safeDiv(batches_per_block_outer_reduction, 2), 8); + bdimz = ceilDiv( + outer_reduction_numel, + batches_per_block_outer_reduction * + outer_reduction_unroll_factor); + continue; + } + } } + // Nothing could be modified, break break; } + // Make sure we're not over max thread count at this point + reduceProductTo(bdimz, bdimy, bdimx, device_max_threads_per_multiprocessor); + // Register pressure is really high per thread, which could lead to local - // memory leaks, if using less than maximum threads, decrease batches per - // block by a factor of 2 + // memory leaks, if using less than maximum threads increase thread usage. if (batches_per_block_outer_reduction * batches_per_block_inner_reduction * inner_reduction_unroll_factor * outer_reduction_unroll_factor * 4 > @@ -393,8 +387,15 @@ ReductionParams innerPersistentHeuristic( 4 > 255 * 3 && bdimx * bdimy * bdimz * 2 <= device_max_threads_per_multiprocessor && - batches_per_block_outer_reduction >= 2) { - batches_per_block_outer_reduction /= 2; + (batches_per_block_outer_reduction >= 2 || + outer_reduction_unroll_factor >= 2)) { + // If unroll factor is large, prefer taking it out of there rather than + // the persistent buffer. + if (outer_reduction_unroll_factor > batches_per_block_outer_reduction) { + outer_reduction_unroll_factor /= 2; + } else { + batches_per_block_outer_reduction /= 2; + } } auto device_warp_size = at::cuda::warp_size(); @@ -403,8 +404,7 @@ ReductionParams innerPersistentHeuristic( : bdimx + (device_warp_size - bdimx % device_warp_size); bool pad_bdimx = bdimx > 16 && - padded_bdimx * bdimy * bdimz < - (int64_t)at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock; + padded_bdimx * bdimy * bdimz < device_max_threads_per_multiprocessor; pad_bdimx = pad_bdimx && bdimx * inner_reduction_unroll_factor != inner_most_dimension_numel; @@ -656,7 +656,7 @@ ReductionParams OuterPersistentHeuristic( int64_t batches_per_block = ceilDiv(total_reduction_numel, bdimy * inner_reduction_unroll_factor); - batches_per_block = roundUpPow2Or8(batches_per_block); + batches_per_block = roundUpPow2OrMultipleOf(batches_per_block, 8); // Adjust bdimy based on batches_per_block and unroll factor set bdimy = ceilDiv( @@ -671,8 +671,8 @@ ReductionParams OuterPersistentHeuristic( // And batches_per_block can be divided by two batches_per_block >= 2 && // Make sure batches_per_block will be updated - batches_per_block != roundUpPow2Or8(batches_per_block / 2)) { - batches_per_block = roundUpPow2Or8(batches_per_block / 2); + batches_per_block != roundUpPow2OrMultipleOf(batches_per_block / 2, 8)) { + batches_per_block = roundUpPow2OrMultipleOf(batches_per_block / 2, 8); // Adjust bdimx based on batches_per_block and unroll factor set bdimy = ceilDiv( diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp index 3e8c14e924175..a457df6c2d125 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp @@ -12629,7 +12629,9 @@ __global__ void kernel1( (long*)shared_buf_N, threadIdx.x fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(&fusion); + + int X = 256, Y = 7, Z = 2048; + + // setup fusion + auto tv0 = makeContigTensor(4, DataType::Half); + fusion.addInput(tv0); + auto tv1 = castOp(DataType::Float, tv0); + + auto tvs = Welford(tv1, {0, 1, 2}); + auto tv_avg = tvs.avg; + auto tv_M2 = tvs.var_sum; + auto tv_N = tvs.n; + fusion.addOutput(tv_avg); + fusion.addOutput(tv_M2); + + auto cached_input = tv0->cacheAfter(); + auto cached_avg = tv_avg->cacheBefore(); + auto cached_M2 = tv_M2->cacheBefore(); + + auto reduction_tv = scheduler_utils::getReductionTvs(&fusion)[0]; + + reduction_tv->merge(0); + reduction_tv->merge(0); + + int TIDx = 16; + int vec = 4; + + int TIDy = 16; + int outer_tidy_fact = 16; + + reduction_tv->split(-1, TIDx * vec); + reduction_tv->split(-1, vec); + reduction_tv->axis(-2)->parallelize(ParallelType::TIDx); + // TODO: enable: + // reduction_tv->axis(-1)->parallelize(ParallelType::Vectorize); + reduction_tv->axis(-1)->parallelize(ParallelType::Unroll); + reduction_tv->axis(-3)->parallelize(ParallelType::BIDx); + + reduction_tv->split(0, TIDy); + reduction_tv->axis(1)->parallelize(ParallelType::TIDy); + reduction_tv->split(0, outer_tidy_fact); + reduction_tv->axis(0)->parallelize(ParallelType::BIDy); + + // T2_g[ rblockIdx.y, rS{16}, rthreadIdx.y, iblockIdx.x, ithreadIdx.x24, + // iV25{4} ] + reduction_tv->reorder({{3, 0}, {4, 1}, {0, 2}, {2, 3}, {1, 4}, {5, 5}}); + // T2_g[iblockIdx.x, ithreadIdx.x24, rblockIdx.y, rthreadIdx.y, rS{16}, + // iV25{4}] + + TransformPropagator::from(reduction_tv); + auto rfactor_tv = ir_utils::rfactorHelper(reduction_tv, {4}); + scheduler_utils::parallelizeAllLike(rfactor_tv, ir_utils::allTvs(&fusion)); + + tv0->computeAt(tv_avg, 2); + tv0->computeAt(cached_input, -2); + + cached_input->computeAt(rfactor_tv, 4, ComputeAtMode::BestEffort); + + for (auto tv : ir_utils::allTvs(&fusion)) { + if (tv == cached_input || tv == tv_avg || tv == tv_M2) { + continue; + } + tv->axis(-1)->parallelize(ParallelType::Serial); + } + + CompileOptions co; + co.index_mode = KernelIndexMode::INT32; + + FusionExecutor fe; + fe.compileFusion(&fusion, {}, LaunchParams(), co); + + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({X, Y, Y, Z}, options); + + auto cg_outputs = fe.runFusion({t0}, LaunchParams(-1, -1, -1, -1, -1, -1)); + + // by default Welford outputs sum of square diff so need to divide to get var + cg_outputs[1] = cg_outputs[1].div((float)(X * Y * Y)); + + auto at_mu = at::mean(t0.to(at::kDouble), {0, 1, 2}); + auto at_var = at::var(t0.to(at::kDouble), {0, 1, 2}, false); + + testValidate( + &fusion, + cg_outputs, + {t0}, + {at_mu, at_var}, + __LINE__, + __FILE__, + "", + LaunchParams(-1, -1, -1, -1, -1, -1)); +} + } // namespace jit } // namespace torch #endif // #if defined(USE_CUDA)