-
Notifications
You must be signed in to change notification settings - Fork 8
/
joint_model.py
153 lines (130 loc) · 5.27 KB
/
joint_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
#!/usr/bin/python
'''
Defines a PyTorch graph for forward and backward propogation
within the network encoders and decoders.
'''
import torch
from utils import *
from Generators import *
from data_loader import *
from params import *
'''
Class Definition for 1 Encoder and 1 Decoder joint model
'''
class Enc_Dec_Network():
def initialize(self, opt, encoder, decoder, frozen_dec=False, frozen_enc=False, gpu_ids='1'):
self.opt = opt
self.isTrain = opt.isTrain
self.encoder = encoder
self.decoder = decoder
self.frozen_dec = frozen_dec
self.frozen_enc = frozen_enc
self.device = torch.device('cuda:{}'.format(gpu_ids[0])) # if self.gpu_ids else torch.device('cpu')
# self.encoder.net = encoder.net.to(self.device)
# self.decoder.net = decoder.net.to(self.device)
def set_input(self, input, target, convert_enc=True, shuffle_channel=True):
self.input = input.to(self.device)
self.target = target.to(self.device)
self.encoder.set_data(self.input, self.input, convert=convert_enc, shuffle_channel=shuffle_channel)
def save_networks(self, epoch):
self.encoder.save_networks(epoch)
self.decoder.save_networks(epoch)
def save_outputs(self):
self.encoder.save_outputs()
self.decoder.save_outputs()
def update_learning_rate(self):
self.encoder.update_learning_rate()
def forward(self):
self.encoder.forward()
self.decoder.set_data(self.encoder.output, self.target)
self.decoder.forward()
def test(self):
self.encoder.test()
self.decoder.set_data(self.encoder.output, self.target)
self.decoder.test()
def backward(self):
self.decoder.backward()
# self.encoder.backward()
def optimize_parameters(self):
self.forward()
self.backward()
if not self.frozen_enc:
self.encoder.optimizer.step()
if not self.frozen_dec:
self.decoder.optimizer.step()
self.encoder.optimizer.zero_grad()
self.decoder.optimizer.zero_grad()
def eval(self):
self.encoder.forward()
self.decoder.set_data(self.encoder.output, self.target)
self.decoder.forward()
'''
Class Definition for the final DLoc architecture with
1 Encoder and 2 Decoders joint model
'''
class Enc_2Dec_Network():
def initialize(self, opt , encoder, decoder, offset_decoder, frozen_dec=False, frozen_enc=False, gpu_ids='1'):
print('initializing Encoder and 2 Decoders Model')
self.opt = opt
self.isTrain = opt.isTrain
self.encoder = encoder
self.decoder = decoder
self.offset_decoder = offset_decoder
self.frozen_dec = frozen_dec
self.frozen_enc = frozen_enc
self.device = torch.device('cuda:{}'.format(gpu_ids[0])) # if self.gpu_ids else torch.device('cpu')
# self.encoder.net = encoder.net.to(self.device)
# self.decoder.net = decoder.net.to(self.device)
self.results_save_dir = opt.results_dir
def set_input(self, input, target ,offset_target ,convert_enc=True, shuffle_channel=True):
# features_w_offset, labels_gaussian_2d, features_wo_offset
# input, target, offset_target
self.input = input.to(self.device)
self.target = target.to(self.device)
self.offset_target = offset_target.to(self.device)
self.encoder.set_data(self.input, self.input, convert=convert_enc, shuffle_channel=shuffle_channel)
def save_networks(self, epoch):
self.encoder.save_networks(epoch)
self.decoder.save_networks(epoch)
self.offset_decoder.save_networks(epoch)
def save_outputs(self):
self.encoder.save_outputs()
self.decoder.save_outputs()
self.offset_decoder.save_outputs()
def update_learning_rate(self):
self.encoder.update_learning_rate()
self.decoder.update_learning_rate()
self.offset_decoder.update_learning_rate()
def forward(self):
self.encoder.forward()
self.decoder.set_data(self.encoder.output, self.target)
self.offset_decoder.set_data(self.encoder.output, self.offset_target)
self.decoder.forward()
self.offset_decoder.forward()
# Test the network once set into Evaluation mode!
def test(self):
self.encoder.test()
self.decoder.set_data(self.encoder.output, self.target)
self.offset_decoder.set_data(self.encoder.output, self.offset_target)
self.decoder.test()
self.offset_decoder.test()
def backward(self):
self.decoder.backward()
self.offset_decoder.backward()
# self.encoder.backward()
def optimize_parameters(self):
self.forward()
self.backward()
if not self.frozen_enc:
self.encoder.optimizer.step()
if not self.frozen_dec:
self.decoder.optimizer.step()
self.offset_decoder.optimizer.step()
self.encoder.optimizer.zero_grad()
self.decoder.optimizer.zero_grad()
self.offset_decoder.optimizer.zero_grad()
# set the models to evaluation mode
def eval(self):
self.encoder.eval()
self.decoder.eval()
self.offset_decoder.eval()