Nektar++
MatrixOperations.hpp
Go to the documentation of this file.
1///////////////////////////////////////////////////////////////////////////////
2//
3// File: MatrixOperations.hpp
4//
5// For more information, please see: http://www.nektar.info
6//
7// The MIT License
8//
9// Copyright (c) 2006 Division of Applied Mathematics, Brown University (USA),
10// Department of Aeronautics, Imperial College London (UK), and Scientific
11// Computing and Imaging Institute, University of Utah (USA).
12//
13// Permission is hereby granted, free of charge, to any person obtaining a
14// copy of this software and associated documentation files (the "Software"),
15// to deal in the Software without restriction, including without limitation
16// the rights to use, copy, modify, merge, publish, distribute, sublicense,
17// and/or sell copies of the Software, and to permit persons to whom the
18// Software is furnished to do so, subject to the following conditions:
19//
20// The above copyright notice and this permission notice shall be included
21// in all copies or substantial portions of the Software.
22//
23// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
24// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
25// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
26// THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
27// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
28// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
29// DEALINGS IN THE SOFTWARE.
30//
31// Description: Defines the global functions needed for matrix operations.
32//
33///////////////////////////////////////////////////////////////////////////////
34
35#ifndef NEKTAR_LIB_UTILITIES_LINEAR_ALGEBRA_MATRIX_OPERATIONS_DECLARATIONS_HPP
36#define NEKTAR_LIB_UTILITIES_LINEAR_ALGEBRA_MATRIX_OPERATIONS_DECLARATIONS_HPP
37
38#include <boost/core/ignore_unused.hpp>
39
40// Since this file defines all of the operations for all combination of matrix
41// types, we have to include all matrix specializations first.
42
50
51#include <string>
52#include <type_traits>
53
54namespace Nektar
55{
56////////////////////////////////////////////////////////////////////////////////
57// Matrix-Vector Multiplication
58////////////////////////////////////////////////////////////////////////////////
59template <typename DataType, typename LhsDataType, typename MatrixType>
60NekVector<DataType> Multiply(const NekMatrix<LhsDataType, MatrixType> &lhs,
61 const NekVector<DataType> &rhs);
62
63template <typename DataType, typename LhsDataType, typename MatrixType>
64void Multiply(NekVector<DataType> &result,
65 const NekMatrix<LhsDataType, MatrixType> &lhs,
66 const NekVector<DataType> &rhs);
67
68template <typename DataType, typename LhsInnerMatrixType>
69void Multiply(NekVector<DataType> &result,
70 const NekMatrix<LhsInnerMatrixType, BlockMatrixTag> &lhs,
71 const NekVector<DataType> &rhs);
72
73template <typename DataType, typename LhsDataType, typename MatrixType>
75 const NekVector<DataType> &rhs)
76{
77 return Multiply(lhs, rhs);
78}
79
81 NekVector<double> &result,
82 const NekMatrix<
83 NekMatrix<NekMatrix<NekDouble, StandardMatrixTag>, ScaledMatrixTag>,
84 BlockMatrixTag> &lhs,
85 const NekVector<double> &rhs);
86
89 const NekMatrix<
90 NekMatrix<NekMatrix<NekSingle, StandardMatrixTag>, ScaledMatrixTag>,
91 BlockMatrixTag> &lhs,
92 const NekVector<NekSingle> &rhs);
93
94////////////////////////////////////////////////////////////////////////////////
95// Matrix-Constant Multiplication
96////////////////////////////////////////////////////////////////////////////////
97template <typename ResultDataType, typename LhsDataType, typename LhsMatrixType>
98void Multiply(NekMatrix<ResultDataType, StandardMatrixTag> &result,
99 const NekMatrix<LhsDataType, LhsMatrixType> &lhs,
100 const ResultDataType &rhs);
101
102template <typename DataType, typename LhsDataType, typename LhsMatrixType>
103NekMatrix<typename NekMatrix<LhsDataType, LhsMatrixType>::NumberType,
104 StandardMatrixTag>
105Multiply(const NekMatrix<LhsDataType, LhsMatrixType> &lhs, const DataType &rhs);
106
107template <typename RhsDataType, typename RhsMatrixType, typename ResultDataType>
108void Multiply(NekMatrix<ResultDataType, StandardMatrixTag> &result,
109 const ResultDataType &lhs,
110 const NekMatrix<RhsDataType, RhsMatrixType> &rhs);
111
112template <typename DataType, typename RhsDataType, typename RhsMatrixType>
113NekMatrix<typename NekMatrix<RhsDataType, RhsMatrixType>::NumberType,
114 StandardMatrixTag>
115Multiply(const DataType &lhs, const NekMatrix<RhsDataType, RhsMatrixType> &rhs);
116
117template <typename DataType, typename RhsDataType, typename RhsMatrixType>
118NekMatrix<typename NekMatrix<RhsDataType, RhsMatrixType>::NumberType,
119 StandardMatrixTag>
120operator*(const DataType &lhs, const NekMatrix<RhsDataType, RhsMatrixType> &rhs)
121{
122 return Multiply(lhs, rhs);
123}
124
125template <typename DataType, typename RhsDataType, typename RhsMatrixType>
126NekMatrix<typename NekMatrix<RhsDataType, RhsMatrixType>::NumberType,
127 StandardMatrixTag>
128operator*(const NekMatrix<RhsDataType, RhsMatrixType> &lhs, const DataType &rhs)
129{
130 return Multiply(lhs, rhs);
131}
132
133template <typename LhsDataType>
134void MultiplyEqual(
135 NekMatrix<LhsDataType, StandardMatrixTag> &lhs,
136 typename boost::call_traits<LhsDataType>::const_reference rhs);
137
138///////////////////////////////////////////////////////////////////
139// Matrix-Matrix Multipliation
140//////////////////////////////////////////////////////////////////
141
142template <typename LhsDataType, typename RhsDataType, typename LhsMatrixType,
143 typename RhsMatrixType>
148 typename std::enable_if<
151 0)
152{
153 boost::ignore_unused(p);
154
155 ASSERTL1(lhs.GetType() == eFULL && rhs.GetType() == eFULL,
156 "Only full matrices are supported.");
157
158 unsigned int M = lhs.GetRows();
159 unsigned int N = rhs.GetColumns();
160 unsigned int K = lhs.GetColumns();
161
162 unsigned int LDA = M;
163 if (lhs.GetTransposeFlag() == 'T')
164 {
165 LDA = K;
166 }
167
168 unsigned int LDB = K;
169 if (rhs.GetTransposeFlag() == 'T')
170 {
171 LDB = N;
172 }
173
174 Blas::Gemm(lhs.GetTransposeFlag(), rhs.GetTransposeFlag(), M, N, K,
175 lhs.Scale() * rhs.Scale(), lhs.GetRawPtr(), LDA, rhs.GetRawPtr(),
176 LDB, 0.0, result.GetRawPtr(), result.GetRows());
177}
178
179template <typename LhsDataType, typename RhsDataType, typename DataType,
180 typename LhsMatrixType, typename RhsMatrixType>
181void Multiply(NekMatrix<DataType, StandardMatrixTag> &result,
182 const NekMatrix<LhsDataType, LhsMatrixType> &lhs,
183 const NekMatrix<RhsDataType, RhsMatrixType> &rhs);
184
185template <typename RhsInnerType, typename RhsMatrixType>
189 typename std::enable_if<
190 std::is_same<RawType_t<typename NekMatrix<RhsInnerType,
191 RhsMatrixType>::NumberType>,
192 RhsInnerType>::value &&
194 0)
195{
196 boost::ignore_unused(t);
197 ASSERTL0(result.GetType() == eFULL && rhs.GetType() == eFULL,
198 "Only full matrices supported.");
199 unsigned int M = result.GetRows();
200 unsigned int N = rhs.GetColumns();
201 unsigned int K = result.GetColumns();
202
203 unsigned int LDA = M;
204 if (result.GetTransposeFlag() == 'T')
205 {
206 LDA = K;
207 }
208
209 unsigned int LDB = K;
210 if (rhs.GetTransposeFlag() == 'T')
211 {
212 LDB = N;
213 }
214 RhsInnerType scale = rhs.Scale();
215 Array<OneD, RhsInnerType> &buf = result.GetTempSpace();
216 Blas::Gemm(result.GetTransposeFlag(), rhs.GetTransposeFlag(), M, N, K,
217 scale, result.GetRawPtr(), LDA, rhs.GetRawPtr(), LDB, 0.0,
218 buf.data(), result.GetRows());
219 result.SetSize(result.GetRows(), rhs.GetColumns());
220 result.SwapTempAndDataBuffers();
221}
222
223template <typename DataType, typename RhsInnerType, typename RhsMatrixType>
227 typename std::enable_if<
228 !std::is_same<RawType_t<typename NekMatrix<RhsInnerType,
229 RhsMatrixType>::NumberType>,
230 DataType>::value ||
232 0)
233{
234 boost::ignore_unused(t);
235 ASSERTL1(result.GetColumns() == rhs.GetRows(),
236 std::string("A left side matrix with column count ") +
237 std::to_string(result.GetColumns()) +
238 std::string(" and a right side matrix with row count ") +
239 std::to_string(rhs.GetRows()) +
240 std::string(" can't be multiplied."));
242 result.GetColumns());
243
244 for (unsigned int i = 0; i < result.GetRows(); ++i)
245 {
246 for (unsigned int j = 0; j < result.GetColumns(); ++j)
247 {
248 DataType t = DataType(0);
249
250 // Set the result(i,j) element.
251 for (unsigned int k = 0; k < result.GetColumns(); ++k)
252 {
253 t += result(i, k) * rhs(k, j);
254 }
255 temp(i, j) = t;
256 }
257 }
258
259 result = temp;
260}
261
262template <typename LhsDataType, typename RhsDataType, typename LhsMatrixType,
263 typename RhsMatrixType>
264NekMatrix<typename std::remove_const<
265 typename NekMatrix<LhsDataType, LhsMatrixType>::NumberType>::type,
266 StandardMatrixTag>
269{
270 typedef typename std::remove_const<
272 NumberType;
273 NekMatrix<NumberType, StandardMatrixTag> result(lhs.GetRows(),
274 rhs.GetColumns());
275 Multiply(result, lhs, rhs);
276 return result;
277}
278
279template <typename LhsDataType, typename RhsDataType, typename LhsMatrixType,
280 typename RhsMatrixType>
281NekMatrix<typename std::remove_const<
282 typename NekMatrix<LhsDataType, LhsMatrixType>::NumberType>::type,
283 StandardMatrixTag>
286{
287 return Multiply(lhs, rhs);
288}
289
290///////////////////////////////////////////////////////////////////
291// Addition
292///////////////////////////////////////////////////////////////////
293
294template <typename DataType, typename RhsDataType, typename RhsMatrixType>
295void AddEqual(NekMatrix<DataType, StandardMatrixTag> &result,
296 const NekMatrix<RhsDataType, RhsMatrixType> &rhs);
297
298template <typename DataType, typename LhsDataType, typename LhsMatrixType,
299 typename RhsDataType, typename RhsMatrixType>
300void Add(NekMatrix<DataType, StandardMatrixTag> &result,
301 const NekMatrix<LhsDataType, LhsMatrixType> &lhs,
302 const NekMatrix<RhsDataType, RhsMatrixType> &rhs);
303
304template <typename LhsDataType, typename LhsMatrixType, typename RhsDataType,
305 typename RhsMatrixType>
306NekMatrix<typename NekMatrix<LhsDataType, LhsMatrixType>::NumberType,
307 StandardMatrixTag>
308Add(const NekMatrix<LhsDataType, LhsMatrixType> &lhs,
309 const NekMatrix<RhsDataType, RhsMatrixType> &rhs);
310
311template <typename LhsDataType, typename LhsMatrixType, typename RhsDataType,
312 typename RhsMatrixType>
313NekMatrix<typename NekMatrix<LhsDataType, LhsMatrixType>::NumberType,
314 StandardMatrixTag>
317{
318 return Add(lhs, rhs);
319}
320
321template <typename DataType, typename LhsDataType, typename LhsMatrixType,
322 typename RhsDataType, typename RhsMatrixType>
323void AddNegatedLhs(NekMatrix<DataType, StandardMatrixTag> &result,
324 const NekMatrix<LhsDataType, LhsMatrixType> &lhs,
325 const NekMatrix<RhsDataType, RhsMatrixType> &rhs);
326
327template <typename DataType, typename RhsDataType, typename RhsMatrixType>
328void AddEqualNegatedLhs(NekMatrix<DataType, StandardMatrixTag> &result,
329 const NekMatrix<RhsDataType, RhsMatrixType> &rhs);
330
331////////////////////////////////////////////////////////////////////////////////
332// Subtraction
333////////////////////////////////////////////////////////////////////////////////
334template <typename DataType, typename LhsDataType, typename LhsMatrixType,
335 typename RhsDataType, typename RhsMatrixType>
336void Subtract(NekMatrix<DataType, StandardMatrixTag> &result,
337 const NekMatrix<LhsDataType, LhsMatrixType> &lhs,
338 const NekMatrix<RhsDataType, RhsMatrixType> &rhs);
339
340template <typename DataType, typename LhsDataType, typename LhsMatrixType,
341 typename RhsDataType, typename RhsMatrixType>
342void SubtractNegatedLhs(NekMatrix<DataType, StandardMatrixTag> &result,
343 const NekMatrix<LhsDataType, LhsMatrixType> &lhs,
344 const NekMatrix<RhsDataType, RhsMatrixType> &rhs);
345
346template <typename DataType, typename RhsDataType, typename RhsMatrixType>
347void SubtractEqual(NekMatrix<DataType, StandardMatrixTag> &result,
348 const NekMatrix<RhsDataType, RhsMatrixType> &rhs);
349
350template <typename DataType, typename RhsDataType, typename RhsMatrixType>
351void SubtractEqualNegatedLhs(NekMatrix<DataType, StandardMatrixTag> &result,
352 const NekMatrix<RhsDataType, RhsMatrixType> &rhs);
353
354template <typename LhsDataType, typename LhsMatrixType, typename RhsDataType,
355 typename RhsMatrixType>
356NekMatrix<typename NekMatrix<LhsDataType, LhsMatrixType>::NumberType,
357 StandardMatrixTag>
358Subtract(const NekMatrix<LhsDataType, LhsMatrixType> &lhs,
359 const NekMatrix<RhsDataType, RhsMatrixType> &rhs);
360
361template <typename LhsDataType, typename LhsMatrixType, typename RhsDataType,
362 typename RhsMatrixType>
363NekMatrix<typename NekMatrix<LhsDataType, LhsMatrixType>::NumberType,
364 StandardMatrixTag>
367{
368 return Subtract(lhs, rhs);
369}
370
371} // namespace Nektar
372
373#endif
#define ASSERTL0(condition, msg)
Definition: ErrorUtil.hpp:215
#define ASSERTL1(condition, msg)
Assert Level 1 – Debugging which is used whether in FULLDEBUG or DEBUG compilation mode....
Definition: ErrorUtil.hpp:249
#define LIB_UTILITIES_EXPORT
unsigned int GetRows() const
Definition: MatrixBase.cpp:65
unsigned int GetColumns() const
Definition: MatrixBase.cpp:84
static void Gemm(const char &transa, const char &transb, const int &m, const int &n, const int &k, const double &alpha, const double *a, const int &lda, const double *b, const int &ldb, const double &beta, double *c, const int &ldc)
BLAS level 3: Matrix-matrix multiply C = A x B where op(A)[m x k], op(B)[k x n], C[m x n] DGEMM perfo...
Definition: Blas.hpp:357
The above copyright notice and this permission notice shall be included.
Definition: CoupledSolver.h:2
SNekMat SNekMat void SubtractEqual(NekMatrix< DataType, StandardMatrixTag > &result, const NekMatrix< RhsDataType, RhsMatrixType > &rhs)
NekMatrix< typename NekMatrix< LhsDataType, LhsMatrixType >::NumberType, StandardMatrixTag > operator-(const NekMatrix< LhsDataType, LhsMatrixType > &lhs, const NekMatrix< RhsDataType, RhsMatrixType > &rhs)
void AddEqualNegatedLhs(NekMatrix< DataType, StandardMatrixTag > &result, const NekMatrix< RhsDataType, RhsMatrixType > &rhs)
SNekMat void AddEqual(NekMatrix< DataType, StandardMatrixTag > &result, const NekMatrix< RhsDataType, RhsMatrixType > &rhs)
void Subtract(NekMatrix< DataType, StandardMatrixTag > &result, const NekMatrix< LhsDataType, LhsMatrixType > &lhs, const NekMatrix< RhsDataType, RhsMatrixType > &rhs)
Array< OneD, DataType > operator+(const Array< OneD, DataType > &lhs, typename Array< OneD, DataType >::size_type offset)
void Multiply(NekMatrix< ResultDataType, StandardMatrixTag > &result, const NekMatrix< LhsDataType, LhsMatrixType > &lhs, const ResultDataType &rhs)
void NekMultiplyFullMatrixFullMatrix(NekMatrix< ResultType, StandardMatrixTag > &result, const NekMatrix< LhsDataType, LhsMatrixType > &lhs, const NekMatrix< RhsDataType, RhsMatrixType > &rhs)
NekVector< DataType > operator*(const NekMatrix< LhsDataType, MatrixType > &lhs, const NekVector< DataType > &rhs)
void AddNegatedLhs(NekMatrix< DataType, StandardMatrixTag > &result, const NekMatrix< LhsDataType, LhsMatrixType > &lhs, const NekMatrix< RhsDataType, RhsMatrixType > &rhs)
void DiagonalBlockFullScalMatrixMultiply(NekVector< double > &result, const NekMatrix< NekMatrix< NekMatrix< NekDouble, StandardMatrixTag >, ScaledMatrixTag >, BlockMatrixTag > &lhs, const NekVector< double > &rhs)
void SubtractEqualNegatedLhs(NekMatrix< DataType, StandardMatrixTag > &result, const NekMatrix< RhsDataType, RhsMatrixType > &rhs)
typename RawType< T >::type RawType_t
Definition: RawType.hpp:70
const NekSingle void MultiplyEqual(NekMatrix< LhsDataType, StandardMatrixTag > &lhs, typename boost::call_traits< LhsDataType >::const_reference rhs)
void SubtractNegatedLhs(NekMatrix< DataType, StandardMatrixTag > &result, const NekMatrix< LhsDataType, LhsMatrixType > &lhs, const NekMatrix< RhsDataType, RhsMatrixType > &rhs)
SNekMat SNekMat void Add(NekMatrix< DataType, StandardMatrixTag > &result, const NekMatrix< LhsDataType, LhsMatrixType > &lhs, const NekMatrix< RhsDataType, RhsMatrixType > &rhs)