diff --git a/src/ucp/core/ucp_request.inl b/src/ucp/core/ucp_request.inl index 386dc68fd665..77102b6548d0 100644 --- a/src/ucp/core/ucp_request.inl +++ b/src/ucp/core/ucp_request.inl @@ -64,12 +64,15 @@ #define ucp_request_complete(_req, _cb, _status, ...) \ { \ + /* NOTE: external request can't have RELEASE flag and we */ \ + /* will never put it into mpool */ \ + uint32_t _released = ((_req)->flags |= UCP_REQUEST_FLAG_COMPLETED) & \ + UCP_REQUEST_FLAG_RELEASED; \ (_req)->status = (_status); \ if (ucs_likely((_req)->flags & UCP_REQUEST_FLAG_CALLBACK)) { \ (_req)->_cb((_req) + 1, (_status), ## __VA_ARGS__); \ } \ - if (ucs_unlikely(((_req)->flags |= UCP_REQUEST_FLAG_COMPLETED) & \ - UCP_REQUEST_FLAG_RELEASED)) { \ + if (ucs_unlikely(_released)) { \ ucp_request_put(_req); \ } \ } diff --git a/test/gtest/ucp/test_ucp_tag.cc b/test/gtest/ucp/test_ucp_tag.cc index eaf0dd2ffc93..1440126b77ed 100644 --- a/test/gtest/ucp/test_ucp_tag.cc +++ b/test/gtest/ucp/test_ucp_tag.cc @@ -14,6 +14,7 @@ extern "C" { #include #include #include +#include } #include @@ -462,28 +463,39 @@ UCS_TEST_P(test_ucp_tag_limits, check_max_short_zcopy_thresh_zero, "ZCOPY_THRESH UCP_INSTANTIATE_TEST_CASE(test_ucp_tag_limits) -class test_ucp_tag_fallback : public ucp_test { +class test_ucp_tag_nbx : public test_ucp_tag { public: void init() { /* forbid zcopy access because it will always fail due to read-only * memory pages (will fail to register memory) */ modify_config("ZCOPY_THRESH", "inf"); - ucp_test::init(); - sender().connect(&receiver(), get_ep_params()); - receiver().connect(&sender(), get_ep_params()); - } - - static void get_test_variants(std::vector& variants) { - add_variant(variants, UCP_FEATURE_TAG); + test_ucp_tag::init(); + m_completed = 0; } protected: static const size_t MSG_SIZE; + uint32_t m_completed; + + static void send_callback(void *req, ucs_status_t status, + void *user_data) + { + request_free((request*)req); + ucs_atomic_add32((volatile uint32_t*)user_data, 1); + } + + static void recv_callback(void *req, ucs_status_t status, + const ucp_tag_recv_info_t *info, + void *user_data) + { + request_free((request*)req); + ucs_atomic_add32((volatile uint32_t*)user_data, 1); + } }; -const size_t test_ucp_tag_fallback::MSG_SIZE = 4 * 1024 * ucs_get_page_size(); +const size_t test_ucp_tag_nbx::MSG_SIZE = 4 * UCS_KBYTE * ucs_get_page_size(); -UCS_TEST_P(test_ucp_tag_fallback, fallback) +UCS_TEST_P(test_ucp_tag_nbx, fallback) { ucp_request_param_t param = {0}; @@ -509,4 +521,35 @@ UCS_TEST_P(test_ucp_tag_fallback, fallback) munmap(send_buffer, MSG_SIZE); } -UCP_INSTANTIATE_TEST_CASE(test_ucp_tag_fallback) +UCS_TEST_P(test_ucp_tag_nbx, external_request_free) +{ + ucp_request_param_t send_param; + ucp_request_param_t recv_param; + + send_param.op_attr_mask = recv_param.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK | + UCP_OP_ATTR_FIELD_REQUEST | + UCP_OP_ATTR_FLAG_NO_IMM_CMPL | + UCP_OP_ATTR_FIELD_USER_DATA; + send_param.user_data = recv_param.user_data = &m_completed; + send_param.request = request_alloc(); + recv_param.request = request_alloc(); + send_param.cb.send = (ucp_send_nbx_callback_t)send_callback; + recv_param.cb.recv = (ucp_tag_recv_nbx_callback_t)recv_callback; + send_param.user_data = &m_completed; + recv_param.user_data = &m_completed; + + std::vector send_buffer(MSG_SIZE); + std::vector recv_buffer(MSG_SIZE); + + ucs_status_ptr_t recv_req = ucp_tag_recv_nbx(receiver().worker(), + &recv_buffer[0], MSG_SIZE, + 0, 0, &recv_param); + ASSERT_TRUE(UCS_PTR_IS_PTR(recv_req)); + ucs_status_ptr_t send_req = ucp_tag_send_nbx(sender().ep(), &send_buffer[0], + MSG_SIZE, 0, &send_param); + ASSERT_TRUE(UCS_PTR_IS_PTR(send_req)); + + wait_for_value(&m_completed, 2u); +} + +UCP_INSTANTIATE_TEST_CASE(test_ucp_tag_nbx) diff --git a/test/gtest/ucp/ucp_test.h b/test/gtest/ucp/ucp_test.h index 36f0f4811fc2..a5c93ff1656c 100644 --- a/test/gtest/ucp/ucp_test.h +++ b/test/gtest/ucp/ucp_test.h @@ -297,6 +297,15 @@ class ucp_test : public ucp_test_base, } } + template + void wait_for_value(volatile T *var, T value, double timeout = 10.0) const + { + ucs_time_t deadline = ucs_get_time() + + ucs_time_from_sec(timeout) * ucs::test_time_multiplier(); + while ((ucs_get_time() < deadline) && (*var != value)) { + short_progress_loop(); + } + } static const ucp_datatype_t DATATYPE; static const ucp_datatype_t DATATYPE_IOV;