Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
use encoder only
  • Loading branch information
Ngone51 committed Mar 19, 2020
commit 842d6fa7453d0cd34a41ebf2eb13c93c899ad83d
Original file line number Diff line number Diff line change
Expand Up @@ -2707,13 +2707,13 @@ class Analyzer(

case p => p transformExpressionsUp {

case udf @ ScalaUDF(_, _, inputs, inputPrimitives, _, _, _, _, _)
if inputPrimitives.contains(true) =>
case udf @ ScalaUDF(_, _, inputs, _, _, _, _)
if udf.inputPrimitives.contains(true) =>
// Otherwise, add special handling of null for fields that can't accept null.
// The result of operations like this, when passed null, is generally to return null.
assert(inputPrimitives.length == inputs.length)
assert(udf.inputPrimitives.length == inputs.length)

val inputPrimitivesPair = inputPrimitives.zip(inputs)
val inputPrimitivesPair = udf.inputPrimitives.zip(inputs)
val inputNullCheck = inputPrimitivesPair.collect {
case (isPrimitive, input) if isPrimitive && input.nullable =>
IsNull(input)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types.{AbstractDataType, DataType}
import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, DataType}

/**
* User-defined function.
Expand All @@ -34,17 +34,9 @@ import org.apache.spark.sql.types.{AbstractDataType, DataType}
* null. Use boxed type or [[Option]] if you wanna do the null-handling yourself.
* @param dataType Return type of function.
* @param children The input expressions of this UDF.
* @param inputPrimitives The analyzer should be aware of Scala primitive types so as to make the
* UDF return null if there is any null input value of these types. On the
* other hand, Java UDFs can only have boxed types, thus this parameter will
* always be all false.
* @param inputEncoders ExpressionEncoder for each input parameters. For a input parameter which
* serialized as struct will use encoder instead of CatalystTypeConverters to
* convert internal value to Scala value.
* @param inputTypes The expected input types of this UDF, used to perform type coercion. If we do
* not want to perform coercion, simply use "Nil". Note that it would've been
* better to use Option of Seq[DataType] so we can use "None" as the case for no
* type coercion. However, that would require more refactoring of the codebase.
* @param udfName The user-specified name of this UDF.
* @param nullable True if the UDF can return null value.
* @param udfDeterministic True if the UDF is deterministic. Deterministic UDF returns same result
Expand All @@ -54,9 +46,7 @@ case class ScalaUDF(
function: AnyRef,
dataType: DataType,
children: Seq[Expression],
inputPrimitives: Seq[Boolean],
inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Nil,
inputTypes: Seq[AbstractDataType] = Nil,
udfName: Option[String] = None,
nullable: Boolean = true,
udfDeterministic: Boolean = true)
Expand All @@ -68,6 +58,52 @@ case class ScalaUDF(

override def toString: String = s"${udfName.getOrElse("UDF")}(${children.mkString(", ")})"

/**
* The analyzer should be aware of Scala primitive types so as to make the
* UDF return null if there is any null input value of these types. On the
* other hand, Java UDFs can only have boxed types, thus this parameter will
* always be all false.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to make sure the comment is accurate. Java UDFs can only have boxed types, thus this parameter will always be all false. This is wrong now.

I agree that Nil is fine in this case, but the comment needs to be updated.

*/
def inputPrimitives: Seq[Boolean] = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this be Nil?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think need to return children.map(_ => false) if inputEncoders is empty.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It can be Nil. Previously, Java UDF returns children.map(_ => false) and it has the same affect with Nil indeed. And also, untyped Scala UDF always input Nil.

But for typed Scala UDF, it will aways has inputPrimitives and inputTypes.

inputEncoders.map { encoderOpt =>
// It's possible that some of the inputs don't have a specific encoder(e.g. `Any`)
if (encoderOpt.isDefined) {
val encoder = encoderOpt.get
if (encoder.isSerializedAsStruct) {
// struct type is not primitive
false
} else {
// `nullable` is false iff the type is primitive
!encoder.schema.head.nullable
}
} else {
// Any type is not primitive
false
}
}
}

/**
* The expected input types of this UDF, used to perform type coercion. If we do
* not want to perform coercion, simply use "Nil". Note that it would've been
* better to use Option of Seq[DataType] so we can use "None" as the case for no
* type coercion. However, that would require more refactoring of the codebase.
*/
def inputTypes: Seq[AbstractDataType] = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unless we guarantee inputEncoders always have the length of children.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly, the input types of Java UDF and untyped Scala UDF are always Nil.

inputEncoders.map { encoderOpt =>
if (encoderOpt.isDefined) {
val encoder = encoderOpt.get
if (encoder.isSerializedAsStruct) {
encoder.schema
} else {
encoder.schema.head.dataType
}
} else {
AnyDataType
}
}
}

