Skip to content

Commit

Permalink
Merge pull request #4529 from petro-rudenko/jucx/request-close
Browse files Browse the repository at this point in the history
JUCX: ucp_request functionality.
  • Loading branch information
yosefe authored Dec 4, 2019
2 parents 99c7a73 + 04fb75c commit b991d31
Show file tree
Hide file tree
Showing 15 changed files with 177 additions and 72 deletions.
1 change: 1 addition & 0 deletions bindings/java/pom.xml.in
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,7 @@
<javahClassName>org.openucx.jucx.ucp.UcpEndpoint</javahClassName>
<javahClassName>org.openucx.jucx.ucp.UcpListener</javahClassName>
<javahClassName>org.openucx.jucx.ucp.UcpMemory</javahClassName>
<javahClassName>org.openucx.jucx.ucp.UcpRequest</javahClassName>
<javahClassName>org.openucx.jucx.ucp.UcpRemoteKey</javahClassName>
<javahClassName>org.openucx.jucx.ucp.UcpWorker</javahClassName>
<javahClassName>org.openucx.jucx.ucs.UcsConstants</javahClassName>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@

package org.openucx.jucx;

import org.openucx.jucx.ucp.UcpRequest;

/**
* Callback wrapper to notify successful or failure events from JNI.
*/

