-
Notifications
You must be signed in to change notification settings - Fork 0
/
function_discriminator.py
161 lines (116 loc) · 4.53 KB
/
function_discriminator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
""" Definition of a KubeML function to train the LeNet network with the MNIST dataset"""
import logging
import random
from typing import List, Any, Union, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from kubeml import KubeModel, KubeDataset
from torch.optim import SGD, Adam
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class Discriminator(nn.Module):
""" Discriminator
"""
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(784, 1024)
self.fc2 = nn.Linear(self.fc1.out_features, self.fc1.out_features//2)
self.fc3 = nn.Linear(self.fc2.out_features, self.fc2.out_features//2)
self.fc4 = nn.Linear(self.fc3.out_features, 1)
def forward(self, x):
x.view(x.shape[0], -1)
y = F.leaky_relu(self.fc1(x), 0.2)
y = F.dropout(y, 0.3)
y = F.leaky_relu(self.fc2(y), 0.2)
y = F.dropout(y, 0.3)
y = F.leaky_relu(self.fc3(y), 0.2)
y = F.dropout(y, 0.3)
y = torch.sigmoid(self.fc4(y))
return y
class MnistDataset(KubeDataset):
def __init__(self):
super().__init__("mnist_disc")
def __getitem__(self, index):
x_real = self.data[index]
x_fake = self.labels[index]
return x_real, x_fake
def __len__(self):
return len(self.data)
class KubeDiscriminator(KubeModel):
def __init__(self, network: nn.Module, dataset: MnistDataset):
super().__init__(network, dataset)
def configure_optimizers(self) -> torch.optim.Optimizer:
adam = Adam(self.parameters(), lr=self.lr)
return adam
def init(self):
pass
def train(self, batch, batch_index) -> float:
# define the device for training and load the data
criterion = nn.BCELoss()
total_loss = 0
x,_ = batch
bs = x.size(0)
self.optimizer.zero_grad()
# train discriminator on real
x_real, y_real = x[:, :, 0].view(bs,784), torch.ones(bs, 1)
x_real, y_real = Variable(x_real.to(device)), Variable(y_real.to(device))
output = self(x_real)
logging.info(f"Output is {output}")
real_loss = criterion(output, y_real)
logging.info(f"real_loss {real_loss}")
# train discriminator on fake
x_fake, y_fake = x[:,:,1].view(bs,784), torch.zeros(bs, 1)
x_fake, y_fake = Variable(x_fake.to(device)), Variable(y_fake.to(device))
output = self(x_fake)
fake_loss = criterion(output, y_fake)
# gradient backprop & optimize ONLY D's parameters
loss = real_loss + fake_loss
loss.backward()
# step with the optimizer
self.optimizer.step()
total_loss += loss.data.item()
if batch_index % 10 == 0:
logging.info(f"Index {batch_index}, error: {loss.item()}")
return total_loss
def validate(self, batch, _) -> Tuple[float, float]:
criterion = nn.BCELoss()
x,_ = batch
bs = x.size(0)
test_loss = 0
correct = 0
# test discriminator on real
x_real, y_real = x[:, :, 0].view(bs,784), torch.ones(bs, 1)
x_real, y_real = Variable(x_real.to(device)), Variable(y_real.to(device))
output = self(x_real)
real_loss = criterion(output, y_real)
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(y_real.view_as(pred)).sum().item()
# test discriminator on fake
x_fake, y_fake = x[:,:,1].view(bs,784), torch.zeros(bs, 1)
x_fake, y_fake = Variable(x_fake.to(device)), Variable(y_fake.to(device))
output = self(x_fake)
fake_loss = criterion(output, y_fake)
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(y_fake.view_as(pred)).sum().item()
loss = real_loss + fake_loss
test_loss += loss.data.item()
accuracy = 100. * correct / batch[0].shape[0]
# self.logger.debug(f'accuracy {accuracy}')
return accuracy, test_loss
def infer(self,data: List[Any]) -> Union[torch.Tensor, np.ndarray, List[float]]:
samples = data['data']
with torch.no_grad():
x = torch.tensor(samples)
x = x.view(x.size(0), 784)
output = self(x)
return output
def main():
# set the random seeds
torch.manual_seed(42)
random.seed(42)
discriminator = Discriminator()
dataset = MnistDataset()
kubedisc = KubeDiscriminator(discriminator, dataset)
return kubedisc.start()