Skip to content

Commit

Permalink
c10::optional -> std::optional in deeplearning/fbgemm/fbgemm_gpu/code…
Browse files Browse the repository at this point in the history
…gen/inference/embedding_forward_quantized_host.cpp +51 (#2623)

Summary:
Pull Request resolved: #2623

Generated with
```
fbgs -f '.*\.(cpp|cxx|cc|h|hpp|cu|cuh)$' c10::optional -l | perl -pe 's/^fbsource.fbcode.//' | grep -v executorch | xargs -n 50 perl -pi -e 's/c10::optional/std::optional/g'
```

 - If you approve of this diff, please use the "Accept & Ship" button :-)

(51 file modified.)

Reviewed By: palmje

Differential Revision: D57631089
  • Loading branch information
r-barnes authored and facebook-github-bot committed May 22, 2024
1 parent 3ff2f66 commit 21a6dc7
Show file tree
Hide file tree
Showing 49 changed files with 366 additions and 366 deletions.
44 changes: 22 additions & 22 deletions fbgemm_gpu/codegen/inference/embedding_forward_quantized_host.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -275,14 +275,14 @@ Tensor int_nbit_split_embedding_codegen_lookup_function(
Tensor indices,
Tensor offsets,
int64_t pooling_mode,
c10::optional<Tensor> indice_weights,
std::optional<Tensor> indice_weights,
int64_t output_dtype,
c10::optional<Tensor> lxu_cache_weights,
c10::optional<Tensor> lxu_cache_locations,
c10::optional<int64_t> row_alignment,
c10::optional<int64_t> max_float8_D,
c10::optional<int64_t> fp8_exponent_bits,
c10::optional<int64_t> fp8_exponent_bias) {
std::optional<Tensor> lxu_cache_weights,
std::optional<Tensor> lxu_cache_locations,
std::optional<int64_t> row_alignment,
std::optional<int64_t> max_float8_D,
std::optional<int64_t> fp8_exponent_bits,
std::optional<int64_t> fp8_exponent_bias) {
if (static_cast<PoolingMode>(pooling_mode) == PoolingMode::NONE) {
std::vector<int64_t> max_D_list{
max_int2_D,
Expand Down Expand Up @@ -390,29 +390,29 @@ Tensor int_nbit_split_embedding_uvm_caching_codegen_lookup_function(
Tensor indices,
Tensor offsets,
int64_t pooling_mode,
c10::optional<Tensor> indice_weights,
std::optional<Tensor> indice_weights,
int64_t output_dtype,
c10::optional<Tensor> lxu_cache_weights,
c10::optional<Tensor> lxu_cache_locations,
c10::optional<int64_t> row_alignment,
c10::optional<int64_t> max_float8_D,
c10::optional<int64_t> fp8_exponent_bits,
c10::optional<int64_t> fp8_exponent_bias,
std::optional<Tensor> lxu_cache_weights,
std::optional<Tensor> lxu_cache_locations,
std::optional<int64_t> row_alignment,
std::optional<int64_t> max_float8_D,
std::optional<int64_t> fp8_exponent_bits,
std::optional<int64_t> fp8_exponent_bias,
// Additional args for UVM_CACHING.
// cache_hash_size_cumsum: cumulative sum of # embedding rows of all the
// tables. 1D tensor, dtype=int64.
c10::optional<Tensor> cache_hash_size_cumsum,
std::optional<Tensor> cache_hash_size_cumsum,
// total_cache_hash_size: sum of # embedding rows of all the tables.
c10::optional<int64_t> total_cache_hash_size,
std::optional<int64_t> total_cache_hash_size,
// cache_index_table_map: (linearized) index to table number map.
// 1D tensor, dtype=int32.
c10::optional<Tensor> cache_index_table_map,
std::optional<Tensor> cache_index_table_map,
// lxu_cache_state: Cache state (cached index, or invalid).
// 2D tensor: # sets x assoc. dtype=int64.
c10::optional<Tensor> lxu_cache_state,
std::optional<Tensor> lxu_cache_state,
// lxu_state: meta info for replacement (time stamp for LRU).
// 2D tensor: # sets x assoc. dtype=int64.
c10::optional<Tensor> lxu_state) {
std::optional<Tensor> lxu_state) {
// This function does prefetch() and foward() methods in
// IntNBitTableBatchedEmbeddingBagsCodegen, but run them in sequence.
// Prefetching of multiple batches of requests is not yet supported.
Expand All @@ -435,7 +435,7 @@ Tensor int_nbit_split_embedding_uvm_caching_codegen_lookup_function(
cache_hash_size_cumsum.value(),
indices,
offsets,
/*B_offsets=*/c10::optional<Tensor>(),
/*B_offsets=*/std::optional<Tensor>(),
/*max_B=*/-1,
/*indices_base_offset=*/0);

Expand Down Expand Up @@ -506,8 +506,8 @@ Tensor int_nbit_split_embedding_uvm_caching_codegen_lookup_function(
total_cache_hash_size.value(),
gather_uvm_stats,
uvm_cache_stats,
c10::optional<Tensor>(), // num_uniq_cache_indices
c10::optional<Tensor>() // lxu_cache_locations_output
std::optional<Tensor>(), // num_uniq_cache_indices
std::optional<Tensor>() // lxu_cache_locations_output
);

#ifdef FBCODE_CAFFE2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,16 +93,16 @@ Tensor int_nbit_split_embedding_codegen_lookup_function_cpu(
Tensor indices,
Tensor offsets,
int64_t pooling_mode,
c10::optional<Tensor> indice_weights,
std::optional<Tensor> indice_weights,
int64_t output_dtype,
c10::optional<Tensor>
std::optional<Tensor>
lxu_cache_weights, // Not used, to match cache interface for CUDA op
c10::optional<Tensor>
std::optional<Tensor>
lxu_cache_locations, // Not used, to match cache interface for CUDA op
c10::optional<int64_t> row_alignment,
c10::optional<int64_t> max_float8_D,
c10::optional<int64_t> fp8_exponent_bits,
c10::optional<int64_t> fp8_exponent_bias) {
std::optional<int64_t> row_alignment,
std::optional<int64_t> max_float8_D,
std::optional<int64_t> fp8_exponent_bits,
std::optional<int64_t> fp8_exponent_bias) {
if (static_cast<PoolingMode>(pooling_mode) == PoolingMode::NONE) {
std::vector<int64_t> max_D_list{
max_int2_D,
Expand Down Expand Up @@ -179,20 +179,20 @@ Tensor int_nbit_split_embedding_uvm_caching_codegen_lookup_function_cpu(
Tensor indices,
Tensor offsets,
int64_t pooling_mode,
c10::optional<Tensor> indice_weights,
std::optional<Tensor> indice_weights,
int64_t output_dtype,
c10::optional<Tensor> lxu_cache_weights,
c10::optional<Tensor> lxu_cache_locations,
c10::optional<int64_t> row_alignment,
c10::optional<int64_t> max_float8_D,
c10::optional<int64_t> fp8_exponent_bits,
c10::optional<int64_t> fp8_exponent_bias,
std::optional<Tensor> lxu_cache_weights,
std::optional<Tensor> lxu_cache_locations,
std::optional<int64_t> row_alignment,
std::optional<int64_t> max_float8_D,
std::optional<int64_t> fp8_exponent_bits,
std::optional<int64_t> fp8_exponent_bias,
// Additinal args for uvm_caching version.
c10::optional<Tensor> cache_hash_size_cumsum [[maybe_unused]],
c10::optional<int64_t> total_cache_hash_size [[maybe_unused]],
c10::optional<Tensor> cache_index_table_map [[maybe_unused]],
c10::optional<Tensor> lxu_cache_state [[maybe_unused]],
c10::optional<Tensor> lxu_state [[maybe_unused]]) {
std::optional<Tensor> cache_hash_size_cumsum [[maybe_unused]],
std::optional<int64_t> total_cache_hash_size [[maybe_unused]],
std::optional<Tensor> cache_index_table_map [[maybe_unused]],
std::optional<Tensor> lxu_cache_state [[maybe_unused]],
std::optional<Tensor> lxu_state [[maybe_unused]]) {
LOG(WARNING)
<< "int_nbit_split_embedding_uvm_caching_codegen_lookup_function shouldn't be called for CPU; it is only for GPU.";
return int_nbit_split_embedding_codegen_lookup_function_cpu(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ class SplitLookupFunction_Dense_Op
Tensor indices,
Tensor offsets,
int64_t pooling_mode,
c10::optional<Tensor> indice_weights,
c10::optional<Tensor> feature_requires_grad) {
std::optional<Tensor> indice_weights,
std::optional<Tensor> feature_requires_grad) {
ctx->save_for_backward({
dev_weights,
weights_offsets,
Expand Down Expand Up @@ -388,8 +388,8 @@ Tensor split_embedding_codegen_lookup_dense_function(
Tensor indices,
Tensor offsets,
int64_t pooling_mode,
c10::optional<Tensor> indice_weights,
c10::optional<Tensor> feature_requires_grad,
std::optional<Tensor> indice_weights,
std::optional<Tensor> feature_requires_grad,
int64_t output_dtype = static_cast<int64_t>(SparseType::FP32)) {
if (static_cast<PoolingMode>(pooling_mode) == PoolingMode::NONE) {
return SplitNoBagLookupFunction_Dense_Op::apply(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ class SplitLookupFunction_Dense_Op
Tensor indices,
Tensor offsets,
int64_t pooling_mode,
c10::optional<Tensor> indice_weights,
c10::optional<Tensor> feature_requires_grad) {
std::optional<Tensor> indice_weights,
std::optional<Tensor> feature_requires_grad) {
Tensor indice_weights_value = indice_weights.value_or(Tensor());
Tensor feature_requires_grad_value =
feature_requires_grad.value_or(Tensor());
Expand Down Expand Up @@ -161,8 +161,8 @@ Tensor split_embedding_codegen_lookup_dense_function(
Tensor indices,
Tensor offsets,
int64_t pooling_mode,
c10::optional<Tensor> indice_weights,
c10::optional<Tensor> feature_requires_grad,
std::optional<Tensor> indice_weights,
std::optional<Tensor> feature_requires_grad,
int64_t /* output_dtype = static_cast<int64_t>(SparseType::FP32) */) {
return SplitLookupFunction_Dense_Op::apply(
host_weights,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ class SplitLookupFunction_{{ optimizer }}_Op : public torch::autograd::Function<
Tensor indices,
Tensor offsets,
int64_t pooling_mode,
c10::optional<Tensor> indice_weights,
c10::optional<Tensor> feature_requires_grad,
std::optional<Tensor> indice_weights,
std::optional<Tensor> feature_requires_grad,
bool gradient_clipping,
double max_gradient,
bool stochastic_rounding,
Expand Down Expand Up @@ -208,8 +208,8 @@ Tensor split_embedding_codegen_lookup_{{ optimizer }}_function_cpu(
Tensor indices,
Tensor offsets,
int64_t pooling_mode,
c10::optional<Tensor> indice_weights,
c10::optional<Tensor> feature_requires_grad,
std::optional<Tensor> indice_weights,
std::optional<Tensor> feature_requires_grad,
bool gradient_clipping,
double max_gradient,
bool stochastic_rounding,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -470,20 +470,20 @@ class {{ autograd_func }} :
const Tensor& offsets,
{%- if not nobag %}
const int64_t pooling_mode,
const c10::optional<Tensor>& indice_weights,
const c10::optional<Tensor>& feature_requires_grad,
const std::optional<Tensor>& indice_weights,
const std::optional<Tensor>& feature_requires_grad,
{%- endif %}
const Tensor& lxu_cache_locations,
c10::optional<Tensor> uvm_cache_stats,
std::optional<Tensor> uvm_cache_stats,
{%- if optimizer != "none" %}
const bool gradient_clipping,
const double max_gradient,
const bool stochastic_rounding,
{%- endif %}
{%- if vbe %}
const c10::optional<Tensor>& B_offsets,
const c10::optional<Tensor>& vbe_output_offsets_feature_rank,
const c10::optional<Tensor>& vbe_B_offsets_rank_per_feature,
const std::optional<Tensor>& B_offsets,
const std::optional<Tensor>& vbe_output_offsets_feature_rank,
const std::optional<Tensor>& vbe_B_offsets_rank_per_feature,
const c10::SymInt max_B,
const c10::SymInt max_B_feature_rank,
const c10::SymInt vbe_output_size,
Expand All @@ -493,7 +493,7 @@ class {{ autograd_func }} :
const bool use_homogeneous_placements,
{%- if is_gwd %}
{%- if "prev_iter_dev" not in args.split_function_arg_names %}
const c10::optional<Tensor>& prev_iter_dev,
const std::optional<Tensor>& prev_iter_dev,
{%- endif %}
{%- if "iter" not in args.split_function_arg_names %}
const int64_t iter,
Expand Down Expand Up @@ -790,8 +790,8 @@ Tensor split_embedding_codegen_lookup_{{ optimizer }}_function(
const Tensor& indices,
const Tensor& offsets,
const int64_t pooling_mode,
const c10::optional<Tensor>& indice_weights,
const c10::optional<Tensor>& feature_requires_grad,
const std::optional<Tensor>& indice_weights,
const std::optional<Tensor>& feature_requires_grad,
const Tensor& lxu_cache_locations,
{%- if optimizer != "none" %}
const bool gradient_clipping,
Expand All @@ -800,18 +800,18 @@ Tensor split_embedding_codegen_lookup_{{ optimizer }}_function(
{%- endif %}
{{ args.split_function_args | join(", ") }},
const int64_t output_dtype = static_cast<int64_t>(SparseType::FP32),
const c10::optional<Tensor>& B_offsets = c10::nullopt,
const c10::optional<Tensor>& vbe_output_offsets_feature_rank = c10::nullopt,
const c10::optional<Tensor>& vbe_B_offsets_rank_per_feature = c10::nullopt,
const std::optional<Tensor>& B_offsets = c10::nullopt,
const std::optional<Tensor>& vbe_output_offsets_feature_rank = c10::nullopt,
const std::optional<Tensor>& vbe_B_offsets_rank_per_feature = c10::nullopt,
const c10::SymInt max_B = -1,
const c10::SymInt max_B_feature_rank = -1,
const c10::SymInt vbe_output_size = -1,
const bool is_experimental = false,
const bool use_uniq_cache_locations_bwd = false,
const bool use_homogeneous_placements = false,
const c10::optional<Tensor>& uvm_cache_stats = c10::nullopt,
const std::optional<Tensor>& uvm_cache_stats = c10::nullopt,
{%- if "prev_iter_dev" not in args.split_function_arg_names %}
const c10::optional<Tensor>& prev_iter_dev = c10::nullopt,
const std::optional<Tensor>& prev_iter_dev = c10::nullopt,
{%- endif %}
{%- if "iter" not in args.split_function_arg_names %}
const int64_t iter = 0,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -661,13 +661,13 @@ Tensor split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_e
indices,
{{ "offsets" if not is_index_select else "Tensor()" }},
{{ "true" if nobag else "false" }},
{{ "c10::optional<Tensor>(vbe_b_t_map)" if vbe else "c10::optional<Tensor>()" }},
{{ "std::optional<Tensor>(vbe_b_t_map)" if vbe else "std::optional<Tensor>()" }},
info_B_num_bits,
info_B_mask,
total_unique_indices,
{%- if is_index_select %}
true, // is_index_select
c10::optional<Tensor>(total_L_offsets),
std::optional<Tensor>(total_L_offsets),
fixed_L_per_warp,
num_warps_per_feature
{%- else %}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,20 +87,20 @@ class {{ autograd_func }} :
const Tensor& offsets,
{%- if not nobag %}
const int64_t pooling_mode,
const c10::optional<Tensor>& indice_weights,
const c10::optional<Tensor>& feature_requires_grad,
const std::optional<Tensor>& indice_weights,
const std::optional<Tensor>& feature_requires_grad,
{%- endif %}
const Tensor& lxu_cache_locations,
c10::optional<Tensor> uvm_cache_stats,
std::optional<Tensor> uvm_cache_stats,
{%- if optimizer != "none" %}
const bool gradient_clipping,
const double max_gradient,
const bool stochastic_rounding,
{%- endif %}
{%- if vbe %}
const c10::optional<Tensor>& B_offsets,
const c10::optional<Tensor>& vbe_output_offsets_feature_rank,
const c10::optional<Tensor>& vbe_B_offsets_rank_per_feature,
const std::optional<Tensor>& B_offsets,
const std::optional<Tensor>& vbe_output_offsets_feature_rank,
const std::optional<Tensor>& vbe_B_offsets_rank_per_feature,
const c10::SymInt max_B,
const c10::SymInt max_B_feature_rank,
const c10::SymInt vbe_output_size,
Expand Down Expand Up @@ -773,8 +773,8 @@ Tensor split_embedding_codegen_lookup_{{ optimizer }}_function_pt2(
const Tensor& indices,
const Tensor& offsets,
const int64_t pooling_mode,
const c10::optional<Tensor>& indice_weights,
const c10::optional<Tensor>& feature_requires_grad,
const std::optional<Tensor>& indice_weights,
const std::optional<Tensor>& feature_requires_grad,
const Tensor& lxu_cache_locations,
{%- if optimizer != "none" %}
const bool gradient_clipping,
Expand All @@ -783,16 +783,16 @@ Tensor split_embedding_codegen_lookup_{{ optimizer }}_function_pt2(
{%- endif %}
{{ args_pt2.split_function_args | join(", ") }},
const int64_t output_dtype = static_cast<int64_t>(SparseType::FP32),
const c10::optional<Tensor>& B_offsets = c10::optional<Tensor>(),
const c10::optional<Tensor>& vbe_output_offsets_feature_rank = c10::optional<Tensor>(),
const c10::optional<Tensor>& vbe_B_offsets_rank_per_feature = c10::optional<Tensor>(),
const std::optional<Tensor>& B_offsets = std::optional<Tensor>(),
const std::optional<Tensor>& vbe_output_offsets_feature_rank = std::optional<Tensor>(),
const std::optional<Tensor>& vbe_B_offsets_rank_per_feature = std::optional<Tensor>(),
const c10::SymInt max_B = -1,
const c10::SymInt max_B_feature_rank = -1,
const c10::SymInt vbe_output_size = -1,
const bool is_experimental = false,
const bool use_uniq_cache_locations_bwd = false,
const bool use_homogeneous_placements = false,
const c10::optional<Tensor>& uvm_cache_stats = c10::optional<Tensor>()) {
const std::optional<Tensor>& uvm_cache_stats = std::optional<Tensor>()) {
{%- for vbe in ([True, False] if has_vbe_support else [False]) %}
{%- if has_vbe_support %}
{%- if vbe %}
Expand Down
4 changes: 2 additions & 2 deletions fbgemm_gpu/codegen/utils/embedding_bounds_check.cu
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,8 @@ void bounds_check_indices_cuda(
Tensor& offsets,
int64_t bounds_check_mode_,
Tensor& warning,
const c10::optional<Tensor>& weights,
const c10::optional<Tensor>& B_offsets,
const std::optional<Tensor>& weights,
const std::optional<Tensor>& B_offsets,
const int64_t max_B) {
TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(
rows_per_table, indices, offsets, warning, weights, B_offsets);
Expand Down
4 changes: 2 additions & 2 deletions fbgemm_gpu/codegen/utils/embedding_bounds_check_host.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ void bounds_check_indices_cuda(
Tensor& offsets,
int64_t bounds_check_mode,
Tensor& warning,
const c10::optional<Tensor>& weights,
const c10::optional<Tensor>& B_ofsets,
const std::optional<Tensor>& weights,
const std::optional<Tensor>& B_ofsets,
const int64_t max_B);

// Deprecated for fb namespace! Please use fbgemm namespace instead!
Expand Down
4 changes: 2 additions & 2 deletions fbgemm_gpu/codegen/utils/embedding_bounds_check_host_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ void bounds_check_indices_cpu(
Tensor& offsets,
int64_t bounds_check_mode_,
Tensor& warning,
const c10::optional<Tensor>& weights,
const c10::optional<Tensor>& B_offsets,
const std::optional<Tensor>& weights,
const std::optional<Tensor>& B_offsets,
const int64_t /*max_B*/) {
TORCH_CHECK(
!B_offsets.has_value(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1836,7 +1836,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> gqa_attn_splitk_impl(
const at::Tensor& seq_positions, // [B]
const double qk_scale,
const int64_t split_k,
const c10::optional<int64_t>& num_groups) {
const std::optional<int64_t>& num_groups) {
at::OptionalDeviceGuard guard(XQ.device());
TORCH_CHECK(XQ.is_cuda());
TORCH_CHECK(cache_K.is_cuda());
Expand Down
Loading

0 comments on commit 21a6dc7

Please sign in to comment.