diff --git a/examples/hello_world_util.h b/examples/hello_world_util.h index 51cff49c830..30d0e2d88c7 100644 --- a/examples/hello_world_util.h +++ b/examples/hello_world_util.h @@ -9,6 +9,7 @@ #include +#include #include #include #include @@ -265,8 +266,10 @@ int client_connect(const char *server, uint16_t server_port) return -1; } -static inline int barrier(int oob_sock) +static inline int +barrier(int oob_sock, void (*progress_cb)(void *arg), void *arg) { + struct pollfd pfd = { .fd = oob_sock, .events = POLLIN }; int dummy = 0; ssize_t res; @@ -275,6 +278,11 @@ static inline int barrier(int oob_sock) return res; } + do { + res = poll(&pfd, 1, 1); + progress_cb(arg); + } while (res != 1); + res = recv(oob_sock, &dummy, sizeof(dummy), MSG_WAITALL); /* number of received bytes should be the same as sent */ diff --git a/examples/ucp_hello_world.c b/examples/ucp_hello_world.c index 01a79af533a..c45e5ae2bbb 100644 --- a/examples/ucp_hello_world.c +++ b/examples/ucp_hello_world.c @@ -506,6 +506,11 @@ static int run_test(const char *client_target_name, ucp_worker_h ucp_worker) } } +static void progress_worker(void *arg) +{ + ucp_worker_progress((ucp_worker_h)arg); +} + int main(int argc, char **argv) { /* UCP temporary vars */ @@ -602,7 +607,7 @@ int main(int argc, char **argv) if (!ret && (err_handling_opt.failure_mode == FAILURE_MODE_NONE)) { /* Make sure remote is disconnected before destroying local worker */ - ret = barrier(oob_sock); + ret = barrier(oob_sock, progress_worker, ucp_worker); } close(oob_sock); diff --git a/examples/uct_hello_world.c b/examples/uct_hello_world.c index 66a51d81f39..237d46a09b2 100644 --- a/examples/uct_hello_world.c +++ b/examples/uct_hello_world.c @@ -561,6 +561,11 @@ int sendrecv(int sock, const void *sbuf, size_t slen, void **rbuf) return 0; } +static void progress_worker(void *arg) +{ + uct_worker_progress((uct_worker_h)arg); +} + int main(int argc, char **argv) { uct_device_addr_t *peer_dev = NULL; @@ -660,7 +665,7 @@ int main(int argc, char **argv) /* Connect endpoint to a remote endpoint */ status = uct_ep_connect_to_ep(ep, peer_dev, peer_ep); - if (barrier(oob_sock)) { + if (barrier(oob_sock, progress_worker, if_info.worker)) { status = UCS_ERR_IO_ERROR; goto out_free_ep; } @@ -729,7 +734,7 @@ int main(int argc, char **argv) } } - if (barrier(oob_sock)) { + if (barrier(oob_sock, progress_worker, if_info.worker)) { status = UCS_ERR_IO_ERROR; }