-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
150 lines (130 loc) · 5.92 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import torch
import argparse
from omegaconf import OmegaConf
from torch.utils.data.dataloader import DataLoader
from transformers import AutoTokenizer, DataCollatorWithPadding
from datasets import load_from_disk, load_metric
from sklearn.model_selection import StratifiedKFold
import dataloader as DataProcess
import trainer as Trainer
import model as Model
import torch.optim as optim
import utils.loss as Criterion
import utils.metric as Metric
from utils.check_dir import check_dir
from utils.wandb_setting import wandb_setting
from utils.seed_setting import seed_setting
from utils.AIhub_data_add import AIhub_data_add
def main(config):
seed_setting(config.train.seed)
assert torch.cuda.is_available(), "GPU를 사용할 수 없습니다."
device = torch.device('cuda')
print('='*50,f'현재 적용되고 있는 전처리 클래스는 {config.data.preprocess}입니다.', '='*50, sep='\n\n')
tokenizer = AutoTokenizer.from_pretrained(config.model.model_name, use_fast=True)
if 't5' in config.model.model_name:
prepare_features = getattr(DataProcess, config.data.preprocess)(tokenizer, config.train.max_length, config.train.max_answer_length, config.train.stride)
else:
prepare_features = getattr(DataProcess, config.data.preprocess)(tokenizer, config.train.max_length, config.train.stride)
# data Augementation
if config.data.get('AIhub_data_add'):
train_data = AIhub_data_add(config.data.train_path)
else:
train_data = load_from_disk(config.data.train_path)
valid_data = load_from_disk(config.data.val_path)
# 데이터셋 로드 클래스를 불러옵니다.
train_dataset = train_data.map(
prepare_features.train,
batched=True,
num_proc=4,
remove_columns=train_data.column_names,
load_from_cache_file=True,
)
valid_dataset = valid_data.map(
prepare_features.valid,
batched=True,
num_proc=4,
remove_columns=valid_data.column_names,
load_from_cache_file=True,
)
# 원본 test data와 test dataset을 넣어주셔야 합니다.
if 't5' in config.model.model_name:
metric = getattr(Metric, config.model.metric_class)(
metric = load_metric('squad'),
dataset = valid_dataset,
raw_data = valid_data,
n_best_size = config.train.n_best_size,
max_answer_length = config.train.max_answer_length,
save_dir = config.save_dir,
mode = 'train',
tokenizer = tokenizer
)
else:
metric = getattr(Metric, config.model.metric_class)(
metric = load_metric('squad'),
dataset = valid_dataset,
raw_data = valid_data,
n_best_size = config.train.n_best_size,
max_answer_length = config.train.max_answer_length,
save_dir = config.save_dir,
mode = 'train'
)
train_dataset.set_format("torch")
valid_dataset = valid_dataset.remove_columns(["example_id", "offset_mapping"])
valid_dataset.set_format("torch")
data_collator = DataCollatorWithPadding(tokenizer)
train_dataloader = DataLoader(train_dataset, batch_size= config.train.batch_size, collate_fn=data_collator, pin_memory=True, shuffle=True)
valid_dataloader = DataLoader(valid_dataset, batch_size= config.train.batch_size, collate_fn=data_collator, pin_memory=True, shuffle=False)
# 모델 아키텍처를 불러옵니다.
print('='*50,f'현재 적용되고 있는 모델 클래스는 {config.model.model_class}입니다.', '='*50, sep='\n\n')
model = getattr(Model, config.model.model_class)(
model_name = config.model.model_name,
num_labels=2,
dropout_rate = config.train.dropout_rate,
).to(device)
criterion = getattr(Criterion, config.model.loss)
optimizer = getattr(optim, config.model.optimizer)(model.parameters(), lr=config.train.learning_rate)
lr_scheduler = None
epochs = config.train.max_epoch
save_dir = check_dir(config.save_dir)
print('='*50,f'현재 적용되고 있는 트레이너는 {config.model.trainer_class}입니다.', '='*50, sep='\n\n')
if 't5' in config.model.model_name:
trainer = getattr(Trainer, config.model.trainer_class)(
model = model,
criterion = criterion,
metric = metric,
optimizer = optimizer,
device = device,
save_dir = save_dir,
train_dataloader = train_dataloader,
valid_dataloader = valid_dataloader,
lr_scheduler=lr_scheduler,
epochs=epochs,
tokenizer = tokenizer,
max_answer_length = config.train.max_answer_length
)
else:
trainer = getattr(Trainer, config.model.trainer_class)(
model = model,
criterion = criterion,
metric = metric,
optimizer = optimizer,
device = device,
save_dir = save_dir,
train_dataloader = train_dataloader,
valid_dataloader = valid_dataloader,
lr_scheduler=lr_scheduler,
epochs=epochs,
)
## wandb를 설정해주시면 됩니다. 만약 sweep을 진행하고 싶다면 sweep=True로 설정해주세요.
## 자세한 sweep 설정은 utils/wandb_setting.py를 수정해주세요.
wandb_setting(config)
trainer.train()
if __name__=='__main__':
torch.cuda.empty_cache()
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, default='baseline')
args, _ = parser.parse_known_args()
## ex) python3 train.py --config baseline
config = OmegaConf.load(f'./configs/{args.config}.yaml')
print(f'사용할 수 있는 GPU는 {torch.cuda.device_count()}개 입니다.')
main(config)