Skip to content

Commit

Permalink
Refactor: refactor HsolverPW & HsolverPW_SDFT func (#5094)
Browse files Browse the repository at this point in the history
* refactor HsolverPW & HsolverPW_SDFT func

* remove some useless code

* refactor hsolver_pw & hsolver_pw_sdft

* fix build bug

* fix build bug

* add nspin value in hsolver_pw

* fix build bug

* fix build bug
  • Loading branch information
haozhihan authored Sep 14, 2024
1 parent c020e20 commit 7ef0f50
Show file tree
Hide file tree
Showing 9 changed files with 172 additions and 192 deletions.
36 changes: 21 additions & 15 deletions source/module_esolver/esolver_ks_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,7 @@ void ESolver_KS_PW<T, Device>::hamilt2density(const int istep, const int iter, c
// be careful that istep start from 0 and iter start from 1
// if (iter == 1)
hsolver::DiagoIterAssist<T, Device>::need_subspace = ((istep == 0 || istep == 1) && iter == 1) ? false : true;

hsolver::DiagoIterAssist<T, Device>::SCF_ITER = iter;
hsolver::DiagoIterAssist<T, Device>::PW_DIAG_THR = ethr;
hsolver::DiagoIterAssist<T, Device>::PW_DIAG_NMAX = PARAM.inp.pw_diag_nmax;
Expand All @@ -361,25 +362,30 @@ void ESolver_KS_PW<T, Device>::hamilt2density(const int istep, const int iter, c
this->kspw_psi->get_nbands(),
PARAM.inp.diago_full_acc);

hsolver::HSolverPW<T, Device> hsolver_pw_obj(this->pw_wfc, &this->wf, this->init_psi);
hsolver_pw_obj.solve(this->p_hamilt, // hamilt::Hamilt<T, Device>* pHamilt,
this->kspw_psi[0], // psi::Psi<T, Device>& psi,
this->pelec, // elecstate::ElecState<T, Device>* pelec,
hsolver::HSolverPW<T, Device> hsolver_pw_obj(this->pw_wfc,
&this->wf,

PARAM.inp.calculation,
PARAM.inp.basis_type,
PARAM.inp.ks_solver,
PARAM.inp.use_paw,
GlobalV::use_uspp,
GlobalV::NSPIN,

hsolver::DiagoIterAssist<T, Device>::SCF_ITER,
hsolver::DiagoIterAssist<T, Device>::PW_DIAG_NMAX,
hsolver::DiagoIterAssist<T, Device>::PW_DIAG_THR,

hsolver::DiagoIterAssist<T, Device>::need_subspace,
this->init_psi);

hsolver_pw_obj.solve(this->p_hamilt,
this->kspw_psi[0],
this->pelec,
this->pelec->ekb.c,
is_occupied,
PARAM.inp.ks_solver,
PARAM.inp.calculation,
PARAM.inp.basis_type,
PARAM.inp.use_paw,
GlobalV::use_uspp,
GlobalV::RANK_IN_POOL,
GlobalV::NPROC_IN_POOL,

hsolver::DiagoIterAssist<T, Device>::SCF_ITER,
hsolver::DiagoIterAssist<T, Device>::need_subspace,
hsolver::DiagoIterAssist<T, Device>::PW_DIAG_NMAX,
hsolver::DiagoIterAssist<T, Device>::PW_DIAG_THR,

false);

this->init_psi = true;
Expand Down
23 changes: 17 additions & 6 deletions source/module_esolver/esolver_sdft_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,23 @@ void ESolver_SDFT_PW::hamilt2density(int istep, int iter, double ethr)
hsolver::DiagoIterAssist<std::complex<double>>::PW_DIAG_NMAX = PARAM.inp.pw_diag_nmax;

// hsolver only exists in this function
hsolver::HSolverPW_SDFT hsolver_pw_sdft_obj(&this->kv, this->pw_wfc, &this->wf, this->stowf, this->stoche, this->init_psi);
hsolver::HSolverPW_SDFT hsolver_pw_sdft_obj(
&this->kv,
this->pw_wfc,
&this->wf,
this->stowf,
this->stoche,
PARAM.inp.calculation,
PARAM.inp.basis_type,
PARAM.inp.ks_solver,
PARAM.inp.use_paw,
GlobalV::use_uspp,
GlobalV::NSPIN,
hsolver::DiagoIterAssist<std::complex<double>>::SCF_ITER,
hsolver::DiagoIterAssist<std::complex<double>>::PW_DIAG_NMAX,
hsolver::DiagoIterAssist<std::complex<double>>::PW_DIAG_THR,
hsolver::DiagoIterAssist<std::complex<double>>::need_subspace,
this->init_psi);

hsolver_pw_sdft_obj.solve(this->p_hamilt,
this->psi[0],
Expand All @@ -181,11 +197,6 @@ void ESolver_SDFT_PW::hamilt2density(int istep, int iter, double ethr)
this->stowf,
istep,
iter,
GlobalV::KS_SOLVER,
hsolver::DiagoIterAssist<std::complex<double>>::SCF_ITER,
hsolver::DiagoIterAssist<std::complex<double>>::need_subspace,
hsolver::DiagoIterAssist<std::complex<double>>::PW_DIAG_NMAX,
hsolver::DiagoIterAssist<std::complex<double>>::PW_DIAG_THR,
false);
this->init_psi = true;

Expand Down
40 changes: 23 additions & 17 deletions source/module_esolver/pw_fun.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,25 +80,31 @@ void ESolver_KS_PW<T, Device>::hamilt2estates(const double ethr)
this->kspw_psi->get_nbands(),
PARAM.inp.diago_full_acc);

hsolver::HSolverPW<T, Device> hsolver_pw_obj(this->pw_wfc, &this->wf, this->init_psi);
hsolver::HSolverPW<T, Device> hsolver_pw_obj(this->pw_wfc,
&this->wf,

PARAM.inp.calculation,
PARAM.inp.basis_type,
PARAM.inp.ks_solver,
PARAM.inp.use_paw,
GlobalV::use_uspp,
GlobalV::NSPIN,

hsolver::DiagoIterAssist<T, Device>::SCF_ITER,
hsolver::DiagoIterAssist<T, Device>::PW_DIAG_NMAX,
hsolver::DiagoIterAssist<T, Device>::PW_DIAG_THR,

hsolver::DiagoIterAssist<T, Device>::need_subspace,
this->init_psi);

hsolver_pw_obj.solve(this->p_hamilt,
this->kspw_psi[0],
this->pelec,
this->pelec->ekb.c,
is_occupied,
PARAM.inp.ks_solver,
PARAM.inp.calculation,
PARAM.inp.basis_type,
PARAM.inp.use_paw,
GlobalV::use_uspp,
GlobalV::RANK_IN_POOL,
GlobalV::NPROC_IN_POOL,
hsolver::DiagoIterAssist<T, Device>::SCF_ITER,
hsolver::DiagoIterAssist<T, Device>::need_subspace,
hsolver::DiagoIterAssist<T, Device>::PW_DIAG_NMAX,
hsolver::DiagoIterAssist<T, Device>::PW_DIAG_THR,
true);
this->kspw_psi[0],
this->pelec,
this->pelec->ekb.c,
is_occupied,
GlobalV::RANK_IN_POOL,
GlobalV::NPROC_IN_POOL,
true);

this->init_psi = true;
}
Expand Down
45 changes: 6 additions & 39 deletions source/module_hsolver/hsolver_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -208,55 +208,22 @@ void HSolverPW<T, Device>::paw_func_after_kloop(psi::Psi<T, Device>& psi, elecst

#endif

template <typename T, typename Device>
HSolverPW<T, Device>::HSolverPW(ModulePW::PW_Basis_K* wfc_basis_in,
wavefunc* pwf_in,
const bool initialed_psi_in)
{
this->wfc_basis = wfc_basis_in;
this->pwf = pwf_in;

this->initialed_psi = initialed_psi_in;
}

template <typename T, typename Device>
void HSolverPW<T, Device>::solve(hamilt::Hamilt<T, Device>* pHamilt,
psi::Psi<T, Device>& psi,
elecstate::ElecState* pes,
double* out_eigenvalues,
const std::vector<bool>& is_occupied_in,
const std::string method_in,
const std::string calculation_type_in,
const std::string basis_type_in,
const bool use_paw_in,
const bool use_uspp_in,
const int rank_in_pool_in,
const int nproc_in_pool_in,
const int scf_iter_in,
const bool need_subspace_in,
const int diag_iter_max_in,
const double iter_diag_thr_in,
const bool skip_charge)
{
ModuleBase::TITLE("HSolverPW", "solve");
ModuleBase::timer::tick("HSolverPW", "solve");

// select the method of diagonalization
this->method = method_in;
this->calculation_type = calculation_type_in;
this->basis_type = basis_type_in;

this->use_paw = use_paw_in;
this->use_uspp = use_uspp_in;

this->rank_in_pool = rank_in_pool_in;
this->nproc_in_pool = nproc_in_pool_in;

this->scf_iter = scf_iter_in;
this->need_subspace = need_subspace_in;
this->diag_iter_max = diag_iter_max_in;
this->iter_diag_thr = iter_diag_thr_in;

// report if the specified diagonalization method is not supported
const std::initializer_list<std::string> _methods = {"cg", "dav", "dav_subspace", "bpcg"};
if (std::find(std::begin(_methods), std::end(_methods), this->method) == std::end(_methods))
Expand Down Expand Up @@ -295,7 +262,7 @@ void HSolverPW<T, Device>::solve(hamilt::Hamilt<T, Device>* pHamilt,
{
GlobalV::ofs_running << "Average iterative diagonalization steps for k-points " << ik
<< " is: " << DiagoIterAssist<T, Device>::avg_iter
<< " ; where current threshold is: " << this->iter_diag_thr
<< " ; where current threshold is: " << this->diag_thr
<< " . " << std::endl;
DiagoIterAssist<T, Device>::avg_iter = 0.0;
}
Expand Down Expand Up @@ -384,7 +351,7 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
this->calculation_type,
this->need_subspace,
subspace_func,
this->iter_diag_thr,
this->diag_thr,
this->diag_iter_max,
this->nproc_in_pool);

Expand Down Expand Up @@ -494,7 +461,7 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
psi.get_k_first() ? psi.get_current_nbas()
: psi.get_nk() * psi.get_nbasis(),
GlobalV::PW_DIAG_NDIM,
this->iter_diag_thr,
this->diag_thr,
this->diag_iter_max,
this->need_subspace,
comm_info);
Expand All @@ -512,7 +479,7 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
/// allow 5 eigenvecs to be NOT converged.
const int notconv_max = ("nscf" == this->calculation_type) ? 0 : 5;
/// convergence threshold
const Real david_diag_thr = this->iter_diag_thr;
const Real david_diag_thr = this->diag_thr;
/// maximum iterations
const int david_maxiter = this->diag_iter_max;

Expand Down Expand Up @@ -615,7 +582,7 @@ void HSolverPW<T, Device>::update_precondition(std::vector<Real>& h_diag, const
}
}
}
if (GlobalV::NSPIN == 4)
if (this->nspin == 4)
{
const int size = h_diag.size();
for (int ig = 0; ig < npw; ig++)
Expand All @@ -633,7 +600,7 @@ void HSolverPW<T, Device>::output_iterInfo()
{
GlobalV::ofs_running << "Average iterative diagonalization steps: "
<< DiagoIterAssist<T, Device>::avg_iter / this->wfc_basis->nks
<< " ; where current threshold is: " << this->iter_diag_thr << " . "
<< " ; where current threshold is: " << this->diag_thr << " . "
<< std::endl;
// reset avg_iter
DiagoIterAssist<T, Device>::avg_iter = 0.0;
Expand Down
61 changes: 32 additions & 29 deletions source/module_hsolver/hsolver_pw.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,25 @@ class HSolverPW
public:
HSolverPW(ModulePW::PW_Basis_K* wfc_basis_in,
wavefunc* pwf_in,
const bool initialed_psi_in);

const std::string calculation_type_in,
const std::string basis_type_in,
const std::string method_in,
const bool use_paw_in,
const bool use_uspp_in,
const int nspin_in,

const int scf_iter_in,
const int diag_iter_max_in,
const double diag_thr_in,
const bool need_subspace_in,
const bool initialed_psi_in)

: wfc_basis(wfc_basis_in), pwf(pwf_in),
calculation_type(calculation_type_in), basis_type(basis_type_in), method(method_in),
use_paw(use_paw_in), use_uspp(use_uspp_in), nspin(nspin_in),
scf_iter(scf_iter_in), diag_iter_max(diag_iter_max_in), diag_thr(diag_thr_in),
need_subspace(need_subspace_in), initialed_psi(initialed_psi_in) {};

/// @brief solve function for pw
/// @param pHamilt interface to hamilt
Expand All @@ -34,19 +52,10 @@ class HSolverPW
elecstate::ElecState* pes,
double* out_eigenvalues,
const std::vector<bool>& is_occupied_in,
const std::string method_in,
const std::string calculation_type_in,
const std::string basis_type_in,
const bool use_paw_in,
const bool use_uspp_in,
const int rank_in_pool_in,
const int nproc_in_pool_in,
const int scf_iter_in,
const bool need_subspace_in,
const int diag_iter_max_in,
const double iter_diag_thr_in,
const bool skip_charge);

