Skip to content

Commit

Permalink
Rowwise F8F8BF16 GEMMs - Auto-generate kernel library, auto-generated…
Browse files Browse the repository at this point in the history
… heuristics cache, add to FBGEMM quantize_bench

Summary:
# Summary
- Auto-generated F8F8BF16 Rowwise Scaled Kernels.
- Auto-generation of Heuristic Cache.
- Add to quantize_bench

# Performance Improvements

## DisaggBench
Cultass
Prefill  B=1	T=2048: Elapsed: 109.13ms	FLOPs: 333.74TF/s
Prefill B=1	T=4928: Elapsed: 272.55ms	FLOPs: 338.62TF/s
Prefill B=1	T=6336: Elapsed: 354.93ms	FLOPs: 342.55TF/s
Prefill B=1	T=8192: Elapsed: 468.64ms	FLOPs: 346.06TF/s

Cultass extensions
Prefill B=1	T=2048: Elapsed: 108.83ms	FLOPs: 334.66TF/s
Prefill B=1	T=4928: Elapsed: 260.46ms	FLOPs: 354.34TF/s
Prefill B=1	T=6336: Elapsed: 336.39ms	FLOPs: 361.43TF/s
Prefill B=1	T=8192: Elapsed: 442.64ms	FLOPs: 366.39TF/s

Differential Revision: D63744054
  • Loading branch information
manishucsd authored and facebook-github-bot committed Oct 2, 2024
1 parent 893769e commit 8d65a52
Showing 1 changed file with 42 additions and 0 deletions.
42 changes: 42 additions & 0 deletions fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,11 @@ def cuda(self) -> bool:
return True


####################################################################################################
# CUTLASS kernel v2
####################################################################################################


@register_quantize_op
class CutlassFP8TensorwiseGemm_v2(QuantizeOpBase):
"""
Expand Down Expand Up @@ -535,6 +539,44 @@ def cuda(self) -> bool:
return True


# CUTLASS kernel v2
@register_quantize_op
class CutlassFP8RowwiseGemm_v2(QuantizeOpBase):
"""
FP8 matmul with rowwise scaling.
"""

def quantize(self, x, w):
# Quantize both input tensors.
xq, x_scale = quantize_fp8_row(x)
wq, w_scale = quantize_fp8_row(w)
return xq, wq, x_scale, w_scale

def compute(self, xq, wq, x_scale, w_scale):
return torch.ops.cutlass_extensions.f8f8bf16_rowwise(xq, wq, x_scale, w_scale)

def quantize_and_compute(self, x, w):
xq, wq, x_scale, w_scale = self.quantize(x, w)
return self.compute(xq, wq, x_scale, w_scale)

@property
def name(self) -> str:
return "cutlass_rowwise_v2"

@property
def hip(self) -> bool:
# Need to add support for better quantize kernel.
# Also may have an issue with cuda graphs.
return False

@property
def cuda(self) -> bool:
return True


####################################################################################################


@register_quantize_op
class F8I4RowwiseGemm(QuantizeOpBase):
"""
Expand Down

0 comments on commit 8d65a52

Please sign in to comment.