36 #ifndef NEKTAR_LIB_UTILITIES_LINEAR_ALGEBRA_MATRIX_OPERATIONS_DECLARATIONS_HPP
37 #define NEKTAR_LIB_UTILITIES_LINEAR_ALGEBRA_MATRIX_OPERATIONS_DECLARATIONS_HPP
53 #include <boost/utility/enable_if.hpp>
54 #include <boost/type_traits.hpp>
63 template<
typename DataType,
typename LhsDataType,
typename MatrixType>
65 Multiply(
const NekMatrix<LhsDataType, MatrixType>&
lhs,
66 const NekVector<DataType>& rhs);
68 template<
typename DataType,
typename LhsDataType,
typename MatrixType>
69 void Multiply(NekVector<DataType>& result,
70 const NekMatrix<LhsDataType, MatrixType>&
lhs,
71 const NekVector<DataType>& rhs);
73 template<
typename DataType,
typename LhsInnerMatrixType>
74 void Multiply(NekVector<DataType>& result,
75 const NekMatrix<LhsInnerMatrixType, BlockMatrixTag>&
lhs,
76 const NekVector<DataType>& rhs);
79 const NekMatrix<NekMatrix<NekMatrix<NekDouble, StandardMatrixTag>, ScaledMatrixTag>, BlockMatrixTag>&
lhs,
80 const NekVector<double>& rhs);
85 template<
typename ResultDataType,
typename LhsDataType,
typename LhsMatrixType>
86 void Multiply(NekMatrix<ResultDataType, StandardMatrixTag>& result,
87 const NekMatrix<LhsDataType, LhsMatrixType>&
lhs,
88 const ResultDataType& rhs);
90 template<
typename DataType,
typename LhsDataType,
typename LhsMatrixType>
91 NekMatrix<typename NekMatrix<LhsDataType, LhsMatrixType>::NumberType, StandardMatrixTag>
92 Multiply(
const NekMatrix<LhsDataType, LhsMatrixType>&
lhs,
95 template<
typename RhsDataType,
typename RhsMatrixType,
typename ResultDataType>
96 void Multiply(NekMatrix<ResultDataType, StandardMatrixTag>& result,
97 const ResultDataType&
lhs,
98 const NekMatrix<RhsDataType, RhsMatrixType>& rhs);
100 template<
typename DataType,
typename RhsDataType,
typename RhsMatrixType>
101 NekMatrix<typename NekMatrix<RhsDataType, RhsMatrixType>::NumberType, StandardMatrixTag>
103 const NekMatrix<RhsDataType, RhsMatrixType>& rhs);
105 template<
typename LhsDataType>
107 typename boost::call_traits<LhsDataType>::const_reference rhs);
119 typename boost::enable_if
128 ASSERTL1(lhs.GetType() ==
eFULL && rhs.GetType() ==
eFULL,
"Only full matrices are supported.");
130 unsigned int M = lhs.GetRows();
131 unsigned int N = rhs.GetColumns();
132 unsigned int K = lhs.GetColumns();
134 unsigned int LDA = M;
135 if( lhs.GetTransposeFlag() ==
'T' )
140 unsigned int LDB = K;
141 if( rhs.GetTransposeFlag() ==
'T' )
146 Blas::Dgemm(lhs.GetTransposeFlag(), rhs.GetTransposeFlag(), M, N, K,
147 lhs.Scale()*rhs.Scale(), lhs.GetRawPtr(), LDA,
148 rhs.GetRawPtr(), LDB, 0.0,
149 result.GetRawPtr(), result.GetRows());
156 void Multiply(NekMatrix<DataType, StandardMatrixTag>& result,
157 const NekMatrix<LhsDataType, LhsMatrixType>&
lhs,
158 const NekMatrix<RhsDataType, RhsMatrixType>& rhs);
160 template<
typename RhsInnerType,
typename RhsMatrixType>
163 typename boost::enable_if
172 ASSERTL0(result.GetType() ==
eFULL && rhs.GetType() ==
eFULL,
"Only full matrices supported.");
173 unsigned int M = result.GetRows();
174 unsigned int N = rhs.GetColumns();
175 unsigned int K = result.GetColumns();
177 unsigned int LDA = M;
178 if( result.GetTransposeFlag() ==
'T' )
183 unsigned int LDB = K;
184 if( rhs.GetTransposeFlag() ==
'T' )
188 double scale = rhs.Scale();
190 Blas::Dgemm(result.GetTransposeFlag(), rhs.GetTransposeFlag(), M, N, K,
191 scale, result.GetRawPtr(), LDA, rhs.GetRawPtr(), LDB, 0.0,
192 buf.data(), result.GetRows());
193 result.SetSize(result.GetRows(), rhs.GetColumns());
194 result.SwapTempAndDataBuffers();
197 template<
typename DataType,
typename RhsInnerType,
typename RhsMatrixType>
200 typename boost::enable_if
209 ASSERTL1(result.
GetColumns() == rhs.GetRows(), std::string(
"A left side matrix with column count ") +
210 boost::lexical_cast<std::string>(result.
GetColumns()) +
211 std::string(
" and a right side matrix with row count ") +
212 boost::lexical_cast<std::string>(rhs.GetRows()) + std::string(
" can't be multiplied."));
215 for(
unsigned int i = 0; i < result.
GetRows(); ++i)
217 for(
unsigned int j = 0; j < result.
GetColumns(); ++j)
219 DataType t = DataType(0);
222 for(
unsigned int k = 0; k < result.
GetColumns(); ++k)
224 t += result(i,k)*rhs(k,j);
235 NekMatrix<typename boost::remove_const<typename NekMatrix<LhsDataType, LhsMatrixType>::NumberType>::type, StandardMatrixTag>
239 typedef typename boost::remove_const<typename NekMatrix<LhsDataType, LhsMatrixType>::NumberType>::type NumberType;
251 template<
typename DataType,
typename RhsDataType,
typename RhsMatrixType>
252 void AddEqual(NekMatrix<DataType, StandardMatrixTag>& result,
253 const NekMatrix<RhsDataType, RhsMatrixType>& rhs);
256 template<
typename DataType,
typename LhsDataType,
typename LhsMatrixType,
typename RhsDataType,
typename RhsMatrixType>
257 void Add(NekMatrix<DataType, StandardMatrixTag>& result,
258 const NekMatrix<LhsDataType, LhsMatrixType>&
lhs,
259 const NekMatrix<RhsDataType, RhsMatrixType>& rhs);
262 template<
typename LhsDataType,
typename LhsMatrixType,
typename RhsDataType,
typename RhsMatrixType>
263 NekMatrix<typename NekMatrix<LhsDataType, LhsMatrixType>::NumberType, StandardMatrixTag>
264 Add(
const NekMatrix<LhsDataType, LhsMatrixType>&
lhs,
265 const NekMatrix<RhsDataType, RhsMatrixType>& rhs);
267 template<
typename DataType,
typename LhsDataType,
typename LhsMatrixType,
typename RhsDataType,
typename RhsMatrixType>
268 void AddNegatedLhs(NekMatrix<DataType, StandardMatrixTag>& result,
269 const NekMatrix<LhsDataType, LhsMatrixType>&
lhs,
270 const NekMatrix<RhsDataType, RhsMatrixType>& rhs);
272 template<
typename DataType,
typename RhsDataType,
typename RhsMatrixType>
274 const NekMatrix<RhsDataType, RhsMatrixType>& rhs);
280 template<
typename DataType,
typename LhsDataType,
typename LhsMatrixType,
typename RhsDataType,
typename RhsMatrixType>
281 void Subtract(NekMatrix<DataType, StandardMatrixTag>& result,
282 const NekMatrix<LhsDataType, LhsMatrixType>&
lhs,
283 const NekMatrix<RhsDataType, RhsMatrixType>& rhs);
285 template<
typename DataType,
typename LhsDataType,
typename LhsMatrixType,
typename RhsDataType,
typename RhsMatrixType>
287 const NekMatrix<LhsDataType, LhsMatrixType>&
lhs,
288 const NekMatrix<RhsDataType, RhsMatrixType>& rhs);
290 template<
typename DataType,
typename RhsDataType,
typename RhsMatrixType>
291 void SubtractEqual(NekMatrix<DataType, StandardMatrixTag>& result,
292 const NekMatrix<RhsDataType, RhsMatrixType>& rhs);
294 template<
typename DataType,
typename RhsDataType,
typename RhsMatrixType>
296 const NekMatrix<RhsDataType, RhsMatrixType>& rhs);
298 template<
typename LhsDataType,
typename LhsMatrixType,
typename RhsDataType,
typename RhsMatrixType>
299 NekMatrix<typename NekMatrix<LhsDataType, LhsMatrixType>::NumberType, StandardMatrixTag>
300 Subtract(
const NekMatrix<LhsDataType, LhsMatrixType>&
lhs,
301 const NekMatrix<RhsDataType, RhsMatrixType>& rhs);
void SubtractNegatedLhs(NekMatrix< DataType, StandardMatrixTag > &result, const NekMatrix< LhsDataType, LhsMatrixType > &lhs, const NekMatrix< RhsDataType, RhsMatrixType > &rhs)
#define ASSERTL0(condition, msg)
void DiagonalBlockFullScalMatrixMultiply(NekVector< double > &result, const NekMatrix< NekMatrix< NekMatrix< NekDouble, StandardMatrixTag >, ScaledMatrixTag >, BlockMatrixTag > &lhs, const NekVector< double > &rhs)
void MultiplyEqual(NekMatrix< LhsDataType, StandardMatrixTag > &lhs, typename boost::call_traits< LhsDataType >::const_reference rhs)
void NekMultiplyFullMatrixFullMatrix(NekMatrix< ResultType, StandardMatrixTag > &result, const NekMatrix< LhsDataType, LhsMatrixType > &lhs, const NekMatrix< RhsDataType, RhsMatrixType > &rhs)
RhsMatrixType void AddEqual(NekMatrix< DataType, StandardMatrixTag > &result, const NekMatrix< RhsDataType, RhsMatrixType > &rhs)
void Multiply(NekMatrix< ResultDataType, StandardMatrixTag > &result, const NekMatrix< LhsDataType, LhsMatrixType > &lhs, const ResultDataType &rhs)
unsigned int GetColumns() const
void AddEqualNegatedLhs(NekMatrix< DataType, StandardMatrixTag > &result, const NekMatrix< RhsDataType, RhsMatrixType > &rhs)
DNekMat void SubtractEqual(NekMatrix< DataType, StandardMatrixTag > &result, const NekMatrix< RhsDataType, RhsMatrixType > &rhs)
void SubtractEqualNegatedLhs(NekMatrix< DataType, StandardMatrixTag > &result, const NekMatrix< RhsDataType, RhsMatrixType > &rhs)
#define LIB_UTILITIES_EXPORT
void AddNegatedLhs(NekMatrix< DataType, StandardMatrixTag > &result, const NekMatrix< LhsDataType, LhsMatrixType > &lhs, const NekMatrix< RhsDataType, RhsMatrixType > &rhs)
unsigned int GetRows() const
DNekMat void Add(NekMatrix< DataType, StandardMatrixTag > &result, const NekMatrix< LhsDataType, LhsMatrixType > &lhs, const NekMatrix< RhsDataType, RhsMatrixType > &rhs)
RhsMatrixType void Subtract(NekMatrix< DataType, StandardMatrixTag > &result, const NekMatrix< LhsDataType, LhsMatrixType > &lhs, const NekMatrix< RhsDataType, RhsMatrixType > &rhs)
#define ASSERTL1(condition, msg)
Assert Level 1 – Debugging which is used whether in FULLDEBUG or DEBUG compilation mode...