Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,8 @@ class Analyzer(
Batch("Nondeterministic", Once,
PullOutNondeterministic),
Batch("UDF", Once,
HandleNullInputsForUDF),
HandleNullInputsForUDF,
ResolveEncodersInUDF),
Batch("UpdateNullability", Once,
UpdateAttributeNullability),
Batch("Subquery", Once,
Expand Down Expand Up @@ -2847,6 +2848,45 @@ class Analyzer(
}
}

/**
* Resolve the encoders for the UDF by explicitly given the attributes. We give the
* attributes explicitly in order to handle the case where the data type of the input
* value is not the same with the internal schema of the encoder, which could cause
* data loss. For example, the encoder should not cast the input value to Decimal(38, 18)
* if the actual data type is Decimal(30, 0).
*
* The resolved encoders then will be used to deserialize the internal row to Scala value.
*/
object ResolveEncodersInUDF extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp {
case p if !p.resolved => p // Skip unresolved nodes.

case p => p transformExpressionsUp {

case udf: ScalaUDF if udf.inputEncoders.nonEmpty =>
val boundEncoders = udf.inputEncoders.zipWithIndex.map { case (encOpt, i) =>
val dataType = udf.children(i).dataType
if (dataType.existsRecursively(_.isInstanceOf[UserDefinedType[_]])) {
// for UDT, we use `CatalystTypeConverters`
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

encoder does support UDT. We can figure it out later.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does, but just doesn't support upcast from the subclass to the parent class. So, when the input data type from the child is the subclass of the input parameter data type of the udf, resolveAndBind can fail.

I think this may need a separate fix.

None
} else {
encOpt.map { enc =>
val attrs = if (enc.isSerializedAsStructForTopLevel) {
dataType.asInstanceOf[StructType].toAttributes
} else {
// the field name doesn't matter here, so we use
// a simple literal to avoid any overhead
new StructType().add("input", dataType).toAttributes
}
enc.resolveAndBind(attrs)
}
}
}
udf.copy(inputEncoders = boundEncoders)
}
}
}

/**
* Check and add proper window frames for all window functions.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,13 @@

package org.apache.spark.sql.catalyst.expressions

import scala.collection.mutable

import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.CatalystTypeConverters.{createToCatalystConverter, createToScalaConverter => catalystCreateToScalaConverter, isPrimitive}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, DataType}
import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, DataType, UserDefinedType}

/**
* User-defined function.
Expand Down Expand Up @@ -103,21 +102,46 @@ case class ScalaUDF(
}
}

private def createToScalaConverter(i: Int, dataType: DataType): Any => Any = {
if (inputEncoders.isEmpty) {
// for untyped Scala UDF
CatalystTypeConverters.createToScalaConverter(dataType)
} else {
val encoder = inputEncoders(i)
if (encoder.isDefined && encoder.get.isSerializedAsStructForTopLevel) {
val fromRow = encoder.get.resolveAndBind().createDeserializer()
/**
* Create the converter which converts the catalyst data type to the scala data type.
* We use `CatalystTypeConverters` to create the converter for:
* - UDF which doesn't provide inputEncoders, e.g., untyped Scala UDF and Java UDF
* - type which isn't supported by `ExpressionEncoder`, e.g., Any
* - primitive types, in order to use `identity` for better performance
* - UserDefinedType which isn't fully supported by `ExpressionEncoder`
* For other cases like case class, Option[T], we use `ExpressionEncoder` instead since
* `CatalystTypeConverters` doesn't support these data types.
*
* @param i the index of the child
* @param dataType the output data type of the i-th child
* @return the converter and a boolean value to indicate whether the converter is
* created by using `ExpressionEncoder`.
*/
private def scalaConverter(i: Int, dataType: DataType): (Any => Any, Boolean) = {
val useEncoder =
!(inputEncoders.isEmpty || // for untyped Scala UDF and Java UDF
inputEncoders(i).isEmpty || // for types aren't supported by encoder, e.g. Any
inputPrimitives(i) || // for primitive types
dataType.existsRecursively(_.isInstanceOf[UserDefinedType[_]]))

if (useEncoder) {
val enc = inputEncoders(i).get
val fromRow = enc.createDeserializer()
val converter = if (enc.isSerializedAsStructForTopLevel) {
row: Any => fromRow(row.asInstanceOf[InternalRow])
} else {
CatalystTypeConverters.createToScalaConverter(dataType)
val inputRow = new GenericInternalRow(1)
value: Any => inputRow.update(0, value); fromRow(inputRow)
}
(converter, true)
} else { // use CatalystTypeConverters
(catalystCreateToScalaConverter(dataType), false)
}
}

