Skip to content

Commit

Permalink
Merge pull request openucx#19 from hoopoepg/topic/tag-send-flags-eage…
Browse files Browse the repository at this point in the history
…r-rndv

UCP/TAG/SEND: added eager/rndv flags to tag_send op
  • Loading branch information
yosefe authored Sep 24, 2021
2 parents c1a0486 + f6a9275 commit aa500aa
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 14 deletions.
21 changes: 20 additions & 1 deletion src/ucp/api/ucp.h
Original file line number Diff line number Diff line change
Expand Up @@ -709,6 +709,21 @@ typedef enum {
} ucp_op_attr_t;


/**
* @ingroup UCP_COMM
* @brief UCP tag send operation flags
*
* Flags dictate the behavior of @ref ucp_tag_send_nbx and
* @ref ucp_tag_send_sync_nbx routines.
*/
typedef enum {
UCP_EP_TAG_SEND_FLAG_EAGER = UCS_BIT(0), /**< force use eager protocol
to transfer data */
UCP_EP_TAG_SEND_FLAG_RNDV = UCS_BIT(1) /**< force use rndv protocol
to transfer data */
} ucp_ep_tag_send_flags_t;


/**
* @ingroup UCP_COMM
* @brief UCP request query attributes
Expand Down Expand Up @@ -3484,7 +3499,11 @@ ucs_status_ptr_t ucp_tag_send_sync_nb(ucp_ep_h ep, const void *buffer, size_t co
* @param [in] buffer Pointer to the message buffer (payload).
* @param [in] count Number of elements to send
* @param [in] tag Message tag.
* @param [in] param Operation parameters, see @ref ucp_request_param_t
* @param [in] param Operation parameters, see @ref ucp_request_param_t.
* This operation supports specific flags, which can be
* passed in @a param by @ref ucp_request_param_t.flags.
* The exact set of flags is defined
* by @ref ucp_ep_tag_send_flags_t.
*
* @return UCS_OK - The send operation was completed immediately.
* @return UCS_PTR_IS_ERR(_ptr) - The send operation failed.
Expand Down
36 changes: 31 additions & 5 deletions src/ucp/tag/tag_send.c
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,18 @@
static UCS_F_ALWAYS_INLINE size_t
ucp_tag_get_rndv_threshold(const ucp_request_t *req, size_t count,
size_t max_iov, size_t rndv_rma_thresh,
size_t rndv_am_thresh)
size_t rndv_am_thresh, uint32_t flags)
{
/* Eager protocol requested - set rndv threshold to max */
if (flags & UCP_EP_TAG_SEND_FLAG_EAGER) {
return SIZE_MAX;
}

/* RNDV protocol requested - set rndv threshold to 0 */
if (flags & UCP_EP_TAG_SEND_FLAG_RNDV) {
return 0;
}

