diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 68edf851bf2a..30399f17dbb4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -256,7 +256,8 @@ class Analyzer( Batch("Nondeterministic", Once, PullOutNondeterministic), Batch("UDF", Once, - HandleNullInputsForUDF), + HandleNullInputsForUDF, + ResolveEncodersInUDF), Batch("UpdateNullability", Once, UpdateAttributeNullability), Batch("Subquery", Once, @@ -2847,6 +2848,45 @@ class Analyzer( } } + /** + * Resolve the encoders for the UDF by explicitly given the attributes. We give the + * attributes explicitly in order to handle the case where the data type of the input + * value is not the same with the internal schema of the encoder, which could cause + * data loss. For example, the encoder should not cast the input value to Decimal(38, 18) + * if the actual data type is Decimal(30, 0). + * + * The resolved encoders then will be used to deserialize the internal row to Scala value. + */ + object ResolveEncodersInUDF extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { + case p if !p.resolved => p // Skip unresolved nodes. + + case p => p transformExpressionsUp { + + case udf: ScalaUDF if udf.inputEncoders.nonEmpty => + val boundEncoders = udf.inputEncoders.zipWithIndex.map { case (encOpt, i) => + val dataType = udf.children(i).dataType + if (dataType.existsRecursively(_.isInstanceOf[UserDefinedType[_]])) { + // for UDT, we use `CatalystTypeConverters` + None + } else { + encOpt.map { enc => + val attrs = if (enc.isSerializedAsStructForTopLevel) { + dataType.asInstanceOf[StructType].toAttributes + } else { + // the field name doesn't matter here, so we use + // a simple literal to avoid any overhead + new StructType().add("input", dataType).toAttributes + } + enc.resolveAndBind(attrs) + } + } + } + udf.copy(inputEncoders = boundEncoders) + } + } + } + /** * Check and add proper window frames for all window functions. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index e80f03ea8475..1e3e6d90b850 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -17,14 +17,13 @@ package org.apache.spark.sql.catalyst.expressions -import scala.collection.mutable - import org.apache.spark.SparkException -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.catalyst.CatalystTypeConverters.{createToCatalystConverter, createToScalaConverter => catalystCreateToScalaConverter, isPrimitive} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, DataType} +import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, DataType, UserDefinedType} /** * User-defined function. @@ -103,21 +102,46 @@ case class ScalaUDF( } } - private def createToScalaConverter(i: Int, dataType: DataType): Any => Any = { - if (inputEncoders.isEmpty) { - // for untyped Scala UDF - CatalystTypeConverters.createToScalaConverter(dataType) - } else { - val encoder = inputEncoders(i) - if (encoder.isDefined && encoder.get.isSerializedAsStructForTopLevel) { - val fromRow = encoder.get.resolveAndBind().createDeserializer() + /** + * Create the converter which converts the catalyst data type to the scala data type. + * We use `CatalystTypeConverters` to create the converter for: + * - UDF which doesn't provide inputEncoders, e.g., untyped Scala UDF and Java UDF + * - type which isn't supported by `ExpressionEncoder`, e.g., Any + * - primitive types, in order to use `identity` for better performance + * - UserDefinedType which isn't fully supported by `ExpressionEncoder` + * For other cases like case class, Option[T], we use `ExpressionEncoder` instead since + * `CatalystTypeConverters` doesn't support these data types. + * + * @param i the index of the child + * @param dataType the output data type of the i-th child + * @return the converter and a boolean value to indicate whether the converter is + * created by using `ExpressionEncoder`. + */ + private def scalaConverter(i: Int, dataType: DataType): (Any => Any, Boolean) = { + val useEncoder = + !(inputEncoders.isEmpty || // for untyped Scala UDF and Java UDF + inputEncoders(i).isEmpty || // for types aren't supported by encoder, e.g. Any + inputPrimitives(i) || // for primitive types + dataType.existsRecursively(_.isInstanceOf[UserDefinedType[_]])) + + if (useEncoder) { + val enc = inputEncoders(i).get + val fromRow = enc.createDeserializer() + val converter = if (enc.isSerializedAsStructForTopLevel) { row: Any => fromRow(row.asInstanceOf[InternalRow]) } else { - CatalystTypeConverters.createToScalaConverter(dataType) + val inputRow = new GenericInternalRow(1) + value: Any => inputRow.update(0, value); fromRow(inputRow) } + (converter, true) + } else { // use CatalystTypeConverters + (catalystCreateToScalaConverter(dataType), false) } } + private def createToScalaConverter(i: Int, dataType: DataType): Any => Any = + scalaConverter(i, dataType)._1 + // scalastyle:off line.size.limit /** This method has been generated by this script @@ -1045,10 +1069,11 @@ case class ScalaUDF( ev: ExprCode): ExprCode = { val converterClassName = classOf[Any => Any].getName - // The type converters for inputs and the result. - val converters: Array[Any => Any] = children.zipWithIndex.map { case (c, i) => - createToScalaConverter(i, c.dataType) - }.toArray :+ CatalystTypeConverters.createToCatalystConverter(dataType) + // The type converters for inputs and the result + val (converters, useEncoders): (Array[Any => Any], Array[Boolean]) = + (children.zipWithIndex.map { case (c, i) => + scalaConverter(i, c.dataType) + }.toArray :+ (createToCatalystConverter(dataType), false)).unzip val convertersTerm = ctx.addReferenceObj("converters", converters, s"$converterClassName[]") val errorMsgTerm = ctx.addReferenceObj("errMsg", udfErrorMessage) val resultTerm = ctx.freshName("result") @@ -1064,12 +1089,26 @@ case class ScalaUDF( val (funcArgs, initArgs) = evals.zipWithIndex.zip(children.map(_.dataType)).map { case ((eval, i), dt) => val argTerm = ctx.freshName("arg") - val initArg = if (CatalystTypeConverters.isPrimitive(dt)) { + // Check `inputPrimitives` when it's not empty in order to figure out the Option + // type as non primitive type, e.g., Option[Int]. Fall back to `isPrimitive` when + // `inputPrimitives` is empty for other cases, e.g., Java UDF, untyped Scala UDF + val primitive = (inputPrimitives.isEmpty && isPrimitive(dt)) || + (inputPrimitives.nonEmpty && inputPrimitives(i)) + val initArg = if (primitive) { val convertedTerm = ctx.freshName("conv") s""" |${CodeGenerator.boxedType(dt)} $convertedTerm = ${eval.value}; |Object $argTerm = ${eval.isNull} ? null : $convertedTerm; """.stripMargin + } else if (useEncoders(i)) { + s""" + |Object $argTerm = null; + |if (${eval.isNull}) { + | $argTerm = $convertersTerm[$i].apply(null); + |} else { + | $argTerm = $convertersTerm[$i].apply(${eval.value}); + |} + """.stripMargin } else { s"Object $argTerm = ${eval.isNull} ? null : $convertersTerm[$i].apply(${eval.value});" } @@ -1081,7 +1120,7 @@ case class ScalaUDF( val resultConverter = s"$convertersTerm[${children.length}]" val boxedType = CodeGenerator.boxedType(dataType) - val funcInvokation = if (CatalystTypeConverters.isPrimitive(dataType) + val funcInvokation = if (isPrimitive(dataType) // If the output is nullable, the returned value must be unwrapped from the Option && !nullable) { s"$resultTerm = ($boxedType)$getFuncResult" @@ -1112,7 +1151,7 @@ case class ScalaUDF( """.stripMargin) } - private[this] val resultConverter = CatalystTypeConverters.createToCatalystConverter(dataType) + private[this] val resultConverter = createToCatalystConverter(dataType) lazy val udfErrorMessage = { val funcCls = function.getClass.getSimpleName diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 02472e153b09..189152374b0d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.analysis import java.util.{Locale, TimeZone} import scala.reflect.ClassTag +import scala.reflect.runtime.universe.TypeTag import org.apache.log4j.Level import org.scalatest.Matchers @@ -307,6 +308,10 @@ class AnalysisSuite extends AnalysisTest with Matchers { } test("SPARK-11725: correctly handle null inputs for ScalaUDF") { + def resolvedEncoder[T : TypeTag](): ExpressionEncoder[T] = { + ExpressionEncoder[T]().resolveAndBind() + } + val testRelation = LocalRelation( AttributeReference("a", StringType)(), AttributeReference("b", DoubleType)(), @@ -328,20 +333,20 @@ class AnalysisSuite extends AnalysisTest with Matchers { // non-primitive parameters do not need special null handling val udf1 = ScalaUDF((s: String) => "x", StringType, string :: Nil, - Option(ExpressionEncoder[String]()) :: Nil) + Option(resolvedEncoder[String]()) :: Nil) val expected1 = udf1 checkUDF(udf1, expected1) // only primitive parameter needs special null handling val udf2 = ScalaUDF((s: String, d: Double) => "x", StringType, string :: double :: Nil, - Option(ExpressionEncoder[String]()) :: Option(ExpressionEncoder[Double]()) :: Nil) + Option(resolvedEncoder[String]()) :: Option(resolvedEncoder[Double]()) :: Nil) val expected2 = If(IsNull(double), nullResult, udf2.copy(children = string :: KnownNotNull(double) :: Nil)) checkUDF(udf2, expected2) // special null handling should apply to all primitive parameters val udf3 = ScalaUDF((s: Short, d: Double) => "x", StringType, short :: double :: Nil, - Option(ExpressionEncoder[Short]()) :: Option(ExpressionEncoder[Double]()) :: Nil) + Option(resolvedEncoder[Short]()) :: Option(resolvedEncoder[Double]()) :: Nil) val expected3 = If( IsNull(short) || IsNull(double), nullResult, @@ -353,7 +358,7 @@ class AnalysisSuite extends AnalysisTest with Matchers { (s: Short, d: Double) => "x", StringType, short :: nonNullableDouble :: Nil, - Option(ExpressionEncoder[Short]()) :: Option(ExpressionEncoder[Double]()) :: Nil) + Option(resolvedEncoder[Short]()) :: Option(resolvedEncoder[Double]()) :: Nil) val expected4 = If( IsNull(short), nullResult, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala index 836b2eaa642a..1b40e02aa866 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.expressions import java.util.Locale +import scala.reflect.runtime.universe.TypeTag + import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext @@ -27,13 +29,17 @@ import org.apache.spark.sql.types.{DecimalType, IntegerType, StringType} class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper { + private def resolvedEncoder[T : TypeTag](): ExpressionEncoder[T] = { + ExpressionEncoder[T]().resolveAndBind() + } + test("basic") { val intUdf = ScalaUDF((i: Int) => i + 1, IntegerType, Literal(1) :: Nil, - Option(ExpressionEncoder[Int]()) :: Nil) + Option(resolvedEncoder[Int]()) :: Nil) checkEvaluation(intUdf, 2) val stringUdf = ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil, - Option(ExpressionEncoder[String]()) :: Nil) + Option(resolvedEncoder[String]()) :: Nil) checkEvaluation(stringUdf, "ax") } @@ -42,7 +48,7 @@ class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper { (s: String) => s.toLowerCase(Locale.ROOT), StringType, Literal.create(null, StringType) :: Nil, - Option(ExpressionEncoder[String]()) :: Nil) + Option(resolvedEncoder[String]()) :: Nil) val e1 = intercept[SparkException](udf.eval()) assert(e1.getMessage.contains("Failed to execute user defined function")) @@ -56,7 +62,7 @@ class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper { test("SPARK-22695: ScalaUDF should not use global variables") { val ctx = new CodegenContext ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil, - Option(ExpressionEncoder[String]()) :: Nil).genCode(ctx) + Option(resolvedEncoder[String]()) :: Nil).genCode(ctx) assert(ctx.inlinedMutableStates.isEmpty) } @@ -66,7 +72,7 @@ class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper { (a: java.math.BigDecimal) => a.multiply(new java.math.BigDecimal(100)), DecimalType.SYSTEM_DEFAULT, Literal(BigDecimal("12345678901234567890.123")) :: Nil, - Option(ExpressionEncoder[java.math.BigDecimal]()) :: Nil) + Option(resolvedEncoder[java.math.BigDecimal]()) :: Nil) val e1 = intercept[ArithmeticException](udf.eval()) assert(e1.getMessage.contains("cannot be represented as Decimal")) val e2 = intercept[SparkException] { @@ -79,7 +85,7 @@ class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper { (a: java.math.BigDecimal) => a.multiply(new java.math.BigDecimal(100)), DecimalType.SYSTEM_DEFAULT, Literal(BigDecimal("12345678901234567890.123")) :: Nil, - Option(ExpressionEncoder[java.math.BigDecimal]()) :: Nil) + Option(resolvedEncoder[java.math.BigDecimal]()) :: Nil) checkEvaluation(udf, null) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index e2747d7db9f3..5c1fe265c15d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.execution.{QueryExecution, SimpleMode} import org.apache.spark.sql.execution.columnar.InMemoryRelation import org.apache.spark.sql.execution.command.{CreateDataSourceTableAsSelectCommand, ExplainCommand} import org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand -import org.apache.spark.sql.functions.{lit, udf} +import org.apache.spark.sql.functions.{lit, struct, udf} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.test.SQLTestData._ @@ -581,4 +581,92 @@ class UDFSuite extends QueryTest with SharedSparkSession { .toDF("col1", "col2") checkAnswer(df.select(myUdf(Column("col1"), Column("col2"))), Row(2020) :: Nil) } + + test("case class as element type of Seq/Array") { + val f1 = (s: Seq[TestData]) => s.map(d => d.key * d.value.toInt).sum + val myUdf1 = udf(f1) + val df1 = Seq(("data", Seq(TestData(50, "2")))).toDF("col1", "col2") + checkAnswer(df1.select(myUdf1(Column("col2"))), Row(100) :: Nil) + + val f2 = (s: Array[TestData]) => s.map(d => d.key * d.value.toInt).sum + val myUdf2 = udf(f2) + val df2 = Seq(("data", Array(TestData(50, "2")))).toDF("col1", "col2") + checkAnswer(df2.select(myUdf2(Column("col2"))), Row(100) :: Nil) + } + + test("case class as key/value type of Map") { + val f1 = (s: Map[TestData, Int]) => s.keys.head.key * s.keys.head.value.toInt + val myUdf1 = udf(f1) + val df1 = Seq(("data", Map(TestData(50, "2") -> 502))).toDF("col1", "col2") + checkAnswer(df1.select(myUdf1(Column("col2"))), Row(100) :: Nil) + + val f2 = (s: Map[Int, TestData]) => s.values.head.key * s.values.head.value.toInt + val myUdf2 = udf(f2) + val df2 = Seq(("data", Map(502 -> TestData(50, "2")))).toDF("col1", "col2") + checkAnswer(df2.select(myUdf2(Column("col2"))), Row(100) :: Nil) + + val f3 = (s: Map[TestData, TestData]) => s.keys.head.key * s.values.head.value.toInt + val myUdf3 = udf(f3) + val df3 = Seq(("data", Map(TestData(50, "2") -> TestData(50, "2")))).toDF("col1", "col2") + checkAnswer(df3.select(myUdf3(Column("col2"))), Row(100) :: Nil) + } + + test("case class as element of tuple") { + val f = (s: (TestData, Int)) => s._1.key * s._2 + val myUdf = udf(f) + val df = Seq(("data", (TestData(50, "2"), 2))).toDF("col1", "col2") + checkAnswer(df.select(myUdf(Column("col2"))), Row(100) :: Nil) + } + + test("case class as generic type of Option") { + val f = (o: Option[TestData]) => o.map(t => t.key * t.value.toInt) + val myUdf = udf(f) + val df1 = Seq(("data", Some(TestData(50, "2")))).toDF("col1", "col2") + checkAnswer(df1.select(myUdf(Column("col2"))), Row(100) :: Nil) + val df2 = Seq(("data", None: Option[TestData])).toDF("col1", "col2") + checkAnswer(df2.select(myUdf(Column("col2"))), Row(null) :: Nil) + } + + test("more input fields than expect for case class") { + val f = (t: TestData2) => t.a * t.b + val myUdf = udf(f) + val df = spark.range(1) + .select(lit(50).as("a"), lit(2).as("b"), lit(2).as("c")) + .select(struct("a", "b", "c").as("col")) + checkAnswer(df.select(myUdf(Column("col"))), Row(100) :: Nil) + } + + test("less input fields than expect for case class") { + val f = (t: TestData2) => t.a * t.b + val myUdf = udf(f) + val df = spark.range(1) + .select(lit(50).as("a")) + .select(struct("a").as("col")) + val error = intercept[AnalysisException](df.select(myUdf(Column("col")))) + assert(error.getMessage.contains("cannot resolve '`b`' given input columns: [a]")) + } + + test("wrong order of input fields for case class") { + val f = (t: TestData) => t.key * t.value.toInt + val myUdf = udf(f) + val df = spark.range(1) + .select(lit("2").as("value"), lit(50).as("key")) + .select(struct("value", "key").as("col")) + checkAnswer(df.select(myUdf(Column("col"))), Row(100) :: Nil) + } + + test("top level Option primitive type") { + val f = (i: Option[Int]) => i.map(_ * 10) + val myUdf = udf(f) + val df = Seq(Some(10), None).toDF("col") + checkAnswer(df.select(myUdf(Column("col"))), Row(100) :: Row(null) :: Nil) + } + + test("array Option") { + val f = (i: Array[Option[TestData]]) => + i.map(_.map(t => t.key * t.value.toInt).getOrElse(0)).sum + val myUdf = udf(f) + val df = Seq(Array(Some(TestData(50, "2")), None)).toDF("col") + checkAnswer(df.select(myUdf(Column("col"))), Row(100) :: Nil) + } }