Skip to content

Commit

Permalink
UCP/TAG/TEST: Fix tag_msg_probe and unite flow for receive funcs.
Browse files Browse the repository at this point in the history
  • Loading branch information
yosefe committed May 23, 2017
1 parent 5b551dc commit cc693c4
Show file tree
Hide file tree
Showing 5 changed files with 117 additions and 128 deletions.
2 changes: 1 addition & 1 deletion src/ucp/api/ucp_def.h
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ typedef uint64_t ucp_tag_t;
* @ref ucp_tag_probe_nb. This handle can be passed to @ref ucp_tag_msg_recv_nb
* in order to receive the message data to a specific buffer.
*/
typedef struct ucp_recv_desc *ucp_tag_message_h;
typedef void *ucp_tag_message_h;


/**
Expand Down
15 changes: 9 additions & 6 deletions src/ucp/tag/probe.c
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
#include <ucs/datastruct/queue.h>


static UCS_F_ALWAYS_INLINE ucp_recv_desc_t*
static UCS_F_ALWAYS_INLINE ucs_queue_iter_t
ucp_tag_probe_search(ucp_context_h context, ucp_tag_t tag, uint64_t tag_mask,
ucp_tag_recv_info_t *info, int remove)
{
Expand Down Expand Up @@ -44,9 +44,12 @@ ucp_tag_probe_search(ucp_context_h context, ucp_tag_t tag, uint64_t tag_mask,
}

if (remove) {
ucs_queue_del_iter(&context->tm.unexpected, iter);
/* Prevent the receive descriptor, and any fragments after it,
* from being matched by receive requests.
*/
rdesc->flags &= ~UCP_RECV_DESC_FLAG_FIRST;
}
return rdesc;
return iter;
}
}

