FEAT 3
Finite Element Analysis Toolbox
Loading...
Searching...
No Matches
direct_sparse_solver_mumps.cpp
1// FEAT3: Finite Element Analysis Toolbox, Version 3
2// Copyright (C) 2010 by Stefan Turek & the FEAT group
3// FEAT3 is released under the GNU General Public License version 3,
4// see the file 'copyright.txt' in the top level directory for details.
5
6// includes, FEAT
8#include <kernel/backend.hpp>
9#include <kernel/solver/direct_sparse_solver.hpp>
10#include <kernel/util/omp_util.hpp>
11
12#ifdef FEAT_HAVE_MUMPS
13
14#ifdef FEAT_HAVE_MPI
15#include <mpi.h>
16#endif // FEAT_HAVE_MPI
17
18#include <dmumps_c.h>
19
20#include <vector>
21
22namespace FEAT
23{
24 namespace Solver
25 {
26 namespace DSS
27 {
33 class MUMPS_Core
34 {
35 public:
36 const Dist::Comm& comm;
37#ifdef FEAT_HAVE_MPI
39 MPI_Comm _comm;
40#endif
42 DMUMPS_STRUC_C id;
43
44 const MUMPS_INT num_global_dofs, dof_offset, num_owned_dofs, num_owned_nzes, num_global_nzes;
45 MUMPS_INT max_owned_dofs;
46
48 std::vector<std::int64_t> row_ptr, col_idx, vec_ptr;
50 std::vector<MUMPS_INT> mumps_row_idx, mumps_col_idx;
52 std::vector<double> mat_val;
54 std::vector<double> vector;
55
56 explicit MUMPS_Core(const Dist::Comm& comm_, Index num_global_dofs_, Index dof_offset_,
57 Index num_owned_dofs_, Index num_owned_nzes_, Index num_global_nzes_) :
58 comm(comm_),
59 num_global_dofs(num_global_dofs_),
60 dof_offset(dof_offset_),
61 num_owned_dofs(num_owned_dofs_),
62 num_owned_nzes(num_owned_nzes_),
63 num_global_nzes(num_global_nzes_),
64 max_owned_dofs(num_owned_dofs_),
65 row_ptr(num_owned_dofs_+1),
66 col_idx(num_owned_nzes_),
67 mumps_row_idx(num_owned_nzes_),
68 mumps_col_idx(num_owned_nzes_),
69 mat_val(num_owned_nzes_),
70 vector()
71 {
72 // initialize the MUMPS data structure
73 memset(&id, 0, sizeof(DMUMPS_STRUC_C));
74#ifdef FEAT_HAVE_MPI
75 id.comm_fortran = (MUMPS_INT)MPI_Comm_c2f(comm.mpi_comm());
76#endif
77 id.sym = 0; // unsymmetric matrix
78 id.par = 1; // host working
79
80 // initialize MUMPS
81 id.job = -1;
82 dmumps_c(&id);
83
84 // print errors in debug mode and always keep quiet in opt mode
85#ifdef DEBUG
86 id.icntl[4-1] = 1; // only error messages
87#else
88 id.icntl[4-1] = 0; // no messages
89#endif
90
91#ifdef FEAT_HAVE_MPI
92 std::size_t num_procs = std::size_t(comm.size());
93 // allocate vector pointer on rank 0
94 if(comm.rank() == 0)
95 vec_ptr.resize(num_procs + 1u, 0);
96
97 // gather local vector sizes on rank 0
98 std::int64_t loc_size(num_owned_dofs);
99 comm.gather(&loc_size, std::size_t(1), vec_ptr.data(), std::size_t(1), 0);
100
101 if(comm.rank() == 0)
102 {
103 // perform exclusive scan to obtain vector pointer array
104 feat_omp_ex_scan(num_procs + 1u, vec_ptr.data(), vec_ptr.data());
105
106 // ensure that the number of global dofs adds up
107 XASSERT(vec_ptr.back() == std::int64_t(num_global_dofs));
108
109 // allocate global vector only on rank 0
110 vector.resize(num_global_dofs);
111 }
112 else
113 {
114 vector.resize(num_owned_dofs);
115 }
116#else // no FEAT_HAVE_MPI
117 vector.resize(num_global_dofs);
118#endif // FEAT_HAVE_MPI
119 }
120
121 ~MUMPS_Core()
122 {
123 // release MUMPS
124 id.job = -2;
125 dmumps_c(&id);
126 }
127
128 void init_symbolic()
129 {
130 // convert 0-based CSR index arrays to 1-based COO index arrays
131 FEAT_PRAGMA_OMP(parallel for)
132 for(std::int64_t i = 0; i < num_owned_dofs; ++i)
133 {
134 for(std::int64_t j = row_ptr[i]; j < row_ptr[i+1]; ++j)
135 {
136 mumps_row_idx[j] = static_cast<MUMPS_INT>(dof_offset + i + 1);
137 mumps_col_idx[j] = static_cast<MUMPS_INT>(col_idx[j] + 1);
138 }
139 }
140
141 // analyze matrix
142 id.job = 1;
143 // note: the -1 is to account for the 1-based indexing in the MUMPS documentation
144 id.icntl[ 5 - 1] = 0; // assembled matrix format
145 id.icntl[ 6 - 1] = 0; // column permutation not available for distributed matrices
146 id.icntl[18 - 1] = 3; // matrix distributed by user
147 id.icntl[20 - 1] = 0; // centratized RHS
148 id.icntl[21 - 1] = 0; // centralized solution
149 id.n = (MUMPS_INT)num_global_dofs;
150 id.nnz_loc = (MUMPS_INT)num_owned_nzes;
151 id.irn_loc = mumps_row_idx.data();
152 id.jcn_loc = mumps_col_idx.data();
153 id.a_loc = mat_val.data();
154 id.nrhs = 1;
155 id.lrhs = (MUMPS_INT)num_global_dofs;
156 id.rhs = vector.data();
157 dmumps_c(&id);
158
159 if(id.infog[0] == 0)
160 return;
161
162 throw DirectSparseSolverException("mumps",
163 "MUMPS Symbolic factorization error with INFO(1) = " + stringify(id.info[0]) + " and INFO(2) = " + stringify(id.info[1]));
164 }
165
166 void init_numeric()
167 {
168 // factorize matrix
169 id.job = 2;
170 dmumps_c(&id);
171
172 if(id.infog[0] == 0)
173 return;
174
175 throw DirectSparseSolverException("mumps",
176 "MUMPS Numeric factorization error with INFO(1) = " + stringify(id.info[0]) + " and INFO(2) = " + stringify(id.info[1]));
177 }
178
179 void done_numeric()
180 {
181 // release MUMPS factorization
182 id.job = -4;
183 dmumps_c(&id);
184 }
185
186 void solve()
187 {
188#ifdef FEAT_HAVE_MPI
189 // gather vector on rank 0
190 if(comm.size() > 1)
191 {
192 if(comm.rank() == 0)
193 {
194 // post receives
195 Dist::RequestVector reqs(comm.size());
196 for(int i = 1; i < comm.size(); ++i)
197 reqs[i] = comm.irecv(&vector[vec_ptr[i]], std::size_t(vec_ptr[i+1] - vec_ptr[i]), i);
198
199 // process all receives
200 reqs.wait_all();
201 }
202 else
203 {
204 // send our local vector to rank 0
205 comm.send(vector.data(), vector.size(), 0);
206 }
207 }
208#endif // FEAT_HAVE_MPI
209
210 // solve system
211 id.job = 3;
212 dmumps_c(&id);
213
214#ifdef FEAT_HAVE_MPI
215 // scatter vector from rank 0
216 if(comm.size() > 1)
217 {
218 if(comm.rank() == 0)
219 {
220 // post sends
221 Dist::RequestVector reqs(comm.size());
222 for(int i = 1; i < comm.size(); ++i)
223 reqs[i] = comm.isend(&vector[vec_ptr[i]], std::size_t(vec_ptr[i+1] - vec_ptr[i]), i);
224
225 // process all sens
226 reqs.wait_all();
227 }
228 else
229 {
230 // receive our local vector from rank 0
231 comm.recv(vector.data(), vector.size(), 0);
232 }
233 }
234#endif // FEAT_HAVE_MPI
235 }
236 }; // class MUMPS_Core
237
238 void* create_mumps_core(const Dist::Comm* comm, Index num_global_dofs, Index dof_offset,
239 Index num_owned_dofs, Index num_owned_nzes, Index num_global_nzes)
240 {
241 return new MUMPS_Core(*comm, num_global_dofs, dof_offset, num_owned_dofs, num_owned_nzes, num_global_nzes);
242 }
243
244 void destroy_mumps_core(void* core)
245 {
246 XASSERT(core != nullptr);
247 delete reinterpret_cast<MUMPS_Core*>(core);
248 }
249
250 MUMPS_IT* get_mumps_row_ptr(void* core)
251 {
252 XASSERT(core != nullptr);
253 return reinterpret_cast<MUMPS_Core*>(core)->row_ptr.data();
254 }
255
256 MUMPS_IT* get_mumps_col_idx(void* core)
257 {
258 XASSERT(core != nullptr);
259 return reinterpret_cast<MUMPS_Core*>(core)->col_idx.data();
260 }
261
262 MUMPS_DT* get_mumps_mat_val(void* core)
263 {
264 XASSERT(core != nullptr);
265 return reinterpret_cast<MUMPS_Core*>(core)->mat_val.data();
266 }
267
268 MUMPS_DT* get_mumps_vector(void* core)
269 {
270 XASSERT(core != nullptr);
271 return reinterpret_cast<MUMPS_Core*>(core)->vector.data();
272 }
273
274 void init_mumps_symbolic(void* core)
275 {
276 XASSERT(core != nullptr);
277 reinterpret_cast<MUMPS_Core*>(core)->init_symbolic();
278 }
279
280 void init_mumps_numeric(void* core)
281 {
282 XASSERT(core != nullptr);
283 reinterpret_cast<MUMPS_Core*>(core)->init_numeric();
284 }
285
286 void done_mumps_numeric(void* core)
287 {
288 XASSERT(core != nullptr);
289 reinterpret_cast<MUMPS_Core*>(core)->done_numeric();
290 }
291
292 void solve_mumps(void* core)
293 {
294 XASSERT(core != nullptr);
295 reinterpret_cast<MUMPS_Core*>(core)->solve();
296 }
297 } // namespace DSS
298 } // namespace Solver
299} // namespace FEAT
300
301#else // no FEAT_HAVE_MUMPS
302
303void feat_direct_sparse_solver_mumps_dummy()
304{
305}
306
307#endif // FEAT_HAVE_MUMPS
#define XASSERT(expr)
Assertion macro definition.
Definition: assertion.hpp:262
FEAT Kernel base header.
Status solve(SolverBase< Vector_ > &solver, Vector_ &vec_sol, const Vector_ &vec_rhs, const Matrix_ &matrix, const Filter_ &filter)
Solve linear system with initial solution guess.
Definition: base.hpp:347
FEAT namespace.
Definition: adjactor.hpp:12
void feat_omp_ex_scan(std::size_t n, const T_ x[], T_ y[])
Computes an OpenMP-parallel exclusive scan a.k.a. a prefix sum of an array, i.e.
Definition: omp_util.hpp:153
String stringify(const T_ &item)
Converts an item into a String.
Definition: string.hpp:993