Skip to content

Commit

Permalink
Merge pull request openucx#21 from Akshay-Venkatesh/receive-separatio…
Browse files Browse the repository at this point in the history
…n-via-ep

Separate receive operations based on the Endpoint it used with
  • Loading branch information
Akshay-Venkatesh authored Jan 25, 2019
2 parents c4124fa + 237ca49 commit ecc0d72
Show file tree
Hide file tree
Showing 5 changed files with 216 additions and 48 deletions.
21 changes: 11 additions & 10 deletions pybind/ucp_py.pyx
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
# See file LICENSE for terms.

# cython: language_level=3
import concurrent.futures
import asyncio
import time
from weakref import WeakValueDictionary

cdef extern from "ucp_py_ucp_fxns.h":
ctypedef void (*listener_accept_cb_func)(ucp_ep_h *client_ep_ptr, void *user_data)
ctypedef void (*listener_accept_cb_func)(void *client_ep_ptr, void *user_data)

cdef extern from "ucp/api/ucp.h":
ctypedef struct ucp_ep_h:
Expand Down Expand Up @@ -107,7 +107,7 @@ cdef class ucp_py_ep:
"""A class that represents an endpoint connected to a peer
"""

cdef ucp_ep_h* ucp_ep
cdef void* ucp_ep
cdef int ptr_set

def __cinit__(self):
Expand All @@ -121,6 +121,7 @@ cdef class ucp_py_ep:
"""Blind receive operation"""

recv_msg = ucp_msg(None)
recv_msg.ucp_ep = self.ucp_ep
recv_future = CommFuture(recv_msg)
ucp_py_ep_post_probe()
return recv_future
Expand All @@ -133,7 +134,7 @@ cdef class ucp_py_ep:
CommFuture object
"""

msg.ctx_ptr = ucp_py_recv_nb(msg.buf, len)
msg.ctx_ptr = ucp_py_recv_nb(self.ucp_ep, msg.buf, len)
return msg.get_future(len)

def send(self, ucp_msg msg, len):
Expand All @@ -154,7 +155,7 @@ cdef class ucp_py_ep:
-------
ucp_comm_request object
"""
msg.ctx_ptr = ucp_py_recv_nb(msg.buf, len)
msg.ctx_ptr = ucp_py_recv_nb(self.ucp_ep, msg.buf, len)
return msg.get_comm_request(len)

def send_fast(self, ucp_msg msg, len):
Expand All @@ -180,7 +181,7 @@ cdef class ucp_py_ep:
buf_reg.populate_ptr(msg)
buf_reg.is_cuda = 0 # for now but it does not matter
internal_msg = ucp_msg(buf_reg)
internal_msg.ctx_ptr = ucp_py_recv_nb(internal_msg.buf, len)
internal_msg.ctx_ptr = ucp_py_recv_nb(self.ucp_ep, internal_msg.buf, len)
return internal_msg.get_comm_request(len)

def send_obj(self, msg, len):
Expand Down Expand Up @@ -209,7 +210,7 @@ cdef class ucp_msg:
cdef ucx_context* ctx_ptr
cdef int ctx_ptr_set
cdef data_buf* buf
cdef ucp_ep_h* ep_ptr
cdef void* ucp_ep
cdef int is_cuda
cdef int alloc_len
cdef int comm_len
Expand Down Expand Up @@ -274,11 +275,11 @@ cdef class ucp_msg:
if 1 == self.ctx_ptr_set:
return ucp_py_query_request(self.ctx_ptr)
else:
len = ucp_py_probe_query()
len = ucp_py_probe_query(self.ucp_ep)
if -1 != len:
self.alloc_host(len)
self.internally_allocated = 1
self.ctx_ptr = ucp_py_recv_nb(self.buf, len)
self.ctx_ptr = ucp_py_recv_nb(self.ucp_ep, self.buf, len)
self.comm_len = len
self.ctx_ptr_set = 1
return 0
Expand Down Expand Up @@ -330,7 +331,7 @@ cdef class ucp_comm_request:
accept_cb_is_coroutine = False
sf_instance = None

cdef void accept_callback(ucp_ep_h *client_ep_ptr, void *f):
cdef void accept_callback(void *client_ep_ptr, void *f):
global accept_cb_is_coroutine
client_ep = ucp_py_ep()
client_ep.ucp_ep = client_ep_ptr
Expand Down
Loading

0 comments on commit ecc0d72

Please sign in to comment.