public class UcxCallback {
public void onSuccess(UcxRequest request) {}
public void onSuccess(UcpRequest request) {}

public void onError(int ucsStatus, String errorMsg) {
throw new UcxException(errorMsg);
Expand Down
22 changes: 0 additions & 22 deletions bindings/java/src/main/java/org/openucx/jucx/UcxRequest.java

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
package org.openucx.jucx.examples;

import org.openucx.jucx.UcxCallback;
import org.openucx.jucx.UcxRequest;
import org.openucx.jucx.ucp.UcpRequest;
import org.openucx.jucx.UcxUtils;
import org.openucx.jucx.ucp.*;

Expand All @@ -30,7 +30,7 @@ public static void main(String[] args) throws Exception {
resources.push(listener);

ByteBuffer recvBuffer = ByteBuffer.allocateDirect(4096);
UcxRequest recvRequest = worker.recvTaggedNonBlocking(recvBuffer, null);
UcpRequest recvRequest = worker.recvTaggedNonBlocking(recvBuffer, null);

System.out.println("Waiting for connections on " + sockaddr + " ...");

Expand Down Expand Up @@ -66,13 +66,13 @@ public static void main(String[] args) throws Exception {
(int)Math.min(Integer.MAX_VALUE, totalSize));
for (int i = 0; i < numIterations; i++) {
final int iterNum = i;
UcxRequest getRequest = endpoint.getNonBlocking(remoteAddress, remoteKey,
UcpRequest getRequest = endpoint.getNonBlocking(remoteAddress, remoteKey,
recvMemory.getAddress(), totalSize,
new UcxCallback() {
long startTime = System.nanoTime();

@Override
public void onSuccess(UcxRequest request) {
public void onSuccess(UcpRequest request) {
long finishTime = System.nanoTime();
data.clear();
assert data.hashCode() == remoteHashCode;
Expand All @@ -90,7 +90,7 @@ public void onSuccess(UcxRequest request) {

ByteBuffer sendBuffer = ByteBuffer.allocateDirect(100);
sendBuffer.asCharBuffer().put("DONE");
UcxRequest sent = endpoint.sendTaggedNonBlocking(sendBuffer, null);
UcpRequest sent = endpoint.sendTaggedNonBlocking(sendBuffer, null);

while (!sent.isCompleted()) {
worker.progress();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
package org.openucx.jucx.examples;

import org.openucx.jucx.UcxCallback;
import org.openucx.jucx.UcxRequest;
import org.openucx.jucx.ucp.UcpRequest;
import org.openucx.jucx.UcxUtils;
import org.openucx.jucx.ucp.UcpEndpoint;
import org.openucx.jucx.ucp.UcpEndpointParams;
Expand Down Expand Up @@ -51,10 +51,10 @@ public static void main(String[] args) throws Exception {
endpoint.sendTaggedNonBlocking(sendData, null);

ByteBuffer recvBuffer = ByteBuffer.allocateDirect(4096);
UcxRequest recvRequest = worker.recvTaggedNonBlocking(recvBuffer,
UcpRequest recvRequest = worker.recvTaggedNonBlocking(recvBuffer,
new UcxCallback() {
@Override
public void onSuccess(UcxRequest request) {
public void onSuccess(UcpRequest request) {
System.out.println("Received a message:");
System.out.println(recvBuffer.asCharBuffer().toString());
}
Expand Down
26 changes: 13 additions & 13 deletions bindings/java/src/main/java/org/openucx/jucx/ucp/UcpEndpoint.java
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ private void checkRemoteAccessParams(ByteBuffer buf, UcpRemoteKey remoteKey) {
* guarantee re-usability of the source {@code data} buffer.
* {@code callback} is invoked on completion of this operation.
*/
public UcxRequest putNonBlocking(ByteBuffer src, long remoteAddress, UcpRemoteKey remoteKey,
public UcpRequest putNonBlocking(ByteBuffer src, long remoteAddress, UcpRemoteKey remoteKey,
UcxCallback callback) {

checkRemoteAccessParams(src, remoteKey);
Expand All @@ -60,7 +60,7 @@ public UcxRequest putNonBlocking(ByteBuffer src, long remoteAddress, UcpRemoteKe
remoteKey, callback);
}

public UcxRequest putNonBlocking(long localAddress, long size,
public UcpRequest putNonBlocking(long localAddress, long size,
long remoteAddress, UcpRemoteKey remoteKey,
UcxCallback callback) {

Expand Down Expand Up @@ -104,9 +104,9 @@ public void putNonBlockingImplicit(long localAddress, long size,
* not guarantee that remote data is loaded and stored under the local {@code dst} buffer
* starting of it's {@code dst.position()} and size {@code dst.remaining()}.
* {@code callback} is invoked on completion of this operation.
* @return {@link UcxRequest} object that can be monitored for completion.
* @return {@link UcpRequest} object that can be monitored for completion.
*/
public UcxRequest getNonBlocking(long remoteAddress, UcpRemoteKey remoteKey,
public UcpRequest getNonBlocking(long remoteAddress, UcpRemoteKey remoteKey,
ByteBuffer dst, UcxCallback callback) {

checkRemoteAccessParams(dst, remoteKey);
Expand All @@ -115,7 +115,7 @@ public UcxRequest getNonBlocking(long remoteAddress, UcpRemoteKey remoteKey,
dst.remaining(), callback);
}

public UcxRequest getNonBlocking(long remoteAddress, UcpRemoteKey remoteKey,
public UcpRequest getNonBlocking(long remoteAddress, UcpRemoteKey remoteKey,
long localAddress, long size, UcxCallback callback) {

return getNonBlockingNative(getNativeId(), remoteAddress, remoteKey.getNativeId(),
Expand Down Expand Up @@ -165,15 +165,15 @@ public void getNonBlockingImplicit(long remoteAddress, UcpRemoteKey remoteKey,
* The send operation is considered completed when it is safe to reuse the source
* {@code data} buffer. {@code callback} is invoked on completion of this operation.
*/
public UcxRequest sendTaggedNonBlocking(ByteBuffer sendBuffer, long tag, UcxCallback callback) {
public UcpRequest sendTaggedNonBlocking(ByteBuffer sendBuffer, long tag, UcxCallback callback) {
if (!sendBuffer.isDirect()) {
throw new UcxException("Send buffer must be direct.");
}
return sendTaggedNonBlocking(UcxUtils.getAddress(sendBuffer),
sendBuffer.remaining(), tag, callback);
}

public UcxRequest sendTaggedNonBlocking(long localAddress, long size,
public UcpRequest sendTaggedNonBlocking(long localAddress, long size,
long tag, UcxCallback callback) {

return sendTaggedNonBlockingNative(getNativeId(),
Expand All @@ -185,7 +185,7 @@ public UcxRequest sendTaggedNonBlocking(long localAddress, long size,
* Non blocking send operation. Invokes
* {@link UcpEndpoint#sendTaggedNonBlocking(ByteBuffer, long, UcxCallback)} with default 0 tag.
*/
public UcxRequest sendTaggedNonBlocking(ByteBuffer sendBuffer, UcxCallback callback) {
public UcpRequest sendTaggedNonBlocking(ByteBuffer sendBuffer, UcxCallback callback) {
return sendTaggedNonBlocking(sendBuffer, 0, callback);
}

Expand All @@ -194,7 +194,7 @@ public UcxRequest sendTaggedNonBlocking(ByteBuffer sendBuffer, UcxCallback callb
* All the AMO and RMA operations issued on this endpoint prior to this call
* are completed both at the origin and at the target.
*/
public UcxRequest flushNonBlocking(UcxCallback callback) {
public UcpRequest flushNonBlocking(UcxCallback callback) {
return flushNonBlockingNative(getNativeId(), callback);
}

Expand All @@ -204,25 +204,25 @@ public UcxRequest flushNonBlocking(UcxCallback callback) {

private static native UcpRemoteKey unpackRemoteKey(long epId, long rkeyAddress);

private static native UcxRequest putNonBlockingNative(long enpointId, long localAddress,
private static native UcpRequest putNonBlockingNative(long enpointId, long localAddress,
long size, long remoteAddr,
long ucpRkeyId, UcxCallback callback);

private static native void putNonBlockingImplicitNative(long enpointId, long localAddress,
long size, long remoteAddr,
long ucpRkeyId);

private static native UcxRequest getNonBlockingNative(long enpointId, long remoteAddress,
private static native UcpRequest getNonBlockingNative(long enpointId, long remoteAddress,
long ucpRkeyId, long localAddress,
long size, UcxCallback callback);

private static native void getNonBlockingImplicitNative(long enpointId, long remoteAddress,
long ucpRkeyId, long localAddress,
long size);

private static native UcxRequest sendTaggedNonBlockingNative(long enpointId, long localAddress,
private static native UcpRequest sendTaggedNonBlockingNative(long enpointId, long localAddress,
long size, long tag,
UcxCallback callback);

private static native UcxRequest flushNonBlockingNative(long enpointId, UcxCallback callback);
private static native UcpRequest flushNonBlockingNative(long enpointId, UcxCallback callback);
}
45 changes: 45 additions & 0 deletions bindings/java/src/main/java/org/openucx/jucx/ucp/UcpRequest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED.
* See file LICENSE for terms.
*/

package org.openucx.jucx.ucp;

import org.openucx.jucx.UcxNativeStruct;

import java.io.Closeable;

/**
* Request object, that returns by ucp operations (GET, PUT, SEND, etc.).
* Call {@link UcpRequest#isCompleted()} to monitor completion of request.
*/
public class UcpRequest extends UcxNativeStruct implements Closeable {

private UcpRequest(long nativeId) {
setNativeId(nativeId);
}

/**
* @return whether this request is completed.
*/
public boolean isCompleted() {
return (getNativeId() == null) || isCompletedNative(getNativeId());
}

/**
* This routine releases the non-blocking request back to the library, regardless
* of its current state. Communications operations associated with this request
* will make progress internally, however no further notifications or callbacks
* will be invoked for this request.
*/
@Override
public void close() {
if (getNativeId() != null) {
closeRequestNative(getNativeId());
}
}

private static native boolean isCompletedNative(long ucpRequest);

private static native void closeRequestNative(long ucpRequest);
}
26 changes: 20 additions & 6 deletions bindings/java/src/main/java/org/openucx/jucx/ucp/UcpWorker.java
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ public int progress() {
* this worker. All the AMO and RMA operations issued on this worker prior to this call
* are completed both at the origin and at the target when this call returns.
*/
public UcxRequest flushNonBlocking(UcxCallback callback) {
public UcpRequest flushNonBlocking(UcxCallback callback) {
return flushNonBlockingNative(getNativeId(), callback);
}

Expand Down Expand Up @@ -114,7 +114,7 @@ public void signal() {
* @param tagMask - bit mask that indicates the bits that are used for the matching of the
* incoming tag against the expected tag.
*/
public UcxRequest recvTaggedNonBlocking(ByteBuffer recvBuffer, long tag, long tagMask,
public UcpRequest recvTaggedNonBlocking(ByteBuffer recvBuffer, long tag, long tagMask,
UcxCallback callback) {
if (!recvBuffer.isDirect()) {
throw new UcxException("Recv buffer must be direct.");
Expand All @@ -123,7 +123,7 @@ public UcxRequest recvTaggedNonBlocking(ByteBuffer recvBuffer, long tag, long ta
recvBuffer.remaining(), tag, tagMask, callback);
}

public UcxRequest recvTaggedNonBlocking(long localAddress, long size, long tag, long tagMask,
public UcpRequest recvTaggedNonBlocking(long localAddress, long size, long tag, long tagMask,
UcxCallback callback) {
return recvTaggedNonBlockingNative(getNativeId(), localAddress, size,
tag, tagMask, callback);
Expand All @@ -134,10 +134,22 @@ public UcxRequest recvTaggedNonBlocking(long localAddress, long size, long tag,
* {@link UcpWorker#recvTaggedNonBlocking(ByteBuffer, long, long, UcxCallback)}
* with default tag=0 and tagMask=0.
*/
public UcxRequest recvTaggedNonBlocking(ByteBuffer recvBuffer, UcxCallback callback) {
public UcpRequest recvTaggedNonBlocking(ByteBuffer recvBuffer, UcxCallback callback) {
return recvTaggedNonBlocking(recvBuffer, 0, 0, callback);
}

/**
* This routine tries to cancels an outstanding communication request. After
* calling this routine, the request will be in completed or canceled (but
* not both) state regardless of the status of the target endpoint associated
* with the communication request. If the request is completed successfully,
* the "send" or the "receive" completion callbacks (based on the type of the request) will be
* called with the status argument of the callback set to UCS_OK, and in a
* case it is canceled the status argument is set to UCS_ERR_CANCELED.
*/
public void cancelRequest(UcpRequest request) {
cancelRequestNative(getNativeId(), request.getNativeId());
}

/**
* This routine returns the address of the worker object. This address can be
Expand Down Expand Up @@ -167,13 +179,15 @@ public ByteBuffer getAddress() {

private static native int progressWorkerNative(long workerId);

private static native UcxRequest flushNonBlockingNative(long workerId, UcxCallback callback);
private static native UcpRequest flushNonBlockingNative(long workerId, UcxCallback callback);

private static native void waitWorkerNative(long workerId);

private static native void signalWorkerNative(long workerId);

private static native UcxRequest recvTaggedNonBlockingNative(long workerId, long localAddress,
private static native UcpRequest recvTaggedNonBlockingNative(long workerId, long localAddress,
long size, long tag, long tagMask,
UcxCallback callback);

private static native void cancelRequestNative(long workerId, long requestId);
}
2 changes: 2 additions & 0 deletions bindings/java/src/main/native/Makefile.am
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ JUCX_GENERATED_H_FILES = org_openucx_jucx_ucp_UcpConstants.h \
org_openucx_jucx_ucp_UcpEndpoint.h \
org_openucx_jucx_ucp_UcpListener.h \
org_openucx_jucx_ucp_UcpMemory.h \
org_openucx_jucx_ucp_UcpRequest.h \
org_openucx_jucx_ucp_UcpRemoteKey.h \
org_openucx_jucx_ucp_UcpWorker.h \
org_openucx_jucx_ucs_UcsConstants_ThreadMode.h \
Expand Down Expand Up @@ -47,6 +48,7 @@ libjucx_la_SOURCES = context.cc \
jucx_common_def.cc \
listener.cc \
memory.cc \
request.cc \
ucp_constants.cc \
ucs_constants.cc \
worker.cc
Expand Down
15 changes: 8 additions & 7 deletions bindings/java/src/main/native/jucx_common_def.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ extern "C" {

static JavaVM *jvm_global;
static jclass jucx_request_cls;
static jfieldID completed_field;
static jfieldID native_id_field;
static jmethodID on_success;
static jmethodID jucx_request_constructor;

Expand All @@ -29,13 +29,13 @@ extern "C" JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM *jvm, void* reserved) {
return JNI_ERR;
}

jclass jucx_request_cls_local = env->FindClass("org/openucx/jucx/UcxRequest");
jclass jucx_request_cls_local = env->FindClass("org/openucx/jucx/ucp/UcpRequest");
jucx_request_cls = (jclass) env->NewGlobalRef(jucx_request_cls_local);
jclass jucx_callback_cls = env->FindClass("org/openucx/jucx/UcxCallback");
completed_field = env->GetFieldID(jucx_request_cls, "completed", "Z");
native_id_field = env->GetFieldID(jucx_request_cls, "nativeId", "Ljava/lang/Long;");
on_success = env->GetMethodID(jucx_callback_cls, "onSuccess",
"(Lorg/openucx/jucx/UcxRequest;)V");
jucx_request_constructor = env->GetMethodID(jucx_request_cls, "<init>", "()V");
"(Lorg/openucx/jucx/ucp/UcpRequest;)V");
jucx_request_constructor = env->GetMethodID(jucx_request_cls, "<init>", "(J)V");
return JNI_VERSION_1_1;
}

Expand Down Expand Up @@ -155,7 +155,7 @@ JNIEnv* get_jni_env()

static inline void set_jucx_request_completed(JNIEnv *env, jobject jucx_request)
{
env->SetBooleanField(jucx_request, completed_field, true);
env->SetObjectField(jucx_request, native_id_field, NULL);
}

static inline void call_on_success(jobject callback, jobject request)
Expand Down Expand Up @@ -219,7 +219,8 @@ void recv_callback(void *request, ucs_status_t status, ucp_tag_recv_info_t *info
UCS_PROFILE_FUNC(jobject, process_request, (request, callback), void *request, jobject callback)
{
JNIEnv *env = get_jni_env();
jobject jucx_request = env->NewObject(jucx_request_cls, jucx_request_constructor);
jobject jucx_request = env->NewObject(jucx_request_cls, jucx_request_constructor,
(native_ptr)request);

if (UCS_PTR_IS_PTR(request)) {
struct jucx_context *ctx = (struct jucx_context *)request;
Expand Down
Loading

0 comments on commit b991d31

Please sign in to comment.