From 04fb75c066d2eafae6a3338d48f4770e3e5d9a67 Mon Sep 17 00:00:00 2001 From: Peter Rudenko Date: Mon, 2 Dec 2019 18:35:49 +0200 Subject: [PATCH] JUCX: ucp_request functionality. --- bindings/java/pom.xml.in | 1 + .../java/org/openucx/jucx/UcxCallback.java | 4 +- .../java/org/openucx/jucx/UcxRequest.java | 22 --------- .../examples/UcxReadBWBenchmarkReceiver.java | 10 ++--- .../examples/UcxReadBWBenchmarkSender.java | 6 +-- .../org/openucx/jucx/ucp/UcpEndpoint.java | 26 +++++------ .../java/org/openucx/jucx/ucp/UcpRequest.java | 45 +++++++++++++++++++ .../java/org/openucx/jucx/ucp/UcpWorker.java | 26 ++++++++--- bindings/java/src/main/native/Makefile.am | 2 + .../java/src/main/native/jucx_common_def.cc | 15 ++++--- bindings/java/src/main/native/request.cc | 23 ++++++++++ bindings/java/src/main/native/worker.cc | 8 ++++ .../org/openucx/jucx/UcpEndpointTest.java | 26 +++++------ .../java/org/openucx/jucx/UcpRequestTest.java | 31 +++++++++++++ .../java/org/openucx/jucx/UcpWorkerTest.java | 4 +- 15 files changed, 177 insertions(+), 72 deletions(-) delete mode 100644 bindings/java/src/main/java/org/openucx/jucx/UcxRequest.java create mode 100644 bindings/java/src/main/java/org/openucx/jucx/ucp/UcpRequest.java create mode 100644 bindings/java/src/main/native/request.cc create mode 100644 bindings/java/src/test/java/org/openucx/jucx/UcpRequestTest.java diff --git a/bindings/java/pom.xml.in b/bindings/java/pom.xml.in index cd497f7dbd6..1052d5aaac7 100644 --- a/bindings/java/pom.xml.in +++ b/bindings/java/pom.xml.in @@ -279,6 +279,7 @@ org.openucx.jucx.ucp.UcpEndpoint org.openucx.jucx.ucp.UcpListener org.openucx.jucx.ucp.UcpMemory + org.openucx.jucx.ucp.UcpRequest org.openucx.jucx.ucp.UcpRemoteKey org.openucx.jucx.ucp.UcpWorker org.openucx.jucx.ucs.UcsConstants diff --git a/bindings/java/src/main/java/org/openucx/jucx/UcxCallback.java b/bindings/java/src/main/java/org/openucx/jucx/UcxCallback.java index 60262f16a2e..a75cb766e8d 100644 --- a/bindings/java/src/main/java/org/openucx/jucx/UcxCallback.java +++ b/bindings/java/src/main/java/org/openucx/jucx/UcxCallback.java @@ -5,12 +5,14 @@ package org.openucx.jucx; +import org.openucx.jucx.ucp.UcpRequest; + /** * Callback wrapper to notify successful or failure events from JNI. */ public class UcxCallback { - public void onSuccess(UcxRequest request) {} + public void onSuccess(UcpRequest request) {} public void onError(int ucsStatus, String errorMsg) { throw new UcxException(errorMsg); diff --git a/bindings/java/src/main/java/org/openucx/jucx/UcxRequest.java b/bindings/java/src/main/java/org/openucx/jucx/UcxRequest.java deleted file mode 100644 index 623c13795ef..00000000000 --- a/bindings/java/src/main/java/org/openucx/jucx/UcxRequest.java +++ /dev/null @@ -1,22 +0,0 @@ -/* - * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. - * See file LICENSE for terms. - */ - -package org.openucx.jucx; - -/** - * Request object, that returns by ucp operations (GET, PUT, SEND, etc.). - * Call {@link UcxRequest#isCompleted()} to monitor completion of request. - */ -public class UcxRequest { - - private boolean completed; - - /** - * @return whether this request is completed. - */ - public boolean isCompleted() { - return completed; - } -} diff --git a/bindings/java/src/main/java/org/openucx/jucx/examples/UcxReadBWBenchmarkReceiver.java b/bindings/java/src/main/java/org/openucx/jucx/examples/UcxReadBWBenchmarkReceiver.java index 5933471c1f9..8bdbb2e57c9 100644 --- a/bindings/java/src/main/java/org/openucx/jucx/examples/UcxReadBWBenchmarkReceiver.java +++ b/bindings/java/src/main/java/org/openucx/jucx/examples/UcxReadBWBenchmarkReceiver.java @@ -6,7 +6,7 @@ package org.openucx.jucx.examples; import org.openucx.jucx.UcxCallback; -import org.openucx.jucx.UcxRequest; +import org.openucx.jucx.ucp.UcpRequest; import org.openucx.jucx.UcxUtils; import org.openucx.jucx.ucp.*; @@ -30,7 +30,7 @@ public static void main(String[] args) throws Exception { resources.push(listener); ByteBuffer recvBuffer = ByteBuffer.allocateDirect(4096); - UcxRequest recvRequest = worker.recvTaggedNonBlocking(recvBuffer, null); + UcpRequest recvRequest = worker.recvTaggedNonBlocking(recvBuffer, null); System.out.println("Waiting for connections on " + sockaddr + " ..."); @@ -66,13 +66,13 @@ public static void main(String[] args) throws Exception { (int)Math.min(Integer.MAX_VALUE, totalSize)); for (int i = 0; i < numIterations; i++) { final int iterNum = i; - UcxRequest getRequest = endpoint.getNonBlocking(remoteAddress, remoteKey, + UcpRequest getRequest = endpoint.getNonBlocking(remoteAddress, remoteKey, recvMemory.getAddress(), totalSize, new UcxCallback() { long startTime = System.nanoTime(); @Override - public void onSuccess(UcxRequest request) { + public void onSuccess(UcpRequest request) { long finishTime = System.nanoTime(); data.clear(); assert data.hashCode() == remoteHashCode; @@ -90,7 +90,7 @@ public void onSuccess(UcxRequest request) { ByteBuffer sendBuffer = ByteBuffer.allocateDirect(100); sendBuffer.asCharBuffer().put("DONE"); - UcxRequest sent = endpoint.sendTaggedNonBlocking(sendBuffer, null); + UcpRequest sent = endpoint.sendTaggedNonBlocking(sendBuffer, null); while (!sent.isCompleted()) { worker.progress(); diff --git a/bindings/java/src/main/java/org/openucx/jucx/examples/UcxReadBWBenchmarkSender.java b/bindings/java/src/main/java/org/openucx/jucx/examples/UcxReadBWBenchmarkSender.java index 2b89945f3c9..fe2bbbdb916 100644 --- a/bindings/java/src/main/java/org/openucx/jucx/examples/UcxReadBWBenchmarkSender.java +++ b/bindings/java/src/main/java/org/openucx/jucx/examples/UcxReadBWBenchmarkSender.java @@ -6,7 +6,7 @@ package org.openucx.jucx.examples; import org.openucx.jucx.UcxCallback; -import org.openucx.jucx.UcxRequest; +import org.openucx.jucx.ucp.UcpRequest; import org.openucx.jucx.UcxUtils; import org.openucx.jucx.ucp.UcpEndpoint; import org.openucx.jucx.ucp.UcpEndpointParams; @@ -51,10 +51,10 @@ public static void main(String[] args) throws Exception { endpoint.sendTaggedNonBlocking(sendData, null); ByteBuffer recvBuffer = ByteBuffer.allocateDirect(4096); - UcxRequest recvRequest = worker.recvTaggedNonBlocking(recvBuffer, + UcpRequest recvRequest = worker.recvTaggedNonBlocking(recvBuffer, new UcxCallback() { @Override - public void onSuccess(UcxRequest request) { + public void onSuccess(UcpRequest request) { System.out.println("Received a message:"); System.out.println(recvBuffer.asCharBuffer().toString()); } diff --git a/bindings/java/src/main/java/org/openucx/jucx/ucp/UcpEndpoint.java b/bindings/java/src/main/java/org/openucx/jucx/ucp/UcpEndpoint.java index ae68176c3e9..6fb41534c61 100644 --- a/bindings/java/src/main/java/org/openucx/jucx/ucp/UcpEndpoint.java +++ b/bindings/java/src/main/java/org/openucx/jucx/ucp/UcpEndpoint.java @@ -51,7 +51,7 @@ private void checkRemoteAccessParams(ByteBuffer buf, UcpRemoteKey remoteKey) { * guarantee re-usability of the source {@code data} buffer. * {@code callback} is invoked on completion of this operation. */ - public UcxRequest putNonBlocking(ByteBuffer src, long remoteAddress, UcpRemoteKey remoteKey, + public UcpRequest putNonBlocking(ByteBuffer src, long remoteAddress, UcpRemoteKey remoteKey, UcxCallback callback) { checkRemoteAccessParams(src, remoteKey); @@ -60,7 +60,7 @@ public UcxRequest putNonBlocking(ByteBuffer src, long remoteAddress, UcpRemoteKe remoteKey, callback); } - public UcxRequest putNonBlocking(long localAddress, long size, + public UcpRequest putNonBlocking(long localAddress, long size, long remoteAddress, UcpRemoteKey remoteKey, UcxCallback callback) { @@ -104,9 +104,9 @@ public void putNonBlockingImplicit(long localAddress, long size, * not guarantee that remote data is loaded and stored under the local {@code dst} buffer * starting of it's {@code dst.position()} and size {@code dst.remaining()}. * {@code callback} is invoked on completion of this operation. - * @return {@link UcxRequest} object that can be monitored for completion. + * @return {@link UcpRequest} object that can be monitored for completion. */ - public UcxRequest getNonBlocking(long remoteAddress, UcpRemoteKey remoteKey, + public UcpRequest getNonBlocking(long remoteAddress, UcpRemoteKey remoteKey, ByteBuffer dst, UcxCallback callback) { checkRemoteAccessParams(dst, remoteKey); @@ -115,7 +115,7 @@ public UcxRequest getNonBlocking(long remoteAddress, UcpRemoteKey remoteKey, dst.remaining(), callback); } - public UcxRequest getNonBlocking(long remoteAddress, UcpRemoteKey remoteKey, + public UcpRequest getNonBlocking(long remoteAddress, UcpRemoteKey remoteKey, long localAddress, long size, UcxCallback callback) { return getNonBlockingNative(getNativeId(), remoteAddress, remoteKey.getNativeId(), @@ -165,7 +165,7 @@ public void getNonBlockingImplicit(long remoteAddress, UcpRemoteKey remoteKey, * The send operation is considered completed when it is safe to reuse the source * {@code data} buffer. {@code callback} is invoked on completion of this operation. */ - public UcxRequest sendTaggedNonBlocking(ByteBuffer sendBuffer, long tag, UcxCallback callback) { + public UcpRequest sendTaggedNonBlocking(ByteBuffer sendBuffer, long tag, UcxCallback callback) { if (!sendBuffer.isDirect()) { throw new UcxException("Send buffer must be direct."); } @@ -173,7 +173,7 @@ public UcxRequest sendTaggedNonBlocking(ByteBuffer sendBuffer, long tag, UcxCall sendBuffer.remaining(), tag, callback); } - public UcxRequest sendTaggedNonBlocking(long localAddress, long size, + public UcpRequest sendTaggedNonBlocking(long localAddress, long size, long tag, UcxCallback callback) { return sendTaggedNonBlockingNative(getNativeId(), @@ -185,7 +185,7 @@ public UcxRequest sendTaggedNonBlocking(long localAddress, long size, * Non blocking send operation. Invokes * {@link UcpEndpoint#sendTaggedNonBlocking(ByteBuffer, long, UcxCallback)} with default 0 tag. */ - public UcxRequest sendTaggedNonBlocking(ByteBuffer sendBuffer, UcxCallback callback) { + public UcpRequest sendTaggedNonBlocking(ByteBuffer sendBuffer, UcxCallback callback) { return sendTaggedNonBlocking(sendBuffer, 0, callback); } @@ -194,7 +194,7 @@ public UcxRequest sendTaggedNonBlocking(ByteBuffer sendBuffer, UcxCallback callb * All the AMO and RMA operations issued on this endpoint prior to this call * are completed both at the origin and at the target. */ - public UcxRequest flushNonBlocking(UcxCallback callback) { + public UcpRequest flushNonBlocking(UcxCallback callback) { return flushNonBlockingNative(getNativeId(), callback); } @@ -204,7 +204,7 @@ public UcxRequest flushNonBlocking(UcxCallback callback) { private static native UcpRemoteKey unpackRemoteKey(long epId, long rkeyAddress); - private static native UcxRequest putNonBlockingNative(long enpointId, long localAddress, + private static native UcpRequest putNonBlockingNative(long enpointId, long localAddress, long size, long remoteAddr, long ucpRkeyId, UcxCallback callback); @@ -212,7 +212,7 @@ private static native void putNonBlockingImplicitNative(long enpointId, long loc long size, long remoteAddr, long ucpRkeyId); - private static native UcxRequest getNonBlockingNative(long enpointId, long remoteAddress, + private static native UcpRequest getNonBlockingNative(long enpointId, long remoteAddress, long ucpRkeyId, long localAddress, long size, UcxCallback callback); @@ -220,9 +220,9 @@ private static native void getNonBlockingImplicitNative(long enpointId, long rem long ucpRkeyId, long localAddress, long size); - private static native UcxRequest sendTaggedNonBlockingNative(long enpointId, long localAddress, + private static native UcpRequest sendTaggedNonBlockingNative(long enpointId, long localAddress, long size, long tag, UcxCallback callback); - private static native UcxRequest flushNonBlockingNative(long enpointId, UcxCallback callback); + private static native UcpRequest flushNonBlockingNative(long enpointId, UcxCallback callback); } diff --git a/bindings/java/src/main/java/org/openucx/jucx/ucp/UcpRequest.java b/bindings/java/src/main/java/org/openucx/jucx/ucp/UcpRequest.java new file mode 100644 index 00000000000..831a33599b4 --- /dev/null +++ b/bindings/java/src/main/java/org/openucx/jucx/ucp/UcpRequest.java @@ -0,0 +1,45 @@ +/* + * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. + * See file LICENSE for terms. + */ + +package org.openucx.jucx.ucp; + +import org.openucx.jucx.UcxNativeStruct; + +import java.io.Closeable; + +/** + * Request object, that returns by ucp operations (GET, PUT, SEND, etc.). + * Call {@link UcpRequest#isCompleted()} to monitor completion of request. + */ +public class UcpRequest extends UcxNativeStruct implements Closeable { + + private UcpRequest(long nativeId) { + setNativeId(nativeId); + } + + /** + * @return whether this request is completed. + */ + public boolean isCompleted() { + return (getNativeId() == null) || isCompletedNative(getNativeId()); + } + + /** + * This routine releases the non-blocking request back to the library, regardless + * of its current state. Communications operations associated with this request + * will make progress internally, however no further notifications or callbacks + * will be invoked for this request. + */ + @Override + public void close() { + if (getNativeId() != null) { + closeRequestNative(getNativeId()); + } + } + + private static native boolean isCompletedNative(long ucpRequest); + + private static native void closeRequestNative(long ucpRequest); +} diff --git a/bindings/java/src/main/java/org/openucx/jucx/ucp/UcpWorker.java b/bindings/java/src/main/java/org/openucx/jucx/ucp/UcpWorker.java index c612722cd63..0cc279e12f4 100644 --- a/bindings/java/src/main/java/org/openucx/jucx/ucp/UcpWorker.java +++ b/bindings/java/src/main/java/org/openucx/jucx/ucp/UcpWorker.java @@ -68,7 +68,7 @@ public int progress() { * this worker. All the AMO and RMA operations issued on this worker prior to this call * are completed both at the origin and at the target when this call returns. */ - public UcxRequest flushNonBlocking(UcxCallback callback) { + public UcpRequest flushNonBlocking(UcxCallback callback) { return flushNonBlockingNative(getNativeId(), callback); } @@ -114,7 +114,7 @@ public void signal() { * @param tagMask - bit mask that indicates the bits that are used for the matching of the * incoming tag against the expected tag. */ - public UcxRequest recvTaggedNonBlocking(ByteBuffer recvBuffer, long tag, long tagMask, + public UcpRequest recvTaggedNonBlocking(ByteBuffer recvBuffer, long tag, long tagMask, UcxCallback callback) { if (!recvBuffer.isDirect()) { throw new UcxException("Recv buffer must be direct."); @@ -123,7 +123,7 @@ public UcxRequest recvTaggedNonBlocking(ByteBuffer recvBuffer, long tag, long ta recvBuffer.remaining(), tag, tagMask, callback); } - public UcxRequest recvTaggedNonBlocking(long localAddress, long size, long tag, long tagMask, + public UcpRequest recvTaggedNonBlocking(long localAddress, long size, long tag, long tagMask, UcxCallback callback) { return recvTaggedNonBlockingNative(getNativeId(), localAddress, size, tag, tagMask, callback); @@ -134,10 +134,22 @@ public UcxRequest recvTaggedNonBlocking(long localAddress, long size, long tag, * {@link UcpWorker#recvTaggedNonBlocking(ByteBuffer, long, long, UcxCallback)} * with default tag=0 and tagMask=0. */ - public UcxRequest recvTaggedNonBlocking(ByteBuffer recvBuffer, UcxCallback callback) { + public UcpRequest recvTaggedNonBlocking(ByteBuffer recvBuffer, UcxCallback callback) { return recvTaggedNonBlocking(recvBuffer, 0, 0, callback); } + /** + * This routine tries to cancels an outstanding communication request. After + * calling this routine, the request will be in completed or canceled (but + * not both) state regardless of the status of the target endpoint associated + * with the communication request. If the request is completed successfully, + * the "send" or the "receive" completion callbacks (based on the type of the request) will be + * called with the status argument of the callback set to UCS_OK, and in a + * case it is canceled the status argument is set to UCS_ERR_CANCELED. + */ + public void cancelRequest(UcpRequest request) { + cancelRequestNative(getNativeId(), request.getNativeId()); + } /** * This routine returns the address of the worker object. This address can be @@ -167,13 +179,15 @@ public ByteBuffer getAddress() { private static native int progressWorkerNative(long workerId); - private static native UcxRequest flushNonBlockingNative(long workerId, UcxCallback callback); + private static native UcpRequest flushNonBlockingNative(long workerId, UcxCallback callback); private static native void waitWorkerNative(long workerId); private static native void signalWorkerNative(long workerId); - private static native UcxRequest recvTaggedNonBlockingNative(long workerId, long localAddress, + private static native UcpRequest recvTaggedNonBlockingNative(long workerId, long localAddress, long size, long tag, long tagMask, UcxCallback callback); + + private static native void cancelRequestNative(long workerId, long requestId); } diff --git a/bindings/java/src/main/native/Makefile.am b/bindings/java/src/main/native/Makefile.am index 8229540cab4..09b0ffa9e70 100644 --- a/bindings/java/src/main/native/Makefile.am +++ b/bindings/java/src/main/native/Makefile.am @@ -16,6 +16,7 @@ JUCX_GENERATED_H_FILES = org_openucx_jucx_ucp_UcpConstants.h \ org_openucx_jucx_ucp_UcpEndpoint.h \ org_openucx_jucx_ucp_UcpListener.h \ org_openucx_jucx_ucp_UcpMemory.h \ + org_openucx_jucx_ucp_UcpRequest.h \ org_openucx_jucx_ucp_UcpRemoteKey.h \ org_openucx_jucx_ucp_UcpWorker.h \ org_openucx_jucx_ucs_UcsConstants_ThreadMode.h \ @@ -47,6 +48,7 @@ libjucx_la_SOURCES = context.cc \ jucx_common_def.cc \ listener.cc \ memory.cc \ + request.cc \ ucp_constants.cc \ ucs_constants.cc \ worker.cc diff --git a/bindings/java/src/main/native/jucx_common_def.cc b/bindings/java/src/main/native/jucx_common_def.cc index 421f40e8529..74d6f9d837b 100644 --- a/bindings/java/src/main/native/jucx_common_def.cc +++ b/bindings/java/src/main/native/jucx_common_def.cc @@ -17,7 +17,7 @@ extern "C" { static JavaVM *jvm_global; static jclass jucx_request_cls; -static jfieldID completed_field; +static jfieldID native_id_field; static jmethodID on_success; static jmethodID jucx_request_constructor; @@ -29,13 +29,13 @@ extern "C" JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM *jvm, void* reserved) { return JNI_ERR; } - jclass jucx_request_cls_local = env->FindClass("org/openucx/jucx/UcxRequest"); + jclass jucx_request_cls_local = env->FindClass("org/openucx/jucx/ucp/UcpRequest"); jucx_request_cls = (jclass) env->NewGlobalRef(jucx_request_cls_local); jclass jucx_callback_cls = env->FindClass("org/openucx/jucx/UcxCallback"); - completed_field = env->GetFieldID(jucx_request_cls, "completed", "Z"); + native_id_field = env->GetFieldID(jucx_request_cls, "nativeId", "Ljava/lang/Long;"); on_success = env->GetMethodID(jucx_callback_cls, "onSuccess", - "(Lorg/openucx/jucx/UcxRequest;)V"); - jucx_request_constructor = env->GetMethodID(jucx_request_cls, "", "()V"); + "(Lorg/openucx/jucx/ucp/UcpRequest;)V"); + jucx_request_constructor = env->GetMethodID(jucx_request_cls, "", "(J)V"); return JNI_VERSION_1_1; } @@ -155,7 +155,7 @@ JNIEnv* get_jni_env() static inline void set_jucx_request_completed(JNIEnv *env, jobject jucx_request) { - env->SetBooleanField(jucx_request, completed_field, true); + env->SetObjectField(jucx_request, native_id_field, NULL); } static inline void call_on_success(jobject callback, jobject request) @@ -219,7 +219,8 @@ void recv_callback(void *request, ucs_status_t status, ucp_tag_recv_info_t *info UCS_PROFILE_FUNC(jobject, process_request, (request, callback), void *request, jobject callback) { JNIEnv *env = get_jni_env(); - jobject jucx_request = env->NewObject(jucx_request_cls, jucx_request_constructor); + jobject jucx_request = env->NewObject(jucx_request_cls, jucx_request_constructor, + (native_ptr)request); if (UCS_PTR_IS_PTR(request)) { struct jucx_context *ctx = (struct jucx_context *)request; diff --git a/bindings/java/src/main/native/request.cc b/bindings/java/src/main/native/request.cc new file mode 100644 index 00000000000..d65619b922e --- /dev/null +++ b/bindings/java/src/main/native/request.cc @@ -0,0 +1,23 @@ +/* + * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. + * See file LICENSE for terms. + */ + +#include "org_openucx_jucx_ucp_UcpRequest.h" + +#include +#include + +JNIEXPORT jboolean JNICALL +Java_org_openucx_jucx_ucp_UcpRequest_isCompletedNative(JNIEnv *env, jclass cls, + jlong ucp_req_ptr) +{ + return ucp_request_check_status((void *)ucp_req_ptr) != UCS_INPROGRESS; +} + +JNIEXPORT void JNICALL +Java_org_openucx_jucx_ucp_UcpRequest_closeRequestNative(JNIEnv *env, jclass cls, + jlong ucp_req_ptr) +{ + ucp_request_free((void *)ucp_req_ptr); +} diff --git a/bindings/java/src/main/native/worker.cc b/bindings/java/src/main/native/worker.cc index 6c908b531b8..6e36a62cdcb 100644 --- a/bindings/java/src/main/native/worker.cc +++ b/bindings/java/src/main/native/worker.cc @@ -156,3 +156,11 @@ Java_org_openucx_jucx_ucp_UcpWorker_recvTaggedNonBlockingNative(JNIEnv *env, jcl return process_request(request, callback); } + +JNIEXPORT void JNICALL +Java_org_openucx_jucx_ucp_UcpWorker_cancelRequestNative(JNIEnv *env, jclass cls, + jlong ucp_worker_ptr, + jlong ucp_request_ptr) +{ + ucp_request_cancel((ucp_worker_h)ucp_worker_ptr, (void *)ucp_request_ptr); +} diff --git a/bindings/java/src/test/java/org/openucx/jucx/UcpEndpointTest.java b/bindings/java/src/test/java/org/openucx/jucx/UcpEndpointTest.java index 615ed1c8bd2..92e0d1b43a2 100644 --- a/bindings/java/src/test/java/org/openucx/jucx/UcpEndpointTest.java +++ b/bindings/java/src/test/java/org/openucx/jucx/UcpEndpointTest.java @@ -104,10 +104,10 @@ public void testGetNB() { UcpRemoteKey rkey2 = endpoint.unpackRemoteKey(memory2.getRemoteKeyBuffer()); AtomicInteger numCompletedRequests = new AtomicInteger(0); - HashMap requestToData = new HashMap<>(); + HashMap requestToData = new HashMap<>(); UcxCallback callback = new UcxCallback() { @Override - public void onSuccess(UcxRequest request) { + public void onSuccess(UcpRequest request) { // Here thread safety is guaranteed since worker progress is called after // request added to map. In multithreaded environment could be an issue that // callback is called, but request wasn't added yet to map. @@ -124,8 +124,8 @@ public void onSuccess(UcxRequest request) { }; // Submit 2 get requests - UcxRequest request1 = endpoint.getNonBlocking(memory1.getAddress(), rkey1, dst1, callback); - UcxRequest request2 = endpoint.getNonBlocking(memory2.getAddress(), rkey2, dst2, callback); + UcpRequest request1 = endpoint.getNonBlocking(memory1.getAddress(), rkey1, dst1, callback); + UcpRequest request2 = endpoint.getNonBlocking(memory2.getAddress(), rkey2, dst2, callback); // Map each request to corresponding data buffer. requestToData.put(request1, dst1); @@ -167,10 +167,10 @@ public void testPutNB() { worker1.newEndpoint(new UcpEndpointParams().setUcpAddress(worker2.getAddress())); UcpRemoteKey rkey = ep.unpackRemoteKey(memory.getRemoteKeyBuffer()); - UcxRequest request = ep.putNonBlocking(src, memory.getAddress(), rkey, + UcpRequest request = ep.putNonBlocking(src, memory.getAddress(), rkey, new UcxCallback() { @Override - public void onSuccess(UcxRequest request) { + public void onSuccess(UcpRequest request) { rkey.close(); memory.deregister(); } @@ -214,7 +214,7 @@ public void testSendRecv() throws Exception { AtomicInteger receivedMessages = new AtomicInteger(0); worker2.recvTaggedNonBlocking(dst1, 0, 0, new UcxCallback() { @Override - public void onSuccess(UcxRequest request) { + public void onSuccess(UcpRequest request) { assertEquals(dst1, src1); receivedMessages.incrementAndGet(); } @@ -222,7 +222,7 @@ public void onSuccess(UcxRequest request) { worker2.recvTaggedNonBlocking(dst2, 1, -1, new UcxCallback() { @Override - public void onSuccess(UcxRequest request) { + public void onSuccess(UcpRequest request) { assertEquals(dst2, src2); receivedMessages.incrementAndGet(); } @@ -291,7 +291,7 @@ public void run() { worker2.recvTaggedNonBlocking(dst1, 0, -1, new UcxCallback() { @Override - public void onSuccess(UcxRequest request) { + public void onSuccess(UcpRequest request) { success.set(true); } }); @@ -337,7 +337,7 @@ public void testBufferOffset() { ByteBuffer bigSendBuffer = ByteBuffer.allocateDirect(UcpMemoryTest.MEM_SIZE); bigRecvBuffer.position(offset).limit(offset + msgSize); - UcxRequest recv = worker1.recvTaggedNonBlocking(bigRecvBuffer, 0, + UcpRequest recv = worker1.recvTaggedNonBlocking(bigRecvBuffer, 0, 0, null); UcpEndpoint ep = worker2.newEndpoint(new UcpEndpointParams() @@ -352,7 +352,7 @@ public void testBufferOffset() { bigSendBuffer.put(msg); bigSendBuffer.position(offset); - UcxRequest sent = ep.sendTaggedNonBlocking(bigSendBuffer, 0, null); + UcpRequest sent = ep.sendTaggedNonBlocking(bigSendBuffer, 0, null); while (!sent.isCompleted() || !recv.isCompleted()) { worker1.progress(); @@ -399,9 +399,9 @@ public void testFlushEp() { UcxUtils.getAddress(dst) + i * blockSize, blockSize); } - UcxRequest request = ep.flushNonBlocking(new UcxCallback() { + UcpRequest request = ep.flushNonBlocking(new UcxCallback() { @Override - public void onSuccess(UcxRequest request) { + public void onSuccess(UcpRequest request) { rkey.close(); memory.deregister(); assertEquals(dst.asCharBuffer().toString().trim(), UcpMemoryTest.RANDOM_TEXT); diff --git a/bindings/java/src/test/java/org/openucx/jucx/UcpRequestTest.java b/bindings/java/src/test/java/org/openucx/jucx/UcpRequestTest.java new file mode 100644 index 00000000000..0ac1fc6327c --- /dev/null +++ b/bindings/java/src/test/java/org/openucx/jucx/UcpRequestTest.java @@ -0,0 +1,31 @@ +/* + * Copyright (C) Mellanox Technologies Ltd. 2001-2019. ALL RIGHTS RESERVED. + * See file LICENSE for terms. + */ +package org.openucx.jucx; + +import org.junit.Test; +import org.openucx.jucx.ucp.*; + +import java.nio.ByteBuffer; +import static org.junit.Assert.*; + +public class UcpRequestTest { + @Test + public void testCancelRequest() { + UcpContext context = new UcpContext(new UcpParams().requestTagFeature()); + UcpWorker worker = context.newWorker(new UcpWorkerParams()); + UcpRequest recv = worker.recvTaggedNonBlocking(ByteBuffer.allocateDirect(100), null); + worker.cancelRequest(recv); + + while (!recv.isCompleted()) { + worker.progress(); + } + + assertTrue(recv.isCompleted()); + assertNull(recv.getNativeId()); + + worker.close(); + context.close(); + } +} diff --git a/bindings/java/src/test/java/org/openucx/jucx/UcpWorkerTest.java b/bindings/java/src/test/java/org/openucx/jucx/UcpWorkerTest.java index d2b6ec2726f..9b0616f4d20 100644 --- a/bindings/java/src/test/java/org/openucx/jucx/UcpWorkerTest.java +++ b/bindings/java/src/test/java/org/openucx/jucx/UcpWorkerTest.java @@ -145,9 +145,9 @@ public void testFlushWorker() { blockSize, memory.getAddress() + i * blockSize, rkey); } - UcxRequest request = worker1.flushNonBlocking(new UcxCallback() { + UcpRequest request = worker1.flushNonBlocking(new UcxCallback() { @Override - public void onSuccess(UcxRequest request) { + public void onSuccess(UcpRequest request) { rkey.close(); memory.deregister(); assertEquals(dst.asCharBuffer().toString().trim(), UcpMemoryTest.RANDOM_TEXT);