7#ifndef KERNEL_LAFEM_ARCH_SCALE_ROW_COL_GENERIC_HPP
8#define KERNEL_LAFEM_ARCH_SCALE_ROW_COL_GENERIC_HPP 1
10#ifndef KERNEL_LAFEM_ARCH_SCALE_ROW_COL_HPP
11#error "Do not include this implementation-only header file directly!"
14#include <kernel/util/tiny_algebra.hpp>
23 template <
typename DT_,
typename IT_>
24 void ScaleRows::csr_generic(DT_ * r,
const DT_ *
const a,
const IT_ *
const ,
25 const IT_ *
const row_ptr,
const DT_ *
const x,
28 FEAT_PRAGMA_OMP(parallel
for)
29 for (
Index row = 0 ; row < rows ; ++row)
31 const IT_ end(row_ptr[row + 1]);
32 for (IT_ i = row_ptr[row] ; i < end ; ++i)
39 template <
int bh_,
int bw_,
typename DT_,
typename IT_>
40 void ScaleRows::bcsr_generic(DT_ * r,
const DT_ *
const a,
const IT_ *
const ,
41 const IT_ *
const row_ptr,
const DT_ *
const x,
45 Tiny::Matrix<DT_, bh_, bw_> *
const br(
reinterpret_cast<Tiny::Matrix<DT_, bh_, bw_> *
>(r));
46 const Tiny::Matrix<DT_, bh_, bw_> *
const ba(
reinterpret_cast<const Tiny::Matrix<DT_, bh_, bw_> *
>(a));
47 const Tiny::Vector<DT_, bh_> *
const bx(
reinterpret_cast<const Tiny::Vector<DT_, bh_> *
>(x));
48 FEAT_PRAGMA_OMP(parallel
for)
49 for (
Index row = 0 ; row < rows ; ++row)
51 const IT_ end(row_ptr[row + 1]);
52 for (IT_ i = row_ptr[row] ; i < end ; ++i)
54 for (
int irow(0); irow < bh_; ++ irow )
56 for (
int icol(0); icol < bw_; ++icol)
58 br[i][irow][icol] = ba[i][irow][icol] * bx[row][irow];
67 template <
typename DT_,
typename IT_>
68 void ScaleCols::csr_generic(DT_ * r,
const DT_ *
const a,
const IT_ *
const col_ind,
69 const IT_ *
const row_ptr,
const DT_ *
const x,
72 FEAT_PRAGMA_OMP(parallel
for)
73 for (
Index row = 0 ; row < rows ; ++row)
75 const IT_ end(row_ptr[row + 1]);
76 for (IT_ i = row_ptr[row] ; i < end ; ++i)
78 r[i] = a[i] * x[col_ind[i]];
83 template <
int bh_,
int bw_,
typename DT_,
typename IT_>
84 void ScaleCols::bcsr_generic(DT_ * r,
const DT_ *
const a,
const IT_ *
const col_ind,
85 const IT_ *
const row_ptr,
const DT_ *
const x,
89 Tiny::Matrix<DT_, bh_, bw_> *
const br(
reinterpret_cast<Tiny::Matrix<DT_, bh_, bw_> *
>(r));
90 const Tiny::Matrix<DT_, bh_, bw_> *
const ba(
reinterpret_cast<const Tiny::Matrix<DT_, bh_, bw_> *
>(a));
91 const Tiny::Vector<DT_, bw_> *
const bx(
reinterpret_cast<const Tiny::Vector<DT_, bw_> *
>(x));
92 FEAT_PRAGMA_OMP(parallel
for)
93 for (
Index row = 0 ; row < rows ; ++row)
95 const IT_ end(row_ptr[row + 1]);
96 for (IT_ i = row_ptr[row] ; i < end ; ++i)
98 for (
int irow = 0; irow < bh_; ++irow)
100 for (
int icol = 0; icol < bw_; ++icol)
102 br[i][irow][icol] = ba[i][irow][icol] * bx[col_ind[i]][icol];
std::uint64_t Index
Index data type.