diff --git a/src/cuda/backend.f90 b/src/cuda/backend.f90 index 9c59ddd2..2cd042d7 100644 --- a/src/cuda/backend.f90 +++ b/src/cuda/backend.f90 @@ -15,7 +15,7 @@ module m_cuda_backend use m_cuda_kernels_dist, only: transeq_3fused_dist, transeq_3fused_subs use m_cuda_kernels_reorder, only: & reorder_x2y, reorder_x2z, reorder_y2x, reorder_y2z, reorder_z2y, & - sum_yintox, sum_zintox + sum_yintox, sum_zintox, axpby implicit none @@ -495,6 +495,18 @@ subroutine vecadd_cuda(self, a, x, b, y) real(dp), intent(in) :: b class(field_t), intent(inout) :: y + real(dp), device, pointer, dimension(:, :, :) :: x_d, y_d + type(dim3) :: blocks, threads + integer :: nx + + select type(x); type is (cuda_field_t); x_d => x%data_d; end select + select type(y); type is (cuda_field_t); y_d => y%data_d; end select + + nx = size(x_d, dim = 2) + blocks = dim3(size(x_d, dim = 3), 1, 1) + threads = dim3(SZ, 1, 1) + call axpby<<>>(nx, a, x_d, b, y_d) + end subroutine vecadd_cuda subroutine copy_into_buffers(u_send_s_dev, u_send_e_dev, u_dev, n)