private def createToScalaConverter(i: Int, dataType: DataType): Any => Any = {
if (inputEncoders.isEmpty) {
// for untyped Scala UDF
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType, InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Count, Sum}
Expand Down Expand Up @@ -326,20 +327,21 @@ class AnalysisSuite extends AnalysisTest with Matchers {
}

// non-primitive parameters do not need special null handling
val udf1 = ScalaUDF((s: String) => "x", StringType, string :: Nil, false :: Nil)
val udf1 = ScalaUDF((s: String) => "x", StringType, string :: Nil,
Option(ExpressionEncoder[String]()) :: Nil)
val expected1 = udf1
checkUDF(udf1, expected1)

// only primitive parameter needs special null handling
val udf2 = ScalaUDF((s: String, d: Double) => "x", StringType, string :: double :: Nil,
false :: true :: Nil)
Option(ExpressionEncoder[String]()) :: Option(ExpressionEncoder[Double]()) :: Nil)
val expected2 =
If(IsNull(double), nullResult, udf2.copy(children = string :: KnownNotNull(double) :: Nil))
checkUDF(udf2, expected2)

// special null handling should apply to all primitive parameters
val udf3 = ScalaUDF((s: Short, d: Double) => "x", StringType, short :: double :: Nil,
true :: true :: Nil)
Option(ExpressionEncoder[Short]()) :: Option(ExpressionEncoder[Double]()) :: Nil)
val expected3 = If(
IsNull(short) || IsNull(double),
nullResult,
Expand All @@ -351,7 +353,7 @@ class AnalysisSuite extends AnalysisTest with Matchers {
(s: Short, d: Double) => "x",
StringType,
short :: nonNullableDouble :: Nil,
true :: true :: Nil)
Option(ExpressionEncoder[Short]()) :: Option(ExpressionEncoder[Double]()) :: Nil)
val expected4 = If(
IsNull(short),
nullResult,
Expand All @@ -362,8 +364,12 @@ class AnalysisSuite extends AnalysisTest with Matchers {
test("SPARK-24891 Fix HandleNullInputsForUDF rule") {
val a = testRelation.output(0)
val func = (x: Int, y: Int) => x + y
val udf1 = ScalaUDF(func, IntegerType, a :: a :: Nil, false :: false :: Nil)
val udf2 = ScalaUDF(func, IntegerType, a :: udf1 :: Nil, false :: false :: Nil)
val udf1 = ScalaUDF(func, IntegerType, a :: a :: Nil,
Option(ExpressionEncoder[java.lang.Integer]()) ::
Option(ExpressionEncoder[java.lang.Integer]()) :: Nil)
val udf2 = ScalaUDF(func, IntegerType, a :: udf1 :: Nil,
Option(ExpressionEncoder[java.lang.Integer]()) ::
Option(ExpressionEncoder[java.lang.Integer]()) :: Nil)
val plan = Project(Alias(udf2, "")() :: Nil, testRelation)
comparePlans(plan.analyze, plan.analyze.analyze)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,20 @@ package org.apache.spark.sql.catalyst.expressions
import java.util.Locale

import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DecimalType, IntegerType, StringType}

class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper {

test("basic") {
val intUdf = ScalaUDF((i: Int) => i + 1, IntegerType, Literal(1) :: Nil, true :: Nil)
val intUdf = ScalaUDF((i: Int) => i + 1, IntegerType, Literal(1) :: Nil,
Option(ExpressionEncoder[Int]()) :: Nil)
checkEvaluation(intUdf, 2)

val stringUdf = ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil, false :: Nil)
val stringUdf = ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil,
Option(ExpressionEncoder[String]()) :: Nil)
checkEvaluation(stringUdf, "ax")
}

Expand All @@ -39,7 +42,7 @@ class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper {
(s: String) => s.toLowerCase(Locale.ROOT),
StringType,
Literal.create(null, StringType) :: Nil,
false :: Nil)
Option(ExpressionEncoder[String]()) :: Nil)

