Skip to content

Commit

Permalink
Add a CPU nbit to float dequantization op that supports torch.quintMx…
Browse files Browse the repository at this point in the history
…N type and QuantizedCPU backend

Differential Revision: D61305979
  • Loading branch information
wsu authored and facebook-github-bot committed Aug 15, 2024
1 parent 1ec08ad commit 21430ec
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 6 deletions.
20 changes: 20 additions & 0 deletions fbgemm_gpu/include/fbgemm_gpu/quantize_ops_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#pragma once

#include <ATen/ATen.h>
#include "fbgemm/Types.h"
#include "fbgemm_gpu/utils/types.h"

namespace fbgemm_gpu {
Expand Down Expand Up @@ -111,4 +112,23 @@ hfp8_to_float(uint8_t hfp8_val, int ebits, int exponent_bias) {
return val_out.F;
}

// Get the number of bytes of a row in a tensor with quantized nbit integers
inline int32_t nbit_elems_to_bytes(const at::Tensor& input) {
const auto input_sizes = input.sizes();
const int32_t ncols = input_sizes[1];
// at::kQUInt4x2 is the dtype for quantized int4 tensors and at::kQUInt2x4 is
// for quantized int2 tensors. QUIntMxN (M*N=8) means quantized M-bit integer
// with each byte holding N such elements.
// input_sizes[1] is the number of elements in each row, so we need to divide
// it by 2 or 4 for quint4x2 or quint2x4 respectively to get the number of
// bytes in each row.
if (input.dtype() == at::kQUInt2x4) {
return fbgemm::div_up(ncols, 4);
} else if (input.dtype() == at::kQUInt4x2) {
return fbgemm::div_up(ncols, 2);
} else {
return ncols;
}
}

} // namespace fbgemm_gpu
5 changes: 5 additions & 0 deletions fbgemm_gpu/include/fbgemm_gpu/sparse_ops_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,11 @@ inline bool torch_tensor_empty_or_on_cpu_check(
#define DISPATCH_TO_CPU(name, function) \
m.impl(name, torch::dispatch(c10::DispatchKey::CPU, TORCH_FN(function)))

#define DISPATCH_TO_QUANTIZED_CPU(name, function) \
m.impl( \
name, \
torch::dispatch(c10::DispatchKey::QuantizedCPU, TORCH_FN(function)))

#define DISPATCH_TO_META(name, function) \
m.impl(name, torch::dispatch(c10::DispatchKey::Meta, TORCH_FN(function)))

Expand Down
63 changes: 62 additions & 1 deletion fbgemm_gpu/src/quantize_ops/quantize_ops_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,8 @@ Tensor _fusednbitrowwise_to_float_cpu(

const auto input_sizes = input.sizes();
const int64_t nrows = input_sizes[0];
const int32_t ncols = input_sizes[1];
// Here we want the number of bytes in a row
const int32_t ncols = nbit_elems_to_bytes(input);
const int32_t num_elem_per_byte = 8 / bit_rate;
const int32_t output_columns =
(ncols - 2 * sizeof(at::Half)) * num_elem_per_byte;
Expand All @@ -148,6 +149,40 @@ Tensor _fusednbitrowwise_to_float_cpu(
return output;
}

Tensor _fusednbitrowwise_sbfront_to_float_cpu(
const Tensor& input,
const int64_t bit_rate) {
TENSOR_ON_CPU(input);
TENSOR_NDIM_EQUALS(input, 2);

const auto input_sizes = input.sizes();
const int64_t nrows = input_sizes[0];
// Here we want the number of bytes in a row
const int32_t ncols = nbit_elems_to_bytes(input);
const int32_t num_elem_per_byte = 8 / bit_rate;
const int32_t output_columns =
(ncols - 2 * sizeof(at::Half)) * num_elem_per_byte;

Tensor output;
output = at::empty(
{nrows, output_columns}, // 4 = sizeof(float)
input.options().dtype(at::kFloat));

float* output_data = static_cast<float*>(
output.data_ptr()); // output.data_ptr<output_t>(); -> Yields
// unresolved data_ptr symbol.

fbgemm::FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfRef<float>(
bit_rate,
input.data_ptr<uint8_t>(),
nrows,
ncols,
output_data,
/*scale_bias_last=*/false);

return output;
}

/// @ingroup quantize-data-cpu
///
Tensor& _fused8bitrowwise_to_float_cpu_out(
Expand Down Expand Up @@ -273,6 +308,24 @@ Tensor fusednbitrowwise_to_float_cpu(
return _fusednbitrowwise_to_float_cpu<float>(input, bit_rate);
}

/// @ingroup quantize-data-cpu
/// @brief Dequantize int4/int2 rows with scale and bias stored in the front
/// into float32.
/// @param input Tensor of int4/int2 rows with scale and bias stored in the
/// front.
/// @param bit_rate Bit rate of each element. Should be 4 or 2.
/// @return Tensor of float32, holding dequantized numbers.
///
/// Dequantize int4/int2 rows with scale and bias stored in the front into
/// float32. The input tensor should have torch.quint4x2 or torch.quint2x4 dtype
/// and QuantizedCPU backend. This operator is only recommended for testing
/// purpose because its kernel is reference implementation and not optimized.
Tensor fusednbitrowwise_sbfront_to_float_cpu(
const Tensor& input,
const int64_t bit_rate) {
return _fusednbitrowwise_sbfront_to_float_cpu(input, bit_rate);
}

/// @ingroup quantize-data-cpu
///
Tensor fusednbitrowwise_to_half_cpu(
Expand Down Expand Up @@ -465,6 +518,8 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
"FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf(Tensor input, int bit_rate) -> Tensor");
m.def(
"FusedNBitRowwiseQuantizedSBHalfToFloat(Tensor input, int bit_rate) -> Tensor");
m.def(
"FusedNBitRowwiseQuantizedSBHalfFrontToFloat(Tensor input, int bit_rate) -> Tensor");
m.def(
"FusedNBitRowwiseQuantizedSBHalfToHalf(Tensor input, int bit_rate) -> Tensor");
m.def(
Expand All @@ -484,6 +539,12 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
m.def("dequantize_mx_cuda(Tensor input, int mx_group_size) -> Tensor");
}

TORCH_LIBRARY_IMPL(fbgemm, QuantizedCPU, m) {
DISPATCH_TO_QUANTIZED_CPU(
"FusedNBitRowwiseQuantizedSBHalfFrontToFloat",
fbgemm_gpu::fusednbitrowwise_sbfront_to_float_cpu);
}

TORCH_LIBRARY_IMPL(fbgemm, CPU, m) {
DISPATCH_TO_CPU(
"FloatToFused8BitRowwiseQuantized",
Expand Down
3 changes: 2 additions & 1 deletion include/fbgemm/QuantUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,8 @@ FBGEMM_API void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfRef(
const uint8_t* input,
size_t input_rows,
int input_columns,
OutputType* output);
OutputType* output,
bool scale_bias_last = true);

/**
* Same as Fused8BitRowwiseQuantizedSBFloatToFloatOrHalf but unoptimized.
Expand Down
14 changes: 10 additions & 4 deletions src/QuantUtils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -729,7 +729,8 @@ void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfRef(
const uint8_t* input,
size_t input_rows,
int input_columns,
OutputType* output) {
OutputType* output,
bool scale_bias_last) {
static_assert(
std::is_same<OutputType, float>() || std::is_same<OutputType, float16>(),
"Only float and float16 types are allowed.");
Expand All @@ -742,13 +743,17 @@ void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfRef(
const std::uint8_t* input_row = input + row * input_columns;
const float16* input_row_scale_bias = reinterpret_cast<const float16*>(
input_row +
(output_columns + num_elem_per_byte - 1) / num_elem_per_byte);
(scale_bias_last
? (output_columns + num_elem_per_byte - 1) / num_elem_per_byte
: 0));
float scale = cpu_half2float(input_row_scale_bias[0]);
float bias = cpu_half2float(input_row_scale_bias[1]);
const std::uint8_t* nums =
(scale_bias_last) ? input_row : input_row + 2 * sizeof(float16);
OutputType* output_row = output + row * output_columns;

for (int64_t col = 0; col < output_columns; ++col) {
std::uint8_t quantized = input_row[col / num_elem_per_byte];
std::uint8_t quantized = nums[col / num_elem_per_byte];
quantized >>= (col % num_elem_per_byte) * bit_rate;
quantized &= (1 << bit_rate) - 1;
float output_value = scale * quantized + bias;
Expand Down Expand Up @@ -857,7 +862,8 @@ void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalf(
const uint8_t* input, \
size_t input_rows, \
int input_columns, \
type* output); \
type* output, \
bool scale_bias_last); \
template FBGEMM_API void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalf<type>( \
int bit_rate, \
const uint8_t* input, \
Expand Down

0 comments on commit 21430ec

Please sign in to comment.