diff --git a/src/ucp/wireup/select.c b/src/ucp/wireup/select.c index ee9e20f2d97..a661f496e4e 100644 --- a/src/ucp/wireup/select.c +++ b/src/ucp/wireup/select.c @@ -1231,9 +1231,7 @@ ucp_wireup_iface_avail_bandwidth(const ucp_worker_iface_t *wiface, double eps = 1e-3; double local_bw, remote_bw; - local_bw = ucp_wireup_iface_bw_distance(wiface) * - ucp_tl_iface_bandwidth_ratio(context, local_dev_count[dev_index], - wiface->attr.dev_num_paths); + local_bw = ucp_wireup_iface_bw_distance(wiface); if (remote_addr->iface_attr.addr_version == UCP_OBJECT_VERSION_V2) { /* FP8 is a lossy compression method, so in order to create a symmetric @@ -1241,6 +1239,13 @@ ucp_wireup_iface_avail_bandwidth(const ucp_worker_iface_t *wiface, local_bw = ucp_wireup_fp8_pack_unpack_bw(local_bw); } + /* Apply dev num paths ratio after fp8 pack/unpack to make sure it is not + * neglected because of fp8 inaccuracy + */ + local_bw *= ucp_tl_iface_bandwidth_ratio( + context, local_dev_count[dev_index], + wiface->attr.dev_num_paths); + remote_bw = remote_addr->iface_attr.bandwidth * ucp_tl_iface_bandwidth_ratio( context, remote_dev_count[remote_addr->dev_index], @@ -1412,15 +1417,23 @@ ucp_wireup_is_md_map_count_valid(ucp_context_h context, ucp_md_map_t md_map) static double ucp_wireup_get_lane_bw(ucp_worker_h worker, const ucp_wireup_select_info_t *sinfo, - const ucp_address_entry_t *address) + const ucp_address_entry_t *address_list) { ucp_context_h context = worker->context; const uct_iface_attr_t *iface_attr; + const ucp_address_entry_t *address; double bw_local, bw_remote; iface_attr = ucp_worker_iface_get_attr(worker, sinfo->rsc_index); bw_local = ucp_tl_iface_bandwidth(context, &iface_attr->bandwidth); - bw_remote = address[sinfo->addr_index].iface_attr.bandwidth; + address = &address_list[sinfo->addr_index]; + bw_remote = address->iface_attr.bandwidth; + + if (address->iface_attr.addr_version == UCP_OBJECT_VERSION_V2) { + /* FP8 is a lossy compression method, so in order to create a symmetric + * calculation we pack/unpack the local bandwidth as well */ + bw_local = ucp_wireup_fp8_pack_unpack_bw(bw_local); + } return ucs_min(bw_local, bw_remote); }