Nektar++
CommMpi.cpp
Go to the documentation of this file.
1///////////////////////////////////////////////////////////////////////////////
2//
3// File: CommMpi.cpp
4//
5// For more information, please see: http://www.nektar.info
6//
7// The MIT License
8//
9// Copyright (c) 2006 Division of Applied Mathematics, Brown University (USA),
10// Department of Aeronautics, Imperial College London (UK), and Scientific
11// Computing and Imaging Institute, University of Utah (USA).
12//
13// Permission is hereby granted, free of charge, to any person obtaining a
14// copy of this software and associated documentation files (the "Software"),
15// to deal in the Software without restriction, including without limitation
16// the rights to use, copy, modify, merge, publish, distribute, sublicense,
17// and/or sell copies of the Software, and to permit persons to whom the
18// Software is furnished to do so, subject to the following conditions:
19//
20// The above copyright notice and this permission notice shall be included
21// in all copies or substantial portions of the Software.
22//
23// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
24// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
25// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
26// THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
27// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
28// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
29// DEALINGS IN THE SOFTWARE.
30//
31// Description: MPI communication implementation
32//
33///////////////////////////////////////////////////////////////////////////////
34
35#ifdef NEKTAR_USING_PETSC
36#include "petscsys.h"
37#endif
38
40
42{
43
45 "ParallelMPI", CommMpi::create, "Parallel communication using MPI.");
46
47/**
48 *
49 */
50CommMpi::CommMpi(int narg, char *arg[]) : Comm(narg, arg)
51{
52 int init = 0;
53 MPI_Initialized(&init);
54
55 if (!init)
56 {
57 ASSERTL0(MPI_Init(&narg, &arg) == MPI_SUCCESS,
58 "Failed to initialise MPI");
59 // store bool to indicate that Nektar++ is in charge of finalizing MPI.
60 m_controls_mpi = true;
61 }
62 else
63 {
64 // Another code is in charge of finalizing MPI and this is not the
65 // responsiblity of Nektar++
66 m_controls_mpi = false;
67 }
68
69 m_comm = MPI_COMM_WORLD;
70 MPI_Comm_size(m_comm, &m_size);
71 MPI_Comm_rank(m_comm, &m_rank);
72
73#ifdef NEKTAR_USING_PETSC
74 PetscInitializeNoArguments();
75#endif
76
77 m_type = "Parallel MPI";
78}
79
80/**
81 *
82 */
83CommMpi::CommMpi(MPI_Comm pComm) : Comm()
84{
85 m_comm = pComm;
86 MPI_Comm_size(m_comm, &m_size);
87 MPI_Comm_rank(m_comm, &m_rank);
88
89 m_type = "Parallel MPI";
90}
91
92/**
93 *
94 */
96{
97 int flag;
98 MPI_Finalized(&flag);
99 if (!flag && m_comm != MPI_COMM_WORLD)
100 {
101 MPI_Comm_free(&m_comm);
102 }
103}
104
105/**
106 *
107 */
109{
110 return m_comm;
111}
112
113/**
114 *
115 */
117{
118#ifdef NEKTAR_USING_PETSC
119 PetscFinalize();
120#endif
121 int flag;
122 MPI_Finalized(&flag);
123 if ((!flag) && m_controls_mpi)
124 {
125 MPI_Finalize();
126 }
127}
128
129/**
130 *
131 */
133{
134 return m_rank;
135}
136
137/**
138 *
139 */
141{
142 return m_rank == 0;
143}
144
145/**
146 *
147 */
149{
150 return m_size == 1;
151}
152
153/**
154 *
155 */
156std::tuple<int, int, int> CommMpi::v_GetVersion()
157{
158 int version, subversion;
159 int retval = MPI_Get_version(&version, &subversion);
160
161 ASSERTL0(retval == MPI_SUCCESS, "MPI error performing GetVersion.");
162
163 return std::make_tuple(version, subversion, 0);
164}
165
166/**
167 *
168 */
170{
171 MPI_Barrier(m_comm);
172}
173
174/**
175 *
176 */
178{
179 return MPI_Wtime();
180}
181
182/**
183 *
184 */
185void CommMpi::v_Send(void *buf, int count, CommDataType dt, int dest)
186{
187 if (MPISYNC)
188 {
189 MPI_Ssend(buf, count, dt, dest, 0, m_comm);
190 }
191 else
192 {
193 MPI_Send(buf, count, dt, dest, 0, m_comm);
194 }
195}
196
197/**
198 *
199 */
200void CommMpi::v_Recv(void *buf, int count, CommDataType dt, int source)
201{
202 MPI_Recv(buf, count, dt, source, 0, m_comm, MPI_STATUS_IGNORE);
203}
204
205/**
206 *
207 */
208void CommMpi::v_SendRecv(void *sendbuf, int sendcount, CommDataType sendtype,
209 int dest, void *recvbuf, int recvcount,
210 CommDataType recvtype, int source)
211{
212 MPI_Status status;
213 int retval = MPI_Sendrecv(sendbuf, sendcount, sendtype, dest, 0, recvbuf,
214 recvcount, recvtype, source, 0, m_comm, &status);
215
216 ASSERTL0(retval == MPI_SUCCESS,
217 "MPI error performing send-receive of data.");
218}
219
220/**
221 *
222 */
223void CommMpi::v_AllReduce(void *buf, int count, CommDataType dt,
224 enum ReduceOperator pOp)
225{
226 if (GetSize() == 1)
227 {
228 return;
229 }
230
231 MPI_Op vOp;
232 switch (pOp)
233 {
234 case ReduceMax:
235 vOp = MPI_MAX;
236 break;
237 case ReduceMin:
238 vOp = MPI_MIN;
239 break;
240 case ReduceSum:
241 default:
242 vOp = MPI_SUM;
243 break;
244 }
245 int retval = MPI_Allreduce(MPI_IN_PLACE, buf, count, dt, vOp, m_comm);
246
247 ASSERTL0(retval == MPI_SUCCESS, "MPI error performing All-reduce.");
248}
249
250/**
251 *
252 */
253void CommMpi::v_AlltoAll(void *sendbuf, int sendcount, CommDataType sendtype,
254 void *recvbuf, int recvcount, CommDataType recvtype)
255{
256 int retval = MPI_Alltoall(sendbuf, sendcount, sendtype, recvbuf, recvcount,
257 recvtype, m_comm);
258
259 ASSERTL0(retval == MPI_SUCCESS, "MPI error performing All-to-All.");
260}
261
262/**
263 *
264 */
265void CommMpi::v_AlltoAllv(void *sendbuf, int sendcounts[], int sdispls[],
266 CommDataType sendtype, void *recvbuf,
267 int recvcounts[], int rdispls[],
268 CommDataType recvtype)
269{
270 int retval = MPI_Alltoallv(sendbuf, sendcounts, sdispls, sendtype, recvbuf,
271 recvcounts, rdispls, recvtype, m_comm);
272
273 ASSERTL0(retval == MPI_SUCCESS, "MPI error performing All-to-All-v.");
274}
275
276/**
277 *
278 */
279void CommMpi::v_AllGather(void *sendbuf, int sendcount, CommDataType sendtype,
280 void *recvbuf, int recvcount, CommDataType recvtype)
281{
282 int retval = MPI_Allgather(sendbuf, sendcount, sendtype, recvbuf, recvcount,
283 recvtype, m_comm);
284
285 ASSERTL0(retval == MPI_SUCCESS, "MPI error performing Allgather.");
286}
287
288/**
289 *
290 */
291void CommMpi::v_AllGatherv(void *sendbuf, int sendcount, CommDataType sendtype,
292 void *recvbuf, int recvcounts[], int rdispls[],
293 CommDataType recvtype)
294{
295 int retval = MPI_Allgatherv(sendbuf, sendcount, sendtype, recvbuf,
296 recvcounts, rdispls, recvtype, m_comm);
297
298 ASSERTL0(retval == MPI_SUCCESS, "MPI error performing Allgatherv.");
299}
300
301/**
302 *
303 */
304void CommMpi::v_AllGatherv(void *recvbuf, int recvcounts[], int rdispls[],
305 CommDataType recvtype)
306{
307 int retval = MPI_Allgatherv(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, recvbuf,
308 recvcounts, rdispls, recvtype, m_comm);
309
310 ASSERTL0(retval == MPI_SUCCESS, "MPI error performing Allgatherv.");
311}
312
313/**
314 *
315 */
316void CommMpi::v_Bcast(void *buffer, int count, CommDataType dt, int root)
317{
318 int retval = MPI_Bcast(buffer, count, dt, root, m_comm);
319
320 ASSERTL0(retval == MPI_SUCCESS, "MPI error performing Bcast-v.");
321}
322
323/**
324 *
325 */
326void CommMpi::v_Gather(void *sendbuf, int sendcount, CommDataType sendtype,
327 void *recvbuf, int recvcount, CommDataType recvtype,
328 int root)
329{
330 int retval = MPI_Gather(sendbuf, sendcount, sendtype, recvbuf, recvcount,
331 recvtype, root, m_comm);
332
333 ASSERTL0(retval == MPI_SUCCESS, "MPI error performing Gather.");
334}
335
336/**
337 *
338 */
339void CommMpi::v_Scatter(void *sendbuf, int sendcount, CommDataType sendtype,
340 void *recvbuf, int recvcount, CommDataType recvtype,
341 int root)
342{
343 int retval = MPI_Scatter(sendbuf, sendcount, sendtype, recvbuf, recvcount,
344 recvtype, root, m_comm);
345
346 ASSERTL0(retval == MPI_SUCCESS, "MPI error performing Scatter.");
347}
348
349/**
350 *
351 */
353 [[maybe_unused]] int indegree, [[maybe_unused]] const int sources[],
354 [[maybe_unused]] const int sourceweights[], [[maybe_unused]] int reorder)
355{
356#if MPI_VERSION < 3
357 ASSERTL0(false, "MPI_Dist_graph_create_adjacent is not supported in your "
358 "installed MPI version.");
359#else
360 int retval = MPI_Dist_graph_create_adjacent(
361 m_comm, indegree, sources, sourceweights, indegree, sources,
362 sourceweights, MPI_INFO_NULL, reorder, &m_comm);
363
364 ASSERTL0(retval == MPI_SUCCESS,
365 "MPI error performing Dist_graph_create_adjacent.")
366#endif
367}
368
369/**
370 *
371 */
373 [[maybe_unused]] void *sendbuf, [[maybe_unused]] int sendcounts[],
374 [[maybe_unused]] int sdispls[], [[maybe_unused]] CommDataType sendtype,
375 [[maybe_unused]] void *recvbuf, [[maybe_unused]] int recvcounts[],
376 [[maybe_unused]] int rdispls[], [[maybe_unused]] CommDataType recvtype)
377{
378#if MPI_VERSION < 3
379 ASSERTL0(false, "MPI_Neighbor_alltoallv is not supported in your "
380 "installed MPI version.");
381#else
382 int retval =
383 MPI_Neighbor_alltoallv(sendbuf, sendcounts, sdispls, sendtype, recvbuf,
384 recvcounts, rdispls, recvtype, m_comm);
385
386 ASSERTL0(retval == MPI_SUCCESS, "MPI error performing NeighborAllToAllV.");
387#endif
388}
389
390/**
391 *
392 */
393void CommMpi::v_Irsend(void *buf, int count, CommDataType dt, int dest,
394 CommRequestSharedPtr request, int loc)
395{
397 std::static_pointer_cast<CommRequestMpi>(request);
398 MPI_Irsend(buf, count, dt, dest, 0, m_comm, req->GetRequest(loc));
399}
400
401/**
402 *
403 */
404void CommMpi::v_Isend(void *buf, int count, CommDataType dt, int dest,
405 CommRequestSharedPtr request, int loc)
406{
408 std::static_pointer_cast<CommRequestMpi>(request);
409 MPI_Isend(buf, count, dt, dest, 0, m_comm, req->GetRequest(loc));
410}
411
412/**
413 *
414 */
415void CommMpi::v_SendInit(void *buf, int count, CommDataType dt, int dest,
416 CommRequestSharedPtr request, int loc)
417{
419 std::static_pointer_cast<CommRequestMpi>(request);
420 MPI_Send_init(buf, count, dt, dest, 0, m_comm, req->GetRequest(loc));
421}
422
423/**
424 *
425 */
426void CommMpi::v_Irecv(void *buf, int count, CommDataType dt, int source,
427 CommRequestSharedPtr request, int loc)
428{
430 std::static_pointer_cast<CommRequestMpi>(request);
431 MPI_Irecv(buf, count, dt, source, 0, m_comm, req->GetRequest(loc));
432}
433
434/**
435 *
436 */
437void CommMpi::v_RecvInit(void *buf, int count, CommDataType dt, int source,
438 CommRequestSharedPtr request, int loc)
439{
441 std::static_pointer_cast<CommRequestMpi>(request);
442 MPI_Recv_init(buf, count, dt, source, 0, m_comm, req->GetRequest(loc));
443}
444
445/**
446 *
447 */
449{
451 std::static_pointer_cast<CommRequestMpi>(request);
452 if (req->GetNumRequest() != 0)
453 {
454 MPI_Startall(req->GetNumRequest(), req->GetRequest(0));
455 }
456}
457
458/**
459 *
460 */
462{
464 std::static_pointer_cast<CommRequestMpi>(request);
465 if (req->GetNumRequest() != 0)
466 {
467 MPI_Waitall(req->GetNumRequest(), req->GetRequest(0),
468 MPI_STATUSES_IGNORE);
469 }
470}
471
472/**
473 *
474 */
476{
477 return std::shared_ptr<CommRequest>(new CommRequestMpi(num));
478}
479
480/**
481 * Processes are considered as a grid of size pRows*pColumns. Comm
482 * objects are created corresponding to the rows and columns of this
483 * grid. The row and column to which this process belongs is stored in
484 * #m_commRow and #m_commColumn.
485 */
486void CommMpi::v_SplitComm(int pRows, int pColumns, int pTime)
487{
488 ASSERTL0(pRows * pColumns * pTime == m_size,
489 "Rows/Columns/Time do not match comm size.");
490
491 MPI_Comm newComm;
492 MPI_Comm gridComm;
493 if (pTime == 1)
494 {
495 // There is a bug in OpenMPI 3.1.3. This bug cause some cases to fail in
496 // buster-full-build-and-test. Failed cases are:
497 //
498 // IncNavierStokesSolver_ChanFlow_3DH1D_FlowrateExplicit_MVM_par
499 // IncNavierStokesSolver_ChanFlow_3DH1D_FlowrateExplicit_MVM_par_hybrid
500
501 // See: https://github.com/open-mpi/ompi/issues/6522
502
503 // Compute row and column in grid.
504 int myCol = m_rank % pColumns;
505 int myRow = (m_rank - myCol) / pColumns;
506
507 // Split Comm into rows - all processes with same myRow are put in
508 // the same communicator. The rank within this communicator is the
509 // column index.
510 MPI_Comm_split(m_comm, myRow, myCol, &newComm);
511 m_commRow = std::shared_ptr<Comm>(new CommMpi(newComm));
512
513 // Split Comm into columns - all processes with same myCol are put
514 // in the same communicator. The rank within this communicator is
515 // the row index.
516 MPI_Comm_split(m_comm, myCol, myRow, &newComm);
517 m_commColumn = std::shared_ptr<Comm>(new CommMpi(newComm));
518 }
519 else
520 {
521 constexpr int dims = 3;
522 const int sizes[dims] = {pRows, pColumns, pTime};
523 const int periods[dims] = {0, 0, 0};
524 constexpr int reorder = 1;
525
526 MPI_Cart_create(m_comm, dims, sizes, periods, reorder, &gridComm);
527
528 constexpr int keepRow[dims] = {0, 1, 0};
529 MPI_Cart_sub(gridComm, keepRow, &newComm);
530 m_commRow = std::shared_ptr<Comm>(new CommMpi(newComm));
531
532 constexpr int keepCol[dims] = {1, 0, 0};
533 MPI_Cart_sub(gridComm, keepCol, &newComm);
534 m_commColumn = std::shared_ptr<Comm>(new CommMpi(newComm));
535
536 constexpr int keepTime[dims] = {0, 0, 1};
537 MPI_Cart_sub(gridComm, keepTime, &newComm);
538 m_commTime = std::shared_ptr<Comm>(new CommMpi(newComm));
539
540 constexpr int keepSpace[dims] = {1, 1, 0};
541 MPI_Cart_sub(gridComm, keepSpace, &newComm);
542 m_commSpace = std::shared_ptr<Comm>(new CommMpi(newComm));
543 }
544}
545
546/**
547 * Create a new communicator if the flag is non-zero.
548 */
550{
551 MPI_Comm newComm;
552 // color == MPI_UNDEF => not in the new communicator
553 // key == 0 on all => use rank to order them. OpenMPI, at least,
554 // implies this is faster than ordering them ourselves.
555 MPI_Comm_split(m_comm, flag ? flag : MPI_UNDEFINED, 0, &newComm);
556
557 if (flag == 0)
558 {
559 // flag == 0 => get back MPI_COMM_NULL, return a null ptr instead.
560 return std::shared_ptr<Comm>();
561 }
562 else
563 {
564 // Return a real communicator
565 return std::shared_ptr<Comm>(new CommMpi(newComm));
566 }
567}
568
569/**
570 *
571 */
572std::pair<CommSharedPtr, CommSharedPtr> CommMpi::v_SplitCommNode()
573{
574 std::pair<CommSharedPtr, CommSharedPtr> ret;
575
576#if MPI_VERSION < 3
577 ASSERTL0(false, "Not implemented for non-MPI-3 versions.");
578#else
579 // Create an intra-node communicator.
580 MPI_Comm nodeComm;
581 MPI_Comm_split_type(MPI_COMM_WORLD, MPI_COMM_TYPE_SHARED, m_rank,
582 MPI_INFO_NULL, &nodeComm);
583
584 // For rank 0 of the intra-node communicator, split the main
585 // communicator. Everyone else will get a null communicator.
586 ret.first = std::shared_ptr<Comm>(new CommMpi(nodeComm));
587 ret.second = CommMpi::v_CommCreateIf(ret.first->GetRank() == 0);
588 if (ret.first->GetRank() == 0)
589 {
590 ret.second->SplitComm(1, ret.second->GetSize());
591 }
592#endif
593
594 return ret;
595}
596
597} // namespace Nektar::LibUtilities
#define MPISYNC
Definition: CommMpi.h:45
#define ASSERTL0(condition, msg)
Definition: ErrorUtil.hpp:208
Base communications class.
Definition: Comm.h:88
CommSharedPtr m_commColumn
Column communicator.
Definition: Comm.h:178
CommSharedPtr m_commRow
Row communicator.
Definition: Comm.h:177
int GetSize() const
Returns number of processes.
Definition: Comm.h:264
CommSharedPtr m_commTime
Definition: Comm.h:179
int m_size
Number of processes.
Definition: Comm.h:175
std::string m_type
Type of communication.
Definition: Comm.h:176
CommSharedPtr m_commSpace
Definition: Comm.h:180
void v_WaitAll(CommRequestSharedPtr request) final
Definition: CommMpi.cpp:461
void v_AlltoAll(void *sendbuf, int sendcount, CommDataType sendtype, void *recvbuf, int recvcount, CommDataType recvtype) final
Definition: CommMpi.cpp:253
CommRequestSharedPtr v_CreateRequest(int num) final
Definition: CommMpi.cpp:475
CommMpi(int narg, char *arg[])
Definition: CommMpi.cpp:50
void v_DistGraphCreateAdjacent(int indegree, const int sources[], const int sourceweights[], int reorder) final
Definition: CommMpi.cpp:352
void v_Bcast(void *buffer, int count, CommDataType dt, int root) final
Definition: CommMpi.cpp:316
void v_AllGatherv(void *sendbuf, int sendcount, CommDataType sendtype, void *recvbuf, int recvcounts[], int rdispls[], CommDataType recvtype) final
Definition: CommMpi.cpp:291
CommSharedPtr v_CommCreateIf(int flag) final
Definition: CommMpi.cpp:549
std::pair< CommSharedPtr, CommSharedPtr > v_SplitCommNode() final
Definition: CommMpi.cpp:572
void v_SendRecv(void *sendbuf, int sendcount, CommDataType sendtype, int dest, void *recvbuf, int recvcount, CommDataType recvtype, int source) final
Definition: CommMpi.cpp:208
void v_NeighborAlltoAllv(void *sendbuf, int sendcounts[], int sensdispls[], CommDataType sendtype, void *recvbuf, int recvcounts[], int rdispls[], CommDataType recvtype) final
Definition: CommMpi.cpp:372
void v_AllGather(void *sendbuf, int sendcount, CommDataType sendtype, void *recvbuf, int recvcount, CommDataType recvtype) final
Definition: CommMpi.cpp:279
void v_Recv(void *buf, int count, CommDataType dt, int source) final
Definition: CommMpi.cpp:200
void v_StartAll(CommRequestSharedPtr request) final
Definition: CommMpi.cpp:448
void v_AllReduce(void *buf, int count, CommDataType dt, enum ReduceOperator pOp) final
Definition: CommMpi.cpp:223
void v_Finalise() override
Definition: CommMpi.cpp:116
void v_Irecv(void *buf, int count, CommDataType dt, int source, CommRequestSharedPtr request, int loc) final
Definition: CommMpi.cpp:426
void v_Irsend(void *buf, int count, CommDataType dt, int dest, CommRequestSharedPtr request, int loc) final
Definition: CommMpi.cpp:393
void v_Gather(void *sendbuf, int sendcount, CommDataType sendtype, void *recvbuf, int recvcount, CommDataType recvtype, int root) final
Definition: CommMpi.cpp:326
static std::string className
Name of class.
Definition: CommMpi.h:99
void v_RecvInit(void *buf, int count, CommDataType dt, int source, CommRequestSharedPtr request, int loc) final
Definition: CommMpi.cpp:437
static CommSharedPtr create(int narg, char *arg[])
Creates an instance of this class.
Definition: CommMpi.h:93
void v_Isend(void *buf, int count, CommDataType dt, int dest, CommRequestSharedPtr request, int loc) final
Definition: CommMpi.cpp:404
void v_AlltoAllv(void *sendbuf, int sendcounts[], int sensdispls[], CommDataType sendtype, void *recvbuf, int recvcounts[], int rdispls[], CommDataType recvtype) final
Definition: CommMpi.cpp:265
bool v_TreatAsRankZero() final
Definition: CommMpi.cpp:140
void v_Send(void *buf, int count, CommDataType dt, int dest) final
Definition: CommMpi.cpp:185
void v_SplitComm(int pRows, int pColumns, int pTime) override
Definition: CommMpi.cpp:486
void v_Scatter(void *sendbuf, int sendcount, CommDataType sendtype, void *recvbuf, int recvcount, CommDataType recvtype, int root) final
Definition: CommMpi.cpp:339
std::tuple< int, int, int > v_GetVersion() final
Definition: CommMpi.cpp:156
void v_SendInit(void *buf, int count, CommDataType dt, int dest, CommRequestSharedPtr request, int loc) final
Definition: CommMpi.cpp:415
Class for communicator request type.
Definition: CommMpi.h:60
tKey RegisterCreatorFunction(tKey idKey, CreatorFunction classCreator, std::string pDesc="")
Register a class with the factory.
Definition: NekFactory.hpp:197
array buffer
Definition: GsLib.hpp:81
unsigned int CommDataType
Definition: CommDataType.h:65
std::shared_ptr< CommRequest > CommRequestSharedPtr
Definition: Comm.h:84
std::shared_ptr< CommRequestMpi > CommRequestMpiSharedPtr
Definition: CommMpi.h:86
CommFactory & GetCommFactory()
ReduceOperator
Type of operation to perform in AllReduce.
Definition: Comm.h:65
std::shared_ptr< Comm > CommSharedPtr
Pointer to a Communicator object.
Definition: Comm.h:55