-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable grid reductions within loops. #1681
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -143,6 +143,7 @@ namespace { | |
// size. For example, FusedReduction should double the work buffer size. | ||
Val* getGridCommWorkBufferSize( | ||
const TensorDomain* td, | ||
const std::vector<kir::ForLoop*>& for_loops = {}, | ||
int expansion_factor = 1) { | ||
// The buffer size is the number of thread blocks multiplied by the | ||
// number of threads not used for reduction domains. | ||
|
@@ -172,10 +173,28 @@ Val* getGridCommWorkBufferSize( | |
} | ||
buffer_size = SimplifyingIrBuilder::mulExpr(buffer_size, pt_dim); | ||
} | ||
|
||
// All iteration domains require a separate entry in the buffer for re-entrant | ||
// grid reductions. | ||
for (auto fl : for_loops) { | ||
if (fl->isTrivial()) { | ||
continue; | ||
} | ||
if (fl->iter_domain()->isThread()) { | ||
// already accounted for. | ||
continue; | ||
} | ||
|
||
buffer_size = | ||
SimplifyingIrBuilder::mulExpr(buffer_size, fl->iter_domain()->extent()); | ||
} | ||
|
||
return buffer_size; | ||
} | ||
|
||
Val* getGridSyncBufferSize(const TensorDomain* td) { | ||
Val* getGridSyncBufferSize( | ||
const TensorDomain* td, | ||
const std::vector<kir::ForLoop*>& for_loops = {}) { | ||
// See the comment above for getGridCommWorkBufferSize. | ||
Val* buffer_size = GpuLower::current()->kernel()->oneVal(); | ||
for (auto pt : kParallelTypeBIDs) { | ||
|
@@ -191,9 +210,66 @@ Val* getGridSyncBufferSize(const TensorDomain* td) { | |
} | ||
buffer_size = SimplifyingIrBuilder::mulExpr(buffer_size, pt_dim); | ||
} | ||
|
||
// All iteration domains require a separate semaphore for re-entrant grid | ||
// reductions | ||
for (auto fl : for_loops) { | ||
if (fl->isTrivial()) { | ||
continue; | ||
} | ||
if (fl->iter_domain()->isThread()) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
// already accounted for. | ||
continue; | ||
} | ||
|
||
buffer_size = | ||
SimplifyingIrBuilder::mulExpr(buffer_size, fl->iter_domain()->extent()); | ||
} | ||
|
||
return buffer_size; | ||
} | ||
|
||
Val* getEntranceCountGridReduce(std::vector<kir::ForLoop*>& for_loops) { | ||
Val* grid_reduction_entrances = GpuLower::current()->kernel()->oneVal(); | ||
|
||
for (const auto loop : for_loops) { | ||
if (loop->isTrivial()) { | ||
continue; | ||
} | ||
if (loop->iter_domain()->isThread()) { | ||
// already accounted for. | ||
continue; | ||
} | ||
// TODO: Does this work for shift/gather? | ||
grid_reduction_entrances = SimplifyingIrBuilder::mulExpr( | ||
grid_reduction_entrances, loop->iter_domain()->extent()); | ||
} | ||
return grid_reduction_entrances; | ||
} | ||
|
||
// Linear indexing of for loops for multiple entrances into grid reduce | ||
// TODO: What happens if there's a broadcast that's resolved (not present in the | ||
// grid reduce) but the global buffer isn't expanded? | ||
Val* getEntranceLinIndGridReduce(std::vector<kir::ForLoop*>& for_loops) { | ||
Val* linear_index = GpuLower::current()->kernel()->zeroVal(); | ||
|
||
for (const auto loop : for_loops) { | ||
if (loop->isTrivial()) { | ||
continue; | ||
} | ||
if (loop->iter_domain()->isThread()) { | ||
// already accounted for. | ||
continue; | ||
} | ||
// TODO: Does this work for shift/gather? | ||
linear_index = SimplifyingIrBuilder::addExpr( | ||
SimplifyingIrBuilder::mulExpr( | ||
linear_index, loop->iter_domain()->extent()), | ||
loop->index()); | ||
} | ||
return linear_index; | ||
} | ||
|
||
} // namespace | ||
|
||
void IndexLowering::handle(const ReductionOp* rop) { | ||
|
@@ -271,12 +347,25 @@ void IndexLowering::handleGridReduction( | |
|
||
const auto reduce_buffer = ir_utils::allocGlobalBufferForGridComm( | ||
getGridCommWorkBufferSize( | ||
out_domain, rop->isAllreduce() && is_within_a_loop ? 2 : 1), | ||
out_domain, | ||
rop->isAllreduce() ? std::vector<kir::ForLoop*>() : for_loops_, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would be nice if we could refactor the code on the conditional processing of when to expand the buffer. |
||
rop->isAllreduce() && is_within_a_loop ? 2 : 1), | ||
out->dtype(), | ||
false); | ||
|
||
const auto sync_buffer = ir_utils::allocGlobalBufferForGridComm( | ||
getGridSyncBufferSize(out_domain), DataType::Int, true); | ||
getGridSyncBufferSize( | ||
out_domain, | ||
rop->isAllreduce() ? std::vector<kir::ForLoop*>() : for_loops_), | ||
DataType::Int, | ||
true); | ||
|
||
const auto entrance_ind = rop->isAllreduce() | ||
? GpuLower::current()->kernel()->zeroVal() | ||
: getEntranceLinIndGridReduce(for_loops_); | ||
const auto n_entrances = rop->isAllreduce() | ||
? GpuLower::current()->kernel()->oneVal() | ||
: getEntranceCountGridReduce(for_loops_); | ||
|
||
// The thread predicate for GridReduction needs to be set | ||
// separately from the main predicate. Do not combine them like | ||
|
@@ -291,6 +380,8 @@ void IndexLowering::handleGridReduction( | |
in, | ||
reduce_buffer, | ||
sync_buffer, | ||
entrance_ind, | ||
n_entrances, | ||
rop->isAllreduce()); | ||
|
||
grid_reduction->setThreadPredicate(thread_pred); | ||
|
@@ -412,13 +503,14 @@ void IndexLowering::handleGridReduction( | |
return ir_utils::allocGlobalBufferForGridComm( | ||
getGridCommWorkBufferSize( | ||
out_domain, | ||
for_loops_, | ||
(grouped_rop->isAllreduce() && is_within_a_loop ? 2 : 1)), | ||
output->dtype(), | ||
false); | ||
}); | ||
|
||
const auto sync_buffer = ir_utils::allocGlobalBufferForGridComm( | ||
getGridSyncBufferSize(out_domain), DataType::Int, true); | ||
getGridSyncBufferSize(out_domain, for_loops_), DataType::Int, true); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. GroupedReduction doesn't support reentrance, so this is not necessary right now. |
||
|
||
// The thread predicate for GridReduction needs to be set | ||
// separately from the main predicate. Do not combine them like | ||
|
@@ -547,7 +639,9 @@ void IndexLowering::handleGridWelford(WelfordOp* indexed_wop) { | |
[](IterDomain* id) { return !isTrivialIterDomain(id); }); | ||
|
||
const auto work_buffer_size = getGridCommWorkBufferSize( | ||
out_domain, indexed_wop->isAllreduce() && is_within_a_loop ? 2 : 1); | ||
out_domain, | ||
indexed_wop->isAllreduce() ? std::vector<kir::ForLoop*>() : for_loops_, | ||
indexed_wop->isAllreduce() && is_within_a_loop ? 2 : 1); | ||
|
||
const auto out_var_buffer = ir_utils::allocGlobalBufferForGridComm( | ||
work_buffer_size, indexed_wop->outVar()->dtype(), false); | ||
|
@@ -557,7 +651,19 @@ void IndexLowering::handleGridWelford(WelfordOp* indexed_wop) { | |
work_buffer_size, indexed_wop->outN()->dtype(), false); | ||
|
||
const auto sync_buffer = ir_utils::allocGlobalBufferForGridComm( | ||
getGridSyncBufferSize(out_domain), DataType::Int, true); | ||
getGridSyncBufferSize( | ||
out_domain, | ||
indexed_wop->isAllreduce() ? std::vector<kir::ForLoop*>() | ||
: for_loops_), | ||
DataType::Int, | ||
true); | ||
|
||
const auto entrance_ind = indexed_wop->isAllreduce() | ||
? GpuLower::current()->kernel()->zeroVal() | ||
: getEntranceLinIndGridReduce(for_loops_); | ||
const auto n_entrances = indexed_wop->isAllreduce() | ||
? GpuLower::current()->kernel()->oneVal() | ||
: getEntranceCountGridReduce(for_loops_); | ||
|
||
// The thread predicate for GridReduction needs to be set | ||
// separately from the main predicate. Do not combine them like | ||
|
@@ -566,7 +672,13 @@ void IndexLowering::handleGridWelford(WelfordOp* indexed_wop) { | |
GpuLower::current()->threadPredMap().getPredicatedParallelTypes(out_tv); | ||
|
||
auto grid_welford = IrBuilder::create<kir::GridWelford>( | ||
indexed_wop, out_var_buffer, out_avg_buffer, out_N_buffer, sync_buffer); | ||
indexed_wop, | ||
out_var_buffer, | ||
out_avg_buffer, | ||
out_N_buffer, | ||
sync_buffer, | ||
entrance_ind, | ||
n_entrances); | ||
|
||
grid_welford->setThreadPredicate(thread_pred); | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to expand the sync buffer?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, we shouldn't need to. Because we wait until all iterations are done to start cleaning any of them up. Maybe that's a reason it's slow, I think we should use multiple sync buffers for each reduction!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know if it would make a difference, but seems like it's low risk.