From a4d2ae26a60473abe0d7b61d4dced0c69ece6d5d Mon Sep 17 00:00:00 2001 From: sperlingxx Date: Mon, 26 Apr 2021 15:54:45 +0800 Subject: [PATCH 1/7] refactor AggregationJni to suppport collectSet as well as collectList Signed-off-by: sperlingxx --- .../main/java/ai/rapids/cudf/Aggregation.java | 142 +++++++++++++++--- java/src/main/native/src/AggregationJni.cpp | 31 +++- .../test/java/ai/rapids/cudf/TableTest.java | 88 ++++++++++- 3 files changed, 236 insertions(+), 25 deletions(-) diff --git a/java/src/main/java/ai/rapids/cudf/Aggregation.java b/java/src/main/java/ai/rapids/cudf/Aggregation.java index 7d8989571f7..ddfc86c8aaf 100644 --- a/java/src/main/java/ai/rapids/cudf/Aggregation.java +++ b/java/src/main/java/ai/rapids/cudf/Aggregation.java @@ -54,11 +54,12 @@ enum Kind { NUNIQUE(15), NTH_ELEMENT(16), ROW_NUMBER(17), - COLLECT(18), - LEAD(19), - LAG(20), - PTX(21), - CUDA(22); + COLLECT_LIST(18), + COLLECT_SET(19), + LEAD(20), + LAG(21), + PTX(22), + CUDA(23); final int nativeId; @@ -77,6 +78,30 @@ public enum NullPolicy { final boolean includeNulls; } + /* + * This is analogous to the native 'null_equality'. + */ + public enum NullEquality { + UNEQUAL(false), + EQUAL(true); + + NullEquality(boolean nullsEqual) { this.nullsEqual = nullsEqual; } + + final boolean nullsEqual; + } + + /* + * This is analogous to the native 'nan_equality'. + */ + public enum NanEquality { + UNEQUAL(false), + ALL_EQUAL(true); + + NanEquality(boolean nansEqual) { this.nansEqual = nansEqual; } + + final boolean nansEqual; + } + /** * An Aggregation that only needs a kind and nothing else. */ @@ -280,17 +305,17 @@ long getDefaultOutput() { } } - private static final class CollectAggregation extends Aggregation { + private static final class CollectListAggregation extends Aggregation { private final NullPolicy nullPolicy; - public CollectAggregation(NullPolicy nullPolicy) { - super(Kind.COLLECT); + public CollectListAggregation(NullPolicy nullPolicy) { + super(Kind.COLLECT_LIST); this.nullPolicy = nullPolicy; } @Override long createNativeInstance() { - return Aggregation.createCollectAgg(nullPolicy.includeNulls); + return Aggregation.createCollectListAgg(nullPolicy.includeNulls); } @Override @@ -302,14 +327,53 @@ public int hashCode() { public boolean equals(Object other) { if (this == other) { return true; - } else if (other instanceof CollectAggregation) { - CollectAggregation o = (CollectAggregation) other; + } else if (other instanceof CollectListAggregation) { + CollectListAggregation o = (CollectListAggregation) other; return o.nullPolicy == this.nullPolicy; } return false; } } + private static final class CollectSetAggregation extends Aggregation { + private final NullPolicy nullPolicy; + private final NullEquality nullEquality; + private final NanEquality nanEquality; + + public CollectSetAggregation(NullPolicy nullPolicy, NullEquality nullEquality, NanEquality nanEquality) { + super(Kind.COLLECT_SET); + this.nullPolicy = nullPolicy; + this.nullEquality = nullEquality; + this.nanEquality = nanEquality; + } + + @Override + long createNativeInstance() { + return Aggregation.createCollectSetAgg(nullPolicy.includeNulls, + nullEquality.nullsEqual, + nanEquality.nansEqual); + } + + @Override + public int hashCode() { + boolean[] configs = new boolean[]{nullPolicy.includeNulls, nullEquality.nullsEqual, nanEquality.nansEqual}; + return 31 * kind.hashCode() + Arrays.hashCode(configs); + } + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } else if (other instanceof CollectSetAggregation) { + CollectSetAggregation o = (CollectSetAggregation) other; + return o.nullPolicy == this.nullPolicy && + o.nullEquality == this.nullEquality && + o.nanEquality == this.nanEquality; + } + return false; + } + } + protected final Kind kind; protected Aggregation(Kind kind) { @@ -593,18 +657,57 @@ public static Aggregation rowNumber() { /** * Collect the values into a list. nulls will be skipped. + * WARNING: This method is deprecated, please use collectList as instead. */ + @Deprecated public static Aggregation collect() { - return collect(NullPolicy.EXCLUDE); + return collectList(); } /** * Collect the values into a list. - * @param nullPolicy INCLUDE if nulls should be included in the aggregation or EXCLUDE if they - * should be skipped. + * WARNING: This method is deprecated, please use collectList as instead. + * + * @param nullPolicy Indicates whether to include/exclude nulls during collection. */ + @Deprecated public static Aggregation collect(NullPolicy nullPolicy) { - return new CollectAggregation(nullPolicy); + return collectList(nullPolicy); + } + + /** + * Collect the values into a list. nulls will be skipped. + */ + public static Aggregation collectList() { + return collectList(NullPolicy.EXCLUDE); + } + + /** + * Collect the values into a list. + * + * @param nullPolicy Indicates whether to include/exclude nulls during collection. + */ + public static Aggregation collectList(NullPolicy nullPolicy) { + return new CollectListAggregation(nullPolicy); + } + + /** + * Collect the values into a set. All null values will be excluded. And all nan values are regarded as + * unique instances. + */ + public static Aggregation collectSet() { + return new CollectSetAggregation(NullPolicy.EXCLUDE, NullEquality.UNEQUAL, NanEquality.UNEQUAL); + } + + /** + * Collect the values into a set. + * + * @param nullPolicy Indicates whether to include/exclude nulls during collection. + * @param nullEquality Flag to specify whether null entries within each list should be considered equal. + * @param nanEquality Flag to specify whether NaN values in floating point column should be considered equal. + */ + public static Aggregation collectSet(NullPolicy nullPolicy, NullEquality nullEquality, NanEquality nanEquality) { + return new CollectSetAggregation(nullPolicy, nullEquality, nanEquality); } /** @@ -675,7 +778,12 @@ public static Aggregation lag(int offset, ColumnVector defaultOutput) { private static native long createLeadLagAgg(int kind, int offset); /** - * Create a collect aggregation including nulls or not. + * Create a collect list aggregation including nulls or not. + */ + private static native long createCollectListAgg(boolean includeNulls); + + /** + * Create a collect set aggregation. */ - private static native long createCollectAgg(boolean includeNulls); + private static native long createCollectSetAgg(boolean includeNulls, boolean nullsEqual, boolean nansEqual); } diff --git a/java/src/main/native/src/AggregationJni.cpp b/java/src/main/native/src/AggregationJni.cpp index c5184111edf..10839889d46 100644 --- a/java/src/main/native/src/AggregationJni.cpp +++ b/java/src/main/native/src/AggregationJni.cpp @@ -186,10 +186,10 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Aggregation_createLeadLagAgg(JNIEnv std::unique_ptr ret; // These numbers come from Aggregation.java and must stay in sync switch (kind) { - case 19: // LEAD + case 20: // LEAD ret = cudf::make_lead_aggregation(offset); break; - case 20: // LAG + case 21: // LAG ret = cudf::make_lag_aggregation(offset); break; default: throw std::logic_error("Unsupported Lead/Lag Aggregation Operation"); @@ -199,9 +199,9 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Aggregation_createLeadLagAgg(JNIEnv CATCH_STD(env, 0); } -JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Aggregation_createCollectAgg(JNIEnv *env, - jclass class_object, - jboolean include_nulls) { +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Aggregation_createCollectListAgg(JNIEnv *env, + jclass class_object, + jboolean include_nulls) { try { cudf::jni::auto_set_device(env); cudf::null_policy policy = @@ -212,4 +212,25 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Aggregation_createCollectAgg(JNIEnv CATCH_STD(env, 0); } +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Aggregation_createCollectSetAgg(JNIEnv *env, + jclass class_object, + jboolean include_nulls, + jboolean nulls_equal, + jboolean nans_equal) { + try { + cudf::jni::auto_set_device(env); + cudf::null_policy null_policy = + include_nulls ? cudf::null_policy::INCLUDE : cudf::null_policy::EXCLUDE; + cudf::null_equality null_equality = + nulls_equal ? cudf::null_equality::EQUAL : cudf::null_equality::UNEQUAL; + cudf::nan_equality nan_equality = + nans_equal ? cudf::nan_equality::ALL_EQUAL : cudf::nan_equality::UNEQUAL; + std::unique_ptr ret = cudf::make_collect_set_aggregation(null_policy, + null_equality, + nan_equality); + return reinterpret_cast(ret.release()); + } + CATCH_STD(env, 0); +} + } // extern "C" diff --git a/java/src/test/java/ai/rapids/cudf/TableTest.java b/java/src/test/java/ai/rapids/cudf/TableTest.java index 670ec585da3..ded4b606da2 100644 --- a/java/src/test/java/ai/rapids/cudf/TableTest.java +++ b/java/src/test/java/ai/rapids/cudf/TableTest.java @@ -2945,9 +2945,9 @@ void testWindowingRowNumber() { } @Test - void testWindowingCollect() { - Aggregation aggCollectWithNulls = Aggregation.collect(Aggregation.NullPolicy.INCLUDE); - Aggregation aggCollect = Aggregation.collect(); + void testWindowingCollectList() { + Aggregation aggCollectWithNulls = Aggregation.collectList(Aggregation.NullPolicy.INCLUDE); + Aggregation aggCollect = Aggregation.collectList(); WindowOptions winOpts = WindowOptions.builder() .minPeriods(1) .window(2, 1).build(); @@ -4403,6 +4403,88 @@ void testGroupByContiguousSplitGroups() { } } + @Test + void testGroupByCollectListIncludeNulls() { + try (Table input = new Table.TestBuilder() + .column(1, 1, 1, 1, 2, 2, 2, 2, 3, 4) + .column(null, 13, null, 12, 14, null, 15, null, null, 0) + .build(); + Table expected = new Table.TestBuilder() + .column(1, 2, 3, 4) + .column(new ListType(false, new BasicType(true, DType.INT32)), + Arrays.asList(null, 13, null, 12), + Arrays.asList(14, null, 15, null), + Arrays.asList((Integer) null), + Arrays.asList(0)) + .build(); + Table found = input.groupBy(0).aggregate( + Aggregation.collectList(Aggregation.NullPolicy.INCLUDE).onColumn(1))) { + assertTablesAreEqual(expected, found); + } + } + + @Test + void testGroupByCollectSetIncludeNulls() { + // test with null unequal and nan unequal + Aggregation collectSet = Aggregation.collectSet(Aggregation.NullPolicy.INCLUDE, + Aggregation.NullEquality.UNEQUAL, Aggregation.NanEquality.UNEQUAL); + try (Table input = new Table.TestBuilder() + .column(1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4) + .column(null, 13, null, 13, 14, null, 15, null, 4, 1, 1, 4, 0, 0, 0, 0) + .build(); + Table expected = new Table.TestBuilder() + .column(1, 2, 3, 4) + .column(new ListType(false, new BasicType(true, DType.INT32)), + Arrays.asList(13, null, null), Arrays.asList(14, 15, null, null), + Arrays.asList(1, 4), Arrays.asList(0)) + .build(); + Table found = input.groupBy(0).aggregate(collectSet.onColumn(1))) { + assertTablesAreEqual(expected, found); + } + // test with null equal and nan unequal + collectSet = Aggregation.collectSet(Aggregation.NullPolicy.INCLUDE, + Aggregation.NullEquality.EQUAL, Aggregation.NanEquality.UNEQUAL); + try (Table input = new Table.TestBuilder() + .column(1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4) + .column(null, 13.0, null, 13.0, + 14.1, Double.NaN, 13.9, Double.NaN, + Double.NaN, null, 1.0, null, + null, null, null, null) + .build(); + Table expected = new Table.TestBuilder() + .column(1, 2, 3, 4) + .column(new ListType(false, new BasicType(true, DType.FLOAT64)), + Arrays.asList(13.0, null), + Arrays.asList(13.9, 14.1, Double.NaN, Double.NaN), + Arrays.asList(1.0, null, Double.NaN), + Arrays.asList((Integer) null)) + .build(); + Table found = input.groupBy(0).aggregate(collectSet.onColumn(1))) { + assertTablesAreEqual(expected, found); + } + // test with null equal and nan equal + collectSet = Aggregation.collectSet(Aggregation.NullPolicy.INCLUDE, + Aggregation.NullEquality.EQUAL, Aggregation.NanEquality.ALL_EQUAL); + try (Table input = new Table.TestBuilder() + .column(1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4) + .column(null, 13.0, null, 13.0, + 14.1, Double.NaN, 13.9, Double.NaN, + 0.0, 0.0, 0.00, 0.0, + Double.NaN, Double.NaN, null, null) + .build(); + Table expected = new Table.TestBuilder() + .column(1, 2, 3, 4) + .column(new ListType(false, new BasicType(true, DType.FLOAT64)), + Arrays.asList(13.0, null), + Arrays.asList(13.9, 14.1, Double.NaN), + Arrays.asList(0.0), + Arrays.asList((Integer) null, Double.NaN)) + .build(); + Table found = input.groupBy(0).aggregate(collectSet.onColumn(1))) { + assertTablesAreEqual(expected, found); + } + } + @Test void testRowBitCount() { try (Table t = new Table.TestBuilder() From 2944c470ddec7d28e1c4ba352318fec98be3e40a Mon Sep 17 00:00:00 2001 From: Alfred Xu Date: Tue, 27 Apr 2021 17:07:14 +0800 Subject: [PATCH 2/7] Update java/src/main/java/ai/rapids/cudf/Aggregation.java Co-authored-by: Nghia Truong --- java/src/main/java/ai/rapids/cudf/Aggregation.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/java/src/main/java/ai/rapids/cudf/Aggregation.java b/java/src/main/java/ai/rapids/cudf/Aggregation.java index ddfc86c8aaf..db40075b197 100644 --- a/java/src/main/java/ai/rapids/cudf/Aggregation.java +++ b/java/src/main/java/ai/rapids/cudf/Aggregation.java @@ -93,7 +93,7 @@ public enum NullEquality { /* * This is analogous to the native 'nan_equality'. */ - public enum NanEquality { + public enum NaNEquality { UNEQUAL(false), ALL_EQUAL(true); From fc62d719bce9860d2dee5665bd7bce8855c2aede Mon Sep 17 00:00:00 2001 From: sperlingxx Date: Tue, 27 Apr 2021 18:10:50 +0800 Subject: [PATCH 3/7] some refinements Signed-off-by: sperlingxx --- .../main/java/ai/rapids/cudf/Aggregation.java | 54 ++++++++------- java/src/main/native/src/AggregationJni.cpp | 69 ++++++++----------- .../test/java/ai/rapids/cudf/TableTest.java | 6 +- 3 files changed, 59 insertions(+), 70 deletions(-) diff --git a/java/src/main/java/ai/rapids/cudf/Aggregation.java b/java/src/main/java/ai/rapids/cudf/Aggregation.java index db40075b197..f02c8e85c5d 100644 --- a/java/src/main/java/ai/rapids/cudf/Aggregation.java +++ b/java/src/main/java/ai/rapids/cudf/Aggregation.java @@ -40,26 +40,27 @@ enum Kind { PRODUCT(1), MIN(2), MAX(3), - COUNT(4), - ANY(5), - ALL(6), - SUM_OF_SQUARES(7), - MEAN(8), - VARIANCE(9), // This can take a delta degrees of freedom - STD(10), // This can take a delta degrees of freedom - MEDIAN(11), - QUANTILE(12), - ARGMAX(13), - ARGMIN(14), - NUNIQUE(15), - NTH_ELEMENT(16), - ROW_NUMBER(17), - COLLECT_LIST(18), - COLLECT_SET(19), - LEAD(20), - LAG(21), - PTX(22), - CUDA(23); + COUNT_VALID(4), + COUNT_ALL(5), + ANY(6), + ALL(7), + SUM_OF_SQUARES(8), + MEAN(9), + VARIANCE(10), // This can take a delta degrees of freedom + STD(11), // This can take a delta degrees of freedom + MEDIAN(12), + QUANTILE(13), + ARGMAX(14), + ARGMIN(15), + NUNIQUE(16), + NTH_ELEMENT(17), + ROW_NUMBER(18), + COLLECT_LIST(19), + COLLECT_SET(20), + LEAD(21), + LAG(22), + PTX(23), + CUDA(24); final int nativeId; @@ -97,7 +98,7 @@ public enum NaNEquality { UNEQUAL(false), ALL_EQUAL(true); - NanEquality(boolean nansEqual) { this.nansEqual = nansEqual; } + NaNEquality(boolean nansEqual) { this.nansEqual = nansEqual; } final boolean nansEqual; } @@ -338,9 +339,9 @@ public boolean equals(Object other) { private static final class CollectSetAggregation extends Aggregation { private final NullPolicy nullPolicy; private final NullEquality nullEquality; - private final NanEquality nanEquality; + private final NaNEquality nanEquality; - public CollectSetAggregation(NullPolicy nullPolicy, NullEquality nullEquality, NanEquality nanEquality) { + public CollectSetAggregation(NullPolicy nullPolicy, NullEquality nullEquality, NaNEquality nanEquality) { super(Kind.COLLECT_SET); this.nullPolicy = nullPolicy; this.nullEquality = nullEquality; @@ -477,7 +478,8 @@ public static Aggregation count(boolean includeNulls) { * should be counted. */ public static Aggregation count(NullPolicy nullPolicy) { - return new CountLikeAggregation(Kind.COUNT, nullPolicy); + Aggregation.Kind kind = nullPolicy.includeNulls ? Kind.COUNT_ALL : Kind.COUNT_VALID; + return new CountLikeAggregation(kind, nullPolicy); } /** @@ -696,7 +698,7 @@ public static Aggregation collectList(NullPolicy nullPolicy) { * unique instances. */ public static Aggregation collectSet() { - return new CollectSetAggregation(NullPolicy.EXCLUDE, NullEquality.UNEQUAL, NanEquality.UNEQUAL); + return new CollectSetAggregation(NullPolicy.EXCLUDE, NullEquality.UNEQUAL, NaNEquality.UNEQUAL); } /** @@ -706,7 +708,7 @@ public static Aggregation collectSet() { * @param nullEquality Flag to specify whether null entries within each list should be considered equal. * @param nanEquality Flag to specify whether NaN values in floating point column should be considered equal. */ - public static Aggregation collectSet(NullPolicy nullPolicy, NullEquality nullEquality, NanEquality nanEquality) { + public static Aggregation collectSet(NullPolicy nullPolicy, NullEquality nullEquality, NaNEquality nanEquality) { return new CollectSetAggregation(nullPolicy, nullEquality, nanEquality); } diff --git a/java/src/main/native/src/AggregationJni.cpp b/java/src/main/native/src/AggregationJni.cpp index 10839889d46..18f3975eb1f 100644 --- a/java/src/main/native/src/AggregationJni.cpp +++ b/java/src/main/native/src/AggregationJni.cpp @@ -34,58 +34,47 @@ JNIEXPORT void JNICALL Java_ai_rapids_cudf_Aggregation_close(JNIEnv *env, JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Aggregation_createNoParamAgg(JNIEnv *env, jclass class_object, jint kind) { + try { cudf::jni::auto_set_device(env); std::unique_ptr ret; - // These numbers come from Aggregation.java and must stay in sync - switch (kind) { - case 0: // SUM + switch (static_cast(kind)) { + case cudf::aggregation::SUM: ret = cudf::make_sum_aggregation(); break; - case 1: // PRODUCT + case cudf::aggregation::PRODUCT: ret = cudf::make_product_aggregation(); break; - case 2: // MIN + case cudf::aggregation::MIN: ret = cudf::make_min_aggregation(); break; - case 3: // MAX + case cudf::aggregation::MAX: ret = cudf::make_max_aggregation(); break; - //case 4 COUNT - case 5: // ANY + case cudf::aggregation::ANY: ret = cudf::make_any_aggregation(); break; - case 6: // ALL + case cudf::aggregation::ALL: ret = cudf::make_all_aggregation(); break; - case 7: // SUM_OF_SQUARES + case cudf::aggregation::SUM_OF_SQUARES: ret = cudf::make_sum_of_squares_aggregation(); break; - case 8: // MEAN + case cudf::aggregation::MEAN: ret = cudf::make_mean_aggregation(); break; - // case 9: VARIANCE - // case 10: STD - case 11: // MEDIAN + case cudf::aggregation::MEDIAN: ret = cudf::make_median_aggregation(); break; - // case 12: QUANTILE - case 13: // ARGMAX + case cudf::aggregation::ARGMAX: ret = cudf::make_argmax_aggregation(); break; - case 14: // ARGMIN + case cudf::aggregation::ARGMIN: ret = cudf::make_argmin_aggregation(); break; - // case 15: NUNIQUE - // case 16: NTH_ELEMENT - case 17: // ROW_NUMBER + case cudf::aggregation::ROW_NUMBER: ret = cudf::make_row_number_aggregation(); break; - // case 18: COLLECT - // case 19: LEAD - // case 20: LAG - // case 21: PTX - // case 22: CUDA default: throw std::logic_error("Unsupported No Parameter Aggregation Operation"); } @@ -117,12 +106,11 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Aggregation_createDdofAgg(JNIEnv *en cudf::jni::auto_set_device(env); std::unique_ptr ret; - // These numbers come from Aggregation.java and must stay in sync - switch (kind) { - case 9: // VARIANCE + switch (static_cast(kind)) { + case cudf::aggregation::VARIANCE: ret = cudf::make_variance_aggregation(ddof); break; - case 10: // STD + case cudf::aggregation::STD: ret = cudf::make_std_aggregation(ddof); break; default: throw std::logic_error("Unsupported DDOF Aggregation Operation"); @@ -139,16 +127,16 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Aggregation_createCountLikeAgg(JNIEn try { cudf::jni::auto_set_device(env); - cudf::null_policy policy = - include_nulls ? cudf::null_policy::INCLUDE : cudf::null_policy::EXCLUDE; std::unique_ptr ret; - // These numbers come from Aggregation.java and must stay in sync - switch (kind) { - case 4: // COUNT - ret = cudf::make_count_aggregation(policy); + switch (static_cast(kind)) { + case cudf::aggregation::COUNT_VALID: + ret = cudf::make_count_aggregation(cudf::null_policy::EXCLUDE); + break; + case cudf::aggregation::COUNT_ALL: + ret = cudf::make_count_aggregation(cudf::null_policy::INCLUDE); break; - case 15: // NUNIQUE - ret = cudf::make_nunique_aggregation(policy); + case cudf::aggregation::NUNIQUE: + ret = cudf::make_nunique_aggregation(include_nulls ? cudf::null_policy::INCLUDE : cudf::null_policy::EXCLUDE); break; default: throw std::logic_error("Unsupported Count Like Aggregation Operation"); } @@ -184,12 +172,11 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Aggregation_createLeadLagAgg(JNIEnv cudf::jni::auto_set_device(env); std::unique_ptr ret; - // These numbers come from Aggregation.java and must stay in sync - switch (kind) { - case 20: // LEAD + switch (static_cast(kind)) { + case cudf::aggregation::LEAD: ret = cudf::make_lead_aggregation(offset); break; - case 21: // LAG + case cudf::aggregation::LAG: ret = cudf::make_lag_aggregation(offset); break; default: throw std::logic_error("Unsupported Lead/Lag Aggregation Operation"); diff --git a/java/src/test/java/ai/rapids/cudf/TableTest.java b/java/src/test/java/ai/rapids/cudf/TableTest.java index ded4b606da2..821b61e729f 100644 --- a/java/src/test/java/ai/rapids/cudf/TableTest.java +++ b/java/src/test/java/ai/rapids/cudf/TableTest.java @@ -4427,7 +4427,7 @@ void testGroupByCollectListIncludeNulls() { void testGroupByCollectSetIncludeNulls() { // test with null unequal and nan unequal Aggregation collectSet = Aggregation.collectSet(Aggregation.NullPolicy.INCLUDE, - Aggregation.NullEquality.UNEQUAL, Aggregation.NanEquality.UNEQUAL); + Aggregation.NullEquality.UNEQUAL, Aggregation.NaNEquality.UNEQUAL); try (Table input = new Table.TestBuilder() .column(1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4) .column(null, 13, null, 13, 14, null, 15, null, 4, 1, 1, 4, 0, 0, 0, 0) @@ -4443,7 +4443,7 @@ void testGroupByCollectSetIncludeNulls() { } // test with null equal and nan unequal collectSet = Aggregation.collectSet(Aggregation.NullPolicy.INCLUDE, - Aggregation.NullEquality.EQUAL, Aggregation.NanEquality.UNEQUAL); + Aggregation.NullEquality.EQUAL, Aggregation.NaNEquality.UNEQUAL); try (Table input = new Table.TestBuilder() .column(1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4) .column(null, 13.0, null, 13.0, @@ -4464,7 +4464,7 @@ void testGroupByCollectSetIncludeNulls() { } // test with null equal and nan equal collectSet = Aggregation.collectSet(Aggregation.NullPolicy.INCLUDE, - Aggregation.NullEquality.EQUAL, Aggregation.NanEquality.ALL_EQUAL); + Aggregation.NullEquality.EQUAL, Aggregation.NaNEquality.ALL_EQUAL); try (Table input = new Table.TestBuilder() .column(1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4) .column(null, 13.0, null, 13.0, From 402003d894b1c34dc85b5226f9fbd4e6b3d51fba Mon Sep 17 00:00:00 2001 From: sperlingxx Date: Fri, 30 Apr 2021 14:54:21 +0800 Subject: [PATCH 4/7] Revert "some refinements" This reverts commit fc62d719bce9860d2dee5665bd7bce8855c2aede. --- .../main/java/ai/rapids/cudf/Aggregation.java | 54 +++++++-------- java/src/main/native/src/AggregationJni.cpp | 69 +++++++++++-------- .../test/java/ai/rapids/cudf/TableTest.java | 6 +- 3 files changed, 70 insertions(+), 59 deletions(-) diff --git a/java/src/main/java/ai/rapids/cudf/Aggregation.java b/java/src/main/java/ai/rapids/cudf/Aggregation.java index f02c8e85c5d..db40075b197 100644 --- a/java/src/main/java/ai/rapids/cudf/Aggregation.java +++ b/java/src/main/java/ai/rapids/cudf/Aggregation.java @@ -40,27 +40,26 @@ enum Kind { PRODUCT(1), MIN(2), MAX(3), - COUNT_VALID(4), - COUNT_ALL(5), - ANY(6), - ALL(7), - SUM_OF_SQUARES(8), - MEAN(9), - VARIANCE(10), // This can take a delta degrees of freedom - STD(11), // This can take a delta degrees of freedom - MEDIAN(12), - QUANTILE(13), - ARGMAX(14), - ARGMIN(15), - NUNIQUE(16), - NTH_ELEMENT(17), - ROW_NUMBER(18), - COLLECT_LIST(19), - COLLECT_SET(20), - LEAD(21), - LAG(22), - PTX(23), - CUDA(24); + COUNT(4), + ANY(5), + ALL(6), + SUM_OF_SQUARES(7), + MEAN(8), + VARIANCE(9), // This can take a delta degrees of freedom + STD(10), // This can take a delta degrees of freedom + MEDIAN(11), + QUANTILE(12), + ARGMAX(13), + ARGMIN(14), + NUNIQUE(15), + NTH_ELEMENT(16), + ROW_NUMBER(17), + COLLECT_LIST(18), + COLLECT_SET(19), + LEAD(20), + LAG(21), + PTX(22), + CUDA(23); final int nativeId; @@ -98,7 +97,7 @@ public enum NaNEquality { UNEQUAL(false), ALL_EQUAL(true); - NaNEquality(boolean nansEqual) { this.nansEqual = nansEqual; } + NanEquality(boolean nansEqual) { this.nansEqual = nansEqual; } final boolean nansEqual; } @@ -339,9 +338,9 @@ public boolean equals(Object other) { private static final class CollectSetAggregation extends Aggregation { private final NullPolicy nullPolicy; private final NullEquality nullEquality; - private final NaNEquality nanEquality; + private final NanEquality nanEquality; - public CollectSetAggregation(NullPolicy nullPolicy, NullEquality nullEquality, NaNEquality nanEquality) { + public CollectSetAggregation(NullPolicy nullPolicy, NullEquality nullEquality, NanEquality nanEquality) { super(Kind.COLLECT_SET); this.nullPolicy = nullPolicy; this.nullEquality = nullEquality; @@ -478,8 +477,7 @@ public static Aggregation count(boolean includeNulls) { * should be counted. */ public static Aggregation count(NullPolicy nullPolicy) { - Aggregation.Kind kind = nullPolicy.includeNulls ? Kind.COUNT_ALL : Kind.COUNT_VALID; - return new CountLikeAggregation(kind, nullPolicy); + return new CountLikeAggregation(Kind.COUNT, nullPolicy); } /** @@ -698,7 +696,7 @@ public static Aggregation collectList(NullPolicy nullPolicy) { * unique instances. */ public static Aggregation collectSet() { - return new CollectSetAggregation(NullPolicy.EXCLUDE, NullEquality.UNEQUAL, NaNEquality.UNEQUAL); + return new CollectSetAggregation(NullPolicy.EXCLUDE, NullEquality.UNEQUAL, NanEquality.UNEQUAL); } /** @@ -708,7 +706,7 @@ public static Aggregation collectSet() { * @param nullEquality Flag to specify whether null entries within each list should be considered equal. * @param nanEquality Flag to specify whether NaN values in floating point column should be considered equal. */ - public static Aggregation collectSet(NullPolicy nullPolicy, NullEquality nullEquality, NaNEquality nanEquality) { + public static Aggregation collectSet(NullPolicy nullPolicy, NullEquality nullEquality, NanEquality nanEquality) { return new CollectSetAggregation(nullPolicy, nullEquality, nanEquality); } diff --git a/java/src/main/native/src/AggregationJni.cpp b/java/src/main/native/src/AggregationJni.cpp index 18f3975eb1f..10839889d46 100644 --- a/java/src/main/native/src/AggregationJni.cpp +++ b/java/src/main/native/src/AggregationJni.cpp @@ -34,47 +34,58 @@ JNIEXPORT void JNICALL Java_ai_rapids_cudf_Aggregation_close(JNIEnv *env, JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Aggregation_createNoParamAgg(JNIEnv *env, jclass class_object, jint kind) { - try { cudf::jni::auto_set_device(env); std::unique_ptr ret; - switch (static_cast(kind)) { - case cudf::aggregation::SUM: + // These numbers come from Aggregation.java and must stay in sync + switch (kind) { + case 0: // SUM ret = cudf::make_sum_aggregation(); break; - case cudf::aggregation::PRODUCT: + case 1: // PRODUCT ret = cudf::make_product_aggregation(); break; - case cudf::aggregation::MIN: + case 2: // MIN ret = cudf::make_min_aggregation(); break; - case cudf::aggregation::MAX: + case 3: // MAX ret = cudf::make_max_aggregation(); break; - case cudf::aggregation::ANY: + //case 4 COUNT + case 5: // ANY ret = cudf::make_any_aggregation(); break; - case cudf::aggregation::ALL: + case 6: // ALL ret = cudf::make_all_aggregation(); break; - case cudf::aggregation::SUM_OF_SQUARES: + case 7: // SUM_OF_SQUARES ret = cudf::make_sum_of_squares_aggregation(); break; - case cudf::aggregation::MEAN: + case 8: // MEAN ret = cudf::make_mean_aggregation(); break; - case cudf::aggregation::MEDIAN: + // case 9: VARIANCE + // case 10: STD + case 11: // MEDIAN ret = cudf::make_median_aggregation(); break; - case cudf::aggregation::ARGMAX: + // case 12: QUANTILE + case 13: // ARGMAX ret = cudf::make_argmax_aggregation(); break; - case cudf::aggregation::ARGMIN: + case 14: // ARGMIN ret = cudf::make_argmin_aggregation(); break; - case cudf::aggregation::ROW_NUMBER: + // case 15: NUNIQUE + // case 16: NTH_ELEMENT + case 17: // ROW_NUMBER ret = cudf::make_row_number_aggregation(); break; + // case 18: COLLECT + // case 19: LEAD + // case 20: LAG + // case 21: PTX + // case 22: CUDA default: throw std::logic_error("Unsupported No Parameter Aggregation Operation"); } @@ -106,11 +117,12 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Aggregation_createDdofAgg(JNIEnv *en cudf::jni::auto_set_device(env); std::unique_ptr ret; - switch (static_cast(kind)) { - case cudf::aggregation::VARIANCE: + // These numbers come from Aggregation.java and must stay in sync + switch (kind) { + case 9: // VARIANCE ret = cudf::make_variance_aggregation(ddof); break; - case cudf::aggregation::STD: + case 10: // STD ret = cudf::make_std_aggregation(ddof); break; default: throw std::logic_error("Unsupported DDOF Aggregation Operation"); @@ -127,16 +139,16 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Aggregation_createCountLikeAgg(JNIEn try { cudf::jni::auto_set_device(env); + cudf::null_policy policy = + include_nulls ? cudf::null_policy::INCLUDE : cudf::null_policy::EXCLUDE; std::unique_ptr ret; - switch (static_cast(kind)) { - case cudf::aggregation::COUNT_VALID: - ret = cudf::make_count_aggregation(cudf::null_policy::EXCLUDE); - break; - case cudf::aggregation::COUNT_ALL: - ret = cudf::make_count_aggregation(cudf::null_policy::INCLUDE); + // These numbers come from Aggregation.java and must stay in sync + switch (kind) { + case 4: // COUNT + ret = cudf::make_count_aggregation(policy); break; - case cudf::aggregation::NUNIQUE: - ret = cudf::make_nunique_aggregation(include_nulls ? cudf::null_policy::INCLUDE : cudf::null_policy::EXCLUDE); + case 15: // NUNIQUE + ret = cudf::make_nunique_aggregation(policy); break; default: throw std::logic_error("Unsupported Count Like Aggregation Operation"); } @@ -172,11 +184,12 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Aggregation_createLeadLagAgg(JNIEnv cudf::jni::auto_set_device(env); std::unique_ptr ret; - switch (static_cast(kind)) { - case cudf::aggregation::LEAD: + // These numbers come from Aggregation.java and must stay in sync + switch (kind) { + case 20: // LEAD ret = cudf::make_lead_aggregation(offset); break; - case cudf::aggregation::LAG: + case 21: // LAG ret = cudf::make_lag_aggregation(offset); break; default: throw std::logic_error("Unsupported Lead/Lag Aggregation Operation"); diff --git a/java/src/test/java/ai/rapids/cudf/TableTest.java b/java/src/test/java/ai/rapids/cudf/TableTest.java index 821b61e729f..ded4b606da2 100644 --- a/java/src/test/java/ai/rapids/cudf/TableTest.java +++ b/java/src/test/java/ai/rapids/cudf/TableTest.java @@ -4427,7 +4427,7 @@ void testGroupByCollectListIncludeNulls() { void testGroupByCollectSetIncludeNulls() { // test with null unequal and nan unequal Aggregation collectSet = Aggregation.collectSet(Aggregation.NullPolicy.INCLUDE, - Aggregation.NullEquality.UNEQUAL, Aggregation.NaNEquality.UNEQUAL); + Aggregation.NullEquality.UNEQUAL, Aggregation.NanEquality.UNEQUAL); try (Table input = new Table.TestBuilder() .column(1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4) .column(null, 13, null, 13, 14, null, 15, null, 4, 1, 1, 4, 0, 0, 0, 0) @@ -4443,7 +4443,7 @@ void testGroupByCollectSetIncludeNulls() { } // test with null equal and nan unequal collectSet = Aggregation.collectSet(Aggregation.NullPolicy.INCLUDE, - Aggregation.NullEquality.EQUAL, Aggregation.NaNEquality.UNEQUAL); + Aggregation.NullEquality.EQUAL, Aggregation.NanEquality.UNEQUAL); try (Table input = new Table.TestBuilder() .column(1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4) .column(null, 13.0, null, 13.0, @@ -4464,7 +4464,7 @@ void testGroupByCollectSetIncludeNulls() { } // test with null equal and nan equal collectSet = Aggregation.collectSet(Aggregation.NullPolicy.INCLUDE, - Aggregation.NullEquality.EQUAL, Aggregation.NaNEquality.ALL_EQUAL); + Aggregation.NullEquality.EQUAL, Aggregation.NanEquality.ALL_EQUAL); try (Table input = new Table.TestBuilder() .column(1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4) .column(null, 13.0, null, 13.0, From c01275a9755fbf743135322b7e248818667aec77 Mon Sep 17 00:00:00 2001 From: sperlingxx Date: Fri, 30 Apr 2021 17:06:00 +0800 Subject: [PATCH 5/7] fix Signed-off-by: sperlingxx --- .../main/java/ai/rapids/cudf/Aggregation.java | 18 +++++++++--------- .../test/java/ai/rapids/cudf/TableTest.java | 10 +++++----- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/java/src/main/java/ai/rapids/cudf/Aggregation.java b/java/src/main/java/ai/rapids/cudf/Aggregation.java index db40075b197..0a0b55e0181 100644 --- a/java/src/main/java/ai/rapids/cudf/Aggregation.java +++ b/java/src/main/java/ai/rapids/cudf/Aggregation.java @@ -97,7 +97,7 @@ public enum NaNEquality { UNEQUAL(false), ALL_EQUAL(true); - NanEquality(boolean nansEqual) { this.nansEqual = nansEqual; } + NaNEquality(boolean nansEqual) { this.nansEqual = nansEqual; } final boolean nansEqual; } @@ -338,9 +338,9 @@ public boolean equals(Object other) { private static final class CollectSetAggregation extends Aggregation { private final NullPolicy nullPolicy; private final NullEquality nullEquality; - private final NanEquality nanEquality; + private final NaNEquality nanEquality; - public CollectSetAggregation(NullPolicy nullPolicy, NullEquality nullEquality, NanEquality nanEquality) { + public CollectSetAggregation(NullPolicy nullPolicy, NullEquality nullEquality, NaNEquality nanEquality) { super(Kind.COLLECT_SET); this.nullPolicy = nullPolicy; this.nullEquality = nullEquality; @@ -657,7 +657,7 @@ public static Aggregation rowNumber() { /** * Collect the values into a list. nulls will be skipped. - * WARNING: This method is deprecated, please use collectList as instead. + * @deprecated please use collectList as instead. */ @Deprecated public static Aggregation collect() { @@ -666,7 +666,7 @@ public static Aggregation collect() { /** * Collect the values into a list. - * WARNING: This method is deprecated, please use collectList as instead. + * @deprecated please use collectList as instead. * * @param nullPolicy Indicates whether to include/exclude nulls during collection. */ @@ -676,7 +676,7 @@ public static Aggregation collect(NullPolicy nullPolicy) { } /** - * Collect the values into a list. nulls will be skipped. + * Collect the values into a list. Nulls will be skipped. */ public static Aggregation collectList() { return collectList(NullPolicy.EXCLUDE); @@ -692,11 +692,11 @@ public static Aggregation collectList(NullPolicy nullPolicy) { } /** - * Collect the values into a set. All null values will be excluded. And all nan values are regarded as + * Collect the values into a set. All null values will be excluded, qnd all nan values are regarded as * unique instances. */ public static Aggregation collectSet() { - return new CollectSetAggregation(NullPolicy.EXCLUDE, NullEquality.UNEQUAL, NanEquality.UNEQUAL); + return new CollectSetAggregation(NullPolicy.EXCLUDE, NullEquality.UNEQUAL, NaNEquality.UNEQUAL); } /** @@ -706,7 +706,7 @@ public static Aggregation collectSet() { * @param nullEquality Flag to specify whether null entries within each list should be considered equal. * @param nanEquality Flag to specify whether NaN values in floating point column should be considered equal. */ - public static Aggregation collectSet(NullPolicy nullPolicy, NullEquality nullEquality, NanEquality nanEquality) { + public static Aggregation collectSet(NullPolicy nullPolicy, NullEquality nullEquality, NaNEquality nanEquality) { return new CollectSetAggregation(nullPolicy, nullEquality, nanEquality); } diff --git a/java/src/test/java/ai/rapids/cudf/TableTest.java b/java/src/test/java/ai/rapids/cudf/TableTest.java index ded4b606da2..b398157983c 100644 --- a/java/src/test/java/ai/rapids/cudf/TableTest.java +++ b/java/src/test/java/ai/rapids/cudf/TableTest.java @@ -4427,7 +4427,7 @@ void testGroupByCollectListIncludeNulls() { void testGroupByCollectSetIncludeNulls() { // test with null unequal and nan unequal Aggregation collectSet = Aggregation.collectSet(Aggregation.NullPolicy.INCLUDE, - Aggregation.NullEquality.UNEQUAL, Aggregation.NanEquality.UNEQUAL); + Aggregation.NullEquality.UNEQUAL, Aggregation.NaNEquality.UNEQUAL); try (Table input = new Table.TestBuilder() .column(1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4) .column(null, 13, null, 13, 14, null, 15, null, 4, 1, 1, 4, 0, 0, 0, 0) @@ -4443,7 +4443,7 @@ void testGroupByCollectSetIncludeNulls() { } // test with null equal and nan unequal collectSet = Aggregation.collectSet(Aggregation.NullPolicy.INCLUDE, - Aggregation.NullEquality.EQUAL, Aggregation.NanEquality.UNEQUAL); + Aggregation.NullEquality.EQUAL, Aggregation.NaNEquality.UNEQUAL); try (Table input = new Table.TestBuilder() .column(1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4) .column(null, 13.0, null, 13.0, @@ -4456,7 +4456,7 @@ void testGroupByCollectSetIncludeNulls() { .column(new ListType(false, new BasicType(true, DType.FLOAT64)), Arrays.asList(13.0, null), Arrays.asList(13.9, 14.1, Double.NaN, Double.NaN), - Arrays.asList(1.0, null, Double.NaN), + Arrays.asList(1.0, Double.NaN, null), Arrays.asList((Integer) null)) .build(); Table found = input.groupBy(0).aggregate(collectSet.onColumn(1))) { @@ -4464,7 +4464,7 @@ void testGroupByCollectSetIncludeNulls() { } // test with null equal and nan equal collectSet = Aggregation.collectSet(Aggregation.NullPolicy.INCLUDE, - Aggregation.NullEquality.EQUAL, Aggregation.NanEquality.ALL_EQUAL); + Aggregation.NullEquality.EQUAL, Aggregation.NaNEquality.ALL_EQUAL); try (Table input = new Table.TestBuilder() .column(1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4) .column(null, 13.0, null, 13.0, @@ -4478,7 +4478,7 @@ void testGroupByCollectSetIncludeNulls() { Arrays.asList(13.0, null), Arrays.asList(13.9, 14.1, Double.NaN), Arrays.asList(0.0), - Arrays.asList((Integer) null, Double.NaN)) + Arrays.asList(Double.NaN, (Integer) null)) .build(); Table found = input.groupBy(0).aggregate(collectSet.onColumn(1))) { assertTablesAreEqual(expected, found); From c08c0ab223719022805359b1e73d67771d2592f0 Mon Sep 17 00:00:00 2001 From: sperlingxx Date: Fri, 30 Apr 2021 17:23:36 +0800 Subject: [PATCH 6/7] fix nit Signed-off-by: sperlingxx --- java/src/main/java/ai/rapids/cudf/Aggregation.java | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/java/src/main/java/ai/rapids/cudf/Aggregation.java b/java/src/main/java/ai/rapids/cudf/Aggregation.java index 0a0b55e0181..8dbe780a149 100644 --- a/java/src/main/java/ai/rapids/cudf/Aggregation.java +++ b/java/src/main/java/ai/rapids/cudf/Aggregation.java @@ -356,8 +356,10 @@ long createNativeInstance() { @Override public int hashCode() { - boolean[] configs = new boolean[]{nullPolicy.includeNulls, nullEquality.nullsEqual, nanEquality.nansEqual}; - return 31 * kind.hashCode() + Arrays.hashCode(configs); + return 31 * kind.hashCode() + + Boolean.hashCode(nullPolicy.includeNulls) + + Boolean.hashCode(nullEquality.nullsEqual) + + Boolean.hashCode(nanEquality.nansEqual); } @Override From 0d441b1264eadab8bb69f4727736a611a7df6a9c Mon Sep 17 00:00:00 2001 From: sperlingxx Date: Wed, 5 May 2021 11:34:48 +0800 Subject: [PATCH 7/7] fix nit Signed-off-by: sperlingxx --- java/src/main/java/ai/rapids/cudf/Aggregation.java | 4 ++-- java/src/main/native/src/AggregationJni.cpp | 11 ++++++----- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/java/src/main/java/ai/rapids/cudf/Aggregation.java b/java/src/main/java/ai/rapids/cudf/Aggregation.java index 8dbe780a149..35510372cc1 100644 --- a/java/src/main/java/ai/rapids/cudf/Aggregation.java +++ b/java/src/main/java/ai/rapids/cudf/Aggregation.java @@ -658,7 +658,7 @@ public static Aggregation rowNumber() { } /** - * Collect the values into a list. nulls will be skipped. + * Collect the values into a list. Nulls will be skipped. * @deprecated please use collectList as instead. */ @Deprecated @@ -694,7 +694,7 @@ public static Aggregation collectList(NullPolicy nullPolicy) { } /** - * Collect the values into a set. All null values will be excluded, qnd all nan values are regarded as + * Collect the values into a set. All null values will be excluded, and all nan values are regarded as * unique instances. */ public static Aggregation collectSet() { diff --git a/java/src/main/native/src/AggregationJni.cpp b/java/src/main/native/src/AggregationJni.cpp index 10839889d46..63c2c33202e 100644 --- a/java/src/main/native/src/AggregationJni.cpp +++ b/java/src/main/native/src/AggregationJni.cpp @@ -81,11 +81,12 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Aggregation_createNoParamAgg(JNIEnv case 17: // ROW_NUMBER ret = cudf::make_row_number_aggregation(); break; - // case 18: COLLECT - // case 19: LEAD - // case 20: LAG - // case 21: PTX - // case 22: CUDA + // case 18: COLLECT_LIST + // case 19: COLLECT_SET + // case 20: LEAD + // case 21: LAG + // case 22: PTX + // case 23: CUDA default: throw std::logic_error("Unsupported No Parameter Aggregation Operation"); }