Nektar++
CommMpi.cpp
Go to the documentation of this file.
1 ///////////////////////////////////////////////////////////////////////////////
2 //
3 // File CommMpi.cpp
4 //
5 // For more information, please see: http://www.nektar.info
6 //
7 // The MIT License
8 //
9 // Copyright (c) 2006 Division of Applied Mathematics, Brown University (USA),
10 // Department of Aeronautics, Imperial College London (UK), and Scientific
11 // Computing and Imaging Institute, University of Utah (USA).
12 //
13 // Permission is hereby granted, free of charge, to any person obtaining a
14 // copy of this software and associated documentation files (the "Software"),
15 // to deal in the Software without restriction, including without limitation
16 // the rights to use, copy, modify, merge, publish, distribute, sublicense,
17 // and/or sell copies of the Software, and to permit persons to whom the
18 // Software is furnished to do so, subject to the following conditions:
19 //
20 // The above copyright notice and this permission notice shall be included
21 // in all copies or substantial portions of the Software.
22 //
23 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
24 // OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
25 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
26 // THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
27 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
28 // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
29 // DEALINGS IN THE SOFTWARE.
30 //
31 // Description: MPI communication implementation
32 //
33 ///////////////////////////////////////////////////////////////////////////////
34 
35 #ifdef NEKTAR_USING_PETSC
36 #include "petscsys.h"
37 #endif
38 
41 
42 namespace Nektar
43 {
44 namespace LibUtilities
45 {
47  "ParallelMPI", CommMpi::create, "Parallel communication using MPI.");
48 
49 /**
50  *
51  */
52 CommMpi::CommMpi(int narg, char *arg[]) : Comm(narg, arg)
53 {
54  int init = 0;
55  MPI_Initialized(&init);
56  ASSERTL0(!init, "MPI has already been initialised.");
57 
58  int retval = MPI_Init(&narg, &arg);
59  if (retval != MPI_SUCCESS)
60  {
61  ASSERTL0(false, "Failed to initialise MPI");
62  }
63 
64  m_comm = MPI_COMM_WORLD;
65  MPI_Comm_size(m_comm, &m_size);
66  MPI_Comm_rank(m_comm, &m_rank);
67 
68 #ifdef NEKTAR_USING_PETSC
69  PetscInitializeNoArguments();
70 #endif
71 
72  m_type = "Parallel MPI";
73 }
74 
75 /**
76  *
77  */
78 CommMpi::CommMpi(MPI_Comm pComm) : Comm()
79 {
80  m_comm = pComm;
81  MPI_Comm_size(m_comm, &m_size);
82  MPI_Comm_rank(m_comm, &m_rank);
83 
84  m_type = "Parallel MPI";
85 }
86 
87 /**
88  *
89  */
91 {
92  int flag;
93  MPI_Finalized(&flag);
94  if (!flag && m_comm != MPI_COMM_WORLD)
95  {
96  MPI_Comm_free(&m_comm);
97  }
98 }
99 
100 /**
101  *
102  */
104 {
105  return m_comm;
106 }
107 
108 /**
109  *
110  */
112 {
113 #ifdef NEKTAR_USING_PETSC
114  PetscFinalize();
115 #endif
116  int flag;
117  MPI_Finalized(&flag);
118  if (!flag)
119  {
120  MPI_Finalize();
121  }
122 }
123 
124 /**
125  *
126  */
128 {
129  return m_rank;
130 }
131 
132 /**
133  *
134  */
136 {
137  if (m_rank == 0)
138  {
139  return true;
140  }
141  else
142  {
143  return false;
144  }
145  return true;
146 }
147 
148 /**
149  *
150  */
152 {
153  if(m_size == 1)
154  {
155  return true;
156  }
157  else
158  {
159  return false;
160  }
161 }
162 
163 /**
164  *
165  */
167 {
168  MPI_Barrier(m_comm);
169 }
170 
171 /**
172  *
173  */
175 {
176  return MPI_Wtime();
177 }
178 
179 /**
180  *
181  */
182 void CommMpi::v_Send(void *buf, int count, CommDataType dt, int dest)
183 {
184  if (MPISYNC)
185  {
186  MPI_Ssend(buf, count, dt, dest, 0, m_comm);
187  }
188  else
189  {
190  MPI_Send(buf, count, dt, dest, 0, m_comm);
191  }
192 }
193 
194 /**
195  *
196  */
197 void CommMpi::v_Recv(void *buf, int count, CommDataType dt, int source)
198 {
199  MPI_Recv(buf, count, dt, source, 0, m_comm, MPI_STATUS_IGNORE);
200  // ASSERTL0(status.MPI_ERROR == MPI_SUCCESS,
201  // "MPI error receiving data.");
202 }
203 
204 /**
205  *
206  */
207 void CommMpi::v_SendRecv(void *sendbuf, int sendcount, CommDataType sendtype,
208  int dest, void *recvbuf, int recvcount,
209  CommDataType recvtype, int source)
210 {
211  MPI_Status status;
212  int retval = MPI_Sendrecv(sendbuf, sendcount, sendtype, dest, 0, recvbuf,
213  recvcount, recvtype, source, 0, m_comm, &status);
214 
215  ASSERTL0(retval == MPI_SUCCESS,
216  "MPI error performing send-receive of data.");
217 }
218 
219 /**
220 *
221 */
222 void CommMpi::v_SendRecvReplace(void *buf, int count, CommDataType dt,
223  int pSendProc, int pRecvProc)
224 {
225  MPI_Status status;
226  int retval = MPI_Sendrecv_replace(buf, count, dt, pRecvProc, 0, pSendProc,
227  0, m_comm, &status);
228 
229  ASSERTL0(retval == MPI_SUCCESS,
230  "MPI error performing Send-Receive-Replace of data.");
231 }
232 
233 /**
234  *
235  */
236 void CommMpi::v_AllReduce(void *buf, int count, CommDataType dt,
237  enum ReduceOperator pOp)
238 {
239  if (GetSize() == 1)
240  {
241  return;
242  }
243 
244  MPI_Op vOp;
245  switch (pOp)
246  {
247  case ReduceMax:
248  vOp = MPI_MAX;
249  break;
250  case ReduceMin:
251  vOp = MPI_MIN;
252  break;
253  case ReduceSum:
254  default:
255  vOp = MPI_SUM;
256  break;
257  }
258  int retval = MPI_Allreduce(MPI_IN_PLACE, buf, count, dt, vOp, m_comm);
259 
260  ASSERTL0(retval == MPI_SUCCESS, "MPI error performing All-reduce.");
261 }
262 
263 /**
264  *
265  */
266 void CommMpi::v_AlltoAll(void *sendbuf, int sendcount, CommDataType sendtype,
267  void *recvbuf, int recvcount, CommDataType recvtype)
268 {
269  int retval = MPI_Alltoall(sendbuf, sendcount, sendtype, recvbuf, recvcount,
270  recvtype, m_comm);
271 
272  ASSERTL0(retval == MPI_SUCCESS, "MPI error performing All-to-All.");
273 }
274 
275 /**
276  *
277  */
278 void CommMpi::v_AlltoAllv(void *sendbuf, int sendcounts[], int sdispls[],
279  CommDataType sendtype, void *recvbuf,
280  int recvcounts[], int rdispls[],
281  CommDataType recvtype)
282 {
283  int retval = MPI_Alltoallv(sendbuf, sendcounts, sdispls, sendtype, recvbuf,
284  recvcounts, rdispls, recvtype, m_comm);
285 
286  ASSERTL0(retval == MPI_SUCCESS, "MPI error performing All-to-All-v.");
287 }
288 
289 /**
290  *
291  */
292 void CommMpi::v_AllGather(void *sendbuf, int sendcount, CommDataType sendtype,
293  void *recvbuf, int recvcount, CommDataType recvtype)
294 {
295  int retval = MPI_Allgather(sendbuf, sendcount, sendtype, recvbuf, recvcount,
296  recvtype, m_comm);
297 
298  ASSERTL0(retval == MPI_SUCCESS, "MPI error performing Allgather.");
299 }
300 
301 void CommMpi::v_AllGatherv(void *sendbuf, int sendcount, CommDataType sendtype,
302  void *recvbuf, int recvcounts[], int rdispls[],
303  CommDataType recvtype)
304 {
305  int retval = MPI_Allgatherv(sendbuf, sendcount, sendtype, recvbuf,
306  recvcounts, rdispls, recvtype, m_comm);
307 
308  ASSERTL0(retval == MPI_SUCCESS, "MPI error performing Allgather.");
309 }
310 
311 void CommMpi::v_AllGatherv(void *recvbuf, int recvcounts[], int rdispls[],
312  CommDataType recvtype)
313 {
314  int retval = MPI_Allgatherv(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, recvbuf,
315  recvcounts, rdispls, recvtype, m_comm);
316 
317  ASSERTL0(retval == MPI_SUCCESS, "MPI error performing Allgatherv.");
318 }
319 
320 void CommMpi::v_Bcast(void *buffer, int count, CommDataType dt, int root)
321 {
322  int retval = MPI_Bcast(buffer, count, dt, root, m_comm);
323  ASSERTL0(retval == MPI_SUCCESS, "MPI error performing Bcast-v.");
324 }
325 
327  const enum ReduceOperator pOp,
329 {
330  int n = pData.num_elements();
331  ASSERTL0(n == ans.num_elements(), "Array sizes differ in Exscan");
332 
333  MPI_Op vOp;
334  switch (pOp)
335  {
336  case ReduceMax:
337  vOp = MPI_MAX;
338  break;
339  case ReduceMin:
340  vOp = MPI_MIN;
341  break;
342  case ReduceSum:
343  default:
344  vOp = MPI_SUM;
345  break;
346  }
347 
348  int retval = MPI_Exscan(pData.get(), ans.get(), n, MPI_UNSIGNED_LONG_LONG,
349  vOp, m_comm);
350  ASSERTL0(retval == MPI_SUCCESS, "MPI error performing Exscan-v.");
351 }
352 
353 void CommMpi::v_Gather(void *sendbuf, int sendcount, CommDataType sendtype,
354  void *recvbuf, int recvcount, CommDataType recvtype,
355  int root)
356 {
357  int retval = MPI_Gather(sendbuf, sendcount, sendtype, recvbuf, recvcount,
358  recvtype, root, m_comm);
359 
360  ASSERTL0(retval == MPI_SUCCESS, "MPI error performing Gather.");
361 }
362 
363 void CommMpi::v_Scatter(void *sendbuf, int sendcount, CommDataType sendtype,
364  void *recvbuf, int recvcount, CommDataType recvtype,
365  int root)
366 {
367  int retval = MPI_Scatter(sendbuf, sendcount, sendtype, recvbuf, recvcount,
368  recvtype, root, m_comm);
369  ASSERTL0(retval == MPI_SUCCESS, "MPI error performing Scatter.");
370 }
371 
372 /**
373  * Processes are considered as a grid of size pRows*pColumns. Comm
374  * objects are created corresponding to the rows and columns of this
375  * grid. The row and column to which this process belongs is stored in
376  * #m_commRow and #m_commColumn.
377  */
378 void CommMpi::v_SplitComm(int pRows, int pColumns)
379 {
380  ASSERTL0(pRows * pColumns == m_size,
381  "Rows/Columns do not match comm size.");
382 
383  MPI_Comm newComm;
384 
385  // Compute row and column in grid.
386  int myCol = m_rank % pColumns;
387  int myRow = (m_rank - myCol) / pColumns;
388 
389  // Split Comm into rows - all processes with same myRow are put in
390  // the same communicator. The rank within this communicator is the
391  // column index.
392  MPI_Comm_split(m_comm, myRow, myCol, &newComm);
393  m_commRow = std::shared_ptr<Comm>(new CommMpi(newComm));
394 
395  // Split Comm into columns - all processes with same myCol are put
396  // in the same communicator. The rank within this communicator is
397  // the row index.
398  MPI_Comm_split(m_comm, myCol, myRow, &newComm);
399  m_commColumn = std::shared_ptr<Comm>(new CommMpi(newComm));
400 }
401 
402 /**
403  * Create a new communicator if the flag is non-zero.
404  */
406 {
407  MPI_Comm newComm;
408  // color == MPI_UNDEF => not in the new communicator
409  // key == 0 on all => use rank to order them. OpenMPI, at least,
410  // implies this is faster than ordering them ourselves.
411  MPI_Comm_split(m_comm, flag ? 0 : MPI_UNDEFINED, 0, &newComm);
412 
413  if (flag == 0)
414  {
415  // flag == 0 => get back MPI_COMM_NULL, return a null ptr instead.
416  return std::shared_ptr<Comm>();
417  }
418  else
419  {
420  // Return a real communicator
421  return std::shared_ptr<Comm>(new CommMpi(newComm));
422  }
423 }
424 }
425 }
virtual void v_SendRecvReplace(void *buf, int count, CommDataType dt, int pSendProc, int pRecvProc)
Definition: CommMpi.cpp:222
#define ASSERTL0(condition, msg)
Definition: ErrorUtil.hpp:216
ReduceOperator
Type of operation to perform in AllReduce.
Definition: Comm.h:65
virtual void v_Recv(void *buf, int count, CommDataType dt, int source)
Definition: CommMpi.cpp:197
CommSharedPtr m_commColumn
Column communicator.
Definition: Comm.h:151
virtual void v_Send(void *buf, int count, CommDataType dt, int dest)
Definition: CommMpi.cpp:182
std::string m_type
Type of communication.
Definition: Comm.h:149
std::shared_ptr< Comm > CommSharedPtr
Pointer to a Communicator object.
Definition: Comm.h:53
virtual void v_AllReduce(void *buf, int count, CommDataType dt, enum ReduceOperator pOp)
Definition: CommMpi.cpp:236
array buffer
Definition: GsLib.hpp:61
virtual void v_AllGather(void *sendbuf, int sendcount, CommDataType sendtype, void *recvbuf, int recvcount, CommDataType recvtype)
Definition: CommMpi.cpp:292
virtual void v_AlltoAllv(void *sendbuf, int sendcounts[], int sensdispls[], CommDataType sendtype, void *recvbuf, int recvcounts[], int rdispls[], CommDataType recvtype)
Definition: CommMpi.cpp:278
CommFactory & GetCommFactory()
CommSharedPtr m_commRow
Row communicator.
Definition: Comm.h:150
virtual bool v_IsSerial(void)
Definition: CommMpi.cpp:151
#define MPISYNC
Definition: CommMpi.h:44
static CommSharedPtr create(int narg, char *arg[])
Creates an instance of this class.
Definition: CommMpi.h:64
virtual CommSharedPtr v_CommCreateIf(int flag)
Definition: CommMpi.cpp:405
virtual void v_SplitComm(int pRows, int pColumns)
Definition: CommMpi.cpp:378
virtual void v_AllGatherv(void *sendbuf, int sendcount, CommDataType sendtype, void *recvbuf, int recvcounts[], int rdispls[], CommDataType recvtype)
Definition: CommMpi.cpp:301
virtual double v_Wtime()
Definition: CommMpi.cpp:174
virtual bool v_TreatAsRankZero(void)
Definition: CommMpi.cpp:135
virtual void v_Scatter(void *sendbuf, int sendcount, CommDataType sendtype, void *recvbuf, int recvcount, CommDataType recvtype, int root)
Definition: CommMpi.cpp:363
Base communications class.
Definition: Comm.h:81
static std::string className
Name of class.
Definition: CommMpi.h:70
virtual void v_Exscan(Array< OneD, unsigned long long > &pData, const enum ReduceOperator pOp, Array< OneD, unsigned long long > &ans)
Definition: CommMpi.cpp:326
virtual void v_Bcast(void *buffer, int count, CommDataType dt, int root)
Definition: CommMpi.cpp:320
virtual void v_SendRecv(void *sendbuf, int sendcount, CommDataType sendtype, int dest, void *recvbuf, int recvcount, CommDataType recvtype, int source)
Definition: CommMpi.cpp:207
tKey RegisterCreatorFunction(tKey idKey, CreatorFunction classCreator, std::string pDesc="")
Register a class with the factory.
Definition: NekFactory.hpp:199
int GetSize()
Returns number of processes.
Definition: Comm.h:215
virtual void v_Gather(void *sendbuf, int sendcount, CommDataType sendtype, void *recvbuf, int recvcount, CommDataType recvtype, int root)
Definition: CommMpi.cpp:353
int m_size
Number of processes.
Definition: Comm.h:148
virtual void v_AlltoAll(void *sendbuf, int sendcount, CommDataType sendtype, void *recvbuf, int recvcount, CommDataType recvtype)
Definition: CommMpi.cpp:266