Skip to content

Commit

Permalink
Add support for int64_t indices in TBE inference [1/N]
Browse files Browse the repository at this point in the history
Summary: - Add support for int64_t indices in TBE inference [1/N]

Differential Revision: D61813383
  • Loading branch information
Benson Ma authored and facebook-github-bot committed Sep 30, 2024
1 parent 7a881f2 commit dfbe4cb
Show file tree
Hide file tree
Showing 2 changed files with 207 additions and 66 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
*/

// clang-format off
{% set wdesc = "weighted" if weighted else "unweighted" %}
{%- set wdesc = "weighted" if weighted else "unweighted" %}
#include "fbgemm_gpu/embedding_forward_template_helpers.cuh"
#include "fbgemm_gpu/utils/tensor_accessor.h"

Expand All @@ -22,7 +22,7 @@ namespace nbit {
`Tensor int_nbit_split_embedding*_codegen_forward_*_cuda(...)` later in the
same generated source file.
*/
{% for emb_weight_type in ["FP32", "FP16", "FP8", "INT8", "INT4", "INT2"] %}
{%- for emb_weight_type in ["FP32", "FP16", "FP8", "INT8", "INT4", "INT2"] %}
template<typename index_t, typename output_t, size_t OutputRowsPerThread, size_t WarpsPerBlock, size_t InputRowsInFlight, size_t MinNum128BRows, size_t MaxNum128BRows, bool DeviceOnly>
__launch_bounds__(WarpsPerBlock * kWarpSize)
__global__ void {{ type_map[emb_weight_type].enum_name }}_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{{ wdesc }}_kernel_small_L(
Expand All @@ -31,30 +31,30 @@ __global__ void {{ type_map[emb_weight_type].enum_name }}_split_embedding{{ "_no
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> weights_placements,
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> weights_offsets,
const pta::PackedTensorAccessor32<uint8_t, 1, at::RestrictPtrTraits> weights_tys,
{% if not nobag %}
{%- if not nobag %}
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> D_offsets,
{% else %}
{%- else %}
const int64_t D,
{% endif %}
{%- endif %}
FixedDivisor fd_B, // FixedDivisor(div_round_up(B, OutputRowsPerThread))
const pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits> indices,
const pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits> offsets,
{% if not nobag %}
{%- if not nobag %}
const int64_t pooling_mode,
{% endif %}
{%- endif %}
const int64_t row_alignment,
{% if weighted %}
{%- if weighted %}
pta::PackedTensorAccessor32<float, 1, at::RestrictPtrTraits> indice_weights,
{% endif %}
{% if type_map[emb_weight_type].enum_name == "FP8" %}
{%- endif %}
{%- if type_map[emb_weight_type].enum_name == "FP8" %}
const int fp8_exponent_bits,
const int fp8_exponent_bias,
{% endif %}
{%- endif %}
pta::PackedTensorAccessor32<output_t, 2, at::RestrictPtrTraits> output, // [B][total_D],
const pta::PackedTensorAccessor64<uint8_t, 2, at::RestrictPtrTraits> lxu_cache_weights,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> lxu_cache_locations
);
{% endfor %} // for emb_weight_type in ["FP32", "FP16", "FP8", "INT8", "INT4", "INT2"]
{%- endfor %} // for emb_weight_type in ["FP32", "FP16", "FP8", "INT8", "INT4", "INT2"]

}

Expand Down Expand Up @@ -107,58 +107,7 @@ __global__ void {{ type_map[emb_weight_type].enum_name }}_split_embedding{{ "_no
C10_CUDA_KERNEL_LAUNCH_CHECK(); \
{%- endmacro %}
Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{{ wdesc }}_cuda(
Tensor dev_weights,
Tensor uvm_weights,
Tensor weights_placements,
Tensor weights_offsets,
Tensor weights_tys,
{% if not nobag %}
Tensor D_offsets,
const int64_t total_D,
{% else %}
const int64_t D,
{% endif %}
const int64_t max_int2_D,
const int64_t max_int4_D,
const int64_t max_int8_D,
const int64_t max_float16_D,
const int64_t max_float32_D,
Tensor indices,
Tensor offsets,
{% if not nobag %}
const int64_t pooling_mode,
{% endif %}
const int64_t row_alignment,
{% if weighted %}
Tensor indice_weights,
{% endif %}
const int64_t output_dtype,
Tensor lxu_cache_weights,
Tensor lxu_cache_locations,
const int64_t max_float8_D,
const int64_t fp8_exponent_bits,
const int64_t fp8_exponent_bias
) {
TENSOR_ON_CUDA_GPU(dev_weights);
TENSORS_ON_SAME_DEVICE(uvm_weights, dev_weights);
TENSORS_ON_SAME_DEVICE(weights_placements, dev_weights);
TENSORS_ON_SAME_DEVICE(weights_offsets, dev_weights);
TENSORS_ON_SAME_DEVICE(weights_tys, dev_weights);
{% if not nobag %}
TENSORS_ON_SAME_DEVICE(D_offsets, dev_weights);
{% endif %}
TENSORS_ON_SAME_DEVICE(indices, dev_weights);
TENSORS_ON_SAME_DEVICE(offsets, dev_weights);
{% if weighted %}
TENSORS_EMPTY_OR_ON_SAME_DEVICE(indice_weights, dev_weights);
{% endif %}
TENSORS_EMPTY_OR_ON_SAME_DEVICE(lxu_cache_weights, dev_weights);
TENSORS_EMPTY_OR_ON_SAME_DEVICE(lxu_cache_locations, dev_weights);
CUDA_DEVICE_GUARD(dev_weights);
{%- macro construct_and_return_output_tensor() %}
// kernels assume indices are contiguous.
indices = indices.contiguous();
Expand All @@ -180,8 +129,10 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{
TORCH_CHECK(D > 0);
{%- endif %}
// Construct output tensor
Tensor output;
const int kINT8QparamsBytes = 8;
SparseType o_dtype = static_cast<SparseType>(output_dtype);
TORCH_CHECK(o_dtype == SparseType::FP32 || o_dtype == SparseType::FP16 || o_dtype == SparseType::BF16 || o_dtype == SparseType::INT8);
Expand Down Expand Up @@ -216,11 +167,63 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{
if (B == 0 || indices.numel() == 0) {
return output;
}
{%- endmacro %}
using index_t = int32_t;
template <typename index_t>
Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{{ wdesc }}_cuda_impl(
Tensor dev_weights,
Tensor uvm_weights,
Tensor weights_placements,
Tensor weights_offsets,
Tensor weights_tys,
{%- if not nobag %}
Tensor D_offsets,
const int64_t total_D,
{%- else %}
const int64_t D,
{%- endif %}
const int64_t max_int2_D,
const int64_t max_int4_D,
const int64_t max_int8_D,
const int64_t max_float16_D,
const int64_t max_float32_D,
Tensor indices,
Tensor offsets,
{%- if not nobag %}
const int64_t pooling_mode,
{%- endif %}
const int64_t row_alignment,
{%- if weighted %}
Tensor indice_weights,
{%- endif %}
const int64_t output_dtype,
Tensor lxu_cache_weights,
Tensor lxu_cache_locations,
const int64_t max_float8_D,
const int64_t fp8_exponent_bits,
const int64_t fp8_exponent_bias
) {
TENSOR_ON_CUDA_GPU(dev_weights);
TENSORS_ON_SAME_DEVICE(uvm_weights, dev_weights);
TENSORS_ON_SAME_DEVICE(weights_placements, dev_weights);
TENSORS_ON_SAME_DEVICE(weights_offsets, dev_weights);
TENSORS_ON_SAME_DEVICE(weights_tys, dev_weights);
{%- if not nobag %}
TENSORS_ON_SAME_DEVICE(D_offsets, dev_weights);
{%- endif %}
TENSORS_ON_SAME_DEVICE(indices, dev_weights);
TENSORS_ON_SAME_DEVICE(offsets, dev_weights);
{%- if weighted %}
TENSORS_EMPTY_OR_ON_SAME_DEVICE(indice_weights, dev_weights);
{%- endif %}
TENSORS_EMPTY_OR_ON_SAME_DEVICE(lxu_cache_weights, dev_weights);
TENSORS_EMPTY_OR_ON_SAME_DEVICE(lxu_cache_locations, dev_weights);
constexpr int32_t kWarpsPerBlock = 4;
CUDA_DEVICE_GUARD(dev_weights);
{{- construct_and_return_output_tensor() }}
constexpr int32_t kWarpsPerBlock = 4;
const auto device_only = lxu_cache_weights.numel() == 0 && uvm_weights.numel() == 0;
#define Y(...) \
if (device_only) { \
Expand Down Expand Up @@ -397,6 +400,104 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{
}));
#undef X
return output;
}
Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{{ wdesc }}_cuda(
Tensor dev_weights,
Tensor uvm_weights,
Tensor weights_placements,
Tensor weights_offsets,
Tensor weights_tys,
{%- if not nobag %}
Tensor D_offsets,
const int64_t total_D,
{%- else %}
const int64_t D,
{%- endif %}
const int64_t max_int2_D,
const int64_t max_int4_D,
const int64_t max_int8_D,
const int64_t max_float16_D,
const int64_t max_float32_D,
Tensor indices,
Tensor offsets,
{%- if not nobag %}
const int64_t pooling_mode,
{%- endif %}
const int64_t row_alignment,
{%- if weighted %}
Tensor indice_weights,
{%- endif %}
const int64_t output_dtype,
Tensor lxu_cache_weights,
Tensor lxu_cache_locations,
const int64_t max_float8_D,
const int64_t fp8_exponent_bits,
const int64_t fp8_exponent_bias
) {
// All argument tensors need to be on the same CUDA device
TENSOR_ON_CUDA_GPU(dev_weights);
TENSORS_ON_SAME_DEVICE(uvm_weights, dev_weights);
TENSORS_ON_SAME_DEVICE(weights_placements, dev_weights);
TENSORS_ON_SAME_DEVICE(weights_offsets, dev_weights);
TENSORS_ON_SAME_DEVICE(weights_tys, dev_weights);
{%- if not nobag %}
TENSORS_ON_SAME_DEVICE(D_offsets, dev_weights);
{%- endif %}
TENSORS_ON_SAME_DEVICE(indices, dev_weights);
TENSORS_ON_SAME_DEVICE(offsets, dev_weights);
{%- if weighted %}
TENSORS_EMPTY_OR_ON_SAME_DEVICE(indice_weights, dev_weights);
{%- endif %}
TENSORS_EMPTY_OR_ON_SAME_DEVICE(lxu_cache_weights, dev_weights);
TENSORS_EMPTY_OR_ON_SAME_DEVICE(lxu_cache_locations, dev_weights);
// indices and offsets need to have the same scalar type
TENSORS_HAVE_SAME_TYPE(indices, offsets);
// Only int32_t and int64_t indices are supported at the moment
TENSOR_SCALAR_TYPE_IS_ONE_OF(indices, at::ScalarType::Long, at::ScalarType::Int);
CUDA_DEVICE_GUARD(dev_weights);
// Create output tensor ref
Tensor output;
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "{{ 'int_nbit_split_embedding' + ('_nobag' if nobag else '') + '_codegen_forward_' + wdesc + '_cuda' }}", [&] {
output = int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{{ wdesc }}_cuda_impl<index_t>(
dev_weights,
uvm_weights,
weights_placements,
weights_offsets,
weights_tys,
{%- if not nobag %}
D_offsets,
total_D,
{%- else %}
D,
{%- endif %}
max_int2_D,
max_int4_D,
max_int8_D,
max_float16_D,
max_float32_D,
indices,
offsets,
{%- if not nobag %}
pooling_mode,
{%- endif %}
row_alignment,
{%- if weighted %}
indice_weights,
{%- endif %}
output_dtype,
lxu_cache_weights,
lxu_cache_locations,
max_float8_D,
fp8_exponent_bits,
fp8_exponent_bias);
});
return output;
}
Expand Down
40 changes: 40 additions & 0 deletions fbgemm_gpu/include/fbgemm_gpu/utils/tensor_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -299,3 +299,43 @@ inline at::Tensor aligned_grad_output_tensor_for_cuda_backwards(
}
return aligned_grad_output;
}

template <typename... ScalarTypes>
std::string tensor_scalar_type_is_one_of(
const at::Tensor& ten,
const ScalarTypes&... ttypes) {
auto has_match = false;

// Collect the GPU index of the first non-empty optional tensor and make sure
// that all tensors are on this same index.
(
[&](const auto& ttype) {
if (ten.scalar_type() == ttype) {
has_match = true;
}
}(ttypes),
...);

if (has_match) {
return "";
}

std::string msg = "Tensor's scalar type (";
msg.append(toString(ten.scalar_type()));
msg.append(") did not match any one of the following types: [");
(
[&](const auto& ttype) {
msg.append(toString(ttype));
msg.append(", ");
}(ttypes),
...);

msg.append("]");
return msg;
}

#define TENSOR_SCALAR_TYPE_IS_ONE_OF(...) \
do { \
const auto has_match = tensor_scalar_type_is_one_of(__VA_ARGS__); \
TORCH_CHECK(has_match.empty(), has_match); \
} while (false)

0 comments on commit dfbe4cb

Please sign in to comment.