-
Notifications
You must be signed in to change notification settings - Fork 1
/
train.py
134 lines (114 loc) · 3.74 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import argparse
from dataclasses import dataclass
import os
import torch
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from torch import nn
from tqdm import tqdm
from model import Model, default_model_config
from utils.data import build_datasets_and_tokenizer
device = "cuda" if torch.cuda.is_available() else "cpu"
@dataclass
class TrainConfig:
batch_size: int
learning_rate: float
epochs: int
block_size: int
default_train_config = TrainConfig(
batch_size=1024,
learning_rate=3e-4,
epochs=2,
block_size=128,
)
def train(model: Model, dataset: Dataset, config: TrainConfig):
dataloader = DataLoader(dataset, batch_size=config.batch_size, shuffle=False)
optimizer = AdamW(model.transformer.parameters(), lr=config.learning_rate)
loss_fn = nn.CrossEntropyLoss()
for epoch in range(config.epochs):
i = 0
epoch_loss = 0
for X, y in tqdm(dataloader):
optimizer.zero_grad()
y_pred = model.transformer(X)
loss = loss_fn(y_pred.permute(0, 2, 1), y)
loss.backward()
optimizer.step()
i += 1
epoch_loss += loss.item()
print(f"Epoch {epoch + 1} / {config.epochs} - Average Loss: {epoch_loss / i}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--data_dir", type=str, default="data", help="Directory containing the dataset"
)
parser.add_argument(
"--model_dir", type=str, default="models", help="Directory containing the model"
)
parser.add_argument(
"--output_dir",
type=str,
default="outputs",
help="Directory containing the output",
)
parser.add_argument(
"--batch_size",
type=int,
default=default_train_config.batch_size,
help="Batch size for training",
)
parser.add_argument(
"--learning_rate",
type=float,
default=default_train_config.learning_rate,
help="Learning rate for training",
)
parser.add_argument(
"--epochs",
type=int,
default=default_train_config.epochs,
help="Number of epochs",
)
parser.add_argument(
"--block_size",
type=int,
default=default_train_config.block_size,
help="Block size for training",
)
parser.add_argument(
"--from_checkpoint",
type=str,
default=None,
help="Path to a model checkpoint to start training from",
)
args = parser.parse_args()
train_config = TrainConfig(
batch_size=args.batch_size,
learning_rate=args.learning_rate,
epochs=args.epochs,
block_size=args.block_size,
)
print(f"Data directory: {args.data_dir}")
print(f"Model directory: {args.model_dir}")
print(f"Output directory: {args.output_dir}")
os.makedirs(args.data_dir, exist_ok=True)
os.makedirs(args.model_dir, exist_ok=True)
os.makedirs(args.output_dir, exist_ok=True)
train_ds, test_ds, tokenizer = build_datasets_and_tokenizer(
args.data_dir, train_config.block_size, device=device
)
if args.from_checkpoint:
model = torch.load(args.from_checkpoint)
assert tokenizer == model.tokenizer, "Tokenizer changed since last checkpoint"
else:
model = Model(tokenizer, default_model_config)
model.transformer.to(device=device)
train(model, train_ds, train_config)
model_save_path = os.path.join(args.model_dir, "model.pt")
torch.save(model, model_save_path)
print(f"Model saved to {model_save_path}")
while True:
input_text = input("Enter a prompt: ")
output_text = model.generate(input_text)
print(f"Completion: {output_text}")
print()