From 6247e315ef5edc23ba95fe680de34c0931bd4629 Mon Sep 17 00:00:00 2001 From: Valentin Petrov Date: Thu, 30 Jan 2020 21:41:40 +0200 Subject: [PATCH] ompi/coll/mccl Mellanox collective communication library (MCCL) integration layer --- config/ompi_check_mccl.m4 | 75 ++++++ ompi/mca/coll/base/coll_tags.h | 4 +- ompi/mca/coll/mccl/Makefile.am | 46 ++++ ompi/mca/coll/mccl/coll_mccl.h | 103 ++++++++ ompi/mca/coll/mccl/coll_mccl_component.c | 163 +++++++++++++ ompi/mca/coll/mccl/coll_mccl_debug.h | 30 +++ ompi/mca/coll/mccl/coll_mccl_dtypes.h | 78 ++++++ ompi/mca/coll/mccl/coll_mccl_module.c | 289 +++++++++++++++++++++++ ompi/mca/coll/mccl/coll_mccl_ops.c | 90 +++++++ ompi/mca/coll/mccl/configure.m4 | 38 +++ 10 files changed, 915 insertions(+), 1 deletion(-) create mode 100644 config/ompi_check_mccl.m4 create mode 100644 ompi/mca/coll/mccl/Makefile.am create mode 100644 ompi/mca/coll/mccl/coll_mccl.h create mode 100644 ompi/mca/coll/mccl/coll_mccl_component.c create mode 100644 ompi/mca/coll/mccl/coll_mccl_debug.h create mode 100644 ompi/mca/coll/mccl/coll_mccl_dtypes.h create mode 100644 ompi/mca/coll/mccl/coll_mccl_module.c create mode 100644 ompi/mca/coll/mccl/coll_mccl_ops.c create mode 100644 ompi/mca/coll/mccl/configure.m4 diff --git a/config/ompi_check_mccl.m4 b/config/ompi_check_mccl.m4 new file mode 100644 index 00000000000..b333d726900 --- /dev/null +++ b/config/ompi_check_mccl.m4 @@ -0,0 +1,75 @@ +dnl -*- shell-script -*- +dnl +dnl Copyright (c) 2011 Mellanox Technologies. All rights reserved. +dnl Copyright (c) 2013 Cisco Systems, Inc. All rights reserved. +dnl Copyright (c) 2015 Research Organization for Information Science +dnl and Technology (RIST). All rights reserved. +dnl $COPYRIGHT$ +dnl +dnl Additional copyrights may follow +dnl +dnl $HEADER$ +dnl + +# OMPI_CHECK_MCCL(prefix, [action-if-found], [action-if-not-found]) +# -------------------------------------------------------- +# check if mccl support can be found. sets prefix_{CPPFLAGS, +# LDFLAGS, LIBS} as needed and runs action-if-found if there is +# support, otherwise executes action-if-not-found +AC_DEFUN([OMPI_CHECK_MCCL],[ + OPAL_VAR_SCOPE_PUSH([ompi_check_mccl_dir ompi_check_mccl_libs ompi_check_mccl_happy CPPFLAGS_save LDFLAGS_save LIBS_save]) + + AC_ARG_WITH([mccl], + [AC_HELP_STRING([--with-mccl(=DIR)], + [Build mccl (Unified Communication Hierarchical collectives) support, optionally adding + DIR/include and DIR/lib or DIR/lib64 to the search path for headers and libraries])]) + + AS_IF([test "$with_mccl" != "no"], + [ompi_check_mccl_libs=mccl + AS_IF([test ! -z "$with_mccl" && test "$with_mccl" != "yes"], + [ompi_check_mccl_dir=$with_mccl]) + + CPPFLAGS_save=$CPPFLAGS + LDFLAGS_save=$LDFLAGS + LIBS_save=$LIBS + + OPAL_LOG_MSG([$1_CPPFLAGS : $$1_CPPFLAGS], 1) + OPAL_LOG_MSG([$1_LDFLAGS : $$1_LDFLAGS], 1) + OPAL_LOG_MSG([$1_LIBS : $$1_LIBS], 1) + + OPAL_CHECK_PACKAGE([$1], + [api/mccl.h], + [$ompi_check_mccl_libs], + [mccl_init_context], + [], + [$ompi_check_mccl_dir], + [], + [ompi_check_mccl_happy="yes"], + [ompi_check_mccl_happy="no"]) + + AS_IF([test "$ompi_check_mccl_happy" = "yes"], + [ + CPPFLAGS=$coll_mccl_CPPFLAGS + LDFLAGS=$coll_mccl_LDFLAGS + LIBS=$coll_mccl_LIBS + AC_CHECK_FUNCS(mccl_comm_free, [], []) + ], + []) + + CPPFLAGS=$CPPFLAGS_save + LDFLAGS=$LDFLAGS_save + LIBS=$LIBS_save], + [ompi_check_mccl_happy=no]) + + AS_IF([test "$ompi_check_mccl_happy" = "yes" && test "$enable_progress_threads" = "yes"], + [AC_MSG_WARN([mccl driver does not currently support progress threads. Disabling MCCL.]) + ompi_check_mccl_happy="no"]) + + AS_IF([test "$ompi_check_mccl_happy" = "yes"], + [$2], + [AS_IF([test ! -z "$with_mccl" && test "$with_mccl" != "no"], + [AC_MSG_ERROR([MCCL support requested but not found. Aborting])]) + $3]) + + OPAL_VAR_SCOPE_POP +]) diff --git a/ompi/mca/coll/base/coll_tags.h b/ompi/mca/coll/base/coll_tags.h index 2bcf2a6cc95..2df50ad1736 100644 --- a/ompi/mca/coll/base/coll_tags.h +++ b/ompi/mca/coll/base/coll_tags.h @@ -41,7 +41,9 @@ #define MCA_COLL_BASE_TAG_SCAN -24 #define MCA_COLL_BASE_TAG_SCATTER -25 #define MCA_COLL_BASE_TAG_SCATTERV -26 -#define MCA_COLL_BASE_TAG_NONBLOCKING_BASE -27 +#define MCA_COLL_BASE_TAG_MCCL -27 + +#define MCA_COLL_BASE_TAG_NONBLOCKING_BASE -28 #define MCA_COLL_BASE_TAG_NONBLOCKING_END ((-1 * INT_MAX/2) + 1) #define MCA_COLL_BASE_TAG_NEIGHBOR_BASE (MCA_COLL_BASE_TAG_NONBLOCKING_END - 1) #define MCA_COLL_BASE_TAG_NEIGHBOR_END (MCA_COLL_BASE_TAG_NEIGHBOR_BASE - 1024) diff --git a/ompi/mca/coll/mccl/Makefile.am b/ompi/mca/coll/mccl/Makefile.am new file mode 100644 index 00000000000..9a6f492c6b3 --- /dev/null +++ b/ompi/mca/coll/mccl/Makefile.am @@ -0,0 +1,46 @@ +# -*- shell-script -*- +# +# +# Copyright (c) 2020 Mellanox Technologies. All rights reserved. +# $COPYRIGHT$ +# +# Additional copyrights may follow +# +# $HEADER$ +# +# + +AM_CPPFLAGS = $(coll_mccl_CPPFLAGS) + +coll_mccl_sources = \ + coll_mccl.h \ + coll_mccl_debug.h \ + coll_mccl_dtypes.h \ + coll_mccl_module.c \ + coll_mccl_component.c \ + coll_mccl_ops.c + +# Make the output library in this directory, and name it either +# mca__.la (for DSO builds) or libmca__.la +# (for static builds). + +if MCA_BUILD_ompi_coll_mccl_DSO +component_noinst = +component_install = mca_coll_mccl.la +else +component_noinst = libmca_coll_mccl.la +component_install = +endif + +mcacomponentdir = $(ompilibdir) +mcacomponent_LTLIBRARIES = $(component_install) +mca_coll_mccl_la_SOURCES = $(coll_mccl_sources) +mca_coll_mccl_la_LIBADD = $(top_builddir)/ompi/lib@OMPI_LIBMPI_NAME@.la \ + $(coll_mccl_LIBS) +mca_coll_mccl_la_LDFLAGS = -module -avoid-version $(coll_mccl_LDFLAGS) + +noinst_LTLIBRARIES = $(component_noinst) +libmca_coll_mccl_la_SOURCES = $(coll_mccl_sources) +libmca_coll_mccl_la_LIBADD = $(coll_mccl_LIBS) +libmca_coll_mccl_la_LDFLAGS = -module -avoid-version $(coll_mccl_LDFLAGS) + diff --git a/ompi/mca/coll/mccl/coll_mccl.h b/ompi/mca/coll/mccl/coll_mccl.h new file mode 100644 index 00000000000..7244ca75cb6 --- /dev/null +++ b/ompi/mca/coll/mccl/coll_mccl.h @@ -0,0 +1,103 @@ +/** + Copyright (c) 2020 Mellanox Technologies. All rights reserved. + $COPYRIGHT$ + + Additional copyrights may follow + + $HEADER$ + */ + +#ifndef MCA_COLL_MCCL_H +#define MCA_COLL_FcaMCCL_H + +#include "ompi_config.h" + +#include "mpi.h" +#include "ompi/mca/mca.h" +#include "opal/memoryhooks/memory.h" +#include "opal/mca/memory/base/base.h" +#include "ompi/mca/coll/coll.h" +#include "ompi/request/request.h" +#include "ompi/mca/pml/pml.h" +#include "ompi/mca/coll/base/coll_tags.h" +#include "ompi/communicator/communicator.h" +#include "ompi/attribute/attribute.h" +#include "ompi/op/op.h" + +#include "orte/runtime/orte_globals.h" + +#include "api/mccl.h" + +#include "coll_mccl_debug.h" +#ifndef MCCL_VERSION +#define MCCL_VERSION(major, minor) (((major)< + +#include +#include + +#include "coll_mccl.h" +#include "opal/mca/installdirs/installdirs.h" +#include "coll_mccl_dtypes.h" + +/* + * Public string showing the coll ompi_hcol component version number + */ +const char *mca_coll_mccl_component_version_string = + "Open MPI MCCL collective MCA component version " OMPI_VERSION; + + +static int mccl_open(void); +static int mccl_close(void); +static int mccl_register(void); +int mca_coll_mccl_output = -1; +mca_coll_mccl_component_t mca_coll_mccl_component = { + /* First, the mca_component_t struct containing meta information + about the component itfca */ + { + .collm_version = { + MCA_COLL_BASE_VERSION_2_0_0, + + /* Component name and version */ + .mca_component_name = "mccl", + MCA_BASE_MAKE_VERSION(component, OMPI_MAJOR_VERSION, OMPI_MINOR_VERSION, + OMPI_RELEASE_VERSION), + + /* Component open and close functions */ + .mca_open_component = mccl_open, + .mca_close_component = mccl_close, + .mca_register_component_params = mccl_register, + }, + .collm_data = { + /* The component is not checkpoint ready */ + MCA_BASE_METADATA_PARAM_NONE + }, + + /* Initialization / querying functions */ + .collm_init_query = mca_coll_mccl_init_query, + .collm_comm_query = mca_coll_mccl_comm_query, + }, + 120, /* priority */ + 0, /* verbose level */ + 0, /* mccl_enable */ + NULL /*mccl version */ +}; + +enum { + REGINT_NEG_ONE_OK = 0x01, + REGINT_GE_ZERO = 0x02, + REGINT_GE_ONE = 0x04, + REGINT_NONZERO = 0x08, + REGINT_MAX = 0x88 +}; + +enum { + REGSTR_EMPTY_OK = 0x01, + REGSTR_MAX = 0x88 +}; + + +/* + * Utility routine for integer parameter registration + */ +static int reg_int(const char* param_name, const char* deprecated_param_name, + const char* param_desc, int default_value, int *storage, int flags) +{ + int index; + *storage = default_value; + index = mca_base_component_var_register(&mca_coll_mccl_component.super.collm_version, + param_name, param_desc, MCA_BASE_VAR_TYPE_INT, + NULL, 0, 0,OPAL_INFO_LVL_9, + MCA_BASE_VAR_SCOPE_READONLY, storage); + if (NULL != deprecated_param_name) { + (void) mca_base_var_register_synonym(index, "ompi", "coll", "mccl", deprecated_param_name, + MCA_BASE_VAR_SYN_FLAG_DEPRECATED); + } + + if (0 != (flags & REGINT_NEG_ONE_OK) && -1 == *storage) { + return OMPI_SUCCESS; + } + + if ((0 != (flags & REGINT_GE_ZERO) && *storage < 0) || + (0 != (flags & REGINT_GE_ONE) && *storage < 1) || + (0 != (flags & REGINT_NONZERO) && 0 == *storage)) { + opal_output(0, "Bad parameter value for parameter \"%s\"", param_name); + return OMPI_ERR_BAD_PARAM; + } + return OMPI_SUCCESS; +} + + +static int mccl_register(void) +{ + int ret, tmp; + ret = OMPI_SUCCESS; + +#define CHECK(expr) do { \ + tmp = (expr); \ + if (OMPI_SUCCESS != tmp) ret = tmp; \ + } while (0) + + + CHECK(reg_int("priority", NULL, "Priority of the hcol coll component", + 120, &mca_coll_mccl_component.mccl_priority, 0)); + + CHECK(reg_int("verbose", NULL, "Verbose level of the hcol coll component", + 0, &mca_coll_mccl_component.mccl_verbose, 0)); + + CHECK(reg_int("enable", NULL, "[1|0|] Enable/Disable HCOL", + 1, &mca_coll_mccl_component.mccl_enable, 0)); + + CHECK(reg_int("np", NULL, "Minimal number of processes in the communicator" + " for the corresponding mccl context to be created (default: 32)", + 2, &mca_coll_mccl_component.mccl_np, 0)); + + /* mca_coll_mccl_component.compiletime_version = MCCL_VERNO_STRING; */ + + mca_base_component_var_register(&mca_coll_mccl_component.super.collm_version, + MCA_COMPILETIME_VER, + "Version of the libmccl library with which Open MPI was compiled", + MCA_BASE_VAR_TYPE_VERSION_STRING, NULL, 0, 0, + OPAL_INFO_LVL_3, MCA_BASE_VAR_SCOPE_READONLY, + &mca_coll_mccl_component.compiletime_version); + /* mca_coll_mccl_component.runtime_version = mccl_get_version(); */ + mca_base_component_var_register(&mca_coll_mccl_component.super.collm_version, + MCA_RUNTIME_VER, + "Version of the libmccl library with which Open MPI is running", + MCA_BASE_VAR_TYPE_VERSION_STRING, NULL, 0, 0, + OPAL_INFO_LVL_3, MCA_BASE_VAR_SCOPE_READONLY, + &mca_coll_mccl_component.runtime_version); + return ret; +} + +static int mccl_open(void) +{ + mca_coll_mccl_component_t *cm; + cm = &mca_coll_mccl_component; + mca_coll_mccl_output = opal_output_open(NULL); + opal_output_set_verbosity(mca_coll_mccl_output, cm->mccl_verbose); + cm->libmccl_initialized = false; + return OMPI_SUCCESS; +} + +static int mccl_close(void) +{ + return OMPI_SUCCESS; +} diff --git a/ompi/mca/coll/mccl/coll_mccl_debug.h b/ompi/mca/coll/mccl/coll_mccl_debug.h new file mode 100644 index 00000000000..aa37929ca5d --- /dev/null +++ b/ompi/mca/coll/mccl/coll_mccl_debug.h @@ -0,0 +1,30 @@ +/** + Copyright (c) 2020 Mellanox Technologies. All rights reserved. + $COPYRIGHT$ + + Additional copyrights may follow + + $HEADER$ + */ + +#ifndef COLL_MCCL_DEBUG_H +#define COLL_MCCL_DEBUG_H +#include "ompi_config.h" +#pragma GCC system_header + +#ifdef __BASE_FILE__ +#define __MCCL_FILE__ __BASE_FILE__ +#else +#define __MCCL_FILE__ __FILE__ +#endif + +#define MCCL_VERBOSE(level, format, ...) \ + opal_output_verbose(level, mca_coll_mccl_output, "%s:%d - %s() " format, \ + __MCCL_FILE__, __LINE__, __FUNCTION__, ## __VA_ARGS__) + +#define MCCL_ERROR(format, ... ) \ + opal_output_verbose(0, mca_coll_mccl_output, "Error: %s:%d - %s() " format, \ + __MCCL_FILE__, __LINE__, __FUNCTION__, ## __VA_ARGS__) + +extern int mca_coll_mccl_output; +#endif diff --git a/ompi/mca/coll/mccl/coll_mccl_dtypes.h b/ompi/mca/coll/mccl/coll_mccl_dtypes.h new file mode 100644 index 00000000000..6fea81aa52b --- /dev/null +++ b/ompi/mca/coll/mccl/coll_mccl_dtypes.h @@ -0,0 +1,78 @@ +#ifndef COLL_MCCL_DTYPES_H +#define COLL_MCCL_DTYPES_H +#include "ompi/datatype/ompi_datatype.h" +#include "ompi/datatype/ompi_datatype_internal.h" +#include "ompi/mca/op/op.h" +#include "api/mccl.h" + + +static tccl_dt_t ompi_datatype_2_tccl_dt[OMPI_DATATYPE_MAX_PREDEFINED] = { + TCCL_DT_UNSUPPORTED, /*OPAL_DATATYPE_LOOP 0 */ + TCCL_DT_UNSUPPORTED, /*OPAL_DATATYPE_END_LOOP 1 */ + TCCL_DT_UNSUPPORTED, /*OPAL_DATATYPE_LB 2 */ + TCCL_DT_UNSUPPORTED, /*OPAL_DATATYPE_UB 3 */ + TCCL_DT_INT8, /*OPAL_DATATYPE_INT1 4 */ + TCCL_DT_INT16, /*OPAL_DATATYPE_INT2 5 */ + TCCL_DT_INT32, /*OPAL_DATATYPE_INT4 6 */ + TCCL_DT_INT64, /*OPAL_DATATYPE_INT8 7 */ + TCCL_DT_INT128, /*OPAL_DATATYPE_INT16 8 */ + TCCL_DT_UINT8, /*OPAL_DATATYPE_UINT1 9 */ + TCCL_DT_UINT16, /*OPAL_DATATYPE_UINT2 10 */ + TCCL_DT_UINT32, /*OPAL_DATATYPE_UINT4 11 */ + TCCL_DT_UINT64, /*OPAL_DATATYPE_UINT8 12 */ + TCCL_DT_UINT128, /*OPAL_DATATYPE_UINT16 13 */ + TCCL_DT_FLOAT16, /*OPAL_DATATYPE_FLOAT2 14 */ + TCCL_DT_FLOAT32, /*OPAL_DATATYPE_FLOAT4 15 */ + TCCL_DT_FLOAT64, /*OPAL_DATATYPE_FLOAT8 16 */ + TCCL_DT_UNSUPPORTED, /*OPAL_DATATYPE_FLOAT12 17 */ + TCCL_DT_UNSUPPORTED, /*OPAL_DATATYPE_FLOAT16 18 */ + TCCL_DT_UNSUPPORTED, /*OPAL_DATATYPE_SHORT_FLOAT_COMPLEX 19 */ + TCCL_DT_UNSUPPORTED, /*OPAL_DATATYPE_FLOAT_COMPLEX 20 */ + TCCL_DT_UNSUPPORTED, /*OPAL_DATATYPE_DOUBLE_COMPLEX 21 */ + TCCL_DT_UNSUPPORTED, /*OPAL_DATATYPE_LONG_DOUBLE_COMPLEX 22 */ + TCCL_DT_UNSUPPORTED, /*OPAL_DATATYPE_BOOL 23 */ + TCCL_DT_UNSUPPORTED, /*OPAL_DATATYPE_WCHAR 24 */ + TCCL_DT_UNSUPPORTED /*OPAL_DATATYPE_UNAVAILABLE 25 */ +}; + +static inline tccl_dt_t ompi_dtype_to_tccl_dtype(ompi_datatype_t *dtype) +{ + int ompi_type_id = dtype->id; + int opal_type_id = dtype->super.id; + + if (ompi_type_id < OMPI_DATATYPE_MPI_MAX_PREDEFINED && + dtype->super.flags & OMPI_DATATYPE_FLAG_PREDEFINED) { + if (opal_type_id > 0 && opal_type_id < OPAL_DATATYPE_MAX_PREDEFINED) { + return ompi_datatype_2_tccl_dt[opal_type_id]; + } + } + return TCCL_DT_UNSUPPORTED; +} + +static tccl_op_t ompi_op_to_tccl_op_map[OMPI_OP_BASE_FORTRAN_OP_MAX + 1] = { + TCCL_OP_UNSUPPORTED, /* OMPI_OP_BASE_FORTRAN_NULL = 0 */ + TCCL_OP_MAX, /* OMPI_OP_BASE_FORTRAN_MAX */ + TCCL_OP_MIN, /* OMPI_OP_BASE_FORTRAN_MIN */ + TCCL_OP_SUM, /* OMPI_OP_BASE_FORTRAN_SUM */ + TCCL_OP_PROD, /* OMPI_OP_BASE_FORTRAN_PROD */ + TCCL_OP_LAND, /* OMPI_OP_BASE_FORTRAN_LAND */ + TCCL_OP_BAND, /* OMPI_OP_BASE_FORTRAN_BAND */ + TCCL_OP_LOR, /* OMPI_OP_BASE_FORTRAN_LOR */ + TCCL_OP_BOR, /* OMPI_OP_BASE_FORTRAN_BOR */ + TCCL_OP_LXOR, /* OMPI_OP_BASE_FORTRAN_LXOR */ + TCCL_OP_BXOR, /* OMPI_OP_BASE_FORTRAN_BXOR */ + TCCL_OP_UNSUPPORTED, /* OMPI_OP_BASE_FORTRAN_MAXLOC */ + TCCL_OP_UNSUPPORTED, /* OMPI_OP_BASE_FORTRAN_MINLOC */ + TCCL_OP_UNSUPPORTED, /* OMPI_OP_BASE_FORTRAN_REPLACE */ + TCCL_OP_UNSUPPORTED, /* OMPI_OP_BASE_FORTRAN_NO_OP */ + TCCL_OP_UNSUPPORTED /* OMPI_OP_BASE_FORTRAN_OP_MAX */ +}; + +static inline tccl_op_t ompi_op_to_tccl_op(ompi_op_t *op) { + if (op->o_f_to_c_index > OMPI_OP_BASE_FORTRAN_OP_MAX) { + return TCCL_OP_UNSUPPORTED; + } + return ompi_op_to_tccl_op_map[op->o_f_to_c_index]; +} + +#endif /* COLL_MCCL_DTYPES_H */ diff --git a/ompi/mca/coll/mccl/coll_mccl_module.c b/ompi/mca/coll/mccl/coll_mccl_module.c new file mode 100644 index 00000000000..e57d4ec0e2d --- /dev/null +++ b/ompi/mca/coll/mccl/coll_mccl_module.c @@ -0,0 +1,289 @@ +/** + * Copyright (c) 2020 Mellanox Technologies. All rights reserved. + * $COPYRIGHT$ + * + * Additional copyrights may follow + * + * $HEADER$ + */ + +#include "ompi_config.h" +#include "coll_mccl.h" +#include "coll_mccl_dtypes.h" + +int mccl_comm_attr_keyval; +/* + * Initial query function that is invoked during MPI_INIT, allowing + * this module to indicate what level of thread support it provides. + */ +int mca_coll_mccl_init_query(bool enable_progress_threads, bool enable_mpi_threads) +{ + return OMPI_SUCCESS; +} + +static void mca_coll_mccl_module_clear(mca_coll_mccl_module_t *mccl_module) +{ + mccl_module->mccl_comm = NULL; + mccl_module->previous_allreduce = NULL; + mccl_module->previous_barrier = NULL; + mccl_module->previous_bcast = NULL; +} + +static void mca_coll_mccl_module_construct(mca_coll_mccl_module_t *mccl_module) +{ + mca_coll_mccl_module_clear(mccl_module); +} + +#define OBJ_RELEASE_IF_NOT_NULL( obj ) if( NULL != (obj) ) OBJ_RELEASE( obj ); + +int mca_coll_mccl_progress(void) +{ + mccl_progress(mca_coll_mccl_component.mccl_context); + return OPAL_SUCCESS; +} + +static void mca_coll_mccl_module_destruct(mca_coll_mccl_module_t *mccl_module) +{ + int context_destroyed; + if (mccl_module->comm == &ompi_mpi_comm_world.comm){ + if (OMPI_SUCCESS != ompi_attr_free_keyval(COMM_ATTR, &mccl_comm_attr_keyval, 0)) { + MCCL_VERBOSE(1,"mccl ompi_attr_free_keyval failed"); + } + } + + /* If the mccl_context is null then we are destroying the mccl_module + that didn't initialized fallback colls/modules. + Then just clear and return. Otherwise release module pointers and + destroy mccl context*/ + + if (mccl_module->mccl_comm != NULL){ + OBJ_RELEASE_IF_NOT_NULL(mccl_module->previous_allreduce_module); + OBJ_RELEASE_IF_NOT_NULL(mccl_module->previous_barrier_module); + OBJ_RELEASE_IF_NOT_NULL(mccl_module->previous_bcast_module); + } + mca_coll_mccl_module_clear(mccl_module); +} + +#define SAVE_PREV_COLL_API(__api) do { \ + mccl_module->previous_ ## __api = comm->c_coll->coll_ ## __api; \ + mccl_module->previous_ ## __api ## _module = comm->c_coll->coll_ ## __api ## _module; \ + if (!comm->c_coll->coll_ ## __api || !comm->c_coll->coll_ ## __api ## _module) { \ + return OMPI_ERROR; \ + } \ + OBJ_RETAIN(mccl_module->previous_ ## __api ## _module); \ + } while(0) + +static int mca_coll_mccl_save_coll_handlers(mca_coll_mccl_module_t *mccl_module) +{ + ompi_communicator_t *comm; + comm = mccl_module->comm; + SAVE_PREV_COLL_API(allreduce); + SAVE_PREV_COLL_API(barrier); + SAVE_PREV_COLL_API(bcast); + return OMPI_SUCCESS; +} + + + +/* +** Communicator free callback +*/ +static int mccl_comm_attr_del_fn(MPI_Comm comm, int keyval, void *attr_val, void *extra) +{ + + mca_coll_mccl_module_t *mccl_module; + mccl_module = (mca_coll_mccl_module_t*) attr_val; + mccl_comm_free(mccl_module->mccl_comm); + if (mccl_module->comm == &ompi_mpi_comm_world.comm) { + if (mca_coll_mccl_component.libmccl_initialized) { + MCCL_VERBOSE(5,"MCCL FINALIZE"); + if (TCCL_OK != mccl_finalize(mca_coll_mccl_component.mccl_context)) { + MCCL_VERBOSE(1,"MCCL library finalize failed"); + } + opal_progress_unregister(mca_coll_mccl_progress); + } + } + return OMPI_SUCCESS; +} + + +static int oob_allgather(void *sbuf, void *rbuf, size_t msglen, + int my_rank, int *ranks, int nranks, void *oob_coll_ctx) { + ompi_communicator_t *comm = (ompi_communicator_t *)oob_coll_ctx; + if (!comm) comm = &ompi_mpi_comm_world.comm; + if (ranks == NULL) { + comm->c_coll->coll_allgather(sbuf, msglen, MPI_BYTE, + rbuf, msglen, MPI_BYTE, comm, + comm->c_coll->coll_allgather_module); + } else { + if (my_rank == ranks[0]) { + int i; + memcpy(rbuf, sbuf, msglen); + for (i=1; ic_coll->coll_allgather(sbuf, msglen, MPI_BYTE, + rbuf, msglen, MPI_BYTE, comm, + comm->c_coll->coll_allgather_module); + return 0; +} + +/* + * Initialize module on the communicator + */ +static int mca_coll_mccl_module_enable(mca_coll_base_module_t *module, + struct ompi_communicator_t *comm) +{ + int rc; + mca_coll_mccl_module_t *mccl_module; + ompi_attribute_fn_ptr_union_t del_fn; + ompi_attribute_fn_ptr_union_t copy_fn; + mca_coll_mccl_component_t *cm = + &mca_coll_mccl_component; + if (!cm->libmccl_initialized) + { + MCCL_VERBOSE(10,"Calling mccl_init();"); + mccl_config_t config = { + .flags = 0, + .world_size = ompi_comm_size(&ompi_mpi_comm_world.comm), + .world_rank = ompi_comm_rank(&ompi_mpi_comm_world.comm), + .allgather = oob_allgather_ctx, + .oob_coll_ctx = NULL, + }; + + rc = mccl_init_context(&config, &cm->mccl_context); + if (TCCL_OK != rc){ + cm->mccl_enable = 0; + /* opal_progress_unregister(mccl_progress_fn); */ + MCCL_ERROR("Hcol library init failed"); + return OMPI_ERROR; + } + copy_fn.attr_communicator_copy_fn = (MPI_Comm_internal_copy_attr_function*) MPI_COMM_NULL_COPY_FN; + del_fn.attr_communicator_delete_fn = mccl_comm_attr_del_fn; + rc = ompi_attr_create_keyval(COMM_ATTR, copy_fn, del_fn, &mccl_comm_attr_keyval, NULL ,0, NULL); + if (OMPI_SUCCESS != rc) { + cm->mccl_enable = 0; + /* opal_progress_unregister(mccl_progress_fn); */ + /* mccl_finalize(); */ + MCCL_ERROR("Hcol comm keyval create failed"); + return OMPI_ERROR; + } + opal_progress_register(mca_coll_mccl_progress); + cm->libmccl_initialized = true; + } + + MCCL_VERBOSE(10,"Creating mccl_context for comm %p, comm_id %d, comm_size %d", + (void*)comm,comm->c_contextid,ompi_comm_size(comm)); + + mccl_comm_config_t comm_config = { + .allgather = oob_allgather, + .oob_coll_ctx = (void*)comm, + .mccl_ctx = cm->mccl_context, + .is_world = (comm == &ompi_mpi_comm_world.comm ? 1 : 0), + .world_rank = ompi_comm_rank(&ompi_mpi_comm_world.comm), + .comm_size = ompi_comm_size(comm), + .comm_rank = ompi_comm_rank(comm), + .caps.tagged_colls = 0, + }; + + mccl_module = (mca_coll_mccl_module_t *)module; + if (TCCL_OK != mccl_comm_create(&comm_config, &mccl_module->mccl_comm)) { + MCCL_VERBOSE(1,"mccl_create_context returned NULL"); + OBJ_RELEASE(mccl_module); + if (!cm->libmccl_initialized) { + cm->mccl_enable = 0; + /* mccl_finalize(); */ + /* opal_progress_unregister(mccl_progress_fn); */ + } + return OMPI_ERROR; + } + + if (OMPI_SUCCESS != mca_coll_mccl_save_coll_handlers((mca_coll_mccl_module_t *)module)){ + MCCL_ERROR("coll_mccl: mca_coll_mccl_save_coll_handlers failed"); + return OMPI_ERROR; + } + + rc = ompi_attr_set_c(COMM_ATTR, comm, &comm->c_keyhash, mccl_comm_attr_keyval, (void *)module, false); + if (OMPI_SUCCESS != rc) { + MCCL_VERBOSE(1,"mccl ompi_attr_set_c failed"); + return OMPI_ERROR; + } + + return OMPI_SUCCESS; +} + + +/* + * Invoked when there's a new communicator that has been created. + * Look at the communicator and decide which set of functions and + * priority we want to return. + */ +mca_coll_base_module_t * +mca_coll_mccl_comm_query(struct ompi_communicator_t *comm, int *priority) +{ + int err; + int rc; + mca_coll_mccl_module_t *mccl_module; + mca_coll_mccl_component_t *cm = + &mca_coll_mccl_component; + *priority = 0; + + if (!cm->mccl_enable){ + return NULL; + } + + if (OMPI_COMM_IS_INTER(comm) || ompi_comm_size(comm) < cm->mccl_np + || ompi_comm_size(comm) < 2){ + return NULL; + } + + mccl_module = OBJ_NEW(mca_coll_mccl_module_t); + if (!mccl_module){ + if (!cm->libmccl_initialized) { + cm->mccl_enable = 0; + /* mccl_finalize(); */ + /* opal_progress_unregister(mccl_progress_fn); */ + } + return NULL; + } + + mccl_module->comm = comm; + mccl_module->super.coll_module_enable = mca_coll_mccl_module_enable; + mccl_module->super.coll_allreduce = mca_coll_mccl_allreduce; + mccl_module->super.coll_barrier = mca_coll_mccl_barrier; + mccl_module->super.coll_bcast = mca_coll_mccl_bcast; + *priority = cm->mccl_priority; + return &mccl_module->super; +} + + +OBJ_CLASS_INSTANCE(mca_coll_mccl_module_t, + mca_coll_base_module_t, + mca_coll_mccl_module_construct, + mca_coll_mccl_module_destruct); + + + diff --git a/ompi/mca/coll/mccl/coll_mccl_ops.c b/ompi/mca/coll/mccl/coll_mccl_ops.c new file mode 100644 index 00000000000..b673dde04a7 --- /dev/null +++ b/ompi/mca/coll/mccl/coll_mccl_ops.c @@ -0,0 +1,90 @@ +/** + Copyright (c) 2020 Mellanox Technologies. All rights reserved. + $COPYRIGHT$ + + Additional copyrights may follow + + $HEADER$ + */ + +#include "ompi_config.h" +#include "ompi/constants.h" +#include "coll_mccl.h" +#include "coll_mccl_dtypes.h" + +#define COLL_MCCL_CHECK(_call) do { \ + if (TCCL_OK != (_call)) { \ + goto fallback; \ + } \ + } while(0) + +static inline int coll_mccl_req_wait(mccl_request_h req) { + while (TCCL_INPROGRESS == mccl_test(req)) { + opal_progress(); + } + return mccl_request_free(req); +} + +int mca_coll_mccl_allreduce(const void *sbuf, void *rbuf, int count, struct ompi_datatype_t *dtype, + struct ompi_op_t *op, struct ompi_communicator_t *comm, + mca_coll_base_module_t *module) +{ + mccl_request_h req; + tccl_dt_t tccl_dt; + tccl_op_t tccl_op; + mca_coll_mccl_module_t *mccl_module = (mca_coll_mccl_module_t*)module; + + MCCL_VERBOSE(20,"RUNNING MCCL ALLREDUCE"); + tccl_dt = ompi_dtype_to_tccl_dtype(dtype); + tccl_op = ompi_op_to_tccl_op(op); + if (OPAL_UNLIKELY(TCCL_DT_UNSUPPORTED == tccl_dt || TCCL_OP_UNSUPPORTED == tccl_op)) { + MCCL_VERBOSE(20,"Ompi_datatype is not supported: dtype = %s; calling fallback allreduce;", + dtype->super.name); + goto fallback; + } + + COLL_MCCL_CHECK(mccl_allreduce_init((void *)sbuf, rbuf, count, tccl_dt, + tccl_op, mccl_module->mccl_comm, &req)); + COLL_MCCL_CHECK(mccl_start(req)); + COLL_MCCL_CHECK(coll_mccl_req_wait(req)); + return OMPI_SUCCESS; +fallback: + MCCL_VERBOSE(20,"RUNNING FALLBACK ALLREDUCE"); + return mccl_module->previous_allreduce(sbuf, rbuf, count, dtype, op, + comm, mccl_module->previous_allreduce_module); +} + +int mca_coll_mccl_barrier(struct ompi_communicator_t *comm, + mca_coll_base_module_t *module) +{ + mccl_request_h req; + mca_coll_mccl_module_t *mccl_module = (mca_coll_mccl_module_t*)module; + + MCCL_VERBOSE(20,"RUNNING MCCL BARRIER"); + COLL_MCCL_CHECK(mccl_barrier_init(mccl_module->mccl_comm, &req)); + COLL_MCCL_CHECK(mccl_start(req)); + COLL_MCCL_CHECK(coll_mccl_req_wait(req)); + return OMPI_SUCCESS; +fallback: + MCCL_VERBOSE(20,"RUNNING FALLBACK BARRIER"); + return mccl_module->previous_barrier(comm, mccl_module->previous_barrier_module); +} + +int mca_coll_mccl_bcast(void *buf, int count, struct ompi_datatype_t *dtype, + int root, struct ompi_communicator_t *comm, + mca_coll_base_module_t *module) +{ + mccl_request_h req; + tccl_dt_t tccl_dt; + mca_coll_mccl_module_t *mccl_module = (mca_coll_mccl_module_t*)module; + + MCCL_VERBOSE(20,"RUNNING MCCL BCAST"); + tccl_dt = ompi_dtype_to_tccl_dtype(dtype); + COLL_MCCL_CHECK(mccl_bcast_init(buf, count, tccl_dt, root, mccl_module->mccl_comm, &req)); + COLL_MCCL_CHECK(mccl_start(req)); + COLL_MCCL_CHECK(coll_mccl_req_wait(req)); + return OMPI_SUCCESS; +fallback: + MCCL_VERBOSE(20,"RUNNING FALLBACK BCAST"); + return mccl_module->previous_barrier(comm, mccl_module->previous_barrier_module); +} diff --git a/ompi/mca/coll/mccl/configure.m4 b/ompi/mca/coll/mccl/configure.m4 new file mode 100644 index 00000000000..5d985e9e414 --- /dev/null +++ b/ompi/mca/coll/mccl/configure.m4 @@ -0,0 +1,38 @@ +# -*- shell-script -*- +# +# +# Copyright (c) 2011 Mellanox Technologies. All rights reserved. +# Copyright (c) 2015 Research Organization for Information Science +# and Technology (RIST). All rights reserved. +# $COPYRIGHT$ +# +# Additional copyrights may follow +# +# $HEADER$ +# + + +# MCA_coll_mccl_CONFIG([action-if-can-compile], +# [action-if-cant-compile]) +# ------------------------------------------------ +AC_DEFUN([MCA_ompi_coll_mccl_CONFIG],[ + AC_CONFIG_FILES([ompi/mca/coll/mccl/Makefile]) + + OMPI_CHECK_MCCL([coll_mccl], + [coll_mccl_happy="yes"], + [coll_mccl_happy="no"]) + + AS_IF([test "$coll_mccl_happy" = "yes"], + [coll_mccl_WRAPPER_EXTRA_LDFLAGS="$coll_mccl_LDFLAGS" + coll_mccl_CPPFLAGS="$coll_mccl_CPPFLAGS" + coll_mccl_WRAPPER_EXTRA_LIBS="$coll_mccl_LIBS" + $1], + [$2]) + + # substitute in the things needed to build mccl + AC_SUBST([coll_mccl_CFLAGS]) + AC_SUBST([coll_mccl_CPPFLAGS]) + AC_SUBST([coll_mccl_LDFLAGS]) + AC_SUBST([coll_mccl_LIBS]) +])dnl +