Expand All @@ -58,16 +61,16 @@ ucp_tag_message_h ucp_tag_probe_nb(ucp_worker_h worker, ucp_tag_t tag,
ucp_tag_recv_info_t *info)
{
ucp_context_h context = worker->context;
ucp_recv_desc_t *ret;
ucs_queue_iter_t message;

UCP_THREAD_CS_ENTER_CONDITIONAL(&worker->mt_lock);
UCP_THREAD_CS_ENTER_CONDITIONAL(&context->mt_lock);

ucs_trace_req("probe_nb tag %"PRIx64"/%"PRIx64, tag, tag_mask);
ret = ucp_tag_probe_search(context, tag, tag_mask, info, remove);
message = ucp_tag_probe_search(context, tag, tag_mask, info, remove);

UCP_THREAD_CS_EXIT_CONDITIONAL(&context->mt_lock);
UCP_THREAD_CS_EXIT_CONDITIONAL(&worker->mt_lock);

return ret;
return message;
}
1 change: 1 addition & 0 deletions src/ucp/tag/tag_match.inl
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ ucp_tag_unexp_recv(ucp_tag_match_t *tm, ucp_worker_h worker, void *data,
static UCS_F_ALWAYS_INLINE void
ucp_tag_unexp_desc_release(ucp_recv_desc_t *rdesc)
{
ucs_trace_req("release receive descriptor %p", rdesc);
if (ucs_unlikely(rdesc->flags & UCP_RECV_DESC_FLAG_UCT_DESC)) {
uct_iface_release_desc(rdesc); /* uct desc is slowpath */
} else {
Expand Down
169 changes: 48 additions & 121 deletions src/ucp/tag/tag_recv.c
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ static UCS_F_ALWAYS_INLINE ucs_status_t
ucp_tag_search_unexp(ucp_worker_h worker, void *buffer, size_t buffer_size,
ucp_datatype_t datatype, ucp_tag_t tag, uint64_t tag_mask,
ucp_request_t *req, ucp_tag_recv_info_t *info,
ucp_tag_recv_callback_t cb, unsigned *save_rreq)
ucp_tag_recv_callback_t cb, ucs_queue_iter_t first_iter,
unsigned *save_rreq)
{
ucp_context_h context = worker->context;
ucp_recv_desc_t *rdesc;
Expand All @@ -27,7 +28,14 @@ ucp_tag_search_unexp(ucp_worker_h worker, void *buffer, size_t buffer_size,
ucp_tag_t recv_tag;
unsigned flags;

ucs_queue_for_each_safe(rdesc, iter, &context->tm.unexpected, queue) {
if (first_iter == NULL) {
iter = ucs_queue_iter_begin(&context->tm.unexpected);
} else {
iter = first_iter;
}

while (!ucs_queue_iter_end(&context->tm.unexpected, iter)) {
rdesc = ucs_queue_iter_elem(rdesc, iter, queue);
recv_tag = ucp_rdesc_get_tag(rdesc);
flags = rdesc->flags;
ucs_trace_req("searching for %"PRIx64"/%"PRIx64"/%"PRIx64" offset %zu, "
Expand Down Expand Up @@ -66,6 +74,8 @@ ucp_tag_search_unexp(ucp_worker_h worker, void *buffer, size_t buffer_size,
UCP_WORKER_STAT_RNDV(worker, UNEXP);
return UCS_INPROGRESS;
}
} else {
iter = ucs_queue_iter_next(iter);
}
}

Expand Down Expand Up @@ -108,22 +118,6 @@ ucp_tag_recv_request_init(ucp_request_t *req, ucp_worker_h worker, void* buffer,
}
}

static UCS_F_ALWAYS_INLINE ucp_request_t*
ucp_tag_recv_request_get(ucp_worker_h worker, void* buffer, size_t count,
ucp_datatype_t datatype)
{
ucp_request_t *req;

req = ucp_request_get(worker);
if (ucs_unlikely(req == NULL)) {
return NULL;
}

ucp_tag_recv_request_init(req, worker, buffer, count, datatype,
UCP_REQUEST_FLAG_CALLBACK);
return req;
}

static UCS_F_ALWAYS_INLINE void
ucp_tag_recv_request_completed(ucp_request_t *req, ucs_status_t status,
ucp_tag_recv_info_t *info, const char *function)
Expand All @@ -138,22 +132,30 @@ ucp_tag_recv_request_completed(ucp_request_t *req, ucs_status_t status,
}

static UCS_F_ALWAYS_INLINE ucs_status_t
ucp_tag_recv_common(ucp_worker_h worker, void *buffer, size_t buffer_size,
ucp_tag_recv_common(ucp_worker_h worker, void *buffer, size_t count,
uintptr_t datatype, ucp_tag_t tag, ucp_tag_t tag_mask,
ucp_request_t *req, ucp_tag_recv_callback_t cb,
const char *debug_name)
ucp_request_t *req, uint16_t req_flags, ucp_tag_recv_callback_t cb,
ucs_queue_iter_t iter, const char *debug_name)
{
ucs_status_t status;
unsigned save_rreq = 1;
size_t buffer_size;

ucp_tag_recv_request_init(req, worker, buffer, count, datatype, req_flags);
buffer_size = ucp_dt_length(datatype, count, buffer, &req->recv.state);

ucs_trace_req("%s buffer %p buffer_size %zu tag %"PRIx64"/%"PRIx64, debug_name,
buffer, buffer_size, tag, tag_mask);

/* First, search in unexpected list */
status = ucp_tag_search_unexp(worker, buffer, buffer_size, datatype, tag,
tag_mask, req, &req->recv.info, cb, &save_rreq);
tag_mask, req, &req->recv.info, cb, iter,
&save_rreq);
if (status != UCS_INPROGRESS) {
return status;
if (req_flags & UCP_REQUEST_FLAG_CALLBACK) {
cb(req + 1, status, &req->recv.info);
}
ucp_tag_recv_request_completed(req, status, &req->recv.info, debug_name);
} else if (save_rreq) {
/* If not found on unexpected, wait until it arrives.
* If was found but need this receive request for later completion, save it */
Expand All @@ -177,23 +179,15 @@ UCS_PROFILE_FUNC(ucs_status_t, ucp_tag_recv_nbr,
uintptr_t datatype, ucp_tag_t tag, ucp_tag_t tag_mask,
void *request)
{
ucp_request_t *req = (ucp_request_t *)request - 1;
ucp_request_t *req = (ucp_request_t *)request - 1;
ucs_status_t status;
size_t buffer_size;

UCP_THREAD_CS_ENTER_CONDITIONAL(&worker->mt_lock);
UCP_THREAD_CS_ENTER_CONDITIONAL(&worker->context->mt_lock);

ucp_tag_recv_request_init(req, worker, buffer, count, datatype,
UCP_REQUEST_DEBUG_FLAG_EXTERNAL);

buffer_size = ucp_dt_length(datatype, count, buffer, &req->recv.state);

status = ucp_tag_recv_common(worker, buffer, buffer_size, datatype, tag,
tag_mask, req, NULL, "recv_nbr");
if (status != UCS_INPROGRESS) {
ucp_tag_recv_request_completed(req, status, &req->recv.info, "recv_nbr");
}
status = ucp_tag_recv_common(worker, buffer, count, datatype, tag, tag_mask,
req, UCP_REQUEST_DEBUG_FLAG_EXTERNAL, NULL, NULL,
"recv_nbr");

UCP_THREAD_CS_EXIT_CONDITIONAL(&worker->context->mt_lock);
UCP_THREAD_CS_EXIT_CONDITIONAL(&worker->mt_lock);
Expand All @@ -206,32 +200,21 @@ UCS_PROFILE_FUNC(ucs_status_ptr_t, ucp_tag_recv_nb,
uintptr_t datatype, ucp_tag_t tag, ucp_tag_t tag_mask,
ucp_tag_recv_callback_t cb)
{
ucp_request_t *req;
ucs_status_t status;
ucs_status_ptr_t ret;
size_t buffer_size;
ucp_request_t *req;

UCP_THREAD_CS_ENTER_CONDITIONAL(&worker->mt_lock);
UCP_THREAD_CS_ENTER_CONDITIONAL(&worker->context->mt_lock);

req = ucp_tag_recv_request_get(worker, buffer, count, datatype);
if (ucs_unlikely(req == NULL)) {
req = ucp_request_get(worker);
if (ucs_likely(req != NULL)) {
ucp_tag_recv_common(worker, buffer, count, datatype, tag, tag_mask, req,
UCP_REQUEST_FLAG_CALLBACK, cb, NULL, "recv_nb");
ret = req + 1;
} else {
ret = UCS_STATUS_PTR(UCS_ERR_NO_MEMORY);
goto out;
}

buffer_size = ucp_dt_length(datatype, count, buffer, &req->recv.state);

status = ucp_tag_recv_common(worker, buffer, buffer_size, datatype, tag,
tag_mask, req, cb, "recv_nb");

if (status != UCS_INPROGRESS) {
cb(req + 1, status, &req->recv.info);
ucp_tag_recv_request_completed(req, status, &req->recv.info, "recv_nb");
}

ret = req + 1;
out:
UCP_THREAD_CS_EXIT_CONDITIONAL(&worker->context->mt_lock);
UCP_THREAD_CS_EXIT_CONDITIONAL(&worker->mt_lock);
return ret;
Expand All @@ -243,82 +226,26 @@ UCS_PROFILE_FUNC(ucs_status_ptr_t, ucp_tag_msg_recv_nb,
uintptr_t datatype, ucp_tag_message_h message,
ucp_tag_recv_callback_t cb)
{
ucp_recv_desc_t *rdesc = message;
ucs_status_t status;
ucp_request_t *req;
ucp_tag_t tag;
unsigned save_rreq = 1;
ucs_queue_iter_t iter = message;
ucp_recv_desc_t *rdesc;
ucs_status_ptr_t ret;
size_t buffer_size;
ucp_request_t *req;

UCP_THREAD_CS_ENTER_CONDITIONAL(&worker->mt_lock);
UCP_THREAD_CS_ENTER_CONDITIONAL(&worker->context->mt_lock);

ucs_trace_req("msg_recv_nb buffer %p count %zu message %p", buffer, count,
message);

req = ucp_tag_recv_request_get(worker, buffer, count, datatype);
if (req == NULL) {
ret = UCS_STATUS_PTR(UCS_ERR_NO_MEMORY);
goto out;
}

buffer_size = ucp_dt_length(datatype, count, buffer, &req->recv.state);

/* First, handle the first packet that was already matched */
if (rdesc->flags & UCP_RECV_DESC_FLAG_EAGER) {
tag = ((ucp_tag_hdr_t*)(rdesc + 1))->tag;
UCS_PROFILE_REQUEST_EVENT(req, "eager_match", 0);
status = ucp_eager_unexp_match(worker, rdesc, tag, rdesc->flags,
buffer, buffer_size, datatype,
&req->recv.state, &req->recv.info);
ucs_trace_req("release receive descriptor %p", rdesc);
ucp_tag_unexp_desc_release(rdesc);
} else if (rdesc->flags & UCP_RECV_DESC_FLAG_RNDV) {
req->recv.buffer = buffer;
req->recv.length = buffer_size;
req->recv.datatype = datatype;
req->recv.cb = cb;
ucp_rndv_matched(worker, req, (void*)(rdesc + 1));
ucp_tag_unexp_desc_release(rdesc);
status = UCS_INPROGRESS;
save_rreq = 0;
UCP_WORKER_STAT_RNDV(worker, UNEXP);
req = ucp_request_get(worker);
if (ucs_likely(req != NULL)) {
rdesc = ucs_queue_iter_elem(rdesc, iter, queue);
rdesc->flags |= UCP_RECV_DESC_FLAG_FIRST;
ucp_tag_recv_common(worker, buffer, count, datatype,
ucp_rdesc_get_tag(rdesc), UCP_TAG_MASK_FULL, req,
UCP_REQUEST_FLAG_CALLBACK, cb, iter, "msg_recv_nb");
ret = req + 1;
} else {
ucp_request_put(req);
ret = UCS_STATUS_PTR(UCS_ERR_INVALID_PARAM);
goto out;
}

/* Since the message contains only the first fragment, we might want
* to receive additional fragments.
*/
if (status == UCS_INPROGRESS) {
status = ucp_tag_search_unexp(worker, buffer, buffer_size, datatype, 0,
-1, req, &req->recv.info, cb, &save_rreq);
}

if (status != UCS_INPROGRESS) {
cb(req + 1, status, &req->recv.info);
ucp_tag_recv_request_completed(req, status, &req->recv.info,
"msg_recv_nb");
} else if (save_rreq) {
ucs_trace_req("msg_recv_nb returning inprogress request %p (%p)",
req, req + 1);
/* For eager - need to put the recv_req in expected since more packets
* will follow. For rndv - don't need to keep the recv_req in the expected queue
* as the match to the RTS already happened. */
req->recv.buffer = buffer;
req->recv.length = buffer_size;
req->recv.datatype = datatype;
req->recv.cb = cb;
req->recv.tag = req->recv.info.sender_tag;
req->recv.tag_mask = UCP_TAG_MASK_FULL;
ucp_tag_exp_add(&worker->context->tm, req);
ret = UCS_STATUS_PTR(UCS_ERR_NO_MEMORY);
}

ret = req + 1;
out:
UCP_THREAD_CS_EXIT_CONDITIONAL(&worker->context->mt_lock);
UCP_THREAD_CS_EXIT_CONDITIONAL(&worker->mt_lock);
return ret;
Expand Down
58 changes: 58 additions & 0 deletions test/gtest/ucp/test_ucp_tag_probe.cc
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,64 @@ UCS_TEST_P(test_ucp_tag_probe, send_rndv_msg_probe, "RNDV_THRESH=1048576") {
request_release(my_recv_req);
}

UCS_TEST_P(test_ucp_tag_probe, send_2_msg_probe, "RNDV_THRESH=inf") {
const ucp_datatype_t DT_INT = ucp_dt_make_contig(sizeof(int));
const ucp_tag_t TAG = 0xaaa;
const size_t COUNT = 20000;

/*
* send in order: 1, 2
*/
std::vector<int> sdata1(COUNT, 1);
std::vector<int> sdata2(COUNT, 2);
send_b(&sdata1[0], COUNT, DT_INT, TAG);
send_b(&sdata2[0], COUNT, DT_INT, TAG);

/*
* probe in order: 1, 2
*/
ucp_tag_message_h message1, message2;
ucp_tag_recv_info info;
do {
progress();
message1 = ucp_tag_probe_nb(receiver().worker(), TAG, 0xffff, 1, &info);
} while (message1 == NULL);
do {
progress();
message2 = ucp_tag_probe_nb(receiver().worker(), TAG, 0xffff, 1, &info);
} while (message2 == NULL);

/*
* receive in **reverse** order: 2, 1
*/
std::vector<int> rdata2(COUNT);
request *rreq2 = (request*)ucp_tag_msg_recv_nb(receiver().worker(), &rdata2[0],
COUNT, DT_INT, message2,
recv_callback);
ASSERT_TRUE(!UCS_PTR_IS_ERR(rreq2));
wait(rreq2);

std::vector<int> rdata1(COUNT);
request *rreq1 = (request*)ucp_tag_msg_recv_nb(receiver().worker(), &rdata1[0],
COUNT, DT_INT, message1,
recv_callback);
ASSERT_TRUE(!UCS_PTR_IS_ERR(rreq1));
wait(rreq1);

/*
* expect data to arrive in probe order (rather than recv order)
*/
EXPECT_TRUE(rreq1->completed);
EXPECT_TRUE(rreq2->completed);
EXPECT_EQ(UCS_OK, rreq1->status);
EXPECT_EQ(UCS_OK, rreq2->status);
EXPECT_EQ(sdata1, rdata1);
EXPECT_EQ(sdata2, rdata2);

request_release(rreq1);
request_release(rreq2);
}

UCS_TEST_P(test_ucp_tag_probe, limited_probe_size) {
static const int COUNT = 1000;
std::string sendbuf, recvbuf;
Expand Down

0 comments on commit cc693c4

Please sign in to comment.