diff --git a/src/ucp/core/ucp_am.c b/src/ucp/core/ucp_am.c index 326b766e74ad..ef47482ccc1f 100644 --- a/src/ucp/core/ucp_am.c +++ b/src/ucp/core/ucp_am.c @@ -138,14 +138,17 @@ static void ucp_am_rndv_send_ats(ucp_worker_h worker, ucs_status_t status) { ucp_request_t *req; + ucp_ep_h ep; + ep = UCP_WORKER_GET_EP_BY_ID(worker, rts->super.sreq.ep_id, 1, return, + "AM RNDV ATS"); req = ucp_request_get(worker); if (ucs_unlikely(req == NULL)) { ucs_error("failed to allocate request for AM RNDV ATS"); return; } - req->send.ep = ucp_worker_get_ep_by_id(worker, rts->super.sreq.ep_id); + req->send.ep = ep; req->flags = 0; ucp_rndv_req_send_ats(req, NULL, rts->super.sreq.req_id, status); @@ -1176,8 +1179,8 @@ UCS_PROFILE_FUNC(ucs_status_t, ucp_am_handler_reply, ucp_worker_h worker = (ucp_worker_h)am_arg; ucp_ep_h reply_ep; - reply_ep = UCP_WORKER_GET_EP_BY_ID(worker, hdr->ep_id, "AM (reply proto)", - return UCS_OK); + reply_ep = UCP_WORKER_GET_EP_BY_ID(worker, hdr->ep_id, 1, return UCS_OK, + "AM (reply proto)"); return ucp_am_handler_common(worker, &hdr->super, sizeof(*hdr), am_length, reply_ep, am_flags, @@ -1234,6 +1237,7 @@ ucp_am_hdr_reply_ep(ucp_worker_h worker, uint16_t flags, ucp_ep_h ep, *reply_ep_p = NULL; +out: return 0ul; } @@ -1305,8 +1309,8 @@ UCS_PROFILE_FUNC(ucs_status_t, ucp_am_long_first_handler, size_t remaining; uint64_t recv_flags; - ep = UCP_WORKER_GET_EP_BY_ID(worker, first_hdr->super.ep_id, - "AM first fragment", return UCS_OK); + ep = UCP_WORKER_GET_EP_BY_ID(worker, first_hdr->super.ep_id, 1, + return UCS_OK, "AM first fragment"); remaining = first_hdr->total_size - (am_length - sizeof(*first_hdr)); if (ucs_unlikely(remaining == 0)) { @@ -1378,8 +1382,8 @@ UCS_PROFILE_FUNC(ucs_status_t, ucp_am_long_middle_handler, ucp_ep_h ep; ucs_status_t status; - ep = UCP_WORKER_GET_EP_BY_ID(worker, mid_hdr->ep_id, - "AM middle fragment", return UCS_OK); + ep = UCP_WORKER_GET_EP_BY_ID(worker, mid_hdr->ep_id, 1, + return UCS_OK, "AM middle fragment"); ep_ext = ucp_ep_ext_proto(ep); first_rdesc = ucp_am_find_first_rdesc(worker, ep_ext, msg_id); if (first_rdesc != NULL) { @@ -1421,10 +1425,11 @@ ucs_status_t ucp_am_rndv_process_rts(void *arg, void *data, size_t length, ucs_status_t status, desc_status; void *hdr; - ep = UCP_WORKER_GET_EP_BY_ID(worker, rts->super.sreq.ep_id, "AM RTS", + ep = UCP_WORKER_GET_EP_BY_ID(worker, rts->super.sreq.ep_id, 1, { status = UCS_ERR_ENDPOINT_TIMEOUT; goto out_send_ats; - }); + }, + "AM RTS"); if (ucs_unlikely(!ucp_am_recv_check_id(worker, am_id))) { status = UCS_ERR_INVALID_PARAM; diff --git a/src/ucp/core/ucp_request.inl b/src/ucp/core/ucp_request.inl index 386dc68fd665..340077914b71 100644 --- a/src/ucp/core/ucp_request.inl +++ b/src/ucp/core/ucp_request.inl @@ -794,4 +794,22 @@ ucp_request_invoke_uct_completion(ucp_request_t *req, ucs_status_t status) ucp_invoke_uct_completion(&req->send.state.uct_comp, status); } +static UCS_F_ALWAYS_INLINE void +ucp_request_complete_recv_rndv(ucp_request_t *req, ucs_status_t status) +{ + if (req->flags & UCP_REQUEST_FLAG_RECV_AM) { + ucp_request_complete_am_recv(req, status); + } else { + ucs_assert(req->flags & UCP_REQUEST_FLAG_RECV_TAG); + ucp_request_complete_tag_recv(req, status); + } +} + +static UCS_F_ALWAYS_INLINE void +ucp_request_complete_recv_rndv_common(ucp_request_t *rreq, ucs_status_t status) +{ + ucp_request_recv_buffer_dereg(rreq); + ucp_request_complete_recv_rndv(rreq, status); +} + #endif diff --git a/src/ucp/core/ucp_worker.inl b/src/ucp/core/ucp_worker.inl index 8cfe7afcf947..97d4b0a83a7c 100644 --- a/src/ucp/core/ucp_worker.inl +++ b/src/ucp/core/ucp_worker.inl @@ -78,11 +78,7 @@ ucp_worker_get_request_id(ucp_worker_h worker, ucp_request_t *req, int indirect) static UCS_F_ALWAYS_INLINE ucp_request_t* ucp_worker_get_request_by_id(ucp_worker_h worker, ucs_ptr_map_key_t id) { - ucp_request_t* request; - - request = (ucp_request_t*)ucs_ptr_map_get(&worker->ptr_map, id); - ucs_assert(request != NULL); - return request; + return (ucp_request_t*)ucs_ptr_map_get(&worker->ptr_map, id); } static UCS_F_ALWAYS_INLINE void @@ -247,14 +243,16 @@ ucp_worker_get_rkey_config(ucp_worker_h worker, const ucp_rkey_config_key_t *key return ucp_worker_add_rkey_config(worker, key, cfg_index_p); } -#define UCP_WORKER_GET_EP_BY_ID(_worker, _ep_id, _str, _action) \ +#define UCP_WORKER_GET_EP_BY_ID(_worker, _ep_id, _check_closed_ep, \ + _action, _fmt_str, ...) \ ({ \ ucp_ep_h _ep = ucp_worker_get_ep_by_id(_worker, _ep_id); \ if (ucs_unlikely((_ep == NULL) || \ - ((_ep)->flags & (UCP_EP_FLAG_CLOSED | \ - UCP_EP_FLAG_FAILED)))) { \ - ucs_trace_data("worker %p: drop %s on closed/failed ep %p", \ - _worker, _str, _ep); \ + (_check_closed_ep && \ + ((_ep)->flags & (UCP_EP_FLAG_CLOSED | \ + UCP_EP_FLAG_FAILED))))) { \ + ucs_diag("worker %p: ep id 0x%" PRIx64 " and closed/failed ep %p," \ + " drop " _fmt_str, _worker, _ep_id, _ep, ##__VA_ARGS__); \ _action; \ } \ _ep; \ diff --git a/src/ucp/proto/proto_am.inl b/src/ucp/proto/proto_am.inl index 6440c7740030..a9014aa134ef 100644 --- a/src/ucp/proto/proto_am.inl +++ b/src/ucp/proto/proto_am.inl @@ -521,14 +521,18 @@ static UCS_F_ALWAYS_INLINE ucp_request_t* ucp_proto_ssend_ack_request_alloc(ucp_worker_h worker, ucs_ptr_map_key_t ep_id) { ucp_request_t *req; + ucp_ep_h ep; + ep = UCP_WORKER_GET_EP_BY_ID(worker, ep_id, 1, return NULL, + "ACK for sync-send"); req = ucp_request_get(worker); if (req == NULL) { + ucs_error("failed to allocate UCP request"); return NULL; } req->flags = 0; - req->send.ep = ucp_worker_get_ep_by_id(worker, ep_id); + req->send.ep = ep; req->send.uct.func = ucp_proto_progress_am_single; req->send.proto.comp_cb = ucp_request_put; req->send.proto.status = UCS_OK; diff --git a/src/ucp/rma/amo_sw.c b/src/ucp/rma/amo_sw.c index 12f01dd5e646..ad059992fb20 100644 --- a/src/ucp/rma/amo_sw.c +++ b/src/ucp/rma/amo_sw.c @@ -191,11 +191,12 @@ UCS_PROFILE_FUNC(ucs_status_t, ucp_atomic_req_handler, (arg, data, length, am_fl { ucp_atomic_req_hdr_t *atomicreqh = data; ucp_worker_h worker = arg; - ucp_ep_h ep = ucp_worker_get_ep_by_id(worker, - atomicreqh->req.ep_id); ucp_rsc_index_t amo_rsc_idx = ucs_ffs64_safe(worker->atomic_tls); ucp_request_t *req; + ucp_ep_h ep; + ep = UCP_WORKER_GET_EP_BY_ID(worker, atomicreqh->req.ep_id, 1, + return UCS_OK, "SW AMO request"); if (ucs_unlikely((amo_rsc_idx != UCP_MAX_RESOURCES) && (ucp_worker_iface_get_attr(worker, amo_rsc_idx)->cap.flags & diff --git a/src/ucp/rma/rma_sw.c b/src/ucp/rma/rma_sw.c index 1696d66b2596..996d49501534 100644 --- a/src/ucp/rma/rma_sw.c +++ b/src/ucp/rma/rma_sw.c @@ -142,10 +142,13 @@ UCS_PROFILE_FUNC(ucs_status_t, ucp_put_handler, (arg, data, length, am_flags), { ucp_put_hdr_t *puth = data; ucp_worker_h worker = arg; + ucp_ep_h ep; + ep = UCP_WORKER_GET_EP_BY_ID(worker, puth->ep_id, 1, return UCS_OK, + "SW PUT request"); ucp_dt_contig_unpack(worker, (void*)puth->address, puth + 1, length - sizeof(*puth), puth->mem_type); - ucp_rma_sw_send_cmpl(ucp_worker_get_ep_by_id(worker, puth->ep_id)); + ucp_rma_sw_send_cmpl(ep); return UCS_OK; } @@ -154,8 +157,10 @@ UCS_PROFILE_FUNC(ucs_status_t, ucp_rma_cmpl_handler, (arg, data, length, am_flag { ucp_cmpl_hdr_t *putackh = data; ucp_worker_h worker = arg; - ucp_ep_h ep = ucp_worker_get_ep_by_id(worker, putackh->ep_id); + ucp_ep_h ep; + ep = UCP_WORKER_GET_EP_BY_ID(worker, putackh->ep_id, 1, return UCS_OK, + "SW RMA completion"); ucp_ep_rma_remote_request_completed(ep); return UCS_OK; } @@ -208,10 +213,11 @@ UCS_PROFILE_FUNC(ucs_status_t, ucp_get_req_handler, (arg, data, length, am_flags { ucp_get_req_hdr_t *getreqh = data; ucp_worker_h worker = arg; - ucp_ep_h ep = ucp_worker_get_ep_by_id(worker, - getreqh->req.ep_id); + ucp_ep_h ep; ucp_request_t *req; + ep = UCP_WORKER_GET_EP_BY_ID(worker, getreqh->req.ep_id, 1, return UCS_OK, + "SW GET request"); req = ucp_request_get(worker); if (req == NULL) { ucs_error("failed to allocate get reply"); @@ -239,10 +245,17 @@ UCS_PROFILE_FUNC(ucs_status_t, ucp_get_rep_handler, (arg, data, length, am_flags ucp_worker_h worker = arg; ucp_rma_rep_hdr_t *getreph = data; size_t frag_length = length - sizeof(*getreph); - ucp_request_t *req = ucp_worker_get_request_by_id(worker, - getreph->req_id); - ucp_ep_h ep = req->send.ep; + ucp_request_t *req; + ucp_ep_h ep; + + req = ucp_worker_get_request_by_id(worker, getreph->req_id); + if (ucs_unlikely(req == NULL)) { + ucs_diag("unable to get request from GET reply data %p for non-existing" + " ep_id 0x%"PRIx64, getreph, getreph->req_id); + return UCS_OK; + } + ep = req->send.ep; if (ep->worker->context->config.ext.proto_enable) { // TODO use dt_iter.inl unpack ucp_dt_contig_unpack(ep->worker, diff --git a/src/ucp/rndv/rndv.c b/src/ucp/rndv/rndv.c index 70eedda48603..985ad19030d7 100644 --- a/src/ucp/rndv/rndv.c +++ b/src/ucp/rndv/rndv.c @@ -734,6 +734,12 @@ UCS_PROFILE_FUNC_VOID(ucp_rndv_recv_frag_put_completion, (self), /* rndv_req is NULL in case of put protocol */ if (!is_put_proto) { rndv_req = ucp_worker_get_request_by_id(worker, rreq_remote_id); + if (ucs_unlikely(rndv_req == NULL)) { + ucs_diag("unable to get request from fragmented PUT request %p for" + " non-existing ep_id 0x%"PRIx64, freq, rreq_remote_id); + return; + } + /* pipeline recv get protocol */ rndv_req->send.state.dt.offset += freq->send.length; @@ -1217,17 +1223,19 @@ UCS_PROFILE_FUNC_VOID(ucp_rndv_receive, (worker, rreq, rndv_rts_hdr, rkey_buf), UCS_PROFILE_REQUEST_EVENT(rreq, "rndv_receive", 0); + ep = UCP_WORKER_GET_EP_BY_ID(worker, rndv_rts_hdr->sreq.ep_id, 1, goto err, + "RNDV rts"); + /* the internal send request allocated on receiver side (to perform a "get" * operation, send "ATS" and "RTR") */ rndv_req = ucp_request_get(worker); if (rndv_req == NULL) { ucs_error("failed to allocate rendezvous reply"); - goto out; + goto err; } - rndv_req->send.ep = ucp_worker_get_ep_by_id(worker, - rndv_rts_hdr->sreq.ep_id); rndv_req->flags = 0; + rndv_req->send.ep = ep; rndv_req->send.mdesc = NULL; rndv_req->send.pending_lane = UCP_NULL_LANE; is_get_zcopy_failed = 0; @@ -1248,7 +1256,6 @@ UCS_PROFILE_FUNC_VOID(ucp_rndv_receive, (worker, rreq, rndv_rts_hdr, rkey_buf), } /* if the receive side is not connected yet then the RTS was received on a stub ep */ - ep = rndv_req->send.ep; ep_config = ucp_ep_config(ep); rndv_mode = worker->context->config.ext.rndv_mode; @@ -1320,6 +1327,11 @@ UCS_PROFILE_FUNC_VOID(ucp_rndv_receive, (worker, rreq, rndv_rts_hdr, rkey_buf), out: UCS_ASYNC_UNBLOCK(&worker->async); + return; + +err: + ucp_request_complete_recv_rndv_common(rreq, UCS_ERR_CONNECTION_RESET); + goto out; } UCS_PROFILE_FUNC(ucs_status_t, ucp_rndv_rts_handler, @@ -1678,6 +1690,12 @@ UCS_PROFILE_FUNC(ucs_status_t, ucp_rndv_atp_handler, ucp_request_t *req = ucp_worker_get_request_by_id(arg, rep_hdr->req_id); + if (ucs_unlikely(req == NULL)) { + ucs_diag("unable to get request from RNDV ATP %p for non-existing" + " ep_id 0x%"PRIx64, rep_hdr, rep_hdr->req_id); + return UCS_OK; + } + if (req->flags & UCP_REQUEST_FLAG_RNDV_FRAG) { /* received ATP for frag RTR request */ ucs_assert(req->super_req != NULL); @@ -1699,14 +1717,24 @@ UCS_PROFILE_FUNC(ucs_status_t, ucp_rndv_rtr_handler, void *arg, void *data, size_t length, unsigned flags) { ucp_rndv_rtr_hdr_t *rndv_rtr_hdr = data; - ucp_request_t *sreq = ucp_worker_get_request_by_id(arg, - rndv_rtr_hdr->sreq_id); - ucp_ep_h ep = sreq->send.ep; - ucp_ep_config_t *ep_config = ucp_ep_config(ep); - ucp_context_h context = ep->worker->context; + ucp_request_t *sreq; + ucp_ep_h ep; + ucp_ep_config_t *ep_config; + ucp_context_h context; ucs_status_t status; int is_pipeline_rndv; + sreq = ucp_worker_get_request_by_id(arg, rndv_rtr_hdr->sreq_id); + if (ucs_unlikely(sreq == NULL)) { + ucs_diag("unable to get request from RNDV RTR %p for non-existing" + " ep_id 0x%"PRIx64, rndv_rtr_hdr, rndv_rtr_hdr->sreq_id); + return UCS_OK; + } + + ep = sreq->send.ep; + ep_config = ucp_ep_config(ep); + context = ep->worker->context; + ucp_trace_req(sreq, "received rtr address 0x%"PRIx64" remote rreq_id" "0x%"PRIx64, rndv_rtr_hdr->address, rndv_rtr_hdr->rreq_id); UCS_PROFILE_REQUEST_EVENT(sreq, "rndv_rtr_recv", 0); @@ -1815,6 +1843,12 @@ UCS_PROFILE_FUNC(ucs_status_t, ucp_rndv_data_handler, size_t recv_len; rreq = ucp_worker_get_request_by_id(worker, rndv_data_hdr->rreq_id); + if (ucs_unlikely(rreq == NULL)) { + ucs_diag("unable to get request from RNDV data %p for non-existing" + " ep_id 0x%"PRIx64, rndv_data_hdr, rndv_data_hdr->rreq_id); + return UCS_OK; + } + ucs_assert(!(rreq->flags & UCP_REQUEST_FLAG_RNDV_FRAG) && (rreq->flags & (UCP_REQUEST_FLAG_RECV_AM | UCP_REQUEST_FLAG_RECV_TAG))); diff --git a/src/ucp/stream/stream_recv.c b/src/ucp/stream/stream_recv.c index 7836b3fe2566..5261f690b86e 100644 --- a/src/ucp/stream/stream_recv.c +++ b/src/ucp/stream/stream_recv.c @@ -528,7 +528,8 @@ ucp_stream_am_handler(void *am_arg, void *am_data, size_t am_length, ucs_assert(am_length >= sizeof(ucp_stream_am_hdr_t)); - ep = ucp_worker_get_ep_by_id(worker, data->hdr.ep_id); + ep = UCP_WORKER_GET_EP_BY_ID(worker, data->hdr.ep_id, 1, return UCS_OK, + "stream data"); ep_ext = ucp_ep_ext_proto(ep); if (ucs_unlikely(ep->flags & (UCP_EP_FLAG_CLOSED | diff --git a/src/ucp/tag/eager_snd.c b/src/ucp/tag/eager_snd.c index 110af2030e7a..3bc8a14dbe64 100644 --- a/src/ucp/tag/eager_snd.c +++ b/src/ucp/tag/eager_snd.c @@ -318,10 +318,11 @@ void ucp_tag_eager_sync_send_ack(ucp_worker_h worker, void *hdr, uint16_t recv_f ucs_assert(reqhdr->req_id != UCP_REQUEST_ID_INVALID); req = ucp_proto_ssend_ack_request_alloc(worker, reqhdr->ep_id); if (req == NULL) { - ucs_fatal("could not allocate request"); + /* drop the packet */ + return; } - req->send.proto.am_id = UCP_AM_ID_EAGER_SYNC_ACK; + req->send.proto.am_id = UCP_AM_ID_EAGER_SYNC_ACK; req->send.proto.remote_req_id = reqhdr->req_id; ucs_trace_req("send_sync_ack req %p ep %p", req, req->send.ep); diff --git a/src/ucp/wireup/wireup.c b/src/ucp/wireup/wireup.c index 7f34412bbf8d..229e7118940a 100644 --- a/src/ucp/wireup/wireup.c +++ b/src/ucp/wireup/wireup.c @@ -387,8 +387,9 @@ ucp_wireup_process_pre_request(ucp_worker_h worker, const ucp_wireup_msg_t *msg, /* wireup pre_request for a specific ep */ ep = ucp_worker_get_ep_by_id(worker, msg->dst_ep_id); - ucs_assert((ep->flags & UCP_EP_FLAG_SOCKADDR_PARTIAL_ADDR) || - ucp_ep_has_cm_lane(ep)); + ucs_assert((ep != NULL) && + ((ep->flags & UCP_EP_FLAG_SOCKADDR_PARTIAL_ADDR) || + ucp_ep_has_cm_lane(ep))); ucp_ep_update_remote_id(ep, msg->src_ep_id); ucp_ep_flush_state_reset(ep); @@ -437,6 +438,7 @@ ucp_wireup_process_request(ucp_worker_h worker, const ucp_wireup_msg_t *msg, if (msg->dst_ep_id != UCP_EP_ID_INVALID) { /* wireup request for a specific ep */ ep = ucp_worker_get_ep_by_id(worker, msg->dst_ep_id); + ucs_assert(ep != NULL); ucp_ep_update_remote_id(ep, msg->src_ep_id); if (!(ep->flags & UCP_EP_FLAG_LISTENER)) { /* Reset flush state only if it's not a client-server wireup on @@ -601,6 +603,7 @@ ucp_wireup_process_reply(ucp_worker_h worker, const ucp_wireup_msg_t *msg, int ack; ep = ucp_worker_get_ep_by_id(worker, msg->dst_ep_id); + ucs_assert(ep != NULL); ucs_assert(msg->type == UCP_WIREUP_MSG_REPLY); ucs_assert((!(ep->flags & UCP_EP_FLAG_LISTENER))); @@ -649,6 +652,7 @@ void ucp_wireup_process_ack(ucp_worker_h worker, const ucp_wireup_msg_t *msg) ucp_ep_h ep; ep = ucp_worker_get_ep_by_id(worker, msg->dst_ep_id); + ucs_assert(ep != NULL); ucs_assert(msg->type == UCP_WIREUP_MSG_ACK); ucs_trace("ep %p: got wireup ack", ep); @@ -680,20 +684,15 @@ static ucs_status_t ucp_wireup_msg_handler(void *arg, void *data, ucp_worker_h worker = arg; ucp_wireup_msg_t *msg = data; ucp_unpacked_address_t remote_address; - ucp_ep_h ep UCS_V_UNUSED; ucs_status_t status; UCS_ASYNC_BLOCK(&worker->async); if (msg->dst_ep_id != UCP_EP_ID_INVALID) { - ep = ucp_worker_get_ep_by_id(worker, msg->dst_ep_id); - if (ep == NULL) { - ucs_diag("got wireup msg %d src_ep_id 0x%"PRIx64" for" - " non-existing dst_ep_id 0x%"PRIx64" sn %d," - " ignoring it", - msg->type, msg->src_ep_id, msg->dst_ep_id, msg->conn_sn); - goto out; - } + UCP_WORKER_GET_EP_BY_ID(worker, msg->dst_ep_id, 0, goto out, + "WIREUP message (%d src_ep_id 0x%"PRIx64 + " sn %d)", msg->type, msg->src_ep_id, + msg->conn_sn); } status = ucp_address_unpack(worker, msg + 1, diff --git a/test/gtest/ucp/test_ucp_sockaddr.cc b/test/gtest/ucp/test_ucp_sockaddr.cc index ec01a41adafd..71661d80a80b 100644 --- a/test/gtest/ucp/test_ucp_sockaddr.cc +++ b/test/gtest/ucp/test_ucp_sockaddr.cc @@ -17,6 +17,7 @@ extern "C" { #include #include #include +#include #include } @@ -59,11 +60,13 @@ class test_ucp_sockaddr : public ucp_test { } send_recv_type_t; ucs::sock_addr_storage m_test_addr; + static unsigned m_err_count; void init() { if (get_variant_value() & TEST_MODIFIER_CM) { modify_config("SOCKADDR_CM_ENABLE", "yes"); } + m_err_count = 0; get_sockaddr(); ucp_test::init(); skip_loopback(); @@ -215,6 +218,15 @@ class test_ucp_sockaddr : public ucp_test { UCS_TEST_MESSAGE << "server listening on " << m_test_addr.to_str(); } + static void complete_err_handling_status_verify(ucs_status_t status) + { + EXPECT_TRUE(/* was successful */ + (status == UCS_OK) || + /* completed from error handling for EP */ + (status == UCS_ERR_ENDPOINT_TIMEOUT) || + (status == UCS_ERR_CONNECTION_RESET)); + } + static void scomplete_cb(void *req, ucs_status_t status) { if ((status == UCS_OK) || @@ -225,12 +237,23 @@ class test_ucp_sockaddr : public ucp_test { UCS_TEST_ABORT("Error: " << ucs_status_string(status)); } + static void scomplete_err_handling_cb(void *req, ucs_status_t status) + { + complete_err_handling_status_verify(status); + } + static void rtag_complete_cb(void *req, ucs_status_t status, ucp_tag_recv_info_t *info) { EXPECT_UCS_OK(status); } + static void rtag_complete_err_handling_cb(void *req, ucs_status_t status, + ucp_tag_recv_info_t *info) + { + complete_err_handling_status_verify(status); + } + static void rstream_complete_cb(void *req, ucs_status_t status, size_t length) { @@ -512,6 +535,8 @@ class test_ucp_sockaddr : public ucp_test { static void err_handler_cb(void *arg, ucp_ep_h ep, ucs_status_t status) { ucp_test::err_handler_cb(arg, ep, status); + ++m_err_count; + /* The current expected errors are only from the err_handle test * and from transports where the worker address is too long but ud/ud_x * are not present, or ud/ud_x are present but their addresses are too @@ -523,6 +548,7 @@ class test_ucp_sockaddr : public ucp_test { case UCS_ERR_UNREACHABLE: case UCS_ERR_CONNECTION_RESET: case UCS_ERR_NOT_CONNECTED: + case UCS_ERR_ENDPOINT_TIMEOUT: UCS_TEST_MESSAGE << "ignoring error " << ucs_status_string(status) << " on endpoint " << ep; return; @@ -568,6 +594,8 @@ class test_ucp_sockaddr : public ucp_test { } }; +unsigned test_ucp_sockaddr::m_err_count = 0; + UCS_TEST_SKIP_COND_P(test_ucp_sockaddr, listen, no_close_protocol()) { listen_and_communicate(false, 0); } @@ -786,7 +814,7 @@ UCS_TEST_SKIP_COND_P(test_ucp_sockaddr, compare_cm_and_wireup_configs, UCP_INSTANTIATE_ALL_TEST_CASE(test_ucp_sockaddr) -class test_ucp_sockaddr_destroy_ep_on_err : public test_ucp_sockaddr { +class test_ucp_sockaddr_destroy_ep_on_err : virtual public test_ucp_sockaddr { public: test_ucp_sockaddr_destroy_ep_on_err() { /* Set small TL timeouts to reduce testing time */ @@ -795,6 +823,8 @@ class test_ucp_sockaddr_destroy_ep_on_err : public test_ucp_sockaddr { m_env.push_back(new ucs::scoped_setenv("UCX_RC_RETRY_COUNT", "2")); } + virtual ~test_ucp_sockaddr_destroy_ep_on_err() { } + virtual ucp_ep_params_t get_server_ep_params() { ucp_ep_params_t params = test_ucp_sockaddr::get_server_ep_params(); @@ -978,8 +1008,10 @@ UCS_TEST_P(test_ucp_sockaddr_with_rma_atomic, wireup) { UCP_INSTANTIATE_ALL_TEST_CASE(test_ucp_sockaddr_with_rma_atomic) -class test_ucp_sockaddr_protocols : public test_ucp_sockaddr { +class test_ucp_sockaddr_protocols : virtual public test_ucp_sockaddr { public: + virtual ~test_ucp_sockaddr_protocols() { } + static void get_test_variants(std::vector& variants) { /* Atomics not supported for now because need to emulate the case * of using different device than the one selected by default on the @@ -1022,13 +1054,28 @@ class test_ucp_sockaddr_protocols : public test_ucp_sockaddr { << "recv_buf: '" << ucs::compact_string(recv_buf, 20) << "'"; } - void test_tag_send_recv(size_t size, bool is_exp, bool is_sync = false) + typedef void (*stop_cb_t)(void *arg); + + virtual void test_tag_send_recv(size_t size, bool is_exp, + bool is_sync = false, + stop_cb_t send_stop = NULL, + stop_cb_t recv_stop = NULL, + void *arg = NULL) { std::string send_buf(size, 'x'); std::string recv_buf(size, 'y'); void *rreq = NULL, *sreq = NULL; + bool err_handling = ((recv_stop != NULL) || + (send_stop != NULL)); + + scoped_log_handler *slh = NULL; + if (err_handling) { + slh = new scoped_log_handler(wrap_errors_logger); + ASSERT_TRUE(slh != NULL); + } + if (is_exp) { rreq = ucp_tag_recv_nb(receiver().worker(), &recv_buf[0], size, ucp_dt_make_contig(1), 0, 0, rtag_complete_cb); @@ -1036,22 +1083,79 @@ class test_ucp_sockaddr_protocols : public test_ucp_sockaddr { if (is_sync) { sreq = ucp_tag_send_sync_nb(sender().ep(), &send_buf[0], size, - ucp_dt_make_contig(1), 0, scomplete_cb); + ucp_dt_make_contig(1), 0, + !err_handling ? scomplete_cb : + scomplete_err_handling_cb); } else { sreq = ucp_tag_send_nb(sender().ep(), &send_buf[0], size, - ucp_dt_make_contig(1), 0, scomplete_cb); + ucp_dt_make_contig(1), 0, + !err_handling ? scomplete_cb : + scomplete_err_handling_cb); } if (!is_exp) { - short_progress_loop(); - rreq = ucp_tag_recv_nb(receiver().worker(), &recv_buf[0], size, - ucp_dt_make_contig(1), 0, 0, rtag_complete_cb); + ucp_tag_recv_info_t recv_info = {}; + ucp_tag_message_h message; + do { + short_progress_loop(); + message = ucp_tag_probe_nb(receiver().worker(), + 0, 0, 1, &recv_info); + } while (message == NULL); + + EXPECT_EQ(size, recv_info.length); + EXPECT_EQ(0, recv_info.sender_tag); + + if (err_handling) { + if (recv_stop != NULL) { + recv_stop(arg); + } else { + /* If send request is NULL, it means that send operation + * completed immediately */ + if (sreq != NULL) { + /* TODO: remove memory deregistration, when UCP + * requests tracking is added */ + ucp_request_t *req = static_cast + (sreq) - 1; + ucp_request_memory_dereg(req->send.ep->worker->context, + req->send.datatype, + &req->send.state.dt, req); + ucp_request_free(sreq); + } + send_stop(arg); + } + } + + rreq = ucp_tag_msg_recv_nb(receiver().worker(), &recv_buf[0], size, + ucp_dt_make_contig(1), message, + !err_handling ? rtag_complete_cb : + rtag_complete_err_handling_cb); } - request_wait(sreq); - request_wait(rreq); + if (!err_handling) { + request_wait(sreq); + request_wait(rreq); + } else if (recv_stop != NULL) { + /* TODO: add waiting for send request completion and remove + * memory deregistration, when UCP requests tracking is added */ + request_wait(rreq); + + if (sreq != NULL) { + ucp_request_t *req = static_cast(sreq) - 1; + ucp_request_memory_dereg(req->send.ep->worker->context, + req->send.datatype, + &req->send.state.dt, req); + } + sender().disconnect_nb(0, 0, UCP_EP_CLOSE_MODE_FORCE); + } else if (send_stop != NULL) { + request_wait(rreq); + receiver().disconnect_nb(0, 0, UCP_EP_CLOSE_MODE_FORCE); + } - compare_buffers(send_buf, recv_buf); + if (!err_handling) { + compare_buffers(send_buf, recv_buf); + } else { + delete slh; + } } void wait_for_server_ep() @@ -1371,3 +1475,177 @@ UCS_TEST_P(test_ucp_sockaddr_protocols, am_zcopy_64k, UCP_INSTANTIATE_TEST_CASE_TLS(_test_case, all, "all") UCP_INSTANTIATE_CM_TEST_CASE(test_ucp_sockaddr_protocols) + + +class test_ucp_sockaddr_protocols_err : public test_ucp_sockaddr_protocols, + test_ucp_sockaddr_destroy_ep_on_err { +protected: + static void recv_stop(void *arg) { + test_ucp_sockaddr_protocols_err *test = + static_cast(arg); + test->one_sided_disconnect(test->receiver(), UCP_EP_CLOSE_MODE_FORCE); + while (m_err_count == 0) { + test->short_progress_loop(); + } + } + + static void send_stop(void *arg) { + test_ucp_sockaddr_protocols_err *test = + static_cast(arg); + test->one_sided_disconnect(test->sender(), UCP_EP_CLOSE_MODE_FORCE); + while (m_err_count == 0) { + test->short_progress_loop(); + } + } + + void test_tag_send_recv(size_t size, bool is_exp, + bool is_sync = false, + stop_cb_t send_stop = NULL, + stop_cb_t recv_stop = NULL, + void *arg = NULL) { + test_ucp_sockaddr_protocols::test_tag_send_recv(size, is_exp, is_sync); + test_ucp_sockaddr_protocols::test_tag_send_recv(size, is_exp, is_sync, + send_stop, recv_stop, + arg); + } + + static ucs_log_func_rc_t + warn_unreleased_ptr_map_leak_handler(const char *file, unsigned line, + const char *function, + ucs_log_level_t level, + const ucs_log_component_config_t + *comp_conf, const char *message, + va_list ap) { + if (level == UCS_LOG_LEVEL_WARN) { + std::string err_str = format_message(message, ap); + + if (/* TODO: remove when ptr map leaks are fixed */ + (err_str.find("ptr map ") != std::string::npos) || + /* TODO: remove when tracking of UCP requests is added */ + (err_str.find("was not returned to mpool ucp_requests") != + std::string::npos)) { + return UCS_LOG_FUNC_RC_STOP; + } + } + + return UCS_LOG_FUNC_RC_CONTINUE; + } + + void cleanup() { + scoped_log_handler slh(warn_unreleased_ptr_map_leak_handler); + test_ucp_sockaddr_protocols::cleanup(); + } +}; + + +UCS_TEST_P(test_ucp_sockaddr_protocols_err, tag_zcopy_4k_unexp_stop_recv, + "ZCOPY_THRESH=2k", "RNDV_THRESH=inf") +{ + test_tag_send_recv(4 * UCS_KBYTE, false, false, + NULL /* stops sender */, + recv_stop /* stops receiver */, + this /* argument to stop functions */); + +} + +UCS_TEST_P(test_ucp_sockaddr_protocols_err, tag_zcopy_4k_unexp_stop_send, + "ZCOPY_THRESH=2k", "RNDV_THRESH=inf") +{ + test_tag_send_recv(4 * UCS_KBYTE, false, false, + send_stop /* stops sender */, + NULL /* stops receiver */, + this /* argument to stop functions */); +} + +UCS_TEST_P(test_ucp_sockaddr_protocols_err, tag_zcopy_64k_unexp_stop_recv, + "ZCOPY_THRESH=2k", "RNDV_THRESH=inf") +{ + test_tag_send_recv(64 * UCS_KBYTE, false, false, + NULL /* stops sender */, + recv_stop /* stops receiver */, + this /* argument to stop functions */); +} + +UCS_TEST_P(test_ucp_sockaddr_protocols_err, tag_zcopy_64k_unexp_stop_send, + "ZCOPY_THRESH=2k", "RNDV_THRESH=inf") +{ + test_tag_send_recv(64 * UCS_KBYTE, false, false, + send_stop /* stops sender */, + NULL /* stops receiver */, + this /* argument to stop functions */); +} + +UCS_TEST_P(test_ucp_sockaddr_protocols_err, tag_zcopy_4k_unexp_sync_stop_recv, + "ZCOPY_THRESH=2k", "RNDV_THRESH=inf") +{ + test_tag_send_recv(4 * UCS_KBYTE, false, true, + NULL /* stops sender */, + recv_stop /* stops receiver */, + this /* argument to stop functions */); +} + +UCS_TEST_P(test_ucp_sockaddr_protocols_err, tag_zcopy_4k_unexp_sync_stop_send, + "ZCOPY_THRESH=2k", "RNDV_THRESH=inf") +{ + test_tag_send_recv(4 * UCS_KBYTE, false, true, + send_stop /* stops sender */, + NULL /* stops receiver */, + this /* argument to stop functions */); +} + +UCS_TEST_P(test_ucp_sockaddr_protocols_err, tag_zcopy_64k_unexp_sync_stop_recv, + "ZCOPY_THRESH=2k", "RNDV_THRESH=inf") +{ + test_tag_send_recv(64 * UCS_KBYTE, false, true, + NULL /* stops sender */, + recv_stop /* stops receiver */, + this /* argument to stop functions */); +} + +UCS_TEST_P(test_ucp_sockaddr_protocols_err, tag_zcopy_64k_unexp_sync_stop_send, + "ZCOPY_THRESH=2k", "RNDV_THRESH=inf") +{ + test_tag_send_recv(64 * UCS_KBYTE, false, true, + send_stop /* stops sender */, + NULL /* stops receiver */, + this /* argument to stop functions */); +} + +UCS_TEST_P(test_ucp_sockaddr_protocols_err, tag_rndv_unexp_stop_recv, + "RNDV_THRESH=0") +{ + test_tag_send_recv(64 * UCS_KBYTE, false, false, + NULL /* stops sender */, + recv_stop /* stops receiver */, + this /* argument to stop functions */); +} + +UCS_TEST_P(test_ucp_sockaddr_protocols_err, tag_rndv_unexp_stop_recv_put_scheme, + "RNDV_THRESH=0", "RNDV_SCHEME=put_zcopy") +{ + test_tag_send_recv(64 * UCS_KBYTE, false, false, + NULL /* stops sender */, + recv_stop /* stops receiver */, + this /* argument to stop functions */); +} + +UCS_TEST_P(test_ucp_sockaddr_protocols_err, tag_rndv_unexp_stop_send, + "RNDV_THRESH=0") +{ + test_tag_send_recv(64 * UCS_KBYTE, false, false, + send_stop /* stops sender */, + NULL /* stops receiver */, + this /* argument to stop functions */); +} + +UCS_TEST_P(test_ucp_sockaddr_protocols_err, tag_rndv_unexp_stop_send_put_scheme, + "RNDV_THRESH=0", "RNDV_SCHEME=put_zcopy") +{ + test_tag_send_recv(64 * UCS_KBYTE, false, false, + send_stop /* stops sender */, + NULL /* stops receiver */, + this /* argument to stop functions */); +} + + +UCP_INSTANTIATE_CM_TEST_CASE(test_ucp_sockaddr_protocols_err)