Skip to content

Commit

Permalink
UCT/CUDA-IPC: stub impl for md_invalidate support in cuda-ipc
Browse files Browse the repository at this point in the history
  • Loading branch information
Akshay-Venkatesh committed Oct 27, 2021
1 parent f9f4622 commit 3829297
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 27 deletions.
5 changes: 3 additions & 2 deletions src/uct/cuda/cuda_ipc/cuda_ipc_md.c
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,9 @@ static ucs_config_field_t uct_cuda_ipc_md_config_table[] = {

static ucs_status_t uct_cuda_ipc_md_query(uct_md_h md, uct_md_attr_t *md_attr)
{
md_attr->cap.flags = UCT_MD_FLAG_REG |
UCT_MD_FLAG_NEED_RKEY;
md_attr->cap.flags = UCT_MD_FLAG_REG |
UCT_MD_FLAG_NEED_RKEY |
UCT_MD_FLAG_INVALIDATE;
md_attr->cap.reg_mem_types = UCS_BIT(UCS_MEMORY_TYPE_CUDA);
md_attr->cap.alloc_mem_types = 0;
md_attr->cap.access_mem_types = UCS_BIT(UCS_MEMORY_TYPE_CUDA);
Expand Down
64 changes: 39 additions & 25 deletions test/gtest/uct/test_md.cc
Original file line number Diff line number Diff line change
Expand Up @@ -655,45 +655,59 @@ UCS_TEST_SKIP_COND_P(test_md, invalidate, !check_caps(UCT_MD_FLAG_INVALIDATE))
ucs_status_t status;
uct_md_mem_dereg_params_t params;

if (!strcmp(GetParam().md_name.c_str(), "cuda_ipc")) {
UCS_TEST_SKIP_R("test not needed with cuda-ipc");
}

params.field_mask = UCT_MD_MEM_DEREG_FIELD_FLAGS |
UCT_MD_MEM_DEREG_FIELD_MEMH |
UCT_MD_MEM_DEREG_FIELD_COMPLETION;
comp().comp.func = dereg_cb;
comp().comp.status = UCS_OK;
comp().self = this;
params.comp = &comp().comp;
ptr = malloc(size);
for (mem_reg_count = 1; mem_reg_count < 100; mem_reg_count++) {
comp().comp.count = (mem_reg_count + 1) / 2;
m_comp_count = 0;
for (iter = 0; iter < mem_reg_count; iter++) {
status = uct_md_mem_reg(md(), ptr, size, UCT_MD_MEM_ACCESS_ALL,
&memh);
ASSERT_UCS_OK(status);
memhs.push_back(memh);

for (size_t i = 0; i < mem_buffer::supported_mem_types().size(); ++i) {
ucs_memory_type_t mem_type = mem_buffer::supported_mem_types()[i];

if (!check_reg_mem_type(mem_type)) {
continue;
}

for (iter = 0; iter < mem_reg_count; iter++) {
/* mix dereg and dereg(invalidate) operations */
ASSERT_EQ(0, m_comp_count);
memh = memhs.back();
if ((iter & 1) == 0) { /* on even iteration invalidate handle */
params.flags = UCT_MD_MEM_DEREG_FLAG_INVALIDATE;
} else {
params.flags = 0;
alloc_memory(&ptr, size, NULL, mem_type);

for (mem_reg_count = 1; mem_reg_count < 100; mem_reg_count++) {
comp().comp.count = (mem_reg_count + 1) / 2;
m_comp_count = 0;
for (iter = 0; iter < mem_reg_count; iter++) {
status = uct_md_mem_reg(md(), ptr, size, UCT_MD_MEM_ACCESS_ALL,
&memh);
ASSERT_UCS_OK(status);
memhs.push_back(memh);
}

params.memh = memh;
status = uct_md_mem_dereg_v2(md(), &params);
ASSERT_UCS_OK(status);
memhs.pop_back();
for (iter = 0; iter < mem_reg_count; iter++) {
/* mix dereg and dereg(invalidate) operations */
ASSERT_EQ(0, m_comp_count);
memh = memhs.back();
if ((iter & 1) == 0) { /* on even iteration invalidate handle */
params.flags = UCT_MD_MEM_DEREG_FLAG_INVALIDATE;
} else {
params.flags = 0;
}

params.memh = memh;
status = uct_md_mem_dereg_v2(md(), &params);
ASSERT_UCS_OK(status);
memhs.pop_back();
}

ASSERT_TRUE(memhs.empty());
EXPECT_EQ(1, m_comp_count);
}

ASSERT_TRUE(memhs.empty());
EXPECT_EQ(1, m_comp_count);
free_memory(ptr, mem_type);
}

free(ptr);
}

UCS_TEST_SKIP_COND_P(test_md, dereg_bad_arg,
Expand Down

0 comments on commit 3829297

Please sign in to comment.