Skip to content

Commit

Permalink
TEAM/UCX: don't use event query with dynamic kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
Sergei-Lebedev committed Jan 16, 2021
1 parent f4f2082 commit 4506b08
Show file tree
Hide file tree
Showing 11 changed files with 220 additions and 87 deletions.
5 changes: 4 additions & 1 deletion src/api/xccl_status.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,15 @@ typedef enum {
/* Operation is queued and still in progress */
XCCL_INPROGRESS = 1,

/* Operation is queued but has not started yet*/
XCCL_INITIALIZED = 2,

/* Failure codes */
XCCL_ERR_NO_MESSAGE = -1,
XCCL_ERR_NO_RESOURCE = -2,
XCCL_ERR_NO_MEMORY = -4,
XCCL_ERR_INVALID_PARAM = -5,
XCCL_ERR_UNREACHABLE = -6,
XCCL_ERR_UNREACHABLE = -6,
XCCL_ERR_NOT_IMPLEMENTED = -8,
XCCL_ERR_MESSAGE_TRUNCATED = -9,
XCCL_ERR_NO_PROGRESS = -10,
Expand Down
14 changes: 13 additions & 1 deletion src/team_lib/nccl/xccl_nccl_collective.c
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,19 @@ xccl_nccl_collective_init_base(xccl_coll_op_args_t *coll_args,

(*request)->team = nccl_team;
(*request)->super.lib = &xccl_team_lib_nccl.super;
CUDACHECK(cudaEventCreateWithFlags(&((*request)->completed), cuda_event_flags));

switch((TEAM_NCCL_CTX_REQ(*request)->completion_sync)) {
case XCCL_NCCL_COMPLETION_SYNC_EVENT:
((*request)->completed) = (void*)0x1;
break;
case XCCL_NCCL_COMPLETION_SYNC_CALLBACK:
((*request)->completed) = NULL;
break;
default:
xccl_nccl_error("wrong completion sync type");
free(*request);
return XCCL_ERR_INVALID_PARAM;
}

return XCCL_OK;
}
Expand Down
13 changes: 11 additions & 2 deletions src/team_lib/nccl/xccl_nccl_collective.h
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
/*
* Copyright (C) Mellanox Technologies Ltd. 2020. ALL RIGHTS RESERVED.
* Copyright (C) Mellanox Technologies Ltd. 2020-2021. ALL RIGHTS RESERVED.
* See file LICENSE for terms.
*/

#ifndef XCCL_NCCL_COLLECTIVE_H_
#define XCCL_NCCL_COLLECTIVE_H_

#include <xccl_nccl_lib.h>
#include <utils/mem_component.h>
#include <nccl.h>

#define ncclOpUnsupported (ncclNumOps + 1)
Expand All @@ -19,7 +20,8 @@ typedef struct xccl_nccl_coll_req {
xccl_coll_op_args_t args;
xccl_nccl_team_t *team;
xccl_nccl_collective_start_fn coll_start;
cudaEvent_t completed;
xccl_mc_event_t *completed;
xccl_status_t status;
} xccl_nccl_coll_req_t;

xccl_status_t
Expand All @@ -46,5 +48,12 @@ xccl_status_t
xccl_nccl_allgather_init(xccl_coll_op_args_t *coll_args,
xccl_nccl_coll_req_t *request,
xccl_nccl_team_t *team);
xccl_status_t
xccl_nccl_bcast_init(xccl_coll_op_args_t *coll_args,
xccl_nccl_coll_req_t *request,
xccl_nccl_team_t *team);

#define TEAM_NCCL_CTX_REQ(_req) \
(ucs_derived_of((_req)->team->super.ctx, xccl_nccl_context_t))

#endif
90 changes: 60 additions & 30 deletions src/team_lib/nccl/xccl_nccl_lib.c
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,25 @@ static ucs_config_field_t xccl_team_lib_nccl_config_table[] = {
{NULL}
};

const char* xccl_nccl_sync_names[] = {
[XCCL_NCCL_COMPLETION_SYNC_EVENT] = "event",
[XCCL_NCCL_COMPLETION_SYNC_CALLBACK] = "callback",
};


static ucs_config_field_t xccl_tl_nccl_context_config_table[] = {
{"", "",
NULL,
ucs_offsetof(xccl_tl_nccl_context_config_t, super),
UCS_CONFIG_TYPE_TABLE(xccl_tl_context_config_table)
},

{"SYNC", "event",
"Determines how XCCL tests completion of NCCL collective",
ucs_offsetof(xccl_tl_nccl_context_config_t, completion_sync),
UCS_CONFIG_TYPE_ENUM(xccl_nccl_sync_names)
},

{NULL}
};

Expand Down Expand Up @@ -131,7 +143,15 @@ xccl_nccl_context_create(xccl_team_lib_h lib, xccl_context_params_t *params,
xccl_tl_context_t **context)
{
xccl_nccl_context_t *ctx = malloc(sizeof(*ctx));
xccl_tl_nccl_context_config_t *tl_config;

if (ctx == NULL) {
xccl_nccl_error("failed to allocate memory for nccl context");
return XCCL_ERR_NO_MEMORY;
}
tl_config = ucs_derived_of(config, xccl_tl_nccl_context_config_t);
ctx->completion_sync = tl_config->completion_sync;
xccl_nccl_debug("sync type: %s", xccl_nccl_sync_names[ctx->completion_sync]);
XCCL_CONTEXT_SUPER_INIT(ctx->super, lib, params);
*context = &ctx->super;

Expand Down Expand Up @@ -271,63 +291,73 @@ xccl_nccl_collective_init(xccl_coll_op_args_t *coll_args,
return XCCL_OK;
}

static void nccl_completion_callback(void *request) {
xccl_nccl_coll_req_t *req = ucs_derived_of(request, xccl_nccl_coll_req_t);
req->status = XCCL_OK;
}

static xccl_status_t
xccl_nccl_collective_post(xccl_tl_coll_req_t *request)
{
xccl_nccl_coll_req_t *req = ucs_derived_of(request, xccl_nccl_coll_req_t);
cudaStream_t *stream;
xccl_status_t st;
cudaStream_t *stream;

st = req->coll_start(request);
if (st != XCCL_OK) {
xccl_nccl_error("collective start failed %d", st);
return st;
}
stream = (cudaStream_t*)req->args.stream.stream;
CUDACHECK(cudaEventRecord(req->completed, *stream));

return XCCL_OK;
if (req->completed != NULL) {
st = xccl_mc_event_record(&req->args.stream, &req->completed);
} else {
stream = (cudaStream_t*)req->args.stream.stream;
req->status = XCCL_INPROGRESS;
CUDACHECK(cudaLaunchHostFunc(*stream, nccl_completion_callback, req));
st = XCCL_OK;
}

return st;
}

static xccl_status_t
xccl_nccl_collective_wait(xccl_tl_coll_req_t *request)
xccl_nccl_collective_test(xccl_tl_coll_req_t *request)
{
xccl_nccl_coll_req_t *req = ucs_derived_of(request, xccl_nccl_coll_req_t);
cudaError_t cuda_st;
xccl_status_t st;

CUDACHECK(cudaEventSynchronize(req->completed));
if (req->completed != NULL) {
/* use event to determine collective status */
req->status = xccl_mc_event_query(req->completed);
if (req->status != XCCL_INPROGRESS) {
st = xccl_mc_event_free(req->completed);
req->completed = NULL;
if (st != XCCL_OK) {
return st;
}
}
}

return XCCL_OK;
return req->status;
}

static xccl_status_t
xccl_nccl_collective_test(xccl_tl_coll_req_t *request)
xccl_nccl_collective_wait(xccl_tl_coll_req_t *request)
{
xccl_nccl_coll_req_t *req = ucs_derived_of(request, xccl_nccl_coll_req_t);
cudaError_t cuda_st;

cuda_st = cudaEventQuery(req->completed);
switch(cuda_st) {
case cudaSuccess:
return XCCL_OK;
case cudaErrorNotReady:
return XCCL_INPROGRESS;
default:
return XCCL_ERR_NO_MESSAGE;
}
xccl_status_t st;

do {
st = xccl_nccl_collective_test(request);
} while (st == XCCL_INPROGRESS);

return st;
}

static xccl_status_t
xccl_nccl_collective_finalize(xccl_tl_coll_req_t *request)
{
xccl_nccl_coll_req_t *req = ucs_derived_of(request, xccl_nccl_coll_req_t);

if (cudaEventQuery(req->completed) != cudaSuccess) {
xccl_nccl_error("calling collective finalize before collective is done");
return XCCL_ERR_NO_MESSAGE;
}

CUDACHECK(cudaEventDestroy(req->completed));
free(req);
free(request);

return XCCL_OK;
}
Expand Down
25 changes: 16 additions & 9 deletions src/team_lib/nccl/xccl_nccl_lib.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,24 @@
#include <cuda.h>
#include "xccl_team_lib.h"

typedef enum xccl_nccl_completion_sync_type {
XCCL_NCCL_COMPLETION_SYNC_EVENT,
XCCL_NCCL_COMPLETION_SYNC_CALLBACK
} xccl_nccl_completion_sync_type_t;

typedef struct xccl_team_lib_nccl_config {
xccl_team_lib_config_t super;
int enable_allreduce;
int enable_alltoall;
int enable_alltoallv;
int enable_allgather;
int enable_bcast;
xccl_team_lib_config_t super;
int enable_allreduce;
int enable_alltoall;
int enable_alltoallv;
int enable_allgather;
int enable_bcast;
} xccl_team_lib_nccl_config_t;

typedef struct xccl_tl_nccl_context_config {
xccl_tl_context_config_t super;
char *device;
xccl_tl_context_config_t super;
char *device;
xccl_nccl_completion_sync_type_t completion_sync;
} xccl_tl_nccl_context_config_t;

typedef struct xccl_team_lib_nccl {
Expand All @@ -48,7 +54,8 @@ extern xccl_team_lib_nccl_t xccl_team_lib_nccl;
#define xccl_nccl_trace_poll(_fmt, ...) xccl_team_nccl_log_component(UCS_LOG_LEVEL_TRACE_POLL, _fmt, ## __VA_ARGS__)

typedef struct xccl_nccl_context {
xccl_tl_context_t super;
xccl_tl_context_t super;
xccl_nccl_completion_sync_type_t completion_sync;
} xccl_nccl_context_t;

typedef struct xccl_nccl_team {
Expand Down
74 changes: 50 additions & 24 deletions src/team_lib/ucx/xccl_ucx_lib.c
Original file line number Diff line number Diff line change
Expand Up @@ -511,23 +511,37 @@ static xccl_status_t xccl_ucx_collective_post(xccl_tl_coll_req_t *request)

req->stream_req = NULL;
if (req->args.field_mask & XCCL_COLL_OP_ARGS_FIELD_STREAM) {
st = xccl_mc_event_record(&req->args.stream, &req->ready_to_start);
if (st != XCCL_OK) {
return st;
}
if (TEAM_UCX_CTX_REQ(req)->block_stream) {
xccl_mem_component_start_acitivity(&req->args.stream,
&req->stream_req);
}
st = xccl_mc_event_query(req->ready_to_start);
if (st == XCCL_INPROGRESS) {
/* collective is not ready to start, start it later*/
return XCCL_OK;
}
if (st != XCCL_OK) {
return st;
st = xccl_mem_component_start_acitivity(&req->args.stream,
&req->stream_req);
if (st != XCCL_OK) {
return st;
}
st = xccl_mem_component_query_activity(req->stream_req);
if (st == XCCL_INITIALIZED) {
/* collective is not ready to start, start it later*/
/* assign dummy value to ready_to_start*/
req->ready_to_start = (void*)0x1;
return XCCL_OK;
}
if (st != XCCL_INPROGRESS) {
return st;
}
} else {
st = xccl_mc_event_record(&req->args.stream, &req->ready_to_start);
if (st != XCCL_OK) {
return st;
}
st = xccl_mc_event_query(req->ready_to_start);
if (st == XCCL_INPROGRESS) {
/* collective is not ready to start, start it later*/
return XCCL_OK;
}
if (st != XCCL_OK) {
return st;
}
xccl_mc_event_free(req->ready_to_start);
}
xccl_mc_event_free(req->ready_to_start);
}
req->ready_to_start = NULL;
return req->start(req);
Expand All @@ -539,23 +553,35 @@ static xccl_status_t xccl_ucx_collective_test(xccl_tl_coll_req_t *request)
xccl_status_t status;

if (req->ready_to_start != NULL) {
status = xccl_mc_event_query(req->ready_to_start);
if (status != XCCL_OK) {
return status;
if (TEAM_UCX_CTX_REQ(req)->block_stream) {
status = xccl_mem_component_query_activity(req->stream_req);
/* status can't be XCCL_OK since collective wasn't started*/
assert(status != XCCL_OK);
if (status == XCCL_INITIALIZED) {
return XCCL_INPROGRESS;
} else if (status != XCCL_INPROGRESS) {
/* error */
return status;
}
} else {
status = xccl_mc_event_query(req->ready_to_start);
if (status != XCCL_OK) {
return status;
}
xccl_mc_event_free(req->ready_to_start);
}
xccl_mc_event_free(req->ready_to_start);
req->ready_to_start = NULL;
req->start(req);
}
if (XCCL_INPROGRESS == req->complete) {
if (XCCL_OK != (status = req->progress(req))) {
return status;
};
if ((XCCL_INPROGRESS != req->complete) &&
(req->stream_req != NULL)) {
xccl_mem_component_finish_acitivity(req->stream_req);
req->stream_req = NULL;
}
}
if ((XCCL_INPROGRESS != req->complete) &&
(req->stream_req != NULL)) {
xccl_mem_component_finish_acitivity(req->stream_req);
req->stream_req = NULL;
}

return req->complete;
Expand Down
Loading

0 comments on commit 4506b08

Please sign in to comment.