Skip to content

Commit f6ff7d0

Browse files
Ngone51cloud-fan
authored andcommitted
[SPARK-30127][SQL] Support case class parameter for typed Scala UDF
### What changes were proposed in this pull request? To support case class parameter for typed Scala UDF, e.g. ``` case class TestData(key: Int, value: String) val f = (d: TestData) => d.key * d.value.toInt val myUdf = udf(f) val df = Seq(("data", TestData(50, "2"))).toDF("col1", "col2") checkAnswer(df.select(myUdf(Column("col2"))), Row(100) :: Nil) ``` ### Why are the changes needed? Currently, Spark UDF can only work on data types like java.lang.String, o.a.s.sql.Row, Seq[_], etc. This is inconvenient if user want to apply an operation on one column, and the column is struct type. You must access data from a Row object, instead of domain object like Dataset operations. It will be great if UDF can work on types that are supported by Dataset, e.g. case class. And here's benchmark result of using case class comparing to row: ```scala // case class: 58ms 65ms 59ms 64ms 61ms // row: 59ms 64ms 73ms 84ms 69ms val f1 = (d: TestData) => s"${d.key}, ${d.value}" val f2 = (r: Row) => s"${r.getInt(0)}, ${r.getString(1)}" val udf1 = udf(f1) // set spark.sql.legacy.allowUntypedScalaUDF=true val udf2 = udf(f2, StringType) val df = spark.range(100000).selectExpr("cast (id as int) as id") .select(struct('id, lit("str")).as("col")) df.cache().collect() // warmup to exclude some extra influence df.select(udf1('col)).write.mode(SaveMode.Overwrite).format("noop").save() df.select(udf2('col)).write.mode(SaveMode.Overwrite).format("noop").save() start = System.currentTimeMillis() df.select(udf1('col)).write.mode(SaveMode.Overwrite).format("noop").save() println(System.currentTimeMillis() - start) start = System.currentTimeMillis() df.select(udf2('col)).write.mode(SaveMode.Overwrite).format("noop").save() println(System.currentTimeMillis() - start) ``` ### Does this PR introduce any user-facing change? Yes. User now could be able to use typed Scala UDF with case class as input parameter. ### How was this patch tested? Added unit tests. Closes #27937 from Ngone51/udf_caseclass_support. Authored-by: yi.wu <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 1fd4607 commit f6ff7d0

File tree

12 files changed

+519
-410
lines changed

12 files changed

+519
-410
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2707,13 +2707,13 @@ class Analyzer(
27072707

27082708
case p => p transformExpressionsUp {
27092709

2710-
case udf @ ScalaUDF(_, _, inputs, inputPrimitives, _, _, _, _)
2711-
if inputPrimitives.contains(true) =>
2710+
case udf @ ScalaUDF(_, _, inputs, _, _, _, _)
2711+
if udf.inputPrimitives.contains(true) =>
27122712
// Otherwise, add special handling of null for fields that can't accept null.
27132713
// The result of operations like this, when passed null, is generally to return null.
2714-
assert(inputPrimitives.length == inputs.length)
2714+
assert(udf.inputPrimitives.length == inputs.length)
27152715

2716-
val inputPrimitivesPair = inputPrimitives.zip(inputs)
2716+
val inputPrimitivesPair = udf.inputPrimitives.zip(inputs)
27172717
val inputNullCheck = inputPrimitivesPair.collect {
27182718
case (isPrimitive, input) if isPrimitive && input.nullable =>
27192719
IsNull(input)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala

Lines changed: 327 additions & 268 deletions
Large diffs are not rendered by default.

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier
2929
import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType, InMemoryCatalog, SessionCatalog}
3030
import org.apache.spark.sql.catalyst.dsl.expressions._
3131
import org.apache.spark.sql.catalyst.dsl.plans._
32+
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
3233
import org.apache.spark.sql.catalyst.errors.TreeNodeException
3334
import org.apache.spark.sql.catalyst.expressions._
3435
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Count, Sum}
@@ -326,20 +327,21 @@ class AnalysisSuite extends AnalysisTest with Matchers {
326327
}
327328

328329
// non-primitive parameters do not need special null handling
329-
val udf1 = ScalaUDF((s: String) => "x", StringType, string :: Nil, false :: Nil)
330+
val udf1 = ScalaUDF((s: String) => "x", StringType, string :: Nil,
331+
Option(ExpressionEncoder[String]()) :: Nil)
330332
val expected1 = udf1
331333
checkUDF(udf1, expected1)
332334

333335
// only primitive parameter needs special null handling
334336
val udf2 = ScalaUDF((s: String, d: Double) => "x", StringType, string :: double :: Nil,
335-
false :: true :: Nil)
337+
Option(ExpressionEncoder[String]()) :: Option(ExpressionEncoder[Double]()) :: Nil)
336338
val expected2 =
337339
If(IsNull(double), nullResult, udf2.copy(children = string :: KnownNotNull(double) :: Nil))
338340
checkUDF(udf2, expected2)
339341

