Skip to content

Commit

Permalink
Refactor: move cal_edm_tddft to module_dm (#5485)
Browse files Browse the repository at this point in the history
* Refactor: move cal_edm_tddft to module_dm

* update head file

* add lapack_connector.h in cal_edm_tddft.cpp
  • Loading branch information
YuLiu98 authored Nov 14, 2024
1 parent 8e8cad2 commit a95eb5b
Show file tree
Hide file tree
Showing 7 changed files with 70 additions and 90 deletions.
2 changes: 1 addition & 1 deletion source/Makefile.Objects
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ OBJS_ELECSTAT_LCAO=elecstate_lcao.o\
density_matrix.o\
density_matrix_io.o\
cal_dm_psi.o\
cal_edm_tddft.o\

OBJS_ESOLVER=esolver.o\
esolver_ks.o\
Expand All @@ -259,7 +260,6 @@ OBJS_ESOLVER_LCAO=esolver_ks_lcao.o\
lcao_others.o\
lcao_init_after_vc.o\
lcao_fun.o\
cal_edm_tddft.o\

OBJS_GINT=gint.o\
gint_gamma_env.o\
Expand Down
1 change: 1 addition & 0 deletions source/module_elecstate/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ if(ENABLE_LCAO)
module_dm/density_matrix.cpp
module_dm/density_matrix_io.cpp
module_dm/cal_dm_psi.cpp
module_dm/cal_edm_tddft.cpp
)
endif()

Expand Down
Original file line number Diff line number Diff line change
@@ -1,81 +1,46 @@
#include "esolver_ks_lcao_tddft.h"

#include "module_io/cal_r_overlap_R.h"
#include "module_io/dipole_io.h"
#include "module_io/td_current_io.h"
#include "module_io/write_HS.h"
#include "module_io/write_HS_R.h"
#include "module_io/write_wfc_nao.h"

//--------------temporary----------------------------
#include "module_base/blas_connector.h"
#include "module_base/global_function.h"
#include "module_base/scalapack_connector.h"
#include "cal_edm_tddft.h"

#include "module_base/lapack_connector.h"
#include "module_elecstate/module_charge/symmetry_rho.h"
#include "module_elecstate/occupy.h"
#include "module_hamilt_lcao/hamilt_lcaodft/LCAO_domain.h" // need divide_HS_in_frag
#include "module_hamilt_lcao/module_tddft/evolve_elec.h"
#include "module_hamilt_lcao/module_tddft/td_velocity.h"
#include "module_hamilt_pw/hamilt_pwdft/global.h"
#include "module_io/print_info.h"

//-----HSolver ElecState Hamilt--------
#include "module_elecstate/elecstate_lcao.h"
#include "module_elecstate/elecstate_lcao_tddft.h"
#include "module_hamilt_lcao/hamilt_lcaodft/hamilt_lcao.h"
#include "module_hsolver/hsolver_lcao.h"
#include "module_parameter/parameter.h"
#include "module_psi/psi.h"

//-----force& stress-------------------
#include "module_hamilt_lcao/hamilt_lcaodft/FORCE_STRESS.h"

//---------------------------------------------------

