Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor AggregationJni to support collectSet [skip ci] #8057

Merged
merged 9 commits into from
May 5, 2021
174 changes: 142 additions & 32 deletions java/src/main/java/ai/rapids/cudf/Aggregation.java
Original file line number Diff line number Diff line change
Expand Up @@ -40,25 +40,27 @@ enum Kind {
PRODUCT(1),
MIN(2),
MAX(3),
COUNT(4),
ANY(5),
ALL(6),
SUM_OF_SQUARES(7),
MEAN(8),
VARIANCE(9), // This can take a delta degrees of freedom
STD(10), // This can take a delta degrees of freedom
MEDIAN(11),
QUANTILE(12),
ARGMAX(13),
ARGMIN(14),
NUNIQUE(15),
NTH_ELEMENT(16),
ROW_NUMBER(17),
COLLECT(18),
LEAD(19),
LAG(20),
PTX(21),
CUDA(22);
COUNT_VALID(4),
COUNT_ALL(5),
ANY(6),
ALL(7),
SUM_OF_SQUARES(8),
MEAN(9),
VARIANCE(10), // This can take a delta degrees of freedom
STD(11), // This can take a delta degrees of freedom
MEDIAN(12),
QUANTILE(13),
ARGMAX(14),
ARGMIN(15),
NUNIQUE(16),
NTH_ELEMENT(17),
ROW_NUMBER(18),
COLLECT_LIST(19),
COLLECT_SET(20),
LEAD(21),
LAG(22),
PTX(23),
CUDA(24);

final int nativeId;

Expand All @@ -77,6 +79,30 @@ public enum NullPolicy {
final boolean includeNulls;
}

/*
* This is analogous to the native 'null_equality'.
*/
public enum NullEquality {
UNEQUAL(false),
EQUAL(true);

NullEquality(boolean nullsEqual) { this.nullsEqual = nullsEqual; }

final boolean nullsEqual;
}

/*
* This is analogous to the native 'nan_equality'.
*/
public enum NaNEquality {
UNEQUAL(false),
ALL_EQUAL(true);

NaNEquality(boolean nansEqual) { this.nansEqual = nansEqual; }

final boolean nansEqual;
}

/**
* An Aggregation that only needs a kind and nothing else.
*/
Expand Down Expand Up @@ -280,17 +306,17 @@ long getDefaultOutput() {
}
}

