Skip to content

Commit

Permalink
Fix Torch Compile with FP8 Quantization (pytorch#2637)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#2637

Fixes a few incompatibilities between torch compile and fp8 quantization. Outputs look correct but we unfortunately still dont see any speed up. At least this allows further performance analysis.

Reviewed By: jianyuh

Differential Revision: D57885347

fbshipit-source-id: dce775069bde1a8d1891cb52f9813eb54d0d162a
  • Loading branch information
jwfromm authored and facebook-github-bot committed May 29, 2024
1 parent 7b39db2 commit 35fa7be
Showing 1 changed file with 13 additions and 1 deletion.
14 changes: 13 additions & 1 deletion fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,6 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
m.impl("quantize_fp8_per_tensor", quantize_fp8_per_tensor);
m.def(
"quantize_fp8_per_row(Tensor input, Tensor? bs=None, Tensor? scale_ub=None, ScalarType? output_dtype=None) -> Tensor[]");
m.impl("quantize_fp8_per_row", quantize_fp8_per_row);

#if CUDART_VERSION >= 12000
m.def(
Expand Down Expand Up @@ -169,6 +168,7 @@ TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) {
m.impl("quantize_fp8_per_tensor", quantize_fp8_per_tensor);
m.impl("f8f8bf16", f8f8bf16);
m.impl("f8f8bf16_cublas", f8f8bf16_cublas);
m.impl("quantize_fp8_per_row", quantize_fp8_per_row);
#endif
}

Expand Down Expand Up @@ -252,6 +252,17 @@ at::Tensor f8i4bf16_rowwise_meta(
return Y;
}

std::vector<at::Tensor> quantize_fp8_per_row_meta(
at::Tensor input,
std::optional<at::Tensor> bs,
std::optional<at::Tensor> scale_ub,
std::optional<c10::ScalarType> output_dtype) {
const at::SymInt M = input.sym_size(0);
auto Y = at::empty_like(input, input.options().dtype(at::kFloat8_e4m3fn));
auto scale = at::empty_symint({M}, input.options().dtype(at::kFloat));
return {Y, scale};
}

TORCH_LIBRARY_IMPL(fbgemm, Meta, m) {
m.impl("f8f8bf16_tensorwise", f8f8bf16_tensorwise_meta);
m.impl("f8f8bf16_rowwise", f8f8bf16_rowwise_meta);
Expand All @@ -261,6 +272,7 @@ TORCH_LIBRARY_IMPL(fbgemm, Meta, m) {
m.impl("f8f8bf16", f8f8bf16_meta);
m.impl("f8f8bf16_cublas", f8f8bf16_cublas_meta);
m.impl("f8i4bf16_rowwise", f8i4bf16_rowwise_meta);
m.impl("quantize_fp8_per_row", quantize_fp8_per_row_meta);
#endif
}

Expand Down

0 comments on commit 35fa7be

Please sign in to comment.