From 4fe203b14fe5931822a03809ab2520a60f233fb7 Mon Sep 17 00:00:00 2001 From: Sarunya Pumma Date: Mon, 20 May 2024 03:32:42 -0700 Subject: [PATCH] Add cache conflict miss support (backend) (#2596) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/2596 Differential Revision: D55998215 --- fbgemm_gpu/FbgemmGpu.cmake | 50 ++- .../genscript/generate_backward_split.py | 150 ++++++--- .../genscript/generate_forward_split.py | 79 +++-- .../genscript/generate_index_select.py | 4 +- .../codegen/genscript/jinja_environment.py | 17 +- fbgemm_gpu/codegen/genscript/optimizers.py | 17 + ...embedding_backward_split_host_template.cpp | 302 ++++++++++++------ ..._backward_split_indice_weights_template.cu | 70 ++-- ...ding_backward_split_kernel_cta_template.cu | 31 +- ...ing_backward_split_kernel_warp_template.cu | 29 +- ...embedding_backward_split_meta_template.cpp | 10 +- .../embedding_backward_split_template.cu | 109 ++++--- ...rward_split_kernel_nobag_small_template.cu | 54 +++- ...embedding_forward_split_kernel_template.cu | 73 +++-- ...edding_forward_split_kernel_v2_template.cu | 2 +- .../embedding_forward_split_meta_template.cpp | 16 +- .../embedding_forward_split_template.cu | 76 +++-- ...optimizer_split_device_kernel_template.cuh | 20 +- .../codegen/training/python/__init__.template | 48 ++- .../{lookup_args.py => lookup_args.template} | 5 +- ..._embedding_codegen_lookup_invoker.template | 25 +- .../ssd_split_table_batched_embeddings_ops.py | 202 +++++++++--- .../include/fbgemm_gpu/fbgemm_cuda_utils.cuh | 13 +- ...ssd_split_table_batched_embeddings_test.py | 248 ++++++++++++-- 24 files changed, 1173 insertions(+), 477 deletions(-) rename fbgemm_gpu/codegen/training/python/{lookup_args.py => lookup_args.template} (94%) diff --git a/fbgemm_gpu/FbgemmGpu.cmake b/fbgemm_gpu/FbgemmGpu.cmake index a6824b57c1..e02f71d9cb 100644 --- a/fbgemm_gpu/FbgemmGpu.cmake +++ b/fbgemm_gpu/FbgemmGpu.cmake @@ -97,6 +97,11 @@ set(GWD_OPTIMIZERS set(DEFUSED_OPTIMIZERS rowwise_adagrad) +# Optimizers with the SSD support +set(SSD_OPTIMIZERS + rowwise_adagrad + sgd) + set(WEIGHT_OPTIONS weighted unweighted_nobag @@ -143,6 +148,7 @@ set(gen_gpu_kernel_source_files "gen_embedding_forward_split_unweighted_codegen_cuda.cu" "gen_embedding_backward_dense_indice_weights_codegen_cuda.cu" "gen_embedding_backward_split_indice_weights_codegen_cuda.cu" + "gen_embedding_backward_ssd_indice_weights_codegen_cuda.cu" "gen_embedding_forward_split_weighted_vbe_codegen_cuda.cu" "gen_embedding_forward_split_unweighted_vbe_codegen_cuda.cu" "gen_batch_index_select_dim0_forward_codegen_cuda.cu" @@ -153,10 +159,13 @@ set(gen_gpu_kernel_source_files "gen_batch_index_select_dim0_backward_kernel_warp.cu" "gen_embedding_backward_split_grad_embedding_ops.cu" "gen_embedding_backward_split_grad_index_select.cu" - "gen_embedding_backward_common_split_device_kernel.cuh" - "gen_embedding_backward_batch_index_select_split_device_kernel.cuh" + "gen_embedding_backward_split_common_device_kernel.cuh" + "gen_embedding_backward_split_batch_index_select_device_kernel.cuh" "gen_embedding_forward_split_weighted_gwd_codegen_cuda.cu" "gen_embedding_forward_split_unweighted_gwd_codegen_cuda.cu" + "gen_embedding_forward_ssd_weighted_codegen_cuda.cu" + "gen_embedding_forward_ssd_unweighted_codegen_cuda.cu" + "gen_embedding_forward_ssd_unweighted_nobag_kernel_small.cu" ) if(NOT USE_ROCM) @@ -179,7 +188,8 @@ foreach(wdesc ${WEIGHT_OPTIONS}) "gen_embedding_backward_dense_split_${wdesc}_kernel_cta.cu" "gen_embedding_backward_dense_split_${wdesc}_kernel_warp.cu" "gen_embedding_forward_split_${wdesc}_kernel.cu" - "gen_embedding_backward_${wdesc}_split_device_kernel.cuh") + "gen_embedding_forward_ssd_${wdesc}_kernel.cu" + "gen_embedding_backward_split_${wdesc}_device_kernel.cuh") foreach(etype fp32 fp16 fp8 int8 int4 int2) list(APPEND gen_gpu_kernel_source_files @@ -191,7 +201,7 @@ endforeach() foreach(wdesc weighted unweighted) list(APPEND gen_gpu_kernel_source_files "gen_embedding_forward_split_${wdesc}_vbe_kernel.cu" - "gen_embedding_backward_${wdesc}_vbe_split_device_kernel.cuh") + "gen_embedding_backward_split_${wdesc}_vbe_device_kernel.cuh") endforeach() # Generate GWD files @@ -207,22 +217,31 @@ set(gen_cpu_source_files set(gen_python_source_files ${CMAKE_BINARY_DIR}/__init__.py - ${CMAKE_BINARY_DIR}/lookup_args.py) + ${CMAKE_BINARY_DIR}/lookup_args.py + ${CMAKE_BINARY_DIR}/lookup_args_ssd.py +) # For each of the optimizers, generate the backward split variant by adding # the Python, CPU-only, GPU host, and GPU kernel source files -# Generate the Python functions only if there is the backend support +# Generate the Python functions only if there is the backend support (for all +# optimizers) foreach(optimizer ${COMMON_OPTIMIZERS} ${CPU_ONLY_OPTIMIZERS} ${GPU_ONLY_OPTIMIZERS}) list(APPEND gen_python_source_files - "${CMAKE_BINARY_DIR}/lookup_${optimizer}.py") - list(APPEND gen_python_source_files + "${CMAKE_BINARY_DIR}/lookup_${optimizer}.py" "${CMAKE_BINARY_DIR}/lookup_${optimizer}_pt2.py") endforeach() +# Generate the Python functions only if there is the backend support (for SSD +# optimizers) +foreach(optimizer ${SSD_OPTIMIZERS}) + list(APPEND gen_python_source_files + "${CMAKE_BINARY_DIR}/lookup_${optimizer}_ssd.py") +endforeach() + # Generate the backend API for all optimizers to preserve the backward # compatibility list(APPEND gen_cpu_source_files @@ -285,6 +304,21 @@ foreach(optimizer ${DEFUSED_OPTIMIZERS}) "${CMAKE_BINARY_DIR}/split_embedding_optimizer_${optimizer}.py") endforeach() +foreach(optimizer ${SSD_OPTIMIZERS}) + list(APPEND gen_gpu_kernel_source_files + "gen_embedding_optimizer_${optimizer}_ssd_device_kernel.cuh" + "gen_embedding_backward_ssd_${optimizer}.cpp" + ) + + foreach(wdesc weighted unweighted unweighted_nobag) + list(APPEND gen_gpu_kernel_source_files + "gen_embedding_backward_${optimizer}_ssd_${wdesc}_cuda.cu" + "gen_embedding_backward_${optimizer}_ssd_${wdesc}_kernel_cta.cu" + "gen_embedding_backward_${optimizer}_ssd_${wdesc}_kernel_warp.cu") + endforeach() + +endforeach() + list(APPEND gen_defused_optim_py_files ${CMAKE_BINARY_DIR}/optimizer_args.py) diff --git a/fbgemm_gpu/codegen/genscript/generate_backward_split.py b/fbgemm_gpu/codegen/genscript/generate_backward_split.py index 9ccc23f576..80ef79d81e 100644 --- a/fbgemm_gpu/codegen/genscript/generate_backward_split.py +++ b/fbgemm_gpu/codegen/genscript/generate_backward_split.py @@ -8,6 +8,7 @@ # pyre-strict # flake8: noqa F401 +import itertools import sys try: @@ -39,28 +40,44 @@ def render_backward_templates( ) -> None: if not kwargs.get("has_gpu_support"): return + + weighted_options = [True, False] + nobag_options = [True, False] if (not is_gwd) else [False] vbe_options = ( [True, False] if (kwargs.get("has_vbe_support") and not is_gwd) else [False] ) + ssd_options = [True, False] if kwargs.get("has_ssd_support") else [False] template = CodeTemplate.load(template_filepath) - for weighted in [True, False]: - for nobag in [True, False] if (not is_gwd) else [False]: - for vbe in vbe_options: - if (not nobag or (not weighted and not vbe)) and ( - not kwargs.get("dense") or not vbe - ): - wdesc = f"{ 'weighted' if weighted else 'unweighted' }{ '_nobag' if nobag else '' }{ '_vbe' if vbe else '' }" - template.write( - filename_format.format(optimizer, wdesc), - weighted=weighted, - nobag=nobag, - vbe=vbe, - is_index_select=False, - kdesc=wdesc, - **kwargs, - is_gwd=is_gwd, - ) + for weighted, nobag, vbe, ssd in itertools.product( + weighted_options, nobag_options, vbe_options, ssd_options + ): + if nobag and (weighted or vbe): + continue + if kwargs.get("dense") and (vbe or ssd): + continue + if ssd and (vbe or is_gwd): + continue + + kdesc = "".join( + [ + f"{ 'weighted' if weighted else 'unweighted' }", + f"{ '_nobag' if nobag else '' }", + f"{ '_vbe' if vbe else '' }", + ] + ) + desc = "_".join([f"{ 'ssd' if ssd else 'split' }", kdesc]) + template.write( + filename_format.format(optimizer, desc), + weighted=weighted, + nobag=nobag, + vbe=vbe, + is_index_select=False, + kdesc=kdesc, + is_gwd=is_gwd, + ssd=ssd, + **kwargs, + ) @staticmethod def generate_backward_split_gpu(**kwargs: Any) -> None: @@ -73,19 +90,19 @@ def generate_backward_split_gpu(**kwargs: Any) -> None: for template_filepath, filename_format in [ ( "training/backward/embedding_backward_split_template.cu", - "gen_embedding_backward_{}_split_{}_cuda.cu", + "gen_embedding_backward_{}_{}_cuda.cu", ), ( "training/backward/embedding_backward_split_meta_template.cpp", - "gen_embedding_backward_{}_split_{}_meta.cpp", + "gen_embedding_backward_{}_{}_meta.cpp", ), ( "training/backward/embedding_backward_split_kernel_cta_template.cu", - "gen_embedding_backward_{}_split_{}_kernel_cta.cu", + "gen_embedding_backward_{}_{}_kernel_cta.cu", ), ( "training/backward/embedding_backward_split_kernel_warp_template.cu", - "gen_embedding_backward_{}_split_{}_kernel_warp.cu", + "gen_embedding_backward_{}_{}_kernel_warp.cu", ), ]: BackwardSplitGenerator.render_backward_templates( @@ -94,20 +111,21 @@ def generate_backward_split_gpu(**kwargs: Any) -> None: filename_format, kwargs, ) + # Generate the global weight decay CUDA kernels if kwargs.get("has_global_weight_decay_support"): for template_filepath, filename_format in [ ( "training/backward/embedding_backward_split_kernel_cta_template.cu", - "gen_embedding_backward_{}_split_{}_gwd_kernel_cta.cu", + "gen_embedding_backward_{}_{}_gwd_kernel_cta.cu", ), ( "training/backward/embedding_backward_split_kernel_warp_template.cu", - "gen_embedding_backward_{}_split_{}_gwd_kernel_warp.cu", + "gen_embedding_backward_{}_{}_gwd_kernel_warp.cu", ), ( "training/backward/embedding_backward_split_template.cu", - "gen_embedding_backward_{}_split_{}_gwd_cuda.cu", + "gen_embedding_backward_{}_{}_gwd_cuda.cu", ), ]: BackwardSplitGenerator.render_backward_templates( @@ -118,23 +136,38 @@ def generate_backward_split_gpu(**kwargs: Any) -> None: is_gwd=True, ) - # Generate optimizer kernel - CodeTemplate.load( - "training/optimizer/embedding_optimizer_split_device_kernel_template.cuh" - ).write( - f"gen_embedding_optimizer_{optimizer}_split_device_kernel.cuh", **kwargs - ) + for ssd in ( + [True, False] + if kwargs.get("has_ssd_support") and not kwargs.get("dense") + else [False] + ): + desc = f"{ 'ssd' if ssd else 'split' }" + # Generate optimizer kernel + CodeTemplate.load( + "training/optimizer/embedding_optimizer_split_device_kernel_template.cuh" + ).write( + f"gen_embedding_optimizer_{optimizer}_{desc}_device_kernel.cuh", + ssd=ssd, + **kwargs, + ) # Generate the backward splits (non-dense) # We generate only the API to preserve the backward compatibility if # has_gpu_support=True if not kwargs.get("dense"): - # Generate CUDA autograd, PT2 unified autograd, and PT2 backward wrapper + # Generate CUDA autograd + template_filepath = ( + "training/backward/embedding_backward_split_host_template.cpp" + ) + for ssd in [True, False]: + desc = "ssd" if ssd else "split" + filename = f"gen_embedding_backward_{desc}_{optimizer}.cpp" + CodeTemplate.load(template_filepath).write( + filename, is_forward=False, ssd=ssd, **kwargs + ) + + # Generate PT2 unified autograd, and PT2 backward wrapper for template_filepath, filename in [ - ( - "training/backward/embedding_backward_split_host_template.cpp", - f"gen_embedding_backward_split_{optimizer}.cpp", - ), ( "training/pt2/embedding_split_host_pt2_autograd_template.cpp", f"gen_embedding_split_{optimizer}_pt2_autograd.cpp", @@ -153,11 +186,15 @@ def generate_backward_split_gpu(**kwargs: Any) -> None: template = CodeTemplate.load( "training/python/split_embedding_codegen_lookup_invoker.template" ) - for filename in [ - f"lookup_{optimizer}.py", - f"lookup_{optimizer}_pt2.py", - ]: - template.write(filename, is_fbcode=args.is_fbcode, **kwargs) + for ssd in [True, False]: + sdesc = "_ssd" if ssd else "" + for filename in [ + f"lookup_{optimizer}{sdesc}.py", + f"lookup_{optimizer}{sdesc}_pt2.py", + ]: + template.write( + filename, is_fbcode=args.is_fbcode, ssd=ssd, **kwargs + ) @staticmethod def generate_backward_split_cpu(**kwargs: Any) -> None: @@ -213,10 +250,11 @@ def generate_backward_device() -> None: BackwardSplitGenerator.render_backward_templates( template_filepath, "", - "{}gen_embedding_backward_{}_split_device_kernel.cuh", + "{}gen_embedding_backward_{}_device_kernel.cuh", { "has_gpu_support": True, "has_vbe_support": True, + "has_ssd_support": True, "dense": False, "gen_once": False, }, @@ -224,7 +262,7 @@ def generate_backward_device() -> None: # Generate common backward device kernels (generate only once) CodeTemplate.load(template_filepath).write( - "gen_embedding_backward_common_split_device_kernel.cuh", + "gen_embedding_backward_split_common_device_kernel.cuh", gen_once=True, ) @@ -242,16 +280,27 @@ def generate_backward_indices() -> None: template = CodeTemplate.load( "training/backward/embedding_backward_split_indice_weights_template.cu" ) - for dense in [True, False]: + dense_options = [True, False] + ssd_options = [True, False] + for dense, ssd in itertools.product(dense_options, ssd_options): + if dense and ssd: + continue + desc = "dense" if dense else ("ssd" if ssd else "split") template.write( - f"gen_embedding_backward_{'dense' if dense else 'split'}_indice_weights_codegen_cuda.cu", + f"gen_embedding_backward_{ desc }_indice_weights_codegen_cuda.cu", dense=dense, + ssd=ssd, ) @staticmethod def generate_python_sources() -> None: CodeTemplate.load("training/python/__init__.template").write("__init__.py") - CodeTemplate.copy_to_root("training/python/lookup_args.py") + + template = CodeTemplate.load("training/python/lookup_args.template") + for ssd in [True, False]: + sdesc = "_ssd" if ssd else "" + filename = f"lookup_args{sdesc}.py" + template.write(filename, ssd=ssd) @staticmethod def generate() -> None: @@ -276,8 +325,17 @@ def generate() -> None: none_optimizer(), ] + ssd_tensors = [ + "row_addrs", + "inserted_rows", + "post_bwd_evicted_indices", + "actions_count", + ] + for optimizer in optimizers: - BackwardSplitGenerator.generate_backward_split(**optimizer) + BackwardSplitGenerator.generate_backward_split( + ssd_tensors=ssd_tensors, **optimizer + ) # Generate common device kernels for backwards BackwardSplitGenerator.generate_backward_device() diff --git a/fbgemm_gpu/codegen/genscript/generate_forward_split.py b/fbgemm_gpu/codegen/genscript/generate_forward_split.py index d907a11ff1..bb6488b54a 100644 --- a/fbgemm_gpu/codegen/genscript/generate_forward_split.py +++ b/fbgemm_gpu/codegen/genscript/generate_forward_split.py @@ -9,6 +9,7 @@ # flake8: noqa F401 import argparse +import itertools import sys from typing import List @@ -27,36 +28,42 @@ def render_forward_templates( dense_options: List[bool], nobag_options: List[bool], vbe_options: List[bool], + ssd_options: List[bool], is_gwd: bool = False, ) -> None: template = CodeTemplate.load(template_filepath) - for dense in dense_options: - for weighted in [True, False]: - for nobag in nobag_options: - for vbe in vbe_options: - if (not nobag or (not weighted and not vbe)) and ( - not dense or not vbe - ): - dense_desc = f"{ 'dense' if dense else 'split'}" - weight_desc = ( - f"{ 'weighted' if weighted else 'unweighted' }" - ) - nobag_desc = f"{ '_nobag' if nobag else '' }" - vbe_desc = f"{ '_vbe' if vbe else '' }" - - template.write( - filename_format.format( - f"{ dense_desc }_{ weight_desc }{ nobag_desc }{ vbe_desc }" - ), - dense=dense, - weighted=weighted, - nobag=nobag, - vbe=vbe, - is_index_select=False, - is_gwd=is_gwd, - ) + weighted_options = [True, False] + + for dense, weighted, nobag, vbe, ssd in itertools.product( + dense_options, weighted_options, nobag_options, vbe_options, ssd_options + ): + if nobag and (weighted or vbe): + continue + if dense and (vbe or ssd): + continue + if ssd and (vbe or is_gwd): + continue + + desc = "".join( + [ + f"{ 'dense' if dense else ('ssd' if ssd else 'split') }", + f"{ '_weighted' if weighted else '_unweighted' }", + f"{ '_nobag' if nobag else '' }", + f"{ '_vbe' if vbe else '' }", + ] + ) + fname = filename_format.format(desc) + template.write( + fname, + dense=dense, + weighted=weighted, + nobag=nobag, + vbe=vbe, + ssd=ssd, + is_index_select=False, + is_gwd=is_gwd, + ) - @staticmethod def generate_pt2_wrappers() -> None: # Generate PT2 forward wrapper (CUDA) CodeTemplate.load( @@ -84,12 +91,14 @@ def generate_small_kernels() -> None: "training/forward/embedding_forward_split_kernel_nobag_small_template.cu" ) for dense in [True, False]: - wdesc = f"{ 'dense' if dense else 'split' }" - template.write( - f"gen_embedding_forward_{wdesc}_unweighted_nobag_kernel_small.cu", - dense=dense, - is_index_select=False, - ) + for ssd in [True, False]: + ddesc = f"{ 'dense' if dense else ('ssd' if ssd else 'split') }" + template.write( + f"gen_embedding_forward_{ ddesc }_unweighted_nobag_kernel_small.cu", + dense=dense, + ssd=ssd, + is_index_select=False, + ) @staticmethod def generate_kernels() -> None: @@ -100,6 +109,7 @@ def generate_kernels() -> None: dense_options=[True, False], nobag_options=[False], # nobag is not used vbe_options=[True, False], + ssd_options=[True, False], ) # Generate the CUDA host code for global weight decay ForwardSplitGenerator.render_forward_templates( @@ -109,6 +119,7 @@ def generate_kernels() -> None: nobag_options=[False], # nobag is not used vbe_options=[False], is_gwd=True, + ssd_options=[False], ) # Generate the meta kernels @@ -118,6 +129,7 @@ def generate_kernels() -> None: dense_options=[True, False], nobag_options=[False], # nobag is not used vbe_options=[True, False], + ssd_options=[True, False], ) # Generate the CUDA kernels @@ -127,6 +139,7 @@ def generate_kernels() -> None: dense_options=[True, False], nobag_options=[True, False], vbe_options=[True, False], + ssd_options=[True, False], ) # Generate the global weight decay CUDA kernels ForwardSplitGenerator.render_forward_templates( @@ -135,6 +148,7 @@ def generate_kernels() -> None: dense_options=[False], nobag_options=[False], vbe_options=[False], + ssd_options=[False], is_gwd=True, ) @@ -145,6 +159,7 @@ def generate_kernels() -> None: dense_options=[False], # dense is not supported nobag_options=[False], # nobag is not supported vbe_options=[False], # vbe is not supported + ssd_options=[False], # ssd is not supported ) @staticmethod diff --git a/fbgemm_gpu/codegen/genscript/generate_index_select.py b/fbgemm_gpu/codegen/genscript/generate_index_select.py index c28ab9b60a..22dc672806 100644 --- a/fbgemm_gpu/codegen/genscript/generate_index_select.py +++ b/fbgemm_gpu/codegen/genscript/generate_index_select.py @@ -58,7 +58,7 @@ def generate() -> None: ), ( "training/backward/embedding_backward_split_device_kernel_template.cuh", - "gen_embedding_backward_batch_index_select_split_device_kernel.cuh", + "gen_embedding_backward_split_batch_index_select_device_kernel.cuh", ), ]: CodeTemplate.load(template_file).write( @@ -84,7 +84,7 @@ def generate() -> None: CodeTemplate.load( "training/backward/embedding_backward_split_device_kernel_template.cuh" ).write( - "gen_embedding_backward_common_split_device_kernel.cuh", + "gen_embedding_backward_split_common_device_kernel.cuh", gen_once=True, ) diff --git a/fbgemm_gpu/codegen/genscript/jinja_environment.py b/fbgemm_gpu/codegen/genscript/jinja_environment.py index 5e6d96a4a7..1ac6cd705d 100644 --- a/fbgemm_gpu/codegen/genscript/jinja_environment.py +++ b/fbgemm_gpu/codegen/genscript/jinja_environment.py @@ -289,13 +289,20 @@ def is_valid_forward_config( def has_experimental_support( - dense: bool, nobag: bool, vbe: bool, is_index_select: bool, is_rocm: bool + dense: bool, nobag: bool, vbe: bool, is_index_select: bool, is_rocm: bool, ssd: bool ) -> bool: """ Check if the given combination of configs has TBE v2 support - - TBE v2 does not support dense, nobag, vbe, is_index_select, and is_rocm + - TBE v2 does not support dense, nobag, vbe, is_index_select, is_rocm, and ssd """ - return not dense and not nobag and not vbe and not is_index_select and not is_rocm + return ( + not dense + and not nobag + and not vbe + and not is_index_select + and not is_rocm + and not ssd + ) def is_valid_gwd_config( @@ -303,7 +310,8 @@ def is_valid_gwd_config( nobag: bool, vbe: bool, is_index_select: bool, - has_global_weight_decay_support: bool = True, + has_global_weight_decay_support: bool, + ssd: bool, ) -> bool: """ Check if the given combination of configs is valid for global weight decay support @@ -318,6 +326,7 @@ def is_valid_gwd_config( and not vbe and not is_index_select and has_global_weight_decay_support + and not ssd ) diff --git a/fbgemm_gpu/codegen/genscript/optimizers.py b/fbgemm_gpu/codegen/genscript/optimizers.py index 179b80f745..325a771502 100644 --- a/fbgemm_gpu/codegen/genscript/optimizers.py +++ b/fbgemm_gpu/codegen/genscript/optimizers.py @@ -42,6 +42,7 @@ def dense() -> Dict[str, Any]: "has_gpu_support": True, "has_vbe_support": False, "has_global_weight_decay_support": False, + "has_ssd_support": False, } @@ -86,6 +87,7 @@ def adagrad() -> Dict[str, Any]: "has_gpu_support": True, "has_vbe_support": False, "has_global_weight_decay_support": False, + "has_ssd_support": False, } @@ -255,6 +257,7 @@ def rowwise_adagrad() -> Dict[str, Any]: "has_gpu_support": True, "has_vbe_support": True, "has_global_weight_decay_support": True, + "has_ssd_support": True, } @@ -286,6 +289,7 @@ def approx_rowwise_adagrad() -> Dict[str, Any]: "has_gpu_support": False, "has_vbe_support": False, "has_global_weight_decay_support": False, + "has_ssd_support": False, } @@ -392,6 +396,7 @@ def rowwise_adagrad_with_weight_decay() -> Dict[str, Any]: "has_gpu_support": False, "has_vbe_support": False, "has_global_weight_decay_support": False, + "has_ssd_support": False, } @@ -428,6 +433,7 @@ def approx_rowwise_adagrad_with_weight_decay() -> Dict[str, Any]: "has_gpu_support": False, "has_vbe_support": False, "has_global_weight_decay_support": False, + "has_ssd_support": False, } @@ -599,6 +605,7 @@ def rowwise_adagrad_with_counter() -> Dict[str, Any]: "has_gpu_support": True, "has_vbe_support": True, "has_global_weight_decay_support": False, + "has_ssd_support": False, } @@ -648,6 +655,7 @@ def approx_rowwise_adagrad_with_counter() -> Dict[str, Any]: "has_gpu_support": False, "has_vbe_support": False, "has_global_weight_decay_support": False, + "has_ssd_support": False, } @@ -726,6 +734,7 @@ def rowwise_weighted_adagrad() -> Dict[str, Any]: "has_gpu_support": False, "has_vbe_support": False, "has_global_weight_decay_support": False, + "has_ssd_support": False, } @@ -750,6 +759,7 @@ def sgd() -> Dict[str, Any]: "has_gpu_support": True, "has_vbe_support": True, "has_global_weight_decay_support": False, + "has_ssd_support": True, } @@ -774,6 +784,7 @@ def approx_sgd() -> Dict[str, Any]: "has_gpu_support": False, "has_vbe_support": False, "has_global_weight_decay_support": False, + "has_ssd_support": False, } @@ -852,6 +863,7 @@ def lamb() -> Dict[str, Any]: "has_gpu_support": True, "has_vbe_support": False, "has_global_weight_decay_support": False, + "has_ssd_support": False, } @@ -944,6 +956,7 @@ def partial_rowwise_lamb() -> Dict[str, Any]: "has_gpu_support": True, "has_vbe_support": False, "has_global_weight_decay_support": False, + "has_ssd_support": False, } @@ -1000,6 +1013,7 @@ def adam() -> Dict[str, Any]: "has_gpu_support": True, "has_vbe_support": False, "has_global_weight_decay_support": False, + "has_ssd_support": False, } @@ -1075,6 +1089,7 @@ def partial_rowwise_adam() -> Dict[str, Any]: "has_gpu_support": True, "has_vbe_support": False, "has_global_weight_decay_support": False, + "has_ssd_support": False, } @@ -1140,6 +1155,7 @@ def lars_sgd() -> Dict[str, Any]: "has_gpu_support": True, "has_vbe_support": False, "has_global_weight_decay_support": False, + "has_ssd_support": False, } @@ -1158,4 +1174,5 @@ def none_optimizer() -> Dict[str, Any]: "has_gpu_support": True, "has_vbe_support": False, "has_global_weight_decay_support": False, + "has_ssd_support": False, } diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp index 8fcc77cce3..137a093027 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp @@ -20,6 +20,17 @@ using Tensor = at::Tensor; using namespace fbgemm_gpu; +{#/* Module description */#} +{%- set mdesc = "ssd" if ssd else "split" %} + +{%- if ssd %} +enum SSDTensor { + {%- for tensor in ssd_tensors %} + {{ tensor | upper }} = {{ loop.index - 1 }}, + {%- endfor %} +}; +{%- endif %} + //////////////////////////////////////////////////////////////////////////////// // Macro Helper Functions //////////////////////////////////////////////////////////////////////////////// @@ -31,22 +42,21 @@ using namespace fbgemm_gpu; */ #} {%- macro call_forward_op_dispatch(nobag, weighted, vbe, is_gwd) %} - {%- set vdesc = "_vbe" if vbe else "" %} - {%- set gwddesc = "_gwd" if is_gwd else "" %} - {%- set wdesc = "weighted" if weighted else "unweighted" %} - {%- set nobag = "_nobag" if nobag else "" %} - - {%- set forward_op = "split_embedding{}_codegen_forward_{}{}{}_cuda".format( - nobag, wdesc, vdesc, gwddesc + {%- set forward_op = "{}_embedding{}_codegen_forward_{}{}{}_cuda".format( + mdesc, + "_nobag" if nobag else "", + "weighted" if weighted else "unweighted", + "_vbe" if vbe else "", + "_gwd" if is_gwd else "", ) %} - static auto split_embedding_codegen_forward_op = + static auto embedding_codegen_forward_op = torch::Dispatcher::singleton() .findSchemaOrThrow("fbgemm::{{ forward_op }}", "") .typed(); return { - split_embedding_codegen_forward_op.call( + embedding_codegen_forward_op.call( flatten_dev_weights, uvm_weights, lxu_cache_weights, @@ -67,7 +77,11 @@ using namespace fbgemm_gpu; *indice_weights, {%- endif %} {%- endif %} {# /* if not nobag */ #} + {%- if ssd %} + ssd_tensors[SSDTensor::ROW_ADDRS], + {%- else %} lxu_cache_locations, + {%- endif %} uvm_cache_stats_, output_dtype, {%- if not nobag %} @@ -99,20 +113,22 @@ using namespace fbgemm_gpu; unweighted backward op via Pytorch dispatcher */ {%- macro call_backward_op_dispatch(nobag, weighted, vbe, is_gwd) %} - {%- set nobag = "_nobag" if nobag else "" %} - {%- set vdesc = "_vbe" if vbe else "" %} - {%- set gwddesc = "_gwd" if is_gwd else "" %} - {%- set wdesc = "weighted" if weighted else "unweighted" %} - {%- set backward_op = "split_embedding{}_backward_codegen_{}_{}_exact{}{}_cuda".format( - nobag, optimizer, wdesc, vdesc, gwddesc + {%- set wdesc = "_weighted" if weighted else "_unweighted" %} + {%- set backward_op = "{}_embedding{}_backward_codegen_{}{}{}{}_exact_cuda".format( + mdesc, + "_nobag" if nobag else "", + optimizer, + wdesc, + "_vbe" if vbe else "", + "_gwd" if is_gwd else "", ) %} - static auto split_embedding_codegen_{{ wdesc }}_backward_op = + static auto embedding_codegen_{{ wdesc }}_backward_op = torch::Dispatcher::singleton() .findSchemaOrThrow("fbgemm::{{ backward_op }}", "") .typed(); - grad_dev_weights = split_embedding_codegen_{{ wdesc }}_backward_op.call( + grad_dev_weights = embedding_codegen_{{ wdesc }}_backward_op.call( grad_output, dev_weights, uvm_weights, @@ -135,7 +151,11 @@ using namespace fbgemm_gpu; indice_weights, {%- endif %} {%- endif %} {# /* if not nobag */ #} + {%- if ssd %} + ssd_row_addrs, + {%- else %} lxu_cache_locations, + {%- endif %} BT_block_size, max_segment_length_per_warp, {%- if optimizer != "none" %} @@ -209,6 +229,11 @@ using namespace fbgemm_gpu; Variable(), // iter {%- endif %} {%- endif %} + {%- if ssd %} + {%- for tensor in ssd_tensors %} + Variable(), // {{ tensor }} + {%- endfor %} + {%- endif %} {{ args.split_variables | join(", ") }} }; {%- endmacro %} @@ -217,7 +242,8 @@ using namespace fbgemm_gpu; from lookup_function */ {%- macro call_autograd(nobag, vbe, is_gwd) %} - {%- set autograd_func = "Split{}{}{}LookupFunction_{}_Op".format( + {%- set autograd_func = "{}{}{}{}LookupFunction_{}_Op".format( + "SSD" if ssd else "Split", "NoBag" if nobag else "", "VBE" if vbe else "", "GWD" if is_gwd else "", @@ -274,6 +300,9 @@ using namespace fbgemm_gpu; iter, {%- endif %} {%- endif %} + {%- if ssd %} + ssd_tensors.value(), + {%- endif %} {{ args.split_function_arg_names | join(", ") }})[0]; {%- endmacro %} @@ -287,9 +316,34 @@ using namespace fbgemm_gpu; {%- for vbe in ([True, False] if has_vbe_support else [False]) %} {%- set vdesc = "_vbe" if vbe else "" %} -{%- for weighted in [True, False] %} +Tensor {{ mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda( + const Tensor& grad_output, + const Tensor& dev_weights, + const Tensor& uvm_weights, + const Tensor& lxu_cache_weights, + const Tensor& weights_placements, + const Tensor& weights_offsets, + const Tensor& D_offsets, + const c10::SymInt max_D, + const Tensor& indices, + const Tensor& offsets, + {%- if ssd %} + const Tensor& ssd_row_addrs, + {%- else %} + const Tensor& lxu_cache_locations, + {%- endif %} + {%- if vbe %} + const Tensor& feature_requires_grad, + const Tensor& vbe_row_output_offsets, + const Tensor& vbe_b_t_map, + const int64_t info_B_num_bits, + const int64_t info_B_mask_int64 + {%- else %} + const Tensor& feature_requires_grad + {%- endif %} +); + {%- for nobag in ([False] if (weighted or vbe) else [True, False]) %} -{%- set wdesc = "weighted" if weighted else "unweighted" %} {%- set ndesc = "_nobag" if nobag else "" %} {%- for is_gwd in ([True, False] @@ -298,11 +352,17 @@ using namespace fbgemm_gpu; nobag, vbe, is_index_select, - has_global_weight_decay_support - ) + has_global_weight_decay_support, + ssd) else [False]) %} {%- set gwddesc = "_gwd" if is_gwd else "" %} -Tensor split_embedding{{ ndesc }}_codegen_forward_{{ wdesc }}{{ vdesc }}{{ gwddesc }}_cuda( + +{%- for weighted in [True, False] %} +{%- set wdesc = "_weighted" if weighted else "_unweighted" %} + +{%- set desc_suffix = wdesc + vdesc + gwddesc %} + +Tensor {{ mdesc }}_embedding{{ ndesc }}_codegen_forward{{ desc_suffix }}_cuda( const Tensor& dev_weights, const Tensor& uvm_weights, const Tensor& lxu_cache_weights, @@ -323,7 +383,11 @@ Tensor split_embedding{{ ndesc }}_codegen_forward_{{ wdesc }}{{ vdesc }}{{ gwdde {%- if weighted %} const Tensor& indice_weights, {%- endif %} + {%- if ssd %} + const Tensor& ssd_row_addrs, + {%- else %} const Tensor& lxu_cache_locations, + {%- endif %} const Tensor& uvm_cache_stats, const int64_t output_dtype, {%- if vbe %} @@ -343,9 +407,10 @@ Tensor split_embedding{{ ndesc }}_codegen_forward_{{ wdesc }}{{ vdesc }}{{ gwdde {%- else %} const bool is_experimental {%- endif %} - ); +); -Tensor split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_exact{{ vdesc }}{{ gwddesc }}_cuda( +Tensor +{{ mdesc }}_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}{{ desc_suffix }}_exact_cuda( const Tensor& grad_output, const Tensor& dev_weights, const Tensor& uvm_weights, @@ -368,7 +433,11 @@ Tensor split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_e {%- if weighted %} const Tensor& indice_weights, {%- endif %} + {%- if ssd %} + const Tensor& ssd_row_addrs, + {%- else %} const Tensor& lxu_cache_locations, + {%- endif %} const int64_t BT_block_size, const int64_t max_segment_length_per_warp, {%- if optimizer != "none" %} @@ -392,51 +461,15 @@ Tensor split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_e {%- endif %} {%- endif %} {{ args.split_function_args | join(", ") }}); - -{%- endfor %} {#-/*for is_gwd*/#} -{%- endfor %} {#-/*for nobag*/#} -{%- endfor %} {#-/*for weighted*/#} - -Tensor split_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda( - const Tensor& grad_output, - const Tensor& dev_weights, - const Tensor& uvm_weights, - const Tensor& lxu_cache_weights, - const Tensor& weights_placements, - const Tensor& weights_offsets, - const Tensor& D_offsets, - const c10::SymInt max_D, - const Tensor& indices, - const Tensor& offsets, - const Tensor& lxu_cache_locations, - {%- if vbe %} - const Tensor& feature_requires_grad, - const Tensor& vbe_row_output_offsets, - const Tensor& vbe_b_t_map, - const int64_t info_B_num_bits, - const int64_t info_B_mask_int64 - {%- else %} - const Tensor& feature_requires_grad - {%- endif %} -); +{%- endfor %} {#-/* for weighted*/#} //////////////////////////////////////////////////////////////////////////////// // Autograd Function Declarations //////////////////////////////////////////////////////////////////////////////// -{%- for nobag in [True, False] %} -{%- if not nobag or not vbe %} {#-/* nobag does not support vbe */#} {#- /* Generate a separate autograd function for global weight decay */ #} -{%- for is_gwd in ([True, False] - if is_valid_gwd_config( - dense, - nobag, - vbe, - is_index_select, - has_global_weight_decay_support - ) - else [False]) %} -{%- set autograd_func = "Split{}{}{}LookupFunction_{}_Op".format( +{%- set autograd_func = "{}{}{}{}LookupFunction_{}_Op".format( + "SSD" if ssd else "Split", "NoBag" if nobag else "", "VBE" if vbe else "", "GWD" if is_gwd else "", @@ -444,7 +477,6 @@ Tensor split_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda( ) %} - class {{ autograd_func }} : public torch::autograd::Function<{{ autograd_func }}> { public: @@ -499,6 +531,9 @@ class {{ autograd_func }} : const int64_t iter, {%- endif %} {%- endif %} + {%- if ssd %} + const at::TensorList& ssd_tensors, + {%- endif %} {{ args.split_function_args | join(", ") }}) { const auto T = weights_offsets.sym_numel(); @@ -590,6 +625,11 @@ class {{ autograd_func }} : {%- if is_gwd and "prev_iter_dev" not in args.split_function_arg_names %} prev_iter_dev_, {%- endif %} + {%- if ssd %} + {%- for tensor in ssd_tensors %} + ssd_tensors[SSDTensor::{{ tensor | upper }}], + {%- endfor %} + {%- endif %} {{ args.split_saved_tensors | join(", ") }} }); @@ -627,12 +667,33 @@ class {{ autograd_func }} : {%- endif %} {%- if nobag %} - {{ call_forward_op_dispatch(nobag=True, weighted=False, vbe=vbe, is_gwd=is_gwd) }} + {{ + call_forward_op_dispatch( + nobag=True, + weighted=False, + vbe=vbe, + is_gwd=is_gwd, + ) + }} {%- else %} if (indice_weights) { - {{ call_forward_op_dispatch(nobag=False, weighted=True, vbe=vbe, is_gwd=is_gwd) }} + {{ + call_forward_op_dispatch( + nobag=False, + weighted=True, + vbe=vbe, + is_gwd=is_gwd, + ) + }} } - {{ call_forward_op_dispatch(nobag=False, weighted=False, vbe=vbe, is_gwd=is_gwd) }} + {{ + call_forward_op_dispatch( + nobag=False, + weighted=False, + vbe=vbe, + is_gwd=is_gwd, + ) + }} {%- endif %} {#-/* if not nobag */ #} } @@ -666,6 +727,11 @@ class {{ autograd_func }} : {%- if is_gwd and "prev_iter_dev" not in args.split_function_arg_names %} auto prev_iter_dev = *savedItr++; {%- endif %} + {%- if ssd %} + {%- for tensor in ssd_tensors %} + auto ssd_{{ tensor }} = *savedItr++; + {%- endfor %} + {%- endif %} {%- for tensor in args.split_saved_tensors %} auto {{ tensor }} = *savedItr++; @@ -719,21 +785,21 @@ class {{ autograd_func }} : {%- if not nobag %} {%- if optimizer == "none" %} // Flatten (dev_weights is used in - // split_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda) + // {{ mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda) dev_weights = dev_weights.flatten(); {%- endif %} {%- set grad_indice_weights_op = - "split_embedding_codegen_grad_indice_weights{}_cuda".format(vdesc) + "{}_embedding_codegen_grad_indice_weights{}_cuda".format(mdesc, vdesc) %} - static auto split_embedding_codegen_grad_indice_weights_op = + static auto {{ mdesc }}_embedding_codegen_grad_indice_weights_op = torch::Dispatcher::singleton() .findSchemaOrThrow("fbgemm::{{ grad_indice_weights_op }}", "") .typed(); const auto grad_indice_weights = !indice_weights.defined() ? Variable() : - split_embedding_codegen_grad_indice_weights_op.call( + {{ mdesc }}_embedding_codegen_grad_indice_weights_op.call( grad_output, dev_weights, uvm_weights, @@ -744,7 +810,11 @@ class {{ autograd_func }} : max_D, indices, offsets, + {%- if ssd %} + ssd_row_addrs, + {%- else %} lxu_cache_locations, + {%- endif %} {%- if vbe %} feature_requires_grad, vbe_row_output_offsets, @@ -758,24 +828,43 @@ class {{ autograd_func }} : Tensor grad_dev_weights; if (indice_weights.defined()) { - {{ call_backward_op_dispatch(nobag=False, weighted=True, vbe=vbe, is_gwd=is_gwd) }} + {{ + call_backward_op_dispatch( + nobag=False, + weighted=True, + vbe=vbe, + is_gwd=is_gwd, + ) + }} } - {{ call_backward_op_dispatch(nobag=False, weighted=False, vbe=vbe, is_gwd=is_gwd) }} - + {{ + call_backward_op_dispatch( + nobag=False, + weighted=False, + vbe=vbe, + is_gwd=is_gwd, + ) + }} {%- else %} Tensor grad_dev_weights; - {{ call_backward_op_dispatch(nobag=True, weighted=False, vbe=vbe, is_gwd=is_gwd) }} + {{ + call_backward_op_dispatch( + nobag=True, + weighted=False, + vbe=vbe, + is_gwd=is_gwd, + ) + }} {%- endif %} } }; {%- endfor %} {#-/* for is_gwd */#} -{%- endif %} {#-/* if not nobag or not vbe */#} -{%- endfor %} {#-/* for nobag in [True, False] */#} -{%- endfor %} {#-/* for vbe in [True, False] */#} +{%- endfor %} {#-/* for nobag */#} +{%- endfor %} {#-/* for vbe */#} {%- endif %} {#-/* if has_gpu_support */#} ///@ingroup embedding-cuda -Tensor split_embedding_codegen_lookup_{{ optimizer }}_function( +Tensor {{ mdesc }}_embedding_codegen_lookup_{{ optimizer }}_function( const Tensor& placeholder_autograd_tensor, const Tensor& dev_weights, const Tensor& uvm_weights, @@ -816,15 +905,17 @@ Tensor split_embedding_codegen_lookup_{{ optimizer }}_function( {%- if "iter" not in args.split_function_arg_names %} const int64_t iter = 0, {%- endif %} + {%- if ssd %} + const bool apply_global_weight_decay = false, + const c10::optional& ssd_tensors = c10::nullopt + {%- else %} const bool apply_global_weight_decay = false + {%- endif %} ) { // TODO: refactor into macro {%- if has_gpu_support %} - if (static_cast(pooling_mode) == PoolingMode::NONE) { - // no bag - {{ call_autograd(nobag=True, vbe=False, is_gwd=False) }} - } + {%- if not ssd %} {%- if has_vbe_support %} // has vbe support if (B_offsets.has_value()) { @@ -840,11 +931,26 @@ Tensor split_embedding_codegen_lookup_{{ optimizer }}_function( {{ call_autograd(nobag=False, vbe=False, is_gwd=True) }} } {%- endif %} {#-/* if has_global_weight_decay_support */ #} + {%- endif %} {#-/* if not ssd */#} - {{ call_autograd(nobag=False, vbe=False, is_gwd=False) }} + {%- if ssd %} + TORCH_CHECK( + ssd_tensors.value().size() == {{ ssd_tensors | length }}, + "SSD TBE expects {{ ssd_tensors | length }} in ssd_tensors"); + {%- endif %} + if (static_cast(pooling_mode) == PoolingMode::NONE) { + // no bag + {{ call_autograd(nobag=True, vbe=False, is_gwd=False) }} + } + else { + {{ call_autograd(nobag=False, vbe=False, is_gwd=False) }} + } {%- else %} - TORCH_CHECK(false, "split_embedding_codegen_lookup_{{ optimizer }}_function is deprecated. Please see https://github.com/pytorch/FBGEMM/discussions/1727 for more detail."); + TORCH_CHECK( + false, + "{{ mdesc }}_embedding_codegen_lookup_{{ optimizer }}_function is deprecated. Please see https://github.com/pytorch/FBGEMM/discussions/1727 for more detail." + ); return Tensor(); {%- endif %} {#-/* if has_gpu_support */#} } @@ -852,7 +958,8 @@ Tensor split_embedding_codegen_lookup_{{ optimizer }}_function( // Deprecated for fb namespace! Please use fbgemm namespace instead! {%- for lib_name in ["fb", "fbgemm"] %} TORCH_LIBRARY_FRAGMENT({{ lib_name }}, m) { - m.def("split_embedding_codegen_lookup_{{ optimizer }}_function(" + {%- set op_name = "{}_embedding_codegen_lookup_{}_function".format(mdesc, optimizer) %} + m.def("{{ op_name }}(" " Tensor placeholder_autograd_tensor, " " Tensor dev_weights, " " Tensor uvm_weights, " @@ -893,7 +1000,12 @@ TORCH_LIBRARY_FRAGMENT({{ lib_name }}, m) { {%- if "iter" not in args.split_function_arg_names %} " int iter=0, " {%- endif %} - " bool apply_global_weight_decay=False " + {%- if ssd %} + " bool apply_global_weight_decay=False, " + " Tensor[]? ssd_tensors=None" + {%- else %} + " bool apply_global_weight_decay=False" + {%- endif %} ") -> Tensor", {PT2_COMPLIANT_TAG}); // We're playing a funny trick here: we're using the autograd @@ -902,18 +1014,18 @@ TORCH_LIBRARY_FRAGMENT({{ lib_name }}, m) { // no autograd enabled, and all of the internal implementations redispatch // appropriately m.impl( - "split_embedding_codegen_lookup_{{ optimizer }}_function", + "{{ op_name }}", torch::dispatch( c10::DispatchKey::Autograd, - TORCH_FN(split_embedding_codegen_lookup_{{ optimizer }}_function))); + TORCH_FN({{ op_name }}))); m.impl( - "split_embedding_codegen_lookup_{{ optimizer }}_function", + "{{ op_name }}", torch::dispatch( c10::DispatchKey::Meta, - TORCH_FN(split_embedding_codegen_lookup_{{ optimizer }}_function))); + TORCH_FN({{ op_name }}))); DISPATCH_TO_CUDA( - "split_embedding_codegen_lookup_{{ optimizer }}_function", - split_embedding_codegen_lookup_{{ optimizer }}_function); + "{{ op_name }}", + {{ op_name }}); } {%- endfor %} {#-/* for lib_name */#} - // clang-format on + // clang-format on diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu index 501137b61d..8a1c1b052b 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu @@ -8,7 +8,11 @@ // clang-format off -{%- set ddesc = "dense" if dense else "split" %} +{%- set mdesc = "dense" if dense else ("ssd" if ssd else "split") %} + +{%- set locs_or_addrs_tensor = "ssd_row_addrs" if ssd else "lxu_cache_locations" %} +{%- set locs_or_addrs_type = "int64_t" if ssd else "int32_t" %} +{%- set locs_or_addrs_idx = "row_idx" if ssd else "cache_idx" %} //////////////////////////////////////////////////////////////////////////////// // Required for op registrations @@ -41,9 +45,8 @@ using namespace fbgemm_gpu; -}} }() -{%- for vbe in [True, False] %} +{%- for vbe in ([True, False] if (not dense and not ssd) else [False]) %} {%- set vdesc = "_vbe" if vbe else "" %} -{%- if not dense or not vbe %} {#- /* Generate different kernels for different kUseVecBlocking using Jinja @@ -64,7 +67,7 @@ template < int32_t kFixedMaxVecsPerThread > __global__ __launch_bounds__(kForwardMaxThreads) void -{{ ddesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_{{ vbdesc }}kernel( +{{ mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_{{ vbdesc }}kernel( // [\sum_t E_t x D_t] const pta::PackedTensorAccessor64 grad_output, pta::PackedTensorAccessor64 dev_weights, @@ -78,7 +81,7 @@ __global__ __launch_bounds__(kForwardMaxThreads) void const pta::PackedTensorAccessor32 indices, // [N = \sum_{b,t} L_{b,t} total indices, i.e. flattened [B][T][L] const pta::PackedTensorAccessor32 offsets, // [B x T + 1] {%- if not dense %} - const pta::PackedTensorAccessor32 lxu_cache_locations, + const pta::PackedTensorAccessor32<{{ locs_or_addrs_type }}, 1, at::RestrictPtrTraits> {{ locs_or_addrs_tensor }}, {%- endif %} pta::PackedTensorAccessor32 feature_requires_grad, // [T], pta::PackedTensorAccessor32, 1, at::RestrictPtrTraits> grad_indice_weights, @@ -172,14 +175,14 @@ __global__ __launch_bounds__(kForwardMaxThreads) void int32_t l = l_start + threadIdx.x; int64_t idx = l < L ? indices[indices_start + l] : 0; {%- if not dense %} - int32_t cache_idx = + const auto {{ locs_or_addrs_idx }} = (placement == PlacementType::MANAGED_CACHING && l < L) - ? lxu_cache_locations[indices_start + l] : 0; + ? {{ locs_or_addrs_tensor }}[indices_start + l] : 0; {%- endif %} for (auto j = 0; j < kWarpSize && l_start + j < L; ++j) { int64_t idx_j = shfl_sync(idx, j); {%- if not dense %} - int32_t cache_idx_j = shfl_sync(cache_idx, j); + const auto {{ locs_or_addrs_idx }}_j = shfl_sync({{ locs_or_addrs_idx }}, j); {%- endif %} at::acc_type grad_indice_weight = 0.0; @@ -189,8 +192,20 @@ __global__ __launch_bounds__(kForwardMaxThreads) void ++vec) { const int32_t d = {{ d }}; {%- if not dense %} - if (placement == PlacementType::MANAGED_CACHING && cache_idx_j != kCacheLocationMissing) { - Vec4T weight(&lxu_cache_weights[cache_idx_j][d]); + if ({{ "true || " if ssd else "" }} + ( + placement == PlacementType::MANAGED_CACHING + && ({{ locs_or_addrs_idx }}_j != kCacheLocationMissing) + ) + ) { + const cache_t* cache_weights = + {%- if ssd %} + reinterpret_cast( + *reinterpret_cast(&{{ locs_or_addrs_idx }}_j)); + {%- else %} + &lxu_cache_weights[{{ locs_or_addrs_idx }}_j][d]; + {%- endif %} + Vec4T weight(cache_weights); grad_indice_weight += weight.acc.x * grad_out[vec].acc.x + weight.acc.y * grad_out[vec].acc.y + weight.acc.z * grad_out[vec].acc.z + @@ -261,7 +276,7 @@ __global__ __launch_bounds__(kForwardMaxThreads) void } {%- endfor %} {# /* for use_vec_blocking */ #} -Tensor {{ ddesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda( +Tensor {{ mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda( const Tensor& grad_output, const Tensor& dev_weights, {%- if not dense %} @@ -275,7 +290,7 @@ Tensor {{ ddesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda( const Tensor& indices, const Tensor& offsets, {%- if not dense %} - const Tensor& lxu_cache_locations, + const Tensor& {{ locs_or_addrs_tensor }}, {%- endif %} {%- if vbe %} const Tensor& feature_requires_grad, @@ -300,7 +315,7 @@ Tensor {{ ddesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda( indices, offsets, {%- if not dense %} - lxu_cache_locations, + {{ locs_or_addrs_tensor }}, {%- endif %} {%- if vbe %} vbe_row_output_offsets, @@ -347,7 +362,7 @@ Tensor {{ ddesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda( {%- else %} dev_weights.scalar_type(), {%- endif %} - "split_embedding_codegen_grad_indice_weights_kernel", + "split_embedding_codegen_grad_indice_weights{{ vdesc }}_kernel", [&] { {%- if vbe %} const auto& grad_output_reshaped = aligned_grad_output.reshape({1, -1}); @@ -361,7 +376,7 @@ Tensor {{ ddesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda( DISPATCH_{{ dpdesc }}VEC_BLOCKING_KERNEL(max_D, [&] { {%- set kernel_name = "{}_embedding_codegen_grad_indice_weights{}_{}kernel".format( - ddesc, vdesc, vbdesc) + mdesc, vdesc, vbdesc) %} #ifdef FBGEMM_GPU_MEMCHECK const auto func_name = @@ -388,7 +403,7 @@ Tensor {{ ddesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda( MAKE_PTA_WITH_NAME(func_name, indices, int64_t, 1, 32), MAKE_PTA_WITH_NAME(func_name, offsets, int64_t, 1, 32), {%- if not dense %} - MAKE_PTA_WITH_NAME(func_name, lxu_cache_locations, int32_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name, {{ locs_or_addrs_tensor }}, {{ locs_or_addrs_type }}, 1, 32), {%- endif %} MAKE_PTA_WITH_NAME(func_name, feature_requires_grad_, int32_t, 1, 32), MAKE_PTA_ACC_WITH_NAME(func_name, grad_indice_weights, grad_t, 1, 32), @@ -412,7 +427,7 @@ Tensor {{ ddesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda( } -Tensor {{ ddesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_meta( +Tensor {{ mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_meta( const Tensor& grad_output, const Tensor& dev_weights, {%- if not dense %} @@ -426,7 +441,7 @@ Tensor {{ ddesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_meta( const Tensor& indices, const Tensor& offsets, {%- if not dense %} - const Tensor& lxu_cache_locations, + const Tensor& {{ locs_or_addrs_tensor }}, {%- endif %} {%- if vbe %} const Tensor& feature_requires_grad, @@ -457,11 +472,12 @@ Tensor {{ ddesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_meta( //////////////////////////////////////////////////////////////////////////////// TORCH_LIBRARY_FRAGMENT(fbgemm, m) { {%- set embedding_codegen_grad_indice_weights_op = - "{}_embedding_codegen_grad_indice_weights{}_cuda".format( - ddesc, vdesc + "{}_embedding_codegen_grad_indice_weights{}".format( + mdesc, vdesc ) %} - m.def("{{ embedding_codegen_grad_indice_weights_op }}(" + {%- set embedding_codegen_grad_indice_weights_op_cuda = embedding_codegen_grad_indice_weights_op + "_cuda" %} + m.def("{{ embedding_codegen_grad_indice_weights_op_cuda }}(" " Tensor grad_output, " " Tensor dev_weights, " {%- if not dense %} @@ -475,7 +491,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { " Tensor indices, " " Tensor offsets, " {%- if not dense %} - " Tensor lxu_cache_locations, " + " Tensor {{ locs_or_addrs_tensor }}, " {%- endif %} {%- if vbe %} " Tensor feature_requires_grad, " @@ -488,12 +504,12 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { {%- endif %} ") -> Tensor"); DISPATCH_TO_CUDA( - "{{ embedding_codegen_grad_indice_weights_op }}", - {{ embedding_codegen_grad_indice_weights_op }} + "{{ embedding_codegen_grad_indice_weights_op_cuda }}", + {{ embedding_codegen_grad_indice_weights_op_cuda }} ); - m.impl("{{ embedding_codegen_grad_indice_weights_op }}", - torch::dispatch(c10::DispatchKey::Meta, TORCH_FN({{ ddesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_meta))); + m.impl("{{ embedding_codegen_grad_indice_weights_op_cuda }}", + torch::dispatch(c10::DispatchKey::Meta, + TORCH_FN({{ embedding_codegen_grad_indice_weights_op }}_meta))); } -{%- endif %} {#-/* if not dense or not vbe */#} {%- endfor %} {#-/* for vbe */#} // clang-format on diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_cta_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_cta_template.cu index f56bb818d8..77590f06d1 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_cta_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_cta_template.cu @@ -7,6 +7,7 @@ */ // clang-format off +{%- set mdesc = "ssd" if ssd else "split" %} {%- set wdesc = "weighted" if weighted else "unweighted" %} {%- set ndesc = "_nobag" if nobag else "" %} {%- set vdesc = "_vbe" if vbe else "" %} @@ -23,16 +24,23 @@ nobag, vbe, is_index_select, - has_global_weight_decay_support) %} + has_global_weight_decay_support, + ssd) %} +{%- set gwddesc = "_gwd" if is_gwd_kernel else "" %} + +{%- set desc_suffix = wdesc + vdesc + gwddesc %} + +{%- set locs_or_addrs_tensor = "ssd_row_addrs" if ssd else "lxu_cache_locations" %} +{%- set locs_or_addrs_type = "int64_t" if ssd else "int32_t" %} #include "fbgemm_gpu/embedding_backward_template_helpers.cuh" #include "fbgemm_gpu/fbgemm_tensor_accessor.h" #include "fbgemm_gpu/split_embeddings_utils.cuh" {%- if optimizer != "none" and not dense %} -#include "gen_embedding_optimizer_{{ optimizer }}_split_device_kernel.cuh" +#include "gen_embedding_optimizer_{{ optimizer }}_{{ mdesc }}_device_kernel.cuh" {%- endif %} -#include "gen_embedding_backward_{{ kdesc }}_split_device_kernel.cuh" -#include "gen_embedding_backward_common_split_device_kernel.cuh" +#include "gen_embedding_backward_split_{{ kdesc }}_device_kernel.cuh" +#include "gen_embedding_backward_split_common_device_kernel.cuh" using Tensor = at::Tensor; using namespace fbgemm_gpu; @@ -65,7 +73,6 @@ using namespace fbgemm_gpu; thus reduce the kernel occupancy, which can degrade the kernel performance. This increases the binary size, but the increase is minimal. */ #} -{%- set gwddesc = "_gwd" if is_gwd_kernel else "" %} template < typename emb_t, typename grad_t, @@ -80,7 +87,7 @@ __global__ __launch_bounds__(kMaxThreads) void {%- if is_index_select %} batch_index_select_dim0_codegen_backward_kernel_cta_per_row( {%- else %} -split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}{{ gwddesc }}_kernel_cta_per_row_1( +{{ mdesc }}_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ desc_suffix }}_kernel_cta_per_row_1( {%- endif %} const pta::PackedTensorAccessor64 grad_output, {%- if optimizer != "none" %} @@ -108,7 +115,7 @@ split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc const pta::PackedTensorAccessor32 sorted_infos, {%- endif %} {%- if not dense %} - const pta::PackedTensorAccessor32 sorted_lxu_cache_locations, + const pta::PackedTensorAccessor32<{{ locs_or_addrs_type }}, 1, at::RestrictPtrTraits> sorted_{{ locs_or_addrs_tensor }}, const bool use_uniq_cache_locations, const pta::PackedTensorAccessor32 table_unique_indices_offsets, {%- endif %} @@ -341,7 +348,7 @@ split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc {{ compute_global_weight_decay(is_gwd_kernel) }} {%- if not dense and optimizer != "none" %} - split_{{ optimizer }}_table_update_kernel< + {{ mdesc }}_{{ optimizer }}_table_update_kernel< emb_t, cache_t, {%- for ph_name in args.placeholder_tensor_names %} @@ -356,7 +363,7 @@ split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc lxu_cache_weights, weights_placements, weights_offsets, - sorted_lxu_cache_locations, + sorted_{{ locs_or_addrs_tensor }}, grad_sum, kUseVecBlocking ? smem_grad_sum : nullptr, kIsInt8 ? smem_grad_sum : nullptr, @@ -433,7 +440,7 @@ template __global__ __launch_bounds__(kMaxThreads) void {%- if is_index_select %} batch_index_select_dim0_codegen_backward_kernel_cta_per_row {%- else %} -split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}{{ gwddesc }}_kernel_cta_per_row_1 +{{ mdesc }}_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ desc_suffix }}_kernel_cta_per_row_1 {%- endif %} < {{ emb_type }}, {{ grad_type }}, @@ -472,8 +479,8 @@ split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc const pta::PackedTensorAccessor32 sorted_infos, {%- endif %} {%- if not dense %} - const pta::PackedTensorAccessor32 - sorted_lxu_cache_locations, + const pta::PackedTensorAccessor32<{{ locs_or_addrs_type }}, 1, at::RestrictPtrTraits> + sorted_{{ locs_or_addrs_tensor }}, const bool use_uniq_cache_locations, const pta::PackedTensorAccessor32 table_unique_indices_offsets, {%- endif %} diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu index e91296e586..9ff187dd8b 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu @@ -7,9 +7,11 @@ */ // clang-format off +{%- set mdesc = "ssd" if ssd else "split" %} {%- set wdesc = "weighted" if weighted else "unweighted" %} {%- set ndesc = "_nobag" if nobag else "" %} {%- set vdesc = "_vbe" if vbe else "" %} + {# /* `has_global_weight_decay_support` tells whether the optimizer has support for global weight decay (gwd) @@ -22,16 +24,23 @@ nobag, vbe, is_index_select, - has_global_weight_decay_support) %} + has_global_weight_decay_support, + ssd) %} +{%- set gwddesc = "_gwd" if is_gwd_kernel else "" %} + +{%- set desc_suffix = wdesc + vdesc + gwddesc %} + +{%- set locs_or_addrs_tensor = "ssd_row_addrs" if ssd else "lxu_cache_locations" %} +{%- set locs_or_addrs_type = "int64_t" if ssd else "int32_t" %} #include "fbgemm_gpu/embedding_backward_template_helpers.cuh" #include "fbgemm_gpu/fbgemm_tensor_accessor.h" #include "fbgemm_gpu/split_embeddings_utils.cuh" {%- if optimizer != "none" and not dense %} -#include "gen_embedding_optimizer_{{ optimizer }}_split_device_kernel.cuh" +#include "gen_embedding_optimizer_{{ optimizer }}_{{ mdesc }}_device_kernel.cuh" {%- endif %} -#include "gen_embedding_backward_{{ kdesc }}_split_device_kernel.cuh" -#include "gen_embedding_backward_common_split_device_kernel.cuh" +#include "gen_embedding_backward_split_{{ kdesc }}_device_kernel.cuh" +#include "gen_embedding_backward_split_common_device_kernel.cuh" using Tensor = at::Tensor; using namespace fbgemm_gpu; @@ -63,7 +72,7 @@ __global__ __launch_bounds__(kBackwardMaxThreads) void {%- if is_index_select %} batch_index_select_dim0_codegen_backward_kernel_warp_per_row( {%- else %} -split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}{{ gwddesc }}_kernel_warp_per_row_1( +{{ mdesc }}_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ desc_suffix }}_kernel_warp_per_row_1( {%- endif %} const pta::PackedTensorAccessor64 grad_output, {%- if optimizer != "none" %} @@ -89,7 +98,7 @@ split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc const pta::PackedTensorAccessor32 sorted_infos, {%- endif %} {%- if not dense %} - const pta::PackedTensorAccessor32 sorted_lxu_cache_locations, + const pta::PackedTensorAccessor32<{{ locs_or_addrs_type }}, 1, at::RestrictPtrTraits> sorted_{{ locs_or_addrs_tensor }}, const bool use_uniq_cache_locations, const pta::PackedTensorAccessor32 table_unique_indices_offsets, {%- endif %} @@ -249,7 +258,7 @@ split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc kUseVecBlocking ? max_vecs_per_thread : kFixedMaxVecsPerThread; {%- if not dense and optimizer != "none" %} - split_{{ optimizer }}_table_update_kernel< + {{ mdesc }}_{{ optimizer }}_table_update_kernel< emb_t, cache_t, {%- for ph_name in args.placeholder_tensor_names %} @@ -264,7 +273,7 @@ split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc lxu_cache_weights, weights_placements, weights_offsets, - sorted_lxu_cache_locations, + sorted_{{ locs_or_addrs_tensor }}, grad_sum, smem_grad_sum, smem_grad_sum, // shared_weight_update_row (reuse smem_grad_sum) @@ -343,7 +352,7 @@ template __global__ __launch_bounds__(kBackwardMaxThreads) void {%- if is_index_select %} batch_index_select_dim0_codegen_backward_kernel_warp_per_row {%- else %} -split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}{{ gwddesc }}_kernel_warp_per_row_1 +{{ mdesc }}_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ desc_suffix }}_kernel_warp_per_row_1 {%- endif %} < {{ emb_type }}, {{ grad_type }}, @@ -379,7 +388,7 @@ split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc const pta::PackedTensorAccessor32 sorted_infos, {%- endif %} {%- if not dense %} - const pta::PackedTensorAccessor32 sorted_lxu_cache_locations, + const pta::PackedTensorAccessor32<{{ locs_or_addrs_type }}, 1, at::RestrictPtrTraits> sorted_{{ locs_or_addrs_tensor }}, const bool use_uniq_cache_locations, const pta::PackedTensorAccessor32 table_unique_indices_offsets, {%- endif %} diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_meta_template.cpp b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_meta_template.cpp index 3aded4a1ed..28eb82ca57 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_meta_template.cpp +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_meta_template.cpp @@ -17,10 +17,13 @@ // Companion template is embedding_backward_split_template.cu +{%- set mdesc = "ssd" if ssd else "split" %} {%- set wdesc = "weighted" if weighted else "unweighted" %} {%- set vdesc = "_vbe" if vbe else "" %} {%- set ndesc = "_nobag" if nobag else "" %} +{%- set locs_or_addrs_tensor = "ssd_row_addrs" if ssd else "lxu_cache_locations" %} + //////////////////////////////////////////////////////////////////////////////// // Required for op registrations #include "fbgemm_gpu/embedding_op_registration.h" @@ -42,14 +45,15 @@ using Tensor = at::Tensor; nobag, vbe, is_index_select, - has_global_weight_decay_support + has_global_weight_decay_support, + ssd=False ) else [False]) %} {%- set gwddesc = "_gwd" if is_gwd else "" %} {%- if is_index_select %} Tensor batch_index_select_dim0_codegen_backward_meta( {%- else %} -Tensor split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_exact{{ vdesc }}{{ gwddesc }}_meta( +Tensor {{ mdesc }}_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_exact{{ vdesc }}{{ gwddesc }}_meta( {%- endif %} const Tensor& grad_output, const Tensor& dev_weights, @@ -78,7 +82,7 @@ Tensor split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_e const Tensor& indice_weights, {%- endif %} {%- if not dense %} - const Tensor& lxu_cache_locations, + const Tensor& {{ locs_or_addrs_tensor }}, {%- endif %} {%- if not is_index_select %} const int64_t unused_, diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index a31bc8c399..b1fa132948 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -8,8 +8,17 @@ // clang-format off {%- set wdesc = "weighted" if weighted else "unweighted" %} -{%- set vdesc = "_vbe" if vbe else "" %} {%- set ndesc = "_nobag" if nobag else "" %} +{%- set vdesc = "_vbe" if vbe else "" %} +{%- set mdesc = "ssd" if ssd else "split" %} + +{%- macro get_desc_suffix(gwd) %} +{%- set gwddesc = "_gwd" if gwd else "" %} +{{- wdesc + vdesc + gwddesc }} +{%- endmacro %} + +{%- set locs_or_addrs_tensor = "ssd_row_addrs" if ssd else "lxu_cache_locations" %} +{%- set locs_or_addrs_type = "int64_t" if ssd else "int32_t" %} {%- if not is_index_select %} //////////////////////////////////////////////////////////////////////////////// @@ -35,8 +44,9 @@ using namespace fbgemm_gpu; nobag, vbe, is_index_select, - has_global_weight_decay_support) %} -{%- set gwddesc = "_gwd" if is_gwd_kernel else "" %} + has_global_weight_decay_support, + ssd) %} +{%- set desc_suffix = get_desc_suffix(is_gwd_kernel) %} template < typename emb_t, typename grad_t, @@ -51,7 +61,7 @@ __global__ __launch_bounds__(kMaxThreads) void {%- if is_index_select %} batch_index_select_dim0_codegen_backward_kernel_cta_per_row( {%- else %} -split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}{{ gwddesc }}_kernel_cta_per_row_1( +{{ mdesc }}_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ desc_suffix }}_kernel_cta_per_row_1( {%- endif %} const pta::PackedTensorAccessor64 grad_output, {%- if optimizer != "none" %} @@ -79,7 +89,7 @@ split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc const pta::PackedTensorAccessor32 sorted_infos, {%- endif %} {%- if not dense %} - const pta::PackedTensorAccessor32 sorted_lxu_cache_locations, + const pta::PackedTensorAccessor32<{{ locs_or_addrs_type }}, 1, at::RestrictPtrTraits> sorted_{{ locs_or_addrs_tensor }}, const bool use_uniq_cache_locations, const pta::PackedTensorAccessor32 table_unique_indices_offsets, {%- endif %} @@ -139,7 +149,7 @@ __global__ __launch_bounds__(kBackwardMaxThreads) void {%- if is_index_select %} batch_index_select_dim0_codegen_backward_kernel_warp_per_row( {%- else %} -split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}{{ gwddesc }}_kernel_warp_per_row_1( +{{ mdesc }}_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ desc_suffix }}_kernel_warp_per_row_1( {%- endif %} const pta::PackedTensorAccessor64 grad_output, {%- if optimizer != "none" %} @@ -165,7 +175,7 @@ split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc const pta::PackedTensorAccessor32 sorted_infos, {%- endif %} {%- if not dense %} - const pta::PackedTensorAccessor32 sorted_lxu_cache_locations, + const pta::PackedTensorAccessor32<{{ locs_or_addrs_type }}, 1, at::RestrictPtrTraits> sorted_{{ locs_or_addrs_tensor }}, const bool use_uniq_cache_locations, const pta::PackedTensorAccessor32 table_unique_indices_offsets, {%- endif %} @@ -417,22 +427,21 @@ int32_t compute_num_groups_and_dynamic_smem_bytes( nobag, vbe, is_index_select, - has_global_weight_decay_support) %} -{%- set gwddesc = "_gwd" if is_gwd_kernel else "" %} - -{%- set func_name0 = "split_embedding{}_backward_codegen_{}_{}_exact{}{}_cuda".format( + has_global_weight_decay_support, + ssd) %} +{%- set desc_suffix = get_desc_suffix(is_gwd_kernel) %} + +{%- set embedding_cuda_op = + "batch_index_select_dim0_codegen_backward_cuda" + if is_index_select + else "{}_embedding{}_backward_codegen_{}_{}_exact_cuda".format( + mdesc, ndesc, optimizer, - wdesc, - vdesc, - gwddesc) + desc_suffix) %} -{%- if is_index_select %} -Tensor batch_index_select_dim0_codegen_backward_cuda( -{%- else %} -Tensor split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_exact{{ vdesc }}{{ gwddesc }}_cuda( -{%- endif %} +Tensor {{ embedding_cuda_op }}( const Tensor& grad_output, const Tensor& dev_weights, {%- if not dense %} @@ -460,7 +469,7 @@ Tensor split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_e const Tensor& indice_weights, {%- endif %} {%- if not dense %} - const Tensor& lxu_cache_locations, + const Tensor& {{ locs_or_addrs_tensor }}, {%- endif %} {%- if not is_index_select %} const int64_t unused_, @@ -538,7 +547,7 @@ Tensor split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_e indice_weights, {%- endif %} {%- if not dense %} - lxu_cache_locations, + {{ locs_or_addrs_tensor }}, {%- endif %} {%- if is_gwd_kernel %} prev_iter_dev, @@ -676,24 +685,24 @@ Tensor split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_e ); {%- if not dense %} - Tensor lxu_cache_locations_sorted = lxu_cache_locations; + Tensor {{ locs_or_addrs_tensor }}_sorted = {{ locs_or_addrs_tensor }}; Tensor table_unique_indices_offsets; - if (lxu_cache_locations.size(0) > 0) { + if ({{ locs_or_addrs_tensor }}.size(0) > 0) { if (use_uniq_cache_locations) { if (!use_homogeneous_placements) { - // When use_uniq_cache_locations=true, lxu_cache_locations are unique + // When use_uniq_cache_locations=true, {{ locs_or_addrs_tensor }} are unique // and sorted in an ascending order based on the linear cache indices. // Linear cache indices of tables that are not placed in cache are set // to a sentinel value (i.e., the sum of hash sizes of all embedding // tables). Since the sentinel value is larger than the max linear - // cache index value, the lxu_cache_locations can be sorted differently + // cache index value, the {{ locs_or_addrs_tensor }} can be sorted differently // than the sorted_linear_indices. // // For this reason, the run ids of sorted and unique - // lxu_cache_locations can be different from those of the + // {{ locs_or_addrs_tensor }} can be different from those of the // sorted_linear_indices. We need the following code to compute // table_unique_indices_offsets which contains the differences between - // lxu_cache_locations run ids and sorted_linear_indices run ids. + // {{ locs_or_addrs_tensor }} run ids and sorted_linear_indices run ids. auto dev_or_uvm_unique_indices = at::zeros_like(weights_placements); #ifdef FBGEMM_GPU_MEMCHECK @@ -728,15 +737,15 @@ Tensor split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_e } } else { - lxu_cache_locations_sorted = at::empty_like(lxu_cache_locations); + {{ locs_or_addrs_tensor }}_sorted = at::empty_like({{ locs_or_addrs_tensor }}); size_t temp_storage_bytes = 0; AT_CUDA_CHECK(radix_sort_pairs( nullptr, temp_storage_bytes, linear_indices.data_ptr(), linear_indices_sorted.data_ptr(), - lxu_cache_locations.data_ptr(), - lxu_cache_locations_sorted.data_ptr(), + {{ locs_or_addrs_tensor }}.data_ptr<{{ locs_or_addrs_type }}>(), + {{ locs_or_addrs_tensor }}_sorted.data_ptr<{{ locs_or_addrs_type }}>(), linear_indices.numel(), 0, total_hash_size_bits, @@ -749,8 +758,8 @@ Tensor split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_e temp_storage_bytes, linear_indices.data_ptr(), linear_indices_sorted.data_ptr(), - lxu_cache_locations.data_ptr(), - lxu_cache_locations_sorted.data_ptr(), + {{ locs_or_addrs_tensor }}.data_ptr<{{ locs_or_addrs_type }}>(), + {{ locs_or_addrs_tensor }}_sorted.data_ptr<{{ locs_or_addrs_type }}>(), linear_indices.numel(), 0, total_hash_size_bits, @@ -758,7 +767,7 @@ Tensor split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_e } } - if (lxu_cache_locations.size(0) == 0 || !use_uniq_cache_locations || use_homogeneous_placements) { + if ({{ locs_or_addrs_tensor }}.size(0) == 0 || !use_uniq_cache_locations || use_homogeneous_placements) { table_unique_indices_offsets = at::zeros_like(weights_placements); } {%- endif %} @@ -771,7 +780,7 @@ Tensor split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_e {%- else %} dev_weights.scalar_type(), {%- endif %} - "split_embedding_backward_{{ optimizer }}_exact_kernel", + "{{ embedding_cuda_op }}", [&] { {%- if weighted %} auto indice_weights_sorted = at::empty_like(indice_weights); @@ -816,7 +825,7 @@ Tensor split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_e {%- endif %} auto grad_output_accessor = MAKE_PTA_WITH_NAME( - "{{ func_name0 }}.1", + "{{ embedding_cuda_op }}.1", grad_output_reshaped, grad_t, {{ "1" if is_index_select else "2" }}, 64 @@ -855,7 +864,7 @@ Tensor split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_e C10_CUDA_KERNEL_LAUNCH_CHECK(); {%- endif %} // if not dense or not vbe - grad_output_accessor = MAKE_PTA_WITH_NAME("{{ func_name0 }}.2", grad_output_mean, grad_t, 2, 64); + grad_output_accessor = MAKE_PTA_WITH_NAME("{{ embedding_cuda_op }}.2", grad_output_mean, grad_t, 2, 64); } {%- endif %} @@ -924,18 +933,17 @@ Tensor split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_e {%- for ph_name in args.placeholder_tensor_names %} {{ ph_name + "_dev" }}.scalar_type(), {%- endfor %} - "split_embedding_backward_{{ optimizer }}_exact_placeholder_type_kernel", + "{{ mdesc }}_embedding_backward_{{ optimizer }}_exact_placeholder_type_kernel", [&] { {%- set cta_kernel = "batch_index_select_dim0_codegen_backward_kernel_cta_per_row" if is_index_select else - "split_embedding{}_backward_codegen_{}_{}{}{}_kernel_cta_per_row_1".format( + "{}_embedding{}_backward_codegen_{}_{}_kernel_cta_per_row_1".format( + mdesc, ndesc, optimizer, - wdesc, - vdesc, - gwddesc + desc_suffix, ) %} @@ -1003,7 +1011,7 @@ Tensor split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_e MAKE_PTA_WITH_NAME(func_name3, infos_sorted, int64_t, 1, 32), {%- endif %} {%- if not dense %} - MAKE_PTA_WITH_NAME(func_name3, lxu_cache_locations_sorted, int32_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name3, {{ locs_or_addrs_tensor }}_sorted, {{ locs_or_addrs_type }}, 1, 32), use_uniq_cache_locations, MAKE_PTA_WITH_NAME(func_name3, table_unique_indices_offsets, int32_t, 1, 32), {%- endif %} @@ -1051,12 +1059,11 @@ Tensor split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_e {%- set warp_kernel = "batch_index_select_dim0_codegen_backward_kernel_warp_per_row" if is_index_select else - "split_embedding{}_backward_codegen_{}_{}{}{}_kernel_warp_per_row_1".format( + "{}_embedding{}_backward_codegen_{}_{}_kernel_warp_per_row_1".format( + mdesc, ndesc, optimizer, - wdesc, - vdesc, - gwddesc + desc_suffix, ) %} const auto backward_warp_per_row_kernel = @@ -1125,7 +1132,7 @@ Tensor split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_e MAKE_PTA_WITH_NAME(func_name4, infos_sorted, int64_t, 1, 32), {%- endif %} {%- if not dense %} - MAKE_PTA_WITH_NAME(func_name4, lxu_cache_locations_sorted, int32_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name4, {{ locs_or_addrs_tensor }}_sorted, {{ locs_or_addrs_type }}, 1, 32), use_uniq_cache_locations, MAKE_PTA_WITH_NAME(func_name4, table_unique_indices_offsets, int32_t, 1, 32), {%- endif %} @@ -1191,8 +1198,8 @@ Tensor split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_e {%- if not is_index_select %} TORCH_LIBRARY_FRAGMENT(fbgemm, m) { {%- set embedding_codegen_backward_op = - "split_embedding{}_backward_codegen_{}_{}_exact{}{}_cuda".format( - ndesc, optimizer, wdesc, vdesc, gwddesc + "{}_embedding{}_backward_codegen_{}_{}_exact_cuda".format( + mdesc, ndesc, optimizer, desc_suffix ) %} m.def("{{ embedding_codegen_backward_op }}(" @@ -1223,7 +1230,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { " Tensor indice_weights, " {%- endif %} {%- if not dense %} - " Tensor lxu_cache_locations, " + " Tensor {{ locs_or_addrs_tensor }}, " {%- endif %} {%- if not is_index_select %} " int unused_, " @@ -1261,4 +1268,4 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { ); } {%- endif %} {#-/* if not is_index_select */#} - // clang-format on +// clang-format on diff --git a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_nobag_small_template.cu b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_nobag_small_template.cu index c1b876a0b6..d818d3b253 100644 --- a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_nobag_small_template.cu +++ b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_nobag_small_template.cu @@ -15,6 +15,11 @@ // See https://fburl.com/dw9ljh4h #} +{%- set mdesc = "dense" if dense else ("ssd" if ssd else "split") %} +{%- set locs_or_addrs_tensor = "ssd_row_addrs" if ssd else "lxu_cache_locations" %} +{%- set locs_or_addrs_type = "int64_t" if ssd else "int32_t" %} +{%- set locs_or_addrs_idx = "row_idx" if ssd else "cache_idx" %} + #include "fbgemm_gpu/embedding_forward_template_helpers.cuh" using Tensor = at::Tensor; @@ -31,7 +36,7 @@ __launch_bounds__(kForwardMaxThreads) __global__ void {%- if is_index_select %} batch_index_select_dim0_codegen_forward_small_kernel( {%- else %} -{{ "dense" if dense else "split" }}_embedding_nobag_codegen_forward_unweighted_small_kernel( +{{ mdesc }}_embedding_nobag_codegen_forward_unweighted_small_kernel( {%- endif %} const pta::PackedTensorAccessor64 dev_weights, {%- if not dense %} @@ -51,7 +56,7 @@ batch_index_select_dim0_codegen_forward_small_kernel( const pta::PackedTensorAccessor32 offsets, {%- endif %} {%- if not dense %} - const pta::PackedTensorAccessor32 lxu_cache_locations, + const pta::PackedTensorAccessor32<{{ locs_or_addrs_type }}, 1, at::RestrictPtrTraits> {{ locs_or_addrs_tensor }}, {%- endif %} {%- if is_index_select %} const at::PackedTensorAccessor32 output_offsets, @@ -112,7 +117,7 @@ batch_index_select_dim0_codegen_forward_small_kernel( if (placement == PlacementType::DEVICE) { weights = &dev_weights[weights_offset]; } else { - weights = &uvm_weights[weights_offset]; + weights = {{ "nullptr" if ssd else "&uvm_weights[weights_offset]" }}; } {%- else %} weights = &dev_weights[weights_offset]; @@ -131,7 +136,9 @@ batch_index_select_dim0_codegen_forward_small_kernel( int32_t l = l_start + threadIdx.x; int64_t idx = l < L ? indices[indices_start + l] : 0; {%- if not dense %} - int32_t cache_idx = (placement == PlacementType::MANAGED_CACHING && l < L) ? lxu_cache_locations[indices_start + l] : 0; + const {{ locs_or_addrs_type }} {{ locs_or_addrs_idx }} = + (placement == PlacementType::MANAGED_CACHING && l < L) + ? {{ locs_or_addrs_tensor }}[indices_start + l] : 0; {%- endif %} for (auto j = group_start; j < group_end && l_start + j < L; ++j) { int64_t idx_j = shfl_sync(idx, j); @@ -141,7 +148,8 @@ batch_index_select_dim0_codegen_forward_small_kernel( int64_t output_j = indices_start + l_start + j; {%- endif %} {%- if not dense %} - int32_t cache_idx_j = shfl_sync(cache_idx, j); + const {{ locs_or_addrs_type }} {{ locs_or_addrs_idx }}_j = + shfl_sync({{ locs_or_addrs_idx }}, j); {%- endif %} {%- if not dense %} @@ -150,10 +158,13 @@ batch_index_select_dim0_codegen_forward_small_kernel( float2 qparams_cache = make_float2(0.0f, 0.0f); {%- endif %} - auto weight_row_emb = WeightRow( - const_cast(&weights[idx_j * D_emb]), + auto weight_row_emb = WeightRowAccessor< + emb_t, cache_t, cache_t, false + >( + &weights[idx_j * D_emb], nullptr, - D); + D + ); [[maybe_unused]] float2 qparams_emb; if (std::is_same::value) { qparams_emb = weight_row_emb.load_qparams(); @@ -161,11 +172,24 @@ batch_index_select_dim0_codegen_forward_small_kernel( if (d < D) { {%- if not dense %} - if (placement == PlacementType::MANAGED_CACHING && cache_idx_j != kCacheLocationMissing) { - auto weight_row_cache = WeightRow( - const_cast(&weights[idx_j * D_emb]), - const_cast(&lxu_cache_weights[cache_idx_j][0]), - D); + if (placement == PlacementType::MANAGED_CACHING && + {{ locs_or_addrs_idx }}_j != kCacheLocationMissing) { + const cache_t* cache_weights; + {%- if ssd %} + cache_weights = reinterpret_cast( + *reinterpret_cast(&{{ locs_or_addrs_idx }}_j)); + {%- else %} + cache_weights = reinterpret_cast( + &lxu_cache_weights[{{ locs_or_addrs_idx }}_j][0]); + {%- endif %} + + auto weight_row_cache = WeightRowAccessor< + emb_t, cache_t, cache_t, true + >( + &weights[idx_j * D_emb], + cache_weights, + D + ); Vec4T weight = weight_row_cache.load(d, qparams_cache); weight.store(&output[output_j][d]); } else { @@ -203,7 +227,7 @@ template __launch_bounds__(kForwardMaxThreads) __global__ void {%- if is_index_select %} batch_index_select_dim0_codegen_forward_small_kernel {%- else %} -{{ "dense" if dense else "split" }}_embedding_nobag_codegen_forward_unweighted_small_kernel +{{ mdesc }}_embedding_nobag_codegen_forward_unweighted_small_kernel {%- endif %} < {{ emb_type }}, @@ -230,7 +254,7 @@ batch_index_select_dim0_codegen_forward_small_kernel const pta::PackedTensorAccessor32<{{ index_type }}, 1, at::RestrictPtrTraits> offsets, {%- endif %} {%- if not dense %} - const pta::PackedTensorAccessor32 lxu_cache_locations, + const pta::PackedTensorAccessor32<{{ locs_or_addrs_type }}, 1, at::RestrictPtrTraits> {{ locs_or_addrs_tensor }}, {%- endif %} {%- if is_index_select %} const pta::PackedTensorAccessor32 output_offsets, diff --git a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu index d7ce2c7414..b80e04c010 100644 --- a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu +++ b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu @@ -15,7 +15,7 @@ // See https://fburl.com/dw9ljh4h #} -{%- set ddesc = "dense" if dense else "split" %} +{%- set mdesc = "dense" if dense else ("ssd" if ssd else "split") %} {%- set wdesc = "weighted" if weighted else "unweighted" %} {%- set ndesc = "_nobag" if nobag else "" %} {%- set vdesc = "_vbe" if vbe else "" %} @@ -23,7 +23,16 @@ dense, nobag, vbe, - is_index_select) %} + is_index_select, + has_global_weight_decay_support=True, + ssd=ssd) %} +{%- set gwddesc = "_gwd" if is_gwd_kernel else "" %} +{%- set desc_suffix = wdesc + vdesc + gwddesc %} + +{%- set locs_or_addrs_tensor = "ssd_row_addrs" if ssd else "lxu_cache_locations" %} +{%- set locs_or_addrs_type = "int64_t" if ssd else "int32_t" %} +{%- set locs_or_addrs_idx = "row_idx" if ssd else "cache_idx" %} + #include "fbgemm_gpu/embedding_forward_template_helpers.cuh" #include "fbgemm_gpu/split_embeddings_cache_cuda.cuh" @@ -50,7 +59,7 @@ using namespace fbgemm_gpu; idx_j D_emb lxu_cache_weights - cache_idx_j + {{ locs_or_addrs_idx }}_j idx_weight_j VEC_WIDTH D @@ -58,6 +67,16 @@ using namespace fbgemm_gpu; output_j */#} {%- macro load_and_accumulate(from_cache) %} + {%- if from_cache %} + const cache_t* cache_weights; + {%- if ssd %} + cache_weights = reinterpret_cast( + *reinterpret_cast(&{{ locs_or_addrs_idx }}_j)); + {%- else %} + cache_weights = reinterpret_cast( + &lxu_cache_weights[{{ locs_or_addrs_idx }}_j][0]); + {%- endif %} + {%- endif %} {#-/* Set the weights row */#} const auto weights_row = WeightRowAccessor < @@ -75,15 +94,14 @@ using namespace fbgemm_gpu; // memory into the registers as a side effect nullptr, // Load from the cache - const_cast(&lxu_cache_weights[cache_idx_j][0]), + cache_weights, {%- else %} // Load from the embedding table - const_cast(&weights[idx_j * D_emb]), + &weights[idx_j * D_emb], // Pass nullptr bc we are loading from the embedding table nullptr, {%- endif %} - D, - nullptr); + D); {#-/* Set the quantization params */#} {%- if from_cache %} @@ -160,7 +178,7 @@ using namespace fbgemm_gpu; {%- if not dense and lxu_miss_rate != "cache_conflict_miss_rate::all" %} // Cooperatively load the cache's indices - [[maybe_unused]] int32_t cache_idx = (use_lxu_cache && placement == PlacementType::MANAGED_CACHING && l < L) ? lxu_cache_locations[indices_start + l] : 0; + [[maybe_unused]] {{ locs_or_addrs_type }} {{ locs_or_addrs_idx }} = (use_lxu_cache && placement == PlacementType::MANAGED_CACHING && l < L) ? {{ locs_or_addrs_tensor }}[indices_start + l] : 0; {%- endif %} {%- if lxu_miss_rate == "cache_conflict_miss_rate::zero" and is_gwd_kernel %} int64_t idx = l < L ? indices[indices_start + l] : 0; // only used for accessing prev_iter @@ -191,7 +209,8 @@ using namespace fbgemm_gpu; {%- if not dense and lxu_miss_rate != "cache_conflict_miss_rate::all" %} // Load cache's index from thread j in the group - [[maybe_unused]] int32_t cache_idx_j = use_lxu_cache ? SHFL_SYNC(cache_idx, j) : 0; + [[maybe_unused]] {{ locs_or_addrs_type }} {{ locs_or_addrs_idx }}_j + = use_lxu_cache ? SHFL_SYNC({{ locs_or_addrs_idx }}, j) : 0; {%- endif %} {%- if weighted %} @@ -223,7 +242,9 @@ using namespace fbgemm_gpu; {{ load_and_accumulate(true) }} {%- else %} {#-/* Else we defer to run-time selection */#} - if (placement == PlacementType::MANAGED_CACHING && cache_idx_j != kCacheLocationMissing) { + if (placement == PlacementType::MANAGED_CACHING + && {{ locs_or_addrs_idx }}_j != kCacheLocationMissing + ) { {#-/* If the row is available in the cache, fetch from the cache */#} {{ load_and_accumulate(true) }} } else { @@ -246,7 +267,6 @@ using namespace fbgemm_gpu; with global_weight_decay, otherwise, only generate regular kernel. */ #} -{%- set gwddesc = "_gwd" if is_gwd_kernel else "" %} template < typename emb_t, typename cache_t, @@ -258,12 +278,12 @@ template < {%- if not nobag %} size_t kMaxVecsPerThread, {%- endif %} - size_t kThreadGroupSize > + size_t kThreadGroupSize> __launch_bounds__(kForwardMaxThreads) __global__ void {%- if is_index_select %} batch_index_select_dim0_codegen_forward_kernel( {%- else %} -{{ ddesc }}_embedding{{ ndesc }}_codegen_forward_{{ wdesc }}{{ vdesc }}{{ gwddesc }}_kernel( +{{ mdesc }}_embedding{{ ndesc }}_codegen_forward_{{ desc_suffix }}_kernel( {%- endif %} const pta::PackedTensorAccessor64 dev_weights, {%- if not dense %} @@ -296,7 +316,7 @@ batch_index_select_dim0_codegen_forward_kernel( pta::PackedTensorAccessor32, 1, at::RestrictPtrTraits> indice_weights, {%- endif %} {%- if not dense %} - const pta::PackedTensorAccessor32 lxu_cache_locations, + const pta::PackedTensorAccessor32<{{ locs_or_addrs_type }}, 1, at::RestrictPtrTraits> {{ locs_or_addrs_tensor }}, /* NOTE: We pass in `lxu_cache_conflict_misses = uvm_cache_stats[uvm_cache_stats_index::num_conflict_unique_misses]` as a @@ -533,13 +553,19 @@ batch_index_select_dim0_codegen_forward_kernel( embedding_forward_split_template.cu */ -{%- macro template_instantiation(emb_type, cache_type, output_type, use_cache, kMaxVecsPerThread, kThreadGroupSize) %} -{%- set gwddesc = "_gwd" if is_gwd_kernel else "" %} +{%- macro template_instantiation( + emb_type, + cache_type, + output_type, + use_cache, + kMaxVecsPerThread, + kThreadGroupSize) +%} template __launch_bounds__(kForwardMaxThreads) __global__ void {%- if is_index_select %} batch_index_select_dim0_codegen_forward_kernel {%- else %} -{{ ddesc }}_embedding{{ ndesc }}_codegen_forward_{{ wdesc }}{{ vdesc }}{{ gwddesc }}_kernel +{{ mdesc }}_embedding{{ ndesc }}_codegen_forward_{{ desc_suffix }}_kernel {%- endif %} < {{ emb_type }}, @@ -585,7 +611,7 @@ batch_index_select_dim0_codegen_forward_kernel pta::PackedTensorAccessor32, 1, at::RestrictPtrTraits> indice_weights, {%- endif %} {%- if not dense %} - const pta::PackedTensorAccessor32 lxu_cache_locations, + const pta::PackedTensorAccessor32<{{ locs_or_addrs_type }}, 1, at::RestrictPtrTraits> {{ locs_or_addrs_tensor }}, const int32_t* lxu_cache_conflict_misses, {%- endif %} {%- if is_index_select %} @@ -608,7 +634,14 @@ batch_index_select_dim0_codegen_forward_kernel {%- for emb_type in ['float', 'at::Half'] %} {%- for cache_type in ['float', 'at::Half'] %} {%- for output_type in ['float', 'at::Half', 'at::BFloat16'] %} - {{ template_instantiation(emb_type, cache_type, output_type, use_cache, kMaxVecsPerThread, kThreadGroupSize) }} + {{ template_instantiation( + emb_type, + cache_type, + output_type, + use_cache, + kMaxVecsPerThread, + kThreadGroupSize) + }} {%- endfor %} {%- endfor %} {%- endfor %} @@ -616,7 +649,7 @@ batch_index_select_dim0_codegen_forward_kernel {%- macro instantiate_templates(use_subwarp_shuffle) %} {%- set has_experimental = - has_experimental_support(dense, nobag, vbe, is_index_select, is_rocm) + has_experimental_support(dense, nobag, vbe, is_index_select, is_rocm, ssd) %} {%- set max_forward_embedding_dim = legacy_max_embedding_dim if has_experimental else max_embedding_dim diff --git a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_v2_template.cu b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_v2_template.cu index 48c1d5e9b9..47e38b621e 100644 --- a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_v2_template.cu +++ b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_v2_template.cu @@ -1008,7 +1008,7 @@ __global__ void split_embedding_codegen_forward_{{ wdesc }}_v2_kernel const int64_t* __restrict__ const offsets, const uint32_t* __restrict__ const D_offsets, const int64_t* __restrict__ const weights_offsets, - const int32_t* __restrict__ const lxu_cache_locations, + const int32_t*__restrict__ const lxu_cache_locations, {{ output_type }}* __restrict__ const output); {%- endfor %} diff --git a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_meta_template.cpp b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_meta_template.cpp index 15611fb0c7..1e0860d006 100644 --- a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_meta_template.cpp +++ b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_meta_template.cpp @@ -17,10 +17,12 @@ // Companion template is embedding_forward_split_template.cu -{%- set ddesc = "dense" if dense else "split" %} +{%- set mdesc = "dense" if dense else ("ssd" if ssd else "split") %} {%- set wdesc = "weighted" if weighted else "unweighted" %} {%- set vdesc = "_vbe" if vbe else "" %} +{%- set locs_or_addrs_tensor = "ssd_row_addrs" if ssd else "lxu_cache_locations" %} + //////////////////////////////////////////////////////////////////////////////// // Required for op registrations #include "fbgemm_gpu/embedding_op_registration.h" @@ -47,12 +49,14 @@ static constexpr float kINT8QparamsBytes = 8; dense, nobag, vbe, - is_index_select + is_index_select, + has_global_weight_decay_support=True, + ssd=False, ) else [False]) %} {%- set gwddesc = "_gwd" if is_gwd else "" %} Tensor -{{ ddesc }}_embedding{{ ndesc }}_codegen_forward_{{ wdesc }}{{ vdesc }}{{ gwddesc }}_meta( +{{ mdesc }}_embedding{{ ndesc }}_codegen_forward_{{ wdesc }}{{ vdesc }}{{ gwddesc }}_meta( const Tensor& dev_weights, {%- if not dense %} const Tensor& uvm_weights, @@ -80,7 +84,7 @@ Tensor const Tensor& indice_weights, {%- endif %} {%- if not dense %} - const Tensor& lxu_cache_locations, + const Tensor& {{ locs_or_addrs_tensor }}, const Tensor& uvm_cache_stats, {%- endif %} const int64_t output_dtype, @@ -200,10 +204,10 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { // NB: yes cuda here {%- set embedding_codegen_forward_op = "{}_embedding{}_codegen_forward_{}{}{}_cuda".format( - ddesc, ndesc, wdesc, vdesc, gwddesc + mdesc, ndesc, wdesc, vdesc, gwddesc ) %} - m.impl("{{ embedding_codegen_forward_op }}", torch::dispatch(c10::DispatchKey::Meta, TORCH_FN({{ ddesc }}_embedding{{ ndesc }}_codegen_forward_{{ wdesc }}{{ vdesc }}{{ gwddesc }}_meta))); + m.impl("{{ embedding_codegen_forward_op }}", torch::dispatch(c10::DispatchKey::Meta, TORCH_FN({{ mdesc }}_embedding{{ ndesc }}_codegen_forward_{{ wdesc }}{{ vdesc }}{{ gwddesc }}_meta))); } {%- endfor %} {#-/* for is_gwd */#} {%- endif %} {#/* if (not nobag or (not weighted and not vbe)) */#} diff --git a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu index 42a59272d1..3205d34a2a 100644 --- a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu +++ b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu @@ -15,9 +15,17 @@ // See https://fburl.com/dw9ljh4h #} -{%- set ddesc = "dense" if dense else "split" %} +{%- set mdesc = "dense" if dense else ("ssd" if ssd else "split") %} {%- set wdesc = "weighted" if weighted else "unweighted" %} + +{%- macro get_desc_suffix(gwd) %} {%- set vdesc = "_vbe" if vbe else "" %} +{%- set gwddesc = "_gwd" if gwd else "" %} +{{- wdesc + vdesc + gwddesc }} +{%- endmacro %} + +{%- set locs_or_addrs_tensor = "ssd_row_addrs" if ssd else "lxu_cache_locations" %} +{%- set locs_or_addrs_type = "int64_t" if ssd else "int32_t" %} {%- if not dense and not nobag and not vbe %} #include "fbgemm_gpu/dispatch_macros.h" @@ -50,7 +58,7 @@ __launch_bounds__(kForwardMaxThreads) __global__ void {%- if is_index_select %} batch_index_select_dim0_codegen_forward_small_kernel( {%- else %} -{{ ddesc }}_embedding_nobag_codegen_forward_unweighted_small_kernel( +{{ mdesc }}_embedding_nobag_codegen_forward_unweighted_small_kernel( {%- endif %} const pta::PackedTensorAccessor64 dev_weights, {%- if not dense %} @@ -70,7 +78,7 @@ batch_index_select_dim0_codegen_forward_small_kernel( const pta::PackedTensorAccessor32 offsets, {%- endif %} {%- if not dense %} - const pta::PackedTensorAccessor32 lxu_cache_locations, + const pta::PackedTensorAccessor32<{{ locs_or_addrs_type }}, 1, at::RestrictPtrTraits> {{ locs_or_addrs_tensor }}, {%- endif %} {%- if is_index_select %} const pta::PackedTensorAccessor32 output_offsets, @@ -119,14 +127,15 @@ __global__ void split_embedding_codegen_forward_{{ wdesc }}_v2_kernel( {%- for nobag in ([True, False] if (not is_gwd) else [False]) %} {%- set ndesc = "_nobag" if nobag else "" %} {%- if is_valid_forward_config(nobag, weighted, vbe, is_index_select) %} -{%- set has_experimental = has_experimental_support(dense, nobag, vbe, is_index_select, is_rocm) %} +{%- set has_experimental = has_experimental_support(dense, nobag, vbe, is_index_select, is_rocm, ssd) %} {%- set is_gwd_kernel = is_gwd and is_valid_gwd_config( dense, nobag, vbe, - is_index_select) %} -{%- set gwddesc = "_gwd" if is_gwd_kernel else "" %} + is_index_select, + has_global_weight_decay_support=True, + ssd=ssd) %} template < typename emb_t, typename cache_t, @@ -144,7 +153,7 @@ __launch_bounds__(kForwardMaxThreads) __global__ void {%- if is_index_select %} batch_index_select_dim0_codegen_forward_kernel( {%- else %} -{{ ddesc }}_embedding{{ ndesc }}_codegen_forward_{{ wdesc }}{{ vdesc }}{{ gwddesc }}_kernel( +{{ mdesc }}_embedding{{ ndesc }}_codegen_forward_{{ get_desc_suffix(is_gwd_kernel) }}_kernel( {%- endif %} const pta::PackedTensorAccessor64 dev_weights, {%- if not dense %} @@ -177,7 +186,7 @@ batch_index_select_dim0_codegen_forward_kernel( pta::PackedTensorAccessor32, 1, at::RestrictPtrTraits> indice_weights, {%- endif %} {%- if not dense %} - const pta::PackedTensorAccessor32 lxu_cache_locations, + const pta::PackedTensorAccessor32<{{ locs_or_addrs_type }}, 1, at::RestrictPtrTraits> {{ locs_or_addrs_tensor }}, const int32_t* lxu_cache_conflict_misses, {%- endif %} {%- if is_index_select %} @@ -307,20 +316,22 @@ batch_index_select_dim0_codegen_forward_kernel( {%- for nobag in ([True, False] if (not is_gwd) else [False]) %} {%- set ndesc = "_nobag" if nobag else "" %} {%- if is_valid_forward_config(nobag, weighted, vbe, is_index_select) %} -{%- set has_experimental = has_experimental_support(dense, nobag, vbe, is_index_select, is_rocm) %} +{%- set has_experimental = has_experimental_support(dense, nobag, vbe, is_index_select, is_rocm, ssd) %} {#- /* Generate a separate cuda host to enable global weight decay using Jinja */ #} {%- set is_gwd_kernel = is_gwd and is_valid_gwd_config( dense, nobag, vbe, - is_index_select) %} -{%- set gwddesc = "_gwd" if is_gwd_kernel else "" %} + is_index_select, + has_global_weight_decay_support=True, + ssd=ssd) %} +{%- set desc_suffix = get_desc_suffix(is_gwd_kernel) %} Tensor {%- if is_index_select %} batch_index_select_dim0_codegen_forward_cuda( {%- else %} -{{ ddesc }}_embedding{{ ndesc }}_codegen_forward_{{ wdesc }}{{ vdesc }}{{ gwddesc }}_cuda( +{{ mdesc }}_embedding{{ ndesc }}_codegen_forward_{{ desc_suffix }}_cuda( {%- endif %} const Tensor& dev_weights, {%- if not dense %} @@ -351,7 +362,7 @@ batch_index_select_dim0_codegen_forward_cuda( const Tensor& indice_weights, {%- endif %} {%- if not dense %} - const Tensor& lxu_cache_locations, + const Tensor& {{ locs_or_addrs_tensor }}, const Tensor& uvm_cache_stats, {%- endif %} const int64_t output_dtype, @@ -391,7 +402,6 @@ batch_index_select_dim0_codegen_forward_cuda( const int64_t total_D = total_D_.guard_int(__FILE__, __LINE__); {%- endif %} - {%- if not nobag or is_index_select %} const int64_t max_D = max_D_.guard_int(__FILE__, __LINE__); {%- endif %} @@ -417,7 +427,7 @@ batch_index_select_dim0_codegen_forward_cuda( indice_weights, {%- endif %} {%- if not dense %} - lxu_cache_locations, + {{ locs_or_addrs_tensor }}, {%- endif %} {%- if vbe %} vbe_row_output_offsets, @@ -576,12 +586,17 @@ batch_index_select_dim0_codegen_forward_cuda( {%- set nobag_small_kernel = "batch_index_select_dim0_codegen_forward_small_kernel" if is_index_select else - "{}_embedding_nobag_codegen_forward_unweighted_small_kernel".format(ddesc) + "{}_embedding_nobag_codegen_forward_unweighted_small_kernel".format(mdesc) %} #ifdef FBGEMM_GPU_MEMCHECK const auto func_name = "{{ nobag_small_kernel }}"; #endif - {{ nobag_small_kernel }} + {{ nobag_small_kernel }}< + emb_t, + cache_t, + output_t, + int64_t, + kEmbeddingSize / 4> <<< div_round_up(total_B, kForwardMaxThreads / kWarpSize), dim3(kWarpSize, kForwardMaxThreads / kWarpSize), @@ -606,7 +621,7 @@ batch_index_select_dim0_codegen_forward_cuda( MAKE_PTA_WITH_NAME(func_name, offsets, int64_t, 1, 32), {%- endif %} {%- if not dense %} - MAKE_PTA_WITH_NAME(func_name, lxu_cache_locations, int32_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name, {{ locs_or_addrs_tensor }}, {{ locs_or_addrs_type }}, 1, 32), {%- endif %} {%- if is_index_select %} MAKE_PTA_WITH_NAME(func_name, output_offsets, int64_t, 1, 32), @@ -617,6 +632,7 @@ batch_index_select_dim0_codegen_forward_cuda( MAKE_PTA_WITH_NAME(func_name, output, output_t, {{ "1" if is_index_select else "2" }}, 64) ); + C10_CUDA_KERNEL_LAUNCH_CHECK(); return; }); @@ -625,7 +641,7 @@ batch_index_select_dim0_codegen_forward_cuda( {%- set nobag_kernel = "batch_index_select_dim0_codegen_forward_kernel" if is_index_select else - "{}_embedding_nobag_codegen_forward_unweighted_kernel".format(ddesc) + "{}_embedding_nobag_codegen_forward_unweighted_kernel".format(mdesc) %} #ifdef FBGEMM_GPU_MEMCHECK const auto func_name = "{{ nobag_kernel }}"; @@ -661,7 +677,7 @@ batch_index_select_dim0_codegen_forward_cuda( MAKE_PTA_WITH_NAME(func_name, offsets, int64_t, 1, 32), {%- endif %} {%- if not dense %} - MAKE_PTA_WITH_NAME(func_name, lxu_cache_locations, int32_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name, {{ locs_or_addrs_tensor }}, {{ locs_or_addrs_type }}, 1, 32), uvm_cache_stats.size(0) == 0 ? nullptr : (uvm_cache_stats.data_ptr() + uvm_cache_stats_index::num_conflict_unique_misses), @@ -674,7 +690,6 @@ batch_index_select_dim0_codegen_forward_cuda( {%- endif %} MAKE_PTA_WITH_NAME(func_name, output, output_t, {{ "1" if is_index_select else "2" }}, 64) ); - C10_CUDA_KERNEL_LAUNCH_CHECK(); return; }); @@ -692,7 +707,7 @@ batch_index_select_dim0_codegen_forward_cuda( {{ dispatcher }}(max_D, [&] { #ifdef FBGEMM_GPU_MEMCHECK - const auto func_name = "{{ ddesc }}_embedding_codegen_forward_{{ wdesc }}{{ vdesc }}{{ gwddesc }}_kernel"; + const auto func_name = "{{ mdesc }}_embedding_codegen_forward_{{ desc_suffix }}_kernel"; #endif // Other components in TBE (backward, backward_indice_weights) use // kFixedMaxVecsPerThread. Thus, the codegen generates @@ -700,7 +715,7 @@ batch_index_select_dim0_codegen_forward_cuda( // kMaxVecsPerThread and kFixedMaxVecsPerThread are the same // forward constexpr auto kMaxVecsPerThread = kFixedMaxVecsPerThread; - {{ ddesc }}_embedding_codegen_forward_{{ wdesc }}{{ vdesc }}{{ gwddesc }}_kernel + {{ mdesc }}_embedding_codegen_forward_{{ desc_suffix }}_kernel () + uvm_cache_stats_index::num_conflict_unique_misses), @@ -753,7 +768,6 @@ batch_index_select_dim0_codegen_forward_cuda( {%- endif %} // if not dense MAKE_PTA_WITH_NAME(func_name, output, output_t, 2, 64) ); - C10_CUDA_KERNEL_LAUNCH_CHECK(); {%- if vbe %} output = output.reshape({-1}); @@ -777,8 +791,10 @@ batch_index_select_dim0_codegen_forward_cuda( const uint32_t num_warps_per_threadblock = kForwardMaxThreads / kWarpSize; const auto kernel_func = - (use_lxu_cache ? split_embedding_codegen_forward_{{ wdesc }}_v2_kernel - : split_embedding_codegen_forward_{{ wdesc }}_v2_kernel); + (use_lxu_cache ? split_embedding_codegen_forward_{{ wdesc }}_v2_kernel< + emb_t, cache_t, output_t, int64_t, true> + : split_embedding_codegen_forward_{{ wdesc }}_v2_kernel< + emb_t, cache_t, output_t, int64_t, false>); kernel_func <<< @@ -823,8 +839,8 @@ batch_index_select_dim0_codegen_forward_cuda( {%- if not is_index_select %} TORCH_LIBRARY_FRAGMENT(fbgemm, m) { {%- set embedding_codegen_forward_op = - "{}_embedding{}_codegen_forward_{}{}{}_cuda".format( - ddesc, ndesc, wdesc, vdesc, gwddesc + "{}_embedding{}_codegen_forward_{}_cuda".format( + mdesc, ndesc, desc_suffix ) %} m.def("{{ embedding_codegen_forward_op }}(" @@ -851,7 +867,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { " Tensor indice_weights, " {%- endif %} {%- if not dense %} - " Tensor lxu_cache_locations, " + " Tensor {{ locs_or_addrs_tensor }}, " " Tensor uvm_cache_stats, " {%- endif %} " int output_dtype, " diff --git a/fbgemm_gpu/codegen/training/optimizer/embedding_optimizer_split_device_kernel_template.cuh b/fbgemm_gpu/codegen/training/optimizer/embedding_optimizer_split_device_kernel_template.cuh index d4cbb962a4..cb3977cabe 100644 --- a/fbgemm_gpu/codegen/training/optimizer/embedding_optimizer_split_device_kernel_template.cuh +++ b/fbgemm_gpu/codegen/training/optimizer/embedding_optimizer_split_device_kernel_template.cuh @@ -14,6 +14,11 @@ #define GROUP_REDUCE_ALL_SUM(val, ...) \ warpReduceAllSum<__VA_ARGS__, kThreadGroupSize>(val, shfl_sync_mask) +{%- set mdesc = "ssd" if ssd else "split" %} +{%- set locs_or_addrs_tensor = "ssd_row_addrs" if ssd else "lxu_cache_locations" %} +{%- set locs_or_addrs_type = "int64_t" if ssd else "int32_t" %} +{%- set locs_or_addrs_idx = "row_idx" if ssd else "cache_idx" %} + using namespace fbgemm_gpu; template < @@ -28,13 +33,13 @@ template < int32_t VEC_WIDTH, bool kUseVecBlocking > -DEVICE_INLINE void split_{{ optimizer }}_table_update_kernel( +DEVICE_INLINE void {{ mdesc }}_{{ optimizer }}_table_update_kernel( pta::PackedTensorAccessor64& dev_weights, pta::PackedTensorAccessor64& uvm_weights, pta::PackedTensorAccessor64& lxu_cache_weights, const pta::PackedTensorAccessor32& weights_placements, const pta::PackedTensorAccessor32& weights_offsets, - const pta::PackedTensorAccessor32& sorted_lxu_cache_locations, + const pta::PackedTensorAccessor32<{{ locs_or_addrs_type }}, 1, at::RestrictPtrTraits>& sorted_{{ locs_or_addrs_tensor }}, Vec4TAcc* grad_sum, Vec4TAcc* smem_grad_sum, Vec4TAcc* shared_weight_update_row, @@ -71,10 +76,15 @@ DEVICE_INLINE void split_{{ optimizer }}_table_update_kernel( weights = &uvm_weights[weights_offset + idx * D_emb]; } if (weights_placement == PlacementType::MANAGED_CACHING) { - const int32_t cache_idx = sorted_lxu_cache_locations[cache_loc_run_id]; - if (cache_idx != kCacheLocationMissing) { - cache_weights = &lxu_cache_weights[cache_idx][0]; + const auto {{ locs_or_addrs_idx }} = sorted_{{ locs_or_addrs_tensor }}[cache_loc_run_id]; + {%- if ssd %} + cache_weights = reinterpret_cast( + *reinterpret_cast(&{{ locs_or_addrs_idx }})); + {%- else %} + if ({{ locs_or_addrs_idx }} != kCacheLocationMissing) { + cache_weights = &lxu_cache_weights[{{ locs_or_addrs_idx }}][0]; } + {%- endif %} } {%- for tensor in args.split_tensors %} {{ args.split_tensor_types[tensor] }}* __restrict__ {{ tensor }}; diff --git a/fbgemm_gpu/codegen/training/python/__init__.template b/fbgemm_gpu/codegen/training/python/__init__.template index 42f49ee3c1..4b7c9a8122 100644 --- a/fbgemm_gpu/codegen/training/python/__init__.template +++ b/fbgemm_gpu/codegen/training/python/__init__.template @@ -6,14 +6,42 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import fbgemm_gpu.split_embedding_codegen_lookup_invokers.lookup_adagrad as lookup_adagrad # noqa: F401 -import fbgemm_gpu.split_embedding_codegen_lookup_invokers.lookup_adam as lookup_adam # noqa: F401 +{%- set all_optims = [ + "adagrad", + "adam", + "lamb", + "lars_sgd", + "partial_rowwise_adam", + "partial_rowwise_lamb", + "rowwise_adagrad", + "rowwise_adagrad_with_counter", + "sgd", + "none" + ] +%} + +{%- set ssd_optims = [ + "sgd", + "rowwise_adagrad" + ] +%} + +# All optimizers import fbgemm_gpu.split_embedding_codegen_lookup_invokers.lookup_args as lookup_args # noqa: F401 -import fbgemm_gpu.split_embedding_codegen_lookup_invokers.lookup_lamb as lookup_lamb # noqa: F401 -import fbgemm_gpu.split_embedding_codegen_lookup_invokers.lookup_lars_sgd as lookup_lars_sgd # noqa: F401 -import fbgemm_gpu.split_embedding_codegen_lookup_invokers.lookup_partial_rowwise_adam as lookup_partial_rowwise_adam # noqa: F401 -import fbgemm_gpu.split_embedding_codegen_lookup_invokers.lookup_partial_rowwise_lamb as lookup_partial_rowwise_lamb # noqa: F401 -import fbgemm_gpu.split_embedding_codegen_lookup_invokers.lookup_rowwise_adagrad as lookup_rowwise_adagrad # noqa: F401 -import fbgemm_gpu.split_embedding_codegen_lookup_invokers.lookup_rowwise_adagrad_with_counter as lookup_rowwise_adagrad_with_counter # noqa: F401 -import fbgemm_gpu.split_embedding_codegen_lookup_invokers.lookup_sgd as lookup_sgd # noqa: F401 -import fbgemm_gpu.split_embedding_codegen_lookup_invokers.lookup_none as lookup_none # noqa: F401 +{%- for optim in all_optims %} +import fbgemm_gpu.split_embedding_codegen_lookup_invokers.lookup_{{ optim }} as lookup_{{optim}} # noqa: F401 +{%- endfor %} + +# SSD optimizers (putting them under try-except for BC as they are +# experimental ops which can be removed/updated in the future) +try: + import fbgemm_gpu.split_embedding_codegen_lookup_invokers.lookup_args_ssd as lookup_args_ssd + {%- for optim in ssd_optims %} + import fbgemm_gpu.split_embedding_codegen_lookup_invokers.lookup_{{ optim }}_ssd as lookup_{{ optim }}_ssd + {%- endfor %} +except: + import logging + logging.warn("fbgemm_gpu.split_embedding_codegen_lookup_invokers.lookup_args_ssd import failed") + {%- for optim in ssd_optims %} + logging.warn("fbgemm_gpu.split_embedding_codegen_lookup_invokers.lookup_{{ optim }}_ssd import failed") + {%- endfor %} diff --git a/fbgemm_gpu/codegen/training/python/lookup_args.py b/fbgemm_gpu/codegen/training/python/lookup_args.template similarity index 94% rename from fbgemm_gpu/codegen/training/python/lookup_args.py rename to fbgemm_gpu/codegen/training/python/lookup_args.template index eb752ee6a0..357aad622a 100644 --- a/fbgemm_gpu/codegen/training/python/lookup_args.py +++ b/fbgemm_gpu/codegen/training/python/lookup_args.template @@ -7,7 +7,7 @@ # pyre-strict -from typing import NamedTuple, Optional +from typing import NamedTuple, Optional, Dict import torch @@ -46,6 +46,9 @@ class CommonArgs(NamedTuple): is_experimental: bool use_uniq_cache_locations_bwd: bool use_homogeneous_placements: bool + {%- if ssd %} + ssd_tensors: Dict[str, torch.Tensor] + {%- endif %} class OptimizerArgs(NamedTuple): diff --git a/fbgemm_gpu/codegen/training/python/split_embedding_codegen_lookup_invoker.template b/fbgemm_gpu/codegen/training/python/split_embedding_codegen_lookup_invoker.template index cedc0c41ce..2de471eea6 100644 --- a/fbgemm_gpu/codegen/training/python/split_embedding_codegen_lookup_invoker.template +++ b/fbgemm_gpu/codegen/training/python/split_embedding_codegen_lookup_invoker.template @@ -7,12 +7,14 @@ # pyre-ignore-all-errors +{%- set mdesc = "ssd" if ssd else "split" %} +{%- set sdesc = "_ssd" if ssd else "" %} + import torch {%- if is_experimental_optimizer %} import warnings {%- endif %} -from .lookup_args import * - +from .lookup_args{{ sdesc }} import * {%- if is_fbcode %} @@ -73,7 +75,7 @@ def invoke( ) {%- endif %} - {%- if has_cpu_support %} + {%- if has_cpu_support and not ssd %} if (common_args.host_weights.numel() > 0): return torch.ops.fbgemm.split_embedding_codegen_lookup_{{ optimizer }}_function_cpu( # common_args @@ -192,7 +194,19 @@ def invoke( {%- if has_gpu_support %} vbe_metadata = common_args.vbe_metadata - return torch.ops.fbgemm.split_embedding_codegen_lookup_{{ optimizer }}_function( + + {%- if ssd %} + ssd_tensors = [] + {%- for tensor in ssd_tensors %} + assert "{{ tensor }}" in common_args.ssd_tensors, ( + "{{ tensor }} must be in common_args.ssd_tensors. " + "Please check the backend version" + ) + ssd_tensors.append(common_args.ssd_tensors["{{ tensor }}"]) + {%- endfor %} + {%- endif %} + + return torch.ops.fbgemm.{{ mdesc }}_embedding_codegen_lookup_{{ optimizer }}_function( # common_args {%- if not dense %} placeholder_autograd_tensor=common_args.placeholder_autograd_tensor, @@ -214,6 +228,9 @@ def invoke( feature_requires_grad=common_args.feature_requires_grad, lxu_cache_locations=common_args.lxu_cache_locations, uvm_cache_stats=common_args.uvm_cache_stats, + {%- if ssd %} + ssd_tensors=ssd_tensors, + {%- endif %} # VBE metadata B_offsets=vbe_metadata.B_offsets, vbe_output_offsets_feature_rank=vbe_metadata.output_offsets_feature_rank, diff --git a/fbgemm_gpu/fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py b/fbgemm_gpu/fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py index 64316bc6c8..46e9395f87 100644 --- a/fbgemm_gpu/fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +++ b/fbgemm_gpu/fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py @@ -17,7 +17,7 @@ import torch # usort:skip import fbgemm_gpu.split_embedding_codegen_lookup_invokers as invokers -from fbgemm_gpu.split_embedding_configs import SparseType +from fbgemm_gpu.split_embedding_configs import EmbOptimType as OptimType, SparseType from fbgemm_gpu.split_table_batched_embeddings_ops_common import ( CacheAlgorithm, @@ -92,6 +92,7 @@ def __init__( ssd_cache_location: EmbeddingLocation = EmbeddingLocation.MANAGED, ssd_uniform_init_lower: float = -0.01, ssd_uniform_init_upper: float = 0.01, + optimizer: OptimType = OptimType.EXACT_ROWWISE_ADAGRAD, # General Optimizer args stochastic_rounding: bool = True, gradient_clipping: bool = False, @@ -240,6 +241,9 @@ def __init__( self.ssd_set_start = torch.cuda.Event() self.ssd_set_end = torch.cuda.Event() self.timesteps_prefetched: List[int] = [] + self.ssd_scratch_pads: List[Tuple[Tensor, Tensor, Tensor]] = [] + # TODO: add type annotation + self.ssd_prefetch_data = [] if weight_decay_mode == WeightDecayMode.COUNTER or counter_based_regularization: raise AssertionError( @@ -253,7 +257,7 @@ def __init__( ) cowclip_regularization = CowClipDefinition() - self.optimizer_args = invokers.lookup_args.OptimizerArgs( + self.optimizer_args = invokers.lookup_args_ssd.OptimizerArgs( stochastic_rounding=stochastic_rounding, gradient_clipping=gradient_clipping, max_gradient=max_gradient, @@ -299,9 +303,10 @@ def __init__( dtype=torch.int32, ), ) - weights_offsets = [0] + list( - itertools.accumulate([row * dim for (row, dim) in zip(rows, dims)]) - ) + # weights_offsets = [0] + list( + # itertools.accumulate([row * dim for (row, dim) in zip(rows, dims)]) + # ) + weights_offsets = [0] * (len(rows) + 1) self.register_buffer( "weights_offsets", torch.tensor( @@ -350,8 +355,93 @@ def __init__( torch.zeros(0, device=self.current_device, dtype=torch.float) ) + # Register backward hook for evicting rows from a scratch pad to SSD + # post backward + self.placeholder_autograd_tensor.register_hook(self._evict_from_scratch_pad) + + assert optimizer in ( + OptimType.EXACT_SGD, + OptimType.EXACT_ROWWISE_ADAGRAD, + ), f"Optimizer {optimizer} is not supported by SSDTableBatchedEmbeddingBags" + self.optimizer = optimizer + + def to_pinned_cpu(self, t: torch.Tensor) -> torch.Tensor: + t_cpu = torch.empty(t.shape, pin_memory=True, dtype=t.dtype) + t_cpu.copy_(t, non_blocking=True) + return t_cpu + + def evict( + self, evicted_rows: Tensor, evicted_indices: Tensor, actions_count_cpu: Tensor + ) -> None: + """ + Evict data from the given input tensors to SSD via RocksDB + """ + with torch.cuda.stream(self.ssd_stream): + self.ssd_stream.wait_event(self.ssd_set_start) + evicted_rows_cpu = self.to_pinned_cpu(evicted_rows) + evicted_indices_cpu = self.to_pinned_cpu(evicted_indices) + evicted_rows.record_stream(self.ssd_stream) + evicted_indices.record_stream(self.ssd_stream) + self.ssd_db.set_cuda( + evicted_indices_cpu, evicted_rows_cpu, actions_count_cpu, self.timestep + ) + # TODO: is this needed? + # Need a way to synchronize + # actions_count_cpu.record_stream(self.ssd_stream) + self.ssd_stream.record_event(self.ssd_set_end) + + def _evict_from_scratch_pad(self, grad: Tensor) -> None: + assert len(self.ssd_scratch_pads) > 0, "There must be at least one scratch pad" + (inserted_rows_gpu, post_bwd_evicted_indices, actions_count_cpu) = ( + self.ssd_scratch_pads.pop(0) + ) + self.evict(inserted_rows_gpu, post_bwd_evicted_indices, actions_count_cpu) + + def _compute_cache_ptrs( + self, + linear_cache_indices: torch.Tensor, + assigned_cache_slots: torch.Tensor, + linear_index_inverse_indices: torch.Tensor, + unique_indices_count_cumsum: torch.Tensor, + cache_set_inverse_indices: torch.Tensor, + inserted_rows_gpu: torch.Tensor, + unique_indices_length: torch.Tensor, + inserted_indices: torch.Tensor, + actions_count_cpu: torch.Tensor, + ) -> torch.Tensor: + lxu_cache_locations = torch.ops.fbgemm.lxu_cache_lookup( + linear_cache_indices, + self.lxu_cache_state, + self.hash_size_cumsum[-1].item(), + ) + lxu_cache_ptrs, post_bwd_evicted_indices = ( + torch.ops.fbgemm.ssd_generate_row_addrs( + lxu_cache_locations, + assigned_cache_slots, + linear_index_inverse_indices, + unique_indices_count_cumsum, + cache_set_inverse_indices, + self.lxu_cache_weights, + inserted_rows_gpu, + unique_indices_length, + inserted_indices, + ) + ) + # Store scratch pad info for post backward eviction + self.ssd_scratch_pads.append( + (inserted_rows_gpu, post_bwd_evicted_indices, actions_count_cpu) + ) + assert lxu_cache_ptrs[lxu_cache_ptrs == 0].numel() == 0 + return ( + lxu_cache_ptrs, + inserted_rows_gpu, + post_bwd_evicted_indices, + actions_count_cpu, + ) + def prefetch(self, indices: Tensor, offsets: Tensor) -> Optional[Tensor]: (indices, offsets) = indices.long(), offsets.long() + linear_cache_indices = torch.ops.fbgemm.linearize_cache_indices( self.hash_size_cumsum, indices, @@ -364,10 +454,10 @@ def prefetch(self, indices: Tensor, offsets: Tensor) -> Optional[Tensor]: evicted_indices, assigned_cache_slots, actions_count_gpu, - _, - _, - _, - _, + linear_index_inverse_indices, + unique_indices_count_cumsum, + cache_set_inverse_indices, + unique_indices_length, ) = torch.ops.fbgemm.ssd_cache_populate_actions( linear_cache_indices, self.total_hash_size, @@ -377,15 +467,10 @@ def prefetch(self, indices: Tensor, offsets: Tensor) -> Optional[Tensor]: self.lru_state, ) - def to_pinned_cpu(t: torch.Tensor) -> torch.Tensor: - t_cpu = torch.empty(t.shape, pin_memory=True, dtype=t.dtype) - t_cpu.copy_(t, non_blocking=True) - return t_cpu - - actions_count_cpu = to_pinned_cpu(actions_count_gpu) + actions_count_cpu = self.to_pinned_cpu(actions_count_gpu) assigned_cache_slots = assigned_cache_slots.long() evicted_rows = self.lxu_cache_weights[ - assigned_cache_slots.clamp_(min=0).long(), : + assigned_cache_slots.clamp(min=0).long(), : ] inserted_rows = torch.empty( evicted_rows.shape, @@ -398,14 +483,13 @@ def to_pinned_cpu(t: torch.Tensor) -> torch.Tensor: # Ensure the previous iterations l3_db.set(..) has completed. current_stream.wait_event(self.ssd_set_end) self.ssd_db.get_cuda( - to_pinned_cpu(inserted_indices), inserted_rows, actions_count_cpu + self.to_pinned_cpu(inserted_indices), inserted_rows, actions_count_cpu ) current_stream.record_event(self.ssd_set_start) # TODO: T123943415 T123943414 this is a big copy that is (mostly) unnecessary with a decent cache hit rate. # Should we allocate on HBM? inserted_rows_gpu = inserted_rows.cuda(non_blocking=True) - # self.lxu_cache_weights[assigned_cache_slots, :] = inserted_rows.cuda(non_blocking=True) torch.ops.fbgemm.masked_index_put( self.lxu_cache_weights, assigned_cache_slots, @@ -413,20 +497,23 @@ def to_pinned_cpu(t: torch.Tensor) -> torch.Tensor: actions_count_gpu, ) - with torch.cuda.stream(self.ssd_stream): - self.ssd_stream.wait_event(self.ssd_set_start) - evicted_rows_cpu = to_pinned_cpu(evicted_rows) - evicted_indices_cpu = to_pinned_cpu(evicted_indices) - evicted_rows.record_stream(self.ssd_stream) - evicted_indices.record_stream(self.ssd_stream) - self.ssd_db.set_cuda( - evicted_indices_cpu, evicted_rows_cpu, actions_count_cpu, self.timestep + # Evict rows from cache to SSD + self.evict(evicted_rows, evicted_indices, actions_count_cpu) + + # TODO: keep only necessary tensors + self.ssd_prefetch_data.append( + ( + linear_cache_indices, + assigned_cache_slots, + linear_index_inverse_indices, + unique_indices_count_cumsum, + cache_set_inverse_indices, + inserted_rows_gpu, + unique_indices_length, + inserted_indices, + actions_count_cpu, ) - # TODO: is this needed? - # Need a way to synchronize - # actions_count_cpu.record_stream(self.ssd_stream) - self.ssd_stream.record_event(self.ssd_set_end) - return linear_cache_indices + ) def forward( self, @@ -441,19 +528,18 @@ def forward( per_sample_weights = per_sample_weights.float() if len(self.timesteps_prefetched) == 0: with record_function("## prefetch ##"): - linear_cache_indices = self.prefetch(indices, offsets) - else: - linear_cache_indices = torch.ops.fbgemm.linearize_cache_indices( - self.hash_size_cumsum, - indices, - offsets, - ) - lxu_cache_locations = torch.ops.fbgemm.lxu_cache_lookup( - linear_cache_indices, - self.lxu_cache_state, - self.hash_size_cumsum[-1].item(), - ) - common_args = invokers.lookup_args.CommonArgs( + self.prefetch(indices, offsets) + assert len(self.ssd_prefetch_data) > 0 + + prefetch_data = self.ssd_prefetch_data.pop(0) + ( + lxu_cache_ptrs, + inserted_rows_gpu, + post_bwd_evicted_indices, + actions_count_cpu, + ) = self._compute_cache_ptrs(*prefetch_data) + + common_args = invokers.lookup_args_ssd.CommonArgs( placeholder_autograd_tensor=self.placeholder_autograd_tensor, output_dtype=SparseType.FP32.as_int(), dev_weights=self.weights_dev, @@ -472,9 +558,9 @@ def forward( pooling_mode=self.pooling_mode, indice_weights=per_sample_weights, feature_requires_grad=feature_requires_grad, - lxu_cache_locations=lxu_cache_locations, + lxu_cache_locations=lxu_cache_ptrs, uvm_cache_stats=None, - vbe_metadata=invokers.lookup_args.VBEMetadata( + vbe_metadata=invokers.lookup_args_ssd.VBEMetadata( B_offsets=None, output_offsets_feature_rank=None, B_offsets_rank_per_feature=None, @@ -486,9 +572,22 @@ def forward( is_experimental=False, use_uniq_cache_locations_bwd=False, use_homogeneous_placements=True, + # The keys for ssd_tensors are controlled by ssd_tensors in + # codegen/genscript/optimizer_args.py + ssd_tensors={ + "row_addrs": lxu_cache_ptrs, + "inserted_rows": inserted_rows_gpu, + "post_bwd_evicted_indices": post_bwd_evicted_indices, + "actions_count": actions_count_cpu, + }, ) - momentum1 = invokers.lookup_args.Momentum( + self.timesteps_prefetched.pop(0) + + if self.optimizer == OptimType.EXACT_SGD: + return invokers.lookup_sgd_ssd.invoke(common_args, self.optimizer_args) + + momentum1 = invokers.lookup_args_ssd.Momentum( dev=self.momentum1_dev, host=self.momentum1_host, uvm=self.momentum1_uvm, @@ -496,10 +595,10 @@ def forward( placements=self.momentum1_placements, ) - self.timesteps_prefetched.pop(0) - return invokers.lookup_rowwise_adagrad.invoke( - common_args, self.optimizer_args, momentum1 - ) + if self.optimizer == OptimType.EXACT_ROWWISE_ADAGRAD: + return invokers.lookup_rowwise_adagrad_ssd.invoke( + common_args, self.optimizer_args, momentum1 + ) @torch.jit.ignore def debug_split_optimizer_states(self) -> List[Tuple[torch.Tensor]]: @@ -890,7 +989,6 @@ def prefetch(self, indices: Tensor, offsets: Tensor) -> Tensor: 1, # for now assume prefetch_dist == 1 self.lru_state, ) - actions_count_cpu = torch.empty( actions_count_gpu.shape, pin_memory=True, dtype=actions_count_gpu.dtype ) diff --git a/fbgemm_gpu/include/fbgemm_gpu/fbgemm_cuda_utils.cuh b/fbgemm_gpu/include/fbgemm_gpu/fbgemm_cuda_utils.cuh index 46fd14c9e5..ab68ee8da8 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/fbgemm_cuda_utils.cuh +++ b/fbgemm_gpu/include/fbgemm_gpu/fbgemm_cuda_utils.cuh @@ -1279,15 +1279,12 @@ struct WeightRow { template struct WeightRowAccessor { - emb_t* row_; - cache_t* cache_row_; - int dim_; + const emb_t* row_; + const cache_t* cache_row_; + const int dim_; - DEVICE_INLINE WeightRowAccessor( - emb_t* row, - cache_t* cache_row, - int dim, - StochasticRoundingRNGState* stoc_rounding_state) + DEVICE_INLINE + WeightRowAccessor(const emb_t* row, const cache_t* cache_row, const int dim) : row_(row), cache_row_(cache_row), dim_(dim) {} DEVICE_INLINE Vec4T load(const int32_t d, const float2 qparams) const { diff --git a/fbgemm_gpu/test/ssd_split_table_batched_embeddings_test.py b/fbgemm_gpu/test/ssd_split_table_batched_embeddings_test.py index 5f3c7ae337..c0f49ade8a 100644 --- a/fbgemm_gpu/test/ssd_split_table_batched_embeddings_test.py +++ b/fbgemm_gpu/test/ssd_split_table_batched_embeddings_test.py @@ -16,7 +16,7 @@ import numpy as np import torch -from fbgemm_gpu.split_embedding_configs import SparseType +from fbgemm_gpu.split_embedding_configs import EmbOptimType as OptimType, SparseType from fbgemm_gpu.split_embedding_utils import ( b_indices, fake_quantize_embs, @@ -34,7 +34,7 @@ SSDTableBatchedEmbeddingBags, ) -from hypothesis import given, settings, Verbosity +from hypothesis import assume, given, settings, Verbosity MAX_EXAMPLES = 40 @@ -116,6 +116,9 @@ def generate_ssd_tbes( lr: float = 0.01, # from SSDTableBatchedEmbeddingBags eps: float = 1.0e-8, # from SSDTableBatchedEmbeddingBags ssd_shards: int = 1, # from SSDTableBatchedEmbeddingBags + optimizer: OptimType = OptimType.EXACT_ROWWISE_ADAGRAD, + cache_set_scale: float = 1.0, + pooling_mode: bool = PoolingMode.SUM, ) -> Tuple[SSDTableBatchedEmbeddingBags, List[torch.nn.EmbeddingBag]]: """ Generate embedding modules (i,e., SSDTableBatchedEmbeddingBags and @@ -130,23 +133,45 @@ def generate_ssd_tbes( Es = [E] * T feature_table_map = list(range(T)) + if pooling_mode == PoolingMode.SUM: + mode = "sum" + do_pooling = True + elif pooling_mode == PoolingMode.MEAN: + mode = "mean" + do_pooling = True + elif pooling_mode == PoolingMode.NONE: + mode = "sum" + do_pooling = False + else: + # This proves that we have exhaustively checked all PoolingModes + raise RuntimeError("Unknown PoolingMode!") + # Generate torch EmbeddingBag - emb_ref = [ - torch.nn.EmbeddingBag(E, D, mode="sum", sparse=True).cuda() - for (E, D) in zip(Es, Ds) - ] + if do_pooling: + emb_ref = [ + torch.nn.EmbeddingBag(E, D, mode=mode, sparse=True).cuda() + for (E, D) in zip(Es, Ds) + ] + else: + emb_ref = [ + torch.nn.Embedding(E, D, sparse=True).cuda() for (E, D) in zip(Es, Ds) + ] + + cache_sets = max(int(max(T * B * L, 1) * cache_set_scale), 1) # Generate TBE SSD emb = SSDTableBatchedEmbeddingBags( embedding_specs=[(E, D) for (E, D) in zip(Es, Ds)], feature_table_map=feature_table_map, ssd_storage_directory=tempfile.mkdtemp(), - cache_sets=max(T * B * L, 1), + cache_sets=cache_sets, ssd_uniform_init_lower=-0.1, ssd_uniform_init_upper=0.1, learning_rate=lr, eps=eps, ssd_shards=ssd_shards, + optimizer=optimizer, + pooling_mode=pooling_mode, ).cuda() # Initialize TBE SSD weights @@ -160,6 +185,17 @@ def generate_ssd_tbes( return emb, emb_ref + def concat_ref_tensors( + self, + tensors: List[torch.Tensor], + do_pooling: bool, + B: int, + D: int, + ) -> torch.Tensor: + if do_pooling: + return torch.cat([t.view(B, -1) for t in tensors], dim=1) + return torch.cat(tensors, dim=0).view(-1, D) + def execute_ssd_forward_( self, emb: SSDTableBatchedEmbeddingBags, @@ -172,23 +208,40 @@ def execute_ssd_forward_( B: int, L: int, weighted: bool, + i: int = 0, ) -> Tuple[List[torch.Tensor], torch.Tensor]: """ Execute the forward functions of SSDTableBatchedEmbeddingBags and torch.nn.EmbeddingBag and compare outputs """ + assert len(emb_ref) == len(indices_list) + do_pooling = emb.pooling_mode != PoolingMode.NONE # Execute torch EmbeddingBag forward output_ref_list = ( - [b_indices(emb_, indices) for (emb_, indices) in zip(emb_ref, indices_list)] + [ + b_indices(emb_, indices, do_pooling=do_pooling) + for (emb_, indices) in zip(emb_ref, indices_list) + ] if not weighted else [ - b_indices(emb_, indices, per_sample_weights=per_sample_weights.view(-1)) + b_indices( + emb_, + indices, + per_sample_weights=per_sample_weights.view(-1), + do_pooling=do_pooling, + ) for (emb_, indices, per_sample_weights) in zip( emb_ref, indices_list, per_sample_weights_list ) ] ) - output_ref = torch.cat([out.view(B, -1) for out in output_ref_list], dim=1) + + output_ref = self.concat_ref_tensors( + output_ref_list, + do_pooling, + B, + emb.embedding_specs[0][1], + ) # Execute TBE SSD forward output = ( @@ -213,11 +266,25 @@ def execute_ssd_forward_( log_E=st.integers(min_value=3, max_value=5), L=st.integers(min_value=0, max_value=20), weighted=st.booleans(), + cache_set_scale=st.sampled_from([0.0, 0.005, 1]), + pooling_mode=st.sampled_from( + [PoolingMode.NONE, PoolingMode.SUM, PoolingMode.MEAN] + ), ) @settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None) def test_ssd_forward( - self, T: int, D: int, B: int, log_E: int, L: int, weighted: bool + self, + T: int, + D: int, + B: int, + log_E: int, + L: int, + weighted: bool, + cache_set_scale: float, + pooling_mode: PoolingMode, ) -> None: + assume(not weighted or pooling_mode == PoolingMode.SUM) + # Generate embedding modules ( emb, @@ -229,6 +296,8 @@ def test_ssd_forward( log_E, L, weighted, + cache_set_scale=cache_set_scale, + pooling_mode=pooling_mode, ) # Generate inputs @@ -262,11 +331,25 @@ def test_ssd_forward( log_E=st.integers(min_value=3, max_value=5), L=st.integers(min_value=0, max_value=20), weighted=st.booleans(), + cache_set_scale=st.sampled_from([0.0, 0.005, 1]), + pooling_mode=st.sampled_from( + [PoolingMode.NONE, PoolingMode.SUM, PoolingMode.MEAN] + ), ) @settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None) def test_ssd_backward_adagrad( - self, T: int, D: int, B: int, log_E: int, L: int, weighted: bool + self, + T: int, + D: int, + B: int, + log_E: int, + L: int, + weighted: bool, + cache_set_scale: float, + pooling_mode: PoolingMode, ) -> None: + assume(not weighted or pooling_mode == PoolingMode.SUM) + # Constants lr = 0.5 eps = 0.2 @@ -286,6 +369,8 @@ def test_ssd_backward_adagrad( lr=lr, eps=eps, ssd_shards=ssd_shards, + cache_set_scale=cache_set_scale, + pooling_mode=pooling_mode, ) Es = [emb.embedding_specs[t][0] for t in range(T)] @@ -317,10 +402,17 @@ def test_ssd_backward_adagrad( # Execute torch EmbeddingBag backward [out.backward(grad) for (out, grad) in zip(output_ref_list, output_grad_list)] - # Execute TBE SSD backward - output.backward( - torch.cat([grad.view(B, -1) for grad in output_grad_list], dim=1) + do_pooling = pooling_mode != PoolingMode.NONE + grad_test = self.concat_ref_tensors( + output_grad_list, + do_pooling, + B, + D * 4, ) + + # Execute TBE SSD backward + output.backward(grad_test) + # Compare optimizer states split_optimizer_states = [s for (s,) in emb.debug_split_optimizer_states()] for t in range(T): @@ -359,18 +451,31 @@ def test_ssd_backward_adagrad( log_E=st.integers(min_value=3, max_value=5), L=st.integers(min_value=0, max_value=20), weighted=st.booleans(), + cache_set_scale=st.sampled_from([0.0, 0.005, 1]), + pooling_mode=st.sampled_from( + [PoolingMode.NONE, PoolingMode.SUM, PoolingMode.MEAN] + ), ) @settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None) def test_ssd_cache( - self, T: int, D: int, B: int, log_E: int, L: int, weighted: bool + self, + T: int, + D: int, + B: int, + log_E: int, + L: int, + weighted: bool, + cache_set_scale: float, + pooling_mode: PoolingMode, ) -> None: - # T=2 - # D=2 - # B=9 - # log_E=3 - # L=14 - # weighted=False + assume(not weighted or pooling_mode == PoolingMode.SUM) + + lr = 0.5 + eps = 0.2 + ssd_shards = 2 torch.manual_seed(42) + optimizer = OptimType.EXACT_ROWWISE_ADAGRAD + # Generate embedding modules ( emb, @@ -382,8 +487,20 @@ def test_ssd_cache( log_E, L, weighted, + lr=lr, + eps=eps, + ssd_shards=ssd_shards, + optimizer=optimizer, + cache_set_scale=cache_set_scale, + pooling_mode=pooling_mode, ) + optimizer_states_ref = [] + if optimizer == OptimType.EXACT_ROWWISE_ADAGRAD: + optimizer_states_ref = [ + s.clone().float() for (s,) in emb.debug_split_optimizer_states() + ] + Es = [emb.embedding_specs[t][0] for t in range(T)] for i in range(10): @@ -422,7 +539,6 @@ def test_ssd_cache( 0, # prefetch_dist emb.lru_state, ) - assert actions_count_gpu.item() == 0 lxu_cache_locations = torch.ops.fbgemm.lxu_cache_lookup( linear_cache_indices, @@ -435,18 +551,8 @@ def test_ssd_cache( NOT_FOUND = np.iinfo(np.int32).max ASSOC = 32 - for loc, linear_idx in zip( - lxu_cache_locations.cpu().numpy().tolist(), - linear_cache_indices.cpu().numpy().tolist(), - ): - assert loc != NOT_FOUND - # if we have a hit, check the cache is consistent - loc_set = loc // ASSOC - loc_slot = loc % ASSOC - assert lru_state_cpu[loc_set, loc_slot] == emb.timestep - assert lxu_cache_state_cpu[loc_set, loc_slot] == linear_idx - - self.execute_ssd_forward_( + # Execute forward + output_ref_list, output = self.execute_ssd_forward_( emb, emb_ref, indices_list, @@ -457,6 +563,78 @@ def test_ssd_cache( B, L, weighted, + i=i, + ) + + # Generate output gradient + output_grad_list = [torch.randn_like(out) for out in output_ref_list] + + # Execute torch EmbeddingBag backward + for t, (out, grad) in enumerate(zip(output_ref_list, output_grad_list)): + # Zero out weight grad + emb_ref[t].weight.grad = None + out.backward(grad) + + do_pooling = pooling_mode != PoolingMode.NONE + grad_test = self.concat_ref_tensors( + output_grad_list, + do_pooling, + B, + D * 4, + ) + + # Execute TBE SSD backward + output.backward(grad_test) + + if optimizer == OptimType.EXACT_SGD: + # Update embedding weight + for t in range(T): + emb_ref[t].weight.data.copy_( + torch.add( + # Perform accumulation in float + emb_ref[t].weight.data.float(), + emb_ref[t].weight.grad.float(), + alpha=lr * -1, + ) + ) + else: + # Compare optimizer states + split_optimizer_states = [ + s for (s,) in emb.debug_split_optimizer_states() + ] + for t in range(T): + # pyre-fixme[16]: Optional type has no attribute `float`. + optimizer_states_ref[t].add_( + emb_ref[t].weight.grad.float().to_dense().pow(2).mean(dim=1) + ) + torch.testing.assert_close( + split_optimizer_states[t].float(), + optimizer_states_ref[t], + atol=1.0e-4, + rtol=1.0e-4, + ) + + emb_ref[t].weight.data.copy_( + torch.addcdiv( + emb_ref[t].weight.float(), + value=-lr, + tensor1=emb_ref[t].weight.grad.float().to_dense(), + tensor2=split_optimizer_states[t] + .float() + .sqrt() + .add(eps) + .view(Es[t], 1), + ) + ) + + # Compare weights + emb.flush() + for t in range(T): + torch.testing.assert_close( + emb.debug_split_embedding_weights()[t].float().cuda(), + emb_ref[t].weight.float(), + atol=1.0e-4, + rtol=1.0e-4, )