diff --git a/src/ucp/tag/eager_rcv.c b/src/ucp/tag/eager_rcv.c index a9f4b3ab1a7..88ecd8a865b 100644 --- a/src/ucp/tag/eager_rcv.c +++ b/src/ucp/tag/eager_rcv.c @@ -75,6 +75,8 @@ ucp_eager_handler(void *arg, void *data, size_t length, unsigned am_flags, * because it arrived as unexpected */ if (flags & UCP_RECV_DESC_FLAG_OFFLOAD) { ucp_tag_offload_cancel(context, req, 1); + } else { + ucs_assert(!(req->flags & UCP_REQUEST_FLAG_OFFLOADED)); } if (flags & UCP_RECV_DESC_FLAG_LAST) { diff --git a/src/ucp/tag/offload.c b/src/ucp/tag/offload.c index edf92b9577d..e8d8cc779d0 100644 --- a/src/ucp/tag/offload.c +++ b/src/ucp/tag/offload.c @@ -74,34 +74,40 @@ ucs_status_t ucp_tag_offload_unexp_rndv(void *arg, unsigned flags, uint64_t remote_addr, size_t length, const void *rkey_buf) { - ucp_rndv_rts_hdr_t *rts = (ucp_rndv_rts_hdr_t*)(((ucp_tag_hdr_t*)hdr) - 1); - ucp_worker_t *worker = arg; - void *rkey = rts + 1; - size_t len = sizeof(*rts); - ucp_ep_t *ep = ucp_worker_get_reply_ep(worker, rts->sreq.sender_uuid); - const uct_md_attr_t *md_attrs; - size_t rkey_size; - - /* rts.req should be alredy in place - it is sent by the sender. - * Fill the rest of rts header and pass to common rts handler */ - if (rkey_buf) { - /* Copy rkey before to fill rts, to avoid overriding rkey */ - md_attrs = ucp_ep_md_attr(ep, ucp_ep_get_tag_lane(ep)); - rkey_size = md_attrs->rkey_packed_size; - memcpy(rkey, rkey_buf, rkey_size); - len += rkey_size; - rts->flags = UCP_RNDV_RTS_FLAG_PACKED_RKEY | UCP_RNDV_RTS_FLAG_OFFLOAD; + ucp_worker_t *worker = arg; + ucp_request_hdr_t *rndv_hdr = (ucp_request_hdr_t*)hdr; + ucp_ep_t *ep = ucp_worker_get_reply_ep(worker, rndv_hdr->sender_uuid); + const uct_md_attr_t *md_attr = ucp_ep_md_attr(ep, ucp_ep_get_tag_lane(ep)); + size_t rkey_size = rkey_buf ? md_attr->rkey_packed_size : 0; + size_t len = sizeof(ucp_rndv_rts_hdr_t) + rkey_size; + ucp_rndv_rts_hdr_t *rts = ucs_alloca(len); + ucp_sw_rndv_hdr_t *sw_rndv_hdr; + + /* Fill RTS to emulate SW RNDV flow. */ + rts->super.tag = stag; + rts->sreq = *rndv_hdr; + rts->address = remote_addr; + + if (remote_addr) { rts->size = length; + rts->flags = UCP_RNDV_RTS_FLAG_OFFLOAD; + if (rkey_buf) { + memcpy(rts + 1, rkey_buf, rkey_size); + len += rkey_size; + rts->flags |= UCP_RNDV_RTS_FLAG_PACKED_RKEY; + } } else { - ucs_assert(remote_addr == 0ul); /* This must be SW RNDV request. Take length from its header. */ - rts->size = ((ucp_sw_rndv_hdr_t*)hdr)->length; + sw_rndv_hdr = ucs_derived_of(hdr, ucp_sw_rndv_hdr_t); + rts->size = sw_rndv_hdr->length; + rts->flags = 0; } - rts->super.tag = stag; - rts->address = remote_addr; + /* Pass 0 as tl flags, because RTS needs to be stored in UCP pool. */ + ucp_rndv_process_rts(arg, rts, len, 0); - return ucp_rndv_rts_handler(arg, rts, len, flags, UCP_RECV_DESC_FLAG_OFFLOAD); + /* Always return UCS_OK, since RNDV hdr should be stored in UCP mpool. */ + return UCS_OK; } void ucp_tag_offload_cancel(ucp_context_t *ctx, ucp_request_t *req, int force) @@ -109,16 +115,17 @@ void ucp_tag_offload_cancel(ucp_context_t *ctx, ucp_request_t *req, int force) ucp_worker_iface_t *ucp_iface; ucs_status_t status; - ucs_assert(req->flags & UCP_REQUEST_FLAG_OFFLOADED); - - ucp_iface = ucs_queue_head_elem_non_empty(&ctx->tm.offload_ifaces, - ucp_worker_iface_t, queue); - ucp_request_memory_dereg(ctx, ucp_iface->rsc_index, req->recv.datatype, - &req->recv.state); - status = uct_iface_tag_recv_cancel(ucp_iface->iface, &req->recv.uct_ctx, force); - if (status != UCS_OK) { - ucs_error("Failed to cancel recv in the transport: %s", - ucs_status_string(status)); + if (req->flags & UCP_REQUEST_FLAG_OFFLOADED) { + ucp_iface = ucs_queue_head_elem_non_empty(&ctx->tm.offload_ifaces, + ucp_worker_iface_t, queue); + ucp_request_memory_dereg(ctx, ucp_iface->rsc_index, req->recv.datatype, + &req->recv.state); + status = uct_iface_tag_recv_cancel(ucp_iface->iface, &req->recv.uct_ctx, + force); + if (status != UCS_OK) { + ucs_error("Failed to cancel recv in the transport: %s", + ucs_status_string(status)); + } } } diff --git a/src/ucp/tag/rndv.c b/src/ucp/tag/rndv.c index e2b2a568370..d92992fe61c 100644 --- a/src/ucp/tag/rndv.c +++ b/src/ucp/tag/rndv.c @@ -431,15 +431,13 @@ UCS_PROFILE_FUNC_VOID(ucp_rndv_matched, (worker, rreq, rndv_rts_hdr), UCS_ASYNC_UNBLOCK(&worker->async); } -UCS_PROFILE_FUNC(ucs_status_t, ucp_rndv_rts_handler, - (arg, data, length, tl_flags, desc_flags), - void *arg, void *data, size_t length, unsigned tl_flags, - unsigned desc_flags) +UCS_PROFILE_FUNC(ucs_status_t, ucp_rndv_process_rts, + (arg, data, length, tl_flags), + void *arg, void *data, size_t length, unsigned tl_flags) { const unsigned recv_flags = UCP_RECV_DESC_FLAG_FIRST | UCP_RECV_DESC_FLAG_LAST | - UCP_RECV_DESC_FLAG_RNDV | - desc_flags; + UCP_RECV_DESC_FLAG_RNDV; ucp_worker_h worker = arg; ucp_rndv_rts_hdr_t *rndv_rts_hdr = data; ucp_context_h context = worker->context; @@ -455,9 +453,7 @@ UCS_PROFILE_FUNC(ucs_status_t, ucp_rndv_rts_handler, /* Cancel req in transport if it was offloaded, because it arrived as unexpected */ - if (recv_flags & UCP_RECV_DESC_FLAG_OFFLOAD) { - ucp_tag_offload_cancel(context, rreq, 1); - } + ucp_tag_offload_cancel(context, rreq, 1); UCP_WORKER_STAT_RNDV(worker, EXP); status = UCS_OK; @@ -469,10 +465,11 @@ UCS_PROFILE_FUNC(ucs_status_t, ucp_rndv_rts_handler, UCP_THREAD_CS_EXIT_CONDITIONAL(&context->mt_lock); return status; } -ucs_status_t ucp_rndv_rts_handler_wrap(void *arg, void *data, size_t length, - unsigned tl_flags) + +ucs_status_t ucp_rndv_rts_handler(void *arg, void *data, size_t length, + unsigned tl_flags) { - return ucp_rndv_rts_handler(arg, data, length, tl_flags, 0); + return ucp_rndv_process_rts(arg, data, length, tl_flags); } UCS_PROFILE_FUNC(ucs_status_t, ucp_rndv_ats_handler, @@ -790,7 +787,7 @@ static void ucp_rndv_dump(ucp_worker_h worker, uct_am_trace_type_t type, } } -UCP_DEFINE_AM(UCP_FEATURE_TAG, UCP_AM_ID_RNDV_RTS, ucp_rndv_rts_handler_wrap, +UCP_DEFINE_AM(UCP_FEATURE_TAG, UCP_AM_ID_RNDV_RTS, ucp_rndv_rts_handler, ucp_rndv_dump, UCT_AM_CB_FLAG_SYNC); UCP_DEFINE_AM(UCP_FEATURE_TAG, UCP_AM_ID_RNDV_ATS, ucp_rndv_ats_handler, ucp_rndv_dump, UCT_AM_CB_FLAG_SYNC); diff --git a/src/ucp/tag/rndv.h b/src/ucp/tag/rndv.h index db04ad58b8d..525e60d8a03 100644 --- a/src/ucp/tag/rndv.h +++ b/src/ucp/tag/rndv.h @@ -53,9 +53,8 @@ void ucp_rndv_matched(ucp_worker_h worker, ucp_request_t *req, ucs_status_t ucp_proto_progress_rndv_get_zcopy(uct_pending_req_t *self); -ucs_status_t -ucp_rndv_rts_handler(void *arg, void *data, size_t length, unsigned tl_flags, - unsigned desc_flags); +ucs_status_t ucp_rndv_process_rts(void *arg, void *data, size_t length, + unsigned tl_flags); static inline size_t ucp_rndv_total_len(ucp_rndv_rts_hdr_t *hdr) diff --git a/src/ucp/tag/tag_match.inl b/src/ucp/tag/tag_match.inl index e370910d102..f1b3929a979 100644 --- a/src/ucp/tag/tag_match.inl +++ b/src/ucp/tag/tag_match.inl @@ -156,14 +156,15 @@ ucp_tag_unexp_recv(ucp_tag_match_t *tm, ucp_worker_h worker, void *data, size_t length, unsigned am_flags, uint16_t hdr_len, uint16_t flags) { - ucp_recv_desc_t *rdesc = (ucp_recv_desc_t *)data - 1; + ucp_recv_desc_t *rdesc; ucs_list_link_t *hash_list; ucs_status_t status; if (ucs_unlikely(am_flags & UCT_CB_FLAG_DESC)) { - /* desc==data is slowpath */ + /* slowpath */ + rdesc = (ucp_recv_desc_t *)data - 1; rdesc->flags = flags | UCP_RECV_DESC_FLAG_UCT_DESC; - status = UCS_INPROGRESS; + status = UCS_INPROGRESS; } else { rdesc = (ucp_recv_desc_t*)ucs_mpool_get_inline(&worker->am_mp); if (rdesc == NULL) {