diff --git a/.github/scripts/fbgemm_gpu_build.bash b/.github/scripts/fbgemm_gpu_build.bash index 381efc2fbc..58c5e705b2 100644 --- a/.github/scripts/fbgemm_gpu_build.bash +++ b/.github/scripts/fbgemm_gpu_build.bash @@ -347,6 +347,7 @@ build_fbgemm_gpu_package () { --package_name="${package_name}" \ --python-tag="${python_tag}" \ --plat-name="${plat_name}" \ + --verbose \ "${build_args[@]}" # Run checks on the built libraries diff --git a/.github/workflows/fbgemm_ci.yml b/.github/workflows/fbgemm_ci.yml index 00ebfa358d..19b515c031 100644 --- a/.github/workflows/fbgemm_ci.yml +++ b/.github/workflows/fbgemm_ci.yml @@ -48,7 +48,7 @@ jobs: git config --global --add safe.directory '*' - name: Checkout the Repository - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: submodules: true @@ -86,7 +86,7 @@ jobs: steps: - name: Checkout the Repository - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: submodules: true @@ -127,7 +127,7 @@ jobs: git config --global --add safe.directory '*' - name: Checkout the Repository - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: submodules: true @@ -159,7 +159,7 @@ jobs: steps: - name: Checkout the Repository - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: submodules: true diff --git a/.github/workflows/fbgemm_gpu_ci.yml b/.github/workflows/fbgemm_gpu_ci.yml index ba89c57329..43006e5a3e 100644 --- a/.github/workflows/fbgemm_gpu_ci.yml +++ b/.github/workflows/fbgemm_gpu_ci.yml @@ -57,7 +57,7 @@ jobs: git config --global --add safe.directory '*' - name: Checkout the Repository - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: submodules: true @@ -126,7 +126,7 @@ jobs: git config --global --add safe.directory '*' - name: Checkout the Repository - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: submodules: true @@ -191,7 +191,7 @@ jobs: git config --global --add safe.directory '*' - name: Checkout the Repository - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: submodules: true diff --git a/.github/workflows/fbgemm_gpu_cpu_nightly.yml b/.github/workflows/fbgemm_gpu_cpu_nightly.yml index 5501ee89e3..8c5efd66fe 100644 --- a/.github/workflows/fbgemm_gpu_cpu_nightly.yml +++ b/.github/workflows/fbgemm_gpu_cpu_nightly.yml @@ -71,7 +71,7 @@ jobs: run: yum update -y; yum install -y binutils findutils git pciutils sudo wget which - name: Checkout the Repository - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: submodules: true @@ -136,7 +136,7 @@ jobs: run: yum update -y; yum install -y binutils findutils git pciutils sudo wget which - name: Checkout the Repository - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: submodules: true diff --git a/.github/workflows/fbgemm_gpu_cpu_release.yml b/.github/workflows/fbgemm_gpu_cpu_release.yml index 1300b4781f..aba87df783 100644 --- a/.github/workflows/fbgemm_gpu_cpu_release.yml +++ b/.github/workflows/fbgemm_gpu_cpu_release.yml @@ -68,7 +68,7 @@ jobs: run: yum update -y; yum install -y binutils findutils git pciutils sudo wget which - name: Checkout the Repository - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: submodules: true @@ -133,7 +133,7 @@ jobs: run: yum update -y; yum install -y binutils findutils git pciutils sudo wget which - name: Checkout the Repository - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: submodules: true diff --git a/.github/workflows/fbgemm_gpu_cuda_nightly.yml b/.github/workflows/fbgemm_gpu_cuda_nightly.yml index be1fc32fc4..f5ed26aec3 100644 --- a/.github/workflows/fbgemm_gpu_cuda_nightly.yml +++ b/.github/workflows/fbgemm_gpu_cuda_nightly.yml @@ -70,7 +70,7 @@ jobs: run: yum update -y; yum install -y binutils findutils git pciutils sudo tar wget which - name: Checkout the Repository - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: submodules: true @@ -140,6 +140,7 @@ jobs: needs: build_artifact steps: + # Cannot upgrade to actions/checkout@v4 yet because GLIBC on the instance is too old - name: Checkout the Repository uses: actions/checkout@v3 with: diff --git a/.github/workflows/fbgemm_gpu_cuda_release.yml b/.github/workflows/fbgemm_gpu_cuda_release.yml index 838be62996..74b79a88dc 100644 --- a/.github/workflows/fbgemm_gpu_cuda_release.yml +++ b/.github/workflows/fbgemm_gpu_cuda_release.yml @@ -74,7 +74,7 @@ jobs: run: yum update -y; yum install -y binutils findutils git pciutils sudo tar wget which - name: Checkout the Repository - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: submodules: true diff --git a/.github/workflows/fbgemm_gpu_docs.yml b/.github/workflows/fbgemm_gpu_docs.yml index bcc6a93d2a..d3a69bca5b 100644 --- a/.github/workflows/fbgemm_gpu_docs.yml +++ b/.github/workflows/fbgemm_gpu_docs.yml @@ -44,7 +44,7 @@ jobs: run: yum update -y; yum install -y binutils findutils git pciutils rsync sudo tar wget which - name: Checkout the Repository - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: submodules: true diff --git a/.github/workflows/fbgemm_gpu_lint.yml b/.github/workflows/fbgemm_gpu_lint.yml index 41c8a08967..3dccceacca 100644 --- a/.github/workflows/fbgemm_gpu_lint.yml +++ b/.github/workflows/fbgemm_gpu_lint.yml @@ -39,7 +39,7 @@ jobs: steps: - name: Checkout the Repository - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Setup Miniconda run: . $PRELUDE; setup_miniconda $HOME/miniconda diff --git a/.github/workflows/fbgemm_gpu_pip.yml b/.github/workflows/fbgemm_gpu_pip.yml index f0858138fb..d9f6dc2ff6 100644 --- a/.github/workflows/fbgemm_gpu_pip.yml +++ b/.github/workflows/fbgemm_gpu_pip.yml @@ -66,7 +66,7 @@ jobs: run: yum update -y; yum install -y binutils findutils git pciutils sudo wget which - name: Checkout the Repository - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Display System Info run: . $PRELUDE; print_system_info; print_ec2_info @@ -116,6 +116,7 @@ jobs: cuda-version-publish: [ "11.8.0" ] steps: + # Cannot upgrade to actions/checkout@v4 yet because GLIBC on the instance is too old - name: Checkout the Repository uses: actions/checkout@v3 @@ -182,7 +183,7 @@ jobs: git config --global --add safe.directory '*' - name: Checkout the Repository - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Display System Info run: . $PRELUDE; print_system_info diff --git a/.gitignore b/.gitignore index 746e08c73e..20163c8dc8 100644 --- a/.gitignore +++ b/.gitignore @@ -8,6 +8,10 @@ # found in: # https://github.com/github/gitignore/ +# General +.DS_Store +*~ + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/fbgemm_gpu/CMakeLists.txt b/fbgemm_gpu/CMakeLists.txt index 076961a827..67c36f125f 100644 --- a/fbgemm_gpu/CMakeLists.txt +++ b/fbgemm_gpu/CMakeLists.txt @@ -432,10 +432,22 @@ else() DEPENDS "${optimizer_codegen_dependencies}") endif() +set(AVX2_FLAGS "-mavx2;-mf16c;-mfma;-fopenmp") +if(NOT FBGEMM_CPU_ONLY AND WSL_MODE) + # NVCC in WSL complains about unknown -mavx options + # https://github.com/pytorch/FBGEMM/issues/2135 + set(AVX2_FLAGS "-Xcompiler;-mavx;-Xcompiler;-mavx2;-Xcompiler;-mf16c;-Xcompiler;-mfma;-fopenmp") +endif() + +set(AVX512_FLAGS "-mavx2;-mf16c;-mfma;-mavx512f;-mavx512bw;-mavx512dq;-mavx512vl;-fopenmp") +if(NOT FBGEMM_CPU_ONLY AND WSL_MODE) + set(AVX512_FLAGS "-Xcompiler;-mavx2;-Xcompiler;-mf16c;-Xcompiler;-mfma;-Xcompiler;-mavx512f;-Xcompiler;-mavx512bw;-Xcompiler;-mavx512dq;-Xcompiler;-mavx512vl;-fopenmp") +endif() + if(CXX_AVX2_FOUND) set_source_files_properties(${gen_cpu_source_files} PROPERTIES COMPILE_OPTIONS - "-mavx2;-mf16c;-mfma;-fopenmp") + "${AVX2_FLAGS}") else() set_source_files_properties(${gen_cpu_source_files} PROPERTIES COMPILE_OPTIONS @@ -504,13 +516,13 @@ set(fbgemm_sources_avx512 if(CXX_AVX2_FOUND) set_source_files_properties(${fbgemm_sources_avx2} PROPERTIES COMPILE_OPTIONS - "-mavx2;-mf16c;-mfma") + "${AVX2_FLAGS}") endif() if(CXX_AVX512_FOUND) set_source_files_properties(${fbgemm_sources_avx512} PROPERTIES COMPILE_OPTIONS - "-mavx2;-mf16c;-mfma;-mavx512f;-mavx512bw;-mavx512dq;-mavx512vl") + "${AVX512_FLAGS}") endif() set(fbgemm_sources ${fbgemm_sources_normal}) @@ -561,19 +573,20 @@ set(fbgemm_gpu_sources_static_cpu codegen/embedding_forward_quantized_host_cpu.cpp codegen/embedding_backward_dense_host_cpu.cpp codegen/embedding_bounds_check_host_cpu.cpp + src/merge_pooled_embedding_ops/merge_pooled_embedding_ops_cpu.cpp src/permute_pooled_embedding_ops/permute_pooled_embedding_function.cpp src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_cpu.cpp src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split_cpu.cpp src/jagged_tensor_ops/jagged_tensor_ops_autograd.cpp src/jagged_tensor_ops/jagged_tensor_ops_meta.cpp src/jagged_tensor_ops/jagged_tensor_ops_cpu.cpp - src/input_combine_cpu.cpp - src/layout_transform_ops_cpu.cpp + src/input_combine_ops/input_combine_cpu.cpp + src/layout_transform_ops/layout_transform_ops_cpu.cpp src/quantize_ops/quantize_ops_cpu.cpp src/quantize_ops/quantize_ops_meta.cpp src/sparse_ops/sparse_ops_cpu.cpp src/sparse_ops/sparse_ops_meta.cpp - src/embedding_inplace_update_cpu.cpp + src/embedding_inplace_ops/embedding_inplace_update_cpu.cpp src/split_embeddings_cache/linearize_cache_indices.cpp src/split_embeddings_cache/lfu_cache_populate_byte.cpp src/split_embeddings_cache/lru_cache_populate_byte.cpp @@ -588,16 +601,16 @@ if(NOT FBGEMM_CPU_ONLY) codegen/embedding_bounds_check_host.cpp src/memory_utils/memory_utils.cpp src/memory_utils/memory_utils_ops.cpp - src/layout_transform_ops_gpu.cpp + src/layout_transform_ops/layout_transform_ops_gpu.cpp src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_gpu.cpp src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split_gpu.cpp src/quantize_ops/quantize_ops_gpu.cpp src/sparse_ops/sparse_ops_gpu.cpp - src/split_embeddings_utils.cpp + src/split_embeddings_utils/split_embeddings_utils.cpp src/split_embeddings_cache/split_embeddings_cache_ops.cu - src/metric_ops_host.cpp - src/embedding_inplace_update_gpu.cpp - src/input_combine_gpu.cpp + src/metric_ops/metric_ops_host.cpp + src/embedding_inplace_ops/embedding_inplace_update_gpu.cpp + src/input_combine_ops/input_combine_gpu.cpp codegen/batch_index_select_dim0_host.cpp) if(NVML_LIB_PATH) @@ -607,8 +620,7 @@ if(NOT FBGEMM_CPU_ONLY) if(NVML_LIB_PATH OR USE_ROCM) message(STATUS "Adding merge_pooled_embeddings sources") list(APPEND fbgemm_gpu_sources_static_cpu - src/merge_pooled_embeddings_cpu.cpp - src/merge_pooled_embeddings_gpu.cpp + src/merge_pooled_embedding_ops/merge_pooled_embedding_ops_gpu.cpp src/topology_utils.cpp) else() message(STATUS "Skipping merge_pooled_embeddings sources") @@ -618,7 +630,7 @@ endif() if(CXX_AVX2_FOUND) set_source_files_properties(${fbgemm_gpu_sources_static_cpu} PROPERTIES COMPILE_OPTIONS - "-mavx;-mf16c;-mfma;-mavx2;-fopenmp") + "${AVX2_FLAGS}") else() set_source_files_properties(${fbgemm_gpu_sources_static_cpu} PROPERTIES COMPILE_OPTIONS @@ -631,9 +643,9 @@ if(NOT FBGEMM_CPU_ONLY) codegen/embedding_forward_quantized_split_lookup.cu src/memory_utils/memory_utils.cu src/memory_utils/memory_utils_ops.cu - src/embedding_inplace_update.cu + src/embedding_inplace_ops/embedding_inplace_update.cu src/histogram_binning_calibration_ops.cu - src/input_combine.cu + src/input_combine_ops/input_combine.cu src/jagged_tensor_ops/batched_dense_vec_jagged_2d_mul_backward.cu src/jagged_tensor_ops/batched_dense_vec_jagged_2d_mul_forward.cu src/jagged_tensor_ops/dense_to_jagged_forward.cu @@ -651,8 +663,8 @@ if(NOT FBGEMM_CPU_ONLY) src/jagged_tensor_ops/jagged_to_padded_dense_forward.cu src/jagged_tensor_ops/jagged_unique_indices.cu src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu - src/layout_transform_ops.cu - src/metric_ops.cu + src/layout_transform_ops/layout_transform_ops.cu + src/metric_ops/metric_ops.cu src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split.cu src/permute_pooled_embedding_ops/permute_pooled_embedding_ops.cu src/quantize_ops/quantize_bfloat16.cu @@ -691,7 +703,10 @@ if(NOT FBGEMM_CPU_ONLY) src/split_embeddings_cache/lxu_cache.cu src/split_embeddings_cache/linearize_cache_indices.cu src/split_embeddings_cache/reset_weight_momentum.cu - src/split_embeddings_utils.cu) + src/split_embeddings_utils/generate_vbe_metadata.cu + src/split_embeddings_utils/get_infos_metadata.cu + src/split_embeddings_utils/radix_sort_pairs.cu + src/split_embeddings_utils/transpose_embedding_input.cu) set_source_files_properties(${fbgemm_gpu_sources_static_gpu} PROPERTIES COMPILE_OPTIONS diff --git a/fbgemm_gpu/bench/batched_unary_embeddings_benchmark.py b/fbgemm_gpu/bench/batched_unary_embeddings_benchmark.py index d15e6b34ef..2ab90148b0 100644 --- a/fbgemm_gpu/bench/batched_unary_embeddings_benchmark.py +++ b/fbgemm_gpu/bench/batched_unary_embeddings_benchmark.py @@ -4,7 +4,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe import functools from math import sqrt from typing import List, Tuple @@ -29,7 +28,10 @@ def generate_unary_feature( - batch_size: int, num_embeddings: int + batch_size: int, + num_embeddings: int + # pyre-fixme[24]: Generic type `list` expects 1 type parameter, use + # `typing.List[]` to avoid runtime subscripting errors. ) -> Tuple[List, List, List]: lengths = [] offsets = [] @@ -90,6 +92,7 @@ def forward( @click.option("--num-tables", default=2) @click.option("--num-tasks", default=3) @click.option("--repeats", default=100) +# pyre-fixme[2]: Parameter must be annotated. def main(batch_size, num_tables, num_tasks, repeats) -> None: device = torch.device("cuda", 0) torch.cuda.set_device(device) diff --git a/fbgemm_gpu/bench/bench_utils.py b/fbgemm_gpu/bench/bench_utils.py index 02ffda22f6..540b6e01ca 100644 --- a/fbgemm_gpu/bench/bench_utils.py +++ b/fbgemm_gpu/bench/bench_utils.py @@ -41,13 +41,13 @@ def benchmark_torch_function( # noqa: C901 copy_f_for_multi_thread_test: bool = False, ) -> Tuple[float, torch.Tensor]: logging.info(f"Start to benchmark {name}...") - if device != "" and device != "cuda": + if device != "cpu" and device != "" and device != "cuda": torch.cuda.set_device(device) for _ in range(num_warmups): output = f(*args) assert num_threads > 0 - if torch.cuda.is_available() and (num_threads == 1): + if device != "cpu" and torch.cuda.is_available() and (num_threads == 1): cache = torch.empty( int(flush_gpu_cache_size_mb * 1024 * 1024 // 4), dtype=torch.float, @@ -69,7 +69,7 @@ def benchmark_torch_function( # noqa: C901 [s.elapsed_time(e) for s, e in zip(start_event, end_event)] ) elapsed_time = torch.mean(times).item() * 1.0e-3 - elif torch.cuda.is_available() and (num_threads > 1): + elif device != "cpu" and torch.cuda.is_available() and (num_threads > 1): cache = torch.empty( int(flush_gpu_cache_size_mb * 1024 * 1024 // 4), dtype=torch.float, @@ -156,6 +156,10 @@ def benchmark_requests( ) -> float: times = [] + # Run at least one warmup iteration to avoid the long cudaLaunchKernel time + # for the first kernel + num_warmups = num_warmups + 1 if num_warmups >= 0 else 1 + if num_warmups > 0: indices, offsets, weights = requests[0] for _ in range(num_warmups): diff --git a/fbgemm_gpu/bench/merge_embeddings_benchmark.py b/fbgemm_gpu/bench/merge_embeddings_benchmark.py index d6bd17d20f..d7f574d6f7 100644 --- a/fbgemm_gpu/bench/merge_embeddings_benchmark.py +++ b/fbgemm_gpu/bench/merge_embeddings_benchmark.py @@ -5,7 +5,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe import logging import signal @@ -44,6 +43,7 @@ ) +# pyre-fixme[2]: Parameter must be annotated. def get_gpu_device(gpu_num) -> torch.device: return torch.device(f"cuda:{gpu_num}") @@ -53,6 +53,7 @@ def get_gpu_device(gpu_num) -> torch.device: # Reference: https://fburl.com/code/5ueyfv5j def get_table_batched_offsets_from_dense( merged_indices: torch.Tensor, + # pyre-fixme[2]: Parameter must be annotated. gpu_num, ) -> Tuple[torch.Tensor, torch.Tensor]: (T, B, L) = merged_indices.size() @@ -95,6 +96,7 @@ def generate_requests( return rs +# pyre-fixme[3]: Return type must be annotated. def _get_random_tensor( num_ads: int, embedding_dimension: int, @@ -140,7 +142,9 @@ def _get_random_tensor( return result_tensor +# pyre-fixme[3]: Return type must be annotated. def generate_tbe( + # pyre-fixme[2]: Parameter must be annotated. batch_indices, num_ads: int, embedding_dimension: int, @@ -204,7 +208,14 @@ def generate_tbe( def print_p2p_bandwidth( - num_gpus, iters, pooled_ad_embeddings, bytes_per_element + # pyre-fixme[2]: Parameter must be annotated. + num_gpus, + # pyre-fixme[2]: Parameter must be annotated. + iters, + # pyre-fixme[2]: Parameter must be annotated. + pooled_ad_embeddings, + # pyre-fixme[2]: Parameter must be annotated. + bytes_per_element, ) -> None: print("Pairwise GPU Copy Bandwidth (GB/s)") p2p_copy_bw = np.zeros((num_gpus, num_gpus)) @@ -289,12 +300,23 @@ def benchmark( # noqa C901 if p2p_bw: print_p2p_bandwidth(num_gpus, iters, pooled_ad_embeddings, bytes_per_element) + # pyre-fixme[53]: Captured variable `emb` is not annotated. + # pyre-fixme[53]: Captured variable `pooled_ad_embeddings` is not annotated. + # pyre-fixme[53]: Captured variable `requests` is not annotated. + # pyre-fixme[53]: Captured variable `tbe_offset` is not annotated. + # pyre-fixme[3]: Return type must be annotated. def pool_func_with_quantization( + # pyre-fixme[2]: Parameter must be annotated. batch_indices, + # pyre-fixme[2]: Parameter must be annotated. include_quantization, + # pyre-fixme[2]: Parameter must be annotated. include_tbe, + # pyre-fixme[2]: Parameter must be annotated. fused_tbe, + # pyre-fixme[2]: Parameter must be annotated. skip_dequantization, + # pyre-fixme[2]: Parameter must be annotated. data_type, ): if include_tbe: @@ -478,6 +500,8 @@ def main( ) if sweep: + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. def handler(signum, frame): logging.error("timeout") raise TimeoutError() diff --git a/fbgemm_gpu/bench/sparse_ops_benchmark.py b/fbgemm_gpu/bench/sparse_ops_benchmark.py index bf0ad96c1e..a578f3f40a 100644 --- a/fbgemm_gpu/bench/sparse_ops_benchmark.py +++ b/fbgemm_gpu/bench/sparse_ops_benchmark.py @@ -7,6 +7,7 @@ import contextlib import functools import logging +import math import random from typing import List @@ -868,5 +869,78 @@ def ben(fn, name, ad_indices, ad_lengths, batch_offsets, num_ads_in_batch): ben(pass_4, "pass_4", ad_indices, ad_lengths, batch_offsets, num_ads_in_batch) +@cli.command() +@click.option("--row-size", default=2560000) +@click.option("--batch-size", default=4096) +@click.option("--bucket-num", default=16) +@click.option("--input-precision", type=str, default="long") +@click.option("--device", type=click.Choice(["cpu", "cuda"]), default="cpu") +def block_bucketize_sparse_features_bench( + row_size: int, batch_size: int, bucket_num: int, input_precision: str, device: str +) -> None: + + dtype = torch.int + if input_precision == "int": + dtype = torch.int + elif input_precision == "long": + dtype = torch.long + else: + raise RuntimeError(f"Does not support data type {input_precision}") + + indices = torch.randint(0, row_size, (batch_size,), dtype=dtype) + weights = torch.randint(0, row_size, (batch_size,), dtype=torch.float) + total = 0 + lengths = [] + for _ in range(batch_size): + length = random.randint(0, 10) + lengths.append(min(length, batch_size - total)) + total += length + if total > batch_size: + break + lengths = torch.tensor(lengths, dtype=dtype) + bucket_size = math.ceil(row_size / bucket_num) + block_sizes = torch.tensor([bucket_size] * lengths.numel(), dtype=dtype) + + bucket_pos = [j * bucket_size for j in range(bucket_num + 1)] + block_bucketize_pos = [torch.tensor(bucket_pos, device=device)] * lengths.numel() + test_param = {"uneven": block_bucketize_pos, "even": None} + print("device {device}") + for name, is_block_bucketize_pos in test_param.items(): + time, output = benchmark_torch_function( + torch.ops.fbgemm.block_bucketize_sparse_features, + ( + lengths if device == "cpu" else lengths.to(device), + indices if device == "cpu" else indices.to(device), + False, + True, + block_sizes if device == "cpu" else block_sizes.to(device), + bucket_num, + weights + if device == "cpu" + else (weights.to(device) if weights is not None else None), + None, + -1, # unused + is_block_bucketize_pos + if device == "cpu" + else ( + [i.to(device) for i in is_block_bucketize_pos] + if is_block_bucketize_pos is not None + else None + ), + ), + iters=100, + device=device, + ) + + num_bytes = 0 + for tensor in [lengths, indices, weights, *block_bucketize_pos, *output]: + if isinstance(tensor, torch.Tensor): + num_bytes += (tensor.numel()) * tensor.element_size() + + logging.info( + f"{name}_block_bucketize_sparse_features forward: {dtype}, {num_bytes} bytes read/write, {time * 1e3} ms, {num_bytes / time / 1e9} GB/s" + ) + + if __name__ == "__main__": cli() diff --git a/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py b/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py index df57b397a6..143b8a0e3d 100644 --- a/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py +++ b/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py @@ -297,6 +297,7 @@ def device( # noqa C901 flush_gpu_cache_size_mb=flush_gpu_cache_size_mb, bwd_only=True, grad=grad_output, + num_warmups=warmup_runs, ) logging.info( f"Backward, B: {B}, E: {E}, T: {T}, D: {D}, L: {L}, " @@ -313,6 +314,7 @@ def device( # noqa C901 @click.option("--weights-precision", type=SparseType, default=SparseType.FP32) @click.option("--stoc", is_flag=True, default=False) @click.option("--iters", default=100) +@click.option("--warmup-runs", default=0) @click.option("--mixed", is_flag=True, default=False) @click.option("--num-embeddings", default=int(1e5)) @click.option("--num-tables", default=32) @@ -336,6 +338,7 @@ def uvm( weights_precision: SparseType, stoc: bool, iters: int, + warmup_runs: int, mixed: bool, num_embeddings: int, num_tables: int, @@ -486,6 +489,7 @@ def uvm( per_sample_weights, ), flush_gpu_cache_size_mb=flush_gpu_cache_size_mb, + num_warmups=warmup_runs, ) logging.info( f"UVM Forward, B: {B}, " @@ -521,6 +525,7 @@ def uvm( per_sample_weights, ), flush_gpu_cache_size_mb=flush_gpu_cache_size_mb, + num_warmups=warmup_runs, ) read_write_bytes_hbm = ( output_size_multiplier * B * sum(Ds[T_uvm:]) @@ -541,6 +546,7 @@ def uvm( per_sample_weights, ), flush_gpu_cache_size_mb=flush_gpu_cache_size_mb, + num_warmups=warmup_runs, ) read_write_bytes_total = read_write_bytes_uvm + read_write_bytes_hbm logging.info( @@ -562,6 +568,7 @@ def uvm( @click.option("--stoc", is_flag=True, default=False) @click.option("--long-index", is_flag=True, default=False) @click.option("--iters", default=100) +@click.option("--warmup-runs", default=0) @click.option("--mixed", is_flag=True, default=False) @click.option("--num-embeddings", default=int(1e5)) @click.option("--num-tables", default=32) @@ -580,6 +587,7 @@ def cache( # noqa C901 weights_precision: SparseType, stoc: bool, iters: int, + warmup_runs: int, long_index: bool, mixed: bool, num_embeddings: int, @@ -678,6 +686,7 @@ def cache( # noqa C901 indices.long(), offsets.long(), per_sample_weights ).backward(grad_output), flush_gpu_cache_size_mb=flush_gpu_cache_size_mb, + num_warmups=warmup_runs, ) logging.info( f"ForwardBackward (UVM), B: {B}, E: {E}, T: {T}, D: {D}, L: {L}, " @@ -713,6 +722,7 @@ def cache( # noqa C901 emb.reset_cache_states() for indices, offsets, _ in warmup_requests: emb.forward(indices, offsets) + # TODO: Add warmup_runs prefetch_time, forward_backward_time = benchmark_pipelined_requests( requests, lambda indices, offsets, indices_weights: emb.prefetch(indices, offsets), @@ -738,9 +748,14 @@ def cache( # noqa C901 def benchmark_cpu_requests( requests: List[Tuple[torch.IntTensor, torch.IntTensor, Optional[torch.Tensor]]], func: Callable[[Tensor, Tensor, Optional[Tensor]], Tensor], + num_warmups: int = 0, ) -> float: import time + if num_warmups > 0: + for _ in range(num_warmups): + func(*requests[0]) + start_time = time.perf_counter() for indices, offsets, weights in requests: func(indices, offsets, weights) @@ -756,6 +771,7 @@ def benchmark_cpu_requests( @click.option("--weights-precision", type=SparseType, default=SparseType.INT4) @click.option("--stoc", is_flag=True, default=False) @click.option("--iters", default=100) +@click.option("--warmup-runs", default=0) @click.option("--managed", default="device") @click.option("--mixed", is_flag=True, default=False) @click.option("--num-embeddings", default=int(1e5)) @@ -777,6 +793,7 @@ def nbit_cpu( # noqa C901 weights_precision: SparseType, stoc: bool, iters: int, + warmup_runs: int, managed: str, mixed: bool, num_embeddings: int, @@ -860,6 +877,7 @@ def nbit_cpu( # noqa C901 offsets, per_sample_weights, ), + num_warmups=warmup_runs, ) logging.info( @@ -1425,6 +1443,7 @@ def nbit_device_with_spec( # noqa C901 @click.option("--embedding-dim", default=128) @click.option("--weights-precision", type=SparseType, default=SparseType.INT4) @click.option("--iters", default=100) +@click.option("--warmup-runs", default=0) @click.option("--mixed", is_flag=True, default=False) @click.option("--num-embeddings", default=int(1e5)) @click.option("--num-tables", default=32) @@ -1448,6 +1467,7 @@ def nbit_uvm( embedding_dim: int, weights_precision: SparseType, iters: int, + warmup_runs: int, mixed: bool, num_embeddings: int, num_tables: int, @@ -1613,6 +1633,7 @@ def nbit_uvm( per_sample_weights, ), flush_gpu_cache_size_mb=flush_gpu_cache_size_mb, + num_warmups=warmup_runs, ) logging.info( f"UVM NBit Forward, {weights_precision}, B: {B}, " @@ -1670,6 +1691,7 @@ def nbit_uvm( per_sample_weights, ), flush_gpu_cache_size_mb=flush_gpu_cache_size_mb, + num_warmups=warmup_runs, ) read_write_bytes_total = read_write_bytes_uvm + read_write_bytes_hbm logging.info( @@ -1683,6 +1705,7 @@ def nbit_uvm( emb_mixed.reset_cache_states() for indices, offsets, _ in requests: emb_mixed.forward(indices, offsets) + # TODO: Add warmup runs prefetch_time, forward_time = benchmark_pipelined_requests( requests, lambda indices, offsets, indices_weights: emb_mixed.prefetch( @@ -1717,7 +1740,7 @@ def nbit_uvm( @click.option("--embedding-dim", default=128) @click.option("--weights-precision", type=SparseType, default=SparseType.INT4) @click.option("--iters", default=100) -@click.option("--warmup", default=10) +@click.option("--warmup_runs", default=10) @click.option("--mixed", is_flag=True, default=False) @click.option("--num-embeddings", default=int(1e5)) @click.option("--num-tables", default=32) @@ -1744,7 +1767,7 @@ def nbit_uvm_compare_direct_mapped( embedding_dim: int, weights_precision: SparseType, iters: int, - warmup: int, + warmup_runs: int, mixed: bool, num_embeddings: int, num_tables: int, @@ -1866,7 +1889,7 @@ def bench_uvm_cls( per_sample_weights, ), flush_gpu_cache_size_mb=flush_gpu_cache_size_mb, - num_warmups=warmup, + num_warmups=warmup_runs, nvtx_range=nvtx_range, callback_after_warmup=callback_after_warmup, ) @@ -1949,6 +1972,7 @@ def bench_uvm_cls( @click.option("--embedding-dim", default=128) @click.option("--weights-precision", type=SparseType, default=SparseType.INT4) @click.option("--iters", default=100) +@click.option("--warmup-runs", default=0) @click.option("--mixed", is_flag=True, default=False) @click.option("--num-embeddings", default=int(1e5)) @click.option("--num-tables", default=32) @@ -1972,6 +1996,7 @@ def nbit_cache( # noqa C901 embedding_dim: int, weights_precision: SparseType, iters: int, + warmup_runs: int, mixed: bool, num_embeddings: int, num_tables: int, @@ -2078,6 +2103,7 @@ def nbit_cache( # noqa C901 indices.int(), offsets.int(), per_sample_weights ), flush_gpu_cache_size_mb=flush_gpu_cache_size_mb, + num_warmups=warmup_runs, ) logging.info( f"Forward (UVM) {weights_precision}, B: {B}, E: {E}, T: {T}, D: {D}, L: {L}, " @@ -2156,6 +2182,7 @@ def nbit_cache( # noqa C901 torch.cuda.cudart().cudaProfilerStart() torch.cuda.nvtx.range_push("pipeline") + # TODO: Add warmup_runs prefetch_time, forward_time = benchmark_pipelined_requests( requests, lambda indices, offsets, indices_weights: emb.prefetch( @@ -2189,6 +2216,7 @@ def nbit_cache( # noqa C901 @click.option("--bag-size", default=20) @click.option("--batch-size", default=2048) @click.option("--iters", default=10) +@click.option("--warmup-runs", default=0) @click.option("--num-embeddings", default=int(1e5)) @click.option("--num-tables", default=100) @click.option("--pruning-hash-load-factor", default=0.75) @@ -2200,6 +2228,7 @@ def hashtable( # noqa C901 bag_size: int, batch_size: int, iters: int, + warmup_runs: int, num_embeddings: int, num_tables: int, pruning_hash_load_factor: float, @@ -2278,6 +2307,7 @@ def hashtable( # noqa C901 lambda indices, offsets, _: torch.ops.fbgemm.pruned_hashmap_lookup( indices, offsets, hash_table, hash_table_offsets ), + num_warmups=warmup_runs, ) logging.info( @@ -2292,6 +2322,7 @@ def hashtable( # noqa C901 time_per_iter = benchmark_requests( requests, lambda indices, offsets, _: ht.lookup(indices, offsets), + num_warmups=warmup_runs, ) logging.info( @@ -2304,6 +2335,7 @@ def hashtable( # noqa C901 @click.option("--bag-size", default=20) @click.option("--batch-size", default=2048) @click.option("--iters", default=100) +@click.option("--warmup-runs", default=0) @click.option("--num-embeddings", default=int(1e5)) @click.option("--num-tables", default=100) @click.option("--pruning-ratio", default=0.9) @@ -2314,6 +2346,7 @@ def pruned_array( # noqa C901 bag_size: int, batch_size: int, iters: int, + warmup_runs: int, num_embeddings: int, num_tables: int, pruning_ratio: float, @@ -2362,6 +2395,7 @@ def pruned_array( # noqa C901 index_remappings, index_remappings_offsets, ), + num_warmups=warmup_runs, ) logging.info( @@ -2374,6 +2408,7 @@ def pruned_array( # noqa C901 @click.option("--bag-size", default=20) @click.option("--batch-size", default=512) @click.option("--iters", default=100) +@click.option("--warmup-runs", default=0) @click.option("--num-embeddings", default=int(1e5)) @click.option("--num-tables", default=32) @click.option("--bounds-check-mode", type=int, default=BoundsCheckMode.WARNING.value) @@ -2383,6 +2418,7 @@ def bounds_check_indices( # noqa C901 bag_size: int, batch_size: int, iters: int, + warmup_runs: int, num_embeddings: int, num_tables: int, bounds_check_mode: int, @@ -2414,11 +2450,12 @@ def bounds_check_indices( # noqa C901 requests, lambda indices, offsets, _: torch.ops.fbgemm.bounds_check_indices( rows_per_table, - indices, - offsets, + indices.long(), + offsets.long(), BoundsCheckMode(bounds_check_mode), warning, ), + num_warmups=warmup_runs, ) logging.info( @@ -2437,6 +2474,7 @@ def bounds_check_indices( # noqa C901 @click.option("--weights-precision", type=SparseType, default=SparseType.INT4) @click.option("--output-dtype", type=SparseType, default=SparseType.FP16) @click.option("--iters", type=int, default=100) +@click.option("--warmup-runs", default=0) @click.option("--fp8-exponent-bits", type=int, default=None) @click.option("--fp8-exponent-bias", type=int, default=None) def emb_inplace_update( # noqa C901 @@ -2447,6 +2485,7 @@ def emb_inplace_update( # noqa C901 weights_precision: SparseType, output_dtype: SparseType, iters: int, + warmup_runs: int, fp8_exponent_bits: Optional[int], fp8_exponent_bias: Optional[int], ) -> None: @@ -2548,6 +2587,7 @@ def emb_inplace_update( # noqa C901 op.embedding_inplace_update_internal, (update_table_idx, update_row_idx, update_weights), iters=iters, + num_warmups=warmup_runs, ) logging.info( @@ -2601,6 +2641,7 @@ def emb_inplace_update( # noqa C901 16, # row_alignment ), iters=iters, + num_warmups=warmup_runs, ) logging.info( @@ -2865,6 +2906,7 @@ def device_with_spec( # noqa C901 flush_gpu_cache_size_mb=flush_gpu_cache_size_mb, bwd_only=True, grad=grad_output, + num_warmups=warmup_runs, ) logging.info( f"Backward, B: {B}, Es: {Es}, T: {T}, Ds: {Ds}, Ls: {Ls_str}, " @@ -2896,6 +2938,7 @@ def vbe( compressed_tables: int, iters: int, ) -> None: + # TODO: Add warmup_runs torch.manual_seed(42) B = batch_size cB = compressed_batch_size diff --git a/fbgemm_gpu/codegen/embedding_backward_split_grad_template.cu b/fbgemm_gpu/codegen/embedding_backward_split_grad_template.cu index 5367f292b7..1db787760c 100644 --- a/fbgemm_gpu/codegen/embedding_backward_split_grad_template.cu +++ b/fbgemm_gpu/codegen/embedding_backward_split_grad_template.cu @@ -57,6 +57,79 @@ split_embedding_backward_codegen_find_long_segments( } } +template +__global__ __launch_bounds__(kMaxThreads) +void split_embedding_backward_count_unique_indices_kernel( + const pta::PackedTensorAccessor32 + sorted_linear_indices_num_runs, + const pta::PackedTensorAccessor32 + sorted_linear_indices_cumulative_run_lengths, + const pta::PackedTensorAccessor32 + sorted_infos, + const pta::PackedTensorAccessor32 + weights_placements, + pta::PackedTensorAccessor32 + dev_or_uvm_unique_indices, + const int info_B_num_bits +) { + const int32_t num_runs = sorted_linear_indices_num_runs[0]; + const auto T = weights_placements.size(0); + for (auto run_id = blockIdx.x * blockDim.x + threadIdx.x; + run_id < num_runs; + run_id += blockDim.x * gridDim.x) { + // Obtain the associated table id of the run id + const auto segment_start = sorted_linear_indices_cumulative_run_lengths[run_id]; + const auto info = reinterpret_cast(&sorted_infos[0])[segment_start]; + const auto t = nobag ? (info % T) : (info >> info_B_num_bits); + + int32_t t_next = -1; + const auto unique_count_offset = run_id + 1; + if (unique_count_offset < num_runs) { + const auto segment_start_next = sorted_linear_indices_cumulative_run_lengths[unique_count_offset]; + const auto info_next = reinterpret_cast(&sorted_infos[0])[segment_start_next]; + t_next = nobag ? (info_next % T) : (info_next >> info_B_num_bits); + } + + if (t != t_next) { + const auto placement = static_cast(weights_placements[t]); + if (placement != PlacementType::MANAGED_CACHING) { + // Record num unique indices for PlacementType::DEVICE from unique_count_offset + gpuAtomicAdd(&dev_or_uvm_unique_indices[t], unique_count_offset); + } + if (t_next != -1) { + const auto placement_next = static_cast(weights_placements[t_next]); + if (placement_next != PlacementType::MANAGED_CACHING) { + // Record num unique indices for PlacementType::DEVICE from unique_count_offset + gpuAtomicAdd(&dev_or_uvm_unique_indices[t_next], -unique_count_offset); + } + } + } + } +} + +{% for nobag in [True, False] %} +{% set info_pta_t = "int64_t" if nobag else "int32_t" %} +template __global__ __launch_bounds__(kMaxThreads) +void split_embedding_backward_count_unique_indices_kernel +< + {{ info_pta_t }}, + {{ "int64_t" if nobag else "uint32_t" }}, + {{ "true" if nobag else "false" }} +> ( + const pta::PackedTensorAccessor32 + sorted_linear_indices_num_runs, + const pta::PackedTensorAccessor32 + sorted_linear_indices_cumulative_run_lengths, + const pta::PackedTensorAccessor32<{{ info_pta_t }}, 1, at::RestrictPtrTraits> + sorted_infos, + const pta::PackedTensorAccessor32 + weights_placements, + pta::PackedTensorAccessor32 + dev_or_uvm_unique_indices, + const int info_B_num_bits +); +{% endfor %} + {% for vbe in [True, False] %} {% set vbe_desc = "_vbe" if vbe else "" %} template @@ -78,7 +151,7 @@ __global__ __launch_bounds__(kMaxThreads) void grad_mean{{ vbe_desc }}_kernel( ) { int32_t T = D_offsets.size(0) - 1; int32_t b_t = blockIdx.x * blockDim.y + threadIdx.y; - int32_t b; + [[maybe_unused]] int32_t b; int32_t t; const auto total_B = offsets.size(0) - 1; diff --git a/fbgemm_gpu/codegen/embedding_backward_split_host_template.cpp b/fbgemm_gpu/codegen/embedding_backward_split_host_template.cpp index 63fdf6e684..c5eb20a620 100644 --- a/fbgemm_gpu/codegen/embedding_backward_split_host_template.cpp +++ b/fbgemm_gpu/codegen/embedding_backward_split_host_template.cpp @@ -98,6 +98,8 @@ Tensor split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_e const Tensor& vbe_row_output_offsets, const Tensor& vbe_b_t_map, {%- endif %} + const bool use_uniq_cache_locations, + const bool use_homogeneous_placements, {{ args.split_function_args | join(", ") }}); {%- endfor %} {#-/*for nobag*/#} @@ -177,20 +179,23 @@ class {{ autograd_func }} : const int64_t vbe_output_size, {%- endif %} const bool is_experimental, + const bool use_uniq_cache_locations_bwd, + const bool use_homogeneous_placements, {{ args.split_function_args | join(", ") }}) { - const auto T = weights_offsets.numel(); + const auto T = weights_offsets.sym_numel(); {%- if vbe %} const auto B_offsets_ = B_offsets.value_or(Tensor()); const auto vbe_output_offsets_feature_rank_ = vbe_output_offsets_feature_rank.value_or(Tensor()); const auto vbe_B_offsets_rank_per_feature_ = vbe_B_offsets_rank_per_feature.value_or(Tensor()); - const auto max_B_ = max_B; + const c10::SymInt max_B_ = max_B; {%- else %} - const auto max_B_ = offsets.size(0) / T; + const auto max_B_ = offsets.sym_size(0) / T; {%- endif %} - auto [info_B_num_bits, info_B_mask] = adjust_info_B_num_bits(max_B_, T); + // TODO: don't guard here + auto [info_B_num_bits, info_B_mask] = adjust_info_B_num_bits(max_B_.guard_int(__FILE__, __LINE__), T.guard_int(__FILE__, __LINE__)); {%- if vbe %} static auto generate_vbe_metadata_op = @@ -262,6 +267,8 @@ class {{ autograd_func }} : ctx->saved_data["info_B_num_bits"] = info_B_num_bits; const auto info_B_mask_int64 = static_cast(info_B_mask); ctx->saved_data["info_B_mask"] = info_B_mask_int64; + ctx->saved_data["use_uniq_cache_locations_bwd"] = use_uniq_cache_locations_bwd; + ctx->saved_data["use_homogeneous_placements"] = use_homogeneous_placements; {%- for (var, _) in args.saved_data %} ctx->saved_data["{{ var }}"] = {{ var }}; @@ -391,6 +398,10 @@ class {{ autograd_func }} : {%- endif %} {#-/* if optimizer != "none" */#} const int32_t info_B_num_bits = ctx->saved_data["info_B_num_bits"].toInt(); const int64_t info_B_mask_int64 = ctx->saved_data["info_B_mask"].toInt(); + const auto use_uniq_cache_locations_bwd = + ctx->saved_data["use_uniq_cache_locations_bwd"].toBool(); + const auto use_homogeneous_placements = + ctx->saved_data["use_homogeneous_placements"].toBool(); {%- for (var, ivalue_cast) in args.saved_data %} auto {{ var }} = ctx->saved_data["{{ var }}"].{{ ivalue_cast }}(); @@ -509,6 +520,8 @@ class {{ autograd_func }} : vbe_row_output_offsets, vbe_b_t_map, {%- endif %} + use_uniq_cache_locations_bwd, + use_homogeneous_placements, {{ args.split_function_arg_names | join(", ") }} ) {{ ":" if not weighted else ";" }} {%- endfor %} {#-/* for weighted in [False, True] */#} @@ -545,6 +558,8 @@ class {{ autograd_func }} : Variable(), // vbe_output_size {%- endif %} Variable(), // is_experimental + Variable(), // use_uniq_cache_locations_bwd + Variable(), // use_homogeneous_placements {{ args.split_variables | join(", ") }} }; {%- else %} @@ -584,6 +599,8 @@ class {{ autograd_func }} : vbe_row_output_offsets, vbe_b_t_map, {%- endif %} + use_uniq_cache_locations_bwd, + use_homogeneous_placements, {{ args.split_function_arg_names | join(", ") }} ); return { @@ -614,6 +631,8 @@ class {{ autograd_func }} : Variable(), // vbe_output_size {%- endif %} Variable(), // is_experimental + Variable(), // use_uniq_cache_locations_bwd + Variable(), // use_homogeneous_placements {{ args.split_variables | join(", ") }} }; {%- endif %} @@ -656,7 +675,9 @@ Tensor split_embedding_codegen_lookup_{{ optimizer }}_function( const int64_t max_B = -1, const int64_t max_B_feature_rank = -1, const int64_t vbe_output_size = -1, - const bool is_experimental = false + const bool is_experimental = false, + const bool use_uniq_cache_locations_bwd = false, + const bool use_homogeneous_placements = false ) { {%- if has_gpu_support %} {%- for vbe in ([True, False] if has_vbe_support else [False]) %} @@ -720,6 +741,8 @@ Tensor split_embedding_codegen_lookup_{{ optimizer }}_function( vbe_output_size, {%- endif %} is_experimental, + use_uniq_cache_locations_bwd, + use_homogeneous_placements, {{ args.split_function_arg_names | join(", ") }})[0]; } {%- endfor %} {#-/* for nobag */#} @@ -766,7 +789,9 @@ TORCH_LIBRARY_FRAGMENT({{ lib_name }}, m) { " int max_B=-1, " " int max_B_feature_rank=-1, " " int vbe_output_size=-1, " - " bool is_experimental=False) -> Tensor"); + " bool is_experimental=False, " + " bool use_uniq_cache_locations_bwd=False, " + " bool use_homogeneous_placements=False) -> Tensor"); // We're playing a funny trick here: we're using the autograd // implementation of the operator at all the dispatch keys. This is OK // because autograd.Function works even in a context where there is diff --git a/fbgemm_gpu/codegen/embedding_backward_split_indice_weights_template.cu b/fbgemm_gpu/codegen/embedding_backward_split_indice_weights_template.cu index bfb6f1fef3..d6e72ccaa4 100644 --- a/fbgemm_gpu/codegen/embedding_backward_split_indice_weights_template.cu +++ b/fbgemm_gpu/codegen/embedding_backward_split_indice_weights_template.cu @@ -66,7 +66,7 @@ __global__ __launch_bounds__(kForwardMaxThreads) void } int32_t t; - int32_t b; + [[maybe_unused]] int32_t b; {%- if vbe %} const auto info = reinterpret_cast(&b_t_map[b_t])[0]; diff --git a/fbgemm_gpu/codegen/embedding_backward_split_kernel_cta_template.cu b/fbgemm_gpu/codegen/embedding_backward_split_kernel_cta_template.cu index ce0a5db17c..1cd572a33c 100644 --- a/fbgemm_gpu/codegen/embedding_backward_split_kernel_cta_template.cu +++ b/fbgemm_gpu/codegen/embedding_backward_split_kernel_cta_template.cu @@ -62,6 +62,8 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_ {%- endif %} {%- if not dense %} const pta::PackedTensorAccessor32 sorted_lxu_cache_locations, + const bool use_uniq_cache_locations, + const pta::PackedTensorAccessor32 table_unique_indices_offsets, {%- endif %} {%- if weighted %} const pta::PackedTensorAccessor32, 1, at::RestrictPtrTraits> sorted_indice_weights, @@ -381,10 +383,10 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_ stochastic_rounding, stochastic_rounding_philox_args, current_run_id, + use_uniq_cache_locations ? (current_run_id - table_unique_indices_offsets[t_0]) : segment_start, D, t_0, idx, - segment_start, shfl_sync_mask, 0, // shared_weight_offset {{ args.split_function_arg_names | join(", ") }}); @@ -462,6 +464,8 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_ {%- if not dense %} const pta::PackedTensorAccessor32 sorted_lxu_cache_locations, + const bool use_uniq_cache_locations, + const pta::PackedTensorAccessor32 table_unique_indices_offsets, {%- endif %} {%- if weighted %} const pta::PackedTensorAccessor32, 1, at::RestrictPtrTraits> sorted_indice_weights, diff --git a/fbgemm_gpu/codegen/embedding_backward_split_kernel_warp_template.cu b/fbgemm_gpu/codegen/embedding_backward_split_kernel_warp_template.cu index ad81b31b1b..176f80bd9d 100644 --- a/fbgemm_gpu/codegen/embedding_backward_split_kernel_warp_template.cu +++ b/fbgemm_gpu/codegen/embedding_backward_split_kernel_warp_template.cu @@ -60,6 +60,8 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_ {%- endif %} {%- if not dense %} const pta::PackedTensorAccessor32 sorted_lxu_cache_locations, + const bool use_uniq_cache_locations, + const pta::PackedTensorAccessor32 table_unique_indices_offsets, {%- endif %} {%- if weighted %} const pta::PackedTensorAccessor32, 1, at::RestrictPtrTraits> sorted_indice_weights, @@ -222,10 +224,10 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_ stochastic_rounding, stochastic_rounding_philox_args, run_id, + use_uniq_cache_locations ? (run_id - table_unique_indices_offsets[t_0]) : segment_start, D, t_0, idx, - segment_start, shfl_sync_mask, threadIdx.y * kMaxVecsPerThread * kThreadGroupSize, // shared_weight_offset {{ args.split_function_arg_names | join(", ") }}); @@ -301,6 +303,8 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_ {%- endif %} {%- if not dense %} const pta::PackedTensorAccessor32 sorted_lxu_cache_locations, + const bool use_uniq_cache_locations, + const pta::PackedTensorAccessor32 table_unique_indices_offsets, {%- endif %} {%- if weighted %} const pta::PackedTensorAccessor32, 1, at::RestrictPtrTraits> sorted_indice_weights, diff --git a/fbgemm_gpu/codegen/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/embedding_backward_split_template.cu index c621769bdc..1d928cd17a 100644 --- a/fbgemm_gpu/codegen/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/embedding_backward_split_template.cu @@ -20,6 +20,7 @@ #include "fbgemm_gpu/embedding_backward_template_helpers.cuh" #include "fbgemm_gpu/fbgemm_tensor_accessor.h" #include "fbgemm_gpu/split_embeddings_utils.cuh" +#include "fbgemm_gpu/sparse_ops.h" using Tensor = at::Tensor; using namespace fbgemm_gpu; @@ -67,6 +68,8 @@ split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc {%- endif %} {%- if not dense %} const pta::PackedTensorAccessor32 sorted_lxu_cache_locations, + const bool use_uniq_cache_locations, + const pta::PackedTensorAccessor32 table_unique_indices_offsets, {%- endif %} {%- if weighted %} const pta::PackedTensorAccessor32, 1, at::RestrictPtrTraits> sorted_indice_weights, @@ -139,6 +142,8 @@ split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc {%- endif %} {%- if not dense %} const pta::PackedTensorAccessor32 sorted_lxu_cache_locations, + const bool use_uniq_cache_locations, + const pta::PackedTensorAccessor32 table_unique_indices_offsets, {%- endif %} {%- if weighted %} const pta::PackedTensorAccessor32, 1, at::RestrictPtrTraits> sorted_indice_weights, @@ -201,6 +206,21 @@ grad_mean{{ vdesc }}_kernel( {%- endif %} ); +template +__global__ __launch_bounds__(kMaxThreads) void +split_embedding_backward_count_unique_indices_kernel( + const pta::PackedTensorAccessor32 + sorted_linear_indices_num_runs, + const pta::PackedTensorAccessor32 + sorted_linear_indices_cumulative_run_lengths, + const pta::PackedTensorAccessor32 + sorted_infos, + const pta::PackedTensorAccessor32 + weights_placements, + pta::PackedTensorAccessor32 + dev_or_uvm_unique_indices, + const int info_B_num_bits +); //////////////////////////////////////////////////////////////////////////////// // Utility Macros @@ -332,6 +352,10 @@ Tensor split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_e const Tensor& vbe_row_output_offsets, const Tensor& vbe_b_t_map, {%- endif %} + {%- if not is_index_select and not dense %} + const bool use_uniq_cache_locations, + const bool use_homogeneous_placements, + {%- endif %} {%- if is_index_select %} const Tensor& grad_offsets, const Tensor& total_L_offsets, @@ -511,34 +535,90 @@ Tensor split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_e ); {%- if not dense %} - auto lxu_cache_locations_sorted = at::empty_like(lxu_cache_locations); + Tensor lxu_cache_locations_sorted = lxu_cache_locations; + Tensor table_unique_indices_offsets; if (lxu_cache_locations.size(0) > 0) { + if (use_uniq_cache_locations) { + if (!use_homogeneous_placements) { + // When use_uniq_cache_locations=true, lxu_cache_locations are unique + // and sorted in an ascending order based on the linear cache indices. + // Linear cache indices of tables that are not placed in cache are set + // to a sentinel value (i.e., the sum of hash sizes of all embedding + // tables). Since the sentinel value is larger than the max linear + // cache index value, the lxu_cache_locations can be sorted differently + // than the sorted_linear_indices. + // + // For this reason, the run ids of sorted and unique + // lxu_cache_locations can be different from those of the + // sorted_linear_indices. We need the following code to compute + // table_unique_indices_offsets which contains the differences between + // lxu_cache_locations run ids and sorted_linear_indices run ids. + auto dev_or_uvm_unique_indices = at::zeros_like(weights_placements); + +#ifdef FBGEMM_GPU_MEMCHECK + const auto func_name = "split_embedding_backward_count_unique_indices_kernel"; +#endif + split_embedding_backward_count_unique_indices_kernel< + {{ "int64_t" if nobag else "int32_t" }}, + {{ "int64_t" if nobag else "uint32_t" }}, + {{ "true" if nobag else "false" }} + ><<< + div_round_up(total_unique_indices, kMaxThreads), + kMaxThreads, + 0, + at::cuda::getCurrentCUDAStream() + >>>( + MAKE_PTA_WITH_NAME( + func_name, sorted_linear_indices_num_runs, int32_t, 1, 32), + MAKE_PTA_WITH_NAME( + func_name, sorted_linear_indices_cumulative_run_lengths, int32_t, 1, 32), + MAKE_PTA_WITH_NAME( + func_name, infos_sorted, {{ "int64_t" if nobag else "int32_t" }}, 1, 32), + MAKE_PTA_WITH_NAME( + func_name, weights_placements, int32_t, 1, 32), + MAKE_PTA_WITH_NAME( + func_name, dev_or_uvm_unique_indices, int32_t, 1, 32), + info_B_num_bits + ); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + + table_unique_indices_offsets = + fbgemm_gpu::asynchronous_complete_cumsum_gpu(dev_or_uvm_unique_indices).to(at::kInt); + } + } + else { + lxu_cache_locations_sorted = at::empty_like(lxu_cache_locations); size_t temp_storage_bytes = 0; AT_CUDA_CHECK(radix_sort_pairs( - nullptr, - temp_storage_bytes, - linear_indices.data_ptr(), - linear_indices_sorted.data_ptr(), - lxu_cache_locations.data_ptr(), - lxu_cache_locations_sorted.data_ptr(), - linear_indices.numel(), - 0, - total_hash_size_bits, - at::cuda::getCurrentCUDAStream())); + nullptr, + temp_storage_bytes, + linear_indices.data_ptr(), + linear_indices_sorted.data_ptr(), + lxu_cache_locations.data_ptr(), + lxu_cache_locations_sorted.data_ptr(), + linear_indices.numel(), + 0, + total_hash_size_bits, + at::cuda::getCurrentCUDAStream())); auto temp_storage = at::empty( {static_cast(temp_storage_bytes)}, indices.options().dtype(at::kByte)); AT_CUDA_CHECK(radix_sort_pairs( - temp_storage.data_ptr(), - temp_storage_bytes, - linear_indices.data_ptr(), - linear_indices_sorted.data_ptr(), - lxu_cache_locations.data_ptr(), - lxu_cache_locations_sorted.data_ptr(), - linear_indices.numel(), - 0, - total_hash_size_bits, - at::cuda::getCurrentCUDAStream())); + temp_storage.data_ptr(), + temp_storage_bytes, + linear_indices.data_ptr(), + linear_indices_sorted.data_ptr(), + lxu_cache_locations.data_ptr(), + lxu_cache_locations_sorted.data_ptr(), + linear_indices.numel(), + 0, + total_hash_size_bits, + at::cuda::getCurrentCUDAStream())); + } + } + + if (lxu_cache_locations.size(0) == 0 || !use_uniq_cache_locations || use_homogeneous_placements) { + table_unique_indices_offsets = at::zeros_like(weights_placements); } {%- endif %} @@ -788,6 +868,8 @@ Tensor split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_e {%- endif %} {%- if not dense %} MAKE_PTA_WITH_NAME(func_name3, lxu_cache_locations_sorted, int32_t, 1, 32), + use_uniq_cache_locations, + MAKE_PTA_WITH_NAME(func_name3, table_unique_indices_offsets, int32_t, 1, 32), {%- endif %} {%- if weighted %} MAKE_PTA_ACC_WITH_NAME(func_name3, indice_weights_sorted, cache_t, 1, 32), @@ -896,6 +978,8 @@ Tensor split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_e {%- endif %} {%- if not dense %} MAKE_PTA_WITH_NAME(func_name4, lxu_cache_locations_sorted, int32_t, 1, 32), + use_uniq_cache_locations, + MAKE_PTA_WITH_NAME(func_name4, table_unique_indices_offsets, int32_t, 1, 32), {%- endif %} {%- if weighted %} MAKE_PTA_ACC_WITH_NAME(func_name4, indice_weights_sorted, cache_t, 1, 32), @@ -1002,6 +1086,10 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { " Tensor vbe_row_output_offsets, " " Tensor vbe_b_t_map, " {%- endif %} + {%- if not is_index_select and not dense %} + " bool use_uniq_cache_locations, " + " bool use_homogeneous_placements, " + {%- endif %} " {{ args.split_function_schemas | join(", ") }}" ") -> Tensor"); DISPATCH_TO_CUDA( diff --git a/fbgemm_gpu/codegen/embedding_forward_quantized_host.cpp b/fbgemm_gpu/codegen/embedding_forward_quantized_host.cpp index 722201e4fb..11b514acc0 100644 --- a/fbgemm_gpu/codegen/embedding_forward_quantized_host.cpp +++ b/fbgemm_gpu/codegen/embedding_forward_quantized_host.cpp @@ -497,7 +497,10 @@ Tensor int_nbit_split_embedding_uvm_caching_codegen_lookup_function( lxu_cache_state.value(), total_cache_hash_size.value(), gather_uvm_stats, - uvm_cache_stats); + uvm_cache_stats, + c10::optional(), // num_uniq_cache_indices + c10::optional() // lxu_cache_locations_output + ); #ifdef FBCODE_CAFFE2 if (FLAGS_tbe_uvm_cache_enforced_misses > 0) { diff --git a/fbgemm_gpu/codegen/embedding_forward_quantized_host_cpu.cpp b/fbgemm_gpu/codegen/embedding_forward_quantized_host_cpu.cpp index 66bd594f20..075574224f 100644 --- a/fbgemm_gpu/codegen/embedding_forward_quantized_host_cpu.cpp +++ b/fbgemm_gpu/codegen/embedding_forward_quantized_host_cpu.cpp @@ -244,13 +244,13 @@ Tensor pruned_array_lookup_cpu( TORCH_LIBRARY_FRAGMENT(fbgemm, m) { m.def( - "int_nbit_split_embedding_codegen_lookup_function(Tensor dev_weights, Tensor uvm_weights, Tensor weights_placements, Tensor weights_offsets, Tensor weights_tys, Tensor D_offsets, int total_D, int max_int2_D, int max_int4_D, int max_int8_D, int max_float16_D, int max_float32_D, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights, int output_dtype=1, Tensor? lxu_cache_weights=None, Tensor? lxu_cache_locations=None, int? row_alignment = None, int? max_float8_D=0, int? fp8_exponent_bits=-1, int? fp8_exponent_bias=-1) -> Tensor"); + "int_nbit_split_embedding_codegen_lookup_function(Tensor dev_weights, Tensor uvm_weights, Tensor weights_placements, Tensor weights_offsets, Tensor weights_tys, Tensor D_offsets, SymInt total_D, int max_int2_D, int max_int4_D, int max_int8_D, int max_float16_D, int max_float32_D, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights, int output_dtype=1, Tensor? lxu_cache_weights=None, Tensor? lxu_cache_locations=None, int? row_alignment = None, int? max_float8_D=0, int? fp8_exponent_bits=-1, int? fp8_exponent_bias=-1) -> Tensor"); DISPATCH_TO_CPU( "int_nbit_split_embedding_codegen_lookup_function", int_nbit_split_embedding_codegen_lookup_function_cpu); m.def( - "int_nbit_split_embedding_uvm_caching_codegen_lookup_function(Tensor dev_weights, Tensor uvm_weights, Tensor weights_placements, Tensor weights_offsets, Tensor weights_tys, Tensor D_offsets, int total_D, int max_int2_D, int max_int4_D, int max_int8_D, int max_float16_D, int max_float32_D, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights=None, int output_dtype=1, Tensor? lxu_cache_weights=None, Tensor? lxu_cache_locations=None, int? row_alignment=-1, int? max_float8_D=0, int? fp8_exponent_bits=-1, int? fp8_exponent_bias=-1, Tensor? cache_hash_size_cumsum=None, int? total_cache_hash_size=-1, Tensor? cache_index_table_map=None, Tensor? lxu_cache_state=None, Tensor? lxu_state=None) -> Tensor"); + "int_nbit_split_embedding_uvm_caching_codegen_lookup_function(Tensor dev_weights, Tensor uvm_weights, Tensor weights_placements, Tensor weights_offsets, Tensor weights_tys, Tensor D_offsets, SymInt total_D, int max_int2_D, int max_int4_D, int max_int8_D, int max_float16_D, int max_float32_D, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights=None, int output_dtype=1, Tensor? lxu_cache_weights=None, Tensor? lxu_cache_locations=None, int? row_alignment=-1, int? max_float8_D=0, int? fp8_exponent_bits=-1, int? fp8_exponent_bias=-1, Tensor? cache_hash_size_cumsum=None, int? total_cache_hash_size=-1, Tensor? cache_index_table_map=None, Tensor? lxu_cache_state=None, Tensor? lxu_state=None) -> Tensor"); DISPATCH_TO_CPU( "int_nbit_split_embedding_uvm_caching_codegen_lookup_function", int_nbit_split_embedding_uvm_caching_codegen_lookup_function_cpu); diff --git a/fbgemm_gpu/codegen/embedding_forward_split_kernel_v2_template.cu b/fbgemm_gpu/codegen/embedding_forward_split_kernel_v2_template.cu index 61b941e4d1..923b52baf4 100644 --- a/fbgemm_gpu/codegen/embedding_forward_split_kernel_v2_template.cu +++ b/fbgemm_gpu/codegen/embedding_forward_split_kernel_v2_template.cu @@ -130,7 +130,7 @@ __inline__ __device__ void process_all_indices_no_pooling( const auto total_load_D = static_cast(smem[params_offset + SAVED_PARAMS::P_total_load_D]); // Each thread loads a separate weight ptr - const auto weight_ptrs = reinterpret_cast(&weights[indices[threadIdx.x] * load_D]); + const auto weight_ptrs = reinterpret_cast(&weights[indices[threadIdx.x] * load_D]); // Assuming kWarpSize is a multiple of STEP for (uint32_t l_start = 0; l_start < TOTAL_L; l_start += STEP) { @@ -332,8 +332,8 @@ __noinline__ __device__ void process_all_indices_small_Ls( const cache_t* lxu_cache_weights = reinterpret_cast(smem[params_offset + LXU_CACHE_PARAMS::P_lxu_cache_weights]); SMEM_GENERIC_PTR[threadIdx.x] = cache_idx != kCacheLocationMissing ? - reinterpret_cast(&lxu_cache_weights[cache_idx * max_D_cache]) : - reinterpret_cast(&weights[indices[l] * load_D]); + reinterpret_cast(&lxu_cache_weights[cache_idx * max_D_cache]) : + reinterpret_cast(&weights[indices[l] * load_D]); } if (!std::is_same::value) { cache_look_up_bits = ballot_sync(cache_idx != kCacheLocationMissing); @@ -558,8 +558,8 @@ __noinline__ __device__ void process_all_indices_large_Ls( const auto* lxu_cache_weights = reinterpret_cast(smem[params_offset + LXU_CACHE_PARAMS::P_lxu_cache_weights]); SMEM_GENERIC_PTR[threadIdx.x] = cache_idx != kCacheLocationMissing ? - reinterpret_cast(&lxu_cache_weights[cache_idx * max_D_cache]) : - reinterpret_cast(&weights[indices[l] * load_D]); + reinterpret_cast(&lxu_cache_weights[cache_idx * max_D_cache]) : + reinterpret_cast(&weights[indices[l] * load_D]); } if (!std::is_same::value) { cache_look_up_bits = ballot_sync(cache_idx != kCacheLocationMissing); diff --git a/fbgemm_gpu/codegen/embedding_forward_split_meta_template.cpp b/fbgemm_gpu/codegen/embedding_forward_split_meta_template.cpp index f46d56c096..0c8c930ecd 100644 --- a/fbgemm_gpu/codegen/embedding_forward_split_meta_template.cpp +++ b/fbgemm_gpu/codegen/embedding_forward_split_meta_template.cpp @@ -85,22 +85,20 @@ Tensor ) { // NB: omitted the device tests TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL - // TODO: SymIntify - {%- if not nobag %} - int32_t T = D_offsets.numel() - 1; + auto T = D_offsets.sym_numel() - 1; {%- else %} - int32_t total_L = indices.numel(); - int32_t T = weights_offsets.numel(); + auto total_L = indices.sym_numel(); + auto T = weights_offsets.sym_numel(); {%- endif %} TORCH_CHECK_GT(T, 0); // offsets = [B x T + 1] {%- if is_index_select %} const auto total_B = num_warps_per_feature * T; - const int32_t B = num_warps_per_feature; + const auto B = num_warps_per_feature; {%- else %} - const auto total_B = offsets.size(0) - 1; - const int32_t B = total_B / T; + const auto total_B = offsets.sym_size(0) - 1; + const auto B = total_B / T; {%- endif %} TORCH_CHECK_GE(B, 0); {%- if not nobag or is_index_select %} @@ -114,7 +112,7 @@ Tensor TORCH_CHECK_EQ(D % 4, 0); {%- endif %} {%- if vbe %} - TORCH_CHECK_EQ(vbe_row_output_offsets.numel(), total_B); + TORCH_CHECK_EQ(vbe_row_output_offsets.sym_numel(), total_B); TENSORS_HAVE_SAME_NUMEL(vbe_row_output_offsets, vbe_b_t_map); TORCH_CHECK_GE(vbe_output_size, 0); @@ -133,40 +131,41 @@ Tensor TORCH_CHECK_GT(num_warps_per_feature, 0); if (!permute_output_dim_0_1) { TORCH_CHECK_GE(output_size, 0); - TORCH_CHECK_GT(output_offsets.numel(), 0); + TORCH_CHECK_GT(output_offsets.sym_numel(), 0); } // If permute_output_dim_0_1 is true, output shape is (batch_size * total_D) // Else, output shape is (output_size) - output = at::empty({output_size}, dev_weights.options().dtype(getScalarType(o_dtype))); + output = at::empty_symint({output_size}, dev_weights.options().dtype(getScalarType(o_dtype))); {%- else %} TORCH_CHECK(o_dtype == SparseType::FP32 || o_dtype == SparseType::FP16 || o_dtype == SparseType::BF16 || o_dtype == SparseType::INT8); - int64_t adjusted_D = D; + c10::SymInt adjusted_D = D; if (o_dtype == SparseType::INT8) { - adjusted_D += T * kINT8QparamsBytes; + adjusted_D += T * int64_t(kINT8QparamsBytes); } - output = at::empty({total_L, adjusted_D}, dev_weights.options().dtype(getScalarType(o_dtype))); + output = at::empty_symint({total_L, adjusted_D}, dev_weights.options().dtype(getScalarType(o_dtype))); {%- endif %} {%- else %} SparseType o_dtype = static_cast(output_dtype); TORCH_CHECK(o_dtype == SparseType::FP32 || o_dtype == SparseType::FP16 || o_dtype == SparseType::BF16 || o_dtype == SparseType::INT8); - int64_t total_adjusted_D = total_D; + c10::SymInt total_adjusted_D = total_D; if (o_dtype == SparseType::INT8) { - total_adjusted_D += T * kINT8QparamsBytes; + // TODO: Why is kINT8QparamsBytes a float + total_adjusted_D += T * int64_t(kINT8QparamsBytes); } {%- if vbe %} // Use a 2D tensor to make it compatible with 2D PackedTensorsAccessor of other output - output = at::empty( + output = at::empty_symint( {1, vbe_output_size}, dev_weights.options().dtype(getScalarType(o_dtype)) ); {%- else %} - output = at::empty( + output = at::empty_symint( {B, total_adjusted_D}, dev_weights.options().dtype(getScalarType(o_dtype)) ); diff --git a/fbgemm_gpu/codegen/embedding_optimizer_split_device_kernel_template.cuh b/fbgemm_gpu/codegen/embedding_optimizer_split_device_kernel_template.cuh index f5d4b68c9d..5bdfa415ce 100644 --- a/fbgemm_gpu/codegen/embedding_optimizer_split_device_kernel_template.cuh +++ b/fbgemm_gpu/codegen/embedding_optimizer_split_device_kernel_template.cuh @@ -31,10 +31,10 @@ DEVICE_INLINE void split_{{ optimizer }}_table_update_kernel( const bool stochastic_rounding, const at::PhiloxCudaState& stochastic_rounding_philox_args, const uint32_t run_id, + const uint32_t cache_loc_run_id, const int32_t D, const int32_t t, const int64_t idx, - const int32_t segment_start, const uint32_t shfl_sync_mask, const int32_t shared_weight_offset, {{ args.split_ref_kernel_args | replace_pta_namespace() | join(",\n ") }} @@ -54,7 +54,7 @@ DEVICE_INLINE void split_{{ optimizer }}_table_update_kernel( weights = &uvm_weights[weights_offset + idx * D_emb]; } if (weights_placement == PlacementType::MANAGED_CACHING) { - const int32_t cache_idx = sorted_lxu_cache_locations[segment_start]; + const int32_t cache_idx = sorted_lxu_cache_locations[cache_loc_run_id]; if (cache_idx != kCacheLocationMissing) { cache_weights = &lxu_cache_weights[cache_idx][0]; } diff --git a/fbgemm_gpu/codegen/embedding_optimizer_split_kernel_template.cu b/fbgemm_gpu/codegen/embedding_optimizer_split_kernel_template.cu index 310a0c6100..73f3721cfb 100644 --- a/fbgemm_gpu/codegen/embedding_optimizer_split_kernel_template.cu +++ b/fbgemm_gpu/codegen/embedding_optimizer_split_kernel_template.cu @@ -71,11 +71,11 @@ void split_{{ optimizer }}_update_kernel( stochastic_rounding, stochastic_rounding_philox_args, run_id, + 0, // segment_start (not used right now because lxu_cache is not + // supported) D, 0, // t grad_dev_indices[run_id], // idx - 0, // segment_start (not used right now because lxu_cache is not - // supported) shfl_sync_mask, 0, // shared_weight_offset (not used because shared memory is not // needed as uint8_t is not supported) diff --git a/fbgemm_gpu/codegen/split_embedding_codegen_lookup_invoker.template b/fbgemm_gpu/codegen/split_embedding_codegen_lookup_invoker.template index 5af6e320f0..3a00d7d1fe 100644 --- a/fbgemm_gpu/codegen/split_embedding_codegen_lookup_invoker.template +++ b/fbgemm_gpu/codegen/split_embedding_codegen_lookup_invoker.template @@ -13,6 +13,7 @@ from .lookup_args import * {%- if is_fbcode %} + # Provide compatibility to downstream packages for eventual migration to the split training / inference packages try: torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops_cuda_training") @@ -28,9 +29,6 @@ torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:split_table_batched_emb torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:embedding_inplace_update") torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:embedding_inplace_update_cpu") -{%- else %} -# import os -# torch.ops.load_library(os.path.join(os.path.join(os.path.dirname(os.path.dirname(__file__)), "fbgemm_gpu_py.so"))) {%- endif %} diff --git a/fbgemm_gpu/fbgemm_gpu/__init__.py b/fbgemm_gpu/fbgemm_gpu/__init__.py index f63fde22d2..3ca97566b7 100644 --- a/fbgemm_gpu/fbgemm_gpu/__init__.py +++ b/fbgemm_gpu/fbgemm_gpu/__init__.py @@ -20,7 +20,7 @@ # Re-export docs # Trigger meta registrations -from . import _fbgemm_gpu_docs, sparse_operators # noqa: F401, E402 # noqa: F401, E402 +from . import _fbgemm_gpu_docs, sparse_ops # noqa: F401, E402 # noqa: F401, E402 # Re-export the version string from the auto-generated version file from ._fbgemm_gpu_version import __version__ # noqa: F401, E402 diff --git a/fbgemm_gpu/fbgemm_gpu/batched_unary_embeddings_ops.py b/fbgemm_gpu/fbgemm_gpu/batched_unary_embeddings_ops.py index 52b4587088..db9b4c48da 100644 --- a/fbgemm_gpu/fbgemm_gpu/batched_unary_embeddings_ops.py +++ b/fbgemm_gpu/fbgemm_gpu/batched_unary_embeddings_ops.py @@ -5,7 +5,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe from math import sqrt from typing import List @@ -28,6 +27,7 @@ def wrap_weight_to_parameter(weights: List[torch.Tensor]) -> List[torch.Tensor]: class BatchedUnaryEmbeddingBag(torch.nn.Module): + # pyre-fixme[3]: Return type must be annotated. def __init__(self, num_tasks: int, hash_sizes: List[int], long_index: bool = False): super().__init__() self.num_tasks = num_tasks @@ -49,6 +49,7 @@ def __init__(self, num_tasks: int, hash_sizes: List[int], long_index: bool = Fal self.register_buffer("table_offsets_tensor", table_offsets_tensor) self.init_parameters() + # pyre-fixme[3]: Return type must be annotated. def forward(self, offsets: torch.Tensor, input: torch.Tensor): # output is [N][B][T] return torch.ops.fbgemm.batched_unary_embeddings( @@ -59,6 +60,7 @@ def forward(self, offsets: torch.Tensor, input: torch.Tensor): ) @torch.jit.export + # pyre-fixme[3]: Return type must be annotated. def split_embedding_weights(self): embedding_weights = [] for n in range(self.num_tasks): @@ -73,6 +75,7 @@ def split_embedding_weights(self): return embedding_weights @torch.jit.export + # pyre-fixme[3]: Return type must be annotated. def init_parameters(self): for num_emb, param in zip( self.hash_sizes * self.num_tasks, diff --git a/fbgemm_gpu/fbgemm_gpu/permute_pooled_embedding_modules.py b/fbgemm_gpu/fbgemm_gpu/permute_pooled_embedding_modules.py index 4ce1cce9ad..cb66088b9f 100644 --- a/fbgemm_gpu/fbgemm_gpu/permute_pooled_embedding_modules.py +++ b/fbgemm_gpu/fbgemm_gpu/permute_pooled_embedding_modules.py @@ -10,7 +10,6 @@ from typing import List, Optional import torch -from torch import nn try: # pyre-ignore[21] @@ -24,48 +23,37 @@ ) -class PermutePooledEmbeddings(nn.Module): +class PermutePooledEmbeddings: def __init__( self, embs_dims: List[int], permute: List[int], device: Optional[torch.device] = None, ) -> None: - super(PermutePooledEmbeddings, self).__init__() logging.info("Using Permute Pooled Embeddings") - - self.register_buffer( - "_offset_dim_list", - torch.tensor( - [0] + list(accumulate(embs_dims)), device=device, dtype=torch.int64 - ), + self._offset_dim_list: torch.Tensor = torch.tensor( + [0] + list(accumulate(embs_dims)), device=device, dtype=torch.int64 ) - self.register_buffer( - "_permute", torch.tensor(permute, device=device, dtype=torch.int64) + + self._permute: torch.Tensor = torch.tensor( + permute, device=device, dtype=torch.int64 ) inv_permute: List[int] = [0] * len(permute) for i, p in enumerate(permute): inv_permute[p] = i - self.register_buffer( - "_inv_permute", torch.tensor(inv_permute, device=device, dtype=torch.int64) + self._inv_permute: torch.Tensor = torch.tensor( + inv_permute, device=device, dtype=torch.int64 ) - # `Union[BoundMethod[typing.Callable(torch.Tensor.tolist)[[Named(self, - # torch.Tensor)], List[typing.Any]], torch.Tensor], nn.Module, torch.Tensor]` - # is not a function. - inv_embs_dims = [embs_dims[i] for i in permute] - self.register_buffer( - "_inv_offset_dim_list", - torch.tensor( - [0] + list(accumulate(inv_embs_dims)), device=device, dtype=torch.int64 - ), + self._inv_offset_dim_list: torch.Tensor = torch.tensor( + [0] + list(accumulate(inv_embs_dims)), device=device, dtype=torch.int64 ) - def forward(self, pooled_embs: torch.Tensor) -> torch.Tensor: + def __call__(self, pooled_embs: torch.Tensor) -> torch.Tensor: result = torch.ops.fbgemm.permute_pooled_embs_auto_grad( pooled_embs, self._offset_dim_list.to(device=pooled_embs.device), diff --git a/fbgemm_gpu/fbgemm_gpu/sparse_operators.py b/fbgemm_gpu/fbgemm_gpu/sparse_operators.py deleted file mode 100644 index d0e3958090..0000000000 --- a/fbgemm_gpu/fbgemm_gpu/sparse_operators.py +++ /dev/null @@ -1,106 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from typing import Callable, Optional, Tuple - -import torch -from torch import Tensor - -try: - # pyre-ignore - from fbgemm_gpu import open_source # noqa: F401 -except Exception: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") - - -if hasattr(torch.library, "impl_abstract"): - impl_abstract = torch.library.impl_abstract -else: - # pyre-ignore - def impl_abstract(schema: str) -> Callable[[Callable], Callable]: - # no-op - # pyre-ignore - def wrapper(f: Callable) -> Callable: - return f - - return wrapper - - -@impl_abstract("fbgemm::permute_2D_sparse_data") -def permute_2D_sparse_data_meta( - permute: Tensor, - lengths: Tensor, - values: Tensor, - weights: Optional[Tensor] = None, - permuted_lengths_sum: Optional[int] = None, -) -> Tuple[Tensor, Tensor, Optional[Tensor]]: - torch._check( - lengths.dim() == 2, lambda: f"expected lengths.dim() == 2, got {lengths.dim()}" - ) - T = permute.numel() - B = lengths.size(1) - indices = values - permuted_lengths = lengths.new_empty([T, B]) - permuted_indices_size = 0 - if permuted_lengths_sum is not None: - permuted_indices_size = permuted_lengths_sum - else: - ctx = torch._custom_op.impl.get_ctx() - permuted_indices_size = ctx.new_dynamic_size() - # pyre-fixme - permuted_indices = indices.new_empty(permuted_indices_size) - permuted_weights = None - if weights is not None: - # pyre-fixme - permuted_weights = weights.new_empty(permuted_indices_size) - return permuted_lengths, permuted_indices, permuted_weights - - -@impl_abstract("fbgemm::permute_1D_sparse_data") -def permute_1D_sparse_data_meta( - permute: Tensor, - lengths: Tensor, - values: Tensor, - weights: Optional[Tensor] = None, - permuted_lengths_sum: Optional[int] = None, -) -> Tuple[Tensor, Tensor, Optional[Tensor]]: - indices = values - permuted_lengths_size = permute.numel() - permuted_lengths = lengths.new_empty([permuted_lengths_size]) - permuted_indices_size = 0 - if permuted_lengths_sum is not None: - permuted_indices_size = permuted_lengths_sum - else: - ctx = torch._custom_op.impl.get_ctx() - permuted_indices_size = ctx.new_dynamic_size() - # pyre-fixme - permuted_indices = indices.new_empty(permuted_indices_size) - permuted_weights = None - if weights is not None: - # pyre-fixme - permuted_weights = weights.new_empty(permuted_indices_size) - return permuted_lengths, permuted_indices, permuted_weights - - -@impl_abstract("fbgemm::expand_into_jagged_permute") -def expand_into_jagged_permute_meta( - permute: Tensor, - input_offsets: Tensor, - output_offsets: Tensor, - output_size: Tuple[int, ...], -) -> Tensor: - torch._check(permute.numel() > 0, lambda: "expected {permute.numel} > 0") - torch._check( - permute.numel() == input_offsets.numel() - 1, - lambda: f"expected {permute.numel()} == {input_offsets.numel()} - 1", - ) - torch._check( - permute.numel() == output_offsets.numel() - 1, - lambda: f"expected {permute.numel()} == {output_offsets.numel()} - 1", - ) - output_permute = input_offsets.new_empty(output_size) - return output_permute diff --git a/fbgemm_gpu/fbgemm_gpu/sparse_ops.py b/fbgemm_gpu/fbgemm_gpu/sparse_ops.py new file mode 100644 index 0000000000..0979959a53 --- /dev/null +++ b/fbgemm_gpu/fbgemm_gpu/sparse_ops.py @@ -0,0 +1,373 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Callable, List, Optional, Tuple + +import torch + +from fbgemm_gpu.split_embedding_configs import SparseType +from fbgemm_gpu.split_table_batched_embeddings_ops_common import PoolingMode + +try: + # pyre-ignore + from fbgemm_gpu import open_source # noqa: F401 +except Exception: + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:merge_pooled_embeddings") + torch.ops.load_library( + "//deeplearning/fbgemm/fbgemm_gpu:merge_pooled_embeddings_cpu" + ) + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops") + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops_cpu") + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:input_combine") + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:input_combine_cpu") + +from torch import Tensor + + +if hasattr(torch.library, "impl_abstract"): + impl_abstract = torch.library.impl_abstract +else: + # pyre-ignore + def impl_abstract(schema: str) -> Callable[[Callable], Callable]: + # no-op + # pyre-ignore + def wrapper(f: Callable) -> Callable: + return f + + return wrapper + + +@impl_abstract("fbgemm::permute_2D_sparse_data") +def permute_2D_sparse_data_meta( + permute: Tensor, + lengths: Tensor, + values: Tensor, + weights: Optional[Tensor] = None, + permuted_lengths_sum: Optional[int] = None, +) -> Tuple[Tensor, Tensor, Optional[Tensor]]: + torch._check( + lengths.dim() == 2, lambda: f"expected lengths.dim() == 2, got {lengths.dim()}" + ) + T = permute.numel() + B = lengths.size(1) + indices = values + permuted_lengths = lengths.new_empty([T, B]) + permuted_indices_size = 0 + if permuted_lengths_sum is not None: + permuted_indices_size = permuted_lengths_sum + else: + ctx = torch._custom_op.impl.get_ctx() + permuted_indices_size = ctx.new_dynamic_size() + # pyre-fixme + permuted_indices = indices.new_empty(permuted_indices_size) + permuted_weights = None + if weights is not None: + # pyre-fixme + permuted_weights = weights.new_empty(permuted_indices_size) + return permuted_lengths, permuted_indices, permuted_weights + + +@impl_abstract("fbgemm::permute_1D_sparse_data") +def permute_1D_sparse_data_meta( + permute: Tensor, + lengths: Tensor, + values: Tensor, + weights: Optional[Tensor] = None, + permuted_lengths_sum: Optional[int] = None, +) -> Tuple[Tensor, Tensor, Optional[Tensor]]: + indices = values + permuted_lengths_size = permute.numel() + permuted_lengths = lengths.new_empty([permuted_lengths_size]) + permuted_indices_size = 0 + if permuted_lengths_sum is not None: + permuted_indices_size = permuted_lengths_sum + else: + ctx = torch._custom_op.impl.get_ctx() + permuted_indices_size = ctx.new_dynamic_size() + # pyre-fixme + permuted_indices = indices.new_empty(permuted_indices_size) + permuted_weights = None + if weights is not None: + # pyre-fixme + permuted_weights = weights.new_empty(permuted_indices_size) + return permuted_lengths, permuted_indices, permuted_weights + + +@impl_abstract("fbgemm::masked_select_jagged_1d") +def masked_select_jagged_1d( + values: Tensor, lengths: Tensor, mask: Tensor +) -> Tuple[Tensor, Tensor]: + torch._check(values.dim() == 1) + torch._check(lengths.dim() == 1) + torch._check(values.device == lengths.device) + torch._check(values.device == mask.device) + + s0 = torch.library.get_ctx().new_dynamic_size() + masked_values = values.new_empty([s0]) + masked_lengths = torch.empty_like(lengths) + return masked_values, masked_lengths + + +@impl_abstract("fbgemm::tbe_input_combine") +def tbe_input_combine_abstract( + indices_list: List[Tensor], + offsets_list: List[Tensor], + per_sample_weights: List[Tensor], + include_last_offsets: Tensor, +) -> Tuple[Tensor, Tensor, Tensor]: + torch._check(len(indices_list) > 0) + torch._check(len(indices_list) == len(offsets_list)) + torch._check(len(indices_list) == len(per_sample_weights)) + torch._check(len(indices_list) == include_last_offsets.numel()) + total_indices = 0 + need_weight = False + for index, offset, weight in zip(indices_list, offsets_list, per_sample_weights): + torch._check(index.dtype == torch.int or index.dtype == torch.long) + torch._check(offset.dtype == torch.int or offset.dtype == torch.long) + torch._check(index.dim() == 1) + torch._check(offset.dim() == 1) + torch._check(index.is_contiguous()) + torch._check(offset.is_contiguous()) + total_indices = total_indices + index.numel() + if weight.numel() > 0: + torch._check(weight.dim() == 1) + torch._check(weight.numel() == index.numel()) + torch._check(weight.is_contiguous()) + need_weight = True + total_offsets = torch.library.get_ctx().new_dynamic_size() + combined_indices = indices_list[0].new_empty([total_indices], dtype=torch.int) + combined_offsets = offsets_list[0].new_empty([total_offsets], dtype=torch.int) + if need_weight: + combined_weights = per_sample_weights[0].new_empty( + [total_indices], dtype=torch.float + ) + else: + combined_weights = torch.empty(0) + return combined_indices, combined_offsets, combined_weights + + +@impl_abstract("fbgemm::jagged_index_select_2d_forward_v2") +def jagged_index_select_2d_forward_v2_abstract( + values: Tensor, + indices: Tensor, + input_offsets: Tensor, + output_offsets: Tensor, +) -> Tensor: + torch._check(values.device == indices.device) + torch._check(values.device == input_offsets.device) + torch._check(values.device == output_offsets.device) + torch._check(values.dim() == 2) + num_dense_output_rows = torch.library.get_ctx().new_dynamic_size() + num_cols = values.size(1) + return values.new_empty([num_dense_output_rows, num_cols]) + + +@impl_abstract("fbgemm::jagged_index_add_2d_forward_v2") +def jagged_index_add_2d_forward_v2_abstract( + values: Tensor, + indices: Tensor, + input_offsets: Tensor, + output_offsets: Tensor, + num_output_rows: int, +) -> Tensor: + torch._check(values.device == indices.device) + torch._check(values.device == input_offsets.device) + torch._check(values.device == output_offsets.device) + torch._check(values.dim() == 2) + num_cols = values.size(1) + return values.new_empty([num_output_rows, num_cols]) + + +@impl_abstract("fbgemm::expand_into_jagged_permute") +def expand_into_jagged_permute_meta( + permute: Tensor, + input_offsets: Tensor, + output_offsets: Tensor, + output_size: Tuple[int, ...], +) -> Tensor: + torch._check(permute.numel() > 0, lambda: "expected {permute.numel} > 0") + torch._check( + permute.numel() == input_offsets.numel() - 1, + lambda: f"expected {permute.numel()} == {input_offsets.numel()} - 1", + ) + torch._check( + permute.numel() == output_offsets.numel() - 1, + lambda: f"expected {permute.numel()} == {output_offsets.numel()} - 1", + ) + output_permute = input_offsets.new_empty(output_size) + return output_permute + + +@impl_abstract("fbgemm::int_nbit_split_embedding_codegen_lookup_function") +def int_nbit_split_embedding_codegen_lookup_function_meta( + dev_weights: torch.Tensor, + uvm_weights: torch.Tensor, + weights_placements: torch.Tensor, + weights_offsets: torch.Tensor, + weights_tys: torch.Tensor, + D_offsets: torch.Tensor, + total_D: int, + max_int2_D: int, + max_int4_D: int, + max_int8_D: int, + max_float16_D: int, + max_float32_D: int, + indices: torch.Tensor, + offsets: torch.Tensor, + pooling_mode: int, + indice_weights: Optional[torch.Tensor] = None, + output_dtype_int: Optional[int] = None, + lxu_cache_weights: Optional[torch.Tensor] = None, + lxu_cache_locations: Optional[torch.Tensor] = None, + row_alignment: Optional[int] = None, + max_float8_D: Optional[int] = None, + fp8_exponent_bits: Optional[int] = None, + fp8_exponent_bias: Optional[int] = None, +) -> torch.Tensor: + T = D_offsets.numel() - 1 + B = (offsets.size(0) - 1) // T + output_dtype = torch.float32 + torch._check( + output_dtype_int in (0, 1, 5), + lambda: f"expected output_dtype to be fp32, fp16 or bf16, got {indices.dtype}", + ) + if output_dtype_int == SparseType.FP32.value: + output_dtype = torch.float32 + elif output_dtype_int == SparseType.FP16.value: + output_dtype = torch.float16 + elif output_dtype_int == SparseType.BF16.value: + output_dtype = torch.bfloat16 + + if pooling_mode == PoolingMode.NONE: + # pyre-ignore + offsets_last: int = offsets[-1].item() + total_D_T: int = total_D // T + torch._check_is_size(offsets[-1].item()) + torch._check_is_size(total_D_T) + torch._check_is_size(B) + return dev_weights.new_empty( + [offsets_last, total_D_T], + dtype=output_dtype, + device=torch.device("meta"), + ) + torch._check_is_size(B) + torch._check_is_size(total_D) + return dev_weights.new_empty( + (B, total_D), + dtype=output_dtype, + device=torch.device("meta"), + ) + + +@impl_abstract("fbgemm::block_bucketize_sparse_features") +def block_bucketize_sparse_features_meta( + lengths: torch.Tensor, + indices: torch.Tensor, + bucketize_pos: bool, + sequence: bool, + block_sizes: torch.Tensor, + my_size: int, + weights: Optional[torch.Tensor] = None, + batch_size_per_feature: Optional[torch.Tensor] = None, + max_B: int = -1, +) -> Tuple[ + torch.Tensor, + torch.Tensor, + Optional[torch.Tensor], + Optional[torch.Tensor], + Optional[torch.Tensor], +]: + # Output: lengths, indices, weights", pos?, unbucketize_permute? + num_buckets = my_size + num_features = lengths.size(0) + num_values = indices.size(0) + return ( + lengths.new_empty([num_buckets * num_features]), + indices.new_empty([num_values]), + weights.new_empty(weights.shape) if weights is not None else None, + indices.new_empty([num_values]) if bucketize_pos else None, + indices.new_empty([num_values]), + ) + + +@impl_abstract("fbgemm::merge_pooled_embeddings") +def merge_pooled_embeddings( + pooled_embeddings: List[torch.Tensor], + uncat_dim_size: int, + target_device: torch.device, + cat_dim: int = 1, +) -> torch.Tensor: + if len(pooled_embeddings) == 0: + return torch.empty([], device=target_device) + torch._check_is_size(cat_dim) + torch._check(cat_dim >= 0) + torch._check(cat_dim <= 1) + total_cat_dim_size = 0 + for e in pooled_embeddings: + torch._check(e.dim() == 2) + torch._check(e.size(1 - cat_dim) == uncat_dim_size) + total_cat_dim_size += e.size(cat_dim) + torch._check_is_size(total_cat_dim_size) + e = pooled_embeddings[0] + if cat_dim == 0: + return e.new_empty( + [total_cat_dim_size, e.size(1)], + device=target_device, + ) + + return e.new_empty( + [e.size(0), total_cat_dim_size], + device=target_device, + ) + + +@impl_abstract("fbgemm::bounds_check_indices") +def bounds_check_indices( + rows_per_table: torch.Tensor, + indices: torch.Tensor, + offsets: torch.Tensor, + bounds_check_mode: int, + warning: torch.Tensor, + weights: Optional[torch.Tensor] = None, + B_offsets: Optional[torch.Tensor] = None, + max_B: int = -1, +) -> None: + pass + + +@impl_abstract("fbgemm::permute_sparse_features") +def permute_sparse_features_abstract( + permute: Tensor, lengths: Tensor, indices: Tensor, weights: Optional[Tensor] = None +) -> Tuple[Tensor, Tensor, Optional[Tensor]]: + torch._check(lengths.dtype == indices.dtype) + torch._check(permute.device == lengths.device) + torch._check(permute.device == indices.device) + if weights is not None: + torch._check(permute.device == weights.device) + num_output_features = permute.numel() + B = lengths.size(1) + permuted_lengths = lengths.new_empty(num_output_features, B) + output_size = torch.library.get_ctx().new_dynamic_size() + # pyre-fixme[6]: In call `torch._C.TensorBase.new_empty`, for 1st positional argument, + # expected `Sequence[Union[int, types.SymInt]]` but got `Union[int, torch.SymInt]` + permuted_indices = indices.new_empty(output_size) + permuted_weights = None + if weights is not None: + # pyre-fixme[6]: In call `torch._C.TensorBase.new_empty`, for 1st positional argument, + # expected `Sequence[Union[int, types.SymInt]]` but got `Union[int, torch.SymInt]` + permuted_weights = weights.new_empty(output_size) + return (permuted_lengths, permuted_indices, permuted_weights) + + +@impl_abstract("fbgemm::segment_sum_csr") +def segment_sum_csr_abstract( + batch_size: int, csr_seg: Tensor, values: Tensor +) -> Tensor: + output_size = csr_seg.numel() - 1 + output = values.new_empty(output_size) + return output diff --git a/fbgemm_gpu/fbgemm_gpu/split_embedding_inference_converter.py b/fbgemm_gpu/fbgemm_gpu/split_embedding_inference_converter.py index 86e53afbea..a2dd34ade5 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_embedding_inference_converter.py +++ b/fbgemm_gpu/fbgemm_gpu/split_embedding_inference_converter.py @@ -5,7 +5,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe import logging import math @@ -29,6 +28,7 @@ # TODO: add per-feature based converter option (based on embedding_specs during inference) # TODO: optimize embedding pruning and quantization latency. class SplitEmbInferenceConverter: + # pyre-fixme[3]: Return type must be annotated. def __init__( self, quantize_type: SparseType, @@ -46,6 +46,7 @@ def convert_model(self, model: torch.nn.Module) -> torch.nn.Module: self._process_split_embs(model) return model + # pyre-fixme[2]: Parameter must be annotated. def _prune_by_weights_l2_norm(self, new_num_rows, weights) -> Tuple[Tensor, float]: assert new_num_rows > 0 from numpy.linalg import norm @@ -83,6 +84,8 @@ def _prune_embs( weights, indicators, threshold, torch.int32 ) + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. def _get_quantization_config(self, name): quantization_config = self.quantization_config if quantization_config is None: diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py index 301dad837a..d5b0d66d23 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py @@ -505,10 +505,6 @@ def prefetch(self, indices: Tensor, offsets: Tensor) -> None: if not self.lxu_cache_weights.numel(): return - # FIXME: check the int32_t range failure in https://fburl.com/gdoc/kcdnrnvg . - # The real failure should be in cache handling in https://fburl.com/ox3f26r0 . - indices, offsets = indices.long(), offsets.long() - linear_cache_indices = torch.ops.fbgemm.linearize_cache_indices( self.cache_hash_size_cumsum, indices, diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py index cfea8e1f88..36bf5f43f9 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py @@ -280,7 +280,6 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module): record_cache_metrics: RecordCacheMetrics uvm_cache_stats: torch.Tensor local_uvm_cache_stats: torch.Tensor - linear_cache_indices_list: List[Tensor] def __init__( # noqa C901 self, @@ -350,8 +349,6 @@ def __init__( # noqa C901 self.embedding_specs = embedding_specs (rows, dims, locations, compute_devices) = zip(*embedding_specs) T_ = len(self.embedding_specs) - # pyre-fixme[8]: Attribute has type `List[int]`; used as - # `Tuple[Union[ComputeDevice, EmbeddingLocation, int]]`. self.dims: List[int] = dims assert T_ > 0 # mixed D is not supported by no bag kernels @@ -701,10 +698,6 @@ def __init__( # noqa C901 persistent=False, ) - # pyre-fixme[6]: For 1st argument expected `List[int]` but got - # `Tuple[Union[ComputeDevice, EmbeddingLocation, int]]`. - # pyre-fixme[6]: For 2nd argument expected `List[EmbeddingLocation]` but got - # `Tuple[Union[ComputeDevice, EmbeddingLocation, int]]`. cache_state = construct_cache_state(rows, locations, self.feature_table_map) # Add table-wise cache miss counter @@ -941,6 +934,11 @@ def forward( # noqa: C901 B_offsets=vbe_metadata.B_offsets, max_B=vbe_metadata.max_B, ) + + # Storing indices and offsets for linear_cache_indices recomputation + self._indices = indices + self._offsets = offsets + self.step += 1 if len(self.timesteps_prefetched) == 0: self._prefetch(indices, offsets) @@ -1160,7 +1158,6 @@ def _prefetch(self, indices: Tensor, offsets: Tensor) -> None: if not self.lxu_cache_weights.numel(): return - (indices, offsets) = indices.long(), offsets.long() linear_cache_indices = torch.ops.fbgemm.linearize_cache_indices( self.cache_hash_size_cumsum, indices, @@ -1234,8 +1231,6 @@ def _prefetch(self, indices: Tensor, offsets: Tensor) -> None: ) self.lxu_cache_locations_list.append(lxu_cache_locations) - if self.prefetch_pipeline: - self.linear_cache_indices_list.append(linear_cache_indices) if self.gather_uvm_cache_stats: # Accumulate local_uvm_cache_stats (int32) into uvm_cache_stats (int64). @@ -1255,8 +1250,6 @@ def _prefetch_tensors_record_stream( for t in self.lxu_cache_locations_list: t.record_stream(forward_stream) - for t in self.linear_cache_indices_list: - t.record_stream(forward_stream) def _update_cache_miss_counter( self, @@ -1580,8 +1573,9 @@ def _apply_cache_state( 0, device=self.current_device, dtype=torch.int32 ).fill_(-1) self.lxu_cache_locations = self.lxu_cache_locations_empty + self._indices = self.lxu_cache_locations_empty + self._offsets = self.lxu_cache_locations_empty self.prefetch_stream: Optional[torch.cuda.Stream] = None - self.linear_cache_indices_list = [] self._init_uvm_cache_stats() @@ -1796,7 +1790,12 @@ def _update_cache_counter_and_locations( self.lxu_cache_locations, ) - linear_cache_indices = self.linear_cache_indices_list.pop(0) + # Recompute linear_cache_indices + linear_cache_indices = torch.ops.fbgemm.linearize_cache_indices( + self.cache_hash_size_cumsum, + self._indices, + self._offsets, + ) lxu_cache_locations_new = torch.ops.fbgemm.lxu_cache_lookup( linear_cache_indices, self.lxu_cache_state, diff --git a/fbgemm_gpu/include/fbgemm_gpu/dispatch_macros.h b/fbgemm_gpu/include/fbgemm_gpu/dispatch_macros.h index f268735d04..834a226ce4 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/dispatch_macros.h +++ b/fbgemm_gpu/include/fbgemm_gpu/dispatch_macros.h @@ -201,3 +201,10 @@ #define FBGEMM_DISPATCH_FLOAT_HALF_AND_BFLOAT16(TYPE, NAME, ...) \ AT_DISPATCH_SWITCH( \ TYPE, NAME, FBGEMM_DISPATCH_FLOAT_HALF_AND_BFLOAT16_CASE(__VA_ARGS__)) + +// We can cleanup the following once fbgemm uses PyTorch 2.2 in January 2024. +#ifdef HAS_PT2_COMPLIANT_TAG +#define PT2_COMPLIANT_TAG at::Tag::pt2_compliant_tag +#else +#define PT2_COMPLIANT_TAG +#endif diff --git a/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h b/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h index 162c4b9ecc..1552dd1be2 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h +++ b/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h @@ -149,7 +149,8 @@ block_bucketize_sparse_features_cuda( const int64_t my_size, const c10::optional& weights, const c10::optional& batch_size_per_feature, - const int64_t max_batch_size); + const int64_t max_batch_size, + const c10::optional>& block_bucketize_pos); std::tuple< at::Tensor, @@ -168,7 +169,8 @@ block_bucketize_sparse_features_cpu( const int64_t my_size, const c10::optional& weights, const c10::optional& batch_size_per_feature, - const int64_t max_batch_size); + const int64_t max_batch_size, + const c10::optional>& block_bucketize_pos); std::tuple< at::Tensor, diff --git a/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_cache_cuda.cuh b/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_cache_cuda.cuh index 800f238f5e..57cd695d8f 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_cache_cuda.cuh +++ b/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_cache_cuda.cuh @@ -174,7 +174,9 @@ at::Tensor lxu_cache_lookup_cuda( at::Tensor lxu_cache_state, int64_t invalid_index, bool gather_cache_stats, - c10::optional uvm_cache_stats); + c10::optional uvm_cache_stats, + c10::optional num_uniq_cache_indices, + c10::optional lxu_cache_locations_output); at::Tensor emulate_cache_miss( at::Tensor lxu_cache_locations, @@ -240,4 +242,5 @@ void lxu_cache_locking_counter_decrement_cuda( /// and lxu_cache_locations_new[i] >= 0 void lxu_cache_locations_update_cuda( at::Tensor lxu_cache_locations, - at::Tensor lxu_cache_locations_new); + at::Tensor lxu_cache_locations_new, + c10::optional num_uniq_cache_indices); diff --git a/fbgemm_gpu/setup.py b/fbgemm_gpu/setup.py index 0a16e6f625..cb7ab36475 100644 --- a/fbgemm_gpu/setup.py +++ b/fbgemm_gpu/setup.py @@ -26,6 +26,11 @@ def parse_args(argv: List[str]) -> argparse.Namespace: parser = argparse.ArgumentParser(description="fbgemm_gpu setup") + parser.add_argument( + "--verbose", + action="store_true", + help="Print verbose logs during the build.", + ) parser.add_argument( "--package_variant", type=str, @@ -133,8 +138,6 @@ def set_cuda_environment_variables() -> None: def cmake_environment_variables(args) -> None: def _get_cxx11_abi(): try: - import torch - value = int(torch._C._GLIBCXX_USE_CXX11_ABI) except ImportError: value = 0 @@ -143,11 +146,22 @@ def _get_cxx11_abi(): torch_root = os.path.dirname(torch.__file__) os.environ["CMAKE_BUILD_PARALLEL_LEVEL"] = str(os.cpu_count() // 2) - cmake_args = [f"-DCMAKE_PREFIX_PATH={torch_root}", _get_cxx11_abi()] + cmake_args = [ + f"-DCMAKE_PREFIX_PATH={torch_root}", + _get_cxx11_abi(), + ] + + if args.verbose: + print("[SETUP.PY] Building in VERBOSE mode ...") + cmake_args.append("-DCMAKE_VERBOSE_MAKEFILE=1") + if args.package_variant == "cpu": + print("[SETUP.PY] Building the CPU-ONLY variant of FBGEMM_GPU ...") cmake_args.append("-DFBGEMM_CPU_ONLY=ON") + if args.nvml_lib_path: cmake_args.append(f"-DNVML_LIB_PATH={args.nvml_lib_path}") + return cmake_args @@ -183,6 +197,7 @@ def extract_variant_version(cls, variant: str) -> str: if variant == "cpu": variant_version = "+cpu" + elif variant == "cuda": set_cuda_environment_variables() if torch.version.cuda is not None: @@ -192,6 +207,7 @@ def extract_variant_version(cls, variant: str) -> str: sys.exit( "[SETUP.PY] Installed PyTorch variant is not CUDA; cannot determine the CUDA version!" ) + elif variant == "rocm": if torch.version.hip is not None: rocm_version = torch.version.hip.split(".") diff --git a/fbgemm_gpu/src/embedding_inplace_update.cu b/fbgemm_gpu/src/embedding_inplace_ops/embedding_inplace_update.cu similarity index 100% rename from fbgemm_gpu/src/embedding_inplace_update.cu rename to fbgemm_gpu/src/embedding_inplace_ops/embedding_inplace_update.cu diff --git a/fbgemm_gpu/src/embedding_inplace_update_cpu.cpp b/fbgemm_gpu/src/embedding_inplace_ops/embedding_inplace_update_cpu.cpp similarity index 96% rename from fbgemm_gpu/src/embedding_inplace_update_cpu.cpp rename to fbgemm_gpu/src/embedding_inplace_ops/embedding_inplace_update_cpu.cpp index 92c7428d71..80ed4db608 100644 --- a/fbgemm_gpu/src/embedding_inplace_update_cpu.cpp +++ b/fbgemm_gpu/src/embedding_inplace_ops/embedding_inplace_update_cpu.cpp @@ -62,9 +62,7 @@ void embedding_inplace_update_cpu_kernel( const uint8_t* __restrict__ update_weight_row = &update_weights[update_weight_offset]; - for (const auto d : c10::irange(D_bytes)) { - weight_row[d] = update_weight_row[d]; - } + memcpy(weight_row, update_weight_row, D_bytes); } } @@ -167,19 +165,11 @@ Tensor pruned_array_lookup_from_row_idx_cpu( TORCH_LIBRARY_FRAGMENT(fbgemm, m) { m.def( "emb_inplace_update(Tensor(a!) dev_weights, Tensor(b!) uvm_weights, Tensor weights_placements, Tensor weights_offsets, Tensor weights_tys, Tensor D_offsets, Tensor update_weights, Tensor update_table_indices, Tensor update_row_indices, Tensor update_offsets, int row_alignment=1, Tensor(c!)? lxu_cache_weights=None, Tensor? lxu_cache_locations=None) -> ()"); -} - -TORCH_LIBRARY_IMPL(fbgemm, CPU, m) { - DISPATCH_TO_CPU( - "emb_inplace_update", fbgemm_gpu::embedding_inplace_update_cpu); -} - -TORCH_LIBRARY_FRAGMENT(fbgemm, m) { m.def( "pruned_array_lookup_from_row_idx(Tensor update_row_indices, Tensor update_table_indices, Tensor index_remappings, Tensor index_remappings_offsets) -> Tensor"); -} -TORCH_LIBRARY_IMPL(fbgemm, CPU, m) { + DISPATCH_TO_CPU( + "emb_inplace_update", fbgemm_gpu::embedding_inplace_update_cpu); DISPATCH_TO_CPU( "pruned_array_lookup_from_row_idx", fbgemm_gpu::pruned_array_lookup_from_row_idx_cpu); diff --git a/fbgemm_gpu/src/embedding_inplace_update_gpu.cpp b/fbgemm_gpu/src/embedding_inplace_ops/embedding_inplace_update_gpu.cpp similarity index 94% rename from fbgemm_gpu/src/embedding_inplace_update_gpu.cpp rename to fbgemm_gpu/src/embedding_inplace_ops/embedding_inplace_update_gpu.cpp index 14ec1229d2..2605827a6d 100644 --- a/fbgemm_gpu/src/embedding_inplace_update_gpu.cpp +++ b/fbgemm_gpu/src/embedding_inplace_ops/embedding_inplace_update_gpu.cpp @@ -15,9 +15,6 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { DISPATCH_TO_CUDA( "emb_inplace_update", fbgemm_gpu::embedding_inplace_update_cuda); -} - -TORCH_LIBRARY_FRAGMENT(fbgemm, m) { DISPATCH_TO_CUDA( "pruned_array_lookup_from_row_idx", fbgemm_gpu::pruned_array_lookup_from_row_idx_cuda); diff --git a/fbgemm_gpu/test/embedding_inplace_update_test.cpp b/fbgemm_gpu/src/embedding_inplace_ops/embedding_inplace_update_test.cpp similarity index 96% rename from fbgemm_gpu/test/embedding_inplace_update_test.cpp rename to fbgemm_gpu/src/embedding_inplace_ops/embedding_inplace_update_test.cpp index 853da8115b..712315f7f7 100644 --- a/fbgemm_gpu/test/embedding_inplace_update_test.cpp +++ b/fbgemm_gpu/src/embedding_inplace_ops/embedding_inplace_update_test.cpp @@ -177,7 +177,10 @@ void test_embedding_inplace_update() { } } -TEST(embedding_inplace_update_test, random_update) { - test_embedding_inplace_update(); +TEST(EmbeddingInplaceUpdateTest, random_update) { + // TODO: Skipping test_embedding_inplace_update because it is + // unreliable and crashes occasionally. This should be fixed and re-enabled. + // + // test_embedding_inplace_update(); test_embedding_inplace_update(); } diff --git a/fbgemm_gpu/src/input_combine.cu b/fbgemm_gpu/src/input_combine_ops/input_combine.cu similarity index 100% rename from fbgemm_gpu/src/input_combine.cu rename to fbgemm_gpu/src/input_combine_ops/input_combine.cu diff --git a/fbgemm_gpu/src/input_combine_cpu.cpp b/fbgemm_gpu/src/input_combine_ops/input_combine_cpu.cpp similarity index 98% rename from fbgemm_gpu/src/input_combine_cpu.cpp rename to fbgemm_gpu/src/input_combine_ops/input_combine_cpu.cpp index c836b91d98..ac711655b2 100644 --- a/fbgemm_gpu/src/input_combine_cpu.cpp +++ b/fbgemm_gpu/src/input_combine_ops/input_combine_cpu.cpp @@ -6,6 +6,7 @@ * LICENSE file in the root directory of this source tree. */ +#include "fbgemm_gpu/dispatch_macros.h" #include "fbgemm_gpu/input_combine.h" #include "fbgemm_gpu/sparse_ops_utils.h" @@ -383,8 +384,14 @@ padding_fused_tbe_input_combine_with_length_cpu( } // namespace fbgemm_gpu TORCH_LIBRARY_FRAGMENT(fbgemm, m) { +#ifdef HAS_IMPL_ABSTRACT_PYSTUB + m.impl_abstract_pystub( + "fbgemm_gpu.sparse_ops", + "//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_py"); +#endif m.def( - "tbe_input_combine(Tensor[] indices_list, Tensor[] offsets_list, Tensor[] per_sample_weights, Tensor include_last_offsets) -> (Tensor, Tensor, Tensor)"); + "tbe_input_combine(Tensor[] indices_list, Tensor[] offsets_list, Tensor[] per_sample_weights, Tensor include_last_offsets) -> (Tensor, Tensor, Tensor)", + {PT2_COMPLIANT_TAG}); m.def( "tbe_input_combine_with_length(Tensor[] indices_list, Tensor[] lengths_list, Tensor[] per_sample_weights) -> (Tensor, Tensor, Tensor)"); m.def( diff --git a/fbgemm_gpu/src/input_combine_gpu.cpp b/fbgemm_gpu/src/input_combine_ops/input_combine_gpu.cpp similarity index 100% rename from fbgemm_gpu/src/input_combine_gpu.cpp rename to fbgemm_gpu/src/input_combine_ops/input_combine_gpu.cpp diff --git a/fbgemm_gpu/src/jagged_tensor_ops/jagged_dense_elementwise_mul_backward.cu b/fbgemm_gpu/src/jagged_tensor_ops/jagged_dense_elementwise_mul_backward.cu index fc4a08b1f2..41fb56a899 100644 --- a/fbgemm_gpu/src/jagged_tensor_ops/jagged_dense_elementwise_mul_backward.cu +++ b/fbgemm_gpu/src/jagged_tensor_ops/jagged_dense_elementwise_mul_backward.cu @@ -104,7 +104,8 @@ void jagged_jagged_elementwise_dense_output_( x_offset_ptrs.vals[d] = \ x_offsets_contig[d].template data_ptr(); \ } \ - const auto func_name = "jagged_jagged_elementwise_dense_output_kernel_"; \ + [[maybe_unused]] const auto func_name = \ + "jagged_jagged_elementwise_dense_output_kernel_"; \ jagged_jagged_elementwise_dense_output_kernel_ \ <<>>( \ MAKE_PTA_WITH_NAME(func_name, x_values, scalar_t, 2, 32), \ diff --git a/fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops.cu b/fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops.cu index 186560caaf..dbccc6fdfd 100644 --- a/fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops.cu +++ b/fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops.cu @@ -37,8 +37,3 @@ FBGEMM_OP_DISPATCH(CUDA, "jagged_2d_to_dense", fbgemm_gpu::jagged_2d_to_dense); FBGEMM_OP_DISPATCH(CUDA, "jagged_softmax", fbgemm_gpu::jagged_softmax); FBGEMM_OP_DISPATCH(CUDA, "jagged_jagged_bmm", fbgemm_gpu::jagged_jagged_bmm); FBGEMM_OP_DISPATCH(CUDA, "jagged_dense_bmm", fbgemm_gpu::jagged_dense_bmm); - -FBGEMM_OP_DISPATCH( - CUDA, - "jagged_index_select", - fbgemm_gpu::jagged_index_select_2d); diff --git a/fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_autograd.cpp b/fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_autograd.cpp index 30060728d7..89c2c89116 100644 --- a/fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_autograd.cpp +++ b/fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_autograd.cpp @@ -524,30 +524,20 @@ class JaggedIndexSelect2dOp Tensor output_offsets = output_lengths.cumsum(0); Tensor input_offsets = lengths.cumsum(0); - int64_t num_dense_output_rows = - output_offsets[output_offsets.numel() - 1].item(); - ctx->save_for_backward({indices, output_offsets, input_offsets}); - ctx->saved_data["num_dense_grad_rows"] = num_dense_output_rows; - ctx->saved_data["num_input_rows"] = values.size(0); + ctx->saved_data["num_input_rows"] = values.sym_size(0); static auto op = c10::Dispatcher::singleton() - .findSchemaOrThrow("fbgemm::jagged_index_select_2d_forward", "") + .findSchemaOrThrow("fbgemm::jagged_index_select_2d_forward_v2", "") .typed(); + const Tensor& output_offsets)>(); return { - op.call( - values, - indices, - input_offsets, - output_offsets, - num_dense_output_rows), + op.call(values, indices, input_offsets, output_offsets), output_lengths}; } @@ -565,29 +555,20 @@ class JaggedIndexSelect2dOp TENSORS_ON_SAME_DEVICE(grad, indices); - int64_t num_dense_grad_rows = - ctx->saved_data["num_dense_grad_rows"].toInt(); - int64_t num_output_rows = ctx->saved_data["num_input_rows"].toInt(); + auto num_output_rows = ctx->saved_data["num_input_rows"].toSymInt(); static auto op = c10::Dispatcher::singleton() - .findSchemaOrThrow("fbgemm::jagged_index_add_2d_forward", "") + .findSchemaOrThrow("fbgemm::jagged_index_add_2d_forward_v2", "") .typed(); + c10::SymInt num_output_rows)>(); return { - op.call( - grad, - indices, - grad_offsets, - output_offsets, - num_dense_grad_rows, - num_output_rows), + op.call(grad, indices, grad_offsets, output_offsets, num_output_rows), torch::autograd::Variable(), // lengths torch::autograd::Variable() // indices }; @@ -883,6 +864,9 @@ TORCH_LIBRARY_IMPL(fbgemm, Autograd, m) { m.impl("jagged_softmax", TORCH_FN(fbgemm_gpu::jagged_softmax)); m.impl("jagged_jagged_bmm", TORCH_FN(fbgemm_gpu::jagged_jagged_bmm)); m.impl("jagged_dense_bmm", TORCH_FN(fbgemm_gpu::jagged_dense_bmm)); - m.impl("jagged_index_select", TORCH_FN(fbgemm_gpu::jagged_index_select_2d)); m.impl("jagged_slice", TORCH_FN(fbgemm_gpu::jagged_slice)); } + +TORCH_LIBRARY_IMPL(fbgemm, CompositeImplicitAutograd, m) { + m.impl("jagged_index_select", TORCH_FN(fbgemm_gpu::jagged_index_select_2d)); +} diff --git a/fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_cpu.cpp b/fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_cpu.cpp index 7ae207adb8..5a1753b239 100644 --- a/fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_cpu.cpp +++ b/fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_cpu.cpp @@ -1136,6 +1136,63 @@ Tensor jagged_index_select_2d_forward_cpu( return output; } +// v2 supports PT2 Dynamic Shapes. +// The problem with v1 is that it accepts a redundant num_dense_output_rows arg +// that we compute by peeking at output_offsets.data. +// PT2 has problems with data access, so we hide the data access inside +// the new operator. +Tensor jagged_index_select_2d_forward_v2_impl( + const Tensor& values, + const Tensor& indices, + const Tensor& input_offsets, + const Tensor& output_offsets) { + int64_t num_dense_output_rows = + output_offsets[output_offsets.numel() - 1].item(); + static auto v1_op = + c10::Dispatcher::singleton() + .findSchemaOrThrow("fbgemm::jagged_index_select_2d_forward", "") + .typed(); + return v1_op.call( + values, indices, input_offsets, output_offsets, num_dense_output_rows); +} + +// v2 supports PT2 Dynamic Shapes. +// The problem with v1 is that it accepts a redundant num_dense_output_rows arg +// that we compute by peeking at input_offsets.data. +// PT2 has problems with data access, so we hide the data access inside +// the new operator. +Tensor jagged_index_add_2d_forward_v2_impl( + const Tensor& values, + const Tensor& indices, + const Tensor& input_offsets, + const Tensor& output_offsets, + const int64_t num_output_rows) { + int64_t num_dense_output_rows = + input_offsets[input_offsets.numel() - 1].item(); + static auto v1_op = + c10::Dispatcher::singleton() + .findSchemaOrThrow("fbgemm::jagged_index_add_2d_forward", "") + .typed(); + return v1_op.call( + values, + indices, + input_offsets, + output_offsets, + num_dense_output_rows, + num_output_rows); +} + template void jagged_index_add_2d_kernel( at::TensorAccessor output, @@ -1569,18 +1626,27 @@ Tensor jagged_slice_forward_cpu( } // namespace fbgemm_gpu TORCH_LIBRARY_FRAGMENT(fbgemm, m) { +#ifdef HAS_IMPL_ABSTRACT_PYSTUB + m.impl_abstract_pystub( + "fbgemm_gpu.sparse_ops", + "//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_py"); +#endif // (dense, offsets) -> jagged. Offsets output is same as input. // SymInt is a new PyTorch 2.0 feature to support dynamic shape. See more // details at https://pytorch.org/get-started/pytorch-2.0/#dynamic-shapes. If // you find it doesn't compile, please pull the new PyTorch 2.0 code m.def( - "dense_to_jagged(Tensor dense, Tensor[] x_offsets, SymInt? total_L=None) -> (Tensor, Tensor[])"); + "dense_to_jagged(Tensor dense, Tensor[] x_offsets, SymInt? total_L=None) -> (Tensor, Tensor[])", + {PT2_COMPLIANT_TAG}); m.def( - "dense_to_jagged_forward(Tensor dense, Tensor[] x_offsets, SymInt? total_L=None) -> Tensor"); + "dense_to_jagged_forward(Tensor dense, Tensor[] x_offsets, SymInt? total_L=None) -> Tensor", + {PT2_COMPLIANT_TAG}); m.def( - "jagged_2d_to_dense(Tensor values, Tensor offsets, SymInt max_sequence_length) -> Tensor"); + "jagged_2d_to_dense(Tensor values, Tensor offsets, SymInt max_sequence_length) -> Tensor", + {PT2_COMPLIANT_TAG}); m.def( - "jagged_1d_to_dense(Tensor values, Tensor offsets, SymInt max_sequence_length, int padding_value) -> Tensor"); + "jagged_1d_to_dense(Tensor values, Tensor offsets, SymInt max_sequence_length, int padding_value) -> Tensor", + {PT2_COMPLIANT_TAG}); m.def( "stacked_jagged_2d_to_dense_forward(Tensor values, Tensor lengths, int[] offset_per_key, int[] max_lengths_per_key, int padding_value = 0) -> (Tensor[], Tensor[])"); m.def( @@ -1590,57 +1656,81 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { m.def( "stacked_jagged_2d_to_dense(Tensor values, Tensor lengths, int[] offset_per_key, int[] max_lengths_per_key, int padding_value = 0) -> Tensor[]"); m.def( - "jagged_to_padded_dense(Tensor values, Tensor[] offsets, SymInt[] max_lengths, float padding_value = 0) -> Tensor"); + "jagged_to_padded_dense(Tensor values, Tensor[] offsets, SymInt[] max_lengths, float padding_value = 0) -> Tensor", + {PT2_COMPLIANT_TAG}); m.def( - "jagged_to_padded_dense_forward(Tensor values, Tensor[] offsets, SymInt[] max_lengths, float padding_value = 0) -> Tensor"); + "jagged_to_padded_dense_forward(Tensor values, Tensor[] offsets, SymInt[] max_lengths, float padding_value = 0) -> Tensor", + {PT2_COMPLIANT_TAG}); m.def( - "jagged_to_padded_dense_backward(Tensor grad_output, Tensor[] offsets, SymInt total_L) -> Tensor"); + "jagged_to_padded_dense_backward(Tensor grad_output, Tensor[] offsets, SymInt total_L) -> Tensor", + {PT2_COMPLIANT_TAG}); // jagged + dense -> dense m.def( - "jagged_dense_elementwise_add(Tensor x_values, Tensor[] x_offsets, Tensor y) -> Tensor"); + "jagged_dense_elementwise_add(Tensor x_values, Tensor[] x_offsets, Tensor y) -> Tensor", + {PT2_COMPLIANT_TAG}); // jagged + dense -> jagged (treat "zeros" in the jagged tensor as unknowns. // output offsets is same as x_offsets) m.def( - "jagged_dense_elementwise_add_jagged_output(Tensor x_values, Tensor[] x_offsets, Tensor y) -> (Tensor, Tensor[])"); + "jagged_dense_elementwise_add_jagged_output(Tensor x_values, Tensor[] x_offsets, Tensor y) -> (Tensor, Tensor[])", + {PT2_COMPLIANT_TAG}); m.def( - "jagged_dense_dense_elementwise_add_jagged_output_forward(Tensor x_values, Tensor[] x_offsets, Tensor y_0, Tensor y_1) -> Tensor"); + "jagged_dense_dense_elementwise_add_jagged_output_forward(Tensor x_values, Tensor[] x_offsets, Tensor y_0, Tensor y_1) -> Tensor", + {PT2_COMPLIANT_TAG}); m.def( - "jagged_dense_dense_elementwise_add_jagged_output(Tensor x_values, Tensor[] x_offsets, Tensor y_0, Tensor y_1) -> (Tensor, Tensor[])"); + "jagged_dense_dense_elementwise_add_jagged_output(Tensor x_values, Tensor[] x_offsets, Tensor y_0, Tensor y_1) -> (Tensor, Tensor[])", + {PT2_COMPLIANT_TAG}); // jagged * dense -> jagged (its offsets is same as x_offsets) m.def( - "jagged_dense_elementwise_mul(Tensor x_values, Tensor[] x_offsets, Tensor y) -> (Tensor, Tensor[])"); + "jagged_dense_elementwise_mul(Tensor x_values, Tensor[] x_offsets, Tensor y) -> (Tensor, Tensor[])", + {PT2_COMPLIANT_TAG}); m.def( - "jagged_dense_elementwise_mul_forward(Tensor x_values, Tensor[] x_offsets, Tensor y) -> Tensor"); + "jagged_dense_elementwise_mul_forward(Tensor x_values, Tensor[] x_offsets, Tensor y) -> Tensor", + {PT2_COMPLIANT_TAG}); m.def( - "jagged_dense_elementwise_mul_backward(Tensor grad_output, Tensor[] x_offsets, Tensor y, Tensor x_values) -> (Tensor, Tensor)"); + "jagged_dense_elementwise_mul_backward(Tensor grad_output, Tensor[] x_offsets, Tensor y, Tensor x_values) -> (Tensor, Tensor)", + {PT2_COMPLIANT_TAG}); m.def( - "batched_dense_vec_jagged_2d_mul(Tensor v, Tensor a_values, Tensor a_offsets) -> Tensor"); + "batched_dense_vec_jagged_2d_mul(Tensor v, Tensor a_values, Tensor a_offsets) -> Tensor", + {PT2_COMPLIANT_TAG}); m.def( - "batched_dense_vec_jagged_2d_mul_forward(Tensor v, Tensor a_values, Tensor a_offsets) -> Tensor"); + "batched_dense_vec_jagged_2d_mul_forward(Tensor v, Tensor a_values, Tensor a_offsets) -> Tensor", + {PT2_COMPLIANT_TAG}); m.def( - "batched_dense_vec_jagged_2d_mul_backward(Tensor grad_output, Tensor v, Tensor a_values, Tensor a_offsets) -> (Tensor, Tensor)"); + "batched_dense_vec_jagged_2d_mul_backward(Tensor grad_output, Tensor v, Tensor a_values, Tensor a_offsets) -> (Tensor, Tensor)", + {PT2_COMPLIANT_TAG}); m.def( - "jagged_index_select(Tensor values, Tensor lengths, Tensor indices) -> Tensor[]"); + "jagged_index_select(Tensor values, Tensor lengths, Tensor indices) -> Tensor[]", + {PT2_COMPLIANT_TAG}); m.def( "jagged_index_select_2d_forward(Tensor values, Tensor indices, Tensor input_offsets, Tensor output_offsets, int num_dense_output_rows) -> Tensor"); + m.def( + "jagged_index_select_2d_forward_v2(Tensor values, Tensor indices, Tensor input_offsets, Tensor output_offsets) -> Tensor", + {PT2_COMPLIANT_TAG}); m.def( "jagged_index_add_2d_forward(Tensor values, Tensor indices, Tensor input_offsets, Tensor output_offsets, int num_dense_input_rows, int num_output_rows) -> Tensor"); + m.def( + "jagged_index_add_2d_forward_v2(Tensor values, Tensor indices, Tensor input_offsets, Tensor output_offsets, SymInt num_output_rows) -> Tensor", + {PT2_COMPLIANT_TAG}); m.def( "jagged_1d_to_truncated_values(Tensor values, Tensor lengths, int max_truncated_length) -> Tensor"); m.def( - "masked_select_jagged_1d(Tensor values, Tensor lengths, Tensor mask) -> (Tensor, Tensor)"); + "masked_select_jagged_1d(Tensor values, Tensor lengths, Tensor mask) -> (Tensor, Tensor)", + {PT2_COMPLIANT_TAG}); m.def( - "jagged_softmax(Tensor values, Tensor x_offsets, int max_L) -> (Tensor, Tensor)"); + "jagged_softmax(Tensor values, Tensor x_offsets, int max_L) -> (Tensor, Tensor)", + {PT2_COMPLIANT_TAG}); m.def( "jagged_softmax_forward(Tensor values, Tensor x_offsets, int max_L) -> Tensor"); m.def( "jagged_softmax_backward(Tensor grad_output, Tensor output, Tensor x_offsets, int max_L) -> Tensor"); m.def( - "jagged_jagged_bmm(Tensor x_values, Tensor y_values, Tensor x_offsets, int max_L) -> Tensor"); + "jagged_jagged_bmm(Tensor x_values, Tensor y_values, Tensor x_offsets, int max_L) -> Tensor", + {PT2_COMPLIANT_TAG}); m.def( "jagged_jagged_bmm_forward(Tensor x_values, Tensor y_values, Tensor x_offsets, int max_L) -> Tensor"); m.def( - "jagged_dense_bmm(Tensor x_values, Tensor x_offsets, Tensor y, int max_L) -> (Tensor, Tensor)"); + "jagged_dense_bmm(Tensor x_values, Tensor x_offsets, Tensor y, int max_L) -> (Tensor, Tensor)", + {PT2_COMPLIANT_TAG}); m.def( "jagged_dense_bmm_forward(Tensor x_values, Tensor x_offsets, Tensor y, int max_L) -> Tensor"); // jagged -> jagged @@ -1702,7 +1792,6 @@ TORCH_LIBRARY_IMPL(fbgemm, CPU, m) { DISPATCH_TO_CPU( "jagged_index_select_2d_forward", fbgemm_gpu::jagged_index_select_2d_forward_cpu); - DISPATCH_TO_CPU("jagged_index_select", fbgemm_gpu::jagged_index_select_2d); DISPATCH_TO_CPU( "jagged_index_add_2d_forward", fbgemm_gpu::jagged_index_add_2d_forward_cpu); @@ -1723,3 +1812,12 @@ TORCH_LIBRARY_IMPL(fbgemm, CPU, m) { "jagged_dense_bmm_forward", fbgemm_gpu::jagged_dense_bmm_forward); DISPATCH_TO_CPU("jagged_slice_forward", fbgemm_gpu::jagged_slice_forward_cpu); } + +TORCH_LIBRARY_IMPL(fbgemm, CompositeExplicitAutograd, m) { + m.impl( + "jagged_index_select_2d_forward_v2", + fbgemm_gpu::jagged_index_select_2d_forward_v2_impl); + m.impl( + "jagged_index_add_2d_forward_v2", + fbgemm_gpu::jagged_index_add_2d_forward_v2_impl); +} diff --git a/fbgemm_gpu/src/layout_transform_ops.cu b/fbgemm_gpu/src/layout_transform_ops/layout_transform_ops.cu similarity index 99% rename from fbgemm_gpu/src/layout_transform_ops.cu rename to fbgemm_gpu/src/layout_transform_ops/layout_transform_ops.cu index 3b76e5b32c..40b796d23f 100644 --- a/fbgemm_gpu/src/layout_transform_ops.cu +++ b/fbgemm_gpu/src/layout_transform_ops/layout_transform_ops.cu @@ -13,19 +13,16 @@ #include "fbgemm_gpu/cub_namespace_postfix.cuh" // clang-format on -#include "fbgemm_gpu/layout_transform_ops.cuh" -#include "fbgemm_gpu/sparse_ops.h" -#include "fbgemm_gpu/sparse_ops_utils.h" - #include #include #include #include #include - #include - #include "ATen/Parallel.h" +#include "fbgemm_gpu/layout_transform_ops.cuh" +#include "fbgemm_gpu/sparse_ops.h" +#include "fbgemm_gpu/sparse_ops_utils.h" using Tensor = at::Tensor; diff --git a/fbgemm_gpu/src/layout_transform_ops_cpu.cpp b/fbgemm_gpu/src/layout_transform_ops/layout_transform_ops_cpu.cpp similarity index 100% rename from fbgemm_gpu/src/layout_transform_ops_cpu.cpp rename to fbgemm_gpu/src/layout_transform_ops/layout_transform_ops_cpu.cpp diff --git a/fbgemm_gpu/src/layout_transform_ops_gpu.cpp b/fbgemm_gpu/src/layout_transform_ops/layout_transform_ops_gpu.cpp similarity index 99% rename from fbgemm_gpu/src/layout_transform_ops_gpu.cpp rename to fbgemm_gpu/src/layout_transform_ops/layout_transform_ops_gpu.cpp index 986b479bd3..19e757875f 100644 --- a/fbgemm_gpu/src/layout_transform_ops_gpu.cpp +++ b/fbgemm_gpu/src/layout_transform_ops/layout_transform_ops_gpu.cpp @@ -6,12 +6,11 @@ * LICENSE file in the root directory of this source tree. */ -#include "fbgemm_gpu/sparse_ops.h" -#include "fbgemm_gpu/sparse_ops_utils.h" - #include #include #include +#include "fbgemm_gpu/sparse_ops.h" +#include "fbgemm_gpu/sparse_ops_utils.h" TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) { DISPATCH_TO_CUDA( diff --git a/fbgemm_gpu/src/merge_pooled_embeddings_cpu.cpp b/fbgemm_gpu/src/merge_pooled_embedding_ops/merge_pooled_embedding_ops_cpu.cpp similarity index 63% rename from fbgemm_gpu/src/merge_pooled_embeddings_cpu.cpp rename to fbgemm_gpu/src/merge_pooled_embedding_ops/merge_pooled_embedding_ops_cpu.cpp index f3db379f93..b7dc57ea63 100644 --- a/fbgemm_gpu/src/merge_pooled_embeddings_cpu.cpp +++ b/fbgemm_gpu/src/merge_pooled_embedding_ops/merge_pooled_embedding_ops_cpu.cpp @@ -10,7 +10,8 @@ #include #include #include -#include "fbgemm_gpu/sparse_ops_utils.h" +#include "fbgemm_gpu/dispatch_macros.h" +#include "fbgemm_gpu/ops_utils.h" using Tensor = at::Tensor; @@ -49,7 +50,22 @@ Tensor merge_pooled_embeddings_cpu( } // namespace fbgemm_gpu -TORCH_LIBRARY_IMPL(fbgemm, CPU, m) { - DISPATCH_TO_CPU( - "merge_pooled_embeddings", fbgemm_gpu::merge_pooled_embeddings_cpu); +TORCH_LIBRARY_FRAGMENT(fbgemm, m) { +#ifdef HAS_IMPL_ABSTRACT_PYSTUB + m.impl_abstract_pystub( + "fbgemm_gpu.sparse_ops", + "//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_py"); +#endif + m.def( + "merge_pooled_embeddings(Tensor[] pooled_embeddings, SymInt uncat_dim_size, Device target_device, SymInt cat_dim=1) -> Tensor", + {PT2_COMPLIANT_TAG}); + m.def( + "all_to_one_device(Tensor[] input_tensors, Device target_device) -> Tensor[]"); + m.def( + "sum_reduce_to_one(Tensor[] input_tensors, Device target_device) -> Tensor"); } + +FBGEMM_OP_DISPATCH( + CPU, + "merge_pooled_embeddings", + fbgemm_gpu::merge_pooled_embeddings_cpu); diff --git a/fbgemm_gpu/src/merge_pooled_embeddings_gpu.cpp b/fbgemm_gpu/src/merge_pooled_embedding_ops/merge_pooled_embedding_ops_gpu.cpp similarity index 96% rename from fbgemm_gpu/src/merge_pooled_embeddings_gpu.cpp rename to fbgemm_gpu/src/merge_pooled_embedding_ops/merge_pooled_embedding_ops_gpu.cpp index 0977a0d98c..37d6c5a440 100644 --- a/fbgemm_gpu/src/merge_pooled_embeddings_gpu.cpp +++ b/fbgemm_gpu/src/merge_pooled_embedding_ops/merge_pooled_embedding_ops_gpu.cpp @@ -11,7 +11,7 @@ #include #include #include -#include + #include #include #include @@ -19,8 +19,8 @@ #include #include #include - #include "fbgemm_gpu/merge_pooled_embeddings.h" + #include "fbgemm_gpu/sparse_ops_utils.h" #include "fbgemm_gpu/topology_utils.h" @@ -642,33 +642,11 @@ Tensor sum_reduce_to_one_device( return sum_reduce_to_one(input_tensors, target_device); } -Tensor merge_pooled_embeddings_meta( - std::vector pooled_embeddings, - int64_t uncat_dim_size, - at::Device /*target_device*/, - int64_t cat_dim) { - if (pooled_embeddings.size() == 0) { - return at::empty({0}, at::TensorOptions().device("meta")); - } - - auto [output_shape, cumulative_dims, total_cat_dim] = - cat_dim_2d_output_shape(pooled_embeddings, uncat_dim_size, cat_dim); - - return at::empty(output_shape, pooled_embeddings.front().options()); -} } // namespace fbgemm_gpu TORCH_LIBRARY_FRAGMENT(fbgemm, m) { - m.def( - "merge_pooled_embeddings(Tensor[] pooled_embeddings, int uncat_dim_size, Device target_device, int cat_dim=1) -> Tensor"); DISPATCH_TO_CUDA( "merge_pooled_embeddings", fbgemm_gpu::merge_pooled_embeddings); - m.def( - "all_to_one_device(Tensor[] input_tensors, Device target_device) -> Tensor[]"); DISPATCH_TO_CUDA("all_to_one_device", fbgemm_gpu::all_to_one_device); - m.def( - "sum_reduce_to_one(Tensor[] input_tensors, Device target_device) -> Tensor"); DISPATCH_TO_CUDA("sum_reduce_to_one", fbgemm_gpu::sum_reduce_to_one_device); - DISPATCH_TO_META( - "merge_pooled_embeddings", fbgemm_gpu::merge_pooled_embeddings_meta); } diff --git a/fbgemm_gpu/src/metric_ops.cu b/fbgemm_gpu/src/metric_ops/metric_ops.cu similarity index 100% rename from fbgemm_gpu/src/metric_ops.cu rename to fbgemm_gpu/src/metric_ops/metric_ops.cu diff --git a/fbgemm_gpu/src/metric_ops.h b/fbgemm_gpu/src/metric_ops/metric_ops.h similarity index 100% rename from fbgemm_gpu/src/metric_ops.h rename to fbgemm_gpu/src/metric_ops/metric_ops.h diff --git a/fbgemm_gpu/src/metric_ops_host.cpp b/fbgemm_gpu/src/metric_ops/metric_ops_host.cpp similarity index 100% rename from fbgemm_gpu/src/metric_ops_host.cpp rename to fbgemm_gpu/src/metric_ops/metric_ops_host.cpp diff --git a/fbgemm_gpu/src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_cpu.cpp b/fbgemm_gpu/src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_cpu.cpp index 2e884f3a3c..a6ff0d5dce 100644 --- a/fbgemm_gpu/src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_cpu.cpp +++ b/fbgemm_gpu/src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_cpu.cpp @@ -8,6 +8,7 @@ #include #include +#include "fbgemm_gpu/dispatch_macros.h" #include "fbgemm_gpu/permute_pooled_embedding_ops.h" using Tensor = at::Tensor; @@ -149,7 +150,8 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { m.def( "permute_pooled_embs(Tensor pooled_embs, Tensor offset_dim_list, Tensor permute_list, Tensor inv_offset_dim_list, Tensor inv_permute_list) -> Tensor"); m.def( - "permute_pooled_embs_auto_grad(Tensor pooled_embs, Tensor offset_dim_list, Tensor permute_list, Tensor inv_offset_dim_list, Tensor inv_permute_list) -> Tensor"); + "permute_pooled_embs_auto_grad(Tensor pooled_embs, Tensor offset_dim_list, Tensor permute_list, Tensor inv_offset_dim_list, Tensor inv_permute_list) -> Tensor", + {PT2_COMPLIANT_TAG}); m.def( "permute_duplicate_pooled_embs(Tensor pooled_embs, Tensor offset_dim_list, Tensor permute_list, Tensor inv_offset_dim_list, Tensor inv_permute_list) -> Tensor"); m.def( diff --git a/fbgemm_gpu/src/quantize_ops/quantize_fp8_rowwise.cu b/fbgemm_gpu/src/quantize_ops/quantize_fp8_rowwise.cu index cde36c3d55..169f8d2937 100644 --- a/fbgemm_gpu/src/quantize_ops/quantize_fp8_rowwise.cu +++ b/fbgemm_gpu/src/quantize_ops/quantize_fp8_rowwise.cu @@ -20,11 +20,14 @@ namespace { // FP32/FP16 -> FP8 rowwise kernel template __global__ inline void _float_to_FP8rowwise_cuda_kernel( - const input_t* __restrict__ input, + const at::PackedTensorAccessor64 input, const int64_t nrows, const int64_t ncols, - std::uint8_t* __restrict__ output, + at::PackedTensorAccessor64 output, const bool forward) { + // Assert if index is out of bound + CUDA_KERNEL_ASSERT(nrows * ncols >= 0); + constexpr float kEpsilon = 1e-20f; const int ebit = forward ? 4 : 5; const int bias = forward ? 15 : 31; @@ -36,17 +39,20 @@ __global__ inline void _float_to_FP8rowwise_cuda_kernel( const int64_t row = blockIdx.x * blockDim.x + threadIdx.x; if (row < nrows) { - const input_t* input_row = input + row * ncols; - std::uint8_t* output_row = output + row * output_columns; + const input_t* input_row = &input[row * ncols]; + std::uint8_t* output_row = &output[row * output_columns]; float* output_row_scale_bias = reinterpret_cast(output_row + ncols_aligned); const float minimum_element = fbgemm_gpu::min(input_row, input_row + ncols); const float maximum_element = fbgemm_gpu::max(input_row, input_row + ncols); - const auto scale = max_pos / (kEpsilon + fmaxf(maximum_element, -minimum_element)); output_row_scale_bias[0] = scale; + // 8 bytes are allocated for scale but only 4 bytes are used + // value of the unassigned 4 bytes are hence indeterministic + // Initialize it to make the output deterministic for PT2 compliance + output_row_scale_bias[1] = 0.0; for (int64_t col = 0; col < ncols; ++col) { output_row[col] = float_to_hfp8(to_float(input_row[col]) * scale, ebit, bias, max_pos); @@ -56,16 +62,18 @@ __global__ inline void _float_to_FP8rowwise_cuda_kernel( template __global__ inline void _get_FP8_qparam_cuda_kernel( - const input_t* __restrict__ input, + const at::PackedTensorAccessor64 input, const int64_t nrows, const int64_t ncols, - uint8_t* __restrict__ output, - float* __restrict__ range_list, + at::PackedTensorAccessor64 output, const bool forward) { + // Assert if index is out of bound + CUDA_KERNEL_ASSERT(nrows * ncols >= 0); const int64_t row = blockIdx.x * blockDim.y + threadIdx.y; const int64_t ncols_aligned = (ncols + 4 - 1) / 4 * 4; const int64_t output_columns = ncols_aligned + 2 * sizeof(float); + float max_pos; if (forward) { max_pos = 0.9375; @@ -82,8 +90,7 @@ __global__ inline void _get_FP8_qparam_cuda_kernel( // March warp-wise through the row, doing thread local min and max reductions. // This loop will only execute once when ncol <= 32 if (row < nrows) { - const input_t* const input_row = input + row * ncols; - + const input_t* input_row = &input[row * ncols]; for (int64_t col = threadIdx.x; col < ncols; col += lane_width) { // Get thread-local minmax. These are the smallest min and max ever seen // by this thread. @@ -106,19 +113,23 @@ __global__ inline void _get_FP8_qparam_cuda_kernel( return; } float* const output_row_qparams = - reinterpret_cast(output + row * output_columns + ncols_aligned); + reinterpret_cast(&output[row * output_columns + ncols_aligned]); output_row_qparams[0] = max_pos / (kEpsilon + maximum_element); + // Initialize it to make the output deterministic for PT2 compliance + output_row_qparams[1] = 0.0; } template __global__ inline void _compute_FP8_quantize_cuda_kernel( - const input_t* const __restrict__ input, - const float* const __restrict__ range_list, + const at::PackedTensorAccessor64 input, const int64_t nrows, const int64_t ncols, - std::uint8_t* const __restrict__ output, + at::PackedTensorAccessor64 output, const bool forward) { + // Assert if index is out of bound + CUDA_KERNEL_ASSERT(nrows * ncols >= 0); + int ebit; int bias; float max_pos; @@ -139,42 +150,41 @@ __global__ inline void _compute_FP8_quantize_cuda_kernel( const int64_t col = blockIdx.x * blockDim.x + threadIdx.x; const int64_t row_incre = blockDim.y * gridDim.y; for (/*row*/; row < nrows; row += row_incre) { + std::uint8_t* output_row = &output[row * output_columns]; if (col < ncols) { - float* row_qparams = reinterpret_cast( - output + row * output_columns + ncols_aligned); + float* row_qparams = reinterpret_cast(output_row + ncols_aligned); const float scale = row_qparams[0]; - const auto input_idx = row * ncols + col; - uint8_t* output_addr = output + row * output_columns + col; - // TODO: lift range_list into shared memory. However, when nrows is large, - // it might exceed the size of shared memory. - // output_addr[0] = lrintf((input[input_idx] - bias) * inverse_scale); - output_addr[0] = float_to_hfp8( - to_float(input[input_idx]) * scale, ebit, bias, max_pos); + output_row[col] = float_to_hfp8( + to_float(input[row * ncols + col]) * scale, ebit, bias, max_pos); } } } template __global__ inline void _FP8rowwise_to_float_cuda_kernel( - const std::uint8_t* const __restrict__ input, - const int nrows, - const int ncols, - output_t* const __restrict__ output, + at::PackedTensorAccessor64 input, + const int64_t nrows, + const int64_t ncols, + at::PackedTensorAccessor64 output, const bool forward) { - const int output_columns = ncols - 2 * sizeof(float); + // Assert if index is out of bound + CUDA_KERNEL_ASSERT(nrows * ncols >= 0); + + const int64_t output_columns = ncols - 2 * sizeof(float); const int ebit = forward ? 4 : 5; const int bias = forward ? 15 : 31; - int row = (int)blockIdx.y * blockDim.y + threadIdx.y; - const int col = (int)blockIdx.x * blockDim.x + threadIdx.x; - const int row_incre = blockDim.y * gridDim.y; + int64_t row = static_cast(blockIdx.y) * blockDim.y + threadIdx.y; + const int64_t col = + static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + const int64_t row_incre = blockDim.y * gridDim.y; + for (/*row*/; row < nrows; row += row_incre) { if (col < output_columns) { - const std::uint8_t* input_row = input + row * ncols; + const std::uint8_t* input_row = &input[row * ncols]; + output_t* output_row = &output[row * output_columns]; const float* input_row_scale_bias = reinterpret_cast(input_row + output_columns); - output_t* output_row = output + row * output_columns; - const float output_ = hfp8_to_float(input_row[col], ebit, bias) / input_row_scale_bias[0]; quantize_float_store(&output_row[col], output_); @@ -189,16 +199,15 @@ template Tensor _float_to_FP8rowwise_gpu_t(const Tensor& input, const bool forward) { TENSOR_ON_CUDA_GPU(input); TORCH_CHECK(input.is_contiguous(), "input must be contiguous"); - at::cuda::OptionalCUDAGuard device_guard; device_guard.set_index(input.get_device()); const auto input_sizes = input.sizes(); const auto last_dim = input_sizes.size() - 1; - const int nrows = c10::size_to_dim_(last_dim, input_sizes); - const int ncols = input_sizes[last_dim]; - const int ncols_aligned = (ncols + 4 - 1) / 4 * 4; - const int output_columns = ncols_aligned + 2 * sizeof(float); + const int64_t nrows = c10::size_to_dim_(last_dim, input_sizes); + const int64_t ncols = input_sizes[last_dim]; + const int64_t ncols_aligned = (ncols + 4 - 1) / 4 * 4; + const int64_t output_columns = ncols_aligned + 2 * sizeof(float); // Global memory instructions support reading or writing words of size equal // to 1, 2, 4, 8, or 16 bytes. Any access (via a variable or a pointer) to @@ -208,18 +217,24 @@ Tensor _float_to_FP8rowwise_gpu_t(const Tensor& input, const bool forward) { // that size). auto output_dims = input_sizes.vec(); output_dims[last_dim] = output_columns; - auto output = at::empty( - output_dims, // 4 = sizeof(float) - input.options().dtype(at::kByte)); if (nrows == 0 || ncols == 0) { - return output; + return at::zeros( + output_dims, // 4 = sizeof(float) + input.options().dtype(at::kByte)); } + auto output = at::empty( + output_dims, // 4 = sizeof(float) + input.options().dtype(at::kByte)); + constexpr int threads_per_block = 256; const auto num_blocks = cuda_calc_xblock_count(nrows, threads_per_block); // think unsigned as we use 0, 255 + const auto input_1D = input.flatten(); + const auto output_1D = output.flatten(); + if (nrows <= 20) { FBGEMM_DISPATCH_FLOAT_HALF_AND_BFLOAT16( input.scalar_type(), "_float_to_FP8rowwise_cuda_kernel", [&] { @@ -228,10 +243,12 @@ Tensor _float_to_FP8rowwise_gpu_t(const Tensor& input, const bool forward) { threads_per_block, 0, at::cuda::getCurrentCUDAStream()>>>( - input.data_ptr(), + input_1D + .packed_accessor64(), nrows, ncols, - output.data_ptr(), + output_1D + .packed_accessor64(), forward); C10_CUDA_KERNEL_LAUNCH_CHECK(); }); @@ -268,18 +285,22 @@ Tensor _float_to_FP8rowwise_gpu_t(const Tensor& input, const bool forward) { dim3(blockDim_x, rows_per_block), 0, at::cuda::getCurrentCUDAStream()>>>( - input.data_ptr(), + input_1D.packed_accessor64< + scalar_t, + 1, + at::RestrictPtrTraits>(), nrows, ncols, - output.data_ptr(), - range_tensor.data_ptr(), + output_1D + .packed_accessor64(), forward); C10_CUDA_KERNEL_LAUNCH_CHECK(); }); } { - const int blockDim_x = std::min(ncols, threads_per_block); + const int blockDim_x = + std::min(ncols, static_cast(threads_per_block)); dim3 blockDim(blockDim_x, threads_per_block / blockDim_x); const auto gridDim_x = cuda_calc_xblock_count(ncols, blockDim.x); const auto gridDim_y = cuda_calc_block_count(nrows, blockDim.y); @@ -289,11 +310,14 @@ Tensor _float_to_FP8rowwise_gpu_t(const Tensor& input, const bool forward) { input.scalar_type(), "_compute_FP8_quantize_cuda_kernel", [&] { _compute_FP8_quantize_cuda_kernel <<>>( - input.data_ptr(), - range_tensor.data_ptr(), + input_1D.packed_accessor64< + scalar_t, + 1, + at::RestrictPtrTraits>(), nrows, ncols, - output.data_ptr(), + output_1D + .packed_accessor64(), forward); C10_CUDA_KERNEL_LAUNCH_CHECK(); }); @@ -328,10 +352,10 @@ Tensor _FP8rowwise_to_float_gpu_t( const auto input_sizes = input.sizes(); const auto last_dim = input_sizes.size() - 1; - const int nrows = c10::size_to_dim_(last_dim, input_sizes); - const int ncols = input_sizes[last_dim]; - const int ncols_aligned = (ncols + 4 - 1) / 4 * 4; - const int output_columns = ncols_aligned - 2 * sizeof(float); + const int64_t nrows = c10::size_to_dim_(last_dim, input_sizes); + const int64_t ncols = input_sizes[last_dim]; + const int64_t ncols_aligned = (ncols + 4 - 1) / 4 * 4; + const int64_t output_columns = ncols_aligned - 2 * sizeof(float); // Global memory instructions support reading or writing words of size equal // to 1, 2, 4, 8, or 16 bytes. Any access (via a variable or a pointer) to @@ -346,31 +370,38 @@ Tensor _FP8rowwise_to_float_gpu_t( output_sdtype == SparseType::FP32 || output_sdtype == SparseType::FP16 || output_sdtype == SparseType::BF16); + if (nrows == 0 || output_columns == 0) { + return at::zeros( + output_dims, // 4 = sizeof(float) + input.options().dtype(getScalarType(output_sdtype))); + } + Tensor output = at::empty( output_dims, // 4 = sizeof(float) input.options().dtype(getScalarType(output_sdtype))); - if (nrows == 0 || output_columns == 0) { - return output; - } - constexpr int threads_per_block = 256; - const int blockDim_x = std::min(threads_per_block, output_columns); + const int blockDim_x = + std::min(static_cast(threads_per_block), output_columns); const dim3 blockDim(blockDim_x, threads_per_block / blockDim_x); const auto gridDim_x = cuda_calc_xblock_count(output_columns, blockDim.x); const auto gridDim_y = cuda_calc_block_count(nrows, blockDim.y); const dim3 gridDim(gridDim_x, gridDim_y); + const auto input_1D = input.flatten(); + const auto output_1D = output.flatten(); + FBGEMM_DISPATCH_FLOAT_HALF_AND_BFLOAT16( output.scalar_type(), "FP8rowwise_to_float_cuda_kernel", [&] { _FP8rowwise_to_float_cuda_kernel <<>>( - input.data_ptr(), + input_1D.packed_accessor64(), nrows, ncols, - output.data_ptr(), + output_1D + .packed_accessor64(), forward); C10_CUDA_KERNEL_LAUNCH_CHECK(); }); diff --git a/fbgemm_gpu/src/quantize_ops/quantize_ops_cpu.cpp b/fbgemm_gpu/src/quantize_ops/quantize_ops_cpu.cpp index 9521f5a4c4..dc5ab72be3 100644 --- a/fbgemm_gpu/src/quantize_ops/quantize_ops_cpu.cpp +++ b/fbgemm_gpu/src/quantize_ops/quantize_ops_cpu.cpp @@ -408,7 +408,9 @@ at::Tensor _hfp8_to_float_cpu( TORCH_LIBRARY_FRAGMENT(fbgemm, m) { m.def("FloatToFused8BitRowwiseQuantized(Tensor t) -> Tensor"); - m.def("FloatToFP8RowwiseQuantized(Tensor t, bool forward) -> Tensor"); + m.def( + "FloatToFP8RowwiseQuantized(Tensor t, bool forward) -> Tensor", + {PT2_COMPLIANT_TAG}); m.def( "FloatToPaddedFP8RowwiseQuantized(Tensor t, bool forward, int row_dim) -> Tensor"); m.def( @@ -417,7 +419,8 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { m.def("FloatOrHalfToFused8BitRowwiseQuantized(Tensor t) -> Tensor"); m.def("Fused8BitRowwiseQuantizedToFloat(Tensor input) -> Tensor"); m.def( - "FP8RowwiseQuantizedToFloat(Tensor input, bool forward, int output_dtype=0) -> Tensor"); + "FP8RowwiseQuantizedToFloat(Tensor input, bool forward, int output_dtype=0) -> Tensor", + {PT2_COMPLIANT_TAG}); m.def("Fused8BitRowwiseQuantizedToHalf(Tensor input) -> Tensor"); m.def( "Fused8BitRowwiseQuantizedToFloatOrHalf(Tensor input, int output_dtype=0) -> Tensor"); diff --git a/fbgemm_gpu/src/quantize_ops/quantize_ops_meta.cpp b/fbgemm_gpu/src/quantize_ops/quantize_ops_meta.cpp index 471116942c..2cecb4ceac 100644 --- a/fbgemm_gpu/src/quantize_ops/quantize_ops_meta.cpp +++ b/fbgemm_gpu/src/quantize_ops/quantize_ops_meta.cpp @@ -49,10 +49,28 @@ Tensor FP8rowwise_to_float_meta( } } +Tensor FloatToFP8RowwiseQuantized_meta(const Tensor& input, bool forward) { + TORCH_CHECK(input.is_contiguous(), "input must be contiguous"); + + const at::SymIntArrayRef input_sizes = input.sym_sizes(); + + const auto last_dim = input_sizes.size() - 1; + const at::SymInt ncols = input_sizes[last_dim]; + const at::SymInt ncols_aligned = (ncols + 4 - 1) / 4 * 4; + const at::SymInt output_columns = ncols_aligned + 2 * sizeof(float); + + auto output_dims = input_sizes.vec(); + output_dims[last_dim] = output_columns; + return at::empty_symint(output_dims, input.options().dtype(at::kByte)); +} + } // namespace fbgemm_gpu TORCH_LIBRARY_IMPL(fbgemm, Meta, m) { m.impl( "FP8RowwiseQuantizedToFloat", TORCH_FN(fbgemm_gpu::FP8rowwise_to_float_meta)); + m.impl( + "FloatToFP8RowwiseQuantized", + TORCH_FN(fbgemm_gpu::FloatToFP8RowwiseQuantized_meta)); } diff --git a/fbgemm_gpu/src/sparse_ops/sparse_block_bucketize_features.cu b/fbgemm_gpu/src/sparse_ops/sparse_block_bucketize_features.cu index d2f281fd49..676ec9f4d8 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_block_bucketize_features.cu +++ b/fbgemm_gpu/src/sparse_ops/sparse_block_bucketize_features.cu @@ -48,13 +48,18 @@ __launch_bounds__(kMaxThreads) void _block_bucketize_sparse_features_cuda_kernel const offset_t* const __restrict__ offsets_data, const index_t* const __restrict__ indices_data, offset_t* const __restrict__ new_lengths_data, - offset_t* __restrict__ length_to_feature_idx) { + offset_t* __restrict__ length_to_feature_idx, + const offset_t* const __restrict__ block_bucketize_pos_concat, + const offset_t* const __restrict__ block_bucketize_pos_offsets, + offset_t* __restrict__ indices_to_lb) { using uindex_t = std::make_unsigned_t; CUDA_KERNEL_LOOP(b_t, lengths_size) { const auto t = length_to_feature_idx ? length_to_feature_idx[b_t] : b_t / B; index_t blk_size = block_sizes_data[t]; offset_t rowstart = (b_t == 0 ? 0 : offsets_data[b_t - 1]); offset_t rowend = offsets_data[b_t]; + const auto use_block_bucketize_pos = + (block_bucketize_pos_concat != nullptr); for (index_t i = rowstart; i < rowend; ++i) { // We have use cases using none-hashed raw indices that can be either // negative or larger than embedding table hash_size (blk_size * @@ -63,7 +68,27 @@ __launch_bounds__(kMaxThreads) void _block_bucketize_sparse_features_cuda_kernel // range of blk_size, we expect the later embedding module to take care // of hashing indices calculation. uindex_t idx = static_cast(indices_data[i]); - uindex_t p = idx < blk_size * my_size ? idx / blk_size : idx % my_size; + uindex_t p = 0; + if (!use_block_bucketize_pos) { + p = idx < blk_size * my_size ? idx / blk_size : idx % my_size; + } else { + index_t first = block_bucketize_pos_offsets[t]; + index_t last = block_bucketize_pos_offsets[t + 1]; + + while (first < last) { + index_t middle = first + ((last - first) / 2); + if (static_cast(block_bucketize_pos_concat[middle]) <= + idx) { + first = ++middle; + } else { + last = middle; + } + } + uindex_t lb = + static_cast(first - block_bucketize_pos_offsets[t] - 1); + indices_to_lb[i] = lb; + p = lb < my_size ? lb : idx % my_size; + } new_lengths_data[p * lengths_size + b_t]++; } } @@ -95,7 +120,10 @@ __launch_bounds__(kMaxThreads) void _block_bucketize_sparse_features_cuda_kernel scalar_t* __restrict__ new_weights_data, index_t* __restrict__ new_pos_data, index_t* const __restrict__ unbucketize_permute_data, - const offset_t* const __restrict__ length_to_feature_idx) { + const offset_t* const __restrict__ length_to_feature_idx, + const offset_t* const __restrict__ block_bucketize_pos_concat, + const offset_t* const __restrict__ block_bucketize_pos_offsets, + const offset_t* const __restrict__ indices_to_lb) { using uindex_t = std::make_unsigned_t; using uoffset_t = std::make_unsigned_t; CUDA_KERNEL_LOOP(b_t, lengths_size) { @@ -103,6 +131,8 @@ __launch_bounds__(kMaxThreads) void _block_bucketize_sparse_features_cuda_kernel index_t blk_size = block_sizes_data[t]; offset_t rowstart = (b_t == 0 ? 0 : offsets_data[b_t - 1]); offset_t rowend = offsets_data[b_t]; + const auto use_block_bucketize_pos = + (block_bucketize_pos_concat != nullptr); for (index_t i = rowstart; i < rowend; ++i) { // We have use cases using none-hashed raw indices that can be either // negative or larger than embedding table hash_size (blk_size * @@ -111,9 +141,18 @@ __launch_bounds__(kMaxThreads) void _block_bucketize_sparse_features_cuda_kernel // range of blk_size, we expect the later embedding module to take care // of hashing indices calculation. uindex_t idx = static_cast(indices_data[i]); - uindex_t p = idx < blk_size * my_size ? idx / blk_size : idx % my_size; - uindex_t new_idx = - idx < blk_size * my_size ? idx % blk_size : idx / my_size; + uindex_t p = 0; + uindex_t new_idx = 0; + if (!use_block_bucketize_pos) { + p = idx < blk_size * my_size ? idx / blk_size : idx % my_size; + new_idx = idx < blk_size * my_size ? idx % blk_size : idx / my_size; + } else { + uindex_t lb = indices_to_lb[i]; + p = lb < my_size ? lb : idx % my_size; + new_idx = lb < my_size ? idx - + block_bucketize_pos_concat[lb + block_bucketize_pos_offsets[t]] + : idx / my_size; + } uoffset_t pos = new_offsets_data[p * lengths_size + b_t]; new_indices_data[pos] = new_idx; new_offsets_data[p * lengths_size + b_t]++; @@ -147,7 +186,8 @@ block_bucketize_sparse_features_cuda( const int64_t my_size, const c10::optional& weights, const c10::optional& batch_size_per_feature, - const int64_t max_B) { + const int64_t max_B, + const c10::optional>& block_bucketize_pos) { TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(lengths, indices); at::cuda::OptionalCUDAGuard device_guard; @@ -181,6 +221,7 @@ block_bucketize_sparse_features_cuda( } auto length_to_feature_idx = at::empty({lengths_size}, lengths_contig.options()); + auto indices_to_lb = at::empty_like(indices); if (batch_size_per_feature.has_value()) { constexpr auto threads_per_block = 256; const auto num_blocks = @@ -204,6 +245,27 @@ block_bucketize_sparse_features_cuda( }); } + at::Tensor block_bucketize_pos_concat = + at::empty({1}, lengths_contig.options()); + at::Tensor block_bucketize_pos_offsets = + at::empty({1}, lengths_contig.options()); + + if (block_bucketize_pos.has_value()) { + block_bucketize_pos_concat = at::cat(block_bucketize_pos.value(), 0); + std::vector sizes_; + sizes_.reserve(block_bucketize_pos.value().size() + 1); + for (auto const& t : block_bucketize_pos.value()) { + sizes_.push_back(t.numel()); + } + sizes_.push_back(0); + at::Tensor sizes_vec = + at::tensor(sizes_, at::TensorOptions().dtype(lengths_contig.dtype())); + block_bucketize_pos_offsets = asynchronous_exclusive_cumsum_cpu( + sizes_vec); // expect sizes_vec to be a small tensor, using cpu instead + // of gpu for cumsum + block_bucketize_pos_offsets = block_bucketize_pos_offsets.to( + block_bucketize_pos_concat.device(), true); + } constexpr auto threads_per_block = 256; const auto num_blocks = cuda_calc_xblock_count(lengths_size, threads_per_block); @@ -230,6 +292,15 @@ block_bucketize_sparse_features_cuda( new_lengths.data_ptr(), batch_size_per_feature.has_value() ? length_to_feature_idx.data_ptr() + : static_cast(nullptr), + block_bucketize_pos.has_value() + ? block_bucketize_pos_concat.data_ptr() + : static_cast(nullptr), + block_bucketize_pos.has_value() + ? block_bucketize_pos_offsets.data_ptr() + : static_cast(nullptr), + block_bucketize_pos.has_value() + ? indices_to_lb.data_ptr() : static_cast(nullptr)); C10_CUDA_KERNEL_LAUNCH_CHECK(); }); @@ -283,6 +354,17 @@ block_bucketize_sparse_features_cuda( unbucketize_permute.data_ptr(), batch_size_per_feature.has_value() ? length_to_feature_idx.data_ptr() + : static_cast(nullptr), + block_bucketize_pos.has_value() + ? block_bucketize_pos_concat + .data_ptr() + : static_cast(nullptr), + block_bucketize_pos.has_value() + ? block_bucketize_pos_offsets + .data_ptr() + : static_cast(nullptr), + block_bucketize_pos.has_value() + ? indices_to_lb.data_ptr() : static_cast(nullptr)); C10_CUDA_KERNEL_LAUNCH_CHECK(); }); @@ -330,6 +412,17 @@ block_bucketize_sparse_features_cuda( unbucketize_permute.data_ptr(), batch_size_per_feature.has_value() ? length_to_feature_idx.data_ptr() + : static_cast(nullptr), + block_bucketize_pos.has_value() + ? block_bucketize_pos_concat + .data_ptr() + : static_cast(nullptr), + block_bucketize_pos.has_value() + ? block_bucketize_pos_offsets + .data_ptr() + : static_cast(nullptr), + block_bucketize_pos.has_value() + ? indices_to_lb.data_ptr() : static_cast(nullptr)); C10_CUDA_KERNEL_LAUNCH_CHECK(); }); @@ -371,6 +464,15 @@ block_bucketize_sparse_features_cuda( unbucketize_permute.data_ptr(), batch_size_per_feature.has_value() ? length_to_feature_idx.data_ptr() + : static_cast(nullptr), + block_bucketize_pos.has_value() + ? block_bucketize_pos_concat.data_ptr() + : static_cast(nullptr), + block_bucketize_pos.has_value() + ? block_bucketize_pos_offsets.data_ptr() + : static_cast(nullptr), + block_bucketize_pos.has_value() + ? indices_to_lb.data_ptr() : static_cast(nullptr)); C10_CUDA_KERNEL_LAUNCH_CHECK(); }); @@ -410,6 +512,15 @@ block_bucketize_sparse_features_cuda( unbucketize_permute.data_ptr(), batch_size_per_feature.has_value() ? length_to_feature_idx.data_ptr() + : static_cast(nullptr), + block_bucketize_pos.has_value() + ? block_bucketize_pos_concat.data_ptr() + : static_cast(nullptr), + block_bucketize_pos.has_value() + ? block_bucketize_pos_offsets.data_ptr() + : static_cast(nullptr), + block_bucketize_pos.has_value() + ? indices_to_lb.data_ptr() : static_cast(nullptr)); C10_CUDA_KERNEL_LAUNCH_CHECK(); }); @@ -459,6 +570,17 @@ block_bucketize_sparse_features_cuda( nullptr, batch_size_per_feature.has_value() ? length_to_feature_idx.data_ptr() + : static_cast(nullptr), + block_bucketize_pos.has_value() + ? block_bucketize_pos_concat + .data_ptr() + : static_cast(nullptr), + block_bucketize_pos.has_value() + ? block_bucketize_pos_offsets + .data_ptr() + : static_cast(nullptr), + block_bucketize_pos.has_value() + ? indices_to_lb.data_ptr() : static_cast(nullptr)); C10_CUDA_KERNEL_LAUNCH_CHECK(); }); @@ -506,6 +628,17 @@ block_bucketize_sparse_features_cuda( nullptr, batch_size_per_feature.has_value() ? length_to_feature_idx.data_ptr() + : static_cast(nullptr), + block_bucketize_pos.has_value() + ? block_bucketize_pos_concat + .data_ptr() + : static_cast(nullptr), + block_bucketize_pos.has_value() + ? block_bucketize_pos_offsets + .data_ptr() + : static_cast(nullptr), + block_bucketize_pos.has_value() + ? indices_to_lb.data_ptr() : static_cast(nullptr)); C10_CUDA_KERNEL_LAUNCH_CHECK(); }); @@ -547,6 +680,15 @@ block_bucketize_sparse_features_cuda( nullptr, batch_size_per_feature.has_value() ? length_to_feature_idx.data_ptr() + : static_cast(nullptr), + block_bucketize_pos.has_value() + ? block_bucketize_pos_concat.data_ptr() + : static_cast(nullptr), + block_bucketize_pos.has_value() + ? block_bucketize_pos_offsets.data_ptr() + : static_cast(nullptr), + block_bucketize_pos.has_value() + ? indices_to_lb.data_ptr() : static_cast(nullptr)); C10_CUDA_KERNEL_LAUNCH_CHECK(); }); @@ -586,6 +728,15 @@ block_bucketize_sparse_features_cuda( nullptr, batch_size_per_feature.has_value() ? length_to_feature_idx.data_ptr() + : static_cast(nullptr), + block_bucketize_pos.has_value() + ? block_bucketize_pos_concat.data_ptr() + : static_cast(nullptr), + block_bucketize_pos.has_value() + ? block_bucketize_pos_offsets.data_ptr() + : static_cast(nullptr), + block_bucketize_pos.has_value() + ? indices_to_lb.data_ptr() : static_cast(nullptr)); C10_CUDA_KERNEL_LAUNCH_CHECK(); }); diff --git a/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp b/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp index 7d2a98d5cd..ae17a393c5 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp +++ b/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp @@ -18,6 +18,7 @@ #include #include #include "c10/util/MaybeOwned.h" +#include "fbgemm_gpu/dispatch_macros.h" #include "fbgemm_gpu/sparse_ops.h" #include "fbgemm_gpu/sparse_ops_utils.h" @@ -127,7 +128,7 @@ Tensor pack_segments_autograd( Tensor native_empty_like(const Tensor& self) { return at::native::empty_like( self, - optTypeMetaToScalarType(self.options().dtype_opt()), + c10::optTypeMetaToScalarType(self.options().dtype_opt()), self.options().layout_opt(), self.options().device_opt(), self.options().pinned_memory_opt(), @@ -283,7 +284,8 @@ void _block_bucketize_sparse_features_cpu( c10::optional new_weights, c10::optional new_pos, const c10::optional& unbucketize_permute, - const c10::optional& batch_size_per_feature) { + const c10::optional& batch_size_per_feature, + const c10::optional>& block_bucketize_pos) { // allocate tensors and buffers const auto lengths_size = lengths.numel(); const auto new_lengths_size = lengths_size * my_size; @@ -304,9 +306,11 @@ void _block_bucketize_sparse_features_cpu( const index_t* const block_sizes_data = block_sizes.data_ptr(); offset_t* batch_sizes_data = nullptr; const auto variable_batch_size = batch_size_per_feature.has_value(); - + const auto variable_bucket_sizes = block_bucketize_pos.has_value() && + block_bucketize_pos.value().size() != 0; using uindex_t = std::make_unsigned_t; using uoffset_t = std::make_unsigned_t; + std::vector lower_bounds(indices.numel(), 0); if constexpr (sequence) { unbucketize_permute_data = unbucketize_permute.value().data_ptr(); @@ -330,6 +334,12 @@ void _block_bucketize_sparse_features_cpu( for (const auto t : c10::irange(T)) { const auto blk_size = block_sizes_data[t]; const auto cur_batch_size = variable_batch_size ? batch_sizes_data[t] : B; + const index_t* bucketize_offset = nullptr; + int64_t bucket_size = 0; + if (variable_bucket_sizes) { + bucketize_offset = block_bucketize_pos.value()[t].data_ptr(); + bucket_size = block_bucketize_pos.value()[t].numel(); + } for (const auto b : c10::irange(cur_batch_size)) { const auto b_t = (variable_batch_size ? cur_offset : t * B) + b; const offset_t rowstart = offsets_data[b_t]; @@ -342,10 +352,21 @@ void _block_bucketize_sparse_features_cpu( // range of blk_size, we expect the later embedding module to take care // of hashing indices calculation. uindex_t idx = static_cast(indices_data[i]); - uindex_t p = idx < static_cast(blk_size * my_size) - ? idx / blk_size - : idx % my_size; - new_lengths_data[p * lengths_size + b_t]++; + if (variable_bucket_sizes) { + int64_t lb = std::upper_bound( + bucketize_offset, + bucketize_offset + static_cast(bucket_size), + indices_data[i]) - + bucketize_offset - 1; + lower_bounds[i] = lb; + uindex_t p = lb < my_size ? lb : idx % my_size; + new_lengths_data[p * lengths_size + b_t]++; + } else { + uindex_t p = idx < static_cast(blk_size * my_size) + ? idx / blk_size + : idx % my_size; + new_lengths_data[p * lengths_size + b_t]++; + } } } cur_offset += cur_batch_size; @@ -358,6 +379,10 @@ void _block_bucketize_sparse_features_cpu( for (const auto t : c10::irange(T)) { const auto blk_size = block_sizes_data[t]; const auto cur_batch_size = variable_batch_size ? batch_sizes_data[t] : B; + const index_t* bucketize_offset = nullptr; + if (variable_bucket_sizes) { + bucketize_offset = block_bucketize_pos.value()[t].data_ptr(); + } for (const auto b : c10::irange(cur_batch_size)) { const auto b_t = (variable_batch_size ? cur_offset : t * B) + b; const offset_t rowstart = offsets_data[b_t]; @@ -370,12 +395,19 @@ void _block_bucketize_sparse_features_cpu( // range of blk_size, we expect the later embedding module to take care // of hashing indices calculation. const uindex_t idx = static_cast(indices_data[i]); - const uindex_t p = idx < static_cast(blk_size * my_size) - ? idx / blk_size - : idx % my_size; - const uindex_t new_idx = idx < static_cast(blk_size * my_size) - ? idx % blk_size - : idx / my_size; + uindex_t p, new_idx; + if (variable_bucket_sizes) { + int64_t lb = lower_bounds[i]; + p = lb < my_size ? lb : idx % my_size; + new_idx = lb < my_size ? idx - bucketize_offset[lb] : idx / my_size; + + } else { + p = idx < static_cast(blk_size * my_size) ? idx / blk_size + : idx % my_size; + new_idx = idx < static_cast(blk_size * my_size) + ? idx % blk_size + : idx / my_size; + } const uoffset_t pos = new_offsets_data[p * lengths_size + b_t]; new_indices_data[pos] = new_idx; if (sequence) { @@ -910,8 +942,8 @@ block_bucketize_sparse_features_cpu( const int64_t my_size, const c10::optional& weights, const c10::optional& batch_size_per_feature, - const int64_t /* max_batch_size */ // Only used in GPU variant -) { + const int64_t /* max_batch_size */, // Only used in GPU variant + const c10::optional>& block_bucketize_pos) { const auto lengths_size = lengths.numel(); const auto new_lengths_size = lengths_size * my_size; auto new_lengths = at::zeros({new_lengths_size}, lengths.options()); @@ -958,7 +990,8 @@ block_bucketize_sparse_features_cpu( new_weights, new_pos, unbucketize_permute, - batch_size_per_feature); + batch_size_per_feature, + block_bucketize_pos); }); }); }); @@ -993,7 +1026,8 @@ block_bucketize_sparse_features_cpu( new_weights, new_pos, unbucketize_permute, - batch_size_per_feature); + batch_size_per_feature, + block_bucketize_pos); }); }); }); @@ -1026,7 +1060,8 @@ block_bucketize_sparse_features_cpu( new_weights, new_pos, unbucketize_permute, - batch_size_per_feature); + batch_size_per_feature, + block_bucketize_pos); }); }); } else { @@ -1054,7 +1089,8 @@ block_bucketize_sparse_features_cpu( new_weights, new_pos, unbucketize_permute, - batch_size_per_feature); + batch_size_per_feature, + block_bucketize_pos); }); }); } @@ -2686,22 +2722,36 @@ Tensor bottom_k_per_row( } // namespace fbgemm_gpu TORCH_LIBRARY_FRAGMENT(fbgemm, m) { +#ifdef HAS_IMPL_ABSTRACT_PYSTUB + m.impl_abstract_pystub( + "fbgemm_gpu.sparse_ops", + "//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_py"); +#endif m.def( "permute_sparse_data(Tensor permute, Tensor lengths, Tensor values, Tensor? weights=None, SymInt? permuted_lengths_sum=None) -> (Tensor, Tensor, Tensor?)"); m.def( - "permute_2D_sparse_data(Tensor permute, Tensor lengths, Tensor values, Tensor? weights=None, SymInt? permuted_lengths_sum=None) -> (Tensor, Tensor, Tensor?)"); + "permute_2D_sparse_data(Tensor permute, Tensor lengths, Tensor values, Tensor? weights=None, SymInt? permuted_lengths_sum=None) -> (Tensor, Tensor, Tensor?)", + {PT2_COMPLIANT_TAG}); m.def( - "permute_1D_sparse_data(Tensor permute, Tensor lengths, Tensor values, Tensor? weights=None, SymInt? permuted_lengths_sum=None) -> (Tensor, Tensor, Tensor?)"); + "permute_1D_sparse_data(Tensor permute, Tensor lengths, Tensor values, Tensor? weights=None, SymInt? permuted_lengths_sum=None) -> (Tensor, Tensor, Tensor?)", + {PT2_COMPLIANT_TAG}); m.def("invert_permute(Tensor permute) -> Tensor"); m.def( - "expand_into_jagged_permute(Tensor permute, Tensor input_offset, Tensor output_offset, SymInt output_size) -> Tensor"); + "expand_into_jagged_permute(Tensor permute, Tensor input_offset, Tensor output_offset, SymInt output_size) -> Tensor", + {PT2_COMPLIANT_TAG}); m.def( - "block_bucketize_sparse_features(Tensor lengths, Tensor indices, bool bucketize_pos, bool sequence, Tensor block_sizes, SymInt my_size, Tensor? weights=None, Tensor? batch_size_per_feature=None, SymInt max_B= -1) -> (Tensor, Tensor, Tensor?, Tensor?, Tensor?)"); + "block_bucketize_sparse_features(Tensor lengths, Tensor indices, bool bucketize_pos, bool sequence, Tensor block_sizes, SymInt my_size, Tensor? weights=None, Tensor? batch_size_per_feature=None, SymInt max_B= -1, Tensor[]? block_bucketize_pos=None) -> (Tensor, Tensor, Tensor?, Tensor?, Tensor?)"); m.def( "bucketize_sparse_features(Tensor lengths, Tensor indices, bool bucketize_pos, SymInt my_size, Tensor? weights=None) -> (Tensor, Tensor, Tensor?, Tensor?)"); - m.def("asynchronous_exclusive_cumsum(Tensor t_in) -> Tensor"); - m.def("asynchronous_inclusive_cumsum(Tensor t_in) -> Tensor"); - m.def("asynchronous_complete_cumsum(Tensor t_in) -> Tensor"); + m.def( + "asynchronous_exclusive_cumsum(Tensor t_in) -> Tensor", + {PT2_COMPLIANT_TAG}); + m.def( + "asynchronous_inclusive_cumsum(Tensor t_in) -> Tensor", + {PT2_COMPLIANT_TAG}); + m.def( + "asynchronous_complete_cumsum(Tensor t_in) -> Tensor", + {PT2_COMPLIANT_TAG}); m.def( "reorder_batched_ad_lengths(Tensor cat_ad_lengths, Tensor batch_offsets, SymInt num_ads_in_batch, bool broadcast_lengths=False) -> Tensor"); m.def( @@ -2710,7 +2760,8 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { "cat_reorder_batched_ad_indices(Tensor cat_ad_offsets, Tensor[] cat_ad_indices, Tensor reordered_cat_ad_offsets, Tensor batch_offsets, SymInt num_ads_in_batch, bool broadcast_indices, SymInt total_num_indices, bool pinned_memory=False) -> Tensor"); m.def("offsets_range(Tensor offsets, SymInt range_size) -> Tensor"); m.def( - "batched_unary_embeddings(Tensor weight, Tensor table_offsets, Tensor offsets, Tensor indices) -> Tensor"); + "batched_unary_embeddings(Tensor weight, Tensor table_offsets, Tensor offsets, Tensor indices) -> Tensor", + {PT2_COMPLIANT_TAG}); m.def( "histogram_binning_calibration(Tensor logit, Tensor bin_num_examples, Tensor bin_num_positives, float positive_weight, float lower_bound, float upper_bound, SymInt bin_ctr_in_use_after, float bin_ctr_weight_value) -> (Tensor, Tensor)"); m.def( @@ -2725,7 +2776,8 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { m.def( "lengths_range_out(Tensor output, Tensor t_in, SymInt[]? shape=None) -> Tensor"); m.def( - "permute_sparse_features(Tensor permute, Tensor lengths, Tensor indices, Tensor? weights=None) -> (Tensor, Tensor, Tensor?)"); + "permute_sparse_features(Tensor permute, Tensor lengths, Tensor indices, Tensor? weights=None) -> (Tensor, Tensor, Tensor?)", + {PT2_COMPLIANT_TAG}); m.def("Bfloat16QuantizedToFloat(Tensor input) -> Tensor"); m.def("FloatToBfloat16Quantized(Tensor input) -> Tensor"); m.def( @@ -2733,7 +2785,8 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { m.def( "permute_sequence_embeddings(Tensor permute, Tensor lengths, Tensor embeddings) -> (Tensor, Tensor)"); m.def( - "pack_segments(Tensor t_in, Tensor lengths, SymInt max_length) -> Tensor"); + "pack_segments(Tensor t_in, Tensor lengths, SymInt max_length) -> Tensor", + {PT2_COMPLIANT_TAG}); m.def( "pack_segments_backward(Tensor data, Tensor lengths, SymInt total_length, SymInt max_length) -> Tensor"); // A specialization of at::index_select for selecting dim 0 @@ -2756,7 +2809,8 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { m.def( "index_select_dim0(Tensor input, Tensor indices, SymInt? consecutive_range_start=0, SymInt? consecutive_range_length=0, bool? skip_indices_sorting_fwd=None) -> Tensor"); m.def( - "group_index_select_dim0(Tensor[] input_group, Tensor[] indices_group) -> Tensor[]"); + "group_index_select_dim0(Tensor[] input_group, Tensor[] indices_group) -> Tensor[]", + {PT2_COMPLIANT_TAG}); // This is an one-off op to be used in split_embedding_utils.py for zipf // generation w/o replacement along dim=-1. If requires_unique=True, find // smallest unique k. If the number of unique elements is less than k, diff --git a/fbgemm_gpu/src/sparse_ops/sparse_permute_1d.cu b/fbgemm_gpu/src/sparse_ops/sparse_permute_1d.cu index 2c03165268..145d279c0d 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_permute_1d.cu +++ b/fbgemm_gpu/src/sparse_ops/sparse_permute_1d.cu @@ -85,7 +85,11 @@ permute_1D_sparse_data_cuda( if (permuted_lengths_size == 0 || lengths_size == 0) { // Permutation will not be performed. Return the input tensors - return {lengths.view({-1}), indices, weights}; + return { + lengths.view({-1}).clone(), + indices.clone(), + weights.has_value() ? c10::make_optional(weights->clone()) + : c10::nullopt}; } Tensor permuted_lengths; diff --git a/fbgemm_gpu/src/sparse_ops/sparse_reorder_batched_ad.cu b/fbgemm_gpu/src/sparse_ops/sparse_reorder_batched_ad.cu index 5243876a1b..dc55104229 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_reorder_batched_ad.cu +++ b/fbgemm_gpu/src/sparse_ops/sparse_reorder_batched_ad.cu @@ -243,7 +243,8 @@ DLL_PUBLIC Tensor reorder_batched_ad_indices_gpu( const dim3 threads(32, 32); const dim3 blocks((B * T + 32 - 1) / 32); - AT_DISPATCH_ALL_TYPES( + AT_DISPATCH_ALL_TYPES_AND( + at::ScalarType::BFloat16, cat_ad_indices.scalar_type(), "reorder_batched_ad_indices_gpu_kernel_1", [&] { diff --git a/fbgemm_gpu/src/split_embeddings_cache/common.h b/fbgemm_gpu/src/split_embeddings_cache/common.h index 9c22f02e40..d43e4467c6 100644 --- a/fbgemm_gpu/src/split_embeddings_cache/common.h +++ b/fbgemm_gpu/src/split_embeddings_cache/common.h @@ -89,7 +89,9 @@ Tensor lxu_cache_lookup_cpu( Tensor lxu_cache_state, int64_t invalid_index, bool gather_cache_stats, - c10::optional uvm_cache_stats); + c10::optional uvm_cache_stats, + c10::optional num_uniq_cache_indices, + c10::optional lxu_cache_locations_output); Tensor direct_mapped_lxu_cache_lookup_cpu( Tensor linear_cache_indices, diff --git a/fbgemm_gpu/src/split_embeddings_cache/linearize_cache_indices.cu b/fbgemm_gpu/src/split_embeddings_cache/linearize_cache_indices.cu index 088ba930e5..9bb6fe5c86 100644 --- a/fbgemm_gpu/src/split_embeddings_cache/linearize_cache_indices.cu +++ b/fbgemm_gpu/src/split_embeddings_cache/linearize_cache_indices.cu @@ -13,15 +13,15 @@ using namespace fbgemm_gpu; namespace { -template +template __global__ __launch_bounds__(kMaxThreads) void linearize_cache_indices_kernel( const pta::PackedTensorAccessor32 cache_hash_size_cumsum, const pta::PackedTensorAccessor32 indices, - const pta::PackedTensorAccessor32 + const pta::PackedTensorAccessor32 table_offsets, - pta::PackedTensorAccessor32 + pta::PackedTensorAccessor32 linear_cache_indices) { const index_t index = blockIdx.x * blockDim.x + threadIdx.x; if (index >= indices.size(0)) { @@ -72,31 +72,36 @@ DLL_PUBLIC Tensor linearize_cache_indices_cuda( const auto B = (offsets.size(0) - 1) / T; TORCH_CHECK(B >= 0); - auto linear_cache_indices = at::empty_like(indices); + auto linear_cache_indices = + at::empty(indices.sizes(), indices.options().dtype(at::kLong)); const auto num_indices = indices.numel(); if (B == 0 || num_indices == 0) { return linear_cache_indices; } - auto table_offsets = offsets.slice(0, B, B * T, B); + const auto table_offsets = offsets.slice(0, B, B * T, B); AT_DISPATCH_INDEX_TYPES( - indices.scalar_type(), "linearize_cache_indices_kernel", [&] { + table_offsets.scalar_type(), "linearize_cache_indices_kernel_1", [&] { + using offset_t = index_t; + AT_DISPATCH_INDEX_TYPES( + indices.scalar_type(), "linearize_cache_indices_kernel_2", [&] { #ifdef FBGEMM_GPU_MEMCHECK - const char* func_name = "linearize_cache_indices_kernel"; + const char* func_name = "linearize_cache_indices_kernel"; #endif - linearize_cache_indices_kernel<<< - div_round_up(num_indices, kMaxThreads), - kMaxThreads, - 0, - at::cuda::getCurrentCUDAStream()>>>( - MAKE_PTA_WITH_NAME( - func_name, cache_hash_size_cumsum, int64_t, 1, 32), - MAKE_PTA_WITH_NAME(func_name, indices, index_t, 1, 32), - MAKE_PTA_WITH_NAME(func_name, table_offsets, index_t, 1, 32), - MAKE_PTA_WITH_NAME( - func_name, linear_cache_indices, index_t, 1, 32)); - C10_CUDA_KERNEL_LAUNCH_CHECK(); + linearize_cache_indices_kernel<<< + div_round_up(num_indices, kMaxThreads), + kMaxThreads, + 0, + at::cuda::getCurrentCUDAStream()>>>( + MAKE_PTA_WITH_NAME( + func_name, cache_hash_size_cumsum, int64_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name, indices, index_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name, table_offsets, offset_t, 1, 32), + MAKE_PTA_WITH_NAME( + func_name, linear_cache_indices, int64_t, 1, 32)); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); }); return linear_cache_indices; } diff --git a/fbgemm_gpu/src/split_embeddings_cache/lxu_cache.cpp b/fbgemm_gpu/src/split_embeddings_cache/lxu_cache.cpp index 490da1fb7d..296f966417 100644 --- a/fbgemm_gpu/src/split_embeddings_cache/lxu_cache.cpp +++ b/fbgemm_gpu/src/split_embeddings_cache/lxu_cache.cpp @@ -14,12 +14,14 @@ namespace fbgemm_gpu { DLL_PUBLIC Tensor lxu_cache_lookup_cpu( Tensor linear_cache_indices, - Tensor lxu_cache_state, - int64_t invalid_index, - bool gather_cache_stats, - c10::optional uvm_cache_stats) { - return empty_like( - linear_cache_indices, linear_cache_indices.options().dtype(at::kInt)); + Tensor /* lxu_cache_state */, + int64_t /* invalid_index */, + bool /* gather_cache_stats */, + c10::optional /* uvm_cache_stats */, + c10::optional /* num_uniq_cache_indices */, + c10::optional lxu_cache_locations_output) { + return lxu_cache_locations_output.value_or(empty_like( + linear_cache_indices, linear_cache_indices.options().dtype(at::kInt))); } DLL_PUBLIC Tensor direct_mapped_lxu_cache_lookup_cpu( diff --git a/fbgemm_gpu/src/split_embeddings_cache/lxu_cache.cu b/fbgemm_gpu/src/split_embeddings_cache/lxu_cache.cu index 6d43c259b5..3e39eb9e23 100644 --- a/fbgemm_gpu/src/split_embeddings_cache/lxu_cache.cu +++ b/fbgemm_gpu/src/split_embeddings_cache/lxu_cache.cu @@ -254,9 +254,11 @@ __global__ __launch_bounds__(kMaxThreads) void lxu_cache_lookup_kernel( lxu_cache_locations, const bool gather_cache_stats, pta::PackedTensorAccessor32 - uvm_cache_stats) { + uvm_cache_stats, + const int32_t* N_unique) { const int32_t C = lxu_cache_state.size(0); - const int32_t N = linear_cache_indices.size(0); + const int32_t N = + N_unique == nullptr ? linear_cache_indices.size(0) : *N_unique; const int32_t n0 = blockIdx.x * blockDim.y * blockDim.x + threadIdx.y * blockDim.x; if (n0 >= N) { @@ -368,14 +370,56 @@ __launch_bounds__(kMaxThreads) void direct_mapped_lxu_cache_lookup_kernel( } // namespace +/// Lookup the cache locations for each linear cache indices in +/// linear_cache_indices and return lxu_cache_locations +/// +/// lxu_cache_locations A 1D tensor with the same length as +/// linear_cache_indices. It contains the cache locations +/// (the row indices in the cache) of the corresponding +/// indices in linear_cache_indices, i.e., +/// lxu_cache_locations[i] is the cache location for +/// linear_cache_indices[i], where 0 <= i < +/// linear_cache_indices.numel(). +/// +/// @param linear_cache_indices Linear cache indices tensor (1D) +/// @param lxu_cache_state LXU cache state tensor (2D tensor of +/// shape (# of cache sets, # of cache +/// slots per set)). It contains linear +/// indices of rows that are in the +/// corresponding cache slots. If the cache +/// slot is empty, a sentinel value is +/// stored. +/// @param invalid_index A sentinel value for linear cache +/// indices. A cache index is skipped if it +/// is a sentinel value. +/// @param gather_cache_stats A flag to enable/disable cache stats +/// collection. +/// @param uvm_cache_stats A tensor for storing cache stats. +/// @param num_uniq_cache_indices An optional GPU tensor that contains the +/// number of unique cache indices. If this +/// tensor is passed, the kernel will only +/// lookup num_uniq_cache_indices number of +/// indices instead of looking up the entire +/// linear_cache_indices. +/// @param lxu_cache_locations_output An optional output tensor. If the +/// tensor is passed, the operator will not +/// allocate a new output tensor and use +/// this tensor as an output tensor. DLL_PUBLIC Tensor lxu_cache_lookup_cuda( - Tensor linear_cache_indices, - Tensor lxu_cache_state, - int64_t invalid_index, - bool gather_cache_stats, - c10::optional uvm_cache_stats) { + const Tensor linear_cache_indices, + const Tensor lxu_cache_state, + const int64_t invalid_index, + const bool gather_cache_stats, + const c10::optional uvm_cache_stats, + const c10::optional num_uniq_cache_indices, + const c10::optional lxu_cache_locations_output) { + const auto uniq_lookup = num_uniq_cache_indices.has_value(); + // TODO: Support gather_cache_stats=true when uniq_lookup=true + TORCH_CHECK( + !uniq_lookup || !gather_cache_stats, + "Unique lxu_cache_locations generation does not support gather_cache_stats=true"); TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL( - linear_cache_indices, lxu_cache_state); + linear_cache_indices, lxu_cache_state, num_uniq_cache_indices); Tensor uvm_cache_stats_ = at::empty({0}, linear_cache_indices.options().dtype(at::kInt)); if (gather_cache_stats) { @@ -386,9 +430,12 @@ DLL_PUBLIC Tensor lxu_cache_lookup_cuda( at::cuda::OptionalCUDAGuard device_guard; device_guard.set_index(linear_cache_indices.get_device()); + const auto lxu_cache_locations = + lxu_cache_locations_output.value_or(empty_like( + linear_cache_indices, + linear_cache_indices.options().dtype(at::kInt))); + const auto N = linear_cache_indices.numel(); - auto lxu_cache_locations = empty_like( - linear_cache_indices, linear_cache_indices.options().dtype(at::kInt)); if (linear_cache_indices.numel() == 0) { // nothing to do return lxu_cache_locations; @@ -412,10 +459,12 @@ DLL_PUBLIC Tensor lxu_cache_lookup_cuda( invalid_index, MAKE_PTA_WITH_NAME(func_name, lxu_cache_locations, int32_t, 1, 32), gather_cache_stats, - MAKE_PTA_WITH_NAME(func_name, uvm_cache_stats_, int32_t, 1, 32)); + MAKE_PTA_WITH_NAME(func_name, uvm_cache_stats_, int32_t, 1, 32), + num_uniq_cache_indices.has_value() + ? num_uniq_cache_indices.value().data_ptr() + : nullptr); C10_CUDA_KERNEL_LAUNCH_CHECK(); }); - return lxu_cache_locations; } @@ -479,11 +528,13 @@ __launch_bounds__(kMaxThreads) void lxu_cache_locations_update_kernel( pta::PackedTensorAccessor32 lxu_cache_locations, const pta::PackedTensorAccessor32 - lxu_cache_locations_new) { - const int32_t N = lxu_cache_locations.size(0); + lxu_cache_locations_new, + const int32_t* N_unique) { + const auto N = N_unique == nullptr ? lxu_cache_locations.size(0) : *N_unique; CUDA_KERNEL_LOOP(n, N) { - if (lxu_cache_locations[n] == kCacheLocationMissing && - lxu_cache_locations_new[n] >= 0) { + if (N_unique != nullptr || + (lxu_cache_locations[n] == kCacheLocationMissing && + lxu_cache_locations_new[n] >= 0)) { lxu_cache_locations[n] = lxu_cache_locations_new[n]; } } @@ -493,9 +544,10 @@ __launch_bounds__(kMaxThreads) void lxu_cache_locations_update_kernel( DLL_PUBLIC void lxu_cache_locations_update_cuda( Tensor lxu_cache_locations, - Tensor lxu_cache_locations_new) { + Tensor lxu_cache_locations_new, + c10::optional num_uniq_cache_indices) { TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL( - lxu_cache_locations, lxu_cache_locations_new); + lxu_cache_locations, lxu_cache_locations_new, num_uniq_cache_indices); at::cuda::OptionalCUDAGuard device_guard; device_guard.set_index(lxu_cache_locations.get_device()); @@ -520,7 +572,10 @@ DLL_PUBLIC void lxu_cache_locations_update_cuda( 0, at::cuda::getCurrentCUDAStream()>>>( MAKE_PTA_WITH_NAME(func_name, lxu_cache_locations, int32_t, 1, 32), - MAKE_PTA_WITH_NAME(func_name, lxu_cache_locations_new, int32_t, 1, 32)); + MAKE_PTA_WITH_NAME(func_name, lxu_cache_locations_new, int32_t, 1, 32), + num_uniq_cache_indices.has_value() + ? num_uniq_cache_indices.value().data_ptr() + : nullptr); C10_CUDA_KERNEL_LAUNCH_CHECK(); return; diff --git a/fbgemm_gpu/src/split_embeddings_cache/reset_weight_momentum.cu b/fbgemm_gpu/src/split_embeddings_cache/reset_weight_momentum.cu index 104bf140e1..321c225772 100644 --- a/fbgemm_gpu/src/split_embeddings_cache/reset_weight_momentum.cu +++ b/fbgemm_gpu/src/split_embeddings_cache/reset_weight_momentum.cu @@ -276,7 +276,10 @@ DLL_PUBLIC void reset_weight_momentum_cuda( lxu_cache_state, total_cache_hash_size, false, // gather_cache_stats - uvm_cache_stats); + uvm_cache_stats, + c10::optional(), // num_uniq_cache_indices + c10::optional() // lxu_cache_locations_output + ); } // Reset weight and momentum of pruned rows diff --git a/fbgemm_gpu/src/split_embeddings_cache/split_embeddings_cache_ops.cpp b/fbgemm_gpu/src/split_embeddings_cache/split_embeddings_cache_ops.cpp index 2565ed219c..816f0fdaf5 100644 --- a/fbgemm_gpu/src/split_embeddings_cache/split_embeddings_cache_ops.cpp +++ b/fbgemm_gpu/src/split_embeddings_cache/split_embeddings_cache_ops.cpp @@ -26,7 +26,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { m.def( "lfu_cache_populate_byte(Tensor weights, Tensor cache_hash_size_cumsum, int total_cache_hash_size, Tensor cache_index_table_map, Tensor weights_offsets, Tensor weights_tys, Tensor D_offsets, Tensor linear_cache_indices, Tensor(a!) lxu_cache_state, Tensor(b!) lxu_cache_weights, Tensor(c!) lfu_state, int row_alignment=16) -> ()"); m.def( - "lxu_cache_lookup(Tensor linear_cache_indices, Tensor lxu_cache_state, int invalid_index = -1, bool gather_cache_stats=False, Tensor(a!)? uvm_cache_stats=None) -> Tensor"); + "lxu_cache_lookup(Tensor linear_cache_indices, Tensor lxu_cache_state, int invalid_index = -1, bool gather_cache_stats=False, Tensor(a!)? uvm_cache_stats=None, Tensor? num_uniq_cache_indices=None, Tensor(b!)? lxu_cache_locations_output=None) -> Tensor"); m.def( "direct_mapped_lxu_cache_lookup(Tensor linear_cache_indices, Tensor lxu_cache_state, int invalid_index = -1, bool gather_cache_stats=False, Tensor(a!)? uvm_cache_stats=None) -> Tensor"); m.def( @@ -37,7 +37,9 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { m.def( "lxu_cache_locking_counter_decrement(Tensor(a!) lxu_cache_locking_counter, Tensor lxu_cache_locations) -> ()"); m.def( - "lxu_cache_locations_update(Tensor(a!) lxu_cache_locations, Tensor lxu_cache_locations_new) -> ()"); + "lxu_cache_locations_update(Tensor(a!) lxu_cache_locations, Tensor lxu_cache_locations_new, Tensor? num_uniq_cache_indices=None) -> ()"); + m.def( + "get_unique_indices(Tensor linear_indices, int max_indices, bool compute_count) -> (Tensor, Tensor, Tensor?)"); } using namespace fbgemm_gpu; diff --git a/fbgemm_gpu/src/split_embeddings_cache/split_embeddings_cache_ops.cu b/fbgemm_gpu/src/split_embeddings_cache/split_embeddings_cache_ops.cu index 91b4024dc9..fdf5c2ccc7 100644 --- a/fbgemm_gpu/src/split_embeddings_cache/split_embeddings_cache_ops.cu +++ b/fbgemm_gpu/src/split_embeddings_cache/split_embeddings_cache_ops.cu @@ -33,6 +33,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { lxu_cache_locking_counter_decrement_cuda); DISPATCH_TO_CUDA( "lxu_cache_locations_update", lxu_cache_locations_update_cuda); + DISPATCH_TO_CUDA("get_unique_indices", get_unique_indices_cuda); } } // namespace diff --git a/fbgemm_gpu/src/split_embeddings_utils/generate_vbe_metadata.cu b/fbgemm_gpu/src/split_embeddings_utils/generate_vbe_metadata.cu new file mode 100644 index 0000000000..a67cf1065f --- /dev/null +++ b/fbgemm_gpu/src/split_embeddings_utils/generate_vbe_metadata.cu @@ -0,0 +1,160 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "fbgemm_gpu/fbgemm_cuda_utils.cuh" +#include "fbgemm_gpu/ops_utils.h" +#include "fbgemm_gpu/sparse_ops_utils.h" +#include "fbgemm_gpu/split_embeddings_utils.cuh" + +using Tensor = at::Tensor; +using namespace fbgemm_gpu; + +namespace { + +__global__ +__launch_bounds__(kMaxThreads) void generate_vbe_metadata_foreach_sample_kernel( + at::PackedTensorAccessor32 + row_output_offsets, + at::PackedTensorAccessor32 b_t_map, + const at::PackedTensorAccessor32 + B_offsets, + const at::PackedTensorAccessor32 + B_offsets_rank_per_feature, + const at::PackedTensorAccessor32 + output_offsets_feature_rank, + const at::PackedTensorAccessor32 + D_offsets, + const int32_t D, + const bool nobag, + FixedDivisor fd_max_B, + FixedDivisor fd_max_B_T, + const int32_t info_B_num_bits) { + const auto r_b_t = blockIdx.x * blockDim.x + threadIdx.x; + const auto T = B_offsets.size(0) - 1; // Num tables + const auto R = B_offsets_rank_per_feature.size(1) - 1; // Num ranks + + int32_t b_t; + int32_t r; // Rank ID + int32_t t; // Table ID + int32_t b; // Relative sample ID in the rank-table matrix + + fd_max_B_T.DivMod(r_b_t, &r, &b_t); + if (r >= R) { + return; + } + + fd_max_B.DivMod(b_t, &t, &b); + if (t >= T) { + return; + } + + const auto B_start_r_t = B_offsets_rank_per_feature[t][r]; + const auto B_r_t = B_offsets_rank_per_feature[t][r + 1] - B_start_r_t; + if (b >= B_r_t) { + return; + } + + const auto B_start_t = B_offsets[t]; + // Update b_t + b_t = B_start_t + B_start_r_t + b; + const auto D_ = nobag ? D : D_offsets[t + 1] - D_offsets[t]; + row_output_offsets[b_t] = output_offsets_feature_rank[r * T + t] + b * D_; + + // Relative sample ID in the table + const auto b_ = B_start_r_t + b; + // b_t is always positive. + *reinterpret_cast(&b_t_map[b_t]) = + (reinterpret_cast(&t)[0] << info_B_num_bits) | + reinterpret_cast(&b_)[0]; +} + +} // namespace + +/// Generate VBE metadata namely output_offsets and b_t_map +/// +/// row_output_offsets A 1D tensor that contains the output offset of each b +/// (sample) and t (feature/table) pair. The output +/// serializes O_r_t where O_r_t is the local output of rank +/// r and feature/table t (t is the fastest moving index). +/// b_t_map A 1D tensor that contains the b and t information of the +/// linearized b and t (b is the fastest moving index). +/// +/// @param B_offsets Batch size offsets for all features. +/// @param B_offsets_rank_per_feature Batch size offsets for all ranks (GPUs) +/// for each feature. +/// @param output_offsets_feature_rank Output offsets for all features and ranks +/// and features. +/// @param D_offsets Embedding dimension offsets. Required if +/// nobag is false. +/// @param D The embedding dimension. Required if +/// nobag is true. +/// @param nobag A boolean to indicate if TBE is pooled +/// (false) or sequence (true). +/// @param info_B_num_bits The number of bits used to encode a +/// sample ID. (Used for populating b_t_map). +/// @param total_B The total number of samples (i.e., the +/// total number of b and t pairs). +DLL_PUBLIC std::tuple +generate_vbe_metadata( + const Tensor& B_offsets, + const Tensor& B_offsets_rank_per_feature, + const Tensor& output_offsets_feature_rank, + const Tensor& D_offsets, + const int64_t D, + const bool nobag, + const int64_t max_B_feature_rank, + const int64_t info_B_num_bits, + const int64_t total_B) { + TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL( + B_offsets, B_offsets_rank_per_feature, output_offsets_feature_rank); + + TENSOR_NDIM_EQUALS(B_offsets, 1); + TENSOR_NDIM_EQUALS(B_offsets_rank_per_feature, 2); + TENSOR_NDIM_EQUALS(output_offsets_feature_rank, 1); + + const int32_t T = B_offsets.numel() - 1; + if (!nobag) { + TENSOR_ON_CUDA_GPU(D_offsets); + TENSORS_ON_SAME_DEVICE(B_offsets, D_offsets); + TORCH_CHECK(D_offsets.numel() == T + 1) + } + + const auto num_ranks = B_offsets_rank_per_feature.size(1) - 1; + TORCH_CHECK(B_offsets_rank_per_feature.size(0) == T); + TORCH_CHECK(output_offsets_feature_rank.numel() == num_ranks * T + 1); + + at::cuda::OptionalCUDAGuard device_guard; + device_guard.set_index(B_offsets.get_device()); + + Tensor row_output_offsets = + at::empty({total_B}, output_offsets_feature_rank.options()); + Tensor b_t_map = at::empty({total_B}, B_offsets.options()); + + // Over allocate total number of threads to avoid using binary search + generate_vbe_metadata_foreach_sample_kernel<<< + div_round_up(max_B_feature_rank * T * num_ranks, kMaxThreads), + kMaxThreads, + 0, + at::cuda::getCurrentCUDAStream()>>>( + row_output_offsets.packed_accessor32(), + b_t_map.packed_accessor32(), + B_offsets.packed_accessor32(), + B_offsets_rank_per_feature + .packed_accessor32(), + output_offsets_feature_rank + .packed_accessor32(), + D_offsets.packed_accessor32(), + D, + nobag, + FixedDivisor(max_B_feature_rank), + FixedDivisor(max_B_feature_rank * T), + info_B_num_bits); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + + return {row_output_offsets, b_t_map}; +} diff --git a/fbgemm_gpu/src/split_embeddings_utils/get_infos_metadata.cu b/fbgemm_gpu/src/split_embeddings_utils/get_infos_metadata.cu new file mode 100644 index 0000000000..d9f40e1f63 --- /dev/null +++ b/fbgemm_gpu/src/split_embeddings_utils/get_infos_metadata.cu @@ -0,0 +1,65 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "fbgemm_gpu/embedding_backward_template_helpers.cuh" +#include "fbgemm_gpu/ops_utils.h" +#include "fbgemm_gpu/split_embeddings_utils.cuh" + +using Tensor = at::Tensor; +using namespace fbgemm_gpu; + +DLL_PUBLIC std::tuple adjust_info_B_num_bits( + int32_t B, + int32_t T) { + int32_t info_B_num_bits = DEFAULT_INFO_B_NUM_BITS; + uint32_t info_B_mask = DEFAULT_INFO_B_MASK; + uint32_t max_T = MAX_T; + uint32_t max_B = MAX_B; + bool invalid_T = T > max_T; + bool invalid_B = B > max_B; + + TORCH_CHECK( + !(invalid_T && invalid_B), + "Not enough infos bits to accommodate T and B. Default num bits = ", + DEFAULT_INFO_NUM_BITS); + + if (invalid_T) { + // Reduce info_B_num_bits + while (invalid_T && !invalid_B && info_B_num_bits > 0) { + info_B_num_bits--; + max_T = ((max_T + 1) << 1) - 1; + max_B = ((max_B + 1) >> 1) - 1; + invalid_T = T > max_T; + invalid_B = B > max_B; + } + } else if (invalid_B) { + // Increase info_B_num_bits + while (!invalid_T && invalid_B && info_B_num_bits < DEFAULT_INFO_NUM_BITS) { + info_B_num_bits++; + max_T = ((max_T + 1) >> 1) - 1; + max_B = ((max_B + 1) << 1) - 1; + invalid_T = T > max_T; + invalid_B = B > max_B; + } + } + + TORCH_CHECK( + !invalid_T && !invalid_B, + "Not enough infos bits to accommodate T and B. Default num bits = ", + DEFAULT_INFO_NUM_BITS); + + // Recompute info_B_mask using new info_B_num_bits + info_B_mask = (1u << info_B_num_bits) - 1; + + return {info_B_num_bits, info_B_mask}; +} + +DLL_PUBLIC std::tuple +get_infos_metadata(Tensor unused, int64_t B, int64_t T) { + return adjust_info_B_num_bits(B, T); +} diff --git a/fbgemm_gpu/src/split_embeddings_utils/radix_sort_pairs.cu b/fbgemm_gpu/src/split_embeddings_utils/radix_sort_pairs.cu new file mode 100644 index 0000000000..121f66625e --- /dev/null +++ b/fbgemm_gpu/src/split_embeddings_utils/radix_sort_pairs.cu @@ -0,0 +1,83 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "fbgemm_gpu/split_embeddings_utils.cuh" + +#include +#include +#include "fbgemm_gpu/embedding_backward_template_helpers.cuh" +#include "fbgemm_gpu/ops_utils.h" + +// clang-format off +#include "fbgemm_gpu/cub_namespace_prefix.cuh" +#include +#include +#include +#include "fbgemm_gpu/cub_namespace_postfix.cuh" +// clang-format on + +using Tensor = at::Tensor; +using namespace fbgemm_gpu; + +#if defined(CUDA_VERSION) && CUDA_VERSION >= 12000 +#define DEF_RADIX_SORT_PAIRS_FN(KeyT, ValueT) \ + DLL_PUBLIC cudaError_t radix_sort_pairs( \ + void* d_temp_storage, \ + size_t& temp_storage_bytes, \ + const KeyT* d_keys_in, \ + KeyT* d_keys_out, \ + const ValueT* d_values_in, \ + ValueT* d_values_out, \ + const int num_items, \ + const int begin_bit, \ + const int end_bit, \ + cudaStream_t stream) { \ + return FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceRadixSort::SortPairs( \ + d_temp_storage, \ + temp_storage_bytes, \ + d_keys_in, \ + d_keys_out, \ + d_values_in, \ + d_values_out, \ + num_items, \ + begin_bit, \ + end_bit, \ + stream); \ + } +#else +#define DEF_RADIX_SORT_PAIRS_FN(KeyT, ValueT) \ + DLL_PUBLIC cudaError_t radix_sort_pairs( \ + void* d_temp_storage, \ + size_t& temp_storage_bytes, \ + const KeyT* d_keys_in, \ + KeyT* d_keys_out, \ + const ValueT* d_values_in, \ + ValueT* d_values_out, \ + const int num_items, \ + const int begin_bit, \ + const int end_bit, \ + cudaStream_t stream) { \ + return FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceRadixSort::SortPairs( \ + d_temp_storage, \ + temp_storage_bytes, \ + d_keys_in, \ + d_keys_out, \ + d_values_in, \ + d_values_out, \ + num_items, \ + begin_bit, \ + end_bit, \ + stream, \ + false); \ + } +#endif + +DEF_RADIX_SORT_PAIRS_FN(int64_t, float); +DEF_RADIX_SORT_PAIRS_FN(int64_t, double); +DEF_RADIX_SORT_PAIRS_FN(int64_t, int64_t); +DEF_RADIX_SORT_PAIRS_FN(int64_t, int32_t); diff --git a/fbgemm_gpu/src/split_embeddings_utils.cpp b/fbgemm_gpu/src/split_embeddings_utils/split_embeddings_utils.cpp similarity index 100% rename from fbgemm_gpu/src/split_embeddings_utils.cpp rename to fbgemm_gpu/src/split_embeddings_utils/split_embeddings_utils.cpp diff --git a/fbgemm_gpu/src/split_embeddings_utils.cu b/fbgemm_gpu/src/split_embeddings_utils/transpose_embedding_input.cu similarity index 60% rename from fbgemm_gpu/src/split_embeddings_utils.cu rename to fbgemm_gpu/src/split_embeddings_utils/transpose_embedding_input.cu index dd5c0ec70a..d1eb5e00a1 100644 --- a/fbgemm_gpu/src/split_embeddings_utils.cu +++ b/fbgemm_gpu/src/split_embeddings_utils/transpose_embedding_input.cu @@ -6,12 +6,9 @@ * LICENSE file in the root directory of this source tree. */ -#include "fbgemm_gpu/split_embeddings_utils.cuh" - -#include -#include #include "fbgemm_gpu/embedding_backward_template_helpers.cuh" #include "fbgemm_gpu/ops_utils.h" +#include "fbgemm_gpu/split_embeddings_utils.cuh" // clang-format off #include "fbgemm_gpu/cub_namespace_prefix.cuh" @@ -21,9 +18,8 @@ #include "fbgemm_gpu/cub_namespace_postfix.cuh" // clang-format on -#ifdef USE_ROCM -#include -#endif +using Tensor = at::Tensor; +using namespace fbgemm_gpu; inline at::Tensor asynchronous_complete_cumsum(at::Tensor t_in) { at::cuda::OptionalCUDAGuard device_guard; @@ -62,10 +58,6 @@ inline at::Tensor asynchronous_complete_cumsum(at::Tensor t_in) { return t_out; } -using Tensor = at::Tensor; - -using namespace fbgemm_gpu; - template __global__ __launch_bounds__(kMaxThreads) void linearize_index_kernel( const at::PackedTensorAccessor32 @@ -394,253 +386,3 @@ transpose_embedding_input( sorted_linear_indices_num_runs, sorted_linear_indices_cumulative_run_lengths}; } - -std::tuple -get_infos_metadata(Tensor unused, int64_t B, int64_t T) { - return adjust_info_B_num_bits(B, T); -} - -DLL_PUBLIC std::tuple adjust_info_B_num_bits( - int32_t B, - int32_t T) { - int32_t info_B_num_bits = DEFAULT_INFO_B_NUM_BITS; - uint32_t info_B_mask = DEFAULT_INFO_B_MASK; - uint32_t max_T = MAX_T; - uint32_t max_B = MAX_B; - bool invalid_T = T > max_T; - bool invalid_B = B > max_B; - - TORCH_CHECK( - !(invalid_T && invalid_B), - "Not enough infos bits to accommodate T and B. Default num bits = ", - DEFAULT_INFO_NUM_BITS); - - if (invalid_T) { - // Reduce info_B_num_bits - while (invalid_T && !invalid_B && info_B_num_bits > 0) { - info_B_num_bits--; - max_T = ((max_T + 1) << 1) - 1; - max_B = ((max_B + 1) >> 1) - 1; - invalid_T = T > max_T; - invalid_B = B > max_B; - } - } else if (invalid_B) { - // Increase info_B_num_bits - while (!invalid_T && invalid_B && info_B_num_bits < DEFAULT_INFO_NUM_BITS) { - info_B_num_bits++; - max_T = ((max_T + 1) >> 1) - 1; - max_B = ((max_B + 1) << 1) - 1; - invalid_T = T > max_T; - invalid_B = B > max_B; - } - } - - TORCH_CHECK( - !invalid_T && !invalid_B, - "Not enough infos bits to accommodate T and B. Default num bits = ", - DEFAULT_INFO_NUM_BITS); - - // Recompute info_B_mask using new info_B_num_bits - info_B_mask = (1u << info_B_num_bits) - 1; - - return {info_B_num_bits, info_B_mask}; -} - -#if defined(CUDA_VERSION) && CUDA_VERSION >= 12000 -#define DEF_RADIX_SORT_PAIRS_FN(KeyT, ValueT) \ - DLL_PUBLIC cudaError_t radix_sort_pairs( \ - void* d_temp_storage, \ - size_t& temp_storage_bytes, \ - const KeyT* d_keys_in, \ - KeyT* d_keys_out, \ - const ValueT* d_values_in, \ - ValueT* d_values_out, \ - const int num_items, \ - const int begin_bit, \ - const int end_bit, \ - cudaStream_t stream) { \ - return FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceRadixSort::SortPairs( \ - d_temp_storage, \ - temp_storage_bytes, \ - d_keys_in, \ - d_keys_out, \ - d_values_in, \ - d_values_out, \ - num_items, \ - begin_bit, \ - end_bit, \ - stream); \ - } -#else -#define DEF_RADIX_SORT_PAIRS_FN(KeyT, ValueT) \ - DLL_PUBLIC cudaError_t radix_sort_pairs( \ - void* d_temp_storage, \ - size_t& temp_storage_bytes, \ - const KeyT* d_keys_in, \ - KeyT* d_keys_out, \ - const ValueT* d_values_in, \ - ValueT* d_values_out, \ - const int num_items, \ - const int begin_bit, \ - const int end_bit, \ - cudaStream_t stream) { \ - return FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceRadixSort::SortPairs( \ - d_temp_storage, \ - temp_storage_bytes, \ - d_keys_in, \ - d_keys_out, \ - d_values_in, \ - d_values_out, \ - num_items, \ - begin_bit, \ - end_bit, \ - stream, \ - false); \ - } -#endif - -DEF_RADIX_SORT_PAIRS_FN(int64_t, float); -DEF_RADIX_SORT_PAIRS_FN(int64_t, double); -DEF_RADIX_SORT_PAIRS_FN(int64_t, int64_t); -DEF_RADIX_SORT_PAIRS_FN(int64_t, int32_t); - -__global__ -__launch_bounds__(kMaxThreads) void generate_vbe_metadata_foreach_sample_kernel( - at::PackedTensorAccessor32 - row_output_offsets, - at::PackedTensorAccessor32 b_t_map, - const at::PackedTensorAccessor32 - B_offsets, - const at::PackedTensorAccessor32 - B_offsets_rank_per_feature, - const at::PackedTensorAccessor32 - output_offsets_feature_rank, - const at::PackedTensorAccessor32 - D_offsets, - const int32_t D, - const bool nobag, - FixedDivisor fd_max_B, - FixedDivisor fd_max_B_T, - const int32_t info_B_num_bits) { - const auto r_b_t = blockIdx.x * blockDim.x + threadIdx.x; - const auto T = B_offsets.size(0) - 1; // Num tables - const auto R = B_offsets_rank_per_feature.size(1) - 1; // Num ranks - - int32_t b_t; - int32_t r; // Rank ID - int32_t t; // Table ID - int32_t b; // Relative sample ID in the rank-table matrix - - fd_max_B_T.DivMod(r_b_t, &r, &b_t); - if (r >= R) { - return; - } - - fd_max_B.DivMod(b_t, &t, &b); - if (t >= T) { - return; - } - - const auto B_start_r_t = B_offsets_rank_per_feature[t][r]; - const auto B_r_t = B_offsets_rank_per_feature[t][r + 1] - B_start_r_t; - if (b >= B_r_t) { - return; - } - - const auto B_start_t = B_offsets[t]; - // Update b_t - b_t = B_start_t + B_start_r_t + b; - const auto D_ = nobag ? D : D_offsets[t + 1] - D_offsets[t]; - row_output_offsets[b_t] = output_offsets_feature_rank[r * T + t] + b * D_; - - // Relative sample ID in the table - const auto b_ = B_start_r_t + b; - // b_t is always positive. - *reinterpret_cast(&b_t_map[b_t]) = - (reinterpret_cast(&t)[0] << info_B_num_bits) | - reinterpret_cast(&b_)[0]; -} - -/// Generate VBE metadata namely output_offsets and b_t_map -/// -/// row_output_offsets A 1D tensor that contains the output offset of each b -/// (sample) and t (feature/table) pair. The output -/// serializes O_r_t where O_r_t is the local output of rank -/// r and feature/table t (t is the fastest moving index). -/// b_t_map A 1D tensor that contains the b and t information of the -/// linearized b and t (b is the fastest moving index). -/// -/// @param B_offsets Batch size offsets for all features. -/// @param B_offsets_rank_per_feature Batch size offsets for all ranks (GPUs) -/// for each feature. -/// @param output_offsets_feature_rank Output offsets for all features and ranks -/// and features. -/// @param D_offsets Embedding dimension offsets. Required if -/// nobag is false. -/// @param D The embedding dimension. Required if -/// nobag is true. -/// @param nobag A boolean to indicate if TBE is pooled -/// (false) or sequence (true). -/// @param info_B_num_bits The number of bits used to encode a -/// sample ID. (Used for populating b_t_map). -/// @param total_B The total number of samples (i.e., the -/// total number of b and t pairs). -DLL_PUBLIC std::tuple -generate_vbe_metadata( - const Tensor& B_offsets, - const Tensor& B_offsets_rank_per_feature, - const Tensor& output_offsets_feature_rank, - const Tensor& D_offsets, - const int64_t D, - const bool nobag, - const int64_t max_B_feature_rank, - const int64_t info_B_num_bits, - const int64_t total_B) { - TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL( - B_offsets, B_offsets_rank_per_feature, output_offsets_feature_rank); - - TENSOR_NDIM_EQUALS(B_offsets, 1); - TENSOR_NDIM_EQUALS(B_offsets_rank_per_feature, 2); - TENSOR_NDIM_EQUALS(output_offsets_feature_rank, 1); - - const int32_t T = B_offsets.numel() - 1; - if (!nobag) { - TENSOR_ON_CUDA_GPU(D_offsets); - TENSORS_ON_SAME_DEVICE(B_offsets, D_offsets); - TORCH_CHECK(D_offsets.numel() == T + 1) - } - - const auto num_ranks = B_offsets_rank_per_feature.size(1) - 1; - TORCH_CHECK(B_offsets_rank_per_feature.size(0) == T); - TORCH_CHECK(output_offsets_feature_rank.numel() == num_ranks * T + 1); - - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(B_offsets.get_device()); - - Tensor row_output_offsets = - at::empty({total_B}, output_offsets_feature_rank.options()); - Tensor b_t_map = at::empty({total_B}, B_offsets.options()); - - // Over allocate total number of threads to avoid using binary search - generate_vbe_metadata_foreach_sample_kernel<<< - div_round_up(max_B_feature_rank * T * num_ranks, kMaxThreads), - kMaxThreads, - 0, - at::cuda::getCurrentCUDAStream()>>>( - row_output_offsets.packed_accessor32(), - b_t_map.packed_accessor32(), - B_offsets.packed_accessor32(), - B_offsets_rank_per_feature - .packed_accessor32(), - output_offsets_feature_rank - .packed_accessor32(), - D_offsets.packed_accessor32(), - D, - nobag, - FixedDivisor(max_B_feature_rank), - FixedDivisor(max_B_feature_rank * T), - info_B_num_bits); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - - return {row_output_offsets, b_t_map}; -} diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache_cuda.cu b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_embeddings_cache_cuda.cu similarity index 100% rename from fbgemm_gpu/src/ssd_split_embeddings_cache_cuda.cu rename to fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_embeddings_cache_cuda.cu diff --git a/fbgemm_gpu/src/ssd_split_table_batched_embeddings.cpp b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp similarity index 100% rename from fbgemm_gpu/src/ssd_split_table_batched_embeddings.cpp rename to fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp diff --git a/fbgemm_gpu/src/ssd_table_batched_embeddings.h b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h similarity index 100% rename from fbgemm_gpu/src/ssd_table_batched_embeddings.h rename to fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h diff --git a/fbgemm_gpu/test/batched_unary_embeddings_test.py b/fbgemm_gpu/test/batched_unary_embeddings_test.py index bbd97f294f..3d63aff90f 100644 --- a/fbgemm_gpu/test/batched_unary_embeddings_test.py +++ b/fbgemm_gpu/test/batched_unary_embeddings_test.py @@ -5,7 +5,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe import random import unittest @@ -30,6 +29,7 @@ # Relative tolerances +# pyre-fixme[5]: Global expression must be annotated. TOLERANCE_REL = { torch.float32: 1e-4, torch.float16: 1e-2, @@ -37,6 +37,7 @@ } # Absolute tolerances +# pyre-fixme[5]: Global expression must be annotated. TOLERANCE_ABS = { torch.float32: 1e-4, torch.float16: 1e-2, @@ -81,7 +82,11 @@ def forward( return torch.cat(tt_list).view(self.num_tasks, -1, len(self.hash_sizes)) def _generate_unary_features( - self, batch_size: int, num_embeddings: int + self, + batch_size: int, + num_embeddings: int + # pyre-fixme[24]: Generic type `list` expects 1 type parameter, use + # `typing.List[]` to avoid runtime subscripting errors. ) -> Tuple[List, List, List]: lengths = [] offsets = [] @@ -212,6 +217,8 @@ def _test_main( output = output[1:] output.sum().backward() + # pyre-fixme[56]: Pyre was not able to infer the type of argument + # `test_utils.gpu_unavailable` to decorator factory `unittest.skipIf`. @unittest.skipIf(*gpu_unavailable) def test_gpu(self) -> None: self._test_main(gpu_infer=True) diff --git a/fbgemm_gpu/test/failures_dict.json b/fbgemm_gpu/test/failures_dict.json index 8e53be56ed..43efa968c9 100644 --- a/fbgemm_gpu/test/failures_dict.json +++ b/fbgemm_gpu/test/failures_dict.json @@ -2,6 +2,37 @@ "_description": "This is a dict containing failures for tests autogenerated by generate_opcheck_tests. For more details, please see https://docs.google.com/document/d/1Pj5HRZvdOq3xpFpbEjUZp2hBovhy7Wnxw14m6lF2154/edit", "_version": 1, "data": { + "fbgemm::FP8RowwiseQuantizedToFloat": {}, + "fbgemm::FloatToFP8RowwiseQuantized": { + "TestFP8RowwiseQuantizationConversion.test_aot_dispatch_dynamic__test_quantize_and_dequantize_op_fp8_rowwise": { + "comment": "", + "status": "xsuccess" + }, + "TestFP8RowwiseQuantizationConversion.test_faketensor__test_quantize_and_dequantize_op_fp8_rowwise": { + "comment": "", + "status": "xsuccess" + } + }, + "fbgemm::FloatToPaddedFP8RowwiseQuantized": { + "TestFP8RowwiseQuantizationConversion.test_aot_dispatch_dynamic__test_quantize_and_dequantize_op_padded_fp8_rowwise": { + "comment": "", + "status": "xfail" + }, + "TestFP8RowwiseQuantizationConversion.test_faketensor__test_quantize_and_dequantize_op_padded_fp8_rowwise": { + "comment": "", + "status": "xfail" + } + }, + "fbgemm::PaddedFP8RowwiseQuantizedToFloat": { + "TestFP8RowwiseQuantizationConversion.test_aot_dispatch_dynamic__test_quantize_and_dequantize_op_padded_fp8_rowwise": { + "comment": "", + "status": "xfail" + }, + "TestFP8RowwiseQuantizationConversion.test_faketensor__test_quantize_and_dequantize_op_padded_fp8_rowwise": { + "comment": "", + "status": "xfail" + } + }, "fbgemm::asynchronous_complete_cumsum": {}, "fbgemm::asynchronous_exclusive_cumsum": {}, "fbgemm::asynchronous_inclusive_cumsum": {}, @@ -29,6 +60,10 @@ "comment": "", "status": "xfail" }, + "SparseOpsTest.test_aot_dispatch_dynamic__test_block_bucketize_sparse_features_with_block_bucketize_pos": { + "comment": "", + "status": "xfail" + }, "SparseOpsTest.test_aot_dispatch_dynamic__test_block_bucketize_sparse_features_with_variable_batch_sizes": { "comment": "", "status": "xfail" @@ -41,6 +76,10 @@ "comment": "", "status": "xfail" }, + "SparseOpsTest.test_faketensor__test_block_bucketize_sparse_features_with_block_bucketize_pos": { + "comment": "", + "status": "xfail" + }, "SparseOpsTest.test_faketensor__test_block_bucketize_sparse_features_with_variable_batch_sizes": { "comment": "", "status": "xfail" @@ -229,27 +268,27 @@ "fbgemm::jagged_index_select": { "JaggedTensorOpsTest.test_aot_dispatch_dynamic__test_jagged_index_select_2d": { "comment": "", - "status": "xfail" + "status": "xsuccess" }, "JaggedTensorOpsTest.test_aot_dispatch_dynamic__test_jagged_index_select_2d_in_inference": { "comment": "", - "status": "xfail" + "status": "xsuccess" }, "JaggedTensorOpsTest.test_aot_dispatch_dynamic__test_keyed_jagged_index_select_dim1": { "comment": "", - "status": "xfail" + "status": "xsuccess" }, "JaggedTensorOpsTest.test_faketensor__test_jagged_index_select_2d": { "comment": "", - "status": "xfail" + "status": "xsuccess" }, "JaggedTensorOpsTest.test_faketensor__test_jagged_index_select_2d_in_inference": { "comment": "", - "status": "xfail" + "status": "xsuccess" }, "JaggedTensorOpsTest.test_faketensor__test_keyed_jagged_index_select_dim1": { "comment": "", - "status": "xfail" + "status": "xsuccess" } }, "fbgemm::jagged_jagged_bmm": {}, @@ -310,16 +349,7 @@ "status": "xfail" } }, - "fbgemm::masked_select_jagged_1d": { - "JaggedTensorOpsTest.test_aot_dispatch_dynamic__test_masked_select_jagged_1d": { - "comment": "", - "status": "xfail" - }, - "JaggedTensorOpsTest.test_faketensor__test_masked_select_jagged_1d": { - "comment": "", - "status": "xfail" - } - }, + "fbgemm::masked_select_jagged_1d": {}, "fbgemm::offsets_range": { "JaggedTensorOpsTest.test_aot_dispatch_dynamic__test_jagged_1d_to_dense": { "comment": "", @@ -344,28 +374,70 @@ "status": "xsuccess" } }, - "fbgemm::permute102_baddbmm_permute102": { - "SparseOpsTest.test_aot_dispatch_dynamic__test_permute102_baddbmm_permute102": { + "fbgemm::padding_fused_tbe_input_combine": { + "InputCombineTest.test_aot_dispatch_dynamic__test_padding_fused_input_combine_int32": { "comment": "", "status": "xfail" }, - "SparseOpsTest.test_faketensor__test_permute102_baddbmm_permute102": { + "InputCombineTest.test_aot_dispatch_dynamic__test_padding_fused_input_combine_int64": { + "comment": "", + "status": "xfail" + }, + "InputCombineTest.test_aot_dispatch_dynamic__test_padding_fused_input_combined_mix": { + "comment": "", + "status": "xfail" + }, + "InputCombineTest.test_faketensor__test_padding_fused_input_combine_int32": { + "comment": "", + "status": "xfail" + }, + "InputCombineTest.test_faketensor__test_padding_fused_input_combine_int64": { + "comment": "", + "status": "xfail" + }, + "InputCombineTest.test_faketensor__test_padding_fused_input_combined_mix": { "comment": "", "status": "xfail" } }, - "fbgemm::permute_1D_sparse_data": { - "SparseOpsTest.test_schema__test_permute_indices": { - "comment": "flaky", - "status": "skip" + "fbgemm::padding_fused_tbe_input_combine_with_length": { + "InputCombineTest.test_aot_dispatch_dynamic__test_padding_fused_input_combine_int32_with_length": { + "comment": "", + "status": "xfail" + }, + "InputCombineTest.test_aot_dispatch_dynamic__test_padding_fused_input_combine_int64_with_length": { + "comment": "", + "status": "xfail" + }, + "InputCombineTest.test_aot_dispatch_dynamic__test_padding_fused_input_combined_mix_with_length": { + "comment": "", + "status": "xfail" + }, + "InputCombineTest.test_faketensor__test_padding_fused_input_combine_int32_with_length": { + "comment": "", + "status": "xfail" + }, + "InputCombineTest.test_faketensor__test_padding_fused_input_combine_int64_with_length": { + "comment": "", + "status": "xfail" + }, + "InputCombineTest.test_faketensor__test_padding_fused_input_combined_mix_with_length": { + "comment": "", + "status": "xfail" } }, - "fbgemm::permute_2D_sparse_data": { - "SparseOpsTest.test_schema__test_permute_indices": { - "comment": "flaky", - "status": "skip" + "fbgemm::permute102_baddbmm_permute102": { + "SparseOpsTest.test_aot_dispatch_dynamic__test_permute102_baddbmm_permute102": { + "comment": "", + "status": "xfail" + }, + "SparseOpsTest.test_faketensor__test_permute102_baddbmm_permute102": { + "comment": "", + "status": "xfail" } }, + "fbgemm::permute_1D_sparse_data": {}, + "fbgemm::permute_2D_sparse_data": {}, "fbgemm::permute_sequence_embeddings": { "SparseOpsTest.test_aot_dispatch_dynamic__test_permute_embeddings": { "comment": "", @@ -439,11 +511,11 @@ "fbgemm::segment_sum_csr": { "SparseOpsTest.test_aot_dispatch_dynamic__test_segment_sum_csr": { "comment": "", - "status": "xfail" + "status": "xsuccess" }, "SparseOpsTest.test_faketensor__test_segment_sum_csr": { "comment": "", - "status": "xfail" + "status": "xsuccess" } }, "fbgemm::stacked_jagged_1d_to_dense": { @@ -469,6 +541,33 @@ "comment": "", "status": "xfail" } + }, + "fbgemm::tbe_input_combine": {}, + "fbgemm::tbe_input_combine_with_length": { + "InputCombineTest.test_aot_dispatch_dynamic__test_input_combine_int32_with_length": { + "comment": "", + "status": "xfail" + }, + "InputCombineTest.test_aot_dispatch_dynamic__test_input_combine_int64_with_length": { + "comment": "", + "status": "xfail" + }, + "InputCombineTest.test_aot_dispatch_dynamic__test_input_combine_mix_with_length": { + "comment": "", + "status": "xfail" + }, + "InputCombineTest.test_faketensor__test_input_combine_int32_with_length": { + "comment": "", + "status": "xfail" + }, + "InputCombineTest.test_faketensor__test_input_combine_int64_with_length": { + "comment": "", + "status": "xfail" + }, + "InputCombineTest.test_faketensor__test_input_combine_mix_with_length": { + "comment": "", + "status": "xfail" + } } } } diff --git a/fbgemm_gpu/test/failures_dict_fast.json b/fbgemm_gpu/test/failures_dict_fast.json index a691ff1ad3..cec2a1adc7 100644 --- a/fbgemm_gpu/test/failures_dict_fast.json +++ b/fbgemm_gpu/test/failures_dict_fast.json @@ -238,6 +238,12 @@ "status": "xfail" } }, + "fbgemm::get_unique_indices": { + "SplitTableBatchedEmbeddingsTest.test_faketensor__test_unique_lxu_cache_lookup": { + "comment": "", + "status": "xfail" + } + }, "fbgemm::int_nbit_split_embedding_codegen_lookup_function": { "SplitTableBatchedEmbeddingsTest.test_faketensor__test_int_nbit_split_embedding_uvm_caching_codegen_lookup_function": { "comment": "", @@ -277,15 +283,19 @@ "fbgemm::lfu_cache_populate": { "SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_sgd": { "comment": "", - "status": "xfail" + "status": "skip" + }, + "SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_sgd_really_long_segments": { + "comment": "", + "status": "skip" }, "SplitTableBatchedEmbeddingsTest.test_faketensor__test_cache_pipeline": { "comment": "", - "status": "xfail" + "status": "skip" }, "SplitTableBatchedEmbeddingsTest.test_schema__test_cache_pipeline": { "comment": "", - "status": "xfail" + "status": "skip" } }, "fbgemm::lfu_cache_populate_byte": { @@ -297,19 +307,27 @@ "fbgemm::linearize_cache_indices": { "SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_adagrad_fp16_pmMEAN": { "comment": "", - "status": "xfail" + "status": "skip" }, "SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_adagrad_fp16_pmNONE": { "comment": "", - "status": "xfail" + "status": "skip" + }, + "SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_adagrad_fp16_pmSUM": { + "comment": "", + "status": "skip" }, "SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_adagrad_fp32_pmMEAN": { "comment": "", - "status": "xfail" + "status": "skip" + }, + "SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_adagrad_fp32_pmNONE": { + "comment": "", + "status": "skip" }, "SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_adagrad_fp32_pmSUM": { "comment": "", - "status": "xfail" + "status": "skip" }, "SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_sgd": { "comment": "", @@ -366,6 +384,10 @@ "SplitTableBatchedEmbeddingsTest.test_faketensor__test_stb_uvm_cache_stats": { "comment": "", "status": "xfail" + }, + "SplitTableBatchedEmbeddingsTest.test_faketensor__test_unique_lxu_cache_lookup": { + "comment": "", + "status": "xfail" } }, "fbgemm::linearize_cache_indices_from_row_idx": { @@ -377,67 +399,75 @@ "fbgemm::lru_cache_populate": { "SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_adagrad_fp16_pmMEAN": { "comment": "", - "status": "xfail" + "status": "skip" }, "SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_adagrad_fp16_pmNONE": { "comment": "", - "status": "xfail" + "status": "skip" + }, + "SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_adagrad_fp16_pmSUM": { + "comment": "", + "status": "skip" }, "SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_adagrad_fp32_pmMEAN": { "comment": "", - "status": "xfail" + "status": "skip" + }, + "SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_adagrad_fp32_pmNONE": { + "comment": "", + "status": "skip" }, "SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_adagrad_fp32_pmSUM": { "comment": "", - "status": "xfail" + "status": "skip" }, "SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_sgd": { "comment": "", - "status": "xfail" + "status": "skip" }, "SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_sgd_really_long_segments": { "comment": "", - "status": "xfail" + "status": "skip" }, "SplitTableBatchedEmbeddingsTest.test_faketensor__test_cache_miss_counter": { "comment": "", - "status": "xfail" + "status": "skip" }, "SplitTableBatchedEmbeddingsTest.test_faketensor__test_cache_pipeline": { "comment": "", - "status": "xfail" + "status": "skip" }, "SplitTableBatchedEmbeddingsTest.test_faketensor__test_cache_prefetch_pipeline": { "comment": "", - "status": "xfail" + "status": "skip" }, "SplitTableBatchedEmbeddingsTest.test_faketensor__test_cache_prefetch_pipeline_stream_1": { "comment": "", - "status": "xfail" + "status": "skip" }, "SplitTableBatchedEmbeddingsTest.test_faketensor__test_cache_prefetch_pipeline_stream_2": { "comment": "", - "status": "xfail" + "status": "skip" }, "SplitTableBatchedEmbeddingsTest.test_faketensor__test_stb_uvm_cache_stats": { "comment": "", - "status": "xfail" + "status": "skip" }, "SplitTableBatchedEmbeddingsTest.test_schema__test_cache_pipeline": { "comment": "", - "status": "xfail" + "status": "skip" }, "SplitTableBatchedEmbeddingsTest.test_schema__test_cache_prefetch_pipeline": { "comment": "", - "status": "xfail" + "status": "skip" }, "SplitTableBatchedEmbeddingsTest.test_schema__test_cache_prefetch_pipeline_stream_1": { "comment": "", - "status": "xfail" + "status": "skip" }, "SplitTableBatchedEmbeddingsTest.test_schema__test_cache_prefetch_pipeline_stream_2": { "comment": "", - "status": "xfail" + "status": "skip" } }, "fbgemm::lru_cache_populate_byte": { @@ -467,10 +497,18 @@ "comment": "", "status": "xfail" }, + "SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_adagrad_fp16_pmSUM": { + "comment": "", + "status": "xfail" + }, "SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_adagrad_fp32_pmMEAN": { "comment": "", "status": "xfail" }, + "SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_adagrad_fp32_pmNONE": { + "comment": "", + "status": "xfail" + }, "SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_adagrad_fp32_pmSUM": { "comment": "", "status": "xfail" @@ -515,10 +553,18 @@ "comment": "", "status": "xfail" }, + "SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_adagrad_fp16_pmSUM": { + "comment": "", + "status": "xfail" + }, "SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_adagrad_fp32_pmMEAN": { "comment": "", "status": "xfail" }, + "SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_adagrad_fp32_pmNONE": { + "comment": "", + "status": "xfail" + }, "SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_adagrad_fp32_pmSUM": { "comment": "", "status": "xfail" @@ -574,6 +620,10 @@ "SplitTableBatchedEmbeddingsTest.test_faketensor__test_stb_uvm_cache_stats": { "comment": "", "status": "xfail" + }, + "SplitTableBatchedEmbeddingsTest.test_faketensor__test_unique_lxu_cache_lookup": { + "comment": "", + "status": "xfail" } }, "fbgemm::new_managed_tensor": {}, @@ -648,11 +698,11 @@ "fbgemm::reset_weight_momentum": { "SplitTableBatchedEmbeddingsTest.test_faketensor__test_reset_embedding_weight_momentum": { "comment": "", - "status": "xfail" + "status": "skip" }, "SplitTableBatchedEmbeddingsTest.test_schema__test_reset_embedding_weight_momentum": { "comment": "", - "status": "xfail" + "status": "skip" } }, "fbgemm::split_embedding_codegen_lookup_adagrad_function": {}, @@ -674,45 +724,90 @@ "fbgemm::split_embedding_codegen_lookup_partial_rowwise_adam_function": {}, "fbgemm::split_embedding_codegen_lookup_partial_rowwise_lamb_function": {}, "fbgemm::split_embedding_codegen_lookup_rowwise_adagrad_function": { + "SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_adagrad_fp16_pmMEAN": { + "comment": "", + "status": "skip" + }, + "SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_adagrad_fp16_pmNONE": { + "comment": "", + "status": "skip" + }, "SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_adagrad_fp16_pmSUM": { "comment": "", - "status": "xfail" + "status": "skip" + }, + "SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_adagrad_fp32_pmMEAN": { + "comment": "", + "status": "skip" + }, + "SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_adagrad_fp32_pmNONE": { + "comment": "", + "status": "skip" }, "SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_adagrad_fp32_pmSUM": { "comment": "", - "status": "xfail" + "status": "skip" }, "SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_optimizers_adagrad": { "comment": "", - "status": "xfail" + "status": "skip" + } + }, + "fbgemm::split_embedding_codegen_lookup_rowwise_adagrad_function_cpu": { + "SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_adagrad_fp16_pmMEAN": { + "comment": "", + "status": "skip" + }, + "SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_adagrad_fp16_pmNONE": { + "comment": "", + "status": "skip" + }, + "SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_adagrad_fp16_pmSUM": { + "comment": "", + "status": "skip" + }, + "SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_adagrad_fp32_pmMEAN": { + "comment": "", + "status": "skip" + }, + "SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_adagrad_fp32_pmNONE": { + "comment": "", + "status": "skip" + }, + "SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_adagrad_fp32_pmSUM": { + "comment": "", + "status": "skip" + }, + "SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_optimizers_adagrad": { + "comment": "", + "status": "skip" } }, - "fbgemm::split_embedding_codegen_lookup_rowwise_adagrad_function_cpu": {}, "fbgemm::split_embedding_codegen_lookup_rowwise_weighted_adagrad_function": {}, "fbgemm::split_embedding_codegen_lookup_sgd_function": { "SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_sgd": { "comment": "", - "status": "xfail" + "status": "skip" }, "SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_sgd_really_long_segments": { "comment": "", - "status": "xfail" + "status": "skip" }, "SplitTableBatchedEmbeddingsTest.test_faketensor__test_cache_pipeline": { "comment": "", - "status": "xfail" + "status": "skip" }, "SplitTableBatchedEmbeddingsTest.test_faketensor__test_cache_prefetch_pipeline": { "comment": "", - "status": "xfail" + "status": "skip" }, "SplitTableBatchedEmbeddingsTest.test_faketensor__test_cache_prefetch_pipeline_stream_1": { "comment": "", - "status": "xfail" + "status": "skip" }, "SplitTableBatchedEmbeddingsTest.test_faketensor__test_cache_prefetch_pipeline_stream_2": { "comment": "", - "status": "xfail" + "status": "skip" } }, "fbgemm::split_embedding_codegen_lookup_sgd_function_cpu": { diff --git a/fbgemm_gpu/test/input_combine_test.py b/fbgemm_gpu/test/input_combine_test.py index cbd054b607..db1f0e8928 100644 --- a/fbgemm_gpu/test/input_combine_test.py +++ b/fbgemm_gpu/test/input_combine_test.py @@ -5,12 +5,12 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe import unittest -from typing import List, Optional, Tuple +from typing import Callable, Dict, List, Optional, Tuple import torch +from fbgemm_gpu import sparse_ops # noqa: F401 from hypothesis import given, settings try: @@ -18,11 +18,11 @@ from fbgemm_gpu import open_source # noqa: F401 # pyre-ignore[21] - from test_utils import cpu_and_maybe_gpu + from test_utils import cpu_and_maybe_gpu, optests except Exception: torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:input_combine") torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:input_combine_cpu") - from fbgemm_gpu.test.test_utils import cpu_and_maybe_gpu + from fbgemm_gpu.test.test_utils import cpu_and_maybe_gpu, optests DEFAULT_DEVICE = torch.device("cpu") @@ -127,7 +127,30 @@ def forward( # noqa C901 return combined_indices, combined_offsets, per_sample_weights +# e.g. "test_faketensor__test_cumsum": [unittest.expectedFailure] +# Please avoid putting tests here, you should put operator-specific +# skips and failures in deeplearning/fbgemm/fbgemm_gpu/test/failures_dict.json +# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. +additional_decorators: Dict[str, List[Callable]] = { + "test_pt2_compliant_tag_fbgemm_dense_to_jagged": [ + # This operator has been grandfathered in. We need to fix this test failure. + unittest.expectedFailure, + ], + "test_pt2_compliant_tag_fbgemm_jagged_dense_elementwise_add": [ + # This operator has been grandfathered in. We need to fix this test failure. + unittest.expectedFailure, + ], + "test_pt2_compliant_tag_fbgemm_jagged_dense_elementwise_add_jagged_output": [ + # This operator has been grandfathered in. We need to fix this test failure. + unittest.expectedFailure, + ], +} + + +@optests.generate_opcheck_tests(additional_decorators=additional_decorators) class InputCombineTest(unittest.TestCase): + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. def _get_inputs(self, dtypes, device=DEFAULT_DEVICE): indices_list = [ torch.tensor([1, 2, 3], dtype=dtypes[0], device=device), @@ -154,6 +177,7 @@ def _get_inputs(self, dtypes, device=DEFAULT_DEVICE): include_last_offsets, ) + # pyre-fixme[2]: Parameter must be annotated. def _run_test(self, dtypes) -> None: ( indices_list, @@ -191,6 +215,7 @@ def _run_test(self, dtypes) -> None: self.assertTrue(outputs[1].dtype == torch.int32) self.assertTrue(outputs[-1].size(0) == 0) + # pyre-fixme[2]: Parameter must be annotated. def _run_padding_fused_test(self, dtypes, batch_size) -> None: ( indices_list, @@ -234,8 +259,17 @@ def _run_padding_fused_test(self, dtypes, batch_size) -> None: self.assertTrue(outputs[1].dtype == torch.int32) self.assertTrue(outputs[-1].size(0) == 0) + # pyre-fixme[3]: Return type must be annotated. def _offsets_to_lengths( - self, offsets, indices, include_last_offsets, device=DEFAULT_DEVICE + self, + # pyre-fixme[2]: Parameter must be annotated. + offsets, + # pyre-fixme[2]: Parameter must be annotated. + indices, + # pyre-fixme[2]: Parameter must be annotated. + include_last_offsets, + # pyre-fixme[2]: Parameter must be annotated. + device=DEFAULT_DEVICE, ): if include_last_offsets: offsets_complete = offsets @@ -248,6 +282,7 @@ def _offsets_to_lengths( ) return offsets_complete[1:] - offsets_complete[:-1] + # pyre-fixme[2]: Parameter must be annotated. def _run_test_with_length(self, dtypes, device=DEFAULT_DEVICE) -> None: ( indices_list, @@ -279,6 +314,7 @@ def _run_test_with_length(self, dtypes, device=DEFAULT_DEVICE) -> None: ref_lengths = self._offsets_to_lengths(ref_outputs[1], ref_outputs[0], True) self.assertTrue(ref_lengths.allclose(outputs[1])) + # pyre-fixme[2]: Parameter must be annotated. def _run_padding_fused_test_with_length(self, dtypes, batch_size) -> None: ( indices_list, @@ -322,16 +358,22 @@ def test_input_combine_int32(self) -> None: def test_input_combined_mix(self) -> None: self._run_test((torch.int64, torch.int32)) + # pyre-fixme[56]: Pyre was not able to infer the type of argument + # `test_utils.cpu_and_maybe_gpu()` to decorator factory `hypothesis.given`. @given(device=cpu_and_maybe_gpu()) @settings(deadline=None) def test_input_combine_int64_with_length(self, device: torch.device) -> None: self._run_test_with_length((torch.int64, torch.int64), device=device) + # pyre-fixme[56]: Pyre was not able to infer the type of argument + # `test_utils.cpu_and_maybe_gpu()` to decorator factory `hypothesis.given`. @given(device=cpu_and_maybe_gpu()) @settings(deadline=None) def test_input_combine_int32_with_length(self, device: torch.device) -> None: self._run_test_with_length((torch.int32, torch.int32), device=device) + # pyre-fixme[56]: Pyre was not able to infer the type of argument + # `test_utils.cpu_and_maybe_gpu()` to decorator factory `hypothesis.given`. @given(device=cpu_and_maybe_gpu()) @settings(deadline=None) def test_input_combine_mix_with_length(self, device: torch.device) -> None: diff --git a/fbgemm_gpu/test/jagged_tensor_ops_test.py b/fbgemm_gpu/test/jagged_tensor_ops_test.py index f83bd941bc..8465490282 100644 --- a/fbgemm_gpu/test/jagged_tensor_ops_test.py +++ b/fbgemm_gpu/test/jagged_tensor_ops_test.py @@ -10,7 +10,7 @@ import itertools import random import unittest -from typing import List, Tuple +from typing import Callable, Dict, List, Tuple import hypothesis.strategies as st import numpy as np @@ -35,7 +35,7 @@ except Exception: torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") - import fbgemm_gpu.sparse_operators # noqa: F401, E402 + import fbgemm_gpu.sparse_ops # noqa: F401, E402 from fbgemm_gpu.test.test_utils import ( gpu_available, gpu_unavailable, @@ -127,7 +127,27 @@ def hash_size_cumsum_to_offsets(hash_size_cum_sum_list: List[int]) -> List[int]: return hash_size_offsets_list -@optests.generate_opcheck_tests +# e.g. "test_faketensor__test_cumsum": [unittest.expectedFailure] +# Please avoid putting tests here, you should put operator-specific +# skips and failures in deeplearning/fbgemm/fbgemm_gpu/test/failures_dict.json +# pyre-ignore[24]: Generic type `Callable` expects 2 type parameters. +additional_decorators: Dict[str, List[Callable]] = { + "test_pt2_compliant_tag_fbgemm_dense_to_jagged": [ + # This operator has been grandfathered in. We need to fix this test failure. + unittest.expectedFailure, + ], + "test_pt2_compliant_tag_fbgemm_jagged_dense_elementwise_add": [ + # This operator has been grandfathered in. We need to fix this test failure. + unittest.expectedFailure, + ], + "test_pt2_compliant_tag_fbgemm_jagged_dense_elementwise_add_jagged_output": [ + # This operator has been grandfathered in. We need to fix this test failure. + unittest.expectedFailure, + ], +} + + +@optests.generate_opcheck_tests(additional_decorators=additional_decorators) class JaggedTensorOpsTest(unittest.TestCase): def setUp(self) -> None: if symint_vector_unsupported()[0]: diff --git a/fbgemm_gpu/test/lint/check_meta_header.py b/fbgemm_gpu/test/lint/check_meta_header.py index fd5178d753..5fd5e41f6c 100644 --- a/fbgemm_gpu/test/lint/check_meta_header.py +++ b/fbgemm_gpu/test/lint/check_meta_header.py @@ -5,7 +5,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe """Check Python source code contains Meta copyright header """ diff --git a/fbgemm_gpu/test/merge_pooled_embeddings_test.py b/fbgemm_gpu/test/merge_pooled_embeddings_test.py index 0e4b7986e7..a686cdb3eb 100644 --- a/fbgemm_gpu/test/merge_pooled_embeddings_test.py +++ b/fbgemm_gpu/test/merge_pooled_embeddings_test.py @@ -5,7 +5,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe import unittest from typing import Tuple @@ -26,6 +25,7 @@ torch.ops.load_library( "//deeplearning/fbgemm/fbgemm_gpu:merge_pooled_embeddings_cpu" ) + import fbgemm_gpu.sparse_ops # noqa: F401, E402 from fbgemm_gpu.test.test_utils import gpu_unavailable open_source = False @@ -36,6 +36,9 @@ @unittest.skipIf(*gpu_unavailable) @unittest.skipIf(open_source, "Not supported in open source yet") class MergePooledEmbeddingsTest(unittest.TestCase): + # pyre-fixme[56]: Pyre was not able to infer the type of argument + # `hypothesis.strategies.integers($parameter$min_value = 1, $parameter$max_value = + # 10)` to decorator factory `hypothesis.given`. @given( num_ads=st.integers(min_value=1, max_value=10), embedding_dimension=st.integers(min_value=1, max_value=32), @@ -49,11 +52,17 @@ class MergePooledEmbeddingsTest(unittest.TestCase): @settings(verbosity=Verbosity.verbose, max_examples=40, deadline=None) def test_merge( self, + # pyre-fixme[2]: Parameter must be annotated. num_ads, + # pyre-fixme[2]: Parameter must be annotated. embedding_dimension, + # pyre-fixme[2]: Parameter must be annotated. ads_tables, + # pyre-fixme[2]: Parameter must be annotated. num_gpus, + # pyre-fixme[2]: Parameter must be annotated. non_default_stream, + # pyre-fixme[2]: Parameter must be annotated. r, dim: int, ) -> None: @@ -82,6 +91,8 @@ def test_merge( pooled_ad_embeddings, uncat_size, batch_indices.device, dim ) + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. def ref(pooled_ad_embeddings, batch_indices): return torch.cat([p.cpu() for p in pooled_ad_embeddings], dim=dim) @@ -96,6 +107,9 @@ def ref(pooled_ad_embeddings, batch_indices): torch.testing.assert_close(output_ref, output.cpu()) torch.testing.assert_close(output_ref, output_cpu) + # pyre-fixme[56]: Pyre was not able to infer the type of argument + # `hypothesis.strategies.integers($parameter$min_value = 1, $parameter$max_value = + # 10)` to decorator factory `hypothesis.given`. @given( num_inputs=st.integers(min_value=1, max_value=10), num_gpus=st.integers(min_value=1, max_value=torch.cuda.device_count()), @@ -105,8 +119,11 @@ def ref(pooled_ad_embeddings, batch_indices): @settings(verbosity=Verbosity.verbose, max_examples=40, deadline=None) def test_all_to_one_device( self, + # pyre-fixme[2]: Parameter must be annotated. num_inputs, + # pyre-fixme[2]: Parameter must be annotated. num_gpus, + # pyre-fixme[2]: Parameter must be annotated. r, ) -> None: dst_device = torch.device(f"cuda:{r.randint(0, num_gpus - 1)}") @@ -132,6 +149,9 @@ def test_merge_pooled_embeddings_cpu_with_different_target_device(self) -> None: self.assertFalse(output_meta.is_cpu) self.assertTrue(output_meta.is_meta) + # pyre-fixme[56]: Pyre was not able to infer the type of argument + # `hypothesis.strategies.integers($parameter$min_value = 1, $parameter$max_value = + # 10)` to decorator factory `hypothesis.given`. @given( num_inputs=st.integers(min_value=1, max_value=10), num_gpus=st.integers(min_value=1, max_value=torch.cuda.device_count()), @@ -141,8 +161,11 @@ def test_merge_pooled_embeddings_cpu_with_different_target_device(self) -> None: @settings(verbosity=Verbosity.verbose, max_examples=10, deadline=None) def test_sum_reduce_to_one( self, + # pyre-fixme[2]: Parameter must be annotated. num_inputs, + # pyre-fixme[2]: Parameter must be annotated. num_gpus, + # pyre-fixme[2]: Parameter must be annotated. r, ) -> None: dst_device = torch.device(f"cuda:{r.randint(0, num_gpus - 1)}") @@ -167,6 +190,11 @@ def test_merge_pooled_embeddings_meta(self) -> None: cat_dim = 1 pooled_embeddings = [torch.ones(uncat_size, 4), torch.ones(uncat_size, 8)] + # pyre-fixme[53]: Captured variable `cat_dim` is not annotated. + # pyre-fixme[53]: Captured variable `pooled_embeddings` is not annotated. + # pyre-fixme[53]: Captured variable `uncat_size` is not annotated. + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. def fbgemm_merge_pooled_embeddings(device): pooled_embeddings_device = [ pooled_embedding.to(device) for pooled_embedding in pooled_embeddings diff --git a/fbgemm_gpu/test/permute_pooled_embedding_test.py b/fbgemm_gpu/test/permute_pooled_embedding_test.py index 3365457c30..3723ee76c4 100644 --- a/fbgemm_gpu/test/permute_pooled_embedding_test.py +++ b/fbgemm_gpu/test/permute_pooled_embedding_test.py @@ -46,7 +46,7 @@ FIXED_EXTERN_API = { "PermutePooledEmbeddings": { "__init__": ["self", "embs_dims", "permute", "device"], - "forward": ["self", "pooled_embs"], + "__call__": ["self", "pooled_embs"], }, } diff --git a/fbgemm_gpu/test/quantize_ops_test.py b/fbgemm_gpu/test/quantize_ops_test.py index b2dac18a7e..493f35f9f8 100644 --- a/fbgemm_gpu/test/quantize_ops_test.py +++ b/fbgemm_gpu/test/quantize_ops_test.py @@ -9,7 +9,7 @@ import random import unittest from ctypes import c_float, c_int32, cast, POINTER, pointer -from typing import Dict, Tuple +from typing import Callable, Dict, List, Tuple import hypothesis.strategies as st import numpy as np @@ -32,6 +32,7 @@ fused_rowwise_nbit_quantize_reference, gpu_available, gpu_unavailable, + optests, symint_vector_unsupported, ) except Exception: @@ -45,6 +46,7 @@ fused_rowwise_nbit_quantize_reference, gpu_available, gpu_unavailable, + optests, symint_vector_unsupported, ) @@ -999,6 +1001,27 @@ def test_quantize_and_dequantize_op_cuda_large_nrows_bf16( torch.testing.assert_close(dequantized_data_gpu.cpu(), dequantized_data) +# e.g. "test_faketensor__test_cumsum": [unittest.expectedFailure] +# Please avoid putting tests here, you should put operator-specific +# skips and failures in deeplearning/fbgemm/fbgemm_gpu/test/failures_dict.json +# pyre-ignore[24]: Generic type `Callable` expects 2 type parameters. +additional_decorators: Dict[str, List[Callable]] = { + "test_pt2_compliant_tag_fbgemm_dense_to_jagged": [ + # This operator has been grandfathered in. We need to fix this test failure. + unittest.expectedFailure, + ], + "test_pt2_compliant_tag_fbgemm_jagged_dense_elementwise_add": [ + # This operator has been grandfathered in. We need to fix this test failure. + unittest.expectedFailure, + ], + "test_pt2_compliant_tag_fbgemm_jagged_dense_elementwise_add_jagged_output": [ + # This operator has been grandfathered in. We need to fix this test failure. + unittest.expectedFailure, + ], +} + + +@optests.generate_opcheck_tests(additional_decorators=additional_decorators) class TestFP8RowwiseQuantizationConversion(unittest.TestCase): enable_logging: bool = False max_examples: int = 40 diff --git a/fbgemm_gpu/test/sparse_ops_test.py b/fbgemm_gpu/test/sparse_ops_test.py index f3cbb5dbf1..d06b7988d0 100644 --- a/fbgemm_gpu/test/sparse_ops_test.py +++ b/fbgemm_gpu/test/sparse_ops_test.py @@ -15,7 +15,7 @@ import random import unittest from itertools import accumulate -from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union +from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Type, Union import fbgemm_gpu @@ -38,7 +38,7 @@ torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu/codegen:index_select_ops") - import fbgemm_gpu.sparse_operators # noqa: F401, E402 + import fbgemm_gpu.sparse_ops # noqa: F401, E402 from fbgemm_gpu.test.test_utils import gpu_available, gpu_unavailable, skipIfRocm suppressed_list: List[HealthCheck] = ( @@ -942,6 +942,149 @@ def test_block_bucketize_sparse_features_with_variable_batch_sizes( new_indices_gpu.cpu(), new_indices_ref, rtol=0, atol=0 ) + @given( + index_type=st.sampled_from([torch.int, torch.long]), + has_weight=st.booleans(), + bucketize_pos=st.booleans(), + sequence=st.booleans(), + ) + @settings(verbosity=Verbosity.verbose, max_examples=16, deadline=None) + def test_block_bucketize_sparse_features_with_block_bucketize_pos( + self, + index_type: Optional[torch.dtype], + has_weight: bool, + bucketize_pos: bool, + sequence: bool, + ) -> None: + """ + Test variable bucket size for block bucketize_sparse features for RW sharding. + E.g. Given bucket_sizes_pos as [[0,5,15], [0,10,13]] + For batch 0, indices in [0,5) will be assigned to bucket 0, indices in [5,15) will be assigned to bucket 1. + For batch 1, indices in [0,10) will be assigned to bucket 0, indices in [10,13) will be assigned to bucket 1. + The new index will be original index - bucket_sizes_pos[new_bucket_id-1] + i.e. for batch = 0, index = 12, it will be assigned to bucket 1 and the new index is 12 - 5 = 7. + """ + # For the following test case, we have + # batch 0: 2 (1,7), 1 (2), 1 (6) + # 1: bucket 0, new_idx 1 + # 7: bucket 1, new_idx 5 + # 2: bucket 1, new_idx 0 + # 6: bucket 1, new_idx 4 + + # batch 1: 2 (7,8) + # 7: bucket 1, new_idx 2 + # 8: bucket 1, new_idx 3 + + # batch 2: 0, 2 (8,4) + # 8: bucket 1, new_idx 1 + # 4: bucket 0, new_idx 4 + + # new_lengths for 0: 1, 0, 0, 0, 0, 1 + # new_indices for 0: 1| | | | | 4 + # new_lengths for 1: 1, 1, 1, 2, 0, 1 + # new_indices for 1: 5| 0| 4| 2,3| |1 + lengths = torch.tensor([2, 1, 1, 2, 0, 2], dtype=index_type) + indices = torch.tensor( + [1, 7, 2, 6, 7, 8, 8, 4], + dtype=index_type, + ) + batch_sizes = torch.tensor([3, 1, 2], dtype=index_type) + weights = ( + torch.tensor( + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], + dtype=torch.float, + ) + if has_weight + else None + ) + + block_sizes = torch.tensor([5, 10, 8], dtype=index_type) + my_size = 2 + max_B = batch_sizes.max().item() # unused + + block_bucketize_pos = [ + torch.tensor([0, 2, 8], dtype=index_type), + torch.tensor([0, 5, 10], dtype=index_type), + torch.tensor([0, 7, 12], dtype=index_type), + ] + + new_lengths_ref = torch.tensor( + [1, 0, 0, 0, 0, 1, 1, 1, 1, 2, 0, 1], + dtype=index_type, + ) + new_indices_ref = torch.tensor( + [1, 4, 5, 0, 4, 2, 3, 1], + dtype=index_type, + ) + new_weights_ref = torch.tensor( + [ + 1.0, + 8.0, + 2.0, + 3.0, + 4.0, + 5.0, + 6.0, + 7.0, + ], + dtype=torch.float, + ) + ( + new_lengths_cpu, + new_indices_cpu, + new_weights_cpu, + new_pos_cpu, + unbucketize_permute, + ) = torch.ops.fbgemm.block_bucketize_sparse_features( + lengths, + indices, + bucketize_pos, + sequence, + block_sizes, + my_size, + weights, + batch_sizes, + max_B, + block_bucketize_pos, + ) + torch.testing.assert_close(new_lengths_cpu, new_lengths_ref, rtol=0, atol=0) + torch.testing.assert_close(new_indices_cpu, new_indices_ref, rtol=0, atol=0) + if has_weight: + torch.testing.assert_close(new_weights_cpu, new_weights_ref) + + if gpu_available: + block_bucketize_pos = [ + torch.tensor([0, 2, 8], dtype=index_type, device="cuda"), + torch.tensor([0, 5, 10], dtype=index_type, device="cuda"), + torch.tensor([0, 7, 12], dtype=index_type, device="cuda"), + ] + ( + new_lengths_gpu, + new_indices_gpu, + new_weights_gpu, + new_pos_gpu, + unbucketize_permute, + ) = torch.ops.fbgemm.block_bucketize_sparse_features( + lengths.cuda(), + indices.cuda(), + bucketize_pos, + sequence, + block_sizes.cuda(), + my_size, + weights.cuda() if weights is not None else None, + batch_sizes.cuda(), + max_B, + block_bucketize_pos, + ) + torch.testing.assert_close( + new_lengths_gpu.cpu(), new_lengths_ref, rtol=0, atol=0 + ) + torch.testing.assert_close( + new_indices_gpu.cpu(), new_indices_ref, rtol=0, atol=0 + ) + if has_weight: + torch.testing.assert_close(new_weights_gpu.cpu(), new_weights_ref) + @given( index_type=st.sampled_from([torch.int, torch.long]), has_weight=st.booleans(), @@ -1114,7 +1257,7 @@ def test_reorder_batched_ad_lengths_cpu( T=st.integers(min_value=1, max_value=20), L=st.integers(min_value=2, max_value=20), A=st.integers(min_value=1, max_value=20), - Dtype=st.sampled_from([torch.int32, torch.float, torch.int64]), + Dtype=st.sampled_from([torch.int32, torch.float, torch.int64, torch.bfloat16]), Itype=st.sampled_from([torch.int32, torch.int64]), broadcast_indices=st.booleans(), ) @@ -2402,6 +2545,167 @@ def validate( "grad", ) + def permute_sparse_features_ref_( + self, + lengths: torch.Tensor, + indices: torch.Tensor, + weights: Optional[torch.Tensor], + permute: torch.LongTensor, + ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + T = lengths.size(0) + B = lengths.size(1) + permuted_lengths = torch.index_select(lengths.view(T, B), 0, permute) + + original_segment_lengths = lengths.view(T, B).sum(dim=1, dtype=torch.int32) + original_segment_start = torch.ops.fbgemm.asynchronous_exclusive_cumsum( + original_segment_lengths.view(-1) + ) + + permuted_indices = [] + permuted_weights = [] + for i in range(permute.size(0)): + start = original_segment_start[permute[i]] + end = start + original_segment_lengths[permute[i]] + permuted_indices.append(indices[start:end]) + if weights is not None: + permuted_weights.append(weights[start:end]) + + permuted_indices = torch.cat(permuted_indices, dim=0).flatten() + + if weights is None: + permuted_weights = None + else: + permuted_weights = torch.cat(permuted_weights, dim=0).flatten() + + return permuted_lengths, permuted_indices, permuted_weights + + @given( + B=st.integers(min_value=1, max_value=20), + T=st.integers(min_value=1, max_value=20), + L=st.integers(min_value=2, max_value=20), + long_index=st.booleans(), + has_weight=st.booleans(), + ) + @settings(max_examples=20, deadline=None) + def test_permute_sparse_features( + self, B: int, T: int, L: int, long_index: bool, has_weight: bool + ) -> None: + index_dtype = torch.int64 if long_index else torch.int32 + lengths = torch.randint(low=1, high=L, size=(T, B)).type(index_dtype) + weights = torch.rand(int(lengths.sum().item())).float() if has_weight else None + indices = torch.randint( + low=1, + high=int(1e5), + size=cast(Tuple[int, ...], (lengths.sum().item(),)), + ).type(index_dtype) + permute_list = list(range(T)) + random.shuffle(permute_list) + permute = torch.IntTensor(permute_list) + + ( + permuted_lengths_cpu, + permuted_indices_cpu, + permuted_weights_cpu, + ) = torch.ops.fbgemm.permute_sparse_features(permute, lengths, indices, weights) + ( + permuted_lengths_ref, + permuted_indices_ref, + permuted_weights_ref, + # pyre-fixme[6]: For 4th param expected `LongTensor` but got `Tensor`. + ) = self.permute_indices_ref_(lengths, indices, weights, permute.long()) + torch.testing.assert_close(permuted_indices_cpu, permuted_indices_ref) + torch.testing.assert_close(permuted_lengths_cpu, permuted_lengths_ref) + if has_weight: + torch.testing.assert_close(permuted_weights_cpu, permuted_weights_ref) + else: + assert permuted_weights_cpu is None and permuted_weights_ref is None + + if gpu_available: + ( + permuted_lengths_gpu, + permuted_indices_gpu, + permuted_weights_gpu, + ) = torch.ops.fbgemm.permute_sparse_features( + permute.cuda(), + lengths.cuda(), + indices.cuda(), + weights.cuda() if has_weight and weights is not None else None, + ) + torch.testing.assert_close(permuted_indices_gpu.cpu(), permuted_indices_cpu) + torch.testing.assert_close(permuted_lengths_gpu.cpu(), permuted_lengths_cpu) + if has_weight: + torch.testing.assert_close( + permuted_weights_gpu.cpu(), permuted_weights_cpu + ) + else: + assert permuted_weights_gpu is None + + @given( + B=st.integers(min_value=1, max_value=20), + T=st.integers(min_value=1, max_value=20), + L=st.integers(min_value=2, max_value=20), + long_index=st.booleans(), + has_weight=st.booleans(), + ) + @settings(max_examples=20, deadline=None) + def test_permute_sparse_features_with_repeats( + self, B: int, T: int, L: int, long_index: bool, has_weight: bool + ) -> None: + index_dtype = torch.int64 if long_index else torch.int32 + lengths = torch.randint(low=1, high=L, size=(T, B)).type(index_dtype) + weights = torch.rand(int(lengths.sum().item())).float() if has_weight else None + indices = torch.randint( + low=1, + high=int(1e5), + size=cast(Tuple[int, ...], (lengths.sum().item(),)), + ).type(index_dtype) + permute_list = list(range(T)) + + num_repeats = random.randint(0, T) + for _ in range(num_repeats): + permute_list.append(random.randint(0, T - 1)) + + random.shuffle(permute_list) + permute = torch.IntTensor(permute_list) + + ( + permuted_lengths_cpu, + permuted_indices_cpu, + permuted_weights_cpu, + ) = torch.ops.fbgemm.permute_sparse_features(permute, lengths, indices, weights) + ( + permuted_lengths_ref, + permuted_indices_ref, + permuted_weights_ref, + # pyre-fixme[6]: For 4th param expected `LongTensor` but got `Tensor`. + ) = self.permute_indices_ref_(lengths, indices, weights, permute.long()) + torch.testing.assert_close(permuted_indices_cpu, permuted_indices_ref) + torch.testing.assert_close(permuted_lengths_cpu, permuted_lengths_ref) + if has_weight: + torch.testing.assert_close(permuted_weights_cpu, permuted_weights_ref) + else: + assert permuted_weights_cpu is None and permuted_weights_ref is None + + if gpu_available: + ( + permuted_lengths_gpu, + permuted_indices_gpu, + permuted_weights_gpu, + ) = torch.ops.fbgemm.permute_sparse_features( + permute.cuda(), + lengths.cuda(), + indices.cuda(), + weights.cuda() if has_weight and weights is not None else None, + ) + torch.testing.assert_close(permuted_indices_gpu.cpu(), permuted_indices_cpu) + torch.testing.assert_close(permuted_lengths_gpu.cpu(), permuted_lengths_cpu) + if has_weight: + torch.testing.assert_close( + permuted_weights_gpu.cpu(), permuted_weights_cpu + ) + else: + assert permuted_weights_cpu is None + failures_dict_path: str = get_file_path_2( "", os.path.dirname(__file__), "failures_dict.json" @@ -2417,6 +2721,18 @@ def validate( "test_faketensor__test_index_select_dim0": [unittest.skip("hangs")], "test_autograd_registration__test_index_select_dim0": [unittest.skip("hangs")], "test_schema__test_index_select_dim0": [unittest.skip("hangs")], + "test_pt2_compliant_tag_fbgemm_dense_to_jagged": [ + # This operator has been grandfathered in. We need to fix this test failure. + unittest.expectedFailure, + ], + "test_pt2_compliant_tag_fbgemm_jagged_dense_elementwise_add": [ + # This operator has been grandfathered in. We need to fix this test failure. + unittest.expectedFailure, + ], + "test_pt2_compliant_tag_fbgemm_jagged_dense_elementwise_add_jagged_output": [ + # This operator has been grandfathered in. We need to fix this test failure. + unittest.expectedFailure, + ], } # only generate tests on nightly pytorch (current release version is 2.1) @@ -2436,7 +2752,6 @@ def validate( "test_schema", "test_autograd_registration", "test_faketensor", - "test_aot_dispatch_static", "test_aot_dispatch_dynamic", ], ) diff --git a/fbgemm_gpu/test/split_embedding_inference_converter_test.py b/fbgemm_gpu/test/split_embedding_inference_converter_test.py index 22ad6ad3ae..126552c2db 100644 --- a/fbgemm_gpu/test/split_embedding_inference_converter_test.py +++ b/fbgemm_gpu/test/split_embedding_inference_converter_test.py @@ -5,7 +5,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe import logging import math @@ -84,9 +83,13 @@ class SparseArch(nn.Module): def __init__( self, + # pyre-fixme[2]: Parameter must be annotated. emb_dim, + # pyre-fixme[2]: Parameter must be annotated. num_tables, + # pyre-fixme[2]: Parameter must be annotated. num_rows, + # pyre-fixme[2]: Parameter must be annotated. use_cpu, ) -> None: super().__init__() @@ -117,11 +120,16 @@ def __init__( -EMB_WEIGHT_UNIFORM_INIT_BOUND, +EMB_WEIGHT_UNIFORM_INIT_BOUND ) + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. def forward(self, indices, offsets): return self.emb_module(indices, offsets) class QuantizedSplitEmbeddingsTest(unittest.TestCase): + # pyre-fixme[56]: Pyre was not able to infer the type of argument + # `hypothesis.strategies.integers($parameter$min_value = 1, $parameter$max_value = + # 10)` to decorator factory `hypothesis.given`. @given( T=st.integers(min_value=1, max_value=10), D=st.integers(min_value=2, max_value=128), @@ -234,6 +242,9 @@ def test_quantize_workflow( ) @unittest.skipIf(*on_arm_platform) + # pyre-fixme[56]: Pyre was not able to infer the type of argument + # `hypothesis.strategies.booleans() if test_utils.gpu_available else + # hypothesis.strategies.just(True)` to decorator factory `hypothesis.given`. @given( use_cpu=st.booleans() if gpu_available else st.just(True), use_array_for_index_remapping=st.booleans(), @@ -319,6 +330,9 @@ def test_l2_norm_pruning_workflow( rtol=1.0e-1, ) + # pyre-fixme[56]: Pyre was not able to infer the type of argument + # `hypothesis.strategies.integers($parameter$min_value = 1, $parameter$max_value = + # 10)` to decorator factory `hypothesis.given`. @given( T=st.integers(min_value=1, max_value=10), D=st.integers(min_value=2, max_value=128), diff --git a/fbgemm_gpu/test/split_embeddings_utils_test.py b/fbgemm_gpu/test/split_embeddings_utils_test.py index 19bb14fff5..bbf0960e5d 100644 --- a/fbgemm_gpu/test/split_embeddings_utils_test.py +++ b/fbgemm_gpu/test/split_embeddings_utils_test.py @@ -13,15 +13,12 @@ import hypothesis.strategies as st import torch - +from fbgemm_gpu import sparse_ops # noqa: F401 from hypothesis import given, HealthCheck, settings try: - # pyre-ignore[21] - from fbgemm_gpu import open_source # noqa: F401 from test_utils import gpu_unavailable # pyre-ignore[21] except Exception: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") from fbgemm_gpu.test.test_utils import gpu_unavailable diff --git a/fbgemm_gpu/test/split_table_batched_embeddings_test.py b/fbgemm_gpu/test/split_table_batched_embeddings_test.py index 641b3c1e31..ce8e41e630 100644 --- a/fbgemm_gpu/test/split_table_batched_embeddings_test.py +++ b/fbgemm_gpu/test/split_table_batched_embeddings_test.py @@ -12,6 +12,7 @@ import pickle import random import unittest + from itertools import accumulate from typing import Any, Callable, Dict, List, Optional, Tuple, Union @@ -96,6 +97,9 @@ MAX_EXAMPLES_LONG_RUNNING = 15 +VERBOSITY: Verbosity = Verbosity.verbose + + settings.register_profile("derandomize", derandomize=True) settings.load_profile("derandomize") @@ -167,19 +171,22 @@ def format_ref_tensors_in_mixed_B_layout( "test_autograd_registration__test_backward_none_with_rowwise_adagrad": [ unittest.skip("Cannot access data pointer of Tensor that doesn't have storage") ], - "test_faketensor__test_cache_prefetch_pipeline_stream_2": [unittest.skip("OOM")], - "test_faketensor__test_cache_prefetch_pipeline": [unittest.skip("OOM")], - "test_faketensor__test_cache_prefetch_pipeline_stream_1": [ - unittest.skip("IMA on exit") + "test_faketensor__test_nbit_forward_uvm_cache": [ + unittest.skip("CUDA Assert"), ], - "test_faketensor__test_cache_pipeline": [ - unittest.skip("OOM when run serially"), + "test_faketensor__test_nbit_uvm_cache_stats": [ + unittest.skip("very slow"), + ], + "test_faketensor__test_nbit_direct_mapped_uvm_cache_stats": [ + unittest.skip("very slow"), ], } @optests.generate_opcheck_tests(fast=True, additional_decorators=additional_decorators) class SplitTableBatchedEmbeddingsTest(unittest.TestCase): + _do_cuda_memory_leak_check = True + def execute_forward_( # noqa C901 self, T: int, @@ -557,7 +564,7 @@ def test_forward_gpu_no_cache_int8( use_experimental_tbe=st.booleans() if not TEST_WITH_ROCM else st.just(False), ) @settings( - verbosity=Verbosity.verbose, + verbosity=VERBOSITY, max_examples=MAX_EXAMPLES_LONG_RUNNING, deadline=None, suppress_health_check=[HealthCheck.filter_too_much, HealthCheck.data_too_large], @@ -620,7 +627,7 @@ def test_forward_gpu_no_cache_fp16( use_experimental_tbe=st.booleans() if not TEST_WITH_ROCM else st.just(False), ) @settings( - verbosity=Verbosity.verbose, + verbosity=VERBOSITY, max_examples=MAX_EXAMPLES_LONG_RUNNING, deadline=None, suppress_health_check=[HealthCheck.filter_too_much, HealthCheck.data_too_large], @@ -683,7 +690,7 @@ def test_forward_gpu_no_cache_fp32( cache_algorithm=st.sampled_from(CacheAlgorithm), ) @settings( - verbosity=Verbosity.verbose, + verbosity=VERBOSITY, max_examples=MAX_EXAMPLES_LONG_RUNNING, deadline=None, suppress_health_check=[HealthCheck.filter_too_much, HealthCheck.data_too_large], @@ -748,7 +755,7 @@ def test_forward_gpu_uvm_cache_int8( use_experimental_tbe=st.booleans() if not TEST_WITH_ROCM else st.just(False), ) @settings( - verbosity=Verbosity.verbose, + verbosity=VERBOSITY, max_examples=MAX_EXAMPLES_LONG_RUNNING, deadline=None, suppress_health_check=[HealthCheck.filter_too_much, HealthCheck.data_too_large], @@ -814,7 +821,7 @@ def test_forward_gpu_uvm_cache_fp16( use_experimental_tbe=st.booleans() if not TEST_WITH_ROCM else st.just(False), ) @settings( - verbosity=Verbosity.verbose, + verbosity=VERBOSITY, max_examples=MAX_EXAMPLES_LONG_RUNNING, deadline=None, suppress_health_check=[HealthCheck.filter_too_much, HealthCheck.data_too_large], @@ -884,7 +891,7 @@ def test_forward_gpu_uvm_cache_fp32( output_dtype=st.sampled_from([SparseType.FP16, SparseType.INT8]), ) @settings( - verbosity=Verbosity.verbose, + verbosity=VERBOSITY, max_examples=MAX_EXAMPLES_LONG_RUNNING, deadline=None, suppress_health_check=[HealthCheck.filter_too_much], @@ -1019,7 +1026,7 @@ def test_forward_fused_pooled_emb_quant( ), ) @settings( - verbosity=Verbosity.verbose, + verbosity=VERBOSITY, max_examples=MAX_EXAMPLES_LONG_RUNNING, deadline=None, suppress_health_check=[HealthCheck.filter_too_much], @@ -1192,7 +1199,7 @@ def test_nbit_forward_fused_pooled_emb_quant( ), ) @settings( - verbosity=Verbosity.verbose, + verbosity=VERBOSITY, max_examples=MAX_EXAMPLES_LONG_RUNNING, deadline=None, suppress_health_check=[HealthCheck.filter_too_much], @@ -1305,7 +1312,7 @@ def test_nbit_split_embedding_weights_with_scale_and_bias( output_dtype=st.sampled_from([SparseType.FP32, SparseType.FP16]), ) @settings( - verbosity=Verbosity.verbose, + verbosity=VERBOSITY, max_examples=10, deadline=None, suppress_health_check=[HealthCheck.filter_too_much, HealthCheck.data_too_large], @@ -1555,7 +1562,7 @@ def test_backward_dense( # noqa C901 output_dtype=st.sampled_from([SparseType.FP16, SparseType.FP32]), ) @settings( - verbosity=Verbosity.verbose, + verbosity=VERBOSITY, max_examples=MAX_EXAMPLES, deadline=None, suppress_health_check=[HealthCheck.filter_too_much, HealthCheck.data_too_large], @@ -1582,7 +1589,7 @@ def test_backward_none(self, **kwargs: Any) -> None: output_dtype=st.sampled_from([SparseType.FP16, SparseType.FP32]), ) @settings( - verbosity=Verbosity.verbose, + verbosity=VERBOSITY, max_examples=MAX_EXAMPLES, deadline=None, suppress_health_check=[HealthCheck.filter_too_much, HealthCheck.data_too_large], @@ -2150,7 +2157,7 @@ def execute_backward_sgd_( # noqa C901 else st.just(True), ) @settings( - verbosity=Verbosity.verbose, + verbosity=VERBOSITY, max_examples=MAX_EXAMPLES, deadline=None, suppress_health_check=[HealthCheck.filter_too_much, HealthCheck.data_too_large], @@ -2204,7 +2211,7 @@ def test_backward_sgd( # noqa C901 cache_algorithm=st.sampled_from(CacheAlgorithm), ) @settings( - verbosity=Verbosity.verbose, + verbosity=VERBOSITY, max_examples=MAX_EXAMPLES_LONG_RUNNING, deadline=None, suppress_health_check=[HealthCheck.filter_too_much, HealthCheck.data_too_large], @@ -2648,7 +2655,7 @@ def execute_backward_adagrad_( # noqa C901 output_dtype=st.sampled_from([SparseType.FP32, SparseType.FP16]), ) @settings( - verbosity=Verbosity.verbose, + verbosity=VERBOSITY, max_examples=MAX_EXAMPLES_LONG_RUNNING, deadline=None, suppress_health_check=[HealthCheck.filter_too_much, HealthCheck.data_too_large], @@ -2718,7 +2725,7 @@ def test_backward_adagrad_fp16_pmSUM( # noqa C901 output_dtype=st.sampled_from([SparseType.FP32, SparseType.FP16]), ) @settings( - verbosity=Verbosity.verbose, + verbosity=VERBOSITY, max_examples=MAX_EXAMPLES_LONG_RUNNING, deadline=None, suppress_health_check=[HealthCheck.filter_too_much, HealthCheck.data_too_large], @@ -2787,7 +2794,7 @@ def test_backward_adagrad_fp16_pmMEAN( # noqa C901 output_dtype=st.sampled_from([SparseType.FP32, SparseType.FP16]), ) @settings( - verbosity=Verbosity.verbose, + verbosity=VERBOSITY, max_examples=MAX_EXAMPLES_LONG_RUNNING, deadline=None, suppress_health_check=[HealthCheck.filter_too_much, HealthCheck.data_too_large], @@ -2853,7 +2860,7 @@ def test_backward_adagrad_fp16_pmNONE( # noqa C901 output_dtype=st.sampled_from([SparseType.FP32, SparseType.FP16]), ) @settings( - verbosity=Verbosity.verbose, + verbosity=VERBOSITY, max_examples=MAX_EXAMPLES_LONG_RUNNING, deadline=None, suppress_health_check=[HealthCheck.filter_too_much, HealthCheck.data_too_large], @@ -2923,7 +2930,7 @@ def test_backward_adagrad_fp32_pmSUM( # noqa C901 output_dtype=st.sampled_from([SparseType.FP32, SparseType.FP16]), ) @settings( - verbosity=Verbosity.verbose, + verbosity=VERBOSITY, max_examples=MAX_EXAMPLES_LONG_RUNNING, deadline=None, suppress_health_check=[HealthCheck.filter_too_much, HealthCheck.data_too_large], @@ -2992,7 +2999,7 @@ def test_backward_adagrad_fp32_pmMEAN( # noqa C901 output_dtype=st.sampled_from([SparseType.FP32, SparseType.FP16]), ) @settings( - verbosity=Verbosity.verbose, + verbosity=VERBOSITY, max_examples=MAX_EXAMPLES_LONG_RUNNING, deadline=None, suppress_health_check=[HealthCheck.filter_too_much, HealthCheck.data_too_large], @@ -3116,6 +3123,7 @@ def _generate_cache_tbes( return (cc, cc_ref, min(Es), sum(Ds)) + @optests.dontGenerateOpCheckTests("Serial OOM") @unittest.skipIf(*gpu_unavailable) @given( T=st.integers(min_value=1, max_value=5), @@ -3126,7 +3134,7 @@ def _generate_cache_tbes( mixed=st.booleans(), cache_algorithm=st.sampled_from(CacheAlgorithm), ) - @settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None) + @settings(verbosity=VERBOSITY, max_examples=MAX_EXAMPLES, deadline=None) def test_cache_pipeline( self, T: int, @@ -3242,6 +3250,7 @@ def _prefetch( torch.testing.assert_close(output, output_ref) self.assertTrue(torch.all(cc.lxu_cache_locking_counter == 0)) + @optests.dontGenerateOpCheckTests("Serial OOM") @unittest.skipIf(*gpu_unavailable) @given( T=st.integers(min_value=1, max_value=5), @@ -3252,7 +3261,7 @@ def _prefetch( mixed=st.booleans(), prefetch_location=st.sampled_from(["before_fwd", "between_fwd_bwd"]), ) - @settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None) + @settings(verbosity=VERBOSITY, max_examples=MAX_EXAMPLES, deadline=None) def test_cache_prefetch_pipeline( self, T: int, @@ -3274,6 +3283,7 @@ def test_cache_prefetch_pipeline( prefetch_stream=None, ) + @optests.dontGenerateOpCheckTests("Serial OOM") @unittest.skipIf(*gpu_unavailable) @given( T=st.integers(min_value=1, max_value=5), @@ -3283,7 +3293,7 @@ def test_cache_prefetch_pipeline( L=st.integers(min_value=1, max_value=20), mixed=st.booleans(), ) - @settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None) + @settings(verbosity=VERBOSITY, max_examples=MAX_EXAMPLES, deadline=None) def test_cache_prefetch_pipeline_stream_1( self, T: int, @@ -3304,6 +3314,7 @@ def test_cache_prefetch_pipeline_stream_1( prefetch_stream=torch.cuda.Stream(), ) + @optests.dontGenerateOpCheckTests("Serial OOM") @unittest.skipIf(*gpu_unavailable) @given( T=st.integers(min_value=1, max_value=5), @@ -3313,7 +3324,7 @@ def test_cache_prefetch_pipeline_stream_1( L=st.integers(min_value=1, max_value=20), mixed=st.booleans(), ) - @settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None) + @settings(verbosity=VERBOSITY, max_examples=MAX_EXAMPLES, deadline=None) def test_cache_prefetch_pipeline_stream_2( self, T: int, @@ -3985,7 +3996,7 @@ def get_wts_from_counter_adagrad( uvm_non_rowwise_momentum=st.booleans(), ) @settings( - verbosity=Verbosity.verbose, + verbosity=VERBOSITY, max_examples=MAX_EXAMPLES_LONG_RUNNING, deadline=None, suppress_health_check=[HealthCheck.filter_too_much, HealthCheck.data_too_large], @@ -4065,7 +4076,7 @@ def test_backward_optimizers_adam( # noqa C901 ), ) @settings( - verbosity=Verbosity.verbose, + verbosity=VERBOSITY, max_examples=MAX_EXAMPLES_LONG_RUNNING, deadline=None, suppress_health_check=[HealthCheck.filter_too_much, HealthCheck.data_too_large], @@ -4137,7 +4148,7 @@ def test_backward_optimizers_adagrad( # noqa C901 else st.just(True), ) @settings( - verbosity=Verbosity.verbose, + verbosity=VERBOSITY, max_examples=MAX_EXAMPLES_LONG_RUNNING, deadline=None, suppress_health_check=[HealthCheck.filter_too_much, HealthCheck.data_too_large], @@ -4196,7 +4207,7 @@ def test_backward_optimizers_lamb( # noqa C901 else st.just(True), ) @settings( - verbosity=Verbosity.verbose, + verbosity=VERBOSITY, max_examples=MAX_EXAMPLES_LONG_RUNNING, deadline=None, suppress_health_check=[HealthCheck.filter_too_much, HealthCheck.data_too_large], @@ -4371,14 +4382,9 @@ def execute_nbit_forward_( # noqa C901 # Initialize and insert Array index remapping based data structure index_remappings_array = [] for t in range(T): - # pyre-fixme[6]: For 1st param expected `dtype` but got `Union[int, - # str]`. indice_t = (indices.view(T, B, L))[t].long().view(-1).to(current_device) dense_indice_t = ( - (dense_indices.view(T, B, L))[t].view(-1) - # pyre-fixme[6]: For 1st param expected `dtype` but got `Union[int, - # str]`. - .to(current_device) + (dense_indices.view(T, B, L))[t].view(-1).to(current_device) ) index_remappings_array_t = torch.tensor( [-1] * original_E, @@ -4522,7 +4528,7 @@ def execute_nbit_forward_( # noqa C901 do_pruning=st.booleans(), ) @settings( - verbosity=Verbosity.verbose, + verbosity=VERBOSITY, max_examples=MAX_EXAMPLES_LONG_RUNNING, deadline=None, ) @@ -4609,7 +4615,7 @@ def test_nbit_forward_cpu( do_pruning=st.booleans(), ) @settings( - verbosity=Verbosity.verbose, + verbosity=VERBOSITY, max_examples=MAX_EXAMPLES_LONG_RUNNING, deadline=None, ) @@ -4699,7 +4705,7 @@ def test_nbit_forward_gpu_no_cache_fp8_2048(self) -> None: do_pruning=st.booleans(), ) @settings( - verbosity=Verbosity.verbose, + verbosity=VERBOSITY, max_examples=MAX_EXAMPLES_LONG_RUNNING, deadline=None, ) @@ -4778,7 +4784,7 @@ def test_nbit_forward_gpu_no_cache( ), emulate_pruning=st.booleans(), ) - @settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None) + @settings(verbosity=VERBOSITY, max_examples=MAX_EXAMPLES, deadline=None) def test_int_nbit_split_embedding_uvm_caching_codegen_lookup_function( self, weights_ty: SparseType, @@ -4963,7 +4969,7 @@ def test_int_nbit_split_embedding_uvm_caching_codegen_lookup_function( do_pruning=st.booleans(), use_array_for_index_remapping=st.booleans(), ) - @settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None) + @settings(verbosity=VERBOSITY, max_examples=MAX_EXAMPLES, deadline=None) def test_nbit_forward_uvm_cache( self, weights_ty: SparseType, @@ -5080,7 +5086,7 @@ def test_nbit_forward_uvm_cache( use_cpu_hashtable=st.booleans(), use_array_for_index_remapping=st.booleans(), ) - @settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None) + @settings(verbosity=VERBOSITY, max_examples=MAX_EXAMPLES, deadline=None) def test_pruning( self, T: int, @@ -5144,19 +5150,13 @@ def test_pruning( index_remappings_array_offsets = torch.empty( T + 1, dtype=torch.int64, - # pyre-fixme[6]: For 3rd param expected `Union[None, str, device]` but - # got `Union[int, str]`. device=current_device, ) index_remappings_array_offsets[0] = 0 for t in range(T): - # pyre-fixme[6]: For 1st param expected `dtype` but got `Union[int, str]`. indice_t = (indices.view(T, B, L))[t].long().view(-1).to(current_device) dense_indice_t = ( - (dense_indices.view(T, B, L))[t].view(-1) - # pyre-fixme[6]: For 1st param expected `dtype` but got `Union[int, - # str]`. - .to(current_device) + (dense_indices.view(T, B, L))[t].view(-1).to(current_device) ) selected_indices = torch.add(indice_t, t * original_E)[:E] index_remappings_array[selected_indices] = dense_indice_t @@ -5175,26 +5175,12 @@ def test_pruning( index_remappings_array, index_remappings_array_offsets, ) = ( - # pyre-fixme[6]: For 1st param expected `dtype` but got `Union[int, - # str]`. indices.to(current_device), - # pyre-fixme[6]: For 1st param expected `dtype` but got `Union[int, - # str]`. dense_indices.to(current_device), - # pyre-fixme[6]: For 1st param expected `dtype` but got `Union[int, - # str]`. offsets.to(current_device), - # pyre-fixme[6]: For 1st param expected `dtype` but got `Union[int, - # str]`. hash_table.to(current_device), - # pyre-fixme[6]: For 1st param expected `dtype` but got `Union[int, - # str]`. hash_table_offsets.to(current_device), - # pyre-fixme[6]: For 1st param expected `dtype` but got `Union[int, - # str]`. index_remappings_array.to(current_device), - # pyre-fixme[6]: For 1st param expected `dtype` but got `Union[int, - # str]`. index_remappings_array_offsets.to(current_device), ) @@ -5244,7 +5230,7 @@ def test_pruning( H=st.integers(min_value=512, max_value=1024), S=st.integers(min_value=0, max_value=128), ) - @settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None) + @settings(verbosity=VERBOSITY, max_examples=MAX_EXAMPLES, deadline=None) def test_cache_update_function(self, L: int, H: int, S: int) -> None: # Generate synthetic data linear_cache_indices_cpu = torch.randint(L, H, (S,)) @@ -5295,7 +5281,7 @@ def test_cache_update_function(self, L: int, H: int, S: int) -> None: self.assertLessEqual(cache_miss_forward_count, unique_cache_miss_count) @given(N=st.integers(min_value=1, max_value=8)) - @settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None) + @settings(verbosity=VERBOSITY, max_examples=MAX_EXAMPLES, deadline=None) def test_cache_miss_counter(self, N: int) -> None: # Create an abstract split table D = 8 @@ -5351,7 +5337,7 @@ def test_cache_miss_counter(self, N: int) -> None: self.assertEqual(tablewise_cache_miss[i], t_tablewise_cache_miss[i]) @given(N=st.integers(min_value=1, max_value=2)) - @settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None) + @settings(verbosity=VERBOSITY, max_examples=MAX_EXAMPLES, deadline=None) def test_stb_uvm_cache_stats(self, N: int) -> None: # Create an abstract split table D = 8 @@ -5402,7 +5388,7 @@ def test_stb_uvm_cache_stats(self, N: int) -> None: H=st.integers(min_value=512, max_value=1024), S=st.integers(min_value=0, max_value=128), ) - @settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None) + @settings(verbosity=VERBOSITY, max_examples=MAX_EXAMPLES, deadline=None) def test_nbit_cache_update_function(self, L: int, H: int, S: int) -> None: # Generate synthetic data linear_cache_indices_cpu = torch.randint(L, H, (S,)) @@ -5459,7 +5445,7 @@ def test_nbit_cache_update_function(self, L: int, H: int, S: int) -> None: @unittest.skipIf(*gpu_unavailable) @given(N=st.integers(min_value=1, max_value=8)) - @settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None) + @settings(verbosity=VERBOSITY, max_examples=MAX_EXAMPLES, deadline=None) def test_nbit_cache_miss_counter(self, N: int) -> None: # Create an abstract split table D = 8 @@ -5514,7 +5500,7 @@ def test_nbit_cache_miss_counter(self, N: int) -> None: N=st.integers(min_value=1, max_value=8), dtype=st.sampled_from([SparseType.INT8, SparseType.INT4, SparseType.INT2]), ) - @settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None) + @settings(verbosity=VERBOSITY, max_examples=MAX_EXAMPLES, deadline=None) def test_nbit_uvm_cache_stats(self, N: int, dtype: SparseType) -> None: # Create an abstract split table D = 8 @@ -5629,7 +5615,7 @@ def test_nbit_uvm_cache_stats(self, N: int, dtype: SparseType) -> None: N=st.integers(min_value=1, max_value=8), dtype=st.sampled_from([SparseType.INT8, SparseType.INT4, SparseType.INT2]), ) - @settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None) + @settings(verbosity=VERBOSITY, max_examples=MAX_EXAMPLES, deadline=None) def test_nbit_direct_mapped_uvm_cache_stats( self, N: int, dtype: SparseType ) -> None: @@ -5772,7 +5758,7 @@ def test_nbit_direct_mapped_uvm_cache_stats( ), mixed_B=st.booleans(), ) - @settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None) + @settings(verbosity=VERBOSITY, max_examples=MAX_EXAMPLES, deadline=None) def test_bounds_check( # noqa C901 self, T: int, @@ -6216,7 +6202,7 @@ def test_lxu_cache_lookup(self, associativity: int) -> None: @given( cache_sets=st.integers(min_value=10, max_value=300), ) - @settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None) + @settings(verbosity=VERBOSITY, max_examples=MAX_EXAMPLES, deadline=None) def test_lxu_cache_locking_counter_decrement( self, cache_sets: int, @@ -6281,7 +6267,7 @@ def test_lxu_cache_locking_counter_decrement( else st.just(True), test_internal=st.booleans(), ) - @settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None) + @settings(verbosity=VERBOSITY, max_examples=MAX_EXAMPLES, deadline=None) def test_embedding_inplace_update( self, T: int, # num of embedding tables @@ -6439,7 +6425,7 @@ def test_embedding_inplace_update( num_indices_per_table=st.integers(min_value=1, max_value=5), ) @settings( - verbosity=Verbosity.verbose, + verbosity=VERBOSITY, max_examples=MAX_EXAMPLES, deadline=None, ) @@ -6576,6 +6562,164 @@ def check_weight_momentum(v: int) -> None: check_weight_momentum(0) + @unittest.skipIf(*gpu_unavailable) + @given( + T=st.integers(min_value=1, max_value=10), + D=st.integers(min_value=2, max_value=128), + B=st.integers(min_value=1, max_value=128), + log_E=st.integers(min_value=3, max_value=5), + L=st.integers(min_value=0, max_value=20), + ) + @settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None) + def test_unique_lxu_cache_lookup( + self, + T: int, + D: int, + B: int, + log_E: int, + L: int, + ) -> None: + E = int(10**log_E) + + indices = to_device( + torch.randint(low=0, high=E, size=(T * L * B,)), + use_cpu=False, + ).long() + offsets = to_device( + torch.tensor([0] + list(accumulate([L] * (T * L)))), + use_cpu=False, + ).long() + + def unique_lookup( + indices: Tensor, + offsets: Tensor, + cache_hash_size_cumsum: Tensor, + total_cache_hash_size: int, + ) -> Tuple[Tensor, Tensor]: + linear_cache_indices = torch.ops.fbgemm.linearize_cache_indices( + cache_hash_size_cumsum, + indices, + offsets, + ) + + uniq_indices, uniq_indices_length, _ = torch.ops.fbgemm.get_unique_indices( + linear_cache_indices, total_cache_hash_size, compute_count=False + ) + + uniq_lxu_cache_locations = torch.ops.fbgemm.lxu_cache_lookup( + uniq_indices, + lxu_cache_state, + total_cache_hash_size, + gather_cache_stats=False, + num_uniq_cache_indices=uniq_indices_length, + ) + + return uniq_lxu_cache_locations, uniq_indices_length + + def duplicate_lookup( + indices: Tensor, + offsets: Tensor, + cache_hash_size_cumsum: Tensor, + total_cache_hash_size: int, + ) -> Tensor: + linear_cache_indices = torch.ops.fbgemm.linearize_cache_indices( + cache_hash_size_cumsum, + indices, + offsets, + ) + + lxu_cache_locations = torch.ops.fbgemm.lxu_cache_lookup( + linear_cache_indices, + lxu_cache_state, + total_cache_hash_size, + gather_cache_stats=False, + ) + return lxu_cache_locations + + cache_sets = int((E * T) * 0.2) + lxu_cache_state = torch.zeros( + cache_sets, + DEFAULT_ASSOC, + device="cuda", + dtype=torch.int64, + ).fill_(-1) + + hash_sizes = torch.tensor([E] * T, dtype=torch.long, device="cuda") + cache_hash_size_cumsum = torch.ops.fbgemm.asynchronous_complete_cumsum( + hash_sizes + ) + total_cache_hash_size = cache_hash_size_cumsum[-1].item() + + linear_cache_indices = torch.ops.fbgemm.linearize_cache_indices( + cache_hash_size_cumsum, + indices, + offsets, + ) + + # Emulate cache population + uniq_indices_cpu = linear_cache_indices.unique().cpu() + index_cache_set_map = uniq_indices_cpu.clone() + index_cache_set_map.apply_( + lambda x: torch.ops.fbgemm.lxu_cache_slot(x, cache_sets) + ) + index_cache_set_map = index_cache_set_map.tolist() + uniq_indices_cpu = uniq_indices_cpu.tolist() + + slots = {} + for idx, c in zip(uniq_indices_cpu, index_cache_set_map): + if c not in slots: + slots[c] = 0 + slot = slots[c] + if slot < DEFAULT_ASSOC: + lxu_cache_state[c][slot] = idx + slots[c] = slot + 1 + + # Run unique lookup + uniq_lookup_output, uniq_indices_length = unique_lookup( + indices, offsets, cache_hash_size_cumsum, total_cache_hash_size + ) + + # Run duplicate lookup + duplicate_lookup_output = duplicate_lookup( + indices, offsets, cache_hash_size_cumsum, total_cache_hash_size + ) + + # Start running validation + + # Compute unique indices using PyTorch ops + sorted_linear_cache_indices, inverse_sorted_cache_indices = torch.sort( + linear_cache_indices + ) + ref_uniq_cache_indices, cache_indices_counts = torch.unique_consecutive( + sorted_linear_cache_indices, return_inverse=False, return_counts=True + ) + + # Convert to lists + cache_indices_counts = cache_indices_counts.cpu().tolist() + uniq_lookup_output = uniq_lookup_output.cpu().tolist() + + # Validate the number of unique cache indices + ref_num_uniq_indices = ref_uniq_cache_indices.numel() + assert ref_num_uniq_indices == uniq_indices_length.item() + + # Expand + reshaped_uniq_lookup_output = uniq_lookup_output[:ref_num_uniq_indices] + sorted_lxu_cache_locations = to_device( + torch.tensor( + np.repeat(reshaped_uniq_lookup_output, cache_indices_counts), + dtype=duplicate_lookup_output.dtype, + ), + use_cpu=False, + ) + + _, cache_location_indices = torch.sort(inverse_sorted_cache_indices) + + expanded_lxu_cache_locations = torch.index_select( + sorted_lxu_cache_locations, 0, cache_location_indices + ) + + assert torch.equal(expanded_lxu_cache_locations, duplicate_lookup_output) + if __name__ == "__main__": unittest.main() diff --git a/fbgemm_gpu/test/test_utils.py b/fbgemm_gpu/test/test_utils.py index 7f3ad9ae38..cd5dca4105 100644 --- a/fbgemm_gpu/test/test_utils.py +++ b/fbgemm_gpu/test/test_utils.py @@ -18,6 +18,11 @@ import hypothesis.strategies as st import numpy as np import torch +from hypothesis import settings + + +settings.register_profile("derandomize", derandomize=True) +settings.load_profile("derandomize") TEST_WITH_ROCM: bool = os.getenv("FBGEMM_TEST_WITH_ROCM", "0") == "1" diff --git a/src/EmbeddingSpMDM.cc b/src/EmbeddingSpMDM.cc index 144f8c0730..8aedf284dd 100644 --- a/src/EmbeddingSpMDM.cc +++ b/src/EmbeddingSpMDM.cc @@ -1018,9 +1018,9 @@ typename EmbeddingSpMDMKernelSignature:: Type GenerateEmbeddingSpMDMWithStrides( const int64_t block_size, - bool has_weight, + [[maybe_unused]] bool has_weight, bool normalize_by_lengths, - int prefetch, + [[maybe_unused]] int prefetch, bool is_weight_positional, bool use_offsets, int64_t output_stride /*=-1*/, diff --git a/src/QuantUtils.cc b/src/QuantUtils.cc index da20d71ede..7e2fb37264 100644 --- a/src/QuantUtils.cc +++ b/src/QuantUtils.cc @@ -268,7 +268,7 @@ FBGEMM_SPECIALIZED_QUANTIZE(uint8_t, false) const TensorQuantizationParams& qparams, \ int thread_id, \ int num_threads, \ - float noise_ratio) { \ + [[maybe_unused]] float noise_ratio) { \ int64_t i_begin, i_end; \ fbgemmPartition1D(thread_id, num_threads, len, i_begin, i_end); \ for (int64_t i = i_begin; i < i_end; ++i) { \