10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_BROADCASTING_H
11 #define EIGEN_CXX11_TENSOR_TENSOR_BROADCASTING_H
23 template<
typename Broadcast,
typename XprType>
24 struct traits<TensorBroadcastingOp<Broadcast, XprType> > :
public traits<XprType>
26 typedef typename XprType::Scalar Scalar;
27 typedef traits<XprType> XprTraits;
28 typedef typename packet_traits<Scalar>::type Packet;
29 typedef typename XprTraits::StorageKind StorageKind;
30 typedef typename XprTraits::Index Index;
31 typedef typename XprType::Nested Nested;
32 typedef typename remove_reference<Nested>::type _Nested;
33 static const int NumDimensions = XprTraits::NumDimensions;
34 static const int Layout = XprTraits::Layout;
37 template<
typename Broadcast,
typename XprType>
38 struct eval<TensorBroadcastingOp<Broadcast, XprType>,
Eigen::Dense>
40 typedef const TensorBroadcastingOp<Broadcast, XprType>& type;
43 template<
typename Broadcast,
typename XprType>
44 struct nested<TensorBroadcastingOp<Broadcast, XprType>, 1, typename eval<TensorBroadcastingOp<Broadcast, XprType> >::type>
46 typedef TensorBroadcastingOp<Broadcast, XprType> type;
53 template<
typename Broadcast,
typename XprType>
54 class TensorBroadcastingOp :
public TensorBase<TensorBroadcastingOp<Broadcast, XprType>, ReadOnlyAccessors>
57 typedef typename Eigen::internal::traits<TensorBroadcastingOp>::Scalar Scalar;
58 typedef typename Eigen::internal::traits<TensorBroadcastingOp>::Packet Packet;
59 typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
60 typedef typename XprType::CoeffReturnType CoeffReturnType;
61 typedef typename XprType::PacketReturnType PacketReturnType;
62 typedef typename Eigen::internal::nested<TensorBroadcastingOp>::type Nested;
63 typedef typename Eigen::internal::traits<TensorBroadcastingOp>::StorageKind StorageKind;
64 typedef typename Eigen::internal::traits<TensorBroadcastingOp>::Index Index;
66 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorBroadcastingOp(
const XprType& expr,
const Broadcast& broadcast)
67 : m_xpr(expr), m_broadcast(broadcast) {}
70 const Broadcast& broadcast()
const {
return m_broadcast; }
73 const typename internal::remove_all<typename XprType::Nested>::type&
74 expression()
const {
return m_xpr; }
77 typename XprType::Nested m_xpr;
78 const Broadcast m_broadcast;
83 template<
typename Broadcast,
typename ArgType,
typename Device>
84 struct TensorEvaluator<const TensorBroadcastingOp<Broadcast, ArgType>, Device>
86 typedef TensorBroadcastingOp<Broadcast, ArgType> XprType;
87 typedef typename XprType::Index Index;
88 static const int NumDims = internal::array_size<typename TensorEvaluator<ArgType, Device>::Dimensions>::value;
89 typedef DSizes<Index, NumDims> Dimensions;
90 typedef typename XprType::Scalar Scalar;
91 typedef typename TensorEvaluator<ArgType, Device>::Dimensions InputDimensions;
95 PacketAccess = TensorEvaluator<ArgType, Device>::PacketAccess,
96 Layout = TensorEvaluator<ArgType, Device>::Layout,
99 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(
const XprType& op,
const Device& device)
100 : m_impl(op.expression(), device)
102 const typename TensorEvaluator<ArgType, Device>::Dimensions& input_dims = m_impl.dimensions();
103 const Broadcast& broadcast = op.broadcast();
104 for (
int i = 0; i < NumDims; ++i) {
105 eigen_assert(input_dims[i] > 0);
106 m_dimensions[i] = input_dims[i] * broadcast[i];
109 if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
110 m_inputStrides[0] = 1;
111 m_outputStrides[0] = 1;
112 for (
int i = 1; i < NumDims; ++i) {
113 m_inputStrides[i] = m_inputStrides[i-1] * input_dims[i-1];
114 m_outputStrides[i] = m_outputStrides[i-1] * m_dimensions[i-1];
117 m_inputStrides[NumDims-1] = 1;
118 m_outputStrides[NumDims-1] = 1;
119 for (
int i = NumDims-2; i >= 0; --i) {
120 m_inputStrides[i] = m_inputStrides[i+1] * input_dims[i+1];
121 m_outputStrides[i] = m_outputStrides[i+1] * m_dimensions[i+1];
126 typedef typename XprType::CoeffReturnType CoeffReturnType;
127 typedef typename XprType::PacketReturnType PacketReturnType;
129 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const Dimensions& dimensions()
const {
return m_dimensions; }
131 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
bool evalSubExprsIfNeeded(Scalar* ) {
132 m_impl.evalSubExprsIfNeeded(NULL);
136 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
void cleanup() {
140 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE CoeffReturnType coeff(Index index)
const
142 if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
143 return coeffColMajor(index);
145 return coeffRowMajor(index);
150 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeffColMajor(Index index)
const
152 Index inputIndex = 0;
153 for (
int i = NumDims - 1; i > 0; --i) {
154 const Index idx = index / m_outputStrides[i];
155 if (internal::index_statically_eq<Broadcast>()(i, 1)) {
156 eigen_assert(idx < m_impl.dimensions()[i]);
157 inputIndex += idx * m_inputStrides[i];
159 if (internal::index_statically_eq<InputDimensions>()(i, 1)) {
160 eigen_assert(idx % m_impl.dimensions()[i] == 0);
162 inputIndex += (idx % m_impl.dimensions()[i]) * m_inputStrides[i];
165 index -= idx * m_outputStrides[i];
167 if (internal::index_statically_eq<Broadcast>()(0, 1)) {
168 eigen_assert(index < m_impl.dimensions()[0]);
171 if (internal::index_statically_eq<InputDimensions>()(0, 1)) {
172 eigen_assert(index % m_impl.dimensions()[0] == 0);
174 inputIndex += (index % m_impl.dimensions()[0]);
177 return m_impl.coeff(inputIndex);
180 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeffRowMajor(Index index)
const
182 Index inputIndex = 0;
183 for (
int i = 0; i < NumDims - 1; ++i) {
184 const Index idx = index / m_outputStrides[i];
185 if (internal::index_statically_eq<Broadcast>()(i, 1)) {
186 eigen_assert(idx < m_impl.dimensions()[i]);
187 inputIndex += idx * m_inputStrides[i];
189 if (internal::index_statically_eq<InputDimensions>()(i, 1)) {
190 eigen_assert(idx % m_impl.dimensions()[i] == 0);
192 inputIndex += (idx % m_impl.dimensions()[i]) * m_inputStrides[i];
195 index -= idx * m_outputStrides[i];
197 if (internal::index_statically_eq<Broadcast>()(NumDims-1, 1)) {
198 eigen_assert(index < m_impl.dimensions()[NumDims-1]);
201 if (internal::index_statically_eq<InputDimensions>()(NumDims-1, 1)) {
202 eigen_assert(index % m_impl.dimensions()[NumDims-1] == 0);
204 inputIndex += (index % m_impl.dimensions()[NumDims-1]);
207 return m_impl.coeff(inputIndex);
210 template<
int LoadMode>
211 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketReturnType packet(Index index)
const
213 if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
214 return packetColMajor<LoadMode>(index);
216 return packetRowMajor<LoadMode>(index);
222 template<
int LoadMode>
223 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packetColMajor(Index index)
const
225 const int packetSize = internal::unpacket_traits<PacketReturnType>::size;
226 EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
227 eigen_assert(index+packetSize-1 < dimensions().TotalSize());
229 const Index originalIndex = index;
231 Index inputIndex = 0;
232 for (
int i = NumDims - 1; i > 0; --i) {
233 const Index idx = index / m_outputStrides[i];
234 if (internal::index_statically_eq<Broadcast>()(i, 1)) {
235 eigen_assert(idx < m_impl.dimensions()[i]);
236 inputIndex += idx * m_inputStrides[i];
238 if (internal::index_statically_eq<InputDimensions>()(i, 1)) {
239 eigen_assert(idx % m_impl.dimensions()[i] == 0);
241 inputIndex += (idx % m_impl.dimensions()[i]) * m_inputStrides[i];
244 index -= idx * m_outputStrides[i];
247 if (internal::index_statically_eq<Broadcast>()(0, 1)) {
248 eigen_assert(index < m_impl.dimensions()[0]);
249 innermostLoc = index;
251 if (internal::index_statically_eq<InputDimensions>()(0, 1)) {
252 eigen_assert(index % m_impl.dimensions()[0] == 0);
255 innermostLoc = index % m_impl.dimensions()[0];
258 inputIndex += innermostLoc;
262 if (innermostLoc + packetSize <= m_impl.dimensions()[0]) {
263 return m_impl.template packet<Unaligned>(inputIndex);
265 EIGEN_ALIGN_MAX
typename internal::remove_const<CoeffReturnType>::type values[packetSize];
266 values[0] = m_impl.coeff(inputIndex);
267 for (
int i = 1; i < packetSize; ++i) {
268 values[i] = coeffColMajor(originalIndex+i);
270 PacketReturnType rslt = internal::pload<PacketReturnType>(values);
275 template<
int LoadMode>
276 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packetRowMajor(Index index)
const
278 const int packetSize = internal::unpacket_traits<PacketReturnType>::size;
279 EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
280 eigen_assert(index+packetSize-1 < dimensions().TotalSize());
282 const Index originalIndex = index;
284 Index inputIndex = 0;
285 for (
int i = 0; i < NumDims - 1; ++i) {
286 const Index idx = index / m_outputStrides[i];
287 if (internal::index_statically_eq<Broadcast>()(i, 1)) {
288 eigen_assert(idx < m_impl.dimensions()[i]);
289 inputIndex += idx * m_inputStrides[i];
291 if (internal::index_statically_eq<InputDimensions>()(i, 1)) {
292 eigen_assert(idx % m_impl.dimensions()[i] == 0);
294 inputIndex += (idx % m_impl.dimensions()[i]) * m_inputStrides[i];
297 index -= idx * m_outputStrides[i];
300 if (internal::index_statically_eq<Broadcast>()(NumDims-1, 1)) {
301 eigen_assert(index < m_impl.dimensions()[NumDims-1]);
302 innermostLoc = index;
304 if (internal::index_statically_eq<InputDimensions>()(NumDims-1, 1)) {
305 eigen_assert(index % m_impl.dimensions()[NumDims-1] == 0);
308 innermostLoc = index % m_impl.dimensions()[NumDims-1];
311 inputIndex += innermostLoc;
315 if (innermostLoc + packetSize <= m_impl.dimensions()[NumDims-1]) {
316 return m_impl.template packet<Unaligned>(inputIndex);
318 EIGEN_ALIGN_MAX
typename internal::remove_const<CoeffReturnType>::type values[packetSize];
319 values[0] = m_impl.coeff(inputIndex);
320 for (
int i = 1; i < packetSize; ++i) {
321 values[i] = coeffRowMajor(originalIndex+i);
323 PacketReturnType rslt = internal::pload<PacketReturnType>(values);
329 EIGEN_DEVICE_FUNC Scalar* data()
const {
return NULL; }
332 Dimensions m_dimensions;
333 array<Index, NumDims> m_outputStrides;
334 array<Index, NumDims> m_inputStrides;
335 TensorEvaluator<ArgType, Device> m_impl;
341 #endif // EIGEN_CXX11_TENSOR_TENSOR_BROADCASTING_H
Namespace containing all symbols from the Eigen library.
Definition: CXX11Meta.h:13