Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Indexing Refactor #2316

Draft
wants to merge 36 commits into
base: devel
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
7e43d1d
Start reducing down IdGraph.
csarofeen Dec 3, 2022
7bbf730
Join the different sets into one structure based on MappingMode.
csarofeen Dec 3, 2022
5581f6e
Small alias.
csarofeen Dec 3, 2022
1443fcc
More minor refactoring.
csarofeen Dec 3, 2022
08b35c3
Code movement. Split up IdGraph build process.
csarofeen Dec 4, 2022
a6d824e
Rename node -> disjoint sets.
csarofeen Dec 5, 2022
fb13a0c
Make IterDomainGraph more self contained with a recursive ID mapping …
csarofeen Dec 6, 2022
ada2af5
Expand comment for BestEffortReplay.
csarofeen Dec 22, 2022
4e9d268
Minor interface changes in compute at map, code movement.
csarofeen Dec 23, 2022
ee50ddd
Expose forwarding info for permissive mapping.
csarofeen Dec 26, 2022
efda269
Make exact and permissive maps self building.
csarofeen Dec 26, 2022
fe08382
Remove BestEffortReplay use from Loop Map construction. Remove consum…
csarofeen Dec 28, 2022
37e81b3
Small name refactor, add definition and uses API to IterDomainGraph.
csarofeen Dec 28, 2022
bf202a8
Remove ComputeAtMap's definition and uses in favor of IterDomainGraphs.
csarofeen Dec 28, 2022
d787043
Remove multi output function in IdGraph building, remove explicit rfa…
csarofeen Dec 29, 2022
0321dba
IdGraph/ComputeAt Interface Tuning and cleanup.
csarofeen Dec 30, 2022
c29b427
Minor IdGraph tweaks, continue removing some BestEffortReplay usage.
csarofeen Jan 2, 2023
6f28803
Merge getMatchedLeafPosWithoutReplay[CasP, PasC] into ...TasR.
csarofeen Jan 2, 2023
23c05ad
Remove fullSelfMatching
csarofeen Jan 2, 2023
a390667
Treat braodcast as the exception not the rule for unskippable inlined…
csarofeen Jan 4, 2023
3b93604
Minor build fixes.
csarofeen Jan 6, 2023
9511d76
Minor cleanup.
csarofeen Jan 7, 2023
14c9844
Draft loop promotion in compute at map, validate with warning initial…
csarofeen Jan 8, 2023
af28d38
Minor loop promotion map fix.
csarofeen Jan 8, 2023
b67e245
minor build fix.
csarofeen Jan 20, 2023
4e55bd1
Minor refactoring.
csarofeen Jan 21, 2023
e38e8f7
Some more minor refactoring.
csarofeen Jan 21, 2023
0e6a85a
Update iter domain graph with broadcast promotion logic. WARNING Brea…
csarofeen Jan 26, 2023
33d0e04
Reduce verbosity, add name only option to lower dump.
csarofeen Jan 27, 2023
6651752
Minor loop promote fix.
csarofeen Jan 27, 2023
b465a10
Minor cleanup.
csarofeen Jan 28, 2023
c9e8710
Disable/WAR Gather/Shift for now.
csarofeen Jan 28, 2023
bd86f5c
Add id_definitions_, fix idGraph construction.
csarofeen Feb 6, 2023
33f5dcc
Stash current iter domain graph attempt.
csarofeen Feb 12, 2023
981a14a
Another draft of loop promotion and introduction of index map.
csarofeen Feb 21, 2023
e3cbe70
Factor out IdGraph from multiple IterDomainGraphs.
csarofeen Feb 25, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3,786 changes: 2,828 additions & 958 deletions third_party/nvfuser/csrc/compute_at_map.cpp

Large diffs are not rendered by default.

