Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

UCP/TAG/TEST: Fix tag_msg_probe and unite flow for receive funcs (v1.2) #1537

Merged
merged 1 commit into from
May 23, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/ucp/api/ucp_def.h
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ typedef uint64_t ucp_tag_t;
* @ref ucp_tag_probe_nb. This handle can be passed to @ref ucp_tag_msg_recv_nb
* in order to receive the message data to a specific buffer.
*/
typedef struct ucp_recv_desc *ucp_tag_message_h;
typedef void *ucp_tag_message_h;


/**
Expand Down
15 changes: 9 additions & 6 deletions src/ucp/tag/probe.c
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
#include <ucs/datastruct/queue.h>


static UCS_F_ALWAYS_INLINE ucp_recv_desc_t*
static UCS_F_ALWAYS_INLINE ucs_queue_iter_t
ucp_tag_probe_search(ucp_context_h context, ucp_tag_t tag, uint64_t tag_mask,
ucp_tag_recv_info_t *info, int remove)
{
Expand Down Expand Up @@ -44,9 +44,12 @@ ucp_tag_probe_search(ucp_context_h context, ucp_tag_t tag, uint64_t tag_mask,
}

if (remove) {
ucs_queue_del_iter(&context->tm.unexpected, iter);
/* Prevent the receive descriptor, and any fragments after it,
* from being matched by receive requests.
*/
rdesc->flags &= ~UCP_RECV_DESC_FLAG_FIRST;
}
return rdesc;
return iter;
}
}

