Skip to content

Commit

Permalink
Add missing large count array changes
Browse files Browse the repository at this point in the history
  • Loading branch information
hzhou committed Feb 3, 2021
1 parent a2a700d commit df18568
Show file tree
Hide file tree
Showing 10 changed files with 29 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@

int MPIR_Allgatherv_inter_remote_gather_local_bcast(const void *sendbuf, MPI_Aint sendcount,
MPI_Datatype sendtype, void *recvbuf,
const MPI_Aint * recvcounts, const int
*displs, MPI_Datatype recvtype,
const MPI_Aint * recvcounts, const MPI_Aint
* displs, MPI_Datatype recvtype,
MPIR_Comm * comm_ptr, MPIR_Errflag_t * errflag)
{
int remote_size, mpi_errno, root, rank;
Expand Down Expand Up @@ -94,7 +94,7 @@ int MPIR_Allgatherv_inter_remote_gather_local_bcast(const void *sendbuf, MPI_Ain

newcomm_ptr = comm_ptr->local_comm;

mpi_errno = MPIR_Type_indexed_impl(remote_size, recvcounts, displs, recvtype, &newtype);
mpi_errno = MPIR_Type_indexed_c_impl(remote_size, recvcounts, displs, recvtype, &newtype);
MPIR_ERR_CHECK(mpi_errno);

mpi_errno = MPIR_Type_commit_impl(&newtype);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ int MPIR_Iallgatherv_inter_sched_remote_gather_local_bcast(const void *sendbuf,

newcomm_ptr = comm_ptr->local_comm;

mpi_errno = MPIR_Type_indexed_impl(remote_size, recvcounts, displs, recvtype, &newtype);
mpi_errno = MPIR_Type_indexed_c_impl(remote_size, recvcounts, displs, recvtype, &newtype);
MPIR_ERR_CHECK(mpi_errno);

mpi_errno = MPIR_Type_commit_impl(&newtype);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ int MPIR_TSP_Iallreduce_sched_intra_recexch_reduce_scatter_recexch_allgatherv(co
int p_of_k, log_pofk, T;
int per_nbr_buffer = 0, rem = 0;
int nvtcs, sink_id, *recv_id = NULL, *vtcs = NULL;
int *send_id = NULL, *reduce_id = NULL, *cnts = NULL, *displs = NULL;
int *send_id = NULL, *reduce_id = NULL;
MPI_Aint *cnts = NULL, *displs = NULL;
bool in_step2 = false;
void *tmp_recvbuf;
void **step1_recvbuf = NULL;
Expand Down Expand Up @@ -108,8 +109,10 @@ int MPIR_TSP_Iallreduce_sched_intra_recexch_reduce_scatter_recexch_allgatherv(co
MPL_DBG_MSG_FMT(MPIR_DBG_COLL, VERBOSE, (MPL_DBG_FDEST, "Start Step2"));

if (in_step2) {
MPIR_CHKLMEM_MALLOC(cnts, int *, sizeof(int) * nranks, mpi_errno, "cnts", MPL_MEM_COLL);
MPIR_CHKLMEM_MALLOC(displs, int *, sizeof(int) * nranks, mpi_errno, "displs", MPL_MEM_COLL);
MPIR_CHKLMEM_MALLOC(cnts, MPI_Aint *, sizeof(MPI_Aint) * nranks, mpi_errno, "cnts",
MPL_MEM_COLL);
MPIR_CHKLMEM_MALLOC(displs, MPI_Aint *, sizeof(MPI_Aint) * nranks, mpi_errno, "displs",
MPL_MEM_COLL);
int idx = 0;
rem = nranks - p_of_k;

Expand Down
9 changes: 6 additions & 3 deletions src/mpi/coll/iallreduce/iallreduce_tsp_ring_algos.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ int MPIR_TSP_Iallreduce_sched_intra_ring(const void *sendbuf, void *recvbuf, MPI
int nranks, is_inplace, rank;
size_t extent;
MPI_Aint lb, true_extent;
int *cnts, *displs, recv_id, *reduce_id, nvtcs, vtcs;
MPI_Aint *cnts, *displs;
int recv_id, *reduce_id, nvtcs, vtcs;
int send_rank, recv_rank, total_count;
void *tmpbuf;
int tag;
Expand All @@ -41,8 +42,10 @@ int MPIR_TSP_Iallreduce_sched_intra_ring(const void *sendbuf, void *recvbuf, MPI
MPIR_Type_get_true_extent_impl(datatype, &lb, &true_extent);
extent = MPL_MAX(extent, true_extent);

MPIR_CHKLMEM_MALLOC(cnts, int *, nranks * sizeof(int), mpi_errno, "cnts", MPL_MEM_COLL);
MPIR_CHKLMEM_MALLOC(displs, int *, nranks * sizeof(int), mpi_errno, "displs", MPL_MEM_COLL);
MPIR_CHKLMEM_MALLOC(cnts, MPI_Aint *, nranks * sizeof(MPI_Aint), mpi_errno, "cnts",
MPL_MEM_COLL);
MPIR_CHKLMEM_MALLOC(displs, MPI_Aint *, nranks * sizeof(MPI_Aint), mpi_errno, "displs",
MPL_MEM_COLL);

for (i = 0; i < nranks; i++)
cnts[i] = 0;
Expand Down
6 changes: 3 additions & 3 deletions src/mpi/coll/ibcast/ibcast_tsp_scatterv_allgatherv_algos.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ int MPIR_TSP_Ibcast_sched_intra_scatterv_allgatherv(void *buffer, MPI_Aint count
int size, rank, tag;
int i, j, x, is_contig;
void *tmp_buf = NULL;
int *cnts, *displs;
MPI_Aint *cnts, *displs;
size_t nbytes;
int tree_type;
MPIR_Treealgo_tree_t my_tree, parents_tree;
Expand Down Expand Up @@ -56,8 +56,8 @@ int MPIR_TSP_Ibcast_sched_intra_scatterv_allgatherv(void *buffer, MPI_Aint count
extent = MPL_MAX(extent, true_extent);

nbytes = type_size * count;
MPIR_CHKLMEM_MALLOC(cnts, int *, sizeof(int) * size, mpi_errno, "cnts", MPL_MEM_COLL); /* to store counts of each rank */
MPIR_CHKLMEM_MALLOC(displs, int *, sizeof(int) * size, mpi_errno, "displs", MPL_MEM_COLL); /* to store displs of each rank */
MPIR_CHKLMEM_MALLOC(cnts, MPI_Aint *, sizeof(MPI_Aint) * size, mpi_errno, "cnts", MPL_MEM_COLL); /* to store counts of each rank */
MPIR_CHKLMEM_MALLOC(displs, MPI_Aint *, sizeof(MPI_Aint) * size, mpi_errno, "displs", MPL_MEM_COLL); /* to store displs of each rank */

total_count = 0;
for (i = 0; i < size; i++)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ int MPIR_Ireduce_scatter_inter_sched_remote_reduce_local_scatterv(const void *se
int rank, root, local_size, total_count, i;
MPI_Aint true_extent, true_lb = 0, extent;
void *tmp_buf = NULL;
int *disps = NULL;
MPI_Aint *disps = NULL;
MPIR_Comm *newcomm_ptr = NULL;
MPIR_SCHED_CHKPMEM_DECL(2);

Expand All @@ -38,8 +38,8 @@ int MPIR_Ireduce_scatter_inter_sched_remote_reduce_local_scatterv(const void *se
/* In each group, rank 0 allocates a temp. buffer for the
* reduce */

MPIR_SCHED_CHKPMEM_MALLOC(disps, int *, local_size * sizeof(int), mpi_errno, "disps",
MPL_MEM_BUFFER);
MPIR_SCHED_CHKPMEM_MALLOC(disps, MPI_Aint *, local_size * sizeof(MPI_Aint), mpi_errno,
"disps", MPL_MEM_BUFFER);

total_count = 0;
for (i = 0; i < local_size; i++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ int MPIR_TSP_Ireduce_scatter_sched_intra_recexch(const void *sendbuf, void *recv
int dtcopy_id = -1, recv_id = -1, reduce_id = -1, sink_id = -1;
int nvtcs, vtcs[2];
void *tmp_recvbuf = NULL, *tmp_results = NULL;
int *displs;
MPI_Aint *displs;
int tag;
MPIR_CHKLMEM_DECL(1);

Expand Down Expand Up @@ -174,7 +174,7 @@ int MPIR_TSP_Ireduce_scatter_sched_intra_recexch(const void *sendbuf, void *recv
return mpi_errno;
}

MPIR_CHKLMEM_MALLOC(displs, int *, nranks * sizeof(int),
MPIR_CHKLMEM_MALLOC(displs, MPI_Aint *, nranks * sizeof(MPI_Aint),
mpi_errno, "displs buffer", MPL_MEM_COLL);
displs[0] = 0;
for (i = 1; i < nranks; i++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ int MPIR_Reduce_scatter_inter_remote_reduce_local_scatter(const void *sendbuf, v
int mpi_errno_ret = MPI_SUCCESS;
MPI_Aint true_extent, true_lb = 0, extent;
void *tmp_buf = NULL;
int *disps = NULL;
MPI_Aint *disps = NULL;
MPIR_Comm *newcomm_ptr = NULL;
MPIR_CHKLMEM_DECL(2);

Expand All @@ -37,7 +37,7 @@ int MPIR_Reduce_scatter_inter_remote_reduce_local_scatter(const void *sendbuf, v
/* In each group, rank 0 allocates a temp. buffer for the
* reduce */

MPIR_CHKLMEM_MALLOC(disps, int *, local_size * sizeof(int), mpi_errno, "disps",
MPIR_CHKLMEM_MALLOC(disps, MPI_Aint *, local_size * sizeof(MPI_Aint), mpi_errno, "disps",
MPL_MEM_BUFFER);

total_count = 0;
Expand Down
2 changes: 1 addition & 1 deletion src/mpi/coll/src/csel.c
Original file line number Diff line number Diff line change
Expand Up @@ -831,7 +831,7 @@ static inline bool is_block_regular(MPIR_Csel_coll_sig_s coll_info)
{
bool is_regular = true;
int i = 0;
const int *recvcounts = NULL;
const MPI_Aint *recvcounts = NULL;

switch (coll_info.coll_type) {
case MPIR_CSEL_COLL_TYPE__REDUCE_SCATTER:
Expand Down
4 changes: 2 additions & 2 deletions src/mpid/common/bc/mpidu_bc.c
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ int MPIDU_bc_allgather(MPIR_Comm * allgather_comm, void *bc, int bc_len, int sam
return mpi_errno;
}

int *recv_cnts = MPL_calloc(num_nodes, sizeof(int), MPL_MEM_OTHER);
int *recv_offs = MPL_calloc(num_nodes, sizeof(int), MPL_MEM_OTHER);
MPI_Aint *recv_cnts = MPL_calloc(num_nodes, sizeof(MPI_Aint), MPL_MEM_OTHER);
MPI_Aint *recv_offs = MPL_calloc(num_nodes, sizeof(MPI_Aint), MPL_MEM_OTHER);
for (i = 0; i < size; i++) {
int node_id = MPIR_Process.node_map[i];
recv_cnts[node_id]++;
Expand Down

0 comments on commit df18568

Please sign in to comment.