-
Notifications
You must be signed in to change notification settings - Fork 8
/
util.py
93 lines (84 loc) · 3.55 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import torch
import torch.nn as nn
import threading
from torch._utils import ExceptionWrapper
import logging
def get_a_var(obj):
if isinstance(obj, torch.Tensor):
return obj
if isinstance(obj, list) or isinstance(obj, tuple):
for result in map(get_a_var, obj):
if isinstance(result, torch.Tensor):
return result
if isinstance(obj, dict):
for result in map(get_a_var, obj.items()):
if isinstance(result, torch.Tensor):
return result
return None
def parallel_apply(fct, model, inputs, device_ids):
modules = nn.parallel.replicate(model, device_ids)
assert len(modules) == len(inputs)
lock = threading.Lock()
results = {}
grad_enabled = torch.is_grad_enabled()
def _worker(i, module, input):
torch.set_grad_enabled(grad_enabled)
device = get_a_var(input).get_device()
try:
with torch.cuda.device(device):
# this also avoids accidental slicing of `input` if it is a Tensor
if not isinstance(input, (list, tuple)):
input = (input,)
output = fct(module, *input)
with lock:
results[i] = output
except Exception:
with lock:
results[i] = ExceptionWrapper(where="in replica {} on device {}".format(i, device))
if len(modules) > 1:
threads = [threading.Thread(target=_worker, args=(i, module, input))
for i, (module, input) in enumerate(zip(modules, inputs))]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
else:
_worker(0, modules[0], inputs[0])
outputs = []
for i in range(len(inputs)):
output = results[i]
if isinstance(output, ExceptionWrapper):
output.reraise()
outputs.append(output)
return outputs
logger_initialized = {}
def get_logger(filename=None):
if "seg" in logger_initialized:
return logger_initialized["seg"]
else:
assert filename is not None
try:
from mmcv.utils import get_logger
from termcolor import colored
fmt = '[%(asctime)s %(name)s] (%(filename)s %(lineno)d): %(levelname)s %(message)s'
color_fmt = colored('[%(asctime)s %(name)s]', 'green') \
+ colored('(%(filename)s %(lineno)d)', 'yellow') + ': %(levelname)s %(message)s'
logger = get_logger("seg", log_file=filename, log_level=logging.INFO, file_mode='a')
for handler in logger.handlers:
if isinstance(handler, logging.StreamHandler):
handler.setFormatter(logging.Formatter(fmt=color_fmt, datefmt='%Y-%m-%d %H:%M:%S'))
if isinstance(handler, logging.FileHandler):
handler.setFormatter(logging.Formatter(fmt=fmt, datefmt='%Y-%m-%d %H:%M:%S'))
except Exception as e:
logger = logging.getLogger("seg")
logger.setLevel(logging.DEBUG)
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s',
datefmt='%m/%d/%Y %H:%M:%S',
level=logging.INFO)
if filename is not None:
handler = logging.FileHandler(filename)
handler.setLevel(logging.DEBUG)
handler.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s: %(message)s'))
logging.getLogger().addHandler(handler)
logger_initialized["seg"] = logger
return logger