diff --git a/src/ucp/api/ucp_def.h b/src/ucp/api/ucp_def.h index e33cd084e54..e3b807017d1 100644 --- a/src/ucp/api/ucp_def.h +++ b/src/ucp/api/ucp_def.h @@ -219,7 +219,7 @@ typedef uint64_t ucp_tag_t; * @ref ucp_tag_probe_nb. This handle can be passed to @ref ucp_tag_msg_recv_nb * in order to receive the message data to a specific buffer. */ -typedef struct ucp_recv_desc *ucp_tag_message_h; +typedef void *ucp_tag_message_h; /** diff --git a/src/ucp/tag/probe.c b/src/ucp/tag/probe.c index 6f2b1e08201..0f9b07cf855 100644 --- a/src/ucp/tag/probe.c +++ b/src/ucp/tag/probe.c @@ -13,7 +13,7 @@ #include -static UCS_F_ALWAYS_INLINE ucp_recv_desc_t* +static UCS_F_ALWAYS_INLINE ucs_queue_iter_t ucp_tag_probe_search(ucp_context_h context, ucp_tag_t tag, uint64_t tag_mask, ucp_tag_recv_info_t *info, int remove) { @@ -44,9 +44,12 @@ ucp_tag_probe_search(ucp_context_h context, ucp_tag_t tag, uint64_t tag_mask, } if (remove) { - ucs_queue_del_iter(&context->tm.unexpected, iter); + /* Prevent the receive descriptor, and any fragments after it, + * from being matched by receive requests. + */ + rdesc->flags &= ~UCP_RECV_DESC_FLAG_FIRST; } - return rdesc; + return iter; } } @@ -58,16 +61,16 @@ ucp_tag_message_h ucp_tag_probe_nb(ucp_worker_h worker, ucp_tag_t tag, ucp_tag_recv_info_t *info) { ucp_context_h context = worker->context; - ucp_recv_desc_t *ret; + ucs_queue_iter_t message; UCP_THREAD_CS_ENTER_CONDITIONAL(&worker->mt_lock); UCP_THREAD_CS_ENTER_CONDITIONAL(&context->mt_lock); ucs_trace_req("probe_nb tag %"PRIx64"/%"PRIx64, tag, tag_mask); - ret = ucp_tag_probe_search(context, tag, tag_mask, info, remove); + message = ucp_tag_probe_search(context, tag, tag_mask, info, remove); UCP_THREAD_CS_EXIT_CONDITIONAL(&context->mt_lock); UCP_THREAD_CS_EXIT_CONDITIONAL(&worker->mt_lock); - return ret; + return message; } diff --git a/src/ucp/tag/tag_match.inl b/src/ucp/tag/tag_match.inl index f206786831c..c8833f05406 100644 --- a/src/ucp/tag/tag_match.inl +++ b/src/ucp/tag/tag_match.inl @@ -164,6 +164,7 @@ ucp_tag_unexp_recv(ucp_tag_match_t *tm, ucp_worker_h worker, void *data, static UCS_F_ALWAYS_INLINE void ucp_tag_unexp_desc_release(ucp_recv_desc_t *rdesc) { + ucs_trace_req("release receive descriptor %p", rdesc); if (ucs_unlikely(rdesc->flags & UCP_RECV_DESC_FLAG_UCT_DESC)) { uct_iface_release_desc(rdesc); /* uct desc is slowpath */ } else { diff --git a/src/ucp/tag/tag_recv.c b/src/ucp/tag/tag_recv.c index 94f2f96d759..ca0f704a004 100644 --- a/src/ucp/tag/tag_recv.c +++ b/src/ucp/tag/tag_recv.c @@ -18,7 +18,8 @@ static UCS_F_ALWAYS_INLINE ucs_status_t ucp_tag_search_unexp(ucp_worker_h worker, void *buffer, size_t buffer_size, ucp_datatype_t datatype, ucp_tag_t tag, uint64_t tag_mask, ucp_request_t *req, ucp_tag_recv_info_t *info, - ucp_tag_recv_callback_t cb, unsigned *save_rreq) + ucp_tag_recv_callback_t cb, ucs_queue_iter_t first_iter, + unsigned *save_rreq) { ucp_context_h context = worker->context; ucp_recv_desc_t *rdesc; @@ -27,7 +28,14 @@ ucp_tag_search_unexp(ucp_worker_h worker, void *buffer, size_t buffer_size, ucp_tag_t recv_tag; unsigned flags; - ucs_queue_for_each_safe(rdesc, iter, &context->tm.unexpected, queue) { + if (first_iter == NULL) { + iter = ucs_queue_iter_begin(&context->tm.unexpected); + } else { + iter = first_iter; + } + + while (!ucs_queue_iter_end(&context->tm.unexpected, iter)) { + rdesc = ucs_queue_iter_elem(rdesc, iter, queue); recv_tag = ucp_rdesc_get_tag(rdesc); flags = rdesc->flags; ucs_trace_req("searching for %"PRIx64"/%"PRIx64"/%"PRIx64" offset %zu, " @@ -66,6 +74,8 @@ ucp_tag_search_unexp(ucp_worker_h worker, void *buffer, size_t buffer_size, UCP_WORKER_STAT_RNDV(worker, UNEXP); return UCS_INPROGRESS; } + } else { + iter = ucs_queue_iter_next(iter); } } @@ -108,22 +118,6 @@ ucp_tag_recv_request_init(ucp_request_t *req, ucp_worker_h worker, void* buffer, } } -static UCS_F_ALWAYS_INLINE ucp_request_t* -ucp_tag_recv_request_get(ucp_worker_h worker, void* buffer, size_t count, - ucp_datatype_t datatype) -{ - ucp_request_t *req; - - req = ucp_request_get(worker); - if (ucs_unlikely(req == NULL)) { - return NULL; - } - - ucp_tag_recv_request_init(req, worker, buffer, count, datatype, - UCP_REQUEST_FLAG_CALLBACK); - return req; -} - static UCS_F_ALWAYS_INLINE void ucp_tag_recv_request_completed(ucp_request_t *req, ucs_status_t status, ucp_tag_recv_info_t *info, const char *function) @@ -138,22 +132,30 @@ ucp_tag_recv_request_completed(ucp_request_t *req, ucs_status_t status, } static UCS_F_ALWAYS_INLINE ucs_status_t -ucp_tag_recv_common(ucp_worker_h worker, void *buffer, size_t buffer_size, +ucp_tag_recv_common(ucp_worker_h worker, void *buffer, size_t count, uintptr_t datatype, ucp_tag_t tag, ucp_tag_t tag_mask, - ucp_request_t *req, ucp_tag_recv_callback_t cb, - const char *debug_name) + ucp_request_t *req, uint16_t req_flags, ucp_tag_recv_callback_t cb, + ucs_queue_iter_t iter, const char *debug_name) { ucs_status_t status; unsigned save_rreq = 1; + size_t buffer_size; + + ucp_tag_recv_request_init(req, worker, buffer, count, datatype, req_flags); + buffer_size = ucp_dt_length(datatype, count, buffer, &req->recv.state); ucs_trace_req("%s buffer %p buffer_size %zu tag %"PRIx64"/%"PRIx64, debug_name, buffer, buffer_size, tag, tag_mask); /* First, search in unexpected list */ status = ucp_tag_search_unexp(worker, buffer, buffer_size, datatype, tag, - tag_mask, req, &req->recv.info, cb, &save_rreq); + tag_mask, req, &req->recv.info, cb, iter, + &save_rreq); if (status != UCS_INPROGRESS) { - return status; + if (req_flags & UCP_REQUEST_FLAG_CALLBACK) { + cb(req + 1, status, &req->recv.info); + } + ucp_tag_recv_request_completed(req, status, &req->recv.info, debug_name); } else if (save_rreq) { /* If not found on unexpected, wait until it arrives. * If was found but need this receive request for later completion, save it */ @@ -177,23 +179,15 @@ UCS_PROFILE_FUNC(ucs_status_t, ucp_tag_recv_nbr, uintptr_t datatype, ucp_tag_t tag, ucp_tag_t tag_mask, void *request) { - ucp_request_t *req = (ucp_request_t *)request - 1; + ucp_request_t *req = (ucp_request_t *)request - 1; ucs_status_t status; - size_t buffer_size; UCP_THREAD_CS_ENTER_CONDITIONAL(&worker->mt_lock); UCP_THREAD_CS_ENTER_CONDITIONAL(&worker->context->mt_lock); - ucp_tag_recv_request_init(req, worker, buffer, count, datatype, - UCP_REQUEST_DEBUG_FLAG_EXTERNAL); - - buffer_size = ucp_dt_length(datatype, count, buffer, &req->recv.state); - - status = ucp_tag_recv_common(worker, buffer, buffer_size, datatype, tag, - tag_mask, req, NULL, "recv_nbr"); - if (status != UCS_INPROGRESS) { - ucp_tag_recv_request_completed(req, status, &req->recv.info, "recv_nbr"); - } + status = ucp_tag_recv_common(worker, buffer, count, datatype, tag, tag_mask, + req, UCP_REQUEST_DEBUG_FLAG_EXTERNAL, NULL, NULL, + "recv_nbr"); UCP_THREAD_CS_EXIT_CONDITIONAL(&worker->context->mt_lock); UCP_THREAD_CS_EXIT_CONDITIONAL(&worker->mt_lock); @@ -206,32 +200,21 @@ UCS_PROFILE_FUNC(ucs_status_ptr_t, ucp_tag_recv_nb, uintptr_t datatype, ucp_tag_t tag, ucp_tag_t tag_mask, ucp_tag_recv_callback_t cb) { - ucp_request_t *req; - ucs_status_t status; ucs_status_ptr_t ret; - size_t buffer_size; + ucp_request_t *req; UCP_THREAD_CS_ENTER_CONDITIONAL(&worker->mt_lock); UCP_THREAD_CS_ENTER_CONDITIONAL(&worker->context->mt_lock); - req = ucp_tag_recv_request_get(worker, buffer, count, datatype); - if (ucs_unlikely(req == NULL)) { + req = ucp_request_get(worker); + if (ucs_likely(req != NULL)) { + ucp_tag_recv_common(worker, buffer, count, datatype, tag, tag_mask, req, + UCP_REQUEST_FLAG_CALLBACK, cb, NULL, "recv_nb"); + ret = req + 1; + } else { ret = UCS_STATUS_PTR(UCS_ERR_NO_MEMORY); - goto out; } - buffer_size = ucp_dt_length(datatype, count, buffer, &req->recv.state); - - status = ucp_tag_recv_common(worker, buffer, buffer_size, datatype, tag, - tag_mask, req, cb, "recv_nb"); - - if (status != UCS_INPROGRESS) { - cb(req + 1, status, &req->recv.info); - ucp_tag_recv_request_completed(req, status, &req->recv.info, "recv_nb"); - } - - ret = req + 1; -out: UCP_THREAD_CS_EXIT_CONDITIONAL(&worker->context->mt_lock); UCP_THREAD_CS_EXIT_CONDITIONAL(&worker->mt_lock); return ret; @@ -243,82 +226,26 @@ UCS_PROFILE_FUNC(ucs_status_ptr_t, ucp_tag_msg_recv_nb, uintptr_t datatype, ucp_tag_message_h message, ucp_tag_recv_callback_t cb) { - ucp_recv_desc_t *rdesc = message; - ucs_status_t status; - ucp_request_t *req; - ucp_tag_t tag; - unsigned save_rreq = 1; + ucs_queue_iter_t iter = message; + ucp_recv_desc_t *rdesc; ucs_status_ptr_t ret; - size_t buffer_size; + ucp_request_t *req; UCP_THREAD_CS_ENTER_CONDITIONAL(&worker->mt_lock); UCP_THREAD_CS_ENTER_CONDITIONAL(&worker->context->mt_lock); - ucs_trace_req("msg_recv_nb buffer %p count %zu message %p", buffer, count, - message); - - req = ucp_tag_recv_request_get(worker, buffer, count, datatype); - if (req == NULL) { - ret = UCS_STATUS_PTR(UCS_ERR_NO_MEMORY); - goto out; - } - - buffer_size = ucp_dt_length(datatype, count, buffer, &req->recv.state); - - /* First, handle the first packet that was already matched */ - if (rdesc->flags & UCP_RECV_DESC_FLAG_EAGER) { - tag = ((ucp_tag_hdr_t*)(rdesc + 1))->tag; - UCS_PROFILE_REQUEST_EVENT(req, "eager_match", 0); - status = ucp_eager_unexp_match(worker, rdesc, tag, rdesc->flags, - buffer, buffer_size, datatype, - &req->recv.state, &req->recv.info); - ucs_trace_req("release receive descriptor %p", rdesc); - ucp_tag_unexp_desc_release(rdesc); - } else if (rdesc->flags & UCP_RECV_DESC_FLAG_RNDV) { - req->recv.buffer = buffer; - req->recv.length = buffer_size; - req->recv.datatype = datatype; - req->recv.cb = cb; - ucp_rndv_matched(worker, req, (void*)(rdesc + 1)); - ucp_tag_unexp_desc_release(rdesc); - status = UCS_INPROGRESS; - save_rreq = 0; - UCP_WORKER_STAT_RNDV(worker, UNEXP); + req = ucp_request_get(worker); + if (ucs_likely(req != NULL)) { + rdesc = ucs_queue_iter_elem(rdesc, iter, queue); + rdesc->flags |= UCP_RECV_DESC_FLAG_FIRST; + ucp_tag_recv_common(worker, buffer, count, datatype, + ucp_rdesc_get_tag(rdesc), UCP_TAG_MASK_FULL, req, + UCP_REQUEST_FLAG_CALLBACK, cb, iter, "msg_recv_nb"); + ret = req + 1; } else { - ucp_request_put(req); - ret = UCS_STATUS_PTR(UCS_ERR_INVALID_PARAM); - goto out; - } - - /* Since the message contains only the first fragment, we might want - * to receive additional fragments. - */ - if (status == UCS_INPROGRESS) { - status = ucp_tag_search_unexp(worker, buffer, buffer_size, datatype, 0, - -1, req, &req->recv.info, cb, &save_rreq); - } - - if (status != UCS_INPROGRESS) { - cb(req + 1, status, &req->recv.info); - ucp_tag_recv_request_completed(req, status, &req->recv.info, - "msg_recv_nb"); - } else if (save_rreq) { - ucs_trace_req("msg_recv_nb returning inprogress request %p (%p)", - req, req + 1); - /* For eager - need to put the recv_req in expected since more packets - * will follow. For rndv - don't need to keep the recv_req in the expected queue - * as the match to the RTS already happened. */ - req->recv.buffer = buffer; - req->recv.length = buffer_size; - req->recv.datatype = datatype; - req->recv.cb = cb; - req->recv.tag = req->recv.info.sender_tag; - req->recv.tag_mask = UCP_TAG_MASK_FULL; - ucp_tag_exp_add(&worker->context->tm, req); + ret = UCS_STATUS_PTR(UCS_ERR_NO_MEMORY); } - ret = req + 1; -out: UCP_THREAD_CS_EXIT_CONDITIONAL(&worker->context->mt_lock); UCP_THREAD_CS_EXIT_CONDITIONAL(&worker->mt_lock); return ret; diff --git a/test/gtest/ucp/test_ucp_tag_probe.cc b/test/gtest/ucp/test_ucp_tag_probe.cc index 6294a92b5d0..9d66ad73d80 100644 --- a/test/gtest/ucp/test_ucp_tag_probe.cc +++ b/test/gtest/ucp/test_ucp_tag_probe.cc @@ -179,6 +179,64 @@ UCS_TEST_P(test_ucp_tag_probe, send_rndv_msg_probe, "RNDV_THRESH=1048576") { request_release(my_recv_req); } +UCS_TEST_P(test_ucp_tag_probe, send_2_msg_probe, "RNDV_THRESH=inf") { + const ucp_datatype_t DT_INT = ucp_dt_make_contig(sizeof(int)); + const ucp_tag_t TAG = 0xaaa; + const size_t COUNT = 20000; + + /* + * send in order: 1, 2 + */ + std::vector sdata1(COUNT, 1); + std::vector sdata2(COUNT, 2); + send_b(&sdata1[0], COUNT, DT_INT, TAG); + send_b(&sdata2[0], COUNT, DT_INT, TAG); + + /* + * probe in order: 1, 2 + */ + ucp_tag_message_h message1, message2; + ucp_tag_recv_info info; + do { + progress(); + message1 = ucp_tag_probe_nb(receiver().worker(), TAG, 0xffff, 1, &info); + } while (message1 == NULL); + do { + progress(); + message2 = ucp_tag_probe_nb(receiver().worker(), TAG, 0xffff, 1, &info); + } while (message2 == NULL); + + /* + * receive in **reverse** order: 2, 1 + */ + std::vector rdata2(COUNT); + request *rreq2 = (request*)ucp_tag_msg_recv_nb(receiver().worker(), &rdata2[0], + COUNT, DT_INT, message2, + recv_callback); + ASSERT_TRUE(!UCS_PTR_IS_ERR(rreq2)); + wait(rreq2); + + std::vector rdata1(COUNT); + request *rreq1 = (request*)ucp_tag_msg_recv_nb(receiver().worker(), &rdata1[0], + COUNT, DT_INT, message1, + recv_callback); + ASSERT_TRUE(!UCS_PTR_IS_ERR(rreq1)); + wait(rreq1); + + /* + * expect data to arrive in probe order (rather than recv order) + */ + EXPECT_TRUE(rreq1->completed); + EXPECT_TRUE(rreq2->completed); + EXPECT_EQ(UCS_OK, rreq1->status); + EXPECT_EQ(UCS_OK, rreq2->status); + EXPECT_EQ(sdata1, rdata1); + EXPECT_EQ(sdata2, rdata2); + + request_release(rreq1); + request_release(rreq2); +} + UCS_TEST_P(test_ucp_tag_probe, limited_probe_size) { static const int COUNT = 1000; std::string sendbuf, recvbuf;