34 #ifndef NEKTAR_LIB_UTILITIES_LINEAR_ALGEBRA_DGEMM_OVERRIDE_HPP
35 #define NEKTAR_LIB_UTILITIES_LINEAR_ALGEBRA_DGEMM_OVERRIDE_HPP
37 #ifdef NEKTAR_USE_EXPRESSION_TEMPLATES
39 #include <ExpressionTemplates/ExpressionTemplates.hpp>
40 #include <boost/utility/enable_if.hpp>
41 #include <boost/type_traits.hpp>
48 template<
typename ADataType,
typename BDataType,
49 typename AMatrixType,
typename BMatrixType>
50 void Dgemm(NekMatrix<double, StandardMatrixTag>& result,
51 double alpha,
const NekMatrix<ADataType, AMatrixType>& A,
const NekMatrix<BDataType, BMatrixType>& B)
53 if( A.GetType() !=
eFULL || B.GetType() !=
eFULL )
60 unsigned int M = A.GetRows();
61 unsigned int N = B.GetColumns();
62 unsigned int K = A.GetColumns();
65 if( A.GetTransposeFlag() ==
'T' )
71 if( B.GetTransposeFlag() ==
'T' )
76 Blas::Dgemm(A.GetTransposeFlag(), B.GetTransposeFlag(), M, N, K,
77 A.Scale()*B.Scale()*alpha, A.GetRawPtr(), LDA,
78 B.GetRawPtr(), LDB, 0.0,
79 result.GetRawPtr(), result.GetRows());
82 template<
typename ADataType,
typename BDataType,
83 typename AMatrixType,
typename BMatrixType>
84 void Dgemm(NekMatrix<double, StandardMatrixTag>& result,
85 const NekMatrix<ADataType, AMatrixType>& A,
double alpha,
const NekMatrix<BDataType, BMatrixType>& B)
87 Dgemm(result, alpha, A, B);
90 template<
typename ADataType,
typename BDataType,
91 typename AMatrixType,
typename BMatrixType>
92 void Dgemm(NekMatrix<double, StandardMatrixTag>& result,
93 const NekMatrix<ADataType, AMatrixType>& A,
const NekMatrix<BDataType, BMatrixType>& B,
double alpha)
95 Dgemm(result, alpha, A, B);
98 template<
typename ADataType,
typename BDataType,
typename CDataType,
99 typename AMatrixType,
typename BMatrixType,
typename CMatrixType>
100 void Dgemm(NekMatrix<double, StandardMatrixTag>& result,
101 double alpha,
const NekMatrix<ADataType, AMatrixType>& A,
const NekMatrix<BDataType, BMatrixType>& B,
102 double beta,
const NekMatrix<CDataType, CMatrixType>& C)
104 if( A.GetType() !=
eFULL || B.GetType() !=
eFULL || C.GetType() !=
eFULL )
108 NekMatrix<double> temp = beta*C;
113 if( C.GetTransposeFlag() ==
'T' )
117 result.SetSize(C.GetRows(), C.GetColumns());
119 for(
unsigned int i = 0; i < C.GetRows(); ++i)
121 for(
unsigned int j = 0; j < C.GetColumns(); ++j)
123 result(i,j) = C(i,j);
132 unsigned int M = A.GetRows();
133 unsigned int N = B.GetColumns();
134 unsigned int K = A.GetColumns();
136 unsigned int LDA = M;
137 if( A.GetTransposeFlag() ==
'T' )
142 unsigned int LDB = K;
143 if( B.GetTransposeFlag() ==
'T' )
148 Blas::Dgemm(A.GetTransposeFlag(), B.GetTransposeFlag(), M, N, K,
149 A.Scale()*B.Scale()*alpha, A.GetRawPtr(), LDA,
150 B.GetRawPtr(), LDB, beta*result.Scale(),
151 result.GetRawPtr(), result.GetRows());
154 template<
typename ADataType,
typename BDataType,
typename CDataType,
155 typename AMatrixType,
typename BMatrixType,
typename CMatrixType>
156 void Dgemm(NekMatrix<double, StandardMatrixTag>& result,
157 const NekMatrix<ADataType, AMatrixType>& A,
double alpha,
const NekMatrix<BDataType, BMatrixType>& B,
158 double beta,
const NekMatrix<CDataType, CMatrixType>& C)
160 Dgemm(result, alpha, A, B, beta, C);
163 template<
typename ADataType,
typename BDataType,
typename CDataType,
164 typename AMatrixType,
typename BMatrixType,
typename CMatrixType>
165 void Dgemm(NekMatrix<double, StandardMatrixTag>& result,
166 const NekMatrix<ADataType, AMatrixType>& A,
const NekMatrix<BDataType, BMatrixType>& B,
double alpha,
167 double beta,
const NekMatrix<CDataType, CMatrixType>& C)
169 Dgemm(result, alpha, A, B, beta, C);
180 template<
typename NodeType,
typename IndicesType,
unsigned int index>
181 struct DgemmNodeEvaluator
183 typedef EvaluateNodeWithTemporaryIfNeeded<NodeType, IndicesType, index> Evaluator;
184 typedef typename Evaluator::ResultType Type;
185 static double GetScale() {
return 1.0; }
188 template<
typename LhsType,
typename IndicesType,
unsigned int index>
189 struct DgemmNodeEvaluator<Node<LhsType, NegateOp>, IndicesType, index>
191 typedef EvaluateNodeWithTemporaryIfNeeded<LhsType, IndicesType, index> Evaluator;
192 typedef typename Evaluator::ResultType Type;
193 static double GetScale() {
return -1.0; }
200 template<
typename OpType>
203 static double GetScale() {
return 1.0; }
207 struct BetaScale<SubtractOp>
209 static double GetScale() {
return -1.0; }
213 struct IsDouble :
public boost::false_type{};
216 struct IsDouble<double> :
public boost::true_type {};
222 template<
typename NodeType,
typename IndicesType,
unsigned int index,
typename enabled=
void>
223 struct AlphaABParameterAccessImpl :
public boost::false_type {};
225 template<
typename A1,
typename A2,
typename A3,
typename IndicesType,
unsigned int index>
226 struct AlphaABParameterAccessImpl< Node<Node<A1, MultiplyOp, A2>, MultiplyOp, A3>, IndicesType, index,
227 typename boost::enable_if
229 Test3ArgumentAssociativeNode<Node<Node<A1, MultiplyOp, A2>, MultiplyOp, A3>, IsDouble, MultiplyOp,
230 Nektar::CanGetRawPtr, MultiplyOp, Nektar::CanGetRawPtr>
237 >::type> :
public boost::true_type
239 typedef DgemmNodeEvaluator<A1, IndicesType, index> AlphaWrappedEvaluator;
240 typedef typename AlphaWrappedEvaluator::Evaluator AlphaEvaluator;
242 static const unsigned int A2Index = index + A1::TotalCount;
243 typedef DgemmNodeEvaluator<A2, IndicesType, A2Index> AWrappedEvaluator;
244 typedef typename AWrappedEvaluator::Evaluator AEvaluator;
246 static const unsigned int A3Index = A2Index + A2::TotalCount;
247 typedef DgemmNodeEvaluator<A3, IndicesType, A3Index> BWrappedEvaluator;
248 typedef typename BWrappedEvaluator::Evaluator BEvaluator;
250 template<
typename ArgumentVectorType>
251 static double GetAlpha(
const ArgumentVectorType& args)
253 return AlphaEvaluator::Evaluate(args) * AWrappedEvaluator::GetScale() * BWrappedEvaluator::GetScale();
257 template<
typename A1,
typename A2,
typename A3,
typename IndicesType,
unsigned int index>
258 struct AlphaABParameterAccessImpl< Node<Node<A1, MultiplyOp, A2>, MultiplyOp, A3>, IndicesType, index,
259 typename boost::enable_if
261 Test3ArgumentAssociativeNode<Node<Node<A1, MultiplyOp, A2>, MultiplyOp, A3>, Nektar::CanGetRawPtr, MultiplyOp,
262 IsDouble, MultiplyOp, Nektar::CanGetRawPtr>
269 >::type> :
public boost::true_type
271 typedef DgemmNodeEvaluator<A1, IndicesType, index> AWrappedEvaluator;
272 typedef typename AWrappedEvaluator::Evaluator AEvaluator;
274 static const unsigned int A2Index = index + A1::TotalCount;
275 typedef DgemmNodeEvaluator<A2, IndicesType, A2Index> AlphaWrappedEvaluator;
276 typedef typename AlphaWrappedEvaluator::Evaluator AlphaEvaluator;
278 static const unsigned int A3Index = A2Index + A2::TotalCount;
279 typedef DgemmNodeEvaluator<A3, IndicesType, A3Index> BWrappedEvaluator;
280 typedef typename BWrappedEvaluator::Evaluator BEvaluator;
282 template<
typename ArgumentVectorType>
283 static double GetAlpha(
const ArgumentVectorType& args)
285 return AlphaEvaluator::Evaluate(args) * AWrappedEvaluator::GetScale() * BWrappedEvaluator::GetScale();
289 template<
typename A1,
typename A2,
typename A3,
typename IndicesType,
unsigned int index>
290 struct AlphaABParameterAccessImpl< Node<Node<A1, MultiplyOp, A2>, MultiplyOp, A3>, IndicesType, index,
291 typename boost::enable_if
293 Test3ArgumentAssociativeNode<Node<Node<A1, MultiplyOp, A2>, MultiplyOp, A3>, Nektar::CanGetRawPtr, MultiplyOp,
294 Nektar::CanGetRawPtr, MultiplyOp, IsDouble>
301 >::type> :
public boost::true_type
303 typedef DgemmNodeEvaluator<A1, IndicesType, index> AWrappedEvaluator;
304 typedef typename AWrappedEvaluator::Evaluator AEvaluator;
306 static const unsigned int A2Index = index + A1::TotalCount;
307 typedef DgemmNodeEvaluator<A2, IndicesType, A2Index> BWrappedEvaluator;
308 typedef typename BWrappedEvaluator::Evaluator BEvaluator;
310 static const unsigned int A3Index = A2Index + A2::TotalCount;
311 typedef DgemmNodeEvaluator<A3, IndicesType, A3Index> AlphaWrappedEvaluator;
312 typedef typename AlphaWrappedEvaluator::Evaluator AlphaEvaluator;
314 template<
typename ArgumentVectorType>
315 static double GetAlpha(
const ArgumentVectorType& args)
317 return AlphaEvaluator::Evaluate(args) * AWrappedEvaluator::GetScale() * BWrappedEvaluator::GetScale();
322 template<
typename A1,
typename A2,
typename IndicesType,
unsigned int index>
323 struct AlphaABParameterAccessImpl< Node<A1, MultiplyOp, A2>, IndicesType, index,
324 typename boost::enable_if
328 TestBinaryNode<Node<A1, MultiplyOp, A2>, Nektar::CanGetRawPtr, MultiplyOp, Nektar::CanGetRawPtr>,
329 boost::mpl::not_<Test3ArgumentAssociativeNode<Node<A1, MultiplyOp, A2>, Nektar::CanGetRawPtr, MultiplyOp, Nektar::CanGetRawPtr, MultiplyOp, IsDouble> >,
330 boost::mpl::not_<Test3ArgumentAssociativeNode<Node<A1, MultiplyOp, A2>, Nektar::CanGetRawPtr, MultiplyOp, IsDouble, MultiplyOp, Nektar::CanGetRawPtr> >,
331 boost::mpl::not_<Test3ArgumentAssociativeNode<Node<A1, MultiplyOp, A2>, IsDouble, MultiplyOp, Nektar::CanGetRawPtr, MultiplyOp, Nektar::CanGetRawPtr> >
339 >::type> :
public boost::true_type
341 typedef DgemmNodeEvaluator<A1, IndicesType, index> AWrappedEvaluator;
342 typedef typename AWrappedEvaluator::Evaluator AEvaluator;
344 static const unsigned int A2Index = index + A1::TotalCount;
345 typedef DgemmNodeEvaluator<A2, IndicesType, A2Index> BWrappedEvaluator;
346 typedef typename BWrappedEvaluator::Evaluator BEvaluator;
348 template<
typename ArgumentVectorType>
349 static double GetAlpha(
const ArgumentVectorType& args)
351 return 1.0 * AWrappedEvaluator::GetScale() * BWrappedEvaluator::GetScale();
356 template<
typename NodeType,
typename IndicesType,
unsigned int index>
357 struct AlphaABParameterAccess :
public AlphaABParameterAccessImpl<NodeType, IndicesType, index> {};
359 template<
typename T,
typename IndicesType,
unsigned int index>
360 struct AlphaABParameterAccess<Node<T, NegateOp,
void>, IndicesType, index> :
public AlphaABParameterAccessImpl<T, IndicesType, index>
362 typedef Node<T, NegateOp, void> NodeType;
364 template<
typename ArgumentVectorType>
365 static double GetAlpha(
const ArgumentVectorType& args)
367 return -AlphaABParameterAccessImpl<T, IndicesType, index>::GetAlpha(args);
372 template<
typename NodeType,
typename IndicesType,
unsigned int index,
typename enabled=
void>
373 struct BetaCParameterAccessImpl :
public boost::false_type {};
375 template<
typename T,
typename IndicesType,
unsigned int index>
376 struct BetaCParameterAccessImpl< Node<T,
void,
void>, IndicesType, index> :
public boost::true_type
378 typedef Node<T, void, void> NodeType;
379 typedef DgemmNodeEvaluator<NodeType, IndicesType, index> CWrappedEvaluator;
380 typedef typename CWrappedEvaluator::Evaluator CEvaluator;
382 template<
typename ArgumentVectorType>
383 static double GetBeta(
const ArgumentVectorType& args)
389 template<
typename L,
typename R,
typename IndicesType,
unsigned int index>
390 struct BetaCParameterAccessImpl< Node<L, expt::MultiplyOp, R>, IndicesType, index,
391 typename boost::enable_if
395 Nektar::CanGetRawPtr<typename L::ResultType>,
396 boost::is_same<typename R::ResultType, double>
398 >::type> :
public boost::true_type
400 typedef DgemmNodeEvaluator<L, IndicesType, index> CWrappedEvaluator;
401 typedef typename CWrappedEvaluator::Evaluator CEvaluator;
403 static const unsigned int nextIndex = index + L::TotalCount;
404 typedef DgemmNodeEvaluator<R, IndicesType, nextIndex> BetaWrappedEvaluator;
405 typedef typename BetaWrappedEvaluator::Evaluator BetaEvaluator;
407 template<
typename ArgumentVectorType>
408 static double GetBeta(
const ArgumentVectorType& args)
410 return BetaEvaluator::Evaluate(args) * CWrappedEvaluator::GetScale();
414 template<
typename L,
typename R,
typename IndicesType,
unsigned int index>
415 struct BetaCParameterAccessImpl< Node<L, expt::MultiplyOp, R>, IndicesType, index,
416 typename boost::enable_if
420 Nektar::CanGetRawPtr<typename R::ResultType>,
421 boost::is_same<typename L::ResultType, double>
423 >::type> :
public boost::true_type
425 typedef DgemmNodeEvaluator<L, IndicesType, index> BetaWrappedEvaluator;
426 typedef typename BetaWrappedEvaluator::Evaluator BetaEvaluator;
428 static const unsigned int nextIndex = index + L::TotalCount;
429 typedef DgemmNodeEvaluator<R, IndicesType, nextIndex> CWrappedEvaluator;
430 typedef typename CWrappedEvaluator::Evaluator CEvaluator;
432 template<
typename ArgumentVectorType>
433 static double GetBeta(
const ArgumentVectorType& args)
435 return BetaEvaluator::Evaluate(args) * CWrappedEvaluator::GetScale();
439 template<
typename NodeType,
typename IndicesType,
unsigned int index>
440 struct BetaCParameterAccess :
public BetaCParameterAccessImpl<NodeType, IndicesType, index> {};
442 template<
typename T,
typename IndicesType,
unsigned int index>
443 struct BetaCParameterAccess<Node<T, NegateOp,
void>, IndicesType, index> :
public BetaCParameterAccessImpl<T, IndicesType, index>
445 template<
typename ArgumentVectorType>
446 static double GetBeta(
const ArgumentVectorType& args)
448 return -BetaCParameterAccessImpl<T, IndicesType, index>::GetBeta(args);
464 template<
typename L,
typename OpType,
typename R,
typename IndicesType,
unsigned int index>
465 struct BinaryBinaryEvaluateNodeOverride<L, OpType, R, IndicesType, index,
466 typename boost::enable_if
470 IsAdditiveOperator<OpType>,
471 impl::AlphaABParameterAccess<L, IndicesType, index>,
472 impl::BetaCParameterAccess<R, IndicesType, index + L::TotalCount>
476 >::type> :
public boost::true_type
478 typedef Node<L, OpType, R> NodeType;
479 typedef impl::AlphaABParameterAccess<L, IndicesType, index> LhsAccess;
480 static const unsigned int rhsIndex = index + L::TotalCount;
481 typedef impl::BetaCParameterAccess<R, IndicesType, rhsIndex> RhsAccess;
482 typedef typename LhsAccess::AEvaluator AEvaluator;
483 typedef typename LhsAccess::BEvaluator BEvaluator;
484 typedef typename RhsAccess::CEvaluator CEvaluator;
486 template<
typename ResultType,
typename ArgumentVectorType>
487 static void Evaluate(ResultType& accumulator,
const ArgumentVectorType& args)
489 typename AEvaluator::ResultType a = AEvaluator::Evaluate(args);
490 typename BEvaluator::ResultType b = BEvaluator::Evaluate(args);
491 typename CEvaluator::ResultType c = CEvaluator::Evaluate(args);
493 double alpha = LhsAccess::GetAlpha(args);
494 double beta = RhsAccess::GetBeta(args);
496 beta *= impl::BetaScale<OpType>::GetScale();
497 Nektar::Dgemm(accumulator, alpha, a, b, beta, c);
504 template<
typename L,
typename OpType,
typename R,
typename IndicesType,
unsigned int index>
505 struct BinaryBinaryEvaluateNodeOverride<L, OpType, R, IndicesType, index,
506 typename boost::enable_if
510 IsAdditiveOperator<OpType>,
511 impl::AlphaABParameterAccess<R, IndicesType, index + L::TotalCount>,
512 impl::BetaCParameterAccess<L, IndicesType, index>,
513 boost::mpl::not_<impl::AlphaABParameterAccess<L, IndicesType, index> >
518 >::type> :
public boost::true_type
520 typedef Node<L, OpType, R> NodeType;
522 typedef impl::BetaCParameterAccess<L, IndicesType, index> LhsAccess;
523 static const unsigned int rhsIndex = index + L::TotalCount;
524 typedef impl::AlphaABParameterAccess<R, IndicesType, rhsIndex> RhsAccess;
526 typedef typename RhsAccess::AEvaluator AEvaluator;
527 typedef typename RhsAccess::BEvaluator BEvaluator;
528 typedef typename LhsAccess::CEvaluator CEvaluator;
530 template<
typename ResultType,
typename ArgumentVectorType>
531 static void Evaluate(ResultType& accumulator,
const ArgumentVectorType& args)
533 typename AEvaluator::ResultType a = AEvaluator::Evaluate(args);
534 typename BEvaluator::ResultType b = BEvaluator::Evaluate(args);
535 typename CEvaluator::ResultType c = CEvaluator::Evaluate(args);
537 double alpha = RhsAccess::GetAlpha(args);
538 double beta = LhsAccess::GetBeta(args);
540 alpha *= impl::BetaScale<OpType>::GetScale();
541 Nektar::Dgemm(accumulator, alpha, a, b, beta, c);
548 template<
typename L,
typename OpType,
typename R,
typename IndicesType,
unsigned int index>
549 struct BinaryBinaryEvaluateNodeOverride<L, OpType, R, IndicesType, index,
550 typename boost::enable_if
553 impl::AlphaABParameterAccess<Node<L, OpType, R>, IndicesType, index>
554 >::type> :
public boost::true_type
556 typedef Node<L, OpType, R> NodeType;
557 typedef impl::AlphaABParameterAccess<NodeType, IndicesType, index> LhsAccess;
558 typedef typename LhsAccess::AEvaluator AEvaluator;
559 typedef typename LhsAccess::BEvaluator BEvaluator;
561 template<
typename ResultType,
typename ArgumentVectorType>
562 static void Evaluate(ResultType& accumulator,
const ArgumentVectorType& args)
564 typename AEvaluator::ResultType a = AEvaluator::Evaluate(args);
565 typename BEvaluator::ResultType b = BEvaluator::Evaluate(args);
567 double alpha = LhsAccess::GetAlpha(args);
568 Nektar::Dgemm(accumulator, alpha, a, b);