-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathinfer.py
168 lines (136 loc) · 4.55 KB
/
infer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import argparse
from typing import List
import os
import torch
from torch.utils.data import DataLoader
from torch.nn import Module
from config.default import get_cfg_from_file
from dataset import get_dataloader
from models import get_model
from models.models_utils import (
rename_ordered_dict_from_parallel,
rename_ordered_dict_to_parallel,
)
from train_utils import load_checkpoint
from utils.utilities import get_gpu_count
from utils.infer_utils import generate_outputs, prepare_raster_for_inference
from utils.io_utils import get_lines_from_txt
def parser():
"""Parse the arguments."""
parser = argparse.ArgumentParser(description="Train the model")
parser.add_argument(
"--cfg",
help="Path to the config file defining testing",
type=str,
default="/data/land_cover_tracking/config/weighted_loss.yml",
)
parser.add_argument(
"--checkpoint",
help="Path to the config file",
type=str,
default="/data/land_cover_tracking/weights/cfg_weighted_loss_best_f1.pth",
)
parser.add_argument(
"--samples_list",
help="Path to the list of samples for inference",
type=str,
default="test",
)
parser.add_argument(
"--destination",
help="Path for saving results",
type=str,
default="/data/seg_data/inference",
)
parser.add_argument(
"--outputs",
nargs="+",
default=["raster"],
help="What kind of outputs to generate "
+ "from ['alphablend','raster','alphablended_raster', 'raw_raster']",
)
return parser.parse_args()
def infer(
model: Module,
dataloader: DataLoader,
output_types: List[str],
destination: str,
):
"""Evaluates test dataset and saves predictions if needed
Args:
model (Module): Model to use for inference
dataloader (DataLoader): Dataloader for inference
output_types (List[str]): List of output types.
Supported types:
* alphablend (img and predicted mask)
destination (str): Path to save results
Returns:
dict: Generates and saves predictions in desired format
"""
with torch.no_grad():
model.eval()
mask_config = dataloader.dataset.mask_config
for batch in dataloader:
inputs, names = batch["input"], batch["name"]
# Forward propagation
outputs = model(inputs)["out"]
masks = torch.argmax(outputs, dim=1)
for input_img, mask, name in zip(inputs, masks, names):
generate_outputs(
output_types,
destination,
input_img,
mask,
name,
mask_config,
dataloader,
)
def run_infer(
cfg_path: str,
checkpoint: str,
samples_list_path: str,
destination: str,
output_types: List[str],
):
# Build the model
cfg = get_cfg_from_file(cfg_path)
device = cfg.TEST.DEVICE
if cfg.TEST.WORKERS > 0:
torch.multiprocessing.set_start_method("spawn", force=True)
_, weights, _, _, _ = load_checkpoint(checkpoint, device)
model = get_model(cfg, device)
if get_gpu_count(cfg, mode="train") > 1 and get_gpu_count(cfg, mode="test") == 1:
weights = rename_ordered_dict_from_parallel(weights)
if get_gpu_count(cfg, mode="train") == 1 and get_gpu_count(cfg, mode="test") > 1:
weights = rename_ordered_dict_to_parallel(weights)
model.load_state_dict(weights)
if samples_list_path not in ["train", "val", "test"]:
samples_list = get_lines_from_txt(samples_list_path)
samples_to_infer = []
for sample_path in samples_list:
cropped_samples_paths = prepare_raster_for_inference(
sample_path, crop_size=[256, 256]
)
samples_to_infer.extend(cropped_samples_paths)
with open(cfg.TEST.INFER_SAMPLES_LIST_PATH, "w") as f:
for file in samples_to_infer:
f.write(file + "\n")
samples_list_path = cfg.TEST.INFER_SAMPLES_LIST_PATH
dataloader = get_dataloader(cfg, samples_list_path)
if not os.path.isdir(destination):
os.makedirs(destination)
infer(
model,
dataloader,
output_types,
destination,
)
if __name__ == "__main__":
args = parser()
run_infer(
args.cfg,
args.checkpoint,
args.samples_list,
args.destination,
args.outputs,
)