Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: Using gemm instead of einsum in bpcg #5318

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions source/module_hsolver/diago_bpcg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ void DiagoBPCG<T, Device>::init_iter(const psi::Psi<T, Device> &psi_in) {
// Specify the problem size n_basis, n_band, while lda is n_basis
this->n_band = psi_in.get_nbands();
this->n_basis = psi_in.get_nbasis();
this->n_dim = psi_in.get_current_nbas();

// All column major tensors

Expand Down Expand Up @@ -96,6 +97,21 @@ void DiagoBPCG<T, Device>::orth_cholesky(
ct::EinsumOption option(
/*conj_x=*/false, /*conj_y=*/true, /*alpha=*/1.0, /*beta=*/0.0, /*Tensor out=*/&hsub_out);
hsub_out = ct::op::einsum("ij,kj->ik", psi_out, psi_out, option);
// using gemm instead einsum for different leading dimension and nbasis
// gemm_op()(this->ctx,
// 'N',
// 'C',
// this->n_band,
// this->n_band,
// this->n_dim,
// this->one,
// psi_out.data<T>(),
// this->n_basis,
// psi_out.data<T>(),
// this->n_basis,
// this->zero,
// hsub_out.data<T>(),
// this->n_band);

// set hsub matrix to lower format;
ct::kernels::set_matrix<T, ct_Device>()(
Expand Down
11 changes: 10 additions & 1 deletion source/module_hsolver/diago_bpcg.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,23 @@ class DiagoBPCG


private:
/// ctx is nothing but the devices used in gemm_op (Device * ctx = nullptr;),
Device * ctx = {};
/// the number of rows of the input psi
int n_band = 0;
/// the number of cols of the input psi
/// the number of cols of the input psi, leading dimension
int n_basis = 0;
/// the real-time column size of the input psi
int n_dim = 0;
/// max iter steps for all-band cg loop
int nline = 4;
/// cg convergence thr
Real all_band_cg_thr = 1E-5;

// Pointer to objects of 1 and 0 for gemm
const T *one = nullptr, *zero = nullptr, *neg_one = nullptr;
const T one_ = static_cast<T>(1.0), zero_ = static_cast<T>(0.0), neg_one_ = static_cast<T>(-1.0);

ct::DataType r_type = ct::DataType::DT_INVALID;
ct::DataType t_type = ct::DataType::DT_INVALID;
ct::DeviceType device_type = ct::DeviceType::UnKnown;
Expand Down Expand Up @@ -330,6 +338,7 @@ class DiagoBPCG

using calc_grad_with_block_op = hsolver::calc_grad_with_block_op<T, Device>;
using line_minimize_with_block_op = hsolver::line_minimize_with_block_op<T, Device>;
using gemm_op = hsolver::gemm_op<T, Device>;

};

Expand Down
Loading