Nektar++
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
NekLinSys.hpp
Go to the documentation of this file.
1 ///////////////////////////////////////////////////////////////////////////////
2 //
3 // File: NekLinSys.hpp
4 //
5 // For more information, please see: http://www.nektar.info
6 //
7 // The MIT License
8 //
9 // Copyright (c) 2006 Division of Applied Mathematics, Brown University (USA),
10 // Department of Aeronautics, Imperial College London (UK), and Scientific
11 // Computing and Imaging Institute, University of Utah (USA).
12 //
13 // License for the specific language governing rights and limitations under
14 // Permission is hereby granted, free of charge, to any person obtaining a
15 // copy of this software and associated documentation files (the "Software"),
16 // to deal in the Software without restriction, including without limitation
17 // the rights to use, copy, modify, merge, publish, distribute, sublicense,
18 // and/or sell copies of the Software, and to permit persons to whom the
19 // Software is furnished to do so, subject to the following conditions:
20 //
21 // The above copyright notice and this permission notice shall be included
22 // in all copies or substantial portions of the Software.
23 //
24 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
25 // OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
26 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
27 // THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
28 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
29 // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
30 // DEALINGS IN THE SOFTWARE.
31 //
32 // Description:
33 //
34 ///////////////////////////////////////////////////////////////////////////////
35 
36 #ifndef NEKTAR_LIB_UTILITIES_LINEAR_ALGEBRA_NEK_LINSYS_HPP
37 #define NEKTAR_LIB_UTILITIES_LINEAR_ALGEBRA_NEK_LINSYS_HPP
38 
47 #include <iostream>
48 
49 #include <boost/shared_ptr.hpp>
50 #include <boost/utility/enable_if.hpp>
51 #include <boost/type_traits.hpp>
52 
53 #ifdef max
54 #undef max
55 #endif
56 
57 #ifdef min
58 #undef min
59 #endif
60 
61 namespace Nektar
62 {
63  template<typename DataType>
64  struct IsSharedPointer : public boost::false_type {};
65 
66  template<typename DataType>
67  struct IsSharedPointer<boost::shared_ptr<DataType> > : public boost::true_type {};
68 
69  // The solving of the linear system is located in this class instead of in the LinearSystem
70  // class because XCode gcc 4.2 didn't compile it correctly when it was moved to the
71  // LinearSystem class.
73  {
74  template<typename BVectorType, typename XVectorType>
75  static void Solve(const BVectorType& b, XVectorType& x, MatrixStorage m_matrixType,
76  const Array<OneD, const int>& m_ipivot, unsigned int n,
77  const Array<OneD, const double>& A,
78  char m_transposeFlag, unsigned int m_numberOfSubDiagonals,
79  unsigned int m_numberOfSuperDiagonals)
80  {
81  switch(m_matrixType)
82  {
83  case eFULL:
84  {
85  x = b;
86  int info = 0;
87  Lapack::Dgetrs('N',n,1,A.get(),n,(int *)m_ipivot.get(),x.GetRawPtr(),n,info);
88  if( info < 0 )
89  {
90  std::string message = "ERROR: The " + boost::lexical_cast<std::string>(-info) + "th parameter had an illegal parameter for dgetrs";
91  ASSERTL0(false, message.c_str());
92  }
93 
94  }
95  break;
96  case eDIAGONAL:
97  for(unsigned int i = 0; i < A.num_elements(); ++i)
98  {
99  x[i] = b[i]*A[i];
100  }
101  break;
102  case eUPPER_TRIANGULAR:
103  {
104  x = b;
105  int info = 0;
106  Lapack::Dtptrs('U', m_transposeFlag, 'N', n, 1, A.get(), x.GetRawPtr(), n, info);
107 
108  if( info < 0 )
109  {
110  std::string message = "ERROR: The " + boost::lexical_cast<std::string>(-info) + "th parameter had an illegal parameter for dtrtrs";
111  ASSERTL0(false, message.c_str());
112  }
113  else if( info > 0 )
114  {
115  std::string message = "ERROR: The " + boost::lexical_cast<std::string>(-info) + "th diagonal element of A is 0 for dtrtrs";
116  ASSERTL0(false, message.c_str());
117  }
118  }
119  break;
120  case eLOWER_TRIANGULAR:
121  {
122  x = b;
123  int info = 0;
124  Lapack::Dtptrs('L', m_transposeFlag, 'N', n, 1, A.get(), x.GetRawPtr(), n, info);
125 
126  if( info < 0 )
127  {
128  std::string message = "ERROR: The " + boost::lexical_cast<std::string>(-info) + "th parameter had an illegal parameter for dtrtrs";
129  ASSERTL0(false, message.c_str());
130  }
131  else if( info > 0 )
132  {
133  std::string message = "ERROR: The " + boost::lexical_cast<std::string>(-info) + "th diagonal element of A is 0 for dtrtrs";
134  ASSERTL0(false, message.c_str());
135  }
136  }
137  break;
138  case eSYMMETRIC:
139  {
140  x = b;
141  int info = 0;
142  Lapack::Dsptrs('U', n, 1, A.get(), m_ipivot.get(), x.GetRawPtr(), x.GetRows(), info);
143  if( info < 0 )
144  {
145  std::string message = "ERROR: The " + boost::lexical_cast<std::string>(-info) + "th parameter had an illegal parameter for dsptrs";
146  ASSERTL0(false, message.c_str());
147  }
148  }
149  break;
151  {
152  x = b;
153  int info = 0;
154  Lapack::Dpptrs('U', n, 1, A.get(), x.GetRawPtr(), x.GetRows(), info);
155  if( info < 0 )
156  {
157  std::string message = "ERROR: The " + boost::lexical_cast<std::string>(-info) + "th parameter had an illegal parameter for dpptrs";
158  ASSERTL0(false, message.c_str());
159  }
160  }
161  break;
162  case eBANDED:
163  {
164  x = b;
165  int KL = m_numberOfSubDiagonals;
166  int KU = m_numberOfSuperDiagonals;
167  int info = 0;
168 
169  Lapack::Dgbtrs(m_transposeFlag, n, KL, KU, 1, A.get(), 2*KL+KU+1, m_ipivot.get(), x.GetRawPtr(), n, info);
170 
171  if( info < 0 )
172  {
173  std::string message = "ERROR: The " + boost::lexical_cast<std::string>(-info) + "th parameter had an illegal parameter for dgbtrs";
174  ASSERTL0(false, message.c_str());
175  }
176  }
177  break;
179  {
180  x = b;
181  int KU = m_numberOfSuperDiagonals;
182  int info = 0;
183 
184  Lapack::Dpbtrs('U', n, KU, 1, A.get(), KU+1, x.GetRawPtr(), n, info);
185 
186  if( info < 0 )
187  {
188  std::string message = "ERROR: The " + boost::lexical_cast<std::string>(-info) + "th parameter had an illegal parameter for dpbtrs";
189  ASSERTL0(false, message.c_str());
190  }
191  }
192  break;
193  case eSYMMETRIC_BANDED:
194  NEKERROR(ErrorUtil::efatal, "Unhandled matrix type");
195  break;
197  NEKERROR(ErrorUtil::efatal, "Unhandled matrix type");
198  break;
200  NEKERROR(ErrorUtil::efatal, "Unhandled matrix type");
201  break;
202 
203  default:
204  NEKERROR(ErrorUtil::efatal, "Unhandled matrix type");
205  }
206 
207  }
208 
209  template<typename BVectorType, typename XVectorType>
210  static void SolveTranspose(const BVectorType& b, XVectorType& x, MatrixStorage m_matrixType,
211  const Array<OneD, const int>& m_ipivot, unsigned int n,
212  const Array<OneD, const double>& A,
213  char m_transposeFlag, unsigned int m_numberOfSubDiagonals,
214  unsigned int m_numberOfSuperDiagonals)
215  {
216  switch(m_matrixType)
217  {
218  case eFULL:
219  {
220  x = b;
221  int info = 0;
222  Lapack::Dgetrs('T',n,1,A.get(),n,(int *)m_ipivot.get(),x.GetRawPtr(), n,info);
223 
224  if( info < 0 )
225  {
226  std::string message = "ERROR: The " + boost::lexical_cast<std::string>(-info) + "th parameter had an illegal parameter for dgetrs";
227  ASSERTL0(false, message.c_str());
228  }
229  }
230 
231  break;
232  case eDIAGONAL:
233  Solve(b, x, m_matrixType, m_ipivot, n, A, m_transposeFlag, m_numberOfSubDiagonals, m_numberOfSuperDiagonals);
234  break;
235  case eUPPER_TRIANGULAR:
236  {
237  char trans = m_transposeFlag;
238  if( trans == 'N' )
239  {
240  trans = 'T';
241  }
242  else
243  {
244  trans = 'N';
245  }
246 
247  x = b;
248  int info = 0;
249  Lapack::Dtptrs('U', trans, 'N', n, 1, A.get(), x.GetRawPtr(), n, info);
250 
251  if( info < 0 )
252  {
253  std::string message = "ERROR: The " + boost::lexical_cast<std::string>(-info) + "th parameter had an illegal parameter for dtrtrs";
254  ASSERTL0(false, message.c_str());
255  }
256  else if( info > 0 )
257  {
258  std::string message = "ERROR: The " + boost::lexical_cast<std::string>(-info) + "th diagonal element of A is 0 for dtrtrs";
259  ASSERTL0(false, message.c_str());
260  }
261  }
262 
263  break;
264  case eLOWER_TRIANGULAR:
265  {
266  char trans = m_transposeFlag;
267  if( trans == 'N' )
268  {
269  trans = 'T';
270  }
271  else
272  {
273  trans = 'N';
274  }
275  x = b;
276  int info = 0;
277  Lapack::Dtptrs('L', trans, 'N', n, 1, A.get(), x.GetRawPtr(), n, info);
278 
279  if( info < 0 )
280  {
281  std::string message = "ERROR: The " + boost::lexical_cast<std::string>(-info) + "th parameter had an illegal parameter for dtrtrs";
282  ASSERTL0(false, message.c_str());
283  }
284  else if( info > 0 )
285  {
286  std::string message = "ERROR: The " + boost::lexical_cast<std::string>(-info) + "th diagonal element of A is 0 for dtrtrs";
287  ASSERTL0(false, message.c_str());
288  }
289  }
290  break;
291  case eSYMMETRIC:
294  Solve(b, x, m_matrixType, m_ipivot, n, A, m_transposeFlag, m_numberOfSubDiagonals, m_numberOfSuperDiagonals);
295  break;
296  case eBANDED:
297  {
298  x = b;
299  int KL = m_numberOfSubDiagonals;
300  int KU = m_numberOfSuperDiagonals;
301  int info = 0;
302 
303  Lapack::Dgbtrs(m_transposeFlag, n, KL, KU, 1, A.get(), 2*KL+KU+1, m_ipivot.get(), x.GetRawPtr(), n, info);
304 
305  if( info < 0 )
306  {
307  std::string message = "ERROR: The " + boost::lexical_cast<std::string>(-info) + "th parameter had an illegal parameter for dgbtrs";
308  ASSERTL0(false, message.c_str());
309  }
310  }
311  break;
312  case eSYMMETRIC_BANDED:
313  NEKERROR(ErrorUtil::efatal, "Unhandled matrix type");
314  break;
316  NEKERROR(ErrorUtil::efatal, "Unhandled matrix type");
317  break;
319  NEKERROR(ErrorUtil::efatal, "Unhandled matrix type");
320  break;
321 
322  default:
323  NEKERROR(ErrorUtil::efatal, "Unhandled matrix type");
324  }
325  }
326  };
327 
328 
330  {
331  public:
332  template<typename MatrixType>
333  explicit LinearSystem(const boost::shared_ptr<MatrixType> &theA, PointerWrapper wrapperType = eCopy) :
334  n(theA->GetRows()),
335  A(theA->GetPtr(), eVECTOR_WRAPPER),
336  m_ipivot(),
337  m_numberOfSubDiagonals(theA->GetNumberOfSubDiagonals()),
338  m_numberOfSuperDiagonals(theA->GetNumberOfSuperDiagonals()),
339  m_matrixType(theA->GetType()),
340  m_transposeFlag(theA->GetTransposeFlag())
341  {
342  // At some point we should fix this. We should upate the copy of
343  // A to be transposd for this to work.
344  ASSERTL0(theA->GetTransposeFlag() == 'N', "LinearSystem requires a non-transposed matrix.");
345  ASSERTL0( (wrapperType == eWrapper && theA->GetType() != eBANDED) || wrapperType == eCopy , "Banded matrices can't be wrapped");
346 
347  if( wrapperType == eCopy )
348  {
349  A = Array<OneD, double>(theA->GetPtr().num_elements());
350  CopyArray(theA->GetPtr(), A);
351  }
352 
353  FactorMatrix(*theA);
354  }
355 
356  template<typename MatrixType>
357  explicit LinearSystem(const MatrixType& theA, PointerWrapper wrapperType = eCopy) :
358  n(theA.GetRows()),
359  A(theA.GetPtr(), eVECTOR_WRAPPER),
360  m_ipivot(),
361  m_numberOfSubDiagonals(theA.GetNumberOfSubDiagonals()),
362  m_numberOfSuperDiagonals(theA.GetNumberOfSuperDiagonals()),
363  m_matrixType(theA.GetType()),
364  m_transposeFlag(theA.GetTransposeFlag())
365  {
366  // At some point we should fix this. We should upate the copy of
367  // A to be transposd for this to work.
368  ASSERTL0(theA.GetTransposeFlag() == 'N', "LinearSystem requires a non-transposed matrix.");
369  ASSERTL0( (wrapperType == eWrapper && theA.GetType() != eBANDED) || wrapperType == eCopy, "Banded matrices can't be wrapped" );
370 
371  if( wrapperType == eCopy )
372  {
373  A = Array<OneD, double>(theA.GetPtr().num_elements());
374  CopyArray(theA.GetPtr(), A);
375  }
376 
377  FactorMatrix(theA);
378  }
379 
381  n(rhs.n),
382  A(rhs.A),
383  m_ipivot(rhs.m_ipivot),
388  {
389  }
390 
392  {
393  LinearSystem temp(rhs);
394  swap(temp);
395  return *this;
396  }
397 
399 
400  // In the following calls to Solve, VectorType must be a NekVector.
401  // Anything else won't compile.
402  template<typename VectorType>
403  typename RawType<VectorType>::type Solve(const VectorType& b)
404  {
408  return x;
409  }
410 
411  template<typename BType, typename XType>
412  void Solve(const BType& b, XType& x) const
413  {
417  }
418 
419  // Transpose variant of solve
420  template<typename VectorType>
421  typename RawType<VectorType>::type SolveTranspose(const VectorType& b)
422  {
426  return x;
427  }
428 
429  template<typename BType, typename XType>
430  void SolveTranspose(const BType& b, XType& x) const
431  {
435  }
436 
437  unsigned int GetRows() const { return n; }
438  unsigned int GetColumns() const { return n; }
439 
440  private:
441  template<typename MatrixType>
442  void FactorMatrix(const MatrixType& theA)
443  {
444  switch(m_matrixType)
445  {
446  case eFULL:
447  {
448  int m = theA.GetRows();
449  int n = theA.GetColumns();
450 
451  int pivotSize = std::max(1, std::min(m, n));
452  int info = 0;
453  m_ipivot = Array<OneD, int>(pivotSize);
454 
455  Lapack::Dgetrf(m, n, A.get(), m, m_ipivot.get(), info);
456 
457  if( info < 0 )
458  {
459  std::string message = "ERROR: The " + boost::lexical_cast<std::string>(-info) + "th parameter had an illegal parameter for dgetrf";
460  ASSERTL0(false, message.c_str());
461  }
462  else if( info > 0 )
463  {
464  std::string message = "ERROR: Element u_" + boost::lexical_cast<std::string>(info) + boost::lexical_cast<std::string>(info) + " is 0 from dgetrf";
465  ASSERTL0(false, message.c_str());
466  }
467  }
468  break;
469  case eDIAGONAL:
470  for(unsigned int i = 0; i < theA.GetColumns(); ++i)
471  {
472  A[i] = 1.0/theA(i,i);
473  }
474  break;
475  case eUPPER_TRIANGULAR:
476  case eLOWER_TRIANGULAR:
477  break;
478  case eSYMMETRIC:
479  {
480  int info = 0;
481  int pivotSize = theA.GetRows();
482  m_ipivot = Array<OneD, int>(pivotSize);
483 
484  Lapack::Dsptrf('U', theA.GetRows(), A.get(), m_ipivot.get(), info);
485 
486  if( info < 0 )
487  {
488  std::string message = "ERROR: The " + boost::lexical_cast<std::string>(-info) + "th parameter had an illegal parameter for dsptrf";
489  ASSERTL0(false, message.c_str());
490  }
491  else if( info > 0 )
492  {
493  std::string message = "ERROR: Element u_" + boost::lexical_cast<std::string>(info) + boost::lexical_cast<std::string>(info) + " is 0 from dsptrf";
494  ASSERTL0(false, message.c_str());
495  }
496  }
497  break;
499  {
500  int info = 0;
501  Lapack::Dpptrf('U', theA.GetRows(), A.get(), info);
502 
503  if( info < 0 )
504  {
505  std::string message = "ERROR: The " + boost::lexical_cast<std::string>(-info) + "th parameter had an illegal parameter for dpptrf";
506  ASSERTL0(false, message.c_str());
507  }
508  else if( info > 0 )
509  {
510  std::string message = "ERROR: The leading minor of order " + boost::lexical_cast<std::string>(info) + " is not positive definite from dpptrf";
511  ASSERTL0(false, message.c_str());
512  }
513  }
514  break;
515  case eBANDED:
516  {
517  int M = n;
518  int N = n;
519  int KL = m_numberOfSubDiagonals;
520  int KU = m_numberOfSuperDiagonals;
521 
522  // The array we pass in to dgbtrf must have enough space for KL
523  // subdiagonals and KL+KU superdiagonals (see lapack users guide,
524  // in the section discussing band storage.
525  unsigned int requiredStorageSize = BandedMatrixFuncs::
526  GetRequiredStorageSize(n, n, KL, KL+KU);
527 
528  unsigned int rawRows = KL+KU+1;
529  A = Array<OneD, double>(requiredStorageSize);
530 
531  // Put the extra elements up front.
532  for(unsigned int i = 0; i < theA.GetColumns(); ++i)
533  {
534  std::copy(theA.GetRawPtr() + i*rawRows, theA.GetRawPtr() + (i+1)*rawRows,
535  A.get() + (i+1)*KL + i*rawRows);
536  }
537 
538  int info = 0;
539  int pivotSize = theA.GetRows();
540  m_ipivot = Array<OneD, int>(pivotSize);
541 
542  Lapack::Dgbtrf(M, N, KL, KU, A.get(), 2*KL+KU+1, m_ipivot.get(), info);
543 
544  if( info < 0 )
545  {
546  std::string message = "ERROR: The " + boost::lexical_cast<std::string>(-info) + "th parameter had an illegal parameter for dgbtrf";
547  ASSERTL0(false, message.c_str());
548  }
549  else if( info > 0 )
550  {
551  std::string message = "ERROR: Element u_" + boost::lexical_cast<std::string>(info) + boost::lexical_cast<std::string>(info) + " is 0 from dgbtrf";
552  ASSERTL0(false, message.c_str());
553  }
554  }
555  break;
557  {
559  std::string("Number of sub- and superdiagonals should ") +
560  std::string("be equal for a symmetric banded matrix"));
561 
562  int KU = m_numberOfSuperDiagonals;
563  int info = 0;
564  Lapack::Dpbtrf('U', theA.GetRows(), KU, A.get(), KU+1, info);
565 
566  if( info < 0 )
567  {
568  std::string message = "ERROR: The " + boost::lexical_cast<std::string>(-info) + "th parameter had an illegal parameter for dpbtrf";
569  ASSERTL0(false, message.c_str());
570  }
571  else if( info > 0 )
572  {
573  std::string message = "ERROR: The leading minor of order " + boost::lexical_cast<std::string>(info) + " is not positive definite from dpbtrf";
574  ASSERTL0(false, message.c_str());
575  }
576  }
577  break;
578  case eSYMMETRIC_BANDED:
579  NEKERROR(ErrorUtil::efatal, "Unhandled matrix type");
580  break;
582  NEKERROR(ErrorUtil::efatal, "Unhandled matrix type");
583  break;
585  NEKERROR(ErrorUtil::efatal, "Unhandled matrix type");
586  break;
587 
588  default:
589  NEKERROR(ErrorUtil::efatal, "Unhandled matrix type");
590  }
591  }
592 
593  void swap(LinearSystem& rhs)
594  {
595  std::swap(n, rhs.n);
596  std::swap(A, rhs.A);
602  }
603 
604  unsigned int n;
605  Array<OneD, double> A;
606  Array<OneD, int> m_ipivot;
611  };
612 }
613 
614 #endif //NEKTAR_LIB_UTILITIES_LINEAR_ALGEBRA_NEK_LINSYS_HPP