FEAT 3
Finite Element Analysis Toolbox
Loading...
Searching...
No Matches
scale.cu
1// FEAT3: Finite Element Analysis Toolbox, Version 3
2// Copyright (C) 2010 by Stefan Turek & the FEAT group
3// FEAT3 is released under the GNU General Public License version 3,
4// see the file 'copyright.txt' in the top level directory for details.
5
6// includes, FEAT
7#include <kernel/base_header.hpp>
8#include <kernel/lafem/arch/scale.hpp>
9#include <kernel/util/exception.hpp>
10#include <kernel/util/cuda_util.hpp>
11#include <kernel/util/half.hpp>
12
13using namespace FEAT;
14using namespace FEAT::LAFEM;
15using namespace FEAT::LAFEM::Arch;
16
17template <typename DT_>
18void Scale::value_cuda(DT_ * r, const DT_ * const x, const DT_ s, const Index size)
19{
20 cudaDataType dt;
21 cudaDataType et;
22 if (typeid(DT_) == typeid(double))
23 {
24 dt = CUDA_R_64F;
25 et = CUDA_R_64F;
26 }
27 else if (typeid(DT_) == typeid(float))
28 {
29 dt = CUDA_R_32F;
30 et = CUDA_R_32F;
31 }
32#ifdef FEAT_HAVE_HALFMATH
33 else if (typeid(DT_) == typeid(Half))
34 {
35 dt = CUDA_R_16F;
36 et = CUDA_R_32F;
37 }
38#endif
39 else
40 throw InternalError(__func__, __FILE__, __LINE__, "unsupported data type!");
41
42 if (r != x)
43 ///\todo cuse cublasCopyEx when available
44 cudaMemcpy(r, x, size * sizeof(DT_), cudaMemcpyDefault);
45
46 cublasStatus_t status;
47
48 status = cublasScalEx(Util::Intern::cublas_handle, int(size), &s, et, r, dt, 1, et);
49 if (status != CUBLAS_STATUS_SUCCESS)
50 throw InternalError(__func__, __FILE__, __LINE__, "cuda error: " + stringify(cublasGetStatusString(status)));
51
52 cudaDeviceSynchronize();
53#ifdef FEAT_DEBUG_MODE
54 cudaError_t last_error(cudaGetLastError());
55 if (cudaSuccess != last_error)
56 throw InternalError(__func__, __FILE__, __LINE__, "CUDA error occurred in execution!\n" + stringify(cudaGetErrorString(last_error)));
57#endif
58}
59#ifdef FEAT_HAVE_HALFMATH
60template void Scale::value_cuda(Half *, const Half * const, const Half, const Index);
61#endif
62template void Scale::value_cuda(float *, const float * const, const float, const Index);
63template void Scale::value_cuda(double *, const double * const, const double, const Index);