34 #ifndef NEKTAR_LIB_UTILITIES_LINEAR_ALGEBRA_DGEMM_OVERRIDE_HPP 
   35 #define NEKTAR_LIB_UTILITIES_LINEAR_ALGEBRA_DGEMM_OVERRIDE_HPP 
   37 #ifdef NEKTAR_USE_EXPRESSION_TEMPLATES 
   39 #include <ExpressionTemplates/ExpressionTemplates.hpp> 
   40 #include <boost/utility/enable_if.hpp> 
   41 #include <boost/type_traits.hpp> 
   48     template<
typename ADataType, 
typename BDataType, 
 
   49              typename AMatrixType, 
typename BMatrixType>
 
   50     void Dgemm(NekMatrix<double, StandardMatrixTag>& result, 
 
   51                double alpha, 
const NekMatrix<ADataType, AMatrixType>& A, 
const NekMatrix<BDataType, BMatrixType>& B)
 
   53         if( A.GetType() != 
eFULL || B.GetType() != 
eFULL )
 
   60         unsigned int M = A.GetRows();
 
   61         unsigned int N = B.GetColumns();
 
   62         unsigned int K = A.GetColumns();
 
   65         if( A.GetTransposeFlag() == 
'T' )
 
   71         if( B.GetTransposeFlag() == 
'T' )
 
   76         Blas::Dgemm(A.GetTransposeFlag(), B.GetTransposeFlag(), M, N, K,
 
   77             A.Scale()*B.Scale()*alpha, A.GetRawPtr(), LDA,
 
   78             B.GetRawPtr(), LDB, 0.0,
 
   79                     result.GetRawPtr(), result.GetRows());
 
   82     template<
typename ADataType, 
typename BDataType, 
 
   83              typename AMatrixType, 
typename BMatrixType>
 
   84     void Dgemm(NekMatrix<double, StandardMatrixTag>& result, 
 
   85                const NekMatrix<ADataType, AMatrixType>& A, 
double alpha, 
const NekMatrix<BDataType, BMatrixType>& B)
 
   87         Dgemm(result, alpha, A, B);
 
   90     template<
typename ADataType, 
typename BDataType, 
 
   91              typename AMatrixType, 
typename BMatrixType>
 
   92     void Dgemm(NekMatrix<double, StandardMatrixTag>& result, 
 
   93                const NekMatrix<ADataType, AMatrixType>& A, 
const NekMatrix<BDataType, BMatrixType>& B, 
double alpha)
 
   95         Dgemm(result, alpha, A, B);
 
   98     template<
typename ADataType, 
typename BDataType, 
typename CDataType, 
 
   99              typename AMatrixType, 
typename BMatrixType, 
typename CMatrixType>
 
  100     void Dgemm(NekMatrix<double, StandardMatrixTag>& result, 
 
  101                double alpha, 
const NekMatrix<ADataType, AMatrixType>& A, 
const NekMatrix<BDataType, BMatrixType>& B,
 
  102                double beta, 
const NekMatrix<CDataType, CMatrixType>& C)
 
  104         if( A.GetType() != 
eFULL || B.GetType() != 
eFULL || C.GetType() != 
eFULL )
 
  108             NekMatrix<double> temp = beta*C;
 
  113         if( C.GetTransposeFlag() == 
'T' )
 
  117             result.SetSize(C.GetRows(), C.GetColumns());
 
  119             for(
unsigned int i = 0; i < C.GetRows(); ++i)
 
  121                 for(
unsigned int j = 0; j < C.GetColumns(); ++j)
 
  123                     result(i,j) = C(i,j);
 
  132         unsigned int M = A.GetRows();
 
  133         unsigned int N = B.GetColumns();
 
  134         unsigned int K = A.GetColumns();
 
  136         unsigned int LDA = M;
 
  137         if( A.GetTransposeFlag() == 
'T' )
 
  142         unsigned int LDB = K;
 
  143         if( B.GetTransposeFlag() == 
'T' )
 
  148         Blas::Dgemm(A.GetTransposeFlag(), B.GetTransposeFlag(), M, N, K,
 
  149             A.Scale()*B.Scale()*alpha, A.GetRawPtr(), LDA,
 
  150             B.GetRawPtr(), LDB, beta*result.Scale(),
 
  151                     result.GetRawPtr(), result.GetRows());
 
  154     template<
