Skip to content

Commit

Permalink
Merge pull request #7791 from yosefe/topic/ucp-ucs-uct-fix-memtype-ca…
Browse files Browse the repository at this point in the history
…che-region

UCP/UCS/UCT: Fix memtype_cache region info after merge
  • Loading branch information
yosefe authored Jan 27, 2022
2 parents 5f7879d + 647242a commit cda6aae
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 84 deletions.
17 changes: 14 additions & 3 deletions src/ucp/core/ucp_mm.c
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ ucs_status_t ucp_mem_rereg_mds(ucp_context_h context, ucp_md_map_t reg_md_map,
uct_mem_h *prev_uct_memh;
ucp_md_map_t new_md_map;
const uct_md_attr_t *md_attr;
void *end_address UCS_V_UNUSED;
unsigned prev_num_memh;
unsigned md_index;
ucs_status_t status;
Expand Down Expand Up @@ -119,13 +120,23 @@ ucs_status_t ucp_mem_rereg_mds(ucp_context_h context, ucp_md_map_t reg_md_map,
continue;
}

base_address = address;
reg_length = length;

if (context->config.ext.reg_whole_alloc_bitmap & UCS_BIT(mem_type)) {
ucp_memory_detect_internal(context, address, length, &mem_info);
base_address = mem_info.base_address;
reg_length = mem_info.alloc_length;
end_address = UCS_PTR_BYTE_OFFSET(base_address, reg_length);
ucs_trace("extending %p..%p to %p..%p", address,
UCS_PTR_BYTE_OFFSET(address, length), base_address,
end_address);
ucs_assertv(base_address <= address,
"base_address=%p address=%p", base_address,
address);
ucs_assertv(end_address >= UCS_PTR_BYTE_OFFSET(address, length),
"end_address=%p address+length=%p", end_address,
UCS_PTR_BYTE_OFFSET(address, length));
} else {
base_address = address;
reg_length = length;
}

/* MD supports registration, register new memh on it */
Expand Down
123 changes: 66 additions & 57 deletions src/ucs/memory/memtype_cache.c
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,25 @@ static ucs_spinlock_t ucs_memtype_cache_global_instance_lock;
ucs_memtype_cache_t *ucs_memtype_cache_global_instance = NULL;


#define UCS_MEMTYPE_CACHE_REGION_FMT UCS_PGT_REGION_FMT " %s dev %s"
#define UCS_MEMTYPE_CACHE_REGION_ARG(_region) \
UCS_PGT_REGION_ARG(&(_region)->super), \
ucs_memory_type_names[(_region)->mem_type], \
ucs_topo_sys_device_get_name((_region)->sys_dev)

typedef enum {
UCS_MEMTYPE_CACHE_ACTION_SET_MEMTYPE,
UCS_MEMTYPE_CACHE_ACTION_REMOVE
} ucs_memtype_cache_action_t;

struct ucs_memtype_cache_region {
ucs_pgt_region_t super; /**< Base class - page table region */
ucs_list_link_t list; /**< List element */
ucs_memory_type_t mem_type; /**< Memory type, use uint8 for compact size */
ucs_sys_device_t sys_dev; /**< System device index */
};


static UCS_CLASS_INIT_FUNC(ucs_memtype_cache_t);
static UCS_CLASS_CLEANUP_FUNC(ucs_memtype_cache_t);

Expand Down Expand Up @@ -79,14 +93,6 @@ ucs_memory_info_set_unknown(ucs_memory_info_t *mem_info)
mem_info->alloc_length = -1;
}

void ucs_memory_info_set_host(ucs_memory_info_t *mem_info)
{
mem_info->type = UCS_MEMORY_TYPE_HOST;
mem_info->sys_dev = UCS_SYS_DEVICE_ID_UNKNOWN;
mem_info->base_address = NULL;
mem_info->alloc_length = -1;
}

