Skip to content

Commit

Permalink
Add better colours to space stretching
Browse files Browse the repository at this point in the history
Closes #847
  • Loading branch information
Atcold committed Jan 30, 2024
1 parent ea3aacc commit 575a9b5
Show file tree
Hide file tree
Showing 3 changed files with 145 additions and 30 deletions.
143 changes: 123 additions & 20 deletions 02-space_stretching.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
"import torch\n",
"import torch.nn as nn\n",
"from res.plot_lib import set_default, show_scatterplot, plot_bases\n",
"from matplotlib.pyplot import plot, title, axis"
"from matplotlib.pyplot import plot, title, axis, figure, gca, gcf\n",
"from numpy import clip"
]
},
{
Expand All @@ -20,7 +21,9 @@
"outputs": [],
"source": [
"# Set style (needs to be in a new cell)\n",
"set_default()"
"%matplotlib inline\n",
"set_default()\n",
"torch.manual_seed(0)"
]
},
{
Expand All @@ -39,10 +42,31 @@
"outputs": [],
"source": [
"# generate some points in 2-D space\n",
"n_points = 1000\n",
"X = torch.randn(n_points, 2).to(device)\n",
"colors = X[:, 0]\n",
"\n",
"n_points = 1_000\n",
"X = torch.randn(n_points, 2).to(device)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# colors [0 – 511]^2\n",
"x_min = -1.5 #X.min(0)[0] #+ 1\n",
"x_max = +1.5 #X.max(0)[0] #- 1\n",
"colors = (X - x_min) / (x_max - x_min)\n",
"colors = (colors * 511).short().numpy()\n",
"colors = clip(colors, 0, 511)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"figure().add_axes([0, 0, 1, 1])\n",
"show_scatterplot(X, colors, title='X')\n",
"OI = torch.cat((torch.zeros(2, 2), torch.eye(2))).to(device)\n",
"plot_bases(OI)"
Expand Down Expand Up @@ -75,13 +99,16 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"show_scatterplot(X, colors, title='X')\n",
"plot_bases(OI)\n",
"\n",
"for i in range(10):\n",
" figure()\n",
" # create a random matrix\n",
" W = torch.randn(2, 2).to(device)\n",
" # transform points\n",
Expand All @@ -91,7 +118,7 @@
" # plot transformed points\n",
" show_scatterplot(Y, colors, title='y = Wx, singular values : [{:.3f}, {:.3f}]'.format(S[0], S[1]))\n",
" # transform the basis\n",
" new_OI = OI @ W.t()\n",
" new_OI = OI @ W\n",
" # plot old and new basis\n",
" plot_bases(OI)\n",
"# plot_bases(new_OI)"
Expand All @@ -107,7 +134,9 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"model = nn.Sequential(\n",
Expand All @@ -116,6 +145,7 @@
"model.to(device)\n",
"with torch.no_grad():\n",
" Y = model(X)\n",
" figure()\n",
" show_scatterplot(Y, colors)\n",
" plot_bases(model(OI))"
]
Expand Down Expand Up @@ -156,7 +186,9 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"show_scatterplot(X, colors, title='X')\n",
Expand All @@ -170,6 +202,7 @@
"model.to(device)\n",
"\n",
"for s in range(1, 6):\n",
" figure()\n",
" W = s * torch.eye(2)\n",
" model[0].weight.data.copy_(W)\n",
" Y = model(X).data\n",
Expand All @@ -187,15 +220,19 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"show_scatterplot(X, colors, title='x')\n",
"n_hidden = 5\n",
"\n",
"# NL = nn.ReLU()\n",
"# NL = nn.ReLU() # ()^+\n",
"NL = nn.Tanh()\n",
"\n",
"models = list()\n",
"\n",
"for i in range(5):\n",
" # create 1-layer neural networks with random weights\n",
" model = nn.Sequential(\n",
Expand All @@ -204,24 +241,28 @@
" nn.Linear(n_hidden, 2)\n",
" )\n",
" model.to(device)\n",
" models.append(model)\n",
" with torch.no_grad():\n",
" Y = model(X)\n",
" figure()\n",
" show_scatterplot(Y, colors, title='f(x)')\n",
"# plot_bases(OI)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"# deeper network with random weights\n",
"show_scatterplot(X, colors, title='x')\n",
"n_hidden = 5\n",
"\n",
"# NL = nn.ReLU()\n",
"NL = nn.Tanh()\n",
"NL = nn.ReLU()\n",
"# NL = nn.Tanh()\n",
"\n",
"for i in range(5):\n",
" model = nn.Sequential(\n",
Expand All @@ -238,22 +279,84 @@
" model.to(device)\n",
" with torch.no_grad():\n",
" Y = model(X).detach()\n",
" show_scatterplot(Y, colors, title='f(x)')"
" figure()\n",
" show_scatterplot(Y, colors, title='f(x)', axis=False)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
"source": [
"show_scatterplot(X, colors, title='x')\n",
"with torch.no_grad():\n",
" Y = models[2](X)\n",
"figure()\n",
"show_scatterplot(Y, colors, title='f(x)')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def interpolate(X_in, X_out, steps, p=1/50, plotting_grid=False, ratio='1:1'):\n",
" N = 1000\n",
" for t in range(steps):\n",
" # a = (t / (steps - 1)) ** p\n",
" a = ((p + 1)**(t / (steps - 1)) - 1) / p\n",
" gca().cla()\n",
"# plt.text(0, 5, action, color='w', horizontalalignment='center', verticalalignment='center')\n",
" show_scatterplot(a * X_out + (1 - a) * X_in, colors, title='f(x)')\n",
"\n",
" if plotting_grid: plot_grid(a * X_out[N:] + (1 - a) * X_in[N:])\n",
" gcf().canvas.draw()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%matplotlib notebook"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# fig = figure(figsize=(19.2, 10.8)) # resolution is 100 px per inch => 1920 x 1080\n",
"fig = figure(figsize=(10, 5))\n",
"ax = fig.add_axes([0, 0, 1, 1]) # stretched the plot area to the whole figure\n",
"axis('off');"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Animate input/output\n",
"steps = 150\n",
"# steps = 1500\n",
"interpolate(X, Y, steps, p=.001)"
]
}
],
"metadata": {
"jupytext": {
"formats": "ipynb,py:percent"
},
"kernelspec": {
"display_name": "Python [conda env:dl-minicourse] *",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "conda-env-dl-minicourse-py"
"name": "python3"
},
"language_info": {
"codemirror_mode": {
Expand All @@ -265,7 +368,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
"version": "3.11.5"
}
},
"nbformat": 4,
Expand Down
32 changes: 22 additions & 10 deletions res/plot_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,22 +37,34 @@ def plot_model(X, y, model):
plot_data(X, y)


def show_scatterplot(X, colors, title=''):
colors = colors.cpu().numpy()
X = X.cpu().numpy()
plt.figure()
zieger = plt.imread('res/ziegler.png')

def show_scatterplot(X, colors, title='', axis=True):
colors = zieger[colors[:,0], colors[:,1]]
X = X.numpy()
# plt.figure()
plt.axis('equal')
plt.scatter(X[:, 0], X[:, 1], c=colors, s=30)
# plt.grid(True)
plt.title(title)
plt.axis('off')
_m, _c = 0, '.15'
if axis:
plt.axvline(0, ymin=_m, color=_c, lw=1, zorder=0)
plt.axhline(0, xmin=_m, color=_c, lw=1, zorder=0)


def plot_bases(bases, width=0.04):
bases = bases.cpu()
def plot_bases(bases, plotting=True, width=0.04):
bases[2:] -= bases[:2]
plt.arrow(*bases[0], *bases[2], width=width, color=(1,0,0), zorder=10, alpha=1., length_includes_head=True)
plt.arrow(*bases[1], *bases[3], width=width, color=(0,1,0), zorder=10, alpha=1., length_includes_head=True)
# if plot_bases.a: plot_bases.a.set_visible(False)
# if plot_bases.b: plot_bases.b.set_visible(False)
if plotting:
plot_bases.a = plt.arrow(*bases[0], *bases[2], width=width, color='r', zorder=10, alpha=1., length_includes_head=True)
plot_bases.b = plt.arrow(*bases[1], *bases[3], width=width, color='g', zorder=10, alpha=1., length_includes_head=True)


plot_bases.a = None
plot_bases.b = None


def show_mat(mat, vect, prod, threshold=-1):
Expand All @@ -72,7 +84,7 @@ def show_mat(mat, vect, prod, threshold=-1):
# Remove xticks for vectors
ax2.set_xticks(tuple())
ax3.set_xticks(tuple())

# Plot colourbars
fig.colorbar(cax1, ax=ax2)
fig.colorbar(cax3, ax=ax3)
Expand Down Expand Up @@ -138,4 +150,4 @@ def plot_state(data, state, b, decoder):
seq_len_w_pad = len(state)
for s in range(state.size(2)):
states = torch.sigmoid(state[:, b, s])
_visualise_values(states[seq_len_w_pad - seq_len:], list(actual_data))
_visualise_values(states[seq_len_w_pad - seq_len:], list(actual_data))
Binary file added res/ziegler.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit 575a9b5

Please sign in to comment.