diff --git a/torch/csrc/jit/codegen/cuda/compute_at_map.cpp b/torch/csrc/jit/codegen/cuda/compute_at_map.cpp index 71a994da149af..0e116563e5ccc 100644 --- a/torch/csrc/jit/codegen/cuda/compute_at_map.cpp +++ b/torch/csrc/jit/codegen/cuda/compute_at_map.cpp @@ -47,8 +47,10 @@ IterDomainGraph::IterDomainGraph(Fusion* fusion, bool allow_self_mapping) { IdMappingMode::PERMISSIVE, IdMappingMode::LOOP}; + // Initialize disjoint sets for (auto mode : mapping_types) { disjoint_ids_[mode] = DisjointSets(); + disjoint_exprs_[mode] = DisjointSets(); } build(fusion); @@ -89,6 +91,27 @@ DisjointSets& IterDomainGraph::disjointIdsSet(IdMappingMode mode) { return disjoint_ids_it->second; } +const DisjointSets& IterDomainGraph::getDisjointExprsSet( + IdMappingMode mode) const { + auto disjoint_exprs_it = disjoint_exprs_.find(mode); + TORCH_INTERNAL_ASSERT( + disjoint_exprs_it != disjoint_exprs_.end(), + "Mapping mode ", + mode, + " not supported."); + return disjoint_exprs_it->second; +} + +DisjointSets& IterDomainGraph::disjointExprsSet(IdMappingMode mode) { + auto disjoint_exprs_it = disjoint_exprs_.find(mode); + TORCH_INTERNAL_ASSERT( + disjoint_exprs_it != disjoint_exprs_.end(), + "Mapping mode ", + mode, + " not supported."); + return disjoint_exprs_it->second; +} + bool IterDomainGraph::exprsMap( Expr* first, Expr* second, @@ -103,7 +126,7 @@ bool IterDomainGraph::exprsMap( } TORCH_INTERNAL_ASSERT( - first->isA() || first->isA(), + first->isA() || first->isA() || first->isA(), "Merge and split are the only expressions supported through rfactor operations in compute at map, but found:\n", first->toString()); @@ -181,9 +204,61 @@ bool IterDomainGraph::exprsMap( } } + if (first->isA()) { + auto first_swizzle = first->as(); + auto second_swizzle = second->as(); + if (first_swizzle->swizzleMode() != second_swizzle->swizzleMode() || + first_swizzle->swizzleType() != second_swizzle->swizzleType()) { + return false; + } + } + return true; } +void IterDomainGraph::mapIds( + IterDomain* id0, + IterDomain* id1, + IdMappingMode mode) { + if (mode == IdMappingMode::LOOP) { + disjointIdsSet(mode).mapEntries(id0, id1); + return; + } + + if (disjointIdsSet(mode).strictAreMapped(id0, id1)) { + // Already mapped together, nothing to do. + return; + } + + disjointIdsSet(mode).mapEntries(id0, id1); + + // Map definitions if expressions are not already mapped + auto def0 = id0->definition(); + auto def1 = id1->definition(); + if (def0 != nullptr && def1 != nullptr) { + if (!disjointExprsSet(mode).strictAreMapped(def0, def1)) { + if (exprsMap(def0, def1, false, mode)) { + if (mapThroughExpr(def0, def1, false, mode)) { + disjointExprsSet(mode).mapEntries(def0, def1); + } + } + } + } + + // Map uses if expressions are not already mapped + auto use0 = id_uses_.at(id0); + auto use1 = id_uses_.at(id1); + if (use0 != nullptr && use1 != nullptr) { + if (!disjointExprsSet(mode).strictAreMapped(use0, use1)) { + if (exprsMap(use0, use1, true, mode)) { + if (mapThroughExpr(use0, use1, true, mode)) { + disjointExprsSet(mode).mapEntries(use0, use1); + } + } + } + } +} + // Given first and second Exprs "match" // Expr type matches // IterDomain's in the inputs and outputs exact match, (including argument @@ -192,17 +267,17 @@ bool IterDomainGraph::exprsMap( // better, as today it will just check it's the same symbol or evaluated to // the same constant. However, we know all the extents of all the // IterDomain's that exact map with eachother are the same value. -void IterDomainGraph::mapThroughExpr( +bool IterDomainGraph::mapThroughExpr( Expr* first, Expr* second, bool forward, IdMappingMode mode) { if (first == nullptr || second == nullptr) { - return; + return false; } if (!exprsMap(first, second, forward, mode)) { - return; + return false; } auto first_ids = ir_utils::filterByType( @@ -220,6 +295,8 @@ void IterDomainGraph::mapThroughExpr( for (auto out_i : c10::irange(first_ids.size())) { mapIds(first_ids[out_i], second_ids[out_i], mode); } + + return true; } namespace { @@ -332,9 +409,19 @@ void IterDomainGraph::initializeId( bool is_leaf_id) { disjointIdsSet(IdMappingMode::PERMISSIVE).initializeSet(id); disjointIdsSet(IdMappingMode::EXACT).initializeSet(id); + + if (id->definition() != nullptr) { + disjointExprsSet(IdMappingMode::PERMISSIVE).initializeSet(id->definition()); + disjointExprsSet(IdMappingMode::EXACT).initializeSet(id->definition()); + } + if (is_leaf_id) { disjointIdsSet(IdMappingMode::LOOP).initializeSet(id); + if (id->definition() != nullptr) { + disjointExprsSet(IdMappingMode::LOOP).initializeSet(id->definition()); + } } + consumers_[id] = {}; producers_[id] = {}; @@ -343,9 +430,41 @@ void IterDomainGraph::initializeId( } } +void IterDomainGraph::buildIterDomainUses(Fusion* fusion) { + // Generate IterDomain uses: + for (auto tv : ir_utils::allTvs(fusion)) { + auto all_ids = ir_utils::allIDsOf(tv); + for (auto id : all_ids) { + if (id_uses_.find(id) == id_uses_.end()) { + id_uses_[id] = nullptr; + } + + auto def = id->definition(); + + if (def == nullptr) { + continue; + } + auto inp_ids = ir_utils::filterByType(def->inputs()); + for (auto inp_id : inp_ids) { + if (id_uses_.find(id) != id_uses_.end()) { + TORCH_INTERNAL_ASSERT( + id_uses_[id] == nullptr, + "\nTried to set multiple uses to iteration domain: ", + id->toString(), + "\nWhich is not supported, tried to set expr:\n ", + def->toString(), + "However the following expression was already set:\n ", + id_uses_[id]->toString()); + } + id_uses_[inp_id] = def; + } + } + } +} + void IterDomainGraph::initialIdProcessing(Fusion* fusion) { - // Initialize entries for every iteration domain and mark view like iteration - // domains and leaf iteration domains. + // Initialize entries for every iteration domain and mark view like + // iteration domains and leaf iteration domains. for (auto tv : ir_utils::allTvs(fusion)) { const auto& domain = tv->domain()->domain(); auto all_ids = ir_utils::allIDsOf(tv); @@ -357,9 +476,9 @@ void IterDomainGraph::initialIdProcessing(Fusion* fusion) { // Check if this id is a view like rfactor id bool is_view_rfactor_id = false; if (view_like_domain && id->isRFactorProduct()) { - // If the tensor domain is a view like domain, and the iteration domain - // is marked as an rfactor product and is in the rfactor domain, it's a - // view like rfactor iteration domain + // If the tensor domain is a view like domain, and the iteration + // domain is marked as an rfactor product and is in the rfactor + // domain, it's a view like rfactor iteration domain const auto& rfactor_domain = tv->domain()->getMaybeRFactorDomain(); if (std::find(rfactor_domain.begin(), rfactor_domain.end(), id) != rfactor_domain.end()) { @@ -470,6 +589,21 @@ void mapMaybeSwizzleOp( } } // namespace +void IterDomainGraph::mapThroughLoopSwizzles(IdMappingMode mode) { + for (auto use_it : id_uses_) { + auto use = use_it.second; + if (auto swizzle_2d = dynamic_cast(use)) { + // Map each input to its corresponding output on the given + // disjoint set if this is a loop swizzle. Loop swizzles don't impact + // indexing, only iteration order. + if (swizzle_2d->swizzleMode() == SwizzleMode::Loop) { + mapIds(swizzle_2d->inX(), swizzle_2d->outX(), mode); + mapIds(swizzle_2d->inY(), swizzle_2d->outY(), mode); + } + } + } +} + void IterDomainGraph::mapExact(Expr* expr) { TensorView* c_tv = ir_utils::getTvOutput(expr); @@ -482,6 +616,11 @@ void IterDomainGraph::mapExact(Expr* expr) { PairwiseRootDomainMap(p_tv, c_tv, true) .mapConsumerToProducer(c_tv->domain(), p_tv->domain()); + for (auto c_id : getSortedKeys(exact_c2p_root_map, Statement::lessThan)) { + auto p_id = exact_c2p_root_map.at(c_id); + mapIds(c_id, p_id, IdMappingMode::EXACT); + } + // Same as permissive above but for exact auto exact_replay_PasC = BestEffortReplay( p_tv->domain()->domain(), c_tv->domain()->domain(), exact_c2p_root_map); @@ -490,20 +629,15 @@ void IterDomainGraph::mapExact(Expr* expr) { for (auto c_id : getSortedKeys(exact_c2p_map, Statement::lessThan)) { auto p_id = exact_c2p_map.at(c_id); - mapIds(c_id, p_id, IdMappingMode::EXACT); - // TODO: consumers/producers should be on a per map basis, mapping should - // include unique expr between the disjoint sets + // TODO: consumers/producers should be on a per map basis, mapping + // should include unique expr between the disjoint sets consumers_.at(p_id).pushBack(c_id); producers_.at(c_id).pushBack(p_id); - - // Add the swizzle inputs to the same - // disjoint set as well if either c_id - // or p_id is swizzle output. - mapMaybeSwizzleOp(disjointIdsSet(IdMappingMode::EXACT), p_id); - mapMaybeSwizzleOp(disjointIdsSet(IdMappingMode::EXACT), c_id); } } + + mapThroughLoopSwizzles(IdMappingMode::EXACT); } void IterDomainGraph::mapPermissiveAndLoop(Expr* expr) { @@ -562,6 +696,8 @@ void IterDomainGraph::mapPermissiveAndLoop(Expr* expr) { } } } + + mapThroughLoopSwizzles(IdMappingMode::PERMISSIVE); } void IterDomainGraph::mapRFactorExprs(Fusion* fusion) { @@ -719,6 +855,8 @@ void IterDomainGraph::buildAlmostExactMap() { // Build almost exact map by forwarding through broadcast axes disjointIdsSet(IdMappingMode::ALMOSTEXACT) = disjointIdsSet(IdMappingMode::EXACT); + disjointExprsSet(IdMappingMode::ALMOSTEXACT) = + disjointExprsSet(IdMappingMode::EXACT); std::unordered_set visited; auto all_elements = disjointIdsSet(IdMappingMode::EXACT).getAllElements(); for (auto entry : all_elements.vector()) { @@ -752,27 +890,41 @@ void IterDomainGraph::buildAlmostExactMap() { void IterDomainGraph::build(Fusion* fusion) { FusionGuard fg(fusion); + // Add uses to all iter domains. + buildIterDomainUses(fusion); + // Initialize the maps with all the IterDomains defined in the fusion. initialIdProcessing(fusion); - for (auto expr : fusion->exprs()) { - if (!ir_utils::isTvOp(expr)) { - continue; - } + // Filter non-TensorView expressions + auto all_exprs = fusion->exprs(); + std::vector tv_exprs; + + std::copy_if( + all_exprs.begin(), + all_exprs.end(), + std::back_inserter(tv_exprs), + [](Expr* expr) { return ir_utils::isTvOp(expr); }); + for (auto expr : tv_exprs) { // Connect multi-output expressions as they're trivial to connect. mapMultiOutput(expr); + } + for (auto expr : fusion->exprs()) { // Connect ID's on the exact dimension mapExact(expr); + } + for (auto expr : fusion->exprs()) { // Connect across the permissive, loop, and for now consumer_, producer_ // dimensions. mapPermissiveAndLoop(expr); } - // Map forward and backward through TV root<->rfactor to cross map connections - // that are not explicitly defined through input<->output expression maps. + // Map forward and backward through TV root<->rfactor to cross map + // connections that are not explicitly defined through input<->output + // expression maps. mapRFactorExprs(fusion); buildAlmostExactMap(); @@ -825,7 +977,8 @@ void ComputeAtMap::allocateIndexVariables() { // first allocate thread and grid parallel indices: // The validation pass will check that the parallel bindings within the // loop disjoint IDs set are consistent so all the loops within this - // disjoint set will be realized implicitly using parallel index variables. + // disjoint set will be realized implicitly using parallel index + // variables. if (std::any_of( loop_disjoint_set->vector().begin(), loop_disjoint_set->vector().end(), diff --git a/torch/csrc/jit/codegen/cuda/compute_at_map.h b/torch/csrc/jit/codegen/cuda/compute_at_map.h index dd5173fb72c03..46a9bc090618d 100644 --- a/torch/csrc/jit/codegen/cuda/compute_at_map.h +++ b/torch/csrc/jit/codegen/cuda/compute_at_map.h @@ -66,6 +66,9 @@ class TORCH_CUDA_CU_API IterDomainGraph { // Returns the disjoint set according to one of the mapping mode types. const DisjointSets& getDisjointIdsSet(IdMappingMode mode) const; + // Returns the disjoint set according to one of the mapping mode types. + const DisjointSets& getDisjointExprsSet(IdMappingMode mode) const; + // Consumers and producers is not symmetric like the other sets const std::unordered_map>& consumers() const { @@ -103,6 +106,9 @@ class TORCH_CUDA_CU_API IterDomainGraph { // ======= START Iteration domain build process in order called ======= + // Fills id_uses_ for all IterDomains active in the fusion. + void buildIterDomainUses(Fusion* fusion); + // Initializes entries for the provided IterDomain in the overall // IterDomainGraph void initializeId(IterDomain* id, bool is_view_rfactor_id, bool is_leaf_id); @@ -126,6 +132,10 @@ class TORCH_CUDA_CU_API IterDomainGraph { // producer_ void mapPermissiveAndLoop(Expr* expr); + // Map through loop swizzles, as input/output IterDomains are exact, only the + // order they're traversed differs. + void mapThroughLoopSwizzles(IdMappingMode mode); + // Propagates forward then backward through all view like rfactor // transformations to map cross view operations. // @@ -144,10 +154,12 @@ class TORCH_CUDA_CU_API IterDomainGraph { // Non-const internal only version of getDisjointIdsSet. DisjointSets& disjointIdsSet(IdMappingMode mode); - // Simple alias - void mapIds(IterDomain* id0, IterDomain* id1, IdMappingMode mode) { - disjointIdsSet(mode).mapEntries(id0, id1); - } + // Non-const internal only version of getDisjointExprsSet. + DisjointSets& disjointExprsSet(IdMappingMode mode); + + // Set id0 and id1 to mapped in disjointIdsSet[mode], update id0->definition() + // and id1->definition() sets in disjointExprsSet. + void mapIds(IterDomain* id0, IterDomain* id1, IdMappingMode mode); // Checks if expr's are considered "the same" where sameness inputs and // outputs in the same position across expressions map with provided @@ -156,20 +168,29 @@ class TORCH_CUDA_CU_API IterDomainGraph { // will map outputs // else // will map inputs - // in the provided mode - void mapThroughExpr( + // in the provided mode. + // Returns if expressions were mapped through. + bool mapThroughExpr( Expr* first, Expr* second, bool forward, IdMappingMode mode); - // Keeps a disjoint set entry for all IterDomain mapping mode types. + // Keeps a disjoint set entry for all IterDomain for all mapping mode types. // // Using an array here might be nice, but it seems hard to use an enum as an // array key // https://stackoverflow.com/questions/2102582/how-can-i-count-the-items-in-an-enum std::unordered_map> disjoint_ids_; + // Keeps a disjoint set entry for all Expressions for all mapping mode types. + std::unordered_map> disjoint_exprs_; + + // If multiple transformations occur IterDomains could have multiple uses, + // however only one should be active in the given Fusion. Track what the + // active IterDomain uses are, they can only be used once. + std::unordered_map id_uses_; + // Consumers and producers is not symmetric like the other sets // TODO: Generalize to mapping type. Mappings between producer TV ids and // consumer TV ids depend on the mapping type.