FEAT 3
Finite Element Analysis Toolbox
Loading...
Searching...
No Matches
norm.cu
1// FEAT3: Finite Element Analysis Toolbox, Version 3
2// Copyright (C) 2010 by Stefan Turek & the FEAT group
3// FEAT3 is released under the GNU General Public License version 3,
4// see the file 'copyright.txt' in the top level directory for details.
5
6// includes, FEAT
7#include <kernel/base_header.hpp>
8#include <kernel/lafem/arch/norm.hpp>
9#include <kernel/util/exception.hpp>
10#include <kernel/util/cuda_util.hpp>
11#include <kernel/util/half.hpp>
12
13// includes, CUDA
14#include <cublas_v2.h>
15
16using namespace FEAT;
17using namespace FEAT::LAFEM;
18using namespace FEAT::LAFEM::Arch;
19
20template <typename DT_>
21DT_ Norm2::value_cuda(const DT_ * const x, const Index size)
22{
23 cudaDataType dt;
24 cudaDataType et;
25 if (typeid(DT_) == typeid(double))
26 {
27 dt = CUDA_R_64F;
28 et = CUDA_R_64F;
29 }
30 else if (typeid(DT_) == typeid(float))
31 {
32 dt = CUDA_R_32F;
33 et = CUDA_R_32F;
34 }
35#ifdef FEAT_HAVE_HALFMATH
36 else if (typeid(DT_) == typeid(Half))
37 {
38 dt = CUDA_R_16F;
39 et = CUDA_R_32F;
40 }
41#endif
42 else
43 throw InternalError(__func__, __FILE__, __LINE__, "unsupported data type!");
44
45 cublasStatus_t status;
46 DT_ result(42.);
47
48 status = cublasNrm2Ex(Util::Intern::cublas_handle, int(size), x, dt, 1, &result, dt, et);
49 if (status != CUBLAS_STATUS_SUCCESS)
50 throw InternalError(__func__, __FILE__, __LINE__, "cuda error: " + stringify(cublasGetStatusString(status)));
51
52 cudaDeviceSynchronize();
53#ifdef FEAT_DEBUG_MODE
54 cudaError_t last_error(cudaGetLastError());
55 if (cudaSuccess != last_error)
56 throw InternalError(__func__, __FILE__, __LINE__, "CUDA error occurred in execution!\n" + stringify(cudaGetErrorString(last_error)));
57#endif
58 return result;
59}
60
61#ifdef FEAT_HAVE_HALFMATH
62template Half Norm2::value_cuda(const Half * const, const Index);
63#endif
64template float Norm2::value_cuda(const float * const, const Index);
65template double Norm2::value_cuda(const double * const, const Index);