diff --git a/pybind/ucp_py.pyx b/pybind/ucp_py.pyx index 77ccad0e92f..e2900008065 100644 --- a/pybind/ucp_py.pyx +++ b/pybind/ucp_py.pyx @@ -1,13 +1,13 @@ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. # See file LICENSE for terms. - +# cython: language_level=3 import concurrent.futures import asyncio import time from weakref import WeakValueDictionary cdef extern from "ucp_py_ucp_fxns.h": - ctypedef void (*listener_accept_cb_func)(ucp_ep_h *client_ep_ptr, void *user_data) + ctypedef void (*listener_accept_cb_func)(void *client_ep_ptr, void *user_data) cdef extern from "ucp/api/ucp.h": ctypedef struct ucp_ep_h: @@ -107,7 +107,7 @@ cdef class ucp_py_ep: """A class that represents an endpoint connected to a peer """ - cdef ucp_ep_h* ucp_ep + cdef void* ucp_ep cdef int ptr_set def __cinit__(self): @@ -121,6 +121,7 @@ cdef class ucp_py_ep: """Blind receive operation""" recv_msg = ucp_msg(None) + recv_msg.ucp_ep = self.ucp_ep recv_future = CommFuture(recv_msg) ucp_py_ep_post_probe() return recv_future @@ -133,7 +134,7 @@ cdef class ucp_py_ep: CommFuture object """ - msg.ctx_ptr = ucp_py_recv_nb(msg.buf, len) + msg.ctx_ptr = ucp_py_recv_nb(self.ucp_ep, msg.buf, len) return msg.get_future(len) def send(self, ucp_msg msg, len): @@ -154,7 +155,7 @@ cdef class ucp_py_ep: ------- ucp_comm_request object """ - msg.ctx_ptr = ucp_py_recv_nb(msg.buf, len) + msg.ctx_ptr = ucp_py_recv_nb(self.ucp_ep, msg.buf, len) return msg.get_comm_request(len) def send_fast(self, ucp_msg msg, len): @@ -180,7 +181,7 @@ cdef class ucp_py_ep: buf_reg.populate_ptr(msg) buf_reg.is_cuda = 0 # for now but it does not matter internal_msg = ucp_msg(buf_reg) - internal_msg.ctx_ptr = ucp_py_recv_nb(internal_msg.buf, len) + internal_msg.ctx_ptr = ucp_py_recv_nb(self.ucp_ep, internal_msg.buf, len) return internal_msg.get_comm_request(len) def send_obj(self, msg, len): @@ -209,7 +210,7 @@ cdef class ucp_msg: cdef ucx_context* ctx_ptr cdef int ctx_ptr_set cdef data_buf* buf - cdef ucp_ep_h* ep_ptr + cdef void* ucp_ep cdef int is_cuda cdef int alloc_len cdef int comm_len @@ -274,11 +275,11 @@ cdef class ucp_msg: if 1 == self.ctx_ptr_set: return ucp_py_query_request(self.ctx_ptr) else: - len = ucp_py_probe_query() + len = ucp_py_probe_query(self.ucp_ep) if -1 != len: self.alloc_host(len) self.internally_allocated = 1 - self.ctx_ptr = ucp_py_recv_nb(self.buf, len) + self.ctx_ptr = ucp_py_recv_nb(self.ucp_ep, self.buf, len) self.comm_len = len self.ctx_ptr_set = 1 return 0 @@ -330,7 +331,7 @@ cdef class ucp_comm_request: accept_cb_is_coroutine = False sf_instance = None -cdef void accept_callback(ucp_ep_h *client_ep_ptr, void *f): +cdef void accept_callback(void *client_ep_ptr, void *f): global accept_cb_is_coroutine client_ep = ucp_py_ep() client_ep.ucp_ep = client_ep_ptr diff --git a/pybind/ucp_py_ucp_fxns.c b/pybind/ucp_py_ucp_fxns.c index 0e236d4f296..ec13133887a 100644 --- a/pybind/ucp_py_ucp_fxns.c +++ b/pybind/ucp_py_ucp_fxns.c @@ -25,6 +25,8 @@ #include #include #include +#include + #define CB_Q_MAX_ENTRIES 256 @@ -62,7 +64,12 @@ ucp_py_ctx_t *ucp_py_ctx_head; /* defaults */ static uint16_t default_listener_port = 13337; static const ucp_tag_t default_tag = 0x1337a880u; +static const ucp_tag_t exch_tag = 0x1342a880u; static const ucp_tag_t default_tag_mask = -1; +static char my_hostname[HNAME_MAX_LEN]; +static pid_t my_pid = -1; +static int connect_ep_counter = 0; +static int accept_ep_counter = 0; static void request_init(void *request) { @@ -104,12 +111,26 @@ static void recv_handle(void *request, ucs_status_t status, info->length); } +unsigned long djb2_hash(unsigned char *str) +{ + unsigned long hash = 5381; + int c; + + while (c = *str++) + hash = ((hash << 5) + hash) + c; /* hash * 33 + c */ + + return hash; +} + static unsigned ucp_ipy_worker_progress(ucp_worker_h ucp_worker) { - unsigned status; void *tmp_py_cb; listener_accept_cb_func tmp_pyx_cb; void *tmp_arg; + ucs_status_t status = 0; + char tmp_str[TAG_STR_MAX_LEN]; + struct ucx_context *request = 0; + ucp_py_internal_ep_t *internal_ep; status = ucp_worker_progress(ucp_worker); while (cb_used_head.tqh_first != NULL) { @@ -127,25 +148,55 @@ static unsigned ucp_ipy_worker_progress(ucp_worker_h ucp_worker) num_cb_free++; assert(num_cb_free <= CB_Q_MAX_ENTRIES); assert(cb_free_head.tqh_first != NULL); - tmp_pyx_cb(tmp_arg, tmp_py_cb); + + // call receive and wait for tag info before callback + internal_ep = (ucp_py_internal_ep_t *) tmp_arg; + request = ucp_tag_recv_nb(ucp_worker, + internal_ep->ep_tag_str, TAG_STR_MAX_LEN, + ucp_dt_make_contig(1), exch_tag, + default_tag_mask, recv_handle); + + if (UCS_PTR_IS_ERR(request)) { + fprintf(stderr, "unable to receive UCX data message (%u)\n", + UCS_PTR_STATUS(request)); + goto err_ep; + } + do { + ucp_worker_progress(ucp_worker); + //TODO: Workout if there are deadlock possibilities here + status = ucp_request_check_status(request); + } while (status == UCS_INPROGRESS); + sprintf(tmp_str, "%s:%d", internal_ep->ep_tag_str, default_listener_port); + internal_ep->send_tag = djb2_hash(tmp_str); + internal_ep->recv_tag = djb2_hash(internal_ep->ep_tag_str); + ucp_request_release(request); + accept_ep_counter++; + + tmp_pyx_cb((void *) tmp_arg, tmp_py_cb); } - return status; + return (unsigned int) status; + err_ep: + printf("listener_accept_cb\n"); + exit(-1); } -struct ucx_context *ucp_py_recv_nb(struct data_buf *recv_buf, int length) +struct ucx_context *ucp_py_recv_nb(void *internal_ep, struct data_buf *recv_buf, int length) { ucs_status_t status; + ucp_tag_t tag; ucp_ep_params_t ep_params; struct ucx_context *request = 0; int errs = 0; int i; + ucp_py_internal_ep_t *int_ep = (ucp_py_internal_ep_t *) internal_ep; DEBUG_PRINT("receiving %p\n", recv_buf->buf); + tag = int_ep->recv_tag; request = ucp_tag_recv_nb(ucp_py_ctx_head->ucp_worker, recv_buf->buf, length, - ucp_dt_make_contig(1), default_tag, - default_tag_mask, recv_handle); + ucp_dt_make_contig(1), tag, default_tag_mask, + recv_handle); DEBUG_PRINT("returning request %p\n", request); @@ -158,6 +209,7 @@ struct ucx_context *ucp_py_recv_nb(struct data_buf *recv_buf, int length) return request; err_ep: + ucp_ep_destroy(*((ucp_ep_h *) int_ep->ep_ptr)); return request; } @@ -166,19 +218,22 @@ int ucp_py_ep_post_probe() return ucp_py_ctx_head->num_probes_outstanding++; } -int ucp_py_ep_probe() +int ucp_py_ep_probe(void *internal_ep) { ucs_status_t status; + ucp_tag_t tag; ucp_ep_params_t ep_params; struct ucx_context *request = 0; int errs = 0; int i; ucp_tag_recv_info_t info_tag; ucp_tag_message_h msg_tag; + ucp_py_internal_ep_t *int_ep = (ucp_py_internal_ep_t *) internal_ep; DEBUG_PRINT("probing..\n"); - msg_tag = ucp_tag_probe_nb(ucp_py_ctx_head->ucp_worker, default_tag, + tag = int_ep->recv_tag; + msg_tag = ucp_tag_probe_nb(ucp_py_ctx_head->ucp_worker, tag, default_tag_mask, 0, &info_tag); if (msg_tag != NULL) { /* Message arrived */ @@ -189,41 +244,44 @@ int ucp_py_ep_probe() return -1; } -int ucp_py_probe_wait() +int ucp_py_probe_wait(void *internal_ep) { int probed_length; do { ucp_ipy_worker_progress(ucp_py_ctx_head->ucp_worker); - probed_length = ucp_py_ep_probe(); + probed_length = ucp_py_ep_probe(internal_ep); } while (-1 == probed_length); return probed_length; } -int ucp_py_probe_query() +int ucp_py_probe_query(void *internal_ep) { int probed_length; ucp_ipy_worker_progress(ucp_py_ctx_head->ucp_worker); - probed_length = ucp_py_ep_probe(); + probed_length = ucp_py_ep_probe(internal_ep); return probed_length; } -struct ucx_context *ucp_py_ep_send_nb(ucp_ep_h *ep_ptr, struct data_buf *send_buf, +struct ucx_context *ucp_py_ep_send_nb(void *internal_ep, struct data_buf *send_buf, int length) { ucs_status_t status; + ucp_tag_t tag; ucp_ep_params_t ep_params; struct ucx_context *request = 0; + ucp_py_internal_ep_t *int_ep = (ucp_py_internal_ep_t *) internal_ep; - DEBUG_PRINT("EP send : %p\n", ep_ptr); + DEBUG_PRINT("EP send : %p\n", int_ep->ep_ptr); DEBUG_PRINT("sending %p\n", send_buf->buf); - request = ucp_tag_send_nb(*ep_ptr, send_buf->buf, length, - ucp_dt_make_contig(1), default_tag, + tag = int_ep->send_tag; + request = ucp_tag_send_nb(*((ucp_ep_h *) int_ep->ep_ptr), send_buf->buf, length, + ucp_dt_make_contig(1), tag, send_handle); if (UCS_PTR_IS_ERR(request)) { fprintf(stderr, "unable to send UCX data message\n"); @@ -239,7 +297,7 @@ struct ucx_context *ucp_py_ep_send_nb(ucp_ep_h *ep_ptr, struct data_buf *send_bu return request; err_ep: - ucp_ep_destroy(*ep_ptr); + ucp_ep_destroy(*((ucp_ep_h *) int_ep->ep_ptr)); return request; } @@ -289,11 +347,14 @@ void set_connect_addr(const char *address_str, struct sockaddr_in *connect_addr, static void listener_accept_cb(ucp_ep_h ep, void *arg) { ucx_listener_ctx_t *context = arg; - struct ucx_context *request = 0; + ucp_py_internal_ep_t *internal_ep; ucp_ep_h *ep_ptr = NULL; + ucs_status_t status; + internal_ep = (ucp_py_internal_ep_t *) malloc(sizeof(ucp_py_internal_ep_t)); ep_ptr = (ucp_ep_h *) malloc(sizeof(ucp_ep_h)); *ep_ptr = ep; + internal_ep->ep_ptr = ep_ptr; if (num_cb_free > 0) { num_cb_free--; @@ -301,7 +362,8 @@ static void listener_accept_cb(ucp_ep_h ep, void *arg) TAILQ_REMOVE(&cb_free_head, np, entries); np->pyx_cb = context->pyx_cb; np->py_cb = context->py_cb; - np->arg = ep_ptr; + //np->arg = ep_ptr; + np->arg = internal_ep; TAILQ_INSERT_TAIL(&cb_used_head, np, entries); num_cb_used++; assert(num_cb_used <= CB_Q_MAX_ENTRIES); @@ -309,8 +371,12 @@ static void listener_accept_cb(ucp_ep_h ep, void *arg) } else { WARN_PRINT("out of free cb entries. Trying in place\n"); - context->pyx_cb(ep_ptr, context->py_cb); + // TODO: Need a receive of tag info here as well + //context->pyx_cb((void *) internal_ep, context->py_cb); + context->pyx_cb((void *) internal_ep, context->py_cb); } + + return; } static int start_listener(ucp_worker_h ucp_worker, ucx_listener_ctx_t *context, @@ -337,13 +403,17 @@ static int start_listener(ucp_worker_h ucp_worker, ucx_listener_ctx_t *context, return status; } -ucp_ep_h *ucp_py_get_ep(char *ip, int listener_port) +void *ucp_py_get_ep(char *ip, int listener_port) { ucp_ep_params_t ep_params; struct sockaddr_in connect_addr; ucs_status_t status; ucp_ep_h *ep_ptr; + ucp_py_internal_ep_t *internal_ep; + struct ucx_context *request = 0; + char tmp_str[TAG_STR_MAX_LEN]; + internal_ep = (ucp_py_internal_ep_t *) malloc(sizeof(ucp_py_internal_ep_t)); ep_ptr = (ucp_ep_h *) malloc(sizeof(ucp_ep_h)); set_connect_addr(ip, &connect_addr, (uint16_t) listener_port); ep_params.field_mask = UCP_EP_PARAM_FIELD_FLAGS | @@ -359,14 +429,47 @@ ucp_ep_h *ucp_py_get_ep(char *ip, int listener_port) DEBUG_PRINT(stderr, "failed to connect to %s (%s)\n", ip, ucs_status_string(status)); } + internal_ep->ep_ptr = ep_ptr; + sprintf(internal_ep->ep_tag_str, "%s:%u:%d", my_hostname, + (unsigned int) my_pid, connect_ep_counter); + internal_ep->send_tag = djb2_hash(internal_ep->ep_tag_str); + sprintf(tmp_str, "%s:%d", internal_ep->ep_tag_str, listener_port); + internal_ep->recv_tag = djb2_hash(tmp_str); + + request = ucp_tag_send_nb(*ep_ptr, internal_ep->ep_tag_str, TAG_STR_MAX_LEN, + ucp_dt_make_contig(1), exch_tag, + send_handle); + if (UCS_PTR_IS_ERR(request)) { + fprintf(stderr, "unable to send UCX data message\n"); + goto err_ep; + } else if (UCS_PTR_STATUS(request) != UCS_OK) { + DEBUG_PRINT("UCX data message was scheduled for send\n"); + do { + ucp_ipy_worker_progress(ucp_py_ctx_head->ucp_worker); + //TODO: Workout if there are deadlock possibilities here + status = ucp_request_check_status(request); + } while (status == UCS_INPROGRESS); + ucp_request_release(request); + } else { + /* request is complete so no need to wait on request */ + } + connect_ep_counter++; + + //return (void *)ep_ptr; + return (void *) internal_ep; - return ep_ptr; +err_ep: + ucp_ep_destroy(*ep_ptr); + exit(-1); } -int ucp_py_put_ep(ucp_ep_h *ep_ptr) +int ucp_py_put_ep(void *internal_ep) { ucs_status_t status; void *close_req; + ucp_ep_h *ep_ptr; + ucp_py_internal_ep_t *int_ep = (ucp_py_internal_ep_t *) internal_ep; + ep_ptr = int_ep->ep_ptr; DEBUG_PRINT("try ep close %p\n", ep_ptr); close_req = ucp_ep_close_nb(*ep_ptr, UCP_EP_CLOSE_MODE_FORCE); @@ -383,6 +486,7 @@ int ucp_py_put_ep(ucp_ep_h *ep_ptr) } free(ep_ptr); + free(internal_ep); DEBUG_PRINT("ep closed\n"); } @@ -394,6 +498,11 @@ int ucp_py_init() ucp_config_t *config; ucs_status_t status; + if (0 != gethostname(my_hostname, HNAME_MAX_LEN)) goto err_py_init; + my_pid = getpid(); + + DEBUG_PRINT("hname: %s pid: %d\n", my_hostname, (int)my_pid); + ucp_py_ctx_head = (ucp_py_ctx_t *) malloc(sizeof(ucp_py_ctx_t)); if (NULL == ucp_py_ctx_head) goto err_py_init; @@ -456,11 +565,12 @@ int ucp_py_listen(listener_accept_cb_func pyx_cb, void *py_cb, int port) ucp_py_ctx_head->listener_context.pyx_cb = pyx_cb; ucp_py_ctx_head->listener_context.py_cb = py_cb; ucp_py_ctx_head->listens = 1; - + default_listener_port = (port == -1 ? default_listener_port : port); + status = start_listener(ucp_py_ctx_head->ucp_worker, &ucp_py_ctx_head->listener_context, &ucp_py_ctx_head->listener, - (port == -1 ? default_listener_port : port)); + default_listener_port); CHKERR_JUMP(UCS_OK != status, "failed to start listener", err_worker); return 0; diff --git a/pybind/ucp_py_ucp_fxns.h b/pybind/ucp_py_ucp_fxns.h index e59f3db59bd..05179cc66e2 100644 --- a/pybind/ucp_py_ucp_fxns.h +++ b/pybind/ucp_py_ucp_fxns.h @@ -4,24 +4,35 @@ */ #include #include +#include +#include #include "common.h" +#define HNAME_MAX_LEN 256 +#define TAG_STR_MAX_LEN 512 -typedef void (*listener_accept_cb_func)(ucp_ep_h *client_ep_ptr, void *user_data); +typedef void (*listener_accept_cb_func)(void *client_ep_ptr, void *user_data); struct ucx_context { int completed; }; +typedef struct ucp_py_internal_ep { + ucp_ep_h *ep_ptr; + char ep_tag_str[TAG_STR_MAX_LEN]; + ucp_tag_t send_tag; + ucp_tag_t recv_tag; +} ucp_py_internal_ep_t; + int ucp_py_init(); int ucp_py_listen(listener_accept_cb_func, void *, int); int ucp_py_finalize(void); -ucp_ep_h *ucp_py_get_ep(char *, int); -int ucp_py_put_ep(ucp_ep_h *); +void *ucp_py_get_ep(char *, int); +int ucp_py_put_ep(void *); void ucp_py_worker_progress(); -struct ucx_context *ucp_py_ep_send_nb(ucp_ep_h *ep_ptr, struct data_buf *send_buf, int length); -struct ucx_context *ucp_py_recv_nb(struct data_buf *buf, int length); +struct ucx_context *ucp_py_ep_send_nb(void *ep_ptr, struct data_buf *send_buf, int length); +struct ucx_context *ucp_py_recv_nb(void *ep_ptr, struct data_buf *buf, int length); int ucp_py_ep_post_probe(); -int ucp_py_probe_query(); -int ucp_py_probe_wait(); +int ucp_py_probe_query(void *ep_ptr); +int ucp_py_probe_wait(void *ep_ptr); int ucp_py_query_request(struct ucx_context *request); diff --git a/pybind/ucp_py_ucp_fxns_wrapper.pyx b/pybind/ucp_py_ucp_fxns_wrapper.pyx index 27e929b1cfe..1caebfc8b75 100644 --- a/pybind/ucp_py_ucp_fxns_wrapper.pyx +++ b/pybind/ucp_py_ucp_fxns_wrapper.pyx @@ -6,11 +6,11 @@ cdef extern from "ucp_py_ucp_fxns.h": int ucp_py_init() int ucp_py_listen(listener_accept_cb_func, void *, int) int ucp_py_finalize() - ucp_ep_h* ucp_py_get_ep(char *, int) - int ucp_py_put_ep(ucp_ep_h *) - ucx_context* ucp_py_ep_send_nb(ucp_ep_h*, data_buf*, int) - ucx_context* ucp_py_recv_nb(data_buf*, int) + void* ucp_py_get_ep(char *, int) + int ucp_py_put_ep(void *) + ucx_context* ucp_py_ep_send_nb(void*, data_buf*, int) + ucx_context* ucp_py_recv_nb(void*, data_buf*, int) int ucp_py_ep_post_probe() - int ucp_py_probe_query() - int ucp_py_probe_wait() + int ucp_py_probe_query(void*) + int ucp_py_probe_wait(void*) int ucp_py_query_request(ucx_context*) diff --git a/tests/single.py b/tests/single.py new file mode 100644 index 00000000000..827b0957d3f --- /dev/null +++ b/tests/single.py @@ -0,0 +1,46 @@ +import asyncio +import sys +import ucp_py as ucp + +client_msg = b'hi' +server_msg = b'\\\//' + + +async def connect(host): + print("3. Starting connect") + ep = ucp.get_endpoint(host, 13337) + print("4. Client send") + await ep.send_obj(client_msg, sys.getsizeof(client_msg)) + resp = await ep.recv_future() + r_msg = ucp.get_obj_from_msg(resp) + print("8. Client got message: {}".format(r_msg.decode())) + print("9. Stopping client") + ucp.destroy_ep(ep) + + +async def serve(ep): + print("5. Starting serve") + msg = await ep.recv_future() + print("6. Server got message") + msg = ucp.get_obj_from_msg(msg) + response = "Got: {}".format(server_msg.decode()).encode() + await ep.send_obj(response, sys.getsizeof(response)) + print('7. Stopping server') + ucp.destroy_ep(ep) + ucp.stop_listener() + +async def main(host): + ucp.init() + print("1. Calling connect") + client = connect(host) + print("2. Calling start_server") + server = ucp.start_listener(serve, is_coroutine=True) + + await asyncio.gather(server, client) + +if __name__ == '__main__': + if len(sys.argv) == 2: + host = sys.argv[1].encode() + else: + host = b"192.168.40.19" + asyncio.run(main(host))