diff --git a/NEWS b/NEWS index 3fdc9921105..3767ea918cb 100644 --- a/NEWS +++ b/NEWS @@ -12,6 +12,7 @@ * Fix Infiniband port speed detection for HDR100 * Fix build issues in gtest-all.cc and sock.c with GCC11 * Fix performance degradation with cuda memory on self endpoint +* Fix bug in JUCX listener connection handler. ## 1.10.0 (March 9, 2021) ### Features: diff --git a/bindings/java/src/main/java/org/openucx/jucx/ucp/UcpListener.java b/bindings/java/src/main/java/org/openucx/jucx/ucp/UcpListener.java index 63c0ac003b1..00ea35eda4e 100644 --- a/bindings/java/src/main/java/org/openucx/jucx/ucp/UcpListener.java +++ b/bindings/java/src/main/java/org/openucx/jucx/ucp/UcpListener.java @@ -17,13 +17,18 @@ public class UcpListener extends UcxNativeStruct implements Closeable { private InetSocketAddress address; + private UcpListenerConnectionHandler connectionHandler; public UcpListener(UcpWorker worker, UcpListenerParams params) { if (params.getSockAddr() == null) { throw new UcxException("UcpListenerParams.sockAddr must be non-null."); } + if (params.connectionHandler == null) { + throw new UcxException("Connection handler must be set"); + } + this.connectionHandler = params.connectionHandler; + this.address = params.getSockAddr(); setNativeId(createUcpListener(params, worker.getNativeId())); - address = params.getSockAddr(); } /** diff --git a/bindings/java/src/main/java/org/openucx/jucx/ucp/UcpListenerParams.java b/bindings/java/src/main/java/org/openucx/jucx/ucp/UcpListenerParams.java index 28153a0772d..94fdc8c96ad 100644 --- a/bindings/java/src/main/java/org/openucx/jucx/ucp/UcpListenerParams.java +++ b/bindings/java/src/main/java/org/openucx/jucx/ucp/UcpListenerParams.java @@ -14,12 +14,13 @@ public class UcpListenerParams extends UcxParams { public UcpListenerParams clear() { super.clear(); sockAddr = null; + connectionHandler = null; return this; } private InetSocketAddress sockAddr; - private UcpListenerConnectionHandler connectionHandler; + UcpListenerConnectionHandler connectionHandler; /** * An address, on which {@link UcpListener} would bind. diff --git a/bindings/java/src/main/native/jucx_common_def.cc b/bindings/java/src/main/native/jucx_common_def.cc index 913edc66dfd..0965b91b8bf 100644 --- a/bindings/java/src/main/native/jucx_common_def.cc +++ b/bindings/java/src/main/native/jucx_common_def.cc @@ -338,7 +338,6 @@ void jucx_connection_handler(ucp_conn_request_h conn_request, void *arg) jmethodID on_conn_request = env->GetMethodID(jucx_conn_hndl_cls, "onConnectionRequest", "(Lorg/openucx/jucx/ucp/UcpConnectionRequest;)V"); env->CallVoidMethod(jucx_conn_handler, on_conn_request, jucx_conn_request); - env->DeleteGlobalRef(jucx_conn_handler); } diff --git a/bindings/java/src/main/native/listener.cc b/bindings/java/src/main/native/listener.cc index 3114e71488f..062b08028f2 100644 --- a/bindings/java/src/main/native/listener.cc +++ b/bindings/java/src/main/native/listener.cc @@ -44,7 +44,7 @@ Java_org_openucx_jucx_ucp_UcpListener_createUcpListener(JNIEnv *env, jclass cls, field = env->GetFieldID(jucx_listener_param_class, "connectionHandler", "Lorg/openucx/jucx/ucp/UcpListenerConnectionHandler;"); jobject jucx_conn_handler = env->GetObjectField(ucp_listener_params, field); - params.conn_handler.arg = env->NewGlobalRef(jucx_conn_handler); + params.conn_handler.arg = env->NewWeakGlobalRef(jucx_conn_handler); params.conn_handler.cb = jucx_connection_handler; } diff --git a/bindings/java/src/test/java/org/openucx/jucx/UcpListenerTest.java b/bindings/java/src/test/java/org/openucx/jucx/UcpListenerTest.java index f17d2b2940d..d1e787505af 100644 --- a/bindings/java/src/test/java/org/openucx/jucx/UcpListenerTest.java +++ b/bindings/java/src/test/java/org/openucx/jucx/UcpListenerTest.java @@ -114,14 +114,32 @@ public void testConnectionHandler() throws Exception { UcpEndpoint serverToClient = serverWorker2.newEndpoint( new UcpEndpointParams().setConnectionRequest(conRequest.get())); - // Temporary workaround until new connection establishment protocol in UCX. + // Test connection handler persists for (int i = 0; i < 10; i++) { - serverWorker1.progress(); - serverWorker2.progress(); - clientWorker.progress(); - try { - Thread.sleep(10); - } catch (Exception ignored) { } + conRequest.set(null); + UcpEndpoint tmpEp = clientWorker.newEndpoint(new UcpEndpointParams() + .setSocketAddress(listener.getAddress()).setPeerErrorHandlingMode() + .setErrorHandler((ep, status, errorMsg) -> { + + })); + + while (conRequest.get() == null) { + serverWorker1.progress(); + serverWorker2.progress(); + clientWorker.progress(); + } + + UcpEndpoint tmpEp2 = serverWorker2.newEndpoint( + new UcpEndpointParams().setConnectionRequest(conRequest.get())); + + UcpRequest close1 = tmpEp.closeNonBlockingFlush(); + UcpRequest close2 = tmpEp2.closeNonBlockingFlush(); + + while (!close1.isCompleted() || !close2.isCompleted()) { + serverWorker1.progress(); + serverWorker2.progress(); + clientWorker.progress(); + } } UcpRequest sent = serverToClient.sendStreamNonBlocking(