Skip to content

Commit

Permalink
diag-precondition
Browse files Browse the repository at this point in the history
  • Loading branch information
maki49 committed Nov 12, 2024
1 parent 2357507 commit 2a35de7
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 10 deletions.
14 changes: 12 additions & 2 deletions source/module_lr/esolver_lrtd_lcao.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "module_base/scalapack_connector.h"
#include "module_parameter/parameter.h"
#include "module_lr/ri_benchmark/ri_benchmark.h"
#include "module_lr/operator_casida/operator_lr_diag.h" // for precondition

#ifdef __EXX
template<>
Expand Down Expand Up @@ -444,20 +445,29 @@ void LR::ESolver_LR<T, TR>::runner(int istep, UnitCell& cell)
if (GlobalV::MY_RANK == 0) { assert(nst == LR_Util::write_value(efile(label), prec, e, nst)); }
assert(nst * dim == LR_Util::write_value(vfile(label), prec, v, nst, dim));
};
std::vector<double> precondition(this->input.lr_solver == "lapack" ? 0 : nloc_per_band, 1.0);
// allocate and initialize A matrix and density matrix
if (openshell)
{
for (int is : {0, 1})
{
const int offset_is = is * this->paraX_[0].get_local_size();
OperatorLRDiag<double> pre_op(this->eig_ks.c + is * nk * (nocc[0] + nvirt[0]), this->paraX_[is], this->nk, this->nocc[is], this->nvirt[is]);
if (input.lr_solver != "lapack") { pre_op.act(1, offset_is, 1, precondition.data() + offset_is, precondition.data() + offset_is); }
}
std::cout << "Solving spin-conserving excitation for open-shell system." << std::endl;
HamiltULR<T> hulr(xc_kernel, nspin, this->nbasis, this->nocc, this->nvirt, this->ucell, orb_cutoff_, GlobalC::GridD, *this->psi_ks, this->eig_ks,
#ifdef __EXX
this->exx_lri, this->exx_info.info_global.hybrid_alpha,
#endif
this->gint_, this->pot, this->kv, this->paraX_, this->paraC_, this->paraMat_);
LR::HSolver::solve(hulr, this->X[0].template data<T>(), nloc_per_band, nstates, this->pelec->ekb.c, this->input.lr_solver, this->input.lr_thr);
LR::HSolver::solve(hulr, this->X[0].template data<T>(), nloc_per_band, nstates, this->pelec->ekb.c, this->input.lr_solver, this->input.lr_thr, precondition);
if (input.out_wfc_lr) { write_states("openshell", this->pelec->ekb.c, this->X[0].template data<T>(), nloc_per_band, nstates); }
}
else
{
OperatorLRDiag<double> pre_op(this->eig_ks.c, this->paraX_[0], this->nk, this->nocc[0], this->nvirt[0]);
if (input.lr_solver != "lapack") { pre_op.act(1, nloc_per_band, 1, precondition.data(), precondition.data()); }
auto spin_types = std::vector<std::string>({ "singlet", "triplet" });
for (int is = 0;is < nspin;++is)
{
Expand All @@ -470,7 +480,7 @@ void LR::ESolver_LR<T, TR>::runner(int istep, UnitCell& cell)
spin_types[is], input.ri_hartree_benchmark, (input.ri_hartree_benchmark == "aims" ? input.aims_nbasis : std::vector<int>({})));
// solve the Casida equation
LR::HSolver::solve(hlr, this->X[is].template data<T>(), nloc_per_band, nstates,
this->pelec->ekb.c + is * nstates, this->input.lr_solver, this->input.lr_thr/*,
this->pelec->ekb.c + is * nstates, this->input.lr_solver, this->input.lr_thr, precondition/*,
!std::set<std::string>({ "hf", "hse" }).count(this->xc_kernel)*/); //whether the kernel is Hermitian
if (input.out_wfc_lr) { write_states(spin_types[is], this->pelec->ekb.c + is * nstates, this->X[is].template data<T>(), nloc_per_band, nstates); }
}
Expand Down
11 changes: 5 additions & 6 deletions source/module_lr/hsolver_lrtd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,15 @@ namespace LR
double* eig,
const std::string method,
const Real<T>& diag_ethr, ///< threshold for diagonalization
const std::vector<Real<T>>& precondition,
const bool hermitian = true)
{
ModuleBase::TITLE("HSolverLR", "solve");
const std::vector<std::string> spin_types = { "singlet", "triplet" };
// note: if not TDA, the eigenvalues will be complex
// then we will need a new constructor of DiagoDavid

// 1. allocate precondition and eigenvalue
std::vector<Real<T>> precondition(dim);
// 1. allocate eigenvalue
std::vector<Real<T>> eigenvalue(nband); //nstates
// 2. select the method
#ifdef __MPI
Expand All @@ -67,9 +67,7 @@ namespace LR
}
else
{
// 3. set precondition and diagethr
for (int i = 0; i < dim; ++i) { precondition[i] = static_cast<Real<T>>(1.0); }

// 3. set maxiter and funcs
const int maxiter = hsolver::DiagoIterAssist<T>::PW_DIAG_NMAX;

auto hpsi_func = [&hm](T* psi_in, T* hpsi, const int ld_psi, const int nvec) {hm.hPsi(psi_in, hpsi, ld_psi, nvec);};
Expand Down Expand Up @@ -139,7 +137,8 @@ namespace LR

auto psi_tensor = ct::TensorMap(psi, ct::DataTypeToEnum<T>::value, ct::DeviceType::CpuDevice, ct::TensorShape({ nband, dim }));
auto eigen_tensor = ct::TensorMap(eigenvalue.data(), ct::DataTypeToEnum<Real<T>>::value, ct::DeviceType::CpuDevice, ct::TensorShape({ nband }));
auto precon_tensor = ct::TensorMap(precondition.data(), ct::DataTypeToEnum<Real<T>>::value, ct::DeviceType::CpuDevice, ct::TensorShape({ dim }));
std::vector<Real<T>> precondition_(precondition); //since TensorMap does not support const pointer
auto precon_tensor = ct::TensorMap(precondition_.data(), ct::DataTypeToEnum<Real<T>>::value, ct::DeviceType::CpuDevice, ct::TensorShape({ dim }));
auto hpsi_func = [&hm](const ct::Tensor& psi_in, ct::Tensor& hpsi) {hm.hPsi(psi_in.data<T>(), hpsi.data<T>(), psi_in.shape().dim_size(0) /*nbasis_local*/, 1/*band-by-band*/);};
auto spsi_func = [&hm](const ct::Tensor& psi_in, ct::Tensor& spsi)
{ std::memcpy(spsi.data<T>(), psi_in.data<T>(), sizeof(T) * psi_in.NumElements()); };
Expand Down
3 changes: 1 addition & 2 deletions source/module_lr/operator_casida/operator_lr_diag.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,8 @@ namespace LR
const bool is_first_node = false)const override
{
ModuleBase::TITLE("OperatorLRDiag", "act");
const int nlocal_ph = nk * pX.get_local_size(); // local size of particle-hole basis
hsolver::vector_mul_vector_op<T, Device>()(this->ctx,
nk * pX.get_local_size(),
nk * pX.get_local_size(), // local size of particle-hole basis
hpsi,
psi_in,
this->eig_ks_diff.c);
Expand Down

0 comments on commit 2a35de7

Please sign in to comment.