-
Notifications
You must be signed in to change notification settings - Fork 4
/
factory.py
126 lines (93 loc) · 3.52 KB
/
factory.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import torch
from utils.utils import object_from_dict
from torch.utils.data import DataLoader, SubsetRandomSampler
import albumentations as albu
import numpy as np
def create_backbone(cfg):
backbone = object_from_dict(cfg)
return backbone
def create_model(cfg):
input_size = cfg.data.input_size
backbone = create_backbone(cfg.model.backbone)
model = object_from_dict(cfg.model, backbone=backbone, pool_shape=input_size)
return model
def create_optimizer(cfg, model: torch.nn.Module):
optimizer = object_from_dict(cfg.optimizer, params=filter(lambda x: x.requires_grad, model.parameters()))
return optimizer
def create_scheduler(cfg, optimizer: torch.optim.Optimizer):
scheduler = object_from_dict(cfg.scheduler, optimizer=optimizer)
return scheduler
def create_loss(cfg):
loss = object_from_dict(cfg.loss)
return loss
def create_train_dataloader(cfg):
train_dataloaders = dict()
for dataset_cfg in cfg.data.train_dataset:
dataset = create_dataset(dataset_cfg)
dataloader_dict = create_dataloader(dataset_cfg, dataset)
train_dataloaders[dataloader_dict['name']] = dataloader_dict
return train_dataloaders
def create_val_dataloader(cfg):
val_dataloaders = dict()
for dataset_cfg in cfg.data.validation_dataset:
dataset = create_dataset(dataset_cfg)
dataloader_dict = create_dataloader(dataset_cfg, dataset)
val_dataloaders[dataloader_dict['name']] = dataloader_dict
return val_dataloaders
def create_dataset(cfg):
params = dict()
params['type'] = cfg.type
params['q_root'] = cfg.query.root
params['s_root'] = cfg.support.root
params['q_ann_filename'] = cfg.query.annotations
params['s_ann_filename'] = cfg.support.annotations
params['k_shot'] = cfg.k_shot
params['q_img_size'] = cfg.input_size
params['backbone_stride'] = cfg.backbone_stride
q_transform = create_augmentations(cfg.transforms)
s_transform = create_augmentations(cfg.transforms)
params['q_transform'] = q_transform
params['s_transform'] = s_transform
dataset = object_from_dict(params)
return dataset
def create_dataloader(cfg, dataset):
batch_size = cfg.bs
dataset_length = cfg.len
shuffle = cfg.shuffle
if dataset_length:
if shuffle:
idx = np.random.choice(len(dataset), dataset_length, replace=False)
shuffle = False
else:
idx = np.arange(dataset_length)
sampler = SubsetRandomSampler(indices=idx)
else:
sampler = None
collate_fn = object_from_dict(cfg.collate_fn)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle,
sampler=sampler, collate_fn=collate_fn)
dataloader_dict = {
'name': cfg.name,
'dataloader': dataloader,
'draw': cfg.draw,
}
return dataloader_dict
def create_metrics(cfg):
metrics = []
for metric in cfg.metrics:
metric_obj = object_from_dict(metric)
metrics.append(metric_obj)
return metrics
def create_device(cfg):
return torch.device(cfg.device)
def create_callbacks(cfg, trainer):
for hook in cfg.hooks:
hook_obj = object_from_dict(hook)
trainer.register_callback(hook_obj)
def create_augmentations(cfg):
augmentations = []
for augm in cfg:
augmentations.append(object_from_dict(augm))
transform = albu.Compose(augmentations,
bbox_params=albu.BboxParams(format='coco', label_fields=['bboxes_cats']))
return transform