Skip to content

Commit

Permalink
IterDomain resize for pad, cat, slice
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobhinkle committed Mar 15, 2023
1 parent 48b0cb4 commit 406d0d1
Show file tree
Hide file tree
Showing 55 changed files with 3,755 additions and 404 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,7 @@ if(BUILD_TEST)
list(APPEND JIT_TEST_SRCS ${NVFUSER_ROOT}/test/test_gpu_outer_reduction.cpp)
list(APPEND JIT_TEST_SRCS ${NVFUSER_ROOT}/test/test_loop_rotation.cpp)
list(APPEND JIT_TEST_SRCS ${NVFUSER_ROOT}/test/test_gpu_shift.cpp)
list(APPEND JIT_TEST_SRCS ${NVFUSER_ROOT}/test/test_resize.cpp)
list(APPEND JIT_TEST_SRCS ${NVFUSER_ROOT}/test/test_gpu_tensorcore.cpp)
list(APPEND JIT_TEST_SRCS ${NVFUSER_ROOT}/test/test_gpu_matmul_sass.cpp)
list(APPEND JIT_TEST_SRCS ${NVFUSER_ROOT}/test/test_gpu_view.cpp)
Expand Down
35 changes: 35 additions & 0 deletions csrc/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2774,6 +2774,41 @@ class CudaKernelGenerator : private OptOutConstDispatch {
indent() << "NVFUSER_UPDATE_MAGIC_ZERO\n";
}

void handle(const CatOp* cat) final {
auto out = gen(cat->output(0));

// Generate code like:
// if (consumer_idx < producer_0_extent) {
// consumer[consumer_idx] = produce_0[producer_idx0];
// } else if (consumer_idx < producer_1_extent) {
// consumer[consumer_idx] = produce_1[producer_idx1];
// } else if (consumer_idx < producer_2_extent) {
// consumer[consumer_idx] = produce_2[producer_idx2];
// } else {
// consumer[consumer_idx] = produce_3[producer_idx3];
// }

for (const auto i : c10::irange(cat->inputs().size())) {
auto inp = cat->input(i)->as<kir::TensorIndex>();
auto inp_str = gen(inp);
if (i < cat->inputs().size() - 1) {
if (i == 0) {
indent() << "if (";
} else {
indent() << "} else if (";
}
code_ << gen(cat->getPred(i)) << ") {\n";
} else {
// last case doesn't need to be predicated
indent() << "} else {\n";
}

indent() << kTab << out << " = " << gen(inp) << ";\n";
}

indent() << "}\n";
}

private:
std::stringstream code_;
const kir::Kernel* kernel_;
Expand Down
78 changes: 70 additions & 8 deletions csrc/compute_at_map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ bool IterDomainGraph::exprsMap(
}

TORCH_INTERNAL_ASSERT(
first->isA<Merge>() || first->isA<Split>(),
"Merge and split are the only expressions supported through rfactor operations in compute at map, but found:\n",
first->isA<Merge>() || first->isA<Split>() || first->isA<Resize>(),
"Merge, split and resize are the only expressions supported through rfactor operations in compute at map, but found:\n",
first->toString());

auto first_ids = ir_utils::filterByType<IterDomain>(
Expand Down Expand Up @@ -176,6 +176,15 @@ bool IterDomainGraph::exprsMap(
}
}

if (first->isA<Resize>()) {
auto first_resize = first->as<Resize>();
auto second_resize = second->as<Resize>();
if (!first_resize->leftExpand()->sameAs(second_resize->leftExpand()) ||
!first_resize->rightExpand()->sameAs(second_resize->rightExpand())) {
return false;
}
}

return true;
}

Expand Down Expand Up @@ -211,6 +220,7 @@ void IterDomainGraph::mapThroughExpr(Expr* first, Expr* second, bool forward) {
for (auto out_i : c10::irange(first_ids.size())) {
exact_nodes_.mapEntries(first_ids[out_i], second_ids[out_i]);
permissive_nodes_.mapEntries(first_ids[out_i], second_ids[out_i]);
permissive_resize_nodes_.mapEntries(first_ids[out_i], second_ids[out_i]);
}
}