private def createToScalaConverter(i: Int, dataType: DataType): Any => Any =
scalaConverter(i, dataType)._1

// scalastyle:off line.size.limit

/** This method has been generated by this script
Expand Down Expand Up @@ -1045,10 +1069,11 @@ case class ScalaUDF(
ev: ExprCode): ExprCode = {
val converterClassName = classOf[Any => Any].getName

// The type converters for inputs and the result.
val converters: Array[Any => Any] = children.zipWithIndex.map { case (c, i) =>
createToScalaConverter(i, c.dataType)
}.toArray :+ CatalystTypeConverters.createToCatalystConverter(dataType)
// The type converters for inputs and the result
val (converters, useEncoders): (Array[Any => Any], Array[Boolean]) =
(children.zipWithIndex.map { case (c, i) =>
scalaConverter(i, c.dataType)
}.toArray :+ (createToCatalystConverter(dataType), false)).unzip
val convertersTerm = ctx.addReferenceObj("converters", converters, s"$converterClassName[]")
val errorMsgTerm = ctx.addReferenceObj("errMsg", udfErrorMessage)
val resultTerm = ctx.freshName("result")
Expand All @@ -1064,12 +1089,26 @@ case class ScalaUDF(
val (funcArgs, initArgs) = evals.zipWithIndex.zip(children.map(_.dataType)).map {
case ((eval, i), dt) =>
val argTerm = ctx.freshName("arg")
val initArg = if (CatalystTypeConverters.isPrimitive(dt)) {
// Check `inputPrimitives` when it's not empty in order to figure out the Option
// type as non primitive type, e.g., Option[Int]. Fall back to `isPrimitive` when
// `inputPrimitives` is empty for other cases, e.g., Java UDF, untyped Scala UDF
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so untyped Scala UDF doesn't support Option?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea. We require the encoder to support Option but untyped Scala UDF can't provide the encoder.

val primitive = (inputPrimitives.isEmpty && isPrimitive(dt)) ||
(inputPrimitives.nonEmpty && inputPrimitives(i))
val initArg = if (primitive) {
val convertedTerm = ctx.freshName("conv")
s"""
|${CodeGenerator.boxedType(dt)} $convertedTerm = ${eval.value};
|Object $argTerm = ${eval.isNull} ? null : $convertedTerm;
""".stripMargin
} else if (useEncoders(i)) {
s"""
|Object $argTerm = null;
|if (${eval.isNull}) {
| $argTerm = $convertersTerm[$i].apply(null);
|} else {
| $argTerm = $convertersTerm[$i].apply(${eval.value});
|}
""".stripMargin
} else {
s"Object $argTerm = ${eval.isNull} ? null : $convertersTerm[$i].apply(${eval.value});"
}
Expand All @@ -1081,7 +1120,7 @@ case class ScalaUDF(
val resultConverter = s"$convertersTerm[${children.length}]"
val boxedType = CodeGenerator.boxedType(dataType)

val funcInvokation = if (CatalystTypeConverters.isPrimitive(dataType)
val funcInvokation = if (isPrimitive(dataType)
// If the output is nullable, the returned value must be unwrapped from the Option
&& !nullable) {
s"$resultTerm = ($boxedType)$getFuncResult"
Expand Down Expand Up @@ -1112,7 +1151,7 @@ case class ScalaUDF(
""".stripMargin)
}

private[this] val resultConverter = CatalystTypeConverters.createToCatalystConverter(dataType)
private[this] val resultConverter = createToCatalystConverter(dataType)

lazy val udfErrorMessage = {
val funcCls = function.getClass.getSimpleName
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.analysis
import java.util.{Locale, TimeZone}

import scala.reflect.ClassTag
import scala.reflect.runtime.universe.TypeTag

import org.apache.log4j.Level
import org.scalatest.Matchers
Expand Down Expand Up @@ -307,6 +308,10 @@ class AnalysisSuite extends AnalysisTest with Matchers {
}

test("SPARK-11725: correctly handle null inputs for ScalaUDF") {
def resolvedEncoder[T : TypeTag](): ExpressionEncoder[T] = {
ExpressionEncoder[T]().resolveAndBind()
}

val testRelation = LocalRelation(
AttributeReference("a", StringType)(),
AttributeReference("b", DoubleType)(),
Expand All @@ -328,20 +333,20 @@ class AnalysisSuite extends AnalysisTest with Matchers {

// non-primitive parameters do not need special null handling
val udf1 = ScalaUDF((s: String) => "x", StringType, string :: Nil,
Option(ExpressionEncoder[String]()) :: Nil)
Option(resolvedEncoder[String]()) :: Nil)
val expected1 = udf1
checkUDF(udf1, expected1)

// only primitive parameter needs special null handling
val udf2 = ScalaUDF((s: String, d: Double) => "x", StringType, string :: double :: Nil,
Option(ExpressionEncoder[String]()) :: Option(ExpressionEncoder[Double]()) :: Nil)
Option(resolvedEncoder[String]()) :: Option(resolvedEncoder[Double]()) :: Nil)
val expected2 =
If(IsNull(double), nullResult, udf2.copy(children = string :: KnownNotNull(double) :: Nil))
checkUDF(udf2, expected2)

// special null handling should apply to all primitive parameters
val udf3 = ScalaUDF((s: Short, d: Double) => "x", StringType, short :: double :: Nil,
Option(ExpressionEncoder[Short]()) :: Option(ExpressionEncoder[Double]()) :: Nil)
Option(resolvedEncoder[Short]()) :: Option(resolvedEncoder[Double]()) :: Nil)
val expected3 = If(
IsNull(short) || IsNull(double),
nullResult,
Expand All @@ -353,7 +358,7 @@ class AnalysisSuite extends AnalysisTest with Matchers {
(s: Short, d: Double) => "x",
StringType,
short :: nonNullableDouble :: Nil,
Option(ExpressionEncoder[Short]()) :: Option(ExpressionEncoder[Double]()) :: Nil)
Option(resolvedEncoder[Short]()) :: Option(resolvedEncoder[Double]()) :: Nil)
val expected4 = If(
IsNull(short),
nullResult,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.expressions

import java.util.Locale

import scala.reflect.runtime.universe.TypeTag

import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
Expand All @@ -27,13 +29,17 @@ import org.apache.spark.sql.types.{DecimalType, IntegerType, StringType}

class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper {

private def resolvedEncoder[T : TypeTag](): ExpressionEncoder[T] = {
ExpressionEncoder[T]().resolveAndBind()
}

test("basic") {
val intUdf = ScalaUDF((i: Int) => i + 1, IntegerType, Literal(1) :: Nil,
Option(ExpressionEncoder[Int]()) :: Nil)
Option(resolvedEncoder[Int]()) :: Nil)
checkEvaluation(intUdf, 2)

val stringUdf = ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil,
Option(ExpressionEncoder[String]()) :: Nil)
Option(resolvedEncoder[String]()) :: Nil)
checkEvaluation(stringUdf, "ax")
}

Expand All @@ -42,7 +48,7 @@ class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper {
(s: String) => s.toLowerCase(Locale.ROOT),
StringType,
Literal.create(null, StringType) :: Nil,
Option(ExpressionEncoder[String]()) :: Nil)
Option(resolvedEncoder[String]()) :: Nil)

val e1 = intercept[SparkException](udf.eval())
assert(e1.getMessage.contains("Failed to execute user defined function"))
Expand All @@ -56,7 +62,7 @@ class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper {
test("SPARK-22695: ScalaUDF should not use global variables") {
val ctx = new CodegenContext
ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil,
Option(ExpressionEncoder[String]()) :: Nil).genCode(ctx)
Option(resolvedEncoder[String]()) :: Nil).genCode(ctx)
assert(ctx.inlinedMutableStates.isEmpty)
}

Expand All @@ -66,7 +72,7 @@ class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper {
(a: java.math.BigDecimal) => a.multiply(new java.math.BigDecimal(100)),
DecimalType.SYSTEM_DEFAULT,
Literal(BigDecimal("12345678901234567890.123")) :: Nil,
Option(ExpressionEncoder[java.math.BigDecimal]()) :: Nil)
Option(resolvedEncoder[java.math.BigDecimal]()) :: Nil)
val e1 = intercept[ArithmeticException](udf.eval())
assert(e1.getMessage.contains("cannot be represented as Decimal"))
val e2 = intercept[SparkException] {
Expand All @@ -79,7 +85,7 @@ class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper {
(a: java.math.BigDecimal) => a.multiply(new java.math.BigDecimal(100)),
DecimalType.SYSTEM_DEFAULT,
Literal(BigDecimal("12345678901234567890.123")) :: Nil,
Option(ExpressionEncoder[java.math.BigDecimal]()) :: Nil)
Option(resolvedEncoder[java.math.BigDecimal]()) :: Nil)
checkEvaluation(udf, null)
}
}
Expand Down
Loading