diff --git a/src/ucp/core/ucp_ep.inl b/src/ucp/core/ucp_ep.inl index a538a3c547b..e91702d8fe3 100644 --- a/src/ucp/core/ucp_ep.inl +++ b/src/ucp/core/ucp_ep.inl @@ -51,8 +51,7 @@ static inline int ucp_ep_is_rndv_lane_present(ucp_ep_h ep, int idx) static inline int ucp_ep_is_rndv_mrail_present(ucp_ep_h ep) { - return (ucp_ep_config(ep)->key.rndv_lanes[0] != UCP_NULL_LANE && - ucp_ep_config(ep)->key.rndv_lanes[1] != UCP_NULL_LANE); + return ucp_ep_config(ep)->key.rndv_lanes[1] != UCP_NULL_LANE; } static inline int ucp_ep_is_tag_offload_enabled(ucp_ep_config_t *config) diff --git a/src/ucp/core/ucp_request.c b/src/ucp/core/ucp_request.c index c35ad58d350..3fcc4a10d6e 100644 --- a/src/ucp/core/ucp_request.c +++ b/src/ucp/core/ucp_request.c @@ -134,6 +134,13 @@ ucs_mpool_ops_t ucp_request_mpool_ops = { .obj_cleanup = ucp_worker_request_fini_proxy }; +ucs_mpool_ops_t ucp_mrail_mpool_ops = { + .chunk_alloc = ucs_mpool_hugetlb_malloc, + .chunk_release = ucs_mpool_hugetlb_free, + .obj_init = NULL, + .obj_cleanup = NULL +}; + int ucp_request_pending_add(ucp_request_t *req, ucs_status_t *req_status) { ucs_status_t status; diff --git a/src/ucp/core/ucp_request.h b/src/ucp/core/ucp_request.h index 6830b7193cd..b4e35e590bc 100644 --- a/src/ucp/core/ucp_request.h +++ b/src/ucp/core/ucp_request.h @@ -36,6 +36,7 @@ enum { UCP_REQUEST_FLAG_RNDV = UCS_BIT(9), UCP_REQUEST_FLAG_OFFLOADED = UCS_BIT(10), UCP_REQUEST_FLAG_BLOCK_OFFLOAD = UCS_BIT(11), + UCP_REQUEST_FLAG_RNDV_MRAIL = UCS_BIT(12), #if ENABLE_ASSERT UCP_REQUEST_DEBUG_FLAG_EXTERNAL = UCS_BIT(15) @@ -111,16 +112,11 @@ struct ucp_request { } proxy; struct { - uint64_t remote_address; /* address of the sender's data buffer */ - uintptr_t remote_request; /* pointer to the sender's send request */ - uct_rkey_bundle_t rkey_bundle; - ucp_request_t *rreq; /* receive request on the recv side */ - unsigned use_mrail; - unsigned rail_idx; - struct { - ucp_lane_index_t lane; - uct_rkey_bundle_t rkey_bundle; - } mrail[UCP_MAX_RAILS]; + uint64_t remote_address; /* address of the sender's data buffer */ + uintptr_t remote_request; /* pointer to the sender's send request */ + uct_rkey_bundle_t rkey_bundle; + ucp_request_t *rreq; /* receive request on the recv side */ + struct ucp_rndv_get_mrail *mrail; /* multirail info */ } rndv_get; struct { @@ -191,7 +187,20 @@ typedef struct ucp_recv_desc { } ucp_recv_desc_t; +/** + * Multirail rendezvous-get info. + */ +typedef struct ucp_rndv_get_mrail { + unsigned rail_idx; + struct { + ucp_lane_index_t lane; + uct_rkey_bundle_t rkey_bundle; + } rail[UCP_MAX_RAILS]; +} ucp_rndv_get_mrail_t; + + extern ucs_mpool_ops_t ucp_request_mpool_ops; +extern ucs_mpool_ops_t ucp_mrail_mpool_ops; int ucp_request_pending_add(ucp_request_t *req, ucs_status_t *req_status); diff --git a/src/ucp/core/ucp_request.inl b/src/ucp/core/ucp_request.inl index 839722621c9..84a8659bb95 100644 --- a/src/ucp/core/ucp_request.inl +++ b/src/ucp/core/ucp_request.inl @@ -16,7 +16,7 @@ #define UCP_REQUEST_FLAGS_FMT \ - "%c%c%c%c%c%c%c%c%c" + "%c%c%c%c%c%c%c%c%c%c" #define UCP_REQUEST_FLAGS_ARG(_flags) \ (((_flags) & UCP_REQUEST_FLAG_COMPLETED) ? 'd' : '-'), \ @@ -27,7 +27,8 @@ (((_flags) & UCP_REQUEST_FLAG_CALLBACK) ? 'c' : '-'), \ (((_flags) & UCP_REQUEST_FLAG_RECV) ? 'r' : '-'), \ (((_flags) & UCP_REQUEST_FLAG_SYNC) ? 's' : '-'), \ - (((_flags) & UCP_REQUEST_FLAG_RNDV) ? 'v' : '-') + (((_flags) & UCP_REQUEST_FLAG_RNDV) ? 'v' : '-'), \ + (((_flags) & UCP_REQUEST_FLAG_RNDV_MRAIL) ? 'R' : '-') /* defined as a macro to print the call site */ @@ -68,6 +69,7 @@ ucp_request_put(ucp_request_t *req) { ucs_trace_req("put request %p", req); UCS_PROFILE_REQUEST_FREE(req); + ucs_mpool_put_inline(req); } @@ -78,6 +80,11 @@ ucp_request_complete_send(ucp_request_t *req, ucs_status_t status) req, req + 1, UCP_REQUEST_FLAGS_ARG(req->flags), ucs_status_string(status)); UCS_PROFILE_REQUEST_EVENT(req, "complete_send", status); + + if (ucs_unlikely(req->flags & UCP_REQUEST_FLAG_RNDV_MRAIL)) { + ucs_mpool_put_inline(req->send.rndv_get.mrail); + } + ucp_request_complete(req, send.cb, status); } @@ -93,6 +100,11 @@ ucp_request_complete_recv(ucp_request_t *req, ucs_status_t status) if (req->flags & UCP_REQUEST_FLAG_BLOCK_OFFLOAD) { --req->recv.worker->context->tm.offload.sw_req_count; } + + if (ucs_unlikely(req->flags & UCP_REQUEST_FLAG_RNDV_MRAIL)) { + ucs_mpool_put_inline(req->send.rndv_get.mrail); + } + ucp_request_complete(req, recv.cb, status, &req->recv.info); } @@ -188,6 +200,14 @@ static UCS_F_ALWAYS_INLINE void ucp_request_send_stat(ucp_request_t *req) } } +static UCS_F_ALWAYS_INLINE void +ucp_request_mrail_create(ucp_request_t *req) +{ + ucs_trace_req("mrail create request %p", req); + req->send.rndv_get.mrail = (ucp_rndv_get_mrail_t *)ucs_mpool_get_inline(&(req->send.ep->worker)->mrail_mp); + req->flags |= UCP_REQUEST_FLAG_RNDV_MRAIL; +} + static UCS_F_ALWAYS_INLINE void ucp_request_clear_rails(ucp_dt_state_t *state) { int i; @@ -199,8 +219,7 @@ ucp_request_clear_rails(ucp_dt_state_t *state) { static UCS_F_ALWAYS_INLINE int ucp_request_is_empty_rail(ucp_dt_state_t *state, int rail) { - return state->dt.mrail[rail].memh == UCT_MEM_HANDLE_NULL || - state->dt.mrail[rail].lane == UCP_NULL_LANE; + return state->dt.mrail[rail].lane == UCP_NULL_LANE; } static UCS_F_ALWAYS_INLINE int diff --git a/src/ucp/core/ucp_types.h b/src/ucp/core/ucp_types.h index 9a9a56d01bc..17754b48a50 100644 --- a/src/ucp/core/ucp_types.h +++ b/src/ucp/core/ucp_types.h @@ -29,8 +29,8 @@ typedef ucp_rsc_index_t ucp_md_index_t; UCP_UINT_TYPE(UCP_MD_INDEX_BITS) ucp_md_map_t; /* Lanes */ -#define UCP_MAX_LANES 16 -#define UCP_MAX_RAILS 8 +#define UCP_MAX_LANES 8 +#define UCP_MAX_RAILS 4 #define UCP_NULL_LANE ((ucp_lane_index_t)-1) typedef uint8_t ucp_lane_index_t; UCP_UINT_TYPE(UCP_MAX_LANES) ucp_lane_map_t; diff --git a/src/ucp/core/ucp_worker.c b/src/ucp/core/ucp_worker.c index 96f8e48093b..f3a728c4523 100644 --- a/src/ucp/core/ucp_worker.c +++ b/src/ucp/core/ucp_worker.c @@ -1071,10 +1071,19 @@ ucs_status_t ucp_worker_create(ucp_context_h context, goto err_destroy_uct_worker; } + /* Create memory pool for multirail info */ + status = ucs_mpool_init(&worker->mrail_mp, 0, + sizeof(ucp_rndv_get_mrail_t), + 0, UCS_SYS_CACHE_LINE_SIZE, 128, UINT_MAX, + &ucp_mrail_mpool_ops, "ucp_multirail"); + if (status != UCS_OK) { + goto err_req_mp_cleanup; + } + /* Create epoll set which combines events from all transports */ status = ucp_worker_wakeup_init(worker, params); if (status != UCS_OK) { - goto err_req_mp_cleanup; + goto err_mrail_mp_cleanup; } if (params->field_mask & UCP_WORKER_PARAM_FIELD_CPU_MASK) { @@ -1107,6 +1116,8 @@ ucs_status_t ucp_worker_create(ucp_context_h context, err_close_ifaces: ucp_worker_close_ifaces(worker); ucp_worker_wakeup_cleanup(worker); +err_mrail_mp_cleanup: + ucs_mpool_cleanup(&worker->mrail_mp, 1); err_req_mp_cleanup: ucs_mpool_cleanup(&worker->req_mp, 1); err_destroy_uct_worker: @@ -1141,6 +1152,7 @@ void ucp_worker_destroy(ucp_worker_h worker) ucs_mpool_cleanup(&worker->reg_mp, 1); ucp_worker_close_ifaces(worker); ucp_worker_wakeup_cleanup(worker); + ucs_mpool_cleanup(&worker->mrail_mp, 1); ucs_mpool_cleanup(&worker->req_mp, 1); uct_worker_destroy(worker->uct); ucs_async_context_cleanup(&worker->async); diff --git a/src/ucp/core/ucp_worker.h b/src/ucp/core/ucp_worker.h index 916755910a5..9910ae680ef 100644 --- a/src/ucp/core/ucp_worker.h +++ b/src/ucp/core/ucp_worker.h @@ -127,6 +127,7 @@ typedef struct ucp_worker { uint64_t uuid; /* Unique ID for wireup */ uct_worker_h uct; /* UCT worker handle */ ucs_mpool_t req_mp; /* Memory pool for requests */ + ucs_mpool_t mrail_mp; /* Memory pool for multirail */ uint64_t atomic_tls; /* Which resources can be used for atomics */ int inprogress; diff --git a/src/ucp/tag/rndv.c b/src/ucp/tag/rndv.c index 6c3bc0077c9..197c3541dcf 100644 --- a/src/ucp/tag/rndv.c +++ b/src/ucp/tag/rndv.c @@ -108,15 +108,19 @@ static void ucp_tag_rndv_unpack_mrail_rkeys(ucp_request_t *req, void *rkey_buf) ucs_assert(UCP_DT_IS_CONTIG(req->send.datatype)); ucs_assert(ucp_ep_is_rndv_mrail_present(ep)); + ucp_request_mrail_create(req); + for (i = 0; ucp_ep_is_rndv_lane_present(ep, i) && i < UCP_MAX_RAILS; i++) { lane = ucp_ep_get_rndv_get_lane(ep, i); if (ucp_ep_rndv_md_flags(ep, lane) & UCT_MD_FLAG_NEED_RKEY) { UCS_PROFILE_CALL(uct_rkey_unpack, rkey_buf + packet, - &req->send.rndv_get.mrail[i].rkey_bundle); - req->send.rndv_get.mrail[i].lane = lane; + &req->send.rndv_get.mrail->rail[i].rkey_bundle); + req->send.rndv_get.mrail->rail[i].lane = lane; packet += ucp_ep_md_attr(ep, lane)->rkey_packed_size; } } + + req->flags |= UCP_REQUEST_FLAG_RNDV_MRAIL; } static size_t ucp_tag_rndv_rts_pack(void *dest, void *arg) @@ -308,7 +312,7 @@ UCS_PROFILE_FUNC(ucs_status_t, ucp_proto_progress_rndv_get_zcopy, (self), if (!(ucp_tag_rndv_is_get_op_possible(rndv_req->send.ep, rndv_req->send.lane, rndv_req->send.rndv_get.rkey_bundle.rkey)) && - !(rndv_req->send.rndv_get.use_mrail)) { + !(rndv_req->send.rndv_get.mrail)) { /* can't perform get_zcopy - switch to AM rndv */ if (rndv_req->send.rndv_get.rkey_bundle.rkey != UCT_INVALID_RKEY) { uct_rkey_release(&rndv_req->send.rndv_get.rkey_bundle); @@ -328,16 +332,16 @@ UCS_PROFILE_FUNC(ucs_status_t, ucp_proto_progress_rndv_get_zcopy, (self), rndv_req->send.ep, rndv_req, rndv_req->send.lane); /* rndv_req is the internal request to perform the get operation */ - if (!rndv_req->send.rndv_get.use_mrail && + if (!rndv_req->send.rndv_get.mrail && (rndv_req->send.state.dt.contig.memh == UCT_MEM_HANDLE_NULL)) { /* TODO Not all UCTs need registration on the recv side */ UCS_PROFILE_REQUEST_EVENT(rndv_req->send.rndv_get.rreq, "rndv_recv_reg", 0); status = ucp_request_send_buffer_reg(rndv_req, rndv_req->send.lane); ucs_assert_always(status == UCS_OK); - } else if(rndv_req->send.rndv_get.use_mrail && + } else if(rndv_req->send.rndv_get.mrail && ucp_request_is_empty_rail(&rndv_req->send.state, 0)) { ucp_request_mrail_reg(rndv_req); - rndv_req->send.rndv_get.rail_idx = 0; + rndv_req->send.rndv_get.mrail->rail_idx = 0; } offset = rndv_req->send.state.offset; @@ -354,26 +358,27 @@ UCS_PROFILE_FUNC(ucs_status_t, ucp_proto_progress_rndv_get_zcopy, (self), offset, (uintptr_t)rndv_req->send.buffer % align, (void*)rndv_req->send.buffer + offset, length); - if ((rndv_req->send.rndv_get.use_mrail) && - ((ucp_request_is_empty_rail(&rndv_req->send.state, rndv_req->send.rndv_get.rail_idx)) || - (rndv_req->send.rndv_get.rail_idx >= UCP_MAX_RAILS))) { - rndv_req->send.rndv_get.rail_idx = 0; + if ((rndv_req->send.rndv_get.mrail) && + ((ucp_request_is_empty_rail(&rndv_req->send.state, + rndv_req->send.rndv_get.mrail->rail_idx)) || + (rndv_req->send.rndv_get.mrail->rail_idx >= UCP_MAX_RAILS))) { + rndv_req->send.rndv_get.mrail->rail_idx = 0; } iov[0].buffer = (void*)rndv_req->send.buffer + offset; iov[0].length = length; iov[0].count = 1; iov[0].stride = 0; - if (!rndv_req->send.rndv_get.use_mrail) { + if (!rndv_req->send.rndv_get.mrail) { iov[0].memh = rndv_req->send.state.dt.contig.memh; lane = rndv_req->send.lane; rkey = rndv_req->send.rndv_get.rkey_bundle.rkey; } else { - rail_idx = rndv_req->send.rndv_get.rail_idx; + rail_idx = rndv_req->send.rndv_get.mrail->rail_idx; iov[0].memh = rndv_req->send.state.dt.mrail[rail_idx].memh; - lane = rndv_req->send.rndv_get.mrail[rail_idx].lane; - rkey = rndv_req->send.rndv_get.mrail[rail_idx].rkey_bundle.rkey; - rndv_req->send.rndv_get.rail_idx++; + lane = rndv_req->send.rndv_get.mrail->rail[rail_idx].lane; + rkey = rndv_req->send.rndv_get.mrail->rail[rail_idx].rkey_bundle.rkey; + rndv_req->send.rndv_get.mrail->rail_idx++; } rndv_req->send.uct_comp.count++; @@ -433,10 +438,8 @@ static void ucp_rndv_handle_recv_contig(ucp_request_t *rndv_req, ucp_request_t * rndv_req->send.proto.remote_request = rndv_rts_hdr->sreq.reqptr; rndv_req->send.proto.rreq_ptr = (uintptr_t) rreq; } else { - rndv_req->send.rndv_get.use_mrail = 0; if (rndv_rts_hdr->flags & UCP_RNDV_RTS_FLAG_PACKED_MRAIL_RKEY) { ucp_tag_rndv_unpack_mrail_rkeys(rndv_req, rndv_rts_hdr + 1); - rndv_req->send.rndv_get.use_mrail = 1; } else if (rndv_rts_hdr->flags & UCP_RNDV_RTS_FLAG_PACKED_RKEY) { UCS_PROFILE_CALL(uct_rkey_unpack, rndv_rts_hdr + 1, &rndv_req->send.rndv_get.rkey_bundle);