Skip to content

Commit

Permalink
Revert "Fix racing condition of endpoints in rdma_van" (#28)
Browse files Browse the repository at this point in the history
This reverts commit 3d76ec8.
  • Loading branch information
ymjiang authored Mar 7, 2020
1 parent 3d76ec8 commit 7d2ed94
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 36 deletions.
3 changes: 1 addition & 2 deletions src/rdma_transport.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,7 @@ struct Endpoint {
CHECK(buf);
struct ibv_mr *mr =
ibv_reg_mr(pd, buf, kMempoolChunkSize, IBV_ACCESS_LOCAL_WRITE);
CHECK(mr)<< "ibv_reg_mr Failed: " << strerror(errno)
<< ", i=" << i <<", kMempoolChunkSize:"<< kMempoolChunkSize;
CHECK(mr);

rx_ctx[i].type = kReceiveContext;
rx_ctx[i].buffer = mr;
Expand Down
54 changes: 20 additions & 34 deletions src/rdma_van.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,7 @@ class RDMAVan : public Van {

PS_VLOG(1) << "Clearing endpoints.";
incoming_.clear();
{
std::lock_guard<std::mutex> lk(endpoints_mu_);
endpoints_.clear();
}

endpoints_.clear();

PS_VLOG(1) << "Destroying cq and pd.";
CHECK(!ibv_destroy_cq(cq_)) << "Failed to destroy CQ";
Expand Down Expand Up @@ -143,7 +139,6 @@ class RDMAVan : public Van {
}

if (node.id != Node::kEmpty) {
endpoints_mu_.lock();
auto it = endpoints_.find(node.id);

// if there is an endpoint with pending connection
Expand All @@ -154,7 +149,6 @@ class RDMAVan : public Van {
Endpoint *endpoint;
endpoints_[node.id] = std::make_unique<Endpoint>();
endpoint = endpoints_[node.id].get();
endpoints_mu_.unlock();

endpoint->SetNodeID(node.id);

Expand Down Expand Up @@ -203,19 +197,18 @@ class RDMAVan : public Van {
if (endpoint->status == Endpoint::CONNECTED) break;
std::this_thread::sleep_for(std::chrono::milliseconds(500));
}

bool is_local_node = disable_ipc_ ? false :
(node.hostname == my_node_.hostname ? true : false);
{
std::lock_guard<std::mutex> lk(local_mu_);
is_local_[node.id] = is_local_node;
}

LOG(INFO) << "Connect to Node " << node.id
<< " with Transport=" << (is_local_node ? "IPC" : "RDMA");

local_mu_.lock();
if (disable_ipc_) {
is_local_[node.id] = false;
} else {
is_local_[node.id] = (node.hostname == my_node_.hostname) ? true : false;
}
LOG(INFO) << "Connect to Node " << node.id
<< " with Transport=" << (is_local_[node.id]?"IPC" : "RDMA");
local_mu_.unlock();

std::shared_ptr<Transport> t = is_local_node ?
std::shared_ptr<Transport> t = is_local_[node.id] ?
std::make_shared<IPCTransport>(endpoint, mem_allocator_.get()) :
std::make_shared<RDMATransport>(endpoint, mem_allocator_.get());
endpoint->SetTransport(t);
Expand All @@ -227,11 +220,8 @@ class RDMAVan : public Van {
int SendMsg(Message &msg) override {
int remote_id = msg.meta.recver;
CHECK_NE(remote_id, Meta::kEmpty);

endpoints_mu_.lock();
CHECK_NE(endpoints_.find(remote_id), endpoints_.end());
Endpoint *endpoint = endpoints_[remote_id].get();
endpoints_mu_.unlock();

int meta_len = GetPackMetaLen(msg.meta);
size_t data_len = msg.meta.data_size;
Expand Down Expand Up @@ -763,12 +753,9 @@ class RDMAVan : public Van {
void OnRejected(struct rdma_cm_event *event) {
struct rdma_cm_id *id = event->id;
Endpoint *endpoint = reinterpret_cast<Endpoint *>(id->context);

endpoints_mu_.lock();

auto it = endpoints_.find(endpoint->node_id);
CHECK(it != endpoints_.end()) << "Connection not ready.";
endpoints_mu_.unlock();

CHECK_EQ(endpoint->status, Endpoint::CONNECTING);
CHECK_EQ(endpoint->cm_id, id);

Expand Down Expand Up @@ -805,17 +792,17 @@ class RDMAVan : public Van {

endpoint->Init(cq_, pd_);


bool is_local_node = disable_ipc_ ? false :
(std::string(remote_ctx->hostname) == my_node_.hostname ? true : false);
{
std::lock_guard<std::mutex> lk(local_mu_);
is_local_[remote_ctx->node] = is_local_node;
local_mu_.lock();
if (disable_ipc_) {
is_local_[remote_ctx->node] = false;
} else {
is_local_[remote_ctx->node] = (std::string(remote_ctx->hostname) == my_node_.hostname) ? true : false;
}
LOG(INFO) << "OnConnect to Node " << remote_ctx->node
<< " with Transport=" << (is_local_node ? "IPC" : "RDMA");
<< " with Transport=" << (is_local_[remote_ctx->node]?"IPC" : "RDMA");
local_mu_.unlock();

std::shared_ptr<Transport> t = is_local_node ?
std::shared_ptr<Transport> t = is_local_[remote_ctx->node] ?
std::make_shared<IPCTransport>(endpoint, mem_allocator_.get()) :
std::make_shared<RDMATransport>(endpoint, mem_allocator_.get());
endpoint->SetTransport(t);
Expand Down Expand Up @@ -911,7 +898,6 @@ class RDMAVan : public Van {
struct rdma_cm_id *listener_ = nullptr;
std::atomic<bool> should_stop_;

std::mutex endpoints_mu_;
std::unordered_map<int, std::unique_ptr<Endpoint>> endpoints_;
std::unordered_set<std::unique_ptr<Endpoint>> incoming_;

Expand Down

0 comments on commit 7d2ed94

Please sign in to comment.