switch (req->send.datatype & UCP_DATATYPE_CLASS_MASK) {
case UCP_DATATYPE_IOV:
if ((count > max_iov) &&
Expand Down Expand Up @@ -54,6 +64,7 @@ ucp_tag_send_req(ucp_request_t *req, size_t dt_count,
{
ssize_t max_short = ucp_proto_get_short_max(req, msg_config);
ucp_ep_config_t *ep_config = ucp_ep_config(req->send.ep);
uint32_t flags = ucp_request_param_flags(param);
ucs_status_t status;
size_t zcopy_thresh;
size_t rndv_thresh;
Expand All @@ -65,7 +76,8 @@ ucp_tag_send_req(ucp_request_t *req, size_t dt_count,
&rndv_rma_thresh, &rndv_am_thresh);

rndv_thresh = ucp_tag_get_rndv_threshold(req, dt_count, msg_config->max_iov,
rndv_rma_thresh, rndv_am_thresh);
rndv_rma_thresh, rndv_am_thresh,
flags);

if (!(param->op_attr_mask & UCP_OP_ATTR_FLAG_FAST_CMPL) ||
ucs_unlikely(!UCP_MEM_IS_HOST(req->send.mem_type))) {
Expand Down Expand Up @@ -225,7 +237,8 @@ UCS_PROFILE_FUNC(ucs_status_ptr_t, ucp_tag_send_nbx,
ucp_ep_h ep, const void *buffer, size_t count,
ucp_tag_t tag, const ucp_request_param_t *param)
{
size_t contig_length = 0;
size_t contig_length = 0;
uint32_t UCS_V_UNUSED flags = ucp_request_param_flags(param);
ucs_status_t status;
ucp_request_t *req;
ucs_status_ptr_t ret;
Expand All @@ -237,6 +250,12 @@ UCS_PROFILE_FUNC(ucs_status_ptr_t, ucp_tag_send_nbx,
return UCS_STATUS_PTR(UCS_ERR_INVALID_PARAM));
UCP_REQUEST_CHECK_PARAM(param);

if (ENABLE_PARAMS_CHECK &&
ucs_test_all_flags(flags, UCP_EP_TAG_SEND_FLAG_EAGER |
UCP_EP_TAG_SEND_FLAG_RNDV)) {
return UCS_STATUS_PTR(UCS_ERR_INVALID_PARAM);
}

UCP_WORKER_THREAD_CS_ENTER_CONDITIONAL(ep->worker);

ucs_trace_req("send_nbx buffer %p count %zu tag %"PRIx64" to %s",
Expand Down Expand Up @@ -296,8 +315,9 @@ UCS_PROFILE_FUNC(ucs_status_ptr_t, ucp_tag_send_sync_nbx,
ucp_ep_h ep, const void *buffer, size_t count,
ucp_tag_t tag, const ucp_request_param_t *param)
{
ucp_worker_h worker = ep->worker;
size_t contig_length = 0;
ucp_worker_h worker = ep->worker;
size_t contig_length = 0;
uint32_t UCS_V_UNUSED flags = ucp_request_param_flags(param);
ucs_status_t status;
ucp_request_t *req;
ucs_status_ptr_t ret;
Expand All @@ -308,6 +328,12 @@ UCS_PROFILE_FUNC(ucs_status_ptr_t, ucp_tag_send_sync_nbx,
UCS_ERR_INVALID_PARAM));
UCP_REQUEST_CHECK_PARAM(param);

if (ENABLE_PARAMS_CHECK &&
ucs_test_all_flags(flags, UCP_EP_TAG_SEND_FLAG_EAGER |
UCP_EP_TAG_SEND_FLAG_RNDV)) {
return UCS_STATUS_PTR(UCS_ERR_INVALID_PARAM);
}

UCP_WORKER_THREAD_CS_ENTER_CONDITIONAL(worker);

ucs_trace_req("send_sync_nbx buffer %p count %zu tag %"PRIx64" to %s",
Expand Down
14 changes: 12 additions & 2 deletions test/apps/iodemo/io_demo.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ typedef struct {
bool debug_timeout;
bool use_epoll;
ucs_memory_type_t memory_type;
size_t rndv_thresh;
} options_t;

#define LOG_PREFIX "[DEMO]"
Expand Down Expand Up @@ -724,7 +725,8 @@ class P2pDemoCommon : public UcxContext {

P2pDemoCommon(const options_t &test_opts) :
UcxContext(test_opts.iomsg_size, test_opts.connect_timeout,
test_opts.use_am, test_opts.use_epoll),
test_opts.use_am, test_opts.rndv_thresh,
test_opts.use_epoll),
_test_opts(test_opts),
_io_msg_pool(test_opts.iomsg_size, "io messages"),
_send_callback_pool(0, "send callbacks"),
Expand Down Expand Up @@ -2430,9 +2432,10 @@ static int parse_args(int argc, char **argv, options_t *test_opts)
test_opts->debug_timeout = false;
test_opts->use_epoll = false;
test_opts->memory_type = UCS_MEMORY_TYPE_HOST;
test_opts->rndv_thresh = UcxContext::rndv_thresh_auto;

while ((c = getopt(argc, argv,
"p:c:r:d:b:i:w:a:k:o:t:n:l:s:y:vqeADHP:m:")) != -1) {
"p:c:r:d:b:i:w:a:k:o:t:n:l:s:y:vqeADHP:m:R:")) != -1) {
switch (c) {
case 'p':
test_opts->port_num = atoi(optarg);
Expand Down Expand Up @@ -2574,6 +2577,9 @@ static int parse_args(int argc, char **argv, options_t *test_opts)
return -1;
}
break;
case 'R':
test_opts->rndv_thresh = strtol(optarg, NULL, 0);
break;
case 'h':
default:
std::cout << "Usage: io_demo [options] [server_address]" << std::endl;
Expand Down Expand Up @@ -2607,6 +2613,10 @@ static int parse_args(int argc, char **argv, options_t *test_opts)
std::cout << " -D Enable debugging mode for IO operation timeouts" << std::endl;
std::cout << " -H Use human-readable timestamps" << std::endl;
std::cout << " -P <interval> Set report printing interval" << std::endl;
std::cout << " -R <threshold> Always use rendezvous protocol for messages starting" << std::endl;
std::cout << " from this size, and eager protocol for" << std::endl;
std::cout << " messages lower than this size. If not set," << std::endl;
std::cout << " the threshold is selected automatically by UCX" << std::endl;
std::cout << "" << std::endl;
std::cout << " -m <memory_type> Memory type to use. Possible values: host"
#ifdef HAVE_CUDA
Expand Down
22 changes: 17 additions & 5 deletions test/apps/iodemo/ucx_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,10 @@ void UcxContext::UcxDisconnectCallback::operator()(ucs_status_t status)
}

UcxContext::UcxContext(size_t iomsg_size, double connect_timeout, bool use_am,
bool use_epoll) :
size_t rndv_thresh, bool use_epoll) :
_context(NULL), _worker(NULL), _listener(NULL), _iomsg_recv_request(NULL),
_iomsg_buffer(iomsg_size, '\0'), _connect_timeout(connect_timeout),
_use_am(use_am), _worker_fd(-1), _epoll_fd(-1)
_use_am(use_am), _worker_fd(-1), _epoll_fd(-1), _rndv_thresh(rndv_thresh)
{
if (use_epoll) {
_epoll_fd = epoll_create(1);
Expand Down Expand Up @@ -1152,16 +1152,28 @@ void UcxConnection::established(ucs_status_t status)
bool UcxConnection::send_common(const void *buffer, size_t length, ucp_tag_t tag,
UcxCallback* callback)
{
ucp_request_param_t params;

if (_ep == NULL) {
(*callback)(UCS_ERR_CANCELED);
return false;
}

assert(_ucx_status == UCS_OK);

ucs_status_ptr_t ptr_status = ucp_tag_send_nb(_ep, buffer, length,
ucp_dt_make_contig(1), tag,
common_request_callback);
params.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK;
params.cb.send = (ucp_send_nbx_callback_t)common_request_callback;
/* suppress coverity false-positive */
params.datatype = ucp_dt_make_contig(1);
if (_context.rndv_thresh() != UcxContext::rndv_thresh_auto) {
params.op_attr_mask |= UCP_OP_ATTR_FIELD_FLAGS;
params.flags = (length >= _context.rndv_thresh()) ?
UCP_EP_TAG_SEND_FLAG_RNDV :
UCP_EP_TAG_SEND_FLAG_EAGER;
}

ucs_status_ptr_t ptr_status = ucp_tag_send_nbx(_ep, buffer, length, tag,
&params);
return process_request("ucp_tag_send_nb", ptr_status, callback);
}

Expand Down
10 changes: 9 additions & 1 deletion test/apps/iodemo/ucx_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,10 @@ class UcxContext {
};

public:
static const size_t rndv_thresh_auto = (size_t)-2;

UcxContext(size_t iomsg_size, double connect_timeout, bool use_am,
bool use_epoll = false);
size_t rndv_thresh, bool use_epoll = false);

virtual ~UcxContext();

Expand Down Expand Up @@ -249,6 +251,11 @@ class UcxContext {

void destroy_worker();

size_t rndv_thresh() const
{
return _rndv_thresh;
}

void set_am_handler(ucp_am_recv_callback_t cb, void *arg);

ucp_context_h _context;
Expand All @@ -265,6 +272,7 @@ class UcxContext {
bool _use_am;
int _worker_fd;
int _epoll_fd;
size_t _rndv_thresh;
};


Expand Down

0 comments on commit aa500aa

Please sign in to comment.