typename ADataType, 
typename BDataType, 
typename CDataType, 
 
  155              typename AMatrixType, 
typename BMatrixType, 
typename CMatrixType>
 
  156     void Dgemm(NekMatrix<double, StandardMatrixTag>& result, 
 
  157                const NekMatrix<ADataType, AMatrixType>& A, 
double alpha, 
const NekMatrix<BDataType, BMatrixType>& B,
 
  158                double beta, 
const NekMatrix<CDataType, CMatrixType>& C)
 
  160         Dgemm(result, alpha, A, B, beta, C);
 
  163     template<
typename ADataType, 
typename BDataType, 
typename CDataType, 
 
  164              typename AMatrixType, 
typename BMatrixType, 
typename CMatrixType>
 
  165     void Dgemm(NekMatrix<double, StandardMatrixTag>& result, 
 
  166                const NekMatrix<ADataType, AMatrixType>& A, 
const NekMatrix<BDataType, BMatrixType>& B, 
double alpha, 
 
  167                double beta, 
const NekMatrix<CDataType, CMatrixType>& C)
 
  169         Dgemm(result, alpha, A, B, beta, C);
 
  180         template<
typename NodeType, 
typename IndicesType, 
unsigned int index>
 
  181         struct DgemmNodeEvaluator
 
  183             typedef EvaluateNodeWithTemporaryIfNeeded<NodeType, IndicesType, index> Evaluator;
 
  184             typedef typename Evaluator::ResultType Type;
 
  185             static double GetScale() { 
return 1.0; }
 
  188         template<
typename LhsType, 
typename IndicesType, 
unsigned int index>
 
  189         struct DgemmNodeEvaluator<Node<LhsType, NegateOp>, IndicesType, index>
 
  191             typedef EvaluateNodeWithTemporaryIfNeeded<LhsType, IndicesType, index> Evaluator;
 
  192             typedef typename Evaluator::ResultType Type;
 
  193             static double GetScale() { 
return -1.0; }
 
  200         template<
typename OpType>
 
  203             static double GetScale() { 
return 1.0; }
 
  207         struct BetaScale<SubtractOp>
 
  209             static double GetScale() { 
return -1.0; }
 
  213         struct IsDouble : 
public boost::false_type{};
 
  216         struct IsDouble<double> : 
public boost::true_type {};
 
  222         template<
typename NodeType, 
typename IndicesType, 
unsigned int index, 
typename enabled=
void>
 
  223         struct AlphaABParameterAccessImpl : 
public boost::false_type {};
 
  225         template<
typename A1, 
typename A2, 
typename A3, 
typename IndicesType, 
unsigned int index>
 
  226         struct AlphaABParameterAccessImpl< Node<Node<A1, MultiplyOp, A2>, MultiplyOp, A3>, IndicesType, index,
 
  227             typename boost::enable_if
 
  229                 Test3ArgumentAssociativeNode<Node<Node<A1, MultiplyOp, A2>, MultiplyOp, A3>, IsDouble, MultiplyOp,
 
  230                                              Nektar::CanGetRawPtr, MultiplyOp, Nektar::CanGetRawPtr>
 
  237             >::type> : 
public boost::true_type
 
  239             typedef DgemmNodeEvaluator<A1, IndicesType, index> AlphaWrappedEvaluator;
 
  240             typedef typename AlphaWrappedEvaluator::Evaluator AlphaEvaluator;
 
  242             static const unsigned int A2Index = index + A1::TotalCount;
 
  243             typedef DgemmNodeEvaluator<A2, IndicesType, A2Index> AWrappedEvaluator;
 
  244             typedef typename AWrappedEvaluator::Evaluator AEvaluator;
 
  246             static const unsigned int A3Index = A2Index + A2::TotalCount;
 
  247             typedef DgemmNodeEvaluator<A3, IndicesType, A3Index> BWrappedEvaluator;
 
  248             typedef typename BWrappedEvaluator::Evaluator BEvaluator;
 
  250             template<
