Skip to content
This repository has been archived by the owner on Oct 12, 2022. It is now read-only.

Commit

Permalink
Enable calling Java methods that take arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
paulproteus committed Sep 11, 2020
1 parent a588490 commit 107f4e6
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 3 deletions.
24 changes: 24 additions & 0 deletions org/beeware/rubicon/test/Example.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,30 @@ static public long get_static_long_field() {
return static_long_field;
}

static public int sum_all_ints(int[] numbers) {
int sum = 0;
for (int number : numbers) {
sum += number;
}
return sum;
}

static public double sum_all_doubles(double[] numbers) {
double sum = 0;
for (double number : numbers) {
sum += number;
}
return sum;
}

static public float sum_all_floats(float[] numbers) {
float sum = 0;
for (float number : numbers) {
sum += number;
}
return sum;
}

/* An inner enumerated type */
public enum Stuff {
FOO, BAR, WHIZ;
Expand Down
23 changes: 23 additions & 0 deletions rubicon/java/api.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections.abc import Sequence
import itertools

from .jni import cast, java, reflect
Expand Down Expand Up @@ -104,6 +105,19 @@ def convert_args(args, type_names):
jarg = java.NewByteArray(len(arg))
java.SetByteArrayRegion(jarg, 0, len(arg), arg)
converted.append(jarg)
elif isinstance(arg, Sequence) and type_name[0] == ord(b'['):
if type_name[1] == ord(b'I'):
jarg = java.NewIntArray(len(arg))
java.SetIntArrayRegion(jarg, 0, len(arg), (jint * len(arg))(*arg))
converted.append(jarg)
elif type_name[1] == ord(b'F'):
jarg = java.NewFloatArray(len(arg))
java.SetFloatArrayRegion(jarg, 0, len(arg), (jfloat * len(arg))(*arg))
converted.append(jarg)
elif type_name[1] == ord(b'D'):
jarg = java.NewDoubleArray(len(arg))
java.SetDoubleArrayRegion(jarg, 0, len(arg), (jdouble * len(arg))(*arg))
converted.append(jarg)
elif isinstance(arg, str):
converted.append(java.NewStringUTF(arg.encode('utf-8')))
elif isinstance(arg, (JavaInstance, JavaProxy)):
Expand Down Expand Up @@ -170,6 +184,15 @@ def select_polymorph(polymorphs, args):
b"Ljava/lang/CharSequence;",
b"Ljava/lang/Object;",
])
elif isinstance(arg, Sequence) and len(arg) > 0:
# If arg is an iterable of all the same basic numeric type, then
# an array of that Java type can work.
if isinstance(arg[0], (int, jint)):
if all((isinstance(item, (int, jint)) for item in arg)):
arg_types.append([b'[I'])
elif isinstance(arg[0], (float, jfloat, jdouble)):
if all((isinstance(item, (float, jfloat, jdouble)) for item in arg)):
arg_types.append([b'[D', b'[F'])
elif isinstance(arg, (JavaInstance, JavaProxy)):
arg_types.append(arg.__class__.__dict__['_alternates'])
else:
Expand Down
21 changes: 20 additions & 1 deletion rubicon/java/jni.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@

from .types import (
jarray, jboolean, jboolean_p, jbyte, jbyte_p, jbyteArray, jchar, jclass,
jdouble, jfieldID, jfloat, jint, jlong, jmethodID, jobject, jobjectArray,
jdouble, jdouble_p, jdoubleArray, jfieldID, jfloat, jfloat_p, jfloatArray,
jint, jint_p, jintArray, jlong, jmethodID, jobject, jobjectArray,
jshort, jsize, jstring,
)

Expand Down Expand Up @@ -170,6 +171,24 @@
java.SetByteArrayRegion.restype = None
java.SetByteArrayRegion.argtypes = [jbyteArray, jsize, jsize, jbyte_p]

java.NewDoubleArray.restype = jdoubleArray
java.NewDoubleArray.argtypes = [jsize]

java.SetDoubleArrayRegion.restype = None
java.SetDoubleArrayRegion.argtypes = [jdoubleArray, jsize, jsize, jdouble_p]

java.NewIntArray.restype = jintArray
java.NewIntArray.argtypes = [jsize]

java.SetIntArrayRegion.restype = None
java.SetIntArrayRegion.argtypes = [jintArray, jsize, jsize, jint_p]

java.NewFloatArray.restype = jfloatArray
java.NewFloatArray.argtypes = [jsize]

java.SetFloatArrayRegion.restype = None
java.SetFloatArrayRegion.argtypes = [jfloatArray, jsize, jsize, jfloat_p]


class _ReflectionAPI(object):
"A lazy-loading proxy for the key classes and methods in the Java reflection API"
Expand Down
4 changes: 2 additions & 2 deletions rubicon/java/types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from ctypes import (
POINTER, Structure, c_bool, c_byte, c_char_p, c_double, c_float, c_int,
POINTER, Structure, c_bool, c_byte, c_char_p, c_double, c_float, c_uint32,
c_longlong, c_short, c_void_p, c_wchar,
)

Expand All @@ -19,7 +19,7 @@
jbyte = c_byte
jchar = c_wchar
jshort = c_short
jint = c_int
jint = c_uint32
jlong = c_longlong
jfloat = c_float
jdouble = c_double
Expand Down
15 changes: 15 additions & 0 deletions tests/test_rubicon.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,21 @@ def test_polymorphic_static_method(self):
with self.assertRaises(ValueError):
Example.tripler(1.234)

def test_pass_int_array(self):
"""A list of Python ints can be passed as a Java int array."""
Example = JavaClass("org/beeware/rubicon/test/Example")
self.assertEqual(3, Example.sum_all_ints([1, 2]))

def test_pass_double_array(self):
"""A list of Python floats can be passed as a Java double array."""
Example = JavaClass("org/beeware/rubicon/test/Example")
self.assertEqual(3, Example.sum_all_doubles([1.0, 2.0]))

def test_pass_float_array(self):
"""A list of Python floats can be passed as a Java float array."""
Example = JavaClass("org/beeware/rubicon/test/Example")
self.assertEqual(3, Example.sum_all_floats([1.0, 2.0]))

def test_static_access_non_static(self):
"An instance field/method cannot be accessed from the static context"
Example = JavaClass('org/beeware/rubicon/test/Example')
Expand Down

0 comments on commit 107f4e6

Please sign in to comment.