8#include <kernel/backend.hpp>
9#include <kernel/solver/direct_sparse_solver.hpp>
11#if defined(FEAT_HAVE_SUPERLU_DIST) && defined(FEAT_HAVE_MPI)
15#include <superlu_ddefs.h>
26 static_assert(
sizeof(SUPERLU_IT) ==
sizeof(int_t),
"DirectSparseSolver: SuperLU: index type size mismatch!");
41 superlu_dist_options_t slu_opts;
43 SuperLUStat_t slu_stats;
45 SuperMatrix slu_matrix;
47 NRformat_loc slu_matrix_store;
49 dScalePermstruct_t slu_scale_perm;
51 dLUstruct_t slu_lu_struct;
53 dSOLVEstruct_t slu_solve_struct;
56 const int_t slu_dof_offset, slu_num_owned_dofs, slu_num_global_dofs, slu_num_nonzeros;
60 std::vector<int_t> slu_row_ptr, slu_col_idx, slu_col_idx2;
62 std::vector<double> slu_mat_val;
64 std::vector<double> slu_vector;
66 std::vector<double> slu_v_berr;
72 explicit SuperLU_Core(
const Dist::Comm& comm, Index num_global_dofs, Index my_dof_offset,
73 Index num_owned_dofs, Index num_owned_nzes, Index DOXY(num_global_nzes)) :
74 _comm(comm.mpi_comm()),
76 slu_dof_offset(static_cast<int_t>(my_dof_offset)),
77 slu_num_owned_dofs(static_cast<int_t>(num_owned_dofs)),
78 slu_num_global_dofs(static_cast<int_t>(num_global_dofs)),
79 slu_num_nonzeros(static_cast<int_t>(num_owned_nzes)),
83 memset(&slu_grid, 0,
sizeof(gridinfo_t));
84 memset(&slu_opts, 0,
sizeof(superlu_dist_options_t));
85 memset(&slu_stats, 0,
sizeof(SuperLUStat_t));
86 memset(&slu_matrix, 0,
sizeof(SuperMatrix));
87 memset(&slu_matrix_store, 0,
sizeof(NRformat_loc));
88 memset(&slu_scale_perm, 0,
sizeof(dScalePermstruct_t));
89 memset(&slu_lu_struct, 0,
sizeof(dLUstruct_t));
90 memset(&slu_solve_struct, 0,
sizeof(dSOLVEstruct_t));
94 MPI_Comm_size(_comm, &ranks);
95 superlu_gridinit(_comm, ranks, 1, &slu_grid);
98 slu_row_ptr.resize(num_owned_dofs+1u);
99 slu_col_idx.resize(num_owned_nzes);
100 slu_col_idx2.resize(num_owned_nzes);
101 slu_mat_val.resize(num_owned_nzes);
102 slu_vector.resize(num_owned_dofs);
103 slu_v_berr.resize(num_owned_dofs);
106 slu_matrix.Stype = SLU_NR_loc;
107 slu_matrix.Dtype = SLU_D;
108 slu_matrix.Mtype = SLU_GE;
109 slu_matrix.nrow = slu_num_global_dofs;
110 slu_matrix.ncol = slu_num_global_dofs;
111 slu_matrix.Store = &slu_matrix_store;
114 slu_matrix_store.nnz_loc = slu_num_nonzeros;
115 slu_matrix_store.m_loc = slu_num_owned_dofs;
116 slu_matrix_store.fst_row = slu_dof_offset;
117 slu_matrix_store.nzval = slu_mat_val.data();
118 slu_matrix_store.rowptr = slu_row_ptr.data();
119 slu_matrix_store.colind = slu_col_idx2.data();
124 PStatFree(&slu_stats);
125 dLUstructFree(&slu_lu_struct);
126 dScalePermstructFree(&slu_scale_perm);
127 dSolveFinalize(&slu_opts, &slu_solve_struct);
128 superlu_gridexit(&slu_grid);
134 memcpy(slu_col_idx2.data(), slu_col_idx.data(),
sizeof(int_t)*slu_col_idx.size());
137 set_default_options_dist(&slu_opts);
140 slu_opts.PrintStat = NO;
143 slu_opts.ReplaceTinyPivot = YES;
146 dScalePermstructInit(slu_num_global_dofs, slu_num_global_dofs, &slu_scale_perm);
149 dLUstructInit(slu_num_global_dofs, &slu_lu_struct);
152 PStatInit(&slu_stats);
159 memcpy(slu_col_idx2.data(), slu_col_idx.data(),
sizeof(int_t)*slu_col_idx.size());
162 slu_opts.Fact = (was_factorized ? SamePattern : DOFACT);
170 int(slu_num_owned_dofs),
183 was_factorized =
true;
185 else if(slu_info < 0)
187 throw InternalError(__func__, __FILE__, __LINE__,
"invalid argument for SuperLU pdgssvx");
189 else if(slu_info <
int(slu_num_global_dofs))
191 throw DirectSparseSolverException(
"SuperLU",
"zero pivot");
195 throw DirectSparseSolverException(
"SuperLU",
"out of memory");
201 dDestroy_LU(slu_num_global_dofs, &slu_grid, &slu_lu_struct);
207 slu_opts.Fact = FACTORED;
215 int(slu_num_owned_dofs),
229 else if(slu_info < 0)
231 throw InternalError(__func__, __FILE__, __LINE__,
"invalid argument for SuperLU pdgssvx");
233 else if(slu_info <
int(slu_num_global_dofs))
235 throw DirectSparseSolverException(
"SuperLU",
"zero pivot");
239 throw DirectSparseSolverException(
"SuperLU",
"out of memory");
244 void* create_superlu_core(
const Dist::Comm* comm, Index num_global_dofs, Index dof_offset,
245 Index num_owned_dofs, Index num_owned_nzes, Index num_global_nzes)
247 return new SuperLU_Core(*comm, num_global_dofs, dof_offset, num_owned_dofs, num_owned_nzes, num_global_nzes);
250 void destroy_superlu_core(
void* core)
253 delete reinterpret_cast<SuperLU_Core*
>(core);
256 SUPERLU_IT* get_superlu_row_ptr(
void* core)
259 return reinterpret_cast<SuperLU_Core*
>(core)->slu_row_ptr.data();
262 SUPERLU_IT* get_superlu_col_idx(
void* core)
265 return reinterpret_cast<SuperLU_Core*
>(core)->slu_col_idx.data();
268 SUPERLU_DT* get_superlu_mat_val(
void* core)
271 return reinterpret_cast<SuperLU_Core*
>(core)->slu_mat_val.data();
274 SUPERLU_DT* get_superlu_vector(
void* core)
277 return reinterpret_cast<SuperLU_Core*
>(core)->slu_vector.data();
280 void init_superlu_symbolic(
void* core)
283 reinterpret_cast<SuperLU_Core*
>(core)->init_symbolic();
286 void init_superlu_numeric(
void* core)
289 reinterpret_cast<SuperLU_Core*
>(core)->init_numeric();
292 void done_superlu_numeric(
void* core)
295 reinterpret_cast<SuperLU_Core*
>(core)->done_numeric();
298 void solve_superlu(
void* core)
301 reinterpret_cast<SuperLU_Core*
>(core)->
solve();
309void feat_direct_sparse_solver_superlu_dummy()
#define XASSERT(expr)
Assertion macro definition.
Status solve(SolverBase< Vector_ > &solver, Vector_ &vec_sol, const Vector_ &vec_rhs, const Matrix_ &matrix, const Filter_ &filter)
Solve linear system with initial solution guess.