1// FEAT3: Finite Element Analysis Toolbox, Version 3
 
    2// Copyright (C) 2010 by Stefan Turek & the FEAT group
 
    3// FEAT3 is released under the GNU General Public License version 3,
 
    4// see the file 'copyright.txt' in the top level directory for details.
 
    7#include <kernel/base_header.hpp>
 
    8#include <kernel/lafem/arch/axpy.hpp>
 
    9#include <kernel/util/exception.hpp>
 
   10#include <kernel/util/cuda_util.hpp>
 
   11#include <kernel/util/math.hpp>
 
   12#include <kernel/util/half.hpp>
 
   15using namespace FEAT::LAFEM;
 
   16using namespace FEAT::LAFEM::Arch;
 
   18template <typename DT_>
 
   19void Axpy::value_cuda(DT_ * r, const DT_ a, const DT_ * const x, const Index size)
 
   23  if (typeid(DT_) == typeid(double))
 
   28  else if (typeid(DT_) == typeid(float))
 
   33#ifdef FEAT_HAVE_HALFMATH
 
   34  else if (typeid(DT_) == typeid(Half))
 
   41    throw InternalError(__func__, __FILE__, __LINE__, "unsupported data type!");
 
   43  cublasStatus_t status;
 
   44  status = cublasAxpyEx(Util::Intern::cublas_handle, int(size), &a, et, x, dt, 1, r, dt, 1, et);
 
   46  if (status != CUBLAS_STATUS_SUCCESS)
 
   47    throw InternalError(__func__, __FILE__, __LINE__, "cuda error: " + stringify(cublasGetStatusString(status)));
 
   49  cudaDeviceSynchronize();
 
   51  cudaError_t last_error(cudaGetLastError());
 
   52  if (cudaSuccess != last_error)
 
   53    throw InternalError(__func__, __FILE__, __LINE__, "CUDA error occurred in execution!\n" + stringify(cudaGetErrorString(last_error)));
 
   56#ifdef FEAT_HAVE_HALFMATH
 
   57template void Axpy::value_cuda(Half *, const Half, const Half * const, const Index);
 
   59template void Axpy::value_cuda(float *, const float, const float * const, const Index);
 
   60template void Axpy::value_cuda(double *, const double, const double * const, const Index);