forked from yejin109/NeRF-pytorch-imple
-
Notifications
You must be signed in to change notification settings - Fork 0
/
functionals.py
85 lines (71 loc) · 2.29 KB
/
functionals.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import os
import glob
import copy
import time
import torch
import datetime
import numpy as np
def total_grad_norm(parameters, norm_type=2):
total_norm = 0
for p in parameters:
if p.grad is None:
continue
param_norm = p.grad.data.norm(norm_type)
total_norm += param_norm.item()**norm_type
total_norm = total_norm**(1. / norm_type)
return total_norm
def log_train(*args):
# iter_i, loss, psnr, grad_norm
log = [str(i.item()) if type(i)==torch.Tensor else str(i) for i in args ]
with open(f'{os.environ["LOG_DIR"]}/loss.txt', 'a') as f:
f.write(f'{",".join(log)}\n')
f.close()
def msg_convert(v):
msg = v
if isinstance(v, list):
msg = f"{type(v)}:{len(v)}"
if isinstance(v, np.ndarray):
msg = f"{type(v)}:{v.shape}"
if isinstance(v, np.ndarray):
msg = f"{type(v)}:{v.shape}"
if isinstance(v, torch.Tensor):
msg = f"{str(type(v))}:{v.size()}"
return msg
def log_cfg(func):
def wrap(*args, **kwargs):
with open(f'{os.environ["LOG_DIR"]}/cfg.txt', 'a') as f:
f.write('='*100+'\n')
f.write(f"{func.__name__}:\n")
f.write(f"\tKwargs\n")
for k, v in kwargs.items():
msg = msg_convert(v)
if isinstance(v, dict):
msg = copy.deepcopy(v)
for key, value in msg.items():
msg[key] = msg_convert(value)
f.write(f"\t\t{k}:{msg}\n")
f.write(f"\targs\n")
for v in args:
msg = v
if isinstance(v, list):
msg = f"{type}:{len(v)}"
f.write(f"\t\t{msg}\n")
f.close()
res = func(*args, **kwargs)
return res
wrap.__name__ = func.__name__
return wrap
def log_time(func):
def wrap(*args, **kwargs):
start = time.time()
res = func(*args, **kwargs)
print(f"{func.__name__}: {time.time()-start: .4f}")
return res
wrap.__name__ = func.__name__
return wrap
def log_internal(msg):
with open(f'{os.environ["LOG_DIR"]}/internal.txt', 'a') as f:
f.write(f'[{datetime.datetime.now().strftime("%H:%M:%S")}]{msg}\n')
f.close()
if os.environ['VERBOSE'] == "0":
print(msg)