diff --git a/CMakeLists.txt b/CMakeLists.txt index a7d5d08eec2..6325ac765c6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -353,6 +353,7 @@ if(BUILD_TEST) list(APPEND JIT_TEST_SRCS ${NVFUSER_ROOT}/test/test_gpu_outer_reduction.cpp) list(APPEND JIT_TEST_SRCS ${NVFUSER_ROOT}/test/test_loop_rotation.cpp) list(APPEND JIT_TEST_SRCS ${NVFUSER_ROOT}/test/test_gpu_shift.cpp) + list(APPEND JIT_TEST_SRCS ${NVFUSER_ROOT}/test/test_resize.cpp) list(APPEND JIT_TEST_SRCS ${NVFUSER_ROOT}/test/test_gpu_tensorcore.cpp) list(APPEND JIT_TEST_SRCS ${NVFUSER_ROOT}/test/test_gpu_matmul_sass.cpp) list(APPEND JIT_TEST_SRCS ${NVFUSER_ROOT}/test/test_gpu_view.cpp) diff --git a/csrc/codegen.cpp b/csrc/codegen.cpp index 922eede6b84..35c53d55bca 100644 --- a/csrc/codegen.cpp +++ b/csrc/codegen.cpp @@ -2774,6 +2774,41 @@ class CudaKernelGenerator : private OptOutConstDispatch { indent() << "NVFUSER_UPDATE_MAGIC_ZERO\n"; } + void handle(const CatOp* cat) final { + auto out = gen(cat->output(0)); + + // Generate code like: + // if (consumer_idx < producer_0_extent) { + // consumer[consumer_idx] = produce_0[producer_idx0]; + // } else if (consumer_idx < producer_1_extent) { + // consumer[consumer_idx] = produce_1[producer_idx1]; + // } else if (consumer_idx < producer_2_extent) { + // consumer[consumer_idx] = produce_2[producer_idx2]; + // } else { + // consumer[consumer_idx] = produce_3[producer_idx3]; + // } + + for (const auto i : c10::irange(cat->inputs().size())) { + auto inp = cat->input(i)->as(); + auto inp_str = gen(inp); + if (i < cat->inputs().size() - 1) { + if (i == 0) { + indent() << "if ("; + } else { + indent() << "} else if ("; + } + code_ << gen(cat->getPred(i)) << ") {\n"; + } else { + // last case doesn't need to be predicated + indent() << "} else {\n"; + } + + indent() << kTab << out << " = " << gen(inp) << ";\n"; + } + + indent() << "}\n"; + } + private: std::stringstream code_; const kir::Kernel* kernel_; diff --git a/csrc/compute_at_map.cpp b/csrc/compute_at_map.cpp index 43cc8b00411..854c8d1d602 100644 --- a/csrc/compute_at_map.cpp +++ b/csrc/compute_at_map.cpp @@ -99,8 +99,8 @@ bool IterDomainGraph::exprsMap( } TORCH_INTERNAL_ASSERT( - first->isA() || first->isA(), - "Merge and split are the only expressions supported through rfactor operations in compute at map, but found:\n", + first->isA() || first->isA() || first->isA(), + "Merge, split and resize are the only expressions supported through rfactor operations in compute at map, but found:\n", first->toString()); auto first_ids = ir_utils::filterByType( @@ -176,6 +176,15 @@ bool IterDomainGraph::exprsMap( } } + if (first->isA()) { + auto first_resize = first->as(); + auto second_resize = second->as(); + if (!first_resize->leftExpand()->sameAs(second_resize->leftExpand()) || + !first_resize->rightExpand()->sameAs(second_resize->rightExpand())) { + return false; + } + } + return true; } @@ -211,6 +220,7 @@ void IterDomainGraph::mapThroughExpr(Expr* first, Expr* second, bool forward) { for (auto out_i : c10::irange(first_ids.size())) { exact_nodes_.mapEntries(first_ids[out_i], second_ids[out_i]); permissive_nodes_.mapEntries(first_ids[out_i], second_ids[out_i]); + permissive_resize_nodes_.mapEntries(first_ids[out_i], second_ids[out_i]); } } @@ -407,6 +417,7 @@ void IterDomainGraph::build(Fusion* fusion) { auto id0 = *disjoint_set->begin(); for (auto id1 : disjoint_set->vector()) { permissive_nodes_.mapEntries(id0, id1); + permissive_resize_nodes_.mapEntries(id0, id1); exact_nodes_.mapEntries(id0, id1); sibling_sets_.mapEntries(id0, id1); } @@ -430,8 +441,22 @@ void IterDomainGraph::build(Fusion* fusion) { // Look for matching ID transformations in producer and consumer, replay // producer as consumer. We use the symmetric API of BestEffortReplay so // that both broadcast and squeeze are handled correctly. + // + // Note on the boolean flags: swizzles are skipped in both + // producer and consumer but resizes are not. const auto permissive_disjoint_sets = - BestEffortReplay::replayPasC(p_tv, c_tv, -1, pairwise_map) + BestEffortReplay::replayPasC( + p_tv, c_tv, -1, pairwise_map, true, true, false) + .getIterDomainEquivalence(); + + // Permissive-Resize map allows mappings of resize inputs and + // outputs + // + // Note on the boolean flags: swizzles and resizes are skipped + // in the permissive-resize map + const auto permissive_resize_disjoint_sets = + BestEffortReplay::replayPasC( + p_tv, c_tv, -1, pairwise_map, true, true, true) .getIterDomainEquivalence(); // For exact mapings do not map any broadcast dimensions to @@ -483,16 +508,12 @@ void IterDomainGraph::build(Fusion* fusion) { for (auto j : c10::irange(i + 1, vec.size())) { auto id2 = vec[j]; if (p_ids.count(id1) && c_ids.count(id2)) { - consumers_.at(id1).pushBack(id2); - producers_.at(id2).pushBack(id1); if (idIsAComputeAtLeafDomain(id1, p_tv, c_tv) && idIsALeafDomain(id2, c_tv)) { loop_nodes_.mapEntries(id1, id2); } } if (c_ids.count(id1) && p_ids.count(id2)) { - producers_.at(id1).pushBack(id2); - consumers_.at(id2).pushBack(id1); if (idIsAComputeAtLeafDomain(id2, p_tv, c_tv) && idIsALeafDomain(id1, c_tv)) { loop_nodes_.mapEntries(id1, id2); @@ -501,6 +522,31 @@ void IterDomainGraph::build(Fusion* fusion) { } } } + + // Mostly the same as the above for the permissive map but + // nothing to do for the loop map. + // The producer and consumer maps are based on the most + // permissive mappings, so they are set using the + // permissive-resize mappings. + for (auto& dset : permissive_resize_disjoint_sets.disjointSets()) { + auto& vec = dset->vector(); + for (auto i : c10::irange(vec.size())) { + auto id1 = vec[i]; + permissive_resize_nodes_.mapEntries(id1, vec[0]); + mapMaybeSwizzleOp(permissive_resize_nodes_, id1); + for (auto j : c10::irange(i + 1, vec.size())) { + auto id2 = vec[j]; + if (p_ids.count(id1) && c_ids.count(id2)) { + consumers_.at(id1).pushBack(id2); + producers_.at(id2).pushBack(id1); + } + if (c_ids.count(id1) && p_ids.count(id2)) { + producers_.at(id1).pushBack(id2); + consumers_.at(id2).pushBack(id1); + } + } + } + } } } } @@ -561,7 +607,7 @@ void IterDomainGraph::build(Fusion* fusion) { for (auto expr : exprs) { auto rfactor_inp_ids = ir_utils::filterByType(expr->inputs()); TORCH_INTERNAL_ASSERT( - expr->isA() || expr->isA(), + expr->isA() || expr->isA() || expr->isA(), "Wasn't expecting the expression type of:\n", expr->toString(), "\nto be an expression defined in an rfactor transformation."); @@ -688,6 +734,7 @@ void IterDomainGraph::initializeId( bool is_rfactor_id, bool is_leaf_id) { permissive_nodes_.initializeSet(id); + permissive_resize_nodes_.initializeSet(id); exact_nodes_.initializeSet(id); if (is_leaf_id) { loop_nodes_.initializeSet(id); @@ -1127,6 +1174,17 @@ void ComputeAtMap::buildConcreteIds() { auto concrete_id = computeConcreteId(first_id, IdMappingMode::LOOP); concrete_id_cache_[disjoint_set_shared_ptr] = concrete_id; } + + for (const auto& disjoint_set_shared_ptr : + id_graph_.permissiveResizeNodes().disjointSets()) { + TORCH_INTERNAL_ASSERT( + disjoint_set_shared_ptr->vector().size(), + "Cannot compute concrete id of empty set."); + auto first_id = disjoint_set_shared_ptr->vector().front(); + auto concrete_id = + computeConcreteId(first_id, IdMappingMode::PERMISSIVE_RESIZE); + concrete_id_cache_[disjoint_set_shared_ptr] = concrete_id; + } } bool ComputeAtMap::areExactExprs(Expr* expr_1, Expr* expr_2) { @@ -1349,6 +1407,8 @@ std::string ComputeAtMap::toString() const { ss << "Loop map:\n" << idGraphNodesToString(*this, IdMappingMode::LOOP); ss << "Permissive map:\n" << idGraphNodesToString(*this, IdMappingMode::PERMISSIVE); + ss << "Permissive-Resize map:\n" + << idGraphNodesToString(*this, IdMappingMode::PERMISSIVE_RESIZE); ss << "Consumer maps:\n"; for (auto key : getSortedKeys(id_graph_.consumers(), Statement::lessThan)) { auto consumers = id_graph_.consumers().at(key); @@ -1408,6 +1468,8 @@ const DisjointSets& ComputeAtMap::getIdSets( return id_graph_.loopNodes(); case IdMappingMode::PERMISSIVE: return id_graph_.permissiveNodes(); + case IdMappingMode::PERMISSIVE_RESIZE: + return id_graph_.permissiveResizeNodes(); } TORCH_INTERNAL_ASSERT(false, "Error with mapping mode provided."); } diff --git a/csrc/compute_at_map.h b/csrc/compute_at_map.h index 66169709aff..1d0e67bfecb 100644 --- a/csrc/compute_at_map.h +++ b/csrc/compute_at_map.h @@ -53,6 +53,10 @@ namespace nvfuser { // Map all iteration domains // Always contain root mappings (otherwise they could have been forwarded in // broadcast) +// IdMappingMode::PERMISSIVE_RESIZE +// Include everything in PERMISSIVE. Map also domains that are +// inputs and outputs of resize ops. Used for, e.g., propagating +// parallel types across those domains. // IdMappingMode::EXACT // Don't map any broadcast axes to non-broadcast axes // Do not forward through any broadcast IDs @@ -79,6 +83,9 @@ class TORCH_CUDA_CU_API IterDomainGraph { const DisjointSets& loopNodes() const { return loop_nodes_; } + const DisjointSets& permissiveResizeNodes() const { + return permissive_resize_nodes_; + } // Consumers and producers is not symmetric like the other sets const std::unordered_map>& @@ -132,8 +139,11 @@ class TORCH_CUDA_CU_API IterDomainGraph { DisjointSets exact_nodes_; DisjointSets almost_exact_nodes_; DisjointSets loop_nodes_; + DisjointSets permissive_resize_nodes_; - // Consumers and producers is not symmetric like the other sets + // Consumers and producers is not symmetric like the other sets. + // Mapping is based on the most permissive map, i.e., the + // permissive-resize map. std::unordered_map> consumers_; std::unordered_map> diff --git a/csrc/contiguity.cpp b/csrc/contiguity.cpp index 02d9646d43a..a9482bec0ae 100644 --- a/csrc/contiguity.cpp +++ b/csrc/contiguity.cpp @@ -331,6 +331,48 @@ void OrderedIdInformation::handle(Swizzle2D* swizzle) { } } +void OrderedIdInformation::handle(Resize* resize) { + // Find inputs in the active_ids_ vector + const auto in_it = + std::find(active_ids_.begin(), active_ids_.end(), resize->in()); + + if (in_it == active_ids_.end()) { + return; + } + + auto in_pos = std::distance(active_ids_.begin(), in_it); + + // Find inputs in the ordered transforms map + const auto in_ordered_it = consistently_ordered_ids_.find(resize->in()); + + bool in_ordered = in_ordered_it != consistently_ordered_ids_.end(); + + // Get root ids of the two inputs + const auto in_root_ids_it = id_to_root_ids_.find(resize->in()); + + TORCH_INTERNAL_ASSERT( + in_root_ids_it != id_to_root_ids_.end(), + "Error replaying transforms in contiguous ID checker."); + + const auto& in_root_ids = in_root_ids_it->second; + + // Update map for outputs + // Remove inputs from the active_ids_ and insert the output ID + active_ids_[in_pos] = resize->out(); + + // Not completely certain, but propagating these properties should e + // fine + if (in_ordered) { + consistently_ordered_ids_.emplace(resize->out()); + } + + if (exclusivelyConsumesRoots(resize->in())) { + exclusively_consumes_roots_.emplace(resize->out()); + } + + id_to_root_ids_[resize->out()] = in_root_ids; +} + NonDivisibleSplitDependencies::NonDivisibleSplitDependencies( // TODO: Revisit reduction rfactor axes and propagation. Should probably use // ca_map to propogate non divisibility dependencies across exact map. Still @@ -500,6 +542,19 @@ void ContigIDs::build(const std::vector& ids) { {root_domain_.begin(), root_domain_.end()}, {ids.begin(), ids.end()}); for (auto expr : exprs) { + if (auto resize = dynamic_cast(expr)) { + resize_deps_.insert(resize->out()); + } else { + if (std::any_of( + expr->inputs().begin(), expr->inputs().end(), [&](Val* inp) { + return inp->isA() && + resize_deps_.count(inp->as()); + })) { + for (auto out : ir_utils::filterByType(expr->outputs())) { + resize_deps_.insert(out); + } + } + } handle(expr); } } @@ -576,6 +631,12 @@ void ContigIDs::handle(Merge* merge) { return; } + // Don't allow contig indexing after resize as we need traverse back + // at least to direct outputs of resize ops + if (resize_deps_.count(merge->out())) { + return; + } + // All broadcasting if (last_root == nullptr) { return; diff --git a/csrc/contiguity.h b/csrc/contiguity.h index 6d83d1f2a0f..cf67c1a8c21 100644 --- a/csrc/contiguity.h +++ b/csrc/contiguity.h @@ -62,6 +62,8 @@ class OrderedIdInformation : public OptInDispatch { void handle(Swizzle2D* swizzle) override; + void handle(Resize* resize) override; + // Track which root ids were used to generate each iter domain std::unordered_map> id_to_root_ids_; @@ -255,6 +257,8 @@ class ContigIDs : public OptInDispatch { // cases, depending on specific swizzle type and axes. void handle(Swizzle2D* swizzle) override {} + void handle(Resize* resize) override {} + IterDomain* getCAIndexConcreteId(IterDomain* id) const; //! True if an ID is indexable. @@ -307,6 +311,9 @@ class ContigIDs : public OptInDispatch { std::unique_ptr consistent_transform_info_; NonDivisibleSplitDependencies non_divisible_id_info_; + + //! IDs that depend on resize output IDs + std::unordered_set resize_deps_; }; } // namespace nvfuser diff --git a/csrc/dispatch.cpp b/csrc/dispatch.cpp index 83ac89e7acc..6fe7e39d94f 100644 --- a/csrc/dispatch.cpp +++ b/csrc/dispatch.cpp @@ -187,6 +187,18 @@ void Expr::dispatch(T handler, Expr* expr) { ptr(handler)->handle(expr->as()); return; } + if (expr->isStrictlyA()) { + ptr(handler)->handle(expr->as()); + return; + } + if (expr->isStrictlyA()) { + ptr(handler)->handle(expr->as()); + return; + } + if (expr->isStrictlyA()) { + ptr(handler)->handle(expr->as()); + return; + } if (expr->isStrictlyA()) { ptr(handler)->handle(expr->as()); return; @@ -199,6 +211,10 @@ void Expr::dispatch(T handler, Expr* expr) { ptr(handler)->handle(expr->as()); return; } + if (expr->isStrictlyA()) { + ptr(handler)->handle(expr->as()); + return; + } if (expr->isStrictlyA()) { ptr(handler)->handle(expr->as()); return; @@ -459,6 +475,18 @@ void Expr::constDispatch(T handler, const Expr* expr) { ptr(handler)->handle(expr->as()); return; } + if (expr->isStrictlyA()) { + ptr(handler)->handle(expr->as()); + return; + } + if (expr->isStrictlyA()) { + ptr(handler)->handle(expr->as()); + return; + } + if (expr->isStrictlyA()) { + ptr(handler)->handle(expr->as()); + return; + } if (expr->isStrictlyA()) { ptr(handler)->handle(expr->as()); return; @@ -471,6 +499,10 @@ void Expr::constDispatch(T handler, const Expr* expr) { ptr(handler)->handle(expr->as()); return; } + if (expr->isStrictlyA()) { + ptr(handler)->handle(expr->as()); + return; + } if (expr->isStrictlyA()) { ptr(handler)->handle(expr->as()); return; @@ -859,6 +891,15 @@ void OptOutConstDispatch::handle(const BroadcastOp* stmt) { void OptOutConstDispatch::handle(const SqueezeOp* stmt) { unhandled(stmt); } +void OptOutConstDispatch::handle(const CatOp* stmt) { + unhandled(stmt); +} +void OptOutConstDispatch::handle(const PadOp* stmt) { + unhandled(stmt); +} +void OptOutConstDispatch::handle(const SliceOp* stmt) { + unhandled(stmt); +} void OptOutConstDispatch::handle(const Split* stmt) { unhandled(stmt); @@ -869,6 +910,9 @@ void OptOutConstDispatch::handle(const Merge* stmt) { void OptOutConstDispatch::handle(const Swizzle2D* stmt) { unhandled(stmt); } +void OptOutConstDispatch::handle(const Resize* stmt) { + unhandled(stmt); +} void OptOutConstDispatch::handle(const TransposeOp* stmt) { unhandled(stmt); } @@ -1044,6 +1088,15 @@ void OptOutDispatch::handle(BroadcastOp* stmt) { void OptOutDispatch::handle(SqueezeOp* stmt) { unhandled(stmt); } +void OptOutDispatch::handle(CatOp* stmt) { + unhandled(stmt); +} +void OptOutDispatch::handle(PadOp* stmt) { + unhandled(stmt); +} +void OptOutDispatch::handle(SliceOp* stmt) { + unhandled(stmt); +} void OptOutDispatch::handle(Split* stmt) { unhandled(stmt); @@ -1054,6 +1107,9 @@ void OptOutDispatch::handle(Merge* stmt) { void OptOutDispatch::handle(Swizzle2D* stmt) { unhandled(stmt); } +void OptOutDispatch::handle(Resize* stmt) { + unhandled(stmt); +} void OptOutDispatch::handle(TransposeOp* stmt) { unhandled(stmt); } diff --git a/csrc/dispatch.h b/csrc/dispatch.h index 03db6673143..dd66cd901ca 100644 --- a/csrc/dispatch.h +++ b/csrc/dispatch.h @@ -102,6 +102,9 @@ class ShiftOp; class GatherOp; class ViewAsScalar; class ViewOp; +class CatOp; +class PadOp; +class SliceOp; class AggregateExpr; class SendRecv; @@ -110,6 +113,7 @@ class SendRecv; class Split; class Merge; class Swizzle2D; +class Resize; namespace kir { class Predicate; @@ -182,10 +186,14 @@ class TORCH_CUDA_CU_API OptOutConstDispatch : public PolymorphicBase { virtual void handle(const MmaOp* stmt); virtual void handle(const BroadcastOp* stmt); virtual void handle(const SqueezeOp* stmt); + virtual void handle(const CatOp* stmt); + virtual void handle(const PadOp* stmt); + virtual void handle(const SliceOp* stmt); virtual void handle(const Split* stmt); virtual void handle(const Merge* stmt); virtual void handle(const Swizzle2D* stmt); + virtual void handle(const Resize* stmt); virtual void handle(const TransposeOp* stmt); virtual void handle(const ExpandOp* stmt); virtual void handle(const ShiftOp* stmt); @@ -260,10 +268,14 @@ class TORCH_CUDA_CU_API OptOutDispatch : public PolymorphicBase { virtual void handle(MmaOp* stmt); virtual void handle(BroadcastOp* stmt); virtual void handle(SqueezeOp* stmt); + virtual void handle(CatOp* stmt); + virtual void handle(PadOp* stmt); + virtual void handle(SliceOp* stmt); virtual void handle(Split* stmt); virtual void handle(Merge* stmt); virtual void handle(Swizzle2D* stmt); + virtual void handle(Resize* stmt); virtual void handle(TransposeOp* stmt); virtual void handle(ExpandOp* stmt); virtual void handle(ShiftOp* stmt); diff --git a/csrc/index_compute.cpp b/csrc/index_compute.cpp index 6fd2716bdf6..4e0c98e00a8 100644 --- a/csrc/index_compute.cpp +++ b/csrc/index_compute.cpp @@ -589,8 +589,38 @@ void IndexCompute::handle(Swizzle2D* swizzle_2d) { } } +void IndexCompute::handle(Resize* resize) { + auto out_id = maybeGetExactMapConcreteID(resize->out()); + auto in_id = maybeGetExactMapConcreteID(resize->in()); + + auto out_it = index_map_.find(out_id); + + if (out_it == index_map_.end()) { + return; + } + + const auto out_ind = out_it->second; + + if (isZero(out_id) || hasZeroMerged(out_id)) { + // When the out ID is (partially) zero, the in ID is not indexable. Don't + // add any new mapping to the index and extent maps. This is fine since when + // a resize shows up as part of rfactor transformations, the input to the + // resize is not indexed as the indexing is done using the rfactor root + // domain. This could be an issue when a resize is shows up outside of + // rfactor transfomations, but currently that only can happen when a + // producer tensor is transformed to look like a consumer. Since inlining is + // not allowed with resize, the out ID should never be a zero domain in that + // case. + return; + } else { + index_map_[in_id] = sub(out_ind, resize->leftExpand()); + extent_map_[in_id] = sub( + sub(getExtent(out_id), resize->leftExpand()), resize->rightExpand()); + } +} + void IndexCompute::handle(Expr* e) { - auto is_expected_type = e->isOneOf(); + auto is_expected_type = e->isOneOf(); TORCH_INTERNAL_ASSERT( is_expected_type, "Invalid expr type found in transform traversal."); BackwardVisitor::handle(e); @@ -1369,97 +1399,10 @@ std::vector Index::getGlobalProducerStridedIndices( const std::unordered_map& override_index) { FUSER_PERF_SCOPE("GpuLower::Lower::getGlobalProducerIndex"); - // Replay producer to look like consumer so we can index on producer since - // our loop nests look like consumer - auto pairwise_map = PairwiseRootDomainMap(producer_tv, consumer_tv, false); - - TensorDomain* producerAsC = - TransformReplay::replayPasC(producer_tv, consumer_tv, -1, pairwise_map) - .first; - - // Make the producer_tv look like consumer while performing indexing math - ir_utils::TVDomainGuard domain_guard(producer_tv, producerAsC); + auto root_indices = getProducerRootIndices( + producer_tv, consumer_tv, loops, rotated_loops, override_index); - // Map sent to best effort replay needs to match the exact incantation for - // compute_at_mode.cpp with MappingMode::Index - auto c2p_root_map = - PairwiseRootDomainMap(producer_tv, consumer_tv, true) - .mapConsumerToProducer(consumer_tv->domain(), producer_tv->domain()); - - // This replay has to be consistent with compute at index map. - BestEffortReplay replay_producer_as_consumer( - producer_tv->domain()->domain(), - consumer_tv->domain()->domain(), - c2p_root_map); - - auto c2p_map = replay_producer_as_consumer.getReplay(); - - // Make sure at least root domains are mapped even when extents may - // be different. This mapping is important for the indexing lookup - // tensors of PyTorch gather as a producer. The IDs of a lookup - // tensor may have larger extents than those of the corresponding - // output tensor, but the index expressions to those output IDs can - // still be used for the producer. Note that we always do not map - // the indirectly accessed ID and its corresponding output ID. The - // above relaxed mapping is only for the rest of the IDs. - // - // Note that when the consumer has swizzle, the swizzle are skipped. For - // example, if we have: - // consumer: - // root: I0, I1, I2 - // leaf: I0, I3, I4 - // producer: - // root I5, I6, I7 - // where I3, I4 = swizzle(I1, I2) , then the c2p map will be I3->I6, I4->I7, - // I1 and I2 are not mapped. For this case, we should allow the root unmapped, - // If we add I1->I6 and I2->I7, the c2p map will no longer be injective, which - // is not what we want. - const auto p2c_map_ = invertOneToOneMap(c2p_map); - for (const auto& kv : - PairwiseRootDomainMap(producer_tv, consumer_tv, true, false) - .mapConsumerToProducer( - consumer_tv->domain(), producer_tv->domain())) { - auto consumer_root_id = kv.first; - auto producer_root_id = kv.second; - if (c2p_map.find(consumer_root_id) == c2p_map.end() && - p2c_map_.find(producer_root_id) == p2c_map_.end()) { - c2p_map.emplace(consumer_root_id, producer_root_id); - } - } - - const auto p2c_map = invertOneToOneMap(c2p_map); - - // Forward vectorized IDs to index into producer correctly - // We want p_id to be vectorized like consumer just for the indexing, then we - // need to switch it back later. Store previous state here when changing. We - // need to do this as replaying producer as consumer can use replay best - // effort which means some domains may be producer's original domains. - std::vector> p_id_backup; - for (auto entry : c2p_map) { - auto ref_id = GpuLower::current()->caMap()->getConcreteMappedID( - entry.first, IdMappingMode::EXACT); - auto p_id = entry.second; - if (ref_id->getParallelType() == ParallelType::Vectorize) { - p_id_backup.emplace_back(std::make_pair(p_id, p_id->getParallelType())); - p_id->parallelize(ParallelType::Vectorize); - } else if (ref_id->getParallelType() == ParallelType::MisalignedVectorize) { - p_id->parallelize(ParallelType::MisalignedVectorize); - } - } - - auto producer_indexing_from_idgraph = getTensorIndexFromIdGraph( - loops, rotated_loops, consumer_tv, producer_tv, true, c2p_map); - - auto producer_indexing = producer_indexing_from_idgraph.index; - - // Revert p_ids - for (auto entry : p_id_backup) { - entry.first->parallelize(entry.second); - } - - // Indices should now be mapped onto IterDomains in producer, so just grab - // and use them. - auto root_dom = producer_tv->getMaybeRFactorDomain(); + const auto& root_dom = producer_tv->getMaybeRFactorDomain(); // TODO: Abstract stride logic to reuse with consumer indexing std::vector strides(root_dom.size(), nullptr); @@ -1515,44 +1458,7 @@ std::vector Index::getGlobalProducerStridedIndices( std::vector strided_inds( root_dom.size(), GpuLower::current()->kernel()->zeroVal()); for (const auto i : c10::irange(root_dom.size())) { - if (root_dom[i]->isReduction() || root_dom[i]->isBroadcast()) { - continue; - } - - Val* root_ind = nullptr; - auto override_it = override_index.find(root_dom[i]); - const bool is_overriden = override_it != override_index.end(); - if (is_overriden) { - root_ind = override_it->second; - } else if ( - producer_indexing.indexMap().find(root_dom[i]) != - producer_indexing.indexMap().end()) { - root_ind = producer_indexing.indexMap().at(root_dom[i]); - } else if (root_dom[i]->isBroadcast()) { - root_ind = GpuLower::current()->kernel()->zeroVal(); - } - - TORCH_INTERNAL_ASSERT( - root_ind != nullptr, - "Couldn't find root mapping for ", - producer_tv->toString(), - " dim: ", - i, - " id: ", - root_dom[i]->toString()); - - root_ind = getProducerIndexWithHalo( - producer_tv, i, root_ind, consumer_tv, is_overriden); - - root_ind = getProducerIndexWithGather( - root_ind, - i, - producer_tv, - consumer_tv, - producer_indexing_from_idgraph.concrete_index.indexMap()); - - root_ind = getProducerIndexWithPartialSplit( - root_ind, root_dom[i], producer_tv, consumer_tv); + Val* root_ind = root_indices.at(i); if (root_ind->isZeroInt()) { continue; @@ -1624,12 +1530,13 @@ std::vector Index::getNonGlobalProducerStridedIndices( const std::unordered_set& rotated_loops, const std::unordered_map& override_index) { const auto gpu_lower = GpuLower::current(); - // Replay producer to look like consumer so we can index on producer since our // loop nests look like consumer auto pairwise_map = PairwiseRootDomainMap(producer_tv, consumer_tv); + // Resize ops can be and should be replayed. auto producer_replayed_as_consumer = - TransformReplay::replayPasC(producer_tv, consumer_tv, -1, pairwise_map) + TransformReplay::replayPasC( + producer_tv, consumer_tv, -1, pairwise_map, false, true) .first; ir_utils::TVDomainGuard domain_guard( @@ -1856,14 +1763,25 @@ Val* Index::getLinearLogicalIndex( getGlobalConsumerStridedIndices(consumer_tv, loops, rotated_loops)); } -std::vector Index::getPerDimLogicalIndex( +std::vector Index::getConsumerPerDimLogicalIndex( TensorView* consumer_tv, const std::vector& loops, const std::unordered_set& rotated_loops) { auto guard = ir_utils::overrideContiguityGuard(consumer_tv, false); IndexFromIdGraph index_from_id_graph = getTensorIndexFromIdGraph(loops, rotated_loops, consumer_tv); - return getRootIndices(consumer_tv, loops, index_from_id_graph); + return getConsumerRootIndices(consumer_tv, loops, index_from_id_graph); +} + +std::vector Index::getProducerPerDimLogicalIndex( + TensorView* producer_tv, + const TensorView* consumer_tv, + const std::vector& loops, + const std::unordered_set& rotated_loops, + const std::unordered_map& override_index) { + auto guard = ir_utils::overrideContiguityGuard(producer_tv, false); + return getProducerRootIndices( + producer_tv, consumer_tv, loops, rotated_loops, override_index); } std::vector Index::getStrides(const TensorView* tv) { @@ -1917,7 +1835,7 @@ std::vector Index::getStrides(const TensorView* tv) { return strides; } -std::vector Index::getRootIndices( +std::vector Index::getConsumerRootIndices( const TensorView* tv, const std::vector& loops, const IndexFromIdGraph& index_from_id_graph) { @@ -1951,6 +1869,153 @@ std::vector Index::getRootIndices( return root_inds; } +std::vector Index::getProducerRootIndices( + TensorView* producer_tv, + const TensorView* consumer_tv, + const std::vector& loops, + const std::unordered_set& rotated_loops, + const std::unordered_map& override_index) { + FUSER_PERF_SCOPE("GpuLower::Lower::getProducerRootIndices"); + // Replay producer to look like consumer so we can index on producer since + // our loop nests look like consumer + auto pairwise_map = PairwiseRootDomainMap(producer_tv, consumer_tv, false); + + TensorDomain* producerAsC = + TransformReplay::replayPasC( + producer_tv, consumer_tv, -1, pairwise_map, false, true) + .first; + + // Make the producer_tv look like consumer while performing indexing math + ir_utils::TVDomainGuard domain_guard(producer_tv, producerAsC); + + // Map sent to best effort replay needs to match the exact incantation for + // compute_at_mode.cpp with MappingMode::Index + auto c2p_root_map = + PairwiseRootDomainMap(producer_tv, consumer_tv, true) + .mapConsumerToProducer(consumer_tv->domain(), producer_tv->domain()); + + // This replay has to be consistent with compute at index map. + BestEffortReplay replay_producer_as_consumer( + producer_tv->domain()->domain(), + consumer_tv->domain()->domain(), + c2p_root_map); + + auto c2p_map = replay_producer_as_consumer.getReplay(); + + // Make sure at least root domains are mapped even when extents may + // be different. This mapping is important for the indexing lookup + // tensors of PyTorch gather as a producer. The IDs of a lookup + // tensor may have larger extents than those of the corresponding + // output tensor, but the index expressions to those output IDs can + // still be used for the producer. Note that we always do not map + // the indirectly accessed ID and its corresponding output ID. The + // above relaxed mapping is only for the rest of the IDs. + // + // Note that when the consumer has swizzle, the swizzle are skipped. For + // example, if we have: + // consumer: + // root: I0, I1, I2 + // leaf: I0, I3, I4 + // producer: + // root I5, I6, I7 + // where I3, I4 = swizzle(I1, I2) , then the c2p map will be I3->I6, I4->I7, + // I1 and I2 are not mapped. For this case, we should allow the root unmapped, + // If we add I1->I6 and I2->I7, the c2p map will no longer be injective, which + // is not what we want. + const auto p2c_map_ = invertOneToOneMap(c2p_map); + for (const auto& kv : + PairwiseRootDomainMap(producer_tv, consumer_tv, true, false) + .mapConsumerToProducer( + consumer_tv->domain(), producer_tv->domain())) { + auto consumer_root_id = kv.first; + auto producer_root_id = kv.second; + if (c2p_map.find(consumer_root_id) == c2p_map.end() && + p2c_map_.find(producer_root_id) == p2c_map_.end()) { + c2p_map.emplace(consumer_root_id, producer_root_id); + } + } + + const auto p2c_map = invertOneToOneMap(c2p_map); + + // Forward vectorized IDs to index into producer correctly + // We want p_id to be vectorized like consumer just for the indexing, then we + // need to switch it back later. Store previous state here when changing. We + // need to do this as replaying producer as consumer can use replay best + // effort which means some domains may be producer's original domains. + std::vector> p_id_backup; + for (auto entry : c2p_map) { + auto ref_id = GpuLower::current()->caMap()->getConcreteMappedID( + entry.first, IdMappingMode::EXACT); + auto p_id = entry.second; + if (ref_id->getParallelType() == ParallelType::Vectorize) { + p_id_backup.emplace_back(std::make_pair(p_id, p_id->getParallelType())); + p_id->parallelize(ParallelType::Vectorize); + } else if (ref_id->getParallelType() == ParallelType::MisalignedVectorize) { + p_id->parallelize(ParallelType::MisalignedVectorize); + } + } + + auto producer_indexing_from_idgraph = getTensorIndexFromIdGraph( + loops, rotated_loops, consumer_tv, producer_tv, true, c2p_map); + + auto producer_indexing = producer_indexing_from_idgraph.index; + + // Revert p_ids + for (auto entry : p_id_backup) { + entry.first->parallelize(entry.second); + } + + // Indices should now be mapped onto IterDomains in producer, so just grab + // and use them. + auto root_dom = producer_tv->getMaybeRFactorDomain(); + + std::vector root_inds( + root_dom.size(), GpuLower::current()->kernel()->zeroVal()); + + for (const auto i : c10::irange(root_dom.size())) { + if (root_dom[i]->isReduction() || root_dom[i]->isBroadcast()) { + continue; + } + + Val* root_ind = nullptr; + auto override_it = override_index.find(root_dom[i]); + const bool is_overriden = override_it != override_index.end(); + if (is_overriden) { + root_ind = override_it->second; + } else if ( + producer_indexing.indexMap().find(root_dom[i]) != + producer_indexing.indexMap().end()) { + root_ind = producer_indexing.indexMap().at(root_dom[i]); + } + + TORCH_INTERNAL_ASSERT( + root_ind != nullptr, + "Couldn't find root mapping for ", + producer_tv->toString(), + " dim: ", + i, + " id: ", + root_dom[i]->toString()); + + root_ind = getProducerIndexWithHalo( + producer_tv, i, root_ind, consumer_tv, is_overriden); + + root_ind = getProducerIndexWithGather( + root_ind, + i, + producer_tv, + consumer_tv, + producer_indexing_from_idgraph.concrete_index.indexMap()); + + root_ind = getProducerIndexWithPartialSplit( + root_ind, root_dom[i], producer_tv, consumer_tv); + + root_inds.at(i) = root_ind; + } + + return root_inds; +} + std::vector Index::getGlobalConsumerStridedIndices( const TensorView* consumer_tv, const std::vector& loops, @@ -1964,8 +2029,8 @@ std::vector Index::getGlobalConsumerStridedIndices( auto strides = getStrides(consumer_tv); // if we need to override index, we need to generate the index from each // root axis firstly. - auto root_inds = getRootIndices(consumer_tv, loops, index_from_id_graph); - auto root_dom = consumer_tv->getMaybeRFactorDomain(); + auto root_inds = + getConsumerRootIndices(consumer_tv, loops, index_from_id_graph); // Global striding auto vectorize_shift = @@ -2158,7 +2223,6 @@ std::vector Index::getNonGlobalConsumerStridedIndices( strided_inds.push_back(db_strided_index); } } - return strided_inds; } @@ -2284,8 +2348,9 @@ struct PredicateDomainInfo { // set is used to remove redundant predicates when gathering // unswitch predicates. std::unordered_set covered_ids; - // True if this predicate is for a non-divisible split - bool is_non_divisible_split = false; + // True if this predicate is for an intermediate domain. Examples + // include domains with non-divisible split and resized domains. + bool is_intermediate_domain = false; }; // Find iteration domains in the history of a consumer to predicate comprised @@ -2302,7 +2367,14 @@ std::vector getPredicateContigIds( const std::unordered_map& consumer_index_map) { const auto gpu_lower = GpuLower::current(); - const auto& consumer_root_domain = consumer_tv->getRootDomain(); + // When there's a resize expr between the root and the rfactor + // domains, predicate the rfactor domain. Otherwise, predicate the + // root domain. The actual size of an IterDomain after resize + // changes, and the output IterDomain needs to be used to generate + // its predicate. + const auto& consumer_root_domain = ir_utils::hasResizedRfactor(consumer_tv) + ? consumer_tv->getMaybeRFactorDomain() + : consumer_tv->getRootDomain(); if (consumer_root_domain.empty()) { return std::vector(); @@ -2598,7 +2670,7 @@ std::pair getStartAndStopOffsetsForGather( std::pair getStartAndStopLimitOffsets( IterDomain* consumer_id, bool padding_predicate, - bool non_divisible_pred) { + bool intemediate_domain_pred) { const auto gpu_lower = GpuLower::current(); TORCH_INTERNAL_ASSERT(consumer_id != nullptr); @@ -2606,7 +2678,7 @@ std::pair getStartAndStopLimitOffsets( Val* start_limit = consumer_id->start(); Val* stop_limit = SimplifyingIrBuilder::negExpr(consumer_id->stopOffset()); - if (!non_divisible_pred) { + if (!intemediate_domain_pred) { AxisHaloInfo halo_info = gpu_lower->haloInfo()->getRootAxisInfo(consumer_id); @@ -2650,11 +2722,11 @@ std::pair getStartAndStopOffsets( const std::unordered_map& consumer_stop_index_map, bool padding_predicate, bool unswitch, - bool non_divisible_pred) { + bool intermediate_domain_pred) { // By default, the offsets for the start and stop predicates are // just zero. All halo-related adjustments are done at root domains, // so consumer_id is not a root domain, no adjustment is required. - if (consumer_id->definition() != nullptr && !non_divisible_pred) { + if (consumer_id->definition() != nullptr && !intermediate_domain_pred) { return { GpuLower::current()->kernel()->zeroVal(), GpuLower::current()->kernel()->zeroVal()}; @@ -2666,7 +2738,7 @@ std::pair getStartAndStopOffsets( Val* stop_offset = GpuLower::current()->kernel()->zeroVal(); // These adjustments are not required when predicating non-divisible splits - if (!non_divisible_pred) { + if (!intermediate_domain_pred) { if (consumer_def->isA()) { std::tie(start_offset, stop_offset) = getStartAndStopOffsetsForShift( consumer_tv, consumer_id, padding_predicate); @@ -2702,7 +2774,7 @@ std::pair getStartAndStopOffsets( // Get the boundaries of two ends auto limits = getStartAndStopLimitOffsets( - consumer_id, padding_predicate, non_divisible_pred); + consumer_id, padding_predicate, intermediate_domain_pred); // At this point, we have everything to create both start and stop // predicates as: @@ -2722,26 +2794,6 @@ std::pair getStartAndStopOffsets( return {start_offset, stop_offset}; } -// A partial value of a start offset is returned if determined to be -// safe. Nullptr is returned if it can be omitted completely. -Val* simplifyStartOffset(Val* start_offset) { - // Start predicate can be omitted when start_offset >= 0. - auto offset_val = start_offset->as()->value(); - if (offset_val.has_value() && offset_val.value() >= 0) { - return nullptr; - } - - // start_offset may look like min(0, window_index - pad). Then, can - // remove min and leave the rhs only. - auto def = dynamic_cast(start_offset->definition()); - if (def != nullptr && def->getBinaryOpType() == BinaryOpType::Min && - def->lhs()->isZeroInt()) { - return def->rhs(); - } - - return start_offset; -} - bool canOmitStopPredicate( Val* stop_index, Val* stop_offset, @@ -2949,7 +3001,7 @@ std::vector Index::getReferenceRootPredicates( consumer_stop_index_map, shift_padding, unswitch_or_vec_loop != nullptr, - contig_id_entry.is_non_divisible_split); + contig_id_entry.is_intermediate_domain); auto stop_index = consumer_stop_indexing_it->second; auto start_index = consumer_start_index_map.at(contig_id); @@ -2975,18 +3027,13 @@ std::vector Index::getReferenceRootPredicates( // Build predicates for start positions as: // start_index + start_offset >= 0 - auto start_offset = simplifyStartOffset(info.start_offset_); - if (start_offset == nullptr) { - info.start_predicate_ = GpuLower::current()->kernel()->trueVal(); - } else { - auto offsetted_start_index = - SimplifyingIrBuilder::addExpr(start_index, start_offset); - auto start_pred = - SimplifyingIrBuilder::geExpr( - offsetted_start_index, GpuLower::current()->kernel()->zeroVal()) - ->as(); - info.start_predicate_ = start_pred; - } + auto offsetted_start_index = + SimplifyingIrBuilder::addExpr(start_index, info.start_offset_); + auto start_pred = + SimplifyingIrBuilder::geExpr( + offsetted_start_index, GpuLower::current()->kernel()->zeroVal()) + ->as(); + info.start_predicate_ = start_pred; // Build predicates for stop positions as: // stop_index + stop_offset < IterDomain::extent @@ -3039,7 +3086,7 @@ Val* Index::eye( const std::unordered_set& rotated_loops, DataType dtype) { auto indices = - Index::getPerDimLogicalIndex(consumer_tv, loops, rotated_loops); + Index::getConsumerPerDimLogicalIndex(consumer_tv, loops, rotated_loops); TORCH_INTERNAL_ASSERT(indices.size() == 2); auto result = castOp(dtype, eq(indices[0], indices[1])); GpuLower::current()->commonScalarMap().hoistScalar(result, loops); diff --git a/csrc/index_compute.h b/csrc/index_compute.h index 24dca245c94..5a21a34f4ea 100644 --- a/csrc/index_compute.h +++ b/csrc/index_compute.h @@ -75,6 +75,7 @@ class IndexCompute : public BackwardVisitor { void handle(Merge*) override; void handle(Expr*) override; void handle(Swizzle2D*) override; + void handle(Resize*) override; // return extent_map_[id] if exists, else return id->extent() Val* getExtent(IterDomain* id) const; @@ -327,12 +328,20 @@ class Index { // get the strides of a tensor used for the index lowering static std::vector getStrides(const TensorView* tv); - // get the root indices of a tensor used for the index lowering - static std::vector getRootIndices( + // get the root indices of a consumer tensor + static std::vector getConsumerRootIndices( const TensorView* tv, const std::vector& loops, const IndexFromIdGraph& index_from_id_graph); + // get the root indices of a producer tensor + static std::vector getProducerRootIndices( + TensorView* producer, + const TensorView* consumer, + const std::vector& loops, + const std::unordered_set& rotated_loops, + const std::unordered_map& override_index = {}); + public: // Producer if it's in global memory static std::vector getGlobalProducerStridedIndices( @@ -412,11 +421,20 @@ class Index { //! root domain of a consumer tensor. The returned index is intended //! to be used for the computation of some tensor factories, such as: //! eye - static std::vector getPerDimLogicalIndex( + static std::vector getConsumerPerDimLogicalIndex( TensorView* consumer_tv, const std::vector& loops, const std::unordered_set& rotated_loops); + //! Returns a vector of logical indices mapped onto the (rfactor) + //! root domain of a producer tensor. + static std::vector getProducerPerDimLogicalIndex( + TensorView* producer_tv, + const TensorView* consumer_tv, + const std::vector& loops, + const std::unordered_set& rotated_loops, + const std::unordered_map& override_index = {}); + //! Take a consumer tensorview and loop nest and generates predicates //! associated with the concrete roots of the loop nest. Returns a list of //! predicates, and a list of concrete roots they're associated with. It diff --git a/csrc/ir_internal_nodes.h b/csrc/ir_internal_nodes.h index 36bf64ec130..dc028ade1a3 100644 --- a/csrc/ir_internal_nodes.h +++ b/csrc/ir_internal_nodes.h @@ -1472,6 +1472,23 @@ class TORCH_CUDA_CU_API IterDomain : public Val { bool inner_split, bool trim_out_of_bounds); + //! Resize an IterDomain by expanding both the left and right sides + //! by given widths. The resulting IterDomain has an extent of + //! (left_expansion + in->extent() + right_expansion). Note that the + //! expansion factors can be negative, meaning the input IterDomain + //! is shrunk. This is the case when resize is used to represent + //! slice. + //! + //! When mark_as_rfactor is true, the output IterDomain + //! is marked as an rfactor domain. For example, expressions such as + //! PadOp and SliceOp resize IterDomains and generate rfactor + //! resized domains. + static IterDomain* resize( + IterDomain* in, + Val* left_expansion, + Val* right_expansion, + bool mark_as_rfactor = false); + bool isReduction() const { return getIterType() == IterType::Reduction; } @@ -2112,6 +2129,46 @@ class TORCH_CUDA_CU_API Swizzle2D : public Expr { } }; +//! IterDomain expression to resize +class TORCH_CUDA_CU_API Resize : public Expr { + public: + using Expr::Expr; + + // Expand the input domain by left_expand and right_expand for each + // of the start and end sides, respectively + Resize( + IrBuilderPasskey, + IterDomain* out, + IterDomain* in, + Val* left_expand, + Val* right_expand); + + NVFUSER_DECLARE_CLONE_AND_CREATE + + virtual const char* getOpString() const override { + return "Resize"; + } + + std::string toString(int indent_size = 0) const override; + std::string toInlineString(int indent_size = 0) const override; + + IterDomain* out() const { + return output(0)->as(); + } + + IterDomain* in() const { + return input(0)->as(); + } + + Val* leftExpand() const { + return attributeVal(0); + } + + Val* rightExpand() const { + return attributeVal(1); + } +}; + //! Integer value which has a special name //! //! These could be: @@ -2168,4 +2225,161 @@ class TORCH_CUDA_CU_API NamedScalar : public Val { std::string name_; }; +class TORCH_CUDA_CU_API PadOp : public Expr { + public: + using Expr::Expr; + + //! Pad a tensor as specified by a vector of integer scalars. For + //! the actual semantics, see the torch.pad documentation. Note that + //! unlike torch.pad, the pad_widths vector parameter must contain + //! width vals for all dimensions. For non-padded dimensions, width + //! vals should be integer zero. + PadOp( + IrBuilderPasskey passkey, + TensorView* out, + TensorView* inp, + const std::vector& pad_widths); + + NVFUSER_DECLARE_CLONE_AND_CREATE + + virtual const char* getOpString() const override { + return "PadOp"; + } + + std::string toString(int indent_size = 0) const override; + std::string toInlineString(int indent_size = 0) const override; + + Val* out() const { + return output(0); + } + + Val* in() const { + return input(0); + } + + //! Return axes that are actually paded, i.e., those that have + //! non-zero pad widths + std::vector getPaddedAxes() const; + + //! Return pad widths of the given axis, which are just zero for non padded + //! dimensions + std::pair getPadWidths(int axis) const; + + //! Return the pad widths of all dimensions, including non-padded ones + std::vector getPadWidths() const; + + private: + //! Offset of pad_width inputs in the input vector + int getPadWidthInputOffset() const { + return 1; + } + + //! Iterator to the first pad_width input + auto getPadWidthInputBegin() const { + return inputs().cbegin() + getPadWidthInputOffset(); + } + + //! Iterator to the end of the pad_width inputs + auto getPadWidthInputEnd() const { + return inputs().cend(); + } +}; + +// Similar to at::indexing::Slice +struct Slice { + Val* start = nullptr; + Val* stop = nullptr; + Val* step = nullptr; +}; + +class TORCH_CUDA_CU_API SliceOp : public Expr { + public: + using Expr::Expr; + + SliceOp( + IrBuilderPasskey passkey, + TensorView* out, + TensorView* inp, + const std::vector& ranges); + + NVFUSER_DECLARE_CLONE_AND_CREATE + + virtual const char* getOpString() const override { + return "SliceOp"; + } + + std::string toString(int indent_size = 0) const override; + std::string toInlineString(int indent_size = 0) const override; + + Val* out() const { + return output(0); + } + + Val* in() const { + return input(0); + } + + std::vector getRanges() const; + + private: + //! Offset of ranges input in the input vector + int getRangeInputOffset() const { + return 1; + } + + //! Iterator to the first range inputs + auto getRangeInputBegin() const { + return inputs().cbegin() + getRangeInputOffset(); + } + + //! Iterator to the end of the range inputs + auto getRangeInputEnd() const { + return inputs().cend(); + } +}; + +class TORCH_CUDA_CU_API CatOp : public Expr { + public: + using Expr::Expr; + + CatOp( + IrBuilderPasskey passkey, + Val* out, + const std::vector& inputs, + int concatenated_dim); + + //! Create a cat op with the index and predicates for codegen. Only + //! used for the Kernel container + CatOp( + IrBuilderPasskey passkey, + Val* out, + const std::vector& inputs, + int concatenated_dim, + Val* concatenated_domain_index, + const std::vector& preds); + + NVFUSER_DECLARE_CLONE_AND_CREATE + + virtual const char* getOpString() const override { + return "CatOp"; + } + + std::string toString(int indent_size = 0) const override; + std::string toInlineString(int indent_size = 0) const override; + + int concatenatedDim() const { + return attribute(0)->as>()->value; + } + + //! The index val that determines which input tensor should be used + //! to fill the particular output position of this expression. Only + //! valid after indexing + Val* getConcatenatedDomainIndex() const; + + //! Gets a Bool indicating if the input tensor specified by + //! tensor_idx should be used to fill the output tensor. Only valid + //! with the Kernel container + Bool* getPred(int input_idx) const; +}; + } // namespace nvfuser diff --git a/csrc/ir_nodes.cpp b/csrc/ir_nodes.cpp index b07bc00f8bb..36f05cf85b5 100644 --- a/csrc/ir_nodes.cpp +++ b/csrc/ir_nodes.cpp @@ -23,6 +23,8 @@ #include #include +#include +#include #include #include #include @@ -2135,6 +2137,52 @@ std::pair IterDomain::swizzle( return std::make_pair(out_x, out_y); } +IterDomain* IterDomain::resize( + IterDomain* in, + Val* left_expansion, + Val* right_expansion, + bool mark_as_rfactor) { + TORCH_CHECK( + left_expansion->isIntegralScalar(), + "Expansion factor must be an integer scalar: ", + left_expansion->toString()); + TORCH_CHECK( + right_expansion->isIntegralScalar(), + "Expansion factor must be an integer scalar: ", + right_expansion->toString()); + + // Only Inteation is considered for now. + TORCH_CHECK( + in->getIterType() == IterType::Iteration || + in->getIterType() == IterType::Broadcast, + "Not a valid IterType: ", + in->getIterType()); + + TORCH_CHECK( + in->start()->isZeroInt(), + "Non-zero start not supported: ", + in->toString()); + TORCH_CHECK( + in->stopOffset()->isZeroInt(), + "Non-zero stop offset not considered: ", + in->toString()); + + Val* resized_id_size = SimplifyingIrBuilder::addExpr( + SimplifyingIrBuilder::addExpr(in->extent(), left_expansion), + right_expansion); + + auto resized_id = + IterDomainBuilder(in->container()->zeroVal(), resized_id_size->as()) + .is_rfactor_domain(mark_as_rfactor) + .iter_type(in->getIterType()) + .build(); + + IrBuilder::create( + in->container(), resized_id, in, left_expansion, right_expansion); + + return resized_id; +} + // TODO: We should change parallelize interface to be on tensorview or at least // vectorize should be done on tensorview. This would let us check that we don't // vectorize to the left of the computeAt domain, and could allow us to do some @@ -2953,6 +3001,37 @@ std::string Swizzle2D::toInlineString(int indent_size) const { NVFUSER_DEFINE_CLONE_AND_CREATE(Swizzle2D) +Resize::Resize( + IrBuilderPasskey passkey, + IterDomain* out, + IterDomain* in, + Val* left, + Val* right) + : Expr(passkey) { + addOutput(out); + addInput(in); + addAttribute(left); + addAttribute(right); +} + +std::string Resize::toString(int indent_size) const { + std::stringstream ss; + ss << "Resize: "; + ss << in()->toString(); + ss << " by " << leftExpand()->toInlineString() << " and " + << rightExpand()->toInlineString(); + ss << " -> "; + ss << out()->toString(); + ss << "\n"; + return ss.str(); +} + +std::string Resize::toInlineString(int indent_size) const { + TORCH_CHECK(false, "Resize can not be printed inline"); +} + +NVFUSER_DEFINE_CLONE_AND_CREATE(Resize) + NamedScalar::NamedScalar( IrBuilderPasskey passkey, std::string name, @@ -3034,4 +3113,243 @@ c10::optional NamedScalar::getParallelIndex() const { return c10::nullopt; } +PadOp::PadOp( + IrBuilderPasskey passkey, + TensorView* out, + TensorView* inp, + const std::vector& pad_widths) + : Expr(passkey) { + const auto ndims = + TensorDomain::noReductions(inp->getMaybeRFactorDomain()).size(); + TORCH_INTERNAL_ASSERT( + pad_widths.size() % 2 == 0, + "Invalid size of padding width vector: ", + pad_widths.size(), + ". Number of width vals must be even."); + TORCH_INTERNAL_ASSERT( + pad_widths.size() == ndims * 2, + "Invalid size of padding width vector: ", + pad_widths.size(), + ". All dimensions, padded or not, must have width vals. Use zero for non non-padded dimensions."); + addOutput(out); + addInput(inp); + for (auto width : pad_widths) { + TORCH_CHECK(width != nullptr, "Padding width must not be nullptr"); + addInput(width); + } +} + +NVFUSER_DEFINE_CLONE_AND_CREATE(PadOp) + +std::string PadOp::toString(int indent_size) const { + std::stringstream ss; + indent(ss, indent_size) << out()->toString() << "\n"; + indent(ss, indent_size) << " = pad( " << in()->toString() << ", {" + << toDelimitedString(getPadWidths()) << "}" + << " )\n"; + return ss.str(); +} + +std::string PadOp::toInlineString(int indent_size) const { + TORCH_CHECK(false, "Tensor op can not be printed inline"); +} + +std::vector PadOp::getPaddedAxes() const { + auto num_dims = out()->as()->getRootDomain().size(); + std::vector padded_axes; + for (const auto i : c10::irange(num_dims)) { + auto [left_pad, right_pad] = getPadWidths(i); + // Filter out non-padded dimension + if (left_pad->isZeroInt() && right_pad->isZeroInt()) { + continue; + } + padded_axes.push_back(i); + } + return padded_axes; +} + +std::vector PadOp::getPadWidths() const { + return {getPadWidthInputBegin(), getPadWidthInputEnd()}; +} + +std::pair PadOp::getPadWidths(int axis) const { + const auto num_dims = + static_cast(out()->as()->getRootDomain().size()); + + if (axis < 0) { + axis += num_dims; + } + + TORCH_CHECK(axis >= 0 && axis < num_dims, "Invalid axis: ", axis); + + return std::make_pair( + (*(getPadWidthInputBegin() + axis * 2))->as(), + (*(getPadWidthInputBegin() + axis * 2 + 1))->as()); +} + +SliceOp::SliceOp( + IrBuilderPasskey passkey, + TensorView* out, + TensorView* inp, + const std::vector& ranges) + : Expr(passkey) { + const auto ndims = + TensorDomain::noReductions(inp->getMaybeRFactorDomain()).size(); + TORCH_INTERNAL_ASSERT( + ndims == ranges.size(), + "The range vector must have the same number of Slice descriptors. Given: ", + ranges.size(), + ", Expected: ", + ndims); + + addOutput(out); + addInput(inp); + for (const auto& range : ranges) { + TORCH_INTERNAL_ASSERT(range.start != nullptr, "nullptr not allowed"); + TORCH_INTERNAL_ASSERT(range.stop != nullptr, "nullptr not allowed"); + TORCH_INTERNAL_ASSERT(range.step != nullptr, "nullptr not allowed"); + addInput(range.start); + addInput(range.stop); + addInput(range.step); + } +} + +NVFUSER_DEFINE_CLONE_AND_CREATE(SliceOp) + +std::string SliceOp::toString(int indent_size) const { + std::stringstream ss; + indent(ss, indent_size) << out()->toString() << "\n"; + indent(ss, indent_size) << " = slice( " << in()->toString() << ", {"; + for (const auto& slice : getRanges()) { + ss << " {" + << toDelimitedString(std::vector{ + slice.start->toString(), + slice.stop->toString(), + slice.step->toString()}) + << "}"; + } + ss << " } )\n"; + return ss.str(); +} + +std::string SliceOp::toInlineString(int indent_size) const { + TORCH_CHECK(false, "Tensor op can not be printed inline"); +} + +std::vector SliceOp::getRanges() const { + const auto num_range_vals = + std::distance(getRangeInputBegin(), getRangeInputEnd()); + TORCH_INTERNAL_ASSERT( + num_range_vals % 3 == 0, + "Unexpected number of range vals: ", + num_range_vals); + auto ndims = num_range_vals / 3; + std::vector ranges(ndims); + auto range_val_it = getRangeInputBegin(); + for (const auto i : c10::irange(ndims)) { + ranges.at(i) = Slice{ + .start = *range_val_it, + .stop = *(range_val_it + 1), + .step = *(range_val_it + 2)}; + range_val_it += 3; + } + return ranges; +} + +CatOp::CatOp( + IrBuilderPasskey passkey, + Val* out, + const std::vector& inputs, + int concatenated_dim) + : Expr(passkey) { + addOutput(out); + for (auto inp : inputs) { + addInput(inp); + } + TORCH_INTERNAL_ASSERT( + concatenated_dim >= 0 && + concatenated_dim < + static_cast(ir_utils::getTv(out)->getRootDomain().size()), + "Invalid dimension to concatenate: ", + concatenated_dim); + + addAttribute(IrBuilder::create>( + passkey.ir_container_, concatenated_dim)); +} + +CatOp::CatOp( + IrBuilderPasskey passkey, + Val* out, + const std::vector& inputs, + int concatenated_dim, + Val* concatenated_domain_index, + const std::vector& preds) + : Expr(passkey) { + TORCH_INTERNAL_ASSERT( + passkey.ir_container_->isA(), + "Should only be used for Kernel container."); + + addOutput(out); + for (auto inp : inputs) { + addInput(inp); + } + addAttribute(IrBuilder::create>( + passkey.ir_container_, concatenated_dim)); + addAttribute(concatenated_domain_index); + for (auto pred : preds) { + addAttribute(pred); + } +} + +NVFUSER_DEFINE_CLONE_AND_CREATE(CatOp) + +std::string CatOp::toString(int indent_size) const { + std::stringstream ss; + indent(ss, indent_size) << output(0)->toString() << "\n"; + indent(ss, indent_size) << " = cat( "; + ss << toDelimitedString(inputs()); + ss << ", " << concatenatedDim(); + ss << " )\n"; + return ss.str(); +} + +std::string CatOp::toInlineString(int indent_size) const { + TORCH_CHECK(false, "Tensor op can not be printed inline"); +} + +Val* CatOp::getConcatenatedDomainIndex() const { + TORCH_INTERNAL_ASSERT( + container()->isA(), + "Should only be used for Kernel container."); + TORCH_INTERNAL_ASSERT(attributes().size() > 0, "No attribute found"); + TORCH_INTERNAL_ASSERT( + attribute(1) != nullptr, "nulllptr attribute is invalid"); + auto idx = attribute(1)->as(); + return idx; +} + +Bool* CatOp::getPred(int input_idx) const { + TORCH_INTERNAL_ASSERT( + container()->isA(), + "Should only be used for Kernel container."); + const auto num_input_tensors = static_cast(inputs().size()); + TORCH_INTERNAL_ASSERT( + input_idx < num_input_tensors, "Invalid input index: ", input_idx); + const auto attr_idx = input_idx + 2; + TORCH_INTERNAL_ASSERT( + attr_idx < static_cast(attributes().size()), + "Invalid attribute index: ", + attr_idx, + ", number of attributes: ", + attributes().size()); + auto attr = attribute(attr_idx); + TORCH_INTERNAL_ASSERT(attr != nullptr, "nullptr attribute is invalid"); + TORCH_INTERNAL_ASSERT( + attr->isA(), + "Attribute must be a Bool val: ", + attr->toInlineString()); + auto pred = attr->as(); + return pred; +} + } // namespace nvfuser diff --git a/csrc/ir_utils.cpp b/csrc/ir_utils.cpp index d6973be0928..83ed969d47a 100644 --- a/csrc/ir_utils.cpp +++ b/csrc/ir_utils.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include #include @@ -806,5 +807,19 @@ std::string varName(const Val* val) { return name.str(); } +bool hasResizedRfactor(const TensorView* tv) { + if (!tv->hasRFactor()) { + return false; + } + auto root_to_rf_exprs = StmtSort::getExprsBetween( + tv->fusion(), + {tv->getRootDomain().begin(), tv->getRootDomain().end()}, + {tv->getRFactorDomain().begin(), tv->getRFactorDomain().end()}); + return std::any_of( + root_to_rf_exprs.begin(), root_to_rf_exprs.end(), [](Expr* expr) { + return expr->isA(); + }); +} + } // namespace ir_utils } // namespace nvfuser diff --git a/csrc/ir_utils.h b/csrc/ir_utils.h index 3050266daa1..70e3052de92 100644 --- a/csrc/ir_utils.h +++ b/csrc/ir_utils.h @@ -371,5 +371,8 @@ TORCH_CUDA_CU_API bool isTorchGatherLookupTv(const Val* tv); TORCH_CUDA_CU_API std::string varName(const Val* val); +// Check if a tensor is resized as part of its root to rfactor transformations +bool hasResizedRfactor(const TensorView* tv); + } // namespace ir_utils } // namespace nvfuser diff --git a/csrc/lower2device.cpp b/csrc/lower2device.cpp index 5bd910e3abd..140daf8c3a8 100644 --- a/csrc/lower2device.cpp +++ b/csrc/lower2device.cpp @@ -358,6 +358,9 @@ void GpuLower::lower(Fusion* fusion) { validateSwizzle(fusion_); dumpExprsIfEnabled(fusion_->exprs(), "validateSwizzle"); + validateResize(fusion_); + dumpExprsIfEnabled(fusion_->exprs(), "validateResize"); + // Compute thread predicates. Depends on parallel_dimension_map_ thread_pred_map_.build(fusion_); dumpExprsIfEnabled(fusion_->exprs(), "build thread_pred_map_"); diff --git a/csrc/lower_index.cpp b/csrc/lower_index.cpp index 0bf61e02063..42c9512d308 100644 --- a/csrc/lower_index.cpp +++ b/csrc/lower_index.cpp @@ -1408,4 +1408,95 @@ void IndexLowering::allocateUniqueFusedReduction( insertAtTopLevel(fused_reduction_alloc_reduction); } +void IndexLowering::handle(const PadOp* pad) { + // Convert to a where op as: + // consumer[consumer_idx] = (produer_idx >= 0 && produer_idx < + // producer_extent) ? + // producer[producer_idx] : + // 0; + + auto producer_tv = pad->in()->as(); + auto consumer_tv = pad->out()->as(); + auto producer_doms = + TensorDomain::noReductions(producer_tv->getMaybeRFactorDomain()); + + const auto in = lowerSrcIndex(pad->in(), pad->out()); + const auto out = lowerDstIndex(pad->out()); + + DataType dt = producer_tv->getDataType().value(); + // Currently it's always padded by zero + const auto pad_val = isFloatingPointType(dt) + ? static_cast(IrBuilder::create(0, dt)) + : static_cast(IrBuilder::create(0, dt)); + + const auto producer_root_indices = Index::getProducerPerDimLogicalIndex( + producer_tv, consumer_tv, for_loops_, getRotatedLoop()); + + // Build a predicate for where + Val* pred = IrBuilder::create(true); + for (auto padded_axis : pad->getPaddedAxes()) { + auto producer_idx = producer_root_indices.at(padded_axis); + auto producer_root_id = producer_doms.at(padded_axis); + TORCH_INTERNAL_ASSERT(!producer_root_id->maybePartial()); + pred = SimplifyingIrBuilder::andExpr( + pred, + // idx >= 0 && idx < extent + SimplifyingIrBuilder::andExpr( + SimplifyingIrBuilder::geExpr( + producer_idx, GpuLower::current()->kernel()->zeroVal()), + SimplifyingIrBuilder::ltExpr( + producer_idx, producer_root_id->extent()))); + } + + pushBack(IrBuilder::create( + TernaryOpType::Where, out, pred, in, pad_val)); + GpuLower::current()->propagateExprInfo(pad, back()); +} + +void IndexLowering::handle(const SliceOp* slice) { + // TODO: Consider converting SliceOp to Set at the beginning of + // lowering + const auto in = lowerSrcIndex(slice->in(), slice->out()); + const auto out = lowerDstIndex(slice->out()); + + pushBack(IrBuilder::create(UnaryOpType::Set, out, in)); + GpuLower::current()->propagateExprInfo(slice, back()); +} + +void IndexLowering::handle(const CatOp* cat) { + // It's possible to lower CatOp to a series of IfThenElse or Where, + // but that would going to look really ugly. For now, rely on + // CudaKernelGenerator to produce code based on the predicates + // genereated here. + + const auto out = lowerDstIndex(cat->output(0)); + auto out_indices = Index::getConsumerPerDimLogicalIndex( + cat->output(0)->as(), for_loops_, getRotatedLoop()); + auto concatenated_dim_idx = out_indices.at(cat->concatenatedDim()); + + std::vector inputs(cat->inputs().size()); + std::vector preds(cat->inputs().size()); + Val* cur_extent = GpuLower::current()->kernel()->zeroVal(); + + for (const auto i : c10::irange(cat->inputs().size())) { + const auto inp = lowerSrcIndex(cat->input(i), cat->output(0)); + inputs.at(i) = inp; + + // Note the original extent is the extent of the root domain not + // rfactor domain + auto inp_concat_id = TensorDomain::noReductions( + cat->input(i)->as()->getRootDomain()) + .at(cat->concatenatedDim()); + cur_extent = add(cur_extent, inp_concat_id->extent()); + preds.at(i) = + IrBuilder::ltExpr(concatenated_dim_idx, cur_extent)->as(); + } + + auto lowered = IrBuilder::create( + out, inputs, cat->concatenatedDim(), concatenated_dim_idx, preds); + + pushBack(lowered); + GpuLower::current()->propagateExprInfo(cat, lowered); +} + } // namespace nvfuser diff --git a/csrc/lower_index.h b/csrc/lower_index.h index f121c482ad5..93212aefa73 100644 --- a/csrc/lower_index.h +++ b/csrc/lower_index.h @@ -63,6 +63,9 @@ class TORCH_CUDA_CU_API IndexLowering : private OptOutConstDispatch { void handle(const LoadStoreOp*) final; void handle(const MmaOp*) final; void handle(const BroadcastOp*) final; + void handle(const PadOp*) final; + void handle(const SliceOp*) final; + void handle(const CatOp*) final; void handle(const kir::ForLoop*) final; void handle(const kir::IfThenElse*) final; diff --git a/csrc/lower_index_compute.cpp b/csrc/lower_index_compute.cpp index 5b306cd3b51..3b9bca7a36a 100644 --- a/csrc/lower_index_compute.cpp +++ b/csrc/lower_index_compute.cpp @@ -928,7 +928,6 @@ IndexFromIdGraph getTensorIndexFromIdGraph( GpuLower::current()->haloInfo(), GpuLower::current()->concretizedBroadcastDomains(), p2c_map); - auto target_indexing = indexing.updateIndexCompute( target_tv->domain(), index_update_map, contig_finder); @@ -1288,9 +1287,30 @@ namespace { //! the vector ids by permissive compute at map. bool isPermissivelyMappedWithAny(IterDomain* id, const std::vector& ids) { return std::any_of(ids.begin(), ids.end(), [&](Val* val) { - return val->isA() && - GpuLower::current()->caMap()->areMapped( - id, val->as(), IdMappingMode::PERMISSIVE); + if (!(val->isA() && + GpuLower::current()->caMap()->areMapped( + id, val->as(), IdMappingMode::PERMISSIVE))) { + return false; + } + // When id is an input to resize, make sure the resize argumens + // are compatible. This is important when, for example, a tensor + // is padded two times differently but to the same shape, and the + // pad outputs are exactly mapped. In such a case, there're two + // paths from the post rfactor ID to the original input ID, and + // the correct path depends on the path where this producer is + // used as a producer. See the FusionPad8 test for a concrete + // example. + if (auto id_resize = dynamic_cast(id->uses().at(0))) { + auto mapped_id_resize = + dynamic_cast(val->as()->uses().at(0)); + TORCH_INTERNAL_ASSERT(mapped_id_resize != nullptr); + if (!(id_resize->leftExpand()->sameAs(mapped_id_resize->leftExpand()) && + id_resize->rightExpand()->sameAs( + mapped_id_resize->rightExpand()))) { + return false; + } + } + return true; }); } @@ -1390,7 +1410,6 @@ IterDomain* getRfactorIDToTraverse( const auto& rfactor_ids = GpuLower::current()->caMap()->getRfactorDomainsOfIdGroup( id, IdMappingMode::PERMISSIVE); - if (rfactor_ids.empty()) { return nullptr; } diff --git a/csrc/lower_predicate_elimination.cpp b/csrc/lower_predicate_elimination.cpp index a724b7033e1..093c366bb66 100644 --- a/csrc/lower_predicate_elimination.cpp +++ b/csrc/lower_predicate_elimination.cpp @@ -162,6 +162,12 @@ class PredicateAnalyzer : public OptOutDispatch { handle(merge->outer()); } + void handle(Resize* resize) override { + // resize outputs are guaranteed to match by the check above in + // handle(IterDomain*). + handle(resize->in()); + } + private: //! BestEffort map from consumer IDs to producer IDs const DisjointSets& disjoint_c2p_ids_; diff --git a/csrc/lower_shift.cpp b/csrc/lower_shift.cpp index e1f46e1edd2..18d038b3f63 100644 --- a/csrc/lower_shift.cpp +++ b/csrc/lower_shift.cpp @@ -445,25 +445,20 @@ void HaloInfo::build(TensorDomain* td) { } else { setHaloWidth(merge->out(), 0); } - } else if (auto swizzle = dynamic_cast(expr)) { - // Assume no halo on swizzled domain for now. - TORCH_INTERNAL_ASSERT( - getExtent(swizzle->inX()) == nullptr, - "Halo is not supported with swizzle. Halo-extended ID: ", - swizzle->inX()->toString(), - " used in ", - swizzle->toString()); - TORCH_INTERNAL_ASSERT( - getExtent(swizzle->inY()) == nullptr, - "Halo is not supported with swizzle. Halo-extended ID: ", - swizzle->inY()->toString(), - " used in ", - swizzle->toString()); - for (auto id : ir_utils::filterByType(expr->outputs())) { - setHaloWidth(id, 0); - } } else { - TORCH_INTERNAL_ASSERT(false, "Unsupported expr: ", expr); + // Assume no halo + for (auto input_id : ir_utils::filterByType(expr->inputs())) { + TORCH_INTERNAL_ASSERT( + getExtent(input_id) == nullptr, + "Halo is not supported. Halo-extended ID: ", + input_id->toString(), + " used in ", + expr->toString()); + } + for (auto output_id : + ir_utils::filterByType(expr->outputs())) { + setHaloWidth(output_id, 0); + } } } } @@ -839,28 +834,23 @@ std::unordered_map HaloInfo::buildConcreteHaloExtentMap( merge->out(), IdMappingMode::EXACT), 0); } - } else if (auto swizzle_2d = dynamic_cast(expr)) { - // Swizzle with halo not yet supported, just set the width - // to zero at the moment. - TORCH_INTERNAL_ASSERT( - local_halo_info.getHaloWidth( - GpuLower::current()->caMap()->getConcreteMappedID( - swizzle_2d->inX(), IdMappingMode::EXACT)) == 0 && - local_halo_info.getHaloWidth( - GpuLower::current()->caMap()->getConcreteMappedID( - swizzle_2d->inY(), IdMappingMode::EXACT)) == 0, - "Swizzle on ID with halo not yet supported."); - TORCH_INTERNAL_ASSERT("Swizzle on ID with halo not yet supported."); - local_halo_info.setHaloWidth( - GpuLower::current()->caMap()->getConcreteMappedID( - swizzle_2d->outX(), IdMappingMode::EXACT), - 0); - local_halo_info.setHaloWidth( - GpuLower::current()->caMap()->getConcreteMappedID( - swizzle_2d->outY(), IdMappingMode::EXACT), - 0); } else { - TORCH_INTERNAL_ASSERT(false, "Unsupported expr: ", expr); + // Halo not yet supported, just set the width to zero at the moment. + for (auto input_id : ir_utils::filterByType(expr->inputs())) { + TORCH_INTERNAL_ASSERT( + local_halo_info.getHaloWidth( + GpuLower::current()->caMap()->getConcreteMappedID( + input_id, IdMappingMode::EXACT)) == 0, + "Halo not yet supported: ", + input_id->toString()); + } + for (auto output_id : + ir_utils::filterByType(expr->outputs())) { + local_halo_info.setHaloWidth( + GpuLower::current()->caMap()->getConcreteMappedID( + output_id, IdMappingMode::EXACT), + 0); + } } } diff --git a/csrc/lower_unroll.cpp b/csrc/lower_unroll.cpp index 708f26c3d6d..feaab83796f 100644 --- a/csrc/lower_unroll.cpp +++ b/csrc/lower_unroll.cpp @@ -260,6 +260,9 @@ bool UnrollPass::canOmitElseClause(kir::ForLoop* fl) { const auto& pred_map = GpuLower::current()->threadPredMap(); + std::unordered_set all_exprs_inside_loop_nest; + std::unordered_set resize_exprs; + while (loops.size() > 0) { auto loop = loops.back(); loops.pop_back(); @@ -270,6 +273,16 @@ bool UnrollPass::canOmitElseClause(kir::ForLoop* fl) { if (lower_utils::hasBlockSync(expr, pred_map)) { return false; } + // Keep track of all expressions for additional check for + // resizing expressions + all_exprs_inside_loop_nest.insert(expr); + if (std::any_of( + expr->outputs().begin(), expr->outputs().end(), [](Val* output) { + return output->isA() && + ir_utils::hasResizedRfactor(output->as()); + })) { + resize_exprs.insert(expr); + } } // If the number of visits of the loop body per thread is one, the // unswitch predicate is sufficient. @@ -305,6 +318,35 @@ bool UnrollPass::canOmitElseClause(kir::ForLoop* fl) { } } + // If an expression generates a resized tensor and any of its + // dependencies appears in the loop nest, the else clause cannot be + // omitted. The tensors appearing before the resizing expression has + // a different shape than the output of the resizing expression and + // its subsequent consumers, so the unswitch predicates would + // include the predicates for both sizes, which means the larger + // tensors would still need the else clause. + if (!resize_exprs.empty()) { + std::unordered_set resize_expr_inputs; + std::transform( + resize_exprs.begin(), + resize_exprs.end(), + std::inserter(resize_expr_inputs, resize_expr_inputs.begin()), + [](Expr* resize_expr) { return resize_expr->input(0); }); + if (std::any_of( + all_exprs_inside_loop_nest.begin(), + all_exprs_inside_loop_nest.end(), + [&](Expr* loop_expr) { + return std::any_of( + loop_expr->outputs().begin(), + loop_expr->outputs().end(), + [&](Val* expr_output) { + return resize_expr_inputs.count(expr_output); + }); + })) { + return false; + } + } + return true; } diff --git a/csrc/lower_utils.cpp b/csrc/lower_utils.cpp index 5c3d7fc75b0..e6a8025e4d2 100644 --- a/csrc/lower_utils.cpp +++ b/csrc/lower_utils.cpp @@ -148,6 +148,9 @@ bool isTvOp(const Expr* expr) { GatherOp, ViewAsScalar, ViewOp, + PadOp, + SliceOp, + CatOp, kir::GridReduction, kir::GroupedGridReduction, kir::GridBroadcast, diff --git a/csrc/lower_validation.cpp b/csrc/lower_validation.cpp index f11c35d3c59..92198e001aa 100644 --- a/csrc/lower_validation.cpp +++ b/csrc/lower_validation.cpp @@ -1337,4 +1337,25 @@ void validateLookupTV(Fusion* fusion) { } } +void validateResize(Fusion* fusion) { + auto fusion_vals = fusion->usedMathVals(); + for (auto tv : ir_utils::filterByType(fusion_vals)) { + // Make sure resize is only used as part of rfactor transformations + auto rf_to_leaf_exprs = StmtSort::getExprsBetween( + fusion, + {tv->getMaybeRFactorDomain().begin(), + tv->getMaybeRFactorDomain().end()}, + {tv->domain()->domain().begin(), tv->domain()->domain().end()}); + + TORCH_INTERNAL_ASSERT( + std::none_of( + rf_to_leaf_exprs.begin(), + rf_to_leaf_exprs.end(), + [](Expr* expr) { return expr->isA(); }), + "Invalid use of resize detected with ", + tv->toString(), + ". Resize may only be used as part of rfactor transformations."); + } +} + } // namespace nvfuser diff --git a/csrc/lower_validation.h b/csrc/lower_validation.h index 839608edf7e..2414adc08c7 100644 --- a/csrc/lower_validation.h +++ b/csrc/lower_validation.h @@ -77,4 +77,7 @@ void validateGroupedReductions(Fusion* fusion); //! Validate all of the lookup TVs are ensured to be fusion inputs void validateLookupTV(Fusion* fusion); +//! Validate resize usage +void validateResize(Fusion* fusion); + } // namespace nvfuser diff --git a/csrc/ops/alias.cpp b/csrc/ops/alias.cpp index 133cc4985b4..9ae559e1ef4 100644 --- a/csrc/ops/alias.cpp +++ b/csrc/ops/alias.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include #include @@ -378,4 +379,284 @@ TensorView* transpose(TensorView* x) { return transpose(x, 0, 1); } +// Padding widths are assumed to be non-negative. Currently there's no +// validation. +TensorView* pad(TensorView* inp, const std::vector& pad_widths) { + const auto inp_dom = TensorDomain::noReductions(inp->getMaybeRFactorDomain()); + const auto ndims = inp_dom.size(); + + TORCH_CHECK( + pad_widths.size() % 2 == 0 && pad_widths.size() / 2 <= ndims, + "Invalid number of padding widths: ", + pad_widths.size()); + + const auto num_padded_dims = pad_widths.size() / 2; + const auto num_non_padded_dims = ndims - num_padded_dims; + + std::vector root_ids(ndims); + std::vector rfactor_ids(ndims); + + // PadOp requires pad widths for all dimensions, even for non-padded + // ones. + + std::vector normalized_pad_widths; + + // Fill zero for non padded dimensions + for (const auto i : c10::irange(num_non_padded_dims)) { + (void)i; + normalized_pad_widths.push_back(FusionGuard::getCurFusion()->zeroVal()); + normalized_pad_widths.push_back(FusionGuard::getCurFusion()->zeroVal()); + } + + // torch.pad has padding widths of inner dimensions before outer + // dimensions + for (const auto i : c10::irange(num_padded_dims)) { + auto left_pad = pad_widths.at(num_padded_dims * 2 - (i + 1) * 2); + auto right_pad = pad_widths.at(num_padded_dims * 2 - (i + 1) * 2 + 1); + normalized_pad_widths.push_back(left_pad); + normalized_pad_widths.push_back(right_pad); + } + + // Indicates if any dimension is actually padded. Can be false even + // when non-empty padding width vector is passed + bool is_padded_any = false; + for (const auto idx : c10::irange(ndims)) { + auto inp_root_id = inp_dom.at(idx); + IterDomain* out_root_id = nullptr; + IterDomain* out_rf_id = nullptr; + auto left_pad = normalized_pad_widths.at(idx * 2); + auto right_pad = normalized_pad_widths.at(idx * 2 + 1); + if (idx < num_non_padded_dims || + (left_pad->isZeroInt() && right_pad->isZeroInt())) { + out_root_id = inp_root_id->cloneWithoutRFactor(); + out_rf_id = out_root_id; + } else { + out_root_id = + IterDomainBuilder(inp_root_id).is_rfactor_domain(true).build(); + // Expand the root domain and mark it as a rfactor domain + out_rf_id = IterDomain::resize(out_root_id, left_pad, right_pad, true); + is_padded_any = true; + } + root_ids.at(idx) = out_root_id; + rfactor_ids.at(idx) = out_rf_id; + } + + // If all of the padding widths are just zero, this is just a set op. + if (!is_padded_any) { + return set(inp); + } + + auto out = IrBuilder::create( + IrBuilder::create( + root_ids, + rfactor_ids, + rfactor_ids, + TensorDomain::getContiguityFilledWith(rfactor_ids, true)), + *inp->getDataType()); + + IrBuilder::create(out, inp, normalized_pad_widths); + + return out; +} + +// cat is implemented as PadOp and CatOp. Padding is done first to +// account for the size difference between each of the inputs and the +// output. All of the inputs to CatOp have the same shape as the +// output shape. +TensorView* cat(const std::vector& inputs, int cat_dim) { + TORCH_CHECK(!inputs.empty(), "No input tensor given"); + + const auto dtype = inputs.at(0)->getDataType().value(); + + std::vector> inp_doms; + int ndims = -1; + + for (auto inp : inputs) { + TORCH_CHECK( + inp->getDataType().value() == dtype, + "Can't concatenate tensors with different data types: ", + dtype, + ", ", + inp->getDataType().value()); + inp_doms.emplace_back( + TensorDomain::noReductions(inp->getMaybeRFactorDomain())); + auto i_ndims = static_cast(inp_doms.back().size()); + if (ndims == -1) { + ndims = i_ndims; + } else { + TORCH_CHECK( + ndims == i_ndims, + "Unexpected number of dimensions: ", + inp->toString(), + ", expected: ", + ndims); + } + } + + if (cat_dim < 0) { + cat_dim += ndims; + } + + TORCH_CHECK( + cat_dim >= 0 && cat_dim < ndims, "Invalid dimension to cat: ", cat_dim); + + // Special handling for the case where there's only one input + if (inputs.size() == 1) { + return set(inputs.at(0)); + } + + Val* concat_ext = nullptr; + + for (const auto i : c10::irange(inputs.size())) { + auto input_dim_extent = + inp_doms.at(i).at(cat_dim)->getMaybeExpandedExtent(); + concat_ext = SimplifyingIrBuilder::addExpr(concat_ext, input_dim_extent); + } + + // For each of the input tensors, create a new rfactor tensor by + // padding the concat dim. Padding is used here as it effectively + // embeds the resizing information of the concat operation. + + Val* left_pad = FusionGuard::getCurFusion()->zeroVal(); + Val* right_pad = concat_ext; + std::vector resized_inputs(inputs.size()); + for (const auto input_idx : c10::irange(inputs.size())) { + const auto& inp_dom = inp_doms.at(input_idx); + std::vector pad_widths(ndims * 2); + for (const auto dim : c10::irange(ndims)) { + auto inp_root_id = inp_dom.at(dim); + Val* left_pad_i = nullptr; + Val* right_pad_i = nullptr; + if (dim != cat_dim) { + left_pad_i = FusionGuard::getCurFusion()->zeroVal(); + right_pad_i = FusionGuard::getCurFusion()->zeroVal(); + } else { + // Resize the root ID so that it has the same extent as the + // concatenated ID. The expansion of both left and right sides + // is done so that this input tensor is positioned in a way + // that corresponds to the concatenated dimension. For + // example, the first input should be at the + // left-most position, so it is expanded only at the right side + // with the expansion factor of + // (total_concatenated_domain_extent - + // extent_of_the_input_tensor). Similarly, the second tensor + // is expanded by extent_of_the_input_tensor at its left side, + // and by (total_concatenated_domain_extent - + // extent_of_the_input_tensor - extent_of_the_second_tensor). + // + // TODO: what to do if inp_id is not a normal iterdomain, i.e., + // broadcast, partial, etc? For now, assume it's a normal + // IterDomain. + TORCH_INTERNAL_ASSERT( + inp_root_id->getIterType() == IterType::Iteration && + !inp_root_id->maybePartial(), + "Unsupported IterDomain to concatenate: ", + inp_root_id->toString()); + // The right pad of the last tensor is just zero + right_pad = input_idx < inputs.size() - 1 + ? sub(right_pad, inp_root_id->getMaybeExpandedExtent()) + : FusionGuard::getCurFusion()->zeroVal(); + left_pad_i = left_pad; + right_pad_i = right_pad; + left_pad = add(left_pad, inp_root_id->extent()); + } + // The pad width argument to pad should be ordered such that the + // widths of inner dimensions come first. + pad_widths.at((ndims - dim - 1) * 2) = left_pad_i; + pad_widths.at((ndims - dim - 1) * 2 + 1) = right_pad_i; + } + + resized_inputs.at(input_idx) = pad(inputs.at(input_idx), pad_widths); + } + + // Now all of resized_inputs have the same shape as the out tensor + auto out = ops::newOutputTV(resized_inputs, dtype); + + IrBuilder::create(out, resized_inputs, cat_dim); + + return out; +} + +// Currently there's no error check about the actual values of the +// Slice parameters. For example, the start parameter of a range of a +// domain is assumed to be >= 0 and < the extent of the domain. +TensorView* slice(TensorView* inp, const std::vector& ranges) { + const auto inp_dom = TensorDomain::noReductions(inp->getMaybeRFactorDomain()); + const int ndims = static_cast(inp_dom.size()); + + TORCH_CHECK( + ndims == static_cast(ranges.size()), + "The range vector must have the same number of Slice descriptors. Given: ", + ranges.size(), + ", Expected: ", + ndims); + + auto normalize_slice_range = [](Slice range, Val* extent) -> Slice { + if (range.start == nullptr) { + range.start = FusionGuard::getCurFusion()->zeroVal(); + } + if (range.stop == nullptr) { + range.stop = extent; + } + if (range.step == nullptr) { + range.step = FusionGuard::getCurFusion()->oneVal(); + } + return range; + }; + + for (auto& range : ranges) { + // Step not supported yet + TORCH_CHECK( + range.step == nullptr || range.step->isOneInt(), + "Unsupported step: ", + range.step->toString()); + } + + std::vector root_ids(ndims); + std::vector rfactor_ids(ndims); + std::vector normalized_ranges(ndims); + + bool needs_real_slicing = false; + for (const auto idx : c10::irange(ndims)) { + auto inp_root_id = inp_dom[idx]; + auto range = normalize_slice_range(ranges.at(idx), inp_root_id->extent()); + normalized_ranges.at(idx) = range; + IterDomain* out_root_id = nullptr; + IterDomain* out_rf_id = nullptr; + if (range.start->isZeroInt() && range.stop->sameAs(inp_root_id->extent()) && + range.step->isOneInt()) { + // This dim doesn't need slicing + out_root_id = inp_root_id->cloneWithoutRFactor(); + out_rf_id = out_root_id; + } else { + out_root_id = + IterDomainBuilder(inp_root_id).is_rfactor_domain(true).build(); + out_rf_id = IterDomain::resize( + out_root_id, + IrBuilder::negExpr(range.start), + sub(range.stop, inp_root_id->extent()), + true); + needs_real_slicing = true; + } + root_ids.at(idx) = out_root_id; + rfactor_ids.at(idx) = out_rf_id; + } + + // If slicing isn't actually needed, just return a copy + if (!needs_real_slicing) { + return set(inp); + } + + auto out = IrBuilder::create( + IrBuilder::create( + root_ids, + rfactor_ids, + rfactor_ids, + TensorDomain::getContiguityFilledWith(rfactor_ids, true)), + *inp->getDataType()); + + IrBuilder::create(out, inp, normalized_ranges); + return out; +} + } // namespace nvfuser diff --git a/csrc/ops/alias.h b/csrc/ops/alias.h index 9adfde9fe99..f0d84faa8bc 100644 --- a/csrc/ops/alias.h +++ b/csrc/ops/alias.h @@ -73,4 +73,23 @@ TORCH_CUDA_CU_API TensorView* transpose( //! Transpose a 2D tensor. TORCH_CUDA_CU_API TensorView* transpose(TensorView* x); +//! Pad a tensor by given widths of zero. Similar to torch.pad, the +//! pad_widths vector specifies the padding widths of the innermost N +//! dimensions, where N is half the size of the width vector. Padding +//! is always done just by zero. TODO: Support other padding types +TORCH_CUDA_CU_API TensorView* pad( + TensorView* x, + const std::vector& pad_widths); + +//! Concatenate tensors in the given dimension +TORCH_CUDA_CU_API TensorView* cat( + const std::vector& inputs, + int dim); + +//! Return a tensor where each dimension is sliced as specified by the +//! ranges parameter. Stepping must be one at this moment. +TORCH_CUDA_CU_API TensorView* slice( + TensorView* inp, + const std::vector& ranges); + } // namespace nvfuser diff --git a/csrc/root_domain_map.h b/csrc/root_domain_map.h index ca04ca664b7..91fa35a18f8 100644 --- a/csrc/root_domain_map.h +++ b/csrc/root_domain_map.h @@ -471,6 +471,20 @@ class TORCH_CUDA_CU_API ComputeAtRootDomainMapBuilder void handle(GatherOp* op) override; + void handle(PadOp* op) override { + // For compute-at, padded id should be mapped + mapPointwiseOrReductionOp(op); + } + + void handle(SliceOp* op) override { + mapPointwiseOrReductionOp(op); + } + + void handle(CatOp* op) override { + // For compute-at, concat id should be mapped + mapPointwiseOrReductionOp(op); + } + void handle(TensorView* tv) override; //! Maps all pending mappings. diff --git a/csrc/scheduler/normalization.cpp b/csrc/scheduler/normalization.cpp index 70cc0e19254..db0c5190283 100644 --- a/csrc/scheduler/normalization.cpp +++ b/csrc/scheduler/normalization.cpp @@ -1157,6 +1157,8 @@ void schedulePersistentKernel(Fusion* fusion, const ReductionParams& rparams) { // fusion segmentation scheduler_utils::clearMemorySpace(fusion); + scheduler_utils::prepareForMemoryTypePromotion(fusion); + auto reduction_tvs = scheduler_utils::getReductionTvs(fusion); TORCH_INTERNAL_ASSERT(reduction_tvs.size()); @@ -1226,6 +1228,8 @@ void schedulePersistentKernel(Fusion* fusion, const ReductionParams& rparams) { persistent_buffer->computeWith(-1, true); } } + + scheduler_utils::promoteProducerMemoryTypesOfResizedTensors(fusion); } } // namespace nvfuser diff --git a/csrc/scheduler/pointwise.cpp b/csrc/scheduler/pointwise.cpp index b999135ba69..e935f1b3388 100644 --- a/csrc/scheduler/pointwise.cpp +++ b/csrc/scheduler/pointwise.cpp @@ -448,6 +448,10 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) { // Cache and fork outputs auto cached_outputs = scheduler_utils::cacheAndForkOutputs(fusion, true); + // Create a cache for a tensor if it may need to be placed on a + // farther but shared memory space + scheduler_utils::prepareForMemoryTypePromotion(fusion); + std::vector input_tvs; { auto filtered_tvs = ir_utils::filterByType(fusion->inputs()); @@ -800,6 +804,8 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) { inner_most_tensors.erase(output); } inlineMost(inner_most_tensors); + + scheduler_utils::promoteProducerMemoryTypesOfResizedTensors(fusion); } } // namespace nvfuser diff --git a/csrc/scheduler/pointwise_utils.cpp b/csrc/scheduler/pointwise_utils.cpp index eeaef360671..829b3d4fba0 100644 --- a/csrc/scheduler/pointwise_utils.cpp +++ b/csrc/scheduler/pointwise_utils.cpp @@ -12,7 +12,7 @@ namespace nvfuser { namespace pointwise_utils { DomainMap::DomainMap(Fusion* fusion) : fusion_(fusion), ca_map_(fusion) { - view_tvs_ = scheduler_utils::getViewTVs(fusion); + tvs_with_rfactor_ = scheduler_utils::getTVsWithNonReductionRFactor(fusion); for (auto select : ir_utils::getSelectOps(fusion)) { select_ids_.emplace(select->getSelectAxis()); } @@ -49,9 +49,10 @@ bool DomainMap::areAllInputIdsMappedTo(TensorView* input_tv, TensorView* tv) // Ignore unresolved broadcast dimensions for (auto id : tv->getMaybeRFactorDomain()) { if (!eraseIfMapped(in_concrete_ids, id)) { - eraseIfInputMappedThroughViewTo(in_concrete_ids, id); + eraseIfInputMappedThroughRFactorDomain(in_concrete_ids, id); } } + return in_concrete_ids.empty(); } @@ -69,14 +70,14 @@ bool DomainMap::eraseIfMapped( return found_match; } -// Check if in_id is mapped to out_id through any view rfactor domain. +// Check if in_id is mapped to out_id through any rfactor domain. // Currently this function only allow having one view on the path from input to // output. If there are multiple views, then likely the pointwise scheduler will // reject the fusion because we can not correctly find a reference tensor. -void DomainMap::eraseIfInputMappedThroughViewTo( +void DomainMap::eraseIfInputMappedThroughRFactorDomain( std::unordered_set& in_concrete_ids, IterDomain* id) const { - for (auto view : view_tvs_) { + for (auto view : tvs_with_rfactor_) { // Find any ID in view rfactor domain that is mapped to output ID auto view_rfactor_id = anyMapped(view->getRFactorDomain(), id); if (view_rfactor_id == nullptr) { diff --git a/csrc/scheduler/pointwise_utils.h b/csrc/scheduler/pointwise_utils.h index 58eb7b848ab..37a97355ca1 100644 --- a/csrc/scheduler/pointwise_utils.h +++ b/csrc/scheduler/pointwise_utils.h @@ -45,8 +45,8 @@ class DomainMap { std::unordered_set& in_concrete_ids, IterDomain* out_id) const; - // Check if in_id is mapped to id through any view rfactor domain - void eraseIfInputMappedThroughViewTo( + // Check if in_id is mapped to id through any rfactor domain + void eraseIfInputMappedThroughRFactorDomain( std::unordered_set& in_concrete_ids, IterDomain* id) const; @@ -57,7 +57,7 @@ class DomainMap { Fusion* fusion_ = nullptr; ComputeAtMap ca_map_; - std::vector view_tvs_; + std::vector tvs_with_rfactor_; std::unordered_set select_ids_; }; diff --git a/csrc/scheduler/reduction.cpp b/csrc/scheduler/reduction.cpp index 54d0afa71ba..09db4857c60 100644 --- a/csrc/scheduler/reduction.cpp +++ b/csrc/scheduler/reduction.cpp @@ -995,6 +995,8 @@ void scheduleReduction(Fusion* fusion, const ReductionParams& rparams) { // fusion segmentation scheduler_utils::clearMemorySpace(fusion); + scheduler_utils::prepareForMemoryTypePromotion(fusion); + auto reduction_tvs = scheduler_utils::getReductionTvs(fusion); TORCH_INTERNAL_ASSERT(reduction_tvs.size()); @@ -1046,6 +1048,8 @@ void scheduleReduction(Fusion* fusion, const ReductionParams& rparams) { reduction_tvs, cached_inputs, cached_outputs); + + scheduler_utils::promoteProducerMemoryTypesOfResizedTensors(fusion); } } // namespace nvfuser diff --git a/csrc/scheduler/registry.cpp b/csrc/scheduler/registry.cpp index 9ccfa24981e..cd32c0d4e41 100644 --- a/csrc/scheduler/registry.cpp +++ b/csrc/scheduler/registry.cpp @@ -681,9 +681,9 @@ bool reductionInterferingView( // Make sure groups are disjoint based on view - auto disjoint_view_sets = scheduler_utils::disjointViewSets(fusion); - auto disjoint_set_information = scheduler_utils::getDisjointViewSetsOf( - fusion, reduction_reference, disjoint_view_sets); + auto disjoint_rfactor_sets = scheduler_utils::disjointRFactorSets(fusion); + auto disjoint_set_information = scheduler_utils::getDisjointRFactorSetsOf( + fusion, reduction_reference, disjoint_rfactor_sets); // Convert id's in groups to disjoint_set_ids of disjoint_set_information std::vector> disjoint_groups; diff --git a/csrc/scheduler/utils.cpp b/csrc/scheduler/utils.cpp index 9ead467cc2c..6cb7670b69a 100644 --- a/csrc/scheduler/utils.cpp +++ b/csrc/scheduler/utils.cpp @@ -15,6 +15,7 @@ #include #include #include +#include #include #include @@ -253,7 +254,8 @@ void parallelizeAllLike( const auto& reference_dom = reference_tv->domain()->domain(); for (auto it = reference_dom.begin(); it != reference_dom.begin() + pos; it++) { - auto ca_id = ca_map.getConcreteMappedID(*it, IdMappingMode::PERMISSIVE); + auto ca_id = + ca_map.getConcreteMappedID(*it, IdMappingMode::PERMISSIVE_RESIZE); concrete_to_reference_map[ca_id] = *it; } @@ -265,8 +267,8 @@ void parallelizeAllLike( continue; } for (const auto i : c10::irange(tv->domain()->domain().size())) { - auto ca_id = - ca_map.getConcreteMappedID(tv->axis(i), IdMappingMode::PERMISSIVE); + auto ca_id = ca_map.getConcreteMappedID( + tv->axis(i), IdMappingMode::PERMISSIVE_RESIZE); if (concrete_to_reference_map.count(ca_id) > 0) { auto reference_id = concrete_to_reference_map.at(ca_id); auto reference_parallel_type = reference_id->getParallelType(); @@ -953,6 +955,25 @@ std::vector getViewTVs(Fusion* fusion) { return view_tvs; } +std::vector getTVsWithNonReductionRFactor(Fusion* fusion) { + std::vector tvs_with_rfactor; + auto fusion_vals = fusion->usedMathVals(); + std::copy_if( + ir_utils::filterByType(fusion_vals).begin(), + ir_utils::filterByType(fusion_vals).end(), + std::back_inserter(tvs_with_rfactor), + [](TensorView* tv) { + return tv->hasRFactor() && + std::none_of( + tv->getMaybeRFactorDomain().begin(), + tv->getMaybeRFactorDomain().end(), + [](auto id) { + return id->isReduction() && id->isRFactorProduct(); + }); + }); + return tvs_with_rfactor; +} + // Reset inputs and outputs to global memory, everything else to local. void clearMemorySpace(Fusion* fusion) { for (auto tv : ir_utils::allTvs(fusion)) { @@ -1065,6 +1086,11 @@ IterDomain* projectIdToRoot( projected_id = split->in(); } } + } else if (expr->isA()) { + auto resize = expr->as(); + if (resize->out() == projected_id) { + projected_id = resize->in(); + } } else { TORCH_INTERNAL_ASSERT( false, "Didn't recognize the iterdomain expression: ", expr); @@ -1119,6 +1145,11 @@ IterDomain* projectIdToRFactor( if (split->in() == projected_id) { projected_id = split->inner(); } + } else if (expr->isA()) { + auto resize = expr->as(); + if (resize->in() == projected_id) { + projected_id = resize->out(); + } } else { TORCH_INTERNAL_ASSERT( false, "Didn't recognize the iterdomain expression: ", expr); @@ -1338,21 +1369,10 @@ std::vector getInputsOutputsWithInnerDim( return vectorizable_tensors; } -// Returns disjoint view sets mapped onto the given reference. Returns a pair -// of vectors of size rfactorDomain of reference. Vector of -// VectorOfUniqueEntries returns a const* to the disjoint set in -// disjoint_view_set the iterdomain is mapped to. Integer vector represents -// which disjoint view group the rfactor id belongs to. It's straight forward -// to map from the former to the latter, but not the latter to former. -// -// Since we return a const* to entries in disjoint_view_set, it must be passed -// in as a reference. Algorithm is N^2 based on number of dims in reference, -// but generating the disjoint view set is likely the limiter on perf of this -// function. -DisjointViewSetInfo getDisjointViewSetsOf( +DisjointRFactorSetInfo getDisjointRFactorSetsOf( Fusion* fusion, TensorView* of, - DisjointSets& disjoint_view_set) { + DisjointSets& disjoint_rfactor_set) { auto rfactor_dom = of->getMaybeRFactorDomain(); if (rfactor_dom.size() == 0) { return {}; @@ -1376,12 +1396,12 @@ DisjointViewSetInfo getDisjointViewSetsOf( } const auto& ref_group = - disjoint_view_set.getDisjointSetOf(rfactor_dom[ref_dim_i]); + disjoint_rfactor_set.getDisjointSetOf(rfactor_dom[ref_dim_i]); int other_dim_i = ref_dim_i; while (other_dim_i >= 0) { const auto& other_group = - disjoint_view_set.getDisjointSetOf(rfactor_dom[other_dim_i]); + disjoint_rfactor_set.getDisjointSetOf(rfactor_dom[other_dim_i]); if (&ref_group == &other_group) { disjoint_group_ids[other_dim_i] = current_group_id; disjoint_set_of_id[other_dim_i] = &ref_group; @@ -1398,7 +1418,7 @@ DisjointViewSetInfo getDisjointViewSetsOf( disjoint_group_ids.begin(), disjoint_group_ids.end(), [](int i) { return i == -1; }), - "Failed to generate the view disjoint groups of the reference ", + "Failed to generate the rfactor disjoint groups of the reference ", of->toString()); TORCH_INTERNAL_ASSERT( @@ -1408,10 +1428,10 @@ DisjointViewSetInfo getDisjointViewSetsOf( [](const VectorOfUniqueEntries* ptr) { return ptr == nullptr; }), - "Failed to generate the view disjoint groups of the reference ", + "Failed to generate the rfactor disjoint groups of the reference ", of->toString()); - DisjointViewSetInfo info; + DisjointRFactorSetInfo info; info.disjoint_sets_of_ref = disjoint_set_of_id; info.disjoint_set_ids = disjoint_group_ids; info.ref = of; @@ -1432,9 +1452,9 @@ BroadcastMultipleInformation getBroadcastMultiples( std::vector multiples(ref_root_domain.size()); - auto disjoint_view_sets = disjointViewSets(fusion); - auto disjoint_set_information = scheduler_utils::getDisjointViewSetsOf( - fusion, reference_tv, disjoint_view_sets); + auto disjoint_rfactor_sets = disjointRFactorSets(fusion); + auto disjoint_set_information = scheduler_utils::getDisjointRFactorSetsOf( + fusion, reference_tv, disjoint_rfactor_sets); auto ref_disjoint_sets = disjoint_set_information.disjoint_sets_of_ref; auto ref_disjoint_set_ids = disjoint_set_information.disjoint_set_ids; @@ -2100,10 +2120,10 @@ void BoundedDirectionalTransformPropagator::bothWays( propagate(from, pos, included_tvs, *options); } -DisjointSets disjointViewSets(Fusion* fusion) { +DisjointSets disjointRFactorSets(Fusion* fusion) { // Start from the exact iter domain graph of the fusion IterDomainGraph id_graph(fusion); - auto disjoint_view_ids = id_graph.exactNodes(); + auto disjoint_rfactor_ids = id_graph.exactNodes(); // If iter domains are involved in any transformation from root domains to // rfactor domains they should be considered "contaminated". @@ -2114,19 +2134,22 @@ DisjointSets disjointViewSets(Fusion* fusion) { tv->getMaybeRFactorDomain().end()})) { if (expr->isA()) { auto merge = expr->as(); - disjoint_view_ids.mapEntries(merge->inner(), merge->out()); - disjoint_view_ids.mapEntries(merge->outer(), merge->out()); + disjoint_rfactor_ids.mapEntries(merge->inner(), merge->out()); + disjoint_rfactor_ids.mapEntries(merge->outer(), merge->out()); } else if (expr->isA()) { auto split = expr->as(); - disjoint_view_ids.mapEntries(split->in(), split->inner()); - disjoint_view_ids.mapEntries(split->in(), split->outer()); + disjoint_rfactor_ids.mapEntries(split->in(), split->inner()); + disjoint_rfactor_ids.mapEntries(split->in(), split->outer()); + } else if (expr->isA()) { + auto resize = expr->as(); + disjoint_rfactor_ids.mapEntries(resize->in(), resize->out()); } else { TORCH_INTERNAL_ASSERT( false, "Expression type: ", expr->toString(), " not supported."); } } } - return disjoint_view_ids; + return disjoint_rfactor_ids; } bool breakIsDisjoint(std::vector group_ids, int pos) { @@ -2203,6 +2226,16 @@ std::unordered_map domainReorderAsRfactorMap(TensorView* tv) { reordered_ids.erase(reordered_ids.begin() + pos0); reordered_ids[--pos1] = merge->out(); + } else if (const Resize* resize = dynamic_cast(expr)) { + auto find_it = + std::find(reordered_ids.begin(), reordered_ids.end(), resize->in()); + if (find_it == reordered_ids.end()) { + // Transformations before rfactor, ignore those. + continue; + } + *find_it = resize->out(); + } else { + TORCH_INTERNAL_ASSERT(false, "Unexpected expression: ", expr->toString()); } } @@ -2321,6 +2354,124 @@ bool isFastestDimReduction(TensorView* tv) { return false; } +namespace { + +std::vector getResizedTensors(Fusion* fusion) { + std::vector resized_tensors; + + auto fusion_vals = fusion->usedMathVals(); + for (auto tv : ir_utils::filterByType(fusion_vals)) { + if (ir_utils::hasResizedRfactor(tv)) { + resized_tensors.push_back(tv); + } + } + + return resized_tensors; +} + +} // namespace + +void prepareForMemoryTypePromotion(Fusion* fusion) { + auto resized_tensors = getResizedTensors(fusion); + std::unordered_set cached; + for (auto resized_tensor : resized_tensors) { + for (auto producer : ir_utils::producerTvsOf(resized_tensor)) { + if (cached.count(producer) != 0) { + continue; + } + producer->cacheAfter(); + cached.insert(producer); + } + } +} + +void promoteProducerMemoryTypesOfResizedTensors(Fusion* fusion) { + auto resized_tensors = getResizedTensors(fusion); + + // Just make it simpler to promote memory types. Minimum is + // preferred. Increased as required. + auto memoryTypeToInt = [](MemoryType mt1) -> int { + switch (mt1) { + case MemoryType::Local: + return 1; + case MemoryType::Shared: + return 2; + case MemoryType::Global: + return 3; + default: + TORCH_INTERNAL_ASSERT(false); + } + }; + + std::unordered_map tvs_to_promote; + + auto setPromotion = [&](TensorView* tv, MemoryType m_type) { + // Initialize the memory type with the current type + tvs_to_promote.emplace(tv, tv->getMemoryType()); + + if (memoryTypeToInt(m_type) > memoryTypeToInt(tvs_to_promote.at(tv))) { + tvs_to_promote[tv] = m_type; + } + }; + + for (auto resized_tensor : resized_tensors) { + for (auto producer : ir_utils::producerTvsOf(resized_tensor)) { + auto c2p_map = BestEffortReplay( + producer->domain()->domain(), + resized_tensor->domain()->domain(), + PairwiseRootDomainMap(producer, resized_tensor, true) + .mapConsumerToProducer( + resized_tensor->domain(), producer->domain())) + .getReplay(); + + for (const auto i : + c10::irange(producer->nDims() - producer->getComputeAtPosition())) { + auto producer_non_ca_id = + producer->axis(i + producer->getComputeAtPosition()); + auto producer_non_ca_id_ptype = producer_non_ca_id->getParallelType(); + if (!isParallelTypeThread(producer_non_ca_id_ptype)) { + continue; + } + + auto resized_tensor_exact_map_id_it = std::find_if( + resized_tensor->domain()->domain().begin(), + resized_tensor->domain()->domain().end(), + [&](IterDomain* resized_tensor_leaf_id) { + auto it = c2p_map.find(resized_tensor_leaf_id); + return it != c2p_map.end() && it->second == producer_non_ca_id; + }); + if (resized_tensor_exact_map_id_it != + resized_tensor->domain()->domain().end() && + (*resized_tensor_exact_map_id_it)->getParallelType() == + producer_non_ca_id_ptype) { + continue; + } + + // Promotion required + if (isParallelTypeThreadDim(producer_non_ca_id_ptype)) { + setPromotion(producer, MemoryType::Shared); + } else if (isParallelTypeBlockDim(producer_non_ca_id_ptype)) { + setPromotion(producer, MemoryType::Global); + } + } + } + } + + // Iterate over resized_tensors so that promotion is done in a + // deterministic order + for (auto resized_tensor : resized_tensors) { + for (auto producer : ir_utils::producerTvsOf(resized_tensor)) { + auto it = tvs_to_promote.find(producer); + if (it == tvs_to_promote.end() || + it->second == producer->getMemoryType()) { + continue; + } + auto new_mem_type = it->second; + producer->setMemoryType(new_mem_type); + } + } +} + } // namespace scheduler_utils } // namespace nvfuser diff --git a/csrc/scheduler/utils.h b/csrc/scheduler/utils.h index f3bad044884..e158ab511db 100644 --- a/csrc/scheduler/utils.h +++ b/csrc/scheduler/utils.h @@ -210,6 +210,9 @@ TORCH_CUDA_CU_API std::vector getReductionTvs(Fusion* fusion); // Returns a list of TensorViews that are the consumer tv for a view operation. std::vector getViewTVs(Fusion* fusion); +// Returns a list of non-reduction TensorViews that have a rfactor domain +std::vector getTVsWithNonReductionRFactor(Fusion* fusion); + // Reset inputs and outputs to global memory, everything else to local. void clearMemorySpace(Fusion* fusion); @@ -283,12 +286,13 @@ std::vector getInputsOutputsWithInnerDim( bool vectorize_pass); // Holder return struct for the below function. -struct DisjointViewSetInfo { - // const* to the disjoint set in disjoint_view_set passed in to - // getDisjointViewSetsOf each iterdomain in the rfactor of ref is mapped to. +struct DisjointRFactorSetInfo { + // const* to the disjoint set in disjoint_rfactor_set passed in to + // getDisjointRFactorSetsOf each iterdomain in the rfactor of ref is mapped + // to. // - // WARNING: these pointers are relative to the disjoint_view_set reference - // passed into getDisjointViewSetsOf it's the user's responsibility to + // WARNING: these pointers are relative to the disjoint_rfactor_set reference + // passed into getDisjointRFactorSetsOf it's the user's responsibility to // maintain the lifetime of that reference to match this vector. std::vector*> disjoint_sets_of_ref; @@ -301,21 +305,21 @@ struct DisjointViewSetInfo { TensorView* ref; }; -// Returns disjoint view sets mapped onto the given reference. Returns a pair +// Returns disjoint rfactor sets mapped onto the given reference. Returns a pair // of vectors of size rfactorDomain of reference. Vector of // VectorOfUniqueEntries returns a const* to the disjoint set in -// disjoint_view_set the iterdomain is mapped to. Integer vector represents -// which disjoint view group the rfactor id belongs to. It's straight forward +// disjoint_rfactor_set the iterdomain is mapped to. Integer vector represents +// which disjoint rfactor group the rfactor id belongs to. It's straightforward // to map from the former to the latter, but not the latter to former. // -// Since we return a const* to entries in disjoint_view_set, it must be passed -// in as a reference. Algorithm is N^2 based on number of dims in reference, -// but generating the disjoint view set is likely the limiter on perf of this -// function. -DisjointViewSetInfo getDisjointViewSetsOf( +// Since we return a const* to entries in disjoint_rfactor_set, it must be +// passed in as a reference. Algorithm is N^2 based on number of dims in +// reference, but generating the disjoint rfactor set is likely the limiter on +// perf of this function. +DisjointRFactorSetInfo getDisjointRFactorSetsOf( Fusion* fusion, TensorView* of, - DisjointSets& disjoint_view_set); + DisjointSets& disjoint_rfactor_set); // Structure to hold byte multiples for break points. I.e. if we have the // tensors: @@ -532,7 +536,7 @@ struct TORCH_CUDA_CU_API BoundedDirectionalTransformPropagator { // If IterDomains are disjoint in the returned set, then they are considered // "separable". // Warning: This pass generates the IdGraphs, not intended for use at runtime. -TORCH_CUDA_CU_API DisjointSets disjointViewSets(Fusion* fusion); +TORCH_CUDA_CU_API DisjointSets disjointRFactorSets(Fusion* fusion); // Makes sure that there are no group id's left of pos that match right of pos. // e.g. @@ -564,5 +568,19 @@ inline void rotateLoop( loop_tv->fusion()->rotateLoop(loop_tv, axis, std::move(selection)); } +//! Certain tensors may need to be placed on shared or global memory +//! due to data dependencies caused by resize operations. Create +//! caches of those tensors so that original operations producing +//! them should keep using the same memory. This avoids, for example, +//! reductions to global memory. +TORCH_CUDA_CU_API void prepareForMemoryTypePromotion(Fusion* fusion); + +//! If a resized tensor induces a data dependency between threads, +//! move its producer to a shared memory that is sufficient to satisfy +//! the dependency. A proper RAW sync will be automatically inserted +//! when the fusion is lowered. +TORCH_CUDA_CU_API void promoteProducerMemoryTypesOfResizedTensors( + Fusion* fusion); + } // namespace scheduler_utils } // namespace nvfuser diff --git a/csrc/scheduler/vectorize_helper.cpp b/csrc/scheduler/vectorize_helper.cpp index fc41609684b..460a994eb97 100644 --- a/csrc/scheduler/vectorize_helper.cpp +++ b/csrc/scheduler/vectorize_helper.cpp @@ -579,6 +579,12 @@ std::vector ContiguousInnerDimensionsMapper::projectIdToRoot( ids.insert(ids.begin() + out_pos + 1, merge->inner()); propagateExtentMergeBackward(merge); + } else if (const Resize* resize = dynamic_cast(expr)) { + // Cannot vectorize through resize + auto find_out_it = std::find(ids.begin(), ids.end(), resize->out()); + if (find_out_it != ids.end()) { + ids.erase(ids.begin(), find_out_it + 1); + } } else { // TODO: I wonder if we should just remove all inputs instead of erroring. // Seems that would be safe. @@ -731,6 +737,12 @@ std::vector ContiguousInnerDimensionsMapper::projectIdToRFactor( } propagateExtentSplitForward(split); + } else if (const Resize* resize = dynamic_cast(expr)) { + // Cannot vectorize through resize + auto find_in_it = std::find(ids.begin(), ids.end(), resize->in()); + if (find_in_it != ids.end()) { + ids.erase(ids.begin(), find_in_it + 1); + } } else { // TODO: I wonder if we should just remove all inputs instead of erroring. // Seems that would be safe. diff --git a/csrc/transform_iter.cpp b/csrc/transform_iter.cpp index 728f5f9aa03..ce708f76edc 100644 --- a/csrc/transform_iter.cpp +++ b/csrc/transform_iter.cpp @@ -16,7 +16,7 @@ namespace nvfuser { // Transform dispatch void ReplayTransformations::handle(Expr* e) { - auto is_supported_expr = e->isOneOf(); + auto is_supported_expr = e->isOneOf(); TORCH_INTERNAL_ASSERT( is_supported_expr, "Invalid expr type found in transform traversal."); IterVisitor::handle(e); @@ -39,7 +39,7 @@ void ReplayTransformations::handle(Split* s) { } } - auto mapped = (*it).second; + auto mapped = it->second; // Make sure this ID is a leaf ID (meaning it has no uses we generated) TORCH_INTERNAL_ASSERT( leaf_ids_.find(mapped) != leaf_ids_.end(), @@ -150,8 +150,8 @@ void ReplayTransformations::handle(Swizzle2D* swizzle_2d) { } } - auto mapped_x = (*it_x).second; - auto mapped_y = (*it_y).second; + auto mapped_x = it_x->second; + auto mapped_y = it_y->second; // Make sure this ID is a leaf ID (meaning it has no uses we generated) TORCH_INTERNAL_ASSERT( @@ -179,6 +179,38 @@ void ReplayTransformations::handle(Swizzle2D* swizzle_2d) { id_map_[swizzle_2d->outY()] = outs.second; } +void ReplayTransformations::handle(Resize* exp) { + auto id_in = exp->in(); + + auto it = id_map_.find(id_in); + if (it == id_map_.end()) { + if (error_on_failure_) { + TORCH_INTERNAL_ASSERT( + false, "Transform traversal failed, dependencies not met."); + } else { + return; + } + } + + auto mapped = it->second; + // Make sure this ID is a leaf ID (meaning it has no uses we generated) + TORCH_INTERNAL_ASSERT( + leaf_ids_.find(mapped) != leaf_ids_.end(), + "Transform traversal failed, modified a node but it was not a leaf node."); + + auto out = mapped; + + if (replay_resize_) { + out = IterDomain::resize(mapped, exp->leftExpand(), exp->rightExpand()); + } + + leaf_ids_.erase(mapped); + + leaf_ids_[out] = newCounter(); + + id_map_[exp->out()] = out; +} + ReplayTransformations::ReplayTransformations( const std::vector& target_domain, std::unordered_map id_map) @@ -243,7 +275,7 @@ void ReplayTransformations::runReplay() { continue; } - auto id_replayed = (*it_replayed).second; + auto id_replayed = it_replayed->second; auto it_leaf = leaf_ids_.find(id_replayed); TORCH_INTERNAL_ASSERT( it_leaf != leaf_ids_.end(), @@ -274,7 +306,8 @@ BestEffortReplay::BestEffortReplay( std::unordered_map replay_forward_id_map, std::unordered_map target_forward_id_map, bool skip_replay_swizzle, - bool skip_target_swizzle) + bool skip_target_swizzle, + bool skip_resize) : target2replay_id_map_(std::move(target2replay_map)), replay_forward_id_map_(std::move(replay_forward_id_map)), target_forward_id_map_(std::move(target_forward_id_map)), @@ -348,6 +381,10 @@ BestEffortReplay::BestEffortReplay( skipSwizzles(target_id2expr_map, replay_id2expr_map); } + if (skip_resize) { + skipResizes(); + } + std::string err_str( "Error during replay, a transformation was called that conflicts with an rfactor call."); @@ -495,7 +532,13 @@ BestEffortReplay::BestEffortReplay( // If there isn't an rfactor id in the replay's inputs and there's a // mismatch in replay_expr's and target_expr's outputs, continue if (target_expr->outputs().size() != replay_expr->outputs().size()) { - TORCH_INTERNAL_ASSERT(!replay_has_rfactor_inp, err_str); + TORCH_INTERNAL_ASSERT( + !replay_has_rfactor_inp, + err_str, + ". Target: ", + target_expr->toString(), + ", repaly: ", + replay_expr->toString()); continue; } @@ -533,6 +576,16 @@ BestEffortReplay::BestEffortReplay( } } + if (replay_expr->isA()) { + auto r_resize = replay_expr->as(); + auto t_resize = target_expr->as(); + if (!r_resize->leftExpand()->sameAs(t_resize->leftExpand()) || + !r_resize->rightExpand()->sameAs(t_resize->rightExpand())) { + TORCH_INTERNAL_ASSERT(!replay_has_rfactor_inp, err_str); + continue; + } + } + // Take replay expr inputs out of map: for (const auto t_i : c10::irange(target_id_inps.size())) { auto t_inp = target_id_inps[t_i]; @@ -567,6 +620,10 @@ BestEffortReplay::BestEffortReplay( // swizzles on the mapping. skipSwizzles(target_id2expr_map, replay_id2expr_map); } + + if (skip_resize) { + skipResizes(); + } } } @@ -678,11 +735,10 @@ struct ForwardingInfo { return; } - TORCH_INTERNAL_ASSERT(active_root_dom.size() == active_dim_flags->size()); - // Collect which root ids are only in active_tv but not in the inactive // tensor. std::unordered_set forwarded_ids; + TORCH_INTERNAL_ASSERT(active_root_dom.size() == active_dim_flags->size()); for (auto i : c10::irange(active_dim_flags->size())) { if (active_dim_flags->at(i)) { forwarded_ids.emplace(active_root_dom.at(i)); @@ -835,7 +891,8 @@ void BestEffortReplay::addComplimentLeafIDs( auto compliment_map_it = compliment_map.find(forwarded_id); TORCH_INTERNAL_ASSERT( compliment_map_it != compliment_map.end(), - "Issue tracking forwarded broadcast merges in best effort replay."); + "Issue tracking forwarded broadcast merges in best effort replay. ", + forwarded_id->toString()); compliments.insert( compliments.end(), compliment_map_it->second.begin(), @@ -872,7 +929,8 @@ BestEffortReplay BestEffortReplay::replayCasP( int producer_compute_at_axis, const RootDomainMap& root_map, bool skip_consumer_swizzle, - bool skip_producer_swizzle) { + bool skip_producer_swizzle, + bool skip_resize) { if (producer_compute_at_axis < 0) producer_compute_at_axis += (int)producer->nDims() + 1; @@ -919,7 +977,8 @@ BestEffortReplay BestEffortReplay::replayCasP( forwarding_info.consumer_forwarding_map, forwarding_info.producer_forwarding_map, skip_consumer_swizzle, - skip_producer_swizzle); + skip_producer_swizzle, + skip_resize); consumer_replay.addComplimentLeafIDs( forwarding_info.consumer_forwarding_map, @@ -936,7 +995,8 @@ BestEffortReplay BestEffortReplay::replayPasC( int consumer_compute_at_axis, const RootDomainMap& root_map, bool skip_producer_swizzle, - bool skip_consumer_swizzle) { + bool skip_consumer_swizzle, + bool skip_resize) { if (consumer_compute_at_axis < 0) consumer_compute_at_axis += (int)consumer->nDims() + 1; TORCH_INTERNAL_ASSERT( @@ -975,7 +1035,8 @@ BestEffortReplay BestEffortReplay::replayPasC( forwarding_info.producer_forwarding_map, forwarding_info.consumer_forwarding_map, skip_producer_swizzle, - skip_consumer_swizzle); + skip_consumer_swizzle, + skip_resize); producer_replay.addComplimentLeafIDs( forwarding_info.producer_forwarding_map, @@ -1024,6 +1085,50 @@ void BestEffortReplay::skipSwizzles( } } +// Same logic as skipSwizzles +void BestEffortReplay::skipResizes() { + auto isResizeInput = [](IterDomain* id) -> bool { + return id->uses().size() == 1 && id->uses().front()->isA(); + }; + + bool updated = true; + + while (updated) { + updated = false; + for (auto it : target2replay_id_map_) { + auto target_id = it.first; + auto new_target_id = target_id; + auto replay_id = it.second; + auto new_replay_id = replay_id; + if (isResizeInput(target_id)) { + new_target_id = target_id->uses().front()->as()->out(); + } + if (isResizeInput(replay_id)) { + new_replay_id = replay_id->uses().front()->as()->out(); + } + + if (new_target_id == target_id && new_replay_id == replay_id) { + continue; + } + + target2replay_id_map_.erase(target_id); + TORCH_INTERNAL_ASSERT( + target2replay_id_map_ + .insert(std::make_pair(new_target_id, new_replay_id)) + .second, + "Unexpected replay leaf"); + // Progress the leaf ids if the replay is updated + if (replay_id != new_replay_id && + leaf_ids_.find(replay_id) != leaf_ids_.end()) { + leaf_ids_.erase(replay_id); + leaf_ids_[new_replay_id] = counter++; + } + updated = true; + break; + } + } +} + DisjointSets BestEffortReplay::getIterDomainEquivalence() { DisjointSets result; const std::unordered_map* maps[3] = { diff --git a/csrc/transform_iter.h b/csrc/transform_iter.h index 236262e41f3..076ece21f95 100644 --- a/csrc/transform_iter.h +++ b/csrc/transform_iter.h @@ -61,6 +61,11 @@ class TORCH_CUDA_CU_API ReplayTransformations : public IterVisitor { return *this; } + ReplayTransformations& setReplayResize(bool replay_resize) { + replay_resize_ = replay_resize; + return *this; + } + // Replays outputs that were generated from ids.first on ids.second void runReplay(); @@ -107,6 +112,8 @@ class TORCH_CUDA_CU_API ReplayTransformations : public IterVisitor { // if replaying swizzle is enabled. void handle(Swizzle2D* m) override; + void handle(Resize* resize) override; + size_t newCounter() { return counter_++; } @@ -132,6 +139,10 @@ class TORCH_CUDA_CU_API ReplayTransformations : public IterVisitor { // this functionality could be useful. bool replay_swizzle_ = false; + // Indicates if we want to replay resize ops on the replayed + // tensor. + bool replay_resize_ = false; + size_t counter_ = 0; std::vector leaf_vec_; @@ -294,7 +305,11 @@ class TORCH_CUDA_CU_API BestEffortReplay { const std::unordered_map& target_id2expr, const std::unordered_map& replay_id2expr); + // Skip resize in both target and replay domains + void skipResizes(); + public: + // When skip_resize is true, resize is ignored or in other words forwarded BestEffortReplay( const std::vector& replay_domain, const std::vector& target_domain, @@ -302,7 +317,8 @@ class TORCH_CUDA_CU_API BestEffortReplay { std::unordered_map replay_forward_id_map = {}, std::unordered_map target_forward_id_map = {}, bool skip_replay_swizzle = true, - bool skip_target_swizzle = true); + bool skip_target_swizzle = true, + bool skip_resize = false); // Return iter domain map from target_domain IDs to their "replayed" // replay_domain IDs. If not in map, was not replayed. @@ -346,23 +362,29 @@ class TORCH_CUDA_CU_API BestEffortReplay { // Runs a best effort replay that ignores broadcast axes that appear in // consumer that are not mapped to producer in root_map. + // + // When skip_resize is true, resize is ignored or in other words forwarded static BestEffortReplay replayCasP( const TensorView* consumer, const TensorView* producer, int producer_compute_at_axis, const RootDomainMap& root_map, bool skip_consumer_swizzle = true, - bool skip_producer_swizzle = true); + bool skip_producer_swizzle = true, + bool skip_resize = true); // Runs a best effort replay that ignores broadcast axes that appear in // consumer that are not mapped to producer in root_map. + // + // When skip_resize is true, resize is ignored or in other words forwarded static BestEffortReplay replayPasC( const TensorView* producer, const TensorView* consumer, int consumer_compute_at_axis, const RootDomainMap& root_map, bool skip_producer_swizzle = true, - bool skip_consumer_swizzle = true); + bool skip_consumer_swizzle = true, + bool skip_resize = true); // Find the first position i where td1[i] is not the same as td2[i]. "Same" // means the DAG and input IDs to generate td1[i] and td2[i] are the same. diff --git a/csrc/transform_replay.cpp b/csrc/transform_replay.cpp index ea290bbac8b..7a226b40ae5 100644 --- a/csrc/transform_replay.cpp +++ b/csrc/transform_replay.cpp @@ -7,6 +7,7 @@ // clang-format on #include +#include #include #include #include @@ -136,6 +137,42 @@ class ReplaySelf : public ReplayTransformations { id_map_[m->out()] = merged_id; } + void handle(Swizzle2D* swizzle) override { + TORCH_INTERNAL_ASSERT( + false, "Unexpected expr to self replay: ", swizzle->toString()); + } + + void handle(Resize* resize) override { + auto id_in = resize->in(); + + auto it = id_map_.find(id_in); + TORCH_INTERNAL_ASSERT( + it != id_map_.end(), + "Transform traversal failed, dependencies not met."); + + auto mapped = it->second; + + TORCH_INTERNAL_ASSERT( + leaf_ids_.find(mapped) != leaf_ids_.end(), + "Transform traversal failed, modified a node but it was not a leaf node."); + + // When the original output is an rfactor, make the replayed + // output domain also an rfactor + const auto resize_out_rfactor = resize->out()->isRFactorProduct(); + + auto replayed_out = IterDomain::resize( + mapped, + resize->leftExpand(), + resize->rightExpand(), + resize_out_rfactor); + + leaf_ids_.erase(mapped); + + leaf_ids_[replayed_out] = newCounter(); + + id_map_[resize->out()] = replayed_out; + } + public: ReplaySelf(const std::vector& _target_domain, id_map _id_map) : ReplayTransformations(_target_domain, std::move(_id_map)) { @@ -264,12 +301,12 @@ std::pair TransformReplay::replayPasC( const TensorView* consumer, int consumer_pos, const RootDomainMap& root_map, - bool replay_swizzle) { + bool replay_swizzle, + bool replay_resize) { FUSER_PERF_SCOPE("TransformReplay::replayPasC"); if (producer == consumer) { return {producer->domain(), producer->nDims()}; } - if (consumer_pos < 0) { consumer_pos += (int)consumer->nDims() + 1; } @@ -294,7 +331,13 @@ std::pair TransformReplay::replayPasC( // the inputs of the swizzles instead of the outputs, and therefore should not // skip swizzles in here. auto forward_replay = BestEffortReplay::replayPasC( - producer, consumer, consumer_pos, root_map, false, !replay_swizzle); + producer, + consumer, + consumer_pos, + root_map, + false, + !replay_swizzle, + !replay_resize); // Make a new map based on all the leaves resulting from best effort replay id_map forwarded_replay_map; @@ -309,7 +352,9 @@ std::pair TransformReplay::replayPasC( // Replay producer dimensions. ReplayTransformations replay_PasC(target_consumer_ids, forwarded_replay_map); - replay_PasC.setErrorOnFailure(false).setReplaySwizzle(replay_swizzle); + replay_PasC.setErrorOnFailure(false) + .setReplaySwizzle(replay_swizzle) + .setReplayResize(replay_resize); auto producer_leaf_ids(replay_PasC.getUnorderedLeafIDs()); @@ -520,9 +565,13 @@ std::pair TransformReplay::replayCasP( // axis that those ops match. // // Note on skip_swizzles: Similar constraints apply in replayPasC. See the - // corresponding notes there on not skipping swizzles in the matching here. + // corresponding notes there on not skipping swizzles in the + // matching here. + // + // The consumer may have resize, which replayCasP skips and forwards + // the mapping to the output domain of the resize. BestEffortReplay forward_replay = BestEffortReplay::replayCasP( - consumer, producer, producer_pos, root_map, false, !replay_swizzle); + consumer, producer, producer_pos, root_map, false, !replay_swizzle, true); // Track dangling leaves which can be produced in // BestEffortReplay::replayCasP these don't have any equivalent in producer @@ -538,9 +587,11 @@ std::pair TransformReplay::replayCasP( } } - // Replay producer dimensions. + // Replay producer dimensions. Currently, resize isn't replayed. ReplayTransformations replay_CasP(target_producer_ids, forwarded_replay_map); - replay_CasP.setErrorOnFailure(false).setReplaySwizzle(replay_swizzle); + replay_CasP.setErrorOnFailure(false) + .setReplaySwizzle(replay_swizzle) + .setReplayResize(false); auto consumer_leaf_ids(replay_CasP.getUnorderedLeafIDs()); @@ -715,11 +766,17 @@ std::pair TransformReplay::replayPasC( const TensorView* producer, const TensorView* consumer, int compute_at_axis, - bool replay_swizzle) { + bool replay_swizzle, + bool replay_resize) { // Use the pairwise root map as a default mapper PairwiseRootDomainMap root_map(producer, consumer); return replayPasC( - producer, consumer, compute_at_axis, root_map, replay_swizzle); + producer, + consumer, + compute_at_axis, + root_map, + replay_swizzle, + replay_resize); } std::pair TransformReplay::replayCasP( @@ -740,7 +797,8 @@ std::pair TransformReplay::replayCasP( int TransformReplay::getMatchedLeafPosWithoutReplayPasC( const TensorView* producer, const TensorView* consumer, - int consumer_pos) { + int consumer_pos, + bool skip_resize) { FUSER_PERF_SCOPE("transform_replay.cpp::getMatchedLeafPosWithoutReplayPasC"); const auto pairwise_map = PairwiseRootDomainMap(producer, consumer); @@ -768,7 +826,8 @@ int TransformReplay::getMatchedLeafPosWithoutReplayPasC( auto it_producer = producer_domain.begin(); auto disjoint_sets = - BestEffortReplay::replayPasC(producer, consumer, -1, pairwise_map) + BestEffortReplay::replayPasC( + producer, consumer, -1, pairwise_map, true, true, skip_resize) .getIterDomainEquivalence(); int mismatched_consumer_pos = 0; @@ -809,7 +868,8 @@ int TransformReplay::getMatchedLeafPosWithoutReplayPasC( int TransformReplay::getMatchedLeafPosWithoutReplayCasP( const TensorView* consumer, const TensorView* producer, - int producer_pos) { + int producer_pos, + bool skip_resize) { FUSER_PERF_SCOPE("transform_replay.cpp::getMatchedLeafPosWithoutReplayCasP"); const auto pairwise_map = PairwiseRootDomainMap(producer, consumer); @@ -841,7 +901,8 @@ int TransformReplay::getMatchedLeafPosWithoutReplayCasP( auto it_consumer = consumer_domain.begin(); auto disjoint_sets = - BestEffortReplay::replayPasC(producer, consumer, -1, pairwise_map) + BestEffortReplay::replayPasC( + producer, consumer, -1, pairwise_map, true, true, skip_resize) .getIterDomainEquivalence(); int mismatched_producer_pos = 0; @@ -937,8 +998,11 @@ void TransformPropagator::propagateC2P(TensorView* from, TensorView* to) { // current TransformPropagator might not contain the most amount of // information on how to do the correct transformation. The logic below tells // TransformPropagator to skip the replay when not necessary. + // + // Note on resize: When propagating transformations, resize is just + // skipped, or forwarded, so the matching here is done by skipping it. int new_pos = - TransformReplay::getMatchedLeafPosWithoutReplayPasC(to, from, pos); + TransformReplay::getMatchedLeafPosWithoutReplayPasC(to, from, pos, true); bool debug = isDebugDumpEnabled(DebugDumpOption::TransformPropagator); if (debug) { std::cout << "TransformPropagator::propagateC2P" << std::endl; @@ -969,7 +1033,7 @@ void TransformPropagator::propagateP2C(TensorView* from, TensorView* to) { int pos = replayed_pos_.at(from); // See note [Using multiple TransformPropagators] int new_pos = - TransformReplay::getMatchedLeafPosWithoutReplayCasP(to, from, pos); + TransformReplay::getMatchedLeafPosWithoutReplayCasP(to, from, pos, true); bool debug = isDebugDumpEnabled(DebugDumpOption::TransformPropagator); if (debug) { std::cout << "TransformPropagator::propagateP2C" << std::endl; @@ -1040,7 +1104,7 @@ void MostInlinedTransformPropagator::propagateC2P( int pos = from->nDims(); // See note [Using multiple TransformPropagators] int new_pos = - TransformReplay::getMatchedLeafPosWithoutReplayPasC(to, from, pos); + TransformReplay::getMatchedLeafPosWithoutReplayPasC(to, from, pos, true); bool debug = isDebugDumpEnabled(DebugDumpOption::TransformPropagator); if (debug) { std::cout << "MostInlinedTransformPropagator::propagateC2P" << std::endl; @@ -1071,7 +1135,7 @@ void MostInlinedTransformPropagator::propagateP2C( int pos = from->nDims(); // See note [Using multiple TransformPropagators] int new_pos = - TransformReplay::getMatchedLeafPosWithoutReplayCasP(to, from, pos); + TransformReplay::getMatchedLeafPosWithoutReplayCasP(to, from, pos, true); bool debug = isDebugDumpEnabled(DebugDumpOption::TransformPropagator); if (debug) { std::cout << "MostInlinedTransformPropagator::propagateP2C" << std::endl; diff --git a/csrc/transform_replay.h b/csrc/transform_replay.h index ae5766b0a4f..9b78c15d998 100644 --- a/csrc/transform_replay.h +++ b/csrc/transform_replay.h @@ -132,20 +132,28 @@ class RootDomainMap; class TORCH_CUDA_CU_API TransformReplay { public: // Replay producer as consumer, returns {producer, producer_compute_at_axis}. + // + // replay_resize indicates whether resize should be replayed or + // ignored. It is only replayed when replaying a producer for + // indexing. static std::pair replayPasC( const TensorView* producer, const TensorView* consumer, int consumer_compute_at_axis, - bool replay_swizzle = false); + bool replay_swizzle = false, + bool replay_resize = false); static std::pair replayPasC( const TensorView* producer, const TensorView* consumer, int consumer_compute_at_axis, const RootDomainMap& root_map, - bool replay_swizzle = false); + bool replay_swizzle = false, + bool replay_resize = false); // Replay producer as consumer, returns {replayed_consumer_domain, // consumer_compute_at_axis}. + // + // Unlike replayPasC, it always ignores resize. static std::pair replayCasP( const TensorView* consumer, const TensorView* producer, @@ -171,18 +179,32 @@ class TORCH_CUDA_CU_API TransformReplay { // position as `replayPasC`. However, this function is more tolerant than // fully matching `replayPasC`: if in the consumer, there are unmappable // dimensions, these dimensions are just ignored. + // + // When skip_resize is true, mapping is done more permissively by + // skipping resize ops. For example, that is done when this is used + // by TransformPropagator, whereas it isn't when used for + // determining the inlining position by MaxPosCalculator as inlining + // isn't allowed with different extents. static int getMatchedLeafPosWithoutReplayPasC( const TensorView* producer, const TensorView* consumer, - int consumer_pos); + int consumer_pos, + bool skip_resize = false); // Returns the leaf position in consumer that matches with `producer_pos` in // producer. Behavior similar to getMatchedLeafPosWithoutReplayPasC, except // that we are also ignoring reductions in the producer. + // + // When skip_resize is true, mapping is done more permissively by + // skipping resize ops. For example, that is done when this is used + // by TransformPropagator, whereas it isn't when used for + // determining the inlining position by MaxPosCalculator as inlining + // isn't allowed with different extents. static int getMatchedLeafPosWithoutReplayCasP( const TensorView* consumer, const TensorView* producer, - int producer_pos); + int producer_pos, + bool skip_resize = false); // tests if two tensors has fully matching transformations static bool fullSelfMatching( diff --git a/csrc/transform_rfactor.cpp b/csrc/transform_rfactor.cpp index 75fbd116fa1..04ed387fa9d 100644 --- a/csrc/transform_rfactor.cpp +++ b/csrc/transform_rfactor.cpp @@ -217,6 +217,15 @@ class ReplayRFactor : public ReplayTransformations { } } + void handle(Resize* resize) override { + TORCH_INTERNAL_ASSERT(false, "Unexpected expression: ", resize->toString()); + } + + void handle(Swizzle2D* swizzle) override { + TORCH_INTERNAL_ASSERT( + false, "Unexpected expression: ", swizzle->toString()); + } + // The IterDomains in the original_domain that are being factored into the // first stage of the two stage reduction (the producer). std::unordered_set rfactor_axes_; @@ -446,7 +455,8 @@ std::pair TransformRFactor::runReplay( ReplayTransformations consumer_replay( original_td->domain(), original_to_consumer_root_map); - consumer_replay.setErrorOnFailure(false); + consumer_replay.setErrorOnFailure(false).setReplayResize(true); + auto original_to_consumer_map = consumer_replay.getReplay(); std::vector new_consumer_domain; diff --git a/csrc/type.h b/csrc/type.h index e5e54480e6b..8b71e32980a 100644 --- a/csrc/type.h +++ b/csrc/type.h @@ -370,13 +370,20 @@ enum class IterType { }; // Used for Iteration Domain mapping modes in ComputeAtMap -enum class IdMappingMode { EXACT, ALMOSTEXACT, LOOP, PERMISSIVE }; +enum class IdMappingMode { + EXACT, + ALMOSTEXACT, + LOOP, + PERMISSIVE, + PERMISSIVE_RESIZE +}; -static constexpr std::array kIdMappingModes = { +static constexpr std::array kIdMappingModes = { IdMappingMode::EXACT, IdMappingMode::ALMOSTEXACT, IdMappingMode::LOOP, - IdMappingMode::PERMISSIVE}; + IdMappingMode::PERMISSIVE, + IdMappingMode::PERMISSIVE_RESIZE}; // Used to annotate the special memory intrinsics that a loadstore // op will be lowered to. diff --git a/test/test_gpu2.cpp b/test/test_gpu2.cpp index bc62f622435..7e8f0694d00 100644 --- a/test/test_gpu2.cpp +++ b/test/test_gpu2.cpp @@ -4351,8 +4351,10 @@ TEST_F(NVFuserTest, FusionVectorizeMisalignedStride_CUDA) { auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); const int bx = 128; const int by = 2049; - at::Tensor t0 = at::randn({bx, by}, options).index({"...", Slice(3)}); - at::Tensor t1 = at::randn({bx, by}, options).index({"...", Slice(3)}); + at::Tensor t0 = + at::randn({bx, by}, options).index({"...", at::indexing::Slice(3)}); + at::Tensor t1 = + at::randn({bx, by}, options).index({"...", at::indexing::Slice(3)}); std::vector aten_inputs = {t0, t1}; FusionExecutor fe; @@ -4402,8 +4404,11 @@ TEST_F(NVFuserTest, FusionVectorizeMisalignedStrideFail_CUDA) { auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); const int bx = 128; const int by = 2049; - at::Tensor t0 = at::randn({bx, by}, options).index({"...", Slice(3)}); - at::Tensor t1 = at::randn({bx, by}, options).index({"...", Slice(3)}); + + at::Tensor t0 = + at::randn({bx, by}, options).index({"...", at::indexing::Slice(3)}); + at::Tensor t1 = + at::randn({bx, by}, options).index({"...", at::indexing::Slice(3)}); std::vector aten_inputs = {t0, t1}; FusionExecutor fe; @@ -4548,13 +4553,13 @@ TEST_F(NVFuserTest, FusionVectorization3_CUDA) { // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) ASSERT_ANY_THROW(fe.runFusion(aten_inputs)); - aten_inputs[0] = t0.index({"...", Slice(1)}); - aten_inputs[1] = t1.index({"...", Slice(1)}); + aten_inputs[0] = t0.index({"...", at::indexing::Slice(1)}); + aten_inputs[1] = t1.index({"...", at::indexing::Slice(1)}); // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) ASSERT_ANY_THROW(fe.runFusion(aten_inputs)); - t0 = at::randn({bx, 2048}, options).index({"...", Slice(4)}); - t1 = at::randn({bx, 2048}, options).index({"...", Slice(4)}); + t0 = at::randn({bx, 2048}, options).index({"...", at::indexing::Slice(4)}); + t1 = at::randn({bx, 2048}, options).index({"...", at::indexing::Slice(4)}); aten_inputs = {t0, t1}; auto cg_outputs = fe.runFusion(aten_inputs); diff --git a/test/test_gpu3.cpp b/test/test_gpu3.cpp index 1737637b3e0..3213f1d3348 100644 --- a/test/test_gpu3.cpp +++ b/test/test_gpu3.cpp @@ -1759,16 +1759,16 @@ __global__ void CUDAGeneratedKernel(Tensor T0, Tensor T2) { i39 = ((nvfuser_index_t)threadIdx.x) + (256 * ((nvfuser_index_t)blockIdx.x)); int64_t i7; i7 = T0.size[0] * T0.size[1]; - bool b75; - b75 = i39 < i7; + bool b86; + b86 = i39 < i7; float f8; f8 = (float)(i7); float T1[1]; - if (b75) { + if (b86) { T1[0] = sinf(T0[i39]); } - if (b75) { + if (b86) { T2[i39] = T1[0] + f8; @@ -2188,8 +2188,10 @@ TEST_F(NVFuserTest, FusionVectorizeInputToOutput_CUDA) { const int n = 12; auto t0 = at::randn({n}, options); // Shift by one to make it non-aligned - auto t0_misaligned = at::randn({n + 1}, options).index({Slice(1)}); - auto t1_misaligned = at::empty({n + 1}, options).index({Slice(1)}); + auto t0_misaligned = + at::randn({n + 1}, options).index({at::indexing::Slice(1)}); + auto t1_misaligned = + at::empty({n + 1}, options).index({at::indexing::Slice(1)}); FusionExecutor fe; fe.compileFusion(&fusion, {t0}); diff --git a/test/test_gpu_utils.cpp b/test/test_gpu_utils.cpp index 92cb6ee353e..985ea687529 100644 --- a/test/test_gpu_utils.cpp +++ b/test/test_gpu_utils.cpp @@ -108,7 +108,7 @@ TEST_F(NVFuserTest, FusionDisjointViewSet_CUDA) { auto tv3 = add(tv2, tv1); fusion->addOutput(tv3); - auto disjoint_exact = scheduler_utils::disjointViewSets(fusion.get()); + auto disjoint_exact = scheduler_utils::disjointRFactorSets(fusion.get()); TORCH_INTERNAL_ASSERT( disjoint_exact.strictAreMapped(tv0->axis(1), tv0->axis(2))); diff --git a/test/test_resize.cpp b/test/test_resize.cpp new file mode 100644 index 00000000000..3513173ae51 --- /dev/null +++ b/test/test_resize.cpp @@ -0,0 +1,1594 @@ +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace nvfuser { + +// Simple pad test +TEST_F(NVFuserTest, FusionResizePad1_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + std::vector shape({9}); + + auto tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + + auto tv1 = pad(tv0, {IrBuilder::create(1), IrBuilder::create(1)}); + fusion.addOutput(tv1); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + + auto t0 = at::randn(shape, options); + std::vector aten_inputs({t0}); + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto cg_outputs = fe.runFusion(aten_inputs); + + auto ref = at::pad(t0, {1, 1}); + + TORCH_CHECK(ref.equal(cg_outputs[0])); +} + +// pad + split +TEST_F(NVFuserTest, FusionResizePad2_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + std::vector shape({9}); + + auto tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + + auto tv1 = pad(tv0, {IrBuilder::create(1), IrBuilder::create(1)}); + fusion.addOutput(tv1); + + tv1->split(0, 4); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + + auto t0 = at::randn(shape, options); + std::vector aten_inputs({t0}); + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto cg_outputs = fe.runFusion(aten_inputs); + + auto ref = at::pad(t0, {1, 1}); + + TORCH_CHECK(ref.equal(cg_outputs[0])); +} + +// pad, merge + split, inlineMost +TEST_F(NVFuserTest, FusionResizePad3_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + std::vector shape({9, 11}); + std::vector padded_shape({9, 11 + 2}); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + auto tv1 = makeSymbolicTensor(2); + fusion.addInput(tv1); + + auto tv2 = set(tv0); + auto tv3 = pad(tv2, {IrBuilder::create(1), IrBuilder::create(1)}); + auto tv4 = add(tv3, tv1); + fusion.addOutput(tv4); + + tv4->merge(0); + tv4->split(0, 32); + + TransformPropagator propagator(tv4); + MaxRootDomainInfoSpanningTree(tv4).traverse(&propagator); + + inlineMost(); + + // TransformPropagator and inlineMost do not inline tv2, so it can't + // be on Local memory. It should be possible to expand tv2 such that + // it has the same extent as tv3, allowing it to be inlined. + tv2->setMemoryType(MemoryType::Shared); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + + auto t0 = at::randn(shape, options); + auto t1 = at::randn(padded_shape, options); + std::vector aten_inputs({t0, t1}); + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto cg_outputs = fe.runFusion(aten_inputs); + + auto t3 = at::pad(t0, {1, 1}); + auto ref = t3 + t1; + + testValidate(&fusion, cg_outputs, aten_inputs, {ref}, __LINE__, __FILE__); +} + +// pad + parallelization +TEST_F(NVFuserTest, FusionResizePad4_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + std::vector shape({9}); + + auto tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + + auto tv1 = pad(tv0, {IrBuilder::create(1), IrBuilder::create(1)}); + fusion.addOutput(tv1); + + tv1->axis(0)->parallelize(ParallelType::TIDx); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + + auto t0 = at::randn(shape, options); + std::vector aten_inputs({t0}); + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto cg_outputs = fe.runFusion(aten_inputs); + + auto ref = at::pad(t0, {1, 1}); + + TORCH_CHECK(ref.equal(cg_outputs[0])); +} + +// pad + parallelization + RAW sync +TEST_F(NVFuserTest, FusionResizePad5_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + std::vector shape({9}); + + auto tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + + auto tv1 = set(tv0); + auto tv2 = pad(tv1, {IrBuilder::create(1), IrBuilder::create(1)}); + fusion.addOutput(tv2); + + tv1->axis(0)->parallelize(ParallelType::TIDx); + tv2->axis(0)->parallelize(ParallelType::TIDx); + + scheduler_utils::promoteProducerMemoryTypesOfResizedTensors(&fusion); + + TORCH_CHECK( + tv1->getMemoryType() == MemoryType::Shared, + "tv1 should be on shared memory: ", + tv1->getMemoryType()); + + GpuLower gpulw(&fusion); + auto all_lowered_exprs = KernelExprVisitor::getAllExprs(gpulw.kernel()); + TORCH_CHECK( + std::find_if( + all_lowered_exprs.begin(), + all_lowered_exprs.end(), + [](Expr* expr) { return expr->isA(); }) != + all_lowered_exprs.end(), + "Block sync not found"); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + + auto t0 = at::randn(shape, options); + std::vector aten_inputs({t0}); + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto cg_outputs = fe.runFusion(aten_inputs); + + auto ref = at::pad(t0, {1, 1}); + + TORCH_CHECK(ref.equal(cg_outputs[0])); +} + +// pad + merge + split parallelization +TEST_F(NVFuserTest, FusionResizePad6_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + std::vector shape({99, 111}); + std::vector padded_shape({shape[0], shape[1] + 2}); + + auto tv0 = makeConcreteTensor(shape); + fusion.addInput(tv0); + auto tv1 = makeConcreteTensor(padded_shape); + fusion.addInput(tv1); + + auto tv2 = add(tv0, IrBuilder::create(1)); + auto tv3 = pad(tv2, {IrBuilder::create(1), IrBuilder::create(1)}); + auto tv4 = add(tv3, tv1); + fusion.addOutput(tv4); + + tv4->merge(0); + tv4->split(0, 32); + + TransformPropagator propagator(tv4); + MaxRootDomainInfoSpanningTree(tv4).traverse(&propagator); + + inlineMost(); + + tv4->axis(0)->parallelize(ParallelType::BIDx); + tv4->axis(1)->parallelize(ParallelType::TIDx); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + + auto t0 = at::randn(shape, options); + auto t1 = at::randn(padded_shape, options); + std::vector aten_inputs({t0, t1}); + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto cg_outputs = fe.runFusion(aten_inputs); + + auto t2 = t0 + 1; + auto t3 = at::pad(t2, {1, 1}); + auto ref = t3 + t1; + + testValidate(&fusion, cg_outputs, aten_inputs, {ref}, __LINE__, __FILE__); +} + +// pad + unswitch. Having different extents in an unswitched loop nest +// needs a special care (see UnrollPass::canOmitElseClause) +TEST_F(NVFuserTest, FusionResizePad7_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + std::vector shape({9, 11}); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + auto tv1 = set(tv0); + auto tv2 = pad(tv1, {IrBuilder::create(1), IrBuilder::create(1)}); + auto tv3 = set(tv2); + fusion.addOutput(tv3); + + tv3->split(0, 1); + tv3->split(-1, 4); + tv3->reorder({{1, 2}}); + + TransformPropagator propagator(tv3); + MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); + + inlineMost(); + + tv3->axis(-1)->parallelize(ParallelType::TIDx); + tv3->axis(-2)->parallelize(ParallelType::Unswitch); + + scheduler_utils::parallelizeAllLike(tv3); + + scheduler_utils::promoteProducerMemoryTypesOfResizedTensors(&fusion); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + + auto t0 = at::randn(shape, options); + std::vector aten_inputs({t0}); + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto cg_outputs = fe.runFusion(aten_inputs); + + auto ref = at::pad(t0, {1, 1}); + + testValidate(&fusion, cg_outputs, aten_inputs, {ref}, __LINE__, __FILE__); +} + +// Disable for now. Unclear what would be the best way to handle +// when a tensor is resized multiple times. It would likely need a +// different transform propagator. +#if 0 +// Stencil-like pattern +TEST_F(NVFuserTest, FusionResizePad8_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + + auto tv1 = set(tv0); + // Sort of shift(tv1, {-1}); + auto tv2 = pad(tv1, {IrBuilder::create(0), IrBuilder::create(1)}); + // Sort of shift(tv1, {1}); + auto tv3 = pad(tv1, {IrBuilder::create(1), IrBuilder::create(0)}); + auto tv4 = add(tv2, tv3); + fusion.addOutput(tv4); + + tv4->split(0, 128); + + TransformPropagator propagator(tv4); + MaxRootDomainInfoSpanningTree(tv4).traverse(&propagator); + + inlineMost(); + + tv4->axis(0)->parallelize(ParallelType::BIDx); + tv4->axis(1)->parallelize(ParallelType::TIDx); + scheduler_utils::parallelizeAllLike(tv4); + + scheduler_utils::promoteProducerMemoryTypesOfResizedTensors(&fusion); + + fusion.printMath(); + fusion.print(); + + fusion.printKernel(); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + + auto t0 = at::randn(999, options); + std::vector aten_inputs({t0}); + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto cg_outputs = fe.runFusion(aten_inputs); + + auto ref = at::pad(t0, {0, 1}) + at::pad(t0, {1, 0}); + + testValidate(&fusion, cg_outputs, aten_inputs, {ref}, __LINE__, __FILE__); +} +#endif + +TEST_F(NVFuserTest, FusionResizePadScheduler1_CUDA) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + auto tv0 = makeSymbolicTensor(2); + fusion->addInput(tv0); + + auto tv1 = pad(tv0, {IrBuilder::create(1), IrBuilder::create(1)}); + fusion->addOutput(tv1); + + std::vector shape({99, 111}); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + + auto t0 = at::randn(shape, options); + std::vector aten_inputs({t0}); + + FusionExecutorCache executor_cache(std::move(fusion)); + auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs); + + auto ref = at::pad(t0, {1, 1}); + + TORCH_CHECK(ref.equal(cg_outputs[0])); +} + +TEST_F(NVFuserTest, FusionResizePadScheduler2_CUDA) { + auto fusion_ptr = std::make_unique(); + auto& fusion = *fusion_ptr; + FusionGuard fg(fusion_ptr.get()); + + std::vector shape({9, 11}); + std::vector padded_shape({9, 11 + 2}); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + auto tv1 = makeSymbolicTensor(2); + fusion.addInput(tv1); + + auto tv2 = set(tv0); + auto tv3 = pad(tv2, {IrBuilder::create(1), IrBuilder::create(1)}); + auto tv4 = add(tv3, tv1); + fusion.addOutput(tv4); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + + auto t0 = at::randn(shape, options); + auto t1 = at::randn(padded_shape, options); + std::vector aten_inputs({t0, t1}); + + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs); + + auto t3 = at::pad(t0, {1, 1}); + auto ref = t3 + t1; + + testValidate( + executor_cache.fusion(), + cg_outputs, + aten_inputs, + {ref}, + __LINE__, + __FILE__); +} + +// Disabled due to the same reason as Pad8 +#if 0 +// Auto scheduled version of Pad8 +TEST_F(NVFuserTest, FusionResizePadScheduler3_CUDA) { + auto fusion_ptr = std::make_unique(); + auto& fusion = *fusion_ptr; + FusionGuard fg(fusion_ptr.get()); + + auto tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + + auto tv1 = set(tv0); + auto tv2 = pad(tv1, {IrBuilder::create(0), IrBuilder::create(1)}); + auto tv3 = pad(tv1, {IrBuilder::create(1), IrBuilder::create(0)}); + auto tv4 = add(tv2, tv3); + fusion.addOutput(tv4); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + + auto t0 = at::randn(999, options); + std::vector aten_inputs({t0}); + + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs); + + auto ref = at::pad(t0, {0, 1}) + at::pad(t0, {1, 0}); + + testValidate( + executor_cache.fusion(), + cg_outputs, + aten_inputs, + {ref}, + __LINE__, + __FILE__); +} +#endif + +// Two pad exprs, both using the same symbolic pad widths, segmented +// into two kernels. Make sure the symbolic inputs are available to +// both of the segmented kernels. +TEST_F(NVFuserTest, FusionResizePadScheduler4_CUDA) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + auto tv0 = makeSymbolicTensor(2); + fusion->addInput(tv0); + + auto left_pad = IrBuilder::create(); + fusion->addInput(left_pad); + auto right_pad = IrBuilder::create(); + fusion->addInput(right_pad); + + auto tv1 = pad(tv0, {left_pad, right_pad}); + auto tv2 = sum(tv1, {0}); + fusion->addOutput(tv2); + + auto tv3 = pad(tv0, {left_pad, right_pad}); + auto tv4 = sum(tv3, {1}); + fusion->addOutput(tv4); + + std::vector shape({99, 111}); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + + auto t0 = at::randn(shape, options); + std::vector pad_extents{1, 1}; + std::vector aten_inputs({t0, 1, 1}); + + FusionExecutorCache executor_cache(std::move(fusion)); + auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs); + + auto t0_double = t0.to(at::kDouble); + auto t2 = at::pad(t0_double, {1, 1}).sum({0}); + auto t4 = at::pad(t0_double, {1, 1}).sum({1}); + + testValidate( + executor_cache.fusion(), + cg_outputs, + aten_inputs, + {t2, t4}, + __LINE__, + __FILE__); +} + +// Trivial cat +TEST_F(NVFuserTest, FusionResizeCat1_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + std::vector shape0({2}); + std::vector shape1({3}); + + auto tv0 = makeConcreteTensor(shape0); + fusion.addInput(tv0); + + auto tv1 = makeConcreteTensor(shape1); + fusion.addInput(tv1); + + auto tv2 = cat({tv0, tv1}, 0); + fusion.addOutput(tv2); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + + auto t0 = at::randn(shape0, options); + auto t1 = at::randn(shape1, options); + std::vector aten_inputs({t0, t1}); + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto cg_outputs = fe.runFusion(aten_inputs); + + auto ref = at::cat({t0, t1}, 0); + + TORCH_CHECK(ref.equal(cg_outputs[0])); +} + +// Trivial 2D inner cat +TEST_F(NVFuserTest, FusionResizeCat2_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + std::vector shape0({2, 4}); + std::vector shape1({3, 4}); + + auto tv0 = makeConcreteTensor(shape0); + fusion.addInput(tv0); + + auto tv1 = makeConcreteTensor(shape1); + fusion.addInput(tv1); + + auto tv2 = cat({tv0, tv1}, 0); + fusion.addOutput(tv2); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + + auto t0 = at::randn(shape0, options); + auto t1 = at::randn(shape1, options); + std::vector aten_inputs({t0, t1}); + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto cg_outputs = fe.runFusion(aten_inputs); + + auto ref = at::cat({t0, t1}, 0); + + TORCH_CHECK(ref.equal(cg_outputs[0])); +} + +// Trivial 2D outer cat +TEST_F(NVFuserTest, FusionResizeCat3_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + std::vector shape0({4, 2}); + std::vector shape1({4, 3}); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + auto tv1 = makeSymbolicTensor(2); + fusion.addInput(tv1); + + auto tv2 = cat({tv0, tv1}, 1); + fusion.addOutput(tv2); + + tv2->merge(0); + tv2->split(0, 4); + + TransformPropagator propagator(tv2); + MaxRootDomainInfoSpanningTree(tv2).traverse(&propagator); + + inlineMost(); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + + auto t0 = at::randn(shape0, options); + auto t1 = at::randn(shape1, options); + std::vector aten_inputs({t0, t1}); + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto cg_outputs = fe.runFusion(aten_inputs); + + auto ref = at::cat({t0, t1}, 1); + + TORCH_CHECK(ref.equal(cg_outputs[0])); +} + +// Cat + merge + split + parallelization + inlineMost +TEST_F(NVFuserTest, FusionResizeCat4_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + std::vector shape0({11, 12}); + std::vector shape1({11, 13}); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + auto tv1 = makeSymbolicTensor(2); + fusion.addInput(tv1); + + auto tv2 = cat({tv0, tv1}, 1); + fusion.addOutput(tv2); + + tv2->merge(0); + tv2->split(0, 128); + + TransformPropagator propagator(tv2); + MaxRootDomainInfoSpanningTree(tv2).traverse(&propagator); + + tv2->axis(0)->parallelize(ParallelType::BIDx); + tv2->axis(1)->parallelize(ParallelType::TIDx); + + inlineMost(); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + + auto t0 = at::randn(shape0, options); + auto t1 = at::randn(shape1, options); + std::vector aten_inputs({t0, t1}); + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto cg_outputs = fe.runFusion(aten_inputs); + + auto ref = at::cat({t0, t1}, 1); + + TORCH_CHECK(ref.equal(cg_outputs[0])); +} + +// Cat + arith op +TEST_F(NVFuserTest, FusionResizeCat5_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + auto tv1 = makeSymbolicTensor(2); + fusion.addInput(tv1); + auto tv2 = makeSymbolicTensor(2); + fusion.addInput(tv2); + + auto tv3 = cat({tv0, tv1}, 1); + auto tv4 = add(tv3, tv2); + fusion.addOutput(tv4); + + tv4->merge(0); + tv4->split(0, 128); + + TransformPropagator propagator(tv4); + MaxRootDomainInfoSpanningTree(tv4).traverse(&propagator); + + inlineMost(); + + tv4->axis(0)->parallelize(ParallelType::BIDx); + tv4->axis(1)->parallelize(ParallelType::TIDx); + scheduler_utils::parallelizeAllLike(tv4); + + std::vector shape0({11, 12}); + std::vector shape1({shape0[0], 13}); + std::vector shape2({shape0[0], shape0[1] + shape1[1]}); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + + auto t0 = at::randn(shape0, options); + auto t1 = at::randn(shape1, options); + auto t2 = at::randn(shape2, options); + std::vector aten_inputs({t0, t1, t2}); + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto cg_outputs = fe.runFusion(aten_inputs); + + auto ref = at::cat({t0, t1}, 1) + t2; + + testValidate(&fusion, cg_outputs, aten_inputs, {ref}, __LINE__, __FILE__); +} + +// Cat 3 tensors +TEST_F(NVFuserTest, FusionResizeCat6_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + std::vector shape0({2, 4}); + std::vector shape1({5, 4}); + std::vector shape2({3, 4}); + + auto tv0 = makeConcreteTensor(shape0); + fusion.addInput(tv0); + auto tv1 = makeConcreteTensor(shape1); + fusion.addInput(tv1); + auto tv2 = makeConcreteTensor(shape2); + fusion.addInput(tv2); + + auto tv3 = cat({tv0, tv1, tv2}, 0); + fusion.addOutput(tv3); + + tv3->merge(0); + tv3->split(0, 4); + TransformPropagator propagator(tv3); + MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); + + inlineMost(); + + tv3->axis(0)->parallelize(ParallelType::BIDx); + tv3->axis(1)->parallelize(ParallelType::TIDx); + scheduler_utils::parallelizeAllLike(tv3); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + + auto t0 = at::randn(shape0, options); + auto t1 = at::randn(shape1, options); + auto t2 = at::randn(shape2, options); + std::vector aten_inputs({t0, t1, t2}); + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto cg_outputs = fe.runFusion(aten_inputs); + + auto ref = at::cat({t0, t1, t2}, 0); + + TORCH_CHECK(ref.equal(cg_outputs[0])); +} + +// Cat many tensors +TEST_F(NVFuserTest, FusionResizeCat7_CUDA) { + int num_tensors_to_concat = 10; + + for (int concat_dim : {0, 1}) { + Fusion fusion; + FusionGuard fg(&fusion); + + std::vector inputs; + for (const auto i : c10::irange(num_tensors_to_concat)) { + (void)i; + auto tv = makeSymbolicTensor(2); + fusion.addInput(tv); + inputs.push_back(tv); + } + + auto concat_tv = cat(inputs, concat_dim); + fusion.addOutput(concat_tv); + + concat_tv->merge(0); + concat_tv->split(0, 128); + + TransformPropagator propagator(concat_tv); + MaxRootDomainInfoSpanningTree(concat_tv).traverse(&propagator); + + inlineMost(); + + concat_tv->axis(0)->parallelize(ParallelType::BIDx); + concat_tv->axis(1)->parallelize(ParallelType::TIDx); + scheduler_utils::parallelizeAllLike(concat_tv); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + + std::vector base_shape({11, 13}); + std::vector aten_inputs; + for (const auto i : c10::irange(num_tensors_to_concat)) { + auto shape = base_shape; + shape[concat_dim] = 10 + (i % 5); + aten_inputs.emplace_back(at::randn(shape, options)); + } + + std::vector aten_inputs_ivalue( + {aten_inputs.begin(), aten_inputs.end()}); + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs_ivalue); + auto cg_outputs = fe.runFusion(aten_inputs_ivalue); + + auto ref = at::cat(aten_inputs, concat_dim); + + TORCH_CHECK(ref.equal(cg_outputs[0])); + } +} + +// Auto scheduled version of Cat1 +TEST_F(NVFuserTest, FusionResizeCatScheduler1_CUDA) { + auto fusion_ptr = std::make_unique(); + auto& fusion = *fusion_ptr; + FusionGuard fg(fusion_ptr.get()); + + auto tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + + auto tv1 = makeSymbolicTensor(1); + fusion.addInput(tv1); + + auto tv2 = cat({tv0, tv1}, 0); + fusion.addOutput(tv2); + + std::vector shape0({2}); + std::vector shape1({3}); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + + auto t0 = at::randn(shape0, options); + auto t1 = at::randn(shape1, options); + std::vector aten_inputs({t0, t1}); + + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs); + + auto ref = at::cat({t0, t1}, 0); + + TORCH_CHECK(ref.equal(cg_outputs[0])); +} + +// Auto scheduled version of Cat5 +TEST_F(NVFuserTest, FusionResizeCatScheduler2_CUDA) { + auto fusion_ptr = std::make_unique(); + auto& fusion = *fusion_ptr; + FusionGuard fg(fusion_ptr.get()); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + auto tv1 = makeSymbolicTensor(2); + fusion.addInput(tv1); + auto tv2 = makeSymbolicTensor(2); + fusion.addInput(tv2); + + auto tv3 = cat({tv0, tv1}, 1); + auto tv4 = add(tv3, tv2); + fusion.addOutput(tv4); + + std::vector shape0({11, 12}); + std::vector shape1({shape0[0], 13}); + std::vector shape2({shape0[0], shape0[1] + shape1[1]}); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + + auto t0 = at::randn(shape0, options); + auto t1 = at::randn(shape1, options); + auto t2 = at::randn(shape2, options); + std::vector aten_inputs({t0, t1, t2}); + + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs); + + auto ref = at::cat({t0, t1}, 1) + t2; + + testValidate( + executor_cache.fusion(), + cg_outputs, + aten_inputs, + {ref}, + __LINE__, + __FILE__); +} + +// Auto scheduled version of Cat6 +TEST_F(NVFuserTest, FusionResizeCatScheduler3_CUDA) { + auto fusion_ptr = std::make_unique(); + auto& fusion = *fusion_ptr; + FusionGuard fg(fusion_ptr.get()); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + auto tv1 = makeSymbolicTensor(2); + fusion.addInput(tv1); + auto tv2 = makeSymbolicTensor(2); + fusion.addInput(tv2); + + auto tv3 = cat({tv0, tv1, tv2}, 0); + fusion.addOutput(tv3); + + std::vector shape0({2, 4}); + std::vector shape1({5, 4}); + std::vector shape2({3, 4}); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + + auto t0 = at::randn(shape0, options); + auto t1 = at::randn(shape1, options); + auto t2 = at::randn(shape2, options); + std::vector aten_inputs({t0, t1, t2}); + + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs); + + auto ref = at::cat({t0, t1, t2}, 0); + + TORCH_CHECK(ref.equal(cg_outputs[0])); +} + +// Trivial slice +TEST_F(NVFuserTest, FusionResizeSlice1_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + std::vector shape({9}); + + auto tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + + auto tv1 = slice( + tv0, + {{IrBuilder::create(1), + sub(tv0->axis(0)->extent(), IrBuilder::create(1))}}); + fusion.addOutput(tv1); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + + auto t0 = at::randn(shape, options); + std::vector aten_inputs({t0}); + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto cg_outputs = fe.runFusion(aten_inputs); + + auto ref = t0.index({at::indexing::Slice(1, shape[0] - 1)}); + + TORCH_CHECK(ref.equal(cg_outputs[0])); +} + +// Split a tensor to half and add them up +TEST_F(NVFuserTest, FusionResizeSlice2_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + std::vector shape({11, 30}); + + TORCH_CHECK(shape[1] % 2 == 0); + + auto tv0 = makeConcreteTensor(shape); + fusion.addInput(tv0); + + auto tv1 = slice( + tv0, + {Slice(), + {IrBuilder::create(0), IrBuilder::create(shape[1] / 2)}}); + auto tv2 = slice(tv0, {Slice(), {IrBuilder::create(shape[1] / 2)}}); + auto tv3 = add(tv1, tv2); + fusion.addOutput(tv3); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + + auto t0 = at::randn(shape, options); + std::vector aten_inputs({t0}); + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto cg_outputs = fe.runFusion(aten_inputs); + + auto t1 = t0.index( + {at::indexing::Slice(0, at::indexing::None), + at::indexing::Slice(0, shape[1] / 2)}); + auto t2 = t0.index( + {at::indexing::Slice(0, at::indexing::None), + at::indexing::Slice(shape[1] / 2)}); + auto ref = t1 + t2; + + testValidate(&fusion, cg_outputs, aten_inputs, {ref}, __LINE__, __FILE__); +} + +// "Trivial" slice is converted to Set +TEST_F(NVFuserTest, FusionResizeSlice3_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + + // These should result in unary set op + auto tv1 = slice(tv0, {{nullptr, tv0->axis(0)->extent()}}); + auto tv2 = slice(tv0, {Slice()}); + auto tv3 = add(tv1, tv2); + fusion.addOutput(tv3); + + TORCH_CHECK( + tv1->definition()->isA() && + tv1->definition()->as()->getUnaryOpType() == UnaryOpType::Set); + TORCH_CHECK( + tv2->definition()->isA() && + tv2->definition()->as()->getUnaryOpType() == UnaryOpType::Set); +} + +// Partition an input, reduce each and concatenate them +TEST_F(NVFuserTest, FusionResizeSlice4_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + std::vector shape({5, 100}); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + // Consider a fusion of: + // auto tv1 = add(tv0, IrBuilder::create(1)); + // auto tv2 = sum(tv1, {1}); + + // Reproduce the above fusion with split tensors + + // Split the input to [0:2, :] and [2:, :] + auto tv1 = slice( + tv0, {{IrBuilder::create(0), IrBuilder::create(2)}, Slice()}); + auto tv2 = slice(tv0, {{IrBuilder::create(2)}, Slice()}); + + auto tv3 = add(tv1, IrBuilder::create(1)); + auto tv4 = add(tv2, IrBuilder::create(1)); + + auto tv5 = sum(tv3, {1}); + auto tv6 = sum(tv4, {1}); + auto tv7 = cat({tv5, tv6}, 0); + fusion.addOutput(tv7); + + // Schedule the two reductions separately + tv5->split(-1, 32); + auto tv5_rf = tv5->rFactor({-2}); + tv5_rf->reorder({{-1, -2}}); + auto tv5_cache = tv5->cacheBefore(); + tv5->setMemoryType(MemoryType::Global); + SetSelector tv5_rf_selector({tv1, tv3, tv5, tv5_cache}); + TransformPropagator tv5_rf_tp(tv5_rf); + MaxRootDomainInfoSpanningTree(tv5_rf, &tv5_rf_selector).traverse(&tv5_rf_tp); + inlineMost(std::vector{tv1, tv3, tv5_rf}); + tv5_rf->axis(0)->parallelize(ParallelType::BIDx); + tv5_rf->axis(1)->parallelize(ParallelType::TIDx); + scheduler_utils::parallelizeAllLike(tv5_rf, {tv1, tv3, tv5, tv5_cache}); + + tv6->split(-1, 32); + auto tv6_rf = tv6->rFactor({-2}); + tv6_rf->reorder({{-1, -2}}); + auto tv6_cache = tv6->cacheBefore(); + tv6->setMemoryType(MemoryType::Global); + SetSelector tv6_rf_selector({tv2, tv4, tv6, tv6_cache}); + TransformPropagator tv6_rf_tp(tv6_rf); + MaxRootDomainInfoSpanningTree(tv6_rf, &tv6_rf_selector).traverse(&tv6_rf_tp); + inlineMost(std::vector{tv2, tv4, tv6_rf}); + tv6_rf->axis(0)->parallelize(ParallelType::BIDx); + tv6_rf->axis(1)->parallelize(ParallelType::TIDx); + scheduler_utils::parallelizeAllLike(tv6_rf, {tv2, tv4, tv6, tv6_cache}); + + // cat consits of a PadOp and a CatOp. Fully inline the PadOp + for (auto tv7_inp : + ir_utils::filterByType(tv7->definition()->inputs())) { + tv7_inp->inlineAt(-1); + } + + // Use just one block to concat the two results + tv7->axis(0)->parallelize(ParallelType::TIDx); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + + auto t0 = at::randn(shape, options); + std::vector aten_inputs({t0}); + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto cg_outputs = fe.runFusion(aten_inputs); + + auto ref = (t0 + 1).to(at::kDouble).sum({1}); + + testValidate(&fusion, cg_outputs, aten_inputs, {ref}, __LINE__, __FILE__); +} + +// Multiple slices of the same tensor with the same arguments +TEST_F(NVFuserTest, FusionResizeSlice5_CUDA) { + auto fusion_ptr = std::make_unique(); + auto& fusion = *fusion_ptr; + FusionGuard fg(fusion_ptr.get()); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + auto tv1 = slice( + tv0, + {Slice(), + {IrBuilder::create(1), + sub(tv0->axis(1)->extent(), IrBuilder::create(1))}}); + auto tv2 = sum(tv1, {1}); + fusion.addOutput(tv2); + auto tv3 = slice( + tv0, + {Slice(), + {IrBuilder::create(1), + sub(tv0->axis(1)->extent(), IrBuilder::create(1))}}); + auto tv4 = sum(tv3, {1}); + fusion.addOutput(tv4); + + tv2->split(1, 128); + + // tv1 and tv3 are both slice outputs. Propagation should occur from + // tv1 to tv3 through tv0, which should work as both tensors are + // sliced in the same way. + TransformPropagator propagator(tv2); + MaxRootDomainInfoSpanningTree(tv2).traverse(&propagator); + + inlineMost(); + + tv2->axis(0)->parallelize(ParallelType::BIDx); + tv2->axis(-1)->parallelize(ParallelType::TIDx); + scheduler_utils::parallelizeAllLike(tv2); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + + std::vector shape({11, 1000}); + auto t0 = at::randn(shape, options); + std::vector aten_inputs({t0}); + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto cg_outputs = fe.runFusion(aten_inputs); + + auto t1 = t0.index( + {at::indexing::Slice(0, at::indexing::None), + at::indexing::Slice(1, shape[1] - 1)}); + auto t2 = t1.to(at::kDouble).sum({1}); + auto t3 = t0.index( + {at::indexing::Slice(0, at::indexing::None), + at::indexing::Slice(1, shape[1] - 1)}); + auto t4 = t3.to(at::kDouble).sum({1}); + + testValidate(&fusion, cg_outputs, aten_inputs, {t2, t4}, __LINE__, __FILE__); +} + +// Auto scheduled version of Slice1 +TEST_F(NVFuserTest, FusionResizeSliceScheduler1_CUDA) { + auto fusion_ptr = std::make_unique(); + auto& fusion = *fusion_ptr; + FusionGuard fg(fusion_ptr.get()); + + auto tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + + auto tv1 = slice( + tv0, + {{IrBuilder::create(1), + sub(tv0->axis(0)->extent(), IrBuilder::create(1))}}); + fusion.addOutput(tv1); + + std::vector shape({9}); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + + auto t0 = at::randn(shape, options); + std::vector aten_inputs({t0}); + + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs); + + auto ref = t0.index({at::indexing::Slice(1, shape[0] - 1)}); + + TORCH_CHECK(ref.equal(cg_outputs[0])); +} + +TEST_F(NVFuserTest, FusionResizePadReduceScheduler1_CUDA) { + auto fusion_ptr = std::make_unique(); + auto& fusion = *fusion_ptr; + FusionGuard fg(fusion_ptr.get()); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + auto left_pad0 = IrBuilder::create(); + fusion.addInput(left_pad0); + auto right_pad0 = IrBuilder::create(); + fusion.addInput(right_pad0); + auto left_pad1 = IrBuilder::create(); + fusion.addInput(left_pad1); + auto right_pad1 = IrBuilder::create(); + fusion.addInput(right_pad1); + + auto tv1 = pad(tv0, {left_pad0, right_pad0, left_pad1, right_pad1}); + auto tv2 = sum(tv1, {1}); + fusion.addOutput(tv2); + + std::vector shape({123, 999}); + std::vector pad_extents{1, 2, 2, 1}; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + + auto t0 = at::randn(shape, options); + std::vector aten_inputs({t0}); + std::transform( + pad_extents.begin(), + pad_extents.end(), + std::back_inserter(aten_inputs), + [](auto pad_extent) { return pad_extent; }); + + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs); + + auto ref = at::pad(t0, pad_extents).sum({1}); + + testValidate( + executor_cache.fusion(), + cg_outputs, + aten_inputs, + {ref}, + __LINE__, + __FILE__); +} + +TEST_F(NVFuserTest, FusionResizeSliceReduceScheduler1_CUDA) { + auto fusion_ptr = std::make_unique(); + auto& fusion = *fusion_ptr; + FusionGuard fg(fusion_ptr.get()); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + auto start0 = IrBuilder::create(); + fusion.addInput(start0); + auto end0 = IrBuilder::create(); + fusion.addInput(end0); + auto start1 = IrBuilder::create(); + fusion.addInput(start1); + auto end1 = IrBuilder::create(); + fusion.addInput(end1); + + auto tv1 = slice(tv0, {{start0, end0}, {start1, end1}}); + auto tv2 = sum(tv1, {1}); + fusion.addOutput(tv2); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + + std::vector shape({123, 999}); + std::vector slice_inputs({1, shape[0] - 2, 3, shape[1] - 4}); + + auto t0 = at::randn(shape, options); + std::vector aten_inputs({t0}); + std::copy( + slice_inputs.begin(), + slice_inputs.end(), + std::back_inserter(aten_inputs)); + + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs); + + auto t1 = t0.index( + {at::indexing::Slice(slice_inputs[0], slice_inputs[1]), + at::indexing::Slice(slice_inputs[2], slice_inputs[3])}); + auto ref = t1.sum({1}); + + testValidate( + executor_cache.fusion(), + cg_outputs, + aten_inputs, + {ref}, + __LINE__, + __FILE__); +} + +// Multiple slice+reduction. Different slices. +TEST_F(NVFuserTest, FusionResizeSliceReduceScheduler2_CUDA) { + auto fusion_ptr = std::make_unique(); + auto& fusion = *fusion_ptr; + FusionGuard fg(fusion_ptr.get()); + + auto tv0 = makeContigTensor(2); + fusion.addInput(tv0); + + auto start0 = IrBuilder::create(); + fusion.addInput(start0); + auto end0 = IrBuilder::create(); + fusion.addInput(end0); + auto start1 = IrBuilder::create(); + fusion.addInput(start1); + auto end1 = IrBuilder::create(); + fusion.addInput(end1); + + auto tv1 = slice(tv0, {Slice(), {start0, end0}}); + auto tv2 = sum(tv1, {1}); + fusion.addOutput(tv2); + auto tv3 = slice(tv0, {Slice(), {start1, end1}}); + auto tv4 = sum(tv3, {1}); + fusion.addOutput(tv4); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + + std::vector shape({123, 1024}); + std::vector slice_inputs({1, shape[0] - 2, 3, shape[1] - 4}); + + auto t0 = at::randn(shape, options); + std::vector aten_inputs({t0}); + std::copy( + slice_inputs.begin(), + slice_inputs.end(), + std::back_inserter(aten_inputs)); + + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs); + + auto t1 = t0.index( + {at::indexing::Slice(0, at::indexing::None), + at::indexing::Slice(slice_inputs[0], slice_inputs[1])}); + auto t2 = t1.sum({1}); + auto t3 = t0.index( + {at::indexing::Slice(0, at::indexing::None), + at::indexing::Slice(slice_inputs[2], slice_inputs[3])}); + auto t4 = t3.sum({1}); + + testValidate( + executor_cache.fusion(), + cg_outputs, + aten_inputs, + {t2, t4}, + __LINE__, + __FILE__); +} + +// Multiple slice+reduction. Same slices. Should be segmented at the moment. +TEST_F(NVFuserTest, FusionSliceReduceScheduler3_CUDA) { + auto fusion_ptr = std::make_unique(); + auto& fusion = *fusion_ptr; + FusionGuard fg(fusion_ptr.get()); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + auto start0 = IrBuilder::create(); + fusion.addInput(start0); + auto end0 = IrBuilder::create(); + fusion.addInput(end0); + + auto tv1 = slice(tv0, {Slice(), {start0, end0}}); + auto tv2 = sum(tv1, {1}); + fusion.addOutput(tv2); + auto tv3 = slice(tv0, {Slice(), {start0, end0}}); + auto tv4 = sum(tv3, {1}); + fusion.addOutput(tv4); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + + std::vector shape({123, 999}); + std::vector slice_inputs({1, shape[1] - 2}); + + auto t0 = at::randn(shape, options); + std::vector aten_inputs({t0}); + std::copy( + slice_inputs.begin(), + slice_inputs.end(), + std::back_inserter(aten_inputs)); + + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs); + + auto t1 = t0.index( + {at::indexing::Slice(0, at::indexing::None), + at::indexing::Slice(slice_inputs[0], slice_inputs[1])}); + auto t2 = t1.to(at::kDouble).sum({1}); + auto t3 = t0.index( + {at::indexing::Slice(0, at::indexing::None), + at::indexing::Slice(slice_inputs[0], slice_inputs[1])}); + auto t4 = t3.to(at::kDouble).sum({1}); + + testValidate( + executor_cache.fusion(), + cg_outputs, + aten_inputs, + {t2, t4}, + __LINE__, + __FILE__); +} + +TEST_F(NVFuserTest, FusionResizeCatReduceScheduler1_CUDA) { + auto fusion_ptr = std::make_unique(); + auto& fusion = *fusion_ptr; + FusionGuard fg(fusion_ptr.get()); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + auto tv1 = makeSymbolicTensor(2); + fusion.addInput(tv1); + + auto tv2 = cat({tv0, tv1}, 1); + auto tv3 = sum(tv2, {1}); + fusion.addOutput(tv3); + + std::vector shape0({11, 12}); + std::vector shape1({shape0[0], 13}); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + + auto t0 = at::randn(shape0, options); + auto t1 = at::randn(shape1, options); + std::vector aten_inputs({t0, t1}); + + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs); + + auto ref = at::cat({t0, t1}, 1).sum({1}); + + testValidate( + executor_cache.fusion(), + cg_outputs, + aten_inputs, + {ref}, + __LINE__, + __FILE__); +} + +TEST_F(NVFuserTest, FusionResizeCatSoftmaxScheduler1_CUDA) { + auto fusion_ptr = std::make_unique(); + auto& fusion = *fusion_ptr; + FusionGuard fg(fusion_ptr.get()); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + auto tv1 = makeSymbolicTensor(2); + fusion.addInput(tv1); + + auto tv2 = cat({tv0, tv1}, 1); + auto tv3 = softmax(tv2, 1); + fusion.addOutput(tv3); + + std::vector shape0({11, 99}); + std::vector shape1({shape0[0], 100}); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + + auto t0 = at::randn(shape0, options); + auto t1 = at::randn(shape1, options); + std::vector aten_inputs({t0, t1}); + + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs); + + auto t2 = at::cat({t0, t1}, 1); + auto ref = at::_softmax(t2.to(at::kDouble), -1, false); + + testValidate( + executor_cache.fusion(), + cg_outputs, + aten_inputs, + {ref}, + __LINE__, + __FILE__); +} + +TEST_F(NVFuserTest, FusionResizeReductionSliceScheduler1_CUDA) { + auto fusion_ptr = std::make_unique(); + auto& fusion = *fusion_ptr; + FusionGuard fg(fusion_ptr.get()); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + auto tv1 = sum(tv0, {1}); + auto tv2 = slice( + tv1, + {{IrBuilder::create(1), + sub(tv1->axis(0)->extent(), IrBuilder::create(2))}}); + fusion.addOutput(tv2); + + std::vector shape0({10, 1234}); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + + auto t0 = at::randn(shape0, options); + std::vector aten_inputs({t0}); + + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs); + + auto t1 = t0.to(at::kDouble).sum({1}); + auto t2 = t1.index({at::indexing::Slice(1, shape0[0] - 2)}); + + testValidate( + executor_cache.fusion(), + cg_outputs, + aten_inputs, + {t2}, + __LINE__, + __FILE__); +} + +// Softmax followed by slicing of a non-normalized dimension +TEST_F(NVFuserTest, FusionResizeSoftmaxSliceScheduler1_CUDA) { + auto fusion_ptr = std::make_unique(); + auto& fusion = *fusion_ptr; + FusionGuard fg(fusion_ptr.get()); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + auto tv1 = softmax(tv0, 1); + auto tv2 = slice( + tv1, + {{IrBuilder::create(1), + sub(tv1->axis(0)->extent(), IrBuilder::create(2))}, + Slice()}); + fusion.addOutput(tv2); + + std::vector shape0({13, 1234}); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + + auto t0 = at::randn(shape0, options); + std::vector aten_inputs({t0}); + + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs); + + auto t1 = at::_softmax(t0.to(at::kDouble), -1, false); + auto t2 = t1.index( + {at::indexing::Slice(1, shape0[0] - 2), + at::indexing::Slice(0, at::indexing::None)}); + + testValidate( + executor_cache.fusion(), + cg_outputs, + aten_inputs, + {t2}, + __LINE__, + __FILE__); +} + +// Softmax followed by slicing of a normalized dimension +TEST_F(NVFuserTest, FusionResizeSoftmaxSliceScheduler2_CUDA) { + auto fusion_ptr = std::make_unique(); + auto& fusion = *fusion_ptr; + FusionGuard fg(fusion_ptr.get()); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + auto tv1 = softmax(tv0, 1); + auto tv2 = slice( + tv1, + {Slice(), + {IrBuilder::create(1), + sub(tv1->axis(1)->extent(), IrBuilder::create(2))}}); + fusion.addOutput(tv2); + + std::vector shape0({110, 12345}); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + + auto t0 = at::randn(shape0, options); + std::vector aten_inputs({t0}); + + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs); + + auto t1 = at::_softmax(t0.to(at::kDouble), -1, false); + auto t2 = t1.index( + {at::indexing::Slice(0, at::indexing::None), + at::indexing::Slice(1, shape0[1] - 2)}); + + testValidate( + executor_cache.fusion(), + cg_outputs, + aten_inputs, + {t2}, + __LINE__, + __FILE__); +} + +} // namespace nvfuser