Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
balancap committed Jan 18, 2024
1 parent 3c21e05 commit 556641b
Showing 1 changed file with 335 additions and 0 deletions.
335 changes: 335 additions & 0 deletions notebooks/log-softmax-analysis.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,335 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 36,
"id": "40b36be1-307a-437b-a401-8411407993f0",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"\n",
"import jax\n",
"import jax.lax as lax\n",
"import jax.nn\n",
"import jax.numpy as jnp\n",
"import jax_scaled_arithmetics as jsa"
]
},
{
"cell_type": "code",
"execution_count": 52,
"id": "fddd3b72-ce99-4e96-bae4-b4a884e749e7",
"metadata": {},
"outputs": [],
"source": [
"B = 128\n",
"N = 10"
]
},
{
"cell_type": "code",
"execution_count": 53,
"id": "9462a6fe-5ad4-459d-8be5-b24f0f8fe7af",
"metadata": {},
"outputs": [],
"source": [
"act = np.random.randn(B, N).astype(np.float32)"
]
},
{
"cell_type": "code",
"execution_count": 63,
"id": "80c21df4-2715-4404-ab5d-3ddceb34f4e3",
"metadata": {},
"outputs": [],
"source": [
"def logsumexp(a, axis=None, keepdims=True):\n",
" dims = (axis,)\n",
" amax = jnp.max(a, axis=dims, keepdims=keepdims)\n",
" # FIXME: not proper scale propagation, introducing NaNs\n",
" # amax = lax.stop_gradient(lax.select(jnp.isfinite(amax), amax, lax.full_like(amax, 0)))\n",
" amax = lax.stop_gradient(amax)\n",
" out = lax.sub(a, amax)\n",
" out = lax.exp(out)\n",
" out = lax.add(lax.log(jnp.sum(out, axis=dims, keepdims=keepdims)), amax)\n",
" return out\n"
]
},
{
"cell_type": "code",
"execution_count": 64,
"id": "84349e45-5ce7-4fd2-9449-9d30c48de291",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(128, 1)"
]
},
"execution_count": 64,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def fn(act):\n",
" return logsumexp(act, axis=1)\n",
"\n",
"fn(act).shape"
]
},
{
"cell_type": "code",
"execution_count": 58,
"id": "44aeddb0-948f-4da7-b55e-fa90f3548a97",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Traced<ShapedArray(float32[128,10])>with<DynamicJaxprTrace(level=1/0)>\n"
]
},
{
"data": {
"text/plain": [
"{ lambda ; a:f32[128,10]. let\n",
" b:f32[128] = reduce_max[axes=(1,)] a\n",
" c:f32[128,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(128, 1)] b\n",
" d:f32[128,1] = stop_gradient c\n",
" e:f32[128,10] = sub a d\n",
" f:f32[128,10] = exp e\n",
" g:f32[128] = reduce_sum[axes=(1,)] f\n",
" h:f32[128,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(128, 1)] g\n",
" i:f32[128,1] = log h\n",
" j:f32[128,1] = add i d\n",
" in (j,) }"
]
},
"execution_count": 58,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"jax.make_jaxpr(fn)(act)"
]
},
{
"cell_type": "code",
"execution_count": 47,
"id": "15cf4d4e-e367-4bc5-836c-d9e5b13ea3c9",
"metadata": {},
"outputs": [],
"source": [
"out, fn_vjp = jax.vjp(fn, act)"
]
},
{
"cell_type": "code",
"execution_count": 48,
"id": "6f0d8366-1e44-4795-babc-2e009ed111d7",
"metadata": {},
"outputs": [],
"source": [
"def fn_with_grad(in_act, out_grad):\n",
" out_act, fn_vjp = jax.vjp(fn, in_act)\n",
" return out_act, fn_vjp(out_grad)"
]
},
{
"cell_type": "code",
"execution_count": 66,
"id": "097a08b4-1a61-4dd6-84f1-f4105a53d9e2",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{ lambda ; a:f32[128,10] b:f32[128,1]. let\n",
" c:f32[128] = reduce_max[axes=(1,)] a\n",
" d:f32[128,1] = reshape[dimensions=None new_sizes=(128, 1)] c\n",
" e:bool[128,10] = eq a d\n",
" f:f32[128,10] = convert_element_type[new_dtype=float32 weak_type=False] e\n",
" _:f32[128] = reduce_sum[axes=(1,)] f\n",
" g:f32[128,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(128, 1)] c\n",
" h:f32[128,1] = stop_gradient g\n",
" i:f32[128,10] = sub a h\n",
" j:f32[128,10] = exp i\n",
" k:f32[128] = reduce_sum[axes=(1,)] j\n",
" l:f32[128,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(128, 1)] k\n",
" m:f32[128,1] = log l\n",
" n:f32[128,1] = add m h\n",
" o:f32[128,1] = div b l\n",
" p:f32[128] = reduce_sum[axes=(1,)] o\n",
" q:f32[128,10] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(128, 10)] p\n",
" r:f32[128,10] = mul q j\n",
" in (n, r) }"
]
},
"execution_count": 66,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"jax.make_jaxpr(fn_with_grad)(act, act[:, :1])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "07782137-deb5-4a34-805c-209d68f86880",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "1a709ffe-3bca-4a5b-b96c-50cebf8f4dd1",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 72,
"id": "084098be-2729-4080-b484-fc98fc2febd9",
"metadata": {},
"outputs": [],
"source": [
"def fn2(x, y):\n",
" return x * y\n",
"\n",
"\n",
"def fn2_with_grad(in_act, out_grad):\n",
" out_act, fn_vjp = jax.vjp(fn2, in_act, in_act)\n",
" return out_act, fn_vjp(out_grad)"
]
},
{
"cell_type": "code",
"execution_count": 71,
"id": "c922d288-e938-4de0-a46a-edb48df6d3c0",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{ lambda ; a:f32[128,10] b:f32[128,10]. let\n",
" c:f32[128,10] = mul a a\n",
" d:f32[128,10] = mul a b\n",
" e:f32[128,10] = mul b a\n",
" in (c, e, d) }"
]
},
"execution_count": 71,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"jax.make_jaxpr(fn2_with_grad)(act, act)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9c2f08c3-05b5-4661-81a4-1abfc1a4e625",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 79,
"id": "bd63a6ea-2ab0-4f0f-81cd-8ce4e25ef3b9",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"((128, 10), dtype('float32'))"
]
},
"execution_count": 79,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"act.shape, act.dtype"
]
},
{
"cell_type": "code",
"execution_count": 75,
"id": "4e8e0182-2bf6-415b-b45e-fa12ac355db5",
"metadata": {},
"outputs": [],
"source": [
"def fn3(x):\n",
" return jax.grad(lambda x: jnp.mean(x))(x)"
]
},
{
"cell_type": "code",
"execution_count": 81,
"id": "34dcb443-e0e6-4ded-9ce9-b4bb6660d553",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{ lambda ; a:f32[128,10]. let\n",
" b:f32[] = reduce_sum[axes=(0, 1)] a\n",
" _:f32[] = div b 1280.0\n",
" c:f32[] = div 1.0 1280.0\n",
" d:f32[128,10] = broadcast_in_dim[broadcast_dimensions=() shape=(128, 10)] c\n",
" in (d,) }"
]
},
"execution_count": 81,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"jax.make_jaxpr(fn3)(act)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c5228dc2-8058-4c9a-9089-3e44a9b1eeba",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

0 comments on commit 556641b

Please sign in to comment.