diff --git a/src/ucp/proto/proto_am.inl b/src/ucp/proto/proto_am.inl index 3c2d18ea31b..779fcbea032 100644 --- a/src/ucp/proto/proto_am.inl +++ b/src/ucp/proto/proto_am.inl @@ -345,11 +345,14 @@ ucs_status_t ucp_do_am_zcopy_multi(uct_pending_req_t *self, uint8_t am_id_first, if (enable_am_bw && (req->send.state.dt.offset != 0)) { req->send.lane = ucp_send_request_get_am_bw_lane(req); - ucp_send_request_add_reg_lane(req, req->send.lane); } else { req->send.lane = ucp_ep_get_am_lane(ep); } + if (enable_am_bw || (req->send.state.dt.offset == 0)) { + ucp_send_request_add_reg_lane(req, req->send.lane); + } + uct_ep = ep->uct_eps[req->send.lane]; max_middle = ucp_ep_get_max_zcopy(ep, req->send.lane) - hdr_size_middle; max_iov = ucp_ep_get_max_iov(ep, req->send.lane); diff --git a/src/ucs/datastruct/callbackq.c b/src/ucs/datastruct/callbackq.c index 7a9c97d0fc1..4d5e558c042 100644 --- a/src/ucs/datastruct/callbackq.c +++ b/src/ucs/datastruct/callbackq.c @@ -363,18 +363,20 @@ static unsigned ucs_callbackq_slow_proxy(void *arg) { ucs_callbackq_t *cbq = arg; ucs_callbackq_priv_t *priv = ucs_callbackq_priv(cbq); + unsigned num_slow_elems = priv->num_slow_elems; + unsigned count = 0; ucs_callbackq_elem_t *elem; unsigned UCS_V_UNUSED removed_idx; unsigned slow_idx, fast_idx; ucs_callbackq_elem_t tmp_elem; - unsigned count = 0; ucs_trace_poll("cbq=%p", cbq); ucs_callbackq_enter(cbq); - /* Execute and update slow-path callbacks */ - for (slow_idx = 0; slow_idx < priv->num_slow_elems; ++slow_idx) { + /* Execute and update slow-path callbacks by num_slow_elems copy to avoid + * infinite loop if callback adds another one */ + for (slow_idx = 0; slow_idx < num_slow_elems; ++slow_idx) { elem = &priv->slow_elems[slow_idx]; if (elem->id == UCS_CALLBACKQ_ID_NULL) { continue; diff --git a/test/gtest/ucs/test_callbackq.cc b/test/gtest/ucs/test_callbackq.cc index 705dbdc9134..c8e840d4616 100644 --- a/test/gtest/ucs/test_callbackq.cc +++ b/test/gtest/ucs/test_callbackq.cc @@ -32,6 +32,7 @@ class test_callbackq : uint32_t count; int command; callback_ctx *to_add; + unsigned flags; int key; }; @@ -81,9 +82,11 @@ class test_callbackq : void init_ctx(callback_ctx *ctx, int key = 0) { ctx->test = this; + ctx->callback_id = UCS_CALLBACKQ_ID_NULL; ctx->count = 0; ctx->command = COMMAND_NONE; - ctx->callback_id = UCS_CALLBACKQ_ID_NULL; + ctx->to_add = NULL; + ctx->flags = 0; ctx->key = key; } @@ -95,7 +98,7 @@ class test_callbackq : { ctx->callback_id = ucs_callbackq_add(&m_cbq, callback_proxy, reinterpret_cast(ctx), - cb_flags() | flags); + ctx->flags | cb_flags() | flags); } void remove(int callback_id) @@ -215,19 +218,22 @@ UCS_TEST_P(test_callbackq, add_another) { ctx.command = COMMAND_NONE; unsigned count = ctx.count; + if (cb_flags() & UCS_CALLBACKQ_FLAG_FAST) { + count++; /* fast CBs are executed immediately after "add" */ + } dispatch(); EXPECT_EQ(2u, ctx.count); - EXPECT_EQ(count + 1, ctx2.count); + EXPECT_EQ(count, ctx2.count); remove(&ctx); dispatch(); EXPECT_EQ(2u, ctx.count); - EXPECT_EQ(count + 2, ctx2.count); + EXPECT_EQ(count + 1, ctx2.count); remove(&ctx2); dispatch(); - EXPECT_EQ(count + 2, ctx2.count); + EXPECT_EQ(count + 1, ctx2.count); } UCS_MT_TEST_P(test_callbackq, threads, 10) { @@ -337,6 +343,24 @@ UCS_TEST_F(test_callbackq_noflags, oneshot) { EXPECT_EQ(1u, ctx.count); } +UCS_TEST_F(test_callbackq_noflags, oneshot_recursive) { + callback_ctx ctx; + + init_ctx(&ctx); + ctx.command = COMMAND_ADD_ANOTHER; + ctx.flags = UCS_CALLBACKQ_FLAG_ONESHOT; + ctx.to_add = &ctx; + + add(&ctx); + + for (unsigned i = 0; i < 10; ++i) { + dispatch(1); + EXPECT_LE(i + 1, ctx.count); + } + + remove(ctx.callback_id); +} + UCS_TEST_F(test_callbackq_noflags, remove_if) { const size_t count = 1000; const int num_keys = 10;