10 #ifndef EIGEN_BLASUTIL_H 11 #define EIGEN_BLASUTIL_H 21 template<
typename LhsScalar,
typename RhsScalar,
typename Index,
typename DataMapper,
int mr,
int nr,
bool ConjugateLhs=false,
bool ConjugateRhs=false>
24 template<
typename Scalar,
typename Index,
typename DataMapper,
int nr,
int StorageOrder,
bool Conjugate = false,
bool PanelMode=false>
27 template<
typename Scalar,
typename Index,
typename DataMapper,
int Pack1,
int Pack2,
int StorageOrder,
bool Conjugate = false,
bool PanelMode = false>
32 typename LhsScalar,
int LhsStorageOrder,
bool ConjugateLhs,
33 typename RhsScalar,
int RhsStorageOrder,
bool ConjugateRhs,
35 struct general_matrix_matrix_product;
37 template<
typename Index,
38 typename LhsScalar,
typename LhsMapper,
int LhsStorageOrder,
bool ConjugateLhs,
39 typename RhsScalar,
typename RhsMapper,
bool ConjugateRhs,
int Version=Specialized>
40 struct general_matrix_vector_product;
43 template<
bool Conjugate>
struct conj_if;
45 template<>
struct conj_if<true> {
47 inline T operator()(
const T& x) {
return numext::conj(x); }
49 inline T pconj(
const T& x) {
return internal::pconj(x); }
52 template<>
struct conj_if<false> {
54 inline const T& operator()(
const T& x) {
return x; }
56 inline const T& pconj(
const T& x) {
return x; }
59 template<
typename Scalar>
struct conj_helper<Scalar,Scalar,false,false>
61 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar pmadd(
const Scalar& x,
const Scalar& y,
const Scalar& c)
const {
return internal::pmadd(x,y,c); }
62 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar pmul(
const Scalar& x,
const Scalar& y)
const {
return internal::pmul(x,y); }
65 template<
typename RealScalar>
struct conj_helper<
std::complex<RealScalar>, std::complex<RealScalar>, false,true>
67 typedef std::complex<RealScalar> Scalar;
68 EIGEN_STRONG_INLINE Scalar pmadd(
const Scalar& x,
const Scalar& y,
const Scalar& c)
const 69 {
return c + pmul(x,y); }
71 EIGEN_STRONG_INLINE Scalar pmul(
const Scalar& x,
const Scalar& y)
const 72 {
return Scalar(numext::real(x)*numext::real(y) + numext::imag(x)*numext::imag(y), numext::imag(x)*numext::real(y) - numext::real(x)*numext::imag(y)); }
75 template<
typename RealScalar>
struct conj_helper<
std::complex<RealScalar>, std::complex<RealScalar>, true,false>
77 typedef std::complex<RealScalar> Scalar;
78 EIGEN_STRONG_INLINE Scalar pmadd(
const Scalar& x,
const Scalar& y,
const Scalar& c)
const 79 {
return c + pmul(x,y); }
81 EIGEN_STRONG_INLINE Scalar pmul(
const Scalar& x,
const Scalar& y)
const 82 {
return Scalar(numext::real(x)*numext::real(y) + numext::imag(x)*numext::imag(y), numext::real(x)*numext::imag(y) - numext::imag(x)*numext::real(y)); }
85 template<
typename RealScalar>
struct conj_helper<
std::complex<RealScalar>, std::complex<RealScalar>, true,true>
87 typedef std::complex<RealScalar> Scalar;
88 EIGEN_STRONG_INLINE Scalar pmadd(
const Scalar& x,
const Scalar& y,
const Scalar& c)
const 89 {
return c + pmul(x,y); }
91 EIGEN_STRONG_INLINE Scalar pmul(
const Scalar& x,
const Scalar& y)
const 92 {
return Scalar(numext::real(x)*numext::real(y) - numext::imag(x)*numext::imag(y), - numext::real(x)*numext::imag(y) - numext::imag(x)*numext::real(y)); }
95 template<
typename RealScalar,
bool Conj>
struct conj_helper<
std::complex<RealScalar>, RealScalar, Conj,false>
97 typedef std::complex<RealScalar> Scalar;
98 EIGEN_STRONG_INLINE Scalar pmadd(
const Scalar& x,
const RealScalar& y,
const Scalar& c)
const 99 {
return padd(c, pmul(x,y)); }
100 EIGEN_STRONG_INLINE Scalar pmul(
const Scalar& x,
const RealScalar& y)
const 101 {
return conj_if<Conj>()(x)*y; }
104 template<
typename RealScalar,
bool Conj>
struct conj_helper<RealScalar,
std::complex<RealScalar>, false,Conj>
106 typedef std::complex<RealScalar> Scalar;
107 EIGEN_STRONG_INLINE Scalar pmadd(
const RealScalar& x,
const Scalar& y,
const Scalar& c)
const 108 {
return padd(c, pmul(x,y)); }
109 EIGEN_STRONG_INLINE Scalar pmul(
const RealScalar& x,
const Scalar& y)
const 110 {
return x*conj_if<Conj>()(y); }
113 template<
typename From,
typename To>
struct get_factor {
114 EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE To run(
const From& x) {
return To(x); }
117 template<
typename Scalar>
struct get_factor<Scalar,typename NumTraits<Scalar>::Real> {
119 static EIGEN_STRONG_INLINE
typename NumTraits<Scalar>::Real run(
const Scalar& x) {
return numext::real(x); }
123 template<
typename Scalar,
typename Index>
124 class BlasVectorMapper {
126 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE BlasVectorMapper(Scalar *data) : m_data(data) {}
128 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i)
const {
131 template <
typename Packet,
int AlignmentType>
132 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet load(Index i)
const {
133 return ploadt<Packet, AlignmentType>(m_data + i);
136 template <
typename Packet>
137 EIGEN_DEVICE_FUNC
bool aligned(Index i)
const {
138 return (UIntPtr(m_data+i)%
sizeof(Packet))==0;
145 template<
typename Scalar,
typename Index,
int AlignmentType>
146 class BlasLinearMapper {
148 typedef typename packet_traits<Scalar>::type Packet;
149 typedef typename packet_traits<Scalar>::half HalfPacket;
151 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE BlasLinearMapper(Scalar *data) : m_data(data) {}
153 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
void prefetch(
int i)
const {
154 internal::prefetch(&
operator()(i));
157 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar& operator()(Index i)
const {
161 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i)
const {
162 return ploadt<Packet, AlignmentType>(m_data + i);
165 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE HalfPacket loadHalfPacket(Index i)
const {
166 return ploadt<HalfPacket, AlignmentType>(m_data + i);
169 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
void storePacket(Index i,
const Packet &p)
const {
170 pstoret<Scalar, Packet, AlignmentType>(m_data + i, p);
178 template<
typename Scalar,
typename Index,
int StorageOrder,
int AlignmentType = Unaligned>
179 class blas_data_mapper {
181 typedef typename packet_traits<Scalar>::type Packet;
182 typedef typename packet_traits<Scalar>::half HalfPacket;
184 typedef BlasLinearMapper<Scalar, Index, AlignmentType> LinearMapper;
185 typedef BlasVectorMapper<Scalar, Index> VectorMapper;
187 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE blas_data_mapper(Scalar* data, Index stride) : m_data(data), m_stride(stride) {}
189 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE blas_data_mapper<Scalar, Index, StorageOrder, AlignmentType>
190 getSubMapper(Index i, Index j)
const {
191 return blas_data_mapper<Scalar, Index, StorageOrder, AlignmentType>(&operator()(i, j), m_stride);
194 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper getLinearMapper(Index i, Index j)
const {
195 return LinearMapper(&
operator()(i, j));
198 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE VectorMapper getVectorMapper(Index i, Index j)
const {
199 return VectorMapper(&
operator()(i, j));
204 EIGEN_ALWAYS_INLINE Scalar& operator()(Index i, Index j)
const {
205 return m_data[StorageOrder==
RowMajor ? j + i*m_stride : i + j*m_stride];
208 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i, Index j)
const {
209 return ploadt<Packet, AlignmentType>(&operator()(i, j));
212 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE HalfPacket loadHalfPacket(Index i, Index j)
const {
213 return ploadt<HalfPacket, AlignmentType>(&operator()(i, j));
216 template<
typename SubPacket>
217 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
void scatterPacket(Index i, Index j,
const SubPacket &p)
const {
218 pscatter<Scalar, SubPacket>(&operator()(i, j), p, m_stride);
221 template<
typename SubPacket>
222 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE SubPacket gatherPacket(Index i, Index j)
const {
223 return pgather<Scalar, SubPacket>(&operator()(i, j), m_stride);
226 EIGEN_DEVICE_FUNC
const Index stride()
const {
return m_stride; }
227 EIGEN_DEVICE_FUNC
const Scalar* data()
const {
return m_data; }
229 EIGEN_DEVICE_FUNC Index firstAligned(Index size)
const {
230 if (UIntPtr(m_data)%
sizeof(Scalar)) {
233 return internal::first_default_aligned(m_data, size);
237 Scalar* EIGEN_RESTRICT m_data;
238 const Index m_stride;
242 template<
typename Scalar,
typename Index,
int StorageOrder>
243 class const_blas_data_mapper :
public blas_data_mapper<const Scalar, Index, StorageOrder> {
245 EIGEN_ALWAYS_INLINE const_blas_data_mapper(
const Scalar *data, Index stride) : blas_data_mapper<const Scalar, Index, StorageOrder>(data, stride) {}
247 EIGEN_ALWAYS_INLINE const_blas_data_mapper<Scalar, Index, StorageOrder> getSubMapper(Index i, Index j)
const {
248 return const_blas_data_mapper<Scalar, Index, StorageOrder>(&(this->operator()(i, j)), this->m_stride);
256 template<
typename XprType>
struct blas_traits
258 typedef typename traits<XprType>::Scalar Scalar;
259 typedef const XprType& ExtractType;
260 typedef XprType _ExtractType;
262 IsComplex = NumTraits<Scalar>::IsComplex,
263 IsTransposed =
false,
264 NeedToConjugate =
false,
266 && (
bool(XprType::IsVectorAtCompileTime)
267 || int(inner_stride_at_compile_time<XprType>::ret) == 1)
270 typedef typename conditional<bool(HasUsableDirectAccess),
272 typename _ExtractType::PlainObject
273 >::type DirectLinearAccessType;
274 static inline ExtractType extract(
const XprType& x) {
return x; }
275 static inline const Scalar extractScalarFactor(
const XprType&) {
return Scalar(1); }
279 template<
typename Scalar,
typename NestedXpr>
280 struct blas_traits<CwiseUnaryOp<scalar_conjugate_op<Scalar>, NestedXpr> >
281 : blas_traits<NestedXpr>
283 typedef blas_traits<NestedXpr> Base;
284 typedef CwiseUnaryOp<scalar_conjugate_op<Scalar>, NestedXpr> XprType;
285 typedef typename Base::ExtractType ExtractType;
288 IsComplex = NumTraits<Scalar>::IsComplex,
289 NeedToConjugate = Base::NeedToConjugate ? 0 : IsComplex
291 static inline ExtractType extract(
const XprType& x) {
return Base::extract(x.nestedExpression()); }
292 static inline Scalar extractScalarFactor(
const XprType& x) {
return conj(Base::extractScalarFactor(x.nestedExpression())); }
296 template<
typename Scalar,
typename NestedXpr,
typename Plain>
297 struct blas_traits<CwiseBinaryOp<scalar_product_op<Scalar>, const CwiseNullaryOp<scalar_constant_op<Scalar>,Plain>, NestedXpr> >
298 : blas_traits<NestedXpr>
300 typedef blas_traits<NestedXpr> Base;
301 typedef CwiseBinaryOp<scalar_product_op<Scalar>,
const CwiseNullaryOp<scalar_constant_op<Scalar>,Plain>, NestedXpr> XprType;
302 typedef typename Base::ExtractType ExtractType;
303 static inline ExtractType extract(
const XprType& x) {
return Base::extract(x.rhs()); }
304 static inline Scalar extractScalarFactor(
const XprType& x)
305 {
return x.lhs().functor().m_other * Base::extractScalarFactor(x.rhs()); }
307 template<
typename Scalar,
typename NestedXpr,
typename Plain>
308 struct blas_traits<CwiseBinaryOp<scalar_product_op<Scalar>, NestedXpr, const CwiseNullaryOp<scalar_constant_op<Scalar>,Plain> > >
309 : blas_traits<NestedXpr>
311 typedef blas_traits<NestedXpr> Base;
312 typedef CwiseBinaryOp<scalar_product_op<Scalar>, NestedXpr,
const CwiseNullaryOp<scalar_constant_op<Scalar>,Plain> > XprType;
313 typedef typename Base::ExtractType ExtractType;
314 static inline ExtractType extract(
const XprType& x) {
return Base::extract(x.lhs()); }
315 static inline Scalar extractScalarFactor(
const XprType& x)
316 {
return Base::extractScalarFactor(x.lhs()) * x.rhs().functor().m_other; }
320 template<
typename Scalar,
typename NestedXpr>
321 struct blas_traits<CwiseUnaryOp<scalar_opposite_op<Scalar>, NestedXpr> >
322 : blas_traits<NestedXpr>
324 typedef blas_traits<NestedXpr> Base;
325 typedef CwiseUnaryOp<scalar_opposite_op<Scalar>, NestedXpr> XprType;
326 typedef typename Base::ExtractType ExtractType;
327 static inline ExtractType extract(
const XprType& x) {
return Base::extract(x.nestedExpression()); }
328 static inline Scalar extractScalarFactor(
const XprType& x)
329 {
return - Base::extractScalarFactor(x.nestedExpression()); }
333 template<
typename NestedXpr>
334 struct blas_traits<Transpose<NestedXpr> >
335 : blas_traits<NestedXpr>
337 typedef typename NestedXpr::Scalar Scalar;
338 typedef blas_traits<NestedXpr> Base;
339 typedef Transpose<NestedXpr> XprType;
340 typedef Transpose<const typename Base::_ExtractType> ExtractType;
341 typedef Transpose<const typename Base::_ExtractType> _ExtractType;
342 typedef typename conditional<bool(Base::HasUsableDirectAccess),
344 typename ExtractType::PlainObject
345 >::type DirectLinearAccessType;
347 IsTransposed = Base::IsTransposed ? 0 : 1
349 static inline ExtractType extract(
const XprType& x) {
return ExtractType(Base::extract(x.nestedExpression())); }
350 static inline Scalar extractScalarFactor(
const XprType& x) {
return Base::extractScalarFactor(x.nestedExpression()); }
354 struct blas_traits<const T>
358 template<typename T, bool HasUsableDirectAccess=blas_traits<T>::HasUsableDirectAccess>
359 struct extract_data_selector {
360 static const typename T::Scalar* run(
const T& m)
362 return blas_traits<T>::extract(m).data();
367 struct extract_data_selector<T,false> {
368 static typename T::Scalar* run(
const T&) {
return 0; }
371 template<
typename T>
const typename T::Scalar* extract_data(
const T& m)
373 return extract_data_selector<T>::run(m);
380 #endif // EIGEN_BLASUTIL_H const unsigned int DirectAccessBit
Definition: Constants.h:150
const Eigen::CwiseUnaryOp< Eigen::internal::scalar_conjugate_op< typename Derived::Scalar >, const Derived > conj(const Eigen::ArrayBase< Derived > &x)
Namespace containing all symbols from the Eigen library.
Definition: Core:271
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition: XprHelper.h:35
Definition: Eigen_Colamd.h:50
Definition: Constants.h:322