1// FEAT3: Finite Element Analysis Toolbox, Version 3
 
    2// Copyright (C) 2010 by Stefan Turek & the FEAT group
 
    3// FEAT3 is released under the GNU General Public License version 3,
 
    4// see the file 'copyright.txt' in the top level directory for details.
 
    7#include <kernel/base_header.hpp>
 
    8#include <kernel/lafem/arch/component_invert.hpp>
 
    9#include <kernel/util/exception.hpp>
 
   10#include <kernel/util/memory_pool.hpp>
 
   11#include <kernel/util/half.hpp>
 
   19      template <typename DT_>
 
   20      __global__ void cuda_component_invert(DT_ * r, const DT_ * x, const DT_ s, const Index count)
 
   22        Index idx = threadIdx.x + blockDim.x * blockIdx.x;
 
   28#ifdef FEAT_HAVE_HALFMATH
 
   29      __global__ void cuda_component_invert(Half * r, const Half * x, const Half s, const Index count)
 
   31        Index idx = threadIdx.x + blockDim.x * blockIdx.x;
 
   34        ///\todo skip conversion step
 
   35        r[idx] = __half2float(s / x[idx]);
 
   44using namespace FEAT::LAFEM;
 
   45using namespace FEAT::LAFEM::Arch;
 
   47template <typename DT_>
 
   48void ComponentInvert::value_cuda(DT_ * r, const DT_ * const x, const DT_ s, const Index size)
 
   50  Index blocksize = Util::cuda_blocksize_axpy;
 
   53  block.x = (unsigned)blocksize;
 
   54  grid.x = (unsigned)ceil((size)/(double)(block.x));
 
   56  FEAT::LAFEM::Intern::cuda_component_invert<<<grid, block>>>(r, x, s, size);
 
   58  cudaDeviceSynchronize();
 
   60  cudaError_t last_error(cudaGetLastError());
 
   61  if (cudaSuccess != last_error)
 
   62    throw InternalError(__func__, __FILE__, __LINE__, "CUDA error occurred in execution!\n" + stringify(cudaGetErrorString(last_error)));
 
   65#ifdef FEAT_HAVE_HALFMATH
 
   66template void ComponentInvert::value_cuda(Half *, const Half * const, const Half, const Index);
 
   68template void ComponentInvert::value_cuda(float *, const float * const, const float, const Index);
 
   69template void ComponentInvert::value_cuda(double *, const double * const, const double, const Index);