Skip to content

Commit

Permalink
UCP/TAG: Handle canceling rendezvous requests
Browse files Browse the repository at this point in the history
  • Loading branch information
yosefe committed Apr 22, 2020
1 parent c652998 commit a010441
Show file tree
Hide file tree
Showing 9 changed files with 174 additions and 13 deletions.
2 changes: 2 additions & 0 deletions src/ucp/core/ucp_ep.c
Original file line number Diff line number Diff line change
Expand Up @@ -912,6 +912,8 @@ ucs_status_ptr_t ucp_ep_close_nb(ucp_ep_h ep, unsigned mode)

UCS_ASYNC_BLOCK(&worker->async);

ucp_ep_complete_rndv_reqs(ep);

ep->flags |= UCP_EP_FLAG_CLOSED;
request = ucp_ep_flush_internal(ep,
(mode == UCP_EP_CLOSE_MODE_FLUSH) ?
Expand Down
1 change: 1 addition & 0 deletions src/ucp/core/ucp_ep.inl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "ucp_ep.h"
#include "ucp_worker.h"
#include "ucp_context.h"
#include "ucp_request.h"

#include <ucp/wireup/wireup.h>
#include <ucs/arch/bitops.h>
Expand Down
8 changes: 8 additions & 0 deletions src/ucp/core/ucp_request.c
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "ucp_request.inl"

#include <ucp/proto/proto_am.h>
#include <ucp/tag/rndv.h>

#include <ucs/datastruct/mpool.inl>
#include <ucs/debug/debug.h>
Expand Down Expand Up @@ -127,6 +128,13 @@ UCS_PROFILE_FUNC_VOID(ucp_request_cancel, (worker, request),
}

UCP_WORKER_THREAD_CS_EXIT_CONDITIONAL(worker);

return;
}

if (req->flags & UCP_REQUEST_FLAG_SEND_RNDV) {
ucp_tag_rndv_cancel(req);
return;
}
}

