Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Compute At Refactoring 4 #2244

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
203 changes: 178 additions & 25 deletions torch/csrc/jit/codegen/cuda/compute_at_map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,10 @@ IterDomainGraph::IterDomainGraph(Fusion* fusion, bool allow_self_mapping) {
IdMappingMode::PERMISSIVE,
IdMappingMode::LOOP};

// Initialize disjoint sets
for (auto mode : mapping_types) {
disjoint_ids_[mode] = DisjointSets<IterDomain*>();
disjoint_exprs_[mode] = DisjointSets<Expr*>();
}

build(fusion);
Expand Down Expand Up @@ -89,6 +91,27 @@ DisjointSets<IterDomain*>& IterDomainGraph::disjointIdsSet(IdMappingMode mode) {
return disjoint_ids_it->second;
}

const DisjointSets<Expr*>& IterDomainGraph::getDisjointExprsSet(
IdMappingMode mode) const {
auto disjoint_exprs_it = disjoint_exprs_.find(mode);
TORCH_INTERNAL_ASSERT(
disjoint_exprs_it != disjoint_exprs_.end(),
"Mapping mode ",
mode,
" not supported.");
return disjoint_exprs_it->second;
}

DisjointSets<Expr*>& IterDomainGraph::disjointExprsSet(IdMappingMode mode) {
auto disjoint_exprs_it = disjoint_exprs_.find(mode);
TORCH_INTERNAL_ASSERT(
disjoint_exprs_it != disjoint_exprs_.end(),
"Mapping mode ",
mode,
" not supported.");
return disjoint_exprs_it->second;
}

bool IterDomainGraph::exprsMap(
Expr* first,
Expr* second,
Expand All @@ -103,7 +126,7 @@ bool IterDomainGraph::exprsMap(
}

TORCH_INTERNAL_ASSERT(
first->isA<Merge>() || first->isA<Split>(),
first->isA<Merge>() || first->isA<Split>() || first->isA<Swizzle2D>(),
"Merge and split are the only expressions supported through rfactor operations in compute at map, but found:\n",
first->toString());

Expand Down Expand Up @@ -181,9 +204,61 @@ bool IterDomainGraph::exprsMap(
}
}

if (first->isA<Swizzle2D>()) {
auto first_swizzle = first->as<Swizzle2D>();
auto second_swizzle = second->as<Swizzle2D>();
if (first_swizzle->swizzleMode() != second_swizzle->swizzleMode() ||
first_swizzle->swizzleType() != second_swizzle->swizzleType()) {
return false;
}
}

return true;
}

void IterDomainGraph::mapIds(
IterDomain* id0,
IterDomain* id1,
IdMappingMode mode) {
if (mode == IdMappingMode::LOOP) {
disjointIdsSet(mode).mapEntries(id0, id1);
return;
}

if (disjointIdsSet(mode).strictAreMapped(id0, id1)) {
// Already mapped together, nothing to do.
return;
}

disjointIdsSet(mode).mapEntries(id0, id1);

// Map definitions if expressions are not already mapped
auto def0 = id0->definition();
auto def1 = id1->definition();
if (def0 != nullptr && def1 != nullptr) {
if (!disjointExprsSet(mode).strictAreMapped(def0, def1)) {
if (exprsMap(def0, def1, false, mode)) {
if (mapThroughExpr(def0, def1, false, mode)) {
disjointExprsSet(mode).mapEntries(def0, def1);
}
}
}
}

// Map uses if expressions are not already mapped
auto use0 = id_uses_.at(id0);
auto use1 = id_uses_.at(id1);
if (use0 != nullptr && use1 != nullptr) {
if (!disjointExprsSet(mode).strictAreMapped(use0, use1)) {
if (exprsMap(use0, use1, true, mode)) {
if (mapThroughExpr(use0, use1, true, mode)) {
disjointExprsSet(mode).mapEntries(use0, use1);
}
}
}
}
}

