Nektar++
CommMpi.cpp
Go to the documentation of this file.
1 ///////////////////////////////////////////////////////////////////////////////
2 //
3 // File CommMpi.cpp
4 //
5 // For more information, please see: http://www.nektar.info
6 //
7 // The MIT License
8 //
9 // Copyright (c) 2006 Division of Applied Mathematics, Brown University (USA),
10 // Department of Aeronautics, Imperial College London (UK), and Scientific
11 // Computing and Imaging Institute, University of Utah (USA).
12 //
13 // Permission is hereby granted, free of charge, to any person obtaining a
14 // copy of this software and associated documentation files (the "Software"),
15 // to deal in the Software without restriction, including without limitation
16 // the rights to use, copy, modify, merge, publish, distribute, sublicense,
17 // and/or sell copies of the Software, and to permit persons to whom the
18 // Software is furnished to do so, subject to the following conditions:
19 //
20 // The above copyright notice and this permission notice shall be included
21 // in all copies or substantial portions of the Software.
22 //
23 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
24 // OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
25 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
26 // THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
27 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
28 // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
29 // DEALINGS IN THE SOFTWARE.
30 //
31 // Description: MPI communication implementation
32 //
33 ///////////////////////////////////////////////////////////////////////////////
34 
35 #ifdef NEKTAR_USING_PETSC
36 #include "petscsys.h"
37 #endif
38 
41 
42 namespace Nektar
43 {
44 namespace LibUtilities
45 {
46 
48  "ParallelMPI", CommMpi::create, "Parallel communication using MPI.");
49 
50 /**
51  *
52  */
53 CommMpi::CommMpi(int narg, char *arg[]) : Comm(narg, arg)
54 {
55  int init = 0;
56  MPI_Initialized(&init);
57  ASSERTL0(!init, "MPI has already been initialised.");
58 
59  int retval = MPI_Init(&narg, &arg);
60  if (retval != MPI_SUCCESS)
61  {
62  ASSERTL0(false, "Failed to initialise MPI");
63  }
64 
65  m_comm = MPI_COMM_WORLD;
66  MPI_Comm_size(m_comm, &m_size);
67  MPI_Comm_rank(m_comm, &m_rank);
68 
69 #ifdef NEKTAR_USING_PETSC
70  PetscInitializeNoArguments();
71 #endif
72 
73  m_type = "Parallel MPI";
74 }
75 
76 /**
77  *
78  */
79 CommMpi::CommMpi(MPI_Comm pComm) : Comm()
80 {
81  m_comm = pComm;
82  MPI_Comm_size(m_comm, &m_size);
83  MPI_Comm_rank(m_comm, &m_rank);
84 
85  m_type = "Parallel MPI";
86 }
87 
88 /**
89  *
90  */
92 {
93  int flag;
94  MPI_Finalized(&flag);
95  if (!flag && m_comm != MPI_COMM_WORLD)
96  {
97  MPI_Comm_free(&m_comm);
98  }
99 }
100 
101 /**
102  *
103  */
105 {
106  return m_comm;
107 }
108 
109 /**
110  *
111  */
113 {
114 #ifdef NEKTAR_USING_PETSC
115  PetscFinalize();
116 #endif
117  int flag;
118  MPI_Finalized(&flag);
119  if (!flag)
120  {
121  MPI_Finalize();
122  }
123 }
124 
125 /**
126  *
127  */
129 {
130  return m_rank;
131 }
132 
133 /**
134  *
135  */
137 {
138  return m_rank == 0;
139 }
140 
141 /**
142  *
143  */
145 {
146  return m_size == 1;
147 }
148 
149 std::tuple<int, int, int> CommMpi::v_GetVersion()
150 {
151  int version, subversion;
152  int retval = MPI_Get_version(&version, &subversion);
153 
154  ASSERTL0(retval == MPI_SUCCESS, "MPI error performing GetVersion.");
155 
156  return std::make_tuple(version, subversion, 0);
157 }
158 
159 /**
160  *
161  */
163 {
164  MPI_Barrier(m_comm);
165 }
166 
167 /**
168  *
169  */
171 {
172  return MPI_Wtime();
173 }
174 
175 /**
176  *
177  */
178 void CommMpi::v_Send(void *buf, int count, CommDataType dt, int dest)
179 {
180  if (MPISYNC)
181  {
182  MPI_Ssend(buf, count, dt, dest, 0, m_comm);
183  }
184  else
185  {
186  MPI_Send(buf, count, dt, dest, 0, m_comm);
187  }
188 }
189 
190 /**
191  *
192  */
193 void CommMpi::v_Recv(void *buf, int count, CommDataType dt, int source)
194 {
195  MPI_Recv(buf, count, dt, source, 0, m_comm, MPI_STATUS_IGNORE);
196  // ASSERTL0(status.MPI_ERROR == MPI_SUCCESS,
197  // "MPI error receiving data.");
198 }
199 
200 /**
201  *
202  */
203 void CommMpi::v_SendRecv(void *sendbuf, int sendcount, CommDataType sendtype,
204  int dest, void *recvbuf, int recvcount,
205  CommDataType recvtype, int source)
206 {
207  MPI_Status status;
208  int retval = MPI_Sendrecv(sendbuf, sendcount, sendtype, dest, 0, recvbuf,
209  recvcount, recvtype, source, 0, m_comm, &status);
210 
211  ASSERTL0(retval == MPI_SUCCESS,
212  "MPI error performing send-receive of data.");
213 }
214 
215 /**
216  *
217  */
218 void CommMpi::v_SendRecvReplace(void *buf, int count, CommDataType dt,
219  int pSendProc, int pRecvProc)
220 {
221  MPI_Status status;
222  int retval = MPI_Sendrecv_replace(buf, count, dt, pRecvProc, 0, pSendProc,
223  0, m_comm, &status);
224 
225  ASSERTL0(retval == MPI_SUCCESS,
226  "MPI error performing Send-Receive-Replace of data.");
227 }
228 
229 /**
230  *
231  */
232 void CommMpi::v_AllReduce(void *buf, int count, CommDataType dt,
233  enum ReduceOperator pOp)
234 {
235  if (GetSize() == 1)
236  {
237  return;
238  }
239 
240  MPI_Op vOp;
241  switch (pOp)
242  {
243  case ReduceMax:
244  vOp = MPI_MAX;
245  break;
246  case ReduceMin:
247  vOp = MPI_MIN;
248  break;
249  case ReduceSum:
250  default:
251  vOp = MPI_SUM;
252  break;
253  }
254  int retval = MPI_Allreduce(MPI_IN_PLACE, buf, count, dt, vOp, m_comm);
255 
256  ASSERTL0(retval == MPI_SUCCESS, "MPI error performing All-reduce.");
257 }
258 
259 /**
260  *
261  */
262 void CommMpi::v_AlltoAll(void *sendbuf, int sendcount, CommDataType sendtype,
263  void *recvbuf, int recvcount, CommDataType recvtype)
264 {
265  int retval = MPI_Alltoall(sendbuf, sendcount, sendtype, recvbuf, recvcount,
266  recvtype, m_comm);
267 
268  ASSERTL0(retval == MPI_SUCCESS, "MPI error performing All-to-All.");
269 }
270 
271 /**
272  *
273  */
274 void CommMpi::v_AlltoAllv(void *sendbuf, int sendcounts[], int sdispls[],
275  CommDataType sendtype, void *recvbuf,
276  int recvcounts[], int rdispls[],
277  CommDataType recvtype)
278 {
279  int retval = MPI_Alltoallv(sendbuf, sendcounts, sdispls, sendtype, recvbuf,
280  recvcounts, rdispls, recvtype, m_comm);
281 
282  ASSERTL0(retval == MPI_SUCCESS, "MPI error performing All-to-All-v.");
283 }
284 
285 /**
286  *
287  */
288 void CommMpi::v_AllGather(void *sendbuf, int sendcount, CommDataType sendtype,
289  void *recvbuf, int recvcount, CommDataType recvtype)
290 {
291  int retval = MPI_Allgather(sendbuf, sendcount, sendtype, recvbuf, recvcount,
292  recvtype, m_comm);
293 
294  ASSERTL0(retval == MPI_SUCCESS, "MPI error performing Allgather.");
295 }
296 
297 void CommMpi::v_AllGatherv(void *sendbuf, int sendcount, CommDataType sendtype,
298  void *recvbuf, int recvcounts[], int rdispls[],
299  CommDataType recvtype)
300 {
301  int retval = MPI_Allgatherv(sendbuf, sendcount, sendtype, recvbuf,
302  recvcounts, rdispls, recvtype, m_comm);
303 
304  ASSERTL0(retval == MPI_SUCCESS, "MPI error performing Allgather.");
305 }
306 
307 void CommMpi::v_AllGatherv(void *recvbuf, int recvcounts[], int rdispls[],
308  CommDataType recvtype)
309 {
310  int retval = MPI_Allgatherv(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, recvbuf,
311  recvcounts, rdispls, recvtype, m_comm);
312 
313  ASSERTL0(retval == MPI_SUCCESS, "MPI error performing Allgatherv.");
314 }
315 
316 void CommMpi::v_Bcast(void *buffer, int count, CommDataType dt, int root)
317 {
318  int retval = MPI_Bcast(buffer, count, dt, root, m_comm);
319  ASSERTL0(retval == MPI_SUCCESS, "MPI error performing Bcast-v.");
320 }
321 
323  const enum ReduceOperator pOp,
325 {
326  int n = pData.size();
327  ASSERTL0(n == ans.size(), "Array sizes differ in Exscan");
328 
329  MPI_Op vOp;
330  switch (pOp)
331  {
332  case ReduceMax:
333  vOp = MPI_MAX;
334  break;
335  case ReduceMin:
336  vOp = MPI_MIN;
337  break;
338  case ReduceSum:
339  default:
340  vOp = MPI_SUM;
341  break;
342  }
343 
344  int retval = MPI_Exscan(pData.get(), ans.get(), n, MPI_UNSIGNED_LONG_LONG,
345  vOp, m_comm);
346  ASSERTL0(retval == MPI_SUCCESS, "MPI error performing Exscan-v.");
347 }
348 
349 void CommMpi::v_Gather(void *sendbuf, int sendcount, CommDataType sendtype,
350  void *recvbuf, int recvcount, CommDataType recvtype,
351  int root)
352 {
353  int retval = MPI_Gather(sendbuf, sendcount, sendtype, recvbuf, recvcount,
354  recvtype, root, m_comm);
355 
356  ASSERTL0(retval == MPI_SUCCESS, "MPI error performing Gather.");
357 }
358 
359 void CommMpi::v_Scatter(void *sendbuf, int sendcount, CommDataType sendtype,
360  void *recvbuf, int recvcount, CommDataType recvtype,
361  int root)
362 {
363  int retval = MPI_Scatter(sendbuf, sendcount, sendtype, recvbuf, recvcount,
364  recvtype, root, m_comm);
365  ASSERTL0(retval == MPI_SUCCESS, "MPI error performing Scatter.");
366 }
367 
368 void CommMpi::v_DistGraphCreateAdjacent(int indegree, const int sources[],
369  const int sourceweights[], int reorder)
370 {
371 #if MPI_VERSION < 3
372  boost::ignore_unused(indegree, sources, sourceweights, reorder);
373  ASSERTL0(false, "MPI_Dist_graph_create_adjacent is not supported in your "
374  "installed MPI version.");
375 #else
376  int retval = MPI_Dist_graph_create_adjacent(
377  m_comm, indegree, sources, sourceweights, indegree, sources,
378  sourceweights, MPI_INFO_NULL, reorder, &m_comm);
379 
380  ASSERTL0(retval == MPI_SUCCESS,
381  "MPI error performing Dist_graph_create_adjacent.")
382 #endif
383 }
384 
385 void CommMpi::v_NeighborAlltoAllv(void *sendbuf, int sendcounts[],
386  int sdispls[], CommDataType sendtype,
387  void *recvbuf, int recvcounts[],
388  int rdispls[], CommDataType recvtype)
389 {
390 #if MPI_VERSION < 3
391  boost::ignore_unused(sendbuf, sendcounts, sdispls, sendtype, recvbuf,
392  recvcounts, rdispls, recvtype);
393  ASSERTL0(false, "MPI_Neighbor_alltoallv is not supported in your "
394  "installed MPI version.");
395 #else
396  int retval =
397  MPI_Neighbor_alltoallv(sendbuf, sendcounts, sdispls, sendtype, recvbuf,
398  recvcounts, rdispls, recvtype, m_comm);
399 
400  ASSERTL0(retval == MPI_SUCCESS, "MPI error performing NeighborAllToAllV.");
401 #endif
402 }
403 
404 void CommMpi::v_Irsend(void *buf, int count, CommDataType dt, int dest,
405  CommRequestSharedPtr request, int loc)
406 {
408  std::static_pointer_cast<CommRequestMpi>(request);
409  MPI_Irsend(buf, count, dt, dest, 0, m_comm, req->GetRequest(loc));
410 }
411 
412 void CommMpi::v_SendInit(void *buf, int count, CommDataType dt, int dest,
413  CommRequestSharedPtr request, int loc)
414 {
416  std::static_pointer_cast<CommRequestMpi>(request);
417  MPI_Send_init(buf, count, dt, dest, 0, m_comm, req->GetRequest(loc));
418 }
419 
420 void CommMpi::v_Irecv(void *buf, int count, CommDataType dt, int source,
421  CommRequestSharedPtr request, int loc)
422 {
424  std::static_pointer_cast<CommRequestMpi>(request);
425  MPI_Irecv(buf, count, dt, source, 0, m_comm, req->GetRequest(loc));
426 }
427 
428 void CommMpi::v_RecvInit(void *buf, int count, CommDataType dt, int source,
429  CommRequestSharedPtr request, int loc)
430 {
432  std::static_pointer_cast<CommRequestMpi>(request);
433  MPI_Recv_init(buf, count, dt, source, 0, m_comm, req->GetRequest(loc));
434 }
435 
437 {
439  std::static_pointer_cast<CommRequestMpi>(request);
440  MPI_Startall(req->GetNumRequest(), req->GetRequest(0));
441 }
442 
444 {
446  std::static_pointer_cast<CommRequestMpi>(request);
447  MPI_Waitall(req->GetNumRequest(), req->GetRequest(0), MPI_STATUSES_IGNORE);
448 }
449 
451 {
452  return std::shared_ptr<CommRequest>(new CommRequestMpi(num));
453 }
454 
455 /**
456  * Processes are considered as a grid of size pRows*pColumns. Comm
457  * objects are created corresponding to the rows and columns of this
458  * grid. The row and column to which this process belongs is stored in
459  * #m_commRow and #m_commColumn.
460  */
461 void CommMpi::v_SplitComm(int pRows, int pColumns)
462 {
463  ASSERTL0(pRows * pColumns == m_size,
464  "Rows/Columns do not match comm size.");
465 
466  MPI_Comm newComm;
467 
468  // Compute row and column in grid.
469  int myCol = m_rank % pColumns;
470  int myRow = (m_rank - myCol) / pColumns;
471 
472  // Split Comm into rows - all processes with same myRow are put in
473  // the same communicator. The rank within this communicator is the
474  // column index.
475  MPI_Comm_split(m_comm, myRow, myCol, &newComm);
476  m_commRow = std::shared_ptr<Comm>(new CommMpi(newComm));
477 
478  // Split Comm into columns - all processes with same myCol are put
479  // in the same communicator. The rank within this communicator is
480  // the row index.
481  MPI_Comm_split(m_comm, myCol, myRow, &newComm);
482  m_commColumn = std::shared_ptr<Comm>(new CommMpi(newComm));
483 }
484 
485 /**
486  * Create a new communicator if the flag is non-zero.
487  */
489 {
490  MPI_Comm newComm;
491  // color == MPI_UNDEF => not in the new communicator
492  // key == 0 on all => use rank to order them. OpenMPI, at least,
493  // implies this is faster than ordering them ourselves.
494  MPI_Comm_split(m_comm, flag ? 0 : MPI_UNDEFINED, 0, &newComm);
495 
496  if (flag == 0)
497  {
498  // flag == 0 => get back MPI_COMM_NULL, return a null ptr instead.
499  return std::shared_ptr<Comm>();
500  }
501  else
502  {
503  // Return a real communicator
504  return std::shared_ptr<Comm>(new CommMpi(newComm));
505  }
506 }
507 
508 std::pair<CommSharedPtr, CommSharedPtr> CommMpi::v_SplitCommNode()
509 {
510  std::pair<CommSharedPtr, CommSharedPtr> ret;
511 
512 #if MPI_VERSION < 3
513  ASSERTL0(false, "Not implemented for non-MPI-3 versions.");
514 #else
515  // Create an intra-node communicator.
516  MPI_Comm nodeComm;
517  MPI_Comm_split_type(MPI_COMM_WORLD, MPI_COMM_TYPE_SHARED, m_rank,
518  MPI_INFO_NULL, &nodeComm);
519 
520  // For rank 0 of the intra-node communicator, split the main
521  // communicator. Everyone else will get a null communicator.
522  ret.first = std::shared_ptr<Comm>(new CommMpi(nodeComm));
523  ret.second = CommMpi::v_CommCreateIf(ret.first->GetRank() == 0);
524  if(ret.first->GetRank() == 0)
525  {
526  ret.second->SplitComm(1, ret.second->GetSize());
527  }
528 #endif
529 
530  return ret;
531 }
532 
533 } // namespace LibUtilities
534 } // namespace Nektar
#define MPI_UNSIGNED_LONG_LONG
Definition: CommDataType.h:97
#define MPISYNC
Definition: CommMpi.h:45
#define ASSERTL0(condition, msg)
Definition: ErrorUtil.hpp:216
Base communications class.
Definition: Comm.h:90
CommSharedPtr m_commColumn
Column communicator.
Definition: Comm.h:178
CommSharedPtr m_commRow
Row communicator.
Definition: Comm.h:177
int GetSize() const
Returns number of processes.
Definition: Comm.h:269
int m_size
Number of processes.
Definition: Comm.h:175
std::string m_type
Type of communication.
Definition: Comm.h:176
virtual void v_Block() final
Definition: CommMpi.cpp:162
virtual void v_WaitAll(CommRequestSharedPtr request) final
Definition: CommMpi.cpp:443
virtual bool v_IsSerial() final
Definition: CommMpi.cpp:144
virtual void v_AlltoAll(void *sendbuf, int sendcount, CommDataType sendtype, void *recvbuf, int recvcount, CommDataType recvtype) final
Definition: CommMpi.cpp:262
virtual CommRequestSharedPtr v_CreateRequest(int num) final
Definition: CommMpi.cpp:450
CommMpi(int narg, char *arg[])
Definition: CommMpi.cpp:53
virtual void v_DistGraphCreateAdjacent(int indegree, const int sources[], const int sourceweights[], int reorder) final
Definition: CommMpi.cpp:368
virtual void v_Bcast(void *buffer, int count, CommDataType dt, int root) final
Definition: CommMpi.cpp:316
virtual void v_AllGatherv(void *sendbuf, int sendcount, CommDataType sendtype, void *recvbuf, int recvcounts[], int rdispls[], CommDataType recvtype) final
Definition: CommMpi.cpp:297
virtual void v_SendRecvReplace(void *buf, int count, CommDataType dt, int pSendProc, int pRecvProc) final
Definition: CommMpi.cpp:218
virtual CommSharedPtr v_CommCreateIf(int flag) final
Definition: CommMpi.cpp:488
virtual std::pair< CommSharedPtr, CommSharedPtr > v_SplitCommNode() final
Definition: CommMpi.cpp:508
virtual void v_SendRecv(void *sendbuf, int sendcount, CommDataType sendtype, int dest, void *recvbuf, int recvcount, CommDataType recvtype, int source) final
Definition: CommMpi.cpp:203
virtual double v_Wtime() final
Definition: CommMpi.cpp:170
virtual void v_NeighborAlltoAllv(void *sendbuf, int sendcounts[], int sensdispls[], CommDataType sendtype, void *recvbuf, int recvcounts[], int rdispls[], CommDataType recvtype) final
Definition: CommMpi.cpp:385
virtual void v_AllGather(void *sendbuf, int sendcount, CommDataType sendtype, void *recvbuf, int recvcount, CommDataType recvtype) final
Definition: CommMpi.cpp:288
virtual void v_Recv(void *buf, int count, CommDataType dt, int source) final
Definition: CommMpi.cpp:193
virtual void v_StartAll(CommRequestSharedPtr request) final
Definition: CommMpi.cpp:436
virtual void v_AllReduce(void *buf, int count, CommDataType dt, enum ReduceOperator pOp) final
Definition: CommMpi.cpp:232
virtual void v_Finalise() override
Definition: CommMpi.cpp:112
virtual void v_Irecv(void *buf, int count, CommDataType dt, int source, CommRequestSharedPtr request, int loc) final
Definition: CommMpi.cpp:420
virtual void v_Irsend(void *buf, int count, CommDataType dt, int dest, CommRequestSharedPtr request, int loc) final
Definition: CommMpi.cpp:404
virtual int v_GetRank() final
Definition: CommMpi.cpp:128
virtual void v_Exscan(Array< OneD, unsigned long long > &pData, enum ReduceOperator pOp, Array< OneD, unsigned long long > &ans) final
Definition: CommMpi.cpp:322
virtual void v_Gather(void *sendbuf, int sendcount, CommDataType sendtype, void *recvbuf, int recvcount, CommDataType recvtype, int root) final
Definition: CommMpi.cpp:349
static std::string className
Name of class.
Definition: CommMpi.h:101
virtual void v_RecvInit(void *buf, int count, CommDataType dt, int source, CommRequestSharedPtr request, int loc) final
Definition: CommMpi.cpp:428
static CommSharedPtr create(int narg, char *arg[])
Creates an instance of this class.
Definition: CommMpi.h:95
virtual void v_SplitComm(int pRows, int pColumns) override
Definition: CommMpi.cpp:461
virtual void v_AlltoAllv(void *sendbuf, int sendcounts[], int sensdispls[], CommDataType sendtype, void *recvbuf, int recvcounts[], int rdispls[], CommDataType recvtype) final
Definition: CommMpi.cpp:274
virtual bool v_TreatAsRankZero() final
Definition: CommMpi.cpp:136
virtual void v_Send(void *buf, int count, CommDataType dt, int dest) final
Definition: CommMpi.cpp:178
virtual void v_Scatter(void *sendbuf, int sendcount, CommDataType sendtype, void *recvbuf, int recvcount, CommDataType recvtype, int root) final
Definition: CommMpi.cpp:359
virtual ~CommMpi() override
Definition: CommMpi.cpp:91
virtual std::tuple< int, int, int > v_GetVersion() final
Definition: CommMpi.cpp:149
virtual void v_SendInit(void *buf, int count, CommDataType dt, int dest, CommRequestSharedPtr request, int loc) final
Definition: CommMpi.cpp:412
Class for communicator request type.
Definition: CommMpi.h:62
tKey RegisterCreatorFunction(tKey idKey, CreatorFunction classCreator, std::string pDesc="")
Register a class with the factory.
Definition: NekFactory.hpp:200
array buffer
Definition: GsLib.hpp:61
unsigned int CommDataType
Definition: CommDataType.h:70
std::shared_ptr< CommRequest > CommRequestSharedPtr
Definition: Comm.h:86
std::shared_ptr< CommRequestMpi > CommRequestMpiSharedPtr
Definition: CommMpi.h:88
CommFactory & GetCommFactory()
ReduceOperator
Type of operation to perform in AllReduce.
Definition: Comm.h:67
std::shared_ptr< Comm > CommSharedPtr
Pointer to a Communicator object.
Definition: Comm.h:54
The above copyright notice and this permission notice shall be included.
Definition: CoupledSolver.h:1