Skip to content

Commit

Permalink
Feature: Porting abacus to DSP hardware (mtblas part) (#5301)
Browse files Browse the repository at this point in the history
* Link mtblas library

* Add mtblas gemm kernel usage

* Finish memory_op on dsp

* Update CMakeLists

* Add compilation script

* Fix warnings

* Fix install script

* Initialize DSP hardware

* Replace gemm in math_kernel

* Fix CMakeLists Bug

* Fix bugs #1

* Fix bug 2

* Fix link to shared library error

* Stop use gemm_mt globally

* Modify op usage

* Fix bug

* Fix template usage

* Fix compilation

* Replace all dav_subspace gemm kernels

---------

Co-authored-by: Mohan Chen <mohanchen@pku.edu.cn>
  • Loading branch information
Critsium-xy and mohanchen authored Oct 23, 2024
1 parent 2af2095 commit f039250
Show file tree
Hide file tree
Showing 10 changed files with 251 additions and 11 deletions.
12 changes: 12 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ option(ENABLE_RAPIDJSON "Enable rapid-json usage." OFF)
option(ENABLE_CNPY "Enable cnpy usage." OFF)
option(ENABLE_PEXSI "Enable support for PEXSI." OFF)
option(ENABLE_CUSOLVERMP "Enable cusolvermp." OFF)
option(USE_DSP "Enable DSP usage." OFF)

# enable json support
if(ENABLE_RAPIDJSON)
Expand Down Expand Up @@ -119,6 +120,12 @@ elseif(ENABLE_LCAO AND NOT ENABLE_MPI)
set(ABACUS_BIN_NAME abacus_serial)
endif()

if (USE_DSP)
set(USE_ELPA OFF)
set(ENABLE_LCAO OFF)
set(ABACUS_BIN_NAME abacus_dsp)
endif()

list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake)

if(ENABLE_COVERAGE)
Expand Down Expand Up @@ -240,6 +247,11 @@ if(ENABLE_MPI)
list(APPEND math_libs MPI::MPI_CXX)
endif()

if (USE_DSP)
target_link_libraries(${ABACUS_BIN_NAME} ${DIR_MTBLAS_LIBRARY})
add_compile_definitions(__DSP)
endif()

find_package(Threads REQUIRED)
target_link_libraries(${ABACUS_BIN_NAME} Threads::Threads)

Expand Down
10 changes: 10 additions & 0 deletions install_dsp.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
CXX=mpicxx \
cmake -B build \
-DUSE_DSP=ON \
-DENABLE_LCAO=OFF \
-DFFTW3_DIR=/vol8/appsoftware/fftw/ \
-DFFTW3_LIBRARY=/vol8/appsoftware/fftw/lib/libfftw3.so \
-DFFTW3_OMP_LIBRARY=/vol8/appsoftware/fftw/lib/libfftw3_omp.so \
-DFFTW3_FLOAT_LIBRARY=/vol8/appsoftware/fftw/lib/libfftw3f.so \
-DLAPACK_DIR=/vol8/appsoftware/openblas/0.3.21/lib \
-DDIR_MTBLAS_LIBRARY=/vol8/home/dptech_zyz1/develop/packages/libmtblas_abacus.so
45 changes: 41 additions & 4 deletions source/module_base/blas_connector.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
#include "blas_connector.h"

#ifdef __DSP
#include "module_base/kernels/dsp/dsp_connector.h"
#endif

void BlasConnector::axpy( const int n, const float alpha, const float *X, const int incX, float *Y, const int incY, base_device::AbacusDevice_t device_type)
{
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
Expand Down Expand Up @@ -64,13 +68,15 @@ float BlasConnector::dot( const int n, const float *X, const int incX, const flo
{
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
return sdot_(&n, X, &incX, Y, &incY);
return sdot_(&n, X, &incX, Y, &incY);
}
}

double BlasConnector::dot( const int n, const double *X, const int incX, const double *Y, const int incY, base_device::AbacusDevice_t device_type)
{
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
return ddot_(&n, X, &incX, Y, &incY);
return ddot_(&n, X, &incX, Y, &incY);
}
}

Expand All @@ -83,7 +89,14 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
sgemm_(&transb, &transa, &n, &m, &k,
&alpha, b, &ldb, a, &lda,
&beta, c, &ldc);
}
}
#ifdef __DSP
else if (device_type == base_device::AbacusDevice_t::DspDevice){
sgemm_mt_(&transb, &transa, &n, &m, &k,
&alpha, b, &ldb, a, &lda,
&beta, c, &ldc);
}
#endif
}

void BlasConnector::gemm(const char transa, const char transb, const int m, const int n, const int k,
Expand All @@ -94,7 +107,14 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
dgemm_(&transb, &transa, &n, &m, &k,
&alpha, b, &ldb, a, &lda,
&beta, c, &ldc);
}
}
#ifdef __DSP
else if (device_type == base_device::AbacusDevice_t::DspDevice){
dgemm_mt_(&transb, &transa, &n, &m, &k,
&alpha, b, &ldb, a, &lda,
&beta, c, &ldc);
}
#endif
}

