Nektar++
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Pages
MatrixOperationsDeclarations.hpp
Go to the documentation of this file.
1 ///////////////////////////////////////////////////////////////////////////////
2 //
3 // File: MatrixOperations.hpp
4 //
5 // For more information, please see: http://www.nektar.info
6 //
7 // The MIT License
8 //
9 // Copyright (c) 2006 Division of Applied Mathematics, Brown University (USA),
10 // Department of Aeronautics, Imperial College London (UK), and Scientific
11 // Computing and Imaging Institute, University of Utah (USA).
12 //
13 // License for the specific language governing rights and limitations under
14 // Permission is hereby granted, free of charge, to any person obtaining a
15 // copy of this software and associated documentation files (the "Software"),
16 // to deal in the Software without restriction, including without limitation
17 // the rights to use, copy, modify, merge, publish, distribute, sublicense,
18 // and/or sell copies of the Software, and to permit persons to whom the
19 // Software is furnished to do so, subject to the following conditions:
20 //
21 // The above copyright notice and this permission notice shall be included
22 // in all copies or substantial portions of the Software.
23 //
24 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
25 // OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
26 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
27 // THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
28 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
29 // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
30 // DEALINGS IN THE SOFTWARE.
31 //
32 // Description: Defines the global functions needed for matrix operations.
33 //
34 ///////////////////////////////////////////////////////////////////////////////
35 
36 #ifndef NEKTAR_LIB_UTILITIES_LINEAR_ALGEBRA_MATRIX_OPERATIONS_DECLARATIONS_HPP
37 #define NEKTAR_LIB_UTILITIES_LINEAR_ALGEBRA_MATRIX_OPERATIONS_DECLARATIONS_HPP
38 
39 // Since this file defines all of the operations for all combination of matrix types,
40 // we have to include all matrix specializations first.
52 
53 #include <boost/utility/enable_if.hpp>
54 #include <boost/type_traits.hpp>
55 
56 #include <string>
57 
58 namespace Nektar
59 {
60  ////////////////////////////////////////////////////////////////////////////////////
61  // Matrix-Vector Multiplication
62  ////////////////////////////////////////////////////////////////////////////////////
63  template<typename DataType, typename LhsDataType, typename MatrixType>
64  NekVector<DataType>
65  Multiply(const NekMatrix<LhsDataType, MatrixType>& lhs,
66  const NekVector<DataType>& rhs);
67 
68  template<typename DataType, typename LhsDataType, typename MatrixType>
69  void Multiply(NekVector<DataType>& result,
70  const NekMatrix<LhsDataType, MatrixType>& lhs,
71  const NekVector<DataType>& rhs);
72 
73  template<typename DataType, typename LhsInnerMatrixType>
74  void Multiply(NekVector<DataType>& result,
75  const NekMatrix<LhsInnerMatrixType, BlockMatrixTag>& lhs,
76  const NekVector<DataType>& rhs);
77 
78  LIB_UTILITIES_EXPORT void DiagonalBlockFullScalMatrixMultiply(NekVector<double>& result,
79  const NekMatrix<NekMatrix<NekMatrix<NekDouble, StandardMatrixTag>, ScaledMatrixTag>, BlockMatrixTag>& lhs,
80  const NekVector<double>& rhs);
81 
82  ////////////////////////////////////////////////////////////////////////////////////
83  // Matrix-Constant Multiplication
84  ////////////////////////////////////////////////////////////////////////////////////
85  template<typename ResultDataType, typename LhsDataType, typename LhsMatrixType>
86  void Multiply(NekMatrix<ResultDataType, StandardMatrixTag>& result,
87  const NekMatrix<LhsDataType, LhsMatrixType>& lhs,
88  const ResultDataType& rhs);
89 
90  template<typename DataType, typename LhsDataType, typename LhsMatrixType>
91  NekMatrix<typename NekMatrix<LhsDataType, LhsMatrixType>::NumberType, StandardMatrixTag>
92  Multiply(const NekMatrix<LhsDataType, LhsMatrixType>& lhs,
93  const DataType& rhs);
94 
95  template<typename RhsDataType, typename RhsMatrixType, typename ResultDataType>
96  void Multiply(NekMatrix<ResultDataType, StandardMatrixTag>& result,
97  const ResultDataType& lhs,
98  const NekMatrix<RhsDataType, RhsMatrixType>& rhs);
99 
100  template<typename DataType, typename RhsDataType, typename RhsMatrixType>
101  NekMatrix<typename NekMatrix<RhsDataType, RhsMatrixType>::NumberType, StandardMatrixTag>
102  Multiply(const DataType& lhs,
103  const NekMatrix<RhsDataType, RhsMatrixType>& rhs);
104 
105  template<typename LhsDataType>
106  void MultiplyEqual(NekMatrix<LhsDataType, StandardMatrixTag>& lhs,
107  typename boost::call_traits<LhsDataType>::const_reference rhs);
108 
109 
110  ///////////////////////////////////////////////////////////////////
111  // Matrix-Matrix Multipliation
112  //////////////////////////////////////////////////////////////////
113 
114  template<typename LhsDataType, typename RhsDataType,
115  typename LhsMatrixType, typename RhsMatrixType>
119  typename boost::enable_if
120  <
121  boost::mpl::and_
122  <
125  >
126  >::type* p = 0)
127  {
128  ASSERTL1(lhs.GetType() == eFULL && rhs.GetType() == eFULL, "Only full matrices are supported.");
129 
130  unsigned int M = lhs.GetRows();
131  unsigned int N = rhs.GetColumns();
132  unsigned int K = lhs.GetColumns();
133 
134  unsigned int LDA = M;
135  if( lhs.GetTransposeFlag() == 'T' )
136  {
137  LDA = K;
138  }
139 
140  unsigned int LDB = K;
141  if( rhs.GetTransposeFlag() == 'T' )
142  {
143  LDB = N;
144  }
145 
146  Blas::Dgemm(lhs.GetTransposeFlag(), rhs.GetTransposeFlag(), M, N, K,
147  lhs.Scale()*rhs.Scale(), lhs.GetRawPtr(), LDA,
148  rhs.GetRawPtr(), LDB, 0.0,
149  result.GetRawPtr(), result.GetRows());
150  }
151 
152 
153 
154  template<typename LhsDataType, typename RhsDataType, typename DataType,
155  typename LhsMatrixType, typename RhsMatrixType>
156  void Multiply(NekMatrix<DataType, StandardMatrixTag>& result,
157  const NekMatrix<LhsDataType, LhsMatrixType>& lhs,
158  const NekMatrix<RhsDataType, RhsMatrixType>& rhs);
159 
160  template<typename RhsInnerType, typename RhsMatrixType>
163  typename boost::enable_if
164  <
165  boost::mpl::and_
166  <
167  boost::is_same<typename RawType<typename NekMatrix<RhsInnerType, RhsMatrixType>::NumberType>::type, double>,
169  >
170  >::type* t = 0)
171  {
172  ASSERTL0(result.GetType() == eFULL && rhs.GetType() == eFULL, "Only full matrices supported.");
173  unsigned int M = result.GetRows();
174  unsigned int N = rhs.GetColumns();
175  unsigned int K = result.GetColumns();
176 
177  unsigned int LDA = M;
178  if( result.GetTransposeFlag() == 'T' )
179  {
180  LDA = K;
181  }
182 
183  unsigned int LDB = K;
184  if( rhs.GetTransposeFlag() == 'T' )
185  {
186  LDB = N;
187  }
188  double scale = rhs.Scale();
189  Array<OneD, double>& buf = result.GetTempSpace();
190  Blas::Dgemm(result.GetTransposeFlag(), rhs.GetTransposeFlag(), M, N, K,
191  scale, result.GetRawPtr(), LDA, rhs.GetRawPtr(), LDB, 0.0,
192  buf.data(), result.GetRows());
193  result.SetSize(result.GetRows(), rhs.GetColumns());
194  result.SwapTempAndDataBuffers();
195  }
196 
197  template<typename DataType, typename RhsInnerType, typename RhsMatrixType>
200  typename boost::enable_if
201  <
202  boost::mpl::or_
203  <
204  boost::mpl::not_<boost::is_same<typename RawType<typename NekMatrix<RhsInnerType, RhsMatrixType>::NumberType>::type, double> >,
206  >
207  >::type* t = 0)
208  {
209  ASSERTL1(result.GetColumns() == rhs.GetRows(), std::string("A left side matrix with column count ") +
210  boost::lexical_cast<std::string>(result.GetColumns()) +
211  std::string(" and a right side matrix with row count ") +
212  boost::lexical_cast<std::string>(rhs.GetRows()) + std::string(" can't be multiplied."));
214 
215  for(unsigned int i = 0; i < result.GetRows(); ++i)
216  {
217  for(unsigned int j = 0; j < result.GetColumns(); ++j)
218  {
219  DataType t = DataType(0);
220 
221  // Set the result(i,j) element.
222  for(unsigned int k = 0; k < result.GetColumns(); ++k)
223  {
224  t += result(i,k)*rhs(k,j);
225  }
226  temp(i,j) = t;
227  }
228  }
229 
230  result = temp;
231  }
232 
233  template<typename LhsDataType, typename RhsDataType,
234  typename LhsMatrixType, typename RhsMatrixType>
235  NekMatrix<typename boost::remove_const<typename NekMatrix<LhsDataType, LhsMatrixType>::NumberType>::type, StandardMatrixTag>
238  {
239  typedef typename boost::remove_const<typename NekMatrix<LhsDataType, LhsMatrixType>::NumberType>::type NumberType;
240  NekMatrix<NumberType, StandardMatrixTag> result(lhs.GetRows(), rhs.GetColumns());
241  Multiply(result, lhs, rhs);
242  return result;
243  }
244 
245 
246 
247  ///////////////////////////////////////////////////////////////////
248  // Addition
249  ///////////////////////////////////////////////////////////////////
250 
251  template<typename DataType, typename RhsDataType, typename RhsMatrixType>
252  void AddEqual(NekMatrix<DataType, StandardMatrixTag>& result,
253  const NekMatrix<RhsDataType, RhsMatrixType>& rhs);
254 
255 
256  template<typename DataType, typename LhsDataType, typename LhsMatrixType, typename RhsDataType, typename RhsMatrixType>
257  void Add(NekMatrix<DataType, StandardMatrixTag>& result,
258  const NekMatrix<LhsDataType, LhsMatrixType>& lhs,
259  const NekMatrix<RhsDataType, RhsMatrixType>& rhs);
260 
261 
262  template<typename LhsDataType, typename LhsMatrixType, typename RhsDataType, typename RhsMatrixType>
263  NekMatrix<typename NekMatrix<LhsDataType, LhsMatrixType>::NumberType, StandardMatrixTag>
264  Add(const NekMatrix<LhsDataType, LhsMatrixType>& lhs,
265  const NekMatrix<RhsDataType, RhsMatrixType>& rhs);
266 
267  template<typename DataType, typename LhsDataType, typename LhsMatrixType, typename RhsDataType, typename RhsMatrixType>
268  void AddNegatedLhs(NekMatrix<DataType, StandardMatrixTag>& result,
269  const NekMatrix<LhsDataType, LhsMatrixType>& lhs,
270  const NekMatrix<RhsDataType, RhsMatrixType>& rhs);
271 
272  template<typename DataType, typename RhsDataType, typename RhsMatrixType>
273  void AddEqualNegatedLhs(NekMatrix<DataType, StandardMatrixTag>& result,
274  const NekMatrix<RhsDataType, RhsMatrixType>& rhs);
275 
276 
277  ////////////////////////////////////////////////////////////////////////////////////
278  // Subtraction
279  ////////////////////////////////////////////////////////////////////////////////////
280  template<typename DataType, typename LhsDataType, typename LhsMatrixType, typename RhsDataType, typename RhsMatrixType>
281  void Subtract(NekMatrix<DataType, StandardMatrixTag>& result,
282  const NekMatrix<LhsDataType, LhsMatrixType>& lhs,
283  const NekMatrix<RhsDataType, RhsMatrixType>& rhs);
284 
285  template<typename DataType, typename LhsDataType, typename LhsMatrixType, typename RhsDataType, typename RhsMatrixType>
286  void SubtractNegatedLhs(NekMatrix<DataType, StandardMatrixTag>& result,
287  const NekMatrix<LhsDataType, LhsMatrixType>& lhs,
288  const NekMatrix<RhsDataType, RhsMatrixType>& rhs);
289 
290  template<typename DataType, typename RhsDataType, typename RhsMatrixType>
291  void SubtractEqual(NekMatrix<DataType, StandardMatrixTag>& result,
292  const NekMatrix<RhsDataType, RhsMatrixType>& rhs);
293 
294  template<typename DataType, typename RhsDataType, typename RhsMatrixType>
295  void SubtractEqualNegatedLhs(NekMatrix<DataType, StandardMatrixTag>& result,
296  const NekMatrix<RhsDataType, RhsMatrixType>& rhs);
297 
298  template<typename LhsDataType, typename LhsMatrixType, typename RhsDataType, typename RhsMatrixType>
299  NekMatrix<typename NekMatrix<LhsDataType, LhsMatrixType>::NumberType, StandardMatrixTag>
300  Subtract(const NekMatrix<LhsDataType, LhsMatrixType>& lhs,
301  const NekMatrix<RhsDataType, RhsMatrixType>& rhs);
302 }
303 
304 
305 #endif
306 
void SubtractNegatedLhs(NekMatrix< DataType, StandardMatrixTag > &result, const NekMatrix< LhsDataType, LhsMatrixType > &lhs, const NekMatrix< RhsDataType, RhsMatrixType > &rhs)
#define ASSERTL0(condition, msg)
Definition: ErrorUtil.hpp:198
void DiagonalBlockFullScalMatrixMultiply(NekVector< double > &result, const NekMatrix< NekMatrix< NekMatrix< NekDouble, StandardMatrixTag >, ScaledMatrixTag >, BlockMatrixTag > &lhs, const NekVector< double > &rhs)
void MultiplyEqual(NekMatrix< LhsDataType, StandardMatrixTag > &lhs, typename boost::call_traits< LhsDataType >::const_reference rhs)
void NekMultiplyFullMatrixFullMatrix(NekMatrix< ResultType, StandardMatrixTag > &result, const NekMatrix< LhsDataType, LhsMatrixType > &lhs, const NekMatrix< RhsDataType, RhsMatrixType > &rhs)
StandardMatrixTag & lhs
RhsMatrixType void AddEqual(NekMatrix< DataType, StandardMatrixTag > &result, const NekMatrix< RhsDataType, RhsMatrixType > &rhs)
void Multiply(NekMatrix< ResultDataType, StandardMatrixTag > &result, const NekMatrix< LhsDataType, LhsMatrixType > &lhs, const ResultDataType &rhs)
unsigned int GetColumns() const
Definition: MatrixBase.cpp:78
void AddEqualNegatedLhs(NekMatrix< DataType, StandardMatrixTag > &result, const NekMatrix< RhsDataType, RhsMatrixType > &rhs)
DNekMat void SubtractEqual(NekMatrix< DataType, StandardMatrixTag > &result, const NekMatrix< RhsDataType, RhsMatrixType > &rhs)
void SubtractEqualNegatedLhs(NekMatrix< DataType, StandardMatrixTag > &result, const NekMatrix< RhsDataType, RhsMatrixType > &rhs)
#define LIB_UTILITIES_EXPORT
void AddNegatedLhs(NekMatrix< DataType, StandardMatrixTag > &result, const NekMatrix< LhsDataType, LhsMatrixType > &lhs, const NekMatrix< RhsDataType, RhsMatrixType > &rhs)
unsigned int GetRows() const
Definition: MatrixBase.cpp:59
DNekMat void Add(NekMatrix< DataType, StandardMatrixTag > &result, const NekMatrix< LhsDataType, LhsMatrixType > &lhs, const NekMatrix< RhsDataType, RhsMatrixType > &rhs)
RhsMatrixType void Subtract(NekMatrix< DataType, StandardMatrixTag > &result, const NekMatrix< LhsDataType, LhsMatrixType > &lhs, const NekMatrix< RhsDataType, RhsMatrixType > &rhs)
#define ASSERTL1(condition, msg)
Assert Level 1 – Debugging which is used whether in FULLDEBUG or DEBUG compilation mode...
Definition: ErrorUtil.hpp:228