1// FEAT3: Finite Element Analysis Toolbox, Version 3
2// Copyright (C) 2010 by Stefan Turek & the FEAT group
3// FEAT3 is released under the GNU General Public License version 3,
4// see the file 'copyright.txt' in the top level directory for details.
7#include <kernel/base_header.hpp>
8#include <kernel/lafem/arch/dot_product.hpp>
9#include <kernel/lafem/arch/component_product.hpp>
10#include <kernel/util/exception.hpp>
11#include <kernel/util/cuda_util.hpp>
12#include <kernel/util/half.hpp>
14#include <kernel/util/memory_pool.hpp>
20using namespace FEAT::LAFEM;
21using namespace FEAT::LAFEM::Arch;
23template <typename DT_>
24DT_ DotProduct::value_cuda(const DT_ * const x, const DT_ * const y, const Index size)
28 if (typeid(DT_) == typeid(double))
33 else if (typeid(DT_) == typeid(float))
38#ifdef FEAT_HAVE_HALFMATH
39 else if (typeid(DT_) == typeid(Half))
46 throw InternalError(__func__, __FILE__, __LINE__, "unsupported data type!");
48 cublasStatus_t status;
51 status = cublasDotEx(Util::Intern::cublas_handle, int(size), x, dt, 1, y, dt, 1, &result, dt, et);
52 if (status != CUBLAS_STATUS_SUCCESS)
53 throw InternalError(__func__, __FILE__, __LINE__, "cuda error: " + stringify(cublasGetStatusString(status)));
55 cudaDeviceSynchronize();
57 cudaError_t last_error(cudaGetLastError());
58 if (cudaSuccess != last_error)
59 throw InternalError(__func__, __FILE__, __LINE__, "CUDA error occurred in execution!\n" + stringify(cudaGetErrorString(last_error)));
64#ifdef FEAT_HAVE_HALFMATH
65template Half DotProduct::value_cuda(const Half * const, const Half * const, const Index);
67template float DotProduct::value_cuda(const float * const, const float * const, const Index);
68template double DotProduct::value_cuda(const double * const, const double * const, const Index);
70template <typename DT_>
71DT_ TripleDotProduct::value_cuda(const DT_ * const x, const DT_ * const y, const DT_ * const z, const Index size)
73 DT_ * temp = (DT_*)Util::cuda_get_static_memory(size * sizeof(DT_));
74 ComponentProduct::value_cuda(temp, y, z, size);
75 DT_ result = DotProduct::value_cuda(x, temp, size);
77 cudaDeviceSynchronize();
79 cudaError_t last_error(cudaGetLastError());
80 if (cudaSuccess != last_error)
81 throw InternalError(__func__, __FILE__, __LINE__, "CUDA error occurred in execution!\n" + stringify(cudaGetErrorString(last_error)));
85#ifdef FEAT_HAVE_HALFMATH
86template Half TripleDotProduct::value_cuda(const Half * const x, const Half * const y, const Half * const z, const Index size);
88template float TripleDotProduct::value_cuda(const float * const x, const float * const y, const float * const z, const Index size);
89template double TripleDotProduct::value_cuda(const double * const x, const double * const y, const double * const z, const Index size);