Skip to content

Commit

Permalink
Merge pull request openucx#5920 from evgeny-leksikov/ucp_cm_enable_y
Browse files Browse the repository at this point in the history
UCP: SOCKADDR_CM_ENABLE=y by default
  • Loading branch information
yosefe authored Nov 26, 2020
2 parents 16c554f + 5ca2a32 commit ab9c2f4
Show file tree
Hide file tree
Showing 13 changed files with 162 additions and 64 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,6 @@ public static void main(String[] args) throws Exception {
.setConnectionRequest(connRequest.get())
.setPeerErrorHandlingMode());

// Temporary workaround until new connection establishment protocol in UCX.
for (int i = 0; i < 10; i++) {
worker.progress();
try {
Thread.sleep(10);
} catch (Exception ignored) { }
}

ByteBuffer recvBuffer = ByteBuffer.allocateDirect(4096);
UcpRequest recvRequest = worker.recvTaggedNonBlocking(recvBuffer, null);

Expand Down Expand Up @@ -95,12 +87,6 @@ public void onSuccess(UcpRequest request) {
data.put(0, (byte)1);
}

ByteBuffer sendBuffer = ByteBuffer.allocateDirect(100);
sendBuffer.asCharBuffer().put("DONE");

UcpRequest sent = endpoint.sendTaggedNonBlocking(sendBuffer, null);
worker.progressRequest(sent);

UcpRequest closeRequest = endpoint.closeNonBlockingFlush();
worker.progressRequest(closeRequest);
// Close request won't be return to pull automatically, since there's no callback.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,12 @@

package org.openucx.jucx.examples;

import org.openucx.jucx.UcxCallback;
import org.openucx.jucx.ucp.UcpRequest;
import org.openucx.jucx.UcxException;
import org.openucx.jucx.ucp.*;
import org.openucx.jucx.UcxUtils;
import org.openucx.jucx.ucp.UcpEndpoint;
import org.openucx.jucx.ucp.UcpEndpointParams;
import org.openucx.jucx.ucp.UcpMemory;
import org.openucx.jucx.ucs.UcsConstants;

import java.net.ConnectException;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;