// Given first and second Exprs "match"
// Expr type matches
// IterDomain's in the inputs and outputs exact match, (including argument
Expand All @@ -192,17 +267,17 @@ bool IterDomainGraph::exprsMap(
// better, as today it will just check it's the same symbol or evaluated to
// the same constant. However, we know all the extents of all the
// IterDomain's that exact map with eachother are the same value.
void IterDomainGraph::mapThroughExpr(
bool IterDomainGraph::mapThroughExpr(
Expr* first,
Expr* second,
bool forward,
IdMappingMode mode) {
if (first == nullptr || second == nullptr) {
return;
return false;
}

if (!exprsMap(first, second, forward, mode)) {
return;
return false;
}

auto first_ids = ir_utils::filterByType<IterDomain>(
Expand All @@ -220,6 +295,8 @@ void IterDomainGraph::mapThroughExpr(
for (auto out_i : c10::irange(first_ids.size())) {
mapIds(first_ids[out_i], second_ids[out_i], mode);
}

return true;
}

namespace {
Expand Down Expand Up @@ -332,9 +409,19 @@ void IterDomainGraph::initializeId(
bool is_leaf_id) {
disjointIdsSet(IdMappingMode::PERMISSIVE).initializeSet(id);
disjointIdsSet(IdMappingMode::EXACT).initializeSet(id);

if (id->definition() != nullptr) {
disjointExprsSet(IdMappingMode::PERMISSIVE).initializeSet(id->definition());
disjointExprsSet(IdMappingMode::EXACT).initializeSet(id->definition());
}

if (is_leaf_id) {
disjointIdsSet(IdMappingMode::LOOP).initializeSet(id);
if (id->definition() != nullptr) {
disjointExprsSet(IdMappingMode::LOOP).initializeSet(id->definition());
}
}

consumers_[id] = {};
producers_[id] = {};

Expand All @@ -343,9 +430,41 @@ void IterDomainGraph::initializeId(
}
}

void IterDomainGraph::buildIterDomainUses(Fusion* fusion) {
// Generate IterDomain uses:
for (auto tv : ir_utils::allTvs(fusion)) {
auto all_ids = ir_utils::allIDsOf(tv);
for (auto id : all_ids) {
if (id_uses_.find(id) == id_uses_.end()) {
id_uses_[id] = nullptr;
}

auto def = id->definition();

if (def == nullptr) {
continue;
}
auto inp_ids = ir_utils::filterByType<IterDomain>(def->inputs());
for (auto inp_id : inp_ids) {
if (id_uses_.find(id) != id_uses_.end()) {
TORCH_INTERNAL_ASSERT(
id_uses_[id] == nullptr,
"\nTried to set multiple uses to iteration domain: ",
id->toString(),
"\nWhich is not supported, tried to set expr:\n ",
def->toString(),
"However the following expression was already set:\n ",
id_uses_[id]->toString());
}
id_uses_[inp_id] = def;
}
}
}
}

void IterDomainGraph::initialIdProcessing(Fusion* fusion) {
// Initialize entries for every iteration domain and mark view like iteration
// domains and leaf iteration domains.
// Initialize entries for every iteration domain and mark view like
// iteration domains and leaf iteration domains.
for (auto tv : ir_utils::allTvs(fusion)) {
const auto& domain = tv->domain()->domain();
auto all_ids = ir_utils::allIDsOf(tv);
Expand All @@ -357,9 +476,9 @@ void IterDomainGraph::initialIdProcessing(Fusion* fusion) {
// Check if this id is a view like rfactor id
bool is_view_rfactor_id = false;
if (view_like_domain && id->isRFactorProduct()) {
// If the tensor domain is a view like domain, and the iteration domain
// is marked as an rfactor product and is in the rfactor domain, it's a
// view like rfactor iteration domain
// If the tensor domain is a view like domain, and the iteration
// domain is marked as an rfactor product and is in the rfactor
// domain, it's a view like rfactor iteration domain
const auto& rfactor_domain = tv->domain()->getMaybeRFactorDomain();
if (std::find(rfactor_domain.begin(), rfactor_domain.end(), id) !=
rfactor_domain.end()) {
Expand Down Expand Up @@ -470,6 +589,21 @@ void mapMaybeSwizzleOp(
}
} // namespace

void IterDomainGraph::mapThroughLoopSwizzles(IdMappingMode mode) {
for (auto use_it : id_uses_) {
auto use = use_it.second;
if (auto swizzle_2d = dynamic_cast<Swizzle2D*>(use)) {
// Map each input to its corresponding output on the given
// disjoint set if this is a loop swizzle. Loop swizzles don't impact
// indexing, only iteration order.
if (swizzle_2d->swizzleMode() == SwizzleMode::Loop) {
mapIds(swizzle_2d->inX(), swizzle_2d->outX(), mode);
mapIds(swizzle_2d->inY(), swizzle_2d->outY(), mode);
}
}
}
}

void IterDomainGraph::mapExact(Expr* expr) {
TensorView* c_tv = ir_utils::getTvOutput(expr);

Expand All @@ -482,6 +616,11 @@ void IterDomainGraph::mapExact(Expr* expr) {
PairwiseRootDomainMap(p_tv, c_tv, true)
.mapConsumerToProducer(c_tv->domain(), p_tv->domain());

for (auto c_id : getSortedKeys(exact_c2p_root_map, Statement::lessThan)) {
auto p_id = exact_c2p_root_map.at(c_id);
mapIds(c_id, p_id, IdMappingMode::EXACT);
}

// Same as permissive above but for exact
auto exact_replay_PasC = BestEffortReplay(
p_tv->domain()->domain(), c_tv->domain()->domain(), exact_c2p_root_map);
Expand All @@ -490,20 +629,15 @@ void IterDomainGraph::mapExact(Expr* expr) {

for (auto c_id : getSortedKeys(exact_c2p_map, Statement::lessThan)) {
auto p_id = exact_c2p_map.at(c_id);
mapIds(c_id, p_id, IdMappingMode::EXACT);

// TODO: consumers/producers should be on a per map basis, mapping should
// include unique expr between the disjoint sets
// TODO: consumers/producers should be on a per map basis, mapping
// should include unique expr between the disjoint sets
consumers_.at(p_id).pushBack(c_id);
producers_.at(c_id).pushBack(p_id);

// Add the swizzle inputs to the same
// disjoint set as well if either c_id
// or p_id is swizzle output.
mapMaybeSwizzleOp(disjointIdsSet(IdMappingMode::EXACT), p_id);
mapMaybeSwizzleOp(disjointIdsSet(IdMappingMode::EXACT), c_id);
}
}

mapThroughLoopSwizzles(IdMappingMode::EXACT);
}

void IterDomainGraph::mapPermissiveAndLoop(Expr* expr) {
Expand Down Expand Up @@ -562,6 +696,8 @@ void IterDomainGraph::mapPermissiveAndLoop(Expr* expr) {
}
}
}

mapThroughLoopSwizzles(IdMappingMode::PERMISSIVE);
}

void IterDomainGraph::mapRFactorExprs(Fusion* fusion) {
Expand Down Expand Up @@ -719,6 +855,8 @@ void IterDomainGraph::buildAlmostExactMap() {
// Build almost exact map by forwarding through broadcast axes
disjointIdsSet(IdMappingMode::ALMOSTEXACT) =
disjointIdsSet(IdMappingMode::EXACT);
disjointExprsSet(IdMappingMode::ALMOSTEXACT) =
disjointExprsSet(IdMappingMode::EXACT);
std::unordered_set<Expr*> visited;
auto all_elements = disjointIdsSet(IdMappingMode::EXACT).getAllElements();
for (auto entry : all_elements.vector()) {
Expand Down Expand Up @@ -752,27 +890,41 @@ void IterDomainGraph::buildAlmostExactMap() {
void IterDomainGraph::build(Fusion* fusion) {
FusionGuard fg(fusion);

// Add uses to all iter domains.
buildIterDomainUses(fusion);

// Initialize the maps with all the IterDomains defined in the fusion.
initialIdProcessing(fusion);

for (auto expr : fusion->exprs()) {
if (!ir_utils::isTvOp(expr)) {
continue;
}
// Filter non-TensorView expressions
auto all_exprs = fusion->exprs();
std::vector<Expr*> tv_exprs;

std::copy_if(
all_exprs.begin(),
all_exprs.end(),
std::back_inserter(tv_exprs),
[](Expr* expr) { return ir_utils::isTvOp(expr); });

for (auto expr : tv_exprs) {
// Connect multi-output expressions as they're trivial to connect.
mapMultiOutput(expr);
}

for (auto expr : fusion->exprs()) {
// Connect ID's on the exact dimension
mapExact(expr);
}

for (auto expr : fusion->exprs()) {
// Connect across the permissive, loop, and for now consumer_, producer_
// dimensions.
mapPermissiveAndLoop(expr);
}

// Map forward and backward through TV root<->rfactor to cross map connections
// that are not explicitly defined through input<->output expression maps.
// Map forward and backward through TV root<->rfactor to cross map
// connections that are not explicitly defined through input<->output
// expression maps.
mapRFactorExprs(fusion);

buildAlmostExactMap();
Expand Down Expand Up @@ -825,7 +977,8 @@ void ComputeAtMap::allocateIndexVariables() {
// first allocate thread and grid parallel indices:
// The validation pass will check that the parallel bindings within the
// loop disjoint IDs set are consistent so all the loops within this
// disjoint set will be realized implicitly using parallel index variables.
// disjoint set will be realized implicitly using parallel index
// variables.
if (std::any_of(
loop_disjoint_set->vector().begin(),
loop_disjoint_set->vector().end(),
Expand Down
Loading