36 #ifndef NEKTAR_LIB_UTILITIES_LINEAR_ALGEBRA_MATRIX_OPERATIONS_DECLARATIONS_HPP 
   37 #define NEKTAR_LIB_UTILITIES_LINEAR_ALGEBRA_MATRIX_OPERATIONS_DECLARATIONS_HPP 
   53 #include <boost/utility/enable_if.hpp> 
   54 #include <boost/type_traits.hpp> 
   63     template<
typename DataType, 
typename LhsDataType, 
typename MatrixType>
 
   65     Multiply(
const NekMatrix<LhsDataType, MatrixType>& 
lhs,
 
   66              const NekVector<DataType>& rhs);
 
   68     template<
typename DataType, 
typename LhsDataType, 
typename MatrixType>
 
   69     void Multiply(NekVector<DataType>& result,
 
   70                   const NekMatrix<LhsDataType, MatrixType>& 
lhs,
 
   71                   const NekVector<DataType>& rhs);
 
   73     template<
typename DataType, 
typename LhsInnerMatrixType>
 
   74     void Multiply(NekVector<DataType>& result,
 
   75                   const NekMatrix<LhsInnerMatrixType, BlockMatrixTag>& 
lhs,
 
   76                   const NekVector<DataType>& rhs);
 
   81     template<
typename ResultDataType, 
typename LhsDataType, 
typename LhsMatrixType>
 
   82     void Multiply(NekMatrix<ResultDataType, StandardMatrixTag>& result,
 
   83                      const NekMatrix<LhsDataType, LhsMatrixType>& 
lhs,
 
   84                      const ResultDataType& rhs);
 
   86     template<
typename DataType, 
typename LhsDataType, 
typename LhsMatrixType>
 
   87     NekMatrix<typename NekMatrix<LhsDataType, LhsMatrixType>::NumberType, StandardMatrixTag>
 
   88     Multiply(
const NekMatrix<LhsDataType, LhsMatrixType>& 
lhs,
 
   91     template<
typename RhsDataType, 
typename RhsMatrixType, 
typename ResultDataType>
 
   92     void Multiply(NekMatrix<ResultDataType, StandardMatrixTag>& result,
 
   93                      const ResultDataType& 
lhs,
 
   94                      const NekMatrix<RhsDataType, RhsMatrixType>& rhs);
 
   96     template<
typename DataType, 
typename RhsDataType, 
typename RhsMatrixType>
 
   97     NekMatrix<typename NekMatrix<RhsDataType, RhsMatrixType>::NumberType, StandardMatrixTag>
 
   99                 const NekMatrix<RhsDataType, RhsMatrixType>& rhs);
 
  101     template<
typename LhsDataType>
 
  103                        typename boost::call_traits<LhsDataType>::const_reference rhs);
 
  115                                          typename boost::enable_if
 
  124         ASSERTL1(lhs.GetType() == 
eFULL && rhs.GetType() == 
eFULL, 
"Only full matrices are supported.");
 
  126         unsigned int M = lhs.GetRows();
 
  127         unsigned int N = rhs.GetColumns();
 
  128         unsigned int K = lhs.GetColumns();
 
  130         unsigned int LDA = M;
 
  131         if( lhs.GetTransposeFlag() == 
'T' )
 
  136         unsigned int LDB = K;
 
  137         if( rhs.GetTransposeFlag() == 
'T' )
 
  142         Blas::Dgemm(lhs.GetTransposeFlag(), rhs.GetTransposeFlag(), M, N, K,
 
  143                     lhs.Scale()*rhs.Scale(), lhs.GetRawPtr(), LDA,
 
  144                     rhs.GetRawPtr(), LDB, 0.0,
 
  145                     result.GetRawPtr(), result.GetRows());
 
  152     void Multiply(NekMatrix<DataType, StandardMatrixTag>& result,
 
  153                      const NekMatrix<LhsDataType, LhsMatrixType>& 
lhs,
 
  154                      const NekMatrix<RhsDataType, RhsMatrixType>& rhs);
 
  156     template<