namespace ModuleESolver
#include "module_base/scalapack_connector.h"
namespace elecstate
{

// use the original formula (Hamiltonian matrix) to calculate energy density
// matrix
void ESolver_KS_LCAO_TDDFT::cal_edm_tddft()
// use the original formula (Hamiltonian matrix) to calculate energy density matrix
void cal_edm_tddft(Parallel_Orbitals& pv,
elecstate::ElecState* pelec,
K_Vectors& kv,
hamilt::Hamilt<std::complex<double>>* p_hamilt)
{
// mohan add 2024-03-27
const int nlocal = PARAM.globalv.nlocal;
assert(nlocal >= 0);

dynamic_cast<elecstate::ElecStateLCAO<std::complex<double>>*>(this->pelec)
->get_DM()
->EDMK.resize(kv.get_nks());
for (int ik = 0; ik < kv.get_nks(); ++ik) {
auto _pelec = dynamic_cast<elecstate::ElecStateLCAO<std::complex<double>>*>(pelec);

p_hamilt->updateHk(ik);
_pelec->get_DM()->EDMK.resize(kv.get_nks());

std::complex<double>* tmp_dmk
= dynamic_cast<elecstate::ElecStateLCAO<std::complex<double>>*>(this->pelec)->get_DM()->get_DMK_pointer(ik);

ModuleBase::ComplexMatrix& tmp_edmk
= dynamic_cast<elecstate::ElecStateLCAO<std::complex<double>>*>(this->pelec)->get_DM()->EDMK[ik];

const Parallel_Orbitals* tmp_pv
= dynamic_cast<elecstate::ElecStateLCAO<std::complex<double>>*>(this->pelec)->get_DM()->get_paraV_pointer();
for (int ik = 0; ik < kv.get_nks(); ++ik)
{
p_hamilt->updateHk(ik);
std::complex<double>* tmp_dmk = _pelec->get_DM()->get_DMK_pointer(ik);
ModuleBase::ComplexMatrix& tmp_edmk = _pelec->get_DM()->EDMK[ik];

#ifdef __MPI

// mohan add 2024-03-27
//! be careful, the type of nloc is 'long'
//! whether the long type is safe, needs more discussion
const long nloc = this->pv.nloc;
const int ncol = this->pv.ncol;
const int nrow = this->pv.nrow;
const long nloc = pv.nloc;
const int ncol = pv.ncol;
const int nrow = pv.nrow;

tmp_edmk.create(ncol, nrow);
complex<double>* Htmp = new complex<double>[nloc];
complex<double>* Sinv = new complex<double>[nloc];
complex<double>* tmp1 = new complex<double>[nloc];
complex<double>* tmp2 = new complex<double>[nloc];
complex<double>* tmp3 = new complex<double>[nloc];
complex<double>* tmp4 = new complex<double>[nloc];
std::complex<double>* Htmp = new std::complex<double>[nloc];
std::complex<double>* Sinv = new std::complex<double>[nloc];
std::complex<double>* tmp1 = new std::complex<double>[nloc];
std::complex<double>* tmp2 = new std::complex<double>[nloc];
std::complex<double>* tmp3 = new std::complex<double>[nloc];
std::complex<double>* tmp4 = new std::complex<double>[nloc];

ModuleBase::GlobalFunc::ZEROS(Htmp, nloc);
ModuleBase::GlobalFunc::ZEROS(Sinv, nloc);
Expand All @@ -86,8 +51,8 @@ void ESolver_KS_LCAO_TDDFT::cal_edm_tddft()

const int inc = 1;

hamilt::MatrixBlock<complex<double>> h_mat;
hamilt::MatrixBlock<complex<double>> s_mat;
hamilt::MatrixBlock<std::complex<double>> h_mat;
hamilt::MatrixBlock<std::complex<double>> s_mat;

p_hamilt->matrix(h_mat, s_mat);
zcopy_(&nloc, h_mat.p, &inc, Htmp, &inc);
Expand All @@ -97,7 +62,7 @@ void ESolver_KS_LCAO_TDDFT::cal_edm_tddft()
int info = 0;
const int one_int = 1;

pzgetrf_(&nlocal, &nlocal, Sinv, &one_int, &one_int, this->pv.desc, ipiv.data(), &info);
pzgetrf_(&nlocal, &nlocal, Sinv, &one_int, &one_int, pv.desc, ipiv.data(), &info);

int lwork = -1;
int liwork = -1;
Expand All @@ -112,7 +77,7 @@ void ESolver_KS_LCAO_TDDFT::cal_edm_tddft()
Sinv,
&one_int,
&one_int,
this->pv.desc,
pv.desc,
ipiv.data(),
work.data(),
&lwork,
Expand All @@ -129,7 +94,7 @@ void ESolver_KS_LCAO_TDDFT::cal_edm_tddft()
Sinv,
&one_int,
&one_int,
this->pv.desc,
pv.desc,
ipiv.data(),
work.data(),
&lwork,
Expand All @@ -139,9 +104,9 @@ void ESolver_KS_LCAO_TDDFT::cal_edm_tddft()

const char N_char = 'N';
const char T_char = 'T';
const complex<double> one_float = {1.0, 0.0};
const complex<double> zero_float = {0.0, 0.0};
const complex<double> half_float = {0.5, 0.0};
const std::complex<double> one_float = {1.0, 0.0};
const std::complex<double> zero_float = {0.0, 0.0};
const std::complex<double> half_float = {0.5, 0.0};

pzgemm_(&N_char,
&N_char,
Expand All @@ -152,16 +117,16 @@ void ESolver_KS_LCAO_TDDFT::cal_edm_tddft()
Htmp,
&one_int,
&one_int,
this->pv.desc,
pv.desc,
Sinv,
&one_int,
&one_int,
this->pv.desc,
pv.desc,
&zero_float,
tmp1,
&one_int,
&one_int,
this->pv.desc);
pv.desc);

pzgemm_(&T_char,
&N_char,
Expand All @@ -172,16 +137,16 @@ void ESolver_KS_LCAO_TDDFT::cal_edm_tddft()
tmp1,
&one_int,
&one_int,
this->pv.desc,
pv.desc,
tmp_dmk,
&one_int,
&one_int,
this->pv.desc,
pv.desc,
&zero_float,
tmp2,
&one_int,
&one_int,
this->pv.desc);
pv.desc);

pzgemm_(&N_char,
&N_char,
Expand All @@ -192,16 +157,16 @@ void ESolver_KS_LCAO_TDDFT::cal_edm_tddft()
Sinv,
&one_int,
&one_int,
this->pv.desc,
pv.desc,
Htmp,
&one_int,
&one_int,
this->pv.desc,
pv.desc,
&zero_float,
tmp3,
&one_int,
&one_int,
this->pv.desc);
pv.desc);

pzgemm_(&N_char,
&T_char,
Expand All @@ -212,16 +177,16 @@ void ESolver_KS_LCAO_TDDFT::cal_edm_tddft()
tmp_dmk,
&one_int,
&one_int,
this->pv.desc,
pv.desc,
tmp3,
&one_int,
&one_int,
this->pv.desc,
pv.desc,
&zero_float,
tmp4,
&one_int,
&one_int,
this->pv.desc);
pv.desc);