340342
// special null handling should apply to all primitive parameters
341343
val udf3 = ScalaUDF((s: Short, d: Double) => "x", StringType, short :: double :: Nil,
342-
true :: true :: Nil)
344+
Option(ExpressionEncoder[Short]()) :: Option(ExpressionEncoder[Double]()) :: Nil)
343345
val expected3 = If(
344346
IsNull(short) || IsNull(double),
345347
nullResult,
@@ -351,7 +353,7 @@ class AnalysisSuite extends AnalysisTest with Matchers {
351353
(s: Short, d: Double) => "x",
352354
StringType,
353355
short :: nonNullableDouble :: Nil,
354-
true :: true :: Nil)
356+
Option(ExpressionEncoder[Short]()) :: Option(ExpressionEncoder[Double]()) :: Nil)
355357
val expected4 = If(
356358
IsNull(short),
357359
nullResult,
@@ -362,8 +364,12 @@ class AnalysisSuite extends AnalysisTest with Matchers {
362364
test("SPARK-24891 Fix HandleNullInputsForUDF rule") {
363365
val a = testRelation.output(0)
364366
val func = (x: Int, y: Int) => x + y
365-
val udf1 = ScalaUDF(func, IntegerType, a :: a :: Nil, false :: false :: Nil)
366-
val udf2 = ScalaUDF(func, IntegerType, a :: udf1 :: Nil, false :: false :: Nil)
367+
val udf1 = ScalaUDF(func, IntegerType, a :: a :: Nil,
368+
Option(ExpressionEncoder[java.lang.Integer]()) ::
369+
Option(ExpressionEncoder[java.lang.Integer]()) :: Nil)
370+
val udf2 = ScalaUDF(func, IntegerType, a :: udf1 :: Nil,
371+
Option(ExpressionEncoder[java.lang.Integer]()) ::
372+
Option(ExpressionEncoder[java.lang.Integer]()) :: Nil)
367373
val plan = Project(Alias(udf2, "")() :: Nil, testRelation)
368374
comparePlans(plan.analyze, plan.analyze.analyze)
369375
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,17 +20,20 @@ package org.apache.spark.sql.catalyst.expressions
2020
import java.util.Locale
2121

2222
import org.apache.spark.{SparkException, SparkFunSuite}
23+
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
2324
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
2425
import org.apache.spark.sql.internal.SQLConf
2526
import org.apache.spark.sql.types.{DecimalType, IntegerType, StringType}
2627

2728
class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper {
2829

2930
test("basic") {
30-
val intUdf = ScalaUDF((i: Int) => i + 1, IntegerType, Literal(1) :: Nil, true :: Nil)
31+
val intUdf = ScalaUDF((i: Int) => i + 1, IntegerType, Literal(1) :: Nil,
32+
Option(ExpressionEncoder[Int]()) :: Nil)
3133
checkEvaluation(intUdf, 2)
3234

33-
val stringUdf = ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil, false :: Nil)
35+
val stringUdf = ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil,
36+
Option(ExpressionEncoder[String]()) :: Nil)
3437
checkEvaluation(stringUdf, "ax")
3538
}
3639

@@ -39,7 +42,7 @@ class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper {
3942
(s: String) => s.toLowerCase(Locale.ROOT),
4043
StringType,
4144
Literal.create(null, StringType) :: Nil,
42-
false :: Nil)
45+
Option(ExpressionEncoder[String]()) :: Nil)
4346

4447
val e1 = intercept[SparkException](udf.eval())
4548
assert(e1.getMessage.contains("Failed to execute user defined function"))
@@ -52,7 +55,8 @@ class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper {
5255

5356
test("SPARK-22695: ScalaUDF should not use global variables") {
5457
val ctx = new CodegenContext
55-
ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil, false :: Nil).genCode(ctx)
58+
ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil,
59+
Option(ExpressionEncoder[String]()) :: Nil).genCode(ctx)
5660
assert(ctx.inlinedMutableStates.isEmpty)
5761
}
5862

