Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FFT for OMP backend (via 2decomp&fft) #113

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
13 changes: 13 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,18 @@ cmake_minimum_required(VERSION 3.10)
project(x3d2 LANGUAGES Fortran)
enable_testing()

#
# Set the Poisson solver choice
#
set(POISSON_SOLVER FFT CACHE STRING
"Select the Poisson solver: FFT or ITER")

if(${CMAKE_Fortran_COMPILER_ID} STREQUAL "PGI" OR
${CMAKE_Fortran_COMPILER_ID} STREQUAL "NVHPC")
set(BACKEND CUDA)
else()
set(BACKEND OMP)
endif()

add_subdirectory(src)
add_subdirectory(tests)
105 changes: 105 additions & 0 deletions src/2decompfft/decomp.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
submodule(m_decomp) m_decomp_2decompfft

use mpi
use m_decomp, only: decomp_t
implicit none

type, extends(decomp_t) :: decomp_2decompfft_t
contains
procedure :: decomposition => decomposition_2decompfft
end type

contains

module subroutine decomposition_2decompfft(self, grid, par)
!! Supports 1D, 2D, and 3D domain decomposition.
!!
!! Current implementation allows only constant sub-domain size across a
!! given direction.
use m_grid, only: grid_t
use m_par, only: par_t
use decomp_2d, only: decomp_2d_init, DECOMP_2D_COMM_CART_X, xsize, xstart

class(decomp_2decompfft_t) :: self
class(grid_t), intent(inout) :: grid
class(par_t), intent(inout) :: par
integer :: p_col, p_row
integer, allocatable, dimension(:, :, :) :: global_ranks
integer, allocatable, dimension(:) :: global_ranks_lin
integer :: nproc
integer, dimension(3) :: subd_pos, subd_pos_prev, subd_pos_next
logical, dimension(3) :: periodic_bc
integer :: dir
logical :: is_last_domain
integer :: nx, ny, nz
integer :: ierr
integer :: cart_rank
integer, dimension(2) :: coords

if (par%is_root()) then
print*, "Domain decomposition by 2decomp&fft"
end if

nx = grid%global_cell_dims(1)
ny = grid%global_cell_dims(2)
nz = grid%global_cell_dims(3)

p_row = par%nproc_dir(2)
p_col = par%nproc_dir(3)
periodic_bc(:) = grid%periodic_BC(:)
call decomp_2d_init(nx, ny, nz, p_row, p_col, periodic_bc)

! Get global_ranks
allocate(global_ranks(1, p_row, p_col))
allocate(global_ranks_lin(p_row*p_col))
global_ranks_lin(:) = 0

call MPI_Comm_rank(DECOMP_2D_COMM_CART_X, cart_rank, ierr)
call MPI_Cart_coords(DECOMP_2D_COMM_CART_X, cart_rank, 2, coords, ierr)

global_ranks_lin(coords(1)+1 + p_row*(coords(2))) = par%nrank

call MPI_Allreduce(MPI_IN_PLACE, global_ranks_lin, p_row*p_col, MPI_INTEGER, MPI_SUM, MPI_COMM_WORLD, ierr)

global_ranks = reshape(global_ranks_lin, shape=[1, p_row, p_col])

! subdomain position in the global domain
subd_pos = findloc(global_ranks, par%nrank)

! local/directional position of the subdomain
par%nrank_dir(:) = subd_pos(:) - 1

! Get local domain size and offset from 2decomp
grid%cell_dims(:) = xsize(:)
par%n_offset(:) = xstart(:)

! compute vert_dims from cell_dims
do dir = 1, 3
is_last_domain = (par%nrank_dir(dir) + 1 == par%nproc_dir(dir))
if (is_last_domain .and. (.not. grid%periodic_BC(dir))) then
grid%vert_dims(dir) = grid%cell_dims(dir) +1
else
grid%vert_dims(dir) = grid%cell_dims(dir)
end if
end do

! Get neighbour ranks
do dir = 1, 3
nproc = par%nproc_dir(dir)
subd_pos_prev(:) = subd_pos(:)
subd_pos_prev(dir) = modulo(subd_pos(dir) - 2, nproc) + 1
par%pprev(dir) = global_ranks(subd_pos_prev(1), &
subd_pos_prev(2), &
subd_pos_prev(3))

subd_pos_next(:) = subd_pos(:)
subd_pos_next(dir) = modulo(subd_pos(dir) - nproc, nproc) + 1
par%pnext(dir) = global_ranks(subd_pos_next(1), &
subd_pos_next(2), &
subd_pos_next(3))
end do

