diff --git a/src/uct/base/uct_md.h b/src/uct/base/uct_md.h index 3036354ecb9..e1276880bc0 100644 --- a/src/uct/base/uct_md.h +++ b/src/uct/base/uct_md.h @@ -144,6 +144,7 @@ struct uct_md_ops { int (*is_sockaddr_accessible)(uct_md_h md, const ucs_sock_addr_t *sockaddr, uct_sockaddr_accessibility_t mode); + int (*is_mem_type_owned)(uct_md_h md, void *addr, size_t length); }; diff --git a/src/uct/ib/base/ib_md.c b/src/uct/ib/base/ib_md.c index 82e60c8038b..791933e12ea 100644 --- a/src/uct/ib/base/ib_md.c +++ b/src/uct/ib/base/ib_md.c @@ -844,15 +844,15 @@ static ucs_status_t uct_ib_rkey_unpack(uct_md_component_t *mdc, static void uct_ib_md_close(uct_md_h md); static uct_md_ops_t uct_ib_md_ops = { - .close = uct_ib_md_close, - .query = uct_ib_md_query, - .mem_alloc = uct_ib_mem_alloc, - .mem_free = uct_ib_mem_free, - .mem_reg = uct_ib_mem_reg, - .mem_dereg = uct_ib_mem_dereg, - .mem_advise = uct_ib_mem_advise, - .mkey_pack = uct_ib_mkey_pack, - .is_mem_type_owned = (void *)ucs_empty_function_return_zero, + .close = uct_ib_md_close, + .query = uct_ib_md_query, + .mem_alloc = uct_ib_mem_alloc, + .mem_free = uct_ib_mem_free, + .mem_reg = uct_ib_mem_reg, + .mem_dereg = uct_ib_mem_dereg, + .mem_advise = uct_ib_mem_advise, + .mkey_pack = uct_ib_mkey_pack, + .is_mem_type_owned = (void*)ucs_empty_function_return_zero, }; static inline uct_ib_rcache_region_t* uct_ib_rcache_region_from_memh(uct_mem_h memh) @@ -898,17 +898,17 @@ static ucs_status_t uct_ib_mem_rcache_dereg(uct_md_h uct_md, uct_mem_h memh) } static uct_md_ops_t uct_ib_md_rcache_ops = { - .close = uct_ib_md_close, - .query = uct_ib_md_query, - .mem_alloc = uct_ib_mem_alloc, - .mem_free = uct_ib_mem_free, - .mem_reg = uct_ib_mem_rcache_reg, - .mem_dereg = uct_ib_mem_rcache_dereg, - .mem_advise = uct_ib_mem_advise, - .mkey_pack = uct_ib_mkey_pack, + .close = uct_ib_md_close, + .query = uct_ib_md_query, + .mem_alloc = uct_ib_mem_alloc, + .mem_free = uct_ib_mem_free, + .mem_reg = uct_ib_mem_rcache_reg, + .mem_dereg = uct_ib_mem_rcache_dereg, + .mem_advise = uct_ib_mem_advise, + .mkey_pack = uct_ib_mkey_pack, + .is_mem_type_owned = (void*)ucs_empty_function_return_zero, }; - static ucs_status_t uct_ib_rcache_mem_reg_cb(void *context, ucs_rcache_t *rcache, void *arg, ucs_rcache_region_t *rregion) { @@ -957,6 +957,20 @@ static ucs_rcache_ops_t uct_ib_rcache_ops = { .dump_region = uct_ib_rcache_dump_region_cb }; +static ucs_status_t uct_ib_md_odp_query(uct_md_h uct_md, uct_md_attr_t *md_attr) +{ + ucs_status_t status; + + status = uct_ib_md_query(uct_md, md_attr); + if (status != UCS_OK) { + return status; + } + + /* ODP supports only host memory */ + md_attr->cap.reg_mem_types &= UCS_BIT(UCT_MD_MEM_TYPE_HOST); + return UCS_OK; +} + static ucs_status_t uct_ib_mem_global_odp_reg(uct_md_h uct_md, void *address, size_t length, unsigned flags, uct_mem_h *memh_p) @@ -987,14 +1001,15 @@ static ucs_status_t uct_ib_mem_global_odp_dereg(uct_md_h uct_md, uct_mem_h memh) } static uct_md_ops_t UCS_V_UNUSED uct_ib_md_global_odp_ops = { - .close = uct_ib_md_close, - .query = uct_ib_md_query, - .mem_alloc = uct_ib_mem_alloc, - .mem_free = uct_ib_mem_free, - .mem_reg = uct_ib_mem_global_odp_reg, - .mem_dereg = uct_ib_mem_global_odp_dereg, - .mem_advise = uct_ib_mem_advise, - .mkey_pack = uct_ib_mkey_pack, + .close = uct_ib_md_close, + .query = uct_ib_md_odp_query, + .mem_alloc = uct_ib_mem_alloc, + .mem_free = uct_ib_mem_free, + .mem_reg = uct_ib_mem_global_odp_reg, + .mem_dereg = uct_ib_mem_global_odp_dereg, + .mem_advise = uct_ib_mem_advise, + .mkey_pack = uct_ib_mkey_pack, + .is_mem_type_owned = (void*)ucs_empty_function_return_zero, }; static void uct_ib_make_md_name(char md_name[UCT_MD_NAME_MAX], struct ibv_device *device)