diff --git a/src/ucp/core/ucp_am.c b/src/ucp/core/ucp_am.c index 11d47277aa5..1ec239ae217 100644 --- a/src/ucp/core/ucp_am.c +++ b/src/ucp/core/ucp_am.c @@ -138,14 +138,17 @@ static void ucp_am_rndv_send_ats(ucp_worker_h worker, ucs_status_t status) { ucp_request_t *req; + ucp_ep_h ep; + ep = UCP_WORKER_GET_EP_BY_ID(worker, rts->super.sreq.ep_id, return, + "AM RNDV ATS"); req = ucp_request_get(worker); if (ucs_unlikely(req == NULL)) { ucs_error("failed to allocate request for AM RNDV ATS"); return; } - req->send.ep = ucp_worker_get_ep_by_id(worker, rts->super.sreq.ep_id); + req->send.ep = ep; req->flags = 0; ucp_rndv_req_send_ats(req, NULL, rts->super.sreq.req_id, status); @@ -1227,8 +1230,8 @@ UCS_PROFILE_FUNC(ucs_status_t, ucp_am_handler_reply, ucp_worker_h worker = (ucp_worker_h)am_arg; ucp_ep_h reply_ep; - reply_ep = UCP_WORKER_GET_EP_BY_ID(worker, hdr->ep_id, "AM (reply proto)", - return UCS_OK); + reply_ep = UCP_WORKER_GET_VALID_EP_BY_ID(worker, hdr->ep_id, return UCS_OK, + "AM (reply proto)"); return ucp_am_handler_common(worker, &hdr->super, sizeof(*hdr), am_length, reply_ep, am_flags, @@ -1356,8 +1359,8 @@ UCS_PROFILE_FUNC(ucs_status_t, ucp_am_long_first_handler, size_t remaining; uint64_t recv_flags; - ep = UCP_WORKER_GET_EP_BY_ID(worker, first_hdr->super.ep_id, - "AM first fragment", return UCS_OK); + ep = UCP_WORKER_GET_VALID_EP_BY_ID(worker, first_hdr->super.ep_id, + return UCS_OK, "AM first fragment"); remaining = first_hdr->total_size - (am_length - sizeof(*first_hdr)); if (ucs_unlikely(remaining == 0)) { @@ -1429,8 +1432,8 @@ UCS_PROFILE_FUNC(ucs_status_t, ucp_am_long_middle_handler, ucp_ep_h ep; ucs_status_t status; - ep = UCP_WORKER_GET_EP_BY_ID(worker, mid_hdr->ep_id, - "AM middle fragment", return UCS_OK); + ep = UCP_WORKER_GET_VALID_EP_BY_ID(worker, mid_hdr->ep_id, + return UCS_OK, "AM middle fragment"); ep_ext = ucp_ep_ext_proto(ep); first_rdesc = ucp_am_find_first_rdesc(worker, ep_ext, msg_id); if (first_rdesc != NULL) { @@ -1472,10 +1475,11 @@ ucs_status_t ucp_am_rndv_process_rts(void *arg, void *data, size_t length, ucs_status_t status, desc_status; void *hdr; - ep = UCP_WORKER_GET_EP_BY_ID(worker, rts->super.sreq.ep_id, "AM RTS", - { status = UCS_ERR_ENDPOINT_TIMEOUT; - goto out_send_ats; - }); + ep = UCP_WORKER_GET_VALID_EP_BY_ID(worker, rts->super.sreq.ep_id, + { status = UCS_ERR_CANCELED; + goto out_send_ats; + }, + "AM RTS"); if (ucs_unlikely(!ucp_am_recv_check_id(worker, am_id))) { status = UCS_ERR_INVALID_PARAM; diff --git a/src/ucp/core/ucp_ep.c b/src/ucp/core/ucp_ep.c index 18afecd4bbb..9359c39f6a1 100644 --- a/src/ucp/core/ucp_ep.c +++ b/src/ucp/core/ucp_ep.c @@ -211,8 +211,6 @@ ucs_status_t ucp_worker_create_ep(ucp_worker_h worker, unsigned ep_init_flags, static void ucp_ep_delete(ucp_ep_h ep) { - ucs_status_t status; - ucs_callbackq_remove_if(&ep->worker->uct->progress_q, ucp_wireup_msg_ack_cb_pred, ep); if (!(ep->flags & UCP_EP_FLAG_INTERNAL)) { @@ -220,13 +218,24 @@ static void ucp_ep_delete(ucp_ep_h ep) } ucs_list_del(&ucp_ep_ext_gen(ep)->ep_list); + if (!(ep->flags & UCP_EP_FLAG_FAILED)) { + ucp_ep_release_id(ep); + } + + ucp_ep_destroy_base(ep); +} + +void ucp_ep_release_id(ucp_ep_h ep) +{ + ucs_status_t status; + + ucs_assert(!(ep->flags & UCP_EP_FLAG_FAILED)); + status = ucs_ptr_map_del(&ep->worker->ptr_map, ucp_ep_local_id(ep)); if (status != UCS_OK) { - ucs_warn("ep %p local id 0x%"PRIxPTR": ucs_ptr_map_del failed with status %s", + ucs_warn("ep %p local id 0x%" PRIxPTR ": ucs_ptr_map_del failed: %s", ep, ucp_ep_local_id(ep), ucs_status_string(status)); } - - ucp_ep_destroy_base(ep); } void ucp_ep_config_key_set_err_mode(ucp_ep_config_key_t *key, diff --git a/src/ucp/core/ucp_ep.h b/src/ucp/core/ucp_ep.h index 6508a82b7f5..0a936f59645 100644 --- a/src/ucp/core/ucp_ep.h +++ b/src/ucp/core/ucp_ep.h @@ -496,6 +496,8 @@ ucs_status_t ucp_worker_create_ep(ucp_worker_h worker, unsigned ep_init_flags, const char *peer_name, const char *message, ucp_ep_h *ep_p); +void ucp_ep_release_id(ucp_ep_h ep); + ucs_status_t ucp_ep_init_create_wireup(ucp_ep_h ep, unsigned ep_init_flags, ucp_wireup_ep_t **wireup_ep); diff --git a/src/ucp/core/ucp_worker.c b/src/ucp/core/ucp_worker.c index 79c40d959c7..c5e495cc923 100644 --- a/src/ucp/core/ucp_worker.c +++ b/src/ucp/core/ucp_worker.c @@ -502,6 +502,7 @@ ucs_status_t ucp_worker_set_ep_failed(ucp_worker_h worker, ucp_ep_h ucp_ep, goto out_ok; } + ucp_ep_release_id(ucp_ep); ucp_ep->flags |= UCP_EP_FLAG_FAILED; if (ucp_ep_config(ucp_ep)->key.err_mode == UCP_ERR_HANDLING_MODE_NONE) { diff --git a/src/ucp/core/ucp_worker.inl b/src/ucp/core/ucp_worker.inl index a7ef6d269e6..c29b9cbc2e6 100644 --- a/src/ucp/core/ucp_worker.inl +++ b/src/ucp/core/ucp_worker.inl @@ -83,11 +83,7 @@ ucp_worker_get_request_id(ucp_worker_h worker, ucp_request_t *req, int indirect) static UCS_F_ALWAYS_INLINE ucp_request_t* ucp_worker_get_request_by_id(ucp_worker_h worker, ucs_ptr_map_key_t id) { - ucp_request_t* request; - - request = (ucp_request_t*)ucs_ptr_map_get(&worker->ptr_map, id); - ucs_assert(request != NULL); - return request; + return (ucp_request_t*)ucs_ptr_map_get(&worker->ptr_map, id); } static UCS_F_ALWAYS_INLINE void @@ -262,17 +258,41 @@ ucp_worker_get_rkey_config(ucp_worker_h worker, const ucp_rkey_config_key_t *key return ucp_worker_add_rkey_config(worker, key, cfg_index_p); } -#define UCP_WORKER_GET_EP_BY_ID(_worker, _ep_id, _str, _action) \ +#define UCP_WORKER_GET_EP_BY_ID(_worker, _ep_id, _action, _fmt_str, ...) \ + ({ \ + ucp_ep_h __ep = ucp_worker_get_ep_by_id(_worker, _ep_id); \ + if (ucs_unlikely(__ep == NULL)) { \ + ucs_trace_data("worker %p: ep id 0x%" PRIx64 " was not found, drop" \ + _fmt_str, _worker, _ep_id, ##__VA_ARGS__); \ + _action; \ + } \ + __ep; \ + }) + +#define UCP_WORKER_GET_VALID_EP_BY_ID(_worker, _ep_id, _action, _fmt_str, ...) \ + ({ \ + ucp_ep_h ___ep = UCP_WORKER_GET_EP_BY_ID(_worker, _ep_id, _action, \ + _fmt_str, ##__VA_ARGS__); \ + if (ucs_unlikely((___ep != NULL) && \ + (___ep->flags & UCP_EP_FLAG_CLOSED))) { \ + ucs_trace_data("worker %p: ep id 0x%" PRIx64 " was already closed" \ + " ep %p, drop " _fmt_str, _worker, _ep_id, ___ep, \ + ##__VA_ARGS__); \ + _action; \ + } \ + ___ep; \ + }) + +#define UCP_WORKER_GET_REQ_BY_ID(_worker, _req_id, _action, _fmt_str, ...) \ ({ \ - ucp_ep_h _ep = ucp_worker_get_ep_by_id(_worker, _ep_id); \ - if (ucs_unlikely((_ep == NULL) || \ - ((_ep)->flags & (UCP_EP_FLAG_CLOSED | \ - UCP_EP_FLAG_FAILED)))) { \ - ucs_trace_data("worker %p: drop %s on closed/failed ep %p", \ - _worker, _str, _ep); \ + ucp_request_t *_req = ucp_worker_get_request_by_id(_worker, _req_id); \ + if (ucs_unlikely(_req == NULL)) { \ + ucs_trace_data("worker %p: req id 0x%" PRIx64 " doesn't exist" \ + " drop " _fmt_str, _worker, _req_id, \ + ##__VA_ARGS__); \ _action; \ } \ - _ep; \ + _req; \ }) #endif diff --git a/src/ucp/proto/proto_am.inl b/src/ucp/proto/proto_am.inl index 3c2d18ea31b..03aa880405c 100644 --- a/src/ucp/proto/proto_am.inl +++ b/src/ucp/proto/proto_am.inl @@ -527,17 +527,16 @@ ucp_proto_is_inline(ucp_ep_h ep, const ucp_memtype_thresh_t *max_eager_short, } static UCS_F_ALWAYS_INLINE ucp_request_t* -ucp_proto_ssend_ack_request_alloc(ucp_worker_h worker, ucs_ptr_map_key_t ep_id) +ucp_proto_ssend_ack_request_alloc(ucp_worker_h worker, ucp_ep_h ep) { - ucp_request_t *req; - - req = ucp_request_get(worker); + ucp_request_t *req = ucp_request_get(worker); if (req == NULL) { + ucs_error("failed to allocate UCP request"); return NULL; } req->flags = 0; - req->send.ep = ucp_worker_get_ep_by_id(worker, ep_id); + req->send.ep = ep; req->send.uct.func = ucp_proto_progress_am_single; req->send.proto.comp_cb = ucp_request_put; req->send.proto.status = UCS_OK; diff --git a/src/ucp/rma/amo_sw.c b/src/ucp/rma/amo_sw.c index 35a8f876047..e8aff5bbdb0 100644 --- a/src/ucp/rma/amo_sw.c +++ b/src/ucp/rma/amo_sw.c @@ -202,11 +202,15 @@ UCS_PROFILE_FUNC(ucs_status_t, ucp_atomic_req_handler, (arg, data, length, am_fl { ucp_atomic_req_hdr_t *atomicreqh = data; ucp_worker_h worker = arg; - ucp_ep_h ep = ucp_worker_get_ep_by_id(worker, - atomicreqh->req.ep_id); ucp_rsc_index_t amo_rsc_idx = ucs_ffs64_safe(worker->atomic_tls); ucp_request_t *req; + ucp_ep_h ep; + /* allow getting closed EP to be used for sending a completion or AMO data to + * enable flush on a peer + */ + ep = UCP_WORKER_GET_EP_BY_ID(worker, atomicreqh->req.ep_id, return UCS_OK, + "SW AMO request"); if (ucs_unlikely((amo_rsc_idx != UCP_MAX_RESOURCES) && (ucp_worker_iface_get_attr(worker, amo_rsc_idx)->cap.flags & diff --git a/src/ucp/rma/rma_sw.c b/src/ucp/rma/rma_sw.c index 0a606c35a75..0b2b7d1ca70 100644 --- a/src/ucp/rma/rma_sw.c +++ b/src/ucp/rma/rma_sw.c @@ -148,10 +148,16 @@ UCS_PROFILE_FUNC(ucs_status_t, ucp_put_handler, (arg, data, length, am_flags), { ucp_put_hdr_t *puth = data; ucp_worker_h worker = arg; + ucp_ep_h ep; + /* allow getting closed EP to be used for sending a completion to enable flush + * on a peer + */ + ep = UCP_WORKER_GET_EP_BY_ID(worker, puth->ep_id, return UCS_OK, + "SW PUT request"); ucp_dt_contig_unpack(worker, (void*)puth->address, puth + 1, length - sizeof(*puth), puth->mem_type); - ucp_rma_sw_send_cmpl(ucp_worker_get_ep_by_id(worker, puth->ep_id)); + ucp_rma_sw_send_cmpl(ep); return UCS_OK; } @@ -160,8 +166,13 @@ UCS_PROFILE_FUNC(ucs_status_t, ucp_rma_cmpl_handler, (arg, data, length, am_flag { ucp_cmpl_hdr_t *putackh = data; ucp_worker_h worker = arg; - ucp_ep_h ep = ucp_worker_get_ep_by_id(worker, putackh->ep_id); + ucp_ep_h ep; + /* allow getting closed EP to be used for handling a completion to enable flush + * on a peer + */ + ep = UCP_WORKER_GET_EP_BY_ID(worker, putackh->ep_id, return UCS_OK, + "SW RMA completion"); ucp_ep_rma_remote_request_completed(ep); return UCS_OK; } @@ -214,10 +225,14 @@ UCS_PROFILE_FUNC(ucs_status_t, ucp_get_req_handler, (arg, data, length, am_flags { ucp_get_req_hdr_t *getreqh = data; ucp_worker_h worker = arg; - ucp_ep_h ep = ucp_worker_get_ep_by_id(worker, - getreqh->req.ep_id); + ucp_ep_h ep; ucp_request_t *req; + /* allow getting closed EP to be used for sending a GET operation data to enable + * flush on a peer + */ + ep = UCP_WORKER_GET_EP_BY_ID(worker, getreqh->req.ep_id, return UCS_OK, + "SW GET request"); req = ucp_request_get(worker); if (req == NULL) { ucs_error("failed to allocate get reply"); @@ -246,10 +261,13 @@ UCS_PROFILE_FUNC(ucs_status_t, ucp_get_rep_handler, (arg, data, length, am_flags ucp_worker_h worker = arg; ucp_rma_rep_hdr_t *getreph = data; size_t frag_length = length - sizeof(*getreph); - ucp_request_t *req = ucp_worker_get_request_by_id(worker, - getreph->req_id); - ucp_ep_h ep = req->send.ep; + ucp_request_t *req; + ucp_ep_h ep; + req = UCP_WORKER_GET_REQ_BY_ID(worker, getreph->req_id, + return UCS_OK, + "GET reply data %p", getreph); + ep = req->send.ep; if (ep->worker->context->config.ext.proto_enable) { // TODO use dt_iter.inl unpack ucp_dt_contig_unpack(ep->worker, diff --git a/src/ucp/rndv/rndv.c b/src/ucp/rndv/rndv.c index 3121e6c44ea..edb9b0e514a 100644 --- a/src/ucp/rndv/rndv.c +++ b/src/ucp/rndv/rndv.c @@ -740,7 +740,10 @@ UCS_PROFILE_FUNC_VOID(ucp_rndv_recv_frag_put_completion, (self), /* rndv_req is NULL in case of put protocol */ if (!is_put_proto) { + /* it is local operation, expected that a request will be always valid */ rndv_req = ucp_worker_get_request_by_id(worker, rreq_remote_id); + ucs_assert(rndv_req != NULL); + /* pipeline recv get protocol */ rndv_req->send.state.dt.offset += freq->send.length; @@ -1227,17 +1230,24 @@ UCS_PROFILE_FUNC_VOID(ucp_rndv_receive, (worker, rreq, rndv_rts_hdr, rkey_buf), UCS_PROFILE_REQUEST_EVENT(rreq, "rndv_receive", 0); + /* if receiving a message on an already closed endpoint, stop processing */ + ep = UCP_WORKER_GET_VALID_EP_BY_ID(worker, rndv_rts_hdr->sreq.ep_id, + { status = UCS_ERR_CANCELED; + goto err; + }, + "RNDV rts"); + /* the internal send request allocated on receiver side (to perform a "get" * operation, send "ATS" and "RTR") */ rndv_req = ucp_request_get(worker); if (rndv_req == NULL) { ucs_error("failed to allocate rendezvous reply"); - goto out; + status = UCS_ERR_NO_MEMORY; + goto err; } - rndv_req->send.ep = ucp_worker_get_ep_by_id(worker, - rndv_rts_hdr->sreq.ep_id); rndv_req->flags = 0; + rndv_req->send.ep = ep; rndv_req->send.mdesc = NULL; rndv_req->send.pending_lane = UCP_NULL_LANE; is_get_zcopy_failed = 0; @@ -1258,7 +1268,6 @@ UCS_PROFILE_FUNC_VOID(ucp_rndv_receive, (worker, rreq, rndv_rts_hdr, rkey_buf), } /* if the receive side is not connected yet then the RTS was received on a stub ep */ - ep = rndv_req->send.ep; ep_config = ucp_ep_config(ep); rndv_mode = worker->context->config.ext.rndv_mode; @@ -1330,6 +1339,11 @@ UCS_PROFILE_FUNC_VOID(ucp_rndv_receive, (worker, rreq, rndv_rts_hdr, rkey_buf), out: UCS_ASYNC_UNBLOCK(&worker->async); + return; + +err: + ucp_rndv_recv_req_complete(rreq, status); + goto out; } UCS_PROFILE_FUNC(ucs_status_t, ucp_rndv_rts_handler, @@ -1687,8 +1701,9 @@ UCS_PROFILE_FUNC(ucs_status_t, ucp_rndv_atp_handler, void *arg, void *data, size_t length, unsigned flags) { ucp_reply_hdr_t *rep_hdr = data; - ucp_request_t *req = ucp_worker_get_request_by_id(arg, - rep_hdr->req_id); + ucp_request_t *req = UCP_WORKER_GET_REQ_BY_ID(arg, rep_hdr->req_id, + return UCS_OK, + "RNDV ATP %p", rep_hdr); if (req->flags & UCP_REQUEST_FLAG_RNDV_FRAG) { /* received ATP for frag RTR request */ @@ -1711,15 +1726,20 @@ UCS_PROFILE_FUNC(ucs_status_t, ucp_rndv_rtr_handler, void *arg, void *data, size_t length, unsigned flags) { ucp_worker_h worker = arg; + ucp_context_h context = worker->context; ucp_rndv_rtr_hdr_t *rndv_rtr_hdr = data; - ucp_request_t *sreq = ucp_worker_get_request_by_id(worker, - rndv_rtr_hdr->sreq_id); - ucp_ep_h ep = sreq->send.ep; - ucp_ep_config_t *ep_config = ucp_ep_config(ep); - ucp_context_h context = ep->worker->context; + ucp_request_t *sreq; + ucp_ep_h ep; + ucp_ep_config_t *ep_config; ucs_status_t status; int is_pipeline_rndv; + sreq = UCP_WORKER_GET_REQ_BY_ID(arg, rndv_rtr_hdr->sreq_id, + return UCS_OK, "RNDV RTR %p", + rndv_rtr_hdr); + ep = sreq->send.ep; + ep_config = ucp_ep_config(ep); + ucs_assertv(rndv_rtr_hdr->sreq_id == sreq->send.msg_proto.sreq_id, "received local sreq_id 0x%"PRIx64" is not equal to expected sreq_id" " 0x%"PRIx64, rndv_rtr_hdr->sreq_id, sreq->send.msg_proto.sreq_id); @@ -1746,7 +1766,7 @@ UCS_PROFILE_FUNC(ucs_status_t, ucp_rndv_rtr_handler, (sreq->send.length != rndv_rtr_hdr->size)) && (context->config.ext.rndv_mode != UCP_RNDV_MODE_PUT_ZCOPY)); - sreq->send.lane = ucp_rkey_find_rma_lane(ep->worker->context, ep_config, + sreq->send.lane = ucp_rkey_find_rma_lane(context, ep_config, (is_pipeline_rndv ? sreq->send.rndv_put.rkey->mem_type : sreq->send.mem_type), @@ -1832,7 +1852,10 @@ UCS_PROFILE_FUNC(ucs_status_t, ucp_rndv_data_handler, ucp_request_t *rreq; size_t recv_len; - rreq = ucp_worker_get_request_by_id(worker, rndv_data_hdr->rreq_id); + rreq = UCP_WORKER_GET_REQ_BY_ID(worker, rndv_data_hdr->rreq_id, + return UCS_OK, "RNDV data %p", + rndv_data_hdr); + ucs_assert(!(rreq->flags & UCP_REQUEST_FLAG_RNDV_FRAG) && (rreq->flags & (UCP_REQUEST_FLAG_RECV_AM | UCP_REQUEST_FLAG_RECV_TAG))); diff --git a/src/ucp/stream/stream_recv.c b/src/ucp/stream/stream_recv.c index 7836b3fe256..2dd2cd97efe 100644 --- a/src/ucp/stream/stream_recv.c +++ b/src/ucp/stream/stream_recv.c @@ -528,7 +528,8 @@ ucp_stream_am_handler(void *am_arg, void *am_data, size_t am_length, ucs_assert(am_length >= sizeof(ucp_stream_am_hdr_t)); - ep = ucp_worker_get_ep_by_id(worker, data->hdr.ep_id); + ep = UCP_WORKER_GET_VALID_EP_BY_ID(worker, data->hdr.ep_id, return UCS_OK, + "stream data"); ep_ext = ucp_ep_ext_proto(ep); if (ucs_unlikely(ep->flags & (UCP_EP_FLAG_CLOSED | diff --git a/src/ucp/tag/eager_snd.c b/src/ucp/tag/eager_snd.c index d764a862a3d..59a3a4a1fc3 100644 --- a/src/ucp/tag/eager_snd.c +++ b/src/ucp/tag/eager_snd.c @@ -300,6 +300,7 @@ void ucp_tag_eager_sync_send_ack(ucp_worker_h worker, void *hdr, uint16_t recv_f { ucp_request_hdr_t *reqhdr; ucp_request_t *req; + ucp_ep_h ep; ucs_assert(recv_flags & UCP_RECV_DESC_FLAG_EAGER_SYNC); @@ -317,7 +318,10 @@ void ucp_tag_eager_sync_send_ack(ucp_worker_h worker, void *hdr, uint16_t recv_f } ucs_assert(reqhdr->req_id != UCP_REQUEST_ID_INVALID); - req = ucp_proto_ssend_ack_request_alloc(worker, reqhdr->ep_id); + ep = UCP_WORKER_GET_VALID_EP_BY_ID(worker, reqhdr->ep_id, return, + "ACK for sync-send"); + + req = ucp_proto_ssend_ack_request_alloc(worker, ep); if (req == NULL) { ucs_fatal("could not allocate request"); } diff --git a/src/ucp/tag/offload.c b/src/ucp/tag/offload.c index bd7103d4606..85f8343ba91 100644 --- a/src/ucp/tag/offload.c +++ b/src/ucp/tag/offload.c @@ -724,10 +724,14 @@ void ucp_tag_offload_sync_send_ack(ucp_worker_h worker, ucs_ptr_map_key_t ep_id, ucp_tag_t stag, uint16_t recv_flags) { ucp_request_t *req; + ucp_ep_h ep; ucs_assert(recv_flags & UCP_RECV_DESC_FLAG_EAGER_OFFLOAD); - req = ucp_proto_ssend_ack_request_alloc(worker, ep_id); + ep = UCP_WORKER_GET_VALID_EP_BY_ID(worker, ep_id, return, + "ACK for sync-send"); + + req = ucp_proto_ssend_ack_request_alloc(worker, ep); if (req == NULL) { ucs_fatal("could not allocate request"); } diff --git a/src/ucp/wireup/wireup.c b/src/ucp/wireup/wireup.c index 6e127eda084..51b6cc595af 100644 --- a/src/ucp/wireup/wireup.c +++ b/src/ucp/wireup/wireup.c @@ -372,26 +372,24 @@ ucp_wireup_init_lanes_by_request(ucp_worker_h worker, ucp_ep_h ep, static UCS_F_NOINLINE void -ucp_wireup_process_pre_request(ucp_worker_h worker, const ucp_wireup_msg_t *msg, +ucp_wireup_process_pre_request(ucp_worker_h worker, ucp_ep_h ep, + const ucp_wireup_msg_t *msg, const ucp_unpacked_address_t *remote_address) { unsigned ep_init_flags = UCP_EP_INIT_CREATE_AM_LANE; unsigned addr_indices[UCP_MAX_LANES]; ucs_status_t status; - ucp_ep_h ep; ucs_assert(msg->type == UCP_WIREUP_MSG_PRE_REQUEST); ucs_assert(msg->dst_ep_id != UCP_EP_ID_INVALID); + ucs_assert(ep != NULL); + ucs_assert((ep->flags & UCP_EP_FLAG_SOCKADDR_PARTIAL_ADDR) || + ucp_ep_has_cm_lane(ep)); ucs_trace("got wireup pre_request from 0x%"PRIx64" src_ep_id 0x%"PRIx64 " dst_ep_id 0x%"PRIx64" conn_sn %u", remote_address->uuid, msg->src_ep_id, msg->dst_ep_id, msg->conn_sn); - /* wireup pre_request for a specific ep */ - ep = ucp_worker_get_ep_by_id(worker, msg->dst_ep_id); - ucs_assert((ep->flags & UCP_EP_FLAG_SOCKADDR_PARTIAL_ADDR) || - ucp_ep_has_cm_lane(ep)); - ucp_ep_update_remote_id(ep, msg->src_ep_id); ucp_ep_flush_state_reset(ep); @@ -417,7 +415,8 @@ ucp_wireup_process_pre_request(ucp_worker_h worker, const ucp_wireup_msg_t *msg, } static UCS_F_NOINLINE void -ucp_wireup_process_request(ucp_worker_h worker, const ucp_wireup_msg_t *msg, +ucp_wireup_process_request(ucp_worker_h worker, ucp_ep_h ep, + const ucp_wireup_msg_t *msg, const ucp_unpacked_address_t *remote_address) { uint64_t remote_uuid = remote_address->uuid; @@ -427,7 +426,6 @@ ucp_wireup_process_request(ucp_worker_h worker, const ucp_wireup_msg_t *msg, ucp_rsc_index_t lanes2remote[UCP_MAX_LANES]; unsigned addr_indices[UCP_MAX_LANES]; ucs_status_t status; - ucp_ep_h ep; int has_cm_lane; ucs_assert(msg->type == UCP_WIREUP_MSG_REQUEST); @@ -435,13 +433,13 @@ ucp_wireup_process_request(ucp_worker_h worker, const ucp_wireup_msg_t *msg, " dst_ep_id 0x%"PRIx64" conn_sn %d", remote_address->uuid, msg->src_ep_id, msg->dst_ep_id, msg->conn_sn); - if (msg->dst_ep_id != UCP_EP_ID_INVALID) { - /* wireup request for a specific ep */ - ep = ucp_worker_get_ep_by_id(worker, msg->dst_ep_id); + if (ep != NULL) { + ucs_assert(msg->dst_ep_id != UCP_EP_ID_INVALID); ucp_ep_update_remote_id(ep, msg->src_ep_id); ucp_ep_flush_state_reset(ep); ep_init_flags |= UCP_EP_INIT_CREATE_AM_LANE; } else { + ucs_assert(msg->dst_ep_id == UCP_EP_ID_INVALID); ep = ucp_ep_match_retrieve(worker, remote_uuid, msg->conn_sn ^ (remote_uuid == worker->uuid), @@ -560,17 +558,16 @@ int ucp_wireup_msg_ack_cb_pred(const ucs_callbackq_elem_t *elem, void *arg) } static UCS_F_NOINLINE void -ucp_wireup_process_reply(ucp_worker_h worker, const ucp_wireup_msg_t *msg, +ucp_wireup_process_reply(ucp_worker_h worker, ucp_ep_h ep, + const ucp_wireup_msg_t *msg, const ucp_unpacked_address_t *remote_address) { uct_worker_cb_id_t cb_id = UCS_CALLBACKQ_ID_NULL; ucs_status_t status; - ucp_ep_h ep; int ack; - ep = ucp_worker_get_ep_by_id(worker, msg->dst_ep_id); - ucs_assert(msg->type == UCP_WIREUP_MSG_REPLY); + ucs_assert(ep != NULL); ucs_trace("ep %p: got wireup reply src_ep_id 0x%"PRIx64 " dst_ep_id 0x%"PRIx64" sn %d", ep, msg->src_ep_id, msg->dst_ep_id, msg->conn_sn); @@ -611,11 +608,10 @@ ucp_wireup_process_reply(ucp_worker_h worker, const ucp_wireup_msg_t *msg, } static UCS_F_NOINLINE -void ucp_wireup_process_ack(ucp_worker_h worker, const ucp_wireup_msg_t *msg) +void ucp_wireup_process_ack(ucp_worker_h worker, ucp_ep_h ep, + const ucp_wireup_msg_t *msg) { - ucp_ep_h ep; - - ep = ucp_worker_get_ep_by_id(worker, msg->dst_ep_id); + ucs_assert(ep != NULL); ucs_assert(msg->type == UCP_WIREUP_MSG_ACK); ucs_trace("ep %p: got wireup ack", ep); @@ -639,21 +635,17 @@ static ucs_status_t ucp_wireup_msg_handler(void *arg, void *data, { ucp_worker_h worker = arg; ucp_wireup_msg_t *msg = data; + ucp_ep_h ep = NULL; ucp_unpacked_address_t remote_address; - ucp_ep_h ep UCS_V_UNUSED; ucs_status_t status; UCS_ASYNC_BLOCK(&worker->async); if (msg->dst_ep_id != UCP_EP_ID_INVALID) { - ep = ucp_worker_get_ep_by_id(worker, msg->dst_ep_id); - if (ep == NULL) { - ucs_diag("got wireup msg %d src_ep_id 0x%"PRIx64" for" - " non-existing dst_ep_id 0x%"PRIx64" sn %d," - " ignoring it", - msg->type, msg->src_ep_id, msg->dst_ep_id, msg->conn_sn); - goto out; - } + ep = UCP_WORKER_GET_EP_BY_ID(worker, msg->dst_ep_id, goto out, + "WIREUP message (%d src_ep_id 0x%"PRIx64 + " sn %d)", msg->type, msg->src_ep_id, + msg->conn_sn); } status = ucp_address_unpack(worker, msg + 1, @@ -666,13 +658,13 @@ static ucs_status_t ucp_wireup_msg_handler(void *arg, void *data, if (msg->type == UCP_WIREUP_MSG_ACK) { ucs_assert(remote_address.address_count == 0); - ucp_wireup_process_ack(worker, msg); + ucp_wireup_process_ack(worker, ep, msg); } else if (msg->type == UCP_WIREUP_MSG_PRE_REQUEST) { - ucp_wireup_process_pre_request(worker, msg, &remote_address); + ucp_wireup_process_pre_request(worker, ep, msg, &remote_address); } else if (msg->type == UCP_WIREUP_MSG_REQUEST) { - ucp_wireup_process_request(worker, msg, &remote_address); + ucp_wireup_process_request(worker, ep, msg, &remote_address); } else if (msg->type == UCP_WIREUP_MSG_REPLY) { - ucp_wireup_process_reply(worker, msg, &remote_address); + ucp_wireup_process_reply(worker, ep, msg, &remote_address); } else { ucs_bug("invalid wireup message"); } diff --git a/test/gtest/ucp/test_ucp_am.cc b/test/gtest/ucp/test_ucp_am.cc index 4b2cc22280f..afe7f860e39 100644 --- a/test/gtest/ucp/test_ucp_am.cc +++ b/test/gtest/ucp/test_ucp_am.cc @@ -304,6 +304,7 @@ class test_ucp_am_nbx : public test_ucp_am_base { m_am_received = false; } +protected: size_t max_am_hdr() { ucp_worker_attr_t attr; @@ -420,53 +421,6 @@ class test_ucp_am_nbx : public test_ucp_am_base { } } - void test_recv_on_closed_ep(size_t size, unsigned flags = 0, - bool poke_rx_progress = false, - bool rx_expected = false) - { - skip_loopback(); - test_am_send_recv(0, max_am_hdr()); // warmup wireup - - m_am_received = false; - std::vector sbuf(size, 'd'); - ucp::data_type_desc_t sdt_desc(m_dt, &sbuf[0], size); - - set_am_data_handler(receiver(), TEST_AM_NBX_ID, am_rx_check_cb, this); - - ucs_status_ptr_t sreq = send_am(sdt_desc, flags); - - sender().progress(); - if (poke_rx_progress) { - receiver().progress(); - if (m_am_received) { - request_wait(sreq); - UCS_TEST_SKIP_R("received all AMs before ep closed"); - } - } - - void *close_req = receiver().disconnect_nb(0, 0, - UCP_EP_CLOSE_MODE_FLUSH); - ucs_time_t deadline = ucs::get_deadline(10); - while (!is_request_completed(close_req) && - (ucs_get_time() < deadline)) { - progress(); - }; - - receiver().close_ep_req_free(close_req); - - if (rx_expected) { - request_wait(sreq); - wait_for_flag(&m_am_received); - } else { - // Send request may complete with error - // (rndv should complete with EP_TIMEOUT) - scoped_log_handler wrap_err(wrap_errors_logger); - request_wait(sreq); - } - - EXPECT_EQ(rx_expected, m_am_received); - } - virtual ucs_status_t am_data_handler(const void *header, size_t header_length, void *data, size_t length, @@ -595,7 +549,92 @@ UCS_TEST_P(test_ucp_am_nbx, zero_send) test_am_send_recv(0, max_am_hdr()); } -UCS_TEST_P(test_ucp_am_nbx, rx_short_am_on_closed_ep, "RNDV_THRESH=inf") +UCS_TEST_P(test_ucp_am_nbx, rx_persistent_data) +{ + void *rx_data = NULL; + char data = 'd'; + + set_am_data_handler(receiver(), TEST_AM_NBX_ID, am_data_hold_cb, &rx_data, + UCP_AM_FLAG_PERSISTENT_DATA); + + ucp_request_param_t param; + + param.op_attr_mask = 0ul; + ucs_status_ptr_t sptr = ucp_am_send_nbx(sender().ep(), TEST_AM_NBX_ID, NULL, + 0ul, &data, sizeof(data), ¶m); + wait_for_flag(&rx_data); + EXPECT_TRUE(rx_data != NULL); + EXPECT_EQ(data, *reinterpret_cast(rx_data)); + + ucp_am_data_release(receiver().worker(), rx_data); + EXPECT_EQ(UCS_OK, request_wait(sptr)); +} + +UCP_INSTANTIATE_TEST_CASE(test_ucp_am_nbx) + + +class test_ucp_am_nbx_closed_ep : public test_ucp_am_nbx { +protected: + virtual ucp_ep_params_t get_ep_params() + { + ucp_ep_params_t ep_params = test_ucp_am_nbx::get_ep_params(); + ep_params.field_mask |= UCP_EP_PARAM_FIELD_ERR_HANDLING_MODE; + /* The error handling requirement is needed since we need to take care of + * a case when a receiver tries to fetch data on a closed EP */ + ep_params.err_mode = UCP_ERR_HANDLING_MODE_PEER; + return ep_params; + } + + void test_recv_on_closed_ep(size_t size, unsigned flags = 0, + bool poke_rx_progress = false, + bool rx_expected = false) + { + skip_loopback(); + test_am_send_recv(0, max_am_hdr()); // warmup wireup + + m_am_received = false; + std::vector sbuf(size, 'd'); + ucp::data_type_desc_t sdt_desc(m_dt, &sbuf[0], size); + + set_am_data_handler(receiver(), TEST_AM_NBX_ID, am_rx_check_cb, this); + + ucs_status_ptr_t sreq = send_am(sdt_desc, flags); + + sender().progress(); + if (poke_rx_progress) { + receiver().progress(); + if (m_am_received) { + request_wait(sreq); + UCS_TEST_SKIP_R("received all AMs before ep closed"); + } + } + + void *close_req = receiver().disconnect_nb(0, 0, + UCP_EP_CLOSE_MODE_FLUSH); + ucs_time_t deadline = ucs::get_deadline(10); + while (!is_request_completed(close_req) && + (ucs_get_time() < deadline)) { + progress(); + }; + + receiver().close_ep_req_free(close_req); + + if (rx_expected) { + request_wait(sreq); + wait_for_flag(&m_am_received); + } else { + // Send request may complete with error + // (rndv should complete with EP_TIMEOUT) + scoped_log_handler wrap_err(wrap_errors_logger); + request_wait(sreq); + } + + EXPECT_EQ(rx_expected, m_am_received); + } +}; + + +UCS_TEST_P(test_ucp_am_nbx_closed_ep, rx_short_am_on_closed_ep, "RNDV_THRESH=inf") { // Single fragment message sent without REPLY flag is expected // to be received even if remote side closes its ep @@ -604,53 +643,32 @@ UCS_TEST_P(test_ucp_am_nbx, rx_short_am_on_closed_ep, "RNDV_THRESH=inf") // All the following type of AM messages are expected to be dropped on the // receiver side, when its ep is closed -UCS_TEST_P(test_ucp_am_nbx, rx_short_reply_am_on_closed_ep, "RNDV_THRESH=inf") +UCS_TEST_P(test_ucp_am_nbx_closed_ep, rx_short_reply_am_on_closed_ep, "RNDV_THRESH=inf") { test_recv_on_closed_ep(8, UCP_AM_SEND_REPLY); } -UCS_TEST_P(test_ucp_am_nbx, rx_long_am_on_closed_ep, "RNDV_THRESH=inf") +UCS_TEST_P(test_ucp_am_nbx_closed_ep, rx_long_am_on_closed_ep, "RNDV_THRESH=inf") { test_recv_on_closed_ep(64 * UCS_KBYTE, 0, true); } -UCS_TEST_P(test_ucp_am_nbx, rx_long_reply_am_on_closed_ep, "RNDV_THRESH=inf") +UCS_TEST_P(test_ucp_am_nbx_closed_ep, rx_long_reply_am_on_closed_ep, "RNDV_THRESH=inf") { test_recv_on_closed_ep(64 * UCS_KBYTE, UCP_AM_SEND_REPLY, true); } -UCS_TEST_P(test_ucp_am_nbx, rx_rts_am_on_closed_ep, "RNDV_THRESH=32K") +UCS_TEST_P(test_ucp_am_nbx_closed_ep, rx_rts_am_on_closed_ep, "RNDV_THRESH=32K") { test_recv_on_closed_ep(64 * UCS_KBYTE, 0); } -UCS_TEST_P(test_ucp_am_nbx, rx_rts_reply_am_on_closed_ep, "RNDV_THRESH=32K") +UCS_TEST_P(test_ucp_am_nbx_closed_ep, rx_rts_reply_am_on_closed_ep, "RNDV_THRESH=32K") { test_recv_on_closed_ep(64 * UCS_KBYTE, UCP_AM_SEND_REPLY); } -UCS_TEST_P(test_ucp_am_nbx, rx_persistent_data) -{ - void *rx_data = NULL; - char data = 'd'; - - set_am_data_handler(receiver(), TEST_AM_NBX_ID, am_data_hold_cb, &rx_data, - UCP_AM_FLAG_PERSISTENT_DATA); - - ucp_request_param_t param; - - param.op_attr_mask = 0ul; - ucs_status_ptr_t sptr = ucp_am_send_nbx(sender().ep(), TEST_AM_NBX_ID, NULL, - 0ul, &data, sizeof(data), ¶m); - wait_for_flag(&rx_data); - EXPECT_TRUE(rx_data != NULL); - EXPECT_EQ(data, *reinterpret_cast(rx_data)); - - ucp_am_data_release(receiver().worker(), rx_data); - EXPECT_EQ(UCS_OK, request_wait(sptr)); -} - -UCP_INSTANTIATE_TEST_CASE(test_ucp_am_nbx) +UCP_INSTANTIATE_TEST_CASE(test_ucp_am_nbx_closed_ep) class test_ucp_am_nbx_eager_memtype : public test_ucp_am_nbx { diff --git a/test/gtest/ucp/test_ucp_peer_failure.cc b/test/gtest/ucp/test_ucp_peer_failure.cc index 6ee8e126d87..a21b0f7b01a 100644 --- a/test/gtest/ucp/test_ucp_peer_failure.cc +++ b/test/gtest/ucp/test_ucp_peer_failure.cc @@ -103,10 +103,7 @@ ucp_ep_params_t test_ucp_peer_failure::get_ep_params() { } void test_ucp_peer_failure::set_timeouts() { - /* Set small TL timeouts to reduce testing time */ - m_env.push_back(new ucs::scoped_setenv("UCX_RC_TIMEOUT", "10ms")); - m_env.push_back(new ucs::scoped_setenv("UCX_RC_RNR_TIMEOUT", "10ms")); - m_env.push_back(new ucs::scoped_setenv("UCX_RC_RETRY_COUNT", "2")); + set_tl_timeouts(m_env); } void test_ucp_peer_failure::err_cb(void *arg, ucp_ep_h ep, ucs_status_t status) { diff --git a/test/gtest/ucp/test_ucp_sockaddr.cc b/test/gtest/ucp/test_ucp_sockaddr.cc index 4d32f0ae688..459c257bd99 100644 --- a/test/gtest/ucp/test_ucp_sockaddr.cc +++ b/test/gtest/ucp/test_ucp_sockaddr.cc @@ -17,6 +17,7 @@ extern "C" { #include #include #include +#include #include } @@ -204,6 +205,16 @@ class test_ucp_sockaddr : public ucp_test { UCS_TEST_MESSAGE << "server listening on " << m_test_addr.to_str(); } + static void complete_err_handling_status_verify(ucs_status_t status) + { + EXPECT_TRUE(/* was successful */ + (status == UCS_OK) || + /* completed from error handling for EP */ + (status == UCS_ERR_ENDPOINT_TIMEOUT) || + (status == UCS_ERR_CONNECTION_RESET) || + (status == UCS_ERR_CANCELED)); + } + static void scomplete_cb(void *req, ucs_status_t status) { if ((status == UCS_OK) || @@ -214,12 +225,23 @@ class test_ucp_sockaddr : public ucp_test { UCS_TEST_ABORT("Error: " << ucs_status_string(status)); } + static void scomplete_err_handling_cb(void *req, ucs_status_t status) + { + complete_err_handling_status_verify(status); + } + static void rtag_complete_cb(void *req, ucs_status_t status, ucp_tag_recv_info_t *info) { EXPECT_UCS_OK(status); } + static void rtag_complete_err_handling_cb(void *req, ucs_status_t status, + ucp_tag_recv_info_t *info) + { + complete_err_handling_status_verify(status); + } + static void rstream_complete_cb(void *req, ucs_status_t status, size_t length) { @@ -394,7 +416,7 @@ class test_ucp_sockaddr : public ucp_test { * handle a large worker address but neither ud nor ud_x are present */ ep_params.err_mode = UCP_ERR_HANDLING_MODE_PEER; ep_params.err_handler.cb = err_handler_cb; - ep_params.err_handler.arg = NULL; + ep_params.err_handler.arg = this; return ep_params; } @@ -521,6 +543,7 @@ class test_ucp_sockaddr : public ucp_test { case UCS_ERR_UNREACHABLE: case UCS_ERR_CONNECTION_RESET: case UCS_ERR_NOT_CONNECTED: + case UCS_ERR_ENDPOINT_TIMEOUT: UCS_TEST_MESSAGE << "ignoring error " << ucs_status_string(status) << " on endpoint " << ep; return; @@ -813,10 +836,7 @@ UCP_INSTANTIATE_ALL_TEST_CASE(test_ucp_sockaddr) class test_ucp_sockaddr_destroy_ep_on_err : public test_ucp_sockaddr { public: test_ucp_sockaddr_destroy_ep_on_err() { - /* Set small TL timeouts to reduce testing time */ - m_env.push_back(new ucs::scoped_setenv("UCX_RC_TIMEOUT", "10ms")); - m_env.push_back(new ucs::scoped_setenv("UCX_RC_RNR_TIMEOUT", "10ms")); - m_env.push_back(new ucs::scoped_setenv("UCX_RC_RETRY_COUNT", "2")); + set_tl_timeouts(m_env); } virtual ucp_ep_params_t get_server_ep_params() { @@ -990,6 +1010,8 @@ UCP_INSTANTIATE_ALL_TEST_CASE(test_ucp_sockaddr_with_rma_atomic) class test_ucp_sockaddr_protocols : public test_ucp_sockaddr { public: + virtual ~test_ucp_sockaddr_protocols() { } + static void get_test_variants(std::vector& variants) { /* Atomics not supported for now because need to emulate the case * of using different device than the one selected by default on the @@ -1033,42 +1055,149 @@ class test_ucp_sockaddr_protocols : public test_ucp_sockaddr { << "recv_buf: '" << ucs::compact_string(recv_buf, 20) << "'"; } - void test_tag_send_recv(size_t size, bool is_exp, bool is_sync = false) + typedef void (*stop_cb_t)(void *arg); + + void sreq_mem_dereg(void *sreq) { + if (sreq == NULL) { + /* If send request is NULL, it means that send operation completed + * immediately */ + return; + } + + /* TODO: remove memory deregistration, when UCP requests tracking is added */ + ucp_request_t *req = static_cast(sreq) - 1; + EXPECT_EQ(sender().ucph(), req->send.ep->worker->context); + ucp_request_memory_dereg(sender().ucph(), req->send.datatype, + &req->send.state.dt, req); + } + + void* do_unexp_recv(std::string &recv_buf, size_t size, void *sreq, + bool err_handling, bool send_stop, bool recv_stop) { + ucp_tag_recv_info_t recv_info = {}; + ucp_tag_message_h message; + + do { + short_progress_loop(); + message = ucp_tag_probe_nb(receiver().worker(), + 0, 0, 1, &recv_info); + } while (message == NULL); + + EXPECT_EQ(size, recv_info.length); + EXPECT_EQ(0, recv_info.sender_tag); + + if (err_handling) { + if (recv_stop) { + disconnect(*this, receiver()); + } + + sreq_mem_dereg(sreq); + + if (recv_stop) { + sender().disconnect_nb(0, 0, UCP_EP_CLOSE_MODE_FORCE); + } else { + disconnect(*this, sender()); + receiver().disconnect_nb(0, 0, UCP_EP_CLOSE_MODE_FORCE); + } + } + + ucp_request_param_t recv_param = {}; + recv_param.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK; + /* TODO: remove casting when changed to using NBX API */ + recv_param.cb.recv = reinterpret_cast + ( + !err_handling ? rtag_complete_cb : + rtag_complete_err_handling_cb); + return ucp_tag_msg_recv_nbx(receiver().worker(), &recv_buf[0], size, message, + &recv_param); + } + + void sreq_release(void *sreq) { + if ((sreq == NULL) || !UCS_PTR_IS_PTR(sreq)) { + return; + } + + if (ucp_request_check_status(sreq) == UCS_INPROGRESS) { + ucp_request_t *req = (ucp_request_t*)sreq - 1; + req->flags |= UCP_REQUEST_FLAG_COMPLETED; + + ucp_request_t *req_from_id = + ucp_worker_get_request_by_id(sender().worker(), + req->send.msg_proto.sreq_id); + if (req_from_id != NULL) { + EXPECT_EQ(req, req_from_id); + /* check PTR MAP flag only in this way, since it is debug + * only flag has 0 value in a release mode */ + EXPECT_TRUE((req->flags & UCP_REQUEST_FLAG_IN_PTR_MAP) == + UCP_REQUEST_FLAG_IN_PTR_MAP); + ucp_worker_del_request_id(sender().worker(), req, + req->send.msg_proto.sreq_id); + } + } + + ucp_request_release(sreq); + } + + void test_tag_send_recv(size_t size, bool is_exp, bool is_sync = false, + bool send_stop = false, bool recv_stop = false) { + bool err_handling_test = send_stop || recv_stop; + unsigned num_iters = err_handling_test ? 1 : m_num_iters; + /* send multiple messages to test the protocol both before and after * connection establishment */ - for (int i = 0; i < m_num_iters; i++) { + for (int i = 0; i < num_iters; i++) { std::string send_buf(size, 'x'); std::string recv_buf(size, 'y'); void *rreq = NULL, *sreq = NULL; + std::vector reqs; + + ucs::auto_ptr slh; + if (err_handling_test) { + slh.reset(new scoped_log_handler(wrap_errors_logger)); + } if (is_exp) { - rreq = ucp_tag_recv_nb(receiver().worker(), &recv_buf[0], - size, ucp_dt_make_contig(1), 0, 0, + rreq = ucp_tag_recv_nb(receiver().worker(), &recv_buf[0], size, + ucp_dt_make_contig(1), 0, 0, rtag_complete_cb); + reqs.push_back(rreq); } + ucp_request_param_t send_param = {}; + send_param.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK; + /* TODO: remove casting when changed to using NBX API */ + send_param.cb.send = reinterpret_cast + ( + !err_handling_test ? scomplete_cb : + scomplete_err_handling_cb); if (is_sync) { - sreq = ucp_tag_send_sync_nb(sender().ep(), &send_buf[0], size, - ucp_dt_make_contig(1), 0, - scomplete_cb); + sreq = ucp_tag_send_sync_nbx(sender().ep(), &send_buf[0], size, 0, + &send_param); } else { - sreq = ucp_tag_send_nb(sender().ep(), &send_buf[0], size, - ucp_dt_make_contig(1), 0, scomplete_cb); + sreq = ucp_tag_send_nbx(sender().ep(), &send_buf[0], size, 0, + &send_param); } + reqs.push_back(sreq); if (!is_exp) { - short_progress_loop(); - rreq = ucp_tag_recv_nb(receiver().worker(), &recv_buf[0], size, - ucp_dt_make_contig(1), 0, 0, - rtag_complete_cb); + rreq = do_unexp_recv(recv_buf, size, sreq, err_handling_test, + send_stop, recv_stop); + reqs.push_back(rreq); } - request_wait(sreq); - request_wait(rreq); + if (!err_handling_test) { + requests_wait(reqs); + } else { + /* TODO: add waiting for send request completion, when UCP requests + * tracking is added */ + sreq_release(sreq); + request_wait(rreq); + } - compare_buffers(send_buf, recv_buf); + if (!err_handling_test) { + compare_buffers(send_buf, recv_buf); + } } } @@ -1222,6 +1351,19 @@ class test_ucp_sockaddr_protocols : public test_ucp_sockaddr { ASSERT_UCS_OK(ucp_worker_set_am_recv_handler(e.worker(), ¶m)); } +protected: + enum { + SEND_STOP = UCS_BIT(0), + RECV_STOP = UCS_BIT(1) + }; + + static void disconnect(test_ucp_sockaddr_protocols &test, entity &e) { + test.one_sided_disconnect(e, UCP_EP_CLOSE_MODE_FORCE); + while (m_err_count == 0) { + test.short_progress_loop(); + } + } + private: static const unsigned m_num_iters; }; @@ -1413,3 +1555,98 @@ UCS_TEST_P(test_ucp_sockaddr_protocols, am_zcopy_64k, UCP_INSTANTIATE_TEST_CASE_TLS(_test_case, all, "all") UCP_INSTANTIATE_CM_TEST_CASE(test_ucp_sockaddr_protocols) + + +class test_ucp_sockaddr_protocols_err : public test_ucp_sockaddr_protocols { +public: + static void get_test_variants(std::vector& variants) { + uint64_t features = UCP_FEATURE_TAG; + test_ucp_sockaddr::get_test_variants_mt(variants, features, SEND_STOP, + "send_stop"); + test_ucp_sockaddr::get_test_variants_mt(variants, features, RECV_STOP, + "recv_stop"); + test_ucp_sockaddr::get_test_variants_mt(variants, features, + SEND_STOP | RECV_STOP, "bidi_stop"); + } + +protected: + test_ucp_sockaddr_protocols_err() { + set_tl_timeouts(m_env); + } + + void init() { + test_ucp_sockaddr_protocols::init(); + } + + void test_tag_send_recv(size_t size, bool is_exp, + bool is_sync = false) { + /* warmup */ + test_ucp_sockaddr_protocols::test_tag_send_recv(size, is_exp, is_sync); + + /* run error-handling test */ + int variants = get_variant_value(); + test_ucp_sockaddr_protocols::test_tag_send_recv(size, is_exp, is_sync, + variants & SEND_STOP, + variants & RECV_STOP); + } + + void cleanup() { + test_ucp_sockaddr_protocols::cleanup(); + } + + static void err_handler_cb(void *arg, ucp_ep_h ep, ucs_status_t status) { + test_ucp_sockaddr::err_handler_cb(arg, ep, status); + + test_ucp_sockaddr_protocols *test = + static_cast(arg); + if (test->sender().ep() == ep) { + test->sender().disconnect_nb(0, 0, UCP_EP_CLOSE_MODE_FORCE); + } else { + ASSERT_EQ(test->receiver().ep(), ep); + test->receiver().disconnect_nb(0, 0, UCP_EP_CLOSE_MODE_FORCE); + } + } + +protected: + ucs::ptr_vector m_env; +}; + + +UCS_TEST_P(test_ucp_sockaddr_protocols_err, tag_zcopy_4k_unexp, + "ZCOPY_THRESH=2k", "RNDV_THRESH=inf") +{ + test_tag_send_recv(4 * UCS_KBYTE, false, false); + +} + +UCS_TEST_P(test_ucp_sockaddr_protocols_err, tag_zcopy_64k_unexp, + "ZCOPY_THRESH=2k", "RNDV_THRESH=inf") +{ + test_tag_send_recv(64 * UCS_KBYTE, false, false); +} + +UCS_TEST_P(test_ucp_sockaddr_protocols_err, tag_zcopy_4k_unexp_sync, + "ZCOPY_THRESH=2k", "RNDV_THRESH=inf") +{ + test_tag_send_recv(4 * UCS_KBYTE, false, true); +} + +UCS_TEST_P(test_ucp_sockaddr_protocols_err, tag_zcopy_64k_unexp_sync, + "ZCOPY_THRESH=2k", "RNDV_THRESH=inf") +{ + test_tag_send_recv(64 * UCS_KBYTE, false, true); +} + +UCS_TEST_P(test_ucp_sockaddr_protocols_err, tag_rndv_unexp, + "RNDV_THRESH=0") +{ + test_tag_send_recv(64 * UCS_KBYTE, false, false); +} + +UCS_TEST_P(test_ucp_sockaddr_protocols_err, tag_rndv_unexp_put_scheme, + "RNDV_THRESH=0", "RNDV_SCHEME=put_zcopy") +{ + test_tag_send_recv(64 * UCS_KBYTE, false, false); +} + +UCP_INSTANTIATE_CM_TEST_CASE(test_ucp_sockaddr_protocols_err) diff --git a/test/gtest/ucp/ucp_test.cc b/test/gtest/ucp/ucp_test.cc index 69cf3360b45..09e564624b4 100644 --- a/test/gtest/ucp/ucp_test.cc +++ b/test/gtest/ucp/ucp_test.cc @@ -240,6 +240,22 @@ ucs_status_t ucp_test::request_wait(void *req, int worker_index) return request_process(req, worker_index, true); } +ucs_status_t ucp_test::requests_wait(const std::vector &reqs, + int worker_index) +{ + ucs_status_t ret_status = UCS_OK; + + for (std::vector::const_iterator it = reqs.begin(); it != reqs.end(); + ++it) { + ucs_status_t status = request_process(*it, worker_index, true); + if (status != UCS_OK) { + ret_status = status; + } + } + + return ret_status; +} + void ucp_test::request_release(void *req) { request_process(req, 0, false); @@ -253,6 +269,14 @@ int ucp_test::max_connections() { } } +void ucp_test::set_tl_timeouts(ucs::ptr_vector &env) +{ + /* Set small TL timeouts to reduce testing time */ + env.push_back(new ucs::scoped_setenv("UCX_RC_TIMEOUT", "10ms")); + env.push_back(new ucs::scoped_setenv("UCX_RC_RNR_TIMEOUT", "10ms")); + env.push_back(new ucs::scoped_setenv("UCX_RC_RETRY_COUNT", "2")); +} + void ucp_test::set_ucp_config(ucp_config_t *config, const std::string& tls) { ucs_status_t status; diff --git a/test/gtest/ucp/ucp_test.h b/test/gtest/ucp/ucp_test.h index d91b262cceb..7bc01118aa8 100644 --- a/test/gtest/ucp/ucp_test.h +++ b/test/gtest/ucp/ucp_test.h @@ -229,8 +229,10 @@ class ucp_test : public ucp_test_base, void flush_worker(const entity &e, int worker_index = 0); void disconnect(entity& entity); ucs_status_t request_wait(void *req, int worker_index = 0); + ucs_status_t requests_wait(const std::vector &reqs, int worker_index = 0); void request_release(void *req); int max_connections(); + void set_tl_timeouts(ucs::ptr_vector &env); // Add test variant without values, with given context params static ucp_test_variant&