void BlasConnector::gemm(const char transa, const char transb, const int m, const int n, const int k,
Expand All @@ -105,7 +125,14 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
cgemm_(&transb, &transa, &n, &m, &k,
&alpha, b, &ldb, a, &lda,
&beta, c, &ldc);
}
}
#ifdef __DSP
else if (device_type == base_device::AbacusDevice_t::DspDevice) {
cgemm_mt_(&transb, &transa, &n, &m, &k,
&alpha, b, &ldb, a, &lda,
&beta, c, &ldc);
}
#endif
}

void BlasConnector::gemm(const char transa, const char transb, const int m, const int n, const int k,
Expand All @@ -116,7 +143,14 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
zgemm_(&transb, &transa, &n, &m, &k,
&alpha, b, &ldb, a, &lda,
&beta, c, &ldc);
}
}
#ifdef __DSP
else if (device_type == base_device::AbacusDevice_t::DspDevice) {
zgemm_mt_(&transb, &transa, &n, &m, &k,
&alpha, b, &ldb, a, &lda,
&beta, c, &ldc);
}
#endif
}

void BlasConnector::gemv(const char trans, const int m, const int n,
Expand Down Expand Up @@ -152,6 +186,7 @@ float BlasConnector::nrm2( const int n, const float *X, const int incX, base_dev
{
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
return snrm2_( &n, X, &incX );
return snrm2_( &n, X, &incX );
}
}

Expand All @@ -160,6 +195,7 @@ double BlasConnector::nrm2( const int n, const double *X, const int incX, base_d
{
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
return dnrm2_( &n, X, &incX );
return dnrm2_( &n, X, &incX );
}
}

Expand All @@ -168,6 +204,7 @@ double BlasConnector::nrm2( const int n, const std::complex<double> *X, const in
{
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
return dznrm2_( &n, X, &incX );
return dznrm2_( &n, X, &incX );
}
}

Expand Down
66 changes: 66 additions & 0 deletions source/module_base/kernels/dsp/dsp_connector.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
#ifndef DSP_CONNECTOR_H
#define DSP_CONNECTOR_H
#ifdef __DSP

// Base dsp functions
void dspInitHandle(int id);
void dspDestoryHandle();
void *malloc_ht(size_t bytes);
void free_ht(void* ptr);


// mtblas functions

void sgemm_mt_(const char *transa, const char *transb,
const int *m, const int *n, const int *k,
const float *alpha, const float *a, const int *lda,
const float *b, const int *ldb, const float *beta,
float *c, const int *ldc);

void dgemm_mt_(const char *transa, const char *transb,
const int *m, const int *n, const int *k,
const double *alpha,const double *a, const int *lda,
const double *b, const int *ldb, const double *beta,
double *c, const int *ldc);

void zgemm_mt_(const char *transa, const char *transb,
const int *m, const int *n, const int *k,
const std::complex<double> *alpha, const std::complex<double> *a, const int *lda,
const std::complex<double> *b, const int *ldb, const std::complex<double> *beta,
std::complex<double> *c, const int *ldc);

void cgemm_mt_(const char *transa, const char *transb,
const int *m, const int *n, const int *k,
const std::complex<float> *alpha, const std::complex<float> *a, const int *lda,
const std::complex<float> *b, const int *ldb, const std::complex<float> *beta,
std::complex<float> *c, const int *ldc);


void sgemm_mth_(const char *transa, const char *transb,
const int *m, const int *n, const int *k,
const float *alpha, const float *a, const int *lda,
const float *b, const int *ldb, const float *beta,
float *c, const int *ldc);

void dgemm_mth_(const char *transa, const char *transb,
const int *m, const int *n, const int *k,
const double *alpha,const double *a, const int *lda,
const double *b, const int *ldb, const double *beta,
double *c, const int *ldc);

void zgemm_mth_(const char *transa, const char *transb,
const int *m, const int *n, const int *k,
const std::complex<double> *alpha, const std::complex<double> *a, const int *lda,
const std::complex<double> *b, const int *ldb, const std::complex<double> *beta,
std::complex<double> *c, const int *ldc);

void cgemm_mth_(const char *transa, const char *transb,
const int *m, const int *n, const int *k,
const std::complex<float> *alpha, const std::complex<float> *a, const int *lda,
const std::complex<float> *b, const int *ldb, const std::complex<float> *beta,
std::complex<float> *c, const int *ldc);

//#define zgemm_ zgemm_mt

#endif
#endif
15 changes: 15 additions & 0 deletions source/module_base/module_device/memory_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

#include "module_base/memory.h"
#include "module_base/tool_threading.h"
#ifdef __DSP
#include "module_base/kernels/dsp/dsp_connector.h"
#endif

#include <complex>
#include <cstring>
Expand All @@ -18,9 +21,17 @@ struct resize_memory_op<FPTYPE, base_device::DEVICE_CPU>
{
if (arr != nullptr)
{
#ifdef __DSP
free_ht(arr);
#else
free(arr);
#endif
}
#ifdef __DSP
arr = (FPTYPE*)malloc_ht(sizeof(FPTYPE) * size);
#else
arr = (FPTYPE*)malloc(sizeof(FPTYPE) * size);
#endif
std::string record_string;
if (record_in != nullptr)
{
Expand Down Expand Up @@ -92,7 +103,11 @@ struct delete_memory_op<FPTYPE, base_device::DEVICE_CPU>
{
void operator()(const base_device::DEVICE_CPU* dev, FPTYPE* arr)
{
#ifdef __DSP
free_ht(arr);
#else
free(arr);
#endif
}
};

Expand Down
1 change: 1 addition & 0 deletions source/module_base/module_device/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ enum AbacusDevice_t
UnKnown,
CpuDevice,
GpuDevice,
DspDevice
};

} // namespace base_device
Expand Down
13 changes: 12 additions & 1 deletion source/module_esolver/esolver_ks_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@
#include <ATen/kernels/blas.h>
#include <ATen/kernels/lapack.h>

#ifdef __DSP
#include "module_base/kernels/dsp/dsp_connector.h"
#endif

namespace ModuleESolver
{

Expand All @@ -67,6 +71,10 @@ ESolver_KS_PW<T, Device>::ESolver_KS_PW()
container::kernels::createGpuSolverHandle();
}
#endif
#ifdef __DSP
std::cout << " ** Initializing DSP Hardware..." << std::endl;
dspInitHandle(GlobalV::MY_RANK % 4);
#endif
}