Expand Down
16 changes: 12 additions & 4 deletions src/ucp/core/ucp_request.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,13 @@ enum {
UCP_REQUEST_FLAG_STREAM_RECV_WAITALL = UCS_BIT(12),
UCP_REQUEST_FLAG_SEND_AM = UCS_BIT(13),
UCP_REQUEST_FLAG_SEND_TAG = UCS_BIT(14),
UCP_REQUEST_FLAG_SEND_RNDV = UCS_BIT(15),
UCP_REQUEST_FLAG_RNDV_RTS_SENT = UCS_BIT(16),
UCP_REQUEST_FLAG_CANCELED = UCS_BIT(17),
#if UCS_ENABLE_ASSERT
UCP_REQUEST_FLAG_STREAM_RECV = UCS_BIT(16),
UCP_REQUEST_DEBUG_FLAG_EXTERNAL = UCS_BIT(17),
UCP_REQUEST_DEBUG_RNDV_FRAG = UCS_BIT(18)
UCP_REQUEST_FLAG_STREAM_RECV = UCS_BIT(29),
UCP_REQUEST_DEBUG_FLAG_EXTERNAL = UCS_BIT(30),
UCP_REQUEST_DEBUG_RNDV_FRAG = UCS_BIT(31)
#else
UCP_REQUEST_FLAG_STREAM_RECV = 0,
UCP_REQUEST_DEBUG_FLAG_EXTERNAL = 0,
Expand Down Expand Up @@ -243,7 +246,12 @@ struct ucp_request {
ucp_lane_index_t pending_lane; /* Lane on which request was moved
* to pending state */
ucp_lane_index_t lane; /* Lane on which this request is being sent */
uct_pending_req_t uct; /* UCT pending request */

union {
uct_pending_req_t uct; /* UCT pending request */
ucs_list_link_t list; /* UCP list */
};

ucp_mem_desc_t *mdesc;
} send;

Expand Down
1 change: 1 addition & 0 deletions src/ucp/core/ucp_worker.c
Original file line number Diff line number Diff line change
Expand Up @@ -1703,6 +1703,7 @@ ucs_status_t ucp_worker_create(ucp_context_h context,
ucs_list_head_init(&worker->stream_ready_eps);
ucs_list_head_init(&worker->all_eps);
ucp_ep_match_init(&worker->ep_match_ctx);
ucs_list_head_init(&worker->rndv_reqs_list);

UCS_STATIC_ASSERT(sizeof(ucp_ep_ext_gen_t) <= sizeof(ucp_ep_t));
if (context->config.features & (UCP_FEATURE_STREAM | UCP_FEATURE_AM)) {
Expand Down
1 change: 1 addition & 0 deletions src/ucp/core/ucp_worker.h
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ typedef struct ucp_worker {
* are in-progress */
uct_worker_cb_id_t rkey_ptr_cb_id;/* RKEY PTR worker callback queue ID */
ucp_tag_match_t tm; /* Tag-matching queues and offload info */
ucs_list_link_t rndv_reqs_list;
uint64_t am_message_id; /* For matching long am's */
ucp_ep_h mem_type_ep[UCS_MEMORY_TYPE_LAST];/* memory type eps */

Expand Down
1 change: 1 addition & 0 deletions src/ucp/tag/offload.c
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ UCS_PROFILE_FUNC(ucs_status_t, ucp_tag_offload_unexp_rndv,
dummy_rts->sreq.reqptr = rndv_hdr->reqptr;
dummy_rts->address = remote_addr;
dummy_rts->size = length;
dummy_rts->status = UCS_OK;

ucp_rkey_packed_copy(worker->context, UCS_BIT(md_index),
UCS_MEMORY_TYPE_HOST, dummy_rts + 1, uct_rkeys);
Expand Down
152 changes: 143 additions & 9 deletions src/ucp/tag/rndv.c
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,17 @@ ucp_rndv_req_get_zcopy_rma_lane(ucp_request_t *rndv_req, ucp_lane_map_t ignore,
rndv_req->send.rndv_get.rkey, ignore, uct_rkey_p);
}

static void ucp_rndv_complete_send(ucp_request_t *sreq, ucs_status_t status)
{
ucp_request_send_generic_dt_finish(sreq);
ucp_request_send_buffer_dereg(sreq);
if (sreq->flags & UCP_REQUEST_FLAG_CANCELED) {
ucs_warn("cancel-completing rnv request %p", sreq);
ucs_list_del(&sreq->send.list);
}
ucp_request_complete_send(sreq, status);
}

size_t ucp_tag_rndv_rts_pack(void *dest, void *arg)
{
ucp_request_t *sreq = arg; /* send request */
Expand All @@ -62,6 +73,7 @@ size_t ucp_tag_rndv_rts_pack(void *dest, void *arg)
rndv_rts_hdr->sreq.reqptr = (uintptr_t)sreq;
rndv_rts_hdr->sreq.ep_ptr = ucp_request_get_dest_ep_ptr(sreq);
rndv_rts_hdr->size = sreq->send.length;
rndv_rts_hdr->status = UCS_OK;

/* Pack remote keys (which can be empty list) */
if (UCP_DT_IS_CONTIG(sreq->send.datatype) &&
Expand Down Expand Up @@ -94,11 +106,56 @@ UCS_PROFILE_FUNC(ucs_status_t, ucp_proto_progress_rndv_rts, (self),
{
ucp_request_t *sreq = ucs_container_of(self, ucp_request_t, send.uct);
size_t packed_rkey_size;
ucs_status_t status;

/* send the RTS. the pack_cb will pack all the necessary fields in the RTS */
packed_rkey_size = ucp_ep_config(sreq->send.ep)->tag.rndv.rkey_size;
return ucp_do_am_single(self, UCP_AM_ID_RNDV_RTS, ucp_tag_rndv_rts_pack,
sizeof(ucp_rndv_rts_hdr_t) + packed_rkey_size);
status = ucp_do_am_single(self, UCP_AM_ID_RNDV_RTS, ucp_tag_rndv_rts_pack,
sizeof(ucp_rndv_rts_hdr_t) + packed_rkey_size);
if (status == UCS_OK) {
sreq->flags |= UCP_REQUEST_FLAG_RNDV_RTS_SENT;
return status;
} else if (status == UCS_ERR_NO_RESOURCE) {
return UCS_ERR_NO_RESOURCE;
} else {
ucs_assert(UCS_STATUS_IS_ERR(status));
ucp_rndv_complete_send(sreq, status);
return UCS_OK;
}
}

UCS_PROFILE_FUNC(ucs_status_t, ucp_proto_progress_rndv_cancel, (self),
uct_pending_req_t *self)
{
ucp_request_t *sreq = ucs_container_of(self, ucp_request_t, send.uct);
ucp_ep_h ep = sreq->send.ep;
ucp_memcpy_pack_context_t ctx;;
ucp_rndv_rts_hdr_t rndv_rts_hdr;
ssize_t packed_len;

sreq->send.lane = ucp_ep_get_am_lane(ep);

rndv_rts_hdr.super.tag = sreq->send.msg_proto.tag.tag;
rndv_rts_hdr.sreq.reqptr = (uintptr_t)sreq;
rndv_rts_hdr.sreq.ep_ptr = ucp_request_get_dest_ep_ptr(sreq);
rndv_rts_hdr.size = sreq->send.length;
rndv_rts_hdr.status = UCS_ERR_CANCELED;

ctx.src = &rndv_rts_hdr;
ctx.length = sizeof(rndv_rts_hdr);

packed_len = uct_ep_am_bcopy(ep->uct_eps[sreq->send.lane],
UCP_AM_ID_RNDV_RTS, ucp_memcpy_pack, &ctx, 0);
if (packed_len >= 0) {
sreq->flags |= UCP_REQUEST_FLAG_CANCELED;
ucs_list_add_tail(&ep->worker->rndv_reqs_list, &sreq->send.list);
return UCS_OK;
} else if (packed_len == UCS_ERR_NO_RESOURCE) {
return UCS_ERR_NO_RESOURCE;
} else {
ucp_rndv_complete_send(sreq, (ucs_status_t)packed_len);
return UCS_OK;
}
}

static size_t ucp_tag_rndv_rtr_pack(void *dest, void *arg)
Expand Down Expand Up @@ -191,6 +248,8 @@ ucs_status_t ucp_tag_send_start_rndv(ucp_request_t *sreq)
sreq->send.length);
UCS_PROFILE_REQUEST_EVENT(sreq, "start_rndv", sreq->send.length);

sreq->flags |= UCP_REQUEST_FLAG_SEND_RNDV;

status = ucp_ep_resolve_dest_ep_ptr(ep, sreq->send.lane);
if (status != UCS_OK) {
return status;
Expand All @@ -207,18 +266,38 @@ ucs_status_t ucp_tag_send_start_rndv(ucp_request_t *sreq)
return status;
}

static void ucp_rndv_complete_send(ucp_request_t *sreq, ucs_status_t status)
void ucp_tag_rndv_cancel(ucp_request_t *sreq)
{
ucp_request_send_generic_dt_finish(sreq);
ucp_request_send_buffer_dereg(sreq);
ucp_request_complete_send(sreq, status);
if (!(sreq->send.ep->flags & UCP_EP_FLAG_REMOTE_CONNECTED)) {
ucp_rndv_complete_send(sreq, UCS_ERR_CANCELED);
} else {
sreq->send.uct.func = ucp_proto_progress_rndv_cancel;
if (sreq->flags & UCP_REQUEST_FLAG_RNDV_RTS_SENT) {
ucp_request_send(sreq, 0);
}
}
}

void ucp_ep_complete_rndv_reqs(ucp_ep_h ep)
{
ucp_worker_h worker = ep->worker;
ucp_request_t *sreq, *tmp;

ucs_list_for_each_safe(sreq, tmp, &worker->rndv_reqs_list, send.list) {
if (sreq->send.ep == ep) {
ucp_rndv_complete_send(sreq, UCS_ERR_CANCELED);
}
}
}

static void ucp_rndv_req_send_ats(ucp_request_t *rndv_req, ucp_request_t *rreq,
uintptr_t remote_request, ucs_status_t status)
{
ucp_trace_req(rndv_req, "send ats remote_request 0x%lx", remote_request);
UCS_PROFILE_REQUEST_EVENT(rreq, "send_ats", 0);

if (rreq != NULL) {
UCS_PROFILE_REQUEST_EVENT(rreq, "send_ats", 0);
}

rndv_req->send.lane = ucp_ep_get_am_lane(rndv_req->send.ep);
rndv_req->send.uct.func = ucp_proto_progress_am_single;
Expand Down Expand Up @@ -911,6 +990,56 @@ UCS_PROFILE_FUNC_VOID(ucp_rndv_matched, (worker, rreq, rndv_rts_hdr),
UCS_ASYNC_UNBLOCK(&worker->async);
}

static void ucp_rdnv_send_cancel_ack(ucp_worker_h worker,
ucp_rndv_rts_hdr_t *rndv_rts_hdr)
{
ucp_request_t *req;

req = ucp_request_get(worker);
if (req == NULL) {
return;
}

req->send.ep = ucp_worker_get_ep_by_ptr(worker,
rndv_rts_hdr->sreq.ep_ptr);
req->flags = 0;
req->send.mdesc = NULL;
req->send.pending_lane = UCP_NULL_LANE;

ucp_rndv_req_send_ats(req, NULL, rndv_rts_hdr->sreq.reqptr,
UCS_ERR_CANCELED);
}

static void ucp_rndv_unexp_cancel(ucp_worker_h worker,
ucp_rndv_rts_hdr_t *rndv_rts_hdr)
{
const ucp_rndv_rts_hdr_t *rdesc_rts_hdr;

ucp_recv_desc_t *rdesc;
ucs_list_link_t *list;

ucs_warn("rndv cancel remote request 0x%lx ep 0x%lx",
rndv_rts_hdr->sreq.reqptr, rndv_rts_hdr->sreq.ep_ptr);

list = ucp_tag_unexp_get_list_for_tag(&worker->tm, rndv_rts_hdr->super.tag);
ucs_list_for_each(rdesc, list, tag_list[UCP_RDESC_HASH_LIST]) {
rdesc_rts_hdr = (const void*)(rdesc + 1);
if ((rdesc->flags & UCP_RECV_DESC_FLAG_RNDV) &&
(ucp_rdesc_get_tag(rdesc) == rndv_rts_hdr->super.tag) &&
(rdesc_rts_hdr->sreq.ep_ptr == rndv_rts_hdr->sreq.ep_ptr) &&
(rdesc_rts_hdr->sreq.reqptr == rndv_rts_hdr->sreq.reqptr))
{
ucs_trace_req("canceling unexp rdesc " UCP_RECV_DESC_FMT " with "
"tag %"PRIx64, UCP_RECV_DESC_ARG(rdesc),
ucp_rdesc_get_tag(rdesc));
ucp_tag_unexp_remove(rdesc);
ucp_rdnv_send_cancel_ack(worker, rndv_rts_hdr);
ucp_recv_desc_release(rdesc);
return;
}
}
}

ucs_status_t ucp_rndv_process_rts(void *arg, void *data, size_t length,
unsigned tl_flags)
{
Expand All @@ -920,6 +1049,11 @@ ucs_status_t ucp_rndv_process_rts(void *arg, void *data, size_t length,
ucp_request_t *rreq;
ucs_status_t status;

if (rndv_rts_hdr->status == UCS_ERR_CANCELED) {
ucp_rndv_unexp_cancel(worker, rndv_rts_hdr);
return UCS_OK;
}

rreq = ucp_tag_exp_search(&worker->tm, rndv_rts_hdr->super.tag);
if (rreq != NULL) {
ucp_rndv_matched(worker, rreq, rndv_rts_hdr);
Expand Down Expand Up @@ -1528,9 +1662,9 @@ static void ucp_rndv_dump(ucp_worker_h worker, uct_am_trace_type_t type,
case UCP_AM_ID_RNDV_RTS:
ucs_assert(rndv_rts_hdr->sreq.ep_ptr != 0);
snprintf(buffer, max, "RNDV_RTS tag %"PRIx64" ep_ptr %lx sreq 0x%lx "
"address 0x%"PRIx64" size %zu", rndv_rts_hdr->super.tag,
"address 0x%"PRIx64" size %zu status %d", rndv_rts_hdr->super.tag,
rndv_rts_hdr->sreq.ep_ptr, rndv_rts_hdr->sreq.reqptr,
rndv_rts_hdr->address, rndv_rts_hdr->size);
rndv_rts_hdr->address, rndv_rts_hdr->size, rndv_rts_hdr->status);
if (rndv_rts_hdr->address) {
ucp_rndv_dump_rkey(rndv_rts_hdr + 1, buffer + strlen(buffer),
max - strlen(buffer));
Expand Down
5 changes: 5 additions & 0 deletions src/ucp/tag/rndv.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ typedef struct {
ucp_request_hdr_t sreq; /* send request on the rndv initiator side */
uint64_t address; /* holds the address of the data buffer on the sender's side */
size_t size; /* size of the data for sending */
ucs_status_t status;
/* packed rkeys follow */
} UCS_S_PACKED ucp_rndv_rts_hdr_t;

Expand All @@ -48,6 +49,8 @@ typedef struct {

ucs_status_t ucp_tag_send_start_rndv(ucp_request_t *req);

void ucp_tag_rndv_cancel(ucp_request_t *sreq);

void ucp_rndv_matched(ucp_worker_h worker, ucp_request_t *req,
const ucp_rndv_rts_hdr_t *rndv_rts_hdr);

Expand All @@ -60,6 +63,8 @@ size_t ucp_tag_rndv_rts_pack(void *dest, void *arg);

ucs_status_t ucp_tag_rndv_reg_send_buffer(ucp_request_t *sreq);

void ucp_ep_complete_rndv_reqs(ucp_ep_h ep);

static UCS_F_ALWAYS_INLINE int ucp_rndv_is_get_zcopy(ucs_memory_type_t mem_type,
ucp_rndv_mode_t rndv_mode)
{
Expand Down

0 comments on commit a010441

Please sign in to comment.