Skip to content

Commit

Permalink
Java Support for Decimal 128 (#9485)
Browse files Browse the repository at this point in the history
This depends on #9483

There may be a few more changes coming to this, but it should be fairly complete
  • Loading branch information
revans2 authored Nov 17, 2021
1 parent c1f20c7 commit 9aefbc2
Show file tree
Hide file tree
Showing 16 changed files with 784 additions and 149 deletions.
33 changes: 15 additions & 18 deletions java/src/main/java/ai/rapids/cudf/BinaryOperable.java
Original file line number Diff line number Diff line change
Expand Up @@ -80,25 +80,22 @@ static DType implicitConversion(BinaryOp op, BinaryOperable lhs, BinaryOperable
return DType.BOOL8;
}
if (a.isDecimalType() && b.isDecimalType()) {
// Here scale is created with value 0 as `scale` is required to create DType of
// decimal type. Dtype is discarded for binary operations for decimal types in cudf as a new
// DType is created for output type with new scale. New scale for output depends upon operator.
int scale = 0;
if (a.typeId == DType.DTypeEnum.DECIMAL32) {
if (b.typeId == DType.DTypeEnum.DECIMAL32) {
return DType.create(DType.DTypeEnum.DECIMAL32,
ColumnView.getFixedPointOutputScale(op, lhs.getType(), rhs.getType()));
} else {
throw new IllegalArgumentException("Both columns must be of the same fixed_point type");
}
} else if (a.typeId == DType.DTypeEnum.DECIMAL64) {
if (b.typeId == DType.DTypeEnum.DECIMAL64) {
return DType.create(DType.DTypeEnum.DECIMAL64,
ColumnView.getFixedPointOutputScale(op, lhs.getType(), rhs.getType()));
} else {
throw new IllegalArgumentException("Both columns must be of the same fixed_point type");
}
if (a.typeId != b.typeId) {
throw new IllegalArgumentException("Both columns must be of the same fixed_point type");
}
final int scale = ColumnView.getFixedPointOutputScale(op, lhs.getType(), rhs.getType());
// The output precision/size should be at least as large as the input.
// It may be larger if room is needed for it based off of the output scale.
final DType.DTypeEnum outputEnum;
if (scale <= DType.DECIMAL32_MAX_PRECISION && a.typeId == DType.DTypeEnum.DECIMAL32) {
outputEnum = DType.DTypeEnum.DECIMAL32;
} else if (scale <= DType.DECIMAL64_MAX_PRECISION &&
(a.typeId == DType.DTypeEnum.DECIMAL32 || a.typeId == DType.DTypeEnum.DECIMAL64)) {
outputEnum = DType.DTypeEnum.DECIMAL64;
} else {
outputEnum = DType.DTypeEnum.DECIMAL128;
}
return DType.create(outputEnum, scale);
}
throw new IllegalArgumentException("Unsupported types " + a + " and " + b);
}
Expand Down
13 changes: 13 additions & 0 deletions java/src/main/java/ai/rapids/cudf/ColumnVector.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.slf4j.LoggerFactory;

import java.math.BigDecimal;
import java.math.BigInteger;
import java.math.RoundingMode;
import java.nio.ByteBuffer;
import java.util.ArrayList;
Expand Down Expand Up @@ -1391,6 +1392,18 @@ public static ColumnVector decimalFromDoubles(DType type, RoundingMode mode, dou
}
}


/**
* Create a new decimal vector from BigIntegers
* Compared with scale of [[java.math.BigDecimal]], the scale here represents the opposite meaning.
*/
public static ColumnVector decimalFromBigInt(int scale, BigInteger... values) {
try (HostColumnVector host = HostColumnVector.decimalFromBigIntegers(scale, values)) {
ColumnVector columnVector = host.copyToDevice();
return columnVector;
}
}

