FEAT 3
Finite Element Analysis Toolbox
Loading...
Searching...
No Matches
axpy.cu
1// FEAT3: Finite Element Analysis Toolbox, Version 3
2// Copyright (C) 2010 by Stefan Turek & the FEAT group
3// FEAT3 is released under the GNU General Public License version 3,
4// see the file 'copyright.txt' in the top level directory for details.
5
6// includes, FEAT
7#include <kernel/base_header.hpp>
8#include <kernel/lafem/arch/axpy.hpp>
9#include <kernel/util/exception.hpp>
10#include <kernel/util/cuda_util.hpp>
11#include <kernel/util/math.hpp>
12#include <kernel/util/half.hpp>
13
14using namespace FEAT;
15using namespace FEAT::LAFEM;
16using namespace FEAT::LAFEM::Arch;
17
18template <typename DT_>
19void Axpy::value_cuda(DT_ * r, const DT_ a, const DT_ * const x, const Index size)
20{
21 cudaDataType dt;
22 cudaDataType et;
23 if (typeid(DT_) == typeid(double))
24 {
25 dt = CUDA_R_64F;
26 et = CUDA_R_64F;
27 }
28 else if (typeid(DT_) == typeid(float))
29 {
30 dt = CUDA_R_32F;
31 et = CUDA_R_32F;
32 }
33#ifdef FEAT_HAVE_HALFMATH
34 else if (typeid(DT_) == typeid(Half))
35 {
36 dt = CUDA_R_16F;
37 et = CUDA_R_32F;
38 }
39#endif
40 else
41 throw InternalError(__func__, __FILE__, __LINE__, "unsupported data type!");
42
43 cublasStatus_t status;
44 status = cublasAxpyEx(Util::Intern::cublas_handle, int(size), &a, et, x, dt, 1, r, dt, 1, et);
45
46 if (status != CUBLAS_STATUS_SUCCESS)
47 throw InternalError(__func__, __FILE__, __LINE__, "cuda error: " + stringify(cublasGetStatusString(status)));
48
49 cudaDeviceSynchronize();
50#ifdef FEAT_DEBUG_MODE
51 cudaError_t last_error(cudaGetLastError());
52 if (cudaSuccess != last_error)
53 throw InternalError(__func__, __FILE__, __LINE__, "CUDA error occurred in execution!\n" + stringify(cudaGetErrorString(last_error)));
54#endif
55}
56#ifdef FEAT_HAVE_HALFMATH
57template void Axpy::value_cuda(Half *, const Half, const Half * const, const Index);
58#endif
59template void Axpy::value_cuda(float *, const float, const float * const, const Index);
60template void Axpy::value_cuda(double *, const double, const double * const, const Index);