Skip to content

Commit

Permalink
Make concurrent
Browse files Browse the repository at this point in the history
  • Loading branch information
Lior Paz committed Jan 10, 2021
1 parent 6f884f6 commit 600142c
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 13 deletions.
20 changes: 12 additions & 8 deletions src/team_lib/mhba/xccl_mhba_collective.c
Original file line number Diff line number Diff line change
Expand Up @@ -227,11 +227,12 @@ static xccl_status_t xccl_mhba_asr_barrier_start(xccl_coll_task_t *task)

task->state = XCCL_TASK_STATE_COMPLETED;

team->inter_node_barrier[team->net.sbgp->group_rank] = request->seq_num;
int index = team->net.sbgp->group_rank+ team->net.sbgp->group_size*SEQ_INDEX(request->seq_num);
team->inter_node_barrier[index] = request->seq_num;
for(i=0; i<team->net.net_size;i++){
xccl_status_t status = send_block_data(team->net.qps[i], (uintptr_t)team->inter_node_barrier_mr->addr+team->net.sbgp->group_rank*sizeof(int) , sizeof(int),
xccl_status_t status = send_block_data(team->net.qps[i], (uintptr_t)team->inter_node_barrier_mr->addr+index*sizeof(int), sizeof(int),
team->inter_node_barrier_mr->lkey,
team->net.remote_ctrl[i].barrier_addr+sizeof(int)*team->net.sbgp->group_rank, team->net.remote_ctrl[i].barrier_rkey, 0, 0);
(uintptr_t)team->net.remote_ctrl[i].barrier_addr+sizeof(int)*index, team->net.remote_ctrl[i].barrier_rkey, 0, 0);
if (status != XCCL_OK) {
xccl_mhba_error("Failed sending barrier notice");
return status;
Expand Down Expand Up @@ -331,6 +332,7 @@ xccl_mhba_send_blocks_start_with_transpose(xccl_coll_task_t *task)
int block_size = request->block_size;
int col_msgsize = len * block_size * node_size;
int block_msgsize = SQUARED(block_size) * len;
int barr_index = index*MAX_OUTSTANDING_OPS;
int i, j, k, dest_rank, rank, n_compl, ret;
uint64_t src_addr, remote_addr;
struct ibv_wc transpose_completion[1];
Expand All @@ -343,8 +345,9 @@ xccl_mhba_send_blocks_start_with_transpose(xccl_coll_task_t *task)

while(counter < net_size) {
for (i = 0; i < net_size; i++) {
if (team->inter_node_barrier[i] == request->seq_num && !team->inter_node_barrier_flag[i]) {
team->inter_node_barrier_flag[i] = 1;
if (team->inter_node_barrier[i+barr_index] == request->seq_num &&
!team->inter_node_barrier_flag[i+barr_index]) {
team->inter_node_barrier_flag[i+barr_index] = 1;
dest_rank = team->net.rank_map[i];
//send all blocks from curr node to some ARR
for (j = 0; j < xccl_round_up(node_size, block_size); j++) {
Expand Down Expand Up @@ -428,6 +431,7 @@ static xccl_status_t xccl_mhba_send_blocks_start(xccl_coll_task_t *task)
int block_size = request->block_size;
int col_msgsize = len * block_size * node_size;
int block_msgsize = SQUARED(block_size) * len;
int barr_index = index*MAX_OUTSTANDING_OPS;
int i, j, k, dest_rank, rank;
int counter = 0;
uint64_t src_addr, remote_addr;
Expand All @@ -439,8 +443,8 @@ static xccl_status_t xccl_mhba_send_blocks_start(xccl_coll_task_t *task)

while(counter < net_size) {
for (i = 0; i < net_size; i++) {
if (team->inter_node_barrier[i] == request->seq_num && !team->inter_node_barrier_flag[i]) {
team->inter_node_barrier_flag[i] = 1;
if (team->inter_node_barrier[i +barr_index] == request->seq_num && !team->inter_node_barrier_flag[i +barr_index]) {
team->inter_node_barrier_flag[i+barr_index] = 1;
dest_rank = team->net.rank_map[i];
//send all blocks from curr node to some ARR
for (j = 0; j < xccl_round_up(node_size, block_size); j++) {
Expand Down Expand Up @@ -590,7 +594,7 @@ xccl_status_t xccl_mhba_alltoall_init(xccl_coll_op_args_t *coll_args,
xccl_mhba_fanout_start;
request->tasks[1].super.progress = xccl_mhba_fanout_progress;
} else {
memset(team->inter_node_barrier_flag,0,sizeof(int)*team->net.net_size);
memset(&team->inter_node_barrier_flag[MAX_OUTSTANDING_OPS*SEQ_INDEX(request->seq_num)],0,sizeof(int)*team->net.net_size);
request->tasks[1].super.handlers[XCCL_EVENT_COMPLETED] =
xccl_mhba_asr_barrier_start;
request->tasks[1].super.progress = xccl_mhba_asr_barrier_progress;
Expand Down
10 changes: 5 additions & 5 deletions src/team_lib/mhba/xccl_mhba_team.c
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ xccl_status_t xccl_mhba_team_create_post(xccl_tl_context_t *context,
}
// for each ASR - qp num, in addition to port lid, ctrl segment rkey and address, recieve mkey rkey
local_data_size = (net_size * sizeof(uint32_t)) + sizeof(uint32_t) +
3 * sizeof(uint32_t) + 2*sizeof(void *); //todo make concurrent
3 * sizeof(uint32_t) + 2*sizeof(void *);
local_data = malloc(local_data_size);
if (!local_data) {
xccl_mhba_error("failed to allocate local data");
Expand Down Expand Up @@ -361,17 +361,17 @@ xccl_status_t xccl_mhba_team_create_post(xccl_tl_context_t *context,

local_data[net_size + 4] = mhba_team->node.team_recv_mkey->rkey;

mhba_team->inter_node_barrier = (int*) malloc(sizeof(int)*net_size);
mhba_team->inter_node_barrier_flag = (int*) malloc(sizeof(int)*net_size);
mhba_team->inter_node_barrier = (int*) malloc(sizeof(int)*net_size*MAX_OUTSTANDING_OPS);
mhba_team->inter_node_barrier_flag = (int*) malloc(sizeof(int)*net_size*MAX_OUTSTANDING_OPS);
if(!mhba_team->inter_node_barrier || !mhba_team->inter_node_barrier_flag){
xccl_mhba_error("Failed to malloc");
goto barrier_alloc_failure;
}
for(i=0;i<net_size;i++){
for(i=0;i<net_size*MAX_OUTSTANDING_OPS;i++){
mhba_team->inter_node_barrier[i] = -1;
}
mhba_team->inter_node_barrier_mr = ibv_reg_mr(mhba_team->node.shared_pd, mhba_team->inter_node_barrier,
sizeof(uint32_t)*net_size, IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE);
sizeof(int)*net_size*MAX_OUTSTANDING_OPS, IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE);
if (!mhba_team->inter_node_barrier_mr) {
xccl_mhba_error("Failed to register memory");
goto barrier_alloc_failure;
Expand Down

0 comments on commit 600142c

Please sign in to comment.