Skip to content

Commit

Permalink
Merge pull request #25 from semi-h/feature
Browse files Browse the repository at this point in the history
Add an interface for solving a single tridiagonal system.
  • Loading branch information
semi-h authored Dec 20, 2023
2 parents f18fe02 + c44b395 commit fe39864
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 3 deletions.
22 changes: 22 additions & 0 deletions src/backend.f90
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ module m_base_backend
procedure(transeq_ders), deferred :: transeq_x
procedure(transeq_ders), deferred :: transeq_y
procedure(transeq_ders), deferred :: transeq_z
procedure(tds_solve), deferred :: tds_solve
procedure(transposer), deferred :: trans_x2y
procedure(transposer), deferred :: trans_x2z
procedure(sum9into3), deferred :: sum_yzintox
Expand Down Expand Up @@ -54,6 +55,27 @@ subroutine transeq_ders(self, du, dv, dw, u, v, w, dirps)
end subroutine transeq_ders
end interface

abstract interface
subroutine tds_solve(self, du, u, dirps, tdsops)
!! transeq equation obtains the derivatives direction by
!! direction, and the exact algorithm used to obtain these
!! derivatives are decided at runtime. Backend implementations
!! are responsible from directing calls to transeq_ders into
!! the correct algorithm.
import :: base_backend_t
import :: field_t
import :: dirps_t
import :: tdsops_t
implicit none

class(base_backend_t) :: self
class(field_t), intent(inout) :: du
class(field_t), intent(in) :: u
type(dirps_t), intent(in) :: dirps
class(tdsops_t), intent(in) :: tdsops
end subroutine tds_solve
end interface

abstract interface
subroutine transposer(self, u_, v_, w_, u, v, w)
!! transposer subroutines are straightforward, they rearrange
Expand Down
64 changes: 62 additions & 2 deletions src/cuda/backend.f90
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ module m_cuda_backend

use m_cuda_allocator, only: cuda_allocator_t, cuda_field_t
use m_cuda_common, only: SZ
use m_cuda_exec_dist, only: exec_dist_transeq_3fused
use m_cuda_sendrecv, only: sendrecv_3fields
use m_cuda_exec_dist, only: exec_dist_transeq_3fused, exec_dist_tds_compact
use m_cuda_sendrecv, only: sendrecv_fields, sendrecv_3fields
use m_cuda_tdsops, only: cuda_tdsops_t
use m_cuda_kernels_dist, only: transeq_3fused_dist, transeq_3fused_subs

Expand All @@ -31,13 +31,15 @@ module m_cuda_backend
procedure :: transeq_x => transeq_x_cuda
procedure :: transeq_y => transeq_y_cuda
procedure :: transeq_z => transeq_z_cuda
procedure :: tds_solve => tds_solve_cuda
procedure :: trans_x2y => trans_x2y_cuda
procedure :: trans_x2z => trans_x2z_cuda
procedure :: sum_yzintox => sum_yzintox_cuda
procedure :: set_fields => set_fields_cuda
procedure :: get_fields => get_fields_cuda
procedure :: transeq_cuda_dist
procedure :: transeq_cuda_thom
procedure :: tds_solve_dist
end type cuda_backend_t

interface cuda_backend_t
Expand Down Expand Up @@ -343,6 +345,64 @@ subroutine transeq_cuda_thom(self, du, dv, dw, u, v, w, dirps)

end subroutine transeq_cuda_thom

subroutine tds_solve_cuda(self, du, u, dirps, tdsops)
implicit none

class(cuda_backend_t) :: self
class(field_t), intent(inout) :: du
class(field_t), intent(in) :: u
type(dirps_t), intent(in) :: dirps
class(tdsops_t), intent(in) :: tdsops

type(dim3) :: blocks, threads

blocks = dim3(dirps%n_blocks, 1, 1); threads = dim3(SZ, 1, 1)

