From 3d84b252efe90d2ebe09806ace12befee3bc27c8 Mon Sep 17 00:00:00 2001 From: Sarunya Pumma Date: Tue, 21 May 2024 23:58:10 -0700 Subject: [PATCH] Add cache conflict miss support (#2596) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/2596 Prior to this diff, SSD TBE lacked support for the conflict cache miss scenario. It operated under the assumption that the cache, located in GPU memory, was sufficiently large to hold all prefetched data from SSD. In the event of a conflict cache miss, the behavior of SSD TBE would be unpredictable (it could either fail or potentially access illegal memory). Note that a conflict cache miss happens when an embedding row is absent in the cache, and after being fetched from SSD, it cannot be inserted into the cache due to capacity constraints or associativity limitations. This diff introduces support for conflict cache misses by storing rows that cannot be inserted into the cache due to conflicts in a scratch pad, which is a temporary GPU tensor. In the case where rows are missed from the cache, TBE kernels can access the scratch pad. Prior to this diff, during the SSD prefetch stage, any row that was missed the cache and required fetching from SSD would be first fetched into a CPU scratch pad and then transferred to GPU. Rows that could be inserted into the cache would subsequently be copied from the GPU scratch pad into the cache. If conflict misses occurred, the prefetch behavior would be unpredictable. With this diff, conflict missed rows are now retained in the scratch pad, which is kept alive until the current iteration completes. Throughout the forward and backward + optimizer stages of TBE, both the cache and scratch pad are equivalent in terms of usage. However, following the completion of the backward + optimizer step, rows in the scratch pad are flushed back to SSD, unlike rows residing in the cache which are not evicted for future usage (see the diagram below for more details). {F1645878181} Differential Revision: D55998215 --- fbgemm_gpu/FbgemmGpu.cmake | 52 ++- .../genscript/generate_backward_split.py | 172 +++++++--- .../genscript/generate_forward_split.py | 79 +++-- .../genscript/generate_index_select.py | 4 +- .../codegen/genscript/jinja_environment.py | 17 +- fbgemm_gpu/codegen/genscript/optimizers.py | 17 + ...embedding_backward_split_host_template.cpp | 302 ++++++++++++------ ..._backward_split_indice_weights_template.cu | 70 ++-- ...ding_backward_split_kernel_cta_template.cu | 31 +- ...ing_backward_split_kernel_warp_template.cu | 29 +- ...embedding_backward_split_meta_template.cpp | 10 +- .../embedding_backward_split_template.cu | 109 ++++--- ...rward_split_kernel_nobag_small_template.cu | 54 +++- ...embedding_forward_split_kernel_template.cu | 73 +++-- .../embedding_forward_split_meta_template.cpp | 16 +- .../embedding_forward_split_template.cu | 76 +++-- ...optimizer_split_device_kernel_template.cuh | 20 +- .../codegen/training/python/__init__.template | 28 +- .../{lookup_args.py => lookup_args.template} | 5 +- ..._embedding_codegen_lookup_invoker.template | 25 +- .../ssd_split_table_batched_embeddings_ops.py | 204 +++++++++--- .../include/fbgemm_gpu/fbgemm_cuda_utils.cuh | 13 +- ...ssd_split_table_batched_embeddings_test.py | 241 +++++++++++--- 23 files changed, 1157 insertions(+), 490 deletions(-) rename fbgemm_gpu/codegen/training/python/{lookup_args.py => lookup_args.template} (94%) diff --git a/fbgemm_gpu/FbgemmGpu.cmake b/fbgemm_gpu/FbgemmGpu.cmake index a6824b57c..bb7c76507 100644 --- a/fbgemm_gpu/FbgemmGpu.cmake +++ b/fbgemm_gpu/FbgemmGpu.cmake @@ -97,6 +97,10 @@ set(GWD_OPTIMIZERS set(DEFUSED_OPTIMIZERS rowwise_adagrad) +# Optimizers with the SSD support +set(SSD_OPTIMIZERS + rowwise_adagrad) + set(WEIGHT_OPTIONS weighted unweighted_nobag @@ -143,6 +147,7 @@ set(gen_gpu_kernel_source_files "gen_embedding_forward_split_unweighted_codegen_cuda.cu" "gen_embedding_backward_dense_indice_weights_codegen_cuda.cu" "gen_embedding_backward_split_indice_weights_codegen_cuda.cu" + "gen_embedding_backward_ssd_indice_weights_codegen_cuda.cu" "gen_embedding_forward_split_weighted_vbe_codegen_cuda.cu" "gen_embedding_forward_split_unweighted_vbe_codegen_cuda.cu" "gen_batch_index_select_dim0_forward_codegen_cuda.cu" @@ -153,10 +158,13 @@ set(gen_gpu_kernel_source_files "gen_batch_index_select_dim0_backward_kernel_warp.cu" "gen_embedding_backward_split_grad_embedding_ops.cu" "gen_embedding_backward_split_grad_index_select.cu" - "gen_embedding_backward_common_split_device_kernel.cuh" - "gen_embedding_backward_batch_index_select_split_device_kernel.cuh" + "gen_embedding_backward_split_common_device_kernel.cuh" + "gen_embedding_backward_split_batch_index_select_device_kernel.cuh" "gen_embedding_forward_split_weighted_gwd_codegen_cuda.cu" "gen_embedding_forward_split_unweighted_gwd_codegen_cuda.cu" + "gen_embedding_forward_ssd_weighted_codegen_cuda.cu" + "gen_embedding_forward_ssd_unweighted_codegen_cuda.cu" + "gen_embedding_forward_ssd_unweighted_nobag_kernel_small.cu" ) if(NOT USE_ROCM) @@ -179,7 +187,8 @@ foreach(wdesc ${WEIGHT_OPTIONS}) "gen_embedding_backward_dense_split_${wdesc}_kernel_cta.cu" "gen_embedding_backward_dense_split_${wdesc}_kernel_warp.cu" "gen_embedding_forward_split_${wdesc}_kernel.cu" - "gen_embedding_backward_${wdesc}_split_device_kernel.cuh") + "gen_embedding_forward_ssd_${wdesc}_kernel.cu" + "gen_embedding_backward_split_${wdesc}_device_kernel.cuh") foreach(etype fp32 fp16 fp8 int8 int4 int2) list(APPEND gen_gpu_kernel_source_files @@ -191,7 +200,7 @@ endforeach() foreach(wdesc weighted unweighted) list(APPEND gen_gpu_kernel_source_files "gen_embedding_forward_split_${wdesc}_vbe_kernel.cu" - "gen_embedding_backward_${wdesc}_vbe_split_device_kernel.cuh") + "gen_embedding_backward_split_${wdesc}_vbe_device_kernel.cuh") endforeach() # Generate GWD files @@ -207,22 +216,31 @@ set(gen_cpu_source_files set(gen_python_source_files ${CMAKE_BINARY_DIR}/__init__.py - ${CMAKE_BINARY_DIR}/lookup_args.py) + ${CMAKE_BINARY_DIR}/lookup_args.py + ${CMAKE_BINARY_DIR}/lookup_args_ssd.py +) # For each of the optimizers, generate the backward split variant by adding # the Python, CPU-only, GPU host, and GPU kernel source files -# Generate the Python functions only if there is the backend support +# Generate the Python functions only if there is the backend support (for all +# optimizers) foreach(optimizer ${COMMON_OPTIMIZERS} ${CPU_ONLY_OPTIMIZERS} ${GPU_ONLY_OPTIMIZERS}) list(APPEND gen_python_source_files - "${CMAKE_BINARY_DIR}/lookup_${optimizer}.py") - list(APPEND gen_python_source_files + "${CMAKE_BINARY_DIR}/lookup_${optimizer}.py" "${CMAKE_BINARY_DIR}/lookup_${optimizer}_pt2.py") endforeach() +# Generate the Python functions only if there is the backend support (for SSD +# optimizers) +foreach(optimizer ${SSD_OPTIMIZERS}) + list(APPEND gen_python_source_files + "${CMAKE_BINARY_DIR}/lookup_${optimizer}_ssd.py") +endforeach() + # Generate the backend API for all optimizers to preserve the backward # compatibility list(APPEND gen_cpu_source_files @@ -285,6 +303,24 @@ foreach(optimizer ${DEFUSED_OPTIMIZERS}) "${CMAKE_BINARY_DIR}/split_embedding_optimizer_${optimizer}.py") endforeach() +foreach(optimizer ${SSD_OPTIMIZERS}) + list(APPEND gen_gpu_kernel_source_files + "gen_embedding_optimizer_${optimizer}_ssd_device_kernel.cuh" + ) + + list(APPEND gen_gpu_host_source_files + "gen_embedding_backward_ssd_${optimizer}.cpp" + ) + + foreach(wdesc weighted unweighted unweighted_nobag) + list(APPEND gen_gpu_kernel_source_files + "gen_embedding_backward_${optimizer}_ssd_${wdesc}_cuda.cu" + "gen_embedding_backward_${optimizer}_ssd_${wdesc}_kernel_cta.cu" + "gen_embedding_backward_${optimizer}_ssd_${wdesc}_kernel_warp.cu") + endforeach() + +endforeach() + list(APPEND gen_defused_optim_py_files ${CMAKE_BINARY_DIR}/optimizer_args.py) diff --git a/fbgemm_gpu/codegen/genscript/generate_backward_split.py b/fbgemm_gpu/codegen/genscript/generate_backward_split.py index 9ccc23f57..9d7eaa666 100644 --- a/fbgemm_gpu/codegen/genscript/generate_backward_split.py +++ b/fbgemm_gpu/codegen/genscript/generate_backward_split.py @@ -8,7 +8,9 @@ # pyre-strict # flake8: noqa F401 +import itertools import sys +from typing import List try: from .optimizers import * @@ -39,28 +41,44 @@ def render_backward_templates( ) -> None: if not kwargs.get("has_gpu_support"): return + + weighted_options = [True, False] + nobag_options = [True, False] if (not is_gwd) else [False] vbe_options = ( [True, False] if (kwargs.get("has_vbe_support") and not is_gwd) else [False] ) + ssd_options = [True, False] if kwargs.get("has_ssd_support") else [False] template = CodeTemplate.load(template_filepath) - for weighted in [True, False]: - for nobag in [True, False] if (not is_gwd) else [False]: - for vbe in vbe_options: - if (not nobag or (not weighted and not vbe)) and ( - not kwargs.get("dense") or not vbe - ): - wdesc = f"{ 'weighted' if weighted else 'unweighted' }{ '_nobag' if nobag else '' }{ '_vbe' if vbe else '' }" - template.write( - filename_format.format(optimizer, wdesc), - weighted=weighted, - nobag=nobag, - vbe=vbe, - is_index_select=False, - kdesc=wdesc, - **kwargs, - is_gwd=is_gwd, - ) + for weighted, nobag, vbe, ssd in itertools.product( + weighted_options, nobag_options, vbe_options, ssd_options + ): + if nobag and (weighted or vbe): + continue + if kwargs.get("dense") and (vbe or ssd): + continue + if ssd and (vbe or is_gwd): + continue + + kdesc = "".join( + [ + f"{ 'weighted' if weighted else 'unweighted' }", + f"{ '_nobag' if nobag else '' }", + f"{ '_vbe' if vbe else '' }", + ] + ) + desc = "_".join([f"{ 'ssd' if ssd else 'split' }", kdesc]) + template.write( + filename_format.format(optimizer, desc), + weighted=weighted, + nobag=nobag, + vbe=vbe, + is_index_select=False, + kdesc=kdesc, + is_gwd=is_gwd, + ssd=ssd, + **kwargs, + ) @staticmethod def generate_backward_split_gpu(**kwargs: Any) -> None: @@ -73,19 +91,19 @@ def generate_backward_split_gpu(**kwargs: Any) -> None: for template_filepath, filename_format in [ ( "training/backward/embedding_backward_split_template.cu", - "gen_embedding_backward_{}_split_{}_cuda.cu", + "gen_embedding_backward_{}_{}_cuda.cu", ), ( "training/backward/embedding_backward_split_meta_template.cpp", - "gen_embedding_backward_{}_split_{}_meta.cpp", + "gen_embedding_backward_{}_{}_meta.cpp", ), ( "training/backward/embedding_backward_split_kernel_cta_template.cu", - "gen_embedding_backward_{}_split_{}_kernel_cta.cu", + "gen_embedding_backward_{}_{}_kernel_cta.cu", ), ( "training/backward/embedding_backward_split_kernel_warp_template.cu", - "gen_embedding_backward_{}_split_{}_kernel_warp.cu", + "gen_embedding_backward_{}_{}_kernel_warp.cu", ), ]: BackwardSplitGenerator.render_backward_templates( @@ -94,20 +112,21 @@ def generate_backward_split_gpu(**kwargs: Any) -> None: filename_format, kwargs, ) + # Generate the global weight decay CUDA kernels if kwargs.get("has_global_weight_decay_support"): for template_filepath, filename_format in [ ( "training/backward/embedding_backward_split_kernel_cta_template.cu", - "gen_embedding_backward_{}_split_{}_gwd_kernel_cta.cu", + "gen_embedding_backward_{}_{}_gwd_kernel_cta.cu", ), ( "training/backward/embedding_backward_split_kernel_warp_template.cu", - "gen_embedding_backward_{}_split_{}_gwd_kernel_warp.cu", + "gen_embedding_backward_{}_{}_gwd_kernel_warp.cu", ), ( "training/backward/embedding_backward_split_template.cu", - "gen_embedding_backward_{}_split_{}_gwd_cuda.cu", + "gen_embedding_backward_{}_{}_gwd_cuda.cu", ), ]: BackwardSplitGenerator.render_backward_templates( @@ -118,23 +137,38 @@ def generate_backward_split_gpu(**kwargs: Any) -> None: is_gwd=True, ) - # Generate optimizer kernel - CodeTemplate.load( - "training/optimizer/embedding_optimizer_split_device_kernel_template.cuh" - ).write( - f"gen_embedding_optimizer_{optimizer}_split_device_kernel.cuh", **kwargs - ) + for ssd in ( + [True, False] + if kwargs.get("has_ssd_support") and not kwargs.get("dense") + else [False] + ): + desc = f"{ 'ssd' if ssd else 'split' }" + # Generate optimizer kernel + CodeTemplate.load( + "training/optimizer/embedding_optimizer_split_device_kernel_template.cuh" + ).write( + f"gen_embedding_optimizer_{optimizer}_{desc}_device_kernel.cuh", + ssd=ssd, + **kwargs, + ) # Generate the backward splits (non-dense) # We generate only the API to preserve the backward compatibility if # has_gpu_support=True if not kwargs.get("dense"): - # Generate CUDA autograd, PT2 unified autograd, and PT2 backward wrapper + # Generate CUDA autograd + template_filepath = ( + "training/backward/embedding_backward_split_host_template.cpp" + ) + for ssd in [True, False] if kwargs.get("has_ssd_support") else [False]: + desc = "ssd" if ssd else "split" + filename = f"gen_embedding_backward_{desc}_{optimizer}.cpp" + CodeTemplate.load(template_filepath).write( + filename, is_forward=False, ssd=ssd, **kwargs + ) + + # Generate PT2 unified autograd, and PT2 backward wrapper for template_filepath, filename in [ - ( - "training/backward/embedding_backward_split_host_template.cpp", - f"gen_embedding_backward_split_{optimizer}.cpp", - ), ( "training/pt2/embedding_split_host_pt2_autograd_template.cpp", f"gen_embedding_split_{optimizer}_pt2_autograd.cpp", @@ -153,11 +187,15 @@ def generate_backward_split_gpu(**kwargs: Any) -> None: template = CodeTemplate.load( "training/python/split_embedding_codegen_lookup_invoker.template" ) - for filename in [ - f"lookup_{optimizer}.py", - f"lookup_{optimizer}_pt2.py", - ]: - template.write(filename, is_fbcode=args.is_fbcode, **kwargs) + for ssd in [True, False] if kwargs.get("has_ssd_support") else [False]: + sdesc = "_ssd" if ssd else "" + for filename in [ + f"lookup_{optimizer}{sdesc}.py", + f"lookup_{optimizer}{sdesc}_pt2.py", + ]: + template.write( + filename, is_fbcode=args.is_fbcode, ssd=ssd, **kwargs + ) @staticmethod def generate_backward_split_cpu(**kwargs: Any) -> None: @@ -213,10 +251,11 @@ def generate_backward_device() -> None: BackwardSplitGenerator.render_backward_templates( template_filepath, "", - "{}gen_embedding_backward_{}_split_device_kernel.cuh", + "{}gen_embedding_backward_{}_device_kernel.cuh", { "has_gpu_support": True, "has_vbe_support": True, + "has_ssd_support": True, "dense": False, "gen_once": False, }, @@ -224,7 +263,7 @@ def generate_backward_device() -> None: # Generate common backward device kernels (generate only once) CodeTemplate.load(template_filepath).write( - "gen_embedding_backward_common_split_device_kernel.cuh", + "gen_embedding_backward_split_common_device_kernel.cuh", gen_once=True, ) @@ -242,16 +281,31 @@ def generate_backward_indices() -> None: template = CodeTemplate.load( "training/backward/embedding_backward_split_indice_weights_template.cu" ) - for dense in [True, False]: + dense_options = [True, False] + ssd_options = [True, False] + for dense, ssd in itertools.product(dense_options, ssd_options): + if dense and ssd: + continue + desc = "dense" if dense else ("ssd" if ssd else "split") template.write( - f"gen_embedding_backward_{'dense' if dense else 'split'}_indice_weights_codegen_cuda.cu", + f"gen_embedding_backward_{ desc }_indice_weights_codegen_cuda.cu", dense=dense, + ssd=ssd, ) @staticmethod - def generate_python_sources() -> None: - CodeTemplate.load("training/python/__init__.template").write("__init__.py") - CodeTemplate.copy_to_root("training/python/lookup_args.py") + def generate_python_sources( + all_optimizers: List[str], ssd_optimizers: List[str] + ) -> None: + CodeTemplate.load("training/python/__init__.template").write( + "__init__.py", all_optimizers=all_optimizers, ssd_optimizers=ssd_optimizers + ) + + template = CodeTemplate.load("training/python/lookup_args.template") + for ssd in [True, False]: + sdesc = "_ssd" if ssd else "" + filename = f"lookup_args{sdesc}.py" + template.write(filename, ssd=ssd) @staticmethod def generate() -> None: @@ -276,8 +330,28 @@ def generate() -> None: none_optimizer(), ] + ssd_tensors = [ + "row_addrs", + "inserted_rows", + "post_bwd_evicted_indices", + "actions_count", + ] + + all_optimizers = [] + ssd_optimizers = [] + for optimizer in optimizers: - BackwardSplitGenerator.generate_backward_split(**optimizer) + optim = optimizer["optimizer"] + if ( + optimizer["has_cpu_support"] or optimizer["has_gpu_support"] + ) and optim != "dense": + all_optimizers.append(optim) + if optimizer["has_ssd_support"]: + ssd_optimizers.append(optim) + + BackwardSplitGenerator.generate_backward_split( + ssd_tensors=ssd_tensors, **optimizer + ) # Generate common device kernels for backwards BackwardSplitGenerator.generate_backward_device() @@ -286,7 +360,7 @@ def generate() -> None: BackwardSplitGenerator.generate_backward_grad() BackwardSplitGenerator.generate_backward_indices() - BackwardSplitGenerator.generate_python_sources() + BackwardSplitGenerator.generate_python_sources(all_optimizers, ssd_optimizers) def main() -> None: diff --git a/fbgemm_gpu/codegen/genscript/generate_forward_split.py b/fbgemm_gpu/codegen/genscript/generate_forward_split.py index d907a11ff..bb6488b54 100644 --- a/fbgemm_gpu/codegen/genscript/generate_forward_split.py +++ b/fbgemm_gpu/codegen/genscript/generate_forward_split.py @@ -9,6 +9,7 @@ # flake8: noqa F401 import argparse +import itertools import sys from typing import List @@ -27,36 +28,42 @@ def render_forward_templates( dense_options: List[bool], nobag_options: List[bool], vbe_options: List[bool], + ssd_options: List[bool], is_gwd: bool = False, ) -> None: template = CodeTemplate.load(template_filepath) - for dense in dense_options: - for weighted in [True, False]: - for nobag in nobag_options: - for vbe in vbe_options: - if (not nobag or (not weighted and not vbe)) and ( - not dense or not vbe - ): - dense_desc = f"{ 'dense' if dense else 'split'}" - weight_desc = ( - f"{ 'weighted' if weighted else 'unweighted' }" - ) - nobag_desc = f"{ '_nobag' if nobag else '' }" - vbe_desc = f"{ '_vbe' if vbe else '' }" - - template.write( - filename_format.format( - f"{ dense_desc }_{ weight_desc }{ nobag_desc }{ vbe_desc }" - ), - dense=dense, - weighted=weighted, - nobag=nobag, - vbe=vbe, - is_index_select=False, - is_gwd=is_gwd, - ) + weighted_options = [True, False] + + for dense, weighted, nobag, vbe, ssd in itertools.product( + dense_options, weighted_options, nobag_options, vbe_options, ssd_options + ): + if nobag and (weighted or vbe): + continue + if dense and (vbe or ssd): + continue + if ssd and (vbe or is_gwd): + continue + + desc = "".join( + [ + f"{ 'dense' if dense else ('ssd' if ssd else 'split') }", + f"{ '_weighted' if weighted else '_unweighted' }", + f"{ '_nobag' if nobag else '' }", + f"{ '_vbe' if vbe else '' }", + ] + ) + fname = filename_format.format(desc) + template.write( + fname, + dense=dense, + weighted=weighted, + nobag=nobag, + vbe=vbe, + ssd=ssd, + is_index_select=False, + is_gwd=is_gwd, + ) - @staticmethod def generate_pt2_wrappers() -> None: # Generate PT2 forward wrapper (CUDA) CodeTemplate.load( @@ -84,12 +91,14 @@ def generate_small_kernels() -> None: "training/forward/embedding_forward_split_kernel_nobag_small_template.cu" ) for dense in [True, False]: - wdesc = f"{ 'dense' if dense else 'split' }" - template.write( - f"gen_embedding_forward_{wdesc}_unweighted_nobag_kernel_small.cu", - dense=dense, - is_index_select=False, - ) + for ssd in [True, False]: + ddesc = f"{ 'dense' if dense else ('ssd' if ssd else 'split') }" + template.write( + f"gen_embedding_forward_{ ddesc }_unweighted_nobag_kernel_small.cu", + dense=dense, + ssd=ssd, + is_index_select=False, + ) @staticmethod def generate_kernels() -> None: @@ -100,6 +109,7 @@ def generate_kernels() -> None: dense_options=[True, False], nobag_options=[False], # nobag is not used vbe_options=[True, False], + ssd_options=[True, False], ) # Generate the CUDA host code for global weight decay ForwardSplitGenerator.render_forward_templates( @@ -109,6 +119,7 @@ def generate_kernels() -> None: nobag_options=[False], # nobag is not used vbe_options=[False], is_gwd=True, + ssd_options=[False], ) # Generate the meta kernels @@ -118,6 +129,7 @@ def generate_kernels() -> None: dense_options=[True, False], nobag_options=[False], # nobag is not used vbe_options=[True, False], + ssd_options=[True, False], ) # Generate the CUDA kernels @@ -127,6 +139,7 @@ def generate_kernels() -> None: dense_options=[True, False], nobag_options=[True, False], vbe_options=[True, False], + ssd_options=[True, False], ) # Generate the global weight decay CUDA kernels ForwardSplitGenerator.render_forward_templates( @@ -135,6 +148,7 @@ def generate_kernels() -> None: dense_options=[False], nobag_options=[False], vbe_options=[False], + ssd_options=[False], is_gwd=True, ) @@ -145,6 +159,7 @@ def generate_kernels() -> None: dense_options=[False], # dense is not supported nobag_options=[False], # nobag is not supported vbe_options=[False], # vbe is not supported + ssd_options=[False], # ssd is not supported ) @staticmethod diff --git a/fbgemm_gpu/codegen/genscript/generate_index_select.py b/fbgemm_gpu/codegen/genscript/generate_index_select.py index c28ab9b60..22dc67280 100644 --- a/fbgemm_gpu/codegen/genscript/generate_index_select.py +++ b/fbgemm_gpu/codegen/genscript/generate_index_select.py @@ -58,7 +58,7 @@ def generate() -> None: ), ( "training/backward/embedding_backward_split_device_kernel_template.cuh", - "gen_embedding_backward_batch_index_select_split_device_kernel.cuh", + "gen_embedding_backward_split_batch_index_select_device_kernel.cuh", ), ]: CodeTemplate.load(template_file).write( @@ -84,7 +84,7 @@ def generate() -> None: CodeTemplate.load( "training/backward/embedding_backward_split_device_kernel_template.cuh" ).write( - "gen_embedding_backward_common_split_device_kernel.cuh", + "gen_embedding_backward_split_common_device_kernel.cuh", gen_once=True, ) diff --git a/fbgemm_gpu/codegen/genscript/jinja_environment.py b/fbgemm_gpu/codegen/genscript/jinja_environment.py index 5e6d96a4a..1ac6cd705 100644 --- a/fbgemm_gpu/codegen/genscript/jinja_environment.py +++ b/fbgemm_gpu/codegen/genscript/jinja_environment.py @@ -289,13 +289,20 @@ def is_valid_forward_config( def has_experimental_support( - dense: bool, nobag: bool, vbe: bool, is_index_select: bool, is_rocm: bool + dense: bool, nobag: bool, vbe: bool, is_index_select: bool, is_rocm: bool, ssd: bool ) -> bool: """ Check if the given combination of configs has TBE v2 support - - TBE v2 does not support dense, nobag, vbe, is_index_select, and is_rocm + - TBE v2 does not support dense, nobag, vbe, is_index_select, is_rocm, and ssd """ - return not dense and not nobag and not vbe and not is_index_select and not is_rocm + return ( + not dense + and not nobag + and not vbe + and not is_index_select + and not is_rocm + and not ssd + ) def is_valid_gwd_config( @@ -303,7 +310,8 @@ def is_valid_gwd_config( nobag: bool, vbe: bool, is_index_select: bool, - has_global_weight_decay_support: bool = True, + has_global_weight_decay_support: bool, + ssd: bool, ) -> bool: """ Check if the given combination of configs is valid for global weight decay support @@ -318,6 +326,7 @@ def is_valid_gwd_config( and not vbe and not is_index_select and has_global_weight_decay_support + and not ssd ) diff --git a/fbgemm_gpu/codegen/genscript/optimizers.py b/fbgemm_gpu/codegen/genscript/optimizers.py index 179b80f74..3b2073692 100644 --- a/fbgemm_gpu/codegen/genscript/optimizers.py +++ b/fbgemm_gpu/codegen/genscript/optimizers.py @@ -42,6 +42,7 @@ def dense() -> Dict[str, Any]: "has_gpu_support": True, "has_vbe_support": False, "has_global_weight_decay_support": False, + "has_ssd_support": False, } @@ -86,6 +87,7 @@ def adagrad() -> Dict[str, Any]: "has_gpu_support": True, "has_vbe_support": False, "has_global_weight_decay_support": False, + "has_ssd_support": False, } @@ -255,6 +257,7 @@ def rowwise_adagrad() -> Dict[str, Any]: "has_gpu_support": True, "has_vbe_support": True, "has_global_weight_decay_support": True, + "has_ssd_support": True, } @@ -286,6 +289,7 @@ def approx_rowwise_adagrad() -> Dict[str, Any]: "has_gpu_support": False, "has_vbe_support": False, "has_global_weight_decay_support": False, + "has_ssd_support": False, } @@ -392,6 +396,7 @@ def rowwise_adagrad_with_weight_decay() -> Dict[str, Any]: "has_gpu_support": False, "has_vbe_support": False, "has_global_weight_decay_support": False, + "has_ssd_support": False, } @@ -428,6 +433,7 @@ def approx_rowwise_adagrad_with_weight_decay() -> Dict[str, Any]: "has_gpu_support": False, "has_vbe_support": False, "has_global_weight_decay_support": False, + "has_ssd_support": False, } @@ -599,6 +605,7 @@ def rowwise_adagrad_with_counter() -> Dict[str, Any]: "has_gpu_support": True, "has_vbe_support": True, "has_global_weight_decay_support": False, + "has_ssd_support": False, } @@ -648,6 +655,7 @@ def approx_rowwise_adagrad_with_counter() -> Dict[str, Any]: "has_gpu_support": False, "has_vbe_support": False, "has_global_weight_decay_support": False, + "has_ssd_support": False, } @@ -726,6 +734,7 @@ def rowwise_weighted_adagrad() -> Dict[str, Any]: "has_gpu_support": False, "has_vbe_support": False, "has_global_weight_decay_support": False, + "has_ssd_support": False, } @@ -750,6 +759,7 @@ def sgd() -> Dict[str, Any]: "has_gpu_support": True, "has_vbe_support": True, "has_global_weight_decay_support": False, + "has_ssd_support": False, } @@ -774,6 +784,7 @@ def approx_sgd() -> Dict[str, Any]: "has_gpu_support": False, "has_vbe_support": False, "has_global_weight_decay_support": False, + "has_ssd_support": False, } @@ -852,6 +863,7 @@ def lamb() -> Dict[str, Any]: "has_gpu_support": True, "has_vbe_support": False, "has_global_weight_decay_support": False, + "has_ssd_support": False, } @@ -944,6 +956,7 @@ def partial_rowwise_lamb() -> Dict[str, Any]: "has_gpu_support": True, "has_vbe_support": False, "has_global_weight_decay_support": False, + "has_ssd_support": False, } @@ -1000,6 +1013,7 @@ def adam() -> Dict[str, Any]: "has_gpu_support": True, "has_vbe_support": False, "has_global_weight_decay_support": False, + "has_ssd_support": False, } @@ -1075,6 +1089,7 @@ def partial_rowwise_adam() -> Dict[str, Any]: "has_gpu_support": True, "has_vbe_support": False, "has_global_weight_decay_support": False, + "has_ssd_support": False, } @@ -1140,6 +1155,7 @@ def lars_sgd() -> Dict[str, Any]: "has_gpu_support": True, "has_vbe_support": False, "has_global_weight_decay_support": False, + "has_ssd_support": False, } @@ -1158,4 +1174,5 @@ def none_optimizer() -> Dict[str, Any]: "has_gpu_support": True, "has_vbe_support": False, "has_global_weight_decay_support": False, + "has_ssd_support": False, } diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp index 8fcc77cce..137a09302 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp @@ -20,6 +20,17 @@ using Tensor = at::Tensor; using namespace fbgemm_gpu; +{#/* Module description */#} +{%- set mdesc = "ssd" if ssd else "split" %} + +{%- if ssd %} +enum SSDTensor { + {%- for tensor in ssd_tensors %} + {{ tensor | upper }} = {{ loop.index - 1 }}, + {%- endfor %} +}; +{%- endif %} + //////////////////////////////////////////////////////////////////////////////// // Macro Helper Functions //////////////////////////////////////////////////////////////////////////////// @@ -31,22 +42,21 @@ using namespace fbgemm_gpu; */ #} {%- macro call_forward_op_dispatch(nobag, weighted, vbe, is_gwd) %} - {%- set vdesc = "_vbe" if vbe else "" %} - {%- set gwddesc = "_gwd" if is_gwd else "" %} - {%- set wdesc = "weighted" if weighted else "unweighted" %} - {%- set nobag = "_nobag" if nobag else "" %} - - {%- set forward_op = "split_embedding{}_codegen_forward_{}{}{}_cuda".format( - nobag, wdesc, vdesc, gwddesc + {%- set forward_op = "{}_embedding{}_codegen_forward_{}{}{}_cuda".format( + mdesc, + "_nobag" if nobag else "", + "weighted" if weighted else "unweighted", + "_vbe" if vbe else "", + "_gwd" if is_gwd else "", ) %} - static auto split_embedding_codegen_forward_op = + static auto embedding_codegen_forward_op = torch::Dispatcher::singleton() .findSchemaOrThrow("fbgemm::{{ forward_op }}", "") .typed(); return { - split_embedding_codegen_forward_op.call( + embedding_codegen_forward_op.call( flatten_dev_weights, uvm_weights, lxu_cache_weights, @@ -67,7 +77,11 @@ using namespace fbgemm_gpu; *indice_weights, {%- endif %} {%- endif %} {# /* if not nobag */ #} + {%- if ssd %} + ssd_tensors[SSDTensor::ROW_ADDRS], + {%- else %} lxu_cache_locations, + {%- endif %} uvm_cache_stats_, output_dtype, {%- if not nobag %} @@ -99,20 +113,22 @@ using namespace fbgemm_gpu; unweighted backward op via Pytorch dispatcher */ {%- macro call_backward_op_dispatch(nobag, weighted, vbe, is_gwd) %} - {%- set nobag = "_nobag" if nobag else "" %} - {%- set vdesc = "_vbe" if vbe else "" %} - {%- set gwddesc = "_gwd" if is_gwd else "" %} - {%- set wdesc = "weighted" if weighted else "unweighted" %} - {%- set backward_op = "split_embedding{}_backward_codegen_{}_{}_exact{}{}_cuda".format( - nobag, optimizer, wdesc, vdesc, gwddesc + {%- set wdesc = "_weighted" if weighted else "_unweighted" %} + {%- set backward_op = "{}_embedding{}_backward_codegen_{}{}{}{}_exact_cuda".format( + mdesc, + "_nobag" if nobag else "", + optimizer, + wdesc, + "_vbe" if vbe else "", + "_gwd" if is_gwd else "", ) %} - static auto split_embedding_codegen_{{ wdesc }}_backward_op = + static auto embedding_codegen_{{ wdesc }}_backward_op = torch::Dispatcher::singleton() .findSchemaOrThrow("fbgemm::{{ backward_op }}", "") .typed(); - grad_dev_weights = split_embedding_codegen_{{ wdesc }}_backward_op.call( + grad_dev_weights = embedding_codegen_{{ wdesc }}_backward_op.call( grad_output, dev_weights, uvm_weights, @@ -135,7 +151,11 @@ using namespace fbgemm_gpu; indice_weights, {%- endif %} {%- endif %} {# /* if not nobag */ #} + {%- if ssd %} + ssd_row_addrs, + {%- else %} lxu_cache_locations, + {%- endif %} BT_block_size, max_segment_length_per_warp, {%- if optimizer != "none" %} @@ -209,6 +229,11 @@ using namespace fbgemm_gpu; Variable(), // iter {%- endif %} {%- endif %} + {%- if ssd %} + {%- for tensor in ssd_tensors %} + Variable(), // {{ tensor }} + {%- endfor %} + {%- endif %} {{ args.split_variables | join(", ") }} }; {%- endmacro %} @@ -217,7 +242,8 @@ using namespace fbgemm_gpu; from lookup_function */ {%- macro call_autograd(nobag, vbe, is_gwd) %} - {%- set autograd_func = "Split{}{}{}LookupFunction_{}_Op".format( + {%- set autograd_func = "{}{}{}{}LookupFunction_{}_Op".format( + "SSD" if ssd else "Split", "NoBag" if nobag else "", "VBE" if vbe else "", "GWD" if is_gwd else "", @@ -274,6 +300,9 @@ using namespace fbgemm_gpu; iter, {%- endif %} {%- endif %} + {%- if ssd %} + ssd_tensors.value(), + {%- endif %} {{ args.split_function_arg_names | join(", ") }})[0]; {%- endmacro %} @@ -287,9 +316,34 @@ using namespace fbgemm_gpu; {%- for vbe in ([True, False] if has_vbe_support else [False]) %} {%- set vdesc = "_vbe" if vbe else "" %} -{%- for weighted in [True, False] %} +Tensor {{ mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda( + const Tensor& grad_output, + const Tensor& dev_weights, + const Tensor& uvm_weights, + const Tensor& lxu_cache_weights, + const Tensor& weights_placements, + const Tensor& weights_offsets, + const Tensor& D_offsets, + const c10::SymInt max_D, + const Tensor& indices, + const Tensor& offsets, + {%- if ssd %} + const Tensor& ssd_row_addrs, + {%- else %} + const Tensor& lxu_cache_locations, + {%- endif %} + {%- if vbe %} + const Tensor& feature_requires_grad, + const Tensor& vbe_row_output_offsets, + const Tensor& vbe_b_t_map, + const int64_t info_B_num_bits, + const int64_t info_B_mask_int64 + {%- else %} + const Tensor& feature_requires_grad + {%- endif %} +); + {%- for nobag in ([False] if (weighted or vbe) else [True, False]) %} -{%- set wdesc = "weighted" if weighted else "unweighted" %} {%- set ndesc = "_nobag" if nobag else "" %} {%- for is_gwd in ([True, False] @@ -298,11 +352,17 @@ using namespace fbgemm_gpu; nobag, vbe, is_index_select, - has_global_weight_decay_support - ) + has_global_weight_decay_support, + ssd) else [False]) %} {%- set gwddesc = "_gwd" if is_gwd else "" %} -Tensor split_embedding{{ ndesc }}_codegen_forward_{{ wdesc }}{{ vdesc }}{{ gwddesc }}_cuda( + +{%- for weighted in [True, False] %} +{%- set wdesc = "_weighted" if weighted else "_unweighted" %} + +{%- set desc_suffix = wdesc + vdesc + gwddesc %} + +Tensor {{ mdesc }}_embedding{{ ndesc }}_codegen_forward{{ desc_suffix }}_cuda( const Tensor& dev_weights, const Tensor& uvm_weights, const Tensor& lxu_cache_weights, @@ -323,7 +383,11 @@ Tensor split_embedding{{ ndesc }}_codegen_forward_{{ wdesc }}{{ vdesc }}{{ gwdde {%- if weighted %} const Tensor& indice_weights, {%- endif %} + {%- if ssd %} + const Tensor& ssd_row_addrs, + {%- else %} const Tensor& lxu_cache_locations, + {%- endif %} const Tensor& uvm_cache_stats, const int64_t output_dtype, {%- if vbe %} @@ -343,9 +407,10 @@ Tensor split_embedding{{ ndesc }}_codegen_forward_{{ wdesc }}{{ vdesc }}{{ gwdde {%- else %} const bool is_experimental {%- endif %} - ); +); -Tensor split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_exact{{ vdesc }}{{ gwddesc }}_cuda( +Tensor +{{ mdesc }}_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}{{ desc_suffix }}_exact_cuda( const Tensor& grad_output, const Tensor& dev_weights, const Tensor& uvm_weights, @@ -368,7 +433,11 @@ Tensor split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_e {%- if weighted %} const Tensor& indice_weights, {%- endif %} + {%- if ssd %} + const Tensor& ssd_row_addrs, + {%- else %} const Tensor& lxu_cache_locations, + {%- endif %} const int64_t BT_block_size, const int64_t max_segment_length_per_warp, {%- if optimizer != "none" %} @@ -392,51 +461,15 @@ Tensor split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_e {%- endif %} {%- endif %} {{ args.split_function_args | join(", ") }}); - -{%- endfor %} {#-/*for is_gwd*/#} -{%- endfor %} {#-/*for nobag*/#} -{%- endfor %} {#-/*for weighted*/#} - -Tensor split_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda( - const Tensor& grad_output, - const Tensor& dev_weights, - const Tensor& uvm_weights, - const Tensor& lxu_cache_weights, - const Tensor& weights_placements, - const Tensor& weights_offsets, - const Tensor& D_offsets, - const c10::SymInt max_D, - const Tensor& indices, - const Tensor& offsets, - const Tensor& lxu_cache_locations, - {%- if vbe %} - const Tensor& feature_requires_grad, - const Tensor& vbe_row_output_offsets, - const Tensor& vbe_b_t_map, - const int64_t info_B_num_bits, - const int64_t info_B_mask_int64 - {%- else %} - const Tensor& feature_requires_grad - {%- endif %} -); +{%- endfor %} {#-/* for weighted*/#} //////////////////////////////////////////////////////////////////////////////// // Autograd Function Declarations //////////////////////////////////////////////////////////////////////////////// -{%- for nobag in [True, False] %} -{%- if not nobag or not vbe %} {#-/* nobag does not support vbe */#} {#- /* Generate a separate autograd function for global weight decay */ #} -{%- for is_gwd in ([True, False] - if is_valid_gwd_config( - dense, - nobag, - vbe, - is_index_select, - has_global_weight_decay_support - ) - else [False]) %} -{%- set autograd_func = "Split{}{}{}LookupFunction_{}_Op".format( +{%- set autograd_func = "{}{}{}{}LookupFunction_{}_Op".format( + "SSD" if ssd else "Split", "NoBag" if nobag else "", "VBE" if vbe else "", "GWD" if is_gwd else "", @@ -444,7 +477,6 @@ Tensor split_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda( ) %} - class {{ autograd_func }} : public torch::autograd::Function<{{ autograd_func }}> { public: @@ -499,6 +531,9 @@ class {{ autograd_func }} : const int64_t iter, {%- endif %} {%- endif %} + {%- if ssd %} + const at::TensorList& ssd_tensors, + {%- endif %} {{ args.split_function_args | join(", ") }}) { const auto T = weights_offsets.sym_numel(); @@ -590,6 +625,11 @@ class {{ autograd_func }} : {%- if is_gwd and "prev_iter_dev" not in args.split_function_arg_names %} prev_iter_dev_, {%- endif %} + {%- if ssd %} + {%- for tensor in ssd_tensors %} + ssd_tensors[SSDTensor::{{ tensor | upper }}], + {%- endfor %} + {%- endif %} {{ args.split_saved_tensors | join(", ") }} }); @@ -627,12 +667,33 @@ class {{ autograd_func }} : {%- endif %} {%- if nobag %} - {{ call_forward_op_dispatch(nobag=True, weighted=False, vbe=vbe, is_gwd=is_gwd) }} + {{ + call_forward_op_dispatch( + nobag=True, + weighted=False, + vbe=vbe, + is_gwd=is_gwd, + ) + }} {%- else %} if (indice_weights) { - {{ call_forward_op_dispatch(nobag=False, weighted=True, vbe=vbe, is_gwd=is_gwd) }} + {{ + call_forward_op_dispatch( + nobag=False, + weighted=True, + vbe=vbe, + is_gwd=is_gwd, + ) + }} } - {{ call_forward_op_dispatch(nobag=False, weighted=False, vbe=vbe, is_gwd=is_gwd) }} + {{ + call_forward_op_dispatch( + nobag=False, + weighted=False, + vbe=vbe, + is_gwd=is_gwd, + ) + }} {%- endif %} {#-/* if not nobag */ #} } @@ -666,6 +727,11 @@ class {{ autograd_func }} : {%- if is_gwd and "prev_iter_dev" not in args.split_function_arg_names %} auto prev_iter_dev = *savedItr++; {%- endif %} + {%- if ssd %} + {%- for tensor in ssd_tensors %} + auto ssd_{{ tensor }} = *savedItr++; + {%- endfor %} + {%- endif %} {%- for tensor in args.split_saved_tensors %} auto {{ tensor }} = *savedItr++; @@ -719,21 +785,21 @@ class {{ autograd_func }} : {%- if not nobag %} {%- if optimizer == "none" %} // Flatten (dev_weights is used in - // split_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda) + // {{ mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda) dev_weights = dev_weights.flatten(); {%- endif %} {%- set grad_indice_weights_op = - "split_embedding_codegen_grad_indice_weights{}_cuda".format(vdesc) + "{}_embedding_codegen_grad_indice_weights{}_cuda".format(mdesc, vdesc) %} - static auto split_embedding_codegen_grad_indice_weights_op = + static auto {{ mdesc }}_embedding_codegen_grad_indice_weights_op = torch::Dispatcher::singleton() .findSchemaOrThrow("fbgemm::{{ grad_indice_weights_op }}", "") .typed(); const auto grad_indice_weights = !indice_weights.defined() ? Variable() : - split_embedding_codegen_grad_indice_weights_op.call( + {{ mdesc }}_embedding_codegen_grad_indice_weights_op.call( grad_output, dev_weights, uvm_weights, @@ -744,7 +810,11 @@ class {{ autograd_func }} : max_D, indices, offsets, + {%- if ssd %} + ssd_row_addrs, + {%- else %} lxu_cache_locations, + {%- endif %} {%- if vbe %} feature_requires_grad, vbe_row_output_offsets, @@ -758,24 +828,43 @@ class {{ autograd_func }} : Tensor grad_dev_weights; if (indice_weights.defined()) { - {{ call_backward_op_dispatch(nobag=False, weighted=True, vbe=vbe, is_gwd=is_gwd) }} + {{ + call_backward_op_dispatch( + nobag=False, + weighted=True, + vbe=vbe, + is_gwd=is_gwd, + ) + }} } - {{ call_backward_op_dispatch(nobag=False, weighted=False, vbe=vbe, is_gwd=is_gwd) }} - + {{ + call_backward_op_dispatch( + nobag=False, + weighted=False, + vbe=vbe, + is_gwd=is_gwd, + ) + }} {%- else %} Tensor grad_dev_weights; - {{ call_backward_op_dispatch(nobag=True, weighted=False, vbe=vbe, is_gwd=is_gwd) }} + {{ + call_backward_op_dispatch( + nobag=True, + weighted=False, + vbe=vbe, + is_gwd=is_gwd, + ) + }} {%- endif %} } }; {%- endfor %} {#-/* for is_gwd */#} -{%- endif %} {#-/* if not nobag or not vbe */#} -{%- endfor %} {#-/* for nobag in [True, False] */#} -{%- endfor %} {#-/* for vbe in [True, False] */#} +{%- endfor %} {#-/* for nobag */#} +{%- endfor %} {#-/* for vbe */#} {%- endif %} {#-/* if has_gpu_support */#} ///@ingroup embedding-cuda -Tensor split_embedding_codegen_lookup_{{ optimizer }}_function( +Tensor {{ mdesc }}_embedding_codegen_lookup_{{ optimizer }}_function( const Tensor& placeholder_autograd_tensor, const Tensor& dev_weights, const Tensor& uvm_weights, @@ -816,15 +905,17 @@ Tensor split_embedding_codegen_lookup_{{ optimizer }}_function( {%- if "iter" not in args.split_function_arg_names %} const int64_t iter = 0, {%- endif %} + {%- if ssd %} + const bool apply_global_weight_decay = false, + const c10::optional& ssd_tensors = c10::nullopt + {%- else %} const bool apply_global_weight_decay = false + {%- endif %} ) { // TODO: refactor into macro {%- if has_gpu_support %} - if (static_cast(pooling_mode) == PoolingMode::NONE) { - // no bag - {{ call_autograd(nobag=True, vbe=False, is_gwd=False) }} - } + {%- if not ssd %} {%- if has_vbe_support %} // has vbe support if (B_offsets.has_value()) { @@ -840,11 +931,26 @@ Tensor split_embedding_codegen_lookup_{{ optimizer }}_function( {{ call_autograd(nobag=False, vbe=False, is_gwd=True) }} } {%- endif %} {#-/* if has_global_weight_decay_support */ #} + {%- endif %} {#-/* if not ssd */#} - {{ call_autograd(nobag=False, vbe=False, is_gwd=False) }} + {%- if ssd %} + TORCH_CHECK( + ssd_tensors.value().size() == {{ ssd_tensors | length }}, + "SSD TBE expects {{ ssd_tensors | length }} in ssd_tensors"); + {%- endif %} + if (static_cast(pooling_mode) == PoolingMode::NONE) { + // no bag + {{ call_autograd(nobag=True, vbe=False, is_gwd=False) }} + } + else { + {{ call_autograd(nobag=False, vbe=False, is_gwd=False) }} + } {%- else %} - TORCH_CHECK(false, "split_embedding_codegen_lookup_{{ optimizer }}_function is deprecated. Please see https://github.com/pytorch/FBGEMM/discussions/1727 for more detail."); + TORCH_CHECK( + false, + "{{ mdesc }}_embedding_codegen_lookup_{{ optimizer }}_function is deprecated. Please see https://github.com/pytorch/FBGEMM/discussions/1727 for more detail." + ); return Tensor(); {%- endif %} {#-/* if has_gpu_support */#} } @@ -852,7 +958,8 @@ Tensor split_embedding_codegen_lookup_{{ optimizer }}_function( // Deprecated for fb namespace! Please use fbgemm namespace instead! {%- for lib_name in ["fb", "fbgemm"] %} TORCH_LIBRARY_FRAGMENT({{ lib_name }}, m) { - m.def("split_embedding_codegen_lookup_{{ optimizer }}_function(" + {%- set op_name = "{}_embedding_codegen_lookup_{}_function".format(mdesc, optimizer) %} + m.def("{{ op_name }}(" " Tensor placeholder_autograd_tensor, " " Tensor dev_weights, " " Tensor uvm_weights, " @@ -893,7 +1000,12 @@ TORCH_LIBRARY_FRAGMENT({{ lib_name }}, m) { {%- if "iter" not in args.split_function_arg_names %} " int iter=0, " {%- endif %} - " bool apply_global_weight_decay=False " + {%- if ssd %} + " bool apply_global_weight_decay=False, " + " Tensor[]? ssd_tensors=None" + {%- else %} + " bool apply_global_weight_decay=False" + {%- endif %} ") -> Tensor", {PT2_COMPLIANT_TAG}); // We're playing a funny trick here: we're using the autograd @@ -902,18 +1014,18 @@ TORCH_LIBRARY_FRAGMENT({{ lib_name }}, m) { // no autograd enabled, and all of the internal implementations redispatch // appropriately m.impl( - "split_embedding_codegen_lookup_{{ optimizer }}_function", + "{{ op_name }}", torch::dispatch( c10::DispatchKey::Autograd, - TORCH_FN(split_embedding_codegen_lookup_{{ optimizer }}_function))); + TORCH_FN({{ op_name }}))); m.impl( - "split_embedding_codegen_lookup_{{ optimizer }}_function", + "{{ op_name }}", torch::dispatch( c10::DispatchKey::Meta, - TORCH_FN(split_embedding_codegen_lookup_{{ optimizer }}_function))); + TORCH_FN({{ op_name }}))); DISPATCH_TO_CUDA( - "split_embedding_codegen_lookup_{{ optimizer }}_function", - split_embedding_codegen_lookup_{{ optimizer }}_function); + "{{ op_name }}", + {{ op_name }}); } {%- endfor %} {#-/* for lib_name */#} - // clang-format on + // clang-format on diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu index 501137b61..8a1c1b052 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu @@ -8,7 +8,11 @@ // clang-format off -{%- set ddesc = "dense" if dense else "split" %} +{%- set mdesc = "dense" if dense else ("ssd" if ssd else "split") %} + +{%- set locs_or_addrs_tensor = "ssd_row_addrs" if ssd else "lxu_cache_locations" %} +{%- set locs_or_addrs_type = "int64_t" if ssd else "int32_t" %} +{%- set locs_or_addrs_idx = "row_idx" if ssd else "cache_idx" %} //////////////////////////////////////////////////////////////////////////////// // Required for op registrations @@ -41,9 +45,8 @@ using namespace fbgemm_gpu; -}} }() -{%- for vbe in [True, False] %} +{%- for vbe in ([True, False] if (not dense and not ssd) else [False]) %} {%- set vdesc = "_vbe" if vbe else "" %} -{%- if not dense or not vbe %} {#- /* Generate different kernels for different kUseVecBlocking using Jinja @@ -64,7 +67,7 @@ template < int32_t kFixedMaxVecsPerThread > __global__ __launch_bounds__(kForwardMaxThreads) void -{{ ddesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_{{ vbdesc }}kernel( +{{ mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_{{ vbdesc }}kernel( // [\sum_t E_t x D_t] const pta::PackedTensorAccessor64 grad_output, pta::PackedTensorAccessor64 dev_weights, @@ -78,7 +81,7 @@ __global__ __launch_bounds__(kForwardMaxThreads) void const pta::PackedTensorAccessor32 indices, // [N = \sum_{b,t} L_{b,t} total indices, i.e. flattened [B][T][L] const pta::PackedTensorAccessor32 offsets, // [B x T + 1] {%- if not dense %} - const pta::PackedTensorAccessor32 lxu_cache_locations, + const pta::PackedTensorAccessor32<{{ locs_or_addrs_type }}, 1, at::RestrictPtrTraits> {{ locs_or_addrs_tensor }}, {%- endif %} pta::PackedTensorAccessor32 feature_requires_grad, // [T], pta::PackedTensorAccessor32, 1, at::RestrictPtrTraits> grad_indice_weights, @@ -172,14 +175,14 @@ __global__ __launch_bounds__(kForwardMaxThreads) void int32_t l = l_start + threadIdx.x; int64_t idx = l < L ? indices[indices_start + l] : 0; {%- if not dense %} - int32_t cache_idx = + const auto {{ locs_or_addrs_idx }} = (placement == PlacementType::MANAGED_CACHING && l < L) - ? lxu_cache_locations[indices_start + l] : 0; + ? {{ locs_or_addrs_tensor }}[indices_start + l] : 0; {%- endif %} for (auto j = 0; j < kWarpSize && l_start + j < L; ++j) { int64_t idx_j = shfl_sync(idx, j); {%- if not dense %} - int32_t cache_idx_j = shfl_sync(cache_idx, j); + const auto {{ locs_or_addrs_idx }}_j = shfl_sync({{ locs_or_addrs_idx }}, j); {%- endif %} at::acc_type grad_indice_weight = 0.0; @@ -189,8 +192,20 @@ __global__ __launch_bounds__(kForwardMaxThreads) void ++vec) { const int32_t d = {{ d }}; {%- if not dense %} - if (placement == PlacementType::MANAGED_CACHING && cache_idx_j != kCacheLocationMissing) { - Vec4T weight(&lxu_cache_weights[cache_idx_j][d]); + if ({{ "true || " if ssd else "" }} + ( + placement == PlacementType::MANAGED_CACHING + && ({{ locs_or_addrs_idx }}_j != kCacheLocationMissing) + ) + ) { + const cache_t* cache_weights = + {%- if ssd %} + reinterpret_cast( + *reinterpret_cast(&{{ locs_or_addrs_idx }}_j)); + {%- else %} + &lxu_cache_weights[{{ locs_or_addrs_idx }}_j][d]; + {%- endif %} + Vec4T weight(cache_weights); grad_indice_weight += weight.acc.x * grad_out[vec].acc.x + weight.acc.y * grad_out[vec].acc.y + weight.acc.z * grad_out[vec].acc.z + @@ -261,7 +276,7 @@ __global__ __launch_bounds__(kForwardMaxThreads) void } {%- endfor %} {# /* for use_vec_blocking */ #} -Tensor {{ ddesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda( +Tensor {{ mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda( const Tensor& grad_output, const Tensor& dev_weights, {%- if not dense %} @@ -275,7 +290,7 @@ Tensor {{ ddesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda( const Tensor& indices, const Tensor& offsets, {%- if not dense %} - const Tensor& lxu_cache_locations, + const Tensor& {{ locs_or_addrs_tensor }}, {%- endif %} {%- if vbe %} const Tensor& feature_requires_grad, @@ -300,7 +315,7 @@ Tensor {{ ddesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda( indices, offsets, {%- if not dense %} - lxu_cache_locations, + {{ locs_or_addrs_tensor }}, {%- endif %} {%- if vbe %} vbe_row_output_offsets, @@ -347,7 +362,7 @@ Tensor {{ ddesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda( {%- else %} dev_weights.scalar_type(), {%- endif %} - "split_embedding_codegen_grad_indice_weights_kernel", + "split_embedding_codegen_grad_indice_weights{{ vdesc }}_kernel", [&] { {%- if vbe %} const auto& grad_output_reshaped = aligned_grad_output.reshape({1, -1}); @@ -361,7 +376,7 @@ Tensor {{ ddesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda( DISPATCH_{{ dpdesc }}VEC_BLOCKING_KERNEL(max_D, [&] { {%- set kernel_name = "{}_embedding_codegen_grad_indice_weights{}_{}kernel".format( - ddesc, vdesc, vbdesc) + mdesc, vdesc, vbdesc) %} #ifdef FBGEMM_GPU_MEMCHECK const auto func_name = @@ -388,7 +403,7 @@ Tensor {{ ddesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda( MAKE_PTA_WITH_NAME(func_name, indices, int64_t, 1, 32), MAKE_PTA_WITH_NAME(func_name, offsets, int64_t, 1, 32), {%- if not dense %} - MAKE_PTA_WITH_NAME(func_name, lxu_cache_locations, int32_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name, {{ locs_or_addrs_tensor }}, {{ locs_or_addrs_type }}, 1, 32), {%- endif %} MAKE_PTA_WITH_NAME(func_name, feature_requires_grad_, int32_t, 1, 32), MAKE_PTA_ACC_WITH_NAME(func_name, grad_indice_weights, grad_t, 1, 32), @@ -412,7 +427,7 @@ Tensor {{ ddesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda( } -Tensor {{ ddesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_meta( +Tensor {{ mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_meta( const Tensor& grad_output, const Tensor& dev_weights, {%- if not dense %} @@ -426,7 +441,7 @@ Tensor {{ ddesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_meta( const Tensor& indices, const Tensor& offsets, {%- if not dense %} - const Tensor& lxu_cache_locations, + const Tensor& {{ locs_or_addrs_tensor }}, {%- endif %} {%- if vbe %} const Tensor& feature_requires_grad, @@ -457,11 +472,12 @@ Tensor {{ ddesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_meta( //////////////////////////////////////////////////////////////////////////////// TORCH_LIBRARY_FRAGMENT(fbgemm, m) { {%- set embedding_codegen_grad_indice_weights_op = - "{}_embedding_codegen_grad_indice_weights{}_cuda".format( - ddesc, vdesc + "{}_embedding_codegen_grad_indice_weights{}".format( + mdesc, vdesc ) %} - m.def("{{ embedding_codegen_grad_indice_weights_op }}(" + {%- set embedding_codegen_grad_indice_weights_op_cuda = embedding_codegen_grad_indice_weights_op + "_cuda" %} + m.def("{{ embedding_codegen_grad_indice_weights_op_cuda }}(" " Tensor grad_output, " " Tensor dev_weights, " {%- if not dense %} @@ -475,7 +491,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { " Tensor indices, " " Tensor offsets, " {%- if not dense %} - " Tensor lxu_cache_locations, " + " Tensor {{ locs_or_addrs_tensor }}, " {%- endif %} {%- if vbe %} " Tensor feature_requires_grad, " @@ -488,12 +504,12 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { {%- endif %} ") -> Tensor"); DISPATCH_TO_CUDA( - "{{ embedding_codegen_grad_indice_weights_op }}", - {{ embedding_codegen_grad_indice_weights_op }} + "{{ embedding_codegen_grad_indice_weights_op_cuda }}", + {{ embedding_codegen_grad_indice_weights_op_cuda }} ); - m.impl("{{ embedding_codegen_grad_indice_weights_op }}", - torch::dispatch(c10::DispatchKey::Meta, TORCH_FN({{ ddesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_meta))); + m.impl("{{ embedding_codegen_grad_indice_weights_op_cuda }}", + torch::dispatch(c10::DispatchKey::Meta, + TORCH_FN({{ embedding_codegen_grad_indice_weights_op }}_meta))); } -{%- endif %} {#-/* if not dense or not vbe */#} {%- endfor %} {#-/* for vbe */#} // clang-format on diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_cta_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_cta_template.cu index f56bb818d..77590f06d 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_cta_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_cta_template.cu @@ -7,6 +7,7 @@ */ // clang-format off +{%- set mdesc = "ssd" if ssd else "split" %} {%- set wdesc = "weighted" if weighted else "unweighted" %} {%- set ndesc = "_nobag" if nobag else "" %} {%- set vdesc = "_vbe" if vbe else "" %} @@ -23,16 +24,23 @@ nobag, vbe, is_index_select, - has_global_weight_decay_support) %} + has_global_weight_decay_support, + ssd) %} +{%- set gwddesc = "_gwd" if is_gwd_kernel else "" %} + +{%- set desc_suffix = wdesc + vdesc + gwddesc %} + +{%- set locs_or_addrs_tensor = "ssd_row_addrs" if ssd else "lxu_cache_locations" %} +{%- set locs_or_addrs_type = "int64_t" if ssd else "int32_t" %} #include "fbgemm_gpu/embedding_backward_template_helpers.cuh" #include "fbgemm_gpu/fbgemm_tensor_accessor.h" #include "fbgemm_gpu/split_embeddings_utils.cuh" {%- if optimizer != "none" and not dense %} -#include "gen_embedding_optimizer_{{ optimizer }}_split_device_kernel.cuh" +#include "gen_embedding_optimizer_{{ optimizer }}_{{ mdesc }}_device_kernel.cuh" {%- endif %} -#include "gen_embedding_backward_{{ kdesc }}_split_device_kernel.cuh" -#include "gen_embedding_backward_common_split_device_kernel.cuh" +#include "gen_embedding_backward_split_{{ kdesc }}_device_kernel.cuh" +#include "gen_embedding_backward_split_common_device_kernel.cuh" using Tensor = at::Tensor; using namespace fbgemm_gpu; @@ -65,7 +73,6 @@ using namespace fbgemm_gpu; thus reduce the kernel occupancy, which can degrade the kernel performance. This increases the binary size, but the increase is minimal. */ #} -{%- set gwddesc = "_gwd" if is_gwd_kernel else "" %} template < typename emb_t, typename grad_t, @@ -80,7 +87,7 @@ __global__ __launch_bounds__(kMaxThreads) void {%- if is_index_select %} batch_index_select_dim0_codegen_backward_kernel_cta_per_row( {%- else %} -split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}{{ gwddesc }}_kernel_cta_per_row_1( +{{ mdesc }}_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ desc_suffix }}_kernel_cta_per_row_1( {%- endif %} const pta::PackedTensorAccessor64 grad_output, {%- if optimizer != "none" %} @@ -108,7 +115,7 @@ split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc const pta::PackedTensorAccessor32 sorted_infos, {%- endif %} {%- if not dense %} - const pta::PackedTensorAccessor32 sorted_lxu_cache_locations, + const pta::PackedTensorAccessor32<{{ locs_or_addrs_type }}, 1, at::RestrictPtrTraits> sorted_{{ locs_or_addrs_tensor }}, const bool use_uniq_cache_locations, const pta::PackedTensorAccessor32 table_unique_indices_offsets, {%- endif %} @@ -341,7 +348,7 @@ split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc {{ compute_global_weight_decay(is_gwd_kernel) }} {%- if not dense and optimizer != "none" %} - split_{{ optimizer }}_table_update_kernel< + {{ mdesc }}_{{ optimizer }}_table_update_kernel< emb_t, cache_t, {%- for ph_name in args.placeholder_tensor_names %} @@ -356,7 +363,7 @@ split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc lxu_cache_weights, weights_placements, weights_offsets, - sorted_lxu_cache_locations, + sorted_{{ locs_or_addrs_tensor }}, grad_sum, kUseVecBlocking ? smem_grad_sum : nullptr, kIsInt8 ? smem_grad_sum : nullptr, @@ -433,7 +440,7 @@ template __global__ __launch_bounds__(kMaxThreads) void {%- if is_index_select %} batch_index_select_dim0_codegen_backward_kernel_cta_per_row {%- else %} -split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}{{ gwddesc }}_kernel_cta_per_row_1 +{{ mdesc }}_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ desc_suffix }}_kernel_cta_per_row_1 {%- endif %} < {{ emb_type }}, {{ grad_type }}, @@ -472,8 +479,8 @@ split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc const pta::PackedTensorAccessor32 sorted_infos, {%- endif %} {%- if not dense %} - const pta::PackedTensorAccessor32 - sorted_lxu_cache_locations, + const pta::PackedTensorAccessor32<{{ locs_or_addrs_type }}, 1, at::RestrictPtrTraits> + sorted_{{ locs_or_addrs_tensor }}, const bool use_uniq_cache_locations, const pta::PackedTensorAccessor32 table_unique_indices_offsets, {%- endif %} diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu index e91296e58..9ff187dd8 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu @@ -7,9 +7,11 @@ */ // clang-format off +{%- set mdesc = "ssd" if ssd else "split" %} {%- set wdesc = "weighted" if weighted else "unweighted" %} {%- set ndesc = "_nobag" if nobag else "" %} {%- set vdesc = "_vbe" if vbe else "" %} + {# /* `has_global_weight_decay_support` tells whether the optimizer has support for global weight decay (gwd) @@ -22,16 +24,23 @@ nobag, vbe, is_index_select, - has_global_weight_decay_support) %} + has_global_weight_decay_support, + ssd) %} +{%- set gwddesc = "_gwd" if is_gwd_kernel else "" %} + +{%- set desc_suffix = wdesc + vdesc + gwddesc %} + +{%- set locs_or_addrs_tensor = "ssd_row_addrs" if ssd else "lxu_cache_locations" %} +{%- set locs_or_addrs_type = "int64_t" if ssd else "int32_t" %} #include "fbgemm_gpu/embedding_backward_template_helpers.cuh" #include "fbgemm_gpu/fbgemm_tensor_accessor.h" #include "fbgemm_gpu/split_embeddings_utils.cuh" {%- if optimizer != "none" and not dense %} -#include "gen_embedding_optimizer_{{ optimizer }}_split_device_kernel.cuh" +#include "gen_embedding_optimizer_{{ optimizer }}_{{ mdesc }}_device_kernel.cuh" {%- endif %} -#include "gen_embedding_backward_{{ kdesc }}_split_device_kernel.cuh" -#include "gen_embedding_backward_common_split_device_kernel.cuh" +#include "gen_embedding_backward_split_{{ kdesc }}_device_kernel.cuh" +#include "gen_embedding_backward_split_common_device_kernel.cuh" using Tensor = at::Tensor; using namespace fbgemm_gpu; @@ -63,7 +72,7 @@ __global__ __launch_bounds__(kBackwardMaxThreads) void {%- if is_index_select %} batch_index_select_dim0_codegen_backward_kernel_warp_per_row( {%- else %} -split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}{{ gwddesc }}_kernel_warp_per_row_1( +{{ mdesc }}_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ desc_suffix }}_kernel_warp_per_row_1( {%- endif %} const pta::PackedTensorAccessor64 grad_output, {%- if optimizer != "none" %} @@ -89,7 +98,7 @@ split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc const pta::PackedTensorAccessor32 sorted_infos, {%- endif %} {%- if not dense %} - const pta::PackedTensorAccessor32 sorted_lxu_cache_locations, + const pta::PackedTensorAccessor32<{{ locs_or_addrs_type }}, 1, at::RestrictPtrTraits> sorted_{{ locs_or_addrs_tensor }}, const bool use_uniq_cache_locations, const pta::PackedTensorAccessor32 table_unique_indices_offsets, {%- endif %} @@ -249,7 +258,7 @@ split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc kUseVecBlocking ? max_vecs_per_thread : kFixedMaxVecsPerThread; {%- if not dense and optimizer != "none" %} - split_{{ optimizer }}_table_update_kernel< + {{ mdesc }}_{{ optimizer }}_table_update_kernel< emb_t, cache_t, {%- for ph_name in args.placeholder_tensor_names %} @@ -264,7 +273,7 @@ split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc lxu_cache_weights, weights_placements, weights_offsets, - sorted_lxu_cache_locations, + sorted_{{ locs_or_addrs_tensor }}, grad_sum, smem_grad_sum, smem_grad_sum, // shared_weight_update_row (reuse smem_grad_sum) @@ -343,7 +352,7 @@ template __global__ __launch_bounds__(kBackwardMaxThreads) void {%- if is_index_select %} batch_index_select_dim0_codegen_backward_kernel_warp_per_row {%- else %} -split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}{{ gwddesc }}_kernel_warp_per_row_1 +{{ mdesc }}_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ desc_suffix }}_kernel_warp_per_row_1 {%- endif %} < {{ emb_type }}, {{ grad_type }}, @@ -379,7 +388,7 @@ split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc const pta::PackedTensorAccessor32 sorted_infos, {%- endif %} {%- if not dense %} - const pta::PackedTensorAccessor32 sorted_lxu_cache_locations, + const pta::PackedTensorAccessor32<{{ locs_or_addrs_type }}, 1, at::RestrictPtrTraits> sorted_{{ locs_or_addrs_tensor }}, const bool use_uniq_cache_locations, const pta::PackedTensorAccessor32 table_unique_indices_offsets, {%- endif %} diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_meta_template.cpp b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_meta_template.cpp index 3aded4a1e..28eb82ca5 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_meta_template.cpp +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_meta_template.cpp @@ -17,10 +17,13 @@ // Companion template is embedding_backward_split_template.cu +{%- set mdesc = "ssd" if ssd else "split" %} {%- set wdesc = "weighted" if weighted else "unweighted" %} {%- set vdesc = "_vbe" if vbe else "" %} {%- set ndesc = "_nobag" if nobag else "" %} +{%- set locs_or_addrs_tensor = "ssd_row_addrs" if ssd else "lxu_cache_locations" %} + //////////////////////////////////////////////////////////////////////////////// // Required for op registrations #include "fbgemm_gpu/embedding_op_registration.h" @@ -42,14 +45,15 @@ using Tensor = at::Tensor; nobag, vbe, is_index_select, - has_global_weight_decay_support + has_global_weight_decay_support, + ssd=False ) else [False]) %} {%- set gwddesc = "_gwd" if is_gwd else "" %} {%- if is_index_select %} Tensor batch_index_select_dim0_codegen_backward_meta( {%- else %} -Tensor split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_exact{{ vdesc }}{{ gwddesc }}_meta( +Tensor {{ mdesc }}_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_exact{{ vdesc }}{{ gwddesc }}_meta( {%- endif %} const Tensor& grad_output, const Tensor& dev_weights, @@ -78,7 +82,7 @@ Tensor split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_e const Tensor& indice_weights, {%- endif %} {%- if not dense %} - const Tensor& lxu_cache_locations, + const Tensor& {{ locs_or_addrs_tensor }}, {%- endif %} {%- if not is_index_select %} const int64_t unused_, diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index a31bc8c39..b1fa13294 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -8,8 +8,17 @@ // clang-format off {%- set wdesc = "weighted" if weighted else "unweighted" %} -{%- set vdesc = "_vbe" if vbe else "" %} {%- set ndesc = "_nobag" if nobag else "" %} +{%- set vdesc = "_vbe" if vbe else "" %} +{%- set mdesc = "ssd" if ssd else "split" %} + +{%- macro get_desc_suffix(gwd) %} +{%- set gwddesc = "_gwd" if gwd else "" %} +{{- wdesc + vdesc + gwddesc }} +{%- endmacro %} + +{%- set locs_or_addrs_tensor = "ssd_row_addrs" if ssd else "lxu_cache_locations" %} +{%- set locs_or_addrs_type = "int64_t" if ssd else "int32_t" %} {%- if not is_index_select %} //////////////////////////////////////////////////////////////////////////////// @@ -35,8 +44,9 @@ using namespace fbgemm_gpu; nobag, vbe, is_index_select, - has_global_weight_decay_support) %} -{%- set gwddesc = "_gwd" if is_gwd_kernel else "" %} + has_global_weight_decay_support, + ssd) %} +{%- set desc_suffix = get_desc_suffix(is_gwd_kernel) %} template < typename emb_t, typename grad_t, @@ -51,7 +61,7 @@ __global__ __launch_bounds__(kMaxThreads) void {%- if is_index_select %} batch_index_select_dim0_codegen_backward_kernel_cta_per_row( {%- else %} -split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}{{ gwddesc }}_kernel_cta_per_row_1( +{{ mdesc }}_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ desc_suffix }}_kernel_cta_per_row_1( {%- endif %} const pta::PackedTensorAccessor64 grad_output, {%- if optimizer != "none" %} @@ -79,7 +89,7 @@ split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc const pta::PackedTensorAccessor32 sorted_infos, {%- endif %} {%- if not dense %} - const pta::PackedTensorAccessor32 sorted_lxu_cache_locations, + const pta::PackedTensorAccessor32<{{ locs_or_addrs_type }}, 1, at::RestrictPtrTraits> sorted_{{ locs_or_addrs_tensor }}, const bool use_uniq_cache_locations, const pta::PackedTensorAccessor32 table_unique_indices_offsets, {%- endif %} @@ -139,7 +149,7 @@ __global__ __launch_bounds__(kBackwardMaxThreads) void {%- if is_index_select %} batch_index_select_dim0_codegen_backward_kernel_warp_per_row( {%- else %} -split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}{{ gwddesc }}_kernel_warp_per_row_1( +{{ mdesc }}_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ desc_suffix }}_kernel_warp_per_row_1( {%- endif %} const pta::PackedTensorAccessor64 grad_output, {%- if optimizer != "none" %} @@ -165,7 +175,7 @@ split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc const pta::PackedTensorAccessor32 sorted_infos, {%- endif %} {%- if not dense %} - const pta::PackedTensorAccessor32 sorted_lxu_cache_locations, + const pta::PackedTensorAccessor32<{{ locs_or_addrs_type }}, 1, at::RestrictPtrTraits> sorted_{{ locs_or_addrs_tensor }}, const bool use_uniq_cache_locations, const pta::PackedTensorAccessor32 table_unique_indices_offsets, {%- endif %} @@ -417,22 +427,21 @@ int32_t compute_num_groups_and_dynamic_smem_bytes( nobag, vbe, is_index_select, - has_global_weight_decay_support) %} -{%- set gwddesc = "_gwd" if is_gwd_kernel else "" %} - -{%- set func_name0 = "split_embedding{}_backward_codegen_{}_{}_exact{}{}_cuda".format( + has_global_weight_decay_support, + ssd) %} +{%- set desc_suffix = get_desc_suffix(is_gwd_kernel) %} + +{%- set embedding_cuda_op = + "batch_index_select_dim0_codegen_backward_cuda" + if is_index_select + else "{}_embedding{}_backward_codegen_{}_{}_exact_cuda".format( + mdesc, ndesc, optimizer, - wdesc, - vdesc, - gwddesc) + desc_suffix) %} -{%- if is_index_select %} -Tensor batch_index_select_dim0_codegen_backward_cuda( -{%- else %} -Tensor split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_exact{{ vdesc }}{{ gwddesc }}_cuda( -{%- endif %} +Tensor {{ embedding_cuda_op }}( const Tensor& grad_output, const Tensor& dev_weights, {%- if not dense %} @@ -460,7 +469,7 @@ Tensor split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_e const Tensor& indice_weights, {%- endif %} {%- if not dense %} - const Tensor& lxu_cache_locations, + const Tensor& {{ locs_or_addrs_tensor }}, {%- endif %} {%- if not is_index_select %} const int64_t unused_, @@ -538,7 +547,7 @@ Tensor split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_e indice_weights, {%- endif %} {%- if not dense %} - lxu_cache_locations, + {{ locs_or_addrs_tensor }}, {%- endif %} {%- if is_gwd_kernel %} prev_iter_dev, @@ -676,24 +685,24 @@ Tensor split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_e ); {%- if not dense %} - Tensor lxu_cache_locations_sorted = lxu_cache_locations; + Tensor {{ locs_or_addrs_tensor }}_sorted = {{ locs_or_addrs_tensor }}; Tensor table_unique_indices_offsets; - if (lxu_cache_locations.size(0) > 0) { + if ({{ locs_or_addrs_tensor }}.size(0) > 0) { if (use_uniq_cache_locations) { if (!use_homogeneous_placements) { - // When use_uniq_cache_locations=true, lxu_cache_locations are unique + // When use_uniq_cache_locations=true, {{ locs_or_addrs_tensor }} are unique // and sorted in an ascending order based on the linear cache indices. // Linear cache indices of tables that are not placed in cache are set // to a sentinel value (i.e., the sum of hash sizes of all embedding // tables). Since the sentinel value is larger than the max linear - // cache index value, the lxu_cache_locations can be sorted differently + // cache index value, the {{ locs_or_addrs_tensor }} can be sorted differently // than the sorted_linear_indices. // // For this reason, the run ids of sorted and unique - // lxu_cache_locations can be different from those of the + // {{ locs_or_addrs_tensor }} can be different from those of the // sorted_linear_indices. We need the following code to compute // table_unique_indices_offsets which contains the differences between - // lxu_cache_locations run ids and sorted_linear_indices run ids. + // {{ locs_or_addrs_tensor }} run ids and sorted_linear_indices run ids. auto dev_or_uvm_unique_indices = at::zeros_like(weights_placements); #ifdef FBGEMM_GPU_MEMCHECK @@ -728,15 +737,15 @@ Tensor split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_e } } else { - lxu_cache_locations_sorted = at::empty_like(lxu_cache_locations); + {{ locs_or_addrs_tensor }}_sorted = at::empty_like({{ locs_or_addrs_tensor }}); size_t temp_storage_bytes = 0; AT_CUDA_CHECK(radix_sort_pairs( nullptr, temp_storage_bytes, linear_indices.data_ptr(), linear_indices_sorted.data_ptr(), - lxu_cache_locations.data_ptr(), - lxu_cache_locations_sorted.data_ptr(), + {{ locs_or_addrs_tensor }}.data_ptr<{{ locs_or_addrs_type }}>(), + {{ locs_or_addrs_tensor }}_sorted.data_ptr<{{ locs_or_addrs_type }}>(), linear_indices.numel(), 0, total_hash_size_bits, @@ -749,8 +758,8 @@ Tensor split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_e temp_storage_bytes, linear_indices.data_ptr(), linear_indices_sorted.data_ptr(), - lxu_cache_locations.data_ptr(), - lxu_cache_locations_sorted.data_ptr(), + {{ locs_or_addrs_tensor }}.data_ptr<{{ locs_or_addrs_type }}>(), + {{ locs_or_addrs_tensor }}_sorted.data_ptr<{{ locs_or_addrs_type }}>(), linear_indices.numel(), 0, total_hash_size_bits, @@ -758,7 +767,7 @@ Tensor split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_e } } - if (lxu_cache_locations.size(0) == 0 || !use_uniq_cache_locations || use_homogeneous_placements) { + if ({{ locs_or_addrs_tensor }}.size(0) == 0 || !use_uniq_cache_locations || use_homogeneous_placements) { table_unique_indices_offsets = at::zeros_like(weights_placements); } {%- endif %} @@ -771,7 +780,7 @@ Tensor split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_e {%- else %} dev_weights.scalar_type(), {%- endif %} - "split_embedding_backward_{{ optimizer }}_exact_kernel", + "{{ embedding_cuda_op }}", [&] { {%- if weighted %} auto indice_weights_sorted = at::empty_like(indice_weights); @@ -816,7 +825,7 @@ Tensor split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_e {%- endif %} auto grad_output_accessor = MAKE_PTA_WITH_NAME( - "{{ func_name0 }}.1", + "{{ embedding_cuda_op }}.1", grad_output_reshaped, grad_t, {{ "1" if is_index_select else "2" }}, 64 @@ -855,7 +864,7 @@ Tensor split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_e C10_CUDA_KERNEL_LAUNCH_CHECK(); {%- endif %} // if not dense or not vbe - grad_output_accessor = MAKE_PTA_WITH_NAME("{{ func_name0 }}.2", grad_output_mean, grad_t, 2, 64); + grad_output_accessor = MAKE_PTA_WITH_NAME("{{ embedding_cuda_op }}.2", grad_output_mean, grad_t, 2, 64); } {%- endif %} @@ -924,18 +933,17 @@ Tensor split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_e {%- for ph_name in args.placeholder_tensor_names %} {{ ph_name + "_dev" }}.scalar_type(), {%- endfor %} - "split_embedding_backward_{{ optimizer }}_exact_placeholder_type_kernel", + "{{ mdesc }}_embedding_backward_{{ optimizer }}_exact_placeholder_type_kernel", [&] { {%- set cta_kernel = "batch_index_select_dim0_codegen_backward_kernel_cta_per_row" if is_index_select else - "split_embedding{}_backward_codegen_{}_{}{}{}_kernel_cta_per_row_1".format( + "{}_embedding{}_backward_codegen_{}_{}_kernel_cta_per_row_1".format( + mdesc, ndesc, optimizer, - wdesc, - vdesc, - gwddesc + desc_suffix, ) %} @@ -1003,7 +1011,7 @@ Tensor split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_e MAKE_PTA_WITH_NAME(func_name3, infos_sorted, int64_t, 1, 32), {%- endif %} {%- if not dense %} - MAKE_PTA_WITH_NAME(func_name3, lxu_cache_locations_sorted, int32_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name3, {{ locs_or_addrs_tensor }}_sorted, {{ locs_or_addrs_type }}, 1, 32), use_uniq_cache_locations, MAKE_PTA_WITH_NAME(func_name3, table_unique_indices_offsets, int32_t, 1, 32), {%- endif %} @@ -1051,12 +1059,11 @@ Tensor split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_e {%- set warp_kernel = "batch_index_select_dim0_codegen_backward_kernel_warp_per_row" if is_index_select else - "split_embedding{}_backward_codegen_{}_{}{}{}_kernel_warp_per_row_1".format( + "{}_embedding{}_backward_codegen_{}_{}_kernel_warp_per_row_1".format( + mdesc, ndesc, optimizer, - wdesc, - vdesc, - gwddesc + desc_suffix, ) %} const auto backward_warp_per_row_kernel = @@ -1125,7 +1132,7 @@ Tensor split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_e MAKE_PTA_WITH_NAME(func_name4, infos_sorted, int64_t, 1, 32), {%- endif %} {%- if not dense %} - MAKE_PTA_WITH_NAME(func_name4, lxu_cache_locations_sorted, int32_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name4, {{ locs_or_addrs_tensor }}_sorted, {{ locs_or_addrs_type }}, 1, 32), use_uniq_cache_locations, MAKE_PTA_WITH_NAME(func_name4, table_unique_indices_offsets, int32_t, 1, 32), {%- endif %} @@ -1191,8 +1198,8 @@ Tensor split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_e {%- if not is_index_select %} TORCH_LIBRARY_FRAGMENT(fbgemm, m) { {%- set embedding_codegen_backward_op = - "split_embedding{}_backward_codegen_{}_{}_exact{}{}_cuda".format( - ndesc, optimizer, wdesc, vdesc, gwddesc + "{}_embedding{}_backward_codegen_{}_{}_exact_cuda".format( + mdesc, ndesc, optimizer, desc_suffix ) %} m.def("{{ embedding_codegen_backward_op }}(" @@ -1223,7 +1230,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { " Tensor indice_weights, " {%- endif %} {%- if not dense %} - " Tensor lxu_cache_locations, " + " Tensor {{ locs_or_addrs_tensor }}, " {%- endif %} {%- if not is_index_select %} " int unused_, " @@ -1261,4 +1268,4 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { ); } {%- endif %} {#-/* if not is_index_select */#} - // clang-format on +// clang-format on diff --git a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_nobag_small_template.cu b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_nobag_small_template.cu index c1b876a0b..d818d3b25 100644 --- a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_nobag_small_template.cu +++ b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_nobag_small_template.cu @@ -15,6 +15,11 @@ // See https://fburl.com/dw9ljh4h #} +{%- set mdesc = "dense" if dense else ("ssd" if ssd else "split") %} +{%- set locs_or_addrs_tensor = "ssd_row_addrs" if ssd else "lxu_cache_locations" %} +{%- set locs_or_addrs_type = "int64_t" if ssd else "int32_t" %} +{%- set locs_or_addrs_idx = "row_idx" if ssd else "cache_idx" %} + #include "fbgemm_gpu/embedding_forward_template_helpers.cuh" using Tensor = at::Tensor; @@ -31,7 +36,7 @@ __launch_bounds__(kForwardMaxThreads) __global__ void {%- if is_index_select %} batch_index_select_dim0_codegen_forward_small_kernel( {%- else %} -{{ "dense" if dense else "split" }}_embedding_nobag_codegen_forward_unweighted_small_kernel( +{{ mdesc }}_embedding_nobag_codegen_forward_unweighted_small_kernel( {%- endif %} const pta::PackedTensorAccessor64 dev_weights, {%- if not dense %} @@ -51,7 +56,7 @@ batch_index_select_dim0_codegen_forward_small_kernel( const pta::PackedTensorAccessor32 offsets, {%- endif %} {%- if not dense %} - const pta::PackedTensorAccessor32 lxu_cache_locations, + const pta::PackedTensorAccessor32<{{ locs_or_addrs_type }}, 1, at::RestrictPtrTraits> {{ locs_or_addrs_tensor }}, {%- endif %} {%- if is_index_select %} const at::PackedTensorAccessor32 output_offsets, @@ -112,7 +117,7 @@ batch_index_select_dim0_codegen_forward_small_kernel( if (placement == PlacementType::DEVICE) { weights = &dev_weights[weights_offset]; } else { - weights = &uvm_weights[weights_offset]; + weights = {{ "nullptr" if ssd else "&uvm_weights[weights_offset]" }}; } {%- else %} weights = &dev_weights[weights_offset]; @@ -131,7 +136,9 @@ batch_index_select_dim0_codegen_forward_small_kernel( int32_t l = l_start + threadIdx.x; int64_t idx = l < L ? indices[indices_start + l] : 0; {%- if not dense %} - int32_t cache_idx = (placement == PlacementType::MANAGED_CACHING && l < L) ? lxu_cache_locations[indices_start + l] : 0; + const {{ locs_or_addrs_type }} {{ locs_or_addrs_idx }} = + (placement == PlacementType::MANAGED_CACHING && l < L) + ? {{ locs_or_addrs_tensor }}[indices_start + l] : 0; {%- endif %} for (auto j = group_start; j < group_end && l_start + j < L; ++j) { int64_t idx_j = shfl_sync(idx, j); @@ -141,7 +148,8 @@ batch_index_select_dim0_codegen_forward_small_kernel( int64_t output_j = indices_start + l_start + j; {%- endif %} {%- if not dense %} - int32_t cache_idx_j = shfl_sync(cache_idx, j); + const {{ locs_or_addrs_type }} {{ locs_or_addrs_idx }}_j = + shfl_sync({{ locs_or_addrs_idx }}, j); {%- endif %} {%- if not dense %} @@ -150,10 +158,13 @@ batch_index_select_dim0_codegen_forward_small_kernel( float2 qparams_cache = make_float2(0.0f, 0.0f); {%- endif %} - auto weight_row_emb = WeightRow( - const_cast(&weights[idx_j * D_emb]), + auto weight_row_emb = WeightRowAccessor< + emb_t, cache_t, cache_t, false + >( + &weights[idx_j * D_emb], nullptr, - D); + D + ); [[maybe_unused]] float2 qparams_emb; if (std::is_same::value) { qparams_emb = weight_row_emb.load_qparams(); @@ -161,11 +172,24 @@ batch_index_select_dim0_codegen_forward_small_kernel( if (d < D) { {%- if not dense %} - if (placement == PlacementType::MANAGED_CACHING && cache_idx_j != kCacheLocationMissing) { - auto weight_row_cache = WeightRow( - const_cast(&weights[idx_j * D_emb]), - const_cast(&lxu_cache_weights[cache_idx_j][0]), - D); + if (placement == PlacementType::MANAGED_CACHING && + {{ locs_or_addrs_idx }}_j != kCacheLocationMissing) { + const cache_t* cache_weights; + {%- if ssd %} + cache_weights = reinterpret_cast( + *reinterpret_cast(&{{ locs_or_addrs_idx }}_j)); + {%- else %} + cache_weights = reinterpret_cast( + &lxu_cache_weights[{{ locs_or_addrs_idx }}_j][0]); + {%- endif %} + + auto weight_row_cache = WeightRowAccessor< + emb_t, cache_t, cache_t, true + >( + &weights[idx_j * D_emb], + cache_weights, + D + ); Vec4T weight = weight_row_cache.load(d, qparams_cache); weight.store(&output[output_j][d]); } else { @@ -203,7 +227,7 @@ template __launch_bounds__(kForwardMaxThreads) __global__ void {%- if is_index_select %} batch_index_select_dim0_codegen_forward_small_kernel {%- else %} -{{ "dense" if dense else "split" }}_embedding_nobag_codegen_forward_unweighted_small_kernel +{{ mdesc }}_embedding_nobag_codegen_forward_unweighted_small_kernel {%- endif %} < {{ emb_type }}, @@ -230,7 +254,7 @@ batch_index_select_dim0_codegen_forward_small_kernel const pta::PackedTensorAccessor32<{{ index_type }}, 1, at::RestrictPtrTraits> offsets, {%- endif %} {%- if not dense %} - const pta::PackedTensorAccessor32 lxu_cache_locations, + const pta::PackedTensorAccessor32<{{ locs_or_addrs_type }}, 1, at::RestrictPtrTraits> {{ locs_or_addrs_tensor }}, {%- endif %} {%- if is_index_select %} const pta::PackedTensorAccessor32 output_offsets, diff --git a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu index d7ce2c741..b80e04c01 100644 --- a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu +++ b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu @@ -15,7 +15,7 @@ // See https://fburl.com/dw9ljh4h #} -{%- set ddesc = "dense" if dense else "split" %} +{%- set mdesc = "dense" if dense else ("ssd" if ssd else "split") %} {%- set wdesc = "weighted" if weighted else "unweighted" %} {%- set ndesc = "_nobag" if nobag else "" %} {%- set vdesc = "_vbe" if vbe else "" %} @@ -23,7 +23,16 @@ dense, nobag, vbe, - is_index_select) %} + is_index_select, + has_global_weight_decay_support=True, + ssd=ssd) %} +{%- set gwddesc = "_gwd" if is_gwd_kernel else "" %} +{%- set desc_suffix = wdesc + vdesc + gwddesc %} + +{%- set locs_or_addrs_tensor = "ssd_row_addrs" if ssd else "lxu_cache_locations" %} +{%- set locs_or_addrs_type = "int64_t" if ssd else "int32_t" %} +{%- set locs_or_addrs_idx = "row_idx" if ssd else "cache_idx" %} + #include "fbgemm_gpu/embedding_forward_template_helpers.cuh" #include "fbgemm_gpu/split_embeddings_cache_cuda.cuh" @@ -50,7 +59,7 @@ using namespace fbgemm_gpu; idx_j D_emb lxu_cache_weights - cache_idx_j + {{ locs_or_addrs_idx }}_j idx_weight_j VEC_WIDTH D @@ -58,6 +67,16 @@ using namespace fbgemm_gpu; output_j */#} {%- macro load_and_accumulate(from_cache) %} + {%- if from_cache %} + const cache_t* cache_weights; + {%- if ssd %} + cache_weights = reinterpret_cast( + *reinterpret_cast(&{{ locs_or_addrs_idx }}_j)); + {%- else %} + cache_weights = reinterpret_cast( + &lxu_cache_weights[{{ locs_or_addrs_idx }}_j][0]); + {%- endif %} + {%- endif %} {#-/* Set the weights row */#} const auto weights_row = WeightRowAccessor < @@ -75,15 +94,14 @@ using namespace fbgemm_gpu; // memory into the registers as a side effect nullptr, // Load from the cache - const_cast(&lxu_cache_weights[cache_idx_j][0]), + cache_weights, {%- else %} // Load from the embedding table - const_cast(&weights[idx_j * D_emb]), + &weights[idx_j * D_emb], // Pass nullptr bc we are loading from the embedding table nullptr, {%- endif %} - D, - nullptr); + D); {#-/* Set the quantization params */#} {%- if from_cache %} @@ -160,7 +178,7 @@ using namespace fbgemm_gpu; {%- if not dense and lxu_miss_rate != "cache_conflict_miss_rate::all" %} // Cooperatively load the cache's indices - [[maybe_unused]] int32_t cache_idx = (use_lxu_cache && placement == PlacementType::MANAGED_CACHING && l < L) ? lxu_cache_locations[indices_start + l] : 0; + [[maybe_unused]] {{ locs_or_addrs_type }} {{ locs_or_addrs_idx }} = (use_lxu_cache && placement == PlacementType::MANAGED_CACHING && l < L) ? {{ locs_or_addrs_tensor }}[indices_start + l] : 0; {%- endif %} {%- if lxu_miss_rate == "cache_conflict_miss_rate::zero" and is_gwd_kernel %} int64_t idx = l < L ? indices[indices_start + l] : 0; // only used for accessing prev_iter @@ -191,7 +209,8 @@ using namespace fbgemm_gpu; {%- if not dense and lxu_miss_rate != "cache_conflict_miss_rate::all" %} // Load cache's index from thread j in the group - [[maybe_unused]] int32_t cache_idx_j = use_lxu_cache ? SHFL_SYNC(cache_idx, j) : 0; + [[maybe_unused]] {{ locs_or_addrs_type }} {{ locs_or_addrs_idx }}_j + = use_lxu_cache ? SHFL_SYNC({{ locs_or_addrs_idx }}, j) : 0; {%- endif %} {%- if weighted %} @@ -223,7 +242,9 @@ using namespace fbgemm_gpu; {{ load_and_accumulate(true) }} {%- else %} {#-/* Else we defer to run-time selection */#} - if (placement == PlacementType::MANAGED_CACHING && cache_idx_j != kCacheLocationMissing) { + if (placement == PlacementType::MANAGED_CACHING + && {{ locs_or_addrs_idx }}_j != kCacheLocationMissing + ) { {#-/* If the row is available in the cache, fetch from the cache */#} {{ load_and_accumulate(true) }} } else { @@ -246,7 +267,6 @@ using namespace fbgemm_gpu; with global_weight_decay, otherwise, only generate regular kernel. */ #} -{%- set gwddesc = "_gwd" if is_gwd_kernel else "" %} template < typename emb_t, typename cache_t, @@ -258,12 +278,12 @@ template < {%- if not nobag %} size_t kMaxVecsPerThread, {%- endif %} - size_t kThreadGroupSize > + size_t kThreadGroupSize> __launch_bounds__(kForwardMaxThreads) __global__ void {%- if is_index_select %} batch_index_select_dim0_codegen_forward_kernel( {%- else %} -{{ ddesc }}_embedding{{ ndesc }}_codegen_forward_{{ wdesc }}{{ vdesc }}{{ gwddesc }}_kernel( +{{ mdesc }}_embedding{{ ndesc }}_codegen_forward_{{ desc_suffix }}_kernel( {%- endif %} const pta::PackedTensorAccessor64 dev_weights, {%- if not dense %} @@ -296,7 +316,7 @@ batch_index_select_dim0_codegen_forward_kernel( pta::PackedTensorAccessor32, 1, at::RestrictPtrTraits> indice_weights, {%- endif %} {%- if not dense %} - const pta::PackedTensorAccessor32 lxu_cache_locations, + const pta::PackedTensorAccessor32<{{ locs_or_addrs_type }}, 1, at::RestrictPtrTraits> {{ locs_or_addrs_tensor }}, /* NOTE: We pass in `lxu_cache_conflict_misses = uvm_cache_stats[uvm_cache_stats_index::num_conflict_unique_misses]` as a @@ -533,13 +553,19 @@ batch_index_select_dim0_codegen_forward_kernel( embedding_forward_split_template.cu */ -{%- macro template_instantiation(emb_type, cache_type, output_type, use_cache, kMaxVecsPerThread, kThreadGroupSize) %} -{%- set gwddesc = "_gwd" if is_gwd_kernel else "" %} +{%- macro template_instantiation( + emb_type, + cache_type, + output_type, + use_cache, + kMaxVecsPerThread, + kThreadGroupSize) +%} template __launch_bounds__(kForwardMaxThreads) __global__ void {%- if is_index_select %} batch_index_select_dim0_codegen_forward_kernel {%- else %} -{{ ddesc }}_embedding{{ ndesc }}_codegen_forward_{{ wdesc }}{{ vdesc }}{{ gwddesc }}_kernel +{{ mdesc }}_embedding{{ ndesc }}_codegen_forward_{{ desc_suffix }}_kernel {%- endif %} < {{ emb_type }}, @@ -585,7 +611,7 @@ batch_index_select_dim0_codegen_forward_kernel pta::PackedTensorAccessor32, 1, at::RestrictPtrTraits> indice_weights, {%- endif %} {%- if not dense %} - const pta::PackedTensorAccessor32 lxu_cache_locations, + const pta::PackedTensorAccessor32<{{ locs_or_addrs_type }}, 1, at::RestrictPtrTraits> {{ locs_or_addrs_tensor }}, const int32_t* lxu_cache_conflict_misses, {%- endif %} {%- if is_index_select %} @@ -608,7 +634,14 @@ batch_index_select_dim0_codegen_forward_kernel {%- for emb_type in ['float', 'at::Half'] %} {%- for cache_type in ['float', 'at::Half'] %} {%- for output_type in ['float', 'at::Half', 'at::BFloat16'] %} - {{ template_instantiation(emb_type, cache_type, output_type, use_cache, kMaxVecsPerThread, kThreadGroupSize) }} + {{ template_instantiation( + emb_type, + cache_type, + output_type, + use_cache, + kMaxVecsPerThread, + kThreadGroupSize) + }} {%- endfor %} {%- endfor %} {%- endfor %} @@ -616,7 +649,7 @@ batch_index_select_dim0_codegen_forward_kernel {%- macro instantiate_templates(use_subwarp_shuffle) %} {%- set has_experimental = - has_experimental_support(dense, nobag, vbe, is_index_select, is_rocm) + has_experimental_support(dense, nobag, vbe, is_index_select, is_rocm, ssd) %} {%- set max_forward_embedding_dim = legacy_max_embedding_dim if has_experimental else max_embedding_dim diff --git a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_meta_template.cpp b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_meta_template.cpp index 15611fb0c..1e0860d00 100644 --- a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_meta_template.cpp +++ b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_meta_template.cpp @@ -17,10 +17,12 @@ // Companion template is embedding_forward_split_template.cu -{%- set ddesc = "dense" if dense else "split" %} +{%- set mdesc = "dense" if dense else ("ssd" if ssd else "split") %} {%- set wdesc = "weighted" if weighted else "unweighted" %} {%- set vdesc = "_vbe" if vbe else "" %} +{%- set locs_or_addrs_tensor = "ssd_row_addrs" if ssd else "lxu_cache_locations" %} + //////////////////////////////////////////////////////////////////////////////// // Required for op registrations #include "fbgemm_gpu/embedding_op_registration.h" @@ -47,12 +49,14 @@ static constexpr float kINT8QparamsBytes = 8; dense, nobag, vbe, - is_index_select + is_index_select, + has_global_weight_decay_support=True, + ssd=False, ) else [False]) %} {%- set gwddesc = "_gwd" if is_gwd else "" %} Tensor -{{ ddesc }}_embedding{{ ndesc }}_codegen_forward_{{ wdesc }}{{ vdesc }}{{ gwddesc }}_meta( +{{ mdesc }}_embedding{{ ndesc }}_codegen_forward_{{ wdesc }}{{ vdesc }}{{ gwddesc }}_meta( const Tensor& dev_weights, {%- if not dense %} const Tensor& uvm_weights, @@ -80,7 +84,7 @@ Tensor const Tensor& indice_weights, {%- endif %} {%- if not dense %} - const Tensor& lxu_cache_locations, + const Tensor& {{ locs_or_addrs_tensor }}, const Tensor& uvm_cache_stats, {%- endif %} const int64_t output_dtype, @@ -200,10 +204,10 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { // NB: yes cuda here {%- set embedding_codegen_forward_op = "{}_embedding{}_codegen_forward_{}{}{}_cuda".format( - ddesc, ndesc, wdesc, vdesc, gwddesc + mdesc, ndesc, wdesc, vdesc, gwddesc ) %} - m.impl("{{ embedding_codegen_forward_op }}", torch::dispatch(c10::DispatchKey::Meta, TORCH_FN({{ ddesc }}_embedding{{ ndesc }}_codegen_forward_{{ wdesc }}{{ vdesc }}{{ gwddesc }}_meta))); + m.impl("{{ embedding_codegen_forward_op }}", torch::dispatch(c10::DispatchKey::Meta, TORCH_FN({{ mdesc }}_embedding{{ ndesc }}_codegen_forward_{{ wdesc }}{{ vdesc }}{{ gwddesc }}_meta))); } {%- endfor %} {#-/* for is_gwd */#} {%- endif %} {#/* if (not nobag or (not weighted and not vbe)) */#} diff --git a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu index 42a59272d..3205d34a2 100644 --- a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu +++ b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu @@ -15,9 +15,17 @@ // See https://fburl.com/dw9ljh4h #} -{%- set ddesc = "dense" if dense else "split" %} +{%- set mdesc = "dense" if dense else ("ssd" if ssd else "split") %} {%- set wdesc = "weighted" if weighted else "unweighted" %} + +{%- macro get_desc_suffix(gwd) %} {%- set vdesc = "_vbe" if vbe else "" %} +{%- set gwddesc = "_gwd" if gwd else "" %} +{{- wdesc + vdesc + gwddesc }} +{%- endmacro %} + +{%- set locs_or_addrs_tensor = "ssd_row_addrs" if ssd else "lxu_cache_locations" %} +{%- set locs_or_addrs_type = "int64_t" if ssd else "int32_t" %} {%- if not dense and not nobag and not vbe %} #include "fbgemm_gpu/dispatch_macros.h" @@ -50,7 +58,7 @@ __launch_bounds__(kForwardMaxThreads) __global__ void {%- if is_index_select %} batch_index_select_dim0_codegen_forward_small_kernel( {%- else %} -{{ ddesc }}_embedding_nobag_codegen_forward_unweighted_small_kernel( +{{ mdesc }}_embedding_nobag_codegen_forward_unweighted_small_kernel( {%- endif %} const pta::PackedTensorAccessor64 dev_weights, {%- if not dense %} @@ -70,7 +78,7 @@ batch_index_select_dim0_codegen_forward_small_kernel( const pta::PackedTensorAccessor32 offsets, {%- endif %} {%- if not dense %} - const pta::PackedTensorAccessor32 lxu_cache_locations, + const pta::PackedTensorAccessor32<{{ locs_or_addrs_type }}, 1, at::RestrictPtrTraits> {{ locs_or_addrs_tensor }}, {%- endif %} {%- if is_index_select %} const pta::PackedTensorAccessor32 output_offsets, @@ -119,14 +127,15 @@ __global__ void split_embedding_codegen_forward_{{ wdesc }}_v2_kernel( {%- for nobag in ([True, False] if (not is_gwd) else [False]) %} {%- set ndesc = "_nobag" if nobag else "" %} {%- if is_valid_forward_config(nobag, weighted, vbe, is_index_select) %} -{%- set has_experimental = has_experimental_support(dense, nobag, vbe, is_index_select, is_rocm) %} +{%- set has_experimental = has_experimental_support(dense, nobag, vbe, is_index_select, is_rocm, ssd) %} {%- set is_gwd_kernel = is_gwd and is_valid_gwd_config( dense, nobag, vbe, - is_index_select) %} -{%- set gwddesc = "_gwd" if is_gwd_kernel else "" %} + is_index_select, + has_global_weight_decay_support=True, + ssd=ssd) %} template < typename emb_t, typename cache_t, @@ -144,7 +153,7 @@ __launch_bounds__(kForwardMaxThreads) __global__ void {%- if is_index_select %} batch_index_select_dim0_codegen_forward_kernel( {%- else %} -{{ ddesc }}_embedding{{ ndesc }}_codegen_forward_{{ wdesc }}{{ vdesc }}{{ gwddesc }}_kernel( +{{ mdesc }}_embedding{{ ndesc }}_codegen_forward_{{ get_desc_suffix(is_gwd_kernel) }}_kernel( {%- endif %} const pta::PackedTensorAccessor64 dev_weights, {%- if not dense %} @@ -177,7 +186,7 @@ batch_index_select_dim0_codegen_forward_kernel( pta::PackedTensorAccessor32, 1, at::RestrictPtrTraits> indice_weights, {%- endif %} {%- if not dense %} - const pta::PackedTensorAccessor32 lxu_cache_locations, + const pta::PackedTensorAccessor32<{{ locs_or_addrs_type }}, 1, at::RestrictPtrTraits> {{ locs_or_addrs_tensor }}, const int32_t* lxu_cache_conflict_misses, {%- endif %} {%- if is_index_select %} @@ -307,20 +316,22 @@ batch_index_select_dim0_codegen_forward_kernel( {%- for nobag in ([True, False] if (not is_gwd) else [False]) %} {%- set ndesc = "_nobag" if nobag else "" %} {%- if is_valid_forward_config(nobag, weighted, vbe, is_index_select) %} -{%- set has_experimental = has_experimental_support(dense, nobag, vbe, is_index_select, is_rocm) %} +{%- set has_experimental = has_experimental_support(dense, nobag, vbe, is_index_select, is_rocm, ssd) %} {#- /* Generate a separate cuda host to enable global weight decay using Jinja */ #} {%- set is_gwd_kernel = is_gwd and is_valid_gwd_config( dense, nobag, vbe, - is_index_select) %} -{%- set gwddesc = "_gwd" if is_gwd_kernel else "" %} + is_index_select, + has_global_weight_decay_support=True, + ssd=ssd) %} +{%- set desc_suffix = get_desc_suffix(is_gwd_kernel) %} Tensor {%- if is_index_select %} batch_index_select_dim0_codegen_forward_cuda( {%- else %} -{{ ddesc }}_embedding{{ ndesc }}_codegen_forward_{{ wdesc }}{{ vdesc }}{{ gwddesc }}_cuda( +{{ mdesc }}_embedding{{ ndesc }}_codegen_forward_{{ desc_suffix }}_cuda( {%- endif %} const Tensor& dev_weights, {%- if not dense %} @@ -351,7 +362,7 @@ batch_index_select_dim0_codegen_forward_cuda( const Tensor& indice_weights, {%- endif %} {%- if not dense %} - const Tensor& lxu_cache_locations, + const Tensor& {{ locs_or_addrs_tensor }}, const Tensor& uvm_cache_stats, {%- endif %} const int64_t output_dtype, @@ -391,7 +402,6 @@ batch_index_select_dim0_codegen_forward_cuda( const int64_t total_D = total_D_.guard_int(__FILE__, __LINE__); {%- endif %} - {%- if not nobag or is_index_select %} const int64_t max_D = max_D_.guard_int(__FILE__, __LINE__); {%- endif %} @@ -417,7 +427,7 @@ batch_index_select_dim0_codegen_forward_cuda( indice_weights, {%- endif %} {%- if not dense %} - lxu_cache_locations, + {{ locs_or_addrs_tensor }}, {%- endif %} {%- if vbe %} vbe_row_output_offsets, @@ -576,12 +586,17 @@ batch_index_select_dim0_codegen_forward_cuda( {%- set nobag_small_kernel = "batch_index_select_dim0_codegen_forward_small_kernel" if is_index_select else - "{}_embedding_nobag_codegen_forward_unweighted_small_kernel".format(ddesc) + "{}_embedding_nobag_codegen_forward_unweighted_small_kernel".format(mdesc) %} #ifdef FBGEMM_GPU_MEMCHECK const auto func_name = "{{ nobag_small_kernel }}"; #endif - {{ nobag_small_kernel }} + {{ nobag_small_kernel }}< + emb_t, + cache_t, + output_t, + int64_t, + kEmbeddingSize / 4> <<< div_round_up(total_B, kForwardMaxThreads / kWarpSize), dim3(kWarpSize, kForwardMaxThreads / kWarpSize), @@ -606,7 +621,7 @@ batch_index_select_dim0_codegen_forward_cuda( MAKE_PTA_WITH_NAME(func_name, offsets, int64_t, 1, 32), {%- endif %} {%- if not dense %} - MAKE_PTA_WITH_NAME(func_name, lxu_cache_locations, int32_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name, {{ locs_or_addrs_tensor }}, {{ locs_or_addrs_type }}, 1, 32), {%- endif %} {%- if is_index_select %} MAKE_PTA_WITH_NAME(func_name, output_offsets, int64_t, 1, 32), @@ -617,6 +632,7 @@ batch_index_select_dim0_codegen_forward_cuda( MAKE_PTA_WITH_NAME(func_name, output, output_t, {{ "1" if is_index_select else "2" }}, 64) ); + C10_CUDA_KERNEL_LAUNCH_CHECK(); return; }); @@ -625,7 +641,7 @@ batch_index_select_dim0_codegen_forward_cuda( {%- set nobag_kernel = "batch_index_select_dim0_codegen_forward_kernel" if is_index_select else - "{}_embedding_nobag_codegen_forward_unweighted_kernel".format(ddesc) + "{}_embedding_nobag_codegen_forward_unweighted_kernel".format(mdesc) %} #ifdef FBGEMM_GPU_MEMCHECK const auto func_name = "{{ nobag_kernel }}"; @@ -661,7 +677,7 @@ batch_index_select_dim0_codegen_forward_cuda( MAKE_PTA_WITH_NAME(func_name, offsets, int64_t, 1, 32), {%- endif %} {%- if not dense %} - MAKE_PTA_WITH_NAME(func_name, lxu_cache_locations, int32_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name, {{ locs_or_addrs_tensor }}, {{ locs_or_addrs_type }}, 1, 32), uvm_cache_stats.size(0) == 0 ? nullptr : (uvm_cache_stats.data_ptr() + uvm_cache_stats_index::num_conflict_unique_misses), @@ -674,7 +690,6 @@ batch_index_select_dim0_codegen_forward_cuda( {%- endif %} MAKE_PTA_WITH_NAME(func_name, output, output_t, {{ "1" if is_index_select else "2" }}, 64) ); - C10_CUDA_KERNEL_LAUNCH_CHECK(); return; }); @@ -692,7 +707,7 @@ batch_index_select_dim0_codegen_forward_cuda( {{ dispatcher }}(max_D, [&] { #ifdef FBGEMM_GPU_MEMCHECK - const auto func_name = "{{ ddesc }}_embedding_codegen_forward_{{ wdesc }}{{ vdesc }}{{ gwddesc }}_kernel"; + const auto func_name = "{{ mdesc }}_embedding_codegen_forward_{{ desc_suffix }}_kernel"; #endif // Other components in TBE (backward, backward_indice_weights) use // kFixedMaxVecsPerThread. Thus, the codegen generates @@ -700,7 +715,7 @@ batch_index_select_dim0_codegen_forward_cuda( // kMaxVecsPerThread and kFixedMaxVecsPerThread are the same // forward constexpr auto kMaxVecsPerThread = kFixedMaxVecsPerThread; - {{ ddesc }}_embedding_codegen_forward_{{ wdesc }}{{ vdesc }}{{ gwddesc }}_kernel + {{ mdesc }}_embedding_codegen_forward_{{ desc_suffix }}_kernel () + uvm_cache_stats_index::num_conflict_unique_misses), @@ -753,7 +768,6 @@ batch_index_select_dim0_codegen_forward_cuda( {%- endif %} // if not dense MAKE_PTA_WITH_NAME(func_name, output, output_t, 2, 64) ); - C10_CUDA_KERNEL_LAUNCH_CHECK(); {%- if vbe %} output = output.reshape({-1}); @@ -777,8 +791,10 @@ batch_index_select_dim0_codegen_forward_cuda( const uint32_t num_warps_per_threadblock = kForwardMaxThreads / kWarpSize; const auto kernel_func = - (use_lxu_cache ? split_embedding_codegen_forward_{{ wdesc }}_v2_kernel - : split_embedding_codegen_forward_{{ wdesc }}_v2_kernel); + (use_lxu_cache ? split_embedding_codegen_forward_{{ wdesc }}_v2_kernel< + emb_t, cache_t, output_t, int64_t, true> + : split_embedding_codegen_forward_{{ wdesc }}_v2_kernel< + emb_t, cache_t, output_t, int64_t, false>); kernel_func <<< @@ -823,8 +839,8 @@ batch_index_select_dim0_codegen_forward_cuda( {%- if not is_index_select %} TORCH_LIBRARY_FRAGMENT(fbgemm, m) { {%- set embedding_codegen_forward_op = - "{}_embedding{}_codegen_forward_{}{}{}_cuda".format( - ddesc, ndesc, wdesc, vdesc, gwddesc + "{}_embedding{}_codegen_forward_{}_cuda".format( + mdesc, ndesc, desc_suffix ) %} m.def("{{ embedding_codegen_forward_op }}(" @@ -851,7 +867,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { " Tensor indice_weights, " {%- endif %} {%- if not dense %} - " Tensor lxu_cache_locations, " + " Tensor {{ locs_or_addrs_tensor }}, " " Tensor uvm_cache_stats, " {%- endif %} " int output_dtype, " diff --git a/fbgemm_gpu/codegen/training/optimizer/embedding_optimizer_split_device_kernel_template.cuh b/fbgemm_gpu/codegen/training/optimizer/embedding_optimizer_split_device_kernel_template.cuh index d4cbb962a..cb3977cab 100644 --- a/fbgemm_gpu/codegen/training/optimizer/embedding_optimizer_split_device_kernel_template.cuh +++ b/fbgemm_gpu/codegen/training/optimizer/embedding_optimizer_split_device_kernel_template.cuh @@ -14,6 +14,11 @@ #define GROUP_REDUCE_ALL_SUM(val, ...) \ warpReduceAllSum<__VA_ARGS__, kThreadGroupSize>(val, shfl_sync_mask) +{%- set mdesc = "ssd" if ssd else "split" %} +{%- set locs_or_addrs_tensor = "ssd_row_addrs" if ssd else "lxu_cache_locations" %} +{%- set locs_or_addrs_type = "int64_t" if ssd else "int32_t" %} +{%- set locs_or_addrs_idx = "row_idx" if ssd else "cache_idx" %} + using namespace fbgemm_gpu; template < @@ -28,13 +33,13 @@ template < int32_t VEC_WIDTH, bool kUseVecBlocking > -DEVICE_INLINE void split_{{ optimizer }}_table_update_kernel( +DEVICE_INLINE void {{ mdesc }}_{{ optimizer }}_table_update_kernel( pta::PackedTensorAccessor64& dev_weights, pta::PackedTensorAccessor64& uvm_weights, pta::PackedTensorAccessor64& lxu_cache_weights, const pta::PackedTensorAccessor32& weights_placements, const pta::PackedTensorAccessor32& weights_offsets, - const pta::PackedTensorAccessor32& sorted_lxu_cache_locations, + const pta::PackedTensorAccessor32<{{ locs_or_addrs_type }}, 1, at::RestrictPtrTraits>& sorted_{{ locs_or_addrs_tensor }}, Vec4TAcc* grad_sum, Vec4TAcc* smem_grad_sum, Vec4TAcc* shared_weight_update_row, @@ -71,10 +76,15 @@ DEVICE_INLINE void split_{{ optimizer }}_table_update_kernel( weights = &uvm_weights[weights_offset + idx * D_emb]; } if (weights_placement == PlacementType::MANAGED_CACHING) { - const int32_t cache_idx = sorted_lxu_cache_locations[cache_loc_run_id]; - if (cache_idx != kCacheLocationMissing) { - cache_weights = &lxu_cache_weights[cache_idx][0]; + const auto {{ locs_or_addrs_idx }} = sorted_{{ locs_or_addrs_tensor }}[cache_loc_run_id]; + {%- if ssd %} + cache_weights = reinterpret_cast( + *reinterpret_cast(&{{ locs_or_addrs_idx }})); + {%- else %} + if ({{ locs_or_addrs_idx }} != kCacheLocationMissing) { + cache_weights = &lxu_cache_weights[{{ locs_or_addrs_idx }}][0]; } + {%- endif %} } {%- for tensor in args.split_tensors %} {{ args.split_tensor_types[tensor] }}* __restrict__ {{ tensor }}; diff --git a/fbgemm_gpu/codegen/training/python/__init__.template b/fbgemm_gpu/codegen/training/python/__init__.template index 42f49ee3c..66c774282 100644 --- a/fbgemm_gpu/codegen/training/python/__init__.template +++ b/fbgemm_gpu/codegen/training/python/__init__.template @@ -6,14 +6,22 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import fbgemm_gpu.split_embedding_codegen_lookup_invokers.lookup_adagrad as lookup_adagrad # noqa: F401 -import fbgemm_gpu.split_embedding_codegen_lookup_invokers.lookup_adam as lookup_adam # noqa: F401 +# All optimizers import fbgemm_gpu.split_embedding_codegen_lookup_invokers.lookup_args as lookup_args # noqa: F401 -import fbgemm_gpu.split_embedding_codegen_lookup_invokers.lookup_lamb as lookup_lamb # noqa: F401 -import fbgemm_gpu.split_embedding_codegen_lookup_invokers.lookup_lars_sgd as lookup_lars_sgd # noqa: F401 -import fbgemm_gpu.split_embedding_codegen_lookup_invokers.lookup_partial_rowwise_adam as lookup_partial_rowwise_adam # noqa: F401 -import fbgemm_gpu.split_embedding_codegen_lookup_invokers.lookup_partial_rowwise_lamb as lookup_partial_rowwise_lamb # noqa: F401 -import fbgemm_gpu.split_embedding_codegen_lookup_invokers.lookup_rowwise_adagrad as lookup_rowwise_adagrad # noqa: F401 -import fbgemm_gpu.split_embedding_codegen_lookup_invokers.lookup_rowwise_adagrad_with_counter as lookup_rowwise_adagrad_with_counter # noqa: F401 -import fbgemm_gpu.split_embedding_codegen_lookup_invokers.lookup_sgd as lookup_sgd # noqa: F401 -import fbgemm_gpu.split_embedding_codegen_lookup_invokers.lookup_none as lookup_none # noqa: F401 +{%- for optim in all_optimizers %} +import fbgemm_gpu.split_embedding_codegen_lookup_invokers.lookup_{{ optim }} as lookup_{{optim}} # noqa: F401 +{%- endfor %} + +# SSD optimizers (putting them under try-except for BC as they are +# experimental ops which can be removed/updated in the future) +try: + import fbgemm_gpu.split_embedding_codegen_lookup_invokers.lookup_args_ssd as lookup_args_ssd + {%- for optim in ssd_optimizers %} + import fbgemm_gpu.split_embedding_codegen_lookup_invokers.lookup_{{ optim }}_ssd as lookup_{{ optim }}_ssd + {%- endfor %} +except: + import logging + logging.warn("fbgemm_gpu.split_embedding_codegen_lookup_invokers.lookup_args_ssd import failed") + {%- for optim in ssd_optims %} + logging.warn("fbgemm_gpu.split_embedding_codegen_lookup_invokers.lookup_{{ optim }}_ssd import failed") + {%- endfor %} diff --git a/fbgemm_gpu/codegen/training/python/lookup_args.py b/fbgemm_gpu/codegen/training/python/lookup_args.template similarity index 94% rename from fbgemm_gpu/codegen/training/python/lookup_args.py rename to fbgemm_gpu/codegen/training/python/lookup_args.template index eb752ee6a..357aad622 100644 --- a/fbgemm_gpu/codegen/training/python/lookup_args.py +++ b/fbgemm_gpu/codegen/training/python/lookup_args.template @@ -7,7 +7,7 @@ # pyre-strict -from typing import NamedTuple, Optional +from typing import NamedTuple, Optional, Dict import torch @@ -46,6 +46,9 @@ class CommonArgs(NamedTuple): is_experimental: bool use_uniq_cache_locations_bwd: bool use_homogeneous_placements: bool + {%- if ssd %} + ssd_tensors: Dict[str, torch.Tensor] + {%- endif %} class OptimizerArgs(NamedTuple): diff --git a/fbgemm_gpu/codegen/training/python/split_embedding_codegen_lookup_invoker.template b/fbgemm_gpu/codegen/training/python/split_embedding_codegen_lookup_invoker.template index cedc0c41c..2de471eea 100644 --- a/fbgemm_gpu/codegen/training/python/split_embedding_codegen_lookup_invoker.template +++ b/fbgemm_gpu/codegen/training/python/split_embedding_codegen_lookup_invoker.template @@ -7,12 +7,14 @@ # pyre-ignore-all-errors +{%- set mdesc = "ssd" if ssd else "split" %} +{%- set sdesc = "_ssd" if ssd else "" %} + import torch {%- if is_experimental_optimizer %} import warnings {%- endif %} -from .lookup_args import * - +from .lookup_args{{ sdesc }} import * {%- if is_fbcode %} @@ -73,7 +75,7 @@ def invoke( ) {%- endif %} - {%- if has_cpu_support %} + {%- if has_cpu_support and not ssd %} if (common_args.host_weights.numel() > 0): return torch.ops.fbgemm.split_embedding_codegen_lookup_{{ optimizer }}_function_cpu( # common_args @@ -192,7 +194,19 @@ def invoke( {%- if has_gpu_support %} vbe_metadata = common_args.vbe_metadata - return torch.ops.fbgemm.split_embedding_codegen_lookup_{{ optimizer }}_function( + + {%- if ssd %} + ssd_tensors = [] + {%- for tensor in ssd_tensors %} + assert "{{ tensor }}" in common_args.ssd_tensors, ( + "{{ tensor }} must be in common_args.ssd_tensors. " + "Please check the backend version" + ) + ssd_tensors.append(common_args.ssd_tensors["{{ tensor }}"]) + {%- endfor %} + {%- endif %} + + return torch.ops.fbgemm.{{ mdesc }}_embedding_codegen_lookup_{{ optimizer }}_function( # common_args {%- if not dense %} placeholder_autograd_tensor=common_args.placeholder_autograd_tensor, @@ -214,6 +228,9 @@ def invoke( feature_requires_grad=common_args.feature_requires_grad, lxu_cache_locations=common_args.lxu_cache_locations, uvm_cache_stats=common_args.uvm_cache_stats, + {%- if ssd %} + ssd_tensors=ssd_tensors, + {%- endif %} # VBE metadata B_offsets=vbe_metadata.B_offsets, vbe_output_offsets_feature_rank=vbe_metadata.output_offsets_feature_rank, diff --git a/fbgemm_gpu/fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py b/fbgemm_gpu/fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py index b483852c1..3518bf868 100644 --- a/fbgemm_gpu/fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +++ b/fbgemm_gpu/fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py @@ -17,7 +17,7 @@ import torch # usort:skip import fbgemm_gpu.split_embedding_codegen_lookup_invokers as invokers -from fbgemm_gpu.split_embedding_configs import SparseType +from fbgemm_gpu.split_embedding_configs import EmbOptimType as OptimType, SparseType from fbgemm_gpu.split_table_batched_embeddings_ops_common import ( CacheAlgorithm, @@ -93,6 +93,7 @@ def __init__( ssd_uniform_init_lower: float = -0.01, ssd_uniform_init_upper: float = 0.01, ssd_block_cache_size: int = 0, + optimizer: OptimType = OptimType.EXACT_ROWWISE_ADAGRAD, # General Optimizer args stochastic_rounding: bool = True, gradient_clipping: bool = False, @@ -242,6 +243,9 @@ def __init__( self.ssd_set_start = torch.cuda.Event() self.ssd_set_end = torch.cuda.Event() self.timesteps_prefetched: List[int] = [] + self.ssd_scratch_pads: List[Tuple[Tensor, Tensor, Tensor]] = [] + # TODO: add type annotation + self.ssd_prefetch_data = [] if weight_decay_mode == WeightDecayMode.COUNTER or counter_based_regularization: raise AssertionError( @@ -255,7 +259,7 @@ def __init__( ) cowclip_regularization = CowClipDefinition() - self.optimizer_args = invokers.lookup_args.OptimizerArgs( + self.optimizer_args = invokers.lookup_args_ssd.OptimizerArgs( stochastic_rounding=stochastic_rounding, gradient_clipping=gradient_clipping, max_gradient=max_gradient, @@ -301,9 +305,10 @@ def __init__( dtype=torch.int32, ), ) - weights_offsets = [0] + list( - itertools.accumulate([row * dim for (row, dim) in zip(rows, dims)]) - ) + # weights_offsets = [0] + list( + # itertools.accumulate([row * dim for (row, dim) in zip(rows, dims)]) + # ) + weights_offsets = [0] * (len(rows) + 1) self.register_buffer( "weights_offsets", torch.tensor( @@ -352,8 +357,92 @@ def __init__( torch.zeros(0, device=self.current_device, dtype=torch.float) ) + # Register backward hook for evicting rows from a scratch pad to SSD + # post backward + self.placeholder_autograd_tensor.register_hook(self._evict_from_scratch_pad) + + assert optimizer in ( + OptimType.EXACT_ROWWISE_ADAGRAD, + ), f"Optimizer {optimizer} is not supported by SSDTableBatchedEmbeddingBags" + self.optimizer = optimizer + + def to_pinned_cpu(self, t: torch.Tensor) -> torch.Tensor: + t_cpu = torch.empty(t.shape, pin_memory=True, dtype=t.dtype) + t_cpu.copy_(t, non_blocking=True) + return t_cpu + + def evict( + self, evicted_rows: Tensor, evicted_indices: Tensor, actions_count_cpu: Tensor + ) -> None: + """ + Evict data from the given input tensors to SSD via RocksDB + """ + with torch.cuda.stream(self.ssd_stream): + self.ssd_stream.wait_event(self.ssd_set_start) + evicted_rows_cpu = self.to_pinned_cpu(evicted_rows) + evicted_indices_cpu = self.to_pinned_cpu(evicted_indices) + evicted_rows.record_stream(self.ssd_stream) + evicted_indices.record_stream(self.ssd_stream) + self.ssd_db.set_cuda( + evicted_indices_cpu, evicted_rows_cpu, actions_count_cpu, self.timestep + ) + # TODO: is this needed? + # Need a way to synchronize + # actions_count_cpu.record_stream(self.ssd_stream) + self.ssd_stream.record_event(self.ssd_set_end) + + def _evict_from_scratch_pad(self, grad: Tensor) -> None: + assert len(self.ssd_scratch_pads) > 0, "There must be at least one scratch pad" + (inserted_rows_gpu, post_bwd_evicted_indices, actions_count_cpu) = ( + self.ssd_scratch_pads.pop(0) + ) + self.evict(inserted_rows_gpu, post_bwd_evicted_indices, actions_count_cpu) + + def _compute_cache_ptrs( + self, + linear_cache_indices: torch.Tensor, + assigned_cache_slots: torch.Tensor, + linear_index_inverse_indices: torch.Tensor, + unique_indices_count_cumsum: torch.Tensor, + cache_set_inverse_indices: torch.Tensor, + inserted_rows_gpu: torch.Tensor, + unique_indices_length: torch.Tensor, + inserted_indices: torch.Tensor, + actions_count_cpu: torch.Tensor, + ) -> torch.Tensor: + lxu_cache_locations = torch.ops.fbgemm.lxu_cache_lookup( + linear_cache_indices, + self.lxu_cache_state, + self.hash_size_cumsum[-1].item(), + ) + lxu_cache_ptrs, post_bwd_evicted_indices = ( + torch.ops.fbgemm.ssd_generate_row_addrs( + lxu_cache_locations, + assigned_cache_slots, + linear_index_inverse_indices, + unique_indices_count_cumsum, + cache_set_inverse_indices, + self.lxu_cache_weights, + inserted_rows_gpu, + unique_indices_length, + inserted_indices, + ) + ) + # Store scratch pad info for post backward eviction + self.ssd_scratch_pads.append( + (inserted_rows_gpu, post_bwd_evicted_indices, actions_count_cpu) + ) + assert lxu_cache_ptrs[lxu_cache_ptrs == 0].numel() == 0 + return ( + lxu_cache_ptrs, + inserted_rows_gpu, + post_bwd_evicted_indices, + actions_count_cpu, + ) + def prefetch(self, indices: Tensor, offsets: Tensor) -> Optional[Tensor]: (indices, offsets) = indices.long(), offsets.long() + linear_cache_indices = torch.ops.fbgemm.linearize_cache_indices( self.hash_size_cumsum, indices, @@ -366,10 +455,10 @@ def prefetch(self, indices: Tensor, offsets: Tensor) -> Optional[Tensor]: evicted_indices, assigned_cache_slots, actions_count_gpu, - _, - _, - _, - _, + linear_index_inverse_indices, + unique_indices_count_cumsum, + cache_set_inverse_indices, + unique_indices_length, ) = torch.ops.fbgemm.ssd_cache_populate_actions( linear_cache_indices, self.total_hash_size, @@ -379,15 +468,10 @@ def prefetch(self, indices: Tensor, offsets: Tensor) -> Optional[Tensor]: self.lru_state, ) - def to_pinned_cpu(t: torch.Tensor) -> torch.Tensor: - t_cpu = torch.empty(t.shape, pin_memory=True, dtype=t.dtype) - t_cpu.copy_(t, non_blocking=True) - return t_cpu - - actions_count_cpu = to_pinned_cpu(actions_count_gpu) + actions_count_cpu = self.to_pinned_cpu(actions_count_gpu) assigned_cache_slots = assigned_cache_slots.long() evicted_rows = self.lxu_cache_weights[ - assigned_cache_slots.clamp_(min=0).long(), : + assigned_cache_slots.clamp(min=0).long(), : ] inserted_rows = torch.empty( evicted_rows.shape, @@ -400,14 +484,13 @@ def to_pinned_cpu(t: torch.Tensor) -> torch.Tensor: # Ensure the previous iterations l3_db.set(..) has completed. current_stream.wait_event(self.ssd_set_end) self.ssd_db.get_cuda( - to_pinned_cpu(inserted_indices), inserted_rows, actions_count_cpu + self.to_pinned_cpu(inserted_indices), inserted_rows, actions_count_cpu ) current_stream.record_event(self.ssd_set_start) # TODO: T123943415 T123943414 this is a big copy that is (mostly) unnecessary with a decent cache hit rate. # Should we allocate on HBM? inserted_rows_gpu = inserted_rows.cuda(non_blocking=True) - # self.lxu_cache_weights[assigned_cache_slots, :] = inserted_rows.cuda(non_blocking=True) torch.ops.fbgemm.masked_index_put( self.lxu_cache_weights, assigned_cache_slots, @@ -415,20 +498,23 @@ def to_pinned_cpu(t: torch.Tensor) -> torch.Tensor: actions_count_gpu, ) - with torch.cuda.stream(self.ssd_stream): - self.ssd_stream.wait_event(self.ssd_set_start) - evicted_rows_cpu = to_pinned_cpu(evicted_rows) - evicted_indices_cpu = to_pinned_cpu(evicted_indices) - evicted_rows.record_stream(self.ssd_stream) - evicted_indices.record_stream(self.ssd_stream) - self.ssd_db.set_cuda( - evicted_indices_cpu, evicted_rows_cpu, actions_count_cpu, self.timestep + # Evict rows from cache to SSD + self.evict(evicted_rows, evicted_indices, actions_count_cpu) + + # TODO: keep only necessary tensors + self.ssd_prefetch_data.append( + ( + linear_cache_indices, + assigned_cache_slots, + linear_index_inverse_indices, + unique_indices_count_cumsum, + cache_set_inverse_indices, + inserted_rows_gpu, + unique_indices_length, + inserted_indices, + actions_count_cpu, ) - # TODO: is this needed? - # Need a way to synchronize - # actions_count_cpu.record_stream(self.ssd_stream) - self.ssd_stream.record_event(self.ssd_set_end) - return linear_cache_indices + ) def forward( self, @@ -443,19 +529,18 @@ def forward( per_sample_weights = per_sample_weights.float() if len(self.timesteps_prefetched) == 0: with record_function("## prefetch ##"): - linear_cache_indices = self.prefetch(indices, offsets) - else: - linear_cache_indices = torch.ops.fbgemm.linearize_cache_indices( - self.hash_size_cumsum, - indices, - offsets, - ) - lxu_cache_locations = torch.ops.fbgemm.lxu_cache_lookup( - linear_cache_indices, - self.lxu_cache_state, - self.hash_size_cumsum[-1].item(), - ) - common_args = invokers.lookup_args.CommonArgs( + self.prefetch(indices, offsets) + assert len(self.ssd_prefetch_data) > 0 + + prefetch_data = self.ssd_prefetch_data.pop(0) + ( + lxu_cache_ptrs, + inserted_rows_gpu, + post_bwd_evicted_indices, + actions_count_cpu, + ) = self._compute_cache_ptrs(*prefetch_data) + + common_args = invokers.lookup_args_ssd.CommonArgs( placeholder_autograd_tensor=self.placeholder_autograd_tensor, output_dtype=SparseType.FP32.as_int(), dev_weights=self.weights_dev, @@ -474,9 +559,9 @@ def forward( pooling_mode=self.pooling_mode, indice_weights=per_sample_weights, feature_requires_grad=feature_requires_grad, - lxu_cache_locations=lxu_cache_locations, + lxu_cache_locations=lxu_cache_ptrs, uvm_cache_stats=None, - vbe_metadata=invokers.lookup_args.VBEMetadata( + vbe_metadata=invokers.lookup_args_ssd.VBEMetadata( B_offsets=None, output_offsets_feature_rank=None, B_offsets_rank_per_feature=None, @@ -488,9 +573,25 @@ def forward( is_experimental=False, use_uniq_cache_locations_bwd=False, use_homogeneous_placements=True, + # The keys for ssd_tensors are controlled by ssd_tensors in + # codegen/genscript/optimizer_args.py + ssd_tensors={ + "row_addrs": lxu_cache_ptrs, + "inserted_rows": inserted_rows_gpu, + "post_bwd_evicted_indices": post_bwd_evicted_indices, + "actions_count": actions_count_cpu, + }, ) - momentum1 = invokers.lookup_args.Momentum( + self.timesteps_prefetched.pop(0) + + if self.optimizer == OptimType.EXACT_SGD: + raise AssertionError( + "SSDTableBatchedEmbeddingBags currently does not support SGD" + ) + return invokers.lookup_sgd_ssd.invoke(common_args, self.optimizer_args) + + momentum1 = invokers.lookup_args_ssd.Momentum( dev=self.momentum1_dev, host=self.momentum1_host, uvm=self.momentum1_uvm, @@ -498,10 +599,10 @@ def forward( placements=self.momentum1_placements, ) - self.timesteps_prefetched.pop(0) - return invokers.lookup_rowwise_adagrad.invoke( - common_args, self.optimizer_args, momentum1 - ) + if self.optimizer == OptimType.EXACT_ROWWISE_ADAGRAD: + return invokers.lookup_rowwise_adagrad_ssd.invoke( + common_args, self.optimizer_args, momentum1 + ) @torch.jit.ignore def debug_split_optimizer_states(self) -> List[Tuple[torch.Tensor]]: @@ -892,7 +993,6 @@ def prefetch(self, indices: Tensor, offsets: Tensor) -> Tensor: 1, # for now assume prefetch_dist == 1 self.lru_state, ) - actions_count_cpu = torch.empty( actions_count_gpu.shape, pin_memory=True, dtype=actions_count_gpu.dtype ) diff --git a/fbgemm_gpu/include/fbgemm_gpu/fbgemm_cuda_utils.cuh b/fbgemm_gpu/include/fbgemm_gpu/fbgemm_cuda_utils.cuh index 46fd14c9e..ab68ee8da 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/fbgemm_cuda_utils.cuh +++ b/fbgemm_gpu/include/fbgemm_gpu/fbgemm_cuda_utils.cuh @@ -1279,15 +1279,12 @@ struct WeightRow { template struct WeightRowAccessor { - emb_t* row_; - cache_t* cache_row_; - int dim_; + const emb_t* row_; + const cache_t* cache_row_; + const int dim_; - DEVICE_INLINE WeightRowAccessor( - emb_t* row, - cache_t* cache_row, - int dim, - StochasticRoundingRNGState* stoc_rounding_state) + DEVICE_INLINE + WeightRowAccessor(const emb_t* row, const cache_t* cache_row, const int dim) : row_(row), cache_row_(cache_row), dim_(dim) {} DEVICE_INLINE Vec4T load(const int32_t d, const float2 qparams) const { diff --git a/fbgemm_gpu/test/ssd_split_table_batched_embeddings_test.py b/fbgemm_gpu/test/ssd_split_table_batched_embeddings_test.py index 5f3c7ae33..24c4f8936 100644 --- a/fbgemm_gpu/test/ssd_split_table_batched_embeddings_test.py +++ b/fbgemm_gpu/test/ssd_split_table_batched_embeddings_test.py @@ -16,7 +16,7 @@ import numpy as np import torch -from fbgemm_gpu.split_embedding_configs import SparseType +from fbgemm_gpu.split_embedding_configs import EmbOptimType as OptimType, SparseType from fbgemm_gpu.split_embedding_utils import ( b_indices, fake_quantize_embs, @@ -34,7 +34,7 @@ SSDTableBatchedEmbeddingBags, ) -from hypothesis import given, settings, Verbosity +from hypothesis import assume, given, settings, Verbosity MAX_EXAMPLES = 40 @@ -116,6 +116,9 @@ def generate_ssd_tbes( lr: float = 0.01, # from SSDTableBatchedEmbeddingBags eps: float = 1.0e-8, # from SSDTableBatchedEmbeddingBags ssd_shards: int = 1, # from SSDTableBatchedEmbeddingBags + optimizer: OptimType = OptimType.EXACT_ROWWISE_ADAGRAD, + cache_set_scale: float = 1.0, + pooling_mode: bool = PoolingMode.SUM, ) -> Tuple[SSDTableBatchedEmbeddingBags, List[torch.nn.EmbeddingBag]]: """ Generate embedding modules (i,e., SSDTableBatchedEmbeddingBags and @@ -130,23 +133,45 @@ def generate_ssd_tbes( Es = [E] * T feature_table_map = list(range(T)) + if pooling_mode == PoolingMode.SUM: + mode = "sum" + do_pooling = True + elif pooling_mode == PoolingMode.MEAN: + mode = "mean" + do_pooling = True + elif pooling_mode == PoolingMode.NONE: + mode = "sum" + do_pooling = False + else: + # This proves that we have exhaustively checked all PoolingModes + raise RuntimeError("Unknown PoolingMode!") + # Generate torch EmbeddingBag - emb_ref = [ - torch.nn.EmbeddingBag(E, D, mode="sum", sparse=True).cuda() - for (E, D) in zip(Es, Ds) - ] + if do_pooling: + emb_ref = [ + torch.nn.EmbeddingBag(E, D, mode=mode, sparse=True).cuda() + for (E, D) in zip(Es, Ds) + ] + else: + emb_ref = [ + torch.nn.Embedding(E, D, sparse=True).cuda() for (E, D) in zip(Es, Ds) + ] + + cache_sets = max(int(max(T * B * L, 1) * cache_set_scale), 1) # Generate TBE SSD emb = SSDTableBatchedEmbeddingBags( embedding_specs=[(E, D) for (E, D) in zip(Es, Ds)], feature_table_map=feature_table_map, ssd_storage_directory=tempfile.mkdtemp(), - cache_sets=max(T * B * L, 1), + cache_sets=cache_sets, ssd_uniform_init_lower=-0.1, ssd_uniform_init_upper=0.1, learning_rate=lr, eps=eps, ssd_shards=ssd_shards, + optimizer=optimizer, + pooling_mode=pooling_mode, ).cuda() # Initialize TBE SSD weights @@ -160,6 +185,17 @@ def generate_ssd_tbes( return emb, emb_ref + def concat_ref_tensors( + self, + tensors: List[torch.Tensor], + do_pooling: bool, + B: int, + D: int, + ) -> torch.Tensor: + if do_pooling: + return torch.cat([t.view(B, -1) for t in tensors], dim=1) + return torch.cat(tensors, dim=0).view(-1, D) + def execute_ssd_forward_( self, emb: SSDTableBatchedEmbeddingBags, @@ -172,23 +208,40 @@ def execute_ssd_forward_( B: int, L: int, weighted: bool, + i: int = 0, ) -> Tuple[List[torch.Tensor], torch.Tensor]: """ Execute the forward functions of SSDTableBatchedEmbeddingBags and torch.nn.EmbeddingBag and compare outputs """ + assert len(emb_ref) == len(indices_list) + do_pooling = emb.pooling_mode != PoolingMode.NONE # Execute torch EmbeddingBag forward output_ref_list = ( - [b_indices(emb_, indices) for (emb_, indices) in zip(emb_ref, indices_list)] + [ + b_indices(emb_, indices, do_pooling=do_pooling) + for (emb_, indices) in zip(emb_ref, indices_list) + ] if not weighted else [ - b_indices(emb_, indices, per_sample_weights=per_sample_weights.view(-1)) + b_indices( + emb_, + indices, + per_sample_weights=per_sample_weights.view(-1), + do_pooling=do_pooling, + ) for (emb_, indices, per_sample_weights) in zip( emb_ref, indices_list, per_sample_weights_list ) ] ) - output_ref = torch.cat([out.view(B, -1) for out in output_ref_list], dim=1) + + output_ref = self.concat_ref_tensors( + output_ref_list, + do_pooling, + B, + emb.embedding_specs[0][1], + ) # Execute TBE SSD forward output = ( @@ -213,11 +266,25 @@ def execute_ssd_forward_( log_E=st.integers(min_value=3, max_value=5), L=st.integers(min_value=0, max_value=20), weighted=st.booleans(), + cache_set_scale=st.sampled_from([0.0, 0.005, 1]), + pooling_mode=st.sampled_from( + [PoolingMode.NONE, PoolingMode.SUM, PoolingMode.MEAN] + ), ) @settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None) def test_ssd_forward( - self, T: int, D: int, B: int, log_E: int, L: int, weighted: bool + self, + T: int, + D: int, + B: int, + log_E: int, + L: int, + weighted: bool, + cache_set_scale: float, + pooling_mode: PoolingMode, ) -> None: + assume(not weighted or pooling_mode == PoolingMode.SUM) + # Generate embedding modules ( emb, @@ -229,6 +296,8 @@ def test_ssd_forward( log_E, L, weighted, + cache_set_scale=cache_set_scale, + pooling_mode=pooling_mode, ) # Generate inputs @@ -262,11 +331,25 @@ def test_ssd_forward( log_E=st.integers(min_value=3, max_value=5), L=st.integers(min_value=0, max_value=20), weighted=st.booleans(), + cache_set_scale=st.sampled_from([0.0, 0.005, 1]), + pooling_mode=st.sampled_from( + [PoolingMode.NONE, PoolingMode.SUM, PoolingMode.MEAN] + ), ) @settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None) def test_ssd_backward_adagrad( - self, T: int, D: int, B: int, log_E: int, L: int, weighted: bool + self, + T: int, + D: int, + B: int, + log_E: int, + L: int, + weighted: bool, + cache_set_scale: float, + pooling_mode: PoolingMode, ) -> None: + assume(not weighted or pooling_mode == PoolingMode.SUM) + # Constants lr = 0.5 eps = 0.2 @@ -286,6 +369,8 @@ def test_ssd_backward_adagrad( lr=lr, eps=eps, ssd_shards=ssd_shards, + cache_set_scale=cache_set_scale, + pooling_mode=pooling_mode, ) Es = [emb.embedding_specs[t][0] for t in range(T)] @@ -317,10 +402,17 @@ def test_ssd_backward_adagrad( # Execute torch EmbeddingBag backward [out.backward(grad) for (out, grad) in zip(output_ref_list, output_grad_list)] - # Execute TBE SSD backward - output.backward( - torch.cat([grad.view(B, -1) for grad in output_grad_list], dim=1) + do_pooling = pooling_mode != PoolingMode.NONE + grad_test = self.concat_ref_tensors( + output_grad_list, + do_pooling, + B, + D * 4, ) + + # Execute TBE SSD backward + output.backward(grad_test) + # Compare optimizer states split_optimizer_states = [s for (s,) in emb.debug_split_optimizer_states()] for t in range(T): @@ -359,18 +451,30 @@ def test_ssd_backward_adagrad( log_E=st.integers(min_value=3, max_value=5), L=st.integers(min_value=0, max_value=20), weighted=st.booleans(), + cache_set_scale=st.sampled_from([0.0, 0.005, 1]), + pooling_mode=st.sampled_from( + [PoolingMode.NONE, PoolingMode.SUM, PoolingMode.MEAN] + ), ) @settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None) def test_ssd_cache( - self, T: int, D: int, B: int, log_E: int, L: int, weighted: bool + self, + T: int, + D: int, + B: int, + log_E: int, + L: int, + weighted: bool, + cache_set_scale: float, + pooling_mode: PoolingMode, ) -> None: - # T=2 - # D=2 - # B=9 - # log_E=3 - # L=14 - # weighted=False + assume(not weighted or pooling_mode == PoolingMode.SUM) + + lr = 0.5 + eps = 0.2 + ssd_shards = 2 torch.manual_seed(42) + # Generate embedding modules ( emb, @@ -382,8 +486,17 @@ def test_ssd_cache( log_E, L, weighted, + lr=lr, + eps=eps, + ssd_shards=ssd_shards, + cache_set_scale=cache_set_scale, + pooling_mode=pooling_mode, ) + optimizer_states_ref = [ + s.clone().float() for (s,) in emb.debug_split_optimizer_states() + ] + Es = [emb.embedding_specs[t][0] for t in range(T)] for i in range(10): @@ -422,31 +535,9 @@ def test_ssd_cache( 0, # prefetch_dist emb.lru_state, ) - assert actions_count_gpu.item() == 0 - lxu_cache_locations = torch.ops.fbgemm.lxu_cache_lookup( - linear_cache_indices, - emb.lxu_cache_state, - emb.hash_size_cumsum[-1], - ) - lru_state_cpu = emb.lru_state.cpu() - lxu_cache_state_cpu = emb.lxu_cache_state.cpu() - - NOT_FOUND = np.iinfo(np.int32).max - ASSOC = 32 - - for loc, linear_idx in zip( - lxu_cache_locations.cpu().numpy().tolist(), - linear_cache_indices.cpu().numpy().tolist(), - ): - assert loc != NOT_FOUND - # if we have a hit, check the cache is consistent - loc_set = loc // ASSOC - loc_slot = loc % ASSOC - assert lru_state_cpu[loc_set, loc_slot] == emb.timestep - assert lxu_cache_state_cpu[loc_set, loc_slot] == linear_idx - - self.execute_ssd_forward_( + # Execute forward + output_ref_list, output = self.execute_ssd_forward_( emb, emb_ref, indices_list, @@ -457,6 +548,64 @@ def test_ssd_cache( B, L, weighted, + i=i, + ) + + # Generate output gradient + output_grad_list = [torch.randn_like(out) for out in output_ref_list] + + # Execute torch EmbeddingBag backward + for t, (out, grad) in enumerate(zip(output_ref_list, output_grad_list)): + # Zero out weight grad + emb_ref[t].weight.grad = None + out.backward(grad) + + do_pooling = pooling_mode != PoolingMode.NONE + grad_test = self.concat_ref_tensors( + output_grad_list, + do_pooling, + B, + D * 4, + ) + + # Execute TBE SSD backward + output.backward(grad_test) + + # Compare optimizer states + split_optimizer_states = [s for (s,) in emb.debug_split_optimizer_states()] + for t in range(T): + # pyre-fixme[16]: Optional type has no attribute `float`. + optimizer_states_ref[t].add_( + emb_ref[t].weight.grad.float().to_dense().pow(2).mean(dim=1) + ) + torch.testing.assert_close( + split_optimizer_states[t].float(), + optimizer_states_ref[t], + atol=1.0e-4, + rtol=1.0e-4, + ) + + emb_ref[t].weight.data.copy_( + torch.addcdiv( + emb_ref[t].weight.float(), + value=-lr, + tensor1=emb_ref[t].weight.grad.float().to_dense(), + tensor2=split_optimizer_states[t] + .float() + .sqrt() + .add(eps) + .view(Es[t], 1), + ) + ) + + # Compare weights + emb.flush() + for t in range(T): + torch.testing.assert_close( + emb.debug_split_embedding_weights()[t].float().cuda(), + emb_ref[t].weight.float(), + atol=1.0e-4, + rtol=1.0e-4, )