val e1 = intercept[SparkException](udf.eval())
assert(e1.getMessage.contains("Failed to execute user defined function"))
Expand All @@ -52,7 +55,8 @@ class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper {

test("SPARK-22695: ScalaUDF should not use global variables") {
val ctx = new CodegenContext
ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil, false :: Nil).genCode(ctx)
ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil,
Option(ExpressionEncoder[String]()) :: Nil).genCode(ctx)
assert(ctx.inlinedMutableStates.isEmpty)
}

Expand All @@ -61,7 +65,8 @@ class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper {
val udf = ScalaUDF(
(a: java.math.BigDecimal) => a.multiply(new java.math.BigDecimal(100)),
DecimalType.SYSTEM_DEFAULT,
Literal(BigDecimal("12345678901234567890.123")) :: Nil, false :: Nil)
Literal(BigDecimal("12345678901234567890.123")) :: Nil,
Option(ExpressionEncoder[java.math.BigDecimal]()) :: Nil)
val e1 = intercept[ArithmeticException](udf.eval())
assert(e1.getMessage.contains("cannot be represented as Decimal"))
val e2 = intercept[SparkException] {
Expand All @@ -73,7 +78,8 @@ class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper {
val udf = ScalaUDF(
(a: java.math.BigDecimal) => a.multiply(new java.math.BigDecimal(100)),
DecimalType.SYSTEM_DEFAULT,
Literal(BigDecimal("12345678901234567890.123")) :: Nil, false :: Nil)
Literal(BigDecimal("12345678901234567890.123")) :: Nil,
Option(ExpressionEncoder[java.math.BigDecimal]()) :: Nil)
checkEvaluation(udf, null)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry}
import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
Expand Down Expand Up @@ -244,7 +245,8 @@ class EliminateSortsSuite extends PlanTest {
}

test("should not remove orderBy in groupBy clause with ScalaUDF as aggs") {
val scalaUdf = ScalaUDF((s: Int) => s, IntegerType, 'a :: Nil, true :: Nil)
val scalaUdf = ScalaUDF((s: Int) => s, IntegerType, 'a :: Nil,
Option(ExpressionEncoder[Int]()) :: Nil)
val projectPlan = testRelation.select('a, 'b)
val orderByPlan = projectPlan.orderBy('a.asc, 'b.desc)
val groupByPlan = orderByPlan.groupBy('a)(scalaUdf)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.{AliasIdentifier, FunctionIdentifier, InternalRow, TableIdentifier}
import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.dsl.expressions.DslString
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.plans.{LeftOuter, NaturalJoin, SQLHelper}
Expand Down Expand Up @@ -594,7 +595,8 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper {
}

test("toJSON should not throws java.lang.StackOverflowError") {
val udf = ScalaUDF(SelfReferenceUDF(), BooleanType, Seq("col1".attr), false :: Nil)
val udf = ScalaUDF(SelfReferenceUDF(), BooleanType, Seq("col1".attr),
Option(ExpressionEncoder[String]()) :: Nil)
// Should not throw java.lang.StackOverflowError
udf.toJSON
}
Expand Down
Loading