@@ -61,7 +65,8 @@ class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper {
6165
val udf = ScalaUDF(
6266
(a: java.math.BigDecimal) => a.multiply(new java.math.BigDecimal(100)),
6367
DecimalType.SYSTEM_DEFAULT,
64-
Literal(BigDecimal("12345678901234567890.123")) :: Nil, false :: Nil)
68+
Literal(BigDecimal("12345678901234567890.123")) :: Nil,
69+
Option(ExpressionEncoder[java.math.BigDecimal]()) :: Nil)
6570
val e1 = intercept[ArithmeticException](udf.eval())
6671
assert(e1.getMessage.contains("cannot be represented as Decimal"))
6772
val e2 = intercept[SparkException] {
@@ -73,7 +78,8 @@ class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper {
7378
val udf = ScalaUDF(
7479
(a: java.math.BigDecimal) => a.multiply(new java.math.BigDecimal(100)),
7580
DecimalType.SYSTEM_DEFAULT,
76-
Literal(BigDecimal("12345678901234567890.123")) :: Nil, false :: Nil)
81+
Literal(BigDecimal("12345678901234567890.123")) :: Nil,
82+
Option(ExpressionEncoder[java.math.BigDecimal]()) :: Nil)
7783
checkEvaluation(udf, null)
7884
}
7985
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry}
2222
import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
2323
import org.apache.spark.sql.catalyst.dsl.expressions._
2424
import org.apache.spark.sql.catalyst.dsl.plans._
25+
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
2526
import org.apache.spark.sql.catalyst.expressions._
2627
import org.apache.spark.sql.catalyst.plans._
2728
import org.apache.spark.sql.catalyst.plans.logical._
@@ -244,7 +245,8 @@ class EliminateSortsSuite extends PlanTest {
244245
}
245246

246247
test("should not remove orderBy in groupBy clause with ScalaUDF as aggs") {
247-
val scalaUdf = ScalaUDF((s: Int) => s, IntegerType, 'a :: Nil, true :: Nil)
248+
val scalaUdf = ScalaUDF((s: Int) => s, IntegerType, 'a :: Nil,
249+
Option(ExpressionEncoder[Int]()) :: Nil)
248250
val projectPlan = testRelation.select('a, 'b)
249251
val orderByPlan = projectPlan.orderBy('a.asc, 'b.desc)
250252
val groupByPlan = orderByPlan.groupBy('a)(scalaUdf)

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ import org.apache.spark.SparkFunSuite
3131
import org.apache.spark.sql.catalyst.{AliasIdentifier, FunctionIdentifier, InternalRow, TableIdentifier}
3232
import org.apache.spark.sql.catalyst.catalog._
3333
import org.apache.spark.sql.catalyst.dsl.expressions.DslString
34+
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
3435
import org.apache.spark.sql.catalyst.expressions._
3536
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
3637
import org.apache.spark.sql.catalyst.plans.{LeftOuter, NaturalJoin, SQLHelper}
@@ -594,7 +595,8 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper {
594595
}
595596

596597
test("toJSON should not throws java.lang.StackOverflowError") {
597-
val udf = ScalaUDF(SelfReferenceUDF(), BooleanType, Seq("col1".attr), false :: Nil)
598+
val udf = ScalaUDF(SelfReferenceUDF(), BooleanType, Seq("col1".attr),
599+
Option(ExpressionEncoder[String]()) :: Nil)
598600
// Should not throw java.lang.StackOverflowError
599601
udf.toJSON
600602
}

0 commit comments

Comments
 (0)