Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: make force and stress of sDFT support GPU #5487

Merged
merged 9 commits into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions source/Makefile.Objects
Original file line number Diff line number Diff line change
Expand Up @@ -618,6 +618,7 @@ OBJS_PARALLEL=parallel_common.o\
parallel_grid.o\
parallel_kpoints.o\
parallel_reduce.o\
parallel_device.o

OBJS_SRCPW=H_Ewald_pw.o\
dnrm2.o\
Expand All @@ -641,6 +642,7 @@ OBJS_SRCPW=H_Ewald_pw.o\
forces_cc.o\
forces_scc.o\
fs_nonlocal_tools.o\
fs_kin_tools.o\
force_op.o\
stress_op.o\
wf_op.o\
Expand Down
1 change: 1 addition & 0 deletions source/module_base/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ add_library(
parallel_global.cpp
parallel_comm.cpp
parallel_reduce.cpp
parallel_device.cpp
spherical_bessel_transformer.cpp
cubic_spline.cpp
module_mixing/mixing_data.cpp
Expand Down
31 changes: 17 additions & 14 deletions source/module_base/module_device/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,27 +191,30 @@ else { return "cpu";
}
}

int get_device_kpar(const int &kpar) {
int get_device_kpar(const int& kpar, const int& bndpar)
{
#if __MPI && (__CUDA || __ROCM)
int temp_nproc;
MPI_Comm_size(MPI_COMM_WORLD, &temp_nproc);
if (temp_nproc != kpar) {
ModuleBase::WARNING("Input_conv",
"None kpar set in INPUT file, auto set kpar value.");
}
// GlobalV::KPAR = temp_nproc;
// band the CPU processor to the devices
int node_rank = base_device::information::get_node_rank();
int temp_nproc = 0;
int new_kpar = kpar;
MPI_Comm_size(MPI_COMM_WORLD, &temp_nproc);
if (temp_nproc != kpar * bndpar)
{
new_kpar = temp_nproc / bndpar;
ModuleBase::WARNING("Input_conv", "kpar is not compatible with the number of processors, auto set kpar value.");
}

// get the CPU rank of current node
int node_rank = base_device::information::get_node_rank();

int device_num = -1;
int device_num = -1;
#if defined(__CUDA)
cudaGetDeviceCount(&device_num);
cudaSetDevice(node_rank % device_num);
cudaGetDeviceCount(&device_num); // get the number of GPU devices of current node
cudaSetDevice(node_rank % device_num); // band the CPU processor to the devices
#elif defined(__ROCM)
hipGetDeviceCount(&device_num);
hipSetDevice(node_rank % device_num);
#endif
return temp_nproc;
return new_kpar;
#endif
return kpar;
}
Expand Down
8 changes: 7 additions & 1 deletion source/module_base/module_device/device.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ std::string get_device_info(std::string device_flag);
* @brief Get the device kpar object
* for module_io GlobalV::KPAR
*/
int get_device_kpar(const int& kpar);
int get_device_kpar(const int& kpar, const int& bndpar);

/**
* @brief Get the device flag object
Expand All @@ -50,6 +50,12 @@ std::string get_device_flag(const std::string& device,
const std::string& basis_type);

#if __MPI
/**
* @brief Get the rank of current node
* Note that GPU can only be binded with CPU in the same node
*
* @return int
*/
int get_node_rank();
int get_node_rank_with_mpi_shared(const MPI_Comm mpi_comm = MPI_COMM_WORLD);
int stringCmp(const void* a, const void* b);
Expand Down
13 changes: 7 additions & 6 deletions source/module_base/parallel_common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,16 @@ void Parallel_Common::bcast_string(std::string& object) // Peize Lin fix bug 201
{
int size = object.size();
MPI_Bcast(&size, 1, MPI_INT, 0, MPI_COMM_WORLD);
char* swap = new char[size + 1];

int my_rank;
MPI_Comm_rank(MPI_COMM_WORLD, &my_rank);
if (0 == my_rank)
strcpy(swap, object.c_str());
MPI_Bcast(swap, size + 1, MPI_CHAR, 0, MPI_COMM_WORLD);

if (0 != my_rank)
object = static_cast<std::string>(swap);
delete[] swap;
{
object.resize(size);
}

MPI_Bcast(&object[0], size, MPI_CHAR, 0, MPI_COMM_WORLD);
return;
}

Expand Down
38 changes: 38 additions & 0 deletions source/module_base/parallel_device.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#include "parallel_device.h"
#ifdef __MPI
namespace Parallel_Common
{
void bcast_data(std::complex<double>* object, const int& n, const MPI_Comm& comm)
{
MPI_Bcast(object, n * 2, MPI_DOUBLE, 0, comm);
}
void bcast_data(std::complex<float>* object, const int& n, const MPI_Comm& comm)
{
MPI_Bcast(object, n * 2, MPI_FLOAT, 0, comm);
}
void bcast_data(double* object, const int& n, const MPI_Comm& comm)
{
MPI_Bcast(object, n, MPI_DOUBLE, 0, comm);
}
void bcast_data(float* object, const int& n, const MPI_Comm& comm)
{
MPI_Bcast(object, n, MPI_FLOAT, 0, comm);
}
void reduce_data(std::complex<double>* object, const int& n, const MPI_Comm& comm)
{
MPI_Allreduce(MPI_IN_PLACE, object, n * 2, MPI_DOUBLE, MPI_SUM, comm);
}
void reduce_data(std::complex<float>* object, const int& n, const MPI_Comm& comm)
{
MPI_Allreduce(MPI_IN_PLACE, object, n * 2, MPI_FLOAT, MPI_SUM, comm);
}
void reduce_data(double* object, const int& n, const MPI_Comm& comm)
{
MPI_Allreduce(MPI_IN_PLACE, object, n, MPI_DOUBLE, MPI_SUM, comm);
}
void reduce_data(float* object, const int& n, const MPI_Comm& comm)
{
MPI_Allreduce(MPI_IN_PLACE, object, n, MPI_FLOAT, MPI_SUM, comm);
}
}
#endif
45 changes: 21 additions & 24 deletions source/module_base/parallel_device.h
Original file line number Diff line number Diff line change
@@ -1,39 +1,34 @@
#ifndef __PARALLEL_DEVICE_H__
#define __PARALLEL_DEVICE_H__
#ifdef __MPI
#include "mpi.h"
#include "module_base/module_device/device.h"
#include "module_base/module_device/memory_op.h"
#include <complex>
#include <string>
#include <vector>
namespace Parallel_Common
{
void bcast_complex(std::complex<double>* object, const int& n, const MPI_Comm& comm)
{
MPI_Bcast(object, n * 2, MPI_DOUBLE, 0, comm);
}
void bcast_complex(std::complex<float>* object, const int& n, const MPI_Comm& comm)
{
MPI_Bcast(object, n * 2, MPI_FLOAT, 0, comm);
}
void bcast_real(double* object, const int& n, const MPI_Comm& comm)
{
MPI_Bcast(object, n, MPI_DOUBLE, 0, comm);
}
void bcast_real(float* object, const int& n, const MPI_Comm& comm)
{
MPI_Bcast(object, n, MPI_FLOAT, 0, comm);
}
void bcast_data(std::complex<double>* object, const int& n, const MPI_Comm& comm);
void bcast_data(std::complex<float>* object, const int& n, const MPI_Comm& comm);
void bcast_data(double* object, const int& n, const MPI_Comm& comm);
void bcast_data(float* object, const int& n, const MPI_Comm& comm);
void reduce_data(std::complex<double>* object, const int& n, const MPI_Comm& comm);
void reduce_data(std::complex<float>* object, const int& n, const MPI_Comm& comm);
void reduce_data(double* object, const int& n, const MPI_Comm& comm);
void reduce_data(float* object, const int& n, const MPI_Comm& comm);

template <typename T, typename Device>
/**
* @brief bcast complex in Device
* @brief bcast data in Device
*
* @tparam T: float, double, std::complex<float>, std::complex<double>
* @tparam Device
* @param ctx Device ctx
* @param object complex arrays in Device
* @param n the size of complex arrays
* @param comm MPI_Comm
* @param tmp_space tmp space in CPU
*/
void bcast_complex(const Device* ctx, T* object, const int& n, const MPI_Comm& comm, T* tmp_space = nullptr)
template <typename T, typename Device>
void bcast_dev(const Device* ctx, T* object, const int& n, const MPI_Comm& comm, T* tmp_space = nullptr)
{
const base_device::DEVICE_CPU* cpu_ctx = {};
T* object_cpu = nullptr;
Expand All @@ -56,7 +51,7 @@ void bcast_complex(const Device* ctx, T* object, const int& n, const MPI_Comm& c
object_cpu = object;
}

bcast_complex(object_cpu, n, comm);
bcast_data(object_cpu, n, comm);

if (base_device::get_device_type<Device>(ctx) == base_device::GpuDevice)
{
Expand All @@ -70,7 +65,7 @@ void bcast_complex(const Device* ctx, T* object, const int& n, const MPI_Comm& c
}

template <typename T, typename Device>
void bcast_real(const Device* ctx, T* object, const int& n, const MPI_Comm& comm, T* tmp_space = nullptr)
void reduce_dev(const Device* ctx, T* object, const int& n, const MPI_Comm& comm, T* tmp_space = nullptr)
{
const base_device::DEVICE_CPU* cpu_ctx = {};
T* object_cpu = nullptr;
Expand All @@ -93,7 +88,7 @@ void bcast_real(const Device* ctx, T* object, const int& n, const MPI_Comm& comm
object_cpu = object;
}

bcast_real(object_cpu, n, comm);
reduce_data(object_cpu, n, comm);

if (base_device::get_device_type<Device>(ctx) == base_device::GpuDevice)
{
Expand All @@ -105,7 +100,9 @@ void bcast_real(const Device* ctx, T* object, const int& n, const MPI_Comm& comm
}
return;
}

}


#endif
#endif
9 changes: 4 additions & 5 deletions source/module_elecstate/elecstate_pw_sdft.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,13 @@ void ElecStatePW_SDFT<T, Device>::psiToRho(const psi::Psi<T, Device>& psi)
ModuleBase::TITLE(this->classname, "psiToRho");
ModuleBase::timer::tick(this->classname, "psiToRho");
const int nspin = PARAM.inp.nspin;
for (int is = 0; is < nspin; is++)
{
setmem_var_op()(this->ctx, this->rho[is], 0, this->charge->nrxx);
}

if (GlobalV::MY_STOGROUP == 0)
{
for (int is = 0; is < nspin; is++)
{
setmem_var_op()(this->ctx, this->rho[is], 0, this->charge->nrxx);
}

for (int ik = 0; ik < psi.get_nk(); ++ik)
{
psi.fix_k(ik);
Expand Down
17 changes: 14 additions & 3 deletions source/module_elecstate/module_charge/charge_extra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,20 @@ void Charge_Extra::Init_CE(const int& nspin, const int& natom, const int& nrxx,

if (pot_order > 0)
{
delta_rho1.resize(this->nspin, std::vector<double>(nrxx, 0.0));
delta_rho2.resize(this->nspin, std::vector<double>(nrxx, 0.0));
delta_rho3.resize(this->nspin, std::vector<double>(nrxx, 0.0));
// delta_rho1.resize(this->nspin, std::vector<double>(nrxx, 0.0));
// delta_rho2.resize(this->nspin, std::vector<double>(nrxx, 0.0));
// delta_rho3.resize(this->nspin, std::vector<double>(nrxx, 0.0));
// qianrui replace the above code with the following code.
// The above code cannot passed valgrind tests, which has an invalid read of size 32.
delta_rho1.resize(this->nspin);
delta_rho2.resize(this->nspin);
delta_rho3.resize(this->nspin);
for (int is = 0; is < this->nspin; is++)
{
delta_rho1[is].resize(nrxx, 0.0);
delta_rho2[is].resize(nrxx, 0.0);
delta_rho3[is].resize(nrxx, 0.0);
}
}

if(pot_order == 3)
Expand Down
1 change: 0 additions & 1 deletion source/module_esolver/esolver_ks_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -623,7 +623,6 @@ void ESolver_KS_PW<T, Device>::cal_stress(ModuleBase::matrix& stress)
&this->sf,
&this->kv,
this->pw_wfc,
this->psi,
this->__kspw_psi);

// external stress
Expand Down
34 changes: 14 additions & 20 deletions source/module_esolver/esolver_sdft_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,9 @@ void ESolver_SDFT_PW<T, Device>::after_scf(const int istep)
template <typename T, typename Device>
void ESolver_SDFT_PW<T, Device>::hamilt2density_single(int istep, int iter, double ethr)
{
ModuleBase::TITLE("ESolver_SDFT_PW", "hamilt2density");
ModuleBase::timer::tick("ESolver_SDFT_PW", "hamilt2density");

// reset energy
this->pelec->f_en.eband = 0.0;
this->pelec->f_en.demet = 0.0;
Expand Down Expand Up @@ -241,6 +244,7 @@ void ESolver_SDFT_PW<T, Device>::hamilt2density_single(int istep, int iter, doub
#ifdef __MPI
MPI_Bcast(&(this->pelec->f_en.deband), 1, MPI_DOUBLE, 0, PARAPW_WORLD);
#endif
ModuleBase::timer::tick("ESolver_SDFT_PW", "hamilt2density");
}

template <typename T, typename Device>
Expand All @@ -249,10 +253,10 @@ double ESolver_SDFT_PW<T, Device>::cal_energy()
return this->pelec->f_en.etot;
}

template <>
void ESolver_SDFT_PW<std::complex<double>, base_device::DEVICE_CPU>::cal_force(ModuleBase::matrix& force)
template <typename T, typename Device>
void ESolver_SDFT_PW<T, Device>::cal_force(ModuleBase::matrix& force)
{
Sto_Forces ff(GlobalC::ucell.nat);
Sto_Forces<double, Device> ff(GlobalC::ucell.nat);

ff.cal_stoforce(force,
*this->pelec,
Expand All @@ -261,40 +265,30 @@ void ESolver_SDFT_PW<std::complex<double>, base_device::DEVICE_CPU>::cal_force(M
&this->sf,
&this->kv,
this->pw_wfc,
this->psi,
GlobalC::ppcell,
GlobalC::ucell,
*this->kspw_psi,
this->stowf);
}

template <>
void ESolver_SDFT_PW<std::complex<double>, base_device::DEVICE_GPU>::cal_force(ModuleBase::matrix& force)
{
ModuleBase::WARNING_QUIT("ESolver_SDFT_PW<T, Device>::cal_force", "DEVICE_GPU is not supported");
}

template <>
void ESolver_SDFT_PW<std::complex<double>, base_device::DEVICE_CPU>::cal_stress(ModuleBase::matrix& stress)
template <typename T, typename Device>
void ESolver_SDFT_PW<T, Device>::cal_stress(ModuleBase::matrix& stress)
{
Sto_Stress_PW ss;
Sto_Stress_PW<double, Device> ss;
ss.cal_stress(stress,
*this->pelec,
this->pw_rho,
&GlobalC::ucell.symm,
&this->sf,
&this->kv,
this->pw_wfc,
this->psi,
*this->kspw_psi,
this->stowf,
this->pelec->charge,
&GlobalC::ppcell,
GlobalC::ucell);
}

template <>
void ESolver_SDFT_PW<std::complex<double>, base_device::DEVICE_GPU>::cal_stress(ModuleBase::matrix& stress)
{
ModuleBase::WARNING_QUIT("ESolver_SDFT_PW<T, Device>::cal_stress", "DEVICE_GPU is not supported");
}

template <typename T, typename Device>
void ESolver_SDFT_PW<T, Device>::after_all_runners()
{
Expand Down
1 change: 1 addition & 0 deletions source/module_hamilt_pw/hamilt_pwdft/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ list(APPEND objects
parallel_grid.cpp
elecond.cpp
fs_nonlocal_tools.cpp
fs_kin_tools.cpp
radial_proj.cpp
)

Expand Down
Loading
Loading