typename ArgumentVectorType>
 
  251             static double GetAlpha(
const ArgumentVectorType& args)
 
  253                 return AlphaEvaluator::Evaluate(args) * AWrappedEvaluator::GetScale() * BWrappedEvaluator::GetScale();
 
  257         template<
typename A1, 
typename A2, 
typename A3, 
typename IndicesType, 
unsigned int index>
 
  258         struct AlphaABParameterAccessImpl< Node<Node<A1, MultiplyOp, A2>, MultiplyOp, A3>, IndicesType, index,
 
  259             typename boost::enable_if
 
  261                 Test3ArgumentAssociativeNode<Node<Node<A1, MultiplyOp, A2>, MultiplyOp, A3>, Nektar::CanGetRawPtr, MultiplyOp,
 
  262                                              IsDouble, MultiplyOp, Nektar::CanGetRawPtr>
 
  269             >::type> : 
public boost::true_type
 
  271             typedef DgemmNodeEvaluator<A1, IndicesType, index> AWrappedEvaluator;
 
  272             typedef typename AWrappedEvaluator::Evaluator AEvaluator;
 
  274             static const unsigned int A2Index = index + A1::TotalCount;
 
  275             typedef DgemmNodeEvaluator<A2, IndicesType, A2Index> AlphaWrappedEvaluator;
 
  276             typedef typename AlphaWrappedEvaluator::Evaluator AlphaEvaluator;
 
  278             static const unsigned int A3Index = A2Index + A2::TotalCount;
 
  279             typedef DgemmNodeEvaluator<A3, IndicesType, A3Index> BWrappedEvaluator;
 
  280             typedef typename BWrappedEvaluator::Evaluator BEvaluator;
 
  282             template<
typename ArgumentVectorType>
 
  283             static double GetAlpha(
const ArgumentVectorType& args)
 
  285                 return AlphaEvaluator::Evaluate(args) * AWrappedEvaluator::GetScale() * BWrappedEvaluator::GetScale();
 
  289         template<
typename A1, 
typename A2, 
typename A3, 
typename IndicesType, 
unsigned int index>
 
  290         struct AlphaABParameterAccessImpl< Node<Node<A1, MultiplyOp, A2>, MultiplyOp, A3>, IndicesType, index,
 
  291             typename boost::enable_if
 
  293                 Test3ArgumentAssociativeNode<Node<Node<A1, MultiplyOp, A2>, MultiplyOp, A3>, Nektar::CanGetRawPtr, MultiplyOp,
 
  294                                              Nektar::CanGetRawPtr, MultiplyOp, IsDouble>
 
  301             >::type> : 
public boost::true_type
 
  303             typedef DgemmNodeEvaluator<A1, IndicesType, index> AWrappedEvaluator;
 
  304             typedef typename AWrappedEvaluator::Evaluator AEvaluator;
 
  306             static const unsigned int A2Index = index + A1::TotalCount;
 
  307             typedef DgemmNodeEvaluator<A2, IndicesType, A2Index> BWrappedEvaluator;
 
  308             typedef typename BWrappedEvaluator::Evaluator BEvaluator;
 
  310             static const unsigned int A3Index = A2Index + A2::TotalCount;
 
  311             typedef DgemmNodeEvaluator<A3, IndicesType, A3Index> AlphaWrappedEvaluator;
 
  312             typedef typename AlphaWrappedEvaluator::Evaluator AlphaEvaluator;
 
  314             template<
typename ArgumentVectorType>
 
  315             static double GetAlpha(
const ArgumentVectorType& args)
 
  317                 return AlphaEvaluator::Evaluate(args) * AWrappedEvaluator::GetScale() * BWrappedEvaluator::GetScale();
 
  322         template<
