FEAT 3
Finite Element Analysis Toolbox
Loading...
Searching...
No Matches
product_matmat.hpp
1// FEAT3: Finite Element Analysis Toolbox, Version 3
2// Copyright (C) 2010 by Stefan Turek & the FEAT group
3// FEAT3 is released under the GNU General Public License version 3,
4// see the file 'copyright.txt' in the top level directory for details.
5
6#pragma once
7#ifndef KERNEL_LAFEM_ARCH_PRODUCT_MATMAT_HPP
8#define KERNEL_LAFEM_ARCH_PRODUCT_MATMAT_HPP 1
9
10// includes, FEAT
12#include <kernel/backend.hpp>
13#include <kernel/util/half.hpp>
14
15namespace FEAT
16{
17 namespace LAFEM
18 {
19 namespace Arch
20 {
22 {
23 template <typename DT_>
24 static void dense(DT_ * r, const DT_ alpha, const DT_ beta, const DT_ * const x, const DT_ * const y, const DT_ * const z, const Index rows, const Index columns, const Index inner)
25 {
26 dense_generic(r, alpha, beta, x, y, z, rows, columns, inner);
27 }
28
29#ifdef FEAT_HAVE_HALFMATH
30 static void dense(Half * r, const Half alpha, const Half beta, const Half * const x, const Half * const y, const Half * const z, const Index rows, const Index columns, const Index inner)
31 {
32 BACKEND_SKELETON_VOID(dense_cuda, dense_generic, dense_generic, r, alpha, beta, x, y, z, rows, columns, inner)
33 }
34#endif
35
36 static void dense(float * r, const float alpha, const float beta, const float * const x, const float * const y, const float * const z, const Index rows, const Index columns, const Index inner)
37 {
38 BACKEND_SKELETON_VOID(dense_cuda, dense_mkl, dense_generic, r, alpha, beta, x, y, z, rows, columns, inner)
39 }
40
41 static void dense(double * r, const double alpha, const double beta, const double * const x, const double * const y, const double * const z, const Index rows, const Index columns, const Index inner)
42 {
43 BACKEND_SKELETON_VOID(dense_cuda, dense_mkl, dense_generic, r, alpha, beta, x, y, z, rows, columns, inner)
44 }
45
46 template <typename DT_, typename IT_>
47 static void dsd(DT_ * r, const DT_ alpha, const DT_ beta, const DT_ * const val, const IT_ * const col_ind, const IT_ * const row_ptr, const Index used_elements,
48 const DT_ * const y, const Index rows, const Index columns, const Index inner)
49 {
50 dsd_generic(r, alpha, beta, val, col_ind, row_ptr, used_elements, y, rows, columns, inner);
51 }
52
53#ifdef FEAT_HAVE_HALFMATH
54 static void dsd(Half * r, const Half alpha, const Half beta, const Half * const val, const std::uint64_t * const col_ind, const std::uint64_t * const row_ptr, const Index used_elements,
55 const Half * const y, const Index rows, const Index columns, const Index inner)
56 {
57 BACKEND_SKELETON_VOID(dsd_cuda, dsd_generic, dsd_generic, r, alpha, beta, val, col_ind, row_ptr, used_elements, y, rows, columns, inner)
58 }
59#endif
60
61 static void dsd(float * r, const float alpha, const float beta, const float * const val, const std::uint64_t * const col_ind, const std::uint64_t * const row_ptr, const Index used_elements,
62 const float * const y, const Index rows, const Index columns, const Index inner)
63 {
64 BACKEND_SKELETON_VOID(dsd_cuda, dsd_generic, dsd_generic, r, alpha, beta, val, col_ind, row_ptr, used_elements, y, rows, columns, inner)
65 }
66
67 static void dsd(double * r, const double alpha, const double beta, const double * const val, const std::uint64_t * const col_ind, const std::uint64_t * const row_ptr, const Index used_elements,
68 const double * const y, const Index rows, const Index columns, const Index inner)
69 {
70 BACKEND_SKELETON_VOID(dsd_cuda, dsd_generic, dsd_generic, r, alpha, beta, val, col_ind, row_ptr, used_elements, y, rows, columns, inner)
71 }
72
73#ifdef FEAT_HAVE_HALFMATH
74 static void dsd(Half * r, const Half alpha, const Half beta, const Half * const val, const std::uint32_t * const col_ind, const std::uint32_t * const row_ptr, const Index used_elements,
75 const Half * const y, const Index rows, const Index columns, const Index inner)
76 {
77 BACKEND_SKELETON_VOID(dsd_cuda, dsd_generic, dsd_generic, r, alpha, beta, val, col_ind, row_ptr, used_elements, y, rows, columns, inner)
78 }
79#endif
80
81 static void dsd(float * r, const float alpha, const float beta, const float * const val, const std::uint32_t * const col_ind, const std::uint32_t * const row_ptr, const Index used_elements,
82 const float * const y, const Index rows, const Index columns, const Index inner)
83 {
84 BACKEND_SKELETON_VOID(dsd_cuda, dsd_generic, dsd_generic, r, alpha, beta, val, col_ind, row_ptr, used_elements, y, rows, columns, inner)
85 }
86
87 static void dsd(double * r, const double alpha, const double beta, const double * const val, const std::uint32_t * const col_ind, const std::uint32_t * const row_ptr, const Index used_elements,
88 const double * const y, const Index rows, const Index columns, const Index inner)
89 {
90 BACKEND_SKELETON_VOID(dsd_cuda, dsd_generic, dsd_generic, r, alpha, beta, val, col_ind, row_ptr, used_elements, y, rows, columns, inner)
91 }
92
93 template <typename DT_>
94 static void dense_generic(DT_ * r, const DT_ alpha, const DT_ beta, const DT_ * const x, const DT_ * const y, const DT_ * const z, const Index rows, const Index columns, const Index inner);
95
96 static void dense_mkl(float * r, const float alpha, const float beta, const float * const x, const float * const y, const float * const z, const Index rows, const Index columns, const Index inner);
97 static void dense_mkl(double * r, const double alpha, const double beta, const double * const x, const double * const y, const double * const z, const Index rows, const Index columns, const Index inner);
98
99 template <typename DT_>
100 static void dense_cuda(DT_ * r, const DT_ alpha, const DT_ beta, const DT_ * const x, const DT_ * const y, const DT_ * const z, const Index rows, const Index columns, const Index inner);
101
102 template <typename DT_, typename IT_>
103 static void dsd_generic(DT_ * r, const DT_ alpha, const DT_ beta, const DT_ * const val, const IT_ * const col_ind, const IT_ * const row_ptr, const Index used_elements,
104 const DT_ * const y, const Index rows, const Index columns, const Index inner);
105
106 template <typename DT_, typename IT_>
107 static void dsd_cuda(DT_ * r, const DT_ alpha, const DT_ beta, const DT_ * const val, const IT_ * const col_ind, const IT_ * const row_ptr, const Index used_elements,
108 const DT_ * const y, const Index rows, const Index columns, const Index inner);
109
110 };
111
112#ifdef FEAT_EICKT
113 extern template void ProductMatMat::dense_generic(float *, const float, const float, const float * const, const float * const, const float * const, const Index, const Index, const Index);
114 extern template void ProductMatMat::dense_generic(double *, const double, const double, const double * const, const double * const, const double * const, const Index, const Index, const Index);
115
116 extern template void ProductMatMat::dsd_generic(float *, const float, const float, const float * const, const std::uint64_t * const, const std::uint64_t * const, const Index, const float * const, const Index, const Index, const Index);
117 extern template void ProductMatMat::dsd_generic(double *, const double, const double, const double * const, const std::uint64_t * const, const std::uint64_t * const, const Index, const double *const , const Index, const Index, const Index);
118
119 extern template void ProductMatMat::dsd_generic(float *, const float, const float, const float * const, const std::uint32_t * const, const std::uint32_t * const, const Index, const float * const, const Index, const Index, const Index);
120 extern template void ProductMatMat::dsd_generic(double *, const double, const double, const double * const, const std::uint32_t * const, const std::uint32_t * const, const Index, const double * const, const Index, const Index, const Index);
121#endif
122
123 } // namespace Arch
124 } // namespace LAFEM
125} // namespace FEAT
126
127#ifndef __CUDACC__
128#include <kernel/lafem/arch/product_matmat_generic.hpp>
129#endif
130#endif // KERNEL_LAFEM_ARCH_PRODUCT_MATMAT_HPP
FEAT Kernel base header.
FEAT namespace.
Definition: adjactor.hpp:12
__half Half
Half data type.
Definition: half.hpp:25
std::uint64_t Index
Index data type.