forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathqualified_name.h
160 lines (135 loc) · 4.22 KB
/
qualified_name.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
#pragma once
#include <c10/util/ArrayRef.h>
#include <c10/util/Exception.h>
#include <c10/util/StringUtil.h>
#include <string>
namespace c10 {
// Represents a name of the form "foo.bar.baz"
struct QualifiedName {
QualifiedName() {}
// `name` can be a dotted string, like "foo.bar.baz", or just a bare name.
/* implicit */ QualifiedName(const std::string& name) {
TORCH_CHECK(!name.empty());
// split the string into its atoms.
size_t startSearchFrom = 0;
size_t pos = name.find(delimiter_, startSearchFrom);
while (pos != std::string::npos) {
auto atom = name.substr(startSearchFrom, pos - startSearchFrom);
AT_ASSERTM(
atom.size() > 0, "Invalid name for qualified name: '", name, "'");
atoms_.push_back(std::move(atom));
startSearchFrom = pos + 1;
pos = name.find(delimiter_, startSearchFrom);
}
auto finalAtom = name.substr(startSearchFrom, pos - startSearchFrom);
AT_ASSERTM(
finalAtom.size() > 0, "Invalid name for qualified name: '", name, "'");
atoms_.push_back(std::move(finalAtom));
cacheAccessors();
}
explicit QualifiedName(std::vector<std::string> atoms) {
for (const auto& atom : atoms) {
TORCH_CHECK(!atom.empty(), "Atom cannot be empty");
TORCH_CHECK(
atom.find(delimiter_) == std::string::npos,
"Delimiter not allowed in atom");
}
atoms_ = atoms;
cacheAccessors();
}
// Unnecessary copy. Ideally we'd use somoething like std::string_view.
/* implicit */ QualifiedName(const char* name)
: QualifiedName(std::string(name)) {}
// `name` must be a bare name (no dots!)
explicit QualifiedName(const QualifiedName& prefix, std::string name) {
TORCH_INTERNAL_ASSERT(!name.empty());
TORCH_INTERNAL_ASSERT(name.find(delimiter_) == std::string::npos);
atoms_.insert(atoms_.begin(), prefix.atoms_.begin(), prefix.atoms_.end());
atoms_.push_back(std::move(name));
cacheAccessors();
}
// Is `this` a prefix of `other`?
// For example, "foo.bar" is a prefix of "foo.bar.baz"
bool isPrefixOf(const QualifiedName& other) const {
const auto& thisAtoms = atoms_;
const auto& otherAtoms = other.atoms_;
if (thisAtoms.size() > otherAtoms.size()) {
// Can't be a prefix if it's bigger
return false;
}
for (size_t i = 0; i < thisAtoms.size(); i++) {
if (thisAtoms[i] != otherAtoms[i]) {
return false;
}
}
return true;
}
// The fully qualified name, like "foo.bar.baz"
const std::string& qualifiedName() const {
return qualifiedName_;
}
// The leading qualifier, like "foo.bar"
const std::string& prefix() const {
return prefix_;
}
// The base name, like "baz"
const std::string& name() const {
return name_;
}
const std::vector<std::string>& atoms() const {
return atoms_;
}
bool operator==(const QualifiedName& other) const {
return this->qualifiedName_ == other.qualifiedName_;
}
bool operator!=(const QualifiedName& other) const {
return !(*this == other);
}
private:
static constexpr char delimiter_ = '.';
// Helper for cacheAccessors() below.
template<typename T>
std::string join(char delimiter, const T& v) {
std::string out;
size_t reserve = 0;
for (const auto& e : v) {
reserve += e.size() + 1;
}
out.reserve(reserve);
for (size_t i = 0; i < v.size(); ++i) {
if (i != 0) {
out.push_back(delimiter);
}
out.append(v[i]);
}
return out;
}
void cacheAccessors() {
qualifiedName_ = join(delimiter_, atoms_);
if (atoms_.size() > 1) {
ArrayRef<std::string> view(atoms_);
const auto prefixView = view.slice(0, view.size() - 1);
prefix_ = join(delimiter_, prefixView);
}
if (atoms_.size() >= 1) {
name_ = atoms_.back();
}
}
// The actual list of names, like "{foo, bar, baz}"
std::vector<std::string> atoms_;
/*
* Cached accessors, derived from `atoms_`.
*/
std::string qualifiedName_;
std::string prefix_;
std::string name_;
};
} // namespace c10
namespace std {
template <>
struct hash<c10::QualifiedName> {
size_t operator()(const c10::QualifiedName& n) const noexcept {
return std::hash<std::string>()(n.qualifiedName());
}
};
} // namespace std