diff --git a/prov/shm/src/smr_av.c b/prov/shm/src/smr_av.c index 355d3bcad64..aa93877841a 100644 --- a/prov/shm/src/smr_av.c +++ b/prov/shm/src/smr_av.c @@ -67,10 +67,13 @@ static int smr_map_init(const struct fi_provider *prov, struct smr_map *map, static void smr_map_cleanup(struct smr_map *map) { - int64_t i; + int ret; - for (i = 0; i < SMR_MAX_PEERS; i++) - smr_map_del(map, i); + ret = ofi_rbmap_foreach(&map->rbmap, map->rbmap.root, smr_map_unmap, + NULL); + if (ret) + FI_WARN(&smr_prov, FI_LOG_AV, + "Failed to remove all entries from the map\n"); ofi_rbmap_cleanup(&map->rbmap); } @@ -148,8 +151,13 @@ static int smr_av_insert(struct fid_av *av_fid, const void *addr, size_t count, if (ret) { if (fi_addr) fi_addr[i] = util_addr; - if (shm_id >= 0) - smr_map_del(&smr_av->smr_map, shm_id); + if (shm_id >= 0) { + ret = smr_map_del(&smr_av->smr_map, shm_id); + if (ret) + FI_WARN(&smr_prov, FI_LOG_AV, + "Failed to remove shm_id %ld\n", + shm_id); + } continue; } @@ -207,11 +215,15 @@ static int smr_av_remove(struct fid_av *av_fid, fi_addr_t *fi_addr, size_t count break; } - smr_map_del(&smr_av->smr_map, id); + ret = smr_map_del(&smr_av->smr_map, id); + if (ret) { + FI_WARN(&smr_prov, FI_LOG_AV, + "Failed to remove shm_id %ld\n", id); + break; + } dlist_foreach(&util_av->ep_list, av_entry) { util_ep = container_of(av_entry, struct util_ep, av_entry); smr_ep = container_of(util_ep, struct smr_ep, util_ep); - smr_unmap_from_endpoint(smr_ep->region, id); if (smr_av->smr_map.num_peers > 0) smr_ep->region->max_sar_buf_per_peer = SMR_MAX_PEERS / diff --git a/prov/shm/src/smr_ep.c b/prov/shm/src/smr_ep.c index 8803495e382..41eccca8608 100644 --- a/prov/shm/src/smr_ep.c +++ b/prov/shm/src/smr_ep.c @@ -223,7 +223,9 @@ int64_t smr_verify_peer(struct smr_ep *ep, fi_addr_t fi_addr) return id; if (!ep->region->map->peers[id].region) { + ofi_spin_lock(&ep->region->map->lock); ret = smr_map_to_region(&smr_prov, ep->region->map, id); + ofi_spin_unlock(&ep->region->map->lock); if (ret) return -1; } diff --git a/prov/shm/src/smr_progress.c b/prov/shm/src/smr_progress.c index 141826b9bba..14e9a42136d 100644 --- a/prov/shm/src/smr_progress.c +++ b/prov/shm/src/smr_progress.c @@ -878,7 +878,9 @@ static void smr_progress_connreq(struct smr_ep *ep, struct smr_cmd *cmd) peer_smr = smr_peer_region(ep->region, idx); if (!peer_smr) { + ofi_spin_lock(&ep->region->map->lock); ret = smr_map_to_region(&smr_prov, ep->region->map, idx); + ofi_spin_unlock(&ep->region->map->lock); if (ret) { FI_WARN(&smr_prov, FI_LOG_EP_CTRL, "Could not map peer region\n"); @@ -891,14 +893,11 @@ static void smr_progress_connreq(struct smr_ep *ep, struct smr_cmd *cmd) if (peer_smr->pid != (int) cmd->msg.hdr.data) { /* TODO track and update/complete in error any transfers * to or from old mapping - * - * TODO create smr_unmap_region - * this needs to close peer_smr->map->peers[idx].pid_fd - * This case will also return an unmapped region because the idx - * is valid but the region was unmapped */ - munmap(peer_smr, peer_smr->total_size); + ofi_spin_lock(&ep->region->map->lock); + smr_unmap_region(&smr_prov, ep->region->map, idx); smr_map_to_region(&smr_prov, ep->region->map, idx); + ofi_spin_unlock(&ep->region->map->lock); peer_smr = smr_peer_region(ep->region, idx); } diff --git a/prov/shm/src/smr_util.c b/prov/shm/src/smr_util.c index 2924ddaa6f2..3f1ab6a7f6d 100644 --- a/prov/shm/src/smr_util.c +++ b/prov/shm/src/smr_util.c @@ -367,16 +367,15 @@ int smr_map_to_region(const struct fi_provider *prov, struct smr_map *map, } pthread_mutex_unlock(&ep_list_lock); - ofi_spin_lock(&map->lock); if (peer_buf->region) - goto unlock; + return FI_SUCCESS; + ofi_spin_held(&map->lock); fd = shm_open(name, O_RDWR, S_IRUSR | S_IWUSR); if (fd < 0) { - ret = -errno; FI_WARN_ONCE(prov, FI_LOG_AV, "shm_open error: name %s errno %d\n", name, errno); - goto unlock; + return -errno; } memset(tmp, 0, sizeof(tmp)); @@ -437,8 +436,6 @@ int smr_map_to_region(const struct fi_provider *prov, struct smr_map *map, out: close(fd); -unlock: - ofi_spin_unlock(&map->lock); return ret; } @@ -448,6 +445,7 @@ void smr_map_to_endpoint(struct smr_region *region, int64_t id) struct smr_region *peer_smr; struct smr_peer_data *local_peers; + ofi_spin_held(&map->lock); peer_smr = smr_peer_region(region, id); if (region->map->peers[id].peer.id < 0 || !peer_smr) return; @@ -479,24 +477,63 @@ void smr_map_to_endpoint(struct smr_region *region, int64_t id) return; } +void smr_unmap_region(const struct fi_provider *prov, struct smr_map *map, + int64_t peer_id) +{ + struct smr_region *peer_region; + struct smr_peer *peer; + struct util_ep *util_ep; + struct smr_ep *smr_ep; + struct smr_av *av; + int ret = 0; + + ofi_spin_held(&map->lock); + peer_region = map->peers[peer_id].region; + if (!peer_region) + return; + + peer = &map->peers[peer_id]; + av = container_of(map, struct smr_av, smr_map); + dlist_foreach_container(&av->util_av.ep_list, struct util_ep, util_ep, + av_entry) { + smr_ep = container_of(util_ep, struct smr_ep, util_ep); + smr_unmap_from_endpoint(smr_ep->region, peer_id); + } + + if (map->flags & SMR_FLAG_HMEM_ENABLED) { + ret = ofi_hmem_host_unregister(peer_region); + if (ret) + FI_WARN(prov, FI_LOG_EP_CTRL, + "unable to unregister shm with iface\n"); + + if (peer->pid_fd != -1) { + close(peer->pid_fd); + peer->pid_fd = -1; + } + } + + munmap(peer_region, peer_region->total_size); + peer->region = NULL; +} + void smr_unmap_from_endpoint(struct smr_region *region, int64_t id) { struct smr_region *peer_smr; struct smr_peer_data *local_peers, *peer_peers; int64_t peer_id; - local_peers = smr_peer_data(region); if (region->map->peers[id].peer.id < 0) return; peer_smr = smr_peer_region(region, id); - peer_id = smr_peer_data(region)[id].addr.id; - + assert(peer_smr); peer_peers = smr_peer_data(peer_smr); + peer_id = smr_peer_data(region)[id].addr.id; peer_peers[peer_id].addr.id = -1; peer_peers[peer_id].name_sent = 0; + local_peers = smr_peer_data(region); ofi_xpmem_release(&local_peers[peer_id].xpmem); } @@ -544,40 +581,36 @@ int smr_map_add(const struct fi_provider *prov, struct smr_map *map, return FI_SUCCESS; } -void smr_map_del(struct smr_map *map, int64_t id) +int smr_map_unmap(struct ofi_rbmap *rbmap, struct ofi_rbnode *node, + void *context) { - struct dlist_entry *entry; + struct smr_map *map = container_of(rbmap, struct smr_map, rbmap); + int64_t id = (uintptr_t) node->data; assert(id >= 0 && id < SMR_MAX_PEERS); - - pthread_mutex_lock(&ep_list_lock); - entry = dlist_find_first_match(&ep_name_list, smr_match_name, - smr_no_prefix(map->peers[id].peer.name)); - pthread_mutex_unlock(&ep_list_lock); - - ofi_spin_lock(&map->lock); - (void) ofi_rbmap_find_delete(&map->rbmap, - (void *) map->peers[id].peer.name); - + smr_unmap_region(&smr_prov, map, id); map->peers[id].fiaddr = FI_ADDR_NOTAVAIL; map->peers[id].peer.id = -1; map->num_peers--; - if (!map->peers[id].region) - goto unlock; + return FI_SUCCESS; +} - if (!entry) { - if (map->flags & SMR_FLAG_HMEM_ENABLED) { - if (map->peers[id].pid_fd != -1) - close(map->peers[id].pid_fd); +int smr_map_del(struct smr_map *map, int64_t shm_id) +{ + struct ofi_rbnode *node; - (void) ofi_hmem_host_unregister(map->peers[id].region); - } - munmap(map->peers[id].region, map->peers[id].region->total_size); - map->peers[id].region = NULL; + ofi_spin_lock(&map->lock); + node = ofi_rbmap_find(&map->rbmap, map->peers[shm_id].peer.name); + if (!node) { + ofi_spin_unlock(&map->lock); + return -FI_ENOENT; } -unlock: + smr_map_unmap(&map->rbmap, node, NULL); + ofi_rbmap_delete(&map->rbmap, node); ofi_spin_unlock(&map->lock); + + return FI_SUCCESS; } struct smr_region *smr_map_get(struct smr_map *map, int64_t id) diff --git a/prov/shm/src/smr_util.h b/prov/shm/src/smr_util.h index c5bf8124873..6da51f47d0d 100644 --- a/prov/shm/src/smr_util.h +++ b/prov/shm/src/smr_util.h @@ -356,11 +356,15 @@ void smr_cleanup(void); int smr_map_to_region(const struct fi_provider *prov, struct smr_map *map, int64_t id); void smr_map_to_endpoint(struct smr_region *region, int64_t id); +void smr_unmap_region(const struct fi_provider *prov, struct smr_map *map, + int64_t id); void smr_unmap_from_endpoint(struct smr_region *region, int64_t id); void smr_exchange_all_peers(struct smr_region *region); int smr_map_add(const struct fi_provider *prov, struct smr_map *map, const char *name, int64_t *id); -void smr_map_del(struct smr_map *map, int64_t id); +int smr_map_unmap(struct ofi_rbmap *rbmap, struct ofi_rbnode *node, + void *context); +int smr_map_del(struct smr_map *map, int64_t shm_id); struct smr_region *smr_map_get(struct smr_map *map, int64_t id);