Skip to content

Commit

Permalink
perf(cuda): Improve performance of the fused dist transeq kernel.
Browse files Browse the repository at this point in the history
  • Loading branch information
semi-h committed Nov 10, 2023
1 parent d074884 commit cca29e1
Showing 1 changed file with 59 additions and 43 deletions.
102 changes: 59 additions & 43 deletions src/cuda/kernels_dist.f90
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ attributes(global) subroutine transeq_3fused_dist( &
real(dp), device, intent(in) :: d2_fw(:), d2_bw(:), d2_af(:)

! Local variables
integer :: i, j, b, k
integer :: i, j, b

real(dp) :: d1_c_m4, d1_c_m3, d1_c_m2, d1_c_m1, d1_c_j, &
d1_c_p1, d1_c_p2, d1_c_p3, d1_c_p4, &
Expand All @@ -226,23 +226,14 @@ attributes(global) subroutine transeq_3fused_dist( &
d2_c_p1, d2_c_p2, d2_c_p3, d2_c_p4, &
d2_alpha, d2_last_r
real(dp) :: temp_du, temp_dud, temp_d2u
real(dp) :: u_m4, u_m3, u_m2, u_m1, u_j, u_p1, u_p2, u_p3, u_p4
real(dp) :: v_m4, v_m3, v_m2, v_m1, v_j, v_p1, v_p2, v_p3, v_p4
real(dp) :: old_du, old_dud, old_d2u

i = threadIdx%x
b = blockIdx%x

! store bulk coeffs in the registers
d1_c_m4 = d1_coeffs(1); d1_c_m3 = d1_coeffs(2)
d1_c_m2 = d1_coeffs(3); d1_c_m1 = d1_coeffs(4)
d1_c_j = d1_coeffs(5)
d1_c_p1 = d1_coeffs(6); d1_c_p2 = d1_coeffs(7)
d1_c_p3 = d1_coeffs(8); d1_c_p4 = d1_coeffs(9)
d1_last_r = d1_fw(1)

d2_c_m4 = d2_coeffs(1); d2_c_m3 = d2_coeffs(2)
d2_c_m2 = d2_coeffs(3); d2_c_m1 = d2_coeffs(4)
d2_c_j = d2_coeffs(5)
d2_c_p1 = d2_coeffs(6); d2_c_p2 = d2_coeffs(7)
d2_c_p3 = d2_coeffs(8); d2_c_p4 = d2_coeffs(9)
d2_last_r = d2_fw(1)

! j = 1
Expand Down Expand Up @@ -373,40 +364,65 @@ attributes(global) subroutine transeq_3fused_dist( &
d1_alpha = d1_af(5)
d2_alpha = d2_af(5)

! store bulk coeffs in the registers
d1_c_m4 = d1_coeffs(1); d1_c_m3 = d1_coeffs(2)
d1_c_m2 = d1_coeffs(3); d1_c_m1 = d1_coeffs(4)
d1_c_j = d1_coeffs(5)
d1_c_p1 = d1_coeffs(6); d1_c_p2 = d1_coeffs(7)
d1_c_p3 = d1_coeffs(8); d1_c_p4 = d1_coeffs(9)

d2_c_m4 = d2_coeffs(1); d2_c_m3 = d2_coeffs(2)
d2_c_m2 = d2_coeffs(3); d2_c_m1 = d2_coeffs(4)
d2_c_j = d2_coeffs(5)
d2_c_p1 = d2_coeffs(6); d2_c_p2 = d2_coeffs(7)
d2_c_p3 = d2_coeffs(8); d2_c_p4 = d2_coeffs(9)

! It is better to access d?(i, j - 1, b) via old_d?
old_du = du(i, 4, b)
old_dud = dud(i, 4, b)
old_d2u = d2u(i, 4, b)

! Populate registers with the u and v stencils
u_m4 = u(i, 1, b); u_m3 = u(i, 2, b)
u_m2 = u(i, 3, b); u_m1 = u(i, 4, b)
u_j = u(i, 5, b); u_p1 = u(i, 6, b)
u_p2 = u(i, 7, b); u_p3 = u(i, 8, b)
v_m4 = v(i, 1, b); v_m3 = v(i, 2, b)
v_m2 = v(i, 3, b); v_m1 = v(i, 4, b)
v_j = v(i, 5, b); v_p1 = v(i, 6, b)
v_p2 = v(i, 7, b); v_p3 = v(i, 8, b)

do j = 5, n - 4
u_p4 = u(i, j+4, b); v_p4 = v(i, j+4, b)

! du
temp_du = d1_c_m4*u(i, j - 4, b) &
+ d1_c_m3*u(i, j - 3, b) &
+ d1_c_m2*u(i, j - 2, b) &
+ d1_c_m1*u(i, j - 1, b) &
+ d1_c_j*u(i, j, b) &
+ d1_c_p1*u(i, j + 1, b) &
+ d1_c_p2*u(i, j + 2, b) &
+ d1_c_p3*u(i, j + 3, b) &
+ d1_c_p4*u(i, j + 4, b)
du(i, j, b) = d1_fw(j)*(temp_du - d1_alpha*du(i, j - 1, b))
temp_du = d1_c_m4*u_m4 + d1_c_m3*u_m3 + d1_c_m2*u_m2 + d1_c_m1*u_m1 &
+ d1_c_j*u_j &
+ d1_c_p1*u_p1 + d1_c_p2*u_p2 + d1_c_p3*u_p3 + d1_c_p4*u_p4
du(i, j, b) = d1_fw(j)*(temp_du - d1_alpha*old_du)
old_du = du(i, j, b)

! dud
temp_dud = d1_c_m4*u(i, j - 4, b)*v(i, j - 4, b) &
+ d1_c_m3*u(i, j - 3, b)*v(i, j - 3, b) &
+ d1_c_m2*u(i, j - 2, b)*v(i, j - 2, b) &
+ d1_c_m1*u(i, j - 1, b)*v(i, j - 1, b) &
+ d1_c_j*u(i, j, b)*v(i, j, b) &
+ d1_c_p1*u(i, j + 1, b)*v(i, j + 1, b) &
+ d1_c_p2*u(i, j + 2, b)*v(i, j + 2, b) &
+ d1_c_p3*u(i, j + 3, b)*v(i, j + 3, b) &
+ d1_c_p4*u(i, j + 4, b)*v(i, j + 4, b)
dud(i, j, b) = d1_fw(j)*(temp_dud - d1_alpha*dud(i, j - 1, b))
temp_dud = d1_c_m4*u_m4*v_m4 + d1_c_m3*u_m3*v_m3 &
+ d1_c_m2*u_m2*v_m2 + d1_c_m1*u_m1*v_m1 &
+ d1_c_j*u_j*v_j &
+ d1_c_p1*u_p1*v_p1 + d1_c_p2*u_p2*v_p2 &
+ d1_c_p3*u_p3*v_p3 + d1_c_p4*u_p4*v_p4
dud(i, j, b) = d1_fw(j)*(temp_dud - d1_alpha*old_dud)
old_dud = dud(i, j, b)

! d2u
temp_d2u = d2_c_m4*u(i, j - 4, b) &
+ d2_c_m3*u(i, j - 3, b) &
+ d2_c_m2*u(i, j - 2, b) &
+ d2_c_m1*u(i, j - 1, b) &
+ d2_c_j*u(i, j, b) &
+ d2_c_p1*u(i, j + 1, b) &
+ d2_c_p2*u(i, j + 2, b) &
+ d2_c_p3*u(i, j + 3, b) &
+ d2_c_p4*u(i, j + 4, b)
d2u(i, j, b) = d2_fw(j)*(temp_d2u - d2_alpha*d2u(i, j - 1, b))
temp_d2u = d2_c_m4*u_m4 + d2_c_m3*u_m3 + d2_c_m2*u_m2 + d2_c_m1*u_m1 &
+ d2_c_j*u_j &
+ d2_c_p1*u_p1 + d2_c_p2*u_p2 + d2_c_p3*u_p3 + d2_c_p4*u_p4
d2u(i, j, b) = d2_fw(j)*(temp_d2u - d2_alpha*old_d2u)
old_d2u = d2u(i, j, b)

! Prepare registers for the next step
u_m4 = u_m3; u_m3 = u_m2; u_m2 = u_m1; u_m1 = u_j
u_j = u_p1; u_p1 = u_p2; u_p2 = u_p3; u_p3 = u_p4
v_m4 = v_m3; v_m3 = v_m2; v_m2 = v_m1; v_m1 = v_j
v_j = v_p1; v_p1 = v_p2; v_p2 = v_p3; v_p3 = v_p4
end do

j = n - 3
Expand Down

0 comments on commit cca29e1

Please sign in to comment.