Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IterDomain resize for pad, cat, slice #2480

Merged
merged 93 commits into from
Mar 14, 2023
Merged
Show file tree
Hide file tree
Changes from 81 commits
Commits
Show all changes
93 commits
Select commit Hold shift + click to select a range
a212e8d
WIP
naoyam Feb 10, 2023
8d66afa
WIP: expand
naoyam Feb 11, 2023
e5ee102
WIP
naoyam Feb 11, 2023
4b707e8
WIP
naoyam Feb 11, 2023
66e5e38
v1
naoyam Feb 11, 2023
b458177
wip
naoyam Feb 11, 2023
179679c
tests
naoyam Feb 11, 2023
c27900d
WIP: PadOp
naoyam Feb 12, 2023
ba6d5ef
checkpoint: decided to try rfactor
naoyam Feb 12, 2023
b826fbc
pad test
naoyam Feb 13, 2023
8267442
computeAt
naoyam Feb 13, 2023
395939d
parallelize
naoyam Feb 13, 2023
a7f42d2
concat
naoyam Feb 13, 2023
c748580
cat tests
naoyam Feb 14, 2023
f823b26
WIP: slice
naoyam Feb 14, 2023
0251bad
bug fix
naoyam Feb 14, 2023
5aed8ec
cleanup
naoyam Feb 15, 2023
c67927e
Merge branch 'devel' into catop
naoyam Feb 15, 2023
49195d4
Merge branch 'devel' into catop
naoyam Feb 15, 2023
c7f2118
Renamed expand to resize
naoyam Feb 15, 2023
ecb72a4
Merge branch 'devel' into catop
naoyam Feb 16, 2023
0f74cc8
cleanup
naoyam Feb 16, 2023
71a99d3
Merge branch 'devel' into catop
naoyam Feb 16, 2023
1864ec9
cleanup
naoyam Feb 16, 2023
32e7d39
cleanup
naoyam Feb 16, 2023
36f31e0
remove TensorView::expand
naoyam Feb 16, 2023
4b5e3da
cleanup
naoyam Feb 16, 2023
e23b0c9
fix
naoyam Feb 16, 2023
071f7d2
cleanup
naoyam Feb 16, 2023
03f3194
cleanup
naoyam Feb 16, 2023
13bb33a
cleanup
naoyam Feb 17, 2023
4374489
cleanup
naoyam Feb 17, 2023
5164c41
update
naoyam Feb 17, 2023
b056a86
cleanup
naoyam Feb 17, 2023
1871d7f
fix
naoyam Feb 17, 2023
27204dd
Merge branch 'devel' into iter_domain_resize
naoyam Feb 18, 2023
b4f24a4
test cleanup
naoyam Feb 18, 2023
bf21fb2
cleanup
naoyam Feb 18, 2023
bdc01ab
unswitch fix
naoyam Feb 22, 2023
d32a05f
updates for resize
naoyam Feb 22, 2023
dc3a4ee
more tests
naoyam Feb 22, 2023
497f330
add scheduler tests for cat
naoyam Feb 22, 2023
1faae1e
Don't propagate resize if not safe
naoyam Feb 22, 2023
0238749
update
naoyam Feb 23, 2023
c792e41
WIP: scheduler support
naoyam Feb 24, 2023
00f4a88
cleanup
naoyam Feb 24, 2023
66d75b3
fix
naoyam Feb 24, 2023
bf45b91
fix
naoyam Feb 24, 2023
594f50d
Softmax + slice working
naoyam Feb 25, 2023
aa4c356
cleanup, should be more robust
naoyam Feb 25, 2023
b036389
cleanup
naoyam Feb 25, 2023
6701006
All tests pass
naoyam Feb 25, 2023
bcb3c8f
fix
naoyam Feb 25, 2023
79e035c
cleanup
naoyam Feb 25, 2023
c492d57
cleanup
naoyam Feb 25, 2023
9c2f350
fix
naoyam Feb 25, 2023
2b8bb4d
cleanup
naoyam Feb 25, 2023
01df3c8
Merge branch 'devel' into iter_domain_resize
naoyam Feb 25, 2023
d8534ca
cleanup
naoyam Feb 28, 2023
460ad43
cleanup
naoyam Feb 28, 2023
4a31361
cleanup
naoyam Feb 28, 2023
c79a4cd
Merge branch 'devel' into iter_domain_resize
naoyam Feb 28, 2023
1665acd
cleanup
naoyam Feb 28, 2023
152c01a
cleanup
naoyam Feb 28, 2023
5e558e9
Merge branch 'devel' into iter_domain_resize
naoyam Mar 2, 2023
b476743
cleanup
naoyam Mar 2, 2023
64c5bd1
PR feedback
naoyam Mar 7, 2023
2a2e189
Mark pad widths of PadOp as inputs
naoyam Mar 7, 2023
30712e7
cleanup
naoyam Mar 7, 2023
4ad0b9d
comments
naoyam Mar 7, 2023
01722c7
cleanpu
naoyam Mar 7, 2023
c4d6b39
Rename "disjoint view" to "disjoint rfactor"
naoyam Mar 7, 2023
27c867f
Set the producer and consumer maps of the ID graph using the
naoyam Mar 7, 2023
622ee93
Do not create index maps for non-indexable domains
naoyam Mar 7, 2023
c4aab87
cleanup
naoyam Mar 7, 2023
b0ffb91
fix dependency analysis
naoyam Mar 7, 2023
da7feb1
Naming fix
naoyam Mar 7, 2023
8ad222c
Use pad instead of manually creating PadOp
naoyam Mar 7, 2023
1b5c23c
Remove unused func
naoyam Mar 7, 2023
6901da3
remove unnecessary func
naoyam Mar 7, 2023
2264406
PR feedback
naoyam Mar 8, 2023
a73ff45
Clean up pad_width vector
naoyam Mar 8, 2023
b3cffb2
Convert pad to set when valid
naoyam Mar 8, 2023
26061eb
More descriptive error messages
naoyam Mar 8, 2023
232e701
Merge branch 'devel' into iter_domain_resize
naoyam Mar 8, 2023
f187d6d
Missing override implementation
naoyam Mar 8, 2023
c1d2be5
test fix
naoyam Mar 8, 2023
3ed4797
Add validation of resize usage
naoyam Mar 8, 2023
82845fa
rename test file
naoyam Mar 8, 2023
1e2c711
fix
naoyam Mar 13, 2023
e366272
Merge branch 'devel' into iter_domain_resize
naoyam Mar 13, 2023
8c1e228
merge fix
naoyam Mar 13, 2023
39e603e
test fix
naoyam Mar 13, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions third_party/nvfuser/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,7 @@ if(BUILD_TEST)
list(APPEND JIT_TEST_SRCS ${NVFUSER_ROOT}/test/test_gpu_outer_reduction.cpp)
list(APPEND JIT_TEST_SRCS ${NVFUSER_ROOT}/test/test_gpu_loop_rotation.cpp)
list(APPEND JIT_TEST_SRCS ${NVFUSER_ROOT}/test/test_gpu_shift.cpp)
list(APPEND JIT_TEST_SRCS ${NVFUSER_ROOT}/test/test_gpu_resize.cpp)
list(APPEND JIT_TEST_SRCS ${NVFUSER_ROOT}/test/test_gpu_tensorcore.cpp)
list(APPEND JIT_TEST_SRCS ${NVFUSER_ROOT}/test/test_gpu_matmul_sass.cpp)
list(APPEND JIT_TEST_SRCS ${NVFUSER_ROOT}/test/test_gpu_view.cpp)
Expand Down
36 changes: 36 additions & 0 deletions third_party/nvfuser/csrc/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2770,6 +2770,42 @@ class CudaKernelGenerator : private OptOutConstDispatch {
indent() << "NVFUSER_UPDATE_MAGIC_ZERO\n";
}

