10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_THREAD_POOL_H
11 #define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_THREAD_POOL_H
14 #ifdef EIGEN_USE_THREADS
19 template<
typename LhsScalar,
typename LhsMapper,
typename Index>
29 template<
typename LhsScalar,
typename RhsScalar,
typename RhsMapper,
typename OutputMapper,
typename Index>
30 struct packRhsAndKernelArg {
31 const std::vector<LhsScalar*>* blockAs;
41 const Index num_threads;
42 const Index num_blockAs;
44 const Index k_block_idx;
45 const Index m_block_idx;
46 const Index n_block_idx;
49 std::vector<Notification*>* kernel_notifications;
50 const std::vector<Notification*>* lhs_notifications;
51 const bool need_to_pack;
57 template<
typename Indices,
typename LeftArgType,
typename RightArgType>
58 struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, ThreadPoolDevice> :
59 public TensorContractionEvaluatorBase<TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, ThreadPoolDevice> > {
61 typedef ThreadPoolDevice Device;
63 typedef TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, Device> Self;
64 typedef TensorContractionEvaluatorBase<Self> Base;
66 typedef TensorContractionOp<Indices, LeftArgType, RightArgType> XprType;
67 typedef typename internal::remove_const<typename XprType::Scalar>::type Scalar;
68 typedef typename XprType::Packet Packet;
69 typedef typename XprType::Index Index;
70 typedef typename XprType::CoeffReturnType CoeffReturnType;
71 typedef typename XprType::PacketReturnType PacketReturnType;
74 Layout = TensorEvaluator<LeftArgType, Device>::Layout,
81 typedef typename internal::conditional<
82 static_cast<int>(Layout) == static_cast<int>(ColMajor), LeftArgType, RightArgType>::type EvalLeftArgType;
83 typedef typename internal::conditional<
84 static_cast<int>(Layout) == static_cast<int>(ColMajor), RightArgType, LeftArgType>::type EvalRightArgType;
86 static const int LDims =
87 internal::array_size<typename TensorEvaluator<EvalLeftArgType, Device>::Dimensions>::value;
88 static const int RDims =
89 internal::array_size<typename TensorEvaluator<EvalRightArgType, Device>::Dimensions>::value;
90 static const int ContractDims = internal::array_size<Indices>::value;
92 typedef array<Index, LDims> left_dim_mapper_t;
93 typedef array<Index, RDims> right_dim_mapper_t;
95 typedef array<Index, ContractDims> contract_t;
96 typedef array<Index, max_n_1<LDims - ContractDims>::size> left_nocontract_t;
97 typedef array<Index, max_n_1<RDims - ContractDims>::size> right_nocontract_t;
99 static const int NumDims = max_n_1<LDims + RDims - 2 * ContractDims>::size;
101 typedef DSizes<Index, NumDims> Dimensions;
104 typedef typename internal::remove_const<typename EvalLeftArgType::Scalar>::type LhsScalar;
105 typedef typename internal::remove_const<typename EvalRightArgType::Scalar>::type RhsScalar;
106 typedef typename internal::gebp_traits<LhsScalar, RhsScalar> Traits;
108 typedef TensorEvaluator<EvalLeftArgType, Device> LeftEvaluator;
109 typedef TensorEvaluator<EvalRightArgType, Device> RightEvaluator;
111 TensorEvaluator(
const XprType& op,
const Device& device) :
114 template <
bool lhs_inner_dim_contiguous,
bool rhs_inner_dim_contiguous,
bool rhs_inner_dim_reordered,
int Alignment>
115 void evalProduct(Scalar* buffer)
const {
116 if (this->m_j_size == 1) {
117 this->
template evalGemv<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Alignment>(buffer);
121 evalGemm<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Alignment>(buffer);
124 template <
bool lhs_inner_dim_contiguous,
bool rhs_inner_dim_contiguous,
bool rhs_inner_dim_reordered,
int Alignment>
125 void evalGemm(Scalar* buffer)
const {
127 const Index k = this->m_k_size;
130 const Index m = this->m_i_size;
133 const Index n = this->m_j_size;
136 this->m_device.memset(buffer, 0, m * n *
sizeof(Scalar));
139 const int lhs_packet_size = internal::packet_traits<LhsScalar>::size;
140 const int rhs_packet_size = internal::packet_traits<RhsScalar>::size;
142 typedef internal::TensorContractionInputMapper<LhsScalar, Index, internal::Lhs,
143 LeftEvaluator, left_nocontract_t,
144 contract_t, lhs_packet_size,
145 lhs_inner_dim_contiguous,
146 false, Unaligned> LhsMapper;
148 typedef internal::TensorContractionInputMapper<RhsScalar, Index, internal::Rhs,
149 RightEvaluator, right_nocontract_t,
150 contract_t, rhs_packet_size,
151 rhs_inner_dim_contiguous,
152 rhs_inner_dim_reordered, Unaligned> RhsMapper;
154 typedef internal::blas_data_mapper<Scalar, Index, ColMajor> OutputMapper;
157 typedef internal::gemm_pack_lhs<LhsScalar, Index,
typename LhsMapper::SubMapper, Traits::mr,
158 Traits::LhsProgress, ColMajor> LhsPacker;
159 typedef internal::gemm_pack_rhs<RhsScalar, Index, typename RhsMapper::SubMapper, Traits::nr, ColMajor> RhsPacker;
162 typedef internal::gebp_kernel<LhsScalar, RhsScalar, Index, OutputMapper,
163 Traits::mr, Traits::nr,
false,
false> GebpKernel;
165 typedef internal::packLhsArg<LhsScalar, LhsMapper, Index> packLArg;
166 typedef internal::packRhsAndKernelArg<LhsScalar, RhsScalar, RhsMapper, OutputMapper, Index> packRKArg;
169 LhsMapper lhs(this->m_leftImpl, this->m_left_nocontract_strides, this->m_i_strides,
170 this->m_left_contracting_strides, this->m_k_strides);
172 RhsMapper rhs(this->m_rightImpl, this->m_right_nocontract_strides, this->m_j_strides,
173 this->m_right_contracting_strides, this->m_k_strides);
175 OutputMapper output(buffer, m);
178 const Index num_threads = this->m_device.numThreads();
182 internal::computeProductBlockingSizes<LhsScalar,RhsScalar,1>(kc, mc, nc, num_threads);
183 eigen_assert(mc <= m);
184 eigen_assert(nc <= n);
185 eigen_assert(kc <= k);
187 #define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
188 const Index k_blocks = CEIL_DIV(k, kc);
189 const Index n_blocks = CEIL_DIV(n, nc);
190 const Index m_blocks = CEIL_DIV(m, mc);
191 const Index sizeA = mc * kc;
192 const Index sizeB = kc * nc;
205 const Index numBlockAs = numext::mini(num_threads, m_blocks);
206 std::vector<LhsScalar *> blockAs;
207 blockAs.reserve(num_threads);
208 for (
int i = 0; i < num_threads; i++) {
209 blockAs.push_back(static_cast<LhsScalar *>(this->m_device.allocate(sizeA *
sizeof(LhsScalar))));
216 std::vector<RhsScalar *> blockBs;
217 blockBs.reserve(n_blocks);
218 for (
int i = 0; i < n_blocks; i++) {
219 blockBs.push_back(static_cast<RhsScalar *>(this->m_device.allocate(sizeB *
sizeof(RhsScalar))));
223 std::vector<Notification*> lhs_notifications(num_threads,
nullptr);
226 const Index num_kernel_notifications = num_threads * n_blocks;
227 std::vector<Notification*> kernel_notifications(num_kernel_notifications,
230 for (Index k_block_idx = 0; k_block_idx < k_blocks; k_block_idx++) {
231 const Index k_start = k_block_idx * kc;
233 const Index actual_kc = numext::mini(k_start + kc, k) - k_start;
235 for (Index m_block_idx = 0; m_block_idx < m_blocks; m_block_idx += numBlockAs) {
236 const Index num_blocks = numext::mini(m_blocks-m_block_idx, numBlockAs);
238 for (Index mt_block_idx = m_block_idx; mt_block_idx < m_block_idx+num_blocks; mt_block_idx++) {
239 const Index m_start = mt_block_idx * mc;
240 const Index actual_mc = numext::mini(m_start + mc, m) - m_start;
241 eigen_assert(actual_mc > 0);
243 Index blockAId = (k_block_idx * m_blocks + mt_block_idx) % num_threads;
245 for (
int i = 0; i < n_blocks; ++i) {
246 Index notification_id = (blockAId * n_blocks + i);
249 if (kernel_notifications[notification_id]) {
250 wait_until_ready(kernel_notifications[notification_id]);
251 delete kernel_notifications[notification_id];
253 kernel_notifications[notification_id] =
new Notification();
255 const packLArg arg = {
267 delete lhs_notifications[blockAId];
268 lhs_notifications[blockAId] =
269 this->m_device.enqueue(&Self::packLhs<packLArg, LhsPacker>, arg);
273 const Index m_base_start = m_block_idx * mc;
274 const bool need_to_pack = m_block_idx == 0;
276 for (Index n_block_idx = 0; n_block_idx < n_blocks; n_block_idx++) {
277 const Index n_start = n_block_idx * nc;
278 const Index actual_nc = numext::mini(n_start + nc, n) - n_start;
283 for (Index i = num_blocks; i < num_threads; ++i) {
284 Index blockAId = (k_block_idx * m_blocks + i + m_block_idx) % num_threads;
285 Index future_id = (blockAId * n_blocks + n_block_idx);
286 wait_until_ready(kernel_notifications[future_id]);
292 blockBs[n_block_idx],
309 &kernel_notifications,
317 this->m_device.enqueueNoNotification(&Self::packRhsAndKernel<packRKArg, RhsPacker, GebpKernel>, arg);
323 for (
size_t i = 0; i < kernel_notifications.size(); ++i) {
324 wait_until_ready(kernel_notifications[i]);
325 delete kernel_notifications[i];
330 for (
size_t i = 0; i < lhs_notifications.size(); ++i) {
331 delete lhs_notifications[i];
335 for (
size_t i = 0; i < blockAs.size(); i++) {
336 this->m_device.deallocate(blockAs[i]);
338 for (
size_t i = 0; i < blockBs.size(); i++) {
339 this->m_device.deallocate(blockBs[i]);
350 template <
typename packLArg,
typename LhsPacker>
351 static void packLhs(
const packLArg arg) {
354 pack_lhs(arg.blockA, arg.lhs.getSubMapper(arg.m_start, arg.k_start), arg.kc, arg.mc);
366 template <
typename packRKArg,
typename RhsPacker,
typename GebpKernel>
367 static void packRhsAndKernel(packRKArg arg) {
368 if (arg.need_to_pack) {
370 pack_rhs(arg.blockB, arg.rhs.getSubMapper(arg.k, arg.n), arg.kc, arg.nc);
374 for (Index mt_block_idx = 0; mt_block_idx < arg.num_blockAs; mt_block_idx++) {
375 const Index m_base_start = arg.m + arg.mc*mt_block_idx;
376 if (m_base_start < arg.max_m) {
377 Index blockAId = (arg.k_block_idx * arg.m_blocks + mt_block_idx + arg.m_block_idx) % arg.num_threads;
378 wait_until_ready((*arg.lhs_notifications)[blockAId]);
379 const Index actual_mc = numext::mini(m_base_start + arg.mc, arg.max_m) - m_base_start;
380 gebp(arg.output.getSubMapper(m_base_start, arg.n),
381 (*arg.blockAs)[blockAId], arg.blockB,
382 actual_mc, arg.kc, arg.nc, 1.0, -1, -1, 0, 0);
385 const Index set_idx = blockAId * arg.n_blocks + arg.n_block_idx;
386 (*arg.kernel_notifications)[set_idx]->Notify();
394 #endif // EIGEN_USE_THREADS
395 #endif // EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_THREAD_POOL_H
Namespace containing all symbols from the Eigen library.
Definition: CXX11Meta.h:13