Expand Down Expand Up @@ -407,6 +417,7 @@ void IterDomainGraph::build(Fusion* fusion) {
auto id0 = *disjoint_set->begin();
for (auto id1 : disjoint_set->vector()) {
permissive_nodes_.mapEntries(id0, id1);
permissive_resize_nodes_.mapEntries(id0, id1);
exact_nodes_.mapEntries(id0, id1);
sibling_sets_.mapEntries(id0, id1);
}
Expand All @@ -430,8 +441,22 @@ void IterDomainGraph::build(Fusion* fusion) {
// Look for matching ID transformations in producer and consumer, replay
// producer as consumer. We use the symmetric API of BestEffortReplay so
// that both broadcast and squeeze are handled correctly.
//
// Note on the boolean flags: swizzles are skipped in both
// producer and consumer but resizes are not.
const auto permissive_disjoint_sets =
BestEffortReplay::replayPasC(p_tv, c_tv, -1, pairwise_map)
BestEffortReplay::replayPasC(
p_tv, c_tv, -1, pairwise_map, true, true, false)
.getIterDomainEquivalence();

// Permissive-Resize map allows mappings of resize inputs and
// outputs
//
// Note on the boolean flags: swizzles and resizes are skipped
// in the permissive-resize map
const auto permissive_resize_disjoint_sets =
BestEffortReplay::replayPasC(
p_tv, c_tv, -1, pairwise_map, true, true, true)
.getIterDomainEquivalence();

// For exact mapings do not map any broadcast dimensions to
Expand Down Expand Up @@ -483,16 +508,12 @@ void IterDomainGraph::build(Fusion* fusion) {
for (auto j : c10::irange(i + 1, vec.size())) {
auto id2 = vec[j];
if (p_ids.count(id1) && c_ids.count(id2)) {
consumers_.at(id1).pushBack(id2);
producers_.at(id2).pushBack(id1);
if (idIsAComputeAtLeafDomain(id1, p_tv, c_tv) &&
idIsALeafDomain(id2, c_tv)) {
loop_nodes_.mapEntries(id1, id2);
}
}
if (c_ids.count(id1) && p_ids.count(id2)) {
producers_.at(id1).pushBack(id2);
consumers_.at(id2).pushBack(id1);
if (idIsAComputeAtLeafDomain(id2, p_tv, c_tv) &&
idIsALeafDomain(id1, c_tv)) {
loop_nodes_.mapEntries(id1, id2);
Expand All @@ -501,6 +522,31 @@ void IterDomainGraph::build(Fusion* fusion) {
}
}
}

// Mostly the same as the above for the permissive map but
// nothing to do for the loop map.
// The producer and consumer maps are based on the most
// permissive mappings, so they are set using the
// permissive-resize mappings.
for (auto& dset : permissive_resize_disjoint_sets.disjointSets()) {
auto& vec = dset->vector();
for (auto i : c10::irange(vec.size())) {
auto id1 = vec[i];
permissive_resize_nodes_.mapEntries(id1, vec[0]);
mapMaybeSwizzleOp(permissive_resize_nodes_, id1);
for (auto j : c10::irange(i + 1, vec.size())) {
auto id2 = vec[j];
if (p_ids.count(id1) && c_ids.count(id2)) {
consumers_.at(id1).pushBack(id2);
producers_.at(id2).pushBack(id1);
}
if (c_ids.count(id1) && p_ids.count(id2)) {
producers_.at(id1).pushBack(id2);
consumers_.at(id2).pushBack(id1);
}
}
}
}
}
}
}
Expand Down Expand Up @@ -561,7 +607,7 @@ void IterDomainGraph::build(Fusion* fusion) {
for (auto expr : exprs) {
auto rfactor_inp_ids = ir_utils::filterByType<IterDomain>(expr->inputs());
TORCH_INTERNAL_ASSERT(
expr->isA<Split>() || expr->isA<Merge>(),
expr->isA<Split>() || expr->isA<Merge>() || expr->isA<Resize>(),
"Wasn't expecting the expression type of:\n",
expr->toString(),
"\nto be an expression defined in an rfactor transformation.");
Expand Down Expand Up @@ -688,6 +734,7 @@ void IterDomainGraph::initializeId(
bool is_rfactor_id,
bool is_leaf_id) {
permissive_nodes_.initializeSet(id);
permissive_resize_nodes_.initializeSet(id);
exact_nodes_.initializeSet(id);
if (is_leaf_id) {
loop_nodes_.initializeSet(id);
Expand Down Expand Up @@ -1127,6 +1174,17 @@ void ComputeAtMap::buildConcreteIds() {
auto concrete_id = computeConcreteId(first_id, IdMappingMode::LOOP);
concrete_id_cache_[disjoint_set_shared_ptr] = concrete_id;
}

for (const auto& disjoint_set_shared_ptr :
id_graph_.permissiveResizeNodes().disjointSets()) {
TORCH_INTERNAL_ASSERT(
disjoint_set_shared_ptr->vector().size(),
"Cannot compute concrete id of empty set.");
auto first_id = disjoint_set_shared_ptr->vector().front();
auto concrete_id =
computeConcreteId(first_id, IdMappingMode::PERMISSIVE_RESIZE);
concrete_id_cache_[disjoint_set_shared_ptr] = concrete_id;
}
}

bool ComputeAtMap::areExactExprs(Expr* expr_1, Expr* expr_2) {
Expand Down Expand Up @@ -1349,6 +1407,8 @@ std::string ComputeAtMap::toString() const {
ss << "Loop map:\n" << idGraphNodesToString(*this, IdMappingMode::LOOP);
ss << "Permissive map:\n"
<< idGraphNodesToString(*this, IdMappingMode::PERMISSIVE);
ss << "Permissive-Resize map:\n"
<< idGraphNodesToString(*this, IdMappingMode::PERMISSIVE_RESIZE);
ss << "Consumer maps:\n";
for (auto key : getSortedKeys(id_graph_.consumers(), Statement::lessThan)) {
auto consumers = id_graph_.consumers().at(key);
Expand Down Expand Up @@ -1408,6 +1468,8 @@ const DisjointSets<IterDomain*>& ComputeAtMap::getIdSets(
return id_graph_.loopNodes();
case IdMappingMode::PERMISSIVE:
return id_graph_.permissiveNodes();
case IdMappingMode::PERMISSIVE_RESIZE:
return id_graph_.permissiveResizeNodes();
}
TORCH_INTERNAL_ASSERT(false, "Error with mapping mode provided.");
}
Expand Down
12 changes: 11 additions & 1 deletion csrc/compute_at_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ namespace nvfuser {
// Map all iteration domains
// Always contain root mappings (otherwise they could have been forwarded in
// broadcast)
// IdMappingMode::PERMISSIVE_RESIZE
// Include everything in PERMISSIVE. Map also domains that are
// inputs and outputs of resize ops. Used for, e.g., propagating
// parallel types across those domains.
// IdMappingMode::EXACT
// Don't map any broadcast axes to non-broadcast axes
// Do not forward through any broadcast IDs
Expand All @@ -79,6 +83,9 @@ class TORCH_CUDA_CU_API IterDomainGraph {
const DisjointSets<IterDomain*>& loopNodes() const {
return loop_nodes_;
}
const DisjointSets<IterDomain*>& permissiveResizeNodes() const {
return permissive_resize_nodes_;
}

// Consumers and producers is not symmetric like the other sets
const std::unordered_map<IterDomain*, VectorOfUniqueEntries<IterDomain*>>&
Expand Down Expand Up @@ -132,8 +139,11 @@ class TORCH_CUDA_CU_API IterDomainGraph {
DisjointSets<IterDomain*> exact_nodes_;
DisjointSets<IterDomain*> almost_exact_nodes_;
DisjointSets<IterDomain*> loop_nodes_;
DisjointSets<IterDomain*> permissive_resize_nodes_;

// Consumers and producers is not symmetric like the other sets
// Consumers and producers is not symmetric like the other sets.
// Mapping is based on the most permissive map, i.e., the
// permissive-resize map.
std::unordered_map<IterDomain*, VectorOfUniqueEntries<IterDomain*>>
consumers_;
std::unordered_map<IterDomain*, VectorOfUniqueEntries<IterDomain*>>
Expand Down
61 changes: 61 additions & 0 deletions csrc/contiguity.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,48 @@ void OrderedIdInformation::handle(Swizzle2D* swizzle) {
}
}

void OrderedIdInformation::handle(Resize* resize) {
// Find inputs in the active_ids_ vector
const auto in_it =
std::find(active_ids_.begin(), active_ids_.end(), resize->in());

if (in_it == active_ids_.end()) {
return;
}

auto in_pos = std::distance(active_ids_.begin(), in_it);

// Find inputs in the ordered transforms map
const auto in_ordered_it = consistently_ordered_ids_.find(resize->in());

bool in_ordered = in_ordered_it != consistently_ordered_ids_.end();

// Get root ids of the two inputs
const auto in_root_ids_it = id_to_root_ids_.find(resize->in());

TORCH_INTERNAL_ASSERT(
in_root_ids_it != id_to_root_ids_.end(),
"Error replaying transforms in contiguous ID checker.");

const auto& in_root_ids = in_root_ids_it->second;

// Update map for outputs
// Remove inputs from the active_ids_ and insert the output ID
active_ids_[in_pos] = resize->out();

// Not completely certain, but propagating these properties should e
// fine
if (in_ordered) {
consistently_ordered_ids_.emplace(resize->out());
}

if (exclusivelyConsumesRoots(resize->in())) {
exclusively_consumes_roots_.emplace(resize->out());
}

id_to_root_ids_[resize->out()] = in_root_ids;
}

NonDivisibleSplitDependencies::NonDivisibleSplitDependencies(
// TODO: Revisit reduction rfactor axes and propagation. Should probably use
// ca_map to propogate non divisibility dependencies across exact map. Still
Expand Down Expand Up @@ -500,6 +542,19 @@ void ContigIDs::build(const std::vector<IterDomain*>& ids) {
{root_domain_.begin(), root_domain_.end()},
{ids.begin(), ids.end()});
for (auto expr : exprs) {
if (auto resize = dynamic_cast<Resize*>(expr)) {
resize_deps_.insert(resize->out());
} else {
if (std::any_of(
expr->inputs().begin(), expr->inputs().end(), [&](Val* inp) {
return inp->isA<IterDomain>() &&
resize_deps_.count(inp->as<IterDomain>());
})) {
for (auto out : ir_utils::filterByType<IterDomain>(expr->outputs())) {
resize_deps_.insert(out);
}
}
}
handle(expr);
}
}
Expand Down Expand Up @@ -576,6 +631,12 @@ void ContigIDs::handle(Merge* merge) {
return;
}

// Don't allow contig indexing after resize as we need traverse back
// at least to direct outputs of resize ops
if (resize_deps_.count(merge->out())) {
return;
}

// All broadcasting
if (last_root == nullptr) {
return;
Expand Down
7 changes: 7 additions & 0 deletions csrc/contiguity.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ class OrderedIdInformation : public OptInDispatch {

void handle(Swizzle2D* swizzle) override;

void handle(Resize* resize) override;

// Track which root ids were used to generate each iter domain
std::unordered_map<IterDomain*, VectorOfUniqueEntries<IterDomain*>>
id_to_root_ids_;
Expand Down Expand Up @@ -255,6 +257,8 @@ class ContigIDs : public OptInDispatch {
// cases, depending on specific swizzle type and axes.
void handle(Swizzle2D* swizzle) override {}

void handle(Resize* resize) override {}

IterDomain* getCAIndexConcreteId(IterDomain* id) const;

//! True if an ID is indexable.
Expand Down Expand Up @@ -307,6 +311,9 @@ class ContigIDs : public OptInDispatch {
std::unique_ptr<const OrderedIdInformation> consistent_transform_info_;

NonDivisibleSplitDependencies non_divisible_id_info_;

//! IDs that depend on resize output IDs
std::unordered_set<IterDomain*> resize_deps_;
};

} // namespace nvfuser
Loading

0 comments on commit 406d0d1

Please sign in to comment.