Skip to content

Commit

Permalink
Merge pull request #94 from arcaneframework/dev/gg-add-hypre-gpu-support
Browse files Browse the repository at this point in the history
Add support for GPU backend of Hypre
  • Loading branch information
grospelliergilles authored Jan 17, 2024
2 parents b777b2b + 248a859 commit 38d4e03
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 23 deletions.
4 changes: 4 additions & 0 deletions femutils/AlephDoFLinearSystem.cc
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,8 @@ class AlephDoFLinearSystemImpl
}

bool hasSetCSRValues() const { return false; }
void setRunner(Runner* r) override { m_runner = r; }
Runner* runner() const { return m_runner; }

private:

Expand Down Expand Up @@ -406,6 +408,8 @@ class AlephDoFLinearSystemImpl
//! True is we need to manually destroy the matrix/vector
bool m_need_destroy_matrix_and_vector = true;

Runner* m_runner = nullptr;

private:

void _fillMatrix();
Expand Down
20 changes: 17 additions & 3 deletions femutils/DoFLinearSystem.cc
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
// -*- tab-width: 2; indent-tabs-mode: nil; coding: utf-8-with-signature -*-
//-----------------------------------------------------------------------------
// Copyright 2000-2023 CEA (www.cea.fr) IFPEN (www.ifpenergiesnouvelles.com)
// Copyright 2000-2024 CEA (www.cea.fr) IFPEN (www.ifpenergiesnouvelles.com)
// See the top-level COPYRIGHT file for details.
// SPDX-License-Identifier: Apache-2.0
//-----------------------------------------------------------------------------
/*---------------------------------------------------------------------------*/
/* DoFLinearSystem.cc (C) 2022-2023 */
/* DoFLinearSystem.cc (C) 2022-2024 */
/* */
/* Linear system: Matrix A + Vector x + Vector b for Ax=b. */
/*---------------------------------------------------------------------------*/
Expand Down Expand Up @@ -186,6 +186,8 @@ class SequentialDoFLinearSystemImpl
ARCANE_THROW(NotImplementedException, "");
}
bool hasSetCSRValues() const override { return false; }
void setRunner(Runner* r) override { m_runner = r; }
Runner* runner() const { return m_runner; }

public:

Expand All @@ -206,6 +208,8 @@ class SequentialDoFLinearSystemImpl
Real m_epsilon = 1.0e-15;
eInternalSolverMethod m_solver_method = eInternalSolverMethod::Auto;

Runner* m_runner = nullptr;

private:

void _fillRHSVector()
Expand Down Expand Up @@ -308,7 +312,7 @@ _checkInit() const
/*---------------------------------------------------------------------------*/

void DoFLinearSystem::
initialize(ISubDomain* sd, IItemFamily* dof_family, const String& solver_name)
initialize(ISubDomain* sd, Runner* runner, IItemFamily* dof_family, const String& solver_name)
{
ARCANE_CHECK_POINTER(sd);
ARCANE_CHECK_POINTER(dof_family);
Expand All @@ -323,6 +327,16 @@ initialize(ISubDomain* sd, IItemFamily* dof_family, const String& solver_name)
}
m_item_family = dof_family;
m_p = m_linear_system_factory->createInstance(sd, dof_family, solver_name);
m_p->setRunner(runner);
}

/*---------------------------------------------------------------------------*/
/*---------------------------------------------------------------------------*/

void DoFLinearSystem::
initialize(ISubDomain* sd, IItemFamily* dof_family, const String& solver_name)
{
initialize(sd, nullptr, dof_family, solver_name);
}

/*---------------------------------------------------------------------------*/
Expand Down
12 changes: 9 additions & 3 deletions femutils/DoFLinearSystem.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ class DoFLinearSystemImpl
virtual void clearValues() = 0;
virtual void setCSRValues(const CSRFormatView& csr_view) = 0;
virtual bool hasSetCSRValues() const = 0;
virtual void setRunner(Runner* r) =0;
virtual Runner* runner() const =0;
};

/*---------------------------------------------------------------------------*/
Expand Down Expand Up @@ -117,12 +119,16 @@ class DoFLinearSystem

/*
* \brief Initialize the instance.
*
* The variable dof_variable will be filled with the solution value after
* the call to the method solve().
*/
void initialize(ISubDomain* sd, IItemFamily* dof_family, const String& solver_name);

