Skip to content

Commit

Permalink
apply to dav
Browse files Browse the repository at this point in the history
  • Loading branch information
maki49 committed Nov 14, 2024
1 parent a8bb37f commit f146204
Show file tree
Hide file tree
Showing 11 changed files with 420 additions and 128 deletions.
5 changes: 3 additions & 2 deletions python/pyabacus/src/hsolver/py_diago_dav_subspace.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,10 @@ class PyDiagoDavSubspace
std::copy(hpsi_ptr, hpsi_ptr + nvec * ld_psi, hpsi_out);
};

hsolver::PreOP<std::complex<double>> pre_op(precond_vec, hsolver::transfunc::qe_pw<double>);
hsolver::PreOP<std::complex<double>, base_device::DEVICE_CPU, hsolver::fvec::DivTransMinusEigKernel<std::complex<double>, base_device::DEVICE_CPU>>
pre_op(precond_vec, hsolver::fvec::div_trans_prevec_minus_eigen<std::complex<double>>, hsolver::fval::qe_pw<double>);
obj = std::make_unique<hsolver::Diago_DavSubspace<std::complex<double>, base_device::DEVICE_CPU>>(
hsolver::bind_pre_op(pre_op),
pre_op.get(),
nband,
nbasis,
dav_ndim,
Expand Down
3 changes: 2 additions & 1 deletion python/pyabacus/src/hsolver/py_diago_david.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,9 @@ class PyDiagoDavid
syncmem_op()(this->ctx, this->ctx, spsi_out, psi_in, static_cast<size_t>(nbands * nrow));
};

