TensorPatch.h
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_PATCH_H
11 #define EIGEN_CXX11_TENSOR_TENSOR_PATCH_H
12 
13 namespace Eigen {
14 
22 namespace internal {
23 template<typename PatchDim, typename XprType>
24 struct traits<TensorPatchOp<PatchDim, XprType> > : public traits<XprType>
25 {
26  typedef typename XprType::Scalar Scalar;
27  typedef traits<XprType> XprTraits;
28  typedef typename packet_traits<Scalar>::type Packet;
29  typedef typename XprTraits::StorageKind StorageKind;
30  typedef typename XprTraits::Index Index;
31  typedef typename XprType::Nested Nested;
32  typedef typename remove_reference<Nested>::type _Nested;
33  static const int NumDimensions = XprTraits::NumDimensions + 1;
34  static const int Layout = XprTraits::Layout;
35 };
36 
37 template<typename PatchDim, typename XprType>
38 struct eval<TensorPatchOp<PatchDim, XprType>, Eigen::Dense>
39 {
40  typedef const TensorPatchOp<PatchDim, XprType>& type;
41 };
42 
43 template<typename PatchDim, typename XprType>
44 struct nested<TensorPatchOp<PatchDim, XprType>, 1, typename eval<TensorPatchOp<PatchDim, XprType> >::type>
45 {
46  typedef TensorPatchOp<PatchDim, XprType> type;
47 };
48 
49 } // end namespace internal
50 
51 
52 
53 template<typename PatchDim, typename XprType>
54 class TensorPatchOp : public TensorBase<TensorPatchOp<PatchDim, XprType>, ReadOnlyAccessors>
55 {
56  public:
57  typedef typename Eigen::internal::traits<TensorPatchOp>::Scalar Scalar;
58  typedef typename Eigen::internal::traits<TensorPatchOp>::Packet Packet;
59  typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
60  typedef typename XprType::CoeffReturnType CoeffReturnType;
61  typedef typename XprType::PacketReturnType PacketReturnType;
62  typedef typename Eigen::internal::nested<TensorPatchOp>::type Nested;
63  typedef typename Eigen::internal::traits<TensorPatchOp>::StorageKind StorageKind;
64  typedef typename Eigen::internal::traits<TensorPatchOp>::Index Index;
65 
66  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorPatchOp(const XprType& expr, const PatchDim& patch_dims)
67  : m_xpr(expr), m_patch_dims(patch_dims) {}
68 
69  EIGEN_DEVICE_FUNC
70  const PatchDim& patch_dims() const { return m_patch_dims; }
71 
72  EIGEN_DEVICE_FUNC
73  const typename internal::remove_all<typename XprType::Nested>::type&
74  expression() const { return m_xpr; }
75 
76  protected:
77  typename XprType::Nested m_xpr;
78  const PatchDim m_patch_dims;
79 };
80 
81 
82 // Eval as rvalue
83 template<typename PatchDim, typename ArgType, typename Device>
84 struct TensorEvaluator<const TensorPatchOp<PatchDim, ArgType>, Device>
85 {
86  typedef TensorPatchOp<PatchDim, ArgType> XprType;
87  typedef typename XprType::Index Index;
88  static const int NumDims = internal::array_size<typename TensorEvaluator<ArgType, Device>::Dimensions>::value + 1;
89  typedef DSizes<Index, NumDims> Dimensions;
90  typedef typename XprType::Scalar Scalar;
91 
92  enum {
93  IsAligned = false,
94  PacketAccess = TensorEvaluator<ArgType, Device>::PacketAccess,
95  Layout = TensorEvaluator<ArgType, Device>::Layout,
96  CoordAccess = true,
97  };
98 
99  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device)
100  : m_impl(op.expression(), device)
101  {
102  Index num_patches = 1;
103  const typename TensorEvaluator<ArgType, Device>::Dimensions& input_dims = m_impl.dimensions();
104  const PatchDim& patch_dims = op.patch_dims();
105  if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
106  for (int i = 0; i < NumDims-1; ++i) {
107  m_dimensions[i] = patch_dims[i];
108  num_patches *= (input_dims[i] - patch_dims[i] + 1);
109  }
110  m_dimensions[NumDims-1] = num_patches;
111 
112  m_inputStrides[0] = 1;
113  m_patchStrides[0] = 1;
114  for (int i = 1; i < NumDims-1; ++i) {
115  m_inputStrides[i] = m_inputStrides[i-1] * input_dims[i-1];
116  m_patchStrides[i] = m_patchStrides[i-1] * (input_dims[i-1] - patch_dims[i-1] + 1);
117  }
118  m_outputStrides[0] = 1;
119  for (int i = 1; i < NumDims; ++i) {
120  m_outputStrides[i] = m_outputStrides[i-1] * m_dimensions[i-1];
121  }
122  } else {
123  for (int i = 0; i < NumDims-1; ++i) {
124  m_dimensions[i+1] = patch_dims[i];
125  num_patches *= (input_dims[i] - patch_dims[i] + 1);
126  }
127  m_dimensions[0] = num_patches;
128 
129  m_inputStrides[NumDims-2] = 1;
130  m_patchStrides[NumDims-2] = 1;
131  for (int i = NumDims-3; i >= 0; --i) {
132  m_inputStrides[i] = m_inputStrides[i+1] * input_dims[i+1];
133  m_patchStrides[i] = m_patchStrides[i+1] * (input_dims[i+1] - patch_dims[i+1] + 1);
134  }
135  m_outputStrides[NumDims-1] = 1;
136  for (int i = NumDims-2; i >= 0; --i) {
137  m_outputStrides[i] = m_outputStrides[i+1] * m_dimensions[i+1];
138  }
139  }
140  }
141 
142  typedef typename XprType::CoeffReturnType CoeffReturnType;
143  typedef typename XprType::PacketReturnType PacketReturnType;
144 
145  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; }
146 
147  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar* /*data*/) {
148  m_impl.evalSubExprsIfNeeded(NULL);
149  return true;
150  }
151 
152  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() {
153  m_impl.cleanup();
154  }
155 
156  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const
157  {
158  Index output_stride_index = (static_cast<int>(Layout) == static_cast<int>(ColMajor)) ? NumDims - 1 : 0;
159  // Find the location of the first element of the patch.
160  Index patchIndex = index / m_outputStrides[output_stride_index];
161  // Find the offset of the element wrt the location of the first element.
162  Index patchOffset = index - patchIndex * m_outputStrides[output_stride_index];
163  Index inputIndex = 0;
164  if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
165  for (int i = NumDims - 2; i > 0; --i) {
166  const Index patchIdx = patchIndex / m_patchStrides[i];
167  patchIndex -= patchIdx * m_patchStrides[i];
168  const Index offsetIdx = patchOffset / m_outputStrides[i];
169  patchOffset -= offsetIdx * m_outputStrides[i];
170  inputIndex += (patchIdx + offsetIdx) * m_inputStrides[i];
171  }
172  } else {
173  for (int i = 0; i < NumDims - 2; ++i) {
174  const Index patchIdx = patchIndex / m_patchStrides[i];
175  patchIndex -= patchIdx * m_patchStrides[i];
176  const Index offsetIdx = patchOffset / m_outputStrides[i+1];
177  patchOffset -= offsetIdx * m_outputStrides[i+1];
178  inputIndex += (patchIdx + offsetIdx) * m_inputStrides[i];
179  }
180  }
181  inputIndex += (patchIndex + patchOffset);
182  return m_impl.coeff(inputIndex);
183  }
184 
185  template<int LoadMode>
186  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index) const
187  {
188  const int packetSize = internal::unpacket_traits<PacketReturnType>::size;
189  EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
190  eigen_assert(index+packetSize-1 < dimensions().TotalSize());
191 
192  Index output_stride_index = (static_cast<int>(Layout) == static_cast<int>(ColMajor)) ? NumDims - 1 : 0;
193  Index indices[2] = {index, index + packetSize - 1};
194  Index patchIndices[2] = {indices[0] / m_outputStrides[output_stride_index],
195  indices[1] / m_outputStrides[output_stride_index]};
196  Index patchOffsets[2] = {indices[0] - patchIndices[0] * m_outputStrides[output_stride_index],
197  indices[1] - patchIndices[1] * m_outputStrides[output_stride_index]};
198 
199  Index inputIndices[2] = {0, 0};
200  if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
201  for (int i = NumDims - 2; i > 0; --i) {
202  const Index patchIdx[2] = {patchIndices[0] / m_patchStrides[i],
203  patchIndices[1] / m_patchStrides[i]};
204  patchIndices[0] -= patchIdx[0] * m_patchStrides[i];
205  patchIndices[1] -= patchIdx[1] * m_patchStrides[i];
206 
207  const Index offsetIdx[2] = {patchOffsets[0] / m_outputStrides[i],
208  patchOffsets[1] / m_outputStrides[i]};
209  patchOffsets[0] -= offsetIdx[0] * m_outputStrides[i];
210  patchOffsets[1] -= offsetIdx[1] * m_outputStrides[i];
211 
212  inputIndices[0] += (patchIdx[0] + offsetIdx[0]) * m_inputStrides[i];
213  inputIndices[1] += (patchIdx[1] + offsetIdx[1]) * m_inputStrides[i];
214  }
215  } else {
216  for (int i = 0; i < NumDims - 2; ++i) {
217  const Index patchIdx[2] = {patchIndices[0] / m_patchStrides[i],
218  patchIndices[1] / m_patchStrides[i]};
219  patchIndices[0] -= patchIdx[0] * m_patchStrides[i];
220  patchIndices[1] -= patchIdx[1] * m_patchStrides[i];
221 
222  const Index offsetIdx[2] = {patchOffsets[0] / m_outputStrides[i+1],
223  patchOffsets[1] / m_outputStrides[i+1]};
224  patchOffsets[0] -= offsetIdx[0] * m_outputStrides[i+1];
225  patchOffsets[1] -= offsetIdx[1] * m_outputStrides[i+1];
226 
227  inputIndices[0] += (patchIdx[0] + offsetIdx[0]) * m_inputStrides[i];
228  inputIndices[1] += (patchIdx[1] + offsetIdx[1]) * m_inputStrides[i];
229  }
230  }
231  inputIndices[0] += (patchIndices[0] + patchOffsets[0]);
232  inputIndices[1] += (patchIndices[1] + patchOffsets[1]);
233 
234  if (inputIndices[1] - inputIndices[0] == packetSize - 1) {
235  PacketReturnType rslt = m_impl.template packet<Unaligned>(inputIndices[0]);
236  return rslt;
237  }
238  else {
239  EIGEN_ALIGN_MAX CoeffReturnType values[packetSize];
240  values[0] = m_impl.coeff(inputIndices[0]);
241  values[packetSize-1] = m_impl.coeff(inputIndices[1]);
242  for (int i = 1; i < packetSize-1; ++i) {
243  values[i] = coeff(index+i);
244  }
245  PacketReturnType rslt = internal::pload<PacketReturnType>(values);
246  return rslt;
247  }
248  }
249 
250  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(const array<Index, NumDims>& coords) const
251  {
252  Index patch_coord_idx = Layout == ColMajor ? NumDims - 1 : 0;
253  // Location of the first element of the patch.
254  const Index patchIndex = coords[patch_coord_idx];
255 
256  if (TensorEvaluator<ArgType, Device>::CoordAccess) {
257  array<Index, NumDims-1> inputCoords;
258  if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
259  for (int i = NumDims - 2; i > 0; --i) {
260  const Index patchIdx = patchIndex / m_patchStrides[i];
261  patchIndex -= patchIdx * m_patchStrides[i];
262  const Index offsetIdx = coords[i];
263  inputCoords[i] = coords[i] + patchIdx;
264  }
265  } else {
266  for (int i = 0; i < NumDims - 2; ++i) {
267  const Index patchIdx = patchIndex / m_patchStrides[i];
268  patchIndex -= patchIdx * m_patchStrides[i];
269  const Index offsetIdx = coords[i+1];
270  inputCoords[i] = coords[i+1] + patchIdx;
271  }
272  }
273  Index coords_idx = Layout == ColMajor ? 0 : NumDims - 1;
274  inputCoords[0] = (patchIndex + coords[coords_idx]);
275  return m_impl.coeff(inputCoords);
276  }
277  else {
278  Index inputIndex = 0;
279  if (Layout == ColMajor) {
280  for (int i = NumDims - 2; i > 0; --i) {
281  const Index patchIdx = patchIndex / m_patchStrides[i];
282  patchIndex -= patchIdx * m_patchStrides[i];
283  const Index offsetIdx = coords[i];
284  inputIndex += (patchIdx + offsetIdx) * m_inputStrides[i];
285  }
286  } else {
287  for (int i = 0; i < NumDims - 2; ++i) {
288  const Index patchIdx = patchIndex / m_patchStrides[i];
289  patchIndex -= patchIdx * m_patchStrides[i];
290  const Index offsetIdx = coords[i+1];
291  inputIndex += (patchIdx + offsetIdx) * m_inputStrides[i];
292  }
293  }
294  Index coords_idx = Layout == ColMajor ? 0 : NumDims - 1;
295  inputIndex += (patchIndex + coords[coords_idx]);
296  return m_impl.coeff(inputIndex);
297  }
298  }
299 
300  EIGEN_DEVICE_FUNC Scalar* data() const { return NULL; }
301 
302  protected:
303  Dimensions m_dimensions;
304  array<Index, NumDims> m_outputStrides;
305  array<Index, NumDims-1> m_inputStrides;
306  array<Index, NumDims-1> m_patchStrides;
307 
308  TensorEvaluator<ArgType, Device> m_impl;
309 };
310 
311 } // end namespace Eigen
312 
313 #endif // EIGEN_CXX11_TENSOR_TENSOR_PATCH_H
Namespace containing all symbols from the Eigen library.
Definition: CXX11Meta.h:13