Skip to content

Commit

Permalink
Make quantized relu more flexible with quant params and use nnlib ker…
Browse files Browse the repository at this point in the history
…nel on HiFi (pytorch#4530)

Summary:
Pull Request resolved: pytorch#4530

As titled. This diff removes the requirement for inputs and outputs of ReLU to share quantization parameters. That should improve the numerics and allow less `requant` nodes in the graph. Since the nnlib kernel does that and is much faster on HiFi, it's a good deal all around.

Reviewed By: hsharma35

Differential Revision: D60696710

fbshipit-source-id: 4fe3faef607b252526cb5aa3a83d064084ba454e
  • Loading branch information
mcremon-meta authored and facebook-github-bot committed Aug 3, 2024
1 parent 7cd96f7 commit 9b06921
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 11 deletions.
2 changes: 1 addition & 1 deletion backends/cadence/aot/functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@
- arg_meta: null
kernel_name: impl::reference::quantized_linear_out

- func: cadence::quantized_relu.out(Tensor X, Tensor X_zero_point, *, Tensor(a!) out) -> Tensor(a!)
- func: cadence::quantized_relu.out(Tensor X, Tensor X_zero_point, int out_zero_point, Tensor out_multiplier, Tensor out_shift, *, Tensor(a!) out) -> Tensor(a!)
kernels:
- arg_meta: null
kernel_name: impl::reference::quantized_relu_out
Expand Down
9 changes: 7 additions & 2 deletions backends/cadence/aot/ops_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,11 @@
"quantized_linear.out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)"
)

lib.define("quantized_relu(Tensor X, Tensor X_zero_point) -> (Tensor Y)")
lib.define(
"quantized_relu.out(Tensor X, Tensor X_zero_point, *, Tensor(a!) out) -> Tensor (a!)"
"quantized_relu(Tensor X, Tensor X_zero_point, int out_zero_point, Tensor out_multiplier, Tensor out_shift) -> (Tensor Y)"
)
lib.define(
"quantized_relu.out(Tensor X, Tensor X_zero_point, int out_zero_point, Tensor out_multiplier, Tensor out_shift, *, Tensor(a!) out) -> Tensor (a!)"
)

lib.define(
Expand Down Expand Up @@ -168,6 +170,9 @@ def quantized_layer_norm_meta(
def quantized_relu_meta(
X: torch.Tensor,
X_zero_point: torch.Tensor,
out_zero_point: int,
out_multiplier: torch.Tensor,
out_shift: torch.Tensor,
):
return X.new_empty(X.size(), dtype=torch.uint8)

Expand Down
22 changes: 22 additions & 0 deletions backends/cadence/aot/quantizer/fusion_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,15 @@ def get_args_and_kwargs_relu(
graph_module: GraphModule,
inputs_inputs: List[fx.Node],
dequants_inputs: List[fx.Node],
quant_node: fx.Node,
) -> Tuple[Tuple[ArgsType], Dict[str, ArgsType]]:
input_scale = dequants_inputs[0].args[1]
# pyre-fixme[58]: Unsupported operand types
requantize_scale = input_scale / quant_node.args[1]
requantize_scale_t = torch.tensor([requantize_scale])

(out_multiplier, out_shift) = quantize_tensor_multiplier(requantize_scale_t)

# Make the args and kwargs for the replacement op
args = tuple(inputs_inputs)

Expand All @@ -296,9 +304,22 @@ def get_args_and_kwargs_relu(
([1], dequants_inputs[0].args[2]),
{"dtype": torch.int32},
)
out_multiplier_ = graph_module.graph.call_function(
torch.ops.aten.full.default,
([1], out_multiplier[0].item()),
{"dtype": torch.int32},
)
out_shift_ = graph_module.graph.call_function(
torch.ops.aten.full.default,
([1], out_shift[0].item()),
{"dtype": torch.int32},
)

kwargs = {
"X_zero_point": X_zero_point,
"out_zero_point": quant_node.args[2],
"out_multiplier": out_multiplier_,
"out_shift": out_shift_,
}
return args, kwargs

Expand Down Expand Up @@ -420,6 +441,7 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
graph_module,
inputs_inputs,
dequants_inputs,
quant_node,
)
fused = graph_module.graph.call_function(
pattern.replacement_op(),
Expand Down
4 changes: 1 addition & 3 deletions backends/cadence/aot/quantizer/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,9 +303,7 @@ def get_anchors(
inputs=[(relu_node, 0)],
weights=[],
biases=[],
output=[
(relu_node, SharedQuantizationSpec((relu_node.args[0], relu_node)))
],
output=[(relu_node,)],
)

def replacement_op(self) -> OpOverload:
Expand Down
36 changes: 31 additions & 5 deletions backends/cadence/reference/operators/quantized_relu_out.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,31 +16,57 @@ namespace native {
using Tensor = exec_aten::Tensor;
using RuntimeContext = torch::executor::RuntimeContext;

// Note: this kernel assumes that the input and output share quantization
// parameters. If that is not the case, it will produce incorrect results.
template <typename T>
void quantized_relu_(
const Tensor& input,
const Tensor& in_zero_point,
const int64_t out_zero_point,
const Tensor& out_multiplier,
const Tensor& out_shift,
Tensor& output) {
T q_zero_point = in_zero_point.const_data_ptr<T>()[0];
const T* __restrict__ in = input.const_data_ptr<T>();
T* __restrict__ out = output.mutable_data_ptr<T>();

const int32_t* __restrict__ out_multiplier_data =
out_multiplier.const_data_ptr<int32_t>();
const int32_t* __restrict__ out_shift_data =
out_shift.const_data_ptr<int32_t>();

// Compute the out_scale from out_multiplier and out_shift
const float out_scale =
-out_multiplier_data[0] * 1.0 / (1 << 31) * pow(2, out_shift_data[0]);

for (size_t i = 0, e = input.numel(); i < e; ++i) {
out[i] = in[i] > q_zero_point ? in[i] : q_zero_point;
const T temp = in[i] > q_zero_point ? (in[i] - q_zero_point) : 0;
out[i] = kernels::quantize<T>(temp, out_scale, out_zero_point);
}
}

void quantized_relu_out(
RuntimeContext& ctx,
const Tensor& input,
const Tensor& in_zero_point,
const int64_t out_zero_point,
const Tensor& out_multiplier,
const Tensor& out_shift,
Tensor& output) {
if (input.scalar_type() == exec_aten::ScalarType::Byte) {
quantized_relu_<uint8_t>(input, in_zero_point, output);
quantized_relu_<uint8_t>(
input,
in_zero_point,
out_zero_point,
out_multiplier,
out_shift,
output);
} else if (input.scalar_type() == exec_aten::ScalarType::Char) {
quantized_relu_<int8_t>(input, in_zero_point, output);
quantized_relu_<int8_t>(
input,
in_zero_point,
out_zero_point,
out_multiplier,
out_shift,
output);
} else {
ET_CHECK_MSG(false, "Unhandled input dtype %hhd", input.scalar_type());
}
Expand Down

0 comments on commit 9b06921

Please sign in to comment.