TensorMap.h
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_MAP_H
11 #define EIGEN_CXX11_TENSOR_TENSOR_MAP_H
12 
13 namespace Eigen {
14 
22 template<typename PlainObjectType, int Options_> class TensorMap : public TensorBase<TensorMap<PlainObjectType, Options_> >
23 {
24  public:
25  typedef TensorMap<PlainObjectType, Options_> Self;
26  typedef typename PlainObjectType::Base Base;
27  typedef typename Eigen::internal::nested<Self>::type Nested;
28  typedef typename internal::traits<PlainObjectType>::StorageKind StorageKind;
29  typedef typename internal::traits<PlainObjectType>::Index Index;
30  typedef typename internal::traits<PlainObjectType>::Scalar Scalar;
31  typedef typename internal::packet_traits<Scalar>::type Packet;
32  typedef typename NumTraits<Scalar>::Real RealScalar;
33  typedef typename Base::CoeffReturnType CoeffReturnType;
34 
35  /* typedef typename internal::conditional<
36  bool(internal::is_lvalue<PlainObjectType>::value),
37  Scalar *,
38  const Scalar *>::type
39  PointerType;*/
40  typedef Scalar* PointerType;
41  typedef PointerType PointerArgType;
42 
43  static const int Options = Options_;
44 
45  static const Index NumIndices = PlainObjectType::NumIndices;
46  typedef typename PlainObjectType::Dimensions Dimensions;
47 
48  enum {
49  IsAligned = ((int(Options_)&Aligned)==Aligned),
50  PacketAccess = (internal::packet_traits<Scalar>::size > 1),
51  Layout = PlainObjectType::Layout,
52  CoordAccess = true,
53  };
54 
55 #ifdef EIGEN_HAS_VARIADIC_TEMPLATES
56  template<typename... IndexTypes> EIGEN_DEVICE_FUNC
57  EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, Index firstDimension, IndexTypes... otherDimensions) : m_data(dataPtr), m_dimensions(firstDimension, otherDimensions...) {
58  // The number of dimensions used to construct a tensor must be equal to the rank of the tensor.
59  EIGEN_STATIC_ASSERT((sizeof...(otherDimensions) + 1 == NumIndices || NumIndices == Dynamic), YOU_MADE_A_PROGRAMMING_MISTAKE)
60  }
61 #else
62  EIGEN_DEVICE_FUNC
63  EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, Index firstDimension) : m_data(dataPtr), m_dimensions(firstDimension) {
64  // The number of dimensions used to construct a tensor must be equal to the rank of the tensor.
65  EIGEN_STATIC_ASSERT((1 == NumIndices || NumIndices == Dynamic), YOU_MADE_A_PROGRAMMING_MISTAKE)
66  }
67  EIGEN_DEVICE_FUNC
68  EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, Index dim1, Index dim2) : m_data(dataPtr), m_dimensions(dim1, dim2) {
69  EIGEN_STATIC_ASSERT(2 == NumIndices || NumIndices == Dynamic, YOU_MADE_A_PROGRAMMING_MISTAKE)
70  }
71  EIGEN_DEVICE_FUNC
72  EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, Index dim1, Index dim2, Index dim3) : m_data(dataPtr), m_dimensions(dim1, dim2, dim3) {
73  EIGEN_STATIC_ASSERT(3 == NumIndices || NumIndices == Dynamic, YOU_MADE_A_PROGRAMMING_MISTAKE)
74  }
75  EIGEN_DEVICE_FUNC
76  EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, Index dim1, Index dim2, Index dim3, Index dim4) : m_data(dataPtr), m_dimensions(dim1, dim2, dim3, dim4) {
77  EIGEN_STATIC_ASSERT(4 == NumIndices || NumIndices == Dynamic, YOU_MADE_A_PROGRAMMING_MISTAKE)
78  }
79  EIGEN_DEVICE_FUNC
80  EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, Index dim1, Index dim2, Index dim3, Index dim4, Index dim5) : m_data(dataPtr), m_dimensions(dim1, dim2, dim3, dim4, dim5) {
81  EIGEN_STATIC_ASSERT(5 == NumIndices || NumIndices == Dynamic, YOU_MADE_A_PROGRAMMING_MISTAKE)
82  }
83 #endif
84 
85  inline TensorMap(PointerArgType dataPtr, const array<Index, NumIndices>& dimensions)
86  : m_data(dataPtr), m_dimensions(dimensions)
87  { }
88 
89  template <typename Dimensions>
90  EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, const Dimensions& dimensions)
91  : m_data(dataPtr), m_dimensions(dimensions)
92  { }
93 
94  EIGEN_DEVICE_FUNC
95  EIGEN_STRONG_INLINE Index rank() const { return m_dimensions.rank(); }
96  EIGEN_DEVICE_FUNC
97  EIGEN_STRONG_INLINE Index dimension(Index n) const { return m_dimensions[n]; }
98  EIGEN_DEVICE_FUNC
99  EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; }
100  EIGEN_DEVICE_FUNC
101  EIGEN_STRONG_INLINE Index size() const { return m_dimensions.TotalSize(); }
102  EIGEN_DEVICE_FUNC
103  EIGEN_STRONG_INLINE Scalar* data() { return m_data; }
104  EIGEN_DEVICE_FUNC
105  EIGEN_STRONG_INLINE const Scalar* data() const { return m_data; }
106 
107  EIGEN_DEVICE_FUNC
108  EIGEN_STRONG_INLINE const Scalar& operator()(const array<Index, NumIndices>& indices) const
109  {
110  // eigen_assert(checkIndexRange(indices));
111  if (PlainObjectType::Options&RowMajor) {
112  const Index index = m_dimensions.IndexOfRowMajor(indices);
113  return m_data[index];
114  } else {
115  const Index index = m_dimensions.IndexOfColMajor(indices);
116  return m_data[index];
117  }
118  }
119 
120 #ifdef EIGEN_HAS_VARIADIC_TEMPLATES
121  template<typename... IndexTypes> EIGEN_DEVICE_FUNC
122  EIGEN_STRONG_INLINE const Scalar& operator()(Index firstIndex, IndexTypes... otherIndices) const
123  {
124  static_assert(sizeof...(otherIndices) + 1 == NumIndices, "Number of indices used to access a tensor coefficient must be equal to the rank of the tensor.");
125  if (PlainObjectType::Options&RowMajor) {
126  const Index index = m_dimensions.IndexOfRowMajor(array<Index, NumIndices>{{firstIndex, otherIndices...}});
127  return m_data[index];
128  } else {
129  const Index index = m_dimensions.IndexOfColMajor(array<Index, NumIndices>{{firstIndex, otherIndices...}});
130  return m_data[index];
131  }
132  }
133 #else
134  EIGEN_DEVICE_FUNC
135  EIGEN_STRONG_INLINE const Scalar& operator()(Index index) const
136  {
137  eigen_internal_assert(index >= 0 && index < size());
138  return m_data[index];
139  }
140  EIGEN_DEVICE_FUNC
141  EIGEN_STRONG_INLINE const Scalar& operator()(Index i0, Index i1) const
142  {
143  if (PlainObjectType::Options&RowMajor) {
144  const Index index = i1 + i0 * m_dimensions[0];
145  return m_data[index];
146  } else {
147  const Index index = i0 + i1 * m_dimensions[0];
148  return m_data[index];
149  }
150  }
151  EIGEN_DEVICE_FUNC
152  EIGEN_STRONG_INLINE const Scalar& operator()(Index i0, Index i1, Index i2) const
153  {
154  if (PlainObjectType::Options&RowMajor) {
155  const Index index = i2 + m_dimensions[1] * (i1 + m_dimensions[0] * i0);
156  return m_data[index];
157  } else {
158  const Index index = i0 + m_dimensions[0] * (i1 + m_dimensions[1] * i2);
159  return m_data[index];
160  }
161  }
162  EIGEN_DEVICE_FUNC
163  EIGEN_STRONG_INLINE const Scalar& operator()(Index i0, Index i1, Index i2, Index i3) const
164  {
165  if (PlainObjectType::Options&RowMajor) {
166  const Index index = i3 + m_dimensions[3] * (i2 + m_dimensions[2] * (i1 + m_dimensions[1] * i0));
167  return m_data[index];
168  } else {
169  const Index index = i0 + m_dimensions[0] * (i1 + m_dimensions[1] * (i2 + m_dimensions[2] * i3));
170  return m_data[index];
171  }
172  }
173  EIGEN_DEVICE_FUNC
174  EIGEN_STRONG_INLINE const Scalar& operator()(Index i0, Index i1, Index i2, Index i3, Index i4) const
175  {
176  if (PlainObjectType::Options&RowMajor) {
177  const Index index = i4 + m_dimensions[4] * (i3 + m_dimensions[3] * (i2 + m_dimensions[2] * (i1 + m_dimensions[1] * i0)));
178  return m_data[index];
179  } else {
180  const Index index = i0 + m_dimensions[0] * (i1 + m_dimensions[1] * (i2 + m_dimensions[2] * (i3 + m_dimensions[3] * i4)));
181  return m_data[index];
182  }
183  }
184 #endif
185 
186  EIGEN_DEVICE_FUNC
187  EIGEN_STRONG_INLINE Scalar& operator()(const array<Index, NumIndices>& indices)
188  {
189  // eigen_assert(checkIndexRange(indices));
190  if (PlainObjectType::Options&RowMajor) {
191  const Index index = m_dimensions.IndexOfRowMajor(indices);
192  return m_data[index];
193  } else {
194  const Index index = m_dimensions.IndexOfColMajor(indices);
195  return m_data[index];
196  }
197  }
198 
199 #ifdef EIGEN_HAS_VARIADIC_TEMPLATES
200  template<typename... IndexTypes> EIGEN_DEVICE_FUNC
201  EIGEN_STRONG_INLINE Scalar& operator()(Index firstIndex, IndexTypes... otherIndices)
202  {
203  static_assert(sizeof...(otherIndices) + 1 == NumIndices || NumIndices == Dynamic, "Number of indices used to access a tensor coefficient must be equal to the rank of the tensor.");
204  const std::size_t NumDims = sizeof...(otherIndices) + 1;
205  if (PlainObjectType::Options&RowMajor) {
206  const Index index = m_dimensions.IndexOfRowMajor(array<Index, NumDims>{{firstIndex, otherIndices...}});
207  return m_data[index];
208  } else {
209  const Index index = m_dimensions.IndexOfColMajor(array<Index, NumDims>{{firstIndex, otherIndices...}});
210  return m_data[index];
211  }
212  }
213 #else
214  EIGEN_DEVICE_FUNC
215  EIGEN_STRONG_INLINE Scalar& operator()(Index index)
216  {
217  eigen_internal_assert(index >= 0 && index < size());
218  return m_data[index];
219  }
220  EIGEN_DEVICE_FUNC
221  EIGEN_STRONG_INLINE Scalar& operator()(Index i0, Index i1)
222  {
223  if (PlainObjectType::Options&RowMajor) {
224  const Index index = i1 + i0 * m_dimensions[0];
225  return m_data[index];
226  } else {
227  const Index index = i0 + i1 * m_dimensions[0];
228  return m_data[index];
229  }
230  }
231  EIGEN_DEVICE_FUNC
232  EIGEN_STRONG_INLINE Scalar& operator()(Index i0, Index i1, Index i2)
233  {
234  if (PlainObjectType::Options&RowMajor) {
235  const Index index = i2 + m_dimensions[1] * (i1 + m_dimensions[0] * i0);
236  return m_data[index];
237  } else {
238  const Index index = i0 + m_dimensions[0] * (i1 + m_dimensions[1] * i2);
239  return m_data[index];
240  }
241  }
242  EIGEN_DEVICE_FUNC
243  EIGEN_STRONG_INLINE Scalar& operator()(Index i0, Index i1, Index i2, Index i3)
244  {
245  if (PlainObjectType::Options&RowMajor) {
246  const Index index = i3 + m_dimensions[3] * (i2 + m_dimensions[2] * (i1 + m_dimensions[1] * i0));
247  return m_data[index];
248  } else {
249  const Index index = i0 + m_dimensions[0] * (i1 + m_dimensions[1] * (i2 + m_dimensions[2] * i3));
250  return m_data[index];
251  }
252  }
253  EIGEN_DEVICE_FUNC
254  EIGEN_STRONG_INLINE Scalar& operator()(Index i0, Index i1, Index i2, Index i3, Index i4)
255  {
256  if (PlainObjectType::Options&RowMajor) {
257  const Index index = i4 + m_dimensions[4] * (i3 + m_dimensions[3] * (i2 + m_dimensions[2] * (i1 + m_dimensions[1] * i0)));
258  return m_data[index];
259  } else {
260  const Index index = i0 + m_dimensions[0] * (i1 + m_dimensions[1] * (i2 + m_dimensions[2] * (i3 + m_dimensions[3] * i4)));
261  return m_data[index];
262  }
263  }
264 #endif
265 
266  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Self& operator=(const Self& other)
267  {
268  typedef TensorAssignOp<Self, const Self> Assign;
269  Assign assign(*this, other);
270  internal::TensorExecutor<const Assign, DefaultDevice>::run(assign, DefaultDevice());
271  return *this;
272  }
273 
274  template<typename OtherDerived>
275  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
276  Self& operator=(const OtherDerived& other)
277  {
278  typedef TensorAssignOp<Self, const OtherDerived> Assign;
279  Assign assign(*this, other);
280  internal::TensorExecutor<const Assign, DefaultDevice>::run(assign, DefaultDevice());
281  return *this;
282  }
283 
284  private:
285  Scalar* m_data;
286  Dimensions m_dimensions;
287 };
288 
289 } // end namespace Eigen
290 
291 #endif // EIGEN_CXX11_TENSOR_TENSOR_MAP_H
Namespace containing all symbols from the Eigen library.
Definition: CXX11Meta.h:13