From 107f4e6d541f017cbc3f804791c64520b400c44a Mon Sep 17 00:00:00 2001 From: Asheesh Laroia Date: Tue, 4 Aug 2020 20:56:39 -0700 Subject: [PATCH] Enable calling Java methods that take arrays --- org/beeware/rubicon/test/Example.java | 24 ++++++++++++++++++++++++ rubicon/java/api.py | 23 +++++++++++++++++++++++ rubicon/java/jni.py | 21 ++++++++++++++++++++- rubicon/java/types.py | 4 ++-- tests/test_rubicon.py | 15 +++++++++++++++ 5 files changed, 84 insertions(+), 3 deletions(-) diff --git a/org/beeware/rubicon/test/Example.java b/org/beeware/rubicon/test/Example.java index eb08632..1af2953 100644 --- a/org/beeware/rubicon/test/Example.java +++ b/org/beeware/rubicon/test/Example.java @@ -28,6 +28,30 @@ static public long get_static_long_field() { return static_long_field; } + static public int sum_all_ints(int[] numbers) { + int sum = 0; + for (int number : numbers) { + sum += number; + } + return sum; + } + + static public double sum_all_doubles(double[] numbers) { + double sum = 0; + for (double number : numbers) { + sum += number; + } + return sum; + } + + static public float sum_all_floats(float[] numbers) { + float sum = 0; + for (float number : numbers) { + sum += number; + } + return sum; + } + /* An inner enumerated type */ public enum Stuff { FOO, BAR, WHIZ; diff --git a/rubicon/java/api.py b/rubicon/java/api.py index 2be18ea..e582545 100644 --- a/rubicon/java/api.py +++ b/rubicon/java/api.py @@ -1,3 +1,4 @@ +from collections.abc import Sequence import itertools from .jni import cast, java, reflect @@ -104,6 +105,19 @@ def convert_args(args, type_names): jarg = java.NewByteArray(len(arg)) java.SetByteArrayRegion(jarg, 0, len(arg), arg) converted.append(jarg) + elif isinstance(arg, Sequence) and type_name[0] == ord(b'['): + if type_name[1] == ord(b'I'): + jarg = java.NewIntArray(len(arg)) + java.SetIntArrayRegion(jarg, 0, len(arg), (jint * len(arg))(*arg)) + converted.append(jarg) + elif type_name[1] == ord(b'F'): + jarg = java.NewFloatArray(len(arg)) + java.SetFloatArrayRegion(jarg, 0, len(arg), (jfloat * len(arg))(*arg)) + converted.append(jarg) + elif type_name[1] == ord(b'D'): + jarg = java.NewDoubleArray(len(arg)) + java.SetDoubleArrayRegion(jarg, 0, len(arg), (jdouble * len(arg))(*arg)) + converted.append(jarg) elif isinstance(arg, str): converted.append(java.NewStringUTF(arg.encode('utf-8'))) elif isinstance(arg, (JavaInstance, JavaProxy)): @@ -170,6 +184,15 @@ def select_polymorph(polymorphs, args): b"Ljava/lang/CharSequence;", b"Ljava/lang/Object;", ]) + elif isinstance(arg, Sequence) and len(arg) > 0: + # If arg is an iterable of all the same basic numeric type, then + # an array of that Java type can work. + if isinstance(arg[0], (int, jint)): + if all((isinstance(item, (int, jint)) for item in arg)): + arg_types.append([b'[I']) + elif isinstance(arg[0], (float, jfloat, jdouble)): + if all((isinstance(item, (float, jfloat, jdouble)) for item in arg)): + arg_types.append([b'[D', b'[F']) elif isinstance(arg, (JavaInstance, JavaProxy)): arg_types.append(arg.__class__.__dict__['_alternates']) else: diff --git a/rubicon/java/jni.py b/rubicon/java/jni.py index 1c56fce..e3a8898 100644 --- a/rubicon/java/jni.py +++ b/rubicon/java/jni.py @@ -3,7 +3,8 @@ from .types import ( jarray, jboolean, jboolean_p, jbyte, jbyte_p, jbyteArray, jchar, jclass, - jdouble, jfieldID, jfloat, jint, jlong, jmethodID, jobject, jobjectArray, + jdouble, jdouble_p, jdoubleArray, jfieldID, jfloat, jfloat_p, jfloatArray, + jint, jint_p, jintArray, jlong, jmethodID, jobject, jobjectArray, jshort, jsize, jstring, ) @@ -170,6 +171,24 @@ java.SetByteArrayRegion.restype = None java.SetByteArrayRegion.argtypes = [jbyteArray, jsize, jsize, jbyte_p] +java.NewDoubleArray.restype = jdoubleArray +java.NewDoubleArray.argtypes = [jsize] + +java.SetDoubleArrayRegion.restype = None +java.SetDoubleArrayRegion.argtypes = [jdoubleArray, jsize, jsize, jdouble_p] + +java.NewIntArray.restype = jintArray +java.NewIntArray.argtypes = [jsize] + +java.SetIntArrayRegion.restype = None +java.SetIntArrayRegion.argtypes = [jintArray, jsize, jsize, jint_p] + +java.NewFloatArray.restype = jfloatArray +java.NewFloatArray.argtypes = [jsize] + +java.SetFloatArrayRegion.restype = None +java.SetFloatArrayRegion.argtypes = [jfloatArray, jsize, jsize, jfloat_p] + class _ReflectionAPI(object): "A lazy-loading proxy for the key classes and methods in the Java reflection API" diff --git a/rubicon/java/types.py b/rubicon/java/types.py index 1b03fd5..6dd07ad 100644 --- a/rubicon/java/types.py +++ b/rubicon/java/types.py @@ -1,5 +1,5 @@ from ctypes import ( - POINTER, Structure, c_bool, c_byte, c_char_p, c_double, c_float, c_int, + POINTER, Structure, c_bool, c_byte, c_char_p, c_double, c_float, c_uint32, c_longlong, c_short, c_void_p, c_wchar, ) @@ -19,7 +19,7 @@ jbyte = c_byte jchar = c_wchar jshort = c_short -jint = c_int +jint = c_uint32 jlong = c_longlong jfloat = c_float jdouble = c_double diff --git a/tests/test_rubicon.py b/tests/test_rubicon.py index 0e24ba2..4dcf564 100644 --- a/tests/test_rubicon.py +++ b/tests/test_rubicon.py @@ -238,6 +238,21 @@ def test_polymorphic_static_method(self): with self.assertRaises(ValueError): Example.tripler(1.234) + def test_pass_int_array(self): + """A list of Python ints can be passed as a Java int array.""" + Example = JavaClass("org/beeware/rubicon/test/Example") + self.assertEqual(3, Example.sum_all_ints([1, 2])) + + def test_pass_double_array(self): + """A list of Python floats can be passed as a Java double array.""" + Example = JavaClass("org/beeware/rubicon/test/Example") + self.assertEqual(3, Example.sum_all_doubles([1.0, 2.0])) + + def test_pass_float_array(self): + """A list of Python floats can be passed as a Java float array.""" + Example = JavaClass("org/beeware/rubicon/test/Example") + self.assertEqual(3, Example.sum_all_floats([1.0, 2.0])) + def test_static_access_non_static(self): "An instance field/method cannot be accessed from the static context" Example = JavaClass('org/beeware/rubicon/test/Example')