typename A1, 
typename A2, 
typename IndicesType, 
unsigned int index>
 
  323         struct AlphaABParameterAccessImpl< Node<A1, MultiplyOp, A2>, IndicesType, index,
 
  324             typename boost::enable_if
 
  328                     TestBinaryNode<Node<A1, MultiplyOp, A2>, Nektar::CanGetRawPtr, MultiplyOp, Nektar::CanGetRawPtr>,
 
  329                     boost::mpl::not_<Test3ArgumentAssociativeNode<Node<A1, MultiplyOp, A2>, Nektar::CanGetRawPtr, MultiplyOp, Nektar::CanGetRawPtr, MultiplyOp, IsDouble> >,
 
  330                     boost::mpl::not_<Test3ArgumentAssociativeNode<Node<A1, MultiplyOp, A2>, Nektar::CanGetRawPtr, MultiplyOp, IsDouble, MultiplyOp, Nektar::CanGetRawPtr> >,
 
  331                     boost::mpl::not_<Test3ArgumentAssociativeNode<Node<A1, MultiplyOp, A2>, IsDouble, MultiplyOp, Nektar::CanGetRawPtr, MultiplyOp, Nektar::CanGetRawPtr> >
 
  339             >::type> : 
public boost::true_type
 
  341             typedef DgemmNodeEvaluator<A1, IndicesType, index> AWrappedEvaluator;
 
  342             typedef typename AWrappedEvaluator::Evaluator AEvaluator;
 
  344             static const unsigned int A2Index = index + A1::TotalCount;
 
  345             typedef DgemmNodeEvaluator<A2, IndicesType, A2Index> BWrappedEvaluator;
 
  346             typedef typename BWrappedEvaluator::Evaluator BEvaluator;
 
  348             template<
typename ArgumentVectorType>
 
  349             static double GetAlpha(
const ArgumentVectorType& args)
 
  351                 return 1.0 * AWrappedEvaluator::GetScale() * BWrappedEvaluator::GetScale();
 
  356         template<
typename NodeType, 
typename IndicesType, 
unsigned int index>
 
  357         struct AlphaABParameterAccess : 
public AlphaABParameterAccessImpl<NodeType, IndicesType, index> {};
 
  359         template<
typename T, 
typename IndicesType, 
unsigned int index>
 
  360         struct AlphaABParameterAccess<Node<T, NegateOp, 
void>, IndicesType, index> : 
public AlphaABParameterAccessImpl<T, IndicesType, index> 
 
  362             typedef Node<T, NegateOp, void> NodeType;
 
  364             template<
typename ArgumentVectorType>
 
  365             static double GetAlpha(
const ArgumentVectorType& args)
 
  367                 return -AlphaABParameterAccessImpl<T, IndicesType, index>::GetAlpha(args);
 
  372         template<
typename NodeType, 
typename IndicesType, 
unsigned int index, 
typename enabled=
void>
 
  373         struct BetaCParameterAccessImpl : 
public boost::false_type {};
 
  375         template<
typename T, 
typename IndicesType, 
unsigned int index>
 
  376         struct BetaCParameterAccessImpl< Node<T, 
void, 
void>, IndicesType, index> : 
public boost::true_type
 
  378             typedef Node<T, void, void> NodeType;
 
  379             typedef DgemmNodeEvaluator<NodeType, IndicesType, index> CWrappedEvaluator;
 
  380             typedef typename CWrappedEvaluator::Evaluator CEvaluator;
 
  382             template<
typename ArgumentVectorType>
 
  383             static double GetBeta(
const ArgumentVectorType& args)
 
  389         template<
typename L, 
typename R, 
typename IndicesType, 
unsigned int index>
 
  390         struct BetaCParameterAccessImpl< Node<
L, expt::MultiplyOp, R>, IndicesType, index, 
 
  391             typename boost::enable_if
 
  395                 Nektar::CanGetRawPtr<typename L::ResultType>,
 
  396                     boost::is_same<typename R::ResultType, double>
 
  398             >::type> : 
