diff --git a/src/ucp/core/ucp_proxy_ep.c b/src/ucp/core/ucp_proxy_ep.c index bc612ea5311..114988ee78a 100644 --- a/src/ucp/core/ucp_proxy_ep.c +++ b/src/ucp/core/ucp_proxy_ep.c @@ -216,6 +216,7 @@ void ucp_proxy_ep_replace(ucp_proxy_ep_t *proxy_ep) ucp_ep_h ucp_ep = proxy_ep->ucp_ep; ucp_lane_index_t lane; uct_ep_h tl_ep = NULL; + ucs_status_t status; ucs_assert(proxy_ep->uct_ep != NULL); for (lane = 0; lane < ucp_ep_num_lanes(ucp_ep); ++lane) { @@ -224,6 +225,11 @@ void ucp_proxy_ep_replace(ucp_proxy_ep_t *proxy_ep) ucp_ep->uct_eps[lane] = proxy_ep->uct_ep; tl_ep = ucp_ep->uct_eps[lane]; proxy_ep->uct_ep = NULL; + status = uct_ep_enable_keep_alive(tl_ep, 1); + if (status != UCS_OK) { + ucs_diag("ep %p: uct_ep_enable_keep_alive(tl_ep=%p, 1) failed %s", + ucp_ep, tl_ep, ucs_status_string(status)); + } } } diff --git a/src/ucp/wireup/wireup_ep.c b/src/ucp/wireup/wireup_ep.c index 2a97c188d18..bc3de9122d1 100644 --- a/src/ucp/wireup/wireup_ep.c +++ b/src/ucp/wireup/wireup_ep.c @@ -319,6 +319,7 @@ UCS_CLASS_INIT_FUNC(ucp_wireup_ep_t, ucp_ep_h ucp_ep) { static uct_iface_ops_t ops = { .ep_connect_to_ep = ucp_wireup_ep_connect_to_ep, + .ep_enable_keep_alive= (uct_ep_enable_keep_alive_func_t)ucs_empty_function_return_unsupported, .ep_flush = ucp_wireup_ep_flush, .ep_query = (uct_ep_query_func_t)ucs_empty_function_return_not_connected, .ep_destroy = UCS_CLASS_DELETE_FUNC_NAME(ucp_wireup_ep_t), diff --git a/src/uct/api/tl.h b/src/uct/api/tl.h index b0dcbe84964..d5226123dfb 100644 --- a/src/uct/api/tl.h +++ b/src/uct/api/tl.h @@ -230,6 +230,8 @@ typedef ucs_status_t (*uct_ep_connect_to_ep_func_t)(uct_ep_h ep, const uct_device_addr_t *dev_addr, const uct_ep_addr_t *ep_addr); +typedef ucs_status_t (*uct_ep_enable_keep_alive_func_t)(uct_ep_h ep, int enable); + typedef ucs_status_t (*uct_iface_accept_func_t)(uct_iface_h iface, uct_conn_request_h conn_request); @@ -341,6 +343,7 @@ typedef struct uct_iface_ops { uct_ep_destroy_func_t ep_destroy; uct_ep_get_address_func_t ep_get_address; uct_ep_connect_to_ep_func_t ep_connect_to_ep; + uct_ep_enable_keep_alive_func_t ep_enable_keep_alive; uct_iface_accept_func_t iface_accept; uct_iface_reject_func_t iface_reject; diff --git a/src/uct/api/uct.h b/src/uct/api/uct.h index 1067467867b..8802e7e4d34 100644 --- a/src/uct/api/uct.h +++ b/src/uct/api/uct.h @@ -2064,6 +2064,17 @@ ucs_status_t uct_ep_connect_to_ep(uct_ep_h ep, const uct_device_addr_t *dev_addr const uct_ep_addr_t *ep_addr); +/** + * @ingroup UCT_RESOURCE + * @brief enable/disable keep alive protocol on the endpoint. + * + * @param [in] ep Endpoint to enable keep alive on. + * @param [in] enable 1 - enable, 0 - disable keep alive + * @return UCS_OK In case of success + * UCS_ERR_UNSUPPORTED If transport does not support keep alive + */ +ucs_status_t uct_ep_enable_keep_alive(uct_ep_h ep, int enable); + /** * @ingroup UCT_MD * @brief Query for memory domain attributes. diff --git a/src/uct/base/uct_iface.c b/src/uct/base/uct_iface.c index 8e4e2067ae0..51e792e172b 100644 --- a/src/uct/base/uct_iface.c +++ b/src/uct/base/uct_iface.c @@ -367,6 +367,7 @@ ucs_status_t uct_set_ep_failed(ucs_class_t *cls, uct_ep_h tl_ep, ops->ep_fence = (uct_ep_fence_func_t)ucs_empty_function_return_ep_timeout; ops->ep_check = (uct_ep_check_func_t)ucs_empty_function_return_ep_timeout; ops->ep_connect_to_ep = (uct_ep_connect_to_ep_func_t)ucs_empty_function_return_ep_timeout; + ops->ep_enable_keep_alive= (uct_ep_enable_keep_alive_func_t)ucs_empty_function_return_ep_timeout, ops->ep_query = (uct_ep_query_func_t)ucs_empty_function_return_ep_timeout; ops->ep_destroy = uct_ep_failed_destroy; ops->ep_get_address = (uct_ep_get_address_func_t)ucs_empty_function_return_ep_timeout; @@ -558,6 +559,11 @@ ucs_status_t uct_ep_connect_to_ep(uct_ep_h ep, const uct_device_addr_t *dev_addr return ep->iface->ops.ep_connect_to_ep(ep, dev_addr, ep_addr); } +ucs_status_t uct_ep_enable_keep_alive(uct_ep_h ep, int enable) +{ + return ep->iface->ops.ep_enable_keep_alive(ep, enable); +} + ucs_status_t uct_cm_client_ep_conn_notify(uct_ep_h ep) { return ep->iface->ops.cm_ep_conn_notify(ep); diff --git a/src/uct/ib/rc/accel/rc_mlx5.h b/src/uct/ib/rc/accel/rc_mlx5.h index 388ee2f2148..4eb8ea7ab0a 100644 --- a/src/uct/ib/rc/accel/rc_mlx5.h +++ b/src/uct/ib/rc/accel/rc_mlx5.h @@ -125,6 +125,8 @@ ucs_status_t uct_rc_mlx5_ep_connect_to_ep(uct_ep_h tl_ep, const uct_device_addr_t *dev_addr, const uct_ep_addr_t *ep_addr); +ucs_status_t uct_rc_mlx5_ep_enable_keep_alive(uct_ep_h tl_ep, int enable); + unsigned uct_rc_mlx5_iface_progress(void *arg); ucs_status_t uct_rc_mlx5_ep_tag_eager_short(uct_ep_h tl_ep, uct_tag_t tag, diff --git a/src/uct/ib/rc/accel/rc_mlx5_ep.c b/src/uct/ib/rc/accel/rc_mlx5_ep.c index 1340039f14e..db6d6b9cd9e 100644 --- a/src/uct/ib/rc/accel/rc_mlx5_ep.c +++ b/src/uct/ib/rc/accel/rc_mlx5_ep.c @@ -701,11 +701,18 @@ ucs_status_t uct_rc_mlx5_ep_connect_to_ep(uct_ep_h tl_ep, } ep->atomic_mr_offset = uct_ib_md_atomic_offset(rc_addr->atomic_mr_id); - ep->connected = 1; return UCS_OK; } +ucs_status_t uct_rc_mlx5_ep_enable_keep_alive(uct_ep_h tl_ep, int enable) +{ + uct_rc_mlx5_ep_t *rc_mlx5_ep = ucs_derived_of(tl_ep, uct_rc_mlx5_ep_t); + + rc_mlx5_ep->connected = enable; + return UCS_OK; +} + #if IBV_HW_TM ucs_status_t uct_rc_mlx5_ep_tag_rndv_cancel(uct_ep_h tl_ep, void *op) diff --git a/src/uct/ib/rc/accel/rc_mlx5_iface.c b/src/uct/ib/rc/accel/rc_mlx5_iface.c index 7316736c215..0cb3e9a7055 100644 --- a/src/uct/ib/rc/accel/rc_mlx5_iface.c +++ b/src/uct/ib/rc/accel/rc_mlx5_iface.c @@ -798,6 +798,7 @@ static uct_rc_iface_ops_t uct_rc_mlx5_iface_ops = { .ep_destroy = UCS_CLASS_DELETE_FUNC_NAME(uct_rc_mlx5_ep_t), .ep_get_address = uct_rc_mlx5_ep_get_address, .ep_connect_to_ep = uct_rc_mlx5_ep_connect_to_ep, + .ep_enable_keep_alive = uct_rc_mlx5_ep_enable_keep_alive, #if IBV_HW_TM .ep_tag_eager_short = uct_rc_mlx5_ep_tag_eager_short, .ep_tag_eager_bcopy = uct_rc_mlx5_ep_tag_eager_bcopy, diff --git a/src/uct/ib/rc/verbs/rc_verbs_iface.c b/src/uct/ib/rc/verbs/rc_verbs_iface.c index 4d23196d8d9..41a9c79db8c 100644 --- a/src/uct/ib/rc/verbs/rc_verbs_iface.c +++ b/src/uct/ib/rc/verbs/rc_verbs_iface.c @@ -410,6 +410,7 @@ static uct_rc_iface_ops_t uct_rc_verbs_iface_ops = { .ep_destroy = UCS_CLASS_DELETE_FUNC_NAME(uct_rc_verbs_ep_t), .ep_get_address = uct_rc_verbs_ep_get_address, .ep_connect_to_ep = uct_rc_verbs_ep_connect_to_ep, + .ep_enable_keep_alive = (uct_ep_enable_keep_alive_func_t)ucs_empty_function_return_unsupported, .iface_flush = uct_rc_iface_flush, .iface_fence = uct_rc_iface_fence, .iface_progress_enable = uct_rc_verbs_iface_common_progress_enable, diff --git a/src/uct/ib/ud/accel/ud_mlx5.c b/src/uct/ib/ud/accel/ud_mlx5.c index a86d0b5d51e..3492e5830a0 100644 --- a/src/uct/ib/ud/accel/ud_mlx5.c +++ b/src/uct/ib/ud/accel/ud_mlx5.c @@ -756,6 +756,7 @@ static uct_ud_iface_ops_t uct_ud_mlx5_iface_ops = { .ep_destroy = uct_ud_ep_disconnect , .ep_get_address = uct_ud_ep_get_address, .ep_connect_to_ep = uct_ud_mlx5_ep_connect_to_ep, + .ep_enable_keep_alive = (uct_ep_enable_keep_alive_func_t)ucs_empty_function_return_unsupported, .iface_flush = uct_ud_iface_flush, .iface_fence = uct_base_iface_fence, .iface_progress_enable = uct_ud_iface_progress_enable, diff --git a/src/uct/ib/ud/verbs/ud_verbs.c b/src/uct/ib/ud/verbs/ud_verbs.c index d8531a2d064..0e957771fd1 100644 --- a/src/uct/ib/ud/verbs/ud_verbs.c +++ b/src/uct/ib/ud/verbs/ud_verbs.c @@ -571,6 +571,7 @@ static uct_ud_iface_ops_t uct_ud_verbs_iface_ops = { .ep_destroy = uct_ud_ep_disconnect, .ep_get_address = uct_ud_ep_get_address, .ep_connect_to_ep = uct_ud_verbs_ep_connect_to_ep, + .ep_enable_keep_alive = (uct_ep_enable_keep_alive_func_t)ucs_empty_function_return_unsupported, .iface_flush = uct_ud_iface_flush, .iface_fence = uct_base_iface_fence, .iface_progress_enable = uct_ud_iface_progress_enable,