34 #ifndef NEKTAR_LIB_UTILITIES_LINEAR_ALGEBRA_DGEMV_OVERRIDE_HPP
35 #define NEKTAR_LIB_UTILITIES_LINEAR_ALGEBRA_DGEMV_OVERRIDE_HPP
37 #ifdef NEKTAR_USE_EXPRESSION_TEMPLATES
39 #include <ExpressionTemplates/ExpressionTemplates.hpp>
40 #include <boost/utility/enable_if.hpp>
41 #include <boost/type_traits.hpp>
48 template<
typename ADataType,
typename AMatrixType>
49 void Dgemv(NekVector<double>& result,
50 double alpha,
const NekMatrix<ADataType, AMatrixType>& A,
const NekVector<double>& x)
52 if( A.GetType() !=
eFULL )
59 unsigned int M = A.GetRows();
60 unsigned int N = A.GetColumns();
62 char t = A.GetTransposeFlag();
70 Blas::Dgemv(t, M, N, alpha*A.Scale(), A.GetRawPtr(), lda, x.GetRawPtr(), 1, 0.0, result.GetRawPtr(), 1);
73 template<
typename ADataType,
typename AMatrixType>
74 void Dgemv(NekVector<double>& result,
75 double alpha,
const NekMatrix<ADataType, AMatrixType>& A,
const NekVector<double>& x,
76 double beta,
const NekVector<double>& y)
78 if( A.GetType() !=
eFULL)
82 NekVector<double> temp = beta*y;
88 unsigned int M = A.GetRows();
89 unsigned int N = A.GetColumns();
91 char t = A.GetTransposeFlag();
98 Blas::Dgemv(t, M, N, alpha*A.Scale(), A.GetRawPtr(), lda, x.GetRawPtr(), 1, beta, result.GetRawPtr(), 1);
109 template<
typename NodeType,
typename IndicesType,
unsigned int index>
110 struct DgemvNodeEvaluator
112 typedef EvaluateNodeWithTemporaryIfNeeded<NodeType, IndicesType, index> Evaluator;
113 typedef typename Evaluator::ResultType Type;
114 static double GetScale() {
return 1.0; }
117 template<
typename LhsType,
typename IndicesType,
unsigned int index>
118 struct DgemvNodeEvaluator<Node<LhsType, NegateOp>, IndicesType, index>
120 typedef EvaluateNodeWithTemporaryIfNeeded<LhsType, IndicesType, index> Evaluator;
121 typedef typename Evaluator::ResultType Type;
122 static double GetScale() {
return -1.0; }
129 template<
typename OpType>
132 static double GetScale() {
return 1.0; }
136 struct BetaScale<SubtractOp>
138 static double GetScale() {
return -1.0; }
144 template<
typename NodeType,
typename IndicesType,
unsigned int index,
typename enabled=
void>
145 struct AlphaAXParameterAccessImpl :
public boost::false_type {};
147 template<
typename A1,
typename A2,
typename A3,
typename IndicesType,
unsigned int index>
148 struct AlphaAXParameterAccessImpl< Node<Node<A1, MultiplyOp, A2>, MultiplyOp, A3>, IndicesType, index,
149 typename boost::enable_if
151 Test3ArgumentAssociativeNode<Node<Node<A1, MultiplyOp, A2>, MultiplyOp, A3>, IsDouble, MultiplyOp,
152 Nektar::CanGetRawPtr, MultiplyOp, Nektar::IsVector>
159 >::type> :
public boost::true_type
161 typedef DgemvNodeEvaluator<A1, IndicesType, index> AlphaWrappedEvaluator;
162 typedef typename AlphaWrappedEvaluator::Evaluator AlphaEvaluator;
164 static const unsigned int A2Index = index + A1::TotalCount;
165 typedef DgemvNodeEvaluator<A2, IndicesType, A2Index> AWrappedEvaluator;
166 typedef typename AWrappedEvaluator::Evaluator AEvaluator;
168 static const unsigned int A3Index = A2Index + A2::TotalCount;
169 typedef DgemvNodeEvaluator<A3, IndicesType, A3Index> XWrappedEvaluator;
170 typedef typename XWrappedEvaluator::Evaluator BEvaluator;
172 template<
typename ArgumentVectorType>
173 static double GetAlpha(
const ArgumentVectorType& args)
175 return AlphaEvaluator::Evaluate(args) * AWrappedEvaluator::GetScale() * XWrappedEvaluator::GetScale();
179 template<
typename A1,
typename A2,
typename A3,
typename IndicesType,
unsigned int index>
180 struct AlphaAXParameterAccessImpl< Node<Node<A1, MultiplyOp, A2>, MultiplyOp, A3>, IndicesType, index,
181 typename boost::enable_if
183 Test3ArgumentAssociativeNode<Node<Node<A1, MultiplyOp, A2>, MultiplyOp, A3>, Nektar::CanGetRawPtr, MultiplyOp,
184 IsDouble, MultiplyOp, Nektar::IsVector>
191 >::type> :
public boost::true_type
193 typedef DgemvNodeEvaluator<A1, IndicesType, index> AWrappedEvaluator;
194 typedef typename AWrappedEvaluator::Evaluator AEvaluator;
196 static const unsigned int A2Index = index + A1::TotalCount;
197 typedef DgemvNodeEvaluator<A2, IndicesType, A2Index> AlphaWrappedEvaluator;
198 typedef typename AlphaWrappedEvaluator::Evaluator AlphaEvaluator;
200 static const unsigned int A3Index = A2Index + A2::TotalCount;
201 typedef DgemvNodeEvaluator<A3, IndicesType, A3Index> XWrappedEvaluator;
202 typedef typename XWrappedEvaluator::Evaluator BEvaluator;
204 template<
typename ArgumentVectorType>
205 static double GetAlpha(
const ArgumentVectorType& args)
207 return AlphaEvaluator::Evaluate(args) * AWrappedEvaluator::GetScale() * XWrappedEvaluator::GetScale();
211 template<
typename A1,
typename A2,
typename A3,
typename IndicesType,
unsigned int index>
212 struct AlphaAXParameterAccessImpl< Node<Node<A1, MultiplyOp, A2>, MultiplyOp, A3>, IndicesType, index,
213 typename boost::enable_if
215 Test3ArgumentAssociativeNode<Node<Node<A1, MultiplyOp, A2>, MultiplyOp, A3>, Nektar::CanGetRawPtr, MultiplyOp,
216 Nektar::IsVector, MultiplyOp, IsDouble>
223 >::type> :
public boost::true_type
225 typedef DgemvNodeEvaluator<A1, IndicesType, index> AWrappedEvaluator;
226 typedef typename AWrappedEvaluator::Evaluator AEvaluator;
228 static const unsigned int A2Index = index + A1::TotalCount;
229 typedef DgemvNodeEvaluator<A2, IndicesType, A2Index> XWrappedEvaluator;
230 typedef typename XWrappedEvaluator::Evaluator XEvaluator;
232 static const unsigned int A3Index = A2Index + A2::TotalCount;
233 typedef DgemvNodeEvaluator<A3, IndicesType, A3Index> AlphaWrappedEvaluator;
234 typedef typename AlphaWrappedEvaluator::Evaluator AlphaEvaluator;
236 template<
typename ArgumentVectorType>
237 static double GetAlpha(
const ArgumentVectorType& args)
239 return AlphaEvaluator::Evaluate(args) * AWrappedEvaluator::GetScale() * XWrappedEvaluator::GetScale();
244 template<
typename A1,
typename A2,
typename IndicesType,
unsigned int index>
245 struct AlphaAXParameterAccessImpl< Node<A1, MultiplyOp, A2>, IndicesType, index,
246 typename boost::enable_if
250 TestBinaryNode<Node<A1, MultiplyOp, A2>, Nektar::CanGetRawPtr, MultiplyOp, Nektar::IsVector>,
251 boost::mpl::not_<Test3ArgumentAssociativeNode<Node<A1, MultiplyOp, A2>, IsDouble, MultiplyOp, Nektar::CanGetRawPtr, MultiplyOp, Nektar::IsVector> >,
252 boost::mpl::not_<Test3ArgumentAssociativeNode<Node<A1, MultiplyOp, A2>, Nektar::CanGetRawPtr, MultiplyOp, IsDouble, MultiplyOp, Nektar::IsVector> >,
253 boost::mpl::not_<Test3ArgumentAssociativeNode<Node<A1, MultiplyOp, A2>, Nektar::CanGetRawPtr, MultiplyOp, Nektar::IsVector, MultiplyOp, IsDouble> >
257 >::type> :
public boost::true_type
259 typedef DgemvNodeEvaluator<A1, IndicesType, index> AWrappedEvaluator;
260 typedef typename AWrappedEvaluator::Evaluator AEvaluator;
262 static const unsigned int A2Index = index + A1::TotalCount;
263 typedef DgemvNodeEvaluator<A2, IndicesType, A2Index> XWrappedEvaluator;
264 typedef typename XWrappedEvaluator::Evaluator XEvaluator;
266 template<
typename ArgumentVectorType>
267 static double GetAlpha(
const ArgumentVectorType& args)
269 return 1.0 * AWrappedEvaluator::GetScale() * XWrappedEvaluator::GetScale();
273 template<
typename NodeType,
typename IndicesType,
unsigned int index>
274 struct AlphaAXParameterAccess :
public AlphaAXParameterAccessImpl<NodeType, IndicesType, index> {};
276 template<
typename T,
typename IndicesType,
unsigned int index>
277 struct AlphaAXParameterAccess<Node<T, NegateOp,
void>, IndicesType, index> :
public AlphaAXParameterAccessImpl<T, IndicesType, index>
279 typedef Node<T, NegateOp, void> NodeType;
281 template<
typename ArgumentVectorType>
282 static double GetAlpha(
const ArgumentVectorType& args)
284 return -AlphaAXParameterAccessImpl<T, IndicesType, index>::GetAlpha(args);
289 template<
typename NodeType,
typename IndicesType,
unsigned int index,
typename enabled=
void>
290 struct BetaYParameterAccessImpl :
public boost::false_type {};
306 template<
typename L,
typename R,
typename IndicesType,
unsigned int index>
307 struct BetaYParameterAccessImpl< Node<L, expt::MultiplyOp, R>, IndicesType, index,
308 typename boost::enable_if
312 boost::is_same<typename L::ResultType, Nektar::NekVector<double> >,
313 boost::is_same<typename R::ResultType, double>
315 >::type> :
public boost::true_type
317 typedef DgemvNodeEvaluator<L, IndicesType, index> YWrappedEvaluator;
318 typedef typename YWrappedEvaluator::Evaluator YEvaluator;
320 static const unsigned int nextIndex = index + L::TotalCount;
321 typedef DgemvNodeEvaluator<R, IndicesType, nextIndex> BetaWrappedEvaluator;
322 typedef typename BetaWrappedEvaluator::Evaluator BetaEvaluator;
324 template<
typename ArgumentVectorType>
325 static double GetBeta(
const ArgumentVectorType& args)
327 return BetaEvaluator::Evaluate(args) * YWrappedEvaluator::GetScale();
331 template<
typename L,
typename R,
typename IndicesType,
unsigned int index>
332 struct BetaYParameterAccessImpl< Node<L, expt::MultiplyOp, R>, IndicesType, index,
333 typename boost::enable_if
337 boost::is_same<typename R::ResultType, Nektar::NekVector<double> >,
338 boost::is_same<typename L::ResultType, double>
340 >::type> :
public boost::true_type
342 typedef DgemvNodeEvaluator<L, IndicesType, index> BetaWrappedEvaluator;
343 typedef typename BetaWrappedEvaluator::Evaluator BetaEvaluator;
345 static const unsigned int nextIndex = index + L::TotalCount;
346 typedef DgemvNodeEvaluator<R, IndicesType, nextIndex> YWrappedEvaluator;
347 typedef typename YWrappedEvaluator::Evaluator YEvaluator;
349 template<
typename ArgumentVectorType>
350 static double GetBeta(
const ArgumentVectorType& args)
352 return BetaEvaluator::Evaluate(args) * YWrappedEvaluator::GetScale();
356 template<
typename NodeType,
typename IndicesType,
unsigned int index>
357 struct BetaYParameterAccessImpl< NodeType, IndicesType, index,
358 typename boost::enable_if
360 boost::is_same<typename NodeType::ResultType, Nektar::NekVector<double> >
361 >::type> :
public boost::true_type
363 typedef DgemvNodeEvaluator<NodeType, IndicesType, index> YWrappedEvaluator;
364 typedef typename YWrappedEvaluator::Evaluator YEvaluator;
366 template<
typename ArgumentVectorType>
367 static double GetBeta(
const ArgumentVectorType& args)
369 return 1.0 * YWrappedEvaluator::GetScale();
373 template<
typename NodeType,
typename IndicesType,
unsigned int index>
374 struct BetaYParameterAccess :
public BetaYParameterAccessImpl<NodeType, IndicesType, index> {};
376 template<
typename T,
typename IndicesType,
unsigned int index>
377 struct BetaYParameterAccess<Node<T, NegateOp,
void>, IndicesType, index> :
public BetaYParameterAccessImpl<T, IndicesType, index>
379 template<
typename ArgumentVectorType>
380 static double GetBeta(
const ArgumentVectorType& args)
382 return -BetaYParameterAccessImpl<T, IndicesType, index>::GetBeta(args);
404 template<
typename L,
typename OpType,
typename R,
typename IndicesType,
unsigned int index>
405 struct BinaryBinaryEvaluateNodeOverride<L, OpType, R, IndicesType, index,
406 typename boost::enable_if
410 impl::AlphaAXParameterAccess<L, IndicesType, index>,
411 impl::BetaYParameterAccess<R, IndicesType, index + L::TotalCount>
415 >::type> :
public boost::true_type
417 typedef Node<L, OpType, R> NodeType;
418 typedef impl::AlphaAXParameterAccess<L, IndicesType, index> LhsAccess;
419 static const unsigned int rhsIndex = index + L::TotalCount;
420 typedef impl::BetaYParameterAccess<R, IndicesType, rhsIndex> RhsAccess;
421 typedef typename LhsAccess::AEvaluator AEvaluator;
422 typedef typename LhsAccess::XEvaluator XEvaluator;
423 typedef typename RhsAccess::YEvaluator YEvaluator;
425 template<
typename ResultType,
typename ArgumentVectorType>
426 static void Evaluate(ResultType& accumulator,
const ArgumentVectorType& args)
428 typename AEvaluator::ResultType a = AEvaluator::Evaluate(args);
429 typename XEvaluator::ResultType x = XEvaluator::Evaluate(args);
430 typename YEvaluator::ResultType y = YEvaluator::Evaluate(args);
432 double alpha = LhsAccess::GetAlpha(args);
433 double beta = RhsAccess::GetBeta(args);
435 beta *= impl::BetaScale<OpType>::GetScale();
436 Nektar::Dgemv(accumulator, alpha, a, x, beta, y);
443 template<
typename L,
typename OpType,
typename R,
typename IndicesType,
unsigned int index>
444 struct BinaryBinaryEvaluateNodeOverride<L, OpType, R, IndicesType, index,
445 typename boost::enable_if
449 impl::AlphaAXParameterAccess<R, IndicesType, index + L::TotalCount>,
450 impl::BetaYParameterAccess<L, IndicesType, index>,
451 boost::mpl::not_<impl::AlphaAXParameterAccess<L, IndicesType, index> >
456 >::type> :
public boost::true_type
458 typedef Node<L, OpType, R> NodeType;
460 typedef impl::BetaYParameterAccess<L, IndicesType, index> LhsAccess;
461 static const unsigned int rhsIndex = index + L::TotalCount;
462 typedef impl::AlphaAXParameterAccess<R, IndicesType, rhsIndex> RhsAccess;
464 typedef typename RhsAccess::AEvaluator AEvaluator;
465 typedef typename RhsAccess::XEvaluator XEvaluator;
466 typedef typename LhsAccess::YEvaluator YEvaluator;
468 template<
typename ResultType,
typename ArgumentVectorType>
469 static void Evaluate(ResultType& accumulator,
const ArgumentVectorType& args)
471 typename AEvaluator::ResultType a = AEvaluator::Evaluate(args);
472 typename XEvaluator::ResultType x = XEvaluator::Evaluate(args);
473 typename YEvaluator::ResultType y = YEvaluator::Evaluate(args);
475 double alpha = RhsAccess::GetAlpha(args);
476 double beta = LhsAccess::GetBeta(args);
478 alpha *= impl::BetaScale<OpType>::GetScale();
479 Nektar::Dgemv(accumulator, alpha, a, x, beta, y);
486 template<
typename L,
typename OpType,
typename R,
typename IndicesType,
unsigned int index>
487 struct BinaryBinaryEvaluateNodeOverride<L, OpType, R, IndicesType, index,
488 typename boost::enable_if
490 impl::AlphaAXParameterAccess<Node<L, OpType, R>, IndicesType, index>
492 >::type> :
public boost::true_type
494 typedef Node<L, OpType, R> NodeType;
495 typedef impl::AlphaAXParameterAccess<NodeType, IndicesType, index> LhsAccess;
496 typedef typename LhsAccess::AEvaluator AEvaluator;
497 typedef typename LhsAccess::XEvaluator XEvaluator;
499 template<
typename ResultType,
typename ArgumentVectorType>
500 static void Evaluate(ResultType& accumulator,
const ArgumentVectorType& args)
502 typename AEvaluator::ResultType a = AEvaluator::Evaluate(args);
503 typename XEvaluator::ResultType x = XEvaluator::Evaluate(args);
505 double alpha = LhsAccess::GetAlpha(args);
506 Nektar::Dgemm(accumulator, alpha, a, x);