Skip to content

Commit

Permalink
Merge pull request #84 from Maikashan/main
Browse files Browse the repository at this point in the history
GPU integration for assemby phase for poisson
  • Loading branch information
mohd-afeef-badri authored Nov 30, 2023
2 parents 90c9a4b + 8bdf70e commit 12a5b81
Show file tree
Hide file tree
Showing 10 changed files with 2,163 additions and 49 deletions.
1 change: 0 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
cmake_minimum_required(VERSION 3.21)
project(FemTest1Main LANGUAGES NONE)
enable_testing()

set(MSH_DIR ${CMAKE_SOURCE_DIR}/meshes/msh)

add_subdirectory(femutils)
Expand Down
6 changes: 6 additions & 0 deletions femutils/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,17 @@ project(FemTest1 LANGUAGES C CXX)

find_package(Arcane REQUIRED)

option(ENABLE_DEBUG_MATRIX "Enable Debug matrix instead of a sparse one" OFF)

add_library(FemUtils
FemUtils.h
FemUtils.cc
NodeLinearSystem.h
NodeLinearSystem.cc
DoFLinearSystem.h
DoFLinearSystem.cc
CooFormatMatrix.h
CsrFormatMatrix.h
FemDoFsOnNodes.h
FemDoFsOnNodes.cc
AlephNodeLinearSystem.cc
Expand All @@ -21,6 +25,8 @@ add_library(FemUtils
arcane_generate_axl(AlephDoFLinearSystemFactory)
arcane_generate_axl(SequentialBasicDoFLinearSystemFactory)

target_compile_definitions(FemUtils PRIVATE $<$<BOOL:${ENABLE_DEBUG_MATRIX}>:ENABLE_DEBUG_MATRIX>)

target_include_directories(FemUtils PUBLIC ${CMAKE_CURRENT_SOURCE_DIR})
target_include_directories(FemUtils PRIVATE ${CMAKE_CURRENT_BINARY_DIR})

Expand Down
299 changes: 299 additions & 0 deletions femutils/CooFormatMatrix.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,299 @@
// -*- tab-width: 2; indent-tabs-mode: nil; coding: utf-8-with-signature -*-
//-----------------------------------------------------------------------------
// Copyright 2000-2023 CEA (www.cea.fr) IFPEN (www.ifpenergiesnouvelles.com)
// See the top-level COPYRIGHT file for details.
// SPDX-License-Identifier: Apache-2.0
//-----------------------------------------------------------------------------
/*---------------------------------------------------------------------------*/
/* CsrFormatMatrix.cc (C) 2022-2023 */
/* */
/*---------------------------------------------------------------------------*/
/*---------------------------------------------------------------------------*/

#include <arcane/utils/FatalErrorException.h>
#include <arcane/utils/NumArray.h>

#include <arcane/VariableTypes.h>
#include <arcane/IItemFamily.h>

#include <arcane/aleph/AlephTypesSolver.h>
#include <arcane/aleph/Aleph.h>

#include "FemUtils.h"
#include "DoFLinearSystem.h"
#include "arcane_version.h"

#include <iostream>
#include <fstream>

#include "arcane/accelerator/NumArrayViews.h"

namespace Arcane::FemUtils
{
using namespace Arcane;

namespace ax = Arcane::Accelerator;
class CooFormat : TraceAccessor
{
public:

CooFormat(ISubDomain* sd)
: TraceAccessor(sd->traceMng())
{
info() << "Creating CSR Matrix";
}

void initialize(IItemFamily* dof_family, Int32 nnz)
{
m_matrix_row.resize(nnz);
m_matrix_column.resize(nnz);
m_matrix_value.resize(nnz);
m_matrix_row.fill(0);
m_matrix_column.fill(0);
m_matrix_value.fill(0);
m_dof_family = dof_family;
m_last_value = 0;
m_nnz = nnz;
info() << "Filling COO Matrix with zeros";
}

/**
* @brief
*
* @param row
* @param column
* @param value
*/
void matrixAddValue(DoFLocalId row, DoFLocalId column, Real value)
{
if (row.isNull())
ARCANE_FATAL("Row is null");
if (column.isNull())
ARCANE_FATAL("Column is null");
if (value == 0.0)
return;
m_matrix_value(indexValue(row, column)) += value;
}

/**
* @brief Translate to Arcane linear system
*
* @param linear_system
*/
void translateToLinearSystem(DoFLinearSystem& linear_system)
{
for (Int32 i = 0; i < m_nnz; i++) {
linear_system.matrixSetValue(DoFLocalId(m_matrix_row(i)), DoFLocalId(m_matrix_column(i)), m_matrix_value(i));
}
}

/**
* @brief function to print the current content of the csr matrix
*
* @param fileName
* @param nonzero if set to true, print only the non zero values
*/
void
printMatrix(std::string fileName, bool nonzero)
{
ofstream file(fileName);
file << "size :" << m_matrix_row.extent0() << "\n";
for (auto i = 0; i < m_matrix_row.extent0(); i++) {
if (nonzero && m_matrix_value(i) == 0)
continue;
file << m_matrix_row(i) << " ";
}
file << "\n";
for (auto i = 0; i < m_nnz; i++) {
if (nonzero && m_matrix_value(i) == 0)
continue;
file << m_matrix_column(i) << " ";
}
file << "\n";
for (auto i = 0; i < m_nnz; i++) {
if (nonzero && m_matrix_value(i) == 0)
continue;
file << m_matrix_value(i) << " ";
}
file.close();
}

void setCoordinates(DoFLocalId row, DoFLocalId column)
{
m_matrix_row(m_last_value) = row.localId();
m_matrix_column(m_last_value) = column.localId();
m_last_value++;
}

void sort()
{
sortMatrix(true, 0, m_matrix_row.extent0() - 1);
Int32 begin = 0;
for (Int32 i = 0; i < m_matrix_row.extent0(); i++) {
if (i + 1 == m_matrix_row.extent0() || m_matrix_row(i + 1) != m_matrix_row(begin)) {
sortMatrix(false, begin, i);
begin = i + 1;
}
}
}

public:

Int32 m_nnz;
// To become parallelizable, have all the index
// inside a queue that would gradually pop ?
// or link the idnex to the index of the core ?
Int32 m_last_value;
NumArray<Int32, MDDim1> m_matrix_row;
NumArray<Int32, MDDim1> m_matrix_column;
NumArray<Real, MDDim1> m_matrix_value;
IItemFamily* m_dof_family = nullptr;

/*
getValue return the Value at the (row, column) coordinates.
*/
Int32 getValue(DoFLocalId row, DoFLocalId column)
{
return m_matrix_value(indexValue(row, column));
}

/**
* @brief binSearchRow is a binary search through the row to get the
* leftmost corresponding index.
*
* @param row
* @return Int32
*/
Int32 binSearchRow(Int32 row)
{
Int32 begin = 0;
Int32 end = m_matrix_row.totalNbElement() - 1;
while (begin <= end) {
Int32 mid = begin + (end - begin) / 2;
if (row == m_matrix_row(mid)) {
while (mid - 1 >= 0 && m_matrix_row(mid - 1) == row) {
mid--;
}
return mid;
}
if (row > m_matrix_row(mid)) {
begin = mid + 1;
}
if (row < m_matrix_row(mid)) {
end = mid - 1;
}
}
return -1;
}

/**
* @brief indexValue is a Binsearch through the row and then the column
* to get the index of the corresponding value.
*
* @param row
* @param column
* @return Int32
*/
Int32 indexValue(Int32 row, Int32 column)
{

Int32 i = binSearchRow(row);
while (i != m_matrix_row.totalNbElement() && m_matrix_row(i) == row) {
if (m_matrix_column(i) == column)
return i;
i++;
}
//binsearch only on the row and iterate through the column
/*
while (begin <= end) {
/*
Int32 mid = begin + (end - begin) / 2;
if (column == m_matrix_column(mid)) {
return mid;
}
if (column > m_matrix_column(mid)) {
begin = mid + 1;
}
if (column < m_matrix_column(mid)) {
end = mid - 1;
}
}
*/
return -1;
}

/**
* @brief Quicksort algorithm for the CSR Matrix
*
* @param is_row
* @param start
* @param end
*/
void
sortMatrix(bool is_row, Int32 start, Int32 end)
{
if (start >= end) {
return;
}

int pivot = partition(is_row, start, end);

sortMatrix(is_row, start, pivot - 1);

sortMatrix(is_row, pivot + 1, end);
}

/**
* @brief Partition helper for the quickSort
*
* @param is_row
* @param start
* @param end
* @return Int32
*/
Int32 partition(bool is_row, Int32 start, Int32 end)
{
Int32 pivot;
if (is_row)
pivot = m_matrix_row[end];
else
pivot = m_matrix_column[end];

Int32 pIndex = start;

for (Int32 i = start; i < end; i++) {
if ((is_row && m_matrix_row[i] <= pivot) || (!is_row && m_matrix_column[i] <= pivot)) {

swap(is_row, i, pIndex);
pIndex++;
}
}

swap(is_row, pIndex, end);

return pIndex;
}

/**
* @brief Swap helper for the quickSort
*
* @param is_row
* @param i
* @param j
*/
void swap(bool is_row, Int32 i, Int32 j)
{
if (is_row) {
Int32 tmp = m_matrix_row(i);
m_matrix_row(i) = m_matrix_row(j);
m_matrix_row(j) = tmp;
}
Int32 tmp = m_matrix_column(i);
m_matrix_column(i) = m_matrix_column(j);
m_matrix_column(j) = tmp;
Real tmp_val = m_matrix_value(i);
m_matrix_value(i) = m_matrix_value(j);
m_matrix_value(j) = tmp_val;
}
};
} // namespace Arcane::FemUtils
Loading

0 comments on commit 12a5b81

Please sign in to comment.