end subroutine decomposition_2decompfft


end submodule
22 changes: 22 additions & 0 deletions src/omp/poisson_fft.f90 → src/2decompfft/omp/poisson_fft.f90
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
module m_omp_poisson_fft

use decomp_2d_constants, only: PHYSICAL_IN_X
use decomp_2d_fft, only: decomp_2d_fft_init, decomp_2d_fft_3d
use m_allocator, only: field_t
use m_common, only: dp
use m_poisson_fft, only: poisson_fft_t
use m_tdsops, only: dirps_t
use m_mesh, only: mesh_t
use m_omp_spectral, only: process_spectral_div_u

implicit none

Expand Down Expand Up @@ -35,26 +39,44 @@ function init(mesh, xdirps, ydirps, zdirps) result(poisson_fft)

call poisson_fft%base_init(mesh, xdirps, ydirps, zdirps)

if (mesh%par%is_root()) then
print*, "Initialising 2decomp&fft"
end if

call decomp_2d_fft_init(PHYSICAL_IN_X)
allocate(poisson_fft%c_x(poisson_fft%nx_spec, poisson_fft%ny_spec, poisson_fft%nz_spec))

end function init

subroutine fft_forward_omp(self, f_in)
implicit none

class(omp_poisson_fft_t) :: self
class(field_t), intent(in) :: f_in

call decomp_2d_fft_3d(f_in%data, self%c_x)

end subroutine fft_forward_omp

subroutine fft_backward_omp(self, f_out)
implicit none

class(omp_poisson_fft_t) :: self
class(field_t), intent(inout) :: f_out

call decomp_2d_fft_3d(self%c_x, f_out%data)

end subroutine fft_backward_omp

subroutine fft_postprocess_omp(self)
implicit none

class(omp_poisson_fft_t) :: self

call process_spectral_div_u(self%c_x, self%waves, self%nx_spec, self%ny_spec, self%nz_spec, &
self%y_sp_st, self%nx_glob, self%ny_glob, self%nz_glob, &
self%ax, self%bx, self%ay, self%by, self%az, self%bz)

end subroutine fft_postprocess_omp

end module m_omp_poisson_fft
38 changes: 30 additions & 8 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,16 @@ set(SRC
solver.f90
tdsops.f90
time_integrator.f90
ordering.f90
mesh.f90
decomp.f90
field.f90
par_grid.f90
vector_calculus.f90
omp/backend.f90
omp/common.f90
omp/kernels/distributed.f90
omp/poisson_fft.f90
omp/kernels/spectral_processing.f90
omp/sendrecv.f90
omp/exec_dist.f90
)
Expand All @@ -31,21 +36,31 @@ set(CUDASRC
cuda/sendrecv.f90
cuda/tdsops.f90
)
set(2DECOMPFFTSRC
2decompfft/omp/poisson_fft.f90
2decompfft/decomp.f90
)
set(GENERICDECOMPSRC
decomp_generic.f90
)

if(${CMAKE_Fortran_COMPILER_ID} STREQUAL "PGI" OR
${CMAKE_Fortran_COMPILER_ID} STREQUAL "NVHPC")
if(${BACKEND} STREQUAL "CUDA")
list(APPEND SRC ${CUDASRC})
endif()

if (${POISSON_SOLVER} STREQUAL "FFT" AND ${BACKEND} STREQUAL "OMP")
list(APPEND SRC ${2DECOMPFFTSRC})
else()
list(APPEND SRC ${GENERICDECOMPSRC})
endif()

add_library(x3d2 STATIC ${SRC})
target_include_directories(x3d2 INTERFACE ${CMAKE_CURRENT_BINARY_DIR})

add_executable(xcompact xcompact.f90)
target_link_libraries(xcompact PRIVATE x3d2)

if(${CMAKE_Fortran_COMPILER_ID} STREQUAL "PGI" OR
${CMAKE_Fortran_COMPILER_ID} STREQUAL "NVHPC")

if(${BACKEND} STREQUAL "CUDA")
set(CMAKE_Fortran_FLAGS "-cpp -cuda")
set(CMAKE_Fortran_FLAGS_DEBUG "-g -O0 -traceback -Mbounds -Mchkptr -Ktrap=fp")
set(CMAKE_Fortran_FLAGS_RELEASE "-O3 -fast")
Expand All @@ -59,11 +74,18 @@ elseif(${CMAKE_Fortran_COMPILER_ID} STREQUAL "GNU")
set(CMAKE_Fortran_FLAGS_RELEASE "-O3 -ffast-math")
endif()