void handle(const CatOp* cat) final {
auto out = gen(cat->output(0));
auto cat_idx = gen(cat->getConcatenatedDomainIndex());

// Generate code like:
// if (consumer_idx < producer_0_extent) {
// consumer[consumer_idx] = produce_0[producer_idx0];
// } else if (consumer_idx < producer_1_extent) {
// consumer[consumer_idx] = produce_1[producer_idx1];
// } else if (consumer_idx < producer_2_extent) {
// consumer[consumer_idx] = produce_2[producer_idx2];
// } else {
// consumer[consumer_idx] = produce_3[producer_idx3];
// }

for (const auto i : c10::irange(cat->inputs().size())) {
auto inp = cat->input(i)->as<kir::TensorIndex>();
auto inp_str = gen(inp);
if (i < cat->inputs().size() - 1) {
if (i == 0) {
indent() << "if (";
} else {
indent() << "} else if (";
}
code_ << gen(cat->getPred(i)) << ") {\n";
} else {
// last case doesn't need to be predicated
indent() << "} else {\n";
}

indent() << kTab << out << " = " << gen(inp) << ";\n";
}

indent() << "}\n";
}

private:
std::stringstream code_;
const kir::Kernel* kernel_;
Expand Down
78 changes: 70 additions & 8 deletions third_party/nvfuser/csrc/compute_at_map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,8 @@ bool IterDomainGraph::exprsMap(
}

