Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IODEMO: non blocking disconnect #139

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 67 additions & 24 deletions test/apps/iodemo/io_demo.cc
Original file line number Diff line number Diff line change
Expand Up @@ -705,11 +705,10 @@ class DemoServer : public P2pDemoCommon {
}

virtual void dispatch_connection_error(UcxConnection *conn) {
LOG << "deleting connection with status "
LOG << "disconnecting connection with status "
<< ucs_status_string(conn->ucx_status());
assert(conn->is_established());
--_curr_state.active_conns;
delete conn;
conn->disconnect(new UcxDisconnectCallback(*conn));
}

virtual void dispatch_io_message(UcxConnection* conn, const void *buffer,
Expand Down Expand Up @@ -763,6 +762,37 @@ class DemoServer : public P2pDemoCommon {


class DemoClient : public P2pDemoCommon {
private:
class DisconnectCallback : public UcxCallback {
public:
DisconnectCallback(DemoClient &client, UcxConnection &conn) :
_client(client), _conn(&conn) {
}

virtual ~DisconnectCallback() {
delete _conn;
}

virtual void operator()(ucs_status_t status) {
server_info_t &server_info = _client.get_server_info(_conn);

_client._num_sent -= get_num_uncompleted(server_info);

// Remove connection pointer
_client._server_index_lookup.erase(_conn);

// Remove active servers entry
_client.active_servers_remove(server_info.active_index);

reset_server_info(server_info);
delete this;
}

private:
DemoClient &_client;
UcxConnection *_conn;
};

public:
typedef struct {
UcxConnection* conn;
Expand Down Expand Up @@ -884,6 +914,13 @@ class DemoClient : public P2pDemoCommon {
i->second;
}

server_info_t &get_server_info(const UcxConnection *conn) {
const size_t server_index = get_server_index(conn);

assert(server_index < _server_info.size());
return _server_info[server_index];
}

void commit_operation(size_t server_index) {
server_info_t& server_info = _server_info[server_index];

Expand Down Expand Up @@ -959,7 +996,7 @@ class DemoClient : public P2pDemoCommon {
return data_size;
}

void close_uncompleted_servers(const char *reason) {
void disconnect_uncompleted_servers(const char *reason) {
std::vector<size_t> server_idxs;
server_idxs.reserve(_active_servers.size());

Expand All @@ -970,7 +1007,7 @@ class DemoClient : public P2pDemoCommon {
}

while (!server_idxs.empty()) {
close_server(server_idxs.back(), reason);
disconnect_server(server_idxs.back(), reason);
server_idxs.pop_back();
}
}
Expand All @@ -997,13 +1034,14 @@ class DemoClient : public P2pDemoCommon {
}
}

long get_num_uncompleted(const server_info_t& server_info) const {
static long get_num_uncompleted(const server_info_t& server_info) {
return server_info.num_sent -
(server_info.num_completed[IO_READ] +
server_info.num_completed[IO_WRITE]);
}

long get_num_uncompleted(size_t server_index) const {
assert(server_index < _server_info.size());
return get_num_uncompleted(_server_info[server_index]);
}

Expand All @@ -1019,27 +1057,24 @@ class DemoClient : public P2pDemoCommon {
virtual void dispatch_connection_error(UcxConnection *conn) {
size_t server_index = get_server_index(conn);
if (server_index < _server_info.size()) {
close_server(server_index, ucs_status_string(conn->ucx_status()));
disconnect_server(server_index,
ucs_status_string(conn->ucx_status()));
}
}

void close_server(size_t server_index, const char *reason) {
void disconnect_server(size_t server_index, const char *reason) {
server_info_t& server_info = _server_info[server_index];

LOG << "terminate connection " << server_info.conn << " due to "
<< reason;
if (server_info.conn->is_disconnecting()) {
return;
}

// Remove connection pointer
_server_index_lookup.erase(server_info.conn);
LOG << "disconnecting connection " << server_info.conn << " due to "
<< reason;

// Destroying the connection will complete its outstanding operations
delete server_info.conn;

// Don't wait for any more completions on this connection
_num_sent -= get_num_uncompleted(server_info);

active_servers_remove(server_info.active_index);
reset_server_info(server_info);
server_info.conn->disconnect(new DisconnectCallback(*this,
*server_info.conn));
}

void wait_for_responses(long max_outstanding) {
Expand Down Expand Up @@ -1068,7 +1103,7 @@ class DemoClient : public P2pDemoCommon {
if (elapsed_time > _test_opts.client_timeout) {
LOG << "timeout waiting for " << (_num_sent - _num_completed)
<< " replies";
close_uncompleted_servers("timeout for replies");
disconnect_uncompleted_servers("timeout for replies");
timer_finished = true;
}
check_time_limit(curr_time);
Expand Down Expand Up @@ -1317,11 +1352,19 @@ class DemoClient : public P2pDemoCommon {
for (size_t server_index = 0; server_index < _server_info.size();
++server_index) {
LOG << "Disconnecting from " << server_name(server_index);
delete _server_info[server_index].conn;
_server_info[server_index].conn = NULL;
UcxConnection& conn = *_server_info[server_index].conn;
conn.disconnect(new DisconnectCallback(*this, conn));
}
_server_index_lookup.clear();
_active_servers.clear();

if (!_active_servers.empty()) {
LOG << "Waiting for " << _active_servers.size()
<< " disconnects to complete";
do {
progress();
} while (!_active_servers.empty());
}

assert(_server_index_lookup.empty());

return _status;
}
Expand Down
127 changes: 99 additions & 28 deletions test/apps/iodemo/ucx_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <sys/types.h>

#include <unistd.h>
#include <algorithm>
#include <limits>

struct ucx_request {
Expand Down Expand Up @@ -90,6 +91,21 @@ void UcxContext::UcxAcceptCallback::operator()(ucs_status_t status)
delete this;
}

UcxContext::UcxDisconnectCallback::UcxDisconnectCallback(UcxConnection &conn)
: _conn(&conn)
{
}

UcxContext::UcxDisconnectCallback::~UcxDisconnectCallback()
{
delete _conn;
}

void UcxContext::UcxDisconnectCallback::operator()(ucs_status_t status)
{
delete this;
}

UcxContext::UcxContext(size_t iomsg_size, double connect_timeout) :
_context(NULL), _worker(NULL), _listener(NULL), _iomsg_recv_request(NULL),
_iomsg_buffer(iomsg_size, '\0'), _connect_timeout(connect_timeout)
Expand Down Expand Up @@ -178,6 +194,7 @@ void UcxContext::progress()
progress_timed_out_conns();
progress_conn_requests();
progress_failed_connections();
progress_disconnected_connections();
}

void UcxContext::memory_pin_stats(memory_pin_stats_t *stats)
Expand Down Expand Up @@ -371,6 +388,19 @@ void UcxContext::progress_failed_connections()
}
}

void UcxContext::progress_disconnected_connections()
{
std::list<UcxConnection *>::iterator it = _disconnecting_conns.begin();
while (it != _disconnecting_conns.end()) {
UcxConnection *conn = *it;
if (conn->disconnect_progress()) {
it = _disconnecting_conns.erase(it);
} else {
++it;
}
}
}

UcxContext::wait_status_t
UcxContext::wait_completion(ucs_status_ptr_t status_ptr, const char *title,
double timeout)
Expand Down Expand Up @@ -446,6 +476,14 @@ void UcxContext::remove_connection_inprogress(UcxConnection *conn)
}
}

void UcxContext::move_connection_to_disconnecting(UcxConnection *conn)
{
remove_connection(conn);
assert(std::find(_disconnecting_conns.begin(), _disconnecting_conns.end(),
conn) == _disconnecting_conns.end());
_disconnecting_conns.push_back(conn);
}

void UcxContext::dispatch_connection_accepted(UcxConnection* conn)
{
}
Expand All @@ -467,14 +505,21 @@ void UcxContext::destroy_connections()
}

while (!_conns_in_progress.empty()) {
delete _conns_in_progress.begin()->second;
UcxConnection &conn = *_conns_in_progress.begin()->second;
_conns_in_progress.erase(_conns_in_progress.begin());
conn.disconnect(new UcxDisconnectCallback(conn));
}

UCX_LOG << "destroy_connections";
while (!_conns.empty()) {
// UcxConnection destructor removes itself from _conns map
delete _conns.begin()->second;
UcxConnection &conn = *_conns.begin()->second;
_conns.erase(_conns.begin());
conn.disconnect(new UcxDisconnectCallback(conn));
}

while (!_disconnecting_conns.empty()) {
ucp_worker_progress(_worker);
progress_disconnected_connections();
}
}

Expand Down Expand Up @@ -513,7 +558,8 @@ unsigned UcxConnection::_num_instances = 0;
UcxConnection::UcxConnection(UcxContext &context) :
_context(context),
_establish_cb(NULL),
_conn_id(UcxContext::get_next_conn_id()),
_disconnect_cb(NULL),
_conn_id(context.get_next_conn_id()),
_remote_conn_id(0),
_ep(NULL),
_close_request(NULL),
Expand All @@ -529,33 +575,13 @@ UcxConnection::UcxConnection(UcxContext &context) :

UcxConnection::~UcxConnection()
{
UCX_CONN_LOG << "destroying, ep is " << _ep;
print_addresses();

_context.remove_connection(this);
cancel_all();

// if _ep is NULL, connection was closed and removed by error handler
if (_ep != NULL) {
ep_close(UCP_EP_CLOSE_MODE_FORCE);
}

if (_close_request) {
_context.wait_completion(_close_request, "ep close");
}

/* establish cb must be destroyed earlier since it accesses
* the connection */
assert(_establish_cb == NULL);

// wait until all requests are completed
if (!ucs_list_is_empty(&_all_requests)) {
UCX_CONN_LOG << "waiting for " << ucs_list_length(&_all_requests) <<
" uncompleted requests";
}
while (!ucs_list_is_empty(&_all_requests)) {
ucp_worker_progress(_context.worker());
}
assert(_disconnect_cb == NULL);
assert(_ep == NULL);
assert(ucs_list_is_empty(&_all_requests));
assert(!UCS_PTR_IS_PTR(_close_request));

UCX_CONN_LOG << "released";
--_num_instances;
Expand Down Expand Up @@ -599,6 +625,48 @@ void UcxConnection::accept(ucp_conn_request_h conn_req, UcxCallback *callback)
connect_common(ep_params, callback);
}

void UcxConnection::disconnect(UcxCallback *callback)
{
/* establish cb must be destroyed earlier since it accesses
* the connection */
assert(_establish_cb == NULL);
assert(_disconnect_cb == NULL);
assert(_ep != NULL);

UCX_CONN_LOG << "destroying, ep is " << _ep;

_disconnect_cb = callback;
if (ucs_list_is_empty(&_all_requests)) {
ep_close(UCP_EP_CLOSE_MODE_FORCE);
_context.move_connection_to_disconnecting(this);
} else {
cancel_all();
ep_close(UCP_EP_CLOSE_MODE_FORCE);
}
}

bool UcxConnection::disconnect_progress()
{
assert(_ep == NULL);
assert(_disconnect_cb != NULL);

if (UCS_PTR_IS_PTR(_close_request)) {
if (ucp_request_check_status(_close_request) == UCS_INPROGRESS) {
return false;
} else {
ucp_request_free(_close_request);
_close_request = NULL;
}
}

assert(ucs_list_is_empty(&_all_requests));
UcxCallback *cb = _disconnect_cb;
_disconnect_cb = NULL;
// invoke last since it can delete this object
(*cb)(UCS_OK);
return true;
}

bool UcxConnection::send_io_message(const void *buffer, size_t length,
UcxCallback* callback)
{
Expand Down Expand Up @@ -839,6 +907,9 @@ void UcxConnection::request_completed(ucx_request *r)
{
assert(r->conn == this);
ucs_list_del(&r->pos);
if (ucs_list_is_empty(&_all_requests) && (_disconnect_cb != NULL)) {
_context.move_connection_to_disconnecting(this);
}
}

void UcxConnection::handle_connection_error(ucs_status_t status)
Expand Down
Loading