Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update collective framework count/disp arrays for bigcount #12621

Merged
merged 1 commit into from
Jul 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion ompi/communicator/comm.c
Original file line number Diff line number Diff line change
Expand Up @@ -2402,6 +2402,8 @@ int ompi_comm_determine_first ( ompi_communicator_t *intercomm, int high )
int rank, rsize;
int *rcounts;
int *rdisps;
ompi_count_array_t rcounts_desc;
ompi_disp_array_t rdisps_desc;
int scount=0;
int rc;

Expand Down Expand Up @@ -2429,8 +2431,10 @@ int ompi_comm_determine_first ( ompi_communicator_t *intercomm, int high )
scount = 1;
}

OMPI_COUNT_ARRAY_INIT(&rcounts_desc, rcounts);
OMPI_DISP_ARRAY_INIT(&rdisps_desc, rdisps);
rc = intercomm->c_coll->coll_allgatherv(&high, scount, MPI_INT,
&rhigh, rcounts, rdisps,
&rhigh, rcounts_desc, rdisps_desc,
MPI_INT, intercomm,
intercomm->c_coll->coll_allgatherv_module);
if ( NULL != rdisps ) {
Expand Down
153 changes: 94 additions & 59 deletions ompi/mca/coll/base/coll_base_allgatherv.c

Large diffs are not rendered by default.

65 changes: 34 additions & 31 deletions ompi/mca/coll/base/coll_base_alltoallv.c
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
* and count) to send the data to the other.
*/
int
mca_coll_base_alltoallv_intra_basic_inplace(const void *rbuf, const int *rcounts, const int *rdisps,
mca_coll_base_alltoallv_intra_basic_inplace(const void *rbuf, ompi_count_array_t rcounts, ompi_disp_array_t rdisps,
struct ompi_datatype_t *rdtype,
struct ompi_communicator_t *comm,
mca_coll_base_module_t *module)
Expand All @@ -72,7 +72,7 @@ mca_coll_base_alltoallv_intra_basic_inplace(const void *rbuf, const int *rcounts
if (i == rank) {
continue;
}
packed_size = rcounts[i] * type_size;
packed_size = ompi_count_array_get(rcounts, i) * type_size;
max_size = opal_max(packed_size, max_size);
}

Expand Down Expand Up @@ -111,11 +111,11 @@ mca_coll_base_alltoallv_intra_basic_inplace(const void *rbuf, const int *rcounts
right = (rank + i) % size;
left = (rank + size - i) % size;

if( 0 != rcounts[right] ) { /* nothing to exchange with the peer on the right */
if( 0 != ompi_count_array_get(rcounts, right) ) { /* nothing to exchange with the peer on the right */
ompi_proc_t *right_proc = ompi_comm_peer_lookup(comm, right);
opal_convertor_clone(right_proc->super.proc_convertor, &convertor, 0);
opal_convertor_prepare_for_send(&convertor, &rdtype->super, rcounts[right],
(char *) rbuf + rdisps[right] * extent);
opal_convertor_prepare_for_send(&convertor, &rdtype->super, ompi_count_array_get(rcounts, right),
(char *) rbuf + ompi_disp_array_get(rdisps, right) * extent);
packed_size = max_size;
err = opal_convertor_pack(&convertor, &iov, &iov_count, &packed_size);
if (1 != err) {
Expand All @@ -124,17 +124,19 @@ mca_coll_base_alltoallv_intra_basic_inplace(const void *rbuf, const int *rcounts
}

/* Receive data from the right */
err = MCA_PML_CALL(irecv ((char *) rbuf + rdisps[right] * extent, rcounts[right], rdtype,
err = MCA_PML_CALL(irecv ((char *) rbuf + ompi_disp_array_get(rdisps, right) * extent,
ompi_count_array_get(rcounts, right), rdtype,
right, MCA_COLL_BASE_TAG_ALLTOALLV, comm, &req));
if (MPI_SUCCESS != err) {
line = __LINE__;
goto error_hndl;
}
}

if( (left != right) && (0 != rcounts[left]) ) {
if( (left != right) && (0 != ompi_count_array_get(rcounts, left)) ) {
/* Send data to the left */
err = MCA_PML_CALL(send ((char *) rbuf + rdisps[left] * extent, rcounts[left], rdtype,
err = MCA_PML_CALL(send ((char *) rbuf + ompi_disp_array_get(rdisps, left) * extent,
ompi_count_array_get(rcounts, left), rdtype,
left, MCA_COLL_BASE_TAG_ALLTOALLV, MCA_PML_BASE_SEND_STANDARD,
comm));
if (MPI_SUCCESS != err) {
Expand All @@ -149,15 +151,16 @@ mca_coll_base_alltoallv_intra_basic_inplace(const void *rbuf, const int *rcounts
}

/* Receive data from the left */
err = MCA_PML_CALL(irecv ((char *) rbuf + rdisps[left] * extent, rcounts[left], rdtype,
err = MCA_PML_CALL(irecv ((char *) rbuf + ompi_disp_array_get(rdisps, left) * extent,
ompi_count_array_get(rcounts, left), rdtype,
left, MCA_COLL_BASE_TAG_ALLTOALLV, comm, &req));
if (MPI_SUCCESS != err) {
line = __LINE__;
goto error_hndl;
}
}

if( 0 != rcounts[right] ) { /* nothing to exchange with the peer on the right */
if( 0 != ompi_count_array_get(rcounts, right) ) { /* nothing to exchange with the peer on the right */
/* Send data to the right */
err = MCA_PML_CALL(send ((char *) tmp_buffer, packed_size, MPI_PACKED,
right, MCA_COLL_BASE_TAG_ALLTOALLV, MCA_PML_BASE_SEND_STANDARD,
Expand Down Expand Up @@ -191,9 +194,9 @@ mca_coll_base_alltoallv_intra_basic_inplace(const void *rbuf, const int *rcounts
}

int
ompi_coll_base_alltoallv_intra_pairwise(const void *sbuf, const int *scounts, const int *sdisps,
ompi_coll_base_alltoallv_intra_pairwise(const void *sbuf, ompi_count_array_t scounts, ompi_disp_array_t sdisps,
struct ompi_datatype_t *sdtype,
void* rbuf, const int *rcounts, const int *rdisps,
void* rbuf, ompi_count_array_t rcounts, ompi_disp_array_t rdisps,
struct ompi_datatype_t *rdtype,
struct ompi_communicator_t *comm,
mca_coll_base_module_t *module)
Expand Down Expand Up @@ -230,21 +233,21 @@ ompi_coll_base_alltoallv_intra_pairwise(const void *sbuf, const int *scounts, co
recvfrom = (rank + size - step) % size;

/* Determine sending and receiving locations */
psnd = (char*)sbuf + (ptrdiff_t)sdisps[sendto] * sext;
prcv = (char*)rbuf + (ptrdiff_t)rdisps[recvfrom] * rext;
psnd = (char*)sbuf + ompi_disp_array_get(sdisps, sendto) * sext;
prcv = (char*)rbuf + ompi_disp_array_get(rdisps, recvfrom) * rext;

/* send and receive */
if (0 < rcounts[recvfrom] && 0 < rdtype_size) {
err = MCA_PML_CALL(irecv(prcv, rcounts[recvfrom], rdtype, recvfrom,
if (0 < ompi_count_array_get(rcounts, recvfrom) && 0 < rdtype_size) {
err = MCA_PML_CALL(irecv(prcv, ompi_count_array_get(rcounts, recvfrom), rdtype, recvfrom,
MCA_COLL_BASE_TAG_ALLTOALLV, comm, &req));
if (MPI_SUCCESS != err) {
line = __LINE__;
goto err_hndl;
}
}

if (0 < scounts[sendto] && 0 < sdtype_size) {
err = MCA_PML_CALL(send(psnd, scounts[sendto], sdtype, sendto,
if (0 < ompi_count_array_get(scounts, sendto) && 0 < sdtype_size) {
err = MCA_PML_CALL(send(psnd, ompi_count_array_get(scounts, sendto), sdtype, sendto,
MCA_COLL_BASE_TAG_ALLTOALLV, MCA_PML_BASE_SEND_STANDARD, comm));
if (MPI_SUCCESS != err) {
line = __LINE__;
Expand Down Expand Up @@ -280,9 +283,9 @@ ompi_coll_base_alltoallv_intra_pairwise(const void *sbuf, const int *scounts, co
* differently and so will not have to duplicate code.
*/
int
ompi_coll_base_alltoallv_intra_basic_linear(const void *sbuf, const int *scounts, const int *sdisps,
ompi_coll_base_alltoallv_intra_basic_linear(const void *sbuf, ompi_count_array_t scounts, ompi_disp_array_t sdisps,
struct ompi_datatype_t *sdtype,
void *rbuf, const int *rcounts, const int *rdisps,
void *rbuf, ompi_count_array_t rcounts, ompi_disp_array_t rdisps,
struct ompi_datatype_t *rdtype,
struct ompi_communicator_t *comm,
mca_coll_base_module_t *module)
Expand Down Expand Up @@ -313,11 +316,11 @@ ompi_coll_base_alltoallv_intra_basic_linear(const void *sbuf, const int *scounts
ompi_datatype_type_extent(rdtype, &rext);

/* Simple optimization - handle send to self first */
psnd = ((char *) sbuf) + (ptrdiff_t)sdisps[rank] * sext;
prcv = ((char *) rbuf) + (ptrdiff_t)rdisps[rank] * rext;
if (0 < scounts[rank] && 0 < sdtype_size) {
err = ompi_datatype_sndrcv(psnd, scounts[rank], sdtype,
prcv, rcounts[rank], rdtype);
psnd = ((char *) sbuf) + ompi_disp_array_get(sdisps, rank) * sext;
prcv = ((char *) rbuf) + ompi_disp_array_get(rdisps, rank) * rext;
if (0 < ompi_count_array_get(scounts, rank) && 0 < sdtype_size) {
err = ompi_datatype_sndrcv(psnd, ompi_count_array_get(scounts, rank), sdtype,
prcv, ompi_count_array_get(rcounts, rank), rdtype);
if (MPI_SUCCESS != err) {
return err;
}
Expand All @@ -339,10 +342,10 @@ ompi_coll_base_alltoallv_intra_basic_linear(const void *sbuf, const int *scounts
continue;
}

if (0 < rcounts[i] && 0 < rdtype_size) {
if (0 < ompi_count_array_get(rcounts, i) && 0 < rdtype_size) {
++nreqs;
prcv = ((char *) rbuf) + (ptrdiff_t)rdisps[i] * rext;
err = MCA_PML_CALL(irecv_init(prcv, rcounts[i], rdtype,
prcv = ((char *) rbuf) + ompi_disp_array_get(rdisps, i) * rext;
err = MCA_PML_CALL(irecv_init(prcv, ompi_count_array_get(rcounts, i), rdtype,
i, MCA_COLL_BASE_TAG_ALLTOALLV, comm,
preq++));
if (MPI_SUCCESS != err) { goto err_hndl; }
Expand All @@ -355,10 +358,10 @@ ompi_coll_base_alltoallv_intra_basic_linear(const void *sbuf, const int *scounts
continue;
}

if (0 < scounts[i] && 0 < sdtype_size) {
if (0 < ompi_count_array_get(scounts, i) && 0 < sdtype_size) {
++nreqs;
psnd = ((char *) sbuf) + (ptrdiff_t)sdisps[i] * sext;
err = MCA_PML_CALL(isend_init(psnd, scounts[i], sdtype,
psnd = ((char *) sbuf) + ompi_disp_array_get(sdisps, i) * sext;
err = MCA_PML_CALL(isend_init(psnd, ompi_count_array_get(scounts, i), sdtype,
i, MCA_COLL_BASE_TAG_ALLTOALLV,
MCA_PML_BASE_SEND_STANDARD, comm,
preq++));
Expand Down
20 changes: 10 additions & 10 deletions ompi/mca/coll/base/coll_base_functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,27 +69,27 @@ typedef enum COLLTYPE {

/* defined arg lists to simply auto inclusion of user overriding decision functions */
#define ALLGATHER_BASE_ARGS const void *sendbuf, size_t sendcount, struct ompi_datatype_t *sendtype, void *recvbuf, size_t recvcount, struct ompi_datatype_t *recvtype, struct ompi_communicator_t *comm
#define ALLGATHERV_BASE_ARGS const void *sendbuf, int sendcount, struct ompi_datatype_t *sendtype, void *recvbuf, const int recvcounts[], const int displs[], struct ompi_datatype_t *recvtype, struct ompi_communicator_t *comm
#define ALLGATHERV_BASE_ARGS const void *sendbuf, size_t sendcount, struct ompi_datatype_t *sendtype, void *recvbuf, ompi_count_array_t recvcounts, ompi_disp_array_t displs, struct ompi_datatype_t *recvtype, struct ompi_communicator_t *comm
#define ALLREDUCE_BASE_ARGS const void *sendbuf, void *recvbuf, size_t count, struct ompi_datatype_t *datatype, struct ompi_op_t *op, struct ompi_communicator_t *comm
#define ALLTOALL_BASE_ARGS const void *sendbuf, size_t sendcount, struct ompi_datatype_t *sendtype, void *recvbuf, size_t recvcount, struct ompi_datatype_t *recvtype, struct ompi_communicator_t *comm
#define ALLTOALLV_BASE_ARGS const void *sendbuf, const int sendcounts[], const int sdispls[], struct ompi_datatype_t *sendtype, void *recvbuf, const int recvcounts[], const int rdispls[], struct ompi_datatype_t *recvtype, struct ompi_communicator_t *comm
#define ALLTOALLW_BASE_ARGS const void *sendbuf, const int sendcounts[], const int sdispls[], struct ompi_datatype_t * const sendtypes[], void *recvbuf, const int recvcounts[], const int rdispls[], struct ompi_datatype_t * const recvtypes[], struct ompi_communicator_t *comm
#define ALLTOALLV_BASE_ARGS const void *sendbuf, ompi_count_array_t sendcounts, ompi_disp_array_t sdispls, struct ompi_datatype_t *sendtype, void *recvbuf, ompi_count_array_t recvcounts, ompi_disp_array_t rdispls, struct ompi_datatype_t *recvtype, struct ompi_communicator_t *comm
#define ALLTOALLW_BASE_ARGS const void *sendbuf, ompi_count_array_t sendcounts, ompi_disp_array_t sdispls, struct ompi_datatype_t * const sendtypes[], void *recvbuf, ompi_count_array_t recvcounts, ompi_disp_array_t rdispls, struct ompi_datatype_t * const recvtypes[], struct ompi_communicator_t *comm
#define BARRIER_BASE_ARGS struct ompi_communicator_t *comm
#define BCAST_BASE_ARGS void *buffer, size_t count, struct ompi_datatype_t *datatype, int root, struct ompi_communicator_t *comm
#define EXSCAN_BASE_ARGS const void *sendbuf, void *recvbuf, size_t count, struct ompi_datatype_t *datatype, struct ompi_op_t *op, struct ompi_communicator_t *comm
#define GATHER_BASE_ARGS const void *sendbuf, size_t sendcount, struct ompi_datatype_t *sendtype, void *recvbuf, size_t recvcount, struct ompi_datatype_t *recvtype, int root, struct ompi_communicator_t *comm
#define GATHERV_BASE_ARGS const void *sendbuf, int sendcount, struct ompi_datatype_t *sendtype, void *recvbuf, const int recvcounts[], const int displs[], struct ompi_datatype_t *recvtype, int root, struct ompi_communicator_t *comm
#define GATHERV_BASE_ARGS const void *sendbuf, size_t sendcount, struct ompi_datatype_t *sendtype, void *recvbuf, ompi_count_array_t recvcounts, ompi_disp_array_t displs, struct ompi_datatype_t *recvtype, int root, struct ompi_communicator_t *comm
#define REDUCE_BASE_ARGS const void *sendbuf, void *recvbuf, size_t count, struct ompi_datatype_t *datatype, struct ompi_op_t *op, int root, struct ompi_communicator_t *comm
#define REDUCESCATTER_BASE_ARGS const void *sendbuf, void *recvbuf, const int recvcounts[], struct ompi_datatype_t *datatype, struct ompi_op_t *op, struct ompi_communicator_t *comm
#define REDUCESCATTER_BASE_ARGS const void *sendbuf, void *recvbuf, ompi_count_array_t recvcounts, struct ompi_datatype_t *datatype, struct ompi_op_t *op, struct ompi_communicator_t *comm
#define REDUCESCATTERBLOCK_BASE_ARGS const void *sendbuf, void *recvbuf, size_t recvcount, struct ompi_datatype_t *datatype, struct ompi_op_t *op, struct ompi_communicator_t *comm
#define SCAN_BASE_ARGS const void *sendbuf, void *recvbuf, size_t count, struct ompi_datatype_t *datatype, struct ompi_op_t *op, struct ompi_communicator_t *comm
#define SCATTER_BASE_ARGS const void *sendbuf, size_t sendcount, struct ompi_datatype_t *sendtype, void *recvbuf, size_t recvcount, struct ompi_datatype_t *recvtype, int root, struct ompi_communicator_t *comm
#define SCATTERV_BASE_ARGS const void *sendbuf, const int sendcounts[], const int displs[], struct ompi_datatype_t *sendtype, void *recvbuf, int recvcount, struct ompi_datatype_t *recvtype, int root, struct ompi_communicator_t *comm
#define SCATTERV_BASE_ARGS const void *sendbuf, ompi_count_array_t sendcounts, ompi_disp_array_t displs, struct ompi_datatype_t *sendtype, void *recvbuf, size_t recvcount, struct ompi_datatype_t *recvtype, int root, struct ompi_communicator_t *comm
#define NEIGHBOR_ALLGATHER_BASE_ARGS const void *sendbuf, size_t sendcount, struct ompi_datatype_t *sendtype, void *recvbuf, size_t recvcount, struct ompi_datatype_t *recvtype, struct ompi_communicator_t *comm
#define NEIGHBOR_ALLGATHERV_BASE_ARGS const void *sendbuf, int sendcount, struct ompi_datatype_t *sendtype, void *recvbuf, const int recvcounts[], const int displs[], struct ompi_datatype_t *recvtype, struct ompi_communicator_t *comm
#define NEIGHBOR_ALLGATHERV_BASE_ARGS const void *sendbuf, size_t sendcount, struct ompi_datatype_t *sendtype, void *recvbuf, ompi_count_array_t recvcounts, ompi_disp_array_t displs, struct ompi_datatype_t *recvtype, struct ompi_communicator_t *comm
#define NEIGHBOR_ALLTOALL_BASE_ARGS const void *sendbuf, size_t sendcount, struct ompi_datatype_t *sendtype, void *recvbuf, size_t recvcount, struct ompi_datatype_t *recvtype, struct ompi_communicator_t *comm
#define NEIGHBOR_ALLTOALLV_BASE_ARGS const void *sendbuf, const int sendcounts[], const int sdispls[], struct ompi_datatype_t *sendtype, void *recvbuf, const int recvcounts[], const int rdispls[], struct ompi_datatype_t *recvtype, struct ompi_communicator_t *comm
#define NEIGHBOR_ALLTOALLW_BASE_ARGS const void *sendbuf, const int sendcounts[], const MPI_Aint sdispls[], struct ompi_datatype_t * const sendtypes[], void *recvbuf, const int recvcounts[], const MPI_Aint rdispls[], struct ompi_datatype_t * const recvtypes[], struct ompi_communicator_t *comm
#define NEIGHBOR_ALLTOALLV_BASE_ARGS const void *sendbuf, ompi_count_array_t sendcounts, ompi_disp_array_t sdispls, struct ompi_datatype_t *sendtype, void *recvbuf, ompi_count_array_t recvcounts, ompi_disp_array_t rdispls, struct ompi_datatype_t *recvtype, struct ompi_communicator_t *comm
#define NEIGHBOR_ALLTOALLW_BASE_ARGS const void *sendbuf, ompi_count_array_t sendcounts, ompi_disp_array_t sdispls, struct ompi_datatype_t * const sendtypes[], void *recvbuf, ompi_count_array_t recvcounts, ompi_disp_array_t rdispls, struct ompi_datatype_t * const recvtypes[], struct ompi_communicator_t *comm

#define ALLGATHER_ARGS ALLGATHER_BASE_ARGS, mca_coll_base_module_t *module
#define ALLGATHERV_ARGS ALLGATHERV_BASE_ARGS, mca_coll_base_module_t *module
Expand Down Expand Up @@ -227,7 +227,7 @@ int mca_coll_base_alltoall_intra_basic_inplace(const void *rbuf, size_t rcount,
/* AlltoAllV */
int ompi_coll_base_alltoallv_intra_pairwise(ALLTOALLV_ARGS);
int ompi_coll_base_alltoallv_intra_basic_linear(ALLTOALLV_ARGS);
int mca_coll_base_alltoallv_intra_basic_inplace(const void *rbuf, const int *rcounts, const int *rdisps,
int mca_coll_base_alltoallv_intra_basic_inplace(const void *rbuf, ompi_count_array_t rcounts, ompi_disp_array_t rdisps,
struct ompi_datatype_t *rdtype,
struct ompi_communicator_t *comm,
mca_coll_base_module_t *module); /* special version for INPLACE */
Expand Down
Loading
Loading