pzgeadd_(&N_char,
&nlocal,
Expand All @@ -230,12 +195,12 @@ void ESolver_KS_LCAO_TDDFT::cal_edm_tddft()
tmp2,
&one_int,
&one_int,
this->pv.desc,
pv.desc,
&half_float,
tmp4,
&one_int,
&one_int,
this->pv.desc);
pv.desc);

zcopy_(&nloc, tmp4, &inc, tmp_edmk.c, &inc);

Expand All @@ -247,12 +212,12 @@ void ESolver_KS_LCAO_TDDFT::cal_edm_tddft()
delete[] tmp4;
#else
// for serial version
tmp_edmk.create(this->pv.ncol, this->pv.nrow);
tmp_edmk.create(pv.ncol, pv.nrow);
ModuleBase::ComplexMatrix Sinv(nlocal, nlocal);
ModuleBase::ComplexMatrix Htmp(nlocal, nlocal);

hamilt::MatrixBlock<complex<double>> h_mat;
hamilt::MatrixBlock<complex<double>> s_mat;
hamilt::MatrixBlock<std::complex<double>> h_mat;
hamilt::MatrixBlock<std::complex<double>> s_mat;

p_hamilt->matrix(h_mat, s_mat);
// cout<<"hmat "<<h_mat.p[0]<<endl;
Expand Down
16 changes: 16 additions & 0 deletions source/module_elecstate/module_dm/cal_edm_tddft.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#ifndef CAL_EDM_TDDFT_H
#define CAL_EDM_TDDFT_H

#include "module_basis/module_ao/parallel_orbitals.h"
#include "module_cell/klist.h"
#include "module_elecstate/elecstate_lcao.h"
#include "module_hamilt_general/hamilt.h"

namespace elecstate
{
void cal_edm_tddft(Parallel_Orbitals& pv,
elecstate::ElecState* pelec,
K_Vectors& kv,
hamilt::Hamilt<std::complex<double>>* p_hamilt);
} // namespace elecstate
#endif // CAL_EDM_TDDFT_H
1 change: 0 additions & 1 deletion source/module_esolver/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ if(ENABLE_LCAO)
lcao_others.cpp
lcao_init_after_vc.cpp
lcao_fun.cpp
cal_edm_tddft.cpp
)
endif()

Expand Down
3 changes: 2 additions & 1 deletion source/module_esolver/esolver_ks_lcao_tddft.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "module_base/lapack_connector.h"
#include "module_base/scalapack_connector.h"
#include "module_elecstate/module_charge/symmetry_rho.h"
#include "module_elecstate/module_dm/cal_edm_tddft.h"
#include "module_elecstate/occupy.h"
#include "module_hamilt_lcao/hamilt_lcaodft/LCAO_domain.h" // need divide_HS_in_frag
#include "module_hamilt_lcao/module_tddft/evolve_elec.h"
Expand Down Expand Up @@ -358,7 +359,7 @@ void ESolver_KS_LCAO_TDDFT::update_pot(const int istep, const int iter)
// calculate energy density matrix for tddft
if (istep >= (wf.init_wfc == "file" ? 0 : 2) && module_tddft::Evolve_elec::td_edm == 0)
{
this->cal_edm_tddft();
elecstate::cal_edm_tddft(this->pv, this->pelec, this->kv, this->p_hamilt);
}
}

Expand Down
2 changes: 0 additions & 2 deletions source/module_esolver/esolver_ks_lcao_tddft.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@ class ESolver_KS_LCAO_TDDFT : public ESolver_KS_LCAO<std::complex<double>, doubl
virtual void iter_finish(const int istep, int& iter) override;

virtual void after_scf(const int istep) override;

void cal_edm_tddft();
};

} // namespace ModuleESolver
Expand Down

0 comments on commit a95eb5b

Please sign in to comment.