TORCH_INTERNAL_ASSERT(
first->isA<Merge>() || first->isA<Split>(),
"Merge and split are the only expressions supported through rfactor operations in compute at map, but found:\n",
first->isA<Merge>() || first->isA<Split>() || first->isA<Resize>(),
"Merge, split and resize are the only expressions supported through rfactor operations in compute at map, but found:\n",
first->toString());

auto first_ids = ir_utils::filterByType<IterDomain>(
Expand Down Expand Up @@ -169,6 +169,15 @@ bool IterDomainGraph::exprsMap(
}
}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Am I understanding correct that, if I have I1 and I3 exact mapped, and

producer: I2 = resize(I1, l1, r1)
consumer: I4 = resize(I3, l2, r2)

If l1->sameAs(l2) && r1->sameAs(r2), then

exact map, permissive map:

{I1, I3}
{I2, I4}

permissive resize map:

{I1, I2, I3, I4}

And if !(l1->sameAs(l2) && r1->sameAs(r2)), then

exact map, permissive map:

{I1, I3}
{I2}
{I4}

permissive resize map:

{I1, I2, I3, I4}

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure what you mean by producer: I2 and consumer: I4, but assuming they are just produced from I1 and I3, respectively, yes, that is correct.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant I have

Tconsumer = pad(Tproducer)

where

Tproducer:
root domain: I1
leaf domain: I2
transformations: I2 = resize(I1, l1, r1)

Tconsumer
root domain: I3
rfactor domain: I4
leaf domain: I4
transformations: I4 = resize(I3, l2, r2)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assuming transformations: I2 = resize(I1, l1, r1) is not done as part of the rfactor transformation of Tproducer, it only happens when Tproducer is transformed as Tconsumer for producer indexing. In that case,while I2 and I4 should map, we don't update the ComputeAt map, so it won't affect how any of the exact/permissive/etc maps.

if (first->isA<Resize>()) {
auto first_resize = first->as<Resize>();
auto second_resize = second->as<Resize>();
if (!first_resize->leftExpand()->sameAs(second_resize->leftExpand()) ||
!first_resize->rightExpand()->sameAs(second_resize->rightExpand())) {
return false;
}
}

return true;
}

Expand Down Expand Up @@ -204,6 +213,7 @@ void IterDomainGraph::mapThroughExpr(Expr* first, Expr* second, bool forward) {
for (auto out_i : c10::irange(first_ids.size())) {
exact_nodes_.mapEntries(first_ids[out_i], second_ids[out_i]);
permissive_nodes_.mapEntries(first_ids[out_i], second_ids[out_i]);
permissive_resize_nodes_.mapEntries(first_ids[out_i], second_ids[out_i]);
}
}

