TensorForcedEval.h
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_FORCED_EVAL_H
11 #define EIGEN_CXX11_TENSOR_TENSOR_FORCED_EVAL_H
12 
13 namespace Eigen {
14 
22 namespace internal {
23 template<typename XprType>
24 struct traits<TensorForcedEvalOp<XprType> >
25 {
26  // Type promotion to handle the case where the types of the lhs and the rhs are different.
27  typedef typename XprType::Scalar Scalar;
28  typedef traits<XprType> XprTraits;
29  typedef typename packet_traits<Scalar>::type Packet;
30  typedef typename traits<XprType>::StorageKind StorageKind;
31  typedef typename traits<XprType>::Index Index;
32  typedef typename XprType::Nested Nested;
33  typedef typename remove_reference<Nested>::type _Nested;
34  static const int NumDimensions = XprTraits::NumDimensions;
35  static const int Layout = XprTraits::Layout;
36 
37  enum {
38  Flags = 0,
39  };
40 };
41 
42 template<typename XprType>
43 struct eval<TensorForcedEvalOp<XprType>, Eigen::Dense>
44 {
45  typedef const TensorForcedEvalOp<XprType>& type;
46 };
47 
48 template<typename XprType>
49 struct nested<TensorForcedEvalOp<XprType>, 1, typename eval<TensorForcedEvalOp<XprType> >::type>
50 {
51  typedef TensorForcedEvalOp<XprType> type;
52 };
53 
54 } // end namespace internal
55 
56 
57 
58 template<typename XprType>
59 class TensorForcedEvalOp : public TensorBase<TensorForcedEvalOp<XprType> >
60 {
61  public:
62  typedef typename Eigen::internal::traits<TensorForcedEvalOp>::Scalar Scalar;
63  typedef typename Eigen::internal::traits<TensorForcedEvalOp>::Packet Packet;
64  typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
65  typedef typename internal::remove_const<typename XprType::CoeffReturnType>::type CoeffReturnType;
66  typedef typename internal::remove_const<typename XprType::PacketReturnType>::type PacketReturnType;
67  typedef typename Eigen::internal::nested<TensorForcedEvalOp>::type Nested;
68  typedef typename Eigen::internal::traits<TensorForcedEvalOp>::StorageKind StorageKind;
69  typedef typename Eigen::internal::traits<TensorForcedEvalOp>::Index Index;
70 
71  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorForcedEvalOp(const XprType& expr)
72  : m_xpr(expr) {}
73 
74  EIGEN_DEVICE_FUNC
75  const typename internal::remove_all<typename XprType::Nested>::type&
76  expression() const { return m_xpr; }
77 
78  protected:
79  typename XprType::Nested m_xpr;
80 };
81 
82 
83 template<typename ArgType, typename Device>
84 struct TensorEvaluator<const TensorForcedEvalOp<ArgType>, Device>
85 {
86  typedef TensorForcedEvalOp<ArgType> XprType;
87  typedef typename ArgType::Scalar Scalar;
88  typedef typename ArgType::Packet Packet;
89  typedef typename TensorEvaluator<ArgType, Device>::Dimensions Dimensions;
90 
91  enum {
92  IsAligned = true,
93  PacketAccess = (internal::packet_traits<Scalar>::size > 1),
94  Layout = TensorEvaluator<ArgType, Device>::Layout,
95  };
96 
97  EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op, const Device& device)
98  : m_impl(op.expression(), device), m_op(op.expression()), m_device(device), m_buffer(NULL)
99  { }
100 
101  typedef typename XprType::Index Index;
102  typedef typename XprType::CoeffReturnType CoeffReturnType;
103  typedef typename XprType::PacketReturnType PacketReturnType;
104 
105  EIGEN_DEVICE_FUNC const Dimensions& dimensions() const { return m_impl.dimensions(); }
106 
107  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(CoeffReturnType*) {
108  m_impl.evalSubExprsIfNeeded(NULL);
109  const Index numValues = m_impl.dimensions().TotalSize();
110  m_buffer = (CoeffReturnType*)m_device.allocate(numValues * sizeof(CoeffReturnType));
111  // Should initialize the memory in case we're dealing with non POD types.
112  if (NumTraits<CoeffReturnType>::RequireInitialization) {
113  for (Index i = 0; i < numValues; ++i) {
114  new(m_buffer+i) CoeffReturnType();
115  }
116  }
117  typedef TensorEvalToOp<const ArgType> EvalTo;
118  EvalTo evalToTmp(m_buffer, m_op);
119  const bool PacketAccess = internal::IsVectorizable<Device, ArgType>::value;
120  internal::TensorExecutor<const EvalTo, Device, PacketAccess>::run(evalToTmp, m_device);
121  m_impl.cleanup();
122  return true;
123  }
124  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() {
125  m_device.deallocate(m_buffer);
126  m_buffer = NULL;
127  }
128 
129  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const
130  {
131  return m_buffer[index];
132  }
133 
134  template<int LoadMode>
135  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index) const
136  {
137  return internal::ploadt<Packet, LoadMode>(m_buffer + index);
138  }
139 
140  EIGEN_DEVICE_FUNC Scalar* data() const { return m_buffer; }
141 
142  private:
143  TensorEvaluator<ArgType, Device> m_impl;
144  const ArgType m_op;
145  const Device& m_device;
146  CoeffReturnType* m_buffer;
147 };
148 
149 
150 } // end namespace Eigen
151 
152 #endif // EIGEN_CXX11_TENSOR_TENSOR_FORCED_EVAL_H
Namespace containing all symbols from the Eigen library.
Definition: CXX11Meta.h:13