-
Notifications
You must be signed in to change notification settings - Fork 12
/
main.py
53 lines (46 loc) · 2.6 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
from data import DEFAULT_VOCAB_PATHS, create_dataloaders
from evaluators import get_evaluator_cls
from evaluators.visualizers import RecallVisualizer
from language import vocabulary_factory
from loggers.file_loggers import BestModelTracker
from loggers.wandb_loggers import WandbSimplePrinter, WandbSummaryPrinter
from losses import loss_factory
from models import create_models
from optimizers import create_optimizers, create_lr_schedulers
from options import get_experiment_config
from set_up import setup_experiment
from trainers import get_trainer_cls
from transforms import image_transform_factory, text_transform_factory
def main():
configs = get_experiment_config()
export_root, configs = setup_experiment(configs)
vocabulary = vocabulary_factory(config={
'vocab_path': configs['vocab_path'] if configs['vocab_path'] else DEFAULT_VOCAB_PATHS[configs['dataset']],
'vocab_threshold': configs['vocab_threshold']
})
image_transform = image_transform_factory(config=configs)
text_transform = text_transform_factory(config={'vocabulary': vocabulary})
train_dataloader, test_dataloaders, train_val_dataloaders = create_dataloaders(image_transform, text_transform,
configs)
criterions = loss_factory(configs)
models = create_models(configs, vocabulary)
optimizers = create_optimizers(models=models, config=configs)
lr_schedulers = create_lr_schedulers(optimizers, config=configs)
train_loggers = [WandbSimplePrinter('train/')]
summary_keys = get_summary_keys(configs)
best_metric_key = [key for key in summary_keys if '10' in key][0]
val_loggers = [WandbSimplePrinter('val/'), WandbSummaryPrinter('best_', summary_keys),
BestModelTracker(export_root, metric_key=best_metric_key)]
visualizer = RecallVisualizer(test_dataloaders)
evaluator = get_evaluator_cls(configs)(models, test_dataloaders, top_k=configs['topk'], visualizer=visualizer)
train_evaluator = get_evaluator_cls(configs)(models, train_val_dataloaders, top_k=configs['topk'],
visualizer=None)
trainer = get_trainer_cls(configs)(models, train_dataloader, criterions, optimizers, lr_schedulers,
configs['epoch'], train_loggers, val_loggers, evaluator, train_evaluator,
start_epoch=0)
trainer.run()
def get_summary_keys(configs):
summary_keys = ['recall_@{}'.format(k) for k in configs['topk'].split(",")]
return summary_keys
if __name__ == '__main__':
main()