Nektar++
CommMpi.cpp
Go to the documentation of this file.
1///////////////////////////////////////////////////////////////////////////////
2//
3// File: CommMpi.cpp
4//
5// For more information, please see: http://www.nektar.info
6//
7// The MIT License
8//
9// Copyright (c) 2006 Division of Applied Mathematics, Brown University (USA),
10// Department of Aeronautics, Imperial College London (UK), and Scientific
11// Computing and Imaging Institute, University of Utah (USA).
12//
13// Permission is hereby granted, free of charge, to any person obtaining a
14// copy of this software and associated documentation files (the "Software"),
15// to deal in the Software without restriction, including without limitation
16// the rights to use, copy, modify, merge, publish, distribute, sublicense,
17// and/or sell copies of the Software, and to permit persons to whom the
18// Software is furnished to do so, subject to the following conditions:
19//
20// The above copyright notice and this permission notice shall be included
21// in all copies or substantial portions of the Software.
22//
23// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
24// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
25// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
26// THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
27// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
28// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
29// DEALINGS IN THE SOFTWARE.
30//
31// Description: MPI communication implementation
32//
33///////////////////////////////////////////////////////////////////////////////
34
35#ifdef NEKTAR_USING_PETSC
36#include "petscsys.h"
37#endif
38
40
42{
43
45 "ParallelMPI", CommMpi::create, "Parallel communication using MPI.");
46
47/**
48 *
49 */
50CommMpi::CommMpi(int narg, char *arg[]) : Comm(narg, arg)
51{
52 int init = 0;
53 MPI_Initialized(&init);
54
55 if (!init)
56 {
57 int thread_support = 0;
58 if (MPI_Init_thread(&narg, &arg, MPI_THREAD_MULTIPLE,
59 &thread_support) != MPI_SUCCESS)
60 {
63 "Initializing MPI using MPI_Init, if scotch version > 6 and is "
64 "compiled with multi-threading, it might cause deadlocks.")
65 ASSERTL0(MPI_Init(&narg, &arg) == MPI_SUCCESS,
66 "Failed to initialise MPI");
67 }
68 // store bool to indicate that Nektar++ is in charge of finalizing MPI.
69 m_controls_mpi = true;
70 }
71 else
72 {
73 // Another code is in charge of finalizing MPI and this is not the
74 // responsiblity of Nektar++
75 m_controls_mpi = false;
76 }
77
78 m_comm = MPI_COMM_WORLD;
79 MPI_Comm_size(m_comm, &m_size);
80 MPI_Comm_rank(m_comm, &m_rank);
81
82#ifdef NEKTAR_USING_PETSC
83 PetscInitializeNoArguments();
84#endif
85
86 m_type = "Parallel MPI";
87}
88
89/**
90 *
91 */
92CommMpi::CommMpi(MPI_Comm pComm) : Comm()
93{
94 m_comm = pComm;
95 MPI_Comm_size(m_comm, &m_size);
96 MPI_Comm_rank(m_comm, &m_rank);
97
98 m_type = "Parallel MPI";
99}
100
101/**
102 *
103 */
105{
106 int flag;
107 MPI_Finalized(&flag);
108 if (!flag && m_comm != MPI_COMM_WORLD)
109 {
110 MPI_Comm_free(&m_comm);
111 }
112}
113
114/**
115 *
116 */
118{
119 return m_comm;
120}
121
122/**
123 *
124 */
126{
127#ifdef NEKTAR_USING_PETSC
128 PetscFinalize();
129#endif
130 int flag;
131 MPI_Finalized(&flag);
132 if ((!flag) && m_controls_mpi)
133 {
134 MPI_Finalize();
135 }
136}
137
138/**
139 *
140 */
142{
143 return m_rank;
144}
145
146/**
147 *
148 */
150{
151 return m_rank == 0;
152}
153
154/**
155 *
156 */
158{
159 return m_size == 1;
160}
161
162/**
163 *
164 */
165std::tuple<int, int, int> CommMpi::v_GetVersion()
166{
167 int version, subversion;
168 int retval = MPI_Get_version(&version, &subversion);
169
170 ASSERTL0(retval == MPI_SUCCESS, "MPI error performing GetVersion.");
171
172 return std::make_tuple(version, subversion, 0);
173}
174
175/**
176 *
177 */
179{
180 MPI_Barrier(m_comm);
181}
182
183/**
184 *
185 */
187{
188 return MPI_Wtime();
189}
190
191/**
192 *
193 */
194void CommMpi::v_Send(void *buf, int count, CommDataType dt, int dest)
195{
196 if (MPISYNC)
197 {
198 MPI_Ssend(buf, count, dt, dest, 0, m_comm);
199 }
200 else
201 {
202 MPI_Send(buf, count, dt, dest, 0, m_comm);
203 }
204}
205
206/**
207 *
208 */
209void CommMpi::v_Recv(void *buf, int count, CommDataType dt, int source)
210{
211 MPI_Recv(buf, count, dt, source, 0, m_comm, MPI_STATUS_IGNORE);
212}
213
214/**
215 *
216 */
217void CommMpi::v_SendRecv(void *sendbuf, int sendcount, CommDataType sendtype,
218 int dest, void *recvbuf, int recvcount,
219 CommDataType recvtype, int source)
220{
221 MPI_Status status;
222 int retval = MPI_Sendrecv(sendbuf, sendcount, sendtype, dest, 0, recvbuf,
223 recvcount, recvtype, source, 0, m_comm, &status);
224
225 ASSERTL0(retval == MPI_SUCCESS,
226 "MPI error performing send-receive of data.");
227}
228
229/**
230 *
231 */
232void CommMpi::v_AllReduce(void *buf, int count, CommDataType dt,
233 enum ReduceOperator pOp)
234{
235 if (GetSize() == 1)
236 {
237 return;
238 }
239
240 MPI_Op vOp;
241 switch (pOp)
242 {
243 case ReduceMax:
244 vOp = MPI_MAX;
245 break;
246 case ReduceMin:
247 vOp = MPI_MIN;
248 break;
249 case ReduceSum:
250 default:
251 vOp = MPI_SUM;
252 break;
253 }
254 int retval = MPI_Allreduce(MPI_IN_PLACE, buf, count, dt, vOp, m_comm);
255
256 ASSERTL0(retval == MPI_SUCCESS, "MPI error performing All-reduce.");
257}
258
259/**
260 *
261 */
262void CommMpi::v_AlltoAll(void *sendbuf, int sendcount, CommDataType sendtype,
263 void *recvbuf, int recvcount, CommDataType recvtype)
264{
265 int retval = MPI_Alltoall(sendbuf, sendcount, sendtype, recvbuf, recvcount,
266 recvtype, m_comm);
267
268 ASSERTL0(retval == MPI_SUCCESS, "MPI error performing All-to-All.");
269}
270
271/**
272 *
273 */
274void CommMpi::v_AlltoAllv(void *sendbuf, int sendcounts[], int sdispls[],
275 CommDataType sendtype, void *recvbuf,
276 int recvcounts[], int rdispls[],
277 CommDataType recvtype)
278{
279 int retval = MPI_Alltoallv(sendbuf, sendcounts, sdispls, sendtype, recvbuf,
280 recvcounts, rdispls, recvtype, m_comm);
281
282 ASSERTL0(retval == MPI_SUCCESS, "MPI error performing All-to-All-v.");
283}
284
285/**
286 *
287 */
288void CommMpi::v_AllGather(void *sendbuf, int sendcount, CommDataType sendtype,
289 void *recvbuf, int recvcount, CommDataType recvtype)
290{
291 int retval = MPI_Allgather(sendbuf, sendcount, sendtype, recvbuf, recvcount,
292 recvtype, m_comm);
293
294 ASSERTL0(retval == MPI_SUCCESS, "MPI error performing Allgather.");
295}
296
297/**
298 *
299 */
300void CommMpi::v_AllGatherv(void *sendbuf, int sendcount, CommDataType sendtype,
301 void *recvbuf, int recvcounts[], int rdispls[],
302 CommDataType recvtype)
303{
304 int retval = MPI_Allgatherv(sendbuf, sendcount, sendtype, recvbuf,
305 recvcounts, rdispls, recvtype, m_comm);
306
307 ASSERTL0(retval == MPI_SUCCESS, "MPI error performing Allgatherv.");
308}
309
310/**
311 *
312 */
313void CommMpi::v_AllGatherv(void *recvbuf, int recvcounts[], int rdispls[],
314 CommDataType recvtype)
315{
316 int retval = MPI_Allgatherv(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, recvbuf,
317 recvcounts, rdispls, recvtype, m_comm);
318
319 ASSERTL0(retval == MPI_SUCCESS, "MPI error performing Allgatherv.");
320}
321
322/**
323 *
324 */
325void CommMpi::v_Bcast(void *buffer, int count, CommDataType dt, int root)
326{
327 int retval = MPI_Bcast(buffer, count, dt, root, m_comm);
328
329 ASSERTL0(retval == MPI_SUCCESS, "MPI error performing Bcast-v.");
330}
331
332/**
333 *
334 */
335void CommMpi::v_Gather(void *sendbuf, int sendcount, CommDataType sendtype,
336 void *recvbuf, int recvcount, CommDataType recvtype,
337 int root)
338{
339 int retval = MPI_Gather(sendbuf, sendcount, sendtype, recvbuf, recvcount,
340 recvtype, root, m_comm);
341
342 ASSERTL0(retval == MPI_SUCCESS, "MPI error performing Gather.");
343}
344
345/**
346 *
347 */
348void CommMpi::v_Scatter(void *sendbuf, int sendcount, CommDataType sendtype,
349 void *recvbuf, int recvcount, CommDataType recvtype,
350 int root)
351{
352 int retval = MPI_Scatter(sendbuf, sendcount, sendtype, recvbuf, recvcount,
353 recvtype, root, m_comm);
354
355 ASSERTL0(retval == MPI_SUCCESS, "MPI error performing Scatter.");
356}
357
358/**
359 *
360 */
362 [[maybe_unused]] int indegree, [[maybe_unused]] const int sources[],
363 [[maybe_unused]] const int sourceweights[], [[maybe_unused]] int reorder)
364{
365#if MPI_VERSION < 3
366 ASSERTL0(false, "MPI_Dist_graph_create_adjacent is not supported in your "
367 "installed MPI version.");
368#else
369 int retval = MPI_Dist_graph_create_adjacent(
370 m_comm, indegree, sources, sourceweights, indegree, sources,
371 sourceweights, MPI_INFO_NULL, reorder, &m_comm);
372
373 ASSERTL0(retval == MPI_SUCCESS,
374 "MPI error performing Dist_graph_create_adjacent.")
375#endif
376}
377
378/**
379 *
380 */
382 [[maybe_unused]] void *sendbuf, [[maybe_unused]] int sendcounts[],
383 [[maybe_unused]] int sdispls[], [[maybe_unused]] CommDataType sendtype,
384 [[maybe_unused]] void *recvbuf, [[maybe_unused]] int recvcounts[],
385 [[maybe_unused]] int rdispls[], [[maybe_unused]] CommDataType recvtype)
386{
387#if MPI_VERSION < 3
388 ASSERTL0(false, "MPI_Neighbor_alltoallv is not supported in your "
389 "installed MPI version.");
390#else
391 int retval =
392 MPI_Neighbor_alltoallv(sendbuf, sendcounts, sdispls, sendtype, recvbuf,
393 recvcounts, rdispls, recvtype, m_comm);
394
395 ASSERTL0(retval == MPI_SUCCESS, "MPI error performing NeighborAllToAllV.");
396#endif
397}
398
399/**
400 *
401 */
402void CommMpi::v_Irsend(void *buf, int count, CommDataType dt, int dest,
403 CommRequestSharedPtr request, int loc)
404{
406 std::static_pointer_cast<CommRequestMpi>(request);
407 MPI_Irsend(buf, count, dt, dest, 0, m_comm, req->GetRequest(loc));
408}
409
410/**
411 *
412 */
413void CommMpi::v_Isend(void *buf, int count, CommDataType dt, int dest,
414 CommRequestSharedPtr request, int loc)
415{
417 std::static_pointer_cast<CommRequestMpi>(request);
418 MPI_Isend(buf, count, dt, dest, 0, m_comm, req->GetRequest(loc));
419}
420
421/**
422 *
423 */
424void CommMpi::v_SendInit(void *buf, int count, CommDataType dt, int dest,
425 CommRequestSharedPtr request, int loc)
426{
428 std::static_pointer_cast<CommRequestMpi>(request);
429 MPI_Send_init(buf, count, dt, dest, 0, m_comm, req->GetRequest(loc));
430}
431
432/**
433 *
434 */
435void CommMpi::v_Irecv(void *buf, int count, CommDataType dt, int source,
436 CommRequestSharedPtr request, int loc)
437{
439 std::static_pointer_cast<CommRequestMpi>(request);
440 MPI_Irecv(buf, count, dt, source, 0, m_comm, req->GetRequest(loc));
441}
442
443/**
444 *
445 */
446void CommMpi::v_RecvInit(void *buf, int count, CommDataType dt, int source,
447 CommRequestSharedPtr request, int loc)
448{
450 std::static_pointer_cast<CommRequestMpi>(request);
451 MPI_Recv_init(buf, count, dt, source, 0, m_comm, req->GetRequest(loc));
452}
453
454/**
455 *
456 */
458{
460 std::static_pointer_cast<CommRequestMpi>(request);
461 if (req->GetNumRequest() != 0)
462 {
463 MPI_Startall(req->GetNumRequest(), req->GetRequest(0));
464 }
465}
466
467/**
468 *
469 */
471{
473 std::static_pointer_cast<CommRequestMpi>(request);
474 if (req->GetNumRequest() != 0)
475 {
476 MPI_Waitall(req->GetNumRequest(), req->GetRequest(0),
477 MPI_STATUSES_IGNORE);
478 }
479}
480
481/**
482 *
483 */
485{
486 return std::shared_ptr<CommRequest>(new CommRequestMpi(num));
487}
488
489/**
490 * Processes are considered as a grid of size pRows*pColumns. Comm
491 * objects are created corresponding to the rows and columns of this
492 * grid. The row and column to which this process belongs is stored in
493 * #m_commRow and #m_commColumn.
494 */
495void CommMpi::v_SplitComm(int pRows, int pColumns, int pTime)
496{
497 ASSERTL0(pRows * pColumns * pTime == m_size,
498 "Rows/Columns/Time do not match comm size.");
499
500 MPI_Comm newComm;
501 MPI_Comm gridComm;
502 if (pTime == 1)
503 {
504 // Compute row and column in grid.
505 int myCol = m_rank % pColumns;
506 int myRow = (m_rank - myCol) / pColumns;
507
508 // Split Comm into rows - all processes with same myRow are put in
509 // the same communicator. The rank within this communicator is the
510 // column index.
511 MPI_Comm_split(m_comm, myRow, myCol, &newComm);
512 m_commRow = std::shared_ptr<Comm>(new CommMpi(newComm));
513
514 // Split Comm into columns - all processes with same myCol are put
515 // in the same communicator. The rank within this communicator is
516 // the row index.
517 MPI_Comm_split(m_comm, myCol, myRow, &newComm);
518 m_commColumn = std::shared_ptr<Comm>(new CommMpi(newComm));
519 }
520 else
521 {
522 constexpr int dims = 3;
523 const int sizes[dims] = {pRows, pColumns, pTime};
524 const int periods[dims] = {0, 0, 0};
525 constexpr int reorder = 1;
526
527 MPI_Cart_create(m_comm, dims, sizes, periods, reorder, &gridComm);
528
529 constexpr int keepRow[dims] = {0, 1, 0};
530 MPI_Cart_sub(gridComm, keepRow, &newComm);
531 m_commRow = std::shared_ptr<Comm>(new CommMpi(newComm));
532
533 constexpr int keepCol[dims] = {1, 0, 0};
534 MPI_Cart_sub(gridComm, keepCol, &newComm);
535 m_commColumn = std::shared_ptr<Comm>(new CommMpi(newComm));
536
537 constexpr int keepTime[dims] = {0, 0, 1};
538 MPI_Cart_sub(gridComm, keepTime, &newComm);
539 m_commTime = std::shared_ptr<Comm>(new CommMpi(newComm));
540
541 constexpr int keepSpace[dims] = {1, 1, 0};
542 MPI_Cart_sub(gridComm, keepSpace, &newComm);
543 m_commSpace = std::shared_ptr<Comm>(new CommMpi(newComm));
544 }
545}
546
547/**
548 * Create a new communicator if the flag is non-zero.
549 */
551{
552 MPI_Comm newComm;
553 // color == MPI_UNDEF => not in the new communicator
554 // key == 0 on all => use rank to order them. OpenMPI, at least,
555 // implies this is faster than ordering them ourselves.
556 MPI_Comm_split(m_comm, flag ? flag : MPI_UNDEFINED, 0, &newComm);
557
558 if (flag == 0)
559 {
560 // flag == 0 => get back MPI_COMM_NULL, return a null ptr instead.
561 return std::shared_ptr<Comm>();
562 }
563 else
564 {
565 // Return a real communicator
566 return std::shared_ptr<Comm>(new CommMpi(newComm));
567 }
568}
569
570/**
571 *
572 */
573std::pair<CommSharedPtr, CommSharedPtr> CommMpi::v_SplitCommNode()
574{
575 std::pair<CommSharedPtr, CommSharedPtr> ret;
576
577#if MPI_VERSION < 3
578 ASSERTL0(false, "Not implemented for non-MPI-3 versions.");
579#else
580 // Create an intra-node communicator.
581 MPI_Comm nodeComm;
582 MPI_Comm_split_type(MPI_COMM_WORLD, MPI_COMM_TYPE_SHARED, m_rank,
583 MPI_INFO_NULL, &nodeComm);
584
585 // For rank 0 of the intra-node communicator, split the main
586 // communicator. Everyone else will get a null communicator.
587 ret.first = std::shared_ptr<Comm>(new CommMpi(nodeComm));
588 ret.second = CommMpi::v_CommCreateIf(ret.first->GetRank() == 0);
589 if (ret.first->GetRank() == 0)
590 {
591 ret.second->SplitComm(1, ret.second->GetSize());
592 }
593#endif
594
595 return ret;
596}
597
598} // namespace Nektar::LibUtilities
#define MPISYNC
Definition: CommMpi.h:45
#define ASSERTL0(condition, msg)
Definition: ErrorUtil.hpp:208
#define NEKERROR(type, msg)
Assert Level 0 – Fundamental assert which is used whether in FULLDEBUG, DEBUG or OPT compilation mode...
Definition: ErrorUtil.hpp:202
Base communications class.
Definition: Comm.h:88
CommSharedPtr m_commColumn
Column communicator.
Definition: Comm.h:178
CommSharedPtr m_commRow
Row communicator.
Definition: Comm.h:177
int GetSize() const
Returns number of processes.
Definition: Comm.h:264
CommSharedPtr m_commTime
Definition: Comm.h:179
int m_size
Number of processes.
Definition: Comm.h:175
std::string m_type
Type of communication.
Definition: Comm.h:176
CommSharedPtr m_commSpace
Definition: Comm.h:180
void v_WaitAll(CommRequestSharedPtr request) final
Definition: CommMpi.cpp:470
void v_AlltoAll(void *sendbuf, int sendcount, CommDataType sendtype, void *recvbuf, int recvcount, CommDataType recvtype) final
Definition: CommMpi.cpp:262
CommRequestSharedPtr v_CreateRequest(int num) final
Definition: CommMpi.cpp:484
CommMpi(int narg, char *arg[])
Definition: CommMpi.cpp:50
void v_DistGraphCreateAdjacent(int indegree, const int sources[], const int sourceweights[], int reorder) final
Definition: CommMpi.cpp:361
void v_Bcast(void *buffer, int count, CommDataType dt, int root) final
Definition: CommMpi.cpp:325
void v_AllGatherv(void *sendbuf, int sendcount, CommDataType sendtype, void *recvbuf, int recvcounts[], int rdispls[], CommDataType recvtype) final
Definition: CommMpi.cpp:300
CommSharedPtr v_CommCreateIf(int flag) final
Definition: CommMpi.cpp:550
std::pair< CommSharedPtr, CommSharedPtr > v_SplitCommNode() final
Definition: CommMpi.cpp:573
void v_SendRecv(void *sendbuf, int sendcount, CommDataType sendtype, int dest, void *recvbuf, int recvcount, CommDataType recvtype, int source) final
Definition: CommMpi.cpp:217
void v_NeighborAlltoAllv(void *sendbuf, int sendcounts[], int sensdispls[], CommDataType sendtype, void *recvbuf, int recvcounts[], int rdispls[], CommDataType recvtype) final
Definition: CommMpi.cpp:381
void v_AllGather(void *sendbuf, int sendcount, CommDataType sendtype, void *recvbuf, int recvcount, CommDataType recvtype) final
Definition: CommMpi.cpp:288
void v_Recv(void *buf, int count, CommDataType dt, int source) final
Definition: CommMpi.cpp:209
void v_StartAll(CommRequestSharedPtr request) final
Definition: CommMpi.cpp:457
void v_AllReduce(void *buf, int count, CommDataType dt, enum ReduceOperator pOp) final
Definition: CommMpi.cpp:232
void v_Finalise() override
Definition: CommMpi.cpp:125
void v_Irecv(void *buf, int count, CommDataType dt, int source, CommRequestSharedPtr request, int loc) final
Definition: CommMpi.cpp:435
void v_Irsend(void *buf, int count, CommDataType dt, int dest, CommRequestSharedPtr request, int loc) final
Definition: CommMpi.cpp:402
void v_Gather(void *sendbuf, int sendcount, CommDataType sendtype, void *recvbuf, int recvcount, CommDataType recvtype, int root) final
Definition: CommMpi.cpp:335
static std::string className
Name of class.
Definition: CommMpi.h:99
void v_RecvInit(void *buf, int count, CommDataType dt, int source, CommRequestSharedPtr request, int loc) final
Definition: CommMpi.cpp:446
static CommSharedPtr create(int narg, char *arg[])
Creates an instance of this class.
Definition: CommMpi.h:93
void v_Isend(void *buf, int count, CommDataType dt, int dest, CommRequestSharedPtr request, int loc) final
Definition: CommMpi.cpp:413
void v_AlltoAllv(void *sendbuf, int sendcounts[], int sensdispls[], CommDataType sendtype, void *recvbuf, int recvcounts[], int rdispls[], CommDataType recvtype) final
Definition: CommMpi.cpp:274
bool v_TreatAsRankZero() final
Definition: CommMpi.cpp:149
void v_Send(void *buf, int count, CommDataType dt, int dest) final
Definition: CommMpi.cpp:194
void v_SplitComm(int pRows, int pColumns, int pTime) override
Definition: CommMpi.cpp:495
void v_Scatter(void *sendbuf, int sendcount, CommDataType sendtype, void *recvbuf, int recvcount, CommDataType recvtype, int root) final
Definition: CommMpi.cpp:348
std::tuple< int, int, int > v_GetVersion() final
Definition: CommMpi.cpp:165
void v_SendInit(void *buf, int count, CommDataType dt, int dest, CommRequestSharedPtr request, int loc) final
Definition: CommMpi.cpp:424
Class for communicator request type.
Definition: CommMpi.h:60
tKey RegisterCreatorFunction(tKey idKey, CreatorFunction classCreator, std::string pDesc="")
Register a class with the factory.
array buffer
Definition: GsLib.hpp:81
unsigned int CommDataType
Definition: CommDataType.h:65
std::shared_ptr< CommRequest > CommRequestSharedPtr
Definition: Comm.h:84
std::shared_ptr< CommRequestMpi > CommRequestMpiSharedPtr
Definition: CommMpi.h:86
CommFactory & GetCommFactory()
ReduceOperator
Type of operation to perform in AllReduce.
Definition: Comm.h:65
std::shared_ptr< Comm > CommSharedPtr
Pointer to a Communicator object.
Definition: Comm.h:55