Skip to content

Commit

Permalink
Merge pull request openucx#9 from shizhibao/huawei
Browse files Browse the repository at this point in the history
UT: fix ucg datatype convert function for ut
  • Loading branch information
nsosnsos authored Nov 21, 2020
2 parents dccae87 + 72849cb commit dd0f7c0
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 4 deletions.
1 change: 1 addition & 0 deletions test/gtest/ucg/test_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ ucg_plan_t *ucg_op_test::create_plan(unsigned phs_cnt, ucg_collective_params_t *
builtin_plan->resend = NULL;
builtin_plan->slots = NULL;
builtin_plan->phs_cnt = phs_cnt;
builtin_plan->convert_f = mca_coll_ucg_datatype_convert_for_ut;

ucs_mpool_ops_t ops = {
ucs_mpool_chunk_malloc,
Expand Down
19 changes: 15 additions & 4 deletions test/gtest/ucg/ucg_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,16 @@ ucg_collective_params_t *ucg_test::create_bcast_params() const {
0, send_buf, count, recv_buf, sizeof(int), NULL, NULL);
}

int mca_coll_ucg_datatype_convert_for_ut(void *mpi_dt, ucp_datatype_t *ucp_dt)
{
if (mpi_dt != NULL) {
ucs_info("mca_coll_ucg_datatype_convert_for_ut");
}

*ucp_dt = UCP_DATATYPE_CONTIG;
return 0;
}

ucg_builtin_config_t *ucg_resource_factory::create_config(
unsigned bcast_alg, unsigned allreduce_alg, unsigned barrier_alg)
{
Expand Down Expand Up @@ -174,6 +184,7 @@ ucg_group_params_t *ucg_resource_factory::create_group_params(
args->release_address_f = NULL;
args->cb_group_obj = NULL;
args->op_is_commute_f = ompi_op_is_commute;
args->mpi_dt_convert = mca_coll_ucg_datatype_convert_for_ut;
args->distance = (ucg_group_member_distance *) malloc(args->member_count * sizeof(*args->distance));
args->node_index = (uint16_t *) malloc(args->member_count * sizeof(*args->node_index));

Expand Down Expand Up @@ -217,10 +228,10 @@ ucg_collective_params_t *ucg_resource_factory::create_collective_params(
params->send.op_ext = op_ext;

params->recv.buf = recv_buffer;
params->send.count = count;
params->send.dt_len = dt_len;
params->send.dt_ext = dt_ext;
params->send.op_ext = op_ext;
params->recv.count = count;
params->recv.dt_len = dt_len;
params->recv.dt_ext = dt_ext;
params->recv.op_ext = op_ext;

return params;
}
Expand Down
2 changes: 2 additions & 0 deletions test/gtest/ucg/ucg_test.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ struct ucg_ompi_op {
bool commutative;
};

int mca_coll_ucg_datatype_convert_for_ut(void *mpi_dt, ucp_datatype_t *ucp_dt);

class ucg_resource_factory {
public:
ucg_builtin_config_t *create_config(unsigned bcast_alg, unsigned allreduce_alg, unsigned barrier_alg);
Expand Down

0 comments on commit dd0f7c0

Please sign in to comment.