Skip to content

Commit

Permalink
Merge pull request openucx#7716 from yosefe/topic/ucp-test-api-to-pas…
Browse files Browse the repository at this point in the history
…s-pre-registered

UCP/TEST: API to pass pre-registered memory handle to UCP operations
  • Loading branch information
yosefe authored Dec 6, 2021
2 parents 3818f46 + 3e26ae3 commit 03b935e
Show file tree
Hide file tree
Showing 5 changed files with 216 additions and 80 deletions.
11 changes: 11 additions & 0 deletions src/ucp/api/ucp.h
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,7 @@ typedef enum {
UCP_OP_ATTR_FIELD_REPLY_BUFFER = UCS_BIT(5), /**< reply_buffer field */
UCP_OP_ATTR_FIELD_MEMORY_TYPE = UCS_BIT(6), /**< memory type field */
UCP_OP_ATTR_FIELD_RECV_INFO = UCS_BIT(7), /**< recv_info field */
UCP_OP_ATTR_FIELD_MEMH = UCS_BIT(8), /**< memory handle field */

UCP_OP_ATTR_FLAG_NO_IMM_CMPL = UCS_BIT(16), /**< deny immediate completion */
UCP_OP_ATTR_FLAG_FAST_CMPL = UCS_BIT(17), /**< expedite local completion,
Expand Down Expand Up @@ -1716,6 +1717,16 @@ typedef struct {
Relevant for @a ucp_tag_recv_nbx
function. */
} recv_info;

/**
* Memory handle for pre-registered buffer.
* If the handle is provided, protocols that require registered memory can
* skip the registration step. As a result, the communication request
* overhead can be reduced and the request can be completed faster.
* The memory handle should be obtained by calling @ref ucp_mem_map.
*/
ucp_mem_h memh;

} ucp_request_param_t;


