Skip to content

Commit

Permalink
Adapt example of class wise CVs
Browse files Browse the repository at this point in the history
  • Loading branch information
knikolaou committed May 10, 2024
1 parent fe1382f commit fadae59
Showing 1 changed file with 73 additions and 19 deletions.
92 changes: 73 additions & 19 deletions examples/Class-wise-Collective-Variables.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,16 @@
"inputs = data_generator.train_ds[\"inputs\"]\n",
"targets = data_generator.train_ds[\"targets\"]\n",
"\n",
"selection_idx = np.where(np.argmax(targets, axis=1) < 4)[0]\n",
"selection_idx_012 = np.where(np.argmax(targets, axis=1) < 3)[0]\n",
"selection_idx_0 = np.where(np.argmax(targets, axis=1) == 7)[0]\n",
"selection_idx = np.concatenate((selection_idx_012, selection_idx_0))\n",
"selection_idx = np.random.permutation(selection_idx)\n",
"ntk_inputs = inputs[selection_idx]\n",
"ntk_targets = targets[selection_idx]\n",
"\n",
"# Quick check\n",
"print(f\"Original data set shape: {inputs.shape}\")\n",
"assert np.argmax(ntk_targets, axis=1).max() == 3\n",
"assert np.argmax(ntk_targets, axis=1).max() == 7 #3\n",
"\n",
"# Set new attribute ntk_ds\n",
"ntk_ds = {\n",
Expand Down Expand Up @@ -118,6 +121,7 @@
" flax_module=Architecture(),\n",
" optimizer=optax.sgd(learning_rate=0.01, momentum=0.9),\n",
" input_shape=(1, 28, 28, 1),\n",
" trace_axes=(),\n",
")"
]
},
Expand All @@ -137,6 +141,7 @@
" covariance_entropy=True,\n",
" magnitude_variance=True, \n",
" trace=True,\n",
" eigenvalues=True,\n",
" loss_derivative=True,\n",
" update_rate=1, \n",
" chunk_size=1e5\n",
Expand Down Expand Up @@ -190,7 +195,7 @@
"outputs": [],
"source": [
"# Show labels-combinations used to compute the inter-class entropy contributions. \n",
"ntk_recorder._class_combinations"
"ntk_recorder._class_combinations, ntk_recorder._class_idx"
]
},
{
Expand All @@ -204,7 +209,7 @@
" train_ds=data_generator.train_ds, \n",
" test_ds=data_generator.test_ds,\n",
" batch_size=128,\n",
" epochs=120,\n",
" epochs=150,\n",
")"
]
},
Expand Down Expand Up @@ -259,15 +264,24 @@
"metadata": {},
"outputs": [],
"source": [
"# Read class specific (cs) data from the ntk_recorder\n",
"\n",
"trace_cs = ntk_recorder.read_class_specific_data(ntk_report.trace)\n",
"covariance_entropy_cs = ntk_recorder.read_class_specific_data(ntk_report.covariance_entropy)\n",
"\n",
"# Reading out the eigenvalues works the same way\n",
"eigenvalues_cs = ntk_recorder.read_class_specific_data(ntk_report.eigenvalues) \n",
"\n",
"\n",
"# Plot the class specific entropy and trace \n",
"\n",
"cmap = plt.get_cmap('rainbow')\n",
"\n",
"fig, axs = plt.subplots(1, 2, figsize=(12, 4))\n",
"\n",
"for i in range(4):\n",
" axs[0].plot(ntk_report.covariance_entropy[:, i], '-', mfc='Entropy', label=f\"Train {i}\", color=cmap(i/3))\n",
" axs[1].plot(ntk_report.trace[:, i], '--', mfc='Trace', label=f\"Train {i}\", color=cmap(i/3))\n",
"for i, l in enumerate([0, 1, 2, 7]):\n",
" axs[0].plot(covariance_entropy_cs[l], '-', mfc='Entropy', label=f\"Train {l}\", color=cmap(i/3))\n",
" axs[1].plot(trace_cs[l], '--', mfc='Trace', label=f\"Train {l}\", color=cmap(i/3))\n",
"\n",
"axs[0].set_xlabel(\"Epoch\")\n",
"axs[1].set_xlabel(\"Epoch\")\n",
Expand All @@ -277,9 +291,12 @@
"# Colorbar with integer ticks\n",
"cbar = fig.colorbar(plt.cm.ScalarMappable(cmap=cmap), ax=axs, orientation='vertical')\n",
"cbar.set_ticks([0, 1/3, 2/3, 1])\n",
"cbar.set_ticklabels([0, 1, 2, 3])\n",
"cbar.set_ticklabels([0, 1, 2, 7])\n",
"\n",
"plt.show()\n",
"\n",
"\n",
"plt.show()\n"
"# 7 Is missing!!! Why?"
]
},
{
Expand All @@ -297,35 +314,72 @@
"\n",
"fig, axs = plt.subplots(1, 2, figsize=(12, 4), tight_layout=True)\n",
"\n",
"for i in range(3):\n",
" axs[0].plot(entropies[:, i+4] - entropies[:, i], '-', label=f\"S(0 - {i+1})\", color=cmap(i/2))\n",
" axs[0].plot(entropies[:, i], '--', label=f\"S({i})\", color=cmap(i/2))\n",
" axs[0].plot(entropies[:, i+4], '_', label=f\"S(0 + {i+1})\", color=cmap(i/2), alpha=0.4)\n",
"# for i in range(3):\n",
"# axs[0].plot(entropies[:, i+4] - entropies[:, i], '-', label=f\"S(0 - {i+1})\", color=cmap(i/2))\n",
"# # axs[0].plot(entropies[:, i], '--', label=f\"S({i})\", color=cmap(i/2))\n",
"# # axs[0].plot(entropies[:, i+4], '_', label=f\"S(0 + {i+1})\", color=cmap(i/2), alpha=0.4)\n",
"\n",
"axs[0].set_xlabel(\"Epoch\")\n",
"axs[0].set_ylabel(\"Entropy\")\n",
"# Put legend outside of plot\n",
"axs[0].legend(loc='center left', bbox_to_anchor=(1, 0.5))\n",
"# # Put legend outside of plot\n",
"# axs[0].legend(loc='center left', bbox_to_anchor=(1, 0.5))\n",
"\n",
"axs[1].plot(entropies[:, :4].sum(axis=1), '-', label=r'$S_{sub} = S(1) + S(2) + S(3) + S(4)$')\n",
"axs[1].plot(entropies[:, :4].sum(axis=1)/4, '-', label=r'$S_{sub} = S(1) + S(2) + S(3) + S(4)$')\n",
"axs[1].plot(entropies[:, -1], '--', label=r'$S_{sys} = S(1 + 2 + 3 + 4)$')\n",
"axs[1].plot(entropies[:, :4].sum(axis=1) - entropies[:,-1], '--', label=r'$S_{sub} - S_{sys}$')\n",
"axs[0].plot(entropies[:, :4].sum(axis=1)/4 - entropies[:,-1], '--', label=r'$S_{sub} - S_{sys}$')\n",
"axs[0].legend()\n",
"axs[1].set_xlabel(\"Epoch\") \n",
"axs[1].legend()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "43a365e5",
"id": "9cd531f1",
"metadata": {},
"outputs": [],
"source": [
"# Plot the mutual information between the classes\n",
"\n",
"cmap = plt.get_cmap('rainbow')\n",
"\n",
"fig, axs = plt.subplots(1, 3, figsize=(12, 4), tight_layout=True, sharey=True)\n",
"\n",
"# axs[0].plot(entropies[:, 0] + entropies[:, i] - entropies[:, i+4], '-', label=f\"Mutual Info (0, {i+1})\", color=cmap(i/2))\n",
"\n",
"axs[0].plot((entropies[:, 0] + entropies[:, 1])/2 - entropies[:, 4], '-', label=f\"Mutual Info (0, 1)\", color=cmap(0))\n",
"axs[0].plot((entropies[:, 0] + entropies[:, 2])/2 - entropies[:, 5], '-', label=f\"Mutual Info (0, 2)\", color=cmap(1/3))\n",
"axs[0].plot((entropies[:, 0] + entropies[:, 3])/2 - entropies[:, 6], '-', label=f\"Mutual Info (0, 7)\", color=cmap(2/3))\n",
"axs[0].legend()\n",
"\n",
"axs[1].plot((entropies[:, 1] + entropies[:, 2])/2 - entropies[:, 7], '-', label=f\"Mutual Info (1, 2)\", color=cmap(0))\n",
"axs[1].plot((entropies[:, 1] + entropies[:, 3])/2 - entropies[:, 8], '-', label=f\"Mutual Info (1, 7)\", color=cmap(1/3))\n",
"axs[1].plot((entropies[:, 0] + entropies[:, 1])/2 - entropies[:, 4], '-', label=f\"Mutual Info (0, 1)\", color=cmap(2/3))\n",
"axs[1].legend()\n",
"\n",
"axs[2].plot((entropies[:, 2] + entropies[:, 3])/2 - entropies[:, 9], '-', label=f\"Mutual Info (2, 7)\", color=cmap(0))\n",
"axs[2].plot((entropies[:, 0] + entropies[:, 2])/2 - entropies[:, 5], '-', label=f\"Mutual Info (0, 2)\", color=cmap(1/3))\n",
"axs[2].plot((entropies[:, 1] + entropies[:, 3])/2 - entropies[:, 8], '-', label=f\"Mutual Info (1, 7)\", color=cmap(2/3))\n",
"axs[2].legend()\n",
"\n",
"axs[0].set_xlabel(\"Epoch\")\n",
"axs[1].set_xlabel(\"Epoch\")\n",
"axs[2].set_xlabel(\"Epoch\")\n",
"axs[0].set_ylabel(\"Mutual Information\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dc832d0d",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "9f0039de",
"id": "652a3419",
"metadata": {},
"outputs": [],
"source": []
Expand Down

0 comments on commit fadae59

Please sign in to comment.