Skip to content

Commit

Permalink
Merge pull request openucx#2 from edgargabriel/topic/memtype_detectio…
Browse files Browse the repository at this point in the history
…n_fix

UCT/ROCM: fix memory type detection
  • Loading branch information
edgargabriel authored May 3, 2022
2 parents adb0649 + 6efdbb4 commit f19386b
Show file tree
Hide file tree
Showing 33 changed files with 1,407 additions and 1,298 deletions.
7 changes: 1 addition & 6 deletions examples/ucp_hello_world.c
Original file line number Diff line number Diff line change
Expand Up @@ -350,17 +350,12 @@ static int run_ucx_client(ucp_worker_h ucp_worker)
return ret;
}

static void flush_callback(void *request, ucs_status_t status, void *user_data)
{
}

static ucs_status_t flush_ep(ucp_worker_h worker, ucp_ep_h ep)
{
ucp_request_param_t param;
void *request;

param.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK;
param.cb.send = flush_callback;
param.op_attr_mask = 0;
request = ucp_ep_flush_nbx(ep, &param);
if (request == NULL) {
return UCS_OK;
Expand Down
829 changes: 13 additions & 816 deletions src/ucp/api/ucp.h

Large diffs are not rendered by default.

862 changes: 862 additions & 0 deletions src/ucp/api/ucp_compat.h

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions src/ucp/core/ucp_am.c
Original file line number Diff line number Diff line change
Expand Up @@ -1129,6 +1129,7 @@ UCS_PROFILE_FUNC(ucs_status_ptr_t, ucp_am_recv_data_nbx,
req->recv.length = ucp_dt_length(datatype, count, buffer,
&req->recv.state);
req->recv.mem_type = mem_type;
req->recv.op_attr = param->op_attr_mask;
req->recv.am.desc = desc;
rts = data_desc;

Expand Down
61 changes: 35 additions & 26 deletions src/ucp/core/ucp_mm.c
Original file line number Diff line number Diff line change
Expand Up @@ -286,13 +286,18 @@ static inline int ucp_mem_map_is_allocate(const ucp_mem_map_params_t *params)
(params->flags & UCP_MEM_MAP_ALLOCATE);
}

void ucp_memh_dereg(ucp_context_h context, ucp_mem_h memh, ucp_md_map_t md_map)
static void ucp_memh_dereg(ucp_context_h context, ucp_mem_h memh,
ucp_md_map_t md_map)
{
ucp_md_index_t md_index;
ucs_status_t status;

/* Unregister from all memory domains */
ucs_for_each_bit(md_index, md_map) {
ucs_assertv(md_index != memh->alloc_md_index,
"memh %p: md_index %u alloc_md_index %u", memh, md_index,
memh->alloc_md_index);

ucs_trace("de-registering memh[%d]=%p", md_index, memh->uct[md_index]);
ucs_assert(context->tl_mds[md_index].attr.cap.flags & UCT_MD_FLAG_REG);
status = uct_md_mem_dereg(context->tl_mds[md_index].md,
Expand All @@ -307,6 +312,33 @@ void ucp_memh_dereg(ucp_context_h context, ucp_mem_h memh, ucp_md_map_t md_map)
}
}

void ucp_memh_unmap(ucp_context_h context, ucp_mem_h memh, ucp_md_map_t md_map)
{
uct_allocated_memory_t mem;
ucs_status_t status;

mem.address = ucp_memh_address(memh);
mem.length = ucp_memh_length(memh);
mem.method = memh->alloc_method;

if (mem.method == UCT_ALLOC_METHOD_MD) {
ucs_assert(memh->alloc_md_index != UCP_NULL_RESOURCE);
mem.md = context->tl_mds[memh->alloc_md_index].md;
mem.memh = memh->uct[memh->alloc_md_index];
md_map &= ~UCS_BIT(memh->alloc_md_index);
}

ucp_memh_dereg(context, memh, md_map);

/* If the memory was also allocated, release it */
if (memh->alloc_method != UCT_ALLOC_METHOD_LAST) {
status = uct_mem_free(&mem);
if (status != UCS_OK) {
ucs_warn("failed to free: %s", ucs_status_string(status));
}
}
}

static ucs_status_t ucp_memh_register(ucp_context_h context, ucp_mem_h memh,
ucp_md_map_t md_map, void *address,
size_t length, unsigned uct_flags)
Expand Down Expand Up @@ -950,32 +982,9 @@ static ucs_status_t ucp_mem_rcache_mem_reg_cb(void *context, ucs_rcache_t *rcach
static void ucp_mem_rcache_mem_dereg_cb(void *ctx, ucs_rcache_t *rcache,
ucs_rcache_region_t *rregion)
{
ucp_mem_h memh = ucs_derived_of(rregion, ucp_mem_t);
ucp_md_map_t md_map = memh->md_map;
ucp_context_h context = ctx;
uct_allocated_memory_t mem;
ucs_status_t status;

mem.address = ucp_memh_address(memh);
mem.length = ucp_memh_length(memh);
mem.method = memh->alloc_method;

if (mem.method == UCT_ALLOC_METHOD_MD) {
ucs_assert(memh->alloc_md_index != UCP_NULL_RESOURCE);
mem.md = context->tl_mds[memh->alloc_md_index].md;
mem.memh = memh->uct[memh->alloc_md_index];
md_map &= ~UCS_BIT(memh->alloc_md_index);
}

ucp_memh_dereg(context, memh, md_map);
ucp_mem_h memh = ucs_derived_of(rregion, ucp_mem_t);

/* If the memory was also allocated, release it */
if (memh->alloc_method != UCT_ALLOC_METHOD_LAST) {
status = uct_mem_free(&mem);
if (status != UCS_OK) {
ucs_warn("failed to free: %s", ucs_status_string(status));
}
}
ucp_memh_unmap((ucp_context_h)ctx, memh, memh->md_map);
}

static void ucp_mem_rcache_dump_region_cb(void *rcontext, ucs_rcache_t *rcache,
Expand Down
3 changes: 2 additions & 1 deletion src/ucp/core/ucp_mm.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,8 @@ ucs_status_t ucp_memh_get_slow(ucp_context_h context, void *address,
ucp_md_map_t reg_md_map, unsigned uct_flags,
ucp_mem_h *memh_p);

void ucp_memh_dereg(ucp_context_h context, ucp_mem_h memh, ucp_md_map_t md_map);
void ucp_memh_unmap(ucp_context_h context, ucp_mem_h memh,
ucp_md_map_t md_map);

ucs_status_t ucp_mem_rcache_init(ucp_context_h context);

Expand Down
2 changes: 1 addition & 1 deletion src/ucp/core/ucp_mm.inl
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ ucp_memh_put(ucp_context_h context, ucp_mem_h memh, int invalidate)
}

if (ucs_unlikely(context->rcache == NULL)) {
ucp_memh_dereg(context, memh, memh->md_map);
ucp_memh_unmap(context, memh, memh->md_map);
ucs_free(memh);
return;
}
Expand Down
1 change: 1 addition & 0 deletions src/ucp/core/ucp_request.h
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,7 @@ struct ucp_request {
ucp_datatype_t datatype; /* Receive type */
size_t length; /* Total length, in bytes */
ucs_memory_type_t mem_type; /* Memory type */
uint32_t op_attr; /* Operation attributes */
ucp_dt_state_t state;
ucp_worker_t *worker;
uct_tag_context_t uct_ctx; /* Transport offload context */
Expand Down
37 changes: 24 additions & 13 deletions src/ucp/rndv/proto_rndv.c
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ ucp_proto_rndv_ctrl_init(const ucp_proto_rndv_ctrl_init_params_t *params)
ucp_memory_info_t mem_info;
ucs_status_t status;
double ctrl_latency;
uint16_t op_flags;

ucs_assert(params->super.flags & UCP_PROTO_COMMON_INIT_FLAG_RESPONSE);
ucs_assert(!(params->super.flags & UCP_PROTO_COMMON_INIT_FLAG_SINGLE_FRAG));
Expand All @@ -199,12 +200,16 @@ ucp_proto_rndv_ctrl_init(const ucp_proto_rndv_ctrl_init_params_t *params)
return UCS_ERR_NO_ELEM;
}

op_flags = UCP_PROTO_SELECT_OP_FLAG_INTERNAL |
(select_param->op_flags &
ucp_proto_select_op_attr_to_flags(UCP_OP_ATTR_FLAG_MULTI_SEND));

/* Construct select parameter for the remote protocol */
if (params->super.super.rkey_config_key == NULL) {
/* Remote buffer is unknown, assume same params as local */
remote_select_param = *select_param;
remote_select_param.op_id = params->remote_op_id;
remote_select_param.op_flags = UCP_PROTO_SELECT_OP_FLAG_INTERNAL;
remote_select_param.op_flags = op_flags;
} else {
/* If we know the remote buffer parameters, these are actually the local
* parameters for the remote protocol
Expand Down Expand Up @@ -462,8 +467,9 @@ void ucp_proto_rndv_bulk_query(const ucp_proto_query_params_t *params,
static ucs_status_t
ucp_proto_rndv_send_reply(ucp_worker_h worker, ucp_request_t *req,
ucp_operation_id_t op_id, uint32_t op_attr_mask,
size_t length, const void *rkey_buffer,
size_t rkey_length, uint8_t sg_count)
uint16_t op_flags, size_t length,
const void *rkey_buffer, size_t rkey_length,
uint8_t sg_count)
{
ucp_ep_h ep = req->send.ep;
ucp_worker_cfg_index_t rkey_cfg_index;
Expand Down Expand Up @@ -492,7 +498,7 @@ ucp_proto_rndv_send_reply(ucp_worker_h worker, ucp_request_t *req,
rkey = NULL;
}

ucp_proto_select_param_init(&sel_param, op_id, op_attr_mask, 0,
ucp_proto_select_param_init(&sel_param, op_id, op_attr_mask, op_flags,
req->send.state.dt_iter.dt_class,
&req->send.state.dt_iter.mem_info, sg_count);

Expand Down Expand Up @@ -575,7 +581,8 @@ void ucp_proto_rndv_receive_start(ucp_worker_h worker, ucp_request_t *recv_req,
&sg_count);
}

status = ucp_proto_rndv_send_reply(worker, req, op_id, 0, rts->size,
status = ucp_proto_rndv_send_reply(worker, req, op_id,
recv_req->recv.op_attr, 0, rts->size,
rkey_buffer, rkey_length, sg_count);
if (status != UCS_OK) {
ucp_datatype_iter_cleanup(&req->send.state.dt_iter, UCP_DT_MASK_ALL);
Expand All @@ -592,8 +599,9 @@ void ucp_proto_rndv_receive_start(ucp_worker_h worker, ucp_request_t *recv_req,

static ucs_status_t
ucp_proto_rndv_send_start(ucp_worker_h worker, ucp_request_t *req,
uint32_t op_attr_mask, const ucp_rndv_rtr_hdr_t *rtr,
size_t header_length, uint8_t sg_count)
uint32_t op_attr_mask, uint32_t op_flags,
const ucp_rndv_rtr_hdr_t *rtr, size_t header_length,
uint8_t sg_count)
{
ucs_status_t status;
size_t rkey_length;
Expand All @@ -608,8 +616,8 @@ ucp_proto_rndv_send_start(ucp_worker_h worker, ucp_request_t *req,

ucs_assert(rtr->size == req->send.state.dt_iter.length);
status = ucp_proto_rndv_send_reply(worker, req, UCP_OP_ID_RNDV_SEND,
op_attr_mask, rtr->size, rtr + 1,
rkey_length, sg_count);
op_attr_mask, op_flags, rtr->size,
rtr + 1, rkey_length, sg_count);
if (status != UCS_OK) {
return status;
}
Expand Down Expand Up @@ -641,6 +649,7 @@ ucp_proto_rndv_handle_rtr(void *arg, void *data, size_t length, unsigned flags)
const ucp_rndv_rtr_hdr_t *rtr = data;
ucp_request_t *req, *freq;
ucs_status_t status;
uint32_t op_flags;
uint8_t sg_count;

UCP_SEND_REQUEST_GET_BY_ID(&req, worker, rtr->sreq_id, 0, return UCS_OK,
Expand All @@ -652,6 +661,8 @@ ucp_proto_rndv_handle_rtr(void *arg, void *data, size_t length, unsigned flags)
/* RTR covers the whole send request - use the send request directly */
ucs_assert(req->flags & UCP_REQUEST_FLAG_PROTO_INITIALIZED);

op_flags = req->send.proto_config->select_param.op_flags;

if (rtr->size == req->send.state.dt_iter.length) {
/* RTR covers the whole send request - use the send request directly */
ucs_assert(rtr->offset == 0);
Expand All @@ -662,8 +673,8 @@ ucp_proto_rndv_handle_rtr(void *arg, void *data, size_t length, unsigned flags)
req->flags &= ~UCP_REQUEST_FLAG_PROTO_INITIALIZED;

sg_count = req->send.proto_config->select_param.sg_count;
status = ucp_proto_rndv_send_start(worker, req, 0, rtr, length,
sg_count);
status = ucp_proto_rndv_send_start(worker, req, 0, op_flags, rtr,
length, sg_count);
if (status != UCS_OK) {
goto err_request_fail;
}
Expand All @@ -688,8 +699,8 @@ ucp_proto_rndv_handle_rtr(void *arg, void *data, size_t length, unsigned flags)
* TODO can rndv/ppln be selected here (and not just single frag)?
*/
status = ucp_proto_rndv_send_start(worker, freq,
UCP_OP_ATTR_FLAG_MULTI_SEND, rtr,
length, sg_count);
UCP_OP_ATTR_FLAG_MULTI_SEND,
op_flags, rtr, length, sg_count);
if (status != UCS_OK) {
goto err_put_freq;
}
Expand Down
2 changes: 1 addition & 1 deletion src/ucp/tag/tag_recv.c
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ ucp_tag_recv_common(ucp_worker_h worker, void *buffer, size_t count,
&req->recv.state);
req->recv.mem_type = ucp_request_get_memory_type(worker->context, buffer,
req->recv.length, param);

req->recv.op_attr = param->op_attr_mask;
req->recv.tag.tag = tag;
req->recv.tag.tag_mask = tag_mask;
if (param->op_attr_mask & UCP_OP_ATTR_FIELD_CALLBACK) {
Expand Down
19 changes: 8 additions & 11 deletions src/ucs/memory/rcache.c
Original file line number Diff line number Diff line change
Expand Up @@ -364,12 +364,10 @@ void ucs_mem_region_destroy_internal(ucs_rcache_t *rcache,

if (region->flags & UCS_RCACHE_REGION_FLAG_REGISTERED) {
UCS_STATS_UPDATE_COUNTER(rcache->stats, UCS_RCACHE_DEREGS, 1);
{
UCS_PROFILE_CODE("mem_dereg") {
rcache->params.ops->mem_dereg(rcache->params.context, rcache,
region);
}
}
UCS_PROFILE_NAMED_CALL_VOID_ALWAYS("mem_dereg",
rcache->params.ops->mem_dereg,
rcache->params.context, rcache,
region);
}

if (!(rcache->params.flags & UCS_RCACHE_FLAG_NO_PFN_CHECK) &&
Expand Down Expand Up @@ -733,7 +731,7 @@ ucs_rcache_check_overlap(ucs_rcache_t *rcache, ucs_pgt_addr_t *start,
* TODO: currently rcache is optimized for the case where most of
* the regions have same protection.
*/
mem_prot = UCS_PROFILE_CALL(ucs_get_mem_prot, *start, *end);
mem_prot = UCS_PROFILE_CALL_ALWAYS(ucs_get_mem_prot, *start, *end);
if (!ucs_test_all_flags(mem_prot, *prot)) {
ucs_rcache_region_trace(rcache, region,
"do not merge "UCS_RCACHE_PROT_FMT
Expand Down Expand Up @@ -897,10 +895,9 @@ ucs_rcache_create_region(ucs_rcache_t *rcache, void *address, size_t length,
++distribution_bin->count;
distribution_bin->total_size += region_size;

region->status = status =
UCS_PROFILE_NAMED_CALL("mem_reg", rcache->params.ops->mem_reg,
rcache->params.context, rcache, arg, region,
merged ? UCS_RCACHE_MEM_REG_HIDE_ERRORS : 0);
region->status = status = UCS_PROFILE_NAMED_CALL_ALWAYS(
"mem_reg", rcache->params.ops->mem_reg, rcache->params.context,
rcache, arg, region, merged ? UCS_RCACHE_MEM_REG_HIDE_ERRORS : 0);
if (status != UCS_OK) {
if (merged) {
/* failure may be due to merge, because memory of the merged
Expand Down
Loading

0 comments on commit f19386b

Please sign in to comment.