public boost::true_type
 
  400             typedef DgemmNodeEvaluator<L, IndicesType, index> CWrappedEvaluator;
 
  401             typedef typename CWrappedEvaluator::Evaluator CEvaluator;
 
  403             static const unsigned int nextIndex = index + L::TotalCount;
 
  404             typedef DgemmNodeEvaluator<R, IndicesType, nextIndex> BetaWrappedEvaluator;
 
  405             typedef typename BetaWrappedEvaluator::Evaluator BetaEvaluator;
 
  407             template<
typename ArgumentVectorType>
 
  408             static double GetBeta(
const ArgumentVectorType& args)
 
  410                 return BetaEvaluator::Evaluate(args) * CWrappedEvaluator::GetScale();
 
  414         template<
typename L, 
typename R, 
typename IndicesType, 
unsigned int index>
 
  415         struct BetaCParameterAccessImpl< Node<
L, expt::MultiplyOp, R>, IndicesType, index, 
 
  416             typename boost::enable_if
 
  420                     Nektar::CanGetRawPtr<typename R::ResultType>,
 
  421                     boost::is_same<typename L::ResultType, double>
 
  423             >::type> : 
public boost::true_type
 
  425             typedef DgemmNodeEvaluator<L, IndicesType, index> BetaWrappedEvaluator;
 
  426             typedef typename BetaWrappedEvaluator::Evaluator BetaEvaluator;
 
  428             static const unsigned int nextIndex = index + L::TotalCount;
 
  429             typedef DgemmNodeEvaluator<R, IndicesType, nextIndex> CWrappedEvaluator;
 
  430             typedef typename CWrappedEvaluator::Evaluator CEvaluator;
 
  432             template<
typename ArgumentVectorType>
 
  433             static double GetBeta(
const ArgumentVectorType& args)
 
  435                 return BetaEvaluator::Evaluate(args) * CWrappedEvaluator::GetScale();
 
  439         template<
typename NodeType, 
typename IndicesType, 
unsigned int index>
 
  440         struct BetaCParameterAccess : 
public BetaCParameterAccessImpl<NodeType, IndicesType, index> {};
 
  442         template<
typename T, 
typename IndicesType, 
unsigned int index>
 
  443         struct BetaCParameterAccess<Node<T, NegateOp, 
void>, IndicesType, index> : 
public BetaCParameterAccessImpl<T, IndicesType, index> 
 
  445             template<
typename ArgumentVectorType>
 
  446             static double GetBeta(
const ArgumentVectorType& args)
 
  448                 return -BetaCParameterAccessImpl<T, IndicesType, index>::GetBeta(args);
 
  464     template<
typename L, 
typename OpType, 
typename R, 
typename IndicesType, 
unsigned int index>
 
  465     struct BinaryBinaryEvaluateNodeOverride<
L, OpType, R, IndicesType, index,
 
  466         typename boost::enable_if
 
  470                 IsAdditiveOperator<OpType>,
 
  471                 impl::AlphaABParameterAccess<L, IndicesType, index>,
 
  472                 impl::BetaCParameterAccess<R, IndicesType, index + L::TotalCount>
 
  476         >::type> : 
public boost::true_type
 
  478         typedef Node<L, OpType, R> NodeType;
 
  479         typedef impl::AlphaABParameterAccess<L, IndicesType, index> LhsAccess;
 
  480         static const unsigned int rhsIndex = index + L::TotalCount;
 
  481         typedef impl::BetaCParameterAccess<R, IndicesType, rhsIndex> RhsAccess;
 
  482         typedef typename LhsAccess::AEvaluator AEvaluator;
 
  483         typedef typename LhsAccess::BEvaluator BEvaluator;
 
  484         typedef typename RhsAccess::CEvaluator CEvaluator;
 
  486         template<
typename ResultType, 
typename ArgumentVectorType>
 
  487         static void Evaluate(ResultType& accumulator, 
const ArgumentVectorType& args)
 
  489             typename AEvaluator::ResultType a = AEvaluator::Evaluate(args);
 
  490             typename BEvaluator::ResultType b = BEvaluator::Evaluate(args);
 
  491             typename CEvaluator::ResultType c = CEvaluator::Evaluate(args);
 
  493             double alpha = LhsAccess::GetAlpha(args);
 
  494             double beta = RhsAccess::GetBeta(args);
 
  496             beta *= impl::BetaScale<OpType>::GetScale();
 
  497             Nektar::Dgemm(accumulator, alpha, a, b, beta, c);
 
  504     template<
