diff --git a/src/api/xccl_status.h b/src/api/xccl_status.h index b35b0c2..fda1847 100644 --- a/src/api/xccl_status.h +++ b/src/api/xccl_status.h @@ -14,12 +14,15 @@ typedef enum { /* Operation is queued and still in progress */ XCCL_INPROGRESS = 1, + /* Operation is queued but has not started yet*/ + XCCL_INITIALIZED = 2, + /* Failure codes */ XCCL_ERR_NO_MESSAGE = -1, XCCL_ERR_NO_RESOURCE = -2, XCCL_ERR_NO_MEMORY = -4, XCCL_ERR_INVALID_PARAM = -5, - XCCL_ERR_UNREACHABLE = -6, + XCCL_ERR_UNREACHABLE = -6, XCCL_ERR_NOT_IMPLEMENTED = -8, XCCL_ERR_MESSAGE_TRUNCATED = -9, XCCL_ERR_NO_PROGRESS = -10, diff --git a/src/team_lib/nccl/xccl_nccl_collective.c b/src/team_lib/nccl/xccl_nccl_collective.c index 62fdb80..227a1c6 100644 --- a/src/team_lib/nccl/xccl_nccl_collective.c +++ b/src/team_lib/nccl/xccl_nccl_collective.c @@ -32,7 +32,19 @@ xccl_nccl_collective_init_base(xccl_coll_op_args_t *coll_args, (*request)->team = nccl_team; (*request)->super.lib = &xccl_team_lib_nccl.super; - CUDACHECK(cudaEventCreateWithFlags(&((*request)->completed), cuda_event_flags)); + + switch((TEAM_NCCL_CTX_REQ(*request)->completion_sync)) { + case XCCL_NCCL_COMPLETION_SYNC_EVENT: + ((*request)->completed) = (void*)0x1; + break; + case XCCL_NCCL_COMPLETION_SYNC_CALLBACK: + ((*request)->completed) = NULL; + break; + default: + xccl_nccl_error("wrong completion sync type"); + free(*request); + return XCCL_ERR_INVALID_PARAM; + } return XCCL_OK; } diff --git a/src/team_lib/nccl/xccl_nccl_collective.h b/src/team_lib/nccl/xccl_nccl_collective.h index 012cc1a..c7d0a9b 100644 --- a/src/team_lib/nccl/xccl_nccl_collective.h +++ b/src/team_lib/nccl/xccl_nccl_collective.h @@ -1,5 +1,5 @@ /* - * Copyright (C) Mellanox Technologies Ltd. 2020. ALL RIGHTS RESERVED. + * Copyright (C) Mellanox Technologies Ltd. 2020-2021. ALL RIGHTS RESERVED. * See file LICENSE for terms. */ @@ -7,6 +7,7 @@ #define XCCL_NCCL_COLLECTIVE_H_ #include +#include #include #define ncclOpUnsupported (ncclNumOps + 1) @@ -19,7 +20,8 @@ typedef struct xccl_nccl_coll_req { xccl_coll_op_args_t args; xccl_nccl_team_t *team; xccl_nccl_collective_start_fn coll_start; - cudaEvent_t completed; + xccl_mc_event_t *completed; + xccl_status_t status; } xccl_nccl_coll_req_t; xccl_status_t @@ -46,5 +48,12 @@ xccl_status_t xccl_nccl_allgather_init(xccl_coll_op_args_t *coll_args, xccl_nccl_coll_req_t *request, xccl_nccl_team_t *team); +xccl_status_t +xccl_nccl_bcast_init(xccl_coll_op_args_t *coll_args, + xccl_nccl_coll_req_t *request, + xccl_nccl_team_t *team); + +#define TEAM_NCCL_CTX_REQ(_req) \ + (ucs_derived_of((_req)->team->super.ctx, xccl_nccl_context_t)) #endif diff --git a/src/team_lib/nccl/xccl_nccl_lib.c b/src/team_lib/nccl/xccl_nccl_lib.c index 952c3a2..7d17142 100644 --- a/src/team_lib/nccl/xccl_nccl_lib.c +++ b/src/team_lib/nccl/xccl_nccl_lib.c @@ -81,6 +81,12 @@ static ucs_config_field_t xccl_team_lib_nccl_config_table[] = { {NULL} }; +const char* xccl_nccl_sync_names[] = { + [XCCL_NCCL_COMPLETION_SYNC_EVENT] = "event", + [XCCL_NCCL_COMPLETION_SYNC_CALLBACK] = "callback", +}; + + static ucs_config_field_t xccl_tl_nccl_context_config_table[] = { {"", "", NULL, @@ -88,6 +94,12 @@ static ucs_config_field_t xccl_tl_nccl_context_config_table[] = { UCS_CONFIG_TYPE_TABLE(xccl_tl_context_config_table) }, + {"SYNC", "event", + "Determines how XCCL tests completion of NCCL collective", + ucs_offsetof(xccl_tl_nccl_context_config_t, completion_sync), + UCS_CONFIG_TYPE_ENUM(xccl_nccl_sync_names) + }, + {NULL} }; @@ -131,7 +143,15 @@ xccl_nccl_context_create(xccl_team_lib_h lib, xccl_context_params_t *params, xccl_tl_context_t **context) { xccl_nccl_context_t *ctx = malloc(sizeof(*ctx)); + xccl_tl_nccl_context_config_t *tl_config; + if (ctx == NULL) { + xccl_nccl_error("failed to allocate memory for nccl context"); + return XCCL_ERR_NO_MEMORY; + } + tl_config = ucs_derived_of(config, xccl_tl_nccl_context_config_t); + ctx->completion_sync = tl_config->completion_sync; + xccl_nccl_debug("sync type: %s", xccl_nccl_sync_names[ctx->completion_sync]); XCCL_CONTEXT_SUPER_INIT(ctx->super, lib, params); *context = &ctx->super; @@ -271,63 +291,73 @@ xccl_nccl_collective_init(xccl_coll_op_args_t *coll_args, return XCCL_OK; } +static void nccl_completion_callback(void *request) { + xccl_nccl_coll_req_t *req = ucs_derived_of(request, xccl_nccl_coll_req_t); + req->status = XCCL_OK; +} + static xccl_status_t xccl_nccl_collective_post(xccl_tl_coll_req_t *request) { xccl_nccl_coll_req_t *req = ucs_derived_of(request, xccl_nccl_coll_req_t); - cudaStream_t *stream; xccl_status_t st; + cudaStream_t *stream; st = req->coll_start(request); if (st != XCCL_OK) { + xccl_nccl_error("collective start failed %d", st); return st; } - stream = (cudaStream_t*)req->args.stream.stream; - CUDACHECK(cudaEventRecord(req->completed, *stream)); - return XCCL_OK; + if (req->completed != NULL) { + st = xccl_mc_event_record(&req->args.stream, &req->completed); + } else { + stream = (cudaStream_t*)req->args.stream.stream; + req->status = XCCL_INPROGRESS; + CUDACHECK(cudaLaunchHostFunc(*stream, nccl_completion_callback, req)); + st = XCCL_OK; + } + + return st; } static xccl_status_t -xccl_nccl_collective_wait(xccl_tl_coll_req_t *request) +xccl_nccl_collective_test(xccl_tl_coll_req_t *request) { xccl_nccl_coll_req_t *req = ucs_derived_of(request, xccl_nccl_coll_req_t); - cudaError_t cuda_st; + xccl_status_t st; - CUDACHECK(cudaEventSynchronize(req->completed)); + if (req->completed != NULL) { + /* use event to determine collective status */ + req->status = xccl_mc_event_query(req->completed); + if (req->status != XCCL_INPROGRESS) { + st = xccl_mc_event_free(req->completed); + req->completed = NULL; + if (st != XCCL_OK) { + return st; + } + } + } - return XCCL_OK; + return req->status; } static xccl_status_t -xccl_nccl_collective_test(xccl_tl_coll_req_t *request) +xccl_nccl_collective_wait(xccl_tl_coll_req_t *request) { - xccl_nccl_coll_req_t *req = ucs_derived_of(request, xccl_nccl_coll_req_t); - cudaError_t cuda_st; - - cuda_st = cudaEventQuery(req->completed); - switch(cuda_st) { - case cudaSuccess: - return XCCL_OK; - case cudaErrorNotReady: - return XCCL_INPROGRESS; - default: - return XCCL_ERR_NO_MESSAGE; - } + xccl_status_t st; + + do { + st = xccl_nccl_collective_test(request); + } while (st == XCCL_INPROGRESS); + + return st; } static xccl_status_t xccl_nccl_collective_finalize(xccl_tl_coll_req_t *request) { - xccl_nccl_coll_req_t *req = ucs_derived_of(request, xccl_nccl_coll_req_t); - - if (cudaEventQuery(req->completed) != cudaSuccess) { - xccl_nccl_error("calling collective finalize before collective is done"); - return XCCL_ERR_NO_MESSAGE; - } - - CUDACHECK(cudaEventDestroy(req->completed)); - free(req); + free(request); return XCCL_OK; } diff --git a/src/team_lib/nccl/xccl_nccl_lib.h b/src/team_lib/nccl/xccl_nccl_lib.h index f35c493..33a695d 100644 --- a/src/team_lib/nccl/xccl_nccl_lib.h +++ b/src/team_lib/nccl/xccl_nccl_lib.h @@ -10,18 +10,24 @@ #include #include "xccl_team_lib.h" +typedef enum xccl_nccl_completion_sync_type { + XCCL_NCCL_COMPLETION_SYNC_EVENT, + XCCL_NCCL_COMPLETION_SYNC_CALLBACK +} xccl_nccl_completion_sync_type_t; + typedef struct xccl_team_lib_nccl_config { - xccl_team_lib_config_t super; - int enable_allreduce; - int enable_alltoall; - int enable_alltoallv; - int enable_allgather; - int enable_bcast; + xccl_team_lib_config_t super; + int enable_allreduce; + int enable_alltoall; + int enable_alltoallv; + int enable_allgather; + int enable_bcast; } xccl_team_lib_nccl_config_t; typedef struct xccl_tl_nccl_context_config { - xccl_tl_context_config_t super; - char *device; + xccl_tl_context_config_t super; + char *device; + xccl_nccl_completion_sync_type_t completion_sync; } xccl_tl_nccl_context_config_t; typedef struct xccl_team_lib_nccl { @@ -48,7 +54,8 @@ extern xccl_team_lib_nccl_t xccl_team_lib_nccl; #define xccl_nccl_trace_poll(_fmt, ...) xccl_team_nccl_log_component(UCS_LOG_LEVEL_TRACE_POLL, _fmt, ## __VA_ARGS__) typedef struct xccl_nccl_context { - xccl_tl_context_t super; + xccl_tl_context_t super; + xccl_nccl_completion_sync_type_t completion_sync; } xccl_nccl_context_t; typedef struct xccl_nccl_team { diff --git a/src/team_lib/ucx/xccl_ucx_lib.c b/src/team_lib/ucx/xccl_ucx_lib.c index beee171..0ea4b84 100644 --- a/src/team_lib/ucx/xccl_ucx_lib.c +++ b/src/team_lib/ucx/xccl_ucx_lib.c @@ -511,23 +511,37 @@ static xccl_status_t xccl_ucx_collective_post(xccl_tl_coll_req_t *request) req->stream_req = NULL; if (req->args.field_mask & XCCL_COLL_OP_ARGS_FIELD_STREAM) { - st = xccl_mc_event_record(&req->args.stream, &req->ready_to_start); - if (st != XCCL_OK) { - return st; - } if (TEAM_UCX_CTX_REQ(req)->block_stream) { - xccl_mem_component_start_acitivity(&req->args.stream, - &req->stream_req); - } - st = xccl_mc_event_query(req->ready_to_start); - if (st == XCCL_INPROGRESS) { - /* collective is not ready to start, start it later*/ - return XCCL_OK; - } - if (st != XCCL_OK) { - return st; + st = xccl_mem_component_start_acitivity(&req->args.stream, + &req->stream_req); + if (st != XCCL_OK) { + return st; + } + st = xccl_mem_component_query_activity(req->stream_req); + if (st == XCCL_INITIALIZED) { + /* collective is not ready to start, start it later*/ + /* assign dummy value to ready_to_start*/ + req->ready_to_start = (void*)0x1; + return XCCL_OK; + } + if (st != XCCL_INPROGRESS) { + return st; + } + } else { + st = xccl_mc_event_record(&req->args.stream, &req->ready_to_start); + if (st != XCCL_OK) { + return st; + } + st = xccl_mc_event_query(req->ready_to_start); + if (st == XCCL_INPROGRESS) { + /* collective is not ready to start, start it later*/ + return XCCL_OK; + } + if (st != XCCL_OK) { + return st; + } + xccl_mc_event_free(req->ready_to_start); } - xccl_mc_event_free(req->ready_to_start); } req->ready_to_start = NULL; return req->start(req); @@ -539,11 +553,23 @@ static xccl_status_t xccl_ucx_collective_test(xccl_tl_coll_req_t *request) xccl_status_t status; if (req->ready_to_start != NULL) { - status = xccl_mc_event_query(req->ready_to_start); - if (status != XCCL_OK) { - return status; + if (TEAM_UCX_CTX_REQ(req)->block_stream) { + status = xccl_mem_component_query_activity(req->stream_req); + /* status can't be XCCL_OK since collective wasn't started*/ + assert(status != XCCL_OK); + if (status == XCCL_INITIALIZED) { + return XCCL_INPROGRESS; + } else if (status != XCCL_INPROGRESS) { + /* error */ + return status; + } + } else { + status = xccl_mc_event_query(req->ready_to_start); + if (status != XCCL_OK) { + return status; + } + xccl_mc_event_free(req->ready_to_start); } - xccl_mc_event_free(req->ready_to_start); req->ready_to_start = NULL; req->start(req); } @@ -551,11 +577,11 @@ static xccl_status_t xccl_ucx_collective_test(xccl_tl_coll_req_t *request) if (XCCL_OK != (status = req->progress(req))) { return status; }; - if ((XCCL_INPROGRESS != req->complete) && - (req->stream_req != NULL)) { - xccl_mem_component_finish_acitivity(req->stream_req); - req->stream_req = NULL; - } + } + if ((XCCL_INPROGRESS != req->complete) && + (req->stream_req != NULL)) { + xccl_mem_component_finish_acitivity(req->stream_req); + req->stream_req = NULL; } return req->complete; diff --git a/src/utils/cuda/cuda_mem_component.c b/src/utils/cuda/cuda_mem_component.c index 3d56955..5f6b3f4 100644 --- a/src/utils/cuda/cuda_mem_component.c +++ b/src/utils/cuda/cuda_mem_component.c @@ -3,14 +3,14 @@ xccl_cuda_mem_component_t xccl_cuda_mem_component; -#define NUM_STREAM_REQUESTS 4 +#define NUM_STREAM_REQUESTS 128 #define NUM_EVENTS NUM_STREAM_REQUESTS #define CUDACHECK(cmd) do { \ cudaError_t e = cmd; \ if( e != cudaSuccess && e != cudaErrorCudartUnloading ) { \ - xccl_ucx_error("cuda cmd:%s failed wtih ret:%d(%s)", e, \ - cudaGetErrorString(e)); \ + fprintf(stderr, "cuda failed wtih ret:%d(%s)", e, \ + cudaGetErrorString(e)); \ return XCCL_ERR_NO_MESSAGE; \ } \ } while(0) @@ -57,8 +57,12 @@ static xccl_status_t xccl_cuda_alloc_resources() for (i = 0; i < NUM_STREAM_REQUESTS; i++) { xccl_cuda_mem_component.stream_requests[i].is_free = 1; CUDACHECK(cudaHostGetDevicePointer( - (void**)&(xccl_cuda_mem_component.stream_requests[i].dev_stop_request), - (void*)&(xccl_cuda_mem_component.stream_requests[i].stop_request), + (void**)&(xccl_cuda_mem_component.stream_requests[i].dev_status), + (void*)&(xccl_cuda_mem_component.stream_requests[i].status), + 0)); + CUDACHECK(cudaHostGetDevicePointer( + (void**)&(xccl_cuda_mem_component.stream_requests[i].dev_is_free), + (void*)&(xccl_cuda_mem_component.stream_requests[i].is_free), 0)); CUDACHECK(cudaEventCreateWithFlags( &xccl_cuda_mem_component.stream_requests[i].event, @@ -134,33 +138,45 @@ xccl_status_t xccl_cuda_reduce_multi(void *sbuf1, void *sbuf2, void *rbuf, xccl_cuda_mem_component.stream); } -cudaError_t xccl_cuda_dummy_kernel(int *stop, cudaStream_t stream); +cudaError_t xccl_cuda_dummy_kernel(xccl_status_t *status, int *is_free, + cudaStream_t stream); xccl_status_t xccl_cuda_start_acitivity(xccl_stream_t *stream, xccl_mem_component_stream_request_t **req) { xccl_cuda_mem_component_stream_request_t *request; - int *dev_stop_request; xccl_status_t st; cudaStream_t internal_stream, user_stream; XCCL_CUDA_INIT_RESOUCES(); st = xccl_cuda_get_free_stream_request(&request); if (st != XCCL_OK) { + fprintf(stderr, "cuda mc: failed to get stream req (%d)\n", st); return st; } - request->stop_request = 0; + request->status = XCCL_INITIALIZED; user_stream = *((cudaStream_t*)stream->stream); internal_stream = xccl_cuda_mem_component.stream; - CUDACHECK(xccl_cuda_dummy_kernel(request->dev_stop_request, internal_stream)); + CUDACHECK(xccl_cuda_dummy_kernel(request->dev_status, request->dev_is_free, + internal_stream)); CUDACHECK(cudaEventRecord(request->event, internal_stream)); CUDACHECK(cudaStreamWaitEvent(user_stream, request->event, 0)); *req = &request->super; return XCCL_OK; } +xccl_status_t +xccl_cuda_query_acitivity(xccl_mem_component_stream_request_t *req) +{ + + xccl_cuda_mem_component_stream_request_t *request; + + request = ucs_derived_of(req, xccl_cuda_mem_component_stream_request_t); + return request->status; +} + xccl_status_t xccl_cuda_finish_acitivity(xccl_mem_component_stream_request_t *req) @@ -168,8 +184,8 @@ xccl_cuda_finish_acitivity(xccl_mem_component_stream_request_t *req) xccl_cuda_mem_component_stream_request_t *request; request = ucs_derived_of(req, xccl_cuda_mem_component_stream_request_t); - request->stop_request = 1; - request->is_free = 1; + /* set status to XCCL_OK to request kernel to stop */ + request->status = XCCL_OK; return XCCL_OK; } @@ -208,11 +224,11 @@ xccl_status_t xccl_cuda_event_record(xccl_stream_t *stream, XCCL_CUDA_INIT_RESOUCES(); st = xccl_cuda_get_free_event(&et); if (st != XCCL_OK) { + fprintf(stderr, "cuda mc: failed to get free event (%d)\n", st); return st; } user_stream = *((cudaStream_t*)stream->stream); - CUDACHECK(cudaEventCreateWithFlags(&et->cuda_event, cudaEventDisableTiming)); CUDACHECK(cudaEventRecord(et->cuda_event, user_stream)); *event = &et->super; @@ -270,6 +286,7 @@ xccl_cuda_mem_component_t xccl_cuda_mem_component = { xccl_cuda_event_query, xccl_cuda_event_free, xccl_cuda_start_acitivity, + xccl_cuda_query_acitivity, xccl_cuda_finish_acitivity, xccl_cuda_close }; diff --git a/src/utils/cuda/cuda_mem_component.h b/src/utils/cuda/cuda_mem_component.h index 1e837ea..09c6b21 100644 --- a/src/utils/cuda/cuda_mem_component.h +++ b/src/utils/cuda/cuda_mem_component.h @@ -13,8 +13,9 @@ typedef struct xccl_cuda_mem_component_stream_request { xccl_mem_component_stream_request_t super; int is_free; - int stop_request; - void *dev_stop_request; + void *dev_is_free; + xccl_status_t status; + void *dev_status; cudaEvent_t event; } xccl_cuda_mem_component_stream_request_t; diff --git a/src/utils/cuda/kernels/xccl_cuda_kernel.cu b/src/utils/cuda/kernels/xccl_cuda_kernel.cu index e65a5c5..8b43b8a 100644 --- a/src/utils/cuda/kernels/xccl_cuda_kernel.cu +++ b/src/utils/cuda/kernels/xccl_cuda_kernel.cu @@ -1,11 +1,21 @@ #include #include -__global__ void dummy_kernel(volatile int *stop) { - int should_stop; +__global__ void dummy_kernel(volatile xccl_status_t *status, int *is_free) { + xccl_status_t st; + + if (*status == XCCL_OK) { + /* was requested to stop allready */ + *is_free = 1; + return; + } else { + *status = XCCL_INPROGRESS; + } do { - should_stop = *stop; - } while(!should_stop); + st = *status; + } while(st != XCCL_OK); + + *is_free = 1; return; } @@ -13,9 +23,10 @@ __global__ void dummy_kernel(volatile int *stop) { extern "C" { #endif -cudaError_t xccl_cuda_dummy_kernel(int *stop, cudaStream_t stream) +cudaError_t xccl_cuda_dummy_kernel(volatile xccl_status_t *status, int *is_free, + cudaStream_t stream) { - dummy_kernel<<<1, 1, 0, stream>>>(stop); + dummy_kernel<<<1, 1, 0, stream>>>(status, is_free); return cudaGetLastError(); } diff --git a/src/utils/mem_component.c b/src/utils/mem_component.c index cefa548..b038523 100644 --- a/src/utils/mem_component.c +++ b/src/utils/mem_component.c @@ -195,6 +195,20 @@ xccl_status_t xccl_mem_component_start_acitivity(xccl_stream_t *stream, return st; } +xccl_status_t xccl_mem_component_query_activity(xccl_mem_component_stream_request_t *req) +{ + int mt = req->mem_type; + xccl_status_t st; + + if (mem_components[mt] == NULL) { + xccl_error("mem component %s is not available", ucs_memory_type_names[mt]); + } + + st = mem_components[mt]->query_stream_activity(req); + + return st; +} + xccl_status_t xccl_mem_component_finish_acitivity(xccl_mem_component_stream_request_t *req) { int mt = req->mem_type; diff --git a/src/utils/mem_component.h b/src/utils/mem_component.h index e78d720..96437d0 100644 --- a/src/utils/mem_component.h +++ b/src/utils/mem_component.h @@ -41,6 +41,7 @@ typedef struct xccl_mem_component { xccl_status_t (*event_free)(xccl_mc_event_t *event); xccl_status_t (*start_stream_activity)(xccl_stream_t *stream, xccl_mem_component_stream_request_t **req); + xccl_status_t (*query_stream_activity)(xccl_mem_component_stream_request_t *req); xccl_status_t (*finish_stream_activity)(xccl_mem_component_stream_request_t *req); void (*close)(); void *dlhandle; @@ -64,6 +65,8 @@ xccl_status_t xccl_mem_component_reduce(void *sbuf1, void *sbuf2, void *target, xccl_status_t xccl_mem_component_start_acitivity(xccl_stream_t *stream, xccl_mem_component_stream_request_t **req); +xccl_status_t xccl_mem_component_query_activity(xccl_mem_component_stream_request_t *req); + xccl_status_t xccl_mem_component_finish_acitivity(xccl_mem_component_stream_request_t *req); /*