-
Notifications
You must be signed in to change notification settings - Fork 2
/
test.py
70 lines (52 loc) · 1.98 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import torch
from torch.utils.benchmark import Timer
from torch import nn
from cublas_ops import CublasLinear
def time_fn(fn, A, B, bias, warmups=20, iters=50):
def fn_():
return fn(A, B, bias)
for _ in range(warmups):
fn_()
timer = Timer(stmt="fn_()", globals=locals())
return timer.timeit(iters)
def time_linear(linear, input, warmups=20, iters=50):
def fn_():
return linear(input)
for _ in range(warmups):
fn_()
timer = Timer(stmt="fn_()", globals=locals())
return timer.timeit(iters)
with torch.no_grad():
M = 2048*2
K = 2048*2
N = 2048*2
base_module = nn.Linear(in_features=K, out_features=N, bias=True).cuda(0).half()
# f16_module = F16Linear(in_features=K, out_features=N, bias=True).cuda(0).half()
cublas_module = (
CublasLinear(in_features=K, out_features=N, bias=True).cuda(0).half()
)
cublas_module.weight.data = base_module.weight.data.clone()
cublas_module.bias.data = base_module.bias.data.clone()
cublas_module
base_module.compile()
# f16_module.weight.data = base_module.weight.data.clone()
# f16_module.bias.data = base_module.bias.data.clone()
input_t = torch.randn(M, K).half().cuda(0) / 4
out = base_module(input_t)
# out_f16 = f16_module(input_t.clone())
out_cublas = cublas_module(input_t)
print("Output From nn.Linear (compiled):\n",out)
print("Output From CublasLinear:\n",out_cublas)
FLOPS = 2 * M * N * K
time_lin = time_linear(base_module, input_t, warmups=50, iters=100)
time_cublas = time_linear(cublas_module, input_t, warmups=50, iters=100)
print(
"torch f16 W/ f32 acc (compiled): ".upper().replace(" ", "_"),
f"{time_lin.mean * 1000 * 1000:.2f} us",
f"{(FLOPS / time_lin.mean)/1e12:.2f} TFLOPS",
)
print(
"cublas f16 W/ f16 acc: ".upper().replace(" ", "_"),
f"{time_cublas.mean * 1000 * 1000:.2f} us",
f"{(FLOPS / time_cublas.mean)/1e12:.2f} TFLOPS",
)