7#ifndef KERNEL_LAFEM_ARCH_PRODUCT_MATMAT_GENERIC_HPP
8#define KERNEL_LAFEM_ARCH_PRODUCT_MATMAT_GENERIC_HPP 1
10#ifndef KERNEL_LAFEM_ARCH_PRODUCT_MATMAT_HPP
11#error "Do not include this implementation-only header file directly!"
14#include <kernel/util/math.hpp>
24 template <
typename DT_>
25 void ProductMatMat::dense_generic(DT_ * r,
const DT_ alpha,
const DT_ beta,
const DT_ *
const x,
const DT_ *
const y,
const DT_ *
const z,
const Index rows,
const Index columns,
const Index inner)
29 FEAT_PRAGMA_OMP(parallel
for)
30 for (
Index i = 0 ; i < rows ; ++i)
32 for (
Index j = 0 ; j < columns ; ++j)
35 Index xindex(i * inner);
37 for (
Index xcol(0) ; xcol < inner ; ++xcol)
39 sum = sum + x[xindex + xcol] * y[yindex + xcol * columns];
41 r[i * columns + j] = alpha * sum;
47 FEAT_PRAGMA_OMP(parallel
for)
48 for (
Index i = 0 ; i < rows ; ++i)
50 for (
Index j = 0 ; j < columns ; ++j)
53 Index xindex(i * inner);
55 for (
Index xcol(0) ; xcol < inner ; ++xcol)
57 sum = sum + x[xindex + xcol] * y[yindex + xcol * columns];
59 r[i * columns + j] = beta * z[i * columns + j] + alpha * sum;
65 template <
typename DT_,
typename IT_>
66 void ProductMatMat::dsd_generic(DT_ * r,
const DT_ alpha,
const DT_ beta,
const DT_ *
const val,
const IT_ *
const col_ind,
const IT_ *
const row_ptr,
const Index ,
67 const DT_ *
const y,
const Index rows,
const Index columns,
const Index )
71 FEAT_PRAGMA_OMP(parallel
for)
72 for (
Index i = 0 ; i < rows ; ++i)
74 for (
Index j = 0 ; j < columns ; ++j)
77 Index xindex = row_ptr[i];
79 for (
Index tmp = xindex ; tmp < row_ptr[i+1] ; ++tmp)
81 sum = sum + val[tmp] * y[yindex + col_ind[tmp] * columns];
83 r[i * columns + j] = alpha * sum;
89 FEAT_PRAGMA_OMP(parallel
for)
90 for (
Index i = 0 ; i < rows ; ++i)
92 for (
Index j = 0 ; j < columns ; ++j)
95 Index xindex = row_ptr[i];
97 for (
Index tmp = xindex ; tmp < row_ptr[i+1] ; ++tmp)
99 sum = sum + val[tmp] * y[yindex + col_ind[tmp] * columns];
101 r[i * columns + j] = beta * r[i * columns + j] + alpha * sum;
T_ abs(T_ x)
Returns the absolute value.
std::uint64_t Index
Index data type.