diff --git a/docs/PyTorch FP8 matmul tutorial.ipynb b/docs/PyTorch FP8 matmul tutorial.ipynb new file mode 100644 index 0000000..b373f39 --- /dev/null +++ b/docs/PyTorch FP8 matmul tutorial.ipynb @@ -0,0 +1,144 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "7ae3e6c9-01d2-4a34-a4a8-88d36c7e9b3f", + "metadata": {}, + "source": [ + "# PyTorch FP8 (fused) matmul tutorial" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "4c9500fc-648d-46d3-95ea-e74a0ee43fe6", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(device(type='cuda', index=0), 'NVIDIA H100 PCIe')" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import numpy as np\n", + "import torch\n", + "\n", + "# Local GPU device\n", + "torch.device(0), torch.cuda.get_device_name(0)" + ] + }, + { + "cell_type": "markdown", + "id": "bdbfe673-a8d8-4ba5-afb5-3c4f7eb5b0e7", + "metadata": {}, + "source": [ + "### `scaled_mm` FP8 matmul wrapper\n", + "\n", + "PyTorch `scaled_mm` defintion: https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/cuda/Blas.cpp#L1176C1-L1176C16\n", + "\n", + "`cublasLtMatmul` not supported `E5M2 @ E5M2` matmuls: https://docs.nvidia.com/cuda/cublas/index.html?highlight=fp8#cublasltmatmul " + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "id": "87bcf537-3c09-4241-8ab7-f5c2a55c3ed2", + "metadata": {}, + "outputs": [ + { + "ename": "RuntimeError", + "evalue": "Multiplication of two Float8_e5m2 matrices is not supported", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[31], line 18\u001b[0m\n\u001b[1;32m 15\u001b[0m b_scale \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mones((), dtype\u001b[38;5;241m=\u001b[39mtorch\u001b[38;5;241m.\u001b[39mfloat32, device\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mcuda\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m 17\u001b[0m \u001b[38;5;66;03m# FP8 matmul\u001b[39;00m\n\u001b[0;32m---> 18\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_scaled_mm\u001b[49m\u001b[43m(\u001b[49m\u001b[43ma_fp8\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mb_fp8\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\n\u001b[1;32m 19\u001b[0m \u001b[43m \u001b[49m\u001b[43mout_dtype\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfloat16\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 20\u001b[0m \u001b[43m \u001b[49m\u001b[43mscale_a\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43ma_scale\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 21\u001b[0m \u001b[43m \u001b[49m\u001b[43mscale_b\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mb_scale\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 22\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_fast_accum\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 23\u001b[0m \u001b[43m \u001b[49m\u001b[43mbias\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 24\u001b[0m \u001b[43m \u001b[49m\u001b[43mscale_result\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m)\u001b[49m\n", + "\u001b[0;31mRuntimeError\u001b[0m: Multiplication of two Float8_e5m2 matrices is not supported" + ] + } + ], + "source": [ + "M, N, K = 128, 64, 256\n", + "\n", + "a = torch.randn((M, K), dtype=torch.float16, device='cuda')\n", + "# Transpose as cuBLASLt requires column major on `rhs`\n", + "b = torch.randn((N, K), dtype=torch.float16, device='cuda').t()\n", + "\n", + "# FP8 inputs & scales\n", + "# a_fp8 = a.to(torch.float8_e4m3fn)\n", + "# b_fp8 = b.to(torch.float8_e4m3fn)\n", + "\n", + "a_fp8 = a.to(torch.float8_e5m2)\n", + "b_fp8 = b.to(torch.float8_e5m2)\n", + "\n", + "a_scale = torch.ones((), dtype=torch.float32, device='cuda')\n", + "b_scale = torch.ones((), dtype=torch.float32, device='cuda')\n", + "\n", + "# FP8 matmul\n", + "out = torch._scaled_mm(a_fp8, b_fp8, \n", + " out_dtype=torch.float16,\n", + " scale_a=a_scale,\n", + " scale_b=b_scale,\n", + " use_fast_accum=True,\n", + " bias=None,\n", + " scale_result=None)" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "50a320ec-769e-4dc8-b933-29610918d395", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(torch.Size([128, 64]), torch.float16)" + ] + }, + "execution_count": 28, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "out.shape, out.dtype" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6d3ef1a6-4322-4f87-901a-7e54185cd4f5", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}