Nektar++
MatrixOperations.hpp
Go to the documentation of this file.
1///////////////////////////////////////////////////////////////////////////////
2//
3// File: MatrixOperations.hpp
4//
5// For more information, please see: http://www.nektar.info
6//
7// The MIT License
8//
9// Copyright (c) 2006 Division of Applied Mathematics, Brown University (USA),
10// Department of Aeronautics, Imperial College London (UK), and Scientific
11// Computing and Imaging Institute, University of Utah (USA).
12//
13// Permission is hereby granted, free of charge, to any person obtaining a
14// copy of this software and associated documentation files (the "Software"),
15// to deal in the Software without restriction, including without limitation
16// the rights to use, copy, modify, merge, publish, distribute, sublicense,
17// and/or sell copies of the Software, and to permit persons to whom the
18// Software is furnished to do so, subject to the following conditions:
19//
20// The above copyright notice and this permission notice shall be included
21// in all copies or substantial portions of the Software.
22//
23// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
24// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
25// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
26// THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
27// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
28// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
29// DEALINGS IN THE SOFTWARE.
30//
31// Description: Defines the global functions needed for matrix operations.
32//
33///////////////////////////////////////////////////////////////////////////////
34
35#ifndef NEKTAR_LIB_UTILITIES_LINEAR_ALGEBRA_MATRIX_OPERATIONS_DECLARATIONS_HPP
36#define NEKTAR_LIB_UTILITIES_LINEAR_ALGEBRA_MATRIX_OPERATIONS_DECLARATIONS_HPP
37
38// Since this file defines all of the operations for all combination of matrix
39// types, we have to include all matrix specializations first.
40
48
49#include <string>
50#include <type_traits>
51
52namespace Nektar
53{
54////////////////////////////////////////////////////////////////////////////////
55// Matrix-Vector Multiplication
56////////////////////////////////////////////////////////////////////////////////
57template <typename DataType, typename LhsDataType, typename MatrixType>
58NekVector<DataType> Multiply(const NekMatrix<LhsDataType, MatrixType> &lhs,
59 const NekVector<DataType> &rhs);
60
61template <typename DataType, typename LhsDataType, typename MatrixType>
62void Multiply(NekVector<DataType> &result,
63 const NekMatrix<LhsDataType, MatrixType> &lhs,
64 const NekVector<DataType> &rhs);
65
66template <typename DataType, typename LhsInnerMatrixType>
67void Multiply(NekVector<DataType> &result,
68 const NekMatrix<LhsInnerMatrixType, BlockMatrixTag> &lhs,
69 const NekVector<DataType> &rhs);
70
71template <typename DataType, typename LhsDataType, typename MatrixType>
73 const NekVector<DataType> &rhs)
74{
75 return Multiply(lhs, rhs);
76}
77
79 NekVector<double> &result,
80 const NekMatrix<
81 NekMatrix<NekMatrix<NekDouble, StandardMatrixTag>, ScaledMatrixTag>,
82 BlockMatrixTag> &lhs,
83 const NekVector<double> &rhs);
84
87 const NekMatrix<
88 NekMatrix<NekMatrix<NekSingle, StandardMatrixTag>, ScaledMatrixTag>,
89 BlockMatrixTag> &lhs,
90 const NekVector<NekSingle> &rhs);
91
92////////////////////////////////////////////////////////////////////////////////
93// Matrix-Constant Multiplication
94////////////////////////////////////////////////////////////////////////////////
95template <typename ResultDataType, typename LhsDataType, typename LhsMatrixType>
96void Multiply(NekMatrix<ResultDataType, StandardMatrixTag> &result,
97 const NekMatrix<LhsDataType, LhsMatrixType> &lhs,
98 const ResultDataType &rhs);
99
100template <typename DataType, typename LhsDataType, typename LhsMatrixType>
101NekMatrix<typename NekMatrix<LhsDataType, LhsMatrixType>::NumberType,
102 StandardMatrixTag>
103Multiply(const NekMatrix<LhsDataType, LhsMatrixType> &lhs, const DataType &rhs);
104
105template <typename RhsDataType, typename RhsMatrixType, typename ResultDataType>
106void Multiply(NekMatrix<ResultDataType, StandardMatrixTag> &result,
107 const ResultDataType &lhs,
108 const NekMatrix<RhsDataType, RhsMatrixType> &rhs);
109
110template <typename DataType, typename RhsDataType, typename RhsMatrixType>
111NekMatrix<typename NekMatrix<RhsDataType, RhsMatrixType>::NumberType,
112 StandardMatrixTag>
113Multiply(const DataType &lhs, const NekMatrix<RhsDataType, RhsMatrixType> &rhs);
114
115template <typename DataType, typename RhsDataType, typename RhsMatrixType>
116NekMatrix<typename NekMatrix<RhsDataType, RhsMatrixType>::NumberType,
117 StandardMatrixTag>
118operator*(const DataType &lhs, const NekMatrix<RhsDataType, RhsMatrixType> &rhs)
119{
120 return Multiply(lhs, rhs);
121}
122
123template <typename DataType, typename RhsDataType, typename RhsMatrixType>
124NekMatrix<typename NekMatrix<RhsDataType, RhsMatrixType>::NumberType,
125 StandardMatrixTag>
126operator*(const NekMatrix<RhsDataType, RhsMatrixType> &lhs, const DataType &rhs)
127{
128 return Multiply(lhs, rhs);
129}
130
131template <typename LhsDataType>
132void MultiplyEqual(
133 NekMatrix<LhsDataType, StandardMatrixTag> &lhs,
134 typename boost::call_traits<LhsDataType>::const_reference rhs);
135
136///////////////////////////////////////////////////////////////////
137// Matrix-Matrix Multipliation
138//////////////////////////////////////////////////////////////////
139
140template <typename LhsDataType, typename RhsDataType, typename LhsMatrixType,
141 typename RhsMatrixType>
146 [[maybe_unused]] typename std::enable_if<
149 nullptr)
150{
151 ASSERTL1(lhs.GetType() == eFULL && rhs.GetType() == eFULL,
152 "Only full matrices are supported.");
153
154 unsigned int M = lhs.GetRows();
155 unsigned int N = rhs.GetColumns();
156 unsigned int K = lhs.GetColumns();
157
158 unsigned int LDA = M;
159 if (lhs.GetTransposeFlag() == 'T')
160 {
161 LDA = K;
162 }
163
164 unsigned int LDB = K;
165 if (rhs.GetTransposeFlag() == 'T')
166 {
167 LDB = N;
168 }
169
170 Blas::Gemm(lhs.GetTransposeFlag(), rhs.GetTransposeFlag(), M, N, K,
171 lhs.Scale() * rhs.Scale(), lhs.GetRawPtr(), LDA, rhs.GetRawPtr(),
172 LDB, 0.0, result.GetRawPtr(), result.GetRows());
173}
174
175template <typename LhsDataType, typename RhsDataType, typename DataType,
176 typename LhsMatrixType, typename RhsMatrixType>
177void Multiply(NekMatrix<DataType, StandardMatrixTag> &result,
178 const NekMatrix<LhsDataType, LhsMatrixType> &lhs,
179 const NekMatrix<RhsDataType, RhsMatrixType> &rhs);
180
181template <typename RhsInnerType, typename RhsMatrixType>
185 [[maybe_unused]] typename std::enable_if<
186 std::is_same<RawType_t<typename NekMatrix<RhsInnerType,
187 RhsMatrixType>::NumberType>,
188 RhsInnerType>::value &&
190 0)
191{
192 ASSERTL0(result.GetType() == eFULL && rhs.GetType() == eFULL,
193 "Only full matrices supported.");
194 unsigned int M = result.GetRows();
195 unsigned int N = rhs.GetColumns();
196 unsigned int K = result.GetColumns();
197
198 unsigned int LDA = M;
199 if (result.GetTransposeFlag() == 'T')
200 {
201 LDA = K;
202 }
203
204 unsigned int LDB = K;
205 if (rhs.GetTransposeFlag() == 'T')
206 {
207 LDB = N;
208 }
209 RhsInnerType scale = rhs.Scale();
210 Array<OneD, RhsInnerType> &buf = result.GetTempSpace();
211 Blas::Gemm(result.GetTransposeFlag(), rhs.GetTransposeFlag(), M, N, K,
212 scale, result.GetRawPtr(), LDA, rhs.GetRawPtr(), LDB, 0.0,
213 buf.data(), result.GetRows());
214 result.SetSize(result.GetRows(), rhs.GetColumns());
215 result.SwapTempAndDataBuffers();
216}
217
218template <typename DataType, typename RhsInnerType, typename RhsMatrixType>
222 [[maybe_unused]] typename std::enable_if<
223 !std::is_same<RawType_t<typename NekMatrix<RhsInnerType,
224 RhsMatrixType>::NumberType>,
225 DataType>::value ||
227 0)
228{
229 ASSERTL1(result.GetColumns() == rhs.GetRows(),
230 std::string("A left side matrix with column count ") +
231 std::to_string(result.GetColumns()) +
232 std::string(" and a right side matrix with row count ") +
233 std::to_string(rhs.GetRows()) +
234 std::string(" can't be multiplied."));
236 result.GetColumns());
237
238 for (unsigned int i = 0; i < result.GetRows(); ++i)
239 {
240 for (unsigned int j = 0; j < result.GetColumns(); ++j)
241 {
242 DataType t = DataType(0);
243
244 // Set the result(i,j) element.
245 for (unsigned int k = 0; k < result.GetColumns(); ++k)
246 {
247 t += result(i, k) * rhs(k, j);
248 }
249 temp(i, j) = t;
250 }
251 }
252
253 result = temp;
254}
255
256template <typename LhsDataType, typename RhsDataType, typename LhsMatrixType,
257 typename RhsMatrixType>
258NekMatrix<typename std::remove_const<
259 typename NekMatrix<LhsDataType, LhsMatrixType>::NumberType>::type,
260 StandardMatrixTag>
263{
264 typedef typename std::remove_const<
266 NumberType;
267 NekMatrix<NumberType, StandardMatrixTag> result(lhs.GetRows(),
268 rhs.GetColumns());
269 Multiply(result, lhs, rhs);
270 return result;
271}
272
273template <typename LhsDataType, typename RhsDataType, typename LhsMatrixType,
274 typename RhsMatrixType>
275NekMatrix<typename std::remove_const<
276 typename NekMatrix<LhsDataType, LhsMatrixType>::NumberType>::type,
277 StandardMatrixTag>
280{
281 return Multiply(lhs, rhs);
282}
283
284///////////////////////////////////////////////////////////////////
285// Addition
286///////////////////////////////////////////////////////////////////
287
288template <typename DataType, typename RhsDataType, typename RhsMatrixType>
289void AddEqual(NekMatrix<DataType, StandardMatrixTag> &result,
290 const NekMatrix<RhsDataType, RhsMatrixType> &rhs);
291
292template <typename DataType, typename LhsDataType, typename LhsMatrixType,
293 typename RhsDataType, typename RhsMatrixType>
294void Add(NekMatrix<DataType, StandardMatrixTag> &result,
295 const NekMatrix<LhsDataType, LhsMatrixType> &lhs,
296 const NekMatrix<RhsDataType, RhsMatrixType> &rhs);
297
298template <typename LhsDataType, typename LhsMatrixType, typename RhsDataType,
299 typename RhsMatrixType>
300NekMatrix<typename NekMatrix<LhsDataType, LhsMatrixType>::NumberType,
301 StandardMatrixTag>
302Add(const NekMatrix<LhsDataType, LhsMatrixType> &lhs,
303 const NekMatrix<RhsDataType, RhsMatrixType> &rhs);
304
305template <typename LhsDataType, typename LhsMatrixType, typename RhsDataType,
306 typename RhsMatrixType>
307NekMatrix<typename NekMatrix<LhsDataType, LhsMatrixType>::NumberType,
308 StandardMatrixTag>
311{
312 return Add(lhs, rhs);
313}
314
315template <typename DataType, typename LhsDataType, typename LhsMatrixType,
316 typename RhsDataType, typename RhsMatrixType>
317void AddNegatedLhs(NekMatrix<DataType, StandardMatrixTag> &result,
318 const NekMatrix<LhsDataType, LhsMatrixType> &lhs,
319 const NekMatrix<RhsDataType, RhsMatrixType> &rhs);
320
321template <typename DataType, typename RhsDataType, typename RhsMatrixType>
322void AddEqualNegatedLhs(NekMatrix<DataType, StandardMatrixTag> &result,
323 const NekMatrix<RhsDataType, RhsMatrixType> &rhs);
324
325////////////////////////////////////////////////////////////////////////////////
326// Subtraction
327////////////////////////////////////////////////////////////////////////////////
328template <typename DataType, typename LhsDataType, typename LhsMatrixType,
329 typename RhsDataType, typename RhsMatrixType>
330void Subtract(NekMatrix<DataType, StandardMatrixTag> &result,
331 const NekMatrix<LhsDataType, LhsMatrixType> &lhs,
332 const NekMatrix<RhsDataType, RhsMatrixType> &rhs);
333
334template <typename DataType, typename LhsDataType, typename LhsMatrixType,
335 typename RhsDataType, typename RhsMatrixType>
336void SubtractNegatedLhs(NekMatrix<DataType, StandardMatrixTag> &result,
337 const NekMatrix<LhsDataType, LhsMatrixType> &lhs,
338 const NekMatrix<RhsDataType, RhsMatrixType> &rhs);
339
340template <typename DataType, typename RhsDataType, typename RhsMatrixType>
341void SubtractEqual(NekMatrix<DataType, StandardMatrixTag> &result,
342 const NekMatrix<RhsDataType, RhsMatrixType> &rhs);
343
344template <typename DataType, typename RhsDataType, typename RhsMatrixType>
345void SubtractEqualNegatedLhs(NekMatrix<DataType, StandardMatrixTag> &result,
346 const NekMatrix<RhsDataType, RhsMatrixType> &rhs);
347
348template <typename LhsDataType, typename LhsMatrixType, typename RhsDataType,
349 typename RhsMatrixType>
350NekMatrix<typename NekMatrix<LhsDataType, LhsMatrixType>::NumberType,
351 StandardMatrixTag>
352Subtract(const NekMatrix<LhsDataType, LhsMatrixType> &lhs,
353 const NekMatrix<RhsDataType, RhsMatrixType> &rhs);
354
355template <typename LhsDataType, typename LhsMatrixType, typename RhsDataType,
356 typename RhsMatrixType>
357NekMatrix<typename NekMatrix<LhsDataType, LhsMatrixType>::NumberType,
358 StandardMatrixTag>
361{
362 return Subtract(lhs, rhs);
363}
364
365} // namespace Nektar
366
367#endif
#define ASSERTL0(condition, msg)
Definition: ErrorUtil.hpp:208
#define ASSERTL1(condition, msg)
Assert Level 1 – Debugging which is used whether in FULLDEBUG or DEBUG compilation mode....
Definition: ErrorUtil.hpp:242
#define LIB_UTILITIES_EXPORT
unsigned int GetRows() const
Definition: MatrixBase.cpp:61
unsigned int GetColumns() const
Definition: MatrixBase.cpp:73
static void Gemm(const char &transa, const char &transb, const int &m, const int &n, const int &k, const double &alpha, const double *a, const int &lda, const double *b, const int &ldb, const double &beta, double *c, const int &ldc)
BLAS level 3: Matrix-matrix multiply C = A x B where op(A)[m x k], op(B)[k x n], C[m x n] DGEMM perfo...
Definition: Blas.hpp:355
SNekMat SNekMat void SubtractEqual(NekMatrix< DataType, StandardMatrixTag > &result, const NekMatrix< RhsDataType, RhsMatrixType > &rhs)
NekMatrix< typename NekMatrix< LhsDataType, LhsMatrixType >::NumberType, StandardMatrixTag > operator-(const NekMatrix< LhsDataType, LhsMatrixType > &lhs, const NekMatrix< RhsDataType, RhsMatrixType > &rhs)
void AddEqualNegatedLhs(NekMatrix< DataType, StandardMatrixTag > &result, const NekMatrix< RhsDataType, RhsMatrixType > &rhs)
SNekMat void AddEqual(NekMatrix< DataType, StandardMatrixTag > &result, const NekMatrix< RhsDataType, RhsMatrixType > &rhs)
void Subtract(NekMatrix< DataType, StandardMatrixTag > &result, const NekMatrix< LhsDataType, LhsMatrixType > &lhs, const NekMatrix< RhsDataType, RhsMatrixType > &rhs)
Array< OneD, DataType > operator+(const Array< OneD, DataType > &lhs, typename Array< OneD, DataType >::size_type offset)
void Multiply(NekMatrix< ResultDataType, StandardMatrixTag > &result, const NekMatrix< LhsDataType, LhsMatrixType > &lhs, const ResultDataType &rhs)
void NekMultiplyFullMatrixFullMatrix(NekMatrix< ResultType, StandardMatrixTag > &result, const NekMatrix< LhsDataType, LhsMatrixType > &lhs, const NekMatrix< RhsDataType, RhsMatrixType > &rhs)
NekVector< DataType > operator*(const NekMatrix< LhsDataType, MatrixType > &lhs, const NekVector< DataType > &rhs)
void AddNegatedLhs(NekMatrix< DataType, StandardMatrixTag > &result, const NekMatrix< LhsDataType, LhsMatrixType > &lhs, const NekMatrix< RhsDataType, RhsMatrixType > &rhs)
void DiagonalBlockFullScalMatrixMultiply(NekVector< double > &result, const NekMatrix< NekMatrix< NekMatrix< NekDouble, StandardMatrixTag >, ScaledMatrixTag >, BlockMatrixTag > &lhs, const NekVector< double > &rhs)
void SubtractEqualNegatedLhs(NekMatrix< DataType, StandardMatrixTag > &result, const NekMatrix< RhsDataType, RhsMatrixType > &rhs)
typename RawType< T >::type RawType_t
Definition: RawType.hpp:70
const NekSingle void MultiplyEqual(NekMatrix< LhsDataType, StandardMatrixTag > &lhs, typename boost::call_traits< LhsDataType >::const_reference rhs)
void SubtractNegatedLhs(NekMatrix< DataType, StandardMatrixTag > &result, const NekMatrix< LhsDataType, LhsMatrixType > &lhs, const NekMatrix< RhsDataType, RhsMatrixType > &rhs)
SNekMat SNekMat void Add(NekMatrix< DataType, StandardMatrixTag > &result, const NekMatrix< LhsDataType, LhsMatrixType > &lhs, const NekMatrix< RhsDataType, RhsMatrixType > &rhs)