hsolver::PreOP<std::complex<double>> pre_op(precond_vec);
obj = std::make_unique<hsolver::DiagoDavid<std::complex<double>, base_device::DEVICE_CPU>>(
precond_vec.data(),
pre_op.get(),
nband,
nbasis,
dav_ndim,
Expand Down
4 changes: 2 additions & 2 deletions source/module_hsolver/diago_dav_subspace.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@
using namespace hsolver;

template <typename T, typename Device>
Diago_DavSubspace<T, Device>::Diago_DavSubspace(PreFunc<T>&& precondition_in,
Diago_DavSubspace<T, Device>::Diago_DavSubspace(PreFunc&& precondition_in,
const int& nband_in,
const int& nbasis_in,
const int& david_ndim_in,
const double& diag_thr_in,
const int& diag_nmax_in,
const bool& need_subspace_in,
const diag_comm_info& diag_comm_in)
: precondition(std::forward<PreFunc<T>>(precondition_in)), n_band(nband_in), dim(nbasis_in), nbase_x(nband_in* david_ndim_in),
: precondition(std::forward<PreFunc>(precondition_in)), n_band(nband_in), dim(nbasis_in), nbase_x(nband_in* david_ndim_in),
diag_thr(diag_thr_in), iter_nmax(diag_nmax_in), is_subspace(need_subspace_in), diag_comm(diag_comm_in)
{
this->device = base_device::get_device_type<Device>(this->ctx);
Expand Down
7 changes: 4 additions & 3 deletions source/module_hsolver/diago_dav_subspace.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@ class Diago_DavSubspace
// otherwise return the real type of T(complex<float>, complex<double>)
using Real = typename GetTypeReal<T>::type;

public:
Diago_DavSubspace(PreFunc<T>&& precondition_in, /// pass in a function, lambda or PreOP object
using PreFunc = fvec::DivTransMinusEig<T>;
public:
Diago_DavSubspace(PreFunc&& precondition_in, /// pass in a function, lambda or PreOP object
const int& nband_in,
const int& nbasis_in,
const int& david_ndim_in,
Expand Down Expand Up @@ -69,7 +70,7 @@ class Diago_DavSubspace
const int nbase_x = 0;

/// The precondition operation, can be a function, lambda or PreOP object
const PreFunc<T> precondition;
const PreFunc precondition;
// note that lambdas can only passed by value

/// record for how many bands not have convergence eigenvalues
Expand Down
54 changes: 6 additions & 48 deletions source/module_hsolver/diago_david.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,16 @@ using namespace hsolver;
* @note Auxiliary memory is allocated in the constructor and deallocated in the destructor.
*/
template <typename T, typename Device>
DiagoDavid<T, Device>::DiagoDavid(const Real* precondition_in,
DiagoDavid<T, Device>::DiagoDavid(PreFunc&& precondition_in,
const int nband_in,
const int dim_in,
const int david_ndim_in,
const bool use_paw_in,
const diag_comm_info& diag_comm_in)
: nband(nband_in), dim(dim_in), nbase_x(david_ndim_in * nband_in), david_ndim(david_ndim_in), use_paw(use_paw_in), diag_comm(diag_comm_in)
: nband(nband_in), dim(dim_in), nbase_x(david_ndim_in* nband_in), david_ndim(david_ndim_in), use_paw(use_paw_in), diag_comm(diag_comm_in),
precondition(std::forward<PreFunc>(precondition_in))
{
this->device = base_device::get_device_type<Device>(this->ctx);
this->precondition = precondition_in;

this->one = &one_;
this->zero = &zero_;
Expand Down Expand Up @@ -110,15 +110,6 @@ DiagoDavid<T, Device>::DiagoDavid(const Real* precondition_in,
// lagrange_matrix(nband, nband); // for orthogonalization
resmem_complex_op()(this->ctx, this->lagrange_matrix, nband * nband);
setmem_complex_op()(this->ctx, this->lagrange_matrix, 0, nband * nband);

#if defined(__CUDA) || defined(__ROCM)
// device precondition array
if (this->device == base_device::GpuDevice)
{
resmem_var_op()(this->ctx, this->d_precondition, dim);
syncmem_var_h2d_op()(this->ctx, this->cpu_ctx, this->d_precondition, this->precondition, dim);
}
#endif
}

/**
Expand All @@ -139,13 +130,6 @@ DiagoDavid<T, Device>::~DiagoDavid()
delmem_complex_op()(this->ctx, this->vcc);
delmem_complex_op()(this->ctx, this->lagrange_matrix);
base_device::memory::delete_memory_op<Real, base_device::DEVICE_CPU>()(this->cpu_ctx, this->eigenvalue);
// If the device is a GPU device, free the d_precondition array.
#if defined(__CUDA) || defined(__ROCM)
if (this->device == base_device::GpuDevice)
{
delmem_var_op()(this->ctx, this->d_precondition);
}
#endif
}

template <typename T, typename Device>
Expand Down Expand Up @@ -499,40 +483,14 @@ void DiagoDavid<T, Device>::cal_grad(const HPsiFunc& hpsi_func,
dim // LDC: if(N) max(1, m)
);
//<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<

// Preconditioning
// basis[nbase] = T * basis[nbase] = T * (H - lambda * S) * psi
// where T, the preconditioner, is an approximate inverse of H
// T is a diagonal stored in array `precondition`
// to do preconditioning, divide each column of basis by the corresponding element of precondition
for (int m = 0; m < notconv; m++)
{
//<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
if (this->device == base_device::GpuDevice)
{
#if defined(__CUDA) || defined(__ROCM)
vector_div_vector_op<T, Device>()(this->ctx,
dim,
basis + dim*(nbase + m),
basis + dim*(nbase + m),
this->d_precondition);
#endif
}
else
{
vector_div_vector_op<T, Device>()(this->ctx,
dim,
basis + dim*(nbase + m),
basis + dim*(nbase + m),
this->precondition);
}
//<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
// for (int ig = 0; ig < dim; ig++)
// {
// ppsi[ig] /= this->precondition[ig];
// }
//<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
}
//<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
this->precondition(basis + dim * nbase, dim, notconv);
//<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<

// there is a nbase to nbase + notconv band orthogonalise
// plan for SchmidtOrth
Expand Down
12 changes: 6 additions & 6 deletions source/module_hsolver/diago_david.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#include "module_base/module_device/memory_op.h"// base_device::memory

#include "module_hsolver/diag_comm_info.h"

#include "module_hsolver/precondition_funcs.h"
#include <vector>
#include <functional>

Expand All @@ -21,10 +21,11 @@ class DiagoDavid
// return T if T is real type(float, double),
// otherwise return the real type of T(complex<float>, complex<double>)
using Real = typename GetTypeReal<T>::type;

public:

DiagoDavid(const Real* precondition_in,
using PreFunc = fvec::Div<T>;
public:

DiagoDavid(PreFunc&& precondition_in,
const int nband_in,
const int dim_in,
const int david_ndim_in,
Expand Down Expand Up @@ -102,8 +103,7 @@ class DiagoDavid
int notconv = 0;

/// precondition for diag, diagonal approximation of matrix A(i.e. Hamilt)
const Real* precondition = nullptr;
Real* d_precondition = nullptr;
const PreFunc precondition;

/// eigenvalue results
Real* eigenvalue = nullptr;
Expand Down
9 changes: 6 additions & 3 deletions source/module_hsolver/hsolver_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -503,8 +503,9 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
};
bool scf = this->calculation_type == "nscf" ? false : true;

PreOP<T, Device> pre_op(pre_condition, transfunc::qe_pw<Real>);
Diago_DavSubspace<T, Device> dav_subspace(bind_pre_op(pre_op),
// const auto pre_op = make_pre_op(pre_condition, fvec::div_trans_prevec_minus_eigen<T, Device>, fval::qe_pw<Real>);
const PreOP<T, Device, fvec::DivTransMinusEigKernel<T, Device>> pre_op(pre_condition, fvec::div_trans_prevec_minus_eigen<T, Device>, fval::qe_pw<Real>);
Diago_DavSubspace<T, Device> dav_subspace(pre_op.get(),
psi.get_nbands(),
psi.get_k_first() ? psi.get_current_nbas()
: psi.get_nk() * psi.get_nbasis(),
Expand Down Expand Up @@ -573,7 +574,9 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
ModuleBase::timer::tick("David", "spsi_func");
};

DiagoDavid<T, Device> david(pre_condition.data(),
// const auto pre_op = make_pre_op(pre_condition, fvec::div_prevec<T, Device>);
const PreOP<T, Device> pre_op(pre_condition);
DiagoDavid<T, Device> david(pre_op.get(),
nband,
dim,
PARAM.inp.pw_diag_ndim,
Expand Down
Loading

0 comments on commit f146204

Please sign in to comment.