Skip to content

Commit

Permalink
Merge pull request #95 from arcaneframework/dev/gg-add-parallel-for-d…
Browse files Browse the repository at this point in the history
…irect-hypre-backend

Add parallel support for direct hypre backend
  • Loading branch information
grospelliergilles authored Jan 20, 2024
2 parents 38d4e03 + 3394cea commit 67fb5d7
Show file tree
Hide file tree
Showing 4 changed files with 214 additions and 43 deletions.
10 changes: 5 additions & 5 deletions femutils/CsrFormatMatrix.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,14 @@ translateToLinearSystem(DoFLinearSystem& linear_system)
if (((i + 1) < nb_row) && (m_matrix_row(i) == m_matrix_row(i + 1)))
continue;
for (Int32 j = m_matrix_row(i); ((i + 1) < nb_row && j < m_matrix_row(i + 1)) || ((i + 1) == nb_row && j < m_matrix_column.dim1Size()); j++) {
if (DoFLocalId(m_matrix_column(j)).isNull())
continue;
//info() << "Add: (" << i << ", " << m_matrix_column(j) << " v=" << m_matrix_value(j);
if (do_set_csr){
++m_matrix_rows_nb_column[i];
continue;
}
else
linear_system.matrixAddValue(DoFLocalId(i), DoFLocalId(m_matrix_column(j)), m_matrix_value(j));
if (DoFLocalId(m_matrix_column(j)).isNull())
continue;
//info() << "Add: (" << i << ", " << m_matrix_column(j) << " v=" << m_matrix_value(j);
linear_system.matrixAddValue(DoFLocalId(i), DoFLocalId(m_matrix_column(j)), m_matrix_value(j));
}
}

Expand Down
231 changes: 197 additions & 34 deletions femutils/HypreDoFLinearSystem.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,14 @@
#include <arcane/utils/FatalErrorException.h>
#include <arcane/utils/NumArray.h>
#include <arcane/utils/PlatformUtils.h>
#include <arcane/utils/ITraceMng.h>

#include <arcane/core/VariableTypes.h>
#include <arcane/core/IItemFamily.h>
#include <arcane/core/BasicService.h>
#include <arcane/core/ServiceFactory.h>
#include <arcane/core/IParallelMng.h>
#include <arcane/core/ItemPrinter.h>

#include <arcane/accelerator/core/Runner.h>

Expand All @@ -33,6 +36,10 @@
#include <HYPRE_parcsr_ls.h>
#include <krylov.h>

// NOTE:
// DoF family must be compacted (i.e maxLocalId()==nbItem()) and sorted
// for this implementation to works.

/*---------------------------------------------------------------------------*/
/*---------------------------------------------------------------------------*/

Expand Down Expand Up @@ -86,18 +93,22 @@ class HypreDoFLinearSystemImpl
, m_dof_matrix_indexes(VariableBuildInfo(m_dof_family, solver_name + "DoFMatrixIndexes"))
, m_dof_elimination_info(VariableBuildInfo(m_dof_family, solver_name + "DoFEliminationInfo"))
, m_dof_elimination_value(VariableBuildInfo(m_dof_family, solver_name + "DoFEliminationValue"))
, m_dof_matrix_numbering(VariableBuildInfo(dof_family, solver_name + "MatrixNumbering"))
{
info() << "Creating HypreDoFLinearSystemImpl()";
}

~HypreDoFLinearSystemImpl()
{
info() << "Calling HYPRE_Finalize";
HYPRE_Finalize(); /* must be the last HYPRE function call */
}

public:

void build()
{
HYPRE_Init(); /* must be the first HYPRE function call */
}

public:
Expand Down Expand Up @@ -161,27 +172,95 @@ class HypreDoFLinearSystemImpl
VariableDoFInt32 m_dof_matrix_indexes;
VariableDoFByte m_dof_elimination_info;
VariableDoFReal m_dof_elimination_value;
VariableDoFInt32 m_dof_matrix_numbering;
NumArray<Int32, MDDim1> m_parallel_columns_index;
NumArray<Int32, MDDim1> m_parallel_rows_index;
//! Work array to store values of solution vector in parallel
NumArray<Real, MDDim1> m_result_work_values;
Runner* m_runner = nullptr;

CSRFormatView m_csr_view;
Int32 m_first_own_row = -1;
Int32 m_nb_own_row = -1;

private:

void _computeMatrixNumerotation();
};

/*---------------------------------------------------------------------------*/
/*---------------------------------------------------------------------------*/

