Skip to content

Commit

Permalink
Fix collect_set on struct type (#4996)
Browse files Browse the repository at this point in the history
correct wrong NullEquality for collect_set, specifically for collect_set on structure type

Signed-off-by: sperlingxx <lovedreamf@gmail.com>
  • Loading branch information
sperlingxx authored Mar 22, 2022
1 parent ad2cc79 commit 7e0e66d
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 25 deletions.
24 changes: 4 additions & 20 deletions integration_tests/src/main/python/hash_aggregate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,6 +578,7 @@ def test_hash_reduction_pivot_without_nans(data_gen, conf):
_repeat_agg_column_for_collect_list_op = [
RepeatSeqGen(ArrayGen(int_gen), length=15),
RepeatSeqGen(all_basic_struct_gen, length=15),
RepeatSeqGen(StructGen([['c0', all_basic_struct_gen]]), length=15),
RepeatSeqGen(simple_string_to_string_map_gen, length=15)]

_gen_data_for_collect_list_op = _full_gen_data_for_collect_op + [[
Expand All @@ -586,11 +587,8 @@ def test_hash_reduction_pivot_without_nans(data_gen, conf):

_repeat_agg_column_for_collect_set_op = [
RepeatSeqGen(all_basic_struct_gen, length=15),
RepeatSeqGen(StructGen([['child0', all_basic_struct_gen]]), length=15)]

_gen_data_for_collect_set_op_for_unique_group_by_key = [[
('a', LongRangeGen()),
('b', value_gen)] for value_gen in _repeat_agg_column_for_collect_set_op]
RepeatSeqGen(StructGen([
['c0', all_basic_struct_gen], ['c1', int_gen]]), length=15)]

_gen_data_for_collect_set_op = [[
('a', RepeatSeqGen(LongGen(), length=20)),
Expand Down Expand Up @@ -654,25 +652,11 @@ def test_hash_groupby_collect_set(data_gen):
@ignore_order(local=True)
@incompat
@pytest.mark.parametrize('data_gen', _gen_data_for_collect_set_op, ids=idfn)
@pytest.mark.xfail(reason="the result order from collect-set can not be ensured for CPU and GPU."
" We need to enable this after SortArray has supported on nested types."
" See https://github.com/NVIDIA/spark-rapids/issues/3715")
def test_hash_groupby_collect_set_on_nested_type(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark: gen_df(spark, data_gen, length=100)
.groupby('a')
.agg(f.sort_array(f.collect_set('b')), f.count('b')))

# After https://github.com/NVIDIA/spark-rapids/issues/3715 is fixed, we should remove this test case
@approximate_float
@ignore_order(local=True)
@incompat
@pytest.mark.parametrize('data_gen', _gen_data_for_collect_set_op_for_unique_group_by_key, ids=idfn)
def test_hash_groupby_collect_set_on_nested_type_for_unique_group_by(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark: gen_df(spark, data_gen, length=100)
.groupby('a')
.agg(f.collect_set('b')))
.agg(f.sort_array(f.collect_set('b'))))

@approximate_float
@ignore_order(local=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3222,8 +3222,8 @@ object GpuOverrides extends Logging {
expr[CollectSet](
"Collect a set of unique elements, not supported in reduction",
// GpuCollectSet is not yet supported in Reduction context.
// Compared to CollectList, StructType is NOT in GpuCollectSet because underlying
// method drop_list_duplicates doesn't support nested types.
// Compared to CollectList, ArrayType and MapType are NOT supported in GpuCollectSet
// because underlying cuDF operator drop_list_duplicates doesn't support LIST type.
ExprChecks.aggNotReduction(
TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 +
TypeSig.NULL + TypeSig.STRUCT),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
package org.apache.spark.sql.rapids

import ai.rapids.cudf
import ai.rapids.cudf.{Aggregation128Utils, BinaryOp, ColumnVector, DType, GroupByAggregation, GroupByScanAggregation, NullPolicy, ReductionAggregation, ReplacePolicy, RollingAggregation, RollingAggregationOnColumn, Scalar, ScanAggregation}
import ai.rapids.cudf.{Aggregation128Utils, BinaryOp, ColumnVector, DType, GroupByAggregation, GroupByScanAggregation, NaNEquality, NullEquality, NullPolicy, ReductionAggregation, ReplacePolicy, RollingAggregation, RollingAggregationOnColumn, Scalar, ScanAggregation}
import com.nvidia.spark.rapids._
import com.nvidia.spark.rapids.shims.{GpuDeterministicFirstLastCollectShim, ShimExpression, ShimUnaryExpression}

Expand Down Expand Up @@ -372,15 +372,15 @@ class CudfCollectSet(override val dataType: DataType) extends CudfAggregate {
override lazy val reductionAggregate: cudf.ColumnVector => cudf.Scalar = _ =>
throw new UnsupportedOperationException("CollectSet is not yet supported in reduction")
override lazy val groupByAggregate: GroupByAggregation =
GroupByAggregation.collectSet()
GroupByAggregation.collectSet(NullPolicy.EXCLUDE, NullEquality.EQUAL, NaNEquality.UNEQUAL)
override val name: String = "CudfCollectSet"
}

class CudfMergeSets(override val dataType: DataType) extends CudfAggregate {
override lazy val reductionAggregate: cudf.ColumnVector => cudf.Scalar = _ =>
throw new UnsupportedOperationException("CudfMergeSets is not yet supported in reduction")
override lazy val groupByAggregate: GroupByAggregation =
GroupByAggregation.mergeSets()
GroupByAggregation.mergeSets(NullEquality.EQUAL, NaNEquality.UNEQUAL)
override val name: String = "CudfMergeSets"
}

Expand Down

0 comments on commit 7e0e66d

Please sign in to comment.