diff --git a/fbgemm_gpu/codegen/genscript/generate_backward_split.py b/fbgemm_gpu/codegen/genscript/generate_backward_split.py index 9d7eaa6669..37c791dc1a 100644 --- a/fbgemm_gpu/codegen/genscript/generate_backward_split.py +++ b/fbgemm_gpu/codegen/genscript/generate_backward_split.py @@ -55,7 +55,7 @@ def render_backward_templates( ): if nobag and (weighted or vbe): continue - if kwargs.get("dense") and (vbe or ssd): + if kwargs.get("dense") and ssd: continue if ssd and (vbe or is_gwd): continue @@ -152,9 +152,10 @@ def generate_backward_split_gpu(**kwargs: Any) -> None: **kwargs, ) - # Generate the backward splits (non-dense) + # Generate the backward splits # We generate only the API to preserve the backward compatibility if # has_gpu_support=True + # Generate CUDA autograd, PT2 unified autograd, and PT2 backward wrapper if not kwargs.get("dense"): # Generate CUDA autograd template_filepath = ( @@ -197,6 +198,17 @@ def generate_backward_split_gpu(**kwargs: Any) -> None: filename, is_fbcode=args.is_fbcode, ssd=ssd, **kwargs ) + else: + template_filepath = ( + "training/backward/embedding_backward_split_host_template.cpp" + ) + filename = "gen_embedding_backward_split_dense.cpp" + CodeTemplate.load(template_filepath).write( + filename, + is_forward=False, + **kwargs, + ) + @staticmethod def generate_backward_split_cpu(**kwargs: Any) -> None: """ diff --git a/fbgemm_gpu/codegen/genscript/generate_forward_split.py b/fbgemm_gpu/codegen/genscript/generate_forward_split.py index bb6488b54a..0f5885b375 100644 --- a/fbgemm_gpu/codegen/genscript/generate_forward_split.py +++ b/fbgemm_gpu/codegen/genscript/generate_forward_split.py @@ -39,7 +39,7 @@ def render_forward_templates( ): if nobag and (weighted or vbe): continue - if dense and (vbe or ssd): + if dense and ssd: continue if ssd and (vbe or is_gwd): continue @@ -64,6 +64,7 @@ def render_forward_templates( is_gwd=is_gwd, ) + @staticmethod def generate_pt2_wrappers() -> None: # Generate PT2 forward wrapper (CUDA) CodeTemplate.load( diff --git a/fbgemm_gpu/codegen/genscript/optimizers.py b/fbgemm_gpu/codegen/genscript/optimizers.py index 3b2073692a..f7ce271ca5 100644 --- a/fbgemm_gpu/codegen/genscript/optimizers.py +++ b/fbgemm_gpu/codegen/genscript/optimizers.py @@ -40,7 +40,7 @@ def dense() -> Dict[str, Any]: ), "has_cpu_support": True, "has_gpu_support": True, - "has_vbe_support": False, + "has_vbe_support": True, "has_global_weight_decay_support": False, "has_ssd_support": False, } diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_dense_host.cpp b/fbgemm_gpu/codegen/training/backward/embedding_backward_dense_host.cpp deleted file mode 100644 index db72a53d0c..0000000000 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_dense_host.cpp +++ /dev/null @@ -1,433 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include -#include -#include -#include - -#include "fbgemm_gpu/embedding_common.h" -#include "fbgemm_gpu/sparse_ops_utils.h" - -using Tensor = at::Tensor; -using namespace fbgemm_gpu; - -Tensor dense_embedding_codegen_forward_unweighted_cuda( - const Tensor& dev_weights, - const Tensor& weights_offsets, - const Tensor& D_offsets, - const c10::SymInt total_D, - const c10::SymInt max_D, - const Tensor& indices, - const Tensor& offsets, - const int64_t pooling_mode, - const int64_t output_dtype, - const bool is_experimental); - -Tensor dense_embedding_codegen_forward_weighted_cuda( - const Tensor& dev_weights, - const Tensor& weights_offsets, - const Tensor& D_offsets, - const c10::SymInt total_D, - const c10::SymInt max_D, - const Tensor& indices, - const Tensor& offsets, - const int64_t pooling_mode, - const Tensor& indice_weights, - const int64_t output_dtype, - const bool is_experimental); - -Tensor dense_embedding_codegen_grad_indice_weights_cuda( - const Tensor& grad_output, - const Tensor& dev_weights, - const Tensor& weights_offsets, - const Tensor& D_offsets, - const c10::SymInt max_D, - const Tensor& indices, - const Tensor& offsets, - const Tensor& feature_requires_grad); - -Tensor split_embedding_backward_codegen_dense_unweighted_exact_cuda( - const Tensor& grad_output, - const Tensor& dev_weights, - const Tensor& weights_offsets, - const Tensor& D_offsets, - const c10::SymInt max_D, - const Tensor& hash_size_cumsum, - const int64_t total_hash_size_bits, - const Tensor& indices, - const Tensor& offsets, - const int64_t pooling_mode, - const int64_t BT_block_size, - const int64_t max_segment_length_per_warp, - const double unused); - -Tensor split_embedding_backward_codegen_dense_weighted_exact_cuda( - const Tensor& grad_output, - const Tensor& dev_weights, - const Tensor& weights_offsets, - const Tensor& D_offsets, - const c10::SymInt max_D, - const Tensor& hash_size_cumsum, - const int64_t total_hash_size_bits, - const Tensor& indices, - const Tensor& offsets, - const int64_t pooling_mode, - const Tensor& indice_weights, - const int64_t BT_block_size, - const int64_t max_segment_length_per_warp, - const double unused); - -class SplitLookupFunction_Dense_Op - : public torch::autograd::Function { - public: - static torch::autograd::variable_list forward( - torch::autograd::AutogradContext* ctx, - int64_t output_dtype, - Tensor dev_weights, - Tensor weights_offsets, - Tensor D_offsets, - c10::SymInt total_D, - c10::SymInt max_D, - Tensor hash_size_cumsum, - int64_t total_hash_size_bits, - Tensor indices, - Tensor offsets, - int64_t pooling_mode, - std::optional indice_weights, - std::optional feature_requires_grad) { - ctx->save_for_backward({ - dev_weights, - weights_offsets, - D_offsets, - hash_size_cumsum, - indices, - offsets, - indice_weights.value_or(Tensor()), - feature_requires_grad.value_or(Tensor()), - }); - - ctx->saved_data["total_D"] = total_D; - ctx->saved_data["max_D"] = max_D; - ctx->saved_data["total_hash_size_bits"] = total_hash_size_bits; - ctx->saved_data["pooling_mode"] = pooling_mode; - - if (!indice_weights.has_value()) { - return {dense_embedding_codegen_forward_unweighted_cuda( - dev_weights, - weights_offsets, - D_offsets, - total_D, - max_D, - indices, - offsets, - pooling_mode, - output_dtype, - /*is_experimental=*/false)}; - } else { - return {dense_embedding_codegen_forward_weighted_cuda( - dev_weights, - weights_offsets, - D_offsets, - total_D, - max_D, - indices, - offsets, - pooling_mode, - indice_weights.value(), - output_dtype, - /*is_experimental=*/false)}; - } - } - - static torch::autograd::variable_list backward( - torch::autograd::AutogradContext* ctx, - torch::autograd::variable_list grad_outputs) { - const auto saved = ctx->get_saved_variables(); - auto savedItr = std::begin(saved); - auto dev_weights = *savedItr++; - auto weights_offsets = *savedItr++; - auto D_offsets = *savedItr++; - auto hash_size_cumsum = *savedItr++; - auto indices = *savedItr++; - auto offsets = *savedItr++; - auto indice_weights = *savedItr++; - auto feature_requires_grad = *savedItr++; - - auto total_D = ctx->saved_data["total_D"].toSymInt(); - auto max_D = ctx->saved_data["max_D"].toSymInt(); - auto total_hash_size_bits = ctx->saved_data["total_hash_size_bits"].toInt(); - auto pooling_mode = ctx->saved_data["pooling_mode"].toInt(); - - TORCH_CHECK_EQ(grad_outputs.size(), 1); - -#ifdef USE_ROCM - constexpr int32_t BT_block_size = 64; - constexpr int32_t max_segment_length_per_warp = 64; -#else - constexpr int32_t BT_block_size = 32; - constexpr int32_t max_segment_length_per_warp = 32; -#endif - using torch::autograd::Variable; - - auto grad_output = grad_outputs[0]; - // FIXME: to support aligned memory access in Vec4T load/store function - // 16 for FP32 and 8 for FP16 - if (reinterpret_cast(grad_output.data_ptr()) % 16 != 0 || - grad_output.stride(1) != 1 || grad_output.stride(0) % 4 != 0) { - grad_output = grad_output.contiguous(); - } - if (reinterpret_cast(grad_output.data_ptr()) % 16 != 0) { - grad_output = at::empty_like(grad_output).copy_(grad_output); - } - - if (!indice_weights.defined()) { - auto grad_dev_weights = - split_embedding_backward_codegen_dense_unweighted_exact_cuda( - grad_output, - dev_weights, - weights_offsets, - D_offsets, - max_D, - hash_size_cumsum, - total_hash_size_bits, - indices, - offsets, - pooling_mode, - BT_block_size, - max_segment_length_per_warp, - /* unused=*/0.0); - return { - Variable(), // output_dtype - grad_dev_weights, - Variable(), // weights_offsets - Variable(), // D_offsets - Variable(), // total_D - Variable(), // max_D - Variable(), // hash_size_cumsum - Variable(), // total_hash_size_bits - Variable(), // indices - Variable(), // offsets - Variable(), // pooling_mode - Variable(), // indice_weights - Variable(), // feature_requires_grad - }; - } else { - auto grad_indice_weights = - dense_embedding_codegen_grad_indice_weights_cuda( - grad_output, - dev_weights, - weights_offsets, - D_offsets, - max_D, - indices, - offsets, - feature_requires_grad); - auto grad_dev_weights = - split_embedding_backward_codegen_dense_weighted_exact_cuda( - grad_output, - dev_weights, - weights_offsets, - D_offsets, - max_D, - hash_size_cumsum, - total_hash_size_bits, - indices, - offsets, - pooling_mode, - indice_weights, - BT_block_size, - max_segment_length_per_warp, - /* unused=*/0.0); - return { - Variable(), // output_dtype - grad_dev_weights, - Variable(), // weights_offsets - Variable(), // D_offsets - Variable(), // total_D - Variable(), // max_D - Variable(), // hash_size_cumsum - Variable(), // total_hash_size_bits - Variable(), // indices - Variable(), // offsets - Variable(), // pooling_mode - grad_indice_weights, - Variable(), // feature_requires_grad - }; - } - } -}; - -/******** nobag ops ********/ -Tensor dense_embedding_nobag_codegen_forward_unweighted_cuda( - const Tensor& dev_weights, - const Tensor& weights_offsets, - const c10::SymInt D, - const Tensor& indices, - const Tensor& offsets, - const int64_t output_dtype, - const bool is_experimental); - -Tensor split_embedding_nobag_backward_codegen_dense_unweighted_exact_cuda( - const Tensor& grad_output, - const Tensor& dev_weights, - const Tensor& weights_offsets, - const c10::SymInt D, - const Tensor& hash_size_cumsum, - const int64_t total_hash_size_bits, - const Tensor& indices, - const Tensor& offsets, - const int64_t BT_block_size, - const int64_t max_segment_length_per_warp, - const double unused); - -class SplitNoBagLookupFunction_Dense_Op - : public torch::autograd::Function { - public: - static torch::autograd::variable_list forward( - torch::autograd::AutogradContext* ctx, - int64_t output_dtype, - Tensor dev_weights, - Tensor weights_offsets, - c10::SymInt D, - Tensor hash_size_cumsum, - int64_t total_hash_size_bits, - Tensor indices, - Tensor offsets) { - ctx->save_for_backward({ - dev_weights, - weights_offsets, - hash_size_cumsum, - indices, - offsets, - }); - - ctx->saved_data["D"] = D; - ctx->saved_data["total_hash_size_bits"] = total_hash_size_bits; - - return {dense_embedding_nobag_codegen_forward_unweighted_cuda( - dev_weights, - weights_offsets, - D, - indices, - offsets, - output_dtype, - /*is_experimental*/ false)}; - } - - static torch::autograd::variable_list backward( - torch::autograd::AutogradContext* ctx, - torch::autograd::variable_list grad_outputs) { - const auto saved = ctx->get_saved_variables(); - auto savedItr = std::begin(saved); - auto dev_weights = *savedItr++; - auto weights_offsets = *savedItr++; - auto hash_size_cumsum = *savedItr++; - auto indices = *savedItr++; - auto offsets = *savedItr++; - - auto D = ctx->saved_data["D"].toSymInt(); - auto total_hash_size_bits = ctx->saved_data["total_hash_size_bits"].toInt(); - - TORCH_CHECK_EQ(grad_outputs.size(), 1); - - constexpr int32_t BT_block_size = 32; - constexpr int32_t max_segment_length_per_warp = 32; - using torch::autograd::Variable; - - auto grad_output = grad_outputs[0]; - // FIXME: to support aligned memory access in Vec4T load/store function - // 16 for FP32 and 8 for FP16 - if (reinterpret_cast(grad_output.data_ptr()) % 16 != 0 || - grad_output.stride(1) != 1 || grad_output.stride(0) % 4 != 0) { - grad_output = grad_output.contiguous(); - } - if (reinterpret_cast(grad_output.data_ptr()) % 16 != 0) { - grad_output = at::empty_like(grad_output).copy_(grad_output); - } - - auto grad_dev_weights = - split_embedding_nobag_backward_codegen_dense_unweighted_exact_cuda( - grad_output, - dev_weights, - weights_offsets, - D, - hash_size_cumsum, - total_hash_size_bits, - indices, - offsets, - BT_block_size, - max_segment_length_per_warp, - 0); - return { - Variable(), // output_dtype - grad_dev_weights, // grad_dev_weights - Variable(), // weights_offsets - Variable(), // D - Variable(), // hash_size_cumsum - Variable(), // total_hash_size_bits - Variable(), // indices - Variable(), // offsets - }; - } -}; - -Tensor split_embedding_codegen_lookup_dense_function( - Tensor dev_weights, - Tensor weights_offsets, - Tensor D_offsets, - int64_t total_D, - int64_t max_D, - Tensor hash_size_cumsum, - int64_t total_hash_size_bits, - Tensor indices, - Tensor offsets, - int64_t pooling_mode, - std::optional indice_weights, - std::optional feature_requires_grad, - int64_t output_dtype = static_cast(SparseType::FP32)) { - if (static_cast(pooling_mode) == PoolingMode::NONE) { - return SplitNoBagLookupFunction_Dense_Op::apply( - output_dtype, - dev_weights, - weights_offsets, - max_D, - hash_size_cumsum, - total_hash_size_bits, - indices, - offsets)[0]; - } else { - return SplitLookupFunction_Dense_Op::apply( - output_dtype, - dev_weights, - weights_offsets, - D_offsets, - total_D, - max_D, - hash_size_cumsum, - total_hash_size_bits, - indices, - offsets, - pooling_mode, - indice_weights, - feature_requires_grad)[0]; - } -} - -// Deprecated for fb namespace! Please use fbgemm namespace instead! -TORCH_LIBRARY_FRAGMENT(fb, m) { - DISPATCH_TO_CUDA( - "dense_embedding_codegen_lookup_function", - split_embedding_codegen_lookup_dense_function); -} - -TORCH_LIBRARY_FRAGMENT(fbgemm, m) { - DISPATCH_TO_CUDA( - "dense_embedding_codegen_lookup_function", - split_embedding_codegen_lookup_dense_function); -} diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_dense_host_cpu.cpp b/fbgemm_gpu/codegen/training/backward/embedding_backward_dense_host_cpu.cpp index ebda352053..ab523ccad5 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_dense_host_cpu.cpp +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_dense_host_cpu.cpp @@ -38,18 +38,18 @@ class SplitLookupFunction_Dense_Op public: static torch::autograd::variable_list forward( torch::autograd::AutogradContext* ctx, - Tensor host_weights, - Tensor weights_offsets, - Tensor D_offsets, - int64_t total_D, - int64_t max_D, - Tensor hash_size_cumsum, + const Tensor& host_weights, + const Tensor& weights_offsets, + const Tensor& D_offsets, + c10::SymInt total_D, + c10::SymInt max_D, + const Tensor& hash_size_cumsum, int64_t total_hash_size_bits, - Tensor indices, - Tensor offsets, + const Tensor& indices, + const Tensor& offsets, int64_t pooling_mode, - std::optional indice_weights, - std::optional feature_requires_grad) { + const std::optional& indice_weights, + const c10::optional& feature_requires_grad) { Tensor indice_weights_value = indice_weights.value_or(Tensor()); Tensor feature_requires_grad_value = feature_requires_grad.value_or(Tensor()); @@ -151,19 +151,25 @@ class SplitLookupFunction_Dense_Op }; Tensor split_embedding_codegen_lookup_dense_function( - Tensor host_weights, - Tensor weights_offsets, - Tensor D_offsets, - int64_t total_D, - int64_t max_D, - Tensor hash_size_cumsum, - int64_t total_hash_size_bits, - Tensor indices, - Tensor offsets, - int64_t pooling_mode, - std::optional indice_weights, - std::optional feature_requires_grad, - int64_t /* output_dtype = static_cast(SparseType::FP32) */) { + const Tensor& host_weights, + const Tensor& weights_offsets, + const Tensor& D_offsets, + c10::SymInt total_D, + c10::SymInt max_D, + const Tensor& hash_size_cumsum, + const int64_t total_hash_size_bits, + const Tensor& indices, + const Tensor& offsets, + const int64_t pooling_mode, + const c10::optional& indice_weights, + const c10::optional& feature_requires_grad, + int64_t output_dtype = static_cast(SparseType::FP32), + const c10::optional& B_offsets = c10::nullopt, + const c10::optional& vbe_output_offsets_feature_rank = c10::nullopt, + const c10::optional& vbe_B_offsets_rank_per_feature = c10::nullopt, + c10::SymInt max_B = -1, + c10::SymInt max_B_feature_rank = -1, + c10::SymInt vbe_output_size = -1) { return SplitLookupFunction_Dense_Op::apply( host_weights, weights_offsets, @@ -182,7 +188,7 @@ Tensor split_embedding_codegen_lookup_dense_function( // Deprecated for fb namespace! Please use fbgemm namespace instead! TORCH_LIBRARY_FRAGMENT(fb, m) { m.def( - "dense_embedding_codegen_lookup_function(Tensor dev_weights, Tensor weights_offsets, Tensor D_offsets, int total_D, int max_D, Tensor hash_size_cumsum, int total_hash_size_bits, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights, Tensor? feature_requires_grad, int output_dtype=0) -> Tensor"); + "dense_embedding_codegen_lookup_function(Tensor dev_weights, Tensor weights_offsets, Tensor D_offsets, SymInt total_D, SymInt max_D, Tensor hash_size_cumsum, int total_hash_size_bits, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights, Tensor? feature_requires_grad, int output_dtype=0, Tensor? B_offsets=None, Tensor? vbe_output_offsets_feature_rank=None, Tensor? vbe_B_offsets_rank_per_feature=None, SymInt max_B=-1, SymInt max_B_feature_rank=-1, SymInt vbe_output_size=-1) -> Tensor"); DISPATCH_TO_CPU( "dense_embedding_codegen_lookup_function", split_embedding_codegen_lookup_dense_function); @@ -190,7 +196,7 @@ TORCH_LIBRARY_FRAGMENT(fb, m) { TORCH_LIBRARY_FRAGMENT(fbgemm, m) { m.def( - "dense_embedding_codegen_lookup_function(Tensor dev_weights, Tensor weights_offsets, Tensor D_offsets, int total_D, int max_D, Tensor hash_size_cumsum, int total_hash_size_bits, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights, Tensor? feature_requires_grad, int output_dtype=0) -> Tensor"); + "dense_embedding_codegen_lookup_function(Tensor dev_weights, Tensor weights_offsets, Tensor D_offsets, SymInt total_D, SymInt max_D, Tensor hash_size_cumsum, int total_hash_size_bits, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights, Tensor? feature_requires_grad, int output_dtype=0, Tensor? B_offsets=None, Tensor? vbe_output_offsets_feature_rank=None, Tensor? vbe_B_offsets_rank_per_feature=None, SymInt max_B=-1, SymInt max_B_feature_rank=-1, SymInt vbe_output_size=-1) -> Tensor"); DISPATCH_TO_CPU( "dense_embedding_codegen_lookup_function", split_embedding_codegen_lookup_dense_function); diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp index a62bcea41f..4cf003a1ac 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp @@ -21,7 +21,9 @@ using Tensor = at::Tensor; using namespace fbgemm_gpu; {#/* Module description */#} -{%- set mdesc = "ssd" if ssd else "split" %} +{%- set fwd_mdesc = "ssd" if ssd else ("dense" if dense else "split") %} +{%- set bwd_mdesc = "ssd" if ssd else "split" %} + {%- if ssd %} enum SSDTensor { @@ -43,7 +45,7 @@ enum SSDTensor { #} {%- macro call_forward_op_dispatch(nobag, weighted, vbe, is_gwd) %} {%- set forward_op = "{}_embedding{}_codegen_forward_{}{}{}_cuda".format( - mdesc, + fwd_mdesc, "_nobag" if nobag else "", "weighted" if weighted else "unweighted", "_vbe" if vbe else "", @@ -58,9 +60,11 @@ enum SSDTensor { return { embedding_codegen_forward_op.call( flatten_dev_weights, + {%- if not dense %} uvm_weights, lxu_cache_weights, weights_placements, + {%- endif %} weights_offsets, {%- if nobag %} D, @@ -79,10 +83,10 @@ enum SSDTensor { {%- endif %} {# /* if not nobag */ #} {%- if ssd %} ssd_tensors[SSDTensor::ROW_ADDRS], - {%- else %} + {%- elif not dense %} lxu_cache_locations, - {%- endif %} uvm_cache_stats_, + {%- endif %} output_dtype, {%- if not nobag %} {%- if vbe %} @@ -100,7 +104,7 @@ enum SSDTensor { weight_decay, iter {%- else %} - is_experimental + {{ "is_experimental" if not dense else "false" }} {%- endif %} {# /* if is_gwd */ #} {%- else %} /*is_experimental=*/false @@ -115,7 +119,7 @@ enum SSDTensor { {%- macro call_backward_op_dispatch(nobag, weighted, vbe, is_gwd) %} {%- set wdesc = "_weighted" if weighted else "_unweighted" %} {%- set backward_op = "{}_embedding{}_backward_codegen_{}{}{}{}_exact_cuda".format( - mdesc, + bwd_mdesc, "_nobag" if nobag else "", optimizer, wdesc, @@ -131,9 +135,11 @@ enum SSDTensor { grad_dev_weights = embedding_codegen_{{ wdesc }}_backward_op.call( grad_output, dev_weights, + {% if not dense %} uvm_weights, lxu_cache_weights, weights_placements, + {%- endif %} weights_offsets, {% if nobag %} D, @@ -153,23 +159,27 @@ enum SSDTensor { {%- endif %} {# /* if not nobag */ #} {%- if ssd %} ssd_row_addrs, - {%- else %} + {%- elif not dense %} lxu_cache_locations, {%- endif %} BT_block_size, max_segment_length_per_warp, - {%- if optimizer != "none" %} + {%- if optimizer != "none" and not dense %} stochastic_rounding, {%- endif %} + {%- if not dense %} info_B_num_bits, info_B_mask_int64, + {%- endif %} {%- if vbe %} B_offsets, vbe_row_output_offsets, vbe_b_t_map, {%- endif %} {# /* if vbe */ #} + {%- if not dense %} use_uniq_cache_locations_bwd, use_homogeneous_placements, + {%- endif %} {%- if is_gwd %} {%- if "prev_iter_dev" not in args.split_function_arg_names %} prev_iter_dev, @@ -178,14 +188,23 @@ enum SSDTensor { iter, {%- endif %} {%- endif %} {# /* if is_gwd */ #} - {{ args.split_function_arg_names | join(", ") }}); + {%- if not dense %} + {{ args.split_function_arg_names | join(", ") }} + {%- else %} + /*unused=*/0 + {%- endif %} + ); return { + {%- if not dense %} Tensor(), // placeholder autograd tensor + {%- endif %} Variable(), // output_dtype grad_dev_weights, // dev_weights + {%- if not dense %} Variable(), // uvm_weights Variable(), // lxu_cache_weights Variable(), // weights_placements + {%- endif %} Variable(), // weights_offsets {%- if nobag %} Variable(), // D @@ -203,9 +222,11 @@ enum SSDTensor { grad_indice_weights, // indice_weights Variable(), // feature_requires_grad {%- endif %} + {%- if not dense %} Variable(), // lxu_cache_locations Variable(), // uvm_cache_stats - {%- if optimizer != "none" %} + {%- endif %} + {%- if optimizer != "none" and not dense %} Variable(), // gradient_clipping Variable(), // max_gradient Variable(), // stochastic_rounding @@ -218,9 +239,11 @@ enum SSDTensor { Variable(), // max_B_feature_rank Variable(), // vbe_output_size {%- endif %} + {%- if not dense %} Variable(), // is_experimental Variable(), // use_uniq_cache_locations_bwd Variable(), // use_homogeneous_placements + {%- endif %} {%- if is_gwd %} {%- if "prev_iter_dev" not in args.split_function_arg_names %} Variable(), // prev_iter_dev @@ -251,12 +274,16 @@ enum SSDTensor { ) %} return {{ autograd_func }}::apply( + {%- if not dense %} placeholder_autograd_tensor, + {%- endif %} output_dtype, dev_weights, + {%- if not dense %} uvm_weights, lxu_cache_weights, weights_placements, + {%- endif %} weights_offsets, {%- if nobag %} max_D, @@ -268,15 +295,26 @@ enum SSDTensor { hash_size_cumsum, total_hash_size_bits, indices, + {%- if not nobag and dense and not vbe %} + offsets, + pooling_mode, + indice_weights, + feature_requires_grad + {%- elif not nobag %} offsets, - {%- if not nobag %} pooling_mode, indice_weights, feature_requires_grad, + {%- elif nobag and dense and not vbe %} + offsets + {%- else %} + offsets, {%- endif %} + {%- if not dense %} lxu_cache_locations, uvm_cache_stats, - {%- if optimizer != "none" %} + {%- endif %} + {%- if optimizer != "none" and not dense %} gradient_clipping, max_gradient, stochastic_rounding, @@ -287,8 +325,13 @@ enum SSDTensor { vbe_B_offsets_rank_per_feature, max_B, max_B_feature_rank, + {%- endif %} + {%- if vbe and not dense %} vbe_output_size, + {%- elif vbe and dense %} + vbe_output_size {%- endif %} + {%- if not dense %} is_experimental, use_uniq_cache_locations_bwd, use_homogeneous_placements, @@ -303,7 +346,9 @@ enum SSDTensor { {%- if ssd %} ssd_tensors.value(), {%- endif %} - {{ args.split_function_arg_names | join(", ") }})[0]; + {{ args.split_function_arg_names | join(", ") }} + {%- endif %} + )[0]; {%- endmacro %} //////////////////////////////////////////////////////////////////////////////// @@ -316,12 +361,14 @@ enum SSDTensor { {%- for vbe in ([True, False] if has_vbe_support else [False]) %} {%- set vdesc = "_vbe" if vbe else "" %} -Tensor {{ mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda( +Tensor {{ fwd_mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda( const Tensor& grad_output, const Tensor& dev_weights, + {%- if not dense %} const Tensor& uvm_weights, const Tensor& lxu_cache_weights, const Tensor& weights_placements, + {%- endif %} const Tensor& weights_offsets, const Tensor& D_offsets, const c10::SymInt max_D, @@ -329,7 +376,7 @@ Tensor {{ mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda( const Tensor& offsets, {%- if ssd %} const Tensor& ssd_row_addrs, - {%- else %} + {%- elif not dense %} const Tensor& lxu_cache_locations, {%- endif %} {%- if vbe %} @@ -362,11 +409,13 @@ Tensor {{ mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda( {%- set desc_suffix = wdesc + vdesc + gwddesc %} -Tensor {{ mdesc }}_embedding{{ ndesc }}_codegen_forward{{ desc_suffix }}_cuda( +Tensor {{ fwd_mdesc }}_embedding{{ ndesc }}_codegen_forward{{ desc_suffix }}_cuda( const Tensor& dev_weights, + {%- if not dense %} const Tensor& uvm_weights, const Tensor& lxu_cache_weights, const Tensor& weights_placements, + {%- endif %} const Tensor& weights_offsets, {%- if nobag %} const c10::SymInt D, @@ -385,10 +434,10 @@ Tensor {{ mdesc }}_embedding{{ ndesc }}_codegen_forward{{ desc_suffix }}_cuda( {%- endif %} {%- if ssd %} const Tensor& ssd_row_addrs, - {%- else %} + {%- elif not dense %} const Tensor& lxu_cache_locations, - {%- endif %} const Tensor& uvm_cache_stats, + {%- endif %} const int64_t output_dtype, {%- if vbe %} const Tensor& vbe_row_output_offsets, @@ -410,12 +459,14 @@ Tensor {{ mdesc }}_embedding{{ ndesc }}_codegen_forward{{ desc_suffix }}_cuda( ); Tensor -{{ mdesc }}_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}{{ desc_suffix }}_exact_cuda( +{{ bwd_mdesc }}_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}{{ desc_suffix }}_exact_cuda( const Tensor& grad_output, const Tensor& dev_weights, + {%- if not dense %} const Tensor& uvm_weights, const Tensor& lxu_cache_weights, const Tensor& weights_placements, + {%- endif %} const Tensor& weights_offsets, {%- if nobag %} const c10::SymInt D, @@ -435,23 +486,27 @@ Tensor {%- endif %} {%- if ssd %} const Tensor& ssd_row_addrs, - {%- else %} + {%- elif not dense %} const Tensor& lxu_cache_locations, {%- endif %} const int64_t BT_block_size, const int64_t max_segment_length_per_warp, - {%- if optimizer != "none" %} + {%- if optimizer != "none" and not dense %} const bool stochastic_rounding, {%- endif %} + {%- if not dense %} const int64_t info_B_num_bits, const int64_t info_B_mask_int64, + {%- endif %} {%- if vbe %} const Tensor& B_offsets, const Tensor& vbe_row_output_offsets, const Tensor& vbe_b_t_map, {%- endif %} + {%- if not dense %} const bool use_uniq_cache_locations, const bool use_homogeneous_placements, + {%- endif %} {%- if is_gwd %} {%- if "prev_iter_dev" not in args.split_function_arg_names %} const Tensor& prev_iter_dev, @@ -482,12 +537,16 @@ class {{ autograd_func }} : public: static torch::autograd::variable_list forward( torch::autograd::AutogradContext* ctx, + {%- if not dense %} const Tensor& placeholder_autograd_tensor, + {%- endif %} const int64_t output_dtype, const Tensor& dev_weights, + {%- if not dense %} const Tensor& uvm_weights, const Tensor& lxu_cache_weights, const Tensor& weights_placements, + {%- endif %} const Tensor& weights_offsets, {%- if not nobag %} const Tensor& D_offsets, @@ -499,12 +558,22 @@ class {{ autograd_func }} : const Tensor& hash_size_cumsum, const int64_t total_hash_size_bits, const Tensor& indices, + {%- if not nobag and dense and not vbe %} + const Tensor& offsets, + const int64_t pooling_mode, + const c10::optional& indice_weights, + const c10::optional& feature_requires_grad + {%- elif not nobag %} const Tensor& offsets, - {%- if not nobag %} const int64_t pooling_mode, const std::optional& indice_weights, const std::optional& feature_requires_grad, + {%- elif nobag and dense and not vbe %} + const Tensor& offsets + {%- else %} + const Tensor& offsets, {%- endif %} + {%- if not dense %} const Tensor& lxu_cache_locations, std::optional uvm_cache_stats, {%- if optimizer != "none" %} @@ -534,7 +603,17 @@ class {{ autograd_func }} : {%- if ssd %} const at::TensorList& ssd_tensors, {%- endif %} - {{ args.split_function_args | join(", ") }}) { + {{ args.split_function_args | join(", ") }} + {%- else %} + {%- if vbe %} + const c10::optional& B_offsets, + const c10::optional& vbe_output_offsets_feature_rank, + const c10::optional& vbe_B_offsets_rank_per_feature, + const c10::SymInt max_B, + const c10::SymInt max_B_feature_rank, + const c10::SymInt vbe_output_size + {%- endif %} + {%- endif %}) { const auto T = weights_offsets.sym_numel(); {%- if vbe %} @@ -547,10 +626,12 @@ class {{ autograd_func }} : const auto max_B_ = offsets.sym_size(0) / T; {%- endif %} + {%- if not dense %} // NOTE: The `local_uvm_cache_stats` variable held by the nn.Module has dtype int32_t // TODO: Hook up with frontend code const auto uvm_cache_stats_ = uvm_cache_stats .value_or(at::empty({0}, uvm_weights.options().dtype(at::kInt))); + {%- endif %} // Default values for Dynamo tracing // SymInt does not support bitshifts operator @@ -602,9 +683,11 @@ class {{ autograd_func }} : {%- endif %} ctx->save_for_backward({ dev_weights, + {%- if not dense %} uvm_weights, lxu_cache_weights, weights_placements, + {%- endif %} weights_offsets, {%- if not nobag %} D_offsets, @@ -616,7 +699,9 @@ class {{ autograd_func }} : indice_weights.value_or(Tensor()), feature_requires_grad.value_or(Tensor()), {%- endif %} + {%- if not dense %} lxu_cache_locations, + {%- endif %} {%- if vbe %} B_offsets_, vbe_row_output_offsets, @@ -641,7 +726,7 @@ class {{ autograd_func }} : {%- endif %} ctx->saved_data["total_hash_size_bits"] = total_hash_size_bits; - {%- if optimizer != "none" %} + {%- if optimizer != "none" and not dense %} ctx->saved_data["gradient_clipping"] = gradient_clipping; ctx->saved_data["max_gradient"] = max_gradient; ctx->saved_data["stochastic_rounding"] = stochastic_rounding; @@ -649,15 +734,19 @@ class {{ autograd_func }} : ctx->saved_data["info_B_num_bits"] = info_B_num_bits; const auto info_B_mask_int64 = static_cast(info_B_mask); ctx->saved_data["info_B_mask"] = info_B_mask_int64; + {%- if not dense %} ctx->saved_data["use_uniq_cache_locations_bwd"] = use_uniq_cache_locations_bwd; ctx->saved_data["use_homogeneous_placements"] = use_homogeneous_placements; + {%- endif %} {%- if is_gwd and "iter" not in args.split_function_arg_names %} ctx->saved_data["iter"] = iter; {%- endif %} + {%- if not dense %} {%- for (var, _) in args.saved_data %} ctx->saved_data["{{ var }}"] = {{ var }}; {%- endfor %} + {%- endif %} {%- if optimizer == "none" %} // Flatten @@ -704,9 +793,11 @@ class {{ autograd_func }} : const auto saved = ctx->get_saved_variables(); auto savedItr = std::begin(saved); auto dev_weights = *savedItr++; + {%- if not dense %} auto uvm_weights = *savedItr++; auto lxu_cache_weights = *savedItr++; auto weights_placements = *savedItr++; + {%- endif %} auto weights_offsets = *savedItr++; {%- if not nobag %} auto D_offsets = *savedItr++; @@ -718,7 +809,9 @@ class {{ autograd_func }} : auto indice_weights = *savedItr++; auto feature_requires_grad = *savedItr++; {%- endif %} + {%- if not dense %} auto lxu_cache_locations = *savedItr++; + {%- endif %} {%- if vbe %} auto B_offsets = *savedItr++; auto vbe_row_output_offsets = *savedItr++; @@ -745,25 +838,29 @@ class {{ autograd_func }} : {%- endif %} auto total_hash_size_bits = ctx->saved_data["total_hash_size_bits"].toInt(); - {%- if optimizer != "none" %} + {%- if optimizer != "none" and not dense %} auto gradient_clipping = ctx->saved_data["gradient_clipping"].toBool(); auto max_gradient = ctx->saved_data["max_gradient"].toDouble(); auto stochastic_rounding = ctx->saved_data["stochastic_rounding"].toBool(); {%- endif %} {#-/* if optimizer != "none" */#} const int32_t info_B_num_bits = ctx->saved_data["info_B_num_bits"].toInt(); const int64_t info_B_mask_int64 = ctx->saved_data["info_B_mask"].toInt(); + {%- if not dense %} const auto use_uniq_cache_locations_bwd = ctx->saved_data["use_uniq_cache_locations_bwd"].toBool(); const auto use_homogeneous_placements = ctx->saved_data["use_homogeneous_placements"].toBool(); + {%- endif %} {%- if is_gwd and "iter" not in args.split_function_arg_names %} const auto iter = ctx->saved_data["iter"].toInt(); {%- endif %} + {%- if not dense%} {%- for (var, ivalue_cast) in args.saved_data %} auto {{ var }} = ctx->saved_data["{{ var }}"].{{ ivalue_cast }}(); {%- endfor %} + {%- endif %} TORCH_CHECK_EQ(grad_outputs.size(), 1); @@ -776,7 +873,7 @@ class {{ autograd_func }} : #endif using torch::autograd::Variable; - {%- if optimizer != "none" %} + {%- if optimizer != "none" and not dense %} auto grad_output = gradient_clipping ? clamp(grad_outputs[0], -max_gradient, max_gradient) : grad_outputs[0]; {%- else %} auto& grad_output = grad_outputs[0]; @@ -785,26 +882,28 @@ class {{ autograd_func }} : {%- if not nobag %} {%- if optimizer == "none" %} // Flatten (dev_weights is used in - // {{ mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda) + // {{ fwd_mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda) dev_weights = dev_weights.flatten(); {%- endif %} {%- set grad_indice_weights_op = - "{}_embedding_codegen_grad_indice_weights{}_cuda".format(mdesc, vdesc) + "{}_embedding_codegen_grad_indice_weights{}_cuda".format(fwd_mdesc, vdesc) %} - static auto {{ mdesc }}_embedding_codegen_grad_indice_weights_op = + static auto embedding_codegen_grad_indice_weights_op = torch::Dispatcher::singleton() .findSchemaOrThrow("fbgemm::{{ grad_indice_weights_op }}", "") .typed(); const auto grad_indice_weights = !indice_weights.defined() ? Variable() : - {{ mdesc }}_embedding_codegen_grad_indice_weights_op.call( + embedding_codegen_grad_indice_weights_op.call( grad_output, dev_weights, + {%- if not dense %} uvm_weights, lxu_cache_weights, weights_placements, + {%- endif %} weights_offsets, D_offsets, max_D, @@ -812,7 +911,7 @@ class {{ autograd_func }} : offsets, {%- if ssd %} ssd_row_addrs, - {%- else %} + {%- elif not dense %} lxu_cache_locations, {%- endif %} {%- if vbe %} @@ -864,12 +963,16 @@ class {{ autograd_func }} : {%- endif %} {#-/* if has_gpu_support */#} ///@ingroup embedding-cuda -Tensor {{ mdesc }}_embedding_codegen_lookup_{{ optimizer }}_function( +Tensor {{ bwd_mdesc }}_embedding_codegen_lookup_{{ optimizer }}_function( + {%- if dense %} + const Tensor& dev_weights, + {%- else %} const Tensor& placeholder_autograd_tensor, const Tensor& dev_weights, const Tensor& uvm_weights, const Tensor& lxu_cache_weights, const Tensor& weights_placements, + {%- endif %} const Tensor& weights_offsets, const Tensor& D_offsets, const c10::SymInt total_D, @@ -881,19 +984,22 @@ Tensor {{ mdesc }}_embedding_codegen_lookup_{{ optimizer }}_function( const int64_t pooling_mode, const std::optional& indice_weights, const std::optional& feature_requires_grad, + {%- if not dense %} const Tensor& lxu_cache_locations, - {%- if optimizer != "none" %} + {%- if optimizer != "none"%} const bool gradient_clipping, const double max_gradient, const bool stochastic_rounding, {%- endif %} {{ args.split_function_args | join(", ") }}, + {%- endif %} const int64_t output_dtype = static_cast(SparseType::FP32), const std::optional& B_offsets = c10::nullopt, const std::optional& vbe_output_offsets_feature_rank = c10::nullopt, const std::optional& vbe_B_offsets_rank_per_feature = c10::nullopt, const c10::SymInt max_B = -1, const c10::SymInt max_B_feature_rank = -1, + {%- if not dense %} const c10::SymInt vbe_output_size = -1, const bool is_experimental = false, const bool use_uniq_cache_locations_bwd = false, @@ -911,6 +1017,9 @@ Tensor {{ mdesc }}_embedding_codegen_lookup_{{ optimizer }}_function( {%- else %} const bool apply_global_weight_decay = false {%- endif %} + {%- else %} + const c10::SymInt vbe_output_size = -1 + {%- endif %} ) { // TODO: refactor into macro {%- if has_gpu_support %} @@ -949,7 +1058,7 @@ Tensor {{ mdesc }}_embedding_codegen_lookup_{{ optimizer }}_function( {%- else %} TORCH_CHECK( false, - "{{ mdesc }}_embedding_codegen_lookup_{{ optimizer }}_function is deprecated. Please see https://github.com/pytorch/FBGEMM/discussions/1727 for more detail." + "{{ bwd_mdesc }}_embedding_codegen_lookup_{{ optimizer }}_function is deprecated. Please see https://github.com/pytorch/FBGEMM/discussions/1727 for more detail." ); return Tensor(); {%- endif %} {#-/* if has_gpu_support */#} @@ -958,7 +1067,8 @@ Tensor {{ mdesc }}_embedding_codegen_lookup_{{ optimizer }}_function( // Deprecated for fb namespace! Please use fbgemm namespace instead! {%- for lib_name in ["fb", "fbgemm"] %} TORCH_LIBRARY_FRAGMENT({{ lib_name }}, m) { - {%- set op_name = "{}_embedding_codegen_lookup_{}_function".format(mdesc, optimizer) %} + {%- set op_name = "{}_embedding_codegen_lookup_{}_function".format(bwd_mdesc, optimizer) %} + {%- if not dense %} m.def("{{ op_name }}(" " Tensor placeholder_autograd_tensor, " " Tensor dev_weights, " @@ -1008,6 +1118,7 @@ TORCH_LIBRARY_FRAGMENT({{ lib_name }}, m) { {%- endif %} ") -> Tensor", {PT2_COMPLIANT_TAG}); + // We're playing a funny trick here: we're using the autograd // implementation of the operator at all the dispatch keys. This is OK // because autograd.Function works even in a context where there is @@ -1023,8 +1134,14 @@ TORCH_LIBRARY_FRAGMENT({{ lib_name }}, m) { torch::dispatch( c10::DispatchKey::Meta, TORCH_FN({{ op_name }}))); + {%- endif %} {#/* if not dense */#} + DISPATCH_TO_CUDA( + {%- if not dense %} "{{ op_name }}", + {%- else %} + "dense_embedding_codegen_lookup_function", + {%- endif %} {{ op_name }}); } {%- endfor %} {#-/* for lib_name */#}