35#ifndef NEKTAR_LIB_UTILITIES_LINEAR_ALGEBRA_MATRIX_OPERATIONS_DECLARATIONS_HPP
36#define NEKTAR_LIB_UTILITIES_LINEAR_ALGEBRA_MATRIX_OPERATIONS_DECLARATIONS_HPP
57template <
typename DataType,
typename LhsDataType,
typename MatrixType>
58NekVector<DataType>
Multiply(
const NekMatrix<LhsDataType, MatrixType> &lhs,
59 const NekVector<DataType> &rhs);
61template <
typename DataType,
typename LhsDataType,
typename MatrixType>
62void Multiply(NekVector<DataType> &result,
63 const NekMatrix<LhsDataType, MatrixType> &lhs,
64 const NekVector<DataType> &rhs);
66template <
typename DataType,
typename LhsInnerMatrixType>
67void Multiply(NekVector<DataType> &result,
68 const NekMatrix<LhsInnerMatrixType, BlockMatrixTag> &lhs,
69 const NekVector<DataType> &rhs);
71template <
typename DataType,
typename LhsDataType,
typename MatrixType>
79 NekVector<double> &result,
81 NekMatrix<NekMatrix<NekDouble, StandardMatrixTag>, ScaledMatrixTag>,
83 const NekVector<double> &rhs);
88 NekMatrix<NekMatrix<NekSingle, StandardMatrixTag>, ScaledMatrixTag>,
95template <
typename ResultDataType,
typename LhsDataType,
typename LhsMatrixType>
96void Multiply(NekMatrix<ResultDataType, StandardMatrixTag> &result,
97 const NekMatrix<LhsDataType, LhsMatrixType> &lhs,
98 const ResultDataType &rhs);
100template <
typename DataType,
typename LhsDataType,
typename LhsMatrixType>
101NekMatrix<typename NekMatrix<LhsDataType, LhsMatrixType>::NumberType,
103Multiply(
const NekMatrix<LhsDataType, LhsMatrixType> &lhs,
const DataType &rhs);
105template <
typename RhsDataType,
typename RhsMatrixType,
typename ResultDataType>
106void Multiply(NekMatrix<ResultDataType, StandardMatrixTag> &result,
107 const ResultDataType &lhs,
108 const NekMatrix<RhsDataType, RhsMatrixType> &rhs);
110template <
typename DataType,
typename RhsDataType,
typename RhsMatrixType>
111NekMatrix<typename NekMatrix<RhsDataType, RhsMatrixType>::NumberType,
113Multiply(
const DataType &lhs,
const NekMatrix<RhsDataType, RhsMatrixType> &rhs);
115template <
typename DataType,
typename RhsDataType,
typename RhsMatrixType>
116NekMatrix<typename NekMatrix<RhsDataType, RhsMatrixType>::NumberType,
123template <
typename DataType,
typename RhsDataType,
typename RhsMatrixType>
124NekMatrix<typename NekMatrix<RhsDataType, RhsMatrixType>::NumberType,
131template <
typename LhsDataType>
133 NekMatrix<LhsDataType, StandardMatrixTag> &lhs,
134 typename boost::call_traits<LhsDataType>::const_reference rhs);
140template <
typename LhsDataType,
typename RhsDataType,
typename LhsMatrixType,
141 typename RhsMatrixType>
146 [[maybe_unused]]
typename std::enable_if<
152 "Only full matrices are supported.");
154 unsigned int M = lhs.GetRows();
155 unsigned int N = rhs.GetColumns();
156 unsigned int K = lhs.GetColumns();
158 unsigned int LDA = M;
159 if (lhs.GetTransposeFlag() ==
'T')
164 unsigned int LDB = K;
165 if (rhs.GetTransposeFlag() ==
'T')
170 Blas::Gemm(lhs.GetTransposeFlag(), rhs.GetTransposeFlag(), M, N, K,
171 lhs.Scale() * rhs.Scale(), lhs.GetRawPtr(), LDA, rhs.GetRawPtr(),
172 LDB, 0.0, result.GetRawPtr(), result.GetRows());
175template <
typename LhsDataType,
typename RhsDataType,
typename DataType,
176 typename LhsMatrixType,
typename RhsMatrixType>
177void Multiply(NekMatrix<DataType, StandardMatrixTag> &result,
178 const NekMatrix<LhsDataType, LhsMatrixType> &lhs,
179 const NekMatrix<RhsDataType, RhsMatrixType> &rhs);
181template <
typename RhsInnerType,
typename RhsMatrixType>
185 [[maybe_unused]]
typename std::enable_if<
187 RhsMatrixType>::NumberType>,
188 RhsInnerType>::value &&
193 "Only full matrices supported.");
194 unsigned int M = result.GetRows();
195 unsigned int N = rhs.GetColumns();
196 unsigned int K = result.GetColumns();
198 unsigned int LDA = M;
199 if (result.GetTransposeFlag() ==
'T')
204 unsigned int LDB = K;
205 if (rhs.GetTransposeFlag() ==
'T')
209 RhsInnerType scale = rhs.Scale();
211 Blas::Gemm(result.GetTransposeFlag(), rhs.GetTransposeFlag(), M, N, K,
212 scale, result.GetRawPtr(), LDA, rhs.GetRawPtr(), LDB, 0.0,
213 buf.data(), result.GetRows());
214 result.SetSize(result.GetRows(), rhs.GetColumns());
215 result.SwapTempAndDataBuffers();
218template <
typename DataType,
typename RhsInnerType,
typename RhsMatrixType>
222 [[maybe_unused]]
typename std::enable_if<
224 RhsMatrixType>::NumberType>,
230 std::string(
"A left side matrix with column count ") +
232 std::string(
" and a right side matrix with row count ") +
233 std::to_string(rhs.GetRows()) +
234 std::string(
" can't be multiplied."));
238 for (
unsigned int i = 0; i < result.
GetRows(); ++i)
240 for (
unsigned int j = 0; j < result.
GetColumns(); ++j)
242 DataType t = DataType(0);
245 for (
unsigned int k = 0; k < result.
GetColumns(); ++k)
247 t += result(i, k) * rhs(k, j);
256template <
typename LhsDataType,
typename RhsDataType,
typename LhsMatrixType,
257 typename RhsMatrixType>
258NekMatrix<
typename std::remove_const<
259 typename NekMatrix<LhsDataType, LhsMatrixType>::NumberType>::type,
264 typedef typename std::remove_const<
273template <
typename LhsDataType,
typename RhsDataType,
typename LhsMatrixType,
274 typename RhsMatrixType>
275NekMatrix<
typename std::remove_const<
276 typename NekMatrix<LhsDataType, LhsMatrixType>::NumberType>::type,
288template <
typename DataType,
typename RhsDataType,
typename RhsMatrixType>
289void AddEqual(NekMatrix<DataType, StandardMatrixTag> &result,
290 const NekMatrix<RhsDataType, RhsMatrixType> &rhs);
292template <
typename DataType,
typename LhsDataType,
typename LhsMatrixType,
293 typename RhsDataType,
typename RhsMatrixType>
294void Add(NekMatrix<DataType, StandardMatrixTag> &result,
295 const NekMatrix<LhsDataType, LhsMatrixType> &lhs,
296 const NekMatrix<RhsDataType, RhsMatrixType> &rhs);
298template <
typename LhsDataType,
typename LhsMatrixType,
typename RhsDataType,
299 typename RhsMatrixType>
300NekMatrix<typename NekMatrix<LhsDataType, LhsMatrixType>::NumberType,
302Add(
const NekMatrix<LhsDataType, LhsMatrixType> &lhs,
303 const NekMatrix<RhsDataType, RhsMatrixType> &rhs);
305template <
typename LhsDataType,
typename LhsMatrixType,
typename RhsDataType,
306 typename RhsMatrixType>
307NekMatrix<typename NekMatrix<LhsDataType, LhsMatrixType>::NumberType,
312 return Add(lhs, rhs);
315template <
typename DataType,
typename LhsDataType,
typename LhsMatrixType,
316 typename RhsDataType,
typename RhsMatrixType>
317void AddNegatedLhs(NekMatrix<DataType, StandardMatrixTag> &result,
318 const NekMatrix<LhsDataType, LhsMatrixType> &lhs,
319 const NekMatrix<RhsDataType, RhsMatrixType> &rhs);
321template <
typename DataType,
typename RhsDataType,
typename RhsMatrixType>
323 const NekMatrix<RhsDataType, RhsMatrixType> &rhs);
328template <
typename DataType,
typename LhsDataType,
typename LhsMatrixType,
329 typename RhsDataType,
typename RhsMatrixType>
330void Subtract(NekMatrix<DataType, StandardMatrixTag> &result,
331 const NekMatrix<LhsDataType, LhsMatrixType> &lhs,
332 const NekMatrix<RhsDataType, RhsMatrixType> &rhs);
334template <
typename DataType,
typename LhsDataType,
typename LhsMatrixType,
335 typename RhsDataType,
typename RhsMatrixType>
337 const NekMatrix<LhsDataType, LhsMatrixType> &lhs,
338 const NekMatrix<RhsDataType, RhsMatrixType> &rhs);
340template <
typename DataType,
typename RhsDataType,
typename RhsMatrixType>
341void SubtractEqual(NekMatrix<DataType, StandardMatrixTag> &result,
342 const NekMatrix<RhsDataType, RhsMatrixType> &rhs);
344template <
typename DataType,
typename RhsDataType,
typename RhsMatrixType>
346 const NekMatrix<RhsDataType, RhsMatrixType> &rhs);
348template <
typename LhsDataType,
typename LhsMatrixType,
typename RhsDataType,
349 typename RhsMatrixType>
350NekMatrix<typename NekMatrix<LhsDataType, LhsMatrixType>::NumberType,
352Subtract(
const NekMatrix<LhsDataType, LhsMatrixType> &lhs,
353 const NekMatrix<RhsDataType, RhsMatrixType> &rhs);
355template <
typename LhsDataType,
typename LhsMatrixType,
typename RhsDataType,
356 typename RhsMatrixType>
357NekMatrix<typename NekMatrix<LhsDataType, LhsMatrixType>::NumberType,
#define ASSERTL0(condition, msg)
#define ASSERTL1(condition, msg)
Assert Level 1 – Debugging which is used whether in FULLDEBUG or DEBUG compilation mode....
#define LIB_UTILITIES_EXPORT
unsigned int GetRows() const
unsigned int GetColumns() const
static void Gemm(const char &transa, const char &transb, const int &m, const int &n, const int &k, const double &alpha, const double *a, const int &lda, const double *b, const int &ldb, const double &beta, double *c, const int &ldc)
BLAS level 3: Matrix-matrix multiply C = A x B where op(A)[m x k], op(B)[k x n], C[m x n] DGEMM perfo...
SNekMat SNekMat void SubtractEqual(NekMatrix< DataType, StandardMatrixTag > &result, const NekMatrix< RhsDataType, RhsMatrixType > &rhs)
NekMatrix< typename NekMatrix< LhsDataType, LhsMatrixType >::NumberType, StandardMatrixTag > operator-(const NekMatrix< LhsDataType, LhsMatrixType > &lhs, const NekMatrix< RhsDataType, RhsMatrixType > &rhs)
void AddEqualNegatedLhs(NekMatrix< DataType, StandardMatrixTag > &result, const NekMatrix< RhsDataType, RhsMatrixType > &rhs)
SNekMat void AddEqual(NekMatrix< DataType, StandardMatrixTag > &result, const NekMatrix< RhsDataType, RhsMatrixType > &rhs)
void Subtract(NekMatrix< DataType, StandardMatrixTag > &result, const NekMatrix< LhsDataType, LhsMatrixType > &lhs, const NekMatrix< RhsDataType, RhsMatrixType > &rhs)
Array< OneD, DataType > operator+(const Array< OneD, DataType > &lhs, typename Array< OneD, DataType >::size_type offset)
void Multiply(NekMatrix< ResultDataType, StandardMatrixTag > &result, const NekMatrix< LhsDataType, LhsMatrixType > &lhs, const ResultDataType &rhs)
void NekMultiplyFullMatrixFullMatrix(NekMatrix< ResultType, StandardMatrixTag > &result, const NekMatrix< LhsDataType, LhsMatrixType > &lhs, const NekMatrix< RhsDataType, RhsMatrixType > &rhs)
NekVector< DataType > operator*(const NekMatrix< LhsDataType, MatrixType > &lhs, const NekVector< DataType > &rhs)
void AddNegatedLhs(NekMatrix< DataType, StandardMatrixTag > &result, const NekMatrix< LhsDataType, LhsMatrixType > &lhs, const NekMatrix< RhsDataType, RhsMatrixType > &rhs)
void DiagonalBlockFullScalMatrixMultiply(NekVector< double > &result, const NekMatrix< NekMatrix< NekMatrix< NekDouble, StandardMatrixTag >, ScaledMatrixTag >, BlockMatrixTag > &lhs, const NekVector< double > &rhs)
void SubtractEqualNegatedLhs(NekMatrix< DataType, StandardMatrixTag > &result, const NekMatrix< RhsDataType, RhsMatrixType > &rhs)
typename RawType< T >::type RawType_t
const NekSingle void MultiplyEqual(NekMatrix< LhsDataType, StandardMatrixTag > &lhs, typename boost::call_traits< LhsDataType >::const_reference rhs)
void SubtractNegatedLhs(NekMatrix< DataType, StandardMatrixTag > &result, const NekMatrix< LhsDataType, LhsMatrixType > &lhs, const NekMatrix< RhsDataType, RhsMatrixType > &rhs)
SNekMat SNekMat void Add(NekMatrix< DataType, StandardMatrixTag > &result, const NekMatrix< LhsDataType, LhsMatrixType > &lhs, const NekMatrix< RhsDataType, RhsMatrixType > &rhs)