/**
* Create a new string vector from the given values. This API
* supports inline nulls. This is really intended to be used only for testing as
Expand Down
17 changes: 15 additions & 2 deletions java/src/main/java/ai/rapids/cudf/DType.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ public final class DType {

public static final int DECIMAL32_MAX_PRECISION = 9;
public static final int DECIMAL64_MAX_PRECISION = 18;
public static final int DECIMAL128_MAX_PRECISION = 38;

/* enum representing various types. Whenever a new non-decimal type is added please make sure
below sections are updated as well:
Expand Down Expand Up @@ -77,7 +78,8 @@ public enum DTypeEnum {
LIST(0, 24),
DECIMAL32(4, 25),
DECIMAL64(8, 26),
STRUCT(0, 27);
DECIMAL128(16, 27),
STRUCT(0, 28);

final int sizeInBytes;
final int nativeId;
Expand Down Expand Up @@ -167,6 +169,7 @@ private DType(DTypeEnum id, int decimalScale) {
LIST,
null, // DECIMAL32
null, // DECIMAL64
null, // DECIMAL128
STRUCT
};

Expand Down Expand Up @@ -276,6 +279,13 @@ public static DType fromNative(int nativeId, int scale) {
}
return new DType(DTypeEnum.DECIMAL64, scale);
}
if (nativeId == DTypeEnum.DECIMAL128.nativeId) {
if (-scale > DECIMAL128_MAX_PRECISION) {
throw new IllegalArgumentException(
"Scale " + (-scale) + " exceeds DECIMAL128_MAX_PRECISION " + DECIMAL128_MAX_PRECISION);
}
return new DType(DTypeEnum.DECIMAL128, scale);
}
}
throw new IllegalArgumentException("Could not translate " + nativeId + " into a DType");
}
Expand All @@ -293,6 +303,8 @@ public static DType fromJavaBigDecimal(BigDecimal dec) {
return new DType(DTypeEnum.DECIMAL32, -dec.scale());
} else if (dec.precision() <= DECIMAL64_MAX_PRECISION) {
return new DType(DTypeEnum.DECIMAL64, -dec.scale());
} else if (dec.precision() <= DECIMAL128_MAX_PRECISION) {
return new DType(DTypeEnum.DECIMAL128, -dec.scale());
}
throw new IllegalArgumentException("Precision " + dec.precision() +
" exceeds max precision cuDF can support " + DECIMAL64_MAX_PRECISION);
Expand Down Expand Up @@ -450,7 +462,8 @@ public boolean hasOffsets() {

private static final EnumSet<DTypeEnum> DECIMALS = EnumSet.of(
DTypeEnum.DECIMAL32,
DTypeEnum.DECIMAL64
DTypeEnum.DECIMAL64,
DTypeEnum.DECIMAL128
);

private static final EnumSet<DTypeEnum> NESTED_TYPE = EnumSet.of(
Expand Down
40 changes: 38 additions & 2 deletions java/src/main/java/ai/rapids/cudf/HostColumnVector.java
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,23 @@ public static HostColumnVector decimalFromBoxedLongs(int scale, Long... values)
});
}

/**
* Create a new decimal vector from unscaled values (BigInteger array) and scale.
* The created vector is of type DType.DECIMAL128.
* Compared with scale of [[java.math.BigDecimal]], the scale here represents the opposite meaning.
*/
public static HostColumnVector decimalFromBigIntegers(int scale, BigInteger... values) {
return build(DType.create(DType.DTypeEnum.DECIMAL128, scale), values.length, (b) -> {
for (BigInteger v : values) {
if (v == null) {
b.appendNull();
} else {
b.appendUnscaledDecimal(v);
}
}
});
}

