Skip to content

Commit

Permalink
Fine-tune FP8 BMM performance (#3224)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3224

X-link: facebookresearch/FBGEMM#322

Fine-tune FP8 BMM performance to get additional 20% performance gain.

Reviewed By: jianyuh

Differential Revision: D63882833

fbshipit-source-id: 04fd5d38e8e127edd2d8771681a180757eaf7321
  • Loading branch information
jiawenliu64 authored and facebook-github-bot committed Oct 4, 2024
1 parent 88ef5f9 commit 7a4472a
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,19 @@ at::Tensor dispatch_fp8_rowwise_batched_kernel(
UseBias,
InputDType,
BiasDType>(XQ, WQ, x_scale, w_scale, bias, output);
} else if (kernel == KernelMode::Medium) {
return f8f8bf16_rowwise_batched_impl<
64,
128,
128,
1,
2,
1,
true,
FastAccum,
UseBias,
InputDType,
BiasDType>(XQ, WQ, x_scale, w_scale, bias, output);
} else if (kernel == KernelMode::Large) {
return f8f8bf16_rowwise_batched_impl<
128,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

namespace fbgemm_gpu {

enum class KernelMode { Small, Large, Default };
enum class KernelMode { Small, Medium, Large, Default };

inline KernelMode get_kernel_mode(at::Tensor XQ, at::Tensor WQ) {
auto M = XQ.size(0);
Expand All @@ -37,14 +37,14 @@ inline KernelMode get_batched_kernel_mode(at::Tensor XQ, at::Tensor WQ) {
auto K = XQ.size(2);
auto N = WQ.size(1);
auto BM = B * M;
auto BN = B * N;
auto BK = B * K;
// Use a large kernel if at least two shapes are large....
bool use_large_kernel =
((BM >= 2048 && BK >= 2048) || (BM >= 2048 && BK >= 2048) ||
(BK >= 2048 && BN >= 2048));
if (BM <= 128 || BN <= 128) {
// Heuristic to determine kernel mode
bool use_medium_kernel =
((BM <= 512 && ((N <= 8192 && K < 8192) || (N < 8192 && K <= 8192))));
bool use_large_kernel = ((BM > 512 && (N >= 1024 || K >= 1024)));
if (BM <= 128 || N <= 128) {
return KernelMode::Small;
} else if (use_medium_kernel) {
return KernelMode::Medium;
} else if (use_large_kernel) {
return KernelMode::Large;
} else {
Expand Down

0 comments on commit 7a4472a

Please sign in to comment.