forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Resize.cpp
94 lines (88 loc) · 3.57 KB
/
Resize.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
#include <ATen/ATen.h>
#include <ATen/native/Resize.h>
#include <ATen/native/ResizeCommon.h>
#include <ATen/core/op_registration/op_registration.h>
#include <c10/core/TensorOptions.h>
namespace at { namespace native {
// Call the sparse implementation in SparseTensor.cpp directly.
// A dynamic dispatch here is NOT necessary, so I didn't put
// this function in native_functions.yaml
Tensor& resize_as_sparse_(Tensor& self, const Tensor& src);
// TODO(VitalyFedyunin): Move it to HTML docs.
//
// Strides of the output tensor of `resize_as_` operator is defined by input
// tensor strides and the value of memory_format argument.
//
// If memory_format is omitted and input tensor have the same shape as output
// tensor, strides of the output will remain unchanged. Strides going to be
// set to contiguous if shapes are different.
//
// If memory_format is equals to MemoryFormat::Contiguous (torch.contiguous_format)
// output tensor will have contiguous strides.
//
// If memory_format is equal to MemoryFormat::ChannelsLast (torch.channels_last)
// and input tensor is 4D, output tensor will have channels last memory layout.
//
// If memory_format is equal to MemoryFormat::Preserve (torch.preserve_format)
// output tensor will be defined by strides of the input tensor, following
// memory format preservation rule:
//
// - If input tensor strides are in channels last format, output tensor will
// have channels last memory layout.
//
// - Otherwise, output tensor will have contiguous memory layout.
//
Tensor& resize_as_(
Tensor& self,
const Tensor& the_template,
c10::optional<MemoryFormat> optional_memory_format) {
if (self.is_sparse() && the_template.is_sparse()) {
TORCH_CHECK(
!optional_memory_format.has_value(),
"Unsupported memory format for sparse tensor resize_as_ :",
optional_memory_format.value());
return native::resize_as_sparse_(self, the_template);
}
Tensor& result = self.resize_(the_template.sizes());
if (optional_memory_format.has_value()) {
auto memory_format = optional_memory_format.value();
if (memory_format == MemoryFormat::Preserve) {
memory_format = the_template.suggest_memory_format();
}
self.unsafeGetTensorImpl()->empty_tensor_restride(memory_format);
}
namedinference::propagate_names(result, the_template);
return result;
}
Tensor& resize_(
Tensor& self,
IntArrayRef size,
c10::optional<MemoryFormat> optional_memory_format) {
if (self.has_names()) {
return resize_named_tensor_(self, size, optional_memory_format);
}
auto* self_ = self.unsafeGetTensorImpl();
resize_impl_cpu_(self_, size, /*strides=*/c10::nullopt);
if (optional_memory_format.has_value()) {
auto memory_format =
optional_memory_format.value();
TORCH_CHECK(
memory_format != MemoryFormat::Preserve,
"Unsupported memory format",
memory_format);
self_->empty_tensor_restride(memory_format);
}
return self;
}
static auto registry = torch::RegisterOperators()
.op(torch::RegisterOperators::options()
.schema("aten::resize_(Tensor(a!) self, int[] size, *, MemoryFormat? memory_format=None) -> Tensor(a!)")
.impl_unboxedOnlyKernel<decltype(resize_), &resize_>(DispatchKey::CPUTensorId)
.aliasAnalysis(AliasAnalysisKind::FROM_SCHEMA))
.op(torch::RegisterOperators::options()
.schema("aten::resize_as_(Tensor(a!) self, Tensor the_template, *, MemoryFormat? memory_format=None) -> Tensor(a!)")
.impl_unboxedOnlyCatchAllKernel<decltype(resize_as_), &resize_as_>()
.aliasAnalysis(AliasAnalysisKind::FROM_SCHEMA))
;
} // namespace native
} // namespace at