Skip to content

Commit

Permalink
UCP/CORE/RNDV/GTEST: Handle status from AM/TAG RNDV RTS/data correctly
Browse files Browse the repository at this point in the history
  • Loading branch information
dmitrygx committed Jan 25, 2021
1 parent 49fbd8e commit 1407016
Show file tree
Hide file tree
Showing 9 changed files with 160 additions and 89 deletions.
6 changes: 4 additions & 2 deletions src/ucp/core/ucp_am.c
Original file line number Diff line number Diff line change
Expand Up @@ -738,14 +738,16 @@ UCS_PROFILE_FUNC(ucs_status_t, ucp_proto_progress_am_rndv_rts, (self),
{
ucp_request_t *sreq = ucs_container_of(self, ucp_request_t, send.uct);
size_t max_rts_size;
ucs_status_t status;

/* RTS consists of: AM RTS header, packed rkeys and user header */
max_rts_size = sizeof(ucp_am_rndv_rts_hdr_t) +
ucp_ep_config(sreq->send.ep)->rndv.rkey_size +
sreq->send.msg_proto.am.header_length;

return ucp_do_am_single(self, UCP_AM_ID_RNDV_RTS, ucp_am_rndv_rts_pack,
max_rts_size);
status = ucp_do_am_single(self, UCP_AM_ID_RNDV_RTS, ucp_am_rndv_rts_pack,
max_rts_size);
return ucp_rndv_rts_handle_status_from_pending(sreq, status);
}

static ucs_status_t ucp_am_send_start_rndv(ucp_request_t *sreq)
Expand Down
16 changes: 10 additions & 6 deletions src/ucp/proto/proto_am.c
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,17 @@ ucp_do_am_single(uct_pending_req_t *self, uint8_t am_id,
ucs_status_t ucp_proto_progress_am_single(uct_pending_req_t *self)
{
ucp_request_t *req = ucs_container_of(self, ucp_request_t, send.uct);
ucs_status_t status = ucp_do_am_single(self, req->send.proto.am_id,
ucp_proto_pack,
ucp_proto_max_packed_size());
if (status == UCS_OK) {
req->send.proto.comp_cb(req);
ucs_status_t status;

status = ucp_do_am_single(self, req->send.proto.am_id, ucp_proto_pack,
ucp_proto_max_packed_size());
if (ucs_unlikely(status == UCS_ERR_NO_RESOURCE)) {
return UCS_ERR_NO_RESOURCE;
}
return status;

/* TODO: handle failure */
req->send.proto.comp_cb(req);
return UCS_OK;
}

void ucp_proto_am_zcopy_req_complete(ucp_request_t *req, ucs_status_t status)
Expand Down
34 changes: 19 additions & 15 deletions src/ucp/proto/proto_am.inl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,24 @@

#define UCP_STATUS_PENDING_SWITCH (UCS_ERR_LAST - 1)

#define UCP_AM_BCOPY_HANDLE_STATUS(_multi, _status) \
do { \
if (_multi) { \
if (_status == UCS_INPROGRESS) { \
return UCS_INPROGRESS; \
} else if (ucs_unlikely(_status == UCP_STATUS_PENDING_SWITCH)) { \
return UCS_OK; \
} \
} else { \
ucs_assert(_status != UCS_INPROGRESS); \
} \
\
if (ucs_unlikely(_status == UCS_ERR_NO_RESOURCE)) { \
return UCS_ERR_NO_RESOURCE; \
} \
} while (0)


typedef void (*ucp_req_complete_func_t)(ucp_request_t *req, ucs_status_t status);


Expand Down Expand Up @@ -552,21 +570,7 @@ ucp_am_bcopy_handle_status_from_pending(uct_pending_req_t *self, int multi,
{
ucp_request_t *req = ucs_container_of(self, ucp_request_t, send.uct);

if (multi) {
if (status == UCS_INPROGRESS) {
return UCS_INPROGRESS;
}

if (ucs_unlikely(status == UCP_STATUS_PENDING_SWITCH)) {
return UCS_OK;
}
} else {
ucs_assert(status != UCS_INPROGRESS);
}

if (ucs_unlikely(status == UCS_ERR_NO_RESOURCE)) {
return UCS_ERR_NO_RESOURCE;
}
UCP_AM_BCOPY_HANDLE_STATUS(multi, status);

ucp_request_send_generic_dt_finish(req);
if (tag_sync) {
Expand Down
48 changes: 35 additions & 13 deletions src/ucp/rndv/rndv.c
Original file line number Diff line number Diff line change
Expand Up @@ -176,14 +176,15 @@ UCS_PROFILE_FUNC(ucs_status_t, ucp_proto_progress_rndv_rtr, (self),

/* send the RTR. the pack_cb will pack all the necessary fields in the RTR */
packed_rkey_size = ucp_ep_config(rndv_req->send.ep)->rndv.rkey_size;
status = ucp_do_am_single(self, UCP_AM_ID_RNDV_RTR, ucp_rndv_rtr_pack,
sizeof(ucp_rndv_rtr_hdr_t) + packed_rkey_size);
if (status == UCS_OK) {
/* release rndv request */
ucp_request_put(rndv_req);
status = ucp_do_am_single(self, UCP_AM_ID_RNDV_RTR, ucp_rndv_rtr_pack,
sizeof(ucp_rndv_rtr_hdr_t) + packed_rkey_size);
if (ucs_unlikely(status == UCS_ERR_NO_RESOURCE)) {
return UCS_ERR_NO_RESOURCE;
}

return status;
/* release rndv request */
ucp_request_put(rndv_req);
return UCS_OK;
}

ucs_status_t ucp_rndv_reg_send_buffer(ucp_request_t *sreq)
Expand Down Expand Up @@ -1379,6 +1380,26 @@ UCS_PROFILE_FUNC(ucs_status_t, ucp_rndv_ats_handler,
return UCS_OK;
}

ucs_status_t ucp_rndv_rts_handle_status_from_pending(ucp_request_t *sreq,
ucs_status_t status)
{
/* we rely on the fact that the RTS isn't being sent by an AM Bcopy multi */
ucs_assert((status != UCP_STATUS_PENDING_SWITCH) &&
(status != UCS_INPROGRESS));

if (ucs_unlikely(status != UCS_OK)) {
if (status == UCS_ERR_NO_RESOURCE) {
return UCS_ERR_NO_RESOURCE;
}

ucp_worker_del_request_id(sreq->send.ep->worker, sreq,
sreq->send.msg_proto.sreq_id);
ucp_rndv_complete_send(sreq, status);
}

return UCS_OK;
}

static size_t ucp_rndv_pack_data(void *dest, void *arg)
{
ucp_rndv_data_hdr_t *hdr = dest;
Expand All @@ -1401,9 +1422,11 @@ UCS_PROFILE_FUNC(ucs_status_t, ucp_rndv_progress_am_bcopy, (self),
{
ucp_request_t *sreq = ucs_container_of(self, ucp_request_t, send.uct);
ucp_ep_t *ep = sreq->send.ep;
int single = (sreq->send.length + sizeof(ucp_rndv_data_hdr_t)) <=
ucp_ep_config(ep)->am.max_bcopy;
ucs_status_t status;

if (sreq->send.length <= ucp_ep_config(ep)->am.max_bcopy - sizeof(ucp_rndv_data_hdr_t)) {
if (single) {
/* send a single bcopy message */
status = ucp_do_am_bcopy_single(self, UCP_AM_ID_RNDV_DATA,
ucp_rndv_pack_data);
Expand All @@ -1413,13 +1436,12 @@ UCS_PROFILE_FUNC(ucs_status_t, ucp_rndv_progress_am_bcopy, (self),
ucp_rndv_pack_data,
ucp_rndv_pack_data, 1);
}
if (status == UCS_OK) {
ucp_rndv_complete_send(sreq, UCS_OK);
} else if (status == UCP_STATUS_PENDING_SWITCH) {
status = UCS_OK;
}

return status;
UCP_AM_BCOPY_HANDLE_STATUS(!single, status);

ucp_rndv_complete_send(sreq, status);

return UCS_OK;
}

UCS_PROFILE_FUNC(ucs_status_t, ucp_rndv_progress_rma_put_zcopy, (self),
Expand Down
3 changes: 3 additions & 0 deletions src/ucp/rndv/rndv.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,4 +68,7 @@ void ucp_rndv_receive(ucp_worker_h worker, ucp_request_t *rreq,
void ucp_rndv_req_send_ats(ucp_request_t *rndv_req, ucp_request_t *rreq,
ucs_ptr_map_key_t remote_req_id, ucs_status_t status);

ucs_status_t ucp_rndv_rts_handle_status_from_pending(ucp_request_t *sreq,
ucs_status_t status);

#endif
10 changes: 7 additions & 3 deletions src/ucp/tag/tag_rndv.c
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,15 @@ UCS_PROFILE_FUNC(ucs_status_t, ucp_proto_progress_rndv_rts, (self),
{
ucp_request_t *sreq = ucs_container_of(self, ucp_request_t, send.uct);
size_t packed_rkey_size;
ucs_status_t status;

/* send the RTS. the pack_cb will pack all the necessary fields in the RTS */
/* send the RTS. the pack_cb packs all the necessary fields in the RTS */
packed_rkey_size = ucp_ep_config(sreq->send.ep)->rndv.rkey_size;
return ucp_do_am_single(self, UCP_AM_ID_RNDV_RTS, ucp_tag_rndv_rts_pack,
sizeof(ucp_tag_rndv_rts_hdr_t) + packed_rkey_size);

status = ucp_do_am_single(self, UCP_AM_ID_RNDV_RTS, ucp_tag_rndv_rts_pack,
sizeof(ucp_tag_rndv_rts_hdr_t) +
packed_rkey_size);
return ucp_rndv_rts_handle_status_from_pending(sreq, status);
}

ucs_status_t ucp_tag_send_start_rndv(ucp_request_t *sreq)
Expand Down
81 changes: 46 additions & 35 deletions test/gtest/ucp/test_ucp_sockaddr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1100,11 +1100,8 @@ class test_ucp_sockaddr_protocols : public test_ucp_sockaddr {

sreq_mem_dereg(sreq);

if (recv_stop) {
sender().disconnect_nb(0, 0, UCP_EP_CLOSE_MODE_FORCE);
} else {
if (send_stop) {
disconnect(*this, sender());
receiver().disconnect_nb(0, 0, UCP_EP_CLOSE_MODE_FORCE);
}
}

Expand Down Expand Up @@ -1145,6 +1142,16 @@ class test_ucp_sockaddr_protocols : public test_ucp_sockaddr {
ucp_request_release(sreq);
}

void extra_send_before_disconnect(entity &e, const std::string &send_buf,
const ucp_request_param_t &send_param)
{
void *sreq = ucp_tag_send_nbx(e.ep(), &send_buf[0], send_buf.size(), 0,
&send_param);
request_wait(sreq);

e.disconnect_nb(0, 0, UCP_EP_CLOSE_MODE_FORCE);
}

void test_tag_send_recv(size_t size, bool is_exp, bool is_sync = false,
bool send_stop = false, bool recv_stop = false)
{
Expand Down Expand Up @@ -1205,6 +1212,14 @@ class test_ucp_sockaddr_protocols : public test_ucp_sockaddr {

if (!err_handling_test) {
compare_buffers(send_buf, recv_buf);
} else {
wait_for_flag(&m_err_count);

if (send_stop == false) {
extra_send_before_disconnect(sender(), send_buf, send_param);
} else if (recv_stop == false) {
extra_send_before_disconnect(receiver(), send_buf, send_param);
}
}
}
}
Expand Down Expand Up @@ -1550,17 +1565,16 @@ UCS_TEST_P(test_ucp_sockaddr_protocols, am_zcopy_64k,
}



/* For DC case, allow fallback to UD if DC is not supported */
#define UCP_INSTANTIATE_CM_TEST_CASE(_test_case) \
UCP_INSTANTIATE_TEST_CASE_TLS(_test_case, dcudx, "dc_x,ud") \
UCP_INSTANTIATE_TEST_CASE_TLS(_test_case, ud, "ud_v") \
UCP_INSTANTIATE_TEST_CASE_TLS(_test_case, udx, "ud_x") \
UCP_INSTANTIATE_TEST_CASE_TLS(_test_case, rc, "rc_v") \
UCP_INSTANTIATE_TEST_CASE_TLS(_test_case, rcx, "rc_x") \
UCP_INSTANTIATE_TEST_CASE_TLS(_test_case, ib, "ib") \
UCP_INSTANTIATE_TEST_CASE_TLS(_test_case, tcp, "tcp") \
UCP_INSTANTIATE_TEST_CASE_TLS(_test_case, all, "all")
UCP_INSTANTIATE_TEST_CASE_TLS_GPU_AWARE(_test_case, dcudx, "dc_x,ud") \
UCP_INSTANTIATE_TEST_CASE_TLS_GPU_AWARE(_test_case, ud, "ud_v") \
UCP_INSTANTIATE_TEST_CASE_TLS_GPU_AWARE(_test_case, udx, "ud_x") \
UCP_INSTANTIATE_TEST_CASE_TLS_GPU_AWARE(_test_case, rc, "rc_v") \
UCP_INSTANTIATE_TEST_CASE_TLS_GPU_AWARE(_test_case, rcx, "rc_x") \
UCP_INSTANTIATE_TEST_CASE_TLS_GPU_AWARE(_test_case, ib, "ib") \
UCP_INSTANTIATE_TEST_CASE_TLS_GPU_AWARE(_test_case, tcp, "tcp") \
UCP_INSTANTIATE_TEST_CASE_TLS_GPU_AWARE(_test_case, all, "all")

UCP_INSTANTIATE_CM_TEST_CASE(test_ucp_sockaddr_protocols)

Expand All @@ -1582,10 +1596,6 @@ class test_ucp_sockaddr_protocols_err : public test_ucp_sockaddr_protocols {
set_tl_timeouts(m_env);
}

void init() {
test_ucp_sockaddr_protocols::init();
}

void test_tag_send_recv(size_t size, bool is_exp,
bool is_sync = false) {
/* warmup */
Expand All @@ -1598,28 +1608,17 @@ class test_ucp_sockaddr_protocols_err : public test_ucp_sockaddr_protocols {
variants & RECV_STOP);
}

void cleanup() {
test_ucp_sockaddr_protocols::cleanup();
}

static void err_handler_cb(void *arg, ucp_ep_h ep, ucs_status_t status) {
test_ucp_sockaddr::err_handler_cb(arg, ep, status);

test_ucp_sockaddr_protocols *test =
static_cast<test_ucp_sockaddr_protocols*>(arg);
if (test->sender().ep() == ep) {
test->sender().disconnect_nb(0, 0, UCP_EP_CLOSE_MODE_FORCE);
} else {
ASSERT_EQ(test->receiver().ep(), ep);
test->receiver().disconnect_nb(0, 0, UCP_EP_CLOSE_MODE_FORCE);
}
}

protected:
ucs::ptr_vector<ucs::scoped_setenv> m_env;
};


UCS_TEST_P(test_ucp_sockaddr_protocols_err, tag_eager_32_unexp,
"ZCOPY_THRESH=inf", "RNDV_THRESH=inf")
{
test_tag_send_recv(32, false, false);
}

UCS_TEST_P(test_ucp_sockaddr_protocols_err, tag_zcopy_4k_unexp,
"ZCOPY_THRESH=2k", "RNDV_THRESH=inf")
{
Expand All @@ -1633,6 +1632,12 @@ UCS_TEST_P(test_ucp_sockaddr_protocols_err, tag_zcopy_64k_unexp,
test_tag_send_recv(64 * UCS_KBYTE, false, false);
}

UCS_TEST_P(test_ucp_sockaddr_protocols_err, tag_eager_32_unexp_sync,
"ZCOPY_THRESH=inf", "RNDV_THRESH=inf")
{
test_tag_send_recv(32, false, true);
}

UCS_TEST_P(test_ucp_sockaddr_protocols_err, tag_zcopy_4k_unexp_sync,
"ZCOPY_THRESH=2k", "RNDV_THRESH=inf")
{
Expand All @@ -1646,7 +1651,13 @@ UCS_TEST_P(test_ucp_sockaddr_protocols_err, tag_zcopy_64k_unexp_sync,
}

UCS_TEST_P(test_ucp_sockaddr_protocols_err, tag_rndv_unexp,
"RNDV_THRESH=0")
"RNDV_THRESH=0", "RNDV_SCHEME=auto")
{
test_tag_send_recv(64 * UCS_KBYTE, false, false);
}

UCS_TEST_P(test_ucp_sockaddr_protocols_err, tag_rndv_unexp_get_scheme,
"RNDV_THRESH=0", "RNDV_SCHEME=get_zcopy")
{
test_tag_send_recv(64 * UCS_KBYTE, false, false);
}
Expand Down
8 changes: 4 additions & 4 deletions test/gtest/ucp/test_ucp_tag_offload.cc
Original file line number Diff line number Diff line change
Expand Up @@ -612,8 +612,8 @@ UCS_TEST_P(test_ucp_tag_offload_gpu, rx_scatter_to_cqe, "TM_THRESH=1")
wait_and_validate(sreq);
}

UCP_INSTANTIATE_TEST_CASE_TLS(test_ucp_tag_offload_gpu, rc_dc_gpu,
"dc_x,rc_x," UCP_TEST_GPU_COPY_TLS)
UCP_INSTANTIATE_TEST_CASE_TLS_GPU_AWARE(test_ucp_tag_offload_gpu, rc_dc_gpu,
"dc_x,rc_x")

class test_ucp_tag_offload_status : public test_ucp_tag {
public:
Expand Down Expand Up @@ -864,7 +864,7 @@ UCS_TEST_P(test_ucp_tag_offload_stats_gpu, block_gpu_no_gpu_direct,
req_cancel(receiver(), rreq);
}

UCP_INSTANTIATE_TEST_CASE_TLS(test_ucp_tag_offload_stats_gpu, rc_dc_gpu,
"dc_x,rc_x," UCP_TEST_GPU_COPY_TLS)
UCP_INSTANTIATE_TEST_CASE_TLS_GPU_AWARE(test_ucp_tag_offload_stats_gpu,
rc_dc_gpu, "dc_x,rc_x")

#endif
Loading

0 comments on commit 1407016

Please sign in to comment.