From 7a75492afdf88cd521287105f7e759492a1fa96f Mon Sep 17 00:00:00 2001 From: Sarunya Pumma Date: Thu, 26 Sep 2024 20:37:27 -0700 Subject: [PATCH 01/48] Add quantization op docstring (#3177) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/273 As title Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3177 Test Plan: See example https://deploy-preview-3177--pytorch-fbgemm-docs.netlify.app/fbgemm_gpu-python-api/quantize_ops Reviewed By: shintaro-iwasaki Differential Revision: D63445241 Pulled By: sryap fbshipit-source-id: 019b2dedfa0f31c487974fb91271d82c661520cc --- .../fbgemm_gpu-python-api/quantize_ops.rst | 6 +++ fbgemm_gpu/docs/src/index.rst | 1 + fbgemm_gpu/fbgemm_gpu/docs/__init__.py | 1 + fbgemm_gpu/fbgemm_gpu/docs/quantize_ops.py | 41 +++++++++++++++++++ 4 files changed, 49 insertions(+) create mode 100644 fbgemm_gpu/docs/src/fbgemm_gpu-python-api/quantize_ops.rst create mode 100644 fbgemm_gpu/fbgemm_gpu/docs/quantize_ops.py diff --git a/fbgemm_gpu/docs/src/fbgemm_gpu-python-api/quantize_ops.rst b/fbgemm_gpu/docs/src/fbgemm_gpu-python-api/quantize_ops.rst new file mode 100644 index 000000000..df2a6c2d7 --- /dev/null +++ b/fbgemm_gpu/docs/src/fbgemm_gpu-python-api/quantize_ops.rst @@ -0,0 +1,6 @@ +Quantization Operators +====================== + +.. automodule:: fbgemm_gpu + +.. autofunction:: torch.ops.fbgemm.FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf diff --git a/fbgemm_gpu/docs/src/index.rst b/fbgemm_gpu/docs/src/index.rst index ba0d8ba6b..1669bf22f 100644 --- a/fbgemm_gpu/docs/src/index.rst +++ b/fbgemm_gpu/docs/src/index.rst @@ -91,6 +91,7 @@ Table of Contents fbgemm_gpu-python-api/jagged_tensor_ops.rst fbgemm_gpu-python-api/pooled_embedding_ops.rst + fbgemm_gpu-python-api/quantize_ops.rst .. _fbgemm-gpu.toc.api.python.modules: diff --git a/fbgemm_gpu/fbgemm_gpu/docs/__init__.py b/fbgemm_gpu/fbgemm_gpu/docs/__init__.py index 4b621cbe3..e531e1254 100644 --- a/fbgemm_gpu/fbgemm_gpu/docs/__init__.py +++ b/fbgemm_gpu/fbgemm_gpu/docs/__init__.py @@ -11,6 +11,7 @@ jagged_tensor_ops, merge_pooled_embedding_ops, permute_pooled_embedding_ops, + quantize_ops, ) except Exception: pass diff --git a/fbgemm_gpu/fbgemm_gpu/docs/quantize_ops.py b/fbgemm_gpu/fbgemm_gpu/docs/quantize_ops.py new file mode 100644 index 000000000..3662b12c7 --- /dev/null +++ b/fbgemm_gpu/fbgemm_gpu/docs/quantize_ops.py @@ -0,0 +1,41 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from .common import add_docs + +add_docs( + torch.ops.fbgemm.FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf, + """ +FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf(input, bit_rate) -> Tensor + +Convert FP32/16 to INT8/4/2 using rowwise quantization. + +Args: + input (Tensor): An input tensor. Must be either FP32 (`torch.float`) + or FP16 (`torch.half`) and must be 2 dimensions. + + bit_rate (int): Quantized bit rate (2 for INT2, 4 for INT4, or 8 for + INT8) + +Returns: + Quantized output (Tensor). Data type is `torch.uint8` (byte type) + +**Example:** + + >>> # Randomize input + >>> input = torch.randn(2, 4, dtype=torch.float32, device="cuda") + >>> print(input) + tensor([[ 0.8247, 0.0031, -1.0068, -1.2081], + [ 0.5427, 1.5772, 1.0291, -0.7626]], device='cuda:0') + >>> # Quantize + >>> output = torch.ops.fbgemm.FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf(input, bit_rate=4) + >>> print(output) + tensor([[159, 1, 86, 48, 213, 188], + [248, 11, 254, 48, 26, 186]], device='cuda:0', dtype=torch.uint8) + """, +) From 642b895e678bce2b593af89af57e0986306d1698 Mon Sep 17 00:00:00 2001 From: Sarunya Pumma Date: Thu, 26 Sep 2024 20:37:27 -0700 Subject: [PATCH 02/48] Add permute sparse data docstrings (#3178) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/274 As title Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3178 Test Plan: See example https://deploy-preview-3178--pytorch-fbgemm-docs.netlify.app/fbgemm_gpu-python-api/sparse_ops Reviewed By: shintaro-iwasaki Differential Revision: D63458583 Pulled By: sryap fbshipit-source-id: 3beb73e65e242c103428000ff26185335194035b --- .../src/fbgemm_gpu-python-api/sparse_ops.rst | 8 ++ fbgemm_gpu/docs/src/index.rst | 3 +- fbgemm_gpu/fbgemm_gpu/docs/__init__.py | 1 + fbgemm_gpu/fbgemm_gpu/docs/sparse_ops.py | 121 ++++++++++++++++++ 4 files changed, 132 insertions(+), 1 deletion(-) create mode 100644 fbgemm_gpu/docs/src/fbgemm_gpu-python-api/sparse_ops.rst create mode 100644 fbgemm_gpu/fbgemm_gpu/docs/sparse_ops.py diff --git a/fbgemm_gpu/docs/src/fbgemm_gpu-python-api/sparse_ops.rst b/fbgemm_gpu/docs/src/fbgemm_gpu-python-api/sparse_ops.rst new file mode 100644 index 000000000..e22812586 --- /dev/null +++ b/fbgemm_gpu/docs/src/fbgemm_gpu-python-api/sparse_ops.rst @@ -0,0 +1,8 @@ +Sparse Operators +================ + +.. automodule:: fbgemm_gpu + +.. autofunction:: torch.ops.fbgemm.permute_2D_sparse_data + +.. autofunction:: torch.ops.fbgemm.permute_1D_sparse_data diff --git a/fbgemm_gpu/docs/src/index.rst b/fbgemm_gpu/docs/src/index.rst index 1669bf22f..f5b4cfb07 100644 --- a/fbgemm_gpu/docs/src/index.rst +++ b/fbgemm_gpu/docs/src/index.rst @@ -89,9 +89,10 @@ Table of Contents :maxdepth: 1 :caption: FBGEMM_GPU Python Operators API - fbgemm_gpu-python-api/jagged_tensor_ops.rst + fbgemm_gpu-python-api/sparse_ops.rst fbgemm_gpu-python-api/pooled_embedding_ops.rst fbgemm_gpu-python-api/quantize_ops.rst + fbgemm_gpu-python-api/jagged_tensor_ops.rst .. _fbgemm-gpu.toc.api.python.modules: diff --git a/fbgemm_gpu/fbgemm_gpu/docs/__init__.py b/fbgemm_gpu/fbgemm_gpu/docs/__init__.py index e531e1254..8d696532a 100644 --- a/fbgemm_gpu/fbgemm_gpu/docs/__init__.py +++ b/fbgemm_gpu/fbgemm_gpu/docs/__init__.py @@ -12,6 +12,7 @@ merge_pooled_embedding_ops, permute_pooled_embedding_ops, quantize_ops, + sparse_ops, ) except Exception: pass diff --git a/fbgemm_gpu/fbgemm_gpu/docs/sparse_ops.py b/fbgemm_gpu/fbgemm_gpu/docs/sparse_ops.py new file mode 100644 index 000000000..c95588207 --- /dev/null +++ b/fbgemm_gpu/fbgemm_gpu/docs/sparse_ops.py @@ -0,0 +1,121 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from .common import add_docs + +add_docs( + torch.ops.fbgemm.permute_2D_sparse_data, + """ +permute_2D_sparse_data(permute, lengths, values, weights=None, permuted_lengths_sum=None) -> Tuple[Tensor, Tensor, Optional[Tensor]] + +Permute 2D sparse data along the first dimension (dim 0). Note that 2D +refers to the number of dense dimensions. The input data is actually 3D +where the first two dimensions are dense and the last dimension is +jagged (sparse). The data to permute over can be less or more and with or +without repetitions. + +Args: + permute (Tensor): A 1D-tensor that describes how data is permuted along dim + 0. `permute[i]` indicates that data at position `permute[i]` is moved + to position `i`. The length of this tensor is the total amount of data + in dim 0 to be permuted. The values in `permute` must be >= 0 and < + `lengths.shape[0]` + + lengths (Tensor): A 2D-tensor that contains jagged shapes corresponding to + the other two dense dimensions. For example, in the case of the + embedding input, the 3D shape is (num features, batch size, bag size). + `lengths[t][b]` represents the bag size of feature `t` and sample `b`. + + values (Tensor): A 1D-input-tensor to be permuted. The length of this + tensor must be equal to `lengths.sum()`. This tensor can be of any data + type. + + weights (Optional[Tensor] = None): An optional 1D-float-tensor. It must + have the same length as `values`. It will be permuted the same way as + values + + permuted_lengths_sum (Optional[int] = None): An optional value that + represents the total number of elements in the permuted data (output + shape). If not provided, the operator will compute this data which may + cause a device-host synchronization (if using GPU). Thus, it is + recommended to supply this value to avoid such the synchronization. + +Returns: + A tuple of permuted lengths, permuted indices and permuted weights + +**Example:** + + >>> permute = torch.tensor([1, 0, 2], dtype=torch.int32, device="cuda") + >>> lengths = torch.tensor([[2, 3, 4, 5], [1, 2, 4, 8], [0, 3, 2, 3]], dtype=torch.int64, device="cuda") + >>> values = torch.randint(low=0, high=100, size=(lengths.sum().item(),), dtype=torch.int64, device="cuda") + >>> print(values) + tensor([29, 12, 61, 98, 56, 94, 5, 89, 65, 48, 71, 54, 40, 33, 78, 68, 42, 21, + 60, 51, 15, 47, 48, 68, 52, 19, 38, 30, 38, 97, 97, 98, 18, 40, 42, 89, + 66], device='cuda:0') + >>> torch.ops.fbgemm.permute_2D_sparse_data(permute, lengths, values) + (tensor([[1, 2, 4, 8], + [2, 3, 4, 5], + [0, 3, 2, 3]], device='cuda:0'), + tensor([78, 68, 42, 21, 60, 51, 15, 47, 48, 68, 52, 19, 38, 30, 38, 29, 12, 61, + 98, 56, 94, 5, 89, 65, 48, 71, 54, 40, 33, 97, 97, 98, 18, 40, 42, 89, + 66], device='cuda:0'), + None) + """, +) + +add_docs( + torch.ops.fbgemm.permute_1D_sparse_data, + """ +permute_1D_sparse_data(permute, lengths, values, weights=None, permuted_lengths_sum=None) -> Tuple[Tensor, Tensor, Optional[Tensor]] + +Permute 1D sparse data. Note that 1D referrs to the number of dense dimensions. +The input data is actually 2D where the first dimension is dense and the second +dimension is jagged (sparse). The data to permute over can be less or more and +withh or without repetitions. + +Args: + permute (Tensor): A 1D-tensor that describes how data is permuted along dim + 0. `permute[i]` indicates that data at position `permute[i]` is moved + to position `i`. The length of this tensor is the total amount of data + in dim 0 to be permuted. The values in `permute` must be >= 0 and < + `lengths.numel()` + + lengths (Tensor): A 1D-tensor that contains jagged shapes corresponding to + the other dense dimension. `lengths[i]` represents the jagged shape of + data at position `i` in dim 0 + + values (Tensor): A 1D-input-tensor to be permuted. The length of this + tensor must be equal to `lengths.sum()`. This tensor can be of any data + type. + + weights (Optional[Tensor] = None): An optional 1D-float-tensor. It must + have the same length as `values`. It will be permuted the same way as + values + + permuted_lengths_sum (Optional[int] = None): An optional value that + represents the total number of elements in the permuted data (output + shape). If not provided, the operator will compute this data which may + cause a device-host synchronization (if using GPU). Thus, it is + recommended to supply this value to avoid such the synchronization. + +Returns: + A tuple of permuted lengths, permuted indices and permuted weights + +**Example:** + >>> permute = torch.tensor([1, 0, 3, 0], dtype=torch.int32, device="cuda") + >>> lengths = torch.tensor([2, 3, 4, 5], dtype=torch.int64, device="cuda") + >>> values = torch.randint(low=0, high=100, size=(lengths.sum().item(),), dtype=torch.int64, device="cuda") + >>> print(values) + tensor([ 1, 76, 24, 84, 94, 25, 15, 23, 31, 46, 9, 23, 34, 3], + device='cuda:0') + >>> torch.ops.fbgemm.permute_1D_sparse_data(permute, lengths, values) + (tensor([3, 2, 5, 2], device='cuda:0'), + tensor([24, 84, 94, 1, 76, 46, 9, 23, 34, 3, 1, 76], device='cuda:0'), + None) + """, +) From a9b7ae8a3b1fdc4f89410df02fac0a4892eaa08a Mon Sep 17 00:00:00 2001 From: Sarunya Pumma Date: Thu, 26 Sep 2024 20:37:27 -0700 Subject: [PATCH 03/48] Add docstrings for sparse ops (#3179) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/275 Add a docstring for - torch.ops.fbgemm.expand_into_jagged_permute - torch.ops.fbgemm.asynchronous_complete_cumsum - torch.ops.fbgemm.offsets_range Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3179 Test Plan: See example https://deploy-preview-3179--pytorch-fbgemm-docs.netlify.app/fbgemm_gpu-python-api/sparse_ops Reviewed By: shintaro-iwasaki Differential Revision: D63464136 Pulled By: sryap fbshipit-source-id: 5f566993c030b9c67998cf4c7a8cfd41fdee1a97 --- .../src/fbgemm_gpu-python-api/sparse_ops.rst | 6 ++ fbgemm_gpu/fbgemm_gpu/docs/sparse_ops.py | 85 +++++++++++++++++++ 2 files changed, 91 insertions(+) diff --git a/fbgemm_gpu/docs/src/fbgemm_gpu-python-api/sparse_ops.rst b/fbgemm_gpu/docs/src/fbgemm_gpu-python-api/sparse_ops.rst index e22812586..afc38a450 100644 --- a/fbgemm_gpu/docs/src/fbgemm_gpu-python-api/sparse_ops.rst +++ b/fbgemm_gpu/docs/src/fbgemm_gpu-python-api/sparse_ops.rst @@ -6,3 +6,9 @@ Sparse Operators .. autofunction:: torch.ops.fbgemm.permute_2D_sparse_data .. autofunction:: torch.ops.fbgemm.permute_1D_sparse_data + +.. autofunction:: torch.ops.fbgemm.expand_into_jagged_permute + +.. autofunction:: torch.ops.fbgemm.asynchronous_complete_cumsum + +.. autofunction:: torch.ops.fbgemm.offsets_range diff --git a/fbgemm_gpu/fbgemm_gpu/docs/sparse_ops.py b/fbgemm_gpu/fbgemm_gpu/docs/sparse_ops.py index c95588207..ae307dc8f 100644 --- a/fbgemm_gpu/fbgemm_gpu/docs/sparse_ops.py +++ b/fbgemm_gpu/fbgemm_gpu/docs/sparse_ops.py @@ -119,3 +119,88 @@ None) """, ) + +add_docs( + torch.ops.fbgemm.expand_into_jagged_permute, + """ +expand_into_jagged_permute(permute, input_offset, output_offset, output_size) -> Tensor + +Expand the sparse data permute index from feature dimension to batch dimension, +for cases where the sparse features has different batch sizes across ranks. + +The op expands the permute from feature level to batch level by contiguously +mapping each bag of its corresponding features to the position the batch sits +on after feature permute. The op will automatically derive offset array of +feature and batch to compute the output permute. + +Args: + permute (Tensor): The feature level permute index. + + input_offset (Tensor): The exclusive offsets of feature-level length. + + output_offsets (Tensor): The exclusive offsets of feature-level permuted + length. + + output_size (int): The number of elements in the output tensor + +Returns: + The output follows the following formula + + >>> output_permute[feature_offset[permute[feature]] + batch] <- bag_offset[batch] + """, +) + +add_docs( + torch.ops.fbgemm.asynchronous_complete_cumsum, + """ +asynchronous_complete_cumsum(t_in) -> Tensor + +Compute complete cumulative sum. For the GPU operator, the operator is +nonblocking asynchronous. For the CPU operator, it is a blocking operator. + +Args: + t_in (Tensor): An input tensor + +Returns: + The complete cumulative sum of `t_in`. Shape is `t_in.numel() + 1` + +**Example:** + + >>> t_in = torch.tensor([7, 8, 2, 1, 0, 9, 4], dtype=torch.int64, device="cuda") + >>> torch.ops.fbgemm.asynchronous_complete_cumsum(t_in) + tensor([ 0, 7, 15, 17, 18, 18, 27, 31], device='cuda:0') + """, +) + +add_docs( + torch.ops.fbgemm.offsets_range, + """ +offsets_range(offsets, range_size) -> Tensor + +Generate an integer sequence from 0 to `(offsets[i+1] - offsets[i])` for every +`i`, where `0 <= i < offsets.numel()` + +Args: + offsets (Tensor): The offsets (complete cumulative sum values) + + range_size (int): The output size (the total sum) + +Returns: + A tensor that contains offsets range + +**Example:** + >>> # Generate example inputs + >>> lengths = torch.tensor([3, 4, 1, 9, 3, 7], dtype=torch.int64, device="cuda") + >>> offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(lengths) + >>> range_size = offsets[-1].item() + >>> print(range_size) + 27 + >>> offsets = offsets[:-1] + >>> print(offsets) + tensor([ 0, 3, 7, 8, 17, 20], device='cuda:0') + >>> # Invoke + >>> torch.ops.fbgemm.offsets_range(offsets, range_size) + tensor([0, 1, 2, 0, 1, 2, 3, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 0, 1, 2, 0, 1, 2, 3, + 4, 5, 6], device='cuda:0') + """, +) From d056aa3689380f7decad83c90bfc36f5dcf04195 Mon Sep 17 00:00:00 2001 From: Supadchaya Puangpontip Date: Thu, 26 Sep 2024 23:18:38 -0700 Subject: [PATCH 04/48] Add generate_vbe_metadata CPU fallback (#3183) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/279 Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3183 Add generate_vbe_metadata for CPU Reviewed By: q10 Differential Revision: D63494418 fbshipit-source-id: 3936e2546ccf4fb89632df5c49148141adbabe71 --- .../fbgemm_gpu/split_embeddings_utils.cuh | 26 +------------- .../fbgemm_gpu/split_embeddings_utils.h | 36 +++++++++++++++++++ .../split_embeddings_utils.cpp | 17 +++++++++ 3 files changed, 54 insertions(+), 25 deletions(-) create mode 100644 fbgemm_gpu/include/fbgemm_gpu/split_embeddings_utils.h diff --git a/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_utils.cuh b/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_utils.cuh index 42fe5eb4c..8351e046c 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_utils.cuh +++ b/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_utils.cuh @@ -12,14 +12,7 @@ #include #include #include "fbgemm_gpu/embedding_common.h" - -// These values are adjusted in backward based on B and T -constexpr int DEFAULT_INFO_NUM_BITS = 32; -constexpr int DEFAULT_INFO_B_NUM_BITS = 26; -constexpr uint32_t DEFAULT_INFO_B_MASK = (1u << DEFAULT_INFO_B_NUM_BITS) - 1; -constexpr uint32_t MAX_T = - (1u << (DEFAULT_INFO_NUM_BITS - DEFAULT_INFO_B_NUM_BITS)) - 1; -constexpr uint32_t MAX_B = (1u << DEFAULT_INFO_B_NUM_BITS) - 1; +#include "fbgemm_gpu/split_embeddings_utils.h" /** * "Transpose" embedding inputs by sorting indices by their values. @@ -50,11 +43,6 @@ transpose_embedding_input( const int64_t fixed_L_per_warp = 0, const int64_t num_warps_per_feature = 0); -std::tuple -get_infos_metadata(at::Tensor unused, int64_t B, int64_t T); - -std::tuple adjust_info_B_num_bits(int32_t B, int32_t T); - // Use these functions instead of directly calling cub functions // to reduce code size and compilation time. // Arguments are the same as cub::DeviceRadixSort::SortPairs @@ -77,15 +65,3 @@ DECL_RADIX_SORT_PAIRS_FN(int64_t, int64_t); DECL_RADIX_SORT_PAIRS_FN(int64_t, int32_t); #undef DECL_RADIX_SORT_PAIRS_FN - -std::tuple -generate_vbe_metadata( - const at::Tensor& B_offsets, - const at::Tensor& B_offsets_rank_per_feature, - const at::Tensor& output_offsets_feature_rank, - const at::Tensor& D_offsets, - const int64_t D, - const bool nobag, - const int64_t max_B_feature_rank, - const int64_t info_B_num_bits, - const int64_t total_B); diff --git a/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_utils.h b/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_utils.h new file mode 100644 index 000000000..b41681012 --- /dev/null +++ b/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_utils.h @@ -0,0 +1,36 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +// These values are adjusted in backward based on B and T +constexpr int DEFAULT_INFO_NUM_BITS = 32; +constexpr int DEFAULT_INFO_B_NUM_BITS = 26; +constexpr uint32_t DEFAULT_INFO_B_MASK = (1u << DEFAULT_INFO_B_NUM_BITS) - 1; +constexpr uint32_t MAX_T = + (1u << (DEFAULT_INFO_NUM_BITS - DEFAULT_INFO_B_NUM_BITS)) - 1; +constexpr uint32_t MAX_B = (1u << DEFAULT_INFO_B_NUM_BITS) - 1; + +std::tuple +get_infos_metadata(at::Tensor unused, int64_t B, int64_t T); + +std::tuple adjust_info_B_num_bits(int32_t B, int32_t T); + +std::tuple +generate_vbe_metadata( + const at::Tensor& B_offsets, + const at::Tensor& B_offsets_rank_per_feature, + const at::Tensor& output_offsets_feature_rank, + const at::Tensor& D_offsets, + const int64_t D, + const bool nobag, + const int64_t max_B_feature_rank, + const int64_t info_B_num_bits, + const int64_t total_B); diff --git a/fbgemm_gpu/src/split_embeddings_utils/split_embeddings_utils.cpp b/fbgemm_gpu/src/split_embeddings_utils/split_embeddings_utils.cpp index 4e8407fb1..4ae9ae0f7 100644 --- a/fbgemm_gpu/src/split_embeddings_utils/split_embeddings_utils.cpp +++ b/fbgemm_gpu/src/split_embeddings_utils/split_embeddings_utils.cpp @@ -33,6 +33,22 @@ generate_vbe_metadata_meta( return {row_output_offsets, b_t_map}; } +std::tuple +generate_vbe_metadata_cpu( + const Tensor& B_offsets, + const Tensor& B_offsets_rank_per_feature, + const Tensor& output_offsets_feature_rank, + const Tensor& D_offsets, + const int64_t D, + const bool nobag, + const c10::SymInt max_B_feature_rank, + const int64_t info_B_num_bits, + const c10::SymInt total_B) { + Tensor row_output_offsets = output_offsets_feature_rank; + Tensor b_t_map = B_offsets_rank_per_feature; + return {row_output_offsets, b_t_map}; +} + } // namespace TORCH_LIBRARY_FRAGMENT(fbgemm, m) { @@ -68,6 +84,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { DISPATCH_TO_CUDA("transpose_embedding_input", transpose_embedding_input); DISPATCH_TO_CUDA("get_infos_metadata", get_infos_metadata); DISPATCH_TO_CUDA("generate_vbe_metadata", generate_vbe_metadata); + DISPATCH_TO_CPU("generate_vbe_metadata", generate_vbe_metadata_cpu); } TORCH_LIBRARY_IMPL(fbgemm, Meta, m) { From 5582d23699758348a271e5a0d51fc17cc6824541 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Fri, 27 Sep 2024 14:15:12 -0700 Subject: [PATCH 05/48] Allow bounds_check_mode to be set by environment variable (#3187) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3187 X-link: https://github.com/facebookresearch/FBGEMM/pull/282 This very simple diff allows TBE's bounds_check_mode to be overwritten by the environment variable `FBGEMM_TBE_BOUNDS_CHECK_MODE`. When set to `3`, this will disable bounds checking. Reviewed By: renganxu, qchip Differential Revision: D63549883 fbshipit-source-id: 4b5f3180e0e04acece9a4c36ba620da748f31107 --- .../split_table_batched_embeddings_ops_training.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py index 554bd0b00..70e1dac68 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py @@ -623,7 +623,10 @@ def __init__( # noqa C901 self.uuid = str(uuid.uuid4()) self.logging_table_name: str = self.get_table_name_for_logging(table_names) self.pooling_mode = pooling_mode - self.bounds_check_mode_int: int = bounds_check_mode.value + # If environment variable is set, it overwrites the default bounds check mode. + self.bounds_check_mode_int: int = int( + os.environ.get("FBGEMM_TBE_BOUNDS_CHECK_MODE", bounds_check_mode.value) + ) self.weights_precision = weights_precision self.output_dtype: int = output_dtype.as_int() assert ( From 8f6d96df19bcc5a3bb64400e1069469b07447978 Mon Sep 17 00:00:00 2001 From: Joe Wang Date: Fri, 27 Sep 2024 20:17:11 -0700 Subject: [PATCH 06/48] fix for CPU ooming (#3186) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3186 X-link: https://github.com/facebookresearch/FBGEMM/pull/281 issue: when we have prefetch_pipeline disabled, we still keep filling ssd_scratch_pads but didn't pop it. Reviewed By: chrisxcai Differential Revision: D63547432 fbshipit-source-id: 08a6363a22d3dbd3f80f2d044926ca98bb38c3ba --- fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py index 2c02db4b9..8d91c9597 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py @@ -499,8 +499,6 @@ def __init__( # pyre-fixme[4]: Attribute must be annotated. self.ssd_prefetch_data = [] - # Scratch pad value queue - self.ssd_scratch_pads: List[Tuple[Tensor, Tensor, Tensor]] = [] # Scratch pad eviction data queue self.ssd_scratch_pad_eviction_data: List[ Tuple[Tensor, Tensor, Tensor, bool] @@ -508,6 +506,9 @@ def __init__( self.ssd_location_update_data: List[Tuple[Tensor, Tensor]] = [] if self.prefetch_pipeline: + # Scratch pad value queue + self.ssd_scratch_pads: List[Tuple[Tensor, Tensor, Tensor]] = [] + # pyre-ignore[4] # Scratch pad index queue self.scratch_pad_idx_queue = torch.classes.fbgemm.SSDScratchPadIndicesQueue( @@ -1407,15 +1408,16 @@ def prefetch( # noqa C901 if t.is_cuda: t.record_stream(forward_stream) - # Store scratch pad info for the lookup in the next iteration - # prefetch - self.ssd_scratch_pads.append( - ( - inserted_rows, - post_bwd_evicted_indices_cpu, - actions_count_cpu, + if self.prefetch_pipeline: + # Store scratch pad info for the lookup in the next iteration + # prefetch + self.ssd_scratch_pads.append( + ( + inserted_rows, + post_bwd_evicted_indices_cpu, + actions_count_cpu, + ) ) - ) # Store scratch pad info for post backward eviction self.ssd_scratch_pad_eviction_data.append( From e57361074823611a3002de2f1aae54fa2bbfde28 Mon Sep 17 00:00:00 2001 From: Sarunya Pumma Date: Fri, 27 Sep 2024 23:30:28 -0700 Subject: [PATCH 07/48] Add docstrings for sparse ops (2) (#3185) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/280 Add a docstring for - torch.ops.fbgemm.segment_sum_csr - torch.ops.fbgemm.keyed_jagged_index_select_dim1 Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3185 Reviewed By: shintaro-iwasaki Differential Revision: D63520029 Pulled By: sryap fbshipit-source-id: 737cd62de83d5c31992aba8898231803545c393a --- .../src/fbgemm_gpu-python-api/sparse_ops.rst | 4 + fbgemm_gpu/fbgemm_gpu/docs/sparse_ops.py | 115 ++++++++++++++++++ 2 files changed, 119 insertions(+) diff --git a/fbgemm_gpu/docs/src/fbgemm_gpu-python-api/sparse_ops.rst b/fbgemm_gpu/docs/src/fbgemm_gpu-python-api/sparse_ops.rst index afc38a450..44d9e34ce 100644 --- a/fbgemm_gpu/docs/src/fbgemm_gpu-python-api/sparse_ops.rst +++ b/fbgemm_gpu/docs/src/fbgemm_gpu-python-api/sparse_ops.rst @@ -12,3 +12,7 @@ Sparse Operators .. autofunction:: torch.ops.fbgemm.asynchronous_complete_cumsum .. autofunction:: torch.ops.fbgemm.offsets_range + +.. autofunction:: torch.ops.fbgemm.segment_sum_csr + +.. autofunction:: torch.ops.fbgemm.keyed_jagged_index_select_dim1 diff --git a/fbgemm_gpu/fbgemm_gpu/docs/sparse_ops.py b/fbgemm_gpu/fbgemm_gpu/docs/sparse_ops.py index ae307dc8f..5dffc308f 100644 --- a/fbgemm_gpu/fbgemm_gpu/docs/sparse_ops.py +++ b/fbgemm_gpu/fbgemm_gpu/docs/sparse_ops.py @@ -204,3 +204,118 @@ 4, 5, 6], device='cuda:0') """, ) + +add_docs( + torch.ops.fbgemm.segment_sum_csr, + """ +segment_sum_csr(batch_size, csr_seg, values) -> Tensor + +Sum values within each segment on the given CSR data where each row has the +same number of non-zero elements. + +Args: + batch_size (int): The row stride (number of non-zero elements in each row) + + csr_seg (Tensor): The complete cumulative sum of segment lengths. A segment + length is the number of rows within each segment. The shape of the + `csr_seg` tensor is `num_segments + 1` where `num_segments` is the + number of segments. + + values (Tensor): The values tensor to be segment summed. The number of + elements in the tensor must be multiple of `batch_size` + +Returns: + A tensor containing the segment sum results. Shape is the number of + segments. + +**Example:** + + >>> batch_size = 2 + >>> # Randomize inputs + >>> lengths = torch.tensor([3, 4, 1], dtype=torch.int, device="cuda") + >>> offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(lengths) + >>> print(offsets) + tensor([0, 3, 7, 8], device='cuda:0', dtype=torch.int32) + >>> values = torch.randn(lengths.sum().item() * batch_size, dtype=torch.float32, device="cuda") + >>> print(values) + tensor([-2.8642e-01, 1.6451e+00, 1.1322e-01, 1.7335e+00, -8.4700e-02, + -1.2756e+00, 1.1206e+00, 9.6385e-01, 6.2122e-02, 1.3104e-03, + 2.2667e-01, 2.3113e+00, -1.1948e+00, -1.5463e-01, -1.0031e+00, + -3.5531e-01], device='cuda:0') + >>> # Invoke + >>> torch.ops.fbgemm.segment_sum_csr(batch_size, offsets, values) + tensor([ 1.8451, 3.3365, -1.3584], device='cuda:0') + """, +) + +add_docs( + torch.ops.fbgemm.keyed_jagged_index_select_dim1, + """ +keyed_jagged_index_select_dim1(values, lengths, offsets, indices, batch_size, weights=None, selected_lengths_sum=None) -> List[Tensor] + +Perform an index select operation on the batch dimension (dim 1) of the given +keyed jagged tensor (KJT) input. The same samples in the batch of every key +will be selected. Note that each KJT has 3 dimensions: (`num_keys`, `batch_size`, +jagged dim), where `num_keys` is the number of keys, and `batch_size` is the +batch size. This operator is similar to a permute operator. + +Args: + values (Tensor): The KJT values tensor which contains concatenated data of + every key + + lengths (Tensor): The KJT lengths tensor which contains the jagged shapes + of every key (dim 0) and sample (dim 1). Shape is `num_keys * + batch_size` + + offsets (Tensor): The KJT offsets tensor which is the complete cumulative + sum of `lengths`. Shape is `num_keys * batch_size + 1` + + indices (Tensor): The indices to select, i.e., samples in the batch to + select. The values of `indices` must be >= 0 and < `batch_size` + + batch_size (int): The batch size (dim 1 of KJT) + + weights (Optional[Tensor] = None): An optional float tensor which will be + selected the same way as `values`. Thus, it must have the same shape as + `values` + + selected_lengths_sum (Optional[int] = None): An optional value that + represents the total number of elements in the index select data + (output shape). If not provided, the operator will compute this data + which may cause a device-host synchronization (if using GPU). Thus, it + is recommended to supply this value to avoid such the synchronization. + +Returns: + The index-select KJT tensor (as a list of values, lengths, and weights if + `weights` is not None) + +**Example:** + + >>> num_keys = 2 + >>> batch_size = 4 + >>> output_size = 3 + >>> # Randomize inputs + >>> lengths = torch.randint(low=0, high=10, size=(batch_size * num_keys,), dtype=torch.int64, device="cuda") + >>> print(lengths) + tensor([8, 5, 1, 4, 2, 7, 5, 9], device='cuda:0') + >>> offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(lengths) + >>> print(offsets) + tensor([ 0, 8, 13, 14, 18, 20, 27, 32, 41], device='cuda:0') + >>> indices = torch.randint(low=0, high=batch_size, size=(output_size,), dtype=torch.int64, device="cuda") + >>> print(indices) + tensor([3, 3, 1], device='cuda:0') + >>> # Use torch.arange instead of torch.randn to simplify the example + >>> values = torch.arange(lengths.sum().item(), dtype=torch.float32, device="cuda") + >>> print(values) + tensor([ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., + 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., 26., 27., + 28., 29., 30., 31., 32., 33., 34., 35., 36., 37., 38., 39., 40.], + device='cuda:0') + >>> # Invoke. Output = (output, lengths) + >>> torch.ops.fbgemm.keyed_jagged_index_select_dim1(values, lengths, offsets, indices, batch_size) + [tensor([14., 15., 16., 17., 14., 15., 16., 17., 8., 9., 10., 11., 12., 32., + 33., 34., 35., 36., 37., 38., 39., 40., 32., 33., 34., 35., 36., 37., + 38., 39., 40., 20., 21., 22., 23., 24., 25., 26.], device='cuda:0'), + tensor([4, 4, 5, 9, 9, 7], device='cuda:0')] + """, +) From 00f2fd5c6504023c3015ebaee9a61634f292dc9a Mon Sep 17 00:00:00 2001 From: Minhua Chen Date: Sat, 28 Sep 2024 18:34:57 -0700 Subject: [PATCH 08/48] decouple ema and adagrad (fbgemm) (#3180) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/276 Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3180 decouple ema and adagrad (fbgemm) Reviewed By: q10 Differential Revision: D63458836 fbshipit-source-id: c868db48137e86beee7bf0c741e9dece4e6e1fa0 --- fbgemm_gpu/FbgemmGpu.cmake | 2 - .../genscript/generate_backward_split.py | 1 - fbgemm_gpu/codegen/genscript/optimizers.py | 121 ------------------ .../training/python/lookup_args.template | 4 - ..._embedding_codegen_lookup_invoker.template | 24 ---- ...split_embedding_optimizer_codegen.template | 38 +----- ...t_table_batched_embeddings_ops_training.py | 101 +++++++-------- fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py | 9 -- .../tbe/training/backward_optimizers_test.py | 50 +++++--- 9 files changed, 81 insertions(+), 269 deletions(-) diff --git a/fbgemm_gpu/FbgemmGpu.cmake b/fbgemm_gpu/FbgemmGpu.cmake index 27b4ec884..2ae401ea1 100644 --- a/fbgemm_gpu/FbgemmGpu.cmake +++ b/fbgemm_gpu/FbgemmGpu.cmake @@ -60,7 +60,6 @@ set(GPU_ONLY_OPTIMIZERS lamb partial_rowwise_adam partial_rowwise_lamb - ensemble_rowwise_adagrad lars_sgd none rowwise_adagrad_with_counter) @@ -87,7 +86,6 @@ set(GPU_OPTIMIZERS ${COMMON_OPTIMIZERS} ${GPU_ONLY_OPTIMIZERS}) set(VBE_OPTIMIZERS rowwise_adagrad rowwise_adagrad_with_counter - ensemble_rowwise_adagrad sgd dense) diff --git a/fbgemm_gpu/codegen/genscript/generate_backward_split.py b/fbgemm_gpu/codegen/genscript/generate_backward_split.py index afdcb8b3c..ac37444ff 100644 --- a/fbgemm_gpu/codegen/genscript/generate_backward_split.py +++ b/fbgemm_gpu/codegen/genscript/generate_backward_split.py @@ -335,7 +335,6 @@ def generate() -> None: lars_sgd(), partial_rowwise_adam(), partial_rowwise_lamb(), - ensemble_rowwise_adagrad(), rowwise_adagrad(), approx_rowwise_adagrad(), rowwise_adagrad_with_weight_decay(), diff --git a/fbgemm_gpu/codegen/genscript/optimizers.py b/fbgemm_gpu/codegen/genscript/optimizers.py index acf2af31f..15c100ed5 100644 --- a/fbgemm_gpu/codegen/genscript/optimizers.py +++ b/fbgemm_gpu/codegen/genscript/optimizers.py @@ -1020,127 +1020,6 @@ def adam() -> Dict[str, Any]: } -def ensemble_rowwise_adagrad() -> Dict[str, Any]: - split_precomputation = """ - at::acc_type g_local_sum_square = 0.0; - """ - split_precomputation += generate_optimized_grad_sum_loop_access( - """ - const float4* grad = &{grad_vec}.acc; - auto gx = grad->x; - auto gy = grad->y; - auto gz = grad->z; - auto gw = grad->w; - g_local_sum_square += gx * gx + gy * gy + gz * gz + gw * gw; - """ - ) - split_precomputation += """ - const at::acc_type g_avg_square = - GROUP_REDUCE_ALL_SUM(g_local_sum_square, at::acc_type) / D; - - at::acc_type multiplier; - at::acc_type coef_ema; - at::acc_type should_ema; - at::acc_type should_swap; - if (threadIdx.x == 0) { - at::acc_type new_sum_square_grads = momentum2[idx] + g_avg_square; - momentum2[idx] = new_sum_square_grads; - multiplier = learning_rate / (sqrtf(new_sum_square_grads) + eps); - - coef_ema = (row_counter[idx] > step_start) ? (momentum*1.0) : 0.0; - if (step_mode == 1) { - // row_counter[idx] tracks the number of appearances of this ID - row_counter[idx] += 1.0; - should_ema = floorf(row_counter[idx] / step_ema) - floorf((row_counter[idx]-1.0) / step_ema); - should_swap = floorf(row_counter[idx] / step_swap) - floorf((row_counter[idx]-1.0) / step_swap); - } else if (step_mode == 2) { - should_ema = floorf((iter*1.0 - row_counter[idx]) / step_ema); - should_swap = floorf((iter*1.0 - prev_iter[idx]) / step_swap); - // row_counter[idx] records the step of last ema - if (should_ema > 0.5) { - coef_ema = powf(coef_ema, (iter*1.0 - row_counter[idx]) / step_ema); - row_counter[idx] = iter*1.0; - } - // prev_iter[idx] records the step of last swap - if (should_swap > 0.5) { - prev_iter[idx] = iter*1.0; - } - } else { - should_ema = 0.0; - should_swap = 0.0; - } - } - multiplier = SHFL_SYNC(multiplier, 0); - coef_ema = SHFL_SYNC(coef_ema, 0); - should_ema = SHFL_SYNC(should_ema, 0); - should_swap = SHFL_SYNC(should_swap, 0); - """ - - split_weight_update = """ - weight_new.acc.x = weight_new.acc.x - multiplier * grad.acc.x; - weight_new.acc.y = weight_new.acc.y - multiplier * grad.acc.y; - weight_new.acc.z = weight_new.acc.z - multiplier * grad.acc.z; - weight_new.acc.w = weight_new.acc.w - multiplier * grad.acc.w; - - if (should_ema > 0.5) { // slow table ema - Vec4T m_t(&momentum1[idx * D + d]); - m_t.acc.x = (1.0 - coef_ema) * weight_new.acc.x + coef_ema * m_t.acc.x + (momentum - coef_ema) * multiplier * grad.acc.x; - m_t.acc.y = (1.0 - coef_ema) * weight_new.acc.y + coef_ema * m_t.acc.y + (momentum - coef_ema) * multiplier * grad.acc.y; - m_t.acc.z = (1.0 - coef_ema) * weight_new.acc.z + coef_ema * m_t.acc.z + (momentum - coef_ema) * multiplier * grad.acc.z; - m_t.acc.w = (1.0 - coef_ema) * weight_new.acc.w + coef_ema * m_t.acc.w + (momentum - coef_ema) * multiplier * grad.acc.w; - m_t.store(&momentum1[idx * D + d]); - } - - if (should_swap > 0.5) { // slow-to-fast swap - Vec4T m_t(&momentum1[idx * D + d]); - weight_new.acc.x = m_t.acc.x * 1.0; - weight_new.acc.y = m_t.acc.y * 1.0; - weight_new.acc.z = m_t.acc.z * 1.0; - weight_new.acc.w = m_t.acc.w * 1.0; - } - """ - - split_weight_update_cpu = "" # TODO - - return { - "optimizer": "ensemble_rowwise_adagrad", - "is_prototype_optimizer": True, - "args": OptimizerArgsSet.create( - [ - OptimItem( - ArgType.PLACEHOLDER_TENSOR, - "momentum1", - ph_tys=[ArgType.FLOAT_TENSOR, ArgType.BFLOAT16_TENSOR], - ), - OptimItem( - ArgType.PLACEHOLDER_TENSOR, - "momentum2", - ph_tys=[ArgType.FLOAT_TENSOR, ArgType.BFLOAT16_TENSOR], - ), - OptimItem(ArgType.TENSOR, "prev_iter"), - OptimItem(ArgType.TENSOR, "row_counter"), - OptimItem(ArgType.FLOAT, "learning_rate"), - OptimItem(ArgType.FLOAT, "eps"), - OptimItem(ArgType.FLOAT, "step_ema"), - OptimItem(ArgType.FLOAT, "step_swap"), - OptimItem(ArgType.FLOAT, "step_start"), - OptimItem(ArgType.FLOAT, "momentum"), - OptimItem(ArgType.INT, "iter"), - OptimItem(ArgType.INT, "step_mode"), - ] - ), - "split_precomputation": split_precomputation, - "split_weight_update": split_weight_update, - "split_post_update": "", - "split_weight_update_cpu": split_weight_update_cpu, - "has_cpu_support": False, - "has_gpu_support": True, - "has_vbe_support": True, - "has_global_weight_decay_support": False, - "has_ssd_support": False, - } - - def partial_rowwise_adam() -> Dict[str, Any]: split_precomputation = """ at::acc_type g_local_sum_square = 0.0; diff --git a/fbgemm_gpu/codegen/training/python/lookup_args.template b/fbgemm_gpu/codegen/training/python/lookup_args.template index 54fa11177..357aad622 100644 --- a/fbgemm_gpu/codegen/training/python/lookup_args.template +++ b/fbgemm_gpu/codegen/training/python/lookup_args.template @@ -60,10 +60,6 @@ class OptimizerArgs(NamedTuple): eps: float beta1: float beta2: float - step_ema: float - step_swap: float - step_start: float - step_mode: int weight_decay: float weight_decay_mode: int eta: float diff --git a/fbgemm_gpu/codegen/training/python/split_embedding_codegen_lookup_invoker.template b/fbgemm_gpu/codegen/training/python/split_embedding_codegen_lookup_invoker.template index e03a879cb..2f14b27de 100644 --- a/fbgemm_gpu/codegen/training/python/split_embedding_codegen_lookup_invoker.template +++ b/fbgemm_gpu/codegen/training/python/split_embedding_codegen_lookup_invoker.template @@ -145,18 +145,6 @@ def invoke( {%- if "beta2" in args.split_function_arg_names %} beta2=optimizer_args.beta2, {%- endif %} - {%- if "step_ema" in args.split_function_arg_names %} - step_ema=optimizer_args.step_ema, - {%- endif %} - {%- if "step_swap" in args.split_function_arg_names %} - step_swap=optimizer_args.step_swap, - {%- endif %} - {%- if "step_start" in args.split_function_arg_names %} - step_start=optimizer_args.step_start, - {%- endif %} - {%- if "step_mode" in args.split_function_arg_names %} - step_mode=optimizer_args.step_mode, - {%- endif %} {%- if "weight_decay" in args.split_function_arg_names %} weight_decay=optimizer_args.weight_decay, {%- endif %} @@ -327,18 +315,6 @@ def invoke( {%- if "beta2" in args.split_function_arg_names %} beta2=optimizer_args.beta2, {%- endif %} - {%- if "step_ema" in args.split_function_arg_names %} - step_ema=optimizer_args.step_ema, - {%- endif %} - {%- if "step_swap" in args.split_function_arg_names %} - step_swap=optimizer_args.step_swap, - {%- endif %} - {%- if "step_start" in args.split_function_arg_names %} - step_start=optimizer_args.step_start, - {%- endif %} - {%- if "step_mode" in args.split_function_arg_names %} - step_mode=optimizer_args.step_mode, - {%- endif %} {%- if "weight_decay" in args.split_function_arg_names %} weight_decay=optimizer_args.weight_decay, {%- endif %} diff --git a/fbgemm_gpu/codegen/training/python/split_embedding_optimizer_codegen.template b/fbgemm_gpu/codegen/training/python/split_embedding_optimizer_codegen.template index b9be5cd4c..6c2380e7c 100644 --- a/fbgemm_gpu/codegen/training/python/split_embedding_optimizer_codegen.template +++ b/fbgemm_gpu/codegen/training/python/split_embedding_optimizer_codegen.template @@ -90,18 +90,6 @@ class SplitEmbedding{{ optimizer_class_name }}(Optimizer): {%- if "beta2" in args.split_function_arg_names %} beta2: float = 0.999, {%- endif %} - {%- if "step_ema" in args.split_function_arg_names %} - step_ema: float = 10000, - {%- endif %} - {%- if "step_swap" in args.split_function_arg_names %} - step_swap: float = 10000, - {%- endif %} - {%- if "step_start" in args.split_function_arg_names %} - step_start: float = 0, - {%- endif %} - {%- if "step_mode" in args.split_function_arg_names %} - step_mode: int = 2, - {%- endif %} {%- if "weight_decay" in args.split_function_arg_names %} weight_decay: float = 0.0, {%- endif %} @@ -130,18 +118,6 @@ class SplitEmbedding{{ optimizer_class_name }}(Optimizer): {%- if "beta2" in args.split_function_arg_names %} beta2=beta2, {%- endif %} - {%- if "step_ema" in args.split_function_arg_names %} - step_ema=step_ema, - {%- endif %} - {%- if "step_swap" in args.split_function_arg_names %} - step_swap=step_swap, - {%- endif %} - {%- if "step_start" in args.split_function_arg_names %} - step_start=step_start, - {%- endif %} - {%- if "step_mode" in args.split_function_arg_names %} - step_mode=step_mode, - {%- endif %} {%- if "weight_decay" in args.split_function_arg_names %} weight_decay=weight_decay, {%- endif %} @@ -186,7 +162,7 @@ class SplitEmbedding{{ optimizer_class_name }}(Optimizer): rowwise = False {% endif %} {% elif state_tensor == "momentum2" %} - {% if optimizer in ["partial_rowwise_adam", "partial_rowwise_lamb", "ensemble_rowwise_adagrad"] %} + {% if optimizer in ["partial_rowwise_adam", "partial_rowwise_lamb"] %} rowwise = True {% else %} rowwise = False @@ -236,18 +212,6 @@ class SplitEmbedding{{ optimizer_class_name }}(Optimizer): {%- if "beta2" in args.split_function_arg_names %} self.beta2 = beta2 {%- endif %} - {%- if "step_ema" in args.split_function_arg_names %} - self.step_ema = step_ema - {%- endif %} - {%- if "step_swap" in args.split_function_arg_names %} - self.step_swap = step_swap - {%- endif %} - {%- if "step_start" in args.split_function_arg_names %} - self.step_start = step_start - {%- endif %} - {%- if "step_mode" in args.split_function_arg_names %} - self.step_mode = step_mode - {%- endif %} {%- if "weight_decay" in args.split_function_arg_names %} self.weight_decay = weight_decay {%- endif %} diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py index 70e1dac68..3556d83a4 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py @@ -145,6 +145,14 @@ class GlobalWeightDecayDefinition: lower_bound: float = 0.0 +@dataclass(frozen=True) +class EnsembleModeDefinition: + step_ema: float = 10000 + step_swap: float = 10000 + step_start: float = 0 + step_mode: StepMode = StepMode.USE_ITER + + # Keep in sync with fbgemm_gpu/include/fbgemm_gpu/split_embeddings_cache_cuda.cuh class UVMCacheStatsIndex(enum.IntEnum): num_calls = 0 @@ -473,14 +481,8 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module): beta2 (float = 0.999): The beta2 value used by LAMB and ADAM - step_ema (float = 10000): Used by ENSEMBLE_ROWWISE_ADAGRAD - - step_swap (float = 10000): Used by ENSEMBLE_ROWWISE_ADAGRAD - - step_start (float = 0.0): Used by ENSEMBLE_ROWWISE_ADAGRAD - - step_mode: (StepMode = StepMode.USE_ITER): Used by - ENSEMBLE_ROWWISE_ADAGRAD + ensemble_mode (Optional[EnsembleModeDefinition] = None): + Used by Ensemble Rowwise Adagrad counter_based_regularization (Optional[CounterBasedRegularizationDefinition] = None): Used by Rowwise Adagrad @@ -598,10 +600,7 @@ def __init__( # noqa C901 eta: float = 0.001, beta1: float = 0.9, beta2: float = 0.999, - step_ema: float = 10000, - step_swap: float = 10000, - step_start: float = 0, - step_mode: StepMode = StepMode.USE_ITER, + ensemble_mode: Optional[EnsembleModeDefinition] = None, counter_based_regularization: Optional[ CounterBasedRegularizationDefinition ] = None, @@ -923,6 +922,15 @@ def __init__( # noqa C901 self.gwd_start_iter: int = global_weight_decay.start_iter self.gwd_lower_bound: float = global_weight_decay.lower_bound + if ensemble_mode is None: + ensemble_mode = EnsembleModeDefinition() + self._ensemble_mode: Dict[str, int] = { + "step_ema": int(ensemble_mode.step_ema), + "step_swap": int(ensemble_mode.step_swap), + "step_start": int(ensemble_mode.step_start), + "step_mode": int(ensemble_mode.step_mode.value), + } + if counter_based_regularization is None: counter_based_regularization = CounterBasedRegularizationDefinition() if cowclip_regularization is None: @@ -960,10 +968,6 @@ def __init__( # noqa C901 eps=eps, beta1=beta1, beta2=beta2, - step_ema=step_ema, - step_swap=step_swap, - step_start=step_start, - step_mode=step_mode.value, weight_decay=weight_decay, weight_decay_mode=opt_arg_weight_decay_mode.value, eta=eta, @@ -1003,6 +1007,7 @@ def __init__( # noqa C901 ) rowwise = optimizer in [ OptimType.EXACT_ROWWISE_ADAGRAD, + OptimType.ENSEMBLE_ROWWISE_ADAGRAD, ] self._apply_split( construct_split_state( @@ -1032,7 +1037,6 @@ def __init__( # noqa C901 rowwise = optimizer in ( OptimType.PARTIAL_ROWWISE_ADAM, OptimType.PARTIAL_ROWWISE_LAMB, - OptimType.ENSEMBLE_ROWWISE_ADAGRAD, ) momentum2_dtype = ( torch.float32 @@ -1062,9 +1066,7 @@ def __init__( # noqa C901 else: # NOTE: make TorchScript work! self._register_nonpersistent_buffers("momentum2") - if self._used_rowwise_adagrad_with_counter or ( - optimizer == OptimType.ENSEMBLE_ROWWISE_ADAGRAD - ): + if self._used_rowwise_adagrad_with_counter: self._apply_split( construct_split_state( embedding_specs, @@ -1868,18 +1870,15 @@ def forward( # noqa: C901 assert self._feature_is_enabled( FeatureGateName.TBE_ENSEMBLE_ROWWISE_ADAGRAD ), "ENSEMBLE_ROWWISE_ADAGRAD is an inactive or deprecated feature!" + with torch.no_grad(): + if self.training: + self.ensemble_and_swap(self._ensemble_mode) return self._report_io_size_count( "fwd_output", - invokers.lookup_ensemble_rowwise_adagrad.invoke( + invokers.lookup_rowwise_adagrad.invoke( common_args, self.optimizer_args, momentum1, - momentum2, - prev_iter, - row_counter, - iter=int(self.iter.item()), - apply_global_weight_decay=False, - gwd_lower_bound=0.0, ), ) @@ -1938,6 +1937,26 @@ def forward( # noqa: C901 raise ValueError(f"Invalid OptimType: {self.optimizer}") + def ensemble_and_swap(self, ensemble_mode: Dict[str, int]) -> None: + should_ema = self.iter.item() % int(ensemble_mode["step_ema"]) == 0 + should_swap = self.iter.item() % int(ensemble_mode["step_swap"]) == 0 + if should_ema or should_swap: + weights = self.split_embedding_weights() + states = self.split_optimizer_states() + for i in range(len(self.embedding_specs)): + if should_ema: + coef_ema = ( + self.optimizer_args.momentum + if self.iter.item() > int(ensemble_mode["step_start"]) + else 0.0 + ) + weights_cpu = weights[i].to( + dtype=states[i][1].dtype, device=states[i][1].device + ) + states[i][1].lerp_(weights_cpu, 1.0 - coef_ema) + if should_swap: + weights[i].copy_(states[i][1], non_blocking=True) + def reset_uvm_cache_stats(self) -> None: assert ( self.gather_uvm_cache_stats @@ -2346,9 +2365,7 @@ def get_optimizer_state(self) -> List[Dict[str, torch.Tensor]]: list_of_state_dict = [ { "sum": states[0], - "exp_avg": states[1], - "prev_iter": states[2], - "row_counter": states[3], + "sparse_ema": states[1], } for states in split_optimizer_states ] @@ -2393,8 +2410,7 @@ def split_optimizer_states( (8) `PARTIAL_ROWWISE_LAMB`: `momentum1`, `momentum2` (rowwise) - (9) `ENSEMBLE_ROWWISE_ADAGRAD`: `momentum2` (rowwise), `momentum1`, - `prev_iter` (rowwise), `row_counter` (rowwise) + (9) `ENSEMBLE_ROWWISE_ADAGRAD`: `momentum1` (rowwise), `momentum2` (10) `NONE`: no states (throwing an error) @@ -2431,19 +2447,6 @@ def get_optimizer_states( return splits states: List[List[torch.Tensor]] = [] - # For ensemble_rowwise_adagrad, momentum2 ("sum") should go first, - # as it is the default optimizer state for embedding pruning later. - if self.optimizer == OptimType.ENSEMBLE_ROWWISE_ADAGRAD: - states.append( - get_optimizer_states( - self.momentum2_dev, - self.momentum2_host, - self.momentum2_uvm, - self.momentum2_physical_offsets, - self.momentum2_physical_placements, - rowwise=True, - ) - ) if self.optimizer not in (OptimType.EXACT_SGD,): states.append( get_optimizer_states( @@ -2455,6 +2458,7 @@ def get_optimizer_states( rowwise=self.optimizer in [ OptimType.EXACT_ROWWISE_ADAGRAD, + OptimType.ENSEMBLE_ROWWISE_ADAGRAD, ], ) ) @@ -2463,6 +2467,7 @@ def get_optimizer_states( OptimType.PARTIAL_ROWWISE_ADAM, OptimType.LAMB, OptimType.PARTIAL_ROWWISE_LAMB, + OptimType.ENSEMBLE_ROWWISE_ADAGRAD, ): states.append( get_optimizer_states( @@ -2478,7 +2483,6 @@ def get_optimizer_states( if ( self._used_rowwise_adagrad_with_counter or self._used_rowwise_adagrad_with_global_weight_decay - or self.optimizer == OptimType.ENSEMBLE_ROWWISE_ADAGRAD ): states.append( get_optimizer_states( @@ -2490,10 +2494,7 @@ def get_optimizer_states( rowwise=True, ) ) - if ( - self._used_rowwise_adagrad_with_counter - or self.optimizer == OptimType.ENSEMBLE_ROWWISE_ADAGRAD - ): + if self._used_rowwise_adagrad_with_counter: states.append( get_optimizer_states( self.row_counter_dev, diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py index 8d91c9597..c1f13bd74 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py @@ -37,7 +37,6 @@ apply_split_helper, CounterBasedRegularizationDefinition, CowClipDefinition, - StepMode, UVMCacheStatsIndex, WeightDecayMode, ) @@ -116,10 +115,6 @@ def __init__( eta: float = 0.001, # used by LARS-SGD, beta1: float = 0.9, # used by LAMB and ADAM beta2: float = 0.999, # used by LAMB and ADAM - step_ema: float = 10000, # used by ENSEMBLE_ROWWISE_ADAGRAD - step_swap: float = 10000, # used by ENSEMBLE_ROWWISE_ADAGRAD - step_start: float = 0, # used by ENSEMBLE_ROWWISE_ADAGRAD - step_mode: StepMode = StepMode.USE_ITER, # used by ENSEMBLE_ROWWISE_ADAGRAD counter_based_regularization: Optional[ CounterBasedRegularizationDefinition ] = None, # used by Rowwise Adagrad @@ -536,10 +531,6 @@ def __init__( eps=eps, beta1=beta1, beta2=beta2, - step_ema=step_ema, - step_swap=step_swap, - step_start=step_start, - step_mode=step_mode.value, weight_decay=weight_decay, weight_decay_mode=weight_decay_mode.value, eta=eta, diff --git a/fbgemm_gpu/test/tbe/training/backward_optimizers_test.py b/fbgemm_gpu/test/tbe/training/backward_optimizers_test.py index adb96daaa..b1809dc1b 100644 --- a/fbgemm_gpu/test/tbe/training/backward_optimizers_test.py +++ b/fbgemm_gpu/test/tbe/training/backward_optimizers_test.py @@ -33,6 +33,14 @@ TailIdThreshold, WeightDecayMode, ) + +try: + from fbgemm_gpu.split_table_batched_embeddings_ops_training import ( + EnsembleModeDefinition, + ) +except ImportError: + EnsembleModeDefinition = None + from fbgemm_gpu.tbe.utils import ( b_indices, get_table_batched_offsets_from_dense, @@ -311,15 +319,17 @@ def execute_backward_optimizers_( # noqa C901 1e-4, 1.0, 1.0, - -1.0, + 0.0, StepMode.USE_ITER, 0.8, ) optimizer_kwargs["eps"] = eps - optimizer_kwargs["step_ema"] = step_ema - optimizer_kwargs["step_swap"] = step_swap - optimizer_kwargs["step_start"] = step_start - optimizer_kwargs["step_mode"] = step_mode + optimizer_kwargs["ensemble_mode"] = EnsembleModeDefinition( + step_ema=step_ema, + step_swap=step_swap, + step_start=step_start, + step_mode=step_mode, + ) optimizer_kwargs["momentum"] = momentum optimizer_kwargs["optimizer_state_dtypes"] = optimizer_state_dtypes @@ -555,14 +565,14 @@ def execute_backward_optimizers_( # noqa C901 if optimizer == OptimType.ENSEMBLE_ROWWISE_ADAGRAD: for t in range(T): iter_ = cc.iter.item() - (m2, m1, prev_iter, row_counter) = split_optimizer_states[t] + (m1, m2) = split_optimizer_states[t] if (m1.dtype == torch.float) and (m2.dtype == torch.float): tol = 1.0e-4 else: tol = 1.0e-2 # Some optimizers have non-float momentums - dense_cpu_grad = bs[t].weight.grad.cpu().to_dense() - m2_ref = dense_cpu_grad.pow(2).mean(dim=1) + m2_ref = torch.mul(bs[t].weight.cpu(), 1.0 - momentum) + weights_ref = m2_ref.mul(1.0) torch.testing.assert_close( m2.float().cpu().index_select(dim=0, index=xs[t].view(-1).cpu()), m2_ref.float() @@ -571,15 +581,8 @@ def execute_backward_optimizers_( # noqa C901 atol=tol, rtol=tol, ) - v_hat_t = m2_ref.view(m2_ref.numel(), 1) - weights_new = split_weights[t] - weights_ref = torch.addcdiv( - bs[t].weight.cpu(), - value=-lr, - tensor1=dense_cpu_grad, - tensor2=v_hat_t.sqrt_().add_(eps), - ) - m1_ref = torch.mul(weights_ref, 1.0 - momentum) + dense_cpu_grad = bs[t].weight.grad.cpu().to_dense() + m1_ref = dense_cpu_grad.pow(2).mean(dim=1) torch.testing.assert_close( m1.float().cpu().index_select(dim=0, index=xs[t].view(-1).cpu()), m1_ref.float() @@ -588,7 +591,14 @@ def execute_backward_optimizers_( # noqa C901 atol=tol, rtol=tol, ) - weights_ref = m1_ref.div(1.0) + v_hat_t = m1_ref.view(m1_ref.numel(), 1) + weights_new = split_weights[t] + weights_ref = torch.addcdiv( + weights_ref, + value=-lr, + tensor1=dense_cpu_grad, + tensor2=v_hat_t.sqrt_().add_(eps), + ) torch.testing.assert_close( weights_new.index_select(dim=0, index=xs[t].view(-1)).cpu(), weights_ref.index_select(dim=0, index=xs[t].view(-1).cpu()), @@ -599,9 +609,7 @@ def execute_backward_optimizers_( # noqa C901 optimizer_states_dict = get_optimizer_states[t] assert set(optimizer_states_dict.keys()) == { "sum", - "exp_avg", - "prev_iter", - "row_counter", + "sparse_ema", } if optimizer in (OptimType.PARTIAL_ROWWISE_LAMB, OptimType.LAMB): From 93dcc07a7ce09ae13d2731877e5f4e559ee1f36b Mon Sep 17 00:00:00 2001 From: Wang Zhou Date: Sun, 29 Sep 2024 20:08:58 -0700 Subject: [PATCH 09/48] use iter to compare with gwd_start_iter Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/292 `self.step` is used for logging purposes only and is not properly checkpointed/reloaded. Instead, `iter` should be used as it's reloaded from external https://fburl.com/code/48caflz4 Reviewed By: spcyppt Differential Revision: D63616621 fbshipit-source-id: fb7aacf0ad57088595b8b0d39bee3076b308d394 --- .../split_table_batched_embeddings_ops_training.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py index 3556d83a4..71c9325cf 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py @@ -1910,8 +1910,9 @@ def forward( # noqa: C901 ), ) elif self._used_rowwise_adagrad_with_global_weight_decay: + iter_ = int(self.iter.item()) apply_global_weight_decay = ( - self.step >= self.gwd_start_iter and self.training + iter_ >= self.gwd_start_iter and self.training ) return self._report_io_size_count( "fwd_output", @@ -1919,7 +1920,7 @@ def forward( # noqa: C901 common_args, self.optimizer_args, momentum1, - iter=int(self.iter.item()), + iter=iter_, apply_global_weight_decay=apply_global_weight_decay, prev_iter_dev=self.prev_iter_dev, gwd_lower_bound=self.gwd_lower_bound, From 4d168921867726c280d4709cc49fb65323a8c756 Mon Sep 17 00:00:00 2001 From: Benson Ma Date: Mon, 30 Sep 2024 11:34:14 -0700 Subject: [PATCH 10/48] Redefine FBGEMM targets with gpu_cpp_library [18/N] (#3190) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3190 X-link: https://github.com/facebookresearch/FBGEMM/pull/285 - Redefine sparse ops targets using `gpu_cpp_library` Reviewed By: spcyppt Differential Revision: D63424455 fbshipit-source-id: 928d92da85f5ae69c8d1f0ee2479ada66bf26380 --- fbgemm_gpu/FbgemmGpu.cmake | 1 + fbgemm_gpu/src/sparse_ops/common.h | 28 ++++ .../src/sparse_ops/sparse_async_cumsum.cpp | 151 ++++++++++++++++++ fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp | 138 +--------------- 4 files changed, 181 insertions(+), 137 deletions(-) create mode 100644 fbgemm_gpu/src/sparse_ops/common.h create mode 100644 fbgemm_gpu/src/sparse_ops/sparse_async_cumsum.cpp diff --git a/fbgemm_gpu/FbgemmGpu.cmake b/fbgemm_gpu/FbgemmGpu.cmake index 2ae401ea1..ccf4805cd 100644 --- a/fbgemm_gpu/FbgemmGpu.cmake +++ b/fbgemm_gpu/FbgemmGpu.cmake @@ -471,6 +471,7 @@ set(fbgemm_gpu_sources_static_cpu src/layout_transform_ops/layout_transform_ops_cpu.cpp src/quantize_ops/quantize_ops_cpu.cpp src/quantize_ops/quantize_ops_meta.cpp + src/sparse_ops/sparse_async_cumsum.cpp src/sparse_ops/sparse_ops_cpu.cpp src/sparse_ops/sparse_ops_meta.cpp src/embedding_inplace_ops/embedding_inplace_update_cpu.cpp diff --git a/fbgemm_gpu/src/sparse_ops/common.h b/fbgemm_gpu/src/sparse_ops/common.h new file mode 100644 index 000000000..1cdd8ce9e --- /dev/null +++ b/fbgemm_gpu/src/sparse_ops/common.h @@ -0,0 +1,28 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +using Tensor = at::Tensor; + +namespace fbgemm_gpu { + +namespace { +inline Tensor native_empty_like(const Tensor& self) { + return at::native::empty_like( + self, + c10::optTypeMetaToScalarType(self.options().dtype_opt()), + self.options().layout_opt(), + self.options().device_opt(), + self.options().pinned_memory_opt(), + c10::nullopt); +} + +} // namespace + +}; // namespace fbgemm_gpu diff --git a/fbgemm_gpu/src/sparse_ops/sparse_async_cumsum.cpp b/fbgemm_gpu/src/sparse_ops/sparse_async_cumsum.cpp new file mode 100644 index 000000000..e3f04b58e --- /dev/null +++ b/fbgemm_gpu/src/sparse_ops/sparse_async_cumsum.cpp @@ -0,0 +1,151 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include "common.h" +#include "fbgemm_gpu/sparse_ops.h" +#include "fbgemm_gpu/utils/dispatch_macros.h" +#include "fbgemm_gpu/utils/ops_utils.h" +#include "fbgemm_gpu/utils/tensor_utils.h" + +using Tensor = at::Tensor; + +namespace fbgemm_gpu { + +// 1D exclusive scan: output[i] = input[i-1] + input[i-2] + input[i-3] +// Used as a helper to several functions below. +template +U exclusive_scan_ptrs_cpu( + const int64_t N, + const T* const input, + U* const output) { + U cumsum = 0; + for (const auto i : c10::irange(N)) { + output[i] = cumsum; + cumsum += input[i]; + } + return cumsum; +} + +void asynchronous_exclusive_cumsum_cpu_out(Tensor& t_out, const Tensor& t_in) { + TENSOR_ON_CPU(t_in); + TENSOR_ON_CPU(t_out); + + const auto t_in_contig = t_in.expect_contiguous(); + at::native::resize_(t_out, t_in_contig->sizes(), c10::nullopt); + + FBGEMM_DISPATCH_ALL_TYPES( + t_in_contig->scalar_type(), + "asynchronous_exclusive_cumsum_cpu_kernel", + [&] { + exclusive_scan_ptrs_cpu( + t_in_contig->numel(), + t_in_contig->data_ptr(), + t_out.data_ptr()); + }); +} + +Tensor asynchronous_exclusive_cumsum_cpu(const Tensor& t_in) { + TENSOR_ON_CPU(t_in); + + const auto t_in_contig = t_in.expect_contiguous(); + auto output = native_empty_like(*t_in_contig); + asynchronous_exclusive_cumsum_cpu_out(output, *t_in_contig); + return output; +} + +Tensor asynchronous_inclusive_cumsum_cpu(const Tensor& t_in) { + TENSOR_ON_CPU(t_in); + + const auto t_in_contig = t_in.expect_contiguous(); + auto output = native_empty_like(*t_in_contig); + FBGEMM_DISPATCH_ALL_TYPES( + t_in_contig->scalar_type(), + "asynchronous_inclusive_cumsum_cpu_kernel", + [&] { + scalar_t cumsum = 0; + const auto* input_ptr = t_in_contig->data_ptr(); + const auto N = t_in_contig->numel(); + auto* output_ptr = output.data_ptr(); + + for (const auto i : c10::irange(N)) { + cumsum += input_ptr[i]; + output_ptr[i] = cumsum; + } + }); + return output; +} + +Tensor asynchronous_complete_cumsum_cpu_out(Tensor& t_out, const Tensor& t_in) { + TENSOR_ON_CPU(t_in); + TENSOR_ON_CPU(t_out); + const auto num_dims = t_in.dim(); + TORCH_CHECK(num_dims == 1 || num_dims == 2); + const auto t_in_contig = t_in.expect_contiguous(); + const auto t_out_contig = t_out.expect_contiguous(); + + FBGEMM_DISPATCH_ALL_TYPES( + t_in_contig->scalar_type(), + "asynchronous_complete_cumsum_cpu_kernel", + [&] { + if (num_dims == 1) { + const auto N = t_in_contig->numel(); + t_out.data_ptr()[N] = exclusive_scan_ptrs_cpu( + N, t_in_contig->data_ptr(), t_out.data_ptr()); + } else { + const auto num_vecs = t_in_contig->size(0); + const auto N = t_in_contig->size(1); + at::parallel_for(0, num_vecs, 1, [&](int64_t start, int64_t end) { + for (const auto i : c10::irange(start, end)) { + scalar_t* out_ptr = t_out.data_ptr() + i * (N + 1); + out_ptr[N] = exclusive_scan_ptrs_cpu( + N, t_in_contig->data_ptr() + i * N, out_ptr); + } + }); + } + }); + return t_out; +} + +Tensor asynchronous_complete_cumsum_cpu(const Tensor& t_in) { + const auto num_dims = t_in.dim(); + TORCH_CHECK(num_dims == 1 || num_dims == 2); + auto output = num_dims == 1 + ? at::empty({t_in.numel() + 1}, t_in.options()) + : at::empty({t_in.size(0), t_in.size(1) + 1}, t_in.options()); + + return asynchronous_complete_cumsum_cpu_out(output, t_in); +} + +} // namespace fbgemm_gpu + +TORCH_LIBRARY_FRAGMENT(fbgemm, m) { + m.def( + "asynchronous_exclusive_cumsum(Tensor t_in) -> Tensor", + {PT2_COMPLIANT_TAG}); + m.def( + "asynchronous_inclusive_cumsum(Tensor t_in) -> Tensor", + {PT2_COMPLIANT_TAG}); + m.def( + "asynchronous_complete_cumsum(Tensor t_in) -> Tensor", + {PT2_COMPLIANT_TAG}); +} + +TORCH_LIBRARY_IMPL(fbgemm, CPU, m) { + DISPATCH_TO_CPU( + "asynchronous_exclusive_cumsum", + fbgemm_gpu::asynchronous_exclusive_cumsum_cpu); + DISPATCH_TO_CPU( + "asynchronous_inclusive_cumsum", + fbgemm_gpu::asynchronous_inclusive_cumsum_cpu); + DISPATCH_TO_CPU( + "asynchronous_complete_cumsum", + fbgemm_gpu::asynchronous_complete_cumsum_cpu); +} diff --git a/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp b/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp index a80eea05e..7734cc69a 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp +++ b/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp @@ -20,6 +20,7 @@ #include #include +#include "common.h" #include "fbgemm_gpu/sparse_ops.h" #include "fbgemm_gpu/utils/dispatch_macros.h" #include "fbgemm_gpu/utils/ops_utils.h" @@ -128,16 +129,6 @@ Tensor pack_segments_autograd( return PackSegments::apply(t_in, lengths, max_length)[0]; } -Tensor native_empty_like(const Tensor& self) { - return at::native::empty_like( - self, - c10::optTypeMetaToScalarType(self.options().dtype_opt()), - self.options().layout_opt(), - self.options().device_opt(), - self.options().pinned_memory_opt(), - c10::nullopt); -} - template void prefix_sum(const int length, const T* const array, T* const presum) { presum[0] = 0; @@ -1317,115 +1308,6 @@ bucketize_sparse_features_cpu( return {new_lengths, new_indices, new_weights, new_pos}; } -// 1D exclusive scan: output[i] = input[i-1] + input[i-2] + input[i-3] -// Used as a helper to several functions below. -template -U exclusive_scan_ptrs_cpu( - const int64_t N, - const T* const input, - U* const output) { - U cumsum = 0; - for (const auto i : c10::irange(N)) { - output[i] = cumsum; - cumsum += input[i]; - } - return cumsum; -} - -void asynchronous_exclusive_cumsum_cpu_out( - at::Tensor& t_out, - const Tensor& t_in) { - TENSOR_ON_CPU(t_in); - TENSOR_ON_CPU(t_out); - - const auto t_in_contig = t_in.expect_contiguous(); - at::native::resize_(t_out, t_in_contig->sizes(), c10::nullopt); - - FBGEMM_DISPATCH_ALL_TYPES( - t_in_contig->scalar_type(), - "asynchronous_exclusive_cumsum_cpu_kernel", - [&] { - exclusive_scan_ptrs_cpu( - t_in_contig->numel(), - t_in_contig->data_ptr(), - t_out.data_ptr()); - }); -} - -Tensor asynchronous_exclusive_cumsum_cpu(const Tensor& t_in) { - TENSOR_ON_CPU(t_in); - - const auto t_in_contig = t_in.expect_contiguous(); - auto output = native_empty_like(*t_in_contig); - asynchronous_exclusive_cumsum_cpu_out(output, *t_in_contig); - return output; -} - -Tensor asynchronous_inclusive_cumsum_cpu(const Tensor& t_in) { - TENSOR_ON_CPU(t_in); - - const auto t_in_contig = t_in.expect_contiguous(); - auto output = native_empty_like(*t_in_contig); - FBGEMM_DISPATCH_ALL_TYPES( - t_in_contig->scalar_type(), - "asynchronous_inclusive_cumsum_cpu_kernel", - [&] { - scalar_t cumsum = 0; - const auto* input_ptr = t_in_contig->data_ptr(); - const auto N = t_in_contig->numel(); - auto* output_ptr = output.data_ptr(); - - for (const auto i : c10::irange(N)) { - cumsum += input_ptr[i]; - output_ptr[i] = cumsum; - } - }); - return output; -} - -at::Tensor asynchronous_complete_cumsum_cpu_out( - at::Tensor& t_out, - const at::Tensor& t_in) { - TENSOR_ON_CPU(t_in); - TENSOR_ON_CPU(t_out); - const auto num_dims = t_in.dim(); - TORCH_CHECK(num_dims == 1 || num_dims == 2); - const auto t_in_contig = t_in.expect_contiguous(); - const auto t_out_contig = t_out.expect_contiguous(); - - FBGEMM_DISPATCH_ALL_TYPES( - t_in_contig->scalar_type(), - "asynchronous_complete_cumsum_cpu_kernel", - [&] { - if (num_dims == 1) { - const auto N = t_in_contig->numel(); - t_out.data_ptr()[N] = exclusive_scan_ptrs_cpu( - N, t_in_contig->data_ptr(), t_out.data_ptr()); - } else { - const auto num_vecs = t_in_contig->size(0); - const auto N = t_in_contig->size(1); - at::parallel_for(0, num_vecs, 1, [&](int64_t start, int64_t end) { - for (const auto i : c10::irange(start, end)) { - scalar_t* out_ptr = t_out.data_ptr() + i * (N + 1); - out_ptr[N] = exclusive_scan_ptrs_cpu( - N, t_in_contig->data_ptr() + i * N, out_ptr); - } - }); - } - }); - return t_out; -} - -Tensor asynchronous_complete_cumsum_cpu(const Tensor& t_in) { - const auto num_dims = t_in.dim(); - TORCH_CHECK(num_dims == 1 || num_dims == 2); - auto output = num_dims == 1 - ? at::empty({t_in.numel() + 1}, t_in.options()) - : at::empty({t_in.size(0), t_in.size(1) + 1}, t_in.options()); - - return asynchronous_complete_cumsum_cpu_out(output, t_in); -} - template void reorder_batched_ad_lengths_( const Tensor& cat_ad_lengths, @@ -3100,15 +2982,6 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { "block_bucketize_sparse_features_inference(Tensor lengths, Tensor indices, bool bucketize_pos, bool sequence, Tensor block_sizes, SymInt my_size, Tensor? weights=None, Tensor? batch_size_per_feature=None, SymInt max_B= -1, Tensor[]? block_bucketize_pos=None, bool return_bucket_mapping=False, bool keep_orig_idx=False) -> (Tensor, Tensor, Tensor?, Tensor?, Tensor?, Tensor?)"); m.def( "bucketize_sparse_features(Tensor lengths, Tensor indices, bool bucketize_pos, SymInt my_size, Tensor? weights=None) -> (Tensor, Tensor, Tensor?, Tensor?)"); - m.def( - "asynchronous_exclusive_cumsum(Tensor t_in) -> Tensor", - {PT2_COMPLIANT_TAG}); - m.def( - "asynchronous_inclusive_cumsum(Tensor t_in) -> Tensor", - {PT2_COMPLIANT_TAG}); - m.def( - "asynchronous_complete_cumsum(Tensor t_in) -> Tensor", - {PT2_COMPLIANT_TAG}); m.def( "reorder_batched_sequence_embeddings(Tensor cat_sequence_embeddings_offsets, Tensor cat_sequence_embeddings, Tensor reordered_cat_sequence_embeddings_offsets, Tensor batch_offsets, SymInt num_items_in_batch) -> Tensor"); m.def( @@ -3214,15 +3087,6 @@ TORCH_LIBRARY_IMPL(fbgemm, CPU, m) { fbgemm_gpu::block_bucketize_sparse_features_inference_cpu); DISPATCH_TO_CPU( "bucketize_sparse_features", fbgemm_gpu::bucketize_sparse_features_cpu); - DISPATCH_TO_CPU( - "asynchronous_exclusive_cumsum", - fbgemm_gpu::asynchronous_exclusive_cumsum_cpu); - DISPATCH_TO_CPU( - "asynchronous_inclusive_cumsum", - fbgemm_gpu::asynchronous_inclusive_cumsum_cpu); - DISPATCH_TO_CPU( - "asynchronous_complete_cumsum", - fbgemm_gpu::asynchronous_complete_cumsum_cpu); DISPATCH_TO_CPU( "reorder_batched_ad_lengths", fbgemm_gpu::reorder_batched_ad_lengths_cpu); DISPATCH_TO_CPU( From f2a015602f053efffbebbca287d66d4440986490 Mon Sep 17 00:00:00 2001 From: Dan Zimmerman Date: Mon, 30 Sep 2024 11:40:18 -0700 Subject: [PATCH 11/48] Try to use triton.language.extra.libdevice when possible (#3196) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3196 X-link: https://github.com/pytorch/pytorch/pull/136997 X-link: https://github.com/facebookresearch/generative-recommenders/pull/90 X-link: https://github.com/facebookresearch/FBGEMM/pull/294 In view of https://github.com/triton-lang/triton/pull/3825 we should try to use `triton.language.extra.libdevice` instead of `triton.language.extra.cuda.libdevice`. Reviewed By: bertmaher, karthik-man Differential Revision: D63583965 fbshipit-source-id: d32f35f7524d45c1e7c95c095144ad27f16eaa5a --- .../experimental/gen_ai/test/kv_cache/rope_padded.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/fbgemm_gpu/experimental/gen_ai/test/kv_cache/rope_padded.py b/fbgemm_gpu/experimental/gen_ai/test/kv_cache/rope_padded.py index ab728db6f..4a7211c68 100644 --- a/fbgemm_gpu/experimental/gen_ai/test/kv_cache/rope_padded.py +++ b/fbgemm_gpu/experimental/gen_ai/test/kv_cache/rope_padded.py @@ -38,8 +38,12 @@ # pyre-fixme[21]: Could not find name `pow` in `triton.language.math`. from triton.language.math import pow except ImportError: - # @manual=//triton:triton - from triton.language.extra.cuda.libdevice import pow + try: + # @manual=//triton:triton + from triton.language.extra.libdevice import pow + except ImportError: + # @manual=//triton:triton + from triton.language.extra.cuda.libdevice import pow _INTERNAL_DTYPE_MAP: Dict[str, int] = {"": 0, "f32": 1, "f64": 2} From 0d5acd1ea92a127f64446d1b308f9bb24bad0f95 Mon Sep 17 00:00:00 2001 From: sryap <17482891+sryap@users.noreply.github.com> Date: Mon, 30 Sep 2024 12:06:34 -0700 Subject: [PATCH 12/48] Add the block_bucketize_sparse_features docstring (#3191) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/287 A title Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3191 Reviewed By: q10, iamzainhuda, spcyppt Differential Revision: D63583026 Pulled By: sryap fbshipit-source-id: 436ea2060e86c4722aadff07a98062b918005606 --- .../src/fbgemm_gpu-python-api/sparse_ops.rst | 2 + fbgemm_gpu/fbgemm_gpu/docs/sparse_ops.py | 147 ++++++++++++++++++ 2 files changed, 149 insertions(+) diff --git a/fbgemm_gpu/docs/src/fbgemm_gpu-python-api/sparse_ops.rst b/fbgemm_gpu/docs/src/fbgemm_gpu-python-api/sparse_ops.rst index 44d9e34ce..b95b6dda4 100644 --- a/fbgemm_gpu/docs/src/fbgemm_gpu-python-api/sparse_ops.rst +++ b/fbgemm_gpu/docs/src/fbgemm_gpu-python-api/sparse_ops.rst @@ -16,3 +16,5 @@ Sparse Operators .. autofunction:: torch.ops.fbgemm.segment_sum_csr .. autofunction:: torch.ops.fbgemm.keyed_jagged_index_select_dim1 + +.. autofunction:: torch.ops.fbgemm.block_bucketize_sparse_features \ No newline at end of file diff --git a/fbgemm_gpu/fbgemm_gpu/docs/sparse_ops.py b/fbgemm_gpu/fbgemm_gpu/docs/sparse_ops.py index 5dffc308f..333d5e5da 100644 --- a/fbgemm_gpu/fbgemm_gpu/docs/sparse_ops.py +++ b/fbgemm_gpu/fbgemm_gpu/docs/sparse_ops.py @@ -319,3 +319,150 @@ tensor([4, 4, 5, 9, 9, 7], device='cuda:0')] """, ) + +add_docs( + torch.ops.fbgemm.block_bucketize_sparse_features, + """ +block_bucketize_sparse_features(lengths, indices, bucketize_pos, sequence, block_sizes, my_size, weights=None, batch_size_per_feature=None, max_B= -1, block_bucketize_pos=None, keep_orig_idx=False) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor]] + +Preprocess sparse features by partitioning sparse features into multiple +buckets. Every feature is split into the same number of buckets, but the bucket +sizes (widths) for the different features can be different. Moreover, the +bucket sizes within each feature can be different. + +Args: + lengths (Tensor): The lengths of the sparse features. The tensor contains + the lengths of each sample in a batch and each feature. Shape is `B * + T` where `B` is the batch size and `T` is the number of features + + indices (Tensor): The sparse data. Only support integer types. Shape is the + sum of `lengths` + + bucketize_pos (bool): If True, return the original relative indices within + a sample. For example, `indices = [9, 8, 2, 1, 0, 8, 9]` and `lengths = + [3, 4]`. The original relative indices within a sample for the indices + are `[0, 1, 2, 0, 1, 2, 3]` + + sequence (bool): If True, return the new indices positions in the original + indices positions (the tensor is called `unbucketize_permute_data`). + + block_sizes (Tensor): This tensor is used for the case where the bucket + size within a feature is uniform (i.e., when + `block_bucketize_pos=None`). The tensor contains bucket sizes (i.e., + bucket widths) for each feature. `block_sizes[t]` represents the + bucket size of feature `t`. Shape is the number of features. + + my_size (int): The number of buckets for each feature. Note that every + feature has the same number of buckets. + + weights (Optional[Tensor] = None): An optional float tensor that will be + bucketized the same way as `indices`. This tensor must have the same + shape as `indices` + + batch_size_per_feature (Optional[Tensor] = None): An optional tensor that + contains batch sizes for different features. If not None, batch sizes + are not uniform among features. Otherwise, the operator will assume + that the batch size is uniform and infer it from the `lengths` and + `block_sizes` tensors + + max_B (int = -1): The max batch size. Must be set if + `batch_size_per_feature` is not None + + block_bucketize_pos (Optional[List[Tensor]] = None): The input is used for + non-uniform bucket sizes within a feature. `block_bucketize_pos` is a + list of tensors. Each tensor contains the range offsets of buckets for + each feature. These range offsets are equivalent to the complete + cumulative sum of the bucket sizes. For example, `[0, 4, 20]` represents + two buckets. The first bucket size is `(4 - 0) = 4`, and the second + bucket size is `(20 - 4) = 16`. The length of `block_bucketize_pos` + must be equal to the number of features. + + keep_orig_idx (bool = False): If True, return original indices instead of + the relative indices within each bucket + +Return: + A tuple of tensors containing + + (1) Bucketized lengths. Shape is `lengths.num() * my_size`. + + (2) Bucketized indices. Same shape as `indices`. + + (3) Bucketized weights or None if `weights` is None. Same shape as + `indices`. + + (4) Bucketized positions or None if `bucketize_pos=False`. Same shape as + `indices`. + + (5) `unbucketize_permute` or None if `sequence=False`. Same shape as + `indices` + +**Example**: + + >>> # Generate input example. Batch size = 2. Number of features = 4 + >>> lengths = torch.tensor([0, 2, 1, 3, 2, 3, 3, 1], dtype=torch.int, device="cuda") + >>> indices = torch.tensor([3, 4, 15, 11, 28, 29, 1, 10, 11, 12, 13, 11, 22, 20, 20], dtype=torch.int, device="cuda") + >>> block_sizes = torch.tensor([[5, 15, 10, 20]], dtype=torch.int, device="cuda") + >>> my_size = 2 # Number of buckets + >>> # Invoke with keep_orig_idx=False, bucketize_pos=False, and + >>> # sequence=False + >>> torch.ops.fbgemm.block_bucketize_sparse_features( + >>> lengths, + >>> indices, + >>> bucketize_pos=False, + >>> sequence=False, + >>> block_sizes=block_sizes, + >>> my_size=my_size, + >>> keep_orig_idx=False) + >>> # The first 8 values in the returned lengths are the lengths for bucket + >>> # 0 and the rests are the legths for bucket 1 + (tensor([0, 2, 0, 1, 1, 0, 1, 0, 0, 0, 1, 2, 1, 3, 2, 1], device='cuda:0', + dtype=torch.int32), + tensor([ 3, 4, 11, 1, 11, 0, 13, 14, 0, 1, 2, 3, 2, 0, 0], + device='cuda:0', dtype=torch.int32), + None, + None, + None) + >>> # Invoke with keep_orig_idx=True, bucketize_pos=True, and + >>> # sequence=True + >>> torch.ops.fbgemm.block_bucketize_sparse_features( + >>> lengths, + >>> indices, + >>> bucketize_pos=True, + >>> sequence=True, + >>> block_sizes=block_sizes, + >>> my_size=my_size, + >>> keep_orig_idx=True) + (tensor([0, 2, 0, 1, 1, 0, 1, 0, 0, 0, 1, 2, 1, 3, 2, 1], device='cuda:0', + dtype=torch.int32), + tensor([ 3, 4, 11, 1, 11, 15, 28, 29, 10, 11, 12, 13, 22, 20, 20], + device='cuda:0', dtype=torch.int32), + None, + tensor([0, 1, 0, 0, 0, 0, 1, 2, 1, 0, 1, 2, 1, 2, 0], device='cuda:0', + dtype=torch.int32), + tensor([ 0, 1, 5, 2, 6, 7, 3, 8, 9, 10, 11, 4, 12, 13, 14], + device='cuda:0', dtype=torch.int32)) + >>> # Invoke with block_bucketize_pos + >>> block_bucketize_pos = [ + >>> torch.tensor([0, 2, 8], dtype=torch.int), + >>> torch.tensor([0, 5, 10], dtype=torch.int), + >>> torch.tensor([0, 7, 12], dtype=torch.int), + >>> torch.tensor([0, 2, 16], dtype=torch.int), + >>> ] + >>> torch.ops.fbgemm.block_bucketize_sparse_features( + >>> lengths, + >>> indices, + >>> bucketize_pos=False, + >>> sequence=False, + >>> block_sizes=block_sizes, + >>> my_size=my_size, + >>> block_bucketize_pos=block_bucketize_pos, + >>> keep_orig_idx=False) + (tensor([0, 0, 0, 1, 1, 1, 2, 1, 0, 2, 1, 2, 1, 2, 1, 0], device='cuda:0', + dtype=torch.int32), + tensor([14, 1, 6, 11, 10, 10, 1, 2, 7, 5, 14, 3, 4, 6, 9], + device='cuda:0', dtype=torch.int32), + None, + None, + None) + """, +) From e4384802dca6d6732df15b1c01e576ee7c9a89eb Mon Sep 17 00:00:00 2001 From: Benson Ma Date: Mon, 30 Sep 2024 13:14:57 -0700 Subject: [PATCH 13/48] Redefine FBGEMM targets with gpu_cpp_library [19/N] (#3192) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/288 Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3192 - Redefine layout transform ops targets using `gpu_cpp_library` Reviewed By: spcyppt Differential Revision: D63591232 fbshipit-source-id: 829cf15a7b90bdaabb9b431c49cb1e91d4b63f4b --- fbgemm_gpu/src/layout_transform_ops/layout_transform_ops.cu | 4 ++-- .../src/layout_transform_ops/layout_transform_ops_cpu.cpp | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/fbgemm_gpu/src/layout_transform_ops/layout_transform_ops.cu b/fbgemm_gpu/src/layout_transform_ops/layout_transform_ops.cu index e3c88b101..6870cdfb9 100644 --- a/fbgemm_gpu/src/layout_transform_ops/layout_transform_ops.cu +++ b/fbgemm_gpu/src/layout_transform_ops/layout_transform_ops.cu @@ -9,19 +9,19 @@ // clang-format off #include "fbgemm_gpu/utils/cub_namespace_prefix.cuh" #include -#include "fbgemm_gpu/utils/dispatch_macros.h" #include "fbgemm_gpu/utils/cub_namespace_postfix.cuh" // clang-format on #include +#include #include #include #include #include #include -#include "ATen/Parallel.h" #include "fbgemm_gpu/layout_transform_ops.cuh" #include "fbgemm_gpu/sparse_ops.h" +#include "fbgemm_gpu/utils/dispatch_macros.h" #include "fbgemm_gpu/utils/ops_utils.h" #include "fbgemm_gpu/utils/tensor_utils.h" diff --git a/fbgemm_gpu/src/layout_transform_ops/layout_transform_ops_cpu.cpp b/fbgemm_gpu/src/layout_transform_ops/layout_transform_ops_cpu.cpp index e9159f37a..a01de3a67 100644 --- a/fbgemm_gpu/src/layout_transform_ops/layout_transform_ops_cpu.cpp +++ b/fbgemm_gpu/src/layout_transform_ops/layout_transform_ops_cpu.cpp @@ -7,9 +7,9 @@ */ #include +#include #include #include -#include "ATen/Parallel.h" #include "fbgemm_gpu/utils/dispatch_macros.h" #include "fbgemm_gpu/utils/ops_utils.h" From 7a881f298a8fe23ac4e18151b1b32e90cc0fb212 Mon Sep 17 00:00:00 2001 From: Nicholas Ormrod Date: Mon, 30 Sep 2024 13:55:50 -0700 Subject: [PATCH 14/48] Deshim coro in fbcode/deeplearning (#3198) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/296 Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3198 The following rules were deshimmed: ``` //folly/experimental/coro:accumulate -> //folly/coro:accumulate //folly/experimental/coro:async_generator -> //folly/coro:async_generator //folly/experimental/coro:async_pipe -> //folly/coro:async_pipe //folly/experimental/coro:async_scope -> //folly/coro:async_scope //folly/experimental/coro:async_stack -> //folly/coro:async_stack //folly/experimental/coro:baton -> //folly/coro:baton //folly/experimental/coro:blocking_wait -> //folly/coro:blocking_wait //folly/experimental/coro:collect -> //folly/coro:collect //folly/experimental/coro:concat -> //folly/coro:concat //folly/experimental/coro:coroutine -> //folly/coro:coroutine //folly/experimental/coro:current_executor -> //folly/coro:current_executor //folly/experimental/coro:detach_on_cancel -> //folly/coro:detach_on_cancel //folly/experimental/coro:detail_barrier -> //folly/coro:detail_barrier //folly/experimental/coro:detail_barrier_task -> //folly/coro:detail_barrier_task //folly/experimental/coro:detail_current_async_frame -> //folly/coro:detail_current_async_frame //folly/experimental/coro:detail_helpers -> //folly/coro:detail_helpers //folly/experimental/coro:detail_malloc -> //folly/coro:detail_malloc //folly/experimental/coro:detail_manual_lifetime -> //folly/coro:detail_manual_lifetime //folly/experimental/coro:detail_traits -> //folly/coro:detail_traits //folly/experimental/coro:filter -> //folly/coro:filter //folly/experimental/coro:future_util -> //folly/coro:future_util //folly/experimental/coro:generator -> //folly/coro:generator //folly/experimental/coro:gmock_helpers -> //folly/coro:gmock_helpers //folly/experimental/coro:gtest_helpers -> //folly/coro:gtest_helpers //folly/experimental/coro:inline_task -> //folly/coro:inline_task //folly/experimental/coro:invoke -> //folly/coro:invoke //folly/experimental/coro:merge -> //folly/coro:merge //folly/experimental/coro:mutex -> //folly/coro:mutex //folly/experimental/coro:promise -> //folly/coro:promise //folly/experimental/coro:result -> //folly/coro:result //folly/experimental/coro:retry -> //folly/coro:retry //folly/experimental/coro:rust_adaptors -> //folly/coro:rust_adaptors //folly/experimental/coro:scope_exit -> //folly/coro:scope_exit //folly/experimental/coro:shared_lock -> //folly/coro:shared_lock //folly/experimental/coro:shared_mutex -> //folly/coro:shared_mutex //folly/experimental/coro:sleep -> //folly/coro:sleep //folly/experimental/coro:small_unbounded_queue -> //folly/coro:small_unbounded_queue //folly/experimental/coro:task -> //folly/coro:task //folly/experimental/coro:timed_wait -> //folly/coro:timed_wait //folly/experimental/coro:timeout -> //folly/coro:timeout //folly/experimental/coro:traits -> //folly/coro:traits //folly/experimental/coro:transform -> //folly/coro:transform //folly/experimental/coro:unbounded_queue -> //folly/coro:unbounded_queue //folly/experimental/coro:via_if_async -> //folly/coro:via_if_async //folly/experimental/coro:with_async_stack -> //folly/coro:with_async_stack //folly/experimental/coro:with_cancellation -> //folly/coro:with_cancellation //folly/experimental/coro:bounded_queue -> //folly/coro:bounded_queue //folly/experimental/coro:shared_promise -> //folly/coro:shared_promise //folly/experimental/coro:cleanup -> //folly/coro:cleanup //folly/experimental/coro:auto_cleanup_fwd -> //folly/coro:auto_cleanup_fwd //folly/experimental/coro:auto_cleanup -> //folly/coro:auto_cleanup ``` The following headers were deshimmed: ``` folly/experimental/coro/Accumulate.h -> folly/coro/Accumulate.h folly/experimental/coro/Accumulate-inl.h -> folly/coro/Accumulate-inl.h folly/experimental/coro/AsyncGenerator.h -> folly/coro/AsyncGenerator.h folly/experimental/coro/AsyncPipe.h -> folly/coro/AsyncPipe.h folly/experimental/coro/AsyncScope.h -> folly/coro/AsyncScope.h folly/experimental/coro/AsyncStack.h -> folly/coro/AsyncStack.h folly/experimental/coro/Baton.h -> folly/coro/Baton.h folly/experimental/coro/BlockingWait.h -> folly/coro/BlockingWait.h folly/experimental/coro/Collect.h -> folly/coro/Collect.h folly/experimental/coro/Collect-inl.h -> folly/coro/Collect-inl.h folly/experimental/coro/Concat.h -> folly/coro/Concat.h folly/experimental/coro/Concat-inl.h -> folly/coro/Concat-inl.h folly/experimental/coro/Coroutine.h -> folly/coro/Coroutine.h folly/experimental/coro/CurrentExecutor.h -> folly/coro/CurrentExecutor.h folly/experimental/coro/DetachOnCancel.h -> folly/coro/DetachOnCancel.h folly/experimental/coro/detail/Barrier.h -> folly/coro/detail/Barrier.h folly/experimental/coro/detail/BarrierTask.h -> folly/coro/detail/BarrierTask.h folly/experimental/coro/detail/CurrentAsyncFrame.h -> folly/coro/detail/CurrentAsyncFrame.h folly/experimental/coro/detail/Helpers.h -> folly/coro/detail/Helpers.h folly/experimental/coro/detail/Malloc.h -> folly/coro/detail/Malloc.h folly/experimental/coro/detail/ManualLifetime.h -> folly/coro/detail/ManualLifetime.h folly/experimental/coro/detail/Traits.h -> folly/coro/detail/Traits.h folly/experimental/coro/Filter.h -> folly/coro/Filter.h folly/experimental/coro/Filter-inl.h -> folly/coro/Filter-inl.h folly/experimental/coro/FutureUtil.h -> folly/coro/FutureUtil.h folly/experimental/coro/Generator.h -> folly/coro/Generator.h folly/experimental/coro/GmockHelpers.h -> folly/coro/GmockHelpers.h folly/experimental/coro/GtestHelpers.h -> folly/coro/GtestHelpers.h folly/experimental/coro/detail/InlineTask.h -> folly/coro/detail/InlineTask.h folly/experimental/coro/Invoke.h -> folly/coro/Invoke.h folly/experimental/coro/Merge.h -> folly/coro/Merge.h folly/experimental/coro/Merge-inl.h -> folly/coro/Merge-inl.h folly/experimental/coro/Mutex.h -> folly/coro/Mutex.h folly/experimental/coro/Promise.h -> folly/coro/Promise.h folly/experimental/coro/Result.h -> folly/coro/Result.h folly/experimental/coro/Retry.h -> folly/coro/Retry.h folly/experimental/coro/RustAdaptors.h -> folly/coro/RustAdaptors.h folly/experimental/coro/ScopeExit.h -> folly/coro/ScopeExit.h folly/experimental/coro/SharedLock.h -> folly/coro/SharedLock.h folly/experimental/coro/SharedMutex.h -> folly/coro/SharedMutex.h folly/experimental/coro/Sleep.h -> folly/coro/Sleep.h folly/experimental/coro/Sleep-inl.h -> folly/coro/Sleep-inl.h folly/experimental/coro/SmallUnboundedQueue.h -> folly/coro/SmallUnboundedQueue.h folly/experimental/coro/Task.h -> folly/coro/Task.h folly/experimental/coro/TimedWait.h -> folly/coro/TimedWait.h folly/experimental/coro/Timeout.h -> folly/coro/Timeout.h folly/experimental/coro/Timeout-inl.h -> folly/coro/Timeout-inl.h folly/experimental/coro/Traits.h -> folly/coro/Traits.h folly/experimental/coro/Transform.h -> folly/coro/Transform.h folly/experimental/coro/Transform-inl.h -> folly/coro/Transform-inl.h folly/experimental/coro/UnboundedQueue.h -> folly/coro/UnboundedQueue.h folly/experimental/coro/ViaIfAsync.h -> folly/coro/ViaIfAsync.h folly/experimental/coro/WithAsyncStack.h -> folly/coro/WithAsyncStack.h folly/experimental/coro/WithCancellation.h -> folly/coro/WithCancellation.h folly/experimental/coro/BoundedQueue.h -> folly/coro/BoundedQueue.h folly/experimental/coro/SharedPromise.h -> folly/coro/SharedPromise.h folly/experimental/coro/Cleanup.h -> folly/coro/Cleanup.h folly/experimental/coro/AutoCleanup-fwd.h -> folly/coro/AutoCleanup-fwd.h folly/experimental/coro/AutoCleanup.h -> folly/coro/AutoCleanup.h ``` This is a codemod. It was automatically generated and will be landed once it is approved and tests are passing in sandcastle. You have been added as a reviewer by Sentinel or Butterfly. Autodiff project: dcoro Autodiff partition: fbcode.deeplearning Autodiff bookmark: ad.dcoro.fbcode.deeplearning Reviewed By: dtolnay Differential Revision: D62684379 fbshipit-source-id: b1b19e549838296b68889909880c3f7dba3a1a50 --- .../ps_split_embeddings_cache/ps_table_batched_embeddings.h | 4 ++-- .../kv_db_table_batched_embeddings.cpp | 2 +- .../kv_db_table_batched_embeddings.h | 2 +- .../ssd_table_batched_embeddings.h | 6 +++--- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/fbgemm_gpu/src/ps_split_embeddings_cache/ps_table_batched_embeddings.h b/fbgemm_gpu/src/ps_split_embeddings_cache/ps_table_batched_embeddings.h index 0498bda96..9e091f1c1 100644 --- a/fbgemm_gpu/src/ps_split_embeddings_cache/ps_table_batched_embeddings.h +++ b/fbgemm_gpu/src/ps_split_embeddings_cache/ps_table_batched_embeddings.h @@ -9,8 +9,8 @@ #pragma once #include "../ssd_split_embeddings_cache/kv_db_table_batched_embeddings.h" -#include -#include +#include +#include #include "mvai_infra/experimental/ps_training/tps_client/TrainingParameterServiceClient.h" namespace ps { diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.cpp b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.cpp index fdc91b0e9..ca5d3b6cb 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.cpp +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.cpp @@ -7,8 +7,8 @@ */ #include "kv_db_table_batched_embeddings.h" +#include #include -#include #include #include "common/time/Time.h" #include "kv_db_cuda_utils.h" diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.h b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.h index 93495f2da..d6e0c9180 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.h +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.h @@ -37,7 +37,7 @@ #include #include -#include +#include #include "fbgemm_gpu/split_embeddings_cache/cachelib_cache.h" #include "fbgemm_gpu/utils/dispatch_macros.h" diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h index 23b6f1e89..87924eefe 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h @@ -11,10 +11,10 @@ #include #include +#include +#include +#include #include -#include -#include -#include #include #ifdef FBGEMM_FBCODE #include "common/strings/UUID.h" From 07b36b78f54fa778d11910205ef0d315e65a4511 Mon Sep 17 00:00:00 2001 From: Minhua Chen Date: Mon, 30 Sep 2024 15:35:26 -0700 Subject: [PATCH 15/48] ensemble_mode consolidation (fbgemm) (#3197) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3197 X-link: https://github.com/facebookresearch/FBGEMM/pull/295 ensemble_mode consolidation (fbgemm) Reviewed By: q10, csmiler Differential Revision: D63634421 fbshipit-source-id: f01ca98b1c157fddb1592a9f7c3a0ae45722e542 --- ...plit_table_batched_embeddings_ops_training.py | 16 +++++++--------- .../tbe/training/backward_optimizers_test.py | 15 ++++----------- 2 files changed, 11 insertions(+), 20 deletions(-) diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py index 71c9325cf..f41a49d19 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py @@ -150,6 +150,7 @@ class EnsembleModeDefinition: step_ema: float = 10000 step_swap: float = 10000 step_start: float = 0 + step_ema_coef: float = 0.6 step_mode: StepMode = StepMode.USE_ITER @@ -457,8 +458,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module): Adam. Note that default is different from torch.nn.optim.Adagrad default of 1e-10 - momentum (float = 0.9): Momentum used by LARS-SGD and - ENSEMBLE_ROWWISE_ADAGRAD + momentum (float = 0.9): Momentum used by LARS-SGD weight_decay (float = 0.0): Weight decay used by LARS-SGD, LAMB, ADAM, and rowwise-Adagrad. @@ -924,11 +924,8 @@ def __init__( # noqa C901 if ensemble_mode is None: ensemble_mode = EnsembleModeDefinition() - self._ensemble_mode: Dict[str, int] = { - "step_ema": int(ensemble_mode.step_ema), - "step_swap": int(ensemble_mode.step_swap), - "step_start": int(ensemble_mode.step_start), - "step_mode": int(ensemble_mode.step_mode.value), + self._ensemble_mode: Dict[str, float] = { + key: float(fval) for key, fval in ensemble_mode.__dict__.items() } if counter_based_regularization is None: @@ -1002,6 +999,7 @@ def __init__( # noqa C901 if ( optimizer_state_dtypes is None or "momentum1" not in optimizer_state_dtypes + or optimizer == OptimType.ENSEMBLE_ROWWISE_ADAGRAD ) else optimizer_state_dtypes["momentum1"].as_dtype() ) @@ -1938,7 +1936,7 @@ def forward( # noqa: C901 raise ValueError(f"Invalid OptimType: {self.optimizer}") - def ensemble_and_swap(self, ensemble_mode: Dict[str, int]) -> None: + def ensemble_and_swap(self, ensemble_mode: Dict[str, float]) -> None: should_ema = self.iter.item() % int(ensemble_mode["step_ema"]) == 0 should_swap = self.iter.item() % int(ensemble_mode["step_swap"]) == 0 if should_ema or should_swap: @@ -1947,7 +1945,7 @@ def ensemble_and_swap(self, ensemble_mode: Dict[str, int]) -> None: for i in range(len(self.embedding_specs)): if should_ema: coef_ema = ( - self.optimizer_args.momentum + ensemble_mode["step_ema_coef"] if self.iter.item() > int(ensemble_mode["step_start"]) else 0.0 ) diff --git a/fbgemm_gpu/test/tbe/training/backward_optimizers_test.py b/fbgemm_gpu/test/tbe/training/backward_optimizers_test.py index b1809dc1b..2db48594d 100644 --- a/fbgemm_gpu/test/tbe/training/backward_optimizers_test.py +++ b/fbgemm_gpu/test/tbe/training/backward_optimizers_test.py @@ -26,6 +26,7 @@ CounterBasedRegularizationDefinition, CounterWeightDecayMode, CowClipDefinition, + EnsembleModeDefinition, GradSumDecay, LearningRateMode, SplitTableBatchedEmbeddingBagsCodegen, @@ -34,13 +35,6 @@ WeightDecayMode, ) -try: - from fbgemm_gpu.split_table_batched_embeddings_ops_training import ( - EnsembleModeDefinition, - ) -except ImportError: - EnsembleModeDefinition = None - from fbgemm_gpu.tbe.utils import ( b_indices, get_table_batched_offsets_from_dense, @@ -315,23 +309,22 @@ def execute_backward_optimizers_( # noqa C901 optimizer_kwargs["eta"] = eta if optimizer == OptimType.ENSEMBLE_ROWWISE_ADAGRAD: - (eps, step_ema, step_swap, step_start, step_mode, momentum) = ( + (eps, step_ema, step_swap, step_start, step_mode) = ( 1e-4, 1.0, 1.0, 0.0, StepMode.USE_ITER, - 0.8, ) optimizer_kwargs["eps"] = eps + optimizer_kwargs["optimizer_state_dtypes"] = optimizer_state_dtypes optimizer_kwargs["ensemble_mode"] = EnsembleModeDefinition( step_ema=step_ema, step_swap=step_swap, step_start=step_start, + step_ema_coef=momentum, step_mode=step_mode, ) - optimizer_kwargs["momentum"] = momentum - optimizer_kwargs["optimizer_state_dtypes"] = optimizer_state_dtypes cc = emb_op( embedding_specs=[ From d27acbd3611827ff457e70d809c6c167992d2104 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Mon, 30 Sep 2024 15:52:32 -0700 Subject: [PATCH 16/48] Make some fbgemm fp8 triton ops pt2 friendly (#3188) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/283 Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3188 Make some fbgemm fp8 triton ops pt2 friendly.. # What this diff tries to do * stop using TensorWrapper and tl.reinterpret * Remove the use of triton_heuristics for _kernel_matmul_fp8_row # What this diff won't help: * triton_herustics use cases of EVEN_K. One option is to just merge that into the autotuning configs # need to do in the future: * Update other ops, like quantize_fp8_row. * Update documentation. Feels pretty outdated, and some still reference to TensorWrapper. Reviewed By: henrylhtsang Differential Revision: D63560103 fbshipit-source-id: ca56b6fa3c041d130f79945b0c3d114a4e90e685 --- .../experimental/gemm/triton_gemm/fp8_gemm.py | 90 +++++++++++-------- 1 file changed, 51 insertions(+), 39 deletions(-) diff --git a/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py b/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py index 07765fa21..9df479e45 100644 --- a/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +++ b/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py @@ -6,7 +6,7 @@ # pyre-unsafe import logging -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Union import torch import triton # @manual @@ -43,7 +43,7 @@ def get_fp8_constants() -> Tuple[torch.dtype, tl.dtype, float, float]: return pt_fp8_dtype, tl_fp8_dtype, torch.finfo(pt_fp8_dtype).max, 1e-12 -def convert_fp8_type(tensor, dtype) -> triton.TensorWrapper: +def reinterpret_fp8_type(tensor: torch.Tensor, dtype: tl.dtype) -> TensorWrapper: """ Converts tensor to triton fp8 type. @@ -213,11 +213,6 @@ def get_configs_io_bound() -> List[Config]: "k_key", ], ) -@triton.heuristics( - { - "EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0, - } -) @triton.jit def _kernel_matmul_fp8_row( A_ptr, @@ -246,7 +241,6 @@ def _kernel_matmul_fp8_row( BLOCK_K: tl.constexpr, GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, - EVEN_K: tl.constexpr, USE_BIAS: tl.constexpr, AB_DTYPE: tl.constexpr, NUM_SMS: tl.constexpr, @@ -964,7 +958,7 @@ def get_tma_descriptor_kernel_param(self, name): return self.cuda_descriptors[name] -@torch.library.custom_op("triton::matmul_fp8_row", mutates_args=()) +@torch._library.triton_op("triton::matmul_fp8_row", mutates_args=()) def matmul_fp8_row( a: torch.Tensor, b: torch.Tensor, @@ -995,15 +989,15 @@ def matmul_fp8_row( torch.Tensor: [M, N] Output tensor a @ b / (a_scale[:, None] * b_scale[None, :]) """ # Get datatypes and constants to use. - _, tl_dtype, _, _ = get_fp8_constants() + pt_fp8_dtype, _, _, _ = get_fp8_constants() # Handle 3D+ a shape a_shape = a.shape a = a.view(-1, a.size(-1)) - # Reinterpret inputs into proper triton fp8 dtype. - a_tl = convert_fp8_type(a, tl_dtype) - b_tl = convert_fp8_type(b, tl_dtype) + # View inputs into proper torch fp8 dtype. + assert a.dtype == pt_fp8_dtype + assert b.dtype == pt_fp8_dtype M, N, K, m_key, n_key, k_key, c, c_dtype_triton, dot_out_dtype_triton, device = ( - prep_matmul(a_tl, b_tl, dot_out_dtype) + prep_matmul(a, b, dot_out_dtype) ) output_shape = a_shape[:-1] + (N,) @@ -1049,22 +1043,22 @@ def persistent_grid_tma(META): nonlocal desc_helper desc_helper.fill_2d_tma_descriptor( "a", - a_tl.data_ptr(), + a.data_ptr(), M, K, META["BLOCK_M"], META["BLOCK_K"], - a_tl.element_size(), + a.element_size(), ) desc_helper.fill_2d_tma_descriptor( "b", - b_tl.data_ptr(), + b.data_ptr(), N, K, META["BLOCK_N"], META["BLOCK_K"], - b_tl.element_size(), + b.element_size(), ) desc_helper.fill_2d_tma_descriptor( "c", @@ -1111,8 +1105,10 @@ def persistent_grid_tma(META): desc_b_scale = desc_helper.get_tma_descriptor_kernel_param("b_scale") desc_bias = desc_helper.get_tma_descriptor_kernel_param("bias") - # pyre-ignore[28]: - _kernel_matmul_fp8_row_tma_persistent[persistent_grid_tma]( + # pyre-ignore + torch._library.capture_triton(_kernel_matmul_fp8_row_tma_persistent)[ + persistent_grid_tma + ]( desc_a, desc_b, desc_c, @@ -1141,9 +1137,9 @@ def persistent_grid_tma(META): USE_BIAS=bias is not None, ) elif imprecise_acc: - _kernel_matmul_fp8_row_imprecise_acc[grid]( - a_tl, - b_tl, + torch._library.capture_triton(_kernel_matmul_fp8_row_imprecise_acc)[grid]( + a, + b, c, M, N, @@ -1168,9 +1164,9 @@ def persistent_grid_tma(META): AB_DTYPE=False, ) elif fp8_fast_accum: - _kernel_matmul_fp8_row[persistent_grid]( - a_tl, - b_tl, + torch._library.capture_triton(_kernel_matmul_fp8_row)[persistent_grid]( + a, + b, c, M, N, @@ -1196,9 +1192,11 @@ def persistent_grid_tma(META): NUM_SMS=NUM_SMS, ) else: - _kernel_matmul_fp8_row_no_fast_acc[persistent_grid]( - a_tl, - b_tl, + torch._library.capture_triton(_kernel_matmul_fp8_row_no_fast_acc)[ + persistent_grid + ]( + a, + b, c, M, N, @@ -1659,13 +1657,13 @@ def matmul_fp8_block( Tensor: [M, N] output tensor, (a / a_scale) @ (b / b_scale) """ # Get datatypes and constants to use. - _, tl_dtype, _, _ = get_fp8_constants() + _, tl_fp8_dtype, _, _ = get_fp8_constants() # Handle 3D+ a shape a_shape = a.shape a = a.view(-1, a.size(-1)) - # Reinterpret inputs into proper triton fp8 dtype. - a_tl = convert_fp8_type(a, tl_dtype) - b_tl = convert_fp8_type(b, tl_dtype) + # View inputs into proper triton fp8 dtype. + a_tl = reinterpret_fp8_type(a, tl_fp8_dtype) + b_tl = reinterpret_fp8_type(b, tl_fp8_dtype) M, N, K, m_key, n_key, k_key, c, _, dot_out_dtype_triton, device = prep_matmul( a_tl, b_tl, dot_out_dtype @@ -1794,14 +1792,18 @@ def get_matmul_tune(M: int, N: int, K: int) -> Tuple[int, int, int]: def prep_matmul( - a: TensorWrapper, b: TensorWrapper, dot_out_dtype: Optional[torch.dtype] -) -> Tuple[int, int, int, int, int, int, torch.Tensor, str, str, torch.device]: + a: Union[TensorWrapper, torch.Tensor], + b: Union[TensorWrapper, torch.Tensor], + dot_out_dtype: Optional[torch.dtype], +) -> Tuple[ + int, int, int, int, int, int, torch.Tensor, tl.dtype, tl.dtype, torch.device +]: """ Shared bookkeeping for a @ b.T matmul. Args: - a (TensorWrapper): [M, K] input tensor. - b (TensorWrapper): [N, K] input tensor. + a (torch.Tensor): [M, K] input tensor. + b (torch.Tensor): [N, K] input tensor. dot_out_dtype (tl.dtype): Output type of tensor core. Returns: @@ -1812,7 +1814,8 @@ def prep_matmul( n_key (int): Autotuning key for N dim. k_key (int): Autotuning key for K dim. c (Tensor): [M, N] output tensor. - dot_out_dtype (torch.dtype): Output type of tensor core. + c_dtype_triton (tl.dtype): Type of output tensor. + dot_out_dtype (tl.dtype): Output type of tensor core. device (torch.device): Device of output tensor. """ device = a.device @@ -1827,11 +1830,20 @@ def prep_matmul( # allocates output assert a.dtype in [ + torch.float8_e4m3fn, + torch.float8_e5m2, + torch.float8_e4m3fnuz, + torch.float8_e5m2fnuz, tl.float8e4nv, tl.float8e4b15, tl.float8e5, tl.float8e4b8, - ] and b.dtype in [ + ] + assert b.dtype in [ + torch.float8_e4m3fn, + torch.float8_e5m2, + torch.float8_e4m3fnuz, + torch.float8_e5m2fnuz, tl.float8e4nv, tl.float8e4b15, tl.float8e5, From 0a751a04443233411fe00bb7465e2b720a43e403 Mon Sep 17 00:00:00 2001 From: Benson Ma Date: Mon, 30 Sep 2024 17:43:53 -0700 Subject: [PATCH 17/48] Redefine FBGEMM targets with gpu_cpp_library [20/N] (#3193) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3193 X-link: https://github.com/facebookresearch/FBGEMM/pull/289 - Combine `gpu_cpp_library` targets into `fbgemm_gpu:sparse_ops` Reviewed By: spcyppt Differential Revision: D63591546 fbshipit-source-id: fdcae2c07cfb993595d421f7a183673f0bacf8b1 --- fbgemm_gpu/bench/batched_unary_embeddings_benchmark.py | 6 +----- .../bench/histogram_binning_calibration_benchmark.py | 6 +----- fbgemm_gpu/bench/jagged_tensor_benchmark.py | 6 +----- fbgemm_gpu/bench/quantize_ops_benchmark.py | 6 +----- fbgemm_gpu/bench/stride_gemm_benchmark.py | 6 +----- fbgemm_gpu/fbgemm_gpu/batched_unary_embeddings_ops.py | 9 +++------ fbgemm_gpu/fbgemm_gpu/sparse_ops.py | 4 +--- fbgemm_gpu/test/batched_unary_embeddings_test.py | 8 ++------ fbgemm_gpu/test/jagged/common.py | 8 +------- fbgemm_gpu/test/layout_transform_ops_test.py | 8 ++------ fbgemm_gpu/test/quantize/common.py | 7 +------ fbgemm_gpu/test/sparse/common.py | 7 +------ 12 files changed, 16 insertions(+), 65 deletions(-) diff --git a/fbgemm_gpu/bench/batched_unary_embeddings_benchmark.py b/fbgemm_gpu/bench/batched_unary_embeddings_benchmark.py index 6091bcc8e..6d354900a 100644 --- a/fbgemm_gpu/bench/batched_unary_embeddings_benchmark.py +++ b/fbgemm_gpu/bench/batched_unary_embeddings_benchmark.py @@ -25,11 +25,7 @@ else: from fbgemm_gpu.bench.bench_utils import benchmark_torch_function - if torch.version.hip: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_hip") - else: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") def generate_unary_feature( diff --git a/fbgemm_gpu/bench/histogram_binning_calibration_benchmark.py b/fbgemm_gpu/bench/histogram_binning_calibration_benchmark.py index c919199ee..e43106b8c 100644 --- a/fbgemm_gpu/bench/histogram_binning_calibration_benchmark.py +++ b/fbgemm_gpu/bench/histogram_binning_calibration_benchmark.py @@ -21,11 +21,7 @@ # pyre-ignore[21] from fbgemm_gpu import open_source # noqa: F401 except Exception: - if torch.version.hip: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_hip") - else: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") def benchmark_hbc_function( diff --git a/fbgemm_gpu/bench/jagged_tensor_benchmark.py b/fbgemm_gpu/bench/jagged_tensor_benchmark.py index 46337701e..814f950b0 100644 --- a/fbgemm_gpu/bench/jagged_tensor_benchmark.py +++ b/fbgemm_gpu/bench/jagged_tensor_benchmark.py @@ -31,10 +31,7 @@ else: from fbgemm_gpu.bench.bench_utils import benchmark_torch_function - if torch.version.hip: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_hip") - else: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") torch.ops.load_library( "//deeplearning/fbgemm/fbgemm_gpu:permute_pooled_embedding_ops_cpu" ) @@ -47,7 +44,6 @@ torch.ops.load_library( "//deeplearning/fbgemm/fbgemm_gpu:permute_multi_embedding_ops_gpu" ) - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") @click.group() diff --git a/fbgemm_gpu/bench/quantize_ops_benchmark.py b/fbgemm_gpu/bench/quantize_ops_benchmark.py index 81eb07bea..54755fff6 100644 --- a/fbgemm_gpu/bench/quantize_ops_benchmark.py +++ b/fbgemm_gpu/bench/quantize_ops_benchmark.py @@ -34,11 +34,7 @@ else: from fbgemm_gpu.bench.bench_utils import benchmark_torch_function - if torch.version.hip: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_hip") - else: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") @click.group() diff --git a/fbgemm_gpu/bench/stride_gemm_benchmark.py b/fbgemm_gpu/bench/stride_gemm_benchmark.py index 2609f7fbf..bca34b8a9 100644 --- a/fbgemm_gpu/bench/stride_gemm_benchmark.py +++ b/fbgemm_gpu/bench/stride_gemm_benchmark.py @@ -20,11 +20,7 @@ # pyre-ignore[21] from fbgemm_gpu import open_source # noqa: F401 except Exception: - if torch.version.hip: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_hip") - else: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") @click.group() diff --git a/fbgemm_gpu/fbgemm_gpu/batched_unary_embeddings_ops.py b/fbgemm_gpu/fbgemm_gpu/batched_unary_embeddings_ops.py index 5ef0f1c32..db2260df4 100644 --- a/fbgemm_gpu/fbgemm_gpu/batched_unary_embeddings_ops.py +++ b/fbgemm_gpu/fbgemm_gpu/batched_unary_embeddings_ops.py @@ -13,16 +13,13 @@ import torch +from fbgemm_gpu.utils.loader import load_torch_module + try: # pyre-ignore[21] from fbgemm_gpu import open_source # noqa: F401 except Exception: - if torch.version.hip: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_hip") - else: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") - - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") + load_torch_module("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") def wrap_weight_to_parameter(weights: List[torch.Tensor]) -> List[torch.Tensor]: diff --git a/fbgemm_gpu/fbgemm_gpu/sparse_ops.py b/fbgemm_gpu/fbgemm_gpu/sparse_ops.py index de9a21ef9..71e0e2ccc 100644 --- a/fbgemm_gpu/fbgemm_gpu/sparse_ops.py +++ b/fbgemm_gpu/fbgemm_gpu/sparse_ops.py @@ -20,20 +20,18 @@ from fbgemm_gpu import open_source # noqa: F401 except Exception: load_torch_module("//deeplearning/fbgemm/fbgemm_gpu:merge_pooled_embeddings") + load_torch_module("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") if torch.version.hip: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_hip") torch.ops.load_library( "//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops_hip" ) else: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops") torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:input_combine") - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") torch.ops.load_library( "//deeplearning/fbgemm/fbgemm_gpu:merge_pooled_embeddings_cpu" ) diff --git a/fbgemm_gpu/test/batched_unary_embeddings_test.py b/fbgemm_gpu/test/batched_unary_embeddings_test.py index 3f1124eff..37d0aaf7b 100644 --- a/fbgemm_gpu/test/batched_unary_embeddings_test.py +++ b/fbgemm_gpu/test/batched_unary_embeddings_test.py @@ -26,14 +26,10 @@ from test_utils import gpu_unavailable except Exception: - if torch.version.hip: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_hip") - else: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") - - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") from fbgemm_gpu.test.test_utils import gpu_unavailable + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") + # Relative tolerances # pyre-fixme[5]: Global expression must be annotated. diff --git a/fbgemm_gpu/test/jagged/common.py b/fbgemm_gpu/test/jagged/common.py index 3bdcacf98..491176f76 100644 --- a/fbgemm_gpu/test/jagged/common.py +++ b/fbgemm_gpu/test/jagged/common.py @@ -24,13 +24,7 @@ open_source: bool = getattr(fbgemm_gpu, "open_source", False) if not open_source: - if torch.version.hip: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_hip") - else: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") - - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") - + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") suppressed_list: List[HealthCheck] = ( [HealthCheck.differing_executors] diff --git a/fbgemm_gpu/test/layout_transform_ops_test.py b/fbgemm_gpu/test/layout_transform_ops_test.py index 658d773f3..57a7b0263 100644 --- a/fbgemm_gpu/test/layout_transform_ops_test.py +++ b/fbgemm_gpu/test/layout_transform_ops_test.py @@ -22,14 +22,10 @@ from test_utils import gpu_unavailable except Exception: - if torch.version.hip: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_hip") - else: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") - - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") from fbgemm_gpu.test.test_utils import gpu_unavailable + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") + MAX_EXAMPLES = 20 diff --git a/fbgemm_gpu/test/quantize/common.py b/fbgemm_gpu/test/quantize/common.py index 5333cc893..6a720a174 100644 --- a/fbgemm_gpu/test/quantize/common.py +++ b/fbgemm_gpu/test/quantize/common.py @@ -23,12 +23,7 @@ from fbgemm_gpu import open_source # noqa: F401 except Exception: - if torch.version.hip: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_hip") - else: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") - - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") # Eigen/Python round 0.5 away from 0, Numpy rounds to even round_to_nearest: Callable[[npt.NDArray], npt.NDArray] = np.vectorize(round) diff --git a/fbgemm_gpu/test/sparse/common.py b/fbgemm_gpu/test/sparse/common.py index 69b6e3477..8abddca75 100644 --- a/fbgemm_gpu/test/sparse/common.py +++ b/fbgemm_gpu/test/sparse/common.py @@ -23,12 +23,7 @@ open_source: bool = getattr(fbgemm_gpu, "open_source", False) if not open_source: - if torch.version.hip: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_hip") - else: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") - - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu/codegen:index_select_ops") suppressed_list: List[HealthCheck] = ( From 9655d8fe4f25a0650a04f0875b9e6e8c14ea10f9 Mon Sep 17 00:00:00 2001 From: Benson Ma Date: Mon, 30 Sep 2024 18:25:43 -0700 Subject: [PATCH 18/48] Add support for int64_t indices in TBE inference [1/N] (#3041) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3041 X-link: https://github.com/facebookresearch/FBGEMM/pull/138 - Add support for int64_t indices in TBE inference [1/N] Reviewed By: jianyuh Differential Revision: D61813383 fbshipit-source-id: a4a7b9379f47217e340bc0902f3c2d6fd514bf92 --- ...ward_quantized_split_nbit_host_template.cu | 233 +++++++++++++----- .../include/fbgemm_gpu/utils/tensor_utils.h | 40 +++ 2 files changed, 207 insertions(+), 66 deletions(-) diff --git a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_host_template.cu b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_host_template.cu index bc4e7ba74..5dd5c30b1 100644 --- a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_host_template.cu +++ b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_host_template.cu @@ -7,7 +7,7 @@ */ // clang-format off -{% set wdesc = "weighted" if weighted else "unweighted" %} +{%- set wdesc = "weighted" if weighted else "unweighted" %} #include "fbgemm_gpu/embedding_forward_template_helpers.cuh" #include "fbgemm_gpu/utils/tensor_accessor.h" @@ -22,7 +22,7 @@ namespace nbit { `Tensor int_nbit_split_embedding*_codegen_forward_*_cuda(...)` later in the same generated source file. */ -{% for emb_weight_type in ["FP32", "FP16", "FP8", "INT8", "INT4", "INT2"] %} +{%- for emb_weight_type in ["FP32", "FP16", "FP8", "INT8", "INT4", "INT2"] %} template __launch_bounds__(WarpsPerBlock * kWarpSize) __global__ void {{ type_map[emb_weight_type].enum_name }}_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{{ wdesc }}_kernel_small_L( @@ -31,30 +31,30 @@ __global__ void {{ type_map[emb_weight_type].enum_name }}_split_embedding{{ "_no const pta::PackedTensorAccessor32 weights_placements, const pta::PackedTensorAccessor32 weights_offsets, const pta::PackedTensorAccessor32 weights_tys, - {% if not nobag %} + {%- if not nobag %} const pta::PackedTensorAccessor32 D_offsets, - {% else %} + {%- else %} const int64_t D, - {% endif %} + {%- endif %} FixedDivisor fd_B, // FixedDivisor(div_round_up(B, OutputRowsPerThread)) const pta::PackedTensorAccessor32 indices, const pta::PackedTensorAccessor32 offsets, - {% if not nobag %} + {%- if not nobag %} const int64_t pooling_mode, - {% endif %} + {%- endif %} const int64_t row_alignment, - {% if weighted %} + {%- if weighted %} pta::PackedTensorAccessor32 indice_weights, - {% endif %} - {% if type_map[emb_weight_type].enum_name == "FP8" %} + {%- endif %} + {%- if type_map[emb_weight_type].enum_name == "FP8" %} const int fp8_exponent_bits, const int fp8_exponent_bias, - {% endif %} + {%- endif %} pta::PackedTensorAccessor32 output, // [B][total_D], const pta::PackedTensorAccessor64 lxu_cache_weights, const pta::PackedTensorAccessor32 lxu_cache_locations ); -{% endfor %} // for emb_weight_type in ["FP32", "FP16", "FP8", "INT8", "INT4", "INT2"] +{%- endfor %} // for emb_weight_type in ["FP32", "FP16", "FP8", "INT8", "INT4", "INT2"] } @@ -107,58 +107,7 @@ __global__ void {{ type_map[emb_weight_type].enum_name }}_split_embedding{{ "_no C10_CUDA_KERNEL_LAUNCH_CHECK(); \ {%- endmacro %} - -Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{{ wdesc }}_cuda( - Tensor dev_weights, - Tensor uvm_weights, - Tensor weights_placements, - Tensor weights_offsets, - Tensor weights_tys, - {% if not nobag %} - Tensor D_offsets, - const int64_t total_D, - {% else %} - const int64_t D, - {% endif %} - const int64_t max_int2_D, - const int64_t max_int4_D, - const int64_t max_int8_D, - const int64_t max_float16_D, - const int64_t max_float32_D, - Tensor indices, - Tensor offsets, - {% if not nobag %} - const int64_t pooling_mode, - {% endif %} - const int64_t row_alignment, - {% if weighted %} - Tensor indice_weights, - {% endif %} - const int64_t output_dtype, - Tensor lxu_cache_weights, - Tensor lxu_cache_locations, - const int64_t max_float8_D, - const int64_t fp8_exponent_bits, - const int64_t fp8_exponent_bias -) { - TENSOR_ON_CUDA_GPU(dev_weights); - TENSORS_ON_SAME_DEVICE(uvm_weights, dev_weights); - TENSORS_ON_SAME_DEVICE(weights_placements, dev_weights); - TENSORS_ON_SAME_DEVICE(weights_offsets, dev_weights); - TENSORS_ON_SAME_DEVICE(weights_tys, dev_weights); - {% if not nobag %} - TENSORS_ON_SAME_DEVICE(D_offsets, dev_weights); - {% endif %} - TENSORS_ON_SAME_DEVICE(indices, dev_weights); - TENSORS_ON_SAME_DEVICE(offsets, dev_weights); - {% if weighted %} - TENSORS_EMPTY_OR_ON_SAME_DEVICE(indice_weights, dev_weights); - {% endif %} - TENSORS_EMPTY_OR_ON_SAME_DEVICE(lxu_cache_weights, dev_weights); - TENSORS_EMPTY_OR_ON_SAME_DEVICE(lxu_cache_locations, dev_weights); - - CUDA_DEVICE_GUARD(dev_weights); - +{%- macro construct_and_return_output_tensor() %} // kernels assume indices are contiguous. indices = indices.contiguous(); @@ -180,8 +129,10 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{ TORCH_CHECK(D > 0); {%- endif %} + // Construct output tensor Tensor output; const int kINT8QparamsBytes = 8; + SparseType o_dtype = static_cast(output_dtype); TORCH_CHECK(o_dtype == SparseType::FP32 || o_dtype == SparseType::FP16 || o_dtype == SparseType::BF16 || o_dtype == SparseType::INT8); @@ -216,11 +167,63 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{ if (B == 0 || indices.numel() == 0) { return output; } +{%- endmacro %} - using index_t = int32_t; +template +Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{{ wdesc }}_cuda_impl( + Tensor dev_weights, + Tensor uvm_weights, + Tensor weights_placements, + Tensor weights_offsets, + Tensor weights_tys, + {%- if not nobag %} + Tensor D_offsets, + const int64_t total_D, + {%- else %} + const int64_t D, + {%- endif %} + const int64_t max_int2_D, + const int64_t max_int4_D, + const int64_t max_int8_D, + const int64_t max_float16_D, + const int64_t max_float32_D, + Tensor indices, + Tensor offsets, + {%- if not nobag %} + const int64_t pooling_mode, + {%- endif %} + const int64_t row_alignment, + {%- if weighted %} + Tensor indice_weights, + {%- endif %} + const int64_t output_dtype, + Tensor lxu_cache_weights, + Tensor lxu_cache_locations, + const int64_t max_float8_D, + const int64_t fp8_exponent_bits, + const int64_t fp8_exponent_bias +) { + TENSOR_ON_CUDA_GPU(dev_weights); + TENSORS_ON_SAME_DEVICE(uvm_weights, dev_weights); + TENSORS_ON_SAME_DEVICE(weights_placements, dev_weights); + TENSORS_ON_SAME_DEVICE(weights_offsets, dev_weights); + TENSORS_ON_SAME_DEVICE(weights_tys, dev_weights); + {%- if not nobag %} + TENSORS_ON_SAME_DEVICE(D_offsets, dev_weights); + {%- endif %} + TENSORS_ON_SAME_DEVICE(indices, dev_weights); + TENSORS_ON_SAME_DEVICE(offsets, dev_weights); + {%- if weighted %} + TENSORS_EMPTY_OR_ON_SAME_DEVICE(indice_weights, dev_weights); + {%- endif %} + TENSORS_EMPTY_OR_ON_SAME_DEVICE(lxu_cache_weights, dev_weights); + TENSORS_EMPTY_OR_ON_SAME_DEVICE(lxu_cache_locations, dev_weights); - constexpr int32_t kWarpsPerBlock = 4; + CUDA_DEVICE_GUARD(dev_weights); + + {{- construct_and_return_output_tensor() }} + constexpr int32_t kWarpsPerBlock = 4; const auto device_only = lxu_cache_weights.numel() == 0 && uvm_weights.numel() == 0; #define Y(...) \ if (device_only) { \ @@ -397,6 +400,104 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{ })); #undef X + return output; +} + +Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{{ wdesc }}_cuda( + Tensor dev_weights, + Tensor uvm_weights, + Tensor weights_placements, + Tensor weights_offsets, + Tensor weights_tys, + {%- if not nobag %} + Tensor D_offsets, + const int64_t total_D, + {%- else %} + const int64_t D, + {%- endif %} + const int64_t max_int2_D, + const int64_t max_int4_D, + const int64_t max_int8_D, + const int64_t max_float16_D, + const int64_t max_float32_D, + Tensor indices, + Tensor offsets, + {%- if not nobag %} + const int64_t pooling_mode, + {%- endif %} + const int64_t row_alignment, + {%- if weighted %} + Tensor indice_weights, + {%- endif %} + const int64_t output_dtype, + Tensor lxu_cache_weights, + Tensor lxu_cache_locations, + const int64_t max_float8_D, + const int64_t fp8_exponent_bits, + const int64_t fp8_exponent_bias +) { + // All argument tensors need to be on the same CUDA device + TENSOR_ON_CUDA_GPU(dev_weights); + TENSORS_ON_SAME_DEVICE(uvm_weights, dev_weights); + TENSORS_ON_SAME_DEVICE(weights_placements, dev_weights); + TENSORS_ON_SAME_DEVICE(weights_offsets, dev_weights); + TENSORS_ON_SAME_DEVICE(weights_tys, dev_weights); + {%- if not nobag %} + TENSORS_ON_SAME_DEVICE(D_offsets, dev_weights); + {%- endif %} + TENSORS_ON_SAME_DEVICE(indices, dev_weights); + TENSORS_ON_SAME_DEVICE(offsets, dev_weights); + {%- if weighted %} + TENSORS_EMPTY_OR_ON_SAME_DEVICE(indice_weights, dev_weights); + {%- endif %} + TENSORS_EMPTY_OR_ON_SAME_DEVICE(lxu_cache_weights, dev_weights); + TENSORS_EMPTY_OR_ON_SAME_DEVICE(lxu_cache_locations, dev_weights); + + // indices and offsets need to have the same scalar type + TENSORS_HAVE_SAME_TYPE(indices, offsets); + // Only int32_t and int64_t indices are supported at the moment + TENSOR_SCALAR_TYPE_IS_ONE_OF(indices, at::ScalarType::Long, at::ScalarType::Int); + + CUDA_DEVICE_GUARD(dev_weights); + + // Create output tensor ref + Tensor output; + + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "{{ 'int_nbit_split_embedding' + ('_nobag' if nobag else '') + '_codegen_forward_' + wdesc + '_cuda' }}", [&] { + output = int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{{ wdesc }}_cuda_impl( + dev_weights, + uvm_weights, + weights_placements, + weights_offsets, + weights_tys, + {%- if not nobag %} + D_offsets, + total_D, + {%- else %} + D, + {%- endif %} + max_int2_D, + max_int4_D, + max_int8_D, + max_float16_D, + max_float32_D, + indices, + offsets, + {%- if not nobag %} + pooling_mode, + {%- endif %} + row_alignment, + {%- if weighted %} + indice_weights, + {%- endif %} + output_dtype, + lxu_cache_weights, + lxu_cache_locations, + max_float8_D, + fp8_exponent_bits, + fp8_exponent_bias); + }); + return output; } diff --git a/fbgemm_gpu/include/fbgemm_gpu/utils/tensor_utils.h b/fbgemm_gpu/include/fbgemm_gpu/utils/tensor_utils.h index b1ab0306c..f64205b7e 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/utils/tensor_utils.h +++ b/fbgemm_gpu/include/fbgemm_gpu/utils/tensor_utils.h @@ -299,3 +299,43 @@ inline at::Tensor aligned_grad_output_tensor_for_cuda_backwards( } return aligned_grad_output; } + +template +std::string tensor_scalar_type_is_one_of( + const at::Tensor& ten, + const ScalarTypes&... ttypes) { + auto has_match = false; + + // Collect the GPU index of the first non-empty optional tensor and make sure + // that all tensors are on this same index. + ( + [&](const auto& ttype) { + if (ten.scalar_type() == ttype) { + has_match = true; + } + }(ttypes), + ...); + + if (has_match) { + return ""; + } + + std::string msg = "Tensor's scalar type ("; + msg.append(toString(ten.scalar_type())); + msg.append(") did not match any one of the following types: ["); + ( + [&](const auto& ttype) { + msg.append(toString(ttype)); + msg.append(", "); + }(ttypes), + ...); + + msg.append("]"); + return msg; +} + +#define TENSOR_SCALAR_TYPE_IS_ONE_OF(...) \ + do { \ + const auto has_match = tensor_scalar_type_is_one_of(__VA_ARGS__); \ + TORCH_CHECK(has_match.empty(), has_match); \ + } while (false) From fc8350f1cc9f71de73df81e3963a71ea7ed0ec62 Mon Sep 17 00:00:00 2001 From: Gagan Jain Date: Mon, 30 Sep 2024 20:34:41 -0700 Subject: [PATCH 19/48] Revert D63591546: Redefine FBGEMM targets with gpu_cpp_library [20/N] Differential Revision: D63591546 Original commit changeset: fdcae2c07cfb Original Phabricator Diff: D63591546 fbshipit-source-id: c753dd92342bedf7803cc4045100438be91fbf49 --- fbgemm_gpu/bench/batched_unary_embeddings_benchmark.py | 6 +++++- .../bench/histogram_binning_calibration_benchmark.py | 6 +++++- fbgemm_gpu/bench/jagged_tensor_benchmark.py | 6 +++++- fbgemm_gpu/bench/quantize_ops_benchmark.py | 6 +++++- fbgemm_gpu/bench/stride_gemm_benchmark.py | 6 +++++- fbgemm_gpu/fbgemm_gpu/batched_unary_embeddings_ops.py | 9 ++++++--- fbgemm_gpu/fbgemm_gpu/sparse_ops.py | 4 +++- fbgemm_gpu/test/batched_unary_embeddings_test.py | 8 ++++++-- fbgemm_gpu/test/jagged/common.py | 8 +++++++- fbgemm_gpu/test/layout_transform_ops_test.py | 8 ++++++-- fbgemm_gpu/test/quantize/common.py | 7 ++++++- fbgemm_gpu/test/sparse/common.py | 7 ++++++- 12 files changed, 65 insertions(+), 16 deletions(-) diff --git a/fbgemm_gpu/bench/batched_unary_embeddings_benchmark.py b/fbgemm_gpu/bench/batched_unary_embeddings_benchmark.py index 6d354900a..6091bcc8e 100644 --- a/fbgemm_gpu/bench/batched_unary_embeddings_benchmark.py +++ b/fbgemm_gpu/bench/batched_unary_embeddings_benchmark.py @@ -25,7 +25,11 @@ else: from fbgemm_gpu.bench.bench_utils import benchmark_torch_function - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") + if torch.version.hip: + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_hip") + else: + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") def generate_unary_feature( diff --git a/fbgemm_gpu/bench/histogram_binning_calibration_benchmark.py b/fbgemm_gpu/bench/histogram_binning_calibration_benchmark.py index e43106b8c..c919199ee 100644 --- a/fbgemm_gpu/bench/histogram_binning_calibration_benchmark.py +++ b/fbgemm_gpu/bench/histogram_binning_calibration_benchmark.py @@ -21,7 +21,11 @@ # pyre-ignore[21] from fbgemm_gpu import open_source # noqa: F401 except Exception: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") + if torch.version.hip: + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_hip") + else: + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") def benchmark_hbc_function( diff --git a/fbgemm_gpu/bench/jagged_tensor_benchmark.py b/fbgemm_gpu/bench/jagged_tensor_benchmark.py index 814f950b0..46337701e 100644 --- a/fbgemm_gpu/bench/jagged_tensor_benchmark.py +++ b/fbgemm_gpu/bench/jagged_tensor_benchmark.py @@ -31,7 +31,10 @@ else: from fbgemm_gpu.bench.bench_utils import benchmark_torch_function - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") + if torch.version.hip: + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_hip") + else: + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") torch.ops.load_library( "//deeplearning/fbgemm/fbgemm_gpu:permute_pooled_embedding_ops_cpu" ) @@ -44,6 +47,7 @@ torch.ops.load_library( "//deeplearning/fbgemm/fbgemm_gpu:permute_multi_embedding_ops_gpu" ) + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") @click.group() diff --git a/fbgemm_gpu/bench/quantize_ops_benchmark.py b/fbgemm_gpu/bench/quantize_ops_benchmark.py index 54755fff6..81eb07bea 100644 --- a/fbgemm_gpu/bench/quantize_ops_benchmark.py +++ b/fbgemm_gpu/bench/quantize_ops_benchmark.py @@ -34,7 +34,11 @@ else: from fbgemm_gpu.bench.bench_utils import benchmark_torch_function - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") + if torch.version.hip: + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_hip") + else: + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") @click.group() diff --git a/fbgemm_gpu/bench/stride_gemm_benchmark.py b/fbgemm_gpu/bench/stride_gemm_benchmark.py index bca34b8a9..2609f7fbf 100644 --- a/fbgemm_gpu/bench/stride_gemm_benchmark.py +++ b/fbgemm_gpu/bench/stride_gemm_benchmark.py @@ -20,7 +20,11 @@ # pyre-ignore[21] from fbgemm_gpu import open_source # noqa: F401 except Exception: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") + if torch.version.hip: + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_hip") + else: + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") @click.group() diff --git a/fbgemm_gpu/fbgemm_gpu/batched_unary_embeddings_ops.py b/fbgemm_gpu/fbgemm_gpu/batched_unary_embeddings_ops.py index db2260df4..5ef0f1c32 100644 --- a/fbgemm_gpu/fbgemm_gpu/batched_unary_embeddings_ops.py +++ b/fbgemm_gpu/fbgemm_gpu/batched_unary_embeddings_ops.py @@ -13,13 +13,16 @@ import torch -from fbgemm_gpu.utils.loader import load_torch_module - try: # pyre-ignore[21] from fbgemm_gpu import open_source # noqa: F401 except Exception: - load_torch_module("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") + if torch.version.hip: + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_hip") + else: + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") + + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") def wrap_weight_to_parameter(weights: List[torch.Tensor]) -> List[torch.Tensor]: diff --git a/fbgemm_gpu/fbgemm_gpu/sparse_ops.py b/fbgemm_gpu/fbgemm_gpu/sparse_ops.py index 71e0e2ccc..de9a21ef9 100644 --- a/fbgemm_gpu/fbgemm_gpu/sparse_ops.py +++ b/fbgemm_gpu/fbgemm_gpu/sparse_ops.py @@ -20,18 +20,20 @@ from fbgemm_gpu import open_source # noqa: F401 except Exception: load_torch_module("//deeplearning/fbgemm/fbgemm_gpu:merge_pooled_embeddings") - load_torch_module("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") if torch.version.hip: + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_hip") torch.ops.load_library( "//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops_hip" ) else: + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops") torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:input_combine") + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") torch.ops.load_library( "//deeplearning/fbgemm/fbgemm_gpu:merge_pooled_embeddings_cpu" ) diff --git a/fbgemm_gpu/test/batched_unary_embeddings_test.py b/fbgemm_gpu/test/batched_unary_embeddings_test.py index 37d0aaf7b..3f1124eff 100644 --- a/fbgemm_gpu/test/batched_unary_embeddings_test.py +++ b/fbgemm_gpu/test/batched_unary_embeddings_test.py @@ -26,9 +26,13 @@ from test_utils import gpu_unavailable except Exception: - from fbgemm_gpu.test.test_utils import gpu_unavailable + if torch.version.hip: + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_hip") + else: + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") + from fbgemm_gpu.test.test_utils import gpu_unavailable # Relative tolerances diff --git a/fbgemm_gpu/test/jagged/common.py b/fbgemm_gpu/test/jagged/common.py index 491176f76..3bdcacf98 100644 --- a/fbgemm_gpu/test/jagged/common.py +++ b/fbgemm_gpu/test/jagged/common.py @@ -24,7 +24,13 @@ open_source: bool = getattr(fbgemm_gpu, "open_source", False) if not open_source: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") + if torch.version.hip: + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_hip") + else: + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") + + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") + suppressed_list: List[HealthCheck] = ( [HealthCheck.differing_executors] diff --git a/fbgemm_gpu/test/layout_transform_ops_test.py b/fbgemm_gpu/test/layout_transform_ops_test.py index 57a7b0263..658d773f3 100644 --- a/fbgemm_gpu/test/layout_transform_ops_test.py +++ b/fbgemm_gpu/test/layout_transform_ops_test.py @@ -22,9 +22,13 @@ from test_utils import gpu_unavailable except Exception: - from fbgemm_gpu.test.test_utils import gpu_unavailable + if torch.version.hip: + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_hip") + else: + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") + from fbgemm_gpu.test.test_utils import gpu_unavailable MAX_EXAMPLES = 20 diff --git a/fbgemm_gpu/test/quantize/common.py b/fbgemm_gpu/test/quantize/common.py index 6a720a174..5333cc893 100644 --- a/fbgemm_gpu/test/quantize/common.py +++ b/fbgemm_gpu/test/quantize/common.py @@ -23,7 +23,12 @@ from fbgemm_gpu import open_source # noqa: F401 except Exception: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") + if torch.version.hip: + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_hip") + else: + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") + + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") # Eigen/Python round 0.5 away from 0, Numpy rounds to even round_to_nearest: Callable[[npt.NDArray], npt.NDArray] = np.vectorize(round) diff --git a/fbgemm_gpu/test/sparse/common.py b/fbgemm_gpu/test/sparse/common.py index 8abddca75..69b6e3477 100644 --- a/fbgemm_gpu/test/sparse/common.py +++ b/fbgemm_gpu/test/sparse/common.py @@ -23,7 +23,12 @@ open_source: bool = getattr(fbgemm_gpu, "open_source", False) if not open_source: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") + if torch.version.hip: + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_hip") + else: + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") + + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu/codegen:index_select_ops") suppressed_list: List[HealthCheck] = ( From f323833aaadcace97225dfd9aa8fafe663930872 Mon Sep 17 00:00:00 2001 From: sryap <17482891+sryap@users.noreply.github.com> Date: Tue, 1 Oct 2024 05:08:44 -0700 Subject: [PATCH 20/48] Add stable API doc (#3194) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/290 As title Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3194 Reviewed By: shintaro-iwasaki Differential Revision: D63601834 Pulled By: sryap fbshipit-source-id: ef5bb4eac602027d6051a2c53e5b59ff00ab2afd --- .../jagged_tensor_ops.rst | 12 +++++- .../pooled_embedding_modules.rst | 8 ++++ .../pooled_embedding_ops.rst | 8 ++++ .../fbgemm_gpu-python-api/quantize_ops.rst | 8 ++++ .../src/fbgemm_gpu-python-api/sparse_ops.rst | 9 ++++- .../table_batched_embedding_ops.rst | 8 ++++ .../src/fbgemm_gpu-stable-api/python_api.rst | 37 +++++++++++++++++++ fbgemm_gpu/docs/src/index.rst | 8 ++++ 8 files changed, 95 insertions(+), 3 deletions(-) create mode 100644 fbgemm_gpu/docs/src/fbgemm_gpu-stable-api/python_api.rst diff --git a/fbgemm_gpu/docs/src/fbgemm_gpu-python-api/jagged_tensor_ops.rst b/fbgemm_gpu/docs/src/fbgemm_gpu-python-api/jagged_tensor_ops.rst index 92e8f1148..a85168bfc 100644 --- a/fbgemm_gpu/docs/src/fbgemm_gpu-python-api/jagged_tensor_ops.rst +++ b/fbgemm_gpu/docs/src/fbgemm_gpu-python-api/jagged_tensor_ops.rst @@ -3,14 +3,22 @@ Jagged Tensor Operators .. automodule:: fbgemm_gpu +.. _jagged-tensor-ops-stable-api: + +Stable API +---------- + +.. autofunction:: torch.ops.fbgemm.jagged_to_padded_dense + +Other API +--------- + .. autofunction:: torch.ops.fbgemm.jagged_2d_to_dense .. autofunction:: torch.ops.fbgemm.jagged_1d_to_dense .. autofunction:: torch.ops.fbgemm.dense_to_jagged -.. autofunction:: torch.ops.fbgemm.jagged_to_padded_dense - .. autofunction:: torch.ops.fbgemm.jagged_dense_elementwise_add .. autofunction:: torch.ops.fbgemm.jagged_dense_elementwise_add_jagged_output diff --git a/fbgemm_gpu/docs/src/fbgemm_gpu-python-api/pooled_embedding_modules.rst b/fbgemm_gpu/docs/src/fbgemm_gpu-python-api/pooled_embedding_modules.rst index 654373f40..7970ce6f9 100644 --- a/fbgemm_gpu/docs/src/fbgemm_gpu-python-api/pooled_embedding_modules.rst +++ b/fbgemm_gpu/docs/src/fbgemm_gpu-python-api/pooled_embedding_modules.rst @@ -3,5 +3,13 @@ Pooled Embedding Modules .. automodule:: fbgemm_gpu +.. _pooled-embedding-modules-stable-api: + +Stable API +---------- + .. autoclass:: fbgemm_gpu.permute_pooled_embedding_modules.PermutePooledEmbeddings :members: __call__ + +Other API +--------- diff --git a/fbgemm_gpu/docs/src/fbgemm_gpu-python-api/pooled_embedding_ops.rst b/fbgemm_gpu/docs/src/fbgemm_gpu-python-api/pooled_embedding_ops.rst index 52e2fd47d..9e9d545d7 100644 --- a/fbgemm_gpu/docs/src/fbgemm_gpu-python-api/pooled_embedding_ops.rst +++ b/fbgemm_gpu/docs/src/fbgemm_gpu-python-api/pooled_embedding_ops.rst @@ -3,6 +3,14 @@ Pooled Embedding Operators .. automodule:: fbgemm_gpu +.. _pooled-embedding-operators-stable-api: + +Stable API +---------- + .. autofunction:: torch.ops.fbgemm.merge_pooled_embeddings .. autofunction:: torch.ops.fbgemm.permute_pooled_embs + +Other API +--------- diff --git a/fbgemm_gpu/docs/src/fbgemm_gpu-python-api/quantize_ops.rst b/fbgemm_gpu/docs/src/fbgemm_gpu-python-api/quantize_ops.rst index df2a6c2d7..3b47f8bcd 100644 --- a/fbgemm_gpu/docs/src/fbgemm_gpu-python-api/quantize_ops.rst +++ b/fbgemm_gpu/docs/src/fbgemm_gpu-python-api/quantize_ops.rst @@ -3,4 +3,12 @@ Quantization Operators .. automodule:: fbgemm_gpu +.. _quantize-ops-stable-api: + +Stable API +---------- + .. autofunction:: torch.ops.fbgemm.FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf + +Other API +--------- diff --git a/fbgemm_gpu/docs/src/fbgemm_gpu-python-api/sparse_ops.rst b/fbgemm_gpu/docs/src/fbgemm_gpu-python-api/sparse_ops.rst index b95b6dda4..e5a4213f7 100644 --- a/fbgemm_gpu/docs/src/fbgemm_gpu-python-api/sparse_ops.rst +++ b/fbgemm_gpu/docs/src/fbgemm_gpu-python-api/sparse_ops.rst @@ -3,6 +3,11 @@ Sparse Operators .. automodule:: fbgemm_gpu +.. _sparse-ops-stable-api: + +Stable API +---------- + .. autofunction:: torch.ops.fbgemm.permute_2D_sparse_data .. autofunction:: torch.ops.fbgemm.permute_1D_sparse_data @@ -17,4 +22,6 @@ Sparse Operators .. autofunction:: torch.ops.fbgemm.keyed_jagged_index_select_dim1 -.. autofunction:: torch.ops.fbgemm.block_bucketize_sparse_features \ No newline at end of file +.. autofunction:: torch.ops.fbgemm.block_bucketize_sparse_features + +Other API diff --git a/fbgemm_gpu/docs/src/fbgemm_gpu-python-api/table_batched_embedding_ops.rst b/fbgemm_gpu/docs/src/fbgemm_gpu-python-api/table_batched_embedding_ops.rst index bbd39d873..9b5453786 100644 --- a/fbgemm_gpu/docs/src/fbgemm_gpu-python-api/table_batched_embedding_ops.rst +++ b/fbgemm_gpu/docs/src/fbgemm_gpu-python-api/table_batched_embedding_ops.rst @@ -1,6 +1,11 @@ Table Batched Embedding (TBE) Training Module ============================================= +.. _table-batched-embedding-ops-stable-api: + +Stable API +---------- + .. autoclass:: fbgemm_gpu.split_table_batched_embeddings_ops_training.SplitTableBatchedEmbeddingBagsCodegen :members: forward, split_embedding_weights, @@ -8,3 +13,6 @@ Table Batched Embedding (TBE) Training Module set_learning_rate, update_hyper_parameters, set_optimizer_step + +Other API +--------- diff --git a/fbgemm_gpu/docs/src/fbgemm_gpu-stable-api/python_api.rst b/fbgemm_gpu/docs/src/fbgemm_gpu-stable-api/python_api.rst new file mode 100644 index 000000000..54b4a6baa --- /dev/null +++ b/fbgemm_gpu/docs/src/fbgemm_gpu-stable-api/python_api.rst @@ -0,0 +1,37 @@ +FBGEMM_GPU Stable Python API +============================ + +We provide the stable API support starting from FBGEMM_GPU v1.0. The following +outlines our supports: + +- API backward compatibility guarantees via thorough testing. We guarantee that + our stable APIs will be backward compatible within a major version, meaning + that the stable APIs for v1.0.0 will be compatible with every future release + unless explicitly announced in advance + +- Enhanced documentation, ensuring that every stable API has comprehensive and + up-to-date documentation. + +- Functionality guarantees are only provided through unit testing framework. + We do NOT guarantee any functionalities that are NOT explicitly tested and + documented in our unit tests. + +- No performance guarantees. However, we are committed to providing support on + a best-effort basis. + +Stable APIs +----------- + +Our stable APIs can be found via the links below: + +- :ref:`Table batched embedding (TBE) modules` + +- :ref:`Pooled embedding operators` + +- :ref:`Pooled embedding modules` + +- :ref:`Sparse operators` + +- :ref:`Jagged tensor operators` + +- :ref:`Quantization operators` diff --git a/fbgemm_gpu/docs/src/index.rst b/fbgemm_gpu/docs/src/index.rst index f5b4cfb07..2b92f0d3d 100644 --- a/fbgemm_gpu/docs/src/index.rst +++ b/fbgemm_gpu/docs/src/index.rst @@ -56,6 +56,14 @@ Table of Contents fbgemm_gpu-overview/jagged-tensor-ops/JaggedTensorOps.rst +.. _fbgemm.toc.api.stable: + +.. toctree:: + :maxdepth: 1 + :caption: FBGEMM Stable API + + fbgemm_gpu-stable-api/python_api.rst + .. _fbgemm.toc.api.cpp: .. toctree:: From a5628cbac1eefb21dce3e7e995c8f83d8cbcdacf Mon Sep 17 00:00:00 2001 From: Benson Ma Date: Tue, 1 Oct 2024 11:12:35 -0700 Subject: [PATCH 21/48] Add support for int64_t indices in TBE inference [2/N] (#3125) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/214 Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3125 - Add support for int64_t indices in TBE inference [2/N] - Convert `pruned_array_lookup_cuda` to use index_t Reviewed By: jianyuh Differential Revision: D62271409 fbshipit-source-id: 2b8e639eaac031c37e017c1be9a92697f17d6e6e --- ...mbedding_forward_quantized_split_lookup.cu | 55 ++++++++++--------- .../include/fbgemm_gpu/utils/tensor_utils.h | 38 ++++++++++++- 2 files changed, 66 insertions(+), 27 deletions(-) diff --git a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_lookup.cu b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_lookup.cu index 7d4eebcce..52f2a49dd 100644 --- a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_lookup.cu +++ b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_lookup.cu @@ -89,19 +89,20 @@ __launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_forward_pru } } +template __global__ __launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_forward_pruned_array_lookup_kernel( - const pta::PackedTensorAccessor32 + const pta::PackedTensorAccessor32 indices, - const pta::PackedTensorAccessor32 + const pta::PackedTensorAccessor32 offsets, - const pta::PackedTensorAccessor32 + const pta::PackedTensorAccessor32 index_remappings, const pta::PackedTensorAccessor32 index_remappings_offsets, const int32_t B, const int32_t T, - pta::PackedTensorAccessor32 + pta::PackedTensorAccessor32 dense_indices) { const int32_t b_t = blockIdx.x * blockDim.y + threadIdx.y; const int32_t t = b_t / B; @@ -109,22 +110,22 @@ __launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_forward_pru if (b_t >= B * T) { return; } - const int32_t indices_start = offsets[t * B + b]; - const int32_t indices_end = offsets[t * B + b + 1]; - const int32_t L = indices_end - indices_start; + const index_t indices_start = offsets[t * B + b]; + const index_t indices_end = offsets[t * B + b + 1]; + const index_t L = indices_end - indices_start; const int64_t index_remappings_start = index_remappings_offsets[t]; const int64_t index_remappings_end = index_remappings_offsets[t + 1]; const int64_t capacity = index_remappings_end - index_remappings_start; if (capacity > 0) { - for (int32_t l = threadIdx.x; l < L; l += blockDim.x) { - int32_t idx = indices[indices_start + l]; + for (index_t l = threadIdx.x; l < L; l += blockDim.x) { + index_t idx = indices[indices_start + l]; dense_indices[indices_start + l] = index_remappings[index_remappings_start + idx]; } } else { - for (int32_t l = threadIdx.x; l < L; l += blockDim.x) { + for (index_t l = threadIdx.x; l < L; l += blockDim.x) { dense_indices[indices_start + l] = indices[indices_start + l]; } } @@ -178,6 +179,7 @@ Tensor pruned_array_lookup_cuda( Tensor index_remappings_offsets) { TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL( indices, offsets, index_remappings, index_remappings_offsets); + TENSORS_HAVE_SAME_SCALAR_TYPE(indices, offsets, index_remappings); CUDA_DEVICE_GUARD(indices); @@ -204,23 +206,26 @@ Tensor pruned_array_lookup_cuda( TORCH_CHECK(dense_indices.dim() == 1, "Tensor dim: ", dense_indices.dim()); constexpr size_t kForwardMaxThreads = 256; + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "pruned_array_lookup", [&] { #ifdef FBGEMM_GPU_MEMCHECK - const auto func_name = - "int_nbit_split_embedding_codegen_forward_pruned_hashmap_lookup_kernel"; + const auto func_name = + "int_nbit_split_embedding_codegen_forward_pruned_hashmap_lookup_kernel"; #endif - nbit::int_nbit_split_embedding_codegen_forward_pruned_array_lookup_kernel<<< - nbit::div_round_up(offsets.size(0), kForwardMaxThreads / kWarpSize), - dim3(kWarpSize, kForwardMaxThreads / kWarpSize), - 0, - at::cuda::getCurrentCUDAStream()>>>( - MAKE_PTA_WITH_NAME(func_name, indices, int32_t, 1, 32), - MAKE_PTA_WITH_NAME(func_name, offsets, int32_t, 1, 32), - MAKE_PTA_WITH_NAME(func_name, index_remappings, int32_t, 1, 32), - MAKE_PTA_WITH_NAME(func_name, index_remappings_offsets, int64_t, 1, 32), - B, - T, - MAKE_PTA_WITH_NAME(func_name, dense_indices, int32_t, 1, 32)); - C10_CUDA_KERNEL_LAUNCH_CHECK(); + nbit::int_nbit_split_embedding_codegen_forward_pruned_array_lookup_kernel<<< + nbit::div_round_up(offsets.size(0), kForwardMaxThreads / kWarpSize), + dim3(kWarpSize, kForwardMaxThreads / kWarpSize), + 0, + at::cuda::getCurrentCUDAStream()>>>( + MAKE_PTA_WITH_NAME(func_name, indices, index_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name, offsets, index_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name, index_remappings, index_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name, index_remappings_offsets, int64_t, 1, 32), + B, + T, + MAKE_PTA_WITH_NAME(func_name, dense_indices, index_t, 1, 32)); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + return dense_indices; } diff --git a/fbgemm_gpu/include/fbgemm_gpu/utils/tensor_utils.h b/fbgemm_gpu/include/fbgemm_gpu/utils/tensor_utils.h index f64205b7e..60cca19ef 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/utils/tensor_utils.h +++ b/fbgemm_gpu/include/fbgemm_gpu/utils/tensor_utils.h @@ -306,8 +306,6 @@ std::string tensor_scalar_type_is_one_of( const ScalarTypes&... ttypes) { auto has_match = false; - // Collect the GPU index of the first non-empty optional tensor and make sure - // that all tensors are on this same index. ( [&](const auto& ttype) { if (ten.scalar_type() == ttype) { @@ -339,3 +337,39 @@ std::string tensor_scalar_type_is_one_of( const auto has_match = tensor_scalar_type_is_one_of(__VA_ARGS__); \ TORCH_CHECK(has_match.empty(), has_match); \ } while (false) + +template +std::string tensors_have_same_scalar_type(const Tensors&... tensors) { + std::optional dtype; + bool have_same_type = true; + + ( + [&](const auto& tensor) { + if (!dtype) { + dtype = tensor.scalar_type(); + } else if (*dtype != tensor.scalar_type()) { + have_same_type = false; + } + }(tensors), + ...); + + if (have_same_type) { + return ""; + } + + std::string msg = "Tensors' scalar types ("; + ( + [&](const auto& tensor) { + msg.append(toString(tensor.scalar_type())); + msg.append(", "); + }(tensors), + ...); + msg.append(") are not one and the same!"); + return msg; +} + +#define TENSORS_HAVE_SAME_SCALAR_TYPE(...) \ + do { \ + const auto have_same_type = tensors_have_same_scalar_type(__VA_ARGS__); \ + TORCH_CHECK(have_same_type.empty(), have_same_type); \ + } while (false) From a44317d75218fa14b8d52ef1932a779346552a2d Mon Sep 17 00:00:00 2001 From: Benson Ma Date: Tue, 1 Oct 2024 11:51:24 -0700 Subject: [PATCH 22/48] Add support for int64_t indices and offsets in TBE inference [3/N] (#3124) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3124 X-link: https://github.com/facebookresearch/FBGEMM/pull/213 - Convert `pruned_hashmap_lookup_cuda` to use index_t Reviewed By: jianyuh Differential Revision: D62277673 fbshipit-source-id: 1f49b5df39cd1406eb7d2669acdd7f6a8cc2e3f9 --- ...mbedding_forward_quantized_split_lookup.cu | 68 +++++++++++-------- .../embedding_forward_template_helpers.cuh | 12 ++++ 2 files changed, 52 insertions(+), 28 deletions(-) diff --git a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_lookup.cu b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_lookup.cu index 52f2a49dd..86165bb39 100644 --- a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_lookup.cu +++ b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_lookup.cu @@ -14,19 +14,20 @@ using Tensor = at::Tensor; namespace nbit { +template __global__ __launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_forward_pruned_hashmap_lookup_kernel( - const pta::PackedTensorAccessor32 + const pta::PackedTensorAccessor32 indices, - const pta::PackedTensorAccessor32 + const pta::PackedTensorAccessor32 offsets, - const pta::PackedTensorAccessor64 + const pta::PackedTensorAccessor64 hash_table, const pta::PackedTensorAccessor32 hash_table_offsets, const int32_t B, const int32_t T, - pta::PackedTensorAccessor32 + pta::PackedTensorAccessor32 dense_indices) { // uint32_t capacity = hash_table.size(0); const int32_t b_t = blockIdx.x * blockDim.y + threadIdx.y; @@ -35,9 +36,9 @@ __launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_forward_pru if (b_t >= B * T) { return; } - const int32_t indices_start = offsets[t * B + b]; - const int32_t indices_end = offsets[t * B + b + 1]; - const int32_t L = indices_end - indices_start; + const index_t indices_start = offsets[t * B + b]; + const index_t indices_end = offsets[t * B + b + 1]; + const index_t L = indices_end - indices_start; const int64_t table_start = hash_table_offsets[t]; const int64_t table_end = hash_table_offsets[t + 1]; @@ -51,6 +52,9 @@ __launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_forward_pru return; } + using hash_t = + std::conditional_t, uint64_t, uint32_t>; + const uint32_t subwarp_id = threadIdx.x / 4; const uint32_t subwarp_tid = threadIdx.x % 4; #ifdef USE_ROCM @@ -58,13 +62,15 @@ __launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_forward_pru #else const uint32_t subwarp_mask = static_cast(0xF) << (4 * subwarp_id); #endif + for (int32_t l_start = 0; l_start + subwarp_id < L; l_start += kWarpSize / 4) { - const int32_t idx = indices[indices_start + l_start + subwarp_id]; - uint32_t slot_start = - pruned_hash_function(static_cast(idx)) % capacity; + const index_t idx = indices[indices_start + l_start + subwarp_id]; + hash_t slot_start = + pruned_hash_function(static_cast(idx)) % capacity; + while (true) { - const uint32_t slot = (slot_start + subwarp_tid) % capacity; + const hash_t slot = (slot_start + subwarp_tid) % capacity; const int2 val = *reinterpret_cast( &hash_table[table_start + static_cast(slot)][0]); const int32_t slot_sparse_idx = val.x; @@ -78,6 +84,7 @@ __launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_forward_pru found = true; dense_indices[indices_start + l_start + subwarp_id] = slot_dense_idx; } + if (__any_sync(subwarp_mask, found)) { break; } else if (__any_sync(subwarp_mask, empty)) { @@ -133,6 +140,8 @@ __launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_forward_pru } // namespace nbit +using namespace nbit; + Tensor pruned_hashmap_lookup_cuda( Tensor indices, Tensor offsets, @@ -140,6 +149,7 @@ Tensor pruned_hashmap_lookup_cuda( Tensor hash_table_offsets) { TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL( indices, offsets, hash_table, hash_table_offsets); + TENSORS_HAVE_SAME_SCALAR_TYPE(indices, offsets, hash_table); CUDA_DEVICE_GUARD(indices); @@ -150,23 +160,25 @@ Tensor pruned_hashmap_lookup_cuda( TORCH_CHECK(hash_table.size(0) < std::numeric_limits::max()); constexpr size_t kForwardMaxThreads = 256; + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "pruned_hashmap_lookup", [&] { #ifdef FBGEMM_GPU_MEMCHECK - const auto func_name = - "int_nbit_split_embedding_codegen_forward_pruned_hashmap_lookup_kernel"; + const auto func_name = + "int_nbit_split_embedding_codegen_forward_pruned_hashmap_lookup_kernel"; #endif - nbit::int_nbit_split_embedding_codegen_forward_pruned_hashmap_lookup_kernel<<< - nbit::div_round_up(B * T + 1, kForwardMaxThreads / kWarpSize), - dim3(kWarpSize, kForwardMaxThreads / kWarpSize), - 0, - at::cuda::getCurrentCUDAStream()>>>( - MAKE_PTA_WITH_NAME(func_name, indices, int32_t, 1, 32), - MAKE_PTA_WITH_NAME(func_name, offsets, int32_t, 1, 32), - MAKE_PTA_WITH_NAME(func_name, hash_table, int32_t, 2, 64), - MAKE_PTA_WITH_NAME(func_name, hash_table_offsets, int64_t, 1, 32), - B, - T, - MAKE_PTA_WITH_NAME(func_name, dense_indices, int32_t, 1, 32)); + int_nbit_split_embedding_codegen_forward_pruned_hashmap_lookup_kernel<<< + nbit::div_round_up(B * T + 1, kForwardMaxThreads / kWarpSize), + dim3(kWarpSize, kForwardMaxThreads / kWarpSize), + 0, + at::cuda::getCurrentCUDAStream()>>>( + MAKE_PTA_WITH_NAME(func_name, indices, index_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name, offsets, index_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name, hash_table, index_t, 2, 64), + MAKE_PTA_WITH_NAME(func_name, hash_table_offsets, int64_t, 1, 32), + B, + T, + MAKE_PTA_WITH_NAME(func_name, dense_indices, index_t, 1, 32)); + }); C10_CUDA_KERNEL_LAUNCH_CHECK(); return dense_indices; @@ -209,10 +221,10 @@ Tensor pruned_array_lookup_cuda( AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "pruned_array_lookup", [&] { #ifdef FBGEMM_GPU_MEMCHECK const auto func_name = - "int_nbit_split_embedding_codegen_forward_pruned_hashmap_lookup_kernel"; + "int_nbit_split_embedding_codegen_forward_pruned_array_lookup_kernel"; #endif - nbit::int_nbit_split_embedding_codegen_forward_pruned_array_lookup_kernel<<< + int_nbit_split_embedding_codegen_forward_pruned_array_lookup_kernel<<< nbit::div_round_up(offsets.size(0), kForwardMaxThreads / kWarpSize), dim3(kWarpSize, kForwardMaxThreads / kWarpSize), 0, @@ -224,8 +236,8 @@ Tensor pruned_array_lookup_cuda( B, T, MAKE_PTA_WITH_NAME(func_name, dense_indices, index_t, 1, 32)); - C10_CUDA_KERNEL_LAUNCH_CHECK(); }); + C10_CUDA_KERNEL_LAUNCH_CHECK(); return dense_indices; } diff --git a/fbgemm_gpu/include/fbgemm_gpu/embedding_forward_template_helpers.cuh b/fbgemm_gpu/include/fbgemm_gpu/embedding_forward_template_helpers.cuh index 97353e03c..2164afd3e 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/embedding_forward_template_helpers.cuh +++ b/fbgemm_gpu/include/fbgemm_gpu/embedding_forward_template_helpers.cuh @@ -88,6 +88,7 @@ __device__ inline int32_t padded_D( __device__ inline uint32_t pruned_hash_function(uint32_t h) { // MurmorHash3 32-bit mixing function. + // https://github.com/aappleby/smhasher/blob/master/src/MurmurHash3.cpp h ^= h >> 16; h *= 0x85ebca6b; h ^= h >> 13; @@ -96,6 +97,17 @@ __device__ inline uint32_t pruned_hash_function(uint32_t h) { return h; } +__device__ inline uint64_t pruned_hash_function(uint64_t k) { + // MurmorHash3 64-bit mixing function. + // https://github.com/aappleby/smhasher/blob/master/src/MurmurHash3.cpp + k ^= k >> 33; + k *= (0xff51afd7ed558ccd); + k ^= k >> 33; + k *= (0xc4ceb9fe1a85ec53); + k ^= k >> 33; + return k; +} + // ---------------------- START cp.async helpers, copied from CUTLASS /// CUTLASS helper to get SMEM pointer From b11adf60ab79183a50aa1c27fb4b76076296acd7 Mon Sep 17 00:00:00 2001 From: Benson Ma Date: Tue, 1 Oct 2024 12:26:03 -0700 Subject: [PATCH 23/48] Add support for int64_t indices and offsets in TBE inference [4/N] (#3128) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3128 X-link: https://github.com/facebookresearch/FBGEMM/pull/215 - Convert `pruned_array_lookup_cpu` to use `index_t` Reviewed By: jianyuh Differential Revision: D62470736 fbshipit-source-id: 759efa0ddd08059018ec6b429672ba267c98f3de --- ...bedding_forward_quantized_cpu_template.cpp | 57 +++++++++++-------- 1 file changed, 33 insertions(+), 24 deletions(-) diff --git a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_cpu_template.cpp b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_cpu_template.cpp index 2b126c96d..55f37eb16 100644 --- a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_cpu_template.cpp +++ b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_cpu_template.cpp @@ -456,6 +456,7 @@ for (const auto l : c10::irange(L)) { } {% if not weighted %} + Tensor pruned_array_lookup_cpu( Tensor indices, Tensor offsets, @@ -469,33 +470,41 @@ Tensor pruned_array_lookup_cpu( int32_t T = index_remappings_offsets.size(0) - 1; int32_t B = (offsets.size(0) - 1) / T; TORCH_CHECK(B > 0); + auto dense_indices = empty_like(indices); - const auto* indices_acc = indices.data_ptr(); - auto* dense_indices_acc = dense_indices.data_ptr(); - const auto* offsets_acc = offsets.data_ptr(); - const auto index_remappings_acc = index_remappings.data_ptr(); - const auto index_remappings_offsets_acc = index_remappings_offsets.data_ptr(); - at::parallel_for(0, T, 1, [&](int64_t begin, int64_t end) { - for (const auto t : c10::irange(begin, end)) { - int64_t index_remappings_start = index_remappings_offsets_acc[t]; - int64_t index_remappings_end = index_remappings_offsets_acc[t + 1]; - int64_t capacity = index_remappings_end - index_remappings_start; - int32_t indices_start = offsets_acc[t * B]; - int32_t indices_end = offsets_acc[(t + 1) * B]; - if (capacity > 0) { - for (const auto i : c10::irange(indices_start,indices_end)) { - int32_t idx = indices_acc[i]; - dense_indices_acc[i] = index_remappings_acc[index_remappings_start + idx]; - } - } else { - std::memcpy( - dense_indices_acc + indices_start, - indices_acc + indices_start, - (indices_end - indices_start) * sizeof(int32_t)); - } - } + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "pruned_array_lookup_cpu", [&] { + const auto* indices_acc = indices.data_ptr(); + auto* dense_indices_acc = dense_indices.data_ptr(); + const auto* offsets_acc = offsets.data_ptr(); + + const auto index_remappings_acc = index_remappings.data_ptr(); + const auto index_remappings_offsets_acc = index_remappings_offsets.data_ptr(); + + at::parallel_for(0, T, 1, [&](int64_t begin, int64_t end) { + for (const auto t : c10::irange(begin, end)) { + const auto index_remappings_start = index_remappings_offsets_acc[t]; + const auto index_remappings_end = index_remappings_offsets_acc[t + 1]; + const auto capacity = index_remappings_end - index_remappings_start; + + const auto indices_start = offsets_acc[t * B]; + const auto indices_end = offsets_acc[(t + 1) * B]; + + if (capacity > 0) { + for (const auto i : c10::irange(indices_start, indices_end)) { + auto idx = indices_acc[i]; + dense_indices_acc[i] = index_remappings_acc[index_remappings_start + idx]; + } + } else { + std::memcpy( + dense_indices_acc + indices_start, + indices_acc + indices_start, + (indices_end - indices_start) * sizeof(index_t)); + } + } + }); }); + return dense_indices; } From 8ca42646075f1a80acbf88fba6e104576d66d95f Mon Sep 17 00:00:00 2001 From: Bert Maher Date: Tue, 1 Oct 2024 14:49:45 -0700 Subject: [PATCH 24/48] Print TMA benchmark info to stderr (#3202) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/301 Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3202 Printing warnings to stdout mucks up the output of various tools/benchmarks Reviewed By: xuzhao9, htyu Differential Revision: D63643615 fbshipit-source-id: 1f34508a7fd36f5aa421e11bddd5ce77fc13038a --- fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py b/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py index 9df479e45..330ab51b8 100644 --- a/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +++ b/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py @@ -6,6 +6,7 @@ # pyre-unsafe import logging +import sys from typing import List, Optional, Tuple, Union import torch @@ -874,10 +875,14 @@ def _kernel_matmul_fp8_row_tma_persistent( if HAS_TMA_DESC: print( - "TMA benchmarks will be running with experimental grid constant TMA descriptor." + "TMA benchmarks will be running with experimental grid constant TMA descriptor.", + file=sys.stderr, ) else: - print("TMA benchmarks will be running without grid constant TMA descriptor.") + print( + "TMA benchmarks will be running without grid constant TMA descriptor.", + file=sys.stderr, + ) class TmaAutoTuneHelper: From 92d18a9d719fe812c461acff398a71df6cc909c1 Mon Sep 17 00:00:00 2001 From: Supadchaya Puangpontip Date: Tue, 1 Oct 2024 15:08:53 -0700 Subject: [PATCH 25/48] Add split_embeddings_utils_cpu (#3205) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/304 Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3205 Add split_embeddings_utils_cpu to enable adjust_info_num_bits and generate_vbe_metadata on CPU build Reviewed By: q10 Differential Revision: D63711688 fbshipit-source-id: e0384c2d1ddd89f175356a14d5d83e74698522f9 --- fbgemm_gpu/FbgemmGpu.cmake | 1 + .../get_infos_metadata.cu | 46 ------- .../split_embeddings_utils.cpp | 46 ------- .../split_embeddings_utils_cpu.cpp | 119 ++++++++++++++++++ 4 files changed, 120 insertions(+), 92 deletions(-) create mode 100644 fbgemm_gpu/src/split_embeddings_utils/split_embeddings_utils_cpu.cpp diff --git a/fbgemm_gpu/FbgemmGpu.cmake b/fbgemm_gpu/FbgemmGpu.cmake index ccf4805cd..dd23dca85 100644 --- a/fbgemm_gpu/FbgemmGpu.cmake +++ b/fbgemm_gpu/FbgemmGpu.cmake @@ -480,6 +480,7 @@ set(fbgemm_gpu_sources_static_cpu src/split_embeddings_cache/lru_cache_populate_byte.cpp src/split_embeddings_cache/lxu_cache.cpp src/split_embeddings_cache/split_embeddings_cache_ops.cpp + src/split_embeddings_utils/split_embeddings_utils_cpu.cpp codegen/training/index_select/batch_index_select_dim0_ops.cpp codegen/training/index_select/batch_index_select_dim0_cpu_host.cpp) diff --git a/fbgemm_gpu/src/split_embeddings_utils/get_infos_metadata.cu b/fbgemm_gpu/src/split_embeddings_utils/get_infos_metadata.cu index c3eb40819..a4efd4c21 100644 --- a/fbgemm_gpu/src/split_embeddings_utils/get_infos_metadata.cu +++ b/fbgemm_gpu/src/split_embeddings_utils/get_infos_metadata.cu @@ -13,52 +13,6 @@ using Tensor = at::Tensor; using namespace fbgemm_gpu; -DLL_PUBLIC std::tuple adjust_info_B_num_bits( - int32_t B, - int32_t T) { - int32_t info_B_num_bits = DEFAULT_INFO_B_NUM_BITS; - uint32_t info_B_mask = DEFAULT_INFO_B_MASK; - uint32_t max_T = MAX_T; - uint32_t max_B = MAX_B; - bool invalid_T = T > max_T; - bool invalid_B = B > max_B; - - TORCH_CHECK( - !(invalid_T && invalid_B), - "Not enough infos bits to accommodate T and B. Default num bits = ", - DEFAULT_INFO_NUM_BITS); - - if (invalid_T) { - // Reduce info_B_num_bits - while (invalid_T && !invalid_B && info_B_num_bits > 0) { - info_B_num_bits--; - max_T = ((max_T + 1) << 1) - 1; - max_B = ((max_B + 1) >> 1) - 1; - invalid_T = T > max_T; - invalid_B = B > max_B; - } - } else if (invalid_B) { - // Increase info_B_num_bits - while (!invalid_T && invalid_B && info_B_num_bits < DEFAULT_INFO_NUM_BITS) { - info_B_num_bits++; - max_T = ((max_T + 1) >> 1) - 1; - max_B = ((max_B + 1) << 1) - 1; - invalid_T = T > max_T; - invalid_B = B > max_B; - } - } - - TORCH_CHECK( - !invalid_T && !invalid_B, - "Not enough infos bits to accommodate T and B. Default num bits = ", - DEFAULT_INFO_NUM_BITS); - - // Recompute info_B_mask using new info_B_num_bits - info_B_mask = (1u << info_B_num_bits) - 1; - - return {info_B_num_bits, info_B_mask}; -} - DLL_PUBLIC std::tuple get_infos_metadata(Tensor unused, int64_t B, int64_t T) { return adjust_info_B_num_bits(B, T); diff --git a/fbgemm_gpu/src/split_embeddings_utils/split_embeddings_utils.cpp b/fbgemm_gpu/src/split_embeddings_utils/split_embeddings_utils.cpp index 4ae9ae0f7..8902e1c44 100644 --- a/fbgemm_gpu/src/split_embeddings_utils/split_embeddings_utils.cpp +++ b/fbgemm_gpu/src/split_embeddings_utils/split_embeddings_utils.cpp @@ -33,58 +33,12 @@ generate_vbe_metadata_meta( return {row_output_offsets, b_t_map}; } -std::tuple -generate_vbe_metadata_cpu( - const Tensor& B_offsets, - const Tensor& B_offsets_rank_per_feature, - const Tensor& output_offsets_feature_rank, - const Tensor& D_offsets, - const int64_t D, - const bool nobag, - const c10::SymInt max_B_feature_rank, - const int64_t info_B_num_bits, - const c10::SymInt total_B) { - Tensor row_output_offsets = output_offsets_feature_rank; - Tensor b_t_map = B_offsets_rank_per_feature; - return {row_output_offsets, b_t_map}; -} - } // namespace TORCH_LIBRARY_FRAGMENT(fbgemm, m) { - m.def( - "transpose_embedding_input(" - " Tensor hash_size_cumsum, " - " int total_hash_size_bits, " - " Tensor indices, " - " Tensor offsets, " - " bool nobag=False, " - " Tensor? vbe_b_t_map=None, " - " int info_B_num_bits=26, " - " int info_B_mask=0x2FFFFFF, " - " int total_unique_indices=-1, " - " bool is_index_select=False, " - " Tensor? total_L_offsets=None, " - " int fixed_L_per_warp=0, " - " int num_warps_per_feature=0" - ") -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)"); - m.def("get_infos_metadata(Tensor unused, int B, int T) -> (int, int)"); - m.def( - "generate_vbe_metadata(" - " Tensor B_offsets, " - " Tensor B_offsets_rank_per_feature, " - " Tensor output_offsets_feature_rank, " - " Tensor D_offsets, " - " int D, " - " bool nobag, " - " SymInt max_B_feature_rank, " - " int info_B_num_bits, " - " SymInt total_B" - ") -> (Tensor, Tensor)"); DISPATCH_TO_CUDA("transpose_embedding_input", transpose_embedding_input); DISPATCH_TO_CUDA("get_infos_metadata", get_infos_metadata); DISPATCH_TO_CUDA("generate_vbe_metadata", generate_vbe_metadata); - DISPATCH_TO_CPU("generate_vbe_metadata", generate_vbe_metadata_cpu); } TORCH_LIBRARY_IMPL(fbgemm, Meta, m) { diff --git a/fbgemm_gpu/src/split_embeddings_utils/split_embeddings_utils_cpu.cpp b/fbgemm_gpu/src/split_embeddings_utils/split_embeddings_utils_cpu.cpp new file mode 100644 index 000000000..654a3c3ed --- /dev/null +++ b/fbgemm_gpu/src/split_embeddings_utils/split_embeddings_utils_cpu.cpp @@ -0,0 +1,119 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include "fbgemm_gpu/split_embeddings_utils.h" +#include "fbgemm_gpu/utils/ops_utils.h" + +using Tensor = at::Tensor; + +DLL_PUBLIC std::tuple adjust_info_B_num_bits( + int32_t B, + int32_t T) { + int32_t info_B_num_bits = DEFAULT_INFO_B_NUM_BITS; + uint32_t info_B_mask = DEFAULT_INFO_B_MASK; + uint32_t max_T = MAX_T; + uint32_t max_B = MAX_B; + bool invalid_T = T > max_T; + bool invalid_B = B > max_B; + + TORCH_CHECK( + !(invalid_T && invalid_B), + "Not enough infos bits to accommodate T and B. Default num bits = ", + DEFAULT_INFO_NUM_BITS); + + if (invalid_T) { + // Reduce info_B_num_bits + while (invalid_T && !invalid_B && info_B_num_bits > 0) { + info_B_num_bits--; + max_T = ((max_T + 1) << 1) - 1; + max_B = ((max_B + 1) >> 1) - 1; + invalid_T = T > max_T; + invalid_B = B > max_B; + } + } else if (invalid_B) { + // Increase info_B_num_bits + while (!invalid_T && invalid_B && info_B_num_bits < DEFAULT_INFO_NUM_BITS) { + info_B_num_bits++; + max_T = ((max_T + 1) >> 1) - 1; + max_B = ((max_B + 1) << 1) - 1; + invalid_T = T > max_T; + invalid_B = B > max_B; + } + } + + TORCH_CHECK( + !invalid_T && !invalid_B, + "Not enough infos bits to accommodate T and B. Default num bits = ", + DEFAULT_INFO_NUM_BITS); + + // Recompute info_B_mask using new info_B_num_bits + info_B_mask = (1u << info_B_num_bits) - 1; + + return {info_B_num_bits, info_B_mask}; +} + +namespace { + +std::tuple +generate_vbe_metadata_cpu( + const Tensor& B_offsets, + const Tensor& B_offsets_rank_per_feature, + const Tensor& output_offsets_feature_rank, + const Tensor& D_offsets, + const int64_t D, + const bool nobag, + const c10::SymInt max_B_feature_rank, + const int64_t info_B_num_bits, + const c10::SymInt total_B) { + Tensor row_output_offsets = output_offsets_feature_rank; + Tensor b_t_map = B_offsets_rank_per_feature; + return {row_output_offsets, b_t_map}; +} + +std::tuple +get_infos_metadata_cpu(Tensor unused, int64_t B, int64_t T) { + return adjust_info_B_num_bits(B, T); +} + +} // namespace + +TORCH_LIBRARY_FRAGMENT(fbgemm, m) { + m.def( + "transpose_embedding_input(" + " Tensor hash_size_cumsum, " + " int total_hash_size_bits, " + " Tensor indices, " + " Tensor offsets, " + " bool nobag=False, " + " Tensor? vbe_b_t_map=None, " + " int info_B_num_bits=26, " + " int info_B_mask=0x2FFFFFF, " + " int total_unique_indices=-1, " + " bool is_index_select=False, " + " Tensor? total_L_offsets=None, " + " int fixed_L_per_warp=0, " + " int num_warps_per_feature=0" + ") -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)"); + m.def("get_infos_metadata(Tensor unused, int B, int T) -> (int, int)"); + m.def( + "generate_vbe_metadata(" + " Tensor B_offsets, " + " Tensor B_offsets_rank_per_feature, " + " Tensor output_offsets_feature_rank, " + " Tensor D_offsets, " + " int D, " + " bool nobag, " + " SymInt max_B_feature_rank, " + " int info_B_num_bits, " + " SymInt total_B" + ") -> (Tensor, Tensor)"); + DISPATCH_TO_CPU("generate_vbe_metadata", generate_vbe_metadata_cpu); + DISPATCH_TO_CPU("get_infos_metadata", get_infos_metadata_cpu); +} From 6f1e7b620aac02b26417af0d97809ca6b97974e0 Mon Sep 17 00:00:00 2001 From: Benson Ma Date: Tue, 1 Oct 2024 15:15:01 -0700 Subject: [PATCH 26/48] Add support for int64_t indices and offsets in TBE inference [5/N] (#3129) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3129 X-link: https://github.com/facebookresearch/FBGEMM/pull/216 - Convert `pruned_hashmap_lookup_cpu` to use `index_t` Reviewed By: spcyppt Differential Revision: D62472965 fbshipit-source-id: c1a9e8bc3796d253bc7f8af73b8ab21821efe72d --- ...bedding_forward_quantized_cpu_template.cpp | 96 ++++++++++++------- ...mbedding_forward_quantized_split_lookup.cu | 4 +- 2 files changed, 62 insertions(+), 38 deletions(-) diff --git a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_cpu_template.cpp b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_cpu_template.cpp index 55f37eb16..92eff015f 100644 --- a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_cpu_template.cpp +++ b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_cpu_template.cpp @@ -41,6 +41,16 @@ inline uint32_t pruned_hash_function(uint32_t h) { return h; } +inline uint64_t pruned_hash_function(uint64_t k) { + // MurmorHash3 64-bit mixing function. + k ^= k >> 33; + k *= (0xff51afd7ed558ccd); + k ^= k >> 33; + k *= (0xc4ceb9fe1a85ec53); + k ^= k >> 33; + return k; +} + } // namespace void pruned_hashmap_insert_{{ wdesc }}_cpu( @@ -404,54 +414,67 @@ Tensor pruned_hashmap_lookup_{{ wdesc }}_cpu( TENSOR_ON_CPU(offsets); TENSOR_ON_CPU(hash_table); TENSOR_ON_CPU(hash_table_offsets); + TENSORS_HAVE_SAME_SCALAR_TYPE(indices, offsets, hash_table); int32_t T = hash_table_offsets.size(0) - 1; int32_t B = (offsets.size(0) - 1) / T; TORCH_CHECK(B > 0); + auto dense_indices = empty_like(indices); - const auto* indices_acc = indices.data_ptr(); - auto* dense_indices_acc = dense_indices.data_ptr(); - const auto* offsets_acc = offsets.data_ptr(); - const auto hash_table_acc = hash_table.accessor(); - const auto hash_table_offsets_acc = hash_table_offsets.accessor(); -for (const auto t : c10::irange(T)) { - int64_t table_start = hash_table_offsets_acc[t]; - int64_t table_end = hash_table_offsets_acc[t + 1]; - int64_t capacity = table_end - table_start; -for (const auto b : c10::irange(B)) { - int32_t indices_start = offsets_acc[t * B + b]; - int32_t indices_end = offsets_acc[t * B + b + 1]; - int32_t L = indices_end - indices_start; + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "pruned_hashmap_lookup_{{ wdesc }}_cpu", [&] { + using hash_t = + std::conditional_t, uint64_t, uint32_t>; - if (table_start == table_end) { -for (const auto l : c10::irange(L)) { - dense_indices_acc[indices_start + l] = indices_acc[indices_start + l]; - } - } else { -for (const auto l : c10::irange(L)) { - int32_t idx = indices_acc[indices_start + l]; - uint32_t slot = pruned_hash_function(static_cast(idx)) % capacity; - while (true) { - int32_t slot_sparse_idx = hash_table_acc[table_start + static_cast(slot)][0]; - - // empty slot - if (slot_sparse_idx == -1) { - dense_indices_acc[indices_start + l] = -1; - break; - } - // already exists - if (slot_sparse_idx == idx) { - dense_indices_acc[indices_start + l] = hash_table_acc[table_start + static_cast(slot)][1]; - break; + const auto* indices_acc = indices.data_ptr(); + auto* dense_indices_acc = dense_indices.data_ptr(); + + const auto* offsets_acc = offsets.data_ptr(); + const auto hash_table_acc = hash_table.accessor(); + const auto hash_table_offsets_acc = hash_table_offsets.accessor(); + + for (const auto t : c10::irange(T)) { + const auto table_start = hash_table_offsets_acc[t]; + const auto table_end = hash_table_offsets_acc[t + 1]; + const auto capacity = table_end - table_start; + + for (const auto b : c10::irange(B)) { + const auto indices_start = offsets_acc[t * B + b]; + const auto indices_end = offsets_acc[t * B + b + 1]; + const auto L = indices_end - indices_start; + + if (table_start == table_end) { + for (const auto l : c10::irange(L)) { + dense_indices_acc[indices_start + l] = indices_acc[indices_start + l]; + } + + } else { + for (const auto l : c10::irange(L)) { + const auto idx = indices_acc[indices_start + l]; + auto slot = pruned_hash_function(static_cast(idx)) % capacity; + + while (true) { + const auto slot_sparse_idx = hash_table_acc[table_start + static_cast(slot)][0]; + + // empty slot + if (slot_sparse_idx == -1) { + dense_indices_acc[indices_start + l] = -1; + break; + } + // already exists + if (slot_sparse_idx == idx) { + dense_indices_acc[indices_start + l] = hash_table_acc[table_start + static_cast(slot)][1]; + break; + } + // linear probe + slot = (slot + 1) % capacity; } - // linear probe - slot = (slot + 1) % capacity; } } } } - } + }); + return dense_indices; } @@ -466,6 +489,7 @@ Tensor pruned_array_lookup_cpu( TENSOR_ON_CPU(offsets); TENSOR_ON_CPU(index_remappings); TENSOR_ON_CPU(index_remappings_offsets); + TENSORS_HAVE_SAME_SCALAR_TYPE(indices, offsets, index_remappings); int32_t T = index_remappings_offsets.size(0) - 1; int32_t B = (offsets.size(0) - 1) / T; diff --git a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_lookup.cu b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_lookup.cu index 86165bb39..846cd4763 100644 --- a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_lookup.cu +++ b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_lookup.cu @@ -160,7 +160,7 @@ Tensor pruned_hashmap_lookup_cuda( TORCH_CHECK(hash_table.size(0) < std::numeric_limits::max()); constexpr size_t kForwardMaxThreads = 256; - AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "pruned_hashmap_lookup", [&] { + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "pruned_hashmap_lookup_cuda", [&] { #ifdef FBGEMM_GPU_MEMCHECK const auto func_name = "int_nbit_split_embedding_codegen_forward_pruned_hashmap_lookup_kernel"; @@ -218,7 +218,7 @@ Tensor pruned_array_lookup_cuda( TORCH_CHECK(dense_indices.dim() == 1, "Tensor dim: ", dense_indices.dim()); constexpr size_t kForwardMaxThreads = 256; - AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "pruned_array_lookup", [&] { + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "pruned_array_lookup_cuda", [&] { #ifdef FBGEMM_GPU_MEMCHECK const auto func_name = "int_nbit_split_embedding_codegen_forward_pruned_array_lookup_kernel"; From a639dc0f8f12de77744ca82410b6d4c805a7caf6 Mon Sep 17 00:00:00 2001 From: Jun Luo Date: Tue, 1 Oct 2024 15:52:23 -0700 Subject: [PATCH 27/48] Refactor the GIS to reuse same autograd function for all backends (#3200) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/298 Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3200 This diff refactors the GIS impl to share the AutogradFunc for all device backends. Device backends only need to impl and register following two ops ``` torch::autograd::variable_list group_index_select_dim0_forward_impl_cpu( at::TensorList all_indices_input, const int64_t group_size) { throw std::runtime_error( "group_index_select_dim0_forward_impl is not implemented for CPU"); } torch::autograd::variable_list group_index_select_dim0_backward_impl_cpu( at::TensorList all_inputs, c10::SymIntArrayRef output_shape_group_ref) { throw std::runtime_error( "group_index_select_dim0_backward_impl is not implemented for CPU"); } ``` ``` DISPATCH_TO_CUDA( "group_index_select_dim0_forward_impl", fbgemm_gpu::forward_impl); DISPATCH_TO_CUDA( "group_index_select_dim0_backward_impl", fbgemm_gpu::backward_impl); ``` Reviewed By: spcyppt Differential Revision: D63335491 fbshipit-source-id: 8eb4be2c8bdcdad3bca4b4e157913a4fac93387f --- fbgemm_gpu/fbgemm_gpu/sparse_ops.py | 4 +- fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h | 104 +++ fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp | 122 ++- fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp | 774 ++++++++----------- 4 files changed, 521 insertions(+), 483 deletions(-) diff --git a/fbgemm_gpu/fbgemm_gpu/sparse_ops.py b/fbgemm_gpu/fbgemm_gpu/sparse_ops.py index de9a21ef9..3234d6668 100644 --- a/fbgemm_gpu/fbgemm_gpu/sparse_ops.py +++ b/fbgemm_gpu/fbgemm_gpu/sparse_ops.py @@ -1075,11 +1075,11 @@ def impl_autograd(op_name, fn, setup_context: Optional[Callable] = None) -> None ) impl_abstract("fbgemm::bounds_check_indices", bounds_check_indices_abstract) impl_abstract( - "fbgemm::group_index_select_dim0_gpu_impl", + "fbgemm::group_index_select_dim0_forward_impl", group_index_select_dim0_gpu_impl_abstract, ) impl_abstract( - "fbgemm::group_index_select_dim0_gpu_backward", + "fbgemm::group_index_select_dim0_backward_impl", group_index_select_dim0_gpu_backward_abstract, ) impl_abstract( diff --git a/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h b/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h index 9bea430ef..b3607a5b1 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h +++ b/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h @@ -9,8 +9,11 @@ #pragma once #include +#include #include #include +#include + #include namespace fbgemm_gpu { @@ -924,6 +927,107 @@ at::Tensor index_add_with_unique_indices_cuda( const int consecutive_range_start, const int consecutive_range_length); +torch::autograd::variable_list group_index_select_dim0_decomposed( + at::TensorList input_group, + at::TensorList indices_group); + +torch::autograd::variable_list group_index_select_dim0_autograd_impl( + at::TensorList all_indices_input, + const int64_t group_size); + +torch::autograd::variable_list group_index_select_dim0( + at::TensorList input_group, + at::TensorList indices_group); + +torch::autograd::variable_list group_index_select_dim0_forward_impl_cpu( + at::TensorList all_indices_input, + const int64_t group_size); + +torch::autograd::variable_list group_index_select_dim0_backward_impl_cpu( + at::TensorList all_inputs, + c10::SymIntArrayRef output_shape_group_ref); + +std::pair, std::vector> +group_index_select_dim0_unpack( + at::TensorList all_indices_input, + const int64_t group_size); + +class GroupIndexSelectDim0Op + : public torch::autograd::Function { + public: + static torch::autograd::variable_list forward( + torch::autograd::AutogradContext* ctx, + at::TensorList all_indices_input, + const int64_t group_size) { + at::AutoDispatchBelowADInplaceOrView guard; + static auto forward_op = + at::Dispatcher::singleton() + .findSchemaOrThrow( + "fbgemm::group_index_select_dim0_forward_impl", "") + .typed(); + auto result = forward_op.call(all_indices_input, group_size); + TORCH_CHECK(static_cast(result.size()) == group_size + 2); + + auto [input_group, indices_group] = + group_index_select_dim0_unpack(all_indices_input, group_size); + const auto input_dim = input_group[0].dim(); + std::vector input_shape_group; + input_shape_group.reserve(group_size * input_dim); + + for (const auto i : c10::irange(group_size)) { + const auto& input = input_group[i]; + // Copy input shape + auto input_shape = input.sym_sizes().vec(); + input_shape_group.insert( + input_shape_group.end(), input_shape.begin(), input_shape.end()); + } + + // save indices, args_tensor, saved_data + auto saved_tensors = std::vector(indices_group); + saved_tensors.insert( + saved_tensors.end(), result.cbegin() + group_size, result.cend()); + saved_tensors.push_back(input_group[0]); + ctx->save_for_backward(saved_tensors); + ctx->saved_data["input_shape_group"] = input_shape_group; + + return result; + } + + static torch::autograd::variable_list backward( + torch::autograd::AutogradContext* ctx, + torch::autograd::variable_list grad_output_group) { + TORCH_CHECK(grad_output_group.size() >= 2); + if (grad_output_group.size() == 2) { + // empty outputs + return torch::autograd::variable_list(1); + } + // remove redundant grads + auto group_size = grad_output_group.size() - 2; + grad_output_group.resize(group_size); + + auto saved_tensors = ctx->get_saved_variables(); + TORCH_CHECK(saved_tensors.size() == group_size + 3); + auto output_shape_group = + ctx->saved_data["input_shape_group"].toSymIntVector(); + grad_output_group.insert( + grad_output_group.end(), saved_tensors.begin(), saved_tensors.end()); + static auto backward_op = + at::Dispatcher::singleton() + .findSchemaOrThrow( + "fbgemm::group_index_select_dim0_backward_impl", "") + .typed(); + auto res = backward_op.call(grad_output_group, output_shape_group); + // 1) Add group_size Variable()'s for indices + // Replace all empty tensors with Variable(). This must be done after the + // op.call to make __torch_dispatch__ work for the backward op. + std::fill( + res.begin(), res.begin() + group_size, torch::autograd::Variable()); + // 3) Add 1 Variable() for group_size + res.push_back({}); + return res; + } +}; + ///@ingroup sparse-data-cuda void group_index_select_or_add_cuda( const int64_t* input_ptrs, diff --git a/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp b/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp index 7734cc69a..01dde3394 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp +++ b/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp @@ -2851,41 +2851,103 @@ Tensor pack_segments_cpu( const int64_t max_length) { return pack_segments_forward_cpu(t_in, lengths, max_length); } -namespace { -Tensor index_select_dim0( - const Tensor& input, - const Tensor& indices, - std::optional /*consecutive_range_start*/, - std::optional /*consecutive_range_length*/, - std::optional /*skip_indices_sorting_fwd*/) { - return at::index_select(input, 0, indices); + +torch::autograd::variable_list group_index_select_dim0_autograd_impl( + at::TensorList all_indices_input, + const int64_t group_size) { + return GroupIndexSelectDim0Op::apply(all_indices_input, group_size); +} + +std::pair, std::vector> +group_index_select_dim0_unpack( + at::TensorList all_indices_input, + const int64_t group_size) { + std::vector indices_group; + std::vector input_group; + + indices_group.reserve(group_size); + input_group.reserve(group_size); + + for (const auto i : c10::irange(group_size)) { + indices_group.push_back(all_indices_input[i]); + input_group.push_back(all_indices_input[group_size + i]); + } + + TORCH_CHECK(group_size == static_cast(indices_group.size())); + + return std::make_pair(input_group, indices_group); } torch::autograd::variable_list group_index_select_dim0( at::TensorList input_group, at::TensorList indices_group) { - int num_groups = input_group.size(); - TORCH_CHECK(num_groups == (int)indices_group.size()) + const auto group_size = indices_group.size(); std::vector output_group; - for (const auto i : c10::irange(num_groups)) { - output_group.push_back( - at::index_select(input_group[i], 0, indices_group[i])); + + if (group_size == 0) { + return std::vector(); } - return output_group; + + // Pack input_group and indices_group into TensorList + std::vector all_indices_input_vec; + all_indices_input_vec.reserve(group_size * 2); + + for (const Tensor& index : indices_group) { + all_indices_input_vec.push_back(index); + } + for (const Tensor& input : input_group) { + all_indices_input_vec.push_back(input); + } + + at::TensorList all_indices_input_tensor = all_indices_input_vec; + + static auto forward_op = + at::Dispatcher::singleton() + .findSchemaOrThrow( + "fbgemm::group_index_select_dim0_autograd_impl", "") + .typed(); + auto res = forward_op.call(all_indices_input_tensor, group_size); + TORCH_CHECK(res.size() == group_size + 2); + // only return the outputs (the first group_size elements) + res.resize(group_size); + return res; } -torch::autograd::variable_list group_index_select_dim0_gpu_impl_cpu( +torch::autograd::variable_list group_index_select_dim0_forward_impl_cpu( at::TensorList all_indices_input, const int64_t group_size) { throw std::runtime_error( - "group_index_select_dim0_gpu_impl is not implemented for CPU"); + "group_index_select_dim0_forward_impl is not implemented for CPU"); } -torch::autograd::variable_list group_index_select_dim0_gpu_backward_cpu( +torch::autograd::variable_list group_index_select_dim0_backward_impl_cpu( at::TensorList all_inputs, c10::SymIntArrayRef output_shape_group_ref) { throw std::runtime_error( - "group_index_select_dim0_gpu_backward is not implemented for CPU"); + "group_index_select_dim0_backward_impl is not implemented for CPU"); +} + +torch::autograd::variable_list group_index_select_dim0_decomposed( + at::TensorList input_group, + at::TensorList indices_group) { + int num_groups = input_group.size(); + TORCH_CHECK(num_groups == (int)indices_group.size()) + std::vector output_group; + for (const auto i : c10::irange(num_groups)) { + output_group.push_back( + at::index_select(input_group[i], 0, indices_group[i])); + } + return output_group; +} + +namespace { +Tensor index_select_dim0( + const Tensor& input, + const Tensor& indices, + std::optional /*consecutive_range_start*/, + std::optional /*consecutive_range_length*/, + std::optional /*skip_indices_sorting_fwd*/) { + return at::index_select(input, 0, indices); } Tensor bottom_k_per_row( @@ -3046,9 +3108,11 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { {PT2_COMPLIANT_TAG}); // group_index_select_dim0_gpu helper functions - not defined for CPU! m.def( - "group_index_select_dim0_gpu_impl(Tensor[] inputs, int group_size) -> Tensor[]"); + "group_index_select_dim0_autograd_impl(Tensor[] inputs, int group_size) -> Tensor[]"); + m.def( + "group_index_select_dim0_forward_impl(Tensor[] inputs, int group_size) -> Tensor[]"); m.def( - "group_index_select_dim0_gpu_backward(Tensor[] inputs, SymInt[] output_shape_group) -> Tensor[]"); + "group_index_select_dim0_backward_impl(Tensor[] inputs, SymInt[] output_shape_group) -> Tensor[]"); // This is an one-off op to be used in split_embedding_utils.py for zipf // generation w/o replacement along dim=-1. If requires_unique=True, find // smallest unique k. If the number of unique elements is less than k, @@ -3132,13 +3196,14 @@ TORCH_LIBRARY_IMPL(fbgemm, CPU, m) { "pack_segments_backward", fbgemm_gpu::pack_segments_backward_cpu); DISPATCH_TO_CPU("index_select_dim0", fbgemm_gpu::index_select_dim0); DISPATCH_TO_CPU( - "group_index_select_dim0", fbgemm_gpu::group_index_select_dim0); + "group_index_select_dim0", + fbgemm_gpu::group_index_select_dim0_decomposed); DISPATCH_TO_CPU( - "group_index_select_dim0_gpu_impl", - fbgemm_gpu::group_index_select_dim0_gpu_impl_cpu); + "group_index_select_dim0_forward_impl", + fbgemm_gpu::group_index_select_dim0_forward_impl_cpu); DISPATCH_TO_CPU( - "group_index_select_dim0_gpu_backward", - fbgemm_gpu::group_index_select_dim0_gpu_backward_cpu); + "group_index_select_dim0_backward_impl", + fbgemm_gpu::group_index_select_dim0_backward_impl_cpu); DISPATCH_TO_CPU("bottom_k_per_row", fbgemm_gpu::bottom_k_per_row); } @@ -3147,11 +3212,14 @@ TORCH_LIBRARY_IMPL(fbgemm, Autograd, m) { } TORCH_LIBRARY_IMPL(fbgemm, AutogradCPU, m) { - m.impl("group_index_select_dim0", &fbgemm_gpu::group_index_select_dim0); + m.impl( + "group_index_select_dim0", + &fbgemm_gpu::group_index_select_dim0_decomposed); } TORCH_LIBRARY_IMPL(fbgemm, Meta, m) { // CPU group_index_select_dim0 is decomposable m.impl( - "group_index_select_dim0", TORCH_FN(fbgemm_gpu::group_index_select_dim0)); + "group_index_select_dim0", + TORCH_FN(fbgemm_gpu::group_index_select_dim0_decomposed)); } diff --git a/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp b/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp index 6325017e8..77b7f7785 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp +++ b/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp @@ -193,442 +193,346 @@ class IndexSelectDim0GPUOp } }; -std::pair, std::vector> -group_index_select_dim0_unpack( +// need to combine input_group and indices_group into one tensor list +// to get this working with autograd. +static torch::autograd::variable_list forward_impl( at::TensorList all_indices_input, const int64_t group_size) { - std::vector indices_group; - std::vector input_group; + // Unpack from TensorList + auto [input_group, indices_group] = + group_index_select_dim0_unpack(all_indices_input, group_size); + + // args_tensor stores kernel arguments: + // input_ptrs (group_size int64_t elements) + // output_ptrs (group_size int64_t elements) + // indices_ptrs (group_size int64_t elements) + // warp_offsets_group (group_size + 1 int64_t elements) + // num_cols_group (group_size int32_t elements) + int64_t args_ptrs_offsets[NUM_ARGS + 1]; + + const int64_t numels_num_cols_group_64 = + compute_num_int64s(group_size); + + // Initialize offsets + args_ptrs_offsets[P_input_ptrs] = group_size; + args_ptrs_offsets[P_output_ptrs] = group_size; + args_ptrs_offsets[P_indices_ptrs] = group_size; + args_ptrs_offsets[P_warp_offsets_group_ptrs] = group_size + 1; + args_ptrs_offsets[P_num_cols_group_ptrs] = numels_num_cols_group_64; + + // Compute offsets + int64_t offset = 0; + auto next = args_ptrs_offsets[0]; + for (const auto i : c10::irange(NUM_ARGS)) { + args_ptrs_offsets[i] = offset; + offset += next; + next = args_ptrs_offsets[i + 1]; + } + // Total number of int64_t elements required + args_ptrs_offsets[NUM_ARGS] = offset; + + // Allocate memory for GroupIndexSelectArgs + at::Tensor args_tensor = at::empty( + {static_cast(args_ptrs_offsets[NUM_ARGS] * sizeof(int64_t))}, + at::TensorOptions().dtype(at::kByte).pinned_memory(true)); + + // Ensure that args_tensor is contiguous + TORCH_CHECK(args_tensor.is_contiguous()); + + // Initialize raw pointers to point to Tensor args_tensor + int64_t* input_ptrs = nullptr; + int64_t* output_ptrs = nullptr; + int64_t* indices_ptrs = nullptr; + int64_t* warp_offsets_group = nullptr; + int32_t* num_cols_group = nullptr; + + // Offset host pointers + offset_args( + &input_ptrs, + &output_ptrs, + &indices_ptrs, + &warp_offsets_group, + &num_cols_group, + reinterpret_cast(args_tensor.data_ptr()), + args_ptrs_offsets); + + auto& first_input = input_group[0]; + auto& first_indices = indices_group[0]; + + const int input_dim = first_input.dim(); + const int num_output_rows = first_indices.size(0); + const int num_input_rows = first_input.size(0); + Tensor input_reshaped = first_input.reshape({num_input_rows, -1}); + const int num_cols = input_reshaped.size(1); + const int cols_per_warp = get_group_index_select_cols_per_warp(); + int64_t warp_offset = 0; + bool use_var_cols = false; + + // Allocate memory for output_group + std::vector output_group; + output_group.reserve(group_size + 2); - indices_group.reserve(group_size); - input_group.reserve(group_size); + // We need to store contiguous inputs and indices outside the for-loop to + // guarantee that the contiguous tensors will outlive the kernel + // computation + std::vector> input_contigs; + std::vector> index_contigs; + input_contigs.reserve(group_size); + index_contigs.reserve(group_size); + // For each group, copy input to output for (const auto i : c10::irange(group_size)) { - indices_group.push_back(all_indices_input[i]); - input_group.push_back(all_indices_input[group_size + i]); - } + const auto& input = input_group[i]; + const auto& indices = indices_group[i]; - TORCH_CHECK(group_size == static_cast(indices_group.size())); + // Verify that all input tensors have the same number of dimensions + TORCH_CHECK( + input_dim == input.dim(), + "All inputs in group_index_select must have the same number of dimensions"); - return std::make_pair(input_group, indices_group); -} + // Verify that all tensors are on the same GPU + TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(input, indices); -class GroupIndexSelectDim0GPUOp - : public torch::autograd::Function { - public: - // need to combine input_group and indices_group into one tensor list - // to get this working with autograd. - static torch::autograd::variable_list forward_impl( - at::TensorList all_indices_input, - const int64_t group_size) { - // Unpack from TensorList - auto [input_group, indices_group] = - group_index_select_dim0_unpack(all_indices_input, group_size); - - // args_tensor stores kernel arguments: - // input_ptrs (group_size int64_t elements) - // output_ptrs (group_size int64_t elements) - // indices_ptrs (group_size int64_t elements) - // warp_offsets_group (group_size + 1 int64_t elements) - // num_cols_group (group_size int32_t elements) - int64_t args_ptrs_offsets[NUM_ARGS + 1]; - - const int64_t numels_num_cols_group_64 = - compute_num_int64s(group_size); - - // Initialize offsets - args_ptrs_offsets[P_input_ptrs] = group_size; - args_ptrs_offsets[P_output_ptrs] = group_size; - args_ptrs_offsets[P_indices_ptrs] = group_size; - args_ptrs_offsets[P_warp_offsets_group_ptrs] = group_size + 1; - args_ptrs_offsets[P_num_cols_group_ptrs] = numels_num_cols_group_64; - - // Compute offsets - int64_t offset = 0; - auto next = args_ptrs_offsets[0]; - for (const auto i : c10::irange(NUM_ARGS)) { - args_ptrs_offsets[i] = offset; - offset += next; - next = args_ptrs_offsets[i + 1]; + auto num_output_rows_ = indices.size(0); + + // Verify that all input tensors have the same shape[0] + TORCH_CHECK( + num_output_rows == num_output_rows_, + "The number of indices to be selected must be the same for the entire group"); + const auto input_reshaped_ = input.reshape({input.size(0), -1}); + + // Number of columns can be different + auto num_cols_ = input_reshaped_.size(1); + auto warps_per_row = (num_cols_ + cols_per_warp - 1) / cols_per_warp; + + if (num_cols != num_cols_) { + use_var_cols = true; } - // Total number of int64_t elements required - args_ptrs_offsets[NUM_ARGS] = offset; - - // Allocate memory for GroupIndexSelectArgs - at::Tensor args_tensor = at::empty( - {static_cast(args_ptrs_offsets[NUM_ARGS] * sizeof(int64_t))}, - at::TensorOptions().dtype(at::kByte).pinned_memory(true)); - - // Ensure that args_tensor is contiguous - TORCH_CHECK(args_tensor.is_contiguous()); - - // Initialize raw pointers to point to Tensor args_tensor - int64_t* input_ptrs = nullptr; - int64_t* output_ptrs = nullptr; - int64_t* indices_ptrs = nullptr; - int64_t* warp_offsets_group = nullptr; - int32_t* num_cols_group = nullptr; - - // Offset host pointers - offset_args( - &input_ptrs, - &output_ptrs, - &indices_ptrs, - &warp_offsets_group, - &num_cols_group, - reinterpret_cast(args_tensor.data_ptr()), - args_ptrs_offsets); - - auto& first_input = input_group[0]; - auto& first_indices = indices_group[0]; - - const int input_dim = first_input.dim(); - const int num_output_rows = first_indices.size(0); - const int num_input_rows = first_input.size(0); - Tensor input_reshaped = first_input.reshape({num_input_rows, -1}); - const int num_cols = input_reshaped.size(1); - const int cols_per_warp = get_group_index_select_cols_per_warp(); - int64_t warp_offset = 0; - bool use_var_cols = false; - - // Allocate memory for output_group - std::vector output_group; - output_group.reserve(group_size + 2); - - // We need to store contiguous inputs and indices outside the for-loop to - // guarantee that the contiguous tensors will outlive the kernel + + // Create output pointers + auto input_shape = input.sizes().vec(); + input_shape[0] = num_output_rows_; + Tensor output = at::empty(input_shape, input.options()); + // Ensure that the allocated output is contiguous + TORCH_CHECK(output.is_contiguous()) + output_group.push_back(output); + + // Store input and indices contigs to keep them alive during the kernel // computation - std::vector> input_contigs; - std::vector> index_contigs; - input_contigs.reserve(group_size); - index_contigs.reserve(group_size); - - // For each group, copy input to output - for (const auto i : c10::irange(group_size)) { - const auto& input = input_group[i]; - const auto& indices = indices_group[i]; - - // Verify that all input tensors have the same number of dimensions - TORCH_CHECK( - input_dim == input.dim(), - "All inputs in group_index_select must have the same number of dimensions"); - - // Verify that all tensors are on the same GPU - TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(input, indices); - - auto num_output_rows_ = indices.size(0); - - // Verify that all input tensors have the same shape[0] - TORCH_CHECK( - num_output_rows == num_output_rows_, - "The number of indices to be selected must be the same for the entire group"); - const auto input_reshaped_ = input.reshape({input.size(0), -1}); - - // Number of columns can be different - auto num_cols_ = input_reshaped_.size(1); - auto warps_per_row = (num_cols_ + cols_per_warp - 1) / cols_per_warp; - - if (num_cols != num_cols_) { - use_var_cols = true; - } - - // Create output pointers - auto input_shape = input.sizes().vec(); - input_shape[0] = num_output_rows_; - Tensor output = at::empty(input_shape, input.options()); - // Ensure that the allocated output is contiguous - TORCH_CHECK(output.is_contiguous()) - output_group.push_back(output); - - // Store input and indices contigs to keep them alive during the kernel - // computation - input_contigs.push_back(input.expect_contiguous()); - index_contigs.push_back(indices.expect_contiguous()); - - // Store args - input_ptrs[i] = reinterpret_cast(input_contigs[i]->data_ptr()); - output_ptrs[i] = reinterpret_cast(output.data_ptr()); - indices_ptrs[i] = reinterpret_cast(index_contigs[i]->data_ptr()); - warp_offsets_group[i] = warp_offset; - num_cols_group[i] = num_cols_; - - warp_offset += warps_per_row * num_output_rows; - } + input_contigs.push_back(input.expect_contiguous()); + index_contigs.push_back(indices.expect_contiguous()); - // Store the last offset - warp_offsets_group[group_size] = warp_offset; - - // Transfer args tensor to GPU - args_tensor = args_tensor.to( - first_input.device(), - /*non_blocking=*/true); - - // Offset raw ptrs in GPU memory - offset_args( - &input_ptrs, - &output_ptrs, - &indices_ptrs, - &warp_offsets_group, - &num_cols_group, - reinterpret_cast(args_tensor.data_ptr()), - args_ptrs_offsets); - - int64_t saved_data[] = { - static_cast(group_size), - use_var_cols, - reinterpret_cast(warp_offsets_group), - reinterpret_cast(num_cols_group), - warp_offset, - }; - auto saved_data_t = at::empty( - {sizeof(saved_data) / sizeof(int64_t)}, - at::TensorOptions().dtype(at::kLong)); - TORCH_CHECK(saved_data_t.is_contiguous()); - memcpy(saved_data_t.data_ptr(), saved_data, sizeof(saved_data)); - - group_index_select_or_add_cuda( - input_ptrs, - output_ptrs, - indices_ptrs, - warp_offsets_group, - num_cols_group, - first_input.scalar_type(), - first_indices.scalar_type(), - first_input.device().index(), - num_output_rows, - /*total_num_warps=*/warp_offset, - group_size, - /*use_index_select=*/true, - use_var_cols); - - output_group.push_back(args_tensor); - output_group.push_back(saved_data_t); - - // return format: - // (group_size outputs, 1 args_tensor, 1 saved_data) - return output_group; + // Store args + input_ptrs[i] = reinterpret_cast(input_contigs[i]->data_ptr()); + output_ptrs[i] = reinterpret_cast(output.data_ptr()); + indices_ptrs[i] = reinterpret_cast(index_contigs[i]->data_ptr()); + warp_offsets_group[i] = warp_offset; + num_cols_group[i] = num_cols_; + + warp_offset += warps_per_row * num_output_rows; } - static torch::autograd::variable_list forward( - torch::autograd::AutogradContext* ctx, - at::TensorList all_indices_input, - const int64_t group_size) { - at::AutoDispatchBelowADInplaceOrView guard; - static auto forward_op = - torch::Dispatcher::singleton() - .findSchemaOrThrow("fbgemm::group_index_select_dim0_gpu_impl", "") - .typed(); - auto result = forward_op.call(all_indices_input, group_size); - TORCH_CHECK(static_cast(result.size()) == group_size + 2); - - auto [input_group, indices_group] = - group_index_select_dim0_unpack(all_indices_input, group_size); - const auto input_dim = input_group[0].dim(); - std::vector input_shape_group; - input_shape_group.reserve(group_size * input_dim); - - for (const auto i : c10::irange(group_size)) { - const auto& input = input_group[i]; - // Copy input shape - auto input_shape = input.sym_sizes().vec(); - input_shape_group.insert( - input_shape_group.end(), input_shape.begin(), input_shape.end()); - } + // Store the last offset + warp_offsets_group[group_size] = warp_offset; + + // Transfer args tensor to GPU + args_tensor = args_tensor.to( + first_input.device(), + /*non_blocking=*/true); + + // Offset raw ptrs in GPU memory + offset_args( + &input_ptrs, + &output_ptrs, + &indices_ptrs, + &warp_offsets_group, + &num_cols_group, + reinterpret_cast(args_tensor.data_ptr()), + args_ptrs_offsets); + + int64_t saved_data[] = { + static_cast(group_size), + use_var_cols, + reinterpret_cast(warp_offsets_group), + reinterpret_cast(num_cols_group), + warp_offset, + }; + auto saved_data_t = at::empty( + {sizeof(saved_data) / sizeof(int64_t)}, + at::TensorOptions().dtype(at::kLong)); + TORCH_CHECK(saved_data_t.is_contiguous()); + memcpy(saved_data_t.data_ptr(), saved_data, sizeof(saved_data)); + + group_index_select_or_add_cuda( + input_ptrs, + output_ptrs, + indices_ptrs, + warp_offsets_group, + num_cols_group, + first_input.scalar_type(), + first_indices.scalar_type(), + first_input.device().index(), + num_output_rows, + /*total_num_warps=*/warp_offset, + group_size, + /*use_index_select=*/true, + use_var_cols); + + output_group.push_back(args_tensor); + output_group.push_back(saved_data_t); + + // return format: + // (group_size outputs, 1 args_tensor, 1 saved_data) + return output_group; +} - // save indices, args_tensor, saved_data - auto saved_tensors = std::vector(indices_group); - saved_tensors.insert( - saved_tensors.end(), result.cbegin() + group_size, result.cend()); - saved_tensors.push_back(input_group[0]); - ctx->save_for_backward(saved_tensors); - ctx->saved_data["input_shape_group"] = input_shape_group; +static torch::autograd::variable_list backward_impl( + at::TensorList all_inputs, + c10::SymIntArrayRef output_shape_group_ref) { + TORCH_CHECK(all_inputs.size() > 2); + + const int64_t group_size = (all_inputs.size() - 3) / 2; + + Tensor fwd_input = all_inputs[2 * group_size + 2]; + const int64_t output_dim = fwd_input.dim(); + Tensor saved_data = all_inputs[2 * group_size + 1]; + Tensor args_tensor_old = all_inputs[2 * group_size]; + Tensor first_indices = all_inputs[group_size]; + + auto grad_output_group = std::vector( + all_inputs.cbegin(), all_inputs.cbegin() + group_size); + std::vector output_shape_group; + output_shape_group.reserve(output_shape_group_ref.size()); + for (const auto& i : output_shape_group_ref) { + output_shape_group.push_back(i.as_int_unchecked()); + } - return result; + auto indices_group = std::vector( + all_inputs.cbegin() + group_size, all_inputs.cbegin() + 2 * group_size); + + // Retrieve saved data + TORCH_CHECK(saved_data.device() == at::kCPU); + TORCH_CHECK(saved_data.is_contiguous()); + int64_t* saved_data_ptr = saved_data.data_ptr(); + // Check that the size is the same + TORCH_CHECK(saved_data_ptr[0] == group_size); + const bool use_var_cols = saved_data_ptr[1]; + int64_t* warp_offsets_group = reinterpret_cast(saved_data_ptr[2]); + int32_t* num_cols_group = reinterpret_cast(saved_data_ptr[3]); + int64_t total_num_warps = saved_data_ptr[4]; + + // We checked in forward that all output rows are the same for all member + // in the group + const int num_input_rows = grad_output_group[0].size(0); + + std::vector outputs; + // Returning 3 outputs: + // 1) group_size Variable()'s for indices + // 2) group_size gradients for inputs + // 3) 1 Variable() for group_size + outputs.reserve(group_size * 2 + 1); + + // 1) Add group_size Variable()'s for indices + // c10::irange cannot be used in here as it + // triggers a build error of i being an unused variable. + // Add empty tensor with zero size here to make __torch_dispatch__ work for + // the backward op. Those empty tensors will be replaced with + // torch::autograd::Variable() outside of the op call. + for (auto i = 0; i < group_size; i++) { + outputs.push_back(at::empty({0}, at::TensorOptions().dtype(at::kLong))); } - static torch::autograd::variable_list backward_impl( - at::TensorList all_inputs, - c10::SymIntArrayRef output_shape_group_ref) { - TORCH_CHECK(all_inputs.size() > 2); - - const int64_t group_size = (all_inputs.size() - 3) / 2; - - Tensor fwd_input = all_inputs[2 * group_size + 2]; - const int64_t output_dim = fwd_input.dim(); - Tensor saved_data = all_inputs[2 * group_size + 1]; - Tensor args_tensor_old = all_inputs[2 * group_size]; - Tensor first_indices = all_inputs[group_size]; - - auto grad_output_group = std::vector( - all_inputs.cbegin(), all_inputs.cbegin() + group_size); - std::vector output_shape_group; - output_shape_group.reserve(output_shape_group_ref.size()); - for (const auto& i : output_shape_group_ref) { - output_shape_group.push_back(i.as_int_unchecked()); - } + // Allocate Tensor for ptrs of grad output and input, and indices + Tensor args_tensor = at::empty( + {group_size * 3}, + at::TensorOptions().dtype(at::kLong).pinned_memory(true)); + // Ensure that args_tensor is contiguous + TORCH_CHECK(args_tensor.is_contiguous()); + int64_t* grad_output_ptrs = args_tensor.data_ptr(); + int64_t* grad_input_ptrs = args_tensor.data_ptr() + group_size; + int64_t* indices_ptrs = args_tensor.data_ptr() + 2 * group_size; + + int64_t group_grad_input_numel = 0; + std::vector grad_input_numels; + grad_input_numels.reserve(group_size); + + // We need to store contiguous gradients outside the for-loop to guarantee + // that the contiguous tensors will outlive the kernel computation + std::vector> grad_output_contigs; + grad_output_contigs.reserve(group_size); - auto indices_group = std::vector( - all_inputs.cbegin() + group_size, all_inputs.cbegin() + 2 * group_size); - - // Retrieve saved data - TORCH_CHECK(saved_data.device() == at::kCPU); - TORCH_CHECK(saved_data.is_contiguous()); - int64_t* saved_data_ptr = saved_data.data_ptr(); - // Check that the size is the same - TORCH_CHECK(saved_data_ptr[0] == group_size); - const bool use_var_cols = saved_data_ptr[1]; - int64_t* warp_offsets_group = reinterpret_cast(saved_data_ptr[2]); - int32_t* num_cols_group = reinterpret_cast(saved_data_ptr[3]); - int64_t total_num_warps = saved_data_ptr[4]; - - // We checked in forward that all output rows are the same for all member - // in the group - const int num_input_rows = grad_output_group[0].size(0); - - std::vector outputs; - // Returning 3 outputs: - // 1) group_size Variable()'s for indices - // 2) group_size gradients for inputs - // 3) 1 Variable() for group_size - outputs.reserve(group_size * 2 + 1); - - // 1) Add group_size Variable()'s for indices - // c10::irange cannot be used in here as it - // triggers a build error of i being an unused variable. - // Add empty tensor with zero size here to make __torch_dispatch__ work for - // the backward op. Those empty tensors will be replaced with - // torch::autograd::Variable() outside of the op call. - for (auto i = 0; i < group_size; i++) { - outputs.push_back(at::empty({0}, at::TensorOptions().dtype(at::kLong))); - } + for (const auto i : c10::irange(group_size)) { + const auto& grad = grad_output_group[i]; + TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(grad, first_indices); - // Allocate Tensor for ptrs of grad output and input, and indices - Tensor args_tensor = at::empty( - {group_size * 3}, - at::TensorOptions().dtype(at::kLong).pinned_memory(true)); - // Ensure that args_tensor is contiguous - TORCH_CHECK(args_tensor.is_contiguous()); - int64_t* grad_output_ptrs = args_tensor.data_ptr(); - int64_t* grad_input_ptrs = args_tensor.data_ptr() + group_size; - int64_t* indices_ptrs = args_tensor.data_ptr() + 2 * group_size; - - int64_t group_grad_input_numel = 0; - std::vector grad_input_numels; - grad_input_numels.reserve(group_size); - - // We need to store contiguous gradients outside the for-loop to guarantee - // that the contiguous tensors will outlive the kernel computation - std::vector> grad_output_contigs; - grad_output_contigs.reserve(group_size); - - for (const auto i : c10::irange(group_size)) { - const auto& grad = grad_output_group[i]; - TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(grad, first_indices); - - // Store grad contigs to keep them alive during the kernel computation - grad_output_contigs.push_back(grad.expect_contiguous()); - - // Compute the total number of elements for all grad_inputs - int64_t grad_input_numel = output_shape_group[i * output_dim]; - for (auto j = (i * output_dim) + 1; j < (i + 1) * output_dim; j++) { - grad_input_numel *= output_shape_group[j]; - } - grad_input_numels.push_back(grad_input_numel); - group_grad_input_numel += grad_input_numel; - - // Put all grad output/input pointers in an array - grad_output_ptrs[i] = - reinterpret_cast(grad_output_contigs[i]->data_ptr()); - } + // Store grad contigs to keep them alive during the kernel computation + grad_output_contigs.push_back(grad.expect_contiguous()); - // Allocate a big tensor to avoid calling many small elementwise kernels - const auto group_grad_input = - at::zeros({group_grad_input_numel}, fwd_input.options()); - TORCH_CHECK(group_grad_input.is_contiguous()); + // Compute the total number of elements for all grad_inputs + int64_t grad_input_numel = output_shape_group[i * output_dim]; + for (auto j = (i * output_dim) + 1; j < (i + 1) * output_dim; j++) { + grad_input_numel *= output_shape_group[j]; + } + grad_input_numels.push_back(grad_input_numel); + group_grad_input_numel += grad_input_numel; - // Split to output_group - auto output_group = group_grad_input.split(grad_input_numels, 0); + // Put all grad output/input pointers in an array + grad_output_ptrs[i] = + reinterpret_cast(grad_output_contigs[i]->data_ptr()); + } - TORCH_CHECK(output_group.size() == static_cast(group_size)); + // Allocate a big tensor to avoid calling many small elementwise kernels + const auto group_grad_input = + at::zeros({group_grad_input_numel}, fwd_input.options()); + TORCH_CHECK(group_grad_input.is_contiguous()); - // Reshape grad inputs and obtain their pointers - for (int i = 0; i < group_size; i++) { - const auto grad_input_shape = std::vector( - output_shape_group.begin() + i * output_dim, - output_shape_group.begin() + (i + 1) * output_dim); - output_group[i] = output_group[i].reshape(grad_input_shape); - TORCH_CHECK(output_group[i].is_contiguous()); - grad_input_ptrs[i] = - reinterpret_cast(output_group[i].data_ptr()); + // Split to output_group + auto output_group = group_grad_input.split(grad_input_numels, 0); - // 2) Add group_size gradients for inputs - outputs.push_back(output_group[i]); - } + TORCH_CHECK(output_group.size() == static_cast(group_size)); - // Calculate indices_ptrs - std::vector> index_contigs; - index_contigs.reserve(group_size); - for (const auto i : c10::irange(group_size)) { - const auto& indices = indices_group[i]; - index_contigs.push_back(indices.expect_contiguous()); - indices_ptrs[i] = reinterpret_cast(index_contigs[i]->data_ptr()); - } + // Reshape grad inputs and obtain their pointers + for (int i = 0; i < group_size; i++) { + const auto grad_input_shape = std::vector( + output_shape_group.begin() + i * output_dim, + output_shape_group.begin() + (i + 1) * output_dim); + output_group[i] = output_group[i].reshape(grad_input_shape); + TORCH_CHECK(output_group[i].is_contiguous()); + grad_input_ptrs[i] = reinterpret_cast(output_group[i].data_ptr()); - // Transfer grad output pointers to GPU - args_tensor = args_tensor.to(first_indices.device(), /*non_blocking=*/true); - - group_index_select_or_add_cuda( - args_tensor.data_ptr(), - args_tensor.data_ptr() + group_size, - args_tensor.data_ptr() + 2 * group_size, - warp_offsets_group, - num_cols_group, - fwd_input.scalar_type(), - first_indices.scalar_type(), - fwd_input.device().index(), - num_input_rows, - total_num_warps, - group_size, - /*use_index_select=*/false, - use_var_cols); - - return outputs; + // 2) Add group_size gradients for inputs + outputs.push_back(output_group[i]); } - static torch::autograd::variable_list backward( - torch::autograd::AutogradContext* ctx, - torch::autograd::variable_list grad_output_group) { - TORCH_CHECK(grad_output_group.size() >= 2); - if (grad_output_group.size() == 2) { - // empty outputs - return torch::autograd::variable_list(1); - } - // remove redundant grads - auto group_size = grad_output_group.size() - 2; - grad_output_group.resize(group_size); - - auto saved_tensors = ctx->get_saved_variables(); - TORCH_CHECK(saved_tensors.size() == group_size + 3); - auto output_shape_group = - ctx->saved_data["input_shape_group"].toSymIntVector(); - grad_output_group.insert( - grad_output_group.end(), saved_tensors.begin(), saved_tensors.end()); - static auto backward_op = - torch::Dispatcher::singleton() - .findSchemaOrThrow( - "fbgemm::group_index_select_dim0_gpu_backward", "") - .typed(); - auto res = backward_op.call(grad_output_group, output_shape_group); - // 1) Add group_size Variable()'s for indices - // Replace all empty tensors with Variable(). This must be done after the - // op.call to make __torch_dispatch__ work for the backward op. - std::fill( - res.begin(), res.begin() + group_size, torch::autograd::Variable()); - // 3) Add 1 Variable() for group_size - res.push_back({}); - return res; + // Calculate indices_ptrs + std::vector> index_contigs; + index_contigs.reserve(group_size); + for (const auto i : c10::irange(group_size)) { + const auto& indices = indices_group[i]; + index_contigs.push_back(indices.expect_contiguous()); + indices_ptrs[i] = reinterpret_cast(index_contigs[i]->data_ptr()); } -}; + + // Transfer grad output pointers to GPU + args_tensor = args_tensor.to(first_indices.device(), /*non_blocking=*/true); + + group_index_select_or_add_cuda( + args_tensor.data_ptr(), + args_tensor.data_ptr() + group_size, + args_tensor.data_ptr() + 2 * group_size, + warp_offsets_group, + num_cols_group, + fwd_input.scalar_type(), + first_indices.scalar_type(), + fwd_input.device().index(), + num_input_rows, + total_num_warps, + group_size, + /*use_index_select=*/false, + use_var_cols); + + return outputs; +} Tensor pack_segments_cuda( const Tensor& t_in, @@ -654,45 +558,6 @@ Tensor index_select_dim0_gpu( user_skip_indices_sorting_fwd && !c10::InferenceMode::is_enabled())[0]; } -torch::autograd::variable_list group_index_select_dim0_gpu_impl( - at::TensorList all_indices_input, - const int64_t group_size) { - return GroupIndexSelectDim0GPUOp::apply(all_indices_input, group_size); -} - -torch::autograd::variable_list group_index_select_dim0_gpu( - at::TensorList input_group, - at::TensorList indices_group) { - const auto group_size = indices_group.size(); - std::vector output_group; - - if (group_size == 0) { - return std::vector(); - } - - // Pack input_group and indices_group into TensorList - std::vector all_indices_input_vec; - all_indices_input_vec.reserve(group_size * 2); - - for (const Tensor& index : indices_group) { - all_indices_input_vec.push_back(index); - } - for (const Tensor& input : input_group) { - all_indices_input_vec.push_back(input); - } - - at::TensorList all_indices_input_tensor = all_indices_input_vec; - - static auto forward_op = - torch::Dispatcher::singleton() - .findSchemaOrThrow("fbgemm::group_index_select_dim0_gpu_impl", "") - .typed(); - auto res = forward_op.call(all_indices_input_tensor, group_size); - TORCH_CHECK(res.size() == group_size + 2); - // only return the outputs (the first group_size elements) - res.resize(group_size); - return res; -} } // namespace fbgemm_gpu TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) { @@ -720,18 +585,19 @@ TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) { "pack_segments_backward", fbgemm_gpu::pack_segments_backward_cuda); DISPATCH_TO_CUDA("index_select_dim0", fbgemm_gpu::index_select_dim0_gpu); DISPATCH_TO_CUDA( - "group_index_select_dim0_gpu_impl", - fbgemm_gpu::GroupIndexSelectDim0GPUOp::forward_impl); + "group_index_select_dim0_forward_impl", fbgemm_gpu::forward_impl); + DISPATCH_TO_CUDA( + "group_index_select_dim0_backward_impl", fbgemm_gpu::backward_impl); DISPATCH_TO_CUDA( - "group_index_select_dim0_gpu_backward", - fbgemm_gpu::GroupIndexSelectDim0GPUOp::backward_impl); + "group_index_select_dim0", fbgemm_gpu::group_index_select_dim0); DISPATCH_TO_CUDA( - "group_index_select_dim0", fbgemm_gpu::group_index_select_dim0_gpu); + "group_index_select_dim0_autograd_impl", + &fbgemm_gpu::group_index_select_dim0_autograd_impl); } TORCH_LIBRARY_IMPL(fbgemm, AutogradCUDA, m) { - m.impl("group_index_select_dim0", &fbgemm_gpu::group_index_select_dim0_gpu); + m.impl("group_index_select_dim0", &fbgemm_gpu::group_index_select_dim0); m.impl( - "group_index_select_dim0_gpu_impl", - &fbgemm_gpu::group_index_select_dim0_gpu_impl); + "group_index_select_dim0_autograd_impl", + &fbgemm_gpu::group_index_select_dim0_autograd_impl); } From b907f3215724ab5d0679206f2a07d633ea0503fb Mon Sep 17 00:00:00 2001 From: Benson Ma Date: Tue, 1 Oct 2024 15:58:41 -0700 Subject: [PATCH 28/48] Add support for int64_t indices and offsets in TBE inference [6/N] (#3182) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3182 X-link: https://github.com/facebookresearch/FBGEMM/pull/278 - Convert `PrunedMapCPU::lookup` to use `index_t` Reviewed By: spcyppt Differential Revision: D62602764 fbshipit-source-id: 1367b5b846a01c29d08193868c732819729aa87a --- .../embedding_forward_quantized_host_cpu.cpp | 39 ++++++++++++------- .../codegen/utils/embedding_bounds_check.cu | 38 +++++++++--------- .../utils/embedding_bounds_check_host_cpu.cpp | 2 +- 3 files changed, 45 insertions(+), 34 deletions(-) diff --git a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_host_cpu.cpp b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_host_cpu.cpp index 41fd137dd..b6f55b961 100644 --- a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_host_cpu.cpp +++ b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_host_cpu.cpp @@ -21,6 +21,7 @@ #include "fbgemm_gpu/embedding_common.h" #include "fbgemm_gpu/utils/dispatch_macros.h" #include "fbgemm_gpu/utils/ops_utils.h" +#include "fbgemm_gpu/utils/tensor_utils.h" using Tensor = at::Tensor; using namespace fbgemm_gpu; @@ -374,29 +375,37 @@ class PrunedMapCPU : public torch::jit::CustomClassHolder { } Tensor lookup(Tensor indices, Tensor offsets) const { + TENSORS_HAVE_SAME_SCALAR_TYPE(indices, offsets); + int32_t T = maps_.size(); TORCH_CHECK(T > 0); int32_t B = (offsets.size(0) - 1) / T; TORCH_CHECK(B > 0); TORCH_CHECK(maps_.size() == T); + auto dense_indices = empty_like(indices); - const auto* indices_acc = indices.data_ptr(); - auto* dense_indices_acc = dense_indices.data_ptr(); - const auto* offsets_acc = offsets.data_ptr(); - for (const auto t : c10::irange(T)) { - auto& map = maps_[t]; - for (const auto b : c10::irange(B)) { - int32_t indices_start = offsets_acc[t * B + b]; - int32_t indices_end = offsets_acc[t * B + b + 1]; - int32_t L = indices_end - indices_start; - for (const auto l : c10::irange(L)) { - int32_t slot_sparse_index = indices_acc[indices_start + l]; - auto it = map.find(slot_sparse_index); - dense_indices_acc[indices_start + l] = - it != map.end() ? it->second : -1; + + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "PrunedMapCPU::lookup", [&] { + const auto* indices_acc = indices.data_ptr(); + auto* dense_indices_acc = dense_indices.data_ptr(); + const auto* offsets_acc = offsets.data_ptr(); + + for (const auto t : c10::irange(T)) { + auto& map = maps_[t]; + for (const auto b : c10::irange(B)) { + const auto indices_start = offsets_acc[t * B + b]; + const auto indices_end = offsets_acc[t * B + b + 1]; + const auto L = indices_end - indices_start; + for (const auto l : c10::irange(L)) { + const auto slot_sparse_index = indices_acc[indices_start + l]; + const auto it = map.find(slot_sparse_index); + dense_indices_acc[indices_start + l] = + it != map.end() ? it->second : -1; + } } } - } + }); + return dense_indices; } diff --git a/fbgemm_gpu/codegen/utils/embedding_bounds_check.cu b/fbgemm_gpu/codegen/utils/embedding_bounds_check.cu index 08e22baa9..8d8ee6ab5 100644 --- a/fbgemm_gpu/codegen/utils/embedding_bounds_check.cu +++ b/fbgemm_gpu/codegen/utils/embedding_bounds_check.cu @@ -233,22 +233,24 @@ void bounds_check_indices_cuda( constexpr size_t kNumThreads = 256; const auto max_B_ = vbe ? max_B : B; - AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "bounds_check_indices", [&] { - const auto bounds_check_kernel = - (vbe ? bounds_check_indices_kernel - : bounds_check_indices_kernel); - TORCH_DSA_KERNEL_LAUNCH( - bounds_check_kernel, - div_round_up(max_B_ * T, kNumThreads / fbgemm_gpu::kWarpSize), - dim3(fbgemm_gpu::kWarpSize, kNumThreads / fbgemm_gpu::kWarpSize), - 0, - at::cuda::getCurrentCUDAStream(), - rows_per_table.packed_accessor32(), - indices.packed_accessor32(), - offsets.packed_accessor32(), - vbe ? B_offsets.value().data_ptr() : nullptr, - bounds_check_mode_, - warning.packed_accessor32(), - FixedDivisor(max_B_)); - }); + AT_DISPATCH_INDEX_TYPES( + indices.scalar_type(), "bounds_check_indices_cuda", [&] { + const auto bounds_check_kernel = + (vbe ? bounds_check_indices_kernel + : bounds_check_indices_kernel); + TORCH_DSA_KERNEL_LAUNCH( + bounds_check_kernel, + div_round_up(max_B_ * T, kNumThreads / fbgemm_gpu::kWarpSize), + dim3(fbgemm_gpu::kWarpSize, kNumThreads / fbgemm_gpu::kWarpSize), + 0, + at::cuda::getCurrentCUDAStream(), + rows_per_table + .packed_accessor32(), + indices.packed_accessor32(), + offsets.packed_accessor32(), + vbe ? B_offsets.value().data_ptr() : nullptr, + bounds_check_mode_, + warning.packed_accessor32(), + FixedDivisor(max_B_)); + }); } diff --git a/fbgemm_gpu/codegen/utils/embedding_bounds_check_host_cpu.cpp b/fbgemm_gpu/codegen/utils/embedding_bounds_check_host_cpu.cpp index 1098378d0..1d0cd1348 100644 --- a/fbgemm_gpu/codegen/utils/embedding_bounds_check_host_cpu.cpp +++ b/fbgemm_gpu/codegen/utils/embedding_bounds_check_host_cpu.cpp @@ -70,7 +70,7 @@ void bounds_check_indices_cpu( const auto rows_per_table_acc = rows_per_table.accessor(); auto warning_acc = warning.data_ptr(); - AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "bounds_check_indices", [&] { + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "bounds_check_indices_cpu", [&] { auto offsets_acc = offsets.accessor(); auto indices_acc = indices.accessor(); auto num_indices = indices.numel(); From f9de20903af236e1d5688770148068fd825d6d07 Mon Sep 17 00:00:00 2001 From: Supadchaya Puangpontip Date: Tue, 1 Oct 2024 16:45:20 -0700 Subject: [PATCH 29/48] Enable VBE support on CPU (#3174) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/286 Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3174 Previous VBE on CPU was enabled in lookup_{{ optimizer }}.py. To support MTIA ops, VBE should be done after torch.ops.fbgemm.{{ mdesc }}_embedding_codegen_lookup_{{ optimizer }}_function_pt2. This diff follows the same implementation but enables it C++ so that it goes through the same PT2 pipeline (i.e., lookup -> VBE autograd -> cpu wrapper (*do vbe here*) -> cpu kernel). the call is done Reviewed By: q10 Differential Revision: D63410944 fbshipit-source-id: 7b1f1aa44429fc68abf9e9c39d6857292baa948e --- fbgemm_gpu/FbgemmGpu.cmake | 5 +- .../genscript/generate_forward_split.py | 1 + ...dding_split_host_pt2_autograd_template.cpp | 20 ++- ...ng_split_host_pt2_cpu_wrapper_template.cpp | 127 +++++++++++++----- .../training/pt2/pt2_autograd_utils.cpp | 62 +++++++++ .../fbgemm_gpu/utils/pt2_autograd_utils.h | 31 +++++ 6 files changed, 207 insertions(+), 39 deletions(-) create mode 100644 fbgemm_gpu/codegen/training/pt2/pt2_autograd_utils.cpp create mode 100644 fbgemm_gpu/include/fbgemm_gpu/utils/pt2_autograd_utils.h diff --git a/fbgemm_gpu/FbgemmGpu.cmake b/fbgemm_gpu/FbgemmGpu.cmake index dd23dca85..99d5e5b0c 100644 --- a/fbgemm_gpu/FbgemmGpu.cmake +++ b/fbgemm_gpu/FbgemmGpu.cmake @@ -263,10 +263,10 @@ list(APPEND gen_gpu_host_source_files foreach(optimizer ${ALL_OPTIMIZERS}) list(APPEND gen_cpu_source_files "gen_embedding_backward_split_${optimizer}_cpu.cpp" - "gen_embedding_backward_split_${optimizer}_pt2_cpu_wrapper.cpp") + "gen_embedding_backward_split_${optimizer}_pt2_cpu_wrapper.cpp" + "gen_embedding_split_${optimizer}_pt2_autograd.cpp") list(APPEND gen_gpu_host_source_files "gen_embedding_backward_split_${optimizer}.cpp" - "gen_embedding_split_${optimizer}_pt2_autograd.cpp" "gen_embedding_backward_split_${optimizer}_pt2_cuda_wrapper.cpp") endforeach() @@ -454,6 +454,7 @@ set(fbgemm_gpu_sources_static_cpu codegen/training/forward/embedding_forward_split_cpu.cpp codegen/inference/embedding_forward_quantized_host_cpu.cpp codegen/training/backward/embedding_backward_dense_host_cpu.cpp + codegen/training/pt2/pt2_autograd_utils.cpp codegen/utils/embedding_bounds_check_host_cpu.cpp src/config/feature_gates.cpp src/memory_utils/memory_utils.cpp diff --git a/fbgemm_gpu/codegen/genscript/generate_forward_split.py b/fbgemm_gpu/codegen/genscript/generate_forward_split.py index 285cf9a55..894ce104c 100644 --- a/fbgemm_gpu/codegen/genscript/generate_forward_split.py +++ b/fbgemm_gpu/codegen/genscript/generate_forward_split.py @@ -83,6 +83,7 @@ def generate_pt2_wrappers() -> None: f"gen_embedding_forward_split_pt2_cpu_wrapper.cpp", has_cpu_support=True, is_forward=True, + has_vbe_support=True, ) # Generate PT2 forward wrapper (CUDA) diff --git a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp index b623b92d0..3666de5b9 100644 --- a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp +++ b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp @@ -35,8 +35,12 @@ #include "fbgemm_gpu/utils/ops_utils.h" #include #include "fbgemm_gpu/utils/dispatch_macros.h" -#include "fbgemm_gpu/split_embeddings_utils.cuh" +#include "fbgemm_gpu/embedding_common.h" +#include "fbgemm_gpu/split_embeddings_utils.h" #include "fbgemm_gpu/config/feature_gates.h" +{%- if has_vbe_support %} +#include "fbgemm_gpu/utils/pt2_autograd_utils.h" +{%- endif %} using Tensor = at::Tensor; @@ -236,9 +240,9 @@ enum SSDTensor { const Tensor& /*prev_iter_dev*/, {%- endif %} {%- if "iter" not in args_pt2.split_function_arg_names %} - const int64_t iter, + const int64_t /*iter*/, {%- endif %} - const double gwd_lower_bound, + const double /*gwd_lower_bound*/, {%- endif %} {# /* if is_gwd */ #} {%- for arg_type in args_pt2.split_function_args %} {{ arg_type.split(' ')[0]}}{%- if not loop.last %}{{ "," }}{%- endif %} @@ -617,7 +621,6 @@ class {{ autograd_func }} : const c10::SymInt, const int64_t, const c10::SymInt)>(); - auto [ vbe_row_output_offsets, vbe_b_t_map @@ -850,6 +853,11 @@ static torch::autograd::variable_list backward( // {{ fwd_mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_pt2_cuda) weights_dev = weights_dev.flatten(); {%- endif %} + {%- if vbe %} + if (weights_host.numel() > 1){ + grad_output = reshape_vbe_output(grad_output, B_offsets, vbe_b_t_map, D_offsets); + } + {%- endif %} {%- set grad_indice_weights_op = "{}_embedding_codegen_grad_indice_weights{}_pt2_wrapper".format(fwd_mdesc, vdesc) @@ -883,7 +891,7 @@ static torch::autograd::variable_list backward( {%- else %} const Tensor& /*feature_requires_grad*/ {%- endif %} - )>(); + )>(); const auto grad_indice_weights = !indice_weights.defined() ? Variable() : @@ -1014,7 +1022,7 @@ Tensor {{ bwd_mdesc }}_embedding_codegen_lookup_{{ optimizer }}_function_pt2( {%- if not ssd %} {%- if has_vbe_support %} // has vbe support and on gpu - if (B_offsets.has_value() && !(weights[0].numel() > 0)) { + if (B_offsets.has_value()) { {%- if has_global_weight_decay_support %} // vbe and has gwd support if (apply_global_weight_decay && weight_decay > 0) { diff --git a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_cpu_wrapper_template.cpp b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_cpu_wrapper_template.cpp index c74355207..5b2b066fe 100644 --- a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_cpu_wrapper_template.cpp +++ b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_cpu_wrapper_template.cpp @@ -30,9 +30,12 @@ using Tensor = at::Tensor; using namespace fbgemm_gpu; +{%- for vbe in ([True, False] if has_vbe_support else [False]) %} +{%- set vdesc = "_vbe" if vbe else "" %} + {%- if is_forward %} {#-/* PT2 wrapper function for backward grad_indice_weights CPU */#} -Tensor split_embedding_codegen_grad_indice_weights_pt2_cpu_wrapper( +Tensor split_embedding_codegen_grad_indice_weights{{ vdesc }}_pt2_cpu_wrapper( const Tensor& grad_output, const Tensor& host_weights, const Tensor& /*dev_weights*/, @@ -45,7 +48,16 @@ Tensor split_embedding_codegen_grad_indice_weights_pt2_cpu_wrapper( const Tensor& indices, const Tensor& offsets, const Tensor& /*lxu_cache_locations*/, - const Tensor& feature_requires_grad) { + {%- if vbe %} + const Tensor& feature_requires_grad, + const Tensor& vbe_row_output_offsets, + const Tensor& vbe_b_t_map, + const int64_t info_B_num_bits, + const int64_t info_B_mask_int64 + {%- else %} + const Tensor& feature_requires_grad + {%- endif %} +) { static auto op = torch::Dispatcher::singleton() .findSchemaOrThrow( @@ -67,7 +79,7 @@ Tensor split_embedding_codegen_grad_indice_weights_pt2_cpu_wrapper( {% if is_forward %} {#-/* PT2 wrapper function for forward CPU */#} -Tensor split_embedding_codegen_forward_{{ wdesc }}_pt2_cpu_wrapper( +Tensor split_embedding_codegen_forward_{{ wdesc }}{{ vdesc }}_pt2_cpu_wrapper( const Tensor& host_weights, const Tensor& /*dev_weights*/, const Tensor& /*uvm_weights*/, @@ -84,30 +96,77 @@ Tensor split_embedding_codegen_forward_{{ wdesc }}_pt2_cpu_wrapper( const Tensor& indice_weights, const Tensor& /*lxu_cache_locations*/, const Tensor& /*uvm_cache_stats*/, + {%- if vbe %} + const Tensor& vbe_row_output_offsets, /*vbe_output_offsets_feature_rank*/ + const Tensor& vbe_b_t_map, /*vbe_B_offsets_rank_per_feature*/ + const c10::SymInt vbe_output_size, + const int64_t info_B_num_bits, + const int64_t info_B_mask_int64, + {%- endif %} const bool /*is_experimental = false*/, const int64_t output_dtype = static_cast(SparseType::FP32)) { - static auto op = - torch::Dispatcher::singleton() - .findSchemaOrThrow("fbgemm::split_embedding_codegen_forward_cpu", "") - .typed(); + static auto op = + torch::Dispatcher::singleton() + .findSchemaOrThrow("fbgemm::split_embedding_codegen_forward_cpu", "") + .typed(); + {%- if vbe %} + // TODO: remove this after vbe is implemented for CPU kernel + Tensor vbe_B_offsets_rank_per_feature = vbe_b_t_map; + Tensor vbe_output_offsets_feature_rank = vbe_row_output_offsets; + const auto output = op.call( + host_weights, + weights_offsets, + D_offsets, + total_D, + hash_size_cumsum, + indices, + offsets, + pooling_mode, + indice_weights, + output_dtype); + auto options = at::TensorOptions() + .dtype(output.options().dtype()) + .device(host_weights.options().device()); + const int64_t vbe_output_size_ = vbe_output_size.guard_int(__FILE__, __LINE__); + Tensor output_new = at::empty({vbe_output_size_}, options); + const int32_t T = D_offsets.numel() - 1; + const int32_t R = vbe_B_offsets_rank_per_feature.size(1) - 1; - return op.call( - host_weights, - weights_offsets, - D_offsets, - total_D, - hash_size_cumsum, - indices, - offsets, - pooling_mode, - indice_weights, - output_dtype); -} + for (int32_t r = 0; r < R; r++){ + auto D_offset = 0; + for (int32_t t = 0; t < T; t++){ + const int32_t o_begin = vbe_output_offsets_feature_rank[r * T + t].item(); + const int32_t o_end = vbe_output_offsets_feature_rank[r * T + t + 1].item(); + const int32_t D = D_offsets[t + 1].item() - D_offsets[t].item(); + const int32_t b_begin = vbe_B_offsets_rank_per_feature[t][r].item(); + const int32_t b_end = vbe_B_offsets_rank_per_feature[t][r + 1].item(); + + TORCH_CHECK((o_end - o_begin) == ((b_end - b_begin) * D)); + auto values = output.index({torch::indexing::Slice(b_begin, b_end), torch::indexing::Slice(D_offset, D_offset + D)}).flatten(); + output_new.index_put_({torch::indexing::Slice(o_begin, o_end)}, values); + D_offset += D; + } + } + return output_new; + {%- else %} + return op.call( + host_weights, + weights_offsets, + D_offsets, + total_D, + hash_size_cumsum, + indices, + offsets, + pooling_mode, + indice_weights, + output_dtype); + {%- endif %} + } {% else %} {#-/* PT2 wrapper function for backward CPU */#} -Tensor split_embedding_backward_codegen_{{ optimizer }}_{{ wdesc }}_pt2_cpu_wrapper( +Tensor split_embedding_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}_pt2_cpu_wrapper( const Tensor& grad_output, const Tensor& host_weights, const Tensor& /*dev_weights*/, @@ -127,8 +186,13 @@ Tensor split_embedding_backward_codegen_{{ optimizer }}_{{ wdesc }}_pt2_cpu_wrap const int64_t /*BT_block_size*/, const int64_t /*max_segment_length_per_warp*/, const bool stochastic_rounding, - const int64_t /*info_B_num_bits*/, - const int64_t /*info_B_mask_int64*/, + const int64_t info_B_num_bits, + const int64_t info_B_mask_int64, + {%- if vbe %} + const Tensor& B_offsets, + const Tensor& vbe_row_output_offsets, + const Tensor& vbe_b_t_map, + {%- endif %} const bool /*use_uniq_cache_locations*/, const bool /*use_homogeneous_placements*/, {{ args_pt2.split_function_args | join(", ") }} @@ -194,29 +258,30 @@ namespace { TORCH_LIBRARY_FRAGMENT(fbgemm, m) { {%- if is_forward %} DISPATCH_TO_CPU( - "split_embedding_codegen_grad_indice_weights_pt2_wrapper", - split_embedding_codegen_grad_indice_weights_pt2_cpu_wrapper); + "split_embedding_codegen_grad_indice_weights{{ vdesc }}_pt2_wrapper", + split_embedding_codegen_grad_indice_weights{{ vdesc }}_pt2_cpu_wrapper); {%- endif %} {%- for weighted in [True, False] %} {%- set wdesc = "weighted" if weighted else "unweighted" %} {%- if is_forward %} - {%- set embedding_codegen_forward_op = "split_embedding_codegen_forward_{}_pt2".format( - wdesc + {%- set embedding_codegen_forward_op = "split_embedding_codegen_forward_{}{}_pt2".format( + wdesc, vdesc ) %} DISPATCH_TO_CPU("{{ embedding_codegen_forward_op }}_wrapper", {{ embedding_codegen_forward_op }}_cpu_wrapper); {%- else %} - {%- set embedding_codegen_backward_op = "split_embedding_backward_codegen_{}_{}_pt2".format( - optimizer, wdesc + {%- set embedding_codegen_backward_op = "split_embedding_backward_codegen_{}_{}{}_pt2".format( + optimizer, wdesc, vdesc ) %} DISPATCH_TO_CPU("{{ embedding_codegen_backward_op }}_wrapper", {{ embedding_codegen_backward_op }}_cpu_wrapper); {%- endif %} {%- endfor %} {#-/*for weighted*/#} } - } // namespace +{%- endfor %} {#-/* for vbe in [True, False] */#} + {% endif %} // if has_cpu_support // clang-format on diff --git a/fbgemm_gpu/codegen/training/pt2/pt2_autograd_utils.cpp b/fbgemm_gpu/codegen/training/pt2/pt2_autograd_utils.cpp new file mode 100644 index 000000000..071acf90a --- /dev/null +++ b/fbgemm_gpu/codegen/training/pt2/pt2_autograd_utils.cpp @@ -0,0 +1,62 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +using Tensor = at::Tensor; + +namespace fbgemm_gpu { + +//////////////////////////////////////////////////////////////////////////////// +// Helper Functions +//////////////////////////////////////////////////////////////////////////////// + +Tensor reshape_vbe_output( + const Tensor& grad_output, + const Tensor& B_offsets, + const Tensor& B_offsets_rank_per_feature, + const Tensor& D_offsets) { + /* FOR CPU VBE to use the same backend */ + const auto T = D_offsets.numel() - 1; + int32_t max_B = 0; + int32_t total_D = 0; + // find max_B, total_D to create output [max_B, total_D] + for (int32_t t = 0; t < T; t++) { + auto b = B_offsets[t + 1].item() - B_offsets[t].item(); + max_B = std::max(max_B, b); + total_D += D_offsets[t + 1].item() - D_offsets[t].item(); + } + auto grad_output_ = at::empty({max_B, total_D}, grad_output.options()); + // for each feature + auto offset = 0; + + const int32_t R = B_offsets_rank_per_feature.size(1) - 1; + for (int32_t r = 0; r < R; r++) { + auto D_offset = 0; + for (int32_t t = 0; t < T; t++) { + const int32_t b_begin = B_offsets_rank_per_feature[t][r].item(); + const int32_t b_end = + B_offsets_rank_per_feature[t][r + 1].item(); + const int32_t D = + D_offsets[t + 1].item() - D_offsets[t].item(); + const int32_t b = b_end - b_begin; + const int32_t num_elm = b * D; + auto values = grad_output.slice(0, offset, offset + num_elm); + values = values.reshape({b, D}); + grad_output_.index_put_( + {at::indexing::Slice(b_begin, b_end), + at::indexing::Slice(D_offset, D_offset + D)}, + values); + D_offset += D; + offset += num_elm; + } + } + return grad_output_; +} +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/include/fbgemm_gpu/utils/pt2_autograd_utils.h b/fbgemm_gpu/include/fbgemm_gpu/utils/pt2_autograd_utils.h new file mode 100644 index 000000000..3aff58c9a --- /dev/null +++ b/fbgemm_gpu/include/fbgemm_gpu/utils/pt2_autograd_utils.h @@ -0,0 +1,31 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +// #include +// #include +// #include "fbgemm_gpu/embedding_common.h" +// #include "fbgemm_gpu/utils/dispatch_macros.h" +// #include "fbgemm_gpu/utils/ops_utils.h" +// #include "fbgemm_gpu/utils/tensor_utils.h" + +using Tensor = at::Tensor; + +namespace fbgemm_gpu { + +//////////////////////////////////////////////////////////////////////////////// +// Helper Functions +//////////////////////////////////////////////////////////////////////////////// + +Tensor reshape_vbe_output( + const Tensor& grad_output, + const Tensor& B_offsets, + const Tensor& B_offsets_rank_per_feature, + const Tensor& D_offsets); +} // namespace fbgemm_gpu From c24a72d3a02b3e0227f7048b3ea2bde0983b4227 Mon Sep 17 00:00:00 2001 From: Joe Wang Date: Tue, 1 Oct 2024 17:12:59 -0700 Subject: [PATCH 30/48] fix rocksdb snapshot read (#3206) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/305 Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3206 as title Reviewed By: xunnanxu Differential Revision: D63717490 fbshipit-source-id: 3a3009d8ea2b777463964be66165d746fcb4b255 --- .../ssd_table_batched_embeddings.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h index 87924eefe..2a8f101c5 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h @@ -668,9 +668,11 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB { snapshot_ptr_t snapshot = snapshot_handle == nullptr ? nullptr : snapshot_handle->get_snapshot_for_shard(shard); + auto local_ro = ro_; + local_ro.snapshot = snapshot; tasks.emplace_back( folly::coro::co_invoke( - [this, &indices, &weights, count_, shard, snapshot]() mutable + [this, &indices, &weights, count_, shard, local_ro]() mutable -> folly::coro::Task { FBGEMM_DISPATCH_FLOAT_HALF_AND_BYTE( weights.scalar_type(), "ssd_get", [&] { @@ -734,10 +736,8 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB { values.resize(keys.size()); statuses.resize(keys.size()); - // Set a snapshot if it is available - ro_.snapshot = snapshot; dbs_[shard]->MultiGet( - ro_, + local_ro, keys.size(), cfs.data(), keys.data(), From 7c2bfb8f7c99be8b897fdf9186e84ef1d2445a49 Mon Sep 17 00:00:00 2001 From: James Donald Date: Tue, 1 Oct 2024 18:42:03 -0700 Subject: [PATCH 31/48] Remove IS_MSVC bool and build_host_info().compiler usage (#3204) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/303 Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3204 `build_host_info()` checks come from buck1, as in buck1 we didn't have a concept of a 'host select()', whereas on buck2 it is preferred to use `select()` after doing a proper configuration transition to the execution platform. Replace the call to `build_host_info().compiler` here with a select(). Note we still retain function arguments like `msvc` and `buck` as these are apparently used in the CMake build flow. Reviewed By: 8Keep Differential Revision: D63710016 fbshipit-source-id: 21864acf1a5ad4eb7a1a71b55723f6667aaae362 --- defs.bzl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/defs.bzl b/defs.bzl index 43d17b13e..9da6f6920 100644 --- a/defs.bzl +++ b/defs.bzl @@ -105,7 +105,8 @@ def get_fbgemm_inline_avx2_srcs(msvc = False, buck = False): asm_srcs = ["src/FbgemmFP16UKernelsAvx2.cc"] if buck: return select({ - "DEFAULT": asm_srcs if not msvc else intrinsics_srcs, + "DEFAULT": asm_srcs, + "ovr_config//compiler:cl": intrinsics_srcs, "ovr_config//cpu:arm64": intrinsics_srcs, }) return asm_srcs if not msvc else intrinsics_srcs @@ -135,7 +136,8 @@ def get_fbgemm_inline_avx512_srcs(msvc = False, buck = False): ] if buck: return select({ - "DEFAULT": asm_srcs if not msvc else intrinsics_srcs, + "DEFAULT": asm_srcs, + "ovr_config//compiler:cl": intrinsics_srcs, "ovr_config//cpu:arm64": intrinsics_srcs, }) return asm_srcs if not msvc else intrinsics_srcs From 893769e53567f2b92aa12b5a8434960c61f8950a Mon Sep 17 00:00:00 2001 From: Minhua Chen Date: Tue, 1 Oct 2024 21:27:34 -0700 Subject: [PATCH 32/48] add step_mode for ensemble_rowwise_adagrad (#3203) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3203 X-link: https://github.com/facebookresearch/FBGEMM/pull/300 add step_mode for ensemble_rowwise_adagrad Reviewed By: q10, minddrummer Differential Revision: D63681997 fbshipit-source-id: 31dffe7e6e375a47872b9d3cac51cbfd7f675f96 --- .../split_table_batched_embeddings_ops_training.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py index f41a49d19..1730dbedc 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py @@ -1944,10 +1944,17 @@ def ensemble_and_swap(self, ensemble_mode: Dict[str, float]) -> None: states = self.split_optimizer_states() for i in range(len(self.embedding_specs)): if should_ema: + step_start = int(ensemble_mode["step_start"]) + if int(ensemble_mode["step_mode"]) == 1: + should_ema_reset = self.iter.item() % step_start == 0 + elif int(ensemble_mode["step_mode"]) == 2: + should_ema_reset = self.iter.item() <= step_start + else: + should_ema_reset = (self.iter.item() <= step_start) or ( + self.iter.item() % step_start == 0 + ) coef_ema = ( - ensemble_mode["step_ema_coef"] - if self.iter.item() > int(ensemble_mode["step_start"]) - else 0.0 + 0.0 if should_ema_reset else ensemble_mode["step_ema_coef"] ) weights_cpu = weights[i].to( dtype=states[i][1].dtype, device=states[i][1].device From 62e7226ab4be53f9e0c07da9b8e0148cd48411ec Mon Sep 17 00:00:00 2001 From: Benson Ma Date: Wed, 2 Oct 2024 10:47:59 -0700 Subject: [PATCH 33/48] Redefine FBGEMM targets with gpu_cpp_library [20B/N] (#3201) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3201 X-link: https://github.com/facebookresearch/FBGEMM/pull/299 - Combine `gpu_cpp_library` targets into `fbgemm_gpu:sparse_ops` Reviewed By: gag1jain Differential Revision: D63679169 fbshipit-source-id: 26d62efef86f5de0d45fd8fd4a2a5a34df0c9f3e --- fbgemm_gpu/bench/batched_unary_embeddings_benchmark.py | 6 +----- .../bench/histogram_binning_calibration_benchmark.py | 6 +----- fbgemm_gpu/bench/jagged_tensor_benchmark.py | 6 +----- fbgemm_gpu/bench/quantize_ops_benchmark.py | 6 +----- fbgemm_gpu/bench/stride_gemm_benchmark.py | 6 +----- fbgemm_gpu/fbgemm_gpu/batched_unary_embeddings_ops.py | 9 +++------ fbgemm_gpu/fbgemm_gpu/sparse_ops.py | 4 +--- fbgemm_gpu/test/batched_unary_embeddings_test.py | 8 ++------ fbgemm_gpu/test/jagged/common.py | 8 +------- fbgemm_gpu/test/layout_transform_ops_test.py | 8 ++------ fbgemm_gpu/test/quantize/common.py | 7 +------ fbgemm_gpu/test/sparse/common.py | 7 +------ 12 files changed, 16 insertions(+), 65 deletions(-) diff --git a/fbgemm_gpu/bench/batched_unary_embeddings_benchmark.py b/fbgemm_gpu/bench/batched_unary_embeddings_benchmark.py index 6091bcc8e..6d354900a 100644 --- a/fbgemm_gpu/bench/batched_unary_embeddings_benchmark.py +++ b/fbgemm_gpu/bench/batched_unary_embeddings_benchmark.py @@ -25,11 +25,7 @@ else: from fbgemm_gpu.bench.bench_utils import benchmark_torch_function - if torch.version.hip: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_hip") - else: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") def generate_unary_feature( diff --git a/fbgemm_gpu/bench/histogram_binning_calibration_benchmark.py b/fbgemm_gpu/bench/histogram_binning_calibration_benchmark.py index c919199ee..e43106b8c 100644 --- a/fbgemm_gpu/bench/histogram_binning_calibration_benchmark.py +++ b/fbgemm_gpu/bench/histogram_binning_calibration_benchmark.py @@ -21,11 +21,7 @@ # pyre-ignore[21] from fbgemm_gpu import open_source # noqa: F401 except Exception: - if torch.version.hip: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_hip") - else: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") def benchmark_hbc_function( diff --git a/fbgemm_gpu/bench/jagged_tensor_benchmark.py b/fbgemm_gpu/bench/jagged_tensor_benchmark.py index 46337701e..814f950b0 100644 --- a/fbgemm_gpu/bench/jagged_tensor_benchmark.py +++ b/fbgemm_gpu/bench/jagged_tensor_benchmark.py @@ -31,10 +31,7 @@ else: from fbgemm_gpu.bench.bench_utils import benchmark_torch_function - if torch.version.hip: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_hip") - else: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") torch.ops.load_library( "//deeplearning/fbgemm/fbgemm_gpu:permute_pooled_embedding_ops_cpu" ) @@ -47,7 +44,6 @@ torch.ops.load_library( "//deeplearning/fbgemm/fbgemm_gpu:permute_multi_embedding_ops_gpu" ) - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") @click.group() diff --git a/fbgemm_gpu/bench/quantize_ops_benchmark.py b/fbgemm_gpu/bench/quantize_ops_benchmark.py index 81eb07bea..54755fff6 100644 --- a/fbgemm_gpu/bench/quantize_ops_benchmark.py +++ b/fbgemm_gpu/bench/quantize_ops_benchmark.py @@ -34,11 +34,7 @@ else: from fbgemm_gpu.bench.bench_utils import benchmark_torch_function - if torch.version.hip: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_hip") - else: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") @click.group() diff --git a/fbgemm_gpu/bench/stride_gemm_benchmark.py b/fbgemm_gpu/bench/stride_gemm_benchmark.py index 2609f7fbf..bca34b8a9 100644 --- a/fbgemm_gpu/bench/stride_gemm_benchmark.py +++ b/fbgemm_gpu/bench/stride_gemm_benchmark.py @@ -20,11 +20,7 @@ # pyre-ignore[21] from fbgemm_gpu import open_source # noqa: F401 except Exception: - if torch.version.hip: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_hip") - else: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") @click.group() diff --git a/fbgemm_gpu/fbgemm_gpu/batched_unary_embeddings_ops.py b/fbgemm_gpu/fbgemm_gpu/batched_unary_embeddings_ops.py index 5ef0f1c32..db2260df4 100644 --- a/fbgemm_gpu/fbgemm_gpu/batched_unary_embeddings_ops.py +++ b/fbgemm_gpu/fbgemm_gpu/batched_unary_embeddings_ops.py @@ -13,16 +13,13 @@ import torch +from fbgemm_gpu.utils.loader import load_torch_module + try: # pyre-ignore[21] from fbgemm_gpu import open_source # noqa: F401 except Exception: - if torch.version.hip: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_hip") - else: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") - - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") + load_torch_module("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") def wrap_weight_to_parameter(weights: List[torch.Tensor]) -> List[torch.Tensor]: diff --git a/fbgemm_gpu/fbgemm_gpu/sparse_ops.py b/fbgemm_gpu/fbgemm_gpu/sparse_ops.py index 3234d6668..8ada0a89f 100644 --- a/fbgemm_gpu/fbgemm_gpu/sparse_ops.py +++ b/fbgemm_gpu/fbgemm_gpu/sparse_ops.py @@ -20,20 +20,18 @@ from fbgemm_gpu import open_source # noqa: F401 except Exception: load_torch_module("//deeplearning/fbgemm/fbgemm_gpu:merge_pooled_embeddings") + load_torch_module("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") if torch.version.hip: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_hip") torch.ops.load_library( "//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops_hip" ) else: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops") torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:input_combine") - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") torch.ops.load_library( "//deeplearning/fbgemm/fbgemm_gpu:merge_pooled_embeddings_cpu" ) diff --git a/fbgemm_gpu/test/batched_unary_embeddings_test.py b/fbgemm_gpu/test/batched_unary_embeddings_test.py index 3f1124eff..37d0aaf7b 100644 --- a/fbgemm_gpu/test/batched_unary_embeddings_test.py +++ b/fbgemm_gpu/test/batched_unary_embeddings_test.py @@ -26,14 +26,10 @@ from test_utils import gpu_unavailable except Exception: - if torch.version.hip: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_hip") - else: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") - - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") from fbgemm_gpu.test.test_utils import gpu_unavailable + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") + # Relative tolerances # pyre-fixme[5]: Global expression must be annotated. diff --git a/fbgemm_gpu/test/jagged/common.py b/fbgemm_gpu/test/jagged/common.py index 3bdcacf98..491176f76 100644 --- a/fbgemm_gpu/test/jagged/common.py +++ b/fbgemm_gpu/test/jagged/common.py @@ -24,13 +24,7 @@ open_source: bool = getattr(fbgemm_gpu, "open_source", False) if not open_source: - if torch.version.hip: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_hip") - else: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") - - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") - + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") suppressed_list: List[HealthCheck] = ( [HealthCheck.differing_executors] diff --git a/fbgemm_gpu/test/layout_transform_ops_test.py b/fbgemm_gpu/test/layout_transform_ops_test.py index 658d773f3..57a7b0263 100644 --- a/fbgemm_gpu/test/layout_transform_ops_test.py +++ b/fbgemm_gpu/test/layout_transform_ops_test.py @@ -22,14 +22,10 @@ from test_utils import gpu_unavailable except Exception: - if torch.version.hip: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_hip") - else: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") - - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") from fbgemm_gpu.test.test_utils import gpu_unavailable + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") + MAX_EXAMPLES = 20 diff --git a/fbgemm_gpu/test/quantize/common.py b/fbgemm_gpu/test/quantize/common.py index 5333cc893..6a720a174 100644 --- a/fbgemm_gpu/test/quantize/common.py +++ b/fbgemm_gpu/test/quantize/common.py @@ -23,12 +23,7 @@ from fbgemm_gpu import open_source # noqa: F401 except Exception: - if torch.version.hip: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_hip") - else: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") - - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") # Eigen/Python round 0.5 away from 0, Numpy rounds to even round_to_nearest: Callable[[npt.NDArray], npt.NDArray] = np.vectorize(round) diff --git a/fbgemm_gpu/test/sparse/common.py b/fbgemm_gpu/test/sparse/common.py index 69b6e3477..8abddca75 100644 --- a/fbgemm_gpu/test/sparse/common.py +++ b/fbgemm_gpu/test/sparse/common.py @@ -23,12 +23,7 @@ open_source: bool = getattr(fbgemm_gpu, "open_source", False) if not open_source: - if torch.version.hip: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_hip") - else: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") - - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu/codegen:index_select_ops") suppressed_list: List[HealthCheck] = ( From c39d7e87200458f608d6e288506cfedc6fe49799 Mon Sep 17 00:00:00 2001 From: Richard Barnes Date: Wed, 2 Oct 2024 11:09:42 -0700 Subject: [PATCH 34/48] c10::optional -> std::optional Summary: `c10::optional` is an alias for `std::optional`. Let's remove the alias and use the real thing. Reviewed By: palmje Differential Revision: D63409425 fbshipit-source-id: 91d88c77c4b85d29091de7f9e22b5ee87d6eb171 --- .../ssd_split_table_batched_embeddings.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp index b4ed00d31..9723ddfea 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp @@ -30,7 +30,7 @@ ssd_cache_populate_actions_cuda( bool gather_cache_stats, std::optional ssd_cache_stats, const bool lock_cache_line, - const c10::optional& lxu_cache_locking_counter); + const std::optional& lxu_cache_locking_counter); /// @ingroup embedding-ssd /// From 38639fd3df6b0b309042f2dc673063cd7061416b Mon Sep 17 00:00:00 2001 From: Richard Barnes Date: Wed, 2 Oct 2024 11:16:53 -0700 Subject: [PATCH 35/48] c10::optional -> std::optional Summary: `c10::optional` is an alias for `std::optional`. Let's remove the alias and use the real thing. Reviewed By: palmje Differential Revision: D63409373 fbshipit-source-id: 6a7da6da68de13f134c2989e018809cce4cc2c05 --- .../ssd_split_embeddings_cache_cuda.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_embeddings_cache_cuda.cu b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_embeddings_cache_cuda.cu index d371c2845..340d616e6 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_embeddings_cache_cuda.cu +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_embeddings_cache_cuda.cu @@ -343,7 +343,7 @@ ssd_cache_populate_actions_cuda( bool gather_cache_stats, std::optional ssd_cache_stats, const bool lock_cache_line, - const c10::optional& lxu_cache_locking_counter) { + const std::optional& lxu_cache_locking_counter) { TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL( linear_indices, lxu_cache_state, lru_state, lxu_cache_locking_counter); From c50b5c8aa13d652cbc3151ed11a914563a37e16d Mon Sep 17 00:00:00 2001 From: Jiawen Liu Date: Wed, 2 Oct 2024 13:06:45 -0700 Subject: [PATCH 36/48] MoE BMM FP8 rowwise (#3207) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3207 X-link: https://github.com/facebookresearch/FBGEMM/pull/306 Enable MoE BMM FP8 rowwise: - MoE BMM FP8 rowwise achieves **up to 4.5x (2.1x on average) speedup compared to BF16 BMM** - In E2E with MoE 16b x 16, FP8 with BMM achieves **1.2x speedup than BF16** - Integrated in E2E and verified correctness which matches BF16 generations - More results are in this [data sheet](https://docs.google.com/spreadsheets/d/1OLdz4MlzWS9pdgTBq4Jjy0-9_nPn-NmdrMolY0jZOXE/edit?gid=0#gid=0) {F1903039975} Reviewed By: jianyuh Differential Revision: D63681109 fbshipit-source-id: 6ba8a04a1f717a2832cce6637628448e82c3d0f9 --- fbgemm_gpu/experimental/gen_ai/CMakeLists.txt | 1 + .../f8f8bf16_rowwise_batched.cu | 502 ++++++++++++++++++ .../cutlass_extensions/include/kernel_mode.h | 21 + .../gen_ai/src/quantize/quantize.cpp | 27 + .../gen_ai/test/quantize/quantize_test.py | 10 +- 5 files changed, 560 insertions(+), 1 deletion(-) create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise_batched.cu diff --git a/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt b/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt index a2916c9e5..5accb9c53 100644 --- a/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt +++ b/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt @@ -43,6 +43,7 @@ else() src/quantize/cutlass_extensions/f8f8bf16_blockwise.cu src/quantize/cutlass_extensions/f8f8bf16_cublas.cu src/quantize/cutlass_extensions/f8f8bf16_rowwise.cu + src/quantize/cutlass_extensions/f8f8bf16_rowwise_batched.cu src/quantize/cutlass_extensions/f8f8bf16_tensorwise.cu src/quantize/cutlass_extensions/i8i8bf16.cu src/quantize/cutlass_extensions/f8i4bf16_rowwise.cu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise_batched.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise_batched.cu new file mode 100644 index 000000000..313c81298 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise_batched.cu @@ -0,0 +1,502 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include + +// clang-format off +// The fixed ordering of the headers is required for CUTLASS 3.2+ +#include +#include // @manual +#include // @manual +#include // @manual +// clang-format on + +#include "cutlass_extensions/include/kernel_mode.h" + +namespace fbgemm_gpu { + +#if CUDART_VERSION >= 12000 + +// Cutlass rowwise batched kernel +template < + int TB_M, + int TB_N, + int TB_K, + int TBS_M, + int TBS_N, + int TBS_K, + bool PONG, + bool FAST_ACCUM, + bool USE_BIAS, + typename INPUT_DTYPE, + typename BIAS_DTYPE> +at::Tensor f8f8bf16_rowwise_batched_impl( + at::Tensor XQ, // FP8 + at::Tensor WQ, // FP8 + at::Tensor x_scale, + at::Tensor w_scale, + std::optional bias, + std::optional output) { + // XQ: B x M x K + // WQ: B x N x K + // output: B x M x N + int B = XQ.size(0); + int M = XQ.size(1); + int N = WQ.size(1); + int K = WQ.size(2); + TORCH_CHECK(XQ.size(-1) == K); + auto out_sizes = XQ.sizes().vec(); + out_sizes.back() = N; + + TORCH_CHECK(XQ.is_cuda() && XQ.is_contiguous()); + TORCH_CHECK(WQ.is_cuda() && WQ.is_contiguous()); + + at::Tensor Y; + if (output.has_value()) { + Y = output.value(); + // Make sure the provided output has the proper shape and dtype. + TORCH_CHECK(Y.sizes().vec() == out_sizes); + TORCH_CHECK(Y.dtype() == at::kBFloat16); + } else { + Y = at::empty(out_sizes, XQ.options().dtype(at::kBFloat16)); + } + + using ElementInputA = INPUT_DTYPE; + using LayoutInputA = cutlass::layout::RowMajor; + constexpr int AlignmentInputA = 16 / sizeof(ElementInputA); + + using ElementInputB = cutlass::float_e4m3_t; + using LayoutInputB = cutlass::layout::ColumnMajor; + constexpr int AlignmentInputB = 16 / sizeof(ElementInputB); + + using ElementBias = BIAS_DTYPE; + + using ElementOutput = cutlass::bfloat16_t; + using LayoutOutput = cutlass::layout::RowMajor; + constexpr int AlignmentOutput = 16 / sizeof(ElementOutput); + + using ElementAccumulator = float; + using ElementComputeEpilogue = float; + using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that + // supports the intended feature + using OperatorClass = cutlass::arch::OpClassTensorOp; + using TileShape = cute::Shape< + cute::Int, + cute::Int, + cute::Int>; // Threadblock-level + // tile size + using ClusterShape = cute::Shape< + cute::Int, + cute::Int, + cute::Int>; // Shape of the + // threadblocks in a + // cluster + using StageCountType = + cutlass::gemm::collective::StageCountAuto; // Stage count maximized + // based on the tile size + using KernelSchedule = cutlass::gemm::collective:: + KernelScheduleAuto; // Kernel to launch based on the default setting in + // the Collective Builder + + // Implement rowwise scaling epilogue. + using XScale = cutlass::epilogue::fusion::Sm90ColBroadcast< + 0, + TileShape, + ElementComputeEpilogue, + cute::Stride, cute::Int<0>, int32_t>>; + + using WScale = cutlass::epilogue::fusion::Sm90RowBroadcast< + PONG ? 2 : 1, + TileShape, + ElementComputeEpilogue, + cute::Stride, cute::Int<1>, int32_t>>; + + using Bias = cutlass::epilogue::fusion::Sm90RowBroadcast< + PONG ? 2 : 1, + TileShape, + ElementBias, + cute::Stride, cute::Int<1>, int32_t>>; + + using Accum = cutlass::epilogue::fusion::Sm90AccFetch; + + using Compute0 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, + ElementComputeEpilogue, // First stage output type. + ElementComputeEpilogue, // First stage input types. + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute0 = + cutlass::epilogue::fusion::Sm90EVT; + + using Compute1 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, + cute::conditional_t< // Second stage output type. + USE_BIAS, + ElementBias, + ElementOutput>, + ElementComputeEpilogue, // Second stage input types. + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute1 = + cutlass::epilogue::fusion::Sm90EVT; + + using ComputeBias = cutlass::epilogue::fusion::Sm90Compute< + cutlass::plus, + ElementOutput, // Final (optional) stage output type. + ElementBias, // Final stage input types. + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeBias = + cutlass::epilogue::fusion::Sm90EVT; + + using EpilogueEVT = + cute::conditional_t; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, + cutlass::arch::OpClassTensorOp, + TileShape, + ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, + ElementComputeEpilogue, + ElementOutput, + LayoutOutput, + AlignmentOutput, + ElementOutput, + LayoutOutput, + AlignmentOutput, + cutlass::epilogue::TmaWarpSpecialized, + EpilogueEVT>::CollectiveOp; + + using DefaultSchedule = cutlass::gemm::KernelTmaWarpSpecialized; + using PongSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong; + using FastDefaultSchedule = + cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum; + using FastPongSchedule = + cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; + using SlowAccum = cute::conditional_t; + using FastAccum = + cute::conditional_t; + using MainLoopSchedule = + cute::conditional_t; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OperatorClass, + ElementInputA, + LayoutInputA, + AlignmentInputA, + ElementInputB, + LayoutInputB, + AlignmentInputB, + ElementAccumulator, + TileShape, + ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainLoopSchedule>::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + using StrideInputA = typename Gemm::GemmKernel::StrideA; + using StrideInputB = typename Gemm::GemmKernel::StrideB; + using StrideOutput = typename Gemm::GemmKernel::StrideC; + + StrideInputA stride_a = cutlass::make_cute_packed_stride( + StrideInputA{}, cute::make_shape(M, K, B)); + StrideInputB stride_b = cutlass::make_cute_packed_stride( + StrideInputB{}, cute::make_shape(N, K, B)); + StrideOutput stride_output = cutlass::make_cute_packed_stride( + StrideOutput{}, cute::make_shape(M, N, B)); + + typename Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + {M, N, K, B}, + {reinterpret_cast(XQ.data_ptr()), + stride_a, + reinterpret_cast(WQ.data_ptr()), + stride_b}, + {{}, // Epilogue thread we populate below. + (ElementOutput*)Y.data_ptr(), + stride_output, + (ElementOutput*)Y.data_ptr(), + stride_output}}; + + if constexpr (USE_BIAS) { + arguments.epilogue.thread = { + {reinterpret_cast(bias.value().data_ptr()), + ElementBias(0), + {cute::Int<0>(), cute::Int<1>(), int32_t(N)}}, // bias + // compute_1 + { + {reinterpret_cast(x_scale.data_ptr()), + ElementComputeEpilogue(0), + {cute::Int<1>(), cute::Int<0>(), int32_t(M)}}, // x_scale + // compute_0 + { + {reinterpret_cast(w_scale.data_ptr()), + ElementComputeEpilogue(0), + {cute::Int<0>(), cute::Int<1>(), int32_t(N)}}, // w_scale + {}, // Accumulator + {} // Multiplies + }, + {}, // Multiplies + }, + {}, // Plus + }; + } else { + arguments.epilogue.thread = { + {reinterpret_cast(x_scale.data_ptr()), + ElementComputeEpilogue(0), + {cute::Int<1>(), cute::Int<0>(), int32_t(M)}}, // x_scale + // compute_0 + { + {reinterpret_cast(w_scale.data_ptr()), + ElementComputeEpilogue(0), + {cute::Int<0>(), cute::Int<1>(), int32_t(N)}}, // w_scale + {}, // Accumulator + {} // Multiplies + }, + {}, // Multiplies + }; + } + + Gemm gemm; + + // Using the arguments, query for extra workspace required for matrix + // multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Check the problem size is supported or not + cutlass::Status status = gemm.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error("cutlass cannot implement"); + } + + // Initialize CUTLASS kernel with arguments and workspace pointer + status = gemm.initialize(arguments, workspace.get()); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error("cutlass cannot initialize"); + } + + status = gemm(at::cuda::getCurrentCUDAStream()); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error( + std::string("cutlass cannot run") + + cutlass::cutlassGetStatusString(status)); + } + C10_CUDA_KERNEL_LAUNCH_CHECK(); + + return Y; +} + +// FP8 Rowwise batched Cutlass kernel dispatch. +template +at::Tensor dispatch_fp8_rowwise_batched_kernel( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + std::optional bias, + std::optional output) { + KernelMode kernel = get_batched_kernel_mode(XQ, WQ); + TORCH_CHECK( + (XQ.dim() == 3 && WQ.dim() == 3), + "FP8 rowwise batched GEMM only supports 3D inputs"); + if (kernel == KernelMode::Small) { + return f8f8bf16_rowwise_batched_impl< + 64, + 128, + 128, + 2, + 1, + 1, + false, + FastAccum, + UseBias, + InputDType, + BiasDType>(XQ, WQ, x_scale, w_scale, bias, output); + } else if (kernel == KernelMode::Large) { + return f8f8bf16_rowwise_batched_impl< + 128, + 128, + 128, + 2, + 1, + 1, + true, + FastAccum, + UseBias, + InputDType, + BiasDType>(XQ, WQ, x_scale, w_scale, bias, output); + } else { + return f8f8bf16_rowwise_batched_impl< + 128, + 128, + 128, + 1, + 2, + 1, + false, + FastAccum, + UseBias, + InputDType, + BiasDType>(XQ, WQ, x_scale, w_scale, bias, output); + } +} + +at::Tensor f8f8bf16_rowwise_batched( + at::Tensor XQ, // FP8 + at::Tensor WQ, // FP8 + at::Tensor x_scale, // FP32 + at::Tensor w_scale, // FP32 + std::optional bias = c10::nullopt, + bool use_fast_accum = true, + std::optional output = c10::nullopt) { + // Check datatypes. + TORCH_CHECK( + x_scale.dtype() == at::kFloat && w_scale.dtype() == at::kFloat, + "Scale tensors must be float32."); + if (bias.has_value()) { + TORCH_CHECK( + bias.value().dtype() == at::kFloat || + bias.value().dtype() == at::kBFloat16, + "Bias type must be bfloat16 or float32 if provided."); + } + bool use_bias = bias.has_value(); + bool bf16_bias = use_bias && bias.value().dtype() == at::kBFloat16; + + // Templatize based on input dtype. + bool use_e5m2 = XQ.dtype() == at::kFloat8_e5m2; + + if (use_bias) { + if (bf16_bias) { + if (use_fast_accum) { + if (use_e5m2) { + return dispatch_fp8_rowwise_batched_kernel< + cutlass::float_e5m2_t, + true, + true, + cutlass::bfloat16_t>(XQ, WQ, x_scale, w_scale, bias, output); + } else { + return dispatch_fp8_rowwise_batched_kernel< + cutlass::float_e4m3_t, + true, + true, + cutlass::bfloat16_t>(XQ, WQ, x_scale, w_scale, bias, output); + } + } else { + if (use_e5m2) { + return dispatch_fp8_rowwise_batched_kernel< + cutlass::float_e5m2_t, + false, + true, + cutlass::bfloat16_t>(XQ, WQ, x_scale, w_scale, bias, output); + } else { + return dispatch_fp8_rowwise_batched_kernel< + cutlass::float_e4m3_t, + false, + true, + cutlass::bfloat16_t>(XQ, WQ, x_scale, w_scale, bias, output); + } + } + } else { + if (use_fast_accum) { + if (use_e5m2) { + return dispatch_fp8_rowwise_batched_kernel< + cutlass::float_e5m2_t, + true, + true, + float>(XQ, WQ, x_scale, w_scale, bias, output); + } else { + return dispatch_fp8_rowwise_batched_kernel< + cutlass::float_e4m3_t, + true, + true, + float>(XQ, WQ, x_scale, w_scale, bias, output); + } + } else { + if (use_e5m2) { + return dispatch_fp8_rowwise_batched_kernel< + cutlass::float_e5m2_t, + false, + true, + float>(XQ, WQ, x_scale, w_scale, bias, output); + } else { + return dispatch_fp8_rowwise_batched_kernel< + cutlass::float_e4m3_t, + false, + true, + float>(XQ, WQ, x_scale, w_scale, bias, output); + } + } + } + } else { + if (use_fast_accum) { + if (use_e5m2) { + return dispatch_fp8_rowwise_batched_kernel< + cutlass::float_e5m2_t, + true, + false, + float>(XQ, WQ, x_scale, w_scale, bias, output); + } else { + return dispatch_fp8_rowwise_batched_kernel< + cutlass::float_e4m3_t, + true, + false, + float>(XQ, WQ, x_scale, w_scale, bias, output); + } + } else { + if (use_e5m2) { + return dispatch_fp8_rowwise_batched_kernel< + cutlass::float_e5m2_t, + false, + false, + float>(XQ, WQ, x_scale, w_scale, bias, output); + } else { + return dispatch_fp8_rowwise_batched_kernel< + cutlass::float_e4m3_t, + false, + false, + float>(XQ, WQ, x_scale, w_scale, bias, output); + } + } + } +} + +#else + +at::Tensor f8f8bf16_rowwise_batched( + at::Tensor XQ, // FP8 + at::Tensor WQ, // FP8 + at::Tensor x_scale, + at::Tensor w_scale, + std::optional bias = c10::nullopt, + bool use_fast_accum = true, + std::optional output = c10::nullopt) { + throw std::runtime_error( + "CUDA version is older than 12.0"); // requires CUDA>=12 +} + +#endif + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/include/kernel_mode.h b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/include/kernel_mode.h index 94a68096d..93b96fb04 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/include/kernel_mode.h +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/include/kernel_mode.h @@ -31,4 +31,25 @@ inline KernelMode get_kernel_mode(at::Tensor XQ, at::Tensor WQ) { } } +inline KernelMode get_batched_kernel_mode(at::Tensor XQ, at::Tensor WQ) { + auto B = XQ.size(0); + auto M = XQ.size(1); + auto K = XQ.size(2); + auto N = WQ.size(1); + auto BM = B * M; + auto BN = B * N; + auto BK = B * K; + // Use a large kernel if at least two shapes are large.... + bool use_large_kernel = + ((BM >= 2048 && BK >= 2048) || (BM >= 2048 && BK >= 2048) || + (BK >= 2048 && BN >= 2048)); + if (BM <= 128 || BN <= 128) { + return KernelMode::Small; + } else if (use_large_kernel) { + return KernelMode::Large; + } else { + return KernelMode::Default; + } +} + } // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp b/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp index 39084712c..ff5c66766 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp @@ -57,6 +57,14 @@ at::Tensor f8f8bf16_rowwise( std::optional bias = c10::nullopt, bool use_fast_accum = true, std::optional output = c10::nullopt); +at::Tensor f8f8bf16_rowwise_batched( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + std::optional bias = c10::nullopt, + bool use_fast_accum = true, + std::optional output = c10::nullopt); at::Tensor f8f8bf16_blockwise( at::Tensor XQ, at::Tensor WQ, @@ -132,6 +140,8 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { m.def( "f8i4bf16_rowwise(Tensor XQ, Tensor WQ, Tensor x_scale, Tensor w_scale, Tensor w_zp) -> Tensor"); m.impl("f8i4bf16_rowwise", f8i4bf16_rowwise); + m.def( + "f8f8bf16_rowwise_batched(Tensor XQ, Tensor WQ, Tensor x_scale, Tensor w_scale, Tensor? bias=None, bool use_fast_accum=True, Tensor(a!)? output=None) -> Tensor"); m.def( "bf16i4bf16_rowwise(Tensor X, Tensor WQ, Tensor w_scale, Tensor w_zp) -> Tensor"); @@ -188,6 +198,7 @@ TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) { m.impl("i8i8bf16", i8i8bf16); m.impl("f8f8bf16", f8f8bf16); m.impl("f8f8bf16_cublas", f8f8bf16_cublas); + m.impl("f8f8bf16_rowwise_batched", f8f8bf16_rowwise_batched); #endif } @@ -216,6 +227,21 @@ at::Tensor f8f8bf16_rowwise_meta( return Y; } +at::Tensor f8f8bf16_rowwise_batched_meta( + at::Tensor XQ, // FP8 + at::Tensor WQ, // FP8 + at::Tensor /* x_scale */, + at::Tensor /* w_scale */, + std::optional /* bias = c10::nullopt */, + bool /* use_fast_accum = true */, + std::optional /* output = c10::nullopt */) { + int B = XQ.size(0); + int M = XQ.size(1); + int N = WQ.size(1); + auto Y = at::empty({B, M, N}, XQ.options().dtype(at::kBFloat16)); + return Y; +} + at::Tensor f8f8bf16_blockwise_meta( at::Tensor XQ, // FP8 at::Tensor WQ, // FP8 @@ -331,6 +357,7 @@ TORCH_LIBRARY_IMPL(fbgemm, Meta, m) { m.impl("i8i8bf16", i8i8bf16_meta); m.impl("f8f8bf16", f8f8bf16_meta); m.impl("f8f8bf16_cublas", f8f8bf16_cublas_meta); + m.impl("f8f8bf16_rowwise_batched", f8f8bf16_rowwise_batched_meta); m.impl("f8i4bf16_rowwise", f8i4bf16_rowwise_meta); m.impl("bf16i4bf16_rowwise", bf16i4bf16_rowwise_meta); #endif diff --git a/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py b/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py index fdd038e2c..c21c1713a 100644 --- a/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py +++ b/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py @@ -614,12 +614,16 @@ def test_quantize_fp8_per_tensor_with_ub( zq_ref = (x @ w.T).to(torch.bfloat16) torch.testing.assert_close(zq, zq_ref, atol=1.0e-3, rtol=1.0e-3) + @unittest.skipIf( + not torch.version.cuda, "Skip on AMD: BMM ops are not yet suported." + ) @settings(deadline=None) @given( B=st.sampled_from([1, 4]), M=st.sampled_from([2048, 4096]), N=st.sampled_from([128, 256]), K=st.sampled_from([256, 512]), + use_loopover=st.sampled_from([True, False]), ) def test_fp8_batched_gemm( self, @@ -627,6 +631,7 @@ def test_fp8_batched_gemm( M: int, N: int, K: int, + use_loopover: bool, ) -> None: x = torch.rand(size=(B, M, K), dtype=torch.bfloat16, device="cuda") * 0.1 w = torch.rand(size=(B, N, K), dtype=torch.bfloat16, device="cuda") * 0.01 @@ -655,7 +660,10 @@ def fp8_loopover_bmm( return y y_ref = torch.bmm(x, w.transpose(1, 2)) - y_fp8 = fp8_loopover_bmm(xq, wq, x_scale, w_scale) + if use_loopover: + y_fp8 = fp8_loopover_bmm(xq, wq, x_scale, w_scale) + else: + y_fp8 = torch.ops.fbgemm.f8f8bf16_rowwise_batched(xq, wq, x_scale, w_scale) torch.testing.assert_close(y_ref, y_fp8, atol=8.0e-2, rtol=8.0e-2) @unittest.skipIf(torch.version.hip, "Skip on AMD: Marlin not yet suported.") From 4a4d187987fe1e41a5988511d835b3eb3f3d6958 Mon Sep 17 00:00:00 2001 From: Hongtao Yu Date: Wed, 2 Oct 2024 14:47:08 -0700 Subject: [PATCH 37/48] Fix bias dtype issue for the TMA kernel (#3199) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/297 Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3199 Passing bias dtype through a const arg instead of hardcoding it in the kernel. Addressing https://fb.workplace.com/groups/fbgemmusers/permalink/8689681817779189/ Reviewed By: sijiac Differential Revision: D63569991 fbshipit-source-id: 46b5621bb668493369b1752512eb9fe86a8340df --- .../experimental/gemm/triton_gemm/fp8_gemm.py | 36 +++++++++++++++---- 1 file changed, 29 insertions(+), 7 deletions(-) diff --git a/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py b/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py index 330ab51b8..b82cde50a 100644 --- a/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +++ b/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py @@ -58,6 +58,28 @@ def reinterpret_fp8_type(tensor: torch.Tensor, dtype: tl.dtype) -> TensorWrapper return tl_reinterpret(tensor, dtype=dtype) +def map_dtype_to_triton(dtype: torch.dtype) -> tl.dtype: + """ + Maps torch dtype to triton dtype. + + Args: + dtype (torch.dtype): input dtype. + + Returns: + tl.dtype: triton dtype. + """ + if dtype == torch.float16: + return tl.float16 + elif dtype == torch.bfloat16: + return tl.bfloat16 + elif dtype == torch.float32: + return tl.float32 + elif dtype == torch.int32: + return tl.int32 + else: + raise ValueError(f"Unsupported dtype {dtype}") + + def init_to_zero(name): return lambda nargs: nargs[name].zero_() @@ -746,6 +768,7 @@ def _kernel_matmul_fp8_row_tma_persistent( stride_cn, dot_out_dtype: tl.constexpr, c_dtype: tl.constexpr, + bias_dtype: tl.constexpr, allow_tf32: tl.constexpr, fp8_fast_accum: tl.constexpr, BLOCK_M: tl.constexpr, @@ -813,7 +836,6 @@ def _kernel_matmul_fp8_row_tma_persistent( dtype_fp8 = tl.float8e4nv scale_dtype = tl.float32 - bias_dtype = tl.float32 for _ in range(0, k_tiles * tiles_per_SM): ki = tl.where(ki == k_tiles - 1, 0, ki + 1) @@ -1110,6 +1132,10 @@ def persistent_grid_tma(META): desc_b_scale = desc_helper.get_tma_descriptor_kernel_param("b_scale") desc_bias = desc_helper.get_tma_descriptor_kernel_param("bias") + bias_dtype_triton = None + if bias is not None: + bias_dtype_triton = map_dtype_to_triton(bias.dtype) + # pyre-ignore torch._library.capture_triton(_kernel_matmul_fp8_row_tma_persistent)[ persistent_grid_tma @@ -1134,6 +1160,7 @@ def persistent_grid_tma(META): c.stride(1), dot_out_dtype=dot_out_dtype_triton, c_dtype=c_dtype_triton, + bias_dtype=bias_dtype_triton, allow_tf32=allow_tf32, fp8_fast_accum=fp8_fast_accum, GROUP_M=8, @@ -1864,12 +1891,7 @@ def prep_matmul( assert isinstance( dot_out_dtype, torch.dtype ), f"dot_out_dtype type {type(dot_out_dtype)} must be a torch.dtype" - if dot_out_dtype == torch.bfloat16: - dot_out_dtype_triton = tl.bfloat16 - elif dot_out_dtype == torch.float32: - dot_out_dtype_triton = tl.float32 - else: - dot_out_dtype_triton = tl.int32 + dot_out_dtype_triton = map_dtype_to_triton(dot_out_dtype) return M, N, K, m_key, n_key, k_key, c, c_dtype_triton, dot_out_dtype_triton, device From d1c40a90dd34a2bced8e2adc8f59270b906cf077 Mon Sep 17 00:00:00 2001 From: generatedunixname89002005232357 Date: Wed, 2 Oct 2024 18:00:17 -0700 Subject: [PATCH 38/48] Revert D63335491: Multisect successfully blamed "D63335491: Refactor the GIS to reuse same autograd function for all backends" for one test failure (#3214) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3214 X-link: https://github.com/facebookresearch/FBGEMM/pull/311 This diff reverts D63335491 D63335491: Refactor the GIS to reuse same autograd function for all backends by egienvalue causes the following test failure: Tests affected: - [cogwheel:cogwheel_model_import_inference_ads_v0_test#test_ads_v0_inference_model_import](https://www.internalfb.com/intern/test/281475118337004/) Here's the Multisect link: https://www.internalfb.com/multisect/11357230 Here are the tasks that are relevant to this breakage: T203477897: Test cogwheel:cogwheel_model_import_inference_ads_v0_test#test_ads_v0_inference_model_import failing for ai_test_validation The backout may land if someone accepts it. If this diff has been generated in error, you can Commandeer and Abandon it. Reviewed By: egienvalue Differential Revision: D63757977 fbshipit-source-id: 04f93655469ce6f732b453d8a4e109b308c94e7d --- fbgemm_gpu/fbgemm_gpu/sparse_ops.py | 4 +- fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h | 104 --- fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp | 122 +-- fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp | 774 +++++++++++-------- 4 files changed, 483 insertions(+), 521 deletions(-) diff --git a/fbgemm_gpu/fbgemm_gpu/sparse_ops.py b/fbgemm_gpu/fbgemm_gpu/sparse_ops.py index 8ada0a89f..71e0e2ccc 100644 --- a/fbgemm_gpu/fbgemm_gpu/sparse_ops.py +++ b/fbgemm_gpu/fbgemm_gpu/sparse_ops.py @@ -1073,11 +1073,11 @@ def impl_autograd(op_name, fn, setup_context: Optional[Callable] = None) -> None ) impl_abstract("fbgemm::bounds_check_indices", bounds_check_indices_abstract) impl_abstract( - "fbgemm::group_index_select_dim0_forward_impl", + "fbgemm::group_index_select_dim0_gpu_impl", group_index_select_dim0_gpu_impl_abstract, ) impl_abstract( - "fbgemm::group_index_select_dim0_backward_impl", + "fbgemm::group_index_select_dim0_gpu_backward", group_index_select_dim0_gpu_backward_abstract, ) impl_abstract( diff --git a/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h b/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h index b3607a5b1..9bea430ef 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h +++ b/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h @@ -9,11 +9,8 @@ #pragma once #include -#include #include #include -#include - #include namespace fbgemm_gpu { @@ -927,107 +924,6 @@ at::Tensor index_add_with_unique_indices_cuda( const int consecutive_range_start, const int consecutive_range_length); -torch::autograd::variable_list group_index_select_dim0_decomposed( - at::TensorList input_group, - at::TensorList indices_group); - -torch::autograd::variable_list group_index_select_dim0_autograd_impl( - at::TensorList all_indices_input, - const int64_t group_size); - -torch::autograd::variable_list group_index_select_dim0( - at::TensorList input_group, - at::TensorList indices_group); - -torch::autograd::variable_list group_index_select_dim0_forward_impl_cpu( - at::TensorList all_indices_input, - const int64_t group_size); - -torch::autograd::variable_list group_index_select_dim0_backward_impl_cpu( - at::TensorList all_inputs, - c10::SymIntArrayRef output_shape_group_ref); - -std::pair, std::vector> -group_index_select_dim0_unpack( - at::TensorList all_indices_input, - const int64_t group_size); - -class GroupIndexSelectDim0Op - : public torch::autograd::Function { - public: - static torch::autograd::variable_list forward( - torch::autograd::AutogradContext* ctx, - at::TensorList all_indices_input, - const int64_t group_size) { - at::AutoDispatchBelowADInplaceOrView guard; - static auto forward_op = - at::Dispatcher::singleton() - .findSchemaOrThrow( - "fbgemm::group_index_select_dim0_forward_impl", "") - .typed(); - auto result = forward_op.call(all_indices_input, group_size); - TORCH_CHECK(static_cast(result.size()) == group_size + 2); - - auto [input_group, indices_group] = - group_index_select_dim0_unpack(all_indices_input, group_size); - const auto input_dim = input_group[0].dim(); - std::vector input_shape_group; - input_shape_group.reserve(group_size * input_dim); - - for (const auto i : c10::irange(group_size)) { - const auto& input = input_group[i]; - // Copy input shape - auto input_shape = input.sym_sizes().vec(); - input_shape_group.insert( - input_shape_group.end(), input_shape.begin(), input_shape.end()); - } - - // save indices, args_tensor, saved_data - auto saved_tensors = std::vector(indices_group); - saved_tensors.insert( - saved_tensors.end(), result.cbegin() + group_size, result.cend()); - saved_tensors.push_back(input_group[0]); - ctx->save_for_backward(saved_tensors); - ctx->saved_data["input_shape_group"] = input_shape_group; - - return result; - } - - static torch::autograd::variable_list backward( - torch::autograd::AutogradContext* ctx, - torch::autograd::variable_list grad_output_group) { - TORCH_CHECK(grad_output_group.size() >= 2); - if (grad_output_group.size() == 2) { - // empty outputs - return torch::autograd::variable_list(1); - } - // remove redundant grads - auto group_size = grad_output_group.size() - 2; - grad_output_group.resize(group_size); - - auto saved_tensors = ctx->get_saved_variables(); - TORCH_CHECK(saved_tensors.size() == group_size + 3); - auto output_shape_group = - ctx->saved_data["input_shape_group"].toSymIntVector(); - grad_output_group.insert( - grad_output_group.end(), saved_tensors.begin(), saved_tensors.end()); - static auto backward_op = - at::Dispatcher::singleton() - .findSchemaOrThrow( - "fbgemm::group_index_select_dim0_backward_impl", "") - .typed(); - auto res = backward_op.call(grad_output_group, output_shape_group); - // 1) Add group_size Variable()'s for indices - // Replace all empty tensors with Variable(). This must be done after the - // op.call to make __torch_dispatch__ work for the backward op. - std::fill( - res.begin(), res.begin() + group_size, torch::autograd::Variable()); - // 3) Add 1 Variable() for group_size - res.push_back({}); - return res; - } -}; - ///@ingroup sparse-data-cuda void group_index_select_or_add_cuda( const int64_t* input_ptrs, diff --git a/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp b/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp index 01dde3394..7734cc69a 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp +++ b/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp @@ -2851,103 +2851,41 @@ Tensor pack_segments_cpu( const int64_t max_length) { return pack_segments_forward_cpu(t_in, lengths, max_length); } - -torch::autograd::variable_list group_index_select_dim0_autograd_impl( - at::TensorList all_indices_input, - const int64_t group_size) { - return GroupIndexSelectDim0Op::apply(all_indices_input, group_size); -} - -std::pair, std::vector> -group_index_select_dim0_unpack( - at::TensorList all_indices_input, - const int64_t group_size) { - std::vector indices_group; - std::vector input_group; - - indices_group.reserve(group_size); - input_group.reserve(group_size); - - for (const auto i : c10::irange(group_size)) { - indices_group.push_back(all_indices_input[i]); - input_group.push_back(all_indices_input[group_size + i]); - } - - TORCH_CHECK(group_size == static_cast(indices_group.size())); - - return std::make_pair(input_group, indices_group); +namespace { +Tensor index_select_dim0( + const Tensor& input, + const Tensor& indices, + std::optional /*consecutive_range_start*/, + std::optional /*consecutive_range_length*/, + std::optional /*skip_indices_sorting_fwd*/) { + return at::index_select(input, 0, indices); } torch::autograd::variable_list group_index_select_dim0( at::TensorList input_group, at::TensorList indices_group) { - const auto group_size = indices_group.size(); + int num_groups = input_group.size(); + TORCH_CHECK(num_groups == (int)indices_group.size()) std::vector output_group; - - if (group_size == 0) { - return std::vector(); - } - - // Pack input_group and indices_group into TensorList - std::vector all_indices_input_vec; - all_indices_input_vec.reserve(group_size * 2); - - for (const Tensor& index : indices_group) { - all_indices_input_vec.push_back(index); - } - for (const Tensor& input : input_group) { - all_indices_input_vec.push_back(input); + for (const auto i : c10::irange(num_groups)) { + output_group.push_back( + at::index_select(input_group[i], 0, indices_group[i])); } - - at::TensorList all_indices_input_tensor = all_indices_input_vec; - - static auto forward_op = - at::Dispatcher::singleton() - .findSchemaOrThrow( - "fbgemm::group_index_select_dim0_autograd_impl", "") - .typed(); - auto res = forward_op.call(all_indices_input_tensor, group_size); - TORCH_CHECK(res.size() == group_size + 2); - // only return the outputs (the first group_size elements) - res.resize(group_size); - return res; + return output_group; } -torch::autograd::variable_list group_index_select_dim0_forward_impl_cpu( +torch::autograd::variable_list group_index_select_dim0_gpu_impl_cpu( at::TensorList all_indices_input, const int64_t group_size) { throw std::runtime_error( - "group_index_select_dim0_forward_impl is not implemented for CPU"); + "group_index_select_dim0_gpu_impl is not implemented for CPU"); } -torch::autograd::variable_list group_index_select_dim0_backward_impl_cpu( +torch::autograd::variable_list group_index_select_dim0_gpu_backward_cpu( at::TensorList all_inputs, c10::SymIntArrayRef output_shape_group_ref) { throw std::runtime_error( - "group_index_select_dim0_backward_impl is not implemented for CPU"); -} - -torch::autograd::variable_list group_index_select_dim0_decomposed( - at::TensorList input_group, - at::TensorList indices_group) { - int num_groups = input_group.size(); - TORCH_CHECK(num_groups == (int)indices_group.size()) - std::vector output_group; - for (const auto i : c10::irange(num_groups)) { - output_group.push_back( - at::index_select(input_group[i], 0, indices_group[i])); - } - return output_group; -} - -namespace { -Tensor index_select_dim0( - const Tensor& input, - const Tensor& indices, - std::optional /*consecutive_range_start*/, - std::optional /*consecutive_range_length*/, - std::optional /*skip_indices_sorting_fwd*/) { - return at::index_select(input, 0, indices); + "group_index_select_dim0_gpu_backward is not implemented for CPU"); } Tensor bottom_k_per_row( @@ -3108,11 +3046,9 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { {PT2_COMPLIANT_TAG}); // group_index_select_dim0_gpu helper functions - not defined for CPU! m.def( - "group_index_select_dim0_autograd_impl(Tensor[] inputs, int group_size) -> Tensor[]"); - m.def( - "group_index_select_dim0_forward_impl(Tensor[] inputs, int group_size) -> Tensor[]"); + "group_index_select_dim0_gpu_impl(Tensor[] inputs, int group_size) -> Tensor[]"); m.def( - "group_index_select_dim0_backward_impl(Tensor[] inputs, SymInt[] output_shape_group) -> Tensor[]"); + "group_index_select_dim0_gpu_backward(Tensor[] inputs, SymInt[] output_shape_group) -> Tensor[]"); // This is an one-off op to be used in split_embedding_utils.py for zipf // generation w/o replacement along dim=-1. If requires_unique=True, find // smallest unique k. If the number of unique elements is less than k, @@ -3196,14 +3132,13 @@ TORCH_LIBRARY_IMPL(fbgemm, CPU, m) { "pack_segments_backward", fbgemm_gpu::pack_segments_backward_cpu); DISPATCH_TO_CPU("index_select_dim0", fbgemm_gpu::index_select_dim0); DISPATCH_TO_CPU( - "group_index_select_dim0", - fbgemm_gpu::group_index_select_dim0_decomposed); + "group_index_select_dim0", fbgemm_gpu::group_index_select_dim0); DISPATCH_TO_CPU( - "group_index_select_dim0_forward_impl", - fbgemm_gpu::group_index_select_dim0_forward_impl_cpu); + "group_index_select_dim0_gpu_impl", + fbgemm_gpu::group_index_select_dim0_gpu_impl_cpu); DISPATCH_TO_CPU( - "group_index_select_dim0_backward_impl", - fbgemm_gpu::group_index_select_dim0_backward_impl_cpu); + "group_index_select_dim0_gpu_backward", + fbgemm_gpu::group_index_select_dim0_gpu_backward_cpu); DISPATCH_TO_CPU("bottom_k_per_row", fbgemm_gpu::bottom_k_per_row); } @@ -3212,14 +3147,11 @@ TORCH_LIBRARY_IMPL(fbgemm, Autograd, m) { } TORCH_LIBRARY_IMPL(fbgemm, AutogradCPU, m) { - m.impl( - "group_index_select_dim0", - &fbgemm_gpu::group_index_select_dim0_decomposed); + m.impl("group_index_select_dim0", &fbgemm_gpu::group_index_select_dim0); } TORCH_LIBRARY_IMPL(fbgemm, Meta, m) { // CPU group_index_select_dim0 is decomposable m.impl( - "group_index_select_dim0", - TORCH_FN(fbgemm_gpu::group_index_select_dim0_decomposed)); + "group_index_select_dim0", TORCH_FN(fbgemm_gpu::group_index_select_dim0)); } diff --git a/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp b/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp index 77b7f7785..6325017e8 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp +++ b/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp @@ -193,346 +193,442 @@ class IndexSelectDim0GPUOp } }; -// need to combine input_group and indices_group into one tensor list -// to get this working with autograd. -static torch::autograd::variable_list forward_impl( +std::pair, std::vector> +group_index_select_dim0_unpack( at::TensorList all_indices_input, const int64_t group_size) { - // Unpack from TensorList - auto [input_group, indices_group] = - group_index_select_dim0_unpack(all_indices_input, group_size); - - // args_tensor stores kernel arguments: - // input_ptrs (group_size int64_t elements) - // output_ptrs (group_size int64_t elements) - // indices_ptrs (group_size int64_t elements) - // warp_offsets_group (group_size + 1 int64_t elements) - // num_cols_group (group_size int32_t elements) - int64_t args_ptrs_offsets[NUM_ARGS + 1]; - - const int64_t numels_num_cols_group_64 = - compute_num_int64s(group_size); - - // Initialize offsets - args_ptrs_offsets[P_input_ptrs] = group_size; - args_ptrs_offsets[P_output_ptrs] = group_size; - args_ptrs_offsets[P_indices_ptrs] = group_size; - args_ptrs_offsets[P_warp_offsets_group_ptrs] = group_size + 1; - args_ptrs_offsets[P_num_cols_group_ptrs] = numels_num_cols_group_64; - - // Compute offsets - int64_t offset = 0; - auto next = args_ptrs_offsets[0]; - for (const auto i : c10::irange(NUM_ARGS)) { - args_ptrs_offsets[i] = offset; - offset += next; - next = args_ptrs_offsets[i + 1]; - } - // Total number of int64_t elements required - args_ptrs_offsets[NUM_ARGS] = offset; - - // Allocate memory for GroupIndexSelectArgs - at::Tensor args_tensor = at::empty( - {static_cast(args_ptrs_offsets[NUM_ARGS] * sizeof(int64_t))}, - at::TensorOptions().dtype(at::kByte).pinned_memory(true)); - - // Ensure that args_tensor is contiguous - TORCH_CHECK(args_tensor.is_contiguous()); - - // Initialize raw pointers to point to Tensor args_tensor - int64_t* input_ptrs = nullptr; - int64_t* output_ptrs = nullptr; - int64_t* indices_ptrs = nullptr; - int64_t* warp_offsets_group = nullptr; - int32_t* num_cols_group = nullptr; - - // Offset host pointers - offset_args( - &input_ptrs, - &output_ptrs, - &indices_ptrs, - &warp_offsets_group, - &num_cols_group, - reinterpret_cast(args_tensor.data_ptr()), - args_ptrs_offsets); - - auto& first_input = input_group[0]; - auto& first_indices = indices_group[0]; - - const int input_dim = first_input.dim(); - const int num_output_rows = first_indices.size(0); - const int num_input_rows = first_input.size(0); - Tensor input_reshaped = first_input.reshape({num_input_rows, -1}); - const int num_cols = input_reshaped.size(1); - const int cols_per_warp = get_group_index_select_cols_per_warp(); - int64_t warp_offset = 0; - bool use_var_cols = false; - - // Allocate memory for output_group - std::vector output_group; - output_group.reserve(group_size + 2); + std::vector indices_group; + std::vector input_group; - // We need to store contiguous inputs and indices outside the for-loop to - // guarantee that the contiguous tensors will outlive the kernel - // computation - std::vector> input_contigs; - std::vector> index_contigs; - input_contigs.reserve(group_size); - index_contigs.reserve(group_size); + indices_group.reserve(group_size); + input_group.reserve(group_size); - // For each group, copy input to output for (const auto i : c10::irange(group_size)) { - const auto& input = input_group[i]; - const auto& indices = indices_group[i]; - - // Verify that all input tensors have the same number of dimensions - TORCH_CHECK( - input_dim == input.dim(), - "All inputs in group_index_select must have the same number of dimensions"); - - // Verify that all tensors are on the same GPU - TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(input, indices); - - auto num_output_rows_ = indices.size(0); + indices_group.push_back(all_indices_input[i]); + input_group.push_back(all_indices_input[group_size + i]); + } - // Verify that all input tensors have the same shape[0] - TORCH_CHECK( - num_output_rows == num_output_rows_, - "The number of indices to be selected must be the same for the entire group"); - const auto input_reshaped_ = input.reshape({input.size(0), -1}); + TORCH_CHECK(group_size == static_cast(indices_group.size())); - // Number of columns can be different - auto num_cols_ = input_reshaped_.size(1); - auto warps_per_row = (num_cols_ + cols_per_warp - 1) / cols_per_warp; + return std::make_pair(input_group, indices_group); +} - if (num_cols != num_cols_) { - use_var_cols = true; +class GroupIndexSelectDim0GPUOp + : public torch::autograd::Function { + public: + // need to combine input_group and indices_group into one tensor list + // to get this working with autograd. + static torch::autograd::variable_list forward_impl( + at::TensorList all_indices_input, + const int64_t group_size) { + // Unpack from TensorList + auto [input_group, indices_group] = + group_index_select_dim0_unpack(all_indices_input, group_size); + + // args_tensor stores kernel arguments: + // input_ptrs (group_size int64_t elements) + // output_ptrs (group_size int64_t elements) + // indices_ptrs (group_size int64_t elements) + // warp_offsets_group (group_size + 1 int64_t elements) + // num_cols_group (group_size int32_t elements) + int64_t args_ptrs_offsets[NUM_ARGS + 1]; + + const int64_t numels_num_cols_group_64 = + compute_num_int64s(group_size); + + // Initialize offsets + args_ptrs_offsets[P_input_ptrs] = group_size; + args_ptrs_offsets[P_output_ptrs] = group_size; + args_ptrs_offsets[P_indices_ptrs] = group_size; + args_ptrs_offsets[P_warp_offsets_group_ptrs] = group_size + 1; + args_ptrs_offsets[P_num_cols_group_ptrs] = numels_num_cols_group_64; + + // Compute offsets + int64_t offset = 0; + auto next = args_ptrs_offsets[0]; + for (const auto i : c10::irange(NUM_ARGS)) { + args_ptrs_offsets[i] = offset; + offset += next; + next = args_ptrs_offsets[i + 1]; } - - // Create output pointers - auto input_shape = input.sizes().vec(); - input_shape[0] = num_output_rows_; - Tensor output = at::empty(input_shape, input.options()); - // Ensure that the allocated output is contiguous - TORCH_CHECK(output.is_contiguous()) - output_group.push_back(output); - - // Store input and indices contigs to keep them alive during the kernel + // Total number of int64_t elements required + args_ptrs_offsets[NUM_ARGS] = offset; + + // Allocate memory for GroupIndexSelectArgs + at::Tensor args_tensor = at::empty( + {static_cast(args_ptrs_offsets[NUM_ARGS] * sizeof(int64_t))}, + at::TensorOptions().dtype(at::kByte).pinned_memory(true)); + + // Ensure that args_tensor is contiguous + TORCH_CHECK(args_tensor.is_contiguous()); + + // Initialize raw pointers to point to Tensor args_tensor + int64_t* input_ptrs = nullptr; + int64_t* output_ptrs = nullptr; + int64_t* indices_ptrs = nullptr; + int64_t* warp_offsets_group = nullptr; + int32_t* num_cols_group = nullptr; + + // Offset host pointers + offset_args( + &input_ptrs, + &output_ptrs, + &indices_ptrs, + &warp_offsets_group, + &num_cols_group, + reinterpret_cast(args_tensor.data_ptr()), + args_ptrs_offsets); + + auto& first_input = input_group[0]; + auto& first_indices = indices_group[0]; + + const int input_dim = first_input.dim(); + const int num_output_rows = first_indices.size(0); + const int num_input_rows = first_input.size(0); + Tensor input_reshaped = first_input.reshape({num_input_rows, -1}); + const int num_cols = input_reshaped.size(1); + const int cols_per_warp = get_group_index_select_cols_per_warp(); + int64_t warp_offset = 0; + bool use_var_cols = false; + + // Allocate memory for output_group + std::vector output_group; + output_group.reserve(group_size + 2); + + // We need to store contiguous inputs and indices outside the for-loop to + // guarantee that the contiguous tensors will outlive the kernel // computation - input_contigs.push_back(input.expect_contiguous()); - index_contigs.push_back(indices.expect_contiguous()); - - // Store args - input_ptrs[i] = reinterpret_cast(input_contigs[i]->data_ptr()); - output_ptrs[i] = reinterpret_cast(output.data_ptr()); - indices_ptrs[i] = reinterpret_cast(index_contigs[i]->data_ptr()); - warp_offsets_group[i] = warp_offset; - num_cols_group[i] = num_cols_; + std::vector> input_contigs; + std::vector> index_contigs; + input_contigs.reserve(group_size); + index_contigs.reserve(group_size); + + // For each group, copy input to output + for (const auto i : c10::irange(group_size)) { + const auto& input = input_group[i]; + const auto& indices = indices_group[i]; + + // Verify that all input tensors have the same number of dimensions + TORCH_CHECK( + input_dim == input.dim(), + "All inputs in group_index_select must have the same number of dimensions"); + + // Verify that all tensors are on the same GPU + TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(input, indices); + + auto num_output_rows_ = indices.size(0); + + // Verify that all input tensors have the same shape[0] + TORCH_CHECK( + num_output_rows == num_output_rows_, + "The number of indices to be selected must be the same for the entire group"); + const auto input_reshaped_ = input.reshape({input.size(0), -1}); + + // Number of columns can be different + auto num_cols_ = input_reshaped_.size(1); + auto warps_per_row = (num_cols_ + cols_per_warp - 1) / cols_per_warp; + + if (num_cols != num_cols_) { + use_var_cols = true; + } + + // Create output pointers + auto input_shape = input.sizes().vec(); + input_shape[0] = num_output_rows_; + Tensor output = at::empty(input_shape, input.options()); + // Ensure that the allocated output is contiguous + TORCH_CHECK(output.is_contiguous()) + output_group.push_back(output); + + // Store input and indices contigs to keep them alive during the kernel + // computation + input_contigs.push_back(input.expect_contiguous()); + index_contigs.push_back(indices.expect_contiguous()); + + // Store args + input_ptrs[i] = reinterpret_cast(input_contigs[i]->data_ptr()); + output_ptrs[i] = reinterpret_cast(output.data_ptr()); + indices_ptrs[i] = reinterpret_cast(index_contigs[i]->data_ptr()); + warp_offsets_group[i] = warp_offset; + num_cols_group[i] = num_cols_; + + warp_offset += warps_per_row * num_output_rows; + } - warp_offset += warps_per_row * num_output_rows; + // Store the last offset + warp_offsets_group[group_size] = warp_offset; + + // Transfer args tensor to GPU + args_tensor = args_tensor.to( + first_input.device(), + /*non_blocking=*/true); + + // Offset raw ptrs in GPU memory + offset_args( + &input_ptrs, + &output_ptrs, + &indices_ptrs, + &warp_offsets_group, + &num_cols_group, + reinterpret_cast(args_tensor.data_ptr()), + args_ptrs_offsets); + + int64_t saved_data[] = { + static_cast(group_size), + use_var_cols, + reinterpret_cast(warp_offsets_group), + reinterpret_cast(num_cols_group), + warp_offset, + }; + auto saved_data_t = at::empty( + {sizeof(saved_data) / sizeof(int64_t)}, + at::TensorOptions().dtype(at::kLong)); + TORCH_CHECK(saved_data_t.is_contiguous()); + memcpy(saved_data_t.data_ptr(), saved_data, sizeof(saved_data)); + + group_index_select_or_add_cuda( + input_ptrs, + output_ptrs, + indices_ptrs, + warp_offsets_group, + num_cols_group, + first_input.scalar_type(), + first_indices.scalar_type(), + first_input.device().index(), + num_output_rows, + /*total_num_warps=*/warp_offset, + group_size, + /*use_index_select=*/true, + use_var_cols); + + output_group.push_back(args_tensor); + output_group.push_back(saved_data_t); + + // return format: + // (group_size outputs, 1 args_tensor, 1 saved_data) + return output_group; } - // Store the last offset - warp_offsets_group[group_size] = warp_offset; - - // Transfer args tensor to GPU - args_tensor = args_tensor.to( - first_input.device(), - /*non_blocking=*/true); - - // Offset raw ptrs in GPU memory - offset_args( - &input_ptrs, - &output_ptrs, - &indices_ptrs, - &warp_offsets_group, - &num_cols_group, - reinterpret_cast(args_tensor.data_ptr()), - args_ptrs_offsets); - - int64_t saved_data[] = { - static_cast(group_size), - use_var_cols, - reinterpret_cast(warp_offsets_group), - reinterpret_cast(num_cols_group), - warp_offset, - }; - auto saved_data_t = at::empty( - {sizeof(saved_data) / sizeof(int64_t)}, - at::TensorOptions().dtype(at::kLong)); - TORCH_CHECK(saved_data_t.is_contiguous()); - memcpy(saved_data_t.data_ptr(), saved_data, sizeof(saved_data)); - - group_index_select_or_add_cuda( - input_ptrs, - output_ptrs, - indices_ptrs, - warp_offsets_group, - num_cols_group, - first_input.scalar_type(), - first_indices.scalar_type(), - first_input.device().index(), - num_output_rows, - /*total_num_warps=*/warp_offset, - group_size, - /*use_index_select=*/true, - use_var_cols); - - output_group.push_back(args_tensor); - output_group.push_back(saved_data_t); - - // return format: - // (group_size outputs, 1 args_tensor, 1 saved_data) - return output_group; -} + static torch::autograd::variable_list forward( + torch::autograd::AutogradContext* ctx, + at::TensorList all_indices_input, + const int64_t group_size) { + at::AutoDispatchBelowADInplaceOrView guard; + static auto forward_op = + torch::Dispatcher::singleton() + .findSchemaOrThrow("fbgemm::group_index_select_dim0_gpu_impl", "") + .typed(); + auto result = forward_op.call(all_indices_input, group_size); + TORCH_CHECK(static_cast(result.size()) == group_size + 2); + + auto [input_group, indices_group] = + group_index_select_dim0_unpack(all_indices_input, group_size); + const auto input_dim = input_group[0].dim(); + std::vector input_shape_group; + input_shape_group.reserve(group_size * input_dim); + + for (const auto i : c10::irange(group_size)) { + const auto& input = input_group[i]; + // Copy input shape + auto input_shape = input.sym_sizes().vec(); + input_shape_group.insert( + input_shape_group.end(), input_shape.begin(), input_shape.end()); + } -static torch::autograd::variable_list backward_impl( - at::TensorList all_inputs, - c10::SymIntArrayRef output_shape_group_ref) { - TORCH_CHECK(all_inputs.size() > 2); - - const int64_t group_size = (all_inputs.size() - 3) / 2; - - Tensor fwd_input = all_inputs[2 * group_size + 2]; - const int64_t output_dim = fwd_input.dim(); - Tensor saved_data = all_inputs[2 * group_size + 1]; - Tensor args_tensor_old = all_inputs[2 * group_size]; - Tensor first_indices = all_inputs[group_size]; - - auto grad_output_group = std::vector( - all_inputs.cbegin(), all_inputs.cbegin() + group_size); - std::vector output_shape_group; - output_shape_group.reserve(output_shape_group_ref.size()); - for (const auto& i : output_shape_group_ref) { - output_shape_group.push_back(i.as_int_unchecked()); - } + // save indices, args_tensor, saved_data + auto saved_tensors = std::vector(indices_group); + saved_tensors.insert( + saved_tensors.end(), result.cbegin() + group_size, result.cend()); + saved_tensors.push_back(input_group[0]); + ctx->save_for_backward(saved_tensors); + ctx->saved_data["input_shape_group"] = input_shape_group; - auto indices_group = std::vector( - all_inputs.cbegin() + group_size, all_inputs.cbegin() + 2 * group_size); - - // Retrieve saved data - TORCH_CHECK(saved_data.device() == at::kCPU); - TORCH_CHECK(saved_data.is_contiguous()); - int64_t* saved_data_ptr = saved_data.data_ptr(); - // Check that the size is the same - TORCH_CHECK(saved_data_ptr[0] == group_size); - const bool use_var_cols = saved_data_ptr[1]; - int64_t* warp_offsets_group = reinterpret_cast(saved_data_ptr[2]); - int32_t* num_cols_group = reinterpret_cast(saved_data_ptr[3]); - int64_t total_num_warps = saved_data_ptr[4]; - - // We checked in forward that all output rows are the same for all member - // in the group - const int num_input_rows = grad_output_group[0].size(0); - - std::vector outputs; - // Returning 3 outputs: - // 1) group_size Variable()'s for indices - // 2) group_size gradients for inputs - // 3) 1 Variable() for group_size - outputs.reserve(group_size * 2 + 1); - - // 1) Add group_size Variable()'s for indices - // c10::irange cannot be used in here as it - // triggers a build error of i being an unused variable. - // Add empty tensor with zero size here to make __torch_dispatch__ work for - // the backward op. Those empty tensors will be replaced with - // torch::autograd::Variable() outside of the op call. - for (auto i = 0; i < group_size; i++) { - outputs.push_back(at::empty({0}, at::TensorOptions().dtype(at::kLong))); + return result; } - // Allocate Tensor for ptrs of grad output and input, and indices - Tensor args_tensor = at::empty( - {group_size * 3}, - at::TensorOptions().dtype(at::kLong).pinned_memory(true)); - // Ensure that args_tensor is contiguous - TORCH_CHECK(args_tensor.is_contiguous()); - int64_t* grad_output_ptrs = args_tensor.data_ptr(); - int64_t* grad_input_ptrs = args_tensor.data_ptr() + group_size; - int64_t* indices_ptrs = args_tensor.data_ptr() + 2 * group_size; - - int64_t group_grad_input_numel = 0; - std::vector grad_input_numels; - grad_input_numels.reserve(group_size); - - // We need to store contiguous gradients outside the for-loop to guarantee - // that the contiguous tensors will outlive the kernel computation - std::vector> grad_output_contigs; - grad_output_contigs.reserve(group_size); - - for (const auto i : c10::irange(group_size)) { - const auto& grad = grad_output_group[i]; - TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(grad, first_indices); + static torch::autograd::variable_list backward_impl( + at::TensorList all_inputs, + c10::SymIntArrayRef output_shape_group_ref) { + TORCH_CHECK(all_inputs.size() > 2); + + const int64_t group_size = (all_inputs.size() - 3) / 2; + + Tensor fwd_input = all_inputs[2 * group_size + 2]; + const int64_t output_dim = fwd_input.dim(); + Tensor saved_data = all_inputs[2 * group_size + 1]; + Tensor args_tensor_old = all_inputs[2 * group_size]; + Tensor first_indices = all_inputs[group_size]; + + auto grad_output_group = std::vector( + all_inputs.cbegin(), all_inputs.cbegin() + group_size); + std::vector output_shape_group; + output_shape_group.reserve(output_shape_group_ref.size()); + for (const auto& i : output_shape_group_ref) { + output_shape_group.push_back(i.as_int_unchecked()); + } - // Store grad contigs to keep them alive during the kernel computation - grad_output_contigs.push_back(grad.expect_contiguous()); + auto indices_group = std::vector( + all_inputs.cbegin() + group_size, all_inputs.cbegin() + 2 * group_size); + + // Retrieve saved data + TORCH_CHECK(saved_data.device() == at::kCPU); + TORCH_CHECK(saved_data.is_contiguous()); + int64_t* saved_data_ptr = saved_data.data_ptr(); + // Check that the size is the same + TORCH_CHECK(saved_data_ptr[0] == group_size); + const bool use_var_cols = saved_data_ptr[1]; + int64_t* warp_offsets_group = reinterpret_cast(saved_data_ptr[2]); + int32_t* num_cols_group = reinterpret_cast(saved_data_ptr[3]); + int64_t total_num_warps = saved_data_ptr[4]; + + // We checked in forward that all output rows are the same for all member + // in the group + const int num_input_rows = grad_output_group[0].size(0); + + std::vector outputs; + // Returning 3 outputs: + // 1) group_size Variable()'s for indices + // 2) group_size gradients for inputs + // 3) 1 Variable() for group_size + outputs.reserve(group_size * 2 + 1); + + // 1) Add group_size Variable()'s for indices + // c10::irange cannot be used in here as it + // triggers a build error of i being an unused variable. + // Add empty tensor with zero size here to make __torch_dispatch__ work for + // the backward op. Those empty tensors will be replaced with + // torch::autograd::Variable() outside of the op call. + for (auto i = 0; i < group_size; i++) { + outputs.push_back(at::empty({0}, at::TensorOptions().dtype(at::kLong))); + } - // Compute the total number of elements for all grad_inputs - int64_t grad_input_numel = output_shape_group[i * output_dim]; - for (auto j = (i * output_dim) + 1; j < (i + 1) * output_dim; j++) { - grad_input_numel *= output_shape_group[j]; + // Allocate Tensor for ptrs of grad output and input, and indices + Tensor args_tensor = at::empty( + {group_size * 3}, + at::TensorOptions().dtype(at::kLong).pinned_memory(true)); + // Ensure that args_tensor is contiguous + TORCH_CHECK(args_tensor.is_contiguous()); + int64_t* grad_output_ptrs = args_tensor.data_ptr(); + int64_t* grad_input_ptrs = args_tensor.data_ptr() + group_size; + int64_t* indices_ptrs = args_tensor.data_ptr() + 2 * group_size; + + int64_t group_grad_input_numel = 0; + std::vector grad_input_numels; + grad_input_numels.reserve(group_size); + + // We need to store contiguous gradients outside the for-loop to guarantee + // that the contiguous tensors will outlive the kernel computation + std::vector> grad_output_contigs; + grad_output_contigs.reserve(group_size); + + for (const auto i : c10::irange(group_size)) { + const auto& grad = grad_output_group[i]; + TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(grad, first_indices); + + // Store grad contigs to keep them alive during the kernel computation + grad_output_contigs.push_back(grad.expect_contiguous()); + + // Compute the total number of elements for all grad_inputs + int64_t grad_input_numel = output_shape_group[i * output_dim]; + for (auto j = (i * output_dim) + 1; j < (i + 1) * output_dim; j++) { + grad_input_numel *= output_shape_group[j]; + } + grad_input_numels.push_back(grad_input_numel); + group_grad_input_numel += grad_input_numel; + + // Put all grad output/input pointers in an array + grad_output_ptrs[i] = + reinterpret_cast(grad_output_contigs[i]->data_ptr()); } - grad_input_numels.push_back(grad_input_numel); - group_grad_input_numel += grad_input_numel; - // Put all grad output/input pointers in an array - grad_output_ptrs[i] = - reinterpret_cast(grad_output_contigs[i]->data_ptr()); - } + // Allocate a big tensor to avoid calling many small elementwise kernels + const auto group_grad_input = + at::zeros({group_grad_input_numel}, fwd_input.options()); + TORCH_CHECK(group_grad_input.is_contiguous()); - // Allocate a big tensor to avoid calling many small elementwise kernels - const auto group_grad_input = - at::zeros({group_grad_input_numel}, fwd_input.options()); - TORCH_CHECK(group_grad_input.is_contiguous()); + // Split to output_group + auto output_group = group_grad_input.split(grad_input_numels, 0); - // Split to output_group - auto output_group = group_grad_input.split(grad_input_numels, 0); + TORCH_CHECK(output_group.size() == static_cast(group_size)); - TORCH_CHECK(output_group.size() == static_cast(group_size)); + // Reshape grad inputs and obtain their pointers + for (int i = 0; i < group_size; i++) { + const auto grad_input_shape = std::vector( + output_shape_group.begin() + i * output_dim, + output_shape_group.begin() + (i + 1) * output_dim); + output_group[i] = output_group[i].reshape(grad_input_shape); + TORCH_CHECK(output_group[i].is_contiguous()); + grad_input_ptrs[i] = + reinterpret_cast(output_group[i].data_ptr()); - // Reshape grad inputs and obtain their pointers - for (int i = 0; i < group_size; i++) { - const auto grad_input_shape = std::vector( - output_shape_group.begin() + i * output_dim, - output_shape_group.begin() + (i + 1) * output_dim); - output_group[i] = output_group[i].reshape(grad_input_shape); - TORCH_CHECK(output_group[i].is_contiguous()); - grad_input_ptrs[i] = reinterpret_cast(output_group[i].data_ptr()); + // 2) Add group_size gradients for inputs + outputs.push_back(output_group[i]); + } - // 2) Add group_size gradients for inputs - outputs.push_back(output_group[i]); - } + // Calculate indices_ptrs + std::vector> index_contigs; + index_contigs.reserve(group_size); + for (const auto i : c10::irange(group_size)) { + const auto& indices = indices_group[i]; + index_contigs.push_back(indices.expect_contiguous()); + indices_ptrs[i] = reinterpret_cast(index_contigs[i]->data_ptr()); + } - // Calculate indices_ptrs - std::vector> index_contigs; - index_contigs.reserve(group_size); - for (const auto i : c10::irange(group_size)) { - const auto& indices = indices_group[i]; - index_contigs.push_back(indices.expect_contiguous()); - indices_ptrs[i] = reinterpret_cast(index_contigs[i]->data_ptr()); + // Transfer grad output pointers to GPU + args_tensor = args_tensor.to(first_indices.device(), /*non_blocking=*/true); + + group_index_select_or_add_cuda( + args_tensor.data_ptr(), + args_tensor.data_ptr() + group_size, + args_tensor.data_ptr() + 2 * group_size, + warp_offsets_group, + num_cols_group, + fwd_input.scalar_type(), + first_indices.scalar_type(), + fwd_input.device().index(), + num_input_rows, + total_num_warps, + group_size, + /*use_index_select=*/false, + use_var_cols); + + return outputs; } - // Transfer grad output pointers to GPU - args_tensor = args_tensor.to(first_indices.device(), /*non_blocking=*/true); - - group_index_select_or_add_cuda( - args_tensor.data_ptr(), - args_tensor.data_ptr() + group_size, - args_tensor.data_ptr() + 2 * group_size, - warp_offsets_group, - num_cols_group, - fwd_input.scalar_type(), - first_indices.scalar_type(), - fwd_input.device().index(), - num_input_rows, - total_num_warps, - group_size, - /*use_index_select=*/false, - use_var_cols); - - return outputs; -} + static torch::autograd::variable_list backward( + torch::autograd::AutogradContext* ctx, + torch::autograd::variable_list grad_output_group) { + TORCH_CHECK(grad_output_group.size() >= 2); + if (grad_output_group.size() == 2) { + // empty outputs + return torch::autograd::variable_list(1); + } + // remove redundant grads + auto group_size = grad_output_group.size() - 2; + grad_output_group.resize(group_size); + + auto saved_tensors = ctx->get_saved_variables(); + TORCH_CHECK(saved_tensors.size() == group_size + 3); + auto output_shape_group = + ctx->saved_data["input_shape_group"].toSymIntVector(); + grad_output_group.insert( + grad_output_group.end(), saved_tensors.begin(), saved_tensors.end()); + static auto backward_op = + torch::Dispatcher::singleton() + .findSchemaOrThrow( + "fbgemm::group_index_select_dim0_gpu_backward", "") + .typed(); + auto res = backward_op.call(grad_output_group, output_shape_group); + // 1) Add group_size Variable()'s for indices + // Replace all empty tensors with Variable(). This must be done after the + // op.call to make __torch_dispatch__ work for the backward op. + std::fill( + res.begin(), res.begin() + group_size, torch::autograd::Variable()); + // 3) Add 1 Variable() for group_size + res.push_back({}); + return res; + } +}; Tensor pack_segments_cuda( const Tensor& t_in, @@ -558,6 +654,45 @@ Tensor index_select_dim0_gpu( user_skip_indices_sorting_fwd && !c10::InferenceMode::is_enabled())[0]; } +torch::autograd::variable_list group_index_select_dim0_gpu_impl( + at::TensorList all_indices_input, + const int64_t group_size) { + return GroupIndexSelectDim0GPUOp::apply(all_indices_input, group_size); +} + +torch::autograd::variable_list group_index_select_dim0_gpu( + at::TensorList input_group, + at::TensorList indices_group) { + const auto group_size = indices_group.size(); + std::vector output_group; + + if (group_size == 0) { + return std::vector(); + } + + // Pack input_group and indices_group into TensorList + std::vector all_indices_input_vec; + all_indices_input_vec.reserve(group_size * 2); + + for (const Tensor& index : indices_group) { + all_indices_input_vec.push_back(index); + } + for (const Tensor& input : input_group) { + all_indices_input_vec.push_back(input); + } + + at::TensorList all_indices_input_tensor = all_indices_input_vec; + + static auto forward_op = + torch::Dispatcher::singleton() + .findSchemaOrThrow("fbgemm::group_index_select_dim0_gpu_impl", "") + .typed(); + auto res = forward_op.call(all_indices_input_tensor, group_size); + TORCH_CHECK(res.size() == group_size + 2); + // only return the outputs (the first group_size elements) + res.resize(group_size); + return res; +} } // namespace fbgemm_gpu TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) { @@ -585,19 +720,18 @@ TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) { "pack_segments_backward", fbgemm_gpu::pack_segments_backward_cuda); DISPATCH_TO_CUDA("index_select_dim0", fbgemm_gpu::index_select_dim0_gpu); DISPATCH_TO_CUDA( - "group_index_select_dim0_forward_impl", fbgemm_gpu::forward_impl); - DISPATCH_TO_CUDA( - "group_index_select_dim0_backward_impl", fbgemm_gpu::backward_impl); + "group_index_select_dim0_gpu_impl", + fbgemm_gpu::GroupIndexSelectDim0GPUOp::forward_impl); DISPATCH_TO_CUDA( - "group_index_select_dim0", fbgemm_gpu::group_index_select_dim0); + "group_index_select_dim0_gpu_backward", + fbgemm_gpu::GroupIndexSelectDim0GPUOp::backward_impl); DISPATCH_TO_CUDA( - "group_index_select_dim0_autograd_impl", - &fbgemm_gpu::group_index_select_dim0_autograd_impl); + "group_index_select_dim0", fbgemm_gpu::group_index_select_dim0_gpu); } TORCH_LIBRARY_IMPL(fbgemm, AutogradCUDA, m) { - m.impl("group_index_select_dim0", &fbgemm_gpu::group_index_select_dim0); + m.impl("group_index_select_dim0", &fbgemm_gpu::group_index_select_dim0_gpu); m.impl( - "group_index_select_dim0_autograd_impl", - &fbgemm_gpu::group_index_select_dim0_autograd_impl); + "group_index_select_dim0_gpu_impl", + &fbgemm_gpu::group_index_select_dim0_gpu_impl); } From 9a845cc573026a13c6d18c6c65835c2c9e4eb1a6 Mon Sep 17 00:00:00 2001 From: Yulu Jia Date: Wed, 2 Oct 2024 21:07:01 -0700 Subject: [PATCH 39/48] add set_range (#3213) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3213 X-link: https://github.com/facebookresearch/FBGEMM/pull/310 add set_range in `KVTensorWrapper`: - add `set_range()` method to write data into the backing storage - make the constructor parameter `snapshot_handle` optional, because it's not needed for writing to the backing storage Reviewed By: xunnanxu Differential Revision: D63702602 fbshipit-source-id: 3b064c1bb12bca0c85f77739458e94a2722b3749 --- .../ssd_split_table_batched_embeddings.cpp | 36 +++++++--- .../ssd_table_batched_embeddings.h | 14 +++- .../test/tbe/ssd/kv_tensor_wrapper_test.py | 71 ++++++++++++++++++- 3 files changed, 107 insertions(+), 14 deletions(-) diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp index 9723ddfea..73e6973c6 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp @@ -396,24 +396,26 @@ class KVTensorWrapper : public torch::jit::CustomClassHolder { public: explicit KVTensorWrapper( c10::intrusive_ptr db, - c10::intrusive_ptr snapshot_handle, std::vector shape, int64_t dtype, - int64_t row_offset) - : db_(db->impl_), - snapshot_handle_(std::move(snapshot_handle)), - shape_(std::move(shape)), - row_offset_(row_offset) { + int64_t row_offset, + std::optional> + snapshot_handle) + : db_(db->impl_), shape_(std::move(shape)), row_offset_(row_offset) { CHECK_EQ(shape_.size(), 2) << "Only 2D emb tensors are supported"; options_ = at::TensorOptions() .dtype(static_cast(dtype)) .device(at::kCPU) .layout(at::kStrided); + if (snapshot_handle.has_value()) { + snapshot_handle_ = std::move(snapshot_handle.value()); + } } at::Tensor narrow(int64_t dim, int64_t start, int64_t length) { CHECK_EQ(dim, 0) << "Only narrow on dim 0 is supported"; CHECK_EQ(db_->get_max_D(), shape_[1]); + CHECK_TRUE(snapshot_handle_ != nullptr); auto t = at::empty(c10::IntArrayRef({length, db_->get_max_D()}), options_); db_->get_range_from_snapshot( t, start + row_offset_, length, snapshot_handle_->handle); @@ -422,6 +424,16 @@ class KVTensorWrapper : public torch::jit::CustomClassHolder { return t.narrow(1, 0, shape_[1]); } + void set_range( + int64_t dim, + const int64_t start, + const int64_t length, + const at::Tensor& weights) { + CHECK_EQ(dim, 0) << "Only set_range on dim 0 is supported"; + CHECK_EQ(db_->get_max_D(), shape_[1]); + db_->set_range(weights, start + row_offset_, length); + } + c10::IntArrayRef size() { return shape_; } @@ -537,21 +549,25 @@ static auto kv_tensor_wrapper = .def( torch::init< c10::intrusive_ptr, - c10::intrusive_ptr, std::vector, int64_t, - int64_t>(), + int64_t, + std::optional< + c10::intrusive_ptr>>(), "", {torch::arg("db"), - torch::arg("snapshot_handle"), torch::arg("shape"), torch::arg("dtype"), - torch::arg("row_offset")}) + torch::arg("row_offset"), + // snapshot must be provided for reading + // not needed for writing + torch::arg("snapshot_handle") = std::nullopt}) .def( "narrow", &KVTensorWrapper::narrow, "", {torch::arg("dim"), torch::arg("start"), torch::arg("length")}) + .def("set_range", &KVTensorWrapper::set_range) .def_property("dtype_str", &KVTensorWrapper::dtype_str) .def_property( "shape", diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h index 2a8f101c5..f14897854 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h @@ -528,13 +528,21 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB { const SnapshotHandle* snapshot_handle) { const auto seq_indices = at::arange(start, start + length, at::TensorOptions().dtype(at::kLong)); - int64_t* count_ = new int64_t[1]; - count_[0] = length; - const auto count = at::from_blob(count_, {1}, at::kLong); + const auto count = at::tensor({length}, at::ScalarType::Long); folly::coro::blockingWait( get_kv_db_async_impl(seq_indices, weights, count, snapshot_handle)); } + void set_range( + const at::Tensor& weights, + const int64_t start, + const int64_t length) { + const auto seq_indices = + at::arange(start, start + length, at::TensorOptions().dtype(at::kLong)); + const auto count = at::tensor({length}, at::ScalarType::Long); + folly::coro::blockingWait(set_kv_db_async(seq_indices, weights, count)); + } + int64_t get_max_D() { return max_D_; } diff --git a/fbgemm_gpu/test/tbe/ssd/kv_tensor_wrapper_test.py b/fbgemm_gpu/test/tbe/ssd/kv_tensor_wrapper_test.py index 1530f1e75..f1277ccac 100644 --- a/fbgemm_gpu/test/tbe/ssd/kv_tensor_wrapper_test.py +++ b/fbgemm_gpu/test/tbe/ssd/kv_tensor_wrapper_test.py @@ -76,7 +76,7 @@ def test_read_tensor_using_wrapper_from_db(self) -> None: # create a view tensor wrapper snapshot = ssd_db.create_snapshot() tensor_wrapper = torch.classes.fbgemm.KVTensorWrapper( - ssd_db, snapshot, [E, D], weights.dtype, 0 + ssd_db, [E, D], weights.dtype, 0, snapshot ) self.assertEqual(tensor_wrapper.shape, [E, D]) @@ -100,3 +100,72 @@ def test_read_tensor_using_wrapper_from_db(self) -> None: del tensor_wrapper del snapshot self.assertEqual(ssd_db.get_snapshot_count(), 0) + + def test_write_tensor_to_db(self) -> None: + E = int(1e4) # num total rows + D = 128 # emb dimension + N = 1000 # window size + weights_precision = SparseType.FP32 + weights_dtype = weights_precision.as_dtype() + + with tempfile.TemporaryDirectory() as ssd_directory: + # pyre-fixme[16]: Module `classes` has no attribute `fbgemm`. + ssd_db = torch.classes.fbgemm.EmbeddingRocksDBWrapper( + ssd_directory, + 8, # num_shards + 8, # num_threads + 0, # ssd_memtable_flush_period, + 0, # ssd_memtable_flush_offset, + 4, # ssd_l0_files_per_compact, + D, # embedding_dim + 0, # ssd_rate_limit_mbps, + 1, # ssd_size_ratio, + 8, # ssd_compaction_trigger, + 536870912, # 512MB ssd_write_buffer_size, + 8, # ssd_max_write_buffer_num, + -0.01, # ssd_uniform_init_lower + 0.01, # ssd_uniform_init_upper + 32, # row_storage_bitwidth + 10 * (2**20), # block cache size + ) + + weights = torch.arange(N * D, dtype=weights_dtype).view(N, D) + output_weights = torch.empty_like(weights) + + # no snapshot needed for writing to rocksdb + tensor_wrapper0 = torch.classes.fbgemm.KVTensorWrapper( + ssd_db, [E, D], weights.dtype, 0 + ) + step = N + for i in range(0, E, step): + tensor_wrapper0.set_range(0, i, step, weights) + + # force waiting for set to complete + indices = torch.arange(step) + for i in range(0, E, step): + ssd_db.get(i + indices, output_weights, torch.tensor(indices.shape[0])) + + # create a view tensor wrapper + snapshot = ssd_db.create_snapshot() + tensor_wrapper = torch.classes.fbgemm.KVTensorWrapper( + ssd_db, [E, D], weights.dtype, 0, snapshot + ) + self.assertEqual(tensor_wrapper.shape, [E, D]) + + # table has a total of E rows + # load 1000 rows at a time + step = 1000 + for i in range(0, E, step): + narrowed = tensor_wrapper.narrow(0, i, step) + self.assertTrue( + torch.equal(narrowed, weights), + msg=( + f"Tensor value mismatch :\n" + f"actual\n{narrowed}\n\nexpected\n{weights}" + ), + ) + + del tensor_wrapper0 + del tensor_wrapper + del snapshot + self.assertEqual(ssd_db.get_snapshot_count(), 0) From 8e2d4a0d556f944c87d0d91c222511d74d154097 Mon Sep 17 00:00:00 2001 From: Karthik Manivannan Date: Wed, 2 Oct 2024 21:12:32 -0700 Subject: [PATCH 40/48] Add non-persistent fp8 triton_rowwise kernel (#3212) Summary: X-link: https://github.com/pytorch/benchmark/pull/2484 Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3212 X-link: https://github.com/facebookresearch/FBGEMM/pull/308 triton_rowwise persistent kernel performs poorly on MI300 compared to the non-persistent kernel, when both are run with exhaustive AMD-specific tuning. Reviewed By: htyu Differential Revision: D63741099 fbshipit-source-id: c276415ddf8f5d24ffeba70b8ee6493011b393e1 --- .../experimental/gemm/triton_gemm/fp8_gemm.py | 244 +++++++++++++++++- 1 file changed, 243 insertions(+), 1 deletion(-) diff --git a/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py b/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py index b82cde50a..d90f12e0a 100644 --- a/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +++ b/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py @@ -997,6 +997,7 @@ def matmul_fp8_row( fp8_fast_accum: bool = True, imprecise_acc: bool = False, tma_persistent: bool = True, + no_use_persistent: bool = False, ) -> torch.Tensor: """ Performs matmul on [M, K] and [N, K] fp8 matrices with row-wise scalings [M], [N]. @@ -1056,7 +1057,38 @@ def persistent_grid(META): ), ) - if tma_persistent: + if no_use_persistent: + logger.info("Using non-persistent kernel") + if bias is not None: + raise AssertionError("bias is not supported in non-persistent kernel") + # pyre-ignore + torch._library.capture_triton(_kernel_matmul_fp8_row_non_persistent)[grid]( + a, + b, + c, + M, + N, + K, + m_key, + n_key, + k_key, + a_scale, + b_scale, + # bias, + a.stride(0), + a.stride(1), + b.stride(0), + b.stride(1), + c.stride(0), + c.stride(1), + dot_out_dtype=dot_out_dtype_triton, + allow_tf32=allow_tf32, + fp8_fast_accum=fp8_fast_accum, + # GROUP_M=8, + # USE_BIAS=bias is not None, + AB_DTYPE=False, + ) + elif tma_persistent: # used by TMA persistent kernel desc_helper = TmaAutoTuneHelper() desc_helper.init_tma_descriptor("a") @@ -2422,3 +2454,213 @@ def quantize_fp8_block( x_scale = x_scale.to(output_device) # pyre-ignore del x, x_padded return x_fp8.view(x_shape), 1 / x_scale # pyre-ignore + + +def need_split_k(SIZE_M, SIZE_N, SIZE_K): + return (SIZE_M < 64 or SIZE_N < 64) and SIZE_K > 1024 + + +def prune_configs(configs, named_args, **kwargs): + + SIZE_M = named_args["A"].shape[0] + SIZE_N = named_args["B"].shape[1] + SIZE_K = named_args["C"].shape[1] + + pruned_configs = [] + for config in configs: + kw = config.kwargs + BLOCK_SIZE_M, BLOCK_SIZE_N, _ = ( + kw["BLOCK_M"], + kw["BLOCK_N"], + kw["BLOCK_K"], + ) + SPLIT_K = kw["SPLIT_K"] + if SIZE_M <= 32 and BLOCK_SIZE_M != 32: + continue + if SIZE_N <= 32 and BLOCK_SIZE_N != 32: + continue + # skip large split_k when not necessary + if SPLIT_K != 1 and not need_split_k(SIZE_M, SIZE_N, SIZE_K): + continue + pruned_configs.append(config) + logging.info(f"pruned_configs: config len{len(pruned_configs)}") + return pruned_configs + + +def get_full_non_persistent_tuning_space(use_split_k): + if torch.version.hip is None: + logger.warning("Using HIP configs on CUDA device, this may be slow.") + configs = [] + block_mn_range = [32, 64, 128, 256] + block_k_range = [32, 64, 128] + split_k_range = [1] + num_warps_range = [1, 2, 4, 8, 16] + group_m_range = [1, 4, 8] + num_stage_range = [0] + + for block_m in block_mn_range: + for block_n in block_mn_range: + for block_k in block_k_range: + for num_warps in num_warps_range: + for group_m in group_m_range: + for split_k in split_k_range: + for num_stages in num_stage_range: + configs.append( + triton.Config( + { + "BLOCK_M": block_m, + "BLOCK_N": block_n, + "BLOCK_K": block_k, + "GROUP_M": group_m, + "SPLIT_K": split_k, + }, + num_stages=num_stages, + num_warps=num_warps, + ) + ) + + return configs + + +MATMUL_CONFIGS: List[Config] = get_full_non_persistent_tuning_space(True) + + +@triton.autotune( + configs=MATMUL_CONFIGS, + key=["M", "N", "K"], + prune_configs_by={ + "early_config_prune": prune_configs, + "perf_model": None, + "top_k": None, + }, +) +@triton.heuristics( + { + "EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0, + } +) +@triton.jit +def _kernel_matmul_fp8_row_non_persistent( + A, + B, + C, + M, + N, + K, + m_key, + n_key, + k_key, + A_scale, + B_scale, + stride_am, + stride_ak, + stride_bn, + stride_bk, + stride_cm, + stride_cn, + dot_out_dtype: tl.constexpr, + allow_tf32: tl.constexpr, + fp8_fast_accum: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + GROUP_M: tl.constexpr, + SPLIT_K: tl.constexpr, + EVEN_K: tl.constexpr, + AB_DTYPE: tl.constexpr, +) -> None: + """Matmul kernel of [M, K] @ [N, K] with row-wise scales + + performs swizzled matmul in [BLOCK_M, BLOCK_K] with [BLOCK_K, BLOCK_N] tiles. + + Args: + A (TensorWrapper): [M, K] input tensor. + B (TensorWrapper): [N, K] input tensor. + C (TensorWrapper): [M, N] output tensor. + M (int): M dimension of input tensor. + N (int): N dimension of input tensor. + K (int): K dimension of input tensor. + m_key (int): Autotuning key for M dimension of input tensor. + n_key (int): Autotuning key for N dimension of input tensor. + k_key (int): Autotuning key for K dimension of input tensor. + A_scale (TensorWrapper): [M] reciprocal scale tensor per row. A * A_scale = original A + B_scale (TensorWrapper): [N] reciprocal scale tensor per row. B * B_scale = original B + stride_am (int): Stride of M dimension of A. + stride_ak (int): Stride of K dimension of A. + stride_bn (int): Stride of N dimension of B. + stride_bk (int): Stride of K dimension of B. + stride_cm (int): Stride of M dimension of C. + stride_cn (int): Stride of N dimension of C. + dot_out_dtype (torch.dtype): Output type of tensor core. + allow_tf32 (bool): Whether to use TF32 for tensor core. + fp8_fast_accum (bool): Whether to use fast accumulation for tensor core. + BLOCK_M (int): Block size for M dimension. + BLOCK_N (int): Block size for N dimension. + BLOCK_K (int): Block size for K dimension. + GROUP_M (int): Number of groups for M dimension swizzle. + SPLIT_K (int): Number of SM's to launch per row. + EVEN_K (bool): Whether K is evenly divisible by BLOCK_K * SPLIT_K. + AB_DTYPE (bool): Wether to cast A and B to C.dtype before tensor core. + """ + # Matrix multiplication. + pid = tl.program_id(0) + pid_z = tl.program_id(1) + grid_m = tl.cdiv(M, BLOCK_M) + grid_n = tl.cdiv(N, BLOCK_N) + # Re-order program ID for better L2 performance (swizzle). + width = GROUP_M * grid_n + group_id = pid // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // (group_size) + # Do matrix multiplication. + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K) + # Pointers. + A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) + B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=dot_out_dtype) + + for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)): + if EVEN_K: + a = tl.load(A) + b = tl.load(B) + else: + k_remaining = K - k * (BLOCK_K * SPLIT_K) + _0 = tl.zeros((1, 1), dtype=C.dtype.element_ty) + a = tl.load(A, mask=rk[None, :] < k_remaining, other=_0) + b = tl.load(B, mask=rk[:, None] < k_remaining, other=_0) + if AB_DTYPE: + a = a.to(C.dtype.element_ty) + b = b.to(C.dtype.element_ty) + if fp8_fast_accum: + acc = tl.dot(a, b, acc, out_dtype=dot_out_dtype, allow_tf32=allow_tf32) + else: + acc += tl.dot(a, b, out_dtype=dot_out_dtype, allow_tf32=allow_tf32) + + A += BLOCK_K * SPLIT_K * stride_ak + B += BLOCK_K * SPLIT_K * stride_bk + + # rematerialize rm and rn to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + + # Invert scaling. + a_scale = tl.load(A_scale + rm, mask=rm < M) + b_scale = tl.load(B_scale + rn, mask=rn < N) + # Invert vector, then multiply on matrix for speed. + # pyre-ignore[16]: Undefined attribute [16]: `float` has no attribute `__getitem__`. + scale = a_scale[:, None] * b_scale[None, :] + acc *= scale + + acc = acc.to(C.dtype.element_ty) + C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) + mask = (rm < M)[:, None] & (rn < N)[None, :] + # Handles write-back with reduction-splitting + if SPLIT_K == 1: + tl.store(C, acc, mask=mask) + else: + tl.atomic_add(C, acc, mask=mask) From d9ae5c498f95b86a3272ce21ac647ce6027adc0b Mon Sep 17 00:00:00 2001 From: Jing Zhang Date: Thu, 3 Oct 2024 11:23:51 -0700 Subject: [PATCH 41/48] Add i-cache flush for AMD GPUs into FBGEMM (#3208) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3208 X-link: https://github.com/facebookresearch/FBGEMM/pull/307 - Add a function into FBGEMM to flush i-cache Reviewed By: zixi-qi Differential Revision: D63296513 fbshipit-source-id: aa215abd1b020623ebd55083b9cc1b6a34373e47 --- .../src/quantize/ck_extensions/ck_utility.hip | 44 +++++++++++++++++++ .../gen_ai/src/quantize/quantize.cpp | 10 +++++ 2 files changed, 54 insertions(+) create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/ck_utility.hip diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/ck_utility.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/ck_utility.hip new file mode 100644 index 000000000..25f532ad3 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/ck_utility.hip @@ -0,0 +1,44 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#if defined(USE_ROCM) + +#include "ck/ck.hpp" +#include "ck/stream_config.hpp" +#include "ck/host_utility/hip_check_error.hpp" +#include "ck/utility/flush_icache.hpp" + +namespace fbgemm_gpu { + +void flush_icache_ck() +{ + hipDeviceProp_t deviceProps; + hip_check_error(hipGetDeviceProperties(&deviceProps, 0)); + int32_t gpu_block3 = deviceProps.multiProcessorCount * 60; + + auto stream = at::cuda::getCurrentHIPStream().stream(); + + ck::flush_icache<<>>(); + hip_check_error(hipGetLastError()); +} + +} // namespace fbgemm_gpu + +#endif // defined(USE_ROCM) diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp b/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp index ff5c66766..101a5cba1 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp @@ -27,6 +27,11 @@ namespace fbgemm_gpu { +#ifdef USE_ROCM +// flush icache +void flush_icache_ck(); +#endif + // SmoothQuant kernels at::Tensor i8i8bf16(at::Tensor XQ, at::Tensor WQ, double scale, int64_t split_k); @@ -185,6 +190,11 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { m.impl( "quantize_fp8_per_tensor_fixed_scale", quantize_fp8_per_tensor_fixed_scale); + +#ifdef USE_ROCM + m.def("flush_icache_hip() -> ()"); + m.impl("flush_icache_hip", flush_icache_ck); +#endif } TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) { From 788cd2acf68d56a742f94ecaeacc50bd1d88916a Mon Sep 17 00:00:00 2001 From: Jun Luo Date: Thu, 3 Oct 2024 11:45:37 -0700 Subject: [PATCH 42/48] Refactor the GIS to reuse same autograd function for all backends (#3216) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3216 X-link: https://github.com/facebookresearch/FBGEMM/pull/313 This diff refactors the GIS impl to share the AutogradFunc for all device backends. Device backends only need to impl and register following two ops. For backward compatibility, we keep the op signature the same. ``` torch::autograd::variable_list group_index_select_dim0_forward_impl_cpu( at::TensorList all_indices_input, const int64_t group_size) { throw std::runtime_error( "group_index_select_dim0_forward_impl is not implemented for CPU"); } torch::autograd::variable_list group_index_select_dim0_backward_impl_cpu( at::TensorList all_inputs, c10::SymIntArrayRef output_shape_group_ref) { throw std::runtime_error( "group_index_select_dim0_backward_impl is not implemented for CPU"); } ``` ``` DISPATCH_TO_CPU( "group_index_select_dim0_gpu_impl", fbgemm_gpu::group_index_select_dim0_forward_impl_cpu); DISPATCH_TO_CPU( "group_index_select_dim0_gpu_backward", fbgemm_gpu::group_index_select_dim0_forward_impl_cpu); ``` Reviewed By: spcyppt Differential Revision: D63809121 fbshipit-source-id: 4f6ca5bdf1810241730a77ab125d6aef31b0cd5b --- fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h | 41 + fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp | 176 ++++- fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp | 767 ++++++++----------- 3 files changed, 512 insertions(+), 472 deletions(-) diff --git a/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h b/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h index 9bea430ef..41ba190fc 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h +++ b/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h @@ -9,8 +9,11 @@ #pragma once #include +#include #include #include +#include + #include namespace fbgemm_gpu { @@ -924,6 +927,44 @@ at::Tensor index_add_with_unique_indices_cuda( const int consecutive_range_start, const int consecutive_range_length); +torch::autograd::variable_list group_index_select_dim0_decomposed( + at::TensorList input_group, + at::TensorList indices_group); + +torch::autograd::variable_list group_index_select_dim0_autograd_impl( + at::TensorList all_indices_input, + const int64_t group_size); + +torch::autograd::variable_list group_index_select_dim0( + at::TensorList input_group, + at::TensorList indices_group); + +torch::autograd::variable_list group_index_select_dim0_forward_impl_cpu( + at::TensorList all_indices_input, + const int64_t group_size); + +torch::autograd::variable_list group_index_select_dim0_backward_impl_cpu( + at::TensorList all_inputs, + c10::SymIntArrayRef output_shape_group_ref); + +std::pair, std::vector> +group_index_select_dim0_unpack( + at::TensorList all_indices_input, + const int64_t group_size); + +class GroupIndexSelectDim0Op + : public torch::autograd::Function { + public: + static torch::autograd::variable_list forward( + torch::autograd::AutogradContext* ctx, + at::TensorList all_indices_input, + const int64_t group_size); + + static torch::autograd::variable_list backward( + torch::autograd::AutogradContext* ctx, + torch::autograd::variable_list grad_output_group); +}; + ///@ingroup sparse-data-cuda void group_index_select_or_add_cuda( const int64_t* input_ptrs, diff --git a/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp b/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp index 7734cc69a..88d9ef2e6 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp +++ b/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp @@ -2851,19 +2851,84 @@ Tensor pack_segments_cpu( const int64_t max_length) { return pack_segments_forward_cpu(t_in, lengths, max_length); } -namespace { -Tensor index_select_dim0( - const Tensor& input, - const Tensor& indices, - std::optional /*consecutive_range_start*/, - std::optional /*consecutive_range_length*/, - std::optional /*skip_indices_sorting_fwd*/) { - return at::index_select(input, 0, indices); + +torch::autograd::variable_list group_index_select_dim0_autograd_impl( + at::TensorList all_indices_input, + const int64_t group_size) { + return GroupIndexSelectDim0Op::apply(all_indices_input, group_size); +} + +std::pair, std::vector> +group_index_select_dim0_unpack( + at::TensorList all_indices_input, + const int64_t group_size) { + std::vector indices_group; + std::vector input_group; + + indices_group.reserve(group_size); + input_group.reserve(group_size); + + for (const auto i : c10::irange(group_size)) { + indices_group.push_back(all_indices_input[i]); + input_group.push_back(all_indices_input[group_size + i]); + } + + TORCH_CHECK(group_size == static_cast(indices_group.size())); + + return std::make_pair(input_group, indices_group); } torch::autograd::variable_list group_index_select_dim0( at::TensorList input_group, at::TensorList indices_group) { + const auto group_size = indices_group.size(); + std::vector output_group; + + if (group_size == 0) { + return std::vector(); + } + + // Pack input_group and indices_group into TensorList + std::vector all_indices_input_vec; + all_indices_input_vec.reserve(group_size * 2); + + for (const Tensor& index : indices_group) { + all_indices_input_vec.push_back(index); + } + for (const Tensor& input : input_group) { + all_indices_input_vec.push_back(input); + } + + at::TensorList all_indices_input_tensor = all_indices_input_vec; + + static auto forward_op = + at::Dispatcher::singleton() + .findSchemaOrThrow("fbgemm::group_index_select_dim0_gpu_impl", "") + .typed(); + auto res = forward_op.call(all_indices_input_tensor, group_size); + TORCH_CHECK(res.size() == group_size + 2); + // only return the outputs (the first group_size elements) + res.resize(group_size); + return res; +} + +torch::autograd::variable_list group_index_select_dim0_forward_impl_cpu( + at::TensorList all_indices_input, + const int64_t group_size) { + throw std::runtime_error( + "group_index_select_dim0_forward_impl is not implemented for CPU"); +} + +torch::autograd::variable_list group_index_select_dim0_backward_impl_cpu( + at::TensorList all_inputs, + c10::SymIntArrayRef output_shape_group_ref) { + throw std::runtime_error( + "group_index_select_dim0_backward_impl is not implemented for CPU"); +} + +torch::autograd::variable_list group_index_select_dim0_decomposed( + at::TensorList input_group, + at::TensorList indices_group) { int num_groups = input_group.size(); TORCH_CHECK(num_groups == (int)indices_group.size()) std::vector output_group; @@ -2874,18 +2939,83 @@ torch::autograd::variable_list group_index_select_dim0( return output_group; } -torch::autograd::variable_list group_index_select_dim0_gpu_impl_cpu( +torch::autograd::variable_list GroupIndexSelectDim0Op::forward( + torch::autograd::AutogradContext* ctx, at::TensorList all_indices_input, const int64_t group_size) { - throw std::runtime_error( - "group_index_select_dim0_gpu_impl is not implemented for CPU"); + at::AutoDispatchBelowADInplaceOrView guard; + static auto forward_op = + at::Dispatcher::singleton() + .findSchemaOrThrow("fbgemm::group_index_select_dim0_gpu_impl", "") + .typed(); + auto result = forward_op.call(all_indices_input, group_size); + TORCH_CHECK(static_cast(result.size()) == group_size + 2); + + auto [input_group, indices_group] = + group_index_select_dim0_unpack(all_indices_input, group_size); + const auto input_dim = input_group[0].dim(); + std::vector input_shape_group; + input_shape_group.reserve(group_size * input_dim); + + for (const auto i : c10::irange(group_size)) { + const auto& input = input_group[i]; + // Copy input shape + auto input_shape = input.sym_sizes().vec(); + input_shape_group.insert( + input_shape_group.end(), input_shape.begin(), input_shape.end()); + } + + // save indices, args_tensor, saved_data + auto saved_tensors = std::vector(indices_group); + saved_tensors.insert( + saved_tensors.end(), result.cbegin() + group_size, result.cend()); + saved_tensors.push_back(input_group[0]); + ctx->save_for_backward(saved_tensors); + ctx->saved_data["input_shape_group"] = input_shape_group; + + return result; } -torch::autograd::variable_list group_index_select_dim0_gpu_backward_cpu( - at::TensorList all_inputs, - c10::SymIntArrayRef output_shape_group_ref) { - throw std::runtime_error( - "group_index_select_dim0_gpu_backward is not implemented for CPU"); +torch::autograd::variable_list GroupIndexSelectDim0Op::backward( + torch::autograd::AutogradContext* ctx, + torch::autograd::variable_list grad_output_group) { + TORCH_CHECK(grad_output_group.size() >= 2); + if (grad_output_group.size() == 2) { + // empty outputs + return torch::autograd::variable_list(1); + } + // remove redundant grads + auto group_size = grad_output_group.size() - 2; + grad_output_group.resize(group_size); + + auto saved_tensors = ctx->get_saved_variables(); + TORCH_CHECK(saved_tensors.size() == group_size + 3); + auto output_shape_group = + ctx->saved_data["input_shape_group"].toSymIntVector(); + grad_output_group.insert( + grad_output_group.end(), saved_tensors.begin(), saved_tensors.end()); + static auto backward_op = + at::Dispatcher::singleton() + .findSchemaOrThrow("fbgemm::group_index_select_dim0_gpu_backward", "") + .typed(); + auto res = backward_op.call(grad_output_group, output_shape_group); + // 1) Add group_size Variable()'s for indices + // Replace all empty tensors with Variable(). This must be done after the + // op.call to make __torch_dispatch__ work for the backward op. + std::fill(res.begin(), res.begin() + group_size, torch::autograd::Variable()); + // 3) Add 1 Variable() for group_size + res.push_back({}); + return res; +} + +namespace { +Tensor index_select_dim0( + const Tensor& input, + const Tensor& indices, + std::optional /*consecutive_range_start*/, + std::optional /*consecutive_range_length*/, + std::optional /*skip_indices_sorting_fwd*/) { + return at::index_select(input, 0, indices); } Tensor bottom_k_per_row( @@ -3132,13 +3262,14 @@ TORCH_LIBRARY_IMPL(fbgemm, CPU, m) { "pack_segments_backward", fbgemm_gpu::pack_segments_backward_cpu); DISPATCH_TO_CPU("index_select_dim0", fbgemm_gpu::index_select_dim0); DISPATCH_TO_CPU( - "group_index_select_dim0", fbgemm_gpu::group_index_select_dim0); + "group_index_select_dim0", + fbgemm_gpu::group_index_select_dim0_decomposed); DISPATCH_TO_CPU( "group_index_select_dim0_gpu_impl", - fbgemm_gpu::group_index_select_dim0_gpu_impl_cpu); + fbgemm_gpu::group_index_select_dim0_forward_impl_cpu); DISPATCH_TO_CPU( "group_index_select_dim0_gpu_backward", - fbgemm_gpu::group_index_select_dim0_gpu_backward_cpu); + fbgemm_gpu::group_index_select_dim0_backward_impl_cpu); DISPATCH_TO_CPU("bottom_k_per_row", fbgemm_gpu::bottom_k_per_row); } @@ -3147,11 +3278,14 @@ TORCH_LIBRARY_IMPL(fbgemm, Autograd, m) { } TORCH_LIBRARY_IMPL(fbgemm, AutogradCPU, m) { - m.impl("group_index_select_dim0", &fbgemm_gpu::group_index_select_dim0); + m.impl( + "group_index_select_dim0", + &fbgemm_gpu::group_index_select_dim0_decomposed); } TORCH_LIBRARY_IMPL(fbgemm, Meta, m) { // CPU group_index_select_dim0 is decomposable m.impl( - "group_index_select_dim0", TORCH_FN(fbgemm_gpu::group_index_select_dim0)); + "group_index_select_dim0", + TORCH_FN(fbgemm_gpu::group_index_select_dim0_decomposed)); } diff --git a/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp b/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp index 6325017e8..0c3966fc3 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp +++ b/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp @@ -193,442 +193,346 @@ class IndexSelectDim0GPUOp } }; -std::pair, std::vector> -group_index_select_dim0_unpack( +// need to combine input_group and indices_group into one tensor list +// to get this working with autograd. +static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu( at::TensorList all_indices_input, const int64_t group_size) { - std::vector indices_group; - std::vector input_group; + // Unpack from TensorList + auto [input_group, indices_group] = + group_index_select_dim0_unpack(all_indices_input, group_size); + + // args_tensor stores kernel arguments: + // input_ptrs (group_size int64_t elements) + // output_ptrs (group_size int64_t elements) + // indices_ptrs (group_size int64_t elements) + // warp_offsets_group (group_size + 1 int64_t elements) + // num_cols_group (group_size int32_t elements) + int64_t args_ptrs_offsets[NUM_ARGS + 1]; + + const int64_t numels_num_cols_group_64 = + compute_num_int64s(group_size); + + // Initialize offsets + args_ptrs_offsets[P_input_ptrs] = group_size; + args_ptrs_offsets[P_output_ptrs] = group_size; + args_ptrs_offsets[P_indices_ptrs] = group_size; + args_ptrs_offsets[P_warp_offsets_group_ptrs] = group_size + 1; + args_ptrs_offsets[P_num_cols_group_ptrs] = numels_num_cols_group_64; + + // Compute offsets + int64_t offset = 0; + auto next = args_ptrs_offsets[0]; + for (const auto i : c10::irange(NUM_ARGS)) { + args_ptrs_offsets[i] = offset; + offset += next; + next = args_ptrs_offsets[i + 1]; + } + // Total number of int64_t elements required + args_ptrs_offsets[NUM_ARGS] = offset; + + // Allocate memory for GroupIndexSelectArgs + at::Tensor args_tensor = at::empty( + {static_cast(args_ptrs_offsets[NUM_ARGS] * sizeof(int64_t))}, + at::TensorOptions().dtype(at::kByte).pinned_memory(true)); + + // Ensure that args_tensor is contiguous + TORCH_CHECK(args_tensor.is_contiguous()); + + // Initialize raw pointers to point to Tensor args_tensor + int64_t* input_ptrs = nullptr; + int64_t* output_ptrs = nullptr; + int64_t* indices_ptrs = nullptr; + int64_t* warp_offsets_group = nullptr; + int32_t* num_cols_group = nullptr; + + // Offset host pointers + offset_args( + &input_ptrs, + &output_ptrs, + &indices_ptrs, + &warp_offsets_group, + &num_cols_group, + reinterpret_cast(args_tensor.data_ptr()), + args_ptrs_offsets); + + auto& first_input = input_group[0]; + auto& first_indices = indices_group[0]; + + const int input_dim = first_input.dim(); + const int num_output_rows = first_indices.size(0); + const int num_input_rows = first_input.size(0); + Tensor input_reshaped = first_input.reshape({num_input_rows, -1}); + const int num_cols = input_reshaped.size(1); + const int cols_per_warp = get_group_index_select_cols_per_warp(); + int64_t warp_offset = 0; + bool use_var_cols = false; + + // Allocate memory for output_group + std::vector output_group; + output_group.reserve(group_size + 2); - indices_group.reserve(group_size); - input_group.reserve(group_size); + // We need to store contiguous inputs and indices outside the for-loop to + // guarantee that the contiguous tensors will outlive the kernel + // computation + std::vector> input_contigs; + std::vector> index_contigs; + input_contigs.reserve(group_size); + index_contigs.reserve(group_size); + // For each group, copy input to output for (const auto i : c10::irange(group_size)) { - indices_group.push_back(all_indices_input[i]); - input_group.push_back(all_indices_input[group_size + i]); - } + const auto& input = input_group[i]; + const auto& indices = indices_group[i]; - TORCH_CHECK(group_size == static_cast(indices_group.size())); + // Verify that all input tensors have the same number of dimensions + TORCH_CHECK( + input_dim == input.dim(), + "All inputs in group_index_select must have the same number of dimensions"); - return std::make_pair(input_group, indices_group); -} + // Verify that all tensors are on the same GPU + TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(input, indices); -class GroupIndexSelectDim0GPUOp - : public torch::autograd::Function { - public: - // need to combine input_group and indices_group into one tensor list - // to get this working with autograd. - static torch::autograd::variable_list forward_impl( - at::TensorList all_indices_input, - const int64_t group_size) { - // Unpack from TensorList - auto [input_group, indices_group] = - group_index_select_dim0_unpack(all_indices_input, group_size); - - // args_tensor stores kernel arguments: - // input_ptrs (group_size int64_t elements) - // output_ptrs (group_size int64_t elements) - // indices_ptrs (group_size int64_t elements) - // warp_offsets_group (group_size + 1 int64_t elements) - // num_cols_group (group_size int32_t elements) - int64_t args_ptrs_offsets[NUM_ARGS + 1]; - - const int64_t numels_num_cols_group_64 = - compute_num_int64s(group_size); - - // Initialize offsets - args_ptrs_offsets[P_input_ptrs] = group_size; - args_ptrs_offsets[P_output_ptrs] = group_size; - args_ptrs_offsets[P_indices_ptrs] = group_size; - args_ptrs_offsets[P_warp_offsets_group_ptrs] = group_size + 1; - args_ptrs_offsets[P_num_cols_group_ptrs] = numels_num_cols_group_64; - - // Compute offsets - int64_t offset = 0; - auto next = args_ptrs_offsets[0]; - for (const auto i : c10::irange(NUM_ARGS)) { - args_ptrs_offsets[i] = offset; - offset += next; - next = args_ptrs_offsets[i + 1]; + auto num_output_rows_ = indices.size(0); + + // Verify that all input tensors have the same shape[0] + TORCH_CHECK( + num_output_rows == num_output_rows_, + "The number of indices to be selected must be the same for the entire group"); + const auto input_reshaped_ = input.reshape({input.size(0), -1}); + + // Number of columns can be different + auto num_cols_ = input_reshaped_.size(1); + auto warps_per_row = (num_cols_ + cols_per_warp - 1) / cols_per_warp; + + if (num_cols != num_cols_) { + use_var_cols = true; } - // Total number of int64_t elements required - args_ptrs_offsets[NUM_ARGS] = offset; - - // Allocate memory for GroupIndexSelectArgs - at::Tensor args_tensor = at::empty( - {static_cast(args_ptrs_offsets[NUM_ARGS] * sizeof(int64_t))}, - at::TensorOptions().dtype(at::kByte).pinned_memory(true)); - - // Ensure that args_tensor is contiguous - TORCH_CHECK(args_tensor.is_contiguous()); - - // Initialize raw pointers to point to Tensor args_tensor - int64_t* input_ptrs = nullptr; - int64_t* output_ptrs = nullptr; - int64_t* indices_ptrs = nullptr; - int64_t* warp_offsets_group = nullptr; - int32_t* num_cols_group = nullptr; - - // Offset host pointers - offset_args( - &input_ptrs, - &output_ptrs, - &indices_ptrs, - &warp_offsets_group, - &num_cols_group, - reinterpret_cast(args_tensor.data_ptr()), - args_ptrs_offsets); - - auto& first_input = input_group[0]; - auto& first_indices = indices_group[0]; - - const int input_dim = first_input.dim(); - const int num_output_rows = first_indices.size(0); - const int num_input_rows = first_input.size(0); - Tensor input_reshaped = first_input.reshape({num_input_rows, -1}); - const int num_cols = input_reshaped.size(1); - const int cols_per_warp = get_group_index_select_cols_per_warp(); - int64_t warp_offset = 0; - bool use_var_cols = false; - - // Allocate memory for output_group - std::vector output_group; - output_group.reserve(group_size + 2); - - // We need to store contiguous inputs and indices outside the for-loop to - // guarantee that the contiguous tensors will outlive the kernel + + // Create output pointers + auto input_shape = input.sizes().vec(); + input_shape[0] = num_output_rows_; + Tensor output = at::empty(input_shape, input.options()); + // Ensure that the allocated output is contiguous + TORCH_CHECK(output.is_contiguous()) + output_group.push_back(output); + + // Store input and indices contigs to keep them alive during the kernel // computation - std::vector> input_contigs; - std::vector> index_contigs; - input_contigs.reserve(group_size); - index_contigs.reserve(group_size); - - // For each group, copy input to output - for (const auto i : c10::irange(group_size)) { - const auto& input = input_group[i]; - const auto& indices = indices_group[i]; - - // Verify that all input tensors have the same number of dimensions - TORCH_CHECK( - input_dim == input.dim(), - "All inputs in group_index_select must have the same number of dimensions"); - - // Verify that all tensors are on the same GPU - TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(input, indices); - - auto num_output_rows_ = indices.size(0); - - // Verify that all input tensors have the same shape[0] - TORCH_CHECK( - num_output_rows == num_output_rows_, - "The number of indices to be selected must be the same for the entire group"); - const auto input_reshaped_ = input.reshape({input.size(0), -1}); - - // Number of columns can be different - auto num_cols_ = input_reshaped_.size(1); - auto warps_per_row = (num_cols_ + cols_per_warp - 1) / cols_per_warp; - - if (num_cols != num_cols_) { - use_var_cols = true; - } - - // Create output pointers - auto input_shape = input.sizes().vec(); - input_shape[0] = num_output_rows_; - Tensor output = at::empty(input_shape, input.options()); - // Ensure that the allocated output is contiguous - TORCH_CHECK(output.is_contiguous()) - output_group.push_back(output); - - // Store input and indices contigs to keep them alive during the kernel - // computation - input_contigs.push_back(input.expect_contiguous()); - index_contigs.push_back(indices.expect_contiguous()); - - // Store args - input_ptrs[i] = reinterpret_cast(input_contigs[i]->data_ptr()); - output_ptrs[i] = reinterpret_cast(output.data_ptr()); - indices_ptrs[i] = reinterpret_cast(index_contigs[i]->data_ptr()); - warp_offsets_group[i] = warp_offset; - num_cols_group[i] = num_cols_; - - warp_offset += warps_per_row * num_output_rows; - } + input_contigs.push_back(input.expect_contiguous()); + index_contigs.push_back(indices.expect_contiguous()); - // Store the last offset - warp_offsets_group[group_size] = warp_offset; - - // Transfer args tensor to GPU - args_tensor = args_tensor.to( - first_input.device(), - /*non_blocking=*/true); - - // Offset raw ptrs in GPU memory - offset_args( - &input_ptrs, - &output_ptrs, - &indices_ptrs, - &warp_offsets_group, - &num_cols_group, - reinterpret_cast(args_tensor.data_ptr()), - args_ptrs_offsets); - - int64_t saved_data[] = { - static_cast(group_size), - use_var_cols, - reinterpret_cast(warp_offsets_group), - reinterpret_cast(num_cols_group), - warp_offset, - }; - auto saved_data_t = at::empty( - {sizeof(saved_data) / sizeof(int64_t)}, - at::TensorOptions().dtype(at::kLong)); - TORCH_CHECK(saved_data_t.is_contiguous()); - memcpy(saved_data_t.data_ptr(), saved_data, sizeof(saved_data)); - - group_index_select_or_add_cuda( - input_ptrs, - output_ptrs, - indices_ptrs, - warp_offsets_group, - num_cols_group, - first_input.scalar_type(), - first_indices.scalar_type(), - first_input.device().index(), - num_output_rows, - /*total_num_warps=*/warp_offset, - group_size, - /*use_index_select=*/true, - use_var_cols); - - output_group.push_back(args_tensor); - output_group.push_back(saved_data_t); - - // return format: - // (group_size outputs, 1 args_tensor, 1 saved_data) - return output_group; + // Store args + input_ptrs[i] = reinterpret_cast(input_contigs[i]->data_ptr()); + output_ptrs[i] = reinterpret_cast(output.data_ptr()); + indices_ptrs[i] = reinterpret_cast(index_contigs[i]->data_ptr()); + warp_offsets_group[i] = warp_offset; + num_cols_group[i] = num_cols_; + + warp_offset += warps_per_row * num_output_rows; } - static torch::autograd::variable_list forward( - torch::autograd::AutogradContext* ctx, - at::TensorList all_indices_input, - const int64_t group_size) { - at::AutoDispatchBelowADInplaceOrView guard; - static auto forward_op = - torch::Dispatcher::singleton() - .findSchemaOrThrow("fbgemm::group_index_select_dim0_gpu_impl", "") - .typed(); - auto result = forward_op.call(all_indices_input, group_size); - TORCH_CHECK(static_cast(result.size()) == group_size + 2); - - auto [input_group, indices_group] = - group_index_select_dim0_unpack(all_indices_input, group_size); - const auto input_dim = input_group[0].dim(); - std::vector input_shape_group; - input_shape_group.reserve(group_size * input_dim); - - for (const auto i : c10::irange(group_size)) { - const auto& input = input_group[i]; - // Copy input shape - auto input_shape = input.sym_sizes().vec(); - input_shape_group.insert( - input_shape_group.end(), input_shape.begin(), input_shape.end()); - } + // Store the last offset + warp_offsets_group[group_size] = warp_offset; + + // Transfer args tensor to GPU + args_tensor = args_tensor.to( + first_input.device(), + /*non_blocking=*/true); + + // Offset raw ptrs in GPU memory + offset_args( + &input_ptrs, + &output_ptrs, + &indices_ptrs, + &warp_offsets_group, + &num_cols_group, + reinterpret_cast(args_tensor.data_ptr()), + args_ptrs_offsets); + + int64_t saved_data[] = { + static_cast(group_size), + use_var_cols, + reinterpret_cast(warp_offsets_group), + reinterpret_cast(num_cols_group), + warp_offset, + }; + auto saved_data_t = at::empty( + {sizeof(saved_data) / sizeof(int64_t)}, + at::TensorOptions().dtype(at::kLong)); + TORCH_CHECK(saved_data_t.is_contiguous()); + memcpy(saved_data_t.data_ptr(), saved_data, sizeof(saved_data)); + + group_index_select_or_add_cuda( + input_ptrs, + output_ptrs, + indices_ptrs, + warp_offsets_group, + num_cols_group, + first_input.scalar_type(), + first_indices.scalar_type(), + first_input.device().index(), + num_output_rows, + /*total_num_warps=*/warp_offset, + group_size, + /*use_index_select=*/true, + use_var_cols); + + output_group.push_back(args_tensor); + output_group.push_back(saved_data_t); + + // return format: + // (group_size outputs, 1 args_tensor, 1 saved_data) + return output_group; +} - // save indices, args_tensor, saved_data - auto saved_tensors = std::vector(indices_group); - saved_tensors.insert( - saved_tensors.end(), result.cbegin() + group_size, result.cend()); - saved_tensors.push_back(input_group[0]); - ctx->save_for_backward(saved_tensors); - ctx->saved_data["input_shape_group"] = input_shape_group; +static torch::autograd::variable_list group_index_select_dim0_backward_impl_gpu( + at::TensorList all_inputs, + c10::SymIntArrayRef output_shape_group_ref) { + TORCH_CHECK(all_inputs.size() > 2); + + const int64_t group_size = (all_inputs.size() - 3) / 2; + + Tensor fwd_input = all_inputs[2 * group_size + 2]; + const int64_t output_dim = fwd_input.dim(); + Tensor saved_data = all_inputs[2 * group_size + 1]; + Tensor args_tensor_old = all_inputs[2 * group_size]; + Tensor first_indices = all_inputs[group_size]; + + auto grad_output_group = std::vector( + all_inputs.cbegin(), all_inputs.cbegin() + group_size); + std::vector output_shape_group; + output_shape_group.reserve(output_shape_group_ref.size()); + for (const auto& i : output_shape_group_ref) { + output_shape_group.push_back(i.as_int_unchecked()); + } - return result; + auto indices_group = std::vector( + all_inputs.cbegin() + group_size, all_inputs.cbegin() + 2 * group_size); + + // Retrieve saved data + TORCH_CHECK(saved_data.device() == at::kCPU); + TORCH_CHECK(saved_data.is_contiguous()); + int64_t* saved_data_ptr = saved_data.data_ptr(); + // Check that the size is the same + TORCH_CHECK(saved_data_ptr[0] == group_size); + const bool use_var_cols = saved_data_ptr[1]; + int64_t* warp_offsets_group = reinterpret_cast(saved_data_ptr[2]); + int32_t* num_cols_group = reinterpret_cast(saved_data_ptr[3]); + int64_t total_num_warps = saved_data_ptr[4]; + + // We checked in forward that all output rows are the same for all member + // in the group + const int num_input_rows = grad_output_group[0].size(0); + + std::vector outputs; + // Returning 3 outputs: + // 1) group_size Variable()'s for indices + // 2) group_size gradients for inputs + // 3) 1 Variable() for group_size + outputs.reserve(group_size * 2 + 1); + + // 1) Add group_size Variable()'s for indices + // c10::irange cannot be used in here as it + // triggers a build error of i being an unused variable. + // Add empty tensor with zero size here to make __torch_dispatch__ work for + // the backward op. Those empty tensors will be replaced with + // torch::autograd::Variable() outside of the op call. + for (auto i = 0; i < group_size; i++) { + outputs.push_back(at::empty({0}, at::TensorOptions().dtype(at::kLong))); } - static torch::autograd::variable_list backward_impl( - at::TensorList all_inputs, - c10::SymIntArrayRef output_shape_group_ref) { - TORCH_CHECK(all_inputs.size() > 2); - - const int64_t group_size = (all_inputs.size() - 3) / 2; - - Tensor fwd_input = all_inputs[2 * group_size + 2]; - const int64_t output_dim = fwd_input.dim(); - Tensor saved_data = all_inputs[2 * group_size + 1]; - Tensor args_tensor_old = all_inputs[2 * group_size]; - Tensor first_indices = all_inputs[group_size]; - - auto grad_output_group = std::vector( - all_inputs.cbegin(), all_inputs.cbegin() + group_size); - std::vector output_shape_group; - output_shape_group.reserve(output_shape_group_ref.size()); - for (const auto& i : output_shape_group_ref) { - output_shape_group.push_back(i.as_int_unchecked()); - } + // Allocate Tensor for ptrs of grad output and input, and indices + Tensor args_tensor = at::empty( + {group_size * 3}, + at::TensorOptions().dtype(at::kLong).pinned_memory(true)); + // Ensure that args_tensor is contiguous + TORCH_CHECK(args_tensor.is_contiguous()); + int64_t* grad_output_ptrs = args_tensor.data_ptr(); + int64_t* grad_input_ptrs = args_tensor.data_ptr() + group_size; + int64_t* indices_ptrs = args_tensor.data_ptr() + 2 * group_size; + + int64_t group_grad_input_numel = 0; + std::vector grad_input_numels; + grad_input_numels.reserve(group_size); + + // We need to store contiguous gradients outside the for-loop to guarantee + // that the contiguous tensors will outlive the kernel computation + std::vector> grad_output_contigs; + grad_output_contigs.reserve(group_size); - auto indices_group = std::vector( - all_inputs.cbegin() + group_size, all_inputs.cbegin() + 2 * group_size); - - // Retrieve saved data - TORCH_CHECK(saved_data.device() == at::kCPU); - TORCH_CHECK(saved_data.is_contiguous()); - int64_t* saved_data_ptr = saved_data.data_ptr(); - // Check that the size is the same - TORCH_CHECK(saved_data_ptr[0] == group_size); - const bool use_var_cols = saved_data_ptr[1]; - int64_t* warp_offsets_group = reinterpret_cast(saved_data_ptr[2]); - int32_t* num_cols_group = reinterpret_cast(saved_data_ptr[3]); - int64_t total_num_warps = saved_data_ptr[4]; - - // We checked in forward that all output rows are the same for all member - // in the group - const int num_input_rows = grad_output_group[0].size(0); - - std::vector outputs; - // Returning 3 outputs: - // 1) group_size Variable()'s for indices - // 2) group_size gradients for inputs - // 3) 1 Variable() for group_size - outputs.reserve(group_size * 2 + 1); - - // 1) Add group_size Variable()'s for indices - // c10::irange cannot be used in here as it - // triggers a build error of i being an unused variable. - // Add empty tensor with zero size here to make __torch_dispatch__ work for - // the backward op. Those empty tensors will be replaced with - // torch::autograd::Variable() outside of the op call. - for (auto i = 0; i < group_size; i++) { - outputs.push_back(at::empty({0}, at::TensorOptions().dtype(at::kLong))); - } + for (const auto i : c10::irange(group_size)) { + const auto& grad = grad_output_group[i]; + TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(grad, first_indices); - // Allocate Tensor for ptrs of grad output and input, and indices - Tensor args_tensor = at::empty( - {group_size * 3}, - at::TensorOptions().dtype(at::kLong).pinned_memory(true)); - // Ensure that args_tensor is contiguous - TORCH_CHECK(args_tensor.is_contiguous()); - int64_t* grad_output_ptrs = args_tensor.data_ptr(); - int64_t* grad_input_ptrs = args_tensor.data_ptr() + group_size; - int64_t* indices_ptrs = args_tensor.data_ptr() + 2 * group_size; - - int64_t group_grad_input_numel = 0; - std::vector grad_input_numels; - grad_input_numels.reserve(group_size); - - // We need to store contiguous gradients outside the for-loop to guarantee - // that the contiguous tensors will outlive the kernel computation - std::vector> grad_output_contigs; - grad_output_contigs.reserve(group_size); - - for (const auto i : c10::irange(group_size)) { - const auto& grad = grad_output_group[i]; - TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(grad, first_indices); - - // Store grad contigs to keep them alive during the kernel computation - grad_output_contigs.push_back(grad.expect_contiguous()); - - // Compute the total number of elements for all grad_inputs - int64_t grad_input_numel = output_shape_group[i * output_dim]; - for (auto j = (i * output_dim) + 1; j < (i + 1) * output_dim; j++) { - grad_input_numel *= output_shape_group[j]; - } - grad_input_numels.push_back(grad_input_numel); - group_grad_input_numel += grad_input_numel; - - // Put all grad output/input pointers in an array - grad_output_ptrs[i] = - reinterpret_cast(grad_output_contigs[i]->data_ptr()); - } + // Store grad contigs to keep them alive during the kernel computation + grad_output_contigs.push_back(grad.expect_contiguous()); - // Allocate a big tensor to avoid calling many small elementwise kernels - const auto group_grad_input = - at::zeros({group_grad_input_numel}, fwd_input.options()); - TORCH_CHECK(group_grad_input.is_contiguous()); + // Compute the total number of elements for all grad_inputs + int64_t grad_input_numel = output_shape_group[i * output_dim]; + for (auto j = (i * output_dim) + 1; j < (i + 1) * output_dim; j++) { + grad_input_numel *= output_shape_group[j]; + } + grad_input_numels.push_back(grad_input_numel); + group_grad_input_numel += grad_input_numel; - // Split to output_group - auto output_group = group_grad_input.split(grad_input_numels, 0); + // Put all grad output/input pointers in an array + grad_output_ptrs[i] = + reinterpret_cast(grad_output_contigs[i]->data_ptr()); + } - TORCH_CHECK(output_group.size() == static_cast(group_size)); + // Allocate a big tensor to avoid calling many small elementwise kernels + const auto group_grad_input = + at::zeros({group_grad_input_numel}, fwd_input.options()); + TORCH_CHECK(group_grad_input.is_contiguous()); - // Reshape grad inputs and obtain their pointers - for (int i = 0; i < group_size; i++) { - const auto grad_input_shape = std::vector( - output_shape_group.begin() + i * output_dim, - output_shape_group.begin() + (i + 1) * output_dim); - output_group[i] = output_group[i].reshape(grad_input_shape); - TORCH_CHECK(output_group[i].is_contiguous()); - grad_input_ptrs[i] = - reinterpret_cast(output_group[i].data_ptr()); + // Split to output_group + auto output_group = group_grad_input.split(grad_input_numels, 0); - // 2) Add group_size gradients for inputs - outputs.push_back(output_group[i]); - } + TORCH_CHECK(output_group.size() == static_cast(group_size)); - // Calculate indices_ptrs - std::vector> index_contigs; - index_contigs.reserve(group_size); - for (const auto i : c10::irange(group_size)) { - const auto& indices = indices_group[i]; - index_contigs.push_back(indices.expect_contiguous()); - indices_ptrs[i] = reinterpret_cast(index_contigs[i]->data_ptr()); - } + // Reshape grad inputs and obtain their pointers + for (int i = 0; i < group_size; i++) { + const auto grad_input_shape = std::vector( + output_shape_group.begin() + i * output_dim, + output_shape_group.begin() + (i + 1) * output_dim); + output_group[i] = output_group[i].reshape(grad_input_shape); + TORCH_CHECK(output_group[i].is_contiguous()); + grad_input_ptrs[i] = reinterpret_cast(output_group[i].data_ptr()); - // Transfer grad output pointers to GPU - args_tensor = args_tensor.to(first_indices.device(), /*non_blocking=*/true); - - group_index_select_or_add_cuda( - args_tensor.data_ptr(), - args_tensor.data_ptr() + group_size, - args_tensor.data_ptr() + 2 * group_size, - warp_offsets_group, - num_cols_group, - fwd_input.scalar_type(), - first_indices.scalar_type(), - fwd_input.device().index(), - num_input_rows, - total_num_warps, - group_size, - /*use_index_select=*/false, - use_var_cols); - - return outputs; + // 2) Add group_size gradients for inputs + outputs.push_back(output_group[i]); } - static torch::autograd::variable_list backward( - torch::autograd::AutogradContext* ctx, - torch::autograd::variable_list grad_output_group) { - TORCH_CHECK(grad_output_group.size() >= 2); - if (grad_output_group.size() == 2) { - // empty outputs - return torch::autograd::variable_list(1); - } - // remove redundant grads - auto group_size = grad_output_group.size() - 2; - grad_output_group.resize(group_size); - - auto saved_tensors = ctx->get_saved_variables(); - TORCH_CHECK(saved_tensors.size() == group_size + 3); - auto output_shape_group = - ctx->saved_data["input_shape_group"].toSymIntVector(); - grad_output_group.insert( - grad_output_group.end(), saved_tensors.begin(), saved_tensors.end()); - static auto backward_op = - torch::Dispatcher::singleton() - .findSchemaOrThrow( - "fbgemm::group_index_select_dim0_gpu_backward", "") - .typed(); - auto res = backward_op.call(grad_output_group, output_shape_group); - // 1) Add group_size Variable()'s for indices - // Replace all empty tensors with Variable(). This must be done after the - // op.call to make __torch_dispatch__ work for the backward op. - std::fill( - res.begin(), res.begin() + group_size, torch::autograd::Variable()); - // 3) Add 1 Variable() for group_size - res.push_back({}); - return res; + // Calculate indices_ptrs + std::vector> index_contigs; + index_contigs.reserve(group_size); + for (const auto i : c10::irange(group_size)) { + const auto& indices = indices_group[i]; + index_contigs.push_back(indices.expect_contiguous()); + indices_ptrs[i] = reinterpret_cast(index_contigs[i]->data_ptr()); } -}; + + // Transfer grad output pointers to GPU + args_tensor = args_tensor.to(first_indices.device(), /*non_blocking=*/true); + + group_index_select_or_add_cuda( + args_tensor.data_ptr(), + args_tensor.data_ptr() + group_size, + args_tensor.data_ptr() + 2 * group_size, + warp_offsets_group, + num_cols_group, + fwd_input.scalar_type(), + first_indices.scalar_type(), + fwd_input.device().index(), + num_input_rows, + total_num_warps, + group_size, + /*use_index_select=*/false, + use_var_cols); + + return outputs; +} Tensor pack_segments_cuda( const Tensor& t_in, @@ -654,45 +558,6 @@ Tensor index_select_dim0_gpu( user_skip_indices_sorting_fwd && !c10::InferenceMode::is_enabled())[0]; } -torch::autograd::variable_list group_index_select_dim0_gpu_impl( - at::TensorList all_indices_input, - const int64_t group_size) { - return GroupIndexSelectDim0GPUOp::apply(all_indices_input, group_size); -} - -torch::autograd::variable_list group_index_select_dim0_gpu( - at::TensorList input_group, - at::TensorList indices_group) { - const auto group_size = indices_group.size(); - std::vector output_group; - - if (group_size == 0) { - return std::vector(); - } - - // Pack input_group and indices_group into TensorList - std::vector all_indices_input_vec; - all_indices_input_vec.reserve(group_size * 2); - - for (const Tensor& index : indices_group) { - all_indices_input_vec.push_back(index); - } - for (const Tensor& input : input_group) { - all_indices_input_vec.push_back(input); - } - - at::TensorList all_indices_input_tensor = all_indices_input_vec; - - static auto forward_op = - torch::Dispatcher::singleton() - .findSchemaOrThrow("fbgemm::group_index_select_dim0_gpu_impl", "") - .typed(); - auto res = forward_op.call(all_indices_input_tensor, group_size); - TORCH_CHECK(res.size() == group_size + 2); - // only return the outputs (the first group_size elements) - res.resize(group_size); - return res; -} } // namespace fbgemm_gpu TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) { @@ -721,17 +586,17 @@ TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) { DISPATCH_TO_CUDA("index_select_dim0", fbgemm_gpu::index_select_dim0_gpu); DISPATCH_TO_CUDA( "group_index_select_dim0_gpu_impl", - fbgemm_gpu::GroupIndexSelectDim0GPUOp::forward_impl); + fbgemm_gpu::group_index_select_dim0_forward_impl_gpu); DISPATCH_TO_CUDA( "group_index_select_dim0_gpu_backward", - fbgemm_gpu::GroupIndexSelectDim0GPUOp::backward_impl); + fbgemm_gpu::group_index_select_dim0_backward_impl_gpu); DISPATCH_TO_CUDA( - "group_index_select_dim0", fbgemm_gpu::group_index_select_dim0_gpu); + "group_index_select_dim0", fbgemm_gpu::group_index_select_dim0); } TORCH_LIBRARY_IMPL(fbgemm, AutogradCUDA, m) { - m.impl("group_index_select_dim0", &fbgemm_gpu::group_index_select_dim0_gpu); + m.impl("group_index_select_dim0", &fbgemm_gpu::group_index_select_dim0); m.impl( "group_index_select_dim0_gpu_impl", - &fbgemm_gpu::group_index_select_dim0_gpu_impl); + &fbgemm_gpu::group_index_select_dim0_autograd_impl); } From f4710c1b3b477838498cc71c7e6d3cf296f3ad82 Mon Sep 17 00:00:00 2001 From: Aya Ibrahim Date: Thu, 3 Oct 2024 12:36:19 -0700 Subject: [PATCH 43/48] FP8 KV + Disagg unit test (#3218) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3218 X-link: https://github.com/facebookresearch/FBGEMM/pull/315 Adding the fp8 kv cache to disagg test for mp2. Changes include changing the model to 7b llama model. The small model has D_H of 64, which is not working with dequantization kernel (will check the issue in another diff). TODO: add Fp8 kv cache + paged kv to the test Reviewed By: jianyuh Differential Revision: D62772678 fbshipit-source-id: 775f572e2c345354844e24d80e2481284ac6f1a3 --- fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cu | 1 + 1 file changed, 1 insertion(+) diff --git a/fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cu b/fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cu index 787c0547c..00974a9fe 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cu @@ -1437,6 +1437,7 @@ __global__ void dequantize_fp8_cache_kernel( auto MAX_T = cache_K.size(1); auto D_H = cache_K_dq.size(3); auto D_H_q = cache_K.size(3); + // TODO: support D_H < 128 for small model used in testing. CUDA_KERNEL_ASSERT(D_H == 128); auto b = blockIdx.x; From a0966e8d3de304b09c5bbe4cdb0fdd890cdd7660 Mon Sep 17 00:00:00 2001 From: Benson Ma Date: Thu, 3 Oct 2024 12:54:46 -0700 Subject: [PATCH 44/48] Redefine FBGEMM targets with gpu_cpp_library [25/N] (#3217) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3217 X-link: https://github.com/facebookresearch/FBGEMM/pull/314 - Add `default_compiler_flags` to codegen targets - Fix inclusion of `feature_gates_fb.h` Reviewed By: jianyuh Differential Revision: D63818462 fbshipit-source-id: d73fcc69e4f72291793daf22b5faa04467b3c8af --- fbgemm_gpu/include/fbgemm_gpu/config/feature_gates.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fbgemm_gpu/include/fbgemm_gpu/config/feature_gates.h b/fbgemm_gpu/include/fbgemm_gpu/config/feature_gates.h index 551e83d7c..f81276a9d 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/config/feature_gates.h +++ b/fbgemm_gpu/include/fbgemm_gpu/config/feature_gates.h @@ -11,7 +11,7 @@ #include #ifdef FBGEMM_FBCODE -#include "fbgemm_gpu/config/feature_gates_fb.h" +#include "deeplearning/fbgemm/fbgemm_gpu/fb/include/fbgemm_gpu/config/feature_gates_fb.h" #endif /// @defgroup fbgemm-gpu-config FBGEMM_GPU Configuration From d3eae1da1b5b9a9ea6e24fcf8e0b4d71a64710e8 Mon Sep 17 00:00:00 2001 From: Jiawen Liu Date: Fri, 4 Oct 2024 09:42:54 -0700 Subject: [PATCH 45/48] MoE BMM INT4 rowwise weight-only (#3219) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3219 X-link: https://github.com/facebookresearch/FBGEMM/pull/316 Marlin int4 weight-only with loopover for bmm performs great (**up to 7x faster** compared to bf16 bmm) when dim M is small to medium size (e.g., < 256) in decode; For larger dim M, we could leverage this bmm int4 rowwise weight-only kernel in prefill that is around **1.5x faster** than marlin int4 loopover and maintain the same accuracy More results can be found in this [data sheet](https://docs.google.com/spreadsheets/d/12JWt3SqX_1GSLKwjGyt0KQl9SMWDF0r0C63MMKsE9JM/edit?usp=sharing) Reviewed By: jianyuh Differential Revision: D63818529 fbshipit-source-id: 127e841fa7c6c1ce810b6e8b6e35907eeaecafd6 --- fbgemm_gpu/experimental/gen_ai/CMakeLists.txt | 1 + .../bf16i4bf16_rowwise_batched.cu | 298 ++++++++++++++++++ .../gen_ai/src/quantize/quantize.cpp | 28 +- .../gen_ai/test/quantize/quantize_test.py | 60 ++-- 4 files changed, 366 insertions(+), 21 deletions(-) create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16i4bf16_rowwise_batched.cu diff --git a/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt b/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt index 5accb9c53..dd6f165ed 100644 --- a/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt +++ b/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt @@ -49,6 +49,7 @@ else() src/quantize/cutlass_extensions/f8i4bf16_rowwise.cu src/quantize/cutlass_extensions/i8i8bf16_dynamic.cu src/quantize/cutlass_extensions/bf16i4bf16_rowwise.cu + src/quantize/cutlass_extensions/bf16i4bf16_rowwise_batched.cu src/quantize/quantize.cu src/quantize/quantize.cpp) endif() diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16i4bf16_rowwise_batched.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16i4bf16_rowwise_batched.cu new file mode 100644 index 000000000..871543a2f --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16i4bf16_rowwise_batched.cu @@ -0,0 +1,298 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include + +// clang-format off +// The fixed ordering of the headers is required for CUTLASS 3.2+ +#include +#include // @manual +#include // @manual +#include // @manual +// clang-format on + +#include "cutlass_extensions/include/kernel_mode.h" + +namespace fbgemm_gpu { + +#if CUDART_VERSION >= 12000 + +template < + int TB_M, + int TB_N, + int TB_K, + int TBS_M, + int TBS_N, + int TBS_K, + bool PONG, + typename WEIGHT_SCALE_DTYPE> +at::Tensor bf16i4bf16_rowwise_batched_impl( + at::Tensor X, // BF16 + at::Tensor WQ, // INT4 + at::Tensor w_scale, + at::Tensor w_zp) { + // XQ: B x M x K + // WQ: B x N x K + // output: B x M x N + int B = X.size(0); + int M = X.size(1); + int N = WQ.size(1); + int K = X.size(2); + + int num_groups = w_scale.size(0) / B; + + TORCH_CHECK(X.is_cuda() && X.is_contiguous()); + TORCH_CHECK(WQ.is_cuda() && WQ.is_contiguous()); + TORCH_CHECK(w_scale.is_cuda() && w_scale.is_contiguous()); + TORCH_CHECK(w_zp.is_cuda() && w_zp.is_contiguous()); + TORCH_CHECK(K >= num_groups && K % num_groups == 0); + + int group_size = K / num_groups; + + auto Y = at::empty({B, M, N}, X.options().dtype(at::kBFloat16)); + + using ElementInputA = cutlass::bfloat16_t; + using LayoutInputA = cutlass::layout::ColumnMajor; + constexpr int AlignmentInputA = + 128 / + cutlass::sizeof_bits< + ElementInputA>::value; // Memory access granularity/alignment of A + // matrix in units of elements (up to 16 bytes) + + using ElementInputB = cutlass::int4b_t; + using LayoutInputB = cutlass::layout::RowMajor; + constexpr int AlignmentInputB = + 128 / + cutlass::sizeof_bits< + ElementInputB>::value; // Memory access granularity/alignment of B + // matrix in units of elements (up to 16 bytes) + + using ElementScale = WEIGHT_SCALE_DTYPE; + using ElementZeroPoint = WEIGHT_SCALE_DTYPE; + using ElementComputeEpilogue = float; + using ElementAccumulator = float; + + using ElementOutput = cutlass::bfloat16_t; + using LayoutOutput = cutlass::layout::ColumnMajor; + constexpr int AlignmentOutput = + 128 / + cutlass::sizeof_bits< + ElementOutput>::value; // Memory access granularity/alignment of C + // matrix in units of elements (up to 16 bytes) + + using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that + // supports the intended feature + using OperatorClass = cutlass::arch::OpClassTensorOp; + using TileShape = cute::Shape< + cute::Int, + cute::Int, + cute::Int>; // Threadblock-level + // tile size + using ClusterShape = cute::Shape< + cute::Int, + cute::Int, + cute::Int>; // Shape of the + // threadblocks in a + // cluster + using DefaultSchedule = cutlass::gemm::KernelTmaWarpSpecializedMixedInput; + using PongSchedule = + cutlass::gemm::KernelTmaWarpSpecializedPingpongMixedInput; + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; + using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; + using MainLoopSchedule = + cute::conditional_t; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, + cutlass::arch::OpClassTensorOp, + TileShape, + ClusterShape, + EpilogueTileType, + ElementAccumulator, + ElementAccumulator, + ElementOutput, + LayoutOutput, + AlignmentOutput, + ElementOutput, + LayoutOutput, + AlignmentOutput, + EpilogueSchedule>::CollectiveOp; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OperatorClass, + cute::tuple, + LayoutInputB, + AlignmentInputB, + ElementInputA, + LayoutInputA, + AlignmentInputA, + ElementAccumulator, + TileShape, + ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainLoopSchedule>::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + using StrideInputA = typename Gemm::GemmKernel::StrideA; + using StrideInputB = typename Gemm::GemmKernel::StrideB; + using StrideOutput = typename Gemm::GemmKernel::StrideC; + using StrideS = typename CollectiveMainloop::StrideScale; + + StrideInputA stride_a = cutlass::make_cute_packed_stride( + StrideInputA{}, cute::make_shape(M, K, B)); + StrideInputB stride_b = cutlass::make_cute_packed_stride( + StrideInputB{}, cute::make_shape(N, K, B)); + StrideOutput stride_output = cutlass::make_cute_packed_stride( + StrideOutput{}, cute::make_shape(N, M, B)); + StrideS stride_S = cutlass::make_cute_packed_stride( + StrideS{}, cute::make_shape(N, num_groups, B)); + + typename Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + {N, M, K, B}, + {reinterpret_cast(WQ.data_ptr()), + stride_b, + reinterpret_cast(X.data_ptr()), + stride_a, + reinterpret_cast(w_scale.data_ptr()), + stride_S, + group_size, + reinterpret_cast(w_zp.data_ptr())}, + {{1.0, 0.0}, + (ElementOutput*)Y.data_ptr(), + stride_output, + (ElementOutput*)Y.data_ptr(), + stride_output}}; + + Gemm gemm; + + // Using the arguments, query for extra workspace required for matrix + // multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Check the problem size is supported or not + cutlass::Status status = gemm.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error("cutlass cannot implement"); + } + + // Initialize CUTLASS kernel with arguments and workspace pointer + status = gemm.initialize(arguments, workspace.get()); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error("cutlass cannot initialize"); + } + + status = gemm(at::cuda::getCurrentCUDAStream()); + + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error( + std::string("cutlass cannot run") + + cutlass::cutlassGetStatusString(status)); + } + C10_CUDA_KERNEL_LAUNCH_CHECK(); + + return Y; +} + +template +at::Tensor dispatch_bf16i4bf16_rowwise_batched_kernel( + at::Tensor X, // BF16 + at::Tensor WQ, // INT4 + at::Tensor w_scale, + at::Tensor w_zp) { + KernelMode kernel = get_batched_kernel_mode(X, WQ); + if (kernel == KernelMode::Small) { + return bf16i4bf16_rowwise_batched_impl< + 64, + 128, + 64, + 2, + 1, + 1, + true, + WEIGHT_SCALE_DTYPE>(X, WQ, w_scale, w_zp); + } else if (kernel == KernelMode::Large) { + return bf16i4bf16_rowwise_batched_impl< + 128, + 128, + 64, + 2, + 1, + 1, + true, + WEIGHT_SCALE_DTYPE>(X, WQ, w_scale, w_zp); + } else { + return bf16i4bf16_rowwise_batched_impl< + 128, + 128, + 64, + 2, + 1, + 1, + false, + WEIGHT_SCALE_DTYPE>(X, WQ, w_scale, w_zp); + } +} + +at::Tensor bf16i4bf16_rowwise_batched( + at::Tensor X, // BF16 + at::Tensor WQ, // INT4 + at::Tensor w_scale, + at::Tensor w_zp) { + // Check datatypes. + TORCH_CHECK( + (w_scale.dtype() == at::kFloat && w_zp.dtype() == at::kFloat) || + (w_scale.dtype() == at::kHalf && w_zp.dtype() == at::kHalf) || + (w_scale.dtype() == at::kBFloat16 && w_zp.dtype() == at::kBFloat16), + "Weight scale and zero point tensors must be float32, bfloat16, or float16, and dtype of weight scale and zero point tensors must be the same ."); + + if (w_scale.dtype() == at::kFloat) { + return dispatch_bf16i4bf16_rowwise_batched_kernel( + X, WQ, w_scale, w_zp); + } else if (w_scale.dtype() == at::kHalf) { + return dispatch_bf16i4bf16_rowwise_batched_kernel( + X, WQ, w_scale, w_zp); + } else if (w_scale.dtype() == at::kBFloat16) { + return dispatch_bf16i4bf16_rowwise_batched_kernel( + X, WQ, w_scale, w_zp); + } else { + throw std::runtime_error( + "Weight scale and zero point data type not supported in bf16i4bf16_rowwise_batched"); + } +} + +#else + +at::Tensor bf16i4bf16_rowwise_batched( + at::Tensor X, // BF16 + at::Tensor WQ, // INT4 + at::Tensor w_scale, + at::Tensor w_zp) { + throw std::runtime_error( + "CUDA version is older than 12.0"); // requires CUDA>=12 +} + +#endif + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp b/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp index 101a5cba1..1abf8fb40 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp @@ -96,6 +96,11 @@ at::Tensor bf16i4bf16_rowwise( at::Tensor WQ, at::Tensor w_scale, at::Tensor w_zp); +at::Tensor bf16i4bf16_rowwise_batched( + at::Tensor X, + at::Tensor WQ, + at::Tensor w_scale, + at::Tensor w_zp); at::Tensor per_tensor_quantize_i8(at::Tensor X, double scale); std::tuple per_tensor_dynamic_quantize_i8(at::Tensor X); @@ -152,6 +157,10 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { "bf16i4bf16_rowwise(Tensor X, Tensor WQ, Tensor w_scale, Tensor w_zp) -> Tensor"); m.impl("bf16i4bf16_rowwise", bf16i4bf16_rowwise); + m.def( + "bf16i4bf16_rowwise_batched(Tensor X, Tensor WQ, Tensor w_scale, Tensor w_zp) -> Tensor"); + m.impl("bf16i4bf16_rowwise_batched", bf16i4bf16_rowwise_batched); + m.def( "i8i8bf16_dynamic(Tensor XQ, Tensor WQ, Tensor scale, int split_k=1) -> Tensor"); m.impl("i8i8bf16_dynamic", i8i8bf16_dynamic); @@ -326,14 +335,28 @@ at::Tensor f8i4bf16_rowwise_meta( at::Tensor bf16i4bf16_rowwise_meta( at::Tensor X, // BF16 at::Tensor WQ, // INT4 - at::Tensor w_scale, - at::Tensor w_zp) { + at::Tensor /* w_scale */, + at::Tensor /* w_zp */ +) { int M = X.size(0); int N = WQ.size(0); auto Y = at::empty({M, N}, X.options().dtype(at::kBFloat16)); return Y; } +at::Tensor bf16i4bf16_rowwise_batched_meta( + at::Tensor X, // BF16 + at::Tensor WQ, // INT4 + at::Tensor /* w_scale */, + at::Tensor /* w_zp */ +) { + int B = X.size(0); + int M = X.size(1); + int N = WQ.size(1); + auto Y = at::empty({B, M, N}, X.options().dtype(at::kBFloat16)); + return Y; +} + std::vector quantize_fp8_per_row_meta( at::Tensor input, std::optional bs, @@ -370,6 +393,7 @@ TORCH_LIBRARY_IMPL(fbgemm, Meta, m) { m.impl("f8f8bf16_rowwise_batched", f8f8bf16_rowwise_batched_meta); m.impl("f8i4bf16_rowwise", f8i4bf16_rowwise_meta); m.impl("bf16i4bf16_rowwise", bf16i4bf16_rowwise_meta); + m.impl("bf16i4bf16_rowwise_batched", bf16i4bf16_rowwise_batched_meta); #endif } diff --git a/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py b/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py index c21c1713a..38a09f360 100644 --- a/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py +++ b/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py @@ -673,6 +673,7 @@ def fp8_loopover_bmm( M=st.sampled_from([2048, 4096]), N=st.sampled_from([256, 512]), K=st.sampled_from([256, 512]), + use_loopover=st.sampled_from([True, False]), ) def test_int4_batched_gemm( self, @@ -680,6 +681,7 @@ def test_int4_batched_gemm( M: int, N: int, K: int, + use_loopover: bool, ) -> None: if not MARLIN_ENABLED: return @@ -689,28 +691,48 @@ def test_int4_batched_gemm( wq = [] w_scale = [] group_size = 128 - for i in range(B): - _, wq_, w_scale_ = marlin_quantize(w[i].cuda().t().contiguous(), group_size) - wq.append(wq_) - w_scale.append(w_scale_) - wq = torch.stack(wq) - w_scale = torch.stack(w_scale) - - def int4_loopover_bmm( - x: torch.Tensor, - wq: torch.Tensor, - w_scale: torch.Tensor, - ) -> torch.Tensor: - B = x.shape[0] - M = x.shape[1] - N = w_scale.shape[2] - y = torch.empty((B, M, N), dtype=torch.bfloat16, device=x[0].device) + + if use_loopover: for i in range(B): - y[i] = torch.ops.marlin.marlin_gemm(x[i], wq[i], w_scale[i]) - return y + _, wq_, w_scale_ = marlin_quantize( + w[i].cuda().t().contiguous(), group_size + ) + wq.append(wq_) + w_scale.append(w_scale_) + wq = torch.stack(wq) + w_scale = torch.stack(w_scale) + + def int4_loopover_bmm( + x: torch.Tensor, + wq: torch.Tensor, + w_scale: torch.Tensor, + ) -> torch.Tensor: + B = x.shape[0] + M = x.shape[1] + N = w_scale.shape[2] + y = torch.empty((B, M, N), dtype=torch.bfloat16, device=x[0].device) + for i in range(B): + y[i] = torch.ops.marlin.marlin_gemm(x[i], wq[i], w_scale[i]) + return y + + y_int4 = int4_loopover_bmm(x, wq, w_scale) + else: + w_zp = [] + for i in range(B): + wq_, w_scale_, w_zp_ = int4_row_quantize(w[i], group_size) + + wq_ = pack_int4(wq_).contiguous().to(device="cuda") + w_scale_ = w_scale_.contiguous().to(device="cuda") + w_zp_ = w_zp_.contiguous().to(device="cuda") + wq.append(wq_) + w_scale.append(w_scale_) + w_zp.append(w_zp_) + wq = torch.stack(wq) + w_scale = torch.stack(w_scale).view(-1, N) + w_zp = torch.stack(w_zp).view(-1, N) + y_int4 = torch.ops.fbgemm.bf16i4bf16_rowwise_batched(x, wq, w_scale, w_zp) y_ref = torch.bmm(x, w.transpose(1, 2)) - y_int4 = int4_loopover_bmm(x, wq, w_scale) torch.testing.assert_close(y_ref, y_int4, atol=8.0e-2, rtol=8.0e-2) From 342e8d2846ce8aac658decf4a5d13f11001e8eac Mon Sep 17 00:00:00 2001 From: generatedunixname89002005307016 Date: Fri, 4 Oct 2024 13:31:47 -0700 Subject: [PATCH 46/48] Add missing Pyre mode headers] [batch:12/1006] [shard:8/N] [B] [B] [A] Differential Revision: D63900700 fbshipit-source-id: 29f02010bf205046c50c2a3c1004911e11421ecf --- fbgemm_gpu/test/release/utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/fbgemm_gpu/test/release/utils.py b/fbgemm_gpu/test/release/utils.py index 005cf38b5..4218ebc95 100644 --- a/fbgemm_gpu/test/release/utils.py +++ b/fbgemm_gpu/test/release/utils.py @@ -4,6 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe + import inspect import typing from typing import Iterable, List, Optional, Sequence, Union # noqa: F401 From 88ef5f9d6c54eac8e11cd629e1f4a511ee6fb457 Mon Sep 17 00:00:00 2001 From: Supadchaya Puangpontip Date: Fri, 4 Oct 2024 13:33:33 -0700 Subject: [PATCH 47/48] Fix pack_segments backward when grad is non-contig (#3222) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3222 X-link: https://github.com/facebookresearch/FBGEMM/pull/320 Original commit changeset: c1fe80d75fb4 Original Phabricator Diff: D61694017 from D61694017 Reviewed By: q10, brad-mengchi Differential Revision: D63424805 fbshipit-source-id: 42e44383b48a577610f00ad6b8c2cd48bf734a2b --- .../sparse_pack_segments_backward.cu | 16 +- fbgemm_gpu/test/sparse/failures_dict.json | 10 + fbgemm_gpu/test/sparse/pack_segments_test.py | 174 +++++++++++++++++- 3 files changed, 192 insertions(+), 8 deletions(-) diff --git a/fbgemm_gpu/src/sparse_ops/sparse_pack_segments_backward.cu b/fbgemm_gpu/src/sparse_ops/sparse_pack_segments_backward.cu index 9037b7c09..c899bbf9b 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_pack_segments_backward.cu +++ b/fbgemm_gpu/src/sparse_ops/sparse_pack_segments_backward.cu @@ -62,18 +62,21 @@ DLL_PUBLIC Tensor pack_segments_backward_cuda( CUDA_DEVICE_GUARD(data); + const auto data_contig = data.expect_contiguous(); + Tensor unpacked_tensor; // The output tensor AT_DISPATCH_INDEX_TYPES(lengths.scalar_type(), "unpack_segments_cuda", [&] { const auto* const lengths_data = lengths.data_ptr(); // Create output tensor of appropriate dimensions - auto shape = data.sizes().vec(); + auto shape = data_contig->sizes().vec(); shape.erase(shape.begin()); shape[0] = total_length; - unpacked_tensor = at::empty(shape, data.options()); + unpacked_tensor = at::empty(shape, data_contig->options()); - if (!(data.size(0) && data.size(1))) { // TODO: What does this mean? + if (!(data_contig->size(0) && + data_contig->size(1))) { // TODO: What does this mean? return; } @@ -82,10 +85,11 @@ DLL_PUBLIC Tensor pack_segments_backward_cuda( auto lps_data = lengths_prefix_sum.data_ptr(); FBGEMM_DISPATCH_ALL_TYPES( - data.scalar_type(), "unpack_segments_cuda-unpacking", [&] { + data_contig->scalar_type(), "unpack_segments_cuda-unpacking", [&] { const auto num_seq = lengths.size(0); - const auto cell_size = data.numel() / (data.size(0) * data.size(1)); - const auto* const data_ptr = data.data_ptr(); + const auto cell_size = data_contig->numel() / + (data_contig->size(0) * data_contig->size(1)); + const auto* const data_ptr = data_contig->data_ptr(); auto* const out_data = unpacked_tensor.data_ptr(); unpack_segments_cuda_kernel diff --git a/fbgemm_gpu/test/sparse/failures_dict.json b/fbgemm_gpu/test/sparse/failures_dict.json index 40cfacc06..fb2fcf85e 100644 --- a/fbgemm_gpu/test/sparse/failures_dict.json +++ b/fbgemm_gpu/test/sparse/failures_dict.json @@ -2,6 +2,16 @@ "_description": "This is a dict containing failures for tests autogenerated by generate_opcheck_tests. For more details, please see https://docs.google.com/document/d/1Pj5HRZvdOq3xpFpbEjUZp2hBovhy7Wnxw14m6lF2154/edit", "_version": 1, "data": { + "fb::pack_segments": { + "PackedSegmentsTest.test_aot_dispatch_dynamic__test_pack_segments_noncontig": { + "comment": "", + "status": "xfail" + }, + "PackedSegmentsTest.test_faketensor__test_pack_segments_noncontig": { + "comment": "", + "status": "xfail" + } + }, "fbgemm::asynchronous_complete_cumsum": {}, "fbgemm::asynchronous_exclusive_cumsum": {}, "fbgemm::asynchronous_inclusive_cumsum": {}, diff --git a/fbgemm_gpu/test/sparse/pack_segments_test.py b/fbgemm_gpu/test/sparse/pack_segments_test.py index dd5319277..095ea4377 100644 --- a/fbgemm_gpu/test/sparse/pack_segments_test.py +++ b/fbgemm_gpu/test/sparse/pack_segments_test.py @@ -23,9 +23,9 @@ if open_source: # pyre-ignore[21] - from test_utils import gpu_available + from test_utils import gpu_available, gpu_unavailable else: - from fbgemm_gpu.test.test_utils import gpu_available + from fbgemm_gpu.test.test_utils import gpu_available, gpu_unavailable def get_n_rand_num_summing_to_k(n: int, k: int) -> npt.NDArray: @@ -47,6 +47,15 @@ def get_n_rand_num_summing_to_k(n: int, k: int) -> npt.NDArray: # pyre-fixme[2] # pyre-fixme[24] def torch_compiled(model: Callable, **kwargs) -> Callable: + """A helper function to apply torch.compile if python < 3.12. + + Args: + model: The model to be compiled. + kwargs: The arguments to be passed to torch.compile. + + Returns: + The model. + """ if sys.version_info < (3, 12, 0): return torch.compile(model, **kwargs) else: @@ -60,6 +69,17 @@ def _pack_segments_ref( tensor: torch.Tensor, max_length: Optional[int] = None, ) -> npt.NDArray: + """ + This function is a reference implementation of pack_segments. + + Args: + lengths (Tensor): The lengths of tensor. + tensor (Tensor): The tensor to be packed. + max_length (Optional[int]): The maximum length of the packed tensor. + + Returns: + The packed tensor. + """ lengths = lengths.numpy() sections = np.split(tensor, np.cumsum(lengths)) max_length = np.max(lengths, initial=0) if max_length is None else max_length @@ -106,6 +126,22 @@ def test_pack_segments( dtype: torch.dtype, torch_compile: bool, ) -> None: + """ + This function tests pack_segments ops compared to the reference implementation. + Both CPU and GPU (if available) are tested. + + Args: + n - The number of rows in the input tensor + k - The number of columns in the input tensor + batch_size - The number of batches of the input tensor + divisions - The number of segments to be packed + dtype - The data type + torch_compile - Whether to use torch.compile + + Returns: + None + """ + input_raw = np.random.rand(batch_size, n, k) input_data = torch.tensor(input_raw, dtype=dtype, requires_grad=True) lengths = torch.tensor( @@ -209,6 +245,23 @@ def test_pack_segments_smaller_max_len( dtype: torch.dtype, torch_compile: bool, ) -> None: + """ + This function tests pack_segments ops with set max_length + Both CPU and GPU (if available) are tested. + + Args: + n - The number of rows in the input tensor + k - The number of columns in the input tensor + batch_size - The number of batches of the input tensor + divisions - The number of segments to be packed + max_length - The maximum length of the packed tensor + dtype - The data type + torch_compile - Whether to use torch.compile + + Returns: + None + """ + input_raw = np.random.rand(batch_size, n, k) input_data = torch.tensor(input_raw, dtype=dtype) lengths = torch.tensor( @@ -264,6 +317,20 @@ def test_pack_segments_meta_backend( divisions: int, dtype: torch.dtype, ) -> None: + """ + This function tests pack_segments ops with meta backend. + + Args: + n - The number of rows in the input tensor + k - The number of columns in the input tensor + batch_size - The number of batches of the input tensor + divisions - The number of segments to be packed + dtype - The data type + + Returns: + None + """ + input_raw = np.random.rand(batch_size, n, k) input_data = torch.tensor( input_raw, dtype=torch.float32, requires_grad=True @@ -281,6 +348,109 @@ def test_pack_segments_meta_backend( # verify forward assert packed_tensor.size() == torch.Tensor(packed_ref).size() + @unittest.skipIf(*gpu_unavailable) + @given( + n=st.integers(2, 10), + k=st.integers(2, 10), + batch_size=st.integers(1, 30), + divisions=st.integers(1, 10), + dtype=st.sampled_from( + [ + torch.float, + torch.half, + ] + ), + torch_compile=st.booleans(), + use_cpu=st.booleans(), + ) + @settings(deadline=None) + def test_pack_segments_noncontig( + self, + n: int, + k: int, + batch_size: int, + divisions: int, + dtype: torch.dtype, + torch_compile: bool, + use_cpu: bool, + ) -> None: + """ + This function tests pack_segments ops when input gradients to backward are non-contiguous. + + Args: + n - The number of rows in the input tensor + k - The number of columns in the input tensor + batch_size - The number of batches of the input tensor + divisions - The number of segments to be packed + dtype - The data type + torch_compile - Whether to use torch.compile + use_cpu - Whether to use CPU or GPU + + Returns: + None + """ + + input_raw = np.random.rand(batch_size, n, k) + # create input + input_data_ref = torch.tensor(input_raw, dtype=dtype, requires_grad=True) + input_data = torch.tensor(input_raw, dtype=dtype, requires_grad=True).cuda() + # retain grad to compare gradients of the inputs later + input_data.retain_grad() + input_data_ref.retain_grad() + + # set lengths + lengths = torch.tensor( + get_n_rand_num_summing_to_k(divisions, batch_size), + dtype=torch.int, + ) + max_length = lengths.max().item() + + packed_ref = torch.ops.fbgemm.pack_segments( + t_in=input_data_ref, lengths=lengths, max_length=max_length + ) + packed_ref.retain_grad() + + # pack segments using fbgemm and fb + packed_tensor = torch.ops.fbgemm.pack_segments( + t_in=input_data, lengths=lengths.cuda(), max_length=max_length + ) + packed_tensor.retain_grad() + + # verify forward + self.assertTrue(torch.equal(packed_tensor.cpu(), packed_ref)) + + # create non-contiguous grad + shape = tuple(x * 2 for x in packed_ref.shape) + grads = torch.tensor( + np.random.uniform(low=0.01, high=0.5, size=shape).astype(np.float32) + ).to(dtype) + grad_noncontig_cpu = grads.as_strided(packed_ref.shape, grads.stride()) + grad_noncontig_cuda = grads.cuda().as_strided(packed_ref.shape, grads.stride()) + + self.assertTrue( + not ( + grad_noncontig_cpu.is_contiguous() + and grad_noncontig_cuda.is_contiguous() + ), + msg="Expected grads to be non-contiguous but they are contiguous", + ) + + # verify backward + packed_ref.backward(grad_noncontig_cpu) + packed_tensor.backward(grad_noncontig_cuda) + self.assertTrue( + torch.equal(packed_tensor.cpu(), packed_ref), + msg="Expected packed tensors to be equal but they are not", + ) + + # verify backward input gradients + self.assertTrue( + # pyre-fixme[16]: Optional type has no attribute `cpu`. + # pyre-fixme[6]: For 2nd param expected `Tensor` but got `Optional[Tensor]`. + torch.equal(input_data.grad.cpu(), input_data_ref.grad.cpu()), + msg="Expected input gradients to be equal but they are not", + ) + extend_test_class(PackedSegmentsTest) From 7a4472a8d14912ab7a3b7ca12bca030448f8fec8 Mon Sep 17 00:00:00 2001 From: Jiawen Liu Date: Fri, 4 Oct 2024 16:08:13 -0700 Subject: [PATCH 48/48] Fine-tune FP8 BMM performance (#3224) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3224 X-link: https://github.com/facebookresearch/FBGEMM/pull/322 Fine-tune FP8 BMM performance to get additional 20% performance gain. Reviewed By: jianyuh Differential Revision: D63882833 fbshipit-source-id: 04fd5d38e8e127edd2d8771681a180757eaf7321 --- .../f8f8bf16_rowwise_batched.cu | 13 +++++++++++++ .../cutlass_extensions/include/kernel_mode.h | 16 ++++++++-------- 2 files changed, 21 insertions(+), 8 deletions(-) diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise_batched.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise_batched.cu index 313c81298..a34c694e0 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise_batched.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise_batched.cu @@ -335,6 +335,19 @@ at::Tensor dispatch_fp8_rowwise_batched_kernel( UseBias, InputDType, BiasDType>(XQ, WQ, x_scale, w_scale, bias, output); + } else if (kernel == KernelMode::Medium) { + return f8f8bf16_rowwise_batched_impl< + 64, + 128, + 128, + 1, + 2, + 1, + true, + FastAccum, + UseBias, + InputDType, + BiasDType>(XQ, WQ, x_scale, w_scale, bias, output); } else if (kernel == KernelMode::Large) { return f8f8bf16_rowwise_batched_impl< 128, diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/include/kernel_mode.h b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/include/kernel_mode.h index 93b96fb04..9a267193a 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/include/kernel_mode.h +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/include/kernel_mode.h @@ -12,7 +12,7 @@ namespace fbgemm_gpu { -enum class KernelMode { Small, Large, Default }; +enum class KernelMode { Small, Medium, Large, Default }; inline KernelMode get_kernel_mode(at::Tensor XQ, at::Tensor WQ) { auto M = XQ.size(0); @@ -37,14 +37,14 @@ inline KernelMode get_batched_kernel_mode(at::Tensor XQ, at::Tensor WQ) { auto K = XQ.size(2); auto N = WQ.size(1); auto BM = B * M; - auto BN = B * N; - auto BK = B * K; - // Use a large kernel if at least two shapes are large.... - bool use_large_kernel = - ((BM >= 2048 && BK >= 2048) || (BM >= 2048 && BK >= 2048) || - (BK >= 2048 && BN >= 2048)); - if (BM <= 128 || BN <= 128) { + // Heuristic to determine kernel mode + bool use_medium_kernel = + ((BM <= 512 && ((N <= 8192 && K < 8192) || (N < 8192 && K <= 8192)))); + bool use_large_kernel = ((BM > 512 && (N >= 1024 || K >= 1024))); + if (BM <= 128 || N <= 128) { return KernelMode::Small; + } else if (use_medium_kernel) { + return KernelMode::Medium; } else if (use_large_kernel) { return KernelMode::Large; } else {