13 #ifndef EIGEN_PRODUCTEVALUATORS_H
14 #define EIGEN_PRODUCTEVALUATORS_H
28 template<
typename Lhs,
typename Rhs,
int Options>
29 struct evaluator<Product<Lhs, Rhs, Options> >
30 :
public product_evaluator<Product<Lhs, Rhs, Options> >
32 typedef Product<Lhs, Rhs, Options> XprType;
33 typedef product_evaluator<XprType> Base;
35 EIGEN_DEVICE_FUNC
explicit evaluator(
const XprType& xpr) : Base(xpr) {}
40 template<
typename Lhs,
typename Rhs,
typename Scalar>
41 struct evaluator<CwiseUnaryOp<
internal::scalar_multiple_op<Scalar>, const Product<Lhs, Rhs, DefaultProduct> > >
42 :
public evaluator<Product<CwiseUnaryOp<internal::scalar_multiple_op<Scalar>,const Lhs>, Rhs, DefaultProduct> >
44 typedef CwiseUnaryOp<internal::scalar_multiple_op<Scalar>,
const Product<Lhs, Rhs, DefaultProduct> > XprType;
45 typedef evaluator<Product<CwiseUnaryOp<internal::scalar_multiple_op<Scalar>,
const Lhs>, Rhs, DefaultProduct> > Base;
47 EIGEN_DEVICE_FUNC
explicit evaluator(
const XprType& xpr)
48 : Base(xpr.functor().m_other * xpr.nestedExpression().lhs() * xpr.nestedExpression().rhs())
53 template<
typename Lhs,
typename Rhs,
int DiagIndex>
54 struct evaluator<Diagonal<const Product<Lhs, Rhs, DefaultProduct>, DiagIndex> >
55 :
public evaluator<Diagonal<const Product<Lhs, Rhs, LazyProduct>, DiagIndex> >
57 typedef Diagonal<const Product<Lhs, Rhs, DefaultProduct>, DiagIndex> XprType;
58 typedef evaluator<Diagonal<const Product<Lhs, Rhs, LazyProduct>, DiagIndex> > Base;
60 EIGEN_DEVICE_FUNC
explicit evaluator(
const XprType& xpr)
61 : Base(Diagonal<const Product<Lhs, Rhs, LazyProduct>, DiagIndex>(
62 Product<Lhs, Rhs, LazyProduct>(xpr.nestedExpression().lhs(), xpr.nestedExpression().rhs()),
71 template<
typename Lhs,
typename Rhs,
72 typename LhsShape =
typename evaluator_traits<Lhs>::Shape,
73 typename RhsShape =
typename evaluator_traits<Rhs>::Shape,
74 int ProductType = internal::product_type<Lhs,Rhs>::value>
75 struct generic_product_impl;
77 template<
typename Lhs,
typename Rhs>
78 struct evaluator_traits<Product<Lhs, Rhs, DefaultProduct> >
79 : evaluator_traits_base<Product<Lhs, Rhs, DefaultProduct> >
81 enum { AssumeAliasing = 1 };
84 template<
typename Lhs,
typename Rhs>
85 struct evaluator_traits<Product<Lhs, Rhs, AliasFreeProduct> >
86 : evaluator_traits_base<Product<Lhs, Rhs, AliasFreeProduct> >
88 enum { AssumeAliasing = 0 };
93 template<
typename Lhs,
typename Rhs,
int Options,
int ProductTag,
typename LhsShape,
typename RhsShape>
94 struct product_evaluator<Product<Lhs, Rhs, Options>, ProductTag, LhsShape, RhsShape, typename traits<Lhs>::Scalar, typename traits<Rhs>::Scalar,
95 EnableIf<(Options==DefaultProduct || Options==AliasFreeProduct)> >
96 :
public evaluator<typename Product<Lhs, Rhs, Options>::PlainObject>
98 typedef Product<Lhs, Rhs, Options> XprType;
99 typedef typename XprType::PlainObject PlainObject;
100 typedef evaluator<PlainObject> Base;
105 EIGEN_DEVICE_FUNC
explicit product_evaluator(
const XprType& xpr)
106 : m_result(xpr.rows(), xpr.cols())
108 ::new (static_cast<Base*>(
this)) Base(m_result);
122 generic_product_impl<Lhs, Rhs, LhsShape, RhsShape, ProductTag>::evalTo(m_result, xpr.lhs(), xpr.rhs());
126 PlainObject m_result;
130 template< typename DstXprType, typename Lhs, typename Rhs,
int Options, typename Scalar>
131 struct Assignment<DstXprType, Product<Lhs,Rhs,Options>,
internal::assign_op<Scalar>, Dense2Dense,
132 typename enable_if<(Options==DefaultProduct || Options==AliasFreeProduct),Scalar>::type>
134 typedef Product<Lhs,Rhs,Options> SrcXprType;
135 static void run(DstXprType &dst,
const SrcXprType &src,
const internal::assign_op<Scalar> &)
138 generic_product_impl<Lhs, Rhs>::evalTo(dst, src.lhs(), src.rhs());
143 template<
typename DstXprType,
typename Lhs,
typename Rhs,
int Options,
typename Scalar>
144 struct Assignment<DstXprType, Product<Lhs,Rhs,Options>,
internal::add_assign_op<Scalar>, Dense2Dense,
145 typename enable_if<(Options==DefaultProduct || Options==AliasFreeProduct),Scalar>::type>
147 typedef Product<Lhs,Rhs,Options> SrcXprType;
148 static void run(DstXprType &dst,
const SrcXprType &src,
const internal::add_assign_op<Scalar> &)
151 generic_product_impl<Lhs, Rhs>::addTo(dst, src.lhs(), src.rhs());
156 template<
typename DstXprType,
typename Lhs,
typename Rhs,
int Options,
typename Scalar>
157 struct Assignment<DstXprType, Product<Lhs,Rhs,Options>,
internal::sub_assign_op<Scalar>, Dense2Dense,
158 typename enable_if<(Options==DefaultProduct || Options==AliasFreeProduct),Scalar>::type>
160 typedef Product<Lhs,Rhs,Options> SrcXprType;
161 static void run(DstXprType &dst,
const SrcXprType &src,
const internal::sub_assign_op<Scalar> &)
164 generic_product_impl<Lhs, Rhs>::subTo(dst, src.lhs(), src.rhs());
172 template<
typename DstXprType,
typename Lhs,
typename Rhs,
typename AssignFunc,
typename Scalar,
typename ScalarBis>
173 struct Assignment<DstXprType, CwiseUnaryOp<
internal::scalar_multiple_op<ScalarBis>,
174 const Product<Lhs,Rhs,DefaultProduct> >, AssignFunc, Dense2Dense, Scalar>
176 typedef CwiseUnaryOp<internal::scalar_multiple_op<ScalarBis>,
177 const Product<Lhs,Rhs,DefaultProduct> > SrcXprType;
178 static void run(DstXprType &dst,
const SrcXprType &src,
const AssignFunc& func)
181 call_assignment(dst.noalias(), prod(src.functor().m_other * src.nestedExpression().lhs(), src.nestedExpression().rhs()), func);
186 template<
typename Lhs,
typename Rhs>
187 struct generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,InnerProduct>
189 template<
typename Dst>
190 static inline void evalTo(Dst& dst,
const Lhs& lhs,
const Rhs& rhs)
192 dst.coeffRef(0,0) = (lhs.transpose().cwiseProduct(rhs)).sum();
195 template<
typename Dst>
196 static inline void addTo(Dst& dst,
const Lhs& lhs,
const Rhs& rhs)
198 dst.coeffRef(0,0) += (lhs.transpose().cwiseProduct(rhs)).sum();
201 template<
typename Dst>
202 static void subTo(Dst& dst,
const Lhs& lhs,
const Rhs& rhs)
203 { dst.coeffRef(0,0) -= (lhs.transpose().cwiseProduct(rhs)).sum(); }
212 template<
typename Dst,
typename Lhs,
typename Rhs,
typename Func>
213 EIGEN_DONT_INLINE
void outer_product_selector_run(Dst& dst,
const Lhs &lhs,
const Rhs &rhs,
const Func& func,
const false_type&)
215 evaluator<Rhs> rhsEval(rhs);
219 const Index cols = dst.cols();
220 for (Index j=0; j<cols; ++j)
221 func(dst.col(j), rhsEval.coeff(0,j) * lhs);
225 template<
typename Dst,
typename Lhs,
typename Rhs,
typename Func>
226 EIGEN_DONT_INLINE
void outer_product_selector_run(Dst& dst,
const Lhs &lhs,
const Rhs &rhs,
const Func& func,
const true_type&)
228 evaluator<Lhs> lhsEval(lhs);
232 const Index rows = dst.rows();
233 for (Index i=0; i<rows; ++i)
234 func(dst.row(i), lhsEval.coeff(i,0) * rhs);
237 template<
typename Lhs,
typename Rhs>
238 struct generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,OuterProduct>
240 template<
typename T>
struct IsRowMajor : internal::conditional<(int(T::Flags)&RowMajorBit), internal::true_type, internal::false_type>::type {};
241 typedef typename Product<Lhs,Rhs>::Scalar Scalar;
244 struct set {
template<
typename Dst,
typename Src>
void operator()(
const Dst& dst,
const Src& src)
const { dst.const_cast_derived() = src; } };
245 struct add {
template<
typename Dst,
typename Src>
void operator()(
const Dst& dst,
const Src& src)
const { dst.const_cast_derived() += src; } };
246 struct sub {
template<
typename Dst,
typename Src>
void operator()(
const Dst& dst,
const Src& src)
const { dst.const_cast_derived() -= src; } };
249 explicit adds(
const Scalar& s) : m_scale(s) {}
250 template<
typename Dst,
typename Src>
void operator()(
const Dst& dst,
const Src& src)
const {
251 dst.const_cast_derived() += m_scale * src;
255 template<
typename Dst>
256 static inline void evalTo(Dst& dst,
const Lhs& lhs,
const Rhs& rhs)
258 internal::outer_product_selector_run(dst, lhs, rhs, set(), IsRowMajor<Dst>());
261 template<
typename Dst>
262 static inline void addTo(Dst& dst,
const Lhs& lhs,
const Rhs& rhs)
264 internal::outer_product_selector_run(dst, lhs, rhs, add(), IsRowMajor<Dst>());
267 template<
typename Dst>
268 static inline void subTo(Dst& dst,
const Lhs& lhs,
const Rhs& rhs)
270 internal::outer_product_selector_run(dst, lhs, rhs, sub(), IsRowMajor<Dst>());
273 template<
typename Dst>
274 static inline void scaleAndAddTo(Dst& dst,
const Lhs& lhs,
const Rhs& rhs,
const Scalar& alpha)
276 internal::outer_product_selector_run(dst, lhs, rhs, adds(alpha), IsRowMajor<Dst>());
283 template<
typename Lhs,
typename Rhs,
typename Derived>
284 struct generic_product_impl_base
286 typedef typename Product<Lhs,Rhs>::Scalar Scalar;
288 template<
typename Dst>
289 static void evalTo(Dst& dst,
const Lhs& lhs,
const Rhs& rhs)
290 { dst.setZero(); scaleAndAddTo(dst, lhs, rhs, Scalar(1)); }
292 template<
typename Dst>
293 static void addTo(Dst& dst,
const Lhs& lhs,
const Rhs& rhs)
294 { scaleAndAddTo(dst,lhs, rhs, Scalar(1)); }
296 template<
typename Dst>
297 static void subTo(Dst& dst,
const Lhs& lhs,
const Rhs& rhs)
298 { scaleAndAddTo(dst, lhs, rhs, Scalar(-1)); }
300 template<
typename Dst>
301 static void scaleAndAddTo(Dst& dst,
const Lhs& lhs,
const Rhs& rhs,
const Scalar& alpha)
302 { Derived::scaleAndAddTo(dst,lhs,rhs,alpha); }
306 template<
typename Lhs,
typename Rhs>
307 struct generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,GemvProduct>
308 : generic_product_impl_base<Lhs,Rhs,generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,GemvProduct> >
310 typedef typename Product<Lhs,Rhs>::Scalar Scalar;
312 typedef typename internal::conditional<int(Side)==OnTheRight,Lhs,Rhs>::type MatrixType;
314 template<
typename Dest>
315 static void scaleAndAddTo(Dest& dst,
const Lhs& lhs,
const Rhs& rhs,
const Scalar& alpha)
317 internal::gemv_dense_sense_selector<Side,
319 bool(internal::blas_traits<MatrixType>::HasUsableDirectAccess)
320 >::run(lhs, rhs, dst, alpha);
324 template<
typename Lhs,
typename Rhs>
325 struct generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,CoeffBasedProductMode>
327 typedef typename Product<Lhs,Rhs>::Scalar Scalar;
329 template<
typename Dst>
330 static inline void evalTo(Dst& dst,
const Lhs& lhs,
const Rhs& rhs)
334 call_assignment(dst, lazyprod(lhs,rhs), internal::assign_op<Scalar>());
337 template<
typename Dst>
338 static inline void addTo(Dst& dst,
const Lhs& lhs,
const Rhs& rhs)
341 call_assignment(dst, lazyprod(lhs,rhs), internal::add_assign_op<Scalar>());
344 template<
typename Dst>
345 static inline void subTo(Dst& dst,
const Lhs& lhs,
const Rhs& rhs)
348 call_assignment(dst, lazyprod(lhs,rhs), internal::sub_assign_op<Scalar>());
357 template<
typename Lhs,
typename Rhs>
358 struct generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,LazyCoeffBasedProductMode>
359 : generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,CoeffBasedProductMode> {};
367 template<
int Traversal,
int UnrollingIndex,
typename Lhs,
typename Rhs,
typename RetScalar>
368 struct etor_product_coeff_impl;
370 template<
int StorageOrder,
int UnrollingIndex,
typename Lhs,
typename Rhs,
typename Packet,
int LoadMode>
371 struct etor_product_packet_impl;
373 template<
typename Lhs,
typename Rhs,
int ProductTag>
374 struct product_evaluator<Product<Lhs, Rhs, LazyProduct>, ProductTag, DenseShape, DenseShape, typename Lhs::Scalar, typename Rhs::Scalar >
375 : evaluator_base<Product<Lhs, Rhs, LazyProduct> >
377 typedef Product<Lhs, Rhs, LazyProduct> XprType;
378 typedef typename XprType::Scalar Scalar;
379 typedef typename XprType::CoeffReturnType CoeffReturnType;
380 typedef typename XprType::PacketScalar PacketScalar;
381 typedef typename XprType::PacketReturnType PacketReturnType;
383 EIGEN_DEVICE_FUNC
explicit product_evaluator(
const XprType& xpr)
389 m_innerDim(xpr.lhs().cols())
394 typedef typename internal::nested_eval<Lhs,Rhs::ColsAtCompileTime>::type LhsNested;
395 typedef typename internal::nested_eval<Rhs,Lhs::RowsAtCompileTime>::type RhsNested;
397 typedef typename internal::remove_all<LhsNested>::type LhsNestedCleaned;
398 typedef typename internal::remove_all<RhsNested>::type RhsNestedCleaned;
400 typedef evaluator<LhsNestedCleaned> LhsEtorType;
401 typedef evaluator<RhsNestedCleaned> RhsEtorType;
404 RowsAtCompileTime = LhsNestedCleaned::RowsAtCompileTime,
405 ColsAtCompileTime = RhsNestedCleaned::ColsAtCompileTime,
406 InnerSize = EIGEN_SIZE_MIN_PREFER_FIXED(LhsNestedCleaned::ColsAtCompileTime, RhsNestedCleaned::RowsAtCompileTime),
407 MaxRowsAtCompileTime = LhsNestedCleaned::MaxRowsAtCompileTime,
408 MaxColsAtCompileTime = RhsNestedCleaned::MaxColsAtCompileTime,
410 PacketSize = packet_traits<Scalar>::size,
412 LhsCoeffReadCost = LhsEtorType::CoeffReadCost,
413 RhsCoeffReadCost = RhsEtorType::CoeffReadCost,
414 CoeffReadCost = InnerSize==0 ? NumTraits<Scalar>::ReadCost
415 : (InnerSize == Dynamic || LhsCoeffReadCost==Dynamic || RhsCoeffReadCost==Dynamic || NumTraits<Scalar>::AddCost==Dynamic || NumTraits<Scalar>::MulCost==Dynamic) ? Dynamic
416 : InnerSize * (NumTraits<Scalar>::MulCost + LhsCoeffReadCost + RhsCoeffReadCost)
417 + (InnerSize - 1) * NumTraits<Scalar>::AddCost,
419 Unroll = CoeffReadCost != Dynamic && CoeffReadCost <= EIGEN_UNROLLING_LIMIT,
421 LhsFlags = LhsEtorType::Flags,
422 RhsFlags = RhsEtorType::Flags,
424 LhsAlignment = LhsEtorType::Alignment,
425 RhsAlignment = RhsEtorType::Alignment,
427 LhsIsAligned =
int(LhsAlignment) >= int(unpacket_traits<PacketScalar>::alignment),
428 RhsIsAligned = int(RhsAlignment) >= int(unpacket_traits<PacketScalar>::alignment),
433 SameType = is_same<typename LhsNestedCleaned::Scalar,typename RhsNestedCleaned::Scalar>::value,
436 && (ColsAtCompileTime == Dynamic || ( (ColsAtCompileTime % PacketSize) == 0 && RhsIsAligned ) ),
439 && (RowsAtCompileTime == Dynamic || ( (RowsAtCompileTime % PacketSize) == 0 && LhsIsAligned ) ),
441 EvalToRowMajor = (MaxRowsAtCompileTime==1&&MaxColsAtCompileTime!=1) ? 1
442 : (MaxColsAtCompileTime==1&&MaxRowsAtCompileTime!=1) ? 0
443 : (RhsRowMajor && !CanVectorizeLhs),
445 Flags = ((
unsigned int)(LhsFlags | RhsFlags) & HereditaryBits & ~RowMajorBit)
446 | (EvalToRowMajor ? RowMajorBit : 0)
448 | (SameType && (CanVectorizeLhs || CanVectorizeRhs) ?
PacketAccessBit : 0),
450 Alignment = CanVectorizeLhs ? LhsAlignment
451 : CanVectorizeRhs ? RhsAlignment
459 CanVectorizeInner = SameType
463 && (LhsIsAligned && RhsIsAligned)
464 && (InnerSize % packet_traits<Scalar>::size == 0)
467 EIGEN_DEVICE_FUNC
const CoeffReturnType coeff(Index row, Index col)
const
470 return (m_lhs.row(row).transpose().cwiseProduct( m_rhs.col(col) )).sum();
477 EIGEN_DEVICE_FUNC
const CoeffReturnType coeff(Index index)
const
479 const Index row = RowsAtCompileTime == 1 ? 0 : index;
480 const Index col = RowsAtCompileTime == 1 ? index : 0;
482 return (m_lhs.row(row).transpose().cwiseProduct( m_rhs.col(col) )).sum();
485 template<
int LoadMode,
typename PacketType>
486 const PacketType packet(Index row, Index col)
const
490 Unroll ? InnerSize : Dynamic,
491 LhsEtorType, RhsEtorType, PacketType, LoadMode> PacketImpl;
493 PacketImpl::run(row, col, m_lhsImpl, m_rhsImpl, m_innerDim, res);
498 const LhsNested m_lhs;
499 const RhsNested m_rhs;
501 LhsEtorType m_lhsImpl;
502 RhsEtorType m_rhsImpl;
508 template<
typename Lhs,
typename Rhs>
509 struct product_evaluator<Product<Lhs, Rhs, DefaultProduct>, LazyCoeffBasedProductMode, DenseShape, DenseShape, typename traits<Lhs>::Scalar, typename traits<Rhs>::Scalar >
510 : product_evaluator<Product<Lhs, Rhs, LazyProduct>, CoeffBasedProductMode, DenseShape, DenseShape, typename traits<Lhs>::Scalar, typename traits<Rhs>::Scalar >
512 typedef Product<Lhs, Rhs, DefaultProduct> XprType;
513 typedef Product<Lhs, Rhs, LazyProduct> BaseProduct;
514 typedef product_evaluator<BaseProduct, CoeffBasedProductMode, DenseShape, DenseShape, typename Lhs::Scalar, typename Rhs::Scalar > Base;
518 EIGEN_DEVICE_FUNC
explicit product_evaluator(
const XprType& xpr)
519 : Base(BaseProduct(xpr.lhs(),xpr.rhs()))
527 template<
int UnrollingIndex,
typename Lhs,
typename Rhs,
typename Packet,
int LoadMode>
528 struct etor_product_packet_impl<
RowMajor, UnrollingIndex, Lhs, Rhs, Packet, LoadMode>
530 static EIGEN_STRONG_INLINE
void run(Index row, Index col,
const Lhs& lhs,
const Rhs& rhs, Index innerDim, Packet &res)
532 etor_product_packet_impl<RowMajor, UnrollingIndex-1, Lhs, Rhs, Packet, LoadMode>::run(row, col, lhs, rhs, innerDim, res);
533 res = pmadd(pset1<Packet>(lhs.coeff(row, UnrollingIndex-1)), rhs.template packet<LoadMode,Packet>(UnrollingIndex-1, col), res);
537 template<
int UnrollingIndex,
typename Lhs,
typename Rhs,
typename Packet,
int LoadMode>
538 struct etor_product_packet_impl<
ColMajor, UnrollingIndex, Lhs, Rhs, Packet, LoadMode>
540 static EIGEN_STRONG_INLINE
void run(Index row, Index col,
const Lhs& lhs,
const Rhs& rhs, Index innerDim, Packet &res)
542 etor_product_packet_impl<ColMajor, UnrollingIndex-1, Lhs, Rhs, Packet, LoadMode>::run(row, col, lhs, rhs, innerDim, res);
543 res = pmadd(lhs.template packet<LoadMode,Packet>(row, UnrollingIndex-1), pset1<Packet>(rhs.coeff(UnrollingIndex-1, col)), res);
547 template<
typename Lhs,
typename Rhs,
typename Packet,
int LoadMode>
548 struct etor_product_packet_impl<
RowMajor, 1, Lhs, Rhs, Packet, LoadMode>
550 static EIGEN_STRONG_INLINE
void run(Index row, Index col,
const Lhs& lhs,
const Rhs& rhs, Index , Packet &res)
552 res = pmul(pset1<Packet>(lhs.coeff(row, 0)),rhs.template packet<LoadMode,Packet>(0, col));
556 template<
typename Lhs,
typename Rhs,
typename Packet,
int LoadMode>
557 struct etor_product_packet_impl<
ColMajor, 1, Lhs, Rhs, Packet, LoadMode>
559 static EIGEN_STRONG_INLINE
void run(Index row, Index col,
const Lhs& lhs,
const Rhs& rhs, Index , Packet &res)
561 res = pmul(lhs.template packet<LoadMode,Packet>(row, 0), pset1<Packet>(rhs.coeff(0, col)));
565 template<
typename Lhs,
typename Rhs,
typename Packet,
int LoadMode>
566 struct etor_product_packet_impl<
RowMajor, 0, Lhs, Rhs, Packet, LoadMode>
568 static EIGEN_STRONG_INLINE
void run(Index , Index ,
const Lhs& ,
const Rhs& , Index , Packet &res)
570 res = pset1<Packet>(0);
574 template<
typename Lhs,
typename Rhs,
typename Packet,
int LoadMode>
575 struct etor_product_packet_impl<
ColMajor, 0, Lhs, Rhs, Packet, LoadMode>
577 static EIGEN_STRONG_INLINE
void run(Index , Index ,
const Lhs& ,
const Rhs& , Index , Packet &res)
579 res = pset1<Packet>(0);
583 template<
typename Lhs,
typename Rhs,
typename Packet,
int LoadMode>
584 struct etor_product_packet_impl<
RowMajor, Dynamic, Lhs, Rhs, Packet, LoadMode>
586 static EIGEN_STRONG_INLINE
void run(Index row, Index col,
const Lhs& lhs,
const Rhs& rhs, Index innerDim, Packet& res)
588 res = pset1<Packet>(0);
589 for(Index i = 0; i < innerDim; ++i)
590 res = pmadd(pset1<Packet>(lhs.coeff(row, i)), rhs.template packet<LoadMode,Packet>(i, col), res);
594 template<
typename Lhs,
typename Rhs,
typename Packet,
int LoadMode>
595 struct etor_product_packet_impl<
ColMajor, Dynamic, Lhs, Rhs, Packet, LoadMode>
597 static EIGEN_STRONG_INLINE
void run(Index row, Index col,
const Lhs& lhs,
const Rhs& rhs, Index innerDim, Packet& res)
599 res = pset1<Packet>(0);
600 for(Index i = 0; i < innerDim; ++i)
601 res = pmadd(lhs.template packet<LoadMode,Packet>(row, i), pset1<Packet>(rhs.coeff(i, col)), res);
609 template<
int Mode,
bool LhsIsTriangular,
610 typename Lhs,
bool LhsIsVector,
611 typename Rhs,
bool RhsIsVector>
612 struct triangular_product_impl;
614 template<
typename Lhs,
typename Rhs,
int ProductTag>
615 struct generic_product_impl<Lhs,Rhs,TriangularShape,DenseShape,ProductTag>
616 : generic_product_impl_base<Lhs,Rhs,generic_product_impl<Lhs,Rhs,TriangularShape,DenseShape,ProductTag> >
618 typedef typename Product<Lhs,Rhs>::Scalar Scalar;
620 template<
typename Dest>
621 static void scaleAndAddTo(Dest& dst,
const Lhs& lhs,
const Rhs& rhs,
const Scalar& alpha)
623 triangular_product_impl<Lhs::Mode,true,typename Lhs::MatrixType,false,Rhs, Rhs::ColsAtCompileTime==1>
624 ::run(dst, lhs.nestedExpression(), rhs, alpha);
628 template<
typename Lhs,
typename Rhs,
int ProductTag>
629 struct generic_product_impl<Lhs,Rhs,DenseShape,TriangularShape,ProductTag>
630 : generic_product_impl_base<Lhs,Rhs,generic_product_impl<Lhs,Rhs,DenseShape,TriangularShape,ProductTag> >
632 typedef typename Product<Lhs,Rhs>::Scalar Scalar;
634 template<
typename Dest>
635 static void scaleAndAddTo(Dest& dst,
const Lhs& lhs,
const Rhs& rhs,
const Scalar& alpha)
637 triangular_product_impl<Rhs::Mode,false,Lhs,Lhs::RowsAtCompileTime==1, typename Rhs::MatrixType, false>::run(dst, lhs, rhs.nestedExpression(), alpha);
645 template <
typename Lhs,
int LhsMode,
bool LhsIsVector,
646 typename Rhs,
int RhsMode,
bool RhsIsVector>
647 struct selfadjoint_product_impl;
649 template<
typename Lhs,
typename Rhs,
int ProductTag>
650 struct generic_product_impl<Lhs,Rhs,SelfAdjointShape,DenseShape,ProductTag>
651 : generic_product_impl_base<Lhs,Rhs,generic_product_impl<Lhs,Rhs,SelfAdjointShape,DenseShape,ProductTag> >
653 typedef typename Product<Lhs,Rhs>::Scalar Scalar;
655 template<
typename Dest>
656 static void scaleAndAddTo(Dest& dst,
const Lhs& lhs,
const Rhs& rhs,
const Scalar& alpha)
658 selfadjoint_product_impl<typename Lhs::MatrixType,Lhs::Mode,false,Rhs,0,Rhs::IsVectorAtCompileTime>::run(dst, lhs.nestedExpression(), rhs, alpha);
662 template<
typename Lhs,
typename Rhs,
int ProductTag>
663 struct generic_product_impl<Lhs,Rhs,DenseShape,SelfAdjointShape,ProductTag>
664 : generic_product_impl_base<Lhs,Rhs,generic_product_impl<Lhs,Rhs,DenseShape,SelfAdjointShape,ProductTag> >
666 typedef typename Product<Lhs,Rhs>::Scalar Scalar;
668 template<
typename Dest>
669 static void scaleAndAddTo(Dest& dst,
const Lhs& lhs,
const Rhs& rhs,
const Scalar& alpha)
671 selfadjoint_product_impl<Lhs,0,Lhs::IsVectorAtCompileTime,typename Rhs::MatrixType,Rhs::Mode,false>::run(dst, lhs, rhs.nestedExpression(), alpha);
680 template<
typename MatrixType,
typename DiagonalType,
typename Derived,
int ProductOrder>
681 struct diagonal_product_evaluator_base
682 : evaluator_base<Derived>
684 typedef typename scalar_product_traits<typename MatrixType::Scalar, typename DiagonalType::Scalar>::ReturnType Scalar;
687 CoeffReadCost = NumTraits<Scalar>::MulCost + evaluator<MatrixType>::CoeffReadCost + evaluator<DiagonalType>::CoeffReadCost,
689 MatrixFlags = evaluator<MatrixType>::Flags,
690 DiagFlags = evaluator<DiagonalType>::Flags,
692 _ScalarAccessOnDiag = !((int(_StorageOrder) ==
ColMajor && int(ProductOrder) ==
OnTheLeft)
694 _SameTypes = is_same<typename MatrixType::Scalar, typename DiagonalType::Scalar>::value,
698 _LinearAccessMask = (MatrixType::RowsAtCompileTime==1 || MatrixType::ColsAtCompileTime==1) ?
LinearAccessBit : 0,
699 Flags = ((HereditaryBits|_LinearAccessMask) & (
unsigned int)(MatrixFlags)) | (_Vectorizable ?
PacketAccessBit : 0),
700 Alignment = evaluator<MatrixType>::Alignment
703 diagonal_product_evaluator_base(
const MatrixType &mat,
const DiagonalType &diag)
704 : m_diagImpl(diag), m_matImpl(mat)
708 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const Scalar coeff(Index idx)
const
710 return m_diagImpl.coeff(idx) * m_matImpl.coeff(idx);
714 template<
int LoadMode,
typename PacketType>
715 EIGEN_STRONG_INLINE PacketType packet_impl(Index row, Index col, Index
id, internal::true_type)
const
717 return internal::pmul(m_matImpl.template packet<LoadMode,PacketType>(row, col),
718 internal::pset1<PacketType>(m_diagImpl.coeff(
id)));
721 template<
int LoadMode,
typename PacketType>
722 EIGEN_STRONG_INLINE PacketType packet_impl(Index row, Index col, Index
id, internal::false_type)
const
725 InnerSize = (MatrixType::Flags &
RowMajorBit) ? MatrixType::ColsAtCompileTime : MatrixType::RowsAtCompileTime,
726 DiagonalPacketLoadMode = EIGEN_PLAIN_ENUM_MIN(LoadMode,((InnerSize%16) == 0) ? int(
Aligned16) : int(evaluator<DiagonalType>::Alignment))
728 return internal::pmul(m_matImpl.template packet<LoadMode,PacketType>(row, col),
729 m_diagImpl.template packet<DiagonalPacketLoadMode,PacketType>(
id));
732 evaluator<DiagonalType> m_diagImpl;
733 evaluator<MatrixType> m_matImpl;
737 template<
typename Lhs,
typename Rhs,
int ProductKind,
int ProductTag>
738 struct product_evaluator<Product<Lhs, Rhs, ProductKind>, ProductTag, DiagonalShape, DenseShape, typename Lhs::Scalar, typename Rhs::Scalar>
739 : diagonal_product_evaluator_base<Rhs, typename Lhs::DiagonalVectorType, Product<Lhs, Rhs, LazyProduct>, OnTheLeft>
741 typedef diagonal_product_evaluator_base<Rhs, typename Lhs::DiagonalVectorType, Product<Lhs, Rhs, LazyProduct>,
OnTheLeft> Base;
742 using Base::m_diagImpl;
743 using Base::m_matImpl;
745 typedef typename Base::Scalar Scalar;
747 typedef Product<Lhs, Rhs, ProductKind> XprType;
748 typedef typename XprType::PlainObject PlainObject;
754 EIGEN_DEVICE_FUNC
explicit product_evaluator(
const XprType& xpr)
755 : Base(xpr.rhs(), xpr.lhs().diagonal())
759 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const Scalar coeff(Index row, Index col)
const
761 return m_diagImpl.coeff(row) * m_matImpl.coeff(row, col);
765 template<
int LoadMode,
typename PacketType>
766 EIGEN_STRONG_INLINE PacketType packet(Index row, Index col)
const
770 return this->
template packet_impl<LoadMode,PacketType>(row,col, row,
771 typename internal::conditional<int(StorageOrder)==RowMajor, internal::true_type, internal::false_type>::type());
774 template<
int LoadMode,
typename PacketType>
775 EIGEN_STRONG_INLINE PacketType packet(Index idx)
const
777 return packet<LoadMode,PacketType>(int(StorageOrder)==
ColMajor?idx:0,int(StorageOrder)==
ColMajor?0:idx);
783 template<
typename Lhs,
typename Rhs,
int ProductKind,
int ProductTag>
784 struct product_evaluator<Product<Lhs, Rhs, ProductKind>, ProductTag, DenseShape, DiagonalShape,
typename Lhs::Scalar,
typename Rhs::Scalar>
785 : diagonal_product_evaluator_base<Lhs, typename Rhs::DiagonalVectorType, Product<Lhs, Rhs, LazyProduct>,
OnTheRight>
787 typedef diagonal_product_evaluator_base<Lhs, typename Rhs::DiagonalVectorType, Product<Lhs, Rhs, LazyProduct>,
OnTheRight> Base;
788 using Base::m_diagImpl;
789 using Base::m_matImpl;
791 typedef typename Base::Scalar Scalar;
793 typedef Product<Lhs, Rhs, ProductKind> XprType;
794 typedef typename XprType::PlainObject PlainObject;
796 enum { StorageOrder = int(Lhs::Flags) & RowMajorBit ?
RowMajor :
ColMajor };
798 EIGEN_DEVICE_FUNC
explicit product_evaluator(
const XprType& xpr)
799 : Base(xpr.lhs(), xpr.rhs().diagonal())
803 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const Scalar coeff(Index row, Index col)
const
805 return m_matImpl.coeff(row, col) * m_diagImpl.coeff(col);
809 template<
int LoadMode,
typename PacketType>
810 EIGEN_STRONG_INLINE PacketType packet(Index row, Index col)
const
812 return this->
template packet_impl<LoadMode,PacketType>(row,col, col,
813 typename internal::conditional<int(StorageOrder)==ColMajor, internal::true_type, internal::false_type>::type());
816 template<
int LoadMode,
typename PacketType>
817 EIGEN_STRONG_INLINE PacketType packet(Index idx)
const
819 return packet<LoadMode,PacketType>(int(StorageOrder)==
ColMajor?idx:0,int(StorageOrder)==
ColMajor?0:idx);
833 template<
typename ExpressionType,
int S
ide,
bool Transposed,
typename ExpressionShape>
834 struct permutation_matrix_product;
836 template<
typename ExpressionType,
int S
ide,
bool Transposed>
837 struct permutation_matrix_product<ExpressionType, Side, Transposed, DenseShape>
839 typedef typename nested_eval<ExpressionType, 1>::type MatrixType;
840 typedef typename remove_all<MatrixType>::type MatrixTypeCleaned;
842 template<
typename Dest,
typename PermutationType>
843 static inline void run(Dest& dst,
const PermutationType& perm,
const ExpressionType& xpr)
846 const Index n = Side==
OnTheLeft ? mat.rows() : mat.cols();
850 if(is_same_dense(dst, mat))
853 Matrix<bool,PermutationType::RowsAtCompileTime,1,0,PermutationType::MaxRowsAtCompileTime> mask(perm.size());
856 while(r < perm.size())
859 while(r<perm.size() && mask[r]) r++;
865 mask.coeffRef(k0) =
true;
866 for(Index k=perm.indices().coeff(k0); k!=k0; k=perm.indices().coeff(k))
868 Block<Dest, Side==OnTheLeft ? 1 : Dest::RowsAtCompileTime, Side==OnTheRight ? 1 : Dest::ColsAtCompileTime>(dst, k)
869 .swap(Block<Dest, Side==OnTheLeft ? 1 : Dest::RowsAtCompileTime, Side==OnTheRight ? 1 : Dest::ColsAtCompileTime>
870 (dst,((Side==
OnTheLeft) ^ Transposed) ? k0 : kPrev));
872 mask.coeffRef(k) =
true;
879 for(Index i = 0; i < n; ++i)
881 Block<Dest, Side==OnTheLeft ? 1 : Dest::RowsAtCompileTime, Side==OnTheRight ? 1 : Dest::ColsAtCompileTime>
882 (dst, ((Side==
OnTheLeft) ^ Transposed) ? perm.indices().coeff(i) : i)
886 Block<const MatrixTypeCleaned,Side==OnTheLeft ? 1 : MatrixTypeCleaned::RowsAtCompileTime,Side==OnTheRight ? 1 : MatrixTypeCleaned::ColsAtCompileTime>
887 (mat, ((Side==
OnTheRight) ^ Transposed) ? perm.indices().coeff(i) : i);
893 template<
typename Lhs,
typename Rhs,
int ProductTag,
typename MatrixShape>
894 struct generic_product_impl<Lhs, Rhs, PermutationShape, MatrixShape, ProductTag>
896 template<
typename Dest>
897 static void evalTo(Dest& dst,
const Lhs& lhs,
const Rhs& rhs)
899 permutation_matrix_product<Rhs, OnTheLeft, false, MatrixShape>::run(dst, lhs, rhs);
903 template<
typename Lhs,
typename Rhs,
int ProductTag,
typename MatrixShape>
904 struct generic_product_impl<Lhs, Rhs, MatrixShape, PermutationShape, ProductTag>
906 template<
typename Dest>
907 static void evalTo(Dest& dst,
const Lhs& lhs,
const Rhs& rhs)
909 permutation_matrix_product<Lhs, OnTheRight, false, MatrixShape>::run(dst, rhs, lhs);
913 template<
typename Lhs,
typename Rhs,
int ProductTag,
typename MatrixShape>
914 struct generic_product_impl<Transpose<Lhs>, Rhs, PermutationShape, MatrixShape, ProductTag>
916 template<
typename Dest>
917 static void evalTo(Dest& dst,
const Transpose<Lhs>& lhs,
const Rhs& rhs)
919 permutation_matrix_product<Rhs, OnTheLeft, true, MatrixShape>::run(dst, lhs.nestedExpression(), rhs);
923 template<
typename Lhs,
typename Rhs,
int ProductTag,
typename MatrixShape>
924 struct generic_product_impl<Lhs, Transpose<Rhs>, MatrixShape, PermutationShape, ProductTag>
926 template<
typename Dest>
927 static void evalTo(Dest& dst,
const Lhs& lhs,
const Transpose<Rhs>& rhs)
929 permutation_matrix_product<Lhs, OnTheRight, true, MatrixShape>::run(dst, rhs.nestedExpression(), lhs);
944 template<
typename ExpressionType,
int S
ide,
bool Transposed,
typename ExpressionShape>
945 struct transposition_matrix_product
947 typedef typename nested_eval<ExpressionType, 1>::type MatrixType;
948 typedef typename remove_all<MatrixType>::type MatrixTypeCleaned;
950 template<
typename Dest,
typename TranspositionType>
951 static inline void run(Dest& dst,
const TranspositionType& tr,
const ExpressionType& xpr)
954 typedef typename TranspositionType::StorageIndex StorageIndex;
955 const Index size = tr.size();
958 if(!(is_same<MatrixTypeCleaned,Dest>::value && extract_data(dst) == extract_data(mat)))
961 for(Index k=(Transposed?size-1:0) ; Transposed?k>=0:k<size ; Transposed?--k:++k)
962 if(Index(j=tr.coeff(k))!=k)
964 if(Side==
OnTheLeft) dst.row(k).swap(dst.row(j));
965 else if(Side==
OnTheRight) dst.col(k).swap(dst.col(j));
970 template<
typename Lhs,
typename Rhs,
int ProductTag,
typename MatrixShape>
971 struct generic_product_impl<Lhs, Rhs, TranspositionsShape, MatrixShape, ProductTag>
973 template<
typename Dest>
974 static void evalTo(Dest& dst,
const Lhs& lhs,
const Rhs& rhs)
976 transposition_matrix_product<Rhs, OnTheLeft, false, MatrixShape>::run(dst, lhs, rhs);
980 template<
typename Lhs,
typename Rhs,
int ProductTag,
typename MatrixShape>
981 struct generic_product_impl<Lhs, Rhs, MatrixShape, TranspositionsShape, ProductTag>
983 template<
typename Dest>
984 static void evalTo(Dest& dst,
const Lhs& lhs,
const Rhs& rhs)
986 transposition_matrix_product<Lhs, OnTheRight, false, MatrixShape>::run(dst, rhs, lhs);
991 template<
typename Lhs,
typename Rhs,
int ProductTag,
typename MatrixShape>
992 struct generic_product_impl<Transpose<Lhs>, Rhs, TranspositionsShape, MatrixShape, ProductTag>
994 template<
typename Dest>
995 static void evalTo(Dest& dst,
const Transpose<Lhs>& lhs,
const Rhs& rhs)
997 transposition_matrix_product<Rhs, OnTheLeft, true, MatrixShape>::run(dst, lhs.nestedExpression(), rhs);
1001 template<
typename Lhs,
typename Rhs,
int ProductTag,
typename MatrixShape>
1002 struct generic_product_impl<Lhs, Transpose<Rhs>, MatrixShape, TranspositionsShape, ProductTag>
1004 template<
typename Dest>
1005 static void evalTo(Dest& dst,
const Lhs& lhs,
const Transpose<Rhs>& rhs)
1007 transposition_matrix_product<Lhs, OnTheRight, true, MatrixShape>::run(dst, rhs.nestedExpression(), lhs);
1015 #endif // EIGEN_PRODUCT_EVALUATORS_H
Definition: Constants.h:314
Definition: Constants.h:327
const unsigned int RowMajorBit
Definition: Constants.h:53
const unsigned int PacketAccessBit
Definition: Constants.h:80
Definition: Constants.h:222
Definition: Eigen_Colamd.h:54
Definition: Constants.h:312
const unsigned int EvalBeforeNestingBit
Definition: Constants.h:57
Definition: Constants.h:325
const unsigned int ActualPacketAccessBit
Definition: Constants.h:91
const unsigned int LinearAccessBit
Definition: Constants.h:116