template <typename T, typename Device>
Expand All @@ -92,7 +100,10 @@ ESolver_KS_PW<T, Device>::~ESolver_KS_PW()
#endif
delete reinterpret_cast<psi::Psi<T, Device>*>(this->kspw_psi);
}

#ifdef __DSP
std::cout << " ** Closing DSP Hardware..." << std::endl;
dspDestoryHandle();
#endif
if (PARAM.inp.precision == "single")
{
delete reinterpret_cast<psi::Psi<std::complex<double>, Device>*>(this->__kspw_psi);
Expand Down
42 changes: 36 additions & 6 deletions source/module_hsolver/diago_dav_subspace.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,12 @@ int Diago_DavSubspace<T, Device>::diag_once(const HPsiFunc& hpsi_func,
// updata eigenvectors of Hamiltonian
setmem_complex_op()(this->ctx, psi_in, 0, n_band * psi_in_dmax);

gemm_op<T, Device>()(this->ctx,
#ifdef __DSP
gemm_op_mt<T, Device>()
#else
gemm_op<T, Device>()
#endif
(this->ctx,
'N',
'N',
this->dim,
Expand Down Expand Up @@ -262,7 +267,12 @@ void Diago_DavSubspace<T, Device>::cal_grad(const HPsiFunc& hpsi_func,
}
}

gemm_op<T, Device>()(this->ctx,
#ifdef __DSP
gemm_op_mt<T, Device>()
#else
gemm_op<T, Device>()
#endif
(this->ctx,
'N',
'N',
this->dim,
Expand Down Expand Up @@ -302,7 +312,12 @@ void Diago_DavSubspace<T, Device>::cal_grad(const HPsiFunc& hpsi_func,
delmem_real_op()(this->ctx, e_temp_hd);
}

gemm_op<T, Device>()(this->ctx,
#ifdef __DSP
gemm_op_mt<T, Device>()
#else
gemm_op<T, Device>()
#endif
(this->ctx,
'N',
'N',
this->dim,
Expand Down Expand Up @@ -386,7 +401,12 @@ void Diago_DavSubspace<T, Device>::cal_elem(const int& dim,
{
ModuleBase::timer::tick("Diago_DavSubspace", "cal_elem");

gemm_op<T, Device>()(this->ctx,
#ifdef __DSP
gemm_op_mt<T, Device>()
#else
gemm_op<T, Device>()
#endif
(this->ctx,
'C',
'N',
nbase + notconv,
Expand All @@ -401,7 +421,12 @@ void Diago_DavSubspace<T, Device>::cal_elem(const int& dim,
&hcc[nbase * this->nbase_x],
this->nbase_x);

gemm_op<T, Device>()(this->ctx,
#ifdef __DSP
gemm_op_mt<T, Device>()
#else
gemm_op<T, Device>()
#endif
(this->ctx,
'C',
'N',
nbase + notconv,
Expand Down Expand Up @@ -603,7 +628,12 @@ void Diago_DavSubspace<T, Device>::refresh(const int& dim,
{
ModuleBase::timer::tick("Diago_DavSubspace", "refresh");

gemm_op<T, Device>()(this->ctx,
#ifdef __DSP
gemm_op_mt<T, Device>()
#else
gemm_op<T, Device>()
#endif
(this->ctx,
'N',
'N',
this->dim,
Expand Down
Loading

0 comments on commit f039250

Please sign in to comment.