FEAT 3
Finite Element Analysis Toolbox
Loading...
Searching...
No Matches
direct_sparse_solver_cudss.cpp
1// FEAT3: Finite Element Analysis Toolbox, Version 3
2// Copyright (C) 2010 by Stefan Turek & the FEAT group
3// FEAT3 is released under the GNU General Public License version 3,
4// see the file 'copyright.txt' in the top level directory for details.
5
6// includes, FEAT
8#include <kernel/backend.hpp>
9#include <kernel/util/memory_pool.hpp>
10#include <kernel/util/cuda_util.hpp>
11#include <kernel/solver/direct_sparse_solver.hpp>
12
13#ifdef FEAT_HAVE_CUDSS
14
15FEAT_DISABLE_WARNINGS
16#include <cudss.h>
17FEAT_RESTORE_WARNINGS
18
19namespace FEAT
20{
21 namespace Solver
22 {
23 namespace DSS
24 {
25 static_assert(sizeof(CUDSS_IT) == 4, "DirectSparseSolver: cuDSS: index type size mismatch!");
26
32 class CUDSS_Core
33 {
34 public:
35#ifdef FEAT_HAVE_MPI
36 MPI_Comm mpi_comm;
37#endif
38 cudssHandle_t handle;
39 cudssConfig_t config;
40 cudssData_t data;
41 cudssMatrix_t matrix;
42 cudssMatrix_t vec_sol;
43 cudssMatrix_t vec_rhs;
44
45 // peak memory usage estimate
46 std::int64_t memory_estimates[16];
47
48 // system dimensions
49 std::int64_t num_global_dofs, dof_offset, num_owned_dofs, num_owned_nzes, num_global_nzes, num_rhs;
50
51 // system data in host memory
52 std::vector<CUDSS_IT> row_ptr_host, col_idx_host;
53 std::vector<CUDSS_DT> mat_val_host, rhs_val_host, sol_val_host;
54
55 // system data in device memory
56 void *row_ptr_dev, *col_idx_dev;
57 void *mat_val_dev, *rhs_val_dev, *sol_val_dev;
58
59 explicit CUDSS_Core(const Dist::Comm& comm, Index num_global_dofs_, Index dof_offset_,
60 Index num_owned_dofs_, Index num_owned_nzes_, Index num_global_nzes_) :
61#ifdef FEAT_HAVE_MPI
62 mpi_comm(comm.mpi_comm()),
63#endif
64 handle(reinterpret_cast<cudssHandle_t>(Runtime::get_cudss_handle())),
65 config(nullptr),
66 data(nullptr),
67 matrix(nullptr),
68 vec_sol(nullptr),
69 vec_rhs(nullptr),
70 num_global_dofs(std::uint32_t(num_global_dofs_)),
71 dof_offset(std::uint32_t(dof_offset_)),
72 num_owned_dofs(std::uint32_t(num_owned_dofs_)),
73 num_owned_nzes(std::uint32_t(num_owned_nzes_)),
74 num_global_nzes(std::uint32_t(num_global_nzes_)),
75 num_rhs(1),
76 row_ptr_host(),
77 col_idx_host(),
78 mat_val_host(),
79 rhs_val_host(),
80 sol_val_host(),
81 row_ptr_dev(nullptr),
82 col_idx_dev(nullptr),
83 mat_val_dev(nullptr),
84 rhs_val_dev(nullptr),
85 sol_val_dev(nullptr)
86 {
87 // prevent unused parameter warnings
88 (void)comm;
89
90 XASSERTM(handle != nullptr, "Failed to retrieve cuDSS handle!");
91
92 if(CUDSS_STATUS_SUCCESS != cudssConfigCreate(&config))
93 throw InternalError(__func__, __FILE__, __LINE__, "cudssConfigCreate failed!");
94
95 if(CUDSS_STATUS_SUCCESS != cudssDataCreate(handle, &data))
96 throw InternalError(__func__, __FILE__, __LINE__, "cudssDataCreate failed!");
97
98#ifdef FEAT_HAVE_MPI
99 if(CUDSS_STATUS_SUCCESS != cudssDataSet(handle, data, CUDSS_DATA_COMM, &mpi_comm, sizeof(MPI_Comm*)))
100 throw InternalError(__func__, __FILE__, __LINE__, "cudssDataSet for 'CUDSS_DATA_COMM' failed!");
101#endif
102
103 // format memory estimates
104 memset(memory_estimates, 0, sizeof(memory_estimates));
105
106 // allocate matrix and vector arrays in host memory
107 row_ptr_host.resize(std::size_t(num_owned_dofs+1), 0u);
108 col_idx_host.resize(std::size_t(num_owned_nzes), 0u);
109 mat_val_host.resize(std::size_t(num_owned_nzes), 0.0);
110 rhs_val_host.resize(std::size_t(num_owned_dofs * num_rhs), 0.0);
111 sol_val_host.resize(std::size_t(num_owned_dofs * num_rhs), 0.0);
112
113 // allocate matrix and vector arrays in device memory
114 row_ptr_dev = Util::cuda_malloc(sizeof(CUDSS_IT) * std::size_t(num_owned_dofs+1));
115 col_idx_dev = Util::cuda_malloc(sizeof(CUDSS_IT) * std::size_t(num_owned_nzes));
116 mat_val_dev = Util::cuda_malloc(sizeof(CUDSS_DT) * std::size_t(num_owned_nzes));
117 rhs_val_dev = Util::cuda_malloc(sizeof(CUDSS_DT) * std::size_t(num_owned_dofs * num_rhs));
118 sol_val_dev = Util::cuda_malloc(sizeof(CUDSS_DT) * std::size_t(num_owned_dofs * num_rhs));
119 }
120
121 ~CUDSS_Core()
122 {
123 if(vec_sol)
124 cudssMatrixDestroy(vec_sol);
125 if(vec_rhs)
126 cudssMatrixDestroy(vec_rhs);
127 if(matrix)
128 cudssMatrixDestroy(matrix);
129
130 if(sol_val_dev)
131 Util::cuda_free(sol_val_dev);
132 if(rhs_val_dev)
133 Util::cuda_free(rhs_val_dev);
134 if(mat_val_dev)
135 Util::cuda_free(mat_val_dev);
136 if(col_idx_dev)
137 Util::cuda_free(col_idx_dev);
138 if(row_ptr_dev)
139 Util::cuda_free(row_ptr_dev);
140 if(data)
141 cudssDataDestroy(handle, data);
142 if(config)
143 cudssConfigDestroy(config);
144
145 // wait until CUDA is done
146 Util::cuda_synchronize();
147 }
148
149 void init_symbolic()
150 {
151 cudssStatus_t ret = CUDSS_STATUS_INTERNAL_ERROR;
152
153 // first and last DOF owned by this process
154 const std::int64_t last_owned_dof = dof_offset + num_owned_dofs - 1;
155
156 // copy structure host arrays to device
157 Util::cuda_copy_host_to_device(row_ptr_dev, row_ptr_host.data(), sizeof(CUDSS_IT) * std::size_t(num_owned_dofs+1));
158 Util::cuda_copy_host_to_device(col_idx_dev, col_idx_host.data(), sizeof(CUDSS_IT) * std::size_t(num_owned_nzes));
159
160 // set the basic configuration
161 //cudssAlgType_t alg_type = CUDSS_ALG_DEFAULT;//CUDSS_ALG_1; // COLAMD based ordering
162 //ret = cudssConfigSet(config, CUDSS_CONFIG_REORDERING_ALG, &alg_type, sizeof(cudssAlgType_t));
163 //if(ret != CUDSS_STATUS_SUCCESS)
164 //throw DirectSparseSolverException("cuDSS", "cudssConfigSet() for 'CUDSS_CONFIG_REORDERING_ALG' failed!");
165
166 // set matrix data
167 ret = cudssMatrixCreateCsr(
168 &matrix,
169 num_global_dofs,
170 num_global_dofs,
171 num_global_nzes,
172 row_ptr_dev,
173 nullptr,
174 col_idx_dev,
175 mat_val_dev,
176 CUDA_R_32I,
177 CUDA_R_64F,
178 CUDSS_MTYPE_GENERAL,
179 CUDSS_MVIEW_FULL, // is ignored
180 CUDSS_BASE_ZERO);
181 if(ret != CUDSS_STATUS_SUCCESS)
182 throw DirectSparseSolverException("cuDSS", "cudssMatrixCreateCsr() for system matrix failed!");
183
184 // set row distribution of matrix
185 ret = cudssMatrixSetDistributionRow1d(matrix, dof_offset, last_owned_dof);
186 if(ret != CUDSS_STATUS_SUCCESS)
187 throw DirectSparseSolverException("cuDSS", "cudssMatrixSetDistributionRow1d() for solution vector failed!");
188
189 // allocate solution vector
190 ret = cudssMatrixCreateDn(
191 &vec_sol,
192 num_global_dofs,
193 num_rhs,
194 num_global_dofs,
195 nullptr,
196 CUDA_R_64F,
197 CUDSS_LAYOUT_COL_MAJOR);
198 if(ret != CUDSS_STATUS_SUCCESS)
199 throw DirectSparseSolverException("cuDSS", "cudssMatrixCreateDn() for solution vector failed!");
200
201 // set row distribution of solution vector
202 ret = cudssMatrixSetDistributionRow1d(vec_sol, dof_offset, last_owned_dof);
203 if(ret != CUDSS_STATUS_SUCCESS)
204 throw DirectSparseSolverException("cuDSS", "cudssMatrixSetDistributionRow1d() for solution vector failed!");
205
206 // allocate rhs vector
207 ret = cudssMatrixCreateDn(
208 &vec_rhs,
209 num_global_dofs,
210 num_rhs,
211 num_global_dofs,
212 nullptr,
213 CUDA_R_64F,
214 CUDSS_LAYOUT_COL_MAJOR);
215 if(ret != CUDSS_STATUS_SUCCESS)
216 throw DirectSparseSolverException("cuDSS", "cudssMatrixCreateDn() for rhs vector failed!");
217
218 // set row distribution of solution vector
219 ret = cudssMatrixSetDistributionRow1d(vec_rhs, dof_offset, last_owned_dof);
220 if(ret != CUDSS_STATUS_SUCCESS)
221 throw DirectSparseSolverException("cuDSS", "cudssMatrixSetDistributionRow1d() for rhs vector failed!");
222
223 // perform symbolic factorization
224 ret = cudssExecute(
225 handle,
226 CUDSS_PHASE_ANALYSIS,
227 config,
228 data,
229 matrix,
230 vec_sol,
231 vec_rhs);
232 if(ret != CUDSS_STATUS_SUCCESS)
233 throw DirectSparseSolverException("cuDSS", "cudssExecute() for phase 'CUDSS_PHASE_ANALYSIS' failed!");
234
235 // retrieve memory estimates for current device
236 std::size_t bytes_written(0u);
237 ret = cudssDataGet(
238 handle,
239 data,
240 CUDSS_DATA_MEMORY_ESTIMATES,
241 memory_estimates,
242 sizeof(memory_estimates),
243 &bytes_written);
244 if(ret != CUDSS_STATUS_SUCCESS)
245 throw DirectSparseSolverException("cuDSS", "cudssDataGet() for 'CUDSS_DATA_MEMORY_ESTIMATES' failed!");
246
247 // wait until CUDA is done
248 Util::cuda_synchronize();
249 }
250
251 void init_numeric()
252 {
253 // copy value host arrays to device
254 Util::cuda_copy_host_to_device(mat_val_dev, mat_val_host.data(), sizeof(CUDSS_DT) * std::size_t(num_owned_nzes));
255
256 // perform numeric factorization
257 cudssStatus_t ret = cudssExecute(
258 handle,
259 CUDSS_PHASE_FACTORIZATION,
260 config,
261 data,
262 matrix,
263 vec_sol,
264 vec_rhs);
265 if(ret != CUDSS_STATUS_SUCCESS)
266 throw DirectSparseSolverException("cuDSS", "cudssExecute() for phase 'CUDSS_PHASE_FACTORIZATION' failed!");
267
268 // wait until CUDA is done
269 Util::cuda_synchronize();
270 }
271
272 void solve()
273 {
274 cudssStatus_t ret = CUDSS_STATUS_INTERNAL_ERROR;
275
276 // copy RHS value array to device
277 Util::cuda_copy_host_to_device(rhs_val_dev, rhs_val_host.data(), sizeof(CUDSS_DT) * std::size_t(num_owned_dofs * num_rhs));
278
279 // set solution vector data array
280 ret = cudssMatrixSetValues(vec_sol, sol_val_dev);
281 if(ret != CUDSS_STATUS_SUCCESS)
282 throw DirectSparseSolverException("cuDSS", "cudssMatrixSetValues() failed for vec_sol!");
283
284 // set rhs vector data array
285 cudssMatrixSetValues(vec_rhs, rhs_val_dev);
286 if(ret != CUDSS_STATUS_SUCCESS)
287 throw DirectSparseSolverException("cuDSS", "cudssMatrixSetValues() failed for vec_rhs!");
288
289 // solve
290 ret = cudssExecute(
291 handle,
292 CUDSS_PHASE_SOLVE,
293 config,
294 data,
295 matrix,
296 vec_sol,
297 vec_rhs);
298 if(ret != CUDSS_STATUS_SUCCESS)
299 throw DirectSparseSolverException("cuDSS", "cudssExecute() for phase 'CUDSS_PHASE_SOLVE' failed!");
300
301 // copy sol values array to host
302 Util::cuda_copy_device_to_host(sol_val_host.data(), sol_val_dev, sizeof(CUDSS_DT) * std::size_t(num_owned_dofs * num_rhs));
303
304 // wait until CUDA is done
305 Util::cuda_synchronize();
306 }
307
308 std::int64_t get_peak_mem_device() const
309 {
310 return memory_estimates[1];
311 }
312
313 std::int64_t get_peak_mem_host() const
314 {
315 return memory_estimates[3];
316 }
317 }; // class CUDSS_Core
318
319 void* create_cudss_core(const Dist::Comm* comm, Index num_global_dofs, Index dof_offset,
320 Index num_owned_dofs, Index num_owned_nzes, Index num_global_nzes)
321 {
322 return new CUDSS_Core(*comm, num_global_dofs, dof_offset, num_owned_dofs, num_owned_nzes, num_global_nzes);
323 }
324
325 void destroy_cudss_core(void* core)
326 {
327 XASSERT(core != nullptr);
328 delete reinterpret_cast<CUDSS_Core*>(core);
329 }
330
331 CUDSS_IT* get_cudss_row_ptr(void* core)
332 {
333 XASSERT(core != nullptr);
334 return reinterpret_cast<CUDSS_Core*>(core)->row_ptr_host.data();
335 }
336
337 CUDSS_IT* get_cudss_col_idx(void* core)
338 {
339 XASSERT(core != nullptr);
340 return reinterpret_cast<CUDSS_Core*>(core)->col_idx_host.data();
341 }
342
343 CUDSS_DT* get_cudss_mat_val(void* core)
344 {
345 XASSERT(core != nullptr);
346 return reinterpret_cast<CUDSS_Core*>(core)->mat_val_host.data();
347 }
348
349 CUDSS_DT* get_cudss_rhs_val(void* core)
350 {
351 XASSERT(core != nullptr);
352 return reinterpret_cast<CUDSS_Core*>(core)->rhs_val_host.data();
353 }
354
355 CUDSS_DT* get_cudss_sol_val(void* core)
356 {
357 XASSERT(core != nullptr);
358 return reinterpret_cast<CUDSS_Core*>(core)->sol_val_host.data();
359 }
360
361 void init_cudss_symbolic(void* core)
362 {
363 XASSERT(core != nullptr);
364 reinterpret_cast<CUDSS_Core*>(core)->init_symbolic();
365 }
366
367 void init_cudss_numeric(void* core)
368 {
369 XASSERT(core != nullptr);
370 reinterpret_cast<CUDSS_Core*>(core)->init_numeric();
371 }
372
373 void solve_cudss(void* core)
374 {
375 XASSERT(core != nullptr);
376 reinterpret_cast<CUDSS_Core*>(core)->solve();
377 }
378
379 std::int64_t get_peak_mem_cudss_host(void* core)
380 {
381 XASSERT(core != nullptr);
382 return reinterpret_cast<CUDSS_Core*>(core)->get_peak_mem_host();
383 }
384
385 std::int64_t get_peak_mem_cudss_device(void* core)
386 {
387 XASSERT(core != nullptr);
388 return reinterpret_cast<CUDSS_Core*>(core)->get_peak_mem_device();
389 }
390 } // namespace DSS
391 } // namespace Solver
392} // namespace FEAT
393
394#else // no FEAT_HAVE_CUDSS
395
396void feat_direct_sparse_solver_cudss_dummy()
397{
398}
399
400#endif // FEAT_HAVE_CUDSS
#define XASSERT(expr)
Assertion macro definition.
Definition: assertion.hpp:262
#define XASSERTM(expr, msg)
Assertion macro definition with custom message.
Definition: assertion.hpp:263
FEAT Kernel base header.
Status solve(SolverBase< Vector_ > &solver, Vector_ &vec_sol, const Vector_ &vec_rhs, const Matrix_ &matrix, const Filter_ &filter)
Solve linear system with initial solution guess.
Definition: base.hpp:347
FEAT namespace.
Definition: adjactor.hpp:12