void HypreDoFLinearSystemImpl::
solve()
_computeMatrixNumerotation()
{
HYPRE_Init(); /* must be the first HYPRE function call */
IParallelMng* pm = m_dof_family->parallelMng();
const bool is_parallel = pm->isParallel();
const Int32 nb_rank = pm->commSize();
const Int32 my_rank = pm->commRank();

DoFGroup all_dofs = m_dof_family->allItems();
DoFGroup own_dofs = all_dofs.own();
const Int32 nb_own_row = own_dofs.size();

Int32 own_first_index = 0;

if (is_parallel) {
// TODO: utiliser un Scan lorsque ce sera disponible dans Arcane
UniqueArray<Int32> parallel_rows_index(nb_rank, 0);
pm->allGather(ConstArrayView<Int32>(1, &nb_own_row), parallel_rows_index);
info() << "ALL_NB_ROW = " << parallel_rows_index;
for (Int32 i = 0; i < my_rank; ++i)
own_first_index += parallel_rows_index[i];
}

info() << "OwnFirstIndex=" << own_first_index << " NbOwnRow=" << nb_own_row;

m_first_own_row = own_first_index;
m_nb_own_row = nb_own_row;

// TODO: Faire avec API accelerateur
ENUMERATE_DOF (idof, own_dofs) {
DoF dof = *idof;
m_dof_matrix_numbering[idof] = own_first_index + idof.index();
//info() << "Numbering dof_uid=" << dof.uniqueId() << " M=" << m_dof_matrix_numbering[idof];
}
info() << " nb_own_row=" << nb_own_row << " nb_item=" << m_dof_family->nbItem();
m_dof_matrix_numbering.synchronize();

m_parallel_rows_index.resize(nb_own_row);
m_result_work_values.resize(nb_own_row);
}

/*---------------------------------------------------------------------------*/
/*---------------------------------------------------------------------------*/

