diff --git a/docs/JAX FP8 matmul tutorial.ipynb b/docs/JAX FP8 matmul tutorial.ipynb index 8c89195..dffd56d 100644 --- a/docs/JAX FP8 matmul tutorial.ipynb +++ b/docs/JAX FP8 matmul tutorial.ipynb @@ -115,7 +115,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 60, "id": "9be90f27-5520-45f6-a42d-b309572e6e91", "metadata": {}, "outputs": [ @@ -140,10 +140,10 @@ "print(\"E4M3 @ E4M3 FP8 matmul output:\", c.aval)\n", "\n", "# E4M3/E5M2 mixed matrix multiplication (NOTE: transpose to reduce on last axis).\n", - "b = jax.random.normal(key, (128, 64), jnp.float8_e5m2)\n", - "c = jax.lax.dot(a, b.T)\n", + "c = jax.random.normal(key, (128, 64), jnp.float8_e5m2)\n", + "d = jax.lax.dot(a, c.T)\n", "# Note: default output dtype is E5M2.\n", - "print(\"E4M3 @ E5M2 FP8 matmul output:\", c.aval)" + "print(\"E4M3 @ E5M2 FP8 matmul output:\", d.aval)" ] }, { @@ -158,7 +158,7 @@ }, { "cell_type": "code", - "execution_count": 47, + "execution_count": 61, "id": "7edfa758-bf4e-49fa-8c5d-5dc9c0c2c346", "metadata": {}, "outputs": [ @@ -166,14 +166,14 @@ "name": "stdout", "output_type": "stream", "text": [ - "HloModule jit_matmul_fn, is_scheduled=true, entry_computation_layout={(f8e4m3fn[32,64]{1,0}, f8e5m2[128,64]{1,0})->f8e5m2[32,128]{1,0}}, allow_spmd_sharding_propagation_to_parameters={true,true}, allow_spmd_sharding_propagation_to_output={true}, frontend_attributes={fingerprint_before_lhs=\"7015b3550208fe306d90cad5cc7c304f\"}\n", + "HloModule jit_matmul_fn, is_scheduled=true, entry_computation_layout={(f8e4m3fn[32,64]{1,0}, f8e4m3fn[128,64]{1,0})->f8e4m3fn[32,128]{1,0}}, allow_spmd_sharding_propagation_to_parameters={true,true}, allow_spmd_sharding_propagation_to_output={true}, frontend_attributes={fingerprint_before_lhs=\"0de13023ed9307a3f8d7f4d22d81a638\"}\n", "\n", - "ENTRY %main.5 (Arg_0.1.0: f8e4m3fn[32,64], Arg_1.2.0: f8e5m2[128,64]) -> f8e5m2[32,128] {\n", + "ENTRY %main.5 (Arg_0.1.0: f8e4m3fn[32,64], Arg_1.2.0: f8e4m3fn[128,64]) -> f8e4m3fn[32,128] {\n", " %constant_1 = f32[] constant(1)\n", - " %Arg_1.2.0 = f8e5m2[128,64]{1,0} parameter(1)\n", + " %Arg_1.2.0 = f8e4m3fn[128,64]{1,0} parameter(1)\n", " %Arg_0.1.0 = f8e4m3fn[32,64]{1,0} parameter(0)\n", - " %cublas-gemm.1.0 = (f8e5m2[32,128]{1,0}, s8[33554432]{0}) custom-call(f8e4m3fn[32,64]{1,0} %Arg_0.1.0, f8e5m2[128,64]{1,0} %Arg_1.2.0, f32[] %constant_1, f32[] %constant_1, f32[] %constant_1, /*index=5*/f32[] %constant_1), custom_call_target=\"__cublas$lt$matmul$f8\"\n", - " ROOT %get-tuple-element.1 = f8e5m2[32,128]{1,0} get-tuple-element((f8e5m2[32,128]{1,0}, s8[33554432]{0}) %cublas-gemm.1.0), index=0\n", + " %cublas-gemm.1.0 = (f8e4m3fn[32,128]{1,0}, s8[33554432]{0}) custom-call(f8e4m3fn[32,64]{1,0} %Arg_0.1.0, f8e4m3fn[128,64]{1,0} %Arg_1.2.0, f32[] %constant_1, f32[] %constant_1, f32[] %constant_1, /*index=5*/f32[] %constant_1), custom_call_target=\"__cublas$lt$matmul$f8\"\n", + " ROOT %get-tuple-element.1 = f8e4m3fn[32,128]{1,0} get-tuple-element((f8e4m3fn[32,128]{1,0}, s8[33554432]{0}) %cublas-gemm.1.0), index=0\n", "}\n", "\n", "\n" @@ -205,7 +205,7 @@ }, { "cell_type": "code", - "execution_count": 48, + "execution_count": 62, "id": "72d805ea-89b6-457d-9558-ff31fdd23d35", "metadata": {}, "outputs": [ @@ -240,7 +240,7 @@ ] }, "rhs_stride": "8192", - "selected_algorithm": "0" + "selected_algorithm": "2" }, "operation_queue_id": "0", "wait_on_operation_queues": [] @@ -249,7 +249,7 @@ "" ] }, - "execution_count": 48, + "execution_count": 62, "metadata": { "application/json": { "expanded": true, @@ -291,60 +291,142 @@ "source": [ "## Fused FP8 matmul in JAX: from simple to complicated!\n", "\n", - "As presented in the literature, using FP8 for matrix multiplication in machine learning models is not as simple as FP16 and BF16. Due to the smaller dynamic range, maintaining model accuracy requires inputs to be properly scaled to avoid overflow or underflow." + "As presented above, the FP8 XLA custom target `__cublas$lt$matmul$f8` has an extended API & config allowing *fusing** multiple operations in the GEMM kernel. More specifically:\n", + "* Scaling of input & output tensors;\n", + "* Capturing absolute-maximum of the output (usually called `damax`);\n", + "* Post-matmul bias or/and non-linearity;\n", + "\n", + "We present below how to generate the proper fused matmul call directly from JAX (and checking in the compiled HLO!). Starting with inputs & outputs scaling, following the interface of `__cublas$lt$matmul$f8`." ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 83, "id": "1ed9d08e-b18a-4fe7-bcba-72b95ddf6e68", "metadata": {}, - "outputs": [], - "source": [] + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "<<< JAX compilation error >>>\n", + "Input dtypes ('float8_e4m3fn', 'float32') have no available implicit dtype promotion path. To avoid unintended promotion, 8-bit floats do not support implicit promotion. If you'd like your inputs to be promoted to another type, you can do so explicitly using e.g. x.astype('float32')\n" + ] + } + ], + "source": [ + "def matmul_fn_with_scale(a_fp8, b_fp8, a_scale, b_scale, c_scale):\n", + " # First try: can we just scale the input with an FP32 scalar?\n", + " a_fp8 = a_fp8 * a_scale\n", + " out = jax.lax.dot(a_fp8, b_fp8.T)\n", + " return out\n", + "\n", + "# `__cublas$lt$matmul$f8` expecting FP32 scales.\n", + "scale_aval = jax.core.ShapedArray((), jnp.float32)\n", + "try:\n", + " fn_compiled = jax.jit(matmul_fn_with_scale).lower(a, b, scale_aval, scale_aval, scale_aval).compile()\n", + "except Exception as e:\n", + " # Issue: JAX does not support implicit mixed-multiplication FP8 x FP32\n", + " print(f\"<<< JAX compilation error >>>\\n{e}\")" + ] }, { - "cell_type": "code", - "execution_count": 42, - "id": "e238ba4d-d749-477b-9ce9-f457a17a75a7", + "cell_type": "markdown", + "id": "57c2a7de-2e89-40b2-b38c-9b8876472d8a", "metadata": {}, - "outputs": [], "source": [ - "import jax\n", - "import jax.numpy as jnp\n" + "### FP8 matmul with scaled inputs & outputs\n", + "\n", + "JAX and XLA do not allow implicit conversion between FP8 and FP32, meaning that we need to write something a bit more explicit for the XLA compiler to pattern match and generate the fused call. More specifically, as presented in XLA FP8 RFC, one needs to adopt a dequantization/quantization type of semantics:\n", + "* Upcast inputs to `float32` and then scale;\n", + "* Scale output, clamp to `float8` range (not optional!) and then downcast to `float8`;\n", + "\n", + "As presented below, when using this pattern, the XLA compiler is able to fuse all the operations into a single call of `__cublas$lt$matmul$f8`." ] }, { "cell_type": "code", - "execution_count": 46, - "id": "2893288d-7f2a-42e1-8541-afeed1d63a85", + "execution_count": 82, + "id": "b9a608d7-6cf8-457b-8275-bdcacc9b06fe", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "HloModule jit_matmul_fn_with_scale, is_scheduled=true, entry_computation_layout={(f8e4m3fn[32,64]{1,0}, f8e4m3fn[128,64]{1,0}, f32[], f32[], f32[])->f8e4m3fn[32,128]{1,0}}, allow_spmd_sharding_propagation_to_parameters={true,true,true,true,true}, allow_spmd_sharding_propagation_to_output={true}, frontend_attributes={fingerprint_before_lhs=\"0ae55bc5ea38f0523b45347dc49424c4\"}\n", + "\n", + "ENTRY %main.22 (Arg_0.1.0: f8e4m3fn[32,64], Arg_1.2.0: f8e4m3fn[128,64], Arg_2.3.0: f32[], Arg_3.4.0: f32[], Arg_4.5.0: f32[]) -> f8e4m3fn[32,128] {\n", + " %constant_1 = f32[] constant(1)\n", + " %Arg_4.5.0 = f32[] parameter(4)\n", + " %Arg_3.4.0 = f32[] parameter(3)\n", + " %Arg_2.3.0 = f32[] parameter(2)\n", + " %Arg_1.2.0 = f8e4m3fn[128,64]{1,0} parameter(1)\n", + " %Arg_0.1.0 = f8e4m3fn[32,64]{1,0} parameter(0)\n", + " %cublas-gemm.clone.1.0 = (f8e4m3fn[32,128]{1,0}, s8[33554432]{0}) custom-call(f8e4m3fn[32,64]{1,0} %Arg_0.1.0, f8e4m3fn[128,64]{1,0} %Arg_1.2.0, f32[] %Arg_2.3.0, f32[] %Arg_3.4.0, f32[] %constant_1, /*index=5*/f32[] %Arg_4.5.0), custom_call_target=\"__cublas$lt$matmul$f8\"\n", + " ROOT %get-tuple-element.1 = f8e4m3fn[32,128]{1,0} get-tuple-element((f8e4m3fn[32,128]{1,0}, s8[33554432]{0}) %cublas-gemm.clone.1.0), index=0\n", + "}\n", + "\n", + "\n" + ] + } + ], "source": [ - "# Starting with the most simple matmul!\n", - "def matmul_fn(a_fp8, b_fp8):\n", - " # FP8 x FP8 -> FP8 matmul\n", - " return jax.lax.dot(a_fp8, b_fp8)" + "e4m3_max = ml_dtypes.finfo(jnp.float8_e4m3fn).max\n", + "\n", + "# XLA requires a \"dequantize/quantize\" pattern to properly support scaled FP8 inputs/outputs. \n", + "def matmul_fn_with_scale(a_fp8, b_fp8, a_scale, b_scale, d_scale):\n", + " # Dequantize x and y\n", + " a_fp32 = a_fp8.astype(jnp.float32) * a_scale\n", + " b_fp32 = b_fp8.astype(jnp.float32) * b_scale\n", + " \n", + " # Do the matmul (NOTE: adding transpose to simplify HLO).\n", + " d_fp32 = jax.lax.dot(a_fp32, b_fp32.T)\n", + " \n", + " # Rescale & clamp to -max/+max FP8 E4M3 values.\n", + " d_fp32 = d_fp32 * d_scale\n", + " # NOTE: clamping is NOT optional for proper pattern matching!\n", + " d_fp32 = jax.lax.clamp(jnp.float32(-e4m3_max), d_fp32, jnp.float32(e4m3_max))\n", + " # (Re)Quantize the scaled matmul output.\n", + " return d_fp32.astype(jnp.float8_e4m3fn)\n", + "\n", + "# AOT compilation with JAX, inspecting the (final) HLO module generated.\n", + "fn_compiled = jax.jit(matmul_fn_with_scale).lower(a, b, scale_aval, scale_aval, scale_aval).compile()\n", + "# (Human readable) optimized Hlo module generated by XLA (ignoring GEMM backend config)\n", + "print_hlo_module(fn_compiled, backend_cfg=False)" + ] + }, + { + "cell_type": "markdown", + "id": "437fefcf-bfb2-42aa-a899-0a57416a6a5e", + "metadata": {}, + "source": [ + "### Adding `relu` to the FP8 matmul\n", + "\n", + "Can we get XLA to fuse the post non-linearity function as well?" ] }, { "cell_type": "code", - "execution_count": 7, - "id": "a6142f8d-08ee-4fa6-962f-2b85a1bcecb6", + "execution_count": 93, + "id": "44f28bbb-d4c6-4170-a736-76d667d73f97", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "HloModule jit_matmul_fn, is_scheduled=true, entry_computation_layout={(f8e4m3fn[32,64]{1,0}, f8e4m3fn[64,128]{1,0})->f8e4m3fn[32,128]{1,0}}, allow_spmd_sharding_propagation_to_parameters={true,true}, allow_spmd_sharding_propagation_to_output={true}\n", + "HloModule jit_matmul_fn_with_scale, is_scheduled=true, entry_computation_layout={(f8e4m3fn[32,64]{1,0}, f8e4m3fn[128,64]{1,0}, f32[], f32[], f32[])->f8e4m3fn[32,128]{1,0}}, allow_spmd_sharding_propagation_to_parameters={true,true,true,true,true}, allow_spmd_sharding_propagation_to_output={true}, frontend_attributes={fingerprint_before_lhs=\"f1fb5db9dad54941d7d17e04fdbe9515\"}\n", "\n", - "ENTRY %main.4 (Arg_0.1: f8e4m3fn[32,64], Arg_1.2: f8e4m3fn[64,128]) -> f8e4m3fn[32,128] {\n", - " %Arg_1.2 = f8e4m3fn[64,128]{1,0} parameter(1)\n", - " %convert.4 = f32[64,128]{1,0} convert(f8e4m3fn[64,128]{1,0} %Arg_1.2)\n", - " %Arg_0.1 = f8e4m3fn[32,64]{1,0} parameter(0)\n", - " %convert.3 = f32[32,64]{1,0} convert(f8e4m3fn[32,64]{1,0} %Arg_0.1)\n", - " %dot.0 = f32[32,128]{1,0} dot(f32[32,64]{1,0} %convert.3, f32[64,128]{1,0} %convert.4), lhs_contracting_dims={1}, rhs_contracting_dims={0}\n", - " ROOT %convert.2 = f8e4m3fn[32,128]{1,0} convert(f32[32,128]{1,0} %dot.0)\n", + "ENTRY %main.28 (Arg_0.1.0: f8e4m3fn[32,64], Arg_1.2.0: f8e4m3fn[128,64], Arg_2.3.0: f32[], Arg_3.4.0: f32[], Arg_4.5.0: f32[]) -> f8e4m3fn[32,128] {\n", + " %constant_1_0 = f32[] constant(1)\n", + " %Arg_4.5.0 = f32[] parameter(4)\n", + " %Arg_3.4.0 = f32[] parameter(3)\n", + " %Arg_2.3.0 = f32[] parameter(2)\n", + " %Arg_1.2.0 = f8e4m3fn[128,64]{1,0} parameter(1)\n", + " %Arg_0.1.0 = f8e4m3fn[32,64]{1,0} parameter(0)\n", + " %cublas-gemm.2.clone.1.0 = (f8e4m3fn[32,128]{1,0}, s8[33554432]{0}) custom-call(f8e4m3fn[32,64]{1,0} %Arg_0.1.0, f8e4m3fn[128,64]{1,0} %Arg_1.2.0, f32[] %Arg_2.3.0, f32[] %Arg_3.4.0, f32[] %constant_1_0, /*index=5*/f32[] %Arg_4.5.0), custom_call_target=\"__cublas$lt$matmul$f8\"\n", + " ROOT %get-tuple-element.1 = f8e4m3fn[32,128]{1,0} get-tuple-element((f8e4m3fn[32,128]{1,0}, s8[33554432]{0}) %cublas-gemm.2.clone.1.0), index=0\n", "}\n", "\n", "\n" @@ -352,15 +434,111 @@ } ], "source": [ - "a_aval = jax.core.ShapedArray((32, 64), jnp.float8_e4m3fn)\n", - "b_aval = jax.core.ShapedArray((64, 128), jnp.float8_e4m3fn)\n", + "e4m3_max = ml_dtypes.finfo(jnp.float8_e4m3fn).max\n", + "\n", + "# XLA requires a \"dequantize/quantize\" pattern to properly support scaled FP8 inputs/outputs. \n", + "def matmul_fn_with_scale(a_fp8, b_fp8, a_scale, b_scale, d_scale):\n", + " # Dequantize x and y\n", + " a_fp32 = a_fp8.astype(jnp.float32) * a_scale\n", + " b_fp32 = b_fp8.astype(jnp.float32) * b_scale\n", + " \n", + " # Do the matmul (NOTE: adding transpose to simplify HLO).\n", + " d_fp32 = jax.lax.dot(a_fp32, b_fp32.T)\n", + " # ReLU non-linearity. Note: needs to be before the scaling.\n", + " d_fp32 = jax.nn.relu(d_fp32)\n", + " \n", + " # Rescale & clamp to -max/+max FP8 E4M3 values.\n", + " d_fp32 = d_fp32 * d_scale\n", + " # NOTE: clamping is NOT optional for proper pattern matching!\n", + " d_fp32 = jax.lax.clamp(jnp.float32(-e4m3_max), d_fp32, jnp.float32(e4m3_max))\n", + " # (Re)Quantize the scaled matmul output.\n", + " return d_fp32.astype(jnp.float8_e4m3fn)\n", "\n", "# AOT compilation with JAX, inspecting the (final) HLO module generated.\n", - "fn_compiled = jax.jit(matmul_fn).lower(a_aval, b_aval).compile()\n", + "fn_compiled = jax.jit(matmul_fn_with_scale).lower(a, b, scale_aval, scale_aval, scale_aval).compile()\n", "# (Human readable) optimized Hlo module generated by XLA (ignoring GEMM backend config)\n", "print_hlo_module(fn_compiled, backend_cfg=False)" ] }, + { + "cell_type": "code", + "execution_count": 95, + "id": "2ca21eae-8b0c-454b-b670-1ef0d5935a5c", + "metadata": {}, + "outputs": [ + { + "data": { + "application/json": { + "force_earliest_schedule": false, + "gemm_backend_config": { + "alpha_imag": 0, + "alpha_real": 1, + "beta": 0, + "damax_output": false, + "dot_dimension_numbers": { + "lhs_batch_dimensions": [], + "lhs_contracting_dimensions": [ + "1" + ], + "rhs_batch_dimensions": [], + "rhs_contracting_dimensions": [ + "1" + ] + }, + "epilogue": "RELU", + "grad_x": false, + "grad_y": false, + "lhs_stride": "2048", + "precision_config": { + "algorithm": "ALG_UNSET", + "operand_precision": [ + "DEFAULT", + "DEFAULT" + ] + }, + "rhs_stride": "8192", + "selected_algorithm": "2" + }, + "operation_queue_id": "0", + "wait_on_operation_queues": [] + }, + "text/plain": [ + "" + ] + }, + "execution_count": 95, + "metadata": { + "application/json": { + "expanded": true, + "root": "root" + } + }, + "output_type": "execute_result" + } + ], + "source": [ + "hlo_module = parse_hlo_module(fn_compiled)\n", + "backend_config = next((m.backend_config for m in hlo_module if \"__cublas$lt$matmul$f8\" in m.cmd))\n", + "# the epilogue is set to `RELU`\n", + "JSON(backend_config, expanded=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "id": "2893288d-7f2a-42e1-8541-afeed1d63a85", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a6142f8d-08ee-4fa6-962f-2b85a1bcecb6", + "metadata": {}, + "outputs": [], + "source": [] + }, { "cell_type": "markdown", "id": "89cc3c24-70bb-4b03-b207-4ac304621579", @@ -406,27 +584,31 @@ }, { "cell_type": "code", - "execution_count": 85, + "execution_count": 69, "id": "6abbed6c-2a06-4942-becb-da9c7b7b79e7", "metadata": {}, - "outputs": [], - "source": [ - "# XLA requires a \"dequantize/quantize\" pattern to properly support scaled FP8 inputs/outputs. \n", - "def matmul_fn_with_scale(a_fp8, b_fp8, a_scale, b_scale, d_scale):\n", - " # Dequantize x and y\n", - " a_fp32 = a_fp8.astype(jnp.float32) * a_scale\n", - " b_fp32 = b_fp8.astype(jnp.float32) * b_scale\n", - " \n", - " # Do the matmul (NOTE: adding transpose to simplify HLO).\n", - " d_fp32 = jax.lax.dot(a_fp32, b_fp32.transpose())\n", - " \n", - " # Rescale & clamp to -max/+max FP8 E4M3 values.\n", - " d_fp32 = d_fp32 * d_scale\n", - " # NOTE: clamping is NOT optional for proper pattern matching!\n", - " d_fp32 = jax.lax.clamp(jnp.float32(-448), d_fp32, jnp.float32(448))\n", - " # (Re)Quantize the scaled matmul output.\n", - " return d_fp32.astype(jnp.float8_e4m3fn)" - ] + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "HloModule jit_matmul_fn_with_scale, is_scheduled=true, entry_computation_layout={(f8e4m3fn[32,64]{1,0}, f8e4m3fn[128,64]{1,0}, f32[], f32[])->f8e4m3fn[32,128]{1,0}}, allow_spmd_sharding_propagation_to_parameters={true,true,true,true}, allow_spmd_sharding_propagation_to_output={true}, frontend_attributes={fingerprint_before_lhs=\"8d4518a08891b2bbb34d71d34d902ec1\"}\n", + "\n", + "ENTRY %main.19 (Arg_0.1.0: f8e4m3fn[32,64], Arg_1.2.0: f8e4m3fn[128,64], Arg_2.3.0: f32[], Arg_3.4.0: f32[]) -> f8e4m3fn[32,128] {\n", + " %constant_1 = f32[] constant(1)\n", + " %Arg_3.4.0 = f32[] parameter(3)\n", + " %Arg_2.3.0 = f32[] parameter(2)\n", + " %Arg_1.2.0 = f8e4m3fn[128,64]{1,0} parameter(1)\n", + " %Arg_0.1.0 = f8e4m3fn[32,64]{1,0} parameter(0)\n", + " %cublas-gemm.clone.1.0 = (f8e4m3fn[32,128]{1,0}, s8[33554432]{0}) custom-call(f8e4m3fn[32,64]{1,0} %Arg_0.1.0, f8e4m3fn[128,64]{1,0} %Arg_1.2.0, f32[] %Arg_2.3.0, f32[] %Arg_3.4.0, f32[] %constant_1, /*index=5*/f32[] %constant_1), custom_call_target=\"__cublas$lt$matmul$f8\"\n", + " ROOT %get-tuple-element.1 = f8e4m3fn[32,128]{1,0} get-tuple-element((f8e4m3fn[32,128]{1,0}, s8[33554432]{0}) %cublas-gemm.clone.1.0), index=0\n", + "}\n", + "\n", + "\n" + ] + } + ], + "source": [] }, { "cell_type": "code",