-
Notifications
You must be signed in to change notification settings - Fork 345
/
pytorch_square.py
62 lines (44 loc) · 1.32 KB
/
pytorch_square.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import torch
a = torch.tensor([1., 2., 3.])
print(torch.square(a))
print(a ** 2)
print(a * a)
def time_pytorch_function(func, input):
# CUDA IS ASYNC so can't use python time module
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
# Warmup
for _ in range(5):
func(input)
start.record()
func(input)
end.record()
torch.cuda.synchronize()
return start.elapsed_time(end)
b = torch.randn(10000, 10000).cuda()
def square_2(a):
return a * a
def square_3(a):
return a ** 2
time_pytorch_function(torch.square, b)
time_pytorch_function(square_2, b)
time_pytorch_function(square_3, b)
print("=============")
print("Profiling torch.square")
print("=============")
# Now profile each function using pytorch profiler
with torch.profiler.profile() as prof:
torch.square(b)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
print("=============")
print("Profiling a * a")
print("=============")
with torch.profiler.profile() as prof:
square_2(b)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
print("=============")
print("Profiling a ** 2")
print("=============")
with torch.profiler.profile() as prof:
square_3(b)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))