Skip to content

Commit

Permalink
UCT/CUDA_IPC: add rcache instance to support md_invalidate
Browse files Browse the repository at this point in the history
  • Loading branch information
Akshay-Venkatesh committed Oct 8, 2021
1 parent f9f4622 commit 47c96c8
Show file tree
Hide file tree
Showing 3 changed files with 177 additions and 29 deletions.
129 changes: 127 additions & 2 deletions src/uct/cuda/cuda_ipc/cuda_ipc_md.c
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,26 @@
#include <ucs/debug/memtrack_int.h>
#include <ucs/type/class.h>
#include <ucs/profile/profile.h>
#include <ucm/api/ucm.h>
#include <sys/types.h>
#include <unistd.h>

static ucs_config_field_t uct_cuda_ipc_md_config_table[] = {
{"", "", NULL,
ucs_offsetof(uct_cuda_ipc_md_config_t, super), UCS_CONFIG_TYPE_TABLE(uct_md_config_table)},

{"", "", NULL,
ucs_offsetof(uct_cuda_ipc_md_config_t, rcache),
UCS_CONFIG_TYPE_TABLE(uct_md_config_rcache_table)},

{NULL}
};

static ucs_status_t uct_cuda_ipc_md_query(uct_md_h md, uct_md_attr_t *md_attr)
{
md_attr->cap.flags = UCT_MD_FLAG_REG |
UCT_MD_FLAG_NEED_RKEY;
md_attr->cap.flags = UCT_MD_FLAG_REG |
UCT_MD_FLAG_NEED_RKEY |
UCT_MD_FLAG_INVALIDATE;
md_attr->cap.reg_mem_types = UCS_BIT(UCS_MEMORY_TYPE_CUDA);
md_attr->cap.alloc_mem_types = 0;
md_attr->cap.access_mem_types = UCS_BIT(UCS_MEMORY_TYPE_CUDA);
Expand Down Expand Up @@ -291,6 +297,100 @@ static void uct_cuda_ipc_md_close(uct_md_h uct_md)
ucs_free(md);
}

static inline uct_cuda_ipc_rcache_region_t*
uct_cuda_ipc_rcache_region_from_memh(uct_mem_h memh)
{
return ucs_container_of(memh, uct_cuda_ipc_rcache_region_t, key);
}

static void uct_cuda_ipc_mem_region_invalidate_cb(void *arg)
{
uct_completion_t *comp = arg;

uct_invoke_completion(comp, UCS_OK);
}

static ucs_status_t uct_cuda_ipc_mem_rcache_reg(uct_md_h uct_md, void *address,
size_t length, unsigned flags,
uct_mem_h *memh_p)
{
uct_cuda_ipc_md_t *md = ucs_derived_of(uct_md, uct_cuda_ipc_md_t);
ucs_rcache_region_t *rregion;
ucs_status_t status;

status = ucs_rcache_get(md->rcache, address, length, PROT_READ|PROT_WRITE,
&flags, &rregion);
if (status != UCS_OK) {
return status;
}

ucs_assert(rregion->refcount > 0);
*memh_p = &ucs_derived_of(rregion, uct_cuda_ipc_rcache_region_t)->key;
return UCS_OK;
}

static ucs_status_t
uct_cuda_ipc_mem_rcache_dereg(uct_md_h uct_md,
const uct_md_mem_dereg_params_t *params)
{
uct_cuda_ipc_md_t *md = ucs_derived_of(uct_md, uct_cuda_ipc_md_t);
uct_cuda_ipc_rcache_region_t *region;

UCT_MD_MEM_DEREG_CHECK_PARAMS(params, 1);

region = uct_cuda_ipc_rcache_region_from_memh(params->memh);
if (UCT_MD_MEM_DEREG_FIELD_VALUE(params, flags, FIELD_FLAGS, 0) &
UCT_MD_MEM_DEREG_FLAG_INVALIDATE) {
ucs_rcache_region_invalidate(md->rcache, &region->super,
uct_cuda_ipc_mem_region_invalidate_cb,
params->comp);
}

ucs_rcache_region_put(md->rcache, &region->super);
return UCS_OK;
}

static uct_md_ops_t uct_cuda_ipc_md_rcache_ops = {
.close = uct_cuda_ipc_md_close,
.query = uct_cuda_ipc_md_query,
.mkey_pack = uct_cuda_ipc_mkey_pack,
.mem_reg = uct_cuda_ipc_mem_rcache_reg,
.mem_dereg = uct_cuda_ipc_mem_rcache_dereg,
.is_sockaddr_accessible = ucs_empty_function_return_zero_int,
.detect_memory_type = ucs_empty_function_return_unsupported,
};

