TensorContraction.h
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_H
11 #define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_H
12 
13 namespace Eigen {
14 
22 namespace internal {
23 
24 enum {
25  Rhs = 0,
26  Lhs = 1,
27 };
28 
29 /*
30  * Implementation of the Eigen blas_data_mapper class for tensors.
31  */
32 template<typename Scalar, typename Index, int side,
33  typename Tensor,
34  typename nocontract_t, typename contract_t,
35  int packet_size, bool inner_dim_contiguous>
36 class BaseTensorContractionMapper {
37  public:
38  EIGEN_DEVICE_FUNC
39  BaseTensorContractionMapper(const Tensor& tensor,
40  const nocontract_t& nocontract_strides,
41  const nocontract_t& ij_strides,
42  const contract_t& contract_strides,
43  const contract_t& k_strides) :
44  m_tensor(tensor),
45  m_nocontract_strides(nocontract_strides),
46  m_ij_strides(ij_strides),
47  m_contract_strides(contract_strides),
48  m_k_strides(k_strides) { }
49 
50  EIGEN_DEVICE_FUNC
51  EIGEN_STRONG_INLINE void prefetch(Index /*i*/) { }
52 
53  EIGEN_DEVICE_FUNC
54  EIGEN_STRONG_INLINE Scalar operator()(Index row) const {
55  // column major assumption
56  return operator()(row, 0);
57  }
58 
59  EIGEN_DEVICE_FUNC
60  EIGEN_STRONG_INLINE Scalar operator()(Index row, Index col) const {
61  return m_tensor.coeff(computeIndex(row, col));
62  }
63 
64  EIGEN_DEVICE_FUNC
65  EIGEN_STRONG_INLINE Index computeIndex(Index row, Index col) const {
66  const bool left = (side == Lhs);
67  Index nocontract_val = left ? row : col;
68  Index linidx = 0;
69  for (int i = static_cast<int>(array_size<nocontract_t>::value) - 1; i > 0; i--) {
70  const Index idx = nocontract_val / m_ij_strides[i];
71  linidx += idx * m_nocontract_strides[i];
72  nocontract_val -= idx * m_ij_strides[i];
73  }
74  if (array_size<typename Tensor::Dimensions>::value > array_size<contract_t>::value) {
75  if (side == Lhs && inner_dim_contiguous) {
76  eigen_assert(m_nocontract_strides[0] == 1);
77  linidx += nocontract_val;
78  } else {
79  linidx += nocontract_val * m_nocontract_strides[0];
80  }
81  }
82 
83  Index contract_val = left ? col : row;
84  for (int i = static_cast<int>(array_size<contract_t>::value) - 1; i > 0; i--) {
85  const Index idx = contract_val / m_k_strides[i];
86  linidx += idx * m_contract_strides[i];
87  contract_val -= idx * m_k_strides[i];
88  }
89 
90  if(array_size<contract_t>::value > 0) {
91  if (side == Rhs && inner_dim_contiguous) {
92  eigen_assert(m_contract_strides[0] == 1);
93  linidx += contract_val;
94  } else {
95  linidx += contract_val * m_contract_strides[0];
96  }
97  }
98 
99  return linidx;
100  }
101 
102  EIGEN_DEVICE_FUNC
103  EIGEN_STRONG_INLINE IndexPair<Index> computeIndexPair(Index row, Index col, const Index distance) const {
104  const bool left = (side == Lhs);
105  Index nocontract_val[2] = {left ? row : col, left ? row + distance : col};
106  Index linidx[2] = {0, 0};
107  for (int i = static_cast<int>(array_size<nocontract_t>::value) - 1; i > 0; i--) {
108  const Index idx0 = nocontract_val[0] / m_ij_strides[i];
109  const Index idx1 = nocontract_val[1] / m_ij_strides[i];
110  linidx[0] += idx0 * m_nocontract_strides[i];
111  linidx[1] += idx1 * m_nocontract_strides[i];
112  nocontract_val[0] -= idx0 * m_ij_strides[i];
113  nocontract_val[1] -= idx1 * m_ij_strides[i];
114  }
115  if (array_size<typename Tensor::Dimensions>::value > array_size<contract_t>::value) {
116  if (side == Lhs && inner_dim_contiguous) {
117  eigen_assert(m_nocontract_strides[0] == 1);
118  linidx[0] += nocontract_val[0];
119  linidx[1] += nocontract_val[1];
120  } else {
121  linidx[0] += nocontract_val[0] * m_nocontract_strides[0];
122  linidx[1] += nocontract_val[1] * m_nocontract_strides[0];
123  }
124  }
125 
126  Index contract_val[2] = {left ? col : row, left ? col : row + distance};
127  for (int i = static_cast<int>(array_size<contract_t>::value) - 1; i > 0; i--) {
128  const Index idx0 = contract_val[0] / m_k_strides[i];
129  const Index idx1 = contract_val[1] / m_k_strides[i];
130  linidx[0] += idx0 * m_contract_strides[i];
131  linidx[1] += idx1 * m_contract_strides[i];
132  contract_val[0] -= idx0 * m_k_strides[i];
133  contract_val[1] -= idx1 * m_k_strides[i];
134  }
135 
136  if (side == Rhs && inner_dim_contiguous) {
137  eigen_assert(m_contract_strides[0] == 1);
138  linidx[0] += contract_val[0];
139  linidx[1] += contract_val[1];
140  } else {
141  linidx[0] += contract_val[0] * m_contract_strides[0];
142  linidx[1] += contract_val[1] * m_contract_strides[0];
143  }
144  return IndexPair<Index>(linidx[0], linidx[1]);
145  }
146 
147  Index firstAligned(Index size) const {
148  return size;
149  }
150  Index stride() const {
151  return 1;
152  }
153 
154  protected:
155  const Tensor m_tensor;
156  const nocontract_t m_nocontract_strides;
157  const nocontract_t m_ij_strides;
158  const contract_t m_contract_strides;
159  const contract_t m_k_strides;
160 };
161 
162 
163 
164 template<typename Scalar, typename Index, int side,
165  typename Tensor,
166  typename nocontract_t, typename contract_t,
167  int packet_size,
168  bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment>
169 class TensorContractionInputMapper;
170 
171 template<typename Scalar, typename Index, int side,
172  typename Tensor,
173  typename nocontract_t, typename contract_t,
174  int packet_size,
175  bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment>
176 class TensorContractionSubMapper {
177  public:
178  typedef typename packet_traits<Scalar>::type Packet;
179  typedef typename packet_traits<Scalar>::half HalfPacket;
180 
181  typedef TensorContractionInputMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> ParentMapper;
182  typedef TensorContractionSubMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> Self;
183  typedef Self LinearMapper;
184 
185  EIGEN_DEVICE_FUNC TensorContractionSubMapper(const ParentMapper& base_mapper, Index vert_offset, Index horiz_offset)
186  : m_base_mapper(base_mapper), m_vert_offset(vert_offset), m_horiz_offset(horiz_offset) { }
187 
188  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i) const {
189  return m_base_mapper(i + m_vert_offset, m_horiz_offset);
190  }
191  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i, Index j) const {
192  return m_base_mapper(i + m_vert_offset, j + m_horiz_offset);
193  }
194 
195  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i) const {
196  return m_base_mapper.loadPacket(i + m_vert_offset, m_horiz_offset);
197  }
198  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i, Index j) const {
199  return m_base_mapper.loadPacket(i + m_vert_offset, j + m_horiz_offset);
200  }
201 
202  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE HalfPacket loadHalfPacket(Index i) const {
203  return m_base_mapper.loadHalfPacket(i + m_vert_offset, m_horiz_offset);
204  }
205 
206  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void storePacket(Index i, Packet p) const {
207  m_base_mapper.storePacket(i + m_vert_offset, m_horiz_offset, p);
208  }
209 
210  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper getLinearMapper(Index i, Index j) const {
211  return LinearMapper(m_base_mapper, i + m_vert_offset, j + m_horiz_offset);
212  }
213 
214  template <typename PacketT, int AlignmentType>
215  EIGEN_ALWAYS_INLINE PacketT load(Index i) const {
216  EIGEN_STATIC_ASSERT((internal::is_same<PacketT, Packet>::value), YOU_MADE_A_PROGRAMMING_MISTAKE);
217  EIGEN_STATIC_ASSERT((AlignmentType == Aligned || Alignment == Unaligned), YOU_MADE_A_PROGRAMMING_MISTAKE);
218  return loadPacket(i);
219  }
220 
221  template <typename Packet>
222  bool aligned(Index /*i*/) const {
223  return false;
224  }
225 
226  private:
227  const ParentMapper& m_base_mapper;
228  const Index m_vert_offset;
229  const Index m_horiz_offset;
230 };
231 
232 
233 template<typename Scalar, typename Index, int side,
234  typename Tensor,
235  typename nocontract_t, typename contract_t,
236  int packet_size = (Tensor::PacketAccess ? packet_traits<Scalar>::size : 1),
237  bool inner_dim_contiguous = false, bool inner_dim_reordered = (side != Lhs), int Alignment=Unaligned>
238 class TensorContractionInputMapper
239  : public BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous> {
240 
241  public:
242  typedef BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous> Base;
243  typedef TensorContractionSubMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> SubMapper;
244  typedef SubMapper VectorMapper;
245 
246  TensorContractionInputMapper(const Tensor& tensor,
247  const nocontract_t& nocontract_strides,
248  const nocontract_t& ij_strides,
249  const contract_t& contract_strides,
250  const contract_t& k_strides)
251  : Base(tensor, nocontract_strides, ij_strides, contract_strides, k_strides) { }
252 
253  EIGEN_DEVICE_FUNC
254  EIGEN_STRONG_INLINE SubMapper getSubMapper(Index i, Index j) const {
255  return SubMapper(*this, i, j);
256  }
257 
258  EIGEN_ALWAYS_INLINE VectorMapper getVectorMapper(Index i, Index j) const {
259  return VectorMapper(*this, i, j);
260  }
261 
262  typedef typename packet_traits<Scalar>::type Packet;
263  typedef typename packet_traits<Scalar>::half HalfPacket;
264 
265  EIGEN_DEVICE_FUNC
266  EIGEN_STRONG_INLINE Packet loadPacket(Index i, Index j) const {
267  // whole method makes column major assumption
268 
269  // don't need to add offsets for now (because operator handles that)
270  // current code assumes packet size must be a multiple of 2
271  EIGEN_STATIC_ASSERT(packet_size % 2 == 0, YOU_MADE_A_PROGRAMMING_MISTAKE);
272 
273  if (Tensor::PacketAccess && inner_dim_contiguous && !inner_dim_reordered) {
274  const Index index = this->computeIndex(i, j);
275  eigen_assert(this->computeIndex(i+packet_size-1, j) == index + packet_size-1);
276  return this->m_tensor.template packet<Alignment>(index);
277  }
278 
279  const IndexPair<Index> indexPair = this->computeIndexPair(i, j, packet_size - 1);
280  const Index first = indexPair.first;
281  const Index last = indexPair.second;
282 
283  // We can always do optimized packet reads from left hand side right now, because
284  // the vertical matrix dimension on the left hand side is never contracting.
285  // On the right hand side we need to check if the contracting dimensions may have
286  // been shuffled first.
287  if (Tensor::PacketAccess &&
288  (side == Lhs || internal::array_size<contract_t>::value <= 1 || !inner_dim_reordered) &&
289  (last - first) == (packet_size - 1)) {
290 
291  return this->m_tensor.template packet<Alignment>(first);
292  }
293 
294  EIGEN_ALIGN_MAX Scalar data[packet_size];
295 
296  data[0] = this->m_tensor.coeff(first);
297  for (Index k = 1; k < packet_size - 1; k += 2) {
298  const IndexPair<Index> internal_pair = this->computeIndexPair(i + k, j, 1);
299  data[k] = this->m_tensor.coeff(internal_pair.first);
300  data[k + 1] = this->m_tensor.coeff(internal_pair.second);
301  }
302  data[packet_size - 1] = this->m_tensor.coeff(last);
303 
304  return pload<Packet>(data);
305  }
306 
307  EIGEN_DEVICE_FUNC
308  EIGEN_STRONG_INLINE HalfPacket loadHalfPacket(Index i, Index j) const {
309  // whole method makes column major assumption
310 
311  // don't need to add offsets for now (because operator handles that)
312  const Index half_packet_size = unpacket_traits<HalfPacket>::size;
313  if (half_packet_size == packet_size) {
314  return loadPacket(i, j);
315  }
316  EIGEN_ALIGN_MAX Scalar data[half_packet_size];
317  for (Index k = 0; k < half_packet_size; k++) {
318  data[k] = operator()(i + k, j);
319  }
320  return pload<HalfPacket>(data);
321  }
322 };
323 
324 
325 
326 
327 template<typename Scalar, typename Index, int side,
328  typename Tensor,
329  typename nocontract_t, typename contract_t,
330  bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment>
331 class TensorContractionInputMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered, Alignment>
332  : public BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous> {
333 
334  public:
335  typedef BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous> Base;
336  typedef TensorContractionSubMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered, Alignment> SubMapper;
337  typedef SubMapper VectorMapper;
338 
339  TensorContractionInputMapper(const Tensor& tensor,
340  const nocontract_t& nocontract_strides,
341  const nocontract_t& ij_strides,
342  const contract_t& contract_strides,
343  const contract_t& k_strides)
344  : Base(tensor, nocontract_strides, ij_strides, contract_strides, k_strides) { }
345 
346  EIGEN_DEVICE_FUNC
347  EIGEN_STRONG_INLINE SubMapper getSubMapper(Index i, Index j) const {
348  return SubMapper(*this, i, j);
349  }
350 
351  EIGEN_ALWAYS_INLINE VectorMapper getVectorMapper(Index i, Index j) const {
352  return VectorMapper(*this, i, j);
353  }
354 
355  typedef typename packet_traits<Scalar>::type Packet;
356  EIGEN_DEVICE_FUNC
357  EIGEN_STRONG_INLINE Packet loadPacket(Index i, Index j) const {
358  EIGEN_ALIGN_MAX Scalar data[1];
359  data[0] = this->m_tensor.coeff(this->computeIndex(i, j));
360  return pload<typename packet_traits<Scalar>::type>(data);
361  }
362  EIGEN_DEVICE_FUNC
363  EIGEN_STRONG_INLINE Packet loadHalfPacket(Index i, Index j) const {
364  return loadPacket(i, j);
365  }
366 };
367 
368 
369 template<typename Dimensions, typename LhsXprType, typename RhsXprType>
370 struct traits<TensorContractionOp<Dimensions, LhsXprType, RhsXprType> >
371 {
372  // Type promotion to handle the case where the types of the lhs and the rhs are different.
373  typedef typename internal::promote_storage_type<typename LhsXprType::Scalar,
374  typename RhsXprType::Scalar>::ret Scalar;
375  typedef typename internal::packet_traits<Scalar>::type Packet;
376  typedef typename promote_storage_type<typename traits<LhsXprType>::StorageKind,
377  typename traits<RhsXprType>::StorageKind>::ret StorageKind;
378  typedef typename promote_index_type<typename traits<LhsXprType>::Index,
379  typename traits<RhsXprType>::Index>::type Index;
380  typedef typename LhsXprType::Nested LhsNested;
381  typedef typename RhsXprType::Nested RhsNested;
382  typedef typename remove_reference<LhsNested>::type _LhsNested;
383  typedef typename remove_reference<RhsNested>::type _RhsNested;
384 
385  // From NumDims below.
386  static const int NumDimensions = max_n_1<traits<RhsXprType>::NumDimensions + traits<RhsXprType>::NumDimensions - 2 * array_size<Dimensions>::value>::size;
387  static const int Layout = traits<LhsXprType>::Layout;
388 
389  enum {
390  Flags = 0,
391  };
392 };
393 
394 template<typename Dimensions, typename LhsXprType, typename RhsXprType>
395 struct eval<TensorContractionOp<Dimensions, LhsXprType, RhsXprType>, Eigen::Dense>
396 {
397  typedef const TensorContractionOp<Dimensions, LhsXprType, RhsXprType>& type;
398 };
399 
400 template<typename Dimensions, typename LhsXprType, typename RhsXprType>
401 struct nested<TensorContractionOp<Dimensions, LhsXprType, RhsXprType>, 1, typename eval<TensorContractionOp<Dimensions, LhsXprType, RhsXprType> >::type>
402 {
403  typedef TensorContractionOp<Dimensions, LhsXprType, RhsXprType> type;
404 };
405 
406 template<typename Indices_, typename LeftArgType_, typename RightArgType_, typename Device_>
407 struct traits<TensorEvaluator<const TensorContractionOp<Indices_, LeftArgType_, RightArgType_>, Device_> > {
408  typedef Indices_ Indices;
409  typedef LeftArgType_ LeftArgType;
410  typedef RightArgType_ RightArgType;
411  typedef Device_ Device;
412 
413  // From NumDims below.
414  static const int NumDimensions = max_n_1<traits<LeftArgType_>::NumDimensions + traits<RightArgType_>::NumDimensions - 2 * array_size<Indices_>::value>::size;
415 };
416 
417 } // end namespace internal
418 
419 template<typename Indices, typename LhsXprType, typename RhsXprType>
420 class TensorContractionOp : public TensorBase<TensorContractionOp<Indices, LhsXprType, RhsXprType>, ReadOnlyAccessors>
421 {
422  public:
423  typedef typename Eigen::internal::traits<TensorContractionOp>::Scalar Scalar;
424  typedef typename Eigen::internal::traits<TensorContractionOp>::Packet Packet;
425  typedef typename internal::promote_storage_type<typename LhsXprType::CoeffReturnType,
426  typename RhsXprType::CoeffReturnType>::ret CoeffReturnType;
427  typedef typename internal::promote_storage_type<typename LhsXprType::PacketReturnType,
428  typename RhsXprType::PacketReturnType>::ret PacketReturnType;
429  typedef typename Eigen::internal::nested<TensorContractionOp>::type Nested;
430  typedef typename Eigen::internal::traits<TensorContractionOp>::StorageKind StorageKind;
431  typedef typename Eigen::internal::traits<TensorContractionOp>::Index Index;
432 
433  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionOp(
434  const LhsXprType& lhs, const RhsXprType& rhs, const Indices& dims)
435  : m_lhs_xpr(lhs), m_rhs_xpr(rhs), m_indices(dims) {}
436 
437  EIGEN_DEVICE_FUNC
438  const Indices& indices() const { return m_indices; }
439 
441  EIGEN_DEVICE_FUNC
442  const typename internal::remove_all<typename LhsXprType::Nested>::type&
443  lhsExpression() const { return m_lhs_xpr; }
444 
445  EIGEN_DEVICE_FUNC
446  const typename internal::remove_all<typename RhsXprType::Nested>::type&
447  rhsExpression() const { return m_rhs_xpr; }
448 
449  protected:
450  typename LhsXprType::Nested m_lhs_xpr;
451  typename RhsXprType::Nested m_rhs_xpr;
452  const Indices m_indices;
453 };
454 
455 
456 template<typename Derived>
457 struct TensorContractionEvaluatorBase
458 {
459  typedef typename internal::traits<Derived>::Indices Indices;
460  typedef typename internal::traits<Derived>::LeftArgType LeftArgType;
461  typedef typename internal::traits<Derived>::RightArgType RightArgType;
462  typedef typename internal::traits<Derived>::Device Device;
463 
464  typedef TensorContractionOp<Indices, LeftArgType, RightArgType> XprType;
465  typedef typename internal::remove_const<typename XprType::Scalar>::type Scalar;
466  typedef typename XprType::Packet Packet;
467  typedef typename XprType::Index Index;
468  typedef typename XprType::CoeffReturnType CoeffReturnType;
469  typedef typename XprType::PacketReturnType PacketReturnType;
470 
471  enum {
472  IsAligned = true,
473  PacketAccess = (internal::packet_traits<Scalar>::size > 1),
474  Layout = TensorEvaluator<LeftArgType, Device>::Layout,
475  CoordAccess = false, // to be implemented
476  };
477 
478  // Most of the code is assuming that both input tensors are ColMajor. If the
479  // inputs are RowMajor, we will "cheat" by swapping the LHS and RHS:
480  // If we want to compute A * B = C, where A is LHS and B is RHS, the code
481  // will pretend B is LHS and A is RHS.
482  typedef typename internal::conditional<
483  static_cast<int>(Layout) == static_cast<int>(ColMajor), LeftArgType, RightArgType>::type EvalLeftArgType;
484  typedef typename internal::conditional<
485  static_cast<int>(Layout) == static_cast<int>(ColMajor), RightArgType, LeftArgType>::type EvalRightArgType;
486 
487  static const int LDims =
488  internal::array_size<typename TensorEvaluator<EvalLeftArgType, Device>::Dimensions>::value;
489  static const int RDims =
490  internal::array_size<typename TensorEvaluator<EvalRightArgType, Device>::Dimensions>::value;
491  static const unsigned int ContractDims = internal::array_size<Indices>::value;
492  static const int NumDims = max_n_1<LDims + RDims - 2 * ContractDims>::size;
493 
494  typedef array<Index, LDims> left_dim_mapper_t;
495  typedef array<Index, RDims> right_dim_mapper_t;
496  typedef array<Index, ContractDims> contract_t;
497  typedef array<Index, max_n_1<LDims - ContractDims>::size> left_nocontract_t;
498  typedef array<Index, max_n_1<RDims - ContractDims>::size> right_nocontract_t;
499 
500  typedef DSizes<Index, NumDims> Dimensions;
501 
502  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
503  TensorContractionEvaluatorBase(const XprType& op, const Device& device)
504  : m_leftImpl(choose(Cond<static_cast<int>(Layout) == static_cast<int>(ColMajor)>(),
505  op.lhsExpression(), op.rhsExpression()), device),
506  m_rightImpl(choose(Cond<static_cast<int>(Layout) == static_cast<int>(ColMajor)>(),
507  op.rhsExpression(), op.lhsExpression()), device),
508  m_device(device),
509  m_result(NULL) {
510  EIGEN_STATIC_ASSERT((static_cast<int>(TensorEvaluator<LeftArgType, Device>::Layout) ==
511  static_cast<int>(TensorEvaluator<RightArgType, Device>::Layout)),
512  YOU_MADE_A_PROGRAMMING_MISTAKE);
513 
514 
515  DSizes<Index, LDims> eval_left_dims;
516  DSizes<Index, RDims> eval_right_dims;
517  array<IndexPair<Index>, ContractDims> eval_op_indices;
518  if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
519  // For ColMajor, we keep using the existing dimensions
520  for (int i = 0; i < LDims; i++) {
521  eval_left_dims[i] = m_leftImpl.dimensions()[i];
522  }
523  for (int i = 0; i < RDims; i++) {
524  eval_right_dims[i] = m_rightImpl.dimensions()[i];
525  }
526  // We keep the pairs of contracting indices.
527  for (unsigned int i = 0; i < ContractDims; i++) {
528  eval_op_indices[i].first = op.indices()[i].first;
529  eval_op_indices[i].second = op.indices()[i].second;
530  }
531  } else {
532  // For RowMajor, we need to reverse the existing dimensions
533  for (int i = 0; i < LDims; i++) {
534  eval_left_dims[i] = m_leftImpl.dimensions()[LDims - i - 1];
535  }
536  for (int i = 0; i < RDims; i++) {
537  eval_right_dims[i] = m_rightImpl.dimensions()[RDims - i - 1];
538  }
539  // We need to flip all the pairs of contracting indices as well as
540  // reversing the dimensions.
541  for (unsigned int i = 0; i < ContractDims; i++) {
542  eval_op_indices[i].first = LDims - 1 - op.indices()[i].second;
543  eval_op_indices[i].second = RDims - 1 - op.indices()[i].first;
544  }
545  }
546 
547  array<Index, LDims> lhs_strides;
548  lhs_strides[0] = 1;
549  for (int i = 0; i < LDims-1; ++i) {
550  lhs_strides[i+1] = lhs_strides[i] * eval_left_dims[i];
551  }
552 
553  array<Index, RDims> rhs_strides;
554  rhs_strides[0] = 1;
555  for (int i = 0; i < RDims-1; ++i) {
556  rhs_strides[i+1] = rhs_strides[i] * eval_right_dims[i];
557  }
558 
559  m_i_strides[0] = 1;
560  m_j_strides[0] = 1;
561  if(ContractDims) {
562  m_k_strides[0] = 1;
563  }
564 
565  m_i_size = 1;
566  m_j_size = 1;
567  m_k_size = 1;
568 
569  // To compute the dimension, we simply concatenate the non-contracting
570  // dimensions of the left and then the right tensor. Additionally, we also
571  // compute the strides corresponding to the left non-contracting
572  // dimensions and right non-contracting dimensions.
573  m_lhs_inner_dim_contiguous = true;
574  int dim_idx = 0;
575  unsigned int nocontract_idx = 0;
576 
577  for (int i = 0; i < LDims; i++) {
578  // find if we are contracting on index i of left tensor
579  bool contracting = false;
580  for (unsigned int j = 0; j < ContractDims; j++) {
581  if (eval_op_indices[j].first == i) {
582  contracting = true;
583  break;
584  }
585  }
586  if (!contracting) {
587  // add dimension size to output dimensions
588  m_dimensions[dim_idx] = eval_left_dims[i];
589  m_left_nocontract_strides[nocontract_idx] = lhs_strides[i];
590  if (dim_idx != i) {
591  m_lhs_inner_dim_contiguous = false;
592  }
593  if (nocontract_idx+1 < internal::array_size<left_nocontract_t>::value) {
594  m_i_strides[nocontract_idx+1] =
595  m_i_strides[nocontract_idx] * eval_left_dims[i];
596  } else {
597  m_i_size = m_i_strides[nocontract_idx] * eval_left_dims[i];
598  }
599  dim_idx++;
600  nocontract_idx++;
601  }
602  }
603 
604  nocontract_idx = 0;
605  for (int i = 0; i < RDims; i++) {
606  bool contracting = false;
607  // find if we are contracting on index i of right tensor
608  for (unsigned int j = 0; j < ContractDims; j++) {
609  if (eval_op_indices[j].second == i) {
610  contracting = true;
611  break;
612  }
613  }
614  if (!contracting) {
615  m_dimensions[dim_idx] = eval_right_dims[i];
616  if (nocontract_idx+1 < internal::array_size<right_nocontract_t>::value) {
617  m_j_strides[nocontract_idx+1] =
618  m_j_strides[nocontract_idx] * eval_right_dims[i];
619  } else {
620  m_j_size = m_j_strides[nocontract_idx] * eval_right_dims[i];
621  }
622  m_right_nocontract_strides[nocontract_idx] = rhs_strides[i];
623  dim_idx++;
624  nocontract_idx++;
625  }
626  }
627 
628  // Now compute the strides corresponding to the contracting dimensions. We
629  // assumed above that non-contracting axes are represented in the same order
630  // in the matrix as they are in the tensor. This is not the case for
631  // contracting axes. As the contracting axes must be of the same size in
632  // each tensor, we'll only look at the first tensor here.
633  m_rhs_inner_dim_contiguous = true;
634  m_rhs_inner_dim_reordered = false;
635  for (unsigned int i = 0; i < ContractDims; i++) {
636  Index left = eval_op_indices[i].first;
637  Index right = eval_op_indices[i].second;
638 
639  Index size = eval_left_dims[left];
640  eigen_assert(size == eval_right_dims[right] &&
641  "Contraction axes must be same size");
642 
643  if (i+1 < internal::array_size<contract_t>::value) {
644  m_k_strides[i+1] = m_k_strides[i] * size;
645  } else {
646  m_k_size = m_k_strides[i] * size;
647  }
648  m_left_contracting_strides[i] = lhs_strides[left];
649  m_right_contracting_strides[i] = rhs_strides[right];
650 
651  if (i > 0 && right < eval_op_indices[i-1].second) {
652  m_rhs_inner_dim_reordered = true;
653  }
654  if (right != i) {
655  m_rhs_inner_dim_contiguous = false;
656  }
657  }
658 
659  // Scalar case. We represent the result as a 1d tensor of size 1.
660  if (LDims + RDims == 2 * ContractDims) {
661  m_dimensions[0] = 1;
662  }
663 
664  // If the layout is RowMajor, we need to reverse the m_dimensions
665  if (static_cast<int>(Layout) == static_cast<int>(RowMajor)) {
666  for (int i = 0, j = NumDims - 1; i < j; i++, j--) {
667  numext::swap(m_dimensions[i], m_dimensions[j]);
668  }
669  }
670  }
671 
672  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; }
673 
674  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar* data) {
675  m_leftImpl.evalSubExprsIfNeeded(NULL);
676  m_rightImpl.evalSubExprsIfNeeded(NULL);
677  if (data) {
678  evalTo(data);
679  return false;
680  } else {
681  m_result = static_cast<Scalar *>(m_device.allocate(dimensions().TotalSize() * sizeof(Scalar)));
682  evalTo(m_result);
683  return true;
684  }
685  }
686 
687  EIGEN_DEVICE_FUNC void evalTo(Scalar* buffer) const {
688  if (this->m_lhs_inner_dim_contiguous) {
689  if (this->m_rhs_inner_dim_contiguous) {
690  if (this->m_rhs_inner_dim_reordered) {
691  static_cast<const Derived*>(this)->template evalProduct<true, true, true, Unaligned>(buffer);
692  }
693  else {
694  static_cast<const Derived*>(this)->template evalProduct<true, true, false, Unaligned>(buffer);
695  }
696  }
697  else {
698  if (this->m_rhs_inner_dim_reordered) {
699  static_cast<const Derived*>(this)->template evalProduct<true, false, true, Unaligned>(buffer);
700  }
701  else {
702  static_cast<const Derived*>(this)->template evalProduct<true, false, false, Unaligned>(buffer);
703  }
704  }
705  }
706  else {
707  if (this->m_rhs_inner_dim_contiguous) {
708  if (this->m_rhs_inner_dim_reordered) {
709  static_cast<const Derived*>(this)->template evalProduct<false, true, true, Unaligned>(buffer);
710  }
711  else {
712  static_cast<const Derived*>(this)->template evalProduct<false, true, false, Unaligned>(buffer);
713  }
714  }
715  else {
716  if (this->m_rhs_inner_dim_reordered) {
717  static_cast<const Derived*>(this)->template evalProduct<false, false, true, Unaligned>(buffer);
718  }
719  else {
720  static_cast<const Derived*>(this)->template evalProduct<false, false, false, Unaligned>(buffer);
721  }
722  }
723  }
724  }
725 
726  template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment>
727  void evalGemv(Scalar* buffer) const {
728  const Index rows = m_i_size;
729  const Index cols = m_k_size;
730 
731  typedef typename internal::remove_const<typename EvalLeftArgType::Scalar>::type LhsScalar;
732  typedef typename internal::remove_const<typename EvalRightArgType::Scalar>::type RhsScalar;
733  typedef TensorEvaluator<EvalLeftArgType, Device> LeftEvaluator;
734  typedef TensorEvaluator<EvalRightArgType, Device> RightEvaluator;
735  const Index lhs_packet_size = internal::packet_traits<LhsScalar>::size;
736  const Index rhs_packet_size = internal::packet_traits<RhsScalar>::size;
737  typedef internal::TensorContractionInputMapper<LhsScalar, Index, internal::Lhs,
738  LeftEvaluator, left_nocontract_t,
739  contract_t, lhs_packet_size,
740  lhs_inner_dim_contiguous,
741  false, Unaligned> LhsMapper;
742 
743  typedef internal::TensorContractionInputMapper<RhsScalar, Index, internal::Rhs,
744  RightEvaluator, right_nocontract_t,
745  contract_t, rhs_packet_size,
746  rhs_inner_dim_contiguous,
747  rhs_inner_dim_reordered, Unaligned> RhsMapper;
748 
749  LhsMapper lhs(m_leftImpl, m_left_nocontract_strides, m_i_strides,
750  m_left_contracting_strides, m_k_strides);
751  RhsMapper rhs(m_rightImpl, m_right_nocontract_strides, m_j_strides,
752  m_right_contracting_strides, m_k_strides);
753 
754  const Scalar alpha(1);
755  const Index resIncr(1);
756 
757  // zero out the result buffer (which must be of size at least rows * sizeof(Scalar)
758  m_device.memset(buffer, 0, rows * sizeof(Scalar));
759 
760  internal::general_matrix_vector_product<Index,LhsScalar,LhsMapper,ColMajor,false,RhsScalar,RhsMapper,false>::run(
761  rows, cols, lhs, rhs,
762  buffer, resIncr, alpha);
763  }
764 
765  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() {
766  m_leftImpl.cleanup();
767  m_rightImpl.cleanup();
768 
769  if (m_result != NULL) {
770  m_device.deallocate(m_result);
771  m_result = NULL;
772  }
773  }
774 
775  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const {
776  return m_result[index];
777  }
778 
779  template<int LoadMode>
780  EIGEN_DEVICE_FUNC PacketReturnType packet(Index index) const {
781  return internal::ploadt<Packet, LoadMode>(m_result + index);
782  }
783 
784  EIGEN_DEVICE_FUNC Scalar* data() const { return NULL; }
785 
786  protected:
787  // Prevent assignment
788  TensorContractionEvaluatorBase& operator = (const TensorContractionEvaluatorBase&);
789  Dimensions m_dimensions;
790 
791  contract_t m_k_strides;
792  contract_t m_left_contracting_strides;
793  contract_t m_right_contracting_strides;
794 
795  bool m_lhs_inner_dim_contiguous;
796  bool m_rhs_inner_dim_contiguous;
797  bool m_rhs_inner_dim_reordered;
798 
799  left_nocontract_t m_i_strides;
800  right_nocontract_t m_j_strides;
801  left_nocontract_t m_left_nocontract_strides;
802  right_nocontract_t m_right_nocontract_strides;
803 
804  Index m_i_size;
805  Index m_j_size;
806  Index m_k_size;
807 
808  TensorEvaluator<EvalLeftArgType, Device> m_leftImpl;
809  TensorEvaluator<EvalRightArgType, Device> m_rightImpl;
810  const Device& m_device;
811  Scalar* m_result;
812 };
813 
814 
815 // evaluator for default device
816 template<typename Indices, typename LeftArgType, typename RightArgType, typename Device>
817 struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, Device> :
818  public TensorContractionEvaluatorBase<
819  TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, Device> > {
820  typedef TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, Device> Self;
821  typedef TensorContractionEvaluatorBase<Self> Base;
822 
823  typedef TensorContractionOp<Indices, LeftArgType, RightArgType> XprType;
824  typedef typename internal::remove_const<typename XprType::Scalar>::type Scalar;
825  typedef typename XprType::Packet Packet;
826  typedef typename XprType::Index Index;
827  typedef typename XprType::CoeffReturnType CoeffReturnType;
828  typedef typename XprType::PacketReturnType PacketReturnType;
829 
830  enum {
831  Layout = TensorEvaluator<LeftArgType, Device>::Layout,
832  };
833 
834  // Most of the code is assuming that both input tensors are ColMajor. If the
835  // inputs are RowMajor, we will "cheat" by swapping the LHS and RHS:
836  // If we want to compute A * B = C, where A is LHS and B is RHS, the code
837  // will pretend B is LHS and A is RHS.
838  typedef typename internal::conditional<
839  static_cast<int>(Layout) == static_cast<int>(ColMajor), LeftArgType, RightArgType>::type EvalLeftArgType;
840  typedef typename internal::conditional<
841  static_cast<int>(Layout) == static_cast<int>(ColMajor), RightArgType, LeftArgType>::type EvalRightArgType;
842 
843  static const int LDims =
844  internal::array_size<typename TensorEvaluator<EvalLeftArgType, Device>::Dimensions>::value;
845  static const int RDims =
846  internal::array_size<typename TensorEvaluator<EvalRightArgType, Device>::Dimensions>::value;
847  static const int ContractDims = internal::array_size<Indices>::value;
848 
849  typedef array<Index, LDims> left_dim_mapper_t;
850  typedef array<Index, RDims> right_dim_mapper_t;
851 
852  typedef array<Index, ContractDims> contract_t;
853  typedef array<Index, max_n_1<LDims - ContractDims>::size> left_nocontract_t;
854  typedef array<Index, max_n_1<RDims - ContractDims>::size> right_nocontract_t;
855 
856  static const int NumDims = max_n_1<LDims + RDims - 2 * ContractDims>::size;
857 
858  // Could we use NumDimensions here?
859  typedef DSizes<Index, NumDims> Dimensions;
860 
861 
862  EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op, const Device& device) :
863  Base(op, device) { }
864 
865  template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment>
866  void evalProduct(Scalar* buffer) const {
867  if (this->m_j_size == 1) {
868  this->template evalGemv<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Alignment>(buffer);
869  return;
870  }
871 
872  evalGemm<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Alignment>(buffer);
873  }
874 
875  template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment>
876  EIGEN_DEVICE_FUNC void evalGemm(Scalar* buffer) const {
877  // columns in left side, rows in right side
878  const Index k = this->m_k_size;
879 
880  // rows in left side
881  const Index m = this->m_i_size;
882 
883  // columns in right side
884  const Index n = this->m_j_size;
885 
886  // zero out the result buffer (which must be of size at least m * n * sizeof(Scalar)
887  this->m_device.memset(buffer, 0, m * n * sizeof(Scalar));
888 
889  // define mr, nr, and all of my data mapper types
890  typedef typename internal::remove_const<typename EvalLeftArgType::Scalar>::type LhsScalar;
891  typedef typename internal::remove_const<typename EvalRightArgType::Scalar>::type RhsScalar;
892  typedef typename internal::gebp_traits<LhsScalar, RhsScalar> Traits;
893 
894  const Index nr = Traits::nr;
895  const Index mr = Traits::mr;
896 
897  typedef TensorEvaluator<EvalLeftArgType, Device> LeftEvaluator;
898  typedef TensorEvaluator<EvalRightArgType, Device> RightEvaluator;
899 
900  const Index lhs_packet_size = internal::packet_traits<LhsScalar>::size;
901  const Index rhs_packet_size = internal::packet_traits<RhsScalar>::size;
902 
903  typedef internal::TensorContractionInputMapper<LhsScalar, Index, internal::Lhs,
904  LeftEvaluator, left_nocontract_t,
905  contract_t, lhs_packet_size,
906  lhs_inner_dim_contiguous,
907  false, Unaligned> LhsMapper;
908 
909  typedef internal::TensorContractionInputMapper<RhsScalar, Index, internal::Rhs,
910  RightEvaluator, right_nocontract_t,
911  contract_t, rhs_packet_size,
912  rhs_inner_dim_contiguous,
913  rhs_inner_dim_reordered, Unaligned> RhsMapper;
914 
915  typedef internal::blas_data_mapper<Scalar, Index, ColMajor> OutputMapper;
916 
917  // Declare GEBP packing and kernel structs
918  internal::gemm_pack_lhs<LhsScalar, Index, typename LhsMapper::SubMapper, mr, Traits::LhsProgress, ColMajor> pack_lhs;
919  internal::gemm_pack_rhs<RhsScalar, Index, typename RhsMapper::SubMapper, nr, ColMajor> pack_rhs;
920 
921  internal::gebp_kernel<LhsScalar, RhsScalar, Index, OutputMapper, mr, nr, false, false> gebp;
922 
923  // initialize data mappers
924  LhsMapper lhs(this->m_leftImpl, this->m_left_nocontract_strides, this->m_i_strides,
925  this->m_left_contracting_strides, this->m_k_strides);
926 
927  RhsMapper rhs(this->m_rightImpl, this->m_right_nocontract_strides, this->m_j_strides,
928  this->m_right_contracting_strides, this->m_k_strides);
929 
930  OutputMapper output(buffer, m);
931 
932  typedef typename internal::gemm_blocking_space<ColMajor, LhsScalar, RhsScalar, Dynamic, Dynamic, Dynamic> BlockingType;
933 
934  // Sizes of the blocks to load in cache. See the Goto paper for details.
935  BlockingType blocking(m, n, k, 1, true);
936  const Index kc = blocking.kc();
937  const Index mc = numext::mini(m, blocking.mc());
938  const Index nc = numext::mini(n, blocking.nc());
939  const Index sizeA = mc * kc;
940  const Index sizeB = kc * nc;
941 
942  LhsScalar* blockA = static_cast<LhsScalar *>(this->m_device.allocate(sizeA * sizeof(LhsScalar)));
943  RhsScalar* blockB = static_cast<RhsScalar *>(this->m_device.allocate(sizeB * sizeof(RhsScalar)));
944 
945  for(Index i2=0; i2<m; i2+=mc)
946  {
947  const Index actual_mc = numext::mini(i2+mc,m)-i2;
948  for (Index k2 = 0; k2 < k; k2 += kc) {
949  // make sure we don't overshoot right edge of left matrix, then pack vertical panel
950  const Index actual_kc = numext::mini(k2 + kc, k) - k2;
951  pack_lhs(blockA, lhs.getSubMapper(i2, k2), actual_kc, actual_mc, 0, 0);
952 
953  // series of horizontal blocks
954  for (Index j2 = 0; j2 < n; j2 += nc) {
955  // make sure we don't overshoot right edge of right matrix, then pack block
956  const Index actual_nc = numext::mini(j2 + nc, n) - j2;
957  pack_rhs(blockB, rhs.getSubMapper(k2, j2), actual_kc, actual_nc, 0, 0);
958 
959  // call gebp (matrix kernel)
960  // The parameters here are copied from Eigen's GEMM implementation
961  gebp(output.getSubMapper(i2, j2), blockA, blockB, actual_mc, actual_kc, actual_nc, 1.0, -1, -1, 0, 0);
962  }
963  }
964  }
965 
966  this->m_device.deallocate(blockA);
967  this->m_device.deallocate(blockB);
968  }
969 };
970 
971 } // end namespace Eigen
972 
973 #endif // EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_H
Namespace containing all symbols from the Eigen library.
Definition: CXX11Meta.h:13