Skip to content

Commit

Permalink
[VL] Support Spark transform_keys, transform_values function (#6095)
Browse files Browse the repository at this point in the history
  • Loading branch information
gaoyangxiaozhu authored Jun 18, 2024
1 parent 3ba3172 commit 800cadd
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.apache.gluten.datasource.ArrowConvertorRule
import org.apache.gluten.exception.GlutenNotSupportException
import org.apache.gluten.execution._
import org.apache.gluten.expression._
import org.apache.gluten.expression.ExpressionNames.{TRANSFORM_KEYS, TRANSFORM_VALUES}
import org.apache.gluten.expression.aggregate.{HLLAdapter, VeloxBloomFilterAggregate, VeloxCollectList, VeloxCollectSet}
import org.apache.gluten.extension._
import org.apache.gluten.extension.columnar.TransformHints
Expand Down Expand Up @@ -854,7 +855,9 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi {
Sig[VeloxCollectList](ExpressionNames.COLLECT_LIST),
Sig[VeloxCollectSet](ExpressionNames.COLLECT_SET),
Sig[VeloxBloomFilterMightContain](ExpressionNames.MIGHT_CONTAIN),
Sig[VeloxBloomFilterAggregate](ExpressionNames.BLOOM_FILTER_AGG)
Sig[VeloxBloomFilterAggregate](ExpressionNames.BLOOM_FILTER_AGG),
Sig[TransformKeys](TRANSFORM_KEYS),
Sig[TransformValues](TRANSFORM_VALUES)
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -537,6 +537,46 @@ class ScalarFunctionsValidateSuite extends FunctionsValidateTest {
}
}

test("test transform_keys function") {
withTempPath {
path =>
Seq(
Map[String, Int]("a" -> 1, "b" -> 2),
Map[String, Int]("a" -> 2, "b" -> 3),
null
)
.toDF("m")
.write
.parquet(path.getCanonicalPath)

spark.read.parquet(path.getCanonicalPath).createOrReplaceTempView("map_tbl")

runQueryAndCompare("select transform_keys(m, (k, v) -> upper(k)) from map_tbl") {
checkGlutenOperatorMatch[ProjectExecTransformer]
}
}
}

test("test transform_values function") {
withTempPath {
path =>
Seq(
Map[String, Int]("a" -> 1, "b" -> 2),
Map[String, Int]("a" -> 2, "b" -> 3),
null
)
.toDF("m")
.write
.parquet(path.getCanonicalPath)

spark.read.parquet(path.getCanonicalPath).createOrReplaceTempView("map_tbl")

runQueryAndCompare("select transform_values(m, (k, v) -> v + 1) from map_tbl") {
checkGlutenOperatorMatch[ProjectExecTransformer]
}
}
}

test("zip_with") {
withTempPath {
path =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -656,6 +656,21 @@ object ExpressionConverter extends SQLConfHelper with Logging {
Seq(replaceWithExpressionTransformerInternal(c.child, attributeSeq, expressionsMap)),
c
)
case t: TransformKeys =>
// default is `EXCEPTION`
val mapKeyDedupPolicy = SQLConf.get.getConf(SQLConf.MAP_KEY_DEDUP_POLICY)
if (mapKeyDedupPolicy == SQLConf.MapKeyDedupPolicy.LAST_WIN.toString) {
// TODO: Remove after fix ready for
// https://github.com/facebookincubator/velox/issues/10219
throw new GlutenNotSupportException(
"LAST_WIN policy is not supported yet in native to deduplicate map keys"
)
}
GenericExpressionTransformer(
substraitExprName,
t.children.map(replaceWithExpressionTransformerInternal(_, attributeSeq, expressionsMap)),
t
)
case expr =>
GenericExpressionTransformer(
substraitExprName,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,8 @@ object ExpressionNames {
final val MAP_FROM_ARRAYS = "map_from_arrays"
final val MAP_ENTRIES = "map_entries"
final val MAP_ZIP_WITH = "map_zip_with"
final val TRANSFORM_KEYS = "transform_keys"
final val TRANSFORM_VALUES = "transform_values"
final val STR_TO_MAP = "str_to_map"

// struct functions
Expand Down

0 comments on commit 800cadd

Please sign in to comment.