Nektar++
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
MatrixOperationsDeclarations.hpp
Go to the documentation of this file.
1 ///////////////////////////////////////////////////////////////////////////////
2 //
3 // File: MatrixOperations.hpp
4 //
5 // For more information, please see: http://www.nektar.info
6 //
7 // The MIT License
8 //
9 // Copyright (c) 2006 Division of Applied Mathematics, Brown University (USA),
10 // Department of Aeronautics, Imperial College London (UK), and Scientific
11 // Computing and Imaging Institute, University of Utah (USA).
12 //
13 // License for the specific language governing rights and limitations under
14 // Permission is hereby granted, free of charge, to any person obtaining a
15 // copy of this software and associated documentation files (the "Software"),
16 // to deal in the Software without restriction, including without limitation
17 // the rights to use, copy, modify, merge, publish, distribute, sublicense,
18 // and/or sell copies of the Software, and to permit persons to whom the
19 // Software is furnished to do so, subject to the following conditions:
20 //
21 // The above copyright notice and this permission notice shall be included
22 // in all copies or substantial portions of the Software.
23 //
24 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
25 // OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
26 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
27 // THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
28 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
29 // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
30 // DEALINGS IN THE SOFTWARE.
31 //
32 // Description: Defines the global functions needed for matrix operations.
33 //
34 ///////////////////////////////////////////////////////////////////////////////
35 
36 #ifndef NEKTAR_LIB_UTILITIES_LINEAR_ALGEBRA_MATRIX_OPERATIONS_DECLARATIONS_HPP
37 #define NEKTAR_LIB_UTILITIES_LINEAR_ALGEBRA_MATRIX_OPERATIONS_DECLARATIONS_HPP
38 
39 // Since this file defines all of the operations for all combination of matrix types,
40 // we have to include all matrix specializations first.
52 
53 #include <boost/utility/enable_if.hpp>
54 #include <boost/type_traits.hpp>
55 
56 #include <string>
57 
58 namespace Nektar
59 {
60  ////////////////////////////////////////////////////////////////////////////////////
61  // Matrix-Vector Multiplication
62  ////////////////////////////////////////////////////////////////////////////////////
63  template<typename DataType, typename LhsDataType, typename MatrixType>
64  NekVector<DataType>
65  Multiply(const NekMatrix<LhsDataType, MatrixType>& lhs,
66  const NekVector<DataType>& rhs);
67 
68  template<typename DataType, typename LhsDataType, typename MatrixType>
69  void Multiply(NekVector<DataType>& result,
70  const NekMatrix<LhsDataType, MatrixType>& lhs,
71  const NekVector<DataType>& rhs);
72 
73  template<typename DataType, typename LhsInnerMatrixType>
74  void Multiply(NekVector<DataType>& result,
75  const NekMatrix<LhsInnerMatrixType, BlockMatrixTag>& lhs,
76  const NekVector<DataType>& rhs);
77 
78  ////////////////////////////////////////////////////////////////////////////////////
79  // Matrix-Constant Multiplication
80  ////////////////////////////////////////////////////////////////////////////////////
81  template<typename ResultDataType, typename LhsDataType, typename LhsMatrixType>
82  void Multiply(NekMatrix<ResultDataType, StandardMatrixTag>& result,
83  const NekMatrix<LhsDataType, LhsMatrixType>& lhs,
84  const ResultDataType& rhs);
85 
86  template<typename DataType, typename LhsDataType, typename LhsMatrixType>
87  NekMatrix<typename NekMatrix<LhsDataType, LhsMatrixType>::NumberType, StandardMatrixTag>
88  Multiply(const NekMatrix<LhsDataType, LhsMatrixType>& lhs,
89  const DataType& rhs);
90 
91  template<typename RhsDataType, typename RhsMatrixType, typename ResultDataType>
92  void Multiply(NekMatrix<ResultDataType, StandardMatrixTag>& result,
93  const ResultDataType& lhs,
94  const NekMatrix<RhsDataType, RhsMatrixType>& rhs);
95 
96  template<typename DataType, typename RhsDataType, typename RhsMatrixType>
97  NekMatrix<typename NekMatrix<RhsDataType, RhsMatrixType>::NumberType, StandardMatrixTag>
98  Multiply(const DataType& lhs,
99  const NekMatrix<RhsDataType, RhsMatrixType>& rhs);
100 
101  template<typename LhsDataType>
102  void MultiplyEqual(NekMatrix<LhsDataType, StandardMatrixTag>& lhs,
103  typename boost::call_traits<LhsDataType>::const_reference rhs);
104 
105 
106  ///////////////////////////////////////////////////////////////////
107  // Matrix-Matrix Multipliation
108  //////////////////////////////////////////////////////////////////
109 
110  template<typename LhsDataType, typename RhsDataType,
111  typename LhsMatrixType, typename RhsMatrixType>
112  void NekMultiplyFullMatrixFullMatrix(NekMatrix<NekDouble, StandardMatrixTag>& result,
113  const NekMatrix<LhsDataType, LhsMatrixType>& lhs,
114  const NekMatrix<RhsDataType, RhsMatrixType>& rhs,
115  typename boost::enable_if
116  <
117  boost::mpl::and_
118  <
119  CanGetRawPtr<NekMatrix<LhsDataType, LhsMatrixType> >,
120  CanGetRawPtr<NekMatrix<RhsDataType, RhsMatrixType> >
121  >
122  >::type* p = 0)
123  {
124  ASSERTL1(lhs.GetType() == eFULL && rhs.GetType() == eFULL, "Only full matrices are supported.");
125 
126  unsigned int M = lhs.GetRows();
127  unsigned int N = rhs.GetColumns();
128  unsigned int K = lhs.GetColumns();
129 
130  unsigned int LDA = M;
131  if( lhs.GetTransposeFlag() == 'T' )
132  {
133  LDA = K;
134  }
135 
136  unsigned int LDB = K;
137  if( rhs.GetTransposeFlag() == 'T' )
138  {
139  LDB = N;
140  }
141 
142  Blas::Dgemm(lhs.GetTransposeFlag(), rhs.GetTransposeFlag(), M, N, K,
143  lhs.Scale()*rhs.Scale(), lhs.GetRawPtr(), LDA,
144  rhs.GetRawPtr(), LDB, 0.0,
145  result.GetRawPtr(), result.GetRows());
146  }
147 
148 
149 
150  template<typename LhsDataType, typename RhsDataType, typename DataType,
151  typename LhsMatrixType, typename RhsMatrixType>
152  void Multiply(NekMatrix<DataType, StandardMatrixTag>& result,
153  const NekMatrix<LhsDataType, LhsMatrixType>& lhs,
154  const NekMatrix<RhsDataType, RhsMatrixType>& rhs);
155 
156  template<typename RhsInnerType, typename RhsMatrixType>
157  void MultiplyEqual(NekMatrix<double, StandardMatrixTag>& result,
158  const NekMatrix<RhsInnerType, RhsMatrixType>& rhs,
159  typename boost::enable_if
160  <
161  boost::mpl::and_
162  <
163  boost::is_same<typename RawType<typename NekMatrix<RhsInnerType, RhsMatrixType>::NumberType>::type, double>,
164  CanGetRawPtr<NekMatrix<RhsInnerType, RhsMatrixType> >
165  >
166  >::type* t = 0)
167  {
168  ASSERTL0(result.GetType() == eFULL && rhs.GetType() == eFULL, "Only full matrices supported.");
169  unsigned int M = result.GetRows();
170  unsigned int N = rhs.GetColumns();
171  unsigned int K = result.GetColumns();
172 
173  unsigned int LDA = M;
174  if( result.GetTransposeFlag() == 'T' )
175  {
176  LDA = K;
177  }
178 
179  unsigned int LDB = K;
180  if( rhs.GetTransposeFlag() == 'T' )
181  {
182  LDB = N;
183  }
184  double scale = rhs.Scale();
185  Array<OneD, double>& buf = result.GetTempSpace();
186  Blas::Dgemm(result.GetTransposeFlag(), rhs.GetTransposeFlag(), M, N, K,
187  scale, result.GetRawPtr(), LDA, rhs.GetRawPtr(), LDB, 0.0,
188  buf.data(), result.GetRows());
189  result.SetSize(result.GetRows(), rhs.GetColumns());
190  result.SwapTempAndDataBuffers();
191  }
192 
193  template<typename DataType, typename RhsInnerType, typename RhsMatrixType>
195  const NekMatrix<RhsInnerType, RhsMatrixType>& rhs,
196  typename boost::enable_if
197  <
198  boost::mpl::or_
199  <
200  boost::mpl::not_<boost::is_same<typename RawType<typename NekMatrix<RhsInnerType, RhsMatrixType>::NumberType>::type, double> >,
201  boost::mpl::not_<CanGetRawPtr<NekMatrix<RhsInnerType, RhsMatrixType> > >
202  >
203  >::type* t = 0)
204  {
205  ASSERTL1(result.GetColumns() == rhs.GetRows(), std::string("A left side matrix with column count ") +
206  boost::lexical_cast<std::string>(result.GetColumns()) +
207  std::string(" and a right side matrix with row count ") +
208  boost::lexical_cast<std::string>(rhs.GetRows()) + std::string(" can't be multiplied."));
210 
211  for(unsigned int i = 0; i < result.GetRows(); ++i)
212  {
213  for(unsigned int j = 0; j < result.GetColumns(); ++j)
214  {
215  DataType t = DataType(0);
216 
217  // Set the result(i,j) element.
218  for(unsigned int k = 0; k < result.GetColumns(); ++k)
219  {
220  t += result(i,k)*rhs(k,j);
221  }
222  temp(i,j) = t;
223  }
224  }
225 
226  result = temp;
227  }
228 
229  template<typename LhsDataType, typename RhsDataType,
230  typename LhsMatrixType, typename RhsMatrixType>
231  NekMatrix<typename boost::remove_const<typename NekMatrix<LhsDataType, LhsMatrixType>::NumberType>::type, StandardMatrixTag>
232  Multiply(const NekMatrix<LhsDataType, LhsMatrixType>& lhs,
233  const NekMatrix<RhsDataType, RhsMatrixType>& rhs)
234  {
235  typedef typename boost::remove_const<typename NekMatrix<LhsDataType, LhsMatrixType>::NumberType>::type NumberType;
236  NekMatrix<NumberType, StandardMatrixTag> result(lhs.GetRows(), rhs.GetColumns());
237  Multiply(result, lhs, rhs);
238  return result;
239  }
240 
241 
242 
243  ///////////////////////////////////////////////////////////////////
244  // Addition
245  ///////////////////////////////////////////////////////////////////
246 
247  template<typename DataType, typename RhsDataType, typename RhsMatrixType>
248  void AddEqual(NekMatrix<DataType, StandardMatrixTag>& result,
249  const NekMatrix<RhsDataType, RhsMatrixType>& rhs);
250 
251 
252  template<typename DataType, typename LhsDataType, typename LhsMatrixType, typename RhsDataType, typename RhsMatrixType>
253  void Add(NekMatrix<DataType, StandardMatrixTag>& result,
254  const NekMatrix<LhsDataType, LhsMatrixType>& lhs,
255  const NekMatrix<RhsDataType, RhsMatrixType>& rhs);
256 
257 
258  template<typename LhsDataType, typename LhsMatrixType, typename RhsDataType, typename RhsMatrixType>
259  NekMatrix<typename NekMatrix<LhsDataType, LhsMatrixType>::NumberType, StandardMatrixTag>
260  Add(const NekMatrix<LhsDataType, LhsMatrixType>& lhs,
261  const NekMatrix<RhsDataType, RhsMatrixType>& rhs);
262 
263  template<typename DataType, typename LhsDataType, typename LhsMatrixType, typename RhsDataType, typename RhsMatrixType>
264  void AddNegatedLhs(NekMatrix<DataType, StandardMatrixTag>& result,
265  const NekMatrix<LhsDataType, LhsMatrixType>& lhs,
266  const NekMatrix<RhsDataType, RhsMatrixType>& rhs);
267 
268  template<typename DataType, typename RhsDataType, typename RhsMatrixType>
269  void AddEqualNegatedLhs(NekMatrix<DataType, StandardMatrixTag>& result,
270  const NekMatrix<RhsDataType, RhsMatrixType>& rhs);
271 
272 
273  ////////////////////////////////////////////////////////////////////////////////////
274  // Subtraction
275  ////////////////////////////////////////////////////////////////////////////////////
276  template<typename DataType, typename LhsDataType, typename LhsMatrixType, typename RhsDataType, typename RhsMatrixType>
277  void Subtract(NekMatrix<DataType, StandardMatrixTag>& result,
278  const NekMatrix<LhsDataType, LhsMatrixType>& lhs,
279  const NekMatrix<RhsDataType, RhsMatrixType>& rhs);
280 
281  template<typename DataType, typename LhsDataType, typename LhsMatrixType, typename RhsDataType, typename RhsMatrixType>
282  void SubtractNegatedLhs(NekMatrix<DataType, StandardMatrixTag>& result,
283  const NekMatrix<LhsDataType, LhsMatrixType>& lhs,
284  const NekMatrix<RhsDataType, RhsMatrixType>& rhs);
285 
286  template<typename DataType, typename RhsDataType, typename RhsMatrixType>
287  void SubtractEqual(NekMatrix<DataType, StandardMatrixTag>& result,
288  const NekMatrix<RhsDataType, RhsMatrixType>& rhs);
289 
290  template<typename DataType, typename RhsDataType, typename RhsMatrixType>
291  void SubtractEqualNegatedLhs(NekMatrix<DataType, StandardMatrixTag>& result,
292  const NekMatrix<RhsDataType, RhsMatrixType>& rhs);
293 
294  template<typename LhsDataType, typename LhsMatrixType, typename RhsDataType, typename RhsMatrixType>
295  NekMatrix<typename NekMatrix<LhsDataType, LhsMatrixType>::NumberType, StandardMatrixTag>
296  Subtract(const NekMatrix<LhsDataType, LhsMatrixType>& lhs,
297  const NekMatrix<RhsDataType, RhsMatrixType>& rhs);
298 }
299 
300 
301 #endif
302