FEAT 3
Finite Element Analysis Toolbox
Loading...
Searching...
No Matches
direct_sparse_solver_mkl.cpp
1// FEAT3: Finite Element Analysis Toolbox, Version 3
2// Copyright (C) 2010 by Stefan Turek & the FEAT group
3// FEAT3 is released under the GNU General Public License version 3,
4// see the file 'copyright.txt' in the top level directory for details.
5
6// includes, FEAT
8#include <kernel/backend.hpp>
9#include <kernel/solver/direct_sparse_solver.hpp>
10
11#ifdef FEAT_HAVE_MKL
12
13FEAT_DISABLE_WARNINGS
14#include <mkl_dss.h>
15#include <mkl_cluster_sparse_solver.h>
16FEAT_RESTORE_WARNINGS
17
18namespace FEAT
19{
20 namespace Solver
21 {
22 namespace DSS
23 {
24 static_assert(sizeof(MKLDSS_IT) == sizeof(MKL_INT), "DirectSparseSolver: MKL-DSS: index type size mismatch!");
25
31 class MKLDSS_Core
32 {
33 public:
34#ifdef FEAT_HAVE_MPI
36 int mpi_comm;
38 void* css_handle[64];
40 MKL_INT css_iparm[64];
41#else // no MPI
43 void* dss_handle;
44#endif // FEAT_HAVE_MPI
46 MKL_INT matrix_type;
48 MKL_INT num_rhs;
50 MKL_INT num_global_dofs, dof_offset, num_owned_dofs, num_owned_nzes;
52 std::int64_t peak_mem_sym, peak_mem_num;
54 std::vector<MKL_INT> row_ptr, col_idx;
56 std::vector<double> mat_val, rhs_val, sol_val;
57
58 explicit MKLDSS_Core(const Dist::Comm& comm, Index num_global_dofs_, Index dof_offset_,
59 Index num_owned_dofs_, Index num_owned_nzes_, Index DOXY(num_global_nzes_)) :
60#ifdef FEAT_HAVE_MPI
61 mpi_comm(MPI_Comm_c2f(comm.mpi_comm())),
62#else
63 dss_handle(nullptr),
64#endif // FEAT_HAVE_MPI
65 matrix_type(11), // real unsymmmetric
66 num_rhs(1),
67 num_global_dofs(static_cast<MKL_INT>(num_global_dofs_)),
68 dof_offset(static_cast<MKL_INT>(dof_offset_)),
69 num_owned_dofs(static_cast<MKL_INT>(num_owned_dofs_)),
70 num_owned_nzes(static_cast<MKL_INT>(num_owned_nzes_)),
71 peak_mem_sym(0),
72 peak_mem_num(0)
73 {
74#ifdef FEAT_HAVE_MPI
75 memset(css_handle, 0, sizeof(void*)*64);
76 memset(css_iparm, 0, sizeof(MKL_INT)*64);
77
78 // set parameters; see
79 // https://www.intel.com/content/www/us/en/docs/onemkl/developer-reference-c/2025-2/cluster-sparse-solver-iparm-parameter.html
80 css_iparm[ 0] = 1; // override default parameters
81 css_iparm[ 1] = 10; // use MPI-parallel factorization
82 css_iparm[ 4] = 0; // let MKL handle the permutation
83 css_iparm[ 5] = 0; // write solution into sol vector
84 css_iparm[ 7] = 0; // default number of iterative refinement steps
85 css_iparm[ 9] = 13; // perturb small pivots by 10^{-13} (default value)
86 css_iparm[10] = 0; // disable diagonal scaling (would require values in symbolic asm)
87 css_iparm[11] = 0; // solve A*X=B
88 css_iparm[12] = 0; // disable weighted matching (would require values in symbolic asm)
89 css_iparm[17] = -1; // enable reporting of non-zero elements in pivots
90 css_iparm[26] = 1; // check matrix for errors
91 css_iparm[27] = 0; // data type is double precision
92 css_iparm[30] = 0; // no partial solve
93 css_iparm[34] = 1; // zero-based indexing (yes, '1' REALLY stands for zero-based and '0' for one-based indexing)
94 css_iparm[35] = 0; // something about Schur-complements...
95 css_iparm[36] = 0; // matrix storage is CSR
96 css_iparm[39] = 2; // distributed CSR and distributed vectors
97 css_iparm[40] = dof_offset;
98 css_iparm[41] = dof_offset + num_owned_dofs - 1;
99 css_iparm[59] = 0; // in-core mode
100#else // no MPI
101 (void)comm;
102 // create DSS handle
103 MKL_INT opt = MKL_DSS_TERM_LVL_ERROR + MKL_DSS_ZERO_BASED_INDEXING;
104#ifdef DEBUG
105 opt += MKL_DSS_MSG_LVL_WARNING;
106#endif
107 MKL_INT ret = ::dss_create(dss_handle, opt);
108
109 switch(ret)
110 {
111 case MKL_DSS_SUCCESS:
112 break;
113 case MKL_DSS_OUT_OF_MEMORY:
114 throw DirectSparseSolverException("MKL-DSS", "out of memory");
115 default:
116 throw DirectSparseSolverException("MKL-DSS", "unknown error");
117 }
118#endif // FEAT_HAVE_MPI
119
120 // allocate matrix and vector arrays
121 row_ptr.resize(std::size_t(num_owned_dofs+1), 0u);
122 col_idx.resize(std::size_t(num_owned_nzes), 0u);
123 mat_val.resize(std::size_t(num_owned_nzes), 0.0);
124 rhs_val.resize(std::size_t(num_owned_dofs), 0.0);
125 sol_val.resize(std::size_t(num_owned_dofs), 0.0);
126 }
127
128 ~MKLDSS_Core()
129 {
130#ifdef FEAT_HAVE_MPI
131 MKL_INT maxfct = 1; // must be set to 1 (is ignored)
132 MKL_INT mnum = 1; // must be set to 1 (is ignored)
133 MKL_INT phase = -1; // cleanup phase
134 MKL_INT msglvl = 0; // print statistics
135 MKL_INT error = 0; // error code output variable
136
137 // call solver routine to perform cleanup
138 cluster_sparse_solver(
139 css_handle,
140 &maxfct,
141 &mnum,
142 &matrix_type,
143 &phase,
144 &num_global_dofs,
145 mat_val.data(),
146 row_ptr.data(),
147 col_idx.data(),
148 nullptr,
149 &num_rhs,
150 css_iparm,
151 &msglvl,
152 rhs_val.data(),
153 sol_val.data(),
154 &mpi_comm,
155 &error);
156#else // no MPI
157 if(dss_handle)
158 {
159 MKL_INT opt = MKL_DSS_TERM_LVL_ERROR;
160#ifdef DEBUG
161 opt += MKL_DSS_MSG_LVL_WARNING;
162#endif
163 ::dss_delete(dss_handle, opt);
164 dss_handle = nullptr;
165 }
166#endif // FEAT_HAVE_MPI
167 }
168
169 void init_symbolic()
170 {
171#ifdef FEAT_HAVE_MPI
172 MKL_INT maxfct = 1; // must be set to 1 (is ignored)
173 MKL_INT mnum = 1; // must be set to 1 (is ignored)
174 MKL_INT phase = 11; // symbolical factorization phase
175 MKL_INT msglvl = 0; // print statistics
176 MKL_INT error = 0; // error code output variable
177
178 // call factorization routine
179 cluster_sparse_solver(
180 css_handle,
181 &maxfct,
182 &mnum,
183 &matrix_type,
184 &phase,
185 &num_global_dofs,
186 mat_val.data(),
187 row_ptr.data(),
188 col_idx.data(),
189 nullptr,
190 &num_rhs,
191 css_iparm,
192 &msglvl,
193 rhs_val.data(),
194 sol_val.data(),
195 &mpi_comm,
196 &error);
197
198 switch(error)
199 {
200 case 0: // no error
201 return;
202
203 case -2: // out of memory
204 throw DirectSparseSolverException("MKL-DSS", "out of memory");
205
206 default:
207 throw DirectSparseSolverException("MKL-DSS", "unknown symbolic factorization error");
208 }
209
210 // collect statistics
211 peak_mem_sym = css_iparm[14] * 1024ll; // size is given in kB
212 peak_mem_num = css_iparm[16] * 1024ll; // ditto
213#else // no MPI
214 // set matrix structure
215 MKL_INT opt = MKL_DSS_NON_SYMMETRIC;
216 MKL_INT ret = ::dss_define_structure(dss_handle, opt, row_ptr.data(), num_owned_dofs,
217 num_owned_dofs, col_idx.data(), num_owned_nzes);
218
219 switch(ret)
220 {
221 case MKL_DSS_SUCCESS:
222 break;
223 case MKL_DSS_OUT_OF_MEMORY:
224 throw DirectSparseSolverException("MKL-DSS", "out of memory");
225 case MKL_DSS_STRUCTURE_ERR:
226 throw DirectSparseSolverException("MKL-DSS", "invalid matrix structure");
227 default:
228 throw DirectSparseSolverException("MKL-DSS", "unknown error");
229 }
230
231 // reorder matrix
232 opt = MKL_DSS_AUTO_ORDER;
233 ret = ::dss_reorder(dss_handle, opt, nullptr);
234
235 switch(ret)
236 {
237 case MKL_DSS_SUCCESS:
238 break;
239 case MKL_DSS_OUT_OF_MEMORY:
240 throw DirectSparseSolverException("MKL-DSS", "out of memory");
241 default:
242 throw DirectSparseSolverException("MKL-DSS", "unknown error");
243 }
244
245 // collect statistics
246 opt = 0;
247 double stats[3] = {0.0, 0.0, 0.0};
248 dss_statistics(dss_handle, opt, "Peakmem,Factormem,Solvemem", stats);
249 peak_mem_sym = std::int64_t(stats[0] * 1024.0);
250 peak_mem_num = std::int64_t((stats[1]+stats[2]) * 1024.0);
251 //peak_mem_sym
252#endif // FEAT_HAVE_MPI
253 }
254
255 void init_numeric()
256 {
257#ifdef FEAT_HAVE_MPI
258 MKL_INT maxfct = 1; // must be set to 1 (is ignored)
259 MKL_INT mnum = 1; // must be set to 1 (is ignored)
260 MKL_INT phase = 22; // symbolical factorization phase
261 MKL_INT msglvl = 0; // print statistics
262 MKL_INT error = 0; // error code output variable
263
264 // call factorization routine
265 cluster_sparse_solver(
266 css_handle,
267 &maxfct,
268 &mnum,
269 &matrix_type,
270 &phase,
271 &num_global_dofs,
272 mat_val.data(),
273 row_ptr.data(),
274 col_idx.data(),
275 nullptr,
276 &num_rhs,
277 css_iparm,
278 &msglvl,
279 rhs_val.data(),
280 sol_val.data(),
281 &mpi_comm,
282 &error);
283
284 switch(error)
285 {
286 case 0: // no error
287 return;
288
289 case -2: // out of memory
290 throw DirectSparseSolverException("MKL-DSS", "out of memory");
291
292 case -4: // zero pivot
293 throw DirectSparseSolverException("MKL-DSS", "zero pivot");
294
295 default:
296 throw DirectSparseSolverException("MKL-DSS", "unknown numeric factorization error");
297 }
298#else // no MPI
299 MKL_INT opt = MKL_DSS_INDEFINITE;
300 MKL_INT ret = ::dss_factor_real(dss_handle, opt, mat_val.data());
301
302 switch(ret)
303 {
304 case MKL_DSS_SUCCESS:
305 break;
306 case MKL_DSS_OUT_OF_MEMORY:
307 throw DirectSparseSolverException("MKL-DSS", "out of memory");
308 case MKL_DSS_ZERO_PIVOT:
309 throw DirectSparseSolverException("MKL-DSS", "zero pivot");
310 default:
311 throw DirectSparseSolverException("MKL-DSS", "unknown error");
312 }
313#endif // FEAT_HAVE_MPI
314 }
315
316 void solve()
317 {
318#ifdef FEAT_HAVE_MPI
319 MKL_INT maxfct = 1; // must be set to 1 (is ignored)
320 MKL_INT mnum = 1; // must be set to 1 (is ignored)
321 MKL_INT phase = 33; // solution phase
322 MKL_INT msglvl = 0; // print statistics
323 MKL_INT error = 0; // error code output variable
324
325 // call factorization routine
326 cluster_sparse_solver(
327 css_handle,
328 &maxfct,
329 &mnum,
330 &matrix_type,
331 &phase,
332 &num_global_dofs,
333 mat_val.data(),
334 row_ptr.data(),
335 col_idx.data(),
336 nullptr,
337 &num_rhs,
338 css_iparm,
339 &msglvl,
340 rhs_val.data(),
341 sol_val.data(),
342 &mpi_comm,
343 &error);
344
345 switch(error)
346 {
347 case 0: // no error
348 return;
349
350 case -2: // out of memory
351 throw DirectSparseSolverException("MKL-DSS", "out of memory");
352
353 case -4: // zero pivot
354 throw DirectSparseSolverException("MKL-DSS", "zero pivot");
355
356 default:
357 throw DirectSparseSolverException("MKL-DSS", "unknown solve error");
358 }
359#else
360 MKL_INT opt = 0;
361 MKL_INT ret = ::dss_solve_real(dss_handle, opt, rhs_val.data(), num_rhs, sol_val.data());
362 switch(ret)
363 {
364 case MKL_DSS_SUCCESS:
365 break;
366 default:
367 throw DirectSparseSolverException("MKL-DSS", "unknown solve error");
368 }
369#endif // FEAT_HAVE_MPI
370 }
371
372 std::int64_t get_peak_mem() const
373 {
374 return Math::max(peak_mem_num, peak_mem_sym);
375 }
376 }; // class MKLDSS_Core
377
378 void* create_mkldss_core(const Dist::Comm* comm, Index num_global_dofs, Index dof_offset,
379 Index num_owned_dofs, Index num_owned_nzes, Index num_global_nzes)
380 {
381 return new MKLDSS_Core(*comm, num_global_dofs, dof_offset, num_owned_dofs, num_owned_nzes, num_global_nzes);
382 }
383
384 void destroy_mkldss_core(void* core)
385 {
386 XASSERT(core != nullptr);
387 delete reinterpret_cast<MKLDSS_Core*>(core);
388 }
389
390 MKLDSS_IT* get_mkldss_row_ptr(void* core)
391 {
392 XASSERT(core != nullptr);
393 return reinterpret_cast<MKLDSS_Core*>(core)->row_ptr.data();
394 }
395
396 MKLDSS_IT* get_mkldss_col_idx(void* core)
397 {
398 XASSERT(core != nullptr);
399 return reinterpret_cast<MKLDSS_Core*>(core)->col_idx.data();
400 }
401
402 MKLDSS_DT* get_mkldss_mat_val(void* core)
403 {
404 XASSERT(core != nullptr);
405 return reinterpret_cast<MKLDSS_Core*>(core)->mat_val.data();
406 }
407
408 MKLDSS_DT* get_mkldss_rhs_val(void* core)
409 {
410 XASSERT(core != nullptr);
411 return reinterpret_cast<MKLDSS_Core*>(core)->rhs_val.data();
412 }
413
414 MKLDSS_DT* get_mkldss_sol_val(void* core)
415 {
416 XASSERT(core != nullptr);
417 return reinterpret_cast<MKLDSS_Core*>(core)->sol_val.data();
418 }
419
420 void init_mkldss_symbolic(void* core)
421 {
422 XASSERT(core != nullptr);
423 reinterpret_cast<MKLDSS_Core*>(core)->init_symbolic();
424 }
425
426 void init_mkldss_numeric(void* core)
427 {
428 XASSERT(core != nullptr);
429 reinterpret_cast<MKLDSS_Core*>(core)->init_numeric();
430 }
431
432 void solve_mkldss(void* core)
433 {
434 XASSERT(core != nullptr);
435 reinterpret_cast<MKLDSS_Core*>(core)->solve();
436 }
437
438 std::int64_t get_peak_mem_mkldss(void* core)
439 {
440 XASSERT(core != nullptr);
441 return reinterpret_cast<MKLDSS_Core*>(core)->get_peak_mem();
442 }
443 } // namespace DSS
444 } // namespace Solver
445} // namespace FEAT
446
447#else // no FEAT_HAVE_MKL
448
449void feat_direct_sparse_solver_mkldss_dummy()
450{
451}
452
453#endif // FEAT_HAVE_MKL
#define XASSERT(expr)
Assertion macro definition.
Definition: assertion.hpp:262
FEAT Kernel base header.
Status solve(SolverBase< Vector_ > &solver, Vector_ &vec_sol, const Vector_ &vec_rhs, const Matrix_ &matrix, const Filter_ &filter)
Solve linear system with initial solution guess.
Definition: base.hpp:347
FEAT namespace.
Definition: adjactor.hpp:12