FEAT 3
Finite Element Analysis Toolbox
Loading...
Searching...
No Matches
dot_product.cu
1// FEAT3: Finite Element Analysis Toolbox, Version 3
2// Copyright (C) 2010 by Stefan Turek & the FEAT group
3// FEAT3 is released under the GNU General Public License version 3,
4// see the file 'copyright.txt' in the top level directory for details.
5
6// includes, FEAT
7#include <kernel/base_header.hpp>
8#include <kernel/lafem/arch/dot_product.hpp>
9#include <kernel/lafem/arch/component_product.hpp>
10#include <kernel/util/exception.hpp>
11#include <kernel/util/cuda_util.hpp>
12#include <kernel/util/half.hpp>
13
14#include <kernel/util/memory_pool.hpp>
15
16// includes, CUDA
17#include <cublas_v2.h>
18
19using namespace FEAT;
20using namespace FEAT::LAFEM;
21using namespace FEAT::LAFEM::Arch;
22
23template <typename DT_>
24DT_ DotProduct::value_cuda(const DT_ * const x, const DT_ * const y, const Index size)
25{
26 cudaDataType dt;
27 cudaDataType et;
28 if (typeid(DT_) == typeid(double))
29 {
30 dt = CUDA_R_64F;
31 et = CUDA_R_64F;
32 }
33 else if (typeid(DT_) == typeid(float))
34 {
35 dt = CUDA_R_32F;
36 et = CUDA_R_32F;
37 }
38#ifdef FEAT_HAVE_HALFMATH
39 else if (typeid(DT_) == typeid(Half))
40 {
41 dt = CUDA_R_16F;
42 et = CUDA_R_32F;
43 }
44#endif
45 else
46 throw InternalError(__func__, __FILE__, __LINE__, "unsupported data type!");
47
48 cublasStatus_t status;
49 DT_ result(0.);
50
51 status = cublasDotEx(Util::Intern::cublas_handle, int(size), x, dt, 1, y, dt, 1, &result, dt, et);
52 if (status != CUBLAS_STATUS_SUCCESS)
53 throw InternalError(__func__, __FILE__, __LINE__, "cuda error: " + stringify(cublasGetStatusString(status)));
54
55 cudaDeviceSynchronize();
56#ifdef FEAT_DEBUG_MODE
57 cudaError_t last_error(cudaGetLastError());
58 if (cudaSuccess != last_error)
59 throw InternalError(__func__, __FILE__, __LINE__, "CUDA error occurred in execution!\n" + stringify(cudaGetErrorString(last_error)));
60#endif
61
62 return result;
63}
64#ifdef FEAT_HAVE_HALFMATH
65template Half DotProduct::value_cuda(const Half * const, const Half * const, const Index);
66#endif
67template float DotProduct::value_cuda(const float * const, const float * const, const Index);
68template double DotProduct::value_cuda(const double * const, const double * const, const Index);
69
70template <typename DT_>
71DT_ TripleDotProduct::value_cuda(const DT_ * const x, const DT_ * const y, const DT_ * const z, const Index size)
72{
73 DT_ * temp = (DT_*)Util::cuda_get_static_memory(size * sizeof(DT_));
74 ComponentProduct::value_cuda(temp, y, z, size);
75 DT_ result = DotProduct::value_cuda(x, temp, size);
76
77 cudaDeviceSynchronize();
78#ifdef FEAT_DEBUG_MODE
79 cudaError_t last_error(cudaGetLastError());
80 if (cudaSuccess != last_error)
81 throw InternalError(__func__, __FILE__, __LINE__, "CUDA error occurred in execution!\n" + stringify(cudaGetErrorString(last_error)));
82#endif
83 return result;
84}
85#ifdef FEAT_HAVE_HALFMATH
86template Half TripleDotProduct::value_cuda(const Half * const x, const Half * const y, const Half * const z, const Index size);
87#endif
88template float TripleDotProduct::value_cuda(const float * const x, const float * const y, const float * const z, const Index size);
89template double TripleDotProduct::value_cuda(const double * const x, const double * const y, const double * const z, const Index size);