typename RhsInnerType, 
typename RhsMatrixType>
 
  159                           typename boost::enable_if
 
  168         ASSERTL0(result.GetType() == 
eFULL && rhs.GetType() == 
eFULL, 
"Only full matrices supported.");
 
  169         unsigned int M = result.GetRows();
 
  170         unsigned int N = rhs.GetColumns();
 
  171         unsigned int K = result.GetColumns();
 
  173         unsigned int LDA = M;
 
  174         if( result.GetTransposeFlag() == 
'T' )
 
  179         unsigned int LDB = K;
 
  180         if( rhs.GetTransposeFlag() == 
'T' )
 
  184         double scale = rhs.Scale();
 
  186         Blas::Dgemm(result.GetTransposeFlag(), rhs.GetTransposeFlag(), M, N, K,
 
  187             scale, result.GetRawPtr(), LDA, rhs.GetRawPtr(), LDB, 0.0,
 
  188             buf.data(), result.GetRows());
 
  189         result.SetSize(result.GetRows(), rhs.GetColumns());
 
  190         result.SwapTempAndDataBuffers();
 
  193     template<
typename DataType, 
typename RhsInnerType, 
typename RhsMatrixType>
 
  196                           typename boost::enable_if
 
  205         ASSERTL1(result.
GetColumns() == rhs.GetRows(), std::string(
"A left side matrix with column count ") + 
 
  206             boost::lexical_cast<std::string>(result.
GetColumns()) + 
 
  207             std::string(
" and a right side matrix with row count ") + 
 
  208             boost::lexical_cast<std::string>(rhs.GetRows()) + std::string(
" can't be multiplied."));
 
  211         for(
unsigned int i = 0; i < result.
GetRows(); ++i)
 
  213             for(
unsigned int j = 0; j < result.
GetColumns(); ++j)
 
  215                 DataType t = DataType(0);
 
  218                 for(
unsigned int k = 0; k < result.
GetColumns(); ++k)
 
  220                     t += result(i,k)*rhs(k,j);
 
  231     NekMatrix<typename boost::remove_const<typename NekMatrix<LhsDataType, LhsMatrixType>::NumberType>::type, StandardMatrixTag> 
 
  235         typedef typename boost::remove_const<typename NekMatrix<LhsDataType, LhsMatrixType>::NumberType>::type NumberType;
 
  247     template<
typename DataType, 
typename RhsDataType, 
typename RhsMatrixType>
 
  248     void AddEqual(NekMatrix<DataType, StandardMatrixTag>& result,
 
  249                      const NekMatrix<RhsDataType, RhsMatrixType>& rhs);
 
  252     template<
typename DataType, 
typename LhsDataType, 
typename LhsMatrixType, 
typename RhsDataType, 
typename RhsMatrixType>
 
  253     void Add(NekMatrix<DataType, StandardMatrixTag>& result,
 
  254                 const NekMatrix<LhsDataType, LhsMatrixType>& 
lhs,
 
  255                 const NekMatrix<RhsDataType, RhsMatrixType>& rhs);
 
  258     template<
typename LhsDataType, 
typename LhsMatrixType, 
typename RhsDataType, 
typename RhsMatrixType>
 
  259     NekMatrix<typename NekMatrix<LhsDataType, LhsMatrixType>::NumberType, StandardMatrixTag> 
 
  260     Add(
const NekMatrix<LhsDataType, LhsMatrixType>& 
lhs,
 
  261         const NekMatrix<RhsDataType, RhsMatrixType>& rhs);
 
  263     template<
typename DataType, 
typename LhsDataType, 
typename LhsMatrixType, 
typename RhsDataType, 
typename RhsMatrixType>
 
  264     void AddNegatedLhs(NekMatrix<DataType, StandardMatrixTag>& result,
 
  265                        const NekMatrix<LhsDataType, LhsMatrixType>& 
lhs,
 
  266                        const NekMatrix<RhsDataType, RhsMatrixType>& rhs);
 
  268     template<
typename DataType, 
typename RhsDataType, 
typename RhsMatrixType>
 
  270                      const NekMatrix<RhsDataType, RhsMatrixType>& rhs);
 
  276     template<