Expand Down
132 changes: 90 additions & 42 deletions test/gtest/ucp/test_ucp_am.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ extern "C" {
class test_ucp_am_base : public ucp_test {
public:
static void get_test_variants(std::vector<ucp_test_variant>& variants) {
add_variant(variants, UCP_FEATURE_AM);
add_variant_with_value(variants, UCP_FEATURE_AM, 0, "");
add_variant_with_value(variants, UCP_FEATURE_AM, TEST_FLAG_PREREG,
"prereg");
}

virtual void init() {
Expand All @@ -42,6 +44,16 @@ class test_ucp_am_base : public ucp_test {
sender().connect(&receiver(), get_ep_params());
receiver().connect(&sender(), get_ep_params());
}

protected:
enum {
TEST_FLAG_PREREG = UCS_BIT(0)
};

bool prereg() const
{
return get_variant_value(0) & TEST_FLAG_PREREG;
}
};

class test_ucp_am : public test_ucp_am_base {
Expand Down Expand Up @@ -313,11 +325,21 @@ class test_ucp_am_nbx : public test_ucp_am_base {
m_dt = ucp_dt_make_contig(1);
m_am_received = false;
m_rx_dt = ucp_dt_make_contig(1);
m_rx_memtype = UCS_MEMORY_TYPE_HOST;
m_rx_buf = NULL;
m_rx_memh = NULL;
}

protected:
virtual ucs_memory_type_t tx_memtype() const
{
return UCS_MEMORY_TYPE_HOST;
}

virtual ucs_memory_type_t rx_memtype() const
{
return UCS_MEMORY_TYPE_HOST;
}

size_t max_am_hdr()
{
ucp_worker_attr_t attr;
Expand Down Expand Up @@ -345,7 +367,7 @@ class test_ucp_am_nbx : public test_ucp_am_base {
} else if (dt == UCP_DATATYPE_IOV) {
return ucp_dt_make_iov();
} else {
ucs_assert(UCP_DATATYPE_GENERIC == dt);
ucs_assertv(UCP_DATATYPE_GENERIC == dt, "dt=%d", dt);
ucp_datatype_t ucp_dt;
ASSERT_UCS_OK(ucp_dt_create_generic(&ucp::test_dt_copy_ops, NULL,
&ucp_dt));
Expand Down Expand Up @@ -387,9 +409,10 @@ class test_ucp_am_nbx : public test_ucp_am_base {
EXPECT_EQ(check_pattern, m_hdr);
}

ucs_status_ptr_t send_am(const ucp::data_type_desc_t& dt_desc,
ucs_status_ptr_t send_am(const ucp::data_type_desc_t &dt_desc,
unsigned flags = 0, const void *hdr = NULL,
unsigned hdr_length = 0)
unsigned hdr_length = 0,
const ucp_mem_h memh = NULL)
{
ucp_request_param_t param;
param.op_attr_mask = UCP_OP_ATTR_FIELD_DATATYPE;
Expand All @@ -400,33 +423,46 @@ class test_ucp_am_nbx : public test_ucp_am_base {
param.flags = flags;
}

if (memh != NULL) {
param.op_attr_mask |= UCP_OP_ATTR_FIELD_MEMH;
param.memh = memh;
}

ucs_status_ptr_t sptr = ucp_am_send_nbx(sender().ep(), TEST_AM_NBX_ID,
hdr, hdr_length, dt_desc.buf(),
dt_desc.count(), &param);
return sptr;
}

void test_am_send_recv(size_t size, size_t header_size = 0ul,
unsigned flags = 0,
ucs_memory_type_t mem_type = UCS_MEMORY_TYPE_HOST,
unsigned data_cb_flags = 0)
unsigned flags = 0, unsigned data_cb_flags = 0)
{
mem_buffer sbuf(size, mem_type);
mem_buffer::pattern_fill(sbuf.ptr(), size, SEED, mem_type);
mem_buffer sbuf(size, tx_memtype());
sbuf.pattern_fill(SEED);
m_hdr.resize(header_size);
ucs::fill_random(m_hdr);
m_am_received = false;
ucp_mem_h memh = NULL;

set_am_data_handler(receiver(), TEST_AM_NBX_ID, am_data_cb, this,
data_cb_flags);

ucp::data_type_desc_t sdt_desc(m_dt, sbuf.ptr(), size);

if (prereg()) {
memh = sender().mem_map(sbuf.ptr(), size);
}

ucs_status_ptr_t sptr = send_am(sdt_desc, get_send_flag() | flags,
m_hdr.data(), m_hdr.size());
m_hdr.data(), m_hdr.size(), memh);

wait_for_flag(&m_am_received);
request_wait(sptr);

if (prereg()) {
sender().mem_unmap(memh);
}

EXPECT_TRUE(m_am_received);
}

Expand Down Expand Up @@ -476,8 +512,8 @@ class test_ucp_am_nbx : public test_ucp_am_base {
return UCS_OK;
}

m_rx_buf = mem_buffer::allocate(length, m_rx_memtype);
mem_buffer::pattern_fill(m_rx_buf, length, 0ul, m_rx_memtype);
m_rx_buf = mem_buffer::allocate(length, rx_memtype());
mem_buffer::pattern_fill(m_rx_buf, length, 0ul, rx_memtype());

m_rx_dt_desc.make(m_rx_dt, m_rx_buf, length);

Expand All @@ -494,6 +530,13 @@ class test_ucp_am_nbx : public test_ucp_am_base {
params.cb.recv_am = am_data_recv_cb;
params.user_data = this;
params.recv_info.length = &rx_length;

if (prereg()) {
params.op_attr_mask |= UCP_OP_ATTR_FIELD_MEMH;
m_rx_memh = receiver().mem_map(m_rx_buf, length);
params.memh = m_rx_memh;
}

ucs_status_ptr_t sp = ucp_am_recv_data_nbx(receiver().worker(),
data, m_rx_dt_desc.buf(),
m_rx_dt_desc.count(),
Expand All @@ -515,8 +558,12 @@ class test_ucp_am_nbx : public test_ucp_am_base {
{
ASSERT_FALSE(m_am_received);
m_am_received = true;
mem_buffer::pattern_check(m_rx_buf, length, SEED, m_rx_memtype);
mem_buffer::release(m_rx_buf, m_rx_memtype);
mem_buffer::pattern_check(m_rx_buf, length, SEED, rx_memtype());

if (m_rx_memh != NULL) {
receiver().mem_unmap(m_rx_memh);
}
mem_buffer::release(m_rx_buf, rx_memtype());
}

static ucs_status_t am_data_cb(void *arg, const void *header,
Expand Down Expand Up @@ -568,9 +615,9 @@ class test_ucp_am_nbx : public test_ucp_am_base {
volatile bool m_am_received;
std::string m_hdr;
ucp_datatype_t m_rx_dt;
ucs_memory_type_t m_rx_memtype;
ucp::data_type_desc_t m_rx_dt_desc;
void *m_rx_buf;
ucp_mem_h m_rx_memh;
};

UCS_TEST_P(test_ucp_am_nbx, set_invalid_handler)
Expand Down Expand Up @@ -856,7 +903,6 @@ class test_ucp_am_nbx_eager_memtype : public test_ucp_am_nbx {
{
modify_config("RNDV_THRESH", "inf");
test_ucp_am_nbx::init();
m_rx_memtype = static_cast<ucs_memory_type_t>(get_variant_value(1));
}

static void base_test_generator(std::vector<ucp_test_variant> &variants)
Expand All @@ -878,12 +924,22 @@ class test_ucp_am_nbx_eager_memtype : public test_ucp_am_nbx {
add_variant_memtypes(variants, base_test_generator,
std::numeric_limits<uint64_t>::max());
}

private:
virtual ucs_memory_type_t tx_memtype() const
{
return static_cast<ucs_memory_type_t>(get_variant_value(1));
}

virtual ucs_memory_type_t rx_memtype() const
{
return static_cast<ucs_memory_type_t>(get_variant_value(2));
}
};

UCS_TEST_P(test_ucp_am_nbx_eager_memtype, basic)
{
ucs_memory_type_t mt = static_cast<ucs_memory_type_t>(get_variant_value(0));
test_am_send_recv(16 * UCS_KBYTE, 8, 0, mt);
test_am_send_recv(16 * UCS_KBYTE, 8, 0);
}

UCP_INSTANTIATE_TEST_CASE_GPU_AWARE(test_ucp_am_nbx_eager_memtype)
Expand All @@ -904,13 +960,13 @@ class test_ucp_am_nbx_eager_data_release : public test_ucp_am_nbx {
static void get_test_variants(std::vector<ucp_test_variant> &variants)
{
add_variant_values(variants, test_ucp_am_base::get_test_variants, 0);
add_variant_values(variants, test_ucp_am_base::get_test_variants,
1, "proto");
add_variant_values(variants, test_ucp_am_base::get_test_variants, 1,
"proto");
}

virtual unsigned enable_proto()
{
return get_variant_value(0);
return get_variant_value(1);
}

virtual ucs_status_t
Expand All @@ -932,12 +988,10 @@ class test_ucp_am_nbx_eager_data_release : public test_ucp_am_nbx {
void test_data_release(size_t size)
{
size_t hdr_size = ucs_min(max_am_hdr(), 8);
test_am_send_recv(size, 0, 0, UCS_MEMORY_TYPE_HOST,
UCP_AM_FLAG_PERSISTENT_DATA);
test_am_send_recv(size, 0, 0, UCP_AM_FLAG_PERSISTENT_DATA);
ucp_am_data_release(receiver().worker(), m_data_ptr);

test_am_send_recv(size, hdr_size, 0, UCS_MEMORY_TYPE_HOST,
UCP_AM_FLAG_PERSISTENT_DATA);
test_am_send_recv(size, hdr_size, 0, UCP_AM_FLAG_PERSISTENT_DATA);
ucp_am_data_release(receiver().worker(), m_data_ptr);
}

Expand Down Expand Up @@ -986,7 +1040,7 @@ class test_ucp_am_nbx_align : public test_ucp_am_nbx {

virtual unsigned get_send_flag()
{
return get_variant_value(0);
return get_variant_value(1);
}

virtual ucs_status_t
Expand All @@ -1010,14 +1064,12 @@ class test_ucp_am_nbx_align : public test_ucp_am_nbx {

UCS_TEST_P(test_ucp_am_nbx_align, basic)
{
test_am_send_recv(fragment_size() / 2, 0, 0, UCS_MEMORY_TYPE_HOST,
UCP_AM_FLAG_PERSISTENT_DATA);
test_am_send_recv(fragment_size() / 2, 0, 0, UCP_AM_FLAG_PERSISTENT_DATA);
}

UCS_TEST_P(test_ucp_am_nbx_align, multi)
{
test_am_send_recv(fragment_size() * 5, 0, 0, UCS_MEMORY_TYPE_HOST,
UCP_AM_FLAG_PERSISTENT_DATA);
test_am_send_recv(fragment_size() * 5, 0, 0, UCP_AM_FLAG_PERSISTENT_DATA);
}

UCP_INSTANTIATE_TEST_CASE(test_ucp_am_nbx_align)
Expand Down Expand Up @@ -1073,8 +1125,8 @@ class test_ucp_am_nbx_dts : public test_ucp_am_nbx {
{
test_ucp_am_nbx::init();

m_dt = make_dt(get_variant_value(0));
m_rx_dt = make_dt(get_variant_value(1));
m_dt = make_dt(get_variant_value(1));
m_rx_dt = make_dt(get_variant_value(2));
}

void cleanup()
Expand All @@ -1086,7 +1138,7 @@ class test_ucp_am_nbx_dts : public test_ucp_am_nbx {

virtual unsigned get_send_flag()
{
return get_variant_value(2);
return get_variant_value(3);
}

virtual ucs_status_t
Expand Down Expand Up @@ -1354,8 +1406,8 @@ class test_ucp_am_nbx_rndv_dts : public test_ucp_am_nbx_rndv {
{
test_ucp_am_nbx::init();

m_dt = make_dt(get_variant_value(0));
m_rx_dt = make_dt(get_variant_value(1));
m_dt = make_dt(get_variant_value(1));
m_rx_dt = make_dt(get_variant_value(2));
}

void cleanup()
Expand Down Expand Up @@ -1387,16 +1439,13 @@ class test_ucp_am_nbx_rndv_memtype : public test_ucp_am_nbx_rndv {
void init()
{
modify_config("RNDV_THRESH", "128");

test_ucp_am_nbx::init();
m_rx_memtype = static_cast<ucs_memory_type_t>(get_variant_value(1));
}
};

UCS_TEST_P(test_ucp_am_nbx_rndv_memtype, rndv)
{
ucs_memory_type_t mt = static_cast<ucs_memory_type_t>(get_variant_value(0));
test_am_send_recv(64 * UCS_KBYTE, 8, 0, mt);
test_am_send_recv(64 * UCS_KBYTE, 8, 0);
}

UCP_INSTANTIATE_TEST_CASE_GPU_AWARE(test_ucp_am_nbx_rndv_memtype);
Expand Down Expand Up @@ -1424,8 +1473,7 @@ class test_ucp_am_nbx_rndv_memtype_disable_zcopy :
disable_rndv_zcopy_config(sender(), zcopy_caps);
disable_rndv_zcopy_config(receiver(), zcopy_caps);

ucs_memory_type_t mt = static_cast<ucs_memory_type_t>(get_variant_value(0));
test_am_send_recv(64 * UCS_KBYTE, 8, 0, mt);
test_am_send_recv(64 * UCS_KBYTE, 8, 0);
}
};

Expand Down
Loading

0 comments on commit 03b935e

Please sign in to comment.