static ucs_pgt_dir_t *ucs_memtype_cache_pgt_dir_alloc(const ucs_pgtable_t *pgtable)
{
void *ptr;
Expand All @@ -110,7 +116,8 @@ static void ucs_memtype_cache_pgt_dir_release(const ucs_pgtable_t *pgtable,
*/
static void ucs_memtype_cache_insert(ucs_memtype_cache_t *memtype_cache,
ucs_pgt_addr_t start, ucs_pgt_addr_t end,
const ucs_memory_info_t *mem_info)
ucs_memory_type_t mem_type,
ucs_sys_device_t sys_dev)
{
ucs_memtype_cache_region_t *region;
ucs_status_t status;
Expand All @@ -131,23 +138,21 @@ static void ucs_memtype_cache_insert(ucs_memtype_cache_t *memtype_cache,

region->super.start = start;
region->super.end = end;
region->mem_info = *mem_info;
region->mem_type = mem_type;
region->sys_dev = sys_dev;

status = UCS_PROFILE_CALL(ucs_pgtable_insert, &memtype_cache->pgtable,
&region->super);
if (status != UCS_OK) {
ucs_error("failed to insert region " UCS_PGT_REGION_FMT ": %s",
UCS_PGT_REGION_ARG(&region->super), ucs_status_string(status));
ucs_error("failed to insert " UCS_MEMTYPE_CACHE_REGION_FMT ": %s",
UCS_MEMTYPE_CACHE_REGION_ARG(region),
ucs_status_string(status));
ucs_free(region);
return;
}

ucs_trace("memtype_cache: insert " UCS_PGT_REGION_FMT " mem_type %s dev %s"
" base_addr %p alloc_length %ld",
UCS_PGT_REGION_ARG(&region->super),
ucs_memory_type_names[mem_info->type],
ucs_topo_sys_device_get_name(mem_info->sys_dev),
mem_info->base_address, mem_info->alloc_length);
ucs_trace("memtype_cache: insert " UCS_MEMTYPE_CACHE_REGION_FMT,
UCS_MEMTYPE_CACHE_REGION_ARG(region));
}

static void ucs_memtype_cache_region_collect_callback(const ucs_pgtable_t *pgtable,
Expand All @@ -161,15 +166,15 @@ static void ucs_memtype_cache_region_collect_callback(const ucs_pgtable_t *pgtab
}

UCS_PROFILE_FUNC_VOID(ucs_memtype_cache_update_internal,
(memtype_cache, address, size, mem_info, action),
ucs_memtype_cache_t *memtype_cache,
const void *address, size_t size,
const ucs_memory_info_t *mem_info,
(memtype_cache, address, size, mem_type, sys_dev, action),
ucs_memtype_cache_t *memtype_cache, const void *address,
size_t size, ucs_memory_type_t mem_type,
ucs_sys_device_t sys_dev,
ucs_memtype_cache_action_t action)
{
ucs_pgt_addr_t start, end, search_start, search_end;
ucs_memtype_cache_region_t *region, *tmp;
UCS_LIST_HEAD(region_list);
ucs_pgt_addr_t start, end, search_start, search_end;
ucs_status_t status;

if (!size) {
Expand All @@ -179,13 +184,11 @@ UCS_PROFILE_FUNC_VOID(ucs_memtype_cache_update_internal,
start = ucs_align_down_pow2((uintptr_t)address, UCS_PGT_ADDR_ALIGN);
end = ucs_align_up_pow2 ((uintptr_t)address + size, UCS_PGT_ADDR_ALIGN);

ucs_trace("%s: [0x%lx..0x%lx] mem_type %s dev %s"
" base_addr %p alloc_length %ld",
ucs_trace("%s: [0x%lx..0x%lx] mem_type %s dev %s",
(action == UCS_MEMTYPE_CACHE_ACTION_SET_MEMTYPE) ? "update" :
"remove",
start, end, ucs_memory_type_names[mem_info->type],
ucs_topo_sys_device_get_name(mem_info->sys_dev),
mem_info->base_address, mem_info->alloc_length);
start, end, ucs_memory_type_names[mem_type],
ucs_topo_sys_device_get_name(sys_dev));

search_start = start;
search_end = end - 1;
Expand All @@ -198,11 +201,14 @@ UCS_PROFILE_FUNC_VOID(ucs_memtype_cache_update_internal,
&region_list);
ucs_list_for_each_safe(region, tmp, &region_list, list) {
if (action == UCS_MEMTYPE_CACHE_ACTION_SET_MEMTYPE) {
if (region->mem_info.type == mem_info->type) {
if (region->mem_type == mem_type) {
/* merge current region with overlapping or adjacent regions
* of same memory type */
start = ucs_min(start, region->super.start);
end = ucs_max(end, region->super.end);
ucs_trace("merge with " UCS_MEMTYPE_CACHE_REGION_FMT
": [0x%lx..0x%lx]",
UCS_MEMTYPE_CACHE_REGION_ARG(region), start, end);
} else if ((region->super.end < start) ||
(region->super.start >= end)) {
/* ignore regions which are not really overlapping and can't
Expand All @@ -214,23 +220,18 @@ UCS_PROFILE_FUNC_VOID(ucs_memtype_cache_update_internal,

status = ucs_pgtable_remove(&memtype_cache->pgtable, &region->super);
if (status != UCS_OK) {
ucs_error("failed to remove " UCS_PGT_REGION_FMT
" from memtype_cache: %s",
UCS_PGT_REGION_ARG(&region->super),
ucs_error("failed to remove " UCS_MEMTYPE_CACHE_REGION_FMT ": %s",
UCS_MEMTYPE_CACHE_REGION_ARG(region),
ucs_status_string(status));
goto out_unlock;
}

ucs_trace("memtype_cache: removed " UCS_PGT_REGION_FMT " %s dev %s"
" base_addr %p alloc_length %ld",
UCS_PGT_REGION_ARG(&region->super),
ucs_memory_type_names[region->mem_info.type],
ucs_topo_sys_device_get_name(region->mem_info.sys_dev),
mem_info->base_address, mem_info->alloc_length);
ucs_trace("memtype_cache: removed " UCS_MEMTYPE_CACHE_REGION_FMT,
UCS_MEMTYPE_CACHE_REGION_ARG(region));
}

if (action == UCS_MEMTYPE_CACHE_ACTION_SET_MEMTYPE) {
ucs_memtype_cache_insert(memtype_cache, start, end, mem_info);
ucs_memtype_cache_insert(memtype_cache, start, end, mem_type, sys_dev);
}

/* slice old regions by the new region, to preserve the previous memory type
Expand All @@ -240,12 +241,12 @@ UCS_PROFILE_FUNC_VOID(ucs_memtype_cache_update_internal,
if (start > region->super.start) {
/* create previous region */
ucs_memtype_cache_insert(memtype_cache, region->super.start, start,
&region->mem_info);
region->mem_type, region->sys_dev);
}
if (end < region->super.end) {
/* create next region */
ucs_memtype_cache_insert(memtype_cache, end, region->super.end,
&region->mem_info);
region->mem_type, region->sys_dev);
}

ucs_free(region);
Expand All @@ -256,36 +257,29 @@ UCS_PROFILE_FUNC_VOID(ucs_memtype_cache_update_internal,
}

void ucs_memtype_cache_update(const void *address, size_t size,
const ucs_memory_info_t *mem_info)
ucs_memory_type_t mem_type,
ucs_sys_device_t sys_dev)
{
if (ucs_memtype_cache_global_instance == NULL) {
return;
}

ucs_memtype_cache_update_internal(ucs_memtype_cache_global_instance,
address, size, mem_info,
address, size, mem_type, sys_dev,
UCS_MEMTYPE_CACHE_ACTION_SET_MEMTYPE);
}

void ucs_memtype_cache_remove(const void *address, size_t size)
{
ucs_memory_info_t mem_info;

ucs_memory_info_set_unknown(&mem_info);
ucs_memtype_cache_update_internal(ucs_memtype_cache_global_instance,
address, size, &mem_info,
address, size, UCS_MEMORY_TYPE_UNKNOWN,
UCS_SYS_DEVICE_ID_UNKNOWN,
UCS_MEMTYPE_CACHE_ACTION_REMOVE);
}

static void ucs_memtype_cache_event_callback(ucm_event_type_t event_type,
ucm_event_t *event, void *arg)
ucm_event_t *event, void *arg)
{
ucs_memory_info_t mem_info = {
.type = event->mem_type.mem_type,
.sys_dev = UCS_SYS_DEVICE_ID_UNKNOWN,
.base_address = event->mem_type.address,
.alloc_length = event->mem_type.size,
};
ucs_memtype_cache_action_t action;

if (event_type & UCM_EVENT_MEM_TYPE_ALLOC) {
Expand All @@ -296,8 +290,14 @@ static void ucs_memtype_cache_event_callback(ucm_event_type_t event_type,
return;
}

ucs_trace("dispatching mem event %d address %p length %zu mem_type %s",
event_type, event->mem_type.address, event->mem_type.size,
ucs_memory_type_names[event->mem_type.mem_type]);

ucs_memtype_cache_update_internal(arg, event->mem_type.address,
event->mem_type.size, &mem_info, action);
event->mem_type.size,
event->mem_type.mem_type,
UCS_SYS_DEVICE_ID_UNKNOWN, action);
}

static void ucs_memtype_cache_purge(ucs_memtype_cache_t *memtype_cache)
Expand Down Expand Up @@ -333,14 +333,23 @@ UCS_PROFILE_FUNC(ucs_status_t, ucs_memtype_cache_lookup,
pgt_region = UCS_PROFILE_CALL(ucs_pgtable_lookup, &memtype_cache->pgtable,
start);
if (pgt_region == NULL) {
ucs_trace("address 0x%lx not found", start);
status = UCS_ERR_NO_ELEM;
goto out_unlock;
}

region = ucs_derived_of(pgt_region, ucs_memtype_cache_region_t);
if (ucs_likely((start + size) <= pgt_region->end)) {
region = ucs_derived_of(pgt_region, ucs_memtype_cache_region_t);
*mem_info = region->mem_info;
mem_info->base_address = (void*)region->super.start;
mem_info->alloc_length = region->super.end - region->super.start;
mem_info->type = region->mem_type;
mem_info->sys_dev = region->sys_dev;
ucs_trace_data("0x%lx..0x%lx found in " UCS_MEMTYPE_CACHE_REGION_FMT,
start, start + size,
UCS_MEMTYPE_CACHE_REGION_ARG(region));
} else {
ucs_trace("0x%lx..0x%lx not contained in " UCS_MEMTYPE_CACHE_REGION_FMT,
start, start + size, UCS_MEMTYPE_CACHE_REGION_ARG(region));
ucs_memory_info_set_unknown(mem_info);
}
status = UCS_OK;
Expand Down
40 changes: 21 additions & 19 deletions src/ucs/memory/memtype_cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,21 +29,13 @@ extern ucs_memtype_cache_t *ucs_memtype_cache_global_instance;

/* Memory information record */
typedef struct ucs_memory_info {
ucs_memory_type_t type; /**< Memory type, use uint8 for compact size */
ucs_memory_type_t type; /**< Memory type */
ucs_sys_device_t sys_dev; /**< System device index */
void *base_address; /**< Base address of the underlying allocation */
size_t alloc_length; /**< Whole length of the underlying allocation */
} ucs_memory_info_t;


struct ucs_memtype_cache_region {
ucs_pgt_region_t super; /**< Base class - page table region */
ucs_list_link_t list; /**< List element */
ucs_memory_info_t mem_info; /**< Memory type and system device the address
belongs to */
};


struct ucs_memtype_cache {
pthread_rwlock_t lock; /**< protests the page table */
ucs_pgtable_t pgtable; /**< Page table to hold the regions */
Expand Down Expand Up @@ -79,11 +71,14 @@ ucs_status_t ucs_memtype_cache_lookup(const void *address, size_t size,
*
* @param [in] address Start address to update.
* @param [in] size Size of the memory to update.
* @param [in] mem_info Set the memory info of the address range to this
* @param [in] mem_type Set the memory type of the address range to this
* value.
* @param [in] sys_dev Set the system device of the address range to
* this value.
*/
void ucs_memtype_cache_update(const void *address, size_t size,
const ucs_memory_info_t *mem_info);
ucs_memory_type_t mem_type,
ucs_sys_device_t sys_dev);


/**
Expand All @@ -95,14 +90,6 @@ void ucs_memtype_cache_update(const void *address, size_t size,
void ucs_memtype_cache_remove(const void *address, size_t size);


/**
* Helper function to set memory info structure to host memory type.
*
* @param [out] mem_info Pointer to memory info structure.
*/
void ucs_memory_info_set_host(ucs_memory_info_t *mem_info);


/**
* Find if global memtype_cache is empty.
*
Expand All @@ -114,6 +101,21 @@ static UCS_F_ALWAYS_INLINE int ucs_memtype_cache_is_empty()
(ucs_memtype_cache_global_instance->pgtable.num_regions == 0);
}


/**
* Helper function to set memory info structure to host memory type.
*
* @param [out] mem_info Pointer to memory info structure.
*/
static UCS_F_ALWAYS_INLINE void
ucs_memory_info_set_host(ucs_memory_info_t *mem_info)
{
mem_info->type = UCS_MEMORY_TYPE_HOST;
mem_info->sys_dev = UCS_SYS_DEVICE_ID_UNKNOWN;
mem_info->base_address = NULL;
mem_info->alloc_length = -1;
}

END_C_DECLS

#endif
6 changes: 5 additions & 1 deletion src/uct/cuda/base/cuda_md.c
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,9 @@ uct_cuda_base_query_attributes(uct_cuda_copy_md_t *md, const void *address,
return UCS_ERR_INVALID_ADDR;
}

ucs_trace("query address %p: 0x%llx..0x%llx length %zu", address,
base_address, base_address + alloc_length, alloc_length);

if (md->config.alloc_whole_reg == UCS_CONFIG_AUTO) {
total_bytes = uct_cuda_base_get_total_device_mem(cuda_device);
if (alloc_length > (total_bytes * md->config.max_reg_ratio)) {
Expand Down Expand Up @@ -241,7 +244,8 @@ ucs_status_t uct_cuda_base_mem_query(uct_md_h tl_md, const void *address,
}

ucs_memtype_cache_update(addr_mem_info.base_address,
addr_mem_info.alloc_length, &addr_mem_info);
addr_mem_info.alloc_length, addr_mem_info.type,
addr_mem_info.sys_dev);
} else {
addr_mem_info = default_mem_info;
}
Expand Down
Loading

0 comments on commit cda6aae

Please sign in to comment.