typename DataType, 
typename LhsDataType, 
typename LhsMatrixType, 
typename RhsDataType, 
typename RhsMatrixType>
 
  277     void Subtract(NekMatrix<DataType, StandardMatrixTag>& result,
 
  278                 const NekMatrix<LhsDataType, LhsMatrixType>& 
lhs,
 
  279                 const NekMatrix<RhsDataType, RhsMatrixType>& rhs);
 
  281     template<
typename DataType, 
typename LhsDataType, 
typename LhsMatrixType, 
typename RhsDataType, 
typename RhsMatrixType>
 
  283                 const NekMatrix<LhsDataType, LhsMatrixType>& 
lhs,
 
  284                 const NekMatrix<RhsDataType, RhsMatrixType>& rhs);
 
  286     template<
typename DataType, 
typename RhsDataType, 
typename RhsMatrixType>
 
  287     void SubtractEqual(NekMatrix<DataType, StandardMatrixTag>& result,
 
  288                           const NekMatrix<RhsDataType, RhsMatrixType>& rhs);
 
  290     template<
typename DataType, 
typename RhsDataType, 
typename RhsMatrixType>
 
  292                           const NekMatrix<RhsDataType, RhsMatrixType>& rhs);
 
  294     template<
typename LhsDataType, 
typename LhsMatrixType, 
typename RhsDataType, 
typename RhsMatrixType>
 
  295     NekMatrix<typename NekMatrix<LhsDataType, LhsMatrixType>::NumberType, StandardMatrixTag> 
 
  296     Subtract(
const NekMatrix<LhsDataType, LhsMatrixType>& 
lhs,
 
  297                 const NekMatrix<RhsDataType, RhsMatrixType>& rhs);
 
void SubtractNegatedLhs(NekMatrix< DataType, StandardMatrixTag > &result, const NekMatrix< LhsDataType, LhsMatrixType > &lhs, const NekMatrix< RhsDataType, RhsMatrixType > &rhs)
 
#define ASSERTL0(condition, msg)
 
void MultiplyEqual(NekMatrix< LhsDataType, StandardMatrixTag > &lhs, typename boost::call_traits< LhsDataType >::const_reference rhs)
 
void NekMultiplyFullMatrixFullMatrix(NekMatrix< ResultType, StandardMatrixTag > &result, const NekMatrix< LhsDataType, LhsMatrixType > &lhs, const NekMatrix< RhsDataType, RhsMatrixType > &rhs)
 
RhsMatrixType void AddEqual(NekMatrix< DataType, StandardMatrixTag > &result, const NekMatrix< RhsDataType, RhsMatrixType > &rhs)
 
void Multiply(NekMatrix< ResultDataType, StandardMatrixTag > &result, const NekMatrix< LhsDataType, LhsMatrixType > &lhs, const ResultDataType &rhs)
 
unsigned int GetColumns() const 
 
void AddEqualNegatedLhs(NekMatrix< DataType, StandardMatrixTag > &result, const NekMatrix< RhsDataType, RhsMatrixType > &rhs)
 
DNekMat void SubtractEqual(NekMatrix< DataType, StandardMatrixTag > &result, const NekMatrix< RhsDataType, RhsMatrixType > &rhs)
 
void SubtractEqualNegatedLhs(NekMatrix< DataType, StandardMatrixTag > &result, const NekMatrix< RhsDataType, RhsMatrixType > &rhs)
 
void AddNegatedLhs(NekMatrix< DataType, StandardMatrixTag > &result, const NekMatrix< LhsDataType, LhsMatrixType > &lhs, const NekMatrix< RhsDataType, RhsMatrixType > &rhs)
 
unsigned int GetRows() const 
 
DNekMat void Add(NekMatrix< DataType, StandardMatrixTag > &result, const NekMatrix< LhsDataType, LhsMatrixType > &lhs, const NekMatrix< RhsDataType, RhsMatrixType > &rhs)
 
RhsMatrixType void Subtract(NekMatrix< DataType, StandardMatrixTag > &result, const NekMatrix< LhsDataType, LhsMatrixType > &lhs, const NekMatrix< RhsDataType, RhsMatrixType > &rhs)
 
#define ASSERTL1(condition, msg)
Assert Level 1 – Debugging which is used whether in FULLDEBUG or DEBUG compilation mode...