call tds_solve_dist(self, du, u, dirps, tdsops, blocks, threads)

end subroutine tds_solve_cuda

subroutine tds_solve_dist(self, du, u, dirps, tdsops, blocks, threads)
implicit none

class(cuda_backend_t) :: self
class(field_t), intent(inout) :: du
class(field_t), intent(in) :: u
type(dirps_t), intent(in) :: dirps
class(tdsops_t), intent(in) :: tdsops
type(dim3), intent(in) :: blocks, threads

real(dp), device, pointer, dimension(:, :, :) :: du_dev, u_dev

type(cuda_tdsops_t), pointer :: tdsops_dev

select type(du); type is (cuda_field_t); du_dev => du%data_d; end select
select type(u); type is (cuda_field_t); u_dev => u%data_d; end select

select type (tdsops)
type is (cuda_tdsops_t); tdsops_dev => tdsops
end select

call copy_into_buffers(self%u_send_s_dev, self%u_send_e_dev, u_dev, &
tdsops_dev%n)

call sendrecv_fields(self%u_recv_s_dev, self%u_recv_e_dev, &
self%u_send_s_dev, self%u_send_e_dev, &
SZ*4*blocks%x, dirps%nproc, &
dirps%pprev, dirps%pnext)

! call exec_dist
call exec_dist_tds_compact( &
du_dev, u_dev, &
self%u_recv_s_dev, self%u_recv_e_dev, &
self%du_send_s_dev, self%du_send_e_dev, &
self%du_recv_s_dev, self%du_recv_e_dev, &
tdsops_dev, dirps%nproc, dirps%pprev, dirps%pnext, &
blocks, threads &
)

end subroutine tds_solve_dist

subroutine trans_x2y_cuda(self, u_y, v_y, w_y, u, v, w)
implicit none

Expand Down
14 changes: 14 additions & 0 deletions src/omp/backend.f90
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ module m_omp_backend
procedure :: transeq_x => transeq_x_omp
procedure :: transeq_y => transeq_y_omp
procedure :: transeq_z => transeq_z_omp
procedure :: tds_solve => tds_solve_omp
procedure :: trans_x2y => trans_x2y_omp
procedure :: trans_x2z => trans_x2z_omp
procedure :: sum_yzintox => sum_yzintox_omp
Expand Down Expand Up @@ -138,6 +139,19 @@ subroutine transeq_z_omp(self, du, dv, dw, u, v, w, dirps)

end subroutine transeq_z_omp

subroutine tds_solve_omp(self, du, u, dirps, tdsops)
implicit none

class(omp_backend_t) :: self
class(field_t), intent(inout) :: du
class(field_t), intent(in) :: u
type(dirps_t), intent(in) :: dirps
class(tdsops_t), intent(in) :: tdsops

!call self%tds_solve_dist(self, du, u, dirps, tdsops)

end subroutine tds_solve_omp

subroutine trans_x2y_omp(self, u_, v_, w_, u, v, w)
implicit none

Expand Down
2 changes: 1 addition & 1 deletion src/tdsops.f90
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ module m_tdsops
der2nd, der2nd_sym, &
stagder_v2p, stagder_p2v, &
interpl_v2p, interpl_p2v
integer :: nrank, nproc, pnext, pprev, n
integer :: nrank, nproc, pnext, pprev, n, n_blocks
end type dirps_t

contains
Expand Down
4 changes: 4 additions & 0 deletions src/xcompact.f90
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ program xcompact
ydirps%n = globs%ny_loc
zdirps%n = globs%nz_loc

xdirps%n_blocks = globs%n_groups_x
ydirps%n_blocks = globs%n_groups_y
zdirps%n_blocks = globs%n_groups_z

#ifdef CUDA
cuda_allocator = cuda_allocator_t([SZ, globs%nx_loc, globs%n_groups_x])
allocator => cuda_allocator
Expand Down

0 comments on commit fe39864

Please sign in to comment.