Expand All @@ -58,16 +61,16 @@ ucp_tag_message_h ucp_tag_probe_nb(ucp_worker_h worker, ucp_tag_t tag,
ucp_tag_recv_info_t *info)
{
ucp_context_h context = worker->context;
ucp_recv_desc_t *ret;
ucs_queue_iter_t message;

UCP_THREAD_CS_ENTER_CONDITIONAL(&worker->mt_lock);
UCP_THREAD_CS_ENTER_CONDITIONAL(&context->mt_lock);

ucs_trace_req("probe_nb tag %"PRIx64"/%"PRIx64, tag, tag_mask);
ret = ucp_tag_probe_search(context, tag, tag_mask, info, remove);
message = ucp_tag_probe_search(context, tag, tag_mask, info, remove);

UCP_THREAD_CS_EXIT_CONDITIONAL(&context->mt_lock);
UCP_THREAD_CS_EXIT_CONDITIONAL(&worker->mt_lock);

return ret;
return message;
}
1 change: 1 addition & 0 deletions src/ucp/tag/tag_match.inl
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ ucp_tag_unexp_recv(ucp_tag_match_t *tm, ucp_worker_h worker, void *data,
static UCS_F_ALWAYS_INLINE void
ucp_tag_unexp_desc_release(ucp_recv_desc_t *rdesc)
{
ucs_trace_req("release receive descriptor %p", rdesc);
if (ucs_unlikely(rdesc->flags & UCP_RECV_DESC_FLAG_UCT_DESC)) {
uct_iface_release_desc(rdesc); /* uct desc is slowpath */
} else {
Expand Down
169 changes: 48 additions & 121 deletions src/ucp/tag/tag_recv.c
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ static UCS_F_ALWAYS_INLINE ucs_status_t
ucp_tag_search_unexp(ucp_worker_h worker, void *buffer, size_t buffer_size,
ucp_datatype_t datatype, ucp_tag_t tag, uint64_t tag_mask,
ucp_request_t *req, ucp_tag_recv_info_t *info,
ucp_tag_recv_callback_t cb, unsigned *save_rreq)
ucp_tag_recv_callback_t cb, ucs_queue_iter_t first_iter,
unsigned *save_rreq)
{
ucp_context_h context = worker->context;
ucp_recv_desc_t *rdesc;
Expand All @@ -27,7 +28,14 @@ ucp_tag_search_unexp(ucp_worker_h worker, void *buffer, size_t buffer_size,
ucp_tag_t recv_tag;
unsigned flags;

ucs_queue_for_each_safe(rdesc, iter, &context->tm.unexpected, queue) {
if (first_iter == NULL) {
iter = ucs_queue_iter_begin(&context->tm.unexpected);
} else {
iter = first_iter;
}

while (!ucs_queue_iter_end(&context->tm.unexpected, iter)) {
rdesc = ucs_queue_iter_elem(rdesc, iter, queue);
recv_tag = ucp_rdesc_get_tag(rdesc);
flags = rdesc->flags;
ucs_trace_req("searching for %"PRIx64"/%"PRIx64"/%"PRIx64" offset %zu, "
Expand Down Expand Up @@ -66,6 +74,8 @@ ucp_tag_search_unexp(ucp_worker_h worker, void *buffer, size_t buffer_size,
UCP_WORKER_STAT_RNDV(worker, UNEXP);
return UCS_INPROGRESS;
}
} else {
iter = ucs_queue_iter_next(iter);
}
}

Expand Down Expand Up @@ -108,22 +118,6 @@ ucp_tag_recv_request_init(ucp_request_t *req, ucp_worker_h worker, void* buffer,
}
}

static UCS_F_ALWAYS_INLINE ucp_request_t*
ucp_tag_recv_request_get(ucp_worker_h worker, void* buffer, size_t count,
ucp_datatype_t datatype)
{
ucp_request_t *req;

req = ucp_request_get(worker);
if (ucs_unlikely(req == NULL)) {
return NULL;
}

ucp_tag_recv_request_init(req, worker, buffer, count, datatype,
UCP_REQUEST_FLAG_CALLBACK);
return req;
}

static UCS_F_ALWAYS_INLINE void
ucp_tag_recv_request_completed(ucp_request_t *req, ucs_status_t status,
ucp_tag_recv_info_t *info, const char *function)
Expand All @@ -138,22 +132,30 @@ ucp_tag_recv_request_completed(ucp_request_t *req, ucs_status_t status,
}

static UCS_F_ALWAYS_INLINE ucs_status_t
ucp_tag_recv_common(ucp_worker_h worker, void *buffer, size_t buffer_size,
ucp_tag_recv_common(ucp_worker_h worker, void *buffer, size_t count,
uintptr_t datatype, ucp_tag_t tag, ucp_tag_t tag_mask,
ucp_request_t *req, ucp_tag_recv_callback_t cb,
const char *debug_name)
ucp_request_t *req, uint16_t req_flags, ucp_tag_recv_callback_t cb,
ucs_queue_iter_t iter, const char *debug_name)
{
ucs_status_t status;
unsigned save_rreq = 1;
size_t buffer_size;

ucp_tag_recv_request_init(req, worker, buffer, count, datatype, req_flags);
buffer_size = ucp_dt_length(datatype, count, buffer, &req->recv.state);

ucs_trace_req("%s buffer %p buffer_size %zu tag %"PRIx64"/%"PRIx64, debug_name,
buffer, buffer_size, tag, tag_mask);

/* First, search in unexpected list */
status = ucp_tag_search_unexp(worker, buffer, buffer_size, datatype, tag,
tag_mask, req, &req->recv.info, cb, &save_rreq);
tag_mask, req, &req->recv.info, cb, iter,
&save_rreq);
if (status != UCS_INPROGRESS) {
return status;
if (req_flags & UCP_REQUEST_FLAG_CALLBACK) {
cb(req + 1, status, &req->recv.info);
}
ucp_tag_recv_request_completed(req, status, &req->recv.info, debug_name);
} else if (save_rreq) {
/* If not found on unexpected, wait until it arrives.
* If was found but need this receive request for later completion, save it */
Expand All @@ -177,23 +179,15 @@ UCS_PROFILE_FUNC(ucs_status_t, ucp_tag_recv_nbr,
uintptr_t datatype, ucp_tag_t tag, ucp_tag_t tag_mask,
void *request)
{
ucp_request_t *req = (ucp_request_t *)request - 1;
ucp_request_t *req = (ucp_request_t *)request - 1;
ucs_status_t status;
size_t buffer_size;

UCP_THREAD_CS_ENTER_CONDITIONAL(&worker->mt_lock);
UCP_THREAD_CS_ENTER_CONDITIONAL(&worker->context->mt_lock);

ucp_tag_recv_request_init(req, worker, buffer, count, datatype,
UCP_REQUEST_DEBUG_FLAG_EXTERNAL);

buffer_size = ucp_dt_length(datatype, count, buffer, &req->recv.state);

status = ucp_tag_recv_common(worker, buffer, buffer_size, datatype, tag,
tag_mask, req, NULL, "recv_nbr");
if (status != UCS_INPROGRESS) {
ucp_tag_recv_request_completed(req, status, &req->recv.info, "recv_nbr");
}
status = ucp_tag_recv_common(worker, buffer, count, datatype, tag, tag_mask,
req, UCP_REQUEST_DEBUG_FLAG_EXTERNAL, NULL, NULL,
"recv_nbr");

UCP_THREAD_CS_EXIT_CONDITIONAL(&worker->context->mt_lock);
UCP_THREAD_CS_EXIT_CONDITIONAL(&worker->mt_lock);
Expand All @@ -206,32 +200,21 @@ UCS_PROFILE_FUNC(ucs_status_ptr_t, ucp_tag_recv_nb,
uintptr_t datatype, ucp_tag_t tag, ucp_tag_t tag_mask,
ucp_tag_recv_callback_t cb)
{
ucp_request_t *req;
ucs_status_t status;
ucs_status_ptr_t ret;
size_t buffer_size;
ucp_request_t *req;

UCP_THREAD_CS_ENTER_CONDITIONAL(&worker->mt_lock);
UCP_THREAD_CS_ENTER_CONDITIONAL(&worker->context->mt_lock);

req = ucp_tag_recv_request_get(worker, buffer, count, datatype);
if (ucs_unlikely(req == NULL)) {
req = ucp_request_get(worker);
if (ucs_likely(req != NULL)) {
ucp_tag_recv_common(worker, buffer, count, datatype, tag, tag_mask, req,
UCP_REQUEST_FLAG_CALLBACK, cb, NULL, "recv_nb");
ret = req + 1;
} else {
ret = UCS_STATUS_PTR(UCS_ERR_NO_MEMORY);
goto out;
}

buffer_size = ucp_dt_length(datatype, count, buffer, &req->recv.state);

status = ucp_tag_recv_common(worker, buffer, buffer_size, datatype, tag,
tag_mask, req, cb, "recv_nb");

if (status != UCS_INPROGRESS) {
cb(req + 1, status, &req->recv.info);
ucp_tag_recv_request_completed(req, status, &req->recv.info, "recv_nb");
}

ret = req + 1;
out:
UCP_THREAD_CS_EXIT_CONDITIONAL(&worker->context->mt_lock);
UCP_THREAD_CS_EXIT_CONDITIONAL(&worker->mt_lock);
return ret;
Expand All @@ -243,82 +226,26 @@ UCS_PROFILE_FUNC(ucs_status_ptr_t, ucp_tag_msg_recv_nb,
uintptr_t datatype, ucp_tag_message_h message,
ucp_tag_recv_callback_t cb)
{
ucp_recv_desc_t *rdesc = message;
ucs_status_t status;
ucp_request_t *req;
ucp_tag_t tag;
unsigned save_rreq = 1;
ucs_queue_iter_t iter = message;
ucp_recv_desc_t *rdesc;
ucs_status_ptr_t ret;
size_t buffer_size;
ucp_request_t *req;

UCP_THREAD_CS_ENTER_CONDITIONAL(&worker->mt_lock);
UCP_THREAD_CS_ENTER_CONDITIONAL(&worker->context->mt_lock);

ucs_trace_req("msg_recv_nb buffer %p count %zu message %p", buffer, count,
message);

req = ucp_tag_recv_request_get(worker, buffer, count, datatype);
if (req == NULL) {
ret = UCS_STATUS_PTR(UCS_ERR_NO_MEMORY);
goto out;
}

buffer_size = ucp_dt_length(datatype, count, buffer, &req->recv.state);

/* First, handle the first packet that was already matched */
if (rdesc->flags & UCP_RECV_DESC_FLAG_EAGER) {
tag = ((ucp_tag_hdr_t*)(rdesc + 1))->tag;
UCS_PROFILE_REQUEST_EVENT(req, "eager_match", 0);
status = ucp_eager_unexp_match(worker, rdesc, tag, rdesc->flags,
buffer, buffer_size, datatype,
&req->recv.state, &req->recv.info);
ucs_trace_req("release receive descriptor %p", rdesc);
ucp_tag_unexp_desc_release(rdesc);
} else if (rdesc->flags & UCP_RECV_DESC_FLAG_RNDV) {
req->recv.buffer = buffer;
req->recv.length = buffer_size;
req->recv.datatype = datatype;
req->recv.cb = cb;
ucp_rndv_matched(worker, req, (void*)(rdesc + 1));
ucp_tag_unexp_desc_release(rdesc);
status = UCS_INPROGRESS;
save_rreq = 0;
UCP_WORKER_STAT_RNDV(worker, UNEXP);
req = ucp_request_get(worker);
if (ucs_likely(req != NULL)) {
rdesc = ucs_queue_iter_elem(rdesc, iter, queue);
rdesc->flags |= UCP_RECV_DESC_FLAG_FIRST;
ucp_tag_recv_common(worker, buffer, count, datatype,
ucp_rdesc_get_tag(rdesc), UCP_TAG_MASK_FULL, req,
UCP_REQUEST_FLAG_CALLBACK, cb, iter, "msg_recv_nb");
ret = req + 1;
} else {
ucp_request_put(req);
ret = UCS_STATUS_PTR(UCS_ERR_INVALID_PARAM);
goto out;
}

/* Since the message contains only the first fragment, we might want
* to receive additional fragments.
*/
if (status == UCS_INPROGRESS) {
status = ucp_tag_search_unexp(worker, buffer, buffer_size, datatype, 0,
-1, req, &req->recv.info, cb, &save_rreq);
}

if (status != UCS_INPROGRESS) {
cb(req + 1, status, &req->recv.info);
ucp_tag_recv_request_completed(req, status, &req->recv.info,
"msg_recv_nb");
} else if (save_rreq) {
ucs_trace_req("msg_recv_nb returning inprogress request %p (%p)",
req, req + 1);
/* For eager - need to put the recv_req in expected since more packets
* will follow. For rndv - don't need to keep the recv_req in the expected queue
* as the match to the RTS already happened. */
req->recv.buffer = buffer;
req->recv.length = buffer_size;
req->recv.datatype = datatype;
req->recv.cb = cb;
req->recv.tag = req->recv.info.sender_tag;
req->recv.tag_mask = UCP_TAG_MASK_FULL;
ucp_tag_exp_add(&worker->context->tm, req);
ret = UCS_STATUS_PTR(UCS_ERR_NO_MEMORY);
}

ret = req + 1;
out:
UCP_THREAD_CS_EXIT_CONDITIONAL(&worker->context->mt_lock);
UCP_THREAD_CS_EXIT_CONDITIONAL(&worker->mt_lock);
return ret;
Expand Down
58 changes: 58 additions & 0 deletions test/gtest/ucp/test_ucp_tag_probe.cc
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,64 @@ UCS_TEST_P(test_ucp_tag_probe, send_rndv_msg_probe, "RNDV_THRESH=1048576") {
request_release(my_recv_req);
}

UCS_TEST_P(test_ucp_tag_probe, send_2_msg_probe, "RNDV_THRESH=inf") {
const ucp_datatype_t DT_INT = ucp_dt_make_contig(sizeof(int));
const ucp_tag_t TAG = 0xaaa;
const size_t COUNT = 20000;

/*
* send in order: 1, 2
*/
std::vector<int> sdata1(COUNT, 1);
std::vector<int> sdata2(COUNT, 2);
send_b(&sdata1[0], COUNT, DT_INT, TAG);
send_b(&sdata2[0], COUNT, DT_INT, TAG);

/*
* probe in order: 1, 2
*/
ucp_tag_message_h message1, message2;
ucp_tag_recv_info info;
do {
progress();
message1 = ucp_tag_probe_nb(receiver().worker(), TAG, 0xffff, 1, &info);
} while (message1 == NULL);
do {
progress();
message2 = ucp_tag_probe_nb(receiver().worker(), TAG, 0xffff, 1, &info);
} while (message2 == NULL);

/*
* receive in **reverse** order: 2, 1
*/
std::vector<int> rdata2(COUNT);
request *rreq2 = (request*)ucp_tag_msg_recv_nb(receiver().worker(), &rdata2[0],
COUNT, DT_INT, message2,
recv_callback);
ASSERT_TRUE(!UCS_PTR_IS_ERR(rreq2));
wait(rreq2);

std::vector<int> rdata1(COUNT);
request *rreq1 = (request*)ucp_tag_msg_recv_nb(receiver().worker(), &rdata1[0],
COUNT, DT_INT, message1,
recv_callback);
ASSERT_TRUE(!UCS_PTR_IS_ERR(rreq1));
wait(rreq1);

/*
* expect data to arrive in probe order (rather than recv order)
*/
EXPECT_TRUE(rreq1->completed);
EXPECT_TRUE(rreq2->completed);
EXPECT_EQ(UCS_OK, rreq1->status);
EXPECT_EQ(UCS_OK, rreq2->status);
EXPECT_EQ(sdata1, rdata1);
EXPECT_EQ(sdata2, rdata2);

request_release(rreq1);
request_release(rreq2);
}

UCS_TEST_P(test_ucp_tag_probe, limited_probe_size) {
static const int COUNT = 1000;
std::string sendbuf, recvbuf;
Expand Down