-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtriton_utils.py
57 lines (47 loc) · 1.12 KB
/
triton_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import triton
import triton.language as tl
@triton.jit
def get_tid():
return tl.inline_asm_elementwise(
"""
mov.u32 $0, %tid.x;
mov.u32 $1, %tid.y;
mov.u32 $2, %tid.z;
""",
"=r,=r,=r",
[],
dtype=(tl.uint32, tl.uint32, tl.uint32),
is_pure=True,
pack=1,
)
@triton.jit
def get_ntid():
return tl.inline_asm_elementwise(
"""
mov.u32 $0, %ntid.x;
mov.u32 $1, %ntid.y;
mov.u32 $2, %ntid.z;
""",
"=r,=r,=r",
[],
dtype=(tl.uint32, tl.uint32, tl.uint32),
is_pure=True,
pack=1,
)
@triton.jit
def get_flat_tid():
tid_x, tid_y, tid_z = get_tid()
ntid_x, ntid_y, _ = get_ntid()
return tid_z * ntid_y * ntid_x + tid_y * ntid_x + tid_x
@triton.jit
def get_flat_bid():
return (
tl.program_id(2) * tl.num_programs(1) * tl.num_programs(0)
+ tl.program_id(1) * tl.num_programs(0)
+ tl.program_id(0)
)
@triton.jit
def sync_threads():
tl.inline_asm_elementwise(
"bar.sync 0;", "=r", [], dtype=tl.int32, is_pure=False, pack=1
)