Expand Down Expand Up @@ -409,6 +419,7 @@ void IterDomainGraph::build(Fusion* fusion) {
auto id0 = *disjoint_set->begin();
for (auto id1 : disjoint_set->vector()) {
permissive_nodes_.mapEntries(id0, id1);
permissive_resize_nodes_.mapEntries(id0, id1);
exact_nodes_.mapEntries(id0, id1);
sibling_sets_.mapEntries(id0, id1);
}
Expand All @@ -432,8 +443,22 @@ void IterDomainGraph::build(Fusion* fusion) {
// Look for matching ID transformations in producer and consumer, replay
// producer as consumer. We use the symmetric API of BestEffortReplay so
// that both broadcast and squeeze are handled correctly.
//
// Note on the boolean flags: swizzles are skipped in both
// producer and consumer but resizes are not.
const auto permissive_disjoint_sets =
BestEffortReplay::replayPasC(p_tv, c_tv, -1, pairwise_map)
BestEffortReplay::replayPasC(
p_tv, c_tv, -1, pairwise_map, true, true, false)
naoyam marked this conversation as resolved.
Show resolved Hide resolved
.getIterDomainEquivalence();

// Permissive-Resize map allows mappings of resize inputs and
// outputs
//
// Note on the boolean flags: swizzles and resizes are skipped
// in the permissive-resize map
const auto permissive_resize_disjoint_sets =
BestEffortReplay::replayPasC(
p_tv, c_tv, -1, pairwise_map, true, true, true)
naoyam marked this conversation as resolved.
Show resolved Hide resolved
.getIterDomainEquivalence();

// For exact mapings do not map any broadcast dimensions to
Expand Down Expand Up @@ -485,16 +510,12 @@ void IterDomainGraph::build(Fusion* fusion) {
for (auto j : c10::irange(i + 1, vec.size())) {
auto id2 = vec[j];
if (p_ids.count(id1) && c_ids.count(id2)) {
consumers_.at(id1).pushBack(id2);
producers_.at(id2).pushBack(id1);
if (idIsAComputeAtLeafDomain(id1, p_tv, c_tv) &&
idIsALeafDomain(id2, c_tv)) {
loop_nodes_.mapEntries(id1, id2);
}
}
if (c_ids.count(id1) && p_ids.count(id2)) {
producers_.at(id1).pushBack(id2);
consumers_.at(id2).pushBack(id1);
if (idIsAComputeAtLeafDomain(id2, p_tv, c_tv) &&
idIsALeafDomain(id1, c_tv)) {
loop_nodes_.mapEntries(id1, id2);
Expand All @@ -503,6 +524,31 @@ void IterDomainGraph::build(Fusion* fusion) {
}
}
}

// Mostly the same as the above for the permissive map but
// nothing to do for the loop map.
// The producer and consumer maps are based on the most
// permissive mappings, so they are set using the
// permissive-resize mappings.
for (auto& dset : permissive_resize_disjoint_sets.disjointSets()) {
auto& vec = dset->vector();
for (auto i : c10::irange(vec.size())) {
zasdfgbnm marked this conversation as resolved.
Show resolved Hide resolved
auto id1 = vec[i];
permissive_resize_nodes_.mapEntries(id1, vec[0]);
mapMaybeSwizzleOp(permissive_resize_nodes_, id1);
for (auto j : c10::irange(i + 1, vec.size())) {
auto id2 = vec[j];
if (p_ids.count(id1) && c_ids.count(id2)) {
consumers_.at(id1).pushBack(id2);
producers_.at(id2).pushBack(id1);
}
if (c_ids.count(id1) && p_ids.count(id2)) {
producers_.at(id1).pushBack(id2);
consumers_.at(id2).pushBack(id1);
}
}
}
}
}
}
}
Expand Down Expand Up @@ -563,7 +609,7 @@ void IterDomainGraph::build(Fusion* fusion) {
for (auto expr : exprs) {
auto rfactor_inp_ids = ir_utils::filterByType<IterDomain>(expr->inputs());
TORCH_INTERNAL_ASSERT(
expr->isA<Split>() || expr->isA<Merge>(),
expr->isA<Split>() || expr->isA<Merge>() || expr->isA<Resize>(),
"Wasn't expecting the expression type of:\n",
expr->toString(),
"\nto be an expression defined in an rfactor transformation.");
Expand Down Expand Up @@ -690,6 +736,7 @@ void IterDomainGraph::initializeId(
bool is_view_rfactor_id,
bool is_leaf_id) {
permissive_nodes_.initializeSet(id);
permissive_resize_nodes_.initializeSet(id);
exact_nodes_.initializeSet(id);
if (is_leaf_id) {
loop_nodes_.initializeSet(id);
Expand Down Expand Up @@ -1129,6 +1176,17 @@ void ComputeAtMap::buildConcreteIds() {
auto concrete_id = computeConcreteId(first_id, IdMappingMode::LOOP);
concrete_id_cache_[disjoint_set_shared_ptr] = concrete_id;
}

for (const auto& disjoint_set_shared_ptr :
id_graph_.permissiveResizeNodes().disjointSets()) {
TORCH_INTERNAL_ASSERT(
disjoint_set_shared_ptr->vector().size(),
"Cannot compute concrete id of empty set.");
auto first_id = disjoint_set_shared_ptr->vector().front();
auto concrete_id =
computeConcreteId(first_id, IdMappingMode::PERMISSIVE_RESIZE);
concrete_id_cache_[disjoint_set_shared_ptr] = concrete_id;
}
}

bool ComputeAtMap::areExactExprs(Expr* expr_1, Expr* expr_2) {
Expand Down Expand Up @@ -1351,6 +1409,8 @@ std::string ComputeAtMap::toString() const {
ss << "Loop map:\n" << idGraphNodesToString(*this, IdMappingMode::LOOP);
ss << "Permissive map:\n"
<< idGraphNodesToString(*this, IdMappingMode::PERMISSIVE);
ss << "Permissive-Resize map:\n"
<< idGraphNodesToString(*this, IdMappingMode::PERMISSIVE_RESIZE);
ss << "Consumer maps:\n";
for (auto key : getSortedKeys(id_graph_.consumers(), Statement::lessThan)) {
auto consumers = id_graph_.consumers().at(key);
Expand Down Expand Up @@ -1411,6 +1471,8 @@ const DisjointSets<IterDomain*>& ComputeAtMap::getIdSets(
return id_graph_.loopNodes();
case IdMappingMode::PERMISSIVE:
return id_graph_.permissiveNodes();
case IdMappingMode::PERMISSIVE_RESIZE:
return id_graph_.permissiveResizeNodes();
}
TORCH_INTERNAL_ASSERT(false, "Error with mapping mode provided.");
}
Expand Down
12 changes: 11 additions & 1 deletion third_party/nvfuser/csrc/compute_at_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ namespace nvfuser {
// Map all iteration domains
// Always contain root mappings (otherwise they could have been forwarded in
// broadcast)
// IdMappingMode::PERMISSIVE_RESIZE
zasdfgbnm marked this conversation as resolved.
Show resolved Hide resolved
// Include everything in PERMISSIVE. Map also domains that are
// inputs and outputs of resize ops. Used for, e.g., propagating
// parallel types across those domains.
// IdMappingMode::EXACT
// Don't map any broadcast axes to non-broadcast axes
// Do not forward through any broadcast IDs
Expand All @@ -72,6 +76,9 @@ class TORCH_CUDA_CU_API IterDomainGraph {
const DisjointSets<IterDomain*>& loopNodes() const {
return loop_nodes_;
}
const DisjointSets<IterDomain*>& permissiveResizeNodes() const {
return permissive_resize_nodes_;
}

// Consumers and producers is not symmetric like the other sets
const std::unordered_map<IterDomain*, VectorOfUniqueEntries<IterDomain*>>&
Expand Down Expand Up @@ -125,8 +132,11 @@ class TORCH_CUDA_CU_API IterDomainGraph {
DisjointSets<IterDomain*> exact_nodes_;
DisjointSets<IterDomain*> almost_exact_nodes_;
DisjointSets<IterDomain*> loop_nodes_;
DisjointSets<IterDomain*> permissive_resize_nodes_;

// Consumers and producers is not symmetric like the other sets
// Consumers and producers is not symmetric like the other sets.
// Mapping is based on the most permissive map, i.e., the
// permissive-resize map.
std::unordered_map<IterDomain*, VectorOfUniqueEntries<IterDomain*>>
consumers_;
std::unordered_map<IterDomain*, VectorOfUniqueEntries<IterDomain*>>
Expand Down
61 changes: 61 additions & 0 deletions third_party/nvfuser/csrc/contiguity.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,48 @@ void OrderedIdInformation::handle(Swizzle2D* swizzle) {
}
}

void OrderedIdInformation::handle(Resize* resize) {
zasdfgbnm marked this conversation as resolved.
Show resolved Hide resolved
// Find inputs in the active_ids_ vector
const auto in_it =
std::find(active_ids_.begin(), active_ids_.end(), resize->in());

if (in_it == active_ids_.end()) {
return;
}

auto in_pos = std::distance(active_ids_.begin(), in_it);

// Find inputs in the ordered transforms map
const auto in_ordered_it = consistently_ordered_ids_.find(resize->in());

bool in_ordered = in_ordered_it != consistently_ordered_ids_.end();

// Get root ids of the two inputs
const auto in_root_ids_it = id_to_root_ids_.find(resize->in());

TORCH_INTERNAL_ASSERT(
in_root_ids_it != id_to_root_ids_.end(),
"Error replaying transforms in contiguous ID checker.");

const auto& in_root_ids = in_root_ids_it->second;

// Update map for outputs
// Remove inputs from the active_ids_ and insert the output ID
active_ids_[in_pos] = resize->out();

// Not completely certain, but propagating these properties should e
// fine
if (in_ordered) {
consistently_ordered_ids_.emplace(resize->out());
}

if (exclusivelyConsumesRoots(resize->in())) {
exclusively_consumes_roots_.emplace(resize->out());
}

id_to_root_ids_[resize->out()] = in_root_ids;
}

NonDivisibleSplitDependencies::NonDivisibleSplitDependencies(
// TODO: Revisit reduction rfactor axes and propagation. Should probably use
// ca_map to propogate non divisibility dependencies across exact map. Still
Expand Down Expand Up @@ -489,6 +531,19 @@ void ContigIDs::build(const std::vector<IterDomain*>& ids) {
{root_domain_.begin(), root_domain_.end()},
{ids.begin(), ids.end()});
for (auto expr : exprs) {
if (auto resize = dynamic_cast<Resize*>(expr)) {
resize_deps_.insert(resize->out());
} else {
if (std::any_of(
expr->inputs().begin(), expr->inputs().end(), [&](Val* inp) {
return inp->isA<IterDomain>() &&
resize_deps_.count(inp->as<IterDomain>());
})) {
for (auto out : ir_utils::filterByType<IterDomain>(expr->outputs())) {
resize_deps_.insert(out);
}
}
}
handle(expr);
}
}
Expand Down Expand Up @@ -561,6 +616,12 @@ void ContigIDs::handle(Merge* merge) {
return;
}

// Don't allow contig indexing after resize as we need traverse back
// at least to direct outputs of resize ops
if (resize_deps_.count(merge->out())) {
return;
}

// Now we know merge->out is a contiguously indexable ID

TORCH_INTERNAL_ASSERT(
Expand Down
7 changes: 7 additions & 0 deletions third_party/nvfuser/csrc/contiguity.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ class OrderedIdInformation : public OptInDispatch {

void handle(Swizzle2D* swizzle) override;

void handle(Resize* resize) override;

// Track which root ids were used to generate each iter domain
std::unordered_map<IterDomain*, VectorOfUniqueEntries<IterDomain*>>
id_to_root_ids_;
Expand Down Expand Up @@ -248,6 +250,8 @@ class ContigIDs : public OptInDispatch {
// cases, depending on specific swizzle type and axes.
void handle(Swizzle2D* swizzle) override {}

void handle(Resize* resize) override {}

IterDomain* getCAIndexConcreteId(IterDomain* id) const;

//! True if an ID is indexable.
Expand Down Expand Up @@ -300,6 +304,9 @@ class ContigIDs : public OptInDispatch {
std::unique_ptr<const OrderedIdInformation> consistent_transform_info_;

NonDivisibleSplitDependencies non_divisible_id_info_;

//! IDs that depend on resize output IDs
std::unordered_set<IterDomain*> resize_deps_;
};

} // namespace nvfuser
Loading