521 changes: 360 additions & 161 deletions third_party/nvfuser/csrc/compute_at_map.h

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion third_party/nvfuser/csrc/contiguity.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -605,7 +605,9 @@ bool ContigIDs::isIndexable(IterDomain* id) const {
// If ID is mapped to consumer through persmissive map but not exact map it
// will not be mapped through to the exact map through the p2c map. Therefore
// reject because it involves broadcast resolution.
if (!ca_map_->idExistsInMap(getMappedId(id))) {
if (!ca_map_->idGraph(IdMappingMode::EXACT)
.disjointIdSets()
.mappingExists(getMappedId(id))) {
return false;
}
auto c_id =
Expand Down
184 changes: 139 additions & 45 deletions third_party/nvfuser/csrc/disjoint_set.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,38 @@ std::string abstractToString(T ref) {

// Vector like class that will prevent adding duplicate entries by also
// maintaing a set
//
// TODO: Can we support std::back_inserter with this class?
template <typename T, typename Hash = std::hash<T>>
class VectorOfUniqueEntries {
public:
VectorOfUniqueEntries() = default;

VectorOfUniqueEntries(const std::initializer_list<T>& x)
: vector_(x), set_(x) {}
VectorOfUniqueEntries(const std::initializer_list<T>& initializer) {
for (auto entry : initializer) {
pushBack(entry);
}
}

VectorOfUniqueEntries(const VectorOfUniqueEntries<T>& other) {
vector_ = other.vector();
set_ = other.set();
}

VectorOfUniqueEntries& operator=(const VectorOfUniqueEntries<T>& other) {
if (this != &other) {
vector_ = other.vector();
set_ = other.set();
}
return *this;
}

template <class InputIt>
VectorOfUniqueEntries(InputIt first, InputIt last) {
while (first != last) {
pushBack(*first++);
}
}

// Returns if a node was actually added
bool pushBack(T entry) {
Expand All @@ -49,6 +74,15 @@ class VectorOfUniqueEntries {
return false;
}

// Returns if a node was actually added
bool pushFront(T entry) {
if (set_.emplace(entry).second) {
vector_.insert(vector_.begin(), entry);
return true;
}
return false;
}

// Returns if any node was added
bool pushBack(const VectorOfUniqueEntries<T, Hash>& other) {
bool any_added = false;
Expand All @@ -58,11 +92,53 @@ class VectorOfUniqueEntries {
return any_added;
}

// Returns a new VectorOfUniqueEntries with entries that are in both this and
// other, order is preserved as this.
VectorOfUniqueEntries<T, Hash> intersect(
const VectorOfUniqueEntries<T, Hash>& other) {
VectorOfUniqueEntries<T, Hash> intersection;
for (auto entry : vector()) {
if (other.has(entry)) {
intersection.pushBack(entry);
}
}
return intersection;
}

// Returns a new VectorOfUniqueEntries with entries that are in this but not
// in other.
VectorOfUniqueEntries<T, Hash> subtract(
const VectorOfUniqueEntries<T, Hash>& other) const {
VectorOfUniqueEntries<T, Hash> subtraction;
for (auto entry : vector()) {
if (!other.has(entry)) {
subtraction.pushBack(entry);
}
}
return subtraction;
}

// Returns a new VectorOfUniqueEntries with entries that are either in this or
// other.
VectorOfUniqueEntries<T, Hash> computeUnion(
const VectorOfUniqueEntries<T, Hash>& other) const {
const VectorOfUniqueEntries<T, Hash>& this_ref = *this;
VectorOfUniqueEntries<T, Hash> union_(this_ref);
for (auto entry : other.vector()) {
union_.pushBack(entry);
}
return union_;
}

// Returns a const vector useful for iterating on
const std::vector<T>& vector() const {
return vector_;
}

const std::unordered_set<T>& set() const {
return set_;
}

// Returns first element in vector
T front() const {
return vector_.front();
Expand All @@ -81,6 +157,14 @@ class VectorOfUniqueEntries {
return v;
}

// Remove and returns the last element in vector
T popFront() {
T v = vector_.front();
set_.erase(v);
vector_.erase(vector_.begin());
return v;
}

// Returns if this container is empty
bool empty() const {
return vector_.empty();
Expand Down Expand Up @@ -137,7 +221,7 @@ class VectorOfUniqueEntries {
return vector_.end();
}

std::string toString() {
std::string toString() const {
std::stringstream ss;
ss << "{ ";
for (auto entry : vector()) {
Expand Down Expand Up @@ -206,64 +290,78 @@ class DisjointSets {
}

// Initializes a new set for provided entry
//
// TODO: Return iterator
void initializeSet(T entry) {
if (disjoint_set_maps_.find(entry) != disjoint_set_maps_.end()) {
return;
std::pair<
typename std::unordered_map<
T,
std::shared_ptr<VectorOfUniqueEntries<T, Hash>>,
Hash>::iterator,
bool>
initializeSet(T entry) {
auto disjoint_set_maps_it = disjoint_set_maps_.find(entry);
if (disjoint_set_maps_it != disjoint_set_maps_.end()) {
return std::make_pair(disjoint_set_maps_it, false);
}

disjoint_sets_.push_back(
std::make_shared<VectorOfUniqueEntries<T, Hash>>());
disjoint_sets_.back()->pushBack(entry);
disjoint_set_maps_.emplace(std::make_pair(entry, disjoint_sets_.back()));
return disjoint_set_maps_.emplace(
std::make_pair(entry, disjoint_sets_.back()));
}

// Adds all of the disjoint set belonging to entry1 to the disjoint set
// belonging to entry0, maps all entries of disjoint set belonging to entry1
// to entry0, removes original disjoint set belonging to entry1.
void mapEntries(T entry0, T entry1) {
if (entry0 == entry1) {
return;
}

auto set_it_0 = disjoint_set_maps_.find(entry0);
auto set_it_1 = disjoint_set_maps_.find(entry1);

// Track if we need to reset iterators, optimize for case where both entries
// exist
bool invalid_iterators = false;
if (set_it_0 == disjoint_set_maps_.end()) {
initializeSet(entry0);
invalid_iterators = true;
}
auto set_0_found = set_it_0 != disjoint_set_maps_.end();
auto set_1_found = set_it_1 != disjoint_set_maps_.end();

if (set_it_1 == disjoint_set_maps_.end()) {
initializeSet(entry1);
invalid_iterators = true;
// Sets already joined
if (set_0_found && set_1_found && set_it_0->second == set_it_1->second) {
return;
}

// TODO: We can avoid refinding one iterator if initialize set returns an
// iterator, though if we insert entry1 we'd have to refind entry0 as it
// could invalidate all iterators
if (invalid_iterators) {
set_it_0 = disjoint_set_maps_.find(entry0);
// Make and map new set
disjoint_sets_.push_back(
std::make_shared<VectorOfUniqueEntries<T, Hash>>());
auto new_set = disjoint_sets_.back();

if (set_0_found) {
auto set_0 = set_it_0->second;
for (auto set_0_entry : *set_0) {
TORCH_INTERNAL_ASSERT(set_0_entry != entry1);
new_set->pushBack(set_0_entry);
disjoint_set_maps_[set_0_entry] = new_set;
}
disjoint_sets_.erase(
std::find(disjoint_sets_.begin(), disjoint_sets_.end(), set_0));
// Erase invalidates iterators, regrab.
set_it_1 = disjoint_set_maps_.find(entry1);
set_1_found = set_it_1 != disjoint_set_maps_.end();
} else {
new_set->pushBack(entry0);
disjoint_set_maps_[entry0] = new_set;
}

auto set0_shared_ptr = set_it_0->second;
auto set1_shared_ptr = set_it_1->second;

// If the sets are already the same, do nothing
if (set0_shared_ptr == set1_shared_ptr) {
return;
}

// Place everything in set1 into set0 and remap all entries in set1 to set0
for (auto entry : set1_shared_ptr->vector()) {
set0_shared_ptr->pushBack(entry);
disjoint_set_maps_[entry] = set0_shared_ptr;
if (set_1_found) {
auto set_1 = set_it_1->second;
for (auto set_1_entry : *set_1) {
new_set->pushBack(set_1_entry);
disjoint_set_maps_[set_1_entry] = new_set;
}
disjoint_sets_.erase(
std::find(disjoint_sets_.begin(), disjoint_sets_.end(), set_1));
} else {
new_set->pushBack(entry1);
disjoint_set_maps_[entry1] = new_set;
}

// set1 no longer needed as its entries are copied into set0
disjoint_sets_.erase(std::find(
disjoint_sets_.begin(), disjoint_sets_.end(), set1_shared_ptr));
}

// Will assert if provided entry0 is not in any disjoint set, otherwise
Expand Down Expand Up @@ -319,11 +417,7 @@ class DisjointSets {
const std::string sep(" ");
for (auto s_ptr : disjoint_sets_) {
auto& set = *s_ptr;
ss << sep << "{\n";
for (auto entry : set.vector()) {
ss << sep << sep << abstractToString(entry) << "\n";
}
ss << sep << "}\n";
ss << sep << abstractToString(set) << "\n";
}
ss << "}";
return ss.str();
Expand Down
35 changes: 10 additions & 25 deletions third_party/nvfuser/csrc/grouped_reduction.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include <compute_at_map.h>
#include <ir_builder.h>
#include <ir_utils.h>
#include <root_domain_map.h>
Expand All @@ -13,24 +14,17 @@ namespace cuda {
namespace {

// Return if ref and other are transformed in the same way.
bool hasMatchingTransformations(TensorView* ref, TensorView* other) {
std::unordered_map<IterDomain*, IterDomain*> ref_2_other;
for (const auto i : c10::irange(ref->getRootDomain().size())) {
ref_2_other.emplace(
ref->getRootDomain().at(i), other->getRootDomain().at(i));
}

auto replay =
BestEffortReplay(
other->domain()->domain(), ref->domain()->domain(), ref_2_other)
.getIterDomainEquivalence();

bool hasMatchingTransformations(
TensorView* ref,
TensorView* other,
const IterDomainGraphs& id_graphs) {
for (const auto i : c10::irange(ref->nDims())) {
if (!replay.permissiveAreMapped(ref->axis(i), other->axis(i))) {
if (!id_graphs.idGraph(IdMappingMode::EXACT)
.disjointIdSets()
.permissiveAreMapped(ref->axis(i), other->axis(i))) {
return false;
}
}

return true;
}

Expand All @@ -45,7 +39,7 @@ void validateReductionGrouping(
TORCH_INTERNAL_ASSERT(
fusion != nullptr, "Grouping of reductions must be done within a Fusion");

ExactRootDomainMap exact_map(fusion);
IterDomainGraphs id_graphs(fusion);

// Pick the first output TV as a reference and compare it with the
// rest. Do not allow grouping if any mismatch is detected.
Expand Down Expand Up @@ -112,19 +106,10 @@ void validateReductionGrouping(
output_id->toString(),
". Invalid tensor: ",
output_tv->toString());
TORCH_INTERNAL_ASSERT(
exact_map.areMapped(ref_id, output_id) || ref_id->sameAs(output_id),
"Invalid grouped reduction due to mismatched root domains. ",
"Reference domain: ",
ref_id->toString(),
". Mismatched domain: ",
output_id->toString(),
". Invalid tensor: ",
output_tv->toString());
}

TORCH_INTERNAL_ASSERT(
hasMatchingTransformations(ref_tv, output_tv),
hasMatchingTransformations(ref_tv, output_tv, id_graphs),
"Invalid grouped reduction due to mismatched transformations. ",
"Reference tensor: ",
ref_tv->toString(),
Expand Down
Loading