Eigen  3.2.91
BlasUtil.h
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2009-2010 Gael Guennebaud <gael.guennebaud@inria.fr>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_BLASUTIL_H
11 #define EIGEN_BLASUTIL_H
12 
13 // This file contains many lightweight helper classes used to
14 // implement and control fast level 2 and level 3 BLAS-like routines.
15 
16 namespace Eigen {
17 
18 namespace internal {
19 
20 // forward declarations
21 template<typename LhsScalar, typename RhsScalar, typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs=false, bool ConjugateRhs=false>
22 struct gebp_kernel;
23 
24 template<typename Scalar, typename Index, typename DataMapper, int nr, int StorageOrder, bool Conjugate = false, bool PanelMode=false>
25 struct gemm_pack_rhs;
26 
27 template<typename Scalar, typename Index, typename DataMapper, int Pack1, int Pack2, int StorageOrder, bool Conjugate = false, bool PanelMode = false>
28 struct gemm_pack_lhs;
29 
30 template<
31  typename Index,
32  typename LhsScalar, int LhsStorageOrder, bool ConjugateLhs,
33  typename RhsScalar, int RhsStorageOrder, bool ConjugateRhs,
34  int ResStorageOrder>
35 struct general_matrix_matrix_product;
36 
37 template<typename Index,
38  typename LhsScalar, typename LhsMapper, int LhsStorageOrder, bool ConjugateLhs,
39  typename RhsScalar, typename RhsMapper, bool ConjugateRhs, int Version=Specialized>
40 struct general_matrix_vector_product;
41 
42 
43 template<bool Conjugate> struct conj_if;
44 
45 template<> struct conj_if<true> {
46  template<typename T>
47  inline T operator()(const T& x) { return numext::conj(x); }
48  template<typename T>
49  inline T pconj(const T& x) { return internal::pconj(x); }
50 };
51 
52 template<> struct conj_if<false> {
53  template<typename T>
54  inline const T& operator()(const T& x) { return x; }
55  template<typename T>
56  inline const T& pconj(const T& x) { return x; }
57 };
58 
59 template<typename Scalar> struct conj_helper<Scalar,Scalar,false,false>
60 {
61  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar pmadd(const Scalar& x, const Scalar& y, const Scalar& c) const { return internal::pmadd(x,y,c); }
62  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar pmul(const Scalar& x, const Scalar& y) const { return internal::pmul(x,y); }
63 };
64 
65 template<typename RealScalar> struct conj_helper<std::complex<RealScalar>, std::complex<RealScalar>, false,true>
66 {
67  typedef std::complex<RealScalar> Scalar;
68  EIGEN_STRONG_INLINE Scalar pmadd(const Scalar& x, const Scalar& y, const Scalar& c) const
69  { return c + pmul(x,y); }
70 
71  EIGEN_STRONG_INLINE Scalar pmul(const Scalar& x, const Scalar& y) const
72  { return Scalar(numext::real(x)*numext::real(y) + numext::imag(x)*numext::imag(y), numext::imag(x)*numext::real(y) - numext::real(x)*numext::imag(y)); }
73 };
74 
75 template<typename RealScalar> struct conj_helper<std::complex<RealScalar>, std::complex<RealScalar>, true,false>
76 {
77  typedef std::complex<RealScalar> Scalar;
78  EIGEN_STRONG_INLINE Scalar pmadd(const Scalar& x, const Scalar& y, const Scalar& c) const
79  { return c + pmul(x,y); }
80 
81  EIGEN_STRONG_INLINE Scalar pmul(const Scalar& x, const Scalar& y) const
82  { return Scalar(numext::real(x)*numext::real(y) + numext::imag(x)*numext::imag(y), numext::real(x)*numext::imag(y) - numext::imag(x)*numext::real(y)); }
83 };
84 
85 template<typename RealScalar> struct conj_helper<std::complex<RealScalar>, std::complex<RealScalar>, true,true>
86 {
87  typedef std::complex<RealScalar> Scalar;
88  EIGEN_STRONG_INLINE Scalar pmadd(const Scalar& x, const Scalar& y, const Scalar& c) const
89  { return c + pmul(x,y); }
90 
91  EIGEN_STRONG_INLINE Scalar pmul(const Scalar& x, const Scalar& y) const
92  { return Scalar(numext::real(x)*numext::real(y) - numext::imag(x)*numext::imag(y), - numext::real(x)*numext::imag(y) - numext::imag(x)*numext::real(y)); }
93 };
94 
95 template<typename RealScalar,bool Conj> struct conj_helper<std::complex<RealScalar>, RealScalar, Conj,false>
96 {
97  typedef std::complex<RealScalar> Scalar;
98  EIGEN_STRONG_INLINE Scalar pmadd(const Scalar& x, const RealScalar& y, const Scalar& c) const
99  { return padd(c, pmul(x,y)); }
100  EIGEN_STRONG_INLINE Scalar pmul(const Scalar& x, const RealScalar& y) const
101  { return conj_if<Conj>()(x)*y; }
102 };
103 
104 template<typename RealScalar,bool Conj> struct conj_helper<RealScalar, std::complex<RealScalar>, false,Conj>
105 {
106  typedef std::complex<RealScalar> Scalar;
107  EIGEN_STRONG_INLINE Scalar pmadd(const RealScalar& x, const Scalar& y, const Scalar& c) const
108  { return padd(c, pmul(x,y)); }
109  EIGEN_STRONG_INLINE Scalar pmul(const RealScalar& x, const Scalar& y) const
110  { return x*conj_if<Conj>()(y); }
111 };
112 
113 template<typename From,typename To> struct get_factor {
114  EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE To run(const From& x) { return x; }
115 };
116 
117 template<typename Scalar> struct get_factor<Scalar,typename NumTraits<Scalar>::Real> {
118  EIGEN_DEVICE_FUNC
119  static EIGEN_STRONG_INLINE typename NumTraits<Scalar>::Real run(const Scalar& x) { return numext::real(x); }
120 };
121 
122 
123 template<typename Scalar, typename Index>
124 class BlasVectorMapper {
125  public:
126  EIGEN_ALWAYS_INLINE BlasVectorMapper(Scalar *data) : m_data(data) {}
127 
128  EIGEN_ALWAYS_INLINE Scalar operator()(Index i) const {
129  return m_data[i];
130  }
131  template <typename Packet, int AlignmentType>
132  EIGEN_ALWAYS_INLINE Packet load(Index i) const {
133  return ploadt<Packet, AlignmentType>(m_data + i);
134  }
135 
136  template <typename Packet>
137  bool aligned(Index i) const {
138  return (size_t(m_data+i)%sizeof(Packet))==0;
139  }
140 
141  protected:
142  Scalar* m_data;
143 };
144 
145 template<typename Scalar, typename Index, int AlignmentType>
146 class BlasLinearMapper {
147  public:
148  typedef typename packet_traits<Scalar>::type Packet;
149  typedef typename packet_traits<Scalar>::half HalfPacket;
150 
151  EIGEN_ALWAYS_INLINE BlasLinearMapper(Scalar *data) : m_data(data) {}
152 
153  EIGEN_ALWAYS_INLINE void prefetch(int i) const {
154  internal::prefetch(&operator()(i));
155  }
156 
157  EIGEN_ALWAYS_INLINE Scalar& operator()(Index i) const {
158  return m_data[i];
159  }
160 
161  EIGEN_ALWAYS_INLINE Packet loadPacket(Index i) const {
162  return ploadt<Packet, AlignmentType>(m_data + i);
163  }
164 
165  EIGEN_ALWAYS_INLINE HalfPacket loadHalfPacket(Index i) const {
166  return ploadt<HalfPacket, AlignmentType>(m_data + i);
167  }
168 
169  EIGEN_ALWAYS_INLINE void storePacket(Index i, const Packet &p) const {
170  pstoret<Scalar, Packet, AlignmentType>(m_data + i, p);
171  }
172 
173  protected:
174  Scalar *m_data;
175 };
176 
177 // Lightweight helper class to access matrix coefficients.
178 template<typename Scalar, typename Index, int StorageOrder, int AlignmentType = Unaligned>
179 class blas_data_mapper {
180  public:
181  typedef typename packet_traits<Scalar>::type Packet;
182  typedef typename packet_traits<Scalar>::half HalfPacket;
183 
184  typedef BlasLinearMapper<Scalar, Index, AlignmentType> LinearMapper;
185  typedef BlasVectorMapper<Scalar, Index> VectorMapper;
186 
187  EIGEN_ALWAYS_INLINE blas_data_mapper(Scalar* data, Index stride) : m_data(data), m_stride(stride) {}
188 
189  EIGEN_ALWAYS_INLINE blas_data_mapper<Scalar, Index, StorageOrder, AlignmentType>
190  getSubMapper(Index i, Index j) const {
191  return blas_data_mapper<Scalar, Index, StorageOrder, AlignmentType>(&operator()(i, j), m_stride);
192  }
193 
194  EIGEN_ALWAYS_INLINE LinearMapper getLinearMapper(Index i, Index j) const {
195  return LinearMapper(&operator()(i, j));
196  }
197 
198  EIGEN_ALWAYS_INLINE VectorMapper getVectorMapper(Index i, Index j) const {
199  return VectorMapper(&operator()(i, j));
200  }
201 
202 
203  EIGEN_DEVICE_FUNC
204  EIGEN_ALWAYS_INLINE Scalar& operator()(Index i, Index j) const {
205  return m_data[StorageOrder==RowMajor ? j + i*m_stride : i + j*m_stride];
206  }
207 
208  EIGEN_ALWAYS_INLINE Packet loadPacket(Index i, Index j) const {
209  return ploadt<Packet, AlignmentType>(&operator()(i, j));
210  }
211 
212  EIGEN_ALWAYS_INLINE HalfPacket loadHalfPacket(Index i, Index j) const {
213  return ploadt<HalfPacket, AlignmentType>(&operator()(i, j));
214  }
215 
216  template<typename SubPacket>
217  EIGEN_ALWAYS_INLINE void scatterPacket(Index i, Index j, const SubPacket &p) const {
218  pscatter<Scalar, SubPacket>(&operator()(i, j), p, m_stride);
219  }
220 
221  template<typename SubPacket>
222  EIGEN_ALWAYS_INLINE SubPacket gatherPacket(Index i, Index j) const {
223  return pgather<Scalar, SubPacket>(&operator()(i, j), m_stride);
224  }
225 
226  const Index stride() const { return m_stride; }
227  const Scalar* data() const { return m_data; }
228 
229  Index firstAligned(Index size) const {
230  if (size_t(m_data)%sizeof(Scalar)) {
231  return -1;
232  }
233  return internal::first_default_aligned(m_data, size);
234  }
235 
236  protected:
237  Scalar* EIGEN_RESTRICT m_data;
238  const Index m_stride;
239 };
240 
241 // lightweight helper class to access matrix coefficients (const version)
242 template<typename Scalar, typename Index, int StorageOrder>
243 class const_blas_data_mapper : public blas_data_mapper<const Scalar, Index, StorageOrder> {
244  public:
245  EIGEN_ALWAYS_INLINE const_blas_data_mapper(const Scalar *data, Index stride) : blas_data_mapper<const Scalar, Index, StorageOrder>(data, stride) {}
246 
247  EIGEN_ALWAYS_INLINE const_blas_data_mapper<Scalar, Index, StorageOrder> getSubMapper(Index i, Index j) const {
248  return const_blas_data_mapper<Scalar, Index, StorageOrder>(&(this->operator()(i, j)), this->m_stride);
249  }
250 };
251 
252 
253 /* Helper class to analyze the factors of a Product expression.
254  * In particular it allows to pop out operator-, scalar multiples,
255  * and conjugate */
256 template<typename XprType> struct blas_traits
257 {
258  typedef typename traits<XprType>::Scalar Scalar;
259  typedef const XprType& ExtractType;
260  typedef XprType _ExtractType;
261  enum {
262  IsComplex = NumTraits<Scalar>::IsComplex,
263  IsTransposed = false,
264  NeedToConjugate = false,
265  HasUsableDirectAccess = ( (int(XprType::Flags)&DirectAccessBit)
266  && ( bool(XprType::IsVectorAtCompileTime)
267  || int(inner_stride_at_compile_time<XprType>::ret) == 1)
268  ) ? 1 : 0
269  };
270  typedef typename conditional<bool(HasUsableDirectAccess),
271  ExtractType,
272  typename _ExtractType::PlainObject
273  >::type DirectLinearAccessType;
274  static inline ExtractType extract(const XprType& x) { return x; }
275  static inline const Scalar extractScalarFactor(const XprType&) { return Scalar(1); }
276 };
277 
278 // pop conjugate
279 template<typename Scalar, typename NestedXpr>
280 struct blas_traits<CwiseUnaryOp<scalar_conjugate_op<Scalar>, NestedXpr> >
281  : blas_traits<NestedXpr>
282 {
283  typedef blas_traits<NestedXpr> Base;
284  typedef CwiseUnaryOp<scalar_conjugate_op<Scalar>, NestedXpr> XprType;
285  typedef typename Base::ExtractType ExtractType;
286 
287  enum {
288  IsComplex = NumTraits<Scalar>::IsComplex,
289  NeedToConjugate = Base::NeedToConjugate ? 0 : IsComplex
290  };
291  static inline ExtractType extract(const XprType& x) { return Base::extract(x.nestedExpression()); }
292  static inline Scalar extractScalarFactor(const XprType& x) { return conj(Base::extractScalarFactor(x.nestedExpression())); }
293 };
294 
295 // pop scalar multiple
296 template<typename Scalar, typename NestedXpr>
297 struct blas_traits<CwiseUnaryOp<scalar_multiple_op<Scalar>, NestedXpr> >
298  : blas_traits<NestedXpr>
299 {
300  typedef blas_traits<NestedXpr> Base;
301  typedef CwiseUnaryOp<scalar_multiple_op<Scalar>, NestedXpr> XprType;
302  typedef typename Base::ExtractType ExtractType;
303  static inline ExtractType extract(const XprType& x) { return Base::extract(x.nestedExpression()); }
304  static inline Scalar extractScalarFactor(const XprType& x)
305  { return x.functor().m_other * Base::extractScalarFactor(x.nestedExpression()); }
306 };
307 
308 // pop opposite
309 template<typename Scalar, typename NestedXpr>
310 struct blas_traits<CwiseUnaryOp<scalar_opposite_op<Scalar>, NestedXpr> >
311  : blas_traits<NestedXpr>
312 {
313  typedef blas_traits<NestedXpr> Base;
314  typedef CwiseUnaryOp<scalar_opposite_op<Scalar>, NestedXpr> XprType;
315  typedef typename Base::ExtractType ExtractType;
316  static inline ExtractType extract(const XprType& x) { return Base::extract(x.nestedExpression()); }
317  static inline Scalar extractScalarFactor(const XprType& x)
318  { return - Base::extractScalarFactor(x.nestedExpression()); }
319 };
320 
321 // pop/push transpose
322 template<typename NestedXpr>
323 struct blas_traits<Transpose<NestedXpr> >
324  : blas_traits<NestedXpr>
325 {
326  typedef typename NestedXpr::Scalar Scalar;
327  typedef blas_traits<NestedXpr> Base;
328  typedef Transpose<NestedXpr> XprType;
329  typedef Transpose<const typename Base::_ExtractType> ExtractType; // const to get rid of a compile error; anyway blas traits are only used on the RHS
330  typedef Transpose<const typename Base::_ExtractType> _ExtractType;
331  typedef typename conditional<bool(Base::HasUsableDirectAccess),
332  ExtractType,
333  typename ExtractType::PlainObject
334  >::type DirectLinearAccessType;
335  enum {
336  IsTransposed = Base::IsTransposed ? 0 : 1
337  };
338  static inline ExtractType extract(const XprType& x) { return ExtractType(Base::extract(x.nestedExpression())); }
339  static inline Scalar extractScalarFactor(const XprType& x) { return Base::extractScalarFactor(x.nestedExpression()); }
340 };
341 
342 template<typename T>
343 struct blas_traits<const T>
344  : blas_traits<T>
345 {};
346 
347 template<typename T, bool HasUsableDirectAccess=blas_traits<T>::HasUsableDirectAccess>
348 struct extract_data_selector {
349  static const typename T::Scalar* run(const T& m)
350  {
351  return blas_traits<T>::extract(m).data();
352  }
353 };
354 
355 template<typename T>
356 struct extract_data_selector<T,false> {
357  static typename T::Scalar* run(const T&) { return 0; }
358 };
359 
360 template<typename T> const typename T::Scalar* extract_data(const T& m)
361 {
362  return extract_data_selector<T>::run(m);
363 }
364 
365 } // end namespace internal
366 
367 } // end namespace Eigen
368 
369 #endif // EIGEN_BLASUTIL_H
Definition: Constants.h:314
const unsigned int DirectAccessBit
Definition: Constants.h:141
Definition: LDLT.h:16
Definition: StdDeque.h:58
Definition: Eigen_Colamd.h:54