FEAT 3
Finite Element Analysis Toolbox
Loading...
Searching...
No Matches
direct_sparse_solver_superlu.cpp
1// FEAT3: Finite Element Analysis Toolbox, Version 3
2// Copyright (C) 2010 by Stefan Turek & the FEAT group
3// FEAT3 is released under the GNU General Public License version 3,
4// see the file 'copyright.txt' in the top level directory for details.
5
6// includes, FEAT
8#include <kernel/backend.hpp>
9#include <kernel/solver/direct_sparse_solver.hpp>
10
11#if defined(FEAT_HAVE_SUPERLU_DIST) && defined(FEAT_HAVE_MPI)
12
13FEAT_DISABLE_WARNINGS
14#define VTUNE 0
15#include <superlu_ddefs.h>
16FEAT_RESTORE_WARNINGS
17
18#include <vector>
19
20namespace FEAT
21{
22 namespace Solver
23 {
24 namespace DSS
25 {
26 static_assert(sizeof(SUPERLU_IT) == sizeof(int_t), "DirectSparseSolver: SuperLU: index type size mismatch!");
27
33 class SuperLU_Core
34 {
35 public:
37 MPI_Comm _comm;
39 gridinfo_t slu_grid;
41 superlu_dist_options_t slu_opts;
43 SuperLUStat_t slu_stats;
45 SuperMatrix slu_matrix;
47 NRformat_loc slu_matrix_store;
49 dScalePermstruct_t slu_scale_perm;
51 dLUstruct_t slu_lu_struct;
53 dSOLVEstruct_t slu_solve_struct;
55 int slu_info;
56 const int_t slu_dof_offset, slu_num_owned_dofs, slu_num_global_dofs, slu_num_nonzeros;
57
60 std::vector<int_t> slu_row_ptr, slu_col_idx, slu_col_idx2;
62 std::vector<double> slu_mat_val;
64 std::vector<double> slu_vector;
66 std::vector<double> slu_v_berr;
67
70 bool was_factorized;
71
72 explicit SuperLU_Core(const Dist::Comm& comm, Index num_global_dofs, Index my_dof_offset,
73 Index num_owned_dofs, Index num_owned_nzes, Index DOXY(num_global_nzes)) :
74 _comm(comm.mpi_comm()),
75 slu_info(0),
76 slu_dof_offset(static_cast<int_t>(my_dof_offset)),
77 slu_num_owned_dofs(static_cast<int_t>(num_owned_dofs)),
78 slu_num_global_dofs(static_cast<int_t>(num_global_dofs)),
79 slu_num_nonzeros(static_cast<int_t>(num_owned_nzes)),
80 was_factorized(false)
81 {
82 // first of all, memclear all SuperLU structures just to be safe
83 memset(&slu_grid, 0, sizeof(gridinfo_t));
84 memset(&slu_opts, 0, sizeof(superlu_dist_options_t));
85 memset(&slu_stats, 0, sizeof(SuperLUStat_t));
86 memset(&slu_matrix, 0, sizeof(SuperMatrix));
87 memset(&slu_matrix_store, 0, sizeof(NRformat_loc));
88 memset(&slu_scale_perm, 0, sizeof(dScalePermstruct_t));
89 memset(&slu_lu_struct, 0, sizeof(dLUstruct_t));
90 memset(&slu_solve_struct, 0, sizeof(dSOLVEstruct_t));
91
92 // set up grid
93 int ranks(0);
94 MPI_Comm_size(_comm, &ranks);
95 superlu_gridinit(_comm, ranks, 1, &slu_grid);
96
97 // allocate matrix and vector arrays
98 slu_row_ptr.resize(num_owned_dofs+1u);
99 slu_col_idx.resize(num_owned_nzes);
100 slu_col_idx2.resize(num_owned_nzes);
101 slu_mat_val.resize(num_owned_nzes);
102 slu_vector.resize(num_owned_dofs);
103 slu_v_berr.resize(num_owned_dofs);
104
105 // setup SuperLU matrix
106 slu_matrix.Stype = SLU_NR_loc; // CSR, local
107 slu_matrix.Dtype = SLU_D; // double
108 slu_matrix.Mtype = SLU_GE; // generic matrix
109 slu_matrix.nrow = slu_num_global_dofs;
110 slu_matrix.ncol = slu_num_global_dofs; // always square
111 slu_matrix.Store = &slu_matrix_store;
112
113 // setup SuperLU matrix storage
114 slu_matrix_store.nnz_loc = slu_num_nonzeros;
115 slu_matrix_store.m_loc = slu_num_owned_dofs;
116 slu_matrix_store.fst_row = slu_dof_offset;
117 slu_matrix_store.nzval = slu_mat_val.data();
118 slu_matrix_store.rowptr = slu_row_ptr.data();
119 slu_matrix_store.colind = slu_col_idx2.data(); // use copy here
120 }
121
122 ~SuperLU_Core()
123 {
124 PStatFree(&slu_stats);
125 dLUstructFree(&slu_lu_struct);
126 dScalePermstructFree(&slu_scale_perm);
127 dSolveFinalize(&slu_opts, &slu_solve_struct);
128 superlu_gridexit(&slu_grid);
129 }
130
131 void init_symbolic()
132 {
133 // copy column indices arrays, because SuperLU permutes them
134 memcpy(slu_col_idx2.data(), slu_col_idx.data(), sizeof(int_t)*slu_col_idx.size());
135
136 // set up default options
137 set_default_options_dist(&slu_opts);
138
139 // don't print statistics to cout
140 slu_opts.PrintStat = NO;
141
142 // replace tiny pivots
143 slu_opts.ReplaceTinyPivot = YES;
144
145 // initialize scale perm structure
146 dScalePermstructInit(slu_num_global_dofs, slu_num_global_dofs, &slu_scale_perm);
147
148 // initialize LU structure
149 dLUstructInit(slu_num_global_dofs, &slu_lu_struct);
150
151 // initialize the statistics
152 PStatInit(&slu_stats);
153 }
154
155 void init_numeric()
156 {
157 // reset column-index array because it might have been
158 // overwritten by the previous factorization call
159 memcpy(slu_col_idx2.data(), slu_col_idx.data(), sizeof(int_t)*slu_col_idx.size());
160
161 // have we already factorized before?
162 slu_opts.Fact = (was_factorized ? SamePattern : DOFACT);
163
164 // call the solver routine to factorize the matrix
165 pdgssvx(
166 &slu_opts,
167 &slu_matrix,
168 &slu_scale_perm,
169 nullptr,
170 int(slu_num_owned_dofs),
171 0,
172 &slu_grid,
173 &slu_lu_struct,
174 &slu_solve_struct,
175 slu_v_berr.data(),
176 &slu_stats,
177 &slu_info);
178
179 // factorization successful?
180 if(slu_info == 0)
181 {
182 // okay, remember that we have factorized
183 was_factorized = true;
184 }
185 else if(slu_info < 0)
186 {
187 throw InternalError(__func__, __FILE__, __LINE__, "invalid argument for SuperLU pdgssvx");
188 }
189 else if(slu_info < int(slu_num_global_dofs))
190 {
191 throw DirectSparseSolverException("SuperLU", "zero pivot");
192 }
193 else
194 {
195 throw DirectSparseSolverException("SuperLU", "out of memory");
196 }
197 }
198
199 void done_numeric()
200 {
201 dDestroy_LU(slu_num_global_dofs, &slu_grid, &slu_lu_struct);
202 }
203
204 void solve()
205 {
206 // matrix was already factorized in set_matrix_values() call
207 slu_opts.Fact = FACTORED;
208
209 // call the solver routine
210 pdgssvx(
211 &slu_opts,
212 &slu_matrix,
213 &slu_scale_perm,
214 slu_vector.data(),
215 int(slu_num_owned_dofs),
216 1,
217 &slu_grid,
218 &slu_lu_struct,
219 &slu_solve_struct,
220 slu_v_berr.data(),
221 &slu_stats,
222 &slu_info);
223
224 // check return value
225 if(slu_info == 0)
226 {
227 // ok, everthing's fine
228 }
229 else if(slu_info < 0)
230 {
231 throw InternalError(__func__, __FILE__, __LINE__, "invalid argument for SuperLU pdgssvx");
232 }
233 else if(slu_info < int(slu_num_global_dofs))
234 {
235 throw DirectSparseSolverException("SuperLU", "zero pivot");
236 }
237 else
238 {
239 throw DirectSparseSolverException("SuperLU", "out of memory");
240 }
241 }
242 }; // class SuperLU_Core
243
244 void* create_superlu_core(const Dist::Comm* comm, Index num_global_dofs, Index dof_offset,
245 Index num_owned_dofs, Index num_owned_nzes, Index num_global_nzes)
246 {
247 return new SuperLU_Core(*comm, num_global_dofs, dof_offset, num_owned_dofs, num_owned_nzes, num_global_nzes);
248 }
249
250 void destroy_superlu_core(void* core)
251 {
252 XASSERT(core != nullptr);
253 delete reinterpret_cast<SuperLU_Core*>(core);
254 }
255
256 SUPERLU_IT* get_superlu_row_ptr(void* core)
257 {
258 XASSERT(core != nullptr);
259 return reinterpret_cast<SuperLU_Core*>(core)->slu_row_ptr.data();
260 }
261
262 SUPERLU_IT* get_superlu_col_idx(void* core)
263 {
264 XASSERT(core != nullptr);
265 return reinterpret_cast<SuperLU_Core*>(core)->slu_col_idx.data();
266 }
267
268 SUPERLU_DT* get_superlu_mat_val(void* core)
269 {
270 XASSERT(core != nullptr);
271 return reinterpret_cast<SuperLU_Core*>(core)->slu_mat_val.data();
272 }
273
274 SUPERLU_DT* get_superlu_vector(void* core)
275 {
276 XASSERT(core != nullptr);
277 return reinterpret_cast<SuperLU_Core*>(core)->slu_vector.data();
278 }
279
280 void init_superlu_symbolic(void* core)
281 {
282 XASSERT(core != nullptr);
283 reinterpret_cast<SuperLU_Core*>(core)->init_symbolic();
284 }
285
286 void init_superlu_numeric(void* core)
287 {
288 XASSERT(core != nullptr);
289 reinterpret_cast<SuperLU_Core*>(core)->init_numeric();
290 }
291
292 void done_superlu_numeric(void* core)
293 {
294 XASSERT(core != nullptr);
295 reinterpret_cast<SuperLU_Core*>(core)->done_numeric();
296 }
297
298 void solve_superlu(void* core)
299 {
300 XASSERT(core != nullptr);
301 reinterpret_cast<SuperLU_Core*>(core)->solve();
302 }
303 } // namespace DSS
304 } // namespace Solver
305} // namespace FEAT
306
307#else // no FEAT_HAVE_SUPERLU_DIST
308
309void feat_direct_sparse_solver_superlu_dummy()
310{
311}
312
313#endif // FEAT_HAVE_SUPERLU_DIST
#define XASSERT(expr)
Assertion macro definition.
Definition: assertion.hpp:262
FEAT Kernel base header.
Status solve(SolverBase< Vector_ > &solver, Vector_ &vec_sol, const Vector_ &vec_rhs, const Matrix_ &matrix, const Filter_ &filter)
Solve linear system with initial solution guess.
Definition: base.hpp:347
FEAT namespace.
Definition: adjactor.hpp:12