Skip to content

Commit

Permalink
Merge pull request openucx#3 from Akshay-Venkatesh/use-coroutines
Browse files Browse the repository at this point in the history
Use coroutines
  • Loading branch information
Akshay-Venkatesh authored Nov 1, 2018
2 parents 695e8ae + 7519df6 commit 9f0750d
Show file tree
Hide file tree
Showing 3 changed files with 148 additions and 62 deletions.
68 changes: 67 additions & 1 deletion pybind/call_myucp.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# See file LICENSE for terms.

import concurrent.futures
import asyncio
from weakref import WeakValueDictionary

cdef extern from "myucp.h":
Expand Down Expand Up @@ -49,6 +50,51 @@ class CommFuture(concurrent.futures.Future):
print("releasing " + str(id(self)))
self.ucp_msg.free_mem()

def __await__(self):
if True == self.done_state:
return self.result_state
else:
while False == self.done_state:
if True == self.done():
return self.result_state
else:
yield

class ServerFuture(concurrent.futures.Future):

_instances = WeakValueDictionary()

def __init__(self, cb):
self.done_state = False
self.result_state = None
self.cb = cb
self._instances[id(self)] = self
super(ServerFuture, self).__init__()

def done(self):
if False == self.done_state:
ucp_py_worker_progress()
return self.done_state

def result(self):
while False == self.done_state:
self.done()
return self.result_state

def __del__(self):
print("releasing " + str(id(self)))

def __await__(self):
if True == self.done_state:
return self.result_state
else:
while False == self.done_state:
if True == self.done():
return self.result_state
else:
yield


cdef class ucp_py_ep:
cdef ucp_ep_h* ucp_ep
cdef int ptr_set
Expand Down Expand Up @@ -218,17 +264,37 @@ cdef class ucp_msg:
def get_comm_len(self):
return self.comm_len

accept_cb_is_coroutine = False

cdef void accept_callback(ucp_ep_h *client_ep_ptr, void *f):
global accept_cb_is_coroutine
client_ep = ucp_py_ep()
client_ep.ucp_ep = client_ep_ptr
(<object>f)(client_ep) #sign py_func(ucp_py_ep()) expected
if not accept_cb_is_coroutine:
print('A')
(<object>f)(client_ep) #sign py_func(ucp_py_ep()) expected
else:
print('B')
current_loop = asyncio.get_running_loop()
current_loop.create_task((<object>f)(client_ep))

def init():
return ucp_py_init()

def listen(py_func, server_port = -1):
return ucp_py_listen(accept_callback, <void *>py_func, server_port)

def start_server(py_func, server_port = -1, is_coroutine = False):
global accept_cb_is_coroutine
accept_cb_is_coroutine = is_coroutine
sf = ServerFuture(py_func)
async def async_start_server():
await sf
if 0 == ucp_py_listen(accept_callback, <void *>py_func, server_port):
return async_start_server()
else:
return -1

def fin():
return ucp_py_finalize()

Expand Down
40 changes: 0 additions & 40 deletions pybind/myucp.c
Original file line number Diff line number Diff line change
Expand Up @@ -291,46 +291,6 @@ static void wait(ucp_worker_h ucp_worker, struct ucx_context *context)
}
}

static ucs_status_t test_poll_wait(ucp_worker_h ucp_worker)
{
int ret = -1, err = 0;
ucs_status_t status;
int epoll_fd_local = 0, epoll_fd = 0;
struct epoll_event ev;
ev.data.u64 = 0;

status = ucp_worker_get_efd(ucp_worker, &epoll_fd);
CHKERR_JUMP(UCS_OK != status, "ucp_worker_get_efd", err);

/* It is recommended to copy original fd */
epoll_fd_local = epoll_create(1);

ev.data.fd = epoll_fd;
ev.events = EPOLLIN;
err = epoll_ctl(epoll_fd_local, EPOLL_CTL_ADD, epoll_fd, &ev);
CHKERR_JUMP(err < 0, "add original socket to the new epoll\n", err_fd);

/* Need to prepare ucp_worker before epoll_wait */
status = ucp_worker_arm(ucp_worker);
if (status == UCS_ERR_BUSY) { /* some events are arrived already */
ret = UCS_OK;
goto err_fd;
}
CHKERR_JUMP(status != UCS_OK, "ucp_worker_arm\n", err_fd);

do {
ret = epoll_wait(epoll_fd_local, &ev, 1, -1);
} while ((ret == -1) && (errno == EINTR));

ret = UCS_OK;

err_fd:
close(epoll_fd_local);

err:
return ret;
}

