-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
75 lines (61 loc) · 1.74 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import torch
from torch.utils.data import DataLoader
from models.memeclip import MemeCLIP
from data.dataset import MemeDataset
from data.collator import MemeCollator
from train import Trainer
from configs import cfg
def main():
# Set random seed
torch.manual_seed(cfg.seed)
# Create datasets
train_dataset = MemeDataset(cfg, cfg.root_dir, split='train')
val_dataset = MemeDataset(cfg, cfg.root_dir, split='val')
test_dataset = MemeDataset(cfg, cfg.root_dir, split='test')
# Create collator
collator = MemeCollator(cfg)
# Create dataloaders
train_loader = DataLoader(
train_dataset,
batch_size=cfg.batch_size,
shuffle=True,
collate_fn=collator
)
val_loader = DataLoader(
val_dataset,
batch_size=cfg.batch_size,
collate_fn=collator
)
test_loader = DataLoader(
test_dataset,
batch_size=cfg.batch_size,
collate_fn=collator
)
# Create model
model = MemeCLIP(cfg).to(cfg.device)
# Create optimizer
optimizer = torch.optim.AdamW(
model.parameters(),
lr=cfg.lr,
weight_decay=cfg.weight_decay
)
# Create trainer
trainer = Trainer(
model=model,
train_loader=train_loader,
val_loader=val_loader,
optimizer=optimizer,
cfg=cfg
)
# Train model
if not cfg.test_only:
trainer.train()
# Test model
if cfg.test_only:
# Load best model
checkpoint = torch.load(cfg.checkpoint_file)
model.load_state_dict(checkpoint['model_state_dict'])
# Run evaluation
trainer.validate_loader(test_loader, "Test")
if __name__ == '__main__':
main()