Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
update
  • Loading branch information
Ngone51 committed Jun 3, 2020
commit 1afbdf5c70e7b693bafb831ef3a2f3596896f7d3
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ class Analyzer(
PullOutNondeterministic),
Batch("UDF", Once,
HandleNullInputsForUDF,
PrepareDeserializerForUDF),
ResolveEncodersInUDF),
Batch("UpdateNullability", Once,
UpdateAttributeNullability),
Batch("Subquery", Once,
Expand Down Expand Up @@ -2814,7 +2814,7 @@ class Analyzer(

case p => p transformExpressionsUp {

case udf @ ScalaUDF(_, _, inputs, _, _, _, _, _)
case udf @ ScalaUDF(_, _, inputs, _, _, _, _)
if udf.inputPrimitives.contains(true) =>
// Otherwise, add special handling of null for fields that can't accept null.
// The result of operations like this, when passed null, is generally to return null.
Expand Down Expand Up @@ -2848,20 +2848,17 @@ class Analyzer(
}
}

object PrepareDeserializerForUDF extends Rule[LogicalPlan] {
object ResolveEncodersInUDF extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp {
case p if !p.resolved => p // Skip unresolved nodes.

case p => p transformExpressionsUp {

case udf @ ScalaUDF(_, _, inputs, encoders, _, _, _, desers)
if encoders.nonEmpty && desers.isEmpty =>
val deserializers = encoders.zipWithIndex.map { case (encOpt, i) =>
case udf @ ScalaUDF(_, _, inputs, encoders, _, _, _) if encoders.nonEmpty =>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we avoid argument matching? It's actually an anti-pattern - https://github.com/databricks/scala-style-guide#pattern-matching

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, thanks!

val resolvedEncoders = encoders.zipWithIndex.map { case (encOpt, i) =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe to call it boundEncoders.

val dataType = inputs(i).dataType
if (CatalystTypeConverters.isPrimitive(dataType) ||
dataType.isInstanceOf[UserDefinedType[_]]) {
// primitive/UDT data types use `CatalystTypeConverters` to
// convert internal data to external data.
if (dataType.isInstanceOf[UserDefinedType[_]]) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about struct/array/map of UDT?

// for UDT, we use `CatalystTypeConverters`
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

encoder does support UDT. We can figure it out later.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does, but just doesn't support upcast from the subclass to the parent class. So, when the input data type from the child is the subclass of the input parameter data type of the udf, resolveAndBind can fail.

I think this may need a separate fix.

None
} else {
encOpt.map { enc =>
Expand All @@ -2870,13 +2867,13 @@ class Analyzer(
} else {
// the field name doesn't matter here, so we use
// a simple literal to avoid any overhead
new StructType().add(s"input", dataType).toAttributes
new StructType().add("input", dataType).toAttributes
}
enc.resolveAndBind(attrs).createDeserializer()
enc.resolveAndBind(attrs)
}
}
}
udf.copy(inputDeserializers = deserializers)
udf.copy(inputEncoders = resolvedEncoders)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,9 @@ import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.CatalystTypeConverters.{createToCatalystConverter, createToScalaConverter, isPrimitive}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder.Deserializer
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, DataType}
import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, DataType, UserDefinedType}

/**
* User-defined function.
Expand All @@ -41,9 +40,6 @@ import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, DataType}
* @param nullable True if the UDF can return null value.
* @param udfDeterministic True if the UDF is deterministic. Deterministic UDF returns same result
* each time it is invoked with a particular input.
* @param inputDeserializers deserializers used to convert internal inputs from children
* to external data types, and they will only be instantiated
* by `PrepareInputDeserializerForUDF`.
*/
case class ScalaUDF(
function: AnyRef,
Expand All @@ -52,19 +48,11 @@ case class ScalaUDF(
inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Nil,
udfName: Option[String] = None,
nullable: Boolean = true,
udfDeterministic: Boolean = true,
inputDeserializers: Seq[Option[Deserializer[_]]] = Nil)
udfDeterministic: Boolean = true)
extends Expression with NonSQLExpression with UserDefinedExpression {

override lazy val deterministic: Boolean = udfDeterministic && children.forall(_.deterministic)

override lazy val canonicalized: Expression = {
val canonicalizedChildren = children.map(_.canonicalized)
// if the canonicalized children and inputEncoders are equal,
// then we must have equal inputDeserializers as well.
this.copy(inputDeserializers = Nil).withNewChildren(canonicalizedChildren)
}

override def toString: String = s"${udfName.getOrElse("UDF")}(${children.mkString(", ")})"

/**
Expand Down Expand Up @@ -115,27 +103,25 @@ case class ScalaUDF(
}

private def scalaConverter(i: Int, dataType: DataType): Any => Any = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we keep the name unchanged? If we keep using createToScalaConverter, many diff can be avoided.

if (inputEncoders.isEmpty) {
// for untyped Scala UDF
if (inputEncoders.isEmpty || // for untyped Scala UDF
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also for Java UDF

dataType.isInstanceOf[UserDefinedType[_]]) {
createToScalaConverter(dataType)
} else if (isPrimitive(dataType)) {
identity
} else {
val encoder = inputEncoders(i)
encoder match {
case Some(enc) =>
val fromRow = inputDeserializers(i).get
if (enc.isSerializedAsStructForTopLevel) {
row: Any => fromRow(row.asInstanceOf[InternalRow])
} else {
value: Any =>
val row = new GenericInternalRow(1)
row.update(0, value)
fromRow(row)
}

// e.g. for UDF types
case _ => createToScalaConverter(dataType)
val encoderOpt = inputEncoders(i)
assert(encoderOpt.isDefined, s"ScalaUDF expects an encoder of ${i}th child, but got empty.")
val enc = encoderOpt.get
if (isPrimitive(dataType) && !enc.schema.head.nullable) {
createToScalaConverter(dataType)
} else {
val fromRow = enc.createDeserializer()
if (enc.isSerializedAsStructForTopLevel) {
row: Any => fromRow(row.asInstanceOf[InternalRow])
} else {
value: Any =>
val row = new GenericInternalRow(1)
row.update(0, value)
fromRow(row)
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.analysis
import java.util.{Locale, TimeZone}

import scala.reflect.ClassTag
import scala.reflect.runtime.universe.TypeTag

import org.apache.log4j.Level
import org.scalatest.Matchers
Expand Down Expand Up @@ -307,6 +308,10 @@ class AnalysisSuite extends AnalysisTest with Matchers {
}

test("SPARK-11725: correctly handle null inputs for ScalaUDF") {
def resolvedEncoder[T : TypeTag](): ExpressionEncoder[T] = {
ExpressionEncoder[T]().resolveAndBind()
}

val testRelation = LocalRelation(
AttributeReference("a", StringType)(),
AttributeReference("b", DoubleType)(),
Expand All @@ -326,44 +331,38 @@ class AnalysisSuite extends AnalysisTest with Matchers {
)
}

val strEnc = ExpressionEncoder[String]()
val strDeser = strEnc.resolveAndBind().createDeserializer()

// non-primitive parameters do not need special null handling
val udf1 = ScalaUDF((s: String) => "x", StringType, string :: Nil,
Option(strEnc) :: Nil, inputDeserializers = Option(strDeser) :: Nil)
Option(resolvedEncoder[String]()) :: Nil)
val expected1 = udf1
checkUDF(udf1, expected1)

// only primitive parameter needs special null handling
val udf2 = ScalaUDF((s: String, d: Double) => "x", StringType, string :: double :: Nil,
Option(strEnc) :: Option(ExpressionEncoder[Double]()) :: Nil,
inputDeserializers = Option(strDeser) :: None :: Nil)
Option(resolvedEncoder[String]()) :: Option(resolvedEncoder[Double]()) :: Nil)
val expected2 =
If(IsNull(double), nullResult, udf2.copy(children = string :: KnownNotNull(double) :: Nil))
checkUDF(udf2, expected2)

// special null handling should apply to all primitive parameters
val udf3 = ScalaUDF((s: Short, d: Double) => "x", StringType, short :: double :: Nil,
Option(ExpressionEncoder[Short]()) :: Option(ExpressionEncoder[Double]()) :: Nil)
Option(resolvedEncoder[Short]()) :: Option(resolvedEncoder[Double]()) :: Nil)
val expected3 = If(
IsNull(short) || IsNull(double),
nullResult,
udf3.copy(children = KnownNotNull(short) :: KnownNotNull(double) :: Nil,
inputDeserializers = None :: None :: Nil))
udf3.copy(children = KnownNotNull(short) :: KnownNotNull(double) :: Nil))
checkUDF(udf3, expected3)

// we can skip special null handling for primitive parameters that are not nullable
val udf4 = ScalaUDF(
(s: Short, d: Double) => "x",
StringType,
short :: nonNullableDouble :: Nil,
Option(ExpressionEncoder[Short]()) :: Option(ExpressionEncoder[Double]()) :: Nil)
Option(resolvedEncoder[Short]()) :: Option(resolvedEncoder[Double]()) :: Nil)
val expected4 = If(
IsNull(short),
nullResult,
udf4.copy(children = KnownNotNull(short) :: nonNullableDouble :: Nil,
inputDeserializers = None :: None :: Nil))
udf4.copy(children = KnownNotNull(short) :: nonNullableDouble :: Nil))
checkUDF(udf4, expected4)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.expressions

import java.util.Locale

import scala.reflect.runtime.universe.TypeTag

import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
Expand All @@ -27,16 +29,17 @@ import org.apache.spark.sql.types.{DecimalType, IntegerType, StringType}

class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper {

private val strEnc = ExpressionEncoder[String]()
private val strDeser = strEnc.resolveAndBind().createDeserializer()
private def resolvedEncoder[T : TypeTag](): ExpressionEncoder[T] = {
ExpressionEncoder[T]().resolveAndBind()
}

test("basic") {
val intUdf = ScalaUDF((i: Int) => i + 1, IntegerType, Literal(1) :: Nil,
Option(ExpressionEncoder[Int]()) :: Nil)
Option(resolvedEncoder[Int]()) :: Nil)
checkEvaluation(intUdf, 2)

val stringUdf = ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil,
Option(strEnc) :: Nil, inputDeserializers = Option(strDeser) :: Nil)
Option(resolvedEncoder[String]()) :: Nil)
checkEvaluation(stringUdf, "ax")
}

Expand All @@ -45,7 +48,7 @@ class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper {
(s: String) => s.toLowerCase(Locale.ROOT),
StringType,
Literal.create(null, StringType) :: Nil,
Option(strEnc) :: Nil, inputDeserializers = Option(strDeser) :: Nil)
Option(resolvedEncoder[String]()) :: Nil)

val e1 = intercept[SparkException](udf.eval())
assert(e1.getMessage.contains("Failed to execute user defined function"))
Expand All @@ -59,19 +62,17 @@ class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper {
test("SPARK-22695: ScalaUDF should not use global variables") {
val ctx = new CodegenContext
ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil,
Option(strEnc) :: Nil, inputDeserializers = Option(strDeser) :: Nil).genCode(ctx)
Option(resolvedEncoder[String]()) :: Nil).genCode(ctx)
assert(ctx.inlinedMutableStates.isEmpty)
}

test("SPARK-28369: honor nullOnOverflow config for ScalaUDF") {
val decimalEnc = ExpressionEncoder[java.math.BigDecimal]()
val decimalDeser = decimalEnc.resolveAndBind().createDeserializer()
withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") {
val udf = ScalaUDF(
(a: java.math.BigDecimal) => a.multiply(new java.math.BigDecimal(100)),
DecimalType.SYSTEM_DEFAULT,
Literal(BigDecimal("12345678901234567890.123")) :: Nil,
Option(decimalEnc) :: Nil, inputDeserializers = Option(decimalDeser) :: Nil)
Option(resolvedEncoder[java.math.BigDecimal]()) :: Nil)
val e1 = intercept[ArithmeticException](udf.eval())
assert(e1.getMessage.contains("cannot be represented as Decimal"))
val e2 = intercept[SparkException] {
Expand All @@ -84,7 +85,7 @@ class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper {
(a: java.math.BigDecimal) => a.multiply(new java.math.BigDecimal(100)),
DecimalType.SYSTEM_DEFAULT,
Literal(BigDecimal("12345678901234567890.123")) :: Nil,
Option(decimalEnc) :: Nil, inputDeserializers = Option(decimalDeser) :: Nil)
Option(resolvedEncoder[java.math.BigDecimal]()) :: Nil)
checkEvaluation(udf, null)
}
}
Expand Down