diff --git a/notebooks/log-softmax-analysis.ipynb b/notebooks/log-softmax-analysis.ipynb new file mode 100644 index 0000000..7f9b080 --- /dev/null +++ b/notebooks/log-softmax-analysis.ipynb @@ -0,0 +1,335 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 36, + "id": "40b36be1-307a-437b-a401-8411407993f0", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "\n", + "import jax\n", + "import jax.lax as lax\n", + "import jax.nn\n", + "import jax.numpy as jnp\n", + "import jax_scaled_arithmetics as jsa" + ] + }, + { + "cell_type": "code", + "execution_count": 52, + "id": "fddd3b72-ce99-4e96-bae4-b4a884e749e7", + "metadata": {}, + "outputs": [], + "source": [ + "B = 128\n", + "N = 10" + ] + }, + { + "cell_type": "code", + "execution_count": 53, + "id": "9462a6fe-5ad4-459d-8be5-b24f0f8fe7af", + "metadata": {}, + "outputs": [], + "source": [ + "act = np.random.randn(B, N).astype(np.float32)" + ] + }, + { + "cell_type": "code", + "execution_count": 63, + "id": "80c21df4-2715-4404-ab5d-3ddceb34f4e3", + "metadata": {}, + "outputs": [], + "source": [ + "def logsumexp(a, axis=None, keepdims=True):\n", + " dims = (axis,)\n", + " amax = jnp.max(a, axis=dims, keepdims=keepdims)\n", + " # FIXME: not proper scale propagation, introducing NaNs\n", + " # amax = lax.stop_gradient(lax.select(jnp.isfinite(amax), amax, lax.full_like(amax, 0)))\n", + " amax = lax.stop_gradient(amax)\n", + " out = lax.sub(a, amax)\n", + " out = lax.exp(out)\n", + " out = lax.add(lax.log(jnp.sum(out, axis=dims, keepdims=keepdims)), amax)\n", + " return out\n" + ] + }, + { + "cell_type": "code", + "execution_count": 64, + "id": "84349e45-5ce7-4fd2-9449-9d30c48de291", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(128, 1)" + ] + }, + "execution_count": 64, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "def fn(act):\n", + " return logsumexp(act, axis=1)\n", + "\n", + "fn(act).shape" + ] + }, + { + "cell_type": "code", + "execution_count": 58, + "id": "44aeddb0-948f-4da7-b55e-fa90f3548a97", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Tracedwith\n" + ] + }, + { + "data": { + "text/plain": [ + "{ lambda ; a:f32[128,10]. let\n", + " b:f32[128] = reduce_max[axes=(1,)] a\n", + " c:f32[128,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(128, 1)] b\n", + " d:f32[128,1] = stop_gradient c\n", + " e:f32[128,10] = sub a d\n", + " f:f32[128,10] = exp e\n", + " g:f32[128] = reduce_sum[axes=(1,)] f\n", + " h:f32[128,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(128, 1)] g\n", + " i:f32[128,1] = log h\n", + " j:f32[128,1] = add i d\n", + " in (j,) }" + ] + }, + "execution_count": 58, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "jax.make_jaxpr(fn)(act)" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "id": "15cf4d4e-e367-4bc5-836c-d9e5b13ea3c9", + "metadata": {}, + "outputs": [], + "source": [ + "out, fn_vjp = jax.vjp(fn, act)" + ] + }, + { + "cell_type": "code", + "execution_count": 48, + "id": "6f0d8366-1e44-4795-babc-2e009ed111d7", + "metadata": {}, + "outputs": [], + "source": [ + "def fn_with_grad(in_act, out_grad):\n", + " out_act, fn_vjp = jax.vjp(fn, in_act)\n", + " return out_act, fn_vjp(out_grad)" + ] + }, + { + "cell_type": "code", + "execution_count": 66, + "id": "097a08b4-1a61-4dd6-84f1-f4105a53d9e2", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{ lambda ; a:f32[128,10] b:f32[128,1]. let\n", + " c:f32[128] = reduce_max[axes=(1,)] a\n", + " d:f32[128,1] = reshape[dimensions=None new_sizes=(128, 1)] c\n", + " e:bool[128,10] = eq a d\n", + " f:f32[128,10] = convert_element_type[new_dtype=float32 weak_type=False] e\n", + " _:f32[128] = reduce_sum[axes=(1,)] f\n", + " g:f32[128,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(128, 1)] c\n", + " h:f32[128,1] = stop_gradient g\n", + " i:f32[128,10] = sub a h\n", + " j:f32[128,10] = exp i\n", + " k:f32[128] = reduce_sum[axes=(1,)] j\n", + " l:f32[128,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(128, 1)] k\n", + " m:f32[128,1] = log l\n", + " n:f32[128,1] = add m h\n", + " o:f32[128,1] = div b l\n", + " p:f32[128] = reduce_sum[axes=(1,)] o\n", + " q:f32[128,10] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(128, 10)] p\n", + " r:f32[128,10] = mul q j\n", + " in (n, r) }" + ] + }, + "execution_count": 66, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "jax.make_jaxpr(fn_with_grad)(act, act[:, :1])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "07782137-deb5-4a34-805c-209d68f86880", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1a709ffe-3bca-4a5b-b96c-50cebf8f4dd1", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 72, + "id": "084098be-2729-4080-b484-fc98fc2febd9", + "metadata": {}, + "outputs": [], + "source": [ + "def fn2(x, y):\n", + " return x * y\n", + "\n", + "\n", + "def fn2_with_grad(in_act, out_grad):\n", + " out_act, fn_vjp = jax.vjp(fn2, in_act, in_act)\n", + " return out_act, fn_vjp(out_grad)" + ] + }, + { + "cell_type": "code", + "execution_count": 71, + "id": "c922d288-e938-4de0-a46a-edb48df6d3c0", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{ lambda ; a:f32[128,10] b:f32[128,10]. let\n", + " c:f32[128,10] = mul a a\n", + " d:f32[128,10] = mul a b\n", + " e:f32[128,10] = mul b a\n", + " in (c, e, d) }" + ] + }, + "execution_count": 71, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "jax.make_jaxpr(fn2_with_grad)(act, act)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9c2f08c3-05b5-4661-81a4-1abfc1a4e625", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 79, + "id": "bd63a6ea-2ab0-4f0f-81cd-8ce4e25ef3b9", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "((128, 10), dtype('float32'))" + ] + }, + "execution_count": 79, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "act.shape, act.dtype" + ] + }, + { + "cell_type": "code", + "execution_count": 75, + "id": "4e8e0182-2bf6-415b-b45e-fa12ac355db5", + "metadata": {}, + "outputs": [], + "source": [ + "def fn3(x):\n", + " return jax.grad(lambda x: jnp.mean(x))(x)" + ] + }, + { + "cell_type": "code", + "execution_count": 81, + "id": "34dcb443-e0e6-4ded-9ce9-b4bb6660d553", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{ lambda ; a:f32[128,10]. let\n", + " b:f32[] = reduce_sum[axes=(0, 1)] a\n", + " _:f32[] = div b 1280.0\n", + " c:f32[] = div 1.0 1280.0\n", + " d:f32[128,10] = broadcast_in_dim[broadcast_dimensions=() shape=(128, 10)] c\n", + " in (d,) }" + ] + }, + "execution_count": 81, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "jax.make_jaxpr(fn3)(act)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c5228dc2-8058-4c9a-9089-3e44a9b1eeba", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}