-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathpredict.py
105 lines (77 loc) · 2.38 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import logging
import os
import warnings
from datetime import datetime
import torch
from munch import munchify, unmunchify
from tqdm import tqdm
from yaml import safe_load
import wandb
from loggers.logs import log_predictions
from models.unet import UNET
from utils.utils import (
get_device,
get_loaders,
get_loss_function,
get_metrics,
get_transforms,
load_checkpoint,
)
warnings.filterwarnings("ignore")
def predict_fn(test_loader, model, loss_fn, global_metrics, label_metrics, config):
device = get_device(config)
loop = tqdm(test_loader)
model.eval()
for idx, (data, targets) in enumerate(loop):
data = data.to(device)
targets = targets.long().to(device)
with torch.no_grad():
predictions = model(data)
loss = loss_fn(predictions, targets)
# update tqdm loop
loop.set_postfix(loss=loss.item())
log_predictions(
data,
targets,
predictions,
global_metrics,
label_metrics,
config,
idx,
)
return loss.item()
def main(config):
device = get_device(config)
logging.info("predict")
wandb.init(
project=config.wandb.project_name,
entity=config.wandb.project_team,
config=unmunchify(config.hyperparameters),
)
config.hyperparameters = munchify(wandb.config)
_, _, test_loader = get_loaders(config, *get_transforms(config))
model = UNET(config).to(device)
model = torch.nn.DataParallel(model)
load_checkpoint(torch.load(config.load.path), model, optimizer=None, scheduler=None)
loss_fn = get_loss_function(config)
global_metrics, label_metrics = get_metrics(config)
predict_fn(
test_loader,
model,
loss_fn,
global_metrics,
label_metrics,
config,
)
wandb.finish()
if __name__ == "__main__":
logging.getLogger().setLevel(logging.INFO)
warnings.filterwarnings("ignore")
torch.cuda.empty_cache()
torch.autograd.set_detect_anomaly(True)
# CHECKPOINT = "data/checkpoints/20220619_061736_best_checkpoint.pth.tar"
with open("config_prediction.yaml") as f:
config = munchify(safe_load(f))
os.environ["WANDB_MODE"] = "online" if config.wandb.online else "offline"
config.project.time = datetime.now().strftime("%Y%m%d_%H%M%S")
main(config)