8#include <kernel/backend.hpp>
9#include <kernel/solver/direct_sparse_solver.hpp>
10#include <kernel/util/omp_util.hpp>
36 const Dist::Comm& comm;
44 const MUMPS_INT num_global_dofs, dof_offset, num_owned_dofs, num_owned_nzes, num_global_nzes;
45 MUMPS_INT max_owned_dofs;
48 std::vector<std::int64_t> row_ptr, col_idx, vec_ptr;
50 std::vector<MUMPS_INT> mumps_row_idx, mumps_col_idx;
52 std::vector<double> mat_val;
54 std::vector<double> vector;
56 explicit MUMPS_Core(
const Dist::Comm& comm_, Index num_global_dofs_, Index dof_offset_,
57 Index num_owned_dofs_, Index num_owned_nzes_, Index num_global_nzes_) :
59 num_global_dofs(num_global_dofs_),
60 dof_offset(dof_offset_),
61 num_owned_dofs(num_owned_dofs_),
62 num_owned_nzes(num_owned_nzes_),
63 num_global_nzes(num_global_nzes_),
64 max_owned_dofs(num_owned_dofs_),
65 row_ptr(num_owned_dofs_+1),
66 col_idx(num_owned_nzes_),
67 mumps_row_idx(num_owned_nzes_),
68 mumps_col_idx(num_owned_nzes_),
69 mat_val(num_owned_nzes_),
73 memset(&
id, 0,
sizeof(DMUMPS_STRUC_C));
75 id.comm_fortran = (MUMPS_INT)MPI_Comm_c2f(comm.mpi_comm());
92 std::size_t num_procs = std::size_t(comm.size());
95 vec_ptr.resize(num_procs + 1u, 0);
98 std::int64_t loc_size(num_owned_dofs);
99 comm.gather(&loc_size, std::size_t(1), vec_ptr.data(), std::size_t(1), 0);
107 XASSERT(vec_ptr.back() == std::int64_t(num_global_dofs));
110 vector.resize(num_global_dofs);
114 vector.resize(num_owned_dofs);
117 vector.resize(num_global_dofs);
131 FEAT_PRAGMA_OMP(parallel
for)
132 for(std::int64_t i = 0; i < num_owned_dofs; ++i)
134 for(std::int64_t j = row_ptr[i]; j < row_ptr[i+1]; ++j)
136 mumps_row_idx[j] =
static_cast<MUMPS_INT
>(dof_offset + i + 1);
137 mumps_col_idx[j] =
static_cast<MUMPS_INT
>(col_idx[j] + 1);
144 id.icntl[ 5 - 1] = 0;
145 id.icntl[ 6 - 1] = 0;
146 id.icntl[18 - 1] = 3;
147 id.icntl[20 - 1] = 0;
148 id.icntl[21 - 1] = 0;
149 id.n = (MUMPS_INT)num_global_dofs;
150 id.nnz_loc = (MUMPS_INT)num_owned_nzes;
151 id.irn_loc = mumps_row_idx.data();
152 id.jcn_loc = mumps_col_idx.data();
153 id.a_loc = mat_val.data();
155 id.lrhs = (MUMPS_INT)num_global_dofs;
156 id.rhs = vector.data();
162 throw DirectSparseSolverException(
"mumps",
163 "MUMPS Symbolic factorization error with INFO(1) = " +
stringify(
id.info[0]) +
" and INFO(2) = " +
stringify(
id.info[1]));
175 throw DirectSparseSolverException(
"mumps",
176 "MUMPS Numeric factorization error with INFO(1) = " +
stringify(
id.info[0]) +
" and INFO(2) = " +
stringify(
id.info[1]));
195 Dist::RequestVector reqs(comm.size());
196 for(
int i = 1; i < comm.size(); ++i)
197 reqs[i] = comm.irecv(&vector[vec_ptr[i]], std::size_t(vec_ptr[i+1] - vec_ptr[i]), i);
205 comm.send(vector.data(), vector.size(), 0);
221 Dist::RequestVector reqs(comm.size());
222 for(
int i = 1; i < comm.size(); ++i)
223 reqs[i] = comm.isend(&vector[vec_ptr[i]], std::size_t(vec_ptr[i+1] - vec_ptr[i]), i);
231 comm.recv(vector.data(), vector.size(), 0);
238 void* create_mumps_core(
const Dist::Comm* comm, Index num_global_dofs, Index dof_offset,
239 Index num_owned_dofs, Index num_owned_nzes, Index num_global_nzes)
241 return new MUMPS_Core(*comm, num_global_dofs, dof_offset, num_owned_dofs, num_owned_nzes, num_global_nzes);
244 void destroy_mumps_core(
void* core)
247 delete reinterpret_cast<MUMPS_Core*
>(core);
250 MUMPS_IT* get_mumps_row_ptr(
void* core)
253 return reinterpret_cast<MUMPS_Core*
>(core)->row_ptr.data();
256 MUMPS_IT* get_mumps_col_idx(
void* core)
259 return reinterpret_cast<MUMPS_Core*
>(core)->col_idx.data();
262 MUMPS_DT* get_mumps_mat_val(
void* core)
265 return reinterpret_cast<MUMPS_Core*
>(core)->mat_val.data();
268 MUMPS_DT* get_mumps_vector(
void* core)
271 return reinterpret_cast<MUMPS_Core*
>(core)->vector.data();
274 void init_mumps_symbolic(
void* core)
277 reinterpret_cast<MUMPS_Core*
>(core)->init_symbolic();
280 void init_mumps_numeric(
void* core)
283 reinterpret_cast<MUMPS_Core*
>(core)->init_numeric();
286 void done_mumps_numeric(
void* core)
289 reinterpret_cast<MUMPS_Core*
>(core)->done_numeric();
292 void solve_mumps(
void* core)
295 reinterpret_cast<MUMPS_Core*
>(core)->
solve();
303void feat_direct_sparse_solver_mumps_dummy()
#define XASSERT(expr)
Assertion macro definition.
Status solve(SolverBase< Vector_ > &solver, Vector_ &vec_sol, const Vector_ &vec_rhs, const Matrix_ &matrix, const Filter_ &filter)
Solve linear system with initial solution guess.
void feat_omp_ex_scan(std::size_t n, const T_ x[], T_ y[])
Computes an OpenMP-parallel exclusive scan a.k.a. a prefix sum of an array, i.e.
String stringify(const T_ &item)
Converts an item into a String.