diff --git a/include/ofi_util.h b/include/ofi_util.h index e2fc1ad94d2..b173e008a2a 100644 --- a/include/ofi_util.h +++ b/include/ofi_util.h @@ -770,6 +770,8 @@ struct util_av { ofi_mutex_t ep_list_lock; }; +#define OFI_AV_DYN_ADDRLEN (1 << 0) + struct util_av_attr { /* Must be a multiple of 8 bytes */ size_t addrlen; @@ -803,8 +805,6 @@ void ofi_av_write_event(struct util_av *av, uint64_t data, int ofi_ip_av_create(struct fid_domain *domain_fid, struct fi_av_attr *attr, struct fid_av **av, void *context); -int ofi_ip_av_create_flags(struct fid_domain *domain_fid, struct fi_av_attr *attr, - struct fid_av **av, void *context, int flags); void *ofi_av_get_addr(struct util_av *av, fi_addr_t fi_addr); #define ofi_ip_av_get_addr ofi_av_get_addr diff --git a/prov/util/src/util_av.c b/prov/util/src/util_av.c index e7071071c76..81334c7a39a 100644 --- a/prov/util/src/util_av.c +++ b/prov/util/src/util_av.c @@ -438,7 +438,7 @@ size_t ofi_av_size(struct util_av *av) static int util_verify_av_util_attr(struct util_domain *domain, const struct util_av_attr *util_attr) { - if (util_attr->flags) { + if (util_attr->flags & ~(OFI_AV_DYN_ADDRLEN)) { FI_WARN(domain->prov, FI_LOG_AV, "invalid internal flags\n"); return -FI_EINVAL; } @@ -630,9 +630,19 @@ int ofi_ip_av_insertv(struct util_av *av, const void *addr, size_t addrlen, int *sync_err = NULL; size_t i; - assert(av->addrlen >= addrlen); - if (av->addrlen > addrlen) + if (!count) + goto done; + + if (addrlen > av->addrlen) { + FI_WARN(av->prov, FI_LOG_AV, "Address too large for AV\n"); + return -FI_EINVAL; + } + + if (!(av->flags & OFI_AV_DYN_ADDRLEN)) { av->addrlen = addrlen; + av->flags &= ~OFI_AV_DYN_ADDRLEN; + } + assert(av->addrlen == addrlen); FI_DBG(av->prov, FI_LOG_AV, "inserting %zu addresses\n", count); if (flags & FI_SYNC_ERR) { @@ -651,6 +661,7 @@ int ofi_ip_av_insertv(struct util_av *av, const void *addr, size_t addrlen, sync_err[i] = -ret; } +done: FI_DBG(av->prov, FI_LOG_AV, "%d addresses successful\n", success_cnt); if (av->eq) { ofi_av_write_event(av, success_cnt, 0, context); @@ -964,22 +975,24 @@ static struct fi_ops ip_av_fi_ops = { .ops_open = fi_no_ops_open, }; -int ofi_ip_av_create_flags(struct fid_domain *domain_fid, struct fi_av_attr *attr, - struct fid_av **av, void *context, int flags) +int ofi_ip_av_create(struct fid_domain *domain_fid, struct fi_av_attr *attr, + struct fid_av **av, void *context) { struct util_domain *domain; - struct util_av_attr util_attr; + struct util_av_attr util_attr = { 0 }; struct util_av *util_av; int ret; domain = container_of(domain_fid, struct util_domain, domain_fid); - if (domain->addr_format == FI_SOCKADDR_IN) + + if (domain->addr_format == FI_SOCKADDR_IN) { util_attr.addrlen = sizeof(struct sockaddr_in); - else + } else if (domain->addr_format == FI_SOCKADDR_IN6) { util_attr.addrlen = sizeof(struct sockaddr_in6); - - util_attr.flags = flags; - util_attr.context_len = 0; + } else { + util_attr.addrlen = sizeof(struct sockaddr_in6); + util_attr.flags = OFI_AV_DYN_ADDRLEN; + } if (attr->type == FI_AV_UNSPEC) attr->type = FI_AV_MAP; @@ -999,9 +1012,3 @@ int ofi_ip_av_create_flags(struct fid_domain *domain_fid, struct fi_av_attr *att (*av)->ops = &ip_av_ops; return 0; } - -int ofi_ip_av_create(struct fid_domain *domain_fid, struct fi_av_attr *attr, - struct fid_av **av, void *context) -{ - return ofi_ip_av_create_flags(domain_fid, attr, av, context, 0); -}