-
Notifications
You must be signed in to change notification settings - Fork 3
/
benchmark.py
73 lines (53 loc) · 1.89 KB
/
benchmark.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import os.path as osp
import os
import sys
import time
import argparse
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.distributed as dist
import torch.backends.cudnn as cudnn
from torch.nn.parallel import DistributedDataParallel
from tensorboardX import SummaryWriter
from tqdm import tqdm
import os
import numpy as np
import torch
from torchvision.models import resnet18
import time
# from Code.lib.model import SPNet
# from mmcv import Config
from mmcv.cnn import get_model_complexity_info
from Code.models.builder import EncoderDecoder as segmodel
import torch
# from torchvision.models import AlexNet
# from torchviz import make_dot
if __name__ == '__main__':
BatchNorm2d = nn.BatchNorm2d
model=segmodel()
device = torch.device('cuda:0')
model.eval()
model.to(device)
dump_input = torch.ones(1,3,352,352).to(device)
# for i in tqdm(range(2000)):
# if i==50:
# start = time.time()
# outputs = model(dump_input,dump_input)
# torch.cuda.synchronize()
# if i==1999:
# end = time.time()
# print('Time:{}ms'.format((end-start)*1000/1950))
from thop import profile
# input = torch.randn(1,3,480,640)
# print(model)
input_shape = (1,3, 480, 640)
flops,params = profile(model,inputs=(dump_input,dump_input[:,1,:,:].unsqueeze(1)))
print('the flops is {}G,the params is {}M'.format(round(flops/(10**9),2), round(params/(10**6),2))) # 4111514624.0 25557032.0 res50
flops, params = get_model_complexity_info(model, input_shape)
split_line = '=' * 30
print('{0}\nInput shape: {1}\nFlops: {2}\nParams: {3}\n{0}'.format(
split_line, input_shape, flops, params))
print('!!!Please be cautious if you use the results in papers. '
'You may need to check if all ops are supported and verify that the '
'flops computation is correct.')