Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AutoScale quickstart notebook. #90

Merged
merged 1 commit into from
Jan 17, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
269 changes: 269 additions & 0 deletions examples/autoscale-quickstart.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,269 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "7c85dead-5274-487c-91ff-7137fbaca393",
"metadata": {},
"source": [
"# JAX Scaled Arithmetics / AutoScale quickstart\n",
"\n",
"**JAX Scaled Arithmetics** is a thin library implementing numerically stable scaled arithmetics, allowing easy training and inference of\n",
"deep neural networks in low precision (BF16, FP16, FP8) with full scale propagation."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "30940677-4296-40fa-b418-351fcfb62098",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import jax\n",
"import jax_scaled_arithmetics as jsa"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e0e729aa-7a81-4001-8a34-9a00ec822948",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 2,
"id": "f374e654-97e4-43ef-902a-a890d36a52b9",
"metadata": {},
"outputs": [],
"source": [
"# `autoscale` interpreter is tracing the graph, adding scale propagation where necessary.\n",
"@jsa.autoscale\n",
"def fn(a, b):\n",
" return a + b"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "8c59245d-27e5-41a7-bfef-f40849a7b550",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"INPUTS: [1. 2.] [3. 6.]\n",
"OUTPUT: [4. 8.] <class 'jaxlib.xla_extension.DeviceArray'>\n"
]
}
],
"source": [
"# Let's start with standard JAX inputs\n",
"a = np.array([1, 2], np.float16)\n",
"b = np.array([3, 6], np.float16)\n",
"out = fn(a, b)\n",
"\n",
"print(\"INPUTS:\", a, b)\n",
"# No scaled arithmetics => \"normal\" JAX mode.\n",
"print(\"OUTPUT:\", out, type(out))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e60cf138-d92d-4ab9-89d4-bacc9e28c39f",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 4,
"id": "e7efaa2e-00a1-40e8-9bbb-685edc975636",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"SCALED inputs: ScaledArray(data=array([1., 2.], dtype=float16), scale=1.0) ScaledArray(data=array([1.5, 3. ], dtype=float16), scale=2.0)\n",
"UNSCALED inputs: [1. 2.] [3. 6.]\n"
]
}
],
"source": [
"# Let's create input scaled arrays.\n",
"# NOTE: scale dtype does not have to match core data dtype.\n",
"sa = jsa.as_scaled_array(a, scale=np.float32(1))\n",
"sb = jsa.as_scaled_array(b, scale=np.float32(2))\n",
"\n",
"print(\"SCALED inputs:\", sa, sb)\n",
"# `as_scaled_array` does not change the value of tensor:\n",
"print(\"UNSCALED inputs:\", np.asarray(sa), np.asarray(sb))"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "1f457243-a0b8-4e4d-b45d-7444d0566b37",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"SCALED OUTPUT: ScaledArray(data=DeviceArray([2., 4.], dtype=float16), scale=DeviceArray(2., dtype=float32))\n",
"No scale rounding: ScaledArray(data=DeviceArray([1.789, 3.578], dtype=float16), scale=DeviceArray(2.236068, dtype=float32))\n"
]
}
],
"source": [
"# Running `fn` on scaled arrays triggers `autoscale` graph transformation\n",
"sout = fn(sa, sb)\n",
"# NOTE: by default, scale propagation is using power-of-2.\n",
"print(\"SCALED OUTPUT:\", sout)\n",
"\n",
"# To choose a different scale rounding:\n",
"with jsa.AutoScaleConfig(rounding_mode=jsa.Pow2RoundMode.NONE):\n",
" print(\"No scale rounding:\", fn(sa, sb))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c2429c10-00d9-44f8-b0b6-a1fdf13ed264",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 6,
"id": "307ee27d-6ed2-4ab6-a152-83947dbf47fd",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"RESCALED OUTPUT: ScaledArray(data=DeviceArray([0.5, 1. ], dtype=float16), scale=DeviceArray(8., dtype=float32))\n"
]
},
{
"data": {
"text/plain": [
"functools.partial(<jax._src.custom_derivatives.custom_vjp object at 0x7fc7b337c4c0>, <function dynamic_rescale_l1_base at 0x7fc7b3380430>)"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# JAX Scaled Arithmetics offers basic dynamic rescaling methods. e.g.: max, l1, l2\n",
"sout_rescaled = jsa.ops.dynamic_rescale_max(sout)\n",
"print(\"RESCALED OUTPUT:\", sout_rescaled)\n",
"\n",
"# Equivalent methods are available to dynamically rescale gradients:\n",
"jsa.ops.dynamic_rescale_l1_grad"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "32930d15-d7ff-41d1-85be-eee958bb0741",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# NOTE: in normal JAX mode, these rescale operations are no-ops:\n",
"jsa.ops.dynamic_rescale_max(a) is a"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ea5942e7-0279-4dc4-a720-b8c7323ab6a1",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 8,
"id": "9920f44a-26e2-4e20-89c3-4e2b2548239f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"ScaledArray(data=DeviceArray([16., 20.], dtype=float32), scale=1.0)\n"
]
}
],
"source": [
"import ml_dtypes\n",
"# Minimal FP8 simulated support is provided using jax.lax.reduce_precision and ml_dtypes.\n",
"# Similarly to `dynamic_rescale`, `cast_ml_dtype(_grad)` are available to cast in forward and backward passes\n",
"sc = jsa.as_scaled_array(np.array([17., 19.]), scale=np.float32(1))\n",
"\n",
"@jsa.autoscale\n",
"def cast_fn(v):\n",
" return jsa.ops.cast_ml_dtype(v, ml_dtypes.float8_e4m3fn)\n",
"\n",
"sc_fp8 = cast_fn(sc)\n",
"print(sc_fp8)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1bd7c1d5-4ea2-4ded-a066-818d9146b8a6",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
}
},
"nbformat": 4,
"nbformat_minor": 5
}