From 818bac689d6bce726d868761249a77147084641a Mon Sep 17 00:00:00 2001 From: Yossi Itigin Date: Sat, 11 Dec 2021 15:58:52 +0200 Subject: [PATCH] UCP/UCS/UCT: Fix memtype_cache region info after merge Instead of tracking base_address/alloc_length in each memtype cache region, use start/end fields to track whole-allocation range. This makes sure the region info after merge stays correct. --- NEWS | 1 + src/ucp/core/ucp_mm.c | 17 +++- src/ucs/memory/memtype_cache.c | 123 ++++++++++++++------------- src/ucs/memory/memtype_cache.h | 40 ++++----- src/uct/cuda/base/cuda_md.c | 6 +- test/gtest/ucs/test_memtype_cache.cc | 6 +- 6 files changed, 109 insertions(+), 84 deletions(-) diff --git a/NEWS b/NEWS index ba2178f04d7..72b9e5d7d3b 100644 --- a/NEWS +++ b/NEWS @@ -10,6 +10,7 @@ ## 1.12.1 (TBD) ### Bugfixes * Fixed memory hooks for Cuda 11.5 +* Fixed memory type cache merge ## 1.12.0 (January 12, 2022) ### Features: diff --git a/src/ucp/core/ucp_mm.c b/src/ucp/core/ucp_mm.c index 87477268729..536b5861e06 100644 --- a/src/ucp/core/ucp_mm.c +++ b/src/ucp/core/ucp_mm.c @@ -40,6 +40,7 @@ ucs_status_t ucp_mem_rereg_mds(ucp_context_h context, ucp_md_map_t reg_md_map, uct_mem_h *prev_uct_memh; ucp_md_map_t new_md_map; const uct_md_attr_t *md_attr; + void *end_address UCS_V_UNUSED; unsigned prev_num_memh; unsigned md_index; ucs_status_t status; @@ -119,13 +120,23 @@ ucs_status_t ucp_mem_rereg_mds(ucp_context_h context, ucp_md_map_t reg_md_map, continue; } - base_address = address; - reg_length = length; - if (context->config.ext.reg_whole_alloc_bitmap & UCS_BIT(mem_type)) { ucp_memory_detect_internal(context, address, length, &mem_info); base_address = mem_info.base_address; reg_length = mem_info.alloc_length; + end_address = UCS_PTR_BYTE_OFFSET(base_address, reg_length); + ucs_trace("extending %p..%p to %p..%p", address, + UCS_PTR_BYTE_OFFSET(address, length), base_address, + end_address); + ucs_assertv(base_address <= address, + "base_address=%p address=%p", base_address, + address); + ucs_assertv(end_address >= UCS_PTR_BYTE_OFFSET(address, length), + "end_address=%p address+length=%p", end_address, + UCS_PTR_BYTE_OFFSET(address, length)); + } else { + base_address = address; + reg_length = length; } /* MD supports registration, register new memh on it */ diff --git a/src/ucs/memory/memtype_cache.c b/src/ucs/memory/memtype_cache.c index 23c87386661..17988706a06 100644 --- a/src/ucs/memory/memtype_cache.c +++ b/src/ucs/memory/memtype_cache.c @@ -27,11 +27,25 @@ ucs_spinlock_t ucs_memtype_cache_global_instance_lock; ucs_memtype_cache_t *ucs_memtype_cache_global_instance = NULL; +#define UCS_MEMTYPE_CACHE_REGION_FMT UCS_PGT_REGION_FMT " %s dev %s" +#define UCS_MEMTYPE_CACHE_REGION_ARG(_region) \ + UCS_PGT_REGION_ARG(&(_region)->super), \ + ucs_memory_type_names[(_region)->mem_type], \ + ucs_topo_sys_device_get_name((_region)->sys_dev) + typedef enum { UCS_MEMTYPE_CACHE_ACTION_SET_MEMTYPE, UCS_MEMTYPE_CACHE_ACTION_REMOVE } ucs_memtype_cache_action_t; +struct ucs_memtype_cache_region { + ucs_pgt_region_t super; /**< Base class - page table region */ + ucs_list_link_t list; /**< List element */ + ucs_memory_type_t mem_type; /**< Memory type, use uint8 for compact size */ + ucs_sys_device_t sys_dev; /**< System device index */ + }; + + static UCS_CLASS_INIT_FUNC(ucs_memtype_cache_t); static UCS_CLASS_CLEANUP_FUNC(ucs_memtype_cache_t); @@ -79,14 +93,6 @@ ucs_memory_info_set_unknown(ucs_memory_info_t *mem_info) mem_info->alloc_length = -1; } -void ucs_memory_info_set_host(ucs_memory_info_t *mem_info) -{ - mem_info->type = UCS_MEMORY_TYPE_HOST; - mem_info->sys_dev = UCS_SYS_DEVICE_ID_UNKNOWN; - mem_info->base_address = NULL; - mem_info->alloc_length = -1; -} - static ucs_pgt_dir_t *ucs_memtype_cache_pgt_dir_alloc(const ucs_pgtable_t *pgtable) { void *ptr; @@ -110,7 +116,8 @@ static void ucs_memtype_cache_pgt_dir_release(const ucs_pgtable_t *pgtable, */ static void ucs_memtype_cache_insert(ucs_memtype_cache_t *memtype_cache, ucs_pgt_addr_t start, ucs_pgt_addr_t end, - const ucs_memory_info_t *mem_info) + ucs_memory_type_t mem_type, + ucs_sys_device_t sys_dev) { ucs_memtype_cache_region_t *region; ucs_status_t status; @@ -131,23 +138,21 @@ static void ucs_memtype_cache_insert(ucs_memtype_cache_t *memtype_cache, region->super.start = start; region->super.end = end; - region->mem_info = *mem_info; + region->mem_type = mem_type; + region->sys_dev = sys_dev; status = UCS_PROFILE_CALL(ucs_pgtable_insert, &memtype_cache->pgtable, ®ion->super); if (status != UCS_OK) { - ucs_error("failed to insert region " UCS_PGT_REGION_FMT ": %s", - UCS_PGT_REGION_ARG(®ion->super), ucs_status_string(status)); + ucs_error("failed to insert " UCS_MEMTYPE_CACHE_REGION_FMT ": %s", + UCS_MEMTYPE_CACHE_REGION_ARG(region), + ucs_status_string(status)); ucs_free(region); return; } - ucs_trace("memtype_cache: insert " UCS_PGT_REGION_FMT " mem_type %s dev %s" - " base_addr %p alloc_length %ld", - UCS_PGT_REGION_ARG(®ion->super), - ucs_memory_type_names[mem_info->type], - ucs_topo_sys_device_get_name(mem_info->sys_dev), - mem_info->base_address, mem_info->alloc_length); + ucs_trace("memtype_cache: insert " UCS_MEMTYPE_CACHE_REGION_FMT, + UCS_MEMTYPE_CACHE_REGION_ARG(region)); } static void ucs_memtype_cache_region_collect_callback(const ucs_pgtable_t *pgtable, @@ -161,15 +166,15 @@ static void ucs_memtype_cache_region_collect_callback(const ucs_pgtable_t *pgtab } UCS_PROFILE_FUNC_VOID(ucs_memtype_cache_update_internal, - (memtype_cache, address, size, mem_info, action), - ucs_memtype_cache_t *memtype_cache, - const void *address, size_t size, - const ucs_memory_info_t *mem_info, + (memtype_cache, address, size, mem_type, sys_dev, action), + ucs_memtype_cache_t *memtype_cache, const void *address, + size_t size, ucs_memory_type_t mem_type, + ucs_sys_device_t sys_dev, ucs_memtype_cache_action_t action) { + ucs_pgt_addr_t start, end, search_start, search_end; ucs_memtype_cache_region_t *region, *tmp; UCS_LIST_HEAD(region_list); - ucs_pgt_addr_t start, end, search_start, search_end; ucs_status_t status; if (!size) { @@ -179,13 +184,11 @@ UCS_PROFILE_FUNC_VOID(ucs_memtype_cache_update_internal, start = ucs_align_down_pow2((uintptr_t)address, UCS_PGT_ADDR_ALIGN); end = ucs_align_up_pow2 ((uintptr_t)address + size, UCS_PGT_ADDR_ALIGN); - ucs_trace("%s: [0x%lx..0x%lx] mem_type %s dev %s" - " base_addr %p alloc_length %ld", + ucs_trace("%s: [0x%lx..0x%lx] mem_type %s dev %s", (action == UCS_MEMTYPE_CACHE_ACTION_SET_MEMTYPE) ? "update" : "remove", - start, end, ucs_memory_type_names[mem_info->type], - ucs_topo_sys_device_get_name(mem_info->sys_dev), - mem_info->base_address, mem_info->alloc_length); + start, end, ucs_memory_type_names[mem_type], + ucs_topo_sys_device_get_name(sys_dev)); search_start = start; search_end = end - 1; @@ -198,11 +201,14 @@ UCS_PROFILE_FUNC_VOID(ucs_memtype_cache_update_internal, ®ion_list); ucs_list_for_each_safe(region, tmp, ®ion_list, list) { if (action == UCS_MEMTYPE_CACHE_ACTION_SET_MEMTYPE) { - if (region->mem_info.type == mem_info->type) { + if (region->mem_type == mem_type) { /* merge current region with overlapping or adjacent regions * of same memory type */ start = ucs_min(start, region->super.start); end = ucs_max(end, region->super.end); + ucs_trace("merge with " UCS_MEMTYPE_CACHE_REGION_FMT + ": [0x%lx..0x%lx]", + UCS_MEMTYPE_CACHE_REGION_ARG(region), start, end); } else if ((region->super.end < start) || (region->super.start >= end)) { /* ignore regions which are not really overlapping and can't @@ -214,23 +220,18 @@ UCS_PROFILE_FUNC_VOID(ucs_memtype_cache_update_internal, status = ucs_pgtable_remove(&memtype_cache->pgtable, ®ion->super); if (status != UCS_OK) { - ucs_error("failed to remove " UCS_PGT_REGION_FMT - " from memtype_cache: %s", - UCS_PGT_REGION_ARG(®ion->super), + ucs_error("failed to remove " UCS_MEMTYPE_CACHE_REGION_FMT ": %s", + UCS_MEMTYPE_CACHE_REGION_ARG(region), ucs_status_string(status)); goto out_unlock; } - ucs_trace("memtype_cache: removed " UCS_PGT_REGION_FMT " %s dev %s" - " base_addr %p alloc_length %ld", - UCS_PGT_REGION_ARG(®ion->super), - ucs_memory_type_names[region->mem_info.type], - ucs_topo_sys_device_get_name(region->mem_info.sys_dev), - mem_info->base_address, mem_info->alloc_length); + ucs_trace("memtype_cache: removed " UCS_MEMTYPE_CACHE_REGION_FMT, + UCS_MEMTYPE_CACHE_REGION_ARG(region)); } if (action == UCS_MEMTYPE_CACHE_ACTION_SET_MEMTYPE) { - ucs_memtype_cache_insert(memtype_cache, start, end, mem_info); + ucs_memtype_cache_insert(memtype_cache, start, end, mem_type, sys_dev); } /* slice old regions by the new region, to preserve the previous memory type @@ -240,12 +241,12 @@ UCS_PROFILE_FUNC_VOID(ucs_memtype_cache_update_internal, if (start > region->super.start) { /* create previous region */ ucs_memtype_cache_insert(memtype_cache, region->super.start, start, - ®ion->mem_info); + region->mem_type, region->sys_dev); } if (end < region->super.end) { /* create next region */ ucs_memtype_cache_insert(memtype_cache, end, region->super.end, - ®ion->mem_info); + region->mem_type, region->sys_dev); } ucs_free(region); @@ -256,36 +257,29 @@ UCS_PROFILE_FUNC_VOID(ucs_memtype_cache_update_internal, } void ucs_memtype_cache_update(const void *address, size_t size, - const ucs_memory_info_t *mem_info) + ucs_memory_type_t mem_type, + ucs_sys_device_t sys_dev) { if (ucs_memtype_cache_global_instance == NULL) { return; } ucs_memtype_cache_update_internal(ucs_memtype_cache_global_instance, - address, size, mem_info, + address, size, mem_type, sys_dev, UCS_MEMTYPE_CACHE_ACTION_SET_MEMTYPE); } void ucs_memtype_cache_remove(const void *address, size_t size) { - ucs_memory_info_t mem_info; - - ucs_memory_info_set_unknown(&mem_info); ucs_memtype_cache_update_internal(ucs_memtype_cache_global_instance, - address, size, &mem_info, + address, size, UCS_MEMORY_TYPE_UNKNOWN, + UCS_SYS_DEVICE_ID_UNKNOWN, UCS_MEMTYPE_CACHE_ACTION_REMOVE); } static void ucs_memtype_cache_event_callback(ucm_event_type_t event_type, - ucm_event_t *event, void *arg) + ucm_event_t *event, void *arg) { - ucs_memory_info_t mem_info = { - .type = event->mem_type.mem_type, - .sys_dev = UCS_SYS_DEVICE_ID_UNKNOWN, - .base_address = event->mem_type.address, - .alloc_length = event->mem_type.size, - }; ucs_memtype_cache_action_t action; if (event_type & UCM_EVENT_MEM_TYPE_ALLOC) { @@ -296,8 +290,14 @@ static void ucs_memtype_cache_event_callback(ucm_event_type_t event_type, return; } + ucs_trace("dispatching mem event %d address %p length %zu mem_type %s", + event_type, event->mem_type.address, event->mem_type.size, + ucs_memory_type_names[event->mem_type.mem_type]); + ucs_memtype_cache_update_internal(arg, event->mem_type.address, - event->mem_type.size, &mem_info, action); + event->mem_type.size, + event->mem_type.mem_type, + UCS_SYS_DEVICE_ID_UNKNOWN, action); } static void ucs_memtype_cache_purge(ucs_memtype_cache_t *memtype_cache) @@ -333,14 +333,23 @@ UCS_PROFILE_FUNC(ucs_status_t, ucs_memtype_cache_lookup, pgt_region = UCS_PROFILE_CALL(ucs_pgtable_lookup, &memtype_cache->pgtable, start); if (pgt_region == NULL) { + ucs_trace("address 0x%lx not found", start); status = UCS_ERR_NO_ELEM; goto out_unlock; } + region = ucs_derived_of(pgt_region, ucs_memtype_cache_region_t); if (ucs_likely((start + size) <= pgt_region->end)) { - region = ucs_derived_of(pgt_region, ucs_memtype_cache_region_t); - *mem_info = region->mem_info; + mem_info->base_address = (void*)region->super.start; + mem_info->alloc_length = region->super.end - region->super.start; + mem_info->type = region->mem_type; + mem_info->sys_dev = region->sys_dev; + ucs_trace_data("0x%lx..0x%lx found in " UCS_MEMTYPE_CACHE_REGION_FMT, + start, start + size, + UCS_MEMTYPE_CACHE_REGION_ARG(region)); } else { + ucs_trace("0x%lx..0x%lx not contained in " UCS_MEMTYPE_CACHE_REGION_FMT, + start, start + size, UCS_MEMTYPE_CACHE_REGION_ARG(region)); ucs_memory_info_set_unknown(mem_info); } status = UCS_OK; diff --git a/src/ucs/memory/memtype_cache.h b/src/ucs/memory/memtype_cache.h index 55d3bfe9529..37d605b3707 100644 --- a/src/ucs/memory/memtype_cache.h +++ b/src/ucs/memory/memtype_cache.h @@ -29,21 +29,13 @@ extern ucs_memtype_cache_t *ucs_memtype_cache_global_instance; /* Memory information record */ typedef struct ucs_memory_info { - ucs_memory_type_t type; /**< Memory type, use uint8 for compact size */ + ucs_memory_type_t type; /**< Memory type */ ucs_sys_device_t sys_dev; /**< System device index */ void *base_address; /**< Base address of the underlying allocation */ size_t alloc_length; /**< Whole length of the underlying allocation */ } ucs_memory_info_t; -struct ucs_memtype_cache_region { - ucs_pgt_region_t super; /**< Base class - page table region */ - ucs_list_link_t list; /**< List element */ - ucs_memory_info_t mem_info; /**< Memory type and system device the address - belongs to */ -}; - - struct ucs_memtype_cache { pthread_rwlock_t lock; /**< protests the page table */ ucs_pgtable_t pgtable; /**< Page table to hold the regions */ @@ -73,11 +65,14 @@ ucs_status_t ucs_memtype_cache_lookup(const void *address, size_t size, * * @param [in] address Start address to update. * @param [in] size Size of the memory to update. - * @param [in] mem_info Set the memory info of the address range to this + * @param [in] mem_type Set the memory type of the address range to this * value. + * @param [in] sys_dev Set the system device of the address range to + * this value. */ void ucs_memtype_cache_update(const void *address, size_t size, - const ucs_memory_info_t *mem_info); + ucs_memory_type_t mem_type, + ucs_sys_device_t sys_dev); /** @@ -89,14 +84,6 @@ void ucs_memtype_cache_update(const void *address, size_t size, void ucs_memtype_cache_remove(const void *address, size_t size); -/** - * Helper function to set memory info structure to host memory type. - * - * @param [out] mem_info Pointer to memory info structure. - */ -void ucs_memory_info_set_host(ucs_memory_info_t *mem_info); - - /** * Find if global memtype_cache is empty. * @@ -108,6 +95,21 @@ static UCS_F_ALWAYS_INLINE int ucs_memtype_cache_is_empty() (ucs_memtype_cache_global_instance->pgtable.num_regions == 0); } + +/** + * Helper function to set memory info structure to host memory type. + * + * @param [out] mem_info Pointer to memory info structure. + */ +static UCS_F_ALWAYS_INLINE void +ucs_memory_info_set_host(ucs_memory_info_t *mem_info) +{ + mem_info->type = UCS_MEMORY_TYPE_HOST; + mem_info->sys_dev = UCS_SYS_DEVICE_ID_UNKNOWN; + mem_info->base_address = NULL; + mem_info->alloc_length = -1; +} + END_C_DECLS #endif diff --git a/src/uct/cuda/base/cuda_md.c b/src/uct/cuda/base/cuda_md.c index 141d7efc2f0..d7a9f6b8daa 100644 --- a/src/uct/cuda/base/cuda_md.c +++ b/src/uct/cuda/base/cuda_md.c @@ -163,6 +163,9 @@ uct_cuda_base_query_attributes(uct_cuda_copy_md_t *md, const void *address, return UCS_ERR_INVALID_ADDR; } + ucs_trace("query address %p: 0x%llx..0x%llx length %zu", address, + base_address, base_address + alloc_length, alloc_length); + if (md->config.alloc_whole_reg == UCS_CONFIG_AUTO) { total_bytes = uct_cuda_base_get_total_device_mem(cuda_device); if (alloc_length > (total_bytes * md->config.max_reg_ratio)) { @@ -241,7 +244,8 @@ ucs_status_t uct_cuda_base_mem_query(uct_md_h tl_md, const void *address, } ucs_memtype_cache_update(addr_mem_info.base_address, - addr_mem_info.alloc_length, &addr_mem_info); + addr_mem_info.alloc_length, addr_mem_info.type, + addr_mem_info.sys_dev); } else { addr_mem_info = default_mem_info; } diff --git a/test/gtest/ucs/test_memtype_cache.cc b/test/gtest/ucs/test_memtype_cache.cc index 0c9e79c398f..e6acaea1951 100644 --- a/test/gtest/ucs/test_memtype_cache.cc +++ b/test/gtest/ucs/test_memtype_cache.cc @@ -276,10 +276,8 @@ class test_memtype_cache : public ucs::test_with_param { return; } - ucs_memory_info_t mem_info; - mem_info.type = mem_type; - mem_info.sys_dev = UCS_SYS_DEVICE_ID_UNKNOWN; - ucs_memtype_cache_update(ptr, size, &mem_info); + ucs_memtype_cache_update(ptr, size, mem_type, + UCS_SYS_DEVICE_ID_UNKNOWN); } void memtype_cache_update(const mem_buffer &b) {