forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
custom_class.cpp
40 lines (29 loc) · 910 Bytes
/
custom_class.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
#include <torch/custom_class.h>
#include <atomic>
namespace torch {
namespace jit {
std::vector<c10::RegisterOperators>& registeredOps() {
static std::vector<c10::RegisterOperators> ops;
return ops;
}
std::shared_ptr<script::CompilationUnit>& classCU() {
static std::shared_ptr<script::CompilationUnit> cu =
std::make_shared<script::CompilationUnit>();
return cu;
}
bool isCustomClass(const c10::IValue& v) {
return v.isObject() && v.toObject()->type()->name() &&
getCustomClass(v.toObject()->type()->name()->qualifiedName());
}
namespace {
TypePtr realCustomClassHandler(const std::string& name) {
return classCU()->get_type(name);
}
} // namespace
int register_custom_class_handler() {
setGetCustomClassFn(realCustomClassHandler);
return 0;
};
static int ensure_custom_class_handler_registered = register_custom_class_handler();
} // namespace jit
} // namespace torch