private static final class CollectAggregation extends Aggregation {
private static final class CollectListAggregation extends Aggregation {
private final NullPolicy nullPolicy;

public CollectAggregation(NullPolicy nullPolicy) {
super(Kind.COLLECT);
public CollectListAggregation(NullPolicy nullPolicy) {
super(Kind.COLLECT_LIST);
this.nullPolicy = nullPolicy;
}

@Override
long createNativeInstance() {
return Aggregation.createCollectAgg(nullPolicy.includeNulls);
return Aggregation.createCollectListAgg(nullPolicy.includeNulls);
}

@Override
Expand All @@ -302,14 +328,53 @@ public int hashCode() {
public boolean equals(Object other) {
if (this == other) {
return true;
} else if (other instanceof CollectAggregation) {
CollectAggregation o = (CollectAggregation) other;
} else if (other instanceof CollectListAggregation) {
CollectListAggregation o = (CollectListAggregation) other;
return o.nullPolicy == this.nullPolicy;
}
return false;
}
}

private static final class CollectSetAggregation extends Aggregation {
private final NullPolicy nullPolicy;
private final NullEquality nullEquality;
private final NaNEquality nanEquality;

public CollectSetAggregation(NullPolicy nullPolicy, NullEquality nullEquality, NaNEquality nanEquality) {
super(Kind.COLLECT_SET);
this.nullPolicy = nullPolicy;
this.nullEquality = nullEquality;
this.nanEquality = nanEquality;
}

@Override
long createNativeInstance() {
return Aggregation.createCollectSetAgg(nullPolicy.includeNulls,
nullEquality.nullsEqual,
nanEquality.nansEqual);
}

@Override
public int hashCode() {
boolean[] configs = new boolean[]{nullPolicy.includeNulls, nullEquality.nullsEqual, nanEquality.nansEqual};
return 31 * kind.hashCode() + Arrays.hashCode(configs);
jlowe marked this conversation as resolved.
Show resolved Hide resolved
}

@Override
public boolean equals(Object other) {
if (this == other) {
return true;
} else if (other instanceof CollectSetAggregation) {
CollectSetAggregation o = (CollectSetAggregation) other;
return o.nullPolicy == this.nullPolicy &&
o.nullEquality == this.nullEquality &&
o.nanEquality == this.nanEquality;
}
return false;
}
}

protected final Kind kind;

protected Aggregation(Kind kind) {
Expand Down Expand Up @@ -413,7 +478,8 @@ public static Aggregation count(boolean includeNulls) {
* should be counted.
*/
public static Aggregation count(NullPolicy nullPolicy) {
return new CountLikeAggregation(Kind.COUNT, nullPolicy);
Aggregation.Kind kind = nullPolicy.includeNulls ? Kind.COUNT_ALL : Kind.COUNT_VALID;
return new CountLikeAggregation(kind, nullPolicy);
}

/**
Expand Down Expand Up @@ -593,18 +659,57 @@ public static Aggregation rowNumber() {

/**
* Collect the values into a list. nulls will be skipped.
sperlingxx marked this conversation as resolved.
Show resolved Hide resolved
* WARNING: This method is deprecated, please use collectList as instead.
jlowe marked this conversation as resolved.
Show resolved Hide resolved
*/
@Deprecated
public static Aggregation collect() {
return collect(NullPolicy.EXCLUDE);
return collectList();
}

/**
* Collect the values into a list.
* @param nullPolicy INCLUDE if nulls should be included in the aggregation or EXCLUDE if they
* should be skipped.
* WARNING: This method is deprecated, please use collectList as instead.
jlowe marked this conversation as resolved.
Show resolved Hide resolved
*
* @param nullPolicy Indicates whether to include/exclude nulls during collection.
*/
@Deprecated
public static Aggregation collect(NullPolicy nullPolicy) {
return new CollectAggregation(nullPolicy);
return collectList(nullPolicy);
}

/**
* Collect the values into a list. nulls will be skipped.
jlowe marked this conversation as resolved.
Show resolved Hide resolved
*/
public static Aggregation collectList() {
return collectList(NullPolicy.EXCLUDE);
}

/**
* Collect the values into a list.
*
* @param nullPolicy Indicates whether to include/exclude nulls during collection.
*/
public static Aggregation collectList(NullPolicy nullPolicy) {
return new CollectListAggregation(nullPolicy);
}

/**
* Collect the values into a set. All null values will be excluded. And all nan values are regarded as
jlowe marked this conversation as resolved.
Show resolved Hide resolved
* unique instances.
*/
public static Aggregation collectSet() {
return new CollectSetAggregation(NullPolicy.EXCLUDE, NullEquality.UNEQUAL, NaNEquality.UNEQUAL);
}

/**
* Collect the values into a set.
*
* @param nullPolicy Indicates whether to include/exclude nulls during collection.
* @param nullEquality Flag to specify whether null entries within each list should be considered equal.
* @param nanEquality Flag to specify whether NaN values in floating point column should be considered equal.
*/
public static Aggregation collectSet(NullPolicy nullPolicy, NullEquality nullEquality, NaNEquality nanEquality) {
return new CollectSetAggregation(nullPolicy, nullEquality, nanEquality);
}

/**
Expand Down Expand Up @@ -675,7 +780,12 @@ public static Aggregation lag(int offset, ColumnVector defaultOutput) {
private static native long createLeadLagAgg(int kind, int offset);

/**
* Create a collect aggregation including nulls or not.
* Create a collect list aggregation including nulls or not.
*/
private static native long createCollectListAgg(boolean includeNulls);

/**
* Create a collect set aggregation.
*/
private static native long createCollectAgg(boolean includeNulls);
private static native long createCollectSetAgg(boolean includeNulls, boolean nullsEqual, boolean nansEqual);
}
Loading