diff --git a/contrib/test_jenkins.sh b/contrib/test_jenkins.sh index c120d9043d8..34593eb3778 100755 --- a/contrib/test_jenkins.sh +++ b/contrib/test_jenkins.sh @@ -593,6 +593,7 @@ run_ucx_perftest_mpi() { $MPIRUN -np 2 -x UCX_TLS=rc,cuda_copy,gdr_copy -x UCX_MEMTYPE_CACHE=y $AFFINITY $UCX_PERFTEST $MPIRUN -np 2 -x UCX_TLS=rc,cuda_copy,gdr_copy -x UCX_MEMTYPE_CACHE=n $AFFINITY $UCX_PERFTEST $MPIRUN -np 2 -x UCX_TLS=rc,cuda_copy $AFFINITY $UCX_PERFTEST + $MPIRUN -np 2 -x UCX_TLS=self,mm,cma,cuda_copy $AFFINITY $UCX_PERFTEST $MPIRUN -np 2 $AFFINITY $UCX_PERFTEST unset CUDA_VISIBLE_DEVICES fi diff --git a/src/ucp/core/ucp_mm.h b/src/ucp/core/ucp_mm.h index 1af08d59363..7b29972847c 100644 --- a/src/ucp/core/ucp_mm.h +++ b/src/ucp/core/ucp_mm.h @@ -42,6 +42,7 @@ typedef struct ucp_rkey { ucp_rma_proto_t *rma_proto; /* Protocol for RMAs */ } cache; ucp_md_map_t md_map; /* Which *remote* MDs have valid memory handles */ + uct_memory_type_t mem_type;/* Memory type of remote key memory */ uct_rkey_bundle_t uct[0]; /* Remote key for every MD */ } ucp_rkey_t; @@ -122,10 +123,12 @@ ucs_status_t ucp_mem_rereg_mds(ucp_context_h context, ucp_md_map_t reg_md_map, size_t ucp_rkey_packed_size(ucp_context_h context, ucp_md_map_t md_map); void ucp_rkey_packed_copy(ucp_context_h context, ucp_md_map_t md_map, - void *rkey_buffer, const void* uct_rkeys[]); + uct_memory_type_t mem_type, void *rkey_buffer, + const void* uct_rkeys[]); ssize_t ucp_rkey_pack_uct(ucp_context_h context, ucp_md_map_t md_map, - const uct_mem_h *memh, void *rkey_buffer); + const uct_mem_h *memh, uct_memory_type_t mem_type, + void *rkey_buffer); void ucp_rkey_dump_packed(const void *rkey_buffer, char *buffer, size_t max); diff --git a/src/ucp/core/ucp_rkey.c b/src/ucp/core/ucp_rkey.c index 3eeb70df717..6443707f8e0 100644 --- a/src/ucp/core/ucp_rkey.c +++ b/src/ucp/core/ucp_rkey.c @@ -13,7 +13,10 @@ #include -static ucp_md_map_t ucp_mem_dummy_buffer = 0; +static struct { + ucp_md_map_t md_map; + uint8_t mem_type; +} UCS_S_PACKED ucp_mem_dummy_buffer = {0, UCT_MD_MEM_TYPE_HOST}; size_t ucp_rkey_packed_size(ucp_context_h context, ucp_md_map_t md_map) @@ -22,6 +25,7 @@ size_t ucp_rkey_packed_size(ucp_context_h context, ucp_md_map_t md_map) unsigned md_index; size = sizeof(ucp_md_map_t); + size += sizeof(uint8_t); ucs_for_each_bit (md_index, md_map) { md_size = context->tl_mds[md_index].attr.rkey_packed_size; ucs_assert_always(md_size <= UINT8_MAX); @@ -31,7 +35,8 @@ size_t ucp_rkey_packed_size(ucp_context_h context, ucp_md_map_t md_map) } void ucp_rkey_packed_copy(ucp_context_h context, ucp_md_map_t md_map, - void *rkey_buffer, const void* uct_rkeys[]) + uct_memory_type_t mem_type, void *rkey_buffer, + const void* uct_rkeys[]) { void *p = rkey_buffer; unsigned md_index; @@ -40,6 +45,8 @@ void ucp_rkey_packed_copy(ucp_context_h context, ucp_md_map_t md_map, *(ucp_md_map_t*)p = md_map; p += sizeof(ucp_md_map_t); + *((uint8_t *)p++) = mem_type; + ucs_for_each_bit(md_index, md_map) { md_size = context->tl_mds[md_index].attr.rkey_packed_size; ucs_assert_always(md_size <= UINT8_MAX); @@ -51,7 +58,8 @@ void ucp_rkey_packed_copy(ucp_context_h context, ucp_md_map_t md_map, } ssize_t ucp_rkey_pack_uct(ucp_context_h context, ucp_md_map_t md_map, - const uct_mem_h *memh, void *rkey_buffer) + const uct_mem_h *memh, uct_memory_type_t mem_type, + void *rkey_buffer) { void *p = rkey_buffer; ucs_status_t status = UCS_OK; @@ -66,6 +74,10 @@ ssize_t ucp_rkey_pack_uct(ucp_context_h context, ucp_md_map_t md_map, *(ucp_md_map_t*)p = md_map; p += sizeof(ucp_md_map_t); + /* Write memory type */ + UCS_STATIC_ASSERT(UCT_MD_MEM_TYPE_LAST <= 255); + *((uint8_t*)p++) = mem_type; + /* Write both size and rkey_buffer for each UCT rkey */ uct_memh_index = 0; ucs_for_each_bit (md_index, md_map) { @@ -122,7 +134,8 @@ ucs_status_t ucp_rkey_pack(ucp_context_h context, ucp_mem_h memh, p = rkey_buffer; - packed_size = ucp_rkey_pack_uct(context, memh->md_map, memh->uct, p); + packed_size = ucp_rkey_pack_uct(context, memh->md_map, memh->uct, + memh->mem_type, p); if (packed_size < 0) { status = (ucs_status_t)packed_size; goto err_destroy; @@ -161,6 +174,7 @@ ucs_status_t ucp_ep_rkey_unpack(ucp_ep_h ep, const void *rkey_buffer, unsigned md_count; ucs_status_t status; ucp_rkey_h rkey; + uct_memory_type_t mem_type; uint8_t md_size; const void *p; @@ -193,7 +207,11 @@ ucs_status_t ucp_ep_rkey_unpack(ucp_ep_h ep, const void *rkey_buffer, goto err; } + /* Read memory type */ + mem_type = *((uint8_t*)p++); + rkey->md_map = md_map; + rkey->mem_type = mem_type; /* Unpack rkey of each UCT MD */ remote_md_index = 0; /* Index of remote MD */ @@ -259,6 +277,8 @@ void ucp_rkey_dump_packed(const void *rkey_buffer, char *buffer, size_t max) md_map = *(ucp_md_map_t*)(rkey_buffer); rkey_buffer += sizeof(ucp_md_map_t); + rkey_buffer += sizeof(uint8_t); + first = 1; ucs_for_each_bit(md_index, md_map) { md_size = *((uint8_t*)rkey_buffer); @@ -337,6 +357,7 @@ static ucp_lane_index_t ucp_config_find_rma_lane(ucp_context_h context, ucp_lane_index_t lane; ucp_md_map_t dst_md_mask; ucp_md_index_t md_index; + uct_md_attr_t *md_attr; uint8_t rkey_index; int prio; @@ -349,16 +370,21 @@ static ucp_lane_index_t ucp_config_find_rma_lane(ucp_context_h context, } md_index = config->md_index[lane]; + md_attr = &context->tl_mds[md_index].attr; + if ((md_index != UCP_NULL_RESOURCE) && - (!(context->tl_mds[md_index].attr.cap.flags & UCT_MD_FLAG_NEED_RKEY))) + (!(md_attr->cap.flags & UCT_MD_FLAG_NEED_RKEY))) { /* Lane does not need rkey, can use the lane with invalid rkey */ - *uct_rkey_p = UCT_INVALID_RKEY; - return lane; + if (!rkey || ((mem_type == md_attr->cap.mem_type) && + (mem_type == rkey->mem_type))) { + *uct_rkey_p = UCT_INVALID_RKEY; + return lane; + } } if ((md_index != UCP_NULL_RESOURCE) && - (!(context->tl_mds[md_index].attr.cap.reg_mem_types & UCS_BIT(mem_type)))) { + (!(md_attr->cap.reg_mem_types & UCS_BIT(mem_type)))) { continue; } diff --git a/src/ucp/tag/offload.c b/src/ucp/tag/offload.c index 7ea8308a994..3fe7fe7f961 100644 --- a/src/ucp/tag/offload.c +++ b/src/ucp/tag/offload.c @@ -186,8 +186,8 @@ UCS_PROFILE_FUNC(ucs_status_t, ucp_tag_offload_unexp_rndv, dummy_rts->address = remote_addr; dummy_rts->size = length; - ucp_rkey_packed_copy(worker->context, UCS_BIT(md_index), dummy_rts + 1, - uct_rkeys); + ucp_rkey_packed_copy(worker->context, UCS_BIT(md_index), + UCT_MD_MEM_TYPE_HOST, dummy_rts + 1, uct_rkeys); UCP_WORKER_STAT_TAG_OFFLOAD(worker, RX_UNEXP_RNDV); ucp_rndv_process_rts(worker, dummy_rts, dummy_rts_size, 0); diff --git a/src/ucp/tag/rndv.c b/src/ucp/tag/rndv.c index 606777f4165..1461357827b 100644 --- a/src/ucp/tag/rndv.c +++ b/src/ucp/tag/rndv.c @@ -59,6 +59,7 @@ size_t ucp_tag_rndv_rts_pack(void *dest, void *arg) packed_rkey_size = ucp_rkey_pack_uct(worker->context, sreq->send.state.dt.dt.contig.md_map, sreq->send.state.dt.dt.contig.memh, + sreq->send.mem_type, rndv_rts_hdr + 1); if (packed_rkey_size < 0) { ucs_fatal("failed to pack rendezvous remote key: %s", @@ -95,6 +96,7 @@ static size_t ucp_tag_rndv_rtr_pack(void *dest, void *arg) packed_rkey_size = ucp_rkey_pack_uct(rndv_req->send.ep->worker->context, rreq->recv.state.dt.contig.md_map, rreq->recv.state.dt.contig.memh, + rreq->recv.mem_type, rndv_rtr_hdr + 1); if (packed_rkey_size < 0) { return packed_rkey_size; @@ -870,10 +872,12 @@ static ucs_status_t ucp_rndv_pipeline(ucp_request_t *sreq, ucp_rndv_rtr_hdr_t *r frag_req->send.ep = pipeline_ep; frag_req->send.buffer = mdesc + 1; frag_req->send.datatype = ucp_dt_make_contig(1); + frag_req->send.mem_type = sreq->send.mem_type; frag_req->send.state.dt.dt.contig.memh[0]= ucp_memh2uct(mdesc->memh, md_index); frag_req->send.state.dt.dt.contig.md_map = UCS_BIT(md_index); frag_req->send.length = length; frag_req->send.uct.func = ucp_rndv_progress_rma_get_zcopy; + frag_req->send.rndv_get.rkey = NULL; frag_req->send.rndv_get.remote_address = (uint64_t)(sreq->send.buffer + offset); frag_req->send.rndv_get.lanes_map = 0; frag_req->send.rndv_get.lane_count = 0;