if (${POISSON_SOLVER} STREQUAL "FFT" AND ${BACKEND} STREQUAL "OMP")
message(STATUS "Using the FFT poisson solver with 2decomp&fft")
find_package(decomp2d REQUIRED)
include_directories(${decomp2d_INCLUDE_DIRS})
target_link_libraries(decomp2d)
target_link_libraries(x3d2 PRIVATE decomp2d)
endif()

find_package(OpenMP REQUIRED)
target_link_libraries(x3d2 PRIVATE OpenMP::OpenMP_Fortran)
target_link_libraries(xcompact PRIVATE OpenMP::OpenMP_Fortran)

find_package(MPI REQUIRED)
target_link_libraries(x3d2 PRIVATE MPI::MPI_Fortran)
target_link_libraries(xcompact PRIVATE MPI::MPI_Fortran)

target_link_libraries(xcompact PRIVATE MPI::MPI_Fortran)
2 changes: 1 addition & 1 deletion src/allocator.f90
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ module m_allocator
contains

function allocator_init(mesh, sz) result(allocator)
type(mesh_t), target, intent(inout) :: mesh
class(mesh_t), target, intent(inout) :: mesh
integer, intent(in) :: sz
type(allocator_t) :: allocator

Expand Down
26 changes: 26 additions & 0 deletions src/decomp.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
module m_decomp
implicit none

type, abstract :: decomp_t
contains
procedure(decomposition), public, deferred :: decomposition
end type decomp_t

interface
subroutine decomposition(self, grid, par)
use m_grid, only: grid_t
use m_par, only: par_t
import :: decomp_t
class(decomp_t) :: self
class(grid_t), intent(inout) :: grid
class(par_t), intent(inout) :: par
end subroutine
end interface

contains

module subroutine test()
print *, "test"
end subroutine

end module m_decomp
80 changes: 80 additions & 0 deletions src/decomp_generic.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
submodule(m_decomp) m_decomp_generic

use m_decomp, only: decomp_t
implicit none

type, extends(decomp_t) :: decomp_generic_t
contains
procedure :: decomposition => decomposition_generic
end type

contains

module subroutine decomposition_generic(self, grid, par)
use m_grid, only: grid_t
use m_par, only: par_t

class(decomp_generic_t) :: self
class(grid_t), intent(inout) :: grid
class(par_t), intent(inout) :: par
integer, allocatable, dimension(:, :, :) :: global_ranks
integer :: i, nproc_x, nproc_y, nproc_z, nproc
integer, dimension(3) :: subd_pos, subd_pos_prev, subd_pos_next
integer :: dir
logical :: is_last_domain

if (par%is_root()) then
print*, "Domain decomposition by x3d2 (generic)"
end if

! Number of processes on a direction basis
nproc_x = par%nproc_dir(1)
nproc_y = par%nproc_dir(2)
nproc_z = par%nproc_dir(3)

! Define number of cells and vertices in each direction
grid%vert_dims = grid%global_vert_dims/par%nproc_dir

! A 3D array corresponding to each region in the global domain
allocate (global_ranks(nproc_x, nproc_y, nproc_z))

! set the corresponding global rank for each sub-domain
global_ranks = reshape([(i, i=0, par%nproc - 1)], &
shape=[nproc_x, nproc_y, nproc_z])

! subdomain position in the global domain
subd_pos = findloc(global_ranks, par%nrank)

! local/directional position of the subdomain
par%nrank_dir(:) = subd_pos(:) - 1

do dir = 1, 3
is_last_domain = (par%nrank_dir(dir) + 1 == par%nproc_dir(dir))
if (is_last_domain .and. (.not. grid%periodic_BC(dir))) then
grid%cell_dims(dir) = grid%vert_dims(dir) - 1
else
grid%cell_dims(dir) = grid%vert_dims(dir)
end if
end do

par%n_offset(:) = grid%vert_dims(:)*par%nrank_dir(:)

do dir = 1, 3
nproc = par%nproc_dir(dir)
subd_pos_prev(:) = subd_pos(:)
subd_pos_prev(dir) = modulo(subd_pos(dir) - 2, nproc) + 1
par%pprev(dir) = global_ranks(subd_pos_prev(1), &
subd_pos_prev(2), &
subd_pos_prev(3))

subd_pos_next(:) = subd_pos(:)
subd_pos_next(dir) = modulo(subd_pos(dir) - nproc, nproc) + 1
par%pnext(dir) = global_ranks(subd_pos_next(1), &
subd_pos_next(2), &
subd_pos_next(3))
end do

end subroutine decomposition_generic


end submodule
Loading
Loading