3 #ifndef DUNE_COMMON_PARALLEL_MPICOMMUNICATION_HH
4 #define DUNE_COMMON_PARALLEL_MPICOMMUNICATION_HH
37 template<
typename Type,
typename BinaryFunction,
typename Enable=
void>
46 op = std::make_unique<MPI_Op>();
51 MPI_Op_create((
void (*)(
void*,
void*,
int*, MPI_Datatype*))&operation,
true,op.get());
56 static void operation (Type *in, Type *inout,
int *len, MPI_Datatype*)
60 for (
int i=0; i< *len; ++i, ++in, ++inout) {
62 temp = func(*in, *inout);
67 Generic_MPI_Op (
const Generic_MPI_Op& ) {}
68 static std::unique_ptr<MPI_Op> op;
72 template<
typename Type,
typename BinaryFunction,
typename Enable>
73 std::unique_ptr<MPI_Op> Generic_MPI_Op<Type,BinaryFunction, Enable>::op;
75 #define ComposeMPIOp(func,op) \
76 template<class T, class S> \
77 class Generic_MPI_Op<T, func<S>, std::enable_if_t<MPITraits<S>::is_intrinsic> >{ \
79 static MPI_Op get(){ \
83 Generic_MPI_Op () {} \
84 Generic_MPI_Op (const Generic_MPI_Op & ) {} \
88 ComposeMPIOp(std::plus, MPI_SUM);
89 ComposeMPIOp(std::multiplies, MPI_PROD);
90 ComposeMPIOp(Min, MPI_MIN);
91 ComposeMPIOp(Max, MPI_MAX);
105 class Communication<MPI_Comm>
112 if(communicator!=MPI_COMM_NULL) {
114 MPI_Initialized(&initialized);
116 DUNE_THROW(ParallelError,
"You must call MPIHelper::instance(argc,argv) in your main() function before using the MPI Communication!");
117 MPI_Comm_rank(communicator,&me);
118 MPI_Comm_size(communicator,&procs);
139 int send(
const T& data,
int dest_rank,
int tag)
const
141 auto mpi_data = getMPIData(data);
142 return MPI_Send(mpi_data.ptr(), mpi_data.size(), mpi_data.type(),
143 dest_rank, tag, communicator);
148 MPIFuture<const T>
isend(
const T&& data,
int dest_rank,
int tag)
const
150 MPIFuture<const T> future(std::forward<const T>(data));
151 auto mpidata = future.get_mpidata();
152 MPI_Isend(mpidata.ptr(), mpidata.size(), mpidata.type(),
153 dest_rank, tag, communicator, &future.req_);
159 T
recv(T&& data,
int source_rank,
int tag, MPI_Status* status = MPI_STATUS_IGNORE)
const
161 T lvalue_data(std::forward<T>(data));
162 auto mpi_data = getMPIData(lvalue_data);
163 MPI_Recv(mpi_data.ptr(), mpi_data.size(), mpi_data.type(),
164 source_rank, tag, communicator, status);
170 MPIFuture<T>
irecv(T&& data,
int source_rank,
int tag)
const
172 MPIFuture<T> future(std::forward<T>(data));
173 auto mpidata = future.get_mpidata();
174 MPI_Irecv(mpidata.ptr(), mpidata.size(), mpidata.type(),
175 source_rank, tag, communicator, &future.req_);
180 T
rrecv(T&& data,
int source_rank,
int tag, MPI_Status* status = MPI_STATUS_IGNORE)
const
183 MPI_Message _message;
184 T lvalue_data(std::forward<T>(data));
185 auto mpi_data = getMPIData(lvalue_data);
186 static_assert(!mpi_data.static_size,
"rrecv work only for non-static-sized types.");
187 if(status == MPI_STATUS_IGNORE)
189 MPI_Mprobe(source_rank, tag, communicator, &_message, status);
191 MPI_Get_count(status, mpi_data.type(), &
size);
192 mpi_data.resize(
size);
193 MPI_Mrecv(mpi_data.ptr(), mpi_data.size(), mpi_data.type(), &_message, status);
199 T
sum (
const T& in)
const
202 allreduce<std::plus<T> >(&in,&out,1);
208 int sum (T* inout,
int len)
const
210 return allreduce<std::plus<T> >(inout,len);
215 T
prod (
const T& in)
const
218 allreduce<std::multiplies<T> >(&in,&out,1);
224 int prod (T* inout,
int len)
const
226 return allreduce<std::multiplies<T> >(inout,len);
231 T
min (
const T& in)
const
234 allreduce<Min<T> >(&in,&out,1);
240 int min (T* inout,
int len)
const
242 return allreduce<Min<T> >(inout,len);
248 T
max (
const T& in)
const
251 allreduce<Max<T> >(&in,&out,1);
257 int max (T* inout,
int len)
const
259 return allreduce<Max<T> >(inout,len);
265 return MPI_Barrier(communicator);
271 MPIFuture<void> future(
true);
272 MPI_Ibarrier(communicator, &future.req_);
279 int broadcast (T* inout,
int len,
int root)
const
281 return MPI_Bcast(inout,len,MPITraits<T>::getType(),root,communicator);
286 MPIFuture<T>
ibroadcast(T&& data,
int root)
const{
287 MPIFuture<T> future(std::forward<T>(data));
288 auto mpidata = future.get_mpidata();
289 MPI_Ibcast(mpidata.ptr(),
301 int gather (
const T* in, T* out,
int len,
int root)
const
303 return MPI_Gather(
const_cast<T*
>(in),len,MPITraits<T>::getType(),
304 out,len,MPITraits<T>::getType(),
309 template<
class TIN,
class TOUT = std::vector<TIN>>
310 MPIFuture<TOUT, TIN>
igather(TIN&& data_in, TOUT&& data_out,
int root){
311 MPIFuture<TOUT, TIN> future(std::forward<TOUT>(data_out), std::forward<TIN>(data_in));
312 auto mpidata_in = future.get_send_mpidata();
313 auto mpidata_out = future.get_mpidata();
314 assert(root != me || mpidata_in.size()*procs <= mpidata_out.size());
315 int outlen = me==root * mpidata_in.size();
316 MPI_Igather(mpidata_in.ptr(), mpidata_in.size(), mpidata_in.type(),
317 mpidata_out.ptr(), outlen, mpidata_out.type(),
318 root, communicator, &future.req_);
324 int gatherv (
const T* in,
int sendlen, T* out,
int* recvlen,
int* displ,
int root)
const
326 return MPI_Gatherv(
const_cast<T*
>(in),sendlen,MPITraits<T>::getType(),
327 out,recvlen,displ,MPITraits<T>::getType(),
336 return MPI_Scatter(
const_cast<T*
>(
send),len,MPITraits<T>::getType(),
337 recv,len,MPITraits<T>::getType(),
342 template<
class TIN,
class TOUT = TIN>
343 MPIFuture<TOUT, TIN>
iscatter(TIN&& data_in, TOUT&& data_out,
int root)
const
345 MPIFuture<TOUT, TIN> future(std::forward<TOUT>(data_out), std::forward<TIN>(data_in));
346 auto mpidata_in = future.get_send_mpidata();
347 auto mpidata_out = future.get_mpidata();
348 int inlen = me==root * mpidata_in.size();
349 MPI_Iscatter(mpidata_in.ptr(), inlen, mpidata_in.type(),
350 mpidata_out.ptr(), mpidata_out.size(), mpidata_out.type(),
351 root, communicator, &future.req_);
357 int scatterv (
const T*
send,
int* sendlen,
int* displ, T*
recv,
int recvlen,
int root)
const
359 return MPI_Scatterv(
const_cast<T*
>(
send),sendlen,displ,MPITraits<T>::getType(),
360 recv,recvlen,MPITraits<T>::getType(),
365 operator MPI_Comm ()
const
371 template<
typename T,
typename T1>
372 int allgather(
const T* sbuf,
int count, T1* rbuf)
const
374 return MPI_Allgather(
const_cast<T*
>(sbuf), count, MPITraits<T>::getType(),
375 rbuf, count, MPITraits<T1>::getType(),
380 template<
class TIN,
class TOUT = TIN>
381 MPIFuture<TOUT, TIN>
iallgather(TIN&& data_in, TOUT&& data_out)
const
383 MPIFuture<TOUT, TIN> future(std::forward<TOUT>(data_out), std::forward<TIN>(data_in));
384 auto mpidata_in = future.get_send_mpidata();
385 auto mpidata_out = future.get_mpidata();
386 assert(mpidata_in.size()*procs <= mpidata_out.size());
387 int outlen = mpidata_in.size();
388 MPI_Iallgather(mpidata_in.ptr(), mpidata_in.size(), mpidata_in.type(),
389 mpidata_out.ptr(), outlen, mpidata_out.type(),
390 communicator, &future.req_);
396 int allgatherv (
const T* in,
int sendlen, T* out,
int* recvlen,
int* displ)
const
398 return MPI_Allgatherv(
const_cast<T*
>(in),sendlen,MPITraits<T>::getType(),
399 out,recvlen,displ,MPITraits<T>::getType(),
404 template<
typename BinaryFunction,
typename Type>
405 int allreduce(Type* inout,
int len)
const
407 Type* out =
new Type[len];
408 int ret = allreduce<BinaryFunction>(inout,out,len);
409 std::copy(out, out+len, inout);
414 template<
typename BinaryFunction,
typename Type>
416 Type lvalue_data = std::forward<Type>(in);
417 auto data = getMPIData(lvalue_data);
418 MPI_Allreduce(MPI_IN_PLACE, data.ptr(), data.size(), data.type(),
425 template<
class BinaryFunction,
class TIN,
class TOUT = TIN>
426 MPIFuture<TOUT, TIN>
iallreduce(TIN&& data_in, TOUT&& data_out)
const {
427 MPIFuture<TOUT, TIN> future(std::forward<TOUT>(data_out), std::forward<TIN>(data_in));
428 auto mpidata_in = future.get_send_mpidata();
429 auto mpidata_out = future.get_mpidata();
430 assert(mpidata_out.size() == mpidata_in.size());
431 assert(mpidata_out.type() == mpidata_in.type());
432 MPI_Iallreduce(mpidata_in.ptr(), mpidata_out.ptr(),
433 mpidata_out.size(), mpidata_out.type(),
435 communicator, &future.req_);
440 template<
class BinaryFunction,
class T>
442 MPIFuture<T> future(std::forward<T>(data));
443 auto mpidata = future.get_mpidata();
444 MPI_Iallreduce(MPI_IN_PLACE, mpidata.ptr(),
445 mpidata.size(), mpidata.type(),
447 communicator, &future.req_);
452 template<
typename BinaryFunction,
typename Type>
453 int allreduce(
const Type* in, Type* out,
int len)
const
455 return MPI_Allreduce(
const_cast<Type*
>(in), out, len, MPITraits<Type>::getType(),
460 MPI_Comm communicator;