/*
* \brief Initialize the instance.
*
* \a runner may be null.
*/
void initialize(ISubDomain* sd, Runner* runner, IItemFamily* dof_family, const String& solver_name);

//! Indicate if method initialize() has been called
bool isInitialized() const;

Expand Down
64 changes: 50 additions & 14 deletions femutils/HypreDoFLinearSystem.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,15 @@

#include <arcane/utils/FatalErrorException.h>
#include <arcane/utils/NumArray.h>
#include <arcane/utils/PlatformUtils.h>

#include <arcane/core/VariableTypes.h>
#include <arcane/core/IItemFamily.h>
#include <arcane/core/BasicService.h>
#include <arcane/core/ServiceFactory.h>

#include <arcane/accelerator/core/Runner.h>

#include "FemUtils.h"
#include "IDoFLinearSystemFactory.h"

Expand Down Expand Up @@ -147,6 +150,9 @@ class HypreDoFLinearSystemImpl
}
bool hasSetCSRValues() const override { return true; }

void setRunner(Runner* r) override { m_runner = r; }
Runner* runner() const { return m_runner; }

private:

IItemFamily* m_dof_family = nullptr;
Expand All @@ -155,6 +161,7 @@ class HypreDoFLinearSystemImpl
VariableDoFInt32 m_dof_matrix_indexes;
VariableDoFByte m_dof_elimination_info;
VariableDoFReal m_dof_elimination_value;
Runner* m_runner = nullptr;

CSRFormatView m_csr_view;
};
Expand All @@ -167,27 +174,48 @@ solve()
{
HYPRE_Init(); /* must be the first HYPRE function call */

#if 0
/* AMG in GPU memory (default) */
hypreCheck("HYPRE_SetMemoryLocation", HYPRE_SetMemoryLocation(HYPRE_MEMORY_DEVICE));
/* setup AMG on GPUs */
HYPRE_SetExecutionPolicy(HYPRE_EXEC_DEVICE);
HYPRE_MemoryLocation hypre_memory = HYPRE_MEMORY_HOST;
HYPRE_ExecutionPolicy hypre_exec_policy = HYPRE_EXEC_HOST;

/* use hypre's SpGEMM instead of vendor implementation */
HYPRE_SetSpGemmUseVendor(false);
/* use GPU RNG */
HYPRE_SetUseGpuRand(true);
bool is_use_device = false;
if (m_runner) {
is_use_device = isAcceleratorPolicy(m_runner->executionPolicy());
info() << "Runner for Hypre=" << m_runner->executionPolicy() << " is_device=" << is_use_device;
}
// Si HYPRE n'est pas compilé avec le support GPU, alors on utilise l'hôte.
// (NOTE: a priori il n'y a pas besoin de faire cela. Si Hypre n'est pas compilé avec
// support GPU alors HYPRE_MEMORY_DEVICE <=> HYPRE_MEMORY_HOST
// TODO: détecter la cohérence entre le GPU de Hypre et le notre (.i.e les deux
// utilisent CUDA ou ROCM)
#ifndef HYPRE_USING_GPU
if (is_use_device)
info() << "Hypre is not compiled with GPU support. Using host backend";
#endif

hypreCheck("HYPRE_SetMemoryLocation", HYPRE_SetMemoryLocation(HYPRE_MEMORY_HOST));
if (is_use_device) {
m_runner->setAsCurrentDevice();
hypre_memory = HYPRE_MEMORY_DEVICE;
hypre_exec_policy = HYPRE_EXEC_DEVICE;
}

hypreCheck("HYPRE_SetMemoryLocation", HYPRE_SetMemoryLocation(hypre_memory));
/* setup AMG on GPUs */
HYPRE_SetExecutionPolicy(HYPRE_EXEC_HOST);
hypreCheck("HYPRE_SetExecutionPolicy", HYPRE_SetExecutionPolicy(hypre_exec_policy));

if (is_use_device) {
#if HYPRE_RELEASE_NUMBER >= 22300
/* use hypre's SpGEMM instead of vendor implementation */
HYPRE_SetSpGemmUseVendor(false);
#endif
/* use GPU RNG */
HYPRE_SetUseGpuRand(true);
}

/* use hypre's GPU memory pool */
//HYPRE_SetGPUMemoryPoolSize(bin_growth, min_bin, max_bin, max_bytes);
auto hypre_memory = HYPRE_MEMORY_HOST;

/* setup IJ matrix A */
// TODO: Utiliser le bon communicateur
// TODO: Utiliser le bon communicateur en parallèle.
MPI_Comm comm = MPI_COMM_WORLD;

HYPRE_IJMatrix ij_A = nullptr;
Expand All @@ -212,8 +240,8 @@ solve()

int* rows_nb_column_data = const_cast<int*>(m_csr_view.rowsNbColumn().data());

Real m1 = platform::getRealTime();
HYPRE_IJMatrixSetObjectType(ij_A, HYPRE_PARCSR);
//HYPRE_IJMatrixSetRowSizes(ij_A, rows_nb_column_data);
HYPRE_IJMatrixInitialize_v2(ij_A, hypre_memory);

/* GPU pointers; efficient in large chunks */
Expand All @@ -226,6 +254,8 @@ solve()

HYPRE_IJMatrixAssemble(ij_A);
HYPRE_IJMatrixGetObject(ij_A, (void**)&parcsr_A);
Real m2 = platform::getRealTime();
info() << "Time to create matrix=" << (m2 - m1);

//HYPRE_IJMatrixPrint(ij_A, "dumpA.txt");

Expand All @@ -242,6 +272,7 @@ solve()
hypreCheck("IJVectorSetObjectType", HYPRE_IJVectorSetObjectType(ij_vector_x, HYPRE_PARCSR));
HYPRE_IJVectorInitialize_v2(ij_vector_x, hypre_memory);

Real v1 = platform::getRealTime();
hypreCheck("HYPRE_IJVectorSetValues",
HYPRE_IJVectorSetValues(ij_vector_b, nb_row, rows_index_span.data(),
m_rhs_variable.asArray().data()));
Expand All @@ -257,6 +288,8 @@ solve()
hypreCheck("HYPRE_IJVectorAssemble",
HYPRE_IJVectorAssemble(ij_vector_x));
HYPRE_IJVectorGetObject(ij_vector_x, (void**)&parvector_x);
Real v2 = platform::getRealTime();
info() << "Time to create vectors=" << (v2 - v1);

//HYPRE_IJVectorPrint(ij_vector_b, "dumpB.txt");
//HYPRE_IJVectorPrint(ij_vector_x, "dumpX.txt");
Expand Down Expand Up @@ -288,8 +321,11 @@ solve()
HYPRE_ParCSRPCGSetPrecond(solver, HYPRE_BoomerAMGSolve, HYPRE_BoomerAMGSetup, precond));
hypreCheck("HYPRE_PCGSetup",
HYPRE_ParCSRPCGSetup(solver, parcsr_A, parvector_b, parvector_x));
Real a1 = platform::getRealTime();
hypreCheck("HYPRE_PCGSolve",
HYPRE_ParCSRPCGSolve(solver, parcsr_A, parvector_b, parvector_x));
Real b1 = platform::getRealTime();
info() << "Time to solve=" << (b1 - a1);