typename L, 
typename OpType, 
typename R, 
typename IndicesType, 
unsigned int index>
 
  505     struct BinaryBinaryEvaluateNodeOverride<
L, OpType, R, IndicesType, index,
 
  506         typename boost::enable_if
 
  510                 IsAdditiveOperator<OpType>,
 
  511                 impl::AlphaABParameterAccess<R, IndicesType, index + L::TotalCount>,
 
  512                 impl::BetaCParameterAccess<L, IndicesType, index>,
 
  513                 boost::mpl::not_<impl::AlphaABParameterAccess<L, IndicesType, index> >
 
  518         >::type> : 
public boost::true_type
 
  520         typedef Node<L, OpType, R> NodeType;
 
  522         typedef impl::BetaCParameterAccess<L, IndicesType, index> LhsAccess;
 
  523         static const unsigned int rhsIndex = index + L::TotalCount;
 
  524         typedef impl::AlphaABParameterAccess<R, IndicesType, rhsIndex> RhsAccess;
 
  526         typedef typename RhsAccess::AEvaluator AEvaluator;
 
  527         typedef typename RhsAccess::BEvaluator BEvaluator;
 
  528         typedef typename LhsAccess::CEvaluator CEvaluator;
 
  530         template<
typename ResultType, 
typename ArgumentVectorType>
 
  531         static void Evaluate(ResultType& accumulator, 
const ArgumentVectorType& args)
 
  533             typename AEvaluator::ResultType a = AEvaluator::Evaluate(args);
 
  534             typename BEvaluator::ResultType b = BEvaluator::Evaluate(args);
 
  535             typename CEvaluator::ResultType c = CEvaluator::Evaluate(args);
 
  537             double alpha = RhsAccess::GetAlpha(args);
 
  538             double beta = LhsAccess::GetBeta(args);
 
  540             alpha *= impl::BetaScale<OpType>::GetScale();
 
  541             Nektar::Dgemm(accumulator, alpha, a, b, beta, c);
 
  548     template<
typename L, 
typename OpType, 
typename R, 
typename IndicesType, 
unsigned int index>
 
  549     struct BinaryBinaryEvaluateNodeOverride<
L, OpType, R, IndicesType, index,
 
  550         typename boost::enable_if
 
  553             impl::AlphaABParameterAccess<Node<L, OpType, R>, IndicesType, index>
 
  554         >::type> : 
public boost::true_type
 
  556         typedef Node<L, OpType, R> NodeType;
 
  557         typedef impl::AlphaABParameterAccess<NodeType, IndicesType, index> LhsAccess;
 
  558         typedef typename LhsAccess::AEvaluator AEvaluator;
 
  559         typedef typename LhsAccess::BEvaluator BEvaluator;
 
  561         template<
typename ResultType, 
typename ArgumentVectorType>
 
  562         static void Evaluate(ResultType& accumulator, 
const ArgumentVectorType& args)
 
  564             typename AEvaluator::ResultType a = AEvaluator::Evaluate(args);
 
  565             typename BEvaluator::ResultType b = BEvaluator::Evaluate(args);
 
  567             double alpha = LhsAccess::GetAlpha(args);
 
  568             Nektar::Dgemm(accumulator, alpha, a, b);
 
void MultiplyEqual(NekMatrix< LhsDataType, StandardMatrixTag > &lhs, typename boost::call_traits< LhsDataType >::const_reference rhs)
 
RhsMatrixType void AddEqual(NekMatrix< DataType, StandardMatrixTag > &result, const NekMatrix< RhsDataType, RhsMatrixType > &rhs)
 
void Multiply(NekMatrix< ResultDataType, StandardMatrixTag > &result, const NekMatrix< LhsDataType, LhsMatrixType > &lhs, const ResultDataType &rhs)