Skip to content

Commit

Permalink
added the id back for struct children (#4998)
Browse files Browse the repository at this point in the history
Signed-off-by: Raza Jafri <rjafri@nvidia.com>

Co-authored-by: Raza Jafri <rjafri@nvidia.com>
  • Loading branch information
razajafri and razajafri authored Mar 22, 2022
1 parent 606b4ca commit ad2cc79
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 5 deletions.
11 changes: 11 additions & 0 deletions integration_tests/src/main/python/cache_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,17 @@ def partial_return_cache(spark):
assert_gpu_and_cpu_are_equal_collect(partial_return(f.col("a")), conf=enable_vectorized_conf)
assert_gpu_and_cpu_are_equal_collect(partial_return(f.col("b")), conf=enable_vectorized_conf)

@pytest.mark.parametrize('enable_vectorized_conf', enable_vectorized_confs, ids=idfn)
@allow_non_gpu('CollectLimitExec')
def test_cache_reverse_order(enable_vectorized_conf):
col0 = StructGen([['child0', StructGen([['child1', byte_gen]])]])
col1 = StructGen([['child0', byte_gen]])
def partial_return():
def partial_return_cache(spark):
return two_col_df(spark, col0, col1).select(f.col("a"), f.col("b")).cache().limit(50).select(f.col("b"), f.col("a"))
return partial_return_cache
assert_gpu_and_cpu_are_equal_collect(partial_return(), conf=enable_vectorized_conf)

@allow_non_gpu('CollectLimitExec')
def test_cache_diff_req_order(spark_tmp_path):
def n_fold(spark):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@

package org.apache.spark.sql.rapids

import java.util.concurrent.atomic.AtomicLong

import com.nvidia.spark.rapids.shims.GpuTypeShims

import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
Expand Down Expand Up @@ -64,6 +62,7 @@ object PCBSSchemaHelper {
* long type.
*/
def getSupportedDataType(dataType: DataType): DataType = {
var curId = 0
dataType match {
case CalendarIntervalType =>
calendarIntervalStructType
Expand All @@ -72,9 +71,11 @@ object PCBSSchemaHelper {
case s: StructType =>
val newStructType = StructType(
s.indices.map { index =>
StructField(s.fields(index).name,
val field = StructField(s"_col$curId",
getSupportedDataType(s.fields(index).dataType),
s.fields(index).nullable, s.fields(index).metadata)
curId += 1
field
})
newStructType
case _@ArrayType(elementType, nullable) =>
Expand Down Expand Up @@ -106,10 +107,11 @@ object PCBSSchemaHelper {
*/
def getSupportedSchemaFromUnsupported(cachedAttributes: Seq[Attribute]): Seq[Attribute] = {
// We convert CalendarIntervalType, UDT and NullType ATM convert it to a supported type
val curId = new AtomicLong()
var curId = 0
cachedAttributes.map {
attribute =>
val name = s"_col${curId.getAndIncrement()}"
val name = s"_col$curId"
curId += 1
attribute.dataType match {
case CalendarIntervalType =>
AttributeReference(name, calendarIntervalStructType,
Expand Down

0 comments on commit ad2cc79

Please sign in to comment.