void HypreDoFLinearSystemImpl::
solve()
{
HYPRE_MemoryLocation hypre_memory = HYPRE_MEMORY_HOST;
HYPRE_ExecutionPolicy hypre_exec_policy = HYPRE_EXEC_HOST;

// Récupère le communicateur MPI associé
IParallelMng* pm = m_dof_family->parallelMng();
Parallel::Communicator arcane_comm = pm->communicator();
MPI_Comm mpi_comm = MPI_COMM_WORLD;
if (arcane_comm.isValid())
mpi_comm = static_cast<MPI_Comm>(arcane_comm);

bool is_parallel = pm->isParallel();
const Int32 nb_rank = pm->commSize();
const Int32 my_rank = pm->commRank();

// TODO: A ne faire qu'un fois sauf si les DoFs évoluent
_computeMatrixNumerotation();

bool is_use_device = false;
if (m_runner) {
is_use_device = isAcceleratorPolicy(m_runner->executionPolicy());
info() << "Runner for Hypre=" << m_runner->executionPolicy() << " is_device=" << is_use_device;
}

// Si HYPRE n'est pas compilé avec le support GPU, alors on utilise l'hôte.
// (NOTE: a priori il n'y a pas besoin de faire cela. Si Hypre n'est pas compilé avec
// support GPU alors HYPRE_MEMORY_DEVICE <=> HYPRE_MEMORY_HOST
Expand Down Expand Up @@ -215,70 +294,138 @@ solve()
//HYPRE_SetGPUMemoryPoolSize(bin_growth, min_bin, max_bin, max_bytes);

/* setup IJ matrix A */
// TODO: Utiliser le bon communicateur en parallèle.
MPI_Comm comm = MPI_COMM_WORLD;

HYPRE_IJMatrix ij_A = nullptr;
HYPRE_ParCSRMatrix parcsr_A = nullptr;

const int first_row = 0;
const int nb_row = m_csr_view.rows().size();
info() << "NB_ROW=" << nb_row;
const int last_row = first_row + nb_row - 1;
{
int first_col = first_row;
int last_col = first_col + nb_row - 1;
info() << "CreateMatrix row=" << first_row << ", " << last_row
<< " col=" << first_col << ", " << last_col;
HYPRE_IJMatrixCreate(comm, first_row, last_row, first_col, last_col, &ij_A);
const bool do_debug_print = false;
const bool do_dump_matrix = false;

Span<const Int32> rows_index_span = m_dof_matrix_numbering.asArray();
const Int32 nb_local_row = rows_index_span.size();

if (do_debug_print) {
info() << "ROWS_INDEX=" << rows_index_span;
info() << "ROWS=" << m_csr_view.rows();
info() << "ROWS_NB_COLUMNS=" << m_csr_view.rowsNbColumn();
info() << "COLUMNS=" << m_csr_view.columns();
info() << "VALUE=" << m_csr_view.values();
}

NumArray<Int32, MDDim1> rows_index(nb_row);
for (Int32 i = 0; i < nb_row; ++i)
rows_index[i] = i;
Span<const Int32> rows_index_span = rows_index.to1DSpan();
const int first_row = m_first_own_row;
const int last_row = m_first_own_row + m_nb_own_row - 1;

info() << "CreateMatrix first_row=" << first_row << " last_row " << last_row;
HYPRE_IJMatrixCreate(mpi_comm, first_row, last_row, first_row, last_row, &ij_A);

int* rows_nb_column_data = const_cast<int*>(m_csr_view.rowsNbColumn().data());

Real m1 = platform::getRealTime();
HYPRE_IJMatrixSetObjectType(ij_A, HYPRE_PARCSR);
HYPRE_IJMatrixInitialize_v2(ij_A, hypre_memory);

// m_csr_view.columns() use matrix coordinates local to sub-domain
// We need to translate them to global matrix coordinates
Span<const Int32> columns_index_span = m_csr_view.columns();
if (is_parallel) {
// TODO: Faire sur accélérateur et ne faire qu'une fois si la structure
// ne change pas.
Int64 nb_column = columns_index_span.size();
m_parallel_columns_index.resize(nb_column);
for (Int64 i = 0; i < nb_column; ++i) {
DoFLocalId lid(columns_index_span[i]);
//info() << "I=" << i << " index=" << columns_index_span[i];
// Si lid correspond à une entité nulle, alors la valeur de la matrice
// ne sera pas utilisée.
if (!lid.isNull())
m_parallel_columns_index[i] = m_dof_matrix_numbering[lid];
else
m_parallel_columns_index[i] = 0;
}
columns_index_span = m_parallel_columns_index.to1DSpan();
}

if (do_debug_print) {
info() << "FINAL_COLUMNS=" << columns_index_span;
info() << "NbValue=" << m_csr_view.values().size();
}

Span<const Real> matrix_values = m_csr_view.values();
if (do_debug_print) {
ENUMERATE_ (DoF, idof, m_dof_family->allItems()) {
DoF dof = *idof;
Int32 nb_col = m_csr_view.rowsNbColumn()[idof.index()];
Int32 row_csr_index = m_csr_view.rows()[idof.index()];
info() << "DoF dof=" << ItemPrinter(dof) << " nb_col=" << nb_col << " row_csr_index=" << row_csr_index
<< " global_row=" << rows_index_span[idof.index()];
for (Int32 i = 0; i < nb_col; ++i) {
Int32 col_index = m_csr_view.columns()[row_csr_index + i];
if (col_index >= 0)
info() << "COL=" << col_index
<< " T_COL=" << m_dof_matrix_numbering[DoFLocalId(col_index)]
<< " V=" << matrix_values[row_csr_index + i];
else
info() << "COL=" << col_index
<< " X_COL=" << columns_index_span[row_csr_index + i]
<< " V=" << matrix_values[row_csr_index + i];
}
}
}

if (is_parallel) {
// Fill 'm_parallel_rows_index' with only rows we owns
// NOTE: This is only needed if matrix structure has changed.
Int32 index = 0;
ENUMERATE_ (DoF, idof, m_dof_family->allItems()) {
DoF dof = *idof;
if (!dof.isOwn())
continue;
Int32 nb_col = m_csr_view.rowsNbColumn()[idof.index()];
m_parallel_rows_index[index] = rows_index_span[idof.index()];
++index;
}
}

/* GPU pointers; efficient in large chunks */
HYPRE_IJMatrixSetValues(ij_A,
nb_row,
nb_local_row,
rows_nb_column_data,
rows_index_span.data(),
m_csr_view.columns().data(),
m_csr_view.values().data());
columns_index_span.data(),
matrix_values.data());

HYPRE_IJMatrixAssemble(ij_A);
HYPRE_IJMatrixGetObject(ij_A, (void**)&parcsr_A);
Real m2 = platform::getRealTime();
info() << "Time to create matrix=" << (m2 - m1);

//HYPRE_IJMatrixPrint(ij_A, "dumpA.txt");
if (do_dump_matrix) {
String file_name = String("dumpA.") + String::fromNumber(my_rank) + ".txt";
HYPRE_IJMatrixPrint(ij_A, file_name.localstr());
pm->traceMng()->flush();
pm->barrier();
}

HYPRE_IJVector ij_vector_b = nullptr;
HYPRE_ParVector parvector_b = nullptr;
HYPRE_IJVector ij_vector_x = nullptr;
HYPRE_ParVector parvector_x = nullptr;

hypreCheck("IJVectorCreate", HYPRE_IJVectorCreate(comm, first_row, last_row, &ij_vector_b));
hypreCheck("IJVectorCreate", HYPRE_IJVectorCreate(mpi_comm, first_row, last_row, &ij_vector_b));
hypreCheck("IJVectorSetObjectType", HYPRE_IJVectorSetObjectType(ij_vector_b, HYPRE_PARCSR));
HYPRE_IJVectorInitialize_v2(ij_vector_b, hypre_memory);

hypreCheck("IJVectorCreate", HYPRE_IJVectorCreate(comm, first_row, last_row, &ij_vector_x));
hypreCheck("IJVectorCreate", HYPRE_IJVectorCreate(mpi_comm, first_row, last_row, &ij_vector_x));
hypreCheck("IJVectorSetObjectType", HYPRE_IJVectorSetObjectType(ij_vector_x, HYPRE_PARCSR));
HYPRE_IJVectorInitialize_v2(ij_vector_x, hypre_memory);

Real v1 = platform::getRealTime();
hypreCheck("HYPRE_IJVectorSetValues",
HYPRE_IJVectorSetValues(ij_vector_b, nb_row, rows_index_span.data(),
HYPRE_IJVectorSetValues(ij_vector_b, nb_local_row, rows_index_span.data(),
m_rhs_variable.asArray().data()));

hypreCheck("HYPRE_IJVectorSetValues",
HYPRE_IJVectorSetValues(ij_vector_x, nb_row, rows_index_span.data(),
HYPRE_IJVectorSetValues(ij_vector_x, nb_local_row, rows_index_span.data(),
m_dof_variable.asArray().data()));

hypreCheck("HYPRE_IJVectorAssemble",
Expand All @@ -291,13 +438,19 @@ solve()
Real v2 = platform::getRealTime();
info() << "Time to create vectors=" << (v2 - v1);

//HYPRE_IJVectorPrint(ij_vector_b, "dumpB.txt");
//HYPRE_IJVectorPrint(ij_vector_x, "dumpX.txt");
if (do_dump_matrix) {
String file_name_b = String("dumpB.") + String::fromNumber(my_rank) + ".txt";
HYPRE_IJVectorPrint(ij_vector_b, file_name_b.localstr());
String file_name_x = String("dumpX.") + String::fromNumber(my_rank) + ".txt";
HYPRE_IJVectorPrint(ij_vector_x, file_name_x.localstr());
pm->traceMng()->flush();
pm->barrier();
}

HYPRE_Solver solver = nullptr;
HYPRE_Solver precond = nullptr;
/* setup AMG */
HYPRE_ParCSRPCGCreate(comm, &solver);
HYPRE_ParCSRPCGCreate(mpi_comm, &solver);

/* Set some parameters (See Reference Manual for more parameters) */
HYPRE_PCGSetMaxIter(solver, 1000); /* max iterations */
Expand Down Expand Up @@ -327,11 +480,21 @@ solve()
Real b1 = platform::getRealTime();
info() << "Time to solve=" << (b1 - a1);

hypreCheck("HYPRE_IJVectorGetValues",
HYPRE_IJVectorGetValues(ij_vector_x, nb_row, rows_index_span.data(),
m_dof_variable.asArray().data()));

HYPRE_Finalize(); /* must be the last HYPRE function call */
if (is_parallel) {
Int32 nb_wanted_row = m_parallel_rows_index.extent0();
hypreCheck("HYPRE_IJVectorGetValues",
HYPRE_IJVectorGetValues(ij_vector_x, nb_wanted_row,
m_parallel_rows_index.to1DSpan().data(),
m_result_work_values.to1DSpan().data()));
ENUMERATE_ (DoF, idof, m_dof_family->allItems().own()) {
m_dof_variable[idof] = m_result_work_values[idof.index()];
}
}
else {
hypreCheck("HYPRE_IJVectorGetValues",
HYPRE_IJVectorGetValues(ij_vector_x, nb_local_row, rows_index_span.data(),
m_dof_variable.asArray().data()));
}
}

/*---------------------------------------------------------------------------*/
Expand Down
Loading

0 comments on commit 67fb5d7

Please sign in to comment.