static void flush_callback(void *request, ucs_status_t status)
{
}
Expand Down
102 changes: 81 additions & 21 deletions tests/test-server-listen-accept-future-recv.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,20 @@
import call_myucp as ucp
import time
import argparse
import asyncio
import concurrent.futures

def cb(future):
print('in test callback')

def send_recv(ep, msg_log, is_server, is_cuda):
def send_recv(ep, msg_log, server, is_cuda):
buffer_region = ucp.buffer_region()
if is_cuda:
buffer_region.alloc_cuda(1 << msg_log)
else:
buffer_region.alloc_host(1 << msg_log)
msg = ucp.ucp_msg(buffer_region)
if 1 == is_server:
if server:
msg.set_mem(0, 1 << msg_log)
#send_req = msg.send_ft(ep, 1 << msg_log)
send_req = ep.send(msg, 1 << msg_log)
Expand All @@ -38,58 +40,116 @@ def send_recv(ep, msg_log, is_server, is_cuda):
else:
buffer_region.free_host()

accept_cb_started = 0
accept_cb_started = False
new_client_ep = None
max_msg_log = 23

async def talk_to_client(client_ep):

print("in talk_to_client")
msg_log = max_msg_log

buffer_region = ucp.buffer_region()
buffer_region.alloc_cuda(1 << msg_log)

msg = ucp.ucp_msg(buffer_region)

send_req = await client_ep.send(msg, 1 << msg_log)

recv_req = await client_ep.recv_ft()

buffer_region.free_cuda()

print(42)
return 42

async def talk_to_server(ip, port):

msg_log = max_msg_log

server_ep = ucp.get_endpoint(ip, port)

buffer_region = ucp.buffer_region()
buffer_region.alloc_cuda(1 << msg_log)

msg = ucp.ucp_msg(buffer_region)

recv_req = await server_ep.recv_ft()

send_req = await server_ep.send(msg, 1 << msg_log)

buffer_region.free_cuda()

print(3.14)
return 3.14

def server_accept_callback(client_ep):
global accept_cb_started
global new_client_ep
print("in python accept callback")
new_client_ep = client_ep
accept_cb_started = 1
assert new_client_ep != None
is_cuda = False
send_recv(new_client_ep, max_msg_log, server, is_cuda)
is_cuda = True
send_recv(new_client_ep, max_msg_log, server, is_cuda)
accept_cb_started = True

max_msg_log = 23
parser = argparse.ArgumentParser()
parser.add_argument('-s','--server', help='enter server ip', required=False)
parser.add_argument('-p','--port', help='enter server port number', required=False)
args = parser.parse_args()

## initiate ucp
init_str = ""
is_server = 0
server = False
if args.server is None:
is_server = 1
server = True
else:
is_server = 0
server = False
init_str = args.server

'''
## setup endpoints
ucp.init()
server_ep = None
if 0 == is_server:
if not server:
#connect to server
server_ep = ucp.get_endpoint(init_str.encode(), int(args.port))
is_cuda = False
send_recv(server_ep, max_msg_log, is_server, is_cuda)
send_recv(server_ep, max_msg_log, server, is_cuda)
is_cuda = True
send_recv(server_ep, max_msg_log, is_server, is_cuda)
send_recv(server_ep, max_msg_log, server, is_cuda)
else:
ucp.listen(server_accept_callback)
while 0 == accept_cb_started:
ucp.ucp_progress()
assert new_client_ep != None
is_cuda = False
send_recv(new_client_ep, max_msg_log, is_server, is_cuda)
is_cuda = True
send_recv(new_client_ep, max_msg_log, is_server, is_cuda)
#assert new_client_ep != None
#is_cuda = False
#send_recv(new_client_ep, max_msg_log, server, is_cuda)
#is_cuda = True
#send_recv(new_client_ep, max_msg_log, server, is_cuda)
if 1 == is_server:
if server:
assert new_client_ep != None
ucp.destroy_ep(new_client_ep)
else:
ucp.destroy_ep(server_ep)
'''

if args.server is None:
print("Server Finalized")
ucp.init()
loop = asyncio.get_event_loop()
# coro points to either client or server-side coroutine
if server:
coro = ucp.start_server(talk_to_client, is_coroutine = True)
#coro = talk_to_client()
else:
print("Client Finalized")
coro = talk_to_server(init_str.encode(), int(args.port))

loop.run_until_complete(coro)

try:
loop.run_forever()
except KeyboardInterrupt:
pass

loop.close()

0 comments on commit 9f0750d

Please sign in to comment.