TensorDimensions.h
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_DIMENSIONS_H
11 #define EIGEN_CXX11_TENSOR_TENSOR_DIMENSIONS_H
12 
13 
14 namespace Eigen {
15 
32 // Can't use std::pair on cuda devices
33 template <typename Index> struct IndexPair {
34  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE IndexPair() : first(0), second(0) { }
35  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE IndexPair(Index f, Index s) : first(f), second(s) { }
36  Index first;
37  Index second;
38 };
39 
40 // Boilerplate code
41 namespace internal {
42 
43 template<std::size_t n, typename Dimension> struct dget {
44  static const std::size_t value = get<n, Dimension>::value;
45 };
46 
47 
48 template<typename Index, std::size_t NumIndices, std::size_t n, bool RowMajor>
49 struct fixed_size_tensor_index_linearization_helper
50 {
51  template <typename Dimensions> EIGEN_DEVICE_FUNC
52  static inline Index run(array<Index, NumIndices> const& indices,
53  const Dimensions& dimensions)
54  {
55  return array_get<RowMajor ? n : (NumIndices - n - 1)>(indices) +
56  dget<RowMajor ? n : (NumIndices - n - 1), Dimensions>::value *
57  fixed_size_tensor_index_linearization_helper<Index, NumIndices, n - 1, RowMajor>::run(indices, dimensions);
58  }
59 };
60 
61 template<typename Index, std::size_t NumIndices, bool RowMajor>
62 struct fixed_size_tensor_index_linearization_helper<Index, NumIndices, 0, RowMajor>
63 {
64  template <typename Dimensions> EIGEN_DEVICE_FUNC
65  static inline Index run(array<Index, NumIndices> const& indices,
66  const Dimensions&)
67  {
68  return array_get<RowMajor ? 0 : NumIndices - 1>(indices);
69  }
70 };
71 
72 template<typename Index, std::size_t n>
73 struct fixed_size_tensor_index_extraction_helper
74 {
75  template <typename Dimensions> EIGEN_DEVICE_FUNC
76  static inline Index run(const Index index,
77  const Dimensions& dimensions)
78  {
79  const Index mult = (index == n) ? 1 : 0;
80  return array_get<n>(dimensions) * mult +
81  fixed_size_tensor_index_extraction_helper<Index, n - 1>::run(index, dimensions);
82  }
83 };
84 
85 template<typename Index>
86 struct fixed_size_tensor_index_extraction_helper<Index, 0>
87 {
88  template <typename Dimensions> EIGEN_DEVICE_FUNC
89  static inline Index run(const Index index,
90  const Dimensions& dimensions)
91  {
92  const Index mult = (index == 0) ? 1 : 0;
93  return array_get<0>(dimensions) * mult;
94  }
95 };
96 
97 } // end namespace internal
98 
99 
100 // Fixed size
101 #ifndef EIGEN_EMULATE_CXX11_META_H
102 template <typename std::ptrdiff_t... Indices>
103 struct Sizes : internal::numeric_list<std::ptrdiff_t, Indices...> {
104  typedef internal::numeric_list<std::ptrdiff_t, Indices...> Base;
105  static const std::ptrdiff_t total_size = internal::arg_prod(Indices...);
106 
107  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::ptrdiff_t rank() const {
108  return Base::count;
109  }
110 
111  static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::ptrdiff_t TotalSize() {
112  return internal::arg_prod(Indices...);
113  }
114 
115  Sizes() { }
116  template <typename DenseIndex>
117  explicit Sizes(const array<DenseIndex, Base::count>& /*indices*/) {
118  // todo: add assertion
119  }
120 #ifdef EIGEN_HAS_VARIADIC_TEMPLATES
121  template <typename... DenseIndex> Sizes(DenseIndex...) { }
122  explicit Sizes(std::initializer_list<std::ptrdiff_t> /*l*/) {
123  // todo: add assertion
124  }
125 #endif
126 
127  template <typename T> Sizes& operator = (const T& /*other*/) {
128  // add assertion failure if the size of other is different
129  return *this;
130  }
131 
132  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::ptrdiff_t operator[] (const std::size_t index) const {
133  return internal::fixed_size_tensor_index_extraction_helper<std::ptrdiff_t, Base::count - 1>::run(index, *this);
134  }
135 
136  template <typename DenseIndex> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
137  size_t IndexOfColMajor(const array<DenseIndex, Base::count>& indices) const {
138  return internal::fixed_size_tensor_index_linearization_helper<DenseIndex, Base::count, Base::count - 1, false>::run(indices, *static_cast<const Base*>(this));
139  }
140  template <typename DenseIndex> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
141  size_t IndexOfRowMajor(const array<DenseIndex, Base::count>& indices) const {
142  return internal::fixed_size_tensor_index_linearization_helper<DenseIndex, Base::count, Base::count - 1, true>::run(indices, *static_cast<const Base*>(this));
143  }
144 };
145 
146 namespace internal {
147 template <typename std::ptrdiff_t... Indices>
148 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::ptrdiff_t array_prod(const Sizes<Indices...>&) {
149  return Sizes<Indices...>::total_size;
150 }
151 }
152 
153 #else
154 
155 template <std::size_t n>
156 struct non_zero_size {
157  typedef internal::type2val<std::size_t, n> type;
158 };
159 template <>
160 struct non_zero_size<0> {
161  typedef internal::null_type type;
162 };
163 
164 template <std::size_t V1=0, std::size_t V2=0, std::size_t V3=0, std::size_t V4=0, std::size_t V5=0> struct Sizes {
165  typedef typename internal::make_type_list<typename non_zero_size<V1>::type, typename non_zero_size<V2>::type, typename non_zero_size<V3>::type, typename non_zero_size<V4>::type, typename non_zero_size<V5>::type >::type Base;
166  static const size_t count = Base::count;
167  static const std::size_t total_size = internal::arg_prod<Base>::value;
168 
169  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE size_t rank() const {
170  return count;
171  }
172 
173  static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE size_t TotalSize() {
174  return internal::arg_prod<Base>::value;
175  }
176 
177  Sizes() { }
178  template <typename DenseIndex>
179  explicit Sizes(const array<DenseIndex, Base::count>& /*indices*/) {
180  // todo: add assertion
181  }
182  template <typename T> Sizes& operator = (const T& /*other*/) {
183  // add assertion failure if the size of other is different
184  return *this;
185  }
186 
187 #ifdef EIGEN_HAS_VARIADIC_TEMPLATES
188  template <typename... DenseIndex> Sizes(DenseIndex... /*indices*/) { }
189  explicit Sizes(std::initializer_list<std::size_t>) {
190  // todo: add assertion
191  }
192 #else
193  EIGEN_DEVICE_FUNC explicit Sizes(const DenseIndex) {
194  }
195  EIGEN_DEVICE_FUNC explicit Sizes(const DenseIndex, const DenseIndex) {
196  }
197  EIGEN_DEVICE_FUNC explicit Sizes(const DenseIndex, const DenseIndex, const DenseIndex) {
198  }
199  EIGEN_DEVICE_FUNC explicit Sizes(const DenseIndex, const DenseIndex, const DenseIndex, const DenseIndex) {
200  }
201  EIGEN_DEVICE_FUNC explicit Sizes(const DenseIndex, const DenseIndex, const DenseIndex, const DenseIndex, const DenseIndex) {
202  }
203 #endif
204 
205  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE DenseIndex operator[] (const int index) const {
206  switch (index) {
207  case 0:
208  return internal::get<0, Base>::value;
209  case 1:
210  return internal::get<1, Base>::value;
211  case 2:
212  return internal::get<2, Base>::value;
213  case 3:
214  return internal::get<3, Base>::value;
215  case 4:
216  return internal::get<4, Base>::value;
217  default:
218  eigen_assert(false && "index overflow");
219  return static_cast<std::size_t>(-1);
220  }
221  }
222 
223  template <typename DenseIndex> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
224  size_t IndexOfColMajor(const array<DenseIndex, Base::count>& indices) const {
225  return internal::fixed_size_tensor_index_linearization_helper<DenseIndex, Base::count, Base::count - 1, false>::run(indices, *reinterpret_cast<const Base*>(this));
226  }
227  template <typename DenseIndex> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
228  size_t IndexOfRowMajor(const array<DenseIndex, Base::count>& indices) const {
229  return internal::fixed_size_tensor_index_linearization_helper<DenseIndex, Base::count, Base::count - 1, true>::run(indices, *reinterpret_cast<const Base*>(this));
230  }
231 };
232 
233 namespace internal {
234 template <std::size_t V1, std::size_t V2, std::size_t V3, std::size_t V4, std::size_t V5>
235 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::size_t array_prod(const Sizes<V1, V2, V3, V4, V5>&) {
236  return Sizes<V1, V2, V3, V4, V5>::total_size;
237 }
238 }
239 
240 #endif
241 
242 // Boilerplate
243 namespace internal {
244 template<typename Index, std::size_t NumIndices, std::size_t n, bool RowMajor>
245 struct tensor_index_linearization_helper
246 {
247  static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
248  Index run(array<Index, NumIndices> const& indices, array<Index, NumIndices> const& dimensions)
249  {
250  return array_get<RowMajor ? n : (NumIndices - n - 1)>(indices) +
251  array_get<RowMajor ? n : (NumIndices - n - 1)>(dimensions) *
252  tensor_index_linearization_helper<Index, NumIndices, n - 1, RowMajor>::run(indices, dimensions);
253  }
254 };
255 
256 template<typename Index, std::size_t NumIndices, bool RowMajor>
257 struct tensor_index_linearization_helper<Index, NumIndices, 0, RowMajor>
258 {
259  static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
260  Index run(array<Index, NumIndices> const& indices, array<Index, NumIndices> const&)
261  {
262  return array_get<RowMajor ? 0 : NumIndices - 1>(indices);
263  }
264 };
265 } // end namespace internal
266 
267 
268 
269 // Dynamic size
270 template <typename DenseIndex, std::size_t NumDims>
271 struct DSizes : array<DenseIndex, NumDims> {
272  typedef array<DenseIndex, NumDims> Base;
273  static const std::size_t count = NumDims;
274 
275  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE size_t rank() const {
276  return NumDims;
277  }
278 
279  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE DenseIndex TotalSize() const {
280  return internal::array_prod(*static_cast<const Base*>(this));
281  }
282 
283  EIGEN_DEVICE_FUNC DSizes() {
284  for (std::size_t i = 0 ; i < NumDims; ++i) {
285  (*this)[i] = 0;
286  }
287  }
288  EIGEN_DEVICE_FUNC explicit DSizes(const array<DenseIndex, NumDims>& a) : Base(a) { }
289 
290 #ifdef EIGEN_HAS_VARIADIC_TEMPLATES
291  template<typename... IndexTypes> EIGEN_DEVICE_FUNC
292  EIGEN_STRONG_INLINE explicit DSizes(DenseIndex firstDimension, IndexTypes... otherDimensions) {
293  EIGEN_STATIC_ASSERT(sizeof...(otherDimensions) + 1 == NumDims, YOU_MADE_A_PROGRAMMING_MISTAKE)
294  (*this) = array<DenseIndex, NumDims>{{firstDimension, otherDimensions...}};
295  }
296 #else
297  EIGEN_DEVICE_FUNC explicit DSizes(const DenseIndex i0) {
298  eigen_assert(NumDims == 1);
299  (*this)[0] = i0;
300  }
301  EIGEN_DEVICE_FUNC explicit DSizes(const DenseIndex i0, const DenseIndex i1) {
302  eigen_assert(NumDims == 2);
303  (*this)[0] = i0;
304  (*this)[1] = i1;
305  }
306  EIGEN_DEVICE_FUNC explicit DSizes(const DenseIndex i0, const DenseIndex i1, const DenseIndex i2) {
307  eigen_assert(NumDims == 3);
308  (*this)[0] = i0;
309  (*this)[1] = i1;
310  (*this)[2] = i2;
311  }
312  EIGEN_DEVICE_FUNC explicit DSizes(const DenseIndex i0, const DenseIndex i1, const DenseIndex i2, const DenseIndex i3) {
313  eigen_assert(NumDims == 4);
314  (*this)[0] = i0;
315  (*this)[1] = i1;
316  (*this)[2] = i2;
317  (*this)[3] = i3;
318  }
319  EIGEN_DEVICE_FUNC explicit DSizes(const DenseIndex i0, const DenseIndex i1, const DenseIndex i2, const DenseIndex i3, const DenseIndex i4) {
320  eigen_assert(NumDims == 5);
321  (*this)[0] = i0;
322  (*this)[1] = i1;
323  (*this)[2] = i2;
324  (*this)[3] = i3;
325  (*this)[4] = i4;
326  }
327 #endif
328 
329  EIGEN_DEVICE_FUNC DSizes& operator = (const array<DenseIndex, NumDims>& other) {
330  *static_cast<Base*>(this) = other;
331  return *this;
332  }
333 
334  // A constexpr would be so much better here
335  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE DenseIndex IndexOfColMajor(const array<DenseIndex, NumDims>& indices) const {
336  return internal::tensor_index_linearization_helper<DenseIndex, NumDims, NumDims - 1, false>::run(indices, *static_cast<const Base*>(this));
337  }
338  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE DenseIndex IndexOfRowMajor(const array<DenseIndex, NumDims>& indices) const {
339  return internal::tensor_index_linearization_helper<DenseIndex, NumDims, NumDims - 1, true>::run(indices, *static_cast<const Base*>(this));
340  }
341 };
342 
343 
344 
345 
346 // Boilerplate
347 namespace internal {
348 template<typename Index, std::size_t NumIndices, std::size_t n, bool RowMajor>
349 struct tensor_vsize_index_linearization_helper
350 {
351  static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
352  Index run(array<Index, NumIndices> const& indices, std::vector<DenseIndex> const& dimensions)
353  {
354  return array_get<RowMajor ? n : (NumIndices - n - 1)>(indices) +
355  array_get<RowMajor ? n : (NumIndices - n - 1)>(dimensions) *
356  tensor_vsize_index_linearization_helper<Index, NumIndices, n - 1, RowMajor>::run(indices, dimensions);
357  }
358 };
359 
360 template<typename Index, std::size_t NumIndices, bool RowMajor>
361 struct tensor_vsize_index_linearization_helper<Index, NumIndices, 0, RowMajor>
362 {
363  static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
364  Index run(array<Index, NumIndices> const& indices, std::vector<DenseIndex> const&)
365  {
366  return array_get<RowMajor ? 0 : NumIndices - 1>(indices);
367  }
368 };
369 } // end namespace internal
370 
371 
372 namespace internal {
373 
374 template <typename DenseIndex, std::size_t NumDims> struct array_size<const DSizes<DenseIndex, NumDims> > {
375  static const size_t value = NumDims;
376 };
377 template <typename DenseIndex, std::size_t NumDims> struct array_size<DSizes<DenseIndex, NumDims> > {
378  static const size_t value = NumDims;
379 };
380 #ifndef EIGEN_EMULATE_CXX11_META_H
381 template <typename std::ptrdiff_t... Indices> struct array_size<const Sizes<Indices...> > {
382 static const std::ptrdiff_t value = Sizes<Indices...>::count;
383 };
384 template <typename std::ptrdiff_t... Indices> struct array_size<Sizes<Indices...> > {
385 static const std::ptrdiff_t value = Sizes<Indices...>::count;
386 };
387 template <std::ptrdiff_t n, typename std::ptrdiff_t... Indices> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::ptrdiff_t array_get(const Sizes<Indices...>&) {
388  return get<n, internal::numeric_list<std::size_t, Indices...> >::value;
389 }
390 #else
391 template <std::size_t V1, std::size_t V2, std::size_t V3, std::size_t V4, std::size_t V5> struct array_size<const Sizes<V1,V2,V3,V4,V5> > {
392  static const size_t value = Sizes<V1,V2,V3,V4,V5>::count;
393 };
394 template <std::size_t V1, std::size_t V2, std::size_t V3, std::size_t V4, std::size_t V5> struct array_size<Sizes<V1,V2,V3,V4,V5> > {
395  static const size_t value = Sizes<V1,V2,V3,V4,V5>::count;
396 };
397 template <std::size_t n, std::size_t V1, std::size_t V2, std::size_t V3, std::size_t V4, std::size_t V5> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::size_t array_get(const Sizes<V1,V2,V3,V4,V5>&) {
398  return get<n, typename Sizes<V1,V2,V3,V4,V5>::Base>::value;
399 }
400 
401 #endif
402 
403 
404 template <typename Dims1, typename Dims2, size_t n, size_t m>
405 struct sizes_match_up_to_dim {
406  static inline bool run(Dims1&, Dims2&) {
407  return false;
408  }
409 };
410 template <typename Dims1, typename Dims2, size_t n>
411 struct sizes_match_up_to_dim<Dims1, Dims2, n, n> {
412  static inline bool run(Dims1& dims1, Dims2& dims2) {
413  return (array_get<n>(dims1) == array_get<n>(dims2)) &
414  sizes_match_up_to_dim<Dims1, Dims2, n-1, n-1>::run(dims1, dims2);
415  }
416 };
417 template <typename Dims1, typename Dims2>
418 struct sizes_match_up_to_dim<Dims1, Dims2, 0, 0> {
419  static inline bool run(Dims1& dims1, Dims2& dims2) {
420  return (array_get<0>(dims1) == array_get<0>(dims2));
421  }
422 };
423 
424 } // end namespace internal
425 
426 
427 template <typename Dims1, typename Dims2>
428 bool dimensions_match(Dims1& dims1, Dims2& dims2) {
429  return internal::sizes_match_up_to_dim<Dims1, Dims2, internal::array_size<Dims1>::value-1, internal::array_size<Dims2>::value-1>::run(dims1, dims2);
430 }
431 
432 } // end namespace Eigen
433 
434 #endif // EIGEN_CXX11_TENSOR_TENSOR_DIMENSIONS_H
Namespace containing all symbols from the Eigen library.
Definition: CXX11Meta.h:13