FEAT 3
Finite Element Analysis Toolbox
Loading...
Searching...
No Matches
product_matmat.cu
1// FEAT3: Finite Element Analysis Toolbox, Version 3
2// Copyright (C) 2010 by Stefan Turek & the FEAT group
3// FEAT3 is released under the GNU General Public License version 3,
4// see the file 'copyright.txt' in the top level directory for details.
5
6// includes, FEAT
7#include <kernel/base_header.hpp>
8#include <kernel/lafem/arch/product_matmat.hpp>
9#include <kernel/util/exception.hpp>
10#include <kernel/util/memory_pool.hpp>
11#include <kernel/util/half.hpp>
12
13#include <cublas_v2.h>
14#include <cublasLt.h>
15#include <cusparse_v2.h>
16
17using namespace FEAT;
18using namespace FEAT::LAFEM;
19using namespace FEAT::LAFEM::Arch;
20
21template <typename DT_>
22void ProductMatMat::dense_cuda(DT_ * r, const DT_ alpha, const DT_ beta, const DT_ * const x, const DT_ * const y, const DT_ * const z, const Index rows, const Index columns, const Index inner)
23{
24 if (r==y || r==x || x==y || z==x || z==y)
25 throw InternalError(__func__, __FILE__, __LINE__, "cuda ProductMatMat does not allow r==y or r==x or x==y or z==x or z==y!");
26
27 cublasStatus_t status;
28
29 // inspired by https://github.com/NVIDIA/CUDALibrarySamples/blob/master/cuBLASLt/LtSgemm/sample_cublasLt_LtSgemm.cu
30
31 cublasLtMatmulDesc_t operationDesc = NULL;
32 cublasLtMatrixLayout_t Rdesc = NULL, Adesc = NULL, Bdesc = NULL, Cdesc = NULL;
33 cublasLtMatmulPreference_t preference = NULL;
34
35 int algo_selector = -1;
36
37 cudaDataType dt;
38 cublasComputeType_t ct;
39 if (typeid(DT_) == typeid(double))
40 {
41 dt = CUDA_R_64F;
42 ct = CUBLAS_COMPUTE_64F;
43 algo_selector = (rows > 1 && columns > 1 && inner > 1) ? 0 : 1;
44 }
45 else if (typeid(DT_) == typeid(float))
46 {
47 dt = CUDA_R_32F;
48#if __CUDA_ARCH__ < 800
49 ct = CUBLAS_COMPUTE_32F;
50#else
51 ct = CUBLAS_COMPUTE_32F_FAST_TF32;
52#endif
53 algo_selector = (rows > 1 && columns > 1 && inner > 1) ? 2 : 3;
54 }
55#ifdef FEAT_HAVE_HALFMATH
56 else if (typeid(DT_) == typeid(Half))
57 {
58 dt = CUDA_R_16F;
59 ct = CUBLAS_COMPUTE_16F;
60 algo_selector = (rows > 1 && columns > 1 && inner > 1) ? 4 : 5;
61 }
62#endif
63 else
64 throw InternalError(__func__, __FILE__, __LINE__, "unsupported data type!");
65
66 status = cublasLtMatmulDescCreate(&operationDesc, ct, dt);
67 if (status != CUBLAS_STATUS_SUCCESS)
68 throw InternalError(__func__, __FILE__, __LINE__, "cuda error: " + stringify(cublasGetStatusString(status)));
69
70 cublasLtOrder_t matrix_order = CUBLASLT_ORDER_ROW;
71 status = cublasLtMatrixLayoutCreate(&Rdesc, dt, rows, columns, columns);
72 if (status != CUBLAS_STATUS_SUCCESS)
73 throw InternalError(__func__, __FILE__, __LINE__, "cuda error: " + stringify(cublasGetStatusString(status)));
74 status = cublasLtMatrixLayoutSetAttribute(Rdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &matrix_order, sizeof(cublasLtOrder_t));
75 if (status != CUBLAS_STATUS_SUCCESS)
76 throw InternalError(__func__, __FILE__, __LINE__, "cuda error: " + stringify(cublasGetStatusString(status)));
77 cublasLtMatrixLayoutCreate(&Adesc, dt, rows, inner, inner);
78 cublasLtMatrixLayoutSetAttribute(Adesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &matrix_order, sizeof(cublasLtOrder_t));
79 cublasLtMatrixLayoutCreate(&Bdesc, dt, inner, columns, columns);
80 cublasLtMatrixLayoutSetAttribute(Bdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &matrix_order, sizeof(cublasLtOrder_t));
81 if (r!=z)
82 {
83 cublasLtMatrixLayoutCreate(&Cdesc, dt, rows, columns, columns);
84 cublasLtMatrixLayoutSetAttribute(Cdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &matrix_order, sizeof(cublasLtOrder_t));
85 }
86 else // r==z -> in-place multiplication
87 {
88 Cdesc = Rdesc;
89 }
90
91 cublasLtMatmulHeuristicResult_t algo_check_result;
92 if (! FEAT::Util::Intern::cublas_lt_algo_matmat_initialized[algo_selector] ||
93 CUBLAS_STATUS_SUCCESS != cublasLtMatmulAlgoCheck((cublasLtHandle_t)Util::Intern::cublas_handle, operationDesc, Adesc, Bdesc, Cdesc, Rdesc, &(FEAT::Util::Intern::cublas_lt_algo_matmat[algo_selector]), &algo_check_result))
94 {
95 int num_algos = 0;
96 cublasLtMatmulHeuristicResult_t heuristic_algos = {};
97
98 status = cublasLtMatmulPreferenceCreate(&preference);
99 if (status != CUBLAS_STATUS_SUCCESS)
100 throw InternalError(__func__, __FILE__, __LINE__, "cuda error: " + stringify(cublasGetStatusString(status)));
101 //status = cublasLtMatmulPreferenceSetAttribute(preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &(FEAT::Util::Intern::cuda_workspace_size), sizeof(FEAT::Util::Intern::cuda_workspace_size));
102 if (status != CUBLAS_STATUS_SUCCESS)
103 throw InternalError(__func__, __FILE__, __LINE__, "cuda error: " + stringify(cublasGetStatusString(status)));
104
105 status = cublasLtMatmulAlgoGetHeuristic((cublasLtHandle_t)Util::Intern::cublas_handle, operationDesc, Adesc, Bdesc, Cdesc, Rdesc, preference, 1, &heuristic_algos, &num_algos);
106 if (status != CUBLAS_STATUS_SUCCESS)
107 throw InternalError(__func__, __FILE__, __LINE__, "cuda error: " + stringify(cublasGetStatusString(status)));
108
109 if (num_algos == 0)
110 throw InternalError(__func__, __FILE__, __LINE__, "no algo supports our matrices!");
111
112 FEAT::Util::Intern::cublas_lt_algo_matmat[algo_selector] = heuristic_algos.algo;
113 FEAT::Util::Intern::cublas_lt_algo_matmat_initialized[algo_selector] = true;
114 }
115
116 cublasLtMatmulAlgo_t * algo = &(FEAT::Util::Intern::cublas_lt_algo_matmat[algo_selector]);
117
118 //status = cublasLtMatmul((cublasLtHandle_t)Util::Intern::cublas_handle, operationDesc, &alpha, x, Adesc, y, Bdesc, &beta, z, Cdesc, r, Rdesc, algo, FEAT::Util::Intern::cuda_workspace, FEAT::Util::Intern::cuda_workspace_size, 0);
119 status = cublasLtMatmul((cublasLtHandle_t)Util::Intern::cublas_handle, operationDesc, &alpha, x, Adesc, y, Bdesc, &beta, z, Cdesc, r, Rdesc, algo, NULL, 0, 0);
120 if (status != CUBLAS_STATUS_SUCCESS)
121 throw InternalError(__func__, __FILE__, __LINE__, "cuda error: " + stringify(cublasGetStatusString(status)));
122
123 cudaDeviceSynchronize();
124#ifdef FEAT_DEBUG_MODE
125 cudaError_t last_error(cudaGetLastError());
126 if (cudaSuccess != last_error)
127 throw InternalError(__func__, __FILE__, __LINE__, "CUDA error occurred in execution!\n" + stringify(cudaGetErrorString(last_error)));
128#endif
129}
130#ifdef FEAT_HAVE_HALFMATH
131template void ProductMatMat::dense_cuda(Half *, const Half, const Half, const Half * const, const Half * const, const Half * const, const Index, const Index, const Index);
132#endif
133template void ProductMatMat::dense_cuda(float *, const float, const float, const float * const, const float * const, const float * const, const Index, const Index, const Index);
134template void ProductMatMat::dense_cuda(double *, const double, const double, const double * const, const double * const, const double * const, const Index, const Index, const Index);
135
136template <typename DT_, typename IT_>
137void ProductMatMat::dsd_cuda(DT_ * r, const DT_ alpha, const DT_ beta, const DT_ * const val, const IT_ * const col_ind, const IT_ * const row_ptr, const Index used_elements,
138 const DT_ * y, const Index rows, const Index columns, const Index inner)
139{
140 if (r==y)
141 throw InternalError(__func__, __FILE__, __LINE__, "cuda ProductMatMat does not allow r==y!");
142
143 cudaDataType dt;
144 cudaDataType ct; //compute type
145 if (typeid(DT_) == typeid(double))
146 {
147 dt = CUDA_R_64F;
148 ct = CUDA_R_64F;
149 }
150 else if (typeid(DT_) == typeid(float))
151 {
152 dt = CUDA_R_32F;
153 ct = CUDA_R_32F;
154 }
155#ifdef FEAT_HAVE_HALFMATH
156 else if (typeid(DT_) == typeid(Half))
157 {
158 dt = CUDA_R_16F;
159 ct = CUDA_R_32F; //cusparseSpMM does not support computation in half, yet
160 }
161#endif
162 else
163 {
164 throw InternalError(__func__, __FILE__, __LINE__, "unsupported data type!");
165 }
166
167 cusparseIndexType_t it;
168 if(sizeof(IT_) == 4u)
169 it = CUSPARSE_INDEX_32I;
170 else if(sizeof(IT_) == 8u)
171 it = CUSPARSE_INDEX_64I;
172 else
173 {
174 throw InternalError(__func__, __FILE__, __LINE__, "unsupported index type!");
175 }
176
177 cusparseStatus_t status;
178
179 cusparseDnMatDescr_t descr_r=0;
180 status = cusparseCreateDnMat(&descr_r, rows, columns, columns, (void*)r, dt, CUSPARSE_ORDER_ROW);
181 if (status != CUSPARSE_STATUS_SUCCESS)
182 throw InternalError(__func__, __FILE__, __LINE__, "cuda error: " + stringify(cusparseGetErrorString(status)));
183
184 cusparseSpMatDescr_t descr_x=0;
185 status = cusparseCreateCsr(&descr_x, rows, inner, used_elements, (void*)row_ptr, (void*)col_ind, (void*)val, it, it, CUSPARSE_INDEX_BASE_ZERO, dt);
186 if (status != CUSPARSE_STATUS_SUCCESS)
187 throw InternalError(__func__, __FILE__, __LINE__, "cuda error: " + stringify(cusparseGetErrorString(status)));
188
189 cusparseDnMatDescr_t descr_y=0;
190 status = cusparseCreateDnMat(&descr_y, inner, columns, columns, (void*)y, dt, CUSPARSE_ORDER_ROW);
191 if (status != CUSPARSE_STATUS_SUCCESS)
192 throw InternalError(__func__, __FILE__, __LINE__, "cuda error: " + stringify(cusparseGetErrorString(status)));
193
194 cusparseOperation_t trans = CUSPARSE_OPERATION_NON_TRANSPOSE;
195 size_t buffer_size(0);
196 status = cusparseSpMM_bufferSize(Util::Intern::cusparse_handle, trans, trans, &alpha, descr_x, descr_y, &beta, descr_r, ct, CUSPARSE_SPMM_CSR_ALG2, &buffer_size);
197 if (status != CUSPARSE_STATUS_SUCCESS)
198 throw InternalError(__func__, __FILE__, __LINE__, "cusparsecsrmvex_buffersize failed with status code: " + stringify(cusparseGetErrorString(status)));
199
200 void* buffer = Util::cuda_get_static_memory(buffer_size);
201
202 status = cusparseSpMM(Util::Intern::cusparse_handle, trans, trans, &alpha, descr_x, descr_y, &beta, descr_r, ct, CUSPARSE_SPMM_CSR_ALG2, buffer);
203 if (status != CUSPARSE_STATUS_SUCCESS)
204 throw InternalError(__func__, __FILE__, __LINE__, "cusparseSpMM failed with status code: " + stringify(cusparseGetErrorString(status)));
205
206 cusparseDestroyDnMat(descr_r);
207 cusparseDestroySpMat(descr_x);
208 cusparseDestroyDnMat(descr_y);
209
210 cudaDeviceSynchronize();
211#ifdef FEAT_DEBUG_MODE
212 cudaError_t last_error(cudaGetLastError());
213 if (cudaSuccess != last_error)
214 throw InternalError(__func__, __FILE__, __LINE__, "CUDA error occurred in execution!\n" + stringify(cudaGetErrorString(last_error)));
215#endif
216}
217#ifdef FEAT_HAVE_HALFMATH
218template void ProductMatMat::dsd_cuda(Half *, const Half, const Half, const Half * const, const std::uint32_t * const, const std::uint32_t * const, const Index, const Half *, const Index, const Index, const Index);
219#endif
220template void ProductMatMat::dsd_cuda(float *, const float, const float, const float * const, const std::uint32_t * const, const std::uint32_t * const, const Index, const float *, const Index, const Index, const Index);
221template void ProductMatMat::dsd_cuda(double *, const double, const double, const double * const, const std::uint32_t * const, const std::uint32_t * const, const Index, const double *, const Index, const Index, const Index);
222#ifdef FEAT_HAVE_HALFMATH
223template void ProductMatMat::dsd_cuda(Half *, const Half, const Half, const Half * const, const std::uint64_t * const, const std::uint64_t * const, const Index, const Half *, const Index, const Index, const Index);
224#endif
225template void ProductMatMat::dsd_cuda(float *, const float, const float, const float * const, const std::uint64_t * const, const std::uint64_t * const, const Index, const float *, const Index, const Index, const Index);
226template void ProductMatMat::dsd_cuda(double *, const double, const double, const double * const, const std::uint64_t * const, const std::uint64_t * const, const Index, const double *, const Index, const Index, const Index);