diff --git a/src/ucp/core/ucp_am.c b/src/ucp/core/ucp_am.c index d3cd1782fed..0a194a321ab 100644 --- a/src/ucp/core/ucp_am.c +++ b/src/ucp/core/ucp_am.c @@ -77,18 +77,24 @@ void ucp_am_ep_cleanup(ucp_ep_h ep) size_t ucp_am_max_header_size(ucp_worker_h worker) { + ucp_context_h context = worker->context; uct_iface_attr_t *if_attr; ucp_rsc_index_t iface_id; size_t max_am_header, max_uct_fragment; + size_t max_rts_size, max_ucp_header; - if (!(worker->context->config.features & UCP_FEATURE_AM)) { + if (!(context->config.features & UCP_FEATURE_AM)) { return 0ul; } - max_am_header = SIZE_MAX; + max_am_header = SIZE_MAX; + max_rts_size = sizeof(ucp_am_rndv_rts_hdr_t) + + ucp_rkey_packed_size(context, UCS_MASK(context->num_mds)); + max_ucp_header = ucs_max(max_rts_size, sizeof(ucp_am_first_hdr_t)); - /* TODO: Make sure maximal AM header can fit into one bcopy fragment - * together with RTS */ + /* Make sure maximal AM header can fit into one bcopy fragment + * together with RTS or first eager header (whatever is bigger) + */ for (iface_id = 0; iface_id < worker->num_ifaces; ++iface_id) { if_attr = &worker->ifaces[iface_id]->attr; @@ -103,9 +109,8 @@ size_t ucp_am_max_header_size(ucp_worker_h worker) */ if (if_attr->cap.flags & UCT_IFACE_FLAG_AM_BCOPY) { max_uct_fragment = ucs_max(if_attr->cap.am.max_bcopy, - sizeof(ucp_am_first_hdr_t) - 1) - - sizeof(ucp_am_first_hdr_t) - 1; - max_am_header = ucs_min(max_am_header, max_uct_fragment); + max_ucp_header - 1) - max_ucp_header - 1; + max_am_header = ucs_min(max_am_header, max_uct_fragment); } } @@ -828,12 +833,19 @@ ucp_am_send_req(ucp_request_t *req, size_t count, * TODO: Consider other ways to send user header, like packing together * with UCT AM header, direct registration of user header buffer, etc. */ - zcopy_thresh = SIZE_MAX; + zcopy_thresh = rndv_thresh; } else { zcopy_thresh = ucp_proto_get_zcopy_threshold(req, msg_config, count, rndv_thresh); } + ucs_trace_req("select am request(%p) progress algorithm datatype=0x%"PRIx64 + " buffer=%p length=%zu header_length=%u max_short=%zd" + " rndv_thresh=%zu zcopy_thresh=%zu", + req, req->send.datatype, req->send.buffer, req->send.length, + req->send.msg_proto.am.header_length, max_short, rndv_thresh, + zcopy_thresh); + status = ucp_request_send_start(req, max_short, zcopy_thresh, rndv_thresh, count, !!user_header_length, ucp_am_send_req_total_size(req), @@ -843,7 +855,7 @@ ucp_am_send_req(ucp_request_t *req, size_t count, return UCS_STATUS_PTR(status); } - ucs_assert(req->send.length >= rndv_thresh); + ucs_assert(ucp_am_send_req_total_size(req) >= rndv_thresh); status = ucp_am_send_start_rndv(req); if (status != UCS_OK) { @@ -881,7 +893,7 @@ ucp_am_try_send_short(ucp_ep_h ep, uint16_t id, uint32_t flags, { if (ucs_unlikely(((length != 0) && (header_length != 0)) || ((ssize_t)(length + header_length) > - ucp_ep_config(ep)->am.max_short)) || + ucp_ep_config(ep)->am_u.max_eager_short)) || (flags & UCP_AM_SEND_FLAG_RNDV)) { goto out; } @@ -959,12 +971,12 @@ UCS_PROFILE_FUNC(ucs_status_ptr_t, ucp_am_send_nbx, if (flags & UCP_AM_SEND_REPLY) { ret = ucp_am_send_req(req, count, &ucp_ep_config(ep)->am, param, ucp_ep_config(ep)->am_u.reply_proto, - ucs_min(ucp_ep_config(ep)->am.max_short, + ucs_min(ucp_ep_config(ep)->am_u.max_eager_short, UCP_AM_SHORT_REPLY_MAX_SIZE), flags); } else { ret = ucp_am_send_req(req, count, &ucp_ep_config(ep)->am, param, ucp_ep_config(ep)->am_u.proto, - ucp_ep_config(ep)->am.max_short, flags); + ucp_ep_config(ep)->am_u.max_eager_short, flags); } out: diff --git a/src/ucp/core/ucp_ep.c b/src/ucp/core/ucp_ep.c index 91b8d277192..d84b9cd2f37 100644 --- a/src/ucp/core/ucp_ep.c +++ b/src/ucp/core/ucp_ep.c @@ -1291,7 +1291,8 @@ static void ucp_ep_config_set_am_rndv_thresh(ucp_worker_h worker, ucp_ep_config_t *config, size_t min_rndv_thresh, size_t max_rndv_thresh, - ucp_rndv_thresh_t *thresh) + ucp_rndv_thresh_t *thresh, + ssize_t *max_short_to_adjust) { ucp_context_h context = worker->context; size_t rndv_thresh, rndv_local_thresh, min_thresh; @@ -1318,8 +1319,7 @@ static void ucp_ep_config_set_am_rndv_thresh(ucp_worker_h worker, rndv_local_thresh = context->config.ext.rndv_thresh; /* adjust max_short if rndv_thresh is set externally */ - ucp_ep_config_adjust_max_short(&config->tag.eager.max_short, - rndv_thresh); + ucp_ep_config_adjust_max_short(max_short_to_adjust, rndv_thresh); } min_thresh = ucs_max(iface_attr->cap.am.min_zcopy, min_rndv_thresh); @@ -1335,7 +1335,8 @@ static void ucp_ep_config_set_rndv_thresh(ucp_worker_t *worker, ucp_lane_index_t *lanes, size_t min_rndv_thresh, size_t max_rndv_thresh, - ucp_rndv_thresh_t *thresh) + ucp_rndv_thresh_t *thresh, + ssize_t *max_short_to_adjust) { ucp_context_t *context = worker->context; ucp_lane_index_t lane = lanes[0]; @@ -1368,8 +1369,7 @@ static void ucp_ep_config_set_rndv_thresh(ucp_worker_t *worker, rndv_local_thresh = context->config.ext.rndv_thresh; /* adjust max_short if rndv_thresh is set externally */ - ucp_ep_config_adjust_max_short(&config->tag.eager.max_short, - rndv_thresh); + ucp_ep_config_adjust_max_short(max_short_to_adjust, rndv_thresh); } min_thresh = ucs_max(iface_attr->cap.get.min_zcopy, min_rndv_thresh); @@ -1710,13 +1710,15 @@ ucs_status_t ucp_ep_config_init(ucp_worker_h worker, ucp_ep_config_t *config, tag_lanes[0] = lane; ucp_ep_config_set_rndv_thresh(worker, config, tag_lanes, min_rndv_thresh, max_rndv_thresh, - &config->tag.rndv.rma_thresh); + &config->tag.rndv.rma_thresh, + &config->tag.eager.max_short); md_attr = &context->tl_mds[context->tl_rscs[rsc_index].md_index].attr; ucp_ep_config_set_am_rndv_thresh(worker, iface_attr, md_attr, config, min_am_rndv_thresh, max_am_rndv_thresh, - &config->tag.rndv.am_thresh); + &config->tag.rndv.am_thresh, + &config->tag.eager.max_short); } /* Max Eager short has to be set after Zcopy and RNDV thresholds */ @@ -1743,6 +1745,7 @@ ucs_status_t ucp_ep_config_init(ucp_worker_h worker, ucp_ep_config_t *config, UCT_IFACE_FLAG_AM_BCOPY, UCT_IFACE_FLAG_AM_ZCOPY, sizeof(ucp_eager_hdr_t), SIZE_MAX); + config->am_u.max_eager_short = config->am.max_short; /* Calculate rendezvous thresholds which may be used by UCP AM * protocol. */ @@ -1750,17 +1753,20 @@ ucs_status_t ucp_ep_config_init(ucp_worker_h worker, ucp_ep_config_t *config, rkey_ptr_lanes[0] = config->key.rkey_ptr_lane; ucp_ep_config_set_rndv_thresh(worker, config, rkey_ptr_lanes, iface_attr->cap.get.min_zcopy, - SIZE_MAX, &config->rndv.rma_thresh); + SIZE_MAX, &config->rndv.rma_thresh, + &config->am_u.max_eager_short); } else { ucp_ep_config_set_rndv_thresh(worker, config, config->key.rma_bw_lanes, iface_attr->cap.get.min_zcopy, - SIZE_MAX, &config->rndv.rma_thresh); + SIZE_MAX, &config->rndv.rma_thresh, + &config->am_u.max_eager_short); } ucp_ep_config_set_am_rndv_thresh(worker, iface_attr, md_attr, config, iface_attr->cap.am.min_zcopy, - SIZE_MAX, &config->rndv.am_thresh); + SIZE_MAX, &config->rndv.am_thresh, + &config->am_u.max_eager_short); /* All keys must fit in RNDV packet. * TODO remove some MDs if they don't @@ -1773,18 +1779,11 @@ ucs_status_t ucp_ep_config_init(ucp_worker_h worker, ucp_ep_config_t *config, /* TODO: set threshold level based on all available lanes */ config->tag.eager = config->am; + config->tag.eager.max_short = config->am_u.max_eager_short; config->tag.lane = lane; config->tag.rndv.am_thresh = config->rndv.am_thresh; config->tag.rndv.rma_thresh = config->rndv.rma_thresh; - if (context->config.ext.rndv_thresh != UCS_MEMUNITS_AUTO) { - /* adjust max_short if rndv_thresh is set externally */ - min_rndv_thresh = ucs_min(config->tag.rndv.rma_thresh.remote, - config->tag.rndv.am_thresh.remote); - ucp_ep_config_adjust_max_short(&config->tag.eager.max_short, - min_rndv_thresh); - } - /* Max Eager short has to be set after Zcopy and RNDV thresholds */ ucp_ep_config_set_memtype_thresh(&config->tag.max_eager_short, config->tag.eager.max_short, diff --git a/src/ucp/core/ucp_ep.h b/src/ucp/core/ucp_ep.h index 5fffacf887b..d967877732a 100644 --- a/src/ucp/core/ucp_ep.h +++ b/src/ucp/core/ucp_ep.h @@ -334,6 +334,9 @@ struct ucp_ep_config { /* Protocols used for am operations */ const ucp_request_send_proto_t *proto; const ucp_request_send_proto_t *reply_proto; + + /* Maximal size for eager short */ + ssize_t max_eager_short; } am_u; /* Protocol selection data */ diff --git a/test/gtest/ucp/test_ucp_am.cc b/test/gtest/ucp/test_ucp_am.cc index e246c497d48..d2f01047af1 100644 --- a/test/gtest/ucp/test_ucp_am.cc +++ b/test/gtest/ucp/test_ucp_am.cc @@ -694,6 +694,21 @@ UCS_TEST_P(test_ucp_am_nbx_rndv, rndv_flag_send, "RNDV_THRESH=inf") test_am_send_recv(64 * UCS_KBYTE, 0, UCP_AM_SEND_FLAG_RNDV); } +UCS_TEST_P(test_ucp_am_nbx_rndv, rndv_zero_send, "RNDV_THRESH=0") +{ + test_am_send_recv(0); +} + +UCS_TEST_P(test_ucp_am_nbx_rndv, just_header_rndv, "RNDV_THRESH=1") +{ + test_am_send_recv(0, max_am_hdr()); +} + +UCS_TEST_P(test_ucp_am_nbx_rndv, header_and_data_rndv, "RNDV_THRESH=128") +{ + test_am_send_recv(127, 1); +} + UCS_TEST_P(test_ucp_am_nbx_rndv, reject_rndv) { skip_loopback();