-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
131 lines (112 loc) · 4.59 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import os
import sys
import pytorch_lightning as pl
import torch
import torch.utils.data
from pytorch_lightning.loggers import TensorBoardLogger
from torch.utils.data import DataLoader
from data.dirtycollate import dirty_collate
from data.unsplashlite_clip import UnsplashLiteDataset
from model.CringeCLIP import CringeCLIPModel
from model.CringeDenoiser import CringeDenoiserModel
from model.CringeVAE import CringeVAEModel
from utils import RegularCheckpoint, train_save_checkpoint
def train_denoiser(device='gpu'):
# hparams while i'm working on it
img_dim = 512
# data
dataset = UnsplashLiteDataset(root_dir='/mnt/e/Source/unsplash-lite-corpus-preprocess/db', img_dim=img_dim)
training_set, validation_set = torch.utils.data.random_split(dataset, [int(len(dataset)*0.8), len(dataset) - int(len(dataset)*0.8)])
train_loader = DataLoader(training_set, batch_size=10, collate_fn=dirty_collate)
val_loader = DataLoader(validation_set, batch_size=10, collate_fn=dirty_collate)
# Load CLIP checkpoint if it exists
clip_model = CringeCLIPModel() #.to("cuda:0" if torch.cuda.is_available() else "cpu")
if (os.path.exists("checkpoints/clip/model.ckpt")):
clip_model.load_state_dict(torch.load("checkpoints/clip/model.ckpt")["state_dict"])
# Load VAE checkpoint if it exists
vae_model = CringeVAEModel(dimensions=[16,32,64,128]) #.to("cuda:0" if torch.cuda.is_available() else "cpu")
if (os.path.exists("checkpoints/vae/model.ckpt")):
vae_model.load_state_dict(torch.load("checkpoints/vae/model.ckpt")["state_dict"])
# Load checkpoint if it exists
denoiser_model = CringeDenoiserModel(
vae_model=vae_model,
clip_model=clip_model,
diffuser_shapes=[32,64,128,256],
img_dim=img_dim) #.to("cuda:0" if torch.cuda.is_available() else "cpu")
# Logger
denoiser_logger = TensorBoardLogger("tb_logs", name="cringeldm")
denoiser_trainer = pl.Trainer(
accelerator=device,
precision=16,
limit_train_batches=0.5,
callbacks=[
RegularCheckpoint(
model=denoiser_model,
period=10,
base_dir="checkpoints/ldm",
do_q=True,
do_img=False,
),
],
accumulate_grad_batches=20,
logger=denoiser_logger)
while True:
try:
# Load checkpoint if it exists
if (os.path.exists("checkpoints/ldm/model.ckpt")):
denoiser_trainer.fit(denoiser_model, train_loader, val_loader, ckpt_path="checkpoints/ldm/model.ckpt")
else:
denoiser_trainer.fit(denoiser_model, train_loader, val_loader)
except Exception as e:
tb = sys.exc_info()[2]
print(e.with_traceback(tb))
def train_vae():
# hparams while i'm working on it
img_dim = 512
# data
dataset = UnsplashLiteDataset(root_dir='/mnt/e/Source/unsplash-lite-corpus-preprocess/db', img_dim=img_dim)
training_set, validation_set = torch.utils.data.random_split(dataset, [int(len(dataset)*0.8), int(len(dataset)*0.2)])
train_loader = DataLoader(training_set, batch_size=8, collate_fn=dirty_collate)
val_loader = DataLoader(validation_set, batch_size=8, collate_fn=dirty_collate)
# Load checkpoint if it exists
vae_model = CringeVAEModel(dimensions=[16,32,64,128]).to("cuda:0" if torch.cuda.is_available() else "cpu")
# Logger
vae_logger = TensorBoardLogger("tb_logs", name="cringeldmvae")
vae_trainer = pl.Trainer(
accelerator='gpu',
precision=16,
limit_train_batches=0.5,
callbacks=[
RegularCheckpoint(
model=vae_model,
period=5000,
base_dir="checkpoints/vae",
do_img=True,
do_q=False
),
],
logger=vae_logger
)
while True:
# Load checkpoint if it exists
if (os.path.exists("checkpoints/vae/model.ckpt")):
vae_trainer.fit(vae_model, train_loader, val_loader, ckpt_path="checkpoints/vae/model.ckpt")
else:
vae_trainer.fit(vae_model, train_loader, val_loader)
def train():
args = sys.argv[1:]
if len(args) == 0:
train_denoiser()
print("Please specify a model to train.")
return
else:
if args[0] == "denoiser":
if len(args) == 2:
train_denoiser(args[1])
train_denoiser()
elif args[0] == "vae":
train_vae()
else:
print("Invalid model specified.")
if __name__ == '__main__':
train()