Skip to content

Commit

Permalink
Merge pull request open-mpi#2 from kunpengcompute/topic/non_contig_da…
Browse files Browse the repository at this point in the history
…tatypes

Support for non-contiguous datatypes
  • Loading branch information
nsosnsos authored Nov 16, 2020
2 parents 8b48192 + 4980a36 commit 3fc5793
Show file tree
Hide file tree
Showing 7 changed files with 416 additions and 28 deletions.
2 changes: 2 additions & 0 deletions ompi/mca/coll/ucx/Makefile.am
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@ AM_CPPFLAGS = $(coll_ucx_CPPFLAGS) -DCOLL_UCX_HOME=\"$(coll_ucx_HOME)\" $(coll_u
coll_ucx_sources = \
coll_ucx.h \
coll_ucx_request.h \
coll_ucx_datatype.h \
coll_ucx_freelist.h \
coll_ucx_op.c \
coll_ucx_module.c \
coll_ucx_request.c \
coll_ucx_datatype.c \
coll_ucx_component.c

# Make the output library in this directory, and name it either
Expand Down
9 changes: 8 additions & 1 deletion ompi/mca/coll/ucx/coll_ucx.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@
#include "ompi/communicator/communicator.h"
#include "ompi/datatype/ompi_datatype.h"
#include "ompi/attribute/attribute.h"
#include "ompi/op/op.h"

#include "orte/runtime/orte_globals.h"
#include "ompi/datatype/ompi_datatype_internal.h"
#include "opal/mca/common/ucx/common_ucx.h"

#include "ucg/api/ucg_mpi.h"
Expand Down Expand Up @@ -71,6 +71,13 @@ typedef struct mca_coll_ucx_component {
mca_coll_ucx_freelist_t persistent_ops;
ompi_request_t completed_send_req;
size_t request_size;

/* Datatypes */
int datatype_attr_keyval;
ucp_datatype_t predefined_types[OMPI_DATATYPE_MPI_MAX_PREDEFINED];

/* Converters pool */
mca_coll_ucx_freelist_t convs;
} mca_coll_ucx_component_t;
OMPI_MODULE_DECLSPEC extern mca_coll_ucx_component_t mca_coll_ucx_component;

Expand Down
29 changes: 20 additions & 9 deletions ompi/mca/coll/ucx/coll_ucx_component.c
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

#include "coll_ucx.h"
#include "coll_ucx_request.h"
#include "coll_ucx_datatype.h"


/*
Expand Down Expand Up @@ -266,6 +267,12 @@ int mca_coll_ucx_open(void)
goto out;
}

int i;
mca_coll_ucx_component.datatype_attr_keyval = MPI_KEYVAL_INVALID;
for (i = 0; i < OMPI_DATATYPE_MAX_PREDEFINED; ++i) {
mca_coll_ucx_component.predefined_types[i] = COLL_UCX_DATATYPE_INVALID;
}

ucs_list_head_init(&mca_coll_ucx_component.group_head);
return OMPI_SUCCESS;

Expand All @@ -279,6 +286,14 @@ int mca_coll_ucx_close(void)
{
COLL_UCX_VERBOSE(1, "mca_coll_ucx_close");

int i;
for (i = 0; i < OMPI_DATATYPE_MAX_PREDEFINED; ++i) {
if (mca_coll_ucx_component.predefined_types[i] != COLL_UCX_DATATYPE_INVALID) {
ucp_dt_destroy(mca_coll_ucx_component.predefined_types[i]);
mca_coll_ucx_component.predefined_types[i] = COLL_UCX_DATATYPE_INVALID;
}
}

if (mca_coll_ucx_component.ucg_worker != NULL) {
mca_coll_ucx_cleanup();
mca_coll_ucx_component.ucg_worker = NULL;
Expand Down Expand Up @@ -355,11 +370,10 @@ int mca_coll_ucx_init(void)
}

/* Initialize the free lists */
OBJ_CONSTRUCT(&mca_coll_ucx_component.persistent_ops, mca_coll_ucx_freelist_t);

