36 #ifndef NEKTAR_LIB_UTILITIES_LINEAR_ALGEBRA_MATRIX_OPERATIONS_DECLARATIONS_HPP
37 #define NEKTAR_LIB_UTILITIES_LINEAR_ALGEBRA_MATRIX_OPERATIONS_DECLARATIONS_HPP
53 #include <boost/utility/enable_if.hpp>
54 #include <boost/type_traits.hpp>
63 template<
typename DataType,
typename LhsDataType,
typename MatrixType>
65 Multiply(
const NekMatrix<LhsDataType, MatrixType>&
lhs,
66 const NekVector<DataType>& rhs);
68 template<
typename DataType,
typename LhsDataType,
typename MatrixType>
69 void Multiply(NekVector<DataType>& result,
70 const NekMatrix<LhsDataType, MatrixType>&
lhs,
71 const NekVector<DataType>& rhs);
73 template<
typename DataType,
typename LhsInnerMatrixType>
74 void Multiply(NekVector<DataType>& result,
75 const NekMatrix<LhsInnerMatrixType, BlockMatrixTag>&
lhs,
76 const NekVector<DataType>& rhs);
81 template<
typename ResultDataType,
typename LhsDataType,
typename LhsMatrixType>
82 void Multiply(NekMatrix<ResultDataType, StandardMatrixTag>& result,
83 const NekMatrix<LhsDataType, LhsMatrixType>&
lhs,
84 const ResultDataType& rhs);
86 template<
typename DataType,
typename LhsDataType,
typename LhsMatrixType>
87 NekMatrix<typename NekMatrix<LhsDataType, LhsMatrixType>::NumberType, StandardMatrixTag>
88 Multiply(
const NekMatrix<LhsDataType, LhsMatrixType>&
lhs,
91 template<
typename RhsDataType,
typename RhsMatrixType,
typename ResultDataType>
92 void Multiply(NekMatrix<ResultDataType, StandardMatrixTag>& result,
93 const ResultDataType&
lhs,
94 const NekMatrix<RhsDataType, RhsMatrixType>& rhs);
96 template<
typename DataType,
typename RhsDataType,
typename RhsMatrixType>
97 NekMatrix<typename NekMatrix<RhsDataType, RhsMatrixType>::NumberType, StandardMatrixTag>
99 const NekMatrix<RhsDataType, RhsMatrixType>& rhs);
101 template<
typename LhsDataType>
103 typename boost::call_traits<LhsDataType>::const_reference rhs);
113 const NekMatrix<LhsDataType, LhsMatrixType>&
lhs,
114 const NekMatrix<RhsDataType, RhsMatrixType>& rhs,
115 typename boost::enable_if
124 ASSERTL1(lhs.GetType() ==
eFULL && rhs.GetType() ==
eFULL,
"Only full matrices are supported.");
126 unsigned int M = lhs.GetRows();
127 unsigned int N = rhs.GetColumns();
128 unsigned int K = lhs.GetColumns();
130 unsigned int LDA = M;
131 if( lhs.GetTransposeFlag() ==
'T' )
136 unsigned int LDB = K;
137 if( rhs.GetTransposeFlag() ==
'T' )
142 Blas::Dgemm(lhs.GetTransposeFlag(), rhs.GetTransposeFlag(), M, N, K,
143 lhs.Scale()*rhs.Scale(), lhs.GetRawPtr(), LDA,
144 rhs.GetRawPtr(), LDB, 0.0,
145 result.GetRawPtr(), result.GetRows());
152 void Multiply(NekMatrix<DataType, StandardMatrixTag>& result,
153 const NekMatrix<LhsDataType, LhsMatrixType>&
lhs,
154 const NekMatrix<RhsDataType, RhsMatrixType>& rhs);
156 template<
typename RhsInnerType,
typename RhsMatrixType>
158 const NekMatrix<RhsInnerType, RhsMatrixType>& rhs,
159 typename boost::enable_if
163 boost::is_same<
typename RawType<
typename NekMatrix<RhsInnerType, RhsMatrixType>::NumberType>::type,
double>,
168 ASSERTL0(result.GetType() ==
eFULL && rhs.GetType() ==
eFULL,
"Only full matrices supported.");
169 unsigned int M = result.GetRows();
170 unsigned int N = rhs.GetColumns();
171 unsigned int K = result.GetColumns();
173 unsigned int LDA = M;
174 if( result.GetTransposeFlag() ==
'T' )
179 unsigned int LDB = K;
180 if( rhs.GetTransposeFlag() ==
'T' )
184 double scale = rhs.Scale();
185 Array<OneD, double>& buf = result.GetTempSpace();
186 Blas::Dgemm(result.GetTransposeFlag(), rhs.GetTransposeFlag(), M, N, K,
187 scale, result.GetRawPtr(), LDA, rhs.GetRawPtr(), LDB, 0.0,
188 buf.data(), result.GetRows());
189 result.SetSize(result.GetRows(), rhs.GetColumns());
190 result.SwapTempAndDataBuffers();
193 template<
typename DataType,
typename RhsInnerType,
typename RhsMatrixType>
195 const NekMatrix<RhsInnerType, RhsMatrixType>& rhs,
196 typename boost::enable_if
200 boost::mpl::not_<boost::is_same<
typename RawType<
typename NekMatrix<RhsInnerType, RhsMatrixType>::NumberType>::type,
double> >,
201 boost::mpl::not_<
CanGetRawPtr<NekMatrix<RhsInnerType, RhsMatrixType> > >
205 ASSERTL1(result.
GetColumns() == rhs.GetRows(), std::string(
"A left side matrix with column count ") +
206 boost::lexical_cast<std::string>(result.
GetColumns()) +
207 std::string(
" and a right side matrix with row count ") +
208 boost::lexical_cast<std::string>(rhs.GetRows()) + std::string(
" can't be multiplied."));
211 for(
unsigned int i = 0; i < result.
GetRows(); ++i)
213 for(
unsigned int j = 0; j < result.
GetColumns(); ++j)
215 DataType t = DataType(0);
218 for(
unsigned int k = 0; k < result.
GetColumns(); ++k)
220 t += result(i,k)*rhs(k,j);
231 NekMatrix<typename boost::remove_const<typename NekMatrix<LhsDataType, LhsMatrixType>::NumberType>::type, StandardMatrixTag>
233 const NekMatrix<RhsDataType, RhsMatrixType>& rhs)
235 typedef typename boost::remove_const<typename NekMatrix<LhsDataType, LhsMatrixType>::NumberType>::type NumberType;
236 NekMatrix<NumberType, StandardMatrixTag> result(lhs.GetRows(), rhs.GetColumns());
247 template<
typename DataType,
typename RhsDataType,
typename RhsMatrixType>
248 void AddEqual(NekMatrix<DataType, StandardMatrixTag>& result,
249 const NekMatrix<RhsDataType, RhsMatrixType>& rhs);
252 template<
typename DataType,
typename LhsDataType,
typename LhsMatrixType,
typename RhsDataType,
typename RhsMatrixType>
253 void Add(NekMatrix<DataType, StandardMatrixTag>& result,
254 const NekMatrix<LhsDataType, LhsMatrixType>&
lhs,
255 const NekMatrix<RhsDataType, RhsMatrixType>& rhs);
258 template<
typename LhsDataType,
typename LhsMatrixType,
typename RhsDataType,
typename RhsMatrixType>
259 NekMatrix<typename NekMatrix<LhsDataType, LhsMatrixType>::NumberType, StandardMatrixTag>
260 Add(
const NekMatrix<LhsDataType, LhsMatrixType>&
lhs,
261 const NekMatrix<RhsDataType, RhsMatrixType>& rhs);
263 template<
typename DataType,
typename LhsDataType,
typename LhsMatrixType,
typename RhsDataType,
typename RhsMatrixType>
264 void AddNegatedLhs(NekMatrix<DataType, StandardMatrixTag>& result,
265 const NekMatrix<LhsDataType, LhsMatrixType>&
lhs,
266 const NekMatrix<RhsDataType, RhsMatrixType>& rhs);
268 template<
typename DataType,
typename RhsDataType,
typename RhsMatrixType>
270 const NekMatrix<RhsDataType, RhsMatrixType>& rhs);
276 template<
typename DataType,
typename LhsDataType,
typename LhsMatrixType,
typename RhsDataType,
typename RhsMatrixType>
277 void Subtract(NekMatrix<DataType, StandardMatrixTag>& result,
278 const NekMatrix<LhsDataType, LhsMatrixType>&
lhs,
279 const NekMatrix<RhsDataType, RhsMatrixType>& rhs);
281 template<
typename DataType,
typename LhsDataType,
typename LhsMatrixType,
typename RhsDataType,
typename RhsMatrixType>
283 const NekMatrix<LhsDataType, LhsMatrixType>&
lhs,
284 const NekMatrix<RhsDataType, RhsMatrixType>& rhs);
286 template<
typename DataType,
typename RhsDataType,
typename RhsMatrixType>
287 void SubtractEqual(NekMatrix<DataType, StandardMatrixTag>& result,
288 const NekMatrix<RhsDataType, RhsMatrixType>& rhs);
290 template<
typename DataType,
typename RhsDataType,
typename RhsMatrixType>
292 const NekMatrix<RhsDataType, RhsMatrixType>& rhs);
294 template<
typename LhsDataType,
typename LhsMatrixType,
typename RhsDataType,
typename RhsMatrixType>
295 NekMatrix<typename NekMatrix<LhsDataType, LhsMatrixType>::NumberType, StandardMatrixTag>
296 Subtract(
const NekMatrix<LhsDataType, LhsMatrixType>&
lhs,
297 const NekMatrix<RhsDataType, RhsMatrixType>& rhs);