diff --git a/test/apps/iodemo/io_demo.cc b/test/apps/iodemo/io_demo.cc index 6f3228ae320..d738e7401d4 100644 --- a/test/apps/iodemo/io_demo.cc +++ b/test/apps/iodemo/io_demo.cc @@ -705,11 +705,10 @@ class DemoServer : public P2pDemoCommon { } virtual void dispatch_connection_error(UcxConnection *conn) { - LOG << "deleting connection with status " + LOG << "disconnecting connection with status " << ucs_status_string(conn->ucx_status()); - assert(conn->is_established()); --_curr_state.active_conns; - delete conn; + conn->disconnect(new UcxDisconnectCallback(*conn)); } virtual void dispatch_io_message(UcxConnection* conn, const void *buffer, @@ -763,6 +762,37 @@ class DemoServer : public P2pDemoCommon { class DemoClient : public P2pDemoCommon { +private: + class DisconnectCallback : public UcxCallback { + public: + DisconnectCallback(DemoClient &client, UcxConnection &conn) : + _client(client), _conn(&conn) { + } + + virtual ~DisconnectCallback() { + delete _conn; + } + + virtual void operator()(ucs_status_t status) { + server_info_t &server_info = _client.get_server_info(_conn); + + _client._num_sent -= get_num_uncompleted(server_info); + + // Remove connection pointer + _client._server_index_lookup.erase(_conn); + + // Remove active servers entry + _client.active_servers_remove(server_info.active_index); + + reset_server_info(server_info); + delete this; + } + + private: + DemoClient &_client; + UcxConnection *_conn; + }; + public: typedef struct { UcxConnection* conn; @@ -884,6 +914,13 @@ class DemoClient : public P2pDemoCommon { i->second; } + server_info_t &get_server_info(const UcxConnection *conn) { + const size_t server_index = get_server_index(conn); + + assert(server_index < _server_info.size()); + return _server_info[server_index]; + } + void commit_operation(size_t server_index) { server_info_t& server_info = _server_info[server_index]; @@ -959,7 +996,7 @@ class DemoClient : public P2pDemoCommon { return data_size; } - void close_uncompleted_servers(const char *reason) { + void disconnect_uncompleted_servers(const char *reason) { std::vector server_idxs; server_idxs.reserve(_active_servers.size()); @@ -970,7 +1007,7 @@ class DemoClient : public P2pDemoCommon { } while (!server_idxs.empty()) { - close_server(server_idxs.back(), reason); + disconnect_server(server_idxs.back(), reason); server_idxs.pop_back(); } } @@ -997,13 +1034,14 @@ class DemoClient : public P2pDemoCommon { } } - long get_num_uncompleted(const server_info_t& server_info) const { + static long get_num_uncompleted(const server_info_t& server_info) { return server_info.num_sent - (server_info.num_completed[IO_READ] + server_info.num_completed[IO_WRITE]); } long get_num_uncompleted(size_t server_index) const { + assert(server_index < _server_info.size()); return get_num_uncompleted(_server_info[server_index]); } @@ -1019,27 +1057,24 @@ class DemoClient : public P2pDemoCommon { virtual void dispatch_connection_error(UcxConnection *conn) { size_t server_index = get_server_index(conn); if (server_index < _server_info.size()) { - close_server(server_index, ucs_status_string(conn->ucx_status())); + disconnect_server(server_index, + ucs_status_string(conn->ucx_status())); } } - void close_server(size_t server_index, const char *reason) { + void disconnect_server(size_t server_index, const char *reason) { server_info_t& server_info = _server_info[server_index]; - LOG << "terminate connection " << server_info.conn << " due to " - << reason; + if (server_info.conn->is_disconnecting()) { + return; + } - // Remove connection pointer - _server_index_lookup.erase(server_info.conn); + LOG << "disconnecting connection " << server_info.conn << " due to " + << reason; // Destroying the connection will complete its outstanding operations - delete server_info.conn; - - // Don't wait for any more completions on this connection - _num_sent -= get_num_uncompleted(server_info); - - active_servers_remove(server_info.active_index); - reset_server_info(server_info); + server_info.conn->disconnect(new DisconnectCallback(*this, + *server_info.conn)); } void wait_for_responses(long max_outstanding) { @@ -1068,7 +1103,7 @@ class DemoClient : public P2pDemoCommon { if (elapsed_time > _test_opts.client_timeout) { LOG << "timeout waiting for " << (_num_sent - _num_completed) << " replies"; - close_uncompleted_servers("timeout for replies"); + disconnect_uncompleted_servers("timeout for replies"); timer_finished = true; } check_time_limit(curr_time); @@ -1317,11 +1352,19 @@ class DemoClient : public P2pDemoCommon { for (size_t server_index = 0; server_index < _server_info.size(); ++server_index) { LOG << "Disconnecting from " << server_name(server_index); - delete _server_info[server_index].conn; - _server_info[server_index].conn = NULL; + UcxConnection& conn = *_server_info[server_index].conn; + conn.disconnect(new DisconnectCallback(*this, conn)); } - _server_index_lookup.clear(); - _active_servers.clear(); + + if (!_active_servers.empty()) { + LOG << "Waiting for " << _active_servers.size() + << " disconnects to complete"; + do { + progress(); + } while (!_active_servers.empty()); + } + + assert(_server_index_lookup.empty()); return _status; } diff --git a/test/apps/iodemo/ucx_wrapper.cc b/test/apps/iodemo/ucx_wrapper.cc index d92c63a7e26..0a49ffe719d 100644 --- a/test/apps/iodemo/ucx_wrapper.cc +++ b/test/apps/iodemo/ucx_wrapper.cc @@ -14,6 +14,7 @@ #include #include +#include #include struct ucx_request { @@ -90,6 +91,21 @@ void UcxContext::UcxAcceptCallback::operator()(ucs_status_t status) delete this; } +UcxContext::UcxDisconnectCallback::UcxDisconnectCallback(UcxConnection &conn) + : _conn(&conn) +{ +} + +UcxContext::UcxDisconnectCallback::~UcxDisconnectCallback() +{ + delete _conn; +} + +void UcxContext::UcxDisconnectCallback::operator()(ucs_status_t status) +{ + delete this; +} + UcxContext::UcxContext(size_t iomsg_size, double connect_timeout) : _context(NULL), _worker(NULL), _listener(NULL), _iomsg_recv_request(NULL), _iomsg_buffer(iomsg_size, '\0'), _connect_timeout(connect_timeout) @@ -178,6 +194,7 @@ void UcxContext::progress() progress_timed_out_conns(); progress_conn_requests(); progress_failed_connections(); + progress_disconnected_connections(); } void UcxContext::memory_pin_stats(memory_pin_stats_t *stats) @@ -371,6 +388,19 @@ void UcxContext::progress_failed_connections() } } +void UcxContext::progress_disconnected_connections() +{ + std::list::iterator it = _disconnecting_conns.begin(); + while (it != _disconnecting_conns.end()) { + UcxConnection *conn = *it; + if (conn->disconnect_progress()) { + it = _disconnecting_conns.erase(it); + } else { + ++it; + } + } +} + UcxContext::wait_status_t UcxContext::wait_completion(ucs_status_ptr_t status_ptr, const char *title, double timeout) @@ -446,6 +476,14 @@ void UcxContext::remove_connection_inprogress(UcxConnection *conn) } } +void UcxContext::move_connection_to_disconnecting(UcxConnection *conn) +{ + remove_connection(conn); + assert(std::find(_disconnecting_conns.begin(), _disconnecting_conns.end(), + conn) == _disconnecting_conns.end()); + _disconnecting_conns.push_back(conn); +} + void UcxContext::dispatch_connection_accepted(UcxConnection* conn) { } @@ -467,14 +505,21 @@ void UcxContext::destroy_connections() } while (!_conns_in_progress.empty()) { - delete _conns_in_progress.begin()->second; + UcxConnection &conn = *_conns_in_progress.begin()->second; _conns_in_progress.erase(_conns_in_progress.begin()); + conn.disconnect(new UcxDisconnectCallback(conn)); } UCX_LOG << "destroy_connections"; while (!_conns.empty()) { - // UcxConnection destructor removes itself from _conns map - delete _conns.begin()->second; + UcxConnection &conn = *_conns.begin()->second; + _conns.erase(_conns.begin()); + conn.disconnect(new UcxDisconnectCallback(conn)); + } + + while (!_disconnecting_conns.empty()) { + ucp_worker_progress(_worker); + progress_disconnected_connections(); } } @@ -513,7 +558,8 @@ unsigned UcxConnection::_num_instances = 0; UcxConnection::UcxConnection(UcxContext &context) : _context(context), _establish_cb(NULL), - _conn_id(UcxContext::get_next_conn_id()), + _disconnect_cb(NULL), + _conn_id(context.get_next_conn_id()), _remote_conn_id(0), _ep(NULL), _close_request(NULL), @@ -529,33 +575,13 @@ UcxConnection::UcxConnection(UcxContext &context) : UcxConnection::~UcxConnection() { - UCX_CONN_LOG << "destroying, ep is " << _ep; - print_addresses(); - - _context.remove_connection(this); - cancel_all(); - - // if _ep is NULL, connection was closed and removed by error handler - if (_ep != NULL) { - ep_close(UCP_EP_CLOSE_MODE_FORCE); - } - - if (_close_request) { - _context.wait_completion(_close_request, "ep close"); - } - /* establish cb must be destroyed earlier since it accesses * the connection */ assert(_establish_cb == NULL); - - // wait until all requests are completed - if (!ucs_list_is_empty(&_all_requests)) { - UCX_CONN_LOG << "waiting for " << ucs_list_length(&_all_requests) << - " uncompleted requests"; - } - while (!ucs_list_is_empty(&_all_requests)) { - ucp_worker_progress(_context.worker()); - } + assert(_disconnect_cb == NULL); + assert(_ep == NULL); + assert(ucs_list_is_empty(&_all_requests)); + assert(!UCS_PTR_IS_PTR(_close_request)); UCX_CONN_LOG << "released"; --_num_instances; @@ -599,6 +625,48 @@ void UcxConnection::accept(ucp_conn_request_h conn_req, UcxCallback *callback) connect_common(ep_params, callback); } +void UcxConnection::disconnect(UcxCallback *callback) +{ + /* establish cb must be destroyed earlier since it accesses + * the connection */ + assert(_establish_cb == NULL); + assert(_disconnect_cb == NULL); + assert(_ep != NULL); + + UCX_CONN_LOG << "destroying, ep is " << _ep; + + _disconnect_cb = callback; + if (ucs_list_is_empty(&_all_requests)) { + ep_close(UCP_EP_CLOSE_MODE_FORCE); + _context.move_connection_to_disconnecting(this); + } else { + cancel_all(); + ep_close(UCP_EP_CLOSE_MODE_FORCE); + } +} + +bool UcxConnection::disconnect_progress() +{ + assert(_ep == NULL); + assert(_disconnect_cb != NULL); + + if (UCS_PTR_IS_PTR(_close_request)) { + if (ucp_request_check_status(_close_request) == UCS_INPROGRESS) { + return false; + } else { + ucp_request_free(_close_request); + _close_request = NULL; + } + } + + assert(ucs_list_is_empty(&_all_requests)); + UcxCallback *cb = _disconnect_cb; + _disconnect_cb = NULL; + // invoke last since it can delete this object + (*cb)(UCS_OK); + return true; +} + bool UcxConnection::send_io_message(const void *buffer, size_t length, UcxCallback* callback) { @@ -839,6 +907,9 @@ void UcxConnection::request_completed(ucx_request *r) { assert(r->conn == this); ucs_list_del(&r->pos); + if (ucs_list_is_empty(&_all_requests) && (_disconnect_cb != NULL)) { + _context.move_connection_to_disconnecting(this); + } } void UcxConnection::handle_connection_error(ucs_status_t status) diff --git a/test/apps/iodemo/ucx_wrapper.h b/test/apps/iodemo/ucx_wrapper.h index 18a8d9feb95..b692e427e80 100644 --- a/test/apps/iodemo/ucx_wrapper.h +++ b/test/apps/iodemo/ucx_wrapper.h @@ -9,15 +9,16 @@ #include #include +#include #include #include #include #include +#include #include #include #include #include -#include #define MAX_LOG_PREFIX_SIZE 64 @@ -85,6 +86,19 @@ class UcxContext { UcxConnection &_connection; }; +protected: + class UcxDisconnectCallback : public UcxCallback { + public: + UcxDisconnectCallback(UcxConnection &conn); + + virtual ~UcxDisconnectCallback(); + + virtual void operator()(ucs_status_t status); + + private: + UcxConnection *_conn; + }; + public: typedef struct memory_pin_stats { unsigned long regions; @@ -166,6 +180,8 @@ class UcxContext { void progress_failed_connections(); + void progress_disconnected_connections(); + wait_status_t wait_completion(ucs_status_ptr_t status_ptr, const char *title, double timeout = 1e6); @@ -177,6 +193,8 @@ class UcxContext { void remove_connection_inprogress(UcxConnection *conn); + void move_connection_to_disconnecting(UcxConnection *conn); + void handle_connection_error(UcxConnection *conn); void destroy_listener(); @@ -193,6 +211,7 @@ class UcxContext { std::deque _conn_requests; timeout_conn_t _conns_in_progress; // ordered in time std::deque _failed_conns; + std::list _disconnecting_conns; ucx_request *_iomsg_recv_request; std::string _iomsg_buffer; double _connect_timeout; @@ -210,6 +229,10 @@ class UcxConnection { void accept(ucp_conn_request_h conn_req, UcxCallback *callback); + void disconnect(UcxCallback *callback); + + bool disconnect_progress(); + bool send_io_message(const void *buffer, size_t length, UcxCallback* callback = EmptyCallback::get()); @@ -233,6 +256,10 @@ class UcxConnection { return _establish_cb == NULL; } + bool is_disconnecting() const { + return _disconnect_cb != NULL; + } + void handle_connection_error(ucs_status_t status); private: @@ -278,8 +305,9 @@ class UcxConnection { UcxContext &_context; UcxCallback *_establish_cb; - uint32_t _conn_id; - uint32_t _remote_conn_id; + UcxCallback *_disconnect_cb; + uint64_t _conn_id; + uint64_t _remote_conn_id; char _log_prefix[MAX_LOG_PREFIX_SIZE]; ucp_ep_h _ep; void *_close_request;