static ucs_status_t uct_cuda_ipc_rcache_mem_reg_cb(void *context, ucs_rcache_t *rcache,
void *arg, ucs_rcache_region_t *rregion,
uint16_t rcache_mem_reg_flags)
{
uct_cuda_ipc_rcache_region_t *region = ucs_derived_of(rregion, uct_cuda_ipc_rcache_region_t);
uct_cuda_ipc_md_t *md = context;
int *flags = arg;

return uct_cuda_ipc_mem_reg_internal(&md->super,
(void*)region->super.super.start,
region->super.super.end - region->super.super.start,
*flags, &region->key);
}

static void uct_cuda_ipc_rcache_dump_region_cb(void *context, ucs_rcache_t *rcache,
ucs_rcache_region_t *rregion, char *buf,
size_t max)
{
uct_cuda_ipc_rcache_region_t *region = ucs_derived_of(rregion, uct_cuda_ipc_rcache_region_t);
uct_cuda_ipc_key_t *key = &region->key;

snprintf(buf, max, "dev_num %d base_addr %p alloc_length %ld",
key->dev_num, (void*)key->d_bptr, key->b_len);
}

static ucs_rcache_ops_t uct_cuda_ipc_rcache_ops = {
.mem_reg = uct_cuda_ipc_rcache_mem_reg_cb,
.mem_dereg = ucs_empty_function,
.dump_region = uct_cuda_ipc_rcache_dump_region_cb
};

static ucs_status_t
uct_cuda_ipc_md_open(uct_component_t *component, const char *md_name,
const uct_md_config_t *config, uct_md_h *md_p)
Expand All @@ -305,9 +405,12 @@ uct_cuda_ipc_md_open(uct_component_t *component, const char *md_name,
.detect_memory_type = ucs_empty_function_return_unsupported
};

const uct_cuda_ipc_md_config_t *md_config;
int num_devices;
uct_cuda_ipc_md_t* md;
uct_cuda_ipc_component_t* com;
ucs_rcache_params_t rcache_params;
ucs_status_t status;

UCT_CUDA_IPC_DEVICE_GET_COUNT(num_devices);

