Skip to content

Commit

Permalink
coll: use MPI_Aint for v-collective array parameters
Browse files Browse the repository at this point in the history
Also added python code to generate MPI_Aint impl prototypes and generate
code to do counts array swap before calling MPIR_Xxx
  • Loading branch information
hzhou committed Feb 12, 2021
1 parent 578ea61 commit fe78ecd
Show file tree
Hide file tree
Showing 129 changed files with 1,625 additions and 1,347 deletions.
104 changes: 102 additions & 2 deletions maint/local_python/binding_c.py
Original file line number Diff line number Diff line change
Expand Up @@ -888,6 +888,7 @@ def push_impl_decl(func, impl_name=None):
if func['dir'] == 'coll':
# All collective impl function use MPI_Aint counts
params = re.sub(r' int (count|sendcount|recvcount),', r' MPI_Aint \1,', params)
params = re.sub(r' int (sendcounts|recvcounts|[sr]?displs)\[', r' MPI_Aint \1[', params)
# block collective use an extra errflag
if not RE.match(r'MPI_(I|Neighbor)', func['name']):
params = params + ", MPIR_Errflag_t *errflag"
Expand Down Expand Up @@ -916,10 +917,106 @@ def dump_CHECKENUM(var, errname, t, type="ENUM"):
dump_if_close()

def dump_body_coll(func):
# --- routines for swaping counts array due to int <-> MPI_Aint
def allocate_tmp_array(n):
G.out.append("MPI_Aint *tmp_array = MPL_malloc(%s * sizeof(MPI_Aint), MPL_MEM_OTHER);" % n)
def swap_one(n, counts):
G.out.append("for (int i = 0; i < %s; i++) {" % n)
G.out.append(" tmp_array[i] = %s[i];" % counts)
G.out.append("}")
def swap_next(base, n, counts):
G.out.append("for (int i = 0; i < %s; i++) {" % n)
G.out.append(" tmp_array[%s + i] = %s[i];" % (base, counts))
G.out.append("}")

def dump_v_swap(func):
args = ", ".join(func['impl_arg_list'])
if RE.match(r'mpi_i?neighbor_', func['name'], re.IGNORECASE):
# neighborhood collectives
G.out.append("int indegree, outdegree, weighted;")
G.out.append("mpi_errno = MPIR_Topo_canon_nhb_count(comm_ptr, &indegree, &outdegree, &weighted);")
if RE.search(r'allgatherv', func['name'], re.IGNORECASE):
allocate_tmp_array("indegree * 2")
swap_one("indegree", "recvcounts")
swap_next("indegree", "indegree", "displs")
args = re.sub(r'recvcounts', 'tmp_array', args)
args = re.sub(r'displs', 'tmp_array + indegree', args)
elif RE.search(r'alltoallv', func['name'], re.IGNORECASE):
allocate_tmp_array("(outdegree + indegree) * 2")
swap_one("outdegree", "sendcounts")
swap_next("outdegree", "outdegree", "sdispls")
swap_next("outdegree * 2", "indegree", "recvcounts")
swap_next("outdegree * 2 + indegree", "indegree", "rdispls")
args = re.sub(r'sendcounts', 'tmp_array', args)
args = re.sub(r'sdispls', 'tmp_array + outdegree', args)
args = re.sub(r'recvcounts', 'tmp_array + outdegree * 2', args)
args = re.sub(r'rdispls', 'tmp_array + outdegree * 2 + indegree', args)
else: # neighbor_alltoallw
allocate_tmp_array("(outdegree + indegree)")
swap_one("indegree", "sendcounts")
swap_next("indegree", "outdegree", "recvcounts")
args = re.sub(r'sendcounts', 'tmp_array', args)
args = re.sub(r'recvcounts', 'tmp_array + indegree', args)
# classical collectives
elif RE.match(r'mpi_i?reduce_scatter\b', func['name'], re.IGNORECASE):
G.out.append("int n = comm_ptr->local_size;")
allocate_tmp_array("n")
swap_one("n", "recvcounts")
args = re.sub(r'recvcounts', 'tmp_array', args)
else:
cond = "(comm_ptr->comm_kind == MPIR_COMM_KIND__INTERCOMM)"
G.out.append("int n = %s ? comm_ptr->remote_size : comm_ptr->local_size;" % cond)
if RE.search(r'alltoall[vw]', func['name'], re.IGNORECASE):
allocate_tmp_array("n * 4")
dump_if_open("sendbuf != MPI_IN_PLACE")
swap_one("n", "sendcounts")
swap_next("n", "n", "sdispls")
dump_if_close()
swap_next("n * 2", "n", "recvcounts")
swap_next("n * 3", "n", "rdispls")
args = re.sub(r'sendcounts', 'tmp_array', args)
args = re.sub(r'sdispls', 'tmp_array + n', args)
args = re.sub(r'recvcounts', 'tmp_array + n * 2', args)
args = re.sub(r'rdispls', 'tmp_array + n * 3', args)
elif RE.search(r'allgatherv', func['name'], re.IGNORECASE):
allocate_tmp_array("n * 2")
swap_one("n", "recvcounts")
swap_next("n", "n", "displs")
args = re.sub(r'recvcounts', 'tmp_array', args)
args = re.sub(r'displs', 'tmp_array + n', args)
else:
allocate_tmp_array("n * 2")
if RE.search(r'scatterv', func['name'], re.IGNORECASE):
counts = "sendcounts"
else: # gatherv
counts = "recvcounts"
# only root need the v-array
cond_intra = "comm_ptr->comm_kind == MPIR_COMM_KIND__INTRACOMM"
cond_inter = "comm_ptr->comm_kind == MPIR_COMM_KIND__INTERCOMM"
cond_a = cond_intra + " && comm_ptr->rank == root"
cond_b = cond_inter + " && root == MPI_ROOT"

