forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
custom_class.h
205 lines (180 loc) · 7.76 KB
/
custom_class.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
#pragma once
#include <ATen/core/function_schema.h>
#include <ATen/core/ivalue.h>
#include <ATen/core/jit_type.h>
#include <ATen/core/op_registration/op_registration.h>
#include <ATen/core/stack.h>
#include <c10/util/C++17.h>
#include <c10/util/Metaprogramming.h>
#include <c10/util/TypeList.h>
#include <c10/util/TypeTraits.h>
#include <torch/csrc/jit/custom_class.h>
#include <torch/csrc/jit/operator.h>
#include <torch/csrc/jit/script/compilation_unit.h>
#include <torch/csrc/jit/tracer.h>
#include <torch/csrc/utils/variadic.h>
#include <torch/custom_class_detail.h>
#include <iostream>
#include <sstream>
namespace torch {
namespace jit {
template <class... Types>
detail::types<void, Types...> init() {
return detail::types<void, Types...>{};
}
// To bind custom classes into Torchscript, use an API very similar to Pybind's.
// Currently exposes one class `torch::jit::class_<T>` and 2 methods.
// - Constructing `torch::jit::class_<Foo>` registers `Foo` in Python and
// Torchscript, and puts it under `torch.classes.Foo` in Python.
// - torch::jit::class_<Foo>.def("method1", &Foo::method1) does some template
// metaprogramming to introspect the function types and register the operator
// for use in Torchscript.
// - torch::jit::class_<Foo>.def(torch::jit::init<int64_t, int64_t>()) registers
// the Foo(int, int) constructor.
// see test/custom_operator/classes.cpp and
// test/custom_operator/test_custom_classes.py for example usages
template <class CurClass>
class class_ {
static_assert(std::is_base_of<CustomClassHolder, CurClass>::value,
"torch::jit::class_<T> requires T to inherit from CustomClassHolder");
std::string className;
std::string qualClassName;
ClassTypePtr classTypePtr;
const std::string parentModule = "classes";
const std::string topModule = "__torch__.torch";
public:
class_(std::string className_) : className(std::move(className_)) {
qualClassName = topModule + "." + parentModule + "." + className;
// We currently represent custom classes as torchscript classes with a
// capsule attribute
classTypePtr =
ClassType::create(c10::QualifiedName(qualClassName), classCU());
classTypePtr->addAttribute("capsule", CapsuleType::get());
c10::getCustomClassTypeMap().insert({typeid(c10::intrusive_ptr<CurClass>).name(),
c10::StrongTypePtr(classCU(), classTypePtr)});
c10::getCustomClassTypeMap().insert({typeid(c10::tagged_capsule<CurClass>).name(),
c10::StrongTypePtr(classCU(), classTypePtr)});
classCU()->register_type(classTypePtr);
}
template <typename... Types>
class_& def(detail::types<void, Types...>) { // Used in combination with
// torch::jit::init<...>()
auto func = [](c10::tagged_capsule<CurClass> self, Types... args) {
auto classObj = c10::make_intrusive<CurClass>(args...);
auto genericPtr = c10::static_intrusive_pointer_cast<torch::jit::CustomClassHolder>(classObj);
auto capsule = IValue(genericPtr);
auto object = self.ivalue.toObject();
object->setSlot(0, capsule);
};
defineMethod("__init__", std::move(func));
return *this;
}
template <typename Func>
class_& def(std::string name, Func f) {
auto wrapped_f = detail::wrap_func<CurClass, Func>(std::move(f));
defineMethod(std::move(name), std::move(wrapped_f));
return *this;
}
// Pickle
template <typename GetStateFn, typename SetStateFn>
class_& def_pickle(GetStateFn&& get_state, SetStateFn&& set_state) {
static_assert(
c10::guts::is_stateless_lambda<std::decay_t<GetStateFn>>::value &&
c10::guts::is_stateless_lambda<std::decay_t<SetStateFn>>::value,
"torch::jit::pickle_ currently only supports lambdas as "
"__getstate__ and __setstate__ arguments.");
def("__getstate__", std::forward<GetStateFn>(get_state));
// __setstate__ needs to be registered with some custom handling:
// We need to wrap the invocation of of the user-provided function
// such that we take the return value (i.e. c10::intrusive_ptr<CurrClass>)
// and assign it to the `capsule` attribute.
using SetStateTraits =
c10::guts::infer_function_traits_t<std::decay_t<SetStateFn>>;
using SetStateArg = typename c10::guts::typelist::head_t<
typename SetStateTraits::parameter_types>;
auto setstate_wrapper = [set_state = std::move(set_state)](
c10::tagged_capsule<CurClass> self,
SetStateArg&& arg) {
c10::intrusive_ptr<CurClass> classObj =
at::guts::invoke(set_state, std::forward<SetStateArg>(arg));
auto genericPtr =
c10::static_intrusive_pointer_cast<torch::jit::CustomClassHolder>(
classObj);
auto capsule = IValue(genericPtr);
auto object = self.ivalue.toObject();
object->setSlot(0, capsule);
};
defineMethod(
"__setstate__",
detail::wrap_func<CurClass, decltype(setstate_wrapper)>(
std::move(setstate_wrapper)));
// type validation
auto getstate_schema = classTypePtr->getMethod("__getstate__")->getSchema();
auto format_getstate_schema = [&getstate_schema]() {
std::stringstream ss;
ss << getstate_schema;
return ss.str();
};
TORCH_CHECK(
getstate_schema.arguments().size() == 1,
"__getstate__ should take exactly one argument: self. Got: ",
format_getstate_schema());
auto first_arg_type = getstate_schema.arguments().at(0).type();
TORCH_CHECK(
*first_arg_type == *classTypePtr,
"self argument of __getstate__ must be the custom class type. Got ",
first_arg_type->python_str());
TORCH_CHECK(
getstate_schema.returns().size() == 1,
"__getstate__ should return exactly one value for serialization. Got: ",
format_getstate_schema());
auto ser_type = getstate_schema.returns().at(0).type();
auto setstate_schema = classTypePtr->getMethod("__setstate__")->getSchema();
auto arg_type = setstate_schema.arguments().at(1).type();
TORCH_CHECK(
(*arg_type == *ser_type),
"__setstate__'s argument should be the same type as the "
"return value of __getstate__. Got ",
arg_type->python_str(),
" but expected ",
ser_type->python_str());
return *this;
}
private:
template <typename Func>
void defineMethod(std::string name, Func func) {
auto graph = std::make_shared<Graph>();
auto qualFuncName = className + "::" + name;
ensure_c10_registerer_defined();
registeredOps().push_back(
torch::RegisterOperators().op(qualFuncName, std::move(func)));
auto func_symbol = c10::Symbol::fromQualString(qualFuncName);
auto ops = torch::jit::getAllOperatorsFor(func_symbol);
TORCH_CHECK(ops.size() == 1);
auto &schema = ops[0]->schema();
for (const auto& arg : schema.arguments()) {
graph->addInput()->setType(arg.type());
}
auto opCall = graph->insertNode(graph->create(
func_symbol, graph->inputs(), schema.returns().size()));
Value* res;
if (schema.returns().size() > 1) {
const auto& returns = schema.returns();
size_t op_invocation_idx = 0;
for (const auto& ret : returns) {
opCall->output(op_invocation_idx++)->setType(ret.type());
}
res = graph->insertNode(graph->createTuple(opCall->outputs()))->output();
} else if (schema.returns().size() == 1) {
const auto& returns = schema.returns();
res = opCall->output()->setType(returns[0].type());
} else {
res = graph->insertConstant(IValue())->setType(NoneType::get());
}
graph->registerOutput(res);
auto method = classCU()->create_function(qualClassName + "." + name, graph);
classTypePtr->addMethod(method);
}
};
} // namespace jit
} // namespace torch