1// FEAT3: Finite Element Analysis Toolbox, Version 3
2// Copyright (C) 2010 by Stefan Turek & the FEAT group
3// FEAT3 is released under the GNU General Public License version 3,
4// see the file 'copyright.txt' in the top level directory for details.
7#include <kernel/base_header.hpp>
8#include <kernel/lafem/arch/product_matmat.hpp>
9#include <kernel/util/exception.hpp>
10#include <kernel/util/memory_pool.hpp>
11#include <kernel/util/half.hpp>
15#include <cusparse_v2.h>
18using namespace FEAT::LAFEM;
19using namespace FEAT::LAFEM::Arch;
21template <typename DT_>
22void ProductMatMat::dense_cuda(DT_ * r, const DT_ alpha, const DT_ beta, const DT_ * const x, const DT_ * const y, const DT_ * const z, const Index rows, const Index columns, const Index inner)
24 if (r==y || r==x || x==y || z==x || z==y)
25 throw InternalError(__func__, __FILE__, __LINE__, "cuda ProductMatMat does not allow r==y or r==x or x==y or z==x or z==y!");
27 cublasStatus_t status;
29 // inspired by https://github.com/NVIDIA/CUDALibrarySamples/blob/master/cuBLASLt/LtSgemm/sample_cublasLt_LtSgemm.cu
31 cublasLtMatmulDesc_t operationDesc = NULL;
32 cublasLtMatrixLayout_t Rdesc = NULL, Adesc = NULL, Bdesc = NULL, Cdesc = NULL;
33 cublasLtMatmulPreference_t preference = NULL;
35 int algo_selector = -1;
38 cublasComputeType_t ct;
39 if (typeid(DT_) == typeid(double))
42 ct = CUBLAS_COMPUTE_64F;
43 algo_selector = (rows > 1 && columns > 1 && inner > 1) ? 0 : 1;
45 else if (typeid(DT_) == typeid(float))
48#if __CUDA_ARCH__ < 800
49 ct = CUBLAS_COMPUTE_32F;
51 ct = CUBLAS_COMPUTE_32F_FAST_TF32;
53 algo_selector = (rows > 1 && columns > 1 && inner > 1) ? 2 : 3;
55#ifdef FEAT_HAVE_HALFMATH
56 else if (typeid(DT_) == typeid(Half))
59 ct = CUBLAS_COMPUTE_16F;
60 algo_selector = (rows > 1 && columns > 1 && inner > 1) ? 4 : 5;
64 throw InternalError(__func__, __FILE__, __LINE__, "unsupported data type!");
66 status = cublasLtMatmulDescCreate(&operationDesc, ct, dt);
67 if (status != CUBLAS_STATUS_SUCCESS)
68 throw InternalError(__func__, __FILE__, __LINE__, "cuda error: " + stringify(cublasGetStatusString(status)));
70 cublasLtOrder_t matrix_order = CUBLASLT_ORDER_ROW;
71 status = cublasLtMatrixLayoutCreate(&Rdesc, dt, rows, columns, columns);
72 if (status != CUBLAS_STATUS_SUCCESS)
73 throw InternalError(__func__, __FILE__, __LINE__, "cuda error: " + stringify(cublasGetStatusString(status)));
74 status = cublasLtMatrixLayoutSetAttribute(Rdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &matrix_order, sizeof(cublasLtOrder_t));
75 if (status != CUBLAS_STATUS_SUCCESS)
76 throw InternalError(__func__, __FILE__, __LINE__, "cuda error: " + stringify(cublasGetStatusString(status)));
77 cublasLtMatrixLayoutCreate(&Adesc, dt, rows, inner, inner);
78 cublasLtMatrixLayoutSetAttribute(Adesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &matrix_order, sizeof(cublasLtOrder_t));
79 cublasLtMatrixLayoutCreate(&Bdesc, dt, inner, columns, columns);
80 cublasLtMatrixLayoutSetAttribute(Bdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &matrix_order, sizeof(cublasLtOrder_t));
83 cublasLtMatrixLayoutCreate(&Cdesc, dt, rows, columns, columns);
84 cublasLtMatrixLayoutSetAttribute(Cdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &matrix_order, sizeof(cublasLtOrder_t));
86 else // r==z -> in-place multiplication
91 cublasLtMatmulHeuristicResult_t algo_check_result;
92 if (! FEAT::Util::Intern::cublas_lt_algo_matmat_initialized[algo_selector] ||
93 CUBLAS_STATUS_SUCCESS != cublasLtMatmulAlgoCheck((cublasLtHandle_t)Util::Intern::cublas_handle, operationDesc, Adesc, Bdesc, Cdesc, Rdesc, &(FEAT::Util::Intern::cublas_lt_algo_matmat[algo_selector]), &algo_check_result))
96 cublasLtMatmulHeuristicResult_t heuristic_algos = {};
98 status = cublasLtMatmulPreferenceCreate(&preference);
99 if (status != CUBLAS_STATUS_SUCCESS)
100 throw InternalError(__func__, __FILE__, __LINE__, "cuda error: " + stringify(cublasGetStatusString(status)));
101 //status = cublasLtMatmulPreferenceSetAttribute(preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &(FEAT::Util::Intern::cuda_workspace_size), sizeof(FEAT::Util::Intern::cuda_workspace_size));
102 if (status != CUBLAS_STATUS_SUCCESS)
103 throw InternalError(__func__, __FILE__, __LINE__, "cuda error: " + stringify(cublasGetStatusString(status)));
105 status = cublasLtMatmulAlgoGetHeuristic((cublasLtHandle_t)Util::Intern::cublas_handle, operationDesc, Adesc, Bdesc, Cdesc, Rdesc, preference, 1, &heuristic_algos, &num_algos);
106 if (status != CUBLAS_STATUS_SUCCESS)
107 throw InternalError(__func__, __FILE__, __LINE__, "cuda error: " + stringify(cublasGetStatusString(status)));
110 throw InternalError(__func__, __FILE__, __LINE__, "no algo supports our matrices!");
112 FEAT::Util::Intern::cublas_lt_algo_matmat[algo_selector] = heuristic_algos.algo;
113 FEAT::Util::Intern::cublas_lt_algo_matmat_initialized[algo_selector] = true;
116 cublasLtMatmulAlgo_t * algo = &(FEAT::Util::Intern::cublas_lt_algo_matmat[algo_selector]);
118 //status = cublasLtMatmul((cublasLtHandle_t)Util::Intern::cublas_handle, operationDesc, &alpha, x, Adesc, y, Bdesc, &beta, z, Cdesc, r, Rdesc, algo, FEAT::Util::Intern::cuda_workspace, FEAT::Util::Intern::cuda_workspace_size, 0);
119 status = cublasLtMatmul((cublasLtHandle_t)Util::Intern::cublas_handle, operationDesc, &alpha, x, Adesc, y, Bdesc, &beta, z, Cdesc, r, Rdesc, algo, NULL, 0, 0);
120 if (status != CUBLAS_STATUS_SUCCESS)
121 throw InternalError(__func__, __FILE__, __LINE__, "cuda error: " + stringify(cublasGetStatusString(status)));
123 cudaDeviceSynchronize();
124#ifdef FEAT_DEBUG_MODE
125 cudaError_t last_error(cudaGetLastError());
126 if (cudaSuccess != last_error)
127 throw InternalError(__func__, __FILE__, __LINE__, "CUDA error occurred in execution!\n" + stringify(cudaGetErrorString(last_error)));
130#ifdef FEAT_HAVE_HALFMATH
131template void ProductMatMat::dense_cuda(Half *, const Half, const Half, const Half * const, const Half * const, const Half * const, const Index, const Index, const Index);
133template void ProductMatMat::dense_cuda(float *, const float, const float, const float * const, const float * const, const float * const, const Index, const Index, const Index);
134template void ProductMatMat::dense_cuda(double *, const double, const double, const double * const, const double * const, const double * const, const Index, const Index, const Index);
136template <typename DT_, typename IT_>
137void ProductMatMat::dsd_cuda(DT_ * r, const DT_ alpha, const DT_ beta, const DT_ * const val, const IT_ * const col_ind, const IT_ * const row_ptr, const Index used_elements,
138 const DT_ * y, const Index rows, const Index columns, const Index inner)
141 throw InternalError(__func__, __FILE__, __LINE__, "cuda ProductMatMat does not allow r==y!");
144 cudaDataType ct; //compute type
145 if (typeid(DT_) == typeid(double))
150 else if (typeid(DT_) == typeid(float))
155#ifdef FEAT_HAVE_HALFMATH
156 else if (typeid(DT_) == typeid(Half))
159 ct = CUDA_R_32F; //cusparseSpMM does not support computation in half, yet
164 throw InternalError(__func__, __FILE__, __LINE__, "unsupported data type!");
167 cusparseIndexType_t it;
168 if(sizeof(IT_) == 4u)
169 it = CUSPARSE_INDEX_32I;
170 else if(sizeof(IT_) == 8u)
171 it = CUSPARSE_INDEX_64I;
174 throw InternalError(__func__, __FILE__, __LINE__, "unsupported index type!");
177 cusparseStatus_t status;
179 cusparseDnMatDescr_t descr_r=0;
180 status = cusparseCreateDnMat(&descr_r, rows, columns, columns, (void*)r, dt, CUSPARSE_ORDER_ROW);
181 if (status != CUSPARSE_STATUS_SUCCESS)
182 throw InternalError(__func__, __FILE__, __LINE__, "cuda error: " + stringify(cusparseGetErrorString(status)));
184 cusparseSpMatDescr_t descr_x=0;
185 status = cusparseCreateCsr(&descr_x, rows, inner, used_elements, (void*)row_ptr, (void*)col_ind, (void*)val, it, it, CUSPARSE_INDEX_BASE_ZERO, dt);
186 if (status != CUSPARSE_STATUS_SUCCESS)
187 throw InternalError(__func__, __FILE__, __LINE__, "cuda error: " + stringify(cusparseGetErrorString(status)));
189 cusparseDnMatDescr_t descr_y=0;
190 status = cusparseCreateDnMat(&descr_y, inner, columns, columns, (void*)y, dt, CUSPARSE_ORDER_ROW);
191 if (status != CUSPARSE_STATUS_SUCCESS)
192 throw InternalError(__func__, __FILE__, __LINE__, "cuda error: " + stringify(cusparseGetErrorString(status)));
194 cusparseOperation_t trans = CUSPARSE_OPERATION_NON_TRANSPOSE;
195 size_t buffer_size(0);
197 float alpha_tmp = float(alpha);
198 float beta_tmp = float(beta);
200 void* const alpha_ptr = dt == CUDA_R_16F ? (void*)&alpha_tmp : (void*)α
201 void* const beta_ptr = dt == CUDA_R_16F ? (void*)&beta_tmp : (void*)β
203 status = cusparseSpMM_bufferSize(Util::Intern::cusparse_handle, trans, trans, alpha_ptr, descr_x, descr_y, beta_ptr, descr_r, ct, CUSPARSE_SPMM_CSR_ALG2, &buffer_size);
204 if (status != CUSPARSE_STATUS_SUCCESS)
205 throw InternalError(__func__, __FILE__, __LINE__, "cusparsecsrmvex_buffersize failed with status code: " + stringify(cusparseGetErrorString(status)));
207 void* buffer = Util::cuda_get_static_memory(buffer_size);
209 status = cusparseSpMM(Util::Intern::cusparse_handle, trans, trans, alpha_ptr, descr_x, descr_y, beta_ptr, descr_r, ct, CUSPARSE_SPMM_CSR_ALG2, buffer);
210 if (status != CUSPARSE_STATUS_SUCCESS)
211 throw InternalError(__func__, __FILE__, __LINE__, "cusparseSpMM failed with status code: " + stringify(cusparseGetErrorString(status)));
213 cusparseDestroyDnMat(descr_r);
214 cusparseDestroySpMat(descr_x);
215 cusparseDestroyDnMat(descr_y);
217 cudaDeviceSynchronize();
218#ifdef FEAT_DEBUG_MODE
219 cudaError_t last_error(cudaGetLastError());
220 if (cudaSuccess != last_error)
221 throw InternalError(__func__, __FILE__, __LINE__, "CUDA error occurred in execution!\n" + stringify(cudaGetErrorString(last_error)));
224#ifdef FEAT_HAVE_HALFMATH
225template void ProductMatMat::dsd_cuda(Half *, const Half, const Half, const Half * const, const std::uint32_t * const, const std::uint32_t * const, const Index, const Half *, const Index, const Index, const Index);
227template void ProductMatMat::dsd_cuda(float *, const float, const float, const float * const, const std::uint32_t * const, const std::uint32_t * const, const Index, const float *, const Index, const Index, const Index);
228template void ProductMatMat::dsd_cuda(double *, const double, const double, const double * const, const std::uint32_t * const, const std::uint32_t * const, const Index, const double *, const Index, const Index, const Index);
229#ifdef FEAT_HAVE_HALFMATH
230template void ProductMatMat::dsd_cuda(Half *, const Half, const Half, const Half * const, const std::uint64_t * const, const std::uint64_t * const, const Index, const Half *, const Index, const Index, const Index);
232template void ProductMatMat::dsd_cuda(float *, const float, const float, const float * const, const std::uint64_t * const, const std::uint64_t * const, const Index, const float *, const Index, const Index, const Index);
233template void ProductMatMat::dsd_cuda(double *, const double, const double, const double * const, const std::uint64_t * const, const std::uint64_t * const, const Index, const double *, const Index, const Index, const Index);