diff --git a/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h b/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h index 24b70ebd06fd0..a832794a7c9b2 100644 --- a/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h @@ -468,7 +468,9 @@ class TORCH_CUDA_CU_API TensorView : public Val { domain_ = td; } + public: void setComputeAt(unsigned int this_pos, bool decrease = false); + void setMaxProducer(unsigned int this_pos, bool decrease = false); diff --git a/torch/csrc/jit/codegen/cuda/runtime/welford.cu b/torch/csrc/jit/codegen/cuda/runtime/welford.cu index 90502d72ce36d..d0ab5cf79db37 100644 --- a/torch/csrc/jit/codegen/cuda/runtime/welford.cu +++ b/torch/csrc/jit/codegen/cuda/runtime/welford.cu @@ -15,7 +15,12 @@ __inline__ __device__ void welfordCombine( return; } TN ab_N = a_N + b_N; +#if 1 T b_N_div_ab_N = ((T)(nvfuser_index_t)(b_N)) / ((T)(nvfuser_index_t)(ab_N)); +#else + // No perf change + T b_N_div_ab_N = a_N == b_N ? 0.5f : (((T)(nvfuser_index_t)(b_N)) / ((T)(nvfuser_index_t)(ab_N))); +#endif T delta = b_avg - a_avg; a_avg += delta * b_N_div_ab_N; a_M2 += b_M2 + delta * delta * ((T)(nvfuser_index_t)(a_N)) * b_N_div_ab_N; @@ -350,10 +355,17 @@ __device__ void gridWelford( grid_sync::sync( sync_flags[idx_in_grid_segment], grid_reduction_segment_size); } else { +#if 0 grid_sync::sync( sync_flags[idx_in_grid_segment], grid_reduction_segment_size, n_entrances_); +#else + // Assumes separate sync flags are allocated for each call. + grid_sync::sync( + sync_flags[entrance_ind_ * grid_segment_size + idx_in_grid_segment], + grid_reduction_segment_size); +#endif } bool last_block = diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu_fused_reduction.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu_fused_reduction.cpp index 61ff194862d3b..5191c827432b6 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu_fused_reduction.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu_fused_reduction.cpp @@ -29,6 +29,8 @@ #include #include +//#include + // fuser and IR parser #include #include @@ -1267,6 +1269,1070 @@ TEST_F(NVFuserTest, FusionPersistentBNBackwardAllreduce_CUDA) { fe.kernel(), outputs, aten_inputs, {at_grad_input}, __LINE__, __FILE__); } +namespace { + +void clearL2Cache() { + // torch::NoGradGuard no_grad; + at::NoGradGuard no_grad; + auto l2_cache_size = at::cuda::getCurrentDeviceProperties()->l2CacheSize; + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + auto l2_elems = l2_cache_size / 4; + auto t0 = at::empty(l2_elems, options); + auto t1 = at::clone(t0); +}; + +const std::vector default_shape({256, 56, 56, 64}); + +std::vector getShape() { + auto shape = default_shape; + if (auto env = getenv("N")) { + shape.at(0) = atoi(env); + } + if (auto env = getenv("HW")) { + shape.at(1) = atoi(env); + shape.at(2) = atoi(env); + } + if (auto env = getenv("C")) { + shape.at(3) = atoi(env); + } + std::cout << "Shape: " << shape << std::endl; + return shape; +} + +const std::vector reduction_axes({0, 1, 2}); +const std::vector reduction_axes_at({0, 1, 2}); + +const bool kTraining = true; +const float kMomentum = 0.1; +const float kEps = 1e-5; + +std::vector batch_norm_1st( + TensorView* x, + TensorView* weight, + TensorView* bias, + TensorView* running_mean, + TensorView* running_var, + const bool kTraining, + Val* momentum, + Val* eps, + bool channels_last) { + auto fusion = FusionGuard::getCurFusion(); + + TORCH_INTERNAL_ASSERT(x != nullptr, "Input is invalid."); + + TORCH_INTERNAL_ASSERT( + !((running_var == nullptr) ^ (running_mean == nullptr)), + "running stats should comes in pairs"); + + TORCH_INTERNAL_ASSERT( + momentum != nullptr && momentum->getDataType().has_value() && + momentum->getDataType().value() == DataType::Double, + "Momentum is not a valid Double."); + + TORCH_INTERNAL_ASSERT( + eps != nullptr && eps->getDataType().has_value() && + eps->getDataType().value() == DataType::Double, + "Epsilon (eps) is not a valid Double."); + + // (B, C, H, W, D) tensor + // M = outer = channels + // N = reduction = B * H * W * D + // weight = bias = (C) tensor + const size_t kNumberOfDims = + TensorDomain::noReductions(x->getMaybeRFactorDomain()).size(); + // channels last format means C dimension is at axis kNumberOfDims-1 at x + size_t c_axis = channels_last ? kNumberOfDims - 1 : 1; + + std::vector reduction_axes; + std::vector broadcast_mask(kNumberOfDims, false); + Val* num_features = IrBuilder::create(x->container(), 1); + + for (const auto axis : c10::irange(kNumberOfDims)) { + if (axis != c_axis) { + reduction_axes.push_back(axis); + broadcast_mask[axis] = true; + num_features = mul(num_features, x->domain()->domain()[axis]->extent()); + } + } + + TensorView* y = nullptr; + TensorView* mean = nullptr; + TensorView* invstd = nullptr; +#if 0 + if (!getenv("welford")) { + auto welford_out = sum(x, reduction_axes); + return {welford_out}; + } +#endif + auto welford_out = Welford(x, reduction_axes); + + // updating running mean and running var + if (running_mean != nullptr && running_var != nullptr) { + // Note: kTraining is true here! + TORCH_INTERNAL_ASSERT( + kTraining, + "When running stats are provided, batch stats should only be computed during training"); + + auto rev_momentum = + sub(IrBuilder::create(x->container(), 1.0), momentum); + //auto current_mean_hat = mul(welford_out.avg, momentum); + //auto mean_hat = mul(running_mean, rev_momentum); + //auto new_mean_hat = add(mean_hat, current_mean_hat); + + //auto num_feature_decrement = sub(num_features, x->container()->oneVal()); + //auto unbiased_var = + //mul(welford_out.var_sum, reciprocal(num_feature_decrement)); + //auto current_var_hat = mul(unbiased_var, momentum); + //auto var_hat = mul(running_var, rev_momentum); + //auto new_var_hat = add(var_hat, current_var_hat); + + // when inputs have been cast by parser. We want to alias the output to + // the pre-cast input, so we can still update running stats + auto cast_to_input_dtype = [fusion]( + Val* cast_input, Val* aliased_output) { + auto unary_op = cast_input->definition(); + TORCH_INTERNAL_ASSERT( + unary_op->isA() && + unary_op->as()->getUnaryOpType() == UnaryOpType::Cast, + "check for cast op"); + auto input_to_cast = unary_op->input(0); + TORCH_INTERNAL_ASSERT( + input_to_cast->isFusionInput(), + "IO_tensor batch_norm::running_stats can only updating input tensor to fusion"); + auto rm_dtype = input_to_cast->getDataType(); + TORCH_INTERNAL_ASSERT( + rm_dtype.has_value(), + "Input running stats must have dtype defined"); + auto cast_output = castOp(*rm_dtype, aliased_output); + + fusion->aliasOutputToInput(cast_output, input_to_cast); + }; + +#if 0 + if (running_mean->isFusionInput()) { + fusion->aliasOutputToInput(new_mean_hat, running_mean); + } else { + cast_to_input_dtype(running_mean, new_mean_hat); + } + + if (running_var->isFusionInput()) { + fusion->aliasOutputToInput(new_var_hat, running_var); + } else { + cast_to_input_dtype(running_var, new_var_hat); + } +#endif + } + + mean = welford_out.avg; + + auto var = mul(welford_out.var_sum, reciprocal(num_features)); + auto var_eps = add(var, eps); + invstd = rsqrt(var_eps); + auto invstd_bcast = broadcast(invstd, broadcast_mask); + + y = invstd_bcast; + + //return {y, mean}; + return {y}; +} + +static void setupBatchNorm_nhwc(Fusion* fusion, DataType dtype) { + TORCH_INTERNAL_ASSERT(dtype == DataType::Float || dtype == DataType::Half); + + FusionGuard fg(fusion); + + // setup fusion + auto input = makeContigTensor(4, dtype); + auto weight = makeContigTensor(1, dtype); + auto bias = makeContigTensor(1, dtype); + auto running_mean = makeContigTensor(1, DataType::Float); + auto running_var = makeContigTensor(1, DataType::Float); + + fusion->addInput(input); + fusion->addInput(weight); + fusion->addInput(bias); + fusion->addInput(running_mean); + fusion->addInput(running_var); + + if (dtype == DataType::Half) { + input = castOp(DataType::Float, input); + weight = castOp(DataType::Float, weight); + bias = castOp(DataType::Float, bias); + } + + auto momentum_ptr = IrBuilder::create(kMomentum); + auto eps_ptr = IrBuilder::create(kEps); + + auto results = batch_norm_1st( + input, + weight, + bias, + running_mean, + running_var, + kTraining, + momentum_ptr, + eps_ptr, + true); + + for (auto tv: results) { + if (dtype == DataType::Half) { + tv = castOp(DataType::Half, tv); + } + fusion->addOutput(tv); + } +} + +} // namespace + +TEST_F(NVFuserTest, FusionGridReductionPerf0_CUDA) { + + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeContigTensor(4, DataType::Half); + fusion.addInput(tv0); + + auto tv1 = castOp(DataType::Float, tv0); + auto tv2 = castOp(DataType::Half, tv1); + fusion.addOutput(tv2); + + auto tv0_cache = tv0->cacheAfter(); + auto tv2_cache = tv2->cacheBefore(); + + const int vec = 2; + + tv2->merge(0); + tv2->merge(0); + tv2->merge(0); + if (vec > 1) { + tv2->split(0, vec); + } + tv2->split(0, 256); + + TransformPropagator::from(tv2); + + tv0->computeAt(tv2, 2); + + tv2->axis(0)->parallelize(ParallelType::BIDx); + tv2->axis(1)->parallelize(ParallelType::TIDx); + + scheduler_utils::parallelizeAllLike(tv2, ir_utils::allTvs(&fusion)); + + if (vec > 1) { + tv0_cache->axis(2)->parallelize(ParallelType::Vectorize); + tv2->axis(2)->parallelize(ParallelType::Vectorize); + } + + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + auto t0 = at::randn(getShape(), options); + std::vector aten_inputs({t0}); + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto outputs = fe.runFusion(aten_inputs); + + auto ref = t0.to(c10::kFloat).to(c10::kHalf); + + testValidate( + fe.kernel(), outputs, aten_inputs, {ref}, __LINE__, __FILE__); + + for (int i = 0; i < 10; ++i) { + clearL2Cache(); + outputs = fe.runFusion(aten_inputs); + } +} + +TEST_F(NVFuserTest, FusionGridReductionPerf1_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeContigTensor(4, DataType::Half); + fusion.addInput(tv0); + + auto tv1 = castOp(DataType::Float, tv0); + auto tv2 = sum(tv1, reduction_axes); + auto tv3 = castOp(DataType::Half, tv2); + fusion.addOutput(tv3); + + fusion.printMath(); + fusion.printKernel(); + + const int tidy = 16; + + auto tv0_cache = tv0->cacheAfter(); + + tv2->merge(0); + tv2->merge(0); + tv2->split(0, tidy); + + TransformPropagator::from(tv2); + + //tv2->computeAt(tv3, -1); + //tv0->computeAt(tv2, -1); + + fusion.printMath(); + + tv2->axis(0)->parallelize(ParallelType::BIDx); + tv2->axis(1)->parallelize(ParallelType::TIDy); + tv2->axis(2)->parallelize(ParallelType::TIDx); + scheduler_utils::parallelizeAllLike(tv2, ir_utils::allTvs(&fusion)); + + fusion.printMath(); + fusion.printKernel(); + + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + auto t0 = at::randn(getShape(), options); + std::vector aten_inputs({t0}); + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto outputs = fe.runFusion(aten_inputs); + + auto ref = t0.to(c10::kFloat).sum(reduction_axes_at).to(c10::kHalf); + + testValidate( + fe.kernel(), outputs, aten_inputs, {ref}, __LINE__, __FILE__); + + for (int i = 0; i < 10; ++i) { + clearL2Cache(); + outputs = fe.runFusion(aten_inputs); + } +} + +TEST_F(NVFuserTest, FusionGridReductionPerf2_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeContigTensor(4, DataType::Half); + fusion.addInput(tv0); + + auto tv1 = castOp(DataType::Float, tv0); + auto tv2 = sum(tv1, reduction_axes); + auto tv3 = castOp(DataType::Half, tv2); + fusion.addOutput(tv3); + + fusion.printMath(); + + const int tidy = 8; + const int bidx = 72 * 4; + + auto tv0_cache = tv0->cacheAfter(); + + tv2->merge(0); + tv2->merge(0); + tv2->split(0, tidy); + tv2->split(0, bidx, false); + + tv2->reorder({{1, 2}, {2, 1}}); + + TransformPropagator::from(tv2); + + fusion.printMath(); + + auto tv2_rf = tv2->rFactor({2}); + + tv0->computeAt(tv2_rf, -1); + + fusion.printMath(); + + tv2_rf->axis(0)->parallelize(ParallelType::BIDx); + tv2_rf->axis(1)->parallelize(ParallelType::TIDy); + tv2_rf->axis(2)->parallelize(ParallelType::Serial); + tv2_rf->axis(3)->parallelize(ParallelType::TIDx); + scheduler_utils::parallelizeAllLike(tv2_rf, ir_utils::allTvs(&fusion)); + + fusion.printMath(); + fusion.printKernel(); + + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + auto t0 = at::randn(getShape(), options); + std::vector aten_inputs({t0}); + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto outputs = fe.runFusion(aten_inputs); + + auto ref = t0.to(c10::kFloat).sum(reduction_axes_at).to(c10::kHalf); + + testValidate( + fe.kernel(), outputs, aten_inputs, {ref}, __LINE__, __FILE__); + + for (int i = 0; i < 10; ++i) { + clearL2Cache(); + outputs = fe.runFusion(aten_inputs); + } +} + +TEST_F(NVFuserTest, FusionGridReductionPerf3_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeContigTensor(4, DataType::Half); + fusion.addInput(tv0); + + auto tv1 = castOp(DataType::Float, tv0); + auto tv2 = sum(tv1, reduction_axes); + auto tv3 = castOp(DataType::Half, tv2); + fusion.addOutput(tv3); + + fusion.printMath(); + + auto shape = getShape(); + + const int tidx = 32; + const int vec = shape.back() / tidx; + TORCH_CHECK(shape.back() == tidx * vec); + const int tidy = 16; + const int bidx = 72 * 4; + + auto tv0_cache = tv0->cacheAfter(); + auto tv3_cache = tv3->cacheBefore(); + + tv2->merge(0); + tv2->merge(0); + tv2->split(0, tidy); + tv2->split(0, bidx, false); + tv2->split(-1, vec); + + tv2->reorder({{1, 2}, {2, 1}}); + + TransformPropagator::from(tv2); + + fusion.printMath(); + + auto tv2_rf = tv2->rFactor({2}); + + tv0->computeAt(tv2_rf, -2); + + fusion.printMath(); + + tv2_rf->axis(0)->parallelize(ParallelType::BIDx); + tv2_rf->axis(1)->parallelize(ParallelType::TIDy); + tv2_rf->axis(2)->parallelize(ParallelType::Serial); + tv2_rf->axis(3)->parallelize(ParallelType::TIDx); + scheduler_utils::parallelizeAllLike(tv2_rf, ir_utils::allTvs(&fusion)); + + tv0_cache->axis(-1)->parallelize(ParallelType::Vectorize); + tv3->axis(-1)->parallelize(ParallelType::Vectorize); + + fusion.printMath(); + fusion.printKernel(); + + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + auto t0 = at::randn(shape, options); + std::vector aten_inputs({t0}); + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto outputs = fe.runFusion(aten_inputs); + + auto ref = t0.to(c10::kFloat).sum(reduction_axes_at).to(c10::kHalf); + + testValidate( + fe.kernel(), outputs, aten_inputs, {ref}, __LINE__, __FILE__); + + for (int i = 0; i < 10; ++i) { + clearL2Cache(); + outputs = fe.runFusion(aten_inputs); + } +} + +TEST_F(NVFuserTest, FusionBNBaseline_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + setupBatchNorm_nhwc(&fusion, DataType::Half); + fusion.printMath(); + + auto input = fusion.inputs().at(0)->as(); + TORCH_CHECK(input->nDims() == 4); + + TensorView* output = nullptr; + for (auto out_tv: ir_utils::filterByType(fusion.outputs())) { + if (out_tv->nDims() == 4) { + TORCH_CHECK(output == nullptr); + output = out_tv; + } + } + + TensorView* input_cache = nullptr; + + for (auto tv: ir_utils::filterByType(fusion.inputs())) { + auto cache = tv->cacheAfter(); + if (tv == input) { + input_cache = cache; + } + } + + for (auto tv: ir_utils::filterByType(fusion.outputs())) { + tv->cacheBefore(); + } + + const int tidx = 32; + //const int tidx = 64; + //const int tidy = 16; + int tidy = 8; + //int bidx = 72 * 4; + int bidy = 1568; + int unswitch = 4; + if (getenv("TIDY")) { + tidy = atoi(getenv("TIDY")); + } + if (getenv("BIDY")) { + bidy = atoi(getenv("BIDY")); + } + + input->merge(1, 2); + input->merge(0, 1); + + input->reorder({{0, 1}, {1, 0}}); + + input->split(1, tidy); + input->split(1, unswitch); + input->split(1, 1); + input->split(1, bidy, false); + // Move tidy before unswitch + input->reorder({{3, 4}, {4, 5}, {5, 3}}); + + input->split(0, tidx); + + TransformPropagator::from(input); + + // Rfactor welford op + WelfordOp* wop = nullptr; + for (auto expr : fusion.exprs()) { + if (expr->isA()) { + TORCH_CHECK(wop == nullptr); + wop = expr->as(); + } + } + TORCH_CHECK(wop != nullptr); + + auto avg_rf = ir_utils::rfactorHelper(wop->outAvg()->as(), {3, 5, 6}); + auto var_rf = avg_rf->definition()->outputs().at(1)->as(); + auto n_rf = avg_rf->definition()->outputs().at(2)->as(); + + input_cache->setComputeAt(4); + auto input_cache_float = input_cache->uses().at(0)->outputs().at(0)->as(); + input_cache_float->setMaxProducer(4); + input_cache_float->setComputeAt(4); + avg_rf->setMaxProducer(4); + var_rf->setMaxProducer(4); + n_rf->setMaxProducer(4); + + input->axis(0)->parallelize(ParallelType::BIDx); + input->axis(1)->parallelize(ParallelType::TIDx); + input->axis(2)->parallelize(ParallelType::BIDy); + input->axis(3)->parallelize(ParallelType::Serial); + input->axis(4)->parallelize(ParallelType::TIDy); + scheduler_utils::parallelizeAllLike(input, ir_utils::allTvs(&fusion)); + + fusion.printMath(); + fusion.printKernel(); + + auto shape = getShape(); + + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + auto options_float = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto at_input = at::randn(shape, options); + auto at_weight = at::randn({shape.back()}, options); + auto at_bias = at::randn({shape.back()}, options); + auto at_running_mean = at::randn({shape.back()}, options_float); + auto at_running_var = at::randn({shape.back()}, options_float); + std::vector aten_inputs({at_input, at_weight, at_bias, at_running_mean, at_running_var}); + + LaunchParams launch_constraints; + CompileOptions compile_options; + compile_options.index_mode = KernelIndexMode::INT32; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs, launch_constraints, compile_options); + auto outputs = fe.runFusion(aten_inputs); + + for (int i = 0; i < 10; ++i) { + clearL2Cache(); + outputs = fe.runFusion(aten_inputs); + } +} + +TEST_F(NVFuserTest, FusionBNOpt1_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + setupBatchNorm_nhwc(&fusion, DataType::Half); + fusion.printMath(); + + auto input = fusion.inputs().at(0)->as(); + TORCH_CHECK(input->nDims() == 4); + + TensorView* output = nullptr; + for (auto out_tv: ir_utils::filterByType(fusion.outputs())) { + if (out_tv->nDims() == 4) { + TORCH_CHECK(output == nullptr); + output = out_tv; + } + } + + TensorView* input_cache = nullptr; + + for (auto tv: ir_utils::filterByType(fusion.inputs())) { + auto cache = tv->cacheAfter(); + if (tv == input) { + input_cache = cache; + } + } + TORCH_CHECK(input_cache != nullptr); + + TensorView* output_cache = nullptr; + + for (auto tv: ir_utils::filterByType(fusion.outputs())) { + auto cache = tv->cacheBefore(); + if (tv == output) { + output_cache = cache; + } + } + TORCH_CHECK(output_cache != nullptr); + + int tidx = 32; + int tidy = 8; + int bidy = 72 * 4; // #SMs * 4 + int unroll = 4; + int vec = 2; + + if (auto env = getenv("TIDX")) { + tidx = atoi(env); + } + if (auto env = getenv("TIDY")) { + tidy = atoi(env); + } + if (auto env = getenv("BIDY")) { + bidy = atoi(env); + } + if (auto env = getenv("VEC")) { + vec = atoi(env); + } + + input->merge(1, 2); + input->merge(0, 1); + + input->split(0, tidy); + input->split(0, unroll); + input->split(0, 1); + input->split(0, bidy, false); + // Move tidy before unswitch + input->reorder({{2, 3}, {3, 4}, {4, 2}}); + + if (vec > 1) { + input->split(-1, vec); + input->split(-2, tidx); + } else { + input->split(-1, tidx); + } + + TransformPropagator::from(input); + + // Rfactor welford op + WelfordOp* wop = nullptr; + for (auto expr : fusion.exprs()) { + if (expr->isA()) { + TORCH_CHECK(wop == nullptr); + wop = expr->as(); + } + } + + auto avg_rf = ir_utils::rfactorHelper(wop->outAvg()->as(), {1, 3, 4}); + auto var_rf = avg_rf->definition()->outputs().at(1)->as(); + auto n_rf = avg_rf->definition()->outputs().at(2)->as(); + + int ca_pos = 2; + input_cache->setComputeAt(ca_pos); + auto input_cache_float = input_cache->uses().at(0)->outputs().at(0)->as(); + input_cache_float->setMaxProducer(ca_pos); + input_cache_float->setComputeAt(ca_pos); + avg_rf->setMaxProducer(ca_pos); + var_rf->setMaxProducer(ca_pos); + n_rf->setMaxProducer(ca_pos); + + input->axis(0)->parallelize(ParallelType::BIDy); + input->axis(1)->parallelize(ParallelType::Serial); + input->axis(2)->parallelize(ParallelType::TIDy); + input->axis(5)->parallelize(ParallelType::BIDx); + input->axis(6)->parallelize(ParallelType::TIDx); + + scheduler_utils::parallelizeAllLike(input, ir_utils::allTvs(&fusion)); + + if (vec > 1 && getenv("SKIP_VEC") == nullptr) { + input_cache->axis(-1)->parallelize(ParallelType::Vectorize); + output->axis(-1)->parallelize(ParallelType::Vectorize); + } + + fusion.printMath(); + fusion.printKernel(); + + auto shape = getShape(); + + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + auto options_float = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto at_input = at::randn(shape, options); + auto at_weight = at::randn({shape.back()}, options); + auto at_bias = at::randn({shape.back()}, options); + auto at_running_mean = at::randn({shape.back()}, options_float); + auto at_running_var = at::randn({shape.back()}, options_float); + std::vector aten_inputs({at_input, at_weight, at_bias, at_running_mean, at_running_var}); + + LaunchParams launch_constraints; + CompileOptions compile_options; + compile_options.index_mode = KernelIndexMode::INT32; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs, launch_constraints, compile_options); + auto outputs = fe.runFusion(aten_inputs); + + for (int i = 0; i < 10; ++i) { + clearL2Cache(); + outputs = fe.runFusion(aten_inputs); + } +} + +TEST_F(NVFuserTest, FusionBNPerf5_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + //const bool is_welford = getenv("welford") != nullptr; + const bool is_welford = true; + + setupBatchNorm_nhwc(&fusion, DataType::Half); + fusion.printMath(); + + auto input = fusion.inputs().at(0)->as(); + TORCH_CHECK(input->nDims() == 4); + + TensorView* output = nullptr; + for (auto out_tv: ir_utils::filterByType(fusion.outputs())) { + if (out_tv->nDims() == 4) { + TORCH_CHECK(output == nullptr); + output = out_tv; + } + } + + TensorView* input_cache = nullptr; + + for (auto tv: ir_utils::filterByType(fusion.inputs())) { + auto cache = tv->cacheAfter(); + if (tv == input) { + input_cache = cache; + } + } + + for (auto tv: ir_utils::filterByType(fusion.outputs())) { + tv->cacheBefore(); + } + + auto shape = getShape(); + + const int tidx = 32; + //const int tidx = 64; + const int vec = shape.back() / tidx; + TORCH_CHECK(shape.back() == tidx * vec); + //const int tidy = 16; + int tidy = 16; + int bidx = 72 * 4; + if (getenv("TIDY")) { + tidy = atoi(getenv("TIDY")); + } + if (getenv("BIDX")) { + bidx = atoi(getenv("BIDX")); + } + + //input->merge(1, 2); + //input->merge(0, 1); + input->merge(0); + input->merge(0); + input->split(0, tidy); + input->split(0, bidx, false); + + if (vec > 1) { + input->split(-1, vec); + } + + input->reorder({{1, 2}, {2, 1}}); + + TransformPropagator::from(input); + + if (is_welford) { + // Rfactor welford op + WelfordOp* wop = nullptr; + for (auto expr : fusion.exprs()) { + if (expr->isA()) { + TORCH_CHECK(wop == nullptr); + wop = expr->as(); + } + } + + std::cerr << wop->toString(); + + auto avg_rf = ir_utils::rfactorHelper(wop->outAvg()->as(), {2}); + auto var_rf = avg_rf->definition()->outputs().at(1)->as(); + auto n_rf = avg_rf->definition()->outputs().at(2)->as(); + + fusion.printMath(); + + input_cache->setComputeAt(3); + auto input_cache_float = input_cache->uses().at(0)->outputs().at(0)->as(); + input_cache_float->setMaxProducer(3); + input_cache_float->setComputeAt(3); + avg_rf->setMaxProducer(3); + var_rf->setMaxProducer(3); + n_rf->setMaxProducer(3); + } else { + + // Rfactor welford op + ReductionOp* rop = nullptr; + for (auto expr : fusion.exprs()) { + if (expr->isA()) { + TORCH_CHECK(rop == nullptr); + rop = expr->as(); + } + } + + auto avg_rf = rop->out()->as()->rFactor({2}); + + input_cache->setComputeAt(3); + auto input_cache_float = input_cache->uses().at(0)->outputs().at(0)->as(); + input_cache_float->setMaxProducer(3); + input_cache_float->setComputeAt(3); + avg_rf->setMaxProducer(3); + } + + fusion.printMath(); + + input->axis(0)->parallelize(ParallelType::BIDx); + input->axis(1)->parallelize(ParallelType::TIDy); + input->axis(2)->parallelize(ParallelType::Serial); + input->axis(3)->parallelize(ParallelType::TIDx); + scheduler_utils::parallelizeAllLike(input, ir_utils::allTvs(&fusion)); + + if (vec > 1) { + input_cache->axis(-1)->parallelize(ParallelType::Vectorize); + } + + fusion.printMath(); + fusion.printKernel(); + + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + auto options_float = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto at_input = at::randn(shape, options); + auto at_weight = at::randn({shape.back()}, options); + auto at_bias = at::randn({shape.back()}, options); + auto at_running_mean = at::randn({shape.back()}, options_float); + auto at_running_var = at::randn({shape.back()}, options_float); + std::vector aten_inputs({at_input, at_weight, at_bias, at_running_mean, at_running_var}); + + LaunchParams launch_constraints; + CompileOptions compile_options; + compile_options.index_mode = KernelIndexMode::INT32; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs, launch_constraints, compile_options); + auto outputs = fe.runFusion(aten_inputs); + + for (int i = 0; i < 10; ++i) { + clearL2Cache(); + outputs = fe.runFusion(aten_inputs); + } +} + +TEST_F(NVFuserTest, FusionWelfordInnerReductionPerf_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + bool is_welford = true; + + auto shape = getShape(); + + auto tv0 = makeContigTensor(4, DataType::Half); + fusion.addInput(tv0); + + auto tv1 = castOp(DataType::Float, tv0); + auto tvs = Welford(tv1, {1, 2, 3}); + auto tv2 = tvs.avg; + auto tv3 = tvs.var_sum; + auto tv4 = tvs.n; + auto welford_out = tv2; + auto tv5 = castOp(DataType::Half, welford_out); + fusion.addOutput(tv5); + auto welford_out2 = tv3; + auto divided = div(welford_out2, IrBuilder::create(shape.at(0) * shape.at(1) * shape.at(2))); + auto tv6 = castOp(DataType::Half, divided); + fusion.addOutput(tv6); + + fusion.printMath(); + + auto input = tv0; + auto output = tv5; + + auto input_cache = tv0->cacheAfter(); + output->cacheBefore(); + + const int tidx = 256; + //const int tidx = 64; + const int vec = 4; + //const int vec = 1; + //TORCH_CHECK(shape.back() == tidx * vec); + const int tidy = 16; + const int bidx = 72 * 2; + + input->merge(-2, -1); + input->merge(-2, -1); + + if (vec > 1) { + input->split(1, vec); + } + + input->split(1, tidx); + input->split(1, bidx, false); + + TransformPropagator::from(input); + + auto rf = ir_utils::rfactorHelper(welford_out, {2, -1}); + + fusion.printMath(); + + input_cache->setComputeAt(4); + auto input_cache_float = input_cache->uses().at(0)->outputs().at(0)->as(); + input_cache_float->setMaxProducer(4); + input_cache_float->setComputeAt(4); + for (auto rf_tv: ir_utils::filterByType(rf->definition()->outputs())) { + rf_tv->setMaxProducer(4); + } + + fusion.printMath(); + + input->axis(0)->parallelize(ParallelType::BIDy); + input->axis(1)->parallelize(ParallelType::BIDx); + input->axis(2)->parallelize(ParallelType::Serial); + input->axis(3)->parallelize(ParallelType::TIDx); + scheduler_utils::parallelizeAllLike(input, ir_utils::allTvs(&fusion)); + + if (vec > 1) { + input_cache->axis(-1)->parallelize(ParallelType::Vectorize); + } + + fusion.printMath(); + fusion.printKernel(); + + auto shape_r = shape; + std::reverse(shape_r.begin(), shape_r.end()); + + std::cerr << "Shape: " << shape_r << std::endl; + + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + auto options_float = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto at_input = at::randn(shape_r, options); + std::vector aten_inputs({at_input}); + + LaunchParams launch_constraints; + CompileOptions compile_options; + compile_options.index_mode = KernelIndexMode::INT32; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs, launch_constraints, compile_options); + auto outputs = fe.runFusion(aten_inputs); + + auto ref = at_input.to(c10::kDouble).mean({1, 2, 3}).to(c10::kHalf); + auto ref_var = at_input.to(c10::kDouble).var({1, 2, 3}, false).to(c10::kHalf); + + testValidate( + fe.kernel(), outputs, aten_inputs, {ref, ref_var}, __LINE__, __FILE__); + + for (int i = 0; i < 10; ++i) { + clearL2Cache(); + outputs = fe.runFusion(aten_inputs); + } +} + +TEST_F(NVFuserTest, FusionHalfVecPerf_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeContigTensor(1, DataType::Half); + fusion.addInput(tv0); + + auto tv1 = set(tv0); + auto tv2 = set(tv1); + fusion.addOutput(tv2); + + int tidx = 256; + int tidy = 4; + int vec = 2; + int bidx = 1024; + int seq = 32; + + if (auto env = getenv("TIDX")) { + tidx = atoi(env); + } + if (auto env = getenv("TIDY")) { + tidy = atoi(env); + } + if (auto env = getenv("VEC")) { + vec = atoi(env); + } + if (auto env = getenv("BIDX")) { + bidx = atoi(env); + } + if (auto env = getenv("SEQ")) { + seq = atoi(env); + } + + const int len = tidx * vec * bidx * seq; + + tv2->split(0, vec * tidx); + + if (vec > 1) { + tv2->split(1, vec); + } + + tv2->split(0, tidy); + tv2->split(0, seq); + + TransformPropagator::from(tv2); + + tv2->axis(0)->parallelize(ParallelType::BIDx); + tv2->axis(1)->parallelize(ParallelType::Serial); + tv2->axis(2)->parallelize(ParallelType::TIDy); + tv2->axis(3)->parallelize(ParallelType::TIDx); + scheduler_utils::parallelizeAllLike(tv2, ir_utils::allTvs(&fusion)); + + if (vec > 1 && !getenv("SKIP_VEC")) { + tv1->axis(-1)->parallelize(ParallelType::Vectorize); + tv2->axis(-1)->parallelize(ParallelType::Vectorize); + } + + fusion.printMath(); + fusion.printKernel(); + + std::cerr << "Buffer length: " << len << std::endl; + + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + auto options_float = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto at_input = at::randn({len}, options); + std::vector aten_inputs({at_input}); + + LaunchParams launch_constraints; + CompileOptions compile_options; + compile_options.index_mode = KernelIndexMode::INT32; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs, launch_constraints, compile_options); + auto outputs = fe.runFusion(aten_inputs); + + for (int i = 0; i < 10; ++i) { + clearL2Cache(); + outputs = fe.runFusion(aten_inputs); + } +} + } // namespace jit } // namespace torch #endif // #if defined(USE_CUDA)