/* Create a completed request to be returned from isend */
OBJ_CONSTRUCT(&mca_coll_ucx_component.completed_send_req, ompi_request_t);
mca_coll_ucx_completed_request_init(&mca_coll_ucx_component.completed_send_req);
OBJ_CONSTRUCT(&mca_coll_ucx_component.convs, mca_coll_ucx_freelist_t);
COLL_UCX_FREELIST_INIT(&mca_coll_ucx_component.convs,
mca_coll_ucx_convertor_t,
128, -1, 128);

rc = opal_progress_register(mca_coll_ucx_progress);
if (OPAL_SUCCESS != rc) {
Expand All @@ -384,10 +398,7 @@ void mca_coll_ucx_cleanup(void)

opal_progress_unregister(mca_coll_ucx_progress);

mca_coll_ucx_component.completed_send_req.req_state = OMPI_REQUEST_INVALID;
OMPI_REQUEST_FINI(&mca_coll_ucx_component.completed_send_req);
OBJ_DESTRUCT(&mca_coll_ucx_component.completed_send_req);
OBJ_DESTRUCT(&mca_coll_ucx_component.persistent_ops);
OBJ_DESTRUCT(&mca_coll_ucx_component.convs);

if (mca_coll_ucx_component.ucg_worker) {
ucg_worker_destroy(mca_coll_ucx_component.ucg_worker);
Expand Down
271 changes: 271 additions & 0 deletions ompi/mca/coll/ucx/coll_ucx_datatype.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,271 @@
/*
* Copyright (C) Mellanox Technologies Ltd. 2001-2011. ALL RIGHTS RESERVED.
* Copyright (c) 2019 Research Organization for Information Science
* and Technology (RIST). All rights reserved.
* Copyright (c) 2020 The University of Tennessee and The University
* of Tennessee Research Foundation. All rights
* reserved.
*
* Copyright (c) 2020 Huawei Technologies Co., Ltd. All rights reserved.
* $COPYRIGHT$
*
* Additional copyrights may follow
*
* $HEADER$
*/

#include "coll_ucx_datatype.h"
#include "coll_ucx_request.h"

#include "ompi/runtime/mpiruntime.h"
#include "ompi/attribute/attribute.h"

#include <inttypes.h>
#include <math.h>

static void* coll_ucx_generic_datatype_start_pack(void *context, const void *buffer,
size_t count)
{
ompi_datatype_t *datatype = context;
mca_coll_ucx_convertor_t *convertor;

convertor = (mca_coll_ucx_convertor_t *)COLL_UCX_FREELIST_GET(&mca_coll_ucx_component.convs);

OMPI_DATATYPE_RETAIN(datatype);
convertor->datatype = datatype;
opal_convertor_copy_and_prepare_for_send(ompi_proc_local_proc->super.proc_convertor,
&datatype->super, count, buffer, 0,
&convertor->opal_conv);
return convertor;
}

static void* coll_ucx_generic_datatype_start_unpack(void *context, void *buffer,
size_t count)
{
ompi_datatype_t *datatype = context;
mca_coll_ucx_convertor_t *convertor;

convertor = (mca_coll_ucx_convertor_t *)COLL_UCX_FREELIST_GET(&mca_coll_ucx_component.convs);

OMPI_DATATYPE_RETAIN(datatype);
convertor->datatype = datatype;
convertor->offset = 0;
opal_convertor_copy_and_prepare_for_recv(ompi_proc_local_proc->super.proc_convertor,
&datatype->super, count, buffer, 0,
&convertor->opal_conv);
return convertor;
}

static size_t coll_ucx_generic_datatype_packed_size(void *state)
{
mca_coll_ucx_convertor_t *convertor = state;
size_t size;

opal_convertor_get_packed_size(&convertor->opal_conv, &size);
return size;
}

static size_t coll_ucx_generic_datatype_pack(void *state, size_t offset,
void *dest, size_t max_length)
{
mca_coll_ucx_convertor_t *convertor = state;
uint32_t iov_count;
struct iovec iov;
size_t length;

iov_count = 1;
iov.iov_base = dest;
iov.iov_len = max_length;

opal_convertor_set_position(&convertor->opal_conv, &offset);
length = max_length;
opal_convertor_pack(&convertor->opal_conv, &iov, &iov_count, &length);
return length;
}

static ucs_status_t coll_ucx_generic_datatype_unpack(void *state, size_t offset,
const void *src, size_t length)
{
mca_coll_ucx_convertor_t *convertor = state;

uint32_t iov_count;
struct iovec iov;
opal_convertor_t conv;

iov_count = 1;
iov.iov_base = (void*)src;
iov.iov_len = length;

/* in case if unordered message arrived - create separate convertor to
* unpack data. */
if (offset != convertor->offset) {
OBJ_CONSTRUCT(&conv, opal_convertor_t);
opal_convertor_copy_and_prepare_for_recv(ompi_proc_local_proc->super.proc_convertor,
&convertor->datatype->super,
convertor->opal_conv.count,
convertor->opal_conv.pBaseBuf, 0,
&conv);
opal_convertor_set_position(&conv, &offset);
opal_convertor_unpack(&conv, &iov, &iov_count, &length);
opal_convertor_cleanup(&conv);
OBJ_DESTRUCT(&conv);
/* permanently switch to un-ordered mode */
convertor->offset = 0;
} else {
opal_convertor_unpack(&convertor->opal_conv, &iov, &iov_count, &length);
convertor->offset += length;
}
return UCS_OK;
}

static void coll_ucx_generic_datatype_finish(void *state)
{
mca_coll_ucx_convertor_t *convertor = state;

opal_convertor_cleanup(&convertor->opal_conv);
OMPI_DATATYPE_RELEASE(convertor->datatype);
COLL_UCX_FREELIST_RETURN(&mca_coll_ucx_component.convs, &convertor->super);
}

static ucp_generic_dt_ops_t coll_ucx_generic_datatype_ops = {
.start_pack = coll_ucx_generic_datatype_start_pack,
.start_unpack = coll_ucx_generic_datatype_start_unpack,
.packed_size = coll_ucx_generic_datatype_packed_size,
.pack = coll_ucx_generic_datatype_pack,
.unpack = coll_ucx_generic_datatype_unpack,
.finish = coll_ucx_generic_datatype_finish
};

int mca_coll_ucx_datatype_attr_del_fn(ompi_datatype_t* datatype, int keyval,
void *attr_val, void *extra)
{
ucp_datatype_t ucp_datatype = (ucp_datatype_t)attr_val;

#ifdef HAVE_UCP_REQUEST_PARAM_T
free((void*)datatype->pml_data);
#else
COLL_UCX_ASSERT((uint64_t)ucp_datatype == datatype->pml_data);
#endif
ucp_dt_destroy(ucp_datatype);
datatype->pml_data = COLL_UCX_DATATYPE_INVALID;
return OMPI_SUCCESS;
}

__opal_attribute_always_inline__
static inline int mca_coll_ucx_datatype_is_contig(ompi_datatype_t *datatype)
{
ptrdiff_t lb;

ompi_datatype_type_lb(datatype, &lb);

return (datatype->super.flags & OPAL_DATATYPE_FLAG_CONTIGUOUS) &&
(datatype->super.flags & OPAL_DATATYPE_FLAG_NO_GAPS) &&
(lb == 0);
}

#ifdef HAVE_UCP_REQUEST_PARAM_T
__opal_attribute_always_inline__ static inline
coll_ucx_datatype_t *mca_coll_ucx_init_nbx_datatype(ompi_datatype_t *datatype,
ucp_datatype_t ucp_datatype,
size_t size)
{
coll_ucx_datatype_t *pml_datatype;
int is_contig_pow2;

pml_datatype = malloc(sizeof(*pml_datatype));
if (pml_datatype == NULL) {
int err = MPI_ERR_INTERN;
COLL_UCX_ERROR("Failed to allocate datatype structure");
/* TODO: this error should return to the caller and invoke an error
* handler from the MPI API call.
* For now, it is fatal. */
ompi_mpi_errors_are_fatal_comm_handler(NULL, &err, "Failed to allocate datatype structure");
}

pml_datatype->datatype = ucp_datatype;

is_contig_pow2 = mca_coll_ucx_datatype_is_contig(datatype) &&
(size && !(size & (size - 1))); /* is_pow2(size) */
if (is_contig_pow2) {
pml_datatype->size_shift = (int)(log(size) / log(2.0)); /* log2(size) */
} else {
pml_datatype->size_shift = 0;
}

return pml_datatype;
}
#endif

ucp_datatype_t mca_coll_ucx_init_datatype(ompi_datatype_t *datatype)
{
size_t size = 0; /* init to suppress compiler warning */
ucp_datatype_t ucp_datatype;
ucs_status_t status;
int ret;

if (mca_coll_ucx_datatype_is_contig(datatype)) {
ompi_datatype_type_size(datatype, &size);
ucp_datatype = ucp_dt_make_contig(size);
goto out;
}

status = ucp_dt_create_generic(&coll_ucx_generic_datatype_ops,
datatype, &ucp_datatype);
if (status != UCS_OK) {
int err = MPI_ERR_INTERN;
COLL_UCX_ERROR("Failed to create UCX datatype for %s", datatype->name);
/* TODO: this error should return to the caller and invoke an error
* handler from the MPI API call.
* For now, it is fatal. */
ompi_mpi_errors_are_fatal_comm_handler(NULL, &err, "Failed to allocate datatype structure");
}

/* Add custom attribute, to clean up UCX resources when OMPI datatype is
* released.
*/
if (ompi_datatype_is_predefined(datatype)) {
COLL_UCX_ASSERT(datatype->id < OMPI_DATATYPE_MAX_PREDEFINED);
mca_coll_ucx_component.predefined_types[datatype->id] = ucp_datatype;
} else {
ret = ompi_attr_set_c(TYPE_ATTR, datatype, &datatype->d_keyhash,
mca_coll_ucx_component.datatype_attr_keyval,
(void*)ucp_datatype, false);
if (ret != OMPI_SUCCESS) {
int err = MPI_ERR_INTERN;
COLL_UCX_ERROR("Failed to add UCX datatype attribute for %s (%p): %d",
datatype->name, (void*)datatype, ret);
/* TODO: this error should return to the caller and invoke an error
* handler from the MPI API call.
* For now, it is fatal. */
ompi_mpi_errors_are_fatal_comm_handler(NULL, &err, "Failed to allocate datatype structure");
}
}
out:
COLL_UCX_VERBOSE(7, "created generic UCX datatype 0x%"PRIx64, ucp_datatype)

#ifdef HAVE_UCP_REQUEST_PARAM_T
UCS_STATIC_ASSERT(sizeof(datatype->pml_data) >= sizeof(coll_ucx_datatype_t*));
datatype->pml_data = (uint64_t)mca_coll_ucx_init_nbx_datatype(datatype,
ucp_datatype,
size);
#else
datatype->pml_data = ucp_datatype;
#endif

return ucp_datatype;
}

static void mca_coll_ucx_convertor_construct(mca_coll_ucx_convertor_t *convertor)
{
OBJ_CONSTRUCT(&convertor->opal_conv, opal_convertor_t);
}

static void mca_coll_ucx_convertor_destruct(mca_coll_ucx_convertor_t *convertor)
{
OBJ_DESTRUCT(&convertor->opal_conv);
}

OBJ_CLASS_INSTANCE(mca_coll_ucx_convertor_t,
opal_free_list_item_t,
mca_coll_ucx_convertor_construct,
mca_coll_ucx_convertor_destruct);
Loading

0 comments on commit 3fc5793

Please sign in to comment.