32 #ifndef NEKTAR_LIBUTILITIES_LINEAR_ALGEBRA_MATRIX_SIZE_HPP
33 #define NEKTAR_LIBUTILITIES_LINEAR_ALGEBRA_MATRIX_SIZE_HPP
35 #ifdef NEKTAR_USE_EXPRESSION_TEMPLATES
37 #include <boost/utility/enable_if.hpp>
38 #include <boost/mpl/or.hpp>
39 #include <boost/tuple/tuple.hpp>
41 #include <ExpressionTemplates/Node.hpp>
42 #include <ExpressionTemplates/Operators.hpp>
49 template<
typename NodeType,
typename Indices,
unsigned int index>
53 template<
typename T,
typename Indices,
unsigned int index>
54 struct MatrixSize<expt::Node<T, void, void>, Indices, index>
56 static const unsigned int MappedIndex = boost::mpl::at_c<Indices, index>::type::value;
59 static unsigned int GetRequiredRowsFromMatrix(
const R& matrix,
60 typename boost::enable_if<IsMatrix<R> >::type* dummy = 0)
62 return matrix.GetRows();
66 static unsigned int GetRequiredColumnsFromMatrix(
const R& matrix,
67 typename boost::enable_if<IsMatrix<R> >::type* dummy = 0)
69 return matrix.GetColumns();
73 static unsigned int GetRequiredRowsFromMatrix(
const R& matrix,
74 typename boost::enable_if<IsVector<R> >::type* dummy = 0)
76 return matrix.GetRows();
80 static unsigned int GetRequiredColumnsFromMatrix(
const R& matrix,
81 typename boost::enable_if<IsVector<R> >::type* dummy = 0)
87 static unsigned int GetRequiredRowsFromMatrix(
const R& matrix,
88 typename boost::disable_if
102 static unsigned int GetRequiredColumnsFromMatrix(
const R& matrix,
103 typename boost::disable_if
115 template<
typename ArgumentVectorType>
116 static boost::tuple<unsigned int, unsigned int, unsigned int>
117 GetRequiredSize(
const ArgumentVectorType& args)
119 unsigned int rows = GetRequiredRowsFromMatrix(boost::fusion::at_c<MappedIndex>(args));
120 unsigned int columns = GetRequiredColumnsFromMatrix(boost::fusion::at_c<MappedIndex>(args));
122 return boost::make_tuple(rows, columns, rows*columns);
129 template<
typename ChildType,
131 typename Indices,
unsigned int index>
132 struct MatrixSize<expt::Node<ChildType, Op, void>, Indices, index>
134 template<
typename ArgumentVectorType>
135 static boost::tuple<unsigned int, unsigned int, unsigned int>
136 GetRequiredSize(
const ArgumentVectorType& args)
138 return MatrixSize<ChildType, Indices, index>::GetRequiredSize(args);
141 template<
typename ArgumentVectorType>
143 GetRequiredRows(
const ArgumentVectorType& args)
145 boost::tuple<unsigned int, unsigned int, unsigned int> values = GetRequiredSize(args);
146 return values.get<0>();
153 template<
typename LeftNodeType,
typename RightNodeType,
typename Indices,
unsigned int index>
154 struct CalculateLargestRequiredSize
156 template<
typename ArgumentVectorType>
157 static boost::tuple<unsigned int, unsigned int, unsigned int>
158 GetRequiredSize(
const ArgumentVectorType& args)
160 boost::tuple<unsigned int, unsigned int, unsigned int> lhsSizes =
161 MatrixSize<LeftNodeType, Indices, index>::GetRequiredSize(args);
163 boost::tuple<unsigned int, unsigned int, unsigned int> rhsSizes =
164 MatrixSize<RightNodeType, Indices, index + LeftNodeType::TotalCount>::GetRequiredSize(args);
166 unsigned int leftRows = lhsSizes.get<0>();
167 unsigned int rightColumns = rhsSizes.get<1>();
169 unsigned int matrixSize = leftRows*rightColumns;
170 unsigned int bufferSize = std::max(std::max(lhsSizes.get<2>(), rhsSizes.get<2>()), matrixSize);
172 return boost::make_tuple(leftRows, rightColumns, bufferSize);
177 template<
typename LhsType,
typename OpType,
typename RhsType,
typename Indices,
unsigned int index,
typename enabled=
void>
178 struct BinaryMatrixSizeEvaluator;
181 template<
typename LhsType,
typename OpType,
typename RhsType,
typename Indices,
unsigned int index>
182 struct BinaryMatrixSizeEvaluator<LhsType, OpType, RhsType, Indices, index,
183 typename boost::enable_if
187 boost::is_same<OpType, expt::AddOp>,
188 boost::is_same<OpType, expt::SubtractOp>
192 template<
typename ArgumentVectorType>
193 static boost::tuple<unsigned int, unsigned int, unsigned int>
194 GetRequiredSize(
const ArgumentVectorType& args)
196 boost::tuple<unsigned int, unsigned int, unsigned int> lhsSizes =
197 MatrixSize<LhsType, Indices, index>::GetRequiredSize(args);
203 template<
typename LhsType,
typename OpType,
typename RhsType,
typename Indices,
unsigned int index>
204 struct BinaryMatrixSizeEvaluator<LhsType, OpType, RhsType, Indices, index,
205 typename boost::enable_if
209 expt::IsConstantNode<LhsType>,
210 boost::is_same<typename LhsType::ResultType, double>
214 template<
typename ArgumentVectorType>
215 static boost::tuple<unsigned int, unsigned int, unsigned int>
216 GetRequiredSize(
const ArgumentVectorType& args)
218 boost::tuple<unsigned int, unsigned int, unsigned int> rhsSizes =
219 MatrixSize<RhsType, Indices, index + LhsType::TotalCount>::GetRequiredSize(args);
226 template<
typename LhsType,
typename OpType,
typename RhsType,
typename Indices,
unsigned int index>
227 struct BinaryMatrixSizeEvaluator<LhsType, OpType, RhsType, Indices, index,
228 typename boost::enable_if
232 expt::IsConstantNode<RhsType>,
233 boost::is_same<typename RhsType::ResultType, double>
237 template<
typename ArgumentVectorType>
238 static boost::tuple<unsigned int, unsigned int, unsigned int>
239 GetRequiredSize(
const ArgumentVectorType& args)
241 boost::tuple<unsigned int, unsigned int, unsigned int> lhsSizes =
242 MatrixSize<LhsType, Indices, index>::GetRequiredSize(args);
248 template<
typename LhsType,
typename RhsType,
typename Indices,
unsigned int index>
249 struct BinaryMatrixSizeEvaluator<LhsType, expt::MultiplyOp, RhsType, Indices, index,
250 typename boost::enable_if
256 boost::mpl::not_<boost::mpl::and_
260 expt::IsConstantNode<RhsType>,
261 boost::mpl::not_<boost::is_same<typename RhsType::ResultType, double> >
265 expt::IsConstantNode<LhsType>,
266 boost::mpl::not_<boost::is_same<typename LhsType::ResultType, double> >
271 boost::mpl::not_<boost::mpl::and_
273 expt::IsConstantNode<RhsType>,
274 boost::is_same<typename RhsType::ResultType, double>
276 boost::mpl::not_<boost::mpl::and_
278 expt::IsConstantNode<LhsType>,
279 boost::is_same<typename LhsType::ResultType, double>
284 template<
typename ArgumentVectorType>
285 static boost::tuple<unsigned int, unsigned int, unsigned int>
286 GetRequiredSize(
const ArgumentVectorType& args)
288 return CalculateLargestRequiredSize<LhsType, RhsType, Indices, index>::GetRequiredSize(args);
292 template<
typename LhsType,
typename RhsType,
typename Indices,
unsigned int index>
293 struct BinaryMatrixSizeEvaluator<LhsType, expt::MultiplyOp, RhsType, Indices, index,
294 typename boost::enable_if
301 expt::IsConstantNode<RhsType>,
302 boost::mpl::not_<boost::is_same<typename RhsType::ResultType, double> >
306 expt::IsConstantNode<LhsType>,
307 boost::mpl::not_<boost::is_same<typename LhsType::ResultType, double> >
312 template<
typename ArgumentVectorType>
313 static boost::tuple<unsigned int, unsigned int, unsigned int>
314 GetRequiredSize(
const ArgumentVectorType& args)
316 boost::tuple<unsigned int, unsigned int, unsigned int> lhsSizes =
317 MatrixSize<LhsType, Indices, index>::GetRequiredSize(args);
319 boost::tuple<unsigned int, unsigned int, unsigned int> rhsSizes =
320 MatrixSize<RhsType, Indices, index + LhsType::TotalCount>::GetRequiredSize(args);
322 unsigned int leftRows = lhsSizes.get<0>();
323 unsigned int rightColumns = rhsSizes.get<1>();
325 unsigned int bufferSize = leftRows*rightColumns;
327 return boost::make_tuple(leftRows, rightColumns, bufferSize);
560 template<
typename L1,
typename LOp,
typename L2,
562 typename R1,
typename ROp,
typename R2,
563 typename Indices,
unsigned int index>
564 struct MatrixSize<expt::Node<expt::Node<L1, LOp, L2>, Op, expt::Node<R1, ROp, R2> >, Indices, index>
566 typedef expt::Node<L1, LOp, L2> LhsType;
567 typedef expt::Node<R1, ROp, R2> RhsType;
569 template<
typename ArgumentVectorType>
570 static boost::tuple<unsigned int, unsigned int, unsigned int>
571 GetRequiredSize(
const ArgumentVectorType& args)
573 return BinaryMatrixSizeEvaluator<LhsType, Op, RhsType, Indices, index>::GetRequiredSize(args);
576 template<
typename ArgumentVectorType>
578 GetRequiredRows(
const ArgumentVectorType& args)
580 boost::tuple<unsigned int, unsigned int, unsigned int> values = GetRequiredSize(args);
581 return values.get<0>();
587 #endif //NEKTAR_USE_EXPRESSION_TEMPLATES
588 #endif //NEKTAR_LIBUTILITIES_LINEAR_ALGEBRA_MATRIX_SIZE_HPP