hypreCheck("HYPRE_IJVectorGetValues",
HYPRE_IJVectorGetValues(ij_vector_x, nb_row, rows_index_span.data(),
Expand Down
3 changes: 1 addition & 2 deletions poisson/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,7 @@ endif()
if(ARCANE_HAS_ACCELERATOR)
add_test(NAME [poisson]poisson_gpu COMMAND ./Poisson -A,AcceleratorRuntime=cuda Test.poisson.petsc.arc)
if(FEMUTILS_HAS_SOLVER_BACKEND_HYPRE)
# Not yet available
# add_test(NAME [poisson]poisson_hypre_direct_gpu COMMAND ./Poisson -A,AcceleratorRuntime=cuda Test.poisson.hypre_direct.arc)
add_test(NAME [poisson]poisson_hypre_direct_gpu COMMAND ./Poisson -A,AcceleratorRuntime=cuda Test.poisson.hypre_direct.arc)
endif()
endif()

Expand Down
2 changes: 1 addition & 1 deletion poisson/FemModule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ compute()
m_linear_system.reset();
m_linear_system.setLinearSystemFactory(options()->linearSystem());

m_linear_system.initialize(subDomain(), m_dofs_on_nodes.dofFamily(), "Solver");
m_linear_system.initialize(subDomain(), acceleratorMng()->defaultRunner(), m_dofs_on_nodes.dofFamily(), "Solver");
// Test for adding parameters for PETSc.
// This is only used for the first call.
{
Expand Down

0 comments on commit 38d4e03

Please sign in to comment.