dump_if_open("(%s) || (%s)" % (cond_a, cond_b))
swap_one("n", counts)
swap_next("n", "n", "displs")
dump_if_close()

args = re.sub(counts, 'tmp_array', args)
args = re.sub(r'displs', 'tmp_array + n', args)
return args

def dump_v_exit(func):
G.out.append("MPL_free(tmp_array);")

# -------------------------
# collectives call MPIR_Xxx
mpir_name = re.sub(r'^MPIX?_', 'MPIR_', func['name'])

args = ", ".join(func['impl_arg_list'])
if RE.search(r'((all)?gatherv|scatterv|alltoall[vw]|reduce_scatter\b)', func['name'], re.IGNORECASE):
args = dump_v_swap(func)
else:
args = ", ".join(func['impl_arg_list'])

if RE.match(r'mpi_i', func['name'], re.IGNORECASE):
# non-blocking collectives
G.out.append("MPIR_Request *request_ptr = NULL;")
Expand All @@ -936,6 +1033,9 @@ def dump_body_coll(func):
G.out.append("MPIR_Errflag_t errflag = MPIR_ERR_NONE;")
dump_line_with_break("mpi_errno = %s(%s, &errflag);" % (mpir_name, args))

if RE.search(r'((all)?gatherv|scatterv|alltoall[vw]|reduce_scatter\b)', func['name'], re.IGNORECASE):
dump_v_exit(func)

def dump_body_topo_fns(func, method):
comm_ptr = func['_has_comm'] + "_ptr"
dump_if_open("%s->topo_fns && %s->topo_fns->%s" % (comm_ptr, comm_ptr, method))
Expand Down Expand Up @@ -1716,7 +1816,7 @@ def dump_validate_op(op, dt):

def dump_validate_get_comm_size(func):
if '_got_comm_size' not in func:
if RE.match(r'mpi_i?reduce_scatter', func['name'], re.IGNORECASE):
if RE.match(r'mpi_i?reduce_scatter\b', func['name'], re.IGNORECASE):
G.out.append("int comm_size = comm_ptr->local_size;")
else:
G.out.append("int comm_size;")
Expand Down
Loading

0 comments on commit fe78ecd

Please sign in to comment.