-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
112 lines (87 loc) · 4 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import os
import json
import glob
import argparse
from typing import Optional
import torch
import torchaudio
import tqdm
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
from avocodo.lightning_module import Avocodo
from avocodo.data.collate import MelCollate
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.profiler import SimpleProfiler, AdvancedProfiler
from avocodo.hparams import HParams
from avocodo.data.dataset import MelDataset, MelDataset
from avocodo.utils import load_state_dict
def get_hparams(config_path: str) -> HParams:
with open(config_path, "r") as f:
data = f.read()
config = json.loads(data)
hparams = HParams(**config)
return hparams
def last_checkpoint(path: str) -> Optional[str]:
ckpt_path = None
if os.path.exists(os.path.join(path, "lightning_logs")):
versions = glob.glob(os.path.join(path, "lightning_logs", "version_*"))
if len(list(versions)) > 0:
last_ver = sorted(list(versions), key=lambda p: int(p.split("_")[-1]))[-1]
last_ckpt = os.path.join(last_ver, "checkpoints/last.ckpt")
if os.path.exists(last_ckpt):
ckpt_path = last_ckpt
return ckpt_path
def get_train_params(args, hparams):
devices = [int(n.strip()) for n in args.device.split(",")]
checkpoint_callback = ModelCheckpoint(
dirpath=None, save_last=True, every_n_train_steps=2000, save_weights_only=False,
monitor="valid/loss_mel_epoch", mode="min", save_top_k=5
)
earlystop_callback = EarlyStopping(monitor="valid/loss_mel_epoch", mode="min", patience=13)
trainer_params = {
"accelerator": args.accelerator,
"callbacks": [checkpoint_callback, earlystop_callback],
}
if args.accelerator != "cpu":
trainer_params["devices"] = devices
if len(devices) > 1:
trainer_params["strategy"] = "ddp"
trainer_params.update(hparams.trainer)
if hparams.train.fp16_run:
trainer_params["amp_backend"] = "native"
trainer_params["precision"] = 16
trainer_params["num_nodes"] = args.num_nodes
return trainer_params
def main():
parser = argparse.ArgumentParser()
parser.add_argument('-c', '--config', type=str, default="./configs/48k.json", help='JSON file for configuration')
parser.add_argument('-a', '--accelerator', type=str, default="gpu", help='training device')
parser.add_argument('-d', '--device', type=str, default="6", help='training device ids')
parser.add_argument('-n', '--num-nodes', type=int, default=1, help='training node number')
args = parser.parse_args()
hparams = get_hparams(args.config)
pl.utilities.seed.seed_everything(hparams.train.seed)
devices = [int(n.strip()) for n in args.device.split(",")]
trainer_params = get_train_params(args, hparams)
# data
train_dataset = MelDataset(hparams.data.training_files, hparams.data)
valid_dataset = MelDataset(hparams.data.validation_files, hparams.data)
collate_fn = MelCollate()
if "strategy" in trainer_params and trainer_params["strategy"] == "ddp":
batch_per_gpu = hparams.train.batch_size // len(devices)
else:
batch_per_gpu = hparams.train.batch_size
train_loader = DataLoader(train_dataset, batch_size=batch_per_gpu, num_workers=8, shuffle=True, pin_memory=True, collate_fn=collate_fn)
valid_loader = DataLoader(valid_dataset, batch_size=4, num_workers=4, shuffle=False, pin_memory=True, collate_fn=collate_fn)
# model
model = Avocodo(**hparams)
# profiler = AdvancedProfiler(filename="profile.txt")
trainer = pl.Trainer(**trainer_params) # , profiler=profiler, max_steps=200
# resume training
ckpt_path = last_checkpoint(hparams.trainer.default_root_dir)
trainer.fit(model=model, train_dataloaders=train_loader, val_dataloaders=valid_loader, ckpt_path=ckpt_path)
if __name__ == "__main__":
main()