Nektar++
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
DgemvOverride.hpp
Go to the documentation of this file.
1 ///////////////////////////////////////////////////////////////////////////////
2 //
3 // For more information, please see: http://www.nektar.info
4 //
5 // The MIT License
6 //
7 // Copyright (c) 2006 Division of Applied Mathematics, Brown University (USA),
8 // Department of Aeronautics, Imperial College London (UK), and Scientific
9 // Computing and Imaging Institute, University of Utah (USA).
10 //
11 // License for the specific language governing rights and limitations under
12 // Permission is hereby granted, free of charge, to any person obtaining a
13 // copy of this software and associated documentation files (the "Software"),
14 // to deal in the Software without restriction, including without limitation
15 // the rights to use, copy, modify, merge, publish, distribute, sublicense,
16 // and/or sell copies of the Software, and to permit persons to whom the
17 // Software is furnished to do so, subject to the following conditions:
18 //
19 // The above copyright notice and this permission notice shall be included
20 // in all copies or substantial portions of the Software.
21 //
22 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
23 // OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
24 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
25 // THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
26 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
27 // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
28 // DEALINGS IN THE SOFTWARE.
29 //
30 // Description: Defines the global functions needed for matrix operations.
31 //
32 ///////////////////////////////////////////////////////////////////////////////
33 
34 #ifndef NEKTAR_LIB_UTILITIES_LINEAR_ALGEBRA_DGEMV_OVERRIDE_HPP
35 #define NEKTAR_LIB_UTILITIES_LINEAR_ALGEBRA_DGEMV_OVERRIDE_HPP
36 
37 #ifdef NEKTAR_USE_EXPRESSION_TEMPLATES
38 
39 #include <ExpressionTemplates/ExpressionTemplates.hpp>
40 #include <boost/utility/enable_if.hpp>
41 #include <boost/type_traits.hpp>
45 
46 namespace Nektar
47 {
48  template<typename ADataType, typename AMatrixType>
49  void Dgemv(NekVector<double>& result,
50  double alpha, const NekMatrix<ADataType, AMatrixType>& A, const NekVector<double>& x)
51  {
52  if( A.GetType() != eFULL )
53  {
54  Multiply(result, A, x);
55  MultiplyEqual(result, alpha);
56  return;
57  }
58 
59  unsigned int M = A.GetRows();
60  unsigned int N = A.GetColumns();
61 
62  char t = A.GetTransposeFlag();
63  if( t == 'T' )
64  {
65  std::swap(M,N);
66  }
67 
68  int lda = M;
69 
70  Blas::Dgemv(t, M, N, alpha*A.Scale(), A.GetRawPtr(), lda, x.GetRawPtr(), 1, 0.0, result.GetRawPtr(), 1);
71  }
72 
73  template<typename ADataType, typename AMatrixType>
74  void Dgemv(NekVector<double>& result,
75  double alpha, const NekMatrix<ADataType, AMatrixType>& A, const NekVector<double>& x,
76  double beta, const NekVector<double>& y)
77  {
78  if( A.GetType() != eFULL)
79  {
80  Multiply(result, A, x);
81  MultiplyEqual(result, alpha);
82  NekVector<double> temp = beta*y;
83  AddEqual(result, temp);
84  return;
85  }
86 
87  result = y;
88  unsigned int M = A.GetRows();
89  unsigned int N = A.GetColumns();
90 
91  char t = A.GetTransposeFlag();
92  if( t == 'T' )
93  {
94  std::swap(M,N);
95  }
96 
97  int lda = M;
98  Blas::Dgemv(t, M, N, alpha*A.Scale(), A.GetRawPtr(), lda, x.GetRawPtr(), 1, beta, result.GetRawPtr(), 1);
99  }
100 }
101 
102 namespace expt
103 {
104  namespace impl
105  {
106  // The DgemmNodeEvaluator creates an evaluator for the requested node, or the
107  // child of a negate node. The negation operation will occur as part of the dgemm
108  // call, so we do not need to do it up front.
109  template<typename NodeType, typename IndicesType, unsigned int index>
110  struct DgemvNodeEvaluator
111  {
112  typedef EvaluateNodeWithTemporaryIfNeeded<NodeType, IndicesType, index> Evaluator;
113  typedef typename Evaluator::ResultType Type;
114  static double GetScale() { return 1.0; }
115  };
116 
117  template<typename LhsType, typename IndicesType, unsigned int index>
118  struct DgemvNodeEvaluator<Node<LhsType, NegateOp>, IndicesType, index>
119  {
120  typedef EvaluateNodeWithTemporaryIfNeeded<LhsType, IndicesType, index> Evaluator;
121  typedef typename Evaluator::ResultType Type;
122  static double GetScale() { return -1.0; }
123  };
124  }
125 
126  namespace dgemv_impl
127  {
128  // To handle alpha*A*B - beta*C, this class provides the scale depending on the operator.
129  template<typename OpType>
130  struct BetaScale
131  {
132  static double GetScale() { return 1.0; }
133  };
134 
135  template<>
136  struct BetaScale<SubtractOp>
137  {
138  static double GetScale() { return -1.0; }
139  };
140  }
141 
142  namespace impl
143  {
144  template<typename NodeType, typename IndicesType, unsigned int index, typename enabled=void>
145  struct AlphaAXParameterAccessImpl : public boost::false_type {};
146 
147  template<typename A1, typename A2, typename A3, typename IndicesType, unsigned int index>
148  struct AlphaAXParameterAccessImpl< Node<Node<A1, MultiplyOp, A2>, MultiplyOp, A3>, IndicesType, index,
149  typename boost::enable_if
150  <
151  Test3ArgumentAssociativeNode<Node<Node<A1, MultiplyOp, A2>, MultiplyOp, A3>, IsDouble, MultiplyOp,
152  Nektar::CanGetRawPtr, MultiplyOp, Nektar::IsVector>
153 // boost::mpl::and_
154 // <
155 // boost::is_same<typename A1::ResultType, double>,
156 // CanGetRawPtr<typename A2::ResultType>,
157 // boost::is_same<typename A3::ResultType, NekVector<double> >
158 // >
159  >::type> : public boost::true_type
160  {
161  typedef DgemvNodeEvaluator<A1, IndicesType, index> AlphaWrappedEvaluator;
162  typedef typename AlphaWrappedEvaluator::Evaluator AlphaEvaluator;
163 
164  static const unsigned int A2Index = index + A1::TotalCount;
165  typedef DgemvNodeEvaluator<A2, IndicesType, A2Index> AWrappedEvaluator;
166  typedef typename AWrappedEvaluator::Evaluator AEvaluator;
167 
168  static const unsigned int A3Index = A2Index + A2::TotalCount;
169  typedef DgemvNodeEvaluator<A3, IndicesType, A3Index> XWrappedEvaluator;
170  typedef typename XWrappedEvaluator::Evaluator BEvaluator;
171 
172  template<typename ArgumentVectorType>
173  static double GetAlpha(const ArgumentVectorType& args)
174  {
175  return AlphaEvaluator::Evaluate(args) * AWrappedEvaluator::GetScale() * XWrappedEvaluator::GetScale();
176  }
177  };
178 
179  template<typename A1, typename A2, typename A3, typename IndicesType, unsigned int index>
180  struct AlphaAXParameterAccessImpl< Node<Node<A1, MultiplyOp, A2>, MultiplyOp, A3>, IndicesType, index,
181  typename boost::enable_if
182  <
183  Test3ArgumentAssociativeNode<Node<Node<A1, MultiplyOp, A2>, MultiplyOp, A3>, Nektar::CanGetRawPtr, MultiplyOp,
184  IsDouble, MultiplyOp, Nektar::IsVector>
185 // boost::mpl::and_
186 // <
187 // CanGetRawPtr<typename A1::ResultType>,
188 // boost::is_same<typename A2::ResultType, double>,
189 // boost::is_same<typename A3::ResultType, NekVector<double> >
190 // >
191  >::type> : public boost::true_type
192  {
193  typedef DgemvNodeEvaluator<A1, IndicesType, index> AWrappedEvaluator;
194  typedef typename AWrappedEvaluator::Evaluator AEvaluator;
195 
196  static const unsigned int A2Index = index + A1::TotalCount;
197  typedef DgemvNodeEvaluator<A2, IndicesType, A2Index> AlphaWrappedEvaluator;
198  typedef typename AlphaWrappedEvaluator::Evaluator AlphaEvaluator;
199 
200  static const unsigned int A3Index = A2Index + A2::TotalCount;
201  typedef DgemvNodeEvaluator<A3, IndicesType, A3Index> XWrappedEvaluator;
202  typedef typename XWrappedEvaluator::Evaluator BEvaluator;
203 
204  template<typename ArgumentVectorType>
205  static double GetAlpha(const ArgumentVectorType& args)
206  {
207  return AlphaEvaluator::Evaluate(args) * AWrappedEvaluator::GetScale() * XWrappedEvaluator::GetScale();
208  }
209  };
210 
211  template<typename A1, typename A2, typename A3, typename IndicesType, unsigned int index>
212  struct AlphaAXParameterAccessImpl< Node<Node<A1, MultiplyOp, A2>, MultiplyOp, A3>, IndicesType, index,
213  typename boost::enable_if
214  <
215  Test3ArgumentAssociativeNode<Node<Node<A1, MultiplyOp, A2>, MultiplyOp, A3>, Nektar::CanGetRawPtr, MultiplyOp,
216  Nektar::IsVector, MultiplyOp, IsDouble>
217 // boost::mpl::and_
218 // <
219 // CanGetRawPtr<typename A1::ResultType>,
220 // boost::is_same<typename A2::ResultType, NekVector<double> >,
221 // boost::is_same<typename A3::ResultType, double>
222 // >
223  >::type> : public boost::true_type
224  {
225  typedef DgemvNodeEvaluator<A1, IndicesType, index> AWrappedEvaluator;
226  typedef typename AWrappedEvaluator::Evaluator AEvaluator;
227 
228  static const unsigned int A2Index = index + A1::TotalCount;
229  typedef DgemvNodeEvaluator<A2, IndicesType, A2Index> XWrappedEvaluator;
230  typedef typename XWrappedEvaluator::Evaluator XEvaluator;
231 
232  static const unsigned int A3Index = A2Index + A2::TotalCount;
233  typedef DgemvNodeEvaluator<A3, IndicesType, A3Index> AlphaWrappedEvaluator;
234  typedef typename AlphaWrappedEvaluator::Evaluator AlphaEvaluator;
235 
236  template<typename ArgumentVectorType>
237  static double GetAlpha(const ArgumentVectorType& args)
238  {
239  return AlphaEvaluator::Evaluate(args) * AWrappedEvaluator::GetScale() * XWrappedEvaluator::GetScale();
240  }
241  };
242 
243  // A*B alone.
244  template<typename A1, typename A2, typename IndicesType, unsigned int index>
245  struct AlphaAXParameterAccessImpl< Node<A1, MultiplyOp, A2>, IndicesType, index,
246  typename boost::enable_if
247  <
248  boost::mpl::and_
249  <
250  TestBinaryNode<Node<A1, MultiplyOp, A2>, Nektar::CanGetRawPtr, MultiplyOp, Nektar::IsVector>,
251  boost::mpl::not_<Test3ArgumentAssociativeNode<Node<A1, MultiplyOp, A2>, IsDouble, MultiplyOp, Nektar::CanGetRawPtr, MultiplyOp, Nektar::IsVector> >,
252  boost::mpl::not_<Test3ArgumentAssociativeNode<Node<A1, MultiplyOp, A2>, Nektar::CanGetRawPtr, MultiplyOp, IsDouble, MultiplyOp, Nektar::IsVector> >,
253  boost::mpl::not_<Test3ArgumentAssociativeNode<Node<A1, MultiplyOp, A2>, Nektar::CanGetRawPtr, MultiplyOp, Nektar::IsVector, MultiplyOp, IsDouble> >
254 // CanGetRawPtr<typename A1::ResultType>,
255 // boost::is_same<typename A2::ResultType, NekVector<double> >
256  >
257  >::type> : public boost::true_type
258  {
259  typedef DgemvNodeEvaluator<A1, IndicesType, index> AWrappedEvaluator;
260  typedef typename AWrappedEvaluator::Evaluator AEvaluator;
261 
262  static const unsigned int A2Index = index + A1::TotalCount;
263  typedef DgemvNodeEvaluator<A2, IndicesType, A2Index> XWrappedEvaluator;
264  typedef typename XWrappedEvaluator::Evaluator XEvaluator;
265 
266  template<typename ArgumentVectorType>
267  static double GetAlpha(const ArgumentVectorType& args)
268  {
269  return 1.0 * AWrappedEvaluator::GetScale() * XWrappedEvaluator::GetScale();
270  }
271  };
272 
273  template<typename NodeType, typename IndicesType, unsigned int index>
274  struct AlphaAXParameterAccess : public AlphaAXParameterAccessImpl<NodeType, IndicesType, index> {};
275 
276  template<typename T, typename IndicesType, unsigned int index>
277  struct AlphaAXParameterAccess<Node<T, NegateOp, void>, IndicesType, index> : public AlphaAXParameterAccessImpl<T, IndicesType, index>
278  {
279  typedef Node<T, NegateOp, void> NodeType;
280 
281  template<typename ArgumentVectorType>
282  static double GetAlpha(const ArgumentVectorType& args)
283  {
284  return -AlphaAXParameterAccessImpl<T, IndicesType, index>::GetAlpha(args);
285  }
286  };
287 
288 
289  template<typename NodeType, typename IndicesType, unsigned int index, typename enabled=void>
290  struct BetaYParameterAccessImpl : public boost::false_type {};
291 
292  //template<typename T, typename IndicesType, unsigned int index>
293  //struct BetaYParameterAccessImpl< Node<T, void, void>, IndicesType, index> : public boost::true_type
294  //{
295  // typedef Node<T, void, void> NodeType;
296  // typedef DgemvNodeEvaluator<NodeType, IndicesType, index> YWrappedEvaluator;
297  // typedef typename YWrappedEvaluator::Evaluator YEvaluator;
298 
299  // template<typename ArgumentVectorType>
300  // static double GetBeta(const ArgumentVectorType& args)
301  // {
302  // return 1.0;
303  // }
304  //};
305 
306  template<typename L, typename R, typename IndicesType, unsigned int index>
307  struct BetaYParameterAccessImpl< Node<L, expt::MultiplyOp, R>, IndicesType, index,
308  typename boost::enable_if
309  <
310  boost::mpl::and_
311  <
312  boost::is_same<typename L::ResultType, Nektar::NekVector<double> >,
313  boost::is_same<typename R::ResultType, double>
314  >
315  >::type> : public boost::true_type
316  {
317  typedef DgemvNodeEvaluator<L, IndicesType, index> YWrappedEvaluator;
318  typedef typename YWrappedEvaluator::Evaluator YEvaluator;
319 
320  static const unsigned int nextIndex = index + L::TotalCount;
321  typedef DgemvNodeEvaluator<R, IndicesType, nextIndex> BetaWrappedEvaluator;
322  typedef typename BetaWrappedEvaluator::Evaluator BetaEvaluator;
323 
324  template<typename ArgumentVectorType>
325  static double GetBeta(const ArgumentVectorType& args)
326  {
327  return BetaEvaluator::Evaluate(args) * YWrappedEvaluator::GetScale();
328  }
329  };
330 
331  template<typename L, typename R, typename IndicesType, unsigned int index>
332  struct BetaYParameterAccessImpl< Node<L, expt::MultiplyOp, R>, IndicesType, index,
333  typename boost::enable_if
334  <
335  boost::mpl::and_
336  <
337  boost::is_same<typename R::ResultType, Nektar::NekVector<double> >,
338  boost::is_same<typename L::ResultType, double>
339  >
340  >::type> : public boost::true_type
341  {
342  typedef DgemvNodeEvaluator<L, IndicesType, index> BetaWrappedEvaluator;
343  typedef typename BetaWrappedEvaluator::Evaluator BetaEvaluator;
344 
345  static const unsigned int nextIndex = index + L::TotalCount;
346  typedef DgemvNodeEvaluator<R, IndicesType, nextIndex> YWrappedEvaluator;
347  typedef typename YWrappedEvaluator::Evaluator YEvaluator;
348 
349  template<typename ArgumentVectorType>
350  static double GetBeta(const ArgumentVectorType& args)
351  {
352  return BetaEvaluator::Evaluate(args) * YWrappedEvaluator::GetScale();
353  }
354  };
355 
356  template<typename NodeType, typename IndicesType, unsigned int index>
357  struct BetaYParameterAccessImpl< NodeType, IndicesType, index,
358  typename boost::enable_if
359  <
360  boost::is_same<typename NodeType::ResultType, Nektar::NekVector<double> >
361  >::type> : public boost::true_type
362  {
363  typedef DgemvNodeEvaluator<NodeType, IndicesType, index> YWrappedEvaluator;
364  typedef typename YWrappedEvaluator::Evaluator YEvaluator;
365 
366  template<typename ArgumentVectorType>
367  static double GetBeta(const ArgumentVectorType& args)
368  {
369  return 1.0 * YWrappedEvaluator::GetScale();
370  }
371  };
372 
373  template<typename NodeType, typename IndicesType, unsigned int index>
374  struct BetaYParameterAccess : public BetaYParameterAccessImpl<NodeType, IndicesType, index> {};
375 
376  template<typename T, typename IndicesType, unsigned int index>
377  struct BetaYParameterAccess<Node<T, NegateOp, void>, IndicesType, index> : public BetaYParameterAccessImpl<T, IndicesType, index>
378  {
379  template<typename ArgumentVectorType>
380  static double GetBeta(const ArgumentVectorType& args)
381  {
382  return -BetaYParameterAccessImpl<T, IndicesType, index>::GetBeta(args);
383  }
384  };
385 
386  }
387 
388  // Cases to handle
389  // 1. alpha*A*x +- beta*y
390  // 2. beta*y +- alpha*A*x
391  // 3. alpha*A*x + beta*B*y
392  // 4. alpha*A*x
393 
394  // Three cases to handle.
395  // 1. alpha*A*B +/- beta*C
396  // 2. beta*C +/- alpha*A*B - Implement through commutative transform.
397  // 3. alpha*A*B +/- beta*C*D - This has to be considered separately because it matches both 1 and 2.
398  // 4. alpha*A*B
399 
400  ///////////////////////////////////////////////////////////
401  // Case 1: alpha*A*B +/- beta*C
402  // Case 3 as well.
403  ///////////////////////////////////////////////////////////
404  template<typename L, typename OpType, typename R, typename IndicesType, unsigned int index>
405  struct BinaryBinaryEvaluateNodeOverride<L, OpType, R, IndicesType, index,
406  typename boost::enable_if
407  <
408  boost::mpl::and_
409  <
410  impl::AlphaAXParameterAccess<L, IndicesType, index>,
411  impl::BetaYParameterAccess<R, IndicesType, index + L::TotalCount>
412  //IsAlphaABNode<L>,
413  //IsBetaCNode<R>
414  >
415  >::type> : public boost::true_type
416  {
417  typedef Node<L, OpType, R> NodeType;
418  typedef impl::AlphaAXParameterAccess<L, IndicesType, index> LhsAccess;
419  static const unsigned int rhsIndex = index + L::TotalCount;
420  typedef impl::BetaYParameterAccess<R, IndicesType, rhsIndex> RhsAccess;
421  typedef typename LhsAccess::AEvaluator AEvaluator;
422  typedef typename LhsAccess::XEvaluator XEvaluator;
423  typedef typename RhsAccess::YEvaluator YEvaluator;
424 
425  template<typename ResultType, typename ArgumentVectorType>
426  static void Evaluate(ResultType& accumulator, const ArgumentVectorType& args)
427  {
428  typename AEvaluator::ResultType a = AEvaluator::Evaluate(args);
429  typename XEvaluator::ResultType x = XEvaluator::Evaluate(args);
430  typename YEvaluator::ResultType y = YEvaluator::Evaluate(args);
431 
432  double alpha = LhsAccess::GetAlpha(args);
433  double beta = RhsAccess::GetBeta(args);
434 
435  beta *= impl::BetaScale<OpType>::GetScale();
436  Nektar::Dgemv(accumulator, alpha, a, x, beta, y);
437  }
438  };
439 
440  ////////////////////////////////////////////////////////////
441  // Case 2 - beta*B +/- alpha*A*X
442  ////////////////////////////////////////////////////////////
443  template<typename L, typename OpType, typename R, typename IndicesType, unsigned int index>
444  struct BinaryBinaryEvaluateNodeOverride<L, OpType, R, IndicesType, index,
445  typename boost::enable_if
446  <
447  boost::mpl::and_
448  <
449  impl::AlphaAXParameterAccess<R, IndicesType, index + L::TotalCount>,
450  impl::BetaYParameterAccess<L, IndicesType, index>,
451  boost::mpl::not_<impl::AlphaAXParameterAccess<L, IndicesType, index> >
452  //IsAlphaABNode<R>,
453  //IsBetaCNode<L>,
454  //boost::mpl::not_<IsAlphaABNode<L> >
455  >
456  >::type> : public boost::true_type
457  {
458  typedef Node<L, OpType, R> NodeType;
459 
460  typedef impl::BetaYParameterAccess<L, IndicesType, index> LhsAccess;
461  static const unsigned int rhsIndex = index + L::TotalCount;
462  typedef impl::AlphaAXParameterAccess<R, IndicesType, rhsIndex> RhsAccess;
463 
464  typedef typename RhsAccess::AEvaluator AEvaluator;
465  typedef typename RhsAccess::XEvaluator XEvaluator;
466  typedef typename LhsAccess::YEvaluator YEvaluator;
467 
468  template<typename ResultType, typename ArgumentVectorType>
469  static void Evaluate(ResultType& accumulator, const ArgumentVectorType& args)
470  {
471  typename AEvaluator::ResultType a = AEvaluator::Evaluate(args);
472  typename XEvaluator::ResultType x = XEvaluator::Evaluate(args);
473  typename YEvaluator::ResultType y = YEvaluator::Evaluate(args);
474 
475  double alpha = RhsAccess::GetAlpha(args);
476  double beta = LhsAccess::GetBeta(args);
477 
478  alpha *= impl::BetaScale<OpType>::GetScale();
479  Nektar::Dgemv(accumulator, alpha, a, x, beta, y);
480  }
481  };
482 
483  ////////////////////////////////////////////////////////////
484  // Case 4 - alpha*A*B
485  ////////////////////////////////////////////////////////////
486  template<typename L, typename OpType, typename R, typename IndicesType, unsigned int index>
487  struct BinaryBinaryEvaluateNodeOverride<L, OpType, R, IndicesType, index,
488  typename boost::enable_if
489  <
490  impl::AlphaAXParameterAccess<Node<L, OpType, R>, IndicesType, index>
491  //IsAlphaABNode<Node<L, OpType, R> >
492  >::type> : public boost::true_type
493  {
494  typedef Node<L, OpType, R> NodeType;
495  typedef impl::AlphaAXParameterAccess<NodeType, IndicesType, index> LhsAccess;
496  typedef typename LhsAccess::AEvaluator AEvaluator;
497  typedef typename LhsAccess::XEvaluator XEvaluator;
498 
499  template<typename ResultType, typename ArgumentVectorType>
500  static void Evaluate(ResultType& accumulator, const ArgumentVectorType& args)
501  {
502  typename AEvaluator::ResultType a = AEvaluator::Evaluate(args);
503  typename XEvaluator::ResultType x = XEvaluator::Evaluate(args);
504 
505  double alpha = LhsAccess::GetAlpha(args);
506  Nektar::Dgemm(accumulator, alpha, a, x);
507  }
508  };
509 
510 
511 }
512 
513 #endif
514 #endif