-
Notifications
You must be signed in to change notification settings - Fork 46
/
models.py
120 lines (104 loc) · 4.11 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import torch
import torch.nn as nn
from torch.autograd import Variable
#weight initialisation with mean=0 and stddev=0.02
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find('BatchNorm') != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)
class Encoder(nn.Module):
def __init__(self,in_channels,latent_dim): #in_channels=1 for MNIST and in_channels=3 for CIFAR10
super(Encoder,self).__init__()
self.conv1=nn.Conv2d(in_channels,64,5,padding=2,stride=2)
self.bn1=nn.BatchNorm2d(64,momentum=0.9)
self.conv2=nn.Conv2d(64,128,5,padding=2,stride=2)
self.bn2=nn.BatchNorm2d(128,momentum=0.9)
self.conv3=nn.Conv2d(128,256,5,padding=2,stride=2)
self.bn3=nn.BatchNorm2d(256,momentum=0.9)
self.relu=nn.LeakyReLU(0.2)
self.fc1=nn.Linear(256*8*8,2048)
self.bn4=nn.BatchNorm1d(2048,momentum=0.9)
self.fc_mean=nn.Linear(2048,latent_dim)
self.fc_logvar=nn.Linear(2048,latent_dim) #latent dim=128
def forward(self,x):
batch_size=x.size()[0]
out=self.relu(self.bn1(self.conv1(x)))
out=self.relu(self.bn2(self.conv2(out)))
out=self.relu(self.bn3(self.conv3(out)))
out=out.view(batch_size,-1)
out=self.relu(self.bn4(self.fc1(out)))
mean=self.fc_mean(out)
logvar=self.fc_logvar(out)
return mean,logvar
class Decoder(nn.Module):
def __init__(self,latent_dim,out_channels):
super(Decoder,self).__init__()
self.fc1=nn.Linear(latent_dim,8*8*256)
self.bn1=nn.BatchNorm1d(8*8*256,momentum=0.9)
self.relu=nn.LeakyReLU(0.2)
self.deconv1=nn.ConvTranspose2d(256,256,6, stride=2, padding=2)
self.bn2=nn.BatchNorm2d(256,momentum=0.9)
self.deconv2=nn.ConvTranspose2d(256,128,6, stride=2, padding=2)
self.bn3=nn.BatchNorm2d(128,momentum=0.9)
self.deconv3=nn.ConvTranspose2d(128,32,6, stride=2, padding=2)
self.bn4=nn.BatchNorm2d(32,momentum=0.9)
self.deconv4=nn.ConvTranspose2d(32,out_channels,5, stride=1, padding=2)
self.tanh=nn.Tanh()
def forward(self,x):
batch_size=x.size()[0]
x=self.relu(self.bn1(self.fc1(x)))
x=x.view(-1,256,8,8)
x=self.relu(self.bn2(self.deconv1(x)))
x=self.relu(self.bn3(self.deconv2(x)))
x=self.relu(self.bn4(self.deconv3(x)))
x=self.tanh(self.deconv4(x))
return x
class Discriminator(nn.Module):
def __init__(self,in_channels):
super(Discriminator,self).__init__()
self.conv1=nn.Conv2d(in_channels,32,5,padding=2,stride=1)
self.relu=nn.LeakyReLU(0.2)
self.conv2=nn.Conv2d(32,128,5,padding=2,stride=2)
self.bn1=nn.BatchNorm2d(128,momentum=0.9)
self.conv3=nn.Conv2d(128,256,5,padding=2,stride=2)
self.bn2=nn.BatchNorm2d(256,momentum=0.9)
self.conv4=nn.Conv2d(256,256,5,padding=2,stride=2)
self.bn3=nn.BatchNorm2d(256,momentum=0.9)
self.fc1=nn.Linear(8*8*256,512)
self.bn4=nn.BatchNorm1d(512,momentum=0.9)
self.fc2=nn.Linear(512,1)
self.sigmoid=nn.Sigmoid()
def forward(self,x):
batch_size=x.size()[0]
x=self.relu(self.conv1(x))
x=self.relu(self.bn1(self.conv2(x)))
x=self.relu(self.bn2(self.conv3(x)))
x=self.relu(self.bn3(self.conv4(x)))
x=x.view(-1,256*8*8)
x1=x;
x=self.relu(self.bn4(self.fc1(x)))
x=self.sigmoid(self.fc2(x))
return x,x1
class VAE_GAN(nn.Module):
def __init__(self,in_channels,out_channels,latent_dim):
super(VAE_GAN,self).__init__()
self.encoder=Encoder(in_channels,latent_dim)
self.decoder=Decoder(latent_dim,out_channels)
self.discriminator=Discriminator(in_channels)
self.encoder.apply(weights_init)
self.decoder.apply(weights_init)
self.discriminator.apply(weights_init)
def forward(self,x):
bs=x.size()[0]
z_mean,z_logvar=self.encoder(x)
std = z_logvar.mul(0.5).exp_()
l_dim=z_mean.size()[1]
#sampling epsilon from normal distribution
epsilon=Variable(torch.randn(bs,l_dim)).to(device)
z=z_mean+std*epsilon
x_tilda=self.decoder(z)
return z_mean,z_logvar,x_tilda