TensorBroadcasting.h
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_BROADCASTING_H
11 #define EIGEN_CXX11_TENSOR_TENSOR_BROADCASTING_H
12 
13 namespace Eigen {
14 
22 namespace internal {
23 template<typename Broadcast, typename XprType>
24 struct traits<TensorBroadcastingOp<Broadcast, XprType> > : public traits<XprType>
25 {
26  typedef typename XprType::Scalar Scalar;
27  typedef traits<XprType> XprTraits;
28  typedef typename packet_traits<Scalar>::type Packet;
29  typedef typename XprTraits::StorageKind StorageKind;
30  typedef typename XprTraits::Index Index;
31  typedef typename XprType::Nested Nested;
32  typedef typename remove_reference<Nested>::type _Nested;
33  static const int NumDimensions = XprTraits::NumDimensions;
34  static const int Layout = XprTraits::Layout;
35 };
36 
37 template<typename Broadcast, typename XprType>
38 struct eval<TensorBroadcastingOp<Broadcast, XprType>, Eigen::Dense>
39 {
40  typedef const TensorBroadcastingOp<Broadcast, XprType>& type;
41 };
42 
43 template<typename Broadcast, typename XprType>
44 struct nested<TensorBroadcastingOp<Broadcast, XprType>, 1, typename eval<TensorBroadcastingOp<Broadcast, XprType> >::type>
45 {
46  typedef TensorBroadcastingOp<Broadcast, XprType> type;
47 };
48 
49 } // end namespace internal
50 
51 
52 
53 template<typename Broadcast, typename XprType>
54 class TensorBroadcastingOp : public TensorBase<TensorBroadcastingOp<Broadcast, XprType>, ReadOnlyAccessors>
55 {
56  public:
57  typedef typename Eigen::internal::traits<TensorBroadcastingOp>::Scalar Scalar;
58  typedef typename Eigen::internal::traits<TensorBroadcastingOp>::Packet Packet;
59  typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
60  typedef typename XprType::CoeffReturnType CoeffReturnType;
61  typedef typename XprType::PacketReturnType PacketReturnType;
62  typedef typename Eigen::internal::nested<TensorBroadcastingOp>::type Nested;
63  typedef typename Eigen::internal::traits<TensorBroadcastingOp>::StorageKind StorageKind;
64  typedef typename Eigen::internal::traits<TensorBroadcastingOp>::Index Index;
65 
66  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorBroadcastingOp(const XprType& expr, const Broadcast& broadcast)
67  : m_xpr(expr), m_broadcast(broadcast) {}
68 
69  EIGEN_DEVICE_FUNC
70  const Broadcast& broadcast() const { return m_broadcast; }
71 
72  EIGEN_DEVICE_FUNC
73  const typename internal::remove_all<typename XprType::Nested>::type&
74  expression() const { return m_xpr; }
75 
76  protected:
77  typename XprType::Nested m_xpr;
78  const Broadcast m_broadcast;
79 };
80 
81 
82 // Eval as rvalue
83 template<typename Broadcast, typename ArgType, typename Device>
84 struct TensorEvaluator<const TensorBroadcastingOp<Broadcast, ArgType>, Device>
85 {
86  typedef TensorBroadcastingOp<Broadcast, ArgType> XprType;
87  typedef typename XprType::Index Index;
88  static const int NumDims = internal::array_size<typename TensorEvaluator<ArgType, Device>::Dimensions>::value;
89  typedef DSizes<Index, NumDims> Dimensions;
90  typedef typename XprType::Scalar Scalar;
91  typedef typename TensorEvaluator<ArgType, Device>::Dimensions InputDimensions;
92 
93  enum {
94  IsAligned = false,
95  PacketAccess = TensorEvaluator<ArgType, Device>::PacketAccess,
96  Layout = TensorEvaluator<ArgType, Device>::Layout,
97  };
98 
99  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device)
100  : m_impl(op.expression(), device)
101  {
102  const typename TensorEvaluator<ArgType, Device>::Dimensions& input_dims = m_impl.dimensions();
103  const Broadcast& broadcast = op.broadcast();
104  for (int i = 0; i < NumDims; ++i) {
105  eigen_assert(input_dims[i] > 0);
106  m_dimensions[i] = input_dims[i] * broadcast[i];
107  }
108 
109  if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
110  m_inputStrides[0] = 1;
111  m_outputStrides[0] = 1;
112  for (int i = 1; i < NumDims; ++i) {
113  m_inputStrides[i] = m_inputStrides[i-1] * input_dims[i-1];
114  m_outputStrides[i] = m_outputStrides[i-1] * m_dimensions[i-1];
115  }
116  } else {
117  m_inputStrides[NumDims-1] = 1;
118  m_outputStrides[NumDims-1] = 1;
119  for (int i = NumDims-2; i >= 0; --i) {
120  m_inputStrides[i] = m_inputStrides[i+1] * input_dims[i+1];
121  m_outputStrides[i] = m_outputStrides[i+1] * m_dimensions[i+1];
122  }
123  }
124  }
125 
126  typedef typename XprType::CoeffReturnType CoeffReturnType;
127  typedef typename XprType::PacketReturnType PacketReturnType;
128 
129  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; }
130 
131  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar* /*data*/) {
132  m_impl.evalSubExprsIfNeeded(NULL);
133  return true;
134  }
135 
136  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() {
137  m_impl.cleanup();
138  }
139 
140  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE CoeffReturnType coeff(Index index) const
141  {
142  if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
143  return coeffColMajor(index);
144  } else {
145  return coeffRowMajor(index);
146  }
147  }
148 
149  // TODO: attempt to speed this up. The integer divisions and modulo are slow
150  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeffColMajor(Index index) const
151  {
152  Index inputIndex = 0;
153  for (int i = NumDims - 1; i > 0; --i) {
154  const Index idx = index / m_outputStrides[i];
155  if (internal::index_statically_eq<Broadcast>()(i, 1)) {
156  eigen_assert(idx < m_impl.dimensions()[i]);
157  inputIndex += idx * m_inputStrides[i];
158  } else {
159  if (internal::index_statically_eq<InputDimensions>()(i, 1)) {
160  eigen_assert(idx % m_impl.dimensions()[i] == 0);
161  } else {
162  inputIndex += (idx % m_impl.dimensions()[i]) * m_inputStrides[i];
163  }
164  }
165  index -= idx * m_outputStrides[i];
166  }
167  if (internal::index_statically_eq<Broadcast>()(0, 1)) {
168  eigen_assert(index < m_impl.dimensions()[0]);
169  inputIndex += index;
170  } else {
171  if (internal::index_statically_eq<InputDimensions>()(0, 1)) {
172  eigen_assert(index % m_impl.dimensions()[0] == 0);
173  } else {
174  inputIndex += (index % m_impl.dimensions()[0]);
175  }
176  }
177  return m_impl.coeff(inputIndex);
178  }
179 
180  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeffRowMajor(Index index) const
181  {
182  Index inputIndex = 0;
183  for (int i = 0; i < NumDims - 1; ++i) {
184  const Index idx = index / m_outputStrides[i];
185  if (internal::index_statically_eq<Broadcast>()(i, 1)) {
186  eigen_assert(idx < m_impl.dimensions()[i]);
187  inputIndex += idx * m_inputStrides[i];
188  } else {
189  if (internal::index_statically_eq<InputDimensions>()(i, 1)) {
190  eigen_assert(idx % m_impl.dimensions()[i] == 0);
191  } else {
192  inputIndex += (idx % m_impl.dimensions()[i]) * m_inputStrides[i];
193  }
194  }
195  index -= idx * m_outputStrides[i];
196  }
197  if (internal::index_statically_eq<Broadcast>()(NumDims-1, 1)) {
198  eigen_assert(index < m_impl.dimensions()[NumDims-1]);
199  inputIndex += index;
200  } else {
201  if (internal::index_statically_eq<InputDimensions>()(NumDims-1, 1)) {
202  eigen_assert(index % m_impl.dimensions()[NumDims-1] == 0);
203  } else {
204  inputIndex += (index % m_impl.dimensions()[NumDims-1]);
205  }
206  }
207  return m_impl.coeff(inputIndex);
208  }
209 
210  template<int LoadMode>
211  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketReturnType packet(Index index) const
212  {
213  if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
214  return packetColMajor<LoadMode>(index);
215  } else {
216  return packetRowMajor<LoadMode>(index);
217  }
218  }
219 
220  // Ignore the LoadMode and always use unaligned loads since we can't guarantee
221  // the alignment at compile time.
222  template<int LoadMode>
223  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packetColMajor(Index index) const
224  {
225  const int packetSize = internal::unpacket_traits<PacketReturnType>::size;
226  EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
227  eigen_assert(index+packetSize-1 < dimensions().TotalSize());
228 
229  const Index originalIndex = index;
230 
231  Index inputIndex = 0;
232  for (int i = NumDims - 1; i > 0; --i) {
233  const Index idx = index / m_outputStrides[i];
234  if (internal::index_statically_eq<Broadcast>()(i, 1)) {
235  eigen_assert(idx < m_impl.dimensions()[i]);
236  inputIndex += idx * m_inputStrides[i];
237  } else {
238  if (internal::index_statically_eq<InputDimensions>()(i, 1)) {
239  eigen_assert(idx % m_impl.dimensions()[i] == 0);
240  } else {
241  inputIndex += (idx % m_impl.dimensions()[i]) * m_inputStrides[i];
242  }
243  }
244  index -= idx * m_outputStrides[i];
245  }
246  Index innermostLoc;
247  if (internal::index_statically_eq<Broadcast>()(0, 1)) {
248  eigen_assert(index < m_impl.dimensions()[0]);
249  innermostLoc = index;
250  } else {
251  if (internal::index_statically_eq<InputDimensions>()(0, 1)) {
252  eigen_assert(index % m_impl.dimensions()[0] == 0);
253  innermostLoc = 0;
254  } else {
255  innermostLoc = index % m_impl.dimensions()[0];
256  }
257  }
258  inputIndex += innermostLoc;
259 
260  // Todo: this could be extended to the second dimension if we're not
261  // broadcasting alongside the first dimension, and so on.
262  if (innermostLoc + packetSize <= m_impl.dimensions()[0]) {
263  return m_impl.template packet<Unaligned>(inputIndex);
264  } else {
265  EIGEN_ALIGN_MAX typename internal::remove_const<CoeffReturnType>::type values[packetSize];
266  values[0] = m_impl.coeff(inputIndex);
267  for (int i = 1; i < packetSize; ++i) {
268  values[i] = coeffColMajor(originalIndex+i);
269  }
270  PacketReturnType rslt = internal::pload<PacketReturnType>(values);
271  return rslt;
272  }
273  }
274 
275  template<int LoadMode>
276  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packetRowMajor(Index index) const
277  {
278  const int packetSize = internal::unpacket_traits<PacketReturnType>::size;
279  EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
280  eigen_assert(index+packetSize-1 < dimensions().TotalSize());
281 
282  const Index originalIndex = index;
283 
284  Index inputIndex = 0;
285  for (int i = 0; i < NumDims - 1; ++i) {
286  const Index idx = index / m_outputStrides[i];
287  if (internal::index_statically_eq<Broadcast>()(i, 1)) {
288  eigen_assert(idx < m_impl.dimensions()[i]);
289  inputIndex += idx * m_inputStrides[i];
290  } else {
291  if (internal::index_statically_eq<InputDimensions>()(i, 1)) {
292  eigen_assert(idx % m_impl.dimensions()[i] == 0);
293  } else {
294  inputIndex += (idx % m_impl.dimensions()[i]) * m_inputStrides[i];
295  }
296  }
297  index -= idx * m_outputStrides[i];
298  }
299  Index innermostLoc;
300  if (internal::index_statically_eq<Broadcast>()(NumDims-1, 1)) {
301  eigen_assert(index < m_impl.dimensions()[NumDims-1]);
302  innermostLoc = index;
303  } else {
304  if (internal::index_statically_eq<InputDimensions>()(NumDims-1, 1)) {
305  eigen_assert(index % m_impl.dimensions()[NumDims-1] == 0);
306  innermostLoc = 0;
307  } else {
308  innermostLoc = index % m_impl.dimensions()[NumDims-1];
309  }
310  }
311  inputIndex += innermostLoc;
312 
313  // Todo: this could be extended to the second dimension if we're not
314  // broadcasting alongside the first dimension, and so on.
315  if (innermostLoc + packetSize <= m_impl.dimensions()[NumDims-1]) {
316  return m_impl.template packet<Unaligned>(inputIndex);
317  } else {
318  EIGEN_ALIGN_MAX typename internal::remove_const<CoeffReturnType>::type values[packetSize];
319  values[0] = m_impl.coeff(inputIndex);
320  for (int i = 1; i < packetSize; ++i) {
321  values[i] = coeffRowMajor(originalIndex+i);
322  }
323  PacketReturnType rslt = internal::pload<PacketReturnType>(values);
324  return rslt;
325  }
326  }
327 
328 
329  EIGEN_DEVICE_FUNC Scalar* data() const { return NULL; }
330 
331  protected:
332  Dimensions m_dimensions;
333  array<Index, NumDims> m_outputStrides;
334  array<Index, NumDims> m_inputStrides;
335  TensorEvaluator<ArgType, Device> m_impl;
336 };
337 
338 
339 } // end namespace Eigen
340 
341 #endif // EIGEN_CXX11_TENSOR_TENSOR_BROADCASTING_H
Namespace containing all symbols from the Eigen library.
Definition: CXX11Meta.h:13