1// FEAT3: Finite Element Analysis Toolbox, Version 3
 
    2// Copyright (C) 2010 by Stefan Turek & the FEAT group
 
    3// FEAT3 is released under the GNU General Public License version 3,
 
    4// see the file 'copyright.txt' in the top level directory for details.
 
    7#include <kernel/base_header.hpp>
 
    8#include <kernel/lafem/arch/product_matmat.hpp>
 
    9#include <kernel/util/exception.hpp>
 
   10#include <kernel/util/memory_pool.hpp>
 
   11#include <kernel/util/half.hpp>
 
   15#include <cusparse_v2.h>
 
   18using namespace FEAT::LAFEM;
 
   19using namespace FEAT::LAFEM::Arch;
 
   21template <typename DT_>
 
   22void ProductMatMat::dense_cuda(DT_ * r, const DT_ alpha, const DT_ beta,  const DT_ * const x, const DT_ * const y, const DT_ * const z, const Index rows, const Index columns, const Index inner)
 
   24  if (r==y || r==x || x==y || z==x || z==y)
 
   25    throw InternalError(__func__, __FILE__, __LINE__, "cuda ProductMatMat does not allow r==y or r==x or x==y or z==x or z==y!");
 
   27  cublasStatus_t status;
 
   29  // inspired by https://github.com/NVIDIA/CUDALibrarySamples/blob/master/cuBLASLt/LtSgemm/sample_cublasLt_LtSgemm.cu
 
   31  cublasLtMatmulDesc_t operationDesc = NULL;
 
   32  cublasLtMatrixLayout_t Rdesc = NULL, Adesc = NULL, Bdesc = NULL, Cdesc = NULL;
 
   33  cublasLtMatmulPreference_t preference = NULL;
 
   35  int algo_selector = -1;
 
   38  cublasComputeType_t ct;
 
   39  if (typeid(DT_) == typeid(double))
 
   42      ct = CUBLAS_COMPUTE_64F;
 
   43      algo_selector = (rows > 1 && columns > 1 && inner > 1) ? 0 : 1;
 
   45  else if (typeid(DT_) == typeid(float))
 
   48#if __CUDA_ARCH__ < 800
 
   49      ct = CUBLAS_COMPUTE_32F;
 
   51      ct = CUBLAS_COMPUTE_32F_FAST_TF32;
 
   53      algo_selector = (rows > 1 && columns > 1 && inner > 1) ? 2 : 3;
 
   55#ifdef FEAT_HAVE_HALFMATH
 
   56  else if (typeid(DT_) == typeid(Half))
 
   59      ct = CUBLAS_COMPUTE_16F;
 
   60      algo_selector = (rows > 1 && columns > 1 && inner > 1) ? 4 : 5;
 
   64    throw InternalError(__func__, __FILE__, __LINE__, "unsupported data type!");
 
   66  status = cublasLtMatmulDescCreate(&operationDesc, ct, dt);
 
   67  if (status != CUBLAS_STATUS_SUCCESS)
 
   68    throw InternalError(__func__, __FILE__, __LINE__, "cuda error: " + stringify(cublasGetStatusString(status)));
 
   70  cublasLtOrder_t matrix_order = CUBLASLT_ORDER_ROW;
 
   71  status = cublasLtMatrixLayoutCreate(&Rdesc, dt, rows, columns, columns);
 
   72  if (status != CUBLAS_STATUS_SUCCESS)
 
   73    throw InternalError(__func__, __FILE__, __LINE__, "cuda error: " + stringify(cublasGetStatusString(status)));
 
   74  status = cublasLtMatrixLayoutSetAttribute(Rdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &matrix_order, sizeof(cublasLtOrder_t));
 
   75  if (status != CUBLAS_STATUS_SUCCESS)
 
   76    throw InternalError(__func__, __FILE__, __LINE__, "cuda error: " + stringify(cublasGetStatusString(status)));
 
   77  cublasLtMatrixLayoutCreate(&Adesc, dt, rows, inner, inner);
 
   78  cublasLtMatrixLayoutSetAttribute(Adesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &matrix_order, sizeof(cublasLtOrder_t));
 
   79  cublasLtMatrixLayoutCreate(&Bdesc, dt, inner, columns, columns);
 
   80  cublasLtMatrixLayoutSetAttribute(Bdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &matrix_order, sizeof(cublasLtOrder_t));
 
   83    cublasLtMatrixLayoutCreate(&Cdesc, dt, rows, columns, columns);
 
   84    cublasLtMatrixLayoutSetAttribute(Cdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &matrix_order, sizeof(cublasLtOrder_t));
 
   86  else // r==z -> in-place multiplication
 
   91  cublasLtMatmulHeuristicResult_t algo_check_result;
 
   92  if (! FEAT::Util::Intern::cublas_lt_algo_matmat_initialized[algo_selector] ||
 
   93      CUBLAS_STATUS_SUCCESS != cublasLtMatmulAlgoCheck((cublasLtHandle_t)Util::Intern::cublas_handle, operationDesc, Adesc, Bdesc, Cdesc, Rdesc, &(FEAT::Util::Intern::cublas_lt_algo_matmat[algo_selector]), &algo_check_result))
 
   96    cublasLtMatmulHeuristicResult_t heuristic_algos = {};
 
   98    status = cublasLtMatmulPreferenceCreate(&preference);
 
   99    if (status != CUBLAS_STATUS_SUCCESS)
 
  100      throw InternalError(__func__, __FILE__, __LINE__, "cuda error: " + stringify(cublasGetStatusString(status)));
 
  101    //status = cublasLtMatmulPreferenceSetAttribute(preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &(FEAT::Util::Intern::cuda_workspace_size), sizeof(FEAT::Util::Intern::cuda_workspace_size));
 
  102    if (status != CUBLAS_STATUS_SUCCESS)
 
  103      throw InternalError(__func__, __FILE__, __LINE__, "cuda error: " + stringify(cublasGetStatusString(status)));
 
  105    status = cublasLtMatmulAlgoGetHeuristic((cublasLtHandle_t)Util::Intern::cublas_handle, operationDesc, Adesc, Bdesc, Cdesc, Rdesc, preference, 1, &heuristic_algos, &num_algos);
 
  106    if (status != CUBLAS_STATUS_SUCCESS)
 
  107      throw InternalError(__func__, __FILE__, __LINE__, "cuda error: " + stringify(cublasGetStatusString(status)));
 
  110      throw InternalError(__func__, __FILE__, __LINE__, "no algo supports our matrices!");
 
  112    FEAT::Util::Intern::cublas_lt_algo_matmat[algo_selector] = heuristic_algos.algo;
 
  113    FEAT::Util::Intern::cublas_lt_algo_matmat_initialized[algo_selector] = true;
 
  116  cublasLtMatmulAlgo_t * algo = &(FEAT::Util::Intern::cublas_lt_algo_matmat[algo_selector]);
 
  118  //status = cublasLtMatmul((cublasLtHandle_t)Util::Intern::cublas_handle, operationDesc, &alpha, x, Adesc, y, Bdesc, &beta, z, Cdesc, r, Rdesc, algo, FEAT::Util::Intern::cuda_workspace, FEAT::Util::Intern::cuda_workspace_size, 0);
 
  119  status = cublasLtMatmul((cublasLtHandle_t)Util::Intern::cublas_handle, operationDesc, &alpha, x, Adesc, y, Bdesc, &beta, z, Cdesc, r, Rdesc, algo, NULL, 0, 0);
 
  120  if (status != CUBLAS_STATUS_SUCCESS)
 
  121    throw InternalError(__func__, __FILE__, __LINE__, "cuda error: " + stringify(cublasGetStatusString(status)));
 
  123  cudaDeviceSynchronize();
 
  124#ifdef FEAT_DEBUG_MODE
 
  125  cudaError_t last_error(cudaGetLastError());
 
  126  if (cudaSuccess != last_error)
 
  127    throw InternalError(__func__, __FILE__, __LINE__, "CUDA error occurred in execution!\n" + stringify(cudaGetErrorString(last_error)));
 
  130#ifdef FEAT_HAVE_HALFMATH
 
  131template void ProductMatMat::dense_cuda(Half *, const Half, const Half, const Half * const, const Half * const, const Half * const, const Index, const Index, const Index);
 
  133template void ProductMatMat::dense_cuda(float *, const float, const float, const float * const, const float * const, const float * const, const Index, const Index, const Index);
 
  134template void ProductMatMat::dense_cuda(double *, const double, const double, const double * const, const double * const, const double * const, const Index, const Index, const Index);
 
  136template <typename DT_, typename IT_>
 
  137void ProductMatMat::dsd_cuda(DT_ * r, const DT_ alpha, const DT_ beta, const DT_ * const val, const IT_ * const col_ind, const IT_ * const row_ptr, const Index used_elements,
 
  138    const DT_ * y, const Index rows, const Index columns, const Index inner)
 
  141    throw InternalError(__func__, __FILE__, __LINE__, "cuda ProductMatMat does not allow r==y!");
 
  144  cudaDataType ct; //compute type
 
  145  if (typeid(DT_) == typeid(double))
 
  150  else if (typeid(DT_) == typeid(float))
 
  155#ifdef FEAT_HAVE_HALFMATH
 
  156  else if (typeid(DT_) == typeid(Half))
 
  159      ct = CUDA_R_32F; //cusparseSpMM does not support computation in half, yet
 
  164    throw InternalError(__func__, __FILE__, __LINE__, "unsupported data type!");
 
  167  cusparseIndexType_t it;
 
  168  if(sizeof(IT_) == 4u)
 
  169    it = CUSPARSE_INDEX_32I;
 
  170  else if(sizeof(IT_) == 8u)
 
  171    it = CUSPARSE_INDEX_64I;
 
  174    throw InternalError(__func__, __FILE__, __LINE__, "unsupported index type!");
 
  177  cusparseStatus_t status;
 
  179  cusparseDnMatDescr_t descr_r=0;
 
  180  status = cusparseCreateDnMat(&descr_r, rows, columns, columns, (void*)r, dt, CUSPARSE_ORDER_ROW);
 
  181  if (status != CUSPARSE_STATUS_SUCCESS)
 
  182    throw InternalError(__func__, __FILE__, __LINE__, "cuda error: " + stringify(cusparseGetErrorString(status)));
 
  184  cusparseSpMatDescr_t descr_x=0;
 
  185  status = cusparseCreateCsr(&descr_x, rows, inner, used_elements, (void*)row_ptr, (void*)col_ind, (void*)val, it, it, CUSPARSE_INDEX_BASE_ZERO, dt);
 
  186  if (status != CUSPARSE_STATUS_SUCCESS)
 
  187    throw InternalError(__func__, __FILE__, __LINE__, "cuda error: " + stringify(cusparseGetErrorString(status)));
 
  189  cusparseDnMatDescr_t descr_y=0;
 
  190  status = cusparseCreateDnMat(&descr_y, inner, columns, columns, (void*)y, dt, CUSPARSE_ORDER_ROW);
 
  191  if (status != CUSPARSE_STATUS_SUCCESS)
 
  192    throw InternalError(__func__, __FILE__, __LINE__, "cuda error: " + stringify(cusparseGetErrorString(status)));
 
  194  cusparseOperation_t trans = CUSPARSE_OPERATION_NON_TRANSPOSE;
 
  195  size_t buffer_size(0);
 
  196  status = cusparseSpMM_bufferSize(Util::Intern::cusparse_handle, trans, trans, &alpha, descr_x, descr_y, &beta, descr_r, ct, CUSPARSE_SPMM_CSR_ALG2, &buffer_size);
 
  197  if (status != CUSPARSE_STATUS_SUCCESS)
 
  198    throw InternalError(__func__, __FILE__, __LINE__, "cusparsecsrmvex_buffersize failed with status code: " + stringify(cusparseGetErrorString(status)));
 
  200  void* buffer = Util::cuda_get_static_memory(buffer_size);
 
  202  status = cusparseSpMM(Util::Intern::cusparse_handle, trans, trans, &alpha, descr_x, descr_y, &beta, descr_r, ct, CUSPARSE_SPMM_CSR_ALG2, buffer);
 
  203  if (status != CUSPARSE_STATUS_SUCCESS)
 
  204    throw InternalError(__func__, __FILE__, __LINE__, "cusparseSpMM failed with status code: " + stringify(cusparseGetErrorString(status)));
 
  206  cusparseDestroyDnMat(descr_r);
 
  207  cusparseDestroySpMat(descr_x);
 
  208  cusparseDestroyDnMat(descr_y);
 
  210  cudaDeviceSynchronize();
 
  211#ifdef FEAT_DEBUG_MODE
 
  212  cudaError_t last_error(cudaGetLastError());
 
  213  if (cudaSuccess != last_error)
 
  214    throw InternalError(__func__, __FILE__, __LINE__, "CUDA error occurred in execution!\n" + stringify(cudaGetErrorString(last_error)));
 
  217#ifdef FEAT_HAVE_HALFMATH
 
  218template void ProductMatMat::dsd_cuda(Half *, const Half, const Half, const Half * const, const std::uint32_t * const, const std::uint32_t * const, const Index, const Half *, const Index, const Index, const Index);
 
  220template void ProductMatMat::dsd_cuda(float *, const float, const float, const float * const, const std::uint32_t * const, const std::uint32_t * const, const Index, const float *, const Index, const Index, const Index);
 
  221template void ProductMatMat::dsd_cuda(double *, const double, const double, const double * const, const std::uint32_t * const, const std::uint32_t * const, const Index, const double *, const Index, const Index, const Index);
 
  222#ifdef FEAT_HAVE_HALFMATH
 
  223template void ProductMatMat::dsd_cuda(Half *, const Half, const Half, const Half * const, const std::uint64_t * const, const std::uint64_t * const, const Index, const Half *, const Index, const Index, const Index);
 
  225template void ProductMatMat::dsd_cuda(float *, const float, const float, const float * const, const std::uint64_t * const, const std::uint64_t * const, const Index, const float *, const Index, const Index, const Index);
 
  226template void ProductMatMat::dsd_cuda(double *, const double, const double, const double * const, const std::uint64_t * const, const std::uint64_t * const, const Index, const double *, const Index, const Index, const Index);