TensorDevice.h
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_DEVICE_H
11 #define EIGEN_CXX11_TENSOR_TENSOR_DEVICE_H
12 
13 namespace Eigen {
14 
27 template <typename ExpressionType, typename DeviceType> class TensorDevice {
28  public:
29  TensorDevice(const DeviceType& device, ExpressionType& expression) : m_device(device), m_expression(expression) {}
30 
31  template<typename OtherDerived>
32  EIGEN_STRONG_INLINE TensorDevice& operator=(const OtherDerived& other) {
33  typedef TensorAssignOp<ExpressionType, const OtherDerived> Assign;
34  Assign assign(m_expression, other);
35  internal::TensorExecutor<const Assign, DeviceType>::run(assign, m_device);
36  return *this;
37  }
38 
39  template<typename OtherDerived>
40  EIGEN_STRONG_INLINE TensorDevice& operator+=(const OtherDerived& other) {
41  typedef typename OtherDerived::Scalar Scalar;
42  typedef TensorCwiseBinaryOp<internal::scalar_sum_op<Scalar>, const ExpressionType, const OtherDerived> Sum;
43  Sum sum(m_expression, other);
44  typedef TensorAssignOp<ExpressionType, const Sum> Assign;
45  Assign assign(m_expression, sum);
46  internal::TensorExecutor<const Assign, DeviceType>::run(assign, m_device);
47  return *this;
48  }
49 
50  template<typename OtherDerived>
51  EIGEN_STRONG_INLINE TensorDevice& operator-=(const OtherDerived& other) {
52  typedef typename OtherDerived::Scalar Scalar;
53  typedef TensorCwiseBinaryOp<internal::scalar_difference_op<Scalar>, const ExpressionType, const OtherDerived> Difference;
54  Difference difference(m_expression, other);
55  typedef TensorAssignOp<ExpressionType, const Difference> Assign;
56  Assign assign(m_expression, difference);
57  internal::TensorExecutor<const Assign, DeviceType>::run(assign, m_device);
58  return *this;
59  }
60 
61  protected:
62  const DeviceType& m_device;
63  ExpressionType& m_expression;
64 };
65 
66 
67 #ifdef EIGEN_USE_THREADS
68 template <typename ExpressionType> class TensorDevice<ExpressionType, ThreadPoolDevice> {
69  public:
70  TensorDevice(const ThreadPoolDevice& device, ExpressionType& expression) : m_device(device), m_expression(expression) {}
71 
72  template<typename OtherDerived>
73  EIGEN_STRONG_INLINE TensorDevice& operator=(const OtherDerived& other) {
74  typedef TensorAssignOp<ExpressionType, const OtherDerived> Assign;
75  Assign assign(m_expression, other);
76  internal::TensorExecutor<const Assign, ThreadPoolDevice>::run(assign, m_device);
77  return *this;
78  }
79 
80  template<typename OtherDerived>
81  EIGEN_STRONG_INLINE TensorDevice& operator+=(const OtherDerived& other) {
82  typedef typename OtherDerived::Scalar Scalar;
83  typedef TensorCwiseBinaryOp<internal::scalar_sum_op<Scalar>, const ExpressionType, const OtherDerived> Sum;
84  Sum sum(m_expression, other);
85  typedef TensorAssignOp<ExpressionType, const Sum> Assign;
86  Assign assign(m_expression, sum);
87  internal::TensorExecutor<const Assign, ThreadPoolDevice>::run(assign, m_device);
88  return *this;
89  }
90 
91  template<typename OtherDerived>
92  EIGEN_STRONG_INLINE TensorDevice& operator-=(const OtherDerived& other) {
93  typedef typename OtherDerived::Scalar Scalar;
94  typedef TensorCwiseBinaryOp<internal::scalar_difference_op<Scalar>, const ExpressionType, const OtherDerived> Difference;
95  Difference difference(m_expression, other);
96  typedef TensorAssignOp<ExpressionType, const Difference> Assign;
97  Assign assign(m_expression, difference);
98  internal::TensorExecutor<const Assign, ThreadPoolDevice>::run(assign, m_device);
99  return *this;
100  }
101 
102  protected:
103  const ThreadPoolDevice& m_device;
104  ExpressionType& m_expression;
105 };
106 #endif
107 
108 
109 #if defined(EIGEN_USE_GPU) && defined(__CUDACC__)
110 template <typename ExpressionType> class TensorDevice<ExpressionType, GpuDevice>
111 {
112  public:
113  TensorDevice(const GpuDevice& device, ExpressionType& expression) : m_device(device), m_expression(expression) {}
114 
115  template<typename OtherDerived>
116  EIGEN_STRONG_INLINE TensorDevice& operator=(const OtherDerived& other) {
117  typedef TensorAssignOp<ExpressionType, const OtherDerived> Assign;
118  Assign assign(m_expression, other);
119  internal::TensorExecutor<const Assign, GpuDevice>::run(assign, m_device);
120  return *this;
121  }
122 
123  template<typename OtherDerived>
124  EIGEN_STRONG_INLINE TensorDevice& operator+=(const OtherDerived& other) {
125  typedef typename OtherDerived::Scalar Scalar;
126  typedef TensorCwiseBinaryOp<internal::scalar_sum_op<Scalar>, const ExpressionType, const OtherDerived> Sum;
127  Sum sum(m_expression, other);
128  typedef TensorAssignOp<ExpressionType, const Sum> Assign;
129  Assign assign(m_expression, sum);
130  internal::TensorExecutor<const Assign, GpuDevice>::run(assign, m_device);
131  return *this;
132  }
133 
134  template<typename OtherDerived>
135  EIGEN_STRONG_INLINE TensorDevice& operator-=(const OtherDerived& other) {
136  typedef typename OtherDerived::Scalar Scalar;
137  typedef TensorCwiseBinaryOp<internal::scalar_difference_op<Scalar>, const ExpressionType, const OtherDerived> Difference;
138  Difference difference(m_expression, other);
139  typedef TensorAssignOp<ExpressionType, const Difference> Assign;
140  Assign assign(m_expression, difference);
141  internal::TensorExecutor<const Assign, GpuDevice>::run(assign, m_device);
142  return *this;
143  }
144 
145  protected:
146  const GpuDevice& m_device;
147  ExpressionType& m_expression;
148 };
149 #endif
150 
151 
152 } // end namespace Eigen
153 
154 #endif // EIGEN_CXX11_TENSOR_TENSOR_DEVICE_H
Namespace containing all symbols from the Eigen library.
Definition: CXX11Meta.h:13
Pseudo expression providing an operator = that will evaluate its argument on the specified computing ...
Definition: TensorDevice.h:27