From 7ef0f5086882f010c44468cab52d8cecaed87437 Mon Sep 17 00:00:00 2001 From: Haozhi Han Date: Sat, 14 Sep 2024 10:19:14 +0800 Subject: [PATCH] Refactor: refactor HsolverPW & HsolverPW_SDFT func (#5094) * refactor HsolverPW & HsolverPW_SDFT func * remove some useless code * refactor hsolver_pw & hsolver_pw_sdft * fix build bug * fix build bug * add nspin value in hsolver_pw * fix build bug * fix build bug --- source/module_esolver/esolver_ks_pw.cpp | 36 ++++++----- source/module_esolver/esolver_sdft_pw.cpp | 23 +++++-- source/module_esolver/pw_fun.cpp | 40 ++++++------ source/module_hsolver/hsolver_pw.cpp | 45 ++------------ source/module_hsolver/hsolver_pw.h | 61 ++++++++++--------- source/module_hsolver/hsolver_pw_sdft.cpp | 24 +------- source/module_hsolver/hsolver_pw_sdft.h | 33 ++++++++-- .../module_hsolver/test/test_hsolver_pw.cpp | 59 ++++++++---------- .../module_hsolver/test/test_hsolver_sdft.cpp | 43 ++++++------- 9 files changed, 172 insertions(+), 192 deletions(-) diff --git a/source/module_esolver/esolver_ks_pw.cpp b/source/module_esolver/esolver_ks_pw.cpp index 22d466d93f..b037c43ef9 100644 --- a/source/module_esolver/esolver_ks_pw.cpp +++ b/source/module_esolver/esolver_ks_pw.cpp @@ -348,6 +348,7 @@ void ESolver_KS_PW::hamilt2density(const int istep, const int iter, c // be careful that istep start from 0 and iter start from 1 // if (iter == 1) hsolver::DiagoIterAssist::need_subspace = ((istep == 0 || istep == 1) && iter == 1) ? false : true; + hsolver::DiagoIterAssist::SCF_ITER = iter; hsolver::DiagoIterAssist::PW_DIAG_THR = ethr; hsolver::DiagoIterAssist::PW_DIAG_NMAX = PARAM.inp.pw_diag_nmax; @@ -361,25 +362,30 @@ void ESolver_KS_PW::hamilt2density(const int istep, const int iter, c this->kspw_psi->get_nbands(), PARAM.inp.diago_full_acc); - hsolver::HSolverPW hsolver_pw_obj(this->pw_wfc, &this->wf, this->init_psi); - hsolver_pw_obj.solve(this->p_hamilt, // hamilt::Hamilt* pHamilt, - this->kspw_psi[0], // psi::Psi& psi, - this->pelec, // elecstate::ElecState* pelec, + hsolver::HSolverPW hsolver_pw_obj(this->pw_wfc, + &this->wf, + + PARAM.inp.calculation, + PARAM.inp.basis_type, + PARAM.inp.ks_solver, + PARAM.inp.use_paw, + GlobalV::use_uspp, + GlobalV::NSPIN, + + hsolver::DiagoIterAssist::SCF_ITER, + hsolver::DiagoIterAssist::PW_DIAG_NMAX, + hsolver::DiagoIterAssist::PW_DIAG_THR, + + hsolver::DiagoIterAssist::need_subspace, + this->init_psi); + + hsolver_pw_obj.solve(this->p_hamilt, + this->kspw_psi[0], + this->pelec, this->pelec->ekb.c, is_occupied, - PARAM.inp.ks_solver, - PARAM.inp.calculation, - PARAM.inp.basis_type, - PARAM.inp.use_paw, - GlobalV::use_uspp, GlobalV::RANK_IN_POOL, GlobalV::NPROC_IN_POOL, - - hsolver::DiagoIterAssist::SCF_ITER, - hsolver::DiagoIterAssist::need_subspace, - hsolver::DiagoIterAssist::PW_DIAG_NMAX, - hsolver::DiagoIterAssist::PW_DIAG_THR, - false); this->init_psi = true; diff --git a/source/module_esolver/esolver_sdft_pw.cpp b/source/module_esolver/esolver_sdft_pw.cpp index 89dfe97445..ac82ae1a9a 100644 --- a/source/module_esolver/esolver_sdft_pw.cpp +++ b/source/module_esolver/esolver_sdft_pw.cpp @@ -172,7 +172,23 @@ void ESolver_SDFT_PW::hamilt2density(int istep, int iter, double ethr) hsolver::DiagoIterAssist>::PW_DIAG_NMAX = PARAM.inp.pw_diag_nmax; // hsolver only exists in this function - hsolver::HSolverPW_SDFT hsolver_pw_sdft_obj(&this->kv, this->pw_wfc, &this->wf, this->stowf, this->stoche, this->init_psi); + hsolver::HSolverPW_SDFT hsolver_pw_sdft_obj( + &this->kv, + this->pw_wfc, + &this->wf, + this->stowf, + this->stoche, + PARAM.inp.calculation, + PARAM.inp.basis_type, + PARAM.inp.ks_solver, + PARAM.inp.use_paw, + GlobalV::use_uspp, + GlobalV::NSPIN, + hsolver::DiagoIterAssist>::SCF_ITER, + hsolver::DiagoIterAssist>::PW_DIAG_NMAX, + hsolver::DiagoIterAssist>::PW_DIAG_THR, + hsolver::DiagoIterAssist>::need_subspace, + this->init_psi); hsolver_pw_sdft_obj.solve(this->p_hamilt, this->psi[0], @@ -181,11 +197,6 @@ void ESolver_SDFT_PW::hamilt2density(int istep, int iter, double ethr) this->stowf, istep, iter, - GlobalV::KS_SOLVER, - hsolver::DiagoIterAssist>::SCF_ITER, - hsolver::DiagoIterAssist>::need_subspace, - hsolver::DiagoIterAssist>::PW_DIAG_NMAX, - hsolver::DiagoIterAssist>::PW_DIAG_THR, false); this->init_psi = true; diff --git a/source/module_esolver/pw_fun.cpp b/source/module_esolver/pw_fun.cpp index c24d426f46..7af9734316 100644 --- a/source/module_esolver/pw_fun.cpp +++ b/source/module_esolver/pw_fun.cpp @@ -80,25 +80,31 @@ void ESolver_KS_PW::hamilt2estates(const double ethr) this->kspw_psi->get_nbands(), PARAM.inp.diago_full_acc); - hsolver::HSolverPW hsolver_pw_obj(this->pw_wfc, &this->wf, this->init_psi); + hsolver::HSolverPW hsolver_pw_obj(this->pw_wfc, + &this->wf, + + PARAM.inp.calculation, + PARAM.inp.basis_type, + PARAM.inp.ks_solver, + PARAM.inp.use_paw, + GlobalV::use_uspp, + GlobalV::NSPIN, + + hsolver::DiagoIterAssist::SCF_ITER, + hsolver::DiagoIterAssist::PW_DIAG_NMAX, + hsolver::DiagoIterAssist::PW_DIAG_THR, + + hsolver::DiagoIterAssist::need_subspace, + this->init_psi); hsolver_pw_obj.solve(this->p_hamilt, - this->kspw_psi[0], - this->pelec, - this->pelec->ekb.c, - is_occupied, - PARAM.inp.ks_solver, - PARAM.inp.calculation, - PARAM.inp.basis_type, - PARAM.inp.use_paw, - GlobalV::use_uspp, - GlobalV::RANK_IN_POOL, - GlobalV::NPROC_IN_POOL, - hsolver::DiagoIterAssist::SCF_ITER, - hsolver::DiagoIterAssist::need_subspace, - hsolver::DiagoIterAssist::PW_DIAG_NMAX, - hsolver::DiagoIterAssist::PW_DIAG_THR, - true); + this->kspw_psi[0], + this->pelec, + this->pelec->ekb.c, + is_occupied, + GlobalV::RANK_IN_POOL, + GlobalV::NPROC_IN_POOL, + true); this->init_psi = true; } diff --git a/source/module_hsolver/hsolver_pw.cpp b/source/module_hsolver/hsolver_pw.cpp index 7ed99e0af6..b032215b91 100644 --- a/source/module_hsolver/hsolver_pw.cpp +++ b/source/module_hsolver/hsolver_pw.cpp @@ -208,55 +208,22 @@ void HSolverPW::paw_func_after_kloop(psi::Psi& psi, elecst #endif -template -HSolverPW::HSolverPW(ModulePW::PW_Basis_K* wfc_basis_in, - wavefunc* pwf_in, - const bool initialed_psi_in) -{ - this->wfc_basis = wfc_basis_in; - this->pwf = pwf_in; - - this->initialed_psi = initialed_psi_in; -} - template void HSolverPW::solve(hamilt::Hamilt* pHamilt, psi::Psi& psi, elecstate::ElecState* pes, double* out_eigenvalues, const std::vector& is_occupied_in, - const std::string method_in, - const std::string calculation_type_in, - const std::string basis_type_in, - const bool use_paw_in, - const bool use_uspp_in, const int rank_in_pool_in, const int nproc_in_pool_in, - const int scf_iter_in, - const bool need_subspace_in, - const int diag_iter_max_in, - const double iter_diag_thr_in, const bool skip_charge) { ModuleBase::TITLE("HSolverPW", "solve"); ModuleBase::timer::tick("HSolverPW", "solve"); - // select the method of diagonalization - this->method = method_in; - this->calculation_type = calculation_type_in; - this->basis_type = basis_type_in; - - this->use_paw = use_paw_in; - this->use_uspp = use_uspp_in; - this->rank_in_pool = rank_in_pool_in; this->nproc_in_pool = nproc_in_pool_in; - this->scf_iter = scf_iter_in; - this->need_subspace = need_subspace_in; - this->diag_iter_max = diag_iter_max_in; - this->iter_diag_thr = iter_diag_thr_in; - // report if the specified diagonalization method is not supported const std::initializer_list _methods = {"cg", "dav", "dav_subspace", "bpcg"}; if (std::find(std::begin(_methods), std::end(_methods), this->method) == std::end(_methods)) @@ -295,7 +262,7 @@ void HSolverPW::solve(hamilt::Hamilt* pHamilt, { GlobalV::ofs_running << "Average iterative diagonalization steps for k-points " << ik << " is: " << DiagoIterAssist::avg_iter - << " ; where current threshold is: " << this->iter_diag_thr + << " ; where current threshold is: " << this->diag_thr << " . " << std::endl; DiagoIterAssist::avg_iter = 0.0; } @@ -384,7 +351,7 @@ void HSolverPW::hamiltSolvePsiK(hamilt::Hamilt* hm, this->calculation_type, this->need_subspace, subspace_func, - this->iter_diag_thr, + this->diag_thr, this->diag_iter_max, this->nproc_in_pool); @@ -494,7 +461,7 @@ void HSolverPW::hamiltSolvePsiK(hamilt::Hamilt* hm, psi.get_k_first() ? psi.get_current_nbas() : psi.get_nk() * psi.get_nbasis(), GlobalV::PW_DIAG_NDIM, - this->iter_diag_thr, + this->diag_thr, this->diag_iter_max, this->need_subspace, comm_info); @@ -512,7 +479,7 @@ void HSolverPW::hamiltSolvePsiK(hamilt::Hamilt* hm, /// allow 5 eigenvecs to be NOT converged. const int notconv_max = ("nscf" == this->calculation_type) ? 0 : 5; /// convergence threshold - const Real david_diag_thr = this->iter_diag_thr; + const Real david_diag_thr = this->diag_thr; /// maximum iterations const int david_maxiter = this->diag_iter_max; @@ -615,7 +582,7 @@ void HSolverPW::update_precondition(std::vector& h_diag, const } } } - if (GlobalV::NSPIN == 4) + if (this->nspin == 4) { const int size = h_diag.size(); for (int ig = 0; ig < npw; ig++) @@ -633,7 +600,7 @@ void HSolverPW::output_iterInfo() { GlobalV::ofs_running << "Average iterative diagonalization steps: " << DiagoIterAssist::avg_iter / this->wfc_basis->nks - << " ; where current threshold is: " << this->iter_diag_thr << " . " + << " ; where current threshold is: " << this->diag_thr << " . " << std::endl; // reset avg_iter DiagoIterAssist::avg_iter = 0.0; diff --git a/source/module_hsolver/hsolver_pw.h b/source/module_hsolver/hsolver_pw.h index 43fcc3b001..1641c04ed4 100644 --- a/source/module_hsolver/hsolver_pw.h +++ b/source/module_hsolver/hsolver_pw.h @@ -21,7 +21,25 @@ class HSolverPW public: HSolverPW(ModulePW::PW_Basis_K* wfc_basis_in, wavefunc* pwf_in, - const bool initialed_psi_in); + + const std::string calculation_type_in, + const std::string basis_type_in, + const std::string method_in, + const bool use_paw_in, + const bool use_uspp_in, + const int nspin_in, + + const int scf_iter_in, + const int diag_iter_max_in, + const double diag_thr_in, + const bool need_subspace_in, + const bool initialed_psi_in) + + : wfc_basis(wfc_basis_in), pwf(pwf_in), + calculation_type(calculation_type_in), basis_type(basis_type_in), method(method_in), + use_paw(use_paw_in), use_uspp(use_uspp_in), nspin(nspin_in), + scf_iter(scf_iter_in), diag_iter_max(diag_iter_max_in), diag_thr(diag_thr_in), + need_subspace(need_subspace_in), initialed_psi(initialed_psi_in) {}; /// @brief solve function for pw /// @param pHamilt interface to hamilt @@ -34,19 +52,10 @@ class HSolverPW elecstate::ElecState* pes, double* out_eigenvalues, const std::vector& is_occupied_in, - const std::string method_in, - const std::string calculation_type_in, - const std::string basis_type_in, - const bool use_paw_in, - const bool use_uspp_in, const int rank_in_pool_in, const int nproc_in_pool_in, - const int scf_iter_in, - const bool need_subspace_in, - const int diag_iter_max_in, - const double iter_diag_thr_in, const bool skip_charge); - + protected: // diago caller void hamiltSolvePsiK(hamilt::Hamilt* hm, @@ -62,33 +71,29 @@ class HSolverPW void output_iterInfo(); - bool initialed_psi = false; - ModulePW::PW_Basis_K* wfc_basis = nullptr; - wavefunc* pwf = nullptr; - int scf_iter = 1; // Start from 1 - bool need_subspace = false; - int diag_iter_max = 50; - double iter_diag_thr = 1.0e-2; // threshold for diagonalization + const std::string calculation_type; + const std::string basis_type; + const std::string method; + const bool use_paw; + const bool use_uspp; + const int nspin; + + const int scf_iter; // Start from 1 + const int diag_iter_max; // max iter times for diagonalization + const double diag_thr; // threshold for diagonalization - std::string method = "none"; + const bool need_subspace; // for cg or dav_subspace + const bool initialed_psi; private: Device* ctx = {}; - std::string calculation_type = "scf"; - std::string basis_type = "pw"; - - bool use_paw = false; - bool use_uspp = false; - int rank_in_pool = 0; int nproc_in_pool = 1; - int nspin = 1; - #ifdef USE_PAW void paw_func_in_kloop(const int ik); @@ -96,10 +101,8 @@ class HSolverPW void paw_func_after_kloop(psi::Psi& psi, elecstate::ElecState* pes); #endif - }; - } // namespace hsolver #endif \ No newline at end of file diff --git a/source/module_hsolver/hsolver_pw_sdft.cpp b/source/module_hsolver/hsolver_pw_sdft.cpp index 83d83fc0f5..0f15879bb1 100644 --- a/source/module_hsolver/hsolver_pw_sdft.cpp +++ b/source/module_hsolver/hsolver_pw_sdft.cpp @@ -17,13 +17,6 @@ void HSolverPW_SDFT::solve(hamilt::Hamilt>* pHamilt, Stochastic_WF& stowf, const int istep, const int iter, - const std::string method_in, - - const int scf_iter_in, - const bool need_subspace_in, - const int diag_iter_max_in, - const double iter_diag_thr_in, - const bool skip_charge) { ModuleBase::TITLE("HSolverPW_SDFT", "solve"); @@ -33,16 +26,9 @@ void HSolverPW_SDFT::solve(hamilt::Hamilt>* pHamilt, const int nbands = psi.get_nbands(); const int nks = psi.get_nk(); - this->scf_iter = scf_iter_in; - this->need_subspace = need_subspace_in; - this->diag_iter_max = diag_iter_max_in; - this->iter_diag_thr = iter_diag_thr_in; - // prepare for the precondition of diagonalization std::vector precondition(psi.get_nbasis(), 0.0); - // select the method of diagonalization - this->method = method_in; // report if the specified diagonalization method is not supported const std::initializer_list _methods = {"cg", "dav", "dav_subspace", "bpcg"}; if (std::find(std::begin(_methods), std::end(_methods), this->method) == std::end(_methods)) @@ -78,13 +64,7 @@ void HSolverPW_SDFT::solve(hamilt::Hamilt>* pHamilt, } this->output_iterInfo(); - - // psi only should be initialed once for PW - if (!this->initialed_psi) - { - this->initialed_psi = true; - } - + for (int ik = 0; ik < nks; ik++) { // init k @@ -114,7 +94,7 @@ void HSolverPW_SDFT::solve(hamilt::Hamilt>* pHamilt, } else { - for (int is = 0; is < GlobalV::NSPIN; is++) + for (int is = 0; is < this->nspin; is++) { ModuleBase::GlobalFunc::ZEROS(pes->charge->rho[is], pes->charge->nrxx); } diff --git a/source/module_hsolver/hsolver_pw_sdft.h b/source/module_hsolver/hsolver_pw_sdft.h index 14a45627d1..aa342b03e1 100644 --- a/source/module_hsolver/hsolver_pw_sdft.h +++ b/source/module_hsolver/hsolver_pw_sdft.h @@ -12,8 +12,34 @@ class HSolverPW_SDFT : public HSolverPW> wavefunc* pwf_in, Stochastic_WF& stowf, StoChe& stoche, + + const std::string calculation_type_in, + const std::string basis_type_in, + const std::string method_in, + const bool use_paw_in, + const bool use_uspp_in, + const int nspin_in, + + const int scf_iter_in, + const int diag_iter_max_in, + const double diag_thr_in, + + const bool need_subspace_in, const bool initialed_psi_in) - : HSolverPW(wfc_basis_in, pwf_in, initialed_psi_in) + + : HSolverPW(wfc_basis_in, + pwf_in, + calculation_type_in, + basis_type_in, + method_in, + use_paw_in, + use_uspp_in, + nspin_in, + scf_iter_in, + diag_iter_max_in, + diag_thr_in, + need_subspace_in, + initialed_psi_in) { stoiter.init(pkv, wfc_basis_in, stowf, stoche); } @@ -25,11 +51,6 @@ class HSolverPW_SDFT : public HSolverPW> Stochastic_WF& stowf, const int istep, const int iter, - const std::string method_in, - const int scf_iter_in, - const bool need_subspace_in, - const int diag_iter_max_in, - const double pw_diag_thr_in, const bool skip_charge); Stochastic_Iter stoiter; diff --git a/source/module_hsolver/test/test_hsolver_pw.cpp b/source/module_hsolver/test/test_hsolver_pw.cpp index 0a6698e570..62d49a35d1 100644 --- a/source/module_hsolver/test/test_hsolver_pw.cpp +++ b/source/module_hsolver/test/test_hsolver_pw.cpp @@ -38,11 +38,33 @@ class TestHSolverPW : public ::testing::Test { hsolver::HSolverPW, base_device::DEVICE_CPU> hs_f = hsolver::HSolverPW, base_device::DEVICE_CPU>(&pwbk, nullptr, - false); + + "scf", + "pw", + "cg", + false, + GlobalV::use_uspp, + GlobalV::NSPIN, + hsolver::DiagoIterAssist, base_device::DEVICE_CPU>::SCF_ITER, + hsolver::DiagoIterAssist, base_device::DEVICE_CPU>::PW_DIAG_NMAX, + hsolver::DiagoIterAssist, base_device::DEVICE_CPU>::PW_DIAG_THR, + hsolver::DiagoIterAssist, base_device::DEVICE_CPU>::need_subspace, + false); hsolver::HSolverPW, base_device::DEVICE_CPU> hs_d = hsolver::HSolverPW, base_device::DEVICE_CPU>(&pwbk, nullptr, - false); + + "scf", + "pw", + "cg", + false, + GlobalV::use_uspp, + GlobalV::NSPIN, + hsolver::DiagoIterAssist, base_device::DEVICE_CPU>::SCF_ITER, + hsolver::DiagoIterAssist, base_device::DEVICE_CPU>::PW_DIAG_NMAX, + hsolver::DiagoIterAssist, base_device::DEVICE_CPU>::PW_DIAG_THR, + hsolver::DiagoIterAssist, base_device::DEVICE_CPU>::need_subspace, + false); hamilt::Hamilt> hamilt_test_d; hamilt::Hamilt> hamilt_test_f; @@ -78,19 +100,10 @@ TEST_F(TestHSolverPW, solve) { &elecstate_test, elecstate_test.ekb.c, is_occupied, - method_test, - "scf", - "pw", - false, - GlobalV::use_uspp, + GlobalV::RANK_IN_POOL, GlobalV::NPROC_IN_POOL, - hsolver::DiagoIterAssist, base_device::DEVICE_CPU>::SCF_ITER, - hsolver::DiagoIterAssist, base_device::DEVICE_CPU>::need_subspace, - hsolver::DiagoIterAssist, base_device::DEVICE_CPU>::PW_DIAG_NMAX, - hsolver::DiagoIterAssist, base_device::DEVICE_CPU>::PW_DIAG_THR, - true); // EXPECT_EQ(this->hs_f.initialed_psi, true); for (int i = 0; i < psi_test_cf.size(); i++) { @@ -106,19 +119,10 @@ TEST_F(TestHSolverPW, solve) { &elecstate_test, elecstate_test.ekb.c, is_occupied, - method_test, - "scf", - "pw", - false, - GlobalV::use_uspp, + GlobalV::RANK_IN_POOL, GlobalV::NPROC_IN_POOL, - hsolver::DiagoIterAssist, base_device::DEVICE_CPU>::SCF_ITER, - hsolver::DiagoIterAssist, base_device::DEVICE_CPU>::need_subspace, - hsolver::DiagoIterAssist, base_device::DEVICE_CPU>::PW_DIAG_NMAX, - hsolver::DiagoIterAssist, base_device::DEVICE_CPU>::PW_DIAG_THR, - true); // EXPECT_EQ(this->hs_d.initialed_psi, true); @@ -130,17 +134,6 @@ TEST_F(TestHSolverPW, solve) { EXPECT_DOUBLE_EQ(elecstate_test.ekb.c[0], 4.0); EXPECT_DOUBLE_EQ(elecstate_test.ekb.c[1], 7.0); - // check initDiagh() - this->hs_f.method = "dav"; - this->hs_d.method = "dav"; - this->hs_f.initialed_psi = false; - this->hs_d.initialed_psi = false; - // this->hs_f.initDiagh(psi_test_cf); - // this->hs_d.initDiagh(psi_test_cd); - // will not change state of initialed_psi in initDiagh - EXPECT_EQ(this->hs_f.initialed_psi, false); - EXPECT_EQ(this->hs_d.initialed_psi, false); - // // check hamiltSolvePsiK() // this->hs_f.hamiltSolvePsiK(&hamilt_test_f, psi_test_cf, this->hs_f.precondition, ekb_f.data()); // this->hs_d.hamiltSolvePsiK(&hamilt_test_d, diff --git a/source/module_hsolver/test/test_hsolver_sdft.cpp b/source/module_hsolver/test/test_hsolver_sdft.cpp index b32718ca04..dbeeac396e 100644 --- a/source/module_hsolver/test/test_hsolver_sdft.cpp +++ b/source/module_hsolver/test/test_hsolver_sdft.cpp @@ -133,7 +133,23 @@ class TestHSolverPW_SDFT : public ::testing::Test K_Vectors kv; wavefunc wf; StoChe stoche; - hsolver::HSolverPW_SDFT hs_d = hsolver::HSolverPW_SDFT(&kv, &pwbk, &wf, stowf, stoche, false); + hsolver::HSolverPW_SDFT hs_d = hsolver::HSolverPW_SDFT(&kv, + &pwbk, + &wf, + stowf, + stoche, + + "scf", + "pw", + "cg", + false, + GlobalV::use_uspp, + GlobalV::NSPIN, + hsolver::DiagoIterAssist>::SCF_ITER, + hsolver::DiagoIterAssist>::PW_DIAG_NMAX, + hsolver::DiagoIterAssist>::PW_DIAG_THR, + hsolver::DiagoIterAssist>::need_subspace, + false); hamilt::Hamilt> hamilt_test_d; @@ -162,9 +178,6 @@ TEST_F(TestHSolverPW_SDFT, solve) int istep = 0; int iter = 0; - //check solve() - EXPECT_EQ(this->hs_d.initialed_psi, false); - this->hs_d.solve(&hamilt_test_d, psi_test_cd, &elecstate_test, @@ -172,14 +185,7 @@ TEST_F(TestHSolverPW_SDFT, solve) stowf, istep, iter, - method_test, - hsolver::DiagoIterAssist>::SCF_ITER, - hsolver::DiagoIterAssist>::need_subspace, - hsolver::DiagoIterAssist>::PW_DIAG_NMAX, - hsolver::DiagoIterAssist>::PW_DIAG_THR, - false - ); - EXPECT_EQ(this->hs_d.initialed_psi, true); + false); EXPECT_DOUBLE_EQ(hsolver::DiagoIterAssist>::avg_iter, 0.0); EXPECT_DOUBLE_EQ(elecstate_test.ekb.c[0], 4.0); EXPECT_DOUBLE_EQ(elecstate_test.ekb.c[1], 7.0); @@ -221,9 +227,6 @@ TEST_F(TestHSolverPW_SDFT, solve_noband_skipcharge) elecstate_test.charge->nrxx = 10; int istep = 0; int iter = 0; - - //check solve() - hs_d.initialed_psi = true; this->hs_d.solve(&hamilt_test_d, psi_test_no, @@ -232,11 +235,6 @@ TEST_F(TestHSolverPW_SDFT, solve_noband_skipcharge) stowf, istep, iter, - method_test, - hsolver::DiagoIterAssist>::SCF_ITER, - hsolver::DiagoIterAssist>::need_subspace, - hsolver::DiagoIterAssist>::PW_DIAG_NMAX, - hsolver::DiagoIterAssist>::PW_DIAG_THR, false ); EXPECT_DOUBLE_EQ(hsolver::DiagoIterAssist>::avg_iter, 0.0); @@ -259,11 +257,6 @@ TEST_F(TestHSolverPW_SDFT, solve_noband_skipcharge) stowf, istep, iter, - method_test, - hsolver::DiagoIterAssist>::SCF_ITER, - hsolver::DiagoIterAssist>::need_subspace, - hsolver::DiagoIterAssist>::PW_DIAG_NMAX, - hsolver::DiagoIterAssist>::PW_DIAG_THR, true ); EXPECT_EQ(stowf.nbands_diag, 4);