Expand All @@ -28,6 +27,13 @@ public static void main(String[] args) throws Exception {
String serverHost = argsMap.get("s");
UcpEndpoint endpoint = worker.newEndpoint(new UcpEndpointParams()
.setPeerErrorHandlingMode()
.setErrorHandler((ep, status, errorMsg) -> {
if (status == UcsConstants.STATUS.UCS_ERR_CONNECTION_RESET) {
throw new ConnectException(errorMsg);
} else {
throw new UcxException(errorMsg);
}
})
.setSocketAddress(new InetSocketAddress(serverHost, serverPort)));

UcpMemory memory = context.memoryMap(allocationParams);
Expand All @@ -49,22 +55,21 @@ public static void main(String[] args) throws Exception {

// Send memory metadata and wait until receiver will finish benchmark.
endpoint.sendTaggedNonBlocking(sendData, null);
ByteBuffer recvBuffer = ByteBuffer.allocateDirect(4096);
UcpRequest recvRequest = worker.recvTaggedNonBlocking(recvBuffer,
new UcxCallback() {
@Override
public void onSuccess(UcpRequest request) {
System.out.println("Received a message:");
System.out.println(recvBuffer.asCharBuffer().toString().trim());
}
});

worker.progressRequest(recvRequest);

UcpRequest closeRequest = endpoint.closeNonBlockingFlush();
worker.progressRequest(closeRequest);
resources.push(closeRequest);

closeResources();
try {
while (true) {
if (worker.progress() == 0) {
worker.waitForEvents();
}
}
} catch (ConnectException ignored) {
} catch (Exception ex) {
System.err.println(ex.getMessage());
} finally {
UcpRequest closeRequest = endpoint.closeNonBlockingForce();
worker.progressRequest(closeRequest);
resources.push(closeRequest);
closeResources();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,5 @@ public interface UcpEndpointErrorHandler {
* all subsequent operations on this ep will fail with
* the error code passed in {@code status}.
*/
void onError(UcpEndpoint ep, int status, String errorMsg);
void onError(UcpEndpoint ep, int status, String errorMsg) throws Exception;
}
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,14 @@ public void close() {
* This routine explicitly progresses all communication operations on a worker.
* @return Non-zero if any communication was progressed, zero otherwise.
*/
public int progress() {
public int progress() throws Exception {
return progressWorkerNative(getNativeId());
}

/**
* Blocking progress for request until it's not completed.
*/
public void progressRequest(UcpRequest request) {
public void progressRequest(UcpRequest request) throws Exception {
while (!request.isCompleted()) {
progress();
}
Expand Down Expand Up @@ -251,7 +251,7 @@ public ByteBuffer getAddress() {

private static native void releaseAddressNative(long workerId, ByteBuffer addressId);

private static native int progressWorkerNative(long workerId);
private static native int progressWorkerNative(long workerId) throws Exception;

private static native UcpRequest flushNonBlockingNative(long workerId, UcxCallback callback);

Expand Down
50 changes: 50 additions & 0 deletions bindings/java/src/main/java/org/openucx/jucx/ucs/UcsConstants.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,56 @@ public static class ThreadMode {
public static int UCS_THREAD_MODE_MULTI;
}

/**
* Status codes
*/
public static class STATUS {
static {
load();
}

/* Operation completed successfully */
public static int UCS_OK;

/* Operation is queued and still in progress */
public static int UCS_INPROGRESS;

/* Failure codes */
public static int UCS_ERR_NO_MESSAGE;
public static int UCS_ERR_NO_RESOURCE;
public static int UCS_ERR_IO_ERROR;
public static int UCS_ERR_NO_MEMORY;
public static int UCS_ERR_INVALID_PARAM;
public static int UCS_ERR_UNREACHABLE;
public static int UCS_ERR_INVALID_ADDR;
public static int UCS_ERR_NOT_IMPLEMENTED;
public static int UCS_ERR_MESSAGE_TRUNCATED;
public static int UCS_ERR_NO_PROGRESS;
public static int UCS_ERR_BUFFER_TOO_SMALL;
public static int UCS_ERR_NO_ELEM;
public static int UCS_ERR_SOME_CONNECTS_FAILED;
public static int UCS_ERR_NO_DEVICE;
public static int UCS_ERR_BUSY;
public static int UCS_ERR_CANCELED;
public static int UCS_ERR_SHMEM_SEGMENT;
public static int UCS_ERR_ALREADY_EXISTS;
public static int UCS_ERR_OUT_OF_RANGE;
public static int UCS_ERR_TIMED_OUT;
public static int UCS_ERR_EXCEEDS_LIMIT;
public static int UCS_ERR_UNSUPPORTED;
public static int UCS_ERR_REJECTED;
public static int UCS_ERR_NOT_CONNECTED;
public static int UCS_ERR_CONNECTION_RESET;

public static int UCS_ERR_FIRST_LINK_FAILURE;
public static int UCS_ERR_LAST_LINK_FAILURE;
public static int UCS_ERR_FIRST_ENDPOINT_FAILURE;
public static int UCS_ERR_ENDPOINT_TIMEOUT;
public static int UCS_ERR_LAST_ENDPOINT_FAILURE;

public static int UCS_ERR_LAST;
}

private static void load() {
NativeLibs.load();
loadConstants();
Expand Down
42 changes: 41 additions & 1 deletion bindings/java/src/main/native/ucs_constants.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,49 @@
#include <ucs/type/thread_mode.h>

JNIEXPORT void JNICALL
Java_org_openucx_jucx_ucs_UcsConstants_loadConstants(JNIEnv *env, jclass cls)
Java_org_openucx_jucx_ucs_UcsConstants_loadConstants(JNIEnv *env, jclass ucs_class)
{
jclass thread_mode = env->FindClass("org/openucx/jucx/ucs/UcsConstants$ThreadMode");
jfieldID field = env->GetStaticFieldID(thread_mode, "UCS_THREAD_MODE_MULTI", "I");
env->SetStaticIntField(thread_mode, field, UCS_THREAD_MODE_MULTI);

jclass cls = env->FindClass("org/openucx/jucx/ucs/UcsConstants$STATUS");

/* Operation completed successfully */
JUCX_DEFINE_INT_CONSTANT(UCS_OK);

/* Operation is queued and stil in progress */
JUCX_DEFINE_INT_CONSTANT(UCS_INPROGRESS);
/* Failure codes */
JUCX_DEFINE_INT_CONSTANT(UCS_ERR_NO_MESSAGE);
JUCX_DEFINE_INT_CONSTANT(UCS_ERR_NO_RESOURCE);
JUCX_DEFINE_INT_CONSTANT(UCS_ERR_IO_ERROR);
JUCX_DEFINE_INT_CONSTANT(UCS_ERR_NO_MEMORY);
JUCX_DEFINE_INT_CONSTANT(UCS_ERR_INVALID_PARAM);
JUCX_DEFINE_INT_CONSTANT(UCS_ERR_UNREACHABLE);
JUCX_DEFINE_INT_CONSTANT(UCS_ERR_INVALID_ADDR);
JUCX_DEFINE_INT_CONSTANT(UCS_ERR_NOT_IMPLEMENTED);
JUCX_DEFINE_INT_CONSTANT(UCS_ERR_MESSAGE_TRUNCATED);
JUCX_DEFINE_INT_CONSTANT(UCS_ERR_NO_PROGRESS);
JUCX_DEFINE_INT_CONSTANT(UCS_ERR_BUFFER_TOO_SMALL);
JUCX_DEFINE_INT_CONSTANT(UCS_ERR_NO_ELEM);
JUCX_DEFINE_INT_CONSTANT(UCS_ERR_SOME_CONNECTS_FAILED);
JUCX_DEFINE_INT_CONSTANT(UCS_ERR_NO_DEVICE);
JUCX_DEFINE_INT_CONSTANT(UCS_ERR_BUSY);
JUCX_DEFINE_INT_CONSTANT(UCS_ERR_CANCELED);
JUCX_DEFINE_INT_CONSTANT(UCS_ERR_SHMEM_SEGMENT);
JUCX_DEFINE_INT_CONSTANT(UCS_ERR_ALREADY_EXISTS);
JUCX_DEFINE_INT_CONSTANT(UCS_ERR_OUT_OF_RANGE);
JUCX_DEFINE_INT_CONSTANT(UCS_ERR_TIMED_OUT);
JUCX_DEFINE_INT_CONSTANT(UCS_ERR_EXCEEDS_LIMIT);
JUCX_DEFINE_INT_CONSTANT(UCS_ERR_UNSUPPORTED);
JUCX_DEFINE_INT_CONSTANT(UCS_ERR_REJECTED);
JUCX_DEFINE_INT_CONSTANT(UCS_ERR_NOT_CONNECTED);
JUCX_DEFINE_INT_CONSTANT(UCS_ERR_CONNECTION_RESET);
JUCX_DEFINE_INT_CONSTANT(UCS_ERR_FIRST_LINK_FAILURE);
JUCX_DEFINE_INT_CONSTANT(UCS_ERR_LAST_LINK_FAILURE);
JUCX_DEFINE_INT_CONSTANT(UCS_ERR_FIRST_ENDPOINT_FAILURE);
JUCX_DEFINE_INT_CONSTANT(UCS_ERR_ENDPOINT_TIMEOUT);
JUCX_DEFINE_INT_CONSTANT(UCS_ERR_LAST_ENDPOINT_FAILURE);
JUCX_DEFINE_INT_CONSTANT(UCS_ERR_LAST);
}
25 changes: 15 additions & 10 deletions bindings/java/src/test/java/org/openucx/jucx/UcpEndpointTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ public void testConnectToListenerByWorkerAddr() {
}

@Test
public void testGetNB() {
public void testGetNB() throws Exception {
// Crerate 2 contexts + 2 workers
UcpParams params = new UcpParams().requestRmaFeature();
UcpWorkerParams rdmaWorkerParams = new UcpWorkerParams().requestWakeupRMA();
Expand Down Expand Up @@ -102,7 +102,7 @@ public void onSuccess(UcpRequest request) {
}

@Test
public void testPutNB() {
public void testPutNB() throws Exception {
// Crerate 2 contexts + 2 workers
UcpParams params = new UcpParams().requestRmaFeature();
UcpWorkerParams rdmaWorkerParams = new UcpWorkerParams().requestWakeupRMA();
Expand Down Expand Up @@ -186,7 +186,7 @@ public void onSuccess(UcpRequest request) {
}

@Test
public void testRecvAfterSend() {
public void testRecvAfterSend() throws Exception {
long sendTag = 4L;
// Crerate 2 contexts + 2 workers
UcpParams params = new UcpParams().requestRmaFeature().requestTagFeature()
Expand All @@ -211,8 +211,13 @@ public void testRecvAfterSend() {
@Override
public void run() {
while (!isInterrupted()) {
worker1.progress();
worker2.progress();
try {
worker1.progress();
worker2.progress();
} catch (Exception ex) {
System.err.println(ex.getMessage());
ex.printStackTrace();
}
}
}
};
Expand Down Expand Up @@ -263,7 +268,7 @@ public void onSuccess(UcpRequest request) {
}

@Test
public void testBufferOffset() {
public void testBufferOffset() throws Exception {
int msgSize = 200;
int offset = 100;
// Crerate 2 contexts + 2 workers
Expand Down Expand Up @@ -311,7 +316,7 @@ public void testBufferOffset() {
}

@Test
public void testFlushEp() {
public void testFlushEp() throws Exception {
int numRequests = 10;
// Crerate 2 contexts + 2 workers
UcpParams params = new UcpParams().requestRmaFeature();
Expand Down Expand Up @@ -356,7 +361,7 @@ public void onSuccess(UcpRequest request) {
}

@Test
public void testRecvSize() {
public void testRecvSize() throws Exception {
UcpContext context1 = new UcpContext(new UcpParams().requestTagFeature());
UcpContext context2 = new UcpContext(new UcpParams().requestTagFeature());

Expand Down Expand Up @@ -386,7 +391,7 @@ public void testRecvSize() {
}

@Test
public void testStreamingAPI() {
public void testStreamingAPI() throws Exception {
UcpParams params = new UcpParams().requestStreamFeature().requestRmaFeature();
UcpContext context1 = new UcpContext(params);
UcpContext context2 = new UcpContext(params);
Expand Down Expand Up @@ -537,7 +542,7 @@ public void testIovOperations() throws Exception {
}

@Test
public void testEpErrorHandler() {
public void testEpErrorHandler() throws Exception {
// Crerate 2 contexts + 2 workers
UcpParams params = new UcpParams().requestTagFeature();
UcpWorkerParams workerParams = new UcpWorkerParams();
Expand Down
14 changes: 11 additions & 3 deletions bindings/java/src/test/java/org/openucx/jucx/UcpListenerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ static UcpListener tryBindListener(UcpWorker worker, UcpListenerParams params) {
}

@Test
public void testConnectionHandler() {
public void testConnectionHandler() throws Exception {
UcpContext context1 = new UcpContext(new UcpParams().requestStreamFeature()
.requestRmaFeature());
UcpContext context2 = new UcpContext(new UcpParams().requestStreamFeature()
Expand All @@ -113,7 +113,7 @@ public void testConnectionHandler() {
// Create endpoint from another worker from pool.
UcpEndpoint serverToClient = serverWorker2.newEndpoint(
new UcpEndpointParams().setConnectionRequest(conRequest.get()));

// Temporary workaround until new connection establishment protocol in UCX.
for (int i = 0; i < 10; i++) {
serverWorker1.progress();
Expand Down Expand Up @@ -147,8 +147,16 @@ public void testConnectionHandler() {

assertEquals(UcpMemoryTest.MEM_SIZE, recv.getRecvSize());

UcpRequest serverClose = serverToClient.closeNonBlockingFlush();
UcpRequest clientClose = clientToServer.closeNonBlockingFlush();

while (!serverClose.isCompleted() || !clientClose.isCompleted()) {
serverWorker2.progress();
clientWorker.progress();
}

Collections.addAll(resources, context2, context1, clientWorker, serverWorker1,
serverWorker2, listener, serverToClient, clientToServer);
serverWorker2, listener);
closeResources();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

public class UcpRequestTest {
@Test
public void testCancelRequest() {
public void testCancelRequest() throws Exception {
UcpContext context = new UcpContext(new UcpParams().requestTagFeature());
UcpWorker worker = context.newWorker(new UcpWorkerParams());
UcpRequest recv = worker.recvTaggedNonBlocking(ByteBuffer.allocateDirect(100), null);
Expand Down
Loading

0 comments on commit ab9c2f4

Please sign in to comment.