protected:
// diago caller
void hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
Expand All @@ -62,44 +71,38 @@ class HSolverPW

void output_iterInfo();

bool initialed_psi = false;

ModulePW::PW_Basis_K* wfc_basis = nullptr;

wavefunc* pwf = nullptr;

int scf_iter = 1; // Start from 1
bool need_subspace = false;
int diag_iter_max = 50;
double iter_diag_thr = 1.0e-2; // threshold for diagonalization
const std::string calculation_type;
const std::string basis_type;
const std::string method;
const bool use_paw;
const bool use_uspp;
const int nspin;

const int scf_iter; // Start from 1
const int diag_iter_max; // max iter times for diagonalization
const double diag_thr; // threshold for diagonalization

std::string method = "none";
const bool need_subspace; // for cg or dav_subspace
const bool initialed_psi;

private:
Device* ctx = {};

std::string calculation_type = "scf";
std::string basis_type = "pw";

bool use_paw = false;
bool use_uspp = false;

int rank_in_pool = 0;
int nproc_in_pool = 1;

int nspin = 1;

#ifdef USE_PAW
void paw_func_in_kloop(const int ik);

void call_paw_cell_set_currentk(const int ik);

void paw_func_after_kloop(psi::Psi<T, Device>& psi, elecstate::ElecState* pes);
#endif

};


} // namespace hsolver

#endif
Loading

0 comments on commit 7ef0f50

Please sign in to comment.