Expand All @@ -324,6 +427,28 @@ uct_cuda_ipc_md_open(uct_component_t *component, const char *md_name,
md->uuid_map_capacity = 0;
md->uuid_map = NULL;
md->peer_accessible_cache = NULL;
md->rcache = NULL;
md_config = ucs_derived_of(config, uct_cuda_ipc_md_config_t);

uct_md_set_rcache_params(&rcache_params, &md_config->rcache);
rcache_params.region_struct_size = sizeof(uct_cuda_ipc_rcache_region_t);
rcache_params.max_alignment = ucs_get_page_size();
rcache_params.ucm_events = UCM_EVENT_MEM_TYPE_FREE;
rcache_params.context = md;
rcache_params.ops = &uct_cuda_ipc_rcache_ops;
rcache_params.flags = 0;

status = ucs_rcache_create(&rcache_params, "cuda_ipc", ucs_stats_get_root(),
&md->rcache);
if (status == UCS_OK) {
md->super.ops = &uct_cuda_ipc_md_rcache_ops;
md->reg_cost = ucs_linear_func_make(md_config->rcache.overhead, 0);
} else {
ucs_assert(md->rcache == NULL);
ucs_error("Failed to create registration cache: %s", ucs_status_string(status));
uct_cuda_ipc_md_close(&md->super);
return status;
}

com = ucs_derived_of(md->super.component, uct_cuda_ipc_component_t);
com->md = md;
Expand Down
17 changes: 15 additions & 2 deletions src/uct/cuda/cuda_ipc/cuda_ipc_md.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <uct/base/uct_md.h>
#include <uct/cuda/base/cuda_md.h>
#include <uct/cuda/base/cuda_iface.h>
#include <ucs/memory/rcache.h>
#include <ucs/type/spinlock.h>
#include <ucs/config/types.h>

Expand All @@ -18,11 +19,13 @@
* @brief cuda ipc MD descriptor
*/
typedef struct uct_cuda_ipc_md {
struct uct_md super; /**< Domain info */
struct uct_md super; /**< Domain info */
CUuuid* uuid_map;
ucs_ternary_auto_value_t *peer_accessible_cache;
int uuid_map_size;
int uuid_map_capacity;
ucs_linear_func_t reg_cost; /**< Memory registration cost */
ucs_rcache_t *rcache; /**< Needed to support MD_INVALIDATE */
} uct_cuda_ipc_md_t;

/**
Expand All @@ -39,7 +42,8 @@ extern uct_cuda_ipc_component_t uct_cuda_ipc_component;
* @brief cuda ipc domain configuration.
*/
typedef struct uct_cuda_ipc_md_config {
uct_md_config_t super;
uct_md_config_t super;
uct_md_rcache_config_t rcache;
} uct_cuda_ipc_md_config_t;


Expand All @@ -56,6 +60,15 @@ typedef struct uct_cuda_ipc_key {
} uct_cuda_ipc_key_t;


/**
* CUDA-IPC memory handle in the registration cache.
*/
typedef struct uct_cuda_ipc_rcache_region {
ucs_rcache_region_t super;
uct_cuda_ipc_key_t key;
} uct_cuda_ipc_rcache_region_t;


#define UCT_CUDA_IPC_GET_DEVICE(_cu_device) \
do { \
if (UCS_OK != \
Expand Down
60 changes: 35 additions & 25 deletions test/gtest/uct/test_md.cc
Original file line number Diff line number Diff line change
Expand Up @@ -662,38 +662,48 @@ UCS_TEST_SKIP_COND_P(test_md, invalidate, !check_caps(UCT_MD_FLAG_INVALIDATE))
comp().comp.status = UCS_OK;
comp().self = this;
params.comp = &comp().comp;
ptr = malloc(size);
for (mem_reg_count = 1; mem_reg_count < 100; mem_reg_count++) {
comp().comp.count = (mem_reg_count + 1) / 2;
m_comp_count = 0;
for (iter = 0; iter < mem_reg_count; iter++) {
status = uct_md_mem_reg(md(), ptr, size, UCT_MD_MEM_ACCESS_ALL,
&memh);
ASSERT_UCS_OK(status);
memhs.push_back(memh);

for (size_t i = 0; i < mem_buffer::supported_mem_types().size(); ++i) {
ucs_memory_type_t mem_type = mem_buffer::supported_mem_types()[i];

if (!check_reg_mem_type(mem_type)) {
continue;
}

for (iter = 0; iter < mem_reg_count; iter++) {
/* mix dereg and dereg(invalidate) operations */
ASSERT_EQ(0, m_comp_count);
memh = memhs.back();
if ((iter & 1) == 0) { /* on even iteration invalidate handle */
params.flags = UCT_MD_MEM_DEREG_FLAG_INVALIDATE;
} else {
params.flags = 0;
alloc_memory(&ptr, size, NULL, mem_type);

for (mem_reg_count = 1; mem_reg_count < 100; mem_reg_count++) {
comp().comp.count = (mem_reg_count + 1) / 2;
m_comp_count = 0;
for (iter = 0; iter < mem_reg_count; iter++) {
status = uct_md_mem_reg(md(), ptr, size, UCT_MD_MEM_ACCESS_ALL,
&memh);
ASSERT_UCS_OK(status);
memhs.push_back(memh);
}

params.memh = memh;
status = uct_md_mem_dereg_v2(md(), &params);
ASSERT_UCS_OK(status);
memhs.pop_back();
for (iter = 0; iter < mem_reg_count; iter++) {
/* mix dereg and dereg(invalidate) operations */
ASSERT_EQ(0, m_comp_count);
memh = memhs.back();
if ((iter & 1) == 0) { /* on even iteration invalidate handle */
params.flags = UCT_MD_MEM_DEREG_FLAG_INVALIDATE;
} else {
params.flags = 0;
}

params.memh = memh;
status = uct_md_mem_dereg_v2(md(), &params);
ASSERT_UCS_OK(status);
memhs.pop_back();
}

ASSERT_TRUE(memhs.empty());
EXPECT_EQ(1, m_comp_count);
}

ASSERT_TRUE(memhs.empty());
EXPECT_EQ(1, m_comp_count);
free_memory(ptr, mem_type);
}

free(ptr);
}

UCS_TEST_SKIP_COND_P(test_md, dereg_bad_arg,
Expand Down

0 comments on commit 47c96c8

Please sign in to comment.