Skip to content

Commit

Permalink
Adding PyTorch FP8 matmul test notebook. (#137)
Browse files Browse the repository at this point in the history
  • Loading branch information
balancap authored Oct 30, 2024
1 parent 365e5d9 commit f9aa9af
Showing 1 changed file with 146 additions and 0 deletions.
146 changes: 146 additions & 0 deletions docs/PyTorch FP8 matmul tutorial.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "7ae3e6c9-01d2-4a34-a4a8-88d36c7e9b3f",
"metadata": {},
"source": [
"# PyTorch FP8 (fused) matmul tutorial"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "4c9500fc-648d-46d3-95ea-e74a0ee43fe6",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(device(type='cuda', index=0), 'NVIDIA H100 PCIe')"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import numpy as np\n",
"import torch\n",
"\n",
"# Local GPU device\n",
"torch.device(0), torch.cuda.get_device_name(0)"
]
},
{
"cell_type": "markdown",
"id": "bdbfe673-a8d8-4ba5-afb5-3c4f7eb5b0e7",
"metadata": {},
"source": [
"### `_scaled_mm` FP8 matmul wrapper\n",
"\n",
"PyTorch `_scaled_mm` defintion: https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/cuda/Blas.cpp#L1176C1-L1176C16\n",
"\n",
"`cublasLtMatmul` not supported `E5M2 @ E5M2` matmuls: https://docs.nvidia.com/cuda/cublas/index.html?highlight=fp8#cublasltmatmul \n",
"\n",
"TorchAO is using `_scaled_mm` function for FP8 integration: https://github.com/pytorch/ao/blob/main/torchao/float8/float8_python_api.py"
]
},
{
"cell_type": "code",
"execution_count": 31,
"id": "87bcf537-3c09-4241-8ab7-f5c2a55c3ed2",
"metadata": {},
"outputs": [
{
"ename": "RuntimeError",
"evalue": "Multiplication of two Float8_e5m2 matrices is not supported",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[31], line 18\u001b[0m\n\u001b[1;32m 15\u001b[0m b_scale \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mones((), dtype\u001b[38;5;241m=\u001b[39mtorch\u001b[38;5;241m.\u001b[39mfloat32, device\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mcuda\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m 17\u001b[0m \u001b[38;5;66;03m# FP8 matmul\u001b[39;00m\n\u001b[0;32m---> 18\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_scaled_mm\u001b[49m\u001b[43m(\u001b[49m\u001b[43ma_fp8\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mb_fp8\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\n\u001b[1;32m 19\u001b[0m \u001b[43m \u001b[49m\u001b[43mout_dtype\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfloat16\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 20\u001b[0m \u001b[43m \u001b[49m\u001b[43mscale_a\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43ma_scale\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 21\u001b[0m \u001b[43m \u001b[49m\u001b[43mscale_b\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mb_scale\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 22\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_fast_accum\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 23\u001b[0m \u001b[43m \u001b[49m\u001b[43mbias\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 24\u001b[0m \u001b[43m \u001b[49m\u001b[43mscale_result\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m)\u001b[49m\n",
"\u001b[0;31mRuntimeError\u001b[0m: Multiplication of two Float8_e5m2 matrices is not supported"
]
}
],
"source": [
"M, N, K = 128, 64, 256\n",
"\n",
"a = torch.randn((M, K), dtype=torch.float16, device='cuda')\n",
"# Transpose as cuBLASLt requires column major on `rhs`\n",
"b = torch.randn((N, K), dtype=torch.float16, device='cuda').t()\n",
"\n",
"# FP8 inputs & scales\n",
"# a_fp8 = a.to(torch.float8_e4m3fn)\n",
"# b_fp8 = b.to(torch.float8_e4m3fn)\n",
"\n",
"a_fp8 = a.to(torch.float8_e5m2)\n",
"b_fp8 = b.to(torch.float8_e5m2)\n",
"\n",
"a_scale = torch.ones((), dtype=torch.float32, device='cuda')\n",
"b_scale = torch.ones((), dtype=torch.float32, device='cuda')\n",
"\n",
"# FP8 matmul\n",
"out = torch._scaled_mm(a_fp8, b_fp8, \n",
" out_dtype=torch.float16,\n",
" scale_a=a_scale,\n",
" scale_b=b_scale,\n",
" use_fast_accum=True,\n",
" bias=None,\n",
" scale_result=None)"
]
},
{
"cell_type": "code",
"execution_count": 28,
"id": "50a320ec-769e-4dc8-b933-29610918d395",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([128, 64]), torch.float16)"
]
},
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"out.shape, out.dtype"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6d3ef1a6-4322-4f87-901a-7e54185cd4f5",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

0 comments on commit f9aa9af

Please sign in to comment.