/**
* Create a new decimal vector from double floats with specific DecimalType and RoundingMode.
* All doubles will be rescaled if necessary, according to scale of input DecimalType and RoundingMode.
Expand Down Expand Up @@ -1222,7 +1239,12 @@ public final ColumnBuilder append(BigDecimal value) {
data.setInt(currentIndex * type.getSizeInBytes(), unscaledVal.intValueExact());
} else if (type.typeId == DType.DTypeEnum.DECIMAL64) {
data.setLong(currentIndex * type.getSizeInBytes(), unscaledVal.longValueExact());
} else {
} else if (type.typeId == DType.DTypeEnum.DECIMAL128) {
assert currentIndex < rows;
byte[] unscaledValueBytes = value.unscaledValue().toByteArray();
byte[] result = convertDecimal128FromJavaToCudf(unscaledValueBytes);
data.setBytes(currentIndex*DType.DTypeEnum.DECIMAL128.sizeInBytes, result, 0, result.length);
} else {
throw new IllegalStateException(type + " is not a supported decimal type.");
}
currentIndex++;
Expand Down Expand Up @@ -1450,14 +1472,18 @@ public final Builder append(BigDecimal value) {
*/
public final Builder append(BigDecimal value, RoundingMode roundingMode) {
assert type.isDecimalType();
assert currentIndex < rows;
assert currentIndex < rows: "appended too many values " + currentIndex + " out of total rows " + rows;
BigInteger unscaledValue = value.setScale(-type.getScale(), roundingMode).unscaledValue();
if (type.typeId == DType.DTypeEnum.DECIMAL32) {
assert value.precision() <= DType.DECIMAL32_MAX_PRECISION : "value exceeds maximum precision for DECIMAL32";
data.setInt(currentIndex * type.getSizeInBytes(), unscaledValue.intValueExact());
} else if (type.typeId == DType.DTypeEnum.DECIMAL64) {
assert value.precision() <= DType.DECIMAL64_MAX_PRECISION : "value exceeds maximum precision for DECIMAL64 ";
data.setLong(currentIndex * type.getSizeInBytes(), unscaledValue.longValueExact());
} else if (type.typeId == DType.DTypeEnum.DECIMAL128) {
assert value.precision() <= DType.DECIMAL128_MAX_PRECISION : "value exceeds maximum precision for DECIMAL128 ";
appendUnscaledDecimal(value.unscaledValue());
return this;
} else {
throw new IllegalStateException(type + " is not a supported decimal type.");
}
Expand All @@ -1481,6 +1507,16 @@ public final Builder appendUnscaledDecimal(long value) {
return this;
}

public final Builder appendUnscaledDecimal(BigInteger value) {
assert type.typeId == DType.DTypeEnum.DECIMAL128;
assert currentIndex < rows;
byte[] unscaledValueBytes = value.toByteArray();
byte[] result = convertDecimal128FromJavaToCudf(unscaledValueBytes);
data.setBytes(currentIndex*DType.DTypeEnum.DECIMAL128.sizeInBytes, result, 0, result.length);
currentIndex++;
return this;
}

public Builder append(String value) {
assert value != null : "appendNull must be used to append null strings";
return appendUTF8String(value.getBytes(StandardCharsets.UTF_8));
Expand Down
49 changes: 40 additions & 9 deletions java/src/main/java/ai/rapids/cudf/HostColumnVectorCore.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import org.slf4j.LoggerFactory;

import java.math.BigDecimal;
import java.math.BigInteger;
import java.nio.ByteOrder;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;
Expand Down Expand Up @@ -341,6 +343,13 @@ public final BigDecimal getBigDecimal(long index) {
} else if (type.typeId == DType.DTypeEnum.DECIMAL64) {
long unscaledValue = offHeap.data.getLong(index * type.getSizeInBytes());
return BigDecimal.valueOf(unscaledValue, -type.getScale());
} else if (type.typeId == DType.DTypeEnum.DECIMAL128) {
int sizeInBytes = DType.DTypeEnum.DECIMAL128.sizeInBytes;
byte[] dst = new byte[sizeInBytes];
// We need to switch the endianness for decimal128 byte arrays between java and native code.
offHeap.data.getBytes(dst, 0, (index * sizeInBytes), sizeInBytes);
convertInPlaceToBigEndian(dst);
return new BigDecimal(new BigInteger(dst), -type.getScale());
} else {
throw new IllegalStateException(type + " is not a supported decimal type.");
}
Expand Down Expand Up @@ -534,6 +543,34 @@ public String toString() {
'}';
}

protected static byte[] convertDecimal128FromJavaToCudf(byte[] bytes) {
byte[] finalBytes = new byte[DType.DTypeEnum.DECIMAL128.sizeInBytes];
byte lastByte = bytes[0];
//Convert to 2's complement representation and make sure the sign bit is extended correctly
byte setByte = (lastByte & 0x80) > 0 ? (byte)0xff : (byte)0x00;
for(int i = bytes.length; i < finalBytes.length; i++) {
finalBytes[i] = setByte;
}
// After setting the sign bits, reverse the rest of the bytes for endianness
for(int k = 0; k < bytes.length; k++) {
finalBytes[k] = bytes[bytes.length - k - 1];
}
return finalBytes;
}

private void convertInPlaceToBigEndian(byte[] dst) {
assert ByteOrder.nativeOrder().equals(ByteOrder.LITTLE_ENDIAN);
int i =0;
int j = dst.length -1;
while (j > i) {
byte tmp;
tmp = dst[j];
dst[j] = dst[i];
dst[i] = tmp;
j--;
i++;
}
}
/////////////////////////////////////////////////////////////////////////////
// HELPER CLASSES
/////////////////////////////////////////////////////////////////////////////
Expand All @@ -557,15 +594,9 @@ protected synchronized boolean cleanImpl(boolean logErrorIfNotClean) {
boolean neededCleanup = false;
if (data != null || valid != null || offsets != null) {
try {
if (data != null) {
data.close();
}
if (offsets != null) {
offsets.close();
}
if (valid != null) {
valid.close();
}
ColumnVector.closeBuffers(data);
ColumnVector.closeBuffers(offsets);
ColumnVector.closeBuffers(valid);
} finally {
// Always mark the resource as freed even if an exception is thrown.
// We cannot know how far it progressed before the exception, and
Expand Down
31 changes: 30 additions & 1 deletion java/src/main/java/ai/rapids/cudf/ORCOptions.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*
*
* Copyright (c) 2019, NVIDIA CORPORATION.
* Copyright (c) 2019-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -18,6 +18,10 @@

package ai.rapids.cudf;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

/**
* Options for reading a ORC file
*/
Expand All @@ -27,9 +31,11 @@ public class ORCOptions extends ColumnFilterOptions {

private final boolean useNumPyTypes;
private final DType unit;
private final String[] decimal128Columns;

private ORCOptions(Builder builder) {
super(builder);
decimal128Columns = builder.decimal128Columns.toArray(new String[0]);
useNumPyTypes = builder.useNumPyTypes;
unit = builder.unit;
}
Expand All @@ -42,6 +48,10 @@ DType timeUnit() {
return unit;
}

String[] getDecimal128Columns() {
return decimal128Columns;
}

public static Builder builder() {
return new Builder();
}
Expand All @@ -50,6 +60,8 @@ public static class Builder extends ColumnFilterOptions.Builder<Builder> {
private boolean useNumPyTypes = true;
private DType unit = DType.EMPTY;

final List<String> decimal128Columns = new ArrayList<>();

/**
* Specify whether the parser should implicitly promote TIMESTAMP_DAYS
* columns to TIMESTAMP_MILLISECONDS for compatibility with NumPy.
Expand All @@ -73,6 +85,23 @@ public ORCOptions.Builder withTimeUnit(DType unit) {
return this;
}

/**
* Specify decimal columns which shall be read as DECIMAL128. Otherwise, decimal columns
* will be read as DECIMAL64 by default in ORC.
*
* In terms of child columns of nested types, their parents need to be prepended as prefix
* of the column name, in case of ambiguity. For struct columns, the names of child columns
* are formatted as `{struct_col_name}.{child_col_name}`. For list columns, the data(child)
* columns are named as `{list_col_name}.1`.
*
* @param names names of columns which read as DECIMAL128
* @return builder for chaining
*/
public Builder decimal128Column(String... names) {
decimal128Columns.addAll(Arrays.asList(names));
return this;
}

public ORCOptions build() { return new ORCOptions(this); }
}
}
Loading

0 comments on commit 9aefbc2

Please sign in to comment.