Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable grid reductions within loops. #1681

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion torch/csrc/jit/codegen/cuda/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1342,6 +1342,8 @@ class CudaKernelGenerator : private OptOutConstDispatch {
}
// Init val
func_args.arg(genCall(data_type, genInline(grop->init())));
func_args.arg(genInline(grop->entrance_index()));
func_args.arg(genInline(grop->entrances()));

indent() << "reduction::gridReduce<" << template_args << ">(\n";
indent() << kTab << func_args << ");\n";
Expand Down Expand Up @@ -1658,7 +1660,10 @@ class CudaKernelGenerator : private OptOutConstDispatch {
indent() << kTab << read_pred << ",\n";
}
// TODO : init value support or remove.
indent() << kTab << data_type << "(0));\n";
indent() << kTab << data_type << "(0),\n";
indent() << kTab << genInline(gwop->entrance_index()) << ",\n";
indent() << kTab << genInline(gwop->entrances());
code_ << ");\n";
}

void generateGridAllreduce(const kir::GridWelford* gwop) {
Expand Down
14 changes: 7 additions & 7 deletions torch/csrc/jit/codegen/cuda/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,15 +115,17 @@ class KernelIrScanner : private IrVisitor {
void handle(GridWelford* grid_welford) final {
summary_.has_welford = true;
summary_.has_grid_welford = true;
const auto dom =
grid_welford->welford_op()->out()->as<TensorIndex>()->view()->domain();
updateGridReductionInLoop(dom);
summary_.has_grid_reductions = true;
if (grid_welford->welford_op()->isAllreduce()) {
summary_.has_cooperative_grid_reduction = true;
}
}

void handle(GridReduction* grid_reduction) final {
summary_.has_grid_reductions = true;
const auto dom = ir_utils::getTvOutput(grid_reduction)->domain();
updateGridReductionInLoop(dom);
if (grid_reduction->isAllreduce()) {
summary_.has_cooperative_grid_reduction = true;
}
}

void handle(GroupedGridReduction* grid_reduction) final {
Expand Down Expand Up @@ -156,8 +158,6 @@ class KernelIrScanner : private IrVisitor {

private:
void updateGridReductionInLoop(TensorDomain* dom) {
summary_.has_grid_reductions = true;

for (const auto i : c10::irange(dom->nDims())) {
const auto id = GpuLower::current()->caMap()->getConcreteMappedID(
dom->domain()[i], IdMappingMode::LOOP);
Expand Down
14 changes: 11 additions & 3 deletions torch/csrc/jit/codegen/cuda/kernel_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,8 @@ GridReduction::GridReduction(
Val* in,
Allocate* reduction_buffer,
Allocate* sync_buffer,
Val* entrance_index,
Val* entrances,
bool is_allreduce)
: ReductionOp(
passkey,
Expand All @@ -445,7 +447,9 @@ GridReduction::GridReduction(
is_allreduce,
ExprType::GridReduction),
reduction_buffer_(reduction_buffer),
sync_buffer_(sync_buffer) {
sync_buffer_(sync_buffer),
entrance_index_(entrance_index),
entrances_(entrances) {
TORCH_INTERNAL_ASSERT(
passkey.ir_container_->isA<kir::Kernel>(),
"IR type only valid for Kernel container.");
Expand Down Expand Up @@ -495,13 +499,17 @@ GridWelford::GridWelford(
Allocate* var_buffer,
Allocate* avg_buffer,
Allocate* n_buffer,
Allocate* sync_buffer)
Allocate* sync_buffer,
Val* entrance_index,
Val* entrances)
: Expr(passkey, ExprType::GridWelford),
welford_op_(welford_op),
var_buffer_(var_buffer),
avg_buffer_(avg_buffer),
n_buffer_(n_buffer),
sync_buffer_(sync_buffer) {
sync_buffer_(sync_buffer),
entrance_index_(entrance_index),
entrances_(entrances) {
TORCH_INTERNAL_ASSERT(
passkey.ir_container_->isA<kir::Kernel>(),
"IR type only valid for Kernel container.");
Expand Down
30 changes: 29 additions & 1 deletion torch/csrc/jit/codegen/cuda/kernel_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,8 @@ class TORCH_CUDA_CU_API GridReduction final : public ReductionOp {
Val* in,
Allocate* reduction_buffer,
Allocate* sync_buffer,
Val* entrance_index,
Val* entrances,
bool is_fused = false);

Allocate* reduction_buffer() const {
Expand All @@ -523,6 +525,16 @@ class TORCH_CUDA_CU_API GridReduction final : public ReductionOp {
return sync_buffer_;
}

// Which instance of entering this grid reduction is this iteration?
Val* entrance_index() const {
return entrance_index_;
}

// How many times will this grid reduction be entered
Val* entrances() const {
return entrances_;
}

const ParallelTypeBitmap& threadPredicate() const {
return thread_predicate_;
}
Expand All @@ -538,6 +550,8 @@ class TORCH_CUDA_CU_API GridReduction final : public ReductionOp {
// use them, the thread predicate is held here separately from
// Expr::predicate_.
ParallelTypeBitmap thread_predicate_;
Val* entrance_index_ = nullptr;
Val* entrances_ = nullptr;
};

class TORCH_CUDA_CU_API GroupedGridReduction final : public GroupedReductionOp {
Expand Down Expand Up @@ -629,7 +643,9 @@ class TORCH_CUDA_CU_API GridWelford final : public Expr {
Allocate* var_buffer,
Allocate* avg_buffer,
Allocate* n_buffer,
Allocate* sync_buffer);
Allocate* sync_buffer,
Val* entrance_index,
Val* entrances);

WelfordOp* welford_op() const {
return welford_op_;
Expand All @@ -651,6 +667,16 @@ class TORCH_CUDA_CU_API GridWelford final : public Expr {
return sync_buffer_;
}

// Which instance of entering this grid reduction is this iteration?
Val* entrance_index() const {
return entrance_index_;
}

// How many times will this grid reduction be entered
Val* entrances() const {
return entrances_;
}

const ParallelTypeBitmap& threadPredicate() const {
return thread_predicate_;
}
Expand All @@ -665,6 +691,8 @@ class TORCH_CUDA_CU_API GridWelford final : public Expr {
Allocate* avg_buffer_ = nullptr;
Allocate* n_buffer_ = nullptr;
Allocate* sync_buffer_ = nullptr;
Val* entrance_index_ = nullptr;
Val* entrances_ = nullptr;
// gridReduce has template flags for thread predicates. In order to
// use them, the thread predicate is held here separately from
// Expr::predicate_.
Expand Down
126 changes: 119 additions & 7 deletions torch/csrc/jit/codegen/cuda/lower_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ namespace {
// size. For example, FusedReduction should double the work buffer size.
Val* getGridCommWorkBufferSize(
const TensorDomain* td,
const std::vector<kir::ForLoop*>& for_loops = {},
int expansion_factor = 1) {
// The buffer size is the number of thread blocks multiplied by the
// number of threads not used for reduction domains.
Expand Down Expand Up @@ -172,10 +173,28 @@ Val* getGridCommWorkBufferSize(
}
buffer_size = SimplifyingIrBuilder::mulExpr(buffer_size, pt_dim);
}

// All iteration domains require a separate entry in the buffer for re-entrant
// grid reductions.
for (auto fl : for_loops) {
if (fl->isTrivial()) {
continue;
}
if (fl->iter_domain()->isThread()) {
// already accounted for.
continue;
}

buffer_size =
SimplifyingIrBuilder::mulExpr(buffer_size, fl->iter_domain()->extent());
}

return buffer_size;
}

Val* getGridSyncBufferSize(const TensorDomain* td) {
Val* getGridSyncBufferSize(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to expand the sync buffer?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, we shouldn't need to. Because we wait until all iterations are done to start cleaning any of them up. Maybe that's a reason it's slow, I think we should use multiple sync buffers for each reduction!

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know if it would make a difference, but seems like it's low risk.

const TensorDomain* td,
const std::vector<kir::ForLoop*>& for_loops = {}) {
// See the comment above for getGridCommWorkBufferSize.
Val* buffer_size = GpuLower::current()->kernel()->oneVal();
for (auto pt : kParallelTypeBIDs) {
Expand All @@ -191,9 +210,66 @@ Val* getGridSyncBufferSize(const TensorDomain* td) {
}
buffer_size = SimplifyingIrBuilder::mulExpr(buffer_size, pt_dim);
}

// All iteration domains require a separate semaphore for re-entrant grid
// reductions
for (auto fl : for_loops) {
if (fl->isTrivial()) {
continue;
}
if (fl->iter_domain()->isThread()) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fl->isTrivial() should include this condition.

// already accounted for.
continue;
}

buffer_size =
SimplifyingIrBuilder::mulExpr(buffer_size, fl->iter_domain()->extent());
}

return buffer_size;
}

Val* getEntranceCountGridReduce(std::vector<kir::ForLoop*>& for_loops) {
Val* grid_reduction_entrances = GpuLower::current()->kernel()->oneVal();

for (const auto loop : for_loops) {
if (loop->isTrivial()) {
continue;
}
if (loop->iter_domain()->isThread()) {
// already accounted for.
continue;
}
// TODO: Does this work for shift/gather?
grid_reduction_entrances = SimplifyingIrBuilder::mulExpr(
grid_reduction_entrances, loop->iter_domain()->extent());
}
return grid_reduction_entrances;
}

// Linear indexing of for loops for multiple entrances into grid reduce
// TODO: What happens if there's a broadcast that's resolved (not present in the
// grid reduce) but the global buffer isn't expanded?
Val* getEntranceLinIndGridReduce(std::vector<kir::ForLoop*>& for_loops) {
Val* linear_index = GpuLower::current()->kernel()->zeroVal();

for (const auto loop : for_loops) {
if (loop->isTrivial()) {
continue;
}
if (loop->iter_domain()->isThread()) {
// already accounted for.
continue;
}
// TODO: Does this work for shift/gather?
linear_index = SimplifyingIrBuilder::addExpr(
SimplifyingIrBuilder::mulExpr(
linear_index, loop->iter_domain()->extent()),
loop->index());
}
return linear_index;
}

} // namespace

void IndexLowering::handle(const ReductionOp* rop) {
Expand Down Expand Up @@ -271,12 +347,25 @@ void IndexLowering::handleGridReduction(

const auto reduce_buffer = ir_utils::allocGlobalBufferForGridComm(
getGridCommWorkBufferSize(
out_domain, rop->isAllreduce() && is_within_a_loop ? 2 : 1),
out_domain,
rop->isAllreduce() ? std::vector<kir::ForLoop*>() : for_loops_,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be nice if we could refactor the code on the conditional processing of when to expand the buffer.

rop->isAllreduce() && is_within_a_loop ? 2 : 1),
out->dtype(),
false);

const auto sync_buffer = ir_utils::allocGlobalBufferForGridComm(
getGridSyncBufferSize(out_domain), DataType::Int, true);
getGridSyncBufferSize(
out_domain,
rop->isAllreduce() ? std::vector<kir::ForLoop*>() : for_loops_),
DataType::Int,
true);

const auto entrance_ind = rop->isAllreduce()
? GpuLower::current()->kernel()->zeroVal()
: getEntranceLinIndGridReduce(for_loops_);
const auto n_entrances = rop->isAllreduce()
? GpuLower::current()->kernel()->oneVal()
: getEntranceCountGridReduce(for_loops_);

// The thread predicate for GridReduction needs to be set
// separately from the main predicate. Do not combine them like
Expand All @@ -291,6 +380,8 @@ void IndexLowering::handleGridReduction(
in,
reduce_buffer,
sync_buffer,
entrance_ind,
n_entrances,
rop->isAllreduce());

grid_reduction->setThreadPredicate(thread_pred);
Expand Down Expand Up @@ -412,13 +503,14 @@ void IndexLowering::handleGridReduction(
return ir_utils::allocGlobalBufferForGridComm(
getGridCommWorkBufferSize(
out_domain,
for_loops_,
(grouped_rop->isAllreduce() && is_within_a_loop ? 2 : 1)),
output->dtype(),
false);
});

const auto sync_buffer = ir_utils::allocGlobalBufferForGridComm(
getGridSyncBufferSize(out_domain), DataType::Int, true);
getGridSyncBufferSize(out_domain, for_loops_), DataType::Int, true);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GroupedReduction doesn't support reentrance, so this is not necessary right now.


// The thread predicate for GridReduction needs to be set
// separately from the main predicate. Do not combine them like
Expand Down Expand Up @@ -547,7 +639,9 @@ void IndexLowering::handleGridWelford(WelfordOp* indexed_wop) {
[](IterDomain* id) { return !isTrivialIterDomain(id); });

const auto work_buffer_size = getGridCommWorkBufferSize(
out_domain, indexed_wop->isAllreduce() && is_within_a_loop ? 2 : 1);
out_domain,
indexed_wop->isAllreduce() ? std::vector<kir::ForLoop*>() : for_loops_,
indexed_wop->isAllreduce() && is_within_a_loop ? 2 : 1);

const auto out_var_buffer = ir_utils::allocGlobalBufferForGridComm(
work_buffer_size, indexed_wop->outVar()->dtype(), false);
Expand All @@ -557,7 +651,19 @@ void IndexLowering::handleGridWelford(WelfordOp* indexed_wop) {
work_buffer_size, indexed_wop->outN()->dtype(), false);

const auto sync_buffer = ir_utils::allocGlobalBufferForGridComm(
getGridSyncBufferSize(out_domain), DataType::Int, true);
getGridSyncBufferSize(
out_domain,
indexed_wop->isAllreduce() ? std::vector<kir::ForLoop*>()
: for_loops_),
DataType::Int,
true);

const auto entrance_ind = indexed_wop->isAllreduce()
? GpuLower::current()->kernel()->zeroVal()
: getEntranceLinIndGridReduce(for_loops_);
const auto n_entrances = indexed_wop->isAllreduce()
? GpuLower::current()->kernel()->oneVal()
: getEntranceCountGridReduce(for_loops_);

// The thread predicate for GridReduction needs to be set
// separately from the main predicate. Do not combine them like
Expand All @@ -566,7 +672,13 @@ void IndexLowering::handleGridWelford(WelfordOp* indexed_wop) {
GpuLower::current()->threadPredMap().getPredicatedParallelTypes(out_tv);

auto grid_welford = IrBuilder::create<kir::GridWelford>(
indexed_wop, out_var_buffer, out_avg_buffer, out_N_buffer, sync_buffer);
indexed_wop,
out_var_buffer,
out_avg_buffer,
out_N_buffer,
sync_buffer,
entrance_ind,
n_entrances);

grid_welford->setThreadPredicate(thread_pred);

Expand Down
Loading