Skip to content

Commit

Permalink
rdma_van: fix multi-thread races and add logging (#29)
Browse files Browse the repository at this point in the history
Co-authored-by: tanguofu <tanguofu>
  • Loading branch information
tanguofu authored Mar 7, 2020
1 parent 7d2ed94 commit 2f40b17
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 21 deletions.
3 changes: 2 additions & 1 deletion src/rdma_transport.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,8 @@ struct Endpoint {
CHECK(buf);
struct ibv_mr *mr =
ibv_reg_mr(pd, buf, kMempoolChunkSize, IBV_ACCESS_LOCAL_WRITE);
CHECK(mr);
CHECK(mr)<< "ibv_reg_mr failed: " << strerror(errno)
<< ", i=" << i <<", kMempoolChunkSize="<< kMempoolChunkSize;

rx_ctx[i].type = kReceiveContext;
rx_ctx[i].buffer = mr;
Expand Down
54 changes: 34 additions & 20 deletions src/rdma_van.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,11 @@ class RDMAVan : public Van {

PS_VLOG(1) << "Clearing endpoints.";
incoming_.clear();
endpoints_.clear();
{
std::lock_guard<std::mutex> lk(endpoints_mu_);
endpoints_.clear();
}


PS_VLOG(1) << "Destroying cq and pd.";
CHECK(!ibv_destroy_cq(cq_)) << "Failed to destroy CQ";
Expand Down Expand Up @@ -139,6 +143,7 @@ class RDMAVan : public Van {
}

if (node.id != Node::kEmpty) {
endpoints_mu_.lock();
auto it = endpoints_.find(node.id);

// if there is an endpoint with pending connection
Expand All @@ -149,6 +154,7 @@ class RDMAVan : public Van {
Endpoint *endpoint;
endpoints_[node.id] = std::make_unique<Endpoint>();
endpoint = endpoints_[node.id].get();
endpoints_mu_.unlock();

endpoint->SetNodeID(node.id);

Expand Down Expand Up @@ -197,18 +203,19 @@ class RDMAVan : public Van {
if (endpoint->status == Endpoint::CONNECTED) break;
std::this_thread::sleep_for(std::chrono::milliseconds(500));
}

bool is_local_node = disable_ipc_ ? false :
(node.hostname == my_node_.hostname ? true : false);
{
std::lock_guard<std::mutex> lk(local_mu_);
is_local_[node.id] = is_local_node;
}

LOG(INFO) << "Connect to Node " << node.id
<< " with Transport=" << (is_local_node ? "IPC" : "RDMA");

local_mu_.lock();
if (disable_ipc_) {
is_local_[node.id] = false;
} else {
is_local_[node.id] = (node.hostname == my_node_.hostname) ? true : false;
}
LOG(INFO) << "Connect to Node " << node.id
<< " with Transport=" << (is_local_[node.id]?"IPC" : "RDMA");
local_mu_.unlock();

std::shared_ptr<Transport> t = is_local_[node.id] ?
std::shared_ptr<Transport> t = is_local_node ?
std::make_shared<IPCTransport>(endpoint, mem_allocator_.get()) :
std::make_shared<RDMATransport>(endpoint, mem_allocator_.get());
endpoint->SetTransport(t);
Expand All @@ -220,8 +227,11 @@ class RDMAVan : public Van {
int SendMsg(Message &msg) override {
int remote_id = msg.meta.recver;
CHECK_NE(remote_id, Meta::kEmpty);

endpoints_mu_.lock();
CHECK_NE(endpoints_.find(remote_id), endpoints_.end());
Endpoint *endpoint = endpoints_[remote_id].get();
endpoints_mu_.unlock();

int meta_len = GetPackMetaLen(msg.meta);
size_t data_len = msg.meta.data_size;
Expand Down Expand Up @@ -753,9 +763,12 @@ class RDMAVan : public Van {
void OnRejected(struct rdma_cm_event *event) {
struct rdma_cm_id *id = event->id;
Endpoint *endpoint = reinterpret_cast<Endpoint *>(id->context);


endpoints_mu_.lock();
auto it = endpoints_.find(endpoint->node_id);
CHECK(it != endpoints_.end()) << "Connection not ready.";
endpoints_mu_.unlock();

CHECK_EQ(endpoint->status, Endpoint::CONNECTING);
CHECK_EQ(endpoint->cm_id, id);

Expand Down Expand Up @@ -792,17 +805,17 @@ class RDMAVan : public Van {

endpoint->Init(cq_, pd_);

local_mu_.lock();
if (disable_ipc_) {
is_local_[remote_ctx->node] = false;
} else {
is_local_[remote_ctx->node] = (std::string(remote_ctx->hostname) == my_node_.hostname) ? true : false;

bool is_local_node = disable_ipc_ ? false :
(std::string(remote_ctx->hostname) == my_node_.hostname ? true : false);
{
std::lock_guard<std::mutex> lk(local_mu_);
is_local_[remote_ctx->node] = is_local_node;
}
LOG(INFO) << "OnConnect to Node " << remote_ctx->node
<< " with Transport=" << (is_local_[remote_ctx->node]?"IPC" : "RDMA");
local_mu_.unlock();
<< " with Transport=" << (is_local_node ? "IPC" : "RDMA");

std::shared_ptr<Transport> t = is_local_[remote_ctx->node] ?
std::shared_ptr<Transport> t = is_local_node ?
std::make_shared<IPCTransport>(endpoint, mem_allocator_.get()) :
std::make_shared<RDMATransport>(endpoint, mem_allocator_.get());
endpoint->SetTransport(t);
Expand Down Expand Up @@ -898,6 +911,7 @@ class RDMAVan : public Van {
struct rdma_cm_id *listener_ = nullptr;
std::atomic<bool> should_stop_;

std::mutex endpoints_mu_;
std::unordered_map<int, std::unique_ptr<Endpoint>> endpoints_;
std::unordered_set<std::unique_ptr<Endpoint>> incoming_;

Expand Down

0 comments on commit 2f40b17

Please sign in to comment.