35 #ifndef NEKTAR_LIB_UTILITIES_LINEAR_ALGEBRA_MATRIX_OPERATIONS_DECLARATIONS_HPP
36 #define NEKTAR_LIB_UTILITIES_LINEAR_ALGEBRA_MATRIX_OPERATIONS_DECLARATIONS_HPP
38 #include <boost/core/ignore_unused.hpp>
52 #include <type_traits>
59 template <
typename DataType,
typename LhsDataType,
typename MatrixType>
60 NekVector<DataType>
Multiply(
const NekMatrix<LhsDataType, MatrixType> &lhs,
61 const NekVector<DataType> &rhs);
63 template <
typename DataType,
typename LhsDataType,
typename MatrixType>
64 void Multiply(NekVector<DataType> &result,
65 const NekMatrix<LhsDataType, MatrixType> &lhs,
66 const NekVector<DataType> &rhs);
68 template <
typename DataType,
typename LhsInnerMatrixType>
69 void Multiply(NekVector<DataType> &result,
70 const NekMatrix<LhsInnerMatrixType, BlockMatrixTag> &lhs,
71 const NekVector<DataType> &rhs);
73 template <
typename DataType,
typename LhsDataType,
typename MatrixType>
81 NekVector<double> &result,
83 NekMatrix<NekMatrix<NekDouble, StandardMatrixTag>, ScaledMatrixTag>,
85 const NekVector<double> &rhs);
90 NekMatrix<NekMatrix<NekSingle, StandardMatrixTag>, ScaledMatrixTag>,
97 template <
typename ResultDataType,
typename LhsDataType,
typename LhsMatrixType>
98 void Multiply(NekMatrix<ResultDataType, StandardMatrixTag> &result,
99 const NekMatrix<LhsDataType, LhsMatrixType> &lhs,
100 const ResultDataType &rhs);
102 template <
typename DataType,
typename LhsDataType,
typename LhsMatrixType>
103 NekMatrix<typename NekMatrix<LhsDataType, LhsMatrixType>::NumberType,
105 Multiply(
const NekMatrix<LhsDataType, LhsMatrixType> &lhs,
const DataType &rhs);
107 template <
typename RhsDataType,
typename RhsMatrixType,
typename ResultDataType>
108 void Multiply(NekMatrix<ResultDataType, StandardMatrixTag> &result,
109 const ResultDataType &lhs,
110 const NekMatrix<RhsDataType, RhsMatrixType> &rhs);
112 template <
typename DataType,
typename RhsDataType,
typename RhsMatrixType>
113 NekMatrix<typename NekMatrix<RhsDataType, RhsMatrixType>::NumberType,
115 Multiply(
const DataType &lhs,
const NekMatrix<RhsDataType, RhsMatrixType> &rhs);
117 template <
typename DataType,
typename RhsDataType,
typename RhsMatrixType>
118 NekMatrix<typename NekMatrix<RhsDataType, RhsMatrixType>::NumberType,
125 template <
typename DataType,
typename RhsDataType,
typename RhsMatrixType>
126 NekMatrix<typename NekMatrix<RhsDataType, RhsMatrixType>::NumberType,
133 template <
typename LhsDataType>
135 NekMatrix<LhsDataType, StandardMatrixTag> &lhs,
136 typename boost::call_traits<LhsDataType>::const_reference rhs);
142 template <
typename LhsDataType,
typename RhsDataType,
typename LhsMatrixType,
143 typename RhsMatrixType>
148 typename std::enable_if<
153 boost::ignore_unused(
p);
156 "Only full matrices are supported.");
158 unsigned int M = lhs.GetRows();
159 unsigned int N = rhs.GetColumns();
160 unsigned int K = lhs.GetColumns();
162 unsigned int LDA = M;
163 if (lhs.GetTransposeFlag() ==
'T')
168 unsigned int LDB = K;
169 if (rhs.GetTransposeFlag() ==
'T')
174 Blas::Gemm(lhs.GetTransposeFlag(), rhs.GetTransposeFlag(), M, N, K,
175 lhs.Scale() * rhs.Scale(), lhs.GetRawPtr(), LDA,
176 rhs.GetRawPtr(), LDB, 0.0, result.GetRawPtr(),
180 template <
typename LhsDataType,
typename RhsDataType,
typename DataType,
181 typename LhsMatrixType,
typename RhsMatrixType>
182 void Multiply(NekMatrix<DataType, StandardMatrixTag> &result,
183 const NekMatrix<LhsDataType, LhsMatrixType> &lhs,
184 const NekMatrix<RhsDataType, RhsMatrixType> &rhs);
186 template <
typename RhsInnerType,
typename RhsMatrixType>
190 typename std::enable_if<
192 RhsInnerType, RhsMatrixType>::NumberType>,
193 RhsInnerType>::value &&
197 boost::ignore_unused(t);
199 "Only full matrices supported.");
200 unsigned int M = result.GetRows();
201 unsigned int N = rhs.GetColumns();
202 unsigned int K = result.GetColumns();
204 unsigned int LDA = M;
205 if (result.GetTransposeFlag() ==
'T')
210 unsigned int LDB = K;
211 if (rhs.GetTransposeFlag() ==
'T')
215 RhsInnerType scale = rhs.Scale();
217 Blas::Gemm(result.GetTransposeFlag(), rhs.GetTransposeFlag(), M, N, K,
218 scale, result.GetRawPtr(), LDA, rhs.GetRawPtr(), LDB, 0.0,
219 buf.data(), result.GetRows());
220 result.SetSize(result.GetRows(), rhs.GetColumns());
221 result.SwapTempAndDataBuffers();
224 template <
typename DataType,
typename RhsInnerType,
typename RhsMatrixType>
228 typename std::enable_if<
230 RhsInnerType, RhsMatrixType>::NumberType>,
235 boost::ignore_unused(t);
237 std::string(
"A left side matrix with column count ") +
239 std::string(
" and a right side matrix with row count ") +
240 std::to_string(rhs.GetRows()) +
241 std::string(
" can't be multiplied."));
245 for (
unsigned int i = 0; i < result.
GetRows(); ++i)
247 for (
unsigned int j = 0; j < result.
GetColumns(); ++j)
249 DataType t = DataType(0);
252 for (
unsigned int k = 0; k < result.
GetColumns(); ++k)
254 t += result(i, k) * rhs(k, j);
263 template <
typename LhsDataType,
typename RhsDataType,
typename LhsMatrixType,
264 typename RhsMatrixType>
265 NekMatrix<
typename std::remove_const<
266 typename NekMatrix<LhsDataType, LhsMatrixType>::NumberType>::type,
271 typedef typename std::remove_const<
280 template <
typename LhsDataType,
typename RhsDataType,
typename LhsMatrixType,
281 typename RhsMatrixType>
282 NekMatrix<
typename std::remove_const<
283 typename NekMatrix<LhsDataType, LhsMatrixType>::NumberType>::type,
295 template <
typename DataType,
typename RhsDataType,
typename RhsMatrixType>
296 void AddEqual(NekMatrix<DataType, StandardMatrixTag> &result,
297 const NekMatrix<RhsDataType, RhsMatrixType> &rhs);
299 template <
typename DataType,
typename LhsDataType,
typename LhsMatrixType,
300 typename RhsDataType,
typename RhsMatrixType>
301 void Add(NekMatrix<DataType, StandardMatrixTag> &result,
302 const NekMatrix<LhsDataType, LhsMatrixType> &lhs,
303 const NekMatrix<RhsDataType, RhsMatrixType> &rhs);
305 template <
typename LhsDataType,
typename LhsMatrixType,
typename RhsDataType,
306 typename RhsMatrixType>
307 NekMatrix<typename NekMatrix<LhsDataType, LhsMatrixType>::NumberType,
309 Add(
const NekMatrix<LhsDataType, LhsMatrixType> &lhs,
310 const NekMatrix<RhsDataType, RhsMatrixType> &rhs);
312 template <
typename LhsDataType,
typename LhsMatrixType,
typename RhsDataType,
313 typename RhsMatrixType>
314 NekMatrix<typename NekMatrix<LhsDataType, LhsMatrixType>::NumberType,
319 return Add(lhs, rhs);
322 template <
typename DataType,
typename LhsDataType,
typename LhsMatrixType,
323 typename RhsDataType,
typename RhsMatrixType>
324 void AddNegatedLhs(NekMatrix<DataType, StandardMatrixTag> &result,
325 const NekMatrix<LhsDataType, LhsMatrixType> &lhs,
326 const NekMatrix<RhsDataType, RhsMatrixType> &rhs);
328 template <
typename DataType,
typename RhsDataType,
typename RhsMatrixType>
330 const NekMatrix<RhsDataType, RhsMatrixType> &rhs);
335 template <
typename DataType,
typename LhsDataType,
typename LhsMatrixType,
336 typename RhsDataType,
typename RhsMatrixType>
337 void Subtract(NekMatrix<DataType, StandardMatrixTag> &result,
338 const NekMatrix<LhsDataType, LhsMatrixType> &lhs,
339 const NekMatrix<RhsDataType, RhsMatrixType> &rhs);
341 template <
typename DataType,
typename LhsDataType,
typename LhsMatrixType,
342 typename RhsDataType,
typename RhsMatrixType>
344 const NekMatrix<LhsDataType, LhsMatrixType> &lhs,
345 const NekMatrix<RhsDataType, RhsMatrixType> &rhs);
347 template <
typename DataType,
typename RhsDataType,
typename RhsMatrixType>
348 void SubtractEqual(NekMatrix<DataType, StandardMatrixTag> &result,
349 const NekMatrix<RhsDataType, RhsMatrixType> &rhs);
351 template <
typename DataType,
typename RhsDataType,
typename RhsMatrixType>
353 const NekMatrix<RhsDataType, RhsMatrixType> &rhs);
355 template <
typename LhsDataType,
typename LhsMatrixType,
typename RhsDataType,
356 typename RhsMatrixType>
357 NekMatrix<typename NekMatrix<LhsDataType, LhsMatrixType>::NumberType,
359 Subtract(
const NekMatrix<LhsDataType, LhsMatrixType> &lhs,
360 const NekMatrix<RhsDataType, RhsMatrixType> &rhs);
362 template <
typename LhsDataType,
typename LhsMatrixType,
typename RhsDataType,
363 typename RhsMatrixType>
364 NekMatrix<typename NekMatrix<LhsDataType, LhsMatrixType>::NumberType,
#define ASSERTL0(condition, msg)
#define ASSERTL1(condition, msg)
Assert Level 1 – Debugging which is used whether in FULLDEBUG or DEBUG compilation mode....
#define LIB_UTILITIES_EXPORT
unsigned int GetRows() const
unsigned int GetColumns() const
static void Gemm(const char &transa, const char &transb, const int &m, const int &n, const int &k, const double &alpha, const double *a, const int &lda, const double *b, const int &ldb, const double &beta, double *c, const int &ldc)
BLAS level 3: Matrix-matrix multiply C = A x B where op(A)[m x k], op(B)[k x n], C[m x n] DGEMM perfo...
The above copyright notice and this permission notice shall be included.
SNekMat SNekMat void SubtractEqual(NekMatrix< DataType, StandardMatrixTag > &result, const NekMatrix< RhsDataType, RhsMatrixType > &rhs)
NekVector< DataType > operator*(const NekMatrix< LhsDataType, MatrixType > &lhs, const NekVector< DataType > &rhs)
void AddEqualNegatedLhs(NekMatrix< DataType, StandardMatrixTag > &result, const NekMatrix< RhsDataType, RhsMatrixType > &rhs)
SNekMat void AddEqual(NekMatrix< DataType, StandardMatrixTag > &result, const NekMatrix< RhsDataType, RhsMatrixType > &rhs)
void Subtract(NekMatrix< DataType, StandardMatrixTag > &result, const NekMatrix< LhsDataType, LhsMatrixType > &lhs, const NekMatrix< RhsDataType, RhsMatrixType > &rhs)
Array< OneD, DataType > operator+(const Array< OneD, DataType > &lhs, typename Array< OneD, DataType >::size_type offset)
void Multiply(NekMatrix< ResultDataType, StandardMatrixTag > &result, const NekMatrix< LhsDataType, LhsMatrixType > &lhs, const ResultDataType &rhs)
void NekMultiplyFullMatrixFullMatrix(NekMatrix< ResultType, StandardMatrixTag > &result, const NekMatrix< LhsDataType, LhsMatrixType > &lhs, const NekMatrix< RhsDataType, RhsMatrixType > &rhs)
void AddNegatedLhs(NekMatrix< DataType, StandardMatrixTag > &result, const NekMatrix< LhsDataType, LhsMatrixType > &lhs, const NekMatrix< RhsDataType, RhsMatrixType > &rhs)
void DiagonalBlockFullScalMatrixMultiply(NekVector< double > &result, const NekMatrix< NekMatrix< NekMatrix< NekDouble, StandardMatrixTag >, ScaledMatrixTag >, BlockMatrixTag > &lhs, const NekVector< double > &rhs)
void SubtractEqualNegatedLhs(NekMatrix< DataType, StandardMatrixTag > &result, const NekMatrix< RhsDataType, RhsMatrixType > &rhs)
typename RawType< T >::type RawType_t
const NekSingle void MultiplyEqual(NekMatrix< LhsDataType, StandardMatrixTag > &lhs, typename boost::call_traits< LhsDataType >::const_reference rhs)
void SubtractNegatedLhs(NekMatrix< DataType, StandardMatrixTag > &result, const NekMatrix< LhsDataType, LhsMatrixType > &lhs, const NekMatrix< RhsDataType, RhsMatrixType > &rhs)
NekMatrix< typename NekMatrix< LhsDataType, LhsMatrixType >::NumberType, StandardMatrixTag > operator-(const NekMatrix< LhsDataType, LhsMatrixType > &lhs, const NekMatrix< RhsDataType, RhsMatrixType > &rhs)
SNekMat SNekMat void Add(NekMatrix< DataType, StandardMatrixTag > &result, const NekMatrix< LhsDataType, LhsMatrixType > &lhs, const NekMatrix< RhsDataType, RhsMatrixType > &rhs)