diff --git a/src/rdma_transport.h b/src/rdma_transport.h index 92e939e9..522b7a15 100644 --- a/src/rdma_transport.h +++ b/src/rdma_transport.h @@ -127,7 +127,8 @@ struct Endpoint { CHECK(buf); struct ibv_mr *mr = ibv_reg_mr(pd, buf, kMempoolChunkSize, IBV_ACCESS_LOCAL_WRITE); - CHECK(mr); + CHECK(mr)<< "ibv_reg_mr failed: " << strerror(errno) + << ", i=" << i <<", kMempoolChunkSize="<< kMempoolChunkSize; rx_ctx[i].type = kReceiveContext; rx_ctx[i].buffer = mr; diff --git a/src/rdma_van.h b/src/rdma_van.h index ccae4f42..d6952ec1 100755 --- a/src/rdma_van.h +++ b/src/rdma_van.h @@ -76,7 +76,11 @@ class RDMAVan : public Van { PS_VLOG(1) << "Clearing endpoints."; incoming_.clear(); - endpoints_.clear(); + { + std::lock_guard lk(endpoints_mu_); + endpoints_.clear(); + } + PS_VLOG(1) << "Destroying cq and pd."; CHECK(!ibv_destroy_cq(cq_)) << "Failed to destroy CQ"; @@ -139,6 +143,7 @@ class RDMAVan : public Van { } if (node.id != Node::kEmpty) { + endpoints_mu_.lock(); auto it = endpoints_.find(node.id); // if there is an endpoint with pending connection @@ -149,6 +154,7 @@ class RDMAVan : public Van { Endpoint *endpoint; endpoints_[node.id] = std::make_unique(); endpoint = endpoints_[node.id].get(); + endpoints_mu_.unlock(); endpoint->SetNodeID(node.id); @@ -197,18 +203,19 @@ class RDMAVan : public Van { if (endpoint->status == Endpoint::CONNECTED) break; std::this_thread::sleep_for(std::chrono::milliseconds(500)); } + + bool is_local_node = disable_ipc_ ? false : + (node.hostname == my_node_.hostname ? true : false); + { + std::lock_guard lk(local_mu_); + is_local_[node.id] = is_local_node; + } + + LOG(INFO) << "Connect to Node " << node.id + << " with Transport=" << (is_local_node ? "IPC" : "RDMA"); - local_mu_.lock(); - if (disable_ipc_) { - is_local_[node.id] = false; - } else { - is_local_[node.id] = (node.hostname == my_node_.hostname) ? true : false; - } - LOG(INFO) << "Connect to Node " << node.id - << " with Transport=" << (is_local_[node.id]?"IPC" : "RDMA"); - local_mu_.unlock(); - std::shared_ptr t = is_local_[node.id] ? + std::shared_ptr t = is_local_node ? std::make_shared(endpoint, mem_allocator_.get()) : std::make_shared(endpoint, mem_allocator_.get()); endpoint->SetTransport(t); @@ -220,8 +227,11 @@ class RDMAVan : public Van { int SendMsg(Message &msg) override { int remote_id = msg.meta.recver; CHECK_NE(remote_id, Meta::kEmpty); + + endpoints_mu_.lock(); CHECK_NE(endpoints_.find(remote_id), endpoints_.end()); Endpoint *endpoint = endpoints_[remote_id].get(); + endpoints_mu_.unlock(); int meta_len = GetPackMetaLen(msg.meta); size_t data_len = msg.meta.data_size; @@ -753,9 +763,12 @@ class RDMAVan : public Van { void OnRejected(struct rdma_cm_event *event) { struct rdma_cm_id *id = event->id; Endpoint *endpoint = reinterpret_cast(id->context); - + + endpoints_mu_.lock(); auto it = endpoints_.find(endpoint->node_id); CHECK(it != endpoints_.end()) << "Connection not ready."; + endpoints_mu_.unlock(); + CHECK_EQ(endpoint->status, Endpoint::CONNECTING); CHECK_EQ(endpoint->cm_id, id); @@ -792,17 +805,17 @@ class RDMAVan : public Van { endpoint->Init(cq_, pd_); - local_mu_.lock(); - if (disable_ipc_) { - is_local_[remote_ctx->node] = false; - } else { - is_local_[remote_ctx->node] = (std::string(remote_ctx->hostname) == my_node_.hostname) ? true : false; + + bool is_local_node = disable_ipc_ ? false : + (std::string(remote_ctx->hostname) == my_node_.hostname ? true : false); + { + std::lock_guard lk(local_mu_); + is_local_[remote_ctx->node] = is_local_node; } LOG(INFO) << "OnConnect to Node " << remote_ctx->node - << " with Transport=" << (is_local_[remote_ctx->node]?"IPC" : "RDMA"); - local_mu_.unlock(); + << " with Transport=" << (is_local_node ? "IPC" : "RDMA"); - std::shared_ptr t = is_local_[remote_ctx->node] ? + std::shared_ptr t = is_local_node ? std::make_shared(endpoint, mem_allocator_.get()) : std::make_shared(endpoint, mem_allocator_.get()); endpoint->SetTransport(t); @@ -898,6 +911,7 @@ class RDMAVan : public Van { struct rdma_cm_id *listener_ = nullptr; std::atomic should_stop_; + std::mutex endpoints_mu_; std::unordered_map> endpoints_; std::unordered_set> incoming_;