Skip to content

Commit

Permalink
hw3_helper update
Browse files Browse the repository at this point in the history
  • Loading branch information
wuphilipp committed Feb 23, 2024
1 parent 2369f60 commit fa74e50
Show file tree
Hide file tree
Showing 4 changed files with 52,599 additions and 8 deletions.
14 changes: 6 additions & 8 deletions deepul/hw3_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def q1_gan_plot(data, samples, xs, ys, title, fname):

def q1_data(n=20000):
assert n % 2 == 0
gaussian1 = np.random.normal(loc=-1.5, scale=0.22, size=(n//2,))
gaussian1 = np.random.normal(loc=-1.5, scale=0.35, size=(n//2,))
gaussian2 = np.random.normal(loc=0.2, scale=0.6, size=(n//2,))
data = (np.concatenate([gaussian1, gaussian2]) + 1).reshape([-1, 1])
scaled_data = (data - np.min(data)) / (np.max(data) - np.min(data) + 1e-8)
Expand Down Expand Up @@ -160,16 +160,14 @@ def save_plot(

def q3_save_results(fn, part):
train_data, test_data = load_q3_data()
gan_losses, optional_lpips_losses, l2_train_losses, l2_val_losses, recon_show, recon_is = fn(train_data, test_data, test_data[:100])
gan_losses, lpips_losses, l2_train_losses, l2_val_losses, recon_show = fn(train_data, test_data, test_data[:100])

plot_gan_training(gan_losses, f'Q3{part} Losses', f'results/q3{part}_gan_losses.png')
plot_gan_training(gan_losses, f'Q3{part} Discriminator Losses', f'results/q3{part}_gan_losses.png')
save_plot(l2_train_losses, l2_val_losses, f'Q3{part} L2 Losses', f'results/q3{part}_l2_losses.png')
if optional_lpips_losses is not None:
save_plot(optional_lpips_losses, None, f'Q3{part} LPIPS Losses', f'results/q3{part}_lpips_losses.png')
save_plot(lpips_losses, None, f'Q3{part} LPIPS Losses', f'results/q3{part}_lpips_losses.png')
show_samples(test_data[:100].transpose(0, 2, 3, 1) * 255.0, nrow=20, fname=f'results/q3{part}_data_samples.png', title=f'Q3{part} CIFAR10 val samples')
show_samples(recon_show * 255.0, nrow=20, fname=f'results/q3{part}_reconstructions.png', title=f'Q3{part} VQGAN reconstructions')
print('inception score:', calculate_is(recon_is.transpose([0, 2, 3, 1])))
print('final_reconstruction_loss:', l2_val_losses[-1])
print('final_val_reconstruction_loss:', l2_val_losses[-1])

######################
##### Question 4 #####
Expand All @@ -178,7 +176,7 @@ def q3_save_results(fn, part):
def get_colored_mnist(data):
# from https://www.wouterbulten.nl/blog/tech/getting-started-with-gans-2-colorful-mnist/
# Read Lena image
lena = PILImage.open('deepul/deepul/hw4_utils/lena.jpg')
lena = PILImage.open('deepul/deepul/hw3_utils/lena.jpg')

# Resize
batch_resized = np.asarray([scipy.ndimage.zoom(image, (2.3, 2.3, 1), order=1) for image in data])
Expand Down
285 changes: 285 additions & 0 deletions deepul/hw3_utils/lpips.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,285 @@
"""Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models
Taken from https://github.com/CompVis/taming-transformers/blob/3ba01b241669f5ade541ce990f7650a3b8f65318/taming/modules/losses/lpips.py#L11
"""

import hashlib
import os
from collections import namedtuple

import requests
import torch
import torch.nn as nn
from torchvision import models
from tqdm import tqdm

URL_MAP = {"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"}

CKPT_MAP = {"vgg_lpips": "vgg.pth"}

MD5_MAP = {"vgg_lpips": "d507d7349b931f0638a25a48a722f98a"}


def download(url, local_path, chunk_size=1024):
os.makedirs(os.path.split(local_path)[0], exist_ok=True)
with requests.get(url, stream=True) as r:
total_size = int(r.headers.get("content-length", 0))
with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
with open(local_path, "wb") as f:
for data in r.iter_content(chunk_size=chunk_size):
if data:
f.write(data)
pbar.update(chunk_size)


def md5_hash(path):
with open(path, "rb") as f:
content = f.read()
return hashlib.md5(content).hexdigest()


def get_ckpt_path(name, root, check=False):
assert name in URL_MAP
path = os.path.join(root, CKPT_MAP[name])
if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
download(URL_MAP[name], path)
md5 = md5_hash(path)
assert md5 == MD5_MAP[name], md5
return path


class KeyNotFoundError(Exception):
def __init__(self, cause, keys=None, visited=None):
self.cause = cause
self.keys = keys
self.visited = visited
messages = list()
if keys is not None:
messages.append("Key not found: {}".format(keys))
if visited is not None:
messages.append("Visited: {}".format(visited))
messages.append("Cause:\n{}".format(cause))
message = "\n".join(messages)
super().__init__(message)


def retrieve(
list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False
):
"""Given a nested list or dict return the desired value at key expanding
callable nodes if necessary and :attr:`expand` is ``True``. The expansion
is done in-place.
Parameters
----------
list_or_dict : list or dict
Possibly nested list or dictionary.
key : str
key/to/value, path like string describing all keys necessary to
consider to get to the desired value. List indices can also be
passed here.
splitval : str
String that defines the delimiter between keys of the
different depth levels in `key`.
default : obj
Value returned if :attr:`key` is not found.
expand : bool
Whether to expand callable nodes on the path or not.
Returns
-------
The desired value or if :attr:`default` is not ``None`` and the
:attr:`key` is not found returns ``default``.
Raises
------
Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is
``None``.
"""

keys = key.split(splitval)

success = True
try:
visited = []
parent = None
last_key = None
for key in keys:
if callable(list_or_dict):
if not expand:
raise KeyNotFoundError(
ValueError(
"Trying to get past callable node with expand=False."
),
keys=keys,
visited=visited,
)
list_or_dict = list_or_dict()
parent[last_key] = list_or_dict

last_key = key
parent = list_or_dict

try:
if isinstance(list_or_dict, dict):
list_or_dict = list_or_dict[key]
else:
list_or_dict = list_or_dict[int(key)]
except (KeyError, IndexError, ValueError) as e:
raise KeyNotFoundError(e, keys=keys, visited=visited)

visited += [key]
# final expansion of retrieved value
if expand and callable(list_or_dict):
list_or_dict = list_or_dict()
parent[last_key] = list_or_dict
except KeyNotFoundError as e:
if default is None:
raise e
else:
list_or_dict = default
success = False

if not pass_success:
return list_or_dict
else:
return list_or_dict, success


class LPIPS(nn.Module):
# Learned perceptual metric
def __init__(self, use_dropout=True):
super().__init__()
self.scaling_layer = ScalingLayer()
self.chns = [64, 128, 256, 512, 512] # vg16 features
self.net = vgg16(pretrained=True, requires_grad=False)
self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
self.load_from_pretrained()
for param in self.parameters():
param.requires_grad = False

def load_from_pretrained(self, name="vgg_lpips"):
ckpt = get_ckpt_path(name, "taming/modules/autoencoder/lpips")
self.load_state_dict(
torch.load(ckpt, map_location=torch.device("cpu")), strict=False
)
print("loaded pretrained LPIPS loss from {}".format(ckpt))

@classmethod
def from_pretrained(cls, name="vgg_lpips"):
if name != "vgg_lpips":
raise NotImplementedError
model = cls()
ckpt = get_ckpt_path(name)
model.load_state_dict(
torch.load(ckpt, map_location=torch.device("cpu")), strict=False
)
return model

def forward(self, input, target):
in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
outs0, outs1 = self.net(in0_input), self.net(in1_input)
feats0, feats1, diffs = {}, {}, {}
lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
for kk in range(len(self.chns)):
feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(
outs1[kk]
)
diffs[kk] = (feats0[kk] - feats1[kk]) ** 2

res = [
spatial_average(lins[kk].model(diffs[kk]), keepdim=True)
for kk in range(len(self.chns))
]
val = res[0]
for l in range(1, len(self.chns)):
val += res[l]
return val


class ScalingLayer(nn.Module):
def __init__(self):
super(ScalingLayer, self).__init__()
self.register_buffer(
"shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None]
)
self.register_buffer(
"scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None]
)

def forward(self, inp):
return (inp - self.shift) / self.scale


class NetLinLayer(nn.Module):
"""A single linear layer which does a 1x1 conv"""

def __init__(self, chn_in, chn_out=1, use_dropout=False):
super(NetLinLayer, self).__init__()
layers = (
[
nn.Dropout(),
]
if (use_dropout)
else []
)
layers += [
nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),
]
self.model = nn.Sequential(*layers)


class vgg16(torch.nn.Module):
def __init__(self, requires_grad=False, pretrained=True):
super(vgg16, self).__init__()
vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
self.slice5 = torch.nn.Sequential()
self.N_slices = 5
for x in range(4):
self.slice1.add_module(str(x), vgg_pretrained_features[x])
for x in range(4, 9):
self.slice2.add_module(str(x), vgg_pretrained_features[x])
for x in range(9, 16):
self.slice3.add_module(str(x), vgg_pretrained_features[x])
for x in range(16, 23):
self.slice4.add_module(str(x), vgg_pretrained_features[x])
for x in range(23, 30):
self.slice5.add_module(str(x), vgg_pretrained_features[x])
if not requires_grad:
for param in self.parameters():
param.requires_grad = False

def forward(self, X):
h = self.slice1(X)
h_relu1_2 = h
h = self.slice2(h)
h_relu2_2 = h
h = self.slice3(h)
h_relu3_3 = h
h = self.slice4(h)
h_relu4_3 = h
h = self.slice5(h)
h_relu5_3 = h
vgg_outputs = namedtuple(
"VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"]
)
out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
return out


def normalize_tensor(x, eps=1e-10):
norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True))
return x / (norm_factor + eps)


def spatial_average(x, keepdim=True):
return x.mean([2, 